math_opt: export from google3

This commit is contained in:
Corentin Le Molgat
2025-11-24 08:30:12 +01:00
parent 444331501f
commit d15a2e67e3
20 changed files with 501 additions and 118 deletions

View File

@@ -47,7 +47,8 @@ enum CallbackEventProto {
// node). Useful for early termination. Note that this event does not provide
// information on LP relaxations nor about new incumbent solutions.
//
// This event is supported for MIP models by SOLVER_TYPE_GUROBI only.
// This event is fully supported for MIP models by SOLVER_TYPE_GUROBI only. If
// used with SOLVER_TYPE_CP_SAT, it is called when the dual bound is improved.
CALLBACK_EVENT_MIP = 3;
// Called every time a new MIP incumbent is found.
@@ -127,7 +128,8 @@ message CallbackDataProto {
BarrierStats barrier_stats = 6;
// MIP B&B stats. Only available during CALLBACK_EVENT_MIPxxxx events.
// Not supported for CP-SAT.
// When using CP-SAT, only primal_bound, dual_bound and
// number_of_solutions_found are populated.
message MipStats {
optional double primal_bound = 1;
optional double dual_bound = 2;

View File

@@ -33,8 +33,9 @@ namespace operations_research::math_opt {
// The API of solvers (in-process, sub-process and streaming RPC ones).
//
// Thread-safety: methods Solve() and Update() must not be called concurrently;
// they should immediately return with an error status if this happens.
// Thread-safety: methods Solve(), ComputeInfeasibleSubsystem() and Update()
// must not be called concurrently; they should immediately return with an error
// status if this happens.
//
// TODO: b/350984134 - Rename `Solver` into `InProcessSolver` and then rename
// `BaseSolver` into `Solver`.
@@ -65,7 +66,14 @@ class BaseSolver {
// printed on stdout/stderr/logs anymore.
MessageCallback message_callback = nullptr;
// Registration parameter controlling calls to user_cb.
CallbackRegistrationProto callback_registration;
// An optional MIP/LP callback. Only called for events registered in
// callback_registration.
//
// Solve() returns an error if called without a user_cb but with some
// non-empty callback_registration.request_registration.
Callback user_cb = nullptr;
// An optional interrupter that the solver can use to interrupt the solve

View File

@@ -120,10 +120,16 @@ absl::StatusOr<SolveResultProto> Solver::Solve(const SolveArgs& arguments) {
ValidateModelSolveParameters(arguments.model_parameters, model_summary_))
<< "invalid model_parameters";
RETURN_IF_ERROR(ValidateCallbackRegistration(arguments.callback_registration,
model_summary_));
SolverInterface::Callback cb = nullptr;
if (!arguments.callback_registration.request_registration().empty() &&
arguments.user_cb == nullptr) {
return absl::InvalidArgumentError(
"no callback function was provided but callback events were "
"registered");
}
if (arguments.user_cb != nullptr) {
RETURN_IF_ERROR(ValidateCallbackRegistration(
arguments.callback_registration, model_summary_));
cb = [&](const CallbackDataProto& callback_data)
-> absl::StatusOr<CallbackResultProto> {
RETURN_IF_ERROR(ValidateCallbackDataProto(

View File

@@ -43,8 +43,9 @@ namespace math_opt {
//
// This interface is not meant to be used directly. The actual API is the one of
// the Solver class. The Solver class validates the models before calling this
// interface. It makes sure no concurrent calls happen on Solve(), CanUpdate()
// and Update(). It makes sure no other function is called after Solve(),
// interface. It makes sure no concurrent calls happen on Solve(),
// ComputeInfeasibleSubsystem(), CanUpdate() and Update(). It makes sure no
// other function is called after Solve(), ComputeInfeasibleSubsystem(),
// Update() or a callback have failed.
//
// Implementations of this interface should not have public constructors but
@@ -69,12 +70,28 @@ class SolverInterface {
// See Solver::MessageCallback documentation for details.
using MessageCallback = std::function<void(const std::vector<std::string>&)>;
// A callback function (if non null) is a function that validates its input
// and its output, and if fails, return a status. The invariant is that the
// solver implementation can rely on receiving valid data. The implementation
// of this interface must provide valid input (which will be validated) and
// in error, it will return a status (without actually calling the callback
// function). This is enforced in the solver.cc layer.
// A callback function (if non null) provided by the Solver class to its
// SolverInterface that wraps the user callback function
// (BaseSolver::Callback) and validates its inputs (provided by the
// SolverInterface implementation) and outputs (provided by the user). A
// failing status is returned if those inputs or outputs are invalid.
//
// To be clear the SolverInterface::Callback is implemented by the Solver
// class and looks like:
//
// absl::Status Callback(const CallbackDataProto& callback_data) {
// RETURN_IF_ERROR(ValidateCallbackDataProto(callback_data, ...));
// CallbackResultProto result = user_cb(callback_data);
// RETURN_IF_ERROR(ValidateCallbackResultProto(result));
// return result;
// }
//
// As a consequence SolverInterface implementations can rely on receiving a
// valid CallbackResultProto.
//
// When the SolverInterface::Callback returns an error the SolverInterface
// implementation must interrupt the Solve() as soon as possible and return
// this error.
using Callback = std::function<absl::StatusOr<CallbackResultProto>(
const CallbackDataProto&)>;
@@ -114,7 +131,11 @@ class SolverInterface {
// When parameter `message_cb` is not null and the underlying solver does not
// supports message callbacks, it should ignore it.
//
// Solvers should return a InvalidArgumentError when called with events on
// The parameter `cb` won't be null when
// callback_registration.request_registration is not empty (solver.cc will
// return an error in that case before calling SolverInterface::Solve()).
//
// Solvers should return an InvalidArgumentError when called with events on
// callback_registration that are not supported by the solver for the type of
// model being solved (for example MIP events if the model is an LP, or events
// that are not emitted by the solver). Solvers should use

View File

@@ -895,6 +895,7 @@ cc_library(
cc_library(
name = "incremental_solver",
srcs = ["incremental_solver.cc"],
hdrs = ["incremental_solver.h"],
deps = [
":compute_infeasible_subsystem_arguments",
@@ -903,10 +904,25 @@ cc_library(
":solve_arguments",
":solve_result",
":update_result",
"//ortools/base:status_macros",
"@abseil-cpp//absl/status:statusor",
],
)
cc_test(
name = "incremental_solver_test",
srcs = ["incremental_solver_test.cc"],
deps = [
":incremental_solver",
":matchers",
":math_opt",
"//ortools/base:gmock_main",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings:string_view",
],
)
cc_library(
name = "remote_streaming_mode",
srcs = ["remote_streaming_mode.cc"],

View File

@@ -109,7 +109,9 @@ enum class CallbackEvent {
// node). Useful for early termination. Note that this event does not provide
// information on LP relaxations nor about new incumbent solutions.
//
// This event is supported for MIP models with SolverType::kGurobi only.
// This event is fully supported for MIP models with SolverType::kGurobi only.
// If used with SolverType::kCpSat, it is called when the dual bound is
// improved.
kMip = CALLBACK_EVENT_MIP,
// Called every time a new MIP incumbent is found.

View File

@@ -0,0 +1,34 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/cpp/incremental_solver.h"
#include "absl/status/statusor.h"
#include "ortools/base/status_macros.h"
namespace operations_research::math_opt {
absl::StatusOr<SolveResult> IncrementalSolver::Solve(
const SolveArguments& arguments) {
RETURN_IF_ERROR(Update().status());
return SolveWithoutUpdate(arguments);
}
absl::StatusOr<ComputeInfeasibleSubsystemResult>
IncrementalSolver::ComputeInfeasibleSubsystem(
const ComputeInfeasibleSubsystemArguments& arguments) {
RETURN_IF_ERROR(Update().status());
return ComputeInfeasibleSubsystemWithoutUpdate(arguments);
}
} // namespace operations_research::math_opt

View File

@@ -112,21 +112,14 @@ class IncrementalSolver {
//
// See callback.h for documentation on arguments.callback and
// arguments.callback_registration.
virtual absl::StatusOr<SolveResult> Solve(
const SolveArguments& arguments) = 0;
absl::StatusOr<SolveResult> Solve() { return Solve({}); }
absl::StatusOr<SolveResult> Solve(const SolveArguments& arguments = {});
// Updates the underlying solver with latest model changes and runs the
// computation.
//
// Same as Solve() but compute the infeasible subsystem.
virtual absl::StatusOr<ComputeInfeasibleSubsystemResult>
ComputeInfeasibleSubsystem(
const ComputeInfeasibleSubsystemArguments& arguments) = 0;
absl::StatusOr<ComputeInfeasibleSubsystemResult>
ComputeInfeasibleSubsystem() {
return ComputeInfeasibleSubsystem({});
}
absl::StatusOr<ComputeInfeasibleSubsystemResult> ComputeInfeasibleSubsystem(
const ComputeInfeasibleSubsystemArguments& arguments = {});
// Updates the model to solve.
//

View File

@@ -0,0 +1,121 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/cpp/incremental_solver.h"
#include <string>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "gtest/gtest.h"
#include "ortools/base/gmock.h"
#include "ortools/math_opt/cpp/matchers.h"
#include "ortools/math_opt/cpp/math_opt.h"
namespace operations_research::math_opt {
namespace {
using ::testing::_;
using ::testing::Return;
using ::testing::status::IsOkAndHolds;
using ::testing::status::StatusIs;
class MockIncrementalSolver final : public IncrementalSolver {
public:
MOCK_METHOD(absl::StatusOr<UpdateResult>, Update, (), (override));
MOCK_METHOD(absl::StatusOr<SolveResult>, SolveWithoutUpdate,
(const SolveArguments&), (const, override));
MOCK_METHOD(absl::StatusOr<ComputeInfeasibleSubsystemResult>,
ComputeInfeasibleSubsystemWithoutUpdate,
(const ComputeInfeasibleSubsystemArguments&), (const, override));
MOCK_METHOD(SolverType, solver_type, (), (const, override));
};
TEST(IncrementalSolverTest, SolveWithFailingUpdate) {
MockIncrementalSolver incremental_solver;
EXPECT_CALL(incremental_solver, Update())
.WillOnce(Return(absl::InternalError("oops")));
EXPECT_THAT(incremental_solver.Solve(),
StatusIs(absl::StatusCode::kInternal, "oops"));
}
TEST(IncrementalSolverTest, SolveWithFailingSolveWithoutUpdate) {
MockIncrementalSolver incremental_solver;
EXPECT_CALL(incremental_solver, Update())
.WillOnce(Return(UpdateResult(/*did_update=*/true)));
EXPECT_CALL(incremental_solver, SolveWithoutUpdate(_))
.WillOnce(Return(absl::InternalError("oops")));
EXPECT_THAT(incremental_solver.Solve(),
StatusIs(absl::StatusCode::kInternal, "oops"));
}
TEST(IncrementalSolverTest, SuccessfulSolve) {
MockIncrementalSolver incremental_solver;
EXPECT_CALL(incremental_solver, Update())
.WillOnce(Return(UpdateResult(/*did_update=*/true)));
constexpr double kObjectiveValue = 3.5;
constexpr absl::string_view kDetail = "found the optimum!";
EXPECT_CALL(incremental_solver, SolveWithoutUpdate(_))
.WillOnce(Return(
SolveResult(Termination::Optimal(/*objective_value=*/kObjectiveValue,
/*detail=*/std::string(kDetail)))));
ASSERT_OK_AND_ASSIGN(const SolveResult solve_result,
incremental_solver.Solve());
EXPECT_THAT(solve_result.termination,
TerminationIsOptimal(/*primal_objective_value=*/kObjectiveValue));
EXPECT_EQ(solve_result.termination.detail, kDetail);
}
TEST(IncrementalSolverTest, ComputeInfeasibleSubsystemWithFailingUpdate) {
MockIncrementalSolver incremental_solver;
EXPECT_CALL(incremental_solver, Update())
.WillOnce(Return(absl::InternalError("oops")));
EXPECT_THAT(incremental_solver.ComputeInfeasibleSubsystem(),
StatusIs(absl::StatusCode::kInternal, "oops"));
}
TEST(IncrementalSolverTest,
ComputeInfeasibleSubsystemWithFailingComputeWithoutUpdate) {
MockIncrementalSolver incremental_solver;
EXPECT_CALL(incremental_solver, Update())
.WillOnce(Return(UpdateResult(/*did_update=*/true)));
EXPECT_CALL(incremental_solver, ComputeInfeasibleSubsystemWithoutUpdate(_))
.WillOnce(Return(absl::InternalError("oops")));
EXPECT_THAT(incremental_solver.ComputeInfeasibleSubsystem(),
StatusIs(absl::StatusCode::kInternal, "oops"));
}
TEST(IncrementalSolverTest, SuccessfulComputeInfeasibleSubsystem) {
MockIncrementalSolver incremental_solver;
EXPECT_CALL(incremental_solver, Update())
.WillOnce(Return(UpdateResult(/*did_update=*/true)));
Model model;
const Variable v = model.AddBinaryVariable("v");
const ModelSubset model_subset = {
.variable_integrality = {v},
};
EXPECT_CALL(incremental_solver, ComputeInfeasibleSubsystemWithoutUpdate(_))
.WillOnce(Return(ComputeInfeasibleSubsystemResult{
.feasibility = FeasibilityStatus::kInfeasible,
.infeasible_subsystem = model_subset,
.is_minimal = false,
}));
ASSERT_THAT(incremental_solver.ComputeInfeasibleSubsystem(),
IsOkAndHolds(IsInfeasible(
/*expected_is_minimal=*/false,
/*expected_infeasible_subsystem=*/model_subset)));
}
} // namespace
} // namespace operations_research::math_opt

View File

@@ -209,8 +209,8 @@ class MapToDoubleMatcher
} // namespace
Matcher<VariableMap<double>> IsNearlySubsetOf(VariableMap<double> expected,
double tolerance) {
Matcher<VariableMap<double>> IsNearlySupersetOf(VariableMap<double> expected,
double tolerance) {
return Matcher<VariableMap<double>>(new MapToDoubleMatcher<Variable>(
std::move(expected), /*all_keys=*/false, tolerance));
}
@@ -221,7 +221,7 @@ Matcher<VariableMap<double>> IsNear(VariableMap<double> expected,
std::move(expected), /*all_keys=*/true, tolerance));
}
Matcher<LinearConstraintMap<double>> IsNearlySubsetOf(
Matcher<LinearConstraintMap<double>> IsNearlySupersetOf(
LinearConstraintMap<double> expected, double tolerance) {
return Matcher<LinearConstraintMap<double>>(
new MapToDoubleMatcher<LinearConstraint>(std::move(expected),
@@ -243,7 +243,7 @@ Matcher<absl::flat_hash_map<QuadraticConstraint, double>> IsNear(
std::move(expected), /*all_keys=*/true, tolerance));
}
Matcher<absl::flat_hash_map<QuadraticConstraint, double>> IsNearlySubsetOf(
Matcher<absl::flat_hash_map<QuadraticConstraint, double>> IsNearlySupersetOf(
absl::flat_hash_map<QuadraticConstraint, double> expected,
double tolerance) {
return Matcher<absl::flat_hash_map<QuadraticConstraint, double>>(
@@ -260,7 +260,7 @@ Matcher<absl::flat_hash_map<K, double>> IsNear(
}
template <typename K>
Matcher<absl::flat_hash_map<K, double>> IsNearlySubsetOf(
Matcher<absl::flat_hash_map<K, double>> IsNearlySupersetOf(
absl::flat_hash_map<K, double> expected, const double tolerance) {
return Matcher<absl::flat_hash_map<K, double>>(new MapToDoubleMatcher<K>(
std::move(expected), /*all_keys=*/false, tolerance));

View File

@@ -121,11 +121,11 @@ constexpr double kMatcherDefaultTolerance = 1e-5;
testing::Matcher<VariableMap<double>> IsNear(
VariableMap<double> expected, double tolerance = kMatcherDefaultTolerance);
// Checks that the keys of actual are a subset of the keys of expected, and that
// for all shared keys, the values are within tolerance. This factory will
// Checks that the keys of actual are a superset of the keys of expected, and
// that for all shared keys, the values are within tolerance. This factory will
// CHECK-fail if expected contains any NaN values, and any NaN values in the
// expression compared against will result in the matcher failing.
testing::Matcher<VariableMap<double>> IsNearlySubsetOf(
testing::Matcher<VariableMap<double>> IsNearlySupersetOf(
VariableMap<double> expected, double tolerance = kMatcherDefaultTolerance);
// Checks that the maps have identical keys and values within tolerance. This
@@ -135,11 +135,11 @@ testing::Matcher<LinearConstraintMap<double>> IsNear(
LinearConstraintMap<double> expected,
double tolerance = kMatcherDefaultTolerance);
// Checks that the keys of actual are a subset of the keys of expected, and that
// for all shared keys, the values are within tolerance. This factory will
// Checks that the keys of actual are a superset of the keys of expected, and
// that for all shared keys, the values are within tolerance. This factory will
// CHECK-fail if expected contains any NaN values, and any NaN values in the
// expression compared against will result in the matcher failing.
testing::Matcher<LinearConstraintMap<double>> IsNearlySubsetOf(
testing::Matcher<LinearConstraintMap<double>> IsNearlySupersetOf(
LinearConstraintMap<double> expected,
double tolerance = kMatcherDefaultTolerance);
@@ -149,13 +149,13 @@ testing::Matcher<absl::flat_hash_map<QuadraticConstraint, double>> IsNear(
absl::flat_hash_map<QuadraticConstraint, double> expected,
double tolerance = kMatcherDefaultTolerance);
// Checks that the keys of actual are a subset of the keys of expected, and that
// for all shared keys, the values are within tolerance. This factory will
// Checks that the keys of actual are a superset of the keys of expected, and
// that for all shared keys, the values are within tolerance. This factory will
// CHECK-fail if expected contains any NaN values, and any NaN values in the
// expression compared against will result in the matcher failing.
testing::Matcher<absl::flat_hash_map<QuadraticConstraint, double>>
IsNearlySubsetOf(absl::flat_hash_map<QuadraticConstraint, double> expected,
double tolerance = kMatcherDefaultTolerance);
IsNearlySupersetOf(absl::flat_hash_map<QuadraticConstraint, double> expected,
double tolerance = kMatcherDefaultTolerance);
////////////////////////////////////////////////////////////////////////////////
// Matchers for various Variable expressions (e.g. LinearExpression)

View File

@@ -115,20 +115,21 @@ TEST(ApproximateMapMatcherTest, VariableIsNear) {
EXPECT_THAT(actual, Not(IsNear({{z, -2.5}})));
}
TEST(ApproximateMapMatcherTest, VariableIsNearlySubsetOf) {
TEST(ApproximateMapMatcherTest, VariableIsNearlySupersetOf) {
Model model;
const Variable w = model.AddBinaryVariable("w");
const Variable x = model.AddBinaryVariable("x");
const Variable y = model.AddBinaryVariable("y");
const Variable z = model.AddBinaryVariable("z");
const VariableMap<double> actual = {{x, 2.0}, {y, 4.1}, {z, -2.5}};
EXPECT_THAT(actual, IsNearlySubsetOf(actual));
EXPECT_THAT(actual, IsNearlySubsetOf({{y, 4.1}, {z, -2.5}}));
EXPECT_THAT(actual, Not(IsNearlySubsetOf({{w, 1}, {y, 4.1}, {z, -2.5}})));
EXPECT_THAT(actual, Not(IsNearlySubsetOf({{y, 4.4}, {z, -2.5}})));
EXPECT_THAT(actual, IsNearlySupersetOf(actual));
EXPECT_THAT(actual, IsNearlySupersetOf({{y, 4.1}, {z, -2.5}}));
EXPECT_THAT(actual, Not(IsNearlySupersetOf({{w, 1}, {y, 4.1}, {z, -2.5}})));
EXPECT_THAT(actual, Not(IsNearlySupersetOf({{y, 4.4}, {z, -2.5}})));
}
TEST(ApproximateMapMatcherTest, QuadraticConstraintIsNearAndIsNearlySubsetOf) {
TEST(ApproximateMapMatcherTest,
QuadraticConstraintIsNearAndIsNearlySupersetOf) {
Model model;
const Variable x = model.AddBinaryVariable("x");
const QuadraticConstraint c = model.AddQuadraticConstraint(x * x <= 0, "c");
@@ -137,29 +138,29 @@ TEST(ApproximateMapMatcherTest, QuadraticConstraintIsNearAndIsNearlySubsetOf) {
const absl::flat_hash_map<QuadraticConstraint, double> actual = {{c, 2},
{e, 5}};
EXPECT_THAT(actual, IsNearlySubsetOf(actual));
EXPECT_THAT(actual, IsNearlySupersetOf(actual));
EXPECT_THAT(actual, IsNear(actual));
EXPECT_THAT(actual, IsNear({{c, 2 + 1e-8}, {e, 5}}));
EXPECT_THAT(actual, Not(IsNear({{e, 5}})));
EXPECT_THAT(actual, Not(IsNear({{c, 2 + 1e-2}, {e, 5}})));
EXPECT_THAT(actual, Not(IsNear({{d, 5}})));
EXPECT_THAT(actual, IsNearlySubsetOf({{e, 5}}));
EXPECT_THAT(actual, IsNearlySupersetOf({{e, 5}}));
}
TEST(ApproximateMapMatcherTest, LinearConstraintIsNearAndIsNearlySubsetOf) {
TEST(ApproximateMapMatcherTest, LinearConstraintIsNearAndIsNearlySupersetOf) {
Model model;
const LinearConstraint c = model.AddLinearConstraint("c");
const LinearConstraint d = model.AddLinearConstraint("d");
const LinearConstraint e = model.AddLinearConstraint("e");
const LinearConstraintMap<double> actual = {{c, 2}, {e, 5}};
EXPECT_THAT(actual, IsNearlySubsetOf(actual));
EXPECT_THAT(actual, IsNearlySupersetOf(actual));
EXPECT_THAT(actual, IsNear(actual));
EXPECT_THAT(actual, IsNear({{c, 2 + 1e-8}, {e, 5}}));
EXPECT_THAT(actual, Not(IsNear({{e, 5}})));
EXPECT_THAT(actual, Not(IsNear({{c, 2 + 1e-2}, {e, 5}})));
EXPECT_THAT(actual, Not(IsNear({{d, 5}})));
EXPECT_THAT(actual, IsNearlySubsetOf({{e, 5}}));
EXPECT_THAT(actual, IsNearlySupersetOf({{e, 5}}));
}
TEST(LinearExpressionMatcherTest, IsIdentical) {

View File

@@ -193,21 +193,6 @@ IncrementalSolverImpl::IncrementalSolverImpl(
update_tracker_(std::move(update_tracker)),
solver_(std::move(solver)) {}
absl::StatusOr<SolveResult> IncrementalSolverImpl::Solve(
const SolveArguments& arguments) {
// TODO: b/260337466 - Add permanent errors and concurrency protection.
RETURN_IF_ERROR(Update().status());
return SolveWithoutUpdate(arguments);
}
absl::StatusOr<ComputeInfeasibleSubsystemResult>
IncrementalSolverImpl::ComputeInfeasibleSubsystem(
const ComputeInfeasibleSubsystemArguments& arguments) {
// TODO: b/260337466 - Add permanent errors and concurrency protection.
RETURN_IF_ERROR(Update().status());
return ComputeInfeasibleSubsystemWithoutUpdate(arguments);
}
absl::StatusOr<UpdateResult> IncrementalSolverImpl::Update() {
// TODO: b/260337466 - Add permanent errors and concurrency protection.
ASSIGN_OR_RETURN(std::optional<ModelUpdateProto> model_update,

View File

@@ -81,11 +81,6 @@ class IncrementalSolverImpl : public IncrementalSolver {
BaseSolverFactory solver_factory, Model* model, SolverType solver_type,
const SolveInterrupter* user_canceller, bool remove_names);
absl::StatusOr<SolveResult> Solve(const SolveArguments& arguments) override;
absl::StatusOr<ComputeInfeasibleSubsystemResult> ComputeInfeasibleSubsystem(
const ComputeInfeasibleSubsystemArguments& arguments) override;
absl::StatusOr<UpdateResult> Update() override;
absl::StatusOr<SolveResult> SolveWithoutUpdate(

View File

@@ -13,6 +13,7 @@
# limitations under the License.
"""Defines how to request a callback and the input and output of a callback."""
import dataclasses
import datetime
import enum
@@ -36,7 +37,8 @@ class Event(enum.Enum):
* MIP: The solver is in the MIP loop (called periodically before starting a
new node). Useful for early termination. Note that this event does not
provide information on LP relaxations nor about new incumbent solutions.
Gurobi only.
Fully supported by Gurobi only. If used with CP-SAT, it is called when the
dual bound is improved.
* MIP_SOLUTION: Called every time a new MIP incumbent is found. Fully
supported by Gurobi, partially supported by CP-SAT (you can observe new
solutions, but not add lazy constraints).

View File

@@ -52,6 +52,7 @@ cc_library(
"//ortools/math_opt/io:mps_converter",
"//ortools/port:proto_utils",
"//ortools/port:scoped_std_stream_capture",
"//ortools/util:fp_roundtrip_conv",
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/status",

View File

@@ -13,6 +13,8 @@
#include "ortools/math_opt/solver_tests/callback_tests.h"
#include <atomic>
#include <cmath>
#include <limits>
#include <memory>
#include <optional>
@@ -44,6 +46,7 @@
#include "ortools/math_opt/solver_tests/test_models.h"
#include "ortools/port/proto_utils.h"
#include "ortools/port/scoped_std_stream_capture.h"
#include "ortools/util/fp_roundtrip_conv.h"
namespace operations_research {
namespace math_opt {
@@ -418,8 +421,8 @@ TEST_P(CallbackTest, EventSolutionAlwaysCalled) {
SolveArguments args = {
.callback_registration = {.events = {CallbackEvent::kMipSolution}}};
absl::Mutex mutex;
bool cb_called = false;
bool cb_called_on_optimal = false;
std::atomic<bool> cb_called = false;
std::atomic<bool> cb_called_on_optimal = false;
args.callback = [&](const CallbackData& callback_data) {
const absl::MutexLock lock(mutex);
cb_called = true;
@@ -433,6 +436,8 @@ TEST_P(CallbackTest, EventSolutionAlwaysCalled) {
EXPECT_THAT(
sol, AnyOf(IsNear({{x, 0.0}, {y, 0.0}}), IsNear({{x, 1.0}, {y, 0.0}}),
IsNear({{x, 0.0}, {y, 1.0}})));
EXPECT_LE(callback_data.mip_stats.primal_bound(), 2.05);
EXPECT_GE(callback_data.mip_stats.dual_bound(), 1.95);
if (gtl::FindWithDefault(sol, y) > 0.5) {
cb_called_on_optimal = true;
}
@@ -646,8 +651,8 @@ TEST_P(CallbackTest, EventSolutionFilter) {
.events = {CallbackEvent::kMipSolution},
.mip_solution_filter = MakeKeepKeysFilter({y})}};
absl::Mutex mutex;
bool cb_called = false;
bool cb_called_on_optimal = false;
std::atomic<bool> cb_called = false;
std::atomic<bool> cb_called_on_optimal = false;
args.callback = [&](const CallbackData& callback_data) {
const absl::MutexLock lock(mutex);
cb_called = true;
@@ -795,6 +800,47 @@ TEST_P(CallbackTest, EventNodeFilter) {
EXPECT_THAT(solutions, Each(UnorderedElementsAre(Pair(x0, _), Pair(x2, _))));
}
TEST_P(CallbackTest, EventMip) {
if (!GetParam().supported_events.contains(CallbackEvent::kMip)) {
GTEST_SKIP() << "Test skipped because this solver does not support "
"CallbackEvent::kMip.";
}
// This test must use integer variables.
ASSERT_TRUE(GetParam().integer_variables);
// Use the MIPLIB instance 23588, which has optimal solution 8090 and LP
// relaxation of 7649.87. This instance was selected because every
// supported solver can solve it quickly (a few seconds), but no solver can
// solve it in one node (so the node callback will be invoked).
ASSERT_OK_AND_ASSIGN(const std::unique_ptr<Model> model,
LoadMiplibInstance("23588"));
std::atomic<double> best_primal_bound =
std::numeric_limits<double>::infinity();
std::atomic<double> best_dual_bound =
-std::numeric_limits<double>::infinity();
const SolveArguments args = {
.callback_registration = {.events = {CallbackEvent::kMip}},
.callback = [&](const CallbackData& callback_data) {
CHECK_EQ(callback_data.event, CallbackEvent::kMip);
const double primal_bound = callback_data.mip_stats.primal_bound();
const double dual_bound = callback_data.mip_stats.dual_bound();
best_primal_bound = std::fmin(best_primal_bound, primal_bound);
best_dual_bound = std::fmax(best_dual_bound, dual_bound);
return CallbackResult();
}};
EXPECT_THAT(Solve(*model, GetParam().solver_type, args),
IsOkAndHolds(IsOptimal(8090)));
LOG(INFO) << "best_primal_bound: "
<< RoundTripDoubleFormat(best_primal_bound.load());
LOG(INFO) << "best_dual_bound: "
<< RoundTripDoubleFormat(best_dual_bound.load());
EXPECT_THAT(best_primal_bound.load(), testing::DoubleNear(8090, 0.5));
EXPECT_LE(best_dual_bound.load(), 8090.5);
EXPECT_GE(best_dual_bound.load(), 7640);
}
TEST_P(CallbackTest, StatusPropagation) {
if (!GetParam().supported_events.contains(CallbackEvent::kMipSolution)) {
GTEST_SKIP() << "Test skipped because this solver does not support "

View File

@@ -238,13 +238,17 @@ cc_library(
"//ortools/port:proto_utils",
"//ortools/sat:sat_parameters_cc_proto",
"//ortools/util:solve_interrupter",
"@abseil-cpp//absl/base:core_headers",
"@abseil-cpp//absl/base:nullability",
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/functional:any_invocable",
"@abseil-cpp//absl/log",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/memory",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/synchronization",
"@abseil-cpp//absl/time",
"@abseil-cpp//absl/types:span",
],

View File

@@ -14,15 +14,21 @@
#include "ortools/math_opt/solvers/cp_sat_solver.h"
#include <atomic>
#include <cmath>
#include <cstdint>
#include <functional>
#include <limits>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/base/attributes.h"
#include "absl/base/nullability.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_set.h"
#include "absl/functional/any_invocable.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
@@ -33,6 +39,7 @@
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
@@ -65,6 +72,8 @@ namespace math_opt {
namespace {
constexpr double kInf = std::numeric_limits<double>::infinity();
constexpr SupportedProblemStructures kCpSatSupportedStructures = {
.integer_variables = SupportType::kSupported,
.quadratic_objectives = SupportType::kNotImplemented,
@@ -316,6 +325,162 @@ absl::StatusOr<TerminationProto> GetTermination(
absl::StrCat("unimplemented solve status: ", response.status()));
}
// This class gathers the solution callback and best bound callback together
// with some solver state that we need to update as the solver progresses.
class CpSatCallbacks {
public:
CpSatCallbacks(const absl_nullable SolverInterface::Callback& cb
ABSL_ATTRIBUTE_LIFETIME_BOUND,
SolveInterrupter* absl_nonnull local_interrupter
ABSL_ATTRIBUTE_LIFETIME_BOUND,
absl_nonnull absl::AnyInvocable<
SparseDoubleVectorProto(absl::Span<const double>) const>
extract_solution,
absl::flat_hash_set<CallbackEventProto> events,
bool is_maximize);
// CpSatCallbacks is neither copyable nor movable as callbacks point to it.
CpSatCallbacks(const CpSatCallbacks&) = delete;
CpSatCallbacks& operator=(const CpSatCallbacks&) = delete;
// Returns a solution callback that wraps the user callback and updates the
// state of CpSatCallbacks. Returns nullptr if it is not needed.
absl_nullable std::function<void(const MPSolution&)> MakeSolutionCallback();
// Returns a best bound callback that wraps the user callback and updates the
// state of CpSatCallbacks. Returns nullptr if it is not needed.
absl_nullable std::function<void(const double)> MakeBestBoundCallback();
absl::Status error() const {
absl::MutexLock lock(mutex_);
return error_;
}
private:
void ExecuteCallback(const CallbackDataProto& cb_data);
void UpdateMipStatsFromNewSolution(const MPSolution& mp_solution)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
const SolverInterface::Callback& cb_;
SolveInterrupter* absl_nonnull const local_interrupter_;
const absl::AnyInvocable<SparseDoubleVectorProto(absl::Span<const double>)
const>
extract_solution_;
const bool has_mip_solution_event_;
const bool has_mip_event_;
const bool is_maximize_;
mutable absl::Mutex mutex_;
absl::Status error_ ABSL_GUARDED_BY(mutex_) = absl::OkStatus();
CallbackDataProto::MipStats current_mip_stats_ ABSL_GUARDED_BY(mutex_);
};
CpSatCallbacks::CpSatCallbacks(
const SolverInterface::Callback& cb ABSL_ATTRIBUTE_LIFETIME_BOUND,
SolveInterrupter* absl_nonnull local_interrupter
ABSL_ATTRIBUTE_LIFETIME_BOUND,
absl_nonnull
absl::AnyInvocable<SparseDoubleVectorProto(absl::Span<const double>) const>
extract_solution ABSL_ATTRIBUTE_LIFETIME_BOUND,
absl::flat_hash_set<CallbackEventProto> events, const bool is_maximize)
: cb_(cb),
local_interrupter_(local_interrupter),
extract_solution_(std::move(extract_solution)),
// If there is no user callback, we make sure not calling it.
has_mip_solution_event_(cb != nullptr &&
events.contains(CALLBACK_EVENT_MIP_SOLUTION)),
has_mip_event_(cb != nullptr && events.contains(CALLBACK_EVENT_MIP)),
is_maximize_(is_maximize) {
current_mip_stats_.set_primal_bound(is_maximize ? -kInf : kInf);
current_mip_stats_.set_dual_bound(is_maximize ? kInf : -kInf);
current_mip_stats_.set_number_of_solutions_found(0);
}
std::function<void(const MPSolution&)> absl_nullable
CpSatCallbacks::MakeSolutionCallback() {
if (!has_mip_solution_event_ && !has_mip_event_) {
return nullptr;
}
if (!has_mip_solution_event_) {
return [this](const MPSolution& mp_solution) {
absl::MutexLock lock(mutex_);
UpdateMipStatsFromNewSolution(mp_solution);
};
}
return [this](const MPSolution& mp_solution) {
CallbackDataProto cb_data;
cb_data.set_event(CALLBACK_EVENT_MIP_SOLUTION);
*cb_data.mutable_primal_solution_vector() =
extract_solution_(mp_solution.variable_value());
{
absl::MutexLock lock(mutex_);
UpdateMipStatsFromNewSolution(mp_solution);
*cb_data.mutable_mip_stats() = current_mip_stats_;
}
ExecuteCallback(cb_data);
};
}
std::function<void(const double)> absl_nullable
CpSatCallbacks::MakeBestBoundCallback() {
if (!has_mip_solution_event_ && !has_mip_event_) {
return nullptr;
}
if (!has_mip_event_) {
return [this](const double best_bound) {
absl::MutexLock lock(mutex_);
current_mip_stats_.set_dual_bound(best_bound);
};
}
return [this](const double best_bound) {
CallbackDataProto cb_data;
cb_data.set_event(CALLBACK_EVENT_MIP);
{
absl::MutexLock lock(mutex_);
current_mip_stats_.set_dual_bound(best_bound);
*cb_data.mutable_mip_stats() = current_mip_stats_;
}
ExecuteCallback(cb_data);
};
}
void CpSatCallbacks::ExecuteCallback(const CallbackDataProto& cb_data) {
{
absl::MutexLock lock(mutex_);
if (!error_.ok()) {
// A previous callback failed.
return;
}
}
const absl::StatusOr<CallbackResultProto> cb_result = cb_(cb_data);
// Note cb_result.cuts and cb_result.suggested solutions are not supported
// by CP-SAT and we have validated they are empty.
if (!cb_result.ok()) {
{
absl::MutexLock lock(mutex_);
error_ = cb_result.status();
}
// Note: we will be returning a status error, we do not need to worry
// about interpreting this as TERMINATION_REASON_INTERRUPTED.
local_interrupter_->Interrupt();
} else if (cb_result->terminate()) {
local_interrupter_->Interrupt();
}
}
void CpSatCallbacks::UpdateMipStatsFromNewSolution(
const MPSolution& mp_solution) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
if (is_maximize_) {
current_mip_stats_.set_primal_bound(std::fmax(
current_mip_stats_.primal_bound(), mp_solution.objective_value()));
} else {
current_mip_stats_.set_primal_bound(std::fmin(
current_mip_stats_.primal_bound(), mp_solution.objective_value()));
}
current_mip_stats_.set_number_of_solutions_found(
current_mip_stats_.number_of_solutions_found() + 1);
}
} // namespace
absl::StatusOr<std::unique_ptr<SolverInterface>> CpSatSolver::New(
@@ -345,7 +510,7 @@ absl::StatusOr<SolveResultProto> CpSatSolver::Solve(
RETURN_IF_ERROR(CheckRegisteredCallbackEvents(
callback_registration,
/*supported_events=*/{CALLBACK_EVENT_MIP_SOLUTION}));
/*supported_events=*/{CALLBACK_EVENT_MIP_SOLUTION, CALLBACK_EVENT_MIP}));
if (callback_registration.add_lazy_constraints()) {
return absl::InvalidArgumentError(
"CallbackRegistrationProto.add_lazy_constraints=true is not supported "
@@ -392,7 +557,7 @@ absl::StatusOr<SolveResultProto> CpSatSolver::Solve(
}
// We need to chain the user interrupter through a local interrupter, because
// if we termiante early from a callback request, we don't want to incorrectly
// if we terminate early from a callback request, we don't want to incorrectly
// modify the input state.
SolveInterrupter local_interrupter;
std::atomic<bool> interrupt_solve = false;
@@ -411,41 +576,21 @@ absl::StatusOr<SolveResultProto> CpSatSolver::Solve(
const absl::flat_hash_set<CallbackEventProto> events =
EventSet(callback_registration);
std::function<void(const MPSolution&)> solution_callback;
absl::Status callback_error = absl::OkStatus();
if (events.contains(CALLBACK_EVENT_MIP_SOLUTION)) {
solution_callback =
[this, &cb, &callback_error, &local_interrupter,
&callback_registration](const MPSolution& mp_solution) {
if (!callback_error.ok()) {
// A previous callback failed.
return;
}
CallbackDataProto cb_data;
cb_data.set_event(CALLBACK_EVENT_MIP_SOLUTION);
*cb_data.mutable_primal_solution_vector() =
ExtractSolution(mp_solution.variable_value(),
callback_registration.mip_solution_filter());
const absl::StatusOr<CallbackResultProto> cb_result = cb(cb_data);
if (!cb_result.ok()) {
callback_error = cb_result.status();
// Note: we will be returning a status error, we do not need to
// worry about interpreting this as TERMINATION_REASON_INTERRUPTED.
local_interrupter.Interrupt();
} else if (cb_result->terminate()) {
local_interrupter.Interrupt();
}
// Note cb_result.cuts and cb_result.suggested solutions are not
// supported by CP-SAT and we have validated they are empty.
};
}
absl::AnyInvocable<SparseDoubleVectorProto(absl::Span<const double>) const>
extract_solution = [&](absl::Span<const double> cp_sat_variable_values) {
return ExtractSolution(cp_sat_variable_values,
callback_registration.mip_solution_filter());
};
CpSatCallbacks callbacks(cb, &local_interrupter, std::move(extract_solution),
events, cp_sat_model_.maximize());
// CP-SAT returns "infeasible" for inverted bounds.
RETURN_IF_ERROR(ListInvertedBounds().ToStatus());
const MPSolutionResponse response = SatSolveProto(
std::move(req), &interrupt_solve, logging_callback, solution_callback);
RETURN_IF_ERROR(callback_error) << "error in callback";
std::move(req), &interrupt_solve, logging_callback,
callbacks.MakeSolutionCallback(), callbacks.MakeBestBoundCallback());
RETURN_IF_ERROR(callbacks.error()) << "error in callback";
ASSIGN_OR_RETURN(*result.mutable_termination(),
GetTermination(local_interrupter.IsInterrupted(),
/*maximize=*/cp_sat_model_.maximize(),

View File

@@ -328,15 +328,16 @@ SolveParameters AllSolutions() {
return result;
}
INSTANTIATE_TEST_SUITE_P(CpSatCallbackTest, CallbackTest,
Values(CallbackTestParams(
SolverType::kCpSat,
/*integer_variables=*/true,
/*add_lazy_constraints=*/false,
/*add_cuts=*/false,
/*supported_events=*/{CallbackEvent::kMipSolution},
/*all_solutions=*/AllSolutions(),
/*reaches_cut_callback=*/std::nullopt)));
INSTANTIATE_TEST_SUITE_P(
CpSatCallbackTest, CallbackTest,
Values(CallbackTestParams(
SolverType::kCpSat,
/*integer_variables=*/true,
/*add_lazy_constraints=*/false,
/*add_cuts=*/false,
/*supported_events=*/{CallbackEvent::kMipSolution, CallbackEvent::kMip},
/*all_solutions=*/AllSolutions(),
/*reaches_cut_callback=*/std::nullopt)));
TEST(CpSatInvalidCallbackTest, RequestLazyConstraints) {
Model model("model");