generic over float type

This commit is contained in:
Andrey Tkachenko 2020-04-28 22:33:03 +04:00
parent 09606936ad
commit c31758c883
9 changed files with 134 additions and 139 deletions

View File

@ -16,10 +16,12 @@ bench = false
[[bench]]
name = "bench"
path = "src/bench.rs"
harness = false
[dev-dependencies]
quickcheck = "~0.9"
rand = "~0.7"
bencher = "~0.1"
[dependencies]
num-traits = "0.2"

View File

@ -1,11 +1,10 @@
#[macro_use]
extern crate bencher;
extern crate kdtree;
extern crate rand;
use bencher::Bencher;
use rand::Rng;
use kdtree::kdtree::KdTree;
use kdtree::kdtree::test_common::*;
fn gen_random() -> f64 {
@ -27,7 +26,7 @@ fn bench_creating_1000_node_tree(b: &mut Bencher) {
let points = generate_points(len);
b.iter(|| {
kdtree::kdtree::KdTree::new(&mut points.clone());
KdTree::new(&mut points.clone());
});
}
@ -35,7 +34,7 @@ fn bench_single_loop_times_for_1000_node_tree(b: &mut Bencher) {
let len = 1000usize;
let points = generate_points(len);
let tree = kdtree::kdtree::KdTree::new(&mut points.clone());
let tree = KdTree::new(&mut points.clone());
b.iter(|| tree.nearest_search(&points[0]));
@ -47,14 +46,14 @@ fn bench_creating_1000_000_node_tree(b: &mut Bencher) {
let points = generate_points(len);
b.iter(|| {
kdtree::kdtree::KdTree::new(&mut points.clone());
KdTree::new(&mut points.clone());
});
}
fn bench_adding_same_node_to_1000_tree(b: &mut Bencher) {
let len = 1000usize;
let mut points = generate_points(len);
let mut tree = kdtree::kdtree::KdTree::new(&mut points);
let mut tree = KdTree::new(&mut points);
let point = Point3WithId::new(-1 as i32, gen_random(), gen_random(), gen_random());
b.iter(|| {
@ -66,7 +65,7 @@ fn bench_incrementally_building_the_1000_tree(b: &mut Bencher) {
b.iter(|| {
let len = 1usize;
let mut points = generate_points(len);
let mut tree = kdtree::kdtree::KdTree::new(&mut points);
let mut tree = KdTree::new(&mut points);
for _ in 0 .. 1000 {
let point = Point3WithId::new(-1 as i32, gen_random(), gen_random(), gen_random());
tree.insert_node(point);

View File

@ -1,31 +1,31 @@
use num_traits::Float;
use crate::kdtree::KdTreePoint;
#[derive(Clone, Copy)]
pub struct Bounds {
pub bounds: [(f64, f64); 3],
pub struct Bounds<F: Float> {
pub bounds: [(F, F); 3],
widest_dim: usize,
midvalue_of_widest_dim: f64,
midvalue_of_widest_dim: F,
}
impl Bounds {
pub fn new_from_points<T: KdTreePoint>(points: &[T]) -> Bounds {
impl<F: Float> Bounds<F> {
pub fn new_from_points<T: KdTreePoint<F>>(points: &[T]) -> Bounds<F> {
let mut bounds = Bounds {
bounds: [(0., 0.), (0., 0.), (0., 0.)],
bounds: [(F::zero(), F::zero()), (F::zero(), F::zero()), (F::zero(), F::zero())],
widest_dim: 0,
midvalue_of_widest_dim: 0.,
midvalue_of_widest_dim: F::zero(),
};
for i in 0..points[0].dims().len() {
bounds.bounds[i].0 = points[0].dims()[i];
bounds.bounds[i].1 = points[0].dims()[i];
for i in 0..points[0].dims() {
bounds.bounds[i].0 = points[0].dim(i);
bounds.bounds[i].1 = points[0].dim(i);
}
for v in points.iter() {
for dim in 0..v.dims().len() {
bounds.bounds[dim].0 = bounds.bounds[dim].0.min(v.dims()[dim]);
bounds.bounds[dim].1 = bounds.bounds[dim].1.max(v.dims()[dim]);
for dim in 0..v.dims() {
bounds.bounds[dim].0 = bounds.bounds[dim].0.min(v.dim(dim));
bounds.bounds[dim].1 = bounds.bounds[dim].1.max(v.dim(dim));
}
}
@ -34,15 +34,18 @@ impl Bounds {
bounds
}
#[inline]
pub fn get_widest_dim(&self) -> usize {
self.widest_dim
}
pub fn get_midvalue_of_widest_dim(&self) -> f64 {
#[inline]
pub fn get_midvalue_of_widest_dim(&self) -> F {
self.midvalue_of_widest_dim
}
pub fn clone_moving_max(&self, value: f64, dimension: usize) -> Bounds {
#[inline]
pub fn clone_moving_max(&self, value: F, dimension: usize) -> Bounds<F> {
let mut cloned = Bounds {
bounds: self.bounds.clone(),
..*self
@ -55,7 +58,7 @@ impl Bounds {
cloned
}
pub fn clone_moving_min(&self, value: f64, dimension: usize) -> Bounds {
pub fn clone_moving_min(&self, value: F, dimension: usize) -> Bounds<F> {
let mut cloned = Bounds {
bounds: self.bounds.clone(),
..*self
@ -85,7 +88,7 @@ impl Bounds {
fn calculate_variables(&mut self) {
self.calculate_widest_dim();
self.midvalue_of_widest_dim = (self.bounds[self.get_widest_dim()].0 + self.bounds[self.get_widest_dim()].1) / 2.0;
self.midvalue_of_widest_dim = (self.bounds[self.get_widest_dim()].0 + self.bounds[self.get_widest_dim()].1) / F::from(2.0f32).unwrap();
}
}

View File

@ -1,39 +0,0 @@
pub fn squared_euclidean(a: &[f64], b: &[f64]) -> f64 {
debug_assert!(a.len() == b.len());
a.iter().zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn squared_euclidean_test_1d() {
let a = [2.];
let b = [4.];
let c = [-2.];
assert_eq!(0., squared_euclidean(&a, &a));
assert_eq!(4., squared_euclidean(&a, &b));
assert_eq!(16., squared_euclidean(&a, &c));
}
#[test]
fn squared_euclidean_test_2d() {
let a = [2., 2.];
let b = [4., 2.];
let c = [4., 4.];
assert_eq!(0., squared_euclidean(&a, &a));
assert_eq!(4., squared_euclidean(&a, &b));
assert_eq!(8., squared_euclidean(&a, &c));
}
}

View File

@ -1,38 +1,55 @@
pub mod test_common;
pub mod distance;
mod partition;
mod bounds;
use self::bounds::*;
use self::distance::*;
use std::cmp;
use num_traits::Float;
use core::cmp;
pub trait KdTreePoint: Copy + PartialEq {
fn dist_1d(left: f64, right: f64, _dim: usize) -> f64 {
pub trait KdTreePoint<F: Float>: Copy + PartialEq {
fn dist_1d(left: F, right: F, _dim: usize) -> F {
let diff = left - right;
diff * diff
}
fn dims(&self) -> &[f64];
fn dist(&self, other: &Self) -> f64 {
squared_euclidean(self.dims(), other.dims())
fn dims(&self) -> usize;
fn dim(&self, i: usize) -> F;
fn dist(&self, other: &Self) -> F {
let mut sum = F::zero();
for i in 0..self.dims() {
let x = self.dim(i);
let y = other.dim(i);
let diff = x - y;
sum = sum + diff * diff;
}
sum
}
#[inline]
fn to_vec(&self) -> Vec<F> {
(0..self.dims())
.map(|x| self.dim(x))
.collect()
}
}
pub struct NearestNeighboursIter<'a, T> {
range: f64,
kdtree: &'a KdTree<T>,
pub struct NearestNeighboursIter<'a, F: Float, T> {
range: F,
kdtree: &'a KdTree<F, T>,
ref_node: T,
node_stack: Vec<usize>,
}
impl<'a, T> Iterator for NearestNeighboursIter<'a, T>
where T: KdTreePoint
impl<'a, F: Float, T> Iterator for NearestNeighboursIter<'a, F, T>
where T: KdTreePoint<F>
{
type Item = (f64, &'a T);
type Item = (F, &'a T);
fn next(&mut self) -> Option<Self::Item> {
let p = &self.ref_node;
@ -42,7 +59,7 @@ impl<'a, T> Iterator for NearestNeighboursIter<'a, T>
let node = &self.kdtree.nodes[node_idx];
let splitting_value = node.split_on;
let point_splitting_dim_value = p.dims()[node.dimension];
let point_splitting_dim_value = p.dim(node.dimension);
let distance_on_single_dimension = T::dist_1d(splitting_value, point_splitting_dim_value, node.dimension);
if distance_on_single_dimension <= self.range {
@ -71,15 +88,15 @@ impl<'a, T> Iterator for NearestNeighboursIter<'a, T>
}
}
pub struct KdTree<KP> {
nodes: Vec<KdTreeNode<KP>>,
pub struct KdTree<F: Float, KP> {
nodes: Vec<KdTreeNode<F, KP>>,
node_adding_dimension: usize,
node_depth_during_last_rebuild: usize,
current_node_depth: usize,
}
impl<KP: KdTreePoint> KdTree<KP> {
impl<F: Float, KP: KdTreePoint<F>> KdTree<F, KP> {
#[inline]
pub fn empty() -> Self {
KdTree {
@ -115,13 +132,13 @@ impl<KP: KdTreePoint> KdTree<KP> {
/// Can be used if you are sure that the tree is degenerated or if you will never again insert the nodes into the tree.
pub fn gather_points_and_rebuild(&mut self) {
let original = std::mem::replace(self, Self::empty());
let original = core::mem::replace(self, Self::empty());
let mut points: Vec<_> = original.into_iter().collect();
self.rebuild_tree(&mut points);
}
pub fn nearest_search(&self, node: &KP) -> (f64, &KP) {
pub fn nearest_search(&self, node: &KP) -> (F, &KP) {
let mut nearest_neighbor = 0usize;
let mut best_distance = self.nodes[0].point.dist(&node);
@ -130,7 +147,7 @@ impl<KP: KdTreePoint> KdTree<KP> {
(best_distance, &self.nodes[nearest_neighbor].point)
}
pub fn nearest_search_dist(&self, node: KP, dist: f64) -> NearestNeighboursIter<'_, KP> {
pub fn nearest_search_dist(&self, node: KP, dist: F) -> NearestNeighboursIter<'_, F, KP> {
let mut node_stack = Vec::with_capacity(16);
node_stack.push(0);
@ -142,13 +159,15 @@ impl<KP: KdTreePoint> KdTree<KP> {
}
}
pub fn has_neighbor_in_range(&self, node: &KP, range: f64) -> bool {
#[inline]
pub fn has_neighbor_in_range(&self, node: &KP, range: F) -> bool {
let squared_range = range * range;
self.distance_squared_to_nearest(node) <= squared_range
}
pub fn distance_squared_to_nearest(&self, node: &KP) -> f64 {
#[inline]
pub fn distance_squared_to_nearest(&self, node: &KP) -> F {
self.nearest_search(node).0
}
@ -165,7 +184,7 @@ impl<KP: KdTreePoint> KdTree<KP> {
pub fn insert_node(&mut self, node_to_add: KP) {
let mut current_index = 0;
let dimension = self.node_adding_dimension;
let dims = node_to_add.dims().to_vec();
let dims = node_to_add.to_vec();
let index_of_new_node = self.add_node(node_to_add, dimension,dims[dimension]);
self.node_adding_dimension = (dimension + 1) % dims.len();
@ -211,16 +230,16 @@ impl<KP: KdTreePoint> KdTree<KP> {
self.nodes.pop();
}
if self.node_depth_during_last_rebuild as f64 * 4.0 < depth as f64 {
if F::from(self.node_depth_during_last_rebuild).unwrap() * F::from(4.0).unwrap() < F::from(depth).unwrap() {
self.gather_points_and_rebuild();
}
}
fn nearest_search_impl(&self, p: &KP, searched_index: usize, best_distance_squared: &mut f64, best_leaf_found: &mut usize) {
fn nearest_search_impl(&self, p: &KP, searched_index: usize, best_distance_squared: &mut F, best_leaf_found: &mut usize) {
let node = &self.nodes[searched_index];
let splitting_value = node.split_on;
let point_splitting_dim_value = p.dims()[node.dimension];
let point_splitting_dim_value = p.dim(node.dimension);
let (closer_node, farther_node) = if point_splitting_dim_value <= splitting_value {
(node.left_node, node.right_node)
@ -247,16 +266,16 @@ impl<KP: KdTreePoint> KdTree<KP> {
}
}
fn add_node(&mut self, p: KP, dimension: usize, split_on: f64) -> usize {
fn add_node(&mut self, p: KP, dimension: usize, split_on: F) -> usize {
let node = KdTreeNode::new(p, dimension, split_on);
self.nodes.push(node);
self.nodes.len() - 1
}
fn build_tree(&mut self, nodes: &mut [KP], bounds: &Bounds, depth : usize) -> usize {
fn build_tree(&mut self, nodes: &mut [KP], bounds: &Bounds<F>, depth : usize) -> usize {
let splitting_index = partition::partition_sliding_midpoint(nodes, bounds.get_midvalue_of_widest_dim(), bounds.get_widest_dim());
let pivot_value = nodes[splitting_index].dims()[bounds.get_widest_dim()];
let pivot_value = nodes[splitting_index].dim(bounds.get_widest_dim());
let node_id = self.add_node(nodes[splitting_index], bounds.get_widest_dim(), pivot_value);
let nodes_len = nodes.len();
@ -287,17 +306,17 @@ impl<KP: KdTreePoint> KdTree<KP> {
}
}
pub struct KdTreeNode<T> {
pub struct KdTreeNode<F: Float, T> {
left_node: Option<usize>,
right_node: Option<usize>,
point: T,
dimension: usize,
split_on: f64
split_on: F
}
impl<T: KdTreePoint> KdTreeNode<T> {
fn new(p: T, splitting_dimension: usize, split_on_value: f64) -> KdTreeNode<T> {
impl<F: Float, T: KdTreePoint<F>> KdTreeNode<F, T> {
fn new(p: T, splitting_dimension: usize, split_on_value: F) -> KdTreeNode<F, T> {
KdTreeNode {
left_node: None,
right_node: None,
@ -401,8 +420,8 @@ mod tests {
assert_eq!(tree.nodes[0].dimension, 0);
assert_eq!(tree.nodes[0].left_node.is_some(), true);
assert_eq!(tree.nodes[1].point.dims()[0], 1.);
assert_eq!(tree.nodes[2].point.dims()[0], -1.);
assert_eq!(tree.nodes[1].point.dim(0), 1.);
assert_eq!(tree.nodes[2].point.dim(0), -1.);
assert_eq!(tree.nodes[0].right_node.is_some(), true);
}

View File

@ -1,3 +1,4 @@
use num_traits::Float;
use crate::kdtree::KdTreePoint;
enum PointsWereOnSide {
@ -11,9 +12,9 @@ struct PartitionPointHelper {
index_of_splitter: usize,
}
fn partition_sliding_midpoint_helper<T: KdTreePoint>(vec: &mut [T], midpoint_value: f64, partition_on_dimension: usize) -> PartitionPointHelper {
fn partition_sliding_midpoint_helper<F: Float, T: KdTreePoint<F>>(vec: &mut [T], midpoint_value: F, partition_on_dimension: usize) -> PartitionPointHelper {
let mut closest_index = 0;
let mut closest_distance = (vec[0].dims()[partition_on_dimension] - midpoint_value).abs();
let mut closest_distance = (vec[0].dim(partition_on_dimension) - midpoint_value).abs();
const HAS_POINTS_ON_LEFT_SIDE: i32 = 0b01;
const HAS_POINTS_ON_RIGHT_SIDE: i32 = 0b10;
@ -22,13 +23,13 @@ fn partition_sliding_midpoint_helper<T: KdTreePoint>(vec: &mut [T], midpoint_val
for i in 0..vec.len() {
let p = vec.get(i).unwrap();
if p.dims()[partition_on_dimension] <= midpoint_value {
if p.dim(partition_on_dimension) <= midpoint_value {
has_points_on_sides |= HAS_POINTS_ON_LEFT_SIDE;
} else {
has_points_on_sides |= HAS_POINTS_ON_RIGHT_SIDE;
}
let dist = (p.dims()[partition_on_dimension] - midpoint_value).abs();
let dist = (p.dim(partition_on_dimension) - midpoint_value).abs();
if dist < closest_distance {
closest_distance = dist;
@ -48,9 +49,9 @@ fn partition_sliding_midpoint_helper<T: KdTreePoint>(vec: &mut [T], midpoint_val
}
}
pub fn partition_sliding_midpoint<T: KdTreePoint>(vec: &mut [T], midpoint_value: f64, partition_on_dimension: usize) -> usize {
pub fn partition_sliding_midpoint<F: Float, T: KdTreePoint<F>>(vec: &mut [T], midpoint_value: F, partition_on_dimension: usize) -> usize {
let vec_len = vec.len();
debug_assert!(vec[0].dims().len() > partition_on_dimension);
debug_assert!(vec[0].dims() > partition_on_dimension);
if vec.len() == 1 {
return 0;
@ -74,12 +75,12 @@ pub fn partition_sliding_midpoint<T: KdTreePoint>(vec: &mut [T], midpoint_value:
}
}
fn partition_kdtree<T: KdTreePoint>(vec: &mut [T], index_of_splitting_point: usize, partition_on_dimension: usize) -> usize {
fn partition_kdtree<F: Float, T: KdTreePoint<F>>(vec: &mut [T], index_of_splitting_point: usize, partition_on_dimension: usize) -> usize {
if vec.len() == 1 {
return 0;
}
let pivot = vec[index_of_splitting_point].dims()[partition_on_dimension];
let pivot = vec[index_of_splitting_point].dim(partition_on_dimension);
let vec_len = vec.len();
vec.swap(index_of_splitting_point, vec_len - 1);
@ -90,11 +91,11 @@ fn partition_kdtree<T: KdTreePoint>(vec: &mut [T], index_of_splitting_point: usi
//variant of Lomuto algo.
loop {
while left <= right && vec[left].dims()[partition_on_dimension] <= pivot {
while left <= right && vec[left].dim(partition_on_dimension) <= pivot {
left += 1;
}
while right > left && vec[right].dims()[partition_on_dimension] > pivot {
while right > left && vec[right].dim(partition_on_dimension) > pivot {
right -= 1;
}
@ -109,10 +110,10 @@ fn partition_kdtree<T: KdTreePoint>(vec: &mut [T], index_of_splitting_point: usi
}
}
if last_succesful_swap == vec_len - 1 && vec[right].dims()[partition_on_dimension] > pivot {
if last_succesful_swap == vec_len - 1 && vec[right].dim(partition_on_dimension) > pivot {
vec.swap(right, last_succesful_swap);
last_succesful_swap = right;
} else if vec[left].dims()[partition_on_dimension] > pivot {
} else if vec[left].dim(partition_on_dimension) > pivot {
vec.swap(left, vec_len - 1);
last_succesful_swap = left;
} else {
@ -228,16 +229,16 @@ mod tests {
}
fn assert_partition(v: &Vec<Point1WithId>, index_of_splitting_point: usize) -> bool {
let pivot = v[index_of_splitting_point].dims()[0];
let pivot = v[index_of_splitting_point].dim(0);
for i in 0..index_of_splitting_point {
if v[i].dims()[0] > pivot {
if v[i].dim(0) > pivot {
return false;
}
}
for i in index_of_splitting_point + 1..v.len() {
if v[i].dims()[0] < pivot {
if v[i].dim(0) < pivot {
return false;
}
}

View File

@ -15,10 +15,15 @@ impl Point3WithId {
}
}
impl KdTreePoint for Point3WithId {
impl KdTreePoint<f64> for Point3WithId {
#[inline]
fn dims(&self) -> &[f64] {
return &self.dims;
fn dims(&self) -> usize {
self.dims.len()
}
#[inline]
fn dim(&self, i: usize) -> f64 {
self.dims[i]
}
}
@ -37,10 +42,15 @@ impl Point2WithId {
}
}
impl KdTreePoint for Point2WithId {
impl KdTreePoint<f64> for Point2WithId {
#[inline]
fn dims(&self) -> &[f64] {
return &self.dims;
fn dims(&self) -> usize {
self.dims.len()
}
#[inline]
fn dim(&self, i: usize) -> f64 {
self.dims[i]
}
}
@ -59,9 +69,14 @@ impl Point1WithId {
}
}
impl KdTreePoint for Point1WithId {
impl KdTreePoint<f64> for Point1WithId {
#[inline]
fn dims(&self) -> &[f64] {
return &self.dims;
fn dims(&self) -> usize {
self.dims.len()
}
#[inline]
fn dim(&self, i: usize) -> f64 {
self.dims[i]
}
}

View File

@ -7,5 +7,3 @@ extern crate rand;
pub mod kdtree;
#[cfg(test)]
mod bench;

View File

@ -1,22 +1,19 @@
extern crate kdtree;
extern crate rand;
use rand::Rng;
use kdtree::kdtree::test_common::*;
use kdtree::kdtree::KdTreePoint;
use kdtree::kdtree::distance::squared_euclidean;
use kdtree::kdtree::KdTree;
fn gen_random() -> f64 {
rand::thread_rng().gen_range(0., 1000.)
}
fn find_nn_with_linear_search(points : &Vec<Point3WithId>, find_for : Point3WithId) -> (f64, &Point3WithId) {
let mut best_found_distance = squared_euclidean(find_for.dims(), points[0].dims());
let mut best_found_distance = find_for.dist(&points[0]);
let mut closed_found_point = &points[0];
for p in points {
let dist = squared_euclidean(find_for.dims(), p.dims());
let dist = find_for.dist(p);
if dist < best_found_distance {
best_found_distance = dist;
@ -31,7 +28,7 @@ fn find_neigbours_with_linear_search(points : &Vec<Point3WithId>, find_for : Poi
let mut result = Vec::new();
for p in points {
let d = squared_euclidean(find_for.dims(), p.dims());
let d = find_for.dist(p);
if d <= dist {
result.push((d, p));
@ -56,9 +53,9 @@ fn generate_points(point_count : usize) -> Vec<Point3WithId> {
fn test_against_1000_random_points() {
let point_count = 1000usize;
let points = generate_points(point_count);
kdtree::kdtree::test_common::Point1WithId::new(0,0.);
Point1WithId::new(0,0.);
let tree = kdtree::kdtree::KdTree::new(&mut points.clone());
let tree = KdTree::new(&mut points.clone());
//test points pushed into the tree, id should be equal.
for i in 0 .. point_count {
@ -83,8 +80,8 @@ fn test_incrementally_build_tree_against_built_at_once() {
let point_count = 2000usize;
let mut points = generate_points(point_count);
let tree_built_at_once = kdtree::kdtree::KdTree::new(&mut points.clone());
let mut tree_built_incrementally = kdtree::kdtree::KdTree::new(&mut points[0..1]);
let tree_built_at_once = KdTree::new(&mut points.clone());
let mut tree_built_incrementally = KdTree::new(&mut points[0..1]);
for i in 1 .. point_count {
let p = &points[i];
@ -113,7 +110,7 @@ fn test_incrementally_build_tree_against_built_at_once() {
fn test_neighbour_search_with_distance() {
let point_count = 1000usize;
let points = generate_points(point_count);
let tree = kdtree::kdtree::KdTree::new(&mut points.clone());
let tree = KdTree::new(&mut points.clone());
for _ in 0 .. 500 {
let dist = 100.0;