[FZ] use ortools_count_eq to reduce the size of fzn files

This commit is contained in:
Laurent Perron
2025-03-21 19:47:01 -07:00
parent 9f884f4118
commit 49056e7c60
3 changed files with 154 additions and 0 deletions

View File

@@ -1304,6 +1304,8 @@ CallMap CreateCallMap() {
m["ortools_array_var_bool_element"] = CheckOrtoolsArrayIntElement;
m["ortools_array_var_int_element"] = CheckOrtoolsArrayIntElement;
m["ortools_circuit"] = CheckCircuit;
m["ortools_count_eq"] = CheckCountEq;
m["ortools_count_eq_cst"] = CheckCountEq;
m["ortools_cumulative_opt"] = CheckCumulativeOpt;
m["ortools_disjunctive_strict_opt"] = CheckDisjunctiveStrictOpt;
m["ortools_inverse"] = CheckInverse;

View File

@@ -75,8 +75,15 @@ struct CpModelProtoWithMapping {
LinearExpressionProto LookupExprAt(const fz::Argument& argument, int pos,
bool negate = false);
std::vector<int> LookupVars(const fz::Argument& argument);
VarOrValue LookupVarOrValue(const fz::Argument& argument);
std::vector<VarOrValue> LookupVarsOrValues(const fz::Argument& argument);
// Get or create a literal that is equivalent tovar == value.
int GetOrCreateLiteralForVarEqValue(int var, int64_t value);
// Get or create a literal that is equivalent to var1 == var2.
int GetOrCreateLiteralForVarEqVar(int var1, int var2);
// Create and return the indices of the IntervalConstraint corresponding
// to the flatzinc "interval" specified by a start var and a size var.
// This method will cache intervals with the key <start, size>.
@@ -137,6 +144,8 @@ struct CpModelProtoWithMapping {
absl::flat_hash_map<std::tuple<int, int64_t, int, int64_t, int>, int>
interval_key_to_index;
absl::flat_hash_map<int, int> var_to_lit_implies_greater_than_zero;
absl::flat_hash_map<std::pair<int, int64_t>, int> var_eq_value_to_literal;
absl::flat_hash_map<std::pair<int, int>, int> var_eq_var_to_literal;
};
int CpModelProtoWithMapping::LookupConstant(int64_t value) {
@@ -205,6 +214,22 @@ std::vector<int> CpModelProtoWithMapping::LookupVars(
return result;
}
VarOrValue CpModelProtoWithMapping::LookupVarOrValue(
const fz::Argument& argument) {
if (argument.type == fz::Argument::INT_VALUE) {
return {kNoVar, argument.Value()};
} else {
CHECK_EQ(argument.type, fz::Argument::VAR_REF);
fz::Variable* var = argument.Var();
CHECK(var != nullptr);
if (var->domain.HasOneValue()) {
return {kNoVar, var->domain.Value()};
} else {
return {fz_var_to_index[var], 0};
}
}
}
std::vector<VarOrValue> CpModelProtoWithMapping::LookupVarsOrValues(
const fz::Argument& argument) {
std::vector<VarOrValue> result;
@@ -238,6 +263,71 @@ ConstraintProto* CpModelProtoWithMapping::AddEnforcedConstraint(int literal) {
return result;
}
int CpModelProtoWithMapping::GetOrCreateLiteralForVarEqValue(int var,
int64_t value) {
const std::pair<int, int64_t> key = {var, value};
const auto it = var_eq_value_to_literal.find(key);
if (it != var_eq_value_to_literal.end()) return it->second;
const int bool_var = proto.variables_size();
IntegerVariableProto* var_proto = proto.add_variables();
var_proto->add_domain(0);
var_proto->add_domain(1);
ConstraintProto* is_eq = AddEnforcedConstraint(TrueLiteral(bool_var));
is_eq->mutable_linear()->add_vars(var);
is_eq->mutable_linear()->add_coeffs(1);
is_eq->mutable_linear()->add_domain(value);
is_eq->mutable_linear()->add_domain(value);
ConstraintProto* is_not_eq = AddEnforcedConstraint(FalseLiteral(bool_var));
is_not_eq->mutable_linear()->add_vars(var);
is_not_eq->mutable_linear()->add_coeffs(1);
is_not_eq->mutable_linear()->add_domain(std::numeric_limits<int64_t>::min());
is_not_eq->mutable_linear()->add_domain(value - 1);
is_not_eq->mutable_linear()->add_domain(value + 1);
is_not_eq->mutable_linear()->add_domain(std::numeric_limits<int64_t>::max());
var_eq_value_to_literal[key] = bool_var;
return bool_var;
}
int CpModelProtoWithMapping::GetOrCreateLiteralForVarEqVar(int var1, int var2) {
CHECK_NE(var1, kNoVar);
CHECK_NE(var2, kNoVar);
if (var1 > var2) std::swap(var1, var2);
if (var1 == var2) return LookupConstant(1);
const std::pair<int, int> key = {var1, var2};
const auto it = var_eq_var_to_literal.find(key);
if (it != var_eq_var_to_literal.end()) return it->second;
const int bool_var = proto.variables_size();
IntegerVariableProto* var_proto = proto.add_variables();
var_proto->add_domain(0);
var_proto->add_domain(1);
ConstraintProto* is_eq = AddEnforcedConstraint(TrueLiteral(bool_var));
is_eq->mutable_linear()->add_vars(var1);
is_eq->mutable_linear()->add_coeffs(1);
is_eq->mutable_linear()->add_vars(var2);
is_eq->mutable_linear()->add_coeffs(-1);
is_eq->mutable_linear()->add_domain(0);
is_eq->mutable_linear()->add_domain(0);
ConstraintProto* is_not_eq = AddEnforcedConstraint(FalseLiteral(bool_var));
is_not_eq->mutable_linear()->add_vars(var1);
is_not_eq->mutable_linear()->add_coeffs(1);
is_not_eq->mutable_linear()->add_vars(var2);
is_not_eq->mutable_linear()->add_coeffs(-1);
is_not_eq->mutable_linear()->add_domain(std::numeric_limits<int64_t>::min());
is_not_eq->mutable_linear()->add_domain(-1);
is_not_eq->mutable_linear()->add_domain(1);
is_not_eq->mutable_linear()->add_domain(std::numeric_limits<int64_t>::max());
var_eq_var_to_literal[key] = bool_var;
return bool_var;
}
int CpModelProtoWithMapping::GetOrCreateOptionalInterval(VarOrValue start,
VarOrValue size,
int opt_var) {
@@ -801,6 +891,60 @@ void CpModelProtoWithMapping::FillConstraint(const fz::Constraint& fz_ct,
for (int i = 0; i < fz_ct.arguments[0].Size(); ++i) {
*arg->add_exprs() = LookupExprAt(fz_ct.arguments[0], i);
}
} else if (fz_ct.type == "ortools_count_eq_cst") {
const std::vector<VarOrValue> counts =
LookupVarsOrValues(fz_ct.arguments[0]);
const int64_t value = fz_ct.arguments[1].Value();
const VarOrValue target = LookupVarOrValue(fz_ct.arguments[2]);
LinearConstraintProto* arg = ct->mutable_linear();
int64_t fixed_contributions = 0;
for (const VarOrValue& count : counts) {
if (count.var == kNoVar) {
fixed_contributions += count.value == value ? 1 : 0;
} else {
const int boolvar = GetOrCreateLiteralForVarEqValue(count.var, value);
CHECK_GE(boolvar, 0);
arg->add_vars(boolvar);
arg->add_coeffs(1);
}
}
if (target.var == kNoVar) {
arg->add_domain(target.value - fixed_contributions);
arg->add_domain(target.value - fixed_contributions);
} else {
arg->add_vars(target.var);
arg->add_coeffs(-1);
arg->add_domain(-fixed_contributions);
arg->add_domain(-fixed_contributions);
}
} else if (fz_ct.type == "ortools_count_eq") {
const std::vector<VarOrValue> counts =
LookupVarsOrValues(fz_ct.arguments[0]);
const int var = LookupVar(fz_ct.arguments[1]);
const VarOrValue target = LookupVarOrValue(fz_ct.arguments[2]);
LinearConstraintProto* arg = ct->mutable_linear();
for (const VarOrValue& count : counts) {
if (count.var == kNoVar) {
const int boolvar = GetOrCreateLiteralForVarEqValue(var, count.value);
CHECK_GE(boolvar, 0);
arg->add_vars(boolvar);
arg->add_coeffs(1);
} else {
const int boolvar = GetOrCreateLiteralForVarEqVar(var, count.var);
CHECK_GE(boolvar, 0);
arg->add_vars(boolvar);
arg->add_coeffs(1);
}
}
if (target.var == kNoVar) {
arg->add_domain(target.value);
arg->add_domain(target.value);
} else {
arg->add_vars(target.var);
arg->add_coeffs(-1);
arg->add_domain(0);
arg->add_domain(0);
}
} else if (fz_ct.type == "ortools_circuit" ||
fz_ct.type == "ortools_subcircuit") {
const int64_t min_index = fz_ct.arguments[1].Value();

View File

@@ -0,0 +1,8 @@
predicate ortools_count_eq(array [int] of var int: x, var int: y, var int: c);
predicate ortools_count_eq_cst(array [int] of var int: x, int: y, var int: c);
predicate fzn_count_eq(array [int] of var int: x, int: y, var int: c) =
ortools_count_eq_cst(x, y, c);
predicate fzn_count_eq(array [int] of var int: x, var int: y, var int: c) =
ortools_count_eq(x, y, c);