mirror of
https://github.com/andreytkachenko/kdtree-rust.git
synced 2024-11-22 09:26:25 +04:00
Added kdtree-lookup
This commit is contained in:
parent
7a6734cd4d
commit
9ac0a899e2
40
src/kdtree/distance.rs
Normal file
40
src/kdtree/distance.rs
Normal file
@ -0,0 +1,40 @@
|
||||
use ::kdtree::KdtreePointTrait;
|
||||
|
||||
pub fn squared_euclidean(a : &[f64], b: &[f64]) -> f64 {
|
||||
debug_assert!(a.len() == b.len());
|
||||
|
||||
a.iter().zip(b.iter())
|
||||
.map(|(x,y)| (x - y) * (x-y))
|
||||
.sum()
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn squared_euclidean_test_1d() {
|
||||
let a = [2.];
|
||||
let b = [4.];
|
||||
let c = [-2.];
|
||||
|
||||
assert_eq!(0., squared_euclidean(&a,&a));
|
||||
|
||||
assert_eq!(4., squared_euclidean(&a,&b));
|
||||
|
||||
assert_eq!(16., squared_euclidean(&a,&c));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn squared_euclidean_test_2d() {
|
||||
let a = [2.,2.];
|
||||
let b = [4.,2.];
|
||||
let c = [4.,4.];
|
||||
|
||||
assert_eq!(0., squared_euclidean(&a,&a));
|
||||
|
||||
assert_eq!(4., squared_euclidean(&a,&b));
|
||||
|
||||
assert_eq!(8., squared_euclidean(&a,&c));
|
||||
}
|
||||
}
|
@ -1,9 +1,10 @@
|
||||
mod test_common;
|
||||
mod partition;
|
||||
mod bounds;
|
||||
mod distance;
|
||||
|
||||
use ::std::cmp;
|
||||
use self::bounds::*;
|
||||
use self::distance::*;
|
||||
|
||||
pub trait KdtreePointTrait {
|
||||
fn dims(&self) -> &[f64];
|
||||
@ -30,8 +31,56 @@ impl<T: KdtreePointTrait + Copy> Kdtree<T> {
|
||||
tree
|
||||
}
|
||||
|
||||
fn add_node(&mut self, p: T) -> usize {
|
||||
let node = KdtreeNode::new(p);
|
||||
pub fn nearest_search(&self, node : &T) -> T
|
||||
{
|
||||
let mut nearest_neighbor = 0usize;
|
||||
let mut best_distance = squared_euclidean(node.dims(), &self.nodes[0].point.dims());
|
||||
self.nearest_search_impl(node, 0usize, &mut best_distance , &mut nearest_neighbor);
|
||||
|
||||
self.nodes[nearest_neighbor].point
|
||||
}
|
||||
|
||||
fn nearest_search_impl(&self, p : &T, searched_index: usize, best_distance_squared : &mut f64, best_leaf_found : &mut usize) {
|
||||
let node = &self.nodes[searched_index];
|
||||
|
||||
let dimension = node.dimension;
|
||||
let splitting_value = node.split_on;
|
||||
let point_splitting_dim_value = p.dims()[dimension];
|
||||
|
||||
let mut closer_node : Option<usize>;
|
||||
let mut farther_node : Option<usize>;
|
||||
|
||||
if point_splitting_dim_value <= splitting_value {
|
||||
closer_node = node.left_node;
|
||||
farther_node = node.right_node;
|
||||
} else {
|
||||
closer_node = node.right_node;
|
||||
farther_node = node.left_node;
|
||||
}
|
||||
|
||||
if closer_node.is_some() {
|
||||
self.nearest_search_impl(p, closer_node.unwrap(), best_distance_squared, best_leaf_found);
|
||||
}
|
||||
|
||||
let distance = squared_euclidean(p.dims(), node.point.dims());
|
||||
if distance < *best_distance_squared {
|
||||
*best_distance_squared = distance;
|
||||
*best_leaf_found = searched_index;
|
||||
}
|
||||
|
||||
if(farther_node.is_some()) {
|
||||
let distance_on_single_dimension = squared_euclidean(&[splitting_value],&[point_splitting_dim_value]);
|
||||
|
||||
if distance_on_single_dimension <= *best_distance_squared {
|
||||
self.nearest_search_impl(p, farther_node.unwrap(), best_distance_squared, best_leaf_found);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
fn add_node(&mut self, p: T, dimension : usize, split_on : f64) -> usize {
|
||||
let node = KdtreeNode::new(p, dimension, split_on );
|
||||
|
||||
self.nodes.push(node);
|
||||
self.nodes.len() - 1
|
||||
@ -40,7 +89,7 @@ impl<T: KdtreePointTrait + Copy> Kdtree<T> {
|
||||
fn build_tree(&mut self, nodes: &mut [T], bounds: &Bounds) -> usize {
|
||||
let (splitting_index, pivot_value) = partition::partition_sliding_midpoint(nodes, bounds.get_midvalue_of_widest_dim(), bounds.get_widest_dim());
|
||||
|
||||
let node_id = self.add_node(nodes[splitting_index]);
|
||||
let node_id = self.add_node(nodes[splitting_index], bounds.get_widest_dim(), bounds.get_midvalue_of_widest_dim());
|
||||
let nodes_len = nodes.len();
|
||||
|
||||
if splitting_index > 0 {
|
||||
@ -65,15 +114,19 @@ pub struct KdtreeNode<T> {
|
||||
right_node: Option<usize>,
|
||||
|
||||
point: T,
|
||||
dimension: usize,
|
||||
split_on: f64
|
||||
}
|
||||
|
||||
impl<T: KdtreePointTrait> KdtreeNode<T> {
|
||||
fn new(p: T) -> KdtreeNode<T> {
|
||||
fn new(p: T, splitting_dimension: usize, split_on_value : f64) -> KdtreeNode<T> {
|
||||
KdtreeNode {
|
||||
left_node: None,
|
||||
right_node: None,
|
||||
|
||||
point: p,
|
||||
dimension : splitting_dimension,
|
||||
split_on : split_on_value
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -82,6 +135,7 @@ impl<T: KdtreePointTrait> KdtreeNode<T> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use ::kdtree::test_common::tests_utils::Point2WithId;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
@ -104,13 +158,11 @@ mod tests {
|
||||
vec.push(p);
|
||||
}
|
||||
|
||||
let tree = Kdtree::new(vec);
|
||||
let tree = Kdtree::new(qc_value_vec_to_2d_points_vec(&xs));
|
||||
|
||||
let mut to_iterate : Vec<usize> = vec![];
|
||||
to_iterate.push(0);
|
||||
|
||||
let mut str = String::new();
|
||||
|
||||
while to_iterate.len() > 0 {
|
||||
let last_index = to_iterate.last().unwrap().clone();
|
||||
let ref x = tree.nodes.get(last_index).unwrap();
|
||||
@ -121,15 +173,46 @@ mod tests {
|
||||
if x.right_node.is_some() {
|
||||
to_iterate.push(x.right_node.unwrap());
|
||||
}
|
||||
|
||||
str.push_str(&format!("Index: {} has ln {} has rn {} \n", last_index, x.left_node.is_some(), x.right_node.is_some()));
|
||||
|
||||
|
||||
}
|
||||
|
||||
// println!("str is: {}", str);
|
||||
|
||||
xs.len() == tree.nodes.len()
|
||||
}
|
||||
}
|
||||
|
||||
quickcheck! {
|
||||
fn nearest_neighbor_search_using_qc(xs : Vec<f64>) -> bool {
|
||||
if(xs.len() == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
let point_vec = qc_value_vec_to_2d_points_vec(&xs);
|
||||
let tree = Kdtree::new(point_vec.clone());
|
||||
|
||||
for p in &point_vec {
|
||||
let found_nn = tree.nearest_search(p);
|
||||
|
||||
assert_eq!(p.id,found_nn.id);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn qc_value_vec_to_2d_points_vec(xs : &Vec<f64>) -> Vec<Point2WithId> {
|
||||
let mut vec : Vec<Point2WithId> = vec![];
|
||||
for i in 0 .. xs.len() {
|
||||
let mut is_duplicated_value = false;
|
||||
for j in 0 .. i {
|
||||
if xs[i] == xs[j] {
|
||||
is_duplicated_value = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !is_duplicated_value {
|
||||
let p = Point2WithId::new(i as i32, xs[i], xs[i]);
|
||||
vec.push(p);
|
||||
}
|
||||
}
|
||||
|
||||
vec
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user