Fix unsolvable bug, add more tests

If a column contains only invalid (e.g. infinite) weights, we cannot
solve that matrix. This resulted in an infinite loop.
This commit is contained in:
Michael Neumann 2019-04-07 16:22:04 +02:00
parent 8f04199cc0
commit 7e07dd5941
2 changed files with 75 additions and 16 deletions

View File

@ -688,23 +688,78 @@ fn test_solve_random10() {
fn test_invalid() {
use std::f32;
const N: usize = 3;
let c = vec![
// row 0
250.0,
400.0,
400.0, // (a)
350.0,
400.0,
// row 1
400.0, // (b)
600.0,
f32::INFINITY,
// row 2
200.0,
400.0,
250.0, // (c)
];
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c.clone());
let matching = solve_assignment(&mut weights).unwrap();
assert_eq!(N, matching.len());
assert_eq!(1050.0, calc_cost(0.0, &c[..], &matching[..], N));
assert_eq!(vec![pos(0, 1), pos(1, 0), pos(2, 2)], matching);
}
#[test]
fn test_should_be_unsolvable_if_first_cell_of_each_row_is_invalid() {
use std::f32;
const N: usize = 3;
let c = vec![
// row 0
f32::INFINITY,
400.0,
350.0,
// row 1
f32::INFINITY,
400.0,
600.0,
// row 2
f32::INFINITY,
400.0,
250.0,
];
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c);
let matching = solve_assignment(&mut weights).unwrap();
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c.clone());
let res = solve_assignment(&mut weights);
assert_eq!(Err(Error::MatrixNotSolvable), res);
}
assert_eq!(vec![pos(0, 1), pos(1, 0), pos(2, 2)], matching);
#[test]
fn test_invalid2() {
use std::f32;
const N: usize = 3;
let c = vec![
// row 0
f32::INFINITY,
400.0,
350.0, // (a)
// row 1
f32::INFINITY,
400.0, // (b)
600.0,
// row 2
400.0, // (c)
250.0,
f32::INFINITY,
];
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c.clone());
let matching = solve_assignment(&mut weights).unwrap();
assert_eq!(N, matching.len());
assert_eq!(1150.0, calc_cost(0.0, &c[..], &matching[..], N));
assert_eq!(vec![pos(0, 2), pos(1, 1), pos(2, 0)], matching);
}
#[test]

View File

@ -42,8 +42,13 @@ impl<T: WeightNum> Weights for WeightMatrix<T> {
}
fn is_solvable(&self) -> bool {
for row in 0..self.n() {
if self.c.row(row).iter().all(|c| !c.is_valid()) {
for row in self.c.genrows() {
if row.iter().all(|c| !c.is_valid()) {
return false;
}
}
for column in self.c.gencolumns() {
if column.iter().all(|c| !c.is_valid()) {
return false;
}
}
@ -67,14 +72,13 @@ impl<T: WeightNum> WeightMatrix<T> {
/// Return the minimum element of row `row`.
fn min_of_row(&self, row: usize) -> T {
let row_slice = self.c.row(row);
let mut min = row_slice[0];
for &val in row_slice.iter().skip(1) {
if val.is_valid() && val < min {
min = val;
}
}
min
let row_iter = self.c.row(row);
let mut valid_iter = row_iter.iter().filter(|cost| cost.is_valid()).cloned();
let first_min = valid_iter.next().unwrap();
valid_iter.fold(
first_min,
|total_min, val| if val < total_min { val } else { total_min },
)
}
// Subtract `val` from every element in row `row`.