Fix unsolvable bug, add more tests

If a column contains only invalid (e.g. infinite) weights, we cannot
solve that matrix. This resulted in an infinite loop.
This commit is contained in:
Michael Neumann 2019-04-07 16:22:04 +02:00
parent 8f04199cc0
commit 7e07dd5941
2 changed files with 75 additions and 16 deletions

View File

@ -688,23 +688,78 @@ fn test_solve_random10() {
fn test_invalid() { fn test_invalid() {
use std::f32; use std::f32;
const N: usize = 3; const N: usize = 3;
let c = vec![ let c = vec![
// row 0
250.0, 250.0,
400.0, 400.0, // (a)
350.0, 350.0,
400.0, // row 1
400.0, // (b)
600.0, 600.0,
f32::INFINITY, f32::INFINITY,
// row 2
200.0, 200.0,
400.0, 400.0,
250.0, // (c)
];
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c.clone());
let matching = solve_assignment(&mut weights).unwrap();
assert_eq!(N, matching.len());
assert_eq!(1050.0, calc_cost(0.0, &c[..], &matching[..], N));
assert_eq!(vec![pos(0, 1), pos(1, 0), pos(2, 2)], matching);
}
#[test]
fn test_should_be_unsolvable_if_first_cell_of_each_row_is_invalid() {
use std::f32;
const N: usize = 3;
let c = vec![
// row 0
f32::INFINITY,
400.0,
350.0,
// row 1
f32::INFINITY,
400.0,
600.0,
// row 2
f32::INFINITY,
400.0,
250.0, 250.0,
]; ];
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c); let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c.clone());
let matching = solve_assignment(&mut weights).unwrap(); let res = solve_assignment(&mut weights);
assert_eq!(Err(Error::MatrixNotSolvable), res);
}
assert_eq!(vec![pos(0, 1), pos(1, 0), pos(2, 2)], matching); #[test]
fn test_invalid2() {
use std::f32;
const N: usize = 3;
let c = vec![
// row 0
f32::INFINITY,
400.0,
350.0, // (a)
// row 1
f32::INFINITY,
400.0, // (b)
600.0,
// row 2
400.0, // (c)
250.0,
f32::INFINITY,
];
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c.clone());
let matching = solve_assignment(&mut weights).unwrap();
assert_eq!(N, matching.len());
assert_eq!(1150.0, calc_cost(0.0, &c[..], &matching[..], N));
assert_eq!(vec![pos(0, 2), pos(1, 1), pos(2, 0)], matching);
} }
#[test] #[test]

View File

@ -42,8 +42,13 @@ impl<T: WeightNum> Weights for WeightMatrix<T> {
} }
fn is_solvable(&self) -> bool { fn is_solvable(&self) -> bool {
for row in 0..self.n() { for row in self.c.genrows() {
if self.c.row(row).iter().all(|c| !c.is_valid()) { if row.iter().all(|c| !c.is_valid()) {
return false;
}
}
for column in self.c.gencolumns() {
if column.iter().all(|c| !c.is_valid()) {
return false; return false;
} }
} }
@ -67,14 +72,13 @@ impl<T: WeightNum> WeightMatrix<T> {
/// Return the minimum element of row `row`. /// Return the minimum element of row `row`.
fn min_of_row(&self, row: usize) -> T { fn min_of_row(&self, row: usize) -> T {
let row_slice = self.c.row(row); let row_iter = self.c.row(row);
let mut min = row_slice[0]; let mut valid_iter = row_iter.iter().filter(|cost| cost.is_valid()).cloned();
for &val in row_slice.iter().skip(1) { let first_min = valid_iter.next().unwrap();
if val.is_valid() && val < min { valid_iter.fold(
min = val; first_min,
} |total_min, val| if val < total_min { val } else { total_min },
} )
min
} }
// Subtract `val` from every element in row `row`. // Subtract `val` from every element in row `row`.