generic over float type

This commit is contained in:
Andrey Tkachenko 2020-04-28 22:33:03 +04:00
parent 09606936ad
commit c31758c883
9 changed files with 134 additions and 139 deletions

View File

@ -16,10 +16,12 @@ bench = false
[[bench]] [[bench]]
name = "bench" name = "bench"
path = "src/bench.rs"
harness = false harness = false
[dev-dependencies] [dev-dependencies]
quickcheck = "~0.9" quickcheck = "~0.9"
rand = "~0.7" rand = "~0.7"
bencher = "~0.1" bencher = "~0.1"
[dependencies]
num-traits = "0.2"

View File

@ -1,11 +1,10 @@
#[macro_use] #[macro_use]
extern crate bencher; extern crate bencher;
extern crate kdtree;
extern crate rand;
use bencher::Bencher; use bencher::Bencher;
use rand::Rng; use rand::Rng;
use kdtree::kdtree::KdTree;
use kdtree::kdtree::test_common::*; use kdtree::kdtree::test_common::*;
fn gen_random() -> f64 { fn gen_random() -> f64 {
@ -27,7 +26,7 @@ fn bench_creating_1000_node_tree(b: &mut Bencher) {
let points = generate_points(len); let points = generate_points(len);
b.iter(|| { b.iter(|| {
kdtree::kdtree::KdTree::new(&mut points.clone()); KdTree::new(&mut points.clone());
}); });
} }
@ -35,7 +34,7 @@ fn bench_single_loop_times_for_1000_node_tree(b: &mut Bencher) {
let len = 1000usize; let len = 1000usize;
let points = generate_points(len); let points = generate_points(len);
let tree = kdtree::kdtree::KdTree::new(&mut points.clone()); let tree = KdTree::new(&mut points.clone());
b.iter(|| tree.nearest_search(&points[0])); b.iter(|| tree.nearest_search(&points[0]));
@ -47,14 +46,14 @@ fn bench_creating_1000_000_node_tree(b: &mut Bencher) {
let points = generate_points(len); let points = generate_points(len);
b.iter(|| { b.iter(|| {
kdtree::kdtree::KdTree::new(&mut points.clone()); KdTree::new(&mut points.clone());
}); });
} }
fn bench_adding_same_node_to_1000_tree(b: &mut Bencher) { fn bench_adding_same_node_to_1000_tree(b: &mut Bencher) {
let len = 1000usize; let len = 1000usize;
let mut points = generate_points(len); let mut points = generate_points(len);
let mut tree = kdtree::kdtree::KdTree::new(&mut points); let mut tree = KdTree::new(&mut points);
let point = Point3WithId::new(-1 as i32, gen_random(), gen_random(), gen_random()); let point = Point3WithId::new(-1 as i32, gen_random(), gen_random(), gen_random());
b.iter(|| { b.iter(|| {
@ -66,7 +65,7 @@ fn bench_incrementally_building_the_1000_tree(b: &mut Bencher) {
b.iter(|| { b.iter(|| {
let len = 1usize; let len = 1usize;
let mut points = generate_points(len); let mut points = generate_points(len);
let mut tree = kdtree::kdtree::KdTree::new(&mut points); let mut tree = KdTree::new(&mut points);
for _ in 0 .. 1000 { for _ in 0 .. 1000 {
let point = Point3WithId::new(-1 as i32, gen_random(), gen_random(), gen_random()); let point = Point3WithId::new(-1 as i32, gen_random(), gen_random(), gen_random());
tree.insert_node(point); tree.insert_node(point);

View File

@ -1,31 +1,31 @@
use num_traits::Float;
use crate::kdtree::KdTreePoint; use crate::kdtree::KdTreePoint;
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub struct Bounds { pub struct Bounds<F: Float> {
pub bounds: [(f64, f64); 3], pub bounds: [(F, F); 3],
widest_dim: usize, widest_dim: usize,
midvalue_of_widest_dim: f64, midvalue_of_widest_dim: F,
} }
impl Bounds { impl<F: Float> Bounds<F> {
pub fn new_from_points<T: KdTreePoint>(points: &[T]) -> Bounds { pub fn new_from_points<T: KdTreePoint<F>>(points: &[T]) -> Bounds<F> {
let mut bounds = Bounds { let mut bounds = Bounds {
bounds: [(0., 0.), (0., 0.), (0., 0.)], bounds: [(F::zero(), F::zero()), (F::zero(), F::zero()), (F::zero(), F::zero())],
widest_dim: 0, widest_dim: 0,
midvalue_of_widest_dim: 0., midvalue_of_widest_dim: F::zero(),
}; };
for i in 0..points[0].dims().len() { for i in 0..points[0].dims() {
bounds.bounds[i].0 = points[0].dims()[i]; bounds.bounds[i].0 = points[0].dim(i);
bounds.bounds[i].1 = points[0].dims()[i]; bounds.bounds[i].1 = points[0].dim(i);
} }
for v in points.iter() { for v in points.iter() {
for dim in 0..v.dims().len() { for dim in 0..v.dims() {
bounds.bounds[dim].0 = bounds.bounds[dim].0.min(v.dims()[dim]); bounds.bounds[dim].0 = bounds.bounds[dim].0.min(v.dim(dim));
bounds.bounds[dim].1 = bounds.bounds[dim].1.max(v.dims()[dim]); bounds.bounds[dim].1 = bounds.bounds[dim].1.max(v.dim(dim));
} }
} }
@ -34,15 +34,18 @@ impl Bounds {
bounds bounds
} }
#[inline]
pub fn get_widest_dim(&self) -> usize { pub fn get_widest_dim(&self) -> usize {
self.widest_dim self.widest_dim
} }
pub fn get_midvalue_of_widest_dim(&self) -> f64 { #[inline]
pub fn get_midvalue_of_widest_dim(&self) -> F {
self.midvalue_of_widest_dim self.midvalue_of_widest_dim
} }
pub fn clone_moving_max(&self, value: f64, dimension: usize) -> Bounds { #[inline]
pub fn clone_moving_max(&self, value: F, dimension: usize) -> Bounds<F> {
let mut cloned = Bounds { let mut cloned = Bounds {
bounds: self.bounds.clone(), bounds: self.bounds.clone(),
..*self ..*self
@ -55,7 +58,7 @@ impl Bounds {
cloned cloned
} }
pub fn clone_moving_min(&self, value: f64, dimension: usize) -> Bounds { pub fn clone_moving_min(&self, value: F, dimension: usize) -> Bounds<F> {
let mut cloned = Bounds { let mut cloned = Bounds {
bounds: self.bounds.clone(), bounds: self.bounds.clone(),
..*self ..*self
@ -85,7 +88,7 @@ impl Bounds {
fn calculate_variables(&mut self) { fn calculate_variables(&mut self) {
self.calculate_widest_dim(); self.calculate_widest_dim();
self.midvalue_of_widest_dim = (self.bounds[self.get_widest_dim()].0 + self.bounds[self.get_widest_dim()].1) / 2.0; self.midvalue_of_widest_dim = (self.bounds[self.get_widest_dim()].0 + self.bounds[self.get_widest_dim()].1) / F::from(2.0f32).unwrap();
} }
} }

View File

@ -1,39 +0,0 @@
pub fn squared_euclidean(a: &[f64], b: &[f64]) -> f64 {
debug_assert!(a.len() == b.len());
a.iter().zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn squared_euclidean_test_1d() {
let a = [2.];
let b = [4.];
let c = [-2.];
assert_eq!(0., squared_euclidean(&a, &a));
assert_eq!(4., squared_euclidean(&a, &b));
assert_eq!(16., squared_euclidean(&a, &c));
}
#[test]
fn squared_euclidean_test_2d() {
let a = [2., 2.];
let b = [4., 2.];
let c = [4., 4.];
assert_eq!(0., squared_euclidean(&a, &a));
assert_eq!(4., squared_euclidean(&a, &b));
assert_eq!(8., squared_euclidean(&a, &c));
}
}

View File

@ -1,38 +1,55 @@
pub mod test_common; pub mod test_common;
pub mod distance;
mod partition; mod partition;
mod bounds; mod bounds;
use self::bounds::*; use self::bounds::*;
use self::distance::*;
use std::cmp; use num_traits::Float;
use core::cmp;
pub trait KdTreePoint: Copy + PartialEq { pub trait KdTreePoint<F: Float>: Copy + PartialEq {
fn dist_1d(left: f64, right: f64, _dim: usize) -> f64 { fn dist_1d(left: F, right: F, _dim: usize) -> F {
let diff = left - right; let diff = left - right;
diff * diff diff * diff
} }
fn dims(&self) -> &[f64]; fn dims(&self) -> usize;
fn dist(&self, other: &Self) -> f64 { fn dim(&self, i: usize) -> F;
squared_euclidean(self.dims(), other.dims()) fn dist(&self, other: &Self) -> F {
let mut sum = F::zero();
for i in 0..self.dims() {
let x = self.dim(i);
let y = other.dim(i);
let diff = x - y;
sum = sum + diff * diff;
}
sum
}
#[inline]
fn to_vec(&self) -> Vec<F> {
(0..self.dims())
.map(|x| self.dim(x))
.collect()
} }
} }
pub struct NearestNeighboursIter<'a, T> { pub struct NearestNeighboursIter<'a, F: Float, T> {
range: f64, range: F,
kdtree: &'a KdTree<T>, kdtree: &'a KdTree<F, T>,
ref_node: T, ref_node: T,
node_stack: Vec<usize>, node_stack: Vec<usize>,
} }
impl<'a, T> Iterator for NearestNeighboursIter<'a, T> impl<'a, F: Float, T> Iterator for NearestNeighboursIter<'a, F, T>
where T: KdTreePoint where T: KdTreePoint<F>
{ {
type Item = (f64, &'a T); type Item = (F, &'a T);
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let p = &self.ref_node; let p = &self.ref_node;
@ -42,7 +59,7 @@ impl<'a, T> Iterator for NearestNeighboursIter<'a, T>
let node = &self.kdtree.nodes[node_idx]; let node = &self.kdtree.nodes[node_idx];
let splitting_value = node.split_on; let splitting_value = node.split_on;
let point_splitting_dim_value = p.dims()[node.dimension]; let point_splitting_dim_value = p.dim(node.dimension);
let distance_on_single_dimension = T::dist_1d(splitting_value, point_splitting_dim_value, node.dimension); let distance_on_single_dimension = T::dist_1d(splitting_value, point_splitting_dim_value, node.dimension);
if distance_on_single_dimension <= self.range { if distance_on_single_dimension <= self.range {
@ -71,15 +88,15 @@ impl<'a, T> Iterator for NearestNeighboursIter<'a, T>
} }
} }
pub struct KdTree<KP> { pub struct KdTree<F: Float, KP> {
nodes: Vec<KdTreeNode<KP>>, nodes: Vec<KdTreeNode<F, KP>>,
node_adding_dimension: usize, node_adding_dimension: usize,
node_depth_during_last_rebuild: usize, node_depth_during_last_rebuild: usize,
current_node_depth: usize, current_node_depth: usize,
} }
impl<KP: KdTreePoint> KdTree<KP> { impl<F: Float, KP: KdTreePoint<F>> KdTree<F, KP> {
#[inline] #[inline]
pub fn empty() -> Self { pub fn empty() -> Self {
KdTree { KdTree {
@ -115,13 +132,13 @@ impl<KP: KdTreePoint> KdTree<KP> {
/// Can be used if you are sure that the tree is degenerated or if you will never again insert the nodes into the tree. /// Can be used if you are sure that the tree is degenerated or if you will never again insert the nodes into the tree.
pub fn gather_points_and_rebuild(&mut self) { pub fn gather_points_and_rebuild(&mut self) {
let original = std::mem::replace(self, Self::empty()); let original = core::mem::replace(self, Self::empty());
let mut points: Vec<_> = original.into_iter().collect(); let mut points: Vec<_> = original.into_iter().collect();
self.rebuild_tree(&mut points); self.rebuild_tree(&mut points);
} }
pub fn nearest_search(&self, node: &KP) -> (f64, &KP) { pub fn nearest_search(&self, node: &KP) -> (F, &KP) {
let mut nearest_neighbor = 0usize; let mut nearest_neighbor = 0usize;
let mut best_distance = self.nodes[0].point.dist(&node); let mut best_distance = self.nodes[0].point.dist(&node);
@ -130,7 +147,7 @@ impl<KP: KdTreePoint> KdTree<KP> {
(best_distance, &self.nodes[nearest_neighbor].point) (best_distance, &self.nodes[nearest_neighbor].point)
} }
pub fn nearest_search_dist(&self, node: KP, dist: f64) -> NearestNeighboursIter<'_, KP> { pub fn nearest_search_dist(&self, node: KP, dist: F) -> NearestNeighboursIter<'_, F, KP> {
let mut node_stack = Vec::with_capacity(16); let mut node_stack = Vec::with_capacity(16);
node_stack.push(0); node_stack.push(0);
@ -142,13 +159,15 @@ impl<KP: KdTreePoint> KdTree<KP> {
} }
} }
pub fn has_neighbor_in_range(&self, node: &KP, range: f64) -> bool { #[inline]
pub fn has_neighbor_in_range(&self, node: &KP, range: F) -> bool {
let squared_range = range * range; let squared_range = range * range;
self.distance_squared_to_nearest(node) <= squared_range self.distance_squared_to_nearest(node) <= squared_range
} }
pub fn distance_squared_to_nearest(&self, node: &KP) -> f64 { #[inline]
pub fn distance_squared_to_nearest(&self, node: &KP) -> F {
self.nearest_search(node).0 self.nearest_search(node).0
} }
@ -165,7 +184,7 @@ impl<KP: KdTreePoint> KdTree<KP> {
pub fn insert_node(&mut self, node_to_add: KP) { pub fn insert_node(&mut self, node_to_add: KP) {
let mut current_index = 0; let mut current_index = 0;
let dimension = self.node_adding_dimension; let dimension = self.node_adding_dimension;
let dims = node_to_add.dims().to_vec(); let dims = node_to_add.to_vec();
let index_of_new_node = self.add_node(node_to_add, dimension,dims[dimension]); let index_of_new_node = self.add_node(node_to_add, dimension,dims[dimension]);
self.node_adding_dimension = (dimension + 1) % dims.len(); self.node_adding_dimension = (dimension + 1) % dims.len();
@ -211,16 +230,16 @@ impl<KP: KdTreePoint> KdTree<KP> {
self.nodes.pop(); self.nodes.pop();
} }
if self.node_depth_during_last_rebuild as f64 * 4.0 < depth as f64 { if F::from(self.node_depth_during_last_rebuild).unwrap() * F::from(4.0).unwrap() < F::from(depth).unwrap() {
self.gather_points_and_rebuild(); self.gather_points_and_rebuild();
} }
} }
fn nearest_search_impl(&self, p: &KP, searched_index: usize, best_distance_squared: &mut f64, best_leaf_found: &mut usize) { fn nearest_search_impl(&self, p: &KP, searched_index: usize, best_distance_squared: &mut F, best_leaf_found: &mut usize) {
let node = &self.nodes[searched_index]; let node = &self.nodes[searched_index];
let splitting_value = node.split_on; let splitting_value = node.split_on;
let point_splitting_dim_value = p.dims()[node.dimension]; let point_splitting_dim_value = p.dim(node.dimension);
let (closer_node, farther_node) = if point_splitting_dim_value <= splitting_value { let (closer_node, farther_node) = if point_splitting_dim_value <= splitting_value {
(node.left_node, node.right_node) (node.left_node, node.right_node)
@ -247,16 +266,16 @@ impl<KP: KdTreePoint> KdTree<KP> {
} }
} }
fn add_node(&mut self, p: KP, dimension: usize, split_on: f64) -> usize { fn add_node(&mut self, p: KP, dimension: usize, split_on: F) -> usize {
let node = KdTreeNode::new(p, dimension, split_on); let node = KdTreeNode::new(p, dimension, split_on);
self.nodes.push(node); self.nodes.push(node);
self.nodes.len() - 1 self.nodes.len() - 1
} }
fn build_tree(&mut self, nodes: &mut [KP], bounds: &Bounds, depth : usize) -> usize { fn build_tree(&mut self, nodes: &mut [KP], bounds: &Bounds<F>, depth : usize) -> usize {
let splitting_index = partition::partition_sliding_midpoint(nodes, bounds.get_midvalue_of_widest_dim(), bounds.get_widest_dim()); let splitting_index = partition::partition_sliding_midpoint(nodes, bounds.get_midvalue_of_widest_dim(), bounds.get_widest_dim());
let pivot_value = nodes[splitting_index].dims()[bounds.get_widest_dim()]; let pivot_value = nodes[splitting_index].dim(bounds.get_widest_dim());
let node_id = self.add_node(nodes[splitting_index], bounds.get_widest_dim(), pivot_value); let node_id = self.add_node(nodes[splitting_index], bounds.get_widest_dim(), pivot_value);
let nodes_len = nodes.len(); let nodes_len = nodes.len();
@ -287,17 +306,17 @@ impl<KP: KdTreePoint> KdTree<KP> {
} }
} }
pub struct KdTreeNode<T> { pub struct KdTreeNode<F: Float, T> {
left_node: Option<usize>, left_node: Option<usize>,
right_node: Option<usize>, right_node: Option<usize>,
point: T, point: T,
dimension: usize, dimension: usize,
split_on: f64 split_on: F
} }
impl<T: KdTreePoint> KdTreeNode<T> { impl<F: Float, T: KdTreePoint<F>> KdTreeNode<F, T> {
fn new(p: T, splitting_dimension: usize, split_on_value: f64) -> KdTreeNode<T> { fn new(p: T, splitting_dimension: usize, split_on_value: F) -> KdTreeNode<F, T> {
KdTreeNode { KdTreeNode {
left_node: None, left_node: None,
right_node: None, right_node: None,
@ -401,8 +420,8 @@ mod tests {
assert_eq!(tree.nodes[0].dimension, 0); assert_eq!(tree.nodes[0].dimension, 0);
assert_eq!(tree.nodes[0].left_node.is_some(), true); assert_eq!(tree.nodes[0].left_node.is_some(), true);
assert_eq!(tree.nodes[1].point.dims()[0], 1.); assert_eq!(tree.nodes[1].point.dim(0), 1.);
assert_eq!(tree.nodes[2].point.dims()[0], -1.); assert_eq!(tree.nodes[2].point.dim(0), -1.);
assert_eq!(tree.nodes[0].right_node.is_some(), true); assert_eq!(tree.nodes[0].right_node.is_some(), true);
} }

View File

@ -1,3 +1,4 @@
use num_traits::Float;
use crate::kdtree::KdTreePoint; use crate::kdtree::KdTreePoint;
enum PointsWereOnSide { enum PointsWereOnSide {
@ -11,9 +12,9 @@ struct PartitionPointHelper {
index_of_splitter: usize, index_of_splitter: usize,
} }
fn partition_sliding_midpoint_helper<T: KdTreePoint>(vec: &mut [T], midpoint_value: f64, partition_on_dimension: usize) -> PartitionPointHelper { fn partition_sliding_midpoint_helper<F: Float, T: KdTreePoint<F>>(vec: &mut [T], midpoint_value: F, partition_on_dimension: usize) -> PartitionPointHelper {
let mut closest_index = 0; let mut closest_index = 0;
let mut closest_distance = (vec[0].dims()[partition_on_dimension] - midpoint_value).abs(); let mut closest_distance = (vec[0].dim(partition_on_dimension) - midpoint_value).abs();
const HAS_POINTS_ON_LEFT_SIDE: i32 = 0b01; const HAS_POINTS_ON_LEFT_SIDE: i32 = 0b01;
const HAS_POINTS_ON_RIGHT_SIDE: i32 = 0b10; const HAS_POINTS_ON_RIGHT_SIDE: i32 = 0b10;
@ -22,13 +23,13 @@ fn partition_sliding_midpoint_helper<T: KdTreePoint>(vec: &mut [T], midpoint_val
for i in 0..vec.len() { for i in 0..vec.len() {
let p = vec.get(i).unwrap(); let p = vec.get(i).unwrap();
if p.dims()[partition_on_dimension] <= midpoint_value { if p.dim(partition_on_dimension) <= midpoint_value {
has_points_on_sides |= HAS_POINTS_ON_LEFT_SIDE; has_points_on_sides |= HAS_POINTS_ON_LEFT_SIDE;
} else { } else {
has_points_on_sides |= HAS_POINTS_ON_RIGHT_SIDE; has_points_on_sides |= HAS_POINTS_ON_RIGHT_SIDE;
} }
let dist = (p.dims()[partition_on_dimension] - midpoint_value).abs(); let dist = (p.dim(partition_on_dimension) - midpoint_value).abs();
if dist < closest_distance { if dist < closest_distance {
closest_distance = dist; closest_distance = dist;
@ -48,9 +49,9 @@ fn partition_sliding_midpoint_helper<T: KdTreePoint>(vec: &mut [T], midpoint_val
} }
} }
pub fn partition_sliding_midpoint<T: KdTreePoint>(vec: &mut [T], midpoint_value: f64, partition_on_dimension: usize) -> usize { pub fn partition_sliding_midpoint<F: Float, T: KdTreePoint<F>>(vec: &mut [T], midpoint_value: F, partition_on_dimension: usize) -> usize {
let vec_len = vec.len(); let vec_len = vec.len();
debug_assert!(vec[0].dims().len() > partition_on_dimension); debug_assert!(vec[0].dims() > partition_on_dimension);
if vec.len() == 1 { if vec.len() == 1 {
return 0; return 0;
@ -74,12 +75,12 @@ pub fn partition_sliding_midpoint<T: KdTreePoint>(vec: &mut [T], midpoint_value:
} }
} }
fn partition_kdtree<T: KdTreePoint>(vec: &mut [T], index_of_splitting_point: usize, partition_on_dimension: usize) -> usize { fn partition_kdtree<F: Float, T: KdTreePoint<F>>(vec: &mut [T], index_of_splitting_point: usize, partition_on_dimension: usize) -> usize {
if vec.len() == 1 { if vec.len() == 1 {
return 0; return 0;
} }
let pivot = vec[index_of_splitting_point].dims()[partition_on_dimension]; let pivot = vec[index_of_splitting_point].dim(partition_on_dimension);
let vec_len = vec.len(); let vec_len = vec.len();
vec.swap(index_of_splitting_point, vec_len - 1); vec.swap(index_of_splitting_point, vec_len - 1);
@ -90,11 +91,11 @@ fn partition_kdtree<T: KdTreePoint>(vec: &mut [T], index_of_splitting_point: usi
//variant of Lomuto algo. //variant of Lomuto algo.
loop { loop {
while left <= right && vec[left].dims()[partition_on_dimension] <= pivot { while left <= right && vec[left].dim(partition_on_dimension) <= pivot {
left += 1; left += 1;
} }
while right > left && vec[right].dims()[partition_on_dimension] > pivot { while right > left && vec[right].dim(partition_on_dimension) > pivot {
right -= 1; right -= 1;
} }
@ -109,10 +110,10 @@ fn partition_kdtree<T: KdTreePoint>(vec: &mut [T], index_of_splitting_point: usi
} }
} }
if last_succesful_swap == vec_len - 1 && vec[right].dims()[partition_on_dimension] > pivot { if last_succesful_swap == vec_len - 1 && vec[right].dim(partition_on_dimension) > pivot {
vec.swap(right, last_succesful_swap); vec.swap(right, last_succesful_swap);
last_succesful_swap = right; last_succesful_swap = right;
} else if vec[left].dims()[partition_on_dimension] > pivot { } else if vec[left].dim(partition_on_dimension) > pivot {
vec.swap(left, vec_len - 1); vec.swap(left, vec_len - 1);
last_succesful_swap = left; last_succesful_swap = left;
} else { } else {
@ -228,16 +229,16 @@ mod tests {
} }
fn assert_partition(v: &Vec<Point1WithId>, index_of_splitting_point: usize) -> bool { fn assert_partition(v: &Vec<Point1WithId>, index_of_splitting_point: usize) -> bool {
let pivot = v[index_of_splitting_point].dims()[0]; let pivot = v[index_of_splitting_point].dim(0);
for i in 0..index_of_splitting_point { for i in 0..index_of_splitting_point {
if v[i].dims()[0] > pivot { if v[i].dim(0) > pivot {
return false; return false;
} }
} }
for i in index_of_splitting_point + 1..v.len() { for i in index_of_splitting_point + 1..v.len() {
if v[i].dims()[0] < pivot { if v[i].dim(0) < pivot {
return false; return false;
} }
} }

View File

@ -15,10 +15,15 @@ impl Point3WithId {
} }
} }
impl KdTreePoint for Point3WithId { impl KdTreePoint<f64> for Point3WithId {
#[inline] #[inline]
fn dims(&self) -> &[f64] { fn dims(&self) -> usize {
return &self.dims; self.dims.len()
}
#[inline]
fn dim(&self, i: usize) -> f64 {
self.dims[i]
} }
} }
@ -37,10 +42,15 @@ impl Point2WithId {
} }
} }
impl KdTreePoint for Point2WithId { impl KdTreePoint<f64> for Point2WithId {
#[inline] #[inline]
fn dims(&self) -> &[f64] { fn dims(&self) -> usize {
return &self.dims; self.dims.len()
}
#[inline]
fn dim(&self, i: usize) -> f64 {
self.dims[i]
} }
} }
@ -59,9 +69,14 @@ impl Point1WithId {
} }
} }
impl KdTreePoint for Point1WithId { impl KdTreePoint<f64> for Point1WithId {
#[inline] #[inline]
fn dims(&self) -> &[f64] { fn dims(&self) -> usize {
return &self.dims; self.dims.len()
}
#[inline]
fn dim(&self, i: usize) -> f64 {
self.dims[i]
} }
} }

View File

@ -7,5 +7,3 @@ extern crate rand;
pub mod kdtree; pub mod kdtree;
#[cfg(test)]
mod bench;

View File

@ -1,22 +1,19 @@
extern crate kdtree;
extern crate rand;
use rand::Rng; use rand::Rng;
use kdtree::kdtree::test_common::*; use kdtree::kdtree::test_common::*;
use kdtree::kdtree::KdTreePoint; use kdtree::kdtree::KdTreePoint;
use kdtree::kdtree::distance::squared_euclidean; use kdtree::kdtree::KdTree;
fn gen_random() -> f64 { fn gen_random() -> f64 {
rand::thread_rng().gen_range(0., 1000.) rand::thread_rng().gen_range(0., 1000.)
} }
fn find_nn_with_linear_search(points : &Vec<Point3WithId>, find_for : Point3WithId) -> (f64, &Point3WithId) { fn find_nn_with_linear_search(points : &Vec<Point3WithId>, find_for : Point3WithId) -> (f64, &Point3WithId) {
let mut best_found_distance = squared_euclidean(find_for.dims(), points[0].dims()); let mut best_found_distance = find_for.dist(&points[0]);
let mut closed_found_point = &points[0]; let mut closed_found_point = &points[0];
for p in points { for p in points {
let dist = squared_euclidean(find_for.dims(), p.dims()); let dist = find_for.dist(p);
if dist < best_found_distance { if dist < best_found_distance {
best_found_distance = dist; best_found_distance = dist;
@ -31,7 +28,7 @@ fn find_neigbours_with_linear_search(points : &Vec<Point3WithId>, find_for : Poi
let mut result = Vec::new(); let mut result = Vec::new();
for p in points { for p in points {
let d = squared_euclidean(find_for.dims(), p.dims()); let d = find_for.dist(p);
if d <= dist { if d <= dist {
result.push((d, p)); result.push((d, p));
@ -56,9 +53,9 @@ fn generate_points(point_count : usize) -> Vec<Point3WithId> {
fn test_against_1000_random_points() { fn test_against_1000_random_points() {
let point_count = 1000usize; let point_count = 1000usize;
let points = generate_points(point_count); let points = generate_points(point_count);
kdtree::kdtree::test_common::Point1WithId::new(0,0.); Point1WithId::new(0,0.);
let tree = kdtree::kdtree::KdTree::new(&mut points.clone()); let tree = KdTree::new(&mut points.clone());
//test points pushed into the tree, id should be equal. //test points pushed into the tree, id should be equal.
for i in 0 .. point_count { for i in 0 .. point_count {
@ -83,8 +80,8 @@ fn test_incrementally_build_tree_against_built_at_once() {
let point_count = 2000usize; let point_count = 2000usize;
let mut points = generate_points(point_count); let mut points = generate_points(point_count);
let tree_built_at_once = kdtree::kdtree::KdTree::new(&mut points.clone()); let tree_built_at_once = KdTree::new(&mut points.clone());
let mut tree_built_incrementally = kdtree::kdtree::KdTree::new(&mut points[0..1]); let mut tree_built_incrementally = KdTree::new(&mut points[0..1]);
for i in 1 .. point_count { for i in 1 .. point_count {
let p = &points[i]; let p = &points[i];
@ -113,7 +110,7 @@ fn test_incrementally_build_tree_against_built_at_once() {
fn test_neighbour_search_with_distance() { fn test_neighbour_search_with_distance() {
let point_count = 1000usize; let point_count = 1000usize;
let points = generate_points(point_count); let points = generate_points(point_count);
let tree = kdtree::kdtree::KdTree::new(&mut points.clone()); let tree = KdTree::new(&mut points.clone());
for _ in 0 .. 500 { for _ in 0 .. 500 {
let dist = 100.0; let dist = 100.0;