Added kdtree-lookup

This commit is contained in:
Olek 2016-12-26 20:53:39 +01:00
parent 7a6734cd4d
commit 9ac0a899e2
2 changed files with 138 additions and 15 deletions

40
src/kdtree/distance.rs Normal file
View File

@ -0,0 +1,40 @@
use ::kdtree::KdtreePointTrait;
pub fn squared_euclidean(a : &[f64], b: &[f64]) -> f64 {
debug_assert!(a.len() == b.len());
a.iter().zip(b.iter())
.map(|(x,y)| (x - y) * (x-y))
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn squared_euclidean_test_1d() {
let a = [2.];
let b = [4.];
let c = [-2.];
assert_eq!(0., squared_euclidean(&a,&a));
assert_eq!(4., squared_euclidean(&a,&b));
assert_eq!(16., squared_euclidean(&a,&c));
}
#[test]
fn squared_euclidean_test_2d() {
let a = [2.,2.];
let b = [4.,2.];
let c = [4.,4.];
assert_eq!(0., squared_euclidean(&a,&a));
assert_eq!(4., squared_euclidean(&a,&b));
assert_eq!(8., squared_euclidean(&a,&c));
}
}

View File

@ -1,9 +1,10 @@
mod test_common;
mod partition;
mod bounds;
mod distance;
use ::std::cmp;
use self::bounds::*;
use self::distance::*;
pub trait KdtreePointTrait {
fn dims(&self) -> &[f64];
@ -30,8 +31,56 @@ impl<T: KdtreePointTrait + Copy> Kdtree<T> {
tree
}
fn add_node(&mut self, p: T) -> usize {
let node = KdtreeNode::new(p);
pub fn nearest_search(&self, node : &T) -> T
{
let mut nearest_neighbor = 0usize;
let mut best_distance = squared_euclidean(node.dims(), &self.nodes[0].point.dims());
self.nearest_search_impl(node, 0usize, &mut best_distance , &mut nearest_neighbor);
self.nodes[nearest_neighbor].point
}
fn nearest_search_impl(&self, p : &T, searched_index: usize, best_distance_squared : &mut f64, best_leaf_found : &mut usize) {
let node = &self.nodes[searched_index];
let dimension = node.dimension;
let splitting_value = node.split_on;
let point_splitting_dim_value = p.dims()[dimension];
let mut closer_node : Option<usize>;
let mut farther_node : Option<usize>;
if point_splitting_dim_value <= splitting_value {
closer_node = node.left_node;
farther_node = node.right_node;
} else {
closer_node = node.right_node;
farther_node = node.left_node;
}
if closer_node.is_some() {
self.nearest_search_impl(p, closer_node.unwrap(), best_distance_squared, best_leaf_found);
}
let distance = squared_euclidean(p.dims(), node.point.dims());
if distance < *best_distance_squared {
*best_distance_squared = distance;
*best_leaf_found = searched_index;
}
if(farther_node.is_some()) {
let distance_on_single_dimension = squared_euclidean(&[splitting_value],&[point_splitting_dim_value]);
if distance_on_single_dimension <= *best_distance_squared {
self.nearest_search_impl(p, farther_node.unwrap(), best_distance_squared, best_leaf_found);
}
}
}
fn add_node(&mut self, p: T, dimension : usize, split_on : f64) -> usize {
let node = KdtreeNode::new(p, dimension, split_on );
self.nodes.push(node);
self.nodes.len() - 1
@ -40,7 +89,7 @@ impl<T: KdtreePointTrait + Copy> Kdtree<T> {
fn build_tree(&mut self, nodes: &mut [T], bounds: &Bounds) -> usize {
let (splitting_index, pivot_value) = partition::partition_sliding_midpoint(nodes, bounds.get_midvalue_of_widest_dim(), bounds.get_widest_dim());
let node_id = self.add_node(nodes[splitting_index]);
let node_id = self.add_node(nodes[splitting_index], bounds.get_widest_dim(), bounds.get_midvalue_of_widest_dim());
let nodes_len = nodes.len();
if splitting_index > 0 {
@ -65,15 +114,19 @@ pub struct KdtreeNode<T> {
right_node: Option<usize>,
point: T,
dimension: usize,
split_on: f64
}
impl<T: KdtreePointTrait> KdtreeNode<T> {
fn new(p: T) -> KdtreeNode<T> {
fn new(p: T, splitting_dimension: usize, split_on_value : f64) -> KdtreeNode<T> {
KdtreeNode {
left_node: None,
right_node: None,
point: p,
dimension : splitting_dimension,
split_on : split_on_value
}
}
}
@ -82,6 +135,7 @@ impl<T: KdtreePointTrait> KdtreeNode<T> {
#[cfg(test)]
mod tests {
use ::kdtree::test_common::tests_utils::Point2WithId;
use super::*;
#[test]
@ -104,13 +158,11 @@ mod tests {
vec.push(p);
}
let tree = Kdtree::new(vec);
let tree = Kdtree::new(qc_value_vec_to_2d_points_vec(&xs));
let mut to_iterate : Vec<usize> = vec![];
to_iterate.push(0);
let mut str = String::new();
while to_iterate.len() > 0 {
let last_index = to_iterate.last().unwrap().clone();
let ref x = tree.nodes.get(last_index).unwrap();
@ -121,15 +173,46 @@ mod tests {
if x.right_node.is_some() {
to_iterate.push(x.right_node.unwrap());
}
str.push_str(&format!("Index: {} has ln {} has rn {} \n", last_index, x.left_node.is_some(), x.right_node.is_some()));
}
// println!("str is: {}", str);
xs.len() == tree.nodes.len()
}
}
quickcheck! {
fn nearest_neighbor_search_using_qc(xs : Vec<f64>) -> bool {
if(xs.len() == 0) {
return true;
}
let point_vec = qc_value_vec_to_2d_points_vec(&xs);
let tree = Kdtree::new(point_vec.clone());
for p in &point_vec {
let found_nn = tree.nearest_search(p);
assert_eq!(p.id,found_nn.id);
}
true
}
}
fn qc_value_vec_to_2d_points_vec(xs : &Vec<f64>) -> Vec<Point2WithId> {
let mut vec : Vec<Point2WithId> = vec![];
for i in 0 .. xs.len() {
let mut is_duplicated_value = false;
for j in 0 .. i {
if xs[i] == xs[j] {
is_duplicated_value = true;
break;
}
}
if !is_duplicated_value {
let p = Point2WithId::new(i as i32, xs[i], xs[i]);
vec.push(p);
}
}
vec
}
}