diff --git a/examples/python/single_machine_scheduling_with_setup_release_due_dates_sat.py b/examples/python/single_machine_scheduling_with_setup_release_due_dates_sat.py index 65b1602631..8e048c0969 100644 --- a/examples/python/single_machine_scheduling_with_setup_release_due_dates_sat.py +++ b/examples/python/single_machine_scheduling_with_setup_release_due_dates_sat.py @@ -27,7 +27,7 @@ _OUTPUT_PROTO = flags.DEFINE_string( ) _PARAMS = flags.DEFINE_string( "params", - "num_search_workers:16,log_search_progress:true,max_time_in_seconds:45", + "num_search_workers:16,log_search_progress:false,max_time_in_seconds:45", "Sat solver parameters.", ) _PREPROCESS = flags.DEFINE_bool( @@ -503,6 +503,7 @@ def single_machine_scheduling(): if parameters: text_format.Parse(parameters, solver.parameters) solution_printer = SolutionPrinter() + solver.best_bound_callback = lambda a : print(f"New objective lower bound: {a}") solver.solve(model, solution_printer) for job_id in all_jobs: print( diff --git a/ortools/java/com/google/ortools/sat/CpSolver.java b/ortools/java/com/google/ortools/sat/CpSolver.java index 704bcbdd63..d5f67fbf7b 100644 --- a/ortools/java/com/google/ortools/sat/CpSolver.java +++ b/ortools/java/com/google/ortools/sat/CpSolver.java @@ -30,6 +30,7 @@ public final class CpSolver { public CpSolver() { this.solveParameters = SatParameters.newBuilder(); this.logCallback = null; + this.bestBoundCallback = null; this.solveWrapper = null; } @@ -52,6 +53,9 @@ public final class CpSolver { if (logCallback != null) { solveWrapper.addLogCallback(logCallback); } + if (bestBoundCallback != null) { + solveWrapper.addBestBoundCallback(bestBoundCallback); + } solveResponse = solveWrapper.solve(model.model()); @@ -201,5 +205,6 @@ public final class CpSolver { private CpSolverResponse solveResponse; private final SatParameters.Builder solveParameters; private Consumer logCallback; + private Consumer bestBoundCallback; private SolveWrapper solveWrapper; } diff --git a/ortools/sat/cp_model_checker.cc b/ortools/sat/cp_model_checker.cc index 35a0a9a9eb..d7e7ca75b5 100644 --- a/ortools/sat/cp_model_checker.cc +++ b/ortools/sat/cp_model_checker.cc @@ -28,7 +28,6 @@ #include "absl/log/check.h" #include "absl/meta/type_traits.h" #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "ortools/base/logging.h" #include "ortools/port/proto_utils.h" diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 7e974b5b1d..8a6e3ba340 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -1088,6 +1088,7 @@ struct SolutionObservers { std::vector> observers; std::vector> log_callbacks; + std::vector> best_bound_callbacks; }; std::function NewFeasibleSolutionObserver( @@ -1105,6 +1106,14 @@ std::function NewFeasibleSolutionLogCallback( }; } +std::function NewBestBoundCallback( + const std::function& callback) { + return [=](Model* model) { + model->GetOrCreate()->best_bound_callbacks.push_back( + callback); + }; +} + #if !defined(__PORTABLE_PLATFORM__) // TODO(user): Support it on android. std::function NewSatParameters( @@ -4185,6 +4194,12 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { shared_response_manager->AddLogCallback(callback); } + const auto& best_bound_callbacks = + model->GetOrCreate()->best_bound_callbacks; + for (const auto& callback : best_bound_callbacks) { + shared_response_manager->AddBestBoundCallback(callback); + } + // Make sure everything stops when we have a first solution if requested. if (params.stop_after_first_solution()) { shared_response_manager->AddSolutionCallback( diff --git a/ortools/sat/cp_model_solver.h b/ortools/sat/cp_model_solver.h index c29b468b6b..24fd56cee2 100644 --- a/ortools/sat/cp_model_solver.h +++ b/ortools/sat/cp_model_solver.h @@ -103,6 +103,15 @@ std::function NewFeasibleSolutionLogCallback( const std::function& callback); +/** Creates a callbacks that will be called on each new best objective bound + * found. + * + * Note that adding a callback is not free since some computation will be done + * before this is called, and it will only be called on optimization models. + */ +std::function NewBestBoundCallback( + const std::function& callback); + /** * Creates parameters for the solver, which you can add to the model with * \code diff --git a/ortools/sat/csharp/CpSolver.cs b/ortools/sat/csharp/CpSolver.cs index 51203daf66..bdf763b4b2 100644 --- a/ortools/sat/csharp/CpSolver.cs +++ b/ortools/sat/csharp/CpSolver.cs @@ -41,6 +41,10 @@ public class CpSolver { solve_wrapper_.AddLogCallbackFromClass(log_callback_); } + if (best_bound_callback_ is not null) + { + solve_wrapper_.AddBestBoundCallbackFromClass(best_bound_callback_); + } if (cb is not null) { solve_wrapper_.AddSolutionCallback(cb); @@ -141,6 +145,11 @@ public class CpSolver log_callback_ = new LogCallbackDelegate(del); } + public void SetBestBoundCallback(DoubleToVoidDelegate del) + { + best_bound_callback_ = new BestBoundCallbackDelegate(del); + } + public CpSolverResponse Response { get { @@ -283,6 +292,7 @@ public class CpSolver private CpSolverResponse response_; private LogCallback log_callback_; + private BestBoundCallback best_bound_callback_; private string string_parameters_; private SolveWrapper solve_wrapper_; private Queue terms_; @@ -303,4 +313,19 @@ class LogCallbackDelegate : LogCallback private StringToVoidDelegate delegate_; } +class BestBoundCallbackDelegate : BestBoundCallback +{ + public BestBoundCallbackDelegate(DoubleToVoidDelegate del) + { + this.delegate_ = del; + } + + public override void NewBestBound(double bound) + { + delegate_(bound); + } + + private DoubleToVoidDelegate delegate_; +} + } // namespace Google.OrTools.Sat diff --git a/ortools/sat/csharp/sat.i b/ortools/sat/csharp/sat.i index 1f3dedeffd..ff8c09f0df 100644 --- a/ortools/sat/csharp/sat.i +++ b/ortools/sat/csharp/sat.i @@ -39,6 +39,8 @@ using Google.OrTools.Util; %typemap(csimports) operations_research::sat::SolveWrapper %{ // Used to wrap log callbacks (std::function) public delegate void StringToVoidDelegate(string message); +// Used to wrap best bound callbacks (std::function) +public delegate void DoubleToVoidDelegate(double bound); %} PROTO_INPUT(operations_research::sat::CpModelProto, @@ -82,8 +84,16 @@ JAGGED_MATRIX_AS_CSHARP_ARRAY(int64_t, int64_t, long, Int64VectorVector); %unignore operations_research::sat::LogCallback::~LogCallback; %unignore operations_research::sat::LogCallback::NewMessage; +// Temporary wrapper class for the DoubleToVoidDelegate. +%feature("director") operations_research::sat::BestBoundCallback; +%unignore operations_research::sat::BestBoundCallback; +%unignore operations_research::sat::BestBoundCallback::~BestBoundCallback; +%unignore operations_research::sat::BestBoundCallback::NewBestBound; + + // Wrap the SolveWrapper class. %unignore operations_research::sat::SolveWrapper; +%unignore operations_research::sat::SolveWrapper::AddBestBoundCallbackFromClass; %unignore operations_research::sat::SolveWrapper::AddLogCallbackFromClass; %unignore operations_research::sat::SolveWrapper::AddSolutionCallback; %unignore operations_research::sat::SolveWrapper::ClearSolutionCallback; diff --git a/ortools/sat/integer_search.cc b/ortools/sat/integer_search.cc index 7d3d9fcd33..f95de4c846 100644 --- a/ortools/sat/integer_search.cc +++ b/ortools/sat/integer_search.cc @@ -256,33 +256,34 @@ std::function LpPseudoCostHeuristic(Model* model) { // average, it is good anyway? if (!is_reliable && is_integer) continue; - // For Booleans, for some reason it seems the up-branch first work better? - if (lb == 0 && ub == 1) { - const double score = pseudo_costs->LpPseudoCost(var, lp_value); - if (score > best_score) { - const LiteralIndex index = encoder->GetAssociatedLiteral( - IntegerLiteral::GreaterOrEqual(var, 1)); - if (index != kNoLiteralIndex) { - best_score = score; - decision = BooleanOrIntegerLiteral(Literal(index)); - } - } - } - // There are some corner cases if we are at the bound. Note that it is // important to be in sync with the SplitAroundLpValue() below. double down_fractionality = lp_value - std::floor(lp_value); - if (lp_value >= ToDouble(ub)) down_fractionality = 1.0; - if (lp_value <= ToDouble(lb)) down_fractionality = 0.0; - const double score = pseudo_costs->LpPseudoCost(var, down_fractionality); + IntegerValue down_target = IntegerValue(std::floor(lp_value)); + if (lp_value >= ToDouble(ub)) { + down_fractionality = 1.0; + down_target = ub - 1; + } else if (lp_value <= ToDouble(lb)) { + down_fractionality = 0.0; + down_target = lb; + } + const auto [down_score, up_score] = + pseudo_costs->LpPseudoCost(var, down_fractionality); + const double score = pseudo_costs->CombineScores(down_score, up_score); // We delay to subsequent heuristic if the score is 0.0. if (score > best_score) { best_score = score; - // This choose <= value if possible. - decision = BooleanOrIntegerLiteral(SplitAroundGivenValue( - var, IntegerValue(std::floor(lp_value)), model)); + // This direction works better than the inverse in the benchs. But + // always branching up seems even better. TODO(user): investigate. + if (down_score > up_score) { + decision = BooleanOrIntegerLiteral( + IntegerLiteral::LowerOrEqual(var, down_target)); + } else { + decision = BooleanOrIntegerLiteral( + IntegerLiteral::GreaterOrEqual(var, down_target + 1)); + } } } return decision; diff --git a/ortools/sat/java/sat.i b/ortools/sat/java/sat.i index 1f0e32731c..e2c6cad058 100644 --- a/ortools/sat/java/sat.i +++ b/ortools/sat/java/sat.i @@ -115,6 +115,40 @@ PROTO2_RETURN(operations_research::sat::CpSolverResponse, %typemap(jstype) std::function "java.util.function.Consumer" // Type used in the Proxy class. %typemap(javain) std::function "$javainput" // passing the Callback to JNI java class. +%typemap(in) std::function %{ + // $input will be deleted once this function return. + // So we create a JNI global reference to keep it alive. + jobject $input_object = jenv->NewGlobalRef($input); + // and we wrap it in a GlobalRefGuard object which will call the + // JNI global reference deleter to avoid leak at destruction. + JavaVM* jvm; + jenv->GetJavaVM(&jvm); + auto $input_guard = std::make_shared(jvm, $input_object); + + jclass $input_object_class = jenv->GetObjectClass($input); + if (nullptr == $input_object_class) return $null; + jmethodID $input_method_id = jenv->GetMethodID( + $input_object_class, "accept", "(Ljava/lang/Double;)V"); + assert($input_method_id != nullptr); + + // When the lambda will be destroyed, input_guard's destructor will be called. + $1 = [jvm, $input_object, $input_method_id, $input_guard]( + double bound) -> void { + JNIEnv *jenv = NULL; + JavaVMAttachArgs args; + args.version = JNI_VERSION_1_2; + args.name = NULL; + args.group = NULL; + jvm->AttachCurrentThread((void**)&jenv, &args); + jenv->CallVoidMethod($input_object, $input_method_id, bound); + jvm->DetachCurrentThread(); + }; +%} +%typemap(jni) std::function "jobject" // Type used in the JNI.java. +%typemap(jtype) std::function "java.util.function.Consumer" // Type used in the JNI.java. +%typemap(jstype) std::function "java.util.function.Consumer" // Type used in the Proxy class. +%typemap(javain) std::function "$javainput" // passing the Callback to JNI java class. + %ignoreall %unignore operations_research; @@ -122,6 +156,7 @@ PROTO2_RETURN(operations_research::sat::CpSolverResponse, // Wrap the SolveWrapper class. %unignore operations_research::sat::SolveWrapper; +%rename (addBestBoundCallback) operations_research::sat::SolveWrapper::AddBestBoundCallback; %rename (addLogCallback) operations_research::sat::SolveWrapper::AddLogCallback; %rename (addSolutionCallback) operations_research::sat::SolveWrapper::AddSolutionCallback; %rename (clearSolutionCallback) operations_research::sat::SolveWrapper::ClearSolutionCallback; diff --git a/ortools/sat/presolve_context.cc b/ortools/sat/presolve_context.cc index a9fef9c255..5a60479798 100644 --- a/ortools/sat/presolve_context.cc +++ b/ortools/sat/presolve_context.cc @@ -1484,22 +1484,21 @@ bool PresolveContext::CanonicalizeEncoding(int* ref, int64_t* value) { return true; } -bool PresolveContext::InsertVarValueEncoding(int literal, int ref, +bool PresolveContext::InsertVarValueEncoding(int literal, int var, int64_t value) { - if (!CanonicalizeEncoding(&ref, &value) || !DomainOf(ref).Contains(value)) { + if (!CanonicalizeEncoding(&var, &value) || !DomainOf(var).Contains(value)) { return SetLiteralToFalse(literal); } literal = GetLiteralRepresentative(literal); - InsertVarValueEncodingInternal(literal, ref, value, /*add_constraints=*/true); + InsertVarValueEncodingInternal(literal, var, value, /*add_constraints=*/true); if (hint_is_loaded_) { const int bool_var = PositiveRef(literal); - const int int_var = PositiveRef(ref); - if (!hint_has_value_[bool_var] && hint_has_value_[int_var]) { - const int64_t int_value = RefIsPositive(ref) ? value : -value; - const int64_t hint_value = hint_[int_var] == int_value ? 1 : 0; + DCHECK(RefIsPositive(var)); + if (!hint_has_value_[bool_var] && hint_has_value_[var]) { + const int64_t bool_value = hint_[var] == value ? 1 : 0; hint_has_value_[bool_var] = true; - hint_[bool_var] = RefIsPositive(literal) ? hint_value : 1 - hint_value; + hint_[bool_var] = RefIsPositive(literal) ? bool_value : 1 - bool_value; } } return true; diff --git a/ortools/sat/presolve_context.h b/ortools/sat/presolve_context.h index d81532f929..7b8b8240d3 100644 --- a/ortools/sat/presolve_context.h +++ b/ortools/sat/presolve_context.h @@ -365,7 +365,7 @@ class PresolveContext { // Clears the "rules" statistics. void ClearStats(); - // Inserts the given literal to encode ref == value. + // Inserts the given literal to encode var == value. // If an encoding already exists, it adds the two implications between // the previous encoding and the new encoding. // @@ -376,9 +376,9 @@ class PresolveContext { // Returns false if the model become UNSAT. // // TODO(user): This function is not always correct if - // !context->DomainOf(ref).contains(value), we could make it correct but it + // !context->DomainOf(var).contains(value), we could make it correct but it // might be a bit expansive to do so. For now we just have a DCHECK(). - bool InsertVarValueEncoding(int literal, int ref, int64_t value); + bool InsertVarValueEncoding(int literal, int var, int64_t value); // Gets the associated literal if it is already created. Otherwise // create it, add the corresponding constraints and returns it. diff --git a/ortools/sat/pseudo_costs.cc b/ortools/sat/pseudo_costs.cc index 92d4ac95ea..79a17e58de 100644 --- a/ortools/sat/pseudo_costs.cc +++ b/ortools/sat/pseudo_costs.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include "absl/log/check.h" @@ -38,7 +39,7 @@ namespace operations_research { namespace sat { // We prefer the product to combine the cost of two branches. -double PseudoCosts::CombineCosts(double down_branch, double up_branch) const { +double PseudoCosts::CombineScores(double down_branch, double up_branch) const { if (true) { return std::max(1e-6, down_branch) * std::max(1e-6, up_branch); } else { @@ -96,10 +97,10 @@ void PseudoCosts::BeforeTakingDecision(Literal decision) { bound_changes_ = GetBoundChanges(decision); } -double PseudoCosts::LpPseudoCost(IntegerVariable var, - double down_fractionality) const { +std::pair PseudoCosts::LpPseudoCost( + IntegerVariable var, double down_fractionality) const { const int max_index = std::max(var.value(), NegationOf(var).value()); - if (max_index >= average_unit_objective_increase_.size()) return 0.0; + if (max_index >= average_unit_objective_increase_.size()) return {0.0, 0.0}; const double up_fractionality = 1.0 - down_fractionality; const double up_branch = @@ -107,7 +108,7 @@ double PseudoCosts::LpPseudoCost(IntegerVariable var, const double down_branch = down_fractionality * average_unit_objective_increase_[NegationOf(var)].CurrentAverage(); - return CombineCosts(down_branch, up_branch); + return {down_branch, up_branch}; } void PseudoCosts::UpdateBoolPseudoCosts(absl::Span reason, @@ -132,7 +133,7 @@ double PseudoCosts::BoolPseudoCost(Literal lit, double lp_value) const { const double down_branch = down_fractionality * lit_pseudo_costs_[lit.NegatedIndex()].CurrentAverage(); - return CombineCosts(down_branch, up_branch); + return CombineScores(down_branch, up_branch); } int PseudoCosts::LpReliability(IntegerVariable var) const { @@ -204,7 +205,7 @@ void PseudoCosts::AfterTakingDecision(bool conflict) { pseudo_costs_[negative_var].NumRecords(); if (count >= parameters_.pseudo_cost_reliability_threshold()) { scores_[positive_var] = - CombineCosts(GetCost(positive_var), GetCost(negative_var)); + CombineScores(GetCost(positive_var), GetCost(negative_var)); if (!is_relevant_[positive_var]) { is_relevant_[positive_var] = true; relevant_variables_.push_back(positive_var); diff --git a/ortools/sat/pseudo_costs.h b/ortools/sat/pseudo_costs.h index b5b950f4cf..a7b8c0447c 100644 --- a/ortools/sat/pseudo_costs.h +++ b/ortools/sat/pseudo_costs.h @@ -61,8 +61,11 @@ class PseudoCosts { } // Alternative pseudo-costs. This relies on the LP more heavily and is more - // in line with what a MIP solver would do. - double LpPseudoCost(IntegerVariable var, double down_fractionality) const; + // in line with what a MIP solver would do. Return the (down, up) costs which + // can be combined with CombineScores(); + double CombineScores(double down_branch, double up_branch) const; + std::pair LpPseudoCost(IntegerVariable var, + double down_fractionality) const; // Returns the pseudo cost "reliability". int LpReliability(IntegerVariable var) const; @@ -83,8 +86,6 @@ class PseudoCosts { std::vector GetBoundChanges(Literal decision); private: - double CombineCosts(double down_branch, double up_branch) const; - // Returns the current objective info. struct ObjectiveInfo { std::string DebugString() const; diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index f09baffbce..9ba1854330 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -3132,6 +3132,7 @@ class CpSolver: sat_parameters_pb2.SatParameters() ) self.log_callback: Optional[Callable[[str], None]] = None + self.best_bound_callback: Optional[Callable[[double], None]] = None self.__solve_wrapper: Optional[swig_helper.SolveWrapper] = None self.__lock: threading.Lock = threading.Lock() @@ -3151,6 +3152,9 @@ class CpSolver: if self.log_callback is not None: self.__solve_wrapper.add_log_callback(self.log_callback) + if self.best_bound_callback is not None: + self.__solve_wrapper.add_best_bound_callback(self.best_bound_callback) + solution: cp_model_pb2.CpSolverResponse = self.__solve_wrapper.solve( model.proto ) diff --git a/ortools/sat/python/swig_helper.cc b/ortools/sat/python/swig_helper.cc index 1a581f0919..5f225cb954 100644 --- a/ortools/sat/python/swig_helper.cc +++ b/ortools/sat/python/swig_helper.cc @@ -93,6 +93,8 @@ PYBIND11_MODULE(swig_helper, m) { .def("add_solution_callback", &SolveWrapper::AddSolutionCallback, arg("callback")) .def("clear_solution_callback", &SolveWrapper::ClearSolutionCallback) + .def("add_best_bound_callback", &SolveWrapper::AddBestBoundCallback, + arg("best_bound_callback")) .def("set_parameters", &SolveWrapper::SetParameters, arg("parameters")) .def("solve", [](SolveWrapper* solve_wrapper, diff --git a/ortools/sat/swig_helper.cc b/ortools/sat/swig_helper.cc index 63b664eedd..8488320642 100644 --- a/ortools/sat/swig_helper.cc +++ b/ortools/sat/swig_helper.cc @@ -140,6 +140,18 @@ void SolveWrapper::AddLogCallbackFromClass(LogCallback* log_callback) { }); } +void SolveWrapper::AddBestBoundCallback( + std::function best_bound_callback) { + if (best_bound_callback != nullptr) { + model_.Add(NewBestBoundCallback(best_bound_callback)); + } +} + +void SolveWrapper::AddBestBoundCallbackFromClass(BestBoundCallback* callback) { + model_.Add(NewBestBoundCallback( + [callback](double bound) { callback->NewBestBound(bound); })); +} + operations_research::sat::CpSolverResponse SolveWrapper::Solve( const operations_research::sat::CpModelProto& model_proto) { FixFlagsAndEnvironmentForSwig(); diff --git a/ortools/sat/swig_helper.h b/ortools/sat/swig_helper.h index b46135c1b0..e9821b620d 100644 --- a/ortools/sat/swig_helper.h +++ b/ortools/sat/swig_helper.h @@ -88,6 +88,11 @@ class LogCallback { virtual ~LogCallback() = default; virtual void NewMessage(const std::string& message) = 0; }; +class BestBoundCallback { + public: + virtual ~BestBoundCallback() = default; + virtual void NewBestBound(double bound) = 0; +}; // This class is not meant to be reused after one solve. class SolveWrapper { @@ -108,9 +113,11 @@ class SolveWrapper { void AddSolutionCallback(const SolutionCallback& callback); void ClearSolutionCallback(const SolutionCallback& callback); void AddLogCallback(std::function log_callback); + void AddBestBoundCallback(std::function best_bound_callback); // Workaround for C#. void AddLogCallbackFromClass(LogCallback* log_callback); + void AddBestBoundCallbackFromClass(BestBoundCallback* best_bound_callback); operations_research::sat::CpSolverResponse Solve( const operations_research::sat::CpModelProto& model_proto); diff --git a/ortools/sat/synchronization.cc b/ortools/sat/synchronization.cc index 60f266aca7..213d60146c 100644 --- a/ortools/sat/synchronization.cc +++ b/ortools/sat/synchronization.cc @@ -304,6 +304,8 @@ void SharedResponseManager::UpdateInnerObjectiveBounds( const bool change = (lb > inner_objective_lower_bound_ || ub < inner_objective_upper_bound_); + if (!change) return; + if (lb > inner_objective_lower_bound_) { // When the improving problem is infeasible, it is possible to report // arbitrary high inner_objective_lower_bound_. We make sure it never cross @@ -328,21 +330,26 @@ void SharedResponseManager::UpdateInnerObjectiveBounds( SatProgressMessage("Done", wall_timer_.Get(), update_info)); return; } - if (logger_->LoggingIsEnabled() && change) { + if (logger_->LoggingIsEnabled() || !best_bound_callbacks_.empty()) { const CpObjectiveProto& obj = *objective_or_null_; const double best = ScaleObjectiveValue(obj, best_solution_objective_value_); double new_lb = ScaleObjectiveValue(obj, inner_objective_lower_bound_); - double new_ub = ScaleObjectiveValue(obj, inner_objective_upper_bound_); - if (obj.scaling_factor() < 0) { - std::swap(new_lb, new_ub); + for (const auto& callback_entry : best_bound_callbacks_) { + callback_entry.second(new_lb); + } + if (logger_->LoggingIsEnabled()) { + double new_ub = ScaleObjectiveValue(obj, inner_objective_upper_bound_); + if (obj.scaling_factor() < 0) { + std::swap(new_lb, new_ub); + } + RegisterObjectiveBoundImprovement(update_info); + logger_->ThrottledLog(bounds_logging_id_, + ProgressMessage("Bound", wall_timer_.Get(), best, + new_lb, new_ub, update_info)); } - RegisterObjectiveBoundImprovement(update_info); - logger_->ThrottledLog(bounds_logging_id_, - ProgressMessage("Bound", wall_timer_.Get(), best, - new_lb, new_ub, update_info)); } - if (change) TestGapLimitsIfNeeded(); + TestGapLimitsIfNeeded(); } // Invariant: the status always start at UNKNOWN and can only evolve as follow: @@ -473,6 +480,25 @@ void SharedResponseManager::UnregisterLogCallback(int callback_id) { LOG(DFATAL) << "Callback id " << callback_id << " not registered."; } +int SharedResponseManager::AddBestBoundCallback( + std::function callback) { + absl::MutexLock mutex_lock(&mutex_); + const int id = next_best_bound_callback_id_++; + best_bound_callbacks_.emplace_back(id, std::move(callback)); + return id; +} + +void SharedResponseManager::UnregisterBestBoundCallback(int callback_id) { + absl::MutexLock mutex_lock(&mutex_); + for (int i = 0; i < best_bound_callbacks_.size(); ++i) { + if (best_bound_callbacks_[i].first == callback_id) { + best_bound_callbacks_.erase(best_bound_callbacks_.begin() + i); + return; + } + } + LOG(DFATAL) << "Callback id " << callback_id << " not registered."; +} + CpSolverResponse SharedResponseManager::GetResponseInternal( absl::Span variable_values, const std::string& solution_info) { diff --git a/ortools/sat/synchronization.h b/ortools/sat/synchronization.h index bf310bcf78..b24a651eb4 100644 --- a/ortools/sat/synchronization.h +++ b/ortools/sat/synchronization.h @@ -269,6 +269,17 @@ class SharedResponseManager { std::function callback); void UnregisterLogCallback(int callback_id); + // Adds a callback that will be called on each new best objective bound + // found. Returns its id so it can be unregistered if needed. + // + // Note that adding a callback is not free since some computation will be done + // before this is called. + // + // Note that currently the class is waiting for the callback to finish before + // accepting any new updates. That could be changed if needed. + int AddBestBoundCallback(std::function callback); + void UnregisterBestBoundCallback(int callback_id); + // The "inner" objective is the CpModelProto objective without scaling/offset. // Note that these bound correspond to valid bound for the problem of finding // a strictly better objective than the current one. Thus the lower bound is @@ -443,6 +454,10 @@ class SharedResponseManager { std::pair>> search_log_callbacks_ ABSL_GUARDED_BY(mutex_); + int next_best_bound_callback_id_ ABSL_GUARDED_BY(mutex_) = 0; + std::vector>> best_bound_callbacks_ + ABSL_GUARDED_BY(mutex_); + std::vector*)>> solution_postprocessors_ ABSL_GUARDED_BY(mutex_); std::vector> postprocessors_