diff --git a/src/lib.rs b/src/lib.rs index 9356db2..b6fca92 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -688,23 +688,78 @@ fn test_solve_random10() { fn test_invalid() { use std::f32; const N: usize = 3; - let c = vec![ + // row 0 250.0, - 400.0, + 400.0, // (a) 350.0, - 400.0, + // row 1 + 400.0, // (b) 600.0, f32::INFINITY, + // row 2 200.0, 400.0, + 250.0, // (c) + ]; + + let mut weights: WeightMatrix = WeightMatrix::from_row_vec(N, c.clone()); + let matching = solve_assignment(&mut weights).unwrap(); + assert_eq!(N, matching.len()); + assert_eq!(1050.0, calc_cost(0.0, &c[..], &matching[..], N)); + assert_eq!(vec![pos(0, 1), pos(1, 0), pos(2, 2)], matching); +} + +#[test] +fn test_should_be_unsolvable_if_first_cell_of_each_row_is_invalid() { + use std::f32; + const N: usize = 3; + + let c = vec![ + // row 0 + f32::INFINITY, + 400.0, + 350.0, + // row 1 + f32::INFINITY, + 400.0, + 600.0, + // row 2 + f32::INFINITY, + 400.0, 250.0, ]; - let mut weights: WeightMatrix = WeightMatrix::from_row_vec(N, c); - let matching = solve_assignment(&mut weights).unwrap(); + let mut weights: WeightMatrix = WeightMatrix::from_row_vec(N, c.clone()); + let res = solve_assignment(&mut weights); + assert_eq!(Err(Error::MatrixNotSolvable), res); +} - assert_eq!(vec![pos(0, 1), pos(1, 0), pos(2, 2)], matching); +#[test] +fn test_invalid2() { + use std::f32; + const N: usize = 3; + + let c = vec![ + // row 0 + f32::INFINITY, + 400.0, + 350.0, // (a) + // row 1 + f32::INFINITY, + 400.0, // (b) + 600.0, + // row 2 + 400.0, // (c) + 250.0, + f32::INFINITY, + ]; + + let mut weights: WeightMatrix = WeightMatrix::from_row_vec(N, c.clone()); + let matching = solve_assignment(&mut weights).unwrap(); + assert_eq!(N, matching.len()); + assert_eq!(1150.0, calc_cost(0.0, &c[..], &matching[..], N)); + assert_eq!(vec![pos(0, 2), pos(1, 1), pos(2, 0)], matching); } #[test] diff --git a/src/weight_matrix.rs b/src/weight_matrix.rs index 1580b2e..ff58641 100644 --- a/src/weight_matrix.rs +++ b/src/weight_matrix.rs @@ -42,8 +42,13 @@ impl Weights for WeightMatrix { } fn is_solvable(&self) -> bool { - for row in 0..self.n() { - if self.c.row(row).iter().all(|c| !c.is_valid()) { + for row in self.c.genrows() { + if row.iter().all(|c| !c.is_valid()) { + return false; + } + } + for column in self.c.gencolumns() { + if column.iter().all(|c| !c.is_valid()) { return false; } } @@ -67,14 +72,13 @@ impl WeightMatrix { /// Return the minimum element of row `row`. fn min_of_row(&self, row: usize) -> T { - let row_slice = self.c.row(row); - let mut min = row_slice[0]; - for &val in row_slice.iter().skip(1) { - if val.is_valid() && val < min { - min = val; - } - } - min + let row_iter = self.c.row(row); + let mut valid_iter = row_iter.iter().filter(|cost| cost.is_valid()).cloned(); + let first_min = valid_iter.next().unwrap(); + valid_iter.fold( + first_min, + |total_min, val| if val < total_min { val } else { total_min }, + ) } // Subtract `val` from every element in row `row`.