diff --git a/src/lib.rs b/src/lib.rs index e48a8e4..00fefb2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ // * Reuse path Vec in step5 // * Cleanup // * More test cases +// * Non-square matrices extern crate bit_vec; @@ -38,12 +39,14 @@ pub struct WeightMatrix { impl WeightMatrix { pub fn from_row_vec(n: usize, data: Vec) -> WeightMatrix { + assert!(n > 0); WeightMatrix{c: SquareMatrix::from_row_vec(n, data)} } #[inline(always)] fn n(&self) -> usize { self.c.n() } + #[inline(always)] fn is_element_zero(&self, pos: (usize, usize)) -> bool { self.c[pos] == T::zero() } @@ -59,23 +62,17 @@ impl WeightMatrix { // Subtract `val` from every element in row `row`. fn sub_row(&mut self, row: usize, val: T) { - for col in 0 .. self.n() { - self.c[(row, col)] = self.c[(row, col)] - val; - } + self.c.map_row(row, |cur| cur - val); } // Subtract `val` from every element in column `col`. fn sub_col(&mut self, col: usize, val: T) { - for row in 0 .. self.n() { - self.c[(row, col)] = self.c[(row, col)] - val; - } + self.c.map_col(col, |cur| cur - val); } // Add `val` to every element in row `row`. fn add_row(&mut self, row: usize, val: T) { - for col in 0 .. self.n() { - self.c[(row, col)] = self.c[(row, col)] - val; - } + self.c.map_row(row, |cur| cur + val); } /// Find the first uncovered element with value 0 `find_a_zero` diff --git a/src/square_matrix.rs b/src/square_matrix.rs index a6efa34..3ca3f68 100644 --- a/src/square_matrix.rs +++ b/src/square_matrix.rs @@ -34,12 +34,30 @@ impl IndexMut<(usize, usize)> for SquareMatrix { } } -impl SquareMatrix { +impl SquareMatrix { pub fn from_row_vec(n: usize, data: Vec) -> SquareMatrix { + assert!(n > 0); assert!(data.len() == n*n); SquareMatrix {n: n, data: data} } #[inline(always)] pub fn n(&self) -> usize { self.n } pub fn into_vec(self) -> Vec { self.data } + + #[inline(always)] + pub fn map_row T>(&mut self, row: usize, f: F) { + for col in 0..self.n { + let n = f(self[(row, col)]); + self[(row, col)] = n; + } + } + + #[inline(always)] + pub fn map_col T>(&mut self, col: usize, f: F) { + for row in 0..self.n { + let n = f(self[(row, col)]); + self[(row, col)] = n; + } + } + }