[CP-SAT] bugfixes

This commit is contained in:
Laurent Perron
2025-12-15 13:42:37 +01:00
committed by Corentin Le Molgat
parent c0b5917c07
commit 4dab47eaa6
8 changed files with 131 additions and 92 deletions

View File

@@ -14634,6 +14634,16 @@ CpSolverStatus CpModelPresolver::Presolve() {
// Sync the domains and initialize the mapping model variables.
context_->WriteVariableDomainsToProto();
// Some vars may have been fixed by the affine relations. This may can impact
// the objective. Let's re-do the canonicalization.
if (context_->working_model->has_objective()) {
// We re-do a canonicalization with the final linear expression.
if (!context_->CanonicalizeObjective()) return InfeasibleStatus();
context_->WriteObjectiveToProto();
}
// Starts the postsolve mapping model.
InitializeMappingModelVariables(context_->AllDomains(),
&fixed_postsolve_mapping,
context_->mapping_model);
@@ -14711,12 +14721,6 @@ CpSolverStatus CpModelPresolver::Presolve() {
*postsolve_mapping_ = std::move(new_postsolve_mapping);
}
if (context_->working_model->has_objective()) {
// We re-do a canonicalization with the final linear expression.
if (!context_->CanonicalizeObjective()) return InfeasibleStatus();
context_->WriteObjectiveToProto();
}
DCHECK(context_->ConstraintVariableUsageIsConsistent());
CanonicalizeRoutesConstraintNodeExpressions(context_);
UpdateHintInProto(context_);

View File

@@ -250,7 +250,6 @@ void DumpNoOverlap2dProblem(const ConstraintProto& ct,
std::vector<Rectangle> sizes_to_render;
IntegerValue x = bounding_box.x_min;
IntegerValue y = 0;
int i = 0;
for (const auto& r : non_fixed_boxes) {
sizes_to_render.push_back(Rectangle{
.x_min = x, .x_max = x + r.x_size, .y_min = y, .y_max = y + r.y_size});
@@ -259,7 +258,6 @@ void DumpNoOverlap2dProblem(const ConstraintProto& ct,
x = 0;
y += r.y_size;
}
++i;
}
VLOG(3) << "Sizes: " << RenderDot(bounding_box, sizes_to_render);
}

View File

@@ -165,19 +165,6 @@ class BestBoundCallback:
self.best_bound = bb
class BestBoundTimeCallback:
def __init__(self) -> None:
self.__last_time: float = 0.0
def new_best_bound(self, unused_bb: float):
self.__last_time = time.time()
@property
def last_time(self) -> float:
return self.__last_time
class CpModelTest(absltest.TestCase):
def tearDown(self) -> None:
@@ -264,7 +251,7 @@ class CpModelTest(absltest.TestCase):
self.assertEqual(nb.index, -b.index - 1)
self.assertRaises(TypeError, x.negated)
def test_issue_4654(self) -> None:
def test_issue4654(self) -> None:
model = cp_model.CpModel()
x = model.NewIntVar(0, 1, "x")
y = model.NewIntVar(0, 2, "y")
@@ -2457,18 +2444,15 @@ TRFM"""
# Solve.
solver = cp_model.CpSolver()
solver.parameters.num_workers = 8
solver.parameters.num_workers = 1
solver.parameters.max_time_in_seconds = 50
solver.parameters.log_search_progress = True
solution_callback = TimeRecorder()
best_bound_callback = BestBoundTimeCallback()
best_bound_callback = BestBoundCallback()
solver.best_bound_callback = best_bound_callback.new_best_bound
status = solver.Solve(model, solution_callback)
status = solver.Solve(model)
if status == cp_model.OPTIMAL:
last_activity = max(
best_bound_callback.last_time, solution_callback.last_time
)
self.assertLess(time.time(), last_activity + 30.0)
# Optimal is 28. The first bound found is 19.0.
self.assertGreaterEqual(best_bound_callback.best_bound, 19.0)
def test_issue4434(self) -> None:
model = cp_model.CpModel()

View File

@@ -1059,20 +1059,58 @@ void SatSolver::ProcessCurrentConflict(
Backtrack(backtrack_level);
DCHECK(ClauseIsValidUnderDebugAssignment(learned_conflict_));
// Tricky: in case of propagation not at the right level we might need to
// backjump further.
for (const auto& [id, is_redundant, min_lbd, clause] : delayed_to_add_) {
// Add the conflict here, so we process all "newly learned" clause in the
// same way.
learned_clauses_.push_back({learned_conflict_clause_id, is_redundant,
min_lbd_of_subsumed_clauses,
std::move(learned_conflict_)});
// Preprocess the new clauses.
// We might need to backtrack further !
for (auto& [id, is_redundant, min_lbd, clause] : learned_clauses_) {
if (clause.empty()) return (void)SetModelUnsat();
// TODO(user): just remove redundant literal from learned clauses. This
// should just be better. We just have to deal with the proof correctly.
if (clause.size() == 2 &&
binary_implication_graph_->RepresentativeOf(clause[0]) ==
binary_implication_graph_->RepresentativeOf(clause[1])) {
Backtrack(0);
break;
// Make sure each clause is "canonicalized" with respect to equivalent
// literals.
//
// TODO(user): Maybe we should do that on each reason before we use them in
// conflict analysis/minimization, but it might be a bit costly.
bool some_change = false;
tmp_clause_ids_.clear();
for (Literal& lit : clause) {
const Literal rep = binary_implication_graph_->RepresentativeOf(lit);
if (rep != lit) {
some_change = true;
if (lrat_proof_handler_ != nullptr) {
// We need not(rep) => not(lit) for the proof.
tmp_clause_ids_.push_back(
binary_implication_graph_->GetClauseId(lit.Negated(), rep));
CHECK_NE(tmp_clause_ids_.back(), kNoClauseId) << lit << " " << rep;
}
lit = rep;
}
}
if (some_change) {
gtl::STLSortAndRemoveDuplicates(&clause);
// This shouldn't happen since it is a new learned clause, otherwise
// something is wrong.
for (int i = 1; i < clause.size(); ++i) {
CHECK_NE(clause[i], clause[i - 1].Negated()) << "trivial new clause?";
}
if (lrat_proof_handler_ != nullptr) {
// We need a new clause id for the canonicalized version, and the proof
// for how we derived that canonicalization.
const ClauseId new_id = clause_id_generator_->GetNextId();
tmp_clause_ids_.push_back(id);
lrat_proof_handler_->AddInferredClause(new_id, clause, tmp_clause_ids_);
id = new_id;
}
}
// Tricky: in case of propagation not at the right level we might need to
// backjump further.
int num_false = 0;
for (const Literal l : clause) {
if (Assignment().LiteralIsFalse(l)) ++num_false;
@@ -1094,19 +1132,15 @@ void SatSolver::ProcessCurrentConflict(
}
}
// Add any delayed clause before the final conflict.
for (const auto& [id, is_redundant, min_lbd, clause] : delayed_to_add_) {
// Learn the new clauses.
int best_lbd = std::numeric_limits<int>::max();
for (const auto& [id, is_redundant, min_lbd, clause] : learned_clauses_) {
DCHECK((lrat_proof_handler_ == nullptr) || (id != kNoClauseId));
AddLearnedClauseAndEnqueueUnitPropagation(id, clause, is_redundant,
min_lbd);
const int lbd = AddLearnedClauseAndEnqueueUnitPropagation(
id, clause, is_redundant, min_lbd);
best_lbd = std::min(best_lbd, lbd);
}
// Create and attach the new learned clause.
const int conflict_lbd = AddLearnedClauseAndEnqueueUnitPropagation(
learned_conflict_clause_id, learned_conflict_, is_redundant,
min_lbd_of_subsumed_clauses);
restart_->OnConflict(conflict_trail_index, conflict_level, conflict_lbd);
restart_->OnConflict(conflict_trail_index, conflict_level, best_lbd);
}
namespace {
@@ -1128,7 +1162,7 @@ std::pair<bool, int> SatSolver::SubsumptionsInConflictResolution(
ClauseId learned_conflict_id, absl::Span<const Literal> conflict,
absl::Span<const Literal> reason_used) {
CHECK_NE(CurrentDecisionLevel(), 0);
delayed_to_add_.clear();
learned_clauses_.clear();
// This is used to see if the learned conflict subsumes some clauses.
// Note that conflict is not yet in the clauses_propagator_.
@@ -1249,7 +1283,7 @@ std::pair<bool, int> SatSolver::SubsumptionsInConflictResolution(
// We can only add them after backtracking, since these are currently
// conflict.
delayed_to_add_.push_back(
learned_clauses_.push_back(
{new_id, new_clause_is_redundant, new_clause_min_lbd,
std::vector<Literal>(subsuming_clauses_[i].begin(),
subsuming_clauses_[i].end())});
@@ -1345,8 +1379,8 @@ std::pair<bool, int> SatSolver::SubsumptionsInConflictResolution(
}
// Also learn the "decision" conflict.
delayed_to_add_.push_back({new_clause_id, decision_is_redundant,
decision_min_lbd, decision_clause});
learned_clauses_.push_back({new_clause_id, decision_is_redundant,
decision_min_lbd, decision_clause});
}
// Sparse clear.

View File

@@ -879,13 +879,15 @@ class SatSolver {
CompactVectorVector<int, Literal> subsuming_clauses_;
CompactVectorVector<int, SatClause*> subsuming_groups_;
struct DelayedNewClause {
// On each conflict, we learn at least one clause, but depending on the cases,
// we can learn more than one.
struct NewClauses {
ClauseId id;
bool is_redundant;
int min_lbd_of_subsumed_clauses;
std::vector<Literal> clause;
};
std::vector<DelayedNewClause> delayed_to_add_;
std::vector<NewClauses> learned_clauses_;
// When true, temporarily disable the deletion of clauses that are not needed
// anymore. This is a hack for TryToMinimizeClause() because we use

View File

@@ -236,38 +236,29 @@ void SolutionCrush::SetOrUpdateVarToDomain(int var, const Domain& domain) {
}
}
void SolutionCrush::SetOrUpdateVarToDomain(
int var, const Domain& domain,
const absl::btree_map<int64_t, int>& encoding,
void SolutionCrush::SetOrUpdateVarToDomainWithOptionalEscapeValue(
int var, const Domain& reduced_var_domain,
std::optional<int64_t> unique_escape_value,
bool push_down_when_repairing_hints) {
DCHECK_EQ(domain.Size(), encoding.size());
bool push_down_when_not_in_domain,
const absl::btree_map<int64_t, int>& encoding) {
if (!solution_is_loaded_) return;
if (HasValue(var)) {
const int64_t old_value = GetVarValue(var);
if (domain.Contains(old_value)) return;
int64_t new_value = old_value;
if (unique_escape_value.has_value()) { // Only one escape value.
if (reduced_var_domain.Contains(old_value)) return;
if (unique_escape_value.has_value()) {
new_value = unique_escape_value.value();
} else if (push_down_when_repairing_hints) {
DCHECK_GT(old_value, domain.Min());
new_value = domain.ValueAtOrBefore(old_value);
} else if (push_down_when_not_in_domain) {
DCHECK_GT(old_value, reduced_var_domain.Min());
new_value = reduced_var_domain.ValueAtOrBefore(old_value);
} else {
new_value = domain.ValueAtOrAfter(old_value);
}
for (const auto [value, lit] : encoding) {
SetLiteralValue(lit, value == new_value);
DCHECK_LT(old_value, reduced_var_domain.Max());
new_value = reduced_var_domain.ValueAtOrAfter(old_value);
}
SetLiteralValue(encoding.at(new_value), true);
CHECK(!encoding.contains(old_value));
SetVarValue(var, new_value);
VLOG(3) << "SetOrUpdateVarToDomain: " << var << ", old_value: " << old_value
<< ", new_value: " << new_value
<< ", domain: " << domain.ToString();
DCHECK(encoding.contains(new_value))
<< "domain: " << domain.ToString() << "old_value: " << old_value
<< " new_value: " << new_value;
} else if (domain.IsFixed()) {
SetVarValue(var, domain.FixedValue());
}
}

View File

@@ -151,15 +151,20 @@ class SolutionCrush {
// value. Otherwise does nothing.
void SetOrUpdateVarToDomain(int var, const Domain& domain);
// If `var` already has a value, updates it to be within the given domain
// following the given encoding and the status of the variable w.r.t. the
// escape value, and the objective. Otherwise, if the domain is fixed, sets
// the value of `var` to this fixed value. Otherwise does nothing. In the
// process, update the encoding literals to reflect the new value of `var`.
void SetOrUpdateVarToDomain(int var, const Domain& domain,
const absl::btree_map<int64_t, int>& encoding,
std::optional<int64_t> unique_escape_value,
bool push_down_when_repairing_hints);
// If `var` already has a value, updates it to be within the given domain.
// There are 3 cases to consider:
// 1/ The hinted value is in reduced_var_domain. Nothing to do.
// 2/ The hinted value is not in the domain, and there is an escape value.
// Update the hinted value to the escape value, and update the encoding
// literals to reflect the new value of `var`.
// 3/ The hinted value is not in the domain, and there is no escape value.
// Update the hinted value to be in the domain by pushing it in the given
// direction, and update the encoding literals to reflect the new value
void SetOrUpdateVarToDomainWithOptionalEscapeValue(
int var, const Domain& reduced_var_domain,
std::optional<int64_t> unique_escape_value,
bool push_down_when_not_in_domain,
const absl::btree_map<int64_t, int>& encoding);
// Updates the value of the given literals to false if their current values
// are different (or does nothing otherwise).

View File

@@ -658,11 +658,32 @@ void TryToReplaceVariableByItsEncoding(int var, PresolveContext* context,
values.CreateAllValueEncodingLiterals();
// Fix the hinted value if needed.
//
// The logic follows the classes of equivalence induced by the value of the
// literals from the enforced linear1 constraining this variable.
// Two values are in the same class if all the literals have the same value.
//
// We have a heuristic method here:
// - If the variable is in the domain, we do nothing.
// - If the variable has only var==value and var!=value encodings. All values
// not touched by these linear1 are equivalent. We will reassign them to the
// unique escape value.
// - If the variable also has var>=value and var<=value encodings, we will
// push the value of the variable to the closest value in the domain in the
// direction of the objective. To this effect, for every contiguous set of
// values not in the set of referenced values. the min of the max of that
// set has been added to the encoded domain, such that the push up or down
// always falls back on an encoded value.
//
// TODO(user): we could optimize this as, for instance, we only need to
// look at values from the order encodings, and not all values when creating
// the equivalence class in the last case.
const bool push_down_when_unconstrained =
!var_in_objective || var_has_positive_objective_coefficient;
solution_crush.SetOrUpdateVarToDomain(
var, Domain::FromValues(values.encoded_values()), values.encoding(),
values.unique_escape_value(), push_down_when_unconstrained);
solution_crush.SetOrUpdateVarToDomainWithOptionalEscapeValue(
var, Domain::FromValues(values.encoded_values()),
values.unique_escape_value(), push_down_when_unconstrained,
values.encoding());
order.CreateAllOrderEncodingLiterals(values);
// Link all Boolean in our linear1 to the encoding literals.