From ab6267b46c23f7065f14e7bd9dac6f9fea93a630 Mon Sep 17 00:00:00 2001 From: David Wang Date: Wed, 15 Mar 2017 07:53:23 +1100 Subject: [PATCH] Add support for rational coefficients in linear expressions. Coef has been redefined to Rational32. --- src/constraint/equality.rs | 43 +++++++++---------- src/lib.rs | 3 +- src/linexpr.rs | 85 +++++++++++++++++++++++++++----------- 3 files changed, 86 insertions(+), 45 deletions(-) diff --git a/src/constraint/equality.rs b/src/constraint/equality.rs index 0312f4b..1a96f37 100644 --- a/src/constraint/equality.rs +++ b/src/constraint/equality.rs @@ -1,9 +1,10 @@ //! Equality implementation. use std::rc::Rc; +use num_rational::Ratio; +use num_traits::Zero; use ::{Constraint,LinExpr,PsResult,PuzzleSearch,Val,VarToken}; -use intdiv::IntDiv; pub struct Equality { // The equation: 0 = constant + coef1 * var1 + coef2 * var2 + ... @@ -42,7 +43,7 @@ impl Constraint for Equality { for (&var, &coef) in self.eqn.coef.iter() { if let Some(val) = search.get_assigned(var) { - sum += coef * val; + sum = sum + coef * Ratio::from_integer(val); } else { // If we find more than one unassigned variable, // cannot assign any other variables. @@ -58,13 +59,13 @@ impl Constraint for Equality { if let Some((var, coef)) = unassigned_var { // sum + coef * var = 0. let val = -sum / coef; - if sum + coef * val == 0 { - try!(search.set_candidate(var, val)); + if val.is_integer() { + try!(search.set_candidate(var, val.to_integer())); } else { return Err(()); } } else { - if sum != 0 { + if !sum.is_zero() { return Err(()); } } @@ -78,12 +79,12 @@ impl Constraint for Equality { for (&var, &coef) in self.eqn.coef.iter() { let (min_val, max_val) = try!(search.get_min_max(var)); - if coef > 0 { - sum_min += coef * min_val; - sum_max += coef * max_val; + if coef > Ratio::zero() { + sum_min = sum_min + coef * Ratio::from_integer(min_val); + sum_max = sum_max + coef * Ratio::from_integer(max_val); } else { - sum_min += coef * max_val; - sum_max += coef * min_val; + sum_min = sum_min + coef * Ratio::from_integer(max_val); + sum_max = sum_max + coef * Ratio::from_integer(min_val); } } @@ -94,7 +95,7 @@ impl Constraint for Equality { let mut iter = self.eqn.coef.iter().cycle(); while iters > 0 { iters = iters - 1; - if !(sum_min <= 0 && 0 <= sum_max) { + if !(sum_min <= Ratio::zero() && Ratio::zero() <= sum_max) { return Err(()); } @@ -106,24 +107,24 @@ impl Constraint for Equality { let (min_val, max_val) = try!(search.get_min_max(var)); let (min_bnd, max_bnd); - if coef > 0 { - min_bnd = (coef * max_val - sum_max).div_round_up(coef); - max_bnd = (coef * min_val - sum_min).div_round_down(coef); + if coef > Ratio::zero() { + min_bnd = ((coef * Ratio::from_integer(max_val) - sum_max) / coef).ceil().to_integer(); + max_bnd = ((coef * Ratio::from_integer(min_val) - sum_min) / coef).floor().to_integer(); } else { - min_bnd = (coef * max_val - sum_min).div_round_up(coef); - max_bnd = (coef * min_val - sum_max).div_round_down(coef); + min_bnd = ((coef * Ratio::from_integer(max_val) - sum_min) / coef).ceil().to_integer(); + max_bnd = ((coef * Ratio::from_integer(min_val) - sum_max) / coef).floor().to_integer(); } if min_val < min_bnd || max_bnd < max_val { let (new_min, new_max) = try!(search.bound_candidate_range(var, min_bnd, max_bnd)); - if coef > 0 { - sum_min = sum_min + coef * (new_min - min_val); - sum_max = sum_max + coef * (new_max - max_val); + if coef > Ratio::zero() { + sum_min = sum_min + coef * Ratio::from_integer(new_min - min_val); + sum_max = sum_max + coef * Ratio::from_integer(new_max - max_val); } else { - sum_min = sum_min + coef * (new_max - max_val); - sum_max = sum_max + coef * (new_min - min_val); + sum_min = sum_min + coef * Ratio::from_integer(new_max - max_val); + sum_max = sum_max + coef * Ratio::from_integer(new_min - min_val); } iters = self.eqn.coef.len(); diff --git a/src/lib.rs b/src/lib.rs index bf8bf66..83f10a6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ extern crate num_traits; use std::collections::HashMap; use std::ops::Index; +use num_rational::Rational32; pub use constraint::Constraint; pub use puzzle::Puzzle; @@ -20,7 +21,7 @@ pub struct VarToken(usize); pub type Val = i32; /// The type of the coefficients in a linear expression. -pub type Coef = i32; +pub type Coef = Rational32; /// A linear expression. /// diff --git a/src/linexpr.rs b/src/linexpr.rs index 1fe45e5..7992e5f 100644 --- a/src/linexpr.rs +++ b/src/linexpr.rs @@ -4,8 +4,10 @@ use std::collections::HashMap; use std::collections::hash_map::Entry; use std::convert::From; use std::ops::{Add,Mul,Neg,Sub}; +use num_rational::{Ratio,Rational32}; +use num_traits::{One,Zero}; -use ::{LinExpr,VarToken}; +use ::{Coef,LinExpr,VarToken}; macro_rules! impl_commutative_op { ($LHS:ident + $RHS:ident) => { @@ -31,12 +33,24 @@ macro_rules! impl_subtract_op { } } +pub trait IntoCoef: Zero { + fn into_coef(self) -> Coef; +} + +impl IntoCoef for i32 { + fn into_coef(self) -> Coef { Ratio::from_integer(self) } +} + +impl IntoCoef for Rational32 { + fn into_coef(self) -> Coef { self } +} + /*--------------------------------------------------------------*/ -impl From for LinExpr { - fn from(constant: i32) -> Self { +impl From for LinExpr { + fn from(constant: T) -> Self { LinExpr { - constant: constant, + constant: constant.into_coef(), coef: HashMap::new(), } } @@ -45,10 +59,10 @@ impl From for LinExpr { impl From for LinExpr { fn from(var: VarToken) -> Self { let mut coef = HashMap::new(); - coef.insert(var, 1); + coef.insert(var, Ratio::one()); LinExpr { - constant: 0, + constant: Ratio::zero(), coef: coef, } } @@ -65,26 +79,30 @@ impl Neg for VarToken { } } -impl Add for VarToken { +impl Add for VarToken { type Output = LinExpr; - fn add(self, rhs: i32) -> Self::Output { + fn add(self, rhs: T) -> Self::Output { LinExpr::from(self) + rhs } } impl_commutative_op!(i32 + VarToken); +impl_commutative_op!(Rational32 + VarToken); impl_subtract_op!(VarToken - i32); impl_subtract_op!(i32 - VarToken); +impl_subtract_op!(VarToken - Rational32); +impl_subtract_op!(Rational32 - VarToken); -impl Mul for VarToken { +impl Mul for VarToken { type Output = LinExpr; - fn mul(self, rhs: i32) -> Self::Output { + fn mul(self, rhs: T) -> Self::Output { LinExpr::from(self) * rhs } } impl_commutative_op!(i32 * VarToken); +impl_commutative_op!(Rational32 * VarToken); /*--------------------------------------------------------------*/ /* Var-Var */ @@ -110,29 +128,35 @@ impl Neg for LinExpr { } } -impl Add for LinExpr { +impl Add for LinExpr { type Output = LinExpr; - fn add(mut self, rhs: i32) -> Self::Output { - self.constant = self.constant + rhs; + fn add(mut self, rhs: T) -> Self::Output { + self.constant = self.constant + rhs.into_coef(); self } } impl_commutative_op!(i32 + LinExpr); +impl_commutative_op!(Rational32 + LinExpr); impl_subtract_op!(LinExpr - i32); impl_subtract_op!(i32 - LinExpr); +impl_subtract_op!(LinExpr - Rational32); +impl_subtract_op!(Rational32 - LinExpr); -impl Mul for LinExpr { +impl Mul for LinExpr { type Output = LinExpr; - fn mul(mut self, rhs: i32) -> Self::Output { - if rhs == 0 { - self.constant = 0; + fn mul(mut self, rhs: T) -> Self::Output { + if rhs.is_zero() { + self.constant = Ratio::zero(); self.coef = HashMap::new(); - } else if rhs != 1 { - self.constant = self.constant * rhs; - for coef in self.coef.values_mut() { - *coef *= rhs; + } else { + let rhs = rhs.into_coef(); + if rhs != Ratio::one() { + self.constant = self.constant * rhs; + for coef in self.coef.values_mut() { + *coef = *coef * rhs; + } } } @@ -141,6 +165,7 @@ impl Mul for LinExpr { } impl_commutative_op!(i32 * LinExpr); +impl_commutative_op!(Rational32 * LinExpr); /*--------------------------------------------------------------*/ /* Expr-Var */ @@ -173,10 +198,11 @@ impl Add for LinExpr { e.insert(a2); }, Entry::Occupied(mut e) => { - if *e.get() + a2 == 0 { + let new_coef = *e.get() + a2; + if new_coef.is_zero() { e.remove(); } else { - *e.get_mut() += a2; + *e.get_mut() = new_coef; } }, } @@ -192,6 +218,7 @@ impl_subtract_op!(LinExpr - LinExpr); #[cfg(test)] mod tests { + use num_rational::Ratio; use ::Puzzle; #[test] @@ -204,11 +231,17 @@ mod tests { let _ = x + 1; let _ = x - 1; let _ = x * 1; + let _ = x + Ratio::new(1, 2); + let _ = x - Ratio::new(1, 2); + let _ = x * Ratio::new(1, 2); // expr = const + var; let _ = 1 + x; let _ = 1 - x; let _ = 1 * x; + let _ = Ratio::new(1, 2) + x; + let _ = Ratio::new(1, 2) - x; + let _ = Ratio::new(1, 2) * x; // expr = var + var; let _ = -x; @@ -219,11 +252,17 @@ mod tests { let _ = (x + y) + 1; let _ = (x + y) - 1; let _ = (x + y) * 1; + let _ = (x + y) + Ratio::new(1, 2); + let _ = (x + y) - Ratio::new(1, 2); + let _ = (x + y) * Ratio::new(1, 2); // expr = const + expr; let _ = 1 + (x + y); let _ = 1 - (x + y); let _ = 1 * (x + y); + let _ = Ratio::new(1, 2) + (x + y); + let _ = Ratio::new(1, 2) - (x + y); + let _ = Ratio::new(1, 2) * (x + y); // expr = expr + var; let _ = (x + 1) + y;