math_opt: export from google3
This commit is contained in:
@@ -47,7 +47,8 @@ enum CallbackEventProto {
|
||||
// node). Useful for early termination. Note that this event does not provide
|
||||
// information on LP relaxations nor about new incumbent solutions.
|
||||
//
|
||||
// This event is supported for MIP models by SOLVER_TYPE_GUROBI only.
|
||||
// This event is fully supported for MIP models by SOLVER_TYPE_GUROBI only. If
|
||||
// used with SOLVER_TYPE_CP_SAT, it is called when the dual bound is improved.
|
||||
CALLBACK_EVENT_MIP = 3;
|
||||
|
||||
// Called every time a new MIP incumbent is found.
|
||||
@@ -127,7 +128,8 @@ message CallbackDataProto {
|
||||
BarrierStats barrier_stats = 6;
|
||||
|
||||
// MIP B&B stats. Only available during CALLBACK_EVENT_MIPxxxx events.
|
||||
// Not supported for CP-SAT.
|
||||
// When using CP-SAT, only primal_bound, dual_bound and
|
||||
// number_of_solutions_found are populated.
|
||||
message MipStats {
|
||||
optional double primal_bound = 1;
|
||||
optional double dual_bound = 2;
|
||||
|
||||
@@ -33,8 +33,9 @@ namespace operations_research::math_opt {
|
||||
|
||||
// The API of solvers (in-process, sub-process and streaming RPC ones).
|
||||
//
|
||||
// Thread-safety: methods Solve() and Update() must not be called concurrently;
|
||||
// they should immediately return with an error status if this happens.
|
||||
// Thread-safety: methods Solve(), ComputeInfeasibleSubsystem() and Update()
|
||||
// must not be called concurrently; they should immediately return with an error
|
||||
// status if this happens.
|
||||
//
|
||||
// TODO: b/350984134 - Rename `Solver` into `InProcessSolver` and then rename
|
||||
// `BaseSolver` into `Solver`.
|
||||
@@ -65,7 +66,14 @@ class BaseSolver {
|
||||
// printed on stdout/stderr/logs anymore.
|
||||
MessageCallback message_callback = nullptr;
|
||||
|
||||
// Registration parameter controlling calls to user_cb.
|
||||
CallbackRegistrationProto callback_registration;
|
||||
|
||||
// An optional MIP/LP callback. Only called for events registered in
|
||||
// callback_registration.
|
||||
//
|
||||
// Solve() returns an error if called without a user_cb but with some
|
||||
// non-empty callback_registration.request_registration.
|
||||
Callback user_cb = nullptr;
|
||||
|
||||
// An optional interrupter that the solver can use to interrupt the solve
|
||||
|
||||
@@ -120,10 +120,16 @@ absl::StatusOr<SolveResultProto> Solver::Solve(const SolveArgs& arguments) {
|
||||
ValidateModelSolveParameters(arguments.model_parameters, model_summary_))
|
||||
<< "invalid model_parameters";
|
||||
|
||||
RETURN_IF_ERROR(ValidateCallbackRegistration(arguments.callback_registration,
|
||||
model_summary_));
|
||||
SolverInterface::Callback cb = nullptr;
|
||||
if (!arguments.callback_registration.request_registration().empty() &&
|
||||
arguments.user_cb == nullptr) {
|
||||
return absl::InvalidArgumentError(
|
||||
"no callback function was provided but callback events were "
|
||||
"registered");
|
||||
}
|
||||
if (arguments.user_cb != nullptr) {
|
||||
RETURN_IF_ERROR(ValidateCallbackRegistration(
|
||||
arguments.callback_registration, model_summary_));
|
||||
cb = [&](const CallbackDataProto& callback_data)
|
||||
-> absl::StatusOr<CallbackResultProto> {
|
||||
RETURN_IF_ERROR(ValidateCallbackDataProto(
|
||||
|
||||
@@ -43,8 +43,9 @@ namespace math_opt {
|
||||
//
|
||||
// This interface is not meant to be used directly. The actual API is the one of
|
||||
// the Solver class. The Solver class validates the models before calling this
|
||||
// interface. It makes sure no concurrent calls happen on Solve(), CanUpdate()
|
||||
// and Update(). It makes sure no other function is called after Solve(),
|
||||
// interface. It makes sure no concurrent calls happen on Solve(),
|
||||
// ComputeInfeasibleSubsystem(), CanUpdate() and Update(). It makes sure no
|
||||
// other function is called after Solve(), ComputeInfeasibleSubsystem(),
|
||||
// Update() or a callback have failed.
|
||||
//
|
||||
// Implementations of this interface should not have public constructors but
|
||||
@@ -69,12 +70,28 @@ class SolverInterface {
|
||||
// See Solver::MessageCallback documentation for details.
|
||||
using MessageCallback = std::function<void(const std::vector<std::string>&)>;
|
||||
|
||||
// A callback function (if non null) is a function that validates its input
|
||||
// and its output, and if fails, return a status. The invariant is that the
|
||||
// solver implementation can rely on receiving valid data. The implementation
|
||||
// of this interface must provide valid input (which will be validated) and
|
||||
// in error, it will return a status (without actually calling the callback
|
||||
// function). This is enforced in the solver.cc layer.
|
||||
// A callback function (if non null) provided by the Solver class to its
|
||||
// SolverInterface that wraps the user callback function
|
||||
// (BaseSolver::Callback) and validates its inputs (provided by the
|
||||
// SolverInterface implementation) and outputs (provided by the user). A
|
||||
// failing status is returned if those inputs or outputs are invalid.
|
||||
//
|
||||
// To be clear the SolverInterface::Callback is implemented by the Solver
|
||||
// class and looks like:
|
||||
//
|
||||
// absl::Status Callback(const CallbackDataProto& callback_data) {
|
||||
// RETURN_IF_ERROR(ValidateCallbackDataProto(callback_data, ...));
|
||||
// CallbackResultProto result = user_cb(callback_data);
|
||||
// RETURN_IF_ERROR(ValidateCallbackResultProto(result));
|
||||
// return result;
|
||||
// }
|
||||
//
|
||||
// As a consequence SolverInterface implementations can rely on receiving a
|
||||
// valid CallbackResultProto.
|
||||
//
|
||||
// When the SolverInterface::Callback returns an error the SolverInterface
|
||||
// implementation must interrupt the Solve() as soon as possible and return
|
||||
// this error.
|
||||
using Callback = std::function<absl::StatusOr<CallbackResultProto>(
|
||||
const CallbackDataProto&)>;
|
||||
|
||||
@@ -114,7 +131,11 @@ class SolverInterface {
|
||||
// When parameter `message_cb` is not null and the underlying solver does not
|
||||
// supports message callbacks, it should ignore it.
|
||||
//
|
||||
// Solvers should return a InvalidArgumentError when called with events on
|
||||
// The parameter `cb` won't be null when
|
||||
// callback_registration.request_registration is not empty (solver.cc will
|
||||
// return an error in that case before calling SolverInterface::Solve()).
|
||||
//
|
||||
// Solvers should return an InvalidArgumentError when called with events on
|
||||
// callback_registration that are not supported by the solver for the type of
|
||||
// model being solved (for example MIP events if the model is an LP, or events
|
||||
// that are not emitted by the solver). Solvers should use
|
||||
|
||||
@@ -895,6 +895,7 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "incremental_solver",
|
||||
srcs = ["incremental_solver.cc"],
|
||||
hdrs = ["incremental_solver.h"],
|
||||
deps = [
|
||||
":compute_infeasible_subsystem_arguments",
|
||||
@@ -903,10 +904,25 @@ cc_library(
|
||||
":solve_arguments",
|
||||
":solve_result",
|
||||
":update_result",
|
||||
"//ortools/base:status_macros",
|
||||
"@abseil-cpp//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "incremental_solver_test",
|
||||
srcs = ["incremental_solver_test.cc"],
|
||||
deps = [
|
||||
":incremental_solver",
|
||||
":matchers",
|
||||
":math_opt",
|
||||
"//ortools/base:gmock_main",
|
||||
"@abseil-cpp//absl/status",
|
||||
"@abseil-cpp//absl/status:statusor",
|
||||
"@abseil-cpp//absl/strings:string_view",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "remote_streaming_mode",
|
||||
srcs = ["remote_streaming_mode.cc"],
|
||||
|
||||
@@ -109,7 +109,9 @@ enum class CallbackEvent {
|
||||
// node). Useful for early termination. Note that this event does not provide
|
||||
// information on LP relaxations nor about new incumbent solutions.
|
||||
//
|
||||
// This event is supported for MIP models with SolverType::kGurobi only.
|
||||
// This event is fully supported for MIP models with SolverType::kGurobi only.
|
||||
// If used with SolverType::kCpSat, it is called when the dual bound is
|
||||
// improved.
|
||||
kMip = CALLBACK_EVENT_MIP,
|
||||
|
||||
// Called every time a new MIP incumbent is found.
|
||||
|
||||
34
ortools/math_opt/cpp/incremental_solver.cc
Normal file
34
ortools/math_opt/cpp/incremental_solver.cc
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/cpp/incremental_solver.h"
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "ortools/base/status_macros.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
absl::StatusOr<SolveResult> IncrementalSolver::Solve(
|
||||
const SolveArguments& arguments) {
|
||||
RETURN_IF_ERROR(Update().status());
|
||||
return SolveWithoutUpdate(arguments);
|
||||
}
|
||||
|
||||
absl::StatusOr<ComputeInfeasibleSubsystemResult>
|
||||
IncrementalSolver::ComputeInfeasibleSubsystem(
|
||||
const ComputeInfeasibleSubsystemArguments& arguments) {
|
||||
RETURN_IF_ERROR(Update().status());
|
||||
return ComputeInfeasibleSubsystemWithoutUpdate(arguments);
|
||||
}
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
@@ -112,21 +112,14 @@ class IncrementalSolver {
|
||||
//
|
||||
// See callback.h for documentation on arguments.callback and
|
||||
// arguments.callback_registration.
|
||||
virtual absl::StatusOr<SolveResult> Solve(
|
||||
const SolveArguments& arguments) = 0;
|
||||
absl::StatusOr<SolveResult> Solve() { return Solve({}); }
|
||||
absl::StatusOr<SolveResult> Solve(const SolveArguments& arguments = {});
|
||||
|
||||
// Updates the underlying solver with latest model changes and runs the
|
||||
// computation.
|
||||
//
|
||||
// Same as Solve() but compute the infeasible subsystem.
|
||||
virtual absl::StatusOr<ComputeInfeasibleSubsystemResult>
|
||||
ComputeInfeasibleSubsystem(
|
||||
const ComputeInfeasibleSubsystemArguments& arguments) = 0;
|
||||
absl::StatusOr<ComputeInfeasibleSubsystemResult>
|
||||
ComputeInfeasibleSubsystem() {
|
||||
return ComputeInfeasibleSubsystem({});
|
||||
}
|
||||
absl::StatusOr<ComputeInfeasibleSubsystemResult> ComputeInfeasibleSubsystem(
|
||||
const ComputeInfeasibleSubsystemArguments& arguments = {});
|
||||
|
||||
// Updates the model to solve.
|
||||
//
|
||||
|
||||
121
ortools/math_opt/cpp/incremental_solver_test.cc
Normal file
121
ortools/math_opt/cpp/incremental_solver_test.cc
Normal file
@@ -0,0 +1,121 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/cpp/incremental_solver.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
#include "ortools/math_opt/cpp/matchers.h"
|
||||
#include "ortools/math_opt/cpp/math_opt.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
namespace {
|
||||
|
||||
using ::testing::_;
|
||||
using ::testing::Return;
|
||||
using ::testing::status::IsOkAndHolds;
|
||||
using ::testing::status::StatusIs;
|
||||
|
||||
class MockIncrementalSolver final : public IncrementalSolver {
|
||||
public:
|
||||
MOCK_METHOD(absl::StatusOr<UpdateResult>, Update, (), (override));
|
||||
MOCK_METHOD(absl::StatusOr<SolveResult>, SolveWithoutUpdate,
|
||||
(const SolveArguments&), (const, override));
|
||||
MOCK_METHOD(absl::StatusOr<ComputeInfeasibleSubsystemResult>,
|
||||
ComputeInfeasibleSubsystemWithoutUpdate,
|
||||
(const ComputeInfeasibleSubsystemArguments&), (const, override));
|
||||
MOCK_METHOD(SolverType, solver_type, (), (const, override));
|
||||
};
|
||||
|
||||
TEST(IncrementalSolverTest, SolveWithFailingUpdate) {
|
||||
MockIncrementalSolver incremental_solver;
|
||||
EXPECT_CALL(incremental_solver, Update())
|
||||
.WillOnce(Return(absl::InternalError("oops")));
|
||||
EXPECT_THAT(incremental_solver.Solve(),
|
||||
StatusIs(absl::StatusCode::kInternal, "oops"));
|
||||
}
|
||||
|
||||
TEST(IncrementalSolverTest, SolveWithFailingSolveWithoutUpdate) {
|
||||
MockIncrementalSolver incremental_solver;
|
||||
EXPECT_CALL(incremental_solver, Update())
|
||||
.WillOnce(Return(UpdateResult(/*did_update=*/true)));
|
||||
EXPECT_CALL(incremental_solver, SolveWithoutUpdate(_))
|
||||
.WillOnce(Return(absl::InternalError("oops")));
|
||||
EXPECT_THAT(incremental_solver.Solve(),
|
||||
StatusIs(absl::StatusCode::kInternal, "oops"));
|
||||
}
|
||||
|
||||
TEST(IncrementalSolverTest, SuccessfulSolve) {
|
||||
MockIncrementalSolver incremental_solver;
|
||||
EXPECT_CALL(incremental_solver, Update())
|
||||
.WillOnce(Return(UpdateResult(/*did_update=*/true)));
|
||||
constexpr double kObjectiveValue = 3.5;
|
||||
constexpr absl::string_view kDetail = "found the optimum!";
|
||||
EXPECT_CALL(incremental_solver, SolveWithoutUpdate(_))
|
||||
.WillOnce(Return(
|
||||
SolveResult(Termination::Optimal(/*objective_value=*/kObjectiveValue,
|
||||
/*detail=*/std::string(kDetail)))));
|
||||
ASSERT_OK_AND_ASSIGN(const SolveResult solve_result,
|
||||
incremental_solver.Solve());
|
||||
EXPECT_THAT(solve_result.termination,
|
||||
TerminationIsOptimal(/*primal_objective_value=*/kObjectiveValue));
|
||||
EXPECT_EQ(solve_result.termination.detail, kDetail);
|
||||
}
|
||||
|
||||
TEST(IncrementalSolverTest, ComputeInfeasibleSubsystemWithFailingUpdate) {
|
||||
MockIncrementalSolver incremental_solver;
|
||||
EXPECT_CALL(incremental_solver, Update())
|
||||
.WillOnce(Return(absl::InternalError("oops")));
|
||||
EXPECT_THAT(incremental_solver.ComputeInfeasibleSubsystem(),
|
||||
StatusIs(absl::StatusCode::kInternal, "oops"));
|
||||
}
|
||||
|
||||
TEST(IncrementalSolverTest,
|
||||
ComputeInfeasibleSubsystemWithFailingComputeWithoutUpdate) {
|
||||
MockIncrementalSolver incremental_solver;
|
||||
EXPECT_CALL(incremental_solver, Update())
|
||||
.WillOnce(Return(UpdateResult(/*did_update=*/true)));
|
||||
EXPECT_CALL(incremental_solver, ComputeInfeasibleSubsystemWithoutUpdate(_))
|
||||
.WillOnce(Return(absl::InternalError("oops")));
|
||||
EXPECT_THAT(incremental_solver.ComputeInfeasibleSubsystem(),
|
||||
StatusIs(absl::StatusCode::kInternal, "oops"));
|
||||
}
|
||||
|
||||
TEST(IncrementalSolverTest, SuccessfulComputeInfeasibleSubsystem) {
|
||||
MockIncrementalSolver incremental_solver;
|
||||
EXPECT_CALL(incremental_solver, Update())
|
||||
.WillOnce(Return(UpdateResult(/*did_update=*/true)));
|
||||
Model model;
|
||||
const Variable v = model.AddBinaryVariable("v");
|
||||
const ModelSubset model_subset = {
|
||||
.variable_integrality = {v},
|
||||
};
|
||||
EXPECT_CALL(incremental_solver, ComputeInfeasibleSubsystemWithoutUpdate(_))
|
||||
.WillOnce(Return(ComputeInfeasibleSubsystemResult{
|
||||
.feasibility = FeasibilityStatus::kInfeasible,
|
||||
.infeasible_subsystem = model_subset,
|
||||
.is_minimal = false,
|
||||
}));
|
||||
ASSERT_THAT(incremental_solver.ComputeInfeasibleSubsystem(),
|
||||
IsOkAndHolds(IsInfeasible(
|
||||
/*expected_is_minimal=*/false,
|
||||
/*expected_infeasible_subsystem=*/model_subset)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research::math_opt
|
||||
@@ -209,8 +209,8 @@ class MapToDoubleMatcher
|
||||
|
||||
} // namespace
|
||||
|
||||
Matcher<VariableMap<double>> IsNearlySubsetOf(VariableMap<double> expected,
|
||||
double tolerance) {
|
||||
Matcher<VariableMap<double>> IsNearlySupersetOf(VariableMap<double> expected,
|
||||
double tolerance) {
|
||||
return Matcher<VariableMap<double>>(new MapToDoubleMatcher<Variable>(
|
||||
std::move(expected), /*all_keys=*/false, tolerance));
|
||||
}
|
||||
@@ -221,7 +221,7 @@ Matcher<VariableMap<double>> IsNear(VariableMap<double> expected,
|
||||
std::move(expected), /*all_keys=*/true, tolerance));
|
||||
}
|
||||
|
||||
Matcher<LinearConstraintMap<double>> IsNearlySubsetOf(
|
||||
Matcher<LinearConstraintMap<double>> IsNearlySupersetOf(
|
||||
LinearConstraintMap<double> expected, double tolerance) {
|
||||
return Matcher<LinearConstraintMap<double>>(
|
||||
new MapToDoubleMatcher<LinearConstraint>(std::move(expected),
|
||||
@@ -243,7 +243,7 @@ Matcher<absl::flat_hash_map<QuadraticConstraint, double>> IsNear(
|
||||
std::move(expected), /*all_keys=*/true, tolerance));
|
||||
}
|
||||
|
||||
Matcher<absl::flat_hash_map<QuadraticConstraint, double>> IsNearlySubsetOf(
|
||||
Matcher<absl::flat_hash_map<QuadraticConstraint, double>> IsNearlySupersetOf(
|
||||
absl::flat_hash_map<QuadraticConstraint, double> expected,
|
||||
double tolerance) {
|
||||
return Matcher<absl::flat_hash_map<QuadraticConstraint, double>>(
|
||||
@@ -260,7 +260,7 @@ Matcher<absl::flat_hash_map<K, double>> IsNear(
|
||||
}
|
||||
|
||||
template <typename K>
|
||||
Matcher<absl::flat_hash_map<K, double>> IsNearlySubsetOf(
|
||||
Matcher<absl::flat_hash_map<K, double>> IsNearlySupersetOf(
|
||||
absl::flat_hash_map<K, double> expected, const double tolerance) {
|
||||
return Matcher<absl::flat_hash_map<K, double>>(new MapToDoubleMatcher<K>(
|
||||
std::move(expected), /*all_keys=*/false, tolerance));
|
||||
|
||||
@@ -121,11 +121,11 @@ constexpr double kMatcherDefaultTolerance = 1e-5;
|
||||
testing::Matcher<VariableMap<double>> IsNear(
|
||||
VariableMap<double> expected, double tolerance = kMatcherDefaultTolerance);
|
||||
|
||||
// Checks that the keys of actual are a subset of the keys of expected, and that
|
||||
// for all shared keys, the values are within tolerance. This factory will
|
||||
// Checks that the keys of actual are a superset of the keys of expected, and
|
||||
// that for all shared keys, the values are within tolerance. This factory will
|
||||
// CHECK-fail if expected contains any NaN values, and any NaN values in the
|
||||
// expression compared against will result in the matcher failing.
|
||||
testing::Matcher<VariableMap<double>> IsNearlySubsetOf(
|
||||
testing::Matcher<VariableMap<double>> IsNearlySupersetOf(
|
||||
VariableMap<double> expected, double tolerance = kMatcherDefaultTolerance);
|
||||
|
||||
// Checks that the maps have identical keys and values within tolerance. This
|
||||
@@ -135,11 +135,11 @@ testing::Matcher<LinearConstraintMap<double>> IsNear(
|
||||
LinearConstraintMap<double> expected,
|
||||
double tolerance = kMatcherDefaultTolerance);
|
||||
|
||||
// Checks that the keys of actual are a subset of the keys of expected, and that
|
||||
// for all shared keys, the values are within tolerance. This factory will
|
||||
// Checks that the keys of actual are a superset of the keys of expected, and
|
||||
// that for all shared keys, the values are within tolerance. This factory will
|
||||
// CHECK-fail if expected contains any NaN values, and any NaN values in the
|
||||
// expression compared against will result in the matcher failing.
|
||||
testing::Matcher<LinearConstraintMap<double>> IsNearlySubsetOf(
|
||||
testing::Matcher<LinearConstraintMap<double>> IsNearlySupersetOf(
|
||||
LinearConstraintMap<double> expected,
|
||||
double tolerance = kMatcherDefaultTolerance);
|
||||
|
||||
@@ -149,13 +149,13 @@ testing::Matcher<absl::flat_hash_map<QuadraticConstraint, double>> IsNear(
|
||||
absl::flat_hash_map<QuadraticConstraint, double> expected,
|
||||
double tolerance = kMatcherDefaultTolerance);
|
||||
|
||||
// Checks that the keys of actual are a subset of the keys of expected, and that
|
||||
// for all shared keys, the values are within tolerance. This factory will
|
||||
// Checks that the keys of actual are a superset of the keys of expected, and
|
||||
// that for all shared keys, the values are within tolerance. This factory will
|
||||
// CHECK-fail if expected contains any NaN values, and any NaN values in the
|
||||
// expression compared against will result in the matcher failing.
|
||||
testing::Matcher<absl::flat_hash_map<QuadraticConstraint, double>>
|
||||
IsNearlySubsetOf(absl::flat_hash_map<QuadraticConstraint, double> expected,
|
||||
double tolerance = kMatcherDefaultTolerance);
|
||||
IsNearlySupersetOf(absl::flat_hash_map<QuadraticConstraint, double> expected,
|
||||
double tolerance = kMatcherDefaultTolerance);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Matchers for various Variable expressions (e.g. LinearExpression)
|
||||
|
||||
@@ -115,20 +115,21 @@ TEST(ApproximateMapMatcherTest, VariableIsNear) {
|
||||
EXPECT_THAT(actual, Not(IsNear({{z, -2.5}})));
|
||||
}
|
||||
|
||||
TEST(ApproximateMapMatcherTest, VariableIsNearlySubsetOf) {
|
||||
TEST(ApproximateMapMatcherTest, VariableIsNearlySupersetOf) {
|
||||
Model model;
|
||||
const Variable w = model.AddBinaryVariable("w");
|
||||
const Variable x = model.AddBinaryVariable("x");
|
||||
const Variable y = model.AddBinaryVariable("y");
|
||||
const Variable z = model.AddBinaryVariable("z");
|
||||
const VariableMap<double> actual = {{x, 2.0}, {y, 4.1}, {z, -2.5}};
|
||||
EXPECT_THAT(actual, IsNearlySubsetOf(actual));
|
||||
EXPECT_THAT(actual, IsNearlySubsetOf({{y, 4.1}, {z, -2.5}}));
|
||||
EXPECT_THAT(actual, Not(IsNearlySubsetOf({{w, 1}, {y, 4.1}, {z, -2.5}})));
|
||||
EXPECT_THAT(actual, Not(IsNearlySubsetOf({{y, 4.4}, {z, -2.5}})));
|
||||
EXPECT_THAT(actual, IsNearlySupersetOf(actual));
|
||||
EXPECT_THAT(actual, IsNearlySupersetOf({{y, 4.1}, {z, -2.5}}));
|
||||
EXPECT_THAT(actual, Not(IsNearlySupersetOf({{w, 1}, {y, 4.1}, {z, -2.5}})));
|
||||
EXPECT_THAT(actual, Not(IsNearlySupersetOf({{y, 4.4}, {z, -2.5}})));
|
||||
}
|
||||
|
||||
TEST(ApproximateMapMatcherTest, QuadraticConstraintIsNearAndIsNearlySubsetOf) {
|
||||
TEST(ApproximateMapMatcherTest,
|
||||
QuadraticConstraintIsNearAndIsNearlySupersetOf) {
|
||||
Model model;
|
||||
const Variable x = model.AddBinaryVariable("x");
|
||||
const QuadraticConstraint c = model.AddQuadraticConstraint(x * x <= 0, "c");
|
||||
@@ -137,29 +138,29 @@ TEST(ApproximateMapMatcherTest, QuadraticConstraintIsNearAndIsNearlySubsetOf) {
|
||||
|
||||
const absl::flat_hash_map<QuadraticConstraint, double> actual = {{c, 2},
|
||||
{e, 5}};
|
||||
EXPECT_THAT(actual, IsNearlySubsetOf(actual));
|
||||
EXPECT_THAT(actual, IsNearlySupersetOf(actual));
|
||||
EXPECT_THAT(actual, IsNear(actual));
|
||||
EXPECT_THAT(actual, IsNear({{c, 2 + 1e-8}, {e, 5}}));
|
||||
EXPECT_THAT(actual, Not(IsNear({{e, 5}})));
|
||||
EXPECT_THAT(actual, Not(IsNear({{c, 2 + 1e-2}, {e, 5}})));
|
||||
EXPECT_THAT(actual, Not(IsNear({{d, 5}})));
|
||||
EXPECT_THAT(actual, IsNearlySubsetOf({{e, 5}}));
|
||||
EXPECT_THAT(actual, IsNearlySupersetOf({{e, 5}}));
|
||||
}
|
||||
|
||||
TEST(ApproximateMapMatcherTest, LinearConstraintIsNearAndIsNearlySubsetOf) {
|
||||
TEST(ApproximateMapMatcherTest, LinearConstraintIsNearAndIsNearlySupersetOf) {
|
||||
Model model;
|
||||
const LinearConstraint c = model.AddLinearConstraint("c");
|
||||
const LinearConstraint d = model.AddLinearConstraint("d");
|
||||
const LinearConstraint e = model.AddLinearConstraint("e");
|
||||
|
||||
const LinearConstraintMap<double> actual = {{c, 2}, {e, 5}};
|
||||
EXPECT_THAT(actual, IsNearlySubsetOf(actual));
|
||||
EXPECT_THAT(actual, IsNearlySupersetOf(actual));
|
||||
EXPECT_THAT(actual, IsNear(actual));
|
||||
EXPECT_THAT(actual, IsNear({{c, 2 + 1e-8}, {e, 5}}));
|
||||
EXPECT_THAT(actual, Not(IsNear({{e, 5}})));
|
||||
EXPECT_THAT(actual, Not(IsNear({{c, 2 + 1e-2}, {e, 5}})));
|
||||
EXPECT_THAT(actual, Not(IsNear({{d, 5}})));
|
||||
EXPECT_THAT(actual, IsNearlySubsetOf({{e, 5}}));
|
||||
EXPECT_THAT(actual, IsNearlySupersetOf({{e, 5}}));
|
||||
}
|
||||
|
||||
TEST(LinearExpressionMatcherTest, IsIdentical) {
|
||||
|
||||
@@ -193,21 +193,6 @@ IncrementalSolverImpl::IncrementalSolverImpl(
|
||||
update_tracker_(std::move(update_tracker)),
|
||||
solver_(std::move(solver)) {}
|
||||
|
||||
absl::StatusOr<SolveResult> IncrementalSolverImpl::Solve(
|
||||
const SolveArguments& arguments) {
|
||||
// TODO: b/260337466 - Add permanent errors and concurrency protection.
|
||||
RETURN_IF_ERROR(Update().status());
|
||||
return SolveWithoutUpdate(arguments);
|
||||
}
|
||||
|
||||
absl::StatusOr<ComputeInfeasibleSubsystemResult>
|
||||
IncrementalSolverImpl::ComputeInfeasibleSubsystem(
|
||||
const ComputeInfeasibleSubsystemArguments& arguments) {
|
||||
// TODO: b/260337466 - Add permanent errors and concurrency protection.
|
||||
RETURN_IF_ERROR(Update().status());
|
||||
return ComputeInfeasibleSubsystemWithoutUpdate(arguments);
|
||||
}
|
||||
|
||||
absl::StatusOr<UpdateResult> IncrementalSolverImpl::Update() {
|
||||
// TODO: b/260337466 - Add permanent errors and concurrency protection.
|
||||
ASSIGN_OR_RETURN(std::optional<ModelUpdateProto> model_update,
|
||||
|
||||
@@ -81,11 +81,6 @@ class IncrementalSolverImpl : public IncrementalSolver {
|
||||
BaseSolverFactory solver_factory, Model* model, SolverType solver_type,
|
||||
const SolveInterrupter* user_canceller, bool remove_names);
|
||||
|
||||
absl::StatusOr<SolveResult> Solve(const SolveArguments& arguments) override;
|
||||
|
||||
absl::StatusOr<ComputeInfeasibleSubsystemResult> ComputeInfeasibleSubsystem(
|
||||
const ComputeInfeasibleSubsystemArguments& arguments) override;
|
||||
|
||||
absl::StatusOr<UpdateResult> Update() override;
|
||||
|
||||
absl::StatusOr<SolveResult> SolveWithoutUpdate(
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Defines how to request a callback and the input and output of a callback."""
|
||||
|
||||
import dataclasses
|
||||
import datetime
|
||||
import enum
|
||||
@@ -36,7 +37,8 @@ class Event(enum.Enum):
|
||||
* MIP: The solver is in the MIP loop (called periodically before starting a
|
||||
new node). Useful for early termination. Note that this event does not
|
||||
provide information on LP relaxations nor about new incumbent solutions.
|
||||
Gurobi only.
|
||||
Fully supported by Gurobi only. If used with CP-SAT, it is called when the
|
||||
dual bound is improved.
|
||||
* MIP_SOLUTION: Called every time a new MIP incumbent is found. Fully
|
||||
supported by Gurobi, partially supported by CP-SAT (you can observe new
|
||||
solutions, but not add lazy constraints).
|
||||
|
||||
@@ -52,6 +52,7 @@ cc_library(
|
||||
"//ortools/math_opt/io:mps_converter",
|
||||
"//ortools/port:proto_utils",
|
||||
"//ortools/port:scoped_std_stream_capture",
|
||||
"//ortools/util:fp_roundtrip_conv",
|
||||
"@abseil-cpp//absl/container:flat_hash_set",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/status",
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
#include "ortools/math_opt/solver_tests/callback_tests.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
@@ -44,6 +46,7 @@
|
||||
#include "ortools/math_opt/solver_tests/test_models.h"
|
||||
#include "ortools/port/proto_utils.h"
|
||||
#include "ortools/port/scoped_std_stream_capture.h"
|
||||
#include "ortools/util/fp_roundtrip_conv.h"
|
||||
|
||||
namespace operations_research {
|
||||
namespace math_opt {
|
||||
@@ -418,8 +421,8 @@ TEST_P(CallbackTest, EventSolutionAlwaysCalled) {
|
||||
SolveArguments args = {
|
||||
.callback_registration = {.events = {CallbackEvent::kMipSolution}}};
|
||||
absl::Mutex mutex;
|
||||
bool cb_called = false;
|
||||
bool cb_called_on_optimal = false;
|
||||
std::atomic<bool> cb_called = false;
|
||||
std::atomic<bool> cb_called_on_optimal = false;
|
||||
args.callback = [&](const CallbackData& callback_data) {
|
||||
const absl::MutexLock lock(mutex);
|
||||
cb_called = true;
|
||||
@@ -433,6 +436,8 @@ TEST_P(CallbackTest, EventSolutionAlwaysCalled) {
|
||||
EXPECT_THAT(
|
||||
sol, AnyOf(IsNear({{x, 0.0}, {y, 0.0}}), IsNear({{x, 1.0}, {y, 0.0}}),
|
||||
IsNear({{x, 0.0}, {y, 1.0}})));
|
||||
EXPECT_LE(callback_data.mip_stats.primal_bound(), 2.05);
|
||||
EXPECT_GE(callback_data.mip_stats.dual_bound(), 1.95);
|
||||
if (gtl::FindWithDefault(sol, y) > 0.5) {
|
||||
cb_called_on_optimal = true;
|
||||
}
|
||||
@@ -646,8 +651,8 @@ TEST_P(CallbackTest, EventSolutionFilter) {
|
||||
.events = {CallbackEvent::kMipSolution},
|
||||
.mip_solution_filter = MakeKeepKeysFilter({y})}};
|
||||
absl::Mutex mutex;
|
||||
bool cb_called = false;
|
||||
bool cb_called_on_optimal = false;
|
||||
std::atomic<bool> cb_called = false;
|
||||
std::atomic<bool> cb_called_on_optimal = false;
|
||||
args.callback = [&](const CallbackData& callback_data) {
|
||||
const absl::MutexLock lock(mutex);
|
||||
cb_called = true;
|
||||
@@ -795,6 +800,47 @@ TEST_P(CallbackTest, EventNodeFilter) {
|
||||
EXPECT_THAT(solutions, Each(UnorderedElementsAre(Pair(x0, _), Pair(x2, _))));
|
||||
}
|
||||
|
||||
TEST_P(CallbackTest, EventMip) {
|
||||
if (!GetParam().supported_events.contains(CallbackEvent::kMip)) {
|
||||
GTEST_SKIP() << "Test skipped because this solver does not support "
|
||||
"CallbackEvent::kMip.";
|
||||
}
|
||||
|
||||
// This test must use integer variables.
|
||||
ASSERT_TRUE(GetParam().integer_variables);
|
||||
|
||||
// Use the MIPLIB instance 23588, which has optimal solution 8090 and LP
|
||||
// relaxation of 7649.87. This instance was selected because every
|
||||
// supported solver can solve it quickly (a few seconds), but no solver can
|
||||
// solve it in one node (so the node callback will be invoked).
|
||||
ASSERT_OK_AND_ASSIGN(const std::unique_ptr<Model> model,
|
||||
LoadMiplibInstance("23588"));
|
||||
|
||||
std::atomic<double> best_primal_bound =
|
||||
std::numeric_limits<double>::infinity();
|
||||
std::atomic<double> best_dual_bound =
|
||||
-std::numeric_limits<double>::infinity();
|
||||
const SolveArguments args = {
|
||||
.callback_registration = {.events = {CallbackEvent::kMip}},
|
||||
.callback = [&](const CallbackData& callback_data) {
|
||||
CHECK_EQ(callback_data.event, CallbackEvent::kMip);
|
||||
const double primal_bound = callback_data.mip_stats.primal_bound();
|
||||
const double dual_bound = callback_data.mip_stats.dual_bound();
|
||||
best_primal_bound = std::fmin(best_primal_bound, primal_bound);
|
||||
best_dual_bound = std::fmax(best_dual_bound, dual_bound);
|
||||
return CallbackResult();
|
||||
}};
|
||||
EXPECT_THAT(Solve(*model, GetParam().solver_type, args),
|
||||
IsOkAndHolds(IsOptimal(8090)));
|
||||
LOG(INFO) << "best_primal_bound: "
|
||||
<< RoundTripDoubleFormat(best_primal_bound.load());
|
||||
LOG(INFO) << "best_dual_bound: "
|
||||
<< RoundTripDoubleFormat(best_dual_bound.load());
|
||||
EXPECT_THAT(best_primal_bound.load(), testing::DoubleNear(8090, 0.5));
|
||||
EXPECT_LE(best_dual_bound.load(), 8090.5);
|
||||
EXPECT_GE(best_dual_bound.load(), 7640);
|
||||
}
|
||||
|
||||
TEST_P(CallbackTest, StatusPropagation) {
|
||||
if (!GetParam().supported_events.contains(CallbackEvent::kMipSolution)) {
|
||||
GTEST_SKIP() << "Test skipped because this solver does not support "
|
||||
|
||||
@@ -238,13 +238,17 @@ cc_library(
|
||||
"//ortools/port:proto_utils",
|
||||
"//ortools/sat:sat_parameters_cc_proto",
|
||||
"//ortools/util:solve_interrupter",
|
||||
"@abseil-cpp//absl/base:core_headers",
|
||||
"@abseil-cpp//absl/base:nullability",
|
||||
"@abseil-cpp//absl/container:flat_hash_set",
|
||||
"@abseil-cpp//absl/functional:any_invocable",
|
||||
"@abseil-cpp//absl/log",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/memory",
|
||||
"@abseil-cpp//absl/status",
|
||||
"@abseil-cpp//absl/status:statusor",
|
||||
"@abseil-cpp//absl/strings",
|
||||
"@abseil-cpp//absl/synchronization",
|
||||
"@abseil-cpp//absl/time",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
],
|
||||
|
||||
@@ -14,15 +14,21 @@
|
||||
#include "ortools/math_opt/solvers/cp_sat_solver.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/attributes.h"
|
||||
#include "absl/base/nullability.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/functional/any_invocable.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/memory/memory.h"
|
||||
@@ -33,6 +39,7 @@
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/time/clock.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "absl/types/span.h"
|
||||
@@ -65,6 +72,8 @@ namespace math_opt {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr double kInf = std::numeric_limits<double>::infinity();
|
||||
|
||||
constexpr SupportedProblemStructures kCpSatSupportedStructures = {
|
||||
.integer_variables = SupportType::kSupported,
|
||||
.quadratic_objectives = SupportType::kNotImplemented,
|
||||
@@ -316,6 +325,162 @@ absl::StatusOr<TerminationProto> GetTermination(
|
||||
absl::StrCat("unimplemented solve status: ", response.status()));
|
||||
}
|
||||
|
||||
// This class gathers the solution callback and best bound callback together
|
||||
// with some solver state that we need to update as the solver progresses.
|
||||
class CpSatCallbacks {
|
||||
public:
|
||||
CpSatCallbacks(const absl_nullable SolverInterface::Callback& cb
|
||||
ABSL_ATTRIBUTE_LIFETIME_BOUND,
|
||||
SolveInterrupter* absl_nonnull local_interrupter
|
||||
ABSL_ATTRIBUTE_LIFETIME_BOUND,
|
||||
absl_nonnull absl::AnyInvocable<
|
||||
SparseDoubleVectorProto(absl::Span<const double>) const>
|
||||
extract_solution,
|
||||
absl::flat_hash_set<CallbackEventProto> events,
|
||||
bool is_maximize);
|
||||
|
||||
// CpSatCallbacks is neither copyable nor movable as callbacks point to it.
|
||||
CpSatCallbacks(const CpSatCallbacks&) = delete;
|
||||
CpSatCallbacks& operator=(const CpSatCallbacks&) = delete;
|
||||
|
||||
// Returns a solution callback that wraps the user callback and updates the
|
||||
// state of CpSatCallbacks. Returns nullptr if it is not needed.
|
||||
absl_nullable std::function<void(const MPSolution&)> MakeSolutionCallback();
|
||||
|
||||
// Returns a best bound callback that wraps the user callback and updates the
|
||||
// state of CpSatCallbacks. Returns nullptr if it is not needed.
|
||||
absl_nullable std::function<void(const double)> MakeBestBoundCallback();
|
||||
|
||||
absl::Status error() const {
|
||||
absl::MutexLock lock(mutex_);
|
||||
return error_;
|
||||
}
|
||||
|
||||
private:
|
||||
void ExecuteCallback(const CallbackDataProto& cb_data);
|
||||
void UpdateMipStatsFromNewSolution(const MPSolution& mp_solution)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
|
||||
|
||||
const SolverInterface::Callback& cb_;
|
||||
SolveInterrupter* absl_nonnull const local_interrupter_;
|
||||
const absl::AnyInvocable<SparseDoubleVectorProto(absl::Span<const double>)
|
||||
const>
|
||||
extract_solution_;
|
||||
const bool has_mip_solution_event_;
|
||||
const bool has_mip_event_;
|
||||
const bool is_maximize_;
|
||||
|
||||
mutable absl::Mutex mutex_;
|
||||
absl::Status error_ ABSL_GUARDED_BY(mutex_) = absl::OkStatus();
|
||||
CallbackDataProto::MipStats current_mip_stats_ ABSL_GUARDED_BY(mutex_);
|
||||
};
|
||||
|
||||
CpSatCallbacks::CpSatCallbacks(
|
||||
const SolverInterface::Callback& cb ABSL_ATTRIBUTE_LIFETIME_BOUND,
|
||||
SolveInterrupter* absl_nonnull local_interrupter
|
||||
ABSL_ATTRIBUTE_LIFETIME_BOUND,
|
||||
absl_nonnull
|
||||
absl::AnyInvocable<SparseDoubleVectorProto(absl::Span<const double>) const>
|
||||
extract_solution ABSL_ATTRIBUTE_LIFETIME_BOUND,
|
||||
absl::flat_hash_set<CallbackEventProto> events, const bool is_maximize)
|
||||
: cb_(cb),
|
||||
local_interrupter_(local_interrupter),
|
||||
extract_solution_(std::move(extract_solution)),
|
||||
// If there is no user callback, we make sure not calling it.
|
||||
has_mip_solution_event_(cb != nullptr &&
|
||||
events.contains(CALLBACK_EVENT_MIP_SOLUTION)),
|
||||
has_mip_event_(cb != nullptr && events.contains(CALLBACK_EVENT_MIP)),
|
||||
is_maximize_(is_maximize) {
|
||||
current_mip_stats_.set_primal_bound(is_maximize ? -kInf : kInf);
|
||||
current_mip_stats_.set_dual_bound(is_maximize ? kInf : -kInf);
|
||||
current_mip_stats_.set_number_of_solutions_found(0);
|
||||
}
|
||||
|
||||
std::function<void(const MPSolution&)> absl_nullable
|
||||
CpSatCallbacks::MakeSolutionCallback() {
|
||||
if (!has_mip_solution_event_ && !has_mip_event_) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!has_mip_solution_event_) {
|
||||
return [this](const MPSolution& mp_solution) {
|
||||
absl::MutexLock lock(mutex_);
|
||||
UpdateMipStatsFromNewSolution(mp_solution);
|
||||
};
|
||||
}
|
||||
return [this](const MPSolution& mp_solution) {
|
||||
CallbackDataProto cb_data;
|
||||
cb_data.set_event(CALLBACK_EVENT_MIP_SOLUTION);
|
||||
*cb_data.mutable_primal_solution_vector() =
|
||||
extract_solution_(mp_solution.variable_value());
|
||||
{
|
||||
absl::MutexLock lock(mutex_);
|
||||
UpdateMipStatsFromNewSolution(mp_solution);
|
||||
*cb_data.mutable_mip_stats() = current_mip_stats_;
|
||||
}
|
||||
ExecuteCallback(cb_data);
|
||||
};
|
||||
}
|
||||
|
||||
std::function<void(const double)> absl_nullable
|
||||
CpSatCallbacks::MakeBestBoundCallback() {
|
||||
if (!has_mip_solution_event_ && !has_mip_event_) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!has_mip_event_) {
|
||||
return [this](const double best_bound) {
|
||||
absl::MutexLock lock(mutex_);
|
||||
current_mip_stats_.set_dual_bound(best_bound);
|
||||
};
|
||||
}
|
||||
return [this](const double best_bound) {
|
||||
CallbackDataProto cb_data;
|
||||
cb_data.set_event(CALLBACK_EVENT_MIP);
|
||||
{
|
||||
absl::MutexLock lock(mutex_);
|
||||
current_mip_stats_.set_dual_bound(best_bound);
|
||||
*cb_data.mutable_mip_stats() = current_mip_stats_;
|
||||
}
|
||||
ExecuteCallback(cb_data);
|
||||
};
|
||||
}
|
||||
|
||||
void CpSatCallbacks::ExecuteCallback(const CallbackDataProto& cb_data) {
|
||||
{
|
||||
absl::MutexLock lock(mutex_);
|
||||
if (!error_.ok()) {
|
||||
// A previous callback failed.
|
||||
return;
|
||||
}
|
||||
}
|
||||
const absl::StatusOr<CallbackResultProto> cb_result = cb_(cb_data);
|
||||
// Note cb_result.cuts and cb_result.suggested solutions are not supported
|
||||
// by CP-SAT and we have validated they are empty.
|
||||
if (!cb_result.ok()) {
|
||||
{
|
||||
absl::MutexLock lock(mutex_);
|
||||
error_ = cb_result.status();
|
||||
}
|
||||
// Note: we will be returning a status error, we do not need to worry
|
||||
// about interpreting this as TERMINATION_REASON_INTERRUPTED.
|
||||
local_interrupter_->Interrupt();
|
||||
} else if (cb_result->terminate()) {
|
||||
local_interrupter_->Interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
void CpSatCallbacks::UpdateMipStatsFromNewSolution(
|
||||
const MPSolution& mp_solution) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
||||
if (is_maximize_) {
|
||||
current_mip_stats_.set_primal_bound(std::fmax(
|
||||
current_mip_stats_.primal_bound(), mp_solution.objective_value()));
|
||||
} else {
|
||||
current_mip_stats_.set_primal_bound(std::fmin(
|
||||
current_mip_stats_.primal_bound(), mp_solution.objective_value()));
|
||||
}
|
||||
current_mip_stats_.set_number_of_solutions_found(
|
||||
current_mip_stats_.number_of_solutions_found() + 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<SolverInterface>> CpSatSolver::New(
|
||||
@@ -345,7 +510,7 @@ absl::StatusOr<SolveResultProto> CpSatSolver::Solve(
|
||||
|
||||
RETURN_IF_ERROR(CheckRegisteredCallbackEvents(
|
||||
callback_registration,
|
||||
/*supported_events=*/{CALLBACK_EVENT_MIP_SOLUTION}));
|
||||
/*supported_events=*/{CALLBACK_EVENT_MIP_SOLUTION, CALLBACK_EVENT_MIP}));
|
||||
if (callback_registration.add_lazy_constraints()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"CallbackRegistrationProto.add_lazy_constraints=true is not supported "
|
||||
@@ -392,7 +557,7 @@ absl::StatusOr<SolveResultProto> CpSatSolver::Solve(
|
||||
}
|
||||
|
||||
// We need to chain the user interrupter through a local interrupter, because
|
||||
// if we termiante early from a callback request, we don't want to incorrectly
|
||||
// if we terminate early from a callback request, we don't want to incorrectly
|
||||
// modify the input state.
|
||||
SolveInterrupter local_interrupter;
|
||||
std::atomic<bool> interrupt_solve = false;
|
||||
@@ -411,41 +576,21 @@ absl::StatusOr<SolveResultProto> CpSatSolver::Solve(
|
||||
|
||||
const absl::flat_hash_set<CallbackEventProto> events =
|
||||
EventSet(callback_registration);
|
||||
std::function<void(const MPSolution&)> solution_callback;
|
||||
absl::Status callback_error = absl::OkStatus();
|
||||
if (events.contains(CALLBACK_EVENT_MIP_SOLUTION)) {
|
||||
solution_callback =
|
||||
[this, &cb, &callback_error, &local_interrupter,
|
||||
&callback_registration](const MPSolution& mp_solution) {
|
||||
if (!callback_error.ok()) {
|
||||
// A previous callback failed.
|
||||
return;
|
||||
}
|
||||
CallbackDataProto cb_data;
|
||||
cb_data.set_event(CALLBACK_EVENT_MIP_SOLUTION);
|
||||
*cb_data.mutable_primal_solution_vector() =
|
||||
ExtractSolution(mp_solution.variable_value(),
|
||||
callback_registration.mip_solution_filter());
|
||||
const absl::StatusOr<CallbackResultProto> cb_result = cb(cb_data);
|
||||
if (!cb_result.ok()) {
|
||||
callback_error = cb_result.status();
|
||||
// Note: we will be returning a status error, we do not need to
|
||||
// worry about interpreting this as TERMINATION_REASON_INTERRUPTED.
|
||||
local_interrupter.Interrupt();
|
||||
} else if (cb_result->terminate()) {
|
||||
local_interrupter.Interrupt();
|
||||
}
|
||||
// Note cb_result.cuts and cb_result.suggested solutions are not
|
||||
// supported by CP-SAT and we have validated they are empty.
|
||||
};
|
||||
}
|
||||
absl::AnyInvocable<SparseDoubleVectorProto(absl::Span<const double>) const>
|
||||
extract_solution = [&](absl::Span<const double> cp_sat_variable_values) {
|
||||
return ExtractSolution(cp_sat_variable_values,
|
||||
callback_registration.mip_solution_filter());
|
||||
};
|
||||
CpSatCallbacks callbacks(cb, &local_interrupter, std::move(extract_solution),
|
||||
events, cp_sat_model_.maximize());
|
||||
|
||||
// CP-SAT returns "infeasible" for inverted bounds.
|
||||
RETURN_IF_ERROR(ListInvertedBounds().ToStatus());
|
||||
|
||||
const MPSolutionResponse response = SatSolveProto(
|
||||
std::move(req), &interrupt_solve, logging_callback, solution_callback);
|
||||
RETURN_IF_ERROR(callback_error) << "error in callback";
|
||||
std::move(req), &interrupt_solve, logging_callback,
|
||||
callbacks.MakeSolutionCallback(), callbacks.MakeBestBoundCallback());
|
||||
RETURN_IF_ERROR(callbacks.error()) << "error in callback";
|
||||
ASSIGN_OR_RETURN(*result.mutable_termination(),
|
||||
GetTermination(local_interrupter.IsInterrupted(),
|
||||
/*maximize=*/cp_sat_model_.maximize(),
|
||||
|
||||
@@ -328,15 +328,16 @@ SolveParameters AllSolutions() {
|
||||
return result;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(CpSatCallbackTest, CallbackTest,
|
||||
Values(CallbackTestParams(
|
||||
SolverType::kCpSat,
|
||||
/*integer_variables=*/true,
|
||||
/*add_lazy_constraints=*/false,
|
||||
/*add_cuts=*/false,
|
||||
/*supported_events=*/{CallbackEvent::kMipSolution},
|
||||
/*all_solutions=*/AllSolutions(),
|
||||
/*reaches_cut_callback=*/std::nullopt)));
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
CpSatCallbackTest, CallbackTest,
|
||||
Values(CallbackTestParams(
|
||||
SolverType::kCpSat,
|
||||
/*integer_variables=*/true,
|
||||
/*add_lazy_constraints=*/false,
|
||||
/*add_cuts=*/false,
|
||||
/*supported_events=*/{CallbackEvent::kMipSolution, CallbackEvent::kMip},
|
||||
/*all_solutions=*/AllSolutions(),
|
||||
/*reaches_cut_callback=*/std::nullopt)));
|
||||
|
||||
TEST(CpSatInvalidCallbackTest, RequestLazyConstraints) {
|
||||
Model model("model");
|
||||
|
||||
Reference in New Issue
Block a user