Fix unsolvable bug, add more tests
If a column contains only invalid (e.g. infinite) weights, we cannot solve that matrix. This resulted in an infinite loop.
This commit is contained in:
parent
8f04199cc0
commit
7e07dd5941
67
src/lib.rs
67
src/lib.rs
@ -688,23 +688,78 @@ fn test_solve_random10() {
|
|||||||
fn test_invalid() {
|
fn test_invalid() {
|
||||||
use std::f32;
|
use std::f32;
|
||||||
const N: usize = 3;
|
const N: usize = 3;
|
||||||
|
|
||||||
let c = vec![
|
let c = vec![
|
||||||
|
// row 0
|
||||||
250.0,
|
250.0,
|
||||||
400.0,
|
400.0, // (a)
|
||||||
350.0,
|
350.0,
|
||||||
400.0,
|
// row 1
|
||||||
|
400.0, // (b)
|
||||||
600.0,
|
600.0,
|
||||||
f32::INFINITY,
|
f32::INFINITY,
|
||||||
|
// row 2
|
||||||
200.0,
|
200.0,
|
||||||
400.0,
|
400.0,
|
||||||
|
250.0, // (c)
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c.clone());
|
||||||
|
let matching = solve_assignment(&mut weights).unwrap();
|
||||||
|
assert_eq!(N, matching.len());
|
||||||
|
assert_eq!(1050.0, calc_cost(0.0, &c[..], &matching[..], N));
|
||||||
|
assert_eq!(vec![pos(0, 1), pos(1, 0), pos(2, 2)], matching);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_should_be_unsolvable_if_first_cell_of_each_row_is_invalid() {
|
||||||
|
use std::f32;
|
||||||
|
const N: usize = 3;
|
||||||
|
|
||||||
|
let c = vec![
|
||||||
|
// row 0
|
||||||
|
f32::INFINITY,
|
||||||
|
400.0,
|
||||||
|
350.0,
|
||||||
|
// row 1
|
||||||
|
f32::INFINITY,
|
||||||
|
400.0,
|
||||||
|
600.0,
|
||||||
|
// row 2
|
||||||
|
f32::INFINITY,
|
||||||
|
400.0,
|
||||||
250.0,
|
250.0,
|
||||||
];
|
];
|
||||||
|
|
||||||
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c);
|
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c.clone());
|
||||||
let matching = solve_assignment(&mut weights).unwrap();
|
let res = solve_assignment(&mut weights);
|
||||||
|
assert_eq!(Err(Error::MatrixNotSolvable), res);
|
||||||
|
}
|
||||||
|
|
||||||
assert_eq!(vec![pos(0, 1), pos(1, 0), pos(2, 2)], matching);
|
#[test]
|
||||||
|
fn test_invalid2() {
|
||||||
|
use std::f32;
|
||||||
|
const N: usize = 3;
|
||||||
|
|
||||||
|
let c = vec![
|
||||||
|
// row 0
|
||||||
|
f32::INFINITY,
|
||||||
|
400.0,
|
||||||
|
350.0, // (a)
|
||||||
|
// row 1
|
||||||
|
f32::INFINITY,
|
||||||
|
400.0, // (b)
|
||||||
|
600.0,
|
||||||
|
// row 2
|
||||||
|
400.0, // (c)
|
||||||
|
250.0,
|
||||||
|
f32::INFINITY,
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut weights: WeightMatrix<f32> = WeightMatrix::from_row_vec(N, c.clone());
|
||||||
|
let matching = solve_assignment(&mut weights).unwrap();
|
||||||
|
assert_eq!(N, matching.len());
|
||||||
|
assert_eq!(1150.0, calc_cost(0.0, &c[..], &matching[..], N));
|
||||||
|
assert_eq!(vec![pos(0, 2), pos(1, 1), pos(2, 0)], matching);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -42,8 +42,13 @@ impl<T: WeightNum> Weights for WeightMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn is_solvable(&self) -> bool {
|
fn is_solvable(&self) -> bool {
|
||||||
for row in 0..self.n() {
|
for row in self.c.genrows() {
|
||||||
if self.c.row(row).iter().all(|c| !c.is_valid()) {
|
if row.iter().all(|c| !c.is_valid()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for column in self.c.gencolumns() {
|
||||||
|
if column.iter().all(|c| !c.is_valid()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -67,14 +72,13 @@ impl<T: WeightNum> WeightMatrix<T> {
|
|||||||
|
|
||||||
/// Return the minimum element of row `row`.
|
/// Return the minimum element of row `row`.
|
||||||
fn min_of_row(&self, row: usize) -> T {
|
fn min_of_row(&self, row: usize) -> T {
|
||||||
let row_slice = self.c.row(row);
|
let row_iter = self.c.row(row);
|
||||||
let mut min = row_slice[0];
|
let mut valid_iter = row_iter.iter().filter(|cost| cost.is_valid()).cloned();
|
||||||
for &val in row_slice.iter().skip(1) {
|
let first_min = valid_iter.next().unwrap();
|
||||||
if val.is_valid() && val < min {
|
valid_iter.fold(
|
||||||
min = val;
|
first_min,
|
||||||
}
|
|total_min, val| if val < total_min { val } else { total_min },
|
||||||
}
|
)
|
||||||
min
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subtract `val` from every element in row `row`.
|
// Subtract `val` from every element in row `row`.
|
||||||
|
Loading…
Reference in New Issue
Block a user