diff --git a/examples/cpp/costas_array_sat.cc b/examples/cpp/costas_array_sat.cc index a805d4f7e9..9523dbabd3 100644 --- a/examples/cpp/costas_array_sat.cc +++ b/examples/cpp/costas_array_sat.cc @@ -110,8 +110,7 @@ void CostasHard(const int dim) { for (int j = 0; j < dim - i; ++j) { IntVar diff = cp_model.NewIntVar(difference_domain); subset.push_back(diff); - cp_model.AddEquality( - diff, LinearExpr::ScalProd({vars[j + i], vars[j]}, {1, -1})); + cp_model.AddEquality(diff, vars[j + i] - vars[j]); } cp_model.AddAllDifferent(subset); @@ -255,10 +254,8 @@ void CostasBoolSoft(const int dim) { cp_model.NewIntVar(Domain(0, positive_diffs.size())); const IntVar neg_var = cp_model.NewIntVar(Domain(0, negative_diffs.size())); - cp_model.AddGreaterOrEqual( - pos_var, LinearExpr::Sum(positive_diffs).AddConstant(-1)); - cp_model.AddGreaterOrEqual( - neg_var, LinearExpr::Sum(negative_diffs).AddConstant(-1)); + cp_model.AddGreaterOrEqual(pos_var, LinearExpr::Sum(positive_diffs) - 1); + cp_model.AddGreaterOrEqual(neg_var, LinearExpr::Sum(negative_diffs) - 1); all_violations.push_back(pos_var); all_violations.push_back(neg_var); } diff --git a/examples/cpp/cryptarithm_sat.cc b/examples/cpp/cryptarithm_sat.cc index 38af00b071..e3e9c83d16 100644 --- a/examples/cpp/cryptarithm_sat.cc +++ b/examples/cpp/cryptarithm_sat.cc @@ -46,24 +46,20 @@ void SendMoreMoney() { // Force all letters to take on different values. cp_model.AddAllDifferent({s, e, n, d, m, o, r, y}); - // Column 0: Force c0 == m. + // Column 0: cp_model.AddEquality(c0, m); - // Column 1: Force c1 + s + m + o == 10*c0. - cp_model.AddEquality(LinearExpr::Sum({c1, s, m, o}), - LinearExpr::ScalProd({c0}, {10})); + // Column 1: + cp_model.AddEquality(c1 + s + m + o, 10 * c0); - // Column 2: Force c2 + e + o == n + 10*c1. - cp_model.AddEquality(LinearExpr::Sum({c2, e, o}), - LinearExpr::ScalProd({n, c1}, {1, 10})); + // Column 2: + cp_model.AddEquality(c2 + e + o, n + 10 * c1); - // Column 3: Force c3 + n + r == e + 10*c2. - cp_model.AddEquality(LinearExpr::Sum({c3, n, r}), - LinearExpr::ScalProd({e, c2}, {1, 10})); + // Column 3: + cp_model.AddEquality(c3 + n + r, e + 10 * c2); - // Column 4: Force d + e == y + 10*c3. - cp_model.AddEquality(LinearExpr::Sum({d, e}), - LinearExpr::ScalProd({y, c3}, {1, 10})); + // Column 4: + cp_model.AddEquality(d + e, y + 10 * c3); // Declare the model, solve it, and display the results. const CpSolverResponse response = Solve(cp_model.Build()); diff --git a/examples/cpp/golomb_sat.cc b/examples/cpp/golomb_sat.cc index edaa400a65..6164639a3f 100644 --- a/examples/cpp/golomb_sat.cc +++ b/examples/cpp/golomb_sat.cc @@ -67,7 +67,7 @@ void GolombRuler(int size) { for (int i = 0; i < size; ++i) { for (int j = i + 1; j < size; ++j) { const IntVar diff = cp_model.NewIntVar(domain); - cp_model.AddEquality(LinearExpr::Sum({diff, ticks[i]}), ticks[j]); + cp_model.AddEquality(diff, ticks[j] - ticks[i]); diffs.push_back(diff); } } diff --git a/examples/cpp/jobshop_sat.cc b/examples/cpp/jobshop_sat.cc index 372918150b..e14986a9a6 100644 --- a/examples/cpp/jobshop_sat.cc +++ b/examples/cpp/jobshop_sat.cc @@ -302,9 +302,10 @@ void AddAlternativeTaskDurationRelaxation( // end == start + min_duration + // sum(shifted_duration[i] * presence_literals[i]) cp_model.AddEquality( - LinearExpr::ScalProd({tasks[t].end, tasks[t].start}, {1, -1}), - LinearExpr::ScalProd(presence_literals, shifted_durations) - .AddConstant(min_duration)); + tasks[t].end, + tasks[t].start + + LinearExpr::ScalProd(presence_literals, shifted_durations) + + min_duration); } } } @@ -432,18 +433,18 @@ void CreateMachines( // Note that we use the start + duration + transition as this is more // precise than the non-propagated end. cp_model - .AddLessOrEqual(tail.interval.StartExpr().AddConstant( - tail.fixed_duration + transition), - head.interval.StartExpr()) + .AddLessOrEqual( + tail.interval.StartExpr() + tail.fixed_duration + transition, + head.interval.StartExpr()) .OnlyEnforceIf(lit); } } // Add a linear equation to define the size of the tail interval. if (absl::GetFlag(FLAGS_use_variable_duration_to_encode_transition)) { - cp_model.AddEquality(tail.interval.SizeExpr(), - LinearExpr::ScalProd(literals, transitions) - .AddConstant(tail.fixed_duration)); + cp_model.AddEquality( + tail.interval.SizeExpr(), + LinearExpr::ScalProd(literals, transitions) + tail.fixed_duration); } } LOG(INFO) << "Machine " << m @@ -493,8 +494,7 @@ void CreateObjective( objective_coeffs.push_back(lateness_penalty); } else { const IntVar lateness_var = cp_model.NewIntVar(Domain(0, horizon)); - cp_model.AddMaxEquality(lateness_var, - {0, job_end.AddConstant(-due_date)}); + cp_model.AddMaxEquality(lateness_var, {0, job_end - due_date}); objective_vars.push_back(lateness_var); objective_coeffs.push_back(lateness_penalty); } @@ -508,9 +508,7 @@ void CreateObjective( if (due_date > 0) { const IntVar earliness_var = cp_model.NewIntVar(Domain(0, horizon)); - cp_model.AddMaxEquality( - earliness_var, - {0, LinearExpr::Term(job_end, -1).AddConstant(due_date)}); + cp_model.AddMaxEquality(earliness_var, {0, due_date - job_end}); objective_vars.push_back(earliness_var); objective_coeffs.push_back(earliness_penalty); } @@ -531,11 +529,11 @@ void CreateObjective( problem.scaling_factor().value()); } cp_model.Minimize( - DoubleLinearExpr::ScalProd(objective_vars, double_objective_coeffs) - .AddConstant(objective_offset)); + DoubleLinearExpr::ScalProd(objective_vars, double_objective_coeffs) + + static_cast(objective_offset)); } else { - cp_model.Minimize(LinearExpr::ScalProd(objective_vars, objective_coeffs) - .AddConstant(objective_offset)); + cp_model.Minimize(LinearExpr::ScalProd(objective_vars, objective_coeffs) + + objective_offset); } } @@ -627,7 +625,7 @@ void AddMakespanRedundantConstraints( } } cp_model.AddLessOrEqual(LinearExpr::Sum(all_task_durations), - LinearExpr::Term(makespan, num_machines)); + makespan * num_machines); } void DisplayJobStatistics( @@ -749,7 +747,7 @@ void Solve(const JsspInputProblem& problem) { const IntVar start = job_to_tasks[precedence.second_job_index()].front().start; const IntVar end = job_to_tasks[precedence.first_job_index()].back().end; - cp_model.AddLessOrEqual(end.AddConstant(precedence.min_delay()), start); + cp_model.AddLessOrEqual(end + precedence.min_delay(), start); } // Objective. diff --git a/examples/cpp/network_routing_sat.cc b/examples/cpp/network_routing_sat.cc index cd8bb737ef..3f8ca09ed9 100644 --- a/examples/cpp/network_routing_sat.cc +++ b/examples/cpp/network_routing_sat.cc @@ -601,8 +601,7 @@ class NetworkRoutingSolver { LinearExpr traffic_expr; for (int i = 0; i < path_vars.size(); ++i) { sum_of_traffic += demands_array_[i].traffic; - traffic_expr.AddTerm(path_vars[i][arc_index], - demands_array_[i].traffic); + traffic_expr += path_vars[i][arc_index] * demands_array_[i].traffic; } const IntVar traffic_var = cp_model.NewIntVar(Domain(0, sum_of_traffic)); traffic_vars[arc_index] = traffic_var; @@ -611,8 +610,7 @@ class NetworkRoutingSolver { const int64_t capacity = arc_capacity_[arc_index]; IntVar scaled_traffic = cp_model.NewIntVar(Domain(0, sum_of_traffic * 1000)); - cp_model.AddEquality(LinearExpr::ScalProd({traffic_var}, {1000}), - scaled_traffic); + cp_model.AddEquality(traffic_var * 1000, scaled_traffic); IntVar normalized_traffic = cp_model.NewIntVar(Domain(0, sum_of_traffic * 1000 / capacity)); max_normalized_traffic = @@ -634,12 +632,8 @@ class NetworkRoutingSolver { cp_model.NewIntVar(Domain(0, max_normalized_traffic)); cp_model.AddMaxEquality(max_usage_cost, normalized_traffic_vars); - LinearExpr objective_expr; - objective_expr.AddVar(max_usage_cost); - for (const BoolVar var : comfortable_traffic_vars) { - objective_expr.AddVar(var); - } - cp_model.Minimize(objective_expr); + cp_model.Minimize(LinearExpr::Sum(comfortable_traffic_vars) + + max_usage_cost); Model model; if (!absl::GetFlag(FLAGS_params).empty()) { diff --git a/examples/cpp/sports_scheduling_sat.cc b/examples/cpp/sports_scheduling_sat.cc index 1cffa8f8e7..444af85140 100644 --- a/examples/cpp/sports_scheduling_sat.cc +++ b/examples/cpp/sports_scheduling_sat.cc @@ -97,8 +97,7 @@ void OpponentModel(int num_teams) { // Link opponent, home_away, and signed_opponent. builder.AddEquality(opp, signed_opp).OnlyEnforceIf(Not(home)); - builder.AddEquality(LinearExpr(opp).AddConstant(num_teams), signed_opp) - .OnlyEnforceIf(home); + builder.AddEquality(opp + num_teams, signed_opp).OnlyEnforceIf(home); } } @@ -118,7 +117,7 @@ void OpponentModel(int num_teams) { IntVar second_home = builder.NewBoolVar(); builder.AddVariableElement(day_opponents[first_team], day_home_aways, second_home); - builder.AddEquality(LinearExpr::Sum({first_home, second_home}), 1); + builder.AddEquality(first_home + second_home, 1); } builder.AddEquality(LinearExpr::Sum(day_home_aways), num_teams / 2); diff --git a/examples/cpp/weighted_tardiness_sat.cc b/examples/cpp/weighted_tardiness_sat.cc index 5074a3a59a..56f2c64e3b 100644 --- a/examples/cpp/weighted_tardiness_sat.cc +++ b/examples/cpp/weighted_tardiness_sat.cc @@ -107,7 +107,7 @@ void Solve(const std::vector& durations, // tardiness_vars >= end - due_date cp_model.AddGreaterOrEqual(tardiness_vars[i], - task_ends[i].AddConstant(-due_dates[i])); + task_ends[i] - due_dates[i]); } } diff --git a/ortools/sat/cp_model.cc b/ortools/sat/cp_model.cc index 18d6cd8118..94fc0b27a1 100644 --- a/ortools/sat/cp_model.cc +++ b/ortools/sat/cp_model.cc @@ -115,15 +115,16 @@ std::string IntVar::DebugString() const { } else { std::string output; if (var_proto.name().empty()) { - absl::StrAppend(&output, "IntVar", index_, "("); + absl::StrAppend(&output, "V", index_, "("); } else { absl::StrAppend(&output, var_proto.name(), "("); } + + // TODO(user): Use domain pretty print function. if (var_proto.domain_size() == 2 && var_proto.domain(0) == var_proto.domain(1)) { absl::StrAppend(&output, var_proto.domain(0), ")"); } else { - // TODO(user): Use domain pretty print function. absl::StrAppend(&output, var_proto.domain(0), ", ", var_proto.domain(1), ")"); } @@ -136,8 +137,6 @@ std::ostream& operator<<(std::ostream& os, const IntVar& var) { return os; } -LinearExpr::LinearExpr() {} - LinearExpr::LinearExpr(BoolVar var) { AddVar(var); } LinearExpr::LinearExpr(IntVar var) { AddVar(var); } @@ -246,12 +245,28 @@ LinearExpr& LinearExpr::AddTerm(IntVar var, int64_t coeff) { return *this; } -LinearExpr& LinearExpr::AddExpression(const LinearExpr& expr) { - constant_ += expr.constant_; - variables_.insert(variables_.end(), expr.variables_.begin(), - expr.variables_.end()); - coefficients_.insert(coefficients_.end(), expr.coefficients_.begin(), - expr.coefficients_.end()); +LinearExpr& LinearExpr::operator+=(const LinearExpr& other) { + constant_ += other.constant_; + variables_.insert(variables_.end(), other.variables_.begin(), + other.variables_.end()); + coefficients_.insert(coefficients_.end(), other.coefficients_.begin(), + other.coefficients_.end()); + return *this; +} + +LinearExpr& LinearExpr::operator-=(const LinearExpr& other) { + constant_ -= other.constant_; + variables_.insert(variables_.end(), other.variables_.begin(), + other.variables_.end()); + for (const int64_t coeff : other.coefficients_) { + coefficients_.push_back(-coeff); + } + return *this; +} + +LinearExpr& LinearExpr::operator*=(int64_t factor) { + constant_ *= factor; + for (int64_t& coeff : coefficients_) coeff *= factor; return *this; } @@ -311,16 +326,16 @@ std::ostream& operator<<(std::ostream& os, const LinearExpr& e) { DoubleLinearExpr::DoubleLinearExpr() {} -DoubleLinearExpr::DoubleLinearExpr(BoolVar var) { AddVar(var); } +DoubleLinearExpr::DoubleLinearExpr(BoolVar var) { AddTerm(var, 1.0); } -DoubleLinearExpr::DoubleLinearExpr(IntVar var) { AddVar(var); } +DoubleLinearExpr::DoubleLinearExpr(IntVar var) { AddTerm(var, 1); } DoubleLinearExpr::DoubleLinearExpr(double constant) { constant_ = constant; } DoubleLinearExpr DoubleLinearExpr::Sum(absl::Span vars) { DoubleLinearExpr result; for (const IntVar& var : vars) { - result.AddVar(var); + result.AddTerm(var, 1.0); } return result; } @@ -328,7 +343,7 @@ DoubleLinearExpr DoubleLinearExpr::Sum(absl::Span vars) { DoubleLinearExpr DoubleLinearExpr::Sum(absl::Span vars) { DoubleLinearExpr result; for (const BoolVar& var : vars) { - result.AddVar(var); + result.AddTerm(var, 1.0); } return result; } @@ -336,7 +351,7 @@ DoubleLinearExpr DoubleLinearExpr::Sum(absl::Span vars) { DoubleLinearExpr DoubleLinearExpr::Sum(std::initializer_list vars) { DoubleLinearExpr result; for (const IntVar& var : vars) { - result.AddVar(var); + result.AddTerm(var, 1.0); } return result; } @@ -378,16 +393,30 @@ DoubleLinearExpr DoubleLinearExpr::Term(IntVar var, double coefficient) { return result; } -DoubleLinearExpr& DoubleLinearExpr::AddConstant(double value) { +DoubleLinearExpr& DoubleLinearExpr::operator+=(double value) { constant_ += value; return *this; } -DoubleLinearExpr& DoubleLinearExpr::AddVar(IntVar var) { +DoubleLinearExpr& DoubleLinearExpr::AddConstant(double constant) { + constant_ += constant; + return *this; +} + +DoubleLinearExpr& DoubleLinearExpr::operator+=(IntVar var) { AddTerm(var, 1); return *this; } +DoubleLinearExpr& DoubleLinearExpr::operator+=(const DoubleLinearExpr& expr) { + constant_ += expr.constant_; + variables_.insert(variables_.end(), expr.variables_.begin(), + expr.variables_.end()); + coefficients_.insert(coefficients_.end(), expr.coefficients_.begin(), + expr.coefficients_.end()); + return *this; +} + DoubleLinearExpr& DoubleLinearExpr::AddTerm(IntVar var, double coeff) { const int index = var.index_; if (RefIsPositive(index)) { @@ -401,13 +430,31 @@ DoubleLinearExpr& DoubleLinearExpr::AddTerm(IntVar var, double coeff) { return *this; } -DoubleLinearExpr& DoubleLinearExpr::AddExpression( - const DoubleLinearExpr& expr) { - constant_ += expr.constant_; +DoubleLinearExpr& DoubleLinearExpr::operator-=(double value) { + constant_ -= value; + return *this; +} + +DoubleLinearExpr& DoubleLinearExpr::operator-=(IntVar var) { + AddTerm(var, -1.0); + return *this; +} + +DoubleLinearExpr& DoubleLinearExpr::operator-=(const DoubleLinearExpr& expr) { + constant_ -= expr.constant_; variables_.insert(variables_.end(), expr.variables_.begin(), expr.variables_.end()); - coefficients_.insert(coefficients_.end(), expr.coefficients_.begin(), - expr.coefficients_.end()); + for (const double coeff : expr.coefficients()) { + coefficients_.push_back(-coeff); + } + return *this; +} + +DoubleLinearExpr& DoubleLinearExpr::operator*=(double coeff) { + constant_ *= coeff; + for (double& c : coefficients_) { + c *= coeff; + } return *this; } diff --git a/ortools/sat/cp_model.h b/ortools/sat/cp_model.h index 625fc5dcd1..4fa7800ff6 100644 --- a/ortools/sat/cp_model.h +++ b/ortools/sat/cp_model.h @@ -216,10 +216,10 @@ std::ostream& operator<<(std::ostream& os, const IntVar& var); /** * A dedicated container for linear expressions. * - * This class helps building and manipulating linear expressions. * With the use of implicit constructors, it can accept integer values, Boolean - * and Integer variables. Note that Not(x) will be silently transformed into - * 1 - x when added to the linear expression. + * and Integer variables. Note that Not(x) will be silently transformed into 1 - + * x when added to the linear expression. It also support operator overloads to + * construct the linear expression naturally. * * Furthermore, static methods allows sums and scalar products, with or without * an additional constant. @@ -231,31 +231,29 @@ std::ostream& operator<<(std::ostream& os, const IntVar& var); IntVar y = model.NewIntVar({0, 10}).WithName("y"); BoolVar b = model.NewBoolVar().WithName("b"); BoolVar c = model.NewBoolVar().WithName("c"); - LinearExpr e1(x); // e1 = x. - LinearExpr e2 = LinearExpr::Sum({x, y}).AddConstant(5); // e2 = x + y + 5; - LinearExpr e3 = LinearExpr::ScalProd({x, y}, {2, -1}); // e3 = 2 * x - y. - LinearExpr e4(b); // e4 = b. - LinearExpr e5(b.Not()); // e5 = 1 - b. - // If passing a std::vector, a specialized method must be called. + LinearExpr e1(x); // Or e1 = x. + LinearExpr e2 = x + y + 5; + LinearExpr e3 = 2 * x - y; + LinearExpr e4 = b; + LinearExpr e5 = b.Not(); // 1 - b. std::vector bools = {b, Not(c)}; - LinearExpr e6 = LinearExpr::Sum(bools); // e6 = b + 1 - c; - // e7 = -3 * b + 1 - c; - LinearExpr e7 = LinearExpr::ScalProd(bools, {-3, 1}); + LinearExpr e6 = LinearExpr::Sum(bools); // b + 1 - c; + LinearExpr e7 = -3 * b + Not(c); // -3 * b + 1 - c; \endcode * This can be used implicitly in some of the CpModelBuilder methods. * \code - cp_model.AddGreaterThan(x, 5); // x > 5 - cp_model.AddEquality(x, LinearExpr(y).AddConstant(5)); // x == y + 5 + cp_model.AddGreaterThan(x, 5); + cp_model.AddEquality(x, y + 5); \endcode */ class LinearExpr { public: - LinearExpr(); + LinearExpr() = default; /** * Constructs a linear expression from a Boolean variable. * - * It deals with logical negation correctly. + * It deals with logical negation correctly. */ LinearExpr(BoolVar var); // NOLINT(runtime/explicit) @@ -265,18 +263,6 @@ class LinearExpr { /// Constructs a constant linear expression. LinearExpr(int64_t constant); // NOLINT(runtime/explicit) - /// Adds a constant value to the linear expression. - LinearExpr& AddConstant(int64_t value); - - /// Adds a single integer variable to the linear expression. - LinearExpr& AddVar(IntVar var); - - /// Adds a term (var * coeff) to the linear expression. - LinearExpr& AddTerm(IntVar var, int64_t coeff); - - /// Adds another linear expression to the linear expression. - LinearExpr& AddExpression(const LinearExpr& expr); - /// Constructs the sum of a list of variables. static LinearExpr Sum(absl::Span vars); @@ -310,6 +296,17 @@ class LinearExpr { /// Constructs var * coefficient. static LinearExpr Term(IntVar var, int64_t coefficient); + // Operators. + LinearExpr& operator+=(const LinearExpr& other); + LinearExpr& operator-=(const LinearExpr& other); + LinearExpr& operator*=(int64_t factor); + + // Deprecated. Use operators instead. + LinearExpr& AddConstant(int64_t value); + LinearExpr& AddVar(IntVar var); + LinearExpr& AddTerm(IntVar var, int64_t coeff); + LinearExpr& AddExpression(const LinearExpr& expr) { return *this += expr; } + /// Returns the vector of variables. const std::vector& variables() const { return variables_; } @@ -390,16 +387,31 @@ class DoubleLinearExpr { explicit DoubleLinearExpr(double constant); /// Adds a constant value to the linear expression. - DoubleLinearExpr& AddConstant(double value); + DoubleLinearExpr& operator+=(double value); /// Adds a single integer variable to the linear expression. - DoubleLinearExpr& AddVar(IntVar var); + DoubleLinearExpr& operator+=(IntVar var); + + /// Adds another linear expression to the linear expression. + DoubleLinearExpr& operator+=(const DoubleLinearExpr& expr); /// Adds a term (var * coeff) to the linear expression. DoubleLinearExpr& AddTerm(IntVar var, double coeff); + /// Deprecated. Use +=. + DoubleLinearExpr& AddConstant(double constant); + + /// Adds a constant value to the linear expression. + DoubleLinearExpr& operator-=(double value); + + /// Adds a single integer variable to the linear expression. + DoubleLinearExpr& operator-=(IntVar var); + /// Adds another linear expression to the linear expression. - DoubleLinearExpr& AddExpression(const DoubleLinearExpr& expr); + DoubleLinearExpr& operator-=(const DoubleLinearExpr& expr); + + /// Multiply the linear expression by a constant. + DoubleLinearExpr& operator*=(double coeff); /// Constructs the sum of a list of variables. static DoubleLinearExpr Sum(absl::Span vars); @@ -1130,6 +1142,180 @@ int64_t SolutionIntegerValue(const CpSolverResponse& r, const LinearExpr& expr); /// Evaluates the value of a Boolean literal in a solver response. bool SolutionBooleanValue(const CpSolverResponse& r, BoolVar x); +// ============================================================================ +// Minimal support for "natural" API to create LinearExpr. +// +// Note(user): This might be optimized further by optimizing LinearExpr for +// holding one term, or introducing an LinearTerm class, but these should mainly +// be used to construct small expressions. Revisit if we run into performance +// issues. Note that if perf become a bottleneck for a client, then probably +// directly writing the proto will be even faster. +// ============================================================================ + +inline LinearExpr operator-(LinearExpr expr) { return expr *= -1; } + +inline LinearExpr operator+(const LinearExpr& lhs, const LinearExpr& rhs) { + LinearExpr temp(lhs); + temp += rhs; + return temp; +} +inline LinearExpr operator+(LinearExpr&& lhs, const LinearExpr& rhs) { + lhs += rhs; + return std::move(lhs); +} +inline LinearExpr operator+(const LinearExpr& lhs, LinearExpr&& rhs) { + rhs += lhs; + return std::move(rhs); +} +inline LinearExpr operator+(LinearExpr&& lhs, LinearExpr&& rhs) { + if (lhs.variables().size() < rhs.variables().size()) { + rhs += lhs; + return std::move(rhs); + } else { + lhs += rhs; + return std::move(lhs); + } +} + +inline LinearExpr operator-(const LinearExpr& lhs, const LinearExpr& rhs) { + LinearExpr temp(lhs); + temp -= rhs; + return temp; +} +inline LinearExpr operator-(LinearExpr&& lhs, const LinearExpr& rhs) { + lhs -= rhs; + return std::move(lhs); +} +inline LinearExpr operator-(const LinearExpr& lhs, LinearExpr&& rhs) { + rhs -= lhs; + return std::move(rhs); +} +inline LinearExpr operator-(LinearExpr&& lhs, LinearExpr&& rhs) { + if (lhs.variables().size() < rhs.variables().size()) { + rhs -= lhs; + return std::move(rhs); + } else { + lhs -= rhs; + return std::move(lhs); + } +} + +inline LinearExpr operator*(LinearExpr expr, int64_t factor) { + expr *= factor; + return expr; +} +inline LinearExpr operator*(int64_t factor, LinearExpr expr) { + expr *= factor; + return expr; +} + +// For DoubleLinearExpr. + +inline DoubleLinearExpr operator-(DoubleLinearExpr expr) { return expr *= -1; } + +inline DoubleLinearExpr operator+(const DoubleLinearExpr& lhs, + const DoubleLinearExpr& rhs) { + DoubleLinearExpr temp(lhs); + temp += rhs; + return temp; +} +inline DoubleLinearExpr operator+(DoubleLinearExpr&& lhs, + const DoubleLinearExpr& rhs) { + lhs += rhs; + return std::move(lhs); +} +inline DoubleLinearExpr operator+(const DoubleLinearExpr& lhs, + DoubleLinearExpr&& rhs) { + rhs += lhs; + return std::move(rhs); +} +inline DoubleLinearExpr operator+(DoubleLinearExpr&& lhs, + DoubleLinearExpr&& rhs) { + if (lhs.variables().size() < rhs.variables().size()) { + rhs += lhs; + return std::move(rhs); + } else { + lhs += rhs; + return std::move(lhs); + } +} + +inline DoubleLinearExpr operator+(const DoubleLinearExpr& lhs, double rhs) { + DoubleLinearExpr temp(lhs); + temp += rhs; + return temp; +} +inline DoubleLinearExpr operator+(DoubleLinearExpr&& lhs, double rhs) { + lhs += rhs; + return std::move(lhs); +} +inline DoubleLinearExpr operator+(double lhs, DoubleLinearExpr&& rhs) { + rhs += lhs; + return std::move(rhs); +} +inline DoubleLinearExpr operator+(double lhs, const DoubleLinearExpr& rhs) { + DoubleLinearExpr temp(rhs); + temp += lhs; + return temp; +} + +inline DoubleLinearExpr operator-(const DoubleLinearExpr& lhs, + const DoubleLinearExpr& rhs) { + DoubleLinearExpr temp(lhs); + temp -= rhs; + return temp; +} +inline DoubleLinearExpr operator-(DoubleLinearExpr&& lhs, + const DoubleLinearExpr& rhs) { + lhs -= rhs; + return std::move(lhs); +} +inline DoubleLinearExpr operator-(const DoubleLinearExpr& lhs, + DoubleLinearExpr&& rhs) { + rhs -= lhs; + return std::move(rhs); +} +inline DoubleLinearExpr operator-(DoubleLinearExpr&& lhs, + DoubleLinearExpr&& rhs) { + if (lhs.variables().size() < rhs.variables().size()) { + rhs -= lhs; + return std::move(rhs); + } else { + lhs -= rhs; + return std::move(lhs); + } +} + +inline DoubleLinearExpr operator-(const DoubleLinearExpr& lhs, double rhs) { + DoubleLinearExpr temp(lhs); + temp -= rhs; + return temp; +} +inline DoubleLinearExpr operator-(DoubleLinearExpr&& lhs, double rhs) { + lhs -= rhs; + return std::move(lhs); +} +inline DoubleLinearExpr operator-(double lhs, DoubleLinearExpr&& rhs) { + rhs *= -1; + rhs += lhs; + return std::move(rhs); +} +inline DoubleLinearExpr operator-(double lhs, const DoubleLinearExpr& rhs) { + DoubleLinearExpr temp = -rhs; + temp += lhs; + return temp; +} + +inline DoubleLinearExpr operator*(DoubleLinearExpr expr, double factor) { + expr *= factor; + return expr; +} + +inline DoubleLinearExpr operator*(double factor, DoubleLinearExpr expr) { + expr *= factor; + return expr; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/cp_model.proto b/ortools/sat/cp_model.proto index ebf32104a5..c628b3e651 100644 --- a/ortools/sat/cp_model.proto +++ b/ortools/sat/cp_model.proto @@ -796,7 +796,8 @@ message CpSolverResponse { // The integral of log(1 + absolute_objective_gap) over time. double gap_integral = 22; - // Additional information about how the solution was found. + // Additional information about how the solution was found. It also stores + // model or parameters errors that caused the model to be invalid. string solution_info = 20; // The solve log will be filled if the parameter log_to_response is set to diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index d220792c93..2064b6527e 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -2990,6 +2990,7 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { // when it become needed. Or rename to INVALID_INPUT ? shared_response_manager->MutableResponse()->set_status( CpSolverStatus::MODEL_INVALID); + shared_response_manager->MutableResponse()->set_solution_info(error); return shared_response_manager->GetResponse(); } } @@ -3037,6 +3038,7 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { SOLVER_LOG(logger, "Invalid model: ", error); shared_response_manager->MutableResponse()->set_status( CpSolverStatus::MODEL_INVALID); + shared_response_manager->MutableResponse()->set_solution_info(error); return shared_response_manager->GetResponse(); } } diff --git a/ortools/sat/csharp/Constraints.cs b/ortools/sat/csharp/Constraints.cs index 3c59fdf1f1..f3f0196cfa 100644 --- a/ortools/sat/csharp/Constraints.cs +++ b/ortools/sat/csharp/Constraints.cs @@ -13,50 +13,50 @@ namespace Google.OrTools.Sat { - using System; - using System.Collections.Generic; +using System; +using System.Collections.Generic; - public class Constraint +public class Constraint +{ + public Constraint(CpModelProto model) { - public Constraint(CpModelProto model) - { - index_ = model.Constraints.Count; - constraint_ = new ConstraintProto(); - model.Constraints.Add(constraint_); - } + index_ = model.Constraints.Count; + constraint_ = new ConstraintProto(); + model.Constraints.Add(constraint_); + } - public void OnlyEnforceIf(ILiteral lit) + public void OnlyEnforceIf(ILiteral lit) + { + constraint_.EnforcementLiteral.Add(lit.GetIndex()); + } + + public void OnlyEnforceIf(ILiteral[] lits) + { + foreach (ILiteral lit in lits) { constraint_.EnforcementLiteral.Add(lit.GetIndex()); } - - public void OnlyEnforceIf(ILiteral[] lits) - { - foreach (ILiteral lit in lits) - { - constraint_.EnforcementLiteral.Add(lit.GetIndex()); - } - } - - public int Index - { - get { - return index_; - } - } - - public ConstraintProto Proto - { - get { - return constraint_; - } - set { - constraint_ = value; - } - } - - private int index_; - private ConstraintProto constraint_; } + public int Index + { + get { + return index_; + } + } + + public ConstraintProto Proto + { + get { + return constraint_; + } + set { + constraint_ = value; + } + } + + private int index_; + private ConstraintProto constraint_; +} + } // namespace Google.OrTools.Sat diff --git a/ortools/sat/csharp/CpModel.cs b/ortools/sat/csharp/CpModel.cs index 48bc1783c7..1e24251c00 100644 --- a/ortools/sat/csharp/CpModel.cs +++ b/ortools/sat/csharp/CpModel.cs @@ -13,900 +13,898 @@ namespace Google.OrTools.Sat { - using System; - using System.Collections.Generic; - using Google.OrTools.Util; +using System; +using System.Collections.Generic; +using Google.OrTools.Util; - /// - /// Wrapper class around the cp_model proto. - /// - public class CpModel +/// +/// Wrapper class around the cp_model proto. +/// +public class CpModel +{ + public CpModel() { - public CpModel() - { - model_ = new CpModelProto(); - constant_map_ = new Dictionary(); - } - - // Getters. - - public CpModelProto Model - { - get { - return model_; - } - } - - int Negated(int index) - { - return -index - 1; - } - - // Integer variables and constraints. - - public IntVar NewIntVar(long lb, long ub, string name) - { - return new IntVar(model_, new Domain(lb, ub), name); - } - - public IntVar NewIntVarFromDomain(Domain domain, string name) - { - return new IntVar(model_, domain, name); - } - // Constants (named or not). - - // TODO: Cache constant. - public IntVar NewConstant(long value) - { - return new IntVar(model_, new Domain(value), String.Format("{0}", value)); - } - - public IntVar NewConstant(long value, string name) - { - return new IntVar(model_, new Domain(value), name); - } - - public IntVar NewBoolVar(string name) - { - return new IntVar(model_, new Domain(0, 1), name); - } - - public Constraint AddLinearConstraint(LinearExpr linear_expr, long lb, long ub) - { - return AddLinearExpressionInDomain(linear_expr, new Domain(lb, ub)); - } - - public Constraint AddLinearExpressionInDomain(LinearExpr linear_expr, Domain domain) - { - Dictionary dict = new Dictionary(); - long constant = LinearExpr.GetVarValueMap(linear_expr, 1L, dict); - Constraint ct = new Constraint(model_); - LinearConstraintProto linear = new LinearConstraintProto(); - foreach (KeyValuePair term in dict) - { - linear.Vars.Add(term.Key.Index); - linear.Coeffs.Add(term.Value); - } - foreach (long value in domain.FlattenedIntervals()) - { - if (value == Int64.MinValue || value == Int64.MaxValue) - { - linear.Domain.Add(value); - } - else - { - linear.Domain.Add(value - constant); - } - } - ct.Proto.Linear = linear; - return ct; - } - - public Constraint Add(BoundedLinearExpression lin) - { - switch (lin.CtType) - { - case BoundedLinearExpression.Type.BoundExpression: { - return AddLinearExpressionInDomain(lin.Left, new Domain(lin.Lb, lin.Ub)); - } - case BoundedLinearExpression.Type.VarEqVar: { - return AddLinearExpressionInDomain(lin.Left - lin.Right, new Domain(0)); - } - case BoundedLinearExpression.Type.VarDiffVar: { - return AddLinearExpressionInDomain( - lin.Left - lin.Right, - Domain.FromFlatIntervals(new long[] { Int64.MinValue, -1, 1, Int64.MaxValue })); - } - case BoundedLinearExpression.Type.VarEqCst: { - return AddLinearExpressionInDomain(lin.Left, new Domain(lin.Lb, lin.Lb)); - } - case BoundedLinearExpression.Type.VarDiffCst: { - return AddLinearExpressionInDomain( - lin.Left, - Domain.FromFlatIntervals(new long[] { Int64.MinValue, lin.Lb - 1, lin.Lb + 1, Int64.MaxValue })); - } - } - return null; - } - - public Constraint AddAllDifferent(IEnumerable vars) - { - Constraint ct = new Constraint(model_); - AllDifferentConstraintProto alldiff = new AllDifferentConstraintProto(); - foreach (IntVar var in vars) - { - alldiff.Exprs.Add(GetLinearExpressionProto(var)); - } - ct.Proto.AllDiff = alldiff; - return ct; - } - - public Constraint AddAllDifferent(IEnumerable exprs) - { - Constraint ct = new Constraint(model_); - AllDifferentConstraintProto alldiff = new AllDifferentConstraintProto(); - foreach (LinearExpr expr in exprs) - { - alldiff.Exprs.Add(GetLinearExpressionProto(expr)); - } - - ct.Proto.AllDiff = alldiff; - return ct; - } - - public Constraint AddElement(IntVar index, IEnumerable vars, IntVar target) - { - Constraint ct = new Constraint(model_); - ElementConstraintProto element = new ElementConstraintProto(); - element.Index = index.Index; - foreach (IntVar var in vars) - { - element.Vars.Add(var.Index); - } - element.Target = target.Index; - ct.Proto.Element = element; - return ct; - } - - public Constraint AddElement(IntVar index, IEnumerable values, IntVar target) - { - Constraint ct = new Constraint(model_); - ElementConstraintProto element = new ElementConstraintProto(); - element.Index = index.Index; - foreach (long value in values) - { - element.Vars.Add(ConvertConstant(value)); - } - element.Target = target.Index; - ct.Proto.Element = element; - return ct; - } - - public Constraint AddElement(IntVar index, IEnumerable values, IntVar target) - { - Constraint ct = new Constraint(model_); - ElementConstraintProto element = new ElementConstraintProto(); - element.Index = index.Index; - foreach (int value in values) - { - element.Vars.Add(ConvertConstant(value)); - } - element.Target = target.Index; - ct.Proto.Element = element; - return ct; - } - - public Constraint AddCircuit(IEnumerable> arcs) - { - Constraint ct = new Constraint(model_); - CircuitConstraintProto circuit = new CircuitConstraintProto(); - foreach (var arc in arcs) - { - circuit.Tails.Add(arc.Item1); - circuit.Heads.Add(arc.Item2); - circuit.Literals.Add(arc.Item3.GetIndex()); - } - ct.Proto.Circuit = circuit; - return ct; - } - - public Constraint AddAllowedAssignments(IEnumerable vars, long[,] tuples) - { - Constraint ct = new Constraint(model_); - TableConstraintProto table = new TableConstraintProto(); - foreach (IntVar var in vars) - { - table.Vars.Add(var.Index); - } - for (int i = 0; i < tuples.GetLength(0); ++i) - { - for (int j = 0; j < tuples.GetLength(1); ++j) - { - table.Values.Add(tuples[i, j]); - } - } - ct.Proto.Table = table; - return ct; - } - - public Constraint AddForbiddenAssignments(IEnumerable vars, long[,] tuples) - { - Constraint ct = AddAllowedAssignments(vars, tuples); - ct.Proto.Table.Negated = true; - return ct; - } - - public Constraint AddAutomaton(IEnumerable vars, long starting_state, long[,] transitions, - IEnumerable final_states) - { - Constraint ct = new Constraint(model_); - AutomatonConstraintProto aut = new AutomatonConstraintProto(); - foreach (IntVar var in vars) - { - aut.Vars.Add(var.Index); - } - aut.StartingState = starting_state; - foreach (long f in final_states) - { - aut.FinalStates.Add(f); - } - for (int i = 0; i < transitions.GetLength(0); ++i) - { - aut.TransitionTail.Add(transitions[i, 0]); - aut.TransitionLabel.Add(transitions[i, 1]); - aut.TransitionHead.Add(transitions[i, 2]); - } - - ct.Proto.Automaton = aut; - return ct; - } - - public Constraint AddAutomaton(IEnumerable vars, long starting_state, - IEnumerable> transitions, IEnumerable final_states) - { - Constraint ct = new Constraint(model_); - AutomatonConstraintProto aut = new AutomatonConstraintProto(); - foreach (IntVar var in vars) - { - aut.Vars.Add(var.Index); - } - aut.StartingState = starting_state; - foreach (long f in final_states) - { - aut.FinalStates.Add(f); - } - foreach (Tuple transition in transitions) - { - aut.TransitionHead.Add(transition.Item1); - aut.TransitionLabel.Add(transition.Item2); - aut.TransitionTail.Add(transition.Item3); - } - - ct.Proto.Automaton = aut; - return ct; - } - - public Constraint AddInverse(IEnumerable direct, IEnumerable reverse) - { - Constraint ct = new Constraint(model_); - InverseConstraintProto inverse = new InverseConstraintProto(); - foreach (IntVar var in direct) - { - inverse.FDirect.Add(var.Index); - } - foreach (IntVar var in reverse) - { - inverse.FInverse.Add(var.Index); - } - ct.Proto.Inverse = inverse; - return ct; - } - - public Constraint AddReservoirConstraint(IEnumerable times, IEnumerable levelChanges, - long minLevel, long maxLevel) - { - Constraint ct = new Constraint(model_); - ReservoirConstraintProto res = new ReservoirConstraintProto(); - foreach (IntVar time in times) - { - res.TimeExprs.Add(GetLinearExpressionProto(time)); - } - foreach (I d in levelChanges) - { - res.LevelChanges.Add(Convert.ToInt64(d)); - } - - res.MinLevel = minLevel; - res.MaxLevel = maxLevel; - ct.Proto.Reservoir = res; - - return ct; - } - - public Constraint AddReservoirConstraintWithActive(IEnumerable times, IEnumerable levelChanges, - IEnumerable actives, long minLevel, long maxLevel) - { - Constraint ct = new Constraint(model_); - ReservoirConstraintProto res = new ReservoirConstraintProto(); - foreach (IntVar time in times) - { - res.TimeExprs.Add(GetLinearExpressionProto(time)); - } - foreach (I d in levelChanges) - { - res.LevelChanges.Add(Convert.ToInt64(d)); - } - foreach (IntVar var in actives) - { - res.ActiveLiterals.Add(var.Index); - } - res.MinLevel = minLevel; - res.MaxLevel = maxLevel; - ct.Proto.Reservoir = res; - - return ct; - } - - public Constraint AddReservoirConstraint(IEnumerable times, IEnumerable levelChanges, - long minLevel, long maxLevel) - { - Constraint ct = new Constraint(model_); - ReservoirConstraintProto res = new ReservoirConstraintProto(); - foreach (LinearExpr time in times) - { - res.TimeExprs.Add(GetLinearExpressionProto(time)); - } - foreach (I d in levelChanges) - { - res.LevelChanges.Add(Convert.ToInt64(d)); - } - - res.MinLevel = minLevel; - res.MaxLevel = maxLevel; - ct.Proto.Reservoir = res; - - return ct; - } - - public Constraint AddReservoirConstraintWithActive(IEnumerable times, - IEnumerable levelChanges, IEnumerable actives, - long minLevel, long maxLevel) - { - Constraint ct = new Constraint(model_); - ReservoirConstraintProto res = new ReservoirConstraintProto(); - foreach (LinearExpr time in times) - { - res.TimeExprs.Add(GetLinearExpressionProto(time)); - } - foreach (I d in levelChanges) - { - res.LevelChanges.Add(Convert.ToInt64(d)); - } - foreach (IntVar var in actives) - { - res.ActiveLiterals.Add(var.Index); - } - res.MinLevel = minLevel; - res.MaxLevel = maxLevel; - ct.Proto.Reservoir = res; - - return ct; - } - - public void AddMapDomain(IntVar var, IEnumerable bool_vars, long offset = 0) - { - int i = 0; - foreach (IntVar bool_var in bool_vars) - { - int b_index = bool_var.Index; - int var_index = var.Index; - - ConstraintProto ct1 = new ConstraintProto(); - LinearConstraintProto lin1 = new LinearConstraintProto(); - lin1.Vars.Add(var_index); - lin1.Coeffs.Add(1L); - lin1.Domain.Add(offset + i); - lin1.Domain.Add(offset + i); - ct1.Linear = lin1; - ct1.EnforcementLiteral.Add(b_index); - model_.Constraints.Add(ct1); - - ConstraintProto ct2 = new ConstraintProto(); - LinearConstraintProto lin2 = new LinearConstraintProto(); - lin2.Vars.Add(var_index); - lin2.Coeffs.Add(1L); - lin2.Domain.Add(Int64.MinValue); - lin2.Domain.Add(offset + i - 1); - lin2.Domain.Add(offset + i + 1); - lin2.Domain.Add(Int64.MaxValue); - ct2.Linear = lin2; - ct2.EnforcementLiteral.Add(-b_index - 1); - model_.Constraints.Add(ct2); - - i++; - } - } - - public Constraint AddImplication(ILiteral a, ILiteral b) - { - Constraint ct = new Constraint(model_); - BoolArgumentProto or = new BoolArgumentProto(); - or.Literals.Add(a.Not().GetIndex()); - or.Literals.Add(b.GetIndex()); - ct.Proto.BoolOr = or; - return ct; - } - - public Constraint AddBoolOr(IEnumerable literals) - { - Constraint ct = new Constraint(model_); - BoolArgumentProto bool_argument = new BoolArgumentProto(); - foreach (ILiteral lit in literals) - { - bool_argument.Literals.Add(lit.GetIndex()); - } - ct.Proto.BoolOr = bool_argument; - return ct; - } - - public Constraint AddBoolAnd(IEnumerable literals) - { - Constraint ct = new Constraint(model_); - BoolArgumentProto bool_argument = new BoolArgumentProto(); - foreach (ILiteral lit in literals) - { - bool_argument.Literals.Add(lit.GetIndex()); - } - ct.Proto.BoolAnd = bool_argument; - return ct; - } - - public Constraint AddBoolXor(IEnumerable literals) - { - Constraint ct = new Constraint(model_); - BoolArgumentProto bool_argument = new BoolArgumentProto(); - foreach (ILiteral lit in literals) - { - bool_argument.Literals.Add(lit.GetIndex()); - } - ct.Proto.BoolXor = bool_argument; - return ct; - } - - public Constraint AddMinEquality(LinearExpr target, IEnumerable vars) - { - Constraint ct = new Constraint(model_); - LinearArgumentProto args = new LinearArgumentProto(); - foreach (IntVar var in vars) - { - args.Exprs.Add(GetLinearExpressionProto(var, /*negate=*/true)); - } - args.Target = GetLinearExpressionProto(target, /*negate=*/true); - ct.Proto.LinMax = args; - return ct; - } - - public Constraint AddMinEquality(LinearExpr target, IEnumerable exprs) - { - Constraint ct = new Constraint(model_); - LinearArgumentProto args = new LinearArgumentProto(); - foreach (LinearExpr expr in exprs) - { - args.Exprs.Add(GetLinearExpressionProto(expr, /*negate=*/true)); - } - args.Target = GetLinearExpressionProto(target, /*negate=*/true); - ct.Proto.LinMax = args; - return ct; - } - - public Constraint AddMaxEquality(IntVar target, IEnumerable vars) - { - Constraint ct = new Constraint(model_); - LinearArgumentProto args = new LinearArgumentProto(); - foreach (IntVar var in vars) - { - args.Exprs.Add(GetLinearExpressionProto(var)); - } - args.Target = GetLinearExpressionProto(target); - ct.Proto.LinMax = args; - return ct; - } - - public Constraint AddMaxEquality(LinearExpr target, IEnumerable exprs) - { - Constraint ct = new Constraint(model_); - LinearArgumentProto args = new LinearArgumentProto(); - foreach (LinearExpr expr in exprs) - { - args.Exprs.Add(GetLinearExpressionProto(expr)); - } - args.Target = GetLinearExpressionProto(target); - ct.Proto.LinMax = args; - return ct; - } - - public Constraint AddDivisionEquality(T target, N num, D denom) - { - Constraint ct = new Constraint(model_); - LinearArgumentProto args = new LinearArgumentProto(); - args.Exprs.Add(GetLinearExpressionProto(GetLinearExpr(num))); - args.Exprs.Add(GetLinearExpressionProto(GetLinearExpr(denom))); - args.Target = GetLinearExpressionProto(GetLinearExpr(target)); - ct.Proto.IntDiv = args; - return ct; - } - - public Constraint AddAbsEquality(LinearExpr target, LinearExpr expr) - { - Constraint ct = new Constraint(model_); - LinearArgumentProto args = new LinearArgumentProto(); - args.Exprs.Add(GetLinearExpressionProto(expr)); - args.Exprs.Add(GetLinearExpressionProto(expr, /*negate=*/true)); - args.Target = GetLinearExpressionProto(target); - ct.Proto.LinMax = args; - return ct; - } - - public Constraint AddModuloEquality(T target, V v, M m) - { - Constraint ct = new Constraint(model_); - LinearArgumentProto args = new LinearArgumentProto(); - args.Exprs.Add(GetLinearExpressionProto(GetLinearExpr(v))); - args.Exprs.Add(GetLinearExpressionProto(GetLinearExpr(m))); - args.Target = GetLinearExpressionProto(GetLinearExpr(target)); - ct.Proto.IntMod = args; - return ct; - } - - public Constraint AddMultiplicationEquality(LinearExpr target, IEnumerable vars) - { - Constraint ct = new Constraint(model_); - LinearArgumentProto args = new LinearArgumentProto(); - args.Target = GetLinearExpressionProto(target); - foreach (IntVar var in vars) - { - args.Exprs.Add(GetLinearExpressionProto(var)); - } - ct.Proto.IntProd = args; - return ct; - } - - public Constraint AddMultiplicationEquality(LinearExpr target, IEnumerable exprs) - { - Constraint ct = new Constraint(model_); - LinearArgumentProto args = new LinearArgumentProto(); - args.Target = GetLinearExpressionProto(target); - foreach (LinearExpr expr in exprs) - { - args.Exprs.Add(GetLinearExpressionProto(expr)); - } - ct.Proto.IntProd = args; - return ct; - } - - public Constraint AddProdEquality(IntVar target, IEnumerable vars) - { - return AddMultiplicationEquality(target, vars); - } - - // Scheduling support - - public IntervalVar NewIntervalVar(S start, D duration, E end, string name) - { - LinearExpr startExpr = GetLinearExpr(start); - LinearExpr durationExpr = GetLinearExpr(duration); - LinearExpr endExpr = GetLinearExpr(end); - Add(startExpr + durationExpr == endExpr); - - LinearExpressionProto startProto = GetLinearExpressionProto(startExpr); - LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr); - LinearExpressionProto endProto = GetLinearExpressionProto(endExpr); - return new IntervalVar(model_, startProto, durationProto, endProto, name); - } - - public IntervalVar NewFixedSizeIntervalVar(S start, long duration, string name) - { - LinearExpr startExpr = GetLinearExpr(start); - LinearExpr durationExpr = GetLinearExpr(duration); - LinearExpr endExpr = LinearExpr.Sum(new LinearExpr[] { startExpr, durationExpr }); - - LinearExpressionProto startProto = GetLinearExpressionProto(startExpr); - LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr); - LinearExpressionProto endProto = GetLinearExpressionProto(endExpr); - return new IntervalVar(model_, startProto, durationProto, endProto, name); - } - - public IntervalVar NewOptionalIntervalVar(S start, D duration, E end, ILiteral is_present, string name) - { - LinearExpr startExpr = GetLinearExpr(start); - LinearExpr durationExpr = GetLinearExpr(duration); - LinearExpr endExpr = GetLinearExpr(end); - Add(startExpr + durationExpr == endExpr).OnlyEnforceIf(is_present); - - LinearExpressionProto startProto = GetLinearExpressionProto(startExpr); - LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr); - LinearExpressionProto endProto = GetLinearExpressionProto(endExpr); - return new IntervalVar(model_, startProto, durationProto, endProto, is_present.GetIndex(), name); - } - - public IntervalVar NewOptionalFixedSizeIntervalVar(S start, long duration, ILiteral is_present, string name) - { - LinearExpr startExpr = GetLinearExpr(start); - LinearExpr durationExpr = GetLinearExpr(duration); - LinearExpr endExpr = LinearExpr.Sum(new LinearExpr[] { startExpr, durationExpr }); - - LinearExpressionProto startProto = GetLinearExpressionProto(startExpr); - LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr); - LinearExpressionProto endProto = GetLinearExpressionProto(endExpr); - return new IntervalVar(model_, startProto, durationProto, endProto, is_present.GetIndex(), name); - } - - public Constraint AddNoOverlap(IEnumerable intervals) - { - Constraint ct = new Constraint(model_); - NoOverlapConstraintProto args = new NoOverlapConstraintProto(); - foreach (IntervalVar var in intervals) - { - args.Intervals.Add(var.GetIndex()); - } - ct.Proto.NoOverlap = args; - return ct; - } - - public Constraint AddNoOverlap2D(IEnumerable x_intervals, IEnumerable y_intervals) - { - Constraint ct = new Constraint(model_); - NoOverlap2DConstraintProto args = new NoOverlap2DConstraintProto(); - foreach (IntervalVar var in x_intervals) - { - args.XIntervals.Add(var.GetIndex()); - } - foreach (IntervalVar var in y_intervals) - { - args.YIntervals.Add(var.GetIndex()); - } - ct.Proto.NoOverlap2D = args; - return ct; - } - - public Constraint AddCumulative(IEnumerable intervals, IEnumerable demands, C capacity) - { - Constraint ct = new Constraint(model_); - CumulativeConstraintProto cumul = new CumulativeConstraintProto(); - foreach (IntervalVar var in intervals) - { - cumul.Intervals.Add(var.GetIndex()); - } - foreach (D demand in demands) - { - LinearExpr demandExpr = GetLinearExpr(demand); - cumul.Demands.Add(GetLinearExpressionProto(demandExpr)); - } - LinearExpr capacityExpr = GetLinearExpr(capacity); - cumul.Capacity = GetLinearExpressionProto(capacityExpr); - ct.Proto.Cumulative = cumul; - return ct; - } - - // Objective. - public void Minimize(LinearExpr obj) - { - SetObjective(obj, true); - } - - public void Maximize(LinearExpr obj) - { - SetObjective(obj, false); - } - - public void Minimize() - { - SetObjective(null, true); - } - - public void Maximize() - { - SetObjective(null, false); - } - - public void AddVarToObjective(IntVar var) - { - if ((Object)var == null) - return; - model_.Objective.Vars.Add(var.Index); - model_.Objective.Coeffs.Add(model_.Objective.ScalingFactor > 0 ? 1 : -1); - } - - public void AddTermToObjective(IntVar var, long coeff) - { - if (coeff == 0 || (Object)var == null) - return; - model_.Objective.Vars.Add(var.Index); - model_.Objective.Coeffs.Add(model_.Objective.ScalingFactor > 0 ? coeff : -coeff); - } - - bool HasObjective() - { - return model_.Objective == null; - } - - // Search Decision. - - public void AddDecisionStrategy(IEnumerable vars, - DecisionStrategyProto.Types.VariableSelectionStrategy var_str, - DecisionStrategyProto.Types.DomainReductionStrategy dom_str) - { - DecisionStrategyProto ds = new DecisionStrategyProto(); - foreach (IntVar var in vars) - { - ds.Variables.Add(var.Index); - } - ds.VariableSelectionStrategy = var_str; - ds.DomainReductionStrategy = dom_str; - model_.SearchStrategy.Add(ds); - } - - public void AddHint(IntVar var, long value) - { - if (model_.SolutionHint == null) - { - model_.SolutionHint = new PartialVariableAssignment(); - } - model_.SolutionHint.Vars.Add(var.GetIndex()); - model_.SolutionHint.Values.Add(value); - } - - public void ClearHints() - { - model_.SolutionHint = null; - } - - public void AddAssumption(ILiteral lit) - { - model_.Assumptions.Add(lit.GetIndex()); - } - - public void AddAssumptions(IEnumerable literals) - { - foreach (ILiteral lit in literals) - { - AddAssumption(lit); - } - } - - public void ClearAssumptions() - { - model_.Assumptions.Clear(); - } - - // Internal methods. - - void SetObjective(LinearExpr obj, bool minimize) - { - CpObjectiveProto objective = new CpObjectiveProto(); - if (obj == null) - { - objective.Offset = 0L; - objective.ScalingFactor = minimize ? 1L : -1; - } - else if (obj is IntVar) - { - objective.Offset = 0L; - objective.Vars.Add(obj.Index); - if (minimize) - { - objective.Coeffs.Add(1L); - objective.ScalingFactor = 1L; - } - else - { - objective.Coeffs.Add(-1L); - objective.ScalingFactor = -1L; - } - } - else - { - Dictionary dict = new Dictionary(); - long constant = LinearExpr.GetVarValueMap(obj, 1L, dict); - if (minimize) - { - objective.ScalingFactor = 1L; - objective.Offset = constant; - } - else - { - objective.ScalingFactor = -1L; - objective.Offset = -constant; - } - foreach (KeyValuePair it in dict) - { - objective.Vars.Add(it.Key.Index); - objective.Coeffs.Add(minimize ? it.Value : -it.Value); - } - } - model_.Objective = objective; - } - public String ModelStats() - { - return CpSatHelper.ModelStats(model_); - } - - public Boolean ExportToFile(String filename) - { - return CpSatHelper.WriteModelToFile(model_, filename); - } - - public String Validate() - { - return CpSatHelper.ValidateModel(model_); - } - - private int ConvertConstant(long value) - { - if (constant_map_.ContainsKey(value)) - { - return constant_map_[value]; - } - else - { - int index = model_.Variables.Count; - IntegerVariableProto var = new IntegerVariableProto(); - var.Domain.Add(value); - var.Domain.Add(value); - constant_map_.Add(value, index); - model_.Variables.Add(var); - return index; - } - } - - private int GetOrCreateIndex(X x) - { - if (typeof(X) == typeof(IntVar)) - { - IntVar vx = (IntVar)(Object)x; - return vx.Index; - } - if (typeof(X) == typeof(long) || typeof(X) == typeof(int)) - { - return ConvertConstant(Convert.ToInt64(x)); - } - throw new ArgumentException("Cannot extract index from argument"); - } - - private LinearExpr GetLinearExpr(X x) - { - if (typeof(X) == typeof(IntVar)) - { - return (IntVar)(Object)x; - } - if (typeof(X) == typeof(long) || typeof(X) == typeof(int) || typeof(X) == typeof(short)) - { - return new ConstantExpr(Convert.ToInt64(x)); - } - if (typeof(X) == typeof(LinearExpr)) - { - return (LinearExpr)(Object)x; - } - throw new ArgumentException("Cannot convert argument to LinearExpr"); - } - - private LinearExpressionProto GetLinearExpressionProto(LinearExpr expr, bool negate = false) - { - Dictionary dict = new Dictionary(); - long constant = LinearExpr.GetVarValueMap(expr, 1L, dict); - long mult = negate ? -1 : 1; - LinearExpressionProto linear = new LinearExpressionProto(); - foreach (KeyValuePair term in dict) - { - linear.Vars.Add(term.Key.Index); - linear.Coeffs.Add(term.Value * mult); - } - linear.Offset = constant * mult; - return linear; - } - - private CpModelProto model_; - private Dictionary constant_map_; + model_ = new CpModelProto(); + constant_map_ = new Dictionary(); } + // Getters. + + public CpModelProto Model + { + get { + return model_; + } + } + + int Negated(int index) + { + return -index - 1; + } + + // Integer variables and constraints. + + public IntVar NewIntVar(long lb, long ub, string name) + { + return new IntVar(model_, new Domain(lb, ub), name); + } + + public IntVar NewIntVarFromDomain(Domain domain, string name) + { + return new IntVar(model_, domain, name); + } + // Constants (named or not). + + // TODO: Cache constant. + public IntVar NewConstant(long value) + { + return new IntVar(model_, new Domain(value), String.Format("{0}", value)); + } + + public IntVar NewConstant(long value, string name) + { + return new IntVar(model_, new Domain(value), name); + } + + public IntVar NewBoolVar(string name) + { + return new IntVar(model_, new Domain(0, 1), name); + } + + public Constraint AddLinearConstraint(LinearExpr linear_expr, long lb, long ub) + { + return AddLinearExpressionInDomain(linear_expr, new Domain(lb, ub)); + } + + public Constraint AddLinearExpressionInDomain(LinearExpr linear_expr, Domain domain) + { + Dictionary dict = new Dictionary(); + long constant = LinearExpr.GetVarValueMap(linear_expr, 1L, dict); + Constraint ct = new Constraint(model_); + LinearConstraintProto linear = new LinearConstraintProto(); + foreach (KeyValuePair term in dict) + { + linear.Vars.Add(term.Key.Index); + linear.Coeffs.Add(term.Value); + } + foreach (long value in domain.FlattenedIntervals()) + { + if (value == Int64.MinValue || value == Int64.MaxValue) + { + linear.Domain.Add(value); + } + else + { + linear.Domain.Add(value - constant); + } + } + ct.Proto.Linear = linear; + return ct; + } + + public Constraint Add(BoundedLinearExpression lin) + { + switch (lin.CtType) + { + case BoundedLinearExpression.Type.BoundExpression: { + return AddLinearExpressionInDomain(lin.Left, new Domain(lin.Lb, lin.Ub)); + } + case BoundedLinearExpression.Type.VarEqVar: { + return AddLinearExpressionInDomain(lin.Left - lin.Right, new Domain(0)); + } + case BoundedLinearExpression.Type.VarDiffVar: { + return AddLinearExpressionInDomain( + lin.Left - lin.Right, Domain.FromFlatIntervals(new long[] { Int64.MinValue, -1, 1, Int64.MaxValue })); + } + case BoundedLinearExpression.Type.VarEqCst: { + return AddLinearExpressionInDomain(lin.Left, new Domain(lin.Lb, lin.Lb)); + } + case BoundedLinearExpression.Type.VarDiffCst: { + return AddLinearExpressionInDomain( + lin.Left, + Domain.FromFlatIntervals(new long[] { Int64.MinValue, lin.Lb - 1, lin.Lb + 1, Int64.MaxValue })); + } + } + return null; + } + + public Constraint AddAllDifferent(IEnumerable vars) + { + Constraint ct = new Constraint(model_); + AllDifferentConstraintProto alldiff = new AllDifferentConstraintProto(); + foreach (IntVar var in vars) + { + alldiff.Exprs.Add(GetLinearExpressionProto(var)); + } + ct.Proto.AllDiff = alldiff; + return ct; + } + + public Constraint AddAllDifferent(IEnumerable exprs) + { + Constraint ct = new Constraint(model_); + AllDifferentConstraintProto alldiff = new AllDifferentConstraintProto(); + foreach (LinearExpr expr in exprs) + { + alldiff.Exprs.Add(GetLinearExpressionProto(expr)); + } + + ct.Proto.AllDiff = alldiff; + return ct; + } + + public Constraint AddElement(IntVar index, IEnumerable vars, IntVar target) + { + Constraint ct = new Constraint(model_); + ElementConstraintProto element = new ElementConstraintProto(); + element.Index = index.Index; + foreach (IntVar var in vars) + { + element.Vars.Add(var.Index); + } + element.Target = target.Index; + ct.Proto.Element = element; + return ct; + } + + public Constraint AddElement(IntVar index, IEnumerable values, IntVar target) + { + Constraint ct = new Constraint(model_); + ElementConstraintProto element = new ElementConstraintProto(); + element.Index = index.Index; + foreach (long value in values) + { + element.Vars.Add(ConvertConstant(value)); + } + element.Target = target.Index; + ct.Proto.Element = element; + return ct; + } + + public Constraint AddElement(IntVar index, IEnumerable values, IntVar target) + { + Constraint ct = new Constraint(model_); + ElementConstraintProto element = new ElementConstraintProto(); + element.Index = index.Index; + foreach (int value in values) + { + element.Vars.Add(ConvertConstant(value)); + } + element.Target = target.Index; + ct.Proto.Element = element; + return ct; + } + + public Constraint AddCircuit(IEnumerable> arcs) + { + Constraint ct = new Constraint(model_); + CircuitConstraintProto circuit = new CircuitConstraintProto(); + foreach (var arc in arcs) + { + circuit.Tails.Add(arc.Item1); + circuit.Heads.Add(arc.Item2); + circuit.Literals.Add(arc.Item3.GetIndex()); + } + ct.Proto.Circuit = circuit; + return ct; + } + + public Constraint AddAllowedAssignments(IEnumerable vars, long[,] tuples) + { + Constraint ct = new Constraint(model_); + TableConstraintProto table = new TableConstraintProto(); + foreach (IntVar var in vars) + { + table.Vars.Add(var.Index); + } + for (int i = 0; i < tuples.GetLength(0); ++i) + { + for (int j = 0; j < tuples.GetLength(1); ++j) + { + table.Values.Add(tuples[i, j]); + } + } + ct.Proto.Table = table; + return ct; + } + + public Constraint AddForbiddenAssignments(IEnumerable vars, long[,] tuples) + { + Constraint ct = AddAllowedAssignments(vars, tuples); + ct.Proto.Table.Negated = true; + return ct; + } + + public Constraint AddAutomaton(IEnumerable vars, long starting_state, long[,] transitions, + IEnumerable final_states) + { + Constraint ct = new Constraint(model_); + AutomatonConstraintProto aut = new AutomatonConstraintProto(); + foreach (IntVar var in vars) + { + aut.Vars.Add(var.Index); + } + aut.StartingState = starting_state; + foreach (long f in final_states) + { + aut.FinalStates.Add(f); + } + for (int i = 0; i < transitions.GetLength(0); ++i) + { + aut.TransitionTail.Add(transitions[i, 0]); + aut.TransitionLabel.Add(transitions[i, 1]); + aut.TransitionHead.Add(transitions[i, 2]); + } + + ct.Proto.Automaton = aut; + return ct; + } + + public Constraint AddAutomaton(IEnumerable vars, long starting_state, + IEnumerable> transitions, IEnumerable final_states) + { + Constraint ct = new Constraint(model_); + AutomatonConstraintProto aut = new AutomatonConstraintProto(); + foreach (IntVar var in vars) + { + aut.Vars.Add(var.Index); + } + aut.StartingState = starting_state; + foreach (long f in final_states) + { + aut.FinalStates.Add(f); + } + foreach (Tuple transition in transitions) + { + aut.TransitionHead.Add(transition.Item1); + aut.TransitionLabel.Add(transition.Item2); + aut.TransitionTail.Add(transition.Item3); + } + + ct.Proto.Automaton = aut; + return ct; + } + + public Constraint AddInverse(IEnumerable direct, IEnumerable reverse) + { + Constraint ct = new Constraint(model_); + InverseConstraintProto inverse = new InverseConstraintProto(); + foreach (IntVar var in direct) + { + inverse.FDirect.Add(var.Index); + } + foreach (IntVar var in reverse) + { + inverse.FInverse.Add(var.Index); + } + ct.Proto.Inverse = inverse; + return ct; + } + + public Constraint AddReservoirConstraint(IEnumerable times, IEnumerable levelChanges, long minLevel, + long maxLevel) + { + Constraint ct = new Constraint(model_); + ReservoirConstraintProto res = new ReservoirConstraintProto(); + foreach (IntVar time in times) + { + res.TimeExprs.Add(GetLinearExpressionProto(time)); + } + foreach (I d in levelChanges) + { + res.LevelChanges.Add(Convert.ToInt64(d)); + } + + res.MinLevel = minLevel; + res.MaxLevel = maxLevel; + ct.Proto.Reservoir = res; + + return ct; + } + + public Constraint AddReservoirConstraintWithActive(IEnumerable times, IEnumerable levelChanges, + IEnumerable actives, long minLevel, long maxLevel) + { + Constraint ct = new Constraint(model_); + ReservoirConstraintProto res = new ReservoirConstraintProto(); + foreach (IntVar time in times) + { + res.TimeExprs.Add(GetLinearExpressionProto(time)); + } + foreach (I d in levelChanges) + { + res.LevelChanges.Add(Convert.ToInt64(d)); + } + foreach (IntVar var in actives) + { + res.ActiveLiterals.Add(var.Index); + } + res.MinLevel = minLevel; + res.MaxLevel = maxLevel; + ct.Proto.Reservoir = res; + + return ct; + } + + public Constraint AddReservoirConstraint(IEnumerable times, IEnumerable levelChanges, + long minLevel, long maxLevel) + { + Constraint ct = new Constraint(model_); + ReservoirConstraintProto res = new ReservoirConstraintProto(); + foreach (LinearExpr time in times) + { + res.TimeExprs.Add(GetLinearExpressionProto(time)); + } + foreach (I d in levelChanges) + { + res.LevelChanges.Add(Convert.ToInt64(d)); + } + + res.MinLevel = minLevel; + res.MaxLevel = maxLevel; + ct.Proto.Reservoir = res; + + return ct; + } + + public Constraint AddReservoirConstraintWithActive(IEnumerable times, IEnumerable levelChanges, + IEnumerable actives, long minLevel, long maxLevel) + { + Constraint ct = new Constraint(model_); + ReservoirConstraintProto res = new ReservoirConstraintProto(); + foreach (LinearExpr time in times) + { + res.TimeExprs.Add(GetLinearExpressionProto(time)); + } + foreach (I d in levelChanges) + { + res.LevelChanges.Add(Convert.ToInt64(d)); + } + foreach (IntVar var in actives) + { + res.ActiveLiterals.Add(var.Index); + } + res.MinLevel = minLevel; + res.MaxLevel = maxLevel; + ct.Proto.Reservoir = res; + + return ct; + } + + public void AddMapDomain(IntVar var, IEnumerable bool_vars, long offset = 0) + { + int i = 0; + foreach (IntVar bool_var in bool_vars) + { + int b_index = bool_var.Index; + int var_index = var.Index; + + ConstraintProto ct1 = new ConstraintProto(); + LinearConstraintProto lin1 = new LinearConstraintProto(); + lin1.Vars.Add(var_index); + lin1.Coeffs.Add(1L); + lin1.Domain.Add(offset + i); + lin1.Domain.Add(offset + i); + ct1.Linear = lin1; + ct1.EnforcementLiteral.Add(b_index); + model_.Constraints.Add(ct1); + + ConstraintProto ct2 = new ConstraintProto(); + LinearConstraintProto lin2 = new LinearConstraintProto(); + lin2.Vars.Add(var_index); + lin2.Coeffs.Add(1L); + lin2.Domain.Add(Int64.MinValue); + lin2.Domain.Add(offset + i - 1); + lin2.Domain.Add(offset + i + 1); + lin2.Domain.Add(Int64.MaxValue); + ct2.Linear = lin2; + ct2.EnforcementLiteral.Add(-b_index - 1); + model_.Constraints.Add(ct2); + + i++; + } + } + + public Constraint AddImplication(ILiteral a, ILiteral b) + { + Constraint ct = new Constraint(model_); + BoolArgumentProto or = new BoolArgumentProto(); + or.Literals.Add(a.Not().GetIndex()); + or.Literals.Add(b.GetIndex()); + ct.Proto.BoolOr = or; + return ct; + } + + public Constraint AddBoolOr(IEnumerable literals) + { + Constraint ct = new Constraint(model_); + BoolArgumentProto bool_argument = new BoolArgumentProto(); + foreach (ILiteral lit in literals) + { + bool_argument.Literals.Add(lit.GetIndex()); + } + ct.Proto.BoolOr = bool_argument; + return ct; + } + + public Constraint AddBoolAnd(IEnumerable literals) + { + Constraint ct = new Constraint(model_); + BoolArgumentProto bool_argument = new BoolArgumentProto(); + foreach (ILiteral lit in literals) + { + bool_argument.Literals.Add(lit.GetIndex()); + } + ct.Proto.BoolAnd = bool_argument; + return ct; + } + + public Constraint AddBoolXor(IEnumerable literals) + { + Constraint ct = new Constraint(model_); + BoolArgumentProto bool_argument = new BoolArgumentProto(); + foreach (ILiteral lit in literals) + { + bool_argument.Literals.Add(lit.GetIndex()); + } + ct.Proto.BoolXor = bool_argument; + return ct; + } + + public Constraint AddMinEquality(LinearExpr target, IEnumerable vars) + { + Constraint ct = new Constraint(model_); + LinearArgumentProto args = new LinearArgumentProto(); + foreach (IntVar var in vars) + { + args.Exprs.Add(GetLinearExpressionProto(var, /*negate=*/true)); + } + args.Target = GetLinearExpressionProto(target, /*negate=*/true); + ct.Proto.LinMax = args; + return ct; + } + + public Constraint AddMinEquality(LinearExpr target, IEnumerable exprs) + { + Constraint ct = new Constraint(model_); + LinearArgumentProto args = new LinearArgumentProto(); + foreach (LinearExpr expr in exprs) + { + args.Exprs.Add(GetLinearExpressionProto(expr, /*negate=*/true)); + } + args.Target = GetLinearExpressionProto(target, /*negate=*/true); + ct.Proto.LinMax = args; + return ct; + } + + public Constraint AddMaxEquality(IntVar target, IEnumerable vars) + { + Constraint ct = new Constraint(model_); + LinearArgumentProto args = new LinearArgumentProto(); + foreach (IntVar var in vars) + { + args.Exprs.Add(GetLinearExpressionProto(var)); + } + args.Target = GetLinearExpressionProto(target); + ct.Proto.LinMax = args; + return ct; + } + + public Constraint AddMaxEquality(LinearExpr target, IEnumerable exprs) + { + Constraint ct = new Constraint(model_); + LinearArgumentProto args = new LinearArgumentProto(); + foreach (LinearExpr expr in exprs) + { + args.Exprs.Add(GetLinearExpressionProto(expr)); + } + args.Target = GetLinearExpressionProto(target); + ct.Proto.LinMax = args; + return ct; + } + + public Constraint AddDivisionEquality(T target, N num, D denom) + { + Constraint ct = new Constraint(model_); + LinearArgumentProto args = new LinearArgumentProto(); + args.Exprs.Add(GetLinearExpressionProto(GetLinearExpr(num))); + args.Exprs.Add(GetLinearExpressionProto(GetLinearExpr(denom))); + args.Target = GetLinearExpressionProto(GetLinearExpr(target)); + ct.Proto.IntDiv = args; + return ct; + } + + public Constraint AddAbsEquality(LinearExpr target, LinearExpr expr) + { + Constraint ct = new Constraint(model_); + LinearArgumentProto args = new LinearArgumentProto(); + args.Exprs.Add(GetLinearExpressionProto(expr)); + args.Exprs.Add(GetLinearExpressionProto(expr, /*negate=*/true)); + args.Target = GetLinearExpressionProto(target); + ct.Proto.LinMax = args; + return ct; + } + + public Constraint AddModuloEquality(T target, V v, M m) + { + Constraint ct = new Constraint(model_); + LinearArgumentProto args = new LinearArgumentProto(); + args.Exprs.Add(GetLinearExpressionProto(GetLinearExpr(v))); + args.Exprs.Add(GetLinearExpressionProto(GetLinearExpr(m))); + args.Target = GetLinearExpressionProto(GetLinearExpr(target)); + ct.Proto.IntMod = args; + return ct; + } + + public Constraint AddMultiplicationEquality(LinearExpr target, IEnumerable vars) + { + Constraint ct = new Constraint(model_); + LinearArgumentProto args = new LinearArgumentProto(); + args.Target = GetLinearExpressionProto(target); + foreach (IntVar var in vars) + { + args.Exprs.Add(GetLinearExpressionProto(var)); + } + ct.Proto.IntProd = args; + return ct; + } + + public Constraint AddMultiplicationEquality(LinearExpr target, IEnumerable exprs) + { + Constraint ct = new Constraint(model_); + LinearArgumentProto args = new LinearArgumentProto(); + args.Target = GetLinearExpressionProto(target); + foreach (LinearExpr expr in exprs) + { + args.Exprs.Add(GetLinearExpressionProto(expr)); + } + ct.Proto.IntProd = args; + return ct; + } + + public Constraint AddProdEquality(IntVar target, IEnumerable vars) + { + return AddMultiplicationEquality(target, vars); + } + + // Scheduling support + + public IntervalVar NewIntervalVar(S start, D duration, E end, string name) + { + LinearExpr startExpr = GetLinearExpr(start); + LinearExpr durationExpr = GetLinearExpr(duration); + LinearExpr endExpr = GetLinearExpr(end); + Add(startExpr + durationExpr == endExpr); + + LinearExpressionProto startProto = GetLinearExpressionProto(startExpr); + LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr); + LinearExpressionProto endProto = GetLinearExpressionProto(endExpr); + return new IntervalVar(model_, startProto, durationProto, endProto, name); + } + + public IntervalVar NewFixedSizeIntervalVar(S start, long duration, string name) + { + LinearExpr startExpr = GetLinearExpr(start); + LinearExpr durationExpr = GetLinearExpr(duration); + LinearExpr endExpr = LinearExpr.Sum(new LinearExpr[] { startExpr, durationExpr }); + + LinearExpressionProto startProto = GetLinearExpressionProto(startExpr); + LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr); + LinearExpressionProto endProto = GetLinearExpressionProto(endExpr); + return new IntervalVar(model_, startProto, durationProto, endProto, name); + } + + public IntervalVar NewOptionalIntervalVar(S start, D duration, E end, ILiteral is_present, string name) + { + LinearExpr startExpr = GetLinearExpr(start); + LinearExpr durationExpr = GetLinearExpr(duration); + LinearExpr endExpr = GetLinearExpr(end); + Add(startExpr + durationExpr == endExpr).OnlyEnforceIf(is_present); + + LinearExpressionProto startProto = GetLinearExpressionProto(startExpr); + LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr); + LinearExpressionProto endProto = GetLinearExpressionProto(endExpr); + return new IntervalVar(model_, startProto, durationProto, endProto, is_present.GetIndex(), name); + } + + public IntervalVar NewOptionalFixedSizeIntervalVar(S start, long duration, ILiteral is_present, string name) + { + LinearExpr startExpr = GetLinearExpr(start); + LinearExpr durationExpr = GetLinearExpr(duration); + LinearExpr endExpr = LinearExpr.Sum(new LinearExpr[] { startExpr, durationExpr }); + + LinearExpressionProto startProto = GetLinearExpressionProto(startExpr); + LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr); + LinearExpressionProto endProto = GetLinearExpressionProto(endExpr); + return new IntervalVar(model_, startProto, durationProto, endProto, is_present.GetIndex(), name); + } + + public Constraint AddNoOverlap(IEnumerable intervals) + { + Constraint ct = new Constraint(model_); + NoOverlapConstraintProto args = new NoOverlapConstraintProto(); + foreach (IntervalVar var in intervals) + { + args.Intervals.Add(var.GetIndex()); + } + ct.Proto.NoOverlap = args; + return ct; + } + + public Constraint AddNoOverlap2D(IEnumerable x_intervals, IEnumerable y_intervals) + { + Constraint ct = new Constraint(model_); + NoOverlap2DConstraintProto args = new NoOverlap2DConstraintProto(); + foreach (IntervalVar var in x_intervals) + { + args.XIntervals.Add(var.GetIndex()); + } + foreach (IntervalVar var in y_intervals) + { + args.YIntervals.Add(var.GetIndex()); + } + ct.Proto.NoOverlap2D = args; + return ct; + } + + public Constraint AddCumulative(IEnumerable intervals, IEnumerable demands, C capacity) + { + Constraint ct = new Constraint(model_); + CumulativeConstraintProto cumul = new CumulativeConstraintProto(); + foreach (IntervalVar var in intervals) + { + cumul.Intervals.Add(var.GetIndex()); + } + foreach (D demand in demands) + { + LinearExpr demandExpr = GetLinearExpr(demand); + cumul.Demands.Add(GetLinearExpressionProto(demandExpr)); + } + LinearExpr capacityExpr = GetLinearExpr(capacity); + cumul.Capacity = GetLinearExpressionProto(capacityExpr); + ct.Proto.Cumulative = cumul; + return ct; + } + + // Objective. + public void Minimize(LinearExpr obj) + { + SetObjective(obj, true); + } + + public void Maximize(LinearExpr obj) + { + SetObjective(obj, false); + } + + public void Minimize() + { + SetObjective(null, true); + } + + public void Maximize() + { + SetObjective(null, false); + } + + public void AddVarToObjective(IntVar var) + { + if ((Object)var == null) + return; + model_.Objective.Vars.Add(var.Index); + model_.Objective.Coeffs.Add(model_.Objective.ScalingFactor > 0 ? 1 : -1); + } + + public void AddTermToObjective(IntVar var, long coeff) + { + if (coeff == 0 || (Object)var == null) + return; + model_.Objective.Vars.Add(var.Index); + model_.Objective.Coeffs.Add(model_.Objective.ScalingFactor > 0 ? coeff : -coeff); + } + + bool HasObjective() + { + return model_.Objective == null; + } + + // Search Decision. + + public void AddDecisionStrategy(IEnumerable vars, + DecisionStrategyProto.Types.VariableSelectionStrategy var_str, + DecisionStrategyProto.Types.DomainReductionStrategy dom_str) + { + DecisionStrategyProto ds = new DecisionStrategyProto(); + foreach (IntVar var in vars) + { + ds.Variables.Add(var.Index); + } + ds.VariableSelectionStrategy = var_str; + ds.DomainReductionStrategy = dom_str; + model_.SearchStrategy.Add(ds); + } + + public void AddHint(IntVar var, long value) + { + if (model_.SolutionHint == null) + { + model_.SolutionHint = new PartialVariableAssignment(); + } + model_.SolutionHint.Vars.Add(var.GetIndex()); + model_.SolutionHint.Values.Add(value); + } + + public void ClearHints() + { + model_.SolutionHint = null; + } + + public void AddAssumption(ILiteral lit) + { + model_.Assumptions.Add(lit.GetIndex()); + } + + public void AddAssumptions(IEnumerable literals) + { + foreach (ILiteral lit in literals) + { + AddAssumption(lit); + } + } + + public void ClearAssumptions() + { + model_.Assumptions.Clear(); + } + + // Internal methods. + + void SetObjective(LinearExpr obj, bool minimize) + { + CpObjectiveProto objective = new CpObjectiveProto(); + if (obj == null) + { + objective.Offset = 0L; + objective.ScalingFactor = minimize ? 1L : -1; + } + else if (obj is IntVar) + { + objective.Offset = 0L; + objective.Vars.Add(obj.Index); + if (minimize) + { + objective.Coeffs.Add(1L); + objective.ScalingFactor = 1L; + } + else + { + objective.Coeffs.Add(-1L); + objective.ScalingFactor = -1L; + } + } + else + { + Dictionary dict = new Dictionary(); + long constant = LinearExpr.GetVarValueMap(obj, 1L, dict); + if (minimize) + { + objective.ScalingFactor = 1L; + objective.Offset = constant; + } + else + { + objective.ScalingFactor = -1L; + objective.Offset = -constant; + } + foreach (KeyValuePair it in dict) + { + objective.Vars.Add(it.Key.Index); + objective.Coeffs.Add(minimize ? it.Value : -it.Value); + } + } + model_.Objective = objective; + } + public String ModelStats() + { + return CpSatHelper.ModelStats(model_); + } + + public Boolean ExportToFile(String filename) + { + return CpSatHelper.WriteModelToFile(model_, filename); + } + + public String Validate() + { + return CpSatHelper.ValidateModel(model_); + } + + private int ConvertConstant(long value) + { + if (constant_map_.ContainsKey(value)) + { + return constant_map_[value]; + } + else + { + int index = model_.Variables.Count; + IntegerVariableProto var = new IntegerVariableProto(); + var.Domain.Add(value); + var.Domain.Add(value); + constant_map_.Add(value, index); + model_.Variables.Add(var); + return index; + } + } + + private int GetOrCreateIndex(X x) + { + if (typeof(X) == typeof(IntVar)) + { + IntVar vx = (IntVar)(Object)x; + return vx.Index; + } + if (typeof(X) == typeof(long) || typeof(X) == typeof(int)) + { + return ConvertConstant(Convert.ToInt64(x)); + } + throw new ArgumentException("Cannot extract index from argument"); + } + + private LinearExpr GetLinearExpr(X x) + { + if (typeof(X) == typeof(IntVar)) + { + return (IntVar)(Object)x; + } + if (typeof(X) == typeof(long) || typeof(X) == typeof(int) || typeof(X) == typeof(short)) + { + return new ConstantExpr(Convert.ToInt64(x)); + } + if (typeof(X) == typeof(LinearExpr)) + { + return (LinearExpr)(Object)x; + } + throw new ArgumentException("Cannot convert argument to LinearExpr"); + } + + private LinearExpressionProto GetLinearExpressionProto(LinearExpr expr, bool negate = false) + { + Dictionary dict = new Dictionary(); + long constant = LinearExpr.GetVarValueMap(expr, 1L, dict); + long mult = negate ? -1 : 1; + LinearExpressionProto linear = new LinearExpressionProto(); + foreach (KeyValuePair term in dict) + { + linear.Vars.Add(term.Key.Index); + linear.Coeffs.Add(term.Value * mult); + } + linear.Offset = constant * mult; + return linear; + } + + private CpModelProto model_; + private Dictionary constant_map_; +} + } // namespace Google.OrTools.Sat diff --git a/ortools/sat/csharp/CpSolver.cs b/ortools/sat/csharp/CpSolver.cs index 4a5ad66788..654055e01d 100644 --- a/ortools/sat/csharp/CpSolver.cs +++ b/ortools/sat/csharp/CpSolver.cs @@ -17,227 +17,232 @@ using System.Runtime.CompilerServices; namespace Google.OrTools.Sat { - public class CpSolver +public class CpSolver +{ + public CpSolverStatus Solve(CpModel model, SolutionCallback cb = null) { - public CpSolverStatus Solve(CpModel model, SolutionCallback cb = null) + // Setup search. + CreateSolveWrapper(); + if (string_parameters_ != null) { - // Setup search. - CreateSolveWrapper(); - if (string_parameters_ != null) + solve_wrapper_.SetStringParameters(string_parameters_); + } + if (log_callback_ != null) + { + solve_wrapper_.AddLogCallbackFromClass(log_callback_); + } + if (cb != null) + { + solve_wrapper_.AddSolutionCallback(cb); + } + + response_ = solve_wrapper_.Solve(model.Model); + + // Cleanup search. + if (cb != null) + { + solve_wrapper_.ClearSolutionCallback(cb); + } + ReleaseSolveWrapper(); + + return response_.Status; + } + + [ObsoleteAttribute("This method is obsolete. Call Solve instead.", false)] + public CpSolverStatus SolveWithSolutionCallback(CpModel model, SolutionCallback cb) + { + return Solve(model, cb); + } + + public CpSolverStatus SearchAllSolutions(CpModel model, SolutionCallback cb) + { + string old_parameters = string_parameters_; + string_parameters_ += " enumerate_all_solutions:true"; + Solve(model, cb); + string_parameters_ = old_parameters; + return response_.Status; + } + + [MethodImpl(MethodImplOptions.Synchronized)] + public void StopSearch() + { + if (solve_wrapper_ != null) + { + solve_wrapper_.StopSearch(); + } + } + + [MethodImpl(MethodImplOptions.Synchronized)] + private void CreateSolveWrapper() + { + solve_wrapper_ = new SolveWrapper(); + } + + [MethodImpl(MethodImplOptions.Synchronized)] + private void ReleaseSolveWrapper() + { + solve_wrapper_ = null; + } + + public String ResponseStats() + { + return CpSatHelper.SolverResponseStats(response_); + } + + public double ObjectiveValue + { + get { + return response_.ObjectiveValue; + } + } + + public double BestObjectiveBound + { + get { + return response_.BestObjectiveBound; + } + } + + public string StringParameters + { + get { + return string_parameters_; + } + set { + string_parameters_ = value; + } + } + + public void SetLogCallback(StringToVoidDelegate del) + { + log_callback_ = new LogCallbackDelegate(del); + } + + public CpSolverResponse Response + { + get { + return response_; + } + } + + public long Value(LinearExpr e) + { + List exprs = new List(); + List coeffs = new List(); + exprs.Add(e); + coeffs.Add(1L); + long constant = 0; + + while (exprs.Count > 0) + { + LinearExpr expr = exprs[0]; + exprs.RemoveAt(0); + long coeff = coeffs[0]; + coeffs.RemoveAt(0); + if (coeff == 0) + continue; + + if (expr is ProductCst) { - solve_wrapper_.SetStringParameters(string_parameters_); - } - if (log_callback_ != null) - { - solve_wrapper_.AddLogCallbackFromClass(log_callback_); - } - if (cb != null) - { - solve_wrapper_.AddSolutionCallback(cb); - } - - response_ = solve_wrapper_.Solve(model.Model); - - // Cleanup search. - if (cb != null) - { - solve_wrapper_.ClearSolutionCallback(cb); - } - ReleaseSolveWrapper(); - - return response_.Status; - } - - [ObsoleteAttribute("This method is obsolete. Call Solve instead.", false)] - public CpSolverStatus SolveWithSolutionCallback(CpModel model, SolutionCallback cb) - { - return Solve(model, cb); - } - - public CpSolverStatus SearchAllSolutions(CpModel model, SolutionCallback cb) - { - string old_parameters = string_parameters_; - string_parameters_ += " enumerate_all_solutions:true"; - Solve(model, cb); - string_parameters_ = old_parameters; - return response_.Status; - } - - [MethodImpl(MethodImplOptions.Synchronized)] - public void StopSearch() - { - if (solve_wrapper_ != null) - { - solve_wrapper_.StopSearch(); - } - } - - [MethodImpl(MethodImplOptions.Synchronized)] - private void CreateSolveWrapper() - { - solve_wrapper_ = new SolveWrapper(); - } - - [MethodImpl(MethodImplOptions.Synchronized)] - private void ReleaseSolveWrapper() - { - solve_wrapper_ = null; - } - - public String ResponseStats() - { - return CpSatHelper.SolverResponseStats(response_); - } - - public double ObjectiveValue - { - get { - return response_.ObjectiveValue; - } - } - - public double BestObjectiveBound - { - get { - return response_.BestObjectiveBound; - } - } - - public string StringParameters - { - get { - return string_parameters_; - } - set { - string_parameters_ = value; - } - } - - public void SetLogCallback(StringToVoidDelegate del) - { - log_callback_ = new LogCallbackDelegate(del); - } - - public CpSolverResponse Response - { - get { - return response_; - } - } - - public long Value(LinearExpr e) - { - List exprs = new List(); - List coeffs = new List(); - exprs.Add(e); - coeffs.Add(1L); - long constant = 0; - - while (exprs.Count > 0) - { - LinearExpr expr = exprs[0]; - exprs.RemoveAt(0); - long coeff = coeffs[0]; - coeffs.RemoveAt(0); - if (coeff == 0) - continue; - - if (expr is ProductCst) + ProductCst p = (ProductCst)expr; + if (p.Coeff != 0) { - ProductCst p = (ProductCst)expr; - if (p.Coeff != 0) - { - exprs.Add(p.Expr); - coeffs.Add(p.Coeff * coeff); - } - } - else if (expr is SumArray) - { - SumArray a = (SumArray)expr; - constant += coeff * a.Offset; - foreach (LinearExpr sub in a.Expressions) - { - exprs.Add(sub); - coeffs.Add(coeff); - } - } - else if (expr is IntVar) - { - int index = expr.Index; - long value = index >= 0 ? response_.Solution[index] : -response_.Solution[-index - 1]; - constant += coeff * value; - } - else if (expr is NotBooleanVariable) - { - throw new ArgumentException("Cannot evaluate a literal in an integer expression."); - } - else - { - throw new ArgumentException("Cannot evaluate '" + expr.ToString() + "' in an integer expression"); + exprs.Add(p.Expr); + coeffs.Add(p.Coeff * coeff); } } - return constant; - } - - public Boolean BooleanValue(ILiteral literal) - { - if (literal is IntVar || literal is NotBooleanVariable) + else if (expr is SumArray) { - int index = literal.GetIndex(); - if (index >= 0) + SumArray a = (SumArray)expr; + constant += coeff * a.Offset; + foreach (LinearExpr sub in a.Expressions) { - return response_.Solution[index] != 0; - } - else - { - return response_.Solution[-index - 1] == 0; + exprs.Add(sub); + coeffs.Add(coeff); } } + else if (expr is IntVar) + { + int index = expr.Index; + long value = index >= 0 ? response_.Solution[index] : -response_.Solution[-index - 1]; + constant += coeff * value; + } + else if (expr is NotBooleanVariable) + { + throw new ArgumentException("Cannot evaluate a literal in an integer expression."); + } else { - throw new ArgumentException("Cannot evaluate '" + literal.ToString() + "' as a boolean literal"); + throw new ArgumentException("Cannot evaluate '" + expr.ToString() + "' in an integer expression"); } } - - public long NumBranches() - { - return response_.NumBranches; - } - - public long NumConflicts() - { - return response_.NumConflicts; - } - - public double WallTime() - { - return response_.WallTime; - } - - public IList SufficientAssumptionsForInfeasibility() - { - return response_.SufficientAssumptionsForInfeasibility; - } - - private CpSolverResponse response_; - private LogCallback log_callback_; - private string string_parameters_; - private SolveWrapper solve_wrapper_; + return constant; } - class LogCallbackDelegate : LogCallback + public Boolean BooleanValue(ILiteral literal) { - public LogCallbackDelegate(StringToVoidDelegate del) + if (literal is IntVar || literal is NotBooleanVariable) { - this.delegate_ = del; + int index = literal.GetIndex(); + if (index >= 0) + { + return response_.Solution[index] != 0; + } + else + { + return response_.Solution[-index - 1] == 0; + } } - - public override void NewMessage(string message) + else { - delegate_(message); + throw new ArgumentException("Cannot evaluate '" + literal.ToString() + "' as a boolean literal"); } - - private StringToVoidDelegate delegate_; } + public long NumBranches() + { + return response_.NumBranches; + } + + public long NumConflicts() + { + return response_.NumConflicts; + } + + public double WallTime() + { + return response_.WallTime; + } + + public IList SufficientAssumptionsForInfeasibility() + { + return response_.SufficientAssumptionsForInfeasibility; + } + + public String SolutionInfo() + { + return response_.SolutionInfo; + } + + private CpSolverResponse response_; + private LogCallback log_callback_; + private string string_parameters_; + private SolveWrapper solve_wrapper_; +} + +class LogCallbackDelegate : LogCallback +{ + public LogCallbackDelegate(StringToVoidDelegate del) + { + this.delegate_ = del; + } + + public override void NewMessage(string message) + { + delegate_(message); + } + + private StringToVoidDelegate delegate_; +} + } // namespace Google.OrTools.Sat diff --git a/ortools/sat/csharp/IntegerExpressions.cs b/ortools/sat/csharp/IntegerExpressions.cs index f48f73a576..d31c16cf15 100644 --- a/ortools/sat/csharp/IntegerExpressions.cs +++ b/ortools/sat/csharp/IntegerExpressions.cs @@ -13,868 +13,867 @@ namespace Google.OrTools.Sat { - using System; - using System.Collections.Generic; - using Google.OrTools.Util; +using System; +using System.Collections.Generic; +using Google.OrTools.Util; - // Helpers. +// Helpers. - // IntVar[] helper class. - public static class IntVarArrayHelper +// IntVar[] helper class. +public static class IntVarArrayHelper +{ + [Obsolete("This Sum method is deprecated, please use LinearExpr.Sum() instead.")] + public static LinearExpr Sum(this IntVar[] vars) { - [Obsolete("This Sum method is deprecated, please use LinearExpr.Sum() instead.")] - public static LinearExpr Sum(this IntVar[] vars) - { - return LinearExpr.Sum(vars); - } - [Obsolete("This ScalProd method is deprecated, please use LinearExpr.ScalProd() instead.")] - public static LinearExpr ScalProd(this IntVar[] vars, int[] coeffs) - { - return LinearExpr.ScalProd(vars, coeffs); - } - [Obsolete("This ScalProd method is deprecated, please use LinearExpr.ScalProd() instead.")] - public static LinearExpr ScalProd(this IntVar[] vars, long[] coeffs) - { - return LinearExpr.ScalProd(vars, coeffs); - } + return LinearExpr.Sum(vars); + } + [Obsolete("This ScalProd method is deprecated, please use LinearExpr.ScalProd() instead.")] + public static LinearExpr ScalProd(this IntVar[] vars, int[] coeffs) + { + return LinearExpr.ScalProd(vars, coeffs); + } + [Obsolete("This ScalProd method is deprecated, please use LinearExpr.ScalProd() instead.")] + public static LinearExpr ScalProd(this IntVar[] vars, long[] coeffs) + { + return LinearExpr.ScalProd(vars, coeffs); + } +} + +public interface ILiteral +{ + ILiteral Not(); + int GetIndex(); +} + +// Holds a linear expression. +public class LinearExpr +{ + public static LinearExpr Sum(IEnumerable vars) + { + return new SumArray(vars); } - public interface ILiteral + public static LinearExpr Sum(IEnumerable exprs) { - ILiteral Not(); - int GetIndex(); + return new SumArray(exprs); } - // Holds a linear expression. - public class LinearExpr + public static LinearExpr ScalProd(IEnumerable vars, IEnumerable coeffs) { - public static LinearExpr Sum(IEnumerable vars) - { - return new SumArray(vars); - } + return new SumArray(vars, coeffs); + } - public static LinearExpr Sum(IEnumerable exprs) - { - return new SumArray(exprs); - } + public static LinearExpr ScalProd(IEnumerable vars, IEnumerable coeffs) + { + return new SumArray(vars, coeffs); + } - public static LinearExpr ScalProd(IEnumerable vars, IEnumerable coeffs) - { - return new SumArray(vars, coeffs); - } + public static LinearExpr Term(IntVar var, long coeff) + { + return Prod(var, coeff); + } - public static LinearExpr ScalProd(IEnumerable vars, IEnumerable coeffs) - { - return new SumArray(vars, coeffs); - } - - public static LinearExpr Term(IntVar var, long coeff) + public static LinearExpr Affine(IntVar var, long coeff, long offset) + { + if (offset == 0) { return Prod(var, coeff); } - - public static LinearExpr Affine(IntVar var, long coeff, long offset) + else { - if (offset == 0) - { - return Prod(var, coeff); - } - else - { - return new SumArray(Prod(var, coeff), offset); - } + return new SumArray(Prod(var, coeff), offset); } + } - public static LinearExpr Constant(long value) + public static LinearExpr Constant(long value) + { + return new ConstantExpr(value); + } + + public int Index + { + get { + return GetIndex(); + } + } + + public virtual int GetIndex() + { + throw new NotImplementedException(); + } + + public virtual string ShortString() + { + return ToString(); + } + + public static LinearExpr operator +(LinearExpr a, LinearExpr b) + { + return new SumArray(a, b); + } + + public static LinearExpr operator +(LinearExpr a, long v) + { + if (v == 0) { - return new ConstantExpr(value); + return a; } + return new SumArray(a, v); + } - public int Index + public static LinearExpr operator +(long v, LinearExpr a) + { + if (v == 0) { - get { - return GetIndex(); - } + return a; } + return new SumArray(a, v); + } - public virtual int GetIndex() + public static LinearExpr operator -(LinearExpr a, LinearExpr b) + { + return new SumArray(a, Prod(b, -1)); + } + + public static LinearExpr operator -(LinearExpr a, long v) + { + if (v == 0) { - throw new NotImplementedException(); + return a; } + return new SumArray(a, -v); + } - public virtual string ShortString() - { - return ToString(); - } - - public static LinearExpr operator +(LinearExpr a, LinearExpr b) - { - return new SumArray(a, b); - } - - public static LinearExpr operator +(LinearExpr a, long v) - { - if (v == 0) - { - return a; - } - return new SumArray(a, v); - } - - public static LinearExpr operator +(long v, LinearExpr a) - { - if (v == 0) - { - return a; - } - return new SumArray(a, v); - } - - public static LinearExpr operator -(LinearExpr a, LinearExpr b) - { - return new SumArray(a, Prod(b, -1)); - } - - public static LinearExpr operator -(LinearExpr a, long v) - { - if (v == 0) - { - return a; - } - return new SumArray(a, -v); - } - - public static LinearExpr operator -(long v, LinearExpr a) - { - if (v == 0) - { - return Prod(a, -1); - } - return new SumArray(Prod(a, -1), v); - } - - public static LinearExpr operator *(LinearExpr a, long v) - { - return Prod(a, v); - } - - public static LinearExpr operator *(long v, LinearExpr a) - { - return Prod(a, v); - } - - public static LinearExpr operator -(LinearExpr a) + public static LinearExpr operator -(long v, LinearExpr a) + { + if (v == 0) { return Prod(a, -1); } + return new SumArray(Prod(a, -1), v); + } - public static BoundedLinearExpression operator ==(LinearExpr a, LinearExpr b) + public static LinearExpr operator *(LinearExpr a, long v) + { + return Prod(a, v); + } + + public static LinearExpr operator *(long v, LinearExpr a) + { + return Prod(a, v); + } + + public static LinearExpr operator -(LinearExpr a) + { + return Prod(a, -1); + } + + public static BoundedLinearExpression operator ==(LinearExpr a, LinearExpr b) + { + return new BoundedLinearExpression(a, b, true); + } + + public static BoundedLinearExpression operator !=(LinearExpr a, LinearExpr b) + { + return new BoundedLinearExpression(a, b, false); + } + + public static BoundedLinearExpression operator ==(LinearExpr a, long v) + { + return new BoundedLinearExpression(a, v, true); + } + + public static BoundedLinearExpression operator !=(LinearExpr a, long v) + { + return new BoundedLinearExpression(a, v, false); + } + + public static BoundedLinearExpression operator >=(LinearExpr a, long v) + { + return new BoundedLinearExpression(v, a, Int64.MaxValue); + } + + public static BoundedLinearExpression operator >=(long v, LinearExpr a) + { + return a <= v; + } + + public static BoundedLinearExpression operator>(LinearExpr a, long v) + { + return new BoundedLinearExpression(v + 1, a, Int64.MaxValue); + } + + public static BoundedLinearExpression operator>(long v, LinearExpr a) + { + return a < v; + } + + public static BoundedLinearExpression operator <=(LinearExpr a, long v) + { + return new BoundedLinearExpression(Int64.MinValue, a, v); + } + + public static BoundedLinearExpression operator <=(long v, LinearExpr a) + { + return a >= v; + } + + public static BoundedLinearExpression operator<(LinearExpr a, long v) + { + return new BoundedLinearExpression(Int64.MinValue, a, v - 1); + } + + public static BoundedLinearExpression operator<(long v, LinearExpr a) + { + return a > v; + } + + public static BoundedLinearExpression operator >=(LinearExpr a, LinearExpr b) + { + return new BoundedLinearExpression(0, a - b, Int64.MaxValue); + } + + public static BoundedLinearExpression operator>(LinearExpr a, LinearExpr b) + { + return new BoundedLinearExpression(1, a - b, Int64.MaxValue); + } + + public static BoundedLinearExpression operator <=(LinearExpr a, LinearExpr b) + { + return new BoundedLinearExpression(Int64.MinValue, a - b, 0); + } + + public static BoundedLinearExpression operator<(LinearExpr a, LinearExpr b) + { + return new BoundedLinearExpression(Int64.MinValue, a - b, -1); + } + + public static LinearExpr Prod(LinearExpr e, long v) + { + if (v == 1) { - return new BoundedLinearExpression(a, b, true); + return e; } - - public static BoundedLinearExpression operator !=(LinearExpr a, LinearExpr b) + else if (e is ProductCst) { - return new BoundedLinearExpression(a, b, false); + ProductCst p = (ProductCst)e; + return new ProductCst(p.Expr, p.Coeff * v); } - - public static BoundedLinearExpression operator ==(LinearExpr a, long v) + else { - return new BoundedLinearExpression(a, v, true); + return new ProductCst(e, v); } + } - public static BoundedLinearExpression operator !=(LinearExpr a, long v) + public static long GetVarValueMap(LinearExpr e, long initial_coeff, Dictionary dict) + { + List exprs = new List(); + List coeffs = new List(); + if ((Object)e != null) { - return new BoundedLinearExpression(a, v, false); + exprs.Add(e); + coeffs.Add(initial_coeff); } + long constant = 0; - public static BoundedLinearExpression operator >=(LinearExpr a, long v) + while (exprs.Count > 0) { - return new BoundedLinearExpression(v, a, Int64.MaxValue); - } + LinearExpr expr = exprs[0]; + exprs.RemoveAt(0); + long coeff = coeffs[0]; + coeffs.RemoveAt(0); + if (coeff == 0 || (Object)expr == null) + continue; - public static BoundedLinearExpression operator >=(long v, LinearExpr a) - { - return a <= v; - } - - public static BoundedLinearExpression operator>(LinearExpr a, long v) - { - return new BoundedLinearExpression(v + 1, a, Int64.MaxValue); - } - - public static BoundedLinearExpression operator>(long v, LinearExpr a) - { - return a < v; - } - - public static BoundedLinearExpression operator <=(LinearExpr a, long v) - { - return new BoundedLinearExpression(Int64.MinValue, a, v); - } - - public static BoundedLinearExpression operator <=(long v, LinearExpr a) - { - return a >= v; - } - - public static BoundedLinearExpression operator<(LinearExpr a, long v) - { - return new BoundedLinearExpression(Int64.MinValue, a, v - 1); - } - - public static BoundedLinearExpression operator<(long v, LinearExpr a) - { - return a > v; - } - - public static BoundedLinearExpression operator >=(LinearExpr a, LinearExpr b) - { - return new BoundedLinearExpression(0, a - b, Int64.MaxValue); - } - - public static BoundedLinearExpression operator>(LinearExpr a, LinearExpr b) - { - return new BoundedLinearExpression(1, a - b, Int64.MaxValue); - } - - public static BoundedLinearExpression operator <=(LinearExpr a, LinearExpr b) - { - return new BoundedLinearExpression(Int64.MinValue, a - b, 0); - } - - public static BoundedLinearExpression operator<(LinearExpr a, LinearExpr b) - { - return new BoundedLinearExpression(Int64.MinValue, a - b, -1); - } - - public static LinearExpr Prod(LinearExpr e, long v) - { - if (v == 1) + if (expr is ProductCst) { - return e; - } - else if (e is ProductCst) - { - ProductCst p = (ProductCst)e; - return new ProductCst(p.Expr, p.Coeff * v); - } - else - { - return new ProductCst(e, v); - } - } - - public static long GetVarValueMap(LinearExpr e, long initial_coeff, Dictionary dict) - { - List exprs = new List(); - List coeffs = new List(); - if ((Object)e != null) - { - exprs.Add(e); - coeffs.Add(initial_coeff); - } - long constant = 0; - - while (exprs.Count > 0) - { - LinearExpr expr = exprs[0]; - exprs.RemoveAt(0); - long coeff = coeffs[0]; - coeffs.RemoveAt(0); - if (coeff == 0 || (Object)expr == null) - continue; - - if (expr is ProductCst) + ProductCst p = (ProductCst)expr; + if (p.Coeff != 0) { - ProductCst p = (ProductCst)expr; - if (p.Coeff != 0) - { - exprs.Add(p.Expr); - coeffs.Add(p.Coeff * coeff); - } + exprs.Add(p.Expr); + coeffs.Add(p.Coeff * coeff); } - else if (expr is SumArray) + } + else if (expr is SumArray) + { + SumArray a = (SumArray)expr; + constant += coeff * a.Offset; + foreach (LinearExpr sub in a.Expressions) { - SumArray a = (SumArray)expr; - constant += coeff * a.Offset; - foreach (LinearExpr sub in a.Expressions) + if (sub is IntVar) { - if (sub is IntVar) + IntVar i = (IntVar)sub; + if (dict.ContainsKey(i)) { - IntVar i = (IntVar)sub; - if (dict.ContainsKey(i)) - { - dict[i] += coeff; - } - else - { - dict.Add(i, coeff); - } - } - else if (sub is ProductCst && ((ProductCst)sub).Expr is IntVar) - { - ProductCst sub_prod = (ProductCst)sub; - IntVar i = (IntVar)sub_prod.Expr; - long sub_coeff = sub_prod.Coeff; - - if (dict.ContainsKey(i)) - { - dict[i] += coeff * sub_coeff; - } - else - { - dict.Add(i, coeff * sub_coeff); - } + dict[i] += coeff; } else { - exprs.Add(sub); - coeffs.Add(coeff); + dict.Add(i, coeff); } } - } - else if (expr is ConstantExpr) - { - ConstantExpr cte = (ConstantExpr)expr; - constant += coeff * cte.Value; - } - else if (expr is IntVar) - { - IntVar i = (IntVar)expr; - if (dict.ContainsKey(i)) + else if (sub is ProductCst && ((ProductCst)sub).Expr is IntVar) { - dict[i] += coeff; + ProductCst sub_prod = (ProductCst)sub; + IntVar i = (IntVar)sub_prod.Expr; + long sub_coeff = sub_prod.Coeff; + + if (dict.ContainsKey(i)) + { + dict[i] += coeff * sub_coeff; + } + else + { + dict.Add(i, coeff * sub_coeff); + } } else { - dict.Add(i, coeff); + exprs.Add(sub); + coeffs.Add(coeff); } } - else if (expr is NotBooleanVariable) + } + else if (expr is ConstantExpr) + { + ConstantExpr cte = (ConstantExpr)expr; + constant += coeff * cte.Value; + } + else if (expr is IntVar) + { + IntVar i = (IntVar)expr; + if (dict.ContainsKey(i)) { - IntVar i = ((NotBooleanVariable)expr).NotVar(); - if (dict.ContainsKey(i)) - { - dict[i] -= coeff; - } - else - { - dict.Add(i, -coeff); - } - constant += coeff; + dict[i] += coeff; } else { - throw new ArgumentException("Cannot interpret '" + expr.ToString() + "' in an integer expression"); + dict.Add(i, coeff); } } - return constant; - } - - public static LinearExpr RebuildLinearExprFromLinearExpressionProto(LinearExpressionProto proto, - CpModelProto model) - { - int numElements = proto.Vars.Count; - long offset = proto.Offset; - if (numElements == 0) + else if (expr is NotBooleanVariable) { - return LinearExpr.Constant(offset); - } - else if (numElements == 1) - { - IntVar var = new IntVar(model, proto.Vars[0]); - long coeff = proto.Coeffs[0]; - return LinearExpr.Affine(var, coeff, offset); + IntVar i = ((NotBooleanVariable)expr).NotVar(); + if (dict.ContainsKey(i)) + { + dict[i] -= coeff; + } + else + { + dict.Add(i, -coeff); + } + constant += coeff; } else { - LinearExpr[] exprs = new LinearExpr[numElements]; - for (int i = 0; i < numElements; ++i) - { - IntVar var = new IntVar(model, proto.Vars[i]); - long coeff = proto.Coeffs[i]; - exprs[i] = Prod(var, coeff); - } - SumArray sum = new SumArray(exprs); - sum.Offset = sum.Offset + offset; - return sum; + throw new ArgumentException("Cannot interpret '" + expr.ToString() + "' in an integer expression"); } } + return constant; + } + + public static LinearExpr RebuildLinearExprFromLinearExpressionProto(LinearExpressionProto proto, CpModelProto model) + { + int numElements = proto.Vars.Count; + long offset = proto.Offset; + if (numElements == 0) + { + return LinearExpr.Constant(offset); + } + else if (numElements == 1) + { + IntVar var = new IntVar(model, proto.Vars[0]); + long coeff = proto.Coeffs[0]; + return LinearExpr.Affine(var, coeff, offset); + } + else + { + LinearExpr[] exprs = new LinearExpr[numElements]; + for (int i = 0; i < numElements; ++i) + { + IntVar var = new IntVar(model, proto.Vars[i]); + long coeff = proto.Coeffs[i]; + exprs[i] = Prod(var, coeff); + } + SumArray sum = new SumArray(exprs); + sum.Offset = sum.Offset + offset; + return sum; + } + } +} + +public class ProductCst : LinearExpr +{ + public ProductCst(LinearExpr e, long v) + { + expr_ = e; + coeff_ = v; + } + + public LinearExpr Expr + { + get { + return expr_; + } } - public class ProductCst : LinearExpr + public long Coeff { - public ProductCst(LinearExpr e, long v) - { - expr_ = e; - coeff_ = v; + get { + return coeff_; } - - public LinearExpr Expr - { - get { - return expr_; - } - } - - public long Coeff - { - get { - return coeff_; - } - } - - private LinearExpr expr_; - private long coeff_; } - public class SumArray : LinearExpr + private LinearExpr expr_; + private long coeff_; +} + +public class SumArray : LinearExpr +{ + public SumArray(LinearExpr a, LinearExpr b) { - public SumArray(LinearExpr a, LinearExpr b) - { - expressions_ = new List(); - AddExpr(a); - AddExpr(b); - offset_ = 0L; - } - - public SumArray(LinearExpr a, long b) - { - expressions_ = new List(); - AddExpr(a); - offset_ = b; - } - - public SumArray(IEnumerable exprs) - { - expressions_ = new List(exprs); - offset_ = 0L; - } - - public SumArray(IEnumerable vars) - { - expressions_ = new List(vars); - offset_ = 0L; - } - - public SumArray(IntVar[] vars, long[] coeffs) - { - expressions_ = new List(vars.Length); - for (int i = 0; i < vars.Length; ++i) - { - AddExpr(Prod(vars[i], coeffs[i])); - } - offset_ = 0L; - } - - public SumArray(IEnumerable vars, IEnumerable coeffs) - { - List tmp_vars = new List(); - foreach (IntVar v in vars) - { - tmp_vars.Add(v); - } - List tmp_coeffs = new List(); - foreach (long c in coeffs) - { - tmp_coeffs.Add(c); - } - if (tmp_vars.Count != tmp_coeffs.Count) - { - throw new ArgumentException("in SumArray(vars, coeffs), the two lists do not have the same length"); - } - IntVar[] flat_vars = tmp_vars.ToArray(); - long[] flat_coeffs = tmp_coeffs.ToArray(); - expressions_ = new List(flat_vars.Length); - for (int i = 0; i < flat_vars.Length; ++i) - { - expressions_.Add(Prod(flat_vars[i], flat_coeffs[i])); - } - offset_ = 0L; - } - - public SumArray(IEnumerable vars, IEnumerable coeffs) - { - List tmp_vars = new List(); - foreach (IntVar v in vars) - { - tmp_vars.Add(v); - } - List tmp_coeffs = new List(); - foreach (int c in coeffs) - { - tmp_coeffs.Add(c); - } - if (tmp_vars.Count != tmp_coeffs.Count) - { - throw new ArgumentException("in SumArray(vars, coeffs), the two lists do not have the same length"); - } - IntVar[] flat_vars = tmp_vars.ToArray(); - long[] flat_coeffs = tmp_coeffs.ToArray(); - expressions_ = new List(flat_vars.Length); - for (int i = 0; i < flat_vars.Length; ++i) - { - expressions_.Add(Prod(flat_vars[i], flat_coeffs[i])); - } - offset_ = 0L; - } - - public void AddExpr(LinearExpr expr) - { - if ((Object)expr != null) - { - expressions_.Add(expr); - } - } - - public List Expressions - { - get { - return expressions_; - } - } - - public long Offset - { - get { - return offset_; - } - set { - offset_ = value; - } - } - - public override string ShortString() - { - return String.Format("({0})", ToString()); - } - - public override string ToString() - { - string result = ""; - foreach (LinearExpr expr in expressions_) - { - if ((Object)expr == null) - continue; - if (!String.IsNullOrEmpty(result)) - { - result += String.Format(" + "); - } - - result += expr.ShortString(); - } - if (offset_ != 0) - { - result += String.Format(" + {0}", offset_); - } - return result; - } - - private List expressions_; - private long offset_; + expressions_ = new List(); + AddExpr(a); + AddExpr(b); + offset_ = 0L; } - public class ConstantExpr : LinearExpr + public SumArray(LinearExpr a, long b) { - public ConstantExpr(long value) - { - value_ = value; - } - - public long Value - { - get { - return value_; - } - } - - public override string ShortString() - { - return String.Format("{0}", value_); - } - - public override string ToString() - { - return String.Format("ConstantExpr({0})", value_); - } - - private long value_; + expressions_ = new List(); + AddExpr(a); + offset_ = b; } - public class IntVar : LinearExpr, ILiteral + public SumArray(IEnumerable exprs) { - public IntVar(CpModelProto model, Domain domain, string name) - { - model_ = model; - index_ = model.Variables.Count; - var_ = new IntegerVariableProto(); - var_.Name = name; - var_.Domain.Add(domain.FlattenedIntervals()); - model.Variables.Add(var_); - negation_ = null; - } + expressions_ = new List(exprs); + offset_ = 0L; + } - public IntVar(CpModelProto model, int index) - { - model_ = model; - index_ = index; - var_ = model.Variables[index]; - negation_ = null; - } + public SumArray(IEnumerable vars) + { + expressions_ = new List(vars); + offset_ = 0L; + } - public override int GetIndex() + public SumArray(IntVar[] vars, long[] coeffs) + { + expressions_ = new List(vars.Length); + for (int i = 0; i < vars.Length; ++i) { - return index_; + AddExpr(Prod(vars[i], coeffs[i])); } + offset_ = 0L; + } - public IntegerVariableProto Proto + public SumArray(IEnumerable vars, IEnumerable coeffs) + { + List tmp_vars = new List(); + foreach (IntVar v in vars) { - get { - return var_; - } - set { - var_ = value; - } + tmp_vars.Add(v); } - - public Domain Domain + List tmp_coeffs = new List(); + foreach (long c in coeffs) { - get { - return CpSatHelper.VariableDomain(var_); - } + tmp_coeffs.Add(c); } - - public override string ToString() + if (tmp_vars.Count != tmp_coeffs.Count) { - return var_.ToString(); + throw new ArgumentException("in SumArray(vars, coeffs), the two lists do not have the same length"); } - - public override string ShortString() + IntVar[] flat_vars = tmp_vars.ToArray(); + long[] flat_coeffs = tmp_coeffs.ToArray(); + expressions_ = new List(flat_vars.Length); + for (int i = 0; i < flat_vars.Length; ++i) { - if (var_.Name != null) + expressions_.Add(Prod(flat_vars[i], flat_coeffs[i])); + } + offset_ = 0L; + } + + public SumArray(IEnumerable vars, IEnumerable coeffs) + { + List tmp_vars = new List(); + foreach (IntVar v in vars) + { + tmp_vars.Add(v); + } + List tmp_coeffs = new List(); + foreach (int c in coeffs) + { + tmp_coeffs.Add(c); + } + if (tmp_vars.Count != tmp_coeffs.Count) + { + throw new ArgumentException("in SumArray(vars, coeffs), the two lists do not have the same length"); + } + IntVar[] flat_vars = tmp_vars.ToArray(); + long[] flat_coeffs = tmp_coeffs.ToArray(); + expressions_ = new List(flat_vars.Length); + for (int i = 0; i < flat_vars.Length; ++i) + { + expressions_.Add(Prod(flat_vars[i], flat_coeffs[i])); + } + offset_ = 0L; + } + + public void AddExpr(LinearExpr expr) + { + if ((Object)expr != null) + { + expressions_.Add(expr); + } + } + + public List Expressions + { + get { + return expressions_; + } + } + + public long Offset + { + get { + return offset_; + } + set { + offset_ = value; + } + } + + public override string ShortString() + { + return String.Format("({0})", ToString()); + } + + public override string ToString() + { + string result = ""; + foreach (LinearExpr expr in expressions_) + { + if ((Object)expr == null) + continue; + if (!String.IsNullOrEmpty(result)) { - return var_.Name; + result += String.Format(" + "); } - else - { - return var_.ToString(); - } - } - public string Name() + result += expr.ShortString(); + } + if (offset_ != 0) + { + result += String.Format(" + {0}", offset_); + } + return result; + } + + private List expressions_; + private long offset_; +} + +public class ConstantExpr : LinearExpr +{ + public ConstantExpr(long value) + { + value_ = value; + } + + public long Value + { + get { + return value_; + } + } + + public override string ShortString() + { + return String.Format("{0}", value_); + } + + public override string ToString() + { + return String.Format("ConstantExpr({0})", value_); + } + + private long value_; +} + +public class IntVar : LinearExpr, ILiteral +{ + public IntVar(CpModelProto model, Domain domain, string name) + { + model_ = model; + index_ = model.Variables.Count; + var_ = new IntegerVariableProto(); + var_.Name = name; + var_.Domain.Add(domain.FlattenedIntervals()); + model.Variables.Add(var_); + negation_ = null; + } + + public IntVar(CpModelProto model, int index) + { + model_ = model; + index_ = index; + var_ = model.Variables[index]; + negation_ = null; + } + + public override int GetIndex() + { + return index_; + } + + public IntegerVariableProto Proto + { + get { + return var_; + } + set { + var_ = value; + } + } + + public Domain Domain + { + get { + return CpSatHelper.VariableDomain(var_); + } + } + + public override string ToString() + { + return var_.ToString(); + } + + public override string ShortString() + { + if (var_.Name != null) { return var_.Name; } - - public ILiteral Not() + else { - foreach (long b in var_.Domain) - { - if (b < 0 || b > 1) - { - throw new ArgumentException("Cannot call Not() on a non boolean variable"); - } - } - if (negation_ == null) - { - negation_ = new NotBooleanVariable(this); - } - return negation_; + return var_.ToString(); } - - private CpModelProto model_; - private int index_; - private IntegerVariableProto var_; - private NotBooleanVariable negation_; } - public class NotBooleanVariable : LinearExpr, ILiteral + public string Name() { - public NotBooleanVariable(IntVar boolvar) - { - boolvar_ = boolvar; - } - - public override int GetIndex() - { - return -boolvar_.Index - 1; - } - - public ILiteral Not() - { - return boolvar_; - } - - public IntVar NotVar() - { - return boolvar_; - } - - public override string ShortString() - { - return String.Format("Not({0})", boolvar_.ShortString()); - } - - private IntVar boolvar_; + return var_.Name; } - public class BoundedLinearExpression + public ILiteral Not() { - public enum Type + foreach (long b in var_.Domain) { - BoundExpression, - VarEqVar, - VarDiffVar, - VarEqCst, - VarDiffCst, - } - - public BoundedLinearExpression(long lb, LinearExpr expr, long ub) - { - left_ = expr; - right_ = null; - lb_ = lb; - ub_ = ub; - type_ = Type.BoundExpression; - } - - public BoundedLinearExpression(LinearExpr left, LinearExpr right, bool equality) - { - left_ = left; - right_ = right; - lb_ = 0; - ub_ = 0; - type_ = equality ? Type.VarEqVar : Type.VarDiffVar; - } - - public BoundedLinearExpression(LinearExpr left, long v, bool equality) - { - left_ = left; - right_ = null; - lb_ = v; - ub_ = 0; - type_ = equality ? Type.VarEqCst : Type.VarDiffCst; - } - - bool IsTrue() - { - if (type_ == Type.VarEqVar) + if (b < 0 || b > 1) { - return (object)left_ == (object)right_; - } - else if (type_ == Type.VarDiffVar) - { - return (object)left_ != (object)right_; - } - return false; - } - - public static bool operator true(BoundedLinearExpression bie) - { - return bie.IsTrue(); - } - - public static bool operator false(BoundedLinearExpression bie) - { - return !bie.IsTrue(); - } - - public override string ToString() - { - switch (type_) - { - case Type.BoundExpression: - return String.Format("{0} <= {1} <= {2}", lb_, left_, ub_); - case Type.VarEqVar: - return String.Format("{0} == {1}", left_, right_); - case Type.VarDiffVar: - return String.Format("{0} != {1}", left_, right_); - case Type.VarEqCst: - return String.Format("{0} == {1}", left_, lb_); - case Type.VarDiffCst: - return String.Format("{0} != {1}", left_, lb_); - default: - throw new ArgumentException("Wrong mode in BoundedLinearExpression."); + throw new ArgumentException("Cannot call Not() on a non boolean variable"); } } - - public static BoundedLinearExpression operator <=(BoundedLinearExpression a, long v) + if (negation_ == null) { - if (a.CtType != Type.BoundExpression || a.Ub != Int64.MaxValue) - { - throw new ArgumentException("Operator <= not supported for this BoundedLinearExpression"); - } - return new BoundedLinearExpression(a.Lb, a.Left, v); + negation_ = new NotBooleanVariable(this); } - - public static BoundedLinearExpression operator<(BoundedLinearExpression a, long v) - { - if (a.CtType != Type.BoundExpression || a.Ub != Int64.MaxValue) - { - throw new ArgumentException("Operator < not supported for this BoundedLinearExpression"); - } - return new BoundedLinearExpression(a.Lb, a.Left, v - 1); - } - - public static BoundedLinearExpression operator >=(BoundedLinearExpression a, long v) - { - if (a.CtType != Type.BoundExpression || a.Lb != Int64.MinValue) - { - throw new ArgumentException("Operator >= not supported for this BoundedLinearExpression"); - } - return new BoundedLinearExpression(v, a.Left, a.Ub); - } - - public static BoundedLinearExpression operator>(BoundedLinearExpression a, long v) - { - if (a.CtType != Type.BoundExpression || a.Lb != Int64.MinValue) - { - throw new ArgumentException("Operator < not supported for this BoundedLinearExpression"); - } - return new BoundedLinearExpression(v + 1, a.Left, a.Ub); - } - - public LinearExpr Left - { - get { - return left_; - } - } - - public LinearExpr Right - { - get { - return right_; - } - } - - public long Lb - { - get { - return lb_; - } - } - - public long Ub - { - get { - return ub_; - } - } - - public Type CtType - { - get { - return type_; - } - } - - private LinearExpr left_; - private LinearExpr right_; - private long lb_; - private long ub_; - private Type type_; + return negation_; } + private CpModelProto model_; + private int index_; + private IntegerVariableProto var_; + private NotBooleanVariable negation_; +} + +public class NotBooleanVariable : LinearExpr, ILiteral +{ + public NotBooleanVariable(IntVar boolvar) + { + boolvar_ = boolvar; + } + + public override int GetIndex() + { + return -boolvar_.Index - 1; + } + + public ILiteral Not() + { + return boolvar_; + } + + public IntVar NotVar() + { + return boolvar_; + } + + public override string ShortString() + { + return String.Format("Not({0})", boolvar_.ShortString()); + } + + private IntVar boolvar_; +} + +public class BoundedLinearExpression +{ + public enum Type + { + BoundExpression, + VarEqVar, + VarDiffVar, + VarEqCst, + VarDiffCst, + } + + public BoundedLinearExpression(long lb, LinearExpr expr, long ub) + { + left_ = expr; + right_ = null; + lb_ = lb; + ub_ = ub; + type_ = Type.BoundExpression; + } + + public BoundedLinearExpression(LinearExpr left, LinearExpr right, bool equality) + { + left_ = left; + right_ = right; + lb_ = 0; + ub_ = 0; + type_ = equality ? Type.VarEqVar : Type.VarDiffVar; + } + + public BoundedLinearExpression(LinearExpr left, long v, bool equality) + { + left_ = left; + right_ = null; + lb_ = v; + ub_ = 0; + type_ = equality ? Type.VarEqCst : Type.VarDiffCst; + } + + bool IsTrue() + { + if (type_ == Type.VarEqVar) + { + return (object)left_ == (object)right_; + } + else if (type_ == Type.VarDiffVar) + { + return (object)left_ != (object)right_; + } + return false; + } + + public static bool operator true(BoundedLinearExpression bie) + { + return bie.IsTrue(); + } + + public static bool operator false(BoundedLinearExpression bie) + { + return !bie.IsTrue(); + } + + public override string ToString() + { + switch (type_) + { + case Type.BoundExpression: + return String.Format("{0} <= {1} <= {2}", lb_, left_, ub_); + case Type.VarEqVar: + return String.Format("{0} == {1}", left_, right_); + case Type.VarDiffVar: + return String.Format("{0} != {1}", left_, right_); + case Type.VarEqCst: + return String.Format("{0} == {1}", left_, lb_); + case Type.VarDiffCst: + return String.Format("{0} != {1}", left_, lb_); + default: + throw new ArgumentException("Wrong mode in BoundedLinearExpression."); + } + } + + public static BoundedLinearExpression operator <=(BoundedLinearExpression a, long v) + { + if (a.CtType != Type.BoundExpression || a.Ub != Int64.MaxValue) + { + throw new ArgumentException("Operator <= not supported for this BoundedLinearExpression"); + } + return new BoundedLinearExpression(a.Lb, a.Left, v); + } + + public static BoundedLinearExpression operator<(BoundedLinearExpression a, long v) + { + if (a.CtType != Type.BoundExpression || a.Ub != Int64.MaxValue) + { + throw new ArgumentException("Operator < not supported for this BoundedLinearExpression"); + } + return new BoundedLinearExpression(a.Lb, a.Left, v - 1); + } + + public static BoundedLinearExpression operator >=(BoundedLinearExpression a, long v) + { + if (a.CtType != Type.BoundExpression || a.Lb != Int64.MinValue) + { + throw new ArgumentException("Operator >= not supported for this BoundedLinearExpression"); + } + return new BoundedLinearExpression(v, a.Left, a.Ub); + } + + public static BoundedLinearExpression operator>(BoundedLinearExpression a, long v) + { + if (a.CtType != Type.BoundExpression || a.Lb != Int64.MinValue) + { + throw new ArgumentException("Operator < not supported for this BoundedLinearExpression"); + } + return new BoundedLinearExpression(v + 1, a.Left, a.Ub); + } + + public LinearExpr Left + { + get { + return left_; + } + } + + public LinearExpr Right + { + get { + return right_; + } + } + + public long Lb + { + get { + return lb_; + } + } + + public long Ub + { + get { + return ub_; + } + } + + public Type CtType + { + get { + return type_; + } + } + + private LinearExpr left_; + private LinearExpr right_; + private long lb_; + private long ub_; + private Type type_; +} + } // namespace Google.OrTools.Sat diff --git a/ortools/sat/csharp/IntervalVariables.cs b/ortools/sat/csharp/IntervalVariables.cs index 82c2ebb37d..a9e3f17aa1 100644 --- a/ortools/sat/csharp/IntervalVariables.cs +++ b/ortools/sat/csharp/IntervalVariables.cs @@ -13,87 +13,87 @@ namespace Google.OrTools.Sat { - using System; - using System.Collections.Generic; +using System; +using System.Collections.Generic; - public class IntervalVar +public class IntervalVar +{ + public IntervalVar(CpModelProto model, LinearExpressionProto start, LinearExpressionProto size, + LinearExpressionProto end, int is_present_index, string name) { - public IntervalVar(CpModelProto model, LinearExpressionProto start, LinearExpressionProto size, - LinearExpressionProto end, int is_present_index, string name) - { - model_ = model; - index_ = model.Constraints.Count; - interval_ = new IntervalConstraintProto(); - interval_.Start = start; - interval_.Size = size; - interval_.End = end; + model_ = model; + index_ = model.Constraints.Count; + interval_ = new IntervalConstraintProto(); + interval_.Start = start; + interval_.Size = size; + interval_.End = end; - ConstraintProto ct = new ConstraintProto(); - ct.Interval = interval_; - ct.Name = name; - ct.EnforcementLiteral.Add(is_present_index); - model.Constraints.Add(ct); - } - - public IntervalVar(CpModelProto model, LinearExpressionProto start, LinearExpressionProto size, - LinearExpressionProto end, string name) - { - model_ = model; - index_ = model.Constraints.Count; - interval_ = new IntervalConstraintProto(); - interval_.Start = start; - interval_.Size = size; - interval_.End = end; - - ConstraintProto ct = new ConstraintProto(); - ct.Interval = interval_; - ct.Name = name; - model_.Constraints.Add(ct); - } - - public int GetIndex() - { - return index_; - } - - public LinearExpr StartExpr() - { - return LinearExpr.RebuildLinearExprFromLinearExpressionProto(interval_.Start, model_); - } - - public LinearExpr SizeExpr() - { - return LinearExpr.RebuildLinearExprFromLinearExpressionProto(interval_.Size, model_); - } - - public LinearExpr EndExpr() - { - return LinearExpr.RebuildLinearExprFromLinearExpressionProto(interval_.End, model_); - } - - public IntervalConstraintProto Proto - { - get { - return interval_; - } - set { - interval_ = value; - } - } - - public override string ToString() - { - return model_.Constraints[index_].ToString(); - } - - public string Name() - { - return model_.Constraints[index_].Name; - } - - private CpModelProto model_; - private int index_; - private IntervalConstraintProto interval_; + ConstraintProto ct = new ConstraintProto(); + ct.Interval = interval_; + ct.Name = name; + ct.EnforcementLiteral.Add(is_present_index); + model.Constraints.Add(ct); } + public IntervalVar(CpModelProto model, LinearExpressionProto start, LinearExpressionProto size, + LinearExpressionProto end, string name) + { + model_ = model; + index_ = model.Constraints.Count; + interval_ = new IntervalConstraintProto(); + interval_.Start = start; + interval_.Size = size; + interval_.End = end; + + ConstraintProto ct = new ConstraintProto(); + ct.Interval = interval_; + ct.Name = name; + model_.Constraints.Add(ct); + } + + public int GetIndex() + { + return index_; + } + + public LinearExpr StartExpr() + { + return LinearExpr.RebuildLinearExprFromLinearExpressionProto(interval_.Start, model_); + } + + public LinearExpr SizeExpr() + { + return LinearExpr.RebuildLinearExprFromLinearExpressionProto(interval_.Size, model_); + } + + public LinearExpr EndExpr() + { + return LinearExpr.RebuildLinearExprFromLinearExpressionProto(interval_.End, model_); + } + + public IntervalConstraintProto Proto + { + get { + return interval_; + } + set { + interval_ = value; + } + } + + public override string ToString() + { + return model_.Constraints[index_].ToString(); + } + + public string Name() + { + return model_.Constraints[index_].Name; + } + + private CpModelProto model_; + private int index_; + private IntervalConstraintProto interval_; +} + } // namespace Google.OrTools.Sat diff --git a/ortools/sat/csharp/SearchHelpers.cs b/ortools/sat/csharp/SearchHelpers.cs index b1fd136178..146e107553 100644 --- a/ortools/sat/csharp/SearchHelpers.cs +++ b/ortools/sat/csharp/SearchHelpers.cs @@ -17,102 +17,102 @@ using System.Collections.Generic; namespace Google.OrTools.Sat { - public class CpSolverSolutionCallback : SolutionCallback +public class CpSolverSolutionCallback : SolutionCallback +{ + public long Value(LinearExpr e) { - public long Value(LinearExpr e) + List exprs = new List(); + List coeffs = new List(); + exprs.Add(e); + coeffs.Add(1L); + long constant = 0; + + while (exprs.Count > 0) { - List exprs = new List(); - List coeffs = new List(); - exprs.Add(e); - coeffs.Add(1L); - long constant = 0; + LinearExpr expr = exprs[0]; + exprs.RemoveAt(0); + long coeff = coeffs[0]; + coeffs.RemoveAt(0); + if (coeff == 0) + continue; - while (exprs.Count > 0) + if (expr is ProductCst) { - LinearExpr expr = exprs[0]; - exprs.RemoveAt(0); - long coeff = coeffs[0]; - coeffs.RemoveAt(0); - if (coeff == 0) - continue; - - if (expr is ProductCst) + ProductCst p = (ProductCst)expr; + if (p.Coeff != 0) { - ProductCst p = (ProductCst)expr; - if (p.Coeff != 0) - { - exprs.Add(p.Expr); - coeffs.Add(p.Coeff * coeff); - } - } - else if (expr is SumArray) - { - SumArray a = (SumArray)expr; - constant += coeff * a.Offset; - foreach (LinearExpr sub in a.Expressions) - { - exprs.Add(sub); - coeffs.Add(coeff); - } - } - else if (expr is IntVar) - { - int index = expr.Index; - long value = SolutionIntegerValue(index); - constant += coeff * value; - } - else if (expr is NotBooleanVariable) - { - throw new ArgumentException("Cannot evaluate a literal in an integer expression."); - } - else - { - throw new ArgumentException("Cannot evaluate '" + expr.ToString() + "' in an integer expression"); + exprs.Add(p.Expr); + coeffs.Add(p.Coeff * coeff); } } - return constant; - } - - public Boolean BooleanValue(ILiteral literal) - { - if (literal is IntVar || literal is NotBooleanVariable) + else if (expr is SumArray) { - int index = literal.GetIndex(); - return SolutionBooleanValue(index); + SumArray a = (SumArray)expr; + constant += coeff * a.Offset; + foreach (LinearExpr sub in a.Expressions) + { + exprs.Add(sub); + coeffs.Add(coeff); + } + } + else if (expr is IntVar) + { + int index = expr.Index; + long value = SolutionIntegerValue(index); + constant += coeff * value; + } + else if (expr is NotBooleanVariable) + { + throw new ArgumentException("Cannot evaluate a literal in an integer expression."); } else { - throw new ArgumentException("Cannot evaluate '" + literal.ToString() + "' as a boolean literal"); + throw new ArgumentException("Cannot evaluate '" + expr.ToString() + "' in an integer expression"); } } + return constant; } - public class ObjectiveSolutionPrinter : CpSolverSolutionCallback + public Boolean BooleanValue(ILiteral literal) { - private DateTime _startTime; - private int _solutionCount; - - public ObjectiveSolutionPrinter() + if (literal is IntVar || literal is NotBooleanVariable) { - _startTime = DateTime.Now; + int index = literal.GetIndex(); + return SolutionBooleanValue(index); } - - public override void OnSolutionCallback() + else { - var currentTime = DateTime.Now; - var objective = ObjectiveValue(); - var objectiveBound = BestObjectiveBound(); - var objLb = Math.Min(objective, objectiveBound); - var objUb = Math.Max(objective, objectiveBound); - var time = currentTime - _startTime; - - Console.WriteLine( - value: $"Solution {_solutionCount}, time = {time.TotalSeconds} s, objective = [{objLb}, {objUb}]"); - - _solutionCount++; + throw new ArgumentException("Cannot evaluate '" + literal.ToString() + "' as a boolean literal"); } - - public int solutionCount() => _solutionCount; } +} + +public class ObjectiveSolutionPrinter : CpSolverSolutionCallback +{ + private DateTime _startTime; + private int _solutionCount; + + public ObjectiveSolutionPrinter() + { + _startTime = DateTime.Now; + } + + public override void OnSolutionCallback() + { + var currentTime = DateTime.Now; + var objective = ObjectiveValue(); + var objectiveBound = BestObjectiveBound(); + var objLb = Math.Min(objective, objectiveBound); + var objUb = Math.Max(objective, objectiveBound); + var time = currentTime - _startTime; + + Console.WriteLine( + value: $"Solution {_solutionCount}, time = {time.TotalSeconds} s, objective = [{objLb}, {objUb}]"); + + _solutionCount++; + } + + public int solutionCount() => _solutionCount; +} } // namespace Google.OrTools.Sat diff --git a/ortools/sat/doc/channeling.md b/ortools/sat/doc/channeling.md index 26cdf25276..7621cd4048 100644 --- a/ortools/sat/doc/channeling.md +++ b/ortools/sat/doc/channeling.md @@ -140,7 +140,7 @@ void ChannelingSampleSat() { // Create our two half-reified constraints. // First, b implies (y == 10 - x). - cp_model.AddEquality(LinearExpr::Sum({x, y}), 10).OnlyEnforceIf(b); + cp_model.AddEquality(x + y, 10).OnlyEnforceIf(b); // Second, not(b) implies y == 0. cp_model.AddEquality(y, 0).OnlyEnforceIf(Not(b)); @@ -467,7 +467,7 @@ void BinpackingProblemSat() { for (int b = 0; b < kNumBins; ++b) { LinearExpr expr; for (int i = 0; i < num_items; ++i) { - expr.AddTerm(x[i][b], items[i][0]); + expr += x[i][b] * items[i][0]; } cp_model.AddEquality(expr, load[b]); } @@ -488,7 +488,7 @@ void BinpackingProblemSat() { } // Maximize sum of slacks. - cp_model.Maximize(LinearExpr::BooleanSum(slacks)); + cp_model.Maximize(LinearExpr::Sum(slacks)); // Solving part. const CpSolverResponse response = Solve(cp_model.Build()); diff --git a/ortools/sat/doc/integer_arithmetic.md b/ortools/sat/doc/integer_arithmetic.md index 454e887af2..d733de17ad 100644 --- a/ortools/sat/doc/integer_arithmetic.md +++ b/ortools/sat/doc/integer_arithmetic.md @@ -191,8 +191,8 @@ void RabbitsAndPheasantsSat() { const IntVar pheasants = cp_model.NewIntVar(all_animals).WithName("pheasants"); - cp_model.AddEquality(LinearExpr::Sum({rabbits, pheasants}), 20); - cp_model.AddEquality(LinearExpr::ScalProd({rabbits, pheasants}, {4, 2}), 56); + cp_model.AddEquality(rabbits + pheasants, 20); + cp_model.AddEquality(4 * rabbits + 2 * pheasants, 56); const CpSolverResponse response = Solve(cp_model.Build()); @@ -436,21 +436,9 @@ void EarlinessTardinessCostSampleSat() { const int64_t kLargeConstant = 1000; const IntVar expr = cp_model.NewIntVar({0, kLargeConstant}); - // First segment. - const IntVar s1 = cp_model.NewIntVar({-kLargeConstant, kLargeConstant}); - cp_model.AddEquality(s1, LinearExpr::ScalProd({x}, {-kEarlinessCost}) - .AddConstant(kEarlinessCost * kEarlinessDate)); - - // Second segment. - const IntVar s2 = cp_model.NewConstant(0); - - // Third segment. - const IntVar s3 = cp_model.NewIntVar({-kLargeConstant, kLargeConstant}); - cp_model.AddEquality(s3, LinearExpr::ScalProd({x}, {kLatenessCost}) - .AddConstant(-kLatenessCost * kLatenessDate)); - - // Link together expr and x through s1, s2, and s3. - cp_model.AddMaxEquality(expr, {s1, s2, s3}); + // Link together expr and x through the 3 segments. + cp_model.AddMaxEquality(expr, {(kEarlinessDate - x) * kEarlinessCost, 0, + (x - kLatenessDate) * kLatenessCost}); // Search for x values in increasing order. cp_model.AddDecisionStrategy({x}, DecisionStrategyProto::CHOOSE_FIRST, diff --git a/ortools/sat/doc/model.md b/ortools/sat/doc/model.md index 2a099231c5..c01340d754 100644 --- a/ortools/sat/doc/model.md +++ b/ortools/sat/doc/model.md @@ -145,7 +145,7 @@ void SolutionHintingSampleSat() { cp_model.AddNotEqual(x, y); - cp_model.Maximize(LinearExpr::ScalProd({x, y, z}, {1, 2, 3})); + cp_model.Maximize(x + 2 * y + 3 * z); // Solution hinting: x <- 1, y <- 2 cp_model.AddHint(x, 1); @@ -381,7 +381,7 @@ void CopyModelSat() { cp_model.AddNotEqual(x, y); - cp_model.Maximize(LinearExpr::ScalProd({x, y, z}, {1, 2, 3})); + cp_model.Maximize(x + 2 * y + 3 * z); const CpSolverResponse initial_response = Solve(cp_model.Build()); LOG(INFO) << "Optimal value of the original model: " @@ -394,7 +394,7 @@ void CopyModelSat() { IntVar copy_of_x = copy.GetIntVarFromProtoIndex(x.index()); IntVar copy_of_y = copy.GetIntVarFromProtoIndex(y.index()); - copy.AddLessOrEqual(LinearExpr::Sum({copy_of_x, copy_of_y}), 1); + copy.AddLessOrEqual(copy_of_x + copy_of_y, 1); const CpSolverResponse modified_response = Solve(copy.Build()); LOG(INFO) << "Optimal value of the modified model: " diff --git a/ortools/sat/doc/scheduling.md b/ortools/sat/doc/scheduling.md index 1d9c9f55dc..ee307065aa 100644 --- a/ortools/sat/doc/scheduling.md +++ b/ortools/sat/doc/scheduling.md @@ -117,8 +117,7 @@ void IntervalSampleSat() { const IntVar z = cp_model.NewIntVar(horizon).WithName("z"); const IntervalVar interval_var = - cp_model.NewIntervalVar(x, y, LinearExpr(z).AddConstant(2)) - .WithName("interval"); + cp_model.NewIntervalVar(x, y, z + 2).WithName("interval"); LOG(INFO) << "start = " << interval_var.StartExpr() << ", size = " << interval_var.SizeExpr() << ", end = " << interval_var.EndExpr() @@ -288,9 +287,7 @@ void OptionalIntervalSampleSat() { const BoolVar presence_var = cp_model.NewBoolVar().WithName("presence"); const IntervalVar interval_var = - cp_model - .NewOptionalIntervalVar(x, y, LinearExpr(z).AddConstant(2), - presence_var) + cp_model.NewOptionalIntervalVar(x, y, z + 2, presence_var) .WithName("interval"); LOG(INFO) << "start = " << interval_var.StartExpr() << ", size = " << interval_var.SizeExpr() @@ -904,7 +901,7 @@ void RankingSampleSat() { for (int i = 0; i < num_tasks; ++i) { LinearExpr sum_of_predecessors(-1); for (int j = 0; j < num_tasks; ++j) { - sum_of_predecessors.AddVar(precedences[j][i]); + sum_of_predecessors += precedences[j][i]; } cp_model.AddEquality(ranks[i], sum_of_predecessors); } @@ -953,10 +950,9 @@ void RankingSampleSat() { // Create objective: minimize 2 * makespan - 7 * sum of presences. // That is you gain 7 by interval performed, but you pay 2 by day of delays. - LinearExpr objective; - objective.AddTerm(makespan, 2); + LinearExpr objective = 2 * makespan; for (int t = 0; t < kNumTasks; ++t) { - objective.AddTerm(presences[t], -7); + objective -= 7 * presences[t]; } cp_model.Minimize(objective); diff --git a/ortools/sat/doc/solver.md b/ortools/sat/doc/solver.md index 840487ac82..a05a448ea4 100644 --- a/ortools/sat/doc/solver.md +++ b/ortools/sat/doc/solver.md @@ -289,7 +289,7 @@ void SolveAndPrintIntermediateSolutionsSampleSat() { cp_model.AddNotEqual(x, y); - cp_model.Maximize(LinearExpr::ScalProd({x, y, z}, {1, 2, 3})); + cp_model.Maximize(x + 2 * y + 3 * z); Model model; int num_solutions = 0; diff --git a/ortools/sat/probing.cc b/ortools/sat/probing.cc index 52145915e7..e46f04bdde 100644 --- a/ortools/sat/probing.cc +++ b/ortools/sat/probing.cc @@ -690,12 +690,14 @@ bool FailedLiteralProbingRound(ProbingOptions options, Model* model) { // even better reasony. Maybe it is just better to change all the // reason above to a binary one so we don't have an issue here. if (trail.AssignmentType(w.blocking_literal.Variable()) != id) { - ++num_new_binary; - implication_graph->AddBinaryClause(last_decision.Negated(), - w.blocking_literal); - + // If the variable was true at level zero, there is no point + // adding the clause. const auto& info = trail.Info(w.blocking_literal.Variable()); if (info.level > 0) { + ++num_new_binary; + implication_graph->AddBinaryClause(last_decision.Negated(), + w.blocking_literal); + const Literal d = sat_solver->Decisions()[info.level - 1].literal; if (d != w.blocking_literal) { implication_graph->ChangeReason(info.trail_index, d); diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index 5ad2abea7d..28dc6cbaac 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -2251,6 +2251,14 @@ class CpSolver(object): """Returns the indices of the infeasible assumptions.""" return self.__solution.sufficient_assumptions_for_infeasibility + def SolutionInfo(self): + """Returns some information on the solve process. + + Returns some information on how the solution was found, or the reason + why the model or the parameters are invalid. + """ + return self.__solution.solution_info + class CpSolverSolutionCallback(pywrapsat.SolutionCallback): """Solution callback. diff --git a/ortools/sat/samples/assignment_sat.cc b/ortools/sat/samples/assignment_sat.cc index e39a9b85b3..02078ac40d 100644 --- a/ortools/sat/samples/assignment_sat.cc +++ b/ortools/sat/samples/assignment_sat.cc @@ -21,7 +21,7 @@ namespace sat { void IntegerProgrammingExample() { // Data // [START data_model] - const std::vector> costs{ + const std::vector> costs{ {90, 80, 75, 70}, {35, 85, 55, 65}, {125, 95, 90, 95}, {45, 110, 95, 115}, {50, 100, 90, 100}, }; @@ -51,11 +51,7 @@ void IntegerProgrammingExample() { // [START constraints] // Each worker is assigned to at most one task. for (int i = 0; i < num_workers; ++i) { - LinearExpr worker_sum; - for (int j = 0; j < num_tasks; ++j) { - worker_sum.AddTerm(x[i][j], 1); - } - cp_model.AddLessOrEqual(worker_sum, 1); + cp_model.AddLessOrEqual(LinearExpr::Sum(x[i]), 1); } // Each task is assigned to exactly one worker. for (int j = 0; j < num_tasks; ++j) { @@ -72,7 +68,7 @@ void IntegerProgrammingExample() { LinearExpr total_cost; for (int i = 0; i < num_workers; ++i) { for (int j = 0; j < num_tasks; ++j) { - total_cost.AddTerm(x[i][j], costs[i][j]); + total_cost += x[i][j] * costs[i][j]; } } cp_model.Minimize(total_cost); diff --git a/ortools/sat/samples/binpacking_problem_sat.cc b/ortools/sat/samples/binpacking_problem_sat.cc index 7d06887f16..5218ee0e0a 100644 --- a/ortools/sat/samples/binpacking_problem_sat.cc +++ b/ortools/sat/samples/binpacking_problem_sat.cc @@ -54,7 +54,7 @@ void BinpackingProblemSat() { for (int b = 0; b < kNumBins; ++b) { LinearExpr expr; for (int i = 0; i < num_items; ++i) { - expr.AddTerm(x[i][b], items[i][0]); + expr += x[i][b] * items[i][0]; } cp_model.AddEquality(expr, load[b]); } diff --git a/ortools/sat/samples/channeling_sample_sat.cc b/ortools/sat/samples/channeling_sample_sat.cc index ac381a1bfe..702471a413 100644 --- a/ortools/sat/samples/channeling_sample_sat.cc +++ b/ortools/sat/samples/channeling_sample_sat.cc @@ -35,7 +35,7 @@ void ChannelingSampleSat() { // Create our two half-reified constraints. // First, b implies (y == 10 - x). - cp_model.AddEquality(LinearExpr::Sum({x, y}), 10).OnlyEnforceIf(b); + cp_model.AddEquality(x + y, 10).OnlyEnforceIf(b); // Second, not(b) implies y == 0. cp_model.AddEquality(y, 0).OnlyEnforceIf(Not(b)); diff --git a/ortools/sat/samples/copy_model_sample_sat.cc b/ortools/sat/samples/copy_model_sample_sat.cc index 8927ffcb36..4bd02b0eb6 100644 --- a/ortools/sat/samples/copy_model_sample_sat.cc +++ b/ortools/sat/samples/copy_model_sample_sat.cc @@ -35,7 +35,7 @@ void CopyModelSat() { // [END constraints] // [START objective] - cp_model.Maximize(LinearExpr::ScalProd({x, y, z}, {1, 2, 3})); + cp_model.Maximize(x + 2 * y + 3 * z); // [END objective] const CpSolverResponse initial_response = Solve(cp_model.Build()); @@ -49,7 +49,7 @@ void CopyModelSat() { IntVar copy_of_x = copy.GetIntVarFromProtoIndex(x.index()); IntVar copy_of_y = copy.GetIntVarFromProtoIndex(y.index()); - copy.AddLessOrEqual(LinearExpr::Sum({copy_of_x, copy_of_y}), 1); + copy.AddLessOrEqual(copy_of_x + copy_of_y, 1); const CpSolverResponse modified_response = Solve(copy.Build()); LOG(INFO) << "Optimal value of the modified model: " diff --git a/ortools/sat/samples/cp_is_fun_sat.cc b/ortools/sat/samples/cp_is_fun_sat.cc index 8dfb32f015..8c8a8559ba 100644 --- a/ortools/sat/samples/cp_is_fun_sat.cc +++ b/ortools/sat/samples/cp_is_fun_sat.cc @@ -61,10 +61,8 @@ void CPIsFunSat() { // CP + IS + FUN = TRUE cp_model.AddEquality( - LinearExpr::ScalProd({c, p, i, s, f, u, n}, - {kBase, 1, kBase, 1, kBase * kBase, kBase, 1}), - LinearExpr::ScalProd({t, r, u, e}, - {kBase * kBase * kBase, kBase * kBase, kBase, 1})); + c * kBase + p + i * kBase + s + f * kBase * kBase + u * kBase + n, + kBase * kBase * kBase * t + kBase * kBase * r + kBase * u + e); // [END constraints] // [START solution_printer] diff --git a/ortools/sat/samples/cp_sat_example.cc b/ortools/sat/samples/cp_sat_example.cc index 60c645e2cd..9de29000d5 100644 --- a/ortools/sat/samples/cp_sat_example.cc +++ b/ortools/sat/samples/cp_sat_example.cc @@ -35,13 +35,13 @@ void CpSatExample() { // [END variables] // [START constraints] - cp_model.AddLessOrEqual(LinearExpr::ScalProd({x, y, z}, {2, 7, 3}), 50); - cp_model.AddLessOrEqual(LinearExpr::ScalProd({x, y, z}, {3, -5, 7}), 45); - cp_model.AddLessOrEqual(LinearExpr::ScalProd({x, y, z}, {5, 2, -6}), 37); + cp_model.AddLessOrEqual(2 * x + 7 * y + 3 * z, 50); + cp_model.AddLessOrEqual(3 * x - 5 * y + 7 * z, 45); + cp_model.AddLessOrEqual(5 * x + 2 * y - 6 * z, 37); // [END constraints] // [START objective] - cp_model.Maximize(LinearExpr::ScalProd({x, y, z}, {2, 2, 3})); + cp_model.Maximize(2 * x + 2 * y + 3 * z); // [END objective] // Solving part. diff --git a/ortools/sat/samples/earliness_tardiness_cost_sample_sat.cc b/ortools/sat/samples/earliness_tardiness_cost_sample_sat.cc index fcc16f6a71..4ffe5b1440 100644 --- a/ortools/sat/samples/earliness_tardiness_cost_sample_sat.cc +++ b/ortools/sat/samples/earliness_tardiness_cost_sample_sat.cc @@ -41,21 +41,9 @@ void EarlinessTardinessCostSampleSat() { const int64_t kLargeConstant = 1000; const IntVar expr = cp_model.NewIntVar({0, kLargeConstant}); - // First segment. - const IntVar s1 = cp_model.NewIntVar({-kLargeConstant, kLargeConstant}); - cp_model.AddEquality(s1, LinearExpr::ScalProd({x}, {-kEarlinessCost}) - .AddConstant(kEarlinessCost * kEarlinessDate)); - - // Second segment. - const IntVar s2 = cp_model.NewConstant(0); - - // Third segment. - const IntVar s3 = cp_model.NewIntVar({-kLargeConstant, kLargeConstant}); - cp_model.AddEquality(s3, LinearExpr::ScalProd({x}, {kLatenessCost}) - .AddConstant(-kLatenessCost * kLatenessDate)); - - // Link together expr and x through s1, s2, and s3. - cp_model.AddMaxEquality(expr, {s1, s2, s3}); + // Link together expr and x through the 3 segments. + cp_model.AddMaxEquality(expr, {(kEarlinessDate - x) * kEarlinessCost, 0, + (x - kLatenessDate) * kLatenessCost}); // Search for x values in increasing order. cp_model.AddDecisionStrategy({x}, DecisionStrategyProto::CHOOSE_FIRST, diff --git a/ortools/sat/samples/interval_sample_sat.cc b/ortools/sat/samples/interval_sample_sat.cc index 98423d392a..f6e04ea491 100644 --- a/ortools/sat/samples/interval_sample_sat.cc +++ b/ortools/sat/samples/interval_sample_sat.cc @@ -27,8 +27,7 @@ void IntervalSampleSat() { const IntVar z = cp_model.NewIntVar(horizon).WithName("z"); const IntervalVar interval_var = - cp_model.NewIntervalVar(x, y, LinearExpr(z).AddConstant(2)) - .WithName("interval"); + cp_model.NewIntervalVar(x, y, z + 2).WithName("interval"); LOG(INFO) << "start = " << interval_var.StartExpr() << ", size = " << interval_var.SizeExpr() << ", end = " << interval_var.EndExpr() diff --git a/ortools/sat/samples/multiple_knapsack_sat.cc b/ortools/sat/samples/multiple_knapsack_sat.cc index 3fe0cd013e..6eaf81f7e5 100644 --- a/ortools/sat/samples/multiple_knapsack_sat.cc +++ b/ortools/sat/samples/multiple_knapsack_sat.cc @@ -64,8 +64,7 @@ void MultipleKnapsackSat() { for (int i : all_items) { LinearExpr expr; for (int b : all_bins) { - auto key = std::make_tuple(i, b); - expr.AddTerm(x[key], 1); + expr += x[std::make_tuple(i, b)]; } cp_model.AddLessOrEqual(expr, 1); } @@ -74,8 +73,7 @@ void MultipleKnapsackSat() { for (int b : all_bins) { LinearExpr bin_weight; for (int i : all_items) { - auto key = std::make_tuple(i, b); - bin_weight.AddTerm(x[key], weights[i]); + bin_weight += x[std::make_tuple(i, b)] * weights[i]; } cp_model.AddLessOrEqual(bin_weight, bin_capacities[b]); } @@ -87,8 +85,7 @@ void MultipleKnapsackSat() { LinearExpr objective; for (int i : all_items) { for (int b : all_bins) { - auto key = std::make_tuple(i, b); - objective.AddTerm(x[key], values[i]); + objective += x[std::make_tuple(i, b)] * values[i]; } } cp_model.Maximize(objective); diff --git a/ortools/sat/samples/nqueens_sat.cc b/ortools/sat/samples/nqueens_sat.cc index c29151b93c..3307c2fe09 100644 --- a/ortools/sat/samples/nqueens_sat.cc +++ b/ortools/sat/samples/nqueens_sat.cc @@ -54,8 +54,8 @@ void NQueensSat(const int board_size) { std::vector diag_2; diag_2.reserve(board_size); for (int i = 0; i < board_size; ++i) { - diag_1.push_back(queens[i].AddConstant(i)); - diag_2.push_back(queens[i].AddConstant(-i)); + diag_1.push_back(queens[i] + i); + diag_2.push_back(queens[i] - i); } cp_model.AddAllDifferentExpr(diag_1); cp_model.AddAllDifferentExpr(diag_2); diff --git a/ortools/sat/samples/optional_interval_sample_sat.cc b/ortools/sat/samples/optional_interval_sample_sat.cc index 846e91cc3c..fd876e39b0 100644 --- a/ortools/sat/samples/optional_interval_sample_sat.cc +++ b/ortools/sat/samples/optional_interval_sample_sat.cc @@ -29,9 +29,7 @@ void OptionalIntervalSampleSat() { const BoolVar presence_var = cp_model.NewBoolVar().WithName("presence"); const IntervalVar interval_var = - cp_model - .NewOptionalIntervalVar(x, y, LinearExpr(z).AddConstant(2), - presence_var) + cp_model.NewOptionalIntervalVar(x, y, z + 2, presence_var) .WithName("interval"); LOG(INFO) << "start = " << interval_var.StartExpr() << ", size = " << interval_var.SizeExpr() diff --git a/ortools/sat/samples/rabbits_and_pheasants_sat.cc b/ortools/sat/samples/rabbits_and_pheasants_sat.cc index 226296581e..43ab382b20 100644 --- a/ortools/sat/samples/rabbits_and_pheasants_sat.cc +++ b/ortools/sat/samples/rabbits_and_pheasants_sat.cc @@ -24,8 +24,8 @@ void RabbitsAndPheasantsSat() { const IntVar pheasants = cp_model.NewIntVar(all_animals).WithName("pheasants"); - cp_model.AddEquality(LinearExpr::Sum({rabbits, pheasants}), 20); - cp_model.AddEquality(LinearExpr::ScalProd({rabbits, pheasants}, {4, 2}), 56); + cp_model.AddEquality(rabbits + pheasants, 20); + cp_model.AddEquality(4 * rabbits + 2 * pheasants, 56); const CpSolverResponse response = Solve(cp_model.Build()); diff --git a/ortools/sat/samples/ranking_sample_sat.cc b/ortools/sat/samples/ranking_sample_sat.cc index 260cf0ba9f..80cd15cd5f 100644 --- a/ortools/sat/samples/ranking_sample_sat.cc +++ b/ortools/sat/samples/ranking_sample_sat.cc @@ -67,7 +67,7 @@ void RankingSampleSat() { for (int i = 0; i < num_tasks; ++i) { LinearExpr sum_of_predecessors(-1); for (int j = 0; j < num_tasks; ++j) { - sum_of_predecessors.AddVar(precedences[j][i]); + sum_of_predecessors += precedences[j][i]; } cp_model.AddEquality(ranks[i], sum_of_predecessors); } @@ -116,10 +116,9 @@ void RankingSampleSat() { // Create objective: minimize 2 * makespan - 7 * sum of presences. // That is you gain 7 by interval performed, but you pay 2 by day of delays. - LinearExpr objective; - objective.AddTerm(makespan, 2); + LinearExpr objective = 2 * makespan; for (int t = 0; t < kNumTasks; ++t) { - objective.AddTerm(presences[t], -7); + objective -= 7 * presences[t]; } cp_model.Minimize(objective); diff --git a/ortools/sat/samples/schedule_requests_sat.cc b/ortools/sat/samples/schedule_requests_sat.cc index fd12879e04..da17997cf2 100644 --- a/ortools/sat/samples/schedule_requests_sat.cc +++ b/ortools/sat/samples/schedule_requests_sat.cc @@ -168,21 +168,18 @@ void ScheduleRequestsSat() { // [END assign_nurses_evenly] // [START objective] - std::vector tmp; + LinearExpr objective_expr; for (int n : all_nurses) { for (int d : all_days) { for (int s : all_shifts) { if (shift_requests[n][d][s] == 1) { auto key = std::make_tuple(n, d, s); - // tmp.push_back(shifts[key]); - tmp.push_back( - LinearExpr::ScalProd({shifts[key]}, {shift_requests[n][d][s]}) - .Var()); + objective_expr += shifts[key] * shift_requests[n][d][s]; } } } } - cp_model.Maximize(LinearExpr::Sum(tmp)); + cp_model.Maximize(objective_expr); // [END objective] // [START solve] diff --git a/ortools/sat/samples/solution_hinting_sample_sat.cc b/ortools/sat/samples/solution_hinting_sample_sat.cc index 155b168472..c7086dd7d6 100644 --- a/ortools/sat/samples/solution_hinting_sample_sat.cc +++ b/ortools/sat/samples/solution_hinting_sample_sat.cc @@ -35,7 +35,7 @@ void SolutionHintingSampleSat() { // [END constraints] // [START objective] - cp_model.Maximize(LinearExpr::ScalProd({x, y, z}, {1, 2, 3})); + cp_model.Maximize(x + 2 * y + 3 * z); // [END objective] // Solution hinting: x <- 1, y <- 2 diff --git a/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.cc b/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.cc index a6b689d238..eba1e2fde9 100644 --- a/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.cc +++ b/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.cc @@ -35,7 +35,7 @@ void SolveAndPrintIntermediateSolutionsSampleSat() { // [END constraints] // [START objective] - cp_model.Maximize(LinearExpr::ScalProd({x, y, z}, {1, 2, 3})); + cp_model.Maximize(x + 2 * y + 3 * z); // [END objective] // [START print_solution]