[CP-SAT] more work on lrat; canonicalize Boolean variables in a few more places

This commit is contained in:
Laurent Perron
2025-12-05 15:59:03 +01:00
committed by Corentin Le Molgat
parent d41e49724a
commit b7d1dc65dc
15 changed files with 854 additions and 179 deletions

View File

@@ -1288,6 +1288,7 @@ cc_library(
"@abseil-cpp//absl/container:btree",
"@abseil-cpp//absl/container:flat_hash_map",
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/container:inlined_vector",
"@abseil-cpp//absl/flags:flag",
"@abseil-cpp//absl/hash",
"@abseil-cpp//absl/log",

View File

@@ -467,6 +467,14 @@ bool ClauseManager::InprocessingAddUnitClause(ClauseId unit_clause_id,
Literal true_literal) {
DCHECK_EQ(trail_->CurrentDecisionLevel(), 0);
if (trail_->Assignment().LiteralIsTrue(true_literal)) return true;
if (trail_->Assignment().LiteralIsFalse(true_literal)) {
if (lrat_proof_handler_ != nullptr) {
lrat_proof_handler_->AddInferredClause(
clause_id_generator_->GetNextId(), {},
{unit_clause_id, trail_->GetUnitClauseId(true_literal.Variable())});
}
return false;
}
trail_->EnqueueWithUnitReason(unit_clause_id, true_literal);

View File

@@ -35,6 +35,7 @@
#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/flags/flag.h"
#include "absl/hash/hash.h"
#include "absl/log/check.h"
@@ -9467,6 +9468,344 @@ bool CpModelPresolver::MergeNoOverlap2DConstraints() {
return true;
}
namespace {
bool ConstraintIsEncodingBound(const ConstraintProto& ct) {
if (ct.constraint_case() != ConstraintProto::kLinear) return false;
if (ct.linear().vars_size() != 1) return false;
if (ct.linear().coeffs(0) != 1) return false;
if (ct.enforcement_literal_size() != 1) return false;
return true;
}
} // namespace
// Return true if something changed.
bool CpModelPresolver::DetectEncodedComplexDomain(
PresolveContext* context, ConstraintProto* ct,
const Bitset64<int>& pertinent_bools) {
if (context->ModelIsUnsat()) return false;
if (ct->constraint_case() != ConstraintProto::kAtMostOne &&
ct->constraint_case() != ConstraintProto::kExactlyOne &&
ct->constraint_case() != ConstraintProto::kBoolOr) {
return false;
}
// Handling exaclty_one, at_most_one and bool_or is pretty similar. If we have
// l1 <=> v \in D1
// l2 <=> v \in D2
//
// We built
// l <=> v \in (D1 U D2).
//
// Moreover, if we have exactly_one(l1, l2, ...) or at_most_one(l1, l2, ...),
// we know that v cannot be in the intersection of D1 and D2. Thus, we first
// unconditionally remove (D1 ∩ D2) from the domain of v, making
// (l1=true and l2=true) impossible and allowing us to write our clauses as
// exactly_one(l1 or l2, ...) or at_most_one(l1 or l2, ...).
//
// Thus, other than the domain reduction that should not be done for the
// bool_or, all we need is to create a variable
// (l1 or l2) == l <=> (v \in (D1 U D2)).
google::protobuf::RepeatedField<int32_t>& literals =
ct->constraint_case() == ConstraintProto::kAtMostOne
? *ct->mutable_at_most_one()->mutable_literals()
: (ct->constraint_case() == ConstraintProto::kExactlyOne
? *ct->mutable_exactly_one()->mutable_literals()
: *ct->mutable_bool_or()->mutable_literals());
if (literals.size() <= 1) return false;
if (!ct->enforcement_literal().empty()) {
// TODO(user): support this case if it any problem needs it.
return false;
}
struct Linear1Info {
int lit = -1;
int positive_linear1_ct = -1;
int negative_linear1_ct = -1;
};
absl::flat_hash_map<int, absl::InlinedVector<Linear1Info, 1>> var_to_linear1;
for (const int lit : literals) {
if (PositiveRef(lit) < pertinent_bools.size() &&
!pertinent_bools[PositiveRef(lit)]) {
continue;
}
bool or_and_single_var_linear1 = true;
Linear1Info info;
int var = -1;
for (const int c : context->VarToConstraints(PositiveRef(lit))) {
if (c < 0) {
or_and_single_var_linear1 = false;
break;
}
const ConstraintProto& other_ct = context->working_model->constraints(c);
if (&other_ct == ct) continue;
if (!ConstraintIsEncodingBound(other_ct)) {
or_and_single_var_linear1 = false;
break;
}
if (other_ct.enforcement_literal(0) != lit &&
other_ct.enforcement_literal(0) != NegatedRef(lit)) {
or_and_single_var_linear1 = false;
break;
}
if (var == -1) {
var = other_ct.linear().vars(0);
} else if (var != other_ct.linear().vars(0)) {
or_and_single_var_linear1 = false;
break;
}
info.lit = lit;
if (other_ct.enforcement_literal(0) == lit) {
info.positive_linear1_ct = c;
} else {
DCHECK_EQ(other_ct.enforcement_literal(0), NegatedRef(lit));
info.negative_linear1_ct = c;
}
}
// When we have
// lit => var in D1
// ~lit => var in D2
// we can represent this on a line:
//
// ----------------D1----------------
// ----------------D2---------------
// |+++++++++++|*********************|++++++++++|
// lit=false lit unconstrained lit=true
//
// Handling the case where the variable is unconstrained by the lit is a
// bit of a pain: we want to replace two literals in a exactly_one by a
// single one, and if they are both unconstrained we might be forced to pick
// one arbitrarily to set to true. In any case, this is not a proper
// encoding of a complex domain, so we just ignore it.
// TODO(user): This can be implemented if it turns out to be common.
if (or_and_single_var_linear1 && info.negative_linear1_ct != -1 &&
info.positive_linear1_ct != -1) {
const Domain domain_enforced_lit = ReadDomainFromProto(
context->working_model->constraints(info.positive_linear1_ct)
.linear());
// ~lit1 => var in domain_enforced_not_lit1
const Domain domain_enforced_not_lit = ReadDomainFromProto(
context->working_model->constraints(info.negative_linear1_ct)
.linear());
if (domain_enforced_lit.IntersectionWith(domain_enforced_not_lit)
.IsEmpty()) {
var_to_linear1[var].push_back(info);
}
}
}
// Ignore all variables that only appear once.
std::vector<std::pair<int, std::vector<Linear1Info>>> var_to_linear1_infos;
for (const auto& [var, linear1_infos] : var_to_linear1) {
if (linear1_infos.size() > 1) {
var_to_linear1_infos.push_back(
{var, std::vector<Linear1Info>(linear1_infos.begin(),
linear1_infos.end())});
}
}
if (var_to_linear1_infos.empty()) return false;
// We have some variables to simplify! Start by sorting to make the code
// deterministic.
absl::c_sort(var_to_linear1_infos,
[](const std::pair<int, std::vector<Linear1Info>>& a,
const std::pair<int, std::vector<Linear1Info>>& b) {
return a.first < b.first;
});
// Doing the general code is rather complex, so we will just simplify one
// variable and two literals at a time, and leave for the presolve fixpoint
// to do the rest.
for (const auto& [var, infos] : var_to_linear1_infos) {
const Linear1Info& info1 = infos[0];
const Linear1Info& info2 = infos[1];
const int lit1 = info1.lit;
const int lit2 = info2.lit;
const Domain original_var_domain = context->DomainOf(var);
DCHECK_NE(info1.positive_linear1_ct, -1);
DCHECK_NE(info2.positive_linear1_ct, -1);
DCHECK_NE(info1.negative_linear1_ct, -1);
DCHECK_NE(info2.negative_linear1_ct, -1);
// lit1 => var in domain_enforced_lit1
const Domain domain_enforced_lit1 = ReadDomainFromProto(
context->working_model->constraints(info1.positive_linear1_ct)
.linear());
// ~lit1 => var in domain_enforced_not_lit1
const Domain domain_enforced_not_lit1 = ReadDomainFromProto(
context->working_model->constraints(info1.negative_linear1_ct)
.linear());
// lit2 => var in domain_enforced_lit2
const Domain domain_enforced_lit2 = ReadDomainFromProto(
context->working_model->constraints(info2.positive_linear1_ct)
.linear());
// ~lit2 => var in domain_enforced_not_lit2
const Domain domain_enforced_not_lit2 = ReadDomainFromProto(
context->working_model->constraints(info2.negative_linear1_ct)
.linear());
DCHECK(domain_enforced_lit1.IntersectionWith(domain_enforced_not_lit1)
.IsEmpty());
DCHECK(domain_enforced_lit2.IntersectionWith(domain_enforced_not_lit2)
.IsEmpty());
// First, the variable must be in the domain of either the lit or of its
// negation.
if (!context->IntersectDomainWith(
var, domain_enforced_lit1.UnionWith(domain_enforced_not_lit1))) {
return true;
}
if (!context->IntersectDomainWith(
var, domain_enforced_lit2.UnionWith(domain_enforced_not_lit2))) {
return true;
}
if (ct->constraint_case() != ConstraintProto::kBoolOr) {
// In virtue of the AMO, var must not be in the intersection of the two
// domains where both literals are true.
if (!context->IntersectDomainWith(
var, domain_enforced_lit2.IntersectionWith(domain_enforced_lit1)
.Complement())) {
return true;
}
}
const Domain domain_new_var_false = context->DomainOf(var).IntersectionWith(
domain_enforced_not_lit1.IntersectionWith(domain_enforced_not_lit2));
const Domain domain_new_var_true = context->DomainOf(var).IntersectionWith(
domain_new_var_false.Complement());
// Now we want to build a lit3 = (lit1 or lit2) to use in the AMO/bool_or.
const int new_var = context->NewBoolVarWithClause({lit1, lit2});
if (domain_new_var_true.IsEmpty()) {
if (!context->SetLiteralToFalse(new_var)) return true;
} else if (domain_new_var_false.IsEmpty()) {
if (!context->SetLiteralToTrue(new_var)) return true;
} else {
ConstraintProto* new_ct = context->working_model->add_constraints();
new_ct->add_enforcement_literal(new_var);
new_ct->mutable_linear()->add_vars(var);
new_ct->mutable_linear()->add_coeffs(1);
FillDomainInProto(domain_new_var_true, new_ct->mutable_linear());
new_ct = context->working_model->add_constraints();
new_ct->add_enforcement_literal(NegatedRef(new_var));
new_ct->mutable_linear()->add_vars(var);
new_ct->mutable_linear()->add_coeffs(1);
FillDomainInProto(domain_new_var_false, new_ct->mutable_linear());
}
// Remove the two literals from the AMO.
int new_size = 0;
for (int i = 0; i < literals.size(); ++i) {
if (literals.Get(i) != lit1 && literals.Get(i) != lit2) {
literals.Set(new_size++, literals.Get(i));
}
}
literals.Truncate(new_size);
literals.Add(new_var);
context->UpdateNewConstraintsVariableUsage();
context->UpdateRuleStats(
"variables: detected encoding of a complex domain with multiple "
"linear1");
}
return true;
}
void CpModelPresolver::DetectEncodedComplexDomains(PresolveContext* context) {
PresolveTimer timer(__FUNCTION__, logger_, time_limit_);
// Constraints taking a list of literals that can, under some conditions,
// accept the following substitution:
// constraint(a, b, ...) => constraint(a | b, ...)
// one obvious case is bool_or. But if we can know that a and b cannot be
// both true, we can also apply this to at_most_one and exactly_one.
std::vector<int> constraint_encoding_or; // bool_or, exactly_one, at_most_one
// To make sure this is not too slow, first do a pass to gather all linear1
// constraints that shares the same variable with other three linear1.
absl::flat_hash_map<int, absl::InlinedVector<int, 1>> var_to_linear1;
for (int i = 0; i < context->working_model->constraints_size(); ++i) {
const ConstraintProto& ct = context->working_model->constraints(i);
if (ct.constraint_case() == ConstraintProto::kBoolOr ||
ct.constraint_case() == ConstraintProto::kAtMostOne ||
ct.constraint_case() == ConstraintProto::kExactlyOne) {
constraint_encoding_or.push_back(i);
continue;
}
if (!ConstraintIsEncodingBound(ct)) {
continue;
}
var_to_linear1[ct.linear().vars(0)].push_back(i);
}
absl::erase_if(var_to_linear1,
[](const auto& p) { return p.second.size() <= 3; });
// Now that we reduced cheaply our set of "interesting" linear1, let's use the
// variable->constraint graph to restrict it further.
for (auto& [var, linear1_cts] : var_to_linear1) {
int new_size = 0;
for (const int ct : linear1_cts) {
const int ref =
context->working_model->constraints(ct).enforcement_literal(0);
// We want to focus on literals that become removable once we undo the
// encoding, otherwise this whole step might just make the problem harder.
// So we want it to appear in two linear1 and a bool_or/amo/exactly_one.
if (context->VarToConstraints(PositiveRef(ref)).size() <= 3) {
linear1_cts[new_size++] = ct;
}
}
linear1_cts.resize(new_size);
}
absl::erase_if(var_to_linear1,
[](const auto& p) { return p.second.size() <= 3; });
if (var_to_linear1.empty()) return;
// Now we use the linear1 we found to see which bool_or/amo/exactly_one could
// be applied to the heuristic.
Bitset64<int> booleans_potentially_encoding_domain(
context_->working_model->variables_size());
for (const auto& [unused, linear1_cts] : var_to_linear1) {
for (const int ct : linear1_cts) {
booleans_potentially_encoding_domain.Set(PositiveRef(
context->working_model->constraints(ct).enforcement_literal(0)));
}
}
int new_encoding_or_count = 0;
for (int i = 0; i < constraint_encoding_or.size(); ++i) {
const int c = constraint_encoding_or[i];
const ConstraintProto& ct = context->working_model->constraints(c);
const BoolArgumentProto& bool_ct =
ct.constraint_case() == ConstraintProto::kAtMostOne
? ct.at_most_one()
: (ct.constraint_case() == ConstraintProto::kExactlyOne
? ct.exactly_one()
: ct.bool_or());
if (absl::c_count_if(
bool_ct.literals(),
[booleans_potentially_encoding_domain](int ref) {
return booleans_potentially_encoding_domain[PositiveRef(ref)];
}) < 2) {
continue;
}
constraint_encoding_or[new_encoding_or_count++] = c;
}
constraint_encoding_or.resize(new_encoding_or_count);
for (const int c : constraint_encoding_or) {
ConstraintProto* ct = context->working_model->mutable_constraints(c);
bool changed = false;
do {
changed = DetectEncodedComplexDomain(
context, ct, booleans_potentially_encoding_domain);
if (changed) {
context->UpdateConstraintVariableUsage(c);
}
} while (changed);
}
}
// TODO(user): Should we take into account the exactly_one constraints? note
// that such constraint cannot be extended. If if a literal implies two literals
// at one inside an exactly one constraint then it must be false. Similarly if
@@ -14155,6 +14494,10 @@ CpSolverStatus CpModelPresolver::Presolve() {
ProcessSetPPC();
TransformClausesToExactlyOne();
if (!time_limit_->LimitReached()) {
DetectEncodedComplexDomains(context_);
}
// These operations might break symmetry. Or at the very least, the newly
// created variable must be incorporated in the generators.
if (context_->params().find_big_linear_overlap() &&

View File

@@ -222,6 +222,33 @@ class CpModelPresolver {
// related presolve.
void DetectDominatedLinearConstraints();
// Detects encodings of the form:
// b1 => x \in Domain1
// ~b1 => x \in Domain1.Complement()
// b2 => x \in Domain2
// ~b2 => x \in Domain2.Complement()
// b3 => x \in Domain3
// ~b3 => x \in Domain3.Complement()
// ...
// bool_or(b1, b2, ..., bn, y, z, ...)
// Where the bi do not appear in any other constraints. When we finds this
// pattern, we create a new boolean variable `l` and replaces all the
// constraints above by three new constraints:
// l => x \in Domain1 U Domain2 U ... U Domainn
// ~l => x \in (Domain1 U Domain2 U ... U Domainn).Complement()
// bool_or(l, y, z, ...),
// Note that `l` is equivalent to at least one of the bi to be true, which is
// a consequence that it is encoding a domain that is the union of the domains
// of the bis.
//
// It does the same when bool_or is replaced by an at_most_one or exactly_one
// but we need to add an extra constraint that
// x \notin (Domain_a U Domain_b) for all a != b.
void DetectEncodedComplexDomains(PresolveContext* context);
bool DetectEncodedComplexDomain(PresolveContext* context, ConstraintProto* ct,
const Bitset64<int>& pertinent_bools);
// Precomputes info about at most one, and use it to presolve linear
// constraints. It can be interesting to know for a given linear constraint
// that a subset of its variables are in at most one relation.

View File

@@ -81,8 +81,7 @@ inline void FillKeyAndBitmask(absl::Span<const Literal> clause,
// Returns true iff the truth table encoded in bitmask encode a function
// Xi = f(Xj, j != i);
template <int num_bits>
bool IsFunction(int i, SmallBitset truth_table) {
inline bool IsFunction(int i, int num_bits, SmallBitset truth_table) {
DCHECK_GE(i, 0);
DCHECK_LT(i, num_bits);

View File

@@ -64,9 +64,9 @@ TEST(FillKeyAndBitmaskTest, BasicBehavior1) {
}
TEST(IsFunctionTest, ConstantValue) {
EXPECT_TRUE(IsFunction<3>(0, 0b10101010));
EXPECT_FALSE(IsFunction<3>(1, 0b10101010));
EXPECT_FALSE(IsFunction<3>(2, 0b10101010));
EXPECT_TRUE(IsFunction(0, 3, 0b10101010));
EXPECT_FALSE(IsFunction(1, 3, 0b10101010));
EXPECT_FALSE(IsFunction(2, 3, 0b10101010));
}
TEST(AddHoleAtPositionTest, BasicTest) {

View File

@@ -343,6 +343,7 @@ inline void InclusionDetector<Storage>::DetectInclusions(
const uint64_t superset_signature = signatures_.back();
const auto is_in_superset_view = is_in_superset_.const_view();
for (const int superset_e : superset_elements_) {
work_done_ += one_watcher_[superset_e].size();
for (int i = 0; i < one_watcher_[superset_e].size(); ++i) {
const int c_index = one_watcher_[superset_e][i];
const Candidate& subset = candidates_[c_index];

View File

@@ -1353,6 +1353,7 @@ IntegerSearchHelper::IntegerSearchHelper(Model* model)
: parameters_(*model->GetOrCreate<SatParameters>()),
model_(model),
sat_solver_(model->GetOrCreate<SatSolver>()),
binary_implication_graph_(model->GetOrCreate<BinaryImplicationGraph>()),
integer_trail_(model->GetOrCreate<IntegerTrail>()),
encoder_(model->GetOrCreate<IntegerEncoder>()),
implied_bounds_(model->GetOrCreate<ImpliedBounds>()),
@@ -1487,6 +1488,15 @@ bool IntegerSearchHelper::GetDecision(
}
bool IntegerSearchHelper::TakeDecision(Literal decision) {
// If we are about to take a decision on a redundant literal, always
// prefer to branch on the representative. This should helps learn more
// consistent conflict.
//
// TODO(user): Ideally never learn anything on redundant variable. This is
// a bit of work.
decision = binary_implication_graph_->RepresentativeOf(decision);
CHECK(!sat_solver_->Assignment().LiteralIsAssigned(decision));
pseudo_costs_->BeforeTakingDecision(decision);
// Note that kUnsatTrailIndex might also mean ASSUMPTIONS_UNSAT.

View File

@@ -312,6 +312,7 @@ class IntegerSearchHelper {
const SatParameters& parameters_;
Model* model_;
SatSolver* sat_solver_;
BinaryImplicationGraph* binary_implication_graph_;
IntegerTrail* integer_trail_;
IntegerEncoder* encoder_;
ImpliedBounds* implied_bounds_;

View File

@@ -642,13 +642,33 @@ void LratProofHandler::Close(bool model_is_unsat) {
}
}
bool LratProofHandler::AddAndProveInferredClauseByEnumeration(
ClauseId new_id, absl::Span<const Literal> new_clause,
ClauseId LratProofHandler::AddAndProveInferredClauseByEnumeration(
absl::Span<const Literal> new_clause,
absl::Span<const ClauseId> ids_for_proof,
const CompactVectorVector<int, Literal>& clauses_for_proof) {
CHECK_EQ(ids_for_proof.size(), clauses_for_proof.size());
CHECK(!clauses_for_proof.empty());
// helper function to report some info on proof failure.
const auto error = [&, this](absl::string_view message) {
if (debug_crash_on_error_) {
LOG(INFO) << "Proving " << new_clause;
for (int i = 0; i < ids_for_proof.size(); ++i) {
LOG(INFO) << "input id= " << ids_for_proof[i]
<< " clause=" << clauses_for_proof[i];
}
LOG(FATAL) << message;
} else {
VLOG(2) << "Proving " << new_clause;
for (int i = 0; i < ids_for_proof.size(); ++i) {
VLOG(2) << "input id= " << ids_for_proof[i]
<< " clause=" << clauses_for_proof[i];
}
VLOG(2) << message;
}
return kNoClauseId;
};
// First we count the number of variables appearing and have a separate dense
// index for them. The first new_clause.size() dense index are exactly the
// literal of the new_clause.
@@ -659,20 +679,33 @@ bool LratProofHandler::AddAndProveInferredClauseByEnumeration(
const auto [it, inserted] =
to_dense_index.insert({lit.Variable(), to_dense_index.size()});
if (!inserted) {
VLOG(2) << "Duplicate variable in new_clause! " << new_clause;
return false;
return error("Duplicate variable in new clause");
}
}
// Then any new BooleanVariable appearing get the next dense index.
std::vector<Literal> relevant_literals;
for (int i = 0; i < clauses_for_proof.size(); ++i) {
int max_index = 0;
for (const Literal lit : clauses_for_proof[i]) {
const auto [it, inserted] =
to_dense_index.insert({lit.Variable(), to_dense_index.size()});
if (inserted) {
relevant_literals.push_back(lit);
}
max_index = std::max(max_index, it->second);
}
if (max_index < new_clause.size()) {
VLOG(2) << "The new clause is trivially true since one of the clauses is "
"included inside "
<< clauses_for_proof[i] << " in " << new_clause;
if (clauses_for_proof[i].size() == new_clause.size()) {
return ids_for_proof[i];
}
// TODO(user): if this ever happen we can create a new id and prove it
// with clauses_for_proof[i], but for now I never saw that.
return error("Case not yet supported");
}
}
@@ -680,8 +713,7 @@ bool LratProofHandler::AddAndProveInferredClauseByEnumeration(
//
// TODO(user): The limit can be increased a bit if needed.
if (to_dense_index.size() > 6) {
VLOG(2) << "Too many variables";
return false;
return error("Too many variables");
}
// For the proof we will need all clauses of the form
@@ -691,15 +723,11 @@ bool LratProofHandler::AddAndProveInferredClauseByEnumeration(
// That give us 2^(n + 1) intermediate clauses.
// Their ids will be stored in (1 << k) + binary_encoding_of_the_li.
const int n = to_dense_index.size() - new_clause.size();
CHECK_GT(n, 0); // We dealt with this above.
CHECK_EQ(n, relevant_literals.size());
const int num_intermediates = 1 << (n + 1);
std::vector<ClauseId> ids(num_intermediates, kNoClauseId);
if (n == 0) {
VLOG(2) << "Nothing to prove! An existing clause is included inside";
return false;
}
VLOG(2) << "Starting proof n= " << n << " " << num_intermediates;
// Any initial clause can be used to prove all the intermediates that contains
@@ -773,21 +801,20 @@ bool LratProofHandler::AddAndProveInferredClauseByEnumeration(
const ClauseId id1 = ids[higher1];
const ClauseId id2 = ids[higher2];
if (id1 == kNoClauseId || id2 == kNoClauseId) {
VLOG(2) << "missing higher level clauses in the resolution."
<< " index: " << std::bitset<8>(index)
<< " higher1: " << std::bitset<8>(higher1)
<< " higher2: " << std::bitset<8>(higher2);
return false;
return error(
absl::StrCat("missing higher level clauses in the resolution.",
" index: ", std::bitset<8>(index).to_string(),
" higher1: ", std::bitset<8>(higher1).to_string(),
" higher2: ", std::bitset<8>(higher2).to_string()));
}
ids[index] = k == 0 ? new_id : id_generator_->GetNextId();
ids[index] = id_generator_->GetNextId();
if (k != 0) {
VLOG(2) << "temporary !! " << ids[index] << " " << tmp_clause;
id_need_deletion[index] = true; // temporary.
}
if (!AddInferredClause(ids[index], tmp_clause, {id1, id2})) {
VLOG(2) << "Failed resolution step";
return false;
return error("Failed resolution step");
}
if (k == 0) {
@@ -811,7 +838,7 @@ bool LratProofHandler::AddAndProveInferredClauseByEnumeration(
}
}
return true;
return ids[1];
}
} // namespace sat

View File

@@ -158,8 +158,12 @@ class LratProofHandler {
// The last two arguments must have the same size and are in one to one
// correspondence. Note that we might not need all the given clauses in the
// proof.
bool AddAndProveInferredClauseByEnumeration(
ClauseId new_id, absl::Span<const Literal> new_clause,
//
// Return the new clause id. Note that in some corner cases, this could be
// one of the id passed in ids_for_proof. Return kNoClauseId if the proof
// is wrong.
ClauseId AddAndProveInferredClauseByEnumeration(
absl::Span<const Literal> new_clause,
absl::Span<const ClauseId> ids_for_proof,
const CompactVectorVector<int, Literal>& clauses_for_proof);

View File

@@ -62,13 +62,13 @@ TEST(AddAndProveInferredClauseByEnumerationTest, XorEquivalence) {
// This should be enough to prove equivalence.
{
std::vector<Literal> to_prove = {b.Negated(), a};
EXPECT_TRUE(lrat->AddAndProveInferredClauseByEnumeration(
id_generator->GetNextId(), to_prove, clause_ids, clauses));
EXPECT_NE(kNoClauseId, lrat->AddAndProveInferredClauseByEnumeration(
to_prove, clause_ids, clauses));
}
{
std::vector<Literal> to_prove = {a.Negated(), b};
EXPECT_TRUE(lrat->AddAndProveInferredClauseByEnumeration(
id_generator->GetNextId(), to_prove, clause_ids, clauses));
EXPECT_NE(kNoClauseId, lrat->AddAndProveInferredClauseByEnumeration(
to_prove, clause_ids, clauses));
}
}

View File

@@ -1920,22 +1920,32 @@ GateCongruenceClosure::~GateCongruenceClosure() {
template <int arity>
void GateCongruenceClosure::AddToTruthTable(
absl::Span<const Literal> clause,
absl::flat_hash_map<std::array<BooleanVariable, arity>, SmallBitset>&
data) {
CHECK_EQ(clause.size(), arity);
SatClause* clause,
absl::flat_hash_map<std::array<BooleanVariable, arity>, TruthTableId>&
ids) {
CHECK_EQ(clause->size(), arity);
std::array<BooleanVariable, arity> key;
SmallBitset bitmask;
FillKeyAndBitmask(clause, absl::MakeSpan(key), bitmask);
for (const BooleanVariable var : key) {
CHECK(!implication_graph_->IsRemoved(Literal(var, true)));
}
auto [it, inserted] = data.insert({key, bitmask});
if (!inserted) {
const SmallBitset old = it->second;
it->second &= bitmask; // Remove one value.
if (old != it->second) {
// TODO(user): keep id for proof.
FillKeyAndBitmask(clause->AsSpan(), absl::MakeSpan(key), bitmask);
TruthTableId next_id(truth_tables_bitset_.size());
auto [it, inserted] = ids.insert({key, next_id});
const TruthTableId id = it->second;
if (inserted) {
truth_tables_inputs_.Add(key);
truth_tables_bitset_.push_back(bitmask);
if (lrat_proof_handler_ != nullptr) {
tmp_ids_.push_back(id);
tmp_clauses_.push_back(clause);
}
} else {
const SmallBitset old = truth_tables_bitset_[id];
// Remove one value.
truth_tables_bitset_[id] &= bitmask;
if (lrat_proof_handler_ != nullptr && old != truth_tables_bitset_[id]) {
tmp_ids_.push_back(id);
tmp_clauses_.push_back(clause);
}
}
}
@@ -1944,8 +1954,13 @@ void GateCongruenceClosure::AddToTruthTable(
// the congruence closure should be quite fast.
void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables(
PresolveTimer& timer) {
truth_table3_.clear();
truth_table4_.clear();
ids3_.clear();
ids4_.clear();
truth_tables_inputs_.clear();
truth_tables_bitset_.clear();
truth_tables_clauses_.clear();
tmp_ids_.clear();
tmp_clauses_.clear();
std::vector<Literal> candidates;
for (SatClause* clause : clause_manager_->AllClausesInCreationOrder()) {
@@ -1953,9 +1968,9 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables(
if (clause->size() == 0) continue;
if (clause->size() == 3) {
AddToTruthTable<3>(clause->AsSpan(), truth_table3_);
AddToTruthTable<3>(clause, ids3_);
} else if (clause->size() == 4) {
AddToTruthTable<4>(clause->AsSpan(), truth_table4_);
AddToTruthTable<4>(clause, ids4_);
}
// Used for an optimization below.
@@ -2066,17 +2081,17 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables(
gates_target_.push_back(target.Index());
gates_type_.push_back(kAndGateType);
const int index = gates_inputs_.Add({});
const GateId gate_id = GateId(gates_inputs_.Add({}));
for (const Literal l : clause->AsSpan()) {
if (l == target) continue;
gates_inputs_.AppendToLastVector(l.NegatedIndex());
}
if (lrat_proof_handler_ != nullptr) {
gates_clause_.push_back(clause);
gates_clauses_.Add({clause});
}
// Canonicalize.
absl::Span<LiteralIndex> gate = gates_inputs_[index];
absl::Span<LiteralIndex> gate = gates_inputs_[gate_id];
std::sort(gate.begin(), gate.end());
// Even if we detected an and_gate from a base clause, we keep going
@@ -2088,27 +2103,35 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables(
timer.AddCounter("and_gates", gates_inputs_.size());
}
template <int arity>
void GateCongruenceClosure::ExtractShortGates(
PresolveTimer& timer,
const absl::flat_hash_map<std::array<BooleanVariable, arity>, SmallBitset>&
data) {
// For a table on n variables, we look for function x = f(n - 1) variable.
const int num_bits = arity - 1;
void GateCongruenceClosure::ExtractShortGates(PresolveTimer& timer) {
if (lrat_proof_handler_ != nullptr) {
truth_tables_clauses_.ResetFromFlatMapping(tmp_ids_, tmp_clauses_);
CHECK_EQ(truth_tables_bitset_.size(), truth_tables_clauses_.size());
}
// TODO(user): This is non-deterministic order. We need to fix that or
// initially sort the queue of gates to process.
int num_functions = 0;
for (const auto [key, truth_table] : data) {
for (int i = 0; i < arity; ++i) {
if (!IsFunction<arity>(i, truth_table)) continue;
++num_functions;
// Counters.
// We only fill a subset of the entries, but that makes the code shorter.
std::vector<int> num_tables(5);
std::vector<int> num_functions(5);
gates_target_.push_back(Literal(key[i], true));
// Note that using the indirection via TruthTableId allow this code to
// be deterministic.
CHECK_EQ(truth_tables_bitset_.size(), truth_tables_inputs_.size());
for (TruthTableId id(0); id < truth_tables_inputs_.size(); ++id) {
const absl::Span<const BooleanVariable> inputs = truth_tables_inputs_[id];
const SmallBitset truth_table = truth_tables_bitset_[id];
++num_tables[inputs.size()];
for (int i = 0; i < inputs.size(); ++i) {
if (!IsFunction(i, inputs.size(), truth_table)) continue;
const int num_bits = inputs.size() - 1;
++num_functions[num_bits];
gates_target_.push_back(Literal(inputs[i], true));
gates_inputs_.Add({});
for (int j = 0; j < arity; ++j) {
for (int j = 0; j < inputs.size(); ++j) {
if (i != j) {
gates_inputs_.AppendToLastVector(Literal(key[j], true));
gates_inputs_.AppendToLastVector(Literal(inputs[j], true));
}
}
@@ -2116,7 +2139,8 @@ void GateCongruenceClosure::ExtractShortGates(
// We will canonicalize it further in the main loop.
unsigned int type = 0;
for (int p = 0; p < (1 << num_bits); ++p) {
// Expand from (arity - 1) bits to (arity) bits.
// Expand from (num_bits == inputs.size() - 1) bits to (inputs.size())
// bits.
const int bigger_p = AddHoleAtPosition(i, p);
if ((truth_table >> (bigger_p + (1 << i))) & 1) {
@@ -2132,31 +2156,36 @@ void GateCongruenceClosure::ExtractShortGates(
}
gates_type_.push_back(type);
gates_clause_.push_back(nullptr);
if (lrat_proof_handler_ != nullptr) {
gates_clauses_.Add(truth_tables_clauses_[id]);
}
}
}
timer.AddCounter(absl::StrCat("table", arity), data.size());
timer.AddCounter(absl::StrCat("fn", num_bits), num_functions);
// Note that we only display non-zero counters.
for (int i = 2; i < 5; ++i) {
timer.AddCounter(absl::StrCat("table", i), num_tables[i]);
timer.AddCounter(absl::StrCat("fn", i), num_functions[i]);
}
}
namespace {
// Helper class to add LRAT proofs for equivalent gate target literals.
class LratGateCongruenceHelper {
public:
LratGateCongruenceHelper(const BinaryImplicationGraph* implication_graph,
const ClauseManager* clause_manager,
ClauseIdGenerator* clause_id_generator,
LratProofHandler* lrat_proof_handler,
absl::Span<const LiteralIndex> gates_target,
absl::Span<const SatClause* const> gates_clause,
DenseConnectedComponentsFinder& union_find)
LratGateCongruenceHelper(
const BinaryImplicationGraph* implication_graph,
ClauseManager* clause_manager, ClauseIdGenerator* clause_id_generator,
LratProofHandler* lrat_proof_handler,
const util_intops::StrongVector<GateId, LiteralIndex>& gates_target,
const CompactVectorVector<GateId, const SatClause*>& gates_clauses,
DenseConnectedComponentsFinder& union_find)
: implication_graph_(implication_graph),
clause_manager_(clause_manager),
clause_id_generator_(clause_id_generator),
lrat_proof_handler_(lrat_proof_handler),
gates_target_(gates_target),
gates_clause_(gates_clause),
gates_clauses_(gates_clauses),
union_find_(union_find) {}
~LratGateCongruenceHelper() {
@@ -2235,7 +2264,7 @@ class LratGateCongruenceHelper {
// Returns an LRAT clause rep(gates_target[gate_a_id]) =>
// rep(gates_target[gate_b_id]). ShortenEquivalencesWithRepresentative() must
// be called first on the two gate target literals.
ClauseId AddGateTargetImplication(int gate_a_id, int gate_b_id) {
ClauseId AddAndGateTargetImplication(GateId gate_a_id, GateId gate_b_id) {
const Literal a = Literal(gates_target_[gate_a_id]);
const Literal b = Literal(gates_target_[gate_b_id]);
const Literal rep_a = GetParent(a);
@@ -2251,7 +2280,7 @@ class LratGateCongruenceHelper {
// inputs are the negation of each clause literal other than the target.
// TODO(user): this can add redundant clauses if two original inputs
// have the same representative.
for (const Literal lit : gates_clause_[gate_a_id]->AsSpan()) {
for (const Literal lit : gates_clauses_[gate_a_id][0]->AsSpan()) {
if (lit == a) continue;
const Literal l = lit.Negated();
clause_ids.push_back(implication_graph_->GetClauseId(a.Negated(), l));
@@ -2260,7 +2289,7 @@ class LratGateCongruenceHelper {
}
// For each original input l of b, rep(l) => l. The original inputs are
// the negation of each gate clause literal other than its target b.
for (const Literal lit : gates_clause_[gate_b_id]->AsSpan()) {
for (const Literal lit : gates_clauses_[gate_b_id][0]->AsSpan()) {
if (lit == b) continue;
const Literal l = lit.Negated();
ShortenEquivalencesWithRepresentative(l);
@@ -2268,7 +2297,7 @@ class LratGateCongruenceHelper {
}
// The original inputs of gate_b_id imply its target b:
clause_ids.push_back(
clause_manager_->GetClauseId(gates_clause_[gate_b_id]));
clause_manager_->GetClauseId(gates_clauses_[gate_b_id][0]));
// b => rep(b):
Append(clause_ids, GetLiteralImpliesRepresentativeClause(b));
@@ -2278,6 +2307,178 @@ class LratGateCongruenceHelper {
return clause_id;
}
void ClearTemporaryProof() {
CHECK(lrat_proof_handler_ != nullptr);
tmp_index_to_delete_.clear();
tmp_proof_clauses_id_.clear();
tmp_proof_clauses_.clear();
marked_.ClearAndResize(LiteralIndex(clause_manager_->literal_size()));
}
Literal GetRepresentativeWithProofSupport(Literal lit) {
const int lit_index_as_int = lit.Index().value();
if (union_find_.GetParent(lit_index_as_int) == lit_index_as_int) {
return lit;
}
if (lrat_proof_handler_ != nullptr) {
ShortenEquivalencesWithRepresentative(lit);
}
return Literal(LiteralIndex(union_find_.FindRoot(lit_index_as_int)));
}
void AddGateClausesToTemporaryProof(GateId id) {
CHECK(lrat_proof_handler_ != nullptr);
for (const SatClause* clause : gates_clauses_[id]) {
// We rewrite each clause using new equivalences found.
marked_.ResetAllToFalse();
tmp_literals_.clear();
tmp_clause_ids_.clear();
bool clause_is_trivial = false;
bool some_change = false;
for (const Literal lit : clause->AsSpan()) {
const Literal rep = GetRepresentativeWithProofSupport(lit);
if (rep != lit) {
some_change = true;
tmp_clause_ids_.push_back(GetLiteralImpliesRepresentativeClause(lit));
}
if (marked_[rep]) continue;
if (marked_[rep.Negated()]) {
clause_is_trivial = true;
break;
}
marked_.Set(rep);
tmp_literals_.push_back(rep);
}
// If this is the case, we shouldn't need it for the proof.
if (clause_is_trivial) continue;
ClauseId new_id = clause_manager_->GetClauseId(clause);
if (some_change) {
// If there is some change, we add a temporary clause id with the
// proof to go from the original clause to this one.
tmp_index_to_delete_.push_back(tmp_proof_clauses_.size());
tmp_clause_ids_.push_back(new_id);
new_id = clause_id_generator_->GetNextId();
lrat_proof_handler_->AddInferredClause(new_id, tmp_literals_,
tmp_clause_ids_);
}
// Add that clause and its id to the set of clauses needed for the proof.
tmp_proof_clauses_id_.push_back(new_id);
tmp_proof_clauses_.Add(tmp_literals_);
}
// Hacky: If we have a single clause, then there is a chance this was
// an and_gate. We must add all the implications target => inputs[i].
// Note that the inputs are the negation of the literals in the unique
// clause, so we really have target => not(lit) for lit in clause.
// which gives (not(target), not(lit)) for the needed clause.
if (gates_clauses_[id].size() == 1) {
// Tricky: The target might have been negated ! so we recover it from
// the unique clause.
const Literal current = Literal(gates_target_[id]);
LiteralIndex real_target = kNoLiteralIndex;
for (const Literal lit : gates_clauses_[id][0]->AsSpan()) {
if (current.Variable() == lit.Variable()) {
real_target = lit.Index();
}
}
if (real_target == kNoLiteralIndex) return;
const Literal neg_target = Literal(real_target).Negated();
const Literal neg_target_rep =
GetRepresentativeWithProofSupport(neg_target);
for (const Literal lit : gates_clauses_[id][0]->AsSpan()) {
const Literal neg_lit = lit.Negated();
if (neg_lit == neg_target) continue;
const Literal neg_lit_rep = GetRepresentativeWithProofSupport(neg_lit);
ClauseId new_id = implication_graph_->GetClauseId(neg_target, neg_lit);
if (new_id == kNoClauseId) {
// We where likely not a bool_and to start with, so we shouldn't need
// these clauses.
break;
}
if (neg_lit != neg_lit_rep || neg_target != neg_target_rep) {
tmp_clause_ids_.clear();
tmp_index_to_delete_.push_back(tmp_proof_clauses_.size());
if (neg_lit != neg_lit_rep) {
tmp_clause_ids_.push_back(
GetLiteralImpliesRepresentativeClause(neg_lit));
}
if (neg_target != neg_target_rep) {
tmp_clause_ids_.push_back(
GetLiteralImpliesRepresentativeClause(neg_target));
}
tmp_clause_ids_.push_back(new_id);
new_id = clause_id_generator_->GetNextId();
lrat_proof_handler_->AddInferredClause(
new_id, {neg_target_rep, neg_lit_rep}, tmp_clause_ids_);
}
tmp_proof_clauses_id_.push_back(new_id);
tmp_proof_clauses_.Add({neg_target_rep, neg_lit_rep});
}
}
}
// Same as AddAndGateTargetImplication() but with truth table based gates.
std::pair<ClauseId, ClauseId> AddShortGateTargetEquivalence(
GateId gate_a_id, GateId gate_b_id) {
// Just add all clauses from both gates.
// But note that we need to remap them.
ClearTemporaryProof();
AddGateClausesToTemporaryProof(gate_a_id);
AddGateClausesToTemporaryProof(gate_b_id);
// All clauses are now in tmp_proof_clauses_/tmp_proof_clauses_id_.
// We can add both implications with proof.
const Literal a = Literal(gates_target_[gate_a_id]);
const Literal b = Literal(gates_target_[gate_b_id]);
const Literal rep_a = GetParent(a);
const Literal rep_b = GetParent(b);
DCHECK(IsRepresentative(rep_a));
DCHECK(IsRepresentative(rep_b));
const ClauseId rep_a_implies_rep_b =
lrat_proof_handler_->AddAndProveInferredClauseByEnumeration(
{rep_a.Negated(), rep_b}, tmp_proof_clauses_id_,
tmp_proof_clauses_);
const ClauseId rep_b_implies_rep_a =
lrat_proof_handler_->AddAndProveInferredClauseByEnumeration(
{rep_b.Negated(), rep_a}, tmp_proof_clauses_id_,
tmp_proof_clauses_);
for (const int i : tmp_index_to_delete_) {
// Corner case if the proof used a temporary id.
if (tmp_proof_clauses_id_[i] == rep_a_implies_rep_b) continue;
if (tmp_proof_clauses_id_[i] == rep_b_implies_rep_a) continue;
lrat_proof_handler_->DeleteClause(tmp_proof_clauses_id_[i],
tmp_proof_clauses_[i]);
}
return {rep_a_implies_rep_b, rep_b_implies_rep_a};
}
ClauseId ProofForFixingLiteral(Literal to_fix, GateId id) {
CHECK(IsRepresentative(to_fix));
ClearTemporaryProof();
AddGateClausesToTemporaryProof(id);
const ClauseId new_id =
lrat_proof_handler_->AddAndProveInferredClauseByEnumeration(
{to_fix}, tmp_proof_clauses_id_, tmp_proof_clauses_);
for (const int i : tmp_index_to_delete_) {
// Corner case if the proof used a temporary id.
if (tmp_proof_clauses_id_[i] == new_id) continue;
lrat_proof_handler_->DeleteClause(tmp_proof_clauses_id_[i],
tmp_proof_clauses_[i]);
}
return new_id;
}
void AddGateEquivalenceClauses(Literal child, ClauseId child_implies_parent,
ClauseId parent_implies_child) {
DCHECK(!parent_equivalence_.contains(child));
@@ -2293,13 +2494,14 @@ class LratGateCongruenceHelper {
// precondition is that two original inputs l and m with rep(l) = rep and
// rep(m) = not(rep) must exist.
void AppendFixAndGateTargetClauses(
int gate_id, Literal rep, absl::InlinedVector<ClauseId, 4>& clause_ids) {
GateId gate_id, Literal rep,
absl::InlinedVector<ClauseId, 4>& clause_ids) {
const Literal target = Literal(gates_target_[gate_id]);
LiteralIndex l_index = kNoLiteralIndex;
LiteralIndex m_index = kNoLiteralIndex;
// Find l and m in the original inputs (the negation of each gate clause
// literal other than its target).
for (const Literal lit : gates_clause_[gate_id]->AsSpan()) {
for (const Literal lit : gates_clauses_[gate_id][0]->AsSpan()) {
if (l_index != kNoLiteralIndex && m_index != kNoLiteralIndex) break;
const Literal l = lit.Negated();
ShortenEquivalencesWithRepresentative(l);
@@ -2350,11 +2552,11 @@ class LratGateCongruenceHelper {
}
const BinaryImplicationGraph* implication_graph_;
const ClauseManager* clause_manager_;
ClauseManager* clause_manager_;
ClauseIdGenerator* clause_id_generator_;
LratProofHandler* lrat_proof_handler_;
absl::Span<const LiteralIndex> gates_target_;
absl::Span<const SatClause* const> gates_clause_;
const util_intops::StrongVector<GateId, LiteralIndex>& gates_target_;
const CompactVectorVector<GateId, const SatClause*>& gates_clauses_;
DenseConnectedComponentsFinder& union_find_;
// For each gate target with a parent in `union_find_` different from itself,
@@ -2365,11 +2567,25 @@ class LratGateCongruenceHelper {
// The literals of the clauses in `to_delete_`. Only needed when checking
// DRAT.
std::vector<std::pair<Literal, Literal>> clauses_to_delete_;
// For AddShortGateTargetImplication().
std::vector<int> tmp_index_to_delete_;
std::vector<ClauseId> tmp_proof_clauses_id_;
CompactVectorVector<int, Literal> tmp_proof_clauses_;
// For the simplification of clauses using equivalences in
// AddGateClausesToTemporaryProof().
SparseBitset<LiteralIndex> marked_;
std::vector<ClauseId> tmp_clause_ids_;
std::vector<Literal> tmp_literals_;
};
} // namespace
bool GateCongruenceClosure::DoOneRound(bool log_info) {
if (implication_graph_->IsEmpty()) return true;
clause_manager_->DetachAllClauses();
PresolveTimer timer("GateCongruenceClosure", logger_, time_limit_);
timer.OverrideLogging(log_info);
@@ -2381,24 +2597,22 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
gates_target_.clear();
gates_inputs_.clear();
gates_type_.clear();
gates_clause_.clear();
gates_clauses_.clear();
ExtractAndGatesAndFillShortTruthTables(timer);
// TODO(user): We currently do not support this with lrat. Fix.
if (lrat_proof_handler_ == nullptr) {
ExtractShortGates<3>(timer, truth_table3_);
ExtractShortGates<4>(timer, truth_table4_);
}
ExtractShortGates(timer);
// All vector have the same size.
// Except gates_clause_ which is only filled if we need proof.
// Except gates_clauses_ which is only filled if we need proof.
CHECK_EQ(gates_target_.size(), gates_type_.size());
CHECK_EQ(gates_target_.size(), gates_inputs_.size());
if (lrat_proof_handler_ != nullptr) {
CHECK_EQ(gates_target_.size(), gates_clauses_.size());
}
// If two gates have the same type and the same inputs, their targets are
// equivalent. We use an hash set to detect that the inputs are the same.
absl::flat_hash_set<int, GateHash, GateEq> gate_set(
absl::flat_hash_set<GateId, GateHash, GateEq> gate_set(
/*capacity=*/gates_inputs_.size(), GateHash(&gates_type_, &gates_inputs_),
GateEq(&gates_type_, &gates_inputs_));
@@ -2413,18 +2627,18 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
// Tricky: we need to resize this to num_literals because the union_find that
// merges target can choose for a representative a literal that is not in the
// set of gate inputs.
MergeableOccurrenceList<LiteralIndex, int> input_literals_to_gate;
MergeableOccurrenceList<LiteralIndex, GateId> input_literals_to_gate;
input_literals_to_gate.ResetFromTranspose(gates_inputs_, num_literals);
LratGateCongruenceHelper lrat_helper(
implication_graph_, clause_manager_, clause_id_generator_,
lrat_proof_handler_, gates_target_, gates_clause_, union_find);
lrat_proof_handler_, gates_target_, gates_clauses_, union_find);
// Starts with all gates in the queue.
const int num_gates = gates_inputs_.size();
std::vector<bool> in_queue(num_gates, true);
std::vector<int> queue(num_gates);
for (int id = 0; id < num_gates; ++id) queue[id] = id;
std::vector<GateId> queue(num_gates);
for (GateId id(0); id < num_gates; ++id) queue[id.value()] = id;
// Main loop.
int num_units = 0;
@@ -2433,9 +2647,9 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
int arity1_equivalences = 0;
while (!queue.empty()) {
++num_processed;
const int id = queue.back();
const GateId id = queue.back();
queue.pop_back();
in_queue[id] = false;
in_queue[id.value()] = false;
// Tricky: the hash-map might contain id not yet canonicalized. And in
// particular might already contain the id we are processing.
@@ -2444,7 +2658,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
// version" and remove id if it was already there. The second will do it on
// the canonicalized version.
for (int pass = 0; pass < 2; ++pass) {
int other_id = -1;
GateId other_id(-1);
bool is_equivalent = false;
if (pass == 0) {
const auto it = gate_set.find(id);
@@ -2473,42 +2687,47 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
CHECK_EQ(absl::Span<const LiteralIndex>(gates_inputs_[id]),
absl::Span<const LiteralIndex>(gates_inputs_[other_id]));
// We detected a <=> b (or, equivalently, rep(a) <=> rep(b)).
const LiteralIndex a = gates_target_[id];
const LiteralIndex b = gates_target_[other_id];
input_literals_to_gate.RemoveFromFutureOutput(id);
if (lrat_proof_handler_ != nullptr) {
lrat_helper.ShortenEquivalencesWithRepresentative(Literal(a));
lrat_helper.ShortenEquivalencesWithRepresentative(Literal(b));
}
const LiteralIndex rep_a(union_find.FindRoot(a.value()));
const LiteralIndex rep_b(union_find.FindRoot(b.value()));
// We detected a <=> b (or, equivalently, rep(a) <=> rep(b)).
const Literal a(gates_target_[id]);
const Literal b(gates_target_[other_id]);
const Literal rep_a = lrat_helper.GetRepresentativeWithProofSupport(a);
const Literal rep_b = lrat_helper.GetRepresentativeWithProofSupport(b);
if (rep_a != rep_b) {
++num_equivalences;
const Literal rep_lit_a(rep_a);
const Literal rep_lit_b(rep_b);
ClauseId rep_a_implies_rep_b = kNoClauseId;
ClauseId rep_b_implies_rep_a = kNoClauseId;
if (lrat_proof_handler_ != nullptr) {
rep_a_implies_rep_b =
lrat_helper.AddGateTargetImplication(id, other_id);
rep_b_implies_rep_a =
lrat_helper.AddGateTargetImplication(other_id, id);
if (gates_type_[id] == kAndGateType) {
rep_a_implies_rep_b =
lrat_helper.AddAndGateTargetImplication(id, other_id);
rep_b_implies_rep_a =
lrat_helper.AddAndGateTargetImplication(other_id, id);
} else {
const auto [x, y] =
lrat_helper.AddShortGateTargetEquivalence(id, other_id);
rep_a_implies_rep_b = x;
rep_b_implies_rep_a = y;
}
}
if (!implication_graph_->AddBinaryClause(
rep_a_implies_rep_b, rep_lit_a.Negated(), rep_lit_b) ||
!implication_graph_->AddBinaryClause(
rep_b_implies_rep_a, rep_lit_b.Negated(), rep_lit_a)) {
if (!implication_graph_->AddBinaryClause(rep_a_implies_rep_b,
rep_a.Negated(), rep_b) ||
!implication_graph_->AddBinaryClause(rep_b_implies_rep_a,
rep_b.Negated(), rep_a)) {
return false;
}
for (const bool negate : {false, true}) {
const LiteralIndex x =
negate ? Literal(rep_a).NegatedIndex() : rep_a;
negate ? rep_a.NegatedIndex() : rep_a.Index();
const LiteralIndex y =
negate ? Literal(rep_b).NegatedIndex() : rep_b;
negate ? rep_b.NegatedIndex() : rep_b.Index();
// TODO(user): We need to change the union_find algo used to be sure
// the that if rep(x) = y then rep(not(x)) = not(y), otherwise we
// might miss some reductions.
union_find.AddEdge(x.value(), y.value());
const LiteralIndex rep(union_find.FindRoot(y.value()));
const LiteralIndex to_merge = rep == x ? y : x;
@@ -2533,10 +2752,10 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
// TODO(user): I think we could only add the gates of "to_merge"
// before we merge. This part of the code is quite quick in any
// case.
for (const int gate_id : input_literals_to_gate[rep]) {
if (in_queue[gate_id]) continue;
for (const GateId gate_id : input_literals_to_gate[rep]) {
if (in_queue[gate_id.value()]) continue;
queue.push_back(gate_id);
in_queue[gate_id] = true;
in_queue[gate_id.value()] = true;
}
}
}
@@ -2599,10 +2818,22 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
std::sort(inputs.begin(), inputs.begin() + new_size);
gates_inputs_.Shrink(id, new_size);
// Lets convert to the short "type" if we can. The truth table is simply
// a 1 on the last position (where all inputs are ones). We fall back to
// the case below to canonicalize further.
if (new_size > 4 || lrat_proof_handler_ != nullptr) continue;
// Lets convert a kAndGateType to the short "type" if we can. The truth
// table is simply a 1 on the last position (where all inputs are ones).
// We fall back to the case below to canonicalize further.
//
// This is needed because while our generic and_gate use 1 clause +
// binary, it wont detect a kAndGateType "badly" encoded with 4 ternary
// clauses for instance:
//
// b & c => a
// not(b) & c => not(a)
// b & not(c) => not(a)
// not(b) & not(c) => not(a)
//
// This is even more important since "generic" short gates might get
// simplified as we detect equialences, and become an and_gate later.
if (new_size > 4) continue;
gates_type_[id] = 1 << ((1 << new_size) - 1);
}
@@ -2612,7 +2843,9 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
CHECK_GE(gates_type_[id], 0);
CHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0);
for (LiteralIndex& lit_ref : inputs) {
lit_ref = LiteralIndex(union_find.FindRoot(lit_ref.value()));
lit_ref =
lrat_helper.GetRepresentativeWithProofSupport(Literal(lit_ref))
.Index();
}
const int new_size = CanonicalizeFunctionTruthTable(
@@ -2622,21 +2855,32 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
}
if (new_size == 1) {
// We have a function of size 1! this is an equivalence.
// We have a function of size 1! This is an equivalence.
//
// TODO(user): deal with it.
++arity1_equivalences;
input_literals_to_gate.RemoveFromFutureOutput(id);
break;
} else if (new_size == 0) {
// We have a fixed function ! just fix the literal.
// We have a fixed function! Just fix the literal.
CHECK(Literal(gates_target_[id]).IsPositive());
const Literal to_fix{Literal(gates_target_[id]).Variable(),
(gates_type_[id] & 1) == 1};
const Literal initial_to_fix =
(gates_type_[id] & 1) == 1 ? Literal(gates_target_[id])
: Literal(gates_target_[id]).Negated();
const Literal to_fix =
lrat_helper.GetRepresentativeWithProofSupport(initial_to_fix);
if (!assignment_.LiteralIsTrue(to_fix)) {
absl::InlinedVector<ClauseId, 4> clause_ids;
if (!clause_manager_->InprocessingFixLiteral(to_fix, clause_ids)) {
return false;
if (lrat_proof_handler_ == nullptr) {
if (!clause_manager_->InprocessingFixLiteral(to_fix, {})) {
return false;
}
} else {
const ClauseId clause_id =
lrat_helper.ProofForFixingLiteral(to_fix, id);
if (!clause_manager_->InprocessingAddUnitClause(clause_id,
to_fix)) {
return false;
}
}
++num_units;
}

View File

@@ -437,6 +437,7 @@ class BoundedVariableElimination {
//
// TODO(user): What is the relation with symmetry ? It feel like all the
// equivalences found here should be in the same symmetry orbit ?
DEFINE_STRONG_INDEX_TYPE(GateId);
class GateCongruenceClosure {
public:
explicit GateCongruenceClosure(Model* model)
@@ -455,30 +456,32 @@ class GateCongruenceClosure {
bool DoOneRound(bool log_info);
private:
DEFINE_STRONG_INDEX_TYPE(TruthTableId);
struct GateHash {
explicit GateHash(std::vector<int>* t,
CompactVectorVector<int, LiteralIndex>* g)
explicit GateHash(util_intops::StrongVector<GateId, int>* t,
CompactVectorVector<GateId, LiteralIndex>* g)
: gates_type(t), gates_inputs(g) {}
std::size_t operator()(int gate_id) const {
std::size_t operator()(GateId gate_id) const {
return absl::HashOf((*gates_type)[gate_id], (*gates_inputs)[gate_id]);
}
const std::vector<int>* gates_type;
const CompactVectorVector<int, LiteralIndex>* gates_inputs;
const util_intops::StrongVector<GateId, int>* gates_type;
const CompactVectorVector<GateId, LiteralIndex>* gates_inputs;
};
struct GateEq {
explicit GateEq(std::vector<int>* t,
CompactVectorVector<int, LiteralIndex>* g)
explicit GateEq(util_intops::StrongVector<GateId, int>* t,
CompactVectorVector<GateId, LiteralIndex>* g)
: gates_type(t), gates_inputs(g) {}
bool operator()(int gate_a, int gate_b) const {
bool operator()(GateId gate_a, GateId gate_b) const {
if (gate_a == gate_b) return true;
// We use absl::span<> comparison.
return ((*gates_type)[gate_a] == (*gates_type)[gate_b]) &&
((*gates_inputs)[gate_a] == (*gates_inputs)[gate_b]);
}
const std::vector<int>* gates_type;
const CompactVectorVector<int, LiteralIndex>* gates_inputs;
const util_intops::StrongVector<GateId, int>* gates_type;
const CompactVectorVector<GateId, LiteralIndex>* gates_inputs;
};
// Recovers "target_literal = and(literals)" from the model.
@@ -493,16 +496,15 @@ class GateCongruenceClosure {
// (not(literal[i]) for all i, target_literal).
void ExtractAndGatesAndFillShortTruthTables(PresolveTimer& timer);
// From possible assignment of "arity" given variables, extract functions.
// From possible assignment of small set of variables (truth_tables), extract
// functions of the form one_var = f(other_vars).
void ExtractShortGates(PresolveTimer& timer);
// Add a small clause to the corresponding truth table.
template <int arity>
void ExtractShortGates(
PresolveTimer& timer,
const absl::flat_hash_map<std::array<BooleanVariable, arity>,
SmallBitset>& data);
template <int arity>
void AddToTruthTable(absl::Span<const Literal> clause,
void AddToTruthTable(SatClause* clause,
absl::flat_hash_map<std::array<BooleanVariable, arity>,
SmallBitset>& data);
TruthTableId>& ids);
const VariablesAssignment& assignment_;
SatSolver* sat_solver_;
@@ -528,25 +530,33 @@ class GateCongruenceClosure {
// truth table. i.e. target = type[sum value_of_inputs[i] * 2^i]. For such
// gate, the target and inputs will always be canonicalized to positive and
// sorted literal. We just update the truth table accordingly.
//
// If lrat_proof_handler_ != nullptr, we also store all the SatClause* needed
// to infer such gate from the clause database. The case of kAndGateType is
// special, because we don't have SatClause for the binary clauses used to
// infer it. We will thus only store the base clause used, if we have a =
// and(x,y,...) we only store the clause "x and y and ... => a".
static constexpr int kAndGateType = -1;
std::vector<LiteralIndex> gates_target_;
std::vector<int> gates_type_;
CompactVectorVector<int, LiteralIndex> gates_inputs_;
// For each gate, "the" corresponding clause. For a gate a = and(x,y,...) this
// is the clause "x and y and ... => a". Only used for LRAT.
std::vector<const SatClause*> gates_clause_;
util_intops::StrongVector<GateId, LiteralIndex> gates_target_;
util_intops::StrongVector<GateId, int> gates_type_;
CompactVectorVector<GateId, LiteralIndex> gates_inputs_;
CompactVectorVector<GateId, const SatClause*> gates_clauses_;
// Map (Xi) (sorted) to a bitmask corresponding to the allowed values.
// We loop over all short clauses to fill this.
// We loop over all short clauses to fill this. We actually store an "id"
// pointing in the vectors below.
//
// TODO(user): Shorter clauses impact larger truth table too and we can
// combine two size 3 to construct a size 4 (needed for ITE-gate).
// not ideal.
absl::flat_hash_map<std::array<BooleanVariable, 3>, SmallBitset>
truth_table3_;
absl::flat_hash_map<std::array<BooleanVariable, 4>, SmallBitset>
truth_table4_;
// TruthTableIds are assigned in insertion order. We copy the map key in
// truth_tables_inputs_, this is a bit wasted but simplify the code.
absl::flat_hash_map<std::array<BooleanVariable, 3>, TruthTableId> ids3_;
absl::flat_hash_map<std::array<BooleanVariable, 4>, TruthTableId> ids4_;
CompactVectorVector<TruthTableId, BooleanVariable> truth_tables_inputs_;
util_intops::StrongVector<TruthTableId, SmallBitset> truth_tables_bitset_;
CompactVectorVector<TruthTableId, SatClause*> truth_tables_clauses_;
// Temporary vector used to construct truth_tables_clauses_.
std::vector<TruthTableId> tmp_ids_;
std::vector<SatClause*> tmp_clauses_;
// For stats.
double total_dtime_ = 0.0;

View File

@@ -187,7 +187,7 @@ class MergeableOccurrenceList {
int min_transpose_size = 0) {
rows_.ResetFromTranspose(input, min_transpose_size);
next_.assign(rows_.size(), K(-1));
marked_.ClearAndResize(input.size());
marked_.ClearAndResize(V(input.size()));
}
int size() const { return rows_.size(); }
@@ -1108,7 +1108,7 @@ inline void CompactVectorVector<K, V>::ResetFromTranspose(
// Compute maximum index.
int max_key = min_transpose_size;
for (V v = 0; v < other.size(); ++v) {
for (V v(0); v < other.size(); ++v) {
for (const K k : other[v]) {
max_key = std::max(max_key, InternalKey(k) + 1);
}
@@ -1116,7 +1116,7 @@ inline void CompactVectorVector<K, V>::ResetFromTranspose(
// Compute sizes_;
sizes_.assign(max_key, 0);
for (V v = 0; v < other.size(); ++v) {
for (V v(0); v < other.size(); ++v) {
for (const K k : other[v]) {
sizes_[InternalKey(k)]++;
}
@@ -1130,7 +1130,7 @@ inline void CompactVectorVector<K, V>::ResetFromTranspose(
// Copy data and uses starts as temporary indices.
buffer_.resize(other.num_entries());
for (V v = 0; v < other.size(); ++v) {
for (V v(0); v < other.size(); ++v) {
for (const K k : other[v]) {
buffer_[starts_[InternalKey(k)]++] = v;
}