[CP-SAT] polish diffn code; improve presolve w.r.t. intervals; simplify use of GCD in linear constraints; polish samples

This commit is contained in:
Laurent Perron
2019-02-27 14:26:44 +01:00
parent cfbf1db09b
commit 58ee3bde43
15 changed files with 287 additions and 113 deletions

View File

@@ -108,7 +108,7 @@ std::string ValidateArgumentReferencesInConstraint(const CpModelProto& model,
ProtobufShortDebugString(ct));
}
}
for (const int i : references.intervals) {
for (const int i : UsedIntervals(ct)) {
if (i < 0 || i >= model.constraints_size()) {
return absl::StrCat("Out of bound interval ", i, " in constraint #", c,
" : ", ProtobufShortDebugString(ct));

View File

@@ -172,12 +172,23 @@ void PresolveContext::UpdateRuleStats(const std::string& name) {
stats_by_rule_name[name]++;
}
void PresolveContext::AddVariableUsage(int c) {
const ConstraintProto& ct = working_model->constraints(c);
constraint_to_vars[c] = UsedVariables(working_model->constraints(c));
constraint_to_intervals[c] = UsedIntervals(ct);
for (const int v : constraint_to_vars[c]) var_to_constraints[v].insert(c);
for (const int i : constraint_to_intervals[c]) interval_usage[i]++;
}
void PresolveContext::UpdateConstraintVariableUsage(int c) {
CHECK_EQ(constraint_to_vars.size(), working_model->constraints_size());
const ConstraintProto& ct = working_model->constraints(c);
// Remove old usage.
for (const int v : constraint_to_vars[c]) var_to_constraints[v].erase(c);
constraint_to_vars[c] = UsedVariables(ct);
for (const int v : constraint_to_vars[c]) var_to_constraints[v].insert(c);
for (const int i : constraint_to_intervals[c]) interval_usage[i]--;
AddVariableUsage(c);
}
void PresolveContext::UpdateNewConstraintsVariableUsage() {
@@ -185,9 +196,10 @@ void PresolveContext::UpdateNewConstraintsVariableUsage() {
const int new_size = working_model->constraints_size();
CHECK_LE(old_size, new_size);
constraint_to_vars.resize(new_size);
constraint_to_intervals.resize(new_size);
interval_usage.resize(new_size);
for (int c = old_size; c < new_size; ++c) {
constraint_to_vars[c] = UsedVariables(working_model->constraints(c));
for (const int v : constraint_to_vars[c]) var_to_constraints[v].insert(c);
AddVariableUsage(c);
}
}
@@ -1546,11 +1558,28 @@ bool PresolveLinearOnBooleans(ConstraintProto* ct, PresolveContext* context) {
return RemoveConstraint(ct, context);
}
bool PresolveInterval(ConstraintProto* ct, PresolveContext* context) {
if (!ct->enforcement_literal().empty()) return false;
bool PresolveInterval(int c, ConstraintProto* ct, PresolveContext* context) {
const int start = ct->interval().start();
const int end = ct->interval().end();
const int size = ct->interval().size();
if (context->interval_usage[c] == 0) {
// Convert to linear.
ConstraintProto* new_ct = context->working_model->add_constraints();
*(new_ct->mutable_enforcement_literal()) = ct->enforcement_literal();
new_ct->mutable_linear()->add_domain(0);
new_ct->mutable_linear()->add_domain(0);
new_ct->mutable_linear()->add_vars(start);
new_ct->mutable_linear()->add_coeffs(1);
new_ct->mutable_linear()->add_vars(size);
new_ct->mutable_linear()->add_coeffs(1);
new_ct->mutable_linear()->add_vars(end);
new_ct->mutable_linear()->add_coeffs(-1);
context->UpdateRuleStats("interval: unused, converted to linear");
return RemoveConstraint(ct, context);
}
if (!ct->enforcement_literal().empty()) return false;
bool changed = false;
changed |= context->IntersectDomainWith(
end, context->DomainOf(start).AdditionWith(context->DomainOf(size)));
@@ -3292,7 +3321,7 @@ bool PresolveOneConstraint(int c, PresolveContext* context) {
return false;
}
case ConstraintProto::ConstraintCase::kInterval:
return PresolveInterval(ct, context);
return PresolveInterval(c, ct, context);
case ConstraintProto::ConstraintCase::kElement:
return PresolveElement(ct, context);
case ConstraintProto::ConstraintCase::kTable:

View File

@@ -144,6 +144,10 @@ struct PresolveContext {
std::vector<std::vector<int>> constraint_to_vars;
std::vector<absl::flat_hash_set<int>> var_to_constraints;
// We maintain how many time each interval is used.
std::vector<std::vector<int>> constraint_to_intervals;
std::vector<int> interval_usage;
CpModelProto* working_model;
CpModelProto* mapping_model;
@@ -167,6 +171,8 @@ struct PresolveContext {
SparseBitset<int64> modified_domains;
private:
void AddVariableUsage(int c);
// The current domain of each variables.
std::vector<Domain> domains;
};

View File

@@ -902,7 +902,6 @@ IntegerVariable AddLPConstraints(const CpModelProto& model_proto,
// TODO(user): It should be possible to speed this up if needed.
refs.variables.clear();
refs.literals.clear();
refs.intervals.clear();
AddReferencesUsedByConstraint(ct, &refs);
bool ok = true;
for (const int literal_ref : refs.literals) {

View File

@@ -26,6 +26,11 @@ void AddIndices(const IntList& indices, absl::flat_hash_set<int>* output) {
output->insert(indices.begin(), indices.end());
}
template <typename IntList>
void AddIndices(const IntList& indices, std::vector<int>* output) {
output->insert(output->end(), indices.begin(), indices.end());
}
} // namespace
void AddReferencesUsedByConstraint(const ConstraintProto& ct,
@@ -102,19 +107,14 @@ void AddReferencesUsedByConstraint(const ConstraintProto& ct,
output->variables.insert(ct.interval().size());
break;
case ConstraintProto::ConstraintCase::kNoOverlap:
AddIndices(ct.no_overlap().intervals(), &output->intervals);
break;
case ConstraintProto::ConstraintCase::kNoOverlap2D:
AddIndices(ct.no_overlap_2d().x_intervals(), &output->intervals);
AddIndices(ct.no_overlap_2d().y_intervals(), &output->intervals);
break;
case ConstraintProto::ConstraintCase::kCumulative:
output->variables.insert(ct.cumulative().capacity());
AddIndices(ct.cumulative().intervals(), &output->intervals);
AddIndices(ct.cumulative().demands(), &output->variables);
break;
case ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET:
// Empty constraint.
break;
}
}
@@ -387,10 +387,11 @@ std::string ConstraintCaseName(
}
}
// TODO(user): Optimize this function, it appear in the presolve profile.
// We could get rid of AddReferencesUsedByConstraint().
std::vector<int> UsedVariables(const ConstraintProto& ct) {
IndexReferences references;
AddReferencesUsedByConstraint(ct, &references);
std::vector<int> used_variables;
for (const int var : references.variables) {
used_variables.push_back(PositiveRef(var));
@@ -405,5 +406,65 @@ std::vector<int> UsedVariables(const ConstraintProto& ct) {
return used_variables;
}
std::vector<int> UsedIntervals(const ConstraintProto& ct) {
std::vector<int> used_intervals;
switch (ct.constraint_case()) {
case ConstraintProto::ConstraintCase::kBoolOr:
break;
case ConstraintProto::ConstraintCase::kBoolAnd:
break;
case ConstraintProto::ConstraintCase::kAtMostOne:
break;
case ConstraintProto::ConstraintCase::kBoolXor:
break;
case ConstraintProto::ConstraintCase::kIntDiv:
break;
case ConstraintProto::ConstraintCase::kIntMod:
break;
case ConstraintProto::ConstraintCase::kIntMax:
break;
case ConstraintProto::ConstraintCase::kIntMin:
break;
case ConstraintProto::ConstraintCase::kIntProd:
break;
case ConstraintProto::ConstraintCase::kLinear:
break;
case ConstraintProto::ConstraintCase::kAllDiff:
break;
case ConstraintProto::ConstraintCase::kElement:
break;
case ConstraintProto::ConstraintCase::kCircuit:
break;
case ConstraintProto::ConstraintCase::kRoutes:
break;
case ConstraintProto::ConstraintCase::kCircuitCovering:
break;
case ConstraintProto::ConstraintCase::kInverse:
break;
case ConstraintProto::ConstraintCase::kReservoir:
break;
case ConstraintProto::ConstraintCase::kTable:
break;
case ConstraintProto::ConstraintCase::kAutomaton:
break;
case ConstraintProto::ConstraintCase::kInterval:
break;
case ConstraintProto::ConstraintCase::kNoOverlap:
AddIndices(ct.no_overlap().intervals(), &used_intervals);
break;
case ConstraintProto::ConstraintCase::kNoOverlap2D:
AddIndices(ct.no_overlap_2d().x_intervals(), &used_intervals);
AddIndices(ct.no_overlap_2d().y_intervals(), &used_intervals);
break;
case ConstraintProto::ConstraintCase::kCumulative:
AddIndices(ct.cumulative().intervals(), &used_intervals);
break;
case ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET:
break;
}
gtl::STLSortAndRemoveDuplicates(&used_intervals);
return used_intervals;
}
} // namespace sat
} // namespace operations_research

View File

@@ -49,7 +49,6 @@ inline int EnforcementLiteral(const ConstraintProto& ct) {
struct IndexReferences {
absl::flat_hash_set<int> variables;
absl::flat_hash_set<int> literals;
absl::flat_hash_set<int> intervals;
};
void AddReferencesUsedByConstraint(const ConstraintProto& ct,
IndexReferences* output);
@@ -69,8 +68,12 @@ void ApplyToAllIntervalIndices(const std::function<void(int*)>& function,
std::string ConstraintCaseName(ConstraintProto::ConstraintCase constraint_case);
// Returns the sorted list of variables used by a constraint.
// Note that this include variable used as a literal.
std::vector<int> UsedVariables(const ConstraintProto& ct);
// Returns the sorted list of interval used by a constraint.
std::vector<int> UsedIntervals(const ConstraintProto& ct);
// Returns true if a proto.domain() contain the given value.
// The domain is expected to be encoded as a sorted disjoint interval list.
template <typename ProtoWithDomain>

View File

@@ -75,18 +75,9 @@ void NonOverlappingRectanglesPropagator::RegisterWith(
namespace {
// Returns true iff the 2 given intervals are disjoint. If their union is one
// point, this also returns true.
bool IntervalAreDisjointForSure(IntegerValue min_a, IntegerValue max_a,
IntegerValue min_b, IntegerValue max_b) {
return min_a >= max_b || min_b >= max_a;
}
// Returns the distance from interval a to the "bounding interval" of a and b.
IntegerValue DistanceToBoundingInterval(IntegerValue min_a, IntegerValue max_a,
IntegerValue min_b,
IntegerValue max_b) {
return std::max(min_a - min_b, max_b - max_a);
IntegerValue MaxSpan(IntegerValue min_a, IntegerValue max_a, IntegerValue min_b,
IntegerValue max_b) {
return std::max(max_a, max_b) - std::min(min_a, min_b) + 1;
}
} // namespace
@@ -110,10 +101,8 @@ void NonOverlappingRectanglesPropagator::SortNeighbors(int box) {
neighbors_.push_back(other);
cached_distance_to_bounding_box_[other] =
std::max(DistanceToBoundingInterval(box_x_min, box_x_max, other_x_min,
other_x_max),
DistanceToBoundingInterval(box_y_min, box_y_max, other_y_min,
other_y_max));
MaxSpan(box_x_min, box_x_max, other_x_min, other_x_max) *
MaxSpan(box_y_min, box_y_max, other_y_min, other_y_max);
}
IncrementalSort(neighbors_.begin(), neighbors_.begin(), [this](int i, int j) {
return cached_distance_to_bounding_box_[i] <
@@ -261,23 +250,6 @@ bool CheckOverload(bool time_direction, IntegerValue other_time,
return true;
}
void AddOtherReasons(const std::vector<int>& tasks, IntegerValue other_time,
int main_task, SchedulingConstraintHelper* other) {
other->ClearReason();
bool main_task_seen = false;
for (const int task : tasks) {
other->AddStartMaxReason(task, other_time);
other->AddEndMinReason(task, other_time + 1);
if (task == main_task) {
main_task_seen = true;
}
}
if (!main_task_seen) {
other->AddStartMaxReason(main_task, other_time);
other->AddEndMinReason(main_task, other_time + 1);
}
}
bool DetectPrecedences(bool time_direction, IntegerValue other_time,
const absl::flat_hash_set<int>& active_boxes,
SchedulingConstraintHelper* helper,
@@ -324,6 +296,7 @@ bool DetectPrecedences(bool time_direction, IntegerValue other_time,
if (end_min_of_critical_tasks > helper->StartMin(t)) {
const std::vector<TaskSet::Entry>& sorted_tasks = task_set.SortedTasks();
helper->ClearReason();
other->ClearReason();
// We need:
// - StartMax(ct) < EndMin(t) for the detectable precedence.
@@ -334,8 +307,6 @@ bool DetectPrecedences(bool time_direction, IntegerValue other_time,
tasks[i - critical_index] = sorted_tasks[i].task;
}
AddOtherReasons(tasks, other_time, t, other);
for (int i = critical_index; i < sorted_tasks.size(); ++i) {
const int ct = sorted_tasks[i].task;
if (ct == t) continue;
@@ -344,10 +315,14 @@ bool DetectPrecedences(bool time_direction, IntegerValue other_time,
helper->AddEnergyAfterReason(ct, sorted_tasks[i].duration_min,
window_start);
helper->AddStartMaxReason(ct, end_min - 1);
other->AddStartMaxReason(ct, other_time);
other->AddEndMinReason(ct, other_time + 1);
}
// Add the reason for t (we only need the end-min).
helper->AddEndMinReason(t, end_min);
other->AddStartMaxReason(t, other_time);
other->AddEndMinReason(t, other_time + 1);
// Import reasons from the 'other' dimension.
helper->ImportOtherReasons(*other);
@@ -368,6 +343,101 @@ bool DetectPrecedences(bool time_direction, IntegerValue other_time,
}
return true;
}
// Specialized propagation on only two boxes are mandatory on a single line
// (parallel to x or parallel to y). In that case, we can improve the reason
// why these two boxes overlap on one dimension, forcing them to be disjoint
// in the other dimension.
bool PropagateTwoBoxes(int b1, int b2, SchedulingConstraintHelper* helper,
SchedulingConstraintHelper* other) {
// For each direction and each order, we test if the boxes can be disjoint.
const int state = (helper->EndMin(b1) <= helper->StartMax(b2)) +
2 * (helper->EndMin(b2) <= helper->StartMax(b1));
const auto left_box_before_right_box =
[](int left, int right, SchedulingConstraintHelper* helper) {
// left box pushes right box.
const IntegerValue left_end_min = helper->EndMin(left);
if (left_end_min > helper->StartMin(right)) {
// Store reasons state.
const int literal_size = helper->MutableLiteralReason()->size();
const int integer_size = helper->MutableIntegerReason()->size();
helper->AddEndMinReason(left, left_end_min);
if (!helper->IncreaseStartMin(right, left_end_min)) {
return false;
}
// Restore the reasons to the state before the increase of the start.
helper->MutableLiteralReason()->resize(literal_size);
helper->MutableIntegerReason()->resize(integer_size);
}
// right box pushes left box.
const IntegerValue right_start_max = helper->StartMax(right);
if (right_start_max < helper->EndMax(left)) {
helper->AddStartMaxReason(right, right_start_max);
return helper->DecreaseEndMax(left, right_start_max);
}
return true;
};
// Clean up reasons.
helper->ClearReason();
other->ClearReason();
// This is an "hack" to be able to easily test for none or for one
// and only one of the conditions below.
switch (state) {
case 0: {
helper->AddReasonForBeingBefore(b1, b2);
helper->AddReasonForBeingBefore(b2, b1);
other->AddReasonForBeingBefore(b1, b2);
other->AddReasonForBeingBefore(b2, b1);
helper->ImportOtherReasons(*other);
return helper->ReportConflict();
}
case 1: {
other->AddReasonForBeingBefore(b1, b2);
other->AddReasonForBeingBefore(b2, b1);
helper->AddReasonForBeingBefore(b1, b2);
helper->ImportOtherReasons(*other);
return left_box_before_right_box(b1, b2, helper);
}
case 2: {
other->AddReasonForBeingBefore(b1, b2);
other->AddReasonForBeingBefore(b2, b1);
helper->AddReasonForBeingBefore(b2, b1);
helper->ImportOtherReasons(*other);
return left_box_before_right_box(b2, b1, helper);
}
default: {
return true;
}
}
}
IntegerValue FindCanonicalValue(IntegerValue lb, IntegerValue ub) {
if (lb == ub) return lb;
if (lb < 0 && ub > 0) return IntegerValue(0);
if (lb < 0 && ub <= 0) {
return -FindCanonicalValue(-ub, -lb);
}
int64 mask = 0;
IntegerValue candidate = ub;
for (int o = 0; o < 62; ++o) {
mask = 2 * mask + 1;
const IntegerValue masked_ub(ub.value() & ~mask);
if (masked_ub >= lb) {
candidate = masked_ub;
} else {
break;
}
}
return candidate;
}
} // namespace
bool NonOverlappingRectanglesPropagator::PropagateOnProjections() {
@@ -430,38 +500,39 @@ bool NonOverlappingRectanglesPropagator::FindMandatoryBoxesOnOneDimension(
}
for (const auto& it : mandatory_boxes) {
// Compute the 'canonical' line to use when explaining that boxes overlap
// on the 'other' dimension.
// Collect the common mandatory coordinates of all boxes.
IntegerValue lb(kint64min);
IntegerValue ub(kint64max);
for (const int task : it.second) {
lb = std::max(lb, other->StartMax(task));
ub = std::min(ub, other->EndMin(task));
ub = std::min(ub, other->EndMin(task) - 1);
}
const IntegerValue span = ub - lb + 1;
IntegerValue selected = lb;
for (int shift = 30; shift >= 0; --shift) {
const IntegerValue mask(static_cast<int64>(1) >> shift);
if (mask <= span) {
selected = (ub / mask) * mask;
break;
}
}
// Compute the 'canonical' line to use when explaining that boxes overlap
// on the 'other' dimension. We compute the multiple of the biggest power of
// two that is common to all boxes.
const IntegerValue candidate = FindCanonicalValue(lb, ub);
// And propagate.
RETURN_IF_FALSE(PropagateMandatoryBoxesOnOneDimension(selected, it.second,
RETURN_IF_FALSE(PropagateMandatoryBoxesOnOneDimension(candidate, it.second,
helper, other));
}
return true;
}
bool NonOverlappingRectanglesPropagator::PropagateMandatoryBoxesOnOneDimension(
IntegerValue event, const std::vector<int>& boxes,
IntegerValue other_time, const std::vector<int>& boxes,
SchedulingConstraintHelper* helper, SchedulingConstraintHelper* other) {
if (boxes.size() == 2) {
return PropagateTwoBoxes(boxes[0], boxes[1], helper, other);
}
const absl::flat_hash_set<int> active_boxes(boxes.begin(), boxes.end());
RETURN_IF_FALSE(CheckOverload(true, event, active_boxes, helper, other));
RETURN_IF_FALSE(DetectPrecedences(true, event, active_boxes, helper, other));
RETURN_IF_FALSE(DetectPrecedences(false, event, active_boxes, helper, other));
RETURN_IF_FALSE(CheckOverload(true, other_time, active_boxes, helper, other));
RETURN_IF_FALSE(
DetectPrecedences(true, other_time, active_boxes, helper, other));
RETURN_IF_FALSE(
DetectPrecedences(false, other_time, active_boxes, helper, other));
return true;
}

View File

@@ -13,6 +13,8 @@
#include "ortools/sat/linear_constraint.h"
#include "ortools/base/mathutil.h"
namespace operations_research {
namespace sat {
@@ -71,19 +73,12 @@ namespace {
// TODO(user): Template for any integer type and expose this?
IntegerValue ComputeGcd(const std::vector<IntegerValue>& values) {
if (values.empty()) return IntegerValue(1);
IntegerValue gcd = IntTypeAbs(values.front());
const int size = values.size();
for (int i = 1; i < size; ++i) {
// GCD(gcd, value) = GCD(value, gcd % value);
IntegerValue value = IntTypeAbs(values[i]);
while (value != 0) {
const IntegerValue r = gcd % value;
gcd = value;
value = r;
}
int64 gcd = 0;
for (const IntegerValue value : values) {
gcd = MathUtil::GCD64(gcd, std::abs(value.value()));
if (gcd == 1) break;
}
return gcd;
return IntegerValue(gcd);
}
} // namespace

View File

@@ -39,22 +39,24 @@ public class ChannelingSampleSat
{
static void Main()
{
// Model.
// Create the CP-SAT model.
CpModel model = new CpModel();
// Variables.
// Declare our two primary variables.
IntVar x = model.NewIntVar(0, 10, "x");
IntVar y = model.NewIntVar(0, 10, "y");
// Declare our intermediate boolean variable.
IntVar b = model.NewBoolVar("b");
// Implement b == (x >= 5).
model.Add(x >= 5).OnlyEnforceIf(b);
model.Add(x < 5).OnlyEnforceIf(b.Not());
// b implies (y == 10 - x).
// Create our two half-reified constraints.
// First, b implies (y == 10 - x).
model.Add(y == 10 - x).OnlyEnforceIf(b);
// not(b) implies y == 0.
// Second, not(b) implies y == 0.
model.Add(y == 0).OnlyEnforceIf(b.Not());
// Search for x values in increasing order.
@@ -63,7 +65,7 @@ public class ChannelingSampleSat
DecisionStrategyProto.Types.VariableSelectionStrategy.ChooseFirst,
DecisionStrategyProto.Types.DomainReductionStrategy.SelectMinValue);
// Create a solver and solve with a fixed search.
// Create the solver.
CpSolver solver = new CpSolver();
// Force solver to follow the decision strategy exactly.

View File

@@ -25,36 +25,38 @@ public class ChannelingSampleSat {
}
public static void main(String[] args) throws Exception {
// Model.
// Create the CP-SAT model.
CpModel model = new CpModel();
// Variables.
// Declare our two primary variables.
IntVar x = model.newIntVar(0, 10, "x");
IntVar y = model.newIntVar(0, 10, "y");
// Declare our intermediate boolean variable.
IntVar b = model.newBoolVar("b");
// Implements b == (x >= 5).
// Implement b == (x >= 5).
model.addGreaterOrEqual(x, 5).onlyEnforceIf(b);
model.addLessOrEqual(x, 4).onlyEnforceIf(b.not());
// b implies (y == 10 - x).
// Create our two half-reified constraints.
// First, b implies (y == 10 - x).
model.addLinearSumEqual(new IntVar[] {x, y}, 10).onlyEnforceIf(b);
// not(b) implies y == 0.
// Second, not(b) implies y == 0.
model.addEquality(y, 0).onlyEnforceIf(b.not());
// Searches for x values in increasing order.
// Search for x values in increasing order.
model.addDecisionStrategy(new IntVar[] {x},
DecisionStrategyProto.VariableSelectionStrategy.CHOOSE_FIRST,
DecisionStrategyProto.DomainReductionStrategy.SELECT_MIN_VALUE);
// Creates the solver.
// Create the solver.
CpSolver solver = new CpSolver();
// Forces the solver to follow the decision strategy exactly.
// Force the solver to follow the decision strategy exactly.
solver.getParameters().setSearchBranching(SatParameters.SearchBranching.FIXED_SEARCH);
// Solves the problem with the printer callback.
// Solve the problem with the printer callback.
solver.searchAllSolutions(model, new CpSolverSolutionCallback() {
public CpSolverSolutionCallback init(IntVar[] variables) {
variableArray = variables;

View File

@@ -19,27 +19,31 @@ namespace operations_research {
namespace sat {
void ChannelingSampleSat() {
// Model.
// Create the CP-SAT model.
CpModelBuilder cp_model;
// Main variables.
// Declare our two primary variables.
const IntVar x = cp_model.NewIntVar({0, 10});
const IntVar y = cp_model.NewIntVar({0, 10});
// Declare our intermediate boolean variable.
const BoolVar b = cp_model.NewBoolVar();
// b == (x >= 5).
// Implement b == (x >= 5).
cp_model.AddGreaterOrEqual(x, 5).OnlyEnforceIf(b);
cp_model.AddLessThan(x, 5).OnlyEnforceIf(Not(b));
// b implies (y == 10 - x).
// Create our two half-reified constraints.
// First, b implies (y == 10 - x).
cp_model.AddEquality(LinearExpr::Sum({x, y}), 10).OnlyEnforceIf(b);
// not(b) implies y == 0.
// Second, not(b) implies y == 0.
cp_model.AddEquality(y, 0).OnlyEnforceIf(Not(b));
// Search for x values in increasing order.
cp_model.AddDecisionStrategy({x}, DecisionStrategyProto::CHOOSE_FIRST,
DecisionStrategyProto::SELECT_MIN_VALUE);
// Create a solver and solve with a fixed search.
Model model;
SatParameters parameters;
parameters.set_search_branching(SatParameters::FIXED_SEARCH);

View File

@@ -40,22 +40,24 @@ class VarArraySolutionPrinter(cp_model.CpSolverSolutionCallback):
def ChannelingSampleSat():
"""Demonstrates how to link integer constraints together."""
# Model.
# Create the CP-SAT model.
model = cp_model.CpModel()
# Variables.
# Declare our two primary variables.
x = model.NewIntVar(0, 10, 'x')
y = model.NewIntVar(0, 10, 'y')
# Declare our intermediate boolean variable.
b = model.NewBoolVar('b')
# Implement b == (x >= 5).
model.Add(x >= 5).OnlyEnforceIf(b)
model.Add(x < 5).OnlyEnforceIf(b.Not())
# b implies (y == 10 - x).
# Create our two half-reified constraints.
# First, b implies (y == 10 - x).
model.Add(y == 10 - x).OnlyEnforceIf(b)
# not(b) implies y == 0.
# Second, not(b) implies y == 0.
model.Add(y == 0).OnlyEnforceIf(b.Not())
# Search for x values in increasing order.
@@ -65,10 +67,10 @@ def ChannelingSampleSat():
# Create a solver and solve with a fixed search.
solver = cp_model.CpSolver()
# Force solver to follow the decision strategy exactly.
# Force the solver to follow the decision strategy exactly.
solver.parameters.search_branching = cp_model.FIXED_SEARCH
# Searches and prints out all solutions.
# Search and print out all solutions.
solution_printer = VarArraySolutionPrinter([x, y, b])
solver.SearchForAllSolutions(model, solution_printer)

View File

@@ -76,8 +76,8 @@ def CPIsFunSat():
model.AddAllDifferent(letters)
# CP + IS + FUN = TRUE
model.Add(c * base + p + i * base + s + f * base * base + u * base +
n == t * base * base * base + r * base * base + u * base + e)
model.Add(c * base + p + i * base + s + f * base * base + u * base + n ==
t * base * base * base + r * base * base + u * base + e)
# [END constraints]
# [START solve]

View File

@@ -36,8 +36,8 @@ def RabbitsAndPheasantsSat():
status = solver.Solve(model)
if status == cp_model.FEASIBLE:
print(
'%i rabbits and %i pheasants' % (solver.Value(r), solver.Value(p)))
print('%i rabbits and %i pheasants' % (solver.Value(r),
solver.Value(p)))
RabbitsAndPheasantsSat()

View File

@@ -56,8 +56,8 @@ def main():
for n in all_nurses:
for d in all_days:
for s in all_shifts:
shifts[(n, d,
s)] = model.NewBoolVar('shift_n%id%is%i' % (n, d, s))
shifts[(n, d, s)] = model.NewBoolVar('shift_n%id%is%i' % (n, d,
s))
# [END variables]
# Each shift is assigned to exactly one nurse in .
@@ -88,8 +88,8 @@ def main():
# [START objective]
model.Maximize(
sum(shift_requests[n][d][s] * shifts[(n, d, s)] for n in all_nurses
for d in all_days for s in all_shifts))
sum(shift_requests[n][d][s] * shifts[(n, d, s)]
for n in all_nurses for d in all_days for s in all_shifts))
# [END objective]
# Creates the solver and solve.
# [START solve]