Support for disallowed assignments.
Add support for disallowed assignments, and include a check to avoid attempting to solve unsolvable matrices (where a row only has disallowed values). For now only allow `f32`/`f64`, as we can use `INFINITY` to encode a disallowed assignment. This is a re-implementation of most of bmc/munkres#20, but is missing the changes to `step6`, which attempt to detect when we're not making progress. I couldn't implement this without regressing performance and I am not sure they add much given we're checking `is_solvable()` before attempting to solve.
This commit is contained in:
parent
ad0d85a2c4
commit
d3e2c23415
61
src/lib.rs
61
src/lib.rs
@ -21,7 +21,9 @@ extern crate ndarray;
|
||||
#[cfg(test)]
|
||||
extern crate test;
|
||||
|
||||
use std::{f32, f64};
|
||||
use std::ops::{Add, Sub};
|
||||
|
||||
use coverage::Coverage;
|
||||
use mark_matrix::MarkMatrix;
|
||||
pub use weight_matrix::WeightMatrix;
|
||||
@ -36,6 +38,7 @@ pub mod weight_matrix;
|
||||
|
||||
pub trait WeightNum: PartialOrd + Copy + Sub<Output=Self> + Add<Output=Self> {
|
||||
fn is_zero(&self) -> bool;
|
||||
fn is_disallowed(&self) -> bool { false }
|
||||
}
|
||||
|
||||
impl WeightNum for usize { #[inline(always)] fn is_zero(&self) -> bool { *self == 0 } }
|
||||
@ -48,8 +51,28 @@ impl WeightNum for u16 { #[inline(always)] fn is_zero(&self) -> bool { *self ==
|
||||
impl WeightNum for i16 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0 } }
|
||||
impl WeightNum for u8 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0 } }
|
||||
impl WeightNum for i8 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0 } }
|
||||
impl WeightNum for f64 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0.0 } }
|
||||
impl WeightNum for f32 { #[inline(always)] fn is_zero(&self) -> bool { *self == 0.0 } }
|
||||
|
||||
impl WeightNum for f64 {
|
||||
#[inline(always)]
|
||||
fn is_zero(&self) -> bool {
|
||||
*self == 0.0
|
||||
}
|
||||
|
||||
fn is_disallowed(&self) -> bool {
|
||||
*self == f64::INFINITY
|
||||
}
|
||||
}
|
||||
|
||||
impl WeightNum for f32 {
|
||||
#[inline(always)]
|
||||
fn is_zero(&self) -> bool {
|
||||
*self == 0.0
|
||||
}
|
||||
|
||||
fn is_disallowed(&self) -> bool {
|
||||
*self == f32::INFINITY
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Weights {
|
||||
type T: WeightNum;
|
||||
@ -64,6 +87,8 @@ pub trait Weights {
|
||||
fn is_element_zero(&self, pos: (usize, usize)) -> bool {
|
||||
self.element_at(pos).is_zero()
|
||||
}
|
||||
|
||||
fn is_solvable(&self) -> bool;
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
@ -275,6 +300,10 @@ fn step6<W>(c: &mut W, cov: &Coverage) -> Step
|
||||
pub fn solve_assignment<W>(weights: &mut W) -> Vec<(usize, usize)>
|
||||
where W: Weights
|
||||
{
|
||||
if !weights.is_solvable() {
|
||||
panic!("Matrix can not be solved");
|
||||
}
|
||||
|
||||
let n = weights.n();
|
||||
|
||||
let mut marks = MarkMatrix::new(n);
|
||||
@ -719,6 +748,34 @@ fn test_solve_random10() {
|
||||
assert_eq!(exp, &matching[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_disallowed() {
|
||||
let c = vec![
|
||||
250.0, 400.0, 350.0,
|
||||
400.0, 600.0, f32::INFINITY,
|
||||
200.0, 400.0, 250.0
|
||||
];
|
||||
|
||||
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(3, c);
|
||||
let matching = solve_assignment(&mut weights);
|
||||
|
||||
assert_eq!(vec![(0, 1), (1, 0), (2, 2)], matching);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_unsolvable() {
|
||||
const N: usize = 3;
|
||||
let c = vec![
|
||||
1.0, 1.0, 1.0,
|
||||
f32::INFINITY, f32::INFINITY, f32::INFINITY,
|
||||
1.0, 1.0, 1.0,
|
||||
];
|
||||
|
||||
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c.clone());
|
||||
solve_assignment(&mut weights);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn gen_matrix(n: usize) -> Vec<i32> {
|
||||
(0..n * n)
|
||||
|
@ -31,12 +31,33 @@ impl<T: WeightNum> Weights for WeightMatrix<T> {
|
||||
|
||||
// Add `val` to every element in row `row`.
|
||||
fn add_row(&mut self, row: usize, val: T) {
|
||||
self.c.row_mut(row).mapv_inplace(|cur| cur + val);
|
||||
self.c.row_mut(row).mapv_inplace(|cur| {
|
||||
if cur.is_disallowed() {
|
||||
cur
|
||||
} else {
|
||||
cur + val
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Subtract `val` from every element in column `col`.
|
||||
fn sub_col(&mut self, col: usize, val: T) {
|
||||
self.c.column_mut(col).mapv_inplace(|cur| cur - val);
|
||||
self.c.column_mut(col).mapv_inplace(|cur| {
|
||||
if cur.is_disallowed() {
|
||||
cur
|
||||
} else {
|
||||
cur - val
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn is_solvable(&self) -> bool {
|
||||
for row in 0..self.n() {
|
||||
if self.c.row(row).iter().all(|c| c.is_disallowed()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,7 +76,7 @@ impl<T: WeightNum> WeightMatrix<T> {
|
||||
let row_slice = self.c.row(row);
|
||||
let mut min = row_slice[0];
|
||||
for &val in row_slice.iter().skip(1) {
|
||||
if val < min {
|
||||
if !val.is_disallowed() && val < min {
|
||||
min = val;
|
||||
}
|
||||
}
|
||||
@ -64,7 +85,13 @@ impl<T: WeightNum> WeightMatrix<T> {
|
||||
|
||||
// Subtract `val` from every element in row `row`.
|
||||
fn sub_row(&mut self, row: usize, val: T) {
|
||||
self.c.row_mut(row).mapv_inplace(|cur| cur - val);
|
||||
self.c.row_mut(row).mapv_inplace(|cur| {
|
||||
if cur.is_disallowed() {
|
||||
cur
|
||||
} else {
|
||||
cur - val
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn as_slice(&self) -> &[T] {
|
||||
|
Loading…
Reference in New Issue
Block a user