[CP-SAT] use precedences in completion time cuts; improve glue clause sharing

This commit is contained in:
Laurent Perron
2025-05-16 16:47:50 +02:00
parent 0e0194eb52
commit b28b0625f9
13 changed files with 548 additions and 571 deletions

View File

@@ -2893,6 +2893,7 @@ cc_library(
":linear_constraint",
":linear_constraint_manager",
":model",
":precedences",
":sat_base",
":sat_solver",
":scheduling_helpers",
@@ -3548,6 +3549,7 @@ cc_library(
"//ortools/util:strong_integers",
"//ortools/util:time_limit",
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/container:inlined_vector",
"@abseil-cpp//absl/log",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/log:vlog_is_on",

View File

@@ -1127,7 +1127,9 @@ class FullProblemSolver : public SubSolver {
// Note that this is done after the loading, so we will never export
// problem clauses.
if (shared_->clauses != nullptr) {
const int id = shared_->clauses->RegisterNewId();
const int id = shared_->clauses->RegisterNewId(
/*may_terminate_early=*/stop_at_first_solution_ &&
local_model_.GetOrCreate<CpModelProto>()->has_objective());
shared_->clauses->SetWorkerNameForId(id, local_model_.Name());
RegisterClausesLevelZeroImport(id, shared_->clauses.get(),

View File

@@ -953,27 +953,31 @@ void RegisterClausesExport(int id, SharedClausesManager* shared_clauses_manager,
if (!model->GetOrCreate<SatParameters>()->share_glue_clauses()) {
return;
}
auto* clause_stream = shared_clauses_manager->GetClauseStream(id);
const int max_lbd =
model->GetOrCreate<SatParameters>()->clause_cleanup_lbd_bound();
// Note that this callback takes no global locks, everything operates on this
// worker's own clause stream, whose lock is only used by this worker, and
// briefly when generating a batch in SharedClausesManager::Synchronize().
auto share_clause = [mapping, clause_stream, max_lbd,
clause = std::vector<int>()](
const double share_interval =
model->GetOrCreate<SatParameters>()->share_glue_clauses_dtime();
auto* clause_stream = model->GetOrCreate<UniqueClauseStream>();
auto* time_limit = model->GetOrCreate<TimeLimit>();
auto share_clause = [mapping, clause_stream, time_limit, id,
shared_clauses_manager, share_interval,
next_batch_dtime = -1.0, clause = std::vector<int>()](
int lbd, absl::Span<const Literal> literals) mutable {
if (lbd <= 0 || lbd > max_lbd ||
!clause_stream->CanAccept(literals.size(), lbd)) {
return;
if (literals.size() >= UniqueClauseStream::kMinClauseSize &&
literals.size() <= UniqueClauseStream::kMaxClauseSize) {
clause.clear();
for (const Literal& lit : literals) {
const int var =
mapping->GetProtoVariableFromBooleanVariable(lit.Variable());
if (var == -1) return;
clause.push_back(lit.IsPositive() ? var : NegatedRef(var));
}
clause_stream->Add(clause, lbd);
}
clause.clear();
for (const Literal& lit : literals) {
const int var =
mapping->GetProtoVariableFromBooleanVariable(lit.Variable());
if (var == -1) return;
clause.push_back(lit.IsPositive() ? var : NegatedRef(var));
const double elapsed_dtime = time_limit->GetElapsedDeterministicTime();
if (next_batch_dtime < 0) next_batch_dtime = elapsed_dtime + share_interval;
if (elapsed_dtime >= next_batch_dtime) {
shared_clauses_manager->AddBatch(id, clause_stream->NextBatch());
next_batch_dtime = elapsed_dtime + share_interval;
}
clause_stream->Add(clause);
};
model->GetOrCreate<ClauseManager>()->SetAddClauseCallback(
std::move(share_clause));
@@ -994,16 +998,16 @@ int RegisterClausesLevelZeroImport(int id,
auto* implications = model->GetOrCreate<BinaryImplicationGraph>();
const bool share_glue_clauses =
model->GetOrCreate<SatParameters>()->share_glue_clauses();
auto* clause_stream =
share_glue_clauses ? model->GetOrCreate<UniqueClauseStream>() : nullptr;
const bool minimize_shared_clauses =
model->GetOrCreate<SatParameters>()->minimize_shared_clauses();
auto* clause_stream = share_glue_clauses
? shared_clauses_manager->GetClauseStream(id)
: nullptr;
auto* clause_manager = model->GetOrCreate<ClauseManager>();
const auto& import_level_zero_clauses = [shared_clauses_manager, id, mapping,
sat_solver, implications,
clause_stream, clause_manager,
minimize_shared_clauses]() {
minimize_shared_clauses,
clause_stream,
clause_manager]() mutable {
std::vector<std::pair<int, int>> new_binary_clauses;
shared_clauses_manager->GetUnseenBinaryClauses(id, &new_binary_clauses);
implications->EnableSharing(false);
@@ -1020,28 +1024,27 @@ int RegisterClausesLevelZeroImport(int id,
int new_clauses = 0;
std::array<Literal, UniqueClauseStream::kMaxClauseSize> local_clause;
sat_solver->EnsureNewClauseIndexInitialized();
// Temporarily disable clause sharing so we don't immediately re-export the
// clauses we just imported.
// Temporarily disable clause sharing.
auto callback = clause_manager->TakeAddClauseCallback();
for (const absl::Span<const int> shared_clause :
shared_clauses_manager->GetUnseenClauses(id)) {
// Check this clause was not already learned by this worker.
// We can delete the fingerprint because we should not learn an identical
// clause, and the global stream will not emit the same clause while any
// worker hasn't consumed this clause (and thus also shouldn't relearn the
// clause).
if (clause_stream->Delete(shared_clause)) continue;
for (int i = 0; i < shared_clause.size(); ++i) {
local_clause[i] = mapping->Literal(shared_clause[i]);
while (true) {
auto batch = shared_clauses_manager->GetUnseenClauses(id);
if (batch.empty()) break;
for (int clause_index = 0; clause_index < batch.size(); ++clause_index) {
const absl::Span<const int>& shared_clause = batch[clause_index];
// Check this clause was not already learned by this worker.
if (!clause_stream->BlockClause(shared_clause)) continue;
++new_clauses;
for (int i = 0; i < shared_clause.size(); ++i) {
local_clause[i] = mapping->Literal(shared_clause[i]);
}
if (!sat_solver->AddProblemClause(
absl::MakeSpan(local_clause)
.subspan(0, shared_clause.size()))) {
return false;
}
}
if (!sat_solver->AddProblemClause(
absl::MakeSpan(local_clause).subspan(0, shared_clause.size()))) {
return false;
}
++new_clauses;
}
clause_manager->SetAddClauseCallback(std::move(callback));
clause_stream->RemoveWorstClauses();
if (minimize_shared_clauses && new_clauses > 0) {
// The new clauses may be subsumed, so try to minimize them to reduce
// overhead of sharing.
@@ -2110,8 +2113,7 @@ SharedClasses::SharedClasses(const CpModelProto* proto, Model* global_model)
!params.interleave_search() || params.num_workers() <= 1;
response->SetSynchronizationMode(always_synchronize);
if (params.share_binary_clauses() && params.num_workers() > 1) {
clauses = std::make_unique<SharedClausesManager>(always_synchronize,
absl::Seconds(1));
clauses = std::make_unique<SharedClausesManager>(always_synchronize);
}
}

View File

@@ -87,9 +87,9 @@ TEST(LoadCpModelTest, PureSatProblem) {
TEST(LoadCpModelTest, PureSatProblemWithLimit) {
const CpModelProto model_proto = Random3SatProblem(500);
LOG(INFO) << CpModelStats(model_proto);
Model model;
model.Add(NewSatParameters("max_deterministic_time:0.00001"));
const CpSolverResponse response = SolveCpModel(model_proto, &model);
SatParameters params;
params.set_max_deterministic_time(0.00001);
const CpSolverResponse response = SolveWithParameters(model_proto, params);
EXPECT_EQ(response.status(), CpSolverStatus::UNKNOWN);
LOG(INFO) << CpSolverResponseStats(response);
}
@@ -193,7 +193,8 @@ TEST(LoadCpModelTest, SimpleCumulative) {
}
TEST(SolverCpModelTest, EmptyModel) {
const CpModelProto cp_model = ParseTestProto("solution_hint {}");
CpModelProto cp_model;
cp_model.mutable_solution_hint();
SatParameters params;
params.set_debug_crash_if_presolve_breaks_hint(true);
@@ -329,6 +330,7 @@ TEST(SolveCpModelTest, TrivialModelWithCore) {
response.solution().end())));
}
#if !defined(__EMBEDDED_PLATFORM__)
TEST(SolveCpModelTest, TrivialLinearTranslatedModel) {
const CpModelProto model_proto = ParseTestProto(R"pb(
variables { domain: -10 domain: 10 }
@@ -4803,6 +4805,7 @@ TEST(PresolveCpModelTest, CumulativeBug4) {
response = SolveWithParameters(cp_model, params);
EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL);
}
#endif // !defined(__EMBEDDED_PLATFORM__)
} // namespace
} // namespace sat

View File

@@ -32,7 +32,7 @@
#include "absl/log/vlog_is_on.h"
#include "absl/numeric/bits.h"
#include "absl/types/span.h"
#include "ortools/base/stl_util.h"
// #include "ortools/base/stl_util.h"
#include "ortools/sat/2d_mandatory_overlap_propagator.h"
#include "ortools/sat/2d_orthogonal_packing.h"
#include "ortools/sat/2d_try_edge_propagator.h"

View File

@@ -87,6 +87,7 @@ std::string ValidateParameters(const SatParameters& params) {
TEST_IS_FINITE(relative_gap_limit);
TEST_IS_FINITE(restart_dl_average_ratio);
TEST_IS_FINITE(restart_lbd_average_ratio);
TEST_IS_FINITE(share_glue_clauses_dtime);
TEST_IS_FINITE(shared_tree_open_leaves_per_worker);
TEST_IS_FINITE(shaving_deterministic_time_in_probing_search);
TEST_IS_FINITE(shaving_search_deterministic_time);
@@ -156,6 +157,7 @@ std::string ValidateParameters(const SatParameters& params) {
TEST_NON_NEGATIVE(presolve_probing_deterministic_time_limit);
TEST_NON_NEGATIVE(probing_deterministic_time_limit);
TEST_NON_NEGATIVE(symmetry_detection_deterministic_time_limit);
TEST_POSITIVE(share_glue_clauses_dtime);
if (params.enumerate_all_solutions() &&
(params.num_search_workers() > 1 || params.num_workers() > 1)) {

View File

@@ -24,7 +24,7 @@ option java_multiple_files = true;
// Contains the definitions for all the sat algorithm parameters and their
// default values.
//
// NEXT TAG: 322
// NEXT TAG: 323
message SatParameters {
// In some context, like in a portfolio of search, it makes sense to name a
// given parameters set for logging purpose.
@@ -705,6 +705,9 @@ message SatParameters {
// are imported.
optional bool minimize_shared_clauses = 300 [default = true];
// The amount of dtime between each export of shared glue clauses.
optional double share_glue_clauses_dtime = 322 [default = 1.0];
// ==========================================================================
// Debugging parameters
// ==========================================================================

View File

@@ -43,6 +43,7 @@
#include "ortools/sat/linear_constraint.h"
#include "ortools/sat/linear_constraint_manager.h"
#include "ortools/sat/model.h"
#include "ortools/sat/precedences.h"
#include "ortools/sat/sat_base.h"
#include "ortools/sat/sat_solver.h"
#include "ortools/sat/scheduling_helpers.h"
@@ -1053,15 +1054,16 @@ CutGenerator CreateNoOverlapPrecedenceCutGenerator(
}
CompletionTimeEvent::CompletionTimeEvent(int t,
SchedulingConstraintHelper* x_helper,
SchedulingConstraintHelper* helper,
SchedulingDemandHelper* demands_helper)
: task_index(t),
start_min(x_helper->StartMin(t)),
start_max(x_helper->StartMax(t)),
end_min(x_helper->EndMin(t)),
end_max(x_helper->EndMax(t)),
size_min(x_helper->SizeMin(t)),
end(x_helper->Ends()[t]) {
start_min(helper->StartMin(t)),
start_max(helper->StartMax(t)),
end_min(helper->EndMin(t)),
end_max(helper->EndMax(t)),
size_min(helper->SizeMin(t)),
start(helper->Starts()[t]),
end(helper->Ends()[t]) {
if (demands_helper == nullptr) {
demand_min = 1;
demand_is_fixed = true;
@@ -1099,20 +1101,68 @@ std::string CompletionTimeEvent::DebugString() const {
"]");
}
void CtExhaustiveHelper::Init(
const absl::Span<const CompletionTimeEvent> events, Model* model) {
BinaryRelationsMaps* binary_relations =
model->GetOrCreate<BinaryRelationsMaps>();
max_task_index_ = 0;
for (const auto& event : events) {
max_task_index_ = std::max(max_task_index_, event.task_index);
}
predecessors_.reserve(max_task_index_ + 1);
for (const auto& e1 : events) {
CHECK_LE(predecessors_.size(), e1.task_index);
while (predecessors_.size() <= e1.task_index) {
predecessors_.Add({});
}
// Cap the number of precedences to avoid O(n^2) time complexity.
if (predecessors_.num_entries() > 20000) break;
for (const auto& e2 : events) {
if (e2.task_index == e1.task_index) continue;
if (binary_relations->GetPrecedenceStatus(e2.end, e1.start) ==
RelationStatus::IS_TRUE) {
predecessors_.AppendToLastVector(e2.task_index);
}
}
}
VLOG(2) << "num_tasks:" << max_task_index_ + 1
<< " num_precedences:" << predecessors_.num_entries();
}
bool CtExhaustiveHelper::PermutationIsCompatibleWithPrecedences(
absl::Span<const CompletionTimeEvent> events,
absl::Span<const int> permutation) {
visited_.assign(max_task_index_ + 1, false);
for (int i = permutation.size() - 1; i >= 0; --i) {
const CompletionTimeEvent& event = events[permutation[i]];
for (const int predecessor : predecessors_[event.task_index]) {
if (visited_[predecessor]) return false;
}
visited_[event.task_index] = true;
}
return true;
}
namespace {
bool ComputeWeightedSumOfEndMinsOfOnePermutationForNoOverlap(
absl::Span<const CompletionTimeEvent> events,
absl::Span<const int> permutation, IntegerValue& sum_of_ends,
IntegerValue& sum_of_weighted_ends) {
// Reset the two sums.
sum_of_ends = 0;
sum_of_weighted_ends = 0;
// Loop over the permutation.
IntegerValue end_min_of_previous_task = kMinIntegerValue;
for (const int index : permutation) {
const CompletionTimeEvent& event = events[index];
const IntegerValue threshold =
std::max(event.start_min, end_min_of_previous_task);
if (event.start_max < threshold) return false; // Infeasible.
end_min_of_previous_task = threshold + event.size_min;
sum_of_ends += end_min_of_previous_task;
sum_of_weighted_ends += event.energy_min * end_min_of_previous_task;
@@ -1131,9 +1181,8 @@ bool ComputeWeightedSumOfEndMinsOfOnePermutationForNoOverlap(
bool ComputeWeightedSumOfEndMinsOfOnePermutation(
absl::Span<const CompletionTimeEvent> events,
absl::Span<const int> permutation, IntegerValue capacity_max,
IntegerValue& sum_of_ends, IntegerValue& sum_of_weighted_ends,
std::vector<std::pair<IntegerValue, IntegerValue>>& profile,
std::vector<std::pair<IntegerValue, IntegerValue>>& new_profile) {
CtExhaustiveHelper& helper, IntegerValue& sum_of_ends,
IntegerValue& sum_of_weighted_ends, bool& cut_use_precedences) {
DCHECK_EQ(permutation.size(), events.size());
if (capacity_max == 1) {
@@ -1141,11 +1190,11 @@ bool ComputeWeightedSumOfEndMinsOfOnePermutation(
events, permutation, sum_of_ends, sum_of_weighted_ends);
}
// Set default values.
// Reset the two sums.
sum_of_ends = 0;
sum_of_weighted_ends = 0;
// Is the permutation feasible ?
// Quick check to see if the permutation feasible:
// ei = events[permutation[i]], ej = events[permutation[j]], i < j
// - start_max(ej) >= start_min(ei)
IntegerValue demand_min_of_previous_task = 0;
@@ -1161,6 +1210,7 @@ bool ComputeWeightedSumOfEndMinsOfOnePermutation(
if (event.start_max < threshold) {
return false;
}
start_min_of_previous_task = threshold;
end_min_of_previous_task = threshold + event.size_min;
demand_min_of_previous_task = event.demand_min;
@@ -1168,9 +1218,12 @@ bool ComputeWeightedSumOfEndMinsOfOnePermutation(
// The profile (and new profile) is a set of (time, capa_left) pairs,
// ordered by increasing time and capa_left.
profile.clear();
profile.emplace_back(kMinIntegerValue, capacity_max);
profile.emplace_back(kMaxIntegerValue, capacity_max);
helper.profile_.clear();
helper.profile_.emplace_back(kMinIntegerValue, capacity_max);
helper.profile_.emplace_back(kMaxIntegerValue, capacity_max);
// Loop over the permutation.
helper.assigned_ends_.assign(helper.max_task_index() + 1, kMinIntegerValue);
IntegerValue start_of_previous_task = kMinIntegerValue;
for (const int index : permutation) {
const CompletionTimeEvent& event = events[index];
@@ -1180,15 +1233,30 @@ bool ComputeWeightedSumOfEndMinsOfOnePermutation(
// Iterate on the profile to find the step that contains start_min.
// Then push until we find a step with enough capacity.
int current = 0;
while (profile[current + 1].first <= start_min ||
profile[current].second < event.demand_min) {
while (helper.profile_[current + 1].first <= start_min ||
helper.profile_[current].second < event.demand_min) {
++current;
}
const IntegerValue actual_start =
std::max(start_min, profile[current].first);
IntegerValue actual_start =
std::max(start_min, helper.profile_[current].first);
const IntegerValue initial_start_min = actual_start;
start_of_previous_task = actual_start;
// Propagate precedences.
//
// helper.predecessors() can be truncated. We need to be careful here.
if (event.task_index < helper.predecessors().size()) {
for (const int predecessor : helper.predecessors()[event.task_index]) {
if (helper.assigned_ends_[predecessor] == kMinIntegerValue) continue;
actual_start =
std::max(actual_start, helper.assigned_ends_[predecessor]);
}
}
if (actual_start > initial_start_min) {
cut_use_precedences = true;
VLOG(3) << "push from " << initial_start_min << " to " << actual_start;
}
// Compatible with the event.start_max ?
if (actual_start > event.start_max) {
@@ -1197,33 +1265,37 @@ bool ComputeWeightedSumOfEndMinsOfOnePermutation(
const IntegerValue actual_end = actual_start + event.size_min;
// Bookkeeping.
helper.assigned_ends_[event.task_index] = actual_end;
sum_of_ends += actual_end;
sum_of_weighted_ends += event.energy_min * actual_end;
start_of_previous_task = actual_start;
// No need to update the profile on the last loop.
if (event.task_index == events[permutation.back()].task_index) break;
// Update the profile.
new_profile.clear();
new_profile.push_back(
{actual_start, profile[current].second - event.demand_min});
helper.new_profile_.clear();
helper.new_profile_.push_back(
{actual_start, helper.profile_[current].second - event.demand_min});
++current;
while (profile[current].first < actual_end) {
new_profile.push_back(
{profile[current].first, profile[current].second - event.demand_min});
while (helper.profile_[current].first < actual_end) {
helper.new_profile_.push_back(
{helper.profile_[current].first,
helper.profile_[current].second - event.demand_min});
++current;
}
if (profile[current].first > actual_end) {
new_profile.push_back(
{actual_end, new_profile.back().second + event.demand_min});
if (helper.profile_[current].first > actual_end) {
helper.new_profile_.push_back(
{actual_end, helper.new_profile_.back().second + event.demand_min});
}
while (current < profile.size()) {
new_profile.push_back(profile[current]);
while (current < helper.profile_.size()) {
helper.new_profile_.push_back(helper.profile_[current]);
++current;
}
profile.swap(new_profile);
helper.profile_.swap(helper.new_profile_);
}
return true;
}
@@ -1232,36 +1304,37 @@ bool ComputeWeightedSumOfEndMinsOfOnePermutation(
bool ComputeMinSumOfWeightedEndMins(
absl::Span<const CompletionTimeEvent> events, IntegerValue capacity_max,
double sum_of_ends_lp, double sum_of_weighted_ends_lp,
IntegerValue& min_sum_of_end_mins,
IntegerValue& min_sum_of_weighted_end_mins) {
double unweighted_threshold, double weighted_threshold,
CtExhaustiveHelper& helper, double& min_sum_of_ends,
double& min_sum_of_weighted_ends, bool& cut_use_precedences) {
// Reset the events based sums.
min_sum_of_ends = std::numeric_limits<double>::max();
min_sum_of_weighted_ends = std::numeric_limits<double>::max();
// Local stats.
int num_explored = 0;
int num_pruned = 0;
min_sum_of_end_mins = kMaxIntegerValue;
min_sum_of_weighted_end_mins = kMaxIntegerValue;
bool aborted = false;
const int64_t unweighted_threshold =
static_cast<int64_t>(std::floor(sum_of_ends_lp + kMinCutViolation));
const int64_t weighted_threshold = static_cast<int64_t>(
std::floor(sum_of_weighted_ends_lp + kMinCutViolation));
// Reusable storage for ComputeWeightedSumOfEndMinsOfOnePermutation().
std::vector<std::pair<IntegerValue, IntegerValue>> profile;
std::vector<std::pair<IntegerValue, IntegerValue>> new_profile;
std::vector<int> permutation(events.size());
std::iota(permutation.begin(), permutation.end(), 0);
do {
IntegerValue sum_of_ends(0);
IntegerValue sum_of_weighted_ends(0);
IntegerValue sum_of_ends = 0;
IntegerValue sum_of_weighted_ends = 0;
if (!helper.PermutationIsCompatibleWithPrecedences(events, permutation)) {
cut_use_precedences = true;
continue;
}
if (ComputeWeightedSumOfEndMinsOfOnePermutation(
events, permutation, capacity_max, sum_of_ends,
sum_of_weighted_ends, profile, new_profile)) {
min_sum_of_end_mins = std::min(sum_of_ends, min_sum_of_end_mins);
min_sum_of_weighted_end_mins =
std::min(sum_of_weighted_ends, min_sum_of_weighted_end_mins);
events, permutation, capacity_max, helper, sum_of_ends,
sum_of_weighted_ends, cut_use_precedences)) {
min_sum_of_ends = std::min(ToDouble(sum_of_ends), min_sum_of_ends);
min_sum_of_weighted_ends =
std::min(ToDouble(sum_of_weighted_ends), min_sum_of_weighted_ends);
num_explored++;
if (min_sum_of_end_mins <= unweighted_threshold &&
min_sum_of_weighted_end_mins <= weighted_threshold) {
if (min_sum_of_ends <= unweighted_threshold &&
min_sum_of_weighted_ends <= weighted_threshold) {
aborted = true;
break;
}
@@ -1271,8 +1344,8 @@ bool ComputeMinSumOfWeightedEndMins(
} while (std::next_permutation(permutation.begin(), permutation.end()));
VLOG(3) << "DP: size=" << events.size() << ", explored = " << num_explored
<< ", pruned = " << num_pruned << ", aborted = " << aborted
<< ", min_sum_of_end_mins = " << min_sum_of_end_mins
<< ", min_sum_of_weighted_end_mins = " << min_sum_of_weighted_end_mins
<< ", min_sum_of_end_mins = " << min_sum_of_ends
<< ", min_sum_of_weighted_end_mins = " << min_sum_of_weighted_ends
<< ", unweighted_threshold = " << unweighted_threshold
<< ", weighted_threshold = " << weighted_threshold;
return num_explored > 0;
@@ -1283,7 +1356,8 @@ bool ComputeMinSumOfWeightedEndMins(
// - better caching of explored states
ABSL_MUST_USE_RESULT bool GenerateShortCompletionTimeCutsWithExactBound(
const std::string& cut_name, std::vector<CompletionTimeEvent> events,
IntegerValue capacity_max, Model* model, LinearConstraintManager* manager) {
IntegerValue capacity_max, CtExhaustiveHelper& helper, Model* model,
LinearConstraintManager* manager) {
TopNCuts top_n_cuts(5);
// Sort by start min to bucketize by start_min.
std::sort(
@@ -1298,6 +1372,7 @@ ABSL_MUST_USE_RESULT bool GenerateShortCompletionTimeCutsWithExactBound(
continue;
}
bool cut_use_precedences = false; // Used for naming the cut.
const IntegerValue sequence_start_min = events[start].start_min;
std::vector<CompletionTimeEvent> residual_tasks(events.begin() + start,
events.end());
@@ -1321,40 +1396,43 @@ ABSL_MUST_USE_RESULT bool GenerateShortCompletionTimeCutsWithExactBound(
double sum_of_ends_lp = 0.0;
double sum_of_weighted_ends_lp = 0.0;
IntegerValue sum_of_demands = 0;
IntegerValue sum_of_energies = 0;
double sum_of_square_energies = 0;
double min_sum_of_ends = std::numeric_limits<double>::max();
double min_sum_of_weighted_ends = std::numeric_limits<double>::max();
for (int i = 0; i < std::min<int>(residual_tasks.size(), 7); ++i) {
const CompletionTimeEvent& event = residual_tasks[i];
const double energy = ToDouble(event.energy_min);
sum_of_ends_lp += event.lp_end;
sum_of_weighted_ends_lp += event.lp_end * ToDouble(event.energy_min);
sum_of_weighted_ends_lp += event.lp_end * energy;
sum_of_demands += event.demand_min;
sum_of_energies += event.energy_min;
sum_of_square_energies += energy * energy;
// Both cases with 1 or 2 tasks are trivial and independent of the order.
// Also, if capacity is not exceeded, pushing all ends left is a valid LP
// assignment.
if (i <= 1 || sum_of_demands <= capacity_max) continue;
IntegerValue min_sum_of_end_mins = kMaxIntegerValue;
IntegerValue min_sum_of_weighted_end_mins = kMaxIntegerValue;
if (!ComputeMinSumOfWeightedEndMins(
absl::MakeSpan(residual_tasks).first(i + 1), capacity_max,
sum_of_ends_lp, sum_of_weighted_ends_lp, min_sum_of_end_mins,
min_sum_of_weighted_end_mins)) {
/* unweighted_threshold= */ sum_of_ends_lp + kMinCutViolation,
/* weighted_threshold= */ sum_of_weighted_ends_lp +
kMinCutViolation,
helper, min_sum_of_ends, min_sum_of_weighted_ends,
cut_use_precedences)) {
return false;
}
const double unweigthed_violation =
(ToDouble(min_sum_of_end_mins) - sum_of_ends_lp) / ToDouble(i + 1);
(min_sum_of_ends - sum_of_ends_lp) / std::sqrt(ToDouble(i + 1));
const double weighted_violation =
(ToDouble(min_sum_of_weighted_end_mins) - sum_of_weighted_ends_lp) /
ToDouble(sum_of_energies);
(min_sum_of_weighted_ends - sum_of_weighted_ends_lp) /
std::sqrt(sum_of_square_energies);
// Unweighted cuts.
if (unweigthed_violation > weighted_violation &&
unweigthed_violation > kMinCutViolation) {
LinearConstraintBuilder cut(model, min_sum_of_end_mins,
kMaxIntegerValue);
LinearConstraintBuilder cut(model, min_sum_of_ends, kMaxIntegerValue);
bool is_lifted = false;
for (int j = 0; j <= i; ++j) {
const CompletionTimeEvent& event = residual_tasks[j];
@@ -1362,6 +1440,7 @@ ABSL_MUST_USE_RESULT bool GenerateShortCompletionTimeCutsWithExactBound(
cut.AddTerm(event.end, IntegerValue(1));
}
std::string full_name = cut_name;
if (cut_use_precedences) full_name.append("_prec");
if (is_lifted) full_name.append("_lifted");
top_n_cuts.AddCut(cut.Build(), full_name, manager->LpValues());
}
@@ -1369,7 +1448,7 @@ ABSL_MUST_USE_RESULT bool GenerateShortCompletionTimeCutsWithExactBound(
// Weighted cuts.
if (weighted_violation >= unweigthed_violation &&
weighted_violation > kMinCutViolation) {
LinearConstraintBuilder cut(model, min_sum_of_weighted_end_mins,
LinearConstraintBuilder cut(model, min_sum_of_weighted_ends,
kMaxIntegerValue);
bool is_lifted = false;
for (int j = 0; j <= i; ++j) {
@@ -1379,6 +1458,7 @@ ABSL_MUST_USE_RESULT bool GenerateShortCompletionTimeCutsWithExactBound(
}
std::string full_name = cut_name;
if (is_lifted) full_name.append("_lifted");
if (cut_use_precedences) full_name.append("_prec");
full_name.append("_weighted");
top_n_cuts.AddCut(cut.Build(), full_name, manager->LpValues());
}
@@ -1686,11 +1766,14 @@ CutGenerator CreateNoOverlapCompletionTimeCutGenerator(
}
}
CtExhaustiveHelper helper;
helper.Init(events, model);
const std::string mirror_str = time_is_forward ? "" : "_mirror";
if (!GenerateShortCompletionTimeCutsWithExactBound(
absl::StrCat("NoOverlapCompletionTimeExhaustive", mirror_str),
events,
/*capacity_max=*/IntegerValue(1), model, manager)) {
/*capacity_max=*/IntegerValue(1), helper, model, manager)) {
return false;
}
@@ -1748,11 +1831,14 @@ CutGenerator CreateCumulativeCompletionTimeCutGenerator(
}
}
CtExhaustiveHelper helper;
helper.Init(events, model);
const IntegerValue capacity_max = integer_trail->UpperBound(capacity);
const std::string mirror_str = time_is_forward ? "" : "_mirror";
if (!GenerateShortCompletionTimeCutsWithExactBound(
absl::StrCat("CumulativeCompletionTimeExhaustive", mirror_str),
events, capacity_max, model, manager)) {
events, capacity_max, helper, model, manager)) {
return false;
}

View File

@@ -16,6 +16,7 @@
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/types/span.h"
@@ -24,6 +25,7 @@
#include "ortools/sat/integer_base.h"
#include "ortools/sat/model.h"
#include "ortools/sat/scheduling_helpers.h"
#include "ortools/sat/util.h"
namespace operations_research {
namespace sat {
@@ -117,7 +119,8 @@ struct CompletionTimeEvent {
IntegerValue end_max;
IntegerValue size_min;
// The lp value of the end of the interval.
// Start and end affine expressions and lp value of the end of the interval.
AffineExpression start;
AffineExpression end;
double lp_end = 0.0;
@@ -147,6 +150,30 @@ struct CompletionTimeEvent {
std::string DebugString() const;
};
class CtExhaustiveHelper {
public:
int max_task_index() const { return max_task_index_; }
const CompactVectorVector<int>& predecessors() const { return predecessors_; }
// Temporary data.
std::vector<std::pair<IntegerValue, IntegerValue>> profile_;
std::vector<std::pair<IntegerValue, IntegerValue>> new_profile_;
std::vector<IntegerValue> assigned_ends_;
// Collect precedences, set max_task_index.
// TODO(user): Do some transitive closure.
void Init(absl::Span<const CompletionTimeEvent> events, Model* model);
bool PermutationIsCompatibleWithPrecedences(
absl::Span<const CompletionTimeEvent> events,
absl::Span<const int> permutation);
private:
CompactVectorVector<int> predecessors_;
int max_task_index_ = 0;
std::vector<bool> visited_;
};
// Computes the minimum sum of the end min and the minimum sum of the end min
// weighted by weight of all events. It returns false if no permutation is
// valid w.r.t. the range of starts.
@@ -157,9 +184,9 @@ struct CompletionTimeEvent {
// Optim: If both sums are proven <= to the corresponding threshold, we abort.
bool ComputeMinSumOfWeightedEndMins(
absl::Span<const CompletionTimeEvent> events, IntegerValue capacity_max,
double sum_of_ends_lp, double sum_of_weighted_ends_lp,
IntegerValue& min_sum_of_end_mins,
IntegerValue& min_sum_of_weighted_end_mins);
double unweighted_threshold, double weighted_threshold,
CtExhaustiveHelper& helper, double& min_sum_of_ends,
double& min_sum_of_weighted_ends, bool& cut_use_precedences);
} // namespace sat
} // namespace operations_research

View File

@@ -406,13 +406,17 @@ TEST(ComputeMinSumOfEndMinsTest, CombinationOf3) {
CompletionTimeEvent e3(2, helper, demands_helper);
const std::vector<CompletionTimeEvent> events = {e1, e2, e3};
IntegerValue min_sum_of_end_mins = 0;
IntegerValue min_sum_of_weighted_end_mins = 0;
ASSERT_TRUE(ComputeMinSumOfWeightedEndMins(events, two, 0.01, 0.01,
min_sum_of_end_mins,
min_sum_of_weighted_end_mins));
double min_sum_of_end_mins = 0;
double min_sum_of_weighted_end_mins = 0;
CtExhaustiveHelper ct_helper;
ct_helper.Init(events, &model);
bool cut_use_precedences = false;
ASSERT_TRUE(ComputeMinSumOfWeightedEndMins(
events, two, 0.01, 0.01, ct_helper, min_sum_of_end_mins,
min_sum_of_weighted_end_mins, cut_use_precedences));
EXPECT_EQ(min_sum_of_end_mins, 17);
EXPECT_EQ(min_sum_of_weighted_end_mins, 86);
EXPECT_FALSE(cut_use_precedences);
}
TEST(ComputeMinSumOfEndMinsTest, CombinationOf3ConstraintStart) {
@@ -451,11 +455,14 @@ TEST(ComputeMinSumOfEndMinsTest, CombinationOf3ConstraintStart) {
CompletionTimeEvent e3(2, helper, demands_helper);
const std::vector<CompletionTimeEvent> events = {e1, e2, e3};
IntegerValue min_sum_of_end_mins = 0;
IntegerValue min_sum_of_weighted_end_mins = 0;
ASSERT_TRUE(ComputeMinSumOfWeightedEndMins(events, two, 0.01, 0.01,
min_sum_of_end_mins,
min_sum_of_weighted_end_mins));
double min_sum_of_end_mins = 0;
double min_sum_of_weighted_end_mins = 0;
CtExhaustiveHelper ct_helper;
ct_helper.Init(events, &model);
bool cut_use_precedences = false;
ASSERT_TRUE(ComputeMinSumOfWeightedEndMins(
events, two, 0.01, 0.01, ct_helper, min_sum_of_end_mins,
min_sum_of_weighted_end_mins, cut_use_precedences));
EXPECT_EQ(min_sum_of_end_mins, 18);
EXPECT_EQ(min_sum_of_weighted_end_mins, 86);
}
@@ -496,15 +503,18 @@ TEST(ComputeMinSumOfEndMinsTest, Infeasible) {
CompletionTimeEvent e3(2, helper, demands_helper);
const std::vector<CompletionTimeEvent> events = {e1, e2, e3};
IntegerValue min_sum_of_end_mins = 0;
IntegerValue min_sum_of_weighted_end_mins = 0;
ASSERT_FALSE(ComputeMinSumOfWeightedEndMins(events, two, 0.01, 0.01,
min_sum_of_end_mins,
min_sum_of_weighted_end_mins));
double min_sum_of_end_mins = 0;
double min_sum_of_weighted_end_mins = 0;
CtExhaustiveHelper ct_helper;
ct_helper.Init(events, &model);
bool cut_use_precedences = false;
ASSERT_FALSE(ComputeMinSumOfWeightedEndMins(
events, two, 0.01, 0.01, ct_helper, min_sum_of_end_mins,
min_sum_of_weighted_end_mins, cut_use_precedences));
}
int64_t ExactMakespan(absl::Span<const int> sizes, std::vector<int>& demands,
int capacity) {
double ExactMakespan(absl::Span<const int> sizes, std::vector<int>& demands,
int capacity) {
const int64_t kHorizon = 1000;
CpModelBuilder builder;
LinearExpr obj;
@@ -519,11 +529,11 @@ int64_t ExactMakespan(absl::Span<const int> sizes, std::vector<int>& demands,
const CpSolverResponse response =
SolveWithParameters(builder.Build(), "num_search_workers:8");
EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL);
return static_cast<int64_t>(response.objective_value());
return response.objective_value();
}
int64_t ExactMakespanBruteForce(absl::Span<const int> sizes,
std::vector<int>& demands, int capacity) {
double ExactMakespanBruteForce(absl::Span<const int> sizes,
std::vector<int>& demands, int capacity) {
const int64_t kHorizon = 1000;
Model model;
auto* intervals_repository = model.GetOrCreate<IntervalsRepository>();
@@ -555,12 +565,15 @@ int64_t ExactMakespanBruteForce(absl::Span<const int> sizes,
events.push_back(e);
}
IntegerValue min_sum_of_end_mins = 0;
IntegerValue min_sum_of_weighted_end_mins = 0;
EXPECT_TRUE(ComputeMinSumOfWeightedEndMins(events, capacity, 0.01, 0.01,
min_sum_of_end_mins,
min_sum_of_weighted_end_mins));
return min_sum_of_end_mins.value();
double min_sum_of_end_mins = 0;
double min_sum_of_weighted_end_mins = 0;
CtExhaustiveHelper ct_helper;
ct_helper.Init(events, &model);
bool cut_use_precedences = false;
EXPECT_TRUE(ComputeMinSumOfWeightedEndMins(
events, capacity, 0.01, 0.01, ct_helper, min_sum_of_end_mins,
min_sum_of_weighted_end_mins, cut_use_precedences));
return min_sum_of_end_mins;
}
TEST(ComputeMinSumOfEndMinsTest, RandomCases) {
@@ -576,8 +589,8 @@ TEST(ComputeMinSumOfEndMinsTest, RandomCases) {
demands.push_back(absl::Uniform<int>(random, 1, capacity));
}
EXPECT_EQ(ExactMakespan(sizes, demands, capacity),
ExactMakespanBruteForce(sizes, demands, capacity));
EXPECT_NEAR(ExactMakespan(sizes, demands, capacity),
ExactMakespanBruteForce(sizes, demands, capacity), 1e-6);
}
}

View File

@@ -1130,76 +1130,79 @@ int SharedBoundsManager::NumBoundsExported(absl::string_view worker_name) {
UniqueClauseStream::UniqueClauseStream() {
for (auto& buffer : clauses_by_size_) {
buffer.reserve(kMaxBufferedLiterals);
buffer.reserve(kMaxLiteralsPerBatch);
}
fingerprints_.reserve(kMaxFingerprints);
}
bool UniqueClauseStream::Add(absl::Span<const int> clause) {
absl::MutexLock mutex_lock(&mutex_);
if (clause.size() > kMaxClauseSize || clause.size() <= 2) return false;
// This is just a safety check, the caller should have called CanAccept().
if (NumLiteralsOfSize(clause.size()) + clause.size() > kMaxBufferedLiterals) {
return false;
}
if (BlockClause(clause)) {
std::vector<int>* buffer = MutableBufferForSize(clause.size());
bool UniqueClauseStream::Add(absl::Span<const int> clause, int lbd) {
if (!BlockClause(clause) || lbd > lbd_threshold_) return false;
std::vector<int>* buffer = MutableBufferForSize(clause.size());
CHECK_NE(buffer, nullptr);
if (buffer->size() + clause.size() <= kMaxLiteralsPerBatch) {
buffer->insert(buffer->end(), clause.begin(), clause.end());
return true;
} else {
// Maybe replace an old buffered clause of the same size if it has a smaller
// hash value. This means that the buffer will contain a deterministic
// sample of the clauses added independent of insertion order.
const int64_t replaced_clause_id =
HashClause(clause, 1) % NumClausesOfSize(clause.size());
absl::Span<int> replaced_clause = absl::MakeSpan(*buffer).subspan(
replaced_clause_id * clause.size(), clause.size());
dropped_literals_since_last_batch_ += clause.size();
if (HashClause(clause, 2) < HashClause(replaced_clause, 2)) {
std::copy(clause.begin(), clause.end(), replaced_clause.begin());
}
}
return false;
return true;
}
bool UniqueClauseStream::BlockClause(absl::Span<const int> clause) {
if (clause.size() > kMaxClauseSize) return false;
if (clause.size() <= 2) return false;
return fingerprints_.emplace(HashClause(clause)).second;
}
bool UniqueClauseStream::Delete(absl::Span<const int> clause) {
const size_t fingerprint = HashClause(clause);
absl::MutexLock mutex_lock(&mutex_);
// Note a clause with this hash may be buffered, but not yet exported.
return fingerprints_.erase(fingerprint) == 1;
const auto hash = HashClause(clause);
return fingerprints_.emplace(hash).second &&
!old_fingerprints_.contains(hash);
}
CompactVectorVector<int> UniqueClauseStream::NextBatch() {
CompactVectorVector<int> buffer;
buffer.reserve(kMaxLiteralsPerBatch / kMinClauseSize, kMaxLiteralsPerBatch);
CompactVectorVector<int> batch;
batch.reserve(kMaxLiteralsPerBatch / kMinClauseSize, kMaxLiteralsPerBatch);
int to_fill = kMaxLiteralsPerBatch;
absl::MutexLock mutex_lock(&mutex_);
for (int size = kMinClauseSize; size <= kMaxClauseSize; ++size) {
CHECK_EQ(NumLiteralsOfSize(size) % size, 0);
while (to_fill >= size && NumLiteralsOfSize(size) > 0) {
absl::Span<const int> clause = NextClause(size);
if (fingerprints_.contains(HashClause(clause))) {
buffer.Add(NextClause(size));
to_fill -= size;
}
std::vector<int>* buffer = MutableBufferForSize(size);
while (to_fill >= size && !buffer->empty()) {
batch.Add(NextClause(size));
to_fill -= size;
PopClause(size);
}
}
return buffer;
}
int UniqueClauseStream::FillUpstreamBuffer(UniqueClauseStream& upstream,
int size,
int max_clauses_to_export) {
int num_exported_clauses = 0;
absl::MutexLock mutex_lock(&mutex_);
while (NumLiteralsOfSize(size) > 0 &&
num_exported_clauses < max_clauses_to_export) {
absl::Span<const int> clause = NextClause(size);
// Don't emit deleted clauses.
if (fingerprints_.contains(HashClause(clause)) && upstream.Add(clause)) {
++num_exported_clauses;
if (to_fill < size) {
dropped_literals_since_last_batch_ += buffer->size();
buffer->clear();
}
PopClause(size);
}
return num_exported_clauses;
if (fingerprints_.size() >= kMaxFingerprints / 2) {
VLOG(2) << "Clearing fingerprints: " << fingerprints_.size() / 1024 << "Ki";
std::swap(fingerprints_, old_fingerprints_);
fingerprints_.clear();
fingerprints_.reserve(kMaxFingerprints);
}
if (to_fill > kMaxLiteralsPerBatch / 2 && lbd_threshold_ < kMaxLbd) {
lbd_threshold_ += 1;
VLOG(2) << "Inc lbd: " << lbd_threshold_;
} else if (dropped_literals_since_last_batch_ > 0 &&
lbd_threshold_ > kMinLbd) {
lbd_threshold_ -= 1;
VLOG(2) << "Dec lbd: " << lbd_threshold_;
}
dropped_literals_since_last_batch_ = 0;
return batch;
}
int UniqueClauseStream::NumBufferedLiterals() const {
absl::MutexLock mutex_lock(&mutex_);
int result = 0;
for (const auto& buffer : clauses_by_size_) {
result += buffer.size();
@@ -1207,42 +1210,6 @@ int UniqueClauseStream::NumBufferedLiterals() const {
return result;
}
bool UniqueClauseStream::CanAccept(int size, int lbd) const {
if (size <= 2 || size > kMaxClauseSize) return false;
absl::MutexLock mutex_lock(&mutex_);
if (lbd > lbd_threshold_) return false;
int num_literals_up_to_size = 0;
for (int i = kMinClauseSize; i <= size; ++i) {
num_literals_up_to_size += NumLiteralsOfSize(i);
}
return num_literals_up_to_size + size <= kMaxBufferedLiterals;
}
void UniqueClauseStream::RemoveWorstClauses() {
absl::MutexLock mutex_lock(&mutex_);
int literals_to_remove = 0;
for (const auto& buffer : clauses_by_size_) {
literals_to_remove += buffer.size();
}
literals_to_remove -= kMaxBufferedLiterals;
for (int size = kMaxClauseSize; size >= kMinClauseSize; --size) {
while (NumLiteralsOfSize(size) > 0) {
// Stop if removing one more clause of the current size would
// leave the buffer under full. Otherwise we might remove a shorter
// clause later!
if (literals_to_remove < size) return;
fingerprints_.erase(HashClause(NextClause(size)));
PopClause(size);
literals_to_remove -= size;
}
}
}
void UniqueClauseStream::set_lbd_threshold(int lbd) {
absl::MutexLock mutex_lock(&mutex_);
lbd_threshold_ = lbd;
}
size_t UniqueClauseStream::HashClause(absl::Span<const int> clause,
size_t hash_seed) {
size_t hash = absl::HashOf(hash_seed, clause.size());
@@ -1270,22 +1237,24 @@ int UniqueClauseStream::NumLiteralsOfSize(int size) const {
return BufferForSize(size).size();
}
SharedClausesManager::SharedClausesManager(bool always_synchronize,
absl::Duration share_frequency)
: always_synchronize_(always_synchronize),
share_frequency_(share_frequency) {}
SharedClausesManager::SharedClausesManager(bool always_synchronize)
: always_synchronize_(always_synchronize) {}
int SharedClausesManager::RegisterNewId() {
int SharedClausesManager::RegisterNewId(bool may_terminate_early) {
absl::MutexLock mutex_lock(&mutex_);
num_full_workers_ += may_terminate_early ? 0 : 1;
const int id = id_to_last_processed_binary_clause_.size();
id_to_last_processed_binary_clause_.resize(id + 1, 0);
id_to_last_returned_batch_.resize(id + 1, 0);
id_to_last_finished_batch_.resize(id + 1, 0);
id_to_last_returned_batch_.resize(id + 1, -1);
id_to_last_finished_batch_.resize(id + 1, -1);
id_to_clauses_exported_.resize(id + 1, 0);
id_to_clause_stream_.emplace_back();
return id;
}
bool SharedClausesManager::ShouldReadBatch(int reader_id, int writer_id) {
return reader_id != writer_id;
}
void SharedClausesManager::SetWorkerNameForId(int id,
absl::string_view worker_name) {
absl::MutexLock mutex_lock(&mutex_);
@@ -1312,18 +1281,25 @@ void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) {
}
}
std::vector<absl::Span<const int>> SharedClausesManager::GetUnseenClauses(
int id) {
std::vector<absl::Span<const int>> result;
void SharedClausesManager::AddBatch(int id, CompactVectorVector<int> batch) {
absl::MutexLock mutex_lock(&mutex_);
for (int i = id_to_last_returned_batch_[id]; i < batches_.size(); ++i) {
for (int j = 0; j < batches_[i].size(); ++j) {
result.push_back(batches_[i][j]);
id_to_clauses_exported_[id] += batch.size();
pending_batches_.push_back(std::move(batch));
}
const CompactVectorVector<int>& SharedClausesManager::GetUnseenClauses(int id) {
std::vector<absl::Span<const int>> result;
{
absl::MutexLock mutex_lock(&mutex_);
id_to_last_finished_batch_[id] = id_to_last_returned_batch_[id];
if (id_to_last_returned_batch_[id] + 1 < batches_.size()) {
id_to_last_returned_batch_[id] += 1;
return batches_[id_to_last_returned_batch_[id]];
}
}
id_to_last_finished_batch_[id] = id_to_last_returned_batch_[id];
id_to_last_returned_batch_[id] = batches_.size();
return result;
static CompactVectorVector<int>* const empty_batch =
new CompactVectorVector<int>();
return *empty_batch;
}
void SharedClausesManager::GetUnseenBinaryClauses(
@@ -1357,96 +1333,47 @@ void SharedClausesManager::LogStatistics(SolverLogger* logger) {
}
void SharedClausesManager::Synchronize() {
absl::MutexLock mutex_lock(&mutex_);
last_visible_binary_clause_ = added_binary_clauses_.size();
const int num_workers = id_to_clause_stream_.size();
if (num_workers <= 1) return;
if (!share_timer_.IsRunning()) share_timer_.Start();
if (share_timer_.GetDuration() < share_frequency_) return;
share_timer_.Restart();
std::vector<CompactVectorVector<int>> batches_to_merge;
{
absl::MutexLock mutex_lock(&mutex_);
last_visible_binary_clause_ = added_binary_clauses_.size();
const int num_workers = id_to_last_processed_binary_clause_.size();
if (num_workers <= 1) return;
// Tune LBD threshold for individual workers based on how the worker's buffer
// is. We aim to ensure workers can always export their fair share of clauses.
for (int id = 0; id < num_workers; ++id) {
UniqueClauseStream& stream = id_to_clause_stream_[id];
const int lbd_threshold = stream.lbd_threshold();
const int num_buffered_literals = stream.NumBufferedLiterals();
const bool underfull =
num_buffered_literals <
UniqueClauseStream::kMaxLiteralsPerBatch / num_workers;
const bool overfull =
num_buffered_literals >
2 * UniqueClauseStream::kMaxLiteralsPerBatch / num_workers;
const int new_lbd = std::clamp(lbd_threshold + underfull - overfull, 2,
UniqueClauseStream::kMaxClauseSize);
if (new_lbd != lbd_threshold) {
VLOG(2) << id_to_worker_name_[id]
<< " sharing clauses with lbd <= " << new_lbd;
stream.set_lbd_threshold(new_lbd);
if (pending_batches_.size() >= num_full_workers_) {
batches_to_merge = std::move(pending_batches_);
}
}
std::vector<int> ids(num_workers);
int literals_to_fill = UniqueClauseStream::kMaxLiteralsPerBatch;
for (int size = UniqueClauseStream::kMinClauseSize;
size <= UniqueClauseStream::kMaxClauseSize; ++size) {
ids.clear();
for (int id = 0; id < num_workers; ++id) {
if (id_to_clause_stream_[id].NumBufferedLiteralsOfSize(size) > 0) {
ids.push_back(id);
// Delete batches that have been consumed by all workers.
// Keep a few batches around for startup (min finished batch doesn't count
// workers that haven't registered yet).
if (batches_.size() > kMinBatches) {
const int min_finished_batch =
std::min<int>(batches_.size() - kMinBatches,
*absl::c_min_element(id_to_last_finished_batch_) + 1);
for (int i = 0; i < min_finished_batch; ++i) {
VLOG(2) << "Erasing batch";
batches_.pop_front();
}
for (int id = 0; id < id_to_last_finished_batch_.size(); ++id) {
id_to_last_returned_batch_[id] -= min_finished_batch;
id_to_last_finished_batch_[id] -= min_finished_batch;
}
}
// Use progressive filling to attempt to fill the batch with clauses of
// minimum size, this is max-min fair.
while (!ids.empty()) {
const int clauses_to_fill = literals_to_fill / size;
if (clauses_to_fill == 0) break;
// Some workers need to export more clauses to fill the batch due to
// rounding, but we don't want all workers to round up.
const int num_to_round_up = clauses_to_fill % ids.size();
for (int i = 0; i < ids.size(); ++i) {
const bool round_up = i < num_to_round_up;
const int id = ids[i];
const int shared = id_to_clause_stream_[id].FillUpstreamBuffer(
all_clauses_, size, clauses_to_fill / ids.size() + round_up);
id_to_clauses_exported_[id] += shared;
if (shared == 0 ||
id_to_clause_stream_[id].NumBufferedLiteralsOfSize(size) == 0) {
ids[i] = ids.back();
ids.pop_back();
--i;
}
}
// TODO(user): We could cleanup binary clauses that have been consumed.
}
if (batches_to_merge.empty()) return;
UniqueClauseStream next_batch;
for (const auto& batch : batches_to_merge) {
for (int i = 0; i < batch.size(); ++i) {
next_batch.Add(batch[i]);
}
}
if (all_clauses_.NumBufferedLiterals() > 0) {
batches_.push_back(all_clauses_.NextBatch());
VLOG(2) << "Batch #" << batches_.size() << " w/ " << batches_.back().size()
<< " clauses max size = "
<< batches_.back()[batches_.back().size() - 1].size();
if (next_batch.NumBufferedLiterals() > 0) {
absl::MutexLock mutex_lock(&mutex_);
VLOG(2) << "Merging batch";
batches_.push_back(next_batch.NextBatch());
}
// Delete batches that have been consumed by all workers.
// Keep a few batches around for startup (min finished batch doesn't count
// workers that haven't registered yet).
// This also ensures that our fingerprint table always contains the last few
// batches, so we reduce the chance of an old buffered duplicate clause on
// a worker being emitted from the global stream multiple times.
if (batches_.size() < kMinBatches) return;
const int min_finished_batch =
std::min<int>(batches_.size() - kMinBatches,
*absl::c_min_element(id_to_last_finished_batch_));
for (int i = 0; i < min_finished_batch; ++i) {
VLOG(2) << "Erasing batch";
for (int i = 0; i < batches_.front().size(); ++i) {
all_clauses_.Delete(batches_.front()[i]);
}
batches_.pop_front();
}
for (int id = 0; id < id_to_last_finished_batch_.size(); ++id) {
id_to_last_returned_batch_[id] -= min_finished_batch;
id_to_last_finished_batch_[id] -= min_finished_batch;
}
// TODO(user): We could cleanup binary clauses that have been consumed.
}
void SharedStatistics::AddStats(

View File

@@ -622,108 +622,87 @@ class SharedBoundsManager {
// It has a finite size internal buffer that is a small multiple of the batch
// size.
//
// This class is thread-safe, the idea is to have one per worker plus a
// global one to deduplicate between workers to minimize contention.
//
// This uses a finite buffer, so some clauses may be dropped if we generate too
// many more than we export, but that is rarely a problem because we never
// overfill the "global" stream, and if we drop a clause on a worker, one of the
// following will most likely happen:
// many more than we export, but that is rarely a problem because if we drop a
// clause on a worker, one of the following will most likely happen:
// 1. Some other worker learns the clause and shares it later.
// 2. All other workers also learn and drop the clause.
// 3. No other worker learns the clause, so it was not that helpful anyway.
//
// Note that this uses literals as encoded in a cp_model.proto. Thus, the
// literals can be negative numbers.
//
// TODO(user): This class might not want to live in this file now it no
// longer needs to be thread-safe.
class UniqueClauseStream {
public:
static constexpr int kMinClauseSize = 3;
static constexpr int kMaxClauseSize = 32;
static constexpr int kMinLbd = 2;
static constexpr int kMaxLbd = 5;
// Export 4KiB of clauses per batch.
static constexpr int kMaxLiteralsPerBatch = 4096 / sizeof(int);
// Bound the total literals we buffer, approximately enforced so shorter
// clauses can replace longer ones. This can be larger than
// kMaxLiteralsPerBatch (hence the separate constant), but experiments suggest
// that this doesn't help.
static constexpr int kMaxBufferedLiterals = kMaxLiteralsPerBatch;
UniqueClauseStream();
// Move only - this is an expensive class to copy.
UniqueClauseStream(const UniqueClauseStream&) = delete;
UniqueClauseStream(UniqueClauseStream&&) = default;
// Adds the clause to a future batch and returns true if the clause was added.
// Otherwise returns false. This may return false if the buffer is full.
// It will not block the clause if it is dropped to avoid unbounded growth of
// the hash table.
bool Add(absl::Span<const int> clause) ABSL_LOCKS_EXCLUDED(mutex_);
// Adds the clause to a future batch and returns true if the clause is new,
// otherwise returns false.
bool Add(absl::Span<const int> clause, int lbd = 2);
// Lazily deletes a clause with the same hash, returns true if it was present.
// The deleted clause will not be exported (either via NextBatch or
// FillUpstreamBuffer). A clause with the same hash may be re-added after
// calling Delete. If another clause with the same hash is added before the
// deleted clause is emitted then both clauses may be emitted.
bool Delete(absl::Span<const int> clause) ABSL_LOCKS_EXCLUDED(mutex_);
// Stop a clause being added to future batches.
// Returns true if the clause is new.
// This is approximate and can have false positives and negatives, it is still
// guaranteed to prevent adding the same clause twice to the next batch.
bool BlockClause(absl::Span<const int> clause);
// Returns a set of clauses totalling up to kMaxLiteralsPerBatch and removes
// exported clauses from the internal buffer.
CompactVectorVector<int> NextBatch() ABSL_LOCKS_EXCLUDED(mutex_);
// Returns a set of clauses totalling up to kMaxLiteralsPerBatch and clears
// the internal buffer.
// Increases the LBD threshold if the batch is underfull, and decreases it if
// too many clauses were dropped.
CompactVectorVector<int> NextBatch();
// Adds up to max_clauses_to_export clauses of a given size to upstream and
// removes them from the internal buffer.
int FillUpstreamBuffer(UniqueClauseStream& upstream, int clause_size,
int max_clauses_to_export) ABSL_LOCKS_EXCLUDED(mutex_);
// Returns the number of literals in the buffer in clauses with size <=
// max_size.
int NumBufferedLiteralsOfSize(int size) const ABSL_LOCKS_EXCLUDED(mutex_) {
absl::MutexLock lock(&mutex_);
return NumLiteralsOfSize(size);
void ClearFingerprints() {
old_fingerprints_.clear();
fingerprints_.clear();
fingerprints_.reserve(kMaxFingerprints);
}
int NumBufferedLiterals() const ABSL_LOCKS_EXCLUDED(mutex_);
// Returns true if the stream can accept a clause of the specified size and
// LBD without dropping it.
bool CanAccept(int size, int lbd) const;
// Returns the number of buffered literals in clauses of a given size.
int NumLiteralsOfSize(int size) const;
int NumBufferedLiterals() const;
// Delete longest clauses while keeping at least kMaxBufferedLiterals.
// This guarantees that CanAccept will return the same result as before, and
// at least the next batch will contain the same clauses, but we will emit
// fewer old, long clauses in the future.
void RemoveWorstClauses();
int lbd_threshold() const ABSL_LOCKS_EXCLUDED(mutex_) {
absl::MutexLock lock(&mutex_);
return lbd_threshold_;
}
void set_lbd_threshold(int lbd) ABSL_LOCKS_EXCLUDED(mutex_);
int lbd_threshold() const { return lbd_threshold_; }
void set_lbd_threshold(int lbd_threshold) { lbd_threshold_ = lbd_threshold; }
// Computes a hash that is independent of the order of literals in the clause.
static size_t HashClause(absl::Span<const int> clause, size_t hash_seed = 0);
private:
bool BlockClause(absl::Span<const int> clause)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
std::vector<int>* MutableBufferForSize(int size)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
// This needs to be >> the number of clauses we can plausibly learn in
// a few seconds.
constexpr static size_t kMaxFingerprints = 1024 * 1024 / sizeof(size_t);
constexpr static int kNumSizes = kMaxClauseSize - kMinClauseSize + 1;
std::vector<int>* MutableBufferForSize(int size) {
return &clauses_by_size_[size - kMinClauseSize];
}
absl::Span<const int> BufferForSize(int size) const
ABSL_SHARED_LOCKS_REQUIRED(mutex_) {
absl::Span<const int> BufferForSize(int size) const {
return clauses_by_size_[size - kMinClauseSize];
}
absl::Span<const int> NextClause(int size) const
ABSL_SHARED_LOCKS_REQUIRED(mutex_);
void PopClause(int size) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
absl::Span<const int> NextClause(int size) const;
void PopClause(int size);
// Computes the number of clauses of a given size.
int NumClausesOfSize(int size) const ABSL_SHARED_LOCKS_REQUIRED(mutex_);
int NumLiteralsOfSize(int size) const ABSL_SHARED_LOCKS_REQUIRED(mutex_);
int NumClausesOfSize(int size) const;
mutable absl::Mutex mutex_;
int lbd_threshold_ ABSL_GUARDED_BY(mutex_) = 2;
absl::flat_hash_set<size_t> fingerprints_ ABSL_GUARDED_BY(mutex_);
std::array<std::vector<int>, kMaxClauseSize - kMinClauseSize + 1>
clauses_by_size_ ABSL_GUARDED_BY(mutex_);
int lbd_threshold_ = kMinLbd;
int64_t dropped_literals_since_last_batch_ = 0;
absl::flat_hash_set<size_t> fingerprints_;
absl::flat_hash_set<size_t> old_fingerprints_;
std::array<std::vector<int>, kNumSizes> clauses_by_size_;
};
// This class holds clauses found and shared by workers.
@@ -735,14 +714,15 @@ class UniqueClauseStream {
// literals can be negative numbers.
class SharedClausesManager {
public:
explicit SharedClausesManager(bool always_synchronize,
absl::Duration share_frequency);
explicit SharedClausesManager(bool always_synchronize);
void AddBinaryClause(int id, int lit1, int lit2);
// Returns new glue clauses.
// The spans are guaranteed to remain valid until the next call to
// SyncClauses().
std::vector<absl::Span<const int>> GetUnseenClauses(int id);
const CompactVectorVector<int>& GetUnseenClauses(int id);
void AddBatch(int id, CompactVectorVector<int> batch);
// Fills new_clauses with
// {{lit1 of clause1, lit2 of clause1},
@@ -752,16 +732,9 @@ class SharedClausesManager {
std::vector<std::pair<int, int>>* new_clauses);
// Ids are used to identify which worker is exporting/importing clauses.
int RegisterNewId();
int RegisterNewId(bool may_terminate_early);
void SetWorkerNameForId(int id, absl::string_view worker_name);
// A worker can add or remove clauses from its own clause set.
// Retains ownership of the returned ClauseFilter.
UniqueClauseStream* GetClauseStream(int id) {
absl::ReaderMutexLock mutex_lock(&mutex_);
return &id_to_clause_stream_[id];
}
// Search statistics.
void LogStatistics(SolverLogger* logger);
@@ -770,8 +743,12 @@ class SharedClausesManager {
void Synchronize();
private:
static constexpr int kMinBatches = 10;
absl::Mutex mutex_;
// Returns true if `reader_id` should read batches produced by `writer_id`.
bool ShouldReadBatch(int reader_id, int writer_id)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
static constexpr int kMinBatches = 64;
mutable absl::Mutex mutex_;
// Binary clauses:
// Cache to avoid adding the same binary clause twice.
@@ -782,18 +759,22 @@ class SharedClausesManager {
std::vector<int> id_to_last_processed_binary_clause_ ABSL_GUARDED_BY(mutex_);
int last_visible_binary_clause_ ABSL_GUARDED_BY(mutex_) = 0;
// Longer clauses:
UniqueClauseStream all_clauses_ ABSL_GUARDED_BY(mutex_);
// This is slightly subtle - we need to track the batches that might be
// currently being processed by each worker.
// currently being processed by each worker to make sure we don't erase any
// batch that a worker might currently be reading.
std::vector<int> id_to_last_returned_batch_ ABSL_GUARDED_BY(mutex_);
std::vector<int> id_to_last_finished_batch_ ABSL_GUARDED_BY(mutex_);
std::deque<CompactVectorVector<int>> batches_ ABSL_GUARDED_BY(mutex_);
std::deque<UniqueClauseStream> id_to_clause_stream_ ABSL_GUARDED_BY(mutex_);
WallTimer share_timer_ ABSL_GUARDED_BY(mutex_);
// pending_batches_ contains clauses produced by individual workers that have
// not yet been merged into batches_, which can be read by other workers. When
// this is long enough they will be merged into a single batch and appended to
// batches_.
std::vector<CompactVectorVector<int>> pending_batches_
ABSL_GUARDED_BY(mutex_);
int num_full_workers_ ABSL_GUARDED_BY(mutex_) = 0;
const bool always_synchronize_ = true;
const absl::Duration share_frequency_;
// Stats:
std::vector<int64_t> id_to_clauses_exported_;

View File

@@ -833,10 +833,9 @@ TEST(SharedResponseManagerTest, Callback) {
}
TEST(SharedClausesManagerTest, SyncApi) {
SharedClausesManager manager(/*always_synchronize=*/true,
/*share_frequency=*/absl::ZeroDuration());
EXPECT_EQ(0, manager.RegisterNewId());
EXPECT_EQ(1, manager.RegisterNewId());
SharedClausesManager manager(/*always_synchronize=*/true);
EXPECT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false));
EXPECT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false));
manager.AddBinaryClause(/*id=*/0, 1, 2);
std::vector<std::pair<int, int>> new_clauses;
@@ -868,19 +867,6 @@ TEST(UniqueClauseStreamTest, AddIgnoresDuplicates) {
EXPECT_EQ(stream.NumBufferedLiterals(), 3);
}
TEST(UniqueClauseStreamTest, Delete) {
UniqueClauseStream stream;
EXPECT_TRUE(stream.Add({3, 2, 1}));
EXPECT_TRUE(stream.Delete({1, 2, 3}));
EXPECT_FALSE(stream.Delete({1, 2, 3, 4}));
EXPECT_THAT(stream.NextBatch(), ::testing::IsEmpty());
EXPECT_TRUE(stream.Add({2, 3, 1}));
EXPECT_EQ(stream.NumBufferedLiterals(), 3);
stream.NextBatch();
EXPECT_TRUE(stream.Delete({1, 2, 3}));
}
TEST(UniqueClauseStreamTest, AddIgnoresInvalidSizeClauses) {
UniqueClauseStream stream;
std::vector<int> long_clause;
@@ -905,46 +891,20 @@ TEST(UniqueClauseStreamTest, ExportsShortestClauses) {
}
// Batch 1 should be filled with size 3 clauses.
EXPECT_EQ(stream.NextBatch().size(), 1024 / 3);
// Batch 2 should be filled with size 4 clauses.
EXPECT_EQ(stream.NextBatch().size(), 1024 / 4);
// Batch 3 should be filled with size 5 clauses.
EXPECT_EQ(stream.NextBatch().size(), 1024 / 5);
}
TEST(UniqueClauseStreamTest, RemoveWorstClauses) {
UniqueClauseStream stream;
// Fill the buffer
for (int i = 0; i < UniqueClauseStream::kMaxBufferedLiterals / 6; ++i) {
stream.Add({i + 1, i + 256, i + 512, -4, -3, -2});
}
for (int i = 0; i < UniqueClauseStream::kMaxLiteralsPerBatch / 2 / 3; ++i) {
stream.Add({i + 1, i + 256, i + 512});
}
stream.RemoveWorstClauses();
EXPECT_GE(stream.NumBufferedLiterals(),
UniqueClauseStream::kMaxBufferedLiterals);
EXPECT_LT(stream.NumBufferedLiterals(),
UniqueClauseStream::kMaxBufferedLiterals + 6);
EXPECT_TRUE(stream.CanAccept(3, /*lbd=*/2));
EXPECT_FALSE(stream.CanAccept(6, /*lbd=*/2));
// Make sure none of the size 3 clauses were removed.
EXPECT_EQ(stream.NextBatch().size(),
UniqueClauseStream::kMaxLiteralsPerBatch / 2 / 3 +
UniqueClauseStream::kMaxBufferedLiterals / 2 / 6);
UniqueClauseStream::kMaxLiteralsPerBatch / 3);
// Batch 2 should be empty.
EXPECT_TRUE(stream.NextBatch().empty());
}
TEST(UniqueClauseStreamTest, DropsClauses) {
UniqueClauseStream stream;
// We shouldn't drop any clause where Add returns true.
int literals_successfully_added = 0;
for (int i = 0; i < 256 / 4; ++i) {
literals_successfully_added +=
4 * stream.Add({i + 1, i + 256, i + 512, -4});
}
for (int i = 0; i < 256 / 3; ++i) {
for (int i = 0; i < UniqueClauseStream::kMaxLiteralsPerBatch / 3; ++i) {
literals_successfully_added += 3 * stream.Add({i + 1, i + 256, i + 512});
}
for (int i = 0; i < 1024 * 1024 / 5; ++i) {
@@ -952,26 +912,18 @@ TEST(UniqueClauseStreamTest, DropsClauses) {
5 * stream.Add({i + 1, i + 256, i + 512, i + 1024, -2048});
}
EXPECT_FALSE(stream.CanAccept(3, /*lbd=*/3));
EXPECT_TRUE(stream.CanAccept(3, /*lbd=*/2));
EXPECT_TRUE(stream.CanAccept(4, /*lbd=*/2));
EXPECT_FALSE(stream.CanAccept(5, /*lbd=*/2));
EXPECT_EQ(stream.NumBufferedLiterals(), literals_successfully_added);
EXPECT_EQ(
literals_successfully_added,
256 - 256 % 3 + // size 3 clauses
256 - 256 % 4 + // size 4 clauses
UniqueClauseStream::kMaxBufferedLiterals -
UniqueClauseStream::kMaxBufferedLiterals % 5); // size 5 clauses
// Batch 1 should be filled with size 3 clauses.
EXPECT_EQ(stream.NextBatch().size(), 256 / 3 + 256 / 4 + 512 / 5);
EXPECT_GT(stream.NumBufferedLiterals(),
UniqueClauseStream::kMaxLiteralsPerBatch - 5);
// Batch should be filled with size 3 clauses.
EXPECT_EQ(stream.NextBatch().size(),
UniqueClauseStream::kMaxLiteralsPerBatch / 3);
EXPECT_TRUE(stream.NextBatch().empty());
}
TEST(SharedClausesManagerTest, NonSyncApi) {
SharedClausesManager manager(/*always_synchronize=*/false,
/*share_frequency=*/absl::ZeroDuration());
EXPECT_EQ(0, manager.RegisterNewId());
EXPECT_EQ(1, manager.RegisterNewId());
SharedClausesManager manager(/*always_synchronize=*/false);
EXPECT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false));
EXPECT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false));
manager.AddBinaryClause(/*id=*/0, 1, 2);
std::vector<std::pair<int, int>> new_clauses;
@@ -1018,115 +970,92 @@ TEST(SharedClausesManagerTest, NonSyncApi) {
}
TEST(SharedClausesManagerTest, ShareGlueClauses) {
SharedClausesManager manager(/*always_synchronize=*/true,
absl::ZeroDuration());
ASSERT_EQ(0, manager.RegisterNewId());
ASSERT_EQ(1, manager.RegisterNewId());
auto* stream0 = manager.GetClauseStream(0);
auto* stream1 = manager.GetClauseStream(1);
// Add a bunch of clauses that will be skipped in the first batch.
for (int i = 0; i < 1024 / 8; ++i) {
EXPECT_TRUE(stream0->Add({1, 2, 3, 4, 5, 6, 7, i + 8}));
SharedClausesManager manager(/*always_synchronize=*/true);
ASSERT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false));
ASSERT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false));
UniqueClauseStream stream0;
UniqueClauseStream stream1;
// Add a bunch of clauses that will be skipped batch.
for (int i = 0; i < UniqueClauseStream::kMaxLiteralsPerBatch / 8; ++i) {
EXPECT_TRUE(stream0.Add({1, 2, 3, 4, 5, 6, 7, i + 8}));
}
EXPECT_EQ(stream0->NumBufferedLiterals(), 1024);
EXPECT_EQ(stream0.NumBufferedLiterals(),
UniqueClauseStream::kMaxLiteralsPerBatch);
// Fill 1 batch of shorter clauses.
for (int i = 0; i < 1024 / 4; ++i) {
stream1->Add({1, 2, 3, i + 4});
for (int i = 0; i < UniqueClauseStream::kMaxLiteralsPerBatch / 4; ++i) {
stream1.Add({1, 2, 3, i + 4});
}
EXPECT_EQ(stream1->NumBufferedLiterals(), 1024);
manager.AddBatch(0, stream0.NextBatch());
manager.AddBatch(1, stream1.NextBatch());
manager.Synchronize();
EXPECT_THAT(manager.GetUnseenClauses(0), ::testing::IsEmpty());
EXPECT_THAT(manager.GetUnseenClauses(1), ::testing::IsEmpty());
manager.Synchronize();
EXPECT_THAT(manager.GetUnseenClauses(0), ::testing::SizeIs(1024 / 4));
EXPECT_THAT(manager.GetUnseenClauses(1), ::testing::SizeIs(1024 / 4));
EXPECT_THAT(manager.GetUnseenClauses(0), ::testing::IsEmpty());
EXPECT_THAT(manager.GetUnseenClauses(1), ::testing::IsEmpty());
manager.Synchronize();
EXPECT_THAT(manager.GetUnseenClauses(0), ::testing::SizeIs(1024 / 8));
EXPECT_THAT(manager.GetUnseenClauses(1), ::testing::SizeIs(1024 / 8));
EXPECT_THAT(manager.GetUnseenClauses(0), ::testing::IsEmpty());
EXPECT_THAT(manager.GetUnseenClauses(1), ::testing::IsEmpty());
manager.Synchronize();
EXPECT_THAT(manager.GetUnseenClauses(0), ::testing::IsEmpty());
EXPECT_THAT(manager.GetUnseenClauses(1), ::testing::IsEmpty());
}
TEST(SharedClausesManagerTest, ShareFrequency) {
SharedClausesManager manager(/*always_synchronize=*/true,
absl::InfiniteDuration());
ASSERT_EQ(0, manager.RegisterNewId());
ASSERT_EQ(1, manager.RegisterNewId());
auto* stream0 = manager.GetClauseStream(0);
auto* stream1 = manager.GetClauseStream(1);
for (int i = 0; i < 1024 / 5; ++i) {
stream0->Add({i + 1, i + 513, 2048, 2049, -10});
stream1->Add({i + 1, i + 513, 2048, 2049, -10});
}
EXPECT_THAT(manager.GetUnseenClauses(0), ::testing::IsEmpty());
EXPECT_THAT(manager.GetUnseenClauses(1), ::testing::IsEmpty());
manager.Synchronize();
EXPECT_THAT(manager.GetUnseenClauses(0),
::testing::SizeIs(UniqueClauseStream::kMaxLiteralsPerBatch / 4));
EXPECT_THAT(manager.GetUnseenClauses(1),
::testing::SizeIs(UniqueClauseStream::kMaxLiteralsPerBatch / 4));
EXPECT_THAT(manager.GetUnseenClauses(0), ::testing::IsEmpty());
EXPECT_THAT(manager.GetUnseenClauses(1), ::testing::IsEmpty());
}
TEST(SharedClausesManagerTest, LbdThresholdIncrease) {
SharedClausesManager manager(/*always_synchronize=*/true,
absl::ZeroDuration());
ASSERT_EQ(0, manager.RegisterNewId());
ASSERT_EQ(1, manager.RegisterNewId());
auto* stream0 = manager.GetClauseStream(0);
auto* stream1 = manager.GetClauseStream(1);
for (int i = 0; i < 1024 / 5; ++i) {
stream0->Add({i + 1, i + 513, 2048, 2049, -10});
stream1->Add({i + 1, i + 513, 2048, 2049, -10});
SharedClausesManager manager(/*always_synchronize=*/true);
ASSERT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false));
ASSERT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false));
UniqueClauseStream stream0;
UniqueClauseStream stream1;
const int kExpectedClauses = UniqueClauseStream::kMaxLiteralsPerBatch / 5;
for (int i = 0; i < kExpectedClauses; ++i) {
stream0.Add({i + 1, i + 513, 2048, 2049, -10});
stream1.Add({i + 1, i + 513, 2048, 2049, -10});
}
manager.AddBatch(0, stream0.NextBatch());
manager.AddBatch(1, stream1.NextBatch());
manager.Synchronize();
EXPECT_THAT(manager.GetUnseenClauses(0), ::testing::SizeIs(kExpectedClauses));
EXPECT_THAT(manager.GetUnseenClauses(1), ::testing::SizeIs(kExpectedClauses));
EXPECT_EQ(stream0.lbd_threshold(), 2);
EXPECT_EQ(stream1.lbd_threshold(), 2);
manager.Synchronize();
manager.AddBatch(0, stream0.NextBatch());
manager.AddBatch(1, stream1.NextBatch());
EXPECT_THAT(manager.GetUnseenClauses(0), ::testing::IsEmpty());
EXPECT_THAT(manager.GetUnseenClauses(1), ::testing::IsEmpty());
manager.Synchronize();
EXPECT_THAT(manager.GetUnseenClauses(0), ::testing::SizeIs(1024 / 5));
EXPECT_THAT(manager.GetUnseenClauses(1), ::testing::SizeIs(1024 / 5));
EXPECT_EQ(stream0->lbd_threshold(), 2);
EXPECT_EQ(stream1->lbd_threshold(), 2);
manager.Synchronize();
EXPECT_THAT(manager.GetUnseenClauses(0), ::testing::IsEmpty());
EXPECT_THAT(manager.GetUnseenClauses(1), ::testing::IsEmpty());
EXPECT_EQ(stream0->lbd_threshold(), 3);
EXPECT_EQ(stream1->lbd_threshold(), 3);
EXPECT_EQ(stream0.lbd_threshold(), 3);
EXPECT_EQ(stream1.lbd_threshold(), 3);
}
TEST(SharedClausesManagerTest, LbdThresholdDecrease) {
SharedClausesManager manager(/*always_synchronize=*/true,
absl::ZeroDuration());
ASSERT_EQ(0, manager.RegisterNewId());
ASSERT_EQ(1, manager.RegisterNewId());
ASSERT_EQ(2, manager.RegisterNewId());
auto* stream0 = manager.GetClauseStream(0);
auto* stream1 = manager.GetClauseStream(1);
SharedClausesManager manager(/*always_synchronize=*/true);
ASSERT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false));
ASSERT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false));
UniqueClauseStream stream0;
UniqueClauseStream stream1;
// Should increase LBD Threshold.
manager.Synchronize();
// Then add 1/2 batch of clauses to each worker.
for (int i = 0; i < 1024 / 4 / 2; ++i) {
stream0->Add({i + 1, i + 512, 2048, 2049});
stream1->Add({i + 1, i + 513, 2048, 2049});
manager.AddBatch(0, stream0.NextBatch());
manager.AddBatch(1, stream1.NextBatch());
const int kSize4Clauses = UniqueClauseStream::kMaxLiteralsPerBatch / 4 / 2;
const int kSize5ClausesAdded = UniqueClauseStream::kMaxLiteralsPerBatch / 5;
// Then add 1/2 batch of size 4 clauses to each worker.
for (int i = 0; i < kSize4Clauses; ++i) {
stream0.Add({i + 1, i + 512, 2048, 2049});
stream1.Add({i + 1, i + 513, 2048, -123});
}
// Than add loads of longer clauses to just stream0.
for (int i = 1024 / 5 / 2; i < 3 * 1024 / 5; ++i) {
stream0->Add({i + 1, 2, 3, -10});
for (int i = 0; i < kSize5ClausesAdded; ++i) {
stream0.Add({i + 1, 2, 3, -10, 12});
}
EXPECT_THAT(manager.GetUnseenClauses(0), ::testing::IsEmpty());
EXPECT_THAT(manager.GetUnseenClauses(1), ::testing::IsEmpty());
EXPECT_EQ(stream0->lbd_threshold(), 3);
EXPECT_EQ(stream1->lbd_threshold(), 3);
EXPECT_EQ(stream0.lbd_threshold(), 3);
EXPECT_EQ(stream1.lbd_threshold(), 3);
manager.AddBatch(0, stream0.NextBatch());
manager.AddBatch(1, stream1.NextBatch());
manager.Synchronize();
EXPECT_THAT(manager.GetUnseenClauses(0), ::testing::SizeIs(1024 / 4));
EXPECT_THAT(manager.GetUnseenClauses(1), ::testing::SizeIs(1024 / 4));
EXPECT_EQ(stream0->lbd_threshold(), 2);
EXPECT_EQ(stream1->lbd_threshold(), 3);
EXPECT_THAT(manager.GetUnseenClauses(0),
::testing::SizeIs(2 * kSize4Clauses));
EXPECT_THAT(manager.GetUnseenClauses(1),
::testing::SizeIs(2 * kSize4Clauses));
EXPECT_EQ(stream0.lbd_threshold(), 2);
}
} // namespace
} // namespace sat