diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 46c06284b6..d2e6fd6fff 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -386,6 +386,18 @@ cc_test( ], ) +cc_library( + name = "presolve_encoding", + srcs = ["presolve_encoding.cc"], + hdrs = ["presolve_encoding.h"], + deps = [ + ":cp_model_utils", + ":presolve_context", + "//ortools/util:sorted_interval_list", + "@abseil-cpp//absl/log", + ], +) + cc_proto_library( name = "cp_model_cc_proto", visibility = ["//visibility:public"], @@ -1322,6 +1334,7 @@ cc_library( ":model", ":precedences", ":presolve_context", + ":presolve_encoding", ":presolve_util", ":probing", ":sat_base", diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index e8ebda8e3a..ae01a7177c 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -76,6 +76,7 @@ #include "ortools/sat/model.h" #include "ortools/sat/precedences.h" #include "ortools/sat/presolve_context.h" +#include "ortools/sat/presolve_encoding.h" #include "ortools/sat/presolve_util.h" #include "ortools/sat/probing.h" #include "ortools/sat/sat_base.h" @@ -430,20 +431,9 @@ bool CpModelPresolver::PresolveBoolOr(ConstraintProto* ct) { // done elsewhere. ABSL_MUST_USE_RESULT bool CpModelPresolver::MarkConstraintAsFalse( ConstraintProto* ct, std::string_view reason) { - DCHECK(!reason.empty()); - if (HasEnforcementLiteral(*ct)) { - // Change the constraint to a bool_or. - ct->mutable_bool_or()->clear_literals(); - for (const int lit : ct->enforcement_literal()) { - ct->mutable_bool_or()->add_literals(NegatedRef(lit)); - } - ct->clear_enforcement_literal(); - PresolveBoolOr(ct); - context_->UpdateRuleStats(reason); - return true; - } else { - return context_->NotifyThatModelIsUnsat(reason); - } + if (!context_->MarkConstraintAsFalse(ct, reason)) return false; + if (ct->constraint_case() == ConstraintProto::kBoolOr) PresolveBoolOr(ct); + return true; } ABSL_MUST_USE_RESULT bool CpModelPresolver::MarkOptionalIntervalAsFalse( @@ -870,30 +860,6 @@ int GetFirstVar(ExpressionList exprs) { return -1; } -bool IsAffineIntAbs(const ConstraintProto& ct) { - if (ct.constraint_case() != ConstraintProto::kLinMax || - ct.lin_max().exprs_size() != 2 || ct.lin_max().target().vars_size() > 1 || - ct.lin_max().exprs(0).vars_size() != 1 || - ct.lin_max().exprs(1).vars_size() != 1) { - return false; - } - - const LinearArgumentProto& lin_max = ct.lin_max(); - if (lin_max.exprs(0).offset() != -lin_max.exprs(1).offset()) return false; - if (PositiveRef(lin_max.exprs(0).vars(0)) != - PositiveRef(lin_max.exprs(1).vars(0))) { - return false; - } - - const int64_t left_coeff = RefIsPositive(lin_max.exprs(0).vars(0)) - ? lin_max.exprs(0).coeffs(0) - : -lin_max.exprs(0).coeffs(0); - const int64_t right_coeff = RefIsPositive(lin_max.exprs(1).vars(0)) - ? lin_max.exprs(1).coeffs(0) - : -lin_max.exprs(1).coeffs(0); - return left_coeff == -right_coeff; -} - } // namespace bool CpModelPresolver::PropagateAndReduceAffineMax(ConstraintProto* ct) { @@ -12293,9 +12259,9 @@ void CpModelPresolver::FindBigAtMostOneAndLinearOverlap( for (int x = 0; x < context_->working_model->variables().size(); ++x) { // We pick a variable x that appear in some AMO. + if (helper->NumAmoForVariable(x) == 0) continue; if (time_limit_->LimitReached()) break; if (timer.WorkLimitIsReached()) break; - if (helper->NumAmoForVariable(x) == 0) continue; amo_cts.clear(); timer.TrackSimpleLoop(context_->VarToConstraints(x).size()); @@ -13363,121 +13329,6 @@ void CpModelPresolver::ProcessVariableInTwoAtMostOrExactlyOne(int var) { } } -// If we have a bunch of constraint of the form literal => Y \in domain and -// another constraint Y = f(X), we can remove Y, that constraint, and transform -// all linear1 from constraining Y to constraining X. -// -// We can for instance do it for Y = abs(X) or Y = X^2 easily. More complex -// function might be trickier. -// -// Note that we can't always do it in the reverse direction though! -// If we have l => X = -1, we can't transfer that to abs(X) for instance, since -// X=1 will also map to abs(-1). We can only do it if for all implied domain D -// we have f^-1(f(D)) = D, which is not easy to check. -void CpModelPresolver::MaybeTransferLinear1ToAnotherVariable(int var) { - // Find the extra constraint and do basic CHECKs. - int other_c; - int num_others = 0; - std::vector to_rewrite; - for (const int c : context_->VarToConstraints(var)) { - if (c >= 0) { - const ConstraintProto& ct = context_->working_model->constraints(c); - if (ct.constraint_case() == ConstraintProto::kLinear && - ct.linear().vars().size() == 1) { - to_rewrite.push_back(c); - continue; - } - } - ++num_others; - other_c = c; - } - if (num_others != 1) return; - if (other_c < 0) return; - - // In general constraint with more than two variable can't be removed. - // Similarly for linear2 with non-fixed rhs as we would need to check the form - // of all implied domain. - const auto& other_ct = context_->working_model->constraints(other_c); - if (context_->ConstraintToVars(other_c).size() != 2 || - !other_ct.enforcement_literal().empty() || - other_ct.constraint_case() == ConstraintProto::kLinear) { - return; - } - - // This will be the rewriting function. It takes the implied domain of var - // from linear1, and return a pair {new_var, new_var_implied_domain}. - std::function(const Domain& implied)> transfer_f = - nullptr; - - // We only support a few cases. - // - // TODO(user): implement more! Note that the linear2 case was tempting, but if - // we don't have an equality, we can't transfer, and if we do, we actually - // have affine equivalence already. - if (other_ct.constraint_case() == ConstraintProto::kLinMax && - other_ct.lin_max().target().vars().size() == 1 && - other_ct.lin_max().target().vars(0) == var && - std::abs(other_ct.lin_max().target().coeffs(0)) == 1 && - IsAffineIntAbs(other_ct)) { - context_->UpdateRuleStats("linear1: transferred from abs(X) to X"); - const LinearExpressionProto& target = other_ct.lin_max().target(); - const LinearExpressionProto& expr = other_ct.lin_max().exprs(0); - transfer_f = [target = target, expr = expr](const Domain& implied) { - Domain target_domain = - implied.ContinuousMultiplicationBy(target.coeffs(0)) - .AdditionWith(Domain(target.offset())); - target_domain = target_domain.IntersectionWith( - Domain(0, std::numeric_limits::max())); - - // We have target = abs(expr). - const Domain expr_domain = - target_domain.UnionWith(target_domain.Negation()); - const Domain new_domain = expr_domain.AdditionWith(Domain(-expr.offset())) - .InverseMultiplicationBy(expr.coeffs(0)); - return std::make_pair(expr.vars(0), new_domain); - }; - } - - if (transfer_f == nullptr) { - context_->UpdateRuleStats( - "TODO linear1: appear in only one extra 2-var constraint"); - return; - } - - // Applies transfer_f to all linear1. - std::sort(to_rewrite.begin(), to_rewrite.end()); - const Domain var_domain = context_->DomainOf(var); - for (const int c : to_rewrite) { - ConstraintProto* ct = context_->working_model->mutable_constraints(c); - if (ct->linear().vars(0) != var || ct->linear().coeffs(0) != 1) { - // This shouldn't happen. - LOG(INFO) << "Aborted in MaybeTransferLinear1ToAnotherVariable()"; - return; - } - - const Domain implied = - var_domain.IntersectionWith(ReadDomainFromProto(ct->linear())); - auto [new_var, new_domain] = transfer_f(implied); - const Domain current = context_->DomainOf(new_var); - new_domain = new_domain.IntersectionWith(current); - if (new_domain.IsEmpty()) { - if (!MarkConstraintAsFalse(ct, "linear1: unsat transfer")) return; - } else if (new_domain == current) { - ct->Clear(); - } else { - ct->mutable_linear()->set_vars(0, new_var); - FillDomainInProto(new_domain, ct->mutable_linear()); - } - context_->UpdateConstraintVariableUsage(c); - } - - // Copy other_ct to the mapping model and delete var! - context_->NewMappingConstraint(other_ct, __FILE__, __LINE__); - context_->working_model->mutable_constraints(other_c)->Clear(); - context_->UpdateConstraintVariableUsage(other_c); - context_->MarkVariableAsRemoved(var); -} - // TODO(user): We can still remove the variable even if we want to keep // all feasible solutions for the cases when we have a full encoding. // Similarly this shouldn't break symmetry, but we do need to do it for all @@ -13499,13 +13350,46 @@ void CpModelPresolver::ProcessVariableOnlyUsedInEncoding(int var) { return; } - if (!context_->VariableIsOnlyUsedInEncodingAndMaybeInObjective(var)) { - if (context_->VariableIsOnlyUsedInLinear1AndOneExtraConstraint(var)) { - MaybeTransferLinear1ToAnotherVariable(var); - return; + const bool is_only_used_in_linear1 = + context_->VariableIsOnlyUsedInLinear1AndOneExtraConstraint(var); + const bool is_only_used_in_encoding = + context_->VariableIsOnlyUsedInEncodingAndMaybeInObjective(var); + if (!is_only_used_in_encoding && is_only_used_in_linear1) { + VariableEncodingLocalModel local_model; + local_model.var = var; + local_model.single_constraint_using_the_var_outside_the_local_model = -1; + local_model.var_in_more_than_one_constraint_outside_the_local_model = false; + for (const int c : context_->VarToConstraints(var)) { + if (c >= 0) { + const ConstraintProto& ct = context_->working_model->constraints(c); + if (ct.constraint_case() == ConstraintProto::kLinear && + ct.linear().vars().size() == 1 && ct.linear().vars(0) == var) { + local_model.linear1_constraints.push_back(c); + continue; + } + } + if (c == kObjectiveConstraint) { + local_model.variable_coeff_in_objective = + context_->ObjectiveMap().at(var); + } else if ( + local_model.single_constraint_using_the_var_outside_the_local_model == + -1 && + c >= 0) { + // First "other" constraint. + local_model.single_constraint_using_the_var_outside_the_local_model = c; + } else { + // We have a second "other" constraint. + local_model.single_constraint_using_the_var_outside_the_local_model = + -1; + local_model.var_in_more_than_one_constraint_outside_the_local_model = + true; + } } + + MaybeTransferLinear1ToAnotherVariable(local_model, context_); return; } + if (!is_only_used_in_encoding) return; // Presolve newly created constraints. const int old_size = context_->working_model->constraints_size(); @@ -13643,18 +13527,19 @@ bool CpModelPresolver::ProcessChangedVariables(std::vector* in_queue, if (!context_->CanonicalizeOneObjectiveVariable(v)) return false; in_queue->resize(context_->working_model->constraints_size(), false); + const int size_before = queue->size(); for (const int c : context_->VarToConstraints(v)) { if (c >= 0 && !(*in_queue)[c]) { (*in_queue)[c] = true; queue->push_back(c); } } + + // Make sure the order is deterministic! because var_to_constraints[] + // order changes from one run to the next. + std::sort(queue->begin() + size_before, queue->end()); } context_->modified_domains.ResetAllToFalse(); - - // Make sure the order is deterministic! because var_to_constraints[] - // order changes from one run to the next. - std::sort(queue->begin(), queue->end()); return !queue->empty(); } @@ -13871,47 +13756,58 @@ void CpModelPresolver::PresolveToFixPoint() { // TODO(user): ideally we should "wake-up" any constraint that contains an // absent interval in the main propagation loop above. But we currently don't // maintain such list. - const int num_constraints = context_->working_model->constraints_size(); - for (int c = 0; c < num_constraints; ++c) { - if (time_limit_->LimitReached()) break; - ConstraintProto* ct = context_->working_model->mutable_constraints(c); - switch (ct->constraint_case()) { - case ConstraintProto::kNoOverlap: - // Filter out absent intervals. - if (PresolveNoOverlap(ct)) { - context_->UpdateConstraintVariableUsage(c); - } - break; - case ConstraintProto::kNoOverlap2D: - // Filter out absent intervals. - if (PresolveNoOverlap2D(c, ct)) { - context_->UpdateConstraintVariableUsage(c); - } - break; - case ConstraintProto::kCumulative: - // Filter out absent intervals. - if (PresolveCumulative(ct)) { - context_->UpdateConstraintVariableUsage(c); - } - break; - case ConstraintProto::kBoolOr: { - // Try to infer domain reductions from clauses and the saved "implies in - // domain" relations. - for (const auto& pair : - context_->deductions.ProcessClause(ct->bool_or().literals())) { - bool modified = false; - if (!context_->IntersectDomainWith(pair.first, pair.second, - &modified)) { - return; + if (!time_limit_->LimitReached()) { + const int num_constraints = context_->working_model->constraints_size(); + TimeLimitCheckEveryNCalls bool_or_check_time_limit(100, time_limit_); + for (int c = 0; c < num_constraints; ++c) { + ConstraintProto* ct = context_->working_model->mutable_constraints(c); + // We don't want to check the time limit at each "small" constraint as + // there could be many. + bool check_time_limit = false; + + switch (ct->constraint_case()) { + case ConstraintProto::kNoOverlap: + // Filter out absent intervals. + if (PresolveNoOverlap(ct)) { + context_->UpdateConstraintVariableUsage(c); } - if (modified) { - context_->UpdateRuleStats("deductions: reduced variable domain"); + check_time_limit = true; + break; + case ConstraintProto::kNoOverlap2D: + // Filter out absent intervals. + if (PresolveNoOverlap2D(c, ct)) { + context_->UpdateConstraintVariableUsage(c); } + check_time_limit = true; + break; + case ConstraintProto::kCumulative: + // Filter out absent intervals. + if (PresolveCumulative(ct)) { + context_->UpdateConstraintVariableUsage(c); + } + check_time_limit = true; + break; + case ConstraintProto::kBoolOr: { + // Try to infer domain reductions from clauses and the saved "implies + // in domain" relations. + for (const auto& pair : + context_->deductions.ProcessClause(ct->bool_or().literals())) { + bool modified = false; + if (!context_->IntersectDomainWith(pair.first, pair.second, + &modified)) { + return; + } + if (modified) { + context_->UpdateRuleStats("deductions: reduced variable domain"); + } + } + if (bool_or_check_time_limit.LimitReached()) check_time_limit = true; + break; } - break; + default: + break; } - default: - break; + if (check_time_limit && time_limit_->LimitReached()) break; } } diff --git a/ortools/sat/cp_model_presolve.h b/ortools/sat/cp_model_presolve.h index f17aea0bc4..b6068867a2 100644 --- a/ortools/sat/cp_model_presolve.h +++ b/ortools/sat/cp_model_presolve.h @@ -335,7 +335,6 @@ class CpModelPresolver { // merge this with what ExpandObjective() is doing. void ShiftObjectiveWithExactlyOnes(); - void MaybeTransferLinear1ToAnotherVariable(int var); void ProcessVariableOnlyUsedInEncoding(int var); void TryToSimplifyDomain(int var); diff --git a/ortools/sat/cp_model_utils.cc b/ortools/sat/cp_model_utils.cc index 1d925ea1c2..8725d300d6 100644 --- a/ortools/sat/cp_model_utils.cc +++ b/ortools/sat/cp_model_utils.cc @@ -1152,5 +1152,29 @@ int CombineSeed(int base_seed, int64_t delta) { return static_cast(FingerprintSingleField(base_seed, fp) & (0x7FFFFFFF)); } +bool IsAffineIntAbs(const ConstraintProto& ct) { + if (ct.constraint_case() != ConstraintProto::kLinMax || + ct.lin_max().exprs_size() != 2 || ct.lin_max().target().vars_size() > 1 || + ct.lin_max().exprs(0).vars_size() != 1 || + ct.lin_max().exprs(1).vars_size() != 1) { + return false; + } + + const LinearArgumentProto& lin_max = ct.lin_max(); + if (lin_max.exprs(0).offset() != -lin_max.exprs(1).offset()) return false; + if (PositiveRef(lin_max.exprs(0).vars(0)) != + PositiveRef(lin_max.exprs(1).vars(0))) { + return false; + } + + const int64_t left_coeff = RefIsPositive(lin_max.exprs(0).vars(0)) + ? lin_max.exprs(0).coeffs(0) + : -lin_max.exprs(0).coeffs(0); + const int64_t right_coeff = RefIsPositive(lin_max.exprs(1).vars(0)) + ? lin_max.exprs(1).coeffs(0) + : -lin_max.exprs(1).coeffs(0); + return left_coeff == -right_coeff; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/cp_model_utils.h b/ortools/sat/cp_model_utils.h index 8d8979677f..87eba2eaf0 100644 --- a/ortools/sat/cp_model_utils.h +++ b/ortools/sat/cp_model_utils.h @@ -289,6 +289,9 @@ bool SafeAddLinearExpressionToLinearConstraint( const LinearExpressionProto& expr, int64_t coefficient, LinearConstraintProto* linear); +// Returns if a constraint is of the form y = lin_max(x, -x). +bool IsAffineIntAbs(const ConstraintProto& ct); + // Returns true iff a == b * b_scaling. bool LinearExpressionProtosAreEqual(const LinearExpressionProto& a, const LinearExpressionProto& b, diff --git a/ortools/sat/presolve_context.cc b/ortools/sat/presolve_context.cc index 19bf5ffcac..3787156be4 100644 --- a/ortools/sat/presolve_context.cc +++ b/ortools/sat/presolve_context.cc @@ -633,6 +633,22 @@ bool PresolveContext::ConstraintIsInactive(int index) const { return false; } +bool PresolveContext::MarkConstraintAsFalse(ConstraintProto* ct, + std::string_view reason) { + DCHECK(!reason.empty()); + if (!HasEnforcementLiteral(*ct)) { + return NotifyThatModelIsUnsat(reason); + } + // Change the constraint to a bool_or. + ct->mutable_bool_or()->clear_literals(); + for (const int lit : ct->enforcement_literal()) { + ct->mutable_bool_or()->add_literals(NegatedRef(lit)); + } + ct->clear_enforcement_literal(); + UpdateRuleStats(reason); + return true; +} + bool PresolveContext::ConstraintIsOptional(int ct_ref) const { const ConstraintProto& ct = working_model->constraints(ct_ref); bool contains_one_free_literal = false; diff --git a/ortools/sat/presolve_context.h b/ortools/sat/presolve_context.h index f1348b2029..13fd7063bb 100644 --- a/ortools/sat/presolve_context.h +++ b/ortools/sat/presolve_context.h @@ -616,6 +616,10 @@ class PresolveContext { return interval_usage_[c]; } + // Note this function does not update the constraint graph. It assumes this is + // done elsewhere. + bool MarkConstraintAsFalse(ConstraintProto* ct, std::string_view reason); + // Checks if a constraint contains an enforcement literal set to false, // or if it has been cleared. bool ConstraintIsInactive(int ct_index) const; diff --git a/ortools/sat/presolve_encoding.cc b/ortools/sat/presolve_encoding.cc new file mode 100644 index 0000000000..33398ff76b --- /dev/null +++ b/ortools/sat/presolve_encoding.cc @@ -0,0 +1,136 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/presolve_encoding.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/sat/presolve_context.h" +#include "ortools/util/sorted_interval_list.h" + +namespace operations_research { +namespace sat { + +bool MaybeTransferLinear1ToAnotherVariable( + VariableEncodingLocalModel& local_model, PresolveContext* context) { + if (local_model.var == -1) return true; + if (local_model.variable_coeff_in_objective != 0) return true; + if (local_model.single_constraint_using_the_var_outside_the_local_model == + -1) { + return true; + } + const int other_c = + local_model.single_constraint_using_the_var_outside_the_local_model; + + const std::vector& to_rewrite = local_model.linear1_constraints; + + // In general constraint with more than two variable can't be removed. + // Similarly for linear2 with non-fixed rhs as we would need to check the form + // of all implied domain. + const auto& other_ct = context->working_model->constraints(other_c); + if (context->ConstraintToVars(other_c).size() != 2 || + !other_ct.enforcement_literal().empty() || + other_ct.constraint_case() == ConstraintProto::kLinear) { + return true; + } + + // This will be the rewriting function. It takes the implied domain of var + // from linear1, and return a pair {new_var, new_var_implied_domain}. + std::function(const Domain& implied)> transfer_f = + nullptr; + + const int var = local_model.var; + // We only support a few cases. + // + // TODO(user): implement more! Note that the linear2 case was tempting, but if + // we don't have an equality, we can't transfer, and if we do, we actually + // have affine equivalence already. + if (other_ct.constraint_case() == ConstraintProto::kLinMax && + other_ct.lin_max().target().vars().size() == 1 && + other_ct.lin_max().target().vars(0) == var && + std::abs(other_ct.lin_max().target().coeffs(0)) == 1 && + IsAffineIntAbs(other_ct)) { + context->UpdateRuleStats("linear1: transferred from abs(X) to X"); + const LinearExpressionProto& target = other_ct.lin_max().target(); + const LinearExpressionProto& expr = other_ct.lin_max().exprs(0); + transfer_f = [target = target, expr = expr](const Domain& implied) { + Domain target_domain = + implied.ContinuousMultiplicationBy(target.coeffs(0)) + .AdditionWith(Domain(target.offset())); + target_domain = target_domain.IntersectionWith( + Domain(0, std::numeric_limits::max())); + + // We have target = abs(expr). + const Domain expr_domain = + target_domain.UnionWith(target_domain.Negation()); + const Domain new_domain = expr_domain.AdditionWith(Domain(-expr.offset())) + .InverseMultiplicationBy(expr.coeffs(0)); + return std::make_pair(expr.vars(0), new_domain); + }; + } + + if (transfer_f == nullptr) { + context->UpdateRuleStats( + "TODO linear1: appear in only one extra 2-var constraint"); + return true; + } + + // Applies transfer_f to all linear1. + const Domain var_domain = context->DomainOf(var); + for (const int c : to_rewrite) { + ConstraintProto* ct = context->working_model->mutable_constraints(c); + if (ct->linear().vars(0) != var || ct->linear().coeffs(0) != 1) { + // This shouldn't happen. + LOG(INFO) << "Aborted in MaybeTransferLinear1ToAnotherVariable()"; + return true; + } + + const Domain implied = + var_domain.IntersectionWith(ReadDomainFromProto(ct->linear())); + auto [new_var, new_domain] = transfer_f(implied); + const Domain current = context->DomainOf(new_var); + new_domain = new_domain.IntersectionWith(current); + if (new_domain.IsEmpty()) { + if (!context->MarkConstraintAsFalse(ct, "linear1: unsat transfer")) { + return false; + } + } else if (new_domain == current) { + // Note that we don't need to remove this constraint from + // local_model.linear1_constraints since we will set + // local_model.var = -1 below. + ct->Clear(); + } else { + ct->mutable_linear()->set_vars(0, new_var); + FillDomainInProto(new_domain, ct->mutable_linear()); + } + context->UpdateConstraintVariableUsage(c); + } + + // Copy other_ct to the mapping model and delete var! + context->NewMappingConstraint(other_ct, __FILE__, __LINE__); + context->working_model->mutable_constraints(other_c)->Clear(); + context->UpdateConstraintVariableUsage(other_c); + context->MarkVariableAsRemoved(var); + local_model.var = -1; + return true; +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/presolve_encoding.h b/ortools/sat/presolve_encoding.h new file mode 100644 index 0000000000..6dad2318bd --- /dev/null +++ b/ortools/sat/presolve_encoding.h @@ -0,0 +1,65 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ORTOOLS_SAT_PRESOLVE_ENCODING_H_ +#define ORTOOLS_SAT_PRESOLVE_ENCODING_H_ + +#include +#include + +#include "ortools/sat/presolve_context.h" + +namespace operations_research { +namespace sat { + +struct VariableEncodingLocalModel { + // The integer variable that is encoded. Internally it can be replaced by + // -1 if some presolve rule removed the variable. + int var; + + // The linear1 constraint indexes that define conditional bounds on the + // variable. Those linear1 should have exactly one enforcement literal and + // satisfy `PositiveRef(enf) != var`. All linear1 restraining `var` and + // fulfilling the conditions above will appear here. + std::vector linear1_constraints; + + // Zero if `var` doesn't appear in the objective. + int64_t variable_coeff_in_objective = 0; + + // Note: the objective doesn't count as a constraint outside the local model. + bool var_in_more_than_one_constraint_outside_the_local_model; + + // Set to -1 if there is none or if the variable appears in more than one + // constraint outside the local model. + int single_constraint_using_the_var_outside_the_local_model = -1; +}; + +// If we have a bunch of constraint of the form literal => Y \in domain and +// another constraint Y = f(X), we can remove Y, that constraint, and transform +// all linear1 from constraining Y to constraining X. +// +// We can for instance do it for Y = abs(X) or Y = X^2 easily. More complex +// function might be trickier. +// +// Note that we can't always do it in the reverse direction though! +// If we have l => X = -1, we can't transfer that to abs(X) for instance, since +// X=1 will also map to abs(-1). We can only do it if for all implied domain D +// we have f^-1(f(D)) = D, which is not easy to check. +// Returns false if we prove unsat. +bool MaybeTransferLinear1ToAnotherVariable( + VariableEncodingLocalModel& local_model, PresolveContext* context); + +} // namespace sat +} // namespace operations_research + +#endif // ORTOOLS_SAT_PRESOLVE_ENCODING_H_ diff --git a/ortools/sat/sat_base.h b/ortools/sat/sat_base.h index 051be3dd50..bc791468b8 100644 --- a/ortools/sat/sat_base.h +++ b/ortools/sat/sat_base.h @@ -142,6 +142,22 @@ inline std::ostream& operator<<(std::ostream& os, return os; } +inline std::ostream& operator<<(std::ostream& os, + absl::Span literals) { + os << "["; + bool first = true; + for (const LiteralIndex index : literals) { + if (first) { + first = false; + } else { + os << ","; + } + os << Literal(index).DebugString(); + } + os << "]"; + return os; +} + // Only used for testing to use the classical SAT notation for a literal. This // allows to write Literals({+1, -4, +3}) for the clause with BooleanVariable 0 // and 2 appearing positively and 3 negatively. diff --git a/ortools/sat/sat_inprocessing.cc b/ortools/sat/sat_inprocessing.cc index 69769ce84b..1baddcf50c 100644 --- a/ortools/sat/sat_inprocessing.cc +++ b/ortools/sat/sat_inprocessing.cc @@ -2017,11 +2017,18 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables( // been cleaned up yet, as these are needed to really recover all gates. // // TODO(user): Ideally the detection code should be robust to that. + // TODO(user): Maybe we should always have an hash-map of binary up to date? int num_fn1 = 0; std::vector> binary_used; for (LiteralIndex a(0); a < implication_graph_->literal_size(); ++a) { + // TODO(user): If we know we have too many implications for the time limit + // We should just be better of not doing that loop at all. + if (timer.WorkLimitIsReached()) break; if (implication_graph_->IsRedundant(Literal(a))) continue; - for (const Literal b : implication_graph_->Implications(Literal(a))) { + const absl::Span implied = + implication_graph_->Implications(Literal(a)); + timer.TrackHashLookups(implied.size()); + for (const Literal b : implied) { if (implication_graph_->IsRedundant(b)) continue; std::array key2; @@ -2066,9 +2073,7 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables( // The AND gate of size 3 should be detected by the short table code, no // need to do the algo here which should be slower. - // - // TODO(user): This seems to be less strong. I think we have some bug - // in our fixed point loop when we fix variables. + continue; } else if (clause->size() == 4) { AddToTruthTable<4>(clause, ids4_); } else if (clause->size() == 5) { @@ -2867,6 +2872,7 @@ class LratGateCongruenceHelper { implication_graph_->GetClauseId(target.Negated(), Literal(m_index))); Append(clause_ids, GetLiteralImpliesRepresentativeClauseId(Literal(m_index))); + Append(clause_ids, GetLiteralImpliesRepresentativeClauseId(target)); } private: @@ -2943,7 +2949,8 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { PresolveTimer timer("GateCongruenceClosure", logger_, time_limit_); timer.OverrideLogging(log_info); - const int num_literals(sat_solver_->NumVariables() * 2); + const int num_variables(sat_solver_->NumVariables()); + const int num_literals(num_variables * 2); marked_.ClearAndResize(Literal(num_literals)); seen_.ClearAndResize(Literal(num_literals)); next_seen_.ClearAndResize(Literal(num_literals)); @@ -2955,7 +2962,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { // Lets release the memory on exit. CHECK(tmp_binary_clauses_.empty()); - absl::Cleanup cleanup = [this] { tmp_binary_clauses_.clear(); }; + absl::Cleanup binary_cleanup = [this] { tmp_binary_clauses_.clear(); }; ExtractAndGatesAndFillShortTruthTables(timer); ExtractShortGates(timer); @@ -2985,37 +2992,67 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { // Tricky: we need to resize this to num_literals because the union_find that // merges target can choose for a representative a literal that is not in the // set of gate inputs. - MergeableOccurrenceList input_literals_to_gate; - input_literals_to_gate.ResetFromTranspose(gates_inputs_, num_literals); + MergeableOccurrenceList input_var_to_gate; + struct GetVarMapper { + BooleanVariable operator()(LiteralIndex l) const { + return Literal(l).Variable(); + } + }; + input_var_to_gate.ResetFromTransposeMap(gates_inputs_, + num_variables); LratGateCongruenceHelper lrat_helper( trail_, implication_graph_, clause_manager_, clause_id_generator_, lrat_proof_handler_, gates_target_, gates_clauses_, union_find); + // Stats + make sure we run it at exit. + int num_units = 0; + int num_equivalences = 0; + int num_processed = 0; + int arity1_equivalences = 0; + absl::Cleanup stat_cleanup = [&] { + total_wtime_ += timer.wtime(); + total_dtime_ += timer.deterministic_time(); + total_equivalences_ += num_equivalences; + total_num_units_ += num_units; + timer.AddCounter("processed", num_processed); + timer.AddCounter("units", num_units); + timer.AddCounter("f1_equiv", arity1_equivalences); + timer.AddCounter("equiv", num_equivalences); + }; + // Starts with all gates in the queue. const int num_gates = gates_inputs_.size(); + total_gates_ += num_gates; std::vector in_queue(num_gates, true); std::vector queue(num_gates); for (GateId id(0); id < num_gates; ++id) queue[id.value()] = id; - int num_units = 0; + int num_processed_fixed_variables = trail_->Index(); + const auto fix_literal = [&, this](Literal to_fix, absl::Span clause_ids) { + DCHECK_EQ(to_fix, lrat_helper.GetRepresentativeWithProofSupport(to_fix)); if (assignment_.LiteralIsTrue(to_fix)) return true; if (!clause_manager_->InprocessingFixLiteral(to_fix, clause_ids)) { return false; } + // This is quite tricky: as we fix a literal, we propagate right away + // everything implied by it in the binary implication graph. So we need to + // loop over all newly_fixed variable in order to properly reach the fix + // point! ++num_units; - for (const GateId gate_id : input_literals_to_gate[to_fix]) { - if (in_queue[gate_id.value()]) continue; - queue.push_back(gate_id); - in_queue[gate_id.value()] = true; - } - for (const GateId gate_id : input_literals_to_gate[to_fix.Negated()]) { - if (in_queue[gate_id.value()]) continue; - queue.push_back(gate_id); - in_queue[gate_id.value()] = true; + for (; num_processed_fixed_variables < trail_->Index(); + ++num_processed_fixed_variables) { + const Literal to_update = lrat_helper.GetRepresentativeWithProofSupport( + (*trail_)[num_processed_fixed_variables]); + for (const GateId gate_id : input_var_to_gate[to_update.Variable()]) { + if (in_queue[gate_id.value()]) continue; + queue.push_back(gate_id); + in_queue[gate_id.value()] = true; + } + input_var_to_gate.ClearList(to_update.Variable()); } return true; }; @@ -3025,7 +3062,6 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { return trail_->GetUnitClauseId(a.Variable()); }; - int num_equivalences = 0; const auto new_equivalence = [&, this](Literal a, Literal b, ClauseId a_implies_b, ClauseId b_implies_a) { @@ -3052,6 +3088,8 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { return false; } + BooleanVariable to_merge_var = kNoBooleanVariable; + BooleanVariable rep_var = kNoBooleanVariable; for (const bool negate : {false, true}) { const LiteralIndex x = negate ? a.NegatedIndex() : a.Index(); const LiteralIndex y = negate ? b.NegatedIndex() : b.Index(); @@ -3064,7 +3102,14 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { union_find.AddEdge(x.value(), y.value()); const LiteralIndex rep(union_find.FindRoot(y.value())); const LiteralIndex to_merge = rep == x ? y : x; - input_literals_to_gate.MergeInto(to_merge, rep); + if (to_merge_var == kNoBooleanVariable) { + to_merge_var = Literal(to_merge).Variable(); + rep_var = Literal(rep).Variable(); + } else { + // We should have the same var. + CHECK_EQ(to_merge_var, Literal(to_merge).Variable()); + CHECK_EQ(rep_var, Literal(rep).Variable()); + } if (lrat_proof_handler_ != nullptr) { if (rep == x) { @@ -3075,17 +3120,6 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { y_implies_x); } } - - // Re-add to the queue all gates with touched inputs. - // - // TODO(user): I think we could only add the gates of "to_merge" - // before we merge. This part of the code is quite quick in any - // case. - for (const GateId gate_id : input_literals_to_gate[rep]) { - if (in_queue[gate_id.value()]) continue; - queue.push_back(gate_id); - in_queue[gate_id.value()] = true; - } } // Invariant. @@ -3095,16 +3129,28 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { CHECK_EQ( lrat_helper.GetRepresentativeWithProofSupport(b), lrat_helper.GetRepresentativeWithProofSupport(b.Negated()).Negated()); + + // Re-add to the queue all gates with touched inputs. + // + // TODO(user): I think we could only add the gates of "to_merge" + // before we merge. This part of the code is quite quick in any + // case. + input_var_to_gate.MergeInto(to_merge_var, rep_var); + for (const GateId gate_id : input_var_to_gate[rep_var]) { + if (in_queue[gate_id.value()]) continue; + queue.push_back(gate_id); + in_queue[gate_id.value()] = true; + } + return true; }; // Main loop. - int num_processed = 0; - int arity1_equivalences = 0; while (!queue.empty()) { ++num_processed; const GateId id = queue.back(); queue.pop_back(); + CHECK(in_queue[id.value()]); in_queue[id.value()] = false; // Tricky: the hash-map might contain id not yet canonicalized. And in @@ -3140,17 +3186,15 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { CHECK_NE(id, other_id); CHECK_GE(other_id, 0); CHECK_EQ(gates_type_[id], gates_type_[other_id]); - CHECK_EQ(absl::Span(gates_inputs_[id]), - absl::Span(gates_inputs_[other_id])); + CHECK_EQ(gates_inputs_[id], gates_inputs_[other_id]); - input_literals_to_gate.RemoveFromFutureOutput(id); + input_var_to_gate.RemoveFromFutureOutput(id); // We detected a <=> b (or, equivalently, rep(a) <=> rep(b)). const Literal a(gates_target_[id]); const Literal b(gates_target_[other_id]); const Literal rep_a = lrat_helper.GetRepresentativeWithProofSupport(a); const Literal rep_b = lrat_helper.GetRepresentativeWithProofSupport(b); - if (rep_a != rep_b) { ClauseId rep_a_implies_rep_b = kNoClauseId; ClauseId rep_b_implies_rep_a = kNoClauseId; @@ -3200,9 +3244,11 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { // then target must be false. if (marked_[Literal(rep).Negated()]) { is_unit = true; - input_literals_to_gate.RemoveFromFutureOutput(id); + input_var_to_gate.RemoveFromFutureOutput(id); - const Literal to_fix = Literal(gates_target_[id]).Negated(); + const Literal initial_to_fix = Literal(gates_target_[id]).Negated(); + const Literal to_fix = + lrat_helper.GetRepresentativeWithProofSupport(initial_to_fix); if (!assignment_.LiteralIsTrue(to_fix)) { absl::InlinedVector clause_ids; if (lrat_proof_handler_ != nullptr) { @@ -3249,10 +3295,9 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { // Generic "short" gates. // We just take the representative and re-canonicalize. - absl::Span inputs = gates_inputs_[id]; DCHECK_GE(gates_type_[id], 0); - DCHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0); - for (LiteralIndex& lit_ref : inputs) { + DCHECK_EQ(gates_type_[id] >> (1 << (gates_inputs_[id].size())), 0); + for (LiteralIndex& lit_ref : gates_inputs_[id]) { lit_ref = lrat_helper.GetRepresentativeWithProofSupport(Literal(lit_ref)) .Index(); @@ -3261,7 +3306,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { const int new_size = CanonicalizeShortGate(id); if (new_size == 1) { // We have a function of size 1! This is an equivalence. - input_literals_to_gate.RemoveFromFutureOutput(id); + input_var_to_gate.RemoveFromFutureOutput(id); const Literal a = Literal(gates_target_[id]); const Literal b = Literal(gates_inputs_[id][0]); const Literal rep_a = lrat_helper.GetRepresentativeWithProofSupport(a); @@ -3277,7 +3322,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { break; } else if (new_size == 0) { // We have a fixed function! Just fix the literal. - input_literals_to_gate.RemoveFromFutureOutput(id); + input_var_to_gate.RemoveFromFutureOutput(id); const Literal initial_to_fix = (gates_type_[id] & 1) == 1 ? Literal(gates_target_[id]) : Literal(gates_target_[id]).Negated(); @@ -3293,16 +3338,44 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { } } - total_wtime_ += timer.wtime(); - total_dtime_ += timer.deterministic_time(); - total_gates_ += num_gates; - total_equivalences_ += num_equivalences; - total_num_units_ += num_units; + // DEBUG check that we reached the fix point correctly. + if (DEBUG_MODE) { + CHECK(queue.empty()); + gate_set.clear(); + for (GateId id(0); id < num_gates; ++id) { + if (gates_type_[id] == kAndGateType) continue; + if (assignment_.LiteralIsAssigned(Literal(gates_target_[id]))) continue; + + const int new_size = CanonicalizeShortGate(id); + if (new_size == 0) { + CHECK_EQ(gates_type_[id] & 1, 0); + const Literal initial_to_fix = Literal(gates_target_[id]).Negated(); + const Literal to_fix = + lrat_helper.GetRepresentativeWithProofSupport(initial_to_fix); + CHECK(assignment_.LiteralIsTrue(to_fix)); + } else if (new_size == 1) { + CHECK(!assignment_.LiteralIsAssigned(Literal(gates_target_[id]))); + CHECK(!assignment_.LiteralIsAssigned(Literal(gates_inputs_[id][0]))); + CHECK_EQ(lrat_helper.GetRepresentativeWithProofSupport( + Literal(gates_target_[id])), + lrat_helper.GetRepresentativeWithProofSupport( + Literal(gates_inputs_[id][0]))) + << id << " "; + } else { + const auto [it, inserted] = gate_set.insert(id); + if (!inserted) { + const GateId other_id = *it; + CHECK_EQ(lrat_helper.GetRepresentativeWithProofSupport( + Literal(gates_target_[id])), + lrat_helper.GetRepresentativeWithProofSupport( + Literal(gates_target_[other_id]))) + << id << " " << gates_inputs_[id] << " " << other_id << " " + << gates_inputs_[other_id]; + } + } + } + } - timer.AddCounter("arity1_equivalences", arity1_equivalences); - timer.AddCounter("units", num_units); - timer.AddCounter("processed", num_processed); - timer.AddCounter("equivalences", num_equivalences); return true; } diff --git a/ortools/sat/util.h b/ortools/sat/util.h index 98786296d0..4077df4ddb 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -193,11 +193,14 @@ class MergeableOccurrenceList { public: MergeableOccurrenceList() = default; - void ResetFromTranspose(const CompactVectorVector& input, - int min_transpose_size = 0) { - rows_.ResetFromTranspose(input, min_transpose_size); + template + void ResetFromTransposeMap(const Container& input, + int min_transpose_size = 0) { + rows_.template ResetFromTransposeMap(input, + min_transpose_size); next_.assign(rows_.size(), K(-1)); marked_.ClearAndResize(V(input.size())); + merged_.ClearAndResize(K(rows_.size())); } int size() const { return rows_.size(); } @@ -212,6 +215,7 @@ class MergeableOccurrenceList { // This is not const because it lazily merges lists. absl::Span operator[](K key) { if (key >= rows_.size()) return {}; + CHECK(!merged_[key]); tmp_result_.clear(); K previous(-1); @@ -247,9 +251,13 @@ class MergeableOccurrenceList { // // And otherwise key should never be accessed anymore. void MergeInto(K to_merge, K representative) { + CHECK(!merged_[to_merge]); + DCHECK_GE(to_merge, 0); + DCHECK_GE(representative, 0); DCHECK_LT(to_merge, rows_.size()); DCHECK_LT(representative, rows_.size()); if (to_merge == representative) return; + merged_.Set(to_merge); // Find the end of the representative list to happen to_merge there. // @@ -259,10 +267,16 @@ class MergeableOccurrenceList { K last_list = representative; while (next_[InternalKey(last_list)] >= 0) { last_list = next_[InternalKey(last_list)]; + DCHECK_NE(last_list, to_merge); } next_[InternalKey(last_list)] = to_merge; } + void ClearList(K key) { + next_[InternalKey(key)] = -1; + rows_.Shrink(key, 0); + } + private: // Convert int and StrongInt to normal int. int InternalKey(K key) const; @@ -271,6 +285,7 @@ class MergeableOccurrenceList { // The bitset is used to remove duplicates when merging lists. std::vector tmp_result_; Bitset64 marked_; + Bitset64 merged_; // Each "row" contains a set of values (we lazily remove duplicate). CompactVectorVector rows_; diff --git a/ortools/sat/util_test.cc b/ortools/sat/util_test.cc index 19ac63aa37..a0c076012d 100644 --- a/ortools/sat/util_test.cc +++ b/ortools/sat/util_test.cc @@ -172,7 +172,10 @@ TEST(MergeableOccurrenceList, BasicTest) { storage.ResetFromFlatMapping(keys, vals); MergeableOccurrenceList occ; - occ.ResetFromTranspose(storage); + struct GetVarMapper { + int operator()(int i) const { return i; } + }; + occ.ResetFromTransposeMap(storage); // The first access should be ordered. EXPECT_THAT(occ.size(), 6); diff --git a/ortools/third_party_solvers/BUILD.bazel b/ortools/third_party_solvers/BUILD.bazel index 188f5b4deb..2fbd496623 100644 --- a/ortools/third_party_solvers/BUILD.bazel +++ b/ortools/third_party_solvers/BUILD.bazel @@ -48,7 +48,6 @@ cc_library( hdrs = ["xpress_environment.h"], deps = [ ":dynamic_library", - "//ortools/base", "//ortools/base:base_export", "//ortools/base:status_builder", "@abseil-cpp//absl/base", diff --git a/ortools/util/logging.h b/ortools/util/logging.h index 4907d1fdb7..4516a3628e 100644 --- a/ortools/util/logging.h +++ b/ortools/util/logging.h @@ -135,6 +135,7 @@ class PresolveTimer { // By default we want a limit of around 1 deterministic seconds. void AddToWork(double dtime) { work_ += dtime; } void TrackSimpleLoop(int size) { work_ += 5e-9 * size; } + void TrackHashLookups(int size) { work_ += 5e-8 * size; } void TrackFastLoop(int size) { work_ += 1e-9 * size; } bool WorkLimitIsReached() const { return work_ >= 1.0; }