From c31758c88355f31c7044e2a103444bdb296a314d Mon Sep 17 00:00:00 2001 From: Andrey Tkachenko Date: Tue, 28 Apr 2020 22:33:03 +0400 Subject: [PATCH] generic over float type --- Cargo.toml | 4 +- {src => benches}/bench.rs | 13 +++--- src/kdtree/bounds.rs | 39 ++++++++-------- src/kdtree/distance.rs | 39 ---------------- src/kdtree/mod.rs | 91 +++++++++++++++++++++++--------------- src/kdtree/partition.rs | 31 ++++++------- src/kdtree/test_common.rs | 33 ++++++++++---- src/lib.rs | 2 - tests/integration_tests.rs | 21 ++++----- 9 files changed, 134 insertions(+), 139 deletions(-) rename {src => benches}/bench.rs (83%) delete mode 100644 src/kdtree/distance.rs diff --git a/Cargo.toml b/Cargo.toml index 85d30e2..8ec3b4d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,10 +16,12 @@ bench = false [[bench]] name = "bench" -path = "src/bench.rs" harness = false [dev-dependencies] quickcheck = "~0.9" rand = "~0.7" bencher = "~0.1" + +[dependencies] +num-traits = "0.2" diff --git a/src/bench.rs b/benches/bench.rs similarity index 83% rename from src/bench.rs rename to benches/bench.rs index 7fe5b7b..936b619 100644 --- a/src/bench.rs +++ b/benches/bench.rs @@ -1,11 +1,10 @@ #[macro_use] extern crate bencher; -extern crate kdtree; -extern crate rand; use bencher::Bencher; use rand::Rng; +use kdtree::kdtree::KdTree; use kdtree::kdtree::test_common::*; fn gen_random() -> f64 { @@ -27,7 +26,7 @@ fn bench_creating_1000_node_tree(b: &mut Bencher) { let points = generate_points(len); b.iter(|| { - kdtree::kdtree::KdTree::new(&mut points.clone()); + KdTree::new(&mut points.clone()); }); } @@ -35,7 +34,7 @@ fn bench_single_loop_times_for_1000_node_tree(b: &mut Bencher) { let len = 1000usize; let points = generate_points(len); - let tree = kdtree::kdtree::KdTree::new(&mut points.clone()); + let tree = KdTree::new(&mut points.clone()); b.iter(|| tree.nearest_search(&points[0])); @@ -47,14 +46,14 @@ fn bench_creating_1000_000_node_tree(b: &mut Bencher) { let points = generate_points(len); b.iter(|| { - kdtree::kdtree::KdTree::new(&mut points.clone()); + KdTree::new(&mut points.clone()); }); } fn bench_adding_same_node_to_1000_tree(b: &mut Bencher) { let len = 1000usize; let mut points = generate_points(len); - let mut tree = kdtree::kdtree::KdTree::new(&mut points); + let mut tree = KdTree::new(&mut points); let point = Point3WithId::new(-1 as i32, gen_random(), gen_random(), gen_random()); b.iter(|| { @@ -66,7 +65,7 @@ fn bench_incrementally_building_the_1000_tree(b: &mut Bencher) { b.iter(|| { let len = 1usize; let mut points = generate_points(len); - let mut tree = kdtree::kdtree::KdTree::new(&mut points); + let mut tree = KdTree::new(&mut points); for _ in 0 .. 1000 { let point = Point3WithId::new(-1 as i32, gen_random(), gen_random(), gen_random()); tree.insert_node(point); diff --git a/src/kdtree/bounds.rs b/src/kdtree/bounds.rs index 39fb6e2..9a23994 100644 --- a/src/kdtree/bounds.rs +++ b/src/kdtree/bounds.rs @@ -1,31 +1,31 @@ +use num_traits::Float; use crate::kdtree::KdTreePoint; - #[derive(Clone, Copy)] -pub struct Bounds { - pub bounds: [(f64, f64); 3], +pub struct Bounds { + pub bounds: [(F, F); 3], widest_dim: usize, - midvalue_of_widest_dim: f64, + midvalue_of_widest_dim: F, } -impl Bounds { - pub fn new_from_points(points: &[T]) -> Bounds { +impl Bounds { + pub fn new_from_points>(points: &[T]) -> Bounds { let mut bounds = Bounds { - bounds: [(0., 0.), (0., 0.), (0., 0.)], + bounds: [(F::zero(), F::zero()), (F::zero(), F::zero()), (F::zero(), F::zero())], widest_dim: 0, - midvalue_of_widest_dim: 0., + midvalue_of_widest_dim: F::zero(), }; - for i in 0..points[0].dims().len() { - bounds.bounds[i].0 = points[0].dims()[i]; - bounds.bounds[i].1 = points[0].dims()[i]; + for i in 0..points[0].dims() { + bounds.bounds[i].0 = points[0].dim(i); + bounds.bounds[i].1 = points[0].dim(i); } for v in points.iter() { - for dim in 0..v.dims().len() { - bounds.bounds[dim].0 = bounds.bounds[dim].0.min(v.dims()[dim]); - bounds.bounds[dim].1 = bounds.bounds[dim].1.max(v.dims()[dim]); + for dim in 0..v.dims() { + bounds.bounds[dim].0 = bounds.bounds[dim].0.min(v.dim(dim)); + bounds.bounds[dim].1 = bounds.bounds[dim].1.max(v.dim(dim)); } } @@ -34,15 +34,18 @@ impl Bounds { bounds } + #[inline] pub fn get_widest_dim(&self) -> usize { self.widest_dim } - pub fn get_midvalue_of_widest_dim(&self) -> f64 { + #[inline] + pub fn get_midvalue_of_widest_dim(&self) -> F { self.midvalue_of_widest_dim } - pub fn clone_moving_max(&self, value: f64, dimension: usize) -> Bounds { + #[inline] + pub fn clone_moving_max(&self, value: F, dimension: usize) -> Bounds { let mut cloned = Bounds { bounds: self.bounds.clone(), ..*self @@ -55,7 +58,7 @@ impl Bounds { cloned } - pub fn clone_moving_min(&self, value: f64, dimension: usize) -> Bounds { + pub fn clone_moving_min(&self, value: F, dimension: usize) -> Bounds { let mut cloned = Bounds { bounds: self.bounds.clone(), ..*self @@ -85,7 +88,7 @@ impl Bounds { fn calculate_variables(&mut self) { self.calculate_widest_dim(); - self.midvalue_of_widest_dim = (self.bounds[self.get_widest_dim()].0 + self.bounds[self.get_widest_dim()].1) / 2.0; + self.midvalue_of_widest_dim = (self.bounds[self.get_widest_dim()].0 + self.bounds[self.get_widest_dim()].1) / F::from(2.0f32).unwrap(); } } diff --git a/src/kdtree/distance.rs b/src/kdtree/distance.rs deleted file mode 100644 index b0ba3cc..0000000 --- a/src/kdtree/distance.rs +++ /dev/null @@ -1,39 +0,0 @@ -pub fn squared_euclidean(a: &[f64], b: &[f64]) -> f64 { - debug_assert!(a.len() == b.len()); - - a.iter().zip(b.iter()) - .map(|(x, y)| (x - y) * (x - y)) - .sum() -} - - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn squared_euclidean_test_1d() { - let a = [2.]; - let b = [4.]; - let c = [-2.]; - - assert_eq!(0., squared_euclidean(&a, &a)); - - assert_eq!(4., squared_euclidean(&a, &b)); - - assert_eq!(16., squared_euclidean(&a, &c)); - } - - #[test] - fn squared_euclidean_test_2d() { - let a = [2., 2.]; - let b = [4., 2.]; - let c = [4., 4.]; - - assert_eq!(0., squared_euclidean(&a, &a)); - - assert_eq!(4., squared_euclidean(&a, &b)); - - assert_eq!(8., squared_euclidean(&a, &c)); - } -} \ No newline at end of file diff --git a/src/kdtree/mod.rs b/src/kdtree/mod.rs index e320b9b..913a851 100644 --- a/src/kdtree/mod.rs +++ b/src/kdtree/mod.rs @@ -1,38 +1,55 @@ pub mod test_common; -pub mod distance; mod partition; mod bounds; use self::bounds::*; -use self::distance::*; -use std::cmp; +use num_traits::Float; +use core::cmp; -pub trait KdTreePoint: Copy + PartialEq { - fn dist_1d(left: f64, right: f64, _dim: usize) -> f64 { +pub trait KdTreePoint: Copy + PartialEq { + fn dist_1d(left: F, right: F, _dim: usize) -> F { let diff = left - right; diff * diff } - fn dims(&self) -> &[f64]; - fn dist(&self, other: &Self) -> f64 { - squared_euclidean(self.dims(), other.dims()) + fn dims(&self) -> usize; + fn dim(&self, i: usize) -> F; + fn dist(&self, other: &Self) -> F { + let mut sum = F::zero(); + + for i in 0..self.dims() { + let x = self.dim(i); + let y = other.dim(i); + let diff = x - y; + + sum = sum + diff * diff; + } + + sum } + + #[inline] + fn to_vec(&self) -> Vec { + (0..self.dims()) + .map(|x| self.dim(x)) + .collect() + } } -pub struct NearestNeighboursIter<'a, T> { - range: f64, - kdtree: &'a KdTree, +pub struct NearestNeighboursIter<'a, F: Float, T> { + range: F, + kdtree: &'a KdTree, ref_node: T, node_stack: Vec, } -impl<'a, T> Iterator for NearestNeighboursIter<'a, T> - where T: KdTreePoint +impl<'a, F: Float, T> Iterator for NearestNeighboursIter<'a, F, T> + where T: KdTreePoint { - type Item = (f64, &'a T); + type Item = (F, &'a T); fn next(&mut self) -> Option { let p = &self.ref_node; @@ -42,7 +59,7 @@ impl<'a, T> Iterator for NearestNeighboursIter<'a, T> let node = &self.kdtree.nodes[node_idx]; let splitting_value = node.split_on; - let point_splitting_dim_value = p.dims()[node.dimension]; + let point_splitting_dim_value = p.dim(node.dimension); let distance_on_single_dimension = T::dist_1d(splitting_value, point_splitting_dim_value, node.dimension); if distance_on_single_dimension <= self.range { @@ -71,15 +88,15 @@ impl<'a, T> Iterator for NearestNeighboursIter<'a, T> } } -pub struct KdTree { - nodes: Vec>, +pub struct KdTree { + nodes: Vec>, node_adding_dimension: usize, node_depth_during_last_rebuild: usize, current_node_depth: usize, } -impl KdTree { +impl> KdTree { #[inline] pub fn empty() -> Self { KdTree { @@ -115,13 +132,13 @@ impl KdTree { /// Can be used if you are sure that the tree is degenerated or if you will never again insert the nodes into the tree. pub fn gather_points_and_rebuild(&mut self) { - let original = std::mem::replace(self, Self::empty()); + let original = core::mem::replace(self, Self::empty()); let mut points: Vec<_> = original.into_iter().collect(); self.rebuild_tree(&mut points); } - pub fn nearest_search(&self, node: &KP) -> (f64, &KP) { + pub fn nearest_search(&self, node: &KP) -> (F, &KP) { let mut nearest_neighbor = 0usize; let mut best_distance = self.nodes[0].point.dist(&node); @@ -130,7 +147,7 @@ impl KdTree { (best_distance, &self.nodes[nearest_neighbor].point) } - pub fn nearest_search_dist(&self, node: KP, dist: f64) -> NearestNeighboursIter<'_, KP> { + pub fn nearest_search_dist(&self, node: KP, dist: F) -> NearestNeighboursIter<'_, F, KP> { let mut node_stack = Vec::with_capacity(16); node_stack.push(0); @@ -142,13 +159,15 @@ impl KdTree { } } - pub fn has_neighbor_in_range(&self, node: &KP, range: f64) -> bool { + #[inline] + pub fn has_neighbor_in_range(&self, node: &KP, range: F) -> bool { let squared_range = range * range; self.distance_squared_to_nearest(node) <= squared_range } - pub fn distance_squared_to_nearest(&self, node: &KP) -> f64 { + #[inline] + pub fn distance_squared_to_nearest(&self, node: &KP) -> F { self.nearest_search(node).0 } @@ -165,7 +184,7 @@ impl KdTree { pub fn insert_node(&mut self, node_to_add: KP) { let mut current_index = 0; let dimension = self.node_adding_dimension; - let dims = node_to_add.dims().to_vec(); + let dims = node_to_add.to_vec(); let index_of_new_node = self.add_node(node_to_add, dimension,dims[dimension]); self.node_adding_dimension = (dimension + 1) % dims.len(); @@ -211,16 +230,16 @@ impl KdTree { self.nodes.pop(); } - if self.node_depth_during_last_rebuild as f64 * 4.0 < depth as f64 { + if F::from(self.node_depth_during_last_rebuild).unwrap() * F::from(4.0).unwrap() < F::from(depth).unwrap() { self.gather_points_and_rebuild(); } } - fn nearest_search_impl(&self, p: &KP, searched_index: usize, best_distance_squared: &mut f64, best_leaf_found: &mut usize) { + fn nearest_search_impl(&self, p: &KP, searched_index: usize, best_distance_squared: &mut F, best_leaf_found: &mut usize) { let node = &self.nodes[searched_index]; let splitting_value = node.split_on; - let point_splitting_dim_value = p.dims()[node.dimension]; + let point_splitting_dim_value = p.dim(node.dimension); let (closer_node, farther_node) = if point_splitting_dim_value <= splitting_value { (node.left_node, node.right_node) @@ -247,16 +266,16 @@ impl KdTree { } } - fn add_node(&mut self, p: KP, dimension: usize, split_on: f64) -> usize { + fn add_node(&mut self, p: KP, dimension: usize, split_on: F) -> usize { let node = KdTreeNode::new(p, dimension, split_on); self.nodes.push(node); self.nodes.len() - 1 } - fn build_tree(&mut self, nodes: &mut [KP], bounds: &Bounds, depth : usize) -> usize { + fn build_tree(&mut self, nodes: &mut [KP], bounds: &Bounds, depth : usize) -> usize { let splitting_index = partition::partition_sliding_midpoint(nodes, bounds.get_midvalue_of_widest_dim(), bounds.get_widest_dim()); - let pivot_value = nodes[splitting_index].dims()[bounds.get_widest_dim()]; + let pivot_value = nodes[splitting_index].dim(bounds.get_widest_dim()); let node_id = self.add_node(nodes[splitting_index], bounds.get_widest_dim(), pivot_value); let nodes_len = nodes.len(); @@ -287,17 +306,17 @@ impl KdTree { } } -pub struct KdTreeNode { +pub struct KdTreeNode { left_node: Option, right_node: Option, point: T, dimension: usize, - split_on: f64 + split_on: F } -impl KdTreeNode { - fn new(p: T, splitting_dimension: usize, split_on_value: f64) -> KdTreeNode { +impl> KdTreeNode { + fn new(p: T, splitting_dimension: usize, split_on_value: F) -> KdTreeNode { KdTreeNode { left_node: None, right_node: None, @@ -401,8 +420,8 @@ mod tests { assert_eq!(tree.nodes[0].dimension, 0); assert_eq!(tree.nodes[0].left_node.is_some(), true); - assert_eq!(tree.nodes[1].point.dims()[0], 1.); - assert_eq!(tree.nodes[2].point.dims()[0], -1.); + assert_eq!(tree.nodes[1].point.dim(0), 1.); + assert_eq!(tree.nodes[2].point.dim(0), -1.); assert_eq!(tree.nodes[0].right_node.is_some(), true); } diff --git a/src/kdtree/partition.rs b/src/kdtree/partition.rs index abb9b60..d9faf1d 100644 --- a/src/kdtree/partition.rs +++ b/src/kdtree/partition.rs @@ -1,3 +1,4 @@ +use num_traits::Float; use crate::kdtree::KdTreePoint; enum PointsWereOnSide { @@ -11,9 +12,9 @@ struct PartitionPointHelper { index_of_splitter: usize, } -fn partition_sliding_midpoint_helper(vec: &mut [T], midpoint_value: f64, partition_on_dimension: usize) -> PartitionPointHelper { +fn partition_sliding_midpoint_helper>(vec: &mut [T], midpoint_value: F, partition_on_dimension: usize) -> PartitionPointHelper { let mut closest_index = 0; - let mut closest_distance = (vec[0].dims()[partition_on_dimension] - midpoint_value).abs(); + let mut closest_distance = (vec[0].dim(partition_on_dimension) - midpoint_value).abs(); const HAS_POINTS_ON_LEFT_SIDE: i32 = 0b01; const HAS_POINTS_ON_RIGHT_SIDE: i32 = 0b10; @@ -22,13 +23,13 @@ fn partition_sliding_midpoint_helper(vec: &mut [T], midpoint_val for i in 0..vec.len() { let p = vec.get(i).unwrap(); - if p.dims()[partition_on_dimension] <= midpoint_value { + if p.dim(partition_on_dimension) <= midpoint_value { has_points_on_sides |= HAS_POINTS_ON_LEFT_SIDE; } else { has_points_on_sides |= HAS_POINTS_ON_RIGHT_SIDE; } - let dist = (p.dims()[partition_on_dimension] - midpoint_value).abs(); + let dist = (p.dim(partition_on_dimension) - midpoint_value).abs(); if dist < closest_distance { closest_distance = dist; @@ -48,9 +49,9 @@ fn partition_sliding_midpoint_helper(vec: &mut [T], midpoint_val } } -pub fn partition_sliding_midpoint(vec: &mut [T], midpoint_value: f64, partition_on_dimension: usize) -> usize { +pub fn partition_sliding_midpoint>(vec: &mut [T], midpoint_value: F, partition_on_dimension: usize) -> usize { let vec_len = vec.len(); - debug_assert!(vec[0].dims().len() > partition_on_dimension); + debug_assert!(vec[0].dims() > partition_on_dimension); if vec.len() == 1 { return 0; @@ -74,12 +75,12 @@ pub fn partition_sliding_midpoint(vec: &mut [T], midpoint_value: } } -fn partition_kdtree(vec: &mut [T], index_of_splitting_point: usize, partition_on_dimension: usize) -> usize { +fn partition_kdtree>(vec: &mut [T], index_of_splitting_point: usize, partition_on_dimension: usize) -> usize { if vec.len() == 1 { return 0; } - let pivot = vec[index_of_splitting_point].dims()[partition_on_dimension]; + let pivot = vec[index_of_splitting_point].dim(partition_on_dimension); let vec_len = vec.len(); vec.swap(index_of_splitting_point, vec_len - 1); @@ -90,11 +91,11 @@ fn partition_kdtree(vec: &mut [T], index_of_splitting_point: usi //variant of Lomuto algo. loop { - while left <= right && vec[left].dims()[partition_on_dimension] <= pivot { + while left <= right && vec[left].dim(partition_on_dimension) <= pivot { left += 1; } - while right > left && vec[right].dims()[partition_on_dimension] > pivot { + while right > left && vec[right].dim(partition_on_dimension) > pivot { right -= 1; } @@ -109,10 +110,10 @@ fn partition_kdtree(vec: &mut [T], index_of_splitting_point: usi } } - if last_succesful_swap == vec_len - 1 && vec[right].dims()[partition_on_dimension] > pivot { + if last_succesful_swap == vec_len - 1 && vec[right].dim(partition_on_dimension) > pivot { vec.swap(right, last_succesful_swap); last_succesful_swap = right; - } else if vec[left].dims()[partition_on_dimension] > pivot { + } else if vec[left].dim(partition_on_dimension) > pivot { vec.swap(left, vec_len - 1); last_succesful_swap = left; } else { @@ -228,16 +229,16 @@ mod tests { } fn assert_partition(v: &Vec, index_of_splitting_point: usize) -> bool { - let pivot = v[index_of_splitting_point].dims()[0]; + let pivot = v[index_of_splitting_point].dim(0); for i in 0..index_of_splitting_point { - if v[i].dims()[0] > pivot { + if v[i].dim(0) > pivot { return false; } } for i in index_of_splitting_point + 1..v.len() { - if v[i].dims()[0] < pivot { + if v[i].dim(0) < pivot { return false; } } diff --git a/src/kdtree/test_common.rs b/src/kdtree/test_common.rs index 38f91c3..15e21f8 100644 --- a/src/kdtree/test_common.rs +++ b/src/kdtree/test_common.rs @@ -15,10 +15,15 @@ impl Point3WithId { } } -impl KdTreePoint for Point3WithId { +impl KdTreePoint for Point3WithId { #[inline] - fn dims(&self) -> &[f64] { - return &self.dims; + fn dims(&self) -> usize { + self.dims.len() + } + + #[inline] + fn dim(&self, i: usize) -> f64 { + self.dims[i] } } @@ -37,10 +42,15 @@ impl Point2WithId { } } -impl KdTreePoint for Point2WithId { +impl KdTreePoint for Point2WithId { #[inline] - fn dims(&self) -> &[f64] { - return &self.dims; + fn dims(&self) -> usize { + self.dims.len() + } + + #[inline] + fn dim(&self, i: usize) -> f64 { + self.dims[i] } } @@ -59,9 +69,14 @@ impl Point1WithId { } } -impl KdTreePoint for Point1WithId { +impl KdTreePoint for Point1WithId { #[inline] - fn dims(&self) -> &[f64] { - return &self.dims; + fn dims(&self) -> usize { + self.dims.len() + } + + #[inline] + fn dim(&self, i: usize) -> f64 { + self.dims[i] } } \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 3c78b33..44b15e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,5 +7,3 @@ extern crate rand; pub mod kdtree; -#[cfg(test)] -mod bench; diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 844be6f..407903e 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -1,22 +1,19 @@ -extern crate kdtree; -extern crate rand; - use rand::Rng; use kdtree::kdtree::test_common::*; use kdtree::kdtree::KdTreePoint; -use kdtree::kdtree::distance::squared_euclidean; +use kdtree::kdtree::KdTree; fn gen_random() -> f64 { rand::thread_rng().gen_range(0., 1000.) } fn find_nn_with_linear_search(points : &Vec, find_for : Point3WithId) -> (f64, &Point3WithId) { - let mut best_found_distance = squared_euclidean(find_for.dims(), points[0].dims()); + let mut best_found_distance = find_for.dist(&points[0]); let mut closed_found_point = &points[0]; for p in points { - let dist = squared_euclidean(find_for.dims(), p.dims()); + let dist = find_for.dist(p); if dist < best_found_distance { best_found_distance = dist; @@ -31,7 +28,7 @@ fn find_neigbours_with_linear_search(points : &Vec, find_for : Poi let mut result = Vec::new(); for p in points { - let d = squared_euclidean(find_for.dims(), p.dims()); + let d = find_for.dist(p); if d <= dist { result.push((d, p)); @@ -56,9 +53,9 @@ fn generate_points(point_count : usize) -> Vec { fn test_against_1000_random_points() { let point_count = 1000usize; let points = generate_points(point_count); - kdtree::kdtree::test_common::Point1WithId::new(0,0.); + Point1WithId::new(0,0.); - let tree = kdtree::kdtree::KdTree::new(&mut points.clone()); + let tree = KdTree::new(&mut points.clone()); //test points pushed into the tree, id should be equal. for i in 0 .. point_count { @@ -83,8 +80,8 @@ fn test_incrementally_build_tree_against_built_at_once() { let point_count = 2000usize; let mut points = generate_points(point_count); - let tree_built_at_once = kdtree::kdtree::KdTree::new(&mut points.clone()); - let mut tree_built_incrementally = kdtree::kdtree::KdTree::new(&mut points[0..1]); + let tree_built_at_once = KdTree::new(&mut points.clone()); + let mut tree_built_incrementally = KdTree::new(&mut points[0..1]); for i in 1 .. point_count { let p = &points[i]; @@ -113,7 +110,7 @@ fn test_incrementally_build_tree_against_built_at_once() { fn test_neighbour_search_with_distance() { let point_count = 1000usize; let points = generate_points(point_count); - let tree = kdtree::kdtree::KdTree::new(&mut points.clone()); + let tree = KdTree::new(&mut points.clone()); for _ in 0 .. 500 { let dist = 100.0;