diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..2a0642e --- /dev/null +++ b/.travis.yml @@ -0,0 +1,2 @@ +language: rust +script: cargo bench diff --git a/Cargo.toml b/Cargo.toml index 623ff09..c146ed7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,19 @@ [package] -name = "kdtree-rust" +name = "fux_kdtree" version = "0.1.0" -authors = ["Aleksander Fular "] +authors = ["fulara "] +description = "K-dimensional tree implemented in Rust for fast NN querying." -[dependencies] -rand = "*" +[lib] +name = "kdtree" +path = "src//lib.rs" +bench = false + +[[bench]] +name = "bench" +harness = false [dev-dependencies] -quickcheck = "0.3" \ No newline at end of file +quickcheck = "0.3" +rand = "*" +bencher = "*" \ No newline at end of file diff --git a/README.md b/README.md index 4c41701..b0feb88 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,89 @@ -# kdtree-rust +# kdtree-rust [![Build Status](https://travis-ci.org/fulara/kdtree-rust.svg?branch=develop)](https://travis-ci.org/fulara/kdtree-rust) kdtree implementation for rust. + +Implementation uses sliding midpoint variation of the tree. [More Info here](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.74.210&rep=rep1&type=pdf) + +###Usage +Tree can only be used with types implementing trait: +``` +pub trait KdtreePointTrait : Copy { + fn dims(&self) -> &[f64]; +} +``` + +Thanks to this trait you can use any dimension. Keep in mind that the tree currently only supports up to 3D. +Examplary implementation would be: +``` +pub struct Point3WithId { + dims: [f64; 3], + pub id: i32, +} + +impl KdtreePointTrait for Point3WithId { + fn dims(&self) -> &[f64] { + return &self.dims; + } +} +``` +Where id is just a example of the way in which I carry the data. +With that trait implemented you are good to go to use the tree. Keep in mind that the kdtree is not a self balancing tree, so it should not support continous add. right now the tree just handles the build up from Vec. Basic usage can be found in the integration test, fragment copied below: +``` +let tree = kdtree::kdtree::Kdtree::new(&mut points.clone()); + +//test points pushed into the tree, id should be equal. +for i in 0 .. point_count { + let p = &points[i]; + + assert_eq!(p.id, tree.nearest_search(p).id ); +} +``` + + +##Benchmark +`cargo bench` using travis :) +``` +running 3 tests +test bench_creating_1000_000_node_tree ... bench: 275,155,622 ns/iter (+/- 32,713,321) +test bench_creating_1000_node_tree ... bench: 121,314 ns/iter (+/- 1,977) +test bench_single_loop_times_for_1000_node_tree ... bench: 162 ns/iter (+/- 76) +test result: ok. 0 passed; 0 failed; 0 ignored; 3 measured +``` + +~275ms to create a 1000_000 node tree. << this bench is now disabled. +~120us to create a 1000 node tree. +160ns to query the tree. + +###Benchmark - comparison with CGAL. +Since raw values arent saying much I've created the benchmark comparing this implementation against CGAL. code of the benchmark is available here: https://github.com/fulara/kdtree-benchmarks +``` +Benchmark Time CPU Iterations +----------------------------------------------------------------- +Cgal_tree_buildup/10 2226 ns 2221 ns 313336 +Cgal_tree_buildup/100 18357 ns 18315 ns 37968 +Cgal_tree_buildup/1000 288135 ns 287345 ns 2369 +Cgal_tree_buildup/9.76562k 3296740 ns 3290815 ns 211 +Cgal_tree_buildup/97.6562k 42909150 ns 42813307 ns 12 +Cgal_tree_buildup/976.562k 734566227 ns 733267760 ns 1 +Cgal_tree_lookup/10 72 ns 72 ns 9392612 +Cgal_tree_lookup/100 95 ns 95 ns 7103628 +Cgal_tree_lookup/1000 174 ns 174 ns 4010773 +Cgal_tree_lookup/9.76562k 268 ns 267 ns 2759487 +Cgal_tree_lookup/97.6562k 881 ns 876 ns 1262454 +Cgal_tree_lookup/976.562k 993 ns 991 ns 713751 +Rust_tree_buildup/10 726 ns 724 ns 856791 +Rust_tree_buildup/100 7103 ns 7092 ns 96132 +Rust_tree_buildup/1000 84879 ns 84720 ns 7927 +Rust_tree_buildup/9.76562k 1012983 ns 1010856 ns 630 +Rust_tree_buildup/97.6562k 12406293 ns 12382399 ns 51 +Rust_tree_buildup/976.562k 197175067 ns 196763387 ns 3 +Rust_tree_lookup/10 62 ns 62 ns 11541505 +Rust_tree_lookup/100 139 ns 139 ns 4058837 +Rust_tree_lookup/1000 220 ns 220 ns 2890813 +Rust_tree_lookup/9.76562k 307 ns 307 ns 2508133 +Rust_tree_lookup/97.6562k 362 ns 362 ns 2035671 +Rust_tree_lookup/976.562k 442 ns 441 ns 1636130 +``` +Rust_tree_lookup has some overhead since the libraries are being invoked from C code into Rust, and there is minor overhead of that in between, my experience indicates around 50 ns overhead. + +##License +The Unlicense diff --git a/src/bench.rs b/src/bench.rs new file mode 100644 index 0000000..83f89e4 --- /dev/null +++ b/src/bench.rs @@ -0,0 +1,91 @@ +#[macro_use] extern crate bencher; +extern crate kdtree; +extern crate rand; + +use bencher::Bencher; + + +#[derive(Copy, Clone, PartialEq)] +pub struct Point2WithId { + dims: [f64; 2], + pub id: i32, +} + +impl Point2WithId { + pub fn new(id: i32, x: f64, y: f64) -> Point2WithId { + Point2WithId { + dims: [x, y], + id: id, + } + } +} + +impl kdtree::kdtree::KdtreePointTrait for Point2WithId { + fn dims(&self) -> &[f64] { + return &self.dims; + } +} + +#[derive(Copy, Clone, PartialEq)] +pub struct Point3WithId { + dims: [f64; 3], + pub id: i32, +} + +impl Point3WithId { + pub fn new(id: i32, x: f64, y: f64, z: f64) -> Point3WithId { + Point3WithId { + dims: [x, y, z], + id: id, + } + } +} + +impl kdtree::kdtree::KdtreePointTrait for Point3WithId { + fn dims(&self) -> &[f64] { + return &self.dims; + } +} + +fn bench_creating_1000_node_tree(b: &mut Bencher) { + let len = 1000usize; + let mut points: Vec = vec![]; + for id in 0..len { + let x: f64 = rand::random(); + points.push(Point2WithId::new(id as i32, x, x)); + } + + b.iter(|| { + kdtree::kdtree::Kdtree::new(&mut points.clone()); + }); +} + +fn bench_single_loop_times_for_1000_node_tree(b: &mut Bencher) { + let len = 1000usize; + let mut points: Vec = vec![]; + + for i in 0..len { + points.push(Point3WithId::new(i as i32, rand::random(), rand::random(), rand::random())) + } + + let tree = kdtree::kdtree::Kdtree::new(&mut points.clone()); + + + b.iter(|| tree.nearest_search(&points[0])); +} + +fn bench_creating_1000_000_node_tree(b: &mut Bencher) { + let len = 1000_000usize; + let mut points: Vec = vec![]; + for id in 0..len { + let x: f64 = rand::random(); + points.push(Point2WithId::new(id as i32, x, x)); + } + + b.iter(|| { + kdtree::kdtree::Kdtree::new(&mut points.clone()); + }); +} + +benchmark_group!(benches, bench_creating_1000_node_tree,bench_single_loop_times_for_1000_node_tree); +benchmark_main!(benches); \ No newline at end of file diff --git a/src/kdtree/bounds.rs b/src/kdtree/bounds.rs new file mode 100644 index 0000000..2843f73 --- /dev/null +++ b/src/kdtree/bounds.rs @@ -0,0 +1,109 @@ +use ::kdtree::*; + +pub struct Bounds { + pub bounds: [(f64, f64); 3], + + widest_dim: usize, + midvalue_of_widest_dim: f64, +} + +impl Bounds { + pub fn new_from_points(points: &[T]) -> Bounds { + let mut bounds = Bounds { + bounds: [(0., 0.), (0., 0.), (0., 0.)], + widest_dim: 0, + midvalue_of_widest_dim: 0., + }; + + for i in 0..points[0].dims().len() { + bounds.bounds[i].0 = points[0].dims()[i]; + bounds.bounds[i].1 = points[0].dims()[i]; + } + + for v in points.iter() { + for dim in 0..v.dims().len() { + bounds.bounds[dim].0 = bounds.bounds[dim].0.min(v.dims()[dim]); + bounds.bounds[dim].1 = bounds.bounds[dim].1.max(v.dims()[dim]); + } + } + + bounds.calculate_variables(); + + bounds + } + + pub fn get_widest_dim(&self) -> usize { + self.widest_dim + } + + pub fn get_midvalue_of_widest_dim(&self) -> f64 { + self.midvalue_of_widest_dim + } + + pub fn clone_moving_max(&self, value: f64, dimension: usize) -> Bounds { + let mut cloned = Bounds { + bounds: self.bounds.clone(), + ..*self + }; + cloned.bounds[dimension].1 = value; + + cloned.calculate_variables(); + + cloned + } + + pub fn clone_moving_min(&self, value: f64, dimension: usize) -> Bounds { + let mut cloned = Bounds { + bounds: self.bounds.clone(), + ..*self + }; + cloned.bounds[dimension].0 = value; + + cloned.calculate_variables(); + + cloned + } + + fn calculate_widest_dim(&mut self) { + let mut widest_dimension = 0usize; + let mut max_found_spread = self.bounds[0].1 - self.bounds[0].0; + + for i in 0..self.bounds.len() { + let dimension_spread = self.bounds[i].1 - self.bounds[i].0; + + if dimension_spread > max_found_spread { + max_found_spread = dimension_spread; + widest_dimension = i; + } + } + + self.widest_dim = widest_dimension; + } + + fn calculate_variables(&mut self) { + self.calculate_widest_dim(); + self.midvalue_of_widest_dim = (self.bounds[self.get_widest_dim()].0 + self.bounds[self.get_widest_dim()].1) / 2.0; + } +} + + +#[cfg(test)] +mod tests { + use super::*; + use ::kdtree::test_common::tests_utils::*; + + #[test] + fn bounds_test() { + let p1 = Point2WithId::new(1, 1.0, 0.5); + let p2 = Point2WithId::new(1, 3.0, 4.0); + let v = vec![p1, p2]; + + + let bounds = Bounds::new_from_points(&v); + + assert_eq!((1., 3.0), bounds.bounds[0]); + assert_eq!((0.5, 4.0), bounds.bounds[1]); + + assert_eq!(1, bounds.get_widest_dim()); + } +} diff --git a/src/kdtree/distance.rs b/src/kdtree/distance.rs new file mode 100644 index 0000000..b0ba3cc --- /dev/null +++ b/src/kdtree/distance.rs @@ -0,0 +1,39 @@ +pub fn squared_euclidean(a: &[f64], b: &[f64]) -> f64 { + debug_assert!(a.len() == b.len()); + + a.iter().zip(b.iter()) + .map(|(x, y)| (x - y) * (x - y)) + .sum() +} + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn squared_euclidean_test_1d() { + let a = [2.]; + let b = [4.]; + let c = [-2.]; + + assert_eq!(0., squared_euclidean(&a, &a)); + + assert_eq!(4., squared_euclidean(&a, &b)); + + assert_eq!(16., squared_euclidean(&a, &c)); + } + + #[test] + fn squared_euclidean_test_2d() { + let a = [2., 2.]; + let b = [4., 2.]; + let c = [4., 4.]; + + assert_eq!(0., squared_euclidean(&a, &a)); + + assert_eq!(4., squared_euclidean(&a, &b)); + + assert_eq!(8., squared_euclidean(&a, &c)); + } +} \ No newline at end of file diff --git a/src/kdtree/mod.rs b/src/kdtree/mod.rs index 9fe1678..612c246 100644 --- a/src/kdtree/mod.rs +++ b/src/kdtree/mod.rs @@ -1,7 +1,15 @@ -mod test_common; -mod partition; +#[cfg(test)] +pub mod test_common; -pub trait KdtreePointTrait { +pub mod distance; + +mod partition; +mod bounds; + +use self::bounds::*; +use self::distance::*; + +pub trait KdtreePointTrait: Copy { fn dims(&self) -> &[f64]; } @@ -10,29 +18,92 @@ pub struct Kdtree { } impl Kdtree { - pub fn new(points: Vec) -> Kdtree { + pub fn new(mut points: &mut [T]) -> Kdtree { if points.len() == 0 { panic!("empty vector point not allowed"); } - Kdtree { + let rect = Bounds::new_from_points(points); + + let mut tree = Kdtree { nodes: vec![], + }; + + tree.build_tree(&mut points, &rect); + + tree + } + + pub fn nearest_search(&self, node: &T) -> T + { + let mut nearest_neighbor = 0usize; + let mut best_distance = squared_euclidean(node.dims(), &self.nodes[0].point.dims()); + self.nearest_search_impl(node, 0usize, &mut best_distance, &mut nearest_neighbor); + + self.nodes[nearest_neighbor].point + } + + fn nearest_search_impl(&self, p: &T, searched_index: usize, best_distance_squared: &mut f64, best_leaf_found: &mut usize) { + let node = &self.nodes[searched_index]; + + let dimension = node.dimension; + let splitting_value = node.split_on; + let point_splitting_dim_value = p.dims()[dimension]; + + let (closer_node, farther_node) = + if point_splitting_dim_value <= splitting_value { + (node.left_node, node.right_node) + } else { + (node.right_node, node.left_node) + }; + + if let Some(closer_node) = closer_node { + self.nearest_search_impl(p, closer_node, best_distance_squared, best_leaf_found); + } + + let distance = squared_euclidean(p.dims(), node.point.dims()); + if distance < *best_distance_squared { + *best_distance_squared = distance; + *best_leaf_found = searched_index; + } + + if let Some(farther_node) = farther_node { + let distance_on_single_dimension = squared_euclidean(&[splitting_value], &[point_splitting_dim_value]); + + if distance_on_single_dimension <= *best_distance_squared { + self.nearest_search_impl(p, farther_node, best_distance_squared, best_leaf_found); + } } } - fn add_node(&mut self, p: T) { - let node = KdtreeNode::new(p); + + fn add_node(&mut self, p: T, dimension: usize, split_on: f64) -> usize { + let node = KdtreeNode::new(p, dimension, split_on); self.nodes.push(node); + self.nodes.len() - 1 } - fn add_left_node(&mut self, for_node: usize, ) { - { - let len = self.nodes.len(); - let node = self.nodes.get_mut(for_node).unwrap(); - node.left_node = Some(len); + fn build_tree(&mut self, nodes: &mut [T], bounds: &Bounds) -> usize { + let (splitting_index, pivot_value) = partition::partition_sliding_midpoint(nodes, bounds.get_midvalue_of_widest_dim(), bounds.get_widest_dim()); + + let node_id = self.add_node(nodes[splitting_index], bounds.get_widest_dim(), bounds.get_midvalue_of_widest_dim()); + let nodes_len = nodes.len(); + + if splitting_index > 0 { + let left_rect = bounds.clone_moving_max(pivot_value, bounds.get_widest_dim()); + let left_child_id = self.build_tree(&mut nodes[0..splitting_index], &left_rect); + self.nodes[node_id].left_node = Some(left_child_id); } - //self.nodes.push(KdtreeNode::new()); + + if splitting_index < nodes.len() - 1 { + let right_rect = bounds.clone_moving_min(pivot_value, bounds.get_widest_dim()); + + let right_child_id = self.build_tree(&mut nodes[splitting_index + 1..nodes_len], &right_rect); + self.nodes[node_id].right_node = Some(right_child_id); + } + + node_id } } @@ -41,39 +112,104 @@ pub struct KdtreeNode { right_node: Option, point: T, + dimension: usize, + split_on: f64 } impl KdtreeNode { - fn new(p: T) -> KdtreeNode { + fn new(p: T, splitting_dimension: usize, split_on_value: f64) -> KdtreeNode { KdtreeNode { left_node: None, right_node: None, point: p, + dimension: splitting_dimension, + split_on: split_on_value } } } - #[cfg(test)] -mod tests3 { +mod tests { use ::kdtree::test_common::tests_utils::Point2WithId; + use super::*; #[test] #[should_panic(expected = "empty vector point not allowed")] fn should_panic_given_empty_vector() { - let empty_vec: Vec = vec![]; + let mut empty_vec: Vec = vec![]; - let tree = Kdtree::new(empty_vec); + Kdtree::new(&mut empty_vec); } - #[test] - fn test2() { - let p1 = Point2WithId::new(1, 1., 2.); - let p2 = Point2WithId::new(1, 1., 2.); - let vec = vec![p1, p2]; + quickcheck! { + fn tree_build_creates_tree_with_as_many_leafs_as_there_is_points(xs : Vec) -> bool { + if xs.len() == 0 { + return true; + } + let mut vec : Vec = vec![]; + for i in 0 .. xs.len() { + let p = Point2WithId::new(i as i32, xs[i], xs[i]); - let tree = Kdtree::new(vec); + vec.push(p); + } + + let tree = Kdtree::new(&mut qc_value_vec_to_2d_points_vec(&xs)); + + let mut to_iterate : Vec = vec![]; + to_iterate.push(0); + + while to_iterate.len() > 0 { + let last_index = to_iterate.last().unwrap().clone(); + let ref x = tree.nodes.get(last_index).unwrap(); + to_iterate.pop(); + if x.left_node.is_some() { + to_iterate.push(x.left_node.unwrap()); + } + if x.right_node.is_some() { + to_iterate.push(x.right_node.unwrap()); + } + } + xs.len() == tree.nodes.len() + } + } + + quickcheck! { + fn nearest_neighbor_search_using_qc(xs : Vec) -> bool { + if xs.len() == 0 { + return true; + } + + let point_vec = qc_value_vec_to_2d_points_vec(&xs); + let tree = Kdtree::new(&mut point_vec.clone()); + + for p in &point_vec { + let found_nn = tree.nearest_search(p); + + assert_eq!(p.id,found_nn.id); + } + + true + } + } + + fn qc_value_vec_to_2d_points_vec(xs: &Vec) -> Vec { + let mut vec: Vec = vec![]; + for i in 0..xs.len() { + let mut is_duplicated_value = false; + for j in 0..i { + if xs[i] == xs[j] { + is_duplicated_value = true; + break; + } + } + if !is_duplicated_value { + let p = Point2WithId::new(i as i32, xs[i], xs[i]); + vec.push(p); + } + } + + vec } } \ No newline at end of file diff --git a/src/kdtree/partition.rs b/src/kdtree/partition.rs index f520679..5cb8e96 100644 --- a/src/kdtree/partition.rs +++ b/src/kdtree/partition.rs @@ -12,7 +12,7 @@ struct PartitionPointHelper { index_of_splitter: usize, } -fn partition_sliding_midpoint_helper(vec: &mut Vec, midpoint_value: f64, partition_on_dimension: usize) -> PartitionPointHelper { +fn partition_sliding_midpoint_helper(vec: &mut [T], midpoint_value: f64, partition_on_dimension: usize) -> PartitionPointHelper { let mut closest_index = 0; let mut closest_distance = (vec[0].dims()[partition_on_dimension] - midpoint_value).abs(); @@ -51,10 +51,13 @@ fn partition_sliding_midpoint_helper(vec: &mut Vec, midp } } -pub fn partition_sliding_midpoint(vec: &mut Vec, midpoint_value: f64, partition_on_dimension: usize) -> (usize, f64) { +pub fn partition_sliding_midpoint(vec: &mut [T], midpoint_value: f64, partition_on_dimension: usize) -> (usize, f64) { let vec_len = vec.len(); debug_assert!(vec[0].dims().len() > partition_on_dimension); - debug_assert!(vec.len() > 1); + + if vec.len() == 1 { + return (0, vec[0].dims()[partition_on_dimension]); + } let partition_point_data = partition_sliding_midpoint_helper(vec, midpoint_value, partition_on_dimension); @@ -74,24 +77,52 @@ pub fn partition_sliding_midpoint(vec: &mut Vec, midpoin } } -fn partition_kdtree(vec: &mut Vec, index_of_splitting_point: usize, partition_on_dimension: usize) -> usize { +fn partition_kdtree(vec: &mut [T], index_of_splitting_point: usize, partition_on_dimension: usize) -> usize { + if vec.len() == 1 { + return 0; + } + let pivot = vec[index_of_splitting_point].dims()[partition_on_dimension]; let vec_len = vec.len(); vec.swap(index_of_splitting_point, vec_len - 1); - //using Lomuto variant of partition here, change it to hoare sometime? - let mut store_index = 0; - for left in 0..vec_len - 1 { - if vec[left].dims()[partition_on_dimension] <= pivot { - vec.swap(left, store_index); - store_index += 1; + let mut left = 0usize; + let mut right = vec.len() - 2; + let mut last_succesful_swap = vec.len() - 1; + + //variant of Lomuto algo. + loop { + while left <= right && vec[left].dims()[partition_on_dimension] <= pivot { + left += 1; + } + + while right > left && vec[right].dims()[partition_on_dimension] > pivot { + right -= 1; + } + + if right > left { + vec.swap(left, right); + last_succesful_swap = right; + + left += 1; + right -= 1; + } else { + break; } } - vec.swap(store_index, vec_len - 1); + if last_succesful_swap == vec_len - 1 && vec[right].dims()[partition_on_dimension] > pivot { + vec.swap(right, last_succesful_swap); + last_succesful_swap = right; + } else if vec[left].dims()[partition_on_dimension] > pivot { + vec.swap(left, vec_len - 1); + last_succesful_swap = left; + } else { + vec.swap(last_succesful_swap, vec_len - 1); + } - store_index + last_succesful_swap } @@ -116,7 +147,7 @@ mod tests { let p6 = Point2WithId::new(5, 3., 8.); let p7 = Point2WithId::new(6, 4., 8.); - let mut vec = vec![p1, p2, p3, p4, p5, p6, p7]; + let vec = vec![p1, p2, p3, p4, p5, p6, p7]; assert_eq! (1, partition_kdtree(&mut vec.clone(), 3, 0)); assert_eq! (6, partition_kdtree(&mut vec.clone(), 6, 0)); @@ -135,13 +166,13 @@ mod tests { vec.push(p); } - if(xs.len() == 0 ) { + if xs.len() == 0 { return true; } let between = Range::new(0, xs.len()); let mut rng = thread_rng(); - for i in 0 .. 5 { + for _ in 0 .. 5 { let random_splitting_index = between.ind_sample(&mut rng); let mut vec = vec.clone(); @@ -209,13 +240,13 @@ mod tests { fn assert_partition(v: &Vec, index_of_splitting_point: usize) -> bool { let pivot = v[index_of_splitting_point].dims()[0]; - for i in 0 .. index_of_splitting_point { + for i in 0..index_of_splitting_point { if v[i].dims()[0] > pivot { return false; } } - for i in index_of_splitting_point + 1 .. v.len() { + for i in index_of_splitting_point + 1..v.len() { if v[i].dims()[0] < pivot { return false; } diff --git a/src/kdtree/test_common.rs b/src/kdtree/test_common.rs index ef26bdc..ab82b46 100644 --- a/src/kdtree/test_common.rs +++ b/src/kdtree/test_common.rs @@ -1,6 +1,22 @@ #[cfg(test)] pub mod tests_utils { - use ::kdtree::*; + use super::super::*; + + #[derive(Copy, Clone, PartialEq)] + pub struct Point3WithId { + dims: [f64; 3], + pub id: i32, + } + + impl Point3WithId { + pub fn new(id: i32, x: f64, y: f64, z: f64) -> Point3WithId { + Point3WithId { + dims: [x, y, z], + id: id, + } + } + } + #[derive(Copy, Clone, PartialEq)] pub struct Point2WithId { dims: [f64; 2], diff --git a/src/lib.rs b/src/lib.rs index fe15ea7..eeebc28 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ #[macro_use] extern crate quickcheck; +#[cfg(test)] extern crate rand; pub mod kdtree; \ No newline at end of file diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs new file mode 100644 index 0000000..b8735a3 --- /dev/null +++ b/tests/integration_tests.rs @@ -0,0 +1,83 @@ +extern crate kdtree; +extern crate rand; + +use rand::Rng; + +use kdtree::kdtree::*; + +//these could be taken from test_common, but I dont fully understand the module thingy yet. +#[derive(Copy, Clone, PartialEq)] +pub struct Point3WithId { + dims: [f64; 3], + pub id: i32, +} + +impl Point3WithId { + pub fn new(id: i32, x: f64, y: f64, z: f64) -> Point3WithId { + Point3WithId { + dims: [x, y, z], + id: id, + } + } +} + +impl KdtreePointTrait for Point3WithId { + fn dims(&self) -> &[f64] { + return &self.dims; + } +} + +fn gen_random() -> f64 { + rand::thread_rng().gen_range(0., 10000.) +} + +fn gen_random_usize( max_value : usize) -> usize { + rand::thread_rng().gen_range(0usize, max_value) +} + +fn find_nn_with_linear_search<'a>(points : &'a Vec, find_for : Point3WithId) -> &Point3WithId { + let distance_fun = kdtree::kdtree::distance::squared_euclidean; + + let mut best_found_distance = distance_fun(find_for.dims(), points[0].dims()); + let mut closed_found_point = &points[0]; + + for p in points { + let dist = distance_fun(find_for.dims(), p.dims()); + + if dist < best_found_distance { + best_found_distance = dist; + closed_found_point = &p; + } + } + + closed_found_point +} + +#[test] +fn test_against_1000_random_points() { + let mut points : Vec = vec![]; + + let point_count = 1000usize; + for i in 0 .. point_count { + points.push(Point3WithId::new(i as i32, gen_random(),gen_random(),gen_random())); + } + + let tree = kdtree::kdtree::Kdtree::new(&mut points.clone()); + + //test points pushed into the tree, id should be equal. + for i in 0 .. point_count { + let p = &points[i]; + + assert_eq!(p.id, tree.nearest_search(p).id ); + } + + //test randomly generated points within the cube. and do the linear search. should match + for _ in 0 .. 500 { + let p = Point3WithId::new(0i32, gen_random(), gen_random(), gen_random()); + + let found_by_linear_search = find_nn_with_linear_search(&points, p); + let point_found_by_kdtree = tree.nearest_search(&p); + + assert_eq!(point_found_by_kdtree.id, found_by_linear_search.id); + } +} \ No newline at end of file