fix slowdown. remove IntExprArrayElement expression

This commit is contained in:
lperron@google.com
2012-07-06 09:15:51 +00:00
parent af4e2e2b48
commit 856bfaafdb

View File

@@ -1412,200 +1412,7 @@ class IntExprArrayPositionCt : public Constraint {
IntVarIterator* const index_iterator_;
};
// ----- IntExprArrayElement -----
class IntExprArrayElement : public BaseIntExpr {
public:
IntExprArrayElement(Solver* const s,
const IntVar* const * vars,
int size,
IntVar* const expr);
virtual ~IntExprArrayElement() {}
virtual int64 Min() const;
virtual void SetMin(int64 m);
virtual int64 Max() const;
virtual void SetMax(int64 m);
virtual void SetRange(int64 mi, int64 ma);
virtual bool Bound() const;
virtual string name() const {
return StringPrintf("IntArrayElement(%s, %s)",
NameArray(vars_.get(), size_, ", ").c_str(),
var_->name().c_str());
}
virtual string DebugString() const {
return StringPrintf("IntArrayElement(%s, %s)",
DebugStringArray(vars_.get(), size_, ", ").c_str(),
var_->DebugString().c_str());
}
virtual void WhenRange(Demon* d) {
var_->WhenRange(d);
for (int i = 0; i < size_; ++i) {
vars_[i]->WhenRange(d);
}
}
virtual IntVar* CastToVar() {
Solver* const s = solver();
int64 vmin = 0LL;
int64 vmax = 0LL;
Range(&vmin, &vmax);
IntVar* var = solver()->MakeIntVar(vmin, vmax);
s->AddCastConstraint(s->RevAlloc(new IntExprArrayElementCt(s,
vars_.get(),
size_,
var_,
var)),
var,
this);
return var;
}
virtual void Accept(ModelVisitor* const visitor) const {
visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
vars_.get(),
size_);
visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
var_);
visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
}
private:
scoped_array<IntVar*> vars_;
int size_;
IntVar* const var_;
};
IntExprArrayElement::IntExprArrayElement(Solver* const s,
const IntVar* const * vars,
int size,
IntVar* const v)
: BaseIntExpr(s), vars_(new IntVar*[size]),
size_(size), var_(v) {
CHECK(vars);
memcpy(vars_.get(), vars, size_ * sizeof(*vars));
}
int64 IntExprArrayElement::Min() const {
const int64 emin = std::max(0LL, var_->Min());
const int64 emax = std::min(size_ - 1LL, var_->Max());
int64 res = kint64max;
for (int i = emin; i <= emax; ++i) {
const int64 vmin = vars_[i]->Min();
if (vmin < res && var_->Contains(i)) {
res = vmin;
}
}
return res;
}
void IntExprArrayElement::SetMin(int64 m) {
const int64 emin = std::max(0LL, var_->Min());
const int64 emax = std::min(size_ - 1LL, var_->Max());
if (emin == emax) {
var_->SetValue(emin); // in case it was reduced by the above min/max.
vars_[emin]->SetMin(m);
} else {
int64 nmin = emin;
while (nmin <= emax && vars_[nmin]->Max() < m) {
nmin++;
}
if (nmin > emax) {
solver()->Fail();
}
int64 nmax = emax;
while (nmax >= nmin && vars_[nmax]->Max() < m) {
nmax--;
}
var_->SetRange(nmin, nmax);
if (var_->Bound()) {
vars_[var_->Min()]->SetMin(m);
}
}
}
int64 IntExprArrayElement::Max() const {
const int64 emin = std::max(0LL, var_->Min());
const int64 emax = std::min(size_ - 1LL, var_->Max());
int64 res = kint64min;
for (int i = emin; i <= emax; ++i) {
const int64 vmax = vars_[i]->Max();
if (vmax > res && var_->Contains(i)) {
res = vmax;
}
}
return res;
}
void IntExprArrayElement::SetMax(int64 m) {
const int64 emin = std::max(0LL, var_->Min());
const int64 emax = std::min(size_ - 1LL, var_->Max());
if (emin == emax) {
var_->SetValue(emin); // in case it was reduced by the above min/max.
vars_[emin]->SetMax(m);
} else {
int64 nmin = emin;
while (nmin <= emax && vars_[nmin]->Min() > m) {
nmin++;
}
if (nmin > emax) {
solver()->Fail();
}
int64 nmax = emax;
while (nmax >= nmin && vars_[nmax]->Min() > m) {
nmax--;
}
var_->SetRange(nmin, nmax);
if (var_->Bound()) {
vars_[var_->Min()]->SetMax(m);
}
}
}
void IntExprArrayElement::SetRange(int64 mi, int64 ma) {
if (mi > ma) {
solver()->Fail();
}
const int64 emin = std::max(0LL, var_->Min());
const int64 emax = std::min(size_ - 1LL, var_->Max());
if (emin == emax) {
var_->SetValue(emin); // in case it was reduced by the above min/max.
vars_[emin]->SetRange(mi, ma);
} else {
int64 nmin = emin;
while (nmin <= emax && (vars_[nmin]->Min() > ma ||
vars_[nmin]->Max() < mi)) {
nmin++;
}
if (nmin > emax) {
solver()->Fail();
}
int64 nmax = emax;
while (nmax >= nmin && (vars_[nmax]->Max() < mi ||
vars_[nmax]->Min() > ma)) {
nmax--;
}
if (nmax < emin) {
solver()->Fail();
}
var_->SetRange(nmin, nmax);
if (var_->Bound()) {
vars_[var_->Min()]->SetRange(mi, ma);
}
}
}
bool IntExprArrayElement::Bound() const {
const int64 emin = std::max(0LL, var_->Min());
const int64 emax = std::min(size_ - 1LL, var_->Max());
const int64 v = vars_[emin]->Min();
for (int i = emin; i <= emax; ++i) {
if (var_->Contains(i) && (!vars_[i]->Bound() || vars_[i]->Value() != v)) {
return false;
}
}
return true;
}
// ----- Misc -----
bool AreAllBound(const std::vector<IntVar*>& vars) {
for (int i = 0; i < vars.size(); ++i) {
@@ -1625,9 +1432,27 @@ IntExpr* Solver::MakeElement(const std::vector<IntVar*>& vars, IntVar* const ind
}
return MakeElement(values, index);
}
CHECK_EQ(this, index->solver());
return RegisterIntExpr(RevAlloc(
new IntExprArrayElement(this, vars.data(), vars.size(), index)));
int64 emin = kint64max;
int64 emax = kint64min;
scoped_ptr<IntVarIterator> iterator(index->MakeDomainIterator(false));
for (iterator->Init(); iterator->Ok(); iterator->Next()) {
const int64 index_value = iterator->Value();
emin = std::min(emin, vars[index_value]->Min());
emax = std::max(emax, vars[index_value]->Max());
}
const string vname =
vars.size() > 10 ?
StringPrintf("ElementVar(var array of size %" GG_LL_FORMAT "d, %s)",
vars.size(), index->DebugString().c_str()) :
StringPrintf("ElementVar([%s], %s)",
NameVector(vars, ", ").c_str(), index->name().c_str());
IntVar* const element_var = MakeIntVar(emin, emax, vname);
AddConstraint(RevAlloc(new IntExprArrayElementCt(this,
vars.data(),
vars.size(),
index,
element_var)));
return element_var;
}
Constraint* Solver::MakeElementEquality(const std::vector<int64>& vals,