Files
ortools-clone/ortools/sat/python/swig_helper.cc

118 lines
5.2 KiB
C++

// Copyright 2010-2024 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This file wraps the swig_helper.h classes in python using pybind11.
// Because pybind11_protobuf does not support building with CMake for OR-Tools,
// the API has been transformed to use serialized protos from Python to C++ and
// from C++ to python:
// from Python to C++: use proto.SerializeToString(). This creates a python
// string that is passed to C++ and parsed back to proto.
// from C++ to Python, we cast the result of proto.SerializeAsString() to
// pybind11::bytes. This is passed back to python, which will reconstruct
// the proto using PythonProto.FromString(byte[]).
#include "ortools/sat/swig_helper.h"
#include <string>
#include "absl/strings/string_view.h"
#include "ortools/sat/cp_model.pb.h"
#include "ortools/util/sorted_interval_list.h"
#include "pybind11/cast.h"
#include "pybind11/functional.h"
#include "pybind11/gil.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11_protobuf/native_proto_caster.h"
using ::operations_research::Domain;
using ::operations_research::sat::CpModelProto;
using ::operations_research::sat::CpSatHelper;
using ::operations_research::sat::CpSolverResponse;
using ::operations_research::sat::IntegerVariableProto;
using ::operations_research::sat::SatParameters;
using ::operations_research::sat::SolutionCallback;
using ::operations_research::sat::SolveWrapper;
using ::pybind11::arg;
class PySolutionCallback : public SolutionCallback {
public:
using SolutionCallback::SolutionCallback; /* Inherit constructors */
void OnSolutionCallback() const override {
::pybind11::gil_scoped_acquire acquire;
PYBIND11_OVERRIDE_PURE(
void, /* Return type */
SolutionCallback, /* Parent class */
OnSolutionCallback, /* Name of function */
/* This function has no arguments. The trailing comma
in the previous line is needed for some compilers */
);
}
};
PYBIND11_MODULE(swig_helper, m) {
pybind11_protobuf::ImportNativeProtoCasters();
pybind11::module::import("ortools.util.python.sorted_interval_list");
pybind11::class_<SolutionCallback, PySolutionCallback>(m, "SolutionCallback")
.def(pybind11::init<>())
.def("OnSolutionCallback", &SolutionCallback::OnSolutionCallback)
.def("BestObjectiveBound", &SolutionCallback::BestObjectiveBound)
.def("DeterministicTime", &SolutionCallback::DeterministicTime)
.def("HasResponse", &SolutionCallback::HasResponse)
.def("NumBinaryPropagations", &SolutionCallback::NumBinaryPropagations)
.def("NumBooleans", &SolutionCallback::NumBooleans)
.def("NumBranches", &SolutionCallback::NumBranches)
.def("NumConflicts", &SolutionCallback::NumConflicts)
.def("NumIntegerPropagations", &SolutionCallback::NumIntegerPropagations)
.def("ObjectiveValue", &SolutionCallback::ObjectiveValue)
.def("Response", &SolutionCallback::Response)
.def("SolutionBooleanValue", &SolutionCallback::SolutionBooleanValue,
arg("index"))
.def("SolutionIntegerValue", &SolutionCallback::SolutionIntegerValue,
arg("index"))
.def("StopSearch", &SolutionCallback::StopSearch)
.def("UserTime", &SolutionCallback::UserTime)
.def("WallTime", &SolutionCallback::WallTime);
pybind11::class_<SolveWrapper>(m, "SolveWrapper")
.def(pybind11::init<>())
.def("add_log_callback", &SolveWrapper::AddLogCallback,
arg("log_callback"))
.def("add_solution_callback", &SolveWrapper::AddSolutionCallback,
arg("callback"))
.def("clear_solution_callback", &SolveWrapper::ClearSolutionCallback)
.def("add_best_bound_callback", &SolveWrapper::AddBestBoundCallback,
arg("best_bound_callback"))
.def("set_parameters", &SolveWrapper::SetParameters, arg("parameters"))
.def("solve",
[](SolveWrapper* solve_wrapper,
const CpModelProto& model_proto) -> CpSolverResponse {
::pybind11::gil_scoped_release release;
return solve_wrapper->Solve(model_proto);
})
.def("stop_search", &SolveWrapper::StopSearch);
pybind11::class_<CpSatHelper>(m, "CpSatHelper")
.def_static("model_stats", &CpSatHelper::ModelStats, arg("model_proto"))
.def_static("solver_response_stats", &CpSatHelper::SolverResponseStats,
arg("response"))
.def_static("validate_model", &CpSatHelper::ValidateModel,
arg("model_proto"))
.def_static("variable_domain", &CpSatHelper::VariableDomain,
arg("variable_proto"))
.def_static("write_model_to_file", &CpSatHelper::WriteModelToFile,
arg("model_proto"), arg("filename"));
}