From 856bfaafdbb8495c8bc5c30a716d86e5fef68045 Mon Sep 17 00:00:00 2001 From: "lperron@google.com" Date: Fri, 6 Jul 2012 09:15:51 +0000 Subject: [PATCH] fix slowdown. remove IntExprArrayElement expression --- src/constraint_solver/element.cc | 219 ++++--------------------------- 1 file changed, 22 insertions(+), 197 deletions(-) diff --git a/src/constraint_solver/element.cc b/src/constraint_solver/element.cc index 5a40c9ea34..4f8d8914f5 100644 --- a/src/constraint_solver/element.cc +++ b/src/constraint_solver/element.cc @@ -1412,200 +1412,7 @@ class IntExprArrayPositionCt : public Constraint { IntVarIterator* const index_iterator_; }; -// ----- IntExprArrayElement ----- - -class IntExprArrayElement : public BaseIntExpr { - public: - IntExprArrayElement(Solver* const s, - const IntVar* const * vars, - int size, - IntVar* const expr); - virtual ~IntExprArrayElement() {} - - virtual int64 Min() const; - virtual void SetMin(int64 m); - virtual int64 Max() const; - virtual void SetMax(int64 m); - virtual void SetRange(int64 mi, int64 ma); - virtual bool Bound() const; - virtual string name() const { - return StringPrintf("IntArrayElement(%s, %s)", - NameArray(vars_.get(), size_, ", ").c_str(), - var_->name().c_str()); - } - virtual string DebugString() const { - return StringPrintf("IntArrayElement(%s, %s)", - DebugStringArray(vars_.get(), size_, ", ").c_str(), - var_->DebugString().c_str()); - } - virtual void WhenRange(Demon* d) { - var_->WhenRange(d); - for (int i = 0; i < size_; ++i) { - vars_[i]->WhenRange(d); - } - } - virtual IntVar* CastToVar() { - Solver* const s = solver(); - int64 vmin = 0LL; - int64 vmax = 0LL; - Range(&vmin, &vmax); - IntVar* var = solver()->MakeIntVar(vmin, vmax); - s->AddCastConstraint(s->RevAlloc(new IntExprArrayElementCt(s, - vars_.get(), - size_, - var_, - var)), - var, - this); - return var; - } - - virtual void Accept(ModelVisitor* const visitor) const { - visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this); - visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument, - vars_.get(), - size_); - visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, - var_); - visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this); - } - - private: - scoped_array vars_; - int size_; - IntVar* const var_; -}; - -IntExprArrayElement::IntExprArrayElement(Solver* const s, - const IntVar* const * vars, - int size, - IntVar* const v) - : BaseIntExpr(s), vars_(new IntVar*[size]), - size_(size), var_(v) { - CHECK(vars); - memcpy(vars_.get(), vars, size_ * sizeof(*vars)); -} - -int64 IntExprArrayElement::Min() const { - const int64 emin = std::max(0LL, var_->Min()); - const int64 emax = std::min(size_ - 1LL, var_->Max()); - int64 res = kint64max; - for (int i = emin; i <= emax; ++i) { - const int64 vmin = vars_[i]->Min(); - if (vmin < res && var_->Contains(i)) { - res = vmin; - } - } - return res; -} - -void IntExprArrayElement::SetMin(int64 m) { - const int64 emin = std::max(0LL, var_->Min()); - const int64 emax = std::min(size_ - 1LL, var_->Max()); - if (emin == emax) { - var_->SetValue(emin); // in case it was reduced by the above min/max. - vars_[emin]->SetMin(m); - } else { - int64 nmin = emin; - while (nmin <= emax && vars_[nmin]->Max() < m) { - nmin++; - } - if (nmin > emax) { - solver()->Fail(); - } - int64 nmax = emax; - while (nmax >= nmin && vars_[nmax]->Max() < m) { - nmax--; - } - var_->SetRange(nmin, nmax); - if (var_->Bound()) { - vars_[var_->Min()]->SetMin(m); - } - } -} - -int64 IntExprArrayElement::Max() const { - const int64 emin = std::max(0LL, var_->Min()); - const int64 emax = std::min(size_ - 1LL, var_->Max()); - int64 res = kint64min; - for (int i = emin; i <= emax; ++i) { - const int64 vmax = vars_[i]->Max(); - if (vmax > res && var_->Contains(i)) { - res = vmax; - } - } - return res; -} - -void IntExprArrayElement::SetMax(int64 m) { - const int64 emin = std::max(0LL, var_->Min()); - const int64 emax = std::min(size_ - 1LL, var_->Max()); - if (emin == emax) { - var_->SetValue(emin); // in case it was reduced by the above min/max. - vars_[emin]->SetMax(m); - } else { - int64 nmin = emin; - while (nmin <= emax && vars_[nmin]->Min() > m) { - nmin++; - } - if (nmin > emax) { - solver()->Fail(); - } - int64 nmax = emax; - while (nmax >= nmin && vars_[nmax]->Min() > m) { - nmax--; - } - var_->SetRange(nmin, nmax); - if (var_->Bound()) { - vars_[var_->Min()]->SetMax(m); - } - } -} - -void IntExprArrayElement::SetRange(int64 mi, int64 ma) { - if (mi > ma) { - solver()->Fail(); - } - const int64 emin = std::max(0LL, var_->Min()); - const int64 emax = std::min(size_ - 1LL, var_->Max()); - if (emin == emax) { - var_->SetValue(emin); // in case it was reduced by the above min/max. - vars_[emin]->SetRange(mi, ma); - } else { - int64 nmin = emin; - while (nmin <= emax && (vars_[nmin]->Min() > ma || - vars_[nmin]->Max() < mi)) { - nmin++; - } - if (nmin > emax) { - solver()->Fail(); - } - int64 nmax = emax; - while (nmax >= nmin && (vars_[nmax]->Max() < mi || - vars_[nmax]->Min() > ma)) { - nmax--; - } - if (nmax < emin) { - solver()->Fail(); - } - var_->SetRange(nmin, nmax); - if (var_->Bound()) { - vars_[var_->Min()]->SetRange(mi, ma); - } - } -} - -bool IntExprArrayElement::Bound() const { - const int64 emin = std::max(0LL, var_->Min()); - const int64 emax = std::min(size_ - 1LL, var_->Max()); - const int64 v = vars_[emin]->Min(); - for (int i = emin; i <= emax; ++i) { - if (var_->Contains(i) && (!vars_[i]->Bound() || vars_[i]->Value() != v)) { - return false; - } - } - return true; -} +// ----- Misc ----- bool AreAllBound(const std::vector& vars) { for (int i = 0; i < vars.size(); ++i) { @@ -1625,9 +1432,27 @@ IntExpr* Solver::MakeElement(const std::vector& vars, IntVar* const ind } return MakeElement(values, index); } - CHECK_EQ(this, index->solver()); - return RegisterIntExpr(RevAlloc( - new IntExprArrayElement(this, vars.data(), vars.size(), index))); + int64 emin = kint64max; + int64 emax = kint64min; + scoped_ptr iterator(index->MakeDomainIterator(false)); + for (iterator->Init(); iterator->Ok(); iterator->Next()) { + const int64 index_value = iterator->Value(); + emin = std::min(emin, vars[index_value]->Min()); + emax = std::max(emax, vars[index_value]->Max()); + } + const string vname = + vars.size() > 10 ? + StringPrintf("ElementVar(var array of size %" GG_LL_FORMAT "d, %s)", + vars.size(), index->DebugString().c_str()) : + StringPrintf("ElementVar([%s], %s)", + NameVector(vars, ", ").c_str(), index->name().c_str()); + IntVar* const element_var = MakeIntVar(emin, emax, vname); + AddConstraint(RevAlloc(new IntExprArrayElementCt(this, + vars.data(), + vars.size(), + index, + element_var))); + return element_var; } Constraint* Solver::MakeElementEquality(const std::vector& vals,