[CP-SAT] code cleanup; more pseudo-cost experimental code; add objective best bound callback
This commit is contained in:
committed by
Corentin Le Molgat
parent
8dbfb730b1
commit
d28edd701c
@@ -27,7 +27,7 @@ _OUTPUT_PROTO = flags.DEFINE_string(
|
||||
)
|
||||
_PARAMS = flags.DEFINE_string(
|
||||
"params",
|
||||
"num_search_workers:16,log_search_progress:true,max_time_in_seconds:45",
|
||||
"num_search_workers:16,log_search_progress:false,max_time_in_seconds:45",
|
||||
"Sat solver parameters.",
|
||||
)
|
||||
_PREPROCESS = flags.DEFINE_bool(
|
||||
@@ -503,6 +503,7 @@ def single_machine_scheduling():
|
||||
if parameters:
|
||||
text_format.Parse(parameters, solver.parameters)
|
||||
solution_printer = SolutionPrinter()
|
||||
solver.best_bound_callback = lambda a : print(f"New objective lower bound: {a}")
|
||||
solver.solve(model, solution_printer)
|
||||
for job_id in all_jobs:
|
||||
print(
|
||||
|
||||
@@ -30,6 +30,7 @@ public final class CpSolver {
|
||||
public CpSolver() {
|
||||
this.solveParameters = SatParameters.newBuilder();
|
||||
this.logCallback = null;
|
||||
this.bestBoundCallback = null;
|
||||
this.solveWrapper = null;
|
||||
}
|
||||
|
||||
@@ -52,6 +53,9 @@ public final class CpSolver {
|
||||
if (logCallback != null) {
|
||||
solveWrapper.addLogCallback(logCallback);
|
||||
}
|
||||
if (bestBoundCallback != null) {
|
||||
solveWrapper.addBestBoundCallback(bestBoundCallback);
|
||||
}
|
||||
|
||||
solveResponse = solveWrapper.solve(model.model());
|
||||
|
||||
@@ -201,5 +205,6 @@ public final class CpSolver {
|
||||
private CpSolverResponse solveResponse;
|
||||
private final SatParameters.Builder solveParameters;
|
||||
private Consumer<String> logCallback;
|
||||
private Consumer<Double> bestBoundCallback;
|
||||
private SolveWrapper solveWrapper;
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/meta/type_traits.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "ortools/base/logging.h"
|
||||
#include "ortools/port/proto_utils.h"
|
||||
|
||||
@@ -1088,6 +1088,7 @@ struct SolutionObservers {
|
||||
std::vector<std::function<void(const CpSolverResponse& response)>> observers;
|
||||
std::vector<std::function<std::string(const CpSolverResponse& response)>>
|
||||
log_callbacks;
|
||||
std::vector<std::function<void(double)>> best_bound_callbacks;
|
||||
};
|
||||
|
||||
std::function<void(Model*)> NewFeasibleSolutionObserver(
|
||||
@@ -1105,6 +1106,14 @@ std::function<void(Model*)> NewFeasibleSolutionLogCallback(
|
||||
};
|
||||
}
|
||||
|
||||
std::function<void(Model*)> NewBestBoundCallback(
|
||||
const std::function<void(double)>& callback) {
|
||||
return [=](Model* model) {
|
||||
model->GetOrCreate<SolutionObservers>()->best_bound_callbacks.push_back(
|
||||
callback);
|
||||
};
|
||||
}
|
||||
|
||||
#if !defined(__PORTABLE_PLATFORM__)
|
||||
// TODO(user): Support it on android.
|
||||
std::function<SatParameters(Model*)> NewSatParameters(
|
||||
@@ -4185,6 +4194,12 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) {
|
||||
shared_response_manager->AddLogCallback(callback);
|
||||
}
|
||||
|
||||
const auto& best_bound_callbacks =
|
||||
model->GetOrCreate<SolutionObservers>()->best_bound_callbacks;
|
||||
for (const auto& callback : best_bound_callbacks) {
|
||||
shared_response_manager->AddBestBoundCallback(callback);
|
||||
}
|
||||
|
||||
// Make sure everything stops when we have a first solution if requested.
|
||||
if (params.stop_after_first_solution()) {
|
||||
shared_response_manager->AddSolutionCallback(
|
||||
|
||||
@@ -103,6 +103,15 @@ std::function<void(Model*)> NewFeasibleSolutionLogCallback(
|
||||
const std::function<std::string(const CpSolverResponse& response)>&
|
||||
callback);
|
||||
|
||||
/** Creates a callbacks that will be called on each new best objective bound
|
||||
* found.
|
||||
*
|
||||
* Note that adding a callback is not free since some computation will be done
|
||||
* before this is called, and it will only be called on optimization models.
|
||||
*/
|
||||
std::function<void(Model*)> NewBestBoundCallback(
|
||||
const std::function<void(double)>& callback);
|
||||
|
||||
/**
|
||||
* Creates parameters for the solver, which you can add to the model with
|
||||
* \code
|
||||
|
||||
@@ -41,6 +41,10 @@ public class CpSolver
|
||||
{
|
||||
solve_wrapper_.AddLogCallbackFromClass(log_callback_);
|
||||
}
|
||||
if (best_bound_callback_ is not null)
|
||||
{
|
||||
solve_wrapper_.AddBestBoundCallbackFromClass(best_bound_callback_);
|
||||
}
|
||||
if (cb is not null)
|
||||
{
|
||||
solve_wrapper_.AddSolutionCallback(cb);
|
||||
@@ -141,6 +145,11 @@ public class CpSolver
|
||||
log_callback_ = new LogCallbackDelegate(del);
|
||||
}
|
||||
|
||||
public void SetBestBoundCallback(DoubleToVoidDelegate del)
|
||||
{
|
||||
best_bound_callback_ = new BestBoundCallbackDelegate(del);
|
||||
}
|
||||
|
||||
public CpSolverResponse Response
|
||||
{
|
||||
get {
|
||||
@@ -283,6 +292,7 @@ public class CpSolver
|
||||
|
||||
private CpSolverResponse response_;
|
||||
private LogCallback log_callback_;
|
||||
private BestBoundCallback best_bound_callback_;
|
||||
private string string_parameters_;
|
||||
private SolveWrapper solve_wrapper_;
|
||||
private Queue<Term> terms_;
|
||||
@@ -303,4 +313,19 @@ class LogCallbackDelegate : LogCallback
|
||||
private StringToVoidDelegate delegate_;
|
||||
}
|
||||
|
||||
class BestBoundCallbackDelegate : BestBoundCallback
|
||||
{
|
||||
public BestBoundCallbackDelegate(DoubleToVoidDelegate del)
|
||||
{
|
||||
this.delegate_ = del;
|
||||
}
|
||||
|
||||
public override void NewBestBound(double bound)
|
||||
{
|
||||
delegate_(bound);
|
||||
}
|
||||
|
||||
private DoubleToVoidDelegate delegate_;
|
||||
}
|
||||
|
||||
} // namespace Google.OrTools.Sat
|
||||
|
||||
@@ -39,6 +39,8 @@ using Google.OrTools.Util;
|
||||
%typemap(csimports) operations_research::sat::SolveWrapper %{
|
||||
// Used to wrap log callbacks (std::function<void(const std::string&>)
|
||||
public delegate void StringToVoidDelegate(string message);
|
||||
// Used to wrap best bound callbacks (std::function<void(double>)
|
||||
public delegate void DoubleToVoidDelegate(double bound);
|
||||
%}
|
||||
|
||||
PROTO_INPUT(operations_research::sat::CpModelProto,
|
||||
@@ -82,8 +84,16 @@ JAGGED_MATRIX_AS_CSHARP_ARRAY(int64_t, int64_t, long, Int64VectorVector);
|
||||
%unignore operations_research::sat::LogCallback::~LogCallback;
|
||||
%unignore operations_research::sat::LogCallback::NewMessage;
|
||||
|
||||
// Temporary wrapper class for the DoubleToVoidDelegate.
|
||||
%feature("director") operations_research::sat::BestBoundCallback;
|
||||
%unignore operations_research::sat::BestBoundCallback;
|
||||
%unignore operations_research::sat::BestBoundCallback::~BestBoundCallback;
|
||||
%unignore operations_research::sat::BestBoundCallback::NewBestBound;
|
||||
|
||||
|
||||
// Wrap the SolveWrapper class.
|
||||
%unignore operations_research::sat::SolveWrapper;
|
||||
%unignore operations_research::sat::SolveWrapper::AddBestBoundCallbackFromClass;
|
||||
%unignore operations_research::sat::SolveWrapper::AddLogCallbackFromClass;
|
||||
%unignore operations_research::sat::SolveWrapper::AddSolutionCallback;
|
||||
%unignore operations_research::sat::SolveWrapper::ClearSolutionCallback;
|
||||
|
||||
@@ -256,33 +256,34 @@ std::function<BooleanOrIntegerLiteral()> LpPseudoCostHeuristic(Model* model) {
|
||||
// average, it is good anyway?
|
||||
if (!is_reliable && is_integer) continue;
|
||||
|
||||
// For Booleans, for some reason it seems the up-branch first work better?
|
||||
if (lb == 0 && ub == 1) {
|
||||
const double score = pseudo_costs->LpPseudoCost(var, lp_value);
|
||||
if (score > best_score) {
|
||||
const LiteralIndex index = encoder->GetAssociatedLiteral(
|
||||
IntegerLiteral::GreaterOrEqual(var, 1));
|
||||
if (index != kNoLiteralIndex) {
|
||||
best_score = score;
|
||||
decision = BooleanOrIntegerLiteral(Literal(index));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// There are some corner cases if we are at the bound. Note that it is
|
||||
// important to be in sync with the SplitAroundLpValue() below.
|
||||
double down_fractionality = lp_value - std::floor(lp_value);
|
||||
if (lp_value >= ToDouble(ub)) down_fractionality = 1.0;
|
||||
if (lp_value <= ToDouble(lb)) down_fractionality = 0.0;
|
||||
const double score = pseudo_costs->LpPseudoCost(var, down_fractionality);
|
||||
IntegerValue down_target = IntegerValue(std::floor(lp_value));
|
||||
if (lp_value >= ToDouble(ub)) {
|
||||
down_fractionality = 1.0;
|
||||
down_target = ub - 1;
|
||||
} else if (lp_value <= ToDouble(lb)) {
|
||||
down_fractionality = 0.0;
|
||||
down_target = lb;
|
||||
}
|
||||
const auto [down_score, up_score] =
|
||||
pseudo_costs->LpPseudoCost(var, down_fractionality);
|
||||
const double score = pseudo_costs->CombineScores(down_score, up_score);
|
||||
|
||||
// We delay to subsequent heuristic if the score is 0.0.
|
||||
if (score > best_score) {
|
||||
best_score = score;
|
||||
|
||||
// This choose <= value if possible.
|
||||
decision = BooleanOrIntegerLiteral(SplitAroundGivenValue(
|
||||
var, IntegerValue(std::floor(lp_value)), model));
|
||||
// This direction works better than the inverse in the benchs. But
|
||||
// always branching up seems even better. TODO(user): investigate.
|
||||
if (down_score > up_score) {
|
||||
decision = BooleanOrIntegerLiteral(
|
||||
IntegerLiteral::LowerOrEqual(var, down_target));
|
||||
} else {
|
||||
decision = BooleanOrIntegerLiteral(
|
||||
IntegerLiteral::GreaterOrEqual(var, down_target + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
return decision;
|
||||
|
||||
@@ -115,6 +115,40 @@ PROTO2_RETURN(operations_research::sat::CpSolverResponse,
|
||||
%typemap(jstype) std::function<void(const std::string&)> "java.util.function.Consumer<String>" // Type used in the Proxy class.
|
||||
%typemap(javain) std::function<void(const std::string&)> "$javainput" // passing the Callback to JNI java class.
|
||||
|
||||
%typemap(in) std::function<void(double)> %{
|
||||
// $input will be deleted once this function return.
|
||||
// So we create a JNI global reference to keep it alive.
|
||||
jobject $input_object = jenv->NewGlobalRef($input);
|
||||
// and we wrap it in a GlobalRefGuard object which will call the
|
||||
// JNI global reference deleter to avoid leak at destruction.
|
||||
JavaVM* jvm;
|
||||
jenv->GetJavaVM(&jvm);
|
||||
auto $input_guard = std::make_shared<GlobalRefGuard>(jvm, $input_object);
|
||||
|
||||
jclass $input_object_class = jenv->GetObjectClass($input);
|
||||
if (nullptr == $input_object_class) return $null;
|
||||
jmethodID $input_method_id = jenv->GetMethodID(
|
||||
$input_object_class, "accept", "(Ljava/lang/Double;)V");
|
||||
assert($input_method_id != nullptr);
|
||||
|
||||
// When the lambda will be destroyed, input_guard's destructor will be called.
|
||||
$1 = [jvm, $input_object, $input_method_id, $input_guard](
|
||||
double bound) -> void {
|
||||
JNIEnv *jenv = NULL;
|
||||
JavaVMAttachArgs args;
|
||||
args.version = JNI_VERSION_1_2;
|
||||
args.name = NULL;
|
||||
args.group = NULL;
|
||||
jvm->AttachCurrentThread((void**)&jenv, &args);
|
||||
jenv->CallVoidMethod($input_object, $input_method_id, bound);
|
||||
jvm->DetachCurrentThread();
|
||||
};
|
||||
%}
|
||||
%typemap(jni) std::function<void(double)> "jobject" // Type used in the JNI.java.
|
||||
%typemap(jtype) std::function<void(double)> "java.util.function.Consumer<Double>" // Type used in the JNI.java.
|
||||
%typemap(jstype) std::function<void(double)> "java.util.function.Consumer<Double>" // Type used in the Proxy class.
|
||||
%typemap(javain) std::function<void(double)> "$javainput" // passing the Callback to JNI java class.
|
||||
|
||||
%ignoreall
|
||||
|
||||
%unignore operations_research;
|
||||
@@ -122,6 +156,7 @@ PROTO2_RETURN(operations_research::sat::CpSolverResponse,
|
||||
|
||||
// Wrap the SolveWrapper class.
|
||||
%unignore operations_research::sat::SolveWrapper;
|
||||
%rename (addBestBoundCallback) operations_research::sat::SolveWrapper::AddBestBoundCallback;
|
||||
%rename (addLogCallback) operations_research::sat::SolveWrapper::AddLogCallback;
|
||||
%rename (addSolutionCallback) operations_research::sat::SolveWrapper::AddSolutionCallback;
|
||||
%rename (clearSolutionCallback) operations_research::sat::SolveWrapper::ClearSolutionCallback;
|
||||
|
||||
@@ -1484,22 +1484,21 @@ bool PresolveContext::CanonicalizeEncoding(int* ref, int64_t* value) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PresolveContext::InsertVarValueEncoding(int literal, int ref,
|
||||
bool PresolveContext::InsertVarValueEncoding(int literal, int var,
|
||||
int64_t value) {
|
||||
if (!CanonicalizeEncoding(&ref, &value) || !DomainOf(ref).Contains(value)) {
|
||||
if (!CanonicalizeEncoding(&var, &value) || !DomainOf(var).Contains(value)) {
|
||||
return SetLiteralToFalse(literal);
|
||||
}
|
||||
literal = GetLiteralRepresentative(literal);
|
||||
InsertVarValueEncodingInternal(literal, ref, value, /*add_constraints=*/true);
|
||||
InsertVarValueEncodingInternal(literal, var, value, /*add_constraints=*/true);
|
||||
|
||||
if (hint_is_loaded_) {
|
||||
const int bool_var = PositiveRef(literal);
|
||||
const int int_var = PositiveRef(ref);
|
||||
if (!hint_has_value_[bool_var] && hint_has_value_[int_var]) {
|
||||
const int64_t int_value = RefIsPositive(ref) ? value : -value;
|
||||
const int64_t hint_value = hint_[int_var] == int_value ? 1 : 0;
|
||||
DCHECK(RefIsPositive(var));
|
||||
if (!hint_has_value_[bool_var] && hint_has_value_[var]) {
|
||||
const int64_t bool_value = hint_[var] == value ? 1 : 0;
|
||||
hint_has_value_[bool_var] = true;
|
||||
hint_[bool_var] = RefIsPositive(literal) ? hint_value : 1 - hint_value;
|
||||
hint_[bool_var] = RefIsPositive(literal) ? bool_value : 1 - bool_value;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
@@ -365,7 +365,7 @@ class PresolveContext {
|
||||
// Clears the "rules" statistics.
|
||||
void ClearStats();
|
||||
|
||||
// Inserts the given literal to encode ref == value.
|
||||
// Inserts the given literal to encode var == value.
|
||||
// If an encoding already exists, it adds the two implications between
|
||||
// the previous encoding and the new encoding.
|
||||
//
|
||||
@@ -376,9 +376,9 @@ class PresolveContext {
|
||||
// Returns false if the model become UNSAT.
|
||||
//
|
||||
// TODO(user): This function is not always correct if
|
||||
// !context->DomainOf(ref).contains(value), we could make it correct but it
|
||||
// !context->DomainOf(var).contains(value), we could make it correct but it
|
||||
// might be a bit expansive to do so. For now we just have a DCHECK().
|
||||
bool InsertVarValueEncoding(int literal, int ref, int64_t value);
|
||||
bool InsertVarValueEncoding(int literal, int var, int64_t value);
|
||||
|
||||
// Gets the associated literal if it is already created. Otherwise
|
||||
// create it, add the corresponding constraints and returns it.
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
@@ -38,7 +39,7 @@ namespace operations_research {
|
||||
namespace sat {
|
||||
|
||||
// We prefer the product to combine the cost of two branches.
|
||||
double PseudoCosts::CombineCosts(double down_branch, double up_branch) const {
|
||||
double PseudoCosts::CombineScores(double down_branch, double up_branch) const {
|
||||
if (true) {
|
||||
return std::max(1e-6, down_branch) * std::max(1e-6, up_branch);
|
||||
} else {
|
||||
@@ -96,10 +97,10 @@ void PseudoCosts::BeforeTakingDecision(Literal decision) {
|
||||
bound_changes_ = GetBoundChanges(decision);
|
||||
}
|
||||
|
||||
double PseudoCosts::LpPseudoCost(IntegerVariable var,
|
||||
double down_fractionality) const {
|
||||
std::pair<double, double> PseudoCosts::LpPseudoCost(
|
||||
IntegerVariable var, double down_fractionality) const {
|
||||
const int max_index = std::max(var.value(), NegationOf(var).value());
|
||||
if (max_index >= average_unit_objective_increase_.size()) return 0.0;
|
||||
if (max_index >= average_unit_objective_increase_.size()) return {0.0, 0.0};
|
||||
|
||||
const double up_fractionality = 1.0 - down_fractionality;
|
||||
const double up_branch =
|
||||
@@ -107,7 +108,7 @@ double PseudoCosts::LpPseudoCost(IntegerVariable var,
|
||||
const double down_branch =
|
||||
down_fractionality *
|
||||
average_unit_objective_increase_[NegationOf(var)].CurrentAverage();
|
||||
return CombineCosts(down_branch, up_branch);
|
||||
return {down_branch, up_branch};
|
||||
}
|
||||
|
||||
void PseudoCosts::UpdateBoolPseudoCosts(absl::Span<const Literal> reason,
|
||||
@@ -132,7 +133,7 @@ double PseudoCosts::BoolPseudoCost(Literal lit, double lp_value) const {
|
||||
const double down_branch =
|
||||
down_fractionality *
|
||||
lit_pseudo_costs_[lit.NegatedIndex()].CurrentAverage();
|
||||
return CombineCosts(down_branch, up_branch);
|
||||
return CombineScores(down_branch, up_branch);
|
||||
}
|
||||
|
||||
int PseudoCosts::LpReliability(IntegerVariable var) const {
|
||||
@@ -204,7 +205,7 @@ void PseudoCosts::AfterTakingDecision(bool conflict) {
|
||||
pseudo_costs_[negative_var].NumRecords();
|
||||
if (count >= parameters_.pseudo_cost_reliability_threshold()) {
|
||||
scores_[positive_var] =
|
||||
CombineCosts(GetCost(positive_var), GetCost(negative_var));
|
||||
CombineScores(GetCost(positive_var), GetCost(negative_var));
|
||||
if (!is_relevant_[positive_var]) {
|
||||
is_relevant_[positive_var] = true;
|
||||
relevant_variables_.push_back(positive_var);
|
||||
|
||||
@@ -61,8 +61,11 @@ class PseudoCosts {
|
||||
}
|
||||
|
||||
// Alternative pseudo-costs. This relies on the LP more heavily and is more
|
||||
// in line with what a MIP solver would do.
|
||||
double LpPseudoCost(IntegerVariable var, double down_fractionality) const;
|
||||
// in line with what a MIP solver would do. Return the (down, up) costs which
|
||||
// can be combined with CombineScores();
|
||||
double CombineScores(double down_branch, double up_branch) const;
|
||||
std::pair<double, double> LpPseudoCost(IntegerVariable var,
|
||||
double down_fractionality) const;
|
||||
|
||||
// Returns the pseudo cost "reliability".
|
||||
int LpReliability(IntegerVariable var) const;
|
||||
@@ -83,8 +86,6 @@ class PseudoCosts {
|
||||
std::vector<VariableBoundChange> GetBoundChanges(Literal decision);
|
||||
|
||||
private:
|
||||
double CombineCosts(double down_branch, double up_branch) const;
|
||||
|
||||
// Returns the current objective info.
|
||||
struct ObjectiveInfo {
|
||||
std::string DebugString() const;
|
||||
|
||||
@@ -3132,6 +3132,7 @@ class CpSolver:
|
||||
sat_parameters_pb2.SatParameters()
|
||||
)
|
||||
self.log_callback: Optional[Callable[[str], None]] = None
|
||||
self.best_bound_callback: Optional[Callable[[double], None]] = None
|
||||
self.__solve_wrapper: Optional[swig_helper.SolveWrapper] = None
|
||||
self.__lock: threading.Lock = threading.Lock()
|
||||
|
||||
@@ -3151,6 +3152,9 @@ class CpSolver:
|
||||
if self.log_callback is not None:
|
||||
self.__solve_wrapper.add_log_callback(self.log_callback)
|
||||
|
||||
if self.best_bound_callback is not None:
|
||||
self.__solve_wrapper.add_best_bound_callback(self.best_bound_callback)
|
||||
|
||||
solution: cp_model_pb2.CpSolverResponse = self.__solve_wrapper.solve(
|
||||
model.proto
|
||||
)
|
||||
|
||||
@@ -93,6 +93,8 @@ PYBIND11_MODULE(swig_helper, m) {
|
||||
.def("add_solution_callback", &SolveWrapper::AddSolutionCallback,
|
||||
arg("callback"))
|
||||
.def("clear_solution_callback", &SolveWrapper::ClearSolutionCallback)
|
||||
.def("add_best_bound_callback", &SolveWrapper::AddBestBoundCallback,
|
||||
arg("best_bound_callback"))
|
||||
.def("set_parameters", &SolveWrapper::SetParameters, arg("parameters"))
|
||||
.def("solve",
|
||||
[](SolveWrapper* solve_wrapper,
|
||||
|
||||
@@ -140,6 +140,18 @@ void SolveWrapper::AddLogCallbackFromClass(LogCallback* log_callback) {
|
||||
});
|
||||
}
|
||||
|
||||
void SolveWrapper::AddBestBoundCallback(
|
||||
std::function<void(double)> best_bound_callback) {
|
||||
if (best_bound_callback != nullptr) {
|
||||
model_.Add(NewBestBoundCallback(best_bound_callback));
|
||||
}
|
||||
}
|
||||
|
||||
void SolveWrapper::AddBestBoundCallbackFromClass(BestBoundCallback* callback) {
|
||||
model_.Add(NewBestBoundCallback(
|
||||
[callback](double bound) { callback->NewBestBound(bound); }));
|
||||
}
|
||||
|
||||
operations_research::sat::CpSolverResponse SolveWrapper::Solve(
|
||||
const operations_research::sat::CpModelProto& model_proto) {
|
||||
FixFlagsAndEnvironmentForSwig();
|
||||
|
||||
@@ -88,6 +88,11 @@ class LogCallback {
|
||||
virtual ~LogCallback() = default;
|
||||
virtual void NewMessage(const std::string& message) = 0;
|
||||
};
|
||||
class BestBoundCallback {
|
||||
public:
|
||||
virtual ~BestBoundCallback() = default;
|
||||
virtual void NewBestBound(double bound) = 0;
|
||||
};
|
||||
|
||||
// This class is not meant to be reused after one solve.
|
||||
class SolveWrapper {
|
||||
@@ -108,9 +113,11 @@ class SolveWrapper {
|
||||
void AddSolutionCallback(const SolutionCallback& callback);
|
||||
void ClearSolutionCallback(const SolutionCallback& callback);
|
||||
void AddLogCallback(std::function<void(const std::string&)> log_callback);
|
||||
void AddBestBoundCallback(std::function<void(double)> best_bound_callback);
|
||||
|
||||
// Workaround for C#.
|
||||
void AddLogCallbackFromClass(LogCallback* log_callback);
|
||||
void AddBestBoundCallbackFromClass(BestBoundCallback* best_bound_callback);
|
||||
|
||||
operations_research::sat::CpSolverResponse Solve(
|
||||
const operations_research::sat::CpModelProto& model_proto);
|
||||
|
||||
@@ -304,6 +304,8 @@ void SharedResponseManager::UpdateInnerObjectiveBounds(
|
||||
|
||||
const bool change =
|
||||
(lb > inner_objective_lower_bound_ || ub < inner_objective_upper_bound_);
|
||||
if (!change) return;
|
||||
|
||||
if (lb > inner_objective_lower_bound_) {
|
||||
// When the improving problem is infeasible, it is possible to report
|
||||
// arbitrary high inner_objective_lower_bound_. We make sure it never cross
|
||||
@@ -328,21 +330,26 @@ void SharedResponseManager::UpdateInnerObjectiveBounds(
|
||||
SatProgressMessage("Done", wall_timer_.Get(), update_info));
|
||||
return;
|
||||
}
|
||||
if (logger_->LoggingIsEnabled() && change) {
|
||||
if (logger_->LoggingIsEnabled() || !best_bound_callbacks_.empty()) {
|
||||
const CpObjectiveProto& obj = *objective_or_null_;
|
||||
const double best =
|
||||
ScaleObjectiveValue(obj, best_solution_objective_value_);
|
||||
double new_lb = ScaleObjectiveValue(obj, inner_objective_lower_bound_);
|
||||
double new_ub = ScaleObjectiveValue(obj, inner_objective_upper_bound_);
|
||||
if (obj.scaling_factor() < 0) {
|
||||
std::swap(new_lb, new_ub);
|
||||
for (const auto& callback_entry : best_bound_callbacks_) {
|
||||
callback_entry.second(new_lb);
|
||||
}
|
||||
if (logger_->LoggingIsEnabled()) {
|
||||
double new_ub = ScaleObjectiveValue(obj, inner_objective_upper_bound_);
|
||||
if (obj.scaling_factor() < 0) {
|
||||
std::swap(new_lb, new_ub);
|
||||
}
|
||||
RegisterObjectiveBoundImprovement(update_info);
|
||||
logger_->ThrottledLog(bounds_logging_id_,
|
||||
ProgressMessage("Bound", wall_timer_.Get(), best,
|
||||
new_lb, new_ub, update_info));
|
||||
}
|
||||
RegisterObjectiveBoundImprovement(update_info);
|
||||
logger_->ThrottledLog(bounds_logging_id_,
|
||||
ProgressMessage("Bound", wall_timer_.Get(), best,
|
||||
new_lb, new_ub, update_info));
|
||||
}
|
||||
if (change) TestGapLimitsIfNeeded();
|
||||
TestGapLimitsIfNeeded();
|
||||
}
|
||||
|
||||
// Invariant: the status always start at UNKNOWN and can only evolve as follow:
|
||||
@@ -473,6 +480,25 @@ void SharedResponseManager::UnregisterLogCallback(int callback_id) {
|
||||
LOG(DFATAL) << "Callback id " << callback_id << " not registered.";
|
||||
}
|
||||
|
||||
int SharedResponseManager::AddBestBoundCallback(
|
||||
std::function<void(double)> callback) {
|
||||
absl::MutexLock mutex_lock(&mutex_);
|
||||
const int id = next_best_bound_callback_id_++;
|
||||
best_bound_callbacks_.emplace_back(id, std::move(callback));
|
||||
return id;
|
||||
}
|
||||
|
||||
void SharedResponseManager::UnregisterBestBoundCallback(int callback_id) {
|
||||
absl::MutexLock mutex_lock(&mutex_);
|
||||
for (int i = 0; i < best_bound_callbacks_.size(); ++i) {
|
||||
if (best_bound_callbacks_[i].first == callback_id) {
|
||||
best_bound_callbacks_.erase(best_bound_callbacks_.begin() + i);
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOG(DFATAL) << "Callback id " << callback_id << " not registered.";
|
||||
}
|
||||
|
||||
CpSolverResponse SharedResponseManager::GetResponseInternal(
|
||||
absl::Span<const int64_t> variable_values,
|
||||
const std::string& solution_info) {
|
||||
|
||||
@@ -269,6 +269,17 @@ class SharedResponseManager {
|
||||
std::function<std::string(const CpSolverResponse&)> callback);
|
||||
void UnregisterLogCallback(int callback_id);
|
||||
|
||||
// Adds a callback that will be called on each new best objective bound
|
||||
// found. Returns its id so it can be unregistered if needed.
|
||||
//
|
||||
// Note that adding a callback is not free since some computation will be done
|
||||
// before this is called.
|
||||
//
|
||||
// Note that currently the class is waiting for the callback to finish before
|
||||
// accepting any new updates. That could be changed if needed.
|
||||
int AddBestBoundCallback(std::function<void(double)> callback);
|
||||
void UnregisterBestBoundCallback(int callback_id);
|
||||
|
||||
// The "inner" objective is the CpModelProto objective without scaling/offset.
|
||||
// Note that these bound correspond to valid bound for the problem of finding
|
||||
// a strictly better objective than the current one. Thus the lower bound is
|
||||
@@ -443,6 +454,10 @@ class SharedResponseManager {
|
||||
std::pair<int, std::function<std::string(const CpSolverResponse&)>>>
|
||||
search_log_callbacks_ ABSL_GUARDED_BY(mutex_);
|
||||
|
||||
int next_best_bound_callback_id_ ABSL_GUARDED_BY(mutex_) = 0;
|
||||
std::vector<std::pair<int, std::function<void(double)>>> best_bound_callbacks_
|
||||
ABSL_GUARDED_BY(mutex_);
|
||||
|
||||
std::vector<std::function<void(std::vector<int64_t>*)>>
|
||||
solution_postprocessors_ ABSL_GUARDED_BY(mutex_);
|
||||
std::vector<std::function<void(CpSolverResponse*)>> postprocessors_
|
||||
|
||||
Reference in New Issue
Block a user