#include "cppds/kdtree.h"

#include <cmath>
#include <stdio.h>
#include <windows.h>

const double *d3(double d1, double d2, double d3){
	static double d[3];
	d[0] = d1;
	d[1] = d2;
	d[2] = d3;
	return d;
}

typedef cppds::kdtree<int, 3> kdtree_t;
typedef kdtree_t::iterator kdtree_iter_t;
typedef kdtree_t::iterator_eukleide_range kdtree_iter_er_t;

kdtree_t kdt;

struct testset{
	double **data;
	int count;
	testset(int num){
		count = num;
		data = new double *[num];
		data[0] = new double[num * 3];
		for (int i = 1; i < num; ++i){
			data[i] = data[i-1] + 3;
		}
		for (int j = 0; j < num; ++j){
			for (int i = 0; i < 3; ++i){
				data[j][i] = (double)std::rand() * 200.0 / (double)(RAND_MAX + 1) - 100.0;
			}
		}
	}
	~testset(){
		delete[] data[0];
		delete data;
	}
};

double get_distance(const double *d3, const double *d3_2){
	double d, sum = 0;
	for (int i = 0; i < 3; ++i){
		d = d3[i] - d3_2[i];
		sum += d * d;
	}
	return std::sqrt(sum);
}

int test_linear_range(testset &ts, const double *query_key, double range){
	int count = 0;
	printf("linear query\n");

	double distance_sum = 0;
	for (int i = 0; i < ts.count; ++i){
		double distance = get_distance(query_key, ts.data[i]);
		if (distance < range){
			++count;
			distance_sum += distance;
		}
	}
	printf("%d queries (sum:%f)\n", count, distance_sum);
	return count;
}

int test_kdtree_range(kdtree_t &kdt, const double *query_key, double range){
	int count = 0;
	printf("kdtree query\n");

	double distance_sum = 0;
	for (kdtree_iter_er_t iter = kdt.begin_range(query_key, range); iter != kdt.end_range(); ++iter){
		const double *pos = iter->first;
		++count;
		distance_sum += iter.get_distance();
	}
	printf("%d queries (sum:%f)\n", count, distance_sum);
	return count;
}

void test_linear_knn(testset &ts, const double *query_key, int count){
	typedef std::pair<double, double *> pair_t;
	struct inner{
		static bool compare(const pair_t &p1, const pair_t &p2){
			return p1.first < p2.first;
		}
	};
	printf("linear query\n");

	pair_t *pl = new pair_t[ts.count];
	for (int i = 0; i < ts.count; ++i){
		double distance = get_distance(query_key, ts.data[i]);
		pl[i].first = distance;
		pl[i].second = ts.data[i];
	}
	std::sort(pl, pl + ts.count, inner::compare);
	for (int i = 0; i < count; ++i){
		printf("(%f, %f, %f): (%f)\n", pl[i].second[0],  pl[i].second[1],  pl[i].second[2],  pl[i].first);
	}
	delete[] pl;
}

void test_kdtree_knn(kdtree_t &kdt, const double *query_key, int count){
	printf("kdtree query\n");
	std::vector<kdtree_t::iterator_io> results(count);
	std::vector<double> distance2(count);
	int count_res = kdt.query_eukleide_knn(query_key, count, &results[0], &distance2[0]);
	for (int i = 0; i < count_res; ++i){
		const double *pos = results[i]->first;
		printf("(%f, %f, %f):%d (%f)\n", pos[0], pos[1], pos[2], results[i]->second, std::sqrt(distance2[i]));
	}
}

int main(void){
	srand(timeGetTime());
	testset ts(1600000);
	int collision = 0;
	for (int i = 0; i < ts.count; ++i){
		std::pair<kdtree_t::iterator_io,bool> res = kdt.insert(ts.data[i], i+1);
		if (res.second && res.first->second != i+1){
			printf("\n");
		}
		if (!res.second) ++collision;
	}
	printf("Փ:%d\n", collision);
	for (int i = 0; i < ts.count; ++i){
		if (kdt.insert(ts.data[i], i+1).second || kdt.find(ts.data[i]) == kdt.end_io()){
			printf("\n");
		}
	}
	kdt.balance();
	kdtree_t kdt_copy(kdt);

	DWORD time1, time2, time_d1, time_d2, time_d3;
	timeBeginPeriod(1);

	const double *query_key;
	for (int i = 0; i < 6; ++i){
		double range_max = 20.0 * (double)i;
		printf("query range:%f\n", range_max);

		query_key = d3(
			(double)rand()*200.0/(double)(RAND_MAX+1) - 100.0,
			(double)rand()*200.0/(double)(RAND_MAX+1) - 100.0,
			(double)rand()*200.0/(double)(RAND_MAX+1) - 100.0);

		time1 = timeGetTime();
		test_kdtree_range(kdt_copy, query_key, range_max);
		time2 = timeGetTime();
		time_d1 = time2 - time1;

		time1 = timeGetTime();
		test_linear_range(ts, query_key, range_max);
		time2 = timeGetTime();
		time_d2 = time2 - time1;

		printf("elapsed:%d(kdtree) vs. %d(linear)\n", time_d1, time_d2);
	}

	query_key = d3(
		(double)rand()*200.0/(double)(RAND_MAX+1) - 100.0,
		(double)rand()*200.0/(double)(RAND_MAX+1) - 100.0,
		(double)rand()*200.0/(double)(RAND_MAX+1) - 100.0);

	time1 = timeGetTime();
	test_kdtree_knn(kdt_copy, query_key, 1);
	time2 = timeGetTime();
	time_d1 = time2 - time1;

	time1 = timeGetTime();
	test_linear_knn(ts, query_key, 1);
	time2 = timeGetTime();
	time_d2 = time2 - time1;
	printf("elapsed:%d(kdtree) vs. %d(linear)\n", time_d1, time_d2);

	query_key = d3(
		(double)rand()*200.0/(double)(RAND_MAX+1) - 100.0,
		(double)rand()*200.0/(double)(RAND_MAX+1) - 100.0,
		(double)rand()*200.0/(double)(RAND_MAX+1) - 100.0);

	time1 = timeGetTime();
	test_kdtree_knn(kdt_copy, query_key, 3);
	time2 = timeGetTime();
	time_d1 = time2 - time1;

	time1 = timeGetTime();
	test_linear_knn(ts, query_key, 3);
	time2 = timeGetTime();
	time_d2 = time2 - time1;
	printf("elapsed:%d(kdtree) vs. %d(linear)\n", time_d1, time_d2);

	query_key = d3(
		(double)rand()*200.0/(double)(RAND_MAX+1) - 100.0,
		(double)rand()*200.0/(double)(RAND_MAX+1) - 100.0,
		(double)rand()*200.0/(double)(RAND_MAX+1) - 100.0);

	time1 = timeGetTime();
	test_kdtree_knn(kdt_copy, query_key, 10);
	time2 = timeGetTime();
	time_d1 = time2 - time1;

	time1 = timeGetTime();
	test_linear_knn(ts, query_key, 10);
	time2 = timeGetTime();
	time_d2 = time2 - time1;
	printf("elapsed:%d(kdtree) vs. %d(linear)\n", time_d1, time_d2);

	timeEndPeriod(1);
	getchar();
	kdtree_t for_swap;
	for_swap.swap(kdt_copy);
	return 0;
}
