Add support for rational coefficients in linear expressions.

Coef has been redefined to Rational32.
This commit is contained in:
David Wang 2017-03-15 07:53:23 +11:00
parent 20e5242624
commit ab6267b46c
3 changed files with 86 additions and 45 deletions

View File

@ -1,9 +1,10 @@
//! Equality implementation.
use std::rc::Rc;
use num_rational::Ratio;
use num_traits::Zero;
use ::{Constraint,LinExpr,PsResult,PuzzleSearch,Val,VarToken};
use intdiv::IntDiv;
pub struct Equality {
// The equation: 0 = constant + coef1 * var1 + coef2 * var2 + ...
@ -42,7 +43,7 @@ impl Constraint for Equality {
for (&var, &coef) in self.eqn.coef.iter() {
if let Some(val) = search.get_assigned(var) {
sum += coef * val;
sum = sum + coef * Ratio::from_integer(val);
} else {
// If we find more than one unassigned variable,
// cannot assign any other variables.
@ -58,13 +59,13 @@ impl Constraint for Equality {
if let Some((var, coef)) = unassigned_var {
// sum + coef * var = 0.
let val = -sum / coef;
if sum + coef * val == 0 {
try!(search.set_candidate(var, val));
if val.is_integer() {
try!(search.set_candidate(var, val.to_integer()));
} else {
return Err(());
}
} else {
if sum != 0 {
if !sum.is_zero() {
return Err(());
}
}
@ -78,12 +79,12 @@ impl Constraint for Equality {
for (&var, &coef) in self.eqn.coef.iter() {
let (min_val, max_val) = try!(search.get_min_max(var));
if coef > 0 {
sum_min += coef * min_val;
sum_max += coef * max_val;
if coef > Ratio::zero() {
sum_min = sum_min + coef * Ratio::from_integer(min_val);
sum_max = sum_max + coef * Ratio::from_integer(max_val);
} else {
sum_min += coef * max_val;
sum_max += coef * min_val;
sum_min = sum_min + coef * Ratio::from_integer(max_val);
sum_max = sum_max + coef * Ratio::from_integer(min_val);
}
}
@ -94,7 +95,7 @@ impl Constraint for Equality {
let mut iter = self.eqn.coef.iter().cycle();
while iters > 0 {
iters = iters - 1;
if !(sum_min <= 0 && 0 <= sum_max) {
if !(sum_min <= Ratio::zero() && Ratio::zero() <= sum_max) {
return Err(());
}
@ -106,24 +107,24 @@ impl Constraint for Equality {
let (min_val, max_val) = try!(search.get_min_max(var));
let (min_bnd, max_bnd);
if coef > 0 {
min_bnd = (coef * max_val - sum_max).div_round_up(coef);
max_bnd = (coef * min_val - sum_min).div_round_down(coef);
if coef > Ratio::zero() {
min_bnd = ((coef * Ratio::from_integer(max_val) - sum_max) / coef).ceil().to_integer();
max_bnd = ((coef * Ratio::from_integer(min_val) - sum_min) / coef).floor().to_integer();
} else {
min_bnd = (coef * max_val - sum_min).div_round_up(coef);
max_bnd = (coef * min_val - sum_max).div_round_down(coef);
min_bnd = ((coef * Ratio::from_integer(max_val) - sum_min) / coef).ceil().to_integer();
max_bnd = ((coef * Ratio::from_integer(min_val) - sum_max) / coef).floor().to_integer();
}
if min_val < min_bnd || max_bnd < max_val {
let (new_min, new_max)
= try!(search.bound_candidate_range(var, min_bnd, max_bnd));
if coef > 0 {
sum_min = sum_min + coef * (new_min - min_val);
sum_max = sum_max + coef * (new_max - max_val);
if coef > Ratio::zero() {
sum_min = sum_min + coef * Ratio::from_integer(new_min - min_val);
sum_max = sum_max + coef * Ratio::from_integer(new_max - max_val);
} else {
sum_min = sum_min + coef * (new_max - max_val);
sum_max = sum_max + coef * (new_min - min_val);
sum_min = sum_min + coef * Ratio::from_integer(new_max - max_val);
sum_max = sum_max + coef * Ratio::from_integer(new_min - min_val);
}
iters = self.eqn.coef.len();

View File

@ -7,6 +7,7 @@ extern crate num_traits;
use std::collections::HashMap;
use std::ops::Index;
use num_rational::Rational32;
pub use constraint::Constraint;
pub use puzzle::Puzzle;
@ -20,7 +21,7 @@ pub struct VarToken(usize);
pub type Val = i32;
/// The type of the coefficients in a linear expression.
pub type Coef = i32;
pub type Coef = Rational32;
/// A linear expression.
///

View File

@ -4,8 +4,10 @@ use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::convert::From;
use std::ops::{Add,Mul,Neg,Sub};
use num_rational::{Ratio,Rational32};
use num_traits::{One,Zero};
use ::{LinExpr,VarToken};
use ::{Coef,LinExpr,VarToken};
macro_rules! impl_commutative_op {
($LHS:ident + $RHS:ident) => {
@ -31,12 +33,24 @@ macro_rules! impl_subtract_op {
}
}
pub trait IntoCoef: Zero {
fn into_coef(self) -> Coef;
}
impl IntoCoef for i32 {
fn into_coef(self) -> Coef { Ratio::from_integer(self) }
}
impl IntoCoef for Rational32 {
fn into_coef(self) -> Coef { self }
}
/*--------------------------------------------------------------*/
impl From<i32> for LinExpr {
fn from(constant: i32) -> Self {
impl<T: IntoCoef> From<T> for LinExpr {
fn from(constant: T) -> Self {
LinExpr {
constant: constant,
constant: constant.into_coef(),
coef: HashMap::new(),
}
}
@ -45,10 +59,10 @@ impl From<i32> for LinExpr {
impl From<VarToken> for LinExpr {
fn from(var: VarToken) -> Self {
let mut coef = HashMap::new();
coef.insert(var, 1);
coef.insert(var, Ratio::one());
LinExpr {
constant: 0,
constant: Ratio::zero(),
coef: coef,
}
}
@ -65,26 +79,30 @@ impl Neg for VarToken {
}
}
impl Add<i32> for VarToken {
impl<T: IntoCoef> Add<T> for VarToken {
type Output = LinExpr;
fn add(self, rhs: i32) -> Self::Output {
fn add(self, rhs: T) -> Self::Output {
LinExpr::from(self) + rhs
}
}
impl_commutative_op!(i32 + VarToken);
impl_commutative_op!(Rational32 + VarToken);
impl_subtract_op!(VarToken - i32);
impl_subtract_op!(i32 - VarToken);
impl_subtract_op!(VarToken - Rational32);
impl_subtract_op!(Rational32 - VarToken);
impl Mul<i32> for VarToken {
impl<T: IntoCoef> Mul<T> for VarToken {
type Output = LinExpr;
fn mul(self, rhs: i32) -> Self::Output {
fn mul(self, rhs: T) -> Self::Output {
LinExpr::from(self) * rhs
}
}
impl_commutative_op!(i32 * VarToken);
impl_commutative_op!(Rational32 * VarToken);
/*--------------------------------------------------------------*/
/* Var-Var */
@ -110,29 +128,35 @@ impl Neg for LinExpr {
}
}
impl Add<i32> for LinExpr {
impl<T: IntoCoef> Add<T> for LinExpr {
type Output = LinExpr;
fn add(mut self, rhs: i32) -> Self::Output {
self.constant = self.constant + rhs;
fn add(mut self, rhs: T) -> Self::Output {
self.constant = self.constant + rhs.into_coef();
self
}
}
impl_commutative_op!(i32 + LinExpr);
impl_commutative_op!(Rational32 + LinExpr);
impl_subtract_op!(LinExpr - i32);
impl_subtract_op!(i32 - LinExpr);
impl_subtract_op!(LinExpr - Rational32);
impl_subtract_op!(Rational32 - LinExpr);
impl Mul<i32> for LinExpr {
impl<T: IntoCoef> Mul<T> for LinExpr {
type Output = LinExpr;
fn mul(mut self, rhs: i32) -> Self::Output {
if rhs == 0 {
self.constant = 0;
fn mul(mut self, rhs: T) -> Self::Output {
if rhs.is_zero() {
self.constant = Ratio::zero();
self.coef = HashMap::new();
} else if rhs != 1 {
} else {
let rhs = rhs.into_coef();
if rhs != Ratio::one() {
self.constant = self.constant * rhs;
for coef in self.coef.values_mut() {
*coef *= rhs;
*coef = *coef * rhs;
}
}
}
@ -141,6 +165,7 @@ impl Mul<i32> for LinExpr {
}
impl_commutative_op!(i32 * LinExpr);
impl_commutative_op!(Rational32 * LinExpr);
/*--------------------------------------------------------------*/
/* Expr-Var */
@ -173,10 +198,11 @@ impl Add for LinExpr {
e.insert(a2);
},
Entry::Occupied(mut e) => {
if *e.get() + a2 == 0 {
let new_coef = *e.get() + a2;
if new_coef.is_zero() {
e.remove();
} else {
*e.get_mut() += a2;
*e.get_mut() = new_coef;
}
},
}
@ -192,6 +218,7 @@ impl_subtract_op!(LinExpr - LinExpr);
#[cfg(test)]
mod tests {
use num_rational::Ratio;
use ::Puzzle;
#[test]
@ -204,11 +231,17 @@ mod tests {
let _ = x + 1;
let _ = x - 1;
let _ = x * 1;
let _ = x + Ratio::new(1, 2);
let _ = x - Ratio::new(1, 2);
let _ = x * Ratio::new(1, 2);
// expr = const + var;
let _ = 1 + x;
let _ = 1 - x;
let _ = 1 * x;
let _ = Ratio::new(1, 2) + x;
let _ = Ratio::new(1, 2) - x;
let _ = Ratio::new(1, 2) * x;
// expr = var + var;
let _ = -x;
@ -219,11 +252,17 @@ mod tests {
let _ = (x + y) + 1;
let _ = (x + y) - 1;
let _ = (x + y) * 1;
let _ = (x + y) + Ratio::new(1, 2);
let _ = (x + y) - Ratio::new(1, 2);
let _ = (x + y) * Ratio::new(1, 2);
// expr = const + expr;
let _ = 1 + (x + y);
let _ = 1 - (x + y);
let _ = 1 * (x + y);
let _ = Ratio::new(1, 2) + (x + y);
let _ = Ratio::new(1, 2) - (x + y);
let _ = Ratio::new(1, 2) * (x + y);
// expr = expr + var;
let _ = (x + 1) + y;