fixes, safe and unsafe time pos cst expt, plenty of mod sub cases

This commit is contained in:
lperron@google.com
2012-06-16 16:19:27 +00:00
parent f6e225cb11
commit 4dd782207d
4 changed files with 281 additions and 76 deletions

View File

@@ -798,9 +798,9 @@ Constraint* Solver::MakePathCumul(const std::vector<IntVar*>& nexts,
}
namespace {
class Modulo : public Constraint {
class IntModulo : public Constraint {
public:
Modulo(Solver* const solver, IntVar* const x, int64 mod, IntVar* const y)
IntModulo(Solver* const solver, IntVar* const x, int64 mod, IntVar* const y)
: Constraint(solver),
x_(x),
mod_(mod),
@@ -810,12 +810,12 @@ class Modulo : public Constraint {
CHECK_GE(mod_, 0);
}
virtual ~Modulo() {}
virtual ~IntModulo() {}
virtual void Post() {
Demon* const demon = solver()->MakeConstraintInitialPropagateCallback(this);
x_->WhenDomain(demon);
y_->WhenDomain(demon);
demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
x_->WhenDomain(demon_);
y_->WhenDomain(demon_);
}
virtual void InitialPropagate() {
@@ -825,32 +825,46 @@ class Modulo : public Constraint {
}
y_->SetRange(0, mod_ - 1);
if (y_->Size() <= mod_) {
if (y_->Bound()) {
const int64 result = y_->Min();
to_remove_.clear();
for (x_iterator_->Init(); x_iterator_->Ok(); x_iterator_->Next()) {
const int64 value = x_iterator_->Value();
if (!y_->Contains(value % mod_)) {
if (value % mod_ != result) {
to_remove_.push_back(value);
}
}
x_->RemoveValues(to_remove_);
to_remove_.clear();
}
demon_->inhibit(solver());
} else {
if (y_->Size() <= mod_) {
to_remove_.clear();
for (x_iterator_->Init(); x_iterator_->Ok(); x_iterator_->Next()) {
const int64 value = x_iterator_->Value();
if (!y_->Contains(value % mod_)) {
to_remove_.push_back(value);
}
}
x_->RemoveValues(to_remove_);
to_remove_.clear();
}
for (y_iterator_->Init(); y_iterator_->Ok(); y_iterator_->Next()) {
const int64 value = y_iterator_->Value();
bool support = false;
for (int64 w = 0; w <= x_->Max() / mod_; ++w) {
if (x_->Contains(w * mod_ + value)) {
support = true;
break;
for (y_iterator_->Init(); y_iterator_->Ok(); y_iterator_->Next()) {
const int64 value = y_iterator_->Value();
bool support = false;
for (int64 w = 0; w <= x_->Max() / mod_; ++w) {
if (x_->Contains(w * mod_ + value)) {
support = true;
break;
}
}
if (!support) {
to_remove_.push_back(value);
}
}
if (!support) {
to_remove_.push_back(value);
}
y_->RemoveValues(to_remove_);
}
y_->RemoveValues(to_remove_);
}
virtual string DebugString() const {
@@ -867,6 +881,7 @@ class Modulo : public Constraint {
IntVarIterator* const x_iterator_;
IntVarIterator* const y_iterator_;
std::vector<int64> to_remove_;
Demon* demon_;
};
class VariableModulo : public Constraint {
@@ -891,27 +906,116 @@ class VariableModulo : public Constraint {
Solver* const s = solver();
IntVar* const d = s->MakeIntVar(std::min(x_->Min(), -x_->Max()),
std::max(x_->Max(), -x_->Min()));
s->AddConstraint(s->MakeEquality(x_, s->MakeSum(s->MakeProd(mod_, d), y_)->Var()));
s->AddConstraint(s->MakeGreater(y_, s->MakeOpposite(s->MakeAbs(mod_))->Var()));
s->AddConstraint(
s->MakeEquality(x_, s->MakeSum(s->MakeProd(mod_, d), y_)->Var()));
s->AddConstraint(
s->MakeGreater(y_, s->MakeOpposite(s->MakeAbs(mod_))->Var()));
s->AddConstraint(s->MakeLess(y_, s->MakeAbs(mod_)->Var()));
s->AddConstraint(s->MakeGreaterOrEqual(d, s->MakeMin(x_, s->MakeOpposite(x_))->Var()));
s->AddConstraint(s->MakeLessOrEqual(d, s->MakeMax(x_, s->MakeOpposite(x_))->Var()));
s->AddConstraint(
s->MakeGreaterOrEqual(d, s->MakeMin(x_, s->MakeOpposite(x_))->Var()));
s->AddConstraint(
s->MakeLessOrEqual(d, s->MakeMax(x_, s->MakeOpposite(x_))->Var()));
}
virtual void InitialPropagate() {
mod_->RemoveValue(0);
}
virtual string DebugString() const {
return StringPrintf("VariableModulo(%s, %s, %s)",
x_->DebugString().c_str(),
mod_->DebugString().c_str(),
y_->DebugString().c_str());
}
private:
IntVar* const x_;
IntVar* const mod_;
IntVar* const y_;
};
class BoundModulo : public Constraint {
public:
BoundModulo(Solver* const solver, IntVar* const x, IntVar* const mod)
: Constraint(solver), x_(x), mod_(mod) {
CHECK_NOTNULL(solver);
CHECK_NOTNULL(x);
CHECK_NOTNULL(mod);
}
virtual ~BoundModulo() {}
virtual void Post() {
Solver* const s = solver();
IntVar* const d = s->MakeIntVar(std::min(x_->Min(), -x_->Max()),
std::max(x_->Max(), -x_->Min()));
s->AddConstraint(s->MakeEquality(x_, s->MakeProd(mod_, d)->Var()));
}
virtual void InitialPropagate() {
mod_->RemoveValue(0);
}
virtual string DebugString() const {
return StringPrintf("BoundModulo(%s, %s)",
x_->DebugString().c_str(),
mod_->DebugString().c_str());
}
private:
IntVar* const x_;
IntVar* const mod_;
};
class PositiveModulo : public Constraint {
public:
PositiveModulo(Solver* const solver,
IntVar* const x,
IntVar* const mod,
IntVar* const y)
: Constraint(solver),
x_(x),
mod_(mod),
y_(y) {
CHECK_NOTNULL(solver);
CHECK_NOTNULL(x);
CHECK_NOTNULL(mod);
CHECK_NOTNULL(y);
}
virtual ~PositiveModulo() {}
virtual void Post() {
Solver* const s = solver();
IntVar* const d = s->MakeIntVar(1, x_->Max());
s->AddConstraint(
s->MakeEquality(x_, s->MakeSum(s->MakeProd(mod_, d), y_)->Var()));
s->AddConstraint(s->MakeLess(y_, mod_));
}
virtual void InitialPropagate() {
mod_->RemoveValue(0);
y_->SetMin(0);
}
virtual string DebugString() const {
return StringPrintf("PositiveModulo(%s, %s, %s)",
x_->DebugString().c_str(),
mod_->DebugString().c_str(),
y_->DebugString().c_str());
}
private:
IntVar* const x_;
IntVar* const mod_;
IntVar* const y_;
};
} // namespace
Constraint* Solver::MakeModuloConstraint(IntVar* const x,
int64 mod,
IntVar* const y) {
return RevAlloc(new Modulo(this, x, mod, y));
return RevAlloc(new IntModulo(this, x, mod, y));
}
Constraint* Solver::MakeModuloConstraint(IntVar* const x,
@@ -919,6 +1023,10 @@ Constraint* Solver::MakeModuloConstraint(IntVar* const x,
IntVar* const y) {
if (mod->Bound()) {
return MakeModuloConstraint(x, mod->Min(), y);
} else if (x->Min() >= 0 && y->Min() >= 0 && mod->Min() >= 0) {
return RevAlloc(new PositiveModulo(this, x, mod, y));
} else if (y->Bound() && y->Min() == 0) {
return RevAlloc(new BoundModulo(this, x, mod));
} else {
return RevAlloc(new VariableModulo(this, x, mod, y));
}

View File

@@ -306,7 +306,7 @@ class DomainIntVar : public IntVar {
IntVar* boolvar = NULL;
if (variable_->Contains(value)) {
if (variable_->Bound()) {
boolvar = solver()->MakeIntConst(0);
boolvar = solver()->MakeIntConst(1);
} else {
const string vname =
variable_->HasName() ?
@@ -317,6 +317,7 @@ class DomainIntVar : public IntVar {
value);
boolvar = solver()->MakeBoolVar(bname);
}
active_watchers_.Incr(solver());
} else {
boolvar = variable_->solver()->MakeIntConst(0);
}
@@ -3081,29 +3082,62 @@ IntVar* OppIntExpr::CastToVar() {
class TimesIntPosCstExpr : public BaseIntExpr {
public:
TimesIntPosCstExpr(Solver* const s, IntExpr* const e, int64 v);
virtual ~TimesIntPosCstExpr();
TimesIntPosCstExpr(Solver* const s, IntExpr* const e, int64 v)
: BaseIntExpr(s), expr_(e), value_(v) {
CHECK_GE(v, 0);
}
virtual ~TimesIntPosCstExpr(){}
virtual int64 Min() const {
return CapProd(expr_->Min(), value_);
return expr_->Min() * value_;
}
virtual void SetMin(int64 m);
virtual void SetMin(int64 m) {
expr_->SetMin(PosIntDivUp(m, value_));
}
virtual int64 Max() const {
return CapProd(expr_->Max(), value_);
return expr_->Max() * value_;
}
virtual void SetMax(int64 m);
virtual bool Bound() const { return (expr_->Bound()); }
virtual void SetMax(int64 m) {
expr_->SetMax(PosIntDivDown(m, value_));
}
virtual bool Bound() const {
return (expr_->Bound());
}
virtual string name() const {
return StringPrintf("(%s * %" GG_LL_FORMAT "d)",
expr_->name().c_str(), value_);
}
virtual string DebugString() const {
return StringPrintf("(%s * %" GG_LL_FORMAT "d)",
expr_->DebugString().c_str(), value_);
}
virtual void WhenRange(Demon* d) {
expr_->WhenRange(d);
}
virtual IntVar* CastToVar();
virtual IntVar* CastToVar() {
Solver* const s = solver();
IntVar* var = NULL;
if (expr_->IsVar() &&
reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
var = s->RegisterIntVar(s->RevAlloc(
new TimesPosCstBoolVar(s,
reinterpret_cast<BooleanVar*>(expr_),
value_)));
} else {
var = s->RegisterIntVar(
s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
}
return var;
}
virtual void Accept(ModelVisitor* const visitor) const {
visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
@@ -3118,42 +3152,81 @@ class TimesIntPosCstExpr : public BaseIntExpr {
const int64 value_;
};
TimesIntPosCstExpr::TimesIntPosCstExpr(Solver* const s,
IntExpr* const e,
int64 v)
: BaseIntExpr(s), expr_(e), value_(v) {
CHECK_GE(v, 0);
}
TimesIntPosCstExpr::~TimesIntPosCstExpr() {}
void TimesIntPosCstExpr::SetMin(int64 m) {
if (m != kint64min) {
expr_->SetMin(PosIntDivUp(m, value_));
class SafeTimesIntPosCstExpr : public BaseIntExpr {
public:
SafeTimesIntPosCstExpr(Solver* const s, IntExpr* const e, int64 v)
: BaseIntExpr(s), expr_(e), value_(v) {
CHECK_GE(v, 0);
}
}
void TimesIntPosCstExpr::SetMax(int64 m) {
if (m != kint64max) {
expr_->SetMax(PosIntDivDown(m, value_));
}
}
virtual ~SafeTimesIntPosCstExpr(){}
IntVar* TimesIntPosCstExpr::CastToVar() {
Solver* const s = solver();
IntVar* var = NULL;
if (expr_->IsVar() &&
reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
var = s->RegisterIntVar(s->RevAlloc(
new TimesPosCstBoolVar(s,
reinterpret_cast<BooleanVar*>(expr_),
value_)));
} else {
var = s->RegisterIntVar(
s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
virtual int64 Min() const {
return CapProd(expr_->Min(), value_);
}
return var;
}
virtual void SetMin(int64 m) {
if (m != kint64min) {
expr_->SetMin(PosIntDivUp(m, value_));
}
}
virtual int64 Max() const {
return CapProd(expr_->Max(), value_);
}
virtual void SetMax(int64 m) {
if (m != kint64max) {
expr_->SetMax(PosIntDivDown(m, value_));
}
}
virtual bool Bound() const {
return (expr_->Bound());
}
virtual string name() const {
return StringPrintf("(%s * %" GG_LL_FORMAT "d)",
expr_->name().c_str(), value_);
}
virtual string DebugString() const {
return StringPrintf("(%s * %" GG_LL_FORMAT "d)",
expr_->DebugString().c_str(), value_);
}
virtual void WhenRange(Demon* d) {
expr_->WhenRange(d);
}
virtual IntVar* CastToVar() {
Solver* const s = solver();
IntVar* var = NULL;
if (expr_->IsVar() &&
reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
var = s->RegisterIntVar(s->RevAlloc(
new TimesPosCstBoolVar(s,
reinterpret_cast<BooleanVar*>(expr_),
value_)));
} else {
var = s->RegisterIntVar(
s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
}
return var;
}
virtual void Accept(ModelVisitor* const visitor) const {
visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
expr_);
visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
}
private:
IntExpr* const expr_;
const int64 value_;
};
// ----- TimesIntNegCstExpr -----
@@ -5286,7 +5359,12 @@ IntExpr* Solver::MakeProd(IntExpr* const e, int64 v) {
} else if (v == -1) {
return MakeOpposite(e);
} else if (v > 0) {
result = RegisterIntExpr(RevAlloc(new TimesIntPosCstExpr(this, e, v)));
if (e->Max() > kint64max / v || e->Min() < kint64min / v) {
result =
RegisterIntExpr(RevAlloc(new SafeTimesIntPosCstExpr(this, e, v)));
} else {
result = RegisterIntExpr(RevAlloc(new TimesIntPosCstExpr(this, e, v)));
}
} else if (v == 0) {
result = MakeIntConst(0);
} else {

View File

@@ -120,7 +120,8 @@ void ParserState::ComputeViableTarget(
id == "int_minus" ||
id == "int_times" ||
id == "array_var_int_element" ||
id == "array_int_element") {
id == "array_int_element" ||
id == "int_abs") {
// Defines an int var.
const int define = FindTarget(spec->annotations());
if (define != CtSpec::kNoDefinition) {

View File

@@ -522,7 +522,7 @@ void p_int_plus(FlatZincModel* const model, CtSpec* const spec) {
IntVar* const right = model->GetIntVar(spec->Arg(1));
IntVar* const target = model->GetIntVar(spec->Arg(2));
IntVar* const left = solver->MakeDifference(target, right)->Var();
VLOG(1) << "Created " << spec->Arg(2)->DebugString() << " == "
VLOG(1) << "Created " << spec->Arg(0)->DebugString() << " == "
<< left->DebugString();
CHECK(model->IntegerVariable(spec->Arg(0)->getIntVar()) == NULL);
model->SetIntegerVariable(spec->Arg(0)->getIntVar(), left);
@@ -531,7 +531,7 @@ void p_int_plus(FlatZincModel* const model, CtSpec* const spec) {
IntVar* const left = model->GetIntVar(spec->Arg(0));
IntVar* const target = model->GetIntVar(spec->Arg(2));
IntVar* const right = solver->MakeDifference(target, left)->Var();
VLOG(1) << "Created " << spec->Arg(2)->DebugString() << " == "
VLOG(1) << "Created " << spec->Arg(1)->DebugString() << " == "
<< right->DebugString();
CHECK(model->IntegerVariable(spec->Arg(1)->getIntVar()) == NULL);
model->SetIntegerVariable(spec->Arg(1)->getIntVar(), left);
@@ -680,6 +680,11 @@ void p_array_bool_and(FlatZincModel* const model, CtSpec* const spec) {
IntVar* const boolvar = solver->MakeMin(variables)->Var();
CHECK(model->BooleanVariable(node_boolvar->getBoolVar()) == NULL);
model->SetBooleanVariable(node_boolvar->getBoolVar(), boolvar);
} else if (node_boolvar->isBool() && node_boolvar->getBool() == 1) {
VLOG(1) << "forcing array_bool_and to 1";
for (int i = 0; i < size; ++i) {
variables[i]->SetValue(1);
}
} else {
IntVar* const boolvar = model->GetIntVar(node_boolvar);
Constraint* const ct =
@@ -705,6 +710,11 @@ void p_array_bool_or(FlatZincModel* const model, CtSpec* const spec) {
IntVar* const boolvar = solver->MakeMax(variables)->Var();
CHECK(model->BooleanVariable(node_boolvar->getBoolVar()) == NULL);
model->SetBooleanVariable(node_boolvar->getBoolVar(), boolvar);
} else if (node_boolvar->isBool() && node_boolvar->getBool() == 0) {
VLOG(1) << "forcing array_bool_or to 0";
for (int i = 0; i < size; ++i) {
variables[i]->SetValue(0);
}
} else {
IntVar* const boolvar = model->GetIntVar(node_boolvar);
Constraint* const ct =
@@ -950,11 +960,19 @@ void p_int_in(FlatZincModel* const model, CtSpec* const spec) {
void p_abs(FlatZincModel* const model, CtSpec* const spec) {
Solver* const solver = model->solver();
IntVar* const left = model->GetIntVar(spec->Arg(0));
IntVar* const target = model->GetIntVar(spec->Arg(1));
Constraint* const ct =
solver->MakeEquality(solver->MakeAbs(left)->Var(), target);
VLOG(1) << "Posted " << ct->DebugString();
solver->AddConstraint(ct);
if (spec->Arg(1)->isIntVar() &&
spec->defines() == spec->Arg(1)->getIntVar()) {
VLOG(1) << "Aliasing int_abs";
CHECK(model->IntegerVariable(spec->defines()) == NULL);
IntVar* const target = solver->MakeAbs(left)->Var();
model->SetIntegerVariable(spec->defines(), target);
} else {
IntVar* const target = model->GetIntVar(spec->Arg(1));
Constraint* const ct =
solver->MakeEquality(solver->MakeAbs(left)->Var(), target);
VLOG(1) << "Posted " << ct->DebugString();
solver->AddConstraint(ct);
}
}
void p_all_different_int(FlatZincModel* const model, CtSpec* const spec) {