diff --git a/makefiles/Makefile.gen.mk b/makefiles/Makefile.gen.mk index d0160c744b..5cbdd67249 100644 --- a/makefiles/Makefile.gen.mk +++ b/makefiles/Makefile.gen.mk @@ -1565,6 +1565,7 @@ SAT_LIB_OBJS = \ $(OBJ_DIR)/sat/clause.$O \ $(OBJ_DIR)/sat/cp_constraints.$O \ $(OBJ_DIR)/sat/cp_model_checker.$O \ + $(OBJ_DIR)/sat/cp_model_expand.$O \ $(OBJ_DIR)/sat/cp_model_presolve.$O \ $(OBJ_DIR)/sat/cp_model_search.$O \ $(OBJ_DIR)/sat/cp_model_solver.$O \ @@ -1652,6 +1653,9 @@ $(SRC_DIR)/ortools/sat/cp_model_checker.h: \ $(GEN_DIR)/ortools/sat/cp_model.pb.h \ $(SRC_DIR)/ortools/base/integral_types.h +$(SRC_DIR)/ortools/sat/cp_model_expand.h: \ + $(GEN_DIR)/ortools/sat/cp_model.pb.h + $(SRC_DIR)/ortools/sat/cp_model_presolve.h: \ $(GEN_DIR)/ortools/sat/cp_model.pb.h @@ -1963,6 +1967,14 @@ $(OBJ_DIR)/sat/cp_model_checker.$O: \ $(SRC_DIR)/ortools/util/sorted_interval_list.h $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Ssat$Scp_model_checker.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Scp_model_checker.$O +$(OBJ_DIR)/sat/cp_model_expand.$O: \ + $(SRC_DIR)/ortools/sat/cp_model_expand.cc \ + $(SRC_DIR)/ortools/sat/cp_model_expand.h \ + $(GEN_DIR)/ortools/sat/cp_model.pb.h \ + $(SRC_DIR)/ortools/base/hash.h \ + $(SRC_DIR)/ortools/base/map_util.h + $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Ssat$Scp_model_expand.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Scp_model_expand.$O + $(OBJ_DIR)/sat/cp_model_presolve.$O: \ $(SRC_DIR)/ortools/sat/cp_model_presolve.cc \ $(SRC_DIR)/ortools/sat/cp_model_checker.h \ @@ -1993,6 +2005,7 @@ $(OBJ_DIR)/sat/cp_model_solver.$O: \ $(SRC_DIR)/ortools/sat/circuit.h \ $(SRC_DIR)/ortools/sat/cp_constraints.h \ $(SRC_DIR)/ortools/sat/cp_model_checker.h \ + $(SRC_DIR)/ortools/sat/cp_model_expand.h \ $(SRC_DIR)/ortools/sat/cp_model_presolve.h \ $(SRC_DIR)/ortools/sat/cp_model_search.h \ $(SRC_DIR)/ortools/sat/cp_model_solver.h \ diff --git a/ortools/linear_solver/linear_expr.h b/ortools/linear_solver/linear_expr.h index 373440f487..bc54bf13c5 100644 --- a/ortools/linear_solver/linear_expr.h +++ b/ortools/linear_solver/linear_expr.h @@ -152,7 +152,7 @@ LinearExpr operator*(double lhs, LinearExpr rhs); // MPSolver::AddRowConstraint(const LinearRange& range[, const std::string& name]); class LinearRange { public: - LinearRange(); + LinearRange() : lower_bound_(0), upper_bound_(0) {} // The bounds of the linear range are updated so that they include the offset // from "linear_expr", i.e., we form the range: // lower_bound - offset <= linear_expr - offset <= upper_bound - offset. diff --git a/ortools/sat/all_different.cc b/ortools/sat/all_different.cc index d2d22ca6bc..fb0cccb046 100644 --- a/ortools/sat/all_different.cc +++ b/ortools/sat/all_different.cc @@ -415,16 +415,13 @@ bool AllDifferentConstraint::Propagate() { } } - const int index = trail_->Index(); LiteralIndex li = VariableLiteralIndexOf(x, offset_value + min_all_values_); DCHECK_NE(li, kTrueLiteralIndex); DCHECK_NE(li, kFalseLiteralIndex); - const Literal deduction = Literal(li).Negated(); - trail_->Enqueue(deduction, AssignmentType::kCachedReason); - *trail_->GetVectorToStoreReason(index) = *reason; - trail_->NotifyThatReasonIsCached(deduction.Variable()); + *trail_->GetVectorToStoreReason() = *reason; + trail_->EnqueueWithStoredReason(Literal(li).Negated()); return true; } } diff --git a/ortools/sat/circuit.cc b/ortools/sat/circuit.cc index 6970fccff8..6c6de4e80b 100644 --- a/ortools/sat/circuit.cc +++ b/ortools/sat/circuit.cc @@ -14,8 +14,8 @@ #include "ortools/sat/circuit.h" #include -#include +#include #include "ortools/base/map_util.h" #include "ortools/sat/sat_solver.h" @@ -23,12 +23,12 @@ namespace operations_research { namespace sat { CircuitPropagator::CircuitPropagator( - const std::vector>& graph, Options options, - Trail* trail) + std::vector> graph, Options options, Trail* trail) : num_nodes_(graph.size()), + graph_(std::move(graph)), options_(options), trail_(trail), - propagation_trail_index_(0) { + assignment_(trail->Assignment()) { // TODO(user): add a way to properly handle trivially UNSAT cases. // For now we just check that they don't occur at construction. CHECK_GT(num_nodes_, 1) @@ -36,18 +36,20 @@ CircuitPropagator::CircuitPropagator( next_.resize(num_nodes_, -1); prev_.resize(num_nodes_, -1); next_literal_.resize(num_nodes_); - self_arcs_.resize(num_nodes_); + must_be_in_cycle_.resize(num_nodes_); std::unordered_map literal_to_watch_index; - const VariablesAssignment& assignment = trail->Assignment(); for (int tail = 0; tail < num_nodes_; ++tail) { - self_arcs_[tail] = graph[tail][tail]; + if (LiteralIndexIsFalse(graph_[tail][tail])) { + // For the multiple_subcircuit_through_zero case, must_be_in_cycle_ will + // be const and only contains zero. + if (tail == 0 || !options_.multiple_subcircuit_through_zero) { + must_be_in_cycle_[rev_must_be_in_cycle_size_++] = tail; + } + } for (int head = 0; head < num_nodes_; ++head) { - const LiteralIndex index = graph[tail][head]; - // Note that we need to test for both "special" cases before we can - // call assignment.LiteralIsTrue() or LiteralIsFalse(). - if (index == kFalseLiteralIndex) continue; - if (index == kTrueLiteralIndex || - assignment.LiteralIsTrue(Literal(index))) { + LiteralIndex index = graph_[tail][head]; + if (LiteralIndexIsFalse(index)) continue; + if (LiteralIndexIsTrue(index)) { CHECK_EQ(next_[tail], -1) << "Trivially UNSAT or duplicate arcs while adding " << tail << " -> " << head; @@ -57,20 +59,32 @@ CircuitPropagator::CircuitPropagator( AddArc(tail, head, kNoLiteralIndex); continue; } - if (assignment.LiteralIsFalse(Literal(index))) continue; - int watch_index_ = FindWithDefault(literal_to_watch_index, index, -1); - if (watch_index_ == -1) { - watch_index_ = watch_index_to_literal_.size(); - literal_to_watch_index[index] = watch_index_; + // Tricky: For self-arc, we watch instead when the arc become false. + if (tail == head) index = Literal(index).NegatedIndex(); + + int watch_index = FindWithDefault(literal_to_watch_index, index, -1); + if (watch_index == -1) { + watch_index = watch_index_to_literal_.size(); + literal_to_watch_index[index] = watch_index; watch_index_to_literal_.push_back(Literal(index)); watch_index_to_arcs_.push_back(std::vector()); } - watch_index_to_arcs_[watch_index_].push_back({tail, head}); + watch_index_to_arcs_[watch_index].push_back({tail, head}); } } } +void CircuitPropagator::RegisterWith(GenericLiteralWatcher* watcher) { + const int id = watcher->Register(this); + for (int w = 0; w < watch_index_to_literal_.size(); ++w) { + watcher->WatchLiteral(watch_index_to_literal_[w], id, w); + } + watcher->RegisterReversibleClass(id, this); + watcher->RegisterReversibleInt(id, &propagation_trail_index_); + watcher->RegisterReversibleInt(id, &rev_must_be_in_cycle_size_); +} + void CircuitPropagator::SetLevel(int level) { if (level == level_ends_.size()) return; if (level > level_ends_.size()) { @@ -90,21 +104,20 @@ void CircuitPropagator::SetLevel(int level) { level_ends_.resize(level); } -void CircuitPropagator::FillConflictFromCircuitAt(int start) { - std::vector* conflict = trail_->MutableConflict(); - conflict->clear(); - int node = start; - do { - CHECK_NE(node, -1); +void CircuitPropagator::FillReasonForPath(int start_node, + std::vector* reason) const { + CHECK_NE(start_node, -1); + reason->clear(); + int node = start_node; + while (next_[node] != -1) { if (next_literal_[node] != kNoLiteralIndex) { - conflict->push_back(Literal(next_literal_[node]).Negated()); + reason->push_back(Literal(next_literal_[node]).Negated()); } node = next_[node]; - } while (node != start); + if (node == start_node) break; + } } -bool CircuitPropagator::Propagate() { return true; } - // If multiple_subcircuit_through_zero is true, we never fill next_[0] and // prev_[0]. void CircuitPropagator::AddArc(int tail, int head, LiteralIndex literal_index) { @@ -122,8 +135,14 @@ bool CircuitPropagator::IncrementalPropagate( for (const int w : watch_indices) { const Literal literal = watch_index_to_literal_[w]; for (const Arc arc : watch_index_to_arcs_[w]) { - // Get rid of the trivial conflicts: - // - At most one incoming and one ougtoing arc for each nodes. + // Special case for self-arc. + if (arc.tail == arc.head) { + must_be_in_cycle_[rev_must_be_in_cycle_size_++] = arc.tail; + continue; + } + + // Get rid of the trivial conflicts: At most one incoming and one outgoing + // arc for each nodes. if (next_[arc.tail] != -1) { std::vector* conflict = trail_->MutableConflict(); if (next_literal_[arc.tail] != kNoLiteralIndex) { @@ -148,91 +167,135 @@ bool CircuitPropagator::IncrementalPropagate( // Add the arc. AddArc(arc.tail, arc.head, literal.Index()); added_arcs_.push_back(arc); + } + } + return Propagate(); +} - // Circuit? - in_circuit_.assign(num_nodes_, false); - in_circuit_[arc.tail] = true; - int size = 1; - int node = arc.head; - while (node != arc.tail && node != -1) { - in_circuit_[node] = true; - node = next_[node]; - size++; +// This function assumes that next_, prev_, next_literal_ and must_be_in_cycle_ +// are all up to date. +bool CircuitPropagator::Propagate() { + processed_.assign(num_nodes_, false); + for (int n = 0; n < num_nodes_; ++n) { + if (processed_[n]) continue; + if (next_[n] == n) continue; + if (next_[n] == -1 && prev_[n] == -1) continue; + + // TODO(user): both this and the loop on must_be_in_cycle_ might take some + // time on large graph. Optimize if this become an issue. + in_current_path_.assign(num_nodes_, false); + + // Find the start and end of the path containing node n. If this is a + // circuit, we will have start_node == end_node. + int start_node = n; + int end_node = n; + in_current_path_[n] = true; + processed_[n] = true; + while (next_[end_node] != -1) { + end_node = next_[end_node]; + in_current_path_[end_node] = true; + processed_[end_node] = true; + if (end_node == n) break; + } + while (prev_[start_node] != -1) { + start_node = prev_[start_node]; + in_current_path_[start_node] = true; + processed_[start_node] = true; + if (start_node == n) break; + } + + // Check if we miss any node that must be in the circuit. Note that the ones + // for which graph_[i][i] is kFalseLiteralIndex are first. This is good as + // it will produce shorter reason. Otherwise we prefer the first that was + // assigned in the trail. + bool miss_some_nodes = false; + LiteralIndex extra_reason = kFalseLiteralIndex; + for (int i = 0; i < rev_must_be_in_cycle_size_; ++i) { + const int node = must_be_in_cycle_[i]; + if (!in_current_path_[node]) { + miss_some_nodes = true; + extra_reason = graph_[node][node]; + break; } + } - if (options_.multiple_subcircuit_through_zero) { - // If we reached zero, this is a valid path provided that we can - // reach the beginning of the path from zero. Note that we only check - // the basic case that the beginning of the path must have "open" arcs - // thanks to ExactlyOnePerRowAndPerColumn(). - if (node == 0 || node != arc.tail) continue; - - // We have a cycle not touching zero, this is a conflict. - FillConflictFromCircuitAt(arc.tail); + if (miss_some_nodes) { + // A circuit that miss a mandatory node is a conflict. + if (start_node == end_node) { + FillReasonForPath(start_node, trail_->MutableConflict()); + if (extra_reason != kFalseLiteralIndex) { + trail_->MutableConflict()->push_back(Literal(extra_reason)); + } return false; } - if (node != arc.tail) continue; + // We have an unclosed path. Propagate the fact that it cannot + // be closed into a cycle, i.e. not(end_node -> start_node). + if (start_node != end_node) { + const LiteralIndex literal_index = graph_[end_node][start_node]; + if (LiteralIndexIsFalse(literal_index)) continue; - // We have one circuit. - if (size == num_nodes_) return true; - if (size == 1) continue; // self-arc. + // We would have detected a cycle otherwise. + // TODO(user): This may actually fail in corner cases where the same + // literal is used for more than one arc and we propagate it here. Fix + // if this happen. + CHECK(!LiteralIndexIsTrue(literal_index)); - // HACK: we can reuse the conflict vector even though we don't have a - // conflict. - FillConflictFromCircuitAt(arc.tail); - BooleanVariable variable_with_same_reason = kNoBooleanVariable; - - // We can propagate all the other nodes to point to themselves. - // If this is not already the case, we have a conflict. - for (int node = 0; node < num_nodes_; ++node) { - if (in_circuit_[node] || next_[node] == node) continue; - if (next_[node] != -1) { - std::vector* conflict = trail_->MutableConflict(); - if (next_literal_[node] != kNoLiteralIndex) { - conflict->push_back(Literal(next_literal_[node]).Negated()); - } - return false; - } else if (self_arcs_[node] == kFalseLiteralIndex) { - return false; - } else { - DCHECK_NE(self_arcs_[node], kTrueLiteralIndex); - const Literal literal(self_arcs_[node]); - - // We may not have processed this literal yet. - if (trail_->Assignment().LiteralIsTrue(literal)) continue; - if (trail_->Assignment().LiteralIsFalse(literal)) { - std::vector* conflict = trail_->MutableConflict(); - conflict->push_back(literal); - return false; - } - - // Propagate. - if (variable_with_same_reason == kNoBooleanVariable) { - variable_with_same_reason = literal.Variable(); - const int index = trail_->Index(); - trail_->Enqueue(literal, AssignmentType::kCachedReason); - *trail_->GetVectorToStoreReason(index) = *trail_->MutableConflict(); - trail_->NotifyThatReasonIsCached(literal.Variable()); - } else { - trail_->EnqueueWithSameReasonAs(literal, variable_with_same_reason); - } + // Propagate. + std::vector* reason = trail_->GetVectorToStoreReason(); + FillReasonForPath(start_node, reason); + if (extra_reason != kFalseLiteralIndex) { + reason->push_back(Literal(extra_reason)); } + trail_->EnqueueWithStoredReason(Literal(literal_index).Negated()); + } + } + + // If we have a cycle, we can propagate all the other nodes to point to + // themselves. Otherwise there is nothing else to do. + if (start_node != end_node) continue; + if (options_.multiple_subcircuit_through_zero) continue; + BooleanVariable variable_with_same_reason = kNoBooleanVariable; + for (int node = 0; node < num_nodes_; ++node) { + if (in_current_path_[node]) continue; + if (LiteralIndexIsTrue(graph_[node][node])) continue; + + // We should have detected that above (miss_some_nodes == true). But we + // still need this for corner cases where the same literal is used for + // many arcs, and we just propagated it here. + if (LiteralIndexIsFalse(graph_[node][node])) { + CHECK_NE(graph_[node][node], kFalseLiteralIndex); + FillReasonForPath(start_node, trail_->MutableConflict()); + trail_->MutableConflict()->push_back(Literal(graph_[node][node])); + return false; + } + + // This shouldn't happen because ExactlyOnePerRowAndPerColumn() should + // have executed first and propagated graph_[node][node] to false. We + // still keep the code for safety though. + if (next_[node] != -1) { + FillReasonForPath(start_node, trail_->MutableConflict()); + if (next_literal_[node] != kNoLiteralIndex) { + trail_->MutableConflict()->push_back( + Literal(next_literal_[node]).Negated()); + } + return false; + } + + // Propagate. + const Literal literal(graph_[node][node]); + if (variable_with_same_reason == kNoBooleanVariable) { + variable_with_same_reason = literal.Variable(); + FillReasonForPath(start_node, trail_->GetVectorToStoreReason()); + trail_->EnqueueWithStoredReason(literal); + } else { + trail_->EnqueueWithSameReasonAs(literal, variable_with_same_reason); } } } return true; } -void CircuitPropagator::RegisterWith(GenericLiteralWatcher* watcher) { - const int id = watcher->Register(this); - for (int w = 0; w < watch_index_to_literal_.size(); ++w) { - watcher->WatchLiteral(watch_index_to_literal_[w], id, w); - } - watcher->RegisterReversibleClass(id, this); - watcher->RegisterReversibleInt(id, &propagation_trail_index_); -} - std::function ExactlyOnePerRowAndPerColumn( const std::vector>& square_matrix, bool ignore_row_and_col_zero) { diff --git a/ortools/sat/circuit.h b/ortools/sat/circuit.h index f8886a66da..7e881006cc 100644 --- a/ortools/sat/circuit.h +++ b/ortools/sat/circuit.h @@ -51,7 +51,7 @@ class CircuitPropagator : PropagatorInterface, ReversibleInterface { // being present when the given literal is true. The special values // kTrueLiteralIndex and kFalseLiteralIndex can be used for arcs that are // either always there or never there. - CircuitPropagator(const std::vector>& graph, + CircuitPropagator(std::vector> graph, Options options, Trail* trail); void SetLevel(int level) final; @@ -60,29 +60,42 @@ class CircuitPropagator : PropagatorInterface, ReversibleInterface { void RegisterWith(GenericLiteralWatcher* watcher); private: + // Helper to deal with kTrueLiteralIndex and kFalseLiteralIndex. + bool LiteralIndexIsTrue(LiteralIndex index) { + if (index == kTrueLiteralIndex) return true; + if (index == kFalseLiteralIndex) return false; + return assignment_.LiteralIsTrue(Literal(index)); + } + bool LiteralIndexIsFalse(LiteralIndex index) { + if (index == kTrueLiteralIndex) return false; + if (index == kFalseLiteralIndex) return true; + return assignment_.LiteralIsFalse(Literal(index)); + } + // Updates the structures when the given arc is added to the paths. void AddArc(int tail, int head, LiteralIndex literal_index); - // Clears and fills trail_->MutableConflict() with the literals of the arcs - // that form a cycle containing the given node. - void FillConflictFromCircuitAt(int start); + // Clears and fills reason with the literals of the arcs that form a path from + // the given node. The path can be a cycle, but in this case it must end at + // start (not like a rho shape). + void FillReasonForPath(int start_node, std::vector* reason) const; const int num_nodes_; + const std::vector> graph_; const Options options_; Trail* trail_; + const VariablesAssignment& assignment_; - // Internal representation of the graph given at construction. Const. + // Data used to interpret the watch indices passed to IncrementalPropagate(). struct Arc { int tail; int head; }; - std::vector self_arcs_; - std::vector watch_index_to_literal_; std::vector> watch_index_to_arcs_; // Index in trail_ up to which we propagated all the assigned Literals. - int propagation_trail_index_; + int propagation_trail_index_ = 0; // Current partial chains of arc that are present. std::vector next_; // -1 if not assigned yet. @@ -94,8 +107,14 @@ class CircuitPropagator : PropagatorInterface, ReversibleInterface { std::vector level_ends_; std::vector added_arcs_; - // Temporary vector. - std::vector in_circuit_; + // Reversible list of node that must be in a cycle. A node must be in a cycle + // iff graph_[node][node] is false. This graph entry can be used as a reason. + int rev_must_be_in_cycle_size_ = 0; + std::vector must_be_in_cycle_; + + // Temporary vectors. + std::vector processed_; + std::vector in_current_path_; DISALLOW_COPY_AND_ASSIGN(CircuitPropagator); }; diff --git a/ortools/sat/cp_model.proto b/ortools/sat/cp_model.proto index a31fd0bfa0..f8c00e097c 100644 --- a/ortools/sat/cp_model.proto +++ b/ortools/sat/cp_model.proto @@ -140,6 +140,23 @@ message CumulativeConstraintProto { repeated int32 demands = 3; // Same size as intervals. } +// Maintain a reservoir level within bounds. The water level starts at 0, and at +// any time >= 0, it must be within min_level, and max_level. Furthermore, this +// constraints expect all times variables to be >= 0. +// If the variable times[i] is assigned a value t, then the current level +// changes by demands[i] (which is constant) at the time t. +// +// Note that level min can be > 0, or level max can be < 0. It just forces +// some demands to be executed at time 0 to make sure that we are within those +// bounds with the executed demands. Therefore, at any time t >= 0: +// sum(demands[i] if times[i] <= t) in [min_level, max_level] +message ReservoirConstraintProto { + int64 min_level = 1; + int64 max_level = 2; + repeated int32 times = 3; + repeated int64 demands = 4; +} + // The "next" variable of a node i represents its successor in a graph. Any // value that fall outside [0, n = next_variables.size()) or is a self-loop // (next[i] == i) takes the special meaning of no-successor. @@ -250,6 +267,7 @@ message ConstraintProto { TableConstraintProto table = 16; AutomataConstraintProto automata = 17; InverseConstraintProto inverse = 18; + ReservoirConstraintProto reservoir = 24; // Constraints on intervals. // diff --git a/ortools/sat/cp_model_checker.cc b/ortools/sat/cp_model_checker.cc index 5ce9c9efba..8c07d2ca37 100644 --- a/ortools/sat/cp_model_checker.cc +++ b/ortools/sat/cp_model_checker.cc @@ -593,6 +593,33 @@ class ConstraintChecker { return true; } + bool ReservoirConstraintIsFeasible(const CpModelProto& model, + const ConstraintProto& ct) { + const int num_variables = ct.reservoir().times_size(); + const int64 min_level = ct.reservoir().min_level(); + const int64 max_level = ct.reservoir().min_level(); + std::map deltas; + deltas[0] = 0; + for (int i = 0; i < num_variables; i++) { + const int t = Value(ct.reservoir().times(i)); + if (t < 0) { + VLOG(1) << "reservoir times(" << i << ") is negative."; + return false; + } + deltas[t] += ct.reservoir().demands(i); + } + int64 current_level = 0; + for (const auto& delta : deltas) { + current_level += delta.second; + if (current_level < min_level || current_level > max_level) { + VLOG(1) << "Reservoir level " << current_level + << " is out of bounds at time" << delta.first; + return false; + } + } + return true; + } + private: std::vector variable_values_; }; @@ -607,7 +634,11 @@ bool SolutionIsFeasible(const CpModelProto& model, } // Check that all values fall in the variable domains. + int num_optional_vars = 0; for (int i = 0; i < model.variables_size(); ++i) { + if (!model.variables(i).enforcement_literal().empty()) { + ++num_optional_vars; + } if (!DomainInProtoContains(model.variables(i), variable_values[i])) { VLOG(1) << "Variable #" << i << " has value " << variable_values[i] << " which do not fall in its domain: " @@ -623,7 +654,11 @@ bool SolutionIsFeasible(const CpModelProto& model, const ConstraintProto& ct = model.constraints(c); if (!checker.ConstraintIsEnforced(ct)) continue; - if (checker.ConstraintHasNonEnforcedVariables(model, ct)) continue; + if (num_optional_vars > 0) { + // This function can be slow because it uses reflection. So we only + // call it if there is any optional variables. + if (checker.ConstraintHasNonEnforcedVariables(model, ct)) continue; + } bool is_feasible = true; const ConstraintProto::ConstraintCase type = ct.constraint_case(); @@ -685,6 +720,9 @@ bool SolutionIsFeasible(const CpModelProto& model, case ConstraintProto::ConstraintCase::kInverse: is_feasible = checker.InverseConstraintIsFeasible(model, ct); break; + case ConstraintProto::ConstraintCase::kReservoir: + is_feasible = checker.ReservoirConstraintIsFeasible(model, ct); + break; case ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET: // Empty constraint is always feasible. break; diff --git a/ortools/sat/cp_model_checker.h b/ortools/sat/cp_model_checker.h index 1928161755..74ebf73c84 100644 --- a/ortools/sat/cp_model_checker.h +++ b/ortools/sat/cp_model_checker.h @@ -31,7 +31,7 @@ namespace sat { std::string ValidateCpModel(const CpModelProto& model); // Verifies that the given variable assignment is a feasible solution of the -// given model. The values vector should be in one to one correspondance with +// given model. The values vector should be in one to one correspondence with // the model.variables() list of variables. bool SolutionIsFeasible(const CpModelProto& model, const std::vector& variable_values); diff --git a/ortools/sat/cp_model_expand.cc b/ortools/sat/cp_model_expand.cc new file mode 100644 index 0000000000..87dce7a69e --- /dev/null +++ b/ortools/sat/cp_model_expand.cc @@ -0,0 +1,255 @@ +// Copyright 2010-2017 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/cp_model_expand.h" + +#include +#include "ortools/base/map_util.h" +#include "ortools/base/hash.h" +#include "ortools/sat/cp_model.pb.h" + +namespace operations_research { +namespace sat { +namespace { + +int Not(int a) { return -a - 1; } + +void AddImplication(int a, int b, CpModelProto* model_proto) { + ConstraintProto* const ct = model_proto->add_constraints(); + ct->add_enforcement_literal(a); + ct->mutable_bool_or()->add_literals(b); +} + +void AddImplyInDomain(int b, int x, int64 lb, int64 ub, + CpModelProto* model_proto) { + ConstraintProto* const imply = model_proto->add_constraints(); + imply->add_enforcement_literal(b); + imply->mutable_linear()->add_vars(x); + imply->mutable_linear()->add_coeffs(1); + imply->mutable_linear()->add_domain(lb); + imply->mutable_linear()->add_domain(ub); +} + +int AddBoolVar(CpModelProto* model_proto) { + IntegerVariableProto* const var = model_proto->add_variables(); + var->add_domain(0); + var->add_domain(1); + return model_proto->variables_size() - 1; +} + +bool IsOptional(const IntegerVariableProto& v) { + return v.enforcement_literal_size() > 0; +} + +// b <=> (x <= 0). +void AddVarEqualLessOrEqualZero(int b, int x, CpModelProto* model_proto) { + if (model_proto->variables(x).enforcement_literal_size() == 0) { + AddImplyInDomain(b, x, kint64min, 0, model_proto); + AddImplyInDomain(Not(b), x, 1, kint64max, model_proto); + } else { + const int opt_x = model_proto->variables(x).enforcement_literal(0); + AddImplyInDomain(b, x, kint64min, 0, model_proto); + AddImplication(b, opt_x, model_proto); + const int g = AddBoolVar(model_proto); + AddImplyInDomain(g, x, 1, kint64max, model_proto); + AddImplication(g, opt_x, model_proto); + ConstraintProto* const bool_or = model_proto->add_constraints(); + bool_or->mutable_bool_or()->add_literals(b); + bool_or->mutable_bool_or()->add_literals(g); + bool_or->mutable_bool_or()->add_literals(Not(opt_x)); + } +} + +// x_lesseq_y <=> (x <= y) && x enforced && y enforced. +void AddVarEqualPrecedence(int x_lesseq_y, int x, int y, + CpModelProto* model_proto) { + const IntegerVariableProto& var_x = model_proto->variables(x); + const IntegerVariableProto& var_y = model_proto->variables(y); + const bool x_is_optional = IsOptional(var_x); + const bool y_is_optional = IsOptional(var_y); + const int x_lit = x_is_optional ? var_x.enforcement_literal(0) : -1; + const int y_lit = y_is_optional ? var_y.enforcement_literal(0) : -1; + + // x_lesseq_y => (x <= y) && x => enforced && y enforced. + ConstraintProto* const lesseq = model_proto->add_constraints(); + lesseq->add_enforcement_literal(x_lesseq_y); + lesseq->mutable_linear()->add_vars(x); + lesseq->mutable_linear()->add_vars(y); + lesseq->mutable_linear()->add_coeffs(-1); + lesseq->mutable_linear()->add_coeffs(1); + lesseq->mutable_linear()->add_domain(0); + lesseq->mutable_linear()->add_domain(kint64max); + + if (IsOptional(var_x)) { + AddImplication(x_lesseq_y, x_lit, model_proto); + } + if (IsOptional(var_y)) { + AddImplication(x_lesseq_y, y_lit, model_proto); + } + + // x_greater_y => (x > y) && x enforced && y enforced. + const int x_greater_y = x_is_optional || y_is_optional + ? AddBoolVar(model_proto) + : Not(x_lesseq_y); + + ConstraintProto* const greater = model_proto->add_constraints(); + greater->add_enforcement_literal(x_greater_y); + greater->mutable_linear()->add_vars(x); + greater->mutable_linear()->add_vars(y); + greater->mutable_linear()->add_coeffs(-1); + greater->mutable_linear()->add_coeffs(1); + greater->mutable_linear()->add_domain(kint64min); + greater->mutable_linear()->add_domain(-1); + + if (IsOptional(var_x)) { + AddImplication(x_greater_y, x_lit, model_proto); + } + if (IsOptional(var_y)) { + AddImplication(x_greater_y, y_lit, model_proto); + } + + // Consistency between x_lesseq_y, x_greater_y, x_lit, y_lit. + ConstraintProto* const bool_or = model_proto->add_constraints(); + bool_or->mutable_bool_or()->add_literals(x_lesseq_y); + bool_or->mutable_bool_or()->add_literals(x_greater_y); + if (x_is_optional) { + AddImplication(Not(x_lit), Not(x_lesseq_y), model_proto); + AddImplication(Not(x_lit), Not(x_greater_y), model_proto); + bool_or->mutable_bool_or()->add_literals(Not(x_lit)); + } + if (y_is_optional) { + AddImplication(Not(y_lit), Not(x_lesseq_y), model_proto); + AddImplication(Not(y_lit), Not(x_greater_y), model_proto); + bool_or->mutable_bool_or()->add_literals(Not(y_lit)); + } + + // TODO(user): Do we add x_lesseq_y => Not(x_greater_y) + // and x_greater_y => Not(x_lesseq_y)? +} + +struct RewriteContext { + CpModelProto expanded_proto; + std::unordered_map, int> precedence_cache; + std::unordered_map statistics; +}; + +void ExpandReservoir(ConstraintProto* ct, RewriteContext* context) { + const ReservoirConstraintProto& reservoir = ct->reservoir(); + const int num_variables = reservoir.times_size(); + CpModelProto& expanded = context->expanded_proto; + + // Creates boolean variables equivalent to (start[i] <= start[j]) i != j, + for (int i = 0; i < num_variables - 1; ++i) { + const int ti = reservoir.times(i); + for (int j = i + 1; j < num_variables; ++j) { + const int tj = reservoir.times(j); + const std::pair p = std::make_pair(ti, tj); + const std::pair rev_p = std::make_pair(tj, ti); + if (ContainsKey(context->precedence_cache, p)) continue; + + const int i_lesseq_j = AddBoolVar(&expanded); + context->precedence_cache[p] = i_lesseq_j; + const int j_lesseq_i = AddBoolVar(&expanded); + context->precedence_cache[rev_p] = j_lesseq_i; + AddVarEqualPrecedence(i_lesseq_j, ti, tj, &expanded); + AddVarEqualPrecedence(j_lesseq_i, tj, ti, &expanded); + // Consistency. + ConstraintProto* const bool_or = expanded.add_constraints(); + bool_or->mutable_bool_or()->add_literals(i_lesseq_j); + bool_or->mutable_bool_or()->add_literals(j_lesseq_i); + const IntegerVariableProto& var_i = expanded.variables(ti); + if (IsOptional(var_i)) { + bool_or->mutable_bool_or()->add_literals( + Not(var_i.enforcement_literal(0))); + } + const IntegerVariableProto& var_j = expanded.variables(tj); + if (IsOptional(var_j)) { + bool_or->mutable_bool_or()->add_literals( + Not(var_j.enforcement_literal(0))); + } + } + } + + // Constrains the reservoir level to be consistent at time 0. + // We need to do it only if 0 is not in [min_level..max_level]. + // Otherwise, the regular propagation will already check it. + if (reservoir.min_level() > 0 || reservoir.max_level() < 0) { + ConstraintProto* const initial = expanded.add_constraints(); + for (int i = 0; i < num_variables; ++i) { + const int ti = reservoir.times(i); + const int b = AddBoolVar(&expanded); + initial->mutable_linear()->add_vars(b); + AddVarEqualLessOrEqualZero(b, ti, &expanded); + initial->mutable_linear()->add_coeffs(reservoir.demands(i)); + } + initial->mutable_linear()->add_domain(reservoir.min_level()); + initial->mutable_linear()->add_domain(reservoir.max_level()); + } + + // Constrains the running level to be consistent at all times. + for (int i = 0; i < num_variables; ++i) { + const int ti = reservoir.times(i); + // Accumulates demands of all predecessors. + ConstraintProto* const level = expanded.add_constraints(); + for (int j = 0; j < num_variables; ++j) { + if (i == j) continue; + const int tj = reservoir.times(j); + const std::pair p = std::make_pair(tj, ti); + level->mutable_linear()->add_vars( + FindOrDieNoPrint(context->precedence_cache, p)); + level->mutable_linear()->add_coeffs(reservoir.demands(j)); + } + // Accounts for own demand. + const int64 demand_i = reservoir.demands(i); + level->mutable_linear()->add_domain(reservoir.min_level() - demand_i); + level->mutable_linear()->add_domain(reservoir.max_level() - demand_i); + const IntegerVariableProto& var_i = expanded.variables(ti); + if (IsOptional(var_i)) { + level->add_enforcement_literal(var_i.enforcement_literal(0)); + } + } + + // Constrains all times to be >= 0. + for (int i = 0; i < num_variables; ++i) { + const int ti = reservoir.times(i); + ConstraintProto* const positive = expanded.add_constraints(); + positive->mutable_linear()->add_vars(ti); + positive->mutable_linear()->add_coeffs(1); + positive->mutable_linear()->add_domain(0); + positive->mutable_linear()->add_domain(kint64max); + } + + ct->Clear(); + context->statistics["kReservoir"]++; +} + +} // namespace + +CpModelProto ExpandCpModel(const CpModelProto& initial_model) { + RewriteContext context; + context.expanded_proto = initial_model; + for (int i = 0; i < initial_model.constraints_size(); ++i) { + ConstraintProto* const ct = context.expanded_proto.mutable_constraints(i); + switch (ct->constraint_case()) { + case ConstraintProto::ConstraintCase::kReservoir: + ExpandReservoir(ct, &context); + break; + default: + break; + } + } + + return context.expanded_proto; +} +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/cp_model_expand.h b/ortools/sat/cp_model_expand.h new file mode 100644 index 0000000000..42a9128842 --- /dev/null +++ b/ortools/sat/cp_model_expand.h @@ -0,0 +1,31 @@ +// Copyright 2010-2017 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_CP_MODEL_EXPAND_H_ +#define OR_TOOLS_SAT_CP_MODEL_EXPAND_H_ + +#include "ortools/sat/cp_model.pb.h" + +namespace operations_research { +namespace sat { + +// Expands a given CpModelProto by rewriting complex constraints into +// simpler constraints. +// This is different from PresolveCpModel() as there are no reduction or +// simplification of the model. Furthermore, this expansion is mandatory. +CpModelProto ExpandCpModel(const CpModelProto& initial_model); + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_CP_MODEL_EXPAND_H_ diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index ee61bfbeb2..bba5e2b555 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -23,12 +23,12 @@ #include #include #include -#include #include #include #include "ortools/base/integral_types.h" #include "ortools/base/logging.h" +#include #include "ortools/base/map_util.h" #include "ortools/base/stl_util.h" #include "ortools/base/hash.h" @@ -1673,8 +1673,7 @@ void PresolveCpModel(const CpModelProto& initial_model, // not enter an infinite loop, we call each (var, constraint) pair at most // once. for (int v = 0; v < context.var_to_constraints.size(); ++v) { - const std::unordered_set& constraints = - context.var_to_constraints[v]; + const auto& constraints = context.var_to_constraints[v]; if (constraints.size() != 1) continue; const int c = *constraints.begin(); if (c < 0) continue; diff --git a/ortools/sat/cp_model_presolve.h b/ortools/sat/cp_model_presolve.h index 3adb5e2e0d..73a6975d3a 100644 --- a/ortools/sat/cp_model_presolve.h +++ b/ortools/sat/cp_model_presolve.h @@ -23,20 +23,21 @@ namespace sat { // Presolves the given CpModelProto into presolved_model. // -// This also creates a mapping model that encode the correspondance between the +// This also creates a mapping model that encode the correspondence between the // two problems. This works as follow: -// - The first variables of mapping_model are in one to one correspondance with +// - The first variables of mapping_model are in one to one correspondence with // the variables of the initial model. -// - The presolved_model variables are in one to one correspondance with the +// - The presolved_model variables are in one to one correspondence with the // variable at the indices given by postsolve_mapping in the mapping model. // - Fixing one of the two sets of variables and solving the model will assign // the other set to a feasible solution of the other problem. Moreover, the -// objective value of these solution will be the same. Note that solving such -// problem will take little time in practice because the propagation will +// objective value of these solutions will be the same. Note that solving such +// problems will take little time in practice because the propagation will // basically do all the work. // -// Note(user): an optimization model can be transformed in a decision one if for -// instance the objective is fixed, or independent on the rest of the problem. +// Note(user): an optimization model can be transformed into a decision problem, +// if for instance the objective is fixed, or independent from the rest of the +// problem. // // TODO(user): Identify disconnected components and returns a vector of // presolved model? If we go this route, it may be nicer to store the indices diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index aa01ada7f3..b2d138eb25 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -40,6 +40,7 @@ #include "ortools/sat/circuit.h" #include "ortools/sat/cp_constraints.h" #include "ortools/sat/cp_model_checker.h" +#include "ortools/sat/cp_model_expand.h" #include "ortools/sat/cp_model_presolve.h" #include "ortools/sat/cp_model_search.h" #include "ortools/sat/cp_model_utils.h" @@ -1552,7 +1553,7 @@ void FillSolutionInResponse(const CpModelProto& model_proto, response->clear_solution_lower_bounds(); response->clear_solution_upper_bounds(); if (!solution.empty()) { - CHECK(SolutionIsFeasible(model_proto, solution)); + DCHECK(SolutionIsFeasible(model_proto, solution)); for (const int64 value : solution) response->add_solution(value); } else { // Not all variables are fixed. @@ -2080,13 +2081,12 @@ IntegerVariable AddLPConstraints(const CpModelProto& model_proto, // constraints have been added. for (auto* lp_constraint : lp_constraints) { lp_constraint->RegisterWith(m->GetOrCreate()); + VLOG(1) << "LP constraint: " << lp_constraint->DimensionString() << "."; } VLOG(1) << top_level_cp_terms.size() << " terms in the main objective linear equation (" << num_components_containing_objective << " from LP constraints)."; - VLOG_IF(1, !lp_constraints.empty()) - << "Added " << lp_constraints.size() << " LP constraints."; return main_objective_var; } @@ -2418,7 +2418,7 @@ CpSolverResponse SolveCpModelInternal( response.set_wall_time(wall_timer.Get()); response.set_user_time(user_timer.Get()); response.set_deterministic_time( - model->Get()->deterministic_time()); + model->Get()->GetElapsedDeterministicTime()); return response; } @@ -2506,24 +2506,25 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { file::Defaults())); } + const CpModelProto expanded_proto = ExpandCpModel(model_proto); + const auto& observers = model->GetOrCreate()->observers; const SatParameters& parameters = model->GetOrCreate()->parameters(); if (!parameters.cp_model_presolve()) { - return SolveCpModelInternal( - model_proto, true, - [&](const CpSolverResponse& response) { - for (const auto& observer : observers) { - observer(response); - } - }, - model); + return SolveCpModelInternal(expanded_proto, true, + [&](const CpSolverResponse& response) { + for (const auto& observer : observers) { + observer(response); + } + }, + model); } CpModelProto presolved_proto; CpModelProto mapping_proto; std::vector postsolve_mapping; - PresolveCpModel(model_proto, &presolved_proto, &mapping_proto, + PresolveCpModel(expanded_proto, &presolved_proto, &mapping_proto, &postsolve_mapping); VLOG(1) << CpModelStats(presolved_proto); diff --git a/ortools/sat/cp_model_utils.cc b/ortools/sat/cp_model_utils.cc index 06262fe39b..1e39292aaf 100644 --- a/ortools/sat/cp_model_utils.cc +++ b/ortools/sat/cp_model_utils.cc @@ -20,7 +20,7 @@ namespace { template void AddIndices(const IntList& indices, std::unordered_set* output) { - for (const int index : indices) output->insert(index); + output->insert(indices.begin(), indices.end()); } } // namespace @@ -78,6 +78,9 @@ void AddReferencesUsedByConstraint(const ConstraintProto& ct, AddIndices(ct.inverse().f_direct(), &output->variables); AddIndices(ct.inverse().f_inverse(), &output->variables); break; + case ConstraintProto::ConstraintCase::kReservoir: + AddIndices(ct.reservoir().times(), &output->variables); + break; case ConstraintProto::ConstraintCase::kTable: AddIndices(ct.table().vars(), &output->variables); break; @@ -155,6 +158,8 @@ void ApplyToAllLiteralIndices(const std::function& f, break; case ConstraintProto::ConstraintCase::kInverse: break; + case ConstraintProto::ConstraintCase::kReservoir: + break; case ConstraintProto::ConstraintCase::kTable: break; case ConstraintProto::ConstraintCase::kAutomata: @@ -221,6 +226,9 @@ void ApplyToAllVariableIndices(const std::function& f, APPLY_TO_REPEATED_FIELD(inverse, f_direct); APPLY_TO_REPEATED_FIELD(inverse, f_inverse); break; + case ConstraintProto::ConstraintCase::kReservoir: + APPLY_TO_REPEATED_FIELD(reservoir, times); + break; case ConstraintProto::ConstraintCase::kTable: APPLY_TO_REPEATED_FIELD(table, vars); break; @@ -276,6 +284,8 @@ void ApplyToAllIntervalIndices(const std::function& f, break; case ConstraintProto::ConstraintCase::kInverse: break; + case ConstraintProto::ConstraintCase::kReservoir: + break; case ConstraintProto::ConstraintCase::kTable: break; case ConstraintProto::ConstraintCase::kAutomata: @@ -330,6 +340,8 @@ std::string ConstraintCaseName(ConstraintProto::ConstraintCase constraint_case) return "kRoutes"; case ConstraintProto::ConstraintCase::kInverse: return "kInverse"; + case ConstraintProto::ConstraintCase::kReservoir: + return "kReservoir"; case ConstraintProto::ConstraintCase::kTable: return "kTable"; case ConstraintProto::ConstraintCase::kAutomata: diff --git a/ortools/sat/cp_model_utils.h b/ortools/sat/cp_model_utils.h index 13b776f35b..5ae2a5e2be 100644 --- a/ortools/sat/cp_model_utils.h +++ b/ortools/sat/cp_model_utils.h @@ -17,11 +17,11 @@ #include #include #include -#include #include #include "ortools/base/integral_types.h" #include "ortools/base/logging.h" +#include #include "ortools/sat/cp_model.pb.h" #include "ortools/util/sorted_interval_list.h" diff --git a/ortools/sat/linear_programming_constraint.h b/ortools/sat/linear_programming_constraint.h index dcb146394d..3a69b987eb 100644 --- a/ortools/sat/linear_programming_constraint.h +++ b/ortools/sat/linear_programming_constraint.h @@ -116,6 +116,8 @@ class LinearProgrammingConstraint : public PropagatorInterface { bool IncrementalPropagate(const std::vector& watch_indices) override; void RegisterWith(GenericLiteralWatcher* watcher); + std::string DimensionString() const { return lp_data_.GetDimensionString(); } + private: // Generates a set of IntegerLiterals explaining why the best solution can not // be improved using reduced costs. This is used to generate explanations for diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index d4e5d820f4..c90a0d83d9 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -77,7 +77,7 @@ def CapSub(x, y): if x == y: if x == INT_MAX or x == INT_MIN: raise OverflowError( - 'Integer NaN: substracting INT_MAX or INT_MIN to itself') + 'Integer NaN: subtracting INT_MAX or INT_MIN to itself') return 0 if x == INT_MAX or x == INT_MIN: return x @@ -628,7 +628,7 @@ class CpModel(object): model_ct.automata.transition_tail.append(t[0]) def AddInverse(self, variables, inverse_variables): - """Adds AddInverse(variables, inverse_variables).""" + """Adds Inverse(variables, inverse_variables).""" ct = Constraint(self.__model.constraints) model_ct = self.__model.constraints[ct.Index()] model_ct.inverse.f_direct.extend( @@ -637,6 +637,16 @@ class CpModel(object): [self.GetOrMakeIndex(x) for x in inverse_variables]) return ct + def AddReservoirConstraint(self, times, demands, min_level, max_level): + """Adds a Reservoir(times, demands, min_level, max_level).""" + ct = Constraint(self.__model.constraints) + model_ct = self.__model.constraints[ct.Index()] + model_ct.reservoir.times.extend([self.GetOrMakeIndex(x) for x in times]) + model_ct.reservoir.extend(demands) + model_ct.reservoir.min_level = min_level + model_ct.reservoir.max_level = max_level + return ct + def AddMapDomain(self, var, bool_var_array, offset=0): """Creates var == i + offset <=> bool_var_array[i] == true for all i.""" @@ -764,22 +774,22 @@ class CpModel(object): [self.GetIntervalIndex(x) for x in interval_vars]) return ct - def AddNoOverlap2D(self, x_transition_triples, y_transition_triples): - """Adds NoOverlap2D(x_transition_triples, y_transition_triples).""" + def AddNoOverlap2D(self, x_intervals, y_intervals): + """Adds NoOverlap2D(x_tintervals, y_intervals).""" ct = Constraint(self.__model.constraints) model_ct = self.__model.constraints[ct.Index()] model_ct.no_overlap_2d.x_intervals.extend( - [self.GetIntervalIndex(x) for x in x_transition_triples]) + [self.GetIntervalIndex(x) for x in x_intervals]) model_ct.no_overlap_2d.y_intervals.extend( - [self.GetIntervalIndex(x) for x in y_transition_triples]) + [self.GetIntervalIndex(x) for x in y_intervals]) return ct - def AddCumulative(self, transition_triples, demands, capacity): - """Adds Cumulative(transition_triples, demands, capacity).""" + def AddCumulative(self, intervals, demands, capacity): + """Adds Cumulative(intervals, demands, capacity).""" ct = Constraint(self.__model.constraints) model_ct = self.__model.constraints[ct.Index()] model_ct.cumulative.intervals.extend( - [self.GetIntervalIndex(x) for x in transition_triples]) + [self.GetIntervalIndex(x) for x in intervals]) model_ct.cumulative.demands.extend( [self.GetOrMakeIndex(x) for x in demands]) model_ct.cumulative.capacity = self.GetOrMakeIndex(capacity) diff --git a/ortools/sat/sat_base.h b/ortools/sat/sat_base.h index 48a643cab0..b4cadae1c5 100644 --- a/ortools/sat/sat_base.h +++ b/ortools/sat/sat_base.h @@ -268,6 +268,15 @@ class Trail { Enqueue(true_literal, AssignmentType::kSameReasonAs); } + // Stores first the reason in GetVectorToStoreReason() before calling this. + void EnqueueWithStoredReason(Literal true_literal) { + Enqueue(true_literal, AssignmentType::kCachedReason); + const BooleanVariable var = true_literal.Variable(); + reasons_[var] = reasons_repository_[info_[var].trail_index]; + old_type_[var] = info_[var].type; + info_[var].type = AssignmentType::kCachedReason; + } + // Returns the reason why this variable was assigned. gtl::Span Reason(BooleanVariable var) const; @@ -290,14 +299,10 @@ class Trail { return &reasons_repository_[trail_index]; } - // After this is called, Reason(var) will return the content of the - // GetVectorToStoreReason(trail_index_of_var) and will not call the virtual - // Reason() function of the associated propagator. - void NotifyThatReasonIsCached(BooleanVariable var) const { - DCHECK(assignment_.VariableIsAssigned(var)); - reasons_[var] = reasons_repository_[info_[var].trail_index]; - old_type_[var] = info_[var].type; - info_[var].type = AssignmentType::kCachedReason; + // TODO(user): GetVectorToStoreReason() is always called with Index(), only + // keep this function instead. + std::vector* GetVectorToStoreReason() const { + return GetVectorToStoreReason(Index()); } // Dequeues the last assigned literal and returns it.