diff --git a/src/weight_matrix.rs b/src/weight_matrix.rs index f2ba33d..233ec85 100644 --- a/src/weight_matrix.rs +++ b/src/weight_matrix.rs @@ -31,25 +31,19 @@ impl Weights for WeightMatrix { fn add_row(&mut self, row: usize, val: T) { self.c .row_mut(row) - .mapv_inplace(|cur| if cur.is_disallowed() { cur } else { cur + val }); + .mapv_inplace(|cur| if cur.is_valid() { cur + val } else { cur }); } // Subtract `val` from every element in column `col`. fn sub_column(&mut self, col: usize, val: T) { - self.c.column_mut(col).mapv_inplace( - |cur| { - if cur.is_disallowed() { - cur - } else { - cur - val - } - }, - ); + self.c + .column_mut(col) + .mapv_inplace(|cur| if cur.is_valid() { cur - val } else { cur }); } fn is_solvable(&self) -> bool { for row in 0..self.n() { - if self.c.row(row).iter().all(|c| c.is_disallowed()) { + if self.c.row(row).iter().all(|c| !c.is_valid()) { return false; } } @@ -76,7 +70,7 @@ impl WeightMatrix { let row_slice = self.c.row(row); let mut min = row_slice[0]; for &val in row_slice.iter().skip(1) { - if !val.is_disallowed() && val < min { + if val.is_valid() && val < min { min = val; } } @@ -87,7 +81,7 @@ impl WeightMatrix { fn sub_row(&mut self, row: usize, val: T) { self.c .row_mut(row) - .mapv_inplace(|cur| if cur.is_disallowed() { cur } else { cur - val }); + .mapv_inplace(|cur| if cur.is_valid() { cur - val } else { cur }); } pub fn as_slice(&self) -> &[T] { diff --git a/src/weight_num.rs b/src/weight_num.rs index 7aef23c..57b37ef 100644 --- a/src/weight_num.rs +++ b/src/weight_num.rs @@ -3,8 +3,8 @@ use std::{f32, f64}; pub trait WeightNum: PartialOrd + Copy + Sub + Add { fn is_zero(&self) -> bool; - fn is_disallowed(&self) -> bool { - false + fn is_valid(&self) -> bool { + true } } @@ -84,8 +84,9 @@ impl WeightNum for f64 { *self == 0.0 } - fn is_disallowed(&self) -> bool { - *self == f64::INFINITY + #[inline(always)] + fn is_valid(&self) -> bool { + self.is_finite() } } @@ -95,7 +96,8 @@ impl WeightNum for f32 { *self == 0.0 } - fn is_disallowed(&self) -> bool { - *self == f32::INFINITY + #[inline(always)] + fn is_valid(&self) -> bool { + self.is_finite() } }