Support for disallowed assignments.

Add support for disallowed assignments, and include a check to avoid
attempting to solve unsolvable matrices (where a row only has disallowed
values).

For now only allow `f32`/`f64`, as we can use `INFINITY` to encode a
disallowed assignment.

This is a re-implementation of most of bmc/munkres#20, but is missing
the changes to `step6`, which attempt to detect when we're not making
progress. I couldn't implement this without regressing performance and I
am not sure they add much given we're checking `is_solvable()` before
attempting to solve.
This commit is contained in:
Michael Killough 2017-11-07 12:12:20 +00:00
parent ad0d85a2c4
commit d3e2c23415
2 changed files with 90 additions and 6 deletions

View File

@ -21,7 +21,9 @@ extern crate ndarray;
#[cfg(test)]
extern crate test;
use std::{f32, f64};
use std::ops::{Add, Sub};
use coverage::Coverage;
use mark_matrix::MarkMatrix;
pub use weight_matrix::WeightMatrix;
@ -36,6 +38,7 @@ pub mod weight_matrix;
pub trait WeightNum: PartialOrd + Copy + Sub<Output=Self> + Add<Output=Self> {
fn is_zero(&self) -> bool;
fn is_disallowed(&self) -> bool { false }
}
impl WeightNum for usize { #[inline(always)] fn is_zero(&self) -> bool { *self == 0 } }
@ -48,8 +51,28 @@ impl WeightNum for u16 { #[inline(always)] fn is_zero(&self) -> bool { *self ==
impl WeightNum for i16 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0 } }
impl WeightNum for u8 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0 } }
impl WeightNum for i8 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0 } }
impl WeightNum for f64 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0.0 } }
impl WeightNum for f32 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0.0 } }
impl WeightNum for f64 {
#[inline(always)]
fn is_zero(&self) -> bool {
*self == 0.0
}
fn is_disallowed(&self) -> bool {
*self == f64::INFINITY
}
}
impl WeightNum for f32 {
#[inline(always)]
fn is_zero(&self) -> bool {
*self == 0.0
}
fn is_disallowed(&self) -> bool {
*self == f32::INFINITY
}
}
pub trait Weights {
type T: WeightNum;
@ -64,6 +87,8 @@ pub trait Weights {
fn is_element_zero(&self, pos: (usize, usize)) -> bool {
self.element_at(pos).is_zero()
}
fn is_solvable(&self) -> bool;
}
#[derive(Debug, Eq, PartialEq)]
@ -275,6 +300,10 @@ fn step6<W>(c: &mut W, cov: &Coverage) -> Step
pub fn solve_assignment<W>(weights: &mut W) -> Vec<(usize, usize)>
where W: Weights
{
if !weights.is_solvable() {
panic!("Matrix can not be solved");
}
let n = weights.n();
let mut marks = MarkMatrix::new(n);
@ -719,6 +748,34 @@ fn test_solve_random10() {
assert_eq!(exp, &matching[..]);
}
#[test]
fn test_disallowed() {
let c = vec![
250.0, 400.0, 350.0,
400.0, 600.0, f32::INFINITY,
200.0, 400.0, 250.0
];
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(3, c);
let matching = solve_assignment(&mut weights);
assert_eq!(vec![(0, 1), (1, 0), (2, 2)], matching);
}
#[test]
#[should_panic]
fn test_unsolvable() {
const N: usize = 3;
let c = vec![
1.0, 1.0, 1.0,
f32::INFINITY, f32::INFINITY, f32::INFINITY,
1.0, 1.0, 1.0,
];
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c.clone());
solve_assignment(&mut weights);
}
#[cfg(test)]
fn gen_matrix(n: usize) -> Vec<i32> {
(0..n * n)

View File

@ -31,12 +31,33 @@ impl<T: WeightNum> Weights for WeightMatrix<T> {
// Add `val` to every element in row `row`.
fn add_row(&mut self, row: usize, val: T) {
self.c.row_mut(row).mapv_inplace(|cur| cur + val);
self.c.row_mut(row).mapv_inplace(|cur| {
if cur.is_disallowed() {
cur
} else {
cur + val
}
});
}
// Subtract `val` from every element in column `col`.
fn sub_col(&mut self, col: usize, val: T) {
self.c.column_mut(col).mapv_inplace(|cur| cur - val);
self.c.column_mut(col).mapv_inplace(|cur| {
if cur.is_disallowed() {
cur
} else {
cur - val
}
});
}
fn is_solvable(&self) -> bool {
for row in 0..self.n() {
if self.c.row(row).iter().all(|c| c.is_disallowed()) {
return false;
}
}
true
}
}
@ -55,7 +76,7 @@ impl<T: WeightNum> WeightMatrix<T> {
let row_slice = self.c.row(row);
let mut min = row_slice[0];
for &val in row_slice.iter().skip(1) {
if val < min {
if !val.is_disallowed() && val < min {
min = val;
}
}
@ -64,7 +85,13 @@ impl<T: WeightNum> WeightMatrix<T> {
// Subtract `val` from every element in row `row`.
fn sub_row(&mut self, row: usize, val: T) {
self.c.row_mut(row).mapv_inplace(|cur| cur - val);
self.c.row_mut(row).mapv_inplace(|cur| {
if cur.is_disallowed() {
cur
} else {
cur - val
}
});
}
pub fn as_slice(&self) -> &[T] {