From d3e2c23415ff6eea18f5520208d9fc0e60e40ae8 Mon Sep 17 00:00:00 2001 From: Michael Killough Date: Tue, 7 Nov 2017 12:12:20 +0000 Subject: [PATCH] Support for disallowed assignments. Add support for disallowed assignments, and include a check to avoid attempting to solve unsolvable matrices (where a row only has disallowed values). For now only allow `f32`/`f64`, as we can use `INFINITY` to encode a disallowed assignment. This is a re-implementation of most of bmc/munkres#20, but is missing the changes to `step6`, which attempt to detect when we're not making progress. I couldn't implement this without regressing performance and I am not sure they add much given we're checking `is_solvable()` before attempting to solve. --- src/lib.rs | 61 ++++++++++++++++++++++++++++++++++++++++++-- src/weight_matrix.rs | 35 ++++++++++++++++++++++--- 2 files changed, 90 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 50e4dac..afb625c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,7 +21,9 @@ extern crate ndarray; #[cfg(test)] extern crate test; +use std::{f32, f64}; use std::ops::{Add, Sub}; + use coverage::Coverage; use mark_matrix::MarkMatrix; pub use weight_matrix::WeightMatrix; @@ -36,6 +38,7 @@ pub mod weight_matrix; pub trait WeightNum: PartialOrd + Copy + Sub + Add { fn is_zero(&self) -> bool; + fn is_disallowed(&self) -> bool { false } } impl WeightNum for usize { #[inline(always)] fn is_zero(&self) -> bool { *self == 0 } } @@ -48,8 +51,28 @@ impl WeightNum for u16 { #[inline(always)] fn is_zero(&self) -> bool { *self == impl WeightNum for i16 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0 } } impl WeightNum for u8 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0 } } impl WeightNum for i8 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0 } } -impl WeightNum for f64 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0.0 } } -impl WeightNum for f32 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0.0 } } + +impl WeightNum for f64 { + #[inline(always)] + fn is_zero(&self) -> bool { + *self == 0.0 + } + + fn is_disallowed(&self) -> bool { + *self == f64::INFINITY + } +} + +impl WeightNum for f32 { + #[inline(always)] + fn is_zero(&self) -> bool { + *self == 0.0 + } + + fn is_disallowed(&self) -> bool { + *self == f32::INFINITY + } +} pub trait Weights { type T: WeightNum; @@ -64,6 +87,8 @@ pub trait Weights { fn is_element_zero(&self, pos: (usize, usize)) -> bool { self.element_at(pos).is_zero() } + + fn is_solvable(&self) -> bool; } #[derive(Debug, Eq, PartialEq)] @@ -275,6 +300,10 @@ fn step6(c: &mut W, cov: &Coverage) -> Step pub fn solve_assignment(weights: &mut W) -> Vec<(usize, usize)> where W: Weights { + if !weights.is_solvable() { + panic!("Matrix can not be solved"); + } + let n = weights.n(); let mut marks = MarkMatrix::new(n); @@ -719,6 +748,34 @@ fn test_solve_random10() { assert_eq!(exp, &matching[..]); } +#[test] +fn test_disallowed() { + let c = vec![ + 250.0, 400.0, 350.0, + 400.0, 600.0, f32::INFINITY, + 200.0, 400.0, 250.0 + ]; + + let mut weights: WeightMatrix = WeightMatrix::from_row_vec(3, c); + let matching = solve_assignment(&mut weights); + + assert_eq!(vec![(0, 1), (1, 0), (2, 2)], matching); +} + +#[test] +#[should_panic] +fn test_unsolvable() { + const N: usize = 3; + let c = vec![ + 1.0, 1.0, 1.0, + f32::INFINITY, f32::INFINITY, f32::INFINITY, + 1.0, 1.0, 1.0, + ]; + + let mut weights: WeightMatrix = WeightMatrix::from_row_vec(N, c.clone()); + solve_assignment(&mut weights); +} + #[cfg(test)] fn gen_matrix(n: usize) -> Vec { (0..n * n) diff --git a/src/weight_matrix.rs b/src/weight_matrix.rs index c9a1248..39f3499 100644 --- a/src/weight_matrix.rs +++ b/src/weight_matrix.rs @@ -31,12 +31,33 @@ impl Weights for WeightMatrix { // Add `val` to every element in row `row`. fn add_row(&mut self, row: usize, val: T) { - self.c.row_mut(row).mapv_inplace(|cur| cur + val); + self.c.row_mut(row).mapv_inplace(|cur| { + if cur.is_disallowed() { + cur + } else { + cur + val + } + }); } // Subtract `val` from every element in column `col`. fn sub_col(&mut self, col: usize, val: T) { - self.c.column_mut(col).mapv_inplace(|cur| cur - val); + self.c.column_mut(col).mapv_inplace(|cur| { + if cur.is_disallowed() { + cur + } else { + cur - val + } + }); + } + + fn is_solvable(&self) -> bool { + for row in 0..self.n() { + if self.c.row(row).iter().all(|c| c.is_disallowed()) { + return false; + } + } + true } } @@ -55,7 +76,7 @@ impl WeightMatrix { let row_slice = self.c.row(row); let mut min = row_slice[0]; for &val in row_slice.iter().skip(1) { - if val < min { + if !val.is_disallowed() && val < min { min = val; } } @@ -64,7 +85,13 @@ impl WeightMatrix { // Subtract `val` from every element in row `row`. fn sub_row(&mut self, row: usize, val: T) { - self.c.row_mut(row).mapv_inplace(|cur| cur - val); + self.c.row_mut(row).mapv_inplace(|cur| { + if cur.is_disallowed() { + cur + } else { + cur - val + } + }); } pub fn as_slice(&self) -> &[T] {