Add neighbour_search_with_distance

This commit is contained in:
Andrey Tkachenko 2018-12-07 20:35:51 +04:00
parent d9769850d8
commit caa2e5cbf8
6 changed files with 171 additions and 65 deletions

View File

@ -26,7 +26,7 @@ fn bench_creating_1000_node_tree(b: &mut Bencher) {
let points = generate_points(len);
b.iter(|| {
kdtree::kdtree::Kdtree::new(&mut points.clone());
kdtree::kdtree::KdTree::new(&mut points.clone());
});
}
@ -34,7 +34,7 @@ fn bench_single_loop_times_for_1000_node_tree(b: &mut Bencher) {
let len = 1000usize;
let points = generate_points(len);
let tree = kdtree::kdtree::Kdtree::new(&mut points.clone());
let tree = kdtree::kdtree::KdTree::new(&mut points.clone());
b.iter(|| tree.nearest_search(&points[0]));
@ -46,14 +46,14 @@ fn bench_creating_1000_000_node_tree(b: &mut Bencher) {
let points = generate_points(len);
b.iter(|| {
kdtree::kdtree::Kdtree::new(&mut points.clone());
kdtree::kdtree::KdTree::new(&mut points.clone());
});
}
fn bench_adding_same_node_to_1000_tree(b: &mut Bencher) {
let len = 1000usize;
let mut points = generate_points(len);
let mut tree = kdtree::kdtree::Kdtree::new(&mut points);
let mut tree = kdtree::kdtree::KdTree::new(&mut points);
let point = Point3WithId::new(-1 as i32, gen_random(), gen_random(), gen_random());
b.iter(|| {
@ -65,7 +65,7 @@ fn bench_incrementally_building_the_1000_tree(b: &mut Bencher) {
b.iter(|| {
let len = 1usize;
let mut points = generate_points(len);
let mut tree = kdtree::kdtree::Kdtree::new(&mut points);
let mut tree = kdtree::kdtree::KdTree::new(&mut points);
for _ in 0 .. 1000 {
let point = Point3WithId::new(-1 as i32, gen_random(), gen_random(), gen_random());
tree.insert_node(point);

View File

@ -1,5 +1,6 @@
use ::kdtree::*;
#[derive(Clone, Copy)]
pub struct Bounds {
pub bounds: [(f64, f64); 3],
@ -8,7 +9,7 @@ pub struct Bounds {
}
impl Bounds {
pub fn new_from_points<T: KdtreePointTrait>(points: &[T]) -> Bounds {
pub fn new_from_points<T: KdTreePoint>(points: &[T]) -> Bounds {
let mut bounds = Bounds {
bounds: [(0., 0.), (0., 0.), (0., 0.)],
widest_dim: 0,
@ -45,6 +46,7 @@ impl Bounds {
bounds: self.bounds.clone(),
..*self
};
cloned.bounds[dimension].1 = value;
cloned.calculate_variables();

View File

@ -9,27 +9,82 @@ use self::distance::*;
use std::cmp;
pub trait KdtreePointTrait: Copy + PartialEq {
fn dims(&self) -> &[f64];
pub trait KdTreePoint: Copy + PartialEq {
fn dist_1d(left: f64, right: f64, dim: usize) -> f64 {
let diff = left - right;
diff * diff
}
pub struct Kdtree<KdtreePoint> {
nodes: Vec<KdtreeNode<KdtreePoint>>,
fn dims(&self) -> &[f64];
fn dist(&self, other: &Self) -> f64 {
squared_euclidean(self.dims(), other.dims())
}
}
pub struct NearestNeighboursIter<'a, 'b, T> {
range: f64,
kdtree: &'a KdTree<T>,
ref_node: &'b T,
node_stack: Vec<usize>,
}
impl<'a, 'b, T> Iterator for NearestNeighboursIter<'a, 'b, T>
where T: KdTreePoint
{
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
let p = self.ref_node;
loop {
let node_idx = self.node_stack.pop()?;
let node = &self.kdtree.nodes[node_idx];
let splitting_value = node.split_on;
let point_splitting_dim_value = p.dims()[node.dimension];
let distance_on_single_dimension = T::dist_1d(splitting_value, point_splitting_dim_value, node.dimension);
if distance_on_single_dimension <= self.range {
if let Some(idx) = node.left_node {
self.node_stack.push(idx);
}
if let Some(idx) = node.right_node {
self.node_stack.push(idx);
}
if p.dist(&node.point) <= self.range {
return Some(&node.point);
}
} else if point_splitting_dim_value <= splitting_value {
if let Some(idx) = node.left_node {
self.node_stack.push(idx);
}
} else {
if let Some(idx) = node.right_node {
self.node_stack.push(idx);
}
}
}
}
}
pub struct KdTree<KP> {
nodes: Vec<KdTreeNode<KP>>,
node_adding_dimension: usize,
node_depth_during_last_rebuild: usize,
current_node_depth: usize,
}
impl<KdtreePoint: KdtreePointTrait> Kdtree<KdtreePoint> {
pub fn new(mut points: &mut [KdtreePoint]) -> Kdtree<KdtreePoint> {
impl<KP: KdTreePoint> KdTree<KP> {
pub fn new(mut points: &mut [KP]) -> Self {
if points.len() == 0 {
panic!("empty vector point not allowed");
}
let mut tree = Kdtree {
let mut tree = KdTree {
nodes: vec![],
node_adding_dimension: 0,
node_depth_during_last_rebuild: 0,
@ -41,8 +96,9 @@ impl<KdtreePoint: KdtreePointTrait> Kdtree<KdtreePoint> {
tree
}
pub fn rebuild_tree(&mut self, points : &mut [KdtreePoint]) {
pub fn rebuild_tree(&mut self, points: &mut [KP]) {
self.nodes.clear();
self.nodes.reserve(points.len());
self.node_depth_during_last_rebuild = 0;
self.current_node_depth = 0;
@ -53,41 +109,52 @@ impl<KdtreePoint: KdtreePointTrait> Kdtree<KdtreePoint> {
/// Can be used if you are sure that the tree is degenerated or if you will never again insert the nodes into the tree.
pub fn gather_points_and_rebuild(&mut self) {
let mut points : Vec<KdtreePoint> = vec![];
let mut points : Vec<KP> = vec![];
self.gather_points(0,&mut points);
self.rebuild_tree(&mut points);
}
pub fn nearest_search(&self, node: &KdtreePoint) -> KdtreePoint
{
pub fn nearest_search(&self, node: &KP) -> KP {
let mut nearest_neighbor = 0usize;
let mut best_distance = squared_euclidean(node.dims(), &self.nodes[0].point.dims());
let mut best_distance = self.nodes[0].point.dist(&node);
self.nearest_search_impl(node, 0usize, &mut best_distance, &mut nearest_neighbor);
self.nodes[nearest_neighbor].point
}
pub fn has_neighbor_in_range(&self, node: &KdtreePoint, range: f64) -> bool {
pub fn nearest_search_dist<'a, 'b>(&'a self, node: &'b KP, dist: f64) -> NearestNeighboursIter<'a, 'b, KP> {
let mut node_stack = Vec::with_capacity(16);
node_stack.push(0);
NearestNeighboursIter {
range: dist,
kdtree: &self,
ref_node: node,
node_stack,
}
}
pub fn has_neighbor_in_range(&self, node: &KP, range: f64) -> bool {
let squared_range = range * range;
self.distance_squared_to_nearest(node) <= squared_range
}
pub fn distance_squared_to_nearest(&self, node: &KdtreePoint) -> f64 {
squared_euclidean(&self.nearest_search(node).dims(), node.dims())
pub fn distance_squared_to_nearest(&self, node: &KP) -> f64 {
self.nearest_search(node).dist(&node)
}
pub fn insert_nodes_and_rebuild(&mut self, nodes_to_add : &mut [KdtreePoint]) {
let mut pts : Vec<KdtreePoint> = vec![];
pub fn insert_nodes_and_rebuild(&mut self, nodes_to_add : &mut [KP]) {
let mut pts : Vec<KP> = vec![];
self.gather_points(0, &mut pts);
pts.extend(nodes_to_add.iter());
self.rebuild_tree(&mut pts);
}
pub fn insert_node(&mut self, node_to_add : KdtreePoint) {
pub fn insert_node(&mut self, node_to_add : KP) {
let mut current_index = 0;
let dimension = self.node_adding_dimension;
let index_of_new_node = self.add_node(node_to_add, dimension,node_to_add.dims()[dimension]);
@ -96,7 +163,6 @@ impl<KdtreePoint: KdtreePointTrait> Kdtree<KdtreePoint> {
let mut depth = 0;
loop {
depth +=1 ;
let current_node = &mut self.nodes[current_index];
@ -134,7 +200,7 @@ impl<KdtreePoint: KdtreePointTrait> Kdtree<KdtreePoint> {
}
}
fn nearest_search_impl(&self, p: &KdtreePoint, searched_index: usize, best_distance_squared: &mut f64, best_leaf_found: &mut usize) {
fn nearest_search_impl(&self, p: &KP, searched_index: usize, best_distance_squared: &mut f64, best_leaf_found: &mut usize) {
let node = &self.nodes[searched_index];
let splitting_value = node.split_on;
@ -150,14 +216,14 @@ impl<KdtreePoint: KdtreePointTrait> Kdtree<KdtreePoint> {
self.nearest_search_impl(p, closer_node, best_distance_squared, best_leaf_found);
}
let distance = squared_euclidean(p.dims(), node.point.dims());
let distance = p.dist(&node.point);
if distance < *best_distance_squared {
*best_distance_squared = distance;
*best_leaf_found = searched_index;
}
if let Some(farther_node) = farther_node {
let distance_on_single_dimension = squared_euclidean(&[splitting_value], &[point_splitting_dim_value]);
let distance_on_single_dimension = KP::dist_1d(splitting_value, point_splitting_dim_value, node.dimension);
if distance_on_single_dimension <= *best_distance_squared {
self.nearest_search_impl(p, farther_node, best_distance_squared, best_leaf_found);
@ -165,14 +231,14 @@ impl<KdtreePoint: KdtreePointTrait> Kdtree<KdtreePoint> {
}
}
fn add_node(&mut self, p: KdtreePoint, dimension: usize, split_on: f64) -> usize {
let node = KdtreeNode::new(p, dimension, split_on);
fn add_node(&mut self, p: KP, dimension: usize, split_on: f64) -> usize {
let node = KdTreeNode::new(p, dimension, split_on);
self.nodes.push(node);
self.nodes.len() - 1
}
fn build_tree(&mut self, nodes: &mut [KdtreePoint], bounds: &Bounds, depth : usize) -> usize {
fn build_tree(&mut self, nodes: &mut [KP], bounds: &Bounds, depth : usize) -> usize {
let splitting_index = partition::partition_sliding_midpoint(nodes, bounds.get_midvalue_of_widest_dim(), bounds.get_widest_dim());
let pivot_value = nodes[splitting_index].dims()[bounds.get_widest_dim()];
@ -197,8 +263,9 @@ impl<KdtreePoint: KdtreePointTrait> Kdtree<KdtreePoint> {
node_id
}
fn gather_points(&self, current_index: usize, points : &mut Vec<KdtreePoint>){
fn gather_points(&self, current_index: usize, points : &mut Vec<KP>){
points.push(self.nodes[current_index].point);
if let Some(left_index) = self.nodes[current_index].left_node {
self.gather_points(left_index, points);
}
@ -209,7 +276,7 @@ impl<KdtreePoint: KdtreePointTrait> Kdtree<KdtreePoint> {
}
}
pub struct KdtreeNode<T> {
pub struct KdTreeNode<T> {
left_node: Option<usize>,
right_node: Option<usize>,
@ -218,9 +285,9 @@ pub struct KdtreeNode<T> {
split_on: f64
}
impl<T: KdtreePointTrait> KdtreeNode<T> {
fn new(p: T, splitting_dimension: usize, split_on_value: f64) -> KdtreeNode<T> {
KdtreeNode {
impl<T: KdTreePoint> KdTreeNode<T> {
fn new(p: T, splitting_dimension: usize, split_on_value: f64) -> KdTreeNode<T> {
KdTreeNode {
left_node: None,
right_node: None,
@ -242,7 +309,7 @@ mod tests {
fn should_panic_given_empty_vector() {
let mut empty_vec: Vec<Point2WithId> = vec![];
Kdtree::new(&mut empty_vec);
KdTree::new(&mut empty_vec);
}
quickcheck! {
@ -257,7 +324,7 @@ mod tests {
vec.push(p);
}
let tree = Kdtree::new(&mut qc_value_vec_to_2d_points_vec(&xs));
let tree = KdTree::new(&mut qc_value_vec_to_2d_points_vec(&xs));
let mut to_iterate : Vec<usize> = vec![];
to_iterate.push(0);
@ -284,7 +351,7 @@ mod tests {
}
let point_vec = qc_value_vec_to_2d_points_vec(&xs);
let tree = Kdtree::new(&mut point_vec.clone());
let tree = KdTree::new(&mut point_vec.clone());
for p in &point_vec {
let found_nn = tree.nearest_search(p);
@ -300,7 +367,7 @@ mod tests {
fn has_neighbor_in_range() {
let mut vec: Vec<Point2WithId> = vec![Point2WithId::new(0,2.,0.)];
let tree = Kdtree::new(&mut vec);
let tree = KdTree::new(&mut vec);
assert_eq!(false,tree.has_neighbor_in_range(&Point2WithId::new(0,0.,0.), 0.));
assert_eq!(false,tree.has_neighbor_in_range(&Point2WithId::new(0,0.,0.), 1.));
@ -314,7 +381,7 @@ mod tests {
let mut vec = vec![Point2WithId::new(0,0.,0.)];
let mut tree = Kdtree::new(&mut vec);
let mut tree = KdTree::new(&mut vec);
tree.insert_node(Point2WithId::new(0,1.,0.));
tree.insert_node(Point2WithId::new(0,-1.,0.));
@ -333,7 +400,7 @@ mod tests {
fn incremental_add_filters_duplicates() {
let mut vec = vec![Point2WithId::new(0,0.,0.)];
let mut tree = Kdtree::new(&mut vec);
let mut tree = KdTree::new(&mut vec);
let node = Point2WithId::new(0,1.,0.);
tree.insert_node(node);

View File

@ -11,7 +11,7 @@ struct PartitionPointHelper {
index_of_splitter: usize,
}
fn partition_sliding_midpoint_helper<T: KdtreePointTrait>(vec: &mut [T], midpoint_value: f64, partition_on_dimension: usize) -> PartitionPointHelper {
fn partition_sliding_midpoint_helper<T: KdTreePoint>(vec: &mut [T], midpoint_value: f64, partition_on_dimension: usize) -> PartitionPointHelper {
let mut closest_index = 0;
let mut closest_distance = (vec[0].dims()[partition_on_dimension] - midpoint_value).abs();
@ -48,7 +48,7 @@ fn partition_sliding_midpoint_helper<T: KdtreePointTrait>(vec: &mut [T], midpoin
}
}
pub fn partition_sliding_midpoint<T: KdtreePointTrait>(vec: &mut [T], midpoint_value: f64, partition_on_dimension: usize) -> usize {
pub fn partition_sliding_midpoint<T: KdTreePoint>(vec: &mut [T], midpoint_value: f64, partition_on_dimension: usize) -> usize {
let vec_len = vec.len();
debug_assert!(vec[0].dims().len() > partition_on_dimension);
@ -74,7 +74,7 @@ pub fn partition_sliding_midpoint<T: KdtreePointTrait>(vec: &mut [T], midpoint_v
}
}
fn partition_kdtree<T: KdtreePointTrait>(vec: &mut [T], index_of_splitting_point: usize, partition_on_dimension: usize) -> usize {
fn partition_kdtree<T: KdTreePoint>(vec: &mut [T], index_of_splitting_point: usize, partition_on_dimension: usize) -> usize {
if vec.len() == 1 {
return 0;
}

View File

@ -1,6 +1,6 @@
use super::KdtreePointTrait;
use super::KdTreePoint;
#[derive(Copy, Clone, PartialEq)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct Point3WithId {
dims: [f64; 3],
pub id: i32,
@ -15,14 +15,14 @@ impl Point3WithId {
}
}
impl KdtreePointTrait for Point3WithId {
impl KdTreePoint for Point3WithId {
#[inline]
fn dims(&self) -> &[f64] {
return &self.dims;
}
}
#[derive(Copy, Clone, PartialEq)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct Point2WithId {
dims: [f64; 2],
pub id: i32,
@ -37,14 +37,14 @@ impl Point2WithId {
}
}
impl KdtreePointTrait for Point2WithId {
impl KdTreePoint for Point2WithId {
#[inline]
fn dims(&self) -> &[f64] {
return &self.dims;
}
}
#[derive(Copy, Clone, PartialEq)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct Point1WithId {
dims: [f64; 1],
pub id: i32,
@ -59,7 +59,7 @@ impl Point1WithId {
}
}
impl KdtreePointTrait for Point1WithId {
impl KdTreePoint for Point1WithId {
#[inline]
fn dims(&self) -> &[f64] {
return &self.dims;

View File

@ -4,21 +4,19 @@ extern crate rand;
use rand::Rng;
use kdtree::kdtree::test_common::*;
use kdtree::kdtree::KdtreePointTrait;
use kdtree::kdtree::KdTreePoint;
use kdtree::kdtree::distance::squared_euclidean;
fn gen_random() -> f64 {
rand::thread_rng().gen_range(0., 10000.)
rand::thread_rng().gen_range(0., 1000.)
}
fn find_nn_with_linear_search(points : &Vec<Point3WithId>, find_for : Point3WithId) -> &Point3WithId {
let distance_fun = kdtree::kdtree::distance::squared_euclidean;
let mut best_found_distance = distance_fun(find_for.dims(), points[0].dims());
let mut best_found_distance = squared_euclidean(find_for.dims(), points[0].dims());
let mut closed_found_point = &points[0];
for p in points {
let dist = distance_fun(find_for.dims(), p.dims());
let dist = squared_euclidean(find_for.dims(), p.dims());
if dist < best_found_distance {
best_found_distance = dist;
@ -29,6 +27,20 @@ fn find_nn_with_linear_search(points : &Vec<Point3WithId>, find_for : Point3With
closed_found_point
}
fn find_neigbours_with_linear_search(points : &Vec<Point3WithId>, find_for : Point3WithId, dist: f64) -> Vec<&Point3WithId> {
let mut result = Vec::new();
for p in points {
let d = squared_euclidean(find_for.dims(), p.dims());
if d <= dist {
result.push(p);
}
}
result
}
fn generate_points(point_count : usize) -> Vec<Point3WithId> {
let mut points : Vec<Point3WithId> = vec![];
@ -46,7 +58,7 @@ fn test_against_1000_random_points() {
let points = generate_points(point_count);
kdtree::kdtree::test_common::Point1WithId::new(0,0.);
let tree = kdtree::kdtree::Kdtree::new(&mut points.clone());
let tree = kdtree::kdtree::KdTree::new(&mut points.clone());
//test points pushed into the tree, id should be equal.
for i in 0 .. point_count {
@ -71,8 +83,8 @@ fn test_incrementally_build_tree_against_built_at_once() {
let point_count = 2000usize;
let mut points = generate_points(point_count);
let tree_built_at_once = kdtree::kdtree::Kdtree::new(&mut points.clone());
let mut tree_built_incrementally = kdtree::kdtree::Kdtree::new(&mut points[0..1]);
let tree_built_at_once = kdtree::kdtree::KdTree::new(&mut points.clone());
let mut tree_built_incrementally = kdtree::kdtree::KdTree::new(&mut points[0..1]);
for i in 1 .. point_count {
let p = &points[i];
@ -95,3 +107,28 @@ fn test_incrementally_build_tree_against_built_at_once() {
assert_eq!(tree_built_at_once.nearest_search(&p).id, tree_built_incrementally.nearest_search(&p).id);
}
}
#[test]
fn test_neighbour_search_with_distance() {
let point_count = 1000usize;
let points = generate_points(point_count);
let tree = kdtree::kdtree::KdTree::new(&mut points.clone());
for _ in 0 .. 500 {
let dist = 100.0;
let p = Point3WithId::new(0i32, gen_random(), gen_random(), gen_random());
let mut found_by_linear_search = find_neigbours_with_linear_search(&points, p, dist * dist);
let mut point_found_by_kdtree: Vec<_> = tree.nearest_search_dist(&p, dist * dist).collect();
assert_eq!(found_by_linear_search.len(), point_found_by_kdtree.len());
if point_found_by_kdtree.len() > 0 {
found_by_linear_search.sort_by(|a, b| a.id.cmp(&b.id));
point_found_by_kdtree.sort_by(|a, b| a.id.cmp(&b.id));
}
assert_eq!(point_found_by_kdtree, found_by_linear_search);
}
}