Fix unsolvable bug, add more tests
If a column contains only invalid (e.g. infinite) weights, we cannot solve that matrix. This resulted in an infinite loop.
This commit is contained in:
parent
8f04199cc0
commit
7e07dd5941
67
src/lib.rs
67
src/lib.rs
@ -688,23 +688,78 @@ fn test_solve_random10() {
|
||||
fn test_invalid() {
|
||||
use std::f32;
|
||||
const N: usize = 3;
|
||||
|
||||
let c = vec![
|
||||
// row 0
|
||||
250.0,
|
||||
400.0,
|
||||
400.0, // (a)
|
||||
350.0,
|
||||
400.0,
|
||||
// row 1
|
||||
400.0, // (b)
|
||||
600.0,
|
||||
f32::INFINITY,
|
||||
// row 2
|
||||
200.0,
|
||||
400.0,
|
||||
250.0, // (c)
|
||||
];
|
||||
|
||||
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c.clone());
|
||||
let matching = solve_assignment(&mut weights).unwrap();
|
||||
assert_eq!(N, matching.len());
|
||||
assert_eq!(1050.0, calc_cost(0.0, &c[..], &matching[..], N));
|
||||
assert_eq!(vec![pos(0, 1), pos(1, 0), pos(2, 2)], matching);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_be_unsolvable_if_first_cell_of_each_row_is_invalid() {
|
||||
use std::f32;
|
||||
const N: usize = 3;
|
||||
|
||||
let c = vec![
|
||||
// row 0
|
||||
f32::INFINITY,
|
||||
400.0,
|
||||
350.0,
|
||||
// row 1
|
||||
f32::INFINITY,
|
||||
400.0,
|
||||
600.0,
|
||||
// row 2
|
||||
f32::INFINITY,
|
||||
400.0,
|
||||
250.0,
|
||||
];
|
||||
|
||||
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c);
|
||||
let matching = solve_assignment(&mut weights).unwrap();
|
||||
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c.clone());
|
||||
let res = solve_assignment(&mut weights);
|
||||
assert_eq!(Err(Error::MatrixNotSolvable), res);
|
||||
}
|
||||
|
||||
assert_eq!(vec![pos(0, 1), pos(1, 0), pos(2, 2)], matching);
|
||||
#[test]
|
||||
fn test_invalid2() {
|
||||
use std::f32;
|
||||
const N: usize = 3;
|
||||
|
||||
let c = vec![
|
||||
// row 0
|
||||
f32::INFINITY,
|
||||
400.0,
|
||||
350.0, // (a)
|
||||
// row 1
|
||||
f32::INFINITY,
|
||||
400.0, // (b)
|
||||
600.0,
|
||||
// row 2
|
||||
400.0, // (c)
|
||||
250.0,
|
||||
f32::INFINITY,
|
||||
];
|
||||
|
||||
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c.clone());
|
||||
let matching = solve_assignment(&mut weights).unwrap();
|
||||
assert_eq!(N, matching.len());
|
||||
assert_eq!(1150.0, calc_cost(0.0, &c[..], &matching[..], N));
|
||||
assert_eq!(vec![pos(0, 2), pos(1, 1), pos(2, 0)], matching);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -42,8 +42,13 @@ impl<T: WeightNum> Weights for WeightMatrix<T> {
|
||||
}
|
||||
|
||||
fn is_solvable(&self) -> bool {
|
||||
for row in 0..self.n() {
|
||||
if self.c.row(row).iter().all(|c| !c.is_valid()) {
|
||||
for row in self.c.genrows() {
|
||||
if row.iter().all(|c| !c.is_valid()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for column in self.c.gencolumns() {
|
||||
if column.iter().all(|c| !c.is_valid()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -67,14 +72,13 @@ impl<T: WeightNum> WeightMatrix<T> {
|
||||
|
||||
/// Return the minimum element of row `row`.
|
||||
fn min_of_row(&self, row: usize) -> T {
|
||||
let row_slice = self.c.row(row);
|
||||
let mut min = row_slice[0];
|
||||
for &val in row_slice.iter().skip(1) {
|
||||
if val.is_valid() && val < min {
|
||||
min = val;
|
||||
}
|
||||
}
|
||||
min
|
||||
let row_iter = self.c.row(row);
|
||||
let mut valid_iter = row_iter.iter().filter(|cost| cost.is_valid()).cloned();
|
||||
let first_min = valid_iter.next().unwrap();
|
||||
valid_iter.fold(
|
||||
first_min,
|
||||
|total_min, val| if val < total_min { val } else { total_min },
|
||||
)
|
||||
}
|
||||
|
||||
// Subtract `val` from every element in row `row`.
|
||||
|
Loading…
Reference in New Issue
Block a user