From 044bc5ebce23fa2d4bdabc34e01ebbb3cb238bca Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Wed, 25 Oct 2023 17:05:22 +0200 Subject: [PATCH] speed up sat scheduling code --- ortools/sat/cumulative_energy.cc | 7 ++-- ortools/sat/diffn.cc | 9 ++--- ortools/sat/integer_expr.h | 6 ++-- ortools/sat/intervals.cc | 50 +++++++++++++++++----------- ortools/sat/intervals.h | 43 ++++++++++++++---------- ortools/sat/linear_relaxation.cc | 2 +- ortools/sat/timetable_edgefinding.cc | 10 +++--- 7 files changed, 70 insertions(+), 57 deletions(-) diff --git a/ortools/sat/cumulative_energy.cc b/ortools/sat/cumulative_energy.cc index a90b72b49d..72a82a07a8 100644 --- a/ortools/sat/cumulative_energy.cc +++ b/ortools/sat/cumulative_energy.cc @@ -87,10 +87,9 @@ bool CumulativeEnergyConstraint::Propagate() { bool tree_has_mandatory_intervals = false; // Main loop: insert tasks by increasing end_max, check for overloads. - for (const auto task_time : - ::gtl::reversed_view(helper_->TaskByDecreasingEndMax())) { - const int current_task = task_time.task_index; - const IntegerValue current_end = task_time.time; + const auto by_decreasing_end_max = helper_->TaskByDecreasingEndMax(); + for (const auto [current_task, current_end] : + ::gtl::reversed_view(by_decreasing_end_max)) { if (task_to_start_event_[current_task] == -1) continue; // Add the current task to the tree. diff --git a/ortools/sat/diffn.cc b/ortools/sat/diffn.cc index 63c616bc17..f6f43f6056 100644 --- a/ortools/sat/diffn.cc +++ b/ortools/sat/diffn.cc @@ -32,7 +32,6 @@ #include "ortools/sat/intervals.h" #include "ortools/sat/linear_constraint.h" #include "ortools/sat/model.h" -#include "ortools/sat/precedences.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/timetable.h" @@ -47,8 +46,7 @@ namespace { // TODO(user): Use the faster variable only version if all expressions reduce // to a single variable? void AddIsEqualToMinOf(IntegerVariable min_var, - const std::vector& exprs, - Model* model) { + absl::Span exprs, Model* model) { std::vector converted; for (const AffineExpression& affine : exprs) { LinearExpression e; @@ -66,8 +64,7 @@ void AddIsEqualToMinOf(IntegerVariable min_var, } void AddIsEqualToMaxOf(IntegerVariable max_var, - const std::vector& exprs, - Model* model) { + absl::Span exprs, Model* model) { std::vector converted; for (const AffineExpression& affine : exprs) { LinearExpression e; @@ -331,7 +328,7 @@ bool NonOverlappingRectanglesDisjunctivePropagator:: // Compute relevant boxes, the one with a mandatory part of y. Because we will // need to sort it this way, we consider them by increasing start max. indexed_boxes_.clear(); - const std::vector& temp = y->TaskByDecreasingStartMax(); + const auto temp = y->TaskByDecreasingStartMax(); for (int i = temp.size(); --i >= 0;) { const int box = temp[i].task_index; // Ignore absent boxes. diff --git a/ortools/sat/integer_expr.h b/ortools/sat/integer_expr.h index c3cf2fb0fe..ce8cdf398b 100644 --- a/ortools/sat/integer_expr.h +++ b/ortools/sat/integer_expr.h @@ -139,9 +139,9 @@ class LinearConstraintPropagator : public PropagatorInterface { // (resp. coefficients) contained in the range [0, rev_num_fixed_vars_) of // vars_ (resp. coeffs_) are fixed (resp. belong to fixed variables). const int size_; - std::unique_ptr vars_; - std::unique_ptr coeffs_; - std::unique_ptr max_variations_; + const std::unique_ptr vars_; + const std::unique_ptr coeffs_; + const std::unique_ptr max_variations_; // This is just the negation of the enforcement literal and it never changes. std::vector literal_reason_; diff --git a/ortools/sat/intervals.cc b/ortools/sat/intervals.cc index be796085f9..4cffbfa663 100644 --- a/ortools/sat/intervals.cc +++ b/ortools/sat/intervals.cc @@ -169,9 +169,10 @@ SchedulingConstraintHelper* IntervalsRepository::GetOrCreateHelper( SchedulingDemandHelper* IntervalsRepository::GetOrCreateDemandHelper( SchedulingConstraintHelper* helper, - const std::vector& demands) { + absl::Span demands) { const std::pair> - key = {helper, demands}; + key = {helper, + std::vector(demands.begin(), demands.end())}; const auto it = demand_helper_repository_.find(key); if (it != demand_helper_repository_.end()) return it->second; @@ -194,7 +195,15 @@ SchedulingConstraintHelper::SchedulingConstraintHelper( integer_trail_(model->GetOrCreate()), precedence_relations_(model->GetOrCreate()), precedences_(model->GetOrCreate()), - interval_variables_(tasks) { + interval_variables_(tasks), + capacity_(tasks.size()), + cached_size_min_(new IntegerValue[capacity_]), + cached_start_min_(new IntegerValue[capacity_]), + cached_end_min_(new IntegerValue[capacity_]), + cached_negated_start_max_(new IntegerValue[capacity_]), + cached_negated_end_max_(new IntegerValue[capacity_]), + cached_shifted_start_min_(new IntegerValue[capacity_]), + cached_negated_shifted_end_max_(new IntegerValue[capacity_]) { starts_.clear(); ends_.clear(); minus_ends_.clear(); @@ -228,7 +237,15 @@ SchedulingConstraintHelper::SchedulingConstraintHelper(int num_tasks, : trail_(model->GetOrCreate()), integer_trail_(model->GetOrCreate()), precedence_relations_(model->GetOrCreate()), - precedences_(model->GetOrCreate()) { + precedences_(model->GetOrCreate()), + capacity_(num_tasks), + cached_size_min_(new IntegerValue[capacity_]), + cached_start_min_(new IntegerValue[capacity_]), + cached_end_min_(new IntegerValue[capacity_]), + cached_negated_start_max_(new IntegerValue[capacity_]), + cached_negated_end_max_(new IntegerValue[capacity_]), + cached_shifted_start_min_(new IntegerValue[capacity_]), + cached_negated_shifted_end_max_(new IntegerValue[capacity_]) { starts_.resize(num_tasks); CHECK_EQ(NumTasks(), num_tasks); } @@ -380,13 +397,8 @@ void SchedulingConstraintHelper::InitSortedVectors() { recompute_all_cache_ = true; recompute_cache_.resize(num_tasks, true); - cached_shifted_start_min_.resize(num_tasks); - cached_negated_shifted_end_max_.resize(num_tasks); - cached_size_min_.resize(num_tasks); - cached_start_min_.resize(num_tasks); - cached_end_min_.resize(num_tasks); - cached_negated_start_max_.resize(num_tasks); - cached_negated_end_max_.resize(num_tasks); + // Make sure all the cached_* arrays can hold enough data. + CHECK_LE(num_tasks, capacity_); task_by_increasing_start_min_.resize(num_tasks); task_by_increasing_end_min_.resize(num_tasks); @@ -491,7 +503,7 @@ void SchedulingConstraintHelper::AddLevelZeroPrecedence(int a, int b) { } } -const std::vector& +absl::Span SchedulingConstraintHelper::TaskByIncreasingStartMin() { const int num_tasks = NumTasks(); for (int i = 0; i < num_tasks; ++i) { @@ -503,7 +515,7 @@ SchedulingConstraintHelper::TaskByIncreasingStartMin() { return task_by_increasing_start_min_; } -const std::vector& +absl::Span SchedulingConstraintHelper::TaskByIncreasingEndMin() { const int num_tasks = NumTasks(); for (int i = 0; i < num_tasks; ++i) { @@ -515,7 +527,7 @@ SchedulingConstraintHelper::TaskByIncreasingEndMin() { return task_by_increasing_end_min_; } -const std::vector& +absl::Span SchedulingConstraintHelper::TaskByDecreasingStartMax() { const int num_tasks = NumTasks(); for (int i = 0; i < num_tasks; ++i) { @@ -528,7 +540,7 @@ SchedulingConstraintHelper::TaskByDecreasingStartMax() { return task_by_decreasing_start_max_; } -const std::vector& +absl::Span SchedulingConstraintHelper::TaskByDecreasingEndMax() { const int num_tasks = NumTasks(); for (int i = 0; i < num_tasks; ++i) { @@ -540,7 +552,7 @@ SchedulingConstraintHelper::TaskByDecreasingEndMax() { return task_by_decreasing_end_max_; } -const std::vector& +absl::Span SchedulingConstraintHelper::TaskByIncreasingShiftedStartMin() { if (recompute_shifted_start_min_) { recompute_shifted_start_min_ = false; @@ -822,13 +834,13 @@ IntegerValue ComputeEnergyMinInWindow( } SchedulingDemandHelper::SchedulingDemandHelper( - std::vector demands, SchedulingConstraintHelper* helper, - Model* model) + absl::Span demands, + SchedulingConstraintHelper* helper, Model* model) : integer_trail_(model->GetOrCreate()), product_decomposer_(model->GetOrCreate()), sat_solver_(model->GetOrCreate()), assignment_(model->GetOrCreate()->Assignment()), - demands_(std::move(demands)), + demands_(demands.begin(), demands.end()), helper_(helper) { const int num_tasks = helper->NumTasks(); linearized_energies_.resize(num_tasks); diff --git a/ortools/sat/intervals.h b/ortools/sat/intervals.h index 5428e10c19..cf7bdb0d7c 100644 --- a/ortools/sat/intervals.h +++ b/ortools/sat/intervals.h @@ -167,7 +167,7 @@ class IntervalsRepository { // demands must be the compatible. SchedulingDemandHelper* GetOrCreateDemandHelper( SchedulingConstraintHelper* helper, - const std::vector& demands); + absl::Span demands); // Calls InitDecomposedEnergies on all SchedulingDemandHelper created. void InitAllDecomposedEnergies(); @@ -362,11 +362,11 @@ class SchedulingConstraintHelper : public PropagatorInterface, // // TODO(user): we could merge the first loop of IncrementalSort() with the // loop that fill TaskTime.time at each call. - const std::vector& TaskByIncreasingStartMin(); - const std::vector& TaskByIncreasingEndMin(); - const std::vector& TaskByDecreasingStartMax(); - const std::vector& TaskByDecreasingEndMax(); - const std::vector& TaskByIncreasingShiftedStartMin(); + absl::Span TaskByIncreasingStartMin(); + absl::Span TaskByIncreasingEndMin(); + absl::Span TaskByDecreasingStartMax(); + absl::Span TaskByDecreasingEndMax(); + absl::Span TaskByIncreasingShiftedStartMin(); // Returns a sorted vector where each task appear twice, the first occurrence // is at size (end_min - size_min) and the second one at (end_min). @@ -434,12 +434,13 @@ class SchedulingConstraintHelper : public PropagatorInterface, IntegerLiteral lit); // Returns the underlying affine expressions. - const std::vector& IntervalVariables() const { + absl::Span IntervalVariables() const { return interval_variables_; } - const std::vector& Starts() const { return starts_; } - const std::vector& Ends() const { return ends_; } - const std::vector& Sizes() const { return sizes_; } + absl::Span Starts() const { return starts_; } + absl::Span Ends() const { return ends_; } + absl::Span Sizes() const { return sizes_; } + Literal PresenceLiteral(int index) const { DCHECK(IsOptional(index)); return Literal(reason_for_presence_[index]); @@ -528,13 +529,19 @@ class SchedulingConstraintHelper : public PropagatorInterface, int previous_level_ = 0; // The caches of all relevant interval values. - std::vector cached_size_min_; - std::vector cached_start_min_; - std::vector cached_end_min_; - std::vector cached_negated_start_max_; - std::vector cached_negated_end_max_; - std::vector cached_shifted_start_min_; - std::vector cached_negated_shifted_end_max_; + // These are initially of size capacity and never resized. + // + // TODO(user): Because of std::swap() in SetTimeDirection, we cannot mark + // most of them as "const" and as a result we loose some performance since + // the address need to be re-fetched on most access. + const int capacity_; + const std::unique_ptr cached_size_min_; + std::unique_ptr cached_start_min_; + std::unique_ptr cached_end_min_; + std::unique_ptr cached_negated_start_max_; + std::unique_ptr cached_negated_end_max_; + std::unique_ptr cached_shifted_start_min_; + std::unique_ptr cached_negated_shifted_end_max_; // Sorted vectors returned by the TasksBy*() functions. std::vector task_by_increasing_start_min_; @@ -581,7 +588,7 @@ class SchedulingDemandHelper { public: // Hack: this can be called with and empty demand vector as long as // OverrideEnergies() is called to define the energies. - SchedulingDemandHelper(std::vector demands, + SchedulingDemandHelper(absl::Span demands, SchedulingConstraintHelper* helper, Model* model); // When defined, the interval will consume this much demand during its whole diff --git a/ortools/sat/linear_relaxation.cc b/ortools/sat/linear_relaxation.cc index 5deae946a7..8ba9863beb 100644 --- a/ortools/sat/linear_relaxation.cc +++ b/ortools/sat/linear_relaxation.cc @@ -688,7 +688,7 @@ std::optional DetectMakespanFromPrecedences( const SchedulingConstraintHelper& helper, Model* model) { if (helper.NumTasks() == 0) return {}; - const std::vector& ends = helper.Ends(); + const absl::Span ends = helper.Ends(); std::vector end_vars; for (const AffineExpression& end : ends) { // TODO(user): Deal with constant end. diff --git a/ortools/sat/timetable_edgefinding.cc b/ortools/sat/timetable_edgefinding.cc index 2bb8397ace..a7177f2e60 100644 --- a/ortools/sat/timetable_edgefinding.cc +++ b/ortools/sat/timetable_edgefinding.cc @@ -68,8 +68,8 @@ void TimeTableEdgeFinding::BuildTimeTable() { ecp_.clear(); // Build start of compulsory part events. - for (const auto task_time : - ::gtl::reversed_view(helper_->TaskByDecreasingStartMax())) { + const auto by_decreasing_start_max = helper_->TaskByDecreasingStartMax(); + for (const auto task_time : ::gtl::reversed_view(by_decreasing_start_max)) { const int t = task_time.task_index; if (!helper_->IsPresent(t)) continue; if (task_time.time < helper_->EndMin(t)) { @@ -88,10 +88,8 @@ void TimeTableEdgeFinding::BuildTimeTable() { DCHECK_EQ(scp_.size(), ecp_.size()); - const std::vector& by_decreasing_end_max = - helper_->TaskByDecreasingEndMax(); - const std::vector& by_start_min = - helper_->TaskByIncreasingStartMin(); + const auto by_decreasing_end_max = helper_->TaskByDecreasingEndMax(); + const auto by_start_min = helper_->TaskByIncreasingStartMin(); IntegerValue height = IntegerValue(0); IntegerValue energy = IntegerValue(0);