[CP-SAT] add StopSearch C++ function.

This commit is contained in:
Laurent Perron
2025-03-24 04:55:15 -07:00
parent 35c27ab31f
commit 8dd492498f
9 changed files with 41 additions and 37 deletions

View File

@@ -403,9 +403,6 @@ class NetworkRoutingSolver {
cp_model.AddAllDifferent(node_vars);
Model model;
// Create an atomic Boolean that will be periodically checked by the limit.
std::atomic<bool> stopped(false);
model.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(&stopped);
model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) {
const int path_id = all_paths_[demand_index].size();
@@ -415,7 +412,7 @@ class NetworkRoutingSolver {
all_paths_[demand_index].back().insert(arc);
}
if (all_paths_[demand_index].size() >= max_paths) {
stopped = true;
StopSearch(&model);
}
}));

View File

@@ -182,10 +182,14 @@ void CheckNumberOfSolutions(int size, int num_solutions) {
if (absl::GetFlag(FLAGS_use_symmetry)) {
if (size - 1 < kKnownUniqueSolutions) {
CHECK_EQ(num_solutions, kNumUniqueSolutions[size - 1]);
} else if (!absl::GetFlag(FLAGS_cp_disable_solve)) {
CHECK_GT(num_solutions, 0);
}
} else {
if (size - 1 < kKnownSolutions) {
CHECK_EQ(num_solutions, kNumSolutions[size - 1]);
} else if (!absl::GetFlag(FLAGS_cp_disable_solve)) {
CHECK_GT(num_solutions, 0);
}
}
}

View File

@@ -53,10 +53,6 @@ void Solve() {
parameters.set_enumerate_all_solutions(true);
model.Add(NewSatParameters(parameters));
// Create an atomic Boolean that will be periodically checked by the limit.
std::atomic<bool> stopped(false);
model.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(&stopped);
const int kSolutionLimit = 100;
int num_solutions = 0;
model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) {
@@ -68,7 +64,7 @@ void Solve() {
LOG(INFO) << " start_ins = " << SolutionIntegerValue(r, start_ins);
num_solutions++;
if (num_solutions >= kSolutionLimit) {
stopped = true;
StopSearch(&model);
LOG(INFO) << "Stop search after " << kSolutionLimit << " solutions.";
}
}));

View File

@@ -5457,19 +5457,26 @@ bool CpModelPresolver::PresolveTable(ConstraintProto* ct) {
namespace {
// A container that is valid if only one value was added.
struct UniqueNonNegativeValue {
int index = -1;
void Add(int new_index) {
DCHECK_GE(index, 0);
if (index == -1) {
index = new_index;
class UniqueNonNegativeValue {
public:
void Add(int value) {
DCHECK_GE(value, 0);
if (value_ == -1) {
value_ = value;
} else {
index = -2;
value_ = -2;
}
}
bool IsValid() const { return index >= 0; }
bool HasUniqueValue() const { return value_ >= 0; }
int64_t value() const {
DCHECK(HasUniqueValue());
return value_;
}
private:
int value_ = -1;
};
} // namespace
@@ -5492,7 +5499,7 @@ bool CpModelPresolver::PresolveAllDiff(ConstraintProto* ct) {
return RemoveConstraint(ct);
}
if (size == 1) {
context_->UpdateRuleStats("all_diff: only one expression");
context_->UpdateRuleStats("all_diff: one expression");
return RemoveConstraint(ct);
}
@@ -5530,7 +5537,7 @@ bool CpModelPresolver::PresolveAllDiff(ConstraintProto* ct) {
}
}
if (propagated) {
context_->UpdateRuleStats("all_diff: propagate fixed values");
context_->UpdateRuleStats("all_diff: propagate fixed expressions");
}
}
@@ -5610,9 +5617,10 @@ bool CpModelPresolver::PresolveAllDiff(ConstraintProto* ct) {
bool propagated = false;
for (const auto& [value, unique_index] : value_to_index) {
if (!unique_index.IsValid()) continue;
if (!unique_index.HasUniqueValue()) continue;
const LinearExpressionProto& expr = all_diff.exprs(unique_index.index);
const LinearExpressionProto& expr =
all_diff.exprs(unique_index.value());
if (!context_->IntersectDomainWith(expr, Domain(value), &propagated)) {
return true;
}
@@ -7762,6 +7770,8 @@ void CpModelPresolver::Probe() {
return (void)context_->NotifyThatModelIsUnsat("during probing");
}
time_limit_->ResetHistory();
// Update the presolve context with fixed Boolean variables.
int num_fixed = 0;
CHECK_EQ(sat_solver->CurrentDecisionLevel(), 0);
@@ -8694,6 +8704,7 @@ void CpModelPresolver::MergeNoOverlapConstraints() {
// We reuse the max-clique code from sat.
Model local_model;
local_model.GetOrCreate<Trail>()->Resize(num_constraints);
local_model.GetOrCreate<TimeLimit>()->MergeWithGlobalTimeLimit(time_limit_);
auto* graph = local_model.GetOrCreate<BinaryImplicationGraph>();
graph->Resize(num_constraints);
for (const std::vector<Literal>& clique : cliques) {
@@ -8730,6 +8741,7 @@ void CpModelPresolver::MergeNoOverlapConstraints() {
new_num_intervals, " intervals).");
context_->UpdateRuleStats("no_overlap: merged constraints");
}
time_limit_->ResetHistory();
}
// TODO(user): Should we take into account the exactly_one constraints? note

View File

@@ -2269,6 +2269,10 @@ std::function<SatParameters(Model*)> NewSatParameters(
};
}
void StopSearch(Model* model) {
model->GetOrCreate<ModelSharedTimeLimit>()->Stop();
}
namespace {
void RegisterSearchStatisticCallback(Model* global_model) {
global_model->GetOrCreate<SharedResponseManager>()

View File

@@ -128,6 +128,9 @@ std::function<SatParameters(Model*)> NewSatParameters(
std::function<SatParameters(Model*)> NewSatParameters(
const SatParameters& parameters);
/// Stops the current search.
void StopSearch(Model* model);
// TODO(user): Clean this up.
/// Solves a CpModelProto without any processing. Only used for unit tests.
void LoadAndSolveCpModelForTest(const CpModelProto& model_proto, Model* model);

View File

@@ -1025,10 +1025,6 @@ void StopAfterNSolutionsSampleSat() {
parameters.set_enumerate_all_solutions(true);
model.Add(NewSatParameters(parameters));
// Create an atomic Boolean that will be periodically checked by the limit.
std::atomic<bool> stopped(false);
model.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(&stopped);
const int kSolutionLimit = 5;
int num_solutions = 0;
model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) {
@@ -1038,7 +1034,7 @@ void StopAfterNSolutionsSampleSat() {
LOG(INFO) << " z = " << SolutionIntegerValue(r, z);
num_solutions++;
if (num_solutions >= kSolutionLimit) {
stopped = true;
StopSearch(&model);
LOG(INFO) << "Stop search after " << kSolutionLimit << " solutions.";
}
}));

View File

@@ -138,10 +138,6 @@ void NurseSat() {
// Display the first five solutions.
// [START solution_printer]
// Create an atomic Boolean that will be periodically checked by the limit.
std::atomic<bool> stopped(false);
model.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(&stopped);
const int kSolutionLimit = 5;
int num_solutions = 0;
model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) {
@@ -165,7 +161,7 @@ void NurseSat() {
}
num_solutions++;
if (num_solutions >= kSolutionLimit) {
stopped = true;
StopSearch(&model);
LOG(INFO) << "Stop search after " << kSolutionLimit << " solutions.";
}
}));

View File

@@ -43,10 +43,6 @@ void StopAfterNSolutionsSampleSat() {
parameters.set_enumerate_all_solutions(true);
model.Add(NewSatParameters(parameters));
// Create an atomic Boolean that will be periodically checked by the limit.
std::atomic<bool> stopped(false);
model.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(&stopped);
const int kSolutionLimit = 5;
int num_solutions = 0;
model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) {
@@ -56,7 +52,7 @@ void StopAfterNSolutionsSampleSat() {
LOG(INFO) << " z = " << SolutionIntegerValue(r, z);
num_solutions++;
if (num_solutions >= kSolutionLimit) {
stopped = true;
StopSearch(&model);
LOG(INFO) << "Stop search after " << kSolutionLimit << " solutions.";
}
}));