[CP-SAT] support linear expressions in decision strategies

This commit is contained in:
Laurent Perron
2023-12-08 14:49:43 +01:00
parent 5f52c87b50
commit 3219d12658
9 changed files with 109 additions and 16 deletions

View File

@@ -983,13 +983,24 @@ public final class CpModel {
// DecisionStrategy
/** Adds {@code DecisionStrategy(variables, varStr, domStr)}. */
public void addDecisionStrategy(IntVar[] variables,
/** Adds {@code DecisionStrategy(expressions, varStr, domStr)}. */
public void addDecisionStrategy(LinearArgument[] expressions,
DecisionStrategyProto.VariableSelectionStrategy varStr,
DecisionStrategyProto.DomainReductionStrategy domStr) {
DecisionStrategyProto.Builder ds = modelBuilder.addSearchStrategyBuilder();
for (IntVar var : variables) {
ds.addVariables(var.getIndex());
for (LinearArgument arg : expressions) {
ds.addExprs(getLinearExpressionProtoBuilderFromLinearArgument(arg, /* negate= */ false));
}
ds.setVariableSelectionStrategy(varStr).setDomainReductionStrategy(domStr);
}
/** Adds {@code DecisionStrategy(expressions, varStr, domStr)}. */
public void addDecisionStrategy(Iterable<? extends LinearArgument> expressions,
DecisionStrategyProto.VariableSelectionStrategy varStr,
DecisionStrategyProto.DomainReductionStrategy domStr) {
DecisionStrategyProto.Builder ds = modelBuilder.addSearchStrategyBuilder();
for (LinearArgument arg : expressions) {
ds.addExprs(getLinearExpressionProtoBuilderFromLinearArgument(arg, /* negate= */ false));
}
ds.setVariableSelectionStrategy(varStr).setDomainReductionStrategy(domStr);
}

View File

@@ -1243,19 +1243,39 @@ void CpModelBuilder::AddDecisionStrategy(
DecisionStrategyProto::DomainReductionStrategy domain_strategy) {
DecisionStrategyProto* const proto = cp_model_.add_search_strategy();
for (const IntVar& var : variables) {
proto->add_variables(var.index_);
LinearExpressionProto* expr = proto->add_exprs();
if (var.index_ >= 0) {
expr->add_vars(var.index_);
expr->add_coeffs(1);
} else {
expr->add_vars(PositiveRef(var.index_));
expr->add_coeffs(-1);
expr->set_offset(1);
}
}
proto->set_variable_selection_strategy(var_strategy);
proto->set_domain_reduction_strategy(domain_strategy);
}
void CpModelBuilder::AddDecisionStrategy(
absl::Span<const BoolVar> variables,
absl::Span<const LinearExpr> expressions,
DecisionStrategyProto::VariableSelectionStrategy var_strategy,
DecisionStrategyProto::DomainReductionStrategy domain_strategy) {
DecisionStrategyProto* const proto = cp_model_.add_search_strategy();
for (const BoolVar& var : variables) {
proto->add_variables(var.index_);
for (const LinearExpr& expr : expressions) {
*proto->add_exprs() = LinearExprToProto(expr);
}
proto->set_variable_selection_strategy(var_strategy);
proto->set_domain_reduction_strategy(domain_strategy);
}
void CpModelBuilder::AddDecisionStrategy(
std::initializer_list<LinearExpr> expressions,
DecisionStrategyProto::VariableSelectionStrategy var_strategy,
DecisionStrategyProto::DomainReductionStrategy domain_strategy) {
DecisionStrategyProto* const proto = cp_model_.add_search_strategy();
for (const LinearExpr& expr : expressions) {
*proto->add_exprs() = LinearExprToProto(expr);
}
proto->set_variable_selection_strategy(var_strategy);
proto->set_domain_reduction_strategy(domain_strategy);

View File

@@ -1066,9 +1066,15 @@ class CpModelBuilder {
DecisionStrategyProto::VariableSelectionStrategy var_strategy,
DecisionStrategyProto::DomainReductionStrategy domain_strategy);
/// Adds a decision strategy on a list of boolean variables.
/// Adds a decision strategy on a list of affine expressions.
void AddDecisionStrategy(
absl::Span<const BoolVar> variables,
absl::Span<const LinearExpr> expressions,
DecisionStrategyProto::VariableSelectionStrategy var_strategy,
DecisionStrategyProto::DomainReductionStrategy domain_strategy);
/// Adds a decision strategy on a list of affine expressions.
void AddDecisionStrategy(
std::initializer_list<LinearExpr> expressions,
DecisionStrategyProto::VariableSelectionStrategy var_strategy,
DecisionStrategyProto::DomainReductionStrategy domain_strategy);

View File

@@ -1026,7 +1026,10 @@ public class CpModel
ds.Variables.TrySetCapacity(vars);
foreach (IntVar var in vars)
{
ds.Variables.Add(var.Index);
LinearExpressionProto expr = new LinearExpressionProto();
expr.Vars.add(var.Index);
expr.Coeffs.add(1);
ds.Exprs.Add(expr);
}
ds.VariableSelectionStrategy = var_str;
ds.DomainReductionStrategy = dom_str;

View File

@@ -16,6 +16,7 @@
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>

View File

@@ -18,6 +18,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull;
import com.google.ortools.Loader;
import com.google.ortools.sat.CpSolverStatus;
import com.google.ortools.sat.DecisionStrategyProto;
import com.google.ortools.sat.LinearArgumentProto;
import com.google.ortools.util.Domain;
import java.util.ArrayList;
@@ -470,6 +471,40 @@ public final class CpModelTest {
assertThat(model.model().getConstraints(2).getNoOverlap2D().getYIntervalsCount()).isEqualTo(2);
}
@Test
public void testCpModelAddDecisionStrategy() throws Exception {
final CpModel model = new CpModel();
assertNotNull(model);
final Literal x1 = model.newBoolVar("x1");
final Literal x2 = model.newBoolVar("x2");
final IntVar x3 = model.newIntVar(0, 2, "x3");
model.addDecisionStrategy(new LinearArgument[] {x1, x2.not(), x3},
DecisionStrategyProto.VariableSelectionStrategy.CHOOSE_FIRST,
DecisionStrategyProto.DomainReductionStrategy.SELECT_MIN_VALUE);
assertThat(model.model().getSearchStrategyCount()).isEqualTo(1);
assertThat(model.model().getSearchStrategy(0).getExprsCount()).isEqualTo(3);
assertThat(model.model().getSearchStrategy(0).getExprs(0).getVarsCount()).isEqualTo(1);
assertThat(model.model().getSearchStrategy(0).getExprs(0).getCoeffsCount()).isEqualTo(1);
assertThat(model.model().getSearchStrategy(0).getExprs(0).getVars(0)).isEqualTo(x1.getIndex());
assertThat(model.model().getSearchStrategy(0).getExprs(0).getCoeffs(0)).isEqualTo(1);
assertThat(model.model().getSearchStrategy(0).getExprs(0).getOffset()).isEqualTo(0);
assertThat(model.model().getSearchStrategy(0).getExprs(1).getVarsCount()).isEqualTo(1);
assertThat(model.model().getSearchStrategy(0).getExprs(1).getCoeffsCount()).isEqualTo(1);
assertThat(model.model().getSearchStrategy(0).getExprs(1).getVars(0)).isEqualTo(x2.getIndex());
assertThat(model.model().getSearchStrategy(0).getExprs(1).getCoeffs(0)).isEqualTo(-1);
assertThat(model.model().getSearchStrategy(0).getExprs(1).getOffset()).isEqualTo(1);
assertThat(model.model().getSearchStrategy(0).getExprs(2).getVarsCount()).isEqualTo(1);
assertThat(model.model().getSearchStrategy(0).getExprs(2).getCoeffsCount()).isEqualTo(1);
assertThat(model.model().getSearchStrategy(0).getExprs(2).getVars(0)).isEqualTo(x3.getIndex());
assertThat(model.model().getSearchStrategy(0).getExprs(2).getCoeffs(0)).isEqualTo(1);
assertThat(model.model().getSearchStrategy(0).getExprs(2).getOffset()).isEqualTo(0);
}
@Test
public void testCpModelModelStats() throws Exception {
final CpModel model = new CpModel();

View File

@@ -2860,7 +2860,15 @@ class CpModel:
strategy = self.__model.search_strategy.add()
for v in variables:
strategy.variables.append(v.index)
expr = strategy.exprs.add()
if v.index >= 0:
expr.vars.append(v.index)
expr.coeffs.append(1)
else:
expr.vars.append(self.negated(v.index))
expr.coeffs.append(-1)
expr.offset = 1
strategy.variable_selection_strategy = var_strategy
strategy.domain_reduction_strategy = domain_strategy

View File

@@ -1314,14 +1314,22 @@ class CpModelTest(absltest.TestCase):
model = cp_model.CpModel()
x = model.new_int_var(0, 5, "x")
y = model.new_int_var(0, 5, "y")
z = model.new_bool_var("z")
model.add_decision_strategy(
[y, x], cp_model.CHOOSE_MIN_DOMAIN_SIZE, cp_model.SELECT_MAX_VALUE
[y, x, z.negated()],
cp_model.CHOOSE_MIN_DOMAIN_SIZE,
cp_model.SELECT_MAX_VALUE,
)
self.assertLen(model.proto.search_strategy, 1)
strategy = model.proto.search_strategy[0]
self.assertLen(strategy.variables, 2)
self.assertEqual(y.index, strategy.variables[0])
self.assertEqual(x.index, strategy.variables[1])
self.assertLen(strategy.exprs, 3)
self.assertEqual(y.index, strategy.exprs[0].vars[0])
self.assertEqual(1, strategy.exprs[0].coeffs[0])
self.assertEqual(x.index, strategy.exprs[1].vars[0])
self.assertEqual(1, strategy.exprs[1].coeffs[0])
self.assertEqual(z.index, strategy.exprs[2].vars[0])
self.assertEqual(-1, strategy.exprs[2].coeffs[0])
self.assertEqual(1, strategy.exprs[2].offset)
self.assertEqual(
cp_model.CHOOSE_MIN_DOMAIN_SIZE, strategy.variable_selection_strategy
)

View File

@@ -15,6 +15,7 @@
#define OR_TOOLS_SAT_STAT_TABLES_H_
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/string_view.h"