diff --git a/cmake/python.cmake b/cmake/python.cmake index c843e6baee..a72a2df5c3 100644 --- a/cmake/python.cmake +++ b/cmake/python.cmake @@ -387,7 +387,7 @@ if(BUILD_MATH_OPT) endif() file(COPY ortools/sat/python/cp_model.py - ortools/sat/python/cp_model_helper.py + ortools/sat/python/cp_model_numbers.py DESTINATION ${PYTHON_PROJECT_DIR}/sat/python) file(COPY ortools/sat/colab/flags.py @@ -637,7 +637,7 @@ add_custom_command( $,copy,true> $<$:$> ${PYTHON_PROJECT}/pdlp/python COMMAND ${CMAKE_COMMAND} -E copy - $ ${PYTHON_PROJECT}/sat/python + $ ${PYTHON_PROJECT}/sat/python COMMAND ${CMAKE_COMMAND} -E copy $ ${PYTHON_PROJECT}/scheduling/python COMMAND ${CMAKE_COMMAND} -E copy @@ -657,7 +657,7 @@ add_custom_command( model_builder_helper_pybind11 math_opt_pybind11 $ - swig_helper_pybind11 + cp_model_helper_pybind11 rcpsp_pybind11 sorted_interval_list_pybind11 WORKING_DIRECTORY python @@ -694,7 +694,7 @@ add_custom_command( COMMAND ${stubgen_EXECUTABLE} -p pybind11_abseil.status --output . COMMAND ${stubgen_EXECUTABLE} -p ortools.math_opt.core.python.solver --output . COMMAND ${stubgen_EXECUTABLE} -p ortools.pdlp.python.pdlp --output . - COMMAND ${stubgen_EXECUTABLE} -p ortools.sat.python.swig_helper --output . + COMMAND ${stubgen_EXECUTABLE} -p ortools.sat.python.cp_model_helper --output . COMMAND ${stubgen_EXECUTABLE} -p ortools.scheduling.python.rcpsp --output . COMMAND ${stubgen_EXECUTABLE} -p ortools.util.python.sorted_interval_list --output . COMMAND ${CMAKE_COMMAND} -E touch ${PROJECT_BINARY_DIR}/python/stub_timestamp diff --git a/cmake/samples/python/sample.py b/cmake/samples/python/sample.py index daa81e54a3..758346bcf1 100644 --- a/cmake/samples/python/sample.py +++ b/cmake/samples/python/sample.py @@ -21,7 +21,7 @@ from ortools.constraint_solver import pywrapcp # from ortools.graph.python import min_cost_flow from ortools.linear_solver import pywraplp # from ortools.linear_solver import linear_solver_pb2 -# from ortools.sat.python import swig_helper +# from ortools.sat.python import cp_model_helper # from ortools.sat.python import cp_model # from ortools.scheduling import rcpsp # from ortools.util.python import sorted_interval_list diff --git a/examples/cpp/dimacs_assignment.cc b/examples/cpp/dimacs_assignment.cc index 3130643b8c..bcf43712ab 100644 --- a/examples/cpp/dimacs_assignment.cc +++ b/examples/cpp/dimacs_assignment.cc @@ -103,10 +103,7 @@ CostValue BuildAndSolveHungarianInstance( template void DisplayAssignment(const LinearSumAssignment& assignment) { - for (typename LinearSumAssignment::BipartiteLeftNodeIterator - node_it(assignment); - node_it.Ok(); node_it.Next()) { - const NodeIndex left_node = node_it.Index(); + for (const auto left_node : assignment.BipartiteLeftNodes()) { const ArcIndex matching_arc = assignment.GetAssignmentArc(left_node); const NodeIndex right_node = assignment.Head(matching_arc); VLOG(5) << "assigned (" << left_node << ", " << right_node diff --git a/examples/cpp/flow_api.cc b/examples/cpp/flow_api.cc index 007f0ca24c..5be499e739 100644 --- a/examples/cpp/flow_api.cc +++ b/examples/cpp/flow_api.cc @@ -56,7 +56,7 @@ void MinCostFlowOn4x4Matrix() { min_cost_flow.SetNodeSupply(kNumSources + target, -1); } CHECK(min_cost_flow.Solve()); - CHECK_EQ(MinCostFlow::OPTIMAL, min_cost_flow.status()); + CHECK_EQ(GenericMinCostFlow::OPTIMAL, min_cost_flow.status()); CostValue total_flow_cost = min_cost_flow.GetOptimalCost(); CHECK_EQ(kExpectedCost, total_flow_cost); } diff --git a/examples/cpp/min_cost_flow.cc b/examples/cpp/min_cost_flow.cc index 26ba8dd3d6..92b0cb4a9a 100644 --- a/examples/cpp/min_cost_flow.cc +++ b/examples/cpp/min_cost_flow.cc @@ -22,15 +22,16 @@ namespace operations_research { struct Arc { - std::pair nodes; - FlowQuantity capacity; - FlowQuantity unit_cost; + std::pair nodes; + SimpleMinCostFlow::FlowQuantity capacity; + SimpleMinCostFlow::FlowQuantity unit_cost; }; void SolveMinCostFlow() { // Define supply of each node. - const std::vector > supplies = { - {0, 20}, {1, 0}, {2, 0}, {3, -5}, {4, -15}}; + const std::vector< + std::pair > + supplies = {{0, 20}, {1, 0}, {2, 0}, {3, -5}, {4, -15}}; // Define each arc // Can't use std::tuple @@ -58,7 +59,7 @@ void SolveMinCostFlow() { if (status != SimpleMinCostFlow::OPTIMAL) { LOG(FATAL) << "Solving the max flow is not optimal!"; } - FlowQuantity total_flow_cost = min_cost_flow.OptimalCost(); + SimpleMinCostFlow::FlowQuantity total_flow_cost = min_cost_flow.OptimalCost(); LOG(INFO) << "Minimum cost flow: " << total_flow_cost; LOG(INFO) << ""; LOG(INFO) << "Arc : Flow / Capacity / Cost"; diff --git a/examples/cpp/print_dimacs_assignment.h b/examples/cpp/print_dimacs_assignment.h index 8d08963403..93154b3be5 100644 --- a/examples/cpp/print_dimacs_assignment.h +++ b/examples/cpp/print_dimacs_assignment.h @@ -46,10 +46,9 @@ void PrintDimacsAssignmentProblem( absl::StrFormat("p asn %d %d\n", graph.num_nodes(), graph.num_arcs()); CHECK_OK(file::WriteString(output, output_line, file::Defaults())); - for (typename LinearSumAssignment::BipartiteLeftNodeIterator - node_it(assignment); - node_it.Ok(); node_it.Next()) { - output_line = absl::StrFormat("n %d\n", node_it.Index() + 1); + for (const typename GraphType::NodeIndex left_node : + assignment.BipartiteLeftNodes()) { + output_line = absl::StrFormat("n %d\n", left_node + 1); CHECK_OK(file::WriteString(output, output_line, file::Defaults())); } diff --git a/ortools/base/BUILD.bazel b/ortools/base/BUILD.bazel index 08fb3ba1e4..327a9aa70c 100644 --- a/ortools/base/BUILD.bazel +++ b/ortools/base/BUILD.bazel @@ -215,6 +215,19 @@ cc_library( ], ) +cc_library( + name = "fuzztest", + testonly = 1, + hdrs = ["fuzztest.h"], + deps = [ + "@com_google_absl//absl/log:check", + "@com_google_fuzztest//fuzztest", + "@com_google_fuzztest//fuzztest:googletest_fixture_adapter", + "@com_google_fuzztest//fuzztest:init_fuzztest", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "status_matchers", hdrs = ["status_matchers.h"], @@ -309,6 +322,31 @@ cc_library( deps = [":base"], ) +cc_library( + name = "constant_divisor", + srcs = ["constant_divisor.cc"], + hdrs = ["constant_divisor.h"], + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:int128", + ], +) + +cc_test( + name = "constant_divisor_test", + srcs = ["constant_divisor_test.cc"], + deps = [ + ":constant_divisor", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/random", + "@com_google_absl//absl/random:bit_gen_ref", + "@com_google_absl//absl/random:distributions", + "@com_google_benchmark//:benchmark", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "linked_hash_map", hdrs = ["linked_hash_map.h"], diff --git a/ortools/base/constant_divisor.cc b/ortools/base/constant_divisor.cc new file mode 100644 index 0000000000..cfc859fed6 --- /dev/null +++ b/ortools/base/constant_divisor.cc @@ -0,0 +1,56 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/base/constant_divisor.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/numeric/int128.h" + +namespace util { +namespace math { + +// Fast div/mod implementation based on +// "Faster Remainder by Direct Computation: Applications to Compilers and +// Software Libraries" Daniel Lemire, Owen Kaser, Nathan Kurz arXiv:1902.01961 +ConstantDivisor::ConstantDivisor(value_type d) + : ConstantDivisorBase((absl::Uint128Max() / d) + 1, d) { + CHECK_GT(d, 1) << "ConstantDivisor only supports denominators > 1."; +} + +// If we hardcode shift_amount to 32, the 32-bit formula is: +// magic_number = 2 ^ 64 / d +// value / d = value * magic_number >> 64 +// +// One caveat is that for d == 1, magic_number takes 65 bits overflowing a +// uint64_t. So, we again disallow inputs with d == 1. +ConstantDivisor::ConstantDivisor(value_type d) + : ConstantDivisorBase((std::numeric_limits::max() / d) + 1, + d) { + CHECK_GT(d, 1) << "ConstantDivisor only supports denominators > 1."; +} + +ConstantDivisor::ConstantDivisor(value_type d) + : ConstantDivisorBase((MagicValueType{1} << kShift) / d + 1, d) { + CHECK_GT(d, 0); +} + +ConstantDivisor::ConstantDivisor(value_type d) + : ConstantDivisorBase((MagicValueType{1} << kShift) / d + 1, d) { + CHECK_GT(d, 0); +} + +} // namespace math +} // namespace util diff --git a/ortools/base/constant_divisor.h b/ortools/base/constant_divisor.h new file mode 100644 index 0000000000..36629cf617 --- /dev/null +++ b/ortools/base/constant_divisor.h @@ -0,0 +1,194 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_BASE_CONSTANT_DIVISOR_H_ +#define OR_TOOLS_BASE_CONSTANT_DIVISOR_H_ + +// Provides faster division in situations where the same divisor is used +// repeatedly but is not known at compile time. For example, a hash table might +// not be sized until the model is loaded, but once loaded it is not resized for +// the life of the model. As this is not a compile time constant, the compiler +// can not optimize away the division. +// +// However the cost of precomputing coefficients when the hash table is sized +// can be dwarfed by the cycles saved avoiding hardware division on every +// lookup. See benchmark section below to estimate the breakeven point. This +// reduces the CPU penalty of non-power of two sized hash tables, bloom filters, +// etc. +// +// +// The following template and specializations are defined in this file: +// - ConstantDivisor +// - ConstantDivisor +// - ConstantDivisor +// - ConstantDivisor (Only supports denominators > 1) +// - ConstantDivisor (Only supports denominators > 1) +// +// Fast div/mod implementation based on +// "Faster Remainder by Direct Computation: Applications to Compilers and +// Software Libraries" Daniel Lemire, Owen Kaser, Nathan Kurz arXiv:1902.01961 +// +// Usage: +// uint64_t n; // Not known at compile time! +// ConstantDivisor divisor(n); +// uint64_t m = ...; +// EXPECT_EQ(m / n, divisor.div(m)); +// EXPECT_EQ(m % n, divisor.mod(m)); +// + +#include + +#include "absl/numeric/int128.h" + +namespace util { +namespace math { + +template +class ConstantDivisor { + public: + typedef T value_type; + + explicit ConstantDivisor(value_type denominator) + : denominator_(denominator) {} + + value_type div(value_type n) const { return n / denominator_; } + + value_type mod(value_type n) const { return n % denominator_; } + + friend value_type operator/(value_type a, const ConstantDivisor& b) { + return b.div(a); + } + + friend value_type operator%(value_type a, const ConstantDivisor& b) { + return b.mod(a); + } + + private: + value_type denominator_; +}; + +namespace internal { + +// Common code for all specializations. +template +class ConstantDivisorBase { + public: + using value_type = T; + + explicit ConstantDivisorBase(MagicT magic, value_type denominator) + : magic_(magic), denominator_(denominator) {} + + value_type mod(value_type numerator) const { + return numerator - + static_cast(this)->div(numerator) * denominator_; + } + + friend value_type operator/(value_type a, const Impl& b) { return b.div(a); } + + friend value_type operator%(value_type a, const Impl& b) { return b.mod(a); } + + value_type denominator() const { return denominator_; } + + protected: + using MagicValueType = MagicT; + static_assert(sizeof(MagicT) >= 2 * sizeof(value_type)); + MagicT magic_; + + private: + value_type denominator_; +}; +} // namespace internal + +// Division and modulus using uint64_t numerators and denominators. +template <> +class ConstantDivisor + : public internal::ConstantDivisorBase> { + public: + // REQUIRES: denominator > 1 + explicit ConstantDivisor(value_type denominator); + + value_type div(value_type numerator) const { + return MultiplyHi(magic_, numerator); + } + + private: + static uint64_t MultiplyHi(absl::uint128 a, uint64_t b) { + absl::uint128 lo(absl::Uint128Low64(a)); + absl::uint128 hi(absl::Uint128High64(a)); + absl::uint128 bottom = (lo * b) >> 64; + absl::uint128 top = (hi * b); + return absl::Uint128High64(bottom + top); + } +}; + +// Division and modulus using uint32_t numerators and denominators. +template <> +class ConstantDivisor + : public internal::ConstantDivisorBase> { + public: + using value_type = uint32_t; + + // REQUIRES: denominator > 1 + explicit ConstantDivisor(value_type denominator); + + value_type div(value_type numerator) const { + return absl::Uint128High64(absl::uint128(numerator) * magic_); + } +}; + +// Division and modulus using uint16_t numerators and denominators. +template <> +class ConstantDivisor + : public internal::ConstantDivisorBase> { + public: + using value_type = uint16_t; + + explicit ConstantDivisor(value_type denominator); + + value_type div(value_type numerator) const { + return static_cast( + (magic_ * static_cast(numerator)) >> kShift); + } + + private: + // Any value in [32;48] works here. + static constexpr MagicValueType kShift = 32; +}; + +// Division and modulus using uint8_t numerators and denominators. +template <> +class ConstantDivisor + : public internal::ConstantDivisorBase> { + public: + using value_type = uint8_t; + + explicit ConstantDivisor(value_type denominator); + + value_type div(value_type numerator) const { + return static_cast( + (magic_ * static_cast(numerator)) >> kShift); + } + + private: + // Any value in [16;24] works here. + static constexpr MagicValueType kShift = 16; +}; + +} // namespace math +} // namespace util + +#endif // OR_TOOLS_BASE_CONSTANT_DIVISOR_H_ diff --git a/ortools/base/constant_divisor_test.cc b/ortools/base/constant_divisor_test.cc new file mode 100644 index 0000000000..eaa1687480 --- /dev/null +++ b/ortools/base/constant_divisor_test.cc @@ -0,0 +1,264 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/base/constant_divisor.h" + +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/random/distributions.h" +#include "absl/random/random.h" +#include "benchmark/benchmark.h" +#include "gtest/gtest.h" + +ABSL_FLAG(int32_t, random_iterations, 100000, + "Number of iterations for ConstantDivisorTest::RandomCases."); + +namespace util { +namespace math { +namespace { + +template +class NativeDivisor { + public: + typedef T value_type; + + explicit NativeDivisor(T denominator) : denominator_(denominator) {} + + T div(T value) const { return value / denominator_; } + + T mod(T value) const { return value % denominator_; } + + friend value_type operator/(value_type a, const NativeDivisor& b) { + return b.div(a); + } + + friend value_type operator%(value_type a, const NativeDivisor& b) { + return b.mod(a); + } + + private: + const T denominator_; +}; + +template +class ConstantDivisorTest : public ::testing::Test {}; + +// Currently there is no specialization for int, so it should default to the +// builtin version. +TEST(ConstantDivisorTemplateTest, Simple) { + ConstantDivisor divisor(3); + EXPECT_EQ(4, divisor.div(12)); + EXPECT_EQ(1, divisor.mod(13)); + EXPECT_EQ(4, 12 / divisor); + EXPECT_EQ(1, 13 % divisor); +} + +TEST(ConstantDivisorUint64Test, Bugs) { + // If forumula (27) from p231 is ever implemented, these divisors will break + // if a >= is accidentally used instead of >. + EXPECT_EQ(uint64_t{828560257293048160}, + ConstantDivisor(21).div(uint64_t{17399765403154011380u})); + EXPECT_EQ(uint64_t{185733693349184273}, + ConstantDivisor(99).div(uint64_t{18387635641569243125u})); +} + +TEST(ConstantDivisorUint16Test, Supports1) { + ConstantDivisor divisor(1); + ASSERT_EQ(42, 42 / divisor); + ASSERT_EQ(0, 42 % divisor); +} + +TEST(ConstantDivisorUint8Test, Exhaustive) { + // This is cheap, so test all values. + for (int denominator = 1; denominator < 256; ++denominator) { + ConstantDivisor divisor(denominator); + for (int value = 0; value < 256; ++value) { + ASSERT_EQ(value / denominator, divisor.div(value)) + << "denominator: " << denominator << " value: " << value; + ASSERT_EQ(value % denominator, divisor.mod(value)) + << "denominator: " << denominator << " value: " << value; + } + } +} + +typedef ::testing::Types, ConstantDivisor, + ConstantDivisor, NativeDivisor, + NativeDivisor, NativeDivisor > + Divisors; +TYPED_TEST_SUITE(ConstantDivisorTest, Divisors); + +TYPED_TEST(ConstantDivisorTest, Simple) { + TypeParam divisor(3); + EXPECT_EQ(4, divisor.div(12)); + EXPECT_EQ(1, divisor.mod(13)); + EXPECT_EQ(4, 12 / divisor); + EXPECT_EQ(1, 13 % divisor); +} + +TYPED_TEST(ConstantDivisorTest, CornerCases) { + EXPECT_EQ(1, TypeParam(5).div(5)); + EXPECT_EQ(2, TypeParam(2).div(4)); + if constexpr (sizeof(typename TypeParam::value_type) >= sizeof(uint16_t)) { + EXPECT_EQ(100, TypeParam(5).div(500)); + } + const auto kTypeMax = + std::numeric_limits::max(); + if constexpr (sizeof(typename TypeParam::value_type) >= sizeof(uint16_t)) { + EXPECT_EQ(kTypeMax / 345, TypeParam(345).div(kTypeMax)); + } + EXPECT_EQ(1, TypeParam(kTypeMax).div(kTypeMax)); + EXPECT_EQ(1, TypeParam(kTypeMax - 1).div(kTypeMax)); + EXPECT_EQ(0, TypeParam(kTypeMax).div((kTypeMax - 1))); +} + +TYPED_TEST(ConstantDivisorTest, Bugs) { + if constexpr (sizeof(typename TypeParam::value_type) < sizeof(uint32_t)) { + GTEST_SKIP() << "This test is only for 32-bit and above."; + } else { + // Cases that triggered bugs found during initial implementation. + EXPECT_EQ(0, TypeParam(2969932030).div(265448460)); + EXPECT_EQ(2, TypeParam(978790915).div(2489284541)); + EXPECT_EQ(1, TypeParam(4113163180).div(4220126436)); + EXPECT_EQ(2072455839, TypeParam(2).div(4144911678)); + } +} + +// Choose a random value of type T, biased towards smaller values. +template +T ChooseValue(absl::BitGenRef gen) { + return absl::Uniform(gen, 0, std::numeric_limits::max()) >> + absl::Uniform(gen, 0, 8 * sizeof(T)); +} + +TYPED_TEST(ConstantDivisorTest, RandomCases) { + typedef typename TypeParam::value_type T; + absl::BitGen gen; + for (int i = 0; i < absl::GetFlag(FLAGS_random_iterations); ++i) { + T denominator = std::max(2, ChooseValue(gen)); + T value = ChooseValue(gen); + TypeParam divisor(denominator); + ASSERT_EQ(value / denominator, divisor.div(value)) + << value << " / " << denominator; + ASSERT_EQ(value % denominator, divisor.mod(value)); + } +} + +// Gives a sense of benchmark overhead. +class NoopDivisor { + public: + typedef uint32_t value_type; + + explicit NoopDivisor(uint32_t) {} + + uint32_t div(uint32_t value) const { return value; } + + uint32_t mod(uint32_t value) const { return value; } +}; + +// Choose a random denominator which is supported by all our implementations, +// biased towards smaller denominators for uint64_t/uint32_t/uint16_t. +template +T ChooseDenominator(absl::BitGenRef random) { + return std::max(uint8_t{2}, ChooseValue(random)); +} + +template +void BM_Divide(benchmark::State& state) { + typedef typename Divisor::value_type T; + absl::BitGen gen; + std::vector values; + for (int i = 0; i < 100000; ++i) { + values.push_back(ChooseValue(gen)); + } + + for (auto _ : state) { + state.PauseTiming(); + Divisor divisor(ChooseDenominator(gen)); + state.ResumeTiming(); + for (T value : values) { + benchmark::DoNotOptimize(divisor.div(value)); + } + } +} +BENCHMARK_TEMPLATE(BM_Divide, NoopDivisor); +BENCHMARK_TEMPLATE(BM_Divide, NativeDivisor); +BENCHMARK_TEMPLATE(BM_Divide, ConstantDivisor); +BENCHMARK_TEMPLATE(BM_Divide, NativeDivisor); +BENCHMARK_TEMPLATE(BM_Divide, ConstantDivisor); +BENCHMARK_TEMPLATE(BM_Divide, NativeDivisor); +BENCHMARK_TEMPLATE(BM_Divide, ConstantDivisor); +BENCHMARK_TEMPLATE(BM_Divide, NativeDivisor); +BENCHMARK_TEMPLATE(BM_Divide, ConstantDivisor); + +template +void BM_Modulo(benchmark::State& state) { + typedef typename Divisor::value_type T; + absl::BitGen gen; + std::vector values; + for (int i = 0; i < 100000; ++i) { + values.push_back(ChooseValue(gen)); + } + + for (auto _ : state) { + state.PauseTiming(); + Divisor divisor(ChooseDenominator(gen)); + state.ResumeTiming(); + for (T value : values) { + benchmark::DoNotOptimize(divisor.mod(value)); + } + } +} +BENCHMARK_TEMPLATE(BM_Modulo, NoopDivisor); +BENCHMARK_TEMPLATE(BM_Modulo, NativeDivisor); +BENCHMARK_TEMPLATE(BM_Modulo, ConstantDivisor); +BENCHMARK_TEMPLATE(BM_Modulo, NativeDivisor); +BENCHMARK_TEMPLATE(BM_Modulo, ConstantDivisor); +BENCHMARK_TEMPLATE(BM_Modulo, NativeDivisor); +BENCHMARK_TEMPLATE(BM_Modulo, ConstantDivisor); +BENCHMARK_TEMPLATE(BM_Modulo, NativeDivisor); +BENCHMARK_TEMPLATE(BM_Modulo, ConstantDivisor); + +template +void BM_ConstructDivisor(benchmark::State& state) { + typedef typename Divisor::value_type T; + absl::BitGen gen; + std::vector values; + for (int i = 0; i < 2048; ++i) { + values.push_back(ChooseDenominator(gen)); + } + + int mask = values.size() - 1; + int i = 0; + for (auto _ : state) { + Divisor divisor(values[i & mask]); + benchmark::DoNotOptimize(divisor.div(values[(i + 1) & mask])); + i++; + } +} +BENCHMARK_TEMPLATE(BM_ConstructDivisor, NoopDivisor); +BENCHMARK_TEMPLATE(BM_ConstructDivisor, NativeDivisor); +BENCHMARK_TEMPLATE(BM_ConstructDivisor, ConstantDivisor); +BENCHMARK_TEMPLATE(BM_ConstructDivisor, NativeDivisor); +BENCHMARK_TEMPLATE(BM_ConstructDivisor, ConstantDivisor); +BENCHMARK_TEMPLATE(BM_ConstructDivisor, NativeDivisor); +BENCHMARK_TEMPLATE(BM_ConstructDivisor, ConstantDivisor); +BENCHMARK_TEMPLATE(BM_ConstructDivisor, NativeDivisor); +BENCHMARK_TEMPLATE(BM_ConstructDivisor, ConstantDivisor); + +} // namespace +} // namespace math +} // namespace util diff --git a/ortools/base/fuzztest.h b/ortools/base/fuzztest.h new file mode 100644 index 0000000000..8c7c992aa1 --- /dev/null +++ b/ortools/base/fuzztest.h @@ -0,0 +1,55 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_BASE_FUZZTEST_H_ +#define OR_TOOLS_BASE_FUZZTEST_H_ + +#include +#include +#include +#include +#include + +#include "fuzztest/domain.h" +#include "fuzztest/fuzztest.h" +#include "fuzztest/googletest_fixture_adapter.h" +#include "fuzztest/init_fuzztest.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" +#include "ortools/base/logging.h" + +namespace fuzztest { + +// Reads protos from directory and returns a vector usable by the .WithSeeds() +// function to aid in fuzz test migrations. `is_text_format` should be true iff +// the protos are in text format. +template +std::vector> ReadFilesFromDirectory( + std::string_view dir) { + std::vector> corpus; + + for (std::tuple& proto_tuple : ReadFilesFromDirectory(dir)) { + std::string text_proto = std::get<0>(proto_tuple); + ProtoType proto; + bool was_parsed = + google::protobuf::TextFormat::ParseFromString(text_proto, &proto); + if (was_parsed) { + corpus.push_back(std::make_tuple(proto)); + } + } + return corpus; +} + +} // namespace fuzztest + +#endif // OR_TOOLS_BASE_FUZZTEST_H_ diff --git a/ortools/graph/BUILD.bazel b/ortools/graph/BUILD.bazel index 2ea9056b9b..35525277bf 100644 --- a/ortools/graph/BUILD.bazel +++ b/ortools/graph/BUILD.bazel @@ -32,22 +32,13 @@ config_setting( constraint_values = ["@platforms//os:windows"], ) -# Main Target -cc_library( - name = "graphs", - hdrs = ["graphs.h"], - deps = [ - ":ebert_graph", - ":graph", - ], -) - cc_library( name = "graph", hdrs = ["graph.h"], deps = [ ":iterators", "//ortools/base", + "//ortools/base:constant_divisor", "@com_google_absl//absl/debugging:leak_check", "@com_google_absl//absl/types:span", ], @@ -101,7 +92,6 @@ cc_library( srcs = ["minimum_vertex_cover.cc"], hdrs = ["minimum_vertex_cover.h"], deps = [ - ":ebert_graph", ":max_flow", "@com_google_absl//absl/log:check", ], @@ -350,28 +340,6 @@ cc_test( cc_library( name = "ebert_graph", hdrs = ["ebert_graph.h"], - deps = [ - ":iterators", - "//ortools/base", - "//ortools/util:permutation", - "//ortools/util:zvector", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/strings", - ], -) - -cc_test( - name = "ebert_graph_test", - size = "small", - srcs = ["ebert_graph_test.cc"], - deps = [ - ":ebert_graph", - "//ortools/base:gmock_main", - "//ortools/util:permutation", - "@com_google_absl//absl/random:distributions", - "@com_google_absl//absl/strings", - "@com_google_benchmark//:benchmark", - ], ) cc_library( @@ -398,7 +366,6 @@ cc_test( srcs = ["shortest_paths_test.cc"], tags = ["noasan"], # Times out occasionally in ASAN mode. deps = [ - ":ebert_graph", ":graph", ":shortest_paths", ":strongly_connected_components", @@ -471,7 +438,6 @@ cc_test( srcs = ["max_flow_test.cc"], data = ["//ortools/graph/testdata:max_flow_test1.pb.txt"], deps = [ - ":ebert_graph", ":flow_problem_cc_proto", ":max_flow", "//ortools/base:gmock_main", @@ -487,7 +453,6 @@ cc_library( deps = [ ":ebert_graph", ":flow_problem_cc_proto", - ":graphs", "//ortools/base", "//ortools/util:stats", "//ortools/util:zvector", @@ -504,7 +469,6 @@ cc_test( ":ebert_graph", ":generic_max_flow", ":graph", - ":graphs", "//ortools/base", "//ortools/base:gmock_main", "//ortools/linear_solver", @@ -528,10 +492,8 @@ cc_library( "//conditions:default": [], }), deps = [ - ":ebert_graph", ":generic_max_flow", ":graph", - ":graphs", "//ortools/base:mathutil", "//ortools/util:saturated_arithmetic", "//ortools/util:stats", @@ -551,7 +513,6 @@ cc_test( srcs = ["min_cost_flow_test.cc"], deps = [ ":ebert_graph", - ":graphs", ":min_cost_flow", "//ortools/base:gmock_main", "@com_google_absl//absl/log", @@ -634,6 +595,7 @@ cc_library( hdrs = ["linear_assignment.h"], deps = [ ":ebert_graph", + ":iterators", "//ortools/base", "//ortools/util:permutation", "//ortools/util:zvector", @@ -647,11 +609,11 @@ cc_test( size = "small", srcs = ["linear_assignment_test.cc"], deps = [ - ":ebert_graph", ":graph", ":linear_assignment", - "//ortools/base", "//ortools/base:gmock_main", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/random:distributions", "@com_google_absl//absl/types:span", "@com_google_benchmark//:benchmark", @@ -681,7 +643,6 @@ cc_library( srcs = ["dag_shortest_path.cc"], hdrs = ["dag_shortest_path.h"], deps = [ - ":ebert_graph", ":graph", ":topologicalsorter", "@com_google_absl//absl/algorithm:container", diff --git a/ortools/graph/csharp/graph.i b/ortools/graph/csharp/graph.i index d033db2894..d51b07c3b5 100644 --- a/ortools/graph/csharp/graph.i +++ b/ortools/graph/csharp/graph.i @@ -27,8 +27,6 @@ %include "ortools/base/base.i" -%import "ortools/graph/ebert_graph.h" - %{ #include "ortools/graph/assignment.h" #include "ortools/graph/max_flow.h" @@ -89,6 +87,7 @@ %unignore operations_research::SimpleMinCostFlow::SimpleMinCostFlow; %unignore operations_research::SimpleMinCostFlow::~SimpleMinCostFlow; %unignore operations_research::SimpleMinCostFlow::AddArcWithCapacityAndUnitCost; +%unignore operations_research::SimpleMinCostFlow::SetArcCapacity; %unignore operations_research::SimpleMinCostFlow::SetNodeSupply; %unignore operations_research::SimpleMinCostFlow::Solve; %unignore operations_research::SimpleMinCostFlow::SolveMaxFlowWithMinCost; diff --git a/ortools/graph/dag_shortest_path.cc b/ortools/graph/dag_shortest_path.cc index e80891d722..f8c15379b4 100644 --- a/ortools/graph/dag_shortest_path.cc +++ b/ortools/graph/dag_shortest_path.cc @@ -19,7 +19,6 @@ #include "absl/log/check.h" #include "absl/types/span.h" -#include "ortools/graph/ebert_graph.h" #include "ortools/graph/graph.h" #include "ortools/graph/topologicalsorter.h" @@ -109,7 +108,7 @@ std::vector KShortestPathsOnDag( const ShortestPathOnDagProblem problem = ReadProblem(num_nodes, arcs_with_length); - KShortestPathsOnDagWrapper> shortest_paths_on_dag( + KShortestPathsOnDagWrapper shortest_paths_on_dag( &problem.graph, &problem.arc_lengths, problem.topological_order, path_count); shortest_paths_on_dag.RunKShortestPathOnDag({source}); @@ -119,9 +118,9 @@ std::vector KShortestPathsOnDag( } std::vector lengths = shortest_paths_on_dag.LengthsTo(destination); - std::vector> arc_paths = + std::vector> arc_paths = shortest_paths_on_dag.ArcPathsTo(destination); - std::vector> node_paths = + std::vector> node_paths = shortest_paths_on_dag.NodePathsTo(destination); std::vector paths; paths.reserve(lengths.size()); diff --git a/ortools/graph/ebert_graph.h b/ortools/graph/ebert_graph.h index 57fc1a39eb..ce226e765e 100644 --- a/ortools/graph/ebert_graph.h +++ b/ortools/graph/ebert_graph.h @@ -11,1714 +11,24 @@ // See the License for the specific language governing permissions and // limitations under the License. +// DEPRECATED: Use the graph types in //util/graph instead. + #ifndef OR_TOOLS_GRAPH_EBERT_GRAPH_H_ #define OR_TOOLS_GRAPH_EBERT_GRAPH_H_ -// A few variations on a theme of the "star" graph representation by -// Ebert, as described in J. Ebert, "A versatile data structure for -// edge-oriented graph algorithms." Communications of the ACM -// 30(6):513-519 (June 1987). -// http://portal.acm.org/citation.cfm?id=214769 -// -// In this file there are three representations that have much in -// common. The general one, called simply EbertGraph, contains both -// forward- and backward-star representations. The other, called -// ForwardEbertGraph, contains only the forward-star representation of -// the graph, and is appropriate for applications where the reverse -// arcs are not needed. -// -// The point of including all the representations in this one file is -// to capitalize, where possible, on the commonalities among them, and -// those commonalities are mostly factored out into base classes as -// described below. Despite the commonalities, however, each of the -// three representations presents a somewhat different interface -// because of their different underlying semantics. -// -// Many clients are expected to use the interfaces to the graph -// objects directly, but some clients are parameterized by graph type -// and need a consistent interface for their underlying graph -// objects. For such clients, a small library of class templates is -// provided to give a consistent interface to clients where the -// underlying graph interfaces differ. Examples are the -// AnnotatedGraphBuildManager<> template, which provides a uniform -// interface for building the various types of graphs; and the -// TailArrayManager<> template, which provides a uniform interface for -// applications that need to map from arc indices to arc tail nodes, -// accounting for the fact that such a mapping has to be requested -// explicitly from the ForwardStarGraph representation. -// -// There are two base class templates, StarGraphBase, and -// EbertGraphBase; their purpose is to hold methods and data -// structures that are in common among their descendants. Only classes -// that are leaves in the following hierarchy tree are eligible for -// free-standing instantiation and use by clients. The parentheses -// around StarGraphBase and EbertGraphBase indicate that they should -// not normally be instantiated by clients: -// -// (StarGraphBase) | -// / | -// / | -// / | -// / | -// (EbertGraphBase) | -// / \ | -// / \ | -// EbertGraph ForwardEbertGraph | -// -// In the general EbertGraph case, the graph is represented with three -// arrays. -// Let n be the number of nodes and m be the number of arcs. -// Let i be an integer in [0..m-1], denoting the index of an arc. -// * head_[i] contains the end-node of arc i, -// * head_[-i-1] contains the start-node of arc i. -// Note that in two's-complement arithmetic, -i-1 = ~i. -// Consequently: -// * head_[~i] contains the end-node of the arc reverse to arc i, -// * head_[i] contains the start-node of the arc reverse to arc i. -// Note that if arc (u, v) is defined, then the data structure also stores -// (v, u). -// Arc ~i thus denotes the arc reverse to arc i. -// This is what makes this representation useful for undirected graphs and for -// implementing algorithms like bidirectional shortest paths. -// Also note that the representation handles multi-graphs. If several arcs -// going from node u to node v are added to the graph, they will be handled as -// separate arcs. -// -// Now, for an integer u in [0..n-1] denoting the index of a node: -// * first_incident_arc_[u] denotes the first arc in the adjacency list of u. -// * going from an arc i, the adjacency list can be traversed using -// j = next_adjacent_arc_[i]. -// -// The EbertGraph implementation has the following benefits: -// * It is able to handle both directed or undirected graphs. -// * Being based on indices, it is easily serializable. Only the contents -// of the head_ array need to be stored. Even so, serialization is -// currently not implemented. -// * The node indices and arc indices can be stored in 32 bits, while -// still allowing to go a bit further than the 4-gigabyte -// limitation (48 gigabytes for a pure graph, without capacities or -// costs.) -// * The representation can be recomputed if edges have been loaded from -// * The representation can be recomputed if edges have been loaded from -// external memory or if edges have been re-ordered. -// * The memory consumption is: 2 * m * sizeof(NodeIndexType) -// + 2 * m * sizeof(ArcIndexType) -// + n * sizeof(ArcIndexType) -// plus a small constant. -// -// The EbertGraph implementation differs from the implementation described in -// [Ebert 1987] in the following respects: -// * arcs are represented using an (i, ~i) approach, whereas Ebert used -// (i, -i). Indices for direct arcs thus start at 0, in a fashion that is -// compatible with the index numbering in C and C++. Note that we also tested -// a (2*i, 2*i+1) storage pattern, which did not show any speed benefit, and -// made the use of the API much more difficult. -// * because of this, the 'nil' values for nodes and arcs are not 0, as Ebert -// first described. The value for the 'nil' node is set to -1, while the -// value for the 'nil' arc is set to the smallest integer representable with -// ArcIndexSize bytes. -// * it is possible to add arcs to the graph, with AddArc, in a much simpler -// way than described by Ebert. -// * TODO(user) although it is already possible, using the -// GroupForwardArcsByFunctor method, to group all the outgoing (resp. -// incoming) arcs of a node, the iterator logic could still be improved to -// allow traversing the outgoing (resp. incoming) arcs in O(out_degree(node)) -// (resp. O(in_degree(node))) instead of O(degree(node)). -// * TODO(user) it is possible to implement arc deletion and garbage collection -// in an efficient (relatively) manner. For the time being we haven't seen an -// application for this. -// -// The ForwardEbertGraph representation is like the EbertGraph case described -// above, with the following modifications: -// * The part of the head_[] array with negative indices is absent. In its -// place is a pointer tail_ which, if assigned, points to an array of tail -// nodes indexed by (nonnegative) arc index. In typical usage tail_ is NULL -// and the memory for the tail nodes need not be allocated. -// * The array of arc tails can be allocated as needed and populated from the -// adjacency lists of the graph. -// * Representing only the forward star of each node implies that the graph -// cannot be serialized directly nor rebuilt from scratch from just the head_ -// array. Rebuilding from scratch requires constructing the array of arc -// tails from the adjacency lists first, and serialization can be done either -// by first constructing the array of arc tails from the adjacency lists, or -// by serializing directly from the adjacency lists. -// * The memory consumption is: m * sizeof(NodeIndexType) -// + m * sizeof(ArcIndexType) -// + n * sizeof(ArcIndexType) -// plus a small constant when the array of arc tails is absent. Allocating -// the arc tail array adds another m * sizeof(NodeIndexType). - -#include -#include #include -#include -#include -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/strings/str_cat.h" -#include "ortools/base/logging.h" -#include "ortools/graph/iterators.h" -#include "ortools/util/permutation.h" -#include "ortools/util/zvector.h" namespace operations_research { -// Forward declarations. -template -class EbertGraph; -template -class ForwardEbertGraph; - -// Standard instantiation of ForwardEbertGraph (named 'ForwardStarGraph') of -// EbertGraph (named 'StarGraph'); and relevant type shortcuts. Unless their use -// cases prevent them from doing so, users are encouraged to use StarGraph or -// ForwardStarGraph according to whether or not they require reverse arcs to be -// represented explicitly. Along with either graph representation, the other -// type shortcuts here will often come in handy. +// DEPRECATED: Global node and arc types for graphs. This have been retired in +// favor of parameterizing the graph types in //util/graph. typedef int32_t NodeIndex; typedef int32_t ArcIndex; + +// DEPRECATED: Global types for flow algorithms. Thes have been retired in favor +// of directly parameterizing those algorithms. typedef int64_t FlowQuantity; typedef int64_t CostValue; -typedef EbertGraph StarGraph; -typedef ForwardEbertGraph ForwardStarGraph; - -// Adapt our old iteration style to support range-based for loops. Add typedefs -// required by std::iterator_traits. -#define DEFINE_STL_ITERATOR_FUNCTIONS(iterator_class_name) \ - using iterator_category = std::input_iterator_tag; \ - using difference_type = ptrdiff_t; \ - using pointer = const ArcIndexType*; \ - using value_type = ArcIndexType; \ - using reference = value_type; \ - bool operator!=(const iterator_class_name& other) const { \ - return this->arc_ != other.arc_; \ - } \ - bool operator==(const iterator_class_name& other) const { \ - return this->arc_ == other.arc_; \ - } \ - ArcIndexType operator*() const { return this->arc_; } \ - void operator++() { this->Next(); } - -template -class StarGraphBase { - public: - // The index of the 'nil' node in the graph. - static const NodeIndexType kNilNode; - - // The index of the 'nil' arc in the graph. - static const ArcIndexType kNilArc; - - // The index of the first node in the graph. - static const NodeIndexType kFirstNode; - - // The index of the first arc in the graph. - static const ArcIndexType kFirstArc; - - // The maximum possible number of nodes in the graph. (The maximum - // index is kMaxNumNodes-1, since indices start at 0. Unfortunately - // we waste a value representing this and the max_num_nodes_ member.) - static const NodeIndexType kMaxNumNodes; - - // The maximum possible number of arcs in the graph. (The maximum - // index is kMaxNumArcs-1, since indices start at 0. Unfortunately - // we waste a value representing this and the max_num_arcs_ member.) - static const ArcIndexType kMaxNumArcs; - // Returns the number of nodes in the graph. - NodeIndexType num_nodes() const { return num_nodes_; } - - // Returns the number of original arcs in the graph - // (The ones with positive indices.) - ArcIndexType num_arcs() const { return num_arcs_; } - - // Returns one more than the largest index of an extant node, - // meaning a node that is mentioned as the head or tail of some arc - // in the graph. To be used as a helper when clients need to - // dimension or iterate over arrays of node annotation information. - NodeIndexType end_node_index() const { return kFirstNode + num_nodes_; } - - // Returns one more than the largest index of an extant direct - // arc. To be used as a helper when clients need to dimension or - // iterate over arrays of arc annotation information. - ArcIndexType end_arc_index() const { return kFirstArc + num_arcs_; } - - // Returns the maximum possible number of nodes in the graph. - NodeIndexType max_num_nodes() const { return max_num_nodes_; } - - // Returns the maximum possible number of original arcs in the graph. - // (The ones with positive indices.) - ArcIndexType max_num_arcs() const { return max_num_arcs_; } - - // Returns one more than the largest valid index of a node. To be - // used as a helper when clients need to dimension or iterate over - // arrays of node annotation information. - NodeIndexType max_end_node_index() const { - return kFirstNode + max_num_nodes_; - } - - // Returns one more than the largest valid index of a direct arc. To - // be used as a helper when clients need to dimension or iterate - // over arrays of arc annotation information. - ArcIndexType max_end_arc_index() const { return kFirstArc + max_num_arcs_; } - - // Utility function to check that a node index is within the bounds AND - // different from kNilNode. - // Returns true if node is in the range [kFirstNode .. max_num_nodes_). - // It is exported so that users of the DerivedGraph class can use it. - // To be used in a DCHECK; also used internally to validate - // arguments passed to our methods from clients (e.g., AddArc()). - bool IsNodeValid(NodeIndexType node) const { - return node >= kFirstNode && node < max_num_nodes_; - } - - // Returns the first arc going from tail to head, if it exists, or kNilArc - // if such an arc does not exist. - ArcIndexType LookUpArc(const NodeIndexType tail, - const NodeIndexType head) const { - for (ArcIndexType arc = FirstOutgoingArc(tail); arc != kNilArc; - arc = ThisAsDerived()->NextOutgoingArc(tail, arc)) { - if (Head(arc) == head) { - return arc; - } - } - return kNilArc; - } - - // Returns the head or end-node of arc. - NodeIndexType Head(const ArcIndexType arc) const { - DCHECK(ThisAsDerived()->CheckArcValidity(arc)); - return head_[arc]; - } - - std::string NodeDebugString(const NodeIndexType node) const { - if (node == kNilNode) { - return "NilNode"; - } else { - return absl::StrCat(static_cast(node)); - } - } - - std::string ArcDebugString(const ArcIndexType arc) const { - if (arc == kNilArc) { - return "NilArc"; - } else { - return absl::StrCat(static_cast(arc)); - } - } - -#if !defined(SWIG) - // Iterator class for traversing all the nodes in the graph. - class NodeIterator { - public: - explicit NodeIterator(const DerivedGraph& graph) - : graph_(graph), head_(graph_.StartNode(kFirstNode)) {} - - // Returns true unless all the nodes have been traversed. - bool Ok() const { return head_ != kNilNode; } - - // Advances the current node index. - void Next() { head_ = graph_.NextNode(head_); } - - // Returns the index of the node currently pointed to by the iterator. - NodeIndexType Index() const { return head_; } - - private: - // A reference to the current DerivedGraph considered. - const DerivedGraph& graph_; - - // The index of the current node considered. - NodeIndexType head_; - }; - - // Iterator class for traversing the arcs in the graph. - class ArcIterator { - public: - explicit ArcIterator(const DerivedGraph& graph) - : graph_(graph), arc_(graph_.StartArc(kFirstArc)) {} - - // Returns true unless all the arcs have been traversed. - bool Ok() const { return arc_ != kNilArc; } - - // Advances the current arc index. - void Next() { arc_ = graph_.NextArc(arc_); } - - // Returns the index of the arc currently pointed to by the iterator. - ArcIndexType Index() const { return arc_; } - - private: - // A reference to the current DerivedGraph considered. - const DerivedGraph& graph_; - - // The index of the current arc considered. - ArcIndexType arc_; - }; - - // Iterator class for traversing the outgoing arcs associated to a given node. - class OutgoingArcIterator { - public: - OutgoingArcIterator(const DerivedGraph& graph, NodeIndexType node) - : graph_(graph), - node_(graph_.StartNode(node)), - arc_(graph_.StartArc(graph_.FirstOutgoingArc(node))) { - DCHECK(CheckInvariant()); - } - - // This constructor takes an arc as extra argument and makes the iterator - // start at arc. - OutgoingArcIterator(const DerivedGraph& graph, NodeIndexType node, - ArcIndexType arc) - : graph_(graph), - node_(graph_.StartNode(node)), - arc_(graph_.StartArc(arc)) { - DCHECK(CheckInvariant()); - } - - // Can only assign from an iterator on the same graph. - void operator=(const OutgoingArcIterator& iterator) { - DCHECK(&iterator.graph_ == &graph_); - node_ = iterator.node_; - arc_ = iterator.arc_; - } - - // Returns true unless all the outgoing arcs have been traversed. - bool Ok() const { return arc_ != kNilArc; } - - // Advances the current outgoing arc index. - void Next() { - arc_ = graph_.NextOutgoingArc(node_, arc_); - DCHECK(CheckInvariant()); - } - - // Returns the index of the arc currently pointed to by the iterator. - ArcIndexType Index() const { return arc_; } - - DEFINE_STL_ITERATOR_FUNCTIONS(OutgoingArcIterator); - - private: - // Returns true if the invariant for the iterator is verified. - // To be used in a DCHECK. - bool CheckInvariant() const { - if (arc_ == kNilArc) { - return true; // This occurs when the iterator has reached the end. - } - DCHECK(graph_.IsOutgoing(arc_, node_)); - return true; - } - - // A reference to the current DerivedGraph considered. - const DerivedGraph& graph_; - - // The index of the node on which arcs are iterated. - NodeIndexType node_; - - // The index of the current arc considered. - ArcIndexType arc_; - }; -#endif // SWIG - - protected: - StarGraphBase() - : max_num_nodes_(0), - max_num_arcs_(0), - num_nodes_(0), - num_arcs_(0), - first_incident_arc_() {} - - ~StarGraphBase() {} - - // Returns kNilNode if the graph has no nodes or node if it has at least one - // node. Useful for initializing iterators correctly in the case of empty - // graphs. - NodeIndexType StartNode(NodeIndexType node) const { - return num_nodes_ == 0 ? kNilNode : node; - } - - // Returns kNilArc if the graph has no arcs arc if it has at least one arc. - // Useful for initializing iterators correctly in the case of empty graphs. - ArcIndexType StartArc(ArcIndexType arc) const { - return num_arcs_ == 0 ? kNilArc : arc; - } - - // Returns the node following the argument in the graph. - // Returns kNilNode (= end) if the range of nodes has been exhausted. - // It is called by NodeIterator::Next() and as such does not expect to be - // passed an argument equal to kNilNode. - // This is why the return line is simplified from - // return (node == kNilNode || next_node >= num_nodes_) - // ? kNilNode : next_node; - // to - // return next_node < num_nodes_ ? next_node : kNilNode; - NodeIndexType NextNode(const NodeIndexType node) const { - DCHECK(IsNodeValid(node)); - const NodeIndexType next_node = node + 1; - return next_node < num_nodes_ ? next_node : kNilNode; - } - - // Returns the arc following the argument in the graph. - // Returns kNilArc (= end) if the range of arcs has been exhausted. - // It is called by ArcIterator::Next() and as such does not expect to be - // passed an argument equal to kNilArc. - // This is why the return line is simplified from - // return ( arc == kNilArc || next_arc >= num_arcs_) ? kNilArc : next_arc; - // to - // return next_arc < num_arcs_ ? next_arc : kNilArc; - ArcIndexType NextArc(const ArcIndexType arc) const { - DCHECK(ThisAsDerived()->CheckArcValidity(arc)); - const ArcIndexType next_arc = arc + 1; - return next_arc < num_arcs_ ? next_arc : kNilArc; - } - - // Returns the first outgoing arc for node. - ArcIndexType FirstOutgoingArc(const NodeIndexType node) const { - DCHECK(IsNodeValid(node)); - return ThisAsDerived()->FindNextOutgoingArc( - ThisAsDerived()->FirstOutgoingOrOppositeIncomingArc(node)); - } - - // The maximum number of nodes that the graph can hold. - NodeIndexType max_num_nodes_; - - // The maximum number of arcs that the graph can hold. - ArcIndexType max_num_arcs_; - - // The maximum index of the node currently held by the graph. - NodeIndexType num_nodes_; - - // The current number of arcs held by the graph. - ArcIndexType num_arcs_; - - // Array of node indices. head_[i] contains the head node of arc i. - ZVector head_; - - // Array of arc indices. first_incident_arc_[i] contains the first arc - // incident to node i. - ZVector first_incident_arc_; - - private: - // Shorthand: returns a const DerivedGraph*-typed version of our - // "this" pointer. - inline const DerivedGraph* ThisAsDerived() const { - return static_cast(this); - } - - // Shorthand: returns a DerivedGraph*-typed version of our "this" - // pointer. - inline DerivedGraph* ThisAsDerived() { - return static_cast(this); - } -}; - -// The index of the 'nil' node in the graph. -template -const NodeIndexType - StarGraphBase::kNilNode = -1; - -// The index of the 'nil' arc in the graph. -template -const ArcIndexType - StarGraphBase::kNilArc = - std::numeric_limits::min(); - -// The index of the first node in the graph. -template -const NodeIndexType - StarGraphBase::kFirstNode = 0; - -// The index of the first arc in the graph. -template -const ArcIndexType - StarGraphBase::kFirstArc = 0; - -// The maximum possible node index in the graph. -template -const NodeIndexType - StarGraphBase::kMaxNumNodes = - std::numeric_limits::max(); - -// The maximum possible number of arcs in the graph. -// (The maximum index is kMaxNumArcs-1, since indices start at 0.) -template -const ArcIndexType - StarGraphBase::kMaxNumArcs = - std::numeric_limits::max(); - -// A template for the base class that holds the functionality that exists in -// common between the EbertGraph<> template and the ForwardEbertGraph<> -// template. -// -// This template is for internal use only, and this is enforced by making all -// constructors for this class template protected. Clients should use one of the -// two derived-class templates. Most clients will not even use those directly, -// but will use the StarGraph and ForwardStarGraph typenames declared above. -// -// The DerivedGraph template argument must be the type of the class (typically -// itself built from a template) that: -// 1. implements the full interface expected for either ForwardEbertGraph or -// EbertGraph, and -// 2. inherits from an instance of this template. -// The base class needs access to some members of the derived class such as, for -// example, NextOutgoingArc(), and it gets this access via the DerivedGraph -// template argument. -template -class EbertGraphBase - : public StarGraphBase { - typedef StarGraphBase Base; - friend class StarGraphBase; - - protected: - using Base::first_incident_arc_; - using Base::head_; - using Base::max_num_arcs_; - using Base::max_num_nodes_; - using Base::num_arcs_; - using Base::num_nodes_; - - public: -#if !SWIG - using Base::end_arc_index; - using Base::IsNodeValid; - - using Base::kFirstArc; - using Base::kFirstNode; - using Base::kMaxNumArcs; - using Base::kMaxNumNodes; - using Base::kNilArc; - using Base::kNilNode; -#endif // SWIG - - // Reserves memory needed for max_num_nodes nodes and max_num_arcs arcs. - // Returns false if the parameters passed are not OK. - // It can be used to enlarge the graph, but does not shrink memory - // if called with smaller values. - bool Reserve(NodeIndexType new_max_num_nodes, ArcIndexType new_max_num_arcs) { - if (new_max_num_nodes < 0 || new_max_num_nodes > kMaxNumNodes) { - return false; - } - if (new_max_num_arcs < 0 || new_max_num_arcs > kMaxNumArcs) { - return false; - } - first_incident_arc_.Reserve(kFirstNode, new_max_num_nodes - 1); - for (NodeIndexType node = max_num_nodes_; - node <= first_incident_arc_.max_index(); ++node) { - first_incident_arc_.Set(node, kNilArc); - } - ThisAsDerived()->ReserveInternal(new_max_num_nodes, new_max_num_arcs); - max_num_nodes_ = new_max_num_nodes; - max_num_arcs_ = new_max_num_arcs; - return true; - } - - // Adds an arc to the graph and returns its index. - // Returns kNilArc if the arc could not be added. - // Note that for a given pair (tail, head) AddArc does not overwrite an - // already-existing arc between tail and head: Another arc is created - // instead. This makes it possible to handle multi-graphs. - ArcIndexType AddArc(NodeIndexType tail, NodeIndexType head) { - if (num_arcs_ >= max_num_arcs_ || !IsNodeValid(tail) || - !IsNodeValid(head)) { - return kNilArc; - } - if (tail + 1 > num_nodes_) { - num_nodes_ = tail + 1; // max does not work on int16_t. - } - if (head + 1 > num_nodes_) { - num_nodes_ = head + 1; - } - ArcIndexType arc = num_arcs_; - ++num_arcs_; - ThisAsDerived()->RecordArc(arc, tail, head); - return arc; - } - -// TODO(user): Configure SWIG to handle the GroupForwardArcsByFunctor -// member template and the CycleHandlerForAnnotatedArcs class. -#if !SWIG - template - void GroupForwardArcsByFunctor( - const ArcIndexTypeStrictWeakOrderingFunctor& compare, - PermutationCycleHandler* annotation_handler) { - std::unique_ptr arc_permutation( - new ArcIndexType[end_arc_index()]); - - // Determine the permutation that groups arcs by their tail nodes. - for (ArcIndexType i = 0; i < end_arc_index(); ++i) { - // Start with the identity permutation. - arc_permutation[i] = i; - } - std::sort(&arc_permutation[kFirstArc], &arc_permutation[end_arc_index()], - compare); - - // Now we actually permute the head_ array and the - // scaled_arc_cost_ array according to the sorting permutation. - CycleHandlerForAnnotatedArcs cycle_handler(annotation_handler, - ThisAsDerived()); - PermutationApplier permutation(&cycle_handler); - permutation.Apply(&arc_permutation[0], kFirstArc, end_arc_index()); - - // Finally, rebuild the graph from its permuted head_ array. - ThisAsDerived()->BuildRepresentation(); - } - - class CycleHandlerForAnnotatedArcs - : public PermutationCycleHandler { - public: - CycleHandlerForAnnotatedArcs( - PermutationCycleHandler* annotation_handler, - DerivedGraph* graph) - : annotation_handler_(annotation_handler), - graph_(graph), - head_temp_(kNilNode), - tail_temp_(kNilNode) {} - - // This type is neither copyable nor movable. - CycleHandlerForAnnotatedArcs(const CycleHandlerForAnnotatedArcs&) = delete; - CycleHandlerForAnnotatedArcs& operator=( - const CycleHandlerForAnnotatedArcs&) = delete; - - void SetTempFromIndex(ArcIndexType source) override { - if (annotation_handler_ != nullptr) { - annotation_handler_->SetTempFromIndex(source); - } - head_temp_ = graph_->Head(source); - tail_temp_ = graph_->Tail(source); - } - - void SetIndexFromIndex(ArcIndexType source, - ArcIndexType destination) const override { - if (annotation_handler_ != nullptr) { - annotation_handler_->SetIndexFromIndex(source, destination); - } - graph_->SetHead(destination, graph_->Head(source)); - graph_->SetTail(destination, graph_->Tail(source)); - } - - void SetIndexFromTemp(ArcIndexType destination) const override { - if (annotation_handler_ != nullptr) { - annotation_handler_->SetIndexFromTemp(destination); - } - graph_->SetHead(destination, head_temp_); - graph_->SetTail(destination, tail_temp_); - } - - // Since we are free to destroy the permutation array we use the - // kNilArc value to mark entries in the array that have been - // processed already. There is no need to be able to recover the - // original permutation array entries once they have been seen. - void SetSeen(ArcIndexType* permutation_element) const override { - *permutation_element = kNilArc; - } - - bool Unseen(ArcIndexType permutation_element) const override { - return permutation_element != kNilArc; - } - - ~CycleHandlerForAnnotatedArcs() override {} - - private: - PermutationCycleHandler* annotation_handler_; - DerivedGraph* graph_; - NodeIndexType head_temp_; - NodeIndexType tail_temp_; - }; -#endif // SWIG - - // Using the SetHead() method implies that the BuildRepresentation() - // method must be called to restore consistency before the graph is - // used. - // - // Visible for testing. - void SetHead(const ArcIndexType arc, const NodeIndexType head) { - representation_clean_ = false; - head_.Set(arc, head); - } - - protected: - EbertGraphBase() : next_adjacent_arc_(), representation_clean_(true) {} - - ~EbertGraphBase() {} - - void Initialize(NodeIndexType max_num_nodes, ArcIndexType max_num_arcs) { - if (!Reserve(max_num_nodes, max_num_arcs)) { - LOG(DFATAL) << "Could not reserve memory for " - << static_cast(max_num_nodes) << " nodes and " - << static_cast(max_num_arcs) << " arcs."; - } - first_incident_arc_.SetAll(kNilArc); - ThisAsDerived()->InitializeInternal(max_num_nodes, max_num_arcs); - } - - // Returns the first arc in node's incidence list. - ArcIndexType FirstOutgoingOrOppositeIncomingArc( - const NodeIndexType node) const { - DCHECK(representation_clean_); - DCHECK(IsNodeValid(node)); - return first_incident_arc_[node]; - } - - // Returns the next arc following the passed argument in its adjacency list. - ArcIndexType NextAdjacentArc(const ArcIndexType arc) const { - DCHECK(representation_clean_); - DCHECK(ThisAsDerived()->CheckArcValidity(arc)); - return next_adjacent_arc_[arc]; - } - - // Returns the outgoing arc following the argument in the adjacency list. - ArcIndexType NextOutgoingArc(const NodeIndexType unused_node, - const ArcIndexType arc) const { - DCHECK(ThisAsDerived()->CheckArcValidity(arc)); - DCHECK(ThisAsDerived()->IsDirect(arc)); - return ThisAsDerived()->FindNextOutgoingArc(NextAdjacentArc(arc)); - } - - // Array of next indices. - // next_adjacent_arc_[i] contains the next arc in the adjacency list of arc i. - ZVector next_adjacent_arc_; - - // Flag to indicate that BuildRepresentation() needs to be called - // before the adjacency lists are examined. Only for DCHECK in debug - // builds. - bool representation_clean_; - - private: - // Shorthand: returns a const DerivedGraph*-typed version of our - // "this" pointer. - inline const DerivedGraph* ThisAsDerived() const { - return static_cast(this); - } - - // Shorthand: returns a DerivedGraph*-typed version of our "this" - // pointer. - inline DerivedGraph* ThisAsDerived() { - return static_cast(this); - } - - void InitializeInternal(NodeIndexType max_num_nodes, - ArcIndexType max_num_arcs) { - next_adjacent_arc_.SetAll(kNilArc); - } - - bool RepresentationClean() const { return representation_clean_; } -}; - -// Most users should only use StarGraph, which is EbertGraph, -// and other type shortcuts; see the bottom of this file. -template -class ABSL_DEPRECATED("Use `::util::ListGraph<>` instead.") EbertGraph - : public EbertGraphBase > { - typedef EbertGraphBase > - Base; - friend class EbertGraphBase >; - friend class StarGraphBase >; - - using Base::ArcDebugString; - using Base::FirstOutgoingOrOppositeIncomingArc; - using Base::Initialize; - using Base::NextAdjacentArc; - using Base::NodeDebugString; - - using Base::first_incident_arc_; - using Base::head_; - using Base::max_num_arcs_; - using Base::max_num_nodes_; - using Base::next_adjacent_arc_; - using Base::num_arcs_; - using Base::num_nodes_; - using Base::representation_clean_; - - public: -#if !SWIG - using Base::Head; - using Base::IsNodeValid; - - using Base::kFirstArc; - using Base::kFirstNode; - using Base::kNilArc; - using Base::kNilNode; -#endif // SWIG - - typedef NodeIndexType NodeIndex; - typedef ArcIndexType ArcIndex; - static constexpr bool kHasNegativeReverseArcs = true; - - EbertGraph() {} - - EbertGraph(NodeIndexType max_num_nodes, ArcIndexType max_num_arcs) { - Initialize(max_num_nodes, max_num_arcs); - } - - ~EbertGraph() {} - -#if !SWIG - // Iterator class for traversing the arcs incident to a given node in the - // graph. - class OutgoingOrOppositeIncomingArcIterator { - public: - OutgoingOrOppositeIncomingArcIterator(const EbertGraph& graph, - NodeIndexType node) - : graph_(graph), - node_(graph_.StartNode(node)), - arc_(graph_.StartArc( - graph_.FirstOutgoingOrOppositeIncomingArc(node))) { - DCHECK(CheckInvariant()); - } - - // This constructor takes an arc as extra argument and makes the iterator - // start at arc. - OutgoingOrOppositeIncomingArcIterator(const EbertGraph& graph, - NodeIndexType node, ArcIndexType arc) - : graph_(graph), - node_(graph_.StartNode(node)), - arc_(graph_.StartArc(arc)) { - DCHECK(CheckInvariant()); - } - - // Can only assign from an iterator on the same graph. - void operator=(const OutgoingOrOppositeIncomingArcIterator& iterator) { - DCHECK(&iterator.graph_ == &graph_); - node_ = iterator.node_; - arc_ = iterator.arc_; - } - - // Returns true unless all the adjancent arcs have been traversed. - bool Ok() const { return arc_ != kNilArc; } - - // Advances the current adjacent arc index. - void Next() { - arc_ = graph_.NextAdjacentArc(arc_); - DCHECK(CheckInvariant()); - } - - // Returns the index of the arc currently pointed to by the iterator. - ArcIndexType Index() const { return arc_; } - - DEFINE_STL_ITERATOR_FUNCTIONS(OutgoingOrOppositeIncomingArcIterator); - - private: - // Returns true if the invariant for the iterator is verified. - // To be used in a DCHECK. - bool CheckInvariant() const { - if (arc_ == kNilArc) { - return true; // This occurs when the iterator has reached the end. - } - DCHECK(graph_.IsOutgoingOrOppositeIncoming(arc_, node_)); - return true; - } - // A reference to the current EbertGraph considered. - const EbertGraph& graph_; - - // The index of the node on which arcs are iterated. - NodeIndexType node_; - - // The index of the current arc considered. - ArcIndexType arc_; - }; - - // Iterator class for traversing the incoming arcs associated to a given node. - class IncomingArcIterator { - public: - IncomingArcIterator(const EbertGraph& graph, NodeIndexType node) - : graph_(graph), - node_(graph_.StartNode(node)), - arc_(graph_.StartArc(graph_.FirstIncomingArc(node))) { - DCHECK(CheckInvariant()); - } - - // This constructor takes an arc as extra argument and makes the iterator - // start at arc. - IncomingArcIterator(const EbertGraph& graph, NodeIndexType node, - ArcIndexType arc) - : graph_(graph), - node_(graph_.StartNode(node)), - arc_(arc == kNilArc ? kNilArc - : graph_.StartArc(graph_.Opposite(arc))) { - DCHECK(CheckInvariant()); - } - - // Can only assign from an iterator on the same graph. - void operator=(const IncomingArcIterator& iterator) { - DCHECK(&iterator.graph_ == &graph_); - node_ = iterator.node_; - arc_ = iterator.arc_; - } - - // Returns true unless all the incoming arcs have been traversed. - bool Ok() const { return arc_ != kNilArc; } - - // Advances the current incoming arc index. - void Next() { - arc_ = graph_.NextIncomingArc(arc_); - DCHECK(CheckInvariant()); - } - - // Returns the index of the arc currently pointed to by the iterator. - ArcIndexType Index() const { - return arc_ == kNilArc ? kNilArc : graph_.Opposite(arc_); - } - - private: - // Returns true if the invariant for the iterator is verified. - // To be used in a DCHECK. - bool CheckInvariant() const { - if (arc_ == kNilArc) { - return true; // This occurs when the iterator has reached the end. - } - DCHECK(graph_.IsIncoming(Index(), node_)); - return true; - } - // A reference to the current EbertGraph considered. - const EbertGraph& graph_; - - // The index of the node on which arcs are iterated. - NodeIndexType node_; - - // The index of the current arc considered. - ArcIndexType arc_; - }; -#endif // SWIG - - // Minimal change to use StarGraph with the new util/graph.h API. - // EbertGraph is going away, so this is just temporary. - void Build() {} - void Build(std::vector* permutation) { permutation->clear(); } - ArcIndexType OppositeArc(ArcIndex arc) const { return Opposite(arc); } - util::BeginEndWrapper OutgoingArcs( - NodeIndexType node) const { - return util::BeginEndWrapper( - typename Base::OutgoingArcIterator(*this, node), - typename Base::OutgoingArcIterator(*this, node, kNilArc)); - } - util::BeginEndWrapper - OutgoingOrOppositeIncomingArcs(NodeIndexType node) const { - return util::BeginEndWrapper( - OutgoingOrOppositeIncomingArcIterator(*this, node), - OutgoingOrOppositeIncomingArcIterator(*this, node, kNilArc)); - } - util::BeginEndWrapper - OutgoingOrOppositeIncomingArcsStartingFrom(NodeIndexType node, - ArcIndex arc) const { - return util::BeginEndWrapper( - OutgoingOrOppositeIncomingArcIterator(*this, node, arc), - OutgoingOrOppositeIncomingArcIterator(*this, node, kNilArc)); - } - - // Utility function to check that an arc index is within the bounds. - // It is exported so that users of the EbertGraph class can use it. - // To be used in a DCHECK. - bool CheckArcBounds(const ArcIndexType arc) const { - return (arc == kNilArc) || (arc >= -max_num_arcs_ && arc < max_num_arcs_); - } - - // Utility function to check that an arc index is within the bounds AND - // different from kNilArc. - // It is exported so that users of the EbertGraph class can use it. - // To be used in a DCHECK. - bool CheckArcValidity(const ArcIndexType arc) const { - return (arc != kNilArc) && (arc >= -max_num_arcs_ && arc < max_num_arcs_); - } - - // Returns the tail or start-node of arc. - NodeIndexType Tail(const ArcIndexType arc) const { - DCHECK(CheckArcValidity(arc)); - return head_[Opposite(arc)]; - } - - // Returns the tail or start-node of arc if it is positive - // (i.e. it is taken in the direction it was entered in the graph), - // and the head or end-node otherwise. 'This' in Ebert's paper. - NodeIndexType DirectArcTail(const ArcIndexType arc) const { - return Tail(DirectArc(arc)); - } - - // Returns the head or end-node of arc if it is positive - // (i.e. it is taken in the direction it was entered in the graph), - // and the tail or start-node otherwise. 'That' in Ebert's paper. - NodeIndexType DirectArcHead(const ArcIndexType arc) const { - return Head(DirectArc(arc)); - } - - // Returns the arc in normal/direct direction. - ArcIndexType DirectArc(const ArcIndexType arc) const { - DCHECK(CheckArcValidity(arc)); - return std::max(arc, Opposite(arc)); - } - - // Returns the arc in reverse direction. - ArcIndexType ReverseArc(const ArcIndexType arc) const { - DCHECK(CheckArcValidity(arc)); - return std::min(arc, Opposite(arc)); - } - - // Returns the opposite arc, i.e. the direct arc is the arc is in reverse - // direction, and the reverse arc if the arc is direct. - ArcIndexType Opposite(const ArcIndexType arc) const { - const ArcIndexType opposite = ~arc; - DCHECK(CheckArcValidity(arc)); - DCHECK(CheckArcValidity(opposite)); - return opposite; - } - - // Returns true if the arc is direct. - bool IsDirect(const ArcIndexType arc) const { - DCHECK(CheckArcBounds(arc)); - return arc != kNilArc && arc >= 0; - } - - // Returns true if the arc is in the reverse direction. - bool IsReverse(const ArcIndexType arc) const { - DCHECK(CheckArcBounds(arc)); - return arc != kNilArc && arc < 0; - } - - // Returns true if arc is incident to node. - bool IsOutgoingOrOppositeIncoming(ArcIndexType arc, - NodeIndexType node) const { - return Tail(arc) == node; - } - - // Returns true if arc is incoming to node. - bool IsIncoming(ArcIndexType arc, NodeIndexType node) const { - return IsDirect(arc) && Head(arc) == node; - } - - // Returns true if arc is outgoing from node. - bool IsOutgoing(ArcIndexType arc, NodeIndexType node) const { - return IsDirect(arc) && Tail(arc) == node; - } - - // Recreates the next_adjacent_arc_ and first_incident_arc_ variables from - // the array head_ in O(n + m) time. - // This is useful if head_ array has been sorted according to a given - // criterion, for example. - void BuildRepresentation() { - first_incident_arc_.SetAll(kNilArc); - for (ArcIndexType arc = kFirstArc; arc < max_num_arcs_; ++arc) { - Attach(arc); - } - representation_clean_ = true; - } - - // Returns a debug string containing all the information contained in the - // data structure in raw form. - std::string DebugString() const { - DCHECK(representation_clean_); - std::string result = "Arcs:(node, next arc) :\n"; - for (ArcIndexType arc = -num_arcs_; arc < num_arcs_; ++arc) { - result += " " + ArcDebugString(arc) + ":(" + NodeDebugString(head_[arc]) + - "," + ArcDebugString(next_adjacent_arc_[arc]) + ")\n"; - } - result += "Node:First arc :\n"; - for (NodeIndexType node = kFirstNode; node < num_nodes_; ++node) { - result += " " + NodeDebugString(node) + ":" + - ArcDebugString(first_incident_arc_[node]) + "\n"; - } - return result; - } - - private: - // Handles reserving space in the next_adjacent_arc_ and head_ - // arrays, which are always present and are therefore in the base - // class. Although they reside in the base class, those two arrays - // are maintained differently by different derived classes, - // depending on whether the derived class stores reverse arcs. Hence - // the code to set those arrays up is in a method of the derived - // class. - void ReserveInternal(NodeIndexType new_max_num_nodes, - ArcIndexType new_max_num_arcs) { - head_.Reserve(-new_max_num_arcs, new_max_num_arcs - 1); - next_adjacent_arc_.Reserve(-new_max_num_arcs, new_max_num_arcs - 1); - for (ArcIndexType arc = -new_max_num_arcs; arc < -max_num_arcs_; ++arc) { - head_.Set(arc, kNilNode); - next_adjacent_arc_.Set(arc, kNilArc); - } - for (ArcIndexType arc = max_num_arcs_; arc < new_max_num_arcs; ++arc) { - head_.Set(arc, kNilNode); - next_adjacent_arc_.Set(arc, kNilArc); - } - } - - // Returns the first incoming arc for node. - ArcIndexType FirstIncomingArc(const NodeIndexType node) const { - DCHECK_LE(kFirstNode, node); - DCHECK_GE(max_num_nodes_, node); - return FindNextIncomingArc(FirstOutgoingOrOppositeIncomingArc(node)); - } - - // Returns the incoming arc following the argument in the adjacency list. - ArcIndexType NextIncomingArc(const ArcIndexType arc) const { - DCHECK(CheckArcValidity(arc)); - DCHECK(IsReverse(arc)); - return FindNextIncomingArc(NextAdjacentArc(arc)); - } - - // Handles the part of AddArc() that is not in common with other - // graph classes based on the EbertGraphBase template. - void RecordArc(ArcIndexType arc, NodeIndexType tail, NodeIndexType head) { - head_.Set(Opposite(arc), tail); - head_.Set(arc, head); - Attach(arc); - } - - // Using the SetTail() method implies that the BuildRepresentation() - // method must be called to restore consistency before the graph is - // used. - void SetTail(const ArcIndexType arc, const NodeIndexType tail) { - representation_clean_ = false; - head_.Set(Opposite(arc), tail); - } - - // Utility method to attach a new arc. - void Attach(ArcIndexType arc) { - DCHECK(CheckArcValidity(arc)); - const NodeIndexType tail = head_[Opposite(arc)]; - DCHECK(IsNodeValid(tail)); - next_adjacent_arc_.Set(arc, first_incident_arc_[tail]); - first_incident_arc_.Set(tail, arc); - const NodeIndexType head = head_[arc]; - DCHECK(IsNodeValid(head)); - next_adjacent_arc_.Set(Opposite(arc), first_incident_arc_[head]); - first_incident_arc_.Set(head, Opposite(arc)); - } - - // Utility method that finds the next outgoing arc. - ArcIndexType FindNextOutgoingArc(ArcIndexType arc) const { - DCHECK(CheckArcBounds(arc)); - while (IsReverse(arc)) { - arc = NextAdjacentArc(arc); - DCHECK(CheckArcBounds(arc)); - } - return arc; - } - - // Utility method that finds the next incoming arc. - ArcIndexType FindNextIncomingArc(ArcIndexType arc) const { - DCHECK(CheckArcBounds(arc)); - while (IsDirect(arc)) { - arc = NextAdjacentArc(arc); - DCHECK(CheckArcBounds(arc)); - } - return arc; - } -}; - -// A forward-star-only graph representation for greater efficiency in -// those algorithms that don't need reverse arcs. -template -class ABSL_DEPRECATED("Use `::util::ListGraph<>` instead.") ForwardEbertGraph - : public EbertGraphBase > { - typedef EbertGraphBase > - Base; - friend class EbertGraphBase >; - friend class StarGraphBase >; - - using Base::ArcDebugString; - using Base::Initialize; - using Base::NextAdjacentArc; - using Base::NodeDebugString; - - using Base::first_incident_arc_; - using Base::head_; - using Base::max_num_arcs_; - using Base::max_num_nodes_; - using Base::next_adjacent_arc_; - using Base::num_arcs_; - using Base::num_nodes_; - using Base::representation_clean_; - - public: -#if !SWIG - using Base::Head; - using Base::IsNodeValid; - - using Base::kFirstArc; - using Base::kFirstNode; - using Base::kNilArc; - using Base::kNilNode; -#endif // SWIG - - typedef NodeIndexType NodeIndex; - typedef ArcIndexType ArcIndex; - - ForwardEbertGraph() {} - - ForwardEbertGraph(NodeIndexType max_num_nodes, ArcIndexType max_num_arcs) { - Initialize(max_num_nodes, max_num_arcs); - } - - ~ForwardEbertGraph() {} - - // Utility function to check that an arc index is within the bounds. - // It is exported so that users of the ForwardEbertGraph class can use it. - // To be used in a DCHECK. - bool CheckArcBounds(const ArcIndexType arc) const { - return (arc == kNilArc) || (arc >= kFirstArc && arc < max_num_arcs_); - } - - // Utility function to check that an arc index is within the bounds AND - // different from kNilArc. - // It is exported so that users of the ForwardEbertGraph class can use it. - // To be used in a DCHECK. - bool CheckArcValidity(const ArcIndexType arc) const { - return (arc != kNilArc) && (arc >= kFirstArc && arc < max_num_arcs_); - } - - // Returns true if arc is a valid index into the (*tail_) array. - bool CheckTailIndexValidity(const ArcIndexType arc) const { - return (tail_ != nullptr) && (arc >= kFirstArc) && - (arc <= tail_->max_index()); - } - - // Returns the tail or start-node of arc. - NodeIndexType Tail(const ArcIndexType arc) const { - DCHECK(CheckArcValidity(arc)); - DCHECK(CheckTailIndexValidity(arc)); - return (*tail_)[arc]; - } - - // Returns true if arc is incoming to node. - bool IsIncoming(ArcIndexType arc, NodeIndexType node) const { - return IsDirect(arc) && Head(arc) == node; - } - - // Recreates the next_adjacent_arc_ and first_incident_arc_ - // variables from the arrays head_ and tail_ in O(n + m) time. This - // is useful if the head_ and tail_ arrays have been sorted - // according to a given criterion, for example. - void BuildRepresentation() { - first_incident_arc_.SetAll(kNilArc); - DCHECK(TailArrayComplete()); - for (ArcIndexType arc = kFirstArc; arc < max_num_arcs_; ++arc) { - DCHECK(CheckTailIndexValidity(arc)); - Attach((*tail_)[arc], arc); - } - representation_clean_ = true; - } - - bool BuildTailArray() { - // If (*tail_) is already allocated, we have the invariant that - // its contents are canonical, so we do not need to do anything - // here in that case except return true. - if (tail_ == nullptr) { - if (!representation_clean_) { - // We have been asked to build the (*tail_) array, but we have - // no valid information from which to build it. The graph is - // in an unrecoverable, inconsistent state. - return false; - } - // Reallocate (*tail_) and rebuild its contents from the - // adjacency lists. - tail_.reset(new ZVector); - tail_->Reserve(kFirstArc, max_num_arcs_ - 1); - typename Base::NodeIterator node_it(*this); - for (; node_it.Ok(); node_it.Next()) { - NodeIndexType node = node_it.Index(); - typename Base::OutgoingArcIterator arc_it(*this, node); - for (; arc_it.Ok(); arc_it.Next()) { - (*tail_)[arc_it.Index()] = node; - } - } - } - DCHECK(TailArrayComplete()); - return true; - } - - void ReleaseTailArray() { tail_.reset(nullptr); } - - // To be used in a DCHECK(). - bool TailArrayComplete() const { - CHECK(tail_); - for (ArcIndexType arc = kFirstArc; arc < num_arcs_; ++arc) { - CHECK(CheckTailIndexValidity(arc)); - CHECK(IsNodeValid((*tail_)[arc])); - } - return true; - } - - // Returns a debug string containing all the information contained in the - // data structure in raw form. - std::string DebugString() const { - DCHECK(representation_clean_); - std::string result = "Arcs:(node, next arc) :\n"; - for (ArcIndexType arc = kFirstArc; arc < num_arcs_; ++arc) { - result += " " + ArcDebugString(arc) + ":(" + NodeDebugString(head_[arc]) + - "," + ArcDebugString(next_adjacent_arc_[arc]) + ")\n"; - } - result += "Node:First arc :\n"; - for (NodeIndexType node = kFirstNode; node < num_nodes_; ++node) { - result += " " + NodeDebugString(node) + ":" + - ArcDebugString(first_incident_arc_[node]) + "\n"; - } - return result; - } - - private: - // Reserves space for the (*tail_) array. - // - // This method is separate from ReserveInternal() because our - // practice of making the (*tail_) array optional implies that the - // tail_ pointer might not be constructed when the ReserveInternal() - // method is called. Therefore we have this method also, and we - // ensure that it is called only when tail_ is guaranteed to have - // been initialized. - void ReserveTailArray(ArcIndexType new_max_num_arcs) { - if (tail_ != nullptr) { - // The (*tail_) values are already canonical, so we're just - // reserving additional space for new arcs that haven't been - // added yet. - if (tail_->Reserve(kFirstArc, new_max_num_arcs - 1)) { - for (ArcIndexType arc = tail_->max_index() + 1; arc < new_max_num_arcs; - ++arc) { - tail_->Set(arc, kNilNode); - } - } - } - } - - // Reserves space for the arrays indexed by arc indices, except - // (*tail_) even if it is present. We cannot grow the (*tail_) array - // in this method because this method is called from - // Base::Reserve(), which in turn is called from the base template - // class constructor. That base class constructor is called on *this - // before tail_ is constructed. Hence when this method is called, - // tail_ might contain garbage. This method can safely refer only to - // fields of the base template class, not to fields of *this outside - // the base template class. - // - // The strange situation in which this method of a derived class can - // refer only to members of the base class arises because different - // derived classes use the data members of the base class in - // slightly different ways. The purpose of this derived class - // method, then, is only to encode the derived-class-specific - // conventions for how the derived class uses the data members of - // the base class. - // - // To be specific, the forward-star graph representation, lacking - // reverse arcs, allocates only the positive index range for the - // head_ and next_adjacent_arc_ arrays, while the general - // representation allocates space for both positive- and - // negative-indexed arcs (i.e., both forward and reverse arcs). - void ReserveInternal(NodeIndexType new_max_num_nodes, - ArcIndexType new_max_num_arcs) { - head_.Reserve(kFirstArc, new_max_num_arcs - 1); - next_adjacent_arc_.Reserve(kFirstArc, new_max_num_arcs - 1); - for (ArcIndexType arc = max_num_arcs_; arc < new_max_num_arcs; ++arc) { - head_.Set(arc, kNilNode); - next_adjacent_arc_.Set(arc, kNilArc); - } - ReserveTailArray(new_max_num_arcs); - } - - // Handles the part of AddArc() that is not in common wth other - // graph classes based on the EbertGraphBase template. - void RecordArc(ArcIndexType arc, NodeIndexType tail, NodeIndexType head) { - head_.Set(arc, head); - Attach(tail, arc); - } - - // Using the SetTail() method implies that the BuildRepresentation() - // method must be called to restore consistency before the graph is - // used. - void SetTail(const ArcIndexType arc, const NodeIndexType tail) { - DCHECK(CheckTailIndexValidity(arc)); - CHECK(tail_); - representation_clean_ = false; - tail_->Set(arc, tail); - } - - // Utility method to attach a new arc. - void Attach(NodeIndexType tail, ArcIndexType arc) { - DCHECK(CheckArcValidity(arc)); - DCHECK(IsNodeValid(tail)); - next_adjacent_arc_.Set(arc, first_incident_arc_[tail]); - first_incident_arc_.Set(tail, arc); - const NodeIndexType head = head_[arc]; - DCHECK(IsNodeValid(head)); - // Because Attach() is a public method, keeping (*tail_) canonical - // requires us to record the new arc's tail here. - if (tail_ != nullptr) { - DCHECK(CheckTailIndexValidity(arc)); - tail_->Set(arc, tail); - } - } - - // Utility method that finds the next outgoing arc. - ArcIndexType FindNextOutgoingArc(ArcIndexType arc) const { - DCHECK(CheckArcBounds(arc)); - return arc; - } - - private: - // Always returns true because for any ForwardEbertGraph, only - // direct arcs are represented, so all valid arc indices refer to - // arcs that are outgoing from their tail nodes. - bool IsOutgoing(const ArcIndex unused_arc, - const NodeIndex unused_node) const { - return true; - } - - // Always returns true because for any ForwardEbertGraph, only - // outgoing arcs are represented, so all valid arc indices refer to - // direct arcs. - bool IsDirect(const ArcIndex unused_arc) const { return true; } - - // Array of node indices, not always present. (*tail_)[i] contains - // the tail node of arc i. This array is not needed for normal graph - // traversal operations, but is used in optimizing the graph's - // layout so arcs are grouped by tail node, and can be used in one - // approach to serializing the graph. - // - // Invariants: At any time when we are not executing a method of - // this class, either tail_ == NULL or the tail_ array's contents - // are kept canonical. If tail_ != NULL, any method that modifies - // adjacency lists must also ensure (*tail_) is modified - // correspondingly. The converse does not hold: Modifications to - // (*tail_) are allowed without updating the adjacency lists. If - // such modifications take place, representation_clean_ must be set - // to false, of course, to indicate that the adjacency lists are no - // longer current. - std::unique_ptr > tail_; -}; - -// Traits for EbertGraphBase types, for use in testing and clients -// that work with both forward-only and forward/reverse graphs. -// -// The default is to assume reverse arcs so if someone forgets to -// specialize the traits of a new forward-only graph type, they will -// get errors from tests rather than incomplete testing. -template -struct graph_traits { - static constexpr bool has_reverse_arcs = true; - static constexpr bool is_dynamic = true; -}; - -template -struct graph_traits > { - static constexpr bool has_reverse_arcs = false; - static constexpr bool is_dynamic = true; -}; - -namespace or_internal { - -// The TailArrayBuilder class template is not expected to be used by -// clients. It is a helper for the TailArrayManager template. -// -// The TailArrayBuilder for graphs with reverse arcs does nothing. -template -struct TailArrayBuilder { - explicit TailArrayBuilder(GraphType* unused_graph) {} - - bool BuildTailArray() const { return true; } -}; - -// The TailArrayBuilder for graphs without reverse arcs calls the -// appropriate method on the graph from the TailArrayBuilder -// constructor. -template -struct TailArrayBuilder { - explicit TailArrayBuilder(GraphType* graph) : graph_(graph) {} - - bool BuildTailArray() const { return graph_->BuildTailArray(); } - - GraphType* const graph_; -}; - -// The TailArrayReleaser class template is not expected to be used by -// clients. It is a helper for the TailArrayManager template. -// -// The TailArrayReleaser for graphs with reverse arcs does nothing. -template -struct TailArrayReleaser { - explicit TailArrayReleaser(GraphType* unused_graph) {} - - void ReleaseTailArray() const {} -}; - -// The TailArrayReleaser for graphs without reverse arcs calls the -// appropriate method on the graph from the TailArrayReleaser -// constructor. -template -struct TailArrayReleaser { - explicit TailArrayReleaser(GraphType* graph) : graph_(graph) {} - - void ReleaseTailArray() const { graph_->ReleaseTailArray(); } - - GraphType* const graph_; -}; - -} // namespace or_internal - -template -class TailArrayManager { - public: - explicit TailArrayManager(GraphType* g) : graph_(g) {} - - bool BuildTailArrayFromAdjacencyListsIfForwardGraph() const { - or_internal::TailArrayBuilder::has_reverse_arcs> - tail_array_builder(graph_); - return tail_array_builder.BuildTailArray(); - } - - void ReleaseTailArrayIfForwardGraph() const { - or_internal::TailArrayReleaser::has_reverse_arcs> - tail_array_releaser(graph_); - tail_array_releaser.ReleaseTailArray(); - } - - private: - GraphType* graph_; -}; - -template -class ArcFunctorOrderingByTailAndHead { - public: - explicit ArcFunctorOrderingByTailAndHead(const GraphType& graph) - : graph_(graph) {} - - bool operator()(typename GraphType::ArcIndex a, - typename GraphType::ArcIndex b) const { - return ((graph_.Tail(a) < graph_.Tail(b)) || - ((graph_.Tail(a) == graph_.Tail(b)) && - (graph_.Head(a) < graph_.Head(b)))); - } - - private: - const GraphType& graph_; -}; - -namespace or_internal { - -// The GraphBuilderFromArcs class template is not expected to be used -// by clients. It is a helper for the AnnotatedGraphBuildManager -// template. -// -// Deletes itself upon returning the graph! -template -class GraphBuilderFromArcs { - public: - GraphBuilderFromArcs(typename GraphType::NodeIndex max_num_nodes, - typename GraphType::ArcIndex max_num_arcs, - bool sort_arcs) - : num_arcs_(0), sort_arcs_(sort_arcs) { - Reserve(max_num_nodes, max_num_arcs); - } - - typename GraphType::ArcIndex AddArc(typename GraphType::NodeIndex tail, - typename GraphType::NodeIndex head) { - DCHECK_LT(num_arcs_, max_num_arcs_); - DCHECK_LT(tail, GraphType::kFirstNode + max_num_nodes_); - DCHECK_LT(head, GraphType::kFirstNode + max_num_nodes_); - if (num_arcs_ < max_num_arcs_ && - tail < GraphType::kFirstNode + max_num_nodes_ && - head < GraphType::kFirstNode + max_num_nodes_) { - typename GraphType::ArcIndex result = GraphType::kFirstArc + num_arcs_; - arcs_.push_back(std::make_pair(tail, head)); - num_arcs_ += 1; - return result; - } else { - // Too many arcs or node index out of bounds! - return GraphType::kNilArc; - } - } - - // Builds the graph from the given arcs. - GraphType* Graph(PermutationCycleHandler* - client_cycle_handler) { - GraphType* graph = new GraphType(max_num_nodes_, num_arcs_, sort_arcs_, - &arcs_, client_cycle_handler); - delete this; - return graph; - } - - private: - bool Reserve(typename GraphType::NodeIndex new_max_num_nodes, - typename GraphType::ArcIndex new_max_num_arcs) { - max_num_nodes_ = new_max_num_nodes; - max_num_arcs_ = new_max_num_arcs; - arcs_.reserve(new_max_num_arcs); - return true; - } - - typename GraphType::NodeIndex max_num_nodes_; - typename GraphType::ArcIndex max_num_arcs_; - typename GraphType::ArcIndex num_arcs_; - - std::vector< - std::pair > - arcs_; - - const bool sort_arcs_; -}; - -// Trivial delegating specialization for dynamic graphs. -// -// Deletes itself upon returning the graph! -template -class GraphBuilderFromArcs { - public: - GraphBuilderFromArcs(typename GraphType::NodeIndex max_num_nodes, - typename GraphType::ArcIndex max_num_arcs, - bool sort_arcs) - : graph_(new GraphType(max_num_nodes, max_num_arcs)), - sort_arcs_(sort_arcs) {} - - bool Reserve(const typename GraphType::NodeIndex new_max_num_nodes, - const typename GraphType::ArcIndex new_max_num_arcs) { - return graph_->Reserve(new_max_num_nodes, new_max_num_arcs); - } - - typename GraphType::ArcIndex AddArc( - const typename GraphType::NodeIndex tail, - const typename GraphType::NodeIndex head) { - return graph_->AddArc(tail, head); - } - - GraphType* Graph(PermutationCycleHandler* - client_cycle_handler) { - if (sort_arcs_) { - TailArrayManager tail_array_manager(graph_); - tail_array_manager.BuildTailArrayFromAdjacencyListsIfForwardGraph(); - ArcFunctorOrderingByTailAndHead arc_ordering(*graph_); - graph_->GroupForwardArcsByFunctor(arc_ordering, client_cycle_handler); - tail_array_manager.ReleaseTailArrayIfForwardGraph(); - } - GraphType* result = graph_; - delete this; - return result; - } - - private: - GraphType* const graph_; - const bool sort_arcs_; -}; - -} // namespace or_internal - -template -class AnnotatedGraphBuildManager - : public or_internal::GraphBuilderFromArcs< - GraphType, graph_traits::is_dynamic> { - public: - AnnotatedGraphBuildManager(typename GraphType::NodeIndex num_nodes, - typename GraphType::ArcIndex num_arcs, - bool sort_arcs) - : or_internal::GraphBuilderFromArcs::is_dynamic>( - num_nodes, num_arcs, sort_arcs) {} -}; - -#undef DEFINE_STL_ITERATOR_FUNCTIONS } // namespace operations_research #endif // OR_TOOLS_GRAPH_EBERT_GRAPH_H_ diff --git a/ortools/graph/ebert_graph_test.cc b/ortools/graph/ebert_graph_test.cc deleted file mode 100644 index 9c13432df6..0000000000 --- a/ortools/graph/ebert_graph_test.cc +++ /dev/null @@ -1,1169 +0,0 @@ -// Copyright 2010-2025 Google LLC -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ortools/graph/ebert_graph.h" - -#include -#include -#include - -#include "absl/random/distributions.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "benchmark/benchmark.h" -#include "gtest/gtest.h" -#include "ortools/util/permutation.h" - -namespace operations_research { - -template -std::string Stringify(const GraphType& graph, - typename GraphType::ArcIndex arc) { - return absl::StrCat(" Arc ", arc, ": ", graph.Tail(arc), " -> ", - graph.Head(arc), "\n"); -} - -template -std::string Stringify(const GraphType& graph, - typename GraphType::NodeIndex tail, - typename GraphType::ArcIndex arc) { - return absl::StrCat(" Arc ", arc, ": ", tail, " -> ", graph.Head(arc), - "\n"); -} - -// We have to base the TestEbertGraph utility function template on a -// class/struct template because partial specialization is not allowed for -// function templates, and we use partial specialization to turn off reverse-arc -// testing for graph representations that don't have reverse arcs. -// -// First, we have the complete case for graphs that represent their reverse arcs -// and therefore can be more thoroughly tested. The constructor does all the -// testing work. -template -struct TestEbertGraphRunner { - TestEbertGraphRunner( - const GraphType& graph, absl::string_view expected_graph_arc_list, - absl::string_view expected_adjacency_list, - absl::string_view expected_incoming_arc_list, - absl::string_view expected_outgoing_arc_list, - absl::string_view expected_debug_string, - absl::string_view unused_expected_forward_debug_string, - absl::string_view unused_expected_forward_static_debug_string) { - std::string graph_arc_list = ""; - for (typename GraphType::ArcIterator arc_it(graph); arc_it.Ok(); - arc_it.Next()) { - typename GraphType::ArcIndex arc = arc_it.Index(); - absl::StrAppend(&graph_arc_list, Stringify(graph, arc)); - EXPECT_EQ(graph.DirectArc(arc), graph.Opposite(graph.ReverseArc(arc))); - } - EXPECT_EQ(expected_graph_arc_list, graph_arc_list); - - std::string adjacency_list = ""; - for (typename GraphType::NodeIterator node_it(graph); node_it.Ok(); - node_it.Next()) { - typename GraphType::NodeIndex node = node_it.Index(); - absl::StrAppend(&adjacency_list, " Node ", node, ":\n"); - for (typename GraphType::OutgoingOrOppositeIncomingArcIterator arc_it( - graph, node); - arc_it.Ok(); arc_it.Next()) { - typename GraphType::ArcIndex arc = arc_it.Index(); - EXPECT_TRUE(graph.IsOutgoingOrOppositeIncoming(arc, node)); - absl::StrAppend(&adjacency_list, Stringify(graph, arc)); - EXPECT_EQ(node, graph.Tail(arc)); - } - } - EXPECT_EQ(expected_adjacency_list, adjacency_list); - - std::string incoming_arc_list = ""; - for (typename GraphType::NodeIterator node_it(graph); node_it.Ok(); - node_it.Next()) { - typename GraphType::NodeIndex node = node_it.Index(); - absl::StrAppend(&incoming_arc_list, " Node ", node, ":\n"); - for (typename GraphType::IncomingArcIterator arc_it(graph, node); - arc_it.Ok(); arc_it.Next()) { - typename GraphType::ArcIndex arc = arc_it.Index(); - EXPECT_TRUE(graph.IsIncoming(arc, node)); - // We assume there are no self-loops in the graph. - EXPECT_FALSE(graph.IsOutgoing(arc, node)); - absl::StrAppend(&incoming_arc_list, Stringify(graph, arc)); - EXPECT_FALSE(graph.IsReverse(arc)); - EXPECT_EQ(node, graph.Head(arc)); - } - } - EXPECT_EQ(expected_incoming_arc_list, incoming_arc_list); - - std::string outgoing_arc_list = ""; - for (typename GraphType::NodeIterator node_it(graph); node_it.Ok(); - node_it.Next()) { - typename GraphType::NodeIndex node = node_it.Index(); - absl::StrAppend(&outgoing_arc_list, " Node ", node, ":\n"); - for (typename GraphType::OutgoingArcIterator arc_it(graph, node); - arc_it.Ok(); arc_it.Next()) { - typename GraphType::ArcIndex arc = arc_it.Index(); - // We assume there are no self-loops in the graph. - EXPECT_FALSE(graph.IsIncoming(arc, node)); - EXPECT_TRUE(graph.IsOutgoing(arc, node)); - absl::StrAppend(&outgoing_arc_list, Stringify(graph, arc)); - EXPECT_TRUE(graph.IsDirect(arc)); - EXPECT_EQ(node, graph.Tail(arc)); - EXPECT_EQ(node, graph.DirectArcTail(arc)); - } - } - EXPECT_EQ(expected_outgoing_arc_list, outgoing_arc_list); - EXPECT_EQ(expected_debug_string, graph.DebugString()); - } -}; - -// Case for graphs that don't have their reverse arcs and therefore cannot be -// tested as completely as if they did, but that are dynamic and so can be -// tested somewhat more than if they weren't. Again, the constructor does all -// the testing work. -template -struct TestEbertGraphRunner { - TestEbertGraphRunner( - const GraphType& graph, absl::string_view unused_expected_graph_arc_list, - absl::string_view unused_expected_adjacency_list, - absl::string_view unused_expected_incoming_arc_list, - absl::string_view expected_outgoing_arc_list, - absl::string_view unused_expected_debug_string, - absl::string_view expected_forward_debug_string, - absl::string_view unused_expected_forward_static_debug_string) { - std::string outgoing_arc_list = ""; - for (typename GraphType::NodeIterator node_it(graph); node_it.Ok(); - node_it.Next()) { - typename GraphType::NodeIndex node = node_it.Index(); - absl::StrAppend(&outgoing_arc_list, " Node ", node, ":\n"); - for (typename GraphType::OutgoingArcIterator arc_it(graph, node); - arc_it.Ok(); arc_it.Next()) { - typename GraphType::ArcIndex arc = arc_it.Index(); - // We assume no self-loops in the graph. - EXPECT_FALSE(graph.IsIncoming(arc, node)); - absl::StrAppend(&outgoing_arc_list, Stringify(graph, node, arc)); - } - } - EXPECT_EQ(expected_outgoing_arc_list, outgoing_arc_list); - EXPECT_EQ(expected_forward_debug_string, graph.DebugString()); - } -}; - -// Case for graphs that don't have their reverse arcs and are static. Due to -// these restrictions, there is less to check about these graphs than for the -// more complicated kinds of graphs. -template -struct TestEbertGraphRunner { - TestEbertGraphRunner(const GraphType& graph, - absl::string_view unused_expected_graph_arc_list, - absl::string_view unused_expected_adjacency_list, - absl::string_view unused_expected_incoming_arc_list, - absl::string_view unused_expected_outgoing_arc_list, - absl::string_view unused_expected_debug_string, - absl::string_view unused_expected_forward_debug_string, - absl::string_view expected_forward_static_debug_string) { - EXPECT_EQ(expected_forward_static_debug_string, graph.DebugString()); - } -}; - -// Tests that various string representations of the given graph match -// the given strings. -template -void TestEbertGraph(const GraphType& graph, - absl::string_view expected_graph_arc_list, - absl::string_view expected_adjacency_list, - absl::string_view expected_incoming_arc_list, - absl::string_view expected_outgoing_arc_list, - absl::string_view expected_debug_string, - absl::string_view expected_forward_debug_string, - absl::string_view expected_forward_static_debug_string) { - TestEbertGraphRunner::has_reverse_arcs, - graph_traits::is_dynamic> - test_object(graph, expected_graph_arc_list, expected_adjacency_list, - expected_incoming_arc_list, expected_outgoing_arc_list, - expected_debug_string, expected_forward_debug_string, - expected_forward_static_debug_string); -} - -template -class DebugStringEbertGraphTest : public ::testing::Test {}; - -typedef ::testing::Types, - ForwardEbertGraph > - EbertGraphTypesForDebugStringTesting; - -TYPED_TEST_SUITE(DebugStringEbertGraphTest, - EbertGraphTypesForDebugStringTesting); - -TYPED_TEST(DebugStringEbertGraphTest, Test1) { - TypeParam graph(4, 6); - graph.AddArc(0, 1); - graph.AddArc(0, 2); - graph.AddArc(1, 3); - graph.AddArc(2, 3); - graph.AddArc(2, 1); - graph.AddArc(1, 2); - - const std::string kGraphArcList = - " Arc 0: 0 -> 1\n" - " Arc 1: 0 -> 2\n" - " Arc 2: 1 -> 3\n" - " Arc 3: 2 -> 3\n" - " Arc 4: 2 -> 1\n" - " Arc 5: 1 -> 2\n"; - - const std::string kExpectedAdjacencyList = - " Node 0:\n" - " Arc 1: 0 -> 2\n" - " Arc 0: 0 -> 1\n" - " Node 1:\n" - " Arc 5: 1 -> 2\n" - " Arc -5: 1 -> 2\n" - " Arc 2: 1 -> 3\n" - " Arc -1: 1 -> 0\n" - " Node 2:\n" - " Arc -6: 2 -> 1\n" - " Arc 4: 2 -> 1\n" - " Arc 3: 2 -> 3\n" - " Arc -2: 2 -> 0\n" - " Node 3:\n" - " Arc -4: 3 -> 2\n" - " Arc -3: 3 -> 1\n"; - - const std::string kExpectedIncomingArcList = - " Node 0:\n" - " Node 1:\n" - " Arc 4: 2 -> 1\n" - " Arc 0: 0 -> 1\n" - " Node 2:\n" - " Arc 5: 1 -> 2\n" - " Arc 1: 0 -> 2\n" - " Node 3:\n" - " Arc 3: 2 -> 3\n" - " Arc 2: 1 -> 3\n"; - - const std::string kExpectedOutgoingArcList = - " Node 0:\n" - " Arc 1: 0 -> 2\n" - " Arc 0: 0 -> 1\n" - " Node 1:\n" - " Arc 5: 1 -> 2\n" - " Arc 2: 1 -> 3\n" - " Node 2:\n" - " Arc 4: 2 -> 1\n" - " Arc 3: 2 -> 3\n" - " Node 3:\n"; - - const std::string kExpectedDebugString = - "Arcs:(node, next arc) :\n" - " -6:(1,4)\n" - " -5:(2,2)\n" - " -4:(2,-3)\n" - " -3:(1,NilArc)\n" - " -2:(0,NilArc)\n" - " -1:(0,NilArc)\n" - " 0:(1,NilArc)\n" - " 1:(2,0)\n" - " 2:(3,-1)\n" - " 3:(3,-2)\n" - " 4:(1,3)\n" - " 5:(2,-5)\n" - "Node:First arc :\n" - " 0:1\n" - " 1:5\n" - " 2:-6\n" - " 3:-4\n"; - - const std::string kExpectedForwardDebugString = - "Arcs:(node, next arc) :\n" - " 0:(1,NilArc)\n" - " 1:(2,0)\n" - " 2:(3,NilArc)\n" - " 3:(3,NilArc)\n" - " 4:(1,3)\n" - " 5:(2,2)\n" - "Node:First arc :\n" - " 0:1\n" - " 1:5\n" - " 2:4\n" - " 3:NilArc\n"; - - TestEbertGraph(graph, kGraphArcList, kExpectedAdjacencyList, - kExpectedIncomingArcList, kExpectedOutgoingArcList, - kExpectedDebugString, kExpectedForwardDebugString, ""); - - // The graph representation is already built, but nothing forbids us from - // testing that it can be rebuilt. To test this for forward graphs, we must - // first collect arc tail information from the adjacency lists, because those - // lists are the only source of information from which we can rebuild a - // forward graph if we haven't maintained its optional arc tail information - // until now. - TailArrayManager tail_array_manager(&graph); - tail_array_manager.BuildTailArrayFromAdjacencyListsIfForwardGraph(); - graph.BuildRepresentation(); - - TestEbertGraph(graph, kGraphArcList, kExpectedAdjacencyList, - kExpectedIncomingArcList, kExpectedOutgoingArcList, - kExpectedDebugString, kExpectedForwardDebugString, ""); -} - -// Unfortunately, this class template has to be defined outside the Test2 method -// where it is used, or a compiler bug gets tickled. -template -class ArcFunctorByHead { - public: - explicit ArcFunctorByHead(const GraphType& graph) : graph_(graph) {} - - bool operator()(typename GraphType::ArcIndex a, - typename GraphType::ArcIndex b) const { - return ((graph_.Head(a) < graph_.Head(b)) || - ((graph_.Head(a) == graph_.Head(b)) && - (graph_.Tail(a) < graph_.Tail(b)))); - } - - private: - const GraphType& graph_; -}; - -// Unfortunately this class template has to be defined outside the Test2 method -// where it is used, or a compiler bug gets tickled. -template -class ArcFunctorByTail { - public: - explicit ArcFunctorByTail(const GraphType& graph) : graph_(graph) {} - - bool operator()(typename GraphType::ArcIndex a, - typename GraphType::ArcIndex b) const { - return ((graph_.Tail(a) < graph_.Tail(b)) || - ((graph_.Tail(a) == graph_.Tail(b)) && - (graph_.Head(a) < graph_.Head(b)))); - } - - private: - const GraphType& graph_; -}; - -TYPED_TEST(DebugStringEbertGraphTest, Test2) { - TypeParam graph(3, 6); - graph.AddArc(0, 1); - graph.AddArc(1, 0); - graph.AddArc(1, 2); - graph.AddArc(2, 1); - graph.AddArc(0, 2); - graph.AddArc(2, 0); - - const std::string kGraphArcList = - " Arc 0: 0 -> 1\n" - " Arc 1: 1 -> 0\n" - " Arc 2: 1 -> 2\n" - " Arc 3: 2 -> 1\n" - " Arc 4: 0 -> 2\n" - " Arc 5: 2 -> 0\n"; - - const std::string kExpectedAdjacencyList = - " Node 0:\n" - " Arc -6: 0 -> 2\n" - " Arc 4: 0 -> 2\n" - " Arc -2: 0 -> 1\n" - " Arc 0: 0 -> 1\n" - " Node 1:\n" - " Arc -4: 1 -> 2\n" - " Arc 2: 1 -> 2\n" - " Arc 1: 1 -> 0\n" - " Arc -1: 1 -> 0\n" - " Node 2:\n" - " Arc 5: 2 -> 0\n" - " Arc -5: 2 -> 0\n" - " Arc 3: 2 -> 1\n" - " Arc -3: 2 -> 1\n"; - - const std::string kExpectedIncomingArcList = - " Node 0:\n" - " Arc 5: 2 -> 0\n" - " Arc 1: 1 -> 0\n" - " Node 1:\n" - " Arc 3: 2 -> 1\n" - " Arc 0: 0 -> 1\n" - " Node 2:\n" - " Arc 4: 0 -> 2\n" - " Arc 2: 1 -> 2\n"; - - const std::string kExpectedOutgoingArcList = - " Node 0:\n" - " Arc 4: 0 -> 2\n" - " Arc 0: 0 -> 1\n" - " Node 1:\n" - " Arc 2: 1 -> 2\n" - " Arc 1: 1 -> 0\n" - " Node 2:\n" - " Arc 5: 2 -> 0\n" - " Arc 3: 2 -> 1\n"; - - const std::string kExpectedDebugString = - "Arcs:(node, next arc) :\n" - " -6:(2,4)\n" - " -5:(0,3)\n" - " -4:(2,2)\n" - " -3:(1,NilArc)\n" - " -2:(1,0)\n" - " -1:(0,NilArc)\n" - " 0:(1,NilArc)\n" - " 1:(0,-1)\n" - " 2:(2,1)\n" - " 3:(1,-3)\n" - " 4:(2,-2)\n" - " 5:(0,-5)\n" - "Node:First arc :\n" - " 0:-6\n" - " 1:-4\n" - " 2:5\n"; - - const std::string kExpectedForwardDebugString = - "Arcs:(node, next arc) :\n" - " 0:(1,NilArc)\n" - " 1:(0,NilArc)\n" - " 2:(2,1)\n" - " 3:(1,NilArc)\n" - " 4:(2,0)\n" - " 5:(0,3)\n" - "Node:First arc :\n" - " 0:4\n" - " 1:2\n" - " 2:5\n"; - - TestEbertGraph(graph, kGraphArcList, kExpectedAdjacencyList, - kExpectedIncomingArcList, kExpectedOutgoingArcList, - kExpectedDebugString, kExpectedForwardDebugString, ""); - - TailArrayManager tail_array_manager(&graph); - tail_array_manager.BuildTailArrayFromAdjacencyListsIfForwardGraph(); - graph.GroupForwardArcsByFunctor(ArcFunctorByHead(graph), nullptr); - - const std::string kGraphHeadGroupedArcList = - " Arc 0: 1 -> 0\n" - " Arc 1: 2 -> 0\n" - " Arc 2: 0 -> 1\n" - " Arc 3: 2 -> 1\n" - " Arc 4: 0 -> 2\n" - " Arc 5: 1 -> 2\n"; - - const std::string kExpectedHeadGroupedAdjacencyList = - " Node 0:\n" - " Arc 4: 0 -> 2\n" - " Arc 2: 0 -> 1\n" - " Arc -2: 0 -> 2\n" - " Arc -1: 0 -> 1\n" - " Node 1:\n" - " Arc 5: 1 -> 2\n" - " Arc -4: 1 -> 2\n" - " Arc -3: 1 -> 0\n" - " Arc 0: 1 -> 0\n" - " Node 2:\n" - " Arc -6: 2 -> 1\n" - " Arc -5: 2 -> 0\n" - " Arc 3: 2 -> 1\n" - " Arc 1: 2 -> 0\n"; - - const std::string kExpectedHeadGroupedIncomingArcList = - " Node 0:\n" - " Arc 1: 2 -> 0\n" - " Arc 0: 1 -> 0\n" - " Node 1:\n" - " Arc 3: 2 -> 1\n" - " Arc 2: 0 -> 1\n" - " Node 2:\n" - " Arc 5: 1 -> 2\n" - " Arc 4: 0 -> 2\n"; - - const std::string kExpectedHeadGroupedOutgoingArcList = - " Node 0:\n" - " Arc 4: 0 -> 2\n" - " Arc 2: 0 -> 1\n" - " Node 1:\n" - " Arc 5: 1 -> 2\n" - " Arc 0: 1 -> 0\n" - " Node 2:\n" - " Arc 3: 2 -> 1\n" - " Arc 1: 2 -> 0\n"; - - const std::string kExpectedHeadGroupedDebugString = - "Arcs:(node, next arc) :\n" - " -6:(1,-5)\n" - " -5:(0,3)\n" - " -4:(2,-3)\n" - " -3:(0,0)\n" - " -2:(2,-1)\n" - " -1:(1,NilArc)\n" - " 0:(0,NilArc)\n" - " 1:(0,NilArc)\n" - " 2:(1,-2)\n" - " 3:(1,1)\n" - " 4:(2,2)\n" - " 5:(2,-4)\n" - "Node:First arc :\n" - " 0:4\n" - " 1:5\n" - " 2:-6\n"; - - const std::string kExpectedHeadGroupedForwardDebugString = - "Arcs:(node, next arc) :\n" - " 0:(0,NilArc)\n" - " 1:(0,NilArc)\n" - " 2:(1,NilArc)\n" - " 3:(1,1)\n" - " 4:(2,2)\n" - " 5:(2,0)\n" - "Node:First arc :\n" - " 0:4\n" - " 1:5\n" - " 2:3\n"; - - TestEbertGraph( - graph, kGraphHeadGroupedArcList, kExpectedHeadGroupedAdjacencyList, - kExpectedHeadGroupedIncomingArcList, kExpectedHeadGroupedOutgoingArcList, - kExpectedHeadGroupedDebugString, kExpectedHeadGroupedForwardDebugString, - ""); - - // Test that the GroupForwardArcsByFunctor method correctly permutes arc - // annotation data. - int arc_annotations[] = {103, 105, 101, 106, 102, 104}; - ArrayIndexCycleHandler handler( - arc_annotations); - graph.GroupForwardArcsByFunctor(ArcFunctorByTail(graph), &handler); - - for (int i = 0; i < 6; ++i) { - EXPECT_EQ(101 + i, arc_annotations[i]); - } - - const std::string kGraphTailGroupedArcList = - " Arc 0: 0 -> 1\n" - " Arc 1: 0 -> 2\n" - " Arc 2: 1 -> 0\n" - " Arc 3: 1 -> 2\n" - " Arc 4: 2 -> 0\n" - " Arc 5: 2 -> 1\n"; - - const std::string kExpectedTailGroupedAdjacencyList = - " Node 0:\n" - " Arc -5: 0 -> 2\n" - " Arc -3: 0 -> 1\n" - " Arc 1: 0 -> 2\n" - " Arc 0: 0 -> 1\n" - " Node 1:\n" - " Arc -6: 1 -> 2\n" - " Arc 3: 1 -> 2\n" - " Arc 2: 1 -> 0\n" - " Arc -1: 1 -> 0\n" - " Node 2:\n" - " Arc 5: 2 -> 1\n" - " Arc 4: 2 -> 0\n" - " Arc -4: 2 -> 1\n" - " Arc -2: 2 -> 0\n"; - - const std::string kExpectedTailGroupedIncomingArcList = - " Node 0:\n" - " Arc 4: 2 -> 0\n" - " Arc 2: 1 -> 0\n" - " Node 1:\n" - " Arc 5: 2 -> 1\n" - " Arc 0: 0 -> 1\n" - " Node 2:\n" - " Arc 3: 1 -> 2\n" - " Arc 1: 0 -> 2\n"; - - const std::string kExpectedTailGroupedOutgoingArcList = - " Node 0:\n" - " Arc 1: 0 -> 2\n" - " Arc 0: 0 -> 1\n" - " Node 1:\n" - " Arc 3: 1 -> 2\n" - " Arc 2: 1 -> 0\n" - " Node 2:\n" - " Arc 5: 2 -> 1\n" - " Arc 4: 2 -> 0\n"; - - const std::string kExpectedTailGroupedDebugString = - "Arcs:(node, next arc) :\n" - " -6:(2,3)\n" - " -5:(2,-3)\n" - " -4:(1,-2)\n" - " -3:(1,1)\n" - " -2:(0,NilArc)\n" - " -1:(0,NilArc)\n" - " 0:(1,NilArc)\n" - " 1:(2,0)\n" - " 2:(0,-1)\n" - " 3:(2,2)\n" - " 4:(0,-4)\n" - " 5:(1,4)\n" - "Node:First arc :\n" - " 0:-5\n" - " 1:-6\n" - " 2:5\n"; - - const std::string kExpectedTailGroupedForwardDebugString = - "Arcs:(node, next arc) :\n" - " 0:(1,NilArc)\n" - " 1:(2,0)\n" - " 2:(0,NilArc)\n" - " 3:(2,2)\n" - " 4:(0,NilArc)\n" - " 5:(1,4)\n" - "Node:First arc :\n" - " 0:1\n" - " 1:3\n" - " 2:5\n"; - - TestEbertGraph( - graph, kGraphTailGroupedArcList, kExpectedTailGroupedAdjacencyList, - kExpectedTailGroupedIncomingArcList, kExpectedTailGroupedOutgoingArcList, - kExpectedTailGroupedDebugString, kExpectedTailGroupedForwardDebugString, - ""); -} - -template -class DebugStringTestWithGraphBuildManager : public ::testing::Test {}; - -typedef ::testing::Types, - ForwardEbertGraph > - GraphTypesForDebugStringTestWithGraphBuildManager; - -TYPED_TEST_SUITE(DebugStringTestWithGraphBuildManager, - GraphTypesForDebugStringTestWithGraphBuildManager); - -TYPED_TEST(DebugStringTestWithGraphBuildManager, - UnsortedArcsWithoutAnnotation) { - AnnotatedGraphBuildManager* builder = - new AnnotatedGraphBuildManager( - 4, 6, false /* don't sort adjacency lists */); - - EXPECT_EQ(0, builder->AddArc(0, 2)); - EXPECT_EQ(1, builder->AddArc(2, 0)); - EXPECT_EQ(2, builder->AddArc(2, 3)); - EXPECT_EQ(3, builder->AddArc(3, 2)); - EXPECT_EQ(4, builder->AddArc(0, 3)); - EXPECT_EQ(5, builder->AddArc(3, 0)); - - const TypeParam* graph = builder->Graph(nullptr); - ASSERT_TRUE(graph != nullptr); - - const std::string kGraphArcList = - " Arc 0: 0 -> 2\n" - " Arc 1: 2 -> 0\n" - " Arc 2: 2 -> 3\n" - " Arc 3: 3 -> 2\n" - " Arc 4: 0 -> 3\n" - " Arc 5: 3 -> 0\n"; - - const std::string kExpectedAdjacencyList = - " Node 0:\n" - " Arc -6: 0 -> 3\n" - " Arc 4: 0 -> 3\n" - " Arc -2: 0 -> 2\n" - " Arc 0: 0 -> 2\n" - " Node 1:\n" - " Node 2:\n" - " Arc -4: 2 -> 3\n" - " Arc 2: 2 -> 3\n" - " Arc 1: 2 -> 0\n" - " Arc -1: 2 -> 0\n" - " Node 3:\n" - " Arc 5: 3 -> 0\n" - " Arc -5: 3 -> 0\n" - " Arc 3: 3 -> 2\n" - " Arc -3: 3 -> 2\n"; - - const std::string kExpectedIncomingArcList = - " Node 0:\n" - " Arc 5: 3 -> 0\n" - " Arc 1: 2 -> 0\n" - " Node 1:\n" - " Node 2:\n" - " Arc 3: 3 -> 2\n" - " Arc 0: 0 -> 2\n" - " Node 3:\n" - " Arc 4: 0 -> 3\n" - " Arc 2: 2 -> 3\n"; - - const std::string kExpectedOutgoingArcList = - " Node 0:\n" - " Arc 4: 0 -> 3\n" - " Arc 0: 0 -> 2\n" - " Node 1:\n" - " Node 2:\n" - " Arc 2: 2 -> 3\n" - " Arc 1: 2 -> 0\n" - " Node 3:\n" - " Arc 5: 3 -> 0\n" - " Arc 3: 3 -> 2\n"; - - const std::string kExpectedDebugString = - "Arcs:(node, next arc) :\n" - " -6:(3,4)\n" - " -5:(0,3)\n" - " -4:(3,2)\n" - " -3:(2,NilArc)\n" - " -2:(2,0)\n" - " -1:(0,NilArc)\n" - " 0:(2,NilArc)\n" - " 1:(0,-1)\n" - " 2:(3,1)\n" - " 3:(2,-3)\n" - " 4:(3,-2)\n" - " 5:(0,-5)\n" - "Node:First arc :\n" - " 0:-6\n" - " 1:NilArc\n" - " 2:-4\n" - " 3:5\n"; - - const std::string kExpectedForwardDebugString = - "Arcs:(node, next arc) :\n" - " 0:(2,NilArc)\n" - " 1:(0,NilArc)\n" - " 2:(3,1)\n" - " 3:(2,NilArc)\n" - " 4:(3,0)\n" - " 5:(0,3)\n" - "Node:First arc :\n" - " 0:4\n" - " 1:NilArc\n" - " 2:2\n" - " 3:5\n"; - - const std::string kExpectedForwardStaticDebugString = - "Arcs:(node) :\n" - " 0:(2)\n" - " 1:(3)\n" - " 2:(0)\n" - " 3:(3)\n" - " 4:(2)\n" - " 5:(0)\n" - "Node:First arc :\n" - " 0:0\n" - " 1:2\n" - " 2:2\n" - " 3:4\n" - " 4:6\n"; - - TestEbertGraph(*graph, kGraphArcList, kExpectedAdjacencyList, - kExpectedIncomingArcList, kExpectedOutgoingArcList, - kExpectedDebugString, kExpectedForwardDebugString, - kExpectedForwardStaticDebugString); - - delete graph; -} - -TYPED_TEST(DebugStringTestWithGraphBuildManager, SortedArcsWithAnnotation) { - AnnotatedGraphBuildManager* builder = - new AnnotatedGraphBuildManager( - 4, 6, true /* sort adjacency lists */); - - EXPECT_EQ(0, builder->AddArc(0, 2)); - EXPECT_EQ(1, builder->AddArc(2, 0)); - EXPECT_EQ(2, builder->AddArc(2, 3)); - EXPECT_EQ(3, builder->AddArc(3, 2)); - EXPECT_EQ(4, builder->AddArc(0, 3)); - EXPECT_EQ(5, builder->AddArc(3, 0)); - - // Test that the graph building and arc sorting operations correctly - // permute arc annotation data. - int arc_annotations[] = {101, 103, 104, 106, 102, 105}; - ArrayIndexCycleHandler handler( - arc_annotations); - const TypeParam* graph = builder->Graph(&handler); - ASSERT_TRUE(graph != nullptr); - for (int i = 0; i < 6; ++i) { - EXPECT_EQ(101 + i, arc_annotations[i]); - } - - const std::string kGraphArcList = - " Arc 0: 0 -> 2\n" - " Arc 1: 0 -> 3\n" - " Arc 2: 2 -> 0\n" - " Arc 3: 2 -> 3\n" - " Arc 4: 3 -> 0\n" - " Arc 5: 3 -> 2\n"; - - const std::string kExpectedAdjacencyList = - " Node 0:\n" - " Arc -5: 0 -> 3\n" - " Arc -3: 0 -> 2\n" - " Arc 1: 0 -> 3\n" - " Arc 0: 0 -> 2\n" - " Node 1:\n" - " Node 2:\n" - " Arc -6: 2 -> 3\n" - " Arc 3: 2 -> 3\n" - " Arc 2: 2 -> 0\n" - " Arc -1: 2 -> 0\n" - " Node 3:\n" - " Arc 5: 3 -> 2\n" - " Arc 4: 3 -> 0\n" - " Arc -4: 3 -> 2\n" - " Arc -2: 3 -> 0\n"; - - const std::string kExpectedIncomingArcList = - " Node 0:\n" - " Arc 4: 3 -> 0\n" - " Arc 2: 2 -> 0\n" - " Node 1:\n" - " Node 2:\n" - " Arc 5: 3 -> 2\n" - " Arc 0: 0 -> 2\n" - " Node 3:\n" - " Arc 3: 2 -> 3\n" - " Arc 1: 0 -> 3\n"; - - const std::string kExpectedOutgoingArcList = - " Node 0:\n" - " Arc 1: 0 -> 3\n" - " Arc 0: 0 -> 2\n" - " Node 1:\n" - " Node 2:\n" - " Arc 3: 2 -> 3\n" - " Arc 2: 2 -> 0\n" - " Node 3:\n" - " Arc 5: 3 -> 2\n" - " Arc 4: 3 -> 0\n"; - - const std::string kExpectedDebugString = - "Arcs:(node, next arc) :\n" - " -6:(3,3)\n" - " -5:(3,-3)\n" - " -4:(2,-2)\n" - " -3:(2,1)\n" - " -2:(0,NilArc)\n" - " -1:(0,NilArc)\n" - " 0:(2,NilArc)\n" - " 1:(3,0)\n" - " 2:(0,-1)\n" - " 3:(3,2)\n" - " 4:(0,-4)\n" - " 5:(2,4)\n" - "Node:First arc :\n" - " 0:-5\n" - " 1:NilArc\n" - " 2:-6\n" - " 3:5\n"; - - const std::string kExpectedForwardDebugString = - "Arcs:(node, next arc) :\n" - " 0:(2,NilArc)\n" - " 1:(3,0)\n" - " 2:(0,NilArc)\n" - " 3:(3,2)\n" - " 4:(0,NilArc)\n" - " 5:(2,4)\n" - "Node:First arc :\n" - " 0:1\n" - " 1:NilArc\n" - " 2:3\n" - " 3:5\n"; - - const std::string kExpectedForwardStaticDebugString = - "Arcs:(node) :\n" - " 0:(2)\n" - " 1:(3)\n" - " 2:(0)\n" - " 3:(3)\n" - " 4:(0)\n" - " 5:(2)\n" - "Node:First arc :\n" - " 0:0\n" - " 1:2\n" - " 2:2\n" - " 3:4\n" - " 4:6\n"; - - TestEbertGraph(*graph, kGraphArcList, kExpectedAdjacencyList, - kExpectedIncomingArcList, kExpectedOutgoingArcList, - kExpectedDebugString, kExpectedForwardDebugString, - kExpectedForwardStaticDebugString); - - delete graph; -} - -TYPED_TEST(DebugStringTestWithGraphBuildManager, SortedArcsWithoutAnnotation) { - AnnotatedGraphBuildManager* builder = - new AnnotatedGraphBuildManager( - 4, 6, true /* sort adjacency lists */); - - EXPECT_EQ(0, builder->AddArc(0, 2)); - EXPECT_EQ(1, builder->AddArc(2, 0)); - EXPECT_EQ(2, builder->AddArc(2, 3)); - EXPECT_EQ(3, builder->AddArc(3, 2)); - EXPECT_EQ(4, builder->AddArc(0, 3)); - EXPECT_EQ(5, builder->AddArc(3, 0)); - - const TypeParam* graph = builder->Graph(nullptr); - - const std::string kGraphArcList = - " Arc 0: 0 -> 2\n" - " Arc 1: 0 -> 3\n" - " Arc 2: 2 -> 0\n" - " Arc 3: 2 -> 3\n" - " Arc 4: 3 -> 0\n" - " Arc 5: 3 -> 2\n"; - - const std::string kExpectedAdjacencyList = - " Node 0:\n" - " Arc -5: 0 -> 3\n" - " Arc -3: 0 -> 2\n" - " Arc 1: 0 -> 3\n" - " Arc 0: 0 -> 2\n" - " Node 1:\n" - " Node 2:\n" - " Arc -6: 2 -> 3\n" - " Arc 3: 2 -> 3\n" - " Arc 2: 2 -> 0\n" - " Arc -1: 2 -> 0\n" - " Node 3:\n" - " Arc 5: 3 -> 2\n" - " Arc 4: 3 -> 0\n" - " Arc -4: 3 -> 2\n" - " Arc -2: 3 -> 0\n"; - - const std::string kExpectedIncomingArcList = - " Node 0:\n" - " Arc 4: 3 -> 0\n" - " Arc 2: 2 -> 0\n" - " Node 1:\n" - " Node 2:\n" - " Arc 5: 3 -> 2\n" - " Arc 0: 0 -> 2\n" - " Node 3:\n" - " Arc 3: 2 -> 3\n" - " Arc 1: 0 -> 3\n"; - - const std::string kExpectedOutgoingArcList = - " Node 0:\n" - " Arc 1: 0 -> 3\n" - " Arc 0: 0 -> 2\n" - " Node 1:\n" - " Node 2:\n" - " Arc 3: 2 -> 3\n" - " Arc 2: 2 -> 0\n" - " Node 3:\n" - " Arc 5: 3 -> 2\n" - " Arc 4: 3 -> 0\n"; - - const std::string kExpectedDebugString = - "Arcs:(node, next arc) :\n" - " -6:(3,3)\n" - " -5:(3,-3)\n" - " -4:(2,-2)\n" - " -3:(2,1)\n" - " -2:(0,NilArc)\n" - " -1:(0,NilArc)\n" - " 0:(2,NilArc)\n" - " 1:(3,0)\n" - " 2:(0,-1)\n" - " 3:(3,2)\n" - " 4:(0,-4)\n" - " 5:(2,4)\n" - "Node:First arc :\n" - " 0:-5\n" - " 1:NilArc\n" - " 2:-6\n" - " 3:5\n"; - - const std::string kExpectedForwardDebugString = - "Arcs:(node, next arc) :\n" - " 0:(2,NilArc)\n" - " 1:(3,0)\n" - " 2:(0,NilArc)\n" - " 3:(3,2)\n" - " 4:(0,NilArc)\n" - " 5:(2,4)\n" - "Node:First arc :\n" - " 0:1\n" - " 1:NilArc\n" - " 2:3\n" - " 3:5\n"; - - const std::string kExpectedForwardStaticDebugString = - "Arcs:(node) :\n" - " 0:(2)\n" - " 1:(3)\n" - " 2:(0)\n" - " 3:(3)\n" - " 4:(0)\n" - " 5:(2)\n" - "Node:First arc :\n" - " 0:0\n" - " 1:2\n" - " 2:2\n" - " 3:4\n" - " 4:6\n"; - - TestEbertGraph(*graph, kGraphArcList, kExpectedAdjacencyList, - kExpectedIncomingArcList, kExpectedOutgoingArcList, - kExpectedDebugString, kExpectedForwardDebugString, - kExpectedForwardStaticDebugString); - - delete graph; -} - -// An empty fixture template to collect the types of tiny graphs for which we -// want to do very basic tests. -template -class TinyEbertGraphTest : public ::testing::Test {}; - -typedef ::testing::Types, - ForwardEbertGraph > - TinyEbertGraphTypesForTesting; - -TYPED_TEST_SUITE(TinyEbertGraphTest, TinyEbertGraphTypesForTesting); - -TYPED_TEST(TinyEbertGraphTest, CheckDeathOnBadBounds) { - typedef TypeParam SmallStarGraph; - int num_nodes = SmallStarGraph::kMaxNumNodes; - int num_arcs = SmallStarGraph::kMaxNumArcs; - SmallStarGraph(num_nodes, num_arcs); // Construct an unused graph. All fine. -} - -// An empty fixture to collect the types of small graphs for which we want to do -// some fairly trivial tests. -template -class SmallEbertGraphTest : public ::testing::Test {}; - -typedef ::testing::Types< - EbertGraph, EbertGraph, - ForwardEbertGraph, ForwardEbertGraph > - SmallEbertGraphTypesForTesting; - -TYPED_TEST_SUITE(SmallEbertGraphTest, SmallEbertGraphTypesForTesting); - -TYPED_TEST(SmallEbertGraphTest, EmptyGraph) { - TypeParam graph(3, 6); - const std::string kGraphArcList = ""; - const std::string kExpectedAdjacencyList = ""; - const std::string kExpectedIncomingArcList = ""; - const std::string kExpectedOutgoingArcList = ""; - const std::string kExpectedDebugString = - "Arcs:(node, next arc) :\n" - "Node:First arc :\n"; - TestEbertGraph(graph, kGraphArcList, kExpectedAdjacencyList, - kExpectedIncomingArcList, kExpectedOutgoingArcList, - kExpectedDebugString, kExpectedDebugString, - kExpectedDebugString); -} - -TEST(EbertGraphTest, CheckBounds) { - typedef EbertGraph SmallStarGraph; - SmallStarGraph g(SmallStarGraph::kMaxNumNodes, SmallStarGraph::kMaxNumArcs); - EXPECT_TRUE(g.CheckArcBounds(SmallStarGraph::kNilArc)); - EXPECT_FALSE(g.CheckArcValidity(SmallStarGraph::kNilArc)); - EXPECT_FALSE(g.CheckArcValidity(SmallStarGraph::kMaxNumArcs)); - EXPECT_TRUE(g.CheckArcValidity(g.SmallStarGraph::kMaxNumArcs - 1)); - EXPECT_TRUE(g.CheckArcValidity(g.Opposite(SmallStarGraph::kMaxNumArcs - 1))); -} - -TEST(ForwardEbertGraphTest, CheckBounds) { - typedef ForwardEbertGraph SmallStarGraph; - SmallStarGraph g(SmallStarGraph::kMaxNumNodes, SmallStarGraph::kMaxNumArcs); - EXPECT_TRUE(g.CheckArcBounds(SmallStarGraph::kNilArc)); - EXPECT_FALSE(g.CheckArcValidity(SmallStarGraph::kNilArc)); - EXPECT_FALSE(g.CheckArcValidity(SmallStarGraph::kMaxNumArcs)); - EXPECT_TRUE(g.CheckArcValidity(g.SmallStarGraph::kMaxNumArcs - 1)); -} - -TEST(ForwardEbertGraphTest, ImpossibleBuildTailArray) { - typedef ForwardEbertGraph SmallStarGraph; - SmallStarGraph g(3, 3); - ArcIndex arc = g.AddArc(0, 1); - // The SetHead() method is the easiest way to dirty the representation. - // Alternatively, since this is a FRIEND_TEST of the EbertGraphBase<> - // template, we could just set g.representation_clean_ = false. Once the - // representation is dirty, the adjacency lists are invalid and we haven't - // bothered to allocate and maintain the optional array of arc tails, so - // rebuilding that optional tail array is impossible. - g.SetHead(arc, 2); - EXPECT_FALSE(g.BuildTailArray()); -} - -template -static void BM_RandomArcs(benchmark::State& state) { - const int kRandomSeed = 0; - const int kNodes = 10 * 1000 * 1000; - const int kArcs = 5 * kNodes; - for (auto _ : state) { - AnnotatedGraphBuildManager* builder = - new AnnotatedGraphBuildManager(kNodes, kArcs, sort_arcs); - std::mt19937 randomizer(kRandomSeed); - for (int i = 0; i < kArcs; ++i) { - builder->AddArc(absl::Uniform(randomizer, 0, kNodes), - absl::Uniform(randomizer, 0, kNodes)); - } - (void)builder->Graph(nullptr); - } - // An item is an arc here. - state.SetItemsProcessed(static_cast(state.max_iterations) * kArcs); -} - -BENCHMARK_TEMPLATE2(BM_RandomArcs, StarGraph, false); -BENCHMARK_TEMPLATE2(BM_RandomArcs, ForwardStarGraph, false); - -BENCHMARK_TEMPLATE2(BM_RandomArcs, StarGraph, true); -BENCHMARK_TEMPLATE2(BM_RandomArcs, ForwardStarGraph, true); - -template -static void BM_RandomAnnotatedArcs(benchmark::State& state) { - const int kRandomSeed = 0; - const int kNodes = 10 * 1000 * 1000; - const int kArcs = 5 * kNodes; - int* annotation = new int[kArcs]; - for (auto _ : state) { - AnnotatedGraphBuildManager* builder = - new AnnotatedGraphBuildManager(kNodes, kArcs, sort_arcs); - std::mt19937 randomizer(kRandomSeed); - for (int i = 0; i < kArcs; ++i) { - ArcIndex arc = builder->AddArc(absl::Uniform(randomizer, 0, kNodes), - absl::Uniform(randomizer, 0, kNodes)); - annotation[arc] = absl::Uniform(randomizer, 0, kNodes); - } - ArrayIndexCycleHandler cycle_handler(annotation); - (void)builder->Graph(&cycle_handler); - } - delete[] annotation; - // An item is an arc here. - state.SetItemsProcessed(static_cast(state.max_iterations) * kArcs); -} - -BENCHMARK_TEMPLATE2(BM_RandomAnnotatedArcs, StarGraph, false); -BENCHMARK_TEMPLATE2(BM_RandomAnnotatedArcs, ForwardStarGraph, false); - -BENCHMARK_TEMPLATE2(BM_RandomAnnotatedArcs, StarGraph, true); -BENCHMARK_TEMPLATE2(BM_RandomAnnotatedArcs, ForwardStarGraph, true); - -template -static void BM_AddRandomArcsAndDoNotRetrieveGraph(benchmark::State& state) { - const int kRandomSeed = 0; - const int kNodes = 10 * 1000 * 1000; - const int kArcs = 5 * kNodes; - for (auto _ : state) { - AnnotatedGraphBuildManager* builder = - new AnnotatedGraphBuildManager(kNodes, kArcs, false); - std::mt19937 randomizer(kRandomSeed); - for (int i = 0; i < kArcs; ++i) { - builder->AddArc(absl::Uniform(randomizer, 0, kNodes), - absl::Uniform(randomizer, 0, kNodes)); - } - delete builder; - } - // An item is an arc here. - state.SetItemsProcessed(static_cast(state.max_iterations) * kArcs); -} - -BENCHMARK_TEMPLATE(BM_AddRandomArcsAndDoNotRetrieveGraph, StarGraph); -BENCHMARK_TEMPLATE(BM_AddRandomArcsAndDoNotRetrieveGraph, ForwardStarGraph); - -} // namespace operations_research diff --git a/ortools/graph/generic_max_flow.h b/ortools/graph/generic_max_flow.h index 6ecaf6f415..5c441da43c 100644 --- a/ortools/graph/generic_max_flow.h +++ b/ortools/graph/generic_max_flow.h @@ -130,12 +130,10 @@ #include #include -#include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "ortools/base/logging.h" #include "ortools/graph/ebert_graph.h" #include "ortools/graph/flow_problem.pb.h" -#include "ortools/graph/graphs.h" #include "ortools/util/stats.h" #include "ortools/util/zvector.h" @@ -573,7 +571,7 @@ GenericMaxFlow::GenericMaxFlow(const Graph* graph, NodeIndex source, SCOPED_TIME_STAT(&stats_); DCHECK(graph->IsNodeValid(source)); DCHECK(graph->IsNodeValid(sink)); - const NodeIndex max_num_nodes = Graphs::NodeReservation(*graph_); + const NodeIndex max_num_nodes = graph_->node_capacity(); if (max_num_nodes > 0) { // We will initialize them in InitializePreflow(), so no need for memset. // @@ -584,7 +582,7 @@ GenericMaxFlow::GenericMaxFlow(const Graph* graph, NodeIndex source, first_admissible_arc_ = std::make_unique(max_num_nodes); bfs_queue_.reserve(max_num_nodes); } - const ArcIndex max_num_arcs = Graphs::ArcReservation(*graph_); + const ArcIndex max_num_arcs = graph_->arc_capacity(); if (max_num_arcs > 0) { if constexpr (Graph::kHasNegativeReverseArcs) { residual_arc_capacity_.Reserve(-max_num_arcs, max_num_arcs - 1); @@ -767,7 +765,7 @@ void GenericMaxFlow::InitializePreflow() { // TODO(user): Ebert graph has an issue with nodes with no arcs, so we // use max_num_nodes here to resize vectors. const NodeIndex num_nodes = graph_->num_nodes(); - const NodeIndex max_num_nodes = Graphs::NodeReservation(*graph_); + const NodeIndex max_num_nodes = graph_->node_capacity(); // InitializePreflow() clears the whole flow that could have been computed // by a previous Solve(). This is not optimal in terms of complexity. @@ -1165,9 +1163,7 @@ template void GenericMaxFlow::RefineWithGlobalUpdate() { SCOPED_TIME_STAT(&stats_); - // TODO(user): This should be graph_->num_nodes(), but ebert graph does not - // have a correct size if the highest index nodes have no arcs. - const NodeIndex num_nodes = Graphs::NodeReservation(*graph_); + const NodeIndex num_nodes = graph_->num_nodes(); std::vector skip_active_node; // Usually SaturateOutgoingArcsFromSource() will saturate all the arcs from @@ -1304,7 +1300,7 @@ void GenericMaxFlow::Relabel(NodeIndex node) { template typename Graph::ArcIndex GenericMaxFlow::Opposite(ArcIndex arc) const { - return Graphs::OppositeArc(*graph_, arc); + return graph_->OppositeArc(arc); } template @@ -1314,7 +1310,7 @@ bool GenericMaxFlow::IsArcDirect(ArcIndex arc) const { template bool GenericMaxFlow::IsArcValid(ArcIndex arc) const { - return Graphs::IsArcValid(*graph_, arc); + return graph_->IsArcValid(arc); } template diff --git a/ortools/graph/generic_max_flow_test.cc b/ortools/graph/generic_max_flow_test.cc index faf6d425fe..9f42eb2e69 100644 --- a/ortools/graph/generic_max_flow_test.cc +++ b/ortools/graph/generic_max_flow_test.cc @@ -33,7 +33,6 @@ #include "ortools/base/logging.h" #include "ortools/graph/ebert_graph.h" #include "ortools/graph/graph.h" -#include "ortools/graph/graphs.h" #include "ortools/linear_solver/linear_solver.h" namespace operations_research { @@ -58,7 +57,7 @@ typename GenericMaxFlow::Status MaxFlowTester( graph.AddArc(tail[i], head[i]); } std::vector permutation; - Graphs::Build(&graph, &permutation); + graph.Build(&permutation); GenericMaxFlow max_flow(&graph, 0, num_nodes - 1); for (typename Graph::ArcIndex arc = 0; arc < num_arcs; ++arc) { @@ -69,7 +68,8 @@ typename GenericMaxFlow::Status MaxFlowTester( } EXPECT_TRUE(max_flow.Solve()); if (max_flow.status() == GenericMaxFlow::OPTIMAL) { - const FlowQuantity total_flow = max_flow.GetOptimalFlow(); + const typename GenericMaxFlow::FlowQuantityT total_flow = + max_flow.GetOptimalFlow(); EXPECT_EQ(expected_total_flow, total_flow); for (int arc = 0; arc < num_arcs; ++arc) { const int image = arc < permutation.size() ? permutation[arc] : arc; @@ -80,13 +80,13 @@ typename GenericMaxFlow::Status MaxFlowTester( // Tests the min-cut functions. if (expected_source_min_cut != nullptr) { - std::vector cut; + std::vector cut; max_flow.GetSourceSideMinCut(&cut); std::sort(cut.begin(), cut.end()); EXPECT_THAT(*expected_source_min_cut, WhenSorted(ContainerEq(cut))); } if (expected_sink_min_cut != nullptr) { - std::vector cut; + std::vector cut; max_flow.GetSinkSideMinCut(&cut); std::sort(cut.begin(), cut.end()); EXPECT_THAT(*expected_sink_min_cut, WhenSorted(ContainerEq(cut))); @@ -98,7 +98,7 @@ typename GenericMaxFlow::Status MaxFlowTester( template class GenericMaxFlowTest : public ::testing::Test {}; -typedef ::testing::Types, +typedef ::testing::Types, util::ReverseArcStaticGraph<>, util::ReverseArcMixedGraph<>> GraphTypes; @@ -176,6 +176,7 @@ TYPED_TEST(GenericMaxFlowTest, HugeCapacity) { } TYPED_TEST(GenericMaxFlowTest, FlowQuantityOverflowLimitCase) { + using FlowQuantity = typename GenericMaxFlow::FlowQuantityT; const FlowQuantity kCapacityMax = std::numeric_limits::max(); const FlowQuantity kHalfLow = kCapacityMax / 2; const FlowQuantity kHalfHigh = kCapacityMax - kHalfLow; @@ -197,6 +198,7 @@ TYPED_TEST(GenericMaxFlowTest, FlowQuantityOverflowLimitCase) { } TYPED_TEST(GenericMaxFlowTest, FlowQuantityOverflow) { + using FlowQuantity = typename GenericMaxFlow::FlowQuantityT; const FlowQuantity kCapacityMax = std::numeric_limits::max(); const int kNumNodes = 4; const int kNumArcs = 4; @@ -394,7 +396,7 @@ void FullAssignment(std::optional unused, typename Graph::NodeIndex num_heads) { Graph graph; GenerateCompleteGraph(num_tails, num_heads, &graph); - Graphs::Build(&graph); + graph.Build(); std::vector arc_capacity(graph.num_arcs(), 1); std::unique_ptr> max_flow(new GenericMaxFlow( &graph, graph.num_nodes() - 2, graph.num_nodes() - 1)); @@ -470,7 +472,7 @@ void PartialRandomFlow(std::optional expected_flow, GenerateRandomArcValuations(random, graph, kCapacityRange, &arc_capacity); std::vector permutation; - Graphs::Build(&graph, &permutation); + graph.Build(&permutation); arc_capacity.resize(graph.num_arcs(), 0); // In case Build() adds more arcs. util::Permute(permutation, &arc_capacity); @@ -518,7 +520,7 @@ void FullRandomFlow(std::optional expected_flow, GenerateRandomArcValuations(random, graph, kCapacityRange, &arc_capacity); std::vector permutation; - Graphs::Build(&graph, &permutation); + graph.Build(&permutation); arc_capacity.resize(graph.num_arcs(), 0); // In case Build() adds more arcs. util::Permute(permutation, &arc_capacity); @@ -554,11 +556,7 @@ void FullRandomFlow(std::optional expected_flow, TEST(MaxFlowListGraphTest, test_name##size) { \ test_name>(std::nullopt, SolveMaxFlow, size, \ size); \ - } \ - TEST(MaxFlowStarGraphTest, test_name##size) { \ - test_name(std::nullopt, SolveMaxFlow, size, size); \ } - // These are absl::BitGen random test, so they will always work on different // graphs. LP_AND_FLOW_TEST(FullAssignment, 300); @@ -604,22 +602,18 @@ static void BM_FullRandomFlow(benchmark::State& state) { } // Note that these benchmark include the graph creation and generation... -BENCHMARK_TEMPLATE(BM_FullRandomAssignment, StarGraph); BENCHMARK_TEMPLATE(BM_FullRandomAssignment, util::ReverseArcListGraph<>); BENCHMARK_TEMPLATE(BM_FullRandomAssignment, util::ReverseArcStaticGraph<>); BENCHMARK_TEMPLATE(BM_FullRandomAssignment, util::ReverseArcMixedGraph<>); -BENCHMARK_TEMPLATE(BM_PartialRandomFlow, StarGraph); BENCHMARK_TEMPLATE(BM_PartialRandomFlow, util::ReverseArcListGraph<>); BENCHMARK_TEMPLATE(BM_PartialRandomFlow, util::ReverseArcStaticGraph<>); BENCHMARK_TEMPLATE(BM_PartialRandomFlow, util::ReverseArcMixedGraph<>); -BENCHMARK_TEMPLATE(BM_FullRandomFlow, StarGraph); BENCHMARK_TEMPLATE(BM_FullRandomFlow, util::ReverseArcListGraph<>); BENCHMARK_TEMPLATE(BM_FullRandomFlow, util::ReverseArcStaticGraph<>); BENCHMARK_TEMPLATE(BM_FullRandomFlow, util::ReverseArcMixedGraph<>); -BENCHMARK_TEMPLATE(BM_PartialRandomAssignment, StarGraph); BENCHMARK_TEMPLATE(BM_PartialRandomAssignment, util::ReverseArcListGraph<>); BENCHMARK_TEMPLATE(BM_PartialRandomAssignment, util::ReverseArcStaticGraph<>); BENCHMARK_TEMPLATE(BM_PartialRandomAssignment, util::ReverseArcMixedGraph<>); diff --git a/ortools/graph/graph.h b/ortools/graph/graph.h index 67797c8a8e..89e6dba2ac 100644 --- a/ortools/graph/graph.h +++ b/ortools/graph/graph.h @@ -169,6 +169,7 @@ #include "absl/debugging/leak_check.h" #include "absl/log/check.h" #include "absl/types/span.h" +#include "ortools/base/constant_divisor.h" #include "ortools/base/logging.h" #include "ortools/base/macros.h" #include "ortools/base/types.h" @@ -271,14 +272,6 @@ class BaseGraph { static const NodeIndexType kNilNode; static const ArcIndexType kNilArc; - // TODO(user): remove the public functions below. They are just here during - // the transition from the old ebert_graph api to this new graph api. - template - void GroupForwardArcsByFunctor(const A& a, B* b) { - LOG(FATAL) << "Not supported"; - } - ArcIndexType max_end_arc_index() const { return arc_capacity_; } - protected: // Functions commented when defined because they are implementation details. void ComputeCumulativeSum(std::vector* v); @@ -515,7 +508,8 @@ class ReverseArcListGraph // for (const Graph::ArcIndex arc : IterationFunction(node)) { ... } // // The StartingFrom() version are similar, but restart the iteration from a - // given arc position (which must be valid in the iteration context). + // given arc position (which must be valid in the iteration context), or + // `kNilArc`, in which case an empty range is returned. BeginEndWrapper OutgoingArcs(NodeIndexType node) const; BeginEndWrapper IncomingArcs(NodeIndexType node) const; BeginEndWrapper @@ -1097,19 +1091,21 @@ void BaseGraph:: // - t: the iteration type (Outgoing, Incoming, OutgoingOrOppositeIncoming // or OppositeIncoming). // - e: the "end" ArcIndexType. -#define DEFINE_RANGE_BASED_ARC_ITERATION(c, t, e) \ - template \ - BeginEndWrapper::t##ArcIterator> \ - c::t##Arcs(NodeIndexType node) const { \ - return BeginEndWrapper(t##ArcIterator(*this, node), \ - t##ArcIterator(*this, node, e)); \ - } \ - template \ - BeginEndWrapper::t##ArcIterator> \ - c::t##ArcsStartingFrom( \ - NodeIndexType node, ArcIndexType from) const { \ - return BeginEndWrapper(t##ArcIterator(*this, node, from), \ - t##ArcIterator(*this, node, e)); \ +#define DEFINE_RANGE_BASED_ARC_ITERATION(c, t) \ + template \ + BeginEndWrapper::t##ArcIterator> \ + c::t##Arcs(NodeIndexType node) const { \ + return BeginEndWrapper( \ + t##ArcIterator(*this, node), \ + t##ArcIterator(*this, node, Base::kNilArc)); \ + } \ + template \ + BeginEndWrapper::t##ArcIterator> \ + c::t##ArcsStartingFrom( \ + NodeIndexType node, ArcIndexType from) const { \ + return BeginEndWrapper( \ + t##ArcIterator(*this, node, from), \ + t##ArcIterator(*this, node, Base::kNilArc)); \ } // Adapt our old iteration style to support range-based for loops. Add typedefs @@ -1131,7 +1127,7 @@ void BaseGraph:: // ListGraph implementation ---------------------------------------------------- -DEFINE_RANGE_BASED_ARC_ITERATION(ListGraph, Outgoing, Base::kNilArc); +DEFINE_RANGE_BASED_ARC_ITERATION(ListGraph, Outgoing); template BeginEndWrapper< @@ -1289,7 +1285,7 @@ StaticGraph::FromArcs(NodeIndexType num_nodes, return g; } -DEFINE_RANGE_BASED_ARC_ITERATION(StaticGraph, Outgoing, DirectArcLimit(node)); +DEFINE_RANGE_BASED_ARC_ITERATION(StaticGraph, Outgoing); template absl::Span @@ -1443,11 +1439,12 @@ class StaticGraph::OutgoingArcIterator { : index_(graph.start_[node]), limit_(graph.DirectArcLimit(node)) {} OutgoingArcIterator(const StaticGraph& graph, NodeIndexType node, ArcIndexType arc) - : index_(arc), limit_(graph.DirectArcLimit(node)) { + : limit_(graph.DirectArcLimit(node)) { + index_ = arc == Base::kNilArc ? limit_ : arc; DCHECK_GE(arc, graph.start_[node]); } - bool Ok() const { return index_ < limit_; } + bool Ok() const { return index_ != limit_; } ArcIndexType Index() const { return index_; } void Next() { DCHECK(Ok()); @@ -1470,12 +1467,11 @@ class StaticGraph::OutgoingArcIterator { // ReverseArcListGraph implementation ------------------------------------------ -DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph, Outgoing, Base::kNilArc); -DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph, Incoming, Base::kNilArc); +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph, Outgoing); +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph, Incoming); DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph, - OutgoingOrOppositeIncoming, Base::kNilArc); -DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph, OppositeIncoming, - Base::kNilArc); + OutgoingOrOppositeIncoming); +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph, OppositeIncoming); template BeginEndWrapper::IncomingArcIterator : public OppositeIncomingArcIterator { public: IncomingArcIterator(const ReverseArcListGraph& graph, NodeIndexType node) - : OppositeIncomingArcIterator(graph, node) {} + : OppositeIncomingArcIterator(graph, node), graph_(graph) {} IncomingArcIterator(const ReverseArcListGraph& graph, NodeIndexType node, ArcIndexType arc) : OppositeIncomingArcIterator( graph, node, - arc == Base::kNilArc ? Base::kNilArc : graph.OppositeArc(arc)) {} + arc == Base::kNilArc ? Base::kNilArc : graph.OppositeArc(arc)), + graph_(graph) {} // We overwrite OppositeIncomingArcIterator::Index() here. ArcIndexType Index() const { - return this->index_ == Base::kNilArc - ? Base::kNilArc - : this->graph_.OppositeArc(this->index_); + return this->index_ == Base::kNilArc ? Base::kNilArc + : graph_.OppositeArc(this->index_); } DEFINE_STL_ITERATOR_FUNCTIONS(IncomingArcIterator); + + private: + const ReverseArcListGraph& graph_; }; template @@ -1723,15 +1722,11 @@ class ReverseArcListGraph::OutgoingHeadIterator { // ReverseArcStaticGraph implementation ---------------------------------------- -DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcStaticGraph, Outgoing, - DirectArcLimit(node)); -DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcStaticGraph, Incoming, - ReverseArcLimit(node)); +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcStaticGraph, Outgoing); +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcStaticGraph, Incoming); DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcStaticGraph, - OutgoingOrOppositeIncoming, - DirectArcLimit(node)); -DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcStaticGraph, OppositeIncoming, - ReverseArcLimit(node)); + OutgoingOrOppositeIncoming); +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcStaticGraph, OppositeIncoming); template ArcIndexType ReverseArcStaticGraph::OutDegree( @@ -1858,11 +1853,12 @@ class ReverseArcStaticGraph::OutgoingArcIterator { : index_(graph.start_[node]), limit_(graph.DirectArcLimit(node)) {} OutgoingArcIterator(const ReverseArcStaticGraph& graph, NodeIndexType node, ArcIndexType arc) - : index_(arc), limit_(graph.DirectArcLimit(node)) { + : limit_(graph.DirectArcLimit(node)) { + index_ = arc == Base::kNilArc ? limit_ : arc; DCHECK_GE(arc, graph.start_[node]); } - bool Ok() const { return index_ < limit_; } + bool Ok() const { return index_ != limit_; } ArcIndexType Index() const { return index_; } void Next() { DCHECK(Ok()); @@ -1884,21 +1880,21 @@ class ReverseArcStaticGraph::IncomingArcIterator : public OppositeIncomingArcIterator { public: IncomingArcIterator(const ReverseArcStaticGraph& graph, NodeIndexType node) - : OppositeIncomingArcIterator(graph, node) {} + : OppositeIncomingArcIterator(graph, node), graph_(graph) {} IncomingArcIterator(const ReverseArcStaticGraph& graph, NodeIndexType node, ArcIndexType arc) : OppositeIncomingArcIterator(graph, node, - arc == graph.ReverseArcLimit(node) - ? graph.ReverseArcLimit(node) - : graph.OppositeArc(arc)) {} + arc == Base::kNilArc + ? Base::kNilArc + : (arc == graph.ReverseArcLimit(node) + ? graph.ReverseArcLimit(node) + : graph.OppositeArc(arc))), + graph_(graph) {} ArcIndexType Index() const { - return this->index_ == this->limit_ - ? this->limit_ - : this->graph_.OppositeArc(this->index_); + return this->index_ == this->limit_ ? this->limit_ + : graph_.OppositeArc(this->index_); } DEFINE_STL_ITERATOR_FUNCTIONS(IncomingArcIterator); + + private: + const ReverseArcStaticGraph& graph_; }; template @@ -1951,17 +1951,17 @@ class ReverseArcStaticGraph< } OutgoingOrOppositeIncomingArcIterator(const ReverseArcStaticGraph& graph, NodeIndexType node, ArcIndexType arc) - : index_(arc), - first_limit_(graph.ReverseArcLimit(node)), + : first_limit_(graph.ReverseArcLimit(node)), next_start_(graph.start_[node]), limit_(graph.DirectArcLimit(node)) { + index_ = arc == Base::kNilArc ? limit_ : arc; DCHECK(graph.IsNodeValid(node)); DCHECK((index_ >= graph.reverse_start_[node] && index_ < first_limit_) || (index_ >= next_start_)); } ArcIndexType Index() const { return index_; } - bool Ok() const { return index_ < limit_; } + bool Ok() const { return index_ != limit_; } void Next() { DCHECK(Ok()); index_++; @@ -1981,14 +1981,11 @@ class ReverseArcStaticGraph< // ReverseArcMixedGraph implementation ----------------------------------------- -DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcMixedGraph, Outgoing, - DirectArcLimit(node)); -DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcMixedGraph, Incoming, Base::kNilArc); +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcMixedGraph, Outgoing); +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcMixedGraph, Incoming); DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcMixedGraph, - OutgoingOrOppositeIncoming, - DirectArcLimit(node)); -DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcMixedGraph, OppositeIncoming, - Base::kNilArc); + OutgoingOrOppositeIncoming); +DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcMixedGraph, OppositeIncoming); template ArcIndexType ReverseArcMixedGraph::OutDegree( @@ -2100,11 +2097,12 @@ class ReverseArcMixedGraph::OutgoingArcIterator { : index_(graph.start_[node]), limit_(graph.DirectArcLimit(node)) {} OutgoingArcIterator(const ReverseArcMixedGraph& graph, NodeIndexType node, ArcIndexType arc) - : index_(arc), limit_(graph.DirectArcLimit(node)) { + : limit_(graph.DirectArcLimit(node)) { + index_ = arc == Base::kNilArc ? limit_ : arc; DCHECK_GE(arc, graph.start_[node]); } - bool Ok() const { return index_ < limit_; } + bool Ok() const { return index_ != limit_; } ArcIndexType Index() const { return index_; } void Next() { DCHECK(Ok()); @@ -2162,7 +2160,8 @@ class ReverseArcMixedGraph::IncomingArcIterator IncomingArcIterator(const ReverseArcMixedGraph& graph, NodeIndexType node, ArcIndexType arc) : OppositeIncomingArcIterator( - graph, node, arc == Base::kNilArc ? arc : graph.OppositeArc(arc)) {} + graph, node, + arc == Base::kNilArc ? Base::kNilArc : graph.OppositeArc(arc)) {} ArcIndexType Index() const { return this->index_ == Base::kNilArc ? Base::kNilArc @@ -2190,14 +2189,11 @@ class ReverseArcMixedGraph< NodeIndexType node, ArcIndexType arc) : graph_(&graph) { limit_ = graph.DirectArcLimit(node); - index_ = arc; + index_ = arc == Base::kNilArc ? limit_ : arc; restart_ = graph.start_[node]; DCHECK(arc == Base::kNilArc || arc == limit_ || graph.Tail(arc) == node); } - bool Ok() const { - // Note that we always have limit_ <= Base::kNilArc. - return index_ < limit_; - } + bool Ok() const { return index_ != limit_; } ArcIndexType Index() const { return index_; } void Next() { DCHECK(Ok()); @@ -2234,7 +2230,10 @@ class CompleteGraph : public BaseGraph { public: // Builds a complete graph with num_nodes nodes. - explicit CompleteGraph(NodeIndexType num_nodes) { + explicit CompleteGraph(NodeIndexType num_nodes) + : // If there are 0 or 1 nodes, the divisor is arbitrary. We pick 2 as 0 + // and 1 are not supported by `ConstantDivisor`. + divisor_(num_nodes > 1 ? num_nodes : 2) { this->Reserve(num_nodes, num_nodes * num_nodes); this->FreezeCapacities(); num_nodes_ = num_nodes; @@ -2248,20 +2247,23 @@ class CompleteGraph : public BaseGraph { IntegerRange OutgoingArcsStartingFrom(NodeIndexType node, ArcIndexType from) const; IntegerRange operator[](NodeIndexType node) const; + + const ::util::math::ConstantDivisor> + divisor_; }; template NodeIndexType CompleteGraph::Head( ArcIndexType arc) const { DCHECK(this->IsArcValid(arc)); - return arc % num_nodes_; + return arc % divisor_; } template NodeIndexType CompleteGraph::Tail( ArcIndexType arc) const { DCHECK(this->IsArcValid(arc)); - return arc / num_nodes_; + return arc / divisor_; } template @@ -2316,7 +2318,12 @@ class CompleteBipartiteGraph // Indices of left nodes of the bipartite graph range from 0 to left_nodes-1; // indices of right nodes range from left_nodes to left_nodes+right_nodes-1. CompleteBipartiteGraph(NodeIndexType left_nodes, NodeIndexType right_nodes) - : left_nodes_(left_nodes), right_nodes_(right_nodes) { + : left_nodes_(left_nodes), + right_nodes_(right_nodes), + // If there are no right nodes, the divisor is arbitrary. We pick 2 as + // 0 and 1 are not supported by `ConstantDivisor`. We handle the case + // where `right_nodes` is 1 explicitly when dividing. + divisor_(right_nodes > 1 ? right_nodes : 2) { this->Reserve(left_nodes + right_nodes, left_nodes * right_nodes); this->FreezeCapacities(); num_nodes_ = left_nodes + right_nodes; @@ -2358,6 +2365,9 @@ class CompleteBipartiteGraph private: const NodeIndexType left_nodes_; const NodeIndexType right_nodes_; + // Note: only valid if `right_nodes_ > 1`. + const ::util::math::ConstantDivisor> + divisor_; }; template @@ -2374,14 +2384,16 @@ template NodeIndexType CompleteBipartiteGraph::Head( ArcIndexType arc) const { DCHECK(this->IsArcValid(arc)); - return left_nodes_ + arc % right_nodes_; + // See comment on `divisor_` in the constructor. + return right_nodes_ > 1 ? left_nodes_ + arc % divisor_ : left_nodes_; } template NodeIndexType CompleteBipartiteGraph::Tail( ArcIndexType arc) const { DCHECK(this->IsArcValid(arc)); - return arc / right_nodes_; + // See comment on `divisor_` in the constructor. + return right_nodes_ > 1 ? arc / divisor_ : arc; } template diff --git a/ortools/graph/graphs.h b/ortools/graph/graphs.h deleted file mode 100644 index cafe279a84..0000000000 --- a/ortools/graph/graphs.h +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2010-2025 Google LLC -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Temporary utility class needed as long as we have two slightly -// different graph interface: The one in ebert_graph.h and the one in graph.h - -#ifndef OR_TOOLS_GRAPH_GRAPHS_H_ -#define OR_TOOLS_GRAPH_GRAPHS_H_ - -#include - -#include "ortools/graph/ebert_graph.h" - -namespace operations_research { - -// Since StarGraph does not have exactly the same interface as the other -// graphs, we define a correspondence there. -template -struct Graphs { - typedef typename Graph::ArcIndex ArcIndex; - typedef typename Graph::NodeIndex NodeIndex; - static ArcIndex OppositeArc(const Graph& graph, ArcIndex arc) { - return graph.OppositeArc(arc); - } - static bool IsArcValid(const Graph& graph, ArcIndex arc) { - return graph.IsArcValid(arc); - } - static NodeIndex NodeReservation(const Graph& graph) { - return graph.node_capacity(); - } - static ArcIndex ArcReservation(const Graph& graph) { - return graph.arc_capacity(); - } - static void Build(Graph* graph) { graph->Build(); } - static void Build(Graph* graph, std::vector* permutation) { - graph->Build(permutation); - } -}; - -template <> -struct Graphs { - typedef operations_research::StarGraph Graph; -#if defined(_MSC_VER) - typedef Graph::ArcIndex ArcIndex; - typedef Graph::NodeIndex NodeIndex; -#else - typedef typename Graph::ArcIndex ArcIndex; - typedef typename Graph::NodeIndex NodeIndex; -#endif - static ArcIndex OppositeArc(const Graph& graph, ArcIndex arc) { - return graph.Opposite(arc); - } - static bool IsArcValid(const Graph& graph, ArcIndex arc) { - return graph.CheckArcValidity(arc); - } - static NodeIndex NodeReservation(const Graph& graph) { - return graph.max_num_nodes(); - } - static ArcIndex ArcReservation(const Graph& graph) { - return graph.max_num_arcs(); - } - static void Build(Graph* graph) {} - static void Build(Graph* graph, std::vector* permutation) { - permutation->clear(); - } -}; - -} // namespace operations_research - -#endif // OR_TOOLS_GRAPH_GRAPHS_H_ diff --git a/ortools/graph/java/graph.i b/ortools/graph/java/graph.i index dc3e12bed8..db096aa644 100644 --- a/ortools/graph/java/graph.i +++ b/ortools/graph/java/graph.i @@ -34,8 +34,6 @@ %include "ortools/base/base.i" -%import "ortools/graph/ebert_graph.h" - %{ #include "ortools/graph/assignment.h" #include "ortools/graph/max_flow.h" @@ -102,6 +100,7 @@ %unignore operations_research::SimpleMinCostFlow::~SimpleMinCostFlow; %rename (addArcWithCapacityAndUnitCost) operations_research::SimpleMinCostFlow::AddArcWithCapacityAndUnitCost; +%rename (setArcCapacity) operations_research::SimpleMinCostFlow::SetArcCapacity; %rename (setNodeSupply) operations_research::SimpleMinCostFlow::SetNodeSupply; %rename (solve) operations_research::SimpleMinCostFlow::Solve; %rename (solveMaxFlowWithMinCost) diff --git a/ortools/graph/k_shortest_paths.h b/ortools/graph/k_shortest_paths.h index c7a5842a64..cd61e2b002 100644 --- a/ortools/graph/k_shortest_paths.h +++ b/ortools/graph/k_shortest_paths.h @@ -70,7 +70,6 @@ #include "absl/types/span.h" #include "ortools/base/logging.h" #include "ortools/graph/bounded_dijkstra.h" -#include "ortools/graph/ebert_graph.h" #include "ortools/graph/shortest_paths.h" namespace operations_research { @@ -82,6 +81,7 @@ namespace operations_research { // The paths in `paths` start with `origin` and end at `destination`. // // If the computations are unsuccessful for any reason, the vectors are empty. +template struct KShortestPaths { // The paths are stored as vectors of nodes, like the other graph algorithms. // TODO(user): what about vectors of arcs? That might be faster @@ -89,7 +89,7 @@ struct KShortestPaths { // user really needs it). It would also have the nice benefit of removing the // need for `distances` (compute it on the fly), with a reference to the graph // and the costs. - std::vector> paths; + std::vector> paths; std::vector distances; }; @@ -113,10 +113,10 @@ struct KShortestPaths { // Science. 17 (11): 712–716, 1971. // https://doi.org/10.1287%2Fmnsc.17.11.712 template -KShortestPaths YenKShortestPaths(const GraphType& graph, - const std::vector& arc_lengths, - NodeIndex source, NodeIndex destination, - unsigned k); +KShortestPaths YenKShortestPaths( + const GraphType& graph, const std::vector& arc_lengths, + typename GraphType::NodeIndex source, + typename GraphType::NodeIndex destination, unsigned k); // End of the interface. Below is the implementation. @@ -137,23 +137,26 @@ const PathDistance kDisconnectedDistance = // In a multigraph, this function returns an index for one of the edges between // the source and the destination. template -ArcIndex FindArcIndex(const GraphType& graph, const NodeIndex source, - const NodeIndex destination) { +typename GraphType::ArcIndex FindArcIndex( + const GraphType& graph, const typename GraphType::NodeIndex source, + const typename GraphType::NodeIndex destination) { const auto outgoing_arcs_iter = graph.OutgoingArcs(source); - const auto arc = - std::find_if(outgoing_arcs_iter.begin(), outgoing_arcs_iter.end(), - [&graph, destination](const ArcIndex arc) { - return graph.Head(arc) == destination; - }); + const auto arc = std::find_if( + outgoing_arcs_iter.begin(), outgoing_arcs_iter.end(), + [&graph, destination](const typename GraphType::ArcIndex arc) { + return graph.Head(arc) == destination; + }); return (arc != outgoing_arcs_iter.end()) ? *arc : GraphType::kNilArc; } // Determines the shortest path from the given source and destination, returns a // tuple with the path (as a vector of node indices) and its cost. template -std::tuple, PathDistance> ComputeShortestPath( - const GraphType& graph, const std::vector& arc_lengths, - const NodeIndex source, const NodeIndex destination) { +std::tuple, PathDistance> +ComputeShortestPath(const GraphType& graph, + const std::vector& arc_lengths, + const typename GraphType::NodeIndex source, + const typename GraphType::NodeIndex destination) { BoundedDijkstraWrapper dijkstra(&graph, &arc_lengths); dijkstra.RunBoundedDijkstra(source, kMaxDistance); @@ -165,25 +168,29 @@ std::tuple, PathDistance> ComputeShortestPath( // This case only happens when some arcs have an infinite length (i.e. // larger than `kMaxDistance`): `BoundedDijkstraWrapper::NodePathTo` fails // to return a path, even empty. - return {std::vector{}, kDisconnectedDistance}; + return {std::vector{}, + kDisconnectedDistance}; } - if (std::vector path = std::move(dijkstra.NodePathTo(destination)); + if (std::vector path = + std::move(dijkstra.NodePathTo(destination)); !path.empty()) { return {std::move(path), path_length}; } else { - return {std::vector{}, kDisconnectedDistance}; + return {std::vector{}, + kDisconnectedDistance}; } } // Computes the total length of a path. template -PathDistance ComputePathLength(const GraphType& graph, - const absl::Span arc_lengths, - const absl::Span path) { +PathDistance ComputePathLength( + const GraphType& graph, const absl::Span arc_lengths, + const absl::Span path) { PathDistance distance = 0; - for (NodeIndex i = 0; i < path.size() - 1; ++i) { - const ArcIndex arc = internal::FindArcIndex(graph, path[i], path[i + 1]); + for (typename GraphType::NodeIndex i = 0; i < path.size() - 1; ++i) { + const typename GraphType::ArcIndex arc = + internal::FindArcIndex(graph, path[i], path[i + 1]); DCHECK_NE(arc, GraphType::kNilArc); distance += arc_lengths[arc]; } @@ -192,8 +199,11 @@ PathDistance ComputePathLength(const GraphType& graph, // Stores a path with a priority (typically, the distance), with a comparison // operator that operates on the priority. +template class PathWithPriority { public: + using NodeIndex = typename GraphType::NodeIndex; + PathWithPriority(PathDistance priority, std::vector path) : path_(std::move(path)), priority_(priority) {} bool operator<(const PathWithPriority& other) const { @@ -265,10 +275,12 @@ class UnderlyingContainerAdapter : public Container { // spur paths, the cheapest being: // S_1^2 = B - E - F - G - H template -KShortestPaths YenKShortestPaths(const GraphType& graph, - const std::vector& arc_lengths, - NodeIndex source, NodeIndex destination, - unsigned k) { +KShortestPaths YenKShortestPaths( + const GraphType& graph, const std::vector& arc_lengths, + typename GraphType::NodeIndex source, + typename GraphType::NodeIndex destination, unsigned k) { + using NodeIndex = typename GraphType::NodeIndex; + CHECK_GT(internal::kDisconnectedDistance, internal::kMaxDistance); CHECK_GE(k, 0) << "k must be nonnegative. Input value: " << k; @@ -289,7 +301,7 @@ KShortestPaths YenKShortestPaths(const GraphType& graph, << destination << ". Number of nodes in the input graph: " << graph.num_nodes(); - KShortestPaths paths; + KShortestPaths paths; // First step: compute the shortest path. { @@ -306,7 +318,7 @@ KShortestPaths YenKShortestPaths(const GraphType& graph, // Generate variant paths. internal::UnderlyingContainerAdapter< - std::priority_queue> + std::priority_queue>> variant_path_queue; // One path has already been generated (the shortest one). Only k-1 more @@ -364,7 +376,7 @@ KShortestPaths YenKShortestPaths(const GraphType& graph, previous_path.begin() + root_path.length()); if (!has_same_prefix_as_root_path) continue; - const ArcIndex after_spur_node_arc = + const typename GraphType::ArcIndex after_spur_node_arc = internal::FindArcIndex(graph, previous_path[spur_node_position], previous_path[spur_node_position + 1]); VLOG(4) << " after_spur_node_arc: " << graph.Tail(after_spur_node_arc) @@ -417,8 +429,8 @@ KShortestPaths YenKShortestPaths(const GraphType& graph, // coincide at the spur node). const bool root_path_leads_to_spur_path = absl::c_any_of( graph.OutgoingArcs(root_path.back()), - [&graph, node_after_spur_in_spur_path = - *(spur_path.begin() + 1)](const ArcIndex arc_index) { + [&graph, node_after_spur_in_spur_path = *(spur_path.begin() + 1)]( + const typename GraphType::ArcIndex arc_index) { return graph.Head(arc_index) == node_after_spur_in_spur_path; }); CHECK(root_path_leads_to_spur_path); @@ -471,12 +483,12 @@ KShortestPaths YenKShortestPaths(const GraphType& graph, // filter by fingerprints? Due to the probability of error with // fingerprints, still use this slow-but-exact code, but after // filtering. - const bool is_new_path_already_known = - std::any_of(variant_path_queue.container().cbegin(), - variant_path_queue.container().cend(), - [&new_path](const internal::PathWithPriority& element) { - return element.path() == new_path; - }); + const bool is_new_path_already_known = std::any_of( + variant_path_queue.container().cbegin(), + variant_path_queue.container().cend(), + [&new_path](const internal::PathWithPriority& element) { + return element.path() == new_path; + }); if (is_new_path_already_known) continue; const PathDistance path_length = @@ -498,7 +510,7 @@ KShortestPaths YenKShortestPaths(const GraphType& graph, // this iteration found no shorter one. if (variant_path_queue.empty()) break; - const internal::PathWithPriority& next_shortest_path = + const internal::PathWithPriority& next_shortest_path = variant_path_queue.top(); VLOG(5) << "> New path generated: " << absl::StrJoin(next_shortest_path.path(), " - ") << " (" diff --git a/ortools/graph/k_shortest_paths_test.cc b/ortools/graph/k_shortest_paths_test.cc index cacd65ccba..ed7bfe053a 100644 --- a/ortools/graph/k_shortest_paths_test.cc +++ b/ortools/graph/k_shortest_paths_test.cc @@ -128,8 +128,9 @@ TEST(KShortestPathsYenTest, ReducesToShortestPath) { (void)graph.Build(); std::vector lengths{1, 1}; - const KShortestPaths paths = YenKShortestPaths(graph, lengths, /*source=*/0, - /*destination=*/2, /*k=*/1); + const KShortestPaths> paths = + YenKShortestPaths(graph, lengths, /*source=*/0, + /*destination=*/2, /*k=*/1); EXPECT_THAT(paths.paths, ElementsAre(std::vector{0, 1, 2})); EXPECT_THAT(paths.distances, ElementsAre(2)); } @@ -141,8 +142,9 @@ TEST(KShortestPathsYenTest, OnlyHasOnePath) { (void)graph.Build(); std::vector lengths{1, 1}; - const KShortestPaths paths = YenKShortestPaths(graph, lengths, /*source=*/0, - /*destination=*/2, /*k=*/10); + const KShortestPaths> paths = + YenKShortestPaths(graph, lengths, /*source=*/0, + /*destination=*/2, /*k=*/10); EXPECT_THAT(paths.paths, ElementsAre(std::vector{0, 1, 2})); EXPECT_THAT(paths.distances, ElementsAre(2)); } @@ -155,8 +157,9 @@ TEST(KShortestPathsYenTest, HasTwoPaths) { (void)graph.Build(); std::vector lengths{1, 30, 1}; - const KShortestPaths paths = YenKShortestPaths(graph, lengths, /*source=*/0, - /*destination=*/2, /*k=*/10); + const KShortestPaths> paths = + YenKShortestPaths(graph, lengths, /*source=*/0, + /*destination=*/2, /*k=*/10); EXPECT_THAT(paths.paths, ElementsAre(std::vector{0, 1, 2}, std::vector{0, 2})); EXPECT_THAT(paths.distances, ElementsAre(2, 30)); @@ -172,8 +175,9 @@ TEST(KShortestPathsYenTest, HasTwoPathsWithLongerPath) { (void)graph.Build(); std::vector lengths{1, 30, 1, 1, 1}; - const KShortestPaths paths = YenKShortestPaths(graph, lengths, /*source=*/0, - /*destination=*/4, /*k=*/10); + const KShortestPaths> paths = + YenKShortestPaths(graph, lengths, /*source=*/0, + /*destination=*/4, /*k=*/10); EXPECT_THAT(paths.paths, ElementsAre(std::vector{0, 1, 2, 3, 4}, std::vector{0, 4})); EXPECT_THAT(paths.distances, ElementsAre(4, 30)); @@ -190,8 +194,9 @@ TEST(KShortestPathsYenTest, ReturnsTheRightNumberOfPaths) { (void)graph.Build(); std::vector lengths{1, 1, 1, 1, 1}; - const KShortestPaths paths = YenKShortestPaths(graph, lengths, /*source=*/0, - /*destination=*/2, /*k=*/2); + const KShortestPaths> paths = + YenKShortestPaths(graph, lengths, /*source=*/0, + /*destination=*/2, /*k=*/2); EXPECT_THAT(paths.paths, ElementsAre(std::vector{0, 2}, std::vector{0, 1, 2})); EXPECT_THAT(paths.distances, ElementsAre(1, 2)); diff --git a/ortools/graph/linear_assignment.h b/ortools/graph/linear_assignment.h index 868da44ed4..562f0a8069 100644 --- a/ortools/graph/linear_assignment.h +++ b/ortools/graph/linear_assignment.h @@ -209,6 +209,7 @@ #include "absl/strings/str_format.h" #include "ortools/base/logging.h" #include "ortools/graph/ebert_graph.h" +#include "ortools/graph/iterators.h" #include "ortools/util/permutation.h" #include "ortools/util/zvector.h" @@ -230,7 +231,7 @@ class LinearSumAssignment { // Constructor for the case in which we will build the graph // incrementally as we discover arc costs, as might be done with any - // of the dynamic graph representations such as StarGraph or ForwardStarGraph. + // of the dynamic graph representations such as `ReverseArcListGraph`. LinearSumAssignment(const GraphType& graph, NodeIndex num_left_nodes); // Constructor for the case in which the underlying graph cannot be built @@ -266,21 +267,6 @@ class LinearSumAssignment { operations_research::PermutationCycleHandler* ArcAnnotationCycleHandler(); - // Optimizes the layout of the graph for the access pattern our - // implementation will use. - // - // REQUIRES for LinearSumAssignment template instantiation if a call - // to the OptimizeGraphLayout() method is compiled: GraphType is a - // dynamic graph, i.e., one that implements the - // GroupForwardArcsByFunctor() member template method. - // - // If analogous optimization is needed for LinearSumAssignment - // instances based on static graphs, the graph layout should be - // constructed such that each node's outgoing arcs are sorted by - // head node index before the - // LinearSumAssignment::SetGraph() method is called. - void OptimizeGraphLayout(GraphType* graph); - // Allows tests, iterators, etc., to inspect our underlying graph. inline const GraphType& Graph() const { return *graph_; } @@ -360,24 +346,10 @@ class LinearSumAssignment { std::string StatsString() const { return total_stats_.StatsString(); } - class BipartiteLeftNodeIterator { - public: - BipartiteLeftNodeIterator(const GraphType& graph, NodeIndex num_left_nodes) - : num_left_nodes_(num_left_nodes), node_iterator_(0) {} - - explicit BipartiteLeftNodeIterator(const LinearSumAssignment& assignment) - : num_left_nodes_(assignment.NumLeftNodes()), node_iterator_(0) {} - - NodeIndex Index() const { return node_iterator_; } - - bool Ok() const { return node_iterator_ < num_left_nodes_; } - - void Next() { ++node_iterator_; } - - private: - const NodeIndex num_left_nodes_; - typename GraphType::NodeIndex node_iterator_; - }; + // Returns the range of valid left node indices. + ::util::IntegerRange BipartiteLeftNodes() const { + return ::util::IntegerRange(0, num_left_nodes_); + } // Returns true if and only if the current pseudoflow is // epsilon-optimal. To be used in a DCHECK. @@ -976,7 +948,7 @@ LinearSumAssignment::LinearSumAssignment( price_(num_left_nodes, 2 * num_left_nodes - 1), matched_arc_(num_left_nodes, 0), matched_node_(num_left_nodes, 2 * num_left_nodes - 1), - scaled_arc_cost_(graph.max_end_arc_index(), 0), + scaled_arc_cost_(graph.arc_capacity(), 0), active_nodes_(absl::GetFlag(FLAGS_assignment_stack_order) ? static_cast( new ActiveNodeStack()) @@ -1088,22 +1060,6 @@ LinearSumAssignment::ArcAnnotationCycleHandler() { &scaled_arc_cost_); } -template -void LinearSumAssignment::OptimizeGraphLayout( - GraphType* graph) { - // The graph argument is only to give us a non-const-qualified - // handle on the graph we already have. Any different graph is - // nonsense. - DCHECK_EQ(graph_, graph); - const ArcIndexOrderingByTailNode compare(*graph_); - CostValueCycleHandler cycle_handler( - &scaled_arc_cost_); - TailArrayManager tail_array_manager(graph); - tail_array_manager.BuildTailArrayFromAdjacencyListsIfForwardGraph(); - graph->GroupForwardArcsByFunctor(compare, &cycle_handler); - tail_array_manager.ReleaseTailArrayIfForwardGraph(); -} - template CostValue LinearSumAssignment::NewEpsilon( const CostValue current_epsilon) const { @@ -1149,9 +1105,7 @@ template void LinearSumAssignment::InitializeActiveNodeContainer() { DCHECK(active_nodes_->Empty()); - for (BipartiteLeftNodeIterator node_it(*graph_, num_left_nodes_); - node_it.Ok(); node_it.Next()) { - const NodeIndex node = node_it.Index(); + for (const NodeIndex node : BipartiteLeftNodes()) { if (IsActive(node)) { active_nodes_->Add(node); } @@ -1171,9 +1125,7 @@ void LinearSumAssignment void LinearSumAssignment::SaturateNegativeArcs() { total_excess_ = 0; - for (BipartiteLeftNodeIterator node_it(*graph_, num_left_nodes_); - node_it.Ok(); node_it.Next()) { - const NodeIndex node = node_it.Index(); + for (const NodeIndex node : BipartiteLeftNodes()) { if (IsActive(node)) { // This can happen in the first iteration when nothing is // matched yet. @@ -1358,9 +1310,7 @@ bool LinearSumAssignment::AllMatched() const { // Only for debugging. template bool LinearSumAssignment::EpsilonOptimal() const { - for (BipartiteLeftNodeIterator node_it(*graph_, num_left_nodes_); - node_it.Ok(); node_it.Next()) { - const NodeIndex left_node = node_it.Index(); + for (const NodeIndex left_node : BipartiteLeftNodes()) { // Get the implicit price of left_node and make sure the reduced // costs of left_node's incident arcs are in bounds. CostValue left_node_price = ImplicitPrice(left_node); @@ -1482,8 +1432,8 @@ CostValue LinearSumAssignment::GetCost() const { // an optimum assignment. DCHECK(success_); CostValue cost = 0; - for (BipartiteLeftNodeIterator node_it(*this); node_it.Ok(); node_it.Next()) { - cost += GetAssignmentCost(node_it.Index()); + for (const NodeIndex node : BipartiteLeftNodes()) { + cost += GetAssignmentCost(node); } return cost; } diff --git a/ortools/graph/linear_assignment_test.cc b/ortools/graph/linear_assignment_test.cc index e2eec2126d..66b5eb3d83 100644 --- a/ortools/graph/linear_assignment_test.cc +++ b/ortools/graph/linear_assignment_test.cc @@ -13,18 +13,20 @@ #include "ortools/graph/linear_assignment.h" +#include #include #include #include #include +#include "absl/flags/declare.h" +#include "absl/flags/flag.h" +#include "absl/log/check.h" #include "absl/random/distributions.h" #include "absl/types/span.h" #include "benchmark/benchmark.h" #include "gtest/gtest.h" -#include "ortools/base/commandlineflags.h" #include "ortools/base/gmock.h" -#include "ortools/graph/ebert_graph.h" #include "ortools/graph/graph.h" ABSL_DECLARE_FLAG(bool, assignment_stack_order); @@ -33,64 +35,55 @@ namespace operations_research { using ::testing::Eq; -template -static ArcIndex CreateArcWithCost( - NodeIndex tail, NodeIndex head, CostValue cost, - AnnotatedGraphBuildManager* builder, - LinearSumAssignment* assignment) { - const ArcIndex arc = builder->AddArc(tail, head); - assignment->SetArcCost(arc, cost); - return arc; -} - // A little package containing everything the AnnotatedGraphBuildManager-based // tests need to know about an assignment instance. template struct AssignmentProblemSetup { + using NodeIndex = typename GraphType::NodeIndex; + using ArcIndex = typename GraphType::ArcIndex; + using CostValue = int64_t; + // The usual constructor, for normal tests where the graph is balanced. - AssignmentProblemSetup(NodeIndex num_left_nodes, ArcIndex num_arcs, - bool optimize_layout) - : builder(new AnnotatedGraphBuildManager( - 2 * num_left_nodes, num_arcs, optimize_layout)), - assignment_scoped( - new LinearSumAssignment(num_left_nodes, num_arcs)), - assignment(assignment_scoped.get()), - cycle_handler_scoped(assignment_scoped->ArcAnnotationCycleHandler()), - cycle_handler(cycle_handler_scoped.get()) {} + AssignmentProblemSetup(NodeIndex num_left_nodes, ArcIndex num_arcs) + : num_left_nodes(num_left_nodes), + graph(2 * num_left_nodes, num_arcs), + arc_costs(num_arcs) {} // A constructor with separate specification of the numbers of left and right // nodes, so the tests can set up graphs where the assignment solution is sure // to fail. AssignmentProblemSetup(NodeIndex num_left_nodes, NodeIndex num_right_nodes, - ArcIndex num_arcs, bool optimize_layout) - : builder(new AnnotatedGraphBuildManager( - num_left_nodes + num_right_nodes, num_arcs, optimize_layout)), - assignment_scoped( - new LinearSumAssignment(num_left_nodes, num_arcs)), - assignment(assignment_scoped.get()), - cycle_handler_scoped(assignment_scoped->ArcAnnotationCycleHandler()), - cycle_handler(cycle_handler_scoped.get()) {} + ArcIndex num_arcs) + : num_left_nodes(num_left_nodes), + graph(num_left_nodes + num_right_nodes, num_arcs), + arc_costs(num_arcs) {} // This type is neither copyable nor movable. AssignmentProblemSetup(const AssignmentProblemSetup&) = delete; AssignmentProblemSetup& operator=(const AssignmentProblemSetup&) = delete; - virtual ~AssignmentProblemSetup() { delete &assignment->Graph(); } - - void Finalize() { - GraphType* graph = builder->Graph(cycle_handler); - ASSERT_TRUE(graph != nullptr); - assignment->SetGraph(graph); + ArcIndex CreateArcWithCost(NodeIndex tail, NodeIndex head, CostValue cost) { + const ArcIndex arc = graph.AddArc(tail, head); + arc_costs[arc] = cost; + return arc; } - AnnotatedGraphBuildManager* builder; + void Finalize() { + CHECK_EQ(graph.num_arcs(), arc_costs.size()); + graph.Build(&arc_permutation); + util::Permute(arc_permutation, &arc_costs); + assignment = std::make_unique>( + graph, num_left_nodes); + for (int arc = 0; arc < arc_costs.size(); ++arc) { + assignment->SetArcCost(arc, arc_costs[arc]); + } + } - std::unique_ptr> assignment_scoped; - LinearSumAssignment* assignment; // to avoid ".get()" everywhere - - std::unique_ptr> - cycle_handler_scoped; - PermutationCycleHandler* cycle_handler; + const int num_left_nodes; + GraphType graph; + std::vector arc_costs; + std::unique_ptr> assignment; + std::vector arc_permutation; }; // A fixture template to collect the types of graphs on which we want to base @@ -101,19 +94,16 @@ struct AssignmentProblemSetup { template class LinearSumAssignmentTestWithGraphBuilder : public ::testing::Test {}; -typedef ::testing::Types< - EbertGraph, ForwardEbertGraph, - EbertGraph, ForwardEbertGraph, - EbertGraph, ForwardEbertGraph, - StarGraph, ForwardStarGraph, util::ListGraph<>, util::ReverseArcListGraph<>> +typedef ::testing::Types, util::ReverseArcListGraph<>, + util::StaticGraph<>, util::ReverseArcStaticGraph<>> GraphTypesForAssignmentTestingWithGraphBuilder; TYPED_TEST_SUITE(LinearSumAssignmentTestWithGraphBuilder, GraphTypesForAssignmentTestingWithGraphBuilder); TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, OptimumMatching0) { - AssignmentProblemSetup setup(1, 1, false); - CreateArcWithCost(0, 1, 0, setup.builder, setup.assignment); + AssignmentProblemSetup setup(1, 1); + setup.CreateArcWithCost(0, 1, 0); setup.Finalize(); EXPECT_TRUE(setup.assignment->ComputeAssignment()); EXPECT_EQ(0, setup.assignment->GetCost()); @@ -121,112 +111,112 @@ TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, OptimumMatching0) { TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, OptimumMatching1) { // A problem instance containing a node with no incident arcs. - AssignmentProblemSetup setup(2, 1, false); + AssignmentProblemSetup setup(2, 1); // We need the graph to include an arc that mentions the largest-indexed node // in order to get better test coverage. Without that node used in an arc, // infeasibility is detected very early because the number of nodes in the // graph isn't twice the stated number of left-side nodes. With that node // mentioned, the number of nodes in the graph alone is not enough to // establish infeasibility, so more code runs. - CreateArcWithCost(1, 3, 0, setup.builder, setup.assignment); + setup.CreateArcWithCost(1, 3, 0); setup.Finalize(); EXPECT_FALSE(setup.assignment->ComputeAssignment()); } TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, OptimumMatching2) { - AssignmentProblemSetup setup(2, 4, false); - CreateArcWithCost(0, 2, 0, setup.builder, setup.assignment); - CreateArcWithCost(0, 3, 2, setup.builder, setup.assignment); - CreateArcWithCost(1, 2, 3, setup.builder, setup.assignment); - CreateArcWithCost(1, 3, 4, setup.builder, setup.assignment); + AssignmentProblemSetup setup(2, 4); + setup.CreateArcWithCost(0, 2, 0); + setup.CreateArcWithCost(0, 3, 2); + setup.CreateArcWithCost(1, 2, 3); + setup.CreateArcWithCost(1, 3, 4); setup.Finalize(); EXPECT_TRUE(setup.assignment->ComputeAssignment()); EXPECT_EQ(4, setup.assignment->GetCost()); } TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, OptimumMatching3) { - AssignmentProblemSetup setup(4, 10, false); + AssignmentProblemSetup setup(4, 10); // Create arcs with tail nodes out of order to ensure that we test a case in // which the cost values must be nontrivially permuted if a static graph is // the underlying representation. - CreateArcWithCost(0, 5, 19, setup.builder, setup.assignment); - CreateArcWithCost(0, 6, 47, setup.builder, setup.assignment); - CreateArcWithCost(0, 7, 0, setup.builder, setup.assignment); - CreateArcWithCost(1, 4, 41, setup.builder, setup.assignment); - CreateArcWithCost(2, 4, 60, setup.builder, setup.assignment); - CreateArcWithCost(2, 5, 15, setup.builder, setup.assignment); - CreateArcWithCost(2, 7, 60, setup.builder, setup.assignment); - CreateArcWithCost(3, 4, 0, setup.builder, setup.assignment); - CreateArcWithCost(1, 6, 13, setup.builder, setup.assignment); - CreateArcWithCost(1, 7, 41, setup.builder, setup.assignment); + setup.CreateArcWithCost(0, 5, 19); + setup.CreateArcWithCost(0, 6, 47); + setup.CreateArcWithCost(0, 7, 0); + setup.CreateArcWithCost(1, 4, 41); + setup.CreateArcWithCost(2, 4, 60); + setup.CreateArcWithCost(2, 5, 15); + setup.CreateArcWithCost(2, 7, 60); + setup.CreateArcWithCost(3, 4, 0); + setup.CreateArcWithCost(1, 6, 13); + setup.CreateArcWithCost(1, 7, 41); setup.Finalize(); EXPECT_TRUE(setup.assignment->ComputeAssignment()); EXPECT_EQ(0 + 13 + 15 + 0, setup.assignment->GetCost()); } TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, OptimumMatching4) { - AssignmentProblemSetup setup(4, 10, false); - CreateArcWithCost(0, 4, 0, setup.builder, setup.assignment); - CreateArcWithCost(1, 4, 60, setup.builder, setup.assignment); - CreateArcWithCost(1, 5, 15, setup.builder, setup.assignment); - CreateArcWithCost(1, 7, 60, setup.builder, setup.assignment); - CreateArcWithCost(2, 4, 41, setup.builder, setup.assignment); - CreateArcWithCost(2, 6, 13, setup.builder, setup.assignment); - CreateArcWithCost(2, 7, 41, setup.builder, setup.assignment); - CreateArcWithCost(3, 5, 19, setup.builder, setup.assignment); - CreateArcWithCost(3, 6, 47, setup.builder, setup.assignment); - CreateArcWithCost(3, 7, 0, setup.builder, setup.assignment); + AssignmentProblemSetup setup(4, 10); + setup.CreateArcWithCost(0, 4, 0); + setup.CreateArcWithCost(1, 4, 60); + setup.CreateArcWithCost(1, 5, 15); + setup.CreateArcWithCost(1, 7, 60); + setup.CreateArcWithCost(2, 4, 41); + setup.CreateArcWithCost(2, 6, 13); + setup.CreateArcWithCost(2, 7, 41); + setup.CreateArcWithCost(3, 5, 19); + setup.CreateArcWithCost(3, 6, 47); + setup.CreateArcWithCost(3, 7, 0); setup.Finalize(); EXPECT_TRUE(setup.assignment->ComputeAssignment()); EXPECT_EQ(0 + 13 + 15 + 0, setup.assignment->GetCost()); } TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, OptimumMatching5) { - AssignmentProblemSetup setup(4, 10, false); - CreateArcWithCost(0, 4, 60, setup.builder, setup.assignment); - CreateArcWithCost(0, 5, 15, setup.builder, setup.assignment); - CreateArcWithCost(0, 7, 60, setup.builder, setup.assignment); - CreateArcWithCost(1, 4, 0, setup.builder, setup.assignment); - CreateArcWithCost(2, 4, 41, setup.builder, setup.assignment); - CreateArcWithCost(2, 6, 13, setup.builder, setup.assignment); - CreateArcWithCost(2, 7, 41, setup.builder, setup.assignment); - CreateArcWithCost(3, 5, 19, setup.builder, setup.assignment); - CreateArcWithCost(3, 6, 47, setup.builder, setup.assignment); - CreateArcWithCost(3, 7, 0, setup.builder, setup.assignment); + AssignmentProblemSetup setup(4, 10); + setup.CreateArcWithCost(0, 4, 60); + setup.CreateArcWithCost(0, 5, 15); + setup.CreateArcWithCost(0, 7, 60); + setup.CreateArcWithCost(1, 4, 0); + setup.CreateArcWithCost(2, 4, 41); + setup.CreateArcWithCost(2, 6, 13); + setup.CreateArcWithCost(2, 7, 41); + setup.CreateArcWithCost(3, 5, 19); + setup.CreateArcWithCost(3, 6, 47); + setup.CreateArcWithCost(3, 7, 0); setup.Finalize(); EXPECT_TRUE(setup.assignment->ComputeAssignment()); EXPECT_EQ(0 + 13 + 15 + 0, setup.assignment->GetCost()); } TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, OptimumMatching6) { - AssignmentProblemSetup setup(4, 10, false); - CreateArcWithCost(0, 4, 41, setup.builder, setup.assignment); - CreateArcWithCost(0, 6, 13, setup.builder, setup.assignment); - CreateArcWithCost(0, 7, 41, setup.builder, setup.assignment); - CreateArcWithCost(1, 4, 60, setup.builder, setup.assignment); - CreateArcWithCost(1, 5, 15, setup.builder, setup.assignment); - CreateArcWithCost(1, 7, 60, setup.builder, setup.assignment); - CreateArcWithCost(2, 4, 0, setup.builder, setup.assignment); - CreateArcWithCost(3, 5, 19, setup.builder, setup.assignment); - CreateArcWithCost(3, 6, 47, setup.builder, setup.assignment); - CreateArcWithCost(3, 7, 0, setup.builder, setup.assignment); + AssignmentProblemSetup setup(4, 10); + setup.CreateArcWithCost(0, 4, 41); + setup.CreateArcWithCost(0, 6, 13); + setup.CreateArcWithCost(0, 7, 41); + setup.CreateArcWithCost(1, 4, 60); + setup.CreateArcWithCost(1, 5, 15); + setup.CreateArcWithCost(1, 7, 60); + setup.CreateArcWithCost(2, 4, 0); + setup.CreateArcWithCost(3, 5, 19); + setup.CreateArcWithCost(3, 6, 47); + setup.CreateArcWithCost(3, 7, 0); setup.Finalize(); EXPECT_TRUE(setup.assignment->ComputeAssignment()); EXPECT_EQ(0 + 13 + 15 + 0, setup.assignment->GetCost()); } TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, ZeroCostArcs) { - AssignmentProblemSetup setup(4, 10, false); - CreateArcWithCost(0, 4, 0, setup.builder, setup.assignment); - CreateArcWithCost(0, 6, 0, setup.builder, setup.assignment); - CreateArcWithCost(0, 7, 0, setup.builder, setup.assignment); - CreateArcWithCost(1, 4, 0, setup.builder, setup.assignment); - CreateArcWithCost(1, 5, 0, setup.builder, setup.assignment); - CreateArcWithCost(1, 7, 0, setup.builder, setup.assignment); - CreateArcWithCost(2, 4, 0, setup.builder, setup.assignment); - CreateArcWithCost(3, 5, 0, setup.builder, setup.assignment); - CreateArcWithCost(3, 6, 0, setup.builder, setup.assignment); - CreateArcWithCost(3, 7, 0, setup.builder, setup.assignment); + AssignmentProblemSetup setup(4, 10); + setup.CreateArcWithCost(0, 4, 0); + setup.CreateArcWithCost(0, 6, 0); + setup.CreateArcWithCost(0, 7, 0); + setup.CreateArcWithCost(1, 4, 0); + setup.CreateArcWithCost(1, 5, 0); + setup.CreateArcWithCost(1, 7, 0); + setup.CreateArcWithCost(2, 4, 0); + setup.CreateArcWithCost(3, 5, 0); + setup.CreateArcWithCost(3, 6, 0); + setup.CreateArcWithCost(3, 7, 0); setup.Finalize(); EXPECT_TRUE(setup.assignment->ComputeAssignment()); EXPECT_EQ(0, setup.assignment->GetCost()); @@ -235,13 +225,11 @@ TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, ZeroCostArcs) { // A helper function template for checking that we got the optimum assignment we // expected. template -static void VerifyAssignment(const LinearSumAssignment& a, - const NodeIndex expected_right_side[]) { - for (typename LinearSumAssignment::BipartiteLeftNodeIterator - node_it(a); - node_it.Ok(); node_it.Next()) { - const NodeIndex left_node = node_it.Index(); - const NodeIndex right_node = a.GetMate(left_node); +static void VerifyAssignment( + const LinearSumAssignment& a, + const typename GraphType::NodeIndex expected_right_side[]) { + for (const typename GraphType::NodeIndex left_node : a.BipartiteLeftNodes()) { + const typename GraphType::NodeIndex right_node = a.GetMate(left_node); EXPECT_EQ(expected_right_side[left_node], right_node); } } @@ -256,16 +244,15 @@ TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, OptimumMatching7) { {125, 95, 90, 105}, {45, 110, 95, 115}}; AssignmentProblemSetup setup(kMatrixHeight, - kMatrixHeight * kMatrixWidth, false); + kMatrixHeight * kMatrixWidth); for (int i = 0; i < kMatrixHeight; ++i) { for (int j = 0; j < kMatrixWidth; ++j) { - CreateArcWithCost(i, j + kMatrixHeight, kCost[i][j], - setup.builder, setup.assignment); + setup.CreateArcWithCost(i, j + kMatrixHeight, kCost[i][j]); } } setup.Finalize(); EXPECT_TRUE(setup.assignment->ComputeAssignment()); - const NodeIndex kExpectedAssignment[] = {7, 6, 5, 4}; + const typename TypeParam::NodeIndex kExpectedAssignment[] = {7, 6, 5, 4}; VerifyAssignment(*setup.assignment, kExpectedAssignment); EXPECT_EQ(80 + 55 + 95 + 45, setup.assignment->GetCost()); } @@ -284,13 +271,13 @@ TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, {-125, -95, -90, -105}, {-45, -110, -95, -115}}; AssignmentProblemSetup setup(kMatrixHeight, - kMatrixHeight * kMatrixWidth, false); + kMatrixHeight * kMatrixWidth); // Index of the arc we will remember and modify. - ArcIndex cost_100_arc = TypeParam::kNilArc; + typename TypeParam::ArcIndex cost_100_arc = TypeParam::kNilArc; for (int i = 0; i < kMatrixHeight; ++i) { for (int j = 0; j < kMatrixWidth; ++j) { - ArcIndex new_arc = CreateArcWithCost(i, j + kMatrixHeight, kCost[i][j], - setup.builder, setup.assignment); + typename TypeParam::ArcIndex new_arc = + setup.CreateArcWithCost(i, j + kMatrixHeight, kCost[i][j]); if (kCost[i][j] == 100) { cost_100_arc = new_arc; } @@ -298,7 +285,7 @@ TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, } setup.Finalize(); EXPECT_TRUE(setup.assignment->ComputeAssignment()); - const NodeIndex kExpectedAssignment1[] = {6, 7, 4, 5}; + const typename TypeParam::NodeIndex kExpectedAssignment1[] = {6, 7, 4, 5}; VerifyAssignment(*setup.assignment, kExpectedAssignment1); EXPECT_EQ(-75 + -65 + -125 + -110, setup.assignment->GetCost()); @@ -308,71 +295,72 @@ TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, ASSERT_EQ(100, setup.assignment->ArcCost(cost_100_arc)); setup.assignment->SetArcCost(cost_100_arc, -85); EXPECT_TRUE(setup.assignment->ComputeAssignment()); - const NodeIndex kExpectedAssignment2[] = {6, 5, 4, 7}; + const typename TypeParam::NodeIndex kExpectedAssignment2[] = {6, 5, 4, 7}; VerifyAssignment(*setup.assignment, kExpectedAssignment2); EXPECT_EQ(-75 + -85 + -125 + -115, setup.assignment->GetCost()); } TYPED_TEST(LinearSumAssignmentTestWithGraphBuilder, InfeasibleProblems) { // No arcs in the graph at all. - AssignmentProblemSetup setup0(1, 1, false); + AssignmentProblemSetup setup0(1, 0); setup0.Finalize(); EXPECT_FALSE(setup0.assignment->ComputeAssignment()); // Unbalanced graph: 4 nodes on the left, 2 on the right. - AssignmentProblemSetup setup1(4, 2, 4, false); - CreateArcWithCost(0, 4, 0, setup1.builder, setup1.assignment); - CreateArcWithCost(1, 4, 2, setup1.builder, setup1.assignment); - CreateArcWithCost(2, 5, 3, setup1.builder, setup1.assignment); - CreateArcWithCost(3, 5, 4, setup1.builder, setup1.assignment); + AssignmentProblemSetup setup1(4, 2, 4); + setup1.CreateArcWithCost(0, 4, 0); + setup1.CreateArcWithCost(1, 4, 2); + setup1.CreateArcWithCost(2, 5, 3); + setup1.CreateArcWithCost(3, 5, 4); setup1.Finalize(); EXPECT_FALSE(setup1.assignment->ComputeAssignment()); // Unbalanced graph: 2 nodes on the left, 4 on the right. - AssignmentProblemSetup setup2(2, 4, 4, false); - CreateArcWithCost(0, 2, 0, setup2.builder, setup2.assignment); - CreateArcWithCost(1, 3, 2, setup2.builder, setup2.assignment); - CreateArcWithCost(0, 4, 3, setup2.builder, setup2.assignment); - CreateArcWithCost(1, 5, 4, setup2.builder, setup2.assignment); + AssignmentProblemSetup setup2(2, 4, 4); + setup2.CreateArcWithCost(0, 2, 0); + setup2.CreateArcWithCost(1, 3, 2); + setup2.CreateArcWithCost(0, 4, 3); + setup2.CreateArcWithCost(1, 5, 4); setup2.Finalize(); EXPECT_FALSE(setup2.assignment->ComputeAssignment()); // Balanced graph with no perfect matching. - AssignmentProblemSetup setup3(3, 5, false); - CreateArcWithCost(0, 3, 0, setup3.builder, setup3.assignment); - CreateArcWithCost(1, 3, 2, setup3.builder, setup3.assignment); - CreateArcWithCost(2, 3, 3, setup3.builder, setup3.assignment); - CreateArcWithCost(0, 4, 4, setup3.builder, setup3.assignment); - CreateArcWithCost(0, 5, 5, setup3.builder, setup3.assignment); + AssignmentProblemSetup setup3(3, 5); + setup3.CreateArcWithCost(0, 3, 0); + setup3.CreateArcWithCost(1, 3, 2); + setup3.CreateArcWithCost(2, 3, 3); + setup3.CreateArcWithCost(0, 4, 4); + setup3.CreateArcWithCost(0, 5, 5); setup3.Finalize(); EXPECT_FALSE(setup3.assignment->ComputeAssignment()); // Another balanced graph with no perfect matching, but with plenty // of in/out degree for each node. - AssignmentProblemSetup setup4(5, 12, false); - CreateArcWithCost(0, 5, 0, setup4.builder, setup4.assignment); - CreateArcWithCost(0, 6, 2, setup4.builder, setup4.assignment); - CreateArcWithCost(1, 5, 3, setup4.builder, setup4.assignment); - CreateArcWithCost(1, 6, 4, setup4.builder, setup4.assignment); - CreateArcWithCost(2, 5, 4, setup4.builder, setup4.assignment); - CreateArcWithCost(2, 6, 4, setup4.builder, setup4.assignment); - CreateArcWithCost(3, 7, 4, setup4.builder, setup4.assignment); - CreateArcWithCost(3, 8, 4, setup4.builder, setup4.assignment); - CreateArcWithCost(3, 9, 4, setup4.builder, setup4.assignment); - CreateArcWithCost(4, 7, 4, setup4.builder, setup4.assignment); - CreateArcWithCost(4, 8, 4, setup4.builder, setup4.assignment); - CreateArcWithCost(4, 9, 4, setup4.builder, setup4.assignment); + AssignmentProblemSetup setup4(5, 12); + setup4.CreateArcWithCost(0, 5, 0); + setup4.CreateArcWithCost(0, 6, 2); + setup4.CreateArcWithCost(1, 5, 3); + setup4.CreateArcWithCost(1, 6, 4); + setup4.CreateArcWithCost(2, 5, 4); + setup4.CreateArcWithCost(2, 6, 4); + setup4.CreateArcWithCost(3, 7, 4); + setup4.CreateArcWithCost(3, 8, 4); + setup4.CreateArcWithCost(3, 9, 4); + setup4.CreateArcWithCost(4, 7, 4); + setup4.CreateArcWithCost(4, 8, 4); + setup4.CreateArcWithCost(4, 9, 4); setup4.Finalize(); EXPECT_FALSE(setup4.assignment->ComputeAssignment()); } // A helper function template for setting up assignment problems based on -// dynamic graph types without an interposed GraphBuildManager. -template -static ArcIndex CreateArcWithCost( - NodeIndex tail, NodeIndex head, CostValue cost, DynamicGraphType* graph, - LinearSumAssignment* assignment) { - ArcIndex arc = graph->AddArc(tail, head); +// dynamic graph types without a need to `Build()`. +template +static typename GraphType::ArcIndex CreateArcWithCost( + typename GraphType::NodeIndex tail, typename GraphType::NodeIndex head, + int64_t cost, GraphType* graph, + LinearSumAssignment* assignment) { + typename GraphType::ArcIndex arc = graph->AddArc(tail, head); assignment->SetArcCost(arc, cost); return arc; } @@ -384,56 +372,19 @@ static ArcIndex CreateArcWithCost( template class LinearSumAssignmentTestWithDynamicGraph : public ::testing::Test {}; -typedef ::testing::Types< - EbertGraph, ForwardEbertGraph, - EbertGraph, ForwardEbertGraph, - EbertGraph, ForwardEbertGraph> +typedef ::testing::Types<::util::ListGraph<>> DynamicGraphTypesForAssignmentTesting; TYPED_TEST_SUITE(LinearSumAssignmentTestWithDynamicGraph, DynamicGraphTypesForAssignmentTesting); -TYPED_TEST(LinearSumAssignmentTestWithDynamicGraph, GraphLayoutTest) { - // A complete bipartite 3x3 graph (9 edges). - TypeParam g(6, 9); - LinearSumAssignment a(g, 3); - // We add arcs in a higgledy-piggledy order, with costs that indicate the - // order the arcs should have after the layout is optimized. - CreateArcWithCost(0, 3, 1, &g, &a); // in cycle [0] - CreateArcWithCost(2, 5, 9, &g, &a); // in cycle [1 8 3] - CreateArcWithCost(1, 5, 6, &g, &a); // in cycle [2 5] - CreateArcWithCost(0, 4, 2, &g, &a); // in cycle [1 8 3] - CreateArcWithCost(1, 4, 5, &g, &a); // in cycle [4] - CreateArcWithCost(0, 5, 3, &g, &a); // in cycle [2 5] - CreateArcWithCost(2, 4, 8, &g, &a); // in cycle [6 7] - CreateArcWithCost(2, 3, 7, &g, &a); // in cycle [6 7] - CreateArcWithCost(1, 3, 4, &g, &a); // in cycle [1 8 3] - - EXPECT_TRUE(a.ComputeAssignment()); - EXPECT_EQ(1 + 5 + 9, a.GetCost()); - - a.OptimizeGraphLayout(&g); - EXPECT_TRUE(a.ComputeAssignment()); - EXPECT_EQ(1 + 5 + 9, a.GetCost()); - // The optimized graph layout is supposed to group arcs by their tail nodes - // and sequence them within each group by their head nodes. - TailArrayManager tail_array_manager(&g); - tail_array_manager.BuildTailArrayFromAdjacencyListsIfForwardGraph(); - for (int i = 0; i < 9; ++i) { - EXPECT_EQ(i + 1, a.ArcCost(i)); - EXPECT_EQ(i / 3, g.Tail(i)); - EXPECT_EQ(3 + i % 3, g.Head(i)); - } - tail_array_manager.ReleaseTailArrayIfForwardGraph(); -} - // The EpsilonOptimal test and the PrecisionWarning test cannot be parameterized // by the type of the underlying graph because doing so is not supported by the // FRIEND_TEST() macro used in the LinearSumAssignment class template to grant // these tests access to private methods of LinearSumAssignment. TEST(LinearSumAssignmentFriendTest, EpsilonOptimal) { - StarGraph g(4, 4); - LinearSumAssignment a(g, 2); + ::util::ListGraph<> g(4, 4); + LinearSumAssignment<::util::ListGraph<>> a(g, 2); CreateArcWithCost(0, 2, 0, &g, &a); CreateArcWithCost(0, 3, 2, &g, &a); CreateArcWithCost(1, 2, 3, &g, &a); @@ -449,11 +400,11 @@ TEST(LinearSumAssignmentFriendTest, EpsilonOptimal) { #if LARGE TEST(LinearSumAssignmentPrecisionTest, PrecisionWarning) { const NodeIndex kNumLeftNodes = 10000000; - ForwardStarGraph g(2 * kNumLeftNodes, 2 * kNumLeftNodes); - LinearSumAssignment a(g, kNumLeftNodes); + util::ListGraph<> g(2 * kNumLeftNodes, 2 * kNumLeftNodes); + LinearSumAssignment> a(g, kNumLeftNodes); int64_t node_count = 0; - for (NodeIndex left_node = ForwardStarGraph::kFirstNode; - node_count < kNumLeftNodes; ++node_count, ++left_node) { + for (NodeIndex left_node = 0; node_count < kNumLeftNodes; + ++node_count, ++left_node) { CreateArcWithCost(left_node, kNumLeftNodes + left_node, kNumLeftNodes, &g, &a); } @@ -465,7 +416,9 @@ TEST(LinearSumAssignmentPrecisionTest, PrecisionWarning) { #endif // LARGE class MacholWien - : public ::testing::TestWithParam<::testing::tuple> {}; + : public ::testing::TestWithParam< + ::testing::tuple<::util::CompleteBipartiteGraph<>::NodeIndex, bool>> { +}; // The following test computes assignments on the instances described in: // Robert E. Machol and Michael Wien, "Errata: A Hard Assignment Problem", @@ -478,22 +431,19 @@ class MacholWien // list flag. TEST_P(MacholWien, SolveHardProblem) { using Graph = ::util::CompleteBipartiteGraph<>; - NodeIndex n = ::testing::get<0>(GetParam()); + Graph::NodeIndex n = ::testing::get<0>(GetParam()); absl::SetFlag(&FLAGS_assignment_stack_order, ::testing::get<1>(GetParam())); Graph graph(n, n); LinearSumAssignment assignment(graph, n); - for (NodeIndex i = 0; i < n; ++i) { - for (NodeIndex j = 0; j < n; ++j) { - const ArcIndex arc = graph.GetArc(i, n + j); + for (Graph::NodeIndex i = 0; i < n; ++i) { + for (Graph::NodeIndex j = 0; j < n; ++j) { + const Graph::ArcIndex arc = graph.GetArc(i, n + j); assignment.SetArcCost(arc, i * j); } } EXPECT_TRUE(assignment.ComputeAssignment()); - for (LinearSumAssignment::BipartiteLeftNodeIterator node_it( - assignment); - node_it.Ok(); node_it.Next()) { - const NodeIndex left_node = node_it.Index(); - const NodeIndex right_node = assignment.GetMate(left_node); + for (const Graph::NodeIndex left_node : assignment.BipartiteLeftNodes()) { + const Graph::NodeIndex right_node = assignment.GetMate(left_node); EXPECT_EQ(2 * n - 1, left_node + right_node); } } @@ -516,48 +466,23 @@ INSTANTIATE_TEST_CASE_P(MacholWienProblems, MacholWien, ::testing::Bool())); #endif // LARGE -// Helper function for random-assignment benchmarks. -template -void ConstructRandomAssignment( - const int left_nodes, const int average_degree, const int cost_limit, - std::unique_ptr* graph, - std::unique_ptr>* assignment) { - const int kNodes = 2 * left_nodes; - const int kArcs = left_nodes * average_degree; - const int kRandomSeed = 0; - std::mt19937 randomizer(kRandomSeed); - AnnotatedGraphBuildManager* builder = - new AnnotatedGraphBuildManager(kNodes, kArcs, optimize_layout); - assignment->reset(new LinearSumAssignment(left_nodes, kArcs)); - for (int i = 0; i < kArcs; ++i) { - const int left = absl::Uniform(randomizer, 0, left_nodes); - const int right = left_nodes + absl::Uniform(randomizer, 0, left_nodes); - const CostValue cost = absl::Uniform(randomizer, 0, cost_limit); - CreateArcWithCost(left, right, cost, builder, assignment->get()); - } - std::unique_ptr> cycle_handler( - assignment->get()->ArcAnnotationCycleHandler()); - graph->reset(builder->Graph(cycle_handler.get())); - assignment->get()->SetGraph(graph->get()); -} - // Same as ConstructRandomAssignment, but for the new API. template void ConstructRandomAssignmentForNewGraphApi( const int left_nodes, const int average_degree, const int cost_limit, std::unique_ptr* graph, - std::unique_ptr>* assignment) { + std::unique_ptr>* assignment) { const int kNodes = 2 * left_nodes; const int kArcs = left_nodes * average_degree; const int kRandomSeed = 0; std::mt19937 randomizer(kRandomSeed); - std::vector arc_costs; + std::vector arc_costs; arc_costs.reserve(kArcs); graph->reset(new GraphType(kNodes, kArcs)); for (int i = 0; i < kArcs; ++i) { const int left = absl::Uniform(randomizer, 0, left_nodes); const int right = left_nodes + absl::Uniform(randomizer, 0, left_nodes); - const CostValue cost = absl::Uniform(randomizer, 0, cost_limit); + const int64_t cost = absl::Uniform(randomizer, 0, cost_limit); graph->get()->AddArc(left, right); arc_costs.push_back(cost); } @@ -569,44 +494,21 @@ void ConstructRandomAssignmentForNewGraphApi( // Create the assignment. assignment->reset( - new LinearSumAssignment(*(graph->get()), left_nodes)); + new LinearSumAssignment(*(graph->get()), left_nodes)); for (int arc = 0; arc < kArcs; ++arc) { assignment->get()->SetArcCost(arc, arc_costs[arc]); } } -// Benchmark function for assignment-problem construction only, no solution. -template -void BM_ConstructRandomAssignmentProblem(benchmark::State& state) { - const int kLeftNodes = 10000; - const int kAverageDegree = 250; - const CostValue kCostLimit = 1000000; - for (auto _ : state) { - std::unique_ptr graph; - std::unique_ptr> assignment; - ConstructRandomAssignment( - kLeftNodes, kAverageDegree, kCostLimit, &graph, &assignment); - } - state.SetItemsProcessed(static_cast(state.max_iterations) * - kLeftNodes * kAverageDegree); -} - -BENCHMARK_TEMPLATE2(BM_ConstructRandomAssignmentProblem, StarGraph, false); -BENCHMARK_TEMPLATE2(BM_ConstructRandomAssignmentProblem, ForwardStarGraph, - false); -BENCHMARK_TEMPLATE2(BM_ConstructRandomAssignmentProblem, StarGraph, true); -BENCHMARK_TEMPLATE2(BM_ConstructRandomAssignmentProblem, ForwardStarGraph, - true); - template void BM_ConstructRandomAssignmentProblemWithNewGraphApi( benchmark::State& state) { const int kLeftNodes = 10000; const int kAverageDegree = 250; - const CostValue kCostLimit = 1000000; + const int64_t kCostLimit = 1000000; for (auto _ : state) { std::unique_ptr graph; - std::unique_ptr> assignment; + std::unique_ptr> assignment; ConstructRandomAssignmentForNewGraphApi( kLeftNodes, kAverageDegree, kCostLimit, &graph, &assignment); } @@ -619,42 +521,18 @@ BENCHMARK_TEMPLATE(BM_ConstructRandomAssignmentProblemWithNewGraphApi, BENCHMARK_TEMPLATE(BM_ConstructRandomAssignmentProblemWithNewGraphApi, util::StaticGraph<>); -// Benchmark function for assignment-problem solution only, with -// problem-construction timing excluded. -template -void BM_SolveRandomAssignmentProblem(benchmark::State& state) { - const int kLeftNodes = 10000; - const int kAverageDegree = 250; - const CostValue kCostLimit = 1000000; - std::unique_ptr graph; - std::unique_ptr> assignment; - ConstructRandomAssignment( - kLeftNodes, kAverageDegree, kCostLimit, &graph, &assignment); - for (auto _ : state) { - assignment->ComputeAssignment(); - EXPECT_EQ(65849286, assignment->GetCost()); - } - state.SetItemsProcessed(static_cast(state.max_iterations) * - kLeftNodes * kAverageDegree); -} - -BENCHMARK_TEMPLATE2(BM_SolveRandomAssignmentProblem, StarGraph, false); -BENCHMARK_TEMPLATE2(BM_SolveRandomAssignmentProblem, ForwardStarGraph, false); -BENCHMARK_TEMPLATE2(BM_SolveRandomAssignmentProblem, StarGraph, true); -BENCHMARK_TEMPLATE2(BM_SolveRandomAssignmentProblem, ForwardStarGraph, true); - template void BM_SolveRandomAssignmentProblemWithNewGraphApi(benchmark::State& state) { const int kLeftNodes = 10000; const int kAverageDegree = 250; - const CostValue kCostLimit = 1000000; + const int64_t kCostLimit = 1000000; std::unique_ptr graph; - std::unique_ptr> assignment; + std::unique_ptr> assignment; ConstructRandomAssignmentForNewGraphApi( kLeftNodes, kAverageDegree, kCostLimit, &graph, &assignment); for (auto _ : state) { assignment->ComputeAssignment(); - EXPECT_EQ(65849286, assignment->GetCost()); + EXPECT_EQ(65415697, assignment->GetCost()); } state.SetItemsProcessed(static_cast(state.max_iterations) * kLeftNodes * kAverageDegree); @@ -665,47 +543,19 @@ BENCHMARK_TEMPLATE(BM_SolveRandomAssignmentProblemWithNewGraphApi, BENCHMARK_TEMPLATE(BM_SolveRandomAssignmentProblemWithNewGraphApi, util::StaticGraph<>); -// Benchmark function for assignment-problem construction and solution, with -// problem-construction timing included. -template -void BM_ConstructAndSolveRandomAssignmentProblem(benchmark::State& state) { - const int kLeftNodes = 10000; - const int kAverageDegree = 250; - const CostValue kCostLimit = 1000000; - for (auto _ : state) { - std::unique_ptr graph; - std::unique_ptr> assignment; - ConstructRandomAssignment( - kLeftNodes, kAverageDegree, kCostLimit, &graph, &assignment); - assignment->ComputeAssignment(); - EXPECT_EQ(65849286, assignment->GetCost()); - } - state.SetItemsProcessed(static_cast(state.max_iterations) * - kLeftNodes * kAverageDegree); -} - -BENCHMARK_TEMPLATE2(BM_ConstructAndSolveRandomAssignmentProblem, StarGraph, - false); -BENCHMARK_TEMPLATE2(BM_ConstructAndSolveRandomAssignmentProblem, - ForwardStarGraph, false); -BENCHMARK_TEMPLATE2(BM_ConstructAndSolveRandomAssignmentProblem, StarGraph, - true); -BENCHMARK_TEMPLATE2(BM_ConstructAndSolveRandomAssignmentProblem, - ForwardStarGraph, true); - template void BM_ConstructAndSolveRandomAssignmentProblemWithNewGraphApi( benchmark::State& state) { const int kLeftNodes = 10000; const int kAverageDegree = 250; - const CostValue kCostLimit = 1000000; + const int64_t kCostLimit = 1000000; for (auto _ : state) { std::unique_ptr graph; - std::unique_ptr> assignment; + std::unique_ptr> assignment; ConstructRandomAssignmentForNewGraphApi( kLeftNodes, kAverageDegree, kCostLimit, &graph, &assignment); assignment->ComputeAssignment(); - EXPECT_EQ(65849286, assignment->GetCost()); + EXPECT_EQ(65415697, assignment->GetCost()); } state.SetItemsProcessed(static_cast(state.max_iterations) * kLeftNodes * kAverageDegree); diff --git a/ortools/graph/max_flow_test.cc b/ortools/graph/max_flow_test.cc index e0b6e2d382..ed6c7ca854 100644 --- a/ortools/graph/max_flow_test.cc +++ b/ortools/graph/max_flow_test.cc @@ -19,7 +19,6 @@ #include "gtest/gtest.h" #include "ortools/base/gmock.h" #include "ortools/base/path.h" -#include "ortools/graph/ebert_graph.h" #include "ortools/graph/flow_problem.pb.h" #include "ortools/util/file_util.h" diff --git a/ortools/graph/min_cost_flow.cc b/ortools/graph/min_cost_flow.cc index 9e64847e09..8e09a8e3bc 100644 --- a/ortools/graph/min_cost_flow.cc +++ b/ortools/graph/min_cost_flow.cc @@ -30,7 +30,6 @@ #include "ortools/base/mathutil.h" #include "ortools/graph/generic_max_flow.h" #include "ortools/graph/graph.h" -#include "ortools/graph/graphs.h" #include "ortools/util/saturated_arithmetic.h" #include "ortools/util/stats.h" @@ -54,14 +53,14 @@ GenericMinCostFlow::GenericMinCostFlow( alpha_(absl::GetFlag(FLAGS_min_cost_flow_alpha)), stats_("MinCostFlow"), check_feasibility_(absl::GetFlag(FLAGS_min_cost_flow_check_feasibility)) { - const NodeIndex max_num_nodes = Graphs::NodeReservation(*graph_); + const NodeIndex max_num_nodes = graph_->node_capacity(); if (max_num_nodes > 0) { first_admissible_arc_.assign(max_num_nodes, Graph::kNilArc); node_potential_.assign(max_num_nodes, 0); node_excess_.assign(max_num_nodes, 0); initial_node_excess_.assign(max_num_nodes, 0); } - const ArcIndex max_num_arcs = Graphs::ArcReservation(*graph_); + const ArcIndex max_num_arcs = graph_->arc_capacity(); if (max_num_arcs > 0) { residual_arc_capacity_.Reserve(-max_num_arcs, max_num_arcs - 1); residual_arc_capacity_.SetAll(0); @@ -309,7 +308,6 @@ bool GenericMinCostFlownum_nodes(); ++node) { @@ -976,13 +974,13 @@ template typename Graph::ArcIndex GenericMinCostFlow::Opposite( ArcIndex arc) const { - return Graphs::OppositeArc(*graph_, arc); + return graph_->OppositeArc(arc); } template bool GenericMinCostFlow::IsArcValid( ArcIndex arc) const { - return Graphs::IsArcValid(*graph_, arc); + return graph_->IsArcValid(arc); } template @@ -996,7 +994,6 @@ bool GenericMinCostFlow::IsArcDirect( // // TODO(user): Move this code out of a .cc file and include it at the end of // the header so it can work with any graph implementation? -template class GenericMinCostFlow; template class GenericMinCostFlow<::util::ReverseArcListGraph<>>; template class GenericMinCostFlow<::util::ReverseArcStaticGraph<>>; template class GenericMinCostFlow<::util::ReverseArcMixedGraph<>>; @@ -1043,6 +1040,10 @@ SimpleMinCostFlow::ArcIndex SimpleMinCostFlow::AddArcWithCapacityAndUnitCost( return arc; } +void SimpleMinCostFlow::SetArcCapacity(ArcIndex arc, FlowQuantity capacity) { + arc_capacity_[arc] = capacity; +} + SimpleMinCostFlow::ArcIndex SimpleMinCostFlow::PermutedArc(ArcIndex arc) { return arc < arc_permutation_.size() ? arc_permutation_[arc] : arc; } diff --git a/ortools/graph/min_cost_flow.h b/ortools/graph/min_cost_flow.h index 187a6cf404..90964c7893 100644 --- a/ortools/graph/min_cost_flow.h +++ b/ortools/graph/min_cost_flow.h @@ -174,7 +174,7 @@ #include #include "absl/strings/string_view.h" -#include "ortools/graph/ebert_graph.h" +#include "ortools/base/logging.h" #include "ortools/graph/graph.h" #include "ortools/util/stats.h" #include "ortools/util/zvector.h" @@ -287,6 +287,11 @@ class SimpleMinCostFlow : public MinCostFlowBase { FlowQuantity capacity, CostValue unit_cost); + // Modifies the capacity of the given arc. The arc index must be non-negative + // (>= 0); it must be an index returned by a previous call to + // AddArcWithCapacityAndUnitCost(). + void SetArcCapacity(ArcIndex arc, FlowQuantity capacity); + // Sets the supply of the given node. The node index must be non-negative (>= // 0). Nodes implicitly created will have a default supply set to 0. A demand // is modeled as a negative supply. @@ -378,9 +383,9 @@ class SimpleMinCostFlow : public MinCostFlowBase { bool scale_prices_ = true; }; -// Generic MinCostFlow that works with StarGraph and all the graphs handling -// reverse arcs from graph.h, see the end of min_cost_flow.cc for the exact -// types this class is compiled for. +// Generic MinCostFlow that works with all the graphs handling reverse arcs from +// graph.h, see the end of min_cost_flow.cc for the exact types this class is +// compiled for. // // One can greatly decrease memory usage by using appropriately small integer // types: @@ -670,7 +675,6 @@ class GenericMinCostFlow : public MinCostFlowBase { // Note: SWIG does not seem to understand explicit template specialization and // instantiation declarations. -extern template class GenericMinCostFlow; extern template class GenericMinCostFlow<::util::ReverseArcListGraph<>>; extern template class GenericMinCostFlow<::util::ReverseArcStaticGraph<>>; extern template class GenericMinCostFlow<::util::ReverseArcMixedGraph<>>; @@ -683,11 +687,14 @@ extern template class GenericMinCostFlow< /*ArcFlowType=*/int16_t, /*ArcScaledCostType=*/int32_t>; -// Default MinCostFlow instance that uses StarGraph. -// New clients should use SimpleMinCostFlow if they can. -class MinCostFlow : public GenericMinCostFlow { - public: - explicit MinCostFlow(const StarGraph* graph) : GenericMinCostFlow(graph) {} +// TODO(b/385094969): Remove this alias after 2025-07-01 to give or-tools users +// a grace period. +struct MinCostFlow : public MinCostFlowBase { + template + MinCostFlow() { + LOG(FATAL) << "MinCostFlow is deprecated. Use `SimpleMinCostFlow` or " + "`GenericMinCostFlow` with a specific graph type instead."; + } }; #endif // SWIG diff --git a/ortools/graph/min_cost_flow_test.cc b/ortools/graph/min_cost_flow_test.cc index c9a3a69f33..f4b7d39e3e 100644 --- a/ortools/graph/min_cost_flow_test.cc +++ b/ortools/graph/min_cost_flow_test.cc @@ -30,7 +30,6 @@ #include "ortools/base/logging.h" #include "ortools/graph/ebert_graph.h" #include "ortools/graph/graph.h" -#include "ortools/graph/graphs.h" #include "ortools/linear_solver/linear_solver.h" namespace operations_research { @@ -116,7 +115,7 @@ void GenericMinCostFlowTester( graph.AddArc(tail[arc], head[arc]); } std::vector permutation; - Graphs::Build(&graph, &permutation); + graph.Build(&permutation); EXPECT_TRUE(permutation.empty()); GenericMinCostFlow min_cost_flow(&graph); @@ -151,7 +150,7 @@ void GenericMinCostFlowTester( template class GenericMinCostFlowTest : public ::testing::Test {}; -typedef ::testing::Types, +typedef ::testing::Types, util::ReverseArcStaticGraph<>, util::ReverseArcMixedGraph<>> GraphTypes; @@ -248,7 +247,7 @@ TYPED_TEST(GenericMinCostFlowTest, Small4x4Matrix) { } } std::vector permutation; - Graphs::Build(&graph, &permutation); + graph.Build(&permutation); EXPECT_TRUE(permutation.empty()); GenericMinCostFlow min_cost_flow(&graph); @@ -642,7 +641,7 @@ void FullRandomAssignment(typename MinCostFlowSolver::Solver f, Graph graph; GenerateCompleteGraph(num_sources, num_targets, &graph); std::vector permutation; - Graphs::Build(&graph, &permutation); + graph.Build(&permutation); std::vector supply; GenerateAssignmentSupply(num_sources, num_targets, &supply); @@ -668,7 +667,7 @@ void PartialRandomAssignment(typename MinCostFlowSolver::Solver f, Graph graph; GeneratePartialRandomGraph(num_sources, num_targets, kDegree, &graph); std::vector permutation; - Graphs::Build(&graph, &permutation); + graph.Build(&permutation); std::vector supply; GenerateAssignmentSupply(num_sources, num_targets, &supply); @@ -718,7 +717,7 @@ void PartialRandomFlow(typename MinCostFlowSolver::Solver f, Graph graph; GeneratePartialRandomGraph(num_sources, num_targets, kDegree, &graph); std::vector permutation; - Graphs::Build(&graph, &permutation); + graph.Build(&permutation); std::vector supply; GenerateRandomSupply(num_sources, num_targets, kSupplyGens, kSupplyRange, @@ -756,7 +755,7 @@ void FullRandomFlow(typename MinCostFlowSolver::Solver f, Graph graph; GenerateCompleteGraph(num_sources, num_targets, &graph); std::vector permutation; - Graphs::Build(&graph, &permutation); + graph.Build(&permutation); std::vector supply; GenerateRandomSupply(num_sources, num_targets, kSupplyGens, kSupplyRange, @@ -787,16 +786,16 @@ void FullRandomFlow(typename MinCostFlowSolver::Solver f, FLOW_ONLY_TEST(test_name, size, expected_cost1, expected_cost2) \ FLOW_ONLY_TEST_SG(test_name, size, expected_cost1, expected_cost2) -#define LP_ONLY_TEST(test_name, size, expected_cost1, expected_cost2) \ - TEST(LPMinCostFlowTest, test_name##size) { \ - test_name(SolveMinCostFlowWithLP, size, size, expected_cost1, \ - expected_cost2); \ +#define LP_ONLY_TEST(test_name, size, expected_cost1, expected_cost2) \ + TEST(LPMinCostFlowTest, test_name##size) { \ + test_name>(SolveMinCostFlowWithLP, size, size, \ + expected_cost1, expected_cost2); \ } -#define FLOW_ONLY_TEST(test_name, size, expected_cost1, expected_cost2) \ - TEST(MinCostFlowTest, test_name##size) { \ - test_name(SolveMinCostFlow, size, size, expected_cost1, \ - expected_cost2); \ +#define FLOW_ONLY_TEST(test_name, size, expected_cost1, expected_cost2) \ + TEST(MinCostFlowTest, test_name##size) { \ + test_name>(SolveMinCostFlow, size, size, \ + expected_cost1, expected_cost2); \ } #define FLOW_ONLY_TEST_SG(test_name, size, expected_cost1, expected_cost2) \ @@ -930,7 +929,7 @@ void BM_MinCostFlowOnMultiMatchingProblem(benchmark::State& state) { graph.AddArc(/*tail=*/kNumChannels + 1 + j, /*head=*/0); } std::vector permutation; - Graphs::Build(&graph, &permutation); + graph.Build(&permutation); // To spare memory, we added arcs in the right order, so that no permutation // is needed. See graph.h. CHECK(permutation.empty()); @@ -974,14 +973,11 @@ void BM_MinCostFlowOnMultiMatchingProblem(benchmark::State& state) { BENCHMARK(BM_MinCostFlowOnMultiMatchingProblem< util::ReverseArcStaticGraph, int16_t, int32_t, /*kNumChannels=*/20000, /*kNumUsers=*/20000>); -// We also benchmark with default parameter types and StarGraph for reference. +// We also benchmark with default parameter types for reference. // We use fewer channels and users to avoid running out of memory. BENCHMARK(BM_MinCostFlowOnMultiMatchingProblem< ::util::ReverseArcListGraph<>, int64_t, int64_t, /*kNumChannels=*/5000, /*kNumUsers=*/5000>); -BENCHMARK(BM_MinCostFlowOnMultiMatchingProblem); } // namespace } // namespace operations_research diff --git a/ortools/graph/minimum_vertex_cover.cc b/ortools/graph/minimum_vertex_cover.cc index 140c3bb49a..0676c50196 100644 --- a/ortools/graph/minimum_vertex_cover.cc +++ b/ortools/graph/minimum_vertex_cover.cc @@ -16,7 +16,6 @@ #include #include "absl/log/check.h" -#include "ortools/graph/ebert_graph.h" #include "ortools/graph/max_flow.h" namespace operations_research { @@ -31,7 +30,7 @@ std::vector BipartiteMinimumVertexCover( // alternating matched/unmatched edges to find a minimum vertex cover. SimpleMaxFlow max_flow; const int num_left = left_to_right_arcs.size(); - std::vector arcs; + std::vector arcs; for (int i = 0; i < num_left; ++i) { for (const int right_node : left_to_right_arcs[i]) { DCHECK_GE(right_node, num_left); @@ -56,7 +55,7 @@ std::vector BipartiteMinimumVertexCover( } CHECK(max_flow.Solve(source, sink) == SimpleMaxFlow::OPTIMAL); std::vector maximum_matching(num_left + num_right, -1); - for (const ArcIndex arc : arcs) { + for (const SimpleMaxFlow::ArcIndex arc : arcs) { if (max_flow.Flow(arc) > 0) { maximum_matching[max_flow.Tail(arc)] = max_flow.Head(arc); maximum_matching[max_flow.Head(arc)] = max_flow.Tail(arc); diff --git a/ortools/graph/python/CMakeLists.txt b/ortools/graph/python/CMakeLists.txt index 85c0474037..a6a394eccc 100644 --- a/ortools/graph/python/CMakeLists.txt +++ b/ortools/graph/python/CMakeLists.txt @@ -62,9 +62,3 @@ endif() target_link_libraries(min_cost_flow_pybind11 PRIVATE ${PROJECT_NAMESPACE}::ortools) add_library(${PROJECT_NAMESPACE}::min_cost_flow_pybind11 ALIAS min_cost_flow_pybind11) -if(BUILD_TESTING) - file(GLOB PYTHON_SRCS "*_test.py") - foreach(FILE_NAME IN LISTS PYTHON_SRCS) - add_python_test(${FILE_NAME}) - endforeach() -endif() diff --git a/ortools/graph/python/min_cost_flow.cc b/ortools/graph/python/min_cost_flow.cc index 362df7a3be..d297ea9930 100644 --- a/ortools/graph/python/min_cost_flow.cc +++ b/ortools/graph/python/min_cost_flow.cc @@ -29,11 +29,18 @@ PYBIND11_MODULE(min_cost_flow, m) { arg("head"), arg("capacity"), arg("unit_cost")); smcf.def( "add_arcs_with_capacity_and_unit_cost", - pybind11::vectorize(&SimpleMinCostFlow::AddArcWithCapacityAndUnitCost)); + pybind11::vectorize(&SimpleMinCostFlow::AddArcWithCapacityAndUnitCost), + arg("tails"), arg("heads"), arg("capacities"), arg("unit_costs")); + smcf.def("set_arc_capacity", &SimpleMinCostFlow::SetArcCapacity, arg("arc"), + arg("capacity")); + smcf.def("set_arc_capacities", + pybind11::vectorize(&SimpleMinCostFlow::SetArcCapacity), arg("arcs"), + arg("capacities")); smcf.def("set_node_supply", &SimpleMinCostFlow::SetNodeSupply, arg("node"), arg("supply")); smcf.def("set_nodes_supplies", - pybind11::vectorize(&SimpleMinCostFlow::SetNodeSupply)); + pybind11::vectorize(&SimpleMinCostFlow::SetNodeSupply), arg("nodes"), + arg("supplies")); smcf.def("num_nodes", &SimpleMinCostFlow::NumNodes); smcf.def("num_arcs", &SimpleMinCostFlow::NumArcs); smcf.def("tail", &SimpleMinCostFlow::Tail, arg("arc")); @@ -47,7 +54,7 @@ PYBIND11_MODULE(min_cost_flow, m) { smcf.def("optimal_cost", &SimpleMinCostFlow::OptimalCost); smcf.def("maximum_flow", &SimpleMinCostFlow::MaximumFlow); smcf.def("flow", &SimpleMinCostFlow::Flow, arg("arc")); - smcf.def("flows", pybind11::vectorize(&SimpleMinCostFlow::Flow)); + smcf.def("flows", pybind11::vectorize(&SimpleMinCostFlow::Flow), arg("arcs")); pybind11::enum_(smcf, "Status") .value("BAD_COST_RANGE", MinCostFlowBase::Status::BAD_COST_RANGE) diff --git a/ortools/graph/samples/assignment_min_flow.cc b/ortools/graph/samples/assignment_min_flow.cc index 1f1240b9e6..0597de159a 100644 --- a/ortools/graph/samples/assignment_min_flow.cc +++ b/ortools/graph/samples/assignment_min_flow.cc @@ -69,7 +69,7 @@ void AssignmentMinFlow() { // [END solve] // [START print_solution] - if (status == MinCostFlow::OPTIMAL) { + if (status == SimpleMinCostFlow::OPTIMAL) { LOG(INFO) << "Total cost: " << min_cost_flow.OptimalCost(); LOG(INFO) << ""; for (std::size_t i = 0; i < min_cost_flow.NumArcs(); ++i) { diff --git a/ortools/graph/samples/balance_min_flow.cc b/ortools/graph/samples/balance_min_flow.cc index b29ddaaa68..4521199dc1 100644 --- a/ortools/graph/samples/balance_min_flow.cc +++ b/ortools/graph/samples/balance_min_flow.cc @@ -75,7 +75,7 @@ void BalanceMinFlow() { // [END solve] // [START print_solution] - if (status == MinCostFlow::OPTIMAL) { + if (status == SimpleMinCostFlow::OPTIMAL) { LOG(INFO) << "Total cost: " << min_cost_flow.OptimalCost(); LOG(INFO) << ""; for (std::size_t i = 0; i < min_cost_flow.NumArcs(); ++i) { diff --git a/ortools/graph/samples/simple_min_cost_flow_program.cc b/ortools/graph/samples/simple_min_cost_flow_program.cc index 1761262100..75d89669b5 100644 --- a/ortools/graph/samples/simple_min_cost_flow_program.cc +++ b/ortools/graph/samples/simple_min_cost_flow_program.cc @@ -62,7 +62,7 @@ void SimpleMinCostFlowProgram() { // [END solve] // [START print_solution] - if (status == MinCostFlow::OPTIMAL) { + if (status == SimpleMinCostFlow::OPTIMAL) { LOG(INFO) << "Minimum cost flow: " << min_cost_flow.OptimalCost(); LOG(INFO) << ""; LOG(INFO) << " Arc Flow / Capacity Cost"; diff --git a/ortools/graph/shortest_paths_test.cc b/ortools/graph/shortest_paths_test.cc index 3443a7b83c..83452244eb 100644 --- a/ortools/graph/shortest_paths_test.cc +++ b/ortools/graph/shortest_paths_test.cc @@ -21,7 +21,6 @@ #include "absl/log/check.h" #include "absl/random/random.h" #include "gtest/gtest.h" -#include "ortools/graph/ebert_graph.h" #include "ortools/graph/graph.h" #include "ortools/graph/strongly_connected_components.h" @@ -31,8 +30,9 @@ template void CheckPathDataPair( const GenericPathContainer& container, const GenericPathContainer& distance_container, - PathDistance expected_distance, NodeIndex expected_predecessor, - NodeIndex tail, NodeIndex head) { + PathDistance expected_distance, + typename GraphType::NodeIndex expected_predecessor, + typename GraphType::NodeIndex tail, typename GraphType::NodeIndex head) { EXPECT_EQ(expected_distance, container.GetDistance(tail, head)); EXPECT_EQ(expected_distance, distance_container.GetDistance(tail, head)); EXPECT_EQ(expected_predecessor, @@ -40,7 +40,7 @@ void CheckPathDataPair( EXPECT_DEATH(distance_container.GetPenultimateNodeInPath(tail, head), "Path not stored."); // Checking path between tail and head. - std::vector paths; + std::vector paths; container.GetPath(tail, head, &paths); if (tail == head) { EXPECT_GE(1, paths.size()); @@ -49,7 +49,7 @@ void CheckPathDataPair( } } else if (!paths.empty()) { EXPECT_EQ(tail, paths[0]); - NodeIndex current = head; + typename GraphType::NodeIndex current = head; for (int i = paths.size() - 1; i >= 0; --i) { EXPECT_EQ(current, paths[i]); current = container.GetPenultimateNodeInPath(tail, current); @@ -63,12 +63,13 @@ template void CheckPathDataRow(const GraphType& graph, const GenericPathContainer& container, const GenericPathContainer& distance_container, - const NodeIndex expected_paths[], - const PathDistance expected_distances[], NodeIndex tail) { + const typename GraphType::NodeIndex expected_paths[], + const PathDistance expected_distances[], + typename GraphType::NodeIndex tail) { int index = tail * graph.num_nodes(); for (typename GraphType::NodeIterator iterator(graph); iterator.Ok(); iterator.Next()) { - const NodeIndex head(iterator.Index()); + const typename GraphType::NodeIndex head(iterator.Index()); CheckPathDataPair(container, distance_container, expected_distances[index], expected_paths[index], tail, head); ++index; @@ -79,8 +80,9 @@ template void CheckPathDataRowFromGraph( const GraphType& graph, const GenericPathContainer& container, const GenericPathContainer& distance_container, - const NodeIndex expected_paths[], const PathDistance expected_distances[], - NodeIndex tail) { + const typename GraphType::NodeIndex expected_paths[], + const PathDistance expected_distances[], + typename GraphType::NodeIndex tail) { int index = tail * graph.num_nodes(); for (typename GraphType::NodeIndex head : graph.AllNodes()) { CheckPathDataPair(container, distance_container, expected_distances[index], @@ -93,11 +95,11 @@ template void CheckPathData(const GraphType& graph, const GenericPathContainer& container, const GenericPathContainer& distance_container, - const NodeIndex expected_paths[], + const typename GraphType::NodeIndex expected_paths[], const PathDistance expected_distances[]) { for (typename GraphType::NodeIterator iterator(graph); iterator.Ok(); iterator.Next()) { - const NodeIndex tail(iterator.Index()); + const typename GraphType::NodeIndex tail(iterator.Index()); CheckPathDataRow(graph, container, distance_container, expected_paths, expected_distances, tail); } @@ -107,7 +109,8 @@ template void CheckPathDataFromGraph( const GraphType& graph, const GenericPathContainer& container, const GenericPathContainer& distance_container, - const NodeIndex expected_paths[], const PathDistance expected_distances[]) { + const typename GraphType::NodeIndex expected_paths[], + const PathDistance expected_distances[]) { for (typename GraphType::NodeIndex tail : graph.AllNodes()) { CheckPathDataRowFromGraph(graph, container, distance_container, expected_paths, expected_distances, tail); @@ -121,10 +124,10 @@ void CheckPathDataFromGraph( GenericPathContainer::BuildPathDistanceContainer() template -void TestShortestPathsFromGraph(const GraphType& graph, - const std::vector& lengths, - const NodeIndex expected_paths[], - const PathDistance expected_distances[]) { +void TestShortestPathsFromGraph( + const GraphType& graph, const std::vector& lengths, + const typename GraphType::NodeIndex expected_paths[], + const PathDistance expected_distances[]) { const int kThreads = 32; const typename GraphType::NodeIndex source = 0; std::vector some_nodes; @@ -170,7 +173,7 @@ void TestShortestPathsFromGraph(const GraphType& graph, // Many-to-all shortest paths with duplicates. { BUILD_CONTAINERS(); - std::vector sources(3, source); + std::vector sources(3, source); ComputeManyToAllShortestPathsWithMultipleThreads(graph, lengths, sources, kThreads, &container); ComputeManyToAllShortestPathsWithMultipleThreads( @@ -235,21 +238,6 @@ void TestShortestPathsFromGraph( // Series of shortest paths tests on small graphs. -// Empty fixture templates to collect the types of graphs on which -// we want to base the shortest paths template instances that we -// test. -template -class ShortestPathsDeathTest : public testing::Test {}; -template -class ShortestPathsTest : public testing::Test {}; - -typedef testing::Types - EbertGraphTypesForShortestPathsTesting; - -TYPED_TEST_SUITE(ShortestPathsDeathTest, - EbertGraphTypesForShortestPathsTesting); -TYPED_TEST_SUITE(ShortestPathsTest, EbertGraphTypesForShortestPathsTesting); - template class GraphShortestPathsDeathTest : public testing::Test {}; template diff --git a/ortools/linear_solver/csharp/model_builder.i b/ortools/linear_solver/csharp/model_builder.i index 857180a75e..4d76bac4f1 100644 --- a/ortools/linear_solver/csharp/model_builder.i +++ b/ortools/linear_solver/csharp/model_builder.i @@ -27,7 +27,7 @@ VECTOR_AS_CSHARP_ARRAY(double, double, double, DoubleVector); %module(directors="1") operations_research_model_builder -%extend operations_research::ModelBuilderHelper { +%extend operations_research::mb::ModelBuilderHelper { std::string ExportToMpsString(bool obfuscate) { operations_research::MPModelExportOptions options; options.obfuscate = obfuscate; @@ -45,136 +45,137 @@ VECTOR_AS_CSHARP_ARRAY(double, double, double, DoubleVector); options.obfuscate = obfuscate; return $self->WriteToMpsFile(filename, options); } -} // Extend operations_research::ModelBuilderHelper +} // Extend operations_research::mb::ModelBuilderHelper %ignoreall %unignore operations_research; +%unignore operations_research::mb; // Wrap the ModelBuilderHelper class. -%unignore operations_research::ModelBuilderHelper; -%unignore operations_research::ModelBuilderHelper::ModelBuilderHelper; -%unignore operations_research::ModelBuilderHelper::~ModelBuilderHelper; +%unignore operations_research::mb::ModelBuilderHelper; +%unignore operations_research::mb::ModelBuilderHelper::ModelBuilderHelper; +%unignore operations_research::mb::ModelBuilderHelper::~ModelBuilderHelper; // Var API. -%unignore operations_research::ModelBuilderHelper::AddVar; -%unignore operations_research::ModelBuilderHelper::VarIsIntegral; -%unignore operations_research::ModelBuilderHelper::VarLowerBound; -%unignore operations_research::ModelBuilderHelper::VarName; -%unignore operations_research::ModelBuilderHelper::VarObjectiveCoefficient; -%unignore operations_research::ModelBuilderHelper::VarUpperBound; -%unignore operations_research::ModelBuilderHelper::SetVarIntegrality; -%unignore operations_research::ModelBuilderHelper::SetVarLowerBound; -%unignore operations_research::ModelBuilderHelper::SetVarName; -%unignore operations_research::ModelBuilderHelper::SetVarObjectiveCoefficient; -%unignore operations_research::ModelBuilderHelper::SetVarUpperBound; +%unignore operations_research::mb::ModelBuilderHelper::AddVar; +%unignore operations_research::mb::ModelBuilderHelper::VarIsIntegral; +%unignore operations_research::mb::ModelBuilderHelper::VarLowerBound; +%unignore operations_research::mb::ModelBuilderHelper::VarName; +%unignore operations_research::mb::ModelBuilderHelper::VarObjectiveCoefficient; +%unignore operations_research::mb::ModelBuilderHelper::VarUpperBound; +%unignore operations_research::mb::ModelBuilderHelper::SetVarIntegrality; +%unignore operations_research::mb::ModelBuilderHelper::SetVarLowerBound; +%unignore operations_research::mb::ModelBuilderHelper::SetVarName; +%unignore operations_research::mb::ModelBuilderHelper::SetVarObjectiveCoefficient; +%unignore operations_research::mb::ModelBuilderHelper::SetVarUpperBound; // Linear Constraint API. -%unignore operations_research::ModelBuilderHelper::AddConstraintTerm; -%unignore operations_research::ModelBuilderHelper::AddLinearConstraint; -%unignore operations_research::ModelBuilderHelper::ClearConstraintTerms; -%unignore operations_research::ModelBuilderHelper::ConstraintCoefficients; -%unignore operations_research::ModelBuilderHelper::ConstraintLowerBound; -%unignore operations_research::ModelBuilderHelper::ConstraintName; -%unignore operations_research::ModelBuilderHelper::ConstraintUpperBound; -%unignore operations_research::ModelBuilderHelper::ConstraintVarIndices; -%unignore operations_research::ModelBuilderHelper::SafeAddConstraintTerm; -%unignore operations_research::ModelBuilderHelper::SetConstraintCoefficient; -%unignore operations_research::ModelBuilderHelper::SetConstraintLowerBound; -%unignore operations_research::ModelBuilderHelper::SetConstraintName; -%unignore operations_research::ModelBuilderHelper::SetConstraintUpperBound; +%unignore operations_research::mb::ModelBuilderHelper::AddConstraintTerm; +%unignore operations_research::mb::ModelBuilderHelper::AddLinearConstraint; +%unignore operations_research::mb::ModelBuilderHelper::ClearConstraintTerms; +%unignore operations_research::mb::ModelBuilderHelper::ConstraintCoefficients; +%unignore operations_research::mb::ModelBuilderHelper::ConstraintLowerBound; +%unignore operations_research::mb::ModelBuilderHelper::ConstraintName; +%unignore operations_research::mb::ModelBuilderHelper::ConstraintUpperBound; +%unignore operations_research::mb::ModelBuilderHelper::ConstraintVarIndices; +%unignore operations_research::mb::ModelBuilderHelper::SafeAddConstraintTerm; +%unignore operations_research::mb::ModelBuilderHelper::SetConstraintCoefficient; +%unignore operations_research::mb::ModelBuilderHelper::SetConstraintLowerBound; +%unignore operations_research::mb::ModelBuilderHelper::SetConstraintName; +%unignore operations_research::mb::ModelBuilderHelper::SetConstraintUpperBound; // Enforced Linear Constraints API. -%unignore operations_research::ModelBuilderHelper::AddEnforcedConstraintTerm; -%unignore operations_research::ModelBuilderHelper::AddEnforcedLinearConstraint; -%unignore operations_research::ModelBuilderHelper::ClearEnforcedConstraintTerms; -%unignore operations_research::ModelBuilderHelper::EnforcedConstraintCoefficients; -%unignore operations_research::ModelBuilderHelper::EnforcedConstraintLowerBound; -%unignore operations_research::ModelBuilderHelper::EnforcedConstraintName; -%unignore operations_research::ModelBuilderHelper::EnforcedConstraintUpperBound; -%unignore operations_research::ModelBuilderHelper::EnforcedConstraintVarIndices; -%unignore operations_research::ModelBuilderHelper::EnforcedIndicatorValue; -%unignore operations_research::ModelBuilderHelper::EnforcedIndicatorVariableIndex; -%unignore operations_research::ModelBuilderHelper::IsEnforcedConstraint; -%unignore operations_research::ModelBuilderHelper::SafeAddEnforcedConstraintTerm; -%unignore operations_research::ModelBuilderHelper::SetEnforcedConstraintCoefficient; -%unignore operations_research::ModelBuilderHelper::SetEnforcedConstraintLowerBound; -%unignore operations_research::ModelBuilderHelper::SetEnforcedConstraintName; -%unignore operations_research::ModelBuilderHelper::SetEnforcedConstraintUpperBound; -%unignore operations_research::ModelBuilderHelper::SetEnforcedIndicatorValue; -%unignore operations_research::ModelBuilderHelper::SetEnforcedIndicatorVariableIndex; +%unignore operations_research::mb::ModelBuilderHelper::AddEnforcedConstraintTerm; +%unignore operations_research::mb::ModelBuilderHelper::AddEnforcedLinearConstraint; +%unignore operations_research::mb::ModelBuilderHelper::ClearEnforcedConstraintTerms; +%unignore operations_research::mb::ModelBuilderHelper::EnforcedConstraintCoefficients; +%unignore operations_research::mb::ModelBuilderHelper::EnforcedConstraintLowerBound; +%unignore operations_research::mb::ModelBuilderHelper::EnforcedConstraintName; +%unignore operations_research::mb::ModelBuilderHelper::EnforcedConstraintUpperBound; +%unignore operations_research::mb::ModelBuilderHelper::EnforcedConstraintVarIndices; +%unignore operations_research::mb::ModelBuilderHelper::EnforcedIndicatorValue; +%unignore operations_research::mb::ModelBuilderHelper::EnforcedIndicatorVariableIndex; +%unignore operations_research::mb::ModelBuilderHelper::IsEnforcedConstraint; +%unignore operations_research::mb::ModelBuilderHelper::SafeAddEnforcedConstraintTerm; +%unignore operations_research::mb::ModelBuilderHelper::SetEnforcedConstraintCoefficient; +%unignore operations_research::mb::ModelBuilderHelper::SetEnforcedConstraintLowerBound; +%unignore operations_research::mb::ModelBuilderHelper::SetEnforcedConstraintName; +%unignore operations_research::mb::ModelBuilderHelper::SetEnforcedConstraintUpperBound; +%unignore operations_research::mb::ModelBuilderHelper::SetEnforcedIndicatorValue; +%unignore operations_research::mb::ModelBuilderHelper::SetEnforcedIndicatorVariableIndex; // Objective API. -%unignore operations_research::ModelBuilderHelper::ClearObjective; -%rename (Maximize) operations_research::ModelBuilderHelper::maximize; -%unignore operations_research::ModelBuilderHelper::SetMaximize; -%unignore operations_research::ModelBuilderHelper::ObjectiveOffset; -%unignore operations_research::ModelBuilderHelper::SetObjectiveOffset; +%unignore operations_research::mb::ModelBuilderHelper::ClearObjective; +%rename (Maximize) operations_research::mb::ModelBuilderHelper::maximize; +%unignore operations_research::mb::ModelBuilderHelper::SetMaximize; +%unignore operations_research::mb::ModelBuilderHelper::ObjectiveOffset; +%unignore operations_research::mb::ModelBuilderHelper::SetObjectiveOffset; // Hints -%unignore operations_research::ModelBuilderHelper::ClearHints; -%unignore operations_research::ModelBuilderHelper::AddHint; +%unignore operations_research::mb::ModelBuilderHelper::ClearHints; +%unignore operations_research::mb::ModelBuilderHelper::AddHint; // Model API. -%rename (VariablesCount) operations_research::ModelBuilderHelper::num_variables; -%rename (ConstraintsCount) operations_research::ModelBuilderHelper::num_constraints; -%rename (Name) operations_research::ModelBuilderHelper::name; -%unignore operations_research::ModelBuilderHelper::SetName; -%unignore operations_research::ModelBuilderHelper::ReadModelFromProtoFile; -%unignore operations_research::ModelBuilderHelper::WriteModelToProtoFile; -%unignore operations_research::ModelBuilderHelper::ImportFromMpsString; -%unignore operations_research::ModelBuilderHelper::ImportFromMpsFile; -%unignore operations_research::ModelBuilderHelper::ImportFromLpString; -%unignore operations_research::ModelBuilderHelper::ImportFromLpFile; -%unignore operations_research::ModelBuilderHelper::WriteToMpsFile(std::string, bool); -%unignore operations_research::ModelBuilderHelper::ExportToMpsString(bool); -%unignore operations_research::ModelBuilderHelper::ExportToLpString(bool); -%unignore operations_research::ModelBuilderHelper::OverwriteModel; +%rename (VariablesCount) operations_research::mb::ModelBuilderHelper::num_variables; +%rename (ConstraintsCount) operations_research::mb::ModelBuilderHelper::num_constraints; +%rename (Name) operations_research::mb::ModelBuilderHelper::name; +%unignore operations_research::mb::ModelBuilderHelper::SetName; +%unignore operations_research::mb::ModelBuilderHelper::ReadModelFromProtoFile; +%unignore operations_research::mb::ModelBuilderHelper::WriteModelToProtoFile; +%unignore operations_research::mb::ModelBuilderHelper::ImportFromMpsString; +%unignore operations_research::mb::ModelBuilderHelper::ImportFromMpsFile; +%unignore operations_research::mb::ModelBuilderHelper::ImportFromLpString; +%unignore operations_research::mb::ModelBuilderHelper::ImportFromLpFile; +%unignore operations_research::mb::ModelBuilderHelper::WriteToMpsFile(std::string, bool); +%unignore operations_research::mb::ModelBuilderHelper::ExportToMpsString(bool); +%unignore operations_research::mb::ModelBuilderHelper::ExportToLpString(bool); +%unignore operations_research::mb::ModelBuilderHelper::OverwriteModel; // Callbacks support. -%feature("director") operations_research::MbLogCallback; -%unignore operations_research::MbLogCallback; -%unignore operations_research::MbLogCallback::~MbLogCallback; -%unignore operations_research::MbLogCallback::NewMessage; +%feature("director") operations_research::mb::MbLogCallback; +%unignore operations_research::mb::MbLogCallback; +%unignore operations_research::mb::MbLogCallback::~MbLogCallback; +%unignore operations_research::mb::MbLogCallback::NewMessage; // Solver API. -%unignore operations_research::ModelSolverHelper; -%unignore operations_research::ModelSolverHelper::ModelSolverHelper(const std::string&); -%unignore operations_research::ModelSolverHelper::SolverIsSupported; -%unignore operations_research::ModelSolverHelper::Solve; -%unignore operations_research::ModelSolverHelper::InterruptSolve; -%rename (HasResponse) operations_research::ModelSolverHelper::has_response; -%rename (HasSolution) operations_research::ModelSolverHelper::has_solution; -%rename (Status) operations_research::ModelSolverHelper::status; -%rename (ObjectiveValue) operations_research::ModelSolverHelper::objective_value; -%rename (BestObjectiveBound) operations_research::ModelSolverHelper::best_objective_bound; -%rename (VariableValue) operations_research::ModelSolverHelper::variable_value; -%rename (ReducedCost) operations_research::ModelSolverHelper::reduced_cost; -%rename (DualValue) operations_research::ModelSolverHelper::dual_value; -%rename (Activity) operations_research::ModelSolverHelper::activity; -%rename (StatusString) operations_research::ModelSolverHelper::status_string; -%rename (WallTime) operations_research::ModelSolverHelper::wall_time; -%rename (UserTime) operations_research::ModelSolverHelper::user_time; -%unignore operations_research::ModelSolverHelper::EnableOutput; -%unignore operations_research::ModelSolverHelper::ClearLogCallback; -%unignore operations_research::ModelSolverHelper::SetLogCallbackFromDirectorClass; -%unignore operations_research::ModelSolverHelper::SetTimeLimitInSeconds; -%unignore operations_research::ModelSolverHelper::SetSolverSpecificParameters; +%unignore operations_research::mb::ModelSolverHelper; +%unignore operations_research::mb::ModelSolverHelper::ModelSolverHelper(const std::string&); +%unignore operations_research::mb::ModelSolverHelper::SolverIsSupported; +%unignore operations_research::mb::ModelSolverHelper::Solve; +%unignore operations_research::mb::ModelSolverHelper::InterruptSolve; +%rename (HasResponse) operations_research::mb::ModelSolverHelper::has_response; +%rename (HasSolution) operations_research::mb::ModelSolverHelper::has_solution; +%rename (Status) operations_research::mb::ModelSolverHelper::status; +%rename (ObjectiveValue) operations_research::mb::ModelSolverHelper::objective_value; +%rename (BestObjectiveBound) operations_research::mb::ModelSolverHelper::best_objective_bound; +%rename (VariableValue) operations_research::mb::ModelSolverHelper::variable_value; +%rename (ReducedCost) operations_research::mb::ModelSolverHelper::reduced_cost; +%rename (DualValue) operations_research::mb::ModelSolverHelper::dual_value; +%rename (Activity) operations_research::mb::ModelSolverHelper::activity; +%rename (StatusString) operations_research::mb::ModelSolverHelper::status_string; +%rename (WallTime) operations_research::mb::ModelSolverHelper::wall_time; +%rename (UserTime) operations_research::mb::ModelSolverHelper::user_time; +%unignore operations_research::mb::ModelSolverHelper::EnableOutput; +%unignore operations_research::mb::ModelSolverHelper::ClearLogCallback; +%unignore operations_research::mb::ModelSolverHelper::SetLogCallbackFromDirectorClass; +%unignore operations_research::mb::ModelSolverHelper::SetTimeLimitInSeconds; +%unignore operations_research::mb::ModelSolverHelper::SetSolverSpecificParameters; -%unignore operations_research::SolveStatus; -%unignore operations_research::OPTIMAL; -%unignore operations_research::FEASIBLE; -%unignore operations_research::INFEASIBLE; -%unignore operations_research::UNBOUNDED; -%unignore operations_research::ABNORMAL; -%unignore operations_research::NOT_SOLVED; -%unignore operations_research::MODEL_IS_VALID; -%unignore operations_research::CANCELLED_BY_USER; -%unignore operations_research::UNKNOWN_STATUS; -%unignore operations_research::MODEL_INVALID; -%unignore operations_research::INVALID_SOLVER_PARAMETERS; -%unignore operations_research::SOLVER_TYPE_UNAVAILABLE; -%unignore operations_research::INCOMPATIBLE_OPTIONS; +%unignore operations_research::mb::SolveStatus; +%unignore operations_research::mb::OPTIMAL; +%unignore operations_research::mb::FEASIBLE; +%unignore operations_research::mb::INFEASIBLE; +%unignore operations_research::mb::UNBOUNDED; +%unignore operations_research::mb::ABNORMAL; +%unignore operations_research::mb::NOT_SOLVED; +%unignore operations_research::mb::MODEL_IS_VALID; +%unignore operations_research::mb::CANCELLED_BY_USER; +%unignore operations_research::mb::UNKNOWN_STATUS; +%unignore operations_research::mb::MODEL_INVALID; +%unignore operations_research::mb::INVALID_SOLVER_PARAMETERS; +%unignore operations_research::mb::SOLVER_TYPE_UNAVAILABLE; +%unignore operations_research::mb::INCOMPATIBLE_OPTIONS; // For enums %include "ortools/linear_solver/wrappers/model_builder_helper.h" diff --git a/ortools/linear_solver/java/modelbuilder.i b/ortools/linear_solver/java/modelbuilder.i index ea68c3b7a2..f8c526c1b0 100644 --- a/ortools/linear_solver/java/modelbuilder.i +++ b/ortools/linear_solver/java/modelbuilder.i @@ -87,7 +87,7 @@ class GlobalRefGuard { %typemap(jstype) std::function "java.util.function.Consumer" // Type used in the Proxy class. %typemap(javain) std::function "$javainput" // passing the Callback to JNI java class. -%extend operations_research::ModelBuilderHelper { +%extend operations_research::mb::ModelBuilderHelper { std::string exportToMpsString(bool obfuscate) { operations_research::MPModelExportOptions options; options.obfuscate = obfuscate; @@ -105,129 +105,130 @@ class GlobalRefGuard { options.obfuscate = obfuscate; return $self->WriteToMpsFile(filename, options); } -} // Extend operations_research::ModelBuilderHelper +} // Extend operations_research::mb::ModelBuilderHelper %ignoreall %unignore operations_research; +%unignore operations_research::mb; // Wrap the ModelBuilderHelper class. -%unignore operations_research::ModelBuilderHelper; -%unignore operations_research::ModelBuilderHelper::ModelBuilderHelper; -%unignore operations_research::ModelBuilderHelper::~ModelBuilderHelper; +%unignore operations_research::mb::ModelBuilderHelper; +%unignore operations_research::mb::ModelBuilderHelper::ModelBuilderHelper; +%unignore operations_research::mb::ModelBuilderHelper::~ModelBuilderHelper; // Var API. -%rename (addVar) operations_research::ModelBuilderHelper::AddVar; -%rename (getVarIntegrality) operations_research::ModelBuilderHelper::VarIsIntegral; -%rename (getVarLowerBound) operations_research::ModelBuilderHelper::VarLowerBound; -%rename (getVarName) operations_research::ModelBuilderHelper::VarName; -%rename (getVarObjectiveCoefficient) operations_research::ModelBuilderHelper::VarObjectiveCoefficient; -%rename (getVarUpperBound) operations_research::ModelBuilderHelper::VarUpperBound; -%rename (setVarIntegrality) operations_research::ModelBuilderHelper::SetVarIntegrality; -%rename (setVarLowerBound) operations_research::ModelBuilderHelper::SetVarLowerBound; -%rename (setVarName) operations_research::ModelBuilderHelper::SetVarName; -%rename (setVarObjectiveCoefficient) operations_research::ModelBuilderHelper::SetVarObjectiveCoefficient; -%rename (setVarUpperBound) operations_research::ModelBuilderHelper::SetVarUpperBound; +%rename (addVar) operations_research::mb::ModelBuilderHelper::AddVar; +%rename (getVarIntegrality) operations_research::mb::ModelBuilderHelper::VarIsIntegral; +%rename (getVarLowerBound) operations_research::mb::ModelBuilderHelper::VarLowerBound; +%rename (getVarName) operations_research::mb::ModelBuilderHelper::VarName; +%rename (getVarObjectiveCoefficient) operations_research::mb::ModelBuilderHelper::VarObjectiveCoefficient; +%rename (getVarUpperBound) operations_research::mb::ModelBuilderHelper::VarUpperBound; +%rename (setVarIntegrality) operations_research::mb::ModelBuilderHelper::SetVarIntegrality; +%rename (setVarLowerBound) operations_research::mb::ModelBuilderHelper::SetVarLowerBound; +%rename (setVarName) operations_research::mb::ModelBuilderHelper::SetVarName; +%rename (setVarObjectiveCoefficient) operations_research::mb::ModelBuilderHelper::SetVarObjectiveCoefficient; +%rename (setVarUpperBound) operations_research::mb::ModelBuilderHelper::SetVarUpperBound; // Linear Constraint API. -%rename (addConstraintTerm) operations_research::ModelBuilderHelper::AddConstraintTerm; -%rename (addLinearConstraint) operations_research::ModelBuilderHelper::AddLinearConstraint; -%rename (clearConstraintTerms) operations_research::ModelBuilderHelper::ClearConstraintTerms; -%rename (getConstraintCoefficients) operations_research::ModelBuilderHelper::ConstraintCoefficients; -%rename (getConstraintLowerBound) operations_research::ModelBuilderHelper::ConstraintLowerBound; -%rename (getConstraintName) operations_research::ModelBuilderHelper::ConstraintName; -%rename (getConstraintUpperBound) operations_research::ModelBuilderHelper::ConstraintUpperBound; -%rename (getConstraintVarIndices) operations_research::ModelBuilderHelper::ConstraintVarIndices; -%rename (safeAddConstraintTerm) operations_research::ModelBuilderHelper::SafeAddConstraintTerm; -%rename (setConstraintCoefficient) operations_research::ModelBuilderHelper::SetConstraintCoefficient; -%rename (setConstraintLowerBound) operations_research::ModelBuilderHelper::SetConstraintLowerBound; -%rename (setConstraintName) operations_research::ModelBuilderHelper::SetConstraintName; -%rename (setConstraintUpperBound) operations_research::ModelBuilderHelper::SetConstraintUpperBound; +%rename (addConstraintTerm) operations_research::mb::ModelBuilderHelper::AddConstraintTerm; +%rename (addLinearConstraint) operations_research::mb::ModelBuilderHelper::AddLinearConstraint; +%rename (clearConstraintTerms) operations_research::mb::ModelBuilderHelper::ClearConstraintTerms; +%rename (getConstraintCoefficients) operations_research::mb::ModelBuilderHelper::ConstraintCoefficients; +%rename (getConstraintLowerBound) operations_research::mb::ModelBuilderHelper::ConstraintLowerBound; +%rename (getConstraintName) operations_research::mb::ModelBuilderHelper::ConstraintName; +%rename (getConstraintUpperBound) operations_research::mb::ModelBuilderHelper::ConstraintUpperBound; +%rename (getConstraintVarIndices) operations_research::mb::ModelBuilderHelper::ConstraintVarIndices; +%rename (safeAddConstraintTerm) operations_research::mb::ModelBuilderHelper::SafeAddConstraintTerm; +%rename (setConstraintCoefficient) operations_research::mb::ModelBuilderHelper::SetConstraintCoefficient; +%rename (setConstraintLowerBound) operations_research::mb::ModelBuilderHelper::SetConstraintLowerBound; +%rename (setConstraintName) operations_research::mb::ModelBuilderHelper::SetConstraintName; +%rename (setConstraintUpperBound) operations_research::mb::ModelBuilderHelper::SetConstraintUpperBound; // Enforced Linear Constraint API. -%rename (addEnforcedConstraintTerm) operations_research::ModelBuilderHelper::AddEnforcedConstraintTerm; -%rename (addEnforcedLinearConstraint) operations_research::ModelBuilderHelper::AddEnforcedLinearConstraint; -%rename (clearEnforcedConstraintTerms) operations_research::ModelBuilderHelper::ClearEnforcedConstraintTerms; -%rename (getEnforcedConstraintCoefficients) operations_research::ModelBuilderHelper::EnforcedConstraintCoefficients; -%rename (getEnforcedConstraintLowerBound) operations_research::ModelBuilderHelper::EnforcedConstraintLowerBound; -%rename (getEnforcedConstraintName) operations_research::ModelBuilderHelper::EnforcedConstraintName; -%rename (getEnforcedConstraintUpperBound) operations_research::ModelBuilderHelper::EnforcedConstraintUpperBound; -%rename (getEnforcedConstraintVarIndices) operations_research::ModelBuilderHelper::EnforcedConstraintVarIndices; -%rename (getEnforcedIndicatorValue) operations_research::ModelBuilderHelper::EnforcedIndicatorValue; -%rename (getEnforcedIndicatorVariableIndex) operations_research::ModelBuilderHelper::EnforcedIndicatorVariableIndex; -%rename (isEnforcedConstraint) operations_research::ModelBuilderHelper::IsEnforcedConstraint; -%rename (safeAddEnforcedConstraintTerm) operations_research::ModelBuilderHelper::SafeAddEnforcedConstraintTerm; -%rename (setEnforcedConstraintCoefficient) operations_research::ModelBuilderHelper::SetEnforcedConstraintCoefficient; -%rename (setEnforcedConstraintLowerBound) operations_research::ModelBuilderHelper::SetEnforcedConstraintLowerBound; -%rename (setEnforcedConstraintName) operations_research::ModelBuilderHelper::SetEnforcedConstraintName; -%rename (setEnforcedConstraintUpperBound) operations_research::ModelBuilderHelper::SetEnforcedConstraintUpperBound; -%rename (setEnforcedIndicatorValue) operations_research::ModelBuilderHelper::SetEnforcedIndicatorValue; -%rename (setEnforcedIndicatorVariableIndex) operations_research::ModelBuilderHelper::SetEnforcedIndicatorVariableIndex; +%rename (addEnforcedConstraintTerm) operations_research::mb::ModelBuilderHelper::AddEnforcedConstraintTerm; +%rename (addEnforcedLinearConstraint) operations_research::mb::ModelBuilderHelper::AddEnforcedLinearConstraint; +%rename (clearEnforcedConstraintTerms) operations_research::mb::ModelBuilderHelper::ClearEnforcedConstraintTerms; +%rename (getEnforcedConstraintCoefficients) operations_research::mb::ModelBuilderHelper::EnforcedConstraintCoefficients; +%rename (getEnforcedConstraintLowerBound) operations_research::mb::ModelBuilderHelper::EnforcedConstraintLowerBound; +%rename (getEnforcedConstraintName) operations_research::mb::ModelBuilderHelper::EnforcedConstraintName; +%rename (getEnforcedConstraintUpperBound) operations_research::mb::ModelBuilderHelper::EnforcedConstraintUpperBound; +%rename (getEnforcedConstraintVarIndices) operations_research::mb::ModelBuilderHelper::EnforcedConstraintVarIndices; +%rename (getEnforcedIndicatorValue) operations_research::mb::ModelBuilderHelper::EnforcedIndicatorValue; +%rename (getEnforcedIndicatorVariableIndex) operations_research::mb::ModelBuilderHelper::EnforcedIndicatorVariableIndex; +%rename (isEnforcedConstraint) operations_research::mb::ModelBuilderHelper::IsEnforcedConstraint; +%rename (safeAddEnforcedConstraintTerm) operations_research::mb::ModelBuilderHelper::SafeAddEnforcedConstraintTerm; +%rename (setEnforcedConstraintCoefficient) operations_research::mb::ModelBuilderHelper::SetEnforcedConstraintCoefficient; +%rename (setEnforcedConstraintLowerBound) operations_research::mb::ModelBuilderHelper::SetEnforcedConstraintLowerBound; +%rename (setEnforcedConstraintName) operations_research::mb::ModelBuilderHelper::SetEnforcedConstraintName; +%rename (setEnforcedConstraintUpperBound) operations_research::mb::ModelBuilderHelper::SetEnforcedConstraintUpperBound; +%rename (setEnforcedIndicatorValue) operations_research::mb::ModelBuilderHelper::SetEnforcedIndicatorValue; +%rename (setEnforcedIndicatorVariableIndex) operations_research::mb::ModelBuilderHelper::SetEnforcedIndicatorVariableIndex; // Objective API. -%rename (clearObjective) operations_research::ModelBuilderHelper::ClearObjective; -%rename (getMaximize) operations_research::ModelBuilderHelper::maximize; -%rename (setMaximize) operations_research::ModelBuilderHelper::SetMaximize; -%rename (getObjectiveOffset) operations_research::ModelBuilderHelper::ObjectiveOffset; -%rename (setObjectiveOffset) operations_research::ModelBuilderHelper::SetObjectiveOffset; +%rename (clearObjective) operations_research::mb::ModelBuilderHelper::ClearObjective; +%rename (getMaximize) operations_research::mb::ModelBuilderHelper::maximize; +%rename (setMaximize) operations_research::mb::ModelBuilderHelper::SetMaximize; +%rename (getObjectiveOffset) operations_research::mb::ModelBuilderHelper::ObjectiveOffset; +%rename (setObjectiveOffset) operations_research::mb::ModelBuilderHelper::SetObjectiveOffset; // Hints. -%rename (clearHints) operations_research::ModelBuilderHelper::ClearHints; -%rename (addHint) operations_research::ModelBuilderHelper::AddHint; +%rename (clearHints) operations_research::mb::ModelBuilderHelper::ClearHints; +%rename (addHint) operations_research::mb::ModelBuilderHelper::AddHint; // Model API. -%rename (numVariables) operations_research::ModelBuilderHelper::num_variables; -%rename (numConstraints) operations_research::ModelBuilderHelper::num_constraints; -%rename (getName) operations_research::ModelBuilderHelper::name; -%rename (setName) operations_research::ModelBuilderHelper::SetName; -%rename (readModelFromProtoFile) operations_research::ModelBuilderHelper::ReadModelFromProtoFile; -%rename (writeModelToProtoFile) operations_research::ModelBuilderHelper::WriteModelToProtoFile; -%rename (importFromMpsString) operations_research::ModelBuilderHelper::ImportFromMpsString; -%rename (importFromMpsFile) operations_research::ModelBuilderHelper::ImportFromMpsFile; -%rename (importFromLpString) operations_research::ModelBuilderHelper::ImportFromLpString; -%rename (importFromLpFile) operations_research::ModelBuilderHelper::ImportFromLpFile; -%unignore operations_research::ModelBuilderHelper::exportToMpsString; -%unignore operations_research::ModelBuilderHelper::exportToLpString; -%unignore operations_research::ModelBuilderHelper::writeToMpsFile; -%rename (overwriteModel) operations_research::ModelBuilderHelper::OverwriteModel; +%rename (numVariables) operations_research::mb::ModelBuilderHelper::num_variables; +%rename (numConstraints) operations_research::mb::ModelBuilderHelper::num_constraints; +%rename (getName) operations_research::mb::ModelBuilderHelper::name; +%rename (setName) operations_research::mb::ModelBuilderHelper::SetName; +%rename (readModelFromProtoFile) operations_research::mb::ModelBuilderHelper::ReadModelFromProtoFile; +%rename (writeModelToProtoFile) operations_research::mb::ModelBuilderHelper::WriteModelToProtoFile; +%rename (importFromMpsString) operations_research::mb::ModelBuilderHelper::ImportFromMpsString; +%rename (importFromMpsFile) operations_research::mb::ModelBuilderHelper::ImportFromMpsFile; +%rename (importFromLpString) operations_research::mb::ModelBuilderHelper::ImportFromLpString; +%rename (importFromLpFile) operations_research::mb::ModelBuilderHelper::ImportFromLpFile; +%unignore operations_research::mb::ModelBuilderHelper::exportToMpsString; +%unignore operations_research::mb::ModelBuilderHelper::exportToLpString; +%unignore operations_research::mb::ModelBuilderHelper::writeToMpsFile; +%rename (overwriteModel) operations_research::mb::ModelBuilderHelper::OverwriteModel; -%unignore operations_research::ModelSolverHelper; -%unignore operations_research::ModelSolverHelper::ModelSolverHelper(const std::string&); -%rename (solverIsSupported) operations_research::ModelSolverHelper::SolverIsSupported; -%rename (solve) operations_research::ModelSolverHelper::Solve; -%rename (interruptSolve) operations_research::ModelSolverHelper::InterruptSolve; -%rename (hasResponse) operations_research::ModelSolverHelper::has_response; -%rename (hasSolution) operations_research::ModelSolverHelper::has_solution; -%rename (getStatus) operations_research::ModelSolverHelper::status; -%rename (getObjectiveValue) operations_research::ModelSolverHelper::objective_value; -%rename (getBestObjectiveBound) operations_research::ModelSolverHelper::best_objective_bound; -%rename (getVariableValue) operations_research::ModelSolverHelper::variable_value; -%rename (getReducedCost) operations_research::ModelSolverHelper::reduced_cost; -%rename (getDualValue) operations_research::ModelSolverHelper::dual_value; -%rename (getActivity) operations_research::ModelSolverHelper::activity; -%rename (getStatusString) operations_research::ModelSolverHelper::status_string; -%rename (getWallTime) operations_research::ModelSolverHelper::wall_time; -%rename (getUserTime) operations_research::ModelSolverHelper::user_time; -%rename (enableOutput) operations_research::ModelSolverHelper::EnableOutput; -%rename (clearLogCallback) operations_research::ModelSolverHelper::ClearLogCallback; -%rename (setLogCallback) operations_research::ModelSolverHelper::SetLogCallback; -%rename (setTimeLimitInSeconds) operations_research::ModelSolverHelper::SetTimeLimitInSeconds; -%rename (setSolverSpecificParameters) operations_research::ModelSolverHelper::SetSolverSpecificParameters; +%unignore operations_research::mb::ModelSolverHelper; +%unignore operations_research::mb::ModelSolverHelper::ModelSolverHelper(const std::string&); +%rename (solverIsSupported) operations_research::mb::ModelSolverHelper::SolverIsSupported; +%rename (solve) operations_research::mb::ModelSolverHelper::Solve; +%rename (interruptSolve) operations_research::mb::ModelSolverHelper::InterruptSolve; +%rename (hasResponse) operations_research::mb::ModelSolverHelper::has_response; +%rename (hasSolution) operations_research::mb::ModelSolverHelper::has_solution; +%rename (getStatus) operations_research::mb::ModelSolverHelper::status; +%rename (getObjectiveValue) operations_research::mb::ModelSolverHelper::objective_value; +%rename (getBestObjectiveBound) operations_research::mb::ModelSolverHelper::best_objective_bound; +%rename (getVariableValue) operations_research::mb::ModelSolverHelper::variable_value; +%rename (getReducedCost) operations_research::mb::ModelSolverHelper::reduced_cost; +%rename (getDualValue) operations_research::mb::ModelSolverHelper::dual_value; +%rename (getActivity) operations_research::mb::ModelSolverHelper::activity; +%rename (getStatusString) operations_research::mb::ModelSolverHelper::status_string; +%rename (getWallTime) operations_research::mb::ModelSolverHelper::wall_time; +%rename (getUserTime) operations_research::mb::ModelSolverHelper::user_time; +%rename (enableOutput) operations_research::mb::ModelSolverHelper::EnableOutput; +%rename (clearLogCallback) operations_research::mb::ModelSolverHelper::ClearLogCallback; +%rename (setLogCallback) operations_research::mb::ModelSolverHelper::SetLogCallback; +%rename (setTimeLimitInSeconds) operations_research::mb::ModelSolverHelper::SetTimeLimitInSeconds; +%rename (setSolverSpecificParameters) operations_research::mb::ModelSolverHelper::SetSolverSpecificParameters; -%unignore operations_research::SolveStatus; -%unignore operations_research::OPTIMAL; -%unignore operations_research::FEASIBLE; -%unignore operations_research::INFEASIBLE; -%unignore operations_research::UNBOUNDED; -%unignore operations_research::ABNORMAL; -%unignore operations_research::NOT_SOLVED; -%unignore operations_research::MODEL_IS_VALID; -%unignore operations_research::CANCELLED_BY_USER; -%unignore operations_research::UNKNOWN_STATUS; -%unignore operations_research::MODEL_INVALID; -%unignore operations_research::INVALID_SOLVER_PARAMETERS; -%unignore operations_research::SOLVER_TYPE_UNAVAILABLE; -%unignore operations_research::INCOMPATIBLE_OPTIONS; +%unignore operations_research::mb::SolveStatus; +%unignore operations_research::mb::OPTIMAL; +%unignore operations_research::mb::FEASIBLE; +%unignore operations_research::mb::INFEASIBLE; +%unignore operations_research::mb::UNBOUNDED; +%unignore operations_research::mb::ABNORMAL; +%unignore operations_research::mb::NOT_SOLVED; +%unignore operations_research::mb::MODEL_IS_VALID; +%unignore operations_research::mb::CANCELLED_BY_USER; +%unignore operations_research::mb::UNKNOWN_STATUS; +%unignore operations_research::mb::MODEL_INVALID; +%unignore operations_research::mb::INVALID_SOLVER_PARAMETERS; +%unignore operations_research::mb::SOLVER_TYPE_UNAVAILABLE; +%unignore operations_research::mb::INCOMPATIBLE_OPTIONS; // For enums %javaconst(1); diff --git a/ortools/linear_solver/python/model_builder.py b/ortools/linear_solver/python/model_builder.py index 48b6251c77..0b59c665fc 100644 --- a/ortools/linear_solver/python/model_builder.py +++ b/ortools/linear_solver/python/model_builder.py @@ -32,386 +32,37 @@ Other methods and functions listed are primarily used for developing OR-Tools, rather than for solving specific optimization problems. """ -import abc -import dataclasses import math import numbers import typing -from typing import Callable, List, Optional, Sequence, Tuple, Union, cast +from typing import Callable, Optional, Union import numpy as np -from numpy import typing as npt import pandas as pd from ortools.linear_solver import linear_solver_pb2 from ortools.linear_solver.python import model_builder_helper as mbh from ortools.linear_solver.python import model_builder_numbers as mbn - # Custom types. NumberT = Union[int, float, numbers.Real, np.number] IntegerT = Union[int, numbers.Integral, np.integer] -LinearExprT = Union["LinearExpr", NumberT] -ConstraintT = Union["_BoundedLinearExpr", bool] +LinearExprT = Union[mbh.LinearExpr, NumberT] +ConstraintT = Union[mbh.BoundedLinearExpression, bool] _IndexOrSeries = Union[pd.Index, pd.Series] -_VariableOrConstraint = Union["LinearConstraint", "Variable"] +_VariableOrConstraint = Union["LinearConstraint", mbh.Variable] # Forward solve statuses. +AffineExpr = mbh.AffineExpr +BoundedLinearExpression = mbh.BoundedLinearExpression +FlatExpr = mbh.FlatExpr +LinearExpr = mbh.LinearExpr SolveStatus = mbh.SolveStatus - -# pylint: disable=protected-access - - -class LinearExpr(metaclass=abc.ABCMeta): - """Holds an linear expression. - - A linear expression is built from constants and variables. - For example, `x + 2.0 * (y - z + 1.0)`. - - Linear expressions are used in Model models in constraints and in the - objective: - - * You can define linear constraints as in: - - ``` - model.add(x + 2 * y <= 5.0) - model.add(sum(array_of_vars) == 5.0) - ``` - - * In Model, the objective is a linear expression: - - ``` - model.minimize(x + 2.0 * y + z) - ``` - - * For large arrays, using the LinearExpr class is faster that using the python - `sum()` function. You can create constraints and the objective from lists of - linear expressions or coefficients as follows: - - ``` - model.minimize(model_builder.LinearExpr.sum(expressions)) - model.add(model_builder.LinearExpr.weighted_sum(expressions, coeffs) >= 0) - ``` - """ - - @classmethod - def sum( # pytype: disable=annotation-type-mismatch # numpy-scalars - cls, expressions: Sequence[LinearExprT], *, constant: NumberT = 0.0 - ) -> LinearExprT: - """Creates `sum(expressions) + constant`. - - It can perform simple simplifications and returns different objects, - including the input. - - Args: - expressions: a sequence of linear expressions or constants. - constant: a numerical constant. - - Returns: - a LinearExpr instance or a numerical constant. - """ - checked_constant: np.double = mbn.assert_is_a_number(constant) - if not expressions: - return checked_constant - if len(expressions) == 1 and mbn.is_zero(checked_constant): - return expressions[0] - - return LinearExpr.weighted_sum( - expressions, np.ones(len(expressions)), constant=checked_constant - ) - - @classmethod - def weighted_sum( # pytype: disable=annotation-type-mismatch # numpy-scalars - cls, - expressions: Sequence[LinearExprT], - coefficients: Sequence[NumberT], - *, - constant: NumberT = 0.0, - ) -> Union[NumberT, "_LinearExpression"]: - """Creates `sum(expressions[i] * coefficients[i]) + constant`. - - It can perform simple simplifications and returns different object, - including the input. - - Args: - expressions: a sequence of linear expressions or constants. - coefficients: a sequence of numerical constants. - constant: a numerical constant. - - Returns: - a _LinearExpression instance or a numerical constant. - """ - if len(expressions) != len(coefficients): - raise ValueError( - "LinearExpr.weighted_sum: expressions and coefficients have" - " different lengths" - ) - checked_constant: np.double = mbn.assert_is_a_number(constant) - if not expressions: - return checked_constant - return _sum_as_flat_linear_expression( - to_process=list(zip(expressions, coefficients)), offset=checked_constant - ) - - @classmethod - def term( # pytype: disable=annotation-type-mismatch # numpy-scalars - cls, - expression: LinearExprT, - coefficient: NumberT, - *, - constant: NumberT = 0.0, - ) -> LinearExprT: - """Creates `expression * coefficient + constant`. - - It can perform simple simplifications and returns different object, - including the input. - Args: - expression: a linear expression or a constant. - coefficient: a numerical constant. - constant: a numerical constant. - - Returns: - a LinearExpr instance or a numerical constant. - """ - checked_coefficient: np.double = mbn.assert_is_a_number(coefficient) - checked_constant: np.double = mbn.assert_is_a_number(constant) - - if mbn.is_zero(checked_coefficient): - return checked_constant - if mbn.is_one(checked_coefficient) and mbn.is_zero(checked_constant): - return expression - if mbn.is_a_number(expression): - return np.double(expression) * checked_coefficient + checked_constant - if isinstance(expression, LinearExpr): - return _as_flat_linear_expression( - expression * checked_coefficient + checked_constant - ) - raise TypeError(f"Unknown expression {expression!r} of type {type(expression)}") - - def __hash__(self): - return object.__hash__(self) - - def __add__(self, arg: LinearExprT) -> "_Sum": - return _Sum(self, arg) - - def __radd__(self, arg: LinearExprT) -> "_Sum": - return self.__add__(arg) - - def __sub__(self, arg: LinearExprT) -> "_Sum": - return _Sum(self, -arg) - - def __rsub__(self, arg: LinearExprT) -> "_Sum": - return _Sum(-self, arg) - - def __mul__(self, arg: NumberT) -> "_Product": - return _Product(self, arg) - - def __rmul__(self, arg: NumberT) -> "_Product": - return self.__mul__(arg) - - def __truediv__(self, coeff: NumberT) -> "_Product": - return self.__mul__(1.0 / coeff) - - def __neg__(self) -> "_Product": - return _Product(self, -1) - - def __bool__(self): - raise NotImplementedError(f"Cannot use a LinearExpr {self} as a Boolean value") - - def __eq__(self, arg: LinearExprT) -> "BoundedLinearExpression": - return BoundedLinearExpression(self - arg, 0, 0) - - def __ge__(self, arg: LinearExprT) -> "BoundedLinearExpression": - return BoundedLinearExpression( - self - arg, 0, math.inf - ) # pytype: disable=wrong-arg-types # numpy-scalars - - def __le__(self, arg: LinearExprT) -> "BoundedLinearExpression": - return BoundedLinearExpression( - self - arg, -math.inf, 0 - ) # pytype: disable=wrong-arg-types # numpy-scalars - - -class Variable(LinearExpr): - """A variable (continuous or integral). - - A Variable is an object that can take on any integer value within defined - ranges. Variables appear in constraint like: - - x + y >= 5 - - Solving a model is equivalent to finding, for each variable, a single value - from the set of initial values (called the initial domain), such that the - model is feasible, or optimal if you provided an objective function. - """ - - def __init__( - self, - helper: mbh.ModelBuilderHelper, - lb: NumberT, - ub: Optional[NumberT], - is_integral: Optional[bool], - name: Optional[str], - ) -> None: - """See Model.new_var below.""" - LinearExpr.__init__(self) - self.__helper: mbh.ModelBuilderHelper = helper - # Python do not support multiple __init__ methods. - # This method is only called from the Model class. - # We hack the parameter to support the two cases: - # case 1: - # helper is a ModelBuilderHelper, lb is a double value, ub is a double - # value, is_integral is a Boolean value, and name is a string. - # case 2: - # helper is a ModelBuilderHelper, lb is an index (int), ub is None, - # is_integral is None, and name is None. - if mbn.is_integral(lb) and ub is None and is_integral is None: - self.__index: np.int32 = np.int32(lb) - self.__helper: mbh.ModelBuilderHelper = helper - else: - index: np.int32 = helper.add_var() - self.__index: np.int32 = np.int32(index) - self.__helper: mbh.ModelBuilderHelper = helper - helper.set_var_lower_bound(index, lb) - helper.set_var_upper_bound(index, ub) - helper.set_var_integrality(index, is_integral) - if name: - helper.set_var_name(index, name) - - @property - def index(self) -> np.int32: - """Returns the index of the variable in the helper.""" - return self.__index - - @property - def helper(self) -> mbh.ModelBuilderHelper: - """Returns the underlying ModelBuilderHelper.""" - return self.__helper - - def is_equal_to(self, other: LinearExprT) -> bool: - """Returns true if self == other in the python sense.""" - if not isinstance(other, Variable): - return False - return self.index == other.index and self.helper == other.helper - - def __str__(self) -> str: - return self.name - - def __repr__(self) -> str: - return self.__str__() - - @property - def name(self) -> str: - """Returns the name of the variable.""" - var_name = self.__helper.var_name(self.__index) - if var_name: - return var_name - return f"variable#{self.index}" - - @name.setter - def name(self, name: str) -> None: - """Sets the name of the variable.""" - self.__helper.set_var_name(self.__index, name) - - @property - def lower_bound(self) -> np.double: - """Returns the lower bound of the variable.""" - return self.__helper.var_lower_bound(self.__index) - - @lower_bound.setter - def lower_bound(self, bound: NumberT) -> None: - """Sets the lower bound of the variable.""" - self.__helper.set_var_lower_bound(self.__index, bound) - - @property - def upper_bound(self) -> np.double: - """Returns the upper bound of the variable.""" - return self.__helper.var_upper_bound(self.__index) - - @upper_bound.setter - def upper_bound(self, bound: NumberT) -> None: - """Sets the upper bound of the variable.""" - self.__helper.set_var_upper_bound(self.__index, bound) - - @property - def is_integral(self) -> bool: - """Returns whether the variable is integral.""" - return self.__helper.var_is_integral(self.__index) - - @is_integral.setter - def integrality(self, is_integral: bool) -> None: - """Sets the integrality of the variable.""" - self.__helper.set_var_integrality(self.__index, is_integral) - - @property - def objective_coefficient(self) -> NumberT: - return self.__helper.var_objective_coefficient(self.__index) - - @objective_coefficient.setter - def objective_coefficient(self, coeff: NumberT) -> None: - self.__helper.set_var_objective_coefficient(self.__index, coeff) - - def __eq__(self, arg: Optional[LinearExprT]) -> ConstraintT: - if arg is None: - return False - if isinstance(arg, Variable): - return VarEqVar(self, arg) - return BoundedLinearExpression( - self - arg, 0.0, 0.0 - ) # pytype: disable=wrong-arg-types # numpy-scalars - - def __hash__(self): - return hash((self.__helper, self.__index)) - - -class _BoundedLinearExpr(metaclass=abc.ABCMeta): - """Interface for types that can build bounded linear (boolean) expressions. - - Classes derived from _BoundedLinearExpr are used to build linear constraints - to be satisfied. - - * BoundedLinearExpression: a linear expression with upper and lower bounds. - * VarEqVar: an equality comparison between two variables. - """ - - @abc.abstractmethod - def _add_linear_constraint( - self, helper: mbh.ModelBuilderHelper, name: str - ) -> "LinearConstraint": - """Creates a new linear constraint in the helper. - - Args: - helper (mbh.ModelBuilderHelper): The helper to create the constraint. - name (str): The name of the linear constraint. - - Returns: - LinearConstraint: A reference to the linear constraint in the helper. - """ - - @abc.abstractmethod - def _add_enforced_linear_constraint( - self, - helper: mbh.ModelBuilderHelper, - var: Variable, - value: bool, - name: str, - ) -> "EnforcedLinearConstraint": - """Creates a new enforced linear constraint in the helper. - - Args: - helper (mbh.ModelBuilderHelper): The helper to create the constraint. - var (Variable): The indicator variable of the constraint. - value (bool): The indicator value of the constraint. - name (str): The name of the linear constraint. - - Returns: - Enforced LinearConstraint: A reference to the linear constraint in the - helper. - """ +Variable = mbh.Variable def _add_linear_constraint_to_helper( - bounded_expr: Union[bool, _BoundedLinearExpr], + bounded_expr: Union[bool, mbh.BoundedLinearExpression], helper: mbh.ModelBuilderHelper, name: Optional[str], ): @@ -448,14 +99,21 @@ def _add_linear_constraint_to_helper( helper.set_constraint_lower_bound(c.index, 1) helper.set_constraint_upper_bound(c.index, -1) return c - if isinstance(bounded_expr, _BoundedLinearExpr): + if isinstance(bounded_expr, mbh.BoundedLinearExpression): + c = LinearConstraint(helper) # pylint: disable=protected-access - return bounded_expr._add_linear_constraint(helper, name) - raise TypeError("invalid type={}".format(type(bounded_expr))) + helper.add_terms_to_constraint(c.index, bounded_expr.vars, bounded_expr.coeffs) + helper.set_constraint_lower_bound(c.index, bounded_expr.lower_bound) + helper.set_constraint_upper_bound(c.index, bounded_expr.upper_bound) + # pylint: enable=protected-access + if name is not None: + helper.set_constraint_name(c.index, name) + return c + raise TypeError(f"invalid type={type(bounded_expr).__name__!r}") def _add_enforced_linear_constraint_to_helper( - bounded_expr: Union[bool, _BoundedLinearExpr], + bounded_expr: Union[bool, mbh.BoundedLinearExpression], helper: mbh.ModelBuilderHelper, var: Variable, value: bool, @@ -502,153 +160,20 @@ def _add_enforced_linear_constraint_to_helper( helper.set_enforced_constraint_lower_bound(c.index, 1) helper.set_enforced_constraint_upper_bound(c.index, -1) return c - if isinstance(bounded_expr, _BoundedLinearExpr): - # pylint: disable=protected-access - return bounded_expr._add_enforced_linear_constraint(helper, var, value, name) - raise TypeError("invalid type={}".format(type(bounded_expr))) - - -@dataclasses.dataclass(repr=False, eq=False, frozen=True) -class VarEqVar(_BoundedLinearExpr): - """Represents var == var.""" - - __slots__ = ("left", "right") - - left: Variable - right: Variable - - def __str__(self): - return f"{self.left} == {self.right}" - - def __repr__(self): - return self.__str__() - - def __bool__(self) -> bool: - return hash(self.left) == hash(self.right) - - def _add_linear_constraint( - self, helper: mbh.ModelBuilderHelper, name: str - ) -> "LinearConstraint": - c = LinearConstraint(helper) - helper.set_constraint_lower_bound(c.index, 0.0) - helper.set_constraint_upper_bound(c.index, 0.0) - # pylint: disable=protected-access - helper.add_term_to_constraint(c.index, self.left.index, 1.0) - helper.add_term_to_constraint(c.index, self.right.index, -1.0) - # pylint: enable=protected-access - helper.set_constraint_name(c.index, name) - return c - - def _add_enforced_linear_constraint( - self, - helper: mbh.ModelBuilderHelper, - var: Variable, - value: bool, - name: str, - ) -> "EnforcedLinearConstraint": - """Adds an enforced linear constraint to the model.""" + if isinstance(bounded_expr, mbh.BoundedLinearExpression): c = EnforcedLinearConstraint(helper) c.indicator_variable = var c.indicator_value = value - helper.set_enforced_constraint_lower_bound(c.index, 0.0) - helper.set_enforced_constraint_upper_bound(c.index, 0.0) - # pylint: disable=protected-access - helper.add_term_to_enforced_constraint(c.index, self.left.index, 1.0) - helper.add_term_to_enforced_constraint(c.index, self.right.index, -1.0) - # pylint: enable=protected-access - helper.set_enforced_constraint_name(c.index, name) - return c - - -class BoundedLinearExpression(_BoundedLinearExpr): - """Represents a linear constraint: `lb <= linear expression <= ub`. - - The only use of this class is to be added to the Model through - `Model.add(bounded expression)`, as in: - - model.Add(x + 2 * y -1 >= z) - """ - - def __init__(self, expr: LinearExprT, lb: NumberT, ub: NumberT) -> None: - self.__expr: LinearExprT = expr - self.__lb: np.double = mbn.assert_is_a_number(lb) - self.__ub: np.double = mbn.assert_is_a_number(ub) - - def __str__(self) -> str: - if self.__lb > -math.inf and self.__ub < math.inf: - if self.__lb == self.__ub: - return f"{self.__expr} == {self.__lb}" - else: - return f"{self.__lb} <= {self.__expr} <= {self.__ub}" - elif self.__lb > -math.inf: - return f"{self.__expr} >= {self.__lb}" - elif self.__ub < math.inf: - return f"{self.__expr} <= {self.__ub}" - else: - return f"{self.__expr} free" - - def __repr__(self): - return self.__str__() - - @property - def expression(self) -> LinearExprT: - return self.__expr - - @property - def lower_bound(self) -> np.double: - return self.__lb - - @property - def upper_bound(self) -> np.double: - return self.__ub - - def __bool__(self) -> bool: - raise NotImplementedError( - f"Cannot use a BoundedLinearExpression {self} as a Boolean value" + helper.add_terms_to_enforced_constraint( + c.index, bounded_expr.vars, bounded_expr.coeffs ) - - def _add_linear_constraint( - self, helper: mbh.ModelBuilderHelper, name: Optional[str] - ) -> "LinearConstraint": - c = LinearConstraint(helper) - flat_expr = _as_flat_linear_expression(self.__expr) - # pylint: disable=protected-access - helper.add_terms_to_constraint( - c.index, flat_expr._variable_indices, flat_expr._coefficients - ) - helper.set_constraint_lower_bound(c.index, self.__lb - flat_expr._offset) - helper.set_constraint_upper_bound(c.index, self.__ub - flat_expr._offset) - # pylint: enable=protected-access + helper.set_enforced_constraint_lower_bound(c.index, bounded_expr.lower_bound) + helper.set_enforced_constraint_upper_bound(c.index, bounded_expr.upper_bound) if name is not None: helper.set_constraint_name(c.index, name) return c - def _add_enforced_linear_constraint( - self, - helper: mbh.ModelBuilderHelper, - var: Variable, - value: bool, - name: Optional[str], - ) -> "EnforcedLinearConstraint": - """Adds an enforced linear constraint to the model.""" - c = EnforcedLinearConstraint(helper) - c.indicator_variable = var - c.indicator_value = value - flat_expr = _as_flat_linear_expression(self.__expr) - # pylint: disable=protected-access - helper.add_terms_to_enforced_constraint( - c.index, flat_expr._variable_indices, flat_expr._coefficients - ) - helper.set_enforced_constraint_lower_bound( - c.index, self.__lb - flat_expr._offset - ) - helper.set_enforced_constraint_upper_bound( - c.index, self.__ub - flat_expr._offset - ) - # pylint: enable=protected-access - if name is not None: - helper.set_enforced_constraint_name(c.index, name) - return c + raise TypeError(f"invalid type={type(bounded_expr).__name__!r}") class LinearConstraint: @@ -683,6 +208,9 @@ class LinearConstraint: self.__helper: mbh.ModelBuilderHelper = helper self.__is_under_specified = is_under_specified + def __hash__(self): + return hash((self.__helper, self.__index)) + @property def index(self) -> IntegerT: """Returns the index of the constraint in the helper.""" @@ -837,7 +365,7 @@ class EnforcedLinearConstraint: enforcement_var_index = ( self.__helper.enforced_constraint_indicator_variable_index(self.__index) ) - return Variable(self.__helper, enforcement_var_index, None, None, None) + return Variable(self.__helper, enforcement_var_index) @indicator_variable.setter def indicator_variable(self, var: "Variable") -> None: @@ -984,15 +512,14 @@ class Model: """ return _attribute_series( # pylint: disable=g-long-lambda - func=lambda c: _as_flat_linear_expression( + func=lambda c: mbh.FlatExpr( # pylint: disable=g-complex-comprehension - sum( - coeff * Variable(self.__helper, var_id, None, None, None) - for var_id, coeff in zip( - c.helper.constraint_var_indices(c.index), - c.helper.constraint_coefficients(c.index), - ) - ) + [ + Variable(self.__helper, var_id) + for var_id in c.helper.constraint_var_indices(c.index) + ], + c.helper.constraint_coefficients(c.index), + 0.0, ), values=self._get_linear_constraints(constraints), ) @@ -1106,8 +633,10 @@ class Model: Returns: a variable whose domain is [lb, ub]. """ - - return Variable(self.__helper, lb, ub, is_integer, name) + if name: + return Variable(self.__helper, lb, ub, is_integer, name) + else: + return Variable(self.__helper, lb, ub, is_integer) def new_int_var( self, lb: NumberT, ub: NumberT, name: Optional[str] = None @@ -1187,16 +716,15 @@ class Model: if not isinstance(index, pd.Index): raise TypeError("Non-index object is used as index") if not name.isidentifier(): - raise ValueError("name={} is not a valid identifier".format(name)) + raise ValueError(f"name={name!r} is not a valid identifier") if ( mbn.is_a_number(lower_bounds) and mbn.is_a_number(upper_bounds) and lower_bounds > upper_bounds ): raise ValueError( - "lower_bound={} is greater than upper_bound={} for variable set={}".format( - lower_bounds, upper_bounds, name - ) + f"lower_bound={lower_bounds} is greater than" + f" upper_bound={upper_bounds} for variable set={name!r}" ) if ( isinstance(is_integral, bool) @@ -1208,10 +736,9 @@ class Model: and math.ceil(lower_bounds) > math.floor(upper_bounds) ): raise ValueError( - "ceil(lower_bound={})={}".format(lower_bounds, math.ceil(lower_bounds)) - + " is greater than floor(" - + "upper_bound={})={}".format(upper_bounds, math.floor(upper_bounds)) - + " for variable set={}".format(name) + f"ceil(lower_bound={lower_bounds})={math.ceil(lower_bounds)}" + f" is greater than floor({upper_bounds}) = {math.floor(upper_bounds)}" + f" for variable set={name!r}" ) lower_bounds = _convert_to_series_and_validate_index(lower_bounds, index) upper_bounds = _convert_to_series_and_validate_index(upper_bounds, index) @@ -1221,11 +748,11 @@ class Model: data=[ # pylint: disable=g-complex-comprehension Variable( - helper=self.__helper, - name=f"{name}[{i}]", - lb=lower_bounds[i], - ub=upper_bounds[i], - is_integral=is_integrals[i], + self.__helper, + lower_bounds[i], + upper_bounds[i], + is_integrals[i], + f"{name}[{i}]", ) for i in index ], @@ -1318,7 +845,7 @@ class Model: def var_from_index(self, index: IntegerT) -> Variable: """Rebuilds a variable object from the model and its index.""" - return Variable(self.__helper, index, None, None, None) + return Variable(self.__helper, index) # Linear constraints. @@ -1336,22 +863,18 @@ class Model: if mbn.is_a_number(linear_expr): self.__helper.set_constraint_lower_bound(ct.index, lb - linear_expr) self.__helper.set_constraint_upper_bound(ct.index, ub - linear_expr) - elif isinstance(linear_expr, Variable): - self.__helper.set_constraint_lower_bound(ct.index, lb) - self.__helper.set_constraint_upper_bound(ct.index, ub) - self.__helper.add_term_to_constraint(ct.index, linear_expr.index, 1.0) elif isinstance(linear_expr, LinearExpr): - flat_expr = _as_flat_linear_expression(linear_expr) + flat_expr = mbh.FlatExpr(linear_expr) # pylint: disable=protected-access - self.__helper.set_constraint_lower_bound(ct.index, lb - flat_expr._offset) - self.__helper.set_constraint_upper_bound(ct.index, ub - flat_expr._offset) + self.__helper.set_constraint_lower_bound(ct.index, lb - flat_expr.offset) + self.__helper.set_constraint_upper_bound(ct.index, ub - flat_expr.offset) self.__helper.add_terms_to_constraint( - ct.index, flat_expr._variable_indices, flat_expr._coefficients + ct.index, flat_expr.vars, flat_expr.coeffs ) else: raise TypeError( - f"Not supported: Model.add_linear_constraint({linear_expr})" - f" with type {type(linear_expr)}" + "Not supported:" + f" Model.add_linear_constraint({type(linear_expr).__name__!r})" ) return ct @@ -1381,8 +904,8 @@ class Model: you can check the if a constraint is under specified by reading the `LinearConstraint.is_under_specified` property. """ - if isinstance(ct, _BoundedLinearExpr): - return ct._add_linear_constraint(self.__helper, name) + if isinstance(ct, mbh.BoundedLinearExpression): + return _add_linear_constraint_to_helper(ct, self.__helper, name) elif isinstance(ct, bool): return _add_linear_constraint_to_helper(ct, self.__helper, name) elif isinstance(ct, pd.Series): @@ -1396,13 +919,13 @@ class Model: ], ) else: - raise TypeError("Not supported: Model.add(" + str(ct) + ")") + raise TypeError(f"Not supported: Model.add({type(ct).__name__!r})") def linear_constraint_from_index(self, index: IntegerT) -> LinearConstraint: """Rebuilds a linear constraint object from the model and its index.""" return LinearConstraint(self.__helper, index=index) - # EnforcedLinear constraints. + # Enforced Linear constraints. def add_enforced_linear_constraint( # pytype: disable=annotation-type-mismatch # numpy-scalars self, @@ -1422,23 +945,18 @@ class Model: if mbn.is_a_number(linear_expr): self.__helper.set_constraint_lower_bound(ct.index, lb - linear_expr) self.__helper.set_constraint_upper_bound(ct.index, ub - linear_expr) - elif isinstance(linear_expr, Variable): - self.__helper.set_constraint_lower_bound(ct.index, lb) - self.__helper.set_constraint_upper_bound(ct.index, ub) - self.__helper.add_term_to_constraint(ct.index, linear_expr.index, 1.0) elif isinstance(linear_expr, LinearExpr): - flat_expr = _as_flat_linear_expression(linear_expr) + flat_expr = mbh.FlatExpr(linear_expr) # pylint: disable=protected-access - self.__helper.set_constraint_lower_bound(ct.index, lb - flat_expr._offset) - self.__helper.set_constraint_upper_bound(ct.index, ub - flat_expr._offset) + self.__helper.set_constraint_lower_bound(ct.index, lb - flat_expr.offset) + self.__helper.set_constraint_upper_bound(ct.index, ub - flat_expr.offset) self.__helper.add_terms_to_constraint( - ct.index, flat_expr._variable_indices, flat_expr._coefficients + ct.index, flat_expr.vars, flat_expr.coeffs ) else: raise TypeError( "Not supported:" - f" Model.add_enforced_linear_constraint({linear_expr}) with" - f" type {type(linear_expr)}" + f" Model.add_enforced_linear_constraint({type(linear_expr).__name__!r})" ) return ct @@ -1472,8 +990,10 @@ class Model: you can check the if a constraint is always false (lb=inf, ub=-inf) by calling EnforcedLinearConstraint.is_always_false() """ - if isinstance(ct, _BoundedLinearExpr): - return ct._add_enforced_linear_constraint(self.__helper, var, value, name) + if isinstance(ct, mbh.BoundedLinearExpression): + return _add_enforced_linear_constraint_to_helper( + ct, self.__helper, var, value, name + ) elif ( isinstance(ct, bool) and isinstance(var, Variable) @@ -1499,7 +1019,7 @@ class Model: ], ) else: - raise TypeError("Not supported: Model.add_enforced(" + str(ct) + ")") + raise TypeError(f"Not supported: Model.add_enforced({type(ct).__name__!r}") def enforced_linear_constraint_from_index( self, index: IntegerT @@ -1525,14 +1045,16 @@ class Model: elif isinstance(linear_expr, Variable): self.helper.set_var_objective_coefficient(linear_expr.index, 1.0) elif isinstance(linear_expr, LinearExpr): - flat_expr = _as_flat_linear_expression(linear_expr) + flat_expr = mbh.FlatExpr(linear_expr) # pylint: disable=protected-access - self.helper.set_objective_offset(flat_expr._offset) - self.helper.set_objective_coefficients( - flat_expr._variable_indices, flat_expr._coefficients - ) + self.helper.set_objective_offset(flat_expr.offset) + var_indices = [var.index for var in flat_expr.vars] + self.helper.set_objective_coefficients(var_indices, flat_expr.coeffs) else: - raise TypeError(f"Not supported: Model.minimize/maximize({linear_expr})") + raise TypeError( + "Not supported:" + f" Model.minimize/maximize({type(linear_expr).__name__!r})" + ) @property def objective_offset(self) -> np.double: @@ -1543,16 +1065,16 @@ class Model: def objective_offset(self, value: NumberT) -> None: self.__helper.set_objective_offset(value) - def objective_expression(self) -> "_LinearExpression": + def objective_expression(self) -> "LinearExpr": """Returns the expression to optimize.""" - return _as_flat_linear_expression( - sum( - variable * self.__helper.var_objective_coefficient(variable.index) - for variable in self.get_variables() - if self.__helper.var_objective_coefficient(variable.index) != 0.0 - ) - + self.__helper.objective_offset() - ) + variables: list[Variable] = [] + coefficients: list[numbers.Real] = [] + for variable in self.get_variables(): + coeff = self.__helper.var_objective_coefficient(variable.index) + if coeff != 0.0: + variables.append(variable) + coefficients.append(coeff) + return mbh.FlatExpr(variables, coefficients, self.__helper.objective_offset()) # Hints. def clear_hints(self): @@ -1712,17 +1234,10 @@ class Solver: return pd.NA if mbn.is_a_number(expr): return expr - elif isinstance(expr, Variable): - return self.__solve_helper.var_value(expr.index) elif isinstance(expr, LinearExpr): - flat_expr = _as_flat_linear_expression(expr) - return self.__solve_helper.expression_value( - flat_expr._variable_indices, - flat_expr._coefficients, - flat_expr._offset, - ) + return self.__solve_helper.expression_value(expr) else: - raise TypeError(f"Unknown expression {expr!r} of type {type(expr)}") + raise TypeError(f"Unknown expression {type(expr).__name__!r}") def values(self, variables: _IndexOrSeries) -> pd.Series: """Returns the values of the input variables. @@ -1742,7 +1257,7 @@ class Solver: if not self.__solve_helper.has_solution(): return _attribute_series(func=lambda v: pd.NA, values=variables) return _attribute_series( - func=lambda v: self.__solve_helper.var_value(v.index), + func=lambda v: self.__solve_helper.variable_value(v.index), values=variables, ) @@ -1839,164 +1354,6 @@ class Solver: return self.__solve_helper.user_time() -# The maximum number of terms to display in a linear expression's repr. -_MAX_LINEAR_EXPRESSION_REPR_TERMS = 5 - - -@dataclasses.dataclass(repr=False, eq=False, frozen=True) -class _LinearExpression(LinearExpr): - """For variables x, an expression: offset + sum_{i in I} coeff_i * x_i.""" - - __slots__ = ("_variable_indices", "_coefficients", "_offset", "_helper") - - _variable_indices: npt.NDArray[np.int32] - _coefficients: npt.NDArray[np.double] - _offset: float - _helper: Optional[mbh.ModelBuilderHelper] - - @property - def variable_indices(self) -> npt.NDArray[np.int32]: - return self._variable_indices - - @property - def coefficients(self) -> npt.NDArray[np.double]: - return self._coefficients - - @property - def constant(self) -> float: - return self._offset - - @property - def helper(self) -> Optional[mbh.ModelBuilderHelper]: - return self._helper - - def __repr__(self): - return self.__str__() - - def __str__(self): - if self._helper is None: - return str(self._offset) - - result = [] - for index, coeff in zip(self.variable_indices, self.coefficients): - if len(result) >= _MAX_LINEAR_EXPRESSION_REPR_TERMS: - result.append(" + ...") - break - var_name = Variable(self._helper, index, None, None, None).name - if not result and mbn.is_one(coeff): - result.append(var_name) - elif not result and mbn.is_minus_one(coeff): - result.append(f"-{var_name}") - elif not result: - result.append(f"{coeff} * {var_name}") - elif mbn.is_one(coeff): - result.append(f" + {var_name}") - elif mbn.is_minus_one(coeff): - result.append(f" - {var_name}") - elif coeff > 0.0: - result.append(f" + {coeff} * {var_name}") - elif coeff < 0.0: - result.append(f" - {-coeff} * {var_name}") - - if not result: - return f"{self.constant}" - if self.constant > 0: - result.append(f" + {self.constant}") - elif self.constant < 0: - result.append(f" - {-self.constant}") - return "".join(result) - - -def _sum_as_flat_linear_expression( - to_process: List[Tuple[LinearExprT, float]], offset: float = 0.0 -) -> _LinearExpression: - """Creates a _LinearExpression as the sum of terms.""" - indices = [] - coeffs = [] - helper = None - while to_process: # Flatten AST of LinearTypes. - expr, coeff = to_process.pop() - if isinstance(expr, _Sum): - to_process.append((expr._left, coeff)) - to_process.append((expr._right, coeff)) - elif isinstance(expr, Variable): - indices.append([expr.index]) - coeffs.append([coeff]) - if helper is None: - helper = expr.helper - elif mbn.is_a_number(expr): - offset += coeff * cast(NumberT, expr) - elif isinstance(expr, _Product): - to_process.append((expr._expression, coeff * expr._coefficient)) - elif isinstance(expr, _LinearExpression): - offset += coeff * expr._offset - if expr._helper is not None: - indices.append(expr.variable_indices) - coeffs.append(np.multiply(expr.coefficients, coeff)) - if helper is None: - helper = expr._helper - else: - raise TypeError( - "Unrecognized linear expression: " + str(expr) + f" {type(expr)}" - ) - - if helper is not None: - all_indices: npt.NDArray[np.int32] = np.concatenate(indices, axis=0) - all_coeffs: npt.NDArray[np.double] = np.concatenate(coeffs, axis=0) - sorted_indices, sorted_coefficients = helper.sort_and_regroup_terms( - all_indices, all_coeffs - ) - return _LinearExpression(sorted_indices, sorted_coefficients, offset, helper) - else: - assert not indices - assert not coeffs - return _LinearExpression( - _variable_indices=np.zeros(dtype=np.int32, shape=[0]), - _coefficients=np.zeros(dtype=np.double, shape=[0]), - _offset=offset, - _helper=None, - ) - - -def _as_flat_linear_expression(base_expr: LinearExprT) -> _LinearExpression: - """Converts floats, ints and Linear objects to a LinearExpression.""" - if isinstance(base_expr, _LinearExpression): - return base_expr - return _sum_as_flat_linear_expression(to_process=[(base_expr, 1.0)], offset=0.0) - - -@dataclasses.dataclass(repr=False, eq=False, frozen=True) -class _Sum(LinearExpr): - """Represents the (deferred) sum of two expressions.""" - - __slots__ = ("_left", "_right") - - _left: LinearExprT - _right: LinearExprT - - def __repr__(self): - return self.__str__() - - def __str__(self): - return str(_as_flat_linear_expression(self)) - - -@dataclasses.dataclass(repr=False, eq=False, frozen=True) -class _Product(LinearExpr): - """Represents the (deferred) product of an expression by a constant.""" - - __slots__ = ("_expression", "_coefficient") - - _expression: LinearExpr - _coefficient: NumberT - - def __repr__(self): - return self.__str__() - - def __str__(self): - return str(_as_flat_linear_expression(self)) - - def _get_index(obj: _IndexOrSeries) -> pd.Index: """Returns the indices of `obj` as a `pd.Index`.""" if isinstance(obj, pd.Series): @@ -2048,7 +1405,7 @@ def _convert_to_series_and_validate_index( else: raise ValueError("index does not match") else: - raise TypeError("invalid type={}".format(type(value_or_series))) + raise TypeError("invalid type={type(value_or_series).__name!r}") return result @@ -2076,7 +1433,7 @@ def _convert_to_var_series_and_validate_index( else: raise ValueError("index does not match") else: - raise TypeError("invalid type={}".format(type(var_or_series))) + raise TypeError("invalid type={type(value_or_series).__name!r}") return result diff --git a/ortools/linear_solver/python/model_builder_helper.cc b/ortools/linear_solver/python/model_builder_helper.cc index 12478aa47a..0d65d87266 100644 --- a/ortools/linear_solver/python/model_builder_helper.cc +++ b/ortools/linear_solver/python/model_builder_helper.cc @@ -16,9 +16,8 @@ #include "ortools/linear_solver/wrappers/model_builder_helper.h" #include -#include +#include #include -#include #include #include #include @@ -28,12 +27,15 @@ #include "Eigen/Core" #include "Eigen/SparseCore" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "ortools/base/logging.h" #include "ortools/linear_solver/linear_solver.pb.h" #include "ortools/linear_solver/model_exporter.h" +#include "pybind11/cast.h" #include "pybind11/eigen.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" @@ -42,19 +44,32 @@ using ::Eigen::SparseMatrix; using ::Eigen::VectorXd; -using ::operations_research::ModelBuilderHelper; -using ::operations_research::ModelSolverHelper; using ::operations_research::MPConstraintProto; using ::operations_research::MPModelExportOptions; using ::operations_research::MPModelProto; using ::operations_research::MPModelRequest; using ::operations_research::MPSolutionResponse; using ::operations_research::MPVariableProto; -using ::operations_research::SolveStatus; +using ::operations_research::mb::AffineExpr; +using ::operations_research::mb::BoundedLinearExpression; +using ::operations_research::mb::FixedValue; +using ::operations_research::mb::FlatExpr; +using ::operations_research::mb::LinearExpr; +using ::operations_research::mb::ModelBuilderHelper; +using ::operations_research::mb::ModelSolverHelper; +using ::operations_research::mb::SolveStatus; +using ::operations_research::mb::SumArray; +using ::operations_research::mb::Variable; +using ::operations_research::mb::WeightedSumArray; namespace py = pybind11; using ::py::arg; +void ThrowError(PyObject* py_exception, const std::string& message) { + PyErr_SetString(py_exception, message.c_str()); + throw py::error_already_set(); +} + const MPModelProto& ToMPModelProto(ModelBuilderHelper* helper) { return helper->model(); } @@ -153,9 +168,369 @@ std::vector> SortedGroupedTerms( return terms; } +const char* kLinearExprClassDoc = R"doc( +Holds an linear expression. + +A linear expression is built from constants and variables. +For example, `x + 2.0 * (y - z + 1.0)`. + +Linear expressions are used in Model models in constraints and in the objective: + + * You can define linear constraints as in: + +``` + model.add(x + 2 * y <= 5.0) + model.add(sum(array_of_vars) == 5.0) +``` + + * In Model, the objective is a linear expression: + +``` + model.minimize(x + 2.0 * y + z) +``` + + * For large arrays, using the LinearExpr class is faster that using the python + `sum()` function. You can create constraints and the objective from lists of + linear expressions or coefficients as follows: + +``` + model.minimize(model_builder.LinearExpr.sum(expressions)) + model.add(model_builder.LinearExpr.weighted_sum(expressions, coeffs) >= 0) +``` +)doc"; + +const char* kVarClassDoc = R"doc(A variable (continuous or integral). + + A Variable is an object that can take on any integer value within defined + ranges. Variables appear in constraint like: + + x + y >= 5 + + Solving a model is equivalent to finding, for each variable, a single value + from the set of initial values (called the initial domain), such that the + model is feasible, or optimal if you provided an objective function. +)doc"; + +void ProcessExprArg(const py::handle& arg, LinearExpr*& expr, + double& float_value) { + if (py::isinstance(arg)) { + expr = arg.cast(); + } else { + float_value = arg.cast(); + } +} + +LinearExpr* SumArguments(py::args args, const py::kwargs& kwargs) { + std::vector linear_exprs; + double float_offset = 0.0; + + const auto process_arg = [&](const py::handle& arg) -> void { + if (py::isinstance(arg)) { + linear_exprs.push_back(arg.cast()); + } else { + float_offset += arg.cast(); + } + }; + + if (args.size() == 1 && py::isinstance(args[0])) { + // Normal list or tuple argument. + py::sequence elements = args[0].cast(); + linear_exprs.reserve(elements.size()); + for (const py::handle& arg : elements) { + process_arg(arg); + } + } else { // Direct sum(x, y, 3, ..) without []. + linear_exprs.reserve(args.size()); + for (const py::handle arg : args) { + process_arg(arg); + } + } + + if (kwargs) { + for (const auto arg : kwargs) { + const std::string arg_name = std::string(py::str(arg.first)); + if (arg_name == "constant") { + float_offset += arg.second.cast(); + } else { + ThrowError(PyExc_ValueError, + absl::StrCat("Unknown keyword argument: ", arg_name)); + } + } + } + + if (linear_exprs.empty()) { + return new FixedValue(float_offset); + } else if (linear_exprs.size() == 1) { + if (float_offset == 0.0) { + return linear_exprs[0]; + } else { + return new AffineExpr(linear_exprs[0], 1.0, float_offset); + } + } else { + return new SumArray(linear_exprs, float_offset); + } +} + +LinearExpr* WeightedSumArguments(py::sequence expressions, + const std::vector& coefficients, + double offset = 0.0) { + if (expressions.size() != coefficients.size()) { + ThrowError(PyExc_ValueError, + absl::StrCat("LinearExpr::weighted_sum() requires the same " + "number of arguments and coefficients: ", + expressions.size(), " != ", coefficients.size())); + } + + std::vector linear_exprs; + std::vector coeffs; + linear_exprs.reserve(expressions.size()); + coeffs.reserve(expressions.size()); + + for (int i = 0; i < expressions.size(); ++i) { + py::handle arg = expressions[i]; + LinearExpr* expr = nullptr; + double value = 0.0; + ProcessExprArg(arg, expr, value); + if (expr != nullptr && coefficients[i] != 0.0) { + linear_exprs.push_back(expr); + coeffs.push_back(coefficients[i]); + continue; + } else if (value != 0.0) { + offset += coefficients[i] * value; + } + } + + if (linear_exprs.empty()) { + return new FixedValue(offset); + } else if (linear_exprs.size() == 1) { + if (offset == 0.0 && coeffs[0] == 1.0) { + return linear_exprs[0]; + } else { + return new AffineExpr(linear_exprs[0], coeffs[0], offset); + } + } else { + return new WeightedSumArray(linear_exprs, coeffs, offset); + } +} + PYBIND11_MODULE(model_builder_helper, m) { pybind11_protobuf::ImportNativeProtoCasters(); + py::class_(m, "LinearExpr", kLinearExprClassDoc) + .def_static("sum", &SumArguments, + "Creates `sum(expressions) [+ constant]`.", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static( + "weighted_sum", &WeightedSumArguments, + "Creates `sum(expressions[i] * coefficients[i]) [+ constant]`.", + arg("expressions"), arg("coefficients"), py::kw_only(), + arg("constant") = 0.0, py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + .def_static("term", &LinearExpr::Term, arg("expr").none(false), + arg("coeff"), "Returns expr * coeff.", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("term", &LinearExpr::Affine, arg("expr").none(false), + arg("coeff"), py::kw_only(), py::arg("constant"), + "Returns expr * coeff [+ constant].", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("term", &LinearExpr::AffineCst, arg("value"), arg("coeff"), + py::kw_only(), py::arg("constant"), + "Returns value * coeff [+ constant].", + py::return_value_policy::automatic) + .def_static("affine", &LinearExpr::Affine, arg("expr").none(false), + arg("coeff"), arg("constant") = 0.0, + "Returns expr * coeff + constant.", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("affine", &LinearExpr::AffineCst, arg("value"), arg("coeff"), + arg("constant") = 0.0, "Returns value * coeff + constant.", + py::return_value_policy::automatic) + .def_static("constant", &LinearExpr::Constant, arg("value"), + "Returns a constant linear expression.", + py::return_value_policy::automatic) + // Methods. + .def("__str__", &LinearExpr::ToString) + .def("__repr__", &LinearExpr::DebugString) + // Operators. + // Note that we keep the 3 APIS (expr, int, double) instead of using an + // py::handle argument as this is more efficient. + .def("__add__", &LinearExpr::Add, arg("other").none(false), + py::return_value_policy::automatic, py::keep_alive<0, 1>(), + py::keep_alive<0, 2>()) + .def("__add__", &LinearExpr::AddFloat, arg("cst"), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__radd__", &LinearExpr::AddFloat, arg("cst"), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__sub__", &LinearExpr::Sub, arg("other").none(false), + py::return_value_policy::automatic, py::keep_alive<0, 1>(), + py::keep_alive<0, 2>()) + .def("__sub__", &LinearExpr::SubFloat, arg("cst"), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__rsub__", &LinearExpr::RSubFloat, arg("cst"), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__mul__", &LinearExpr::MulFloat, arg("cst"), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__rmul__", &LinearExpr::MulFloat, arg("cst"), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def( + "__truediv__", + [](LinearExpr* self, double cst) { + if (cst == 0.0) { + ThrowError(PyExc_ZeroDivisionError, + "Division by zero is not supported."); + } + return self->MulFloat(1.0 / cst); + }, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__neg__", &LinearExpr::Neg, py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + // Comparison operators. + .def("__eq__", &LinearExpr::Eq, arg("other").none(false), + "Creates the constraint `self == other`.", + py::return_value_policy::automatic, py::keep_alive<0, 1>(), + py::keep_alive<0, 2>()) + .def("__eq__", &LinearExpr::EqCst, arg("cst"), + "Creates the constraint `self == cst`.", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__le__", &LinearExpr::Le, arg("other").none(false), + "Creates the constraint `self <= other`.", + py::return_value_policy::automatic, py::keep_alive<0, 1>(), + py::keep_alive<0, 2>()) + .def("__le__", &LinearExpr::LeCst, arg("cst"), + "Creates the constraint `self <= cst`.", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__ge__", &LinearExpr::Ge, arg("other").none(false), + "Creates the constraint `self >= other`.", + py::return_value_policy::automatic, py::keep_alive<0, 1>(), + py::keep_alive<0, 2>()) + .def("__ge__", &LinearExpr::GeCst, arg("cst"), + "Creates the constraint `self >= cst`.", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + // Disable other operators as they are not supported. + .def("__floordiv__", + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling // on a linear expression is not supported."); + }) + .def("__mod__", + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling %% on a linear expression is not supported."); + }) + .def("__pow__", + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling ** on a linear expression is not supported."); + }) + .def("__lshift__", + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError( + PyExc_NotImplementedError, + "calling left shift on a linear expression is not supported"); + }) + .def("__rshift__", + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError( + PyExc_NotImplementedError, + "calling right shift on a linear expression is not supported"); + }) + .def("__and__", + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling and on a linear expression is not supported"); + }) + .def("__or__", + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling or on a linear expression is not supported"); + }) + .def("__xor__", + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling xor on a linear expression is not supported"); + }) + .def("__abs__", + [](LinearExpr* /*self*/) { + ThrowError( + PyExc_NotImplementedError, + "calling abs() on a linear expression is not supported."); + }) + .def("__bool__", [](LinearExpr* /*self*/) { + ThrowError(PyExc_NotImplementedError, + "Evaluating a LinearExpr instance as a Boolean is " + "not supported."); + }); + + // Expose Internal classes, mostly for testing. + py::class_(m, "FlatExpr") + .def(py::init()) + .def(py::init()) + .def(py::init&, + const std::vector&, double>(), + py::keep_alive<1, 2>()) + .def(py::init()) + .def_property_readonly("vars", &FlatExpr::vars) + .def("variable_indices", &FlatExpr::VarIndices) + .def_property_readonly("coeffs", &FlatExpr::coeffs) + .def_property_readonly("offset", &FlatExpr::offset); + + py::class_(m, "AffineExpr") + .def(py::init()) + .def_property_readonly("expression", &AffineExpr ::expression) + .def_property_readonly("coefficient", &AffineExpr::coefficient) + .def_property_readonly("offset", &AffineExpr::offset); + + py::class_(m, "Variable", kVarClassDoc) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def_property_readonly("index", &Variable::index, + "The index of the variable in the model.") + .def_property_readonly("helper", &Variable::helper, + "The ModelBuilderHelper instance.") + .def_property("name", &Variable::name, &Variable::SetName, + "The name of the variable in the model.") + .def_property("lower_bound", &Variable::lower_bounds, + &Variable::SetLowerBound) + .def_property("upper_bound", &Variable::upper_bound, + &Variable::SetUpperBound) + .def_property("is_integral", &Variable::is_integral, + &Variable::SetIsIntegral) + .def_property("objective_coefficient", &Variable::objective_coefficient, + &Variable::SetObjectiveCoefficient) + .def("__str__", &Variable::ToString) + .def("__repr__", &Variable::DebugString) + .def("__hash__", [](const Variable& self) { + return absl::HashOf(std::make_tuple(self.helper(), self.index())); + }); + + py::class_(m, "BoundedLinearExpression") + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def_property_readonly("vars", &BoundedLinearExpression::vars) + .def_property_readonly("coeffs", &BoundedLinearExpression::coeffs) + .def_property_readonly("lower_bound", + &BoundedLinearExpression::lower_bound) + .def_property_readonly("upper_bound", + &BoundedLinearExpression::upper_bound) + .def("__bool__", + [](const BoundedLinearExpression& self) { + bool result; + if (self.CastToBool(&result)) return result; + ThrowError(PyExc_NotImplementedError, + absl::StrCat("Evaluating a BoundedLinearExpression '", + self.ToString(), + "'instance as a Boolean is " + "not supported.") + .c_str()); + return false; + }) + .def("__str__", &BoundedLinearExpression::ToString) + .def("__repr__", &BoundedLinearExpression::DebugString); + m.def("to_mpmodel_proto", &ToMPModelProto, arg("helper")); py::class_(m, "MPModelExportOptions") @@ -314,11 +689,11 @@ PYBIND11_MODULE(model_builder_helper, m) { arg("ct_index"), arg("var_index"), arg("coeff")) .def("add_terms_to_constraint", [](ModelBuilderHelper* helper, int ct_index, - const std::vector& indices, + const std::vector& vars, const std::vector& coefficients) { - for (const auto& [i, c] : - SortedGroupedTerms(indices, coefficients)) { - helper->AddConstraintTerm(ct_index, i, c); + for (int i = 0; i < vars.size(); ++i) { + helper->AddConstraintTerm(ct_index, vars[i]->index(), + coefficients[i]); } }) .def("safe_add_term_to_constraint", @@ -354,11 +729,11 @@ PYBIND11_MODULE(model_builder_helper, m) { arg("var_index"), arg("coeff")) .def("add_terms_to_enforced_constraint", [](ModelBuilderHelper* helper, int ct_index, - const std::vector& indices, + const std::vector& vars, const std::vector& coefficients) { - for (const auto& [i, c] : - SortedGroupedTerms(indices, coefficients)) { - helper->AddEnforcedConstraintTerm(ct_index, i, c); + for (int i = 0; i < vars.size(); ++i) { + helper->AddEnforcedConstraintTerm(ct_index, vars[i]->index(), + coefficients[i]); } }) .def("safe_add_term_to_enforced_constraint", @@ -402,22 +777,7 @@ PYBIND11_MODULE(model_builder_helper, m) { .def("objective_offset", &ModelBuilderHelper::ObjectiveOffset) .def("clear_hints", &ModelBuilderHelper::ClearHints) .def("add_hint", &ModelBuilderHelper::AddHint, arg("var_index"), - arg("var_value")) - .def("sort_and_regroup_terms", - [](ModelBuilderHelper* helper, py::array_t indices, - py::array_t coefficients) { - const std::vector> terms = - SortedGroupedTerms(indices, coefficients); - std::vector sorted_indices; - std::vector sorted_coefficients; - sorted_indices.reserve(terms.size()); - sorted_coefficients.reserve(terms.size()); - for (const auto& [i, c] : terms) { - sorted_indices.push_back(i); - sorted_coefficients.push_back(c); - } - return std::make_pair(sorted_indices, sorted_coefficients); - }); + arg("var_value")); py::enum_(m, "SolveStatus") .value("OPTIMAL", SolveStatus::OPTIMAL) @@ -483,7 +843,16 @@ PYBIND11_MODULE(model_builder_helper, m) { .def("user_time", &ModelSolverHelper::user_time) .def("objective_value", &ModelSolverHelper::objective_value) .def("best_objective_bound", &ModelSolverHelper::best_objective_bound) - .def("var_value", &ModelSolverHelper::variable_value, arg("var_index")) + .def("variable_value", &ModelSolverHelper::variable_value, + arg("var_index")) + .def("expression_value", + [](const ModelSolverHelper& helper, LinearExpr* expr) { + if (!helper.has_response()) { + throw std::logic_error( + "Accessing a solution value when none has been found."); + } + return helper.expression_value(expr); + }) .def("reduced_cost", &ModelSolverHelper::reduced_cost, arg("var_index")) .def("dual_value", &ModelSolverHelper::dual_value, arg("ct_index")) .def("activity", &ModelSolverHelper::activity, arg("ct_index")) @@ -500,20 +869,6 @@ PYBIND11_MODULE(model_builder_helper, m) { } return vec; }) - .def("expression_value", - [](const ModelSolverHelper& helper, const std::vector& indices, - const std::vector& coefficients, double constant) { - if (!helper.has_response()) { - throw std::logic_error( - "Accessing a solution value when none has been found."); - } - const MPSolutionResponse& response = helper.response(); - for (int i = 0; i < indices.size(); ++i) { - constant += - response.variable_value(indices[i]) * coefficients[i]; - } - return constant; - }) .def("reduced_costs", [](const ModelSolverHelper& helper) { if (!helper.has_response()) { @@ -539,4 +894,4 @@ PYBIND11_MODULE(model_builder_helper, m) { } return vec; }); -} +} // NOLINT(readability/fn_size) diff --git a/ortools/linear_solver/python/model_builder_helper_test.py b/ortools/linear_solver/python/model_builder_helper_test.py index 250c8d3067..664cb6ab6e 100644 --- a/ortools/linear_solver/python/model_builder_helper_test.py +++ b/ortools/linear_solver/python/model_builder_helper_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for model_builder_helper.""" - import gzip import os import threading @@ -97,7 +95,7 @@ class PywrapModelBuilderHelperTest(absltest.TestCase): linear_solver_pb2.MPSolverResponseStatus.MPSOLVER_OPTIMAL, ) self.assertAlmostEqual(solver.objective_value(), 1.0) - self.assertAlmostEqual(solver.var_value(0), 1.0) + self.assertAlmostEqual(solver.variable_value(0), 1.0) values = solver.variable_values() self.assertEqual(1, len(values)) self.assertAlmostEqual(1.0, values[0]) diff --git a/ortools/linear_solver/python/model_builder_test.py b/ortools/linear_solver/python/model_builder_test.py index 198ae85e98..7fb0019916 100644 --- a/ortools/linear_solver/python/model_builder_test.py +++ b/ortools/linear_solver/python/model_builder_test.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for ModelBuilder.""" - import math +import sys from typing import Any, Callable, Dict, Mapping, Union from absl.testing import absltest @@ -28,15 +27,16 @@ import os from google.protobuf import text_format from ortools.linear_solver import linear_solver_pb2 from ortools.linear_solver.python import model_builder as mb +from ortools.linear_solver.python import model_builder_helper as mbh -def build_dict(expr: mb.LinearExprT) -> Dict[mb.Variable, float]: +def build_dict(expr: mb.LinearExprT) -> Dict[mbh.Variable, float]: res = {} - for i, c in zip(expr.variable_indices, expr.coefficients): - if not c: + flat_expr = mbh.FlatExpr(expr) + for var, coeff in zip(flat_expr.vars, flat_expr.coeffs): + if not coeff: continue - var = mb.Variable(expr.helper, lb=i, ub=None, is_integral=None, name=None) - res[var] = c + res[var] = coeff return res @@ -45,6 +45,10 @@ class ModelBuilderTest(absltest.TestCase): # checking primal, dual, objective values and other values. NUM_PLACES = 5 + def tearDown(self) -> None: + super().tearDown() + sys.stdout.flush() + # pylint: disable=too-many-statements def run_minimal_linear_example(self, solver_name): """Minimal Linear Example.""" @@ -208,66 +212,75 @@ ENDATA t = model.new_int_var(3, 10, "t") e1 = mb.LinearExpr.sum([x, y, z]) + flat_e1 = mbh.FlatExpr(e1) expected_vars = np.array([0, 1, 2], dtype=np.int32) - np_testing.assert_array_equal(expected_vars, e1.variable_indices) + np_testing.assert_array_equal(expected_vars, flat_e1.variable_indices()) np_testing.assert_array_equal( - np.array([1, 1, 1], dtype=np.double), e1.coefficients + np.array([1, 1, 1], dtype=np.double), flat_e1.coeffs ) - self.assertEqual(e1.constant, 0.0) - self.assertEqual(e1.__str__(), "x + y + z") + self.assertEqual(flat_e1.offset, 0.0) + self.assertEqual(e1.__str__(), "(x + y + z)") e2 = mb.LinearExpr.sum([e1, 4.0]) - np_testing.assert_array_equal(expected_vars, e2.variable_indices) + flat_e2 = mbh.FlatExpr(e2) + np_testing.assert_array_equal(expected_vars, flat_e2.variable_indices()) np_testing.assert_array_equal( - np.array([1, 1, 1], dtype=np.double), e2.coefficients + np.array([1, 1, 1], dtype=np.double), flat_e2.coeffs ) - self.assertEqual(e2.constant, 4.0) - self.assertEqual(e2.__str__(), "x + y + z + 4.0") + self.assertEqual(flat_e2.offset, 4.0) + self.assertEqual(e2.__str__(), "((x + y + z) + 4)") + self.assertEqual(flat_e2.__str__(), "x + y + z + 4") e3 = mb.LinearExpr.term(e2, 2) - np_testing.assert_array_equal(expected_vars, e3.variable_indices) + flat_e3 = mbh.FlatExpr(e3) + np_testing.assert_array_equal(expected_vars, flat_e3.variable_indices()) np_testing.assert_array_equal( - np.array([2, 2, 2], dtype=np.double), e3.coefficients + np.array([2, 2, 2], dtype=np.double), flat_e3.coeffs ) - self.assertEqual(e3.constant, 8.0) - self.assertEqual(e3.__str__(), "2.0 * x + 2.0 * y + 2.0 * z + 8.0") + self.assertEqual(flat_e3.offset, 8.0) + self.assertEqual(e3.__str__(), "(2 * ((x + y + z) + 4))") + self.assertEqual(flat_e3.__str__(), "2 * x + 2 * y + 2 * z + 8") e4 = mb.LinearExpr.weighted_sum([x, t], [-1, 1], constant=2) + flat_e4 = mbh.FlatExpr(e4) np_testing.assert_array_equal( - np.array([0, 3], dtype=np.int32), e4.variable_indices + np.array([0, 3], dtype=np.int32), flat_e4.variable_indices() ) np_testing.assert_array_equal( - np.array([-1, 1], dtype=np.double), e4.coefficients + np.array([-1, 1], dtype=np.double), flat_e4.coeffs ) - self.assertEqual(e4.constant, 2.0) - self.assertEqual(e4.__str__(), "-x + t + 2.0") + self.assertEqual(flat_e4.offset, 2.0) + self.assertEqual(e4.__str__(), "(-x + t + 2)") e4b = mb.LinearExpr.weighted_sum([e4 * 3], [1]) + flat_e4b = mbh.FlatExpr(e4b) np_testing.assert_array_equal( - np.array([0, 3], dtype=np.int32), e4b.variable_indices + np.array([0, 3], dtype=np.int32), flat_e4b.variable_indices() ) np_testing.assert_array_equal( - np.array([-3, 3], dtype=np.double), e4b.coefficients + np.array([-3, 3], dtype=np.double), flat_e4b.coeffs ) - self.assertEqual(e4b.constant, 6.0) - self.assertEqual(e4b.__str__(), "-3.0 * x + 3.0 * t + 6.0") + self.assertEqual(flat_e4b.offset, 6.0) + self.assertEqual(e4b.__str__(), "(3 * (-x + t + 2))") e5 = mb.LinearExpr.sum([e1, -3, e4]) + flat_e5 = mbh.FlatExpr(e5) np_testing.assert_array_equal( - np.array([1, 2, 3], dtype=np.int32), e5.variable_indices + np.array([1, 2, 3], dtype=np.int32), flat_e5.variable_indices() ) np_testing.assert_array_equal( - np.array([1, 1, 1], dtype=np.double), e5.coefficients + np.array([1, 1, 1], dtype=np.double), flat_e5.coeffs ) - self.assertEqual(e5.constant, -1.0) - self.assertEqual(e5.__str__(), "y + z + t - 1.0") + self.assertEqual(flat_e5.offset, -1.0) + self.assertEqual(flat_e5.__str__(), "y + z + t - 1") e6 = mb.LinearExpr.term(x, 2.0, constant=1.0) + flat_e6 = mbh.FlatExpr(e6) np_testing.assert_array_equal( - np.array([0], dtype=np.int32), e6.variable_indices + np.array([0], dtype=np.int32), flat_e6.variable_indices() ) - np_testing.assert_array_equal(np.array([2], dtype=np.double), e6.coefficients) - self.assertEqual(e6.constant, 1.0) + np_testing.assert_array_equal(np.array([2], dtype=np.double), flat_e6.coeffs) + self.assertEqual(flat_e6.offset, 1.0) e7 = mb.LinearExpr.term(x, 1.0, constant=0.0) self.assertEqual(x, e7) @@ -278,10 +291,79 @@ ENDATA e9 = mb.LinearExpr.term(x * 2 + 3, 1, constant=0) e10 = mb.LinearExpr.term(x, 2, constant=3) self.assertEqual( - str(mb._as_flat_linear_expression(e9)), - str(mb._as_flat_linear_expression(e10)), + str(mbh.FlatExpr(e9)), + str(mbh.FlatExpr(e10)), ) + e10 = mb.LinearExpr.sum() + self.assertEqual(str(e10), "0") + + e11 = mb.LinearExpr.sum(x) + self.assertIsInstance(e11, mb.Variable) + self.assertEqual(x.index, e11.index) + + e12 = mb.LinearExpr.sum(-1.0, x, 1.0) + self.assertIsInstance(e12, mb.Variable) + self.assertEqual(x.index, e12.index) + + e13 = mb.LinearExpr.sum(-1.0, x, constant=1.0) + self.assertIsInstance(e13, mb.Variable) + self.assertEqual(x.index, e13.index) + + e14 = mb.LinearExpr.weighted_sum([x, t, 1.2], [1, -1, -1.0], constant=2) + flat_e14 = mbh.FlatExpr(e14) + np_testing.assert_array_equal( + np.array([0, 3], dtype=np.int32), flat_e14.variable_indices() + ) + np_testing.assert_array_equal( + np.array([1, -1], dtype=np.double), flat_e14.coeffs + ) + self.assertEqual(flat_e14.offset, 0.8) + self.assertEqual(e14.__str__(), "(x - t + 0.8)") + + e15 = mb.LinearExpr.weighted_sum([1, x, 1], [1, 1, -1]) + self.assertIsInstance(e15, mb.Variable) + self.assertEqual(x.index, e15.index) + + e16 = mb.LinearExpr.affine(x, 1.0, 0.0) + self.assertIsInstance(e16, mb.Variable) + self.assertEqual(x.index, e16.index) + + e17 = -x + self.assertIsInstance(e17, mb.AffineExpr) + self.assertEqual(str(e17), "(-x)") + + e18 = mb.LinearExpr.affine(x, 1.0, -2.0) + self.assertIsInstance(e18, mb.AffineExpr) + self.assertEqual(str(e18), "(x - 2)") + + e19 = mb.LinearExpr.weighted_sum([1, x, 1], [1, 1, -2]) + self.assertIsInstance(e19, mb.AffineExpr) + self.assertEqual(str(e19), "(x - 1)") + + e20 = mb.LinearExpr.affine(x, -2.0, 0.0) + self.assertIsInstance(e20, mb.AffineExpr) + self.assertEqual(str(e20), "(-2 * x)") + + e21 = mb.LinearExpr.weighted_sum([1, x, 1], [1, 2, -1]) + self.assertIsInstance(e21, mb.AffineExpr) + self.assertEqual(str(e21), "(2 * x)") + + c1 = x == 2 + self.assertEqual(str(c1), "x == 2") + + c2 = -x == 3 + self.assertEqual(str(c2), "-x == 3") + + c3 = x + y == 3 + self.assertEqual(str(c3), "(x + y) == 3") + + c4 = -x + y == 3 + self.assertEqual(str(c4), "(-x + y) == 3") + + c5 = x - y == 3 + self.assertEqual(str(c5), "(x - y) == 3") + def test_variables(self): model = mb.Model() x = model.new_int_var(0.0, 4.0, "x") @@ -294,6 +376,10 @@ ENDATA self.assertEqual(1.0, x.lower_bound) self.assertEqual(3.0, x.upper_bound) self.assertTrue(x.is_integral) + n1 = model.new_int_var(0, 4) + self.assertEqual(n1.name, "variable#1") + n2 = model.new_int_var(0, 4, None) + self.assertEqual(n2.name, "variable#2") # Tests the equality operator. y = model.new_int_var(0.0, 4.0, "y") @@ -377,14 +463,6 @@ ENDATA status = solver.solve(model) self.assertEqual(mb.SolveStatus.OPTIMAL, status) - def test_vareqvar(self): - model = mb.Model() - x = model.new_int_var(0.0, 4.0, "x") - y = model.new_int_var(0.0, 4.0, "y") - ct = x == y - self.assertEqual(ct.left.index, x.index) - self.assertEqual(ct.right.index, y.index) - def test_create_false_ct(self): # Create the model. model = mb.Model() @@ -421,10 +499,14 @@ ENDATA class InternalHelperTest(absltest.TestCase): + def tearDown(self) -> None: + super().tearDown() + sys.stdout.flush() + def test_anonymous_variables(self): helper = mb.Model().helper index = helper.add_var() - variable = mb.Variable(helper, index, None, None, None) + variable = mb.Variable(helper, index) self.assertEqual(variable.name, f"variable#{index}") def test_anonymous_constraints(self): @@ -452,178 +534,180 @@ class LinearBaseTest(parameterized.TestCase): dict( testcase_name="x[0]", expr=lambda x, y: x[0], - expected_repr="x[0]", + expected_str="x[0]", ), dict( testcase_name="x[1]", expr=lambda x, y: x[1], - expected_repr="x[1]", + expected_str="x[1]", ), dict( testcase_name="x[2]", expr=lambda x, y: x[2], - expected_repr="x[2]", + expected_str="x[2]", ), dict( testcase_name="y[0]", expr=lambda x, y: y[0], - expected_repr="y[0]", + expected_str="y[0]", ), dict( testcase_name="y[4]", expr=lambda x, y: y[4], - expected_repr="y[4]", + expected_str="y[4]", ), # Sum dict( testcase_name="x[0] + 5", expr=lambda x, y: x[0] + 5, - expected_repr="x[0] + 5.0", + expected_str="x[0] + 5", ), dict( testcase_name="x[0] - 5", expr=lambda x, y: x[0] - 5, - expected_repr="x[0] - 5.0", + expected_str="x[0] - 5", ), dict( testcase_name="5 - x[0]", expr=lambda x, y: 5 - x[0], - expected_repr="-x[0] + 5.0", + expected_str="-x[0] + 5", ), dict( testcase_name="5 + x[0]", expr=lambda x, y: 5 + x[0], - expected_repr="x[0] + 5.0", + expected_str="x[0] + 5", ), dict( testcase_name="x[0] + y[0]", expr=lambda x, y: x[0] + y[0], - expected_repr="x[0] + y[0]", + expected_str="x[0] + y[0]", ), dict( testcase_name="x[0] + y[0] + 5", expr=lambda x, y: x[0] + y[0] + 5, - expected_repr="x[0] + y[0] + 5.0", + expected_str="x[0] + y[0] + 5", ), dict( testcase_name="5 + x[0] + y[0]", expr=lambda x, y: 5 + x[0] + y[0], - expected_repr="x[0] + y[0] + 5.0", + expected_str="x[0] + y[0] + 5", ), dict( testcase_name="5 + x[0] - x[0]", expr=lambda x, y: 5 + x[0] - x[0], - expected_repr="5.0", + expected_str="5", ), dict( testcase_name="5 + x[0] - y[0]", expr=lambda x, y: 5 + x[0] - y[0], - expected_repr="x[0] - y[0] + 5.0", + expected_str="x[0] - y[0] + 5", ), dict( testcase_name="x.sum()", expr=lambda x, y: x.sum(), - expected_repr="x[0] + x[1] + x[2]", + expected_str="x[0] + x[1] + x[2]", ), dict( testcase_name="x.add(y, fill_value=0).sum() + 5", expr=lambda x, y: x.add(y, fill_value=0).sum() + 5, - expected_repr="x[0] + x[1] + x[2] + y[0] + y[1] + ... + 5.0", + expected_str="x[0] + x[1] + x[2] + y[0] + y[1] + ... + 5", ), dict( testcase_name="sum(x, y + 5)", expr=lambda x, y: mb.LinearExpr.sum([x.sum(), y.sum() + 5]), - expected_repr="x[0] + x[1] + x[2] + y[0] + y[1] + ... + 5.0", + expected_str="x[0] + x[1] + x[2] + y[0] + y[1] + ... + 5", ), # Product dict( testcase_name="- x.sum()", expr=lambda x, y: -x.sum(), - expected_repr="-x[0] - x[1] - x[2]", + expected_str="-x[0] - x[1] - x[2]", ), dict( testcase_name="5 - x.sum()", expr=lambda x, y: 5 - x.sum(), - expected_repr="-x[0] - x[1] - x[2] + 5.0", + expected_str="-x[0] - x[1] - x[2] + 5", ), dict( - testcase_name="x.sum() / 2.0", - expr=lambda x, y: x.sum() / 2.0, - expected_repr="0.5 * x[0] + 0.5 * x[1] + 0.5 * x[2]", + testcase_name="x.sum() / 2", + expr=lambda x, y: x.sum() / 2, + expected_str="0.5 * x[0] + 0.5 * x[1] + 0.5 * x[2]", ), dict( testcase_name="(3 * x).sum()", expr=lambda x, y: (3 * x).sum(), - expected_repr="3.0 * x[0] + 3.0 * x[1] + 3.0 * x[2]", + expected_str="3 * x[0] + 3 * x[1] + 3 * x[2]", ), dict( testcase_name="(x * 3).sum()", expr=lambda x, y: (x * 3).sum(), - expected_repr="3.0 * x[0] + 3.0 * x[1] + 3.0 * x[2]", + expected_str="3 * x[0] + 3 * x[1] + 3 * x[2]", ), dict( testcase_name="x.sum() * 3", expr=lambda x, y: x.sum() * 3, - expected_repr="3.0 * x[0] + 3.0 * x[1] + 3.0 * x[2]", + expected_str="3 * x[0] + 3 * x[1] + 3 * x[2]", ), dict( testcase_name="3 * x.sum()", expr=lambda x, y: 3 * x.sum(), - expected_repr="3.0 * x[0] + 3.0 * x[1] + 3.0 * x[2]", + expected_str="3 * x[0] + 3 * x[1] + 3 * x[2]", ), dict( testcase_name="0 * x.sum() + y.sum()", expr=lambda x, y: 0 * x.sum() + y.sum(), - expected_repr="y[0] + y[1] + y[2] + y[3] + y[4]", + expected_str="y[0] + y[1] + y[2] + y[3] + y[4]", ), # LinearExpression dict( - testcase_name="_as_flat_linear_expression(x.sum())", - expr=lambda x, y: mb._as_flat_linear_expression(x.sum()), - expected_repr="x[0] + x[1] + x[2]", + testcase_name="FlatExpr(x.sum())", + expr=lambda x, y: mbh.FlatExpr(x.sum()), + expected_str="x[0] + x[1] + x[2]", ), dict( - testcase_name=( - "_as_flat_linear_expression(_as_flat_linear_expression(x.sum()))" - ), + testcase_name="FlatExpr(FlatExpr(x.sum()))", # pylint: disable=g-long-lambda - expr=lambda x, y: mb._as_flat_linear_expression( - mb._as_flat_linear_expression(x.sum()) - ), - expected_repr="x[0] + x[1] + x[2]", + expr=lambda x, y: mbh.FlatExpr(mbh.FlatExpr(x.sum())), + expected_str="x[0] + x[1] + x[2]", ), dict( - testcase_name="""_as_flat_linear_expression(sum([ - _as_flat_linear_expression(x.sum()), - _as_flat_linear_expression(x.sum()), + testcase_name="""FlatExpr(sum([ + FlatExpr(x.sum()), + FlatExpr(x.sum()), ]))""", # pylint: disable=g-long-lambda - expr=lambda x, y: mb._as_flat_linear_expression( + expr=lambda x, y: mbh.FlatExpr( sum( [ - mb._as_flat_linear_expression(x.sum()), - mb._as_flat_linear_expression(x.sum()), + mbh.FlatExpr(x.sum()), + mbh.FlatExpr(x.sum()), ] ) ), - expected_repr="2.0 * x[0] + 2.0 * x[1] + 2.0 * x[2]", + expected_str="2 * x[0] + 2 * x[1] + 2 * x[2]", ), ) - def test_repr(self, expr, expected_repr): + def test_str(self, expr, expected_str): x = self.x y = self.y - self.assertEqual(repr(expr(x, y)), expected_repr) + self.assertEqual(str(mbh.FlatExpr(expr(x, y))), expected_str) class LinearBaseErrorsTest(absltest.TestCase): + def tearDown(self) -> None: + super().tearDown() + sys.stdout.flush() + def test_unknown_linear_type(self): - with self.assertRaisesRegex(TypeError, r"Unrecognized linear expression"): + with self.assertRaises(TypeError): class UnknownLinearType(mb.LinearExpr): - pass - mb._as_flat_linear_expression(UnknownLinearType()) + def __init__(self): + mb.LinearExpr.__init__(self) + + mbh.FlatExpr(UnknownLinearType()) def test_division_by_zero(self): with self.assertRaises(ZeroDivisionError): @@ -632,7 +716,7 @@ class LinearBaseErrorsTest(absltest.TestCase): print(x / 0) def test_boolean_expression(self): - with self.assertRaisesRegex(NotImplementedError, r"Cannot use a LinearExpr"): + with self.assertRaisesRegex(NotImplementedError, r"instance as a Boolean"): model = mb.Model() x = model.new_var_series(name="x", index=pd.Index(range(1))) bool(x.sum()) @@ -688,28 +772,26 @@ class BoundedLinearBaseTest(parameterized.TestCase): lambda lhs, rhs: lhs >= rhs, ), ) - def test_repr(self, lhs, rhs, op): + def test_str(self, lhs, rhs, op): x = self.x y = self.y l: mb.LinearExprT = lhs(x, y) r: mb.LinearExprT = rhs(x, y) result = op(l, r) if isinstance(l, mb.LinearExpr) or isinstance(r, mb.LinearExpr): - self.assertIsInstance(result, mb._BoundedLinearExpr) - self.assertIn("=", repr(result), msg="is one of ==, <=, or >=") + self.assertIsInstance(result, mbh.BoundedLinearExpression) + self.assertIn("=", str(result), msg="is one of ==, <=, or >=") else: self.assertIsInstance(result, bool) def test_doublesided_bounded_expressions(self): x = self.x - self.assertEqual( - "0.0 <= x[0] <= 1.0", repr(mb.BoundedLinearExpression(x[0], 0, 1)) - ) + self.assertEqual("0 <= x[0] <= 1", str(mb.BoundedLinearExpression(x[0], 0, 1))) def test_free_bounded_expressions(self): self.assertEqual( - "x[0] free", - repr(mb.BoundedLinearExpression(self.x[0], -math.inf, math.inf)), + "-inf <= x[0] <= inf", + str(mb.BoundedLinearExpression(self.x[0], -math.inf, math.inf)), ) def test_var_eq_var_as_bool(self): @@ -734,8 +816,20 @@ class BoundedLinearBaseTest(parameterized.TestCase): class BoundedLinearBaseErrorsTest(absltest.TestCase): + def tearDown(self) -> None: + super().tearDown() + sys.stdout.flush() + + def test_single_var_bounded_linear_expression_as_bool(self): + with self.assertRaisesRegex( + NotImplementedError, "Evaluating a BoundedLinearExpression" + ): + model = mb.Model() + x = model.new_bool_var(name="x") + bool(mb.BoundedLinearExpression(x, 0, 1)) + def test_bounded_linear_expression_as_bool(self): - with self.assertRaisesRegex(NotImplementedError, "Boolean value"): + with self.assertRaisesRegex(TypeError, "incompatible constructor arguments"): model = mb.Model() x = model.new_var_series(name="x", index=pd.Index(range(1))) bool(mb.BoundedLinearExpression(x, 0, 1)) @@ -743,6 +837,10 @@ class BoundedLinearBaseErrorsTest(absltest.TestCase): class ModelBuilderErrorsTest(absltest.TestCase): + def tearDown(self) -> None: + super().tearDown() + sys.stdout.flush() + def test_new_var_series_errors(self): with self.assertRaisesRegex(TypeError, r"Non-index object"): model = mb.Model() @@ -896,7 +994,7 @@ class ModelBuilderVariablesTest(parameterized.TestCase): self.assertLen(variables, len(index)) self.assertLen(set(variables), len(index)) for i in index: - self.assertEqual(repr(variables[i]), f"test_variable[{i}]") + self.assertEqual(variables[i].name, f"test_variable[{i}]") @parameterized.product( index=_variable_indices, bounds=_bounds, is_integer=_is_integer @@ -1440,7 +1538,7 @@ class ModelBuilderLinearConstraintsTest(parameterized.TestCase): for expr, expr_term in zip(linear_constraint_expressions, expr_terms): self.assertDictEqual(build_dict(expr), expr_term) self.assertSequenceAlmostEqual( - [expr._offset for expr in linear_constraint_expressions], + [expr.offset for expr in linear_constraint_expressions], expression_offsets, ) @@ -1473,19 +1571,22 @@ class ModelBuilderObjectiveTest(parameterized.TestCase): def assertLinearExpressionAlmostEqual( self, - expr1: mb._LinearExpression, - expr2: mb._LinearExpression, + expr1: mbh.LinearExpr, + expr2: mbh.LinearExpr, ) -> None: """Test that the two linear expressions are almost equal.""" - self.assertEqual(len(expr1.variable_indices), len(expr2.variable_indices)) - if len(expr1.variable_indices) > 0: # pylint: disable=g-explicit-length-test - self.assertSequenceEqual(expr1.variable_indices, expr2.variable_indices) + flat_expr1 = mbh.FlatExpr(expr1) + flat_expr2 = mbh.FlatExpr(expr2) + self.assertEqual(len(flat_expr1.vars), len(flat_expr2.vars)) + if len(flat_expr1.vars) > 0: # pylint: disable=g-explicit-length-test + self.assertSequenceEqual(flat_expr1.vars, flat_expr2.vars) self.assertSequenceAlmostEqual( - expr1.coefficients, expr2.coefficients, places=5 + flat_expr1.coeffs, flat_expr2.coeffs, places=5 ) else: - self.assertEmpty(expr2.coefficients) - self.assertAlmostEqual(expr1._offset, expr2._offset) + self.assertEmpty(flat_expr1.coeffs) + self.assertEmpty(flat_expr2.coeffs) + self.assertAlmostEqual(flat_expr1.offset, flat_expr2.offset) @parameterized.product( expression=_expressions, @@ -1501,7 +1602,7 @@ class ModelBuilderObjectiveTest(parameterized.TestCase): model = mb.Model() x = model.new_var_series(name="x", index=variable_indices) y = model.new_var_series(name="y", index=variable_indices) - objective_expression = mb._as_flat_linear_expression(expression(x, y)) + objective_expression = expression(x, y) if is_maximize: model.maximize(objective_expression) else: @@ -1515,14 +1616,14 @@ class ModelBuilderObjectiveTest(parameterized.TestCase): model = mb.Model() x = model.new_var_series(name="x", index=pd.Index(range(3))) old_objective_expression = 1 - new_objective_expression = mb._as_flat_linear_expression(x.sum() - 2.3) + new_objective_expression = x.sum() - 2.3 # Set and check for old objective. model.maximize(old_objective_expression) - got_objective_expression = model.objective_expression() - for var_coeff in got_objective_expression.coefficients: - self.assertAlmostEqual(var_coeff, 0) - self.assertAlmostEqual(got_objective_expression._offset, 1) + flat_got_objective_expression = mbh.FlatExpr(model.objective_expression()) + self.assertEmpty(flat_got_objective_expression.vars) + self.assertEmpty(flat_got_objective_expression.coeffs) + self.assertAlmostEqual(flat_got_objective_expression.offset, 1) # Set to a new objective and check that it is different. model.minimize(new_objective_expression) @@ -1543,7 +1644,7 @@ class ModelBuilderObjectiveTest(parameterized.TestCase): model = mb.Model() x = model.new_var_series(name="x", index=variable_indices) y = model.new_var_series(name="y", index=variable_indices) - objective_expression = mb._as_flat_linear_expression(expression(x, y)) + objective_expression = mbh.FlatExpr(expression(x, y)) model.minimize(objective_expression) got_objective_expression = model.objective_expression() self.assertLinearExpressionAlmostEqual( @@ -1562,7 +1663,7 @@ class ModelBuilderObjectiveTest(parameterized.TestCase): model = mb.Model() x = model.new_var_series(name="x", index=variable_indices) y = model.new_var_series(name="y", index=variable_indices) - objective_expression = mb._as_flat_linear_expression(expression(x, y)) + objective_expression = mbh.FlatExpr(expression(x, y)) model.maximize(objective_expression) got_objective_expression = model.objective_expression() self.assertLinearExpressionAlmostEqual( @@ -1572,6 +1673,10 @@ class ModelBuilderObjectiveTest(parameterized.TestCase): class ModelBuilderProtoTest(absltest.TestCase): + def tearDown(self) -> None: + super().tearDown() + sys.stdout.flush() + def test_export_to_proto(self): expected = linear_solver_pb2.MPModelProto() text_format.Parse( @@ -1935,6 +2040,11 @@ class SolverTest(parameterized.TestCase): class ModelBuilderExamplesTest(absltest.TestCase): + + def tearDown(self) -> None: + super().tearDown() + sys.stdout.flush() + def test_simple_problem(self): # max 5x1 + 4x2 + 3x3 # s.t 2x1 + 3x2 + x3 <= 5 diff --git a/ortools/linear_solver/wrappers/model_builder_helper.cc b/ortools/linear_solver/wrappers/model_builder_helper.cc index 1a43475404..688350f93c 100644 --- a/ortools/linear_solver/wrappers/model_builder_helper.cc +++ b/ortools/linear_solver/wrappers/model_builder_helper.cc @@ -14,6 +14,7 @@ #include "ortools/linear_solver/wrappers/model_builder_helper.h" #include +#include #include #include #include @@ -23,6 +24,8 @@ #include "absl/log/check.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "ortools/base/helpers.h" #include "ortools/base/logging.h" #include "ortools/base/options.h" @@ -51,6 +54,9 @@ #include "ortools/xpress/environment.h" namespace operations_research { +namespace mb { + +// ModelBuilderHelper. void ModelBuilderHelper::OverwriteModel( const ModelBuilderHelper& other_helper) { @@ -514,7 +520,8 @@ SolveStatus MPSolverResponseStatusToSolveStatus(MPSolverResponseStatus s) { } } // namespace -ModelSolverHelper::ModelSolverHelper(const std::string& solver_name) { +ModelSolverHelper::ModelSolverHelper(const std::string& solver_name) + : evaluator_(this) { if (solver_name.empty()) return; MPSolver::OptimizationProblemType parsed_type; if (!MPSolver::ParseSolverType(solver_name, &parsed_type)) { @@ -709,6 +716,13 @@ double ModelSolverHelper::variable_value(int var_index) const { return response_.value().variable_value(var_index); } +double ModelSolverHelper::expression_value(LinearExpr* expr) const { + if (!has_response()) return 0.0; + evaluator_.Clear(); + evaluator_.AddToProcess(expr, 1.0); + return evaluator_.Evaluate(); +} + double ModelSolverHelper::reduced_cost(int var_index) const { if (!has_response()) return 0.0; if (var_index >= response_.value().reduced_cost_size()) return 0.0; @@ -768,4 +782,564 @@ void ModelSolverHelper::SetSolverSpecificParameters( void ModelSolverHelper::EnableOutput(bool enabled) { solver_output_ = enabled; } +// Expressions. +LinearExpr* LinearExpr::Term(LinearExpr* expr, double coeff) { + return new AffineExpr(expr, coeff, 0.0); +} + +LinearExpr* LinearExpr::Affine(LinearExpr* expr, double coeff, + double constant) { + if (coeff == 1.0 && constant == 0.0) return expr; + return new AffineExpr(expr, coeff, constant); +} + +LinearExpr* LinearExpr::AffineCst(double value, double coeff, double constant) { + return new FixedValue(value * coeff + constant); +} + +LinearExpr* LinearExpr::Constant(double value) { return new FixedValue(value); } + +LinearExpr* LinearExpr::Add(LinearExpr* expr) { + return new SumArray({this, expr}, 0.0); +} + +LinearExpr* LinearExpr::AddFloat(double cst) { + if (cst == 0.0) return this; + return new AffineExpr(this, 1.0, cst); +} + +LinearExpr* LinearExpr::Sub(LinearExpr* expr) { + return new WeightedSumArray({this, expr}, {1, -1}, 0.0); +} + +LinearExpr* LinearExpr::SubFloat(double cst) { + if (cst == 0.0) return this; + return new AffineExpr(this, 1.0, -cst); +} + +LinearExpr* LinearExpr::RSubFloat(double cst) { + return new AffineExpr(this, -1.0, cst); +} + +LinearExpr* LinearExpr::MulFloat(double cst) { + if (cst == 0.0) return new FixedValue(0.0); + if (cst == 1.0) return this; + return new AffineExpr(this, cst, 0.0); +} + +LinearExpr* LinearExpr::Neg() { return new AffineExpr(this, -1, 0); } + +// Expression visitors. + +void ExprVisitor::AddToProcess(const LinearExpr* expr, double coeff) { + to_process_.push_back(std::make_pair(expr, coeff)); +} + +void ExprVisitor::AddConstant(double constant) { offset_ += constant; } + +void ExprVisitor::Clear() { + to_process_.clear(); + offset_ = 0.0; +} + +void ExprFlattener::AddVarCoeff(const Variable* var, double coeff) { + canonical_terms_[var] += coeff; +} + +double ExprFlattener::Flatten(std::vector* vars, + std::vector* coeffs) { + while (!to_process_.empty()) { + const auto [expr, coeff] = to_process_.back(); + to_process_.pop_back(); + expr->Visit(*this, coeff); + } + + vars->clear(); + coeffs->clear(); + for (const auto& [var, coeff] : canonical_terms_) { + if (coeff == 0.0) continue; + vars->push_back(var); + coeffs->push_back(coeff); + } + + return offset_; +} + +void ExprEvaluator::AddVarCoeff(const Variable* var, double coeff) { + offset_ += coeff * helper_->variable_value(var->index()); +} + +double ExprEvaluator::Evaluate() { + offset_ = 0.0; + while (!to_process_.empty()) { + const auto [expr, coeff] = to_process_.back(); + to_process_.pop_back(); + expr->Visit(*this, coeff); + } + return offset_; +} + +FlatExpr::FlatExpr(const LinearExpr* expr) { + ExprFlattener lin; + lin.AddToProcess(expr, 1.0); + offset_ = lin.Flatten(&vars_, &coeffs_); +} + +FlatExpr::FlatExpr(const LinearExpr* pos, const LinearExpr* neg) { + ExprFlattener lin; + lin.AddToProcess(pos, 1.0); + lin.AddToProcess(neg, -1.0); + offset_ = lin.Flatten(&vars_, &coeffs_); +} + +FlatExpr::FlatExpr(const std::vector& vars, + const std::vector& coeffs, double offset) + : vars_(vars), coeffs_(coeffs), offset_(offset) {} + +FlatExpr::FlatExpr(double offset) : offset_(offset) {} + +std::vector FlatExpr::VarIndices() const { + std::vector var_indices; + var_indices.reserve(vars_.size()); + for (const Variable* var : vars_) { + var_indices.push_back(var->index()); + } + return var_indices; +} + +void FlatExpr::Visit(ExprVisitor& lin, double c) const { + for (int i = 0; i < vars_.size(); ++i) { + lin.AddVarCoeff(vars_[i], coeffs_[i] * c); + } + lin.AddConstant(offset_ * c); +} + +std::string FlatExpr::ToString() const { + if (vars_.empty()) { + return absl::StrCat(offset_); + } + std::string s; + int num_printed = 0; + for (int i = 0; i < vars_.size(); ++i) { + DCHECK_NE(coeffs_[i], 0.0); + ++num_printed; + if (num_printed > 5) { + absl::StrAppend(&s, " + ..."); + break; + } + if (num_printed == 1) { + if (coeffs_[i] == 1.0) { + absl::StrAppend(&s, vars_[i]->ToString()); + } else if (coeffs_[i] == -1.0) { + absl::StrAppend(&s, "-", vars_[i]->ToString()); + } else { + absl::StrAppend(&s, coeffs_[i], " * ", vars_[i]->ToString()); + } + } else { + if (coeffs_[i] == 1.0) { + absl::StrAppend(&s, " + ", vars_[i]->ToString()); + } else if (coeffs_[i] == -1.0) { + absl::StrAppend(&s, " - ", vars_[i]->ToString()); + } else if (coeffs_[i] > 0.0) { + absl::StrAppend(&s, " + ", coeffs_[i], " * ", vars_[i]->ToString()); + } else { + absl::StrAppend(&s, " - ", -coeffs_[i], " * ", vars_[i]->ToString()); + } + } + } + // If there are no terms, just print the offset. + if (num_printed == 0) { + return absl::StrCat(offset_); + } + + // If there is an offset, print it. + if (offset_ != 0.0) { + if (offset_ > 0.0) { + absl::StrAppend(&s, " + ", offset_); + } else { + absl::StrAppend(&s, " - ", -offset_); + } + } + return s; +} + +std::string FlatExpr::DebugString() const { + std::string s = absl::StrCat( + "FlatExpr(", + absl::StrJoin(vars_, ", ", [](std::string* out, const Variable* expr) { + absl::StrAppend(out, expr->DebugString()); + })); + if (offset_ != 0.0) { + absl::StrAppend(&s, ", offset=", offset_); + } + absl::StrAppend(&s, ")"); + return s; +} + +void FixedValue::Visit(ExprVisitor& lin, double c) const { + lin.AddConstant(value_ * c); +} + +std::string FixedValue::ToString() const { return absl::StrCat(value_); } + +std::string FixedValue::DebugString() const { + return absl::StrCat("FixedValue(", value_, ")"); +} + +WeightedSumArray::WeightedSumArray(const std::vector& exprs, + const std::vector& coeffs, + double offset) + : exprs_(exprs.begin(), exprs.end()), + coeffs_(coeffs.begin(), coeffs.end()), + offset_(offset) {} + +void WeightedSumArray::Visit(ExprVisitor& lin, double c) const { + for (int i = 0; i < exprs_.size(); ++i) { + lin.AddToProcess(exprs_[i], coeffs_[i] * c); + } + lin.AddConstant(offset_ * c); +} + +std::string WeightedSumArray::ToString() const { + if (exprs_.empty()) { + return absl::StrCat(offset_); + } + std::string s = "("; + bool first_printed = true; + for (int i = 0; i < exprs_.size(); ++i) { + if (coeffs_[i] == 0.0) continue; + if (first_printed) { + first_printed = false; + if (coeffs_[i] == 1.0) { + absl::StrAppend(&s, exprs_[i]->ToString()); + } else if (coeffs_[i] == -1.0) { + absl::StrAppend(&s, "-", exprs_[i]->ToString()); + } else { + absl::StrAppend(&s, coeffs_[i], " * ", exprs_[i]->ToString()); + } + } else { + if (coeffs_[i] == 1.0) { + absl::StrAppend(&s, " + ", exprs_[i]->ToString()); + } else if (coeffs_[i] == -1.0) { + absl::StrAppend(&s, " - ", exprs_[i]->ToString()); + } else if (coeffs_[i] > 0.0) { + absl::StrAppend(&s, " + ", coeffs_[i], " * ", exprs_[i]->ToString()); + } else { + absl::StrAppend(&s, " - ", -coeffs_[i], " * ", exprs_[i]->ToString()); + } + } + } + // If there are no terms, just print the offset. + if (first_printed) { + return absl::StrCat(offset_); + } + + // If there is an offset, print it. + if (offset_ != 0.0) { + if (offset_ > 0.0) { + absl::StrAppend(&s, " + ", offset_); + } else { + absl::StrAppend(&s, " - ", -offset_); + } + } + absl::StrAppend(&s, ")"); + return s; +} + +std::string WeightedSumArray::DebugString() const { + return absl::StrCat("WeightedSumArray([", + absl::StrJoin(exprs_, ", ", + [](std::string* out, const LinearExpr* e) { + absl::StrAppend(out, e->DebugString()); + }), + "], [", absl::StrJoin(coeffs_, "], "), offset_, ")"); +} + +AffineExpr::AffineExpr(LinearExpr* expr, double coeff, double offset) + : expr_(expr), coeff_(coeff), offset_(offset) {} + +void AffineExpr::Visit(ExprVisitor& lin, double c) const { + lin.AddToProcess(expr_, c * coeff_); + lin.AddConstant(offset_ * c); +} + +std::string AffineExpr::ToString() const { + std::string s = "("; + if (coeff_ == 1.0) { + absl::StrAppend(&s, expr_->ToString()); + } else if (coeff_ == -1.0) { + absl::StrAppend(&s, "-", expr_->ToString()); + } else { + absl::StrAppend(&s, coeff_, " * ", expr_->ToString()); + } + if (offset_ > 0.0) { + absl::StrAppend(&s, " + ", offset_); + } else if (offset_ < 0.0) { + absl::StrAppend(&s, " - ", -offset_); + } + absl::StrAppend(&s, ")"); + return s; +} + +std::string AffineExpr::DebugString() const { + return absl::StrCat("AffineExpr(expr=", expr_->DebugString(), + ", coeff=", coeff_, ", offset=", offset_, ")"); +} +BoundedLinearExpression* LinearExpr::Eq(LinearExpr* rhs) { + return new BoundedLinearExpression(this, rhs, 0.0, 0.0); +} + +BoundedLinearExpression* LinearExpr::EqCst(double rhs) { + return new BoundedLinearExpression(this, rhs, rhs); +} + +BoundedLinearExpression* LinearExpr::Le(LinearExpr* rhs) { + return new BoundedLinearExpression( + this, rhs, -std::numeric_limits::infinity(), 0.0); +} + +BoundedLinearExpression* LinearExpr::LeCst(double rhs) { + return new BoundedLinearExpression( + this, -std::numeric_limits::infinity(), rhs); +} + +BoundedLinearExpression* LinearExpr::Ge(LinearExpr* rhs) { + return new BoundedLinearExpression(this, rhs, 0.0, + std::numeric_limits::infinity()); +} + +BoundedLinearExpression* LinearExpr::GeCst(double rhs) { + return new BoundedLinearExpression(this, rhs, + std::numeric_limits::infinity()); +} + +bool VariableComparator::operator()(const Variable* lhs, + const Variable* rhs) const { + return lhs->index() < rhs->index(); +} + +Variable::Variable(ModelBuilderHelper* helper, int index) + : helper_(helper), index_(index) {} + +Variable::Variable(ModelBuilderHelper* helper, double lb, double ub, + bool is_integral) + : helper_(helper) { + index_ = helper_->AddVar(); + helper_->SetVarLowerBound(index_, lb); + helper_->SetVarUpperBound(index_, ub); + helper_->SetVarIntegrality(index_, is_integral); +} + +Variable::Variable(ModelBuilderHelper* helper, double lb, double ub, + bool is_integral, const std::string& name) + : helper_(helper) { + index_ = helper_->AddVar(); + helper_->SetVarLowerBound(index_, lb); + helper_->SetVarUpperBound(index_, ub); + helper_->SetVarIntegrality(index_, is_integral); + helper_->SetVarName(index_, name); +} + +Variable::Variable(ModelBuilderHelper* helper, int64_t lb, int64_t ub, + bool is_integral) + : helper_(helper) { + index_ = helper_->AddVar(); + helper_->SetVarLowerBound(index_, lb); + helper_->SetVarUpperBound(index_, ub); + helper_->SetVarIntegrality(index_, is_integral); +} + +Variable::Variable(ModelBuilderHelper* helper, int64_t lb, int64_t ub, + bool is_integral, const std::string& name) + : helper_(helper) { + index_ = helper_->AddVar(); + helper_->SetVarLowerBound(index_, lb); + helper_->SetVarUpperBound(index_, ub); + helper_->SetVarIntegrality(index_, is_integral); + helper_->SetVarName(index_, name); +} + +std::string Variable::ToString() const { + if (!helper_->VarName(index_).empty()) { + return helper_->VarName(index_); + } else { + return absl::StrCat("Variable(", index_, ")"); + } +} + +std::string Variable::DebugString() const { + return absl::StrCat("Variable(index=", index_, + ", lb=", helper_->VarLowerBound(index_), + ", ub=", helper_->VarUpperBound(index_), + ", is_integral=", helper_->VarIsIntegral(index_), + ", name=\'", helper_->VarName(index_), "')"); +} + +std::string Variable::name() const { + const std::string& var_name = helper_->VarName(index_); + if (!var_name.empty()) return var_name; + return absl::StrCat("variable#", index_); +} + +void Variable::SetName(const std::string& name) { + helper_->SetVarName(index_, name); +} + +double Variable::lower_bounds() const { return helper_->VarLowerBound(index_); } + +void Variable::SetLowerBound(double lb) { + helper_->SetVarLowerBound(index_, lb); +} + +double Variable::upper_bound() const { return helper_->VarUpperBound(index_); } + +void Variable::SetUpperBound(double ub) { + helper_->SetVarUpperBound(index_, ub); +} + +bool Variable::is_integral() const { return helper_->VarIsIntegral(index_); } + +void Variable::SetIsIntegral(bool is_integral) { + helper_->SetVarIntegrality(index_, is_integral); +} + +double Variable::objective_coefficient() const { + return helper_->VarObjectiveCoefficient(index_); +} + +void Variable::SetObjectiveCoefficient(double coeff) { + helper_->SetVarObjectiveCoefficient(index_, coeff); +} + +BoundedLinearExpression::BoundedLinearExpression(const LinearExpr* expr, + double lower_bound, + double upper_bound) { + FlatExpr flat_expr(expr); + vars_ = flat_expr.vars(); + coeffs_ = flat_expr.coeffs(); + lower_bound_ = lower_bound - flat_expr.offset(); + upper_bound_ = upper_bound - flat_expr.offset(); +} + +BoundedLinearExpression::BoundedLinearExpression(const LinearExpr* pos, + const LinearExpr* neg, + double lower_bound, + double upper_bound) { + FlatExpr flat_expr(pos, neg); + vars_ = flat_expr.vars(); + coeffs_ = flat_expr.coeffs(); + lower_bound_ = lower_bound - flat_expr.offset(); + upper_bound_ = upper_bound - flat_expr.offset(); +} + +BoundedLinearExpression::BoundedLinearExpression(const LinearExpr* expr, + int64_t lower_bound, + int64_t upper_bound) { + FlatExpr flat_expr(expr); + vars_ = flat_expr.vars(); + coeffs_ = flat_expr.coeffs(); + lower_bound_ = lower_bound - flat_expr.offset(); + upper_bound_ = upper_bound - flat_expr.offset(); +} + +BoundedLinearExpression::BoundedLinearExpression(const LinearExpr* pos, + const LinearExpr* neg, + int64_t lower_bound, + int64_t upper_bound) { + FlatExpr flat_expr(pos, neg); + vars_ = flat_expr.vars(); + coeffs_ = flat_expr.coeffs(); + lower_bound_ = lower_bound - flat_expr.offset(); + upper_bound_ = upper_bound - flat_expr.offset(); +} + +double BoundedLinearExpression::lower_bound() const { return lower_bound_; } +double BoundedLinearExpression::upper_bound() const { return upper_bound_; } +const std::vector& BoundedLinearExpression::vars() const { + return vars_; +} +const std::vector& BoundedLinearExpression::coeffs() const { + return coeffs_; +} +std::string BoundedLinearExpression::ToString() const { + std::string s; + if (vars_.empty()) { + s = absl::StrCat(0.0); + } else if (vars_.size() == 1) { + const std::string var_name = vars_[0]->ToString(); + if (coeffs_[0] == 1) { + s = var_name; + } else if (coeffs_[0] == -1) { + s = absl::StrCat("-", var_name); + } else { + s = absl::StrCat(coeffs_[0], " * ", var_name); + } + } else { + s = "("; + for (int i = 0; i < vars_.size(); ++i) { + const std::string var_name = vars_[i]->ToString(); + if (i == 0) { + if (coeffs_[i] == 1) { + absl::StrAppend(&s, var_name); + } else if (coeffs_[i] == -1) { + absl::StrAppend(&s, "-", var_name); + } else { + absl::StrAppend(&s, coeffs_[i], " * ", var_name); + } + } else { + if (coeffs_[i] == 1) { + absl::StrAppend(&s, " + ", var_name); + } else if (coeffs_[i] == -1) { + absl::StrAppend(&s, " - ", var_name); + } else if (coeffs_[i] > 1) { + absl::StrAppend(&s, " + ", coeffs_[i], " * ", var_name); + } else { + absl::StrAppend(&s, " - ", -coeffs_[i], " * ", var_name); + } + } + } + absl::StrAppend(&s, ")"); + } + if (lower_bound_ == upper_bound_) { + return absl::StrCat(s, " == ", lower_bound_); + } else if (lower_bound_ == std::numeric_limits::min()) { + if (upper_bound_ == std::numeric_limits::max()) { + return absl::StrCat("True (unbounded expr ", s, ")"); + } else { + return absl::StrCat(s, " <= ", upper_bound_); + } + } else if (upper_bound_ == std::numeric_limits::max()) { + return absl::StrCat(s, " >= ", lower_bound_); + } else { + return absl::StrCat(lower_bound_, " <= ", s, " <= ", upper_bound_); + } +} + +std::string BoundedLinearExpression::DebugString() const { + return absl::StrCat("BoundedLinearExpression(vars=[", + absl::StrJoin(vars_, ", ", + [](std::string* out, const Variable* var) { + absl::StrAppend(out, var->DebugString()); + }), + "], coeffs=[", absl::StrJoin(coeffs_, ", "), + "], lower_bound=", lower_bound_, + ", upper_bound=", upper_bound_, ")"); +} + +bool BoundedLinearExpression::CastToBool(bool* result) const { + const bool is_zero = lower_bound_ == 0.0 && upper_bound_ == 0.0; + if (is_zero) { + if (vars_.empty()) { + *result = true; + return true; + } else if (vars_.size() == 2 && coeffs_[0] + coeffs_[1] == 0 && + std::abs(coeffs_[0]) == 1) { + *result = false; + return true; + } + } + return false; +} + +} // namespace mb } // namespace operations_research diff --git a/ortools/linear_solver/wrappers/model_builder_helper.h b/ortools/linear_solver/wrappers/model_builder_helper.h index 797c10abf3..5b42e97fb8 100644 --- a/ortools/linear_solver/wrappers/model_builder_helper.h +++ b/ortools/linear_solver/wrappers/model_builder_helper.h @@ -15,18 +15,314 @@ #define OR_TOOLS_LINEAR_SOLVER_WRAPPERS_MODEL_BUILDER_HELPER_H_ #include +#include #include -#include #include #include +#include #include +#include "absl/container/btree_map.h" +#include "absl/container/fixed_array.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "ortools/linear_solver/linear_solver.pb.h" #include "ortools/linear_solver/model_exporter.h" -#include "ortools/util/logging.h" #include "ortools/util/solve_interrupter.h" namespace operations_research { +namespace mb { + +// Base implementation of linear expressions. + +class BoundedLinearExpression; +class FlatExpr; +class ExprVisitor; +class LinearExpr; +class ModelBuilderHelper; +class ModelSolverHelper; +class Variable; + +// A linear expression that can be either integer or floating point. +class LinearExpr { + public: + virtual ~LinearExpr() = default; + virtual void Visit(ExprVisitor& /*lin*/, double /*c*/) const = 0; + virtual std::string ToString() const = 0; + virtual std::string DebugString() const = 0; + + static LinearExpr* Term(LinearExpr* expr, double coeff); + static LinearExpr* Affine(LinearExpr* expr, double coeff, double constant); + static LinearExpr* AffineCst(double value, double coeff, double constant); + static LinearExpr* Constant(double value); + + LinearExpr* Add(LinearExpr* expr); + LinearExpr* AddFloat(double cst); + LinearExpr* Sub(LinearExpr* expr); + LinearExpr* SubFloat(double cst); + LinearExpr* RSubFloat(double cst); + LinearExpr* MulFloat(double cst); + LinearExpr* Neg(); + + BoundedLinearExpression* Eq(LinearExpr* rhs); + BoundedLinearExpression* EqCst(double rhs); + BoundedLinearExpression* Ge(LinearExpr* rhs); + BoundedLinearExpression* GeCst(double rhs); + BoundedLinearExpression* Le(LinearExpr* rhs); + BoundedLinearExpression* LeCst(double rhs); +}; + +// Compare the indices of variables. +struct VariableComparator { + bool operator()(const Variable* lhs, const Variable* rhs) const; +}; + +// A visitor class to parse a floating point linear expression. +class ExprVisitor { + public: + virtual ~ExprVisitor() = default; + void AddToProcess(const LinearExpr* expr, double coeff); + void AddConstant(double constant); + virtual void AddVarCoeff(const Variable* var, double coeff) = 0; + void Clear(); + + protected: + std::vector> to_process_; + double offset_ = 0; +}; + +class ExprFlattener : public ExprVisitor { + public: + ~ExprFlattener() override = default; + void AddVarCoeff(const Variable* var, double coeff) override; + double Flatten(std::vector* vars, + std::vector* coeffs); + + private: + absl::btree_map canonical_terms_; +}; + +class ExprEvaluator : public ExprVisitor { + public: + explicit ExprEvaluator(ModelSolverHelper* helper) : helper_(helper) {} + ~ExprEvaluator() override = default; + void AddVarCoeff(const Variable* var, double coeff) override; + double Evaluate(); + + private: + ModelSolverHelper* helper_; +}; + +// A flat linear expression sum(vars[i] * coeffs[i]) + offset +class FlatExpr : public LinearExpr { + public: + explicit FlatExpr(const LinearExpr* expr); + // Flatten pos - neg. + FlatExpr(const LinearExpr* pos, const LinearExpr* neg); + FlatExpr(const std::vector&, const std::vector&, + double); + explicit FlatExpr(double offset); + const std::vector& vars() const { return vars_; } + std::vector VarIndices() const; + const std::vector& coeffs() const { return coeffs_; } + double offset() const { return offset_; } + + void Visit(ExprVisitor& lin, double c) const override; + std::string ToString() const override; + std::string DebugString() const override; + + private: + std::vector vars_; + std::vector coeffs_; + double offset_; +}; + +// A class to hold a sum of linear expressions, and optional integer and +// double offsets. +class SumArray : public LinearExpr { + public: + explicit SumArray(const std::vector& exprs, double offset) + : exprs_(exprs.begin(), exprs.end()), offset_(offset) {} + ~SumArray() override = default; + + void Visit(ExprVisitor& lin, double c) const override { + for (int i = 0; i < exprs_.size(); ++i) { + lin.AddToProcess(exprs_[i], c); + } + if (offset_ != 0.0) { + lin.AddConstant(offset_ * c); + } + } + + std::string ToString() const override { + if (exprs_.empty()) { + if (offset_ != 0.0) { + return absl::StrCat(offset_); + } + } + std::string s = "("; + for (int i = 0; i < exprs_.size(); ++i) { + if (i > 0) { + absl::StrAppend(&s, " + "); + } + absl::StrAppend(&s, exprs_[i]->ToString()); + } + if (offset_ != 0.0) { + if (offset_ > 0.0) { + absl::StrAppend(&s, " + ", offset_); + } else { + absl::StrAppend(&s, " - ", -offset_); + } + } + absl::StrAppend(&s, ")"); + return s; + } + + std::string DebugString() const override { + std::string s = absl::StrCat( + "SumArray(", + absl::StrJoin(exprs_, ", ", [](std::string* out, LinearExpr* expr) { + absl::StrAppend(out, expr->DebugString()); + })); + if (offset_ != 0.0) { + absl::StrAppend(&s, ", offset=", offset_); + } + absl::StrAppend(&s, ")"); + return s; + } + + private: + const absl::FixedArray exprs_; + const double offset_; +}; + +// A class to hold a weighted sum of floating point linear expressions. +class WeightedSumArray : public LinearExpr { + public: + WeightedSumArray(const std::vector& exprs, + const std::vector& coeffs, double offset); + ~WeightedSumArray() override = default; + + void Visit(ExprVisitor& lin, double c) const override; + std::string ToString() const override; + std::string DebugString() const override; + + private: + const absl::FixedArray exprs_; + const absl::FixedArray coeffs_; + double offset_; +}; + +// A class to hold linear_expr * a = b. +class AffineExpr : public LinearExpr { + public: + AffineExpr(LinearExpr* expr, double coeff, double offset); + ~AffineExpr() override = default; + + void Visit(ExprVisitor& lin, double c) const override; + + std::string ToString() const override; + std::string DebugString() const override; + + LinearExpr* expression() const { return expr_; } + double coefficient() const { return coeff_; } + double offset() const { return offset_; } + + private: + LinearExpr* expr_; + double coeff_; + double offset_; +}; + +// A class to hold a fixed value. +class FixedValue : public LinearExpr { + public: + explicit FixedValue(double value) : value_(value) {} + ~FixedValue() override = default; + + void Visit(ExprVisitor& lin, double c) const override; + + std::string ToString() const override; + std::string DebugString() const override; + + private: + double value_; +}; + +// A class to hold a variable index. +class Variable : public LinearExpr { + public: + Variable(ModelBuilderHelper* helper, int index); + Variable(ModelBuilderHelper* helper, double lb, double ub, bool is_integral); + Variable(ModelBuilderHelper* helper, double lb, double ub, bool is_integral, + const std::string& name); + Variable(ModelBuilderHelper* helper, int64_t lb, int64_t ub, + bool is_integral); + Variable(ModelBuilderHelper* helper, int64_t lb, int64_t ub, bool is_integral, + const std::string& name); + ~Variable() override {} + + ModelBuilderHelper* helper() const { return helper_; } + int index() const { return index_; } + std::string name() const; + void SetName(const std::string& name); + double lower_bounds() const; + void SetLowerBound(double lb); + double upper_bound() const; + void SetUpperBound(double ub); + bool is_integral() const; + void SetIsIntegral(bool is_integral); + double objective_coefficient() const; + void SetObjectiveCoefficient(double coeff); + + void Visit(ExprVisitor& lin, double c) const override { + lin.AddVarCoeff(this, c); + } + + std::string ToString() const override; + + std::string DebugString() const override; + + bool operator<(const Variable& other) const { return index_ < other.index_; } + + protected: + ModelBuilderHelper* helper_; + int index_; +}; + +template +H AbslHashValue(H h, const Variable* i) { + return H::combine(std::move(h), i->index()); +} + +// A class to hold a linear expression with bounds. +class BoundedLinearExpression { + public: + BoundedLinearExpression(const LinearExpr* expr, double lower_bound, + double upper_bound); + BoundedLinearExpression(const LinearExpr* pos, const LinearExpr* neg, + double lower_bound, double upper_bound); + BoundedLinearExpression(const LinearExpr* expr, int64_t lower_bound, + int64_t upper_bound); + BoundedLinearExpression(const LinearExpr* pos, const LinearExpr* neg, + int64_t lower_bound, int64_t upper_bound); + + ~BoundedLinearExpression() = default; + + double lower_bound() const; + double upper_bound() const; + const std::vector& vars() const; + const std::vector& coeffs() const; + std::string ToString() const; + std::string DebugString() const; + bool CastToBool(bool* result) const; + + private: + std::vector vars_; + std::vector coeffs_; + double lower_bound_; + double upper_bound_; +}; // The arguments of the functions defined below must follow these rules // to be wrapped by SWIG correctly: @@ -189,6 +485,7 @@ class ModelSolverHelper { double objective_value() const; double best_objective_bound() const; double variable_value(int var_index) const; + double expression_value(LinearExpr* expr) const; double reduced_cost(int var_index) const; double dual_value(int ct_index) const; double activity(int ct_index); @@ -216,8 +513,10 @@ class ModelSolverHelper { std::optional model_of_last_solve_; std::vector activities_; bool solver_output_ = false; + mutable ExprEvaluator evaluator_; }; +} // namespace mb } // namespace operations_research #endif // OR_TOOLS_LINEAR_SOLVER_WRAPPERS_MODEL_BUILDER_HELPER_H_ diff --git a/ortools/port/proto_utils.h b/ortools/port/proto_utils.h index 2eb25f7731..aacc755f4c 100644 --- a/ortools/port/proto_utils.h +++ b/ortools/port/proto_utils.h @@ -31,7 +31,7 @@ namespace operations_research { template std::string ProtobufDebugString(const P& message) { #if defined(__PORTABLE_PLATFORM__) - return message.GetTypeName(); + return std::string(message.GetTypeName()); #else // defined(__PORTABLE_PLATFORM__) return message.DebugString(); #endif // !defined(__PORTABLE_PLATFORM__) @@ -40,7 +40,7 @@ std::string ProtobufDebugString(const P& message) { template std::string ProtobufShortDebugString(const P& message) { #if defined(__PORTABLE_PLATFORM__) - return message.GetTypeName(); + return std::string(message.GetTypeName()); #else // defined(__PORTABLE_PLATFORM__) return message.ShortDebugString(); #endif // !defined(__PORTABLE_PLATFORM__) diff --git a/ortools/python/BUILD.bazel b/ortools/python/BUILD.bazel index 015491c399..56b9e6c65d 100644 --- a/ortools/python/BUILD.bazel +++ b/ortools/python/BUILD.bazel @@ -23,7 +23,7 @@ py_binary( "//ortools/graph/python:linear_sum_assignment.so", "//ortools/graph/python:max_flow.so", "//ortools/graph/python:min_cost_flow.so", - "//ortools/sat/python:swig_helper.so", + "//ortools/sat/python:cp_model_helper.so", ], tags = ["manual"], deps = [ @@ -32,7 +32,7 @@ py_binary( "//ortools/sat/colab:flags", "//ortools/sat/colab:visualization", "//ortools/sat/python:cp_model", - "//ortools/sat/python:cp_model_helper", + "//ortools/sat/python:cp_model_numbers", requirement("notebook"), requirement("svgwrite"), requirement("plotly"), diff --git a/ortools/python/setup.py.in b/ortools/python/setup.py.in index 22ac324970..d8db87fcba 100644 --- a/ortools/python/setup.py.in +++ b/ortools/python/setup.py.in @@ -106,7 +106,7 @@ setup( '@PYTHON_PROJECT@.sat':['*.pyi'], '@PYTHON_PROJECT@.sat.colab':['*.pyi', 'py.typed'], '@PYTHON_PROJECT@.sat.python':[ - '$', + '$', '*.pyi', 'py.typed' ], diff --git a/ortools/sat/2d_orthogonal_packing.cc b/ortools/sat/2d_orthogonal_packing.cc index 83b7f9bbfe..adbaeec63b 100644 --- a/ortools/sat/2d_orthogonal_packing.cc +++ b/ortools/sat/2d_orthogonal_packing.cc @@ -26,6 +26,7 @@ #include "absl/numeric/bits.h" #include "absl/random/distributions.h" #include "absl/types/span.h" +#include "ortools/base/constant_divisor.h" #include "ortools/base/logging.h" #include "ortools/sat/2d_packing_brute_force.h" #include "ortools/sat/integer_base.h" @@ -404,16 +405,16 @@ void OrthogonalPackingInfeasibilityDetector::GetAllCandidatesForKForDff2( candidates.Set(i); } for (int i = 1; i <= sqrt_bb_size; i++) { - const QuickSmallDivision div(i); + const ::util::math::ConstantDivisor div(i); if (i > 1) { - candidates.Set(div.DivideByDivisor(bb_size.value())); + candidates.Set(bb_size.value() / div); } for (int k = 0; k < sizes.size(); k++) { IntegerValue size = sizes[k]; if (2 * size > bb_size && size < bb_size) { - candidates.Set(div.DivideByDivisor(bb_size.value() - size.value() + 1)); + candidates.Set((bb_size.value() - size.value() + 1) / div); } else if (2 * size < bb_size) { - candidates.Set(div.DivideByDivisor(size.value())); + candidates.Set(size.value() / div); } } } diff --git a/ortools/sat/2d_orthogonal_packing.h b/ortools/sat/2d_orthogonal_packing.h index 65d0d67279..f1de4a2979 100644 --- a/ortools/sat/2d_orthogonal_packing.h +++ b/ortools/sat/2d_orthogonal_packing.h @@ -22,6 +22,7 @@ #include "absl/log/check.h" #include "absl/random/bit_gen_ref.h" #include "absl/types/span.h" +#include "ortools/base/constant_divisor.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/synchronization.h" #include "ortools/util/bitset.h" @@ -281,10 +282,7 @@ class RoundingDualFeasibleFunction { public: // `max_x` must fit in a uint16_t and `k` in [0, max_x/2]. RoundingDualFeasibleFunction(IntegerValue max_x, IntegerValue k) - : div_(k.value()), - max_x_(max_x), - c_k_(div_.DivideByDivisor(max_x_.value())), - k_(k) { + : div_(k.value()), max_x_(max_x), c_k_(max_x_.value() / div_), k_(k) { DCHECK_GT(k, 0); DCHECK_LE(2 * k, max_x_); DCHECK_LE(max_x_, std::numeric_limits::max()); @@ -296,11 +294,11 @@ class RoundingDualFeasibleFunction { DCHECK_LE(x, max_x_); if (2 * x > max_x_) { - return 2 * (c_k_ - div_.DivideByDivisor(max_x_.value() - x.value())); + return 2 * (c_k_ - (max_x_.value() - x.value()) / div_); } else if (2 * x == max_x_) { return c_k_; } else { - return 2 * div_.DivideByDivisor(x.value()); + return 2 * (x.value() / div_); } } @@ -309,7 +307,7 @@ class RoundingDualFeasibleFunction { IntegerValue LowestInverse(IntegerValue y) const; private: - const QuickSmallDivision div_; + const ::util::math::ConstantDivisor div_; const IntegerValue max_x_; const IntegerValue c_k_; const IntegerValue k_; diff --git a/ortools/sat/2d_orthogonal_packing_testing.cc b/ortools/sat/2d_orthogonal_packing_testing.cc index 7d8dfc6cf5..7445d3c32d 100644 --- a/ortools/sat/2d_orthogonal_packing_testing.cc +++ b/ortools/sat/2d_orthogonal_packing_testing.cc @@ -178,13 +178,12 @@ std::vector MakeItemsFromRectangles( return ranges; } -std::vector -GenerateItemsRectanglesWithNoPairwiseConflict( +std::vector GenerateItemsRectanglesWithNoPairwiseConflict( absl::Span rectangles, double slack_factor, absl::BitGenRef random) { const std::vector range_items = MakeItemsFromRectangles(rectangles, slack_factor, random); - std::vector items; + std::vector items; items.reserve(rectangles.size()); for (int i = 0; i < range_items.size(); ++i) { const RectangleInRange& rec = range_items[i]; @@ -201,13 +200,13 @@ GenerateItemsRectanglesWithNoPairwiseConflict( return items; } -std::vector +std::vector GenerateItemsRectanglesWithNoPairwisePropagation(int num_rectangles, double slack_factor, absl::BitGenRef random) { const std::vector rectangles = GenerateNonConflictingRectangles(num_rectangles, random); - std::vector items = + std::vector items = GenerateItemsRectanglesWithNoPairwiseConflict(rectangles, slack_factor, random); bool done = false; diff --git a/ortools/sat/2d_orthogonal_packing_testing.h b/ortools/sat/2d_orthogonal_packing_testing.h index 3c2abdd52b..b926fe4e0c 100644 --- a/ortools/sat/2d_orthogonal_packing_testing.h +++ b/ortools/sat/2d_orthogonal_packing_testing.h @@ -38,12 +38,11 @@ std::vector MakeItemsFromRectangles( absl::Span rectangles, double slack_factor, absl::BitGenRef random); -std::vector -GenerateItemsRectanglesWithNoPairwiseConflict( +std::vector GenerateItemsRectanglesWithNoPairwiseConflict( absl::Span rectangles, double slack_factor, absl::BitGenRef random); -std::vector +std::vector GenerateItemsRectanglesWithNoPairwisePropagation(int num_rectangles, double slack_factor, absl::BitGenRef random); diff --git a/ortools/sat/2d_try_edge_propagator.cc b/ortools/sat/2d_try_edge_propagator.cc index db93efb5e0..95af23535f 100644 --- a/ortools/sat/2d_try_edge_propagator.cc +++ b/ortools/sat/2d_try_edge_propagator.cc @@ -29,8 +29,8 @@ #include "ortools/sat/diffn_util.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" -#include "ortools/sat/intervals.h" #include "ortools/sat/model.h" +#include "ortools/sat/no_overlap_2d_helper.h" #include "ortools/sat/synchronization.h" #include "ortools/sat/util.h" @@ -39,8 +39,7 @@ namespace sat { int TryEdgeRectanglePropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); - x_.WatchAllTasks(id); - y_.WatchAllTasks(id); + helper_.WatchAllBoxes(id); watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); return id; } @@ -57,7 +56,7 @@ TryEdgeRectanglePropagator::~TryEdgeRectanglePropagator() { } void TryEdgeRectanglePropagator::PopulateActiveBoxRanges() { - const int num_boxes = x_.NumTasks(); + const int num_boxes = helper_.NumBoxes(); placed_boxes_.resize(num_boxes); active_box_ranges_.resize(num_boxes); is_active_.resize(num_boxes); @@ -68,22 +67,20 @@ void TryEdgeRectanglePropagator::PopulateActiveBoxRanges() { changed_mandatory_.clear(); changed_item_.clear(); for (int box = 0; box < num_boxes; ++box) { - const bool inactive = (x_.SizeMin(box) == 0 || y_.SizeMin(box) == 0 || - !x_.IsPresent(box) || !y_.IsPresent(box)); + bool inactive = !helper_.IsPresent(box); + RectangleInRange rec; + if (!inactive) { + rec = helper_.GetItemRangeForSizeMin(box); + if (rec.x_size == 0 || rec.y_size == 0) { + inactive = true; + } + } is_active_[box] = !inactive; if (inactive) { is_in_cache_[box] = false; has_mandatory_region_.Set(box, false); continue; } - const RectangleInRange rec = { - .box_index = box, - .bounding_area = {.x_min = x_.StartMin(box), - .x_max = x_.StartMax(box) + x_.SizeMin(box), - .y_min = y_.StartMin(box), - .y_max = y_.StartMax(box) + y_.SizeMin(box)}, - .x_size = x_.SizeMin(box), - .y_size = y_.SizeMin(box)}; if (is_in_cache_[box] && rec == active_box_ranges_[box]) { DCHECK(mandatory_regions_[box] == rec.GetMandatoryRegion()); DCHECK(has_mandatory_region_[box] == @@ -135,8 +132,10 @@ bool TryEdgeRectanglePropagator::CanPlace( } bool TryEdgeRectanglePropagator::Propagate() { - if (!x_.SynchronizeAndSetTimeDirection(x_is_forward_)) return false; - if (!y_.SynchronizeAndSetTimeDirection(y_is_forward_)) return false; + if (!helper_.SynchronizeAndSetDirection( + x_is_forward_after_swap_, y_is_forward_after_swap_, swap_x_and_y_)) { + return false; + } num_calls_++; @@ -358,8 +357,8 @@ bool TryEdgeRectanglePropagator::ExplainAndPropagate( found_propagations) { for (const auto& [box_index, new_x_min] : found_propagations) { const RectangleInRange& box = active_box_ranges_[box_index]; - x_.ClearReason(); - y_.ClearReason(); + helper_.ClearReason(); + const std::vector minimum_problem_with_propagator = GetMinimumProblemWithPropagation( box_index, new_x_min.has_value() @@ -374,61 +373,56 @@ bool TryEdgeRectanglePropagator::ExplainAndPropagate( const RectangleInRange& box_reason = active_box_ranges_[j]; const int b = box_reason.box_index; - x_.AddStartMinReason(b, box_reason.bounding_area.x_min); - y_.AddStartMinReason(b, box_reason.bounding_area.y_min); + helper_.AddLeftMinReason(b, box_reason.bounding_area.x_min); + helper_.AddBottomMinReason(b, box_reason.bounding_area.y_min); if (j != box_index || !new_x_min.has_value()) { // We don't need to add to the reason the x_max for the box we are // pushing the x_min, except if we found a conflict. - x_.AddStartMaxReason( + helper_.AddLeftMaxReason( b, box_reason.bounding_area.x_max - box_reason.x_size); } - y_.AddStartMaxReason(b, - box_reason.bounding_area.y_max - box_reason.y_size); + helper_.AddBottomMaxReason( + b, box_reason.bounding_area.y_max - box_reason.y_size); - x_.AddSizeMinReason(b); - y_.AddSizeMinReason(b); - - x_.AddPresenceReason(b); - y_.AddPresenceReason(b); + helper_.AddSizeMinReason(b); + helper_.AddPresenceReason(b); } - x_.ImportOtherReasons(y_); if (new_x_min.has_value()) { num_propagations_++; - if (!x_.IncreaseStartMin(box.box_index, *new_x_min)) { + if (!helper_.IncreaseLeftMin(box_index, *new_x_min)) { return false; } } else { num_conflicts_++; - return x_.ReportConflict(); + return helper_.ReportConflict(); } } return true; } -void CreateAndRegisterTryEdgePropagator(SchedulingConstraintHelper* x, - SchedulingConstraintHelper* y, +void CreateAndRegisterTryEdgePropagator(NoOverlap2DConstraintHelper* helper, Model* model, GenericLiteralWatcher* watcher) { TryEdgeRectanglePropagator* try_edge_propagator = - new TryEdgeRectanglePropagator(true, true, x, y, model); + new TryEdgeRectanglePropagator(true, true, false, helper, model); watcher->SetPropagatorPriority(try_edge_propagator->RegisterWith(watcher), 5); model->TakeOwnership(try_edge_propagator); TryEdgeRectanglePropagator* try_edge_propagator_mirrored = - new TryEdgeRectanglePropagator(false, true, x, y, model); + new TryEdgeRectanglePropagator(false, true, false, helper, model); watcher->SetPropagatorPriority( try_edge_propagator_mirrored->RegisterWith(watcher), 5); model->TakeOwnership(try_edge_propagator_mirrored); TryEdgeRectanglePropagator* try_edge_propagator_swap = - new TryEdgeRectanglePropagator(true, true, y, x, model); + new TryEdgeRectanglePropagator(true, true, true, helper, model); watcher->SetPropagatorPriority( try_edge_propagator_swap->RegisterWith(watcher), 5); model->TakeOwnership(try_edge_propagator_swap); TryEdgeRectanglePropagator* try_edge_propagator_swap_mirrored = - new TryEdgeRectanglePropagator(false, true, y, x, model); + new TryEdgeRectanglePropagator(false, true, true, helper, model); watcher->SetPropagatorPriority( try_edge_propagator_swap_mirrored->RegisterWith(watcher), 5); model->TakeOwnership(try_edge_propagator_swap_mirrored); diff --git a/ortools/sat/2d_try_edge_propagator.h b/ortools/sat/2d_try_edge_propagator.h index aac469198f..3ac69f4d14 100644 --- a/ortools/sat/2d_try_edge_propagator.h +++ b/ortools/sat/2d_try_edge_propagator.h @@ -22,8 +22,8 @@ #include "ortools/sat/diffn_util.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" -#include "ortools/sat/intervals.h" #include "ortools/sat/model.h" +#include "ortools/sat/no_overlap_2d_helper.h" #include "ortools/sat/synchronization.h" #include "ortools/sat/util.h" #include "ortools/util/bitset.h" @@ -35,22 +35,21 @@ namespace sat { // try to find the leftmost valid position that is compatible with all the // other boxes. If none is found, it will propagate a conflict. Otherwise, if // it is different from the current x_min, it will propagate the new x_min. -void CreateAndRegisterTryEdgePropagator(SchedulingConstraintHelper* x, - SchedulingConstraintHelper* y, +void CreateAndRegisterTryEdgePropagator(NoOverlap2DConstraintHelper* helper, Model* model, GenericLiteralWatcher* watcher); // Exposed for testing. class TryEdgeRectanglePropagator : public PropagatorInterface { public: - TryEdgeRectanglePropagator(bool x_is_forward, bool y_is_forward, - SchedulingConstraintHelper* x, - SchedulingConstraintHelper* y, Model* model) - : x_(*x), - y_(*y), + TryEdgeRectanglePropagator(bool x_is_forward_after_swap, + bool y_is_forward_after_swap, bool swap_x_and_y, + NoOverlap2DConstraintHelper* helper, Model* model) + : helper_(*helper), shared_stats_(model->GetOrCreate()), - x_is_forward_(x_is_forward), - y_is_forward_(y_is_forward) {} + x_is_forward_after_swap_(x_is_forward_after_swap), + y_is_forward_after_swap_(y_is_forward_after_swap), + swap_x_and_y_(swap_x_and_y) {} ~TryEdgeRectanglePropagator() override; @@ -90,11 +89,11 @@ class TryEdgeRectanglePropagator : public PropagatorInterface { private: void PopulateActiveBoxRanges(); - SchedulingConstraintHelper& x_; - SchedulingConstraintHelper& y_; + NoOverlap2DConstraintHelper& helper_; SharedStatistics* shared_stats_; - bool x_is_forward_; - bool y_is_forward_; + bool x_is_forward_after_swap_; + bool y_is_forward_after_swap_; + bool swap_x_and_y_; std::vector cached_y_hint_; std::vector potential_x_positions_; diff --git a/ortools/sat/2d_try_edge_propagator_test.cc b/ortools/sat/2d_try_edge_propagator_test.cc index 5c0b0317fe..1dfe81e8e0 100644 --- a/ortools/sat/2d_try_edge_propagator_test.cc +++ b/ortools/sat/2d_try_edge_propagator_test.cc @@ -31,6 +31,7 @@ #include "ortools/sat/integer_base.h" #include "ortools/sat/intervals.h" #include "ortools/sat/model.h" +#include "ortools/sat/no_overlap_2d_helper.h" namespace operations_research { namespace sat { @@ -92,9 +93,10 @@ void CheckConflict(const RectangleInRange& box_to_propagate, class TryEdgeRectanglePropagatorForTest : public TryEdgeRectanglePropagator { public: TryEdgeRectanglePropagatorForTest(bool x_is_forward, bool y_is_forward, - SchedulingConstraintHelper* x, - SchedulingConstraintHelper* y, Model* model) - : TryEdgeRectanglePropagator(x_is_forward, y_is_forward, x, y, model) {} + NoOverlap2DConstraintHelper* helper, + Model* model) + : TryEdgeRectanglePropagator(x_is_forward, y_is_forward, false, helper, + model) {} bool ExplainAndPropagate( const std::vector>>& @@ -127,19 +129,14 @@ class TryEdgeRectanglePropagatorForTest : public TryEdgeRectanglePropagator { } private: - static SchedulingConstraintHelper* GetHelperFromModel(Model* model) { - return model->GetOrCreate()->GetOrCreateHelper({}); - } - Model model_; IntervalsRepository* repository_ = model_.GetOrCreate(); std::vector>> propagations_; }; -std::pair -CreateHelper(Model* model, - absl::Span active_box_ranges) { +NoOverlap2DConstraintHelper* CreateHelper( + Model* model, absl::Span active_box_ranges) { std::vector x_intervals; std::vector y_intervals; for (const RectangleInRange& active_box_range : active_box_ranges) { @@ -160,10 +157,8 @@ CreateHelper(Model* model, x_intervals.push_back(x_interval); y_intervals.push_back(y_interval); } - return { - model->GetOrCreate()->GetOrCreateHelper(x_intervals), - model->GetOrCreate()->GetOrCreateHelper( - y_intervals)}; + return model->GetOrCreate()->GetOrCreate2DHelper( + x_intervals, y_intervals); } TEST(TryEdgeRectanglePropagatorTest, Simple) { @@ -195,10 +190,9 @@ TEST(TryEdgeRectanglePropagatorTest, Simple) { { Model model; - auto [x_helper, y_helper] = CreateHelper(&model, active_box_ranges); + auto* helper = CreateHelper(&model, active_box_ranges); - TryEdgeRectanglePropagatorForTest propagator(true, true, x_helper, y_helper, - &model); + TryEdgeRectanglePropagatorForTest propagator(true, true, helper, &model); propagator.Propagate(); EXPECT_THAT(propagator.propagations(), UnorderedElementsAre(Pair(2, IntegerValue(5)))); @@ -209,10 +203,9 @@ TEST(TryEdgeRectanglePropagatorTest, Simple) { active_box_ranges[2].bounding_area.x_min = 0; active_box_ranges[2].bounding_area.x_max = 5; Model model; - auto [x_helper, y_helper] = CreateHelper(&model, active_box_ranges); + auto* helper = CreateHelper(&model, active_box_ranges); - TryEdgeRectanglePropagatorForTest propagator(true, true, x_helper, y_helper, - &model); + TryEdgeRectanglePropagatorForTest propagator(true, true, helper, &model); propagator.Propagate(); EXPECT_THAT(propagator.propagations(), UnorderedElementsAre(Pair(2, std::nullopt))); @@ -233,10 +226,9 @@ TEST(TryEdgeRectanglePropagatorTest, NoConflictForFeasible) { const std::vector input_in_range = MakeItemsFromRectangles(rectangles, 0.6, bit_gen); Model model; - auto [x_helper, y_helper] = CreateHelper(&model, input_in_range); + auto* helper = CreateHelper(&model, input_in_range); - TryEdgeRectanglePropagatorForTest propagator(true, true, x_helper, y_helper, - &model); + TryEdgeRectanglePropagatorForTest propagator(true, true, helper, &model); propagator.Propagate(); EXPECT_THAT(propagator.propagations(), Each(Pair(_, Not(Eq(std::nullopt))))); @@ -272,10 +264,9 @@ TEST(TryEdgeRectanglePropagatorTest, ValidatePropagationsWithConflicts) { const std::vector input_in_range = MakeItemsFromRectangles(rectangles, 0.6, bit_gen); Model model; - auto [x_helper, y_helper] = CreateHelper(&model, input_in_range); + auto* helper = CreateHelper(&model, input_in_range); - TryEdgeRectanglePropagatorForTest propagator(true, true, x_helper, y_helper, - &model); + TryEdgeRectanglePropagatorForTest propagator(true, true, helper, &model); propagator.Propagate(); } } diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 4be4a14b2f..54aa6756e2 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -268,6 +268,7 @@ cc_library( srcs = ["feasibility_jump.cc"], hdrs = ["feasibility_jump.h"], deps = [ + ":combine_solutions", ":constraint_violation", ":cp_model_cc_proto", ":cp_model_checker", @@ -1444,6 +1445,7 @@ cc_library( "//ortools/util:saturated_arithmetic", "//ortools/util:sorted_interval_list", "//ortools/util:strong_integers", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -1539,6 +1541,7 @@ cc_library( ":sat_inprocessing", ":sat_parameters_cc_proto", ":sat_solver", + ":scheduling_helpers", ":synchronization", ":util", "//ortools/base", @@ -1646,6 +1649,43 @@ cc_library( hdrs = ["intervals.h"], deps = [ ":clause", + ":integer", + ":integer_base", + ":integer_expr", + ":linear_constraint", + ":model", + ":no_overlap_2d_helper", + ":sat_base", + ":sat_solver", + ":scheduling_helpers", + "//ortools/base:strong_vector", + "//ortools/util:strong_integers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "intervals_test", + size = "small", + srcs = ["intervals_test.cc"], + deps = [ + ":integer", + ":integer_base", + ":intervals", + ":model", + ":sat_base", + "//ortools/base:gmock_main", + ], +) + +cc_library( + name = "scheduling_helpers", + srcs = ["scheduling_helpers.cc"], + hdrs = ["scheduling_helpers.h"], + deps = [ ":implied_bounds", ":integer", ":integer_base", @@ -1669,10 +1709,24 @@ cc_library( ], ) +cc_library( + name = "no_overlap_2d_helper", + srcs = ["no_overlap_2d_helper.cc"], + hdrs = ["no_overlap_2d_helper.h"], + deps = [ + ":diffn_util", + ":integer", + ":integer_base", + ":model", + ":scheduling_helpers", + "@com_google_absl//absl/types:span", + ], +) + cc_test( - name = "intervals_test", + name = "scheduling_helpers_test", size = "small", - srcs = ["intervals_test.cc"], + srcs = ["scheduling_helpers_test.cc"], deps = [ ":integer", ":integer_base", @@ -1681,6 +1735,7 @@ cc_test( ":model", ":sat_base", ":sat_solver", + ":scheduling_helpers", "//ortools/base:gmock_main", ], ) @@ -1992,6 +2047,7 @@ cc_library( ":intervals", ":model", ":sat_base", + ":scheduling_helpers", "//ortools/util:strong_integers", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", @@ -2013,6 +2069,7 @@ cc_test( ":precedences", ":sat_base", ":sat_solver", + ":scheduling_helpers", ":timetable", "//ortools/base", "//ortools/base:gmock_main", @@ -2032,6 +2089,7 @@ cc_library( ":integer_base", ":intervals", ":model", + ":scheduling_helpers", "//ortools/base:iterator_adaptors", "//ortools/util:strong_integers", "@com_google_absl//absl/log:check", @@ -2100,6 +2158,7 @@ cc_library( ":integer_base", ":intervals", ":model", + ":scheduling_helpers", ":synchronization", ":theta_tree", ":util", @@ -2132,6 +2191,7 @@ cc_test( ":sat_base", ":sat_parameters_cc_proto", ":sat_solver", + ":scheduling_helpers", "//ortools/base", "//ortools/base:gmock_main", "//ortools/util:strong_integers", @@ -2601,6 +2661,7 @@ cc_library( ":linear_constraint_manager", ":model", ":sat_base", + ":scheduling_helpers", ":util", "//ortools/base:stl_util", "//ortools/base:strong_vector", @@ -2924,7 +2985,7 @@ cc_library( hdrs = ["diffn_util.h"], deps = [ ":integer_base", - ":intervals", + ":scheduling_helpers", ":util", "//ortools/base", "//ortools/base:stl_util", @@ -2957,6 +3018,7 @@ cc_library( ":integer_base", ":synchronization", ":util", + "//ortools/base:constant_divisor", "//ortools/util:bitset", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -3068,8 +3130,8 @@ cc_library( ":diffn_util", ":integer", ":integer_base", - ":intervals", ":model", + ":no_overlap_2d_helper", ":synchronization", ":util", "//ortools/algorithms:set_cover_heuristics", @@ -3093,6 +3155,7 @@ cc_test( ":integer_base", ":intervals", ":model", + ":no_overlap_2d_helper", "//ortools/base:gmock_main", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/random", @@ -3111,7 +3174,7 @@ cc_test( ":integer_base", ":util", "//ortools/base", - "//ortools/base:gmock_main", + "//ortools/base:gmock", "//ortools/graph:connected_components", "//ortools/graph:strongly_connected_components", "//ortools/util:saturated_arithmetic", @@ -3125,6 +3188,7 @@ cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_benchmark//:benchmark", + "@com_google_fuzztest//fuzztest:fuzztest_gtest_main", ], ) @@ -3142,18 +3206,19 @@ cc_library( ":integer_base", ":integer_expr", ":intervals", - ":linear_constraint", ":model", + ":no_overlap_2d_helper", ":sat_base", ":sat_parameters_cc_proto", + ":scheduling_helpers", ":synchronization", ":timetable", ":util", + "//ortools/util:bitset", "//ortools/util:saturated_arithmetic", "//ortools/util:strong_integers", "//ortools/util:time_limit", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/numeric:bits", @@ -3707,6 +3772,7 @@ cc_library( deps = [ ":cp_model_cc_proto", ":cp_model_checker", + ":model", ":synchronization", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", diff --git a/ortools/sat/all_different.cc b/ortools/sat/all_different.cc index 25ff6c989d..00d619246a 100644 --- a/ortools/sat/all_different.cc +++ b/ortools/sat/all_different.cc @@ -83,8 +83,9 @@ std::function AllDifferentOnBounds( } std::function AllDifferentOnBounds( - const std::vector& vars) { - return [=](Model* model) { + absl::Span vars) { + return [=, vars = std::vector(vars.begin(), vars.end())]( + Model* model) { if (vars.empty()) return; std::vector expressions; expressions.reserve(vars.size()); diff --git a/ortools/sat/all_different.h b/ortools/sat/all_different.h index 44575e5d3b..743ffc8ebf 100644 --- a/ortools/sat/all_different.h +++ b/ortools/sat/all_different.h @@ -45,7 +45,7 @@ std::function AllDifferentBinary( // this will not remove already taken values from inside a domain, but it will // propagates more the domain bounds. std::function AllDifferentOnBounds( - const std::vector& vars); + absl::Span vars); std::function AllDifferentOnBounds( const std::vector& expressions); diff --git a/ortools/sat/circuit.cc b/ortools/sat/circuit.cc index 46961bc39a..402e516b7c 100644 --- a/ortools/sat/circuit.cc +++ b/ortools/sat/circuit.cc @@ -646,10 +646,9 @@ std::function ExactlyOnePerRowAndPerColumn( }; } -void LoadSubcircuitConstraint(int num_nodes, const std::vector& tails, - const std::vector& heads, - const std::vector& literals, - Model* model, +void LoadSubcircuitConstraint(int num_nodes, absl::Span tails, + absl::Span heads, + absl::Span literals, Model* model, bool multiple_subcircuit_through_zero) { const int num_arcs = tails.size(); CHECK_GT(num_arcs, 0); diff --git a/ortools/sat/circuit.h b/ortools/sat/circuit.h index 8f8c50b1b9..97443a562c 100644 --- a/ortools/sat/circuit.h +++ b/ortools/sat/circuit.h @@ -244,10 +244,9 @@ int ReindexArcs(IntContainer* tails, IntContainer* heads, // This just wraps CircuitPropagator. See the comment there to see what this // does. Note that any nodes with no outgoing or no incoming arc will cause the // problem to be UNSAT. One can call ReindexArcs() first to ignore such nodes. -void LoadSubcircuitConstraint(int num_nodes, const std::vector& tails, - const std::vector& heads, - const std::vector& literals, - Model* model, +void LoadSubcircuitConstraint(int num_nodes, absl::Span tails, + absl::Span heads, + absl::Span literals, Model* model, bool multiple_subcircuit_through_zero = false); // TODO(user): Change to a sparse API like for the function above. diff --git a/ortools/sat/combine_solutions.cc b/ortools/sat/combine_solutions.cc index f8efaec5fc..679074675e 100644 --- a/ortools/sat/combine_solutions.cc +++ b/ortools/sat/combine_solutions.cc @@ -24,6 +24,7 @@ #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "ortools/sat/cp_model_checker.h" +#include "ortools/sat/model.h" #include "ortools/sat/synchronization.h" namespace operations_research { @@ -75,5 +76,25 @@ std::optional> FindCombinedSolution( return std::nullopt; } +PushedSolutionPointers PushAndMaybeCombineSolution( + SharedResponseManager* response_manager, const CpModelProto& model_proto, + absl::Span new_solution, const std::string& solution_info, + absl::Span base_solution, Model* model) { + PushedSolutionPointers result = {nullptr, nullptr}; + result.pushed_solution = + response_manager->NewSolution(new_solution, solution_info, model); + if (!base_solution.empty()) { + std::string combined_solution_info = solution_info; + std::optional> combined_solution = + FindCombinedSolution(model_proto, new_solution, base_solution, + response_manager, &combined_solution_info); + if (combined_solution.has_value()) { + result.improved_solution = response_manager->NewSolution( + combined_solution.value(), combined_solution_info, model); + } + } + return result; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/combine_solutions.h b/ortools/sat/combine_solutions.h index 3bd91b1465..7106e9939e 100644 --- a/ortools/sat/combine_solutions.h +++ b/ortools/sat/combine_solutions.h @@ -15,12 +15,14 @@ #define OR_TOOLS_SAT_COMBINE_SOLUTIONS_H_ #include +#include #include #include #include #include "absl/types/span.h" #include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/model.h" #include "ortools/sat/synchronization.h" namespace operations_research { @@ -34,6 +36,21 @@ std::optional> FindCombinedSolution( absl::Span base_solution, const SharedResponseManager* response_manager, std::string* solution_info); +// This is equivalent to calling SharedResponseManager::NewSolution() then, if +// `base_solution` is non-empty, trying to find a combined solution and calling +// SharedResponseManager::NewSolution() again if an improved solution is found. +struct PushedSolutionPointers { + std::shared_ptr::Solution> + pushed_solution; + // nullptr if no improvement was found. + std::shared_ptr::Solution> + improved_solution; +}; +PushedSolutionPointers PushAndMaybeCombineSolution( + SharedResponseManager* response_manager, const CpModelProto& model_proto, + absl::Span new_solution, const std::string& solution_info, + absl::Span base_solution = {}, Model* model = nullptr); + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/cp_model_expand.cc b/ortools/sat/cp_model_expand.cc index a38b84ac6f..844a3111fd 100644 --- a/ortools/sat/cp_model_expand.cc +++ b/ortools/sat/cp_model_expand.cc @@ -58,8 +58,8 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, const ReservoirConstraintProto& reservoir = reservoir_ct->reservoir(); const int num_events = reservoir.time_exprs_size(); - // The encoding will create a circuit constraint and on integer variable per - // events representing the level a that event time. + // The encoding will create a circuit constraint, and one integer variable per + // event (representing the level at that event time). CircuitConstraintProto* circuit = context->working_model->add_constraints()->mutable_circuit(); @@ -70,16 +70,94 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, std::vector level_vars(num_events); for (int i = 0; i < num_events; ++i) { level_vars[i] = context->NewIntVar(Domain(var_min, var_max)); + if (context->HintIsLoaded()) { + // The hint of active events is set later. + context->SetNewVariableHint(level_vars[i], 0); + } + } + + // The hints of the active events, in the order they should appear in the + // circuit. The hints are collected first, and sorted later. + struct ReservoirEventHint { + int index; // In the reservoir constraint. + int64_t time; + int64_t level_change; + }; + std::vector active_event_hints; + bool has_complete_hint = false; + if (context->HintIsLoaded()) { + has_complete_hint = true; + for (int i = 0; i < num_events && has_complete_hint; ++i) { + if (context->VarHasSolutionHint( + PositiveRef(reservoir.active_literals(i)))) { + if (context->LiteralSolutionHint(reservoir.active_literals(i))) { + const std::optional time_hint = + context->GetExpressionSolutionHint(reservoir.time_exprs(i)); + const std::optional change_hint = + context->GetExpressionSolutionHint(reservoir.level_changes(i)); + if (time_hint.has_value() && change_hint.has_value()) { + active_event_hints.push_back( + {i, time_hint.value(), change_hint.value()}); + } else { + has_complete_hint = false; + } + } + } else { + has_complete_hint = false; + } + } + } + // Update the `level_vars` hints by computing the level at each active event. + if (has_complete_hint) { + std::sort(active_event_hints.begin(), active_event_hints.end(), + [](const ReservoirEventHint& a, const ReservoirEventHint& b) { + return a.time < b.time; + }); + int64_t current_level = 0; + for (int i = 0; i < active_event_hints.size(); ++i) { + int j = i; + // Adjust the order of the events occurring at the same time, in the + // circuit, so that, at each node, the level is between `var_min` and + // `var_max`. For instance, if e1 = {t, +1} and e2 = {t, -1}, and if + // `current_level` = 0, `var_min` = -1 and `var_max` = 0, then e2 must + // occur before e1. + while (j < active_event_hints.size() && + active_event_hints[j].time == active_event_hints[i].time && + (current_level + active_event_hints[j].level_change < var_min || + current_level + active_event_hints[j].level_change > var_max)) { + ++j; + } + if (j < active_event_hints.size() && + active_event_hints[j].time == active_event_hints[i].time) { + if (i != j) std::swap(active_event_hints[i], active_event_hints[j]); + current_level += active_event_hints[i].level_change; + context->UpdateVarSolutionHint(level_vars[active_event_hints[i].index], + current_level); + } else { + has_complete_hint = false; + break; + } + } } // For the corner case where all events are absent, we need a potential // self-arc on the start/end circuit node. { + const int all_inactive = context->NewBoolVar("reservoir expansion"); circuit->add_tails(num_events); circuit->add_heads(num_events); - circuit->add_literals(context->NewBoolVar("reservoir expansion")); + circuit->add_literals(all_inactive); + if (has_complete_hint) { + context->SetNewVariableHint(all_inactive, active_event_hints.empty()); + } } + // The index of each event in `active_event_hints`, or -1 if the event's + // "active" hint is false. + std::vector active_event_hint_index(num_events, -1); + for (int i = 0; i < active_event_hints.size(); ++i) { + active_event_hint_index[active_event_hints[i].index] = i; + } for (int i = 0; i < num_events; ++i) { if (!reservoir.active_literals().empty()) { // Add self arc to represent absence. @@ -96,6 +174,11 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, circuit->add_tails(num_events); circuit->add_heads(i); circuit->add_literals(start_var); + if (has_complete_hint) { + context->SetNewVariableHint(start_var, + !active_event_hints.empty() && + active_event_hints.front().index == i); + } // Add enforced linear for demand. { @@ -112,9 +195,15 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, } // Circuit ends at i, no extra constraint there. + const int end_var = context->NewBoolVar("reservoir expansion"); circuit->add_tails(i); circuit->add_heads(num_events); - circuit->add_literals(context->NewBoolVar("reservoir expansion")); + circuit->add_literals(end_var); + if (has_complete_hint) { + context->SetNewVariableHint(end_var, + !active_event_hints.empty() && + active_event_hints.back().index == i); + } } for (int j = 0; j < num_events; ++j) { @@ -134,6 +223,13 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, circuit->add_tails(i); circuit->add_heads(j); circuit->add_literals(arc_i_j); + if (has_complete_hint) { + const int hint_i_index = active_event_hint_index[i]; + const int hint_j_index = active_event_hint_index[j]; + context->SetNewVariableHint(arc_i_j, + hint_i_index != -1 && hint_j_index != -1 && + hint_j_index == hint_i_index + 1); + } // Add enforced linear for time. { @@ -333,6 +429,19 @@ void ExpandReservoir(ConstraintProto* reservoir_ct, PresolveContext* context) { // not(active) => new_var == 0. context->AddImplyInDomain(NegatedRef(active), new_var, Domain(0)); + + if (context->HintIsLoaded() && + context->VarHasSolutionHint(PositiveRef(active))) { + if (context->LiteralSolutionHint(active)) { + const std::optional demand_hint = + context->GetExpressionSolutionHint(demand); + if (demand_hint.has_value()) { + context->SetNewVariableHint(new_var, demand_hint.value()); + } + } else { + context->SetNewVariableHint(new_var, 0); + } + } } } sum->add_domain(reservoir.min_level()); @@ -652,10 +761,14 @@ void ExpandInverse(ConstraintProto* ct, PresolveContext* context) { const int r_j = f_inverse[j]; int r_j_i; if (context->HasVarValueEncoding(r_j, i, &r_j_i)) { - context->InsertVarValueEncoding(r_j_i, f_i, j); + if (!context->InsertVarValueEncoding(r_j_i, f_i, j)) { + return; + } } else { const int f_i_j = context->GetOrCreateVarValueEncoding(f_i, j); - context->InsertVarValueEncoding(f_i_j, r_j, i); + if (!context->InsertVarValueEncoding(f_i_j, r_j, i)) { + return; + } } } } @@ -1759,13 +1872,15 @@ void CompressAndExpandPositiveTable(ConstraintProto* ct, // selected or not. Enforce an exactly one between them. BoolArgumentProto* exactly_one = context->working_model->add_constraints()->mutable_exactly_one(); + int exactly_one_hint_sum = 0; std::optional table_is_active_literal = std::nullopt; // Process enforcement literals. if (ct->enforcement_literal().size() == 1) { table_is_active_literal = ct->enforcement_literal(0); } else if (ct->enforcement_literal().size() > 1) { - table_is_active_literal = context->NewBoolVar("table expansion"); + table_is_active_literal = + context->NewBoolVarWithConjunction(ct->enforcement_literal()); // Adds table_is_active <=> and(enforcement_literals). BoolArgumentProto* bool_or = @@ -1776,8 +1891,14 @@ void CompressAndExpandPositiveTable(ConstraintProto* ct, bool_or->add_literals(NegatedRef(lit)); } } + if (table_is_active_literal.has_value()) { + const int inactive_lit = NegatedRef(table_is_active_literal.value()); + exactly_one->add_literals(inactive_lit); + exactly_one_hint_sum += context->LiteralSolutionHintIs(inactive_lit, true); + } - int64_t num_reused_variables = 0; + int num_reused_variables = 0; + std::vector tuples_with_new_variable; std::vector tuple_literals(compressed_table.size()); for (int i = 0; i < compressed_table.size(); ++i) { bool create_new_var = true; @@ -1794,13 +1915,38 @@ void CompressAndExpandPositiveTable(ConstraintProto* ct, create_new_var = false; tuple_literals[i] = context->GetOrCreateVarValueEncoding(vars[var_index], v); + exactly_one_hint_sum += context->SolutionHint(vars[var_index]) == v; break; } if (create_new_var) { tuple_literals[i] = context->NewBoolVar("table expansion"); + tuples_with_new_variable.push_back(i); } exactly_one->add_literals(tuple_literals[i]); } + // Set the hint of the `tuple_literals` for which new variables were created. + // If the existing `tuple_literals` hints do not sum to 1, set the hint of the + // first tuple which can be selected to true, and the others to false. A tuple + // T can be selected if, for each variable v, the hint of v is in the set of + // values T[v] (an empty set means "any value"). + for (const int i : tuples_with_new_variable) { + if (exactly_one_hint_sum >= 1) { + context->SetNewVariableHint(tuple_literals[i], false); + continue; + } + bool tuple_literal_hint = true; + for (int var_index = 0; var_index < num_vars; ++var_index) { + const auto& values = compressed_table[i][var_index]; + if (!values.empty() && + std::find(values.begin(), values.end(), + context->SolutionHint(vars[var_index])) == values.end()) { + tuple_literal_hint = false; + break; + } + } + context->SetNewVariableHint(tuple_literals[i], tuple_literal_hint); + exactly_one_hint_sum += tuple_literal_hint; + } if (num_reused_variables > 0) { context->UpdateRuleStats("table: reused literals"); } @@ -1826,10 +1972,6 @@ void CompressAndExpandPositiveTable(ConstraintProto* ct, table_is_active_literal, context); } - if (table_is_active_literal.has_value()) { - exactly_one->add_literals(NegatedRef(table_is_active_literal.value())); - } - context->UpdateRuleStats("table: expanded positive constraint"); } diff --git a/ortools/sat/cp_model_expand_test.cc b/ortools/sat/cp_model_expand_test.cc index 4a665dac92..e0313da12c 100644 --- a/ortools/sat/cp_model_expand_test.cc +++ b/ortools/sat/cp_model_expand_test.cc @@ -380,6 +380,44 @@ TEST(ReservoirExpandTest, FalseActive) { EXPECT_EQ(OPTIMAL, response.status()); } +TEST(ReservoirExpandTest, ExpandReservoirUsingCircuitPreservesSolutionHint) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { + reservoir { + max_level: 2 + time_exprs { offset: 1 } + time_exprs { offset: 1 } + time_exprs { offset: 1 } + time_exprs { offset: 2 } + time_exprs { offset: 3 } + level_changes: { offset: -1 } + level_changes: { offset: -1 } + level_changes: { offset: 2 } + level_changes: { offset: -2 } + level_changes: { offset: 1 } + active_literals: 0 + active_literals: 0 + active_literals: 0 + active_literals: 1 + active_literals: 0 + } + } + solution_hint { + vars: [ 0, 1 ] + values: [ 1, 0 ] + } + )pb"); + + SatParameters params; + params.set_expand_reservoir_using_circuit(true); + params.set_log_search_progress(true); + params.set_debug_crash_if_presolve_breaks_hint(true); + CpSolverResponse response = SolveWithParameters(initial_model, params); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); +} + TEST(IntModExpandTest, FzTest) { const CpModelProto initial_model = ParseTestProto(R"pb( variables { name: 'x' domain: 50 domain: 60 } diff --git a/ortools/sat/cp_model_postsolve.cc b/ortools/sat/cp_model_postsolve.cc index 54adbba102..f1e88031a9 100644 --- a/ortools/sat/cp_model_postsolve.cc +++ b/ortools/sat/cp_model_postsolve.cc @@ -415,8 +415,8 @@ void PostsolveResponse(const int64_t num_variables_in_original_model, void FillTightenedDomainInResponse(const CpModelProto& original_model, const CpModelProto& mapping_proto, - const std::vector& postsolve_mapping, - const std::vector& search_domains, + absl::Span postsolve_mapping, + absl::Span search_domains, CpSolverResponse* response, SolverLogger* logger) { // The [0, num_vars) part will contain the tightened domains. diff --git a/ortools/sat/cp_model_postsolve.h b/ortools/sat/cp_model_postsolve.h index efa5c36726..207bd0f3bb 100644 --- a/ortools/sat/cp_model_postsolve.h +++ b/ortools/sat/cp_model_postsolve.h @@ -53,8 +53,8 @@ void PostsolveResponse(int64_t num_variables_in_original_model, // tightened_variables field for more information on the caveats. void FillTightenedDomainInResponse(const CpModelProto& original_model, const CpModelProto& mapping_proto, - const std::vector& postsolve_mapping, - const std::vector& search_domains, + absl::Span postsolve_mapping, + absl::Span search_domains, CpSolverResponse* response, SolverLogger* logger); diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index bc2ddf8ac8..dec8eb2539 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -42,7 +42,9 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "google/protobuf/arena.h" #include "google/protobuf/repeated_field.h" +#include "google/protobuf/repeated_ptr_field.h" #include "google/protobuf/text_format.h" #include "ortools/base/logging.h" #include "ortools/base/mathutil.h" @@ -1804,9 +1806,19 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { linear_for_true); context_->CanonicalizeLinearConstraint(constraint_for_false); context_->CanonicalizeLinearConstraint(constraint_for_true); - context_->UpdateRuleStats("int_prod: boolean affine term"); - context_->UpdateNewConstraintsVariableUsage(); - return RemoveConstraint(ct); + if (PossibleIntegerOverflow(*context_->working_model, + linear_for_false->vars(), + linear_for_false->coeffs()) || + PossibleIntegerOverflow(*context_->working_model, + linear_for_true->vars(), + linear_for_true->coeffs())) { + context_->working_model->mutable_constraints()->RemoveLast(); + context_->working_model->mutable_constraints()->RemoveLast(); + } else { + context_->UpdateRuleStats("int_prod: boolean affine term"); + context_->UpdateNewConstraintsVariableUsage(); + return RemoveConstraint(ct); + } } } @@ -5836,9 +5848,8 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { // Filter absent boxes. int new_size = 0; - std::vector bounding_boxes, fixed_boxes; + std::vector bounding_boxes, fixed_boxes, non_fixed_bounding_boxes; std::vector non_fixed_boxes; - std::vector active_boxes; absl::flat_hash_set fixed_item_indexes; for (int i = 0; i < proto.x_intervals_size(); ++i) { const int x_interval_index = proto.x_intervals(i); @@ -5877,7 +5888,6 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { IntegerValue(context_->EndMax(x_interval_index)), IntegerValue(context_->StartMin(y_interval_index)), IntegerValue(context_->EndMax(y_interval_index))}); - active_boxes.push_back(new_size); if (context_->IntervalIsConstant(x_interval_index) && context_->IntervalIsConstant(y_interval_index) && context_->SizeMax(x_interval_index) > 0 && @@ -5885,6 +5895,7 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { fixed_boxes.push_back(bounding_boxes.back()); fixed_item_indexes.insert(new_size); } else { + non_fixed_bounding_boxes.push_back(bounding_boxes.back()); non_fixed_boxes.push_back( {.box_index = new_size, .bounding_area = bounding_boxes.back(), @@ -5909,15 +5920,25 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { } } - const CompactVectorVector components = GetOverlappingRectangleComponents( - bounding_boxes, absl::MakeSpan(active_boxes)); - // The result of GetOverlappingRectangleComponents() omit singleton components - // thus to check whether a graph is fully connected we must check also the - // size of the unique component. - const bool is_fully_connected = - (components.size() == 1 && components[0].size() == active_boxes.size()) || - (active_boxes.size() <= 1); - if (!is_fully_connected) { + if (new_size < initial_num_boxes) { + context_->UpdateRuleStats("no_overlap_2d: removed inactive boxes"); + ct->mutable_no_overlap_2d()->mutable_x_intervals()->Truncate(new_size); + ct->mutable_no_overlap_2d()->mutable_y_intervals()->Truncate(new_size); + } + + if (new_size == 0) { + context_->UpdateRuleStats("no_overlap_2d: no boxes"); + return RemoveConstraint(ct); + } + + if (new_size == 1) { + context_->UpdateRuleStats("no_overlap_2d: only one box"); + return RemoveConstraint(ct); + } + + const CompactVectorVector components = + GetOverlappingRectangleComponents(bounding_boxes); + if (components.size() > 1) { for (int i = 0; i < components.size(); ++i) { absl::Span boxes = components[i]; if (boxes.size() <= 1) continue; @@ -5962,22 +5983,6 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { return RemoveConstraint(ct); } - if (new_size < initial_num_boxes) { - context_->UpdateRuleStats("no_overlap_2d: removed inactive boxes"); - ct->mutable_no_overlap_2d()->mutable_x_intervals()->Truncate(new_size); - ct->mutable_no_overlap_2d()->mutable_y_intervals()->Truncate(new_size); - } - - if (new_size == 0) { - context_->UpdateRuleStats("no_overlap_2d: no boxes"); - return RemoveConstraint(ct); - } - - if (new_size == 1) { - context_->UpdateRuleStats("no_overlap_2d: only one box"); - return RemoveConstraint(ct); - } - // We check if the fixed boxes are not overlapping so downstream code can // assume it to be true. if (!FindPartialRectangleIntersections(fixed_boxes).empty()) { @@ -5985,7 +5990,7 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { "Two fixed boxes in no_overlap_2d overlap"); } - if (fixed_boxes.size() == active_boxes.size()) { + if (non_fixed_bounding_boxes.empty()) { context_->UpdateRuleStats("no_overlap_2d: all boxes are fixed"); return RemoveConstraint(ct); } @@ -6035,6 +6040,36 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { return RemoveConstraint(ct); } } + // If the non-fixed boxes are disjoint but connected by fixed boxes, we can + // split the constraint and duplicate the fixed boxes. To avoid duplicating + // too many fixed boxes, we do this after we we applied the presolve reducing + // their number to as few as possible. + const CompactVectorVector non_fixed_components = + GetOverlappingRectangleComponents(non_fixed_bounding_boxes); + if (non_fixed_components.size() > 1) { + for (int i = 0; i < non_fixed_components.size(); ++i) { + // Note: we care about components of size 1 because they might be + // overlapping with the fixed boxes. + absl::Span indexes = non_fixed_components[i]; + + NoOverlap2DConstraintProto* new_no_overlap_2d = + context_->working_model->add_constraints()->mutable_no_overlap_2d(); + for (const int idx : indexes) { + const int b = non_fixed_boxes[idx].box_index; + new_no_overlap_2d->add_x_intervals(proto.x_intervals(b)); + new_no_overlap_2d->add_y_intervals(proto.y_intervals(b)); + } + for (const int b : fixed_item_indexes) { + new_no_overlap_2d->add_x_intervals(proto.x_intervals(b)); + new_no_overlap_2d->add_y_intervals(proto.y_intervals(b)); + } + } + context_->UpdateNewConstraintsVariableUsage(); + context_->UpdateRuleStats( + "no_overlap_2d: split into disjoint components duplicating fixed " + "boxes"); + return RemoveConstraint(ct); + } RunPropagatorsForConstraint(*ct); return new_size < initial_num_boxes; } @@ -9295,8 +9330,9 @@ void CpModelPresolver::DetectDuplicateColumns() { if (rep_to_dups[var].empty()) continue; // Since columns are the same, we can introduce a new variable = sum all - // columns. Note that we shouldn't have any overflow here by the - // precondition on our variable domains. + // columns. Note that the linear expression will not overflow, but the + // overflow check also requires that max_sum < int_max/2, which might + // happen. // // In the corner case where there is a lot of holes in the domain, and the // sum domain is too complex, we skip. Hopefully this should be rare. @@ -9318,7 +9354,10 @@ void CpModelPresolver::DetectDuplicateColumns() { } const int new_var = context_->NewIntVarWithDefinition( domain, definition, /*append_constraint_to_mapping_model=*/true); - CHECK_NE(new_var, -1); + if (new_var == -1) { + context_->UpdateRuleStats("TODO duplicate: possible overflow"); + continue; + } var_to_remove.push_back(var); CHECK_EQ(var_to_rep[var], -1); @@ -13359,7 +13398,7 @@ bool ImportModelWithBasicPresolveIntoContext(const CpModelProto& in_model, } bool ImportModelAndDomainsWithBasicPresolveIntoContext( - const CpModelProto& in_model, const std::vector& domains, + const CpModelProto& in_model, absl::Span domains, std::function active_constraints, PresolveContext* context) { CHECK_EQ(domains.size(), in_model.variables_size()); ModelCopy copier(context); @@ -14357,19 +14396,25 @@ void ApplyVariableMapping(absl::Span mapping, } // Move the variable definitions. - std::vector new_variables; + google::protobuf::RepeatedPtrField + new_variables_storage; + google::protobuf::RepeatedPtrField* new_variables; + if (proto->GetArena() == nullptr) { + new_variables = &new_variables_storage; + } else { + new_variables = google::protobuf::Arena::Create< + google::protobuf::RepeatedPtrField>( + proto->GetArena()); + } for (int i = 0; i < mapping.size(); ++i) { const int image = mapping[i]; if (image < 0) continue; - if (image >= new_variables.size()) { - new_variables.resize(image + 1, IntegerVariableProto()); + while (image >= new_variables->size()) { + new_variables->Add(); } - new_variables[image].Swap(proto->mutable_variables(i)); - } - proto->clear_variables(); - for (IntegerVariableProto& proto_ref : new_variables) { - proto->add_variables()->Swap(&proto_ref); + (*new_variables)[image].Swap(proto->mutable_variables(i)); } + proto->mutable_variables()->Swap(new_variables); // Check that all variables have a non-empty domain. for (const IntegerVariableProto& v : proto->variables()) { diff --git a/ortools/sat/cp_model_presolve.h b/ortools/sat/cp_model_presolve.h index ac6feda71e..64d6496352 100644 --- a/ortools/sat/cp_model_presolve.h +++ b/ortools/sat/cp_model_presolve.h @@ -498,7 +498,7 @@ bool ImportModelWithBasicPresolveIntoContext(const CpModelProto& in_model, // Same as ImportModelWithBasicPresolveIntoContext() except that variable // domains are read from domains. bool ImportModelAndDomainsWithBasicPresolveIntoContext( - const CpModelProto& in_model, const std::vector& domains, + const CpModelProto& in_model, absl::Span domains, std::function active_constraints, PresolveContext* context); // Copies the non constraint, non variables part of the model. diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 22a7864a1a..afb1129c7f 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -1551,8 +1551,9 @@ class LnsSolver : public SubSolver { if (absl::MakeSpan(solution_values) != absl::MakeSpan(base_response.solution())) { new_solution = true; - shared_->response->NewSolution(solution_values, solution_info, - /*model=*/nullptr); + PushAndMaybeCombineSolution( + shared_->response, shared_->model_proto, solution_values, + solution_info, base_response.solution(), /*model=*/nullptr); } } if (!neighborhood.is_reduced && @@ -1564,19 +1565,6 @@ class LnsSolver : public SubSolver { } } - if (new_solution && !base_response.solution().empty()) { - std::string combined_solution_info = solution_info; - std::optional> combined_solution = - FindCombinedSolution(shared_->model_proto, solution_values, - base_response.solution(), shared_->response, - &combined_solution_info); - if (combined_solution.has_value()) { - shared_->response->NewSolution(combined_solution.value(), - combined_solution_info, - /*model=*/nullptr); - } - } - generator_->AddSolveData(data); if (VLOG_IS_ON(2) && display_lns_info) { diff --git a/ortools/sat/cp_model_solver_fuzz.cc b/ortools/sat/cp_model_solver_fuzz.cc deleted file mode 100644 index 791201ce8a..0000000000 --- a/ortools/sat/cp_model_solver_fuzz.cc +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2010-2025 Google LLC -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "absl/log/check.h" -#include "gtest/gtest.h" // IWYU pragma: keep -#include "ortools/base/fuzztest.h" -#include "ortools/base/path.h" // IWYU pragma: keep -#include "ortools/sat/cp_model.pb.h" -#include "ortools/sat/cp_model_solver.h" -#include "tools/cpp/runfiles/runfiles.h" - -namespace operations_research::sat { - -namespace { - -std::string GetTestDataDir() { - return file::JoinPathRespectAbsolute(::testing::SrcDir(), - "_main/ortools/sat/fuzz_testdata"); -} - -void Solve(const CpModelProto& proto) { - SatParameters params; - params.set_max_time_in_seconds(4.0); - params.set_debug_crash_if_presolve_breaks_hint(true); - - // Enable all fancy heuristics. - params.set_linearization_level(2); - params.set_use_try_edge_reasoning_in_no_overlap_2d(true); - params.set_exploit_all_precedences(true); - params.set_use_hard_precedences_in_cumulative(true); - params.set_max_num_intervals_for_timetable_edge_finding(1000); - params.set_use_overload_checker_in_cumulative(true); - params.set_use_strong_propagation_in_disjunctive(true); - params.set_use_timetable_edge_finding_in_cumulative(true); - params.set_max_pairs_pairwise_reasoning_in_no_overlap_2d(50000); - params.set_use_timetabling_in_no_overlap_2d(true); - params.set_use_energetic_reasoning_in_no_overlap_2d(true); - params.set_use_area_energetic_reasoning_in_no_overlap_2d(true); - params.set_use_conservative_scale_overload_checker(true); - params.set_use_dual_scheduling_heuristics(true); - - const CpSolverResponse response = - operations_research::sat::SolveWithParameters(proto, params); - - params.set_cp_model_presolve(false); - const CpSolverResponse response_no_presolve = - operations_research::sat::SolveWithParameters(proto, params); - - CHECK_EQ(response.status() == CpSolverStatus::MODEL_INVALID, - response_no_presolve.status() == CpSolverStatus::MODEL_INVALID) - << "Model being invalid should not depend on presolve"; - - if (response.status() == CpSolverStatus::MODEL_INVALID) { - return; - } - - if (response.status() == CpSolverStatus::UNKNOWN || - response_no_presolve.status() == CpSolverStatus::UNKNOWN) { - return; - } - - CHECK_EQ(response.status() == CpSolverStatus::INFEASIBLE, - response_no_presolve.status() == CpSolverStatus::INFEASIBLE) - << "Presolve should not change feasibility"; -} - -// Fuzzing repeats solve() 100 times, and timeout after 600s. -// With a time limit of 4s, we should be fine. -FUZZ_TEST(CpModelProtoFuzzer, Solve) - .WithDomains(/*proto:*/ fuzztest::Arbitrary()) - .WithSeeds([]() { - return fuzztest::ReadFilesFromDirectory(GetTestDataDir()); - }); - -} // namespace -} // namespace operations_research::sat diff --git a/ortools/sat/cumulative_energy.cc b/ortools/sat/cumulative_energy.cc index cb8b2a48a4..36e5d617c0 100644 --- a/ortools/sat/cumulative_energy.cc +++ b/ortools/sat/cumulative_energy.cc @@ -30,8 +30,8 @@ #include "ortools/sat/diffn_util.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" -#include "ortools/sat/intervals.h" #include "ortools/sat/model.h" +#include "ortools/sat/scheduling_helpers.h" #include "ortools/sat/synchronization.h" #include "ortools/sat/theta_tree.h" #include "ortools/sat/util.h" @@ -77,7 +77,7 @@ CumulativeEnergyConstraint::CumulativeEnergyConstraint( void CumulativeEnergyConstraint::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); - helper_->WatchAllTasks(id, watcher); + helper_->WatchAllTasks(id); watcher->SetPropagatorPriority(id, 2); watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); } @@ -430,7 +430,7 @@ CumulativeDualFeasibleEnergyConstraint:: void CumulativeDualFeasibleEnergyConstraint::RegisterWith( GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); - helper_->WatchAllTasks(id, watcher); + helper_->WatchAllTasks(id); watcher->SetPropagatorPriority(id, 3); watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); } diff --git a/ortools/sat/cumulative_energy_test.cc b/ortools/sat/cumulative_energy_test.cc index 57f4188ff3..2252474527 100644 --- a/ortools/sat/cumulative_energy_test.cc +++ b/ortools/sat/cumulative_energy_test.cc @@ -43,6 +43,7 @@ #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" +#include "ortools/sat/scheduling_helpers.h" #include "ortools/util/strong_integers.h" namespace operations_research { @@ -107,9 +108,8 @@ bool SolveUsingConstraint(const EnergyInstance& instance) { const AffineExpression capacity( model.Add(ConstantIntegerVariable(instance.capacity))); - SchedulingConstraintHelper* helper = - new SchedulingConstraintHelper(intervals, &model); - model.TakeOwnership(helper); + auto* repo = model.GetOrCreate(); + SchedulingConstraintHelper* helper = repo->GetOrCreateHelper(intervals); SchedulingDemandHelper* demands_helper = new SchedulingDemandHelper({}, helper, &model); demands_helper->OverrideLinearizedEnergies(energies); @@ -298,9 +298,8 @@ bool TestOverloadCheckerPropagation( EXPECT_TRUE(precedences->Propagate()); // Propagator responsible for filtering the capacity variable. - SchedulingConstraintHelper* helper = - new SchedulingConstraintHelper(interval_vars, &model); - model.TakeOwnership(helper); + auto* repo = model.GetOrCreate(); + SchedulingConstraintHelper* helper = repo->GetOrCreateHelper(interval_vars); SchedulingDemandHelper* demands_helper = new SchedulingDemandHelper(demands, helper, &model); model.TakeOwnership(demands_helper); @@ -408,9 +407,8 @@ TEST(OverloadCheckerTest, OptionalTaskPropagatedToAbsent) { const IntervalVariable i1 = model.Add(NewOptionalInterval(0, 10, /*size=*/8, is_present)); - SchedulingConstraintHelper* helper = - new SchedulingConstraintHelper({i1, i2}, &model); - model.TakeOwnership(helper); + auto* repo = model.GetOrCreate(); + SchedulingConstraintHelper* helper = repo->GetOrCreateHelper({i1, i2}); const AffineExpression cte(IntegerValue(2)); SchedulingDemandHelper* demands_helper = new SchedulingDemandHelper({cte, cte}, helper, &model); @@ -429,9 +427,8 @@ TEST(OverloadCheckerTest, OptionalTaskMissedPropagationCase) { const IntervalVariable i2 = model.Add(NewOptionalInterval(0, 10, /*size=*/8, is_present)); - SchedulingConstraintHelper* helper = - new SchedulingConstraintHelper({i1, i2}, &model); - model.TakeOwnership(helper); + auto* repo = model.GetOrCreate(); + SchedulingConstraintHelper* helper = repo->GetOrCreateHelper({i1, i2}); const AffineExpression cte(IntegerValue(2)); SchedulingDemandHelper* demands_helper = new SchedulingDemandHelper({cte, cte}, helper, &model); @@ -513,9 +510,8 @@ bool TestIsAfterCumulative(absl::Span tasks, EXPECT_TRUE(precedences->Propagate()); // Propagator responsible for filtering the capacity variable. - SchedulingConstraintHelper* helper = - new SchedulingConstraintHelper(interval_vars, &model); - model.TakeOwnership(helper); + auto* repo = model.GetOrCreate(); + SchedulingConstraintHelper* helper = repo->GetOrCreateHelper(interval_vars); SchedulingDemandHelper* demands_helper = new SchedulingDemandHelper(demands, helper, &model); model.TakeOwnership(demands_helper); diff --git a/ortools/sat/cuts.cc b/ortools/sat/cuts.cc index 41b2f4b150..fbc65d13c7 100644 --- a/ortools/sat/cuts.cc +++ b/ortools/sat/cuts.cc @@ -2382,7 +2382,7 @@ IntegerValue SumOfAllDiffLowerBounder::SumOfMinDomainValues() { int count = 0; IntegerValue sum = 0; for (const IntegerValue value : min_values_) { - sum += value; + sum = CapAddI(sum, value); if (++count >= expr_mins_.size()) return sum; } return sum; @@ -2439,6 +2439,8 @@ void TryToGenerateAllDiffCut( std::string max_suffix; const IntegerValue required_max_sum = -negated_diff_maxes.GetBestLowerBound(max_suffix); + if (required_max_sum == std::numeric_limits::max()) continue; + DCHECK_LE(required_min_sum, required_max_sum); if (sum < ToDouble(required_min_sum) - kMinCutViolation || sum > ToDouble(required_max_sum) + kMinCutViolation) { LinearConstraintBuilder cut(model, required_min_sum, required_max_sum); @@ -2462,7 +2464,7 @@ void TryToGenerateAllDiffCut( } // namespace CutGenerator CreateAllDifferentCutGenerator( - const std::vector& exprs, Model* model) { + absl::Span exprs, Model* model) { CutGenerator result; IntegerTrail* integer_trail = model->GetOrCreate(); @@ -2474,37 +2476,38 @@ CutGenerator CreateAllDifferentCutGenerator( gtl::STLSortAndRemoveDuplicates(&result.vars); Trail* trail = model->GetOrCreate(); - result.generate_cuts = [exprs, integer_trail, trail, - model](LinearConstraintManager* manager) { - // These cuts work at all levels but the generator adds too many cuts on - // some instances and degrade the performance so we only use it at level - // 0. - if (trail->CurrentDecisionLevel() > 0) return true; - const auto& lp_values = manager->LpValues(); - std::vector> sorted_exprs; - for (const AffineExpression expr : exprs) { - if (integer_trail->LevelZeroLowerBound(expr) == - integer_trail->LevelZeroUpperBound(expr)) { - continue; - } - sorted_exprs.push_back(std::make_pair(expr.LpValue(lp_values), expr)); - } + result.generate_cuts = + [exprs = std::vector(exprs.begin(), exprs.end()), + integer_trail, trail, model](LinearConstraintManager* manager) { + // These cuts work at all levels but the generator adds too many cuts on + // some instances and degrade the performance so we only use it at level + // 0. + if (trail->CurrentDecisionLevel() > 0) return true; + const auto& lp_values = manager->LpValues(); + std::vector> sorted_exprs; + for (const AffineExpression expr : exprs) { + if (integer_trail->LevelZeroLowerBound(expr) == + integer_trail->LevelZeroUpperBound(expr)) { + continue; + } + sorted_exprs.push_back(std::make_pair(expr.LpValue(lp_values), expr)); + } - TopNCuts top_n_cuts(5); - std::sort(sorted_exprs.begin(), sorted_exprs.end(), - [](std::pair& a, - const std::pair& b) { - return a.first < b.first; - }); - TryToGenerateAllDiffCut(sorted_exprs, *integer_trail, lp_values, top_n_cuts, - model); - // Other direction. - std::reverse(sorted_exprs.begin(), sorted_exprs.end()); - TryToGenerateAllDiffCut(sorted_exprs, *integer_trail, lp_values, top_n_cuts, - model); - top_n_cuts.TransferToManager(manager); - return true; - }; + TopNCuts top_n_cuts(5); + std::sort(sorted_exprs.begin(), sorted_exprs.end(), + [](std::pair& a, + const std::pair& b) { + return a.first < b.first; + }); + TryToGenerateAllDiffCut(sorted_exprs, *integer_trail, lp_values, + top_n_cuts, model); + // Other direction. + std::reverse(sorted_exprs.begin(), sorted_exprs.end()); + TryToGenerateAllDiffCut(sorted_exprs, *integer_trail, lp_values, + top_n_cuts, model); + top_n_cuts.TransferToManager(manager); + return true; + }; VLOG(2) << "Created all_diff cut generator of size: " << exprs.size(); return result; } @@ -2572,9 +2575,10 @@ double ComputeContribution( } } // namespace -CutGenerator CreateLinMaxCutGenerator( - const IntegerVariable target, const std::vector& exprs, - const std::vector& z_vars, Model* model) { +CutGenerator CreateLinMaxCutGenerator(const IntegerVariable target, + absl::Span exprs, + absl::Span z_vars, + Model* model) { CutGenerator result; std::vector x_vars; result.vars = {target}; @@ -2591,55 +2595,58 @@ CutGenerator CreateLinMaxCutGenerator( result.vars.insert(result.vars.end(), x_vars.begin(), x_vars.end()); IntegerTrail* integer_trail = model->GetOrCreate(); - result.generate_cuts = [x_vars, z_vars, target, num_exprs, exprs, - integer_trail, - model](LinearConstraintManager* manager) { - const auto& lp_values = manager->LpValues(); - util_intops::StrongVector variable_partition( - lp_values.size(), -1); - util_intops::StrongVector - variable_partition_contrib(lp_values.size(), - std::numeric_limits::infinity()); - for (int expr_index = 0; expr_index < num_exprs; ++expr_index) { - for (const IntegerVariable var : x_vars) { - const double contribution = ComputeContribution( - var, z_vars, exprs, lp_values, *integer_trail, expr_index); - const double prev_contribution = variable_partition_contrib[var]; - if (contribution < prev_contribution) { - variable_partition[var] = expr_index; - variable_partition_contrib[var] = contribution; + result.generate_cuts = + [x_vars, + z_vars = std::vector(z_vars.begin(), z_vars.end()), + target, num_exprs, + exprs = std::vector(exprs.begin(), exprs.end()), + integer_trail, model](LinearConstraintManager* manager) { + const auto& lp_values = manager->LpValues(); + util_intops::StrongVector variable_partition( + lp_values.size(), -1); + util_intops::StrongVector + variable_partition_contrib(lp_values.size(), + std::numeric_limits::infinity()); + for (int expr_index = 0; expr_index < num_exprs; ++expr_index) { + for (const IntegerVariable var : x_vars) { + const double contribution = ComputeContribution( + var, z_vars, exprs, lp_values, *integer_trail, expr_index); + const double prev_contribution = variable_partition_contrib[var]; + if (contribution < prev_contribution) { + variable_partition[var] = expr_index; + variable_partition_contrib[var] = contribution; + } + } } - } - } - LinearConstraintBuilder cut(model, /*lb=*/IntegerValue(0), - /*ub=*/kMaxIntegerValue); - double violation = lp_values[target]; - cut.AddTerm(target, IntegerValue(-1)); + LinearConstraintBuilder cut(model, /*lb=*/IntegerValue(0), + /*ub=*/kMaxIntegerValue); + double violation = lp_values[target]; + cut.AddTerm(target, IntegerValue(-1)); - for (const IntegerVariable xi_var : x_vars) { - const int input_index = variable_partition[xi_var]; - const LinearExpression& expr = exprs[input_index]; - const IntegerValue coeff = GetCoefficientOfPositiveVar(xi_var, expr); - if (coeff != IntegerValue(0)) { - cut.AddTerm(xi_var, coeff); - } - violation -= ToDouble(coeff) * lp_values[xi_var]; - } - for (int expr_index = 0; expr_index < num_exprs; ++expr_index) { - const IntegerVariable z_var = z_vars[expr_index]; - const IntegerValue z_coeff = MPlusCoefficient( - x_vars, exprs, variable_partition, expr_index, *integer_trail); - if (z_coeff != IntegerValue(0)) { - cut.AddTerm(z_var, z_coeff); - } - violation -= ToDouble(z_coeff) * lp_values[z_var]; - } - if (violation > 1e-2) { - manager->AddCut(cut.Build(), "LinMax"); - } - return true; - }; + for (const IntegerVariable xi_var : x_vars) { + const int input_index = variable_partition[xi_var]; + const LinearExpression& expr = exprs[input_index]; + const IntegerValue coeff = GetCoefficientOfPositiveVar(xi_var, expr); + if (coeff != IntegerValue(0)) { + cut.AddTerm(xi_var, coeff); + } + violation -= ToDouble(coeff) * lp_values[xi_var]; + } + for (int expr_index = 0; expr_index < num_exprs; ++expr_index) { + const IntegerVariable z_var = z_vars[expr_index]; + const IntegerValue z_coeff = MPlusCoefficient( + x_vars, exprs, variable_partition, expr_index, *integer_trail); + if (z_coeff != IntegerValue(0)) { + cut.AddTerm(z_var, z_coeff); + } + violation -= ToDouble(z_coeff) * lp_values[z_var]; + } + if (violation > 1e-2) { + manager->AddCut(cut.Build(), "LinMax"); + } + return true; + }; return result; } @@ -2734,7 +2741,7 @@ CutGenerator CreateMaxAffineCutGenerator( } CutGenerator CreateCliqueCutGenerator( - const std::vector& base_variables, Model* model) { + absl::Span base_variables, Model* model) { // Filter base_variables to only keep the one with a literal view, and // do the conversion. std::vector variables; diff --git a/ortools/sat/cuts.h b/ortools/sat/cuts.h index 9a47f09b2c..a51dcb835e 100644 --- a/ortools/sat/cuts.h +++ b/ortools/sat/cuts.h @@ -654,7 +654,7 @@ CutGenerator CreateSquareCutGenerator(AffineExpression y, AffineExpression x, // cuts of the form described above if they are violated by lp solution. Note // that all the fixed variables are ignored while generating cuts. CutGenerator CreateAllDifferentCutGenerator( - const std::vector& exprs, Model* model); + absl::Span exprs, Model* model); // Consider the Lin Max constraint with d expressions and n variables in the // form: target = max {exprs[k] = Sum (wki * xi + bk)}. k in {1,..,d}. @@ -693,9 +693,10 @@ CutGenerator CreateAllDifferentCutGenerator( // // Note: This cut generator requires all expressions to contain only positive // vars. -CutGenerator CreateLinMaxCutGenerator( - IntegerVariable target, const std::vector& exprs, - const std::vector& z_vars, Model* model); +CutGenerator CreateLinMaxCutGenerator(IntegerVariable target, + absl::Span exprs, + absl::Span z_vars, + Model* model); // Helper for the affine max constraint. // @@ -718,7 +719,7 @@ CutGenerator CreateMaxAffineCutGenerator( // create a generator that will returns constraint of the form "at_most_one" // between such literals. CutGenerator CreateCliqueCutGenerator( - const std::vector& base_variables, Model* model); + absl::Span base_variables, Model* model); // Utility class for the AllDiff cut generator. class SumOfAllDiffLowerBounder { @@ -727,6 +728,7 @@ class SumOfAllDiffLowerBounder { void Add(const AffineExpression& expr, int num_expr, const IntegerTrail& integer_trail); + // Return int_max if the sum overflows. IntegerValue SumOfMinDomainValues(); IntegerValue SumOfDifferentMins(); IntegerValue GetBestLowerBound(std::string& suffix); diff --git a/ortools/sat/diffn.cc b/ortools/sat/diffn.cc index a3650cedcd..9240323d8a 100644 --- a/ortools/sat/diffn.cc +++ b/ortools/sat/diffn.cc @@ -26,7 +26,6 @@ #include #include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/numeric/bits.h" #include "absl/types/span.h" @@ -40,10 +39,11 @@ #include "ortools/sat/integer_base.h" #include "ortools/sat/integer_expr.h" #include "ortools/sat/intervals.h" -#include "ortools/sat/linear_constraint.h" #include "ortools/sat/model.h" +#include "ortools/sat/no_overlap_2d_helper.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/scheduling_helpers.h" #include "ortools/sat/timetable.h" #include "ortools/util/saturated_arithmetic.h" #include "ortools/util/strong_integers.h" @@ -144,43 +144,23 @@ void AddDiffnCumulativeRelationOnX(SchedulingConstraintHelper* x, } } -// This function will fill the helper why the two boxes always overlap on that -// dimension. -void ClearAndAddMandatoryOverlapReason(int box1, int box2, - SchedulingConstraintHelper* helper) { - helper->ClearReason(); - helper->AddPresenceReason(box1); - helper->AddPresenceReason(box2); - helper->AddReasonForBeingBefore(box1, box2); - helper->AddReasonForBeingBefore(box2, box1); -} - -bool ClearAndAddTwoBoxesConflictReason(int box1, int box2, - SchedulingConstraintHelper* x, - SchedulingConstraintHelper* y) { - ClearAndAddMandatoryOverlapReason(box1, box2, x); - ClearAndAddMandatoryOverlapReason(box1, box2, y); - x->ImportOtherReasons(*y); - return x->ReportConflict(); -} - } // namespace void AddNonOverlappingRectangles(const std::vector& x, const std::vector& y, Model* model) { IntervalsRepository* repository = model->GetOrCreate(); - SchedulingConstraintHelper* x_helper = repository->GetOrCreateHelper(x); - SchedulingConstraintHelper* y_helper = repository->GetOrCreateHelper(y); + NoOverlap2DConstraintHelper* no_overlap_helper = + repository->GetOrCreate2DHelper(x, y); NonOverlappingRectanglesDisjunctivePropagator* constraint = - new NonOverlappingRectanglesDisjunctivePropagator(x_helper, y_helper, + new NonOverlappingRectanglesDisjunctivePropagator(no_overlap_helper, model); constraint->Register(/*fast_priority=*/3, /*slow_priority=*/4); model->TakeOwnership(constraint); RectanglePairwisePropagator* pairwise_propagator = - new RectanglePairwisePropagator(x_helper, y_helper, model); + new RectanglePairwisePropagator(no_overlap_helper, model); GenericLiteralWatcher* const watcher = model->GetOrCreate(); watcher->SetPropagatorPriority(pairwise_propagator->RegisterWith(watcher), 4); @@ -192,6 +172,9 @@ void AddNonOverlappingRectangles(const std::vector& x, params.use_energetic_reasoning_in_no_overlap_2d(); if (add_cumulative_relaxation) { + SchedulingConstraintHelper* x_helper = &no_overlap_helper->x_helper(); + SchedulingConstraintHelper* y_helper = &no_overlap_helper->y_helper(); + // We must first check if the cumulative relaxation is possible. bool some_boxes_are_only_optional_on_x = false; bool some_boxes_are_only_optional_on_y = false; @@ -222,7 +205,7 @@ void AddNonOverlappingRectangles(const std::vector& x, if (params.use_area_energetic_reasoning_in_no_overlap_2d()) { NonOverlappingRectanglesEnergyPropagator* energy_constraint = - new NonOverlappingRectanglesEnergyPropagator(x_helper, y_helper, model); + new NonOverlappingRectanglesEnergyPropagator(no_overlap_helper, model); GenericLiteralWatcher* const watcher = model->GetOrCreate(); watcher->SetPropagatorPriority(energy_constraint->RegisterWith(watcher), 5); @@ -230,7 +213,7 @@ void AddNonOverlappingRectangles(const std::vector& x, } if (params.use_try_edge_reasoning_in_no_overlap_2d()) { - CreateAndRegisterTryEdgePropagator(x_helper, y_helper, model, watcher); + CreateAndRegisterTryEdgePropagator(no_overlap_helper, model, watcher); } } @@ -259,9 +242,8 @@ NonOverlappingRectanglesEnergyPropagator:: bool NonOverlappingRectanglesEnergyPropagator::Propagate() { // TODO(user): double-check/revisit the algo for box of variable sizes. - const int num_boxes = x_.NumTasks(); - if (!x_.SynchronizeAndSetTimeDirection(true)) return false; - if (!y_.SynchronizeAndSetTimeDirection(true)) return false; + const int num_boxes = helper_.NumBoxes(); + if (!helper_.SynchronizeAndSetDirection(true, true, false)) return false; Rectangle bounding_box = {.x_min = std::numeric_limits::max(), .x_max = std::numeric_limits::min(), @@ -270,22 +252,11 @@ bool NonOverlappingRectanglesEnergyPropagator::Propagate() { std::vector active_box_ranges; active_box_ranges.reserve(num_boxes); for (int box = 0; box < num_boxes; ++box) { - if (x_.SizeMin(box) == 0 || y_.SizeMin(box) == 0) continue; - if (!x_.IsPresent(box) || !y_.IsPresent(box)) continue; - - bounding_box.x_min = std::min(bounding_box.x_min, x_.StartMin(box)); - bounding_box.x_max = std::max(bounding_box.x_max, x_.EndMax(box)); - bounding_box.y_min = std::min(bounding_box.y_min, y_.StartMin(box)); - bounding_box.y_max = std::max(bounding_box.y_max, y_.EndMax(box)); - - active_box_ranges.push_back(RectangleInRange{ - .box_index = box, - .bounding_area = {.x_min = x_.StartMin(box), - .x_max = x_.StartMax(box) + x_.SizeMin(box), - .y_min = y_.StartMin(box), - .y_max = y_.StartMax(box) + y_.SizeMin(box)}, - .x_size = x_.SizeMin(box), - .y_size = y_.SizeMin(box)}); + if (!helper_.IsPresent(box)) continue; + RectangleInRange rec = helper_.GetItemRangeForSizeMin(box); + if (rec.x_size == 0 || rec.y_size == 0) continue; + bounding_box.GrowToInclude(rec.bounding_area); + active_box_ranges.push_back(std::move(rec)); } if (active_box_ranges.size() < 2) { @@ -351,7 +322,7 @@ bool NonOverlappingRectanglesEnergyPropagator::Propagate() { if (best_explanation_size == 2) { num_conflicts_two_boxes_++; } - BuildAndReportEnergyTooLarge(generalized_explanation); + helper_.ReportConflictFromInfeasibleBoxRanges(generalized_explanation); return false; } @@ -465,12 +436,13 @@ NonOverlappingRectanglesEnergyPropagator::GeneralizeExplanation( } const RectangleInRange& range = conflict.items_for_opp[items[i].index]; const RectangleInRange item_in_zero_level_range = { - .bounding_area = {.x_min = x_.LevelZeroStartMin(range.box_index), - .x_max = x_.LevelZeroStartMax(range.box_index) + - range.x_size, - .y_min = y_.LevelZeroStartMin(range.box_index), - .y_max = y_.LevelZeroStartMax(range.box_index) + - range.y_size}, + .bounding_area = + {.x_min = helper_.x_helper().LevelZeroStartMin(range.box_index), + .x_max = helper_.x_helper().LevelZeroStartMax(range.box_index) + + range.x_size, + .y_min = helper_.y_helper().LevelZeroStartMin(range.box_index), + .y_max = helper_.y_helper().LevelZeroStartMax(range.box_index) + + range.y_size}, .x_size = range.x_size, .y_size = range.y_size}; // There is no point trying to intersect less the item with the rectangle @@ -512,39 +484,10 @@ NonOverlappingRectanglesEnergyPropagator::GeneralizeExplanation( int NonOverlappingRectanglesEnergyPropagator::RegisterWith( GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); - x_.WatchAllTasks(id); - y_.WatchAllTasks(id); + helper_.WatchAllBoxes(id); return id; } -bool NonOverlappingRectanglesEnergyPropagator::BuildAndReportEnergyTooLarge( - absl::Span ranges) { - if (ranges.size() == 2) { - num_conflicts_two_boxes_++; - return ClearAndAddTwoBoxesConflictReason(ranges[0].box_index, - ranges[1].box_index, &x_, &y_); - } - x_.ClearReason(); - y_.ClearReason(); - for (const auto& range : ranges) { - const int b = range.box_index; - - x_.AddStartMinReason(b, range.bounding_area.x_min); - y_.AddStartMinReason(b, range.bounding_area.y_min); - - x_.AddStartMaxReason(b, range.bounding_area.x_max - range.x_size); - y_.AddStartMaxReason(b, range.bounding_area.y_max - range.y_size); - - x_.AddSizeMinReason(b); - y_.AddSizeMinReason(b); - - x_.AddPresenceReason(b); - y_.AddPresenceReason(b); - } - x_.ImportOtherReasons(y_); - return x_.ReportConflict(); -} - namespace { // We want for different propagation to reuse as much as possible the same @@ -605,64 +548,15 @@ void SplitDisjointBoxes(const SchedulingConstraintHelper& x, } } -// This function assumes that the left and right boxes overlap on the second -// dimension, and that left cannot be after right. -// It checks and pushes the lower bound of the right box and the upper bound -// of the left box if need. -// -// If y is not null, it import the mandatory reason for the overlap on y in -// the x helper. -bool LeftBoxBeforeRightBoxOnFirstDimension(int left, int right, - SchedulingConstraintHelper* x, - SchedulingConstraintHelper* y) { - // left box2 pushes right box2. - const IntegerValue left_end_min = x->EndMin(left); - if (left_end_min > x->StartMin(right)) { - x->ClearReason(); - x->AddPresenceReason(left); - x->AddPresenceReason(right); - x->AddReasonForBeingBefore(left, right); - x->AddEndMinReason(left, left_end_min); - if (y != nullptr) { - // left and right must overlap on y. - ClearAndAddMandatoryOverlapReason(left, right, y); - // Propagate with the complete reason. - x->ImportOtherReasons(*y); - } - RETURN_IF_FALSE(x->IncreaseStartMin(right, left_end_min)); - } - - // right box2 pushes left box2. - const IntegerValue right_start_max = x->StartMax(right); - if (right_start_max < x->EndMax(left)) { - x->ClearReason(); - x->AddPresenceReason(left); - x->AddPresenceReason(right); - x->AddReasonForBeingBefore(left, right); - x->AddStartMaxReason(right, right_start_max); - if (y != nullptr) { - // left and right must overlap on y. - ClearAndAddMandatoryOverlapReason(left, right, y); - // Propagate with the complete reason. - x->ImportOtherReasons(*y); - } - RETURN_IF_FALSE(x->DecreaseEndMax(left, right_start_max)); - } - - return true; -} - } // namespace // Note that x_ and y_ must be initialized with enough intervals when passed // to the disjunctive propagators. NonOverlappingRectanglesDisjunctivePropagator:: - NonOverlappingRectanglesDisjunctivePropagator(SchedulingConstraintHelper* x, - SchedulingConstraintHelper* y, - Model* model) - : global_x_(*x), - global_y_(*y), - x_(x->NumTasks(), model), + NonOverlappingRectanglesDisjunctivePropagator( + NoOverlap2DConstraintHelper* helper, Model* model) + : helper_(helper), + x_(helper->NumBoxes(), model), watcher_(model->GetOrCreate()), time_limit_(model->GetOrCreate()), overload_checker_(&x_), @@ -671,7 +565,8 @@ NonOverlappingRectanglesDisjunctivePropagator:: forward_not_last_(true, &x_), backward_not_last_(false, &x_), forward_edge_finding_(true, &x_), - backward_edge_finding_(false, &x_) {} + backward_edge_finding_(false, &x_), + disjunctive_with_two_items_(&x_) {} NonOverlappingRectanglesDisjunctivePropagator:: ~NonOverlappingRectanglesDisjunctivePropagator() = default; @@ -680,8 +575,7 @@ void NonOverlappingRectanglesDisjunctivePropagator::Register( int fast_priority, int slow_priority) { fast_id_ = watcher_->Register(this); watcher_->SetPropagatorPriority(fast_id_, fast_priority); - global_x_.WatchAllTasks(fast_id_); - global_y_.WatchAllTasks(fast_id_); + helper_->WatchAllBoxes(fast_id_); // This propagator is the one making sure our propagation is complete, so // we do need to make sure it is called again if it modified some bounds. @@ -689,14 +583,11 @@ void NonOverlappingRectanglesDisjunctivePropagator::Register( const int slow_id = watcher_->Register(this); watcher_->SetPropagatorPriority(slow_id, slow_priority); - global_x_.WatchAllTasks(slow_id); - global_y_.WatchAllTasks(slow_id); + helper_->WatchAllBoxes(slow_id); } bool NonOverlappingRectanglesDisjunctivePropagator:: - FindBoxesThatMustOverlapAHorizontalLineAndPropagate( - bool fast_propagation, SchedulingConstraintHelper* x, - SchedulingConstraintHelper* y) { + FindBoxesThatMustOverlapAHorizontalLineAndPropagate(bool fast_propagation) { // When they are many fixed box that we know do not overlap, we compute // the bounding box of the others, and we can exclude all boxes outside this // region. This can help, especially for some LNS neighborhood. @@ -707,10 +598,13 @@ bool NonOverlappingRectanglesDisjunctivePropagator:: // push_back() can be slow as it might not be inlined, so we manage directly // our "boxes" in boxes_data[0 .. num_boxes], with a memory that is always big // enough. - indexed_boxes_.resize(y->NumTasks()); + indexed_boxes_.resize(helper_->NumBoxes()); int num_boxes = 0; IndexedInterval* boxes_data = indexed_boxes_.data(); + SchedulingConstraintHelper* x = &helper_->x_helper(); + SchedulingConstraintHelper* y = &helper_->y_helper(); + // Compute relevant boxes, the one with a mandatory part on y. Because we will // need to sort it this way, we consider them by increasing start max. const auto temp = y->TaskByIncreasingNegatedStartMax(); @@ -744,9 +638,7 @@ bool NonOverlappingRectanglesDisjunctivePropagator:: if (fixed_boxes[box]) { ++num_fixed; } else { - const bool is_fixed = x->StartIsFixed(box) && x->EndIsFixed(box) && - y->StartIsFixed(box) && y->EndIsFixed(box); - if (is_fixed) { + if (helper_->IsFixed(box)) { // We will "check it" below, so it will be checked next time. fixed_boxes.Set(box); } @@ -788,7 +680,7 @@ bool NonOverlappingRectanglesDisjunctivePropagator:: if (boxes.size() < 2) return true; // Optim: Abort if all rectangle can be fixed to their mandatory y + - // minimium x position without any overlap. + // minimum x position without any overlap. // // This is guaranteed to be O(N log N) whereas the algo below is O(N ^ 2). // @@ -885,7 +777,7 @@ bool NonOverlappingRectanglesDisjunctivePropagator:: // In that case, we can use simpler algorithms. // Note that this case happens frequently (~30% of all calls to this // method according to our tests). - RETURN_IF_FALSE(PropagateOnXWhenOnlyTwoBoxes()); + RETURN_IF_FALSE(disjunctive_with_two_items_.Propagate()); } else { RETURN_IF_FALSE(overload_checker_.Propagate()); RETURN_IF_FALSE(forward_detectable_precedences_.Propagate()); @@ -908,15 +800,14 @@ bool NonOverlappingRectanglesDisjunctivePropagator:: // - large problem with many 1000s boxes, but with only a small subset that is // not fixed (mainly coming from LNS). bool NonOverlappingRectanglesDisjunctivePropagator::Propagate() { - if (!global_x_.SynchronizeAndSetTimeDirection(true)) return false; - if (!global_y_.SynchronizeAndSetTimeDirection(true)) return false; + if (!helper_->SynchronizeAndSetDirection(true, true, false)) return false; // If we are "diving" we maintain the set of fixed boxes for which we know // that they are not overlapping. const bool backtrack_since_last_call = !rev_is_in_dive_; watcher_->SetUntilNextBacktrack(&rev_is_in_dive_); if (backtrack_since_last_call) { - const int num_tasks = global_x_.NumTasks(); + const int num_tasks = helper_->NumBoxes(); already_checked_fixed_boxes_.ClearAndResize(num_tasks); } @@ -924,47 +815,20 @@ bool NonOverlappingRectanglesDisjunctivePropagator::Propagate() { // mode. So we will not redo some propagation in slow mode that was already // done by the fast mode. const bool fast_propagation = watcher_->GetCurrentId() == fast_id_; - RETURN_IF_FALSE(FindBoxesThatMustOverlapAHorizontalLineAndPropagate( - fast_propagation, &global_x_, &global_y_)); + RETURN_IF_FALSE( + FindBoxesThatMustOverlapAHorizontalLineAndPropagate(fast_propagation)); // We can actually swap dimensions to propagate vertically. - RETURN_IF_FALSE(FindBoxesThatMustOverlapAHorizontalLineAndPropagate( - fast_propagation, &global_y_, &global_x_)); + if (!helper_->SynchronizeAndSetDirection(true, true, true)) return false; + RETURN_IF_FALSE( + FindBoxesThatMustOverlapAHorizontalLineAndPropagate(fast_propagation)); return true; } -// Specialized propagation on only two boxes that must intersect with the -// given y_line_for_reason. -bool NonOverlappingRectanglesDisjunctivePropagator:: - PropagateOnXWhenOnlyTwoBoxes() { - if (!x_.IsPresent(0) || !x_.IsPresent(1)) return true; - - // For each direction and each order, we test if the boxes can be disjoint. - const int state = - (x_.EndMin(0) <= x_.StartMax(1)) + 2 * (x_.EndMin(1) <= x_.StartMax(0)); - switch (state) { - case 0: { // Conflict. - ClearAndAddMandatoryOverlapReason(0, 1, &x_); - // Note that the secondary helper is set on x. - return x_.ReportConflict(); - } - case 1: { // b1 is left of b2. - return LeftBoxBeforeRightBoxOnFirstDimension(0, 1, &x_, /*y=*/nullptr); - } - case 2: { // b2 is left of b1. - return LeftBoxBeforeRightBoxOnFirstDimension(1, 0, &x_, /*y=*/nullptr); - } - default: { // Nothing to deduce. - return true; - } - } -} - int RectanglePairwisePropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); - global_x_.WatchAllTasks(id); - global_y_.WatchAllTasks(id); + helper_->WatchAllBoxes(id); watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); return id; } @@ -982,8 +846,7 @@ RectanglePairwisePropagator::~RectanglePairwisePropagator() { } bool RectanglePairwisePropagator::Propagate() { - if (!global_x_.SynchronizeAndSetTimeDirection(true)) return false; - if (!global_y_.SynchronizeAndSetTimeDirection(true)) return false; + if (!helper_->SynchronizeAndSetDirection(true, true, false)) return false; num_calls_++; @@ -991,11 +854,10 @@ bool RectanglePairwisePropagator::Propagate() { vertical_zero_area_boxes_.clear(); point_zero_area_boxes_.clear(); non_zero_area_boxes_.clear(); - for (int b = 0; b < global_x_.NumTasks(); ++b) { - if (!global_x_.IsPresent(b) || !global_y_.IsPresent(b)) continue; - const IntegerValue x_size_max = global_x_.SizeMax(b); - const IntegerValue y_size_max = global_y_.SizeMax(b); - ItemForPairwiseRestriction* box; + for (int b = 0; b < helper_->NumBoxes(); ++b) { + if (!helper_->IsPresent(b)) continue; + const auto [x_size_max, y_size_max] = helper_->GetBoxSizesMax(b); + ItemWithVariableSize* box; if (x_size_max == 0) { if (y_size_max == 0) { box = &point_zero_area_boxes_.emplace_back(); @@ -1007,15 +869,7 @@ bool RectanglePairwisePropagator::Propagate() { } else { box = &non_zero_area_boxes_.emplace_back(); } - *box = ItemForPairwiseRestriction{.index = b, - .x = {.start_min = global_x_.StartMin(b), - .start_max = global_x_.StartMax(b), - .end_min = global_x_.EndMin(b), - .end_max = global_x_.EndMax(b)}, - .y = {.start_min = global_y_.StartMin(b), - .start_max = global_y_.StartMax(b), - .end_min = global_y_.EndMin(b), - .end_max = global_y_.EndMax(b)}}; + *box = helper_->GetItemWithVariableSize(b); } std::vector restrictions; @@ -1041,7 +895,7 @@ bool RectanglePairwisePropagator::Propagate() { } bool RectanglePairwisePropagator::FindRestrictionsAndPropagateConflict( - absl::Span items, + absl::Span items, std::vector* restrictions) { const int max_pairs = params_->max_pairs_pairwise_reasoning_in_no_overlap_2d(); @@ -1059,8 +913,8 @@ bool RectanglePairwisePropagator::FindRestrictionsAndPropagateConflict( } bool RectanglePairwisePropagator::FindRestrictionsAndPropagateConflict( - absl::Span items1, - absl::Span items2, + absl::Span items1, + absl::Span items2, std::vector* restrictions) { const int max_pairs = params_->max_pairs_pairwise_reasoning_in_no_overlap_2d(); @@ -1079,30 +933,14 @@ bool RectanglePairwisePropagator::FindRestrictionsAndPropagateConflict( bool RectanglePairwisePropagator::PropagateTwoBoxes( const PairwiseRestriction& restriction) { - const int box1 = restriction.first_index; - const int box2 = restriction.second_index; - switch (restriction.type) { - case PairwiseRestriction::PairwiseRestrictionType::CONFLICT: - num_pairwise_conflicts_++; - return ClearAndAddTwoBoxesConflictReason(box1, box2, &global_x_, - &global_y_); - case PairwiseRestriction::PairwiseRestrictionType::FIRST_LEFT_OF_SECOND: - num_pairwise_propagations_++; - return LeftBoxBeforeRightBoxOnFirstDimension(box1, box2, &global_x_, - &global_y_); - case PairwiseRestriction::PairwiseRestrictionType::FIRST_RIGHT_OF_SECOND: - num_pairwise_propagations_++; - return LeftBoxBeforeRightBoxOnFirstDimension(box2, box1, &global_x_, - &global_y_); - case PairwiseRestriction::PairwiseRestrictionType::FIRST_BELOW_SECOND: - num_pairwise_propagations_++; - return LeftBoxBeforeRightBoxOnFirstDimension(box1, box2, &global_y_, - &global_x_); - case PairwiseRestriction::PairwiseRestrictionType::FIRST_ABOVE_SECOND: - num_pairwise_propagations_++; - return LeftBoxBeforeRightBoxOnFirstDimension(box2, box1, &global_y_, - &global_x_); + if (restriction.type == + PairwiseRestriction::PairwiseRestrictionType::CONFLICT) { + num_pairwise_conflicts_++; + } else { + num_pairwise_propagations_++; } + return helper_->PropagateRelativePosition( + restriction.first_index, restriction.second_index, restriction.type); } #undef RETURN_IF_FALSE diff --git a/ortools/sat/diffn.h b/ortools/sat/diffn.h index 590d900869..0d1628b94a 100644 --- a/ortools/sat/diffn.h +++ b/ortools/sat/diffn.h @@ -24,11 +24,14 @@ #include "ortools/sat/diffn_util.h" #include "ortools/sat/disjunctive.h" #include "ortools/sat/integer.h" -#include "ortools/sat/intervals.h" +#include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" +#include "ortools/sat/no_overlap_2d_helper.h" #include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/scheduling_helpers.h" #include "ortools/sat/synchronization.h" #include "ortools/sat/util.h" +#include "ortools/util/bitset.h" #include "ortools/util/time_limit.h" namespace operations_research { @@ -37,11 +40,9 @@ namespace sat { // Propagates using a box energy reasoning. class NonOverlappingRectanglesEnergyPropagator : public PropagatorInterface { public: - NonOverlappingRectanglesEnergyPropagator(SchedulingConstraintHelper* x, - SchedulingConstraintHelper* y, + NonOverlappingRectanglesEnergyPropagator(NoOverlap2DConstraintHelper* helper, Model* model) - : x_(*x), - y_(*y), + : helper_(*helper), random_(model->GetOrCreate()), shared_stats_(model->GetOrCreate()), orthogonal_packing_checker_(*random_, shared_stats_) {} @@ -65,8 +66,7 @@ class NonOverlappingRectanglesEnergyPropagator : public PropagatorInterface { bool BuildAndReportEnergyTooLarge(absl::Span ranges); - SchedulingConstraintHelper& x_; - SchedulingConstraintHelper& y_; + NoOverlap2DConstraintHelper& helper_; ModelRandomGenerator* random_; SharedStatistics* shared_stats_; OrthogonalPackingInfeasibilityDetector orthogonal_packing_checker_; @@ -98,22 +98,18 @@ class NonOverlappingRectanglesDisjunctivePropagator : public PropagatorInterface { public: // The slow_propagators select which disjunctive algorithms to propagate. - NonOverlappingRectanglesDisjunctivePropagator(SchedulingConstraintHelper* x, - SchedulingConstraintHelper* y, - Model* model); + NonOverlappingRectanglesDisjunctivePropagator( + NoOverlap2DConstraintHelper* helper, Model* model); ~NonOverlappingRectanglesDisjunctivePropagator() override; bool Propagate() final; void Register(int fast_priority, int slow_priority); private: - bool PropagateOnXWhenOnlyTwoBoxes(); bool FindBoxesThatMustOverlapAHorizontalLineAndPropagate( - bool fast_propagation, SchedulingConstraintHelper* x, - SchedulingConstraintHelper* y); + bool fast_propagation); - SchedulingConstraintHelper& global_x_; - SchedulingConstraintHelper& global_y_; + NoOverlap2DConstraintHelper* helper_; SchedulingConstraintHelper x_; GenericLiteralWatcher* watcher_; @@ -143,6 +139,7 @@ class NonOverlappingRectanglesDisjunctivePropagator DisjunctiveNotLast backward_not_last_; DisjunctiveEdgeFinding forward_edge_finding_; DisjunctiveEdgeFinding backward_edge_finding_; + DisjunctiveWithTwoItems disjunctive_with_two_items_; NonOverlappingRectanglesDisjunctivePropagator( const NonOverlappingRectanglesDisjunctivePropagator&) = delete; @@ -153,10 +150,8 @@ class NonOverlappingRectanglesDisjunctivePropagator // Propagator that compares the boxes pairwise. class RectanglePairwisePropagator : public PropagatorInterface { public: - RectanglePairwisePropagator(SchedulingConstraintHelper* x, - SchedulingConstraintHelper* y, Model* model) - : global_x_(*x), - global_y_(*y), + RectanglePairwisePropagator(NoOverlap2DConstraintHelper* helper, Model* model) + : helper_(helper), shared_stats_(model->GetOrCreate()), params_(model->GetOrCreate()) {} @@ -172,18 +167,17 @@ class RectanglePairwisePropagator : public PropagatorInterface { // Return false if a conflict is found. bool FindRestrictionsAndPropagateConflict( - absl::Span items, + absl::Span items, std::vector* restrictions); bool FindRestrictionsAndPropagateConflict( - absl::Span items1, - absl::Span items2, + absl::Span items1, + absl::Span items2, std::vector* restrictions); bool PropagateTwoBoxes(const PairwiseRestriction& restriction); - SchedulingConstraintHelper& global_x_; - SchedulingConstraintHelper& global_y_; + NoOverlap2DConstraintHelper* helper_; SharedStatistics* shared_stats_; const SatParameters* params_; @@ -191,10 +185,10 @@ class RectanglePairwisePropagator : public PropagatorInterface { int64_t num_pairwise_conflicts_ = 0; int64_t num_pairwise_propagations_ = 0; - std::vector non_zero_area_boxes_; - std::vector horizontal_zero_area_boxes_; - std::vector vertical_zero_area_boxes_; - std::vector point_zero_area_boxes_; + std::vector non_zero_area_boxes_; + std::vector horizontal_zero_area_boxes_; + std::vector vertical_zero_area_boxes_; + std::vector point_zero_area_boxes_; }; } // namespace sat diff --git a/ortools/sat/diffn_cuts.cc b/ortools/sat/diffn_cuts.cc index 8944f2861e..85ff82f4d1 100644 --- a/ortools/sat/diffn_cuts.cc +++ b/ortools/sat/diffn_cuts.cc @@ -38,6 +38,7 @@ #include "ortools/sat/linear_constraint_manager.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" +#include "ortools/sat/scheduling_helpers.h" #include "ortools/sat/util.h" #include "ortools/util/strong_integers.h" @@ -330,8 +331,10 @@ CutGenerator CreateNoOverlap2dEnergyCutGenerator( if (!y_demands_helper->CacheAllEnergyValues()) return true; const int num_rectangles = x_helper->NumTasks(); - std::vector active_rectangles; - std::vector cached_rectangles(num_rectangles); + std::vector active_rectangles_indexes; + active_rectangles_indexes.reserve(num_rectangles); + std::vector active_rectangles; + active_rectangles.reserve(num_rectangles); for (int rect = 0; rect < num_rectangles; ++rect) { if (y_helper->IsAbsent(rect) || y_helper->IsAbsent(rect)) continue; // We do not consider rectangles controlled by 2 different unassigned @@ -345,26 +348,30 @@ CutGenerator CreateNoOverlap2dEnergyCutGenerator( // here, but for now this code is not in the hot spot, so better be // defensive and only do connected components on really disjoint // rectangles. - Rectangle& rectangle = cached_rectangles[rect]; + active_rectangles_indexes.push_back(rect); + Rectangle& rectangle = active_rectangles.emplace_back(); rectangle.x_min = x_helper->StartMin(rect); rectangle.x_max = x_helper->EndMax(rect); rectangle.y_min = y_helper->StartMin(rect); rectangle.y_max = y_helper->EndMax(rect); - - active_rectangles.push_back(rect); } if (active_rectangles.size() <= 1) return true; const CompactVectorVector components = - GetOverlappingRectangleComponents(cached_rectangles, - absl::MakeSpan(active_rectangles)); + GetOverlappingRectangleComponents(active_rectangles); // Forward pass. No need to do a backward pass. + std::vector rectangles; for (int i = 0; i < components.size(); ++i) { - absl::Span rectangles = components[i]; - if (rectangles.size() <= 1) continue; + absl::Span indexes = components[i]; + if (indexes.size() <= 1) continue; + rectangles.clear(); + rectangles.reserve(indexes.size()); + for (const int index : indexes) { + rectangles.push_back(active_rectangles_indexes[index]); + } GenerateNoOverlap2dEnergyCut(energies, rectangles, "NoOverlap2dXEnergy", model, manager, x_helper, y_helper, y_demands_helper); @@ -581,9 +588,11 @@ CutGenerator CreateNoOverlap2dCompletionTimeCutGenerator( if (!y_helper->SynchronizeAndSetTimeDirection(true)) return false; const int num_rectangles = x_helper->NumTasks(); - std::vector active_rectangles; + std::vector active_rectangles_indexes; + active_rectangles_indexes.reserve(num_rectangles); + std::vector active_rectangles; + active_rectangles.reserve(num_rectangles); std::vector cached_areas(num_rectangles); - std::vector cached_rectangles(num_rectangles); for (int rect = 0; rect < num_rectangles; ++rect) { if (!y_helper->IsPresent(rect) || !y_helper->IsPresent(rect)) continue; @@ -594,23 +603,28 @@ CutGenerator CreateNoOverlap2dCompletionTimeCutGenerator( // here, but for now this code is not in the hot spot, so better be // defensive and only do connected components on really disjoint // rectangles. - Rectangle& rectangle = cached_rectangles[rect]; + active_rectangles_indexes.push_back(rect); + Rectangle& rectangle = active_rectangles.emplace_back(); rectangle.x_min = x_helper->StartMin(rect); rectangle.x_max = x_helper->EndMax(rect); rectangle.y_min = y_helper->StartMin(rect); rectangle.y_max = y_helper->EndMax(rect); - - active_rectangles.push_back(rect); } if (active_rectangles.size() <= 1) return true; const CompactVectorVector components = - GetOverlappingRectangleComponents(cached_rectangles, - absl::MakeSpan(active_rectangles)); + GetOverlappingRectangleComponents(active_rectangles); + std::vector rectangles; for (int i = 0; i < components.size(); ++i) { - absl::Span rectangles = components[i]; - if (rectangles.size() <= 1) continue; + absl::Span indexes = components[i]; + if (indexes.size() <= 1) continue; + + rectangles.clear(); + rectangles.reserve(indexes.size()); + for (const int index : indexes) { + rectangles.push_back(active_rectangles_indexes[index]); + } auto generate_cuts = [product_decomposer, manager, model, &rectangles]( absl::string_view cut_name, diff --git a/ortools/sat/diffn_cuts.h b/ortools/sat/diffn_cuts.h index 77476b1b4d..c49718d467 100644 --- a/ortools/sat/diffn_cuts.h +++ b/ortools/sat/diffn_cuts.h @@ -21,8 +21,8 @@ #include "ortools/sat/cuts.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" -#include "ortools/sat/intervals.h" #include "ortools/sat/model.h" +#include "ortools/sat/scheduling_helpers.h" namespace operations_research { namespace sat { diff --git a/ortools/sat/diffn_util.cc b/ortools/sat/diffn_util.cc index f22c48f255..a15c84d3d0 100644 --- a/ortools/sat/diffn_util.cc +++ b/ortools/sat/diffn_util.cc @@ -45,7 +45,6 @@ #include "ortools/graph/connected_components.h" #include "ortools/graph/strongly_connected_components.h" #include "ortools/sat/integer_base.h" -#include "ortools/sat/intervals.h" #include "ortools/sat/util.h" #include "ortools/util/fixed_shape_binary_tree.h" #include "ortools/util/integer_pq.h" @@ -107,21 +106,11 @@ absl::InlinedVector Rectangle::RegionDifference( } CompactVectorVector GetOverlappingRectangleComponents( - absl::Span rectangles, - absl::Span active_rectangles) { - if (active_rectangles.empty()) return {}; - - std::vector rectangles_to_process; - std::vector rectangles_index; - rectangles_to_process.reserve(active_rectangles.size()); - rectangles_index.reserve(active_rectangles.size()); - for (const int r : active_rectangles) { - rectangles_to_process.push_back(rectangles[r]); - rectangles_index.push_back(r); - } + absl::Span rectangles) { + if (rectangles.empty()) return {}; std::vector> intersections = - FindPartialRectangleIntersections(rectangles_to_process); + FindPartialRectangleIntersections(rectangles); const int num_intersections = intersections.size(); intersections.reserve(num_intersections * 2 + 1); for (int i = 0; i < num_intersections; ++i) { @@ -136,10 +125,9 @@ CompactVectorVector GetOverlappingRectangleComponents( CompactVectorVector result; for (int i = 0; i < components.size(); ++i) { absl::Span component = components[i]; - if (component.size() == 1) continue; result.Add({}); for (const int r : component) { - result.AppendToLastVector(rectangles_index[r]); + result.AppendToLastVector(r); } } return result; @@ -591,8 +579,8 @@ std::vector GetIntervalArticulationPoints( namespace { bool IsZeroOrPowerOfTwo(int value) { return (value & (value - 1)) == 0; } -void AppendPairwiseRestriction(const ItemForPairwiseRestriction& item1, - const ItemForPairwiseRestriction& item2, +void AppendPairwiseRestriction(const ItemWithVariableSize& item1, + const ItemWithVariableSize& item2, std::vector* result) { const int state = // box1 can be left of box2. @@ -660,9 +648,8 @@ void AppendPairwiseRestriction(const ItemForPairwiseRestriction& item1, } } // namespace -void AppendPairwiseRestrictions( - absl::Span items, - std::vector* result) { +void AppendPairwiseRestrictions(absl::Span items, + std::vector* result) { for (int i1 = 0; i1 + 1 < items.size(); ++i1) { for (int i2 = i1 + 1; i2 < items.size(); ++i2) { AppendPairwiseRestriction(items[i1], items[i2], result); @@ -671,8 +658,8 @@ void AppendPairwiseRestrictions( } void AppendPairwiseRestrictions( - absl::Span items, - absl::Span other_items, + absl::Span items, + absl::Span other_items, std::vector* result) { for (int i1 = 0; i1 < items.size(); ++i1) { for (int i2 = 0; i2 < other_items.size(); ++i2) { diff --git a/ortools/sat/diffn_util.h b/ortools/sat/diffn_util.h index b143855659..a7a924f0d5 100644 --- a/ortools/sat/diffn_util.h +++ b/ortools/sat/diffn_util.h @@ -33,7 +33,7 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "ortools/sat/integer_base.h" -#include "ortools/sat/intervals.h" +#include "ortools/sat/scheduling_helpers.h" #include "ortools/sat/util.h" #include "ortools/util/strong_integers.h" @@ -145,12 +145,8 @@ inline double CenterToCenterLInfinityDistance(const Rectangle& a, // Creates a graph when two nodes are connected iff their rectangles overlap. // Then partition into connected components. -// -// This method removes all singleton components. It will modify the -// active_rectangle span in place. CompactVectorVector GetOverlappingRectangleComponents( - absl::Span rectangles, - absl::Span active_rectangles); + absl::Span rectangles); // Visible for testing. The algo is in O(n^4) so shouldn't be used directly. // Returns true if there exist a bounding box with too much energy. @@ -262,7 +258,7 @@ void GetOverlappingIntervalComponents( std::vector GetIntervalArticulationPoints( std::vector* intervals); -struct ItemForPairwiseRestriction { +struct ItemWithVariableSize { int index; struct Interval { IntegerValue start_min; @@ -295,15 +291,14 @@ struct PairwiseRestriction { // Find pair of items that are either in conflict or could have their range // shrinked to avoid conflict. -void AppendPairwiseRestrictions( - absl::Span items, - std::vector* result); +void AppendPairwiseRestrictions(absl::Span items, + std::vector* result); // Same as above, but test `items` against `other_items` and append the // restrictions found to `result`. void AppendPairwiseRestrictions( - absl::Span items, - absl::Span other_items, + absl::Span items, + absl::Span other_items, std::vector* result); // This class is used by the no_overlap_2d constraint to maintain the envelope diff --git a/ortools/sat/diffn_util_test.cc b/ortools/sat/diffn_util_test.cc index ce10b81d7e..7838ca0123 100644 --- a/ortools/sat/diffn_util_test.cc +++ b/ortools/sat/diffn_util_test.cc @@ -72,55 +72,47 @@ TEST(CenterToCenterDistanceTest, BasicTest) { EXPECT_EQ(CenterToCenterLInfinityDistance(a, b), 4.0); } -TEST(GetOverlappingRectangleComponentsTest, NoComponents) { - EXPECT_TRUE(GetOverlappingRectangleComponents({}, {}).empty()); +TEST(GetOverlappingRectangleComponentsTest, Disconnected) { + EXPECT_TRUE(GetOverlappingRectangleComponents({}).empty()); IntegerValue zero(0); IntegerValue two(2); IntegerValue four(4); - EXPECT_TRUE(GetOverlappingRectangleComponents( - {{zero, two, zero, two}, {two, four, two, four}}, {}) - .empty()); - std::vector first = {0}; - EXPECT_TRUE(GetOverlappingRectangleComponents( - {{zero, two, zero, two}, {two, four, two, four}}, - absl::MakeSpan(first)) - .empty()); - std::vector both = {0, 1}; - EXPECT_TRUE(GetOverlappingRectangleComponents( - {{zero, two, zero, two}, {two, four, two, four}}, - absl::MakeSpan(both)) - .empty()); - EXPECT_TRUE(GetOverlappingRectangleComponents( - {{zero, two, zero, two}, {two, four, zero, two}}, - absl::MakeSpan(both)) - .empty()); - EXPECT_TRUE(GetOverlappingRectangleComponents( - {{zero, two, zero, two}, {zero, two, two, four}}, - absl::MakeSpan(both)) - .empty()); + EXPECT_THAT( + GetOverlappingRectangleComponents( + {{zero, two, zero, two}, {two, four, two, four}}) + .AsVectorOfSpan(), + UnorderedElementsAre(UnorderedElementsAre(0), UnorderedElementsAre(1))); + EXPECT_THAT( + GetOverlappingRectangleComponents( + {{zero, two, zero, two}, {two, four, zero, two}}) + .AsVectorOfSpan(), + UnorderedElementsAre(UnorderedElementsAre(0), UnorderedElementsAre(1))); + EXPECT_THAT( + GetOverlappingRectangleComponents( + {{zero, two, zero, two}, {zero, two, two, four}}) + .AsVectorOfSpan(), + UnorderedElementsAre(UnorderedElementsAre(0), UnorderedElementsAre(1))); } TEST(GetOverlappingRectangleComponentsTest, ComponentAndActive) { - EXPECT_TRUE(GetOverlappingRectangleComponents({}, {}).empty()); + EXPECT_TRUE(GetOverlappingRectangleComponents({}).empty()); IntegerValue zero(0); IntegerValue one(1); IntegerValue two(2); IntegerValue three(3); IntegerValue four(4); - std::vector all = {0, 1, 2}; - const auto& components = GetOverlappingRectangleComponents( - {{zero, two, zero, two}, {zero, two, one, three}, {zero, two, two, four}}, - absl::MakeSpan(all)); - ASSERT_EQ(1, components.size()); - EXPECT_EQ(3, components[0].size()); - - std::vector only_two = {0, 2}; - EXPECT_TRUE(GetOverlappingRectangleComponents({{zero, two, zero, two}, + EXPECT_THAT(GetOverlappingRectangleComponents({{zero, two, zero, two}, {zero, two, one, three}, - {zero, two, two, four}}, - absl::MakeSpan(only_two)) - .empty()); + {zero, two, two, four}}) + .AsVectorOfSpan(), + UnorderedElementsAre(UnorderedElementsAre(0, 1, 2))); + + EXPECT_THAT( + GetOverlappingRectangleComponents( + {{zero, two, zero, two}, {zero, two, two, four}}) + .AsVectorOfSpan(), + UnorderedElementsAre(UnorderedElementsAre(0), UnorderedElementsAre(1))); } TEST(AnalyzeIntervalsTest, Random) { @@ -995,7 +987,7 @@ bool GraphsDefineSameConnectedComponents( return components1 == components2; } -bool HasCycles(const std::vector>& graph) { +bool HasCycles(absl::Span> graph) { std::vector> view; for (const auto& [a, b] : graph) { if (view.size() <= std::max(a, b)) view.resize(std::max(a, b) + 1); @@ -1114,7 +1106,7 @@ TEST(FindPartialIntersections, Random) { } void CheckFuzzedRectangles( - const std::vector>& tuples) { + absl::Span> tuples) { std::vector rectangles; rectangles.reserve(tuples.size()); for (const auto& [x_min, x_size, y_min, y_size] : tuples) { @@ -1181,7 +1173,7 @@ TEST(FindPairwiseRestrictionsTest, Random) { const int num_rectangles = absl::Uniform(random, 1, 20); const std::vector rectangles = GenerateNonConflictingRectangles(num_rectangles, random); - const std::vector items = + const std::vector items = GenerateItemsRectanglesWithNoPairwiseConflict( rectangles, absl::Uniform(random, 0, 1.0), random); std::vector results; @@ -1198,7 +1190,7 @@ void BM_FindPairwiseRestrictions(benchmark::State& state) { // In the vast majority of the cases the propagator doesn't find any pairwise // condition to propagate. Thus we choose to benchmark for this particular // case. - const std::vector items = + const std::vector items = GenerateItemsRectanglesWithNoPairwisePropagation( state.range(0), state.range(1) / 100.0, random); std::vector results; diff --git a/ortools/sat/disjunctive.cc b/ortools/sat/disjunctive.cc index 634838e9e8..bc0823d813 100644 --- a/ortools/sat/disjunctive.cc +++ b/ortools/sat/disjunctive.cc @@ -267,9 +267,23 @@ bool DisjunctiveWithTwoItems::Propagate() { // interval forced absence? Same for the start-max. int task_before = 0; int task_after = 1; - if (helper_->StartMax(0) < helper_->EndMin(1)) { + + const bool task_0_before_task_1 = helper_->StartMax(0) < helper_->EndMin(1); + const bool task_1_before_task_0 = helper_->StartMax(1) < helper_->EndMin(0); + + if (task_0_before_task_1 && task_1_before_task_0 && + helper_->IsPresent(task_before) && helper_->IsPresent(task_after)) { + helper_->ClearReason(); + helper_->AddPresenceReason(task_before); + helper_->AddPresenceReason(task_after); + helper_->AddReasonForBeingBefore(task_before, task_after); + helper_->AddReasonForBeingBefore(task_after, task_before); + return helper_->ReportConflict(); + } + + if (task_0_before_task_1) { // Task 0 must be before task 1. - } else if (helper_->StartMax(1) < helper_->EndMin(0)) { + } else if (task_1_before_task_0) { // Task 1 must be before task 0. std::swap(task_before, task_after); } else { @@ -320,7 +334,8 @@ int DisjunctiveWithTwoItems::RegisterWith(GenericLiteralWatcher* watcher) { template CombinedDisjunctive::CombinedDisjunctive(Model* model) - : helper_(model->GetOrCreate()) { + : helper_(model->GetOrCreate()->GetOrCreateHelper( + model->GetOrCreate()->AllIntervals())) { task_to_disjunctives_.resize(helper_->NumTasks()); auto* watcher = model->GetOrCreate(); diff --git a/ortools/sat/disjunctive.h b/ortools/sat/disjunctive.h index fac7329355..9b26ab6a68 100644 --- a/ortools/sat/disjunctive.h +++ b/ortools/sat/disjunctive.h @@ -258,15 +258,6 @@ class DisjunctiveDetectablePrecedences : public PropagatorInterface { PropagationStatistics stats_; }; -// Singleton model class which is just a SchedulingConstraintHelper will all -// the intervals. -class AllIntervalsHelper : public SchedulingConstraintHelper { - public: - explicit AllIntervalsHelper(Model* model) - : SchedulingConstraintHelper( - model->GetOrCreate()->AllIntervals(), model) {} -}; - // This propagates the same things as DisjunctiveDetectablePrecedences, except // that it only sort the full set of intervals once and then work on a combined // set of disjunctives. @@ -282,7 +273,7 @@ class CombinedDisjunctive : public PropagatorInterface { bool Propagate() final; private: - AllIntervalsHelper* helper_; + SchedulingConstraintHelper* helper_; std::vector> task_to_disjunctives_; std::vector task_is_added_; std::vector task_sets_; diff --git a/ortools/sat/feasibility_jump.cc b/ortools/sat/feasibility_jump.cc index fb87414cca..71373e19e2 100644 --- a/ortools/sat/feasibility_jump.cc +++ b/ortools/sat/feasibility_jump.cc @@ -37,6 +37,7 @@ #include "absl/types/span.h" #include "ortools/algorithms/binary_search.h" #include "ortools/base/logging.h" +#include "ortools/sat/combine_solutions.h" #include "ortools/sat/constraint_violation.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_checker.h" @@ -232,6 +233,7 @@ void FeasibilityJumpSolver::ResetCurrentSolution( const double range_ratio = params_.feasibility_jump_var_perburbation_range_ratio(); std::vector& solution = state_->solution; + state_->base_solution = nullptr; // Resize the solution if needed. solution.resize(num_variables); @@ -387,6 +389,7 @@ std::function FeasibilityJumpSolver::GenerateTask(int64_t /*task_id*/) { solution = shared_response_->SolutionsRepository() .GetRandomBiasedSolution(random_); state_->solution = solution->variable_values; + state_->base_solution = solution; ++state_->num_solutions_imported; } else { if (!first_time) { @@ -489,9 +492,17 @@ std::function FeasibilityJumpSolver::GenerateTask(int64_t /*task_id*/) { if (DoSomeLinearIterations() && DoSomeGeneralIterations()) { // Checks for infeasibility induced by the non supported constraints. if (SolutionIsFeasible(linear_model_->model_proto(), state_->solution)) { - shared_response_->NewSolution( - state_->solution, absl::StrCat(name(), "_", state_->options.name(), - "(", OneLineStats(), ")")); + auto pointers = PushAndMaybeCombineSolution( + shared_response_, linear_model_->model_proto(), state_->solution, + absl::StrCat(name(), "_", state_->options.name(), "(", + OneLineStats(), ")"), + state_->base_solution == nullptr + ? absl::Span() + : state_->base_solution->variable_values, + /*model=*/nullptr); + // If we pushed a new solution, we use it as a new "base" so that we + // will have a smaller delta on the next solution we find. + state_->base_solution = pointers.pushed_solution; } else { shared_response_->LogMessage(name(), "infeasible solution. Aborting."); model_is_supported_ = false; diff --git a/ortools/sat/feasibility_jump.h b/ortools/sat/feasibility_jump.h index 91d030c1ab..d9d1945d7a 100644 --- a/ortools/sat/feasibility_jump.h +++ b/ortools/sat/feasibility_jump.h @@ -301,6 +301,8 @@ struct LsState { // constraint weighted by these weights. std::vector solution; std::vector weights; + std::shared_ptr::Solution> + base_solution; // Depending on the options, we use an exponentially decaying constraint // weight like for SAT activities. diff --git a/ortools/sat/fuzz_testdata/AtMostOneModel b/ortools/sat/fuzz_testdata/AtMostOneModel new file mode 100644 index 0000000000..27e246669d --- /dev/null +++ b/ortools/sat/fuzz_testdata/AtMostOneModel @@ -0,0 +1,15 @@ +# proto-file: ortools/sat/cp_model.proto +# proto-message: operations_research.sat.CpModelProto + +variables { domain: [ 0, 1 ] } +variables { domain: [ 0, 1 ] } +variables { domain: [ 0, 1 ] } +variables { domain: [ 0, 50 ] } +constraints { at_most_one { literals: [ 0, 1, 2 ] } } +constraints { + linear { + vars: [ 0, 1, 2, 3 ] + coeffs: [ 2, 3, 4, -1 ] + domain: [ 0, 10 ] + } +} diff --git a/ortools/sat/fuzz_testdata/AutomatonModel b/ortools/sat/fuzz_testdata/AutomatonModel new file mode 100644 index 0000000000..9a5254cb9b --- /dev/null +++ b/ortools/sat/fuzz_testdata/AutomatonModel @@ -0,0 +1,34 @@ +# proto-file: ortools/sat/cp_model.proto +# proto-message: operations_research.sat.CpModelProto + +variables { domain: 0 domain: 1 } +variables { domain: 0 domain: 1 } +variables { domain: 0 domain: 1 } +variables { domain: 0 domain: 1 } +constraints { + automaton { + final_states: 3 + transition_tail: 0 + transition_tail: 0 + transition_tail: 1 + transition_tail: 2 + transition_tail: 1 + transition_tail: 2 + transition_head: 1 + transition_head: 2 + transition_head: 1 + transition_head: 2 + transition_head: 3 + transition_head: 3 + transition_label: 0 + transition_label: 1 + transition_label: 0 + transition_label: 1 + transition_label: 1 + transition_label: 0 + exprs { vars: 0 coeffs: 1 } + exprs { vars: 1 coeffs: 1 } + exprs { vars: 2 coeffs: 1 } + exprs { vars: 3 coeffs: 1 } + } +} diff --git a/ortools/sat/fuzz_testdata/CircuitModel b/ortools/sat/fuzz_testdata/CircuitModel new file mode 100644 index 0000000000..8c9a7bf323 --- /dev/null +++ b/ortools/sat/fuzz_testdata/CircuitModel @@ -0,0 +1,17 @@ +# proto-file: ortools/sat/cp_model.proto +# proto-message: operations_research.sat.CpModelProto + +variables { domain: [ 0, 1 ] } # 0->1 +variables { domain: [ 0, 1 ] } # 1->2 +variables { domain: [ 0, 1 ] } # 1->0 +variables { domain: [ 0, 1 ] } # 2->0 +variables { domain: [ 0, 1 ] } # 2->2 +variables { domain: [ 0, 1 ] } # 0->2 +variables { domain: [ 0, 1 ] } # 2->1 +constraints { + routes { + tails: [ 0, 1, 1, 2, 2, 0, 2 ] + heads: [ 1, 2, 0, 0, 2, 2, 1 ] + literals: [ 0, 1, 2, 3, 4, 5, 6 ] + } +} diff --git a/ortools/sat/fuzz_testdata/DiophantineModel b/ortools/sat/fuzz_testdata/DiophantineModel new file mode 100644 index 0000000000..a8cba0a425 --- /dev/null +++ b/ortools/sat/fuzz_testdata/DiophantineModel @@ -0,0 +1,17 @@ +# proto-file: ortools/sat/cp_model.proto +# proto-message: operations_research.sat.CpModelProto + +variables { domain: [ 1, 10000000 ] } +variables { domain: [ 0, 10000000 ] } +variables { domain: [ 0, 10000000 ] } +constraints { + linear { + vars: [ 0, 1, 2 ] + coeffs: [ 31013, -41014, -51015 ] + domain: [ 0, 0 ] + } +} +objective { + vars: [ 0 ] + coeffs: [ 1 ] +} diff --git a/ortools/sat/fuzz_testdata/ElementModel b/ortools/sat/fuzz_testdata/ElementModel new file mode 100644 index 0000000000..99f98c8134 --- /dev/null +++ b/ortools/sat/fuzz_testdata/ElementModel @@ -0,0 +1,17 @@ +# proto-file: ortools/sat/cp_model.proto +# proto-message: operations_research.sat.CpModelProto + +variables { domain: [ 0, 5 ] } +variables { domain: [ 0, 7 ] } +constraints { + element { + linear_index { vars: 0 coeffs: 1 } + linear_target { vars: 1 coeffs: 1 } + exprs { offset: 1 } + exprs { offset: 2 } + exprs { offset: 3 } + exprs { offset: 4 } + exprs { offset: 5 } + exprs { offset: 6 } + } +} diff --git a/ortools/sat/fuzz_testdata/ExactlyOneModel b/ortools/sat/fuzz_testdata/ExactlyOneModel new file mode 100644 index 0000000000..74218fabdc --- /dev/null +++ b/ortools/sat/fuzz_testdata/ExactlyOneModel @@ -0,0 +1,17 @@ +# proto-file: ortools/sat/cp_model.proto +# proto-message: operations_research.sat.CpModelProto + +variables { domain: [ 0, 1 ] } +variables { domain: [ 0, 1 ] } +variables { domain: [ 0, 1 ] } +variables { domain: [ 0, 1 ] } +variables { domain: [ 0, 1 ] } +variables { domain: [ -100, 100 ] } +constraints { exactly_one { literals: [ 0, 1, 2, 3, 4 ] } } +constraints { + linear { + vars: [ 0, 1, 3, 4, 5 ] + coeffs: [ 1, 7, -2, 4, 1 ] + domain: [ 10, 10 ] + } +} diff --git a/ortools/sat/fuzz_testdata/IntProdModel b/ortools/sat/fuzz_testdata/IntProdModel new file mode 100644 index 0000000000..76705062fe --- /dev/null +++ b/ortools/sat/fuzz_testdata/IntProdModel @@ -0,0 +1,19 @@ +# proto-file: ortools/sat/cp_model.proto +# proto-message: operations_research.sat.CpModelProto + +variables { + domain: [ 10, 12 ] +} +variables { + domain: [ 2, 2 ] +} +variables { + domain: [ 0, 100 ] +} +constraints { + int_prod { + target { vars: 2 coeffs: 1 } + exprs { vars: 1 coeffs: 1 } + exprs { vars: 0 coeffs: 1 } + } +} diff --git a/ortools/sat/fuzz_testdata/InvalidProblem b/ortools/sat/fuzz_testdata/InvalidProblem deleted file mode 100644 index 4aa225e1db..0000000000 --- a/ortools/sat/fuzz_testdata/InvalidProblem +++ /dev/null @@ -1,5 +0,0 @@ -# proto-file: ortools/sat/cp_model.proto -# proto-message: operations_research.sat.CpModelProto - -variables { -} diff --git a/ortools/sat/fuzz_testdata/InverseModel b/ortools/sat/fuzz_testdata/InverseModel new file mode 100644 index 0000000000..85c9225a9d --- /dev/null +++ b/ortools/sat/fuzz_testdata/InverseModel @@ -0,0 +1,17 @@ +# proto-file: ortools/sat/cp_model.proto +# proto-message: operations_research.sat.CpModelProto + +variables { domain: [ 0, 3 ] } +variables { domain: [ 0, 3 ] } +variables { domain: [ 0, 3 ] } +variables { domain: [ 0, 3 ] } +variables { domain: [ 0, 3 ] } +variables { domain: [ 0, 3 ] } +variables { domain: [ 0, 3 ] } +variables { domain: [ 0, 3 ] } +constraints { + inverse { + f_direct: [ 0, 2, 4, 6 ], + f_inverse: [ 1, 3, 5, 7 ] + } +} diff --git a/ortools/sat/fuzz_testdata/LinMaxModel b/ortools/sat/fuzz_testdata/LinMaxModel new file mode 100644 index 0000000000..ce56cec764 --- /dev/null +++ b/ortools/sat/fuzz_testdata/LinMaxModel @@ -0,0 +1,14 @@ +# proto-file: ortools/sat/cp_model.proto +# proto-message: operations_research.sat.CpModelProto + +variables { domain: [ 0, 1 ] } +variables { domain: [ 0, 1 ] } +variables { domain: [ 0, 1 ] } +constraints { + lin_max { + target { vars: 0 coeffs: 2 } + exprs { vars: 1 coeffs: 1 offset: 1 } + exprs { vars: 2 coeffs: 4 offset: 1 } + exprs { offset: 1 } + } +} diff --git a/ortools/sat/fuzz_testdata/NoOverlap2DOptimization b/ortools/sat/fuzz_testdata/NoOverlap2DOptimization index d444c9fb5e..db846c7b19 100644 --- a/ortools/sat/fuzz_testdata/NoOverlap2DOptimization +++ b/ortools/sat/fuzz_testdata/NoOverlap2DOptimization @@ -2,27 +2,21 @@ # proto-message: operations_research.sat.CpModelProto variables: { - name: "x_0" domain: [ 0, 80 ] } variables: { - name: "y_0" domain: [ 0, 40 ] } variables: { - name: "x_1" domain: [ 0, 80 ] } variables: { - name: "y_1" domain: [ 0, 60 ] } variables: { - name: "x_2" domain: [ 0, 90 ] } variables: { - name: "y_2" domain: [ 0, 50 ] } variables: { domain: [ 1, 1 ] } @@ -39,7 +33,6 @@ constraints: { } } constraints: { - name: "x_interval_0" enforcement_literal: 6 interval: { start: { vars: 0 coeffs: 1 } @@ -48,7 +41,6 @@ constraints: { } } constraints: { - name: "y_interval_0" enforcement_literal: 6 interval: { start: { vars: 1 coeffs: 1 } @@ -57,7 +49,6 @@ constraints: { } } constraints: { - name: "x_interval_1" enforcement_literal: 6 interval: { start: { vars: 2 coeffs: 1 } @@ -66,7 +57,6 @@ constraints: { } } constraints: { - name: "y_interval_1" enforcement_literal: 6 interval: { start: { vars: 3 coeffs: 1 } @@ -75,7 +65,6 @@ constraints: { } } constraints: { - name: "x_interval_2" enforcement_literal: 6 interval: { start: { vars: 4 coeffs: 1 } @@ -84,7 +73,6 @@ constraints: { } } constraints: { - name: "y_interval_2" enforcement_literal: 6 interval: { start: { vars: 5 coeffs: 1 } diff --git a/ortools/sat/fuzz_testdata/PureSatProblem b/ortools/sat/fuzz_testdata/PureSatProblem index 8d80724abd..2fb9866524 100644 --- a/ortools/sat/fuzz_testdata/PureSatProblem +++ b/ortools/sat/fuzz_testdata/PureSatProblem @@ -1,7 +1,6 @@ # proto-file: ortools/sat/cp_model.proto # proto-message: operations_research.sat.CpModelProto -name: "Random 3-SAT" variables { domain: 0 domain: 1 diff --git a/ortools/sat/fuzz_testdata/PureSatProblemWithLimit b/ortools/sat/fuzz_testdata/PureSatProblemWithLimit index 057e77833b..0ec5cc0fc7 100644 --- a/ortools/sat/fuzz_testdata/PureSatProblemWithLimit +++ b/ortools/sat/fuzz_testdata/PureSatProblemWithLimit @@ -1,7 +1,6 @@ # proto-file: ortools/sat/cp_model.proto # proto-message: operations_research.sat.CpModelProto -name: "Random 3-SAT" variables { domain: 0 domain: 1 diff --git a/ortools/sat/fuzz_testdata/ReservoirModel b/ortools/sat/fuzz_testdata/ReservoirModel new file mode 100644 index 0000000000..9acc7cbc30 --- /dev/null +++ b/ortools/sat/fuzz_testdata/ReservoirModel @@ -0,0 +1,18 @@ +# proto-file: ortools/sat/cp_model.proto +# proto-message: operations_research.sat.CpModelProto + +variables { domain: [ 0, 2 ] } +variables { domain: [ 0, 2 ] } +variables { domain: [ 1, 1 ] } +variables { domain: [ 0, 1 ] } +constraints { + reservoir { + time_exprs: { vars: 0 coeffs: 1 } + time_exprs: { vars: 1 coeffs: 1 } + level_changes: { offset: -1 } + level_changes: { offset: 1 } + active_literals: [ 2, 3 ] + min_level: 0 + max_level: 2 + } +} diff --git a/ortools/sat/fuzz_testdata/SimpleOptionalIntervalFeasible b/ortools/sat/fuzz_testdata/SimpleOptionalIntervalFeasible index 161b4cc41c..a3e26b240c 100644 --- a/ortools/sat/fuzz_testdata/SimpleOptionalIntervalFeasible +++ b/ortools/sat/fuzz_testdata/SimpleOptionalIntervalFeasible @@ -50,7 +50,7 @@ constraints { interval { start { vars: 1 coeffs: 1 } end { vars: 3 coeffs: 1 } - size { vars: 2 offset: 2 } + size { vars: 2 coeffs: 1 offset: 2 } } } constraints { diff --git a/ortools/sat/fuzz_testdata/SolutionHintBasicTest b/ortools/sat/fuzz_testdata/SolutionHintBasicTest index 159c9460a5..cf32fe066d 100644 --- a/ortools/sat/fuzz_testdata/SolutionHintBasicTest +++ b/ortools/sat/fuzz_testdata/SolutionHintBasicTest @@ -1,7 +1,6 @@ # proto-file: ortools/sat/cp_model.proto # proto-message: operations_research.sat.CpModelProto -name: "Random 3-SAT" variables { domain: 0 domain: 1 diff --git a/ortools/sat/fuzz_testdata/SolutionHintEnumerateTest b/ortools/sat/fuzz_testdata/SolutionHintEnumerateTest index 3cf22243ff..d31cf002ec 100644 --- a/ortools/sat/fuzz_testdata/SolutionHintEnumerateTest +++ b/ortools/sat/fuzz_testdata/SolutionHintEnumerateTest @@ -2,12 +2,10 @@ # proto-message: operations_research.sat.CpModelProto variables { - name: "x" domain: 0 domain: 10 } variables { - name: "y" domain: 0 domain: 10 } diff --git a/ortools/sat/fuzz_testdata/TableProblem b/ortools/sat/fuzz_testdata/TableProblem new file mode 100644 index 0000000000..463d98f6c8 --- /dev/null +++ b/ortools/sat/fuzz_testdata/TableProblem @@ -0,0 +1,15 @@ +# proto-file: ortools/sat/cp_model.proto +# proto-message: operations_research.sat.CpModelProto + +variables { + domain: [ 0, 4 ] +} +variables { + domain: [ 0, 4 ] +} +constraints { + table { + vars: [ 0, 1 ] + values: [ 0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4 ] + } +} diff --git a/ortools/sat/fuzz_testdata/TrivialLinearTranslatedModel b/ortools/sat/fuzz_testdata/TrivialLinearTranslatedModel deleted file mode 100644 index 683c7e6072..0000000000 --- a/ortools/sat/fuzz_testdata/TrivialLinearTranslatedModel +++ /dev/null @@ -1,42 +0,0 @@ -# proto-file: ortools/sat/cp_model.proto -# proto-message: operations_research.sat.CpModelProto - -variables { - domain: -10 - domain: 10 -} -variables { - domain: -10 - domain: 10 -} -variables { - domain: -4611686018427387903 - domain: 4611686018427387903 -} -constraints { - linear { - vars: 0 - vars: 1 - coeffs: 1 - coeffs: 1 - domain: -4611686018427387903 - domain: 4611686018427387903 - } -} -constraints { - linear { - vars: 0 - vars: 1 - vars: 2 - coeffs: 1 - coeffs: 2 - coeffs: -1 - domain: 0 - domain: 0 - } -} -objective { - vars: -3 - scaling_factor: -1 - coeffs: 1 -} diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 916e1b5a2e..91d22b090c 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -1600,7 +1600,8 @@ inline std::function Equality(IntegerVariable v, int64_t value) { inline std::function Implication( absl::Span enforcement_literals, IntegerLiteral i) { return [=](Model* model) { - IntegerTrail* integer_trail = model->GetOrCreate(); + auto* sat_solver = model->GetOrCreate(); + auto* integer_trail = model->GetOrCreate(); if (i.bound <= integer_trail->LowerBound(i.var)) { // Always true! nothing to do. } else if (i.bound > integer_trail->UpperBound(i.var)) { @@ -1609,7 +1610,7 @@ inline std::function Implication( for (const Literal literal : enforcement_literals) { clause.push_back(literal.Negated()); } - model->Add(ClauseConstraint(clause)); + sat_solver->AddClauseDuringSearch(clause); } else { // TODO(user): Double check what happen when we associate a trivially // true or false literal. @@ -1618,7 +1619,7 @@ inline std::function Implication( for (const Literal literal : enforcement_literals) { clause.push_back(literal.Negated()); } - model->Add(ClauseConstraint(clause)); + sat_solver->AddClauseDuringSearch(clause); } }; } diff --git a/ortools/sat/integer_base.h b/ortools/sat/integer_base.h index 42fa3aa7f8..95ce2ddff6 100644 --- a/ortools/sat/integer_base.h +++ b/ortools/sat/integer_base.h @@ -109,28 +109,6 @@ inline bool AtMinOrMaxInt64I(IntegerValue t) { return AtMinOrMaxInt64(t.value()); } -// Helper for dividing several small integers by the same value. Note that there -// is no point using this class is the divisor is a compile-time constant, since -// the compiler should be smart enough to do this automatically. -// Building a `QuickSmallDivision` object costs an integer division, but each -// call to `DivideByDivisor` will only do an integer multiplication and a shift. -// -// This class always return the exact value of the division for all possible -// values of `dividend` and `divisor`. -class QuickSmallDivision { - public: - explicit QuickSmallDivision(uint16_t divisor) - : inverse_((1ull << 48) / divisor + 1) {} - - uint16_t DivideByDivisor(uint16_t dividend) const { - return static_cast((inverse_ * static_cast(dividend)) >> - 48); - } - - private: - uint64_t inverse_; -}; - // Returns dividend - FloorRatio(dividend, divisor) * divisor; // // This function is around the same speed than the computation above, but it @@ -405,6 +383,9 @@ struct ValueLiteralPair { std::ostream& operator<<(std::ostream& os, const ValueLiteralPair& p); +DEFINE_STRONG_INDEX_TYPE(IntervalVariable); +const IntervalVariable kNoIntervalVariable(-1); + // ============================================================================ // Implementation. // ============================================================================ diff --git a/ortools/sat/integer_expr.h b/ortools/sat/integer_expr.h index 5073873200..853b6f8151 100644 --- a/ortools/sat/integer_expr.h +++ b/ortools/sat/integer_expr.h @@ -445,7 +445,7 @@ inline std::function WeightedSumLowerOrEqual( // Weighted sum >= constant. template inline std::function WeightedSumGreaterOrEqual( - const std::vector& vars, const VectorInt& coefficients, + absl::Span vars, const VectorInt& coefficients, int64_t lower_bound) { // We just negate everything and use an <= constraints. std::vector negated_coeffs(coefficients.begin(), coefficients.end()); diff --git a/ortools/sat/integer_search.cc b/ortools/sat/integer_search.cc index 148e0d93a1..e93f99992f 100644 --- a/ortools/sat/integer_search.cc +++ b/ortools/sat/integer_search.cc @@ -47,6 +47,7 @@ #include "ortools/sat/sat_inprocessing.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" +#include "ortools/sat/scheduling_helpers.h" #include "ortools/sat/synchronization.h" #include "ortools/sat/util.h" #include "ortools/util/strong_integers.h" @@ -730,7 +731,7 @@ std::function DisjunctivePrecedenceSearchHeuristic( } } - // TODO(Fdid): Also compare the second part of the precedence in + // TODO(user): Also compare the second part of the precedence in // PrecedenceIsBetter() and not just the interval before? if (best_helper == nullptr || PrecedenceIsBetter(helper, a, best_helper, best_before)) { @@ -743,13 +744,21 @@ std::function DisjunctivePrecedenceSearchHeuristic( } if (best_helper != nullptr) { + // If one of the task presence is undecided, start by making it present. + for (const int t : {best_before, best_after}) { + if (!best_helper->IsPresent(t)) { + VLOG(2) << "Presence: " << best_helper->TaskDebugString(t); + return BooleanOrIntegerLiteral(best_helper->PresenceLiteral(t)); + } + } + VLOG(2) << "New disjunctive precedence: " << best_helper->TaskDebugString(best_before) << " " << best_helper->TaskDebugString(best_after); - const IntervalVariable a = best_helper->IntervalVariables()[best_before]; - const IntervalVariable b = best_helper->IntervalVariables()[best_after]; - repo->CreateDisjunctivePrecedenceLiteral(a, b); - return BooleanOrIntegerLiteral(repo->GetPrecedenceLiteral(a, b)); + const auto a = best_helper->GetIntervalDefinition(best_before); + const auto b = best_helper->GetIntervalDefinition(best_after); + return BooleanOrIntegerLiteral( + repo->GetOrCreateDisjunctivePrecedenceLiteral(a, b)); } return BooleanOrIntegerLiteral(); @@ -874,9 +883,8 @@ std::function CumulativePrecedenceSearchHeuristic( CHECK_LT(helper->StartMin(t), helper->EndMin(s)); // skip if we already have a literal created and assigned to false. - const IntervalVariable a = helper->IntervalVariables()[s]; - const IntervalVariable b = helper->IntervalVariables()[t]; - const LiteralIndex existing = repo->GetPrecedenceLiteral(a, b); + const LiteralIndex existing = repo->GetPrecedenceLiteral( + helper->Ends()[s], helper->Starts()[t]); if (existing != kNoLiteralIndex) { // It shouldn't be able to be true here otherwise we will have s and // t disjoint. @@ -899,7 +907,8 @@ std::function CumulativePrecedenceSearchHeuristic( } // It shouldn't be able to fail since s can be before t. - CHECK(repo->CreatePrecedenceLiteral(a, b)); + CHECK(repo->CreatePrecedenceLiteral(helper->Ends()[s], + helper->Starts()[t])); } // Branch on that precedence. @@ -950,10 +959,11 @@ std::function CumulativePrecedenceSearchHeuristic( if (best_helper != nullptr) { VLOG(2) << "New precedence: " << best_helper->TaskDebugString(best_before) << " " << best_helper->TaskDebugString(best_after); - const IntervalVariable a = best_helper->IntervalVariables()[best_before]; - const IntervalVariable b = best_helper->IntervalVariables()[best_after]; - repo->CreatePrecedenceLiteral(a, b); - return BooleanOrIntegerLiteral(repo->GetPrecedenceLiteral(a, b)); + const AffineExpression end_a = best_helper->Ends()[best_before]; + const AffineExpression start_b = best_helper->Starts()[best_after]; + repo->CreatePrecedenceLiteral(end_a, start_b); + return BooleanOrIntegerLiteral( + repo->GetPrecedenceLiteral(end_a, start_b)); } return BooleanOrIntegerLiteral(); diff --git a/ortools/sat/integer_test.cc b/ortools/sat/integer_test.cc index 6b31cc2e05..ef95eb7668 100644 --- a/ortools/sat/integer_test.cc +++ b/ortools/sat/integer_test.cc @@ -1227,18 +1227,6 @@ TEST(IntegerTrailTest, AppendNewBounds) { var, IntegerValue(9)))); } -TEST(FastDivisionTest, AllPossibleValues) { - for (int i = 1; i <= std::numeric_limits::max(); ++i) { - const QuickSmallDivision div(i); - for (int j = 0; j <= std::numeric_limits::max(); ++j) { - const uint16_t result = div.DivideByDivisor(j); - const uint16_t j_rounded_to_lowest_multiple = result * i; - CHECK_LE(j_rounded_to_lowest_multiple, j); - CHECK_GT(j_rounded_to_lowest_multiple + i, j); - } - } -} - static void BM_FloorRatio(benchmark::State& state) { IntegerValue divisor(654676436498); IntegerValue dividend(45454655155444); diff --git a/ortools/sat/intervals.cc b/ortools/sat/intervals.cc index 115ef50ca1..e96d2575c5 100644 --- a/ortools/sat/intervals.cc +++ b/ortools/sat/intervals.cc @@ -13,30 +13,23 @@ #include "ortools/sat/intervals.h" -#include -#include -#include -#include +#include #include #include #include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" #include "absl/meta/type_traits.h" -#include "absl/strings/str_cat.h" #include "absl/types/span.h" -#include "ortools/base/logging.h" #include "ortools/base/strong_vector.h" -#include "ortools/sat/implied_bounds.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/integer_expr.h" #include "ortools/sat/linear_constraint.h" #include "ortools/sat/model.h" -#include "ortools/sat/precedences.h" +#include "ortools/sat/no_overlap_2d_helper.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_solver.h" -#include "ortools/util/sort.h" +#include "ortools/sat/scheduling_helpers.h" #include "ortools/util/strong_integers.h" namespace operations_research { @@ -85,40 +78,63 @@ IntervalVariable IntervalsRepository::CreateInterval(AffineExpression start, void IntervalsRepository::CreateDisjunctivePrecedenceLiteral( IntervalVariable a, IntervalVariable b) { - if (disjunctive_precedences_.contains({a, b})) return; + GetOrCreateDisjunctivePrecedenceLiteral( + IntervalDefinition{.start = Start(a), + .end = End(a), + .size = Size(a), + .is_present = IsOptional(a) + ? std::optional(PresenceLiteral(a)) + : std::nullopt}, + IntervalDefinition{.start = Start(b), + .end = End(b), + .size = Size(b), + .is_present = IsOptional(b) + ? std::optional(PresenceLiteral(b)) + : std::nullopt}); +} + +LiteralIndex IntervalsRepository::GetOrCreateDisjunctivePrecedenceLiteral( + const IntervalDefinition& a, const IntervalDefinition& b) { + auto it = disjunctive_precedences_.find({a, b}); + if (it != disjunctive_precedences_.end()) return it->second.Index(); std::vector enforcement_literals; - if (IsOptional(a)) enforcement_literals.push_back(PresenceLiteral(a)); - if (IsOptional(b)) enforcement_literals.push_back(PresenceLiteral(b)); + if (a.is_present.has_value()) { + enforcement_literals.push_back(a.is_present.value()); + } + if (b.is_present.has_value()) { + enforcement_literals.push_back(b.is_present.value()); + } + if (sat_solver_->CurrentDecisionLevel() == 0) { int new_size = 0; for (const Literal l : enforcement_literals) { // We can ignore always absent interval, and skip the literal of the // interval that are now always present. if (assignment_.LiteralIsTrue(l)) continue; - if (assignment_.LiteralIsFalse(l)) return; + if (assignment_.LiteralIsFalse(l)) return kNoLiteralIndex; enforcement_literals[new_size++] = l; } enforcement_literals.resize(new_size); } - const AffineExpression start_a = Start(a); - const AffineExpression end_a = End(a); - const AffineExpression start_b = Start(b); - const AffineExpression end_b = End(b); - - // task_a is always before task_b ? - if (integer_trail_->UpperBound(start_a) < integer_trail_->LowerBound(end_b)) { - AddConditionalAffinePrecedence(enforcement_literals, end_a, start_b, - model_); - return; + // task_a is currently before task_b ? + // Lets not create a literal that will be propagated right away. + if (integer_trail_->UpperBound(a.start) < integer_trail_->LowerBound(b.end)) { + if (sat_solver_->CurrentDecisionLevel() == 0) { + AddConditionalAffinePrecedence(enforcement_literals, a.end, b.start, + model_); + } + return kNoLiteralIndex; } - // task_b is always before task_a ? - if (integer_trail_->UpperBound(start_b) < integer_trail_->LowerBound(end_a)) { - AddConditionalAffinePrecedence(enforcement_literals, end_b, start_a, - model_); - return; + // task_b is before task_a ? + if (integer_trail_->UpperBound(b.start) < integer_trail_->LowerBound(a.end)) { + if (sat_solver_->CurrentDecisionLevel() == 0) { + AddConditionalAffinePrecedence(enforcement_literals, b.end, a.start, + model_); + } + return kNoLiteralIndex; } // Create a new literal. @@ -128,16 +144,18 @@ void IntervalsRepository::CreateDisjunctivePrecedenceLiteral( disjunctive_precedences_.insert({{b, a}, a_before_b.Negated()}); // Also insert it in precedences. - // TODO(user): also add the reverse like start_b + 1 <= end_a if negated? - precedences_.insert({{end_a, start_b}, a_before_b}); - precedences_.insert({{end_b, start_a}, a_before_b.Negated()}); + if (enforcement_literals.empty()) { + // TODO(user): also add the reverse like start_b + 1 <= end_a if negated? + precedences_.insert({{a.end, b.start}, a_before_b}); + precedences_.insert({{b.end, a.start}, a_before_b.Negated()}); + } enforcement_literals.push_back(a_before_b); - AddConditionalAffinePrecedence(enforcement_literals, end_a, start_b, model_); + AddConditionalAffinePrecedence(enforcement_literals, a.end, b.start, model_); enforcement_literals.pop_back(); enforcement_literals.push_back(a_before_b.Negated()); - AddConditionalAffinePrecedence(enforcement_literals, end_b, start_a, model_); + AddConditionalAffinePrecedence(enforcement_literals, b.end, a.start, model_); enforcement_literals.pop_back(); // Force the value of boolean_var in case the precedence is not active. This @@ -145,12 +163,12 @@ void IntervalsRepository::CreateDisjunctivePrecedenceLiteral( for (const Literal l : enforcement_literals) { implications_->AddBinaryClause(l, a_before_b); } + + return a_before_b; } -bool IntervalsRepository::CreatePrecedenceLiteral(IntervalVariable a, - IntervalVariable b) { - const AffineExpression x = End(a); - const AffineExpression y = Start(b); +bool IntervalsRepository::CreatePrecedenceLiteral(AffineExpression x, + AffineExpression y) { if (precedences_.contains({x, y})) return false; // We want l => x <= y and not(l) => x > y <=> y + 1 <= x @@ -177,9 +195,7 @@ bool IntervalsRepository::CreatePrecedenceLiteral(IntervalVariable a, } LiteralIndex IntervalsRepository::GetPrecedenceLiteral( - IntervalVariable a, IntervalVariable b) const { - const AffineExpression x = End(a); - const AffineExpression y = Start(b); + AffineExpression x, AffineExpression y) const { const auto it = precedences_.find({x, y}); if (it != precedences_.end()) return it->second.Index(); return kNoLiteralIndex; @@ -193,9 +209,25 @@ SchedulingConstraintHelper* IntervalsRepository::GetOrCreateHelper( bool register_as_disjunctive_helper) { const auto it = helper_repository_.find(variables); if (it != helper_repository_.end()) return it->second; + std::vector starts; + std::vector ends; + std::vector sizes; + std::vector reason_for_presence; - SchedulingConstraintHelper* helper = - new SchedulingConstraintHelper(variables, model_); + for (const IntervalVariable i : variables) { + if (IsOptional(i)) { + reason_for_presence.push_back(PresenceLiteral(i).Index()); + } else { + reason_for_presence.push_back(kNoLiteralIndex); + } + sizes.push_back(Size(i)); + starts.push_back(Start(i)); + ends.push_back(End(i)); + } + SchedulingConstraintHelper* helper = new SchedulingConstraintHelper( + std::move(starts), std::move(ends), std::move(sizes), + std::move(reason_for_presence), model_); + helper->RegisterWith(model_->GetOrCreate()); helper_repository_[variables] = helper; model_->TakeOwnership(helper); if (register_as_disjunctive_helper) { @@ -204,6 +236,21 @@ SchedulingConstraintHelper* IntervalsRepository::GetOrCreateHelper( return helper; } +NoOverlap2DConstraintHelper* IntervalsRepository::GetOrCreate2DHelper( + const std::vector& x_variables, + const std::vector& y_variables) { + const auto it = + no_overlap_2d_helper_repository_.find({x_variables, y_variables}); + if (it != no_overlap_2d_helper_repository_.end()) return it->second; + + NoOverlap2DConstraintHelper* helper = new NoOverlap2DConstraintHelper( + GetOrCreateHelper(x_variables), GetOrCreateHelper(y_variables), model_); + helper->RegisterWith(model_->GetOrCreate()); + no_overlap_2d_helper_repository_[{x_variables, y_variables}] = helper; + model_->TakeOwnership(helper); + return helper; +} + SchedulingDemandHelper* IntervalsRepository::GetOrCreateDemandHelper( SchedulingConstraintHelper* helper, absl::Span demands) { @@ -226,1045 +273,5 @@ void IntervalsRepository::InitAllDecomposedEnergies() { } } -SchedulingConstraintHelper::SchedulingConstraintHelper( - const std::vector& tasks, Model* model) - : model_(model), - trail_(model->GetOrCreate()), - sat_solver_(model->GetOrCreate()), - integer_trail_(model->GetOrCreate()), - watcher_(model->GetOrCreate()), - precedence_relations_(model->GetOrCreate()), - interval_variables_(tasks), - capacity_(tasks.size()), - cached_size_min_(new IntegerValue[capacity_]), - cached_start_min_(new IntegerValue[capacity_]), - cached_end_min_(new IntegerValue[capacity_]), - cached_negated_start_max_(new IntegerValue[capacity_]), - cached_negated_end_max_(new IntegerValue[capacity_]), - cached_shifted_start_min_(new IntegerValue[capacity_]), - cached_negated_shifted_end_max_(new IntegerValue[capacity_]) { - starts_.clear(); - ends_.clear(); - minus_ends_.clear(); - minus_starts_.clear(); - sizes_.clear(); - reason_for_presence_.clear(); - - auto* repository = model->GetOrCreate(); - for (const IntervalVariable i : tasks) { - if (repository->IsOptional(i)) { - reason_for_presence_.push_back(repository->PresenceLiteral(i).Index()); - } else { - reason_for_presence_.push_back(kNoLiteralIndex); - } - sizes_.push_back(repository->Size(i)); - starts_.push_back(repository->Start(i)); - ends_.push_back(repository->End(i)); - minus_starts_.push_back(repository->Start(i).Negated()); - minus_ends_.push_back(repository->End(i).Negated()); - } - - RegisterWith(model->GetOrCreate()); - InitSortedVectors(); - if (!SynchronizeAndSetTimeDirection(true)) { - model->GetOrCreate()->NotifyThatModelIsUnsat(); - } -} - -SchedulingConstraintHelper::SchedulingConstraintHelper(int num_tasks, - Model* model) - : model_(model), - trail_(model->GetOrCreate()), - sat_solver_(model->GetOrCreate()), - integer_trail_(model->GetOrCreate()), - precedence_relations_(model->GetOrCreate()), - capacity_(num_tasks), - cached_size_min_(new IntegerValue[capacity_]), - cached_start_min_(new IntegerValue[capacity_]), - cached_end_min_(new IntegerValue[capacity_]), - cached_negated_start_max_(new IntegerValue[capacity_]), - cached_negated_end_max_(new IntegerValue[capacity_]), - cached_shifted_start_min_(new IntegerValue[capacity_]), - cached_negated_shifted_end_max_(new IntegerValue[capacity_]) { - starts_.resize(num_tasks); - CHECK_EQ(NumTasks(), num_tasks); -} - -bool SchedulingConstraintHelper::Propagate() { - recompute_all_cache_ = true; - for (const int id : propagator_ids_) watcher_->CallOnNextPropagate(id); - return true; -} - -bool SchedulingConstraintHelper::IncrementalPropagate( - const std::vector& watch_indices) { - for (const int t : watch_indices) recompute_cache_.Set(t); - for (const int id : propagator_ids_) watcher_->CallOnNextPropagate(id); - return true; -} - -void SchedulingConstraintHelper::RegisterWith(GenericLiteralWatcher* watcher) { - const int id = watcher->Register(this); - const int num_tasks = starts_.size(); - for (int t = 0; t < num_tasks; ++t) { - watcher->WatchIntegerVariable(sizes_[t].var, id, t); - watcher->WatchIntegerVariable(starts_[t].var, id, t); - watcher->WatchIntegerVariable(ends_[t].var, id, t); - } - watcher->SetPropagatorPriority(id, 0); -} - -bool SchedulingConstraintHelper::UpdateCachedValues(int t) { - if (IsAbsent(t)) return true; - - IntegerValue smin = integer_trail_->LowerBound(starts_[t]); - IntegerValue smax = integer_trail_->UpperBound(starts_[t]); - IntegerValue emin = integer_trail_->LowerBound(ends_[t]); - IntegerValue emax = integer_trail_->UpperBound(ends_[t]); - - // We take the max for the corner case where the size of an optional interval - // is used elsewhere and has a domain with negative value. - // - // TODO(user): maybe we should just disallow size with a negative domain, but - // is is harder to enforce if we have a linear expression for size. - IntegerValue dmin = - std::max(IntegerValue(0), integer_trail_->LowerBound(sizes_[t])); - IntegerValue dmax = integer_trail_->UpperBound(sizes_[t]); - - // Detect first if we have a conflict using the relation start + size = end. - if (dmax < 0) { - ClearReason(); - AddSizeMaxReason(t, dmax); - return PushTaskAbsence(t); - } - if (smin + dmin - emax > 0) { - ClearReason(); - AddStartMinReason(t, smin); - AddSizeMinReason(t, dmin); - AddEndMaxReason(t, emax); - return PushTaskAbsence(t); - } - if (smax + dmax - emin < 0) { - ClearReason(); - AddStartMaxReason(t, smax); - AddSizeMaxReason(t, dmax); - AddEndMinReason(t, emin); - return PushTaskAbsence(t); - } - - // Sometimes, for optional interval with non-optional bounds, this propagation - // give tighter bounds. We always consider the value assuming - // the interval is present. - // - // Note that this is also useful in case not everything was propagated. Note - // also that since there is no conflict, we reach the fix point in one pass. - smin = std::max(smin, emin - dmax); - smax = std::min(smax, emax - dmin); - dmin = std::max(dmin, emin - smax); - emin = std::max(emin, smin + dmin); - emax = std::min(emax, smax + dmax); - - if (emin != cached_end_min_[t]) { - recompute_energy_profile_ = true; - } - - // We might only want to do that if the value changed, but I am not sure it - // is worth the test. - recompute_by_start_max_ = true; - recompute_by_end_min_ = true; - - cached_start_min_[t] = smin; - cached_end_min_[t] = emin; - cached_negated_start_max_[t] = -smax; - cached_negated_end_max_[t] = -emax; - cached_size_min_[t] = dmin; - - // Note that we use the cached value here for EndMin()/StartMax(). - const IntegerValue new_shifted_start_min = emin - dmin; - if (new_shifted_start_min != cached_shifted_start_min_[t]) { - recompute_energy_profile_ = true; - recompute_shifted_start_min_ = true; - cached_shifted_start_min_[t] = new_shifted_start_min; - } - const IntegerValue new_negated_shifted_end_max = -(smax + dmin); - if (new_negated_shifted_end_max != cached_negated_shifted_end_max_[t]) { - recompute_negated_shifted_end_max_ = true; - cached_negated_shifted_end_max_[t] = new_negated_shifted_end_max; - } - return true; -} - -bool SchedulingConstraintHelper::ResetFromSubset( - const SchedulingConstraintHelper& other, absl::Span tasks) { - current_time_direction_ = other.current_time_direction_; - - const int num_tasks = tasks.size(); - interval_variables_.resize(num_tasks); - starts_.resize(num_tasks); - ends_.resize(num_tasks); - minus_ends_.resize(num_tasks); - minus_starts_.resize(num_tasks); - sizes_.resize(num_tasks); - reason_for_presence_.resize(num_tasks); - for (int i = 0; i < num_tasks; ++i) { - const int t = tasks[i]; - interval_variables_[i] = other.interval_variables_[t]; - starts_[i] = other.starts_[t]; - ends_[i] = other.ends_[t]; - minus_ends_[i] = other.minus_ends_[t]; - minus_starts_[i] = other.minus_starts_[t]; - sizes_[i] = other.sizes_[t]; - reason_for_presence_[i] = other.reason_for_presence_[t]; - } - - InitSortedVectors(); - return SynchronizeAndSetTimeDirection(true); -} - -void SchedulingConstraintHelper::InitSortedVectors() { - const int num_tasks = starts_.size(); - - recompute_all_cache_ = true; - recompute_cache_.Resize(num_tasks); - for (int t = 0; t < num_tasks; ++t) { - recompute_cache_.Set(t); - } - - // Make sure all the cached_* arrays can hold enough data. - CHECK_LE(num_tasks, capacity_); - - task_by_increasing_start_min_.resize(num_tasks); - task_by_increasing_end_min_.resize(num_tasks); - task_by_increasing_negated_start_max_.resize(num_tasks); - task_by_decreasing_end_max_.resize(num_tasks); - task_by_increasing_shifted_start_min_.resize(num_tasks); - task_by_negated_shifted_end_max_.resize(num_tasks); - for (int t = 0; t < num_tasks; ++t) { - task_by_increasing_start_min_[t].task_index = t; - task_by_increasing_end_min_[t].task_index = t; - task_by_increasing_negated_start_max_[t].task_index = t; - task_by_decreasing_end_max_[t].task_index = t; - - task_by_increasing_shifted_start_min_[t].task_index = t; - task_by_increasing_shifted_start_min_[t].presence_lit = - reason_for_presence_[t]; - task_by_negated_shifted_end_max_[t].task_index = t; - task_by_negated_shifted_end_max_[t].presence_lit = reason_for_presence_[t]; - } - - recompute_by_start_max_ = true; - recompute_by_end_min_ = true; - recompute_energy_profile_ = true; - recompute_shifted_start_min_ = true; - recompute_negated_shifted_end_max_ = true; -} - -void SchedulingConstraintHelper::SetTimeDirection(bool is_forward) { - if (current_time_direction_ != is_forward) { - current_time_direction_ = is_forward; - - std::swap(starts_, minus_ends_); - std::swap(ends_, minus_starts_); - - std::swap(task_by_increasing_start_min_, task_by_decreasing_end_max_); - std::swap(task_by_increasing_end_min_, - task_by_increasing_negated_start_max_); - std::swap(recompute_by_end_min_, recompute_by_start_max_); - std::swap(task_by_increasing_shifted_start_min_, - task_by_negated_shifted_end_max_); - - recompute_energy_profile_ = true; - std::swap(cached_start_min_, cached_negated_end_max_); - std::swap(cached_end_min_, cached_negated_start_max_); - std::swap(cached_shifted_start_min_, cached_negated_shifted_end_max_); - std::swap(recompute_shifted_start_min_, recompute_negated_shifted_end_max_); - } -} - -bool SchedulingConstraintHelper::SynchronizeAndSetTimeDirection( - bool is_forward) { - SetTimeDirection(is_forward); - - // If there was any backtracks since the last time this was called, we - // recompute our cache. - if (sat_solver_->num_backtracks() != saved_num_backtracks_) { - recompute_all_cache_ = true; - saved_num_backtracks_ = sat_solver_->num_backtracks(); - } - - if (recompute_all_cache_) { - for (int t = 0; t < recompute_cache_.size(); ++t) { - if (!UpdateCachedValues(t)) return false; - } - } else { - for (const int t : recompute_cache_) { - if (!UpdateCachedValues(t)) return false; - } - } - recompute_cache_.ClearAll(); - recompute_all_cache_ = false; - return true; -} - -// TODO(user): be more precise when we know a and b are in disjunction. -// we really just need start_b > start_a, or even >= if duration is non-zero. -IntegerValue SchedulingConstraintHelper::GetCurrentMinDistanceBetweenTasks( - int a, int b, bool add_reason_if_after) { - const AffineExpression before = ends_[a]; - const AffineExpression after = starts_[b]; - if (before.var == kNoIntegerVariable || before.coeff != 1 || - after.var == kNoIntegerVariable || after.coeff != 1) { - return kMinIntegerValue; - } - - // We take the max of the level zero offset and the one coming from a - // conditional precedence at true. - const IntegerValue conditional_offset = - precedence_relations_->GetConditionalOffset(before.var, after.var); - const IntegerValue known = integer_trail_->LevelZeroLowerBound(after.var) - - integer_trail_->LevelZeroUpperBound(before.var); - const IntegerValue offset = std::max(conditional_offset, known); - - const IntegerValue needed_offset = before.constant - after.constant; - const IntegerValue distance = offset - needed_offset; - if (add_reason_if_after && distance >= 0 && known < conditional_offset) { - for (const Literal l : precedence_relations_->GetConditionalEnforcements( - before.var, after.var)) { - literal_reason_.push_back(l.Negated()); - } - } - return distance; -} - -// Note that we could call this at a positive level to propagate any literal -// associated to task a before task b. However we only call this for task that -// are in detectable precedence, which means the normal precedence or linear -// propagator should have already propagated that Boolean too. -bool SchedulingConstraintHelper::PropagatePrecedence(int a, int b) { - CHECK(IsPresent(a)); - CHECK(IsPresent(b)); - CHECK_EQ(trail_->CurrentDecisionLevel(), 0); - - const AffineExpression before = ends_[a]; - const AffineExpression after = starts_[b]; - if (after.coeff != 1) return true; - if (before.coeff != 1) return true; - if (after.var == kNoIntegerVariable) return true; - if (before.var == kNoIntegerVariable) return true; - const IntegerValue offset = before.constant - after.constant; - if (precedence_relations_->Add(before.var, after.var, offset)) { - VLOG(2) << "new relation " << TaskDebugString(a) - << " <= " << TaskDebugString(b); - - // TODO(user): Adding new constraint during propagation might not be the - // best idea as it can create some complication. - AddWeightedSumLowerOrEqual({}, {before.var, after.var}, - {int64_t{1}, int64_t{-1}}, -offset.value(), - model_); - if (model_->GetOrCreate()->ModelIsUnsat()) return false; - } - return true; -} - -absl::Span -SchedulingConstraintHelper::TaskByIncreasingStartMin() { - for (TaskTime& ref : task_by_increasing_start_min_) { - ref.time = StartMin(ref.task_index); - } - IncrementalSort(task_by_increasing_start_min_.begin(), - task_by_increasing_start_min_.end()); - return task_by_increasing_start_min_; -} - -absl::Span -SchedulingConstraintHelper::TaskByIncreasingEndMin() { - if (!recompute_by_end_min_) return task_by_increasing_end_min_; - for (TaskTime& ref : task_by_increasing_end_min_) { - ref.time = EndMin(ref.task_index); - } - IncrementalSort(task_by_increasing_end_min_.begin(), - task_by_increasing_end_min_.end()); - recompute_by_end_min_ = false; - return task_by_increasing_end_min_; -} - -absl::Span -SchedulingConstraintHelper::TaskByIncreasingNegatedStartMax() { - if (!recompute_by_start_max_) return task_by_increasing_negated_start_max_; - for (TaskTime& ref : task_by_increasing_negated_start_max_) { - ref.time = cached_negated_start_max_[ref.task_index]; - } - IncrementalSort(task_by_increasing_negated_start_max_.begin(), - task_by_increasing_negated_start_max_.end()); - recompute_by_start_max_ = false; - return task_by_increasing_negated_start_max_; -} - -absl::Span -SchedulingConstraintHelper::TaskByDecreasingEndMax() { - for (TaskTime& ref : task_by_decreasing_end_max_) { - ref.time = EndMax(ref.task_index); - } - IncrementalSort(task_by_decreasing_end_max_.begin(), - task_by_decreasing_end_max_.end(), std::greater()); - return task_by_decreasing_end_max_; -} - -absl::Span -SchedulingConstraintHelper::TaskByIncreasingShiftedStartMin() { - if (recompute_shifted_start_min_) { - recompute_shifted_start_min_ = false; - bool is_sorted = true; - IntegerValue previous = kMinIntegerValue; - for (CachedTaskBounds& ref : task_by_increasing_shifted_start_min_) { - ref.time = ShiftedStartMin(ref.task_index); - is_sorted = is_sorted && ref.time >= previous; - previous = ref.time; - } - if (is_sorted) return task_by_increasing_shifted_start_min_; - IncrementalSort(task_by_increasing_shifted_start_min_.begin(), - task_by_increasing_shifted_start_min_.end()); - } - return task_by_increasing_shifted_start_min_; -} - -// TODO(user): Avoid recomputing it if nothing changed. -const std::vector& -SchedulingConstraintHelper::GetEnergyProfile() { - if (energy_profile_.empty()) { - const int num_tasks = NumTasks(); - for (int t = 0; t < num_tasks; ++t) { - energy_profile_.push_back( - {cached_shifted_start_min_[t], t, /*is_first=*/true}); - energy_profile_.push_back({cached_end_min_[t], t, /*is_first=*/false}); - } - } else { - if (!recompute_energy_profile_) return energy_profile_; - for (ProfileEvent& ref : energy_profile_) { - const int t = ref.task; - if (ref.is_first) { - ref.time = cached_shifted_start_min_[t]; - } else { - ref.time = cached_end_min_[t]; - } - } - } - IncrementalSort(energy_profile_.begin(), energy_profile_.end()); - recompute_energy_profile_ = false; - return energy_profile_; -} - -// Produces a relaxed reason for StartMax(before) < EndMin(after). -void SchedulingConstraintHelper::AddReasonForBeingBefore(int before, - int after) { - AddOtherReason(before); - AddOtherReason(after); - - // The reason will be a linear expression greater than a value. Note that all - // coeff must be positive, and we will use the variable lower bound. - std::vector vars; - std::vector coeffs; - - // Reason for StartMax(before). - const IntegerValue smax_before = StartMax(before); - if (smax_before >= integer_trail_->UpperBound(starts_[before])) { - if (starts_[before].var != kNoIntegerVariable) { - vars.push_back(NegationOf(starts_[before].var)); - coeffs.push_back(starts_[before].coeff); - } - } else { - if (ends_[before].var != kNoIntegerVariable) { - vars.push_back(NegationOf(ends_[before].var)); - coeffs.push_back(ends_[before].coeff); - } - if (sizes_[before].var != kNoIntegerVariable) { - vars.push_back(sizes_[before].var); - coeffs.push_back(sizes_[before].coeff); - } - } - - // Reason for EndMin(after); - const IntegerValue emin_after = EndMin(after); - if (emin_after <= integer_trail_->LowerBound(ends_[after])) { - if (ends_[after].var != kNoIntegerVariable) { - vars.push_back(ends_[after].var); - coeffs.push_back(ends_[after].coeff); - } - } else { - if (starts_[after].var != kNoIntegerVariable) { - vars.push_back(starts_[after].var); - coeffs.push_back(starts_[after].coeff); - } - if (sizes_[after].var != kNoIntegerVariable) { - vars.push_back(sizes_[after].var); - coeffs.push_back(sizes_[after].coeff); - } - } - - DCHECK_LT(smax_before, emin_after); - const IntegerValue slack = emin_after - smax_before - 1; - integer_trail_->AppendRelaxedLinearReason(slack, coeffs, vars, - &integer_reason_); -} - -bool SchedulingConstraintHelper::PushIntegerLiteral(IntegerLiteral lit) { - CHECK(other_helper_ == nullptr); - return integer_trail_->Enqueue(lit, literal_reason_, integer_reason_); -} - -bool SchedulingConstraintHelper::PushIntegerLiteralIfTaskPresent( - int t, IntegerLiteral lit) { - if (IsAbsent(t)) return true; - AddOtherReason(t); - ImportOtherReasons(); - if (IsOptional(t)) { - return integer_trail_->ConditionalEnqueue( - PresenceLiteral(t), lit, &literal_reason_, &integer_reason_); - } - return integer_trail_->Enqueue(lit, literal_reason_, integer_reason_); -} - -// We also run directly the precedence propagator for this variable so that when -// we push an interval start for example, we have a chance to push its end. -bool SchedulingConstraintHelper::PushIntervalBound(int t, IntegerLiteral lit) { - if (!PushIntegerLiteralIfTaskPresent(t, lit)) return false; - if (IsAbsent(t)) return true; - if (!UpdateCachedValues(t)) return false; - recompute_cache_.Clear(t); - return true; -} - -bool SchedulingConstraintHelper::IncreaseStartMin(int t, IntegerValue value) { - if (starts_[t].var == kNoIntegerVariable) { - if (value > starts_[t].constant) return PushTaskAbsence(t); - return true; - } - return PushIntervalBound(t, starts_[t].GreaterOrEqual(value)); -} - -bool SchedulingConstraintHelper::IncreaseEndMin(int t, IntegerValue value) { - if (ends_[t].var == kNoIntegerVariable) { - if (value > ends_[t].constant) return PushTaskAbsence(t); - return true; - } - return PushIntervalBound(t, ends_[t].GreaterOrEqual(value)); -} - -bool SchedulingConstraintHelper::DecreaseEndMax(int t, IntegerValue value) { - if (ends_[t].var == kNoIntegerVariable) { - if (value < ends_[t].constant) return PushTaskAbsence(t); - return true; - } - return PushIntervalBound(t, ends_[t].LowerOrEqual(value)); -} - -bool SchedulingConstraintHelper::PushLiteral(Literal l) { - integer_trail_->EnqueueLiteral(l, literal_reason_, integer_reason_); - return true; -} - -bool SchedulingConstraintHelper::PushTaskAbsence(int t) { - if (IsAbsent(t)) return true; - if (!IsOptional(t)) return ReportConflict(); - - AddOtherReason(t); - - if (IsPresent(t)) { - literal_reason_.push_back(Literal(reason_for_presence_[t]).Negated()); - return ReportConflict(); - } - ImportOtherReasons(); - integer_trail_->EnqueueLiteral(Literal(reason_for_presence_[t]).Negated(), - literal_reason_, integer_reason_); - return true; -} - -bool SchedulingConstraintHelper::PushTaskPresence(int t) { - DCHECK_NE(reason_for_presence_[t], kNoLiteralIndex); - DCHECK(!IsPresent(t)); - - AddOtherReason(t); - - if (IsAbsent(t)) { - literal_reason_.push_back(Literal(reason_for_presence_[t])); - return ReportConflict(); - } - ImportOtherReasons(); - integer_trail_->EnqueueLiteral(Literal(reason_for_presence_[t]), - literal_reason_, integer_reason_); - return true; -} - -bool SchedulingConstraintHelper::ReportConflict() { - ImportOtherReasons(); - return integer_trail_->ReportConflict(literal_reason_, integer_reason_); -} - -void SchedulingConstraintHelper::WatchAllTasks(int id, bool watch_max_side) { - // In all cases, we watch presence literals since this class is not waked up - // when those changes. - const int num_tasks = starts_.size(); - for (int t = 0; t < num_tasks; ++t) { - if (!IsPresent(t) && !IsAbsent(t)) { - watcher_->WatchLiteral(Literal(reason_for_presence_[t]), id); - } - } - - // If everything is watched, it is slighlty more efficient to enqueue the - // propagator when the helper Propagate() is called. This result in less - // entries in our watched lists. - if (watch_max_side) { - propagator_ids_.push_back(id); - return; - } - - // We only watch "min" side. - for (int t = 0; t < num_tasks; ++t) { - watcher_->WatchLowerBound(starts_[t], id); - watcher_->WatchLowerBound(ends_[t], id); - watcher_->WatchLowerBound(sizes_[t], id); - } -} - -void SchedulingConstraintHelper::AddOtherReason(int t) { - if (other_helper_ == nullptr || already_added_to_other_reasons_[t]) return; - already_added_to_other_reasons_[t] = true; - const int mapped_t = map_to_other_helper_[t]; - other_helper_->AddStartMaxReason(mapped_t, event_for_other_helper_); - other_helper_->AddEndMinReason(mapped_t, event_for_other_helper_ + 1); -} - -void SchedulingConstraintHelper::ImportOtherReasons() { - if (other_helper_ != nullptr) ImportOtherReasons(*other_helper_); -} - -void SchedulingConstraintHelper::ImportOtherReasons( - const SchedulingConstraintHelper& other_helper) { - literal_reason_.insert(literal_reason_.end(), - other_helper.literal_reason_.begin(), - other_helper.literal_reason_.end()); - integer_reason_.insert(integer_reason_.end(), - other_helper.integer_reason_.begin(), - other_helper.integer_reason_.end()); -} - -std::string SchedulingConstraintHelper::TaskDebugString(int t) const { - return absl::StrCat("t=", t, " is_present=", - (IsPresent(t) ? "1" - : IsAbsent(t) ? "0" - : "?"), - " size=[", SizeMin(t).value(), ",", SizeMax(t).value(), - "]", " start=[", StartMin(t).value(), ",", - StartMax(t).value(), "]", " end=[", EndMin(t).value(), - ",", EndMax(t).value(), "]"); -} - -IntegerValue SchedulingConstraintHelper::GetMinOverlap(int t, - IntegerValue start, - IntegerValue end) const { - return std::min(std::min(end - start, SizeMin(t)), - std::min(EndMin(t) - start, end - StartMax(t))); -} - -IntegerValue ComputeEnergyMinInWindow( - IntegerValue start_min, IntegerValue start_max, IntegerValue end_min, - IntegerValue end_max, IntegerValue size_min, IntegerValue demand_min, - absl::Span filtered_energy, - IntegerValue window_start, IntegerValue window_end) { - if (window_end <= window_start) return IntegerValue(0); - - // Returns zero if the interval do not necessarily overlap. - if (end_min <= window_start) return IntegerValue(0); - if (start_max >= window_end) return IntegerValue(0); - const IntegerValue window_size = window_end - window_start; - const IntegerValue simple_energy_min = - demand_min * std::min({end_min - window_start, window_end - start_max, - size_min, window_size}); - if (filtered_energy.empty()) return simple_energy_min; - - IntegerValue result = kMaxIntegerValue; - for (const auto [lit, fixed_size, fixed_demand] : filtered_energy) { - const IntegerValue alt_end_min = std::max(end_min, start_min + fixed_size); - const IntegerValue alt_start_max = - std::min(start_max, end_max - fixed_size); - const IntegerValue energy_min = - fixed_demand * - std::min({alt_end_min - window_start, window_end - alt_start_max, - fixed_size, window_size}); - result = std::min(result, energy_min); - } - if (result == kMaxIntegerValue) return simple_energy_min; - return std::max(simple_energy_min, result); -} - -SchedulingDemandHelper::SchedulingDemandHelper( - absl::Span demands, - SchedulingConstraintHelper* helper, Model* model) - : integer_trail_(model->GetOrCreate()), - product_decomposer_(model->GetOrCreate()), - sat_solver_(model->GetOrCreate()), - assignment_(model->GetOrCreate()->Assignment()), - demands_(demands.begin(), demands.end()), - helper_(helper) { - const int num_tasks = helper->NumTasks(); - linearized_energies_.resize(num_tasks); - decomposed_energies_.resize(num_tasks); - cached_energies_min_.resize(num_tasks, kMinIntegerValue); - cached_energies_max_.resize(num_tasks, kMaxIntegerValue); - energy_is_quadratic_.resize(num_tasks, false); - - // We try to init decomposed energies. This is needed for the cuts that are - // created after we call InitAllDecomposedEnergies(). - InitDecomposedEnergies(); -} - -void SchedulingDemandHelper::InitDecomposedEnergies() { - // For the special case were demands is empty. - const int num_tasks = helper_->NumTasks(); - if (demands_.size() != num_tasks) return; - for (int t = 0; t < num_tasks; ++t) { - const AffineExpression size = helper_->Sizes()[t]; - const AffineExpression demand = demands_[t]; - decomposed_energies_[t] = product_decomposer_->TryToDecompose(size, demand); - } -} - -IntegerValue SchedulingDemandHelper::SimpleEnergyMin(int t) const { - if (demands_.empty()) return kMinIntegerValue; - return CapProdI(DemandMin(t), helper_->SizeMin(t)); -} - -IntegerValue SchedulingDemandHelper::LinearEnergyMin(int t) const { - if (!linearized_energies_[t].has_value()) return kMinIntegerValue; - return linearized_energies_[t]->Min(*integer_trail_); -} - -IntegerValue SchedulingDemandHelper::DecomposedEnergyMin(int t) const { - if (decomposed_energies_[t].empty()) return kMinIntegerValue; - IntegerValue result = kMaxIntegerValue; - for (const auto [lit, fixed_size, fixed_demand] : decomposed_energies_[t]) { - if (assignment_.LiteralIsTrue(lit)) { - return fixed_size * fixed_demand; - } - if (assignment_.LiteralIsFalse(lit)) continue; - result = std::min(result, fixed_size * fixed_demand); - } - DCHECK_NE(result, kMaxIntegerValue); - return result; -} - -IntegerValue SchedulingDemandHelper::SimpleEnergyMax(int t) const { - if (demands_.empty()) return kMaxIntegerValue; - return CapProdI(DemandMax(t), helper_->SizeMax(t)); -} - -IntegerValue SchedulingDemandHelper::LinearEnergyMax(int t) const { - if (!linearized_energies_[t].has_value()) return kMaxIntegerValue; - return linearized_energies_[t]->Max(*integer_trail_); -} - -IntegerValue SchedulingDemandHelper::DecomposedEnergyMax(int t) const { - if (decomposed_energies_[t].empty()) return kMaxIntegerValue; - IntegerValue result = kMinIntegerValue; - for (const auto [lit, fixed_size, fixed_demand] : decomposed_energies_[t]) { - if (assignment_.LiteralIsTrue(lit)) { - return fixed_size * fixed_demand; - } - if (assignment_.LiteralIsFalse(lit)) continue; - result = std::max(result, fixed_size * fixed_demand); - } - DCHECK_NE(result, kMinIntegerValue); - return result; -} - -bool SchedulingDemandHelper::CacheAllEnergyValues() { - const int num_tasks = cached_energies_min_.size(); - const bool is_at_level_zero = sat_solver_->CurrentDecisionLevel() == 0; - for (int t = 0; t < num_tasks; ++t) { - // Try to reduce the size of the decomposed energy vector. - if (is_at_level_zero) { - int new_size = 0; - for (int i = 0; i < decomposed_energies_[t].size(); ++i) { - if (assignment_.LiteralIsFalse(decomposed_energies_[t][i].literal)) { - continue; - } - decomposed_energies_[t][new_size++] = decomposed_energies_[t][i]; - } - decomposed_energies_[t].resize(new_size); - } - - cached_energies_min_[t] = std::max( - {SimpleEnergyMin(t), LinearEnergyMin(t), DecomposedEnergyMin(t)}); - if (cached_energies_min_[t] <= kMinIntegerValue) return false; - energy_is_quadratic_[t] = - decomposed_energies_[t].empty() && !demands_.empty() && - !integer_trail_->IsFixed(demands_[t]) && !helper_->SizeIsFixed(t); - cached_energies_max_[t] = std::min( - {SimpleEnergyMax(t), LinearEnergyMax(t), DecomposedEnergyMax(t)}); - if (cached_energies_max_[t] >= kMaxIntegerValue) return false; - } - - return true; -} - -IntegerValue SchedulingDemandHelper::DemandMin(int t) const { - DCHECK_LT(t, demands_.size()); - return integer_trail_->LowerBound(demands_[t]); -} - -IntegerValue SchedulingDemandHelper::DemandMax(int t) const { - DCHECK_LT(t, demands_.size()); - return integer_trail_->UpperBound(demands_[t]); -} - -bool SchedulingDemandHelper::DemandIsFixed(int t) const { - return integer_trail_->IsFixed(demands_[t]); -} - -bool SchedulingDemandHelper::DecreaseEnergyMax(int t, IntegerValue value) { - if (value < EnergyMin(t)) { - if (helper_->IsOptional(t)) { - return helper_->PushTaskAbsence(t); - } else { - return helper_->ReportConflict(); - } - } else if (!decomposed_energies_[t].empty()) { - for (const auto [lit, fixed_size, fixed_demand] : decomposed_energies_[t]) { - if (fixed_size * fixed_demand > value) { - if (assignment_.LiteralIsTrue(lit)) return helper_->ReportConflict(); - if (assignment_.LiteralIsFalse(lit)) continue; - if (!helper_->PushLiteral(lit.Negated())) return false; - } - } - } else if (linearized_energies_[t].has_value() && - linearized_energies_[t]->vars.size() == 1) { - const LinearExpression& e = linearized_energies_[t].value(); - const AffineExpression affine_energy(e.vars[0], e.coeffs[0], e.offset); - const IntegerLiteral deduction = affine_energy.LowerOrEqual(value); - if (!helper_->PushIntegerLiteralIfTaskPresent(t, deduction)) { - return false; - } - } else { - // TODO(user): Propagate if possible. - VLOG(3) << "Cumulative energy missed propagation"; - } - return true; -} - -void SchedulingDemandHelper::AddDemandMinReason(int t) { - DCHECK_LT(t, demands_.size()); - if (demands_[t].var != kNoIntegerVariable) { - helper_->MutableIntegerReason()->push_back( - integer_trail_->LowerBoundAsLiteral(demands_[t].var)); - } -} - -void SchedulingDemandHelper::AddDemandMinReason(int t, - IntegerValue min_demand) { - DCHECK_LT(t, demands_.size()); - if (demands_[t].var != kNoIntegerVariable) { - helper_->MutableIntegerReason()->push_back( - demands_[t].GreaterOrEqual(min_demand)); - } -} - -void SchedulingDemandHelper::AddEnergyMinReason(int t) { - // We prefer these reason in order. - const IntegerValue value = cached_energies_min_[t]; - if (DecomposedEnergyMin(t) >= value) { - auto* reason = helper_->MutableLiteralReason(); - const int old_size = reason->size(); - for (const auto [lit, fixed_size, fixed_demand] : decomposed_energies_[t]) { - if (assignment_.LiteralIsTrue(lit)) { - reason->resize(old_size); - reason->push_back(lit.Negated()); - return; - } else if (fixed_size * fixed_demand < value && - assignment_.LiteralIsFalse(lit)) { - reason->push_back(lit); - } - } - } else if (SimpleEnergyMin(t) >= value) { - AddDemandMinReason(t); - helper_->AddSizeMinReason(t); - } else { - DCHECK_GE(LinearEnergyMin(t), value); - for (const IntegerVariable var : linearized_energies_[t]->vars) { - helper_->MutableIntegerReason()->push_back( - integer_trail_->LowerBoundAsLiteral(var)); - } - } -} - -bool SchedulingDemandHelper::AddLinearizedDemand( - int t, LinearConstraintBuilder* builder) const { - if (helper_->IsPresent(t)) { - if (!decomposed_energies_[t].empty()) { - for (const LiteralValueValue& entry : decomposed_energies_[t]) { - if (!builder->AddLiteralTerm(entry.literal, entry.right_value)) { - return false; - } - } - } else { - builder->AddTerm(demands_[t], IntegerValue(1)); - } - } else if (!helper_->IsAbsent(t)) { - return builder->AddLiteralTerm(helper_->PresenceLiteral(t), DemandMin(t)); - } - return true; -} - -void SchedulingDemandHelper::OverrideLinearizedEnergies( - absl::Span energies) { - const int num_tasks = energies.size(); - DCHECK_EQ(num_tasks, helper_->NumTasks()); - linearized_energies_.resize(num_tasks); - for (int t = 0; t < num_tasks; ++t) { - linearized_energies_[t] = energies[t]; - if (DEBUG_MODE) { - for (const IntegerValue coeff : linearized_energies_[t]->coeffs) { - DCHECK_GE(coeff, 0); - } - } - } -} - -std::vector SchedulingDemandHelper::FilteredDecomposedEnergy( - int index) { - if (decomposed_energies_[index].empty()) return {}; - if (sat_solver_->CurrentDecisionLevel() == 0) { - // CacheAllEnergyValues has already filtered false literals. - return decomposed_energies_[index]; - } - - // Scan and filter false literals. - std::vector result; - for (const auto& e : decomposed_energies_[index]) { - if (assignment_.LiteralIsFalse(e.literal)) continue; - result.push_back(e); - } - return result; -} - -void SchedulingDemandHelper::OverrideDecomposedEnergies( - const std::vector>& energies) { - DCHECK_EQ(energies.size(), helper_->NumTasks()); - decomposed_energies_ = energies; -} - -IntegerValue SchedulingDemandHelper::EnergyMinInWindow( - int t, IntegerValue window_start, IntegerValue window_end) { - return ComputeEnergyMinInWindow( - helper_->StartMin(t), helper_->StartMax(t), helper_->EndMin(t), - helper_->EndMax(t), helper_->SizeMin(t), DemandMin(t), - FilteredDecomposedEnergy(t), window_start, window_end); -} - -// Since we usually ask way less often for the reason, we redo the computation -// here. -void SchedulingDemandHelper::AddEnergyMinInWindowReason( - int t, IntegerValue window_start, IntegerValue window_end) { - const IntegerValue actual_energy_min = - EnergyMinInWindow(t, window_start, window_end); - if (actual_energy_min == 0) return; - - // Return simple reason right away if there is no decomposition or the simple - // energy is enough. - const IntegerValue start_max = helper_->StartMax(t); - const IntegerValue end_min = helper_->EndMin(t); - const IntegerValue min_overlap = - helper_->GetMinOverlap(t, window_start, window_end); - const IntegerValue simple_energy_min = DemandMin(t) * min_overlap; - if (simple_energy_min == actual_energy_min) { - AddDemandMinReason(t); - helper_->AddSizeMinReason(t); - helper_->AddStartMaxReason(t, start_max); - helper_->AddEndMinReason(t, end_min); - return; - } - - // TODO(user): only include the one we need? - const IntegerValue start_min = helper_->StartMin(t); - const IntegerValue end_max = helper_->EndMax(t); - DCHECK(!decomposed_energies_[t].empty()); - helper_->AddStartMinReason(t, start_min); - helper_->AddStartMaxReason(t, start_max); - helper_->AddEndMinReason(t, end_min); - helper_->AddEndMaxReason(t, end_max); - - auto* literal_reason = helper_->MutableLiteralReason(); - const int old_size = literal_reason->size(); - - DCHECK(!decomposed_energies_[t].empty()); - for (const auto [lit, fixed_size, fixed_demand] : decomposed_energies_[t]) { - // Should be the same in most cases. - if (assignment_.LiteralIsTrue(lit)) { - literal_reason->resize(old_size); - literal_reason->push_back(lit.Negated()); - return; - } - if (assignment_.LiteralIsFalse(lit)) { - const IntegerValue alt_em = std::max(end_min, start_min + fixed_size); - const IntegerValue alt_sm = std::min(start_max, end_max - fixed_size); - const IntegerValue energy_min = - fixed_demand * - std::min({alt_em - window_start, window_end - alt_sm, fixed_size}); - if (energy_min >= actual_energy_min) continue; - literal_reason->push_back(lit); - } - } -} - -void AddIntegerVariableFromIntervals(SchedulingConstraintHelper* helper, - Model* model, - std::vector* vars) { - IntegerEncoder* encoder = model->GetOrCreate(); - for (int t = 0; t < helper->NumTasks(); ++t) { - if (helper->Starts()[t].var != kNoIntegerVariable) { - vars->push_back(helper->Starts()[t].var); - } - if (helper->Sizes()[t].var != kNoIntegerVariable) { - vars->push_back(helper->Sizes()[t].var); - } - if (helper->Ends()[t].var != kNoIntegerVariable) { - vars->push_back(helper->Ends()[t].var); - } - if (helper->IsOptional(t) && !helper->IsAbsent(t) && - !helper->IsPresent(t)) { - const Literal l = helper->PresenceLiteral(t); - IntegerVariable view = kNoIntegerVariable; - if (!encoder->LiteralOrNegationHasView(l, &view)) { - view = model->Add(NewIntegerVariableFromLiteral(l)); - } - vars->push_back(view); - } - } -} - -void AppendVariablesFromCapacityAndDemands( - const AffineExpression& capacity, SchedulingDemandHelper* demands_helper, - Model* model, std::vector* vars) { - auto* integer_trail = model->GetOrCreate(); - for (const AffineExpression& demand_expr : demands_helper->Demands()) { - if (!integer_trail->IsFixed(demand_expr)) { - vars->push_back(demand_expr.var); - } - } - IntegerEncoder* encoder = model->GetOrCreate(); - for (const auto& product : demands_helper->DecomposedEnergies()) { - for (const auto& lit_val_val : product) { - IntegerVariable view = kNoIntegerVariable; - if (!encoder->LiteralOrNegationHasView(lit_val_val.literal, &view)) { - view = model->Add(NewIntegerVariableFromLiteral(lit_val_val.literal)); - } - vars->push_back(view); - } - } - - if (!integer_trail->IsFixed(capacity)) { - vars->push_back(capacity.var); - } -} - } // namespace sat } // namespace operations_research diff --git a/ortools/sat/intervals.h b/ortools/sat/intervals.h index cd38ae8a6a..bc8b4862f0 100644 --- a/ortools/sat/intervals.h +++ b/ortools/sat/intervals.h @@ -16,38 +16,27 @@ #include #include -#include #include -#include #include #include -#include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/types/span.h" #include "ortools/base/strong_vector.h" #include "ortools/sat/clause.h" -#include "ortools/sat/implied_bounds.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" -#include "ortools/sat/linear_constraint.h" #include "ortools/sat/model.h" -#include "ortools/sat/precedences.h" +#include "ortools/sat/no_overlap_2d_helper.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_solver.h" -#include "ortools/util/bitset.h" +#include "ortools/sat/scheduling_helpers.h" #include "ortools/util/strong_integers.h" namespace operations_research { namespace sat { -DEFINE_STRONG_INDEX_TYPE(IntervalVariable); -const IntervalVariable kNoIntervalVariable(-1); - -class SchedulingConstraintHelper; -class SchedulingDemandHelper; - // This class maintains a set of intervals which correspond to three integer // variables (start, end and size). It automatically registers with the // PrecedencesPropagator the relation between the bounds of each interval and @@ -138,6 +127,10 @@ class IntervalsRepository { const std::vector& variables, bool register_as_disjunctive_helper = false); + NoOverlap2DConstraintHelper* GetOrCreate2DHelper( + const std::vector& x_variables, + const std::vector& y_variables); + // Returns a SchedulingDemandHelper corresponding to the given helper and // demands. Note that the order of interval in the helper and the order of // demands must be the compatible. @@ -157,18 +150,18 @@ class IntervalsRepository { // If such literal already exists this returns it. void CreateDisjunctivePrecedenceLiteral(IntervalVariable a, IntervalVariable b); + LiteralIndex GetOrCreateDisjunctivePrecedenceLiteral( + const IntervalDefinition& a, const IntervalDefinition& b); - // Creates a literal l <=> start_b >= end_a. + // Creates a literal l <=> y >= x. // Returns true if such literal is "non-trivial" and was created. - // Note that this ignore the optionality of a or b, it just creates a literal - // comparing the two affine expression. - bool CreatePrecedenceLiteral(IntervalVariable a, IntervalVariable b); + bool CreatePrecedenceLiteral(AffineExpression x, AffineExpression y); - // Returns a literal l <=> start_b >= end_a if it exist or kNoLiteralIndex + // Returns a literal l <=> y >= x if it exist or kNoLiteralIndex // otherwise. This could be the one created by // CreateDisjunctivePrecedenceLiteral() or CreatePrecedenceLiteral(). - LiteralIndex GetPrecedenceLiteral(IntervalVariable a, - IntervalVariable b) const; + LiteralIndex GetPrecedenceLiteral(AffineExpression x, + AffineExpression y) const; const std::vector& AllDisjunctiveHelpers() const { @@ -211,6 +204,10 @@ class IntervalsRepository { absl::flat_hash_map, SchedulingConstraintHelper*> helper_repository_; + absl::flat_hash_map< + std::pair, std::vector>, + NoOverlap2DConstraintHelper*> + no_overlap_2d_helper_repository_; absl::flat_hash_map< std::pair>, SchedulingDemandHelper*> @@ -221,7 +218,8 @@ class IntervalsRepository { // Note that for normal precedences, we use directly the affine expression so // that if many intervals share the same start, we don't re-create Booleans // for no reason. - absl::flat_hash_map, Literal> + absl::flat_hash_map, + Literal> disjunctive_precedences_; absl::flat_hash_map, Literal> precedences_; @@ -231,713 +229,6 @@ class IntervalsRepository { std::vector cumulative_helpers_; }; -// An helper struct to sort task by time. This is used by the -// SchedulingConstraintHelper but also by many scheduling propagators to sort -// tasks. -struct TaskTime { - int task_index; - IntegerValue time; - bool operator<(TaskTime other) const { return time < other.time; } - bool operator>(TaskTime other) const { return time > other.time; } -}; - -// We have some free space in TaskTime. -// We stick the presence_lit to save an indirection in some algo. -// -// TODO(user): Experiment caching more value. In particular -// TaskByIncreasingShiftedStartMin() could tie break task for better heuristics? -struct CachedTaskBounds { - int task_index; - LiteralIndex presence_lit; - IntegerValue time; - bool operator<(CachedTaskBounds other) const { return time < other.time; } - bool operator>(CachedTaskBounds other) const { return time > other.time; } -}; - -// Helper class shared by the propagators that manage a given list of tasks. -// -// One of the main advantage of this class is that it allows to share the -// vectors of tasks sorted by various criteria between propagator for a faster -// code. -class SchedulingConstraintHelper : public PropagatorInterface { - public: - // All the functions below refer to a task by its index t in the tasks - // vector given at construction. - SchedulingConstraintHelper(const std::vector& tasks, - Model* model); - - // Temporary constructor. - // The class will not be usable until ResetFromSubset() is called. - // - // TODO(user): Remove this. It is a hack because the disjunctive class needs - // to fetch the maximum possible number of task at construction. - SchedulingConstraintHelper(int num_tasks, Model* model); - - // This is a propagator so we can "cache" all the intervals relevant - // information. This gives good speedup. Note however that the info is stale - // except if a bound was pushed by this helper or if this was called. We run - // it at the highest priority, so that will mostly be the case at the - // beginning of each Propagate() call of the classes using this. - bool Propagate() final; - bool IncrementalPropagate(const std::vector& watch_indices) final; - void RegisterWith(GenericLiteralWatcher* watcher); - - // Resets the class to the same state as if it was constructed with - // the given subset of tasks from other. - ABSL_MUST_USE_RESULT bool ResetFromSubset( - const SchedulingConstraintHelper& other, absl::Span tasks); - - // Returns the number of task. - int NumTasks() const { return starts_.size(); } - - // Make sure the cached values are up to date. Also sets the time direction to - // either forward/backward. This will impact all the functions below. This - // MUST be called at the beginning of all Propagate() call that uses this - // helper. - void SetTimeDirection(bool is_forward); - bool CurrentTimeIsForward() const { return current_time_direction_; } - ABSL_MUST_USE_RESULT bool SynchronizeAndSetTimeDirection(bool is_forward); - - // Helpers for the current bounds on the current task time window. - // [ (size-min) ... (size-min) ] - // ^ ^ ^ ^ - // start-min end-min start-max end-max - // - // Note that for tasks with variable durations, we don't necessarily have - // duration-min between the XXX-min and XXX-max value. - // - // Remark: We use cached values for most of these function as this is faster. - // In practice, the cache will almost always be up to date, but not in corner - // cases where pushing the start of one task will change values for many - // others. This is fine as the new values will be picked up as we reach the - // propagation fixed point. - IntegerValue SizeMin(int t) const { return cached_size_min_[t]; } - IntegerValue SizeMax(int t) const { - // This one is "rare" so we don't cache it. - return integer_trail_->UpperBound(sizes_[t]); - } - IntegerValue StartMin(int t) const { return cached_start_min_[t]; } - IntegerValue EndMin(int t) const { return cached_end_min_[t]; } - IntegerValue StartMax(int t) const { return -cached_negated_start_max_[t]; } - IntegerValue EndMax(int t) const { return -cached_negated_end_max_[t]; } - - IntegerValue LevelZeroStartMin(int t) const { - return integer_trail_->LevelZeroLowerBound(starts_[t]); - } - IntegerValue LevelZeroStartMax(int t) const { - return integer_trail_->LevelZeroUpperBound(starts_[t]); - } - IntegerValue LevelZeroEndMax(int t) const { - return integer_trail_->LevelZeroUpperBound(ends_[t]); - } - - // In the presence of tasks with a variable size, we do not necessarily - // have start_min + size_min = end_min, we can instead have a situation - // like: - // | |<--- size-min --->| - // ^ ^ ^ - // start-min | end-min - // | - // We define the "shifted start min" to be the right most time such that - // we known that we must have min-size "energy" to the right of it if the - // task is present. Using it in our scheduling propagators allows to propagate - // more in the presence of tasks with variable size (or optional task - // where we also do not necessarily have start_min + size_min = end_min. - // - // To explain this shifted start min, one must use the AddEnergyAfterReason(). - IntegerValue ShiftedStartMin(int t) const { - return cached_shifted_start_min_[t]; - } - - // As with ShiftedStartMin(), we can compute the shifted end max (that is - // start_max + size_min. - IntegerValue ShiftedEndMax(int t) const { - return -cached_negated_shifted_end_max_[t]; - } - - bool StartIsFixed(int t) const; - bool EndIsFixed(int t) const; - bool SizeIsFixed(int t) const; - - // Returns true if the corresponding fact is known for sure. A normal task is - // always present. For optional task for which the presence is still unknown, - // both of these function will return false. - bool IsOptional(int t) const; - bool IsPresent(int t) const; - bool IsAbsent(int t) const; - - // Same if one already have the presence LiteralIndex of a task. - bool IsOptional(LiteralIndex lit) const; - bool IsPresent(LiteralIndex lit) const; - bool IsAbsent(LiteralIndex lit) const; - - // Return a value so that End(a) + dist <= Start(b). - // Returns kMinInterValue if we don't have any such relation. - IntegerValue GetCurrentMinDistanceBetweenTasks( - int a, int b, bool add_reason_if_after = false); - - // We detected a precedence between two tasks. - // If we are at level zero, we might want to add the constraint. - // If we are at positive level, we might want to propagate the associated - // precedence literal if it exists. - bool PropagatePrecedence(int a, int b); - - // Return the minimum overlap of interval i with the time window [start..end]. - // - // Note: this is different from the mandatory part of an interval. - IntegerValue GetMinOverlap(int t, IntegerValue start, IntegerValue end) const; - - // Returns a string with the current task bounds. - std::string TaskDebugString(int t) const; - - // Sorts and returns the tasks in corresponding order at the time of the call. - // Note that we do not mean strictly-increasing/strictly-decreasing, there - // will be duplicate time values in these vectors. - // - // TODO(user): we could merge the first loop of IncrementalSort() with the - // loop that fill TaskTime.time at each call. - absl::Span TaskByIncreasingStartMin(); - absl::Span TaskByDecreasingEndMax(); - - absl::Span TaskByIncreasingNegatedStartMax(); - absl::Span TaskByIncreasingEndMin(); - - absl::Span TaskByIncreasingShiftedStartMin(); - - // Returns a sorted vector where each task appear twice, the first occurrence - // is at size (end_min - size_min) and the second one at (end_min). - // - // This is quite usage specific. - struct ProfileEvent { - IntegerValue time; - int task; - bool is_first; - - bool operator<(const ProfileEvent& other) const { - if (time == other.time) { - if (task == other.task) return is_first > other.is_first; - return task < other.task; - } - return time < other.time; - } - }; - const std::vector& GetEnergyProfile(); - - // Functions to clear and then set the current reason. - void ClearReason(); - void AddPresenceReason(int t); - void AddAbsenceReason(int t); - void AddSizeMinReason(int t); - void AddSizeMinReason(int t, IntegerValue lower_bound); - void AddSizeMaxReason(int t, IntegerValue upper_bound); - void AddStartMinReason(int t, IntegerValue lower_bound); - void AddStartMaxReason(int t, IntegerValue upper_bound); - void AddEndMinReason(int t, IntegerValue lower_bound); - void AddEndMaxReason(int t, IntegerValue upper_bound); - void AddShiftedEndMaxReason(int t, IntegerValue upper_bound); - - void AddEnergyAfterReason(int t, IntegerValue energy_min, IntegerValue time); - void AddEnergyMinInIntervalReason(int t, IntegerValue min, IntegerValue max); - - // Adds the reason why task "before" must be before task "after". - // That is StartMax(before) < EndMin(after). - void AddReasonForBeingBefore(int before, int after); - - // It is also possible to directly manipulates the underlying reason vectors - // that will be used when pushing something. - std::vector* MutableLiteralReason() { return &literal_reason_; } - std::vector* MutableIntegerReason() { - return &integer_reason_; - } - - // Push something using the current reason. Note that IncreaseStartMin() will - // also increase the end-min, and DecreaseEndMax() will also decrease the - // start-max. - // - // Important: IncreaseStartMin() and DecreaseEndMax() can be called on an - // optional interval whose presence is still unknown and push a bound - // conditioned on its presence. The functions will do the correct thing - // depending on whether or not the start_min/end_max are optional variables - // whose presence implies the interval presence. - ABSL_MUST_USE_RESULT bool IncreaseStartMin(int t, IntegerValue value); - ABSL_MUST_USE_RESULT bool IncreaseEndMin(int t, IntegerValue value); - ABSL_MUST_USE_RESULT bool DecreaseEndMax(int t, IntegerValue value); - ABSL_MUST_USE_RESULT bool PushLiteral(Literal l); - ABSL_MUST_USE_RESULT bool PushTaskAbsence(int t); - ABSL_MUST_USE_RESULT bool PushTaskPresence(int t); - ABSL_MUST_USE_RESULT bool PushIntegerLiteral(IntegerLiteral lit); - ABSL_MUST_USE_RESULT bool ReportConflict(); - ABSL_MUST_USE_RESULT bool PushIntegerLiteralIfTaskPresent(int t, - IntegerLiteral lit); - - // Returns the underlying affine expressions. - absl::Span IntervalVariables() const { - return interval_variables_; - } - absl::Span Starts() const { return starts_; } - absl::Span Ends() const { return ends_; } - absl::Span Sizes() const { return sizes_; } - - Literal PresenceLiteral(int index) const { - DCHECK(IsOptional(index)); - return Literal(reason_for_presence_[index]); - } - - // Registers the given propagator id to be called if any of the tasks - // in this class change. Note that we do not watch size max though. - void WatchAllTasks(int id, bool watch_max_side = true); - - // Manages the other helper (used by the diffn constraint). - // - // For each interval appearing in a reason on this helper, another reason - // will be added. This other reason specifies that on the other helper, the - // corresponding interval overlaps 'event'. - void SetOtherHelper(SchedulingConstraintHelper* other_helper, - absl::Span map_to_other_helper, - IntegerValue event) { - CHECK(other_helper != nullptr); - other_helper_ = other_helper; - map_to_other_helper_ = map_to_other_helper; - event_for_other_helper_ = event; - } - - bool HasOtherHelper() const { return other_helper_ != nullptr; } - - void ClearOtherHelper() { other_helper_ = nullptr; } - - // Adds to this helper reason all the explanation of the other helper. - // This checks that other_helper_ is null. - // - // This is used in the 2D energetic reasoning in the diffn constraint. - void ImportOtherReasons(const SchedulingConstraintHelper& other_helper); - - // TODO(user): Change the propagation loop code so that we don't stop - // pushing in the middle of the propagation as more advanced propagator do - // not handle this correctly. - bool InPropagationLoop() const { return integer_trail_->InPropagationLoop(); } - - int CurrentDecisionLevel() const { return trail_->CurrentDecisionLevel(); } - - private: - // Tricky: when a task is optional, it is possible it size min is negative, - // but we know that if a task is present, its size should be >= 0. So in the - // reason, when we need the size_min and it is currently negative, we can just - // ignore it and use zero instead. - AffineExpression NegatedSizeOrZero(int t) { - if (integer_trail_->LowerBound(sizes_[t]) <= 0) { - return AffineExpression(0); - } - return sizes_[t].Negated(); - } - - // Generic reason for a <= upper_bound, given that a = b + c in case the - // current upper bound of a is not good enough. - void AddGenericReason(const AffineExpression& a, IntegerValue upper_bound, - const AffineExpression& b, const AffineExpression& c); - - void InitSortedVectors(); - ABSL_MUST_USE_RESULT bool UpdateCachedValues(int t); - - // Internal function for IncreaseStartMin()/DecreaseEndMax(). - bool PushIntervalBound(int t, IntegerLiteral lit); - - // This will be called on any interval that is part of a reason or - // a bound push. Since the last call to ClearReason(), for each unique - // t, we will add once to other_helper_ the reason for t containing - // the point event_for_other_helper_. - void AddOtherReason(int t); - - // Import the reasons on the other helper into this helper. - void ImportOtherReasons(); - - Model* model_; - Trail* trail_; - SatSolver* sat_solver_; - IntegerTrail* integer_trail_; - GenericLiteralWatcher* watcher_; - PrecedenceRelations* precedence_relations_; - - // The current direction of time, true for forward, false for backward. - bool current_time_direction_ = true; - - // All the underlying variables of the tasks. - // The vectors are indexed by the task index t. - std::vector interval_variables_; - std::vector starts_; - std::vector ends_; - std::vector sizes_; - std::vector reason_for_presence_; - - // The negation of the start/end variable so that SetTimeDirection() - // can do its job in O(1) instead of calling NegationOf() on each entry. - std::vector minus_starts_; - std::vector minus_ends_; - - // This is used to detect when we need to invalidate the cache. - int64_t saved_num_backtracks_ = 0; - - // The caches of all relevant interval values. - // These are initially of size capacity and never resized. - // - // TODO(user): Because of std::swap() in SetTimeDirection, we cannot mark - // most of them as "const" and as a result we loose some performance since - // the address need to be re-fetched on most access. - const int capacity_; - const std::unique_ptr cached_size_min_; - std::unique_ptr cached_start_min_; - std::unique_ptr cached_end_min_; - std::unique_ptr cached_negated_start_max_; - std::unique_ptr cached_negated_end_max_; - std::unique_ptr cached_shifted_start_min_; - std::unique_ptr cached_negated_shifted_end_max_; - - // Sorted vectors returned by the TasksBy*() functions. - std::vector task_by_increasing_start_min_; - std::vector task_by_decreasing_end_max_; - - bool recompute_by_start_max_ = true; - bool recompute_by_end_min_ = true; - std::vector task_by_increasing_negated_start_max_; - std::vector task_by_increasing_end_min_; - - // Sorted vector returned by GetEnergyProfile(). - bool recompute_energy_profile_ = true; - std::vector energy_profile_; - - // This one is the most commonly used, so we optimized a bit more its - // computation by detecting when there is nothing to do. - std::vector task_by_increasing_shifted_start_min_; - std::vector task_by_negated_shifted_end_max_; - bool recompute_shifted_start_min_ = true; - bool recompute_negated_shifted_end_max_ = true; - - // If recompute_cache_[t] is true, then we need to update all the cached - // value for the task t in SynchronizeAndSetTimeDirection(). - bool recompute_all_cache_ = true; - Bitset64 recompute_cache_; - - // Reason vectors. - std::vector literal_reason_; - std::vector integer_reason_; - - // Optional 'proxy' helper used in the diffn constraint. - SchedulingConstraintHelper* other_helper_ = nullptr; - absl::Span map_to_other_helper_; - IntegerValue event_for_other_helper_; - std::vector already_added_to_other_reasons_; - - // List of watcher to "wake-up" each time one of the task bounds changes. - std::vector propagator_ids_; -}; - -// Helper class for cumulative constraint to wrap demands and expose concept -// like energy. -// -// In a cumulative constraint, an interval always has a size and a demand, but -// it can also have a set of "selector" literals each associated with a fixed -// size / fixed demands. This allows more precise energy estimation. -// -// TODO(user): Cache energy min and reason for the non O(1) cases. -class SchedulingDemandHelper { - public: - // Hack: this can be called with and empty demand vector as long as - // OverrideEnergies() is called to define the energies. - SchedulingDemandHelper(absl::Span demands, - SchedulingConstraintHelper* helper, Model* model); - - // When defined, the interval will consume this much demand during its whole - // duration. Some propagator only relies on the "energy" and thus never uses - // this. - IntegerValue DemandMin(int t) const; - IntegerValue DemandMax(int t) const; - IntegerValue LevelZeroDemandMin(int t) const { - return integer_trail_->LevelZeroLowerBound(demands_[t]); - } - bool DemandIsFixed(int t) const; - void AddDemandMinReason(int t); - void AddDemandMinReason(int t, IntegerValue min_demand); - const std::vector& Demands() const { return demands_; } - - // Adds the linearized demand (either the affine demand expression, or the - // demand part of the decomposed energy if present) to the builder. - // It returns false and do not add any term to the builder.if any literal - // involved has no integer view. - ABSL_MUST_USE_RESULT bool AddLinearizedDemand( - int t, LinearConstraintBuilder* builder) const; - - // The "energy" is usually size * demand, but in some non-conventional usage - // it might have a more complex formula. In all case, the energy is assumed - // to be only consumed during the interval duration. - // - // Returns false if the energy can overflow and was not computed. - // - // IMPORTANT: One must call CacheAllEnergyValues() for the values to be - // updated. TODO(user): this is error prone, maybe we should revisit. But if - // there is many alternatives, we don't want to rescan the list more than a - // linear number of time per propagation. - // - // TODO(user): Add more complex EnergyMinBefore(time) once we also support - // expressing the interval as a set of alternatives. - // - // At level 0, it will filter false literals from decomposed energies. - bool CacheAllEnergyValues(); - IntegerValue EnergyMin(int t) const { return cached_energies_min_[t]; } - IntegerValue EnergyMax(int t) const { return cached_energies_max_[t]; } - bool EnergyIsQuadratic(int t) const { return energy_is_quadratic_[t]; } - void AddEnergyMinReason(int t); - - // Returns the energy min in [start, end]. - // - // Note(user): These functions are not in O(1) if the decomposition is used, - // so we have to be careful in not calling them too often. - IntegerValue EnergyMinInWindow(int t, IntegerValue window_start, - IntegerValue window_end); - void AddEnergyMinInWindowReason(int t, IntegerValue window_start, - IntegerValue window_end); - - // Important: This might not do anything depending on the representation of - // the energy we have. - ABSL_MUST_USE_RESULT bool DecreaseEnergyMax(int t, IntegerValue value); - - // Different optional representation of the energy of an interval. - // - // Important: first value is size, second value is demand. - const std::vector>& DecomposedEnergies() - const { - return decomposed_energies_; - } - - // Visible for testing. - void OverrideLinearizedEnergies(absl::Span energies); - void OverrideDecomposedEnergies( - const std::vector>& energies); - // Returns the decomposed energy terms compatible with the current literal - // assignment. It must not be used to create reasons if not at level 0. - // It returns en empty vector if the decomposed energy is not available. - // - // Important: first value is size, second value is demand. - std::vector FilteredDecomposedEnergy(int index); - - // Init all decomposed energies. It needs probing to be finished. This happens - // after the creation of the helper. - void InitDecomposedEnergies(); - - private: - IntegerValue SimpleEnergyMin(int t) const; - IntegerValue LinearEnergyMin(int t) const; - IntegerValue SimpleEnergyMax(int t) const; - IntegerValue LinearEnergyMax(int t) const; - IntegerValue DecomposedEnergyMin(int t) const; - IntegerValue DecomposedEnergyMax(int t) const; - - IntegerTrail* integer_trail_; - ProductDecomposer* product_decomposer_; - SatSolver* sat_solver_; // To get the current propagation level. - const VariablesAssignment& assignment_; - std::vector demands_; - SchedulingConstraintHelper* helper_; - - // Cached value of the energies, as it can be a bit costly to compute. - std::vector cached_energies_min_; - std::vector cached_energies_max_; - std::vector energy_is_quadratic_; - - // A representation of the energies as a set of alternative. - // If subvector is empty, we don't have this representation. - std::vector> decomposed_energies_; - - // A representation of the energies as a set of linear expression. - // If the optional is not set, we don't have this representation. - std::vector> linearized_energies_; -}; - -// ============================================================================= -// Utilities -// ============================================================================= - -IntegerValue ComputeEnergyMinInWindow( - IntegerValue start_min, IntegerValue start_max, IntegerValue end_min, - IntegerValue end_max, IntegerValue size_min, IntegerValue demand_min, - absl::Span filtered_energy, - IntegerValue window_start, IntegerValue window_end); - -// ============================================================================= -// SchedulingConstraintHelper inlined functions. -// ============================================================================= - -inline bool SchedulingConstraintHelper::StartIsFixed(int t) const { - return integer_trail_->IsFixed(starts_[t]); -} - -inline bool SchedulingConstraintHelper::EndIsFixed(int t) const { - return integer_trail_->IsFixed(ends_[t]); -} - -inline bool SchedulingConstraintHelper::SizeIsFixed(int t) const { - return integer_trail_->IsFixed(sizes_[t]); -} - -inline bool SchedulingConstraintHelper::IsOptional(int t) const { - return reason_for_presence_[t] != kNoLiteralIndex; -} - -inline bool SchedulingConstraintHelper::IsPresent(int t) const { - if (reason_for_presence_[t] == kNoLiteralIndex) return true; - return trail_->Assignment().LiteralIsTrue(Literal(reason_for_presence_[t])); -} - -inline bool SchedulingConstraintHelper::IsAbsent(int t) const { - if (reason_for_presence_[t] == kNoLiteralIndex) return false; - return trail_->Assignment().LiteralIsFalse(Literal(reason_for_presence_[t])); -} - -inline bool SchedulingConstraintHelper::IsOptional(LiteralIndex lit) const { - return lit != kNoLiteralIndex; -} - -inline bool SchedulingConstraintHelper::IsPresent(LiteralIndex lit) const { - if (lit == kNoLiteralIndex) return true; - return trail_->Assignment().LiteralIsTrue(Literal(lit)); -} - -inline bool SchedulingConstraintHelper::IsAbsent(LiteralIndex lit) const { - if (lit == kNoLiteralIndex) return false; - return trail_->Assignment().LiteralIsFalse(Literal(lit)); -} - -inline void SchedulingConstraintHelper::ClearReason() { - integer_reason_.clear(); - literal_reason_.clear(); - if (other_helper_) { - other_helper_->ClearReason(); - already_added_to_other_reasons_.assign(NumTasks(), false); - } -} - -inline void SchedulingConstraintHelper::AddPresenceReason(int t) { - DCHECK(IsPresent(t)); - AddOtherReason(t); - if (reason_for_presence_[t] != kNoLiteralIndex) { - literal_reason_.push_back(Literal(reason_for_presence_[t]).Negated()); - } -} - -inline void SchedulingConstraintHelper::AddAbsenceReason(int t) { - DCHECK(IsAbsent(t)); - AddOtherReason(t); - if (reason_for_presence_[t] != kNoLiteralIndex) { - literal_reason_.push_back(Literal(reason_for_presence_[t])); - } -} - -inline void SchedulingConstraintHelper::AddSizeMinReason(int t) { - AddSizeMinReason(t, SizeMin(t)); -} - -inline void SchedulingConstraintHelper::AddGenericReason( - const AffineExpression& a, IntegerValue upper_bound, - const AffineExpression& b, const AffineExpression& c) { - if (integer_trail_->UpperBound(a) <= upper_bound) { - if (a.var != kNoIntegerVariable) { - integer_reason_.push_back(a.LowerOrEqual(upper_bound)); - } - return; - } - CHECK_NE(a.var, kNoIntegerVariable); - - // Here we assume that the upper_bound on a comes from the bound on b + c. - const IntegerValue slack = upper_bound - integer_trail_->UpperBound(b) - - integer_trail_->UpperBound(c); - CHECK_GE(slack, 0); - if (b.var == kNoIntegerVariable && c.var == kNoIntegerVariable) return; - if (b.var == kNoIntegerVariable) { - integer_reason_.push_back(c.LowerOrEqual(upper_bound - b.constant)); - } else if (c.var == kNoIntegerVariable) { - integer_reason_.push_back(b.LowerOrEqual(upper_bound - c.constant)); - } else { - integer_trail_->AppendRelaxedLinearReason( - slack, {b.coeff, c.coeff}, {NegationOf(b.var), NegationOf(c.var)}, - &integer_reason_); - } -} - -inline void SchedulingConstraintHelper::AddSizeMinReason( - int t, IntegerValue lower_bound) { - AddOtherReason(t); - DCHECK(!IsAbsent(t)); - if (lower_bound <= 0) return; - AddGenericReason(sizes_[t].Negated(), -lower_bound, minus_ends_[t], - starts_[t]); -} - -inline void SchedulingConstraintHelper::AddSizeMaxReason( - int t, IntegerValue upper_bound) { - AddOtherReason(t); - DCHECK(!IsAbsent(t)); - AddGenericReason(sizes_[t], upper_bound, ends_[t], minus_starts_[t]); -} - -inline void SchedulingConstraintHelper::AddStartMinReason( - int t, IntegerValue lower_bound) { - AddOtherReason(t); - DCHECK(!IsAbsent(t)); - AddGenericReason(minus_starts_[t], -lower_bound, minus_ends_[t], sizes_[t]); -} - -inline void SchedulingConstraintHelper::AddStartMaxReason( - int t, IntegerValue upper_bound) { - AddOtherReason(t); - DCHECK(!IsAbsent(t)); - AddGenericReason(starts_[t], upper_bound, ends_[t], NegatedSizeOrZero(t)); -} - -inline void SchedulingConstraintHelper::AddEndMinReason( - int t, IntegerValue lower_bound) { - AddOtherReason(t); - DCHECK(!IsAbsent(t)); - AddGenericReason(minus_ends_[t], -lower_bound, minus_starts_[t], - NegatedSizeOrZero(t)); -} - -inline void SchedulingConstraintHelper::AddEndMaxReason( - int t, IntegerValue upper_bound) { - AddOtherReason(t); - DCHECK(!IsAbsent(t)); - AddGenericReason(ends_[t], upper_bound, starts_[t], sizes_[t]); -} - -inline void SchedulingConstraintHelper::AddShiftedEndMaxReason( - int t, IntegerValue upper_bound) { - AddStartMaxReason(t, upper_bound - SizeMin(t)); -} - -inline void SchedulingConstraintHelper::AddEnergyAfterReason( - int t, IntegerValue energy_min, IntegerValue time) { - if (StartMin(t) >= time) { - AddStartMinReason(t, time); - } else { - AddEndMinReason(t, time + energy_min); - } - AddSizeMinReason(t, energy_min); -} - -inline void SchedulingConstraintHelper::AddEnergyMinInIntervalReason( - int t, IntegerValue time_min, IntegerValue time_max) { - const IntegerValue energy_min = SizeMin(t); - CHECK_LE(time_min + energy_min, time_max); - if (StartMin(t) >= time_min) { - AddStartMinReason(t, time_min); - } else { - AddEndMinReason(t, time_min + energy_min); - } - if (EndMax(t) <= time_max) { - AddEndMaxReason(t, time_max); - } else { - AddStartMaxReason(t, time_max - energy_min); - } - AddSizeMinReason(t, energy_min); -} - // ============================================================================= // Model based functions. // ============================================================================= diff --git a/ortools/sat/intervals_test.cc b/ortools/sat/intervals_test.cc index 5fc3e32df9..ac944efa79 100644 --- a/ortools/sat/intervals_test.cc +++ b/ortools/sat/intervals_test.cc @@ -15,15 +15,11 @@ #include -#include - #include "gtest/gtest.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" -#include "ortools/sat/linear_constraint.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" -#include "ortools/sat/sat_solver.h" namespace operations_research { namespace sat { @@ -46,230 +42,11 @@ TEST(IntervalsRepositoryTest, Precedences) { repo->CreateDisjunctivePrecedenceLiteral(a, b); repo->CreateDisjunctivePrecedenceLiteral(a, b); - EXPECT_NE(kNoLiteralIndex, repo->GetPrecedenceLiteral(a, b)); - EXPECT_EQ(Literal(repo->GetPrecedenceLiteral(a, b)), - Literal(repo->GetPrecedenceLiteral(b, a)).Negated()); -} - -TEST(SchedulingConstraintHelperTest, PushConstantBoundWithOptionalIntervals) { - Model model; - auto* repo = model.GetOrCreate(); - - const AffineExpression start(IntegerValue(0)); - const AffineExpression size(IntegerValue(10)); - const AffineExpression end(IntegerValue(10)); - - Literal presence2 = Literal(model.Add(NewBooleanVariable()), true); - IntervalVariable inter1 = - repo->CreateInterval(start, end, size, kNoLiteralIndex, false); - IntervalVariable inter2 = - repo->CreateInterval(start, end, size, presence2.Index(), false); - - SchedulingConstraintHelper helper({inter1, inter2}, &model); - - EXPECT_TRUE(helper.IncreaseStartMin(1, IntegerValue(20))); - EXPECT_FALSE(model.Get(Value(presence2))); -} - -TEST(SchedulingDemandHelperTest, EnergyInWindow) { - Model model; - - const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); - const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); - const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); - const IntervalVariable inter = - model.GetOrCreate()->CreateInterval( - start, end, size, kNoLiteralIndex, false); - - const AffineExpression demand(model.Add(NewIntegerVariable(2, 10))); - - SchedulingConstraintHelper helper({inter}, &model); - SchedulingDemandHelper demands_helper({demand}, &helper, &model); - demands_helper.CacheAllEnergyValues(); - EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(4)); - - const Literal alt1 = Literal(model.Add(NewBooleanVariable()), true); - const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); - demands_helper.OverrideDecomposedEnergies( - {{{alt1, IntegerValue(2), IntegerValue(4)}, - {alt2, IntegerValue(4), IntegerValue(2)}}}); - demands_helper.CacheAllEnergyValues(); - EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(8)); - - EXPECT_EQ(0, demands_helper.EnergyMinInWindow(0, 8, 2)); - EXPECT_EQ(8, demands_helper.EnergyMinInWindow(0, 0, 10)); - EXPECT_EQ(0, demands_helper.EnergyMinInWindow(0, 2, 10)); - EXPECT_EQ(0, demands_helper.EnergyMinInWindow(0, 0, 8)); - EXPECT_EQ(4, demands_helper.EnergyMinInWindow(0, 0, 9)); -} - -TEST(SchedulingDemandHelperTest, EnergyInWindowTakeIntoAccountWindowSize) { - Model model; - - const AffineExpression start(model.Add(NewIntegerVariable(0, 4))); - const AffineExpression size(model.Add(NewIntegerVariable(6, 10))); - const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); - const IntervalVariable inter = - model.GetOrCreate()->CreateInterval( - start, end, size, kNoLiteralIndex, false); - - const AffineExpression demand(model.Add(NewIntegerVariable(6, 10))); - - SchedulingConstraintHelper helper({inter}, &model); - SchedulingDemandHelper demands_helper({demand}, &helper, &model); - demands_helper.CacheAllEnergyValues(); - - const Literal alt1 = Literal(model.Add(NewBooleanVariable()), true); - const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); - demands_helper.OverrideDecomposedEnergies( - {{{alt1, IntegerValue(8), IntegerValue(6)}, - {alt2, IntegerValue(6), IntegerValue(8)}}}); - demands_helper.CacheAllEnergyValues(); - EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(48)); - - EXPECT_EQ(6, demands_helper.EnergyMinInWindow(0, 5, 6)); -} - -TEST(SchedulingDemandHelperTest, LinearizedDemandWithAffineExpression) { - Model model; - - const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); - const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); - const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); - const IntervalVariable inter = - model.GetOrCreate()->CreateInterval( - start, end, size, kNoLiteralIndex, false); - - const AffineExpression demand( - AffineExpression(model.Add(NewIntegerVariable(2, 10)), 2, 5)); - - SchedulingConstraintHelper helper({inter}, &model); - SchedulingDemandHelper demands_helper({demand}, &helper, &model); - demands_helper.CacheAllEnergyValues(); - - LinearConstraintBuilder builder(&model); - ASSERT_TRUE(demands_helper.AddLinearizedDemand(0, &builder)); - EXPECT_EQ(builder.BuildExpression().DebugString(), "2*X3 + 5"); -} - -TEST(SchedulingDemandHelperTest, LinearizedDemandWithDecomposedEnergy) { - Model model; - - const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); - const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); - const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); - const IntervalVariable inter = - model.GetOrCreate()->CreateInterval( - start, end, size, kNoLiteralIndex, false); - - const AffineExpression demand(model.Add(NewIntegerVariable(2, 10))); - - SchedulingConstraintHelper helper({inter}, &model); - SchedulingDemandHelper demands_helper({demand}, &helper, &model); - demands_helper.CacheAllEnergyValues(); - EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(4)); - - const Literal alt1 = Literal(model.Add(NewBooleanVariable()), true); - const IntegerVariable var1(model.Add(NewIntegerVariable(0, 1))); - model.GetOrCreate()->AssociateToIntegerEqualValue( - alt1, var1, IntegerValue(1)); - - const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); - const IntegerVariable var2(model.Add(NewIntegerVariable(0, 1))); - model.GetOrCreate()->AssociateToIntegerEqualValue( - alt2, var2, IntegerValue(1)); - demands_helper.OverrideDecomposedEnergies( - {{{alt1, IntegerValue(2), IntegerValue(4)}, - {alt2, IntegerValue(4), IntegerValue(2)}}}); - demands_helper.CacheAllEnergyValues(); - LinearConstraintBuilder builder(&model); - ASSERT_TRUE(demands_helper.AddLinearizedDemand(0, &builder)); - EXPECT_EQ(builder.BuildExpression().DebugString(), "4*X4 2*X5"); -} - -TEST(SchedulingDemandHelperTest, FilteredDecomposedEnergy) { - Model model; - SatSolver* sat_solver = model.GetOrCreate(); - IntegerEncoder* encoder = model.GetOrCreate(); - - const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); - const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); - const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); - const IntervalVariable inter = - model.GetOrCreate()->CreateInterval( - start, end, size, kNoLiteralIndex, false); - - const AffineExpression demand(model.Add(NewIntegerVariable(2, 10))); - - SchedulingConstraintHelper helper({inter}, &model); - SchedulingDemandHelper demands_helper({demand}, &helper, &model); - demands_helper.CacheAllEnergyValues(); - EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(4)); - - const std::vector no_energy; - EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), no_energy); - - const Literal alt1 = Literal(model.Add(NewBooleanVariable()), true); - const IntegerVariable var1(model.Add(NewIntegerVariable(0, 1))); - encoder->AssociateToIntegerEqualValue(alt1, var1, IntegerValue(1)); - - const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); - const IntegerVariable var2(model.Add(NewIntegerVariable(0, 1))); - encoder->AssociateToIntegerEqualValue(alt2, var2, IntegerValue(1)); - const std::vector energy = { - {alt1, IntegerValue(2), IntegerValue(4)}, - {alt2, IntegerValue(4), IntegerValue(2)}}; - demands_helper.OverrideDecomposedEnergies({energy}); - demands_helper.CacheAllEnergyValues(); - EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), energy); - - EXPECT_EQ(sat_solver->EnqueueDecisionAndBackjumpOnConflict(alt1.Negated()), - 0); - const std::vector filtered_energy = { - {alt2, IntegerValue(4), IntegerValue(2)}}; - EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), filtered_energy); - EXPECT_EQ(demands_helper.DecomposedEnergies()[0], energy); -} - -TEST(SchedulingDemandHelperTest, FilteredDecomposedEnergyWithFalseLiteral) { - Model model; - IntegerEncoder* encoder = model.GetOrCreate(); - - const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); - const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); - const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); - const IntervalVariable inter = - model.GetOrCreate()->CreateInterval( - start, end, size, kNoLiteralIndex, false); - - const AffineExpression demand(model.Add(NewIntegerVariable(2, 10))); - - SchedulingConstraintHelper helper({inter}, &model); - SchedulingDemandHelper demands_helper({demand}, &helper, &model); - demands_helper.CacheAllEnergyValues(); - EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(4)); - - const std::vector no_energy; - EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), no_energy); - - const Literal alt1 = encoder->GetFalseLiteral(); - const IntegerVariable var1(model.Add(NewIntegerVariable(0, 1))); - model.GetOrCreate()->AssociateToIntegerEqualValue( - alt1, var1, IntegerValue(1)); - - const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); - const IntegerVariable var2(model.Add(NewIntegerVariable(0, 1))); - encoder->AssociateToIntegerEqualValue(alt2, var2, IntegerValue(1)); - const std::vector energy = { - {alt1, IntegerValue(2), IntegerValue(4)}, - {alt2, IntegerValue(4), IntegerValue(2)}}; - demands_helper.OverrideDecomposedEnergies({energy}); - demands_helper.CacheAllEnergyValues(); - const std::vector filtered_energy = { - {alt2, IntegerValue(4), IntegerValue(2)}}; - EXPECT_EQ(demands_helper.DecomposedEnergies()[0], filtered_energy); - EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), filtered_energy); - EXPECT_EQ(0, model.GetOrCreate()->CurrentDecisionLevel()); + EXPECT_NE(kNoLiteralIndex, + repo->GetPrecedenceLiteral(repo->End(a), repo->Start(b))); + EXPECT_EQ(Literal(repo->GetPrecedenceLiteral(repo->End(a), repo->Start(b))), + Literal(repo->GetPrecedenceLiteral(repo->End(b), repo->Start(a))) + .Negated()); } } // namespace diff --git a/ortools/sat/linear_propagation.cc b/ortools/sat/linear_propagation.cc index 9053d7cdcd..4c8a3a40a8 100644 --- a/ortools/sat/linear_propagation.cc +++ b/ortools/sat/linear_propagation.cc @@ -109,6 +109,16 @@ bool EnforcementPropagator::Propagate(Trail* /*trail*/) { } } rev_stack_size_ = static_cast(untrail_stack_.size()); + + // Compute the enforcement status of any constraint added at a positive level. + // This is only needed until we are back to level zero. + for (const EnforcementId id : ids_to_fix_until_next_root_level_) { + ChangeStatus(id, DebugStatus(id)); + } + if (trail_.CurrentDecisionLevel() == 0) { + ids_to_fix_until_next_root_level_.clear(); + } + return true; } @@ -225,6 +235,14 @@ EnforcementId EnforcementPropagator::Register( } } } + + // Tricky: if we added something at a positive level, and its status is + // not CANNOT_PROPAGATE, then we might need to fix it on backtrack. + if (trail_.CurrentDecisionLevel() > 0 && + statuses_[id] != EnforcementStatus::CANNOT_PROPAGATE) { + ids_to_fix_until_next_root_level_.push_back(id); + } + return id; } diff --git a/ortools/sat/linear_propagation.h b/ortools/sat/linear_propagation.h index fe2a960321..fd7d375e2d 100644 --- a/ortools/sat/linear_propagation.h +++ b/ortools/sat/linear_propagation.h @@ -140,6 +140,8 @@ class EnforcementPropagator : public SatPropagator { std::vector temp_literals_; std::vector temp_reason_; + + std::vector ids_to_fix_until_next_root_level_; }; // Helper class to decide on the constraint propagation order. diff --git a/ortools/sat/linear_relaxation.cc b/ortools/sat/linear_relaxation.cc index 54500f3e57..7c21c7cf71 100644 --- a/ortools/sat/linear_relaxation.cc +++ b/ortools/sat/linear_relaxation.cc @@ -964,7 +964,12 @@ void AddCumulativeRelaxation(const AffineExpression& capacity, // // TODO(user): In some cases, we could have only one task that can be // first. - if (ProdOverflow(std::max(-min_of_starts, max_of_ends), + IntegerValue max_for_overflow_check = std::max(-min_of_starts, max_of_ends); + if (makespan.has_value()) { + max_for_overflow_check = std::max( + max_for_overflow_check, integer_trail->UpperBound(makespan.value())); + } + if (ProdOverflow(max_for_overflow_check, integer_trail->UpperBound(capacity))) { return; } diff --git a/ortools/sat/no_overlap_2d_helper.cc b/ortools/sat/no_overlap_2d_helper.cc new file mode 100644 index 0000000000..84edaab71b --- /dev/null +++ b/ortools/sat/no_overlap_2d_helper.cc @@ -0,0 +1,208 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/no_overlap_2d_helper.h" + +#include + +#include "absl/types/span.h" +#include "ortools/sat/diffn_util.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_base.h" +#include "ortools/sat/scheduling_helpers.h" + +namespace operations_research { +namespace sat { + +bool NoOverlap2DConstraintHelper::SynchronizeAndSetDirection( + bool x_is_forward_after_swap, bool y_is_forward_after_swap, + bool swap_x_and_y) { + if (axes_are_swapped_ != swap_x_and_y) { + std::swap(x_helper_, y_helper_); + axes_are_swapped_ = !axes_are_swapped_; + } + if (!x_helper_->SynchronizeAndSetTimeDirection(x_is_forward_after_swap)) + return false; + if (!y_helper_->SynchronizeAndSetTimeDirection(y_is_forward_after_swap)) + return false; + return true; +} + +RectangleInRange NoOverlap2DConstraintHelper::GetItemRangeForSizeMin( + int index) const { + return RectangleInRange{ + .box_index = index, + .bounding_area = {.x_min = x_helper_->StartMin(index), + .x_max = x_helper_->StartMax(index) + + x_helper_->SizeMin(index), + .y_min = y_helper_->StartMin(index), + .y_max = y_helper_->StartMax(index) + + y_helper_->SizeMin(index)}, + .x_size = x_helper_->SizeMin(index), + .y_size = y_helper_->SizeMin(index)}; +} + +ItemWithVariableSize NoOverlap2DConstraintHelper::GetItemWithVariableSize( + int index) const { + return ItemWithVariableSize{.index = index, + .x = {.start_min = x_helper_->StartMin(index), + .start_max = x_helper_->StartMax(index), + .end_min = x_helper_->EndMin(index), + .end_max = x_helper_->EndMax(index)}, + .y = {.start_min = y_helper_->StartMin(index), + .start_max = y_helper_->StartMax(index), + .end_min = y_helper_->EndMin(index), + .end_max = y_helper_->EndMax(index)}}; +} + +namespace { +void ClearAndAddMandatoryOverlapReason(int box1, int box2, + SchedulingConstraintHelper* y) { + y->ClearReason(); + y->AddPresenceReason(box1); + y->AddPresenceReason(box2); + y->AddReasonForBeingBefore(box1, box2); + y->AddReasonForBeingBefore(box2, box1); +} +} // namespace + +bool NoOverlap2DConstraintHelper::ReportConflictFromTwoBoxes(int box1, + int box2) { + ClearAndAddMandatoryOverlapReason(box1, box2, x_helper_); + ClearAndAddMandatoryOverlapReason(box1, box2, y_helper_); + x_helper_->ImportOtherReasons(*y_helper_); + return x_helper_->ReportConflict(); +} + +bool NoOverlap2DConstraintHelper::ReportConflictFromInfeasibleBoxRanges( + absl::Span ranges) { + if (ranges.size() == 2) { + return ReportConflictFromTwoBoxes(ranges[0].box_index, ranges[1].box_index); + } + x_helper_->ClearReason(); + y_helper_->ClearReason(); + for (const auto& range : ranges) { + const int b = range.box_index; + + x_helper_->AddStartMinReason(b, range.bounding_area.x_min); + y_helper_->AddStartMinReason(b, range.bounding_area.y_min); + + x_helper_->AddStartMaxReason(b, range.bounding_area.x_max - range.x_size); + y_helper_->AddStartMaxReason(b, range.bounding_area.y_max - range.y_size); + + x_helper_->AddSizeMinReason(b); + y_helper_->AddSizeMinReason(b); + + x_helper_->AddPresenceReason(b); + y_helper_->AddPresenceReason(b); + } + x_helper_->ImportOtherReasons(*y_helper_); + return x_helper_->ReportConflict(); +} + +namespace { +// This function assumes that the left and right boxes overlap on the second +// dimension, and that left cannot be after right. +// It checks and pushes the lower bound of the right box and the upper bound +// of the left box if need. +// +// If y is not null, it import the mandatory reason for the overlap on y in +// the x helper. +bool LeftBoxBeforeRightBoxOnFirstDimension(int left, int right, + SchedulingConstraintHelper* x, + SchedulingConstraintHelper* y) { + // left box2 pushes right box2. + const IntegerValue left_end_min = x->EndMin(left); + if (left_end_min > x->StartMin(right)) { + x->ClearReason(); + x->AddPresenceReason(left); + x->AddPresenceReason(right); + x->AddReasonForBeingBefore(left, right); + x->AddEndMinReason(left, left_end_min); + // left and right must overlap on y. + ClearAndAddMandatoryOverlapReason(left, right, y); + // Propagate with the complete reason. + x->ImportOtherReasons(*y); + if (!x->IncreaseStartMin(right, left_end_min)) return false; + } + + // right box2 pushes left box2. + const IntegerValue right_start_max = x->StartMax(right); + if (right_start_max < x->EndMax(left)) { + x->ClearReason(); + x->AddPresenceReason(left); + x->AddPresenceReason(right); + x->AddReasonForBeingBefore(left, right); + x->AddStartMaxReason(right, right_start_max); + // left and right must overlap on y. + ClearAndAddMandatoryOverlapReason(left, right, y); + // Propagate with the complete reason. + x->ImportOtherReasons(*y); + if (!x->DecreaseEndMax(left, right_start_max)) return false; + } + + return true; +} + +} // namespace + +bool NoOverlap2DConstraintHelper::PropagateRelativePosition( + int first, int second, PairwiseRestriction::PairwiseRestrictionType type) { + switch (type) { + case PairwiseRestriction::PairwiseRestrictionType::CONFLICT: + return ReportConflictFromTwoBoxes(first, second); + case PairwiseRestriction::PairwiseRestrictionType::FIRST_LEFT_OF_SECOND: + return LeftBoxBeforeRightBoxOnFirstDimension(first, second, x_helper_, + y_helper_); + case PairwiseRestriction::PairwiseRestrictionType::FIRST_RIGHT_OF_SECOND: + return LeftBoxBeforeRightBoxOnFirstDimension(second, first, x_helper_, + y_helper_); + case PairwiseRestriction::PairwiseRestrictionType::FIRST_BELOW_SECOND: + return LeftBoxBeforeRightBoxOnFirstDimension(first, second, y_helper_, + x_helper_); + case PairwiseRestriction::PairwiseRestrictionType::FIRST_ABOVE_SECOND: + return LeftBoxBeforeRightBoxOnFirstDimension(second, first, y_helper_, + x_helper_); + } +} + +bool NoOverlap2DConstraintHelper::Propagate() { + for (const int id : propagators_watching_) { + watcher_->CallOnNextPropagate(id); + } + if (!x_helper_->Propagate() || !y_helper_->Propagate()) return false; + return true; +} + +void NoOverlap2DConstraintHelper::RegisterWith(GenericLiteralWatcher* watcher) { + const int id = watcher->Register(this); + const int num_boxes = NumBoxes(); + for (int b = 0; b < num_boxes; ++b) { + if (x_helper_->IsOptional(b)) { + watcher->WatchLiteral(x_helper_->PresenceLiteral(b), id); + } + if (y_helper_->IsOptional(b)) { + watcher->WatchLiteral(y_helper_->PresenceLiteral(b), id); + } + watcher->WatchIntegerVariable(x_helper_->Sizes()[b].var, id); + watcher->WatchIntegerVariable(x_helper_->Starts()[b].var, id); + watcher->WatchIntegerVariable(x_helper_->Ends()[b].var, id); + watcher->WatchIntegerVariable(y_helper_->Sizes()[b].var, id); + watcher->WatchIntegerVariable(y_helper_->Starts()[b].var, id); + watcher->WatchIntegerVariable(y_helper_->Ends()[b].var, id); + } + watcher->SetPropagatorPriority(id, 0); +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/no_overlap_2d_helper.h b/ortools/sat/no_overlap_2d_helper.h new file mode 100644 index 0000000000..66abfb4923 --- /dev/null +++ b/ortools/sat/no_overlap_2d_helper.h @@ -0,0 +1,187 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_NO_OVERLAP_2D_HELPER_H_ +#define OR_TOOLS_SAT_NO_OVERLAP_2D_HELPER_H_ + +#include +#include + +#include "absl/types/span.h" +#include "ortools/sat/diffn_util.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_base.h" +#include "ortools/sat/model.h" +#include "ortools/sat/scheduling_helpers.h" + +namespace operations_research { +namespace sat { + +// Helper class shared by the propagators that handle no_overlap_2d constraints. +// +// Having a helper class like this one makes much easier to do in-processing and +// to share pre-computed data between the two propagators. +class NoOverlap2DConstraintHelper : public PropagatorInterface { + public: + NoOverlap2DConstraintHelper(SchedulingConstraintHelper* x_helper, + SchedulingConstraintHelper* y_helper, + Model* model) + : axes_are_swapped_(false), + x_helper_(x_helper), + y_helper_(y_helper), + watcher_(model->GetOrCreate()) {} + + void RegisterWith(GenericLiteralWatcher* watcher); + + bool SynchronizeAndSetDirection(bool x_is_forward_after_swap, + bool y_is_forward_after_swap, + bool swap_x_and_y); + + bool IsOptional(int index) const { + return x_helper_->IsOptional(index) || y_helper_->IsOptional(index); + } + + bool IsPresent(int index) const { + return x_helper_->IsPresent(index) && y_helper_->IsPresent(index); + } + + bool IsFixed(int index) const { + return x_helper_->StartIsFixed(index) && x_helper_->EndIsFixed(index) && + y_helper_->StartIsFixed(index) && y_helper_->EndIsFixed(index); + } + + std::pair GetBoxSizesMax(int index) const { + return {x_helper_->SizeMax(index), y_helper_->SizeMax(index)}; + } + + void ClearReason() { + x_helper_->ClearReason(); + y_helper_->ClearReason(); + } + + void WatchAllBoxes(int id) { propagators_watching_.push_back(id); } + + // Propagate a relationship between two boxes (ie., first must be to the left + // of the second, etc). + bool PropagateRelativePosition( + int first, int second, PairwiseRestriction::PairwiseRestrictionType type); + + // Returns a "fixed size projection" of the item of the item `index`. More + // precisely, returns item of index `index` with its sizes fixed to their + // minimum value alongside a bounding box that contains the item. + RectangleInRange GetItemRangeForSizeMin(int index) const; + + // Returns a {start_min, start_max, end_min, end_max} view of the item of + // the index `index`. + ItemWithVariableSize GetItemWithVariableSize(int index) const; + + // If there is no possible placement for the two mandatory boxes (they will + // necessarily overlap), call this function to report this as a conflict. + // Returns true. + bool ReportConflictFromTwoBoxes(int box1, int box2); + + // Reports a conflict due to a (potentially relaxed) infeasible subproblem of + // the no_overlap_2d constraint. More concretely, this function takes a list + // of fixed-size rectangles and their placement domains in `ranges` that + // satisfy: + // - the problem of placing all the rectangles in their domain is + // infeasible; + // - the x and y sizes of each box in `ranges` are smaller or equal than + // the corresponding current minimum sizes of the boxes; + // - for each range in `ranges`, range.box_index.bounding_box is fully + // contained inside GetItemRangeForSizeMin(range.box_index).bounding_box. + // In other terms, each element is infeasible in a domain at least as + // large as the current one. + bool ReportConflictFromInfeasibleBoxRanges( + absl::Span ranges); + + void AddXSizeMinReason(int index) { x_helper_->AddSizeMinReason(index); } + void AddYSizeMinReason(int index) { y_helper_->AddSizeMinReason(index); } + void AddSizeMinReason(int index) { + AddXSizeMinReason(index); + AddYSizeMinReason(index); + } + + // Push the explanation that the left edge of this box is to the right of the + // vertical line x=lower_bound. + // + // | => | + // | => \/ + // | => +---+ + // | => | | + // | => +---+ + // | + void AddLeftMinReason(int index, IntegerValue lower_bound) { + x_helper_->AddStartMinReason(index, lower_bound); + } + + // Push the explanation that the left edge of this box is to the left of the + // vertical line x=upper_bound. + // + // | <= | + // \/ <= | + // +---------|---+ + // | <= | | + // | <= | | + // +---------|---| + // | + void AddLeftMaxReason(int index, IntegerValue upper_bound) { + x_helper_->AddStartMaxReason(index, upper_bound); + } + + // Push the explanation that the bottom edge of this box is to the top of the + // horizontal line y=lower_bound. + void AddBottomMinReason(int index, IntegerValue lower_bound) { + y_helper_->AddStartMinReason(index, lower_bound); + } + + // Push the explanation that the bottom edge of this box is to the bottom of + // the vertical line y=upper_bound. + void AddBottomMaxReason(int index, IntegerValue upper_bound) { + y_helper_->AddStartMaxReason(index, upper_bound); + } + + void AddPresenceReason(int index) { + x_helper_->AddPresenceReason(index); + y_helper_->AddPresenceReason(index); + } + + bool IncreaseLeftMin(int index, IntegerValue new_lower_bound) { + x_helper_->ImportOtherReasons(*y_helper_); + return x_helper_->IncreaseStartMin(index, new_lower_bound); + } + + bool ReportConflict() { + x_helper_->ImportOtherReasons(*y_helper_); + return x_helper_->ReportConflict(); + } + + int NumBoxes() const { return x_helper_->NumTasks(); } + + bool Propagate() override; + + SchedulingConstraintHelper& x_helper() const { return *x_helper_; } + SchedulingConstraintHelper& y_helper() const { return *y_helper_; } + + private: + bool axes_are_swapped_; + SchedulingConstraintHelper* x_helper_; + SchedulingConstraintHelper* y_helper_; + GenericLiteralWatcher* watcher_; + std::vector propagators_watching_; +}; + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_NO_OVERLAP_2D_HELPER_H_ diff --git a/ortools/sat/parameters_validation.cc b/ortools/sat/parameters_validation.cc index d7e469c7bb..c9ed2efb56 100644 --- a/ortools/sat/parameters_validation.cc +++ b/ortools/sat/parameters_validation.cc @@ -182,17 +182,19 @@ std::string ValidateParameters(const SatParameters& params) { } } - const auto strategies = GetNamedParameters(params); - for (const std::string& subsolver : params.subsolvers()) { - if (subsolver == "core_or_no_lp") continue; // Used by fz free search. - if (!strategies.contains(subsolver)) { - return absl::StrCat("subsolver \'", subsolver, "\' is not valid"); + if (!params.subsolvers().empty() || !params.extra_subsolvers().empty()) { + const auto strategies = GetNamedParameters(params); + for (const std::string& subsolver : params.subsolvers()) { + if (subsolver == "core_or_no_lp") continue; // Used by fz free search. + if (!strategies.contains(subsolver)) { + return absl::StrCat("subsolver \'", subsolver, "\' is not valid"); + } } - } - for (const std::string& subsolver : params.extra_subsolvers()) { - if (!strategies.contains(subsolver)) { - return absl::StrCat("subsolver \'", subsolver, "\' is not valid"); + for (const std::string& subsolver : params.extra_subsolvers()) { + if (!strategies.contains(subsolver)) { + return absl::StrCat("subsolver \'", subsolver, "\' is not valid"); + } } } diff --git a/ortools/sat/presolve_context.cc b/ortools/sat/presolve_context.cc index 0596120cd1..48739a5aa1 100644 --- a/ortools/sat/presolve_context.cc +++ b/ortools/sat/presolve_context.cc @@ -2440,9 +2440,24 @@ int PresolveContext::GetOrCreateReifiedPrecedenceLiteral( const auto& it = reified_precedences_cache_.find(key); if (it != reified_precedences_cache_.end()) return it->second; - const int result = NewBoolVar("reified precedence"); + const int result = NewBoolVar(""); reified_precedences_cache_[key] = result; + // Take care of hints. + if (hint_is_loaded_) { + std::optional time_i_hint = GetExpressionSolutionHint(time_i); + std::optional time_j_hint = GetExpressionSolutionHint(time_j); + std::optional active_i_hint = GetRefSolutionHint(active_i); + std::optional active_j_hint = GetRefSolutionHint(active_j); + if (time_i_hint.has_value() && time_j_hint.has_value() && + active_i_hint.has_value() && active_j_hint.has_value()) { + const bool reified_hint = (active_i_hint.value() != 0) && + (active_j_hint.value() != 0) && + (time_i_hint.value() <= time_j_hint.value()); + SetNewVariableHint(result, reified_hint); + } + } + // result => (time_i <= time_j) && active_i && active_j. ConstraintProto* const lesseq = working_model->add_constraints(); lesseq->add_enforcement_literal(result); diff --git a/ortools/sat/presolve_context.h b/ortools/sat/presolve_context.h index 43e9d3e735..cbf1be2bdb 100644 --- a/ortools/sat/presolve_context.h +++ b/ortools/sat/presolve_context.h @@ -676,6 +676,17 @@ class PresolveContext { return RefIsPositive(ref) ? var_hint : -var_hint; } + std::optional GetExpressionSolutionHint( + const LinearExpressionProto& expr) { + int64_t result = expr.offset(); + for (int i = 0; i < expr.vars().size(); ++i) { + if (expr.coeffs(i) == 0) continue; + if (!VarHasSolutionHint(expr.vars(i))) return std::nullopt; + result += expr.coeffs(i) * SolutionHint(expr.vars(i)); + } + return result; + } + void UpdateRefSolutionHint(int ref, int hint) { UpdateVarSolutionHint(PositiveRef(ref), RefIsPositive(ref) ? hint : -hint); } diff --git a/ortools/sat/presolve_context_test.cc b/ortools/sat/presolve_context_test.cc index 921d7e4844..7f4f1181ff 100644 --- a/ortools/sat/presolve_context_test.cc +++ b/ortools/sat/presolve_context_test.cc @@ -681,10 +681,15 @@ TEST(PresolveContextTest, ReifiedConstraintCache) { variables { domain: [ 0, 1 ] } variables { domain: [ 0, 10 ] } variables { domain: [ 0, 10 ] } + solution_hint { + vars: [ 0, 1, 2, 3 ] + values: [ 1, 1, 5, 7 ] + } )pb"); PresolveContext context(&model, &working_model, nullptr); context.InitializeNewDomains(); context.UpdateNewConstraintsVariableUsage(); + context.LoadSolutionHint(); LinearExpressionProto expr1; expr1.add_vars(2); expr1.add_coeffs(1); @@ -705,6 +710,7 @@ TEST(PresolveContextTest, ReifiedConstraintCache) { // 2 x (2 implications , 2 enforced linear) + bool_or. ASSERT_EQ(9, working_model.constraints_size()); EXPECT_THAT(working_model.constraints(8), ::testing::EqualsProto(bool_or)); + EXPECT_TRUE(context.DebugTestHintFeasibility()); } TEST(PresolveContextTest, ExploitFixedDomainOverflow) { diff --git a/ortools/sat/python/BUILD.bazel b/ortools/sat/python/BUILD.bazel index b721e9e3c4..e5b96a3274 100644 --- a/ortools/sat/python/BUILD.bazel +++ b/ortools/sat/python/BUILD.bazel @@ -17,6 +17,11 @@ load("@pip_deps//:requirements.bzl", "requirement") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") load("@rules_python//python:defs.bzl", "py_library", "py_test") +cc_library( + name = "linear_expr_doc", + hdrs = ["linear_expr_doc.h"], +) + cc_library( name = "linear_expr", srcs = ["linear_expr.cc"], @@ -34,11 +39,12 @@ cc_library( ) pybind_extension( - name = "swig_helper", - srcs = ["swig_helper.cc"], + name = "cp_model_helper", + srcs = ["cp_model_helper.cc"], visibility = ["//visibility:public"], deps = [ ":linear_expr", + ":linear_expr_doc", "//ortools/sat:cp_model_cc_proto", "//ortools/sat:cp_model_utils", "//ortools/sat:sat_parameters_cc_proto", @@ -49,10 +55,10 @@ pybind_extension( ) py_test( - name = "swig_helper_test", - srcs = ["swig_helper_test.py"], + name = "cp_model_helper_test", + srcs = ["cp_model_helper_test.py"], deps = [ - ":swig_helper", + ":cp_model_helper", "//ortools/sat:cp_model_py_pb2", "//ortools/sat:sat_parameters_py_pb2", "//ortools/util/python:sorted_interval_list", @@ -61,21 +67,21 @@ py_test( ) py_library( - name = "cp_model_helper", - srcs = ["cp_model_helper.py"], + name = "cp_model_numbers", + srcs = ["cp_model_numbers.py"], visibility = ["//visibility:public"], deps = [ - ":swig_helper", + ":cp_model_helper", requirement("numpy"), "@com_google_protobuf//:protobuf_python", ], ) py_test( - name = "cp_model_helper_test", - srcs = ["cp_model_helper_test.py"], + name = "cp_model_numbers_test", + srcs = ["cp_model_numbers_test.py"], deps = [ - ":cp_model_helper", + ":cp_model_numbers", requirement("absl-py"), ], ) @@ -86,7 +92,7 @@ py_library( visibility = ["//visibility:public"], deps = [ ":cp_model_helper", - ":swig_helper", + ":cp_model_numbers", requirement("numpy"), requirement("pandas"), "//ortools/sat:cp_model_py_pb2", @@ -100,6 +106,8 @@ py_test( srcs = ["cp_model_test.py"], deps = [ ":cp_model", + ":cp_model_helper", requirement("absl-py"), + requirement("numpy"), ], ) diff --git a/ortools/sat/python/CMakeLists.txt b/ortools/sat/python/CMakeLists.txt index 3c0602249c..6e91bfdefb 100644 --- a/ortools/sat/python/CMakeLists.txt +++ b/ortools/sat/python/CMakeLists.txt @@ -11,26 +11,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -pybind11_add_module(swig_helper_pybind11 MODULE swig_helper.cc) -set_target_properties(swig_helper_pybind11 PROPERTIES - LIBRARY_OUTPUT_NAME "swig_helper") +pybind11_add_module(cp_model_helper_pybind11 MODULE cp_model_helper.cc) +set_target_properties(cp_model_helper_pybind11 PROPERTIES + LIBRARY_OUTPUT_NAME "cp_model_helper") # note: macOS is APPLE and also UNIX ! if(APPLE) - set_target_properties(swig_helper_pybind11 PROPERTIES + set_target_properties(cp_model_helper_pybind11 PROPERTIES SUFFIX ".so" INSTALL_RPATH "@loader_path;@loader_path/../../../${PYTHON_PROJECT}/.libs") elseif(UNIX) - set_target_properties(swig_helper_pybind11 PROPERTIES + set_target_properties(cp_model_helper_pybind11 PROPERTIES INSTALL_RPATH "$ORIGIN:$ORIGIN/../../../${PYTHON_PROJECT}/.libs") endif() -target_link_libraries(swig_helper_pybind11 PRIVATE +target_link_libraries(cp_model_helper_pybind11 PRIVATE ${PROJECT_NAMESPACE}::ortools pybind11_native_proto_caster protobuf::libprotobuf) -target_include_directories(swig_helper_pybind11 PRIVATE ${protobuf_SOURCE_DIR}) -add_library(${PROJECT_NAMESPACE}::swig_helper_pybind11 ALIAS swig_helper_pybind11) +target_include_directories(cp_model_helper_pybind11 PRIVATE ${protobuf_SOURCE_DIR}) +add_library(${PROJECT_NAMESPACE}::cp_model_helper_pybind11 ALIAS cp_model_helper_pybind11) if(BUILD_TESTING) file(GLOB PYTHON_SRCS "*_test.py") diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index f2be7c0f28..c54f42ea34 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -67,13 +67,16 @@ import pandas as pd from ortools.sat import cp_model_pb2 from ortools.sat import sat_parameters_pb2 from ortools.sat.python import cp_model_helper as cmh -from ortools.sat.python import swig_helper +from ortools.sat.python import cp_model_numbers as cmn from ortools.util.python import sorted_interval_list # Import external types. Domain = sorted_interval_list.Domain -LinearExpr = swig_helper.LinearExpr -BoundedLinearExpression = swig_helper.BoundedLinearExpression +BoundedLinearExpression = cmh.BoundedLinearExpression +FlatFloatExpr = cmh.FlatFloatExpr +FlatIntExpr = cmh.FlatIntExpr +LinearExpr = cmh.LinearExpr +NotBooleanVariable = cmh.NotBooleanVariable # The classes below allow linear expressions to be expressed naturally with the @@ -152,8 +155,8 @@ NumberTypes = ( np.double, ) -LiteralT = Union[swig_helper.Literal, IntegralT, bool] -BoolVarT = swig_helper.Literal +LiteralT = Union[cmh.Literal, IntegralT, bool] +BoolVarT = cmh.Literal VariableT = Union["IntVar", IntegralT] # We need to add 'IntVar' for pytype. @@ -215,7 +218,7 @@ def short_expr_name( return str(e) -class IntVar(swig_helper.BaseIntVar): +class IntVar(cmh.BaseIntVar): """An integer variable. An IntVar is an object that can take on any integer value within defined @@ -246,10 +249,10 @@ class IntVar(swig_helper.BaseIntVar): # case 2: # model is a CpModelProto, domain is an index (int), and name is None. if isinstance(domain, IntegralTypes) and name is None: - swig_helper.BaseIntVar.__init__(self, int(domain), is_boolean) + cmh.BaseIntVar.__init__(self, int(domain), is_boolean) self.__var = model.variables[domain] else: - swig_helper.BaseIntVar.__init__(self, len(model.variables), is_boolean) + cmh.BaseIntVar.__init__(self, len(model.variables), is_boolean) self.__var = model.variables.add() self.__var.domain.extend( cast(sorted_interval_list.Domain, domain).flattened_intervals() @@ -384,12 +387,12 @@ class Constraint: self. """ for lit in expand_generator_or_tuple(boolvar): - if (cmh.is_boolean(lit) and lit) or ( + if (cmn.is_boolean(lit) and lit) or ( isinstance(lit, IntegralTypes) and lit == 1 ): # Always true. Do nothing. pass - elif (cmh.is_boolean(lit) and not lit) or ( + elif (cmn.is_boolean(lit) and not lit) or ( isinstance(lit, IntegralTypes) and lit == 0 ): self.__constraint.enforcement_literal.append( @@ -397,7 +400,7 @@ class Constraint: ) else: self.__constraint.enforcement_literal.append( - cast(swig_helper.Literal, lit).index + cast(cmh.Literal, lit).index ) return self @@ -584,7 +587,7 @@ def object_is_a_true_literal(literal: LiteralT) -> bool: if isinstance(literal, IntVar): proto = literal.proto return len(proto.domain) == 2 and proto.domain[0] == 1 and proto.domain[1] == 1 - if isinstance(literal, swig_helper.NotBooleanVariable): + if isinstance(literal, cmh.NotBooleanVariable): proto = literal.negated().proto return len(proto.domain) == 2 and proto.domain[0] == 0 and proto.domain[1] == 0 if isinstance(literal, IntegralTypes): @@ -597,7 +600,7 @@ def object_is_a_false_literal(literal: LiteralT) -> bool: if isinstance(literal, IntVar): proto = literal.proto return len(proto.domain) == 2 and proto.domain[0] == 0 and proto.domain[1] == 0 - if isinstance(literal, swig_helper.NotBooleanVariable): + if isinstance(literal, cmh.NotBooleanVariable): proto = literal.negated().proto return len(proto.domain) == 2 and proto.domain[0] == 1 and proto.domain[1] == 1 if isinstance(literal, IntegralTypes): @@ -807,26 +810,21 @@ class CpModel: ) -> Constraint: """Adds the constraint: `linear_expr` in `domain`.""" if isinstance(linear_expr, LinearExpr): - flat_expr = swig_helper.CanonicalIntExpression(linear_expr) - if not flat_expr.ok: + ble = BoundedLinearExpression(linear_expr, domain) + if not ble.ok: raise TypeError( - "linear expression contains floating point coefficients or" - f" constants: {linear_expr}" + "Cannot add a linear expression containing floating point" + f" coefficients or constants: {type(linear_expr).__name__!r}" ) - return self.add( - BoundedLinearExpression( - flat_expr.vars, flat_expr.coeffs, flat_expr.offset, domain - ) - ) + return self.add(ble) if isinstance(linear_expr, IntegralTypes): if not domain.contains(int(linear_expr)): return self.add_bool_or([]) # Evaluate to false. else: return self.add_bool_and([]) # Evaluate to true. raise TypeError( - f"not supported: CpModel.add_linear_expression_in_domain({linear_expr} " - f" {type(linear_expr)} {linear_expr.is_integer()} {domain} " - f"{type(domain)}" + "not supported:" + f" CpModel.add_linear_expression_in_domain({type(linear_expr).__name__!r})" ) def add(self, ct: Union[BoundedLinearExpression, bool, np.bool_]) -> Constraint: @@ -849,16 +847,16 @@ class CpModel: model_ct.linear.coeffs.extend(ct.coeffs) model_ct.linear.domain.extend( [ - cmh.capped_subtraction(x, ct.offset) + cmn.capped_subtraction(x, ct.offset) for x in ct.bounds.flattened_intervals() ] ) return result - if ct and cmh.is_boolean(ct): + if ct and cmn.is_boolean(ct): return self.add_bool_or([True]) - if not ct and cmh.is_boolean(ct): + if not ct and cmn.is_boolean(ct): return self.add_bool_or([]) # Evaluate to false. - raise TypeError("not supported: CpModel.add(" + str(ct) + ")") + raise TypeError(f"not supported: CpModel.add({type(ct).__name__!r})") # General Integer Constraints. @@ -1028,7 +1026,7 @@ class CpModel: arity: int = len(expressions) for one_tuple in tuples_list: if len(one_tuple) != arity: - raise TypeError("Tuple " + str(one_tuple) + " has the wrong arity") + raise TypeError(f"Tuple {one_tuple!r} has the wrong arity") # duck-typing (no explicit type checks here) try: @@ -1037,7 +1035,7 @@ class CpModel: except ValueError as ex: raise TypeError( "add_xxx_assignment: Not an integer or does not fit in an int64_t:" - f" {ex.args}" + f" {type(ex.args).__name__!r}" ) from ex return ct @@ -1148,7 +1146,7 @@ class CpModel: model_ct.automaton.final_states.append(v) for t in transition_triples: if len(t) != 3: - raise TypeError("Tuple " + str(t) + " has the wrong arity (!= 3)") + raise TypeError(f"Tuple {t!r} has the wrong arity (!= 3)") model_ct.automaton.transition_tail.append(t[0]) model_ct.automaton.transition_label.append(t[1]) model_ct.automaton.transition_head.append(t[2]) @@ -2078,14 +2076,16 @@ class CpModel: return arg.index if isinstance(arg, IntegralTypes): return self.get_or_make_index_from_constant(arg) - raise TypeError("NotSupported: model.get_or_make_index(" + str(arg) + ")") + raise TypeError( + f"NotSupported: model.get_or_make_index({type(arg).__name__!r})" + ) def get_or_make_boolean_index(self, arg: LiteralT) -> int: """Returns an index from a boolean expression.""" if isinstance(arg, IntVar): self.assert_is_boolean_variable(arg) return arg.index - if isinstance(arg, swig_helper.NotBooleanVariable): + if isinstance(arg, cmh.NotBooleanVariable): self.assert_is_boolean_variable(arg.negated()) return arg.index if isinstance(arg, IntegralTypes): @@ -2093,15 +2093,19 @@ class CpModel: return self.get_or_make_index_from_constant(1) if arg == ~True: # -2 return self.get_or_make_index_from_constant(0) - arg = cmh.assert_is_zero_or_one(arg) + arg = cmn.assert_is_zero_or_one(arg) return self.get_or_make_index_from_constant(arg) - if cmh.is_boolean(arg): + if cmn.is_boolean(arg): return self.get_or_make_index_from_constant(int(arg)) - raise TypeError(f"not supported: model.get_or_make_boolean_index({arg})") + raise TypeError( + "not supported:" f" model.get_or_make_boolean_index({type(arg).__name__!r})" + ) def get_interval_index(self, arg: IntervalVar) -> int: if not isinstance(arg, IntervalVar): - raise TypeError(f"NotSupported: model.get_interval_index({arg})") + raise TypeError( + f"NotSupported: model.get_interval_index({type(arg).__name__!r})" + ) return arg.index def get_or_make_index_from_constant(self, value: IntegralT) -> int: @@ -2132,14 +2136,8 @@ class CpModel: result.offset = int(linear_expr) * mult return result - if isinstance(linear_expr, IntVar): - result.vars.append(self.get_or_make_index(linear_expr)) - result.coeffs.append(mult) - return result - - flat_expr = swig_helper.CanonicalIntExpression(linear_expr) - if not flat_expr.ok: - raise ValueError(f"Failed to parse linear expression: {linear_expr}") + # Raises TypeError if linear_expr is not an integer. + flat_expr = cmh.FlatIntExpr(linear_expr) result.offset = flat_expr.offset for var in flat_expr.vars: result.vars.append(var.index) @@ -2155,7 +2153,7 @@ class CpModel: self.__model.objective.scaling_factor = 1.0 elif isinstance(obj, LinearExpr): if obj.is_integer(): - int_obj = swig_helper.CanonicalIntExpression(obj) + int_obj = cmh.FlatIntExpr(obj) for var in int_obj.vars: self.__model.objective.vars.append(var.index) if minimize: @@ -2168,14 +2166,16 @@ class CpModel: for c in int_obj.coeffs: self.__model.objective.coeffs.append(-c) else: - float_obj = swig_helper.CanonicalFloatExpression(obj) + float_obj = cmh.FlatFloatExpr(obj) for var in float_obj.vars: self.__model.floating_point_objective.vars.append(var.index) self.__model.floating_point_objective.coeffs.extend(float_obj.coeffs) self.__model.floating_point_objective.maximize = not minimize self.__model.floating_point_objective.offset = float_obj.offset else: - raise TypeError("TypeError: " + str(obj) + " is not a valid objective") + raise TypeError( + f"TypeError: {type(obj).__name__!r} is not a valid objective" + ) def minimize(self, obj: ObjLinearExprT): """Sets the objective of the model to minimize(obj).""" @@ -2229,11 +2229,11 @@ class CpModel: def model_stats(self) -> str: """Returns a string containing some model statistics.""" - return swig_helper.CpSatHelper.model_stats(self.__model) + return cmh.CpSatHelper.model_stats(self.__model) def validate(self) -> str: """Returns a string indicating that the model is invalid.""" - return swig_helper.CpSatHelper.validate_model(self.__model) + return cmh.CpSatHelper.validate_model(self.__model) def export_to_file(self, file: str) -> bool: """Write the model as a protocol buffer to 'file'. @@ -2246,7 +2246,7 @@ class CpModel: Returns: True if the model was correctly written. """ - return swig_helper.CpSatHelper.write_model_to_file(self.__model, file) + return cmh.CpSatHelper.write_model_to_file(self.__model, file) @overload def add_hint(self, var: IntVar, value: int) -> None: ... @@ -2285,9 +2285,13 @@ class CpModel: if isinstance(x, IntVar): var = self.__model.variables[x.index] if len(var.domain) != 2 or var.domain[0] < 0 or var.domain[1] > 1: - raise TypeError("TypeError: " + str(x) + " is not a boolean variable") - elif not isinstance(x, swig_helper.NotBooleanVariable): - raise TypeError("TypeError: " + str(x) + " is not a boolean variable") + raise TypeError( + f"TypeError: {type(x).__name__!r} is not a boolean variable" + ) + elif not isinstance(x, cmh.NotBooleanVariable): + raise TypeError( + f"TypeError: {type(x).__name__!r} is not a boolean variable" + ) # Compatibility with pre PEP8 # pylint: disable=invalid-name @@ -2398,13 +2402,13 @@ class CpSolver: """ def __init__(self) -> None: - self.__response_wrapper: Optional[swig_helper.ResponseWrapper] = None + self.__response_wrapper: Optional[cmh.ResponseWrapper] = None self.parameters: sat_parameters_pb2.SatParameters = ( sat_parameters_pb2.SatParameters() ) self.log_callback: Optional[Callable[[str], None]] = None self.best_bound_callback: Optional[Callable[[float], None]] = None - self.__solve_wrapper: Optional[swig_helper.SolveWrapper] = None + self.__solve_wrapper: Optional[cmh.SolveWrapper] = None self.__lock: threading.Lock = threading.Lock() def solve( @@ -2414,7 +2418,7 @@ class CpSolver: ) -> cp_model_pb2.CpSolverStatus: """Solves a problem and passes each solution to the callback if not null.""" with self.__lock: - self.__solve_wrapper = swig_helper.SolveWrapper() + self.__solve_wrapper = cmh.SolveWrapper() self.__solve_wrapper.set_parameters(self.parameters) if solution_callback is not None: @@ -2570,7 +2574,7 @@ class CpSolver: return self._checked_response.solution_info() @property - def _checked_response(self) -> swig_helper.ResponseWrapper: + def _checked_response(self) -> cmh.ResponseWrapper: """Checks solve() has been called, and returns a response wrapper.""" if self.__response_wrapper is None: raise RuntimeError("solve() has not been called.") @@ -2694,7 +2698,7 @@ class CpSolver: # pylint: enable=invalid-name -class CpSolverSolutionCallback(swig_helper.SolutionCallback): +class CpSolverSolutionCallback(cmh.SolutionCallback): """Solution callback. This class implements a callback that will be called at each new solution @@ -2709,7 +2713,7 @@ class CpSolverSolutionCallback(swig_helper.SolutionCallback): """ def __init__(self) -> None: - swig_helper.SolutionCallback.__init__(self) + cmh.SolutionCallback.__init__(self) def OnSolutionCallback(self) -> None: """Proxy for the same method in snake case.""" @@ -2945,7 +2949,7 @@ def _convert_to_integral_series_and_validate_index( else: raise ValueError("index does not match") else: - raise TypeError(f"invalid type={type(value_or_series)}") + raise TypeError(f"invalid type={type(value_or_series).__name__!r}") def _convert_to_linear_expr_series_and_validate_index( @@ -2972,7 +2976,7 @@ def _convert_to_linear_expr_series_and_validate_index( else: raise ValueError("index does not match") else: - raise TypeError(f"invalid type={type(value_or_series)}") + raise TypeError(f"invalid type={type(value_or_series).__name__!r}") def _convert_to_literal_series_and_validate_index( @@ -2999,4 +3003,4 @@ def _convert_to_literal_series_and_validate_index( else: raise ValueError("index does not match") else: - raise TypeError(f"invalid type={type(value_or_series)}") + raise TypeError(f"invalid type={type(value_or_series).__name__!r}") diff --git a/ortools/sat/python/swig_helper.cc b/ortools/sat/python/cp_model_helper.cc similarity index 53% rename from ortools/sat/python/swig_helper.cc rename to ortools/sat/python/cp_model_helper.cc index 0b523fa66f..14b0b05be2 100644 --- a/ortools/sat/python/swig_helper.cc +++ b/ortools/sat/python/cp_model_helper.cc @@ -11,25 +11,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This file wraps the swig_helper.h classes in python using pybind11. -#include "ortools/sat/swig_helper.h" - #include #include #include #include +#include #include +#include "absl/functional/any_invocable.h" +#include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/python/linear_expr.h" +#include "ortools/sat/python/linear_expr_doc.h" +#include "ortools/sat/swig_helper.h" #include "ortools/util/sorted_interval_list.h" +#include "pybind11/attr.h" #include "pybind11/cast.h" #include "pybind11/functional.h" #include "pybind11/gil.h" +#include "pybind11/numpy.h" #include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "pybind11/stl.h" #include "pybind11_protobuf/native_proto_caster.h" @@ -39,7 +44,7 @@ namespace operations_research::sat::python { using ::py::arg; -void throw_error(PyObject* py_exception, const std::string& message) { +void ThrowError(PyObject* py_exception, const std::string& message) { PyErr_SetString(py_exception, message.c_str()); throw py::error_already_set(); } @@ -146,9 +151,9 @@ class ResponseWrapper { IntExprVisitor visitor; int64_t value; if (!visitor.Evaluate(expr, response_, &value)) { - throw_error(PyExc_TypeError, - absl::StrCat("Failed to evaluate linear expression: ", - expr->DebugString())); + ThrowError(PyExc_ValueError, + absl::StrCat("Failed to evaluate linear expression: ", + expr->DebugString())); } return value; } @@ -161,80 +166,21 @@ class ResponseWrapper { const CpSolverResponse response_; }; -const char* kLinearExprClassDoc = R"doc(Holds an integer linear expression. - - A linear expression is built from integer constants and variables. - For example, `x + 2 * (y - z + 1)`. - - Linear expressions are used in CP-SAT models in constraints and in the - objective: - - * You can define linear constraints as in: - - ``` - model.add(x + 2 * y <= 5) - model.add(sum(array_of_vars) == 5) - ``` - - * In CP-SAT, the objective is a linear expression: - - ``` - model.minimize(x + 2 * y + z) - ``` - - * For large arrays, using the LinearExpr class is faster that using the python - `sum()` function. You can create constraints and the objective from lists of - linear expressions or coefficients as follows: - - ``` - model.minimize(cp_model.LinearExpr.sum(expressions)) - model.add(cp_model.LinearExpr.weighted_sum(expressions, coefficients) >= 0) - ```)doc"; - -const char* kLiteralClassDoc = R"doc( - Holds a Boolean literal. - - A literal is a Boolean variable or its negation. - - Literals are used in CP-SAT models in constraints and in the - objective: - - * You can define literal as in: - - ``` - b1 = model.new_bool_var() - b2 = model.new_bool_var() - # Simple Boolean constraint. - model.add_bool_or(b1, b2.negated()) - # We can use the ~ operator to negate a literal. - model.add_bool_or(b1, ~b2) - # Enforcement literals must be literals. - x = model.new_int_var(0, 10, 'x') - model.add(x == 5).only_enforced_if(~b1) - ``` - - * Literals can be used directly in linear constraints or in the objective: - - ``` - model.minimize(b1 + 2 * ~b2) - ```)doc"; - // Checks that the result is not null and throws an error if it is. BoundedLinearExpression* CheckBoundedLinearExpression( BoundedLinearExpression* result, LinearExpr* lhs, LinearExpr* rhs = nullptr) { - if (result == nullptr) { + if (!result->ok()) { if (rhs == nullptr) { - throw_error(PyExc_TypeError, - absl::StrCat("Linear constraints only accept integer values " - "and coefficients: ", - lhs->DebugString())); + ThrowError(PyExc_TypeError, + absl::StrCat("Linear constraints only accept integer values " + "and coefficients: ", + lhs->DebugString())); } else { - throw_error( - PyExc_TypeError, - absl::StrCat("Linear constraints only accept integer values " - "and coefficients: ", - lhs->DebugString(), " and ", rhs->DebugString())); + ThrowError(PyExc_TypeError, + absl::StrCat("Linear constraints only accept integer values " + "and coefficients: ", + lhs->DebugString(), " and ", rhs->DebugString())); } } return result; @@ -242,17 +188,223 @@ BoundedLinearExpression* CheckBoundedLinearExpression( void RaiseIfNone(LinearExpr* expr) { if (expr == nullptr) { - throw_error(PyExc_TypeError, - "Linear constraints do not accept None as argument."); + ThrowError(PyExc_TypeError, + "Linear constraints do not accept None as argument."); } } -PYBIND11_MODULE(swig_helper, m) { +void ProcessExprArg(const py::handle& arg, + absl::AnyInvocable on_linear_expr, + absl::AnyInvocable on_int_constant, + absl::AnyInvocable on_float_constant) { + if (py::isinstance(arg)) { + on_linear_expr(arg.cast()); + } else if (py::isinstance(arg)) { + on_int_constant(arg.cast()); + } else if (py::isinstance(arg)) { + on_float_constant(arg.cast()); + } else if (hasattr(arg, "dtype") && hasattr(arg, "is_integer")) { + if (getattr(arg, "is_integer")().cast()) { + on_int_constant(arg.cast()); + } else { + on_float_constant(arg.cast()); + } + } else { + py::type objtype = py::type::of(arg); + const std::string type_name = objtype.attr("__name__").cast(); + ThrowError(PyExc_TypeError, + absl::StrCat("LinearExpr::sum() only accept linear " + "expressions and constants as argument: '", + absl::CEscape(type_name), "'")); + } +} + +void ProcessConstantArg(const py::handle& arg, + absl::AnyInvocable on_int_constant, + absl::AnyInvocable on_float_constant) { + if (py::isinstance(arg)) { + on_int_constant(arg.cast()); + } else if (py::isinstance(arg)) { + on_float_constant(arg.cast()); + } else if (hasattr(arg, "dtype") && hasattr(arg, "is_integer")) { + if (getattr(arg, "is_integer")().cast()) { + on_int_constant(arg.cast()); + } else { + on_float_constant(arg.cast()); + } + } else { + py::type objtype = py::type::of(arg); + const std::string type_name = objtype.attr("__name__").cast(); + ThrowError(PyExc_TypeError, + absl::StrCat("LinearExpr::weighted_sum() only accept constants " + "as coefficients: '", + absl::CEscape(type_name), "'")); + } +} + +LinearExpr* SumArguments(py::args expressions) { + std::vector linear_exprs; + int64_t int_offset = 0; + double float_offset = 0.0; + bool has_floats = false; + + const auto process_arg = [&](const py::handle& arg) -> void { + ProcessExprArg( + arg, [&](LinearExpr* expr) { linear_exprs.push_back(expr); }, + [&](int64_t value) { int_offset += value; }, + [&](double value) { + if (value != 0.0) { + float_offset += value; + has_floats = true; + } + }); + }; + + if (expressions.size() == 1 && py::isinstance(expressions[0])) { + // Normal list or tuple argument. + py::sequence elements = expressions[0].cast(); + linear_exprs.reserve(elements.size()); + for (const py::handle& arg : elements) { + process_arg(arg); + } + } else { // Direct sum(x, y, 3, ..) without []. + linear_exprs.reserve(expressions.size()); + for (const py::handle arg : expressions) { + process_arg(arg); + } + } + + // If there are floats, we add the int offset to the float offset. + if (has_floats) { + float_offset += static_cast(int_offset); + int_offset = 0; + } + + if (linear_exprs.empty()) { + if (has_floats) { + return new FloatConstant(float_offset); + } else { + return new IntConstant(int_offset); + } + } else if (linear_exprs.size() == 1) { + if (has_floats) { + if (float_offset == 0.0) { + return linear_exprs[0]; + } else { + return new FloatAffine(linear_exprs[0], 1.0, float_offset); + } + } else if (int_offset != 0) { + return new IntAffine(linear_exprs[0], 1, int_offset); + } else { + return linear_exprs[0]; + } + } else { + if (has_floats) { + return new SumArray(linear_exprs, 0, float_offset); + } else { + return new SumArray(linear_exprs, int_offset, 0.0); + } + } +} + +LinearExpr* WeightedSumArguments(py::sequence expressions, + py::sequence coefficients) { + if (expressions.size() != coefficients.size()) { + ThrowError(PyExc_ValueError, + absl::StrCat("LinearExpr::weighted_sum() requires the same " + "number of arguments and coefficients: ", + expressions.size(), " != ", coefficients.size())); + } + + std::vector linear_exprs; + std::vector int_coeffs; + std::vector float_coeffs; + linear_exprs.reserve(expressions.size()); + int_coeffs.reserve(expressions.size()); + float_coeffs.reserve(expressions.size()); + int64_t int_offset = 0; + double float_offset = 0.0; + bool has_floats = false; + + for (int i = 0; i < expressions.size(); ++i) { + auto on_expr = [&](LinearExpr* expr) { + ProcessConstantArg( + coefficients[i], + [&](int64_t value) { + if (value == 0) return; + linear_exprs.push_back(expr); + int_coeffs.push_back(value); + float_coeffs.push_back(static_cast(value)); + }, + [&](double value) { + if (value == 0.0) return; + linear_exprs.push_back(expr); + float_coeffs.push_back(value); + has_floats = true; + }); + }; + auto on_int = [&](int64_t expr_value) { + if (expr_value == 0) return; + ProcessConstantArg( + coefficients[i], + [&](int64_t coeff_value) { int_offset += coeff_value * expr_value; }, + [&](double coeff_value) { + has_floats = true; + float_offset += coeff_value * static_cast(expr_value); + }); + }; + auto on_float = [&](double expr_value) { + if (expr_value == 0.0) return; + has_floats = true; + ProcessConstantArg( + coefficients[i], + [&](int64_t coeff_value) { + float_offset += static_cast(coeff_value) * expr_value; + }, + [&](double coeff_value) { + if (coeff_value == 0.0) return; + float_offset += coeff_value * expr_value; + }); + }; + ProcessExprArg(expressions[i], std::move(on_expr), std::move(on_int), + std::move(on_float)); + } + + // Correct the float offset if there are int offsets. + if (has_floats) { + float_offset += static_cast(int_offset); + int_offset = 0; + } + + if (linear_exprs.empty()) { + if (has_floats) { + return new FloatConstant(float_offset); + } else { + return new IntConstant(int_offset); + } + } else if (linear_exprs.size() == 1) { + if (has_floats) { + return new FloatAffine(linear_exprs[0], float_coeffs[0], float_offset); + } else if (int_offset != 0 || int_coeffs[0] != 1) { + return new IntAffine(linear_exprs[0], int_coeffs[0], int_offset); + } else { + return linear_exprs[0]; + } + } else { + if (has_floats) { + return new FloatWeightedSum(linear_exprs, float_coeffs, float_offset); + } else { + return new IntWeightedSum(linear_exprs, int_coeffs, int_offset); + } + } +} + +PYBIND11_MODULE(cp_model_helper, m) { pybind11_protobuf::ImportNativeProtoCasters(); py::module::import("ortools.util.python.sorted_interval_list"); - // We keep the CamelCase name for the SolutionCallback class to be compatible - // with the pre PEP8 python code. + // We keep the CamelCase name for the SolutionCallback class to be + // compatible with the pre PEP8 python code. py::class_(m, "SolutionCallback") .def(py::init<>()) .def("OnSolutionCallback", &SolutionCallback::OnSolutionCallback) @@ -279,9 +431,9 @@ PYBIND11_MODULE(swig_helper, m) { IntExprVisitor visitor; int64_t value; if (!visitor.Evaluate(expr, callback.Response(), &value)) { - throw_error(PyExc_TypeError, - absl::StrCat("Failed to evaluate linear expression: ", - expr->DebugString())); + ThrowError(PyExc_ValueError, + absl::StrCat("Failed to evaluate linear expression: ", + expr->DebugString())); } return value; }, @@ -300,7 +452,6 @@ PYBIND11_MODULE(swig_helper, m) { "Returns the Boolean value of a literal after solve."); py::class_(m, "ResponseWrapper") - .def(py::init()) .def("best_objective_bound", &ResponseWrapper::BestObjectiveBound) .def("boolean_value", &ResponseWrapper::BooleanValue, arg("lit")) .def("boolean_value", &ResponseWrapper::FixedBooleanValue, arg("lit")) @@ -358,122 +509,45 @@ PYBIND11_MODULE(swig_helper, m) { .def_static("write_model_to_file", &CpSatHelper::WriteModelToFile, arg("model_proto"), arg("filename")); - py::class_(m, "ExprOrValue") - .def(py::init()) // Needs to be before the double init. - .def(py::init()) - .def(py::init()) - .def_readonly("double_value", &ExprOrValue::double_value) - .def_readonly("expr", &ExprOrValue::expr) - .def_readonly("int_value", &ExprOrValue::int_value); - - py::implicitly_convertible(); - py::implicitly_convertible(); - py::implicitly_convertible(); - - py::class_(m, "LinearExpr", kLinearExprClassDoc) - // We make sure to keep the order of the overloads: LinearExpr* before - // ExprOrValue as this is faster to parse and type check. - .def_static("sum", (&LinearExpr::Sum), arg("exprs"), + py::class_(m, "LinearExpr", + DOC(operations_research, sat, python, LinearExpr)) + .def_static("sum", &SumArguments, py::return_value_policy::automatic, + "Returns the sum(expressions).", py::keep_alive<0, 1>()) + .def_static("weighted_sum", &WeightedSumArguments, arg("expressions"), + arg("coefficients"), + "Returns the sum of (expressions[i] * coefficients[i])", py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static("sum", &LinearExpr::MixedSum, arg("exprs"), - py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static( - "weighted_sum", - [](const std::vector& exprs, - const std::vector& coeffs) { - if (exprs.size() != coeffs.size()) { - throw_error( - PyExc_ValueError, - "The number of expressions and coefficients must match."); - } - return LinearExpr::WeightedSumInt(exprs, coeffs); - }, - py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static( - "weighted_sum", - [](const std::vector& exprs, - const std::vector& coeffs) { - if (exprs.size() != coeffs.size()) { - throw_error( - PyExc_ValueError, - "The number of expressions and coefficients must match."); - } - return LinearExpr::WeightedSumFloat(exprs, coeffs); - }, - py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static( - "weighted_sum", - [](const std::vector& exprs, - const std::vector& coeffs) { - if (exprs.size() != coeffs.size()) { - throw_error( - PyExc_ValueError, - "The number of expressions and coefficients must match."); - } - return LinearExpr::MixedWeightedSumInt(exprs, coeffs); - }, - py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static( - "weighted_sum", - [](const std::vector& exprs, - const std::vector& coeffs) { - if (exprs.size() != coeffs.size()) { - throw_error( - PyExc_ValueError, - "The number of expressions and coefficients must match."); - } - return LinearExpr::MixedWeightedSumFloat(exprs, coeffs); - }, - py::return_value_policy::automatic, py::keep_alive<0, 1>()) - // Make sure to keep the order of the overloads: int before float as an - // an integer value will be silently converted to a float. .def_static("term", &LinearExpr::TermInt, arg("expr").none(false), - arg("coeff"), "Returns expr * coeff.", + arg("coeff"), + DOC(operations_research, sat, python, LinearExpr, TermInt), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def_static("term", &LinearExpr::TermFloat, arg("expr").none(false), - arg("coeff"), "Returns expr * coeff.", + arg("coeff"), + DOC(operations_research, sat, python, LinearExpr, TermFloat), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def_static("affine", &LinearExpr::AffineInt, arg("expr").none(false), - arg("coeff"), arg("offset"), "Returns expr * coeff + offset.", + arg("coeff"), arg("offset"), + DOC(operations_research, sat, python, LinearExpr, AffineInt), py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static("affine", &LinearExpr::AffineFloat, arg("expr").none(false), - arg("coeff"), arg("offset"), "Returns expr * coeff + offset.", - py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static("constant", &LinearExpr::ConstantInt, arg("value"), - "Returns a constant linear expression.", - py::return_value_policy::automatic) - .def_static("constant", &LinearExpr::ConstantFloat, arg("value"), - "Returns a constant linear expression.", - py::return_value_policy::automatic) + .def_static( + "affine", &LinearExpr::AffineFloat, arg("expr").none(false), + arg("coeff"), arg("offset"), + DOC(operations_research, sat, python, LinearExpr, AffineFloat), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static( + "constant", &LinearExpr::ConstantInt, arg("value"), + DOC(operations_research, sat, python, LinearExpr, ConstantInt), + py::return_value_policy::automatic) + .def_static( + "constant", &LinearExpr::ConstantFloat, arg("value"), + DOC(operations_research, sat, python, LinearExpr, ConstantFloat), + py::return_value_policy::automatic) // Pre PEP8 compatibility layer. - .def_static("Sum", &LinearExpr::Sum, arg("exprs"), - py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static("Sum", &LinearExpr::MixedSum, arg("exprs"), - py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static( - "WeightedSum", - [](const std::vector& exprs, - const std::vector& coeffs) { - if (exprs.size() != coeffs.size()) { - throw_error( - PyExc_ValueError, - "The number of expressions and coefficients must match."); - } - return LinearExpr::MixedWeightedSumInt(exprs, coeffs); - }, - py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static( - "WeightedSum", - [](const std::vector& exprs, - const std::vector& coeffs) { - if (exprs.size() != coeffs.size()) { - throw_error( - PyExc_ValueError, - "The number of expressions and coefficients must match."); - } - return LinearExpr::MixedWeightedSumFloat(exprs, coeffs); - }, - py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("Sum", &SumArguments, py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + .def_static("WeightedSum", &WeightedSumArguments, arg("expressions"), + arg("coefficients"), py::return_value_policy::automatic, + py::keep_alive<0, 1>()) .def_static("Term", &LinearExpr::TermInt, arg("expr").none(false), arg("coeff"), "Returns expr * coeff.", py::return_value_policy::automatic, py::keep_alive<0, 1>()) @@ -483,41 +557,57 @@ PYBIND11_MODULE(swig_helper, m) { // Methods. .def("__str__", &LinearExpr::ToString) .def("__repr__", &LinearExpr::DebugString) - .def("is_integer", &LinearExpr::IsInteger) + .def("is_integer", &LinearExpr::IsInteger, + DOC(operations_research, sat, python, LinearExpr, IsInteger)) // Operators. - // Note that we keep the 3 APIS (expr, int, double) instead of using an - // ExprOrValue argument as this is more efficient. + // Note that we keep the 3 APIS (expr, int, double) instead of using + // an ExprOrValue argument as this is more efficient. .def("__add__", &LinearExpr::Add, arg("other").none(false), py::return_value_policy::automatic, py::keep_alive<0, 1>(), + DOC(operations_research, sat, python, LinearExpr, Add), py::keep_alive<0, 2>()) .def("__add__", &LinearExpr::AddInt, arg("cst"), + DOC(operations_research, sat, python, LinearExpr, AddInt), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__add__", &LinearExpr::AddFloat, arg("cst"), + DOC(operations_research, sat, python, LinearExpr, AddFloat), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__radd__", &LinearExpr::AddInt, arg("cst"), + DOC(operations_research, sat, python, LinearExpr, AddInt), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__radd__", &LinearExpr::AddFloat, arg("cst"), + DOC(operations_research, sat, python, LinearExpr, AddInt), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__sub__", &LinearExpr::Sub, arg("other").none(false), + DOC(operations_research, sat, python, LinearExpr, Sub), py::return_value_policy::automatic, py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) .def("__sub__", &LinearExpr::SubInt, arg("cst"), + DOC(operations_research, sat, python, LinearExpr, SubInt), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__sub__", &LinearExpr::SubFloat, arg("cst"), + DOC(operations_research, sat, python, LinearExpr, SubFloat), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__rsub__", &LinearExpr::RSubInt, arg("cst"), + DOC(operations_research, sat, python, LinearExpr, RSubInt), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__rsub__", &LinearExpr::RSubFloat, arg("cst"), + DOC(operations_research, sat, python, LinearExpr, RSubFloat), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__mul__", &LinearExpr::MulInt, arg("cst"), + DOC(operations_research, sat, python, LinearExpr, MulInt), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__mul__", &LinearExpr::MulFloat, arg("cst"), + DOC(operations_research, sat, python, LinearExpr, MulFloat), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__rmul__", &LinearExpr::MulInt, arg("cst"), + DOC(operations_research, sat, python, LinearExpr, MulInt), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__rmul__", &LinearExpr::MulFloat, arg("cst"), + DOC(operations_research, sat, python, LinearExpr, MulFloat), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__neg__", &LinearExpr::Neg, py::return_value_policy::automatic, + DOC(operations_research, sat, python, LinearExpr, Neg), py::keep_alive<0, 1>()) .def( "__eq__", @@ -525,6 +615,7 @@ PYBIND11_MODULE(swig_helper, m) { RaiseIfNone(rhs); return CheckBoundedLinearExpression(lhs->Eq(rhs), lhs, rhs); }, + DOC(operations_research, sat, python, LinearExpr, Eq), py::return_value_policy::automatic, py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) .def( @@ -532,11 +623,12 @@ PYBIND11_MODULE(swig_helper, m) { [](LinearExpr* lhs, int64_t rhs) { if (rhs == std::numeric_limits::max() || rhs == std::numeric_limits::min()) { - throw_error(PyExc_ArithmeticError, - "== INT_MIN or INT_MAX is not supported"); + ThrowError(PyExc_ValueError, + "== INT_MIN or INT_MAX is not supported"); } return CheckBoundedLinearExpression(lhs->EqCst(rhs), lhs); }, + DOC(operations_research, sat, python, LinearExpr, EqCst), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def( "__ne__", @@ -544,6 +636,7 @@ PYBIND11_MODULE(swig_helper, m) { RaiseIfNone(rhs); return CheckBoundedLinearExpression(lhs->Ne(rhs), lhs, rhs); }, + DOC(operations_research, sat, python, LinearExpr, Ne), py::return_value_policy::automatic, py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) .def( @@ -551,6 +644,7 @@ PYBIND11_MODULE(swig_helper, m) { [](LinearExpr* lhs, int64_t rhs) { return CheckBoundedLinearExpression(lhs->NeCst(rhs), lhs); }, + DOC(operations_research, sat, python, LinearExpr, NeCst), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def( "__le__", @@ -558,16 +652,18 @@ PYBIND11_MODULE(swig_helper, m) { RaiseIfNone(rhs); return CheckBoundedLinearExpression(lhs->Le(rhs), lhs, rhs); }, + DOC(operations_research, sat, python, LinearExpr, Le), py::return_value_policy::automatic, py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) .def( "__le__", [](LinearExpr* lhs, int64_t rhs) { if (rhs == std::numeric_limits::min()) { - throw_error(PyExc_ArithmeticError, "<= INT_MIN is not supported"); + ThrowError(PyExc_ArithmeticError, "<= INT_MIN is not supported"); } return CheckBoundedLinearExpression(lhs->LeCst(rhs), lhs); }, + DOC(operations_research, sat, python, LinearExpr, LeCst), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def( "__lt__", @@ -575,16 +671,18 @@ PYBIND11_MODULE(swig_helper, m) { RaiseIfNone(rhs); return CheckBoundedLinearExpression(lhs->Lt(rhs), lhs, rhs); }, + DOC(operations_research, sat, python, LinearExpr, Lt), py::return_value_policy::automatic, py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) .def( "__lt__", [](LinearExpr* lhs, int64_t rhs) { if (rhs == std::numeric_limits::min()) { - throw_error(PyExc_ArithmeticError, "< INT_MIN is not supported"); + ThrowError(PyExc_ArithmeticError, "< INT_MIN is not supported"); } return CheckBoundedLinearExpression(lhs->LtCst(rhs), lhs); }, + DOC(operations_research, sat, python, LinearExpr, LtCst), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def( "__ge__", @@ -592,16 +690,18 @@ PYBIND11_MODULE(swig_helper, m) { RaiseIfNone(rhs); return CheckBoundedLinearExpression(lhs->Ge(rhs), lhs, rhs); }, + DOC(operations_research, sat, python, LinearExpr, Ge), py::return_value_policy::automatic, py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) .def( "__ge__", [](LinearExpr* lhs, int64_t rhs) { if (rhs == std::numeric_limits::max()) { - throw_error(PyExc_ArithmeticError, ">= INT_MAX is not supported"); + ThrowError(PyExc_ArithmeticError, ">= INT_MAX is not supported"); } return CheckBoundedLinearExpression(lhs->GeCst(rhs), lhs); }, + DOC(operations_research, sat, python, LinearExpr, GeCst), py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def( "__gt__", @@ -609,127 +709,140 @@ PYBIND11_MODULE(swig_helper, m) { RaiseIfNone(rhs); return CheckBoundedLinearExpression(lhs->Gt(rhs), lhs, rhs); }, + DOC(operations_research, sat, python, LinearExpr, Gt), py::return_value_policy::automatic, py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) .def( "__gt__", [](LinearExpr* lhs, int64_t rhs) { if (rhs == std::numeric_limits::max()) { - throw_error(PyExc_ArithmeticError, "> INT_MAX is not supported"); + ThrowError(PyExc_ArithmeticError, "> INT_MAX is not supported"); } return CheckBoundedLinearExpression(lhs->GtCst(rhs), lhs); }, + DOC(operations_research, sat, python, LinearExpr, GtCst), py::return_value_policy::automatic, py::keep_alive<0, 1>()) // Disable other operators as they are not supported. .def("__div__", - [](LinearExpr* /*self*/, ExprOrValue /*other*/) { - throw_error(PyExc_NotImplementedError, - "calling / on a linear expression is not supported, " - "please use CpModel.add_division_equality"); + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling / on a linear expression is not supported, " + "please use CpModel.add_division_equality"); }) .def("__truediv__", - [](LinearExpr* /*self*/, ExprOrValue /*other*/) { - throw_error(PyExc_NotImplementedError, - "calling // on a linear expression is not supported, " - "please use CpModel.add_division_equality"); + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling // on a linear expression is not supported, " + "please use CpModel.add_division_equality"); }) .def("__mod__", - [](LinearExpr* /*self*/, ExprOrValue /*other*/) { - throw_error(PyExc_NotImplementedError, - "calling %% on a linear expression is not supported, " - "please use CpModel.add_modulo_equality"); + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling %% on a linear expression is not supported, " + "please use CpModel.add_modulo_equality"); }) .def("__pow__", - [](LinearExpr* /*self*/, ExprOrValue /*other*/) { - throw_error(PyExc_NotImplementedError, - "calling ** on a linear expression is not supported, " - "please use CpModel.add_multiplication_equality"); + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling ** on a linear expression is not supported, " + "please use CpModel.add_multiplication_equality"); }) .def("__lshift__", - [](LinearExpr* /*self*/, ExprOrValue /*other*/) { - throw_error( - PyExc_NotImplementedError, - "calling left shift on a linear expression is not supported"); + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling left shift on a linear expression is not " + "supported"); }) .def("__rshift__", - [](LinearExpr* /*self*/, ExprOrValue /*other*/) { - throw_error( - PyExc_NotImplementedError, - "calling right shift on a linear expression is not supported"); + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling right shift on a linear expression is " + "not supported"); }) .def("__and__", - [](LinearExpr* /*self*/, ExprOrValue /*other*/) { - throw_error(PyExc_NotImplementedError, - "calling and on a linear expression is not supported"); + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling and on a linear expression is not supported"); }) .def("__or__", - [](LinearExpr* /*self*/, ExprOrValue /*other*/) { - throw_error(PyExc_NotImplementedError, - "calling or on a linear expression is not supported"); + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling or on a linear expression is not supported"); }) .def("__xor__", - [](LinearExpr* /*self*/, ExprOrValue /*other*/) { - throw_error(PyExc_NotImplementedError, - "calling xor on a linear expression is not supported"); + [](LinearExpr* /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling xor on a linear expression is not supported"); }) .def("__abs__", [](LinearExpr* /*self*/) { - throw_error( + ThrowError( PyExc_NotImplementedError, "calling abs() on a linear expression is not supported, " "please use CpModel.add_abs_equality"); }) .def("__bool__", [](LinearExpr* /*self*/) { - throw_error(PyExc_NotImplementedError, - "Evaluating a LinearExpr instance as a Boolean is " - "not supported."); + ThrowError(PyExc_NotImplementedError, + "Evaluating a LinearExpr instance as a Boolean is " + "not supported."); }); // Expose Internal classes, mostly for testing. - py::class_(m, "CanonicalFloatExpression") - .def(py::init()) - .def_property_readonly("vars", &CanonicalFloatExpression::vars) - .def_property_readonly("coeffs", &CanonicalFloatExpression::coeffs) - .def_property_readonly("offset", &CanonicalFloatExpression::offset); + py::class_( + m, "FlatFloatExpr", DOC(operations_research, sat, python, FlatFloatExpr)) + .def(py::init(), py::keep_alive<1, 2>()) + .def_property_readonly("vars", &FlatFloatExpr::vars) + .def_property_readonly("coeffs", &FlatFloatExpr::coeffs) + .def_property_readonly("offset", &FlatFloatExpr::offset); - py::class_(m, "CanonicalIntExpression") - .def(py::init()) - .def_property_readonly("vars", &CanonicalIntExpression::vars) - .def_property_readonly("coeffs", &CanonicalIntExpression::coeffs) - .def_property_readonly("offset", &CanonicalIntExpression::offset) - .def_property_readonly("ok", &CanonicalIntExpression::ok); + py::class_( + m, "FlatIntExpr", DOC(operations_research, sat, python, FlatIntExpr)) + .def(py::init([](LinearExpr* expr) { + FlatIntExpr* result = new FlatIntExpr(expr); + if (!result->ok()) { + ThrowError( + PyExc_TypeError, + absl::StrCat("Tried to build a FlatIntExpr from a linear " + "expression with " + "floating point coefficients or constants: ", + expr->DebugString())); + } + return result; + }), + py::keep_alive<1, 2>()) + .def_property_readonly("vars", &FlatIntExpr::vars) + .def_property_readonly("coeffs", &FlatIntExpr::coeffs) + .def_property_readonly("offset", &FlatIntExpr::offset) + .def_property_readonly("ok", &FlatIntExpr::ok); - py::class_(m, "FloatAffine") - .def(py::init()) + py::class_( + m, "FloatAffine", DOC(operations_research, sat, python, FloatAffine)) + .def(py::init(), py::keep_alive<1, 2>()) .def_property_readonly("expression", &FloatAffine::expression) .def_property_readonly("coefficient", &FloatAffine::coefficient) .def_property_readonly("offset", &FloatAffine::offset); - py::class_(m, "IntAffine") - .def(py::init()) + py::class_( + m, "IntAffine", DOC(operations_research, sat, python, IntAffine)) + .def(py::init(), py::keep_alive<1, 2>()) .def_property_readonly("expression", &IntAffine::expression) .def_property_readonly("coefficient", &IntAffine::coefficient) .def_property_readonly("offset", &IntAffine::offset); - py::class_(m, "Literal", kLiteralClassDoc) - .def_property_readonly("index", &Literal::index, - "The index of the literal in the model.") + py::class_( + m, "Literal", DOC(operations_research, sat, python, Literal)) + .def_property_readonly( + "index", &Literal::index, + DOC(operations_research, sat, python, Literal, index)) .def("negated", &Literal::negated, - R"doc( - Returns the negation of a literal (a Boolean variable or its negation). - - This method implements the logical negation of a Boolean variable. - It is only valid if the variable has a Boolean domain (0 or 1). - - Note that this method is nilpotent: `x.negated().negated() == x`. - )doc") + DOC(operations_research, sat, python, Literal, negated)) .def("__invert__", &Literal::negated, - "Returns the negation of the current literal.") + DOC(operations_research, sat, python, Literal, negated)) .def("__bool__", [](Literal* /*self*/) { - throw_error(PyExc_NotImplementedError, - "Evaluating a Literal as a Boolean valueis " - "not supported."); + ThrowError(PyExc_NotImplementedError, + "Evaluating a Literal as a Boolean valueis " + "not supported."); }) // PEP8 Compatibility. .def("Not", &Literal::negated) @@ -744,86 +857,95 @@ PYBIND11_MODULE(swig_helper, m) { // object. That means memory of the negated variable is onwed by the C++ // layer, but a reference is kept in python to link the lifetime of the // negated variable to the base variable. - py::class_(m, "BaseIntVar") + py::class_( + m, "BaseIntVar", DOC(operations_research, sat, python, BaseIntVar)) .def(py::init()) // Integer variable. .def(py::init()) // Potential Boolean variable. - .def_property_readonly("index", &BaseIntVar::index, - "The index of the variable in the model.") - .def_property_readonly("is_boolean", &BaseIntVar::is_boolean, - "Whether the variable is Boolean.") + .def_property_readonly( + "index", &BaseIntVar::index, + DOC(operations_research, sat, python, BaseIntVar, index)) + .def_property_readonly( + "is_boolean", &BaseIntVar::is_boolean, + DOC(operations_research, sat, python, BaseIntVar, is_boolean)) .def("__str__", &BaseIntVar::ToString) .def("__repr__", &BaseIntVar::DebugString) .def( "negated", [](BaseIntVar* self) { if (!self->is_boolean()) { - throw_error(PyExc_TypeError, - "negated() is only supported for Boolean variables."); + ThrowError(PyExc_TypeError, + "negated() is only supported for Boolean variables."); } return self->negated(); }, - "Returns the negation of the current Boolean variable.", + DOC(operations_research, sat, python, BaseIntVar, negated), py::return_value_policy::reference_internal) .def( "__invert__", [](BaseIntVar* self) { if (!self->is_boolean()) { - throw_error(PyExc_TypeError, - "negated() is only supported for Boolean variables."); + ThrowError(PyExc_TypeError, + "negated() is only supported for Boolean variables."); } return self->negated(); }, - "Returns the negation of the current Boolean variable.", + DOC(operations_research, sat, python, BaseIntVar, negated), py::return_value_policy::reference_internal) // PEP8 Compatibility. .def( "Not", [](BaseIntVar* self) { if (!self->is_boolean()) { - throw_error(PyExc_TypeError, - "negated() is only supported for Boolean variables."); + ThrowError(PyExc_TypeError, + "negated() is only supported for Boolean variables."); } return self->negated(); }, py::return_value_policy::reference_internal); // Memory management: - // - Do we need a reference_internal (that add a py::keep_alive<1, 0>() rule) + // - Do we need a reference_internal (that add a py::keep_alive<1, 0>() + // rule) // or just a reference ? - py::class_(m, "NotBooleanVariable") - .def(py::init()) - .def_property_readonly("index", &NotBooleanVariable::index, - "The index of the variable in the model.") + py::class_( + m, "NotBooleanVariable", + DOC(operations_research, sat, python, NotBooleanVariable)) + .def_property_readonly( + "index", &NotBooleanVariable::index, + DOC(operations_research, sat, python, NotBooleanVariable, index)) .def("__str__", &NotBooleanVariable::ToString) .def("__repr__", &NotBooleanVariable::DebugString) .def("negated", &NotBooleanVariable::negated, - "Returns the negation of the current Boolean variable.", + DOC(operations_research, sat, python, NotBooleanVariable, negated), py::return_value_policy::reference_internal) .def("__invert__", &NotBooleanVariable::negated, - "Returns the negation of the current Boolean variable.", + DOC(operations_research, sat, python, NotBooleanVariable, negated), py::return_value_policy::reference_internal) .def("Not", &NotBooleanVariable::negated, "Returns the negation of the current Boolean variable.", py::return_value_policy::reference_internal); - py::class_(m, "BoundedLinearExpression") - .def(py::init, std::vector, - int64_t, Domain>()) + py::class_( + m, "BoundedLinearExpression", + DOC(operations_research, sat, python, BoundedLinearExpression)) + .def(py::init(), py::keep_alive<1, 2>()) + .def(py::init(), + py::keep_alive<1, 2>(), py::keep_alive<1, 3>()) .def_property_readonly("bounds", &BoundedLinearExpression::bounds) .def_property_readonly("vars", &BoundedLinearExpression::vars) .def_property_readonly("coeffs", &BoundedLinearExpression::coeffs) .def_property_readonly("offset", &BoundedLinearExpression::offset) + .def_property_readonly("ok", &BoundedLinearExpression::ok) .def("__str__", &BoundedLinearExpression::ToString) .def("__repr__", &BoundedLinearExpression::DebugString) .def("__bool__", [](const BoundedLinearExpression& self) { bool result; if (self.CastToBool(&result)) return result; - throw_error(PyExc_NotImplementedError, - absl::StrCat("Evaluating a BoundedLinearExpression '", - self.ToString(), - "'instance as a Boolean is " - "not supported.") - .c_str()); + ThrowError(PyExc_NotImplementedError, + absl::StrCat("Evaluating a BoundedLinearExpression '", + self.ToString(), + "'instance as a Boolean is " + "not supported.")); return false; }); } // NOLINT(readability/fn_size) diff --git a/ortools/sat/python/cp_model_helper_test.py b/ortools/sat/python/cp_model_helper_test.py index ce87d7a402..c73f6cbd2b 100644 --- a/ortools/sat/python/cp_model_helper_test.py +++ b/ortools/sat/python/cp_model_helper_test.py @@ -12,85 +12,357 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for ortools.sat.python.cp_model_helper.""" +"""Unit tests for ortools.sat.python.cmh.""" + +import sys from absl.testing import absltest -import numpy as np -from ortools.sat.python import cp_model_helper + +from google.protobuf import text_format +from ortools.sat import cp_model_pb2 +from ortools.sat import sat_parameters_pb2 +from ortools.sat.python import cp_model_helper as cmh + + +class Callback(cmh.SolutionCallback): + + def __init__(self): + cmh.SolutionCallback.__init__(self) + self.__solution_count = 0 + + def OnSolutionCallback(self): + print("New Solution") + self.__solution_count += 1 + + def solution_count(self): + return self.__solution_count + + +class BestBoundCallback: + + def __init__(self): + self.best_bound: float = 0.0 + + def new_best_bound(self, bb: float): + self.best_bound = bb + + +class TestIntVar(cmh.BaseIntVar): + + def __init__(self, index: int, name: str, is_boolean: bool = False) -> None: + cmh.BaseIntVar.__init__(self, index, is_boolean) + self._name = name + + def __str__(self) -> str: + return self._name + + def __repr__(self) -> str: + return self._name class CpModelHelperTest(absltest.TestCase): - def test_is_boolean(self): - print("test_is_boolean") - self.assertTrue(cp_model_helper.is_boolean(True)) - self.assertTrue(cp_model_helper.is_boolean(False)) - self.assertFalse(cp_model_helper.is_boolean(1)) - self.assertFalse(cp_model_helper.is_boolean(0)) - self.assertTrue(cp_model_helper.is_boolean(np.bool_(1))) - self.assertTrue(cp_model_helper.is_boolean(np.bool_(0))) + def tearDown(self) -> None: + super().tearDown() + sys.stdout.flush() - def testto_capped_int64(self): - print("testto_capped_int64") - self.assertEqual( - cp_model_helper.to_capped_int64(cp_model_helper.INT_MAX), - cp_model_helper.INT_MAX, - ) - self.assertEqual( - cp_model_helper.to_capped_int64(cp_model_helper.INT_MAX + 1), - cp_model_helper.INT_MAX, - ) - self.assertEqual( - cp_model_helper.to_capped_int64(cp_model_helper.INT_MIN), - cp_model_helper.INT_MIN, - ) - self.assertEqual( - cp_model_helper.to_capped_int64(cp_model_helper.INT_MIN - 1), - cp_model_helper.INT_MIN, - ) - self.assertEqual(cp_model_helper.to_capped_int64(15), 15) + def testVariableDomain(self): + model_string = """ + variables { domain: [ -10, 10 ] } + variables { domain: [ -5, -5, 3, 6 ] } + """ + model = cp_model_pb2.CpModelProto() + self.assertTrue(text_format.Parse(model_string, model)) - def testcapped_subtraction(self): - print("testcapped_subtraction") - self.assertEqual(cp_model_helper.capped_subtraction(10, 5), 5) - self.assertEqual( - cp_model_helper.capped_subtraction(cp_model_helper.INT_MIN, 5), - cp_model_helper.INT_MIN, - ) - self.assertEqual( - cp_model_helper.capped_subtraction(cp_model_helper.INT_MIN, -5), - cp_model_helper.INT_MIN, - ) - self.assertEqual( - cp_model_helper.capped_subtraction(cp_model_helper.INT_MAX, 5), - cp_model_helper.INT_MAX, - ) - self.assertEqual( - cp_model_helper.capped_subtraction(cp_model_helper.INT_MAX, -5), - cp_model_helper.INT_MAX, - ) - self.assertEqual( - cp_model_helper.capped_subtraction(2, cp_model_helper.INT_MIN), - cp_model_helper.INT_MAX, - ) - self.assertEqual( - cp_model_helper.capped_subtraction(2, cp_model_helper.INT_MAX), - cp_model_helper.INT_MIN, - ) - self.assertRaises( - OverflowError, - cp_model_helper.capped_subtraction, - cp_model_helper.INT_MAX, - cp_model_helper.INT_MAX, - ) - self.assertRaises( - OverflowError, - cp_model_helper.capped_subtraction, - cp_model_helper.INT_MIN, - cp_model_helper.INT_MIN, - ) - self.assertRaises(TypeError, cp_model_helper.capped_subtraction, 5, "dummy") - self.assertRaises(TypeError, cp_model_helper.capped_subtraction, "dummy", 5) + d0 = cmh.CpSatHelper.variable_domain(model.variables[0]) + d1 = cmh.CpSatHelper.variable_domain(model.variables[1]) + + self.assertEqual(d0.flattened_intervals(), [-10, 10]) + self.assertEqual(d1.flattened_intervals(), [-5, -5, 3, 6]) + + def testSimpleSolve(self): + model_string = """ + variables { domain: -10 domain: 10 } + variables { domain: -10 domain: 10 } + variables { domain: -461168601842738790 domain: 461168601842738790 } + constraints { + linear { + vars: 0 + vars: 1 + coeffs: 1 + coeffs: 1 + domain: -9223372036854775808 + domain: 9223372036854775807 + } + } + constraints { + linear { + vars: 0 + vars: 1 + vars: 2 + coeffs: 1 + coeffs: 2 + coeffs: -1 + domain: 0 + domain: 9223372036854775807 + } + } + objective { + vars: -3 + coeffs: 1 + scaling_factor: -1 + }""" + model = cp_model_pb2.CpModelProto() + self.assertTrue(text_format.Parse(model_string, model)) + + solve_wrapper = cmh.SolveWrapper() + response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) + + self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(30.0, response_wrapper.objective_value()) + + def testSimpleSolveWithCore(self): + model_string = """ + variables { domain: -10 domain: 10 } + variables { domain: -10 domain: 10 } + variables { domain: -461168601842738790 domain: 461168601842738790 } + constraints { + linear { + vars: 0 + vars: 1 + coeffs: 1 + coeffs: 1 + domain: -9223372036854775808 + domain: 9223372036854775807 + } + } + constraints { + linear { + vars: 0 + vars: 1 + vars: 2 + coeffs: 1 + coeffs: 2 + coeffs: -1 + domain: 0 + domain: 9223372036854775807 + } + } + objective { + vars: -3 + coeffs: 1 + scaling_factor: -1 + }""" + model = cp_model_pb2.CpModelProto() + self.assertTrue(text_format.Parse(model_string, model)) + + parameters = sat_parameters_pb2.SatParameters(optimize_with_core=True) + + solve_wrapper = cmh.SolveWrapper() + solve_wrapper.set_parameters(parameters) + response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) + + self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(30.0, response_wrapper.objective_value()) + + def testSimpleSolveWithProtoApi(self): + model = cp_model_pb2.CpModelProto() + x = model.variables.add() + x.domain.extend([-10, 10]) + y = model.variables.add() + y.domain.extend([-10, 10]) + obj_var = model.variables.add() + obj_var.domain.extend([-461168601842738790, 461168601842738790]) + ct = model.constraints.add() + ct.linear.vars.extend([0, 1, 2]) + ct.linear.coeffs.extend([1, 2, -1]) + ct.linear.domain.extend([0, 0]) + model.objective.vars.append(-3) + model.objective.coeffs.append(1) + model.objective.scaling_factor = -1 + + solve_wrapper = cmh.SolveWrapper() + response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) + + self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(30.0, response_wrapper.objective_value()) + self.assertEqual(30.0, response_wrapper.best_objective_bound()) + + def testSolutionCallback(self): + model_string = """ + variables { domain: 0 domain: 5 } + variables { domain: 0 domain: 5 } + constraints { + linear { vars: 0 vars: 1 coeffs: 1 coeffs: 1 domain: 6 domain: 6 } } + """ + model = cp_model_pb2.CpModelProto() + self.assertTrue(text_format.Parse(model_string, model)) + + solve_wrapper = cmh.SolveWrapper() + callback = Callback() + solve_wrapper.add_solution_callback(callback) + params = sat_parameters_pb2.SatParameters() + params.enumerate_all_solutions = True + solve_wrapper.set_parameters(params) + response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) + + self.assertEqual(5, callback.solution_count()) + self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + + def testBestBoundCallback(self): + model_string = """ + variables { domain: 0 domain: 1 } + variables { domain: 0 domain: 1 } + variables { domain: 0 domain: 1 } + variables { domain: 0 domain: 1 } + constraints { bool_or { literals: [0, 1, 2, 3] } } + objective { + vars: [0, 1, 2, 3] + coeffs: [3, 2, 4, 5] + offset: 0.6 + } + """ + model = cp_model_pb2.CpModelProto() + self.assertTrue(text_format.Parse(model_string, model)) + + solve_wrapper = cmh.SolveWrapper() + best_bound_callback = BestBoundCallback() + solve_wrapper.add_best_bound_callback(best_bound_callback.new_best_bound) + params = sat_parameters_pb2.SatParameters() + params.num_workers = 1 + params.linearization_level = 2 + params.log_search_progress = True + solve_wrapper.set_parameters(params) + response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) + + self.assertEqual(2.6, best_bound_callback.best_bound) + self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + + def testModelStats(self): + model_string = """ + variables { domain: -10 domain: 10 } + variables { domain: -10 domain: 10 } + variables { domain: -1000 domain: 1000 } + constraints { + linear { + vars: 0 + vars: 1 + coeffs: 1 + coeffs: 1 + domain: -1000 + domain: 1000 + } + } + constraints { + linear { + vars: 0 + vars: 1 + vars: 2 + coeffs: 1 + coeffs: 2 + coeffs: -1 + domain: 0 + domain: 1000 + } + } + objective { + vars: -3 + coeffs: 1 + scaling_factor: -1 + } + name: 'testModelStats' + """ + model = cp_model_pb2.CpModelProto() + self.assertTrue(text_format.Parse(model_string, model)) + stats = cmh.CpSatHelper.model_stats(model) + self.assertTrue(stats) + + def testIntLinExpr(self): + x = TestIntVar(0, "x") + self.assertTrue(x.is_integer()) + self.assertIsInstance(x, cmh.BaseIntVar) + self.assertIsInstance(x, cmh.LinearExpr) + e1 = x + 2 + self.assertTrue(e1.is_integer()) + self.assertEqual(str(e1), "(x + 2)") + e2 = 3 + x + self.assertTrue(e2.is_integer()) + self.assertEqual(str(e2), "(x + 3)") + y = TestIntVar(1, "y") + e3 = y * 5 + self.assertTrue(e3.is_integer()) + self.assertEqual(str(e3), "(5 * y)") + e4 = -2 * y + self.assertTrue(e4.is_integer()) + self.assertEqual(str(e4), "(-2 * y)") + e5 = x - 1 + self.assertTrue(e5.is_integer()) + self.assertEqual(str(e5), "(x - 1)") + e6 = x - 2 * y + self.assertTrue(e6.is_integer()) + self.assertEqual(str(e6), "(x - (2 * y))") + z = TestIntVar(2, "z", True) + e7 = -z + self.assertTrue(e7.is_integer()) + self.assertEqual(str(e7), "(-z)") + not_z = ~z + self.assertTrue(not_z.is_integer()) + self.assertEqual(str(not_z), "not(z)") + self.assertEqual(not_z.index, -3) + + e8 = cmh.LinearExpr.sum([x, y, z]) + self.assertEqual(str(e8), "(x + y + z)") + e9 = cmh.LinearExpr.sum([x, y, z, 11]) + self.assertEqual(str(e9), "(x + y + z + 11)") + e10 = cmh.LinearExpr.weighted_sum([x, y, z], [1, 2, 3]) + self.assertEqual(str(e10), "(x + 2 * y + 3 * z)") + e11 = cmh.LinearExpr.weighted_sum([x, y, z, 5], [1, 2, 3, -1]) + self.assertEqual(str(e11), "(x + 2 * y + 3 * z - 5)") + + def testFloatLinExpr(self): + x = TestIntVar(0, "x") + self.assertTrue(x.is_integer()) + self.assertIsInstance(x, TestIntVar) + self.assertIsInstance(x, cmh.LinearExpr) + e1 = x + 2.5 + self.assertFalse(e1.is_integer()) + self.assertEqual(str(e1), "(x + 2.5)") + e2 = 3.1 + x + self.assertFalse(e2.is_integer()) + self.assertEqual(str(e2), "(x + 3.1)") + y = TestIntVar(1, "y") + e3 = y * 5.2 + self.assertFalse(e3.is_integer()) + self.assertEqual(str(e3), "(5.2 * y)") + e4 = -2.25 * y + self.assertFalse(e4.is_integer()) + self.assertEqual(str(e4), "(-2.25 * y)") + e5 = x - 1.1 + self.assertFalse(e5.is_integer()) + self.assertEqual(str(e5), "(x - 1.1)") + e6 = x + 2.4 * y + self.assertFalse(e6.is_integer()) + self.assertEqual(str(e6), "(x + (2.4 * y))") + e7 = x - 2.4 * y + self.assertFalse(e7.is_integer()) + self.assertEqual(str(e7), "(x - (2.4 * y))") + + z = TestIntVar(2, "z") + e8 = cmh.LinearExpr.sum([x, y, z, -2]) + self.assertTrue(e8.is_integer()) + self.assertEqual(str(e8), "(x + y + z - 2)") + e9 = cmh.LinearExpr.sum([x, y, z, 1.5]) + self.assertFalse(e9.is_integer()) + self.assertEqual(str(e9), "(x + y + z + 1.5)") + e10 = cmh.LinearExpr.weighted_sum([x, y, z], [1.0, 2.25, 5.5]) + self.assertFalse(e10.is_integer()) + self.assertEqual(str(e10), "(x + 2.25 * y + 5.5 * z)") + e11 = cmh.LinearExpr.weighted_sum([x, y, z, 1.5], [1.0, 2.25, 5.5, -1]) + self.assertFalse(e11.is_integer()) + self.assertEqual(str(e11), "(x + 2.25 * y + 5.5 * z - 1.5)") + e12 = (x + 2) * 3.1 + self.assertFalse(e12.is_integer()) + self.assertEqual(str(e12), "(3.1 * (x + 2))") if __name__ == "__main__": diff --git a/ortools/sat/python/cp_model_helper.py b/ortools/sat/python/cp_model_numbers.py similarity index 65% rename from ortools/sat/python/cp_model_helper.py rename to ortools/sat/python/cp_model_numbers.py index fe92fb50b4..26b7928df5 100644 --- a/ortools/sat/python/cp_model_helper.py +++ b/ortools/sat/python/cp_model_numbers.py @@ -14,14 +14,12 @@ """helpers methods for the cp_model module.""" import numbers -from typing import Any, Union +from typing import Any import numpy as np INT_MIN = -9223372036854775808 # hardcoded to be platform independent. INT_MAX = 9223372036854775807 -INT32_MIN = -2147483648 -INT32_MAX = 2147483647 def is_boolean(x: Any) -> bool: @@ -33,33 +31,6 @@ def is_boolean(x: Any) -> bool: return False -def is_zero(x: Any) -> bool: - """Checks if the x is 0 or 0.0.""" - if isinstance(x, numbers.Integral): - return int(x) == 0 - if isinstance(x, numbers.Real): - return float(x) == 0.0 - return False - - -def is_one(x: Any) -> bool: - """Checks if x is 1 or 1.0.""" - if isinstance(x, numbers.Integral): - return int(x) == 1 - if isinstance(x, numbers.Real): - return float(x) == 1.0 - return False - - -def is_minus_one(x: Any) -> bool: - """Checks if x is -1 or -1.0 .""" - if isinstance(x, numbers.Integral): - return int(x) == -1 - if isinstance(x, numbers.Real): - return float(x) == -1.0 - return False - - def assert_is_zero_or_one(x: Any) -> int: """Asserts that x is 0 or 1 and returns it as an int.""" if not isinstance(x, numbers.Integral): @@ -70,15 +41,6 @@ def assert_is_zero_or_one(x: Any) -> int: return x_as_int -def assert_is_a_number(x: Any) -> Union[int, float]: - """Asserts that x is a number and returns it casted to an int or a float.""" - if isinstance(x, numbers.Integral): - return int(x) - if isinstance(x, numbers.Real): - return float(x) - raise TypeError(f"Not a number: {x} of type {type(x)}") - - def to_capped_int64(v: int) -> int: """Restrict v within [INT_MIN..INT_MAX] range.""" if v > INT_MAX: diff --git a/ortools/sat/python/cp_model_numbers_test.py b/ortools/sat/python/cp_model_numbers_test.py new file mode 100644 index 0000000000..7fbc9e2a2a --- /dev/null +++ b/ortools/sat/python/cp_model_numbers_test.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright 2010-2025 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +from absl.testing import absltest +import numpy as np + +from ortools.sat.python import cp_model_numbers as cmn + + +class CpModelNumbersTest(absltest.TestCase): + + def tearDown(self) -> None: + super().tearDown() + sys.stdout.flush() + + def test_is_boolean(self): + self.assertTrue(cmn.is_boolean(True)) + self.assertTrue(cmn.is_boolean(False)) + self.assertFalse(cmn.is_boolean(1)) + self.assertFalse(cmn.is_boolean(0)) + self.assertTrue(cmn.is_boolean(np.bool_(1))) + self.assertTrue(cmn.is_boolean(np.bool_(0))) + + def testto_capped_int64(self): + self.assertEqual(cmn.to_capped_int64(cmn.INT_MAX), cmn.INT_MAX) + self.assertEqual(cmn.to_capped_int64(cmn.INT_MAX + 1), cmn.INT_MAX) + self.assertEqual(cmn.to_capped_int64(cmn.INT_MIN), cmn.INT_MIN) + self.assertEqual(cmn.to_capped_int64(cmn.INT_MIN - 1), cmn.INT_MIN) + self.assertEqual(cmn.to_capped_int64(15), 15) + + def testcapped_subtraction(self): + self.assertEqual(cmn.capped_subtraction(10, 5), 5) + self.assertEqual(cmn.capped_subtraction(cmn.INT_MIN, 5), cmn.INT_MIN) + self.assertEqual(cmn.capped_subtraction(cmn.INT_MIN, -5), cmn.INT_MIN) + self.assertEqual(cmn.capped_subtraction(cmn.INT_MAX, 5), cmn.INT_MAX) + self.assertEqual(cmn.capped_subtraction(cmn.INT_MAX, -5), cmn.INT_MAX) + self.assertEqual(cmn.capped_subtraction(2, cmn.INT_MIN), cmn.INT_MAX) + self.assertEqual(cmn.capped_subtraction(2, cmn.INT_MAX), cmn.INT_MIN) + self.assertRaises( + OverflowError, cmn.capped_subtraction, cmn.INT_MAX, cmn.INT_MAX + ) + self.assertRaises( + OverflowError, cmn.capped_subtraction, cmn.INT_MIN, cmn.INT_MIN + ) + self.assertRaises(TypeError, cmn.capped_subtraction, 5, "dummy") + self.assertRaises(TypeError, cmn.capped_subtraction, "dummy", 5) + + +if __name__ == "__main__": + absltest.main() diff --git a/ortools/sat/python/cp_model_test.py b/ortools/sat/python/cp_model_test.py index be9839bc77..2d6220c275 100644 --- a/ortools/sat/python/cp_model_test.py +++ b/ortools/sat/python/cp_model_test.py @@ -13,14 +13,16 @@ # limitations under the License. import itertools +import sys import time from absl.testing import absltest +import numpy as np import pandas as pd from ortools.sat import cp_model_pb2 from ortools.sat.python import cp_model -from ortools.sat.python import swig_helper +from ortools.sat.python import cp_model_helper as cmh class SolutionCounter(cp_model.CpSolverSolutionCallback): @@ -151,8 +153,11 @@ class BestBoundTimeCallback: class CpModelTest(absltest.TestCase): + def tearDown(self) -> None: + super().tearDown() + sys.stdout.flush() + def testCreateIntegerVariable(self) -> None: - print("testCreateIntegerVariable") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") self.assertEqual("x", str(x)) @@ -176,7 +181,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual("5", str(cst)) def testLiteral(self) -> None: - print("testLiteral") model = cp_model.CpModel() x = model.new_bool_var("x") self.assertEqual("x", str(x)) @@ -198,7 +202,6 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, z.__invert__) def testNegation(self) -> None: - print("testNegation") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") b = model.new_bool_var("b") @@ -211,7 +214,6 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, x.negated) def testEqualityOverload(self) -> None: - print("testEqualityOverload") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") y = model.new_int_var(0, 5, "y") @@ -219,7 +221,6 @@ class CpModelTest(absltest.TestCase): self.assertNotEqual(x, y) def testLinear(self) -> None: - print("testLinear") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") y = model.new_int_var(-10, 10, "y") @@ -231,7 +232,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(-5, solver.value(y)) def testLinearConstraint(self) -> None: - print("testLinear") model = cp_model.CpModel() model.add_linear_constraint(5, 0, 10) model.add_linear_constraint(-1, 0, 10) @@ -242,7 +242,6 @@ class CpModelTest(absltest.TestCase): self.assertEmpty(model.proto.constraints[1].bool_or.literals) def testLinearNonEqual(self) -> None: - print("testLinearNonEqual") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") y = model.new_int_var(-10, 10, "y") @@ -254,7 +253,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(cp_model.INT_MAX, ct.linear.domain[3]) def testEq(self) -> None: - print("testEq") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") ct = model.add(x == 2).proto @@ -265,7 +263,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(2, ct.linear.domain[1]) def testGe(self) -> None: - print("testGe") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") ct = model.add(x >= 2).proto @@ -276,7 +273,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(cp_model.INT_MAX, ct.linear.domain[1]) def testGt(self) -> None: - print("testGt") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") ct = model.add(x > 2).proto @@ -287,7 +283,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(cp_model.INT_MAX, ct.linear.domain[1]) def testLe(self) -> None: - print("testLe") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") ct = model.add(x <= 2).proto @@ -298,7 +293,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(2, ct.linear.domain[1]) def testLt(self) -> None: - print("testLt") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") ct = model.add(x < 2).proto @@ -309,7 +303,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(1, ct.linear.domain[1]) def testEqVar(self) -> None: - print("testEqVar") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") y = model.new_int_var(-10, 10, "y") @@ -323,7 +316,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(2, ct.linear.domain[1]) def testGeVar(self) -> None: - print("testGeVar") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") y = model.new_int_var(-10, 10, "y") @@ -338,7 +330,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(cp_model.INT_MAX, ct.linear.domain[1]) def testGtVar(self) -> None: - print("testGeVar") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") y = model.new_int_var(-10, 10, "y") @@ -353,7 +344,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(cp_model.INT_MAX, ct.linear.domain[1]) def testLeVar(self) -> None: - print("testLeVar") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") y = model.new_int_var(-10, 10, "y") @@ -368,7 +358,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(1, ct.linear.domain[1]) def testLtVar(self) -> None: - print("testLtVar") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") y = model.new_int_var(-10, 10, "y") @@ -383,7 +372,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(0, ct.linear.domain[1]) def testLinearNonEqualWithConstant(self) -> None: - print("testLinearNonEqualWithConstant") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") y = model.new_int_var(-10, 10, "y") @@ -396,7 +384,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(cp_model.INT_MAX, ct.linear.domain[3]) def testLinearWithEnforcement(self) -> None: - print("testLinearWithEnforcement") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") y = model.new_int_var(-10, 10, "y") @@ -416,7 +403,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(2, model.proto.constraints[2].enforcement_literal[1]) def testConstraintWithName(self) -> None: - print("testConstraintWithName") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") y = model.new_int_var(-10, 10, "y") @@ -424,7 +410,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual("test_constraint", ct.name) def testNaturalApiMinimize(self) -> None: - print("testNaturalApiMinimize") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") y = model.new_int_var(-10, 10, "y") @@ -438,7 +423,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(-10.0, solver.objective_value) def testNaturalApiMaximizeFloat(self) -> None: - print("testNaturalApiMaximizeFloat") model = cp_model.CpModel() x = model.new_bool_var("x") y = model.new_int_var(0, 10, "y") @@ -451,7 +435,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(16.1, solver.objective_value) def testNaturalApiMaximizeComplex(self) -> None: - print("testNaturalApiMaximizeComplex") model = cp_model.CpModel() x1 = model.new_bool_var("x1") x2 = model.new_bool_var("x1") @@ -476,7 +459,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(8, solver.objective_value) def testNaturalApiMaximize(self) -> None: - print("testNaturalApiMaximize") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") y = model.new_int_var(-10, 10, "y") @@ -489,7 +471,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(17, solver.objective_value) def testMinimizeConstant(self) -> None: - print("testMinimizeConstant") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") model.add(x >= -1) @@ -499,7 +480,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(10, solver.objective_value) def testMaximizeConstant(self) -> None: - print("testMinimizeConstant") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") model.add(x >= -1) @@ -509,7 +489,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(5, solver.objective_value) def testAddTrue(self) -> None: - print("testAddTrue") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") model.add(3 >= -1) @@ -519,7 +498,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(-10, solver.value(x)) def testAddFalse(self) -> None: - print("testAddFalse") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") model.add(3 <= -1) @@ -528,9 +506,8 @@ class CpModelTest(absltest.TestCase): self.assertEqual("INFEASIBLE", solver.status_name(solver.solve(model))) def testSum(self) -> None: - print("testSum") model = cp_model.CpModel() - x = [model.new_int_var(0, 2, "x%i" % i) for i in range(100)] + x = [model.new_int_var(0, 2, f"x{i}") for i in range(100)] model.add(sum(x) <= 1) model.maximize(x[99]) solver = cp_model.CpSolver() @@ -539,10 +516,170 @@ class CpModelTest(absltest.TestCase): for i in range(100): self.assertEqual(solver.value(x[i]), 1 if i == 99 else 0) - def testSumWithApi(self) -> None: - print("testSumWithApi") + def testSumParsing(self) -> None: model = cp_model.CpModel() - x = [model.new_int_var(0, 2, "x%i" % i) for i in range(100)] + x = [model.new_int_var(0, 2, f"x{i}") for i in range(5)] + s1 = cp_model.LinearExpr.sum(x) + self.assertTrue(s1.is_integer()) + flat_s1 = cp_model.FlatIntExpr(s1) + self.assertLen(flat_s1.vars, 5) + self.assertEqual(0, flat_s1.offset) + + s2 = cp_model.LinearExpr.sum(x[0], x[2], x[4]) + self.assertTrue(s2.is_integer()) + flat_s2 = cp_model.FlatIntExpr(s2) + self.assertLen(flat_s2.vars, 3) + self.assertEqual(0, flat_s2.offset) + + s3 = cp_model.LinearExpr.sum(x[0], x[2], 2, x[4], -4) + self.assertTrue(s3.is_integer()) + flat_s3 = cp_model.FlatIntExpr(s3) + self.assertLen(flat_s3.vars, 3) + self.assertEqual(-2, flat_s3.offset) + + s4 = cp_model.LinearExpr.sum(x[0], x[2], 2.5) + self.assertFalse(s4.is_integer()) + flat_s4 = cp_model.FlatFloatExpr(s4) + self.assertLen(flat_s4.vars, 2) + self.assertEqual(2.5, flat_s4.offset) + + s5 = cp_model.LinearExpr.sum(x[0], x[2], 2, 1.5) + self.assertFalse(s5.is_integer()) + flat_s5 = cp_model.FlatFloatExpr(s5) + self.assertLen(flat_s5.vars, 2) + self.assertEqual(3.5, flat_s5.offset) + self.assertEqual(str(s5), "(x0 + x2 + 3.5)") + + s5b = cp_model.LinearExpr.sum(x[0], x[2], 2, -2.5) + self.assertFalse(s5b.is_integer()) + self.assertEqual(str(s5b), "(x0 + x2 - 0.5)") + flat_s5b = cp_model.FlatFloatExpr(s5b) + self.assertLen(flat_s5b.vars, 2) + self.assertEqual(-0.5, flat_s5b.offset) + + s6 = cp_model.LinearExpr.sum(x[0], x[2], np.int8(-1), np.int64(-4)) + self.assertTrue(s6.is_integer()) + flat_s6 = cp_model.FlatIntExpr(s6) + self.assertLen(flat_s6.vars, 2) + self.assertEqual(-5, flat_s6.offset) + + s7 = cp_model.LinearExpr.sum(x[0], x[2], np.float64(2.0), np.float32(1.5)) + self.assertFalse(s7.is_integer()) + flat_s7 = cp_model.FlatFloatExpr(s7) + self.assertLen(flat_s7.vars, 2) + self.assertEqual(3.5, flat_s7.offset) + + s8 = cp_model.LinearExpr.sum(x[0], 3) + self.assertTrue(s8.is_integer()) + self.assertIsInstance(s8, cmh.IntAffine) + self.assertEqual(s8.expression, x[0]) + self.assertEqual(s8.coefficient, 1) + self.assertEqual(s8.offset, 3) + + s9 = cp_model.LinearExpr.sum(x[0], -2.1) + self.assertFalse(s9.is_integer()) + self.assertIsInstance(s9, cmh.FloatAffine) + self.assertEqual(s9.expression, x[0]) + self.assertEqual(s9.coefficient, 1.0) + self.assertEqual(s9.offset, -2.1) + self.assertEqual(str(s9), "(x0 - 2.1)") + + s10 = cp_model.LinearExpr.sum(x[0], 1, -1) + self.assertTrue(s10.is_integer()) + self.assertIsInstance(s10, cp_model.IntVar) + self.assertEqual(s10, x[0]) + + s11 = cp_model.LinearExpr.sum(x[0]) + self.assertTrue(s11.is_integer()) + self.assertIsInstance(s11, cp_model.IntVar) + self.assertEqual(s11, x[0]) + + s12 = cp_model.LinearExpr.sum(x[0], -x[2], -3) + self.assertEqual(str(s12), "(x0 + (-x2) - 3)") + self.assertEqual( + repr(s12), + "SumArray(x0(0..2), IntAffine(expr=x2(0..2), coeff=-1, offset=0)," + " int_offset=-3)", + ) + flat_int_s12 = cp_model.FlatIntExpr(s12) + self.assertEqual(str(flat_int_s12), "(x0 - x2 - 3)") + self.assertEqual( + repr(flat_int_s12), + "FlatIntExpr([x0(0..2), x2(0..2)], [1, -1], -3)", + ) + flat_float_s12 = cp_model.FlatFloatExpr(s12) + self.assertEqual(str(flat_float_s12), "(x0 - x2 - 3)") + self.assertEqual( + repr(flat_float_s12), + "FlatFloatExpr([x0(0..2), x2(0..2)], [1, -1], -3)", + ) + + class FakeNpDTypeA: + + def __init__(self): + self.dtype = 2 + pass + + def __str__(self): + return "FakeNpDTypeA" + + class FakeNpDTypeB: + + def __init__(self): + self.is_integer = False + pass + + def __str__(self): + return "FakeNpDTypeB" + + with self.assertRaises(TypeError): + cp_model.LinearExpr.sum(x[0], x[2], "foo") + + with self.assertRaises(TypeError): + cp_model.LinearExpr.sum(x[0], x[2], FakeNpDTypeA()) + + with self.assertRaises(TypeError): + cp_model.LinearExpr.sum(x[0], x[2], FakeNpDTypeB()) + + def testWeightedSumParsing(self) -> None: + model = cp_model.CpModel() + x = [model.new_int_var(0, 2, f"x{i}") for i in range(5)] + c = [1, -2, 2, 3, 0.0] + float_c = [1, -1.0, 2, 3, 0.0] + + s1 = cp_model.LinearExpr.weighted_sum(x, c) + self.assertTrue(s1.is_integer()) + flat_s1 = cp_model.FlatIntExpr(s1) + self.assertLen(flat_s1.vars, 4) + self.assertEqual(0, flat_s1.offset) + + s2 = cp_model.LinearExpr.weighted_sum(x, float_c) + self.assertFalse(s2.is_integer()) + flat_s2 = cp_model.FlatFloatExpr(s2) + self.assertLen(flat_s2.vars, 4) + self.assertEqual(0, flat_s2.offset) + + s3 = cp_model.LinearExpr.weighted_sum(x + [2], c + [-1]) + self.assertTrue(s3.is_integer()) + flat_s3 = cp_model.FlatIntExpr(s3) + self.assertLen(flat_s3.vars, 4) + self.assertEqual(-2, flat_s3.offset) + + s4 = cp_model.LinearExpr.weighted_sum(x + [2], float_c + [-1.0]) + self.assertFalse(s4.is_integer()) + flat_s4 = cp_model.FlatFloatExpr(s4) + self.assertLen(flat_s4.vars, 4) + self.assertEqual(-2, flat_s4.offset) + + s5 = cp_model.LinearExpr.weighted_sum(x + [np.int16(2)], c + [-1]) + self.assertTrue(s5.is_integer()) + flat_s5 = cp_model.FlatIntExpr(s5) + self.assertLen(flat_s5.vars, 4) + self.assertEqual(-2, flat_s5.offset) + + def testSumWithApi(self) -> None: + model = cp_model.CpModel() + x = [model.new_int_var(0, 2, f"x{i}") for i in range(100)] self.assertEqual(cp_model.LinearExpr.sum([x[0]]), x[0]) self.assertEqual(cp_model.LinearExpr.sum([x[0], 0]), x[0]) self.assertEqual(cp_model.LinearExpr.sum([x[0], 0.0]), x[0]) @@ -559,9 +696,8 @@ class CpModelTest(absltest.TestCase): self.assertEqual(solver.value(x[i]), 1 if i == 99 else 0) def testWeightedSum(self) -> None: - print("testWeightedSum") model = cp_model.CpModel() - x = [model.new_int_var(0, 2, "x%i" % i) for i in range(100)] + x = [model.new_int_var(0, 2, f"x{i}") for i in range(100)] c = [2] * 100 model.add(cp_model.LinearExpr.weighted_sum(x, c) <= 3) model.maximize(x[99]) @@ -585,35 +721,31 @@ class CpModelTest(absltest.TestCase): cp_model.LinearExpr.WeightedSum([x[0]], [1.1, 2.2]) def testAllDifferent(self) -> None: - print("testAllDifferent") model = cp_model.CpModel() - x = [model.new_int_var(0, 4, "x%i" % i) for i in range(5)] + x = [model.new_int_var(0, 4, f"x{i}") for i in range(5)] model.add_all_different(x) self.assertLen(model.proto.variables, 5) self.assertLen(model.proto.constraints, 1) self.assertLen(model.proto.constraints[0].all_diff.exprs, 5) def testAllDifferentGen(self) -> None: - print("testAllDifferentGen") model = cp_model.CpModel() - model.add_all_different(model.new_int_var(0, 4, "x%i" % i) for i in range(5)) + model.add_all_different(model.new_int_var(0, 4, f"x{i}") for i in range(5)) self.assertLen(model.proto.variables, 5) self.assertLen(model.proto.constraints, 1) self.assertLen(model.proto.constraints[0].all_diff.exprs, 5) def testAllDifferentList(self) -> None: - print("testAllDifferentList") model = cp_model.CpModel() - x = [model.new_int_var(0, 4, "x%i" % i) for i in range(5)] + x = [model.new_int_var(0, 4, f"x{i}") for i in range(5)] model.add_all_different(x[0], x[1], x[2], x[3], x[4]) self.assertLen(model.proto.variables, 5) self.assertLen(model.proto.constraints, 1) self.assertLen(model.proto.constraints[0].all_diff.exprs, 5) def testElement(self) -> None: - print("testElement") model = cp_model.CpModel() - x = [model.new_int_var(0, 4, "x%i" % i) for i in range(5)] + x = [model.new_int_var(0, 4, f"x{i}") for i in range(5)] model.add_element(x[0], [x[1], 2, 4, x[2]], x[4]) self.assertLen(model.proto.variables, 5) self.assertLen(model.proto.constraints, 1) @@ -624,9 +756,8 @@ class CpModelTest(absltest.TestCase): model.add_element(x[0], [], x[4]) def testFixedElement(self) -> None: - print("testFixedElement") model = cp_model.CpModel() - x = [model.new_int_var(0, 4, "x%i" % i) for i in range(4)] + x = [model.new_int_var(0, 4, f"x{i}") for i in range(4)] model.add_element(1, [x[0], 2, 4, x[2]], x[3]) self.assertLen(model.proto.variables, 4) self.assertLen(model.proto.constraints, 1) @@ -636,9 +767,8 @@ class CpModelTest(absltest.TestCase): self.assertEqual([2, 2], model.proto.constraints[0].linear.domain) def testAffineElement(self) -> None: - print("testAffineElement") model = cp_model.CpModel() - x = [model.new_int_var(0, 4, "x%i" % i) for i in range(5)] + x = [model.new_int_var(0, 4, f"x{i}") for i in range(5)] model.add_element(x[0] + 1, [2 * x[1] - 2, 2, 4, x[2]], x[4] - 1) self.assertLen(model.proto.variables, 5) self.assertLen(model.proto.constraints, 1) @@ -657,7 +787,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(-2, expr0.offset) def testCircuit(self) -> None: - print("testCircuit") model = cp_model.CpModel() x = [model.new_bool_var(f"x{i}") for i in range(5)] arcs: list[tuple[int, int, cp_model.LiteralT]] = [ @@ -673,7 +802,6 @@ class CpModelTest(absltest.TestCase): model.add_circuit([]) def testMultipleCircuit(self) -> None: - print("testMultipleCircuit") model = cp_model.CpModel() x = [model.new_bool_var(f"x{i}") for i in range(5)] arcs: list[tuple[int, int, cp_model.LiteralT]] = [ @@ -689,9 +817,8 @@ class CpModelTest(absltest.TestCase): model.add_multiple_circuit([]) def testAllowedAssignments(self) -> None: - print("testAllowedAssignments") model = cp_model.CpModel() - x = [model.new_int_var(0, 4, "x%i" % i) for i in range(5)] + x = [model.new_int_var(0, 4, f"x{i}") for i in range(5)] model.add_allowed_assignments( x, [(0, 1, 2, 3, 4), (4, 3, 2, 1, 1), (0, 0, 0, 0, 0)] ) @@ -711,9 +838,8 @@ class CpModelTest(absltest.TestCase): ) def testForbiddenAssignments(self) -> None: - print("testForbiddenAssignments") model = cp_model.CpModel() - x = [model.new_int_var(0, 4, "x%i" % i) for i in range(5)] + x = [model.new_int_var(0, 4, f"x{i}") for i in range(5)] model.add_forbidden_assignments( x, [(0, 1, 2, 3, 4), (4, 3, 2, 1, 1), (0, 0, 0, 0, 0)] ) @@ -736,9 +862,8 @@ class CpModelTest(absltest.TestCase): ) def testAutomaton(self) -> None: - print("testAutomaton") model = cp_model.CpModel() - x = [model.new_int_var(0, 4, "x%i" % i) for i in range(5)] + x = [model.new_int_var(0, 4, f"x{i}") for i in range(5)] model.add_automaton(x, 0, [2, 3], [(0, 0, 0), (0, 1, 1), (1, 2, 2), (2, 3, 3)]) self.assertLen(model.proto.variables, 5) self.assertLen(model.proto.constraints, 1) @@ -773,10 +898,9 @@ class CpModelTest(absltest.TestCase): model.add_automaton(x, 0, [2, 3], []) def testInverse(self) -> None: - print("testInverse") model = cp_model.CpModel() - x = [model.new_int_var(0, 4, "x%i" % i) for i in range(5)] - y = [model.new_int_var(0, 4, "y%i" % i) for i in range(5)] + x = [model.new_int_var(0, 4, f"x{i}") for i in range(5)] + y = [model.new_int_var(0, 4, f"y{i}") for i in range(5)] model.add_inverse(x, y) self.assertLen(model.proto.variables, 10) self.assertLen(model.proto.constraints, 1) @@ -784,10 +908,9 @@ class CpModelTest(absltest.TestCase): self.assertLen(model.proto.constraints[0].inverse.f_inverse, 5) def testMaxEquality(self) -> None: - print("testMaxEquality") model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") - y = [model.new_int_var(0, 4, "y%i" % i) for i in range(5)] + y = [model.new_int_var(0, 4, f"y{i}") for i in range(5)] model.add_max_equality(x, y) self.assertLen(model.proto.variables, 6) self.assertLen(model.proto.constraints, 1) @@ -796,10 +919,9 @@ class CpModelTest(absltest.TestCase): self.assertEqual(1, model.proto.constraints[0].lin_max.target.coeffs[0]) def testMinEquality(self) -> None: - print("testMinEquality") model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") - y = [model.new_int_var(0, 4, "y%i" % i) for i in range(5)] + y = [model.new_int_var(0, 4, f"y{i}") for i in range(5)] model.add_min_equality(x, y) self.assertLen(model.proto.variables, 6) self.assertLen(model.proto.constraints[0].lin_max.exprs, 5) @@ -807,10 +929,9 @@ class CpModelTest(absltest.TestCase): self.assertEqual(-1, model.proto.constraints[0].lin_max.target.coeffs[0]) def testMinEqualityList(self) -> None: - print("testMinEqualityList") model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") - y = [model.new_int_var(0, 4, "y%i" % i) for i in range(5)] + y = [model.new_int_var(0, 4, f"y{i}") for i in range(5)] model.add_min_equality(x, [y[0], y[2], y[1], y[3]]) self.assertLen(model.proto.variables, 6) self.assertLen(model.proto.constraints[0].lin_max.exprs, 4) @@ -818,10 +939,9 @@ class CpModelTest(absltest.TestCase): self.assertEqual(-1, model.proto.constraints[0].lin_max.target.coeffs[0]) def testMinEqualityTuple(self) -> None: - print("testMinEqualityTuple") model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") - y = [model.new_int_var(0, 4, "y%i" % i) for i in range(5)] + y = [model.new_int_var(0, 4, f"y{i}") for i in range(5)] model.add_min_equality(x, (y[0], y[2], y[1], y[3])) self.assertLen(model.proto.variables, 6) self.assertLen(model.proto.constraints[0].lin_max.exprs, 4) @@ -829,10 +949,9 @@ class CpModelTest(absltest.TestCase): self.assertEqual(-1, model.proto.constraints[0].lin_max.target.coeffs[0]) def testMinEqualityGenerator(self) -> None: - print("testMinEqualityGenerator") model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") - y = [model.new_int_var(0, 4, "y%i" % i) for i in range(5)] + y = [model.new_int_var(0, 4, f"y{i}") for i in range(5)] model.add_min_equality(x, (z for z in y)) self.assertLen(model.proto.variables, 6) self.assertLen(model.proto.constraints[0].lin_max.exprs, 5) @@ -840,7 +959,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(-1, model.proto.constraints[0].lin_max.target.coeffs[0]) def testMinEqualityWithConstant(self) -> None: - print("testMinEqualityWithConstant") model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") y = model.new_int_var(0, 4, "y") @@ -857,7 +975,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(-3, lin_max.exprs[1].offset) def testAbs(self) -> None: - print("testAbs") model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") y = model.new_int_var(-5, 5, "y") @@ -884,7 +1001,6 @@ class CpModelTest(absltest.TestCase): self.assertTrue(passed) def testDivision(self) -> None: - print("testDivision") model = cp_model.CpModel() x = model.new_int_var(0, 10, "x") y = model.new_int_var(0, 50, "y") @@ -911,7 +1027,6 @@ class CpModelTest(absltest.TestCase): self.assertTrue(passed) def testModulo(self) -> None: - print("testModulo") model = cp_model.CpModel() x = model.new_int_var(0, 10, "x") y = model.new_int_var(0, 50, "y") @@ -938,10 +1053,9 @@ class CpModelTest(absltest.TestCase): self.assertTrue(passed) def testMultiplicationEquality(self) -> None: - print("testMultiplicationEquality") model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") - y = [model.new_int_var(0, 4, "y%i" % i) for i in range(5)] + y = [model.new_int_var(0, 4, f"y{i}") for i in range(5)] model.add_multiplication_equality(x, y) self.assertLen(model.proto.variables, 6) self.assertLen(model.proto.constraints, 1) @@ -949,7 +1063,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(0, model.proto.constraints[0].int_prod.target.vars[0]) def testImplication(self) -> None: - print("testImplication") model = cp_model.CpModel() x = model.new_bool_var("x") y = model.new_bool_var("y") @@ -962,9 +1075,8 @@ class CpModelTest(absltest.TestCase): self.assertEqual(y.index, model.proto.constraints[0].bool_or.literals[0]) def testBoolOr(self) -> None: - print("testBoolOr") model = cp_model.CpModel() - x = [model.new_bool_var("x%i" % i) for i in range(5)] + x = [model.new_bool_var(f"x{i}") for i in range(5)] model.add_bool_or(x) self.assertLen(model.proto.variables, 5) self.assertLen(model.proto.constraints, 1) @@ -978,9 +1090,8 @@ class CpModelTest(absltest.TestCase): model.add_bool_or([y, False]) def testBoolOrListOrGet(self) -> None: - print("testBoolOrListOrGet") model = cp_model.CpModel() - x = [model.new_bool_var("x%i" % i) for i in range(5)] + x = [model.new_bool_var(f"x{i}") for i in range(5)] model.add_bool_or(x) model.add_bool_or(True, x[0], x[2]) model.add_bool_or(False, x[0]) @@ -993,9 +1104,8 @@ class CpModelTest(absltest.TestCase): self.assertLen(model.proto.constraints[3].bool_or.literals, 4) def testAtLeastOne(self) -> None: - print("testAtLeastOne") model = cp_model.CpModel() - x = [model.new_bool_var("x%i" % i) for i in range(5)] + x = [model.new_bool_var(f"x{i}") for i in range(5)] model.add_at_least_one(x) self.assertLen(model.proto.variables, 5) self.assertLen(model.proto.constraints, 1) @@ -1007,9 +1117,8 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, model.add_at_least_one, [y, False]) def testAtMostOne(self) -> None: - print("testAtMostOne") model = cp_model.CpModel() - x = [model.new_bool_var("x%i" % i) for i in range(5)] + x = [model.new_bool_var(f"x{i}") for i in range(5)] model.add_at_most_one(x) self.assertLen(model.proto.variables, 5) self.assertLen(model.proto.constraints, 1) @@ -1021,9 +1130,8 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, model.add_at_most_one, [y, False]) def testExactlyOne(self) -> None: - print("testExactlyOne") model = cp_model.CpModel() - x = [model.new_bool_var("x%i" % i) for i in range(5)] + x = [model.new_bool_var(f"x{i}") for i in range(5)] model.add_exactly_one(x) self.assertLen(model.proto.variables, 5) self.assertLen(model.proto.constraints, 1) @@ -1035,9 +1143,8 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, model.add_exactly_one, [y, False]) def testBoolAnd(self) -> None: - print("testBoolAnd") model = cp_model.CpModel() - x = [model.new_bool_var("x%i" % i) for i in range(5)] + x = [model.new_bool_var(f"x{i}") for i in range(5)] model.add_bool_and(x) self.assertLen(model.proto.variables, 5) self.assertLen(model.proto.constraints, 1) @@ -1048,25 +1155,22 @@ class CpModelTest(absltest.TestCase): self.assertEqual(5, model.proto.constraints[1].bool_and.literals[2]) def testBoolXOr(self) -> None: - print("testBoolXOr") model = cp_model.CpModel() - x = [model.new_bool_var("x%i" % i) for i in range(5)] + x = [model.new_bool_var(f"x{i}") for i in range(5)] model.add_bool_xor(x) self.assertLen(model.proto.variables, 5) self.assertLen(model.proto.constraints, 1) self.assertLen(model.proto.constraints[0].bool_xor.literals, 5) def testMapDomain(self) -> None: - print("testMapDomain") model = cp_model.CpModel() - x = [model.new_bool_var("x%i" % i) for i in range(5)] + x = [model.new_bool_var(f"x{i}") for i in range(5)] y = model.new_int_var(0, 10, "y") model.add_map_domain(y, x, 2) self.assertLen(model.proto.variables, 6) self.assertLen(model.proto.constraints, 10) def testInterval(self) -> None: - print("testInterval") model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") y = model.new_int_var(0, 3, "y") @@ -1083,7 +1187,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(str(end_expr), "(x + 2)") def testRebuildFromLinearExpressionProto(self) -> None: - print("testRebuildFromLinearExpressionProto") model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") y = model.new_int_var(0, 1, "y") @@ -1102,7 +1205,7 @@ class CpModelTest(absltest.TestCase): proto.coeffs.append(2) proto.offset = 2 expr = cp_model.rebuild_from_linear_expression_proto(model.proto, proto) - canonical_expr = swig_helper.CanonicalIntExpression(expr) + canonical_expr = cmh.FlatIntExpr(expr) self.assertEqual(canonical_expr.vars[0], x) self.assertEqual(canonical_expr.vars[1], y) self.assertEqual(canonical_expr.coeffs[0], 1) @@ -1112,13 +1215,11 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, canonical_expr.vars[0].negated) def testAbsentInterval(self) -> None: - print("testInterval") model = cp_model.CpModel() i = model.new_optional_interval_var(1, 0, 1, False, "") self.assertEqual(0, i.index) def testOptionalInterval(self) -> None: - print("testOptionalInterval") model = cp_model.CpModel() b = model.new_bool_var("b") x = model.new_int_var(0, 4, "x") @@ -1139,7 +1240,6 @@ class CpModelTest(absltest.TestCase): model.new_optional_interval_var(1, 2, 3, b + 1, "x") def testNoOverlap(self) -> None: - print("testNoOverlap") model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") y = model.new_int_var(0, 3, "y") @@ -1153,7 +1253,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(1, ct.proto.no_overlap.intervals[1]) def testNoOverlap2D(self) -> None: - print("testNoOverlap2D") model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") y = model.new_int_var(0, 3, "y") @@ -1170,7 +1269,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(0, ct.proto.no_overlap_2d.y_intervals[1]) def testCumulative(self) -> None: - print("testCumulative") model = cp_model.CpModel() intervals = [ model.new_interval_var( @@ -1190,7 +1288,6 @@ class CpModelTest(absltest.TestCase): model.add_cumulative([intervals[0], 3], [2, 3], 3) def testGetOrMakeIndexFromConstant(self) -> None: - print("testGetOrMakeIndexFromConstant") model = cp_model.CpModel() self.assertEqual(0, model.get_or_make_index_from_constant(3)) self.assertEqual(0, model.get_or_make_index_from_constant(3)) @@ -1201,7 +1298,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(3, model_var.domain[1]) def testStr(self) -> None: - print("testStr") model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") self.assertEqual(str(x == 2), "x == 2") @@ -1225,8 +1321,27 @@ class CpModelTest(absltest.TestCase): self.assertEqual(str(x != y), "(x - y) != 0") self.assertEqual( "0 <= x <= 10", - str(cp_model.BoundedLinearExpression([x], [1], 0, cp_model.Domain(0, 10))), + str(cp_model.BoundedLinearExpression(x, cp_model.Domain(0, 10))), ) + e1 = 2 * cp_model.LinearExpr.sum([x, y]) + flat_e1 = cmh.FlatIntExpr(e1) + self.assertEqual(str(e1), "(2 * (x + y))") + self.assertEqual(flat_e1.vars, [x, y]) + self.assertEqual(flat_e1.coeffs, [2, 2]) + self.assertEqual(flat_e1.offset, 0) + repeat_flat_e1 = cmh.FlatIntExpr(flat_e1 + 3) + self.assertEqual(repeat_flat_e1.vars, [x, y]) + self.assertEqual(repeat_flat_e1.coeffs, [2, 2]) + self.assertEqual(repeat_flat_e1.offset, 3) + float_flat_e1 = cmh.FlatFloatExpr(flat_e1) + self.assertEqual(float_flat_e1.vars, [x, y]) + self.assertEqual(float_flat_e1.coeffs, [2.0, 2.0]) + self.assertEqual(float_flat_e1.offset, 0.0) + repeat_float_flat_e1 = cmh.FlatFloatExpr(float_flat_e1 - 2.5) + self.assertEqual(repeat_float_flat_e1.vars, [x, y]) + self.assertEqual(repeat_float_flat_e1.coeffs, [2.0, 2.0]) + self.assertEqual(repeat_float_flat_e1.offset, -2.5) + b = model.new_bool_var("b") self.assertEqual(str(cp_model.LinearExpr.term(b.negated(), 3)), "(3 * not(b))") @@ -1234,7 +1349,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(str(i), "i") def testRepr(self) -> None: - print("testRepr") model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") y = model.new_int_var(0, 3, "y") @@ -1265,19 +1379,16 @@ class CpModelTest(absltest.TestCase): ) def testDisplayBounds(self) -> None: - print("testDisplayBounds") self.assertEqual("10..20", cp_model.display_bounds([10, 20])) self.assertEqual("10", cp_model.display_bounds([10, 10])) self.assertEqual("10..15, 20..30", cp_model.display_bounds([10, 15, 20, 30])) def testShortName(self) -> None: - print("testShortName") model = cp_model.CpModel() model.proto.variables.add(domain=[5, 10]) self.assertEqual("[5..10]", cp_model.short_name(model.proto, 0)) def testIntegerExpressionErrors(self) -> None: - print("testIntegerExpressionErrors") model = cp_model.CpModel() x = model.new_int_var(0, 1, "x") y = model.new_int_var(0, 3, "y") @@ -1297,14 +1408,12 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, x.__mul__, "dummy") def testModelErrors(self) -> None: - print("testModelErrors") model = cp_model.CpModel() self.assertRaises(TypeError, model.add, "dummy") self.assertRaises(TypeError, model.get_or_make_index, "dummy") self.assertRaises(TypeError, model.minimize, "dummy") def testSolverErrors(self) -> None: - print("testSolverErrors") model = cp_model.CpModel() x = model.new_int_var(0, 1, "x") y = model.new_int_var(-10, 10, "y") @@ -1317,7 +1426,6 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, model.add_bool_or, [x, y]) def testHasObjectiveMinimize(self) -> None: - print("testHasObjectiveMinimizs") model = cp_model.CpModel() x = model.new_int_var(0, 1, "x") y = model.new_int_var(-10, 10, "y") @@ -1327,7 +1435,6 @@ class CpModelTest(absltest.TestCase): self.assertTrue(model.has_objective()) def testHasObjectiveMaximize(self) -> None: - print("testHasObjectiveMaximizs") model = cp_model.CpModel() x = model.new_int_var(0, 1, "x") y = model.new_int_var(-10, 10, "y") @@ -1337,7 +1444,6 @@ class CpModelTest(absltest.TestCase): self.assertTrue(model.has_objective()) def testSearchForAllSolutions(self) -> None: - print("testSearchForAllSolutions") model = cp_model.CpModel() x = model.new_int_var(0, 5, "x") y = model.new_int_var(0, 5, "y") @@ -1351,7 +1457,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(5, solution_counter.solution_count) def testSolveWithSolutionCallback(self) -> None: - print("testSolveWithSolutionCallback") model = cp_model.CpModel() x = model.new_int_var(0, 5, "x") y = model.new_int_var(0, 5, "y") @@ -1365,7 +1470,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(6, solution_sum.sum) def testBestBoundCallback(self) -> None: - print("testBestBoundCallback") model = cp_model.CpModel() x0 = model.new_bool_var("x0") x1 = model.new_bool_var("x1") @@ -1384,7 +1488,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(2.6, best_bound_callback.best_bound) def testValue(self) -> None: - print("testValue") model = cp_model.CpModel() x = model.new_int_var(0, 10, "x") y = model.new_int_var(0, 10, "y") @@ -1397,7 +1500,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(solver.value(2), 2) def testBooleanValue(self) -> None: - print("testBooleanValue") model = cp_model.CpModel() x = model.new_bool_var("x") y = model.new_bool_var("y") @@ -1419,7 +1521,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(solver.boolean_value(0), False) def testUnsupportedOperators(self) -> None: - print("testUnsupportedOperators") model = cp_model.CpModel() x = model.new_int_var(0, 10, "x") y = model.new_int_var(0, 10, "y") @@ -1435,7 +1536,6 @@ class CpModelTest(absltest.TestCase): print("passed2") def testIsLiteralTrueFalse(self) -> None: - print("testIsLiteralTrueFalse") model = cp_model.CpModel() x = model.new_constant(0) self.assertFalse(cp_model.object_is_a_true_literal(x)) @@ -1448,7 +1548,6 @@ class CpModelTest(absltest.TestCase): self.assertFalse(cp_model.object_is_a_false_literal(True)) def testSolveMinimizeWithSolutionCallback(self) -> None: - print("testSolveMinimizeWithSolutionCallback") model = cp_model.CpModel() x = model.new_int_var(0, 5, "x") y = model.new_int_var(0, 5, "y") @@ -1459,11 +1558,9 @@ class CpModelTest(absltest.TestCase): solution_obj = SolutionObjective() status = solver.solve(model, solution_obj) self.assertEqual(cp_model.OPTIMAL, status) - print("obj = ", solution_obj.obj) self.assertEqual(11, solution_obj.obj) def testSolutionValue(self) -> None: - print("testSolutionValue") model = cp_model.CpModel() x = model.new_int_var(0, 5, "x") b = model.new_bool_var("b") @@ -1483,7 +1580,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual([True, False, True], solution_recorder.bool_var_values) def testSolutionHinting(self) -> None: - print("testSolutionHinting") model = cp_model.CpModel() x = model.new_int_var(0, 5, "x") y = model.new_int_var(0, 5, "y") @@ -1498,7 +1594,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(4, solver.value(y)) def testSolutionHintingWithBooleans(self) -> None: - print("testSolutionHintingWithBooleans") model = cp_model.CpModel() x = model.new_bool_var("x") y = model.new_bool_var("y") @@ -1513,7 +1608,6 @@ class CpModelTest(absltest.TestCase): self.assertFalse(solver.boolean_value(y)) def testStats(self) -> None: - print("testStats") model = cp_model.CpModel() x = model.new_int_var(0, 5, "x") y = model.new_int_var(0, 5, "y") @@ -1530,7 +1624,6 @@ class CpModelTest(absltest.TestCase): self.assertGreater(solver.wall_time, 0.0) def testSearchStrategy(self) -> None: - print("testSearchStrategy") model = cp_model.CpModel() x = model.new_int_var(0, 5, "x") y = model.new_int_var(0, 5, "y") @@ -1556,7 +1649,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(cp_model.SELECT_MAX_VALUE, strategy.domain_reduction_strategy) def testModelAndResponseStats(self) -> None: - print("testStats") model = cp_model.CpModel() x = model.new_int_var(0, 5, "x") y = model.new_int_var(0, 5, "y") @@ -1569,7 +1661,6 @@ class CpModelTest(absltest.TestCase): self.assertTrue(solver.response_stats()) def testValidateModel(self) -> None: - print("testValidateModel") model = cp_model.CpModel() x = model.new_int_var(0, 5, "x") y = model.new_int_var(0, 5, "y") @@ -1578,7 +1669,6 @@ class CpModelTest(absltest.TestCase): self.assertFalse(model.validate()) def testValidateModelWithOverflow(self) -> None: - print("testValidateModel") model = cp_model.CpModel() x = model.new_int_var(0, cp_model.INT_MAX, "x") y = model.new_int_var(0, 10, "y") @@ -1587,7 +1677,6 @@ class CpModelTest(absltest.TestCase): self.assertTrue(model.validate()) def testCopyModel(self) -> None: - print("testCopyModel") model = cp_model.CpModel() b = model.new_bool_var("b") x = model.new_int_var(0, 4, "x") @@ -1625,7 +1714,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(12, interval_ct.size.offset) def testCustomLog(self) -> None: - print("testCustomLog") model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") y = model.new_int_var(-10, 10, "y") @@ -1644,7 +1732,6 @@ class CpModelTest(absltest.TestCase): self.assertRegex(log_callback.log, ".*log_to_stdout.*") def testIssue2762(self) -> None: - print("testIssue2762") model = cp_model.CpModel() x = [model.new_bool_var("a"), model.new_bool_var("b")] @@ -1652,9 +1739,8 @@ class CpModelTest(absltest.TestCase): model.add((x[0] != 0) or (x[1] != 0)) def testModelError(self) -> None: - print("TestModelError") model = cp_model.CpModel() - x = [model.new_int_var(0, -2, "x%i" % i) for i in range(100)] + x = [model.new_int_var(0, -2, f"x{i}") for i in range(100)] model.add(sum(x) <= 1) solver = cp_model.CpSolver() solver.parameters.log_search_progress = True @@ -1662,7 +1748,6 @@ class CpModelTest(absltest.TestCase): self.assertEqual(solver.solution_info(), 'var #0 has no domain(): name: "x0"') def testIntVarSeries(self) -> None: - print("testIntVarSeries") df = pd.DataFrame([1, -1, 1], columns=["coeffs"]) model = cp_model.CpModel() x = model.new_int_var_series( @@ -1676,7 +1761,6 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, x.apply, lambda x: ~x) def testBoolVarSeries(self) -> None: - print("testBoolVarSeries") df = pd.DataFrame([1, -1, 1], columns=["coeffs"]) model = cp_model.CpModel() x = model.new_bool_var_series(name="x", index=df.index) @@ -1692,7 +1776,6 @@ class CpModelTest(absltest.TestCase): self.assertTrue((solution.values == [False, True, False]).all()) def testFixedSizeIntervalVarSeries(self) -> None: - print("testFixedSizeIntervalVarSeries") df = pd.DataFrame([2, 4, 6], columns=["size"]) model = cp_model.CpModel() starts = model.new_int_var_series( @@ -1718,7 +1801,6 @@ class CpModelTest(absltest.TestCase): self.assertLen(model.proto.constraints, 7) def testIntervalVarSeries(self) -> None: - print("testIntervalVarSeries") df = pd.DataFrame([2, 4, 6], columns=["size"]) model = cp_model.CpModel() starts = model.new_int_var_series( @@ -1768,7 +1850,6 @@ class CpModelTest(absltest.TestCase): self.assertLen(model.proto.constraints, 13) def testCompareWithNone(self) -> None: - print("testCompareWithNone") model = cp_model.CpModel() x = model.new_int_var(0, 10, "x") self.assertRaises(TypeError, x.__eq__, None) @@ -1779,7 +1860,6 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, x.__ge__, None) def testIssue4376SatModel(self) -> None: - print("testIssue4376SatModel") letters: str = "BCFLMRT" def symbols_from_string(text: str) -> list[int]: @@ -1890,8 +1970,6 @@ TRFM""" self.assertLess(time.time(), solution_callback.last_time + 5.0) def testIssue4376MinimizeModel(self) -> None: - print("testIssue4376MinimizeModel") - model = cp_model.CpModel() jobs = [ @@ -1995,7 +2073,6 @@ TRFM""" ) def testIssue4434(self) -> None: - print("testIssue4434") model = cp_model.CpModel() i = model.NewIntVar(0, 10, "i") j = model.NewIntVar(0, 10, "j") diff --git a/ortools/sat/python/linear_expr.cc b/ortools/sat/python/linear_expr.cc index 2a89d14420..c705a0bdca 100644 --- a/ortools/sat/python/linear_expr.cc +++ b/ortools/sat/python/linear_expr.cc @@ -20,6 +20,7 @@ #include #include +#include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "ortools/util/fp_roundtrip_conv.h" @@ -35,134 +36,6 @@ bool LinearExpr::IsInteger() const { return lin.ProcessAll(); } -LinearExpr* LinearExpr::Sum(const std::vector& exprs) { - if (exprs.empty()) { - return new IntConstant(0); - } else if (exprs.size() == 1) { - return exprs[0]; - } else { - return new SumArray(exprs); - } -} - -LinearExpr* LinearExpr::MixedSum(const std::vector& exprs) { - std::vector lin_exprs; - int64_t int_offset = 0; - double double_offset = 0.0; - - for (const ExprOrValue& choice : exprs) { - if (choice.expr != nullptr) { - lin_exprs.push_back(choice.expr); - } else { - int_offset += choice.int_value; - double_offset += choice.double_value; - } - } - - // Special case: if there is only one term, return it. - if (int_offset == 0 && double_offset == 0.0 && lin_exprs.size() == 1) { - return lin_exprs[0]; - } - - // Special case: if there is no double offset, return an integer expression. - if (double_offset == 0.0) { - if (lin_exprs.empty()) { - return new IntConstant(int_offset); - } else if (lin_exprs.size() == 1) { - return new IntAffine(lin_exprs[0], 1, int_offset); - } else { - return new SumArray(lin_exprs, int_offset); - } - } else { // General floating point case. - double_offset += static_cast(int_offset); - if (lin_exprs.empty()) { - return new FloatConstant(double_offset); - } else if (lin_exprs.size() == 1) { - return new FloatAffine(lin_exprs[0], 1.0, double_offset); - } else { - return new SumArray(lin_exprs, 0, double_offset); - } - } -} - -LinearExpr* LinearExpr::WeightedSumInt(const std::vector& exprs, - const std::vector& coeffs) { - if (exprs.empty()) return new IntConstant(0); - if (exprs.size() == 1) { - return new IntAffine(exprs[0], coeffs[0], 0); - } - return new IntWeightedSum(exprs, coeffs, 0); -} - -LinearExpr* LinearExpr::WeightedSumFloat(const std::vector& exprs, - const std::vector& coeffs) { - if (exprs.empty()) return new FloatConstant(0.0); - if (exprs.size() == 1) { - return new FloatAffine(exprs[0], coeffs[0], 0.0); - } - return new FloatWeightedSum(exprs, coeffs, 0.0); -} - -LinearExpr* LinearExpr::MixedWeightedSumInt( - const std::vector& exprs, const std::vector& coeffs) { - std::vector lin_exprs; - std::vector lin_coeffs; - int64_t int_cst = 0; - double double_cst = 0.0; - for (int i = 0; i < exprs.size(); ++i) { - if (exprs[i].expr != nullptr) { - lin_exprs.push_back(exprs[i].expr); - lin_coeffs.push_back(coeffs[i]); - } else { - int_cst += coeffs[i] * exprs[i].int_value; - double_cst += coeffs[i] * exprs[i].double_value; - } - } - - if (double_cst != 0.0) { - double_cst += static_cast(int_cst); - if (lin_exprs.empty()) return new FloatConstant(double_cst); - if (lin_exprs.size() == 1) { - return new FloatAffine(lin_exprs[0], static_cast(lin_coeffs[0]), - double_cst); - } - std::vector lin_coeffs_double; - lin_coeffs_double.reserve(lin_coeffs.size()); - for (int64_t coeff : lin_coeffs) { - lin_coeffs_double.push_back(static_cast(coeff)); - } - return new FloatWeightedSum(lin_exprs, lin_coeffs_double, double_cst); - } - - if (lin_exprs.empty()) return new IntConstant(int_cst); - if (lin_exprs.size() == 1) { - return new IntAffine(lin_exprs[0], lin_coeffs[0], int_cst); - } - return new IntWeightedSum(lin_exprs, lin_coeffs, int_cst); -} - -LinearExpr* LinearExpr::MixedWeightedSumFloat( - const std::vector& exprs, const std::vector& coeffs) { - std::vector lin_exprs; - std::vector lin_coeffs; - double cst = 0.0; - for (int i = 0; i < exprs.size(); ++i) { - if (exprs[i].expr != nullptr) { - lin_exprs.push_back(exprs[i].expr); - lin_coeffs.push_back(coeffs[i]); - } else { - cst += coeffs[i] * - (exprs[i].double_value + static_cast(exprs[i].int_value)); - } - } - - if (lin_exprs.empty()) return new FloatConstant(cst); - if (lin_exprs.size() == 1) { - return new FloatAffine(lin_exprs[0], lin_coeffs[0], cst); - } - return new FloatWeightedSum(lin_exprs, lin_coeffs, cst); -} - LinearExpr* LinearExpr::TermInt(LinearExpr* expr, int64_t coeff) { return new IntAffine(expr, coeff, 0); } @@ -269,17 +142,146 @@ double FloatExprVisitor::Process(const LinearExpr* expr, return offset_; } -CanonicalFloatExpression::CanonicalFloatExpression(LinearExpr* expr) { +FlatFloatExpr::FlatFloatExpr(LinearExpr* expr) { FloatExprVisitor lin; offset_ = lin.Process(expr, &vars_, &coeffs_); } -CanonicalIntExpression::CanonicalIntExpression(LinearExpr* expr) { +void FlatFloatExpr::VisitAsFloat(FloatExprVisitor& lin, double c) const { + for (int i = 0; i < vars_.size(); ++i) { + lin.AddVarCoeff(vars_[i], coeffs_[i] * c); + } + lin.AddConstant(offset_ * c); +} + +std::string FlatFloatExpr::ToString() const { + if (vars_.empty()) { + return absl::StrCat(RoundTripDoubleFormat(offset_)); + } + + std::string s = "("; + bool first_printed = true; + for (int i = 0; i < vars_.size(); ++i) { + if (coeffs_[i] == 0.0) continue; + if (first_printed) { + first_printed = false; + if (coeffs_[i] == 1.0) { + absl::StrAppend(&s, vars_[i]->ToString()); + } else if (coeffs_[i] == -1.0) { + absl::StrAppend(&s, "-", vars_[i]->ToString()); + } else { + absl::StrAppend(&s, RoundTripDoubleFormat(coeffs_[i]), " * ", + vars_[i]->ToString()); + } + } else { + if (coeffs_[i] == 1.0) { + absl::StrAppend(&s, " + ", vars_[i]->ToString()); + } else if (coeffs_[i] == -1.0) { + absl::StrAppend(&s, " - ", vars_[i]->ToString()); + } else if (coeffs_[i] > 0.0) { + absl::StrAppend(&s, " + ", RoundTripDoubleFormat(coeffs_[i]), " * ", + vars_[i]->ToString()); + } else { + absl::StrAppend(&s, " - ", RoundTripDoubleFormat(-coeffs_[i]), " * ", + vars_[i]->ToString()); + } + } + } + // If there are no terms, just print the offset. + if (first_printed) { + return absl::StrCat(RoundTripDoubleFormat(offset_)); + } + + // If there is an offset, print it. + if (offset_ != 0.0) { + if (offset_ > 0.0) { + absl::StrAppend(&s, " + ", RoundTripDoubleFormat(offset_)); + } else { + absl::StrAppend(&s, " - ", RoundTripDoubleFormat(-offset_)); + } + } + absl::StrAppend(&s, ")"); + return s; +} + +std::string FlatFloatExpr::DebugString() const { + return absl::StrCat("FlatFloatExpr([", + absl::StrJoin(vars_, ", ", + [](std::string* out, const LinearExpr* e) { + absl::StrAppend(out, e->DebugString()); + }), + "], [", + absl::StrJoin(coeffs_, ", ", + [](std::string* out, double coeff) { + absl::StrAppend( + out, RoundTripDoubleFormat(coeff)); + }), + "], ", RoundTripDoubleFormat(offset_), ")"); +} + +FlatIntExpr::FlatIntExpr(LinearExpr* expr) { IntExprVisitor lin; lin.AddToProcess(expr, 1); ok_ = lin.Process(&vars_, &coeffs_, &offset_); } +std::string FlatIntExpr::ToString() const { + if (vars_.empty()) { + return absl::StrCat(offset_); + } + + std::string s = "("; + bool first_printed = true; + for (int i = 0; i < vars_.size(); ++i) { + if (coeffs_[i] == 0) continue; + if (first_printed) { + first_printed = false; + if (coeffs_[i] == 1) { + absl::StrAppend(&s, vars_[i]->ToString()); + } else if (coeffs_[i] == -1) { + absl::StrAppend(&s, "-", vars_[i]->ToString()); + } else { + absl::StrAppend(&s, coeffs_[i], " * ", vars_[i]->ToString()); + } + } else { + if (coeffs_[i] == 1) { + absl::StrAppend(&s, " + ", vars_[i]->ToString()); + } else if (coeffs_[i] == -1) { + absl::StrAppend(&s, " - ", vars_[i]->ToString()); + } else if (coeffs_[i] > 1) { + absl::StrAppend(&s, " + ", coeffs_[i], " * ", vars_[i]->ToString()); + } else { + absl::StrAppend(&s, " - ", -coeffs_[i], " * ", vars_[i]->ToString()); + } + } + } + // If there are no terms, just print the offset. + if (first_printed) { + return absl::StrCat(offset_); + } + + // If there is an offset, print it. + if (offset_ != 0) { + if (offset_ > 0) { + absl::StrAppend(&s, " + ", offset_); + } else { + absl::StrAppend(&s, " - ", -offset_); + } + } + absl::StrAppend(&s, ")"); + return s; +} + +std::string FlatIntExpr::DebugString() const { + return absl::StrCat( + "FlatIntExpr([", + absl::StrJoin(vars_, ", ", + [](std::string* out, const BaseIntVar* var) { + absl::StrAppend(out, var->DebugString()); + }), + "], [", absl::StrJoin(coeffs_, ", "), "], ", offset_, ")"); +} + void FloatConstant::VisitAsFloat(FloatExprVisitor& lin, double c) const { lin.AddConstant(value_ * c); } @@ -290,18 +292,88 @@ std::string FloatConstant::DebugString() const { return absl::StrCat("FloatConstant(", RoundTripDoubleFormat(value_), ")"); } -FloatWeightedSum::FloatWeightedSum(const std::vector& exprs, - double offset) +SumArray::SumArray(const std::vector& exprs, int64_t int_offset, + double double_offset) : exprs_(exprs.begin(), exprs.end()), - coeffs_(exprs.size(), 1), - offset_(offset) {} + int_offset_(int_offset), + double_offset_(double_offset) { + DCHECK(int_offset_ == 0 || double_offset_ == 0.0); + DCHECK_GE(exprs_.size(), 2); +} + +bool SumArray::VisitAsInt(IntExprVisitor& lin, int64_t c) const { + if (double_offset_ != 0.0) return false; + for (int i = 0; i < exprs_.size(); ++i) { + lin.AddToProcess(exprs_[i], c); + } + lin.AddConstant(int_offset_ * c); + return true; +} + +void SumArray::VisitAsFloat(FloatExprVisitor& lin, double c) const { + for (int i = 0; i < exprs_.size(); ++i) { + lin.AddToProcess(exprs_[i], c); + } + if (int_offset_ != 0) { + lin.AddConstant(int_offset_ * c); + } else if (double_offset_ != 0.0) { + lin.AddConstant(double_offset_ * c); + } +} + +std::string SumArray::ToString() const { + DCHECK(!exprs_.empty()); + + std::string s = "("; + for (int i = 0; i < exprs_.size(); ++i) { + if (i > 0) { + absl::StrAppend(&s, " + "); + } + absl::StrAppend(&s, exprs_[i]->ToString()); + } + if (double_offset_ != 0.0) { + if (double_offset_ > 0.0) { + absl::StrAppend(&s, " + ", double_offset_); + } else { + absl::StrAppend(&s, " - ", -double_offset_); + } + } + if (int_offset_ != 0) { + if (int_offset_ > 0) { + absl::StrAppend(&s, " + ", int_offset_); + } else { + absl::StrAppend(&s, " - ", -int_offset_); + } + } + absl::StrAppend(&s, ")"); + return s; +} + +std::string SumArray::DebugString() const { + std::string s = absl::StrCat( + "SumArray(", + absl::StrJoin(exprs_, ", ", [](std::string* out, LinearExpr* expr) { + absl::StrAppend(out, expr->DebugString()); + })); + if (int_offset_ != 0) { + absl::StrAppend(&s, ", int_offset=", int_offset_); + } + if (double_offset_ != 0.0) { + absl::StrAppend(&s, + ", double_offset=", RoundTripDoubleFormat(double_offset_)); + } + absl::StrAppend(&s, ")"); + return s; +} FloatWeightedSum::FloatWeightedSum(const std::vector& exprs, const std::vector& coeffs, double offset) : exprs_(exprs.begin(), exprs.end()), coeffs_(coeffs.begin(), coeffs.end()), - offset_(offset) {} + offset_(offset) { + DCHECK_GE(exprs_.size(), 2); +} void FloatWeightedSum::VisitAsFloat(FloatExprVisitor& lin, double c) const { for (int i = 0; i < exprs_.size(); ++i) { @@ -311,9 +383,6 @@ void FloatWeightedSum::VisitAsFloat(FloatExprVisitor& lin, double c) const { } std::string FloatWeightedSum::ToString() const { - if (exprs_.empty()) { - return absl::StrCat(offset_); - } std::string s = "("; bool first_printed = true; for (int i = 0; i < exprs_.size(); ++i) { @@ -325,7 +394,8 @@ std::string FloatWeightedSum::ToString() const { } else if (coeffs_[i] == -1.0) { absl::StrAppend(&s, "-", exprs_[i]->ToString()); } else { - absl::StrAppend(&s, coeffs_[i], " * ", exprs_[i]->ToString()); + absl::StrAppend(&s, RoundTripDoubleFormat(coeffs_[i]), " * ", + exprs_[i]->ToString()); } } else { if (coeffs_[i] == 1.0) { @@ -333,23 +403,25 @@ std::string FloatWeightedSum::ToString() const { } else if (coeffs_[i] == -1.0) { absl::StrAppend(&s, " - ", exprs_[i]->ToString()); } else if (coeffs_[i] > 0.0) { - absl::StrAppend(&s, " + ", coeffs_[i], " * ", exprs_[i]->ToString()); + absl::StrAppend(&s, " + ", RoundTripDoubleFormat(coeffs_[i]), " * ", + exprs_[i]->ToString()); } else { - absl::StrAppend(&s, " - ", -coeffs_[i], " * ", exprs_[i]->ToString()); + absl::StrAppend(&s, " - ", RoundTripDoubleFormat(-coeffs_[i]), " * ", + exprs_[i]->ToString()); } } } // If there are no terms, just print the offset. if (first_printed) { - return absl::StrCat(offset_); + return absl::StrCat(RoundTripDoubleFormat(offset_)); } // If there is an offset, print it. if (offset_ != 0.0) { if (offset_ > 0.0) { - absl::StrAppend(&s, " + ", offset_); + absl::StrAppend(&s, " + ", RoundTripDoubleFormat(offset_)); } else { - absl::StrAppend(&s, " - ", -offset_); + absl::StrAppend(&s, " - ", RoundTripDoubleFormat(-offset_)); } } absl::StrAppend(&s, ")"); @@ -362,7 +434,35 @@ std::string FloatWeightedSum::DebugString() const { [](std::string* out, const LinearExpr* e) { absl::StrAppend(out, e->DebugString()); }), - "], [", absl::StrJoin(coeffs_, "], "), offset_, ")"); + "], [", + absl::StrJoin(coeffs_, ", ", + [](std::string* out, double coeff) { + absl::StrAppend( + out, RoundTripDoubleFormat(coeff)); + }), + RoundTripDoubleFormat(offset_), ")"); +} + +IntWeightedSum::IntWeightedSum(const std::vector& exprs, + const std::vector& coeffs, + int64_t offset) + : exprs_(exprs.begin(), exprs.end()), + coeffs_(coeffs.begin(), coeffs.end()), + offset_(offset) {} + +void IntWeightedSum::VisitAsFloat(FloatExprVisitor& lin, double c) const { + for (int i = 0; i < exprs_.size(); ++i) { + lin.AddToProcess(exprs_[i], coeffs_[i] * c); + } + lin.AddConstant(offset_ * c); +} + +bool IntWeightedSum::VisitAsInt(IntExprVisitor& lin, int64_t c) const { + for (int i = 0; i < exprs_.size(); ++i) { + lin.AddToProcess(exprs_[i], coeffs_[i] * c); + } + lin.AddConstant(offset_ * c); + return true; } std::string IntWeightedSum::ToString() const { @@ -451,142 +551,98 @@ std::string FloatAffine::DebugString() const { return absl::StrCat("FloatAffine(expr=", expr_->DebugString(), ", coeff=", coeff_, ", offset=", offset_, ")"); } + +IntAffine::IntAffine(LinearExpr* expr, int64_t coeff, int64_t offset) + : expr_(expr), coeff_(coeff), offset_(offset) {} + +bool IntAffine::VisitAsInt(IntExprVisitor& lin, int64_t c) const { + lin.AddToProcess(expr_, c * coeff_); + lin.AddConstant(offset_ * c); + return true; +} + +void IntAffine::VisitAsFloat(FloatExprVisitor& lin, double c) const { + lin.AddToProcess(expr_, c * coeff_); + lin.AddConstant(offset_ * c); +} + +std::string IntAffine::ToString() const { + std::string s = "("; + if (coeff_ == 1) { + absl::StrAppend(&s, expr_->ToString()); + } else if (coeff_ == -1) { + absl::StrAppend(&s, "-", expr_->ToString()); + } else { + absl::StrAppend(&s, coeff_, " * ", expr_->ToString()); + } + if (offset_ > 0) { + absl::StrAppend(&s, " + ", offset_); + } else if (offset_ < 0) { + absl::StrAppend(&s, " - ", -offset_); + } + absl::StrAppend(&s, ")"); + return s; +} + +std::string IntAffine::DebugString() const { + return absl::StrCat("IntAffine(expr=", expr_->DebugString(), + ", coeff=", coeff_, ", offset=", offset_, ")"); +} + BoundedLinearExpression* LinearExpr::Eq(LinearExpr* rhs) { - IntExprVisitor lin; - lin.AddToProcess(this, 1); - lin.AddToProcess(rhs, -1); - std::vector vars; - std::vector coeffs; - int64_t offset; - if (!lin.Process(&vars, &coeffs, &offset)) return nullptr; - return new BoundedLinearExpression(vars, coeffs, offset, Domain(0)); + return new BoundedLinearExpression(this, rhs, Domain(0)); } BoundedLinearExpression* LinearExpr::EqCst(int64_t rhs) { - IntExprVisitor lin; - lin.AddToProcess(this, 1); - std::vector vars; - std::vector coeffs; - int64_t offset; - if (!lin.Process(&vars, &coeffs, &offset)) return nullptr; - return new BoundedLinearExpression(vars, coeffs, offset, Domain(rhs)); + return new BoundedLinearExpression(this, Domain(rhs)); } BoundedLinearExpression* LinearExpr::Ne(LinearExpr* rhs) { - IntExprVisitor lin; - lin.AddToProcess(this, 1); - lin.AddToProcess(rhs, -1); - std::vector vars; - std::vector coeffs; - int64_t offset; - if (!lin.Process(&vars, &coeffs, &offset)) return nullptr; - return new BoundedLinearExpression(vars, coeffs, offset, - Domain(0).Complement()); + return new BoundedLinearExpression(this, rhs, Domain(0).Complement()); } BoundedLinearExpression* LinearExpr::NeCst(int64_t rhs) { - IntExprVisitor lin; - lin.AddToProcess(this, 1); - std::vector vars; - std::vector coeffs; - int64_t offset; - if (!lin.Process(&vars, &coeffs, &offset)) return nullptr; - return new BoundedLinearExpression(vars, coeffs, offset, - Domain(rhs).Complement()); + return new BoundedLinearExpression(this, Domain(rhs).Complement()); } BoundedLinearExpression* LinearExpr::Le(LinearExpr* rhs) { - IntExprVisitor lin; - lin.AddToProcess(this, 1); - lin.AddToProcess(rhs, -1); - std::vector vars; - std::vector coeffs; - int64_t offset; - if (!lin.Process(&vars, &coeffs, &offset)) return nullptr; return new BoundedLinearExpression( - vars, coeffs, offset, Domain(std::numeric_limits::min(), 0)); + this, rhs, Domain(std::numeric_limits::min(), 0)); } BoundedLinearExpression* LinearExpr::LeCst(int64_t rhs) { - IntExprVisitor lin; - lin.AddToProcess(this, 1); - std::vector vars; - std::vector coeffs; - int64_t offset; - if (!lin.Process(&vars, &coeffs, &offset)) return nullptr; return new BoundedLinearExpression( - vars, coeffs, offset, Domain(std::numeric_limits::min(), rhs)); + this, Domain(std::numeric_limits::min(), rhs)); } BoundedLinearExpression* LinearExpr::Lt(LinearExpr* rhs) { - IntExprVisitor lin; - lin.AddToProcess(this, 1); - lin.AddToProcess(rhs, -1); - std::vector vars; - std::vector coeffs; - int64_t offset; - if (!lin.Process(&vars, &coeffs, &offset)) return nullptr; return new BoundedLinearExpression( - vars, coeffs, offset, Domain(std::numeric_limits::min(), -1)); + this, rhs, Domain(std::numeric_limits::min(), -1)); } BoundedLinearExpression* LinearExpr::LtCst(int64_t rhs) { - IntExprVisitor lin; - lin.AddToProcess(this, 1); - std::vector vars; - std::vector coeffs; - int64_t offset; - if (!lin.Process(&vars, &coeffs, &offset)) return nullptr; return new BoundedLinearExpression( - vars, coeffs, offset, - Domain(std::numeric_limits::min(), rhs - 1)); + this, Domain(std::numeric_limits::min(), rhs - 1)); } BoundedLinearExpression* LinearExpr::Ge(LinearExpr* rhs) { - IntExprVisitor lin; - lin.AddToProcess(this, 1); - lin.AddToProcess(rhs, -1); - std::vector vars; - std::vector coeffs; - int64_t offset; - if (!lin.Process(&vars, &coeffs, &offset)) return nullptr; return new BoundedLinearExpression( - vars, coeffs, offset, Domain(0, std::numeric_limits::max())); + this, rhs, Domain(0, std::numeric_limits::max())); } BoundedLinearExpression* LinearExpr::GeCst(int64_t rhs) { - IntExprVisitor lin; - lin.AddToProcess(this, 1); - std::vector vars; - std::vector coeffs; - int64_t offset; - if (!lin.Process(&vars, &coeffs, &offset)) return nullptr; return new BoundedLinearExpression( - vars, coeffs, offset, Domain(rhs, std::numeric_limits::max())); + this, Domain(rhs, std::numeric_limits::max())); } BoundedLinearExpression* LinearExpr::Gt(LinearExpr* rhs) { - IntExprVisitor lin; - lin.AddToProcess(this, 1); - lin.AddToProcess(rhs, -1); - std::vector vars; - std::vector coeffs; - int64_t offset; - if (!lin.Process(&vars, &coeffs, &offset)) return nullptr; return new BoundedLinearExpression( - vars, coeffs, offset, Domain(1, std::numeric_limits::max())); + this, rhs, Domain(1, std::numeric_limits::max())); } BoundedLinearExpression* LinearExpr::GtCst(int64_t rhs) { - IntExprVisitor lin; - lin.AddToProcess(this, 1); - std::vector vars; - std::vector coeffs; - int64_t offset; - if (!lin.Process(&vars, &coeffs, &offset)) return nullptr; return new BoundedLinearExpression( - vars, coeffs, offset, - Domain(rhs + 1, std::numeric_limits::max())); + this, Domain(rhs + 1, std::numeric_limits::max())); } void IntExprVisitor::AddToProcess(const LinearExpr* expr, int64_t coeff) { @@ -645,10 +701,23 @@ BaseIntVar::BaseIntVar(int index, bool is_boolean) : index_(index), negated_(is_boolean ? new NotBooleanVariable(this) : nullptr) {} -BoundedLinearExpression::BoundedLinearExpression( - const std::vector& vars, - const std::vector& coeffs, int64_t offset, const Domain& bounds) - : vars_(vars), coeffs_(coeffs), offset_(offset), bounds_(bounds) {} +BoundedLinearExpression::BoundedLinearExpression(const LinearExpr* expr, + const Domain& bounds) + : bounds_(bounds) { + IntExprVisitor lin; + lin.AddToProcess(expr, 1); + ok_ = lin.Process(&vars_, &coeffs_, &offset_); +} + +BoundedLinearExpression::BoundedLinearExpression(const LinearExpr* pos, + const LinearExpr* neg, + const Domain& bounds) + : bounds_(bounds) { + IntExprVisitor lin; + lin.AddToProcess(pos, 1); + lin.AddToProcess(neg, -1); + ok_ = lin.Process(&vars_, &coeffs_, &offset_); +} const Domain& BoundedLinearExpression::bounds() const { return bounds_; } const std::vector& BoundedLinearExpression::vars() const { @@ -659,7 +728,10 @@ const std::vector& BoundedLinearExpression::coeffs() const { } int64_t BoundedLinearExpression::offset() const { return offset_; } +bool BoundedLinearExpression::ok() const { return ok_; } + std::string BoundedLinearExpression::ToString() const { + if (!ok_) return "Invalid BoundedLinearExpression"; std::string s; if (vars_.empty()) { s = absl::StrCat(offset_); @@ -731,6 +803,7 @@ std::string BoundedLinearExpression::ToString() const { } std::string BoundedLinearExpression::DebugString() const { + if (!ok_) return "Invalid BoundedLinearExpression"; return absl::StrCat( "BoundedLinearExpression(vars=[", absl::StrJoin(vars_, ", ", @@ -742,6 +815,7 @@ std::string BoundedLinearExpression::DebugString() const { } bool BoundedLinearExpression::CastToBool(bool* result) const { + if (!ok_) return false; const bool is_zero = bounds_.IsFixed() && bounds_.FixedValue() == 0; const Domain complement = bounds_.Complement(); const bool is_all_but_zero = diff --git a/ortools/sat/python/linear_expr.h b/ortools/sat/python/linear_expr.h index b5165cc5a3..d8f118d4e8 100644 --- a/ortools/sat/python/linear_expr.h +++ b/ortools/sat/python/linear_expr.h @@ -23,15 +23,13 @@ #include "absl/container/fixed_array.h" #include "absl/log/check.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "ortools/sat/cp_model.pb.h" -#include "ortools/util/fp_roundtrip_conv.h" #include "ortools/util/sorted_interval_list.h" namespace operations_research::sat::python { class BoundedLinearExpression; -class CanonicalFloatExpression; +class FlatFloatExpr; class FloatExprVisitor; class LinearExpr; class IntExprVisitor; @@ -39,19 +37,41 @@ class LinearExpr; class BaseIntVar; class NotBooleanVariable; -// A class to hold a pointer to a linear expression or a constant. -struct ExprOrValue { - explicit ExprOrValue(LinearExpr* e) : expr(e) {} - explicit ExprOrValue(double v) : double_value(v) {} - explicit ExprOrValue(int64_t v) : int_value(v) {} - - LinearExpr* expr = nullptr; - double double_value = 0.0; - int64_t int_value = 0; -}; - -// Interface for a linear expression that can be either integer or floating -// point. +/** + * A class to hold an integer or floating point linear expression. + * + * A linear expression is built from (integer or floating point) constants and + * variables. For example, `x + 2 * (y - z + 1)`. + * + * Linear expressions are used in CP-SAT models in constraints and in the + * objective. + * + * Note that constraints only accept linear expressions with integral + * coefficients and constants. On the other hand, The objective can be a linear + * expression with floating point coefficients and constants. + * + * You can define linear constraints as in: + * + * ``` + * model.add(x + 2 * y <= 5) + * model.add(sum(array_of_vars) == 5) + * ``` + * + * - In CP-SAT, the objective is a linear expression: + * + * ``` + * model.minimize(x + 2 * y + z) + * ``` + * + * - For large arrays, using the LinearExpr class is faster that using the + * python `sum()` function. You can create constraints and the objective from + * lists of linear expressions or coefficients as follows: + * + * ``` + * model.minimize(cp_model.LinearExpr.sum(expressions)) + * model.add(cp_model.LinearExpr.weighted_sum(expressions, coefficients) >= 0) + * ``` + */ class LinearExpr { public: virtual ~LinearExpr() = default; @@ -61,91 +81,74 @@ class LinearExpr { virtual std::string ToString() const = 0; virtual std::string DebugString() const = 0; - // Returns a new LinearExpr that is the sum of the given expressions. - static LinearExpr* Sum(const std::vector& exprs); - // Returns a new LinearExpr that is the sum of the given expressions or - // constants. - static LinearExpr* MixedSum(const std::vector& exprs); - // Returns the sum(exprs[i] * coeffs[i]). - static LinearExpr* WeightedSumInt(const std::vector& exprs, - const std::vector& coeffs); - // Returns the sum(exprs[i] * coeffs[i]). - static LinearExpr* WeightedSumFloat(const std::vector& exprs, - const std::vector& coeffs); - // Returns the sum(exprs[i] * coeffs[i]). - static LinearExpr* MixedWeightedSumInt(const std::vector& exprs, - const std::vector& coeffs); - // Returns the sum(exprs[i] * coeffs[i]). - static LinearExpr* MixedWeightedSumFloat( - const std::vector& exprs, const std::vector& coeffs); - // returns expr * coeff. + /// Returns expr * coeff. static LinearExpr* TermInt(LinearExpr* expr, int64_t coeff); - // returns expr * coeff. + /// Returns expr * coeff. static LinearExpr* TermFloat(LinearExpr* expr, double coeff); - // returns expr * coeff + offset. + /// Returns expr * coeff + offset. static LinearExpr* AffineInt(LinearExpr* expr, int64_t coeff, int64_t offset); - // returns expr * coeff + offset. + /// Returns expr * coeff + offset. static LinearExpr* AffineFloat(LinearExpr* expr, double coeff, double offset); - // Returns a new LinearExpr that is the given constant. + /// Returns a new LinearExpr that holds the given constant. static LinearExpr* ConstantInt(int64_t value); - // Returns a new LinearExpr that is the given constant. + /// Returns a new LinearExpr that holds the given constant. static LinearExpr* ConstantFloat(double value); - // return (this) + (expr). + /// Returns (this) + (expr). LinearExpr* Add(LinearExpr* expr); - // return (this) + (cst). + /// Returns (this) + (cst). LinearExpr* AddInt(int64_t cst); - // return (this) + (cst). + /// Returns (this) + (cst). LinearExpr* AddFloat(double cst); - // return (this) - (expr). + /// Returns (this) - (expr). LinearExpr* Sub(LinearExpr* expr); - // return (this) - (cst). + /// Returns (this) - (cst). LinearExpr* SubInt(int64_t cst); - // return (this) - (cst). + /// Returns (this) - (cst). LinearExpr* SubFloat(double cst); - // return (cst) - (this). + /// Returns (cst) - (this). LinearExpr* RSubInt(int64_t cst); - // return (cst) - (this). + /// Returns (cst) - (this). LinearExpr* RSubFloat(double cst); - // return (this) * (cst). + /// Returns (this) * (cst). LinearExpr* MulInt(int64_t cst); - // return (this) * (cst). + /// Returns (this) * (cst). LinearExpr* MulFloat(double cst); - // return -(this). + /// Returns -(this). LinearExpr* Neg(); - // returns (this) == (rhs). + /// Returns (this) == (rhs). BoundedLinearExpression* Eq(LinearExpr* rhs); - // returns (this) == (rhs). + /// Returns (this) == (rhs). BoundedLinearExpression* EqCst(int64_t rhs); - // returns (this) != (rhs). + /// Returns (this) != (rhs). BoundedLinearExpression* Ne(LinearExpr* rhs); - // returns (this) != (rhs). + /// Returns (this) != (rhs). BoundedLinearExpression* NeCst(int64_t rhs); - // returns (this) >= (rhs). + /// Returns (this) >= (rhs). BoundedLinearExpression* Ge(LinearExpr* rhs); - // returns (this) >= (rhs). + /// Returns (this) >= (rhs). BoundedLinearExpression* GeCst(int64_t rhs); - // returns (this) <= (rhs). + /// Returns (this) <= (rhs). BoundedLinearExpression* Le(LinearExpr* rhs); - // returns (this) <= (rhs). + /// Returns (this) <= (rhs). BoundedLinearExpression* LeCst(int64_t rhs); - // returns (this) < (rhs). + /// Returns (this) < (rhs). BoundedLinearExpression* Lt(LinearExpr* rhs); - // returns (this) < (rhs). + /// Returns (this) < (rhs). BoundedLinearExpression* LtCst(int64_t rhs); - // returns (this) > (rhs). + /// Returns (this) > (rhs). BoundedLinearExpression* Gt(LinearExpr* rhs); - // returns (this) > (rhs). + /// Returns (this) > (rhs). BoundedLinearExpression* GtCst(int64_t rhs); }; -// Compare the indices of variables. +/// Compare the indices of variables. struct BaseIntVarComparator { bool operator()(const BaseIntVar* lhs, const BaseIntVar* rhs) const; }; -// A visitor class to process a floating point linear expression. +/// A visitor class to process a floating point linear expression. class FloatExprVisitor { public: void AddToProcess(const LinearExpr* expr, double coeff); @@ -161,21 +164,39 @@ class FloatExprVisitor { double offset_ = 0; }; -// A class to build a canonical floating point linear expression. -class CanonicalFloatExpression { +/** + * A flattened and optimized floating point linear expression. + * + * It flattens the linear expression passed to the constructor to a sum of + * products of variables and coefficients plus an offset. It can be used to + * cache complex expressions as parsing them is only done once. + */ +class FlatFloatExpr : public LinearExpr { public: - explicit CanonicalFloatExpression(LinearExpr* expr); + /// Builds a flattened floating point linear expression from the given + /// expression. + explicit FlatFloatExpr(LinearExpr* expr); + /// Returns the array of variables of the flattened expression. const std::vector& vars() const { return vars_; } + /// Returns the array of coefficients of the flattened expression. const std::vector& coeffs() const { return coeffs_; } + /// Returns the offset of the flattened expression. double offset() const { return offset_; } + void VisitAsFloat(FloatExprVisitor& lin, double c) const override; + std::string ToString() const override; + std::string DebugString() const override; + bool VisitAsInt(IntExprVisitor& /*lin*/, int64_t /*c*/) const override { + return false; + } + private: std::vector vars_; std::vector coeffs_; double offset_ = 0; }; -// A visitor class to process an integer linear expression. +/// A visitor class to process an integer linear expression. class IntExprVisitor { public: void AddToProcess(const LinearExpr* expr, int64_t coeff); @@ -194,15 +215,45 @@ class IntExprVisitor { int64_t offset_ = 0; }; -// A class to build a canonical integer linear expression. -class CanonicalIntExpression { +/** + * A flattened and optimized integer linear expression. + * + * It flattens the linear expression passed to the constructor to a sum of + * products of variables and coefficients plus an offset. It can be used to + * cache complex expressions as parsing them is only done once. + */ +class FlatIntExpr : public LinearExpr { public: - explicit CanonicalIntExpression(LinearExpr* expr); + /// Builds a flattened integer linear expression from the given + /// expression. + explicit FlatIntExpr(LinearExpr* expr); + /// Returns the array of variables of the flattened expression. const std::vector& vars() const { return vars_; } + /// Returns the array of coefficients of the flattened expression. const std::vector& coeffs() const { return coeffs_; } + /// Returns the offset of the flattened expression. int64_t offset() const { return offset_; } + /// Returns true if the expression is integer. bool ok() const { return ok_; } + void VisitAsFloat(FloatExprVisitor& lin, double c) const override { + for (int i = 0; i < vars_.size(); ++i) { + lin.AddVarCoeff(vars_[i], coeffs_[i] * c); + } + lin.AddConstant(offset_ * c); + } + + bool VisitAsInt(IntExprVisitor& lin, int64_t c) const override { + for (int i = 0; i < vars_.size(); ++i) { + lin.AddVarCoeff(vars_[i], coeffs_[i] * c); + } + lin.AddConstant(offset_ * c); + return true; + } + + std::string ToString() const override; + std::string DebugString() const override; + private: std::vector vars_; std::vector coeffs_; @@ -210,88 +261,20 @@ class CanonicalIntExpression { bool ok_ = true; }; -// A class to hold a sum of linear expressions, and optional integer and -// double offsets (at most one of them can be non-zero, this is DCHECKed). +/** + * A class to hold a sum of linear expressions, and optional integer and + * double offsets (at most one of them can be non-zero, this is DCHECKed). + */ class SumArray : public LinearExpr { public: explicit SumArray(const std::vector& exprs, - int64_t int_offset = 0, double double_offset = 0.0) - : exprs_(exprs.begin(), exprs.end()), - int_offset_(int_offset), - double_offset_(double_offset) { - DCHECK(int_offset_ == 0 || double_offset_ == 0.0); - } + int64_t int_offset = 0, double double_offset = 0.0); ~SumArray() override = default; - bool VisitAsInt(IntExprVisitor& lin, int64_t c) const override { - if (double_offset_ != 0.0) return false; - for (int i = 0; i < exprs_.size(); ++i) { - lin.AddToProcess(exprs_[i], c); - } - lin.AddConstant(int_offset_ * c); - return true; - } - - void VisitAsFloat(FloatExprVisitor& lin, double c) const override { - for (int i = 0; i < exprs_.size(); ++i) { - lin.AddToProcess(exprs_[i], c); - } - if (int_offset_ != 0) { - lin.AddConstant(int_offset_ * c); - } else if (double_offset_ != 0.0) { - lin.AddConstant(double_offset_ * c); - } - } - - std::string ToString() const override { - if (exprs_.empty()) { - if (double_offset_ != 0.0) { - return absl::StrCat(RoundTripDoubleFormat(double_offset_)); - } else { - return absl::StrCat(int_offset_); - } - } - std::string s = "("; - for (int i = 0; i < exprs_.size(); ++i) { - if (i > 0) { - absl::StrAppend(&s, " + "); - } - absl::StrAppend(&s, exprs_[i]->ToString()); - } - if (double_offset_ != 0.0) { - if (double_offset_ > 0.0) { - absl::StrAppend(&s, " + ", double_offset_); - } else { - absl::StrAppend(&s, " - ", -double_offset_); - } - } - if (int_offset_ != 0) { - if (int_offset_ > 0) { - absl::StrAppend(&s, " + ", int_offset_); - } else { - absl::StrAppend(&s, " - ", -int_offset_); - } - } - absl::StrAppend(&s, ")"); - return s; - } - - std::string DebugString() const override { - std::string s = absl::StrCat( - "SumArray(", - absl::StrJoin(exprs_, ", ", [](std::string* out, LinearExpr* expr) { - absl::StrAppend(out, expr->DebugString()); - })); - if (int_offset_ != 0) { - absl::StrAppend(&s, ", int_offset=", int_offset_); - } - if (double_offset_ != 0.0) { - absl::StrAppend( - &s, ", double_offset=", RoundTripDoubleFormat(double_offset_)); - } - absl::StrAppend(&s, ")"); - return s; - } + void VisitAsFloat(FloatExprVisitor& lin, double c) const override; + bool VisitAsInt(IntExprVisitor& lin, int64_t c) const override; + std::string ToString() const override; + std::string DebugString() const override; private: const absl::FixedArray exprs_; @@ -299,10 +282,9 @@ class SumArray : public LinearExpr { const double double_offset_; }; -// A class to hold a weighted sum of floating point linear expressions. +/// A class to hold a weighted sum of floating point linear expressions. class FloatWeightedSum : public LinearExpr { public: - FloatWeightedSum(const std::vector& exprs, double offset); FloatWeightedSum(const std::vector& exprs, const std::vector& coeffs, double offset); ~FloatWeightedSum() override = default; @@ -321,30 +303,15 @@ class FloatWeightedSum : public LinearExpr { double offset_; }; -// A class to hold a weighted sum of integer linear expressions. +/// A class to hold a weighted sum of integer linear expressions. class IntWeightedSum : public LinearExpr { public: IntWeightedSum(const std::vector& exprs, - const std::vector& coeffs, int64_t offset) - : exprs_(exprs.begin(), exprs.end()), - coeffs_(coeffs.begin(), coeffs.end()), - offset_(offset) {} + const std::vector& coeffs, int64_t offset); ~IntWeightedSum() override = default; - void VisitAsFloat(FloatExprVisitor& lin, double c) const override { - for (int i = 0; i < exprs_.size(); ++i) { - lin.AddToProcess(exprs_[i], coeffs_[i] * c); - } - lin.AddConstant(offset_ * c); - } - - bool VisitAsInt(IntExprVisitor& lin, int64_t c) const override { - for (int i = 0; i < exprs_.size(); ++i) { - lin.AddToProcess(exprs_[i], coeffs_[i] * c); - } - lin.AddConstant(offset_ * c); - return true; - } + void VisitAsFloat(FloatExprVisitor& lin, double c) const override; + bool VisitAsInt(IntExprVisitor& lin, int64_t c) const override; std::string ToString() const override; std::string DebugString() const override; @@ -355,7 +322,7 @@ class IntWeightedSum : public LinearExpr { int64_t offset_; }; -// A class to hold float_exr * a = b. +/// A class to hold linear_expr * a = b (a and b are floating point numbers). class FloatAffine : public LinearExpr { public: FloatAffine(LinearExpr* expr, double coeff, double offset); @@ -378,46 +345,17 @@ class FloatAffine : public LinearExpr { double offset_; }; -// A class to hold int_exr * a = b. +/// A class to hold linear_expr * a = b (a and b are integers). class IntAffine : public LinearExpr { public: - IntAffine(LinearExpr* expr, int64_t coeff, int64_t offset) - : expr_(expr), coeff_(coeff), offset_(offset) {} + IntAffine(LinearExpr* expr, int64_t coeff, int64_t offset); ~IntAffine() override = default; - bool VisitAsInt(IntExprVisitor& lin, int64_t c) const override { - lin.AddToProcess(expr_, c * coeff_); - lin.AddConstant(offset_ * c); - return true; - } + bool VisitAsInt(IntExprVisitor& lin, int64_t c) const override; + void VisitAsFloat(FloatExprVisitor& lin, double c) const override; - void VisitAsFloat(FloatExprVisitor& lin, double c) const override { - lin.AddToProcess(expr_, c * coeff_); - lin.AddConstant(offset_ * c); - } - - std::string ToString() const override { - std::string s = "("; - if (coeff_ == 1) { - absl::StrAppend(&s, expr_->ToString()); - } else if (coeff_ == -1) { - absl::StrAppend(&s, "-", expr_->ToString()); - } else { - absl::StrAppend(&s, coeff_, " * ", expr_->ToString()); - } - if (offset_ > 0) { - absl::StrAppend(&s, " + ", offset_); - } else if (offset_ < 0) { - absl::StrAppend(&s, " - ", -offset_); - } - absl::StrAppend(&s, ")"); - return s; - } - - std::string DebugString() const override { - return absl::StrCat("IntAffine(expr=", expr_->DebugString(), - ", coeff=", coeff_, ", offset=", offset_, ")"); - } + std::string ToString() const override; + std::string DebugString() const override; LinearExpr* expression() const { return expr_; } int64_t coefficient() const { return coeff_; } @@ -429,7 +367,7 @@ class IntAffine : public LinearExpr { int64_t offset_; }; -// A class to hold a constant. +/// A class to hold a floating point constant as a linear expression. class FloatConstant : public LinearExpr { public: explicit FloatConstant(double value) : value_(value) {} @@ -446,20 +384,21 @@ class FloatConstant : public LinearExpr { double value_; }; -// A class to hold a constant. +/// A class to hold an integer constant as a linear expression. class IntConstant : public LinearExpr { public: explicit IntConstant(int64_t value) : value_(value) {} ~IntConstant() override = default; - bool VisitAsInt(IntExprVisitor& lin, int64_t c) const override { - lin.AddConstant(value_ * c); - return true; - } void VisitAsFloat(FloatExprVisitor& lin, double c) const override { lin.AddConstant(value_ * c); } + bool VisitAsInt(IntExprVisitor& lin, int64_t c) const override { + lin.AddConstant(value_ * c); + return true; + } + std::string ToString() const override { return absl::StrCat(value_); } std::string DebugString() const override { @@ -470,15 +409,57 @@ class IntConstant : public LinearExpr { int64_t value_; }; -// A Boolean literal (a Boolean variable or its negation). +/** + * A class to hold a Boolean literal. + * + * A literal is a Boolean variable or its negation. + * + * Literals are used in CP-SAT models in constraints and in the objective. + * + * - You can define literal as in: + * + * ``` + * b1 = model.new_bool_var() + * b2 = model.new_bool_var() + * # Simple Boolean constraint. + * model.add_bool_or(b1, b2.negated()) + * # We can use the ~ operator to negate a literal. + * model.add_bool_or(b1, ~b2) + * # Enforcement literals must be literals. + * x = model.new_int_var(0, 10, 'x') + * model.add(x == 5).only_enforced_if(~b1) + * ``` + * + * - Literals can be used directly in linear constraints or in the objective: + * + * ``` + * model.minimize(b1 + 2 * ~b2) + * ``` + */ class Literal : public LinearExpr { public: ~Literal() override = default; + /// Returns the index of the current literal. virtual int index() const = 0; + + /** + * Returns the negation of a literal (a Boolean variable or its negation). + * + * This method implements the logical negation of a Boolean variable. + * It is only valid if the variable has a Boolean domain (0 or 1). + * + * Note that this method is nilpotent: `x.negated().negated() == x`. + * + * Returns: + * The negation of the current literal. + */ virtual Literal* negated() const = 0; }; -// A class to hold a variable index. +/** + * A class to hold a variable index. It is the base class for Integer + * variables. + */ class BaseIntVar : public Literal { public: explicit BaseIntVar(int index) : index_(index), negated_(nullptr) { @@ -514,8 +495,10 @@ class BaseIntVar : public Literal { ", is_boolean=", negated_ != nullptr, ")"); } + /// Returns the negation of the current variable. Literal* negated() const override { return negated_; } + /// Returns true if the variable has a Boolean domain (0 or 1). bool is_boolean() const { return negated_ != nullptr; } bool operator<(const BaseIntVar& other) const { @@ -532,14 +515,21 @@ H AbslHashValue(H h, const BaseIntVar* i) { return H::combine(std::move(h), i->index()); } -// A class to hold a negated variable index. +/// A class to hold a negated variable index. class NotBooleanVariable : public Literal { public: explicit NotBooleanVariable(BaseIntVar* var) : var_(var) {} ~NotBooleanVariable() override = default; + /// Returns the index of the current literal. int index() const override { return -var_->index() - 1; } + /** + * Returns the negation of the current literal, that is the original Boolean + * variable. + */ + Literal* negated() const override { return var_; } + bool VisitAsInt(IntExprVisitor& lin, int64_t c) const override { lin.AddVarCoeff(var_, -c); lin.AddConstant(c); @@ -555,8 +545,6 @@ class NotBooleanVariable : public Literal { return absl::StrCat("not(", var_->ToString(), ")"); } - Literal* negated() const override { return var_; } - std::string DebugString() const override { return absl::StrCat("NotBooleanVariable(index=", var_->index(), ")"); } @@ -565,28 +553,38 @@ class NotBooleanVariable : public Literal { BaseIntVar* const var_; }; -// A class to hold a linear expression with bounds. +/// A class to hold a linear expression with bounds. class BoundedLinearExpression { public: - BoundedLinearExpression(const std::vector& vars, - const std::vector& coeffs, int64_t offset, + /// Creates a BoundedLinearExpression representing `expr in domain`. + BoundedLinearExpression(const LinearExpr* expr, const Domain& bounds); + + /// Creates a BoundedLinearExpression representing `pos - neg in domain`. + BoundedLinearExpression(const LinearExpr* pos, const LinearExpr* neg, const Domain& bounds); ~BoundedLinearExpression() = default; + /// Returns the bounds constraining the expression passed to the constructor. const Domain& bounds() const; + /// Returns the array of variables of the flattened expression. const std::vector& vars() const; + /// Returns the array of coefficients of the flattened expression. const std::vector& coeffs() const; + /// Returns the offset of the flattened expression. int64_t offset() const; + /// Returns true if the bounded linear expression is valid. + bool ok() const; std::string ToString() const; std::string DebugString() const; bool CastToBool(bool* result) const; private: - const std::vector vars_; - const std::vector coeffs_; + std::vector vars_; + std::vector coeffs_; int64_t offset_; const Domain bounds_; + bool ok_ = true; }; } // namespace operations_research::sat::python diff --git a/ortools/sat/python/linear_expr_doc.h b/ortools/sat/python/linear_expr_doc.h new file mode 100644 index 0000000000..e26ddff411 --- /dev/null +++ b/ortools/sat/python/linear_expr_doc.h @@ -0,0 +1,887 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_PYTHON_LINEAR_EXPR_DOC_H_ +#define OR_TOOLS_SAT_PYTHON_LINEAR_EXPR_DOC_H_ + +// NOLINTBEGIN + +/* + This file contains docstrings for use in the Python bindings. + Do not edit! They were automatically extracted by pybind11_mkdoc. + */ + +#define __EXPAND(x) x +#define __COUNT(_1, _2, _3, _4, _5, _6, _7, COUNT, ...) COUNT +#define __VA_SIZE(...) __EXPAND(__COUNT(__VA_ARGS__, 7, 6, 5, 4, 3, 2, 1)) +#define __CAT1(a, b) a##b +#define __CAT2(a, b) __CAT1(a, b) +#define __DOC1(n1) __doc_##n1 +#define __DOC2(n1, n2) __doc_##n1##_##n2 +#define __DOC3(n1, n2, n3) __doc_##n1##_##n2##_##n3 +#define __DOC4(n1, n2, n3, n4) __doc_##n1##_##n2##_##n3##_##n4 +#define __DOC5(n1, n2, n3, n4, n5) __doc_##n1##_##n2##_##n3##_##n4##_##n5 +#define __DOC6(n1, n2, n3, n4, n5, n6) \ + __doc_##n1##_##n2##_##n3##_##n4##_##n5##_##n6 +#define __DOC7(n1, n2, n3, n4, n5, n6, n7) \ + __doc_##n1##_##n2##_##n3##_##n4##_##n5##_##n6##_##n7 +#define DOC(...) \ + __EXPAND(__EXPAND(__CAT2(__DOC, __VA_SIZE(__VA_ARGS__)))(__VA_ARGS__)) + +#if defined(__GNUG__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-variable" +#endif + +static const char* __doc_operations_research_sat_python_AbslHashValue = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_BaseIntVar = + R"doc(A class to hold a variable index. It is the base class for Integer +variables.)doc"; + +static const char* __doc_operations_research_sat_python_BaseIntVar_2 = + R"doc(A class to hold a variable index. It is the base class for Integer +variables.)doc"; + +static const char* __doc_operations_research_sat_python_BaseIntVarComparator = + R"doc(Compare the indices of variables.)doc"; + +static const char* + __doc_operations_research_sat_python_BaseIntVarComparator_operator_call = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_BaseIntVar_BaseIntVar = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BaseIntVar_BaseIntVar_2 = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_BaseIntVar_DebugString = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_BaseIntVar_ToString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BaseIntVar_VisitAsFloat = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_BaseIntVar_VisitAsInt = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_BaseIntVar_index = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_BaseIntVar_index_2 = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_BaseIntVar_is_boolean = + R"doc(Returns true if the variable has a Boolean domain (0 or 1).)doc"; + +static const char* __doc_operations_research_sat_python_BaseIntVar_negated = + R"doc(Returns the negation of the current variable.)doc"; + +static const char* __doc_operations_research_sat_python_BaseIntVar_negated_2 = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_BaseIntVar_operator_lt = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression = + R"doc(A class to hold a linear expression with bounds.)doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_2 = + R"doc(A class to hold a linear expression with bounds.)doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_BoundedLinearExpression = + R"doc(Creates a BoundedLinearExpression representing `expr in domain`.)doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_BoundedLinearExpression_2 = + R"doc(Creates a BoundedLinearExpression representing `pos - neg in domain`.)doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_CastToBool = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_DebugString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_ToString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_bounds = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_bounds_2 = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_coeffs = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_coeffs_2 = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_offset = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_offset_2 = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_ok = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_ok_2 = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_vars = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_BoundedLinearExpression_vars_2 = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_ExprOrValue = + R"doc(A class to hold a pointer to a linear expression or a constant.)doc"; + +static const char* + __doc_operations_research_sat_python_ExprOrValue_ExprOrValue = R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_ExprOrValue_ExprOrValue_2 = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_ExprOrValue_ExprOrValue_3 = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_ExprOrValue_double_value = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_ExprOrValue_expr = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_ExprOrValue_int_value = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatFloatExpr = + R"doc(A flattened and optimized floating point linear expression. + +It can be used to cache complex expressions as parsing them is only +done once.)doc"; + +static const char* __doc_operations_research_sat_python_FlatFloatExpr_2 = + R"doc(A flattened and optimized floating point linear expression. + +It can be used to cache complex expressions as parsing them is only +done once.)doc"; + +static const char* + __doc_operations_research_sat_python_FlatFloatExpr_DebugString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FlatFloatExpr_FlatFloatExpr = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatFloatExpr_ToString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FlatFloatExpr_VisitAsFloat = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FlatFloatExpr_VisitAsInt = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatFloatExpr_coeffs = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatFloatExpr_coeffs_2 = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatFloatExpr_offset = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatFloatExpr_offset_2 = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatFloatExpr_vars = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatFloatExpr_vars_2 = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatIntExpr = + R"doc(A flattened and optimized integer linear expression. + +It can be used to cache complex expressions as parsing them is only +done once.)doc"; + +static const char* + __doc_operations_research_sat_python_FlatIntExpr_DebugString = R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FlatIntExpr_FlatIntExpr = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatIntExpr_ToString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FlatIntExpr_VisitAsFloat = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatIntExpr_VisitAsInt = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatIntExpr_coeffs = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatIntExpr_coeffs_2 = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatIntExpr_offset = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatIntExpr_offset_2 = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatIntExpr_ok = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatIntExpr_ok_2 = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatIntExpr_vars = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FlatIntExpr_vars_2 = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FloatAffine = + R"doc(A class to hold linear_expr * a = b (a and b are floating point +numbers).)doc"; + +static const char* + __doc_operations_research_sat_python_FloatAffine_DebugString = R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatAffine_FloatAffine = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FloatAffine_ToString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatAffine_VisitAsFloat = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FloatAffine_VisitAsInt = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FloatAffine_coeff = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatAffine_coefficient = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FloatAffine_expr = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FloatAffine_expression = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FloatAffine_offset = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FloatAffine_offset_2 = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FloatConstant = + R"doc(A class to hold a floating point constant as a linear expression.)doc"; + +static const char* + __doc_operations_research_sat_python_FloatConstant_DebugString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatConstant_FloatConstant = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FloatConstant_ToString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatConstant_VisitAsFloat = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatConstant_VisitAsInt = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FloatConstant_value = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FloatExprVisitor = + R"doc(A visitor class to process a floating point linear expression.)doc"; + +static const char* __doc_operations_research_sat_python_FloatExprVisitor_2 = + R"doc(A visitor class to process a floating point linear expression.)doc"; + +static const char* + __doc_operations_research_sat_python_FloatExprVisitor_AddConstant = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatExprVisitor_AddToProcess = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatExprVisitor_AddVarCoeff = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatExprVisitor_Process = R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatExprVisitor_canonical_terms = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatExprVisitor_offset = R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatExprVisitor_to_process = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FloatWeightedSum = + R"doc(A class to hold a weighted sum of floating point linear expressions. +*/)doc"; + +static const char* + __doc_operations_research_sat_python_FloatWeightedSum_DebugString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatWeightedSum_FloatWeightedSum = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatWeightedSum_FloatWeightedSum_2 = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatWeightedSum_ToString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatWeightedSum_VisitAsFloat = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatWeightedSum_VisitAsInt = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatWeightedSum_coeffs = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_FloatWeightedSum_exprs = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_FloatWeightedSum_offset = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntAffine = + R"doc(A class to hold linear_expr * a = b (a and b are integers).)doc"; + +static const char* __doc_operations_research_sat_python_IntAffine_DebugString = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntAffine_IntAffine = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntAffine_ToString = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntAffine_VisitAsFloat = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntAffine_VisitAsInt = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntAffine_coeff = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntAffine_coefficient = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntAffine_expr = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntAffine_expression = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntAffine_offset = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntAffine_offset_2 = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntConstant = + R"doc(A class to hold an integer constant as a linear expression.)doc"; + +static const char* + __doc_operations_research_sat_python_IntConstant_DebugString = R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_IntConstant_IntConstant = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntConstant_ToString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_IntConstant_VisitAsFloat = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntConstant_VisitAsInt = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntConstant_value = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntExprVisitor = + R"doc(A visitor class to process an integer linear expression.)doc"; + +static const char* __doc_operations_research_sat_python_IntExprVisitor_2 = + R"doc(A visitor class to process an integer linear expression.)doc"; + +static const char* + __doc_operations_research_sat_python_IntExprVisitor_AddConstant = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_IntExprVisitor_AddToProcess = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_IntExprVisitor_AddVarCoeff = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_IntExprVisitor_Evaluate = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntExprVisitor_Process = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_IntExprVisitor_ProcessAll = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_IntExprVisitor_canonical_terms = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntExprVisitor_offset = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_IntExprVisitor_to_process = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntWeightedSum = + R"doc(A class to hold a weighted sum of integer linear expressions.)doc"; + +static const char* + __doc_operations_research_sat_python_IntWeightedSum_DebugString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_IntWeightedSum_IntWeightedSum = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_IntWeightedSum_ToString = R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_IntWeightedSum_VisitAsFloat = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_IntWeightedSum_VisitAsInt = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntWeightedSum_coeffs = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntWeightedSum_exprs = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntWeightedSum_offset = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr = + R"doc(A class to hold an integer or floating point linear expression. + +A linear expression is built from (integer or floating point) +constants and variables. For example, `x + 2 * (y - z + 1)`. + +Linear expressions are used in CP-SAT models in constraints and in the +objective. + +Note that constraints only accept linear expressions with integral +coefficients and constants. On the other hand, The objective can be a +linear expression with floating point coefficients and constants. + +You can define linear constraints as in: + +``` +model.add(x + 2 * y <= 5) +model.add(sum(array_of_vars) == 5) +``` + +- In CP-SAT, the objective is a linear expression: + +``` +model.minimize(x + 2 * y + z) +``` + +- For large arrays, using the LinearExpr class is faster that using +the python `sum()` function. You can create constraints and the +objective from lists of linear expressions or coefficients as follows: + +``` +model.minimize(cp_model.LinearExpr.sum(expressions)) +model.add(cp_model.LinearExpr.weighted_sum(expressions, coefficients) >= 0) +```)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_2 = + R"doc(A class to hold an integer or floating point linear expression. + +A linear expression is built from (integer or floating point) +constants and variables. For example, `x + 2 * (y - z + 1)`. + +Linear expressions are used in CP-SAT models in constraints and in the +objective. + +Note that constraints only accept linear expressions with integral +coefficients and constants. On the other hand, The objective can be a +linear expression with floating point coefficients and constants. + +You can define linear constraints as in: + +``` +model.add(x + 2 * y <= 5) +model.add(sum(array_of_vars) == 5) +``` + +- In CP-SAT, the objective is a linear expression: + +``` +model.minimize(x + 2 * y + z) +``` + +- For large arrays, using the LinearExpr class is faster that using +the python `sum()` function. You can create constraints and the +objective from lists of linear expressions or coefficients as follows: + +``` +model.minimize(cp_model.LinearExpr.sum(expressions)) +model.add(cp_model.LinearExpr.weighted_sum(expressions, coefficients) >= 0) +```)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_3 = + R"doc(A class to hold an integer or floating point linear expression. + +A linear expression is built from (integer or floating point) +constants and variables. For example, `x + 2 * (y - z + 1)`. + +Linear expressions are used in CP-SAT models in constraints and in the +objective. + +Note that constraints only accept linear expressions with integral +coefficients and constants. On the other hand, The objective can be a +linear expression with floating point coefficients and constants. + +You can define linear constraints as in: + +``` +model.add(x + 2 * y <= 5) +model.add(sum(array_of_vars) == 5) +``` + +- In CP-SAT, the objective is a linear expression: + +``` +model.minimize(x + 2 * y + z) +``` + +- For large arrays, using the LinearExpr class is faster that using +the python `sum()` function. You can create constraints and the +objective from lists of linear expressions or coefficients as follows: + +``` +model.minimize(cp_model.LinearExpr.sum(expressions)) +model.add(cp_model.LinearExpr.weighted_sum(expressions, coefficients) >= 0) +```)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_Add = + R"doc(Returns (this) + (expr).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_AddFloat = + R"doc(Returns (this) + (cst).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_AddInt = + R"doc(Returns (this) + (cst).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_AffineFloat = + R"doc(Returns expr * coeff + offset.)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_AffineInt = + R"doc(Returns expr * coeff + offset.)doc"; + +static const char* + __doc_operations_research_sat_python_LinearExpr_ConstantFloat = + R"doc(Returns a new LinearExpr that is the given constant.)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_ConstantInt = + R"doc(Returns a new LinearExpr that is the given constant.)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_DebugString = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_Eq = + R"doc(Returns (this) == (rhs).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_EqCst = + R"doc(Returns (this) == (rhs).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_Ge = + R"doc(Returns (this) >= (rhs).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_GeCst = + R"doc(Returns (this) >= (rhs).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_Gt = + R"doc(Returns (this) > (rhs).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_GtCst = + R"doc(Returns (this) > (rhs).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_IsInteger = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_Le = + R"doc(Returns (this) <= (rhs).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_LeCst = + R"doc(Returns (this) <= (rhs).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_Lt = + R"doc(Returns (this) < (rhs).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_LtCst = + R"doc(Returns (this) < (rhs).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_MixedSum = + R"doc(Returns a new LinearExpr that is the sum of the given expressions or +constants.)doc"; + +static const char* + __doc_operations_research_sat_python_LinearExpr_MixedWeightedSumFloat = + R"doc(Returns the sum(exprs[i] * coeffs[i]).)doc"; + +static const char* + __doc_operations_research_sat_python_LinearExpr_MixedWeightedSumInt = + R"doc(Returns the sum(exprs[i] * coeffs[i]).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_MulFloat = + R"doc(Returns (this) * (cst).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_MulInt = + R"doc(Returns (this) * (cst).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_Ne = + R"doc(Returns (this) != (rhs).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_NeCst = + R"doc(Returns (this) != (rhs).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_Neg = + R"doc(Returns -(this).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_RSubFloat = + R"doc(Returns (cst) - (this).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_RSubInt = + R"doc(Returns (cst) - (this).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_Sub = + R"doc(Returns (this) - (expr).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_SubFloat = + R"doc(Returns (this) - (cst).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_SubInt = + R"doc(Returns (this) - (cst).)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_Sum = + R"doc(Returns a new LinearExpr that is the sum of the given expressions.)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_TermFloat = + R"doc(Returns expr * coeff.)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_TermInt = + R"doc(Returns expr * coeff.)doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_ToString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_LinearExpr_VisitAsFloat = R"doc()doc"; + +static const char* __doc_operations_research_sat_python_LinearExpr_VisitAsInt = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_LinearExpr_WeightedSumFloat = + R"doc(Returns the sum(exprs[i] * coeffs[i]).)doc"; + +static const char* + __doc_operations_research_sat_python_LinearExpr_WeightedSumInt = + R"doc(Returns the sum(exprs[i] * coeffs[i]).)doc"; + +static const char* __doc_operations_research_sat_python_Literal = + R"doc(A class to hold a Boolean literal. + +A literal is a Boolean variable or its negation. + +Literals are used in CP-SAT models in constraints and in the +objective. + +- You can define literal as in: + +``` +b1 = model.new_bool_var() +b2 = model.new_bool_var() +# Simple Boolean constraint. +model.add_bool_or(b1, b2.negated()) +# We can use the ~ operator to negate a literal. +model.add_bool_or(b1, ~b2) +# Enforcement literals must be literals. +x = model.new_int_var(0, 10, 'x') +model.add(x == 5).only_enforced_if(~b1) +``` + +- Literals can be used directly in linear constraints or in the +objective: + +``` +model.minimize(b1 + 2 * ~b2) +```)doc"; + +static const char* __doc_operations_research_sat_python_Literal_index = + R"doc(Returns the index of the current literal.)doc"; + +static const char* __doc_operations_research_sat_python_Literal_negated = + R"doc(Returns the negation of a literal (a Boolean variable or its +negation). + +This method implements the logical negation of a Boolean variable. It +is only valid if the variable has a Boolean domain (0 or 1). + +Note that this method is nilpotent: `x.negated().negated() == x`. + +Returns: The negation of the current literal.)doc"; + +static const char* __doc_operations_research_sat_python_NotBooleanVariable = + R"doc(A class to hold a negated variable index.)doc"; + +static const char* __doc_operations_research_sat_python_NotBooleanVariable_2 = + R"doc(A class to hold a negated variable index.)doc"; + +static const char* + __doc_operations_research_sat_python_NotBooleanVariable_DebugString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_NotBooleanVariable_NotBooleanVariable = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_NotBooleanVariable_ToString = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_NotBooleanVariable_VisitAsFloat = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_NotBooleanVariable_VisitAsInt = + R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_NotBooleanVariable_index = R"doc()doc"; + +static const char* + __doc_operations_research_sat_python_NotBooleanVariable_negated = + R"doc(Returns the negation of the current literal, that is the original +Boolean variable.)doc"; + +static const char* __doc_operations_research_sat_python_NotBooleanVariable_var = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_SumArray = + R"doc(A class to hold a sum of linear expressions, and optional integer and +double offsets (at most one of them can be non-zero, this is +DCHECKed).)doc"; + +static const char* __doc_operations_research_sat_python_SumArray_DebugString = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_SumArray_SumArray = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_SumArray_ToString = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_SumArray_VisitAsFloat = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_SumArray_VisitAsInt = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_SumArray_double_offset = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_SumArray_exprs = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_SumArray_int_offset = + R"doc()doc"; + +#if defined(__GNUG__) +#pragma GCC diagnostic pop +#endif + +// NOLINTEND + +#endif // OR_TOOLS_SAT_PYTHON_LINEAR_EXPR_DOC_H_ diff --git a/ortools/sat/python/swig_helper_test.py b/ortools/sat/python/swig_helper_test.py deleted file mode 100644 index c3232c3ac2..0000000000 --- a/ortools/sat/python/swig_helper_test.py +++ /dev/null @@ -1,362 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2010-2025 Google LLC -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for ortools.sat.python.swig_helper.""" - -from absl.testing import absltest -from google.protobuf import text_format -from ortools.sat import cp_model_pb2 -from ortools.sat import sat_parameters_pb2 -from ortools.sat.python import swig_helper - - -class Callback(swig_helper.SolutionCallback): - - def __init__(self): - swig_helper.SolutionCallback.__init__(self) - self.__solution_count = 0 - - def OnSolutionCallback(self): - print("New Solution") - self.__solution_count += 1 - - def solution_count(self): - return self.__solution_count - - -class BestBoundCallback: - - def __init__(self): - self.best_bound: float = 0.0 - - def new_best_bound(self, bb: float): - self.best_bound = bb - - -class TestIntVar(swig_helper.BaseIntVar): - - def __init__(self, index: int, name: str, is_boolean: bool = False) -> None: - swig_helper.BaseIntVar.__init__(self, index, is_boolean) - self._name = name - - def __str__(self) -> str: - return self._name - - def __repr__(self) -> str: - return self._name - - -class SwigHelperTest(absltest.TestCase): - - def testVariableDomain(self): - model_string = """ - variables { domain: [ -10, 10 ] } - variables { domain: [ -5, -5, 3, 6 ] } - """ - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) - - d0 = swig_helper.CpSatHelper.variable_domain(model.variables[0]) - d1 = swig_helper.CpSatHelper.variable_domain(model.variables[1]) - - self.assertEqual(d0.flattened_intervals(), [-10, 10]) - self.assertEqual(d1.flattened_intervals(), [-5, -5, 3, 6]) - - def testSimpleSolve(self): - model_string = """ - variables { domain: -10 domain: 10 } - variables { domain: -10 domain: 10 } - variables { domain: -461168601842738790 domain: 461168601842738790 } - constraints { - linear { - vars: 0 - vars: 1 - coeffs: 1 - coeffs: 1 - domain: -9223372036854775808 - domain: 9223372036854775807 - } - } - constraints { - linear { - vars: 0 - vars: 1 - vars: 2 - coeffs: 1 - coeffs: 2 - coeffs: -1 - domain: 0 - domain: 9223372036854775807 - } - } - objective { - vars: -3 - coeffs: 1 - scaling_factor: -1 - }""" - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) - - solve_wrapper = swig_helper.SolveWrapper() - response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) - - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) - self.assertEqual(30.0, response_wrapper.objective_value()) - - def testSimpleSolveWithCore(self): - model_string = """ - variables { domain: -10 domain: 10 } - variables { domain: -10 domain: 10 } - variables { domain: -461168601842738790 domain: 461168601842738790 } - constraints { - linear { - vars: 0 - vars: 1 - coeffs: 1 - coeffs: 1 - domain: -9223372036854775808 - domain: 9223372036854775807 - } - } - constraints { - linear { - vars: 0 - vars: 1 - vars: 2 - coeffs: 1 - coeffs: 2 - coeffs: -1 - domain: 0 - domain: 9223372036854775807 - } - } - objective { - vars: -3 - coeffs: 1 - scaling_factor: -1 - }""" - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) - - parameters = sat_parameters_pb2.SatParameters(optimize_with_core=True) - - solve_wrapper = swig_helper.SolveWrapper() - solve_wrapper.set_parameters(parameters) - response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) - - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) - self.assertEqual(30.0, response_wrapper.objective_value()) - - def testSimpleSolveWithProtoApi(self): - model = cp_model_pb2.CpModelProto() - x = model.variables.add() - x.domain.extend([-10, 10]) - y = model.variables.add() - y.domain.extend([-10, 10]) - obj_var = model.variables.add() - obj_var.domain.extend([-461168601842738790, 461168601842738790]) - ct = model.constraints.add() - ct.linear.vars.extend([0, 1, 2]) - ct.linear.coeffs.extend([1, 2, -1]) - ct.linear.domain.extend([0, 0]) - model.objective.vars.append(-3) - model.objective.coeffs.append(1) - model.objective.scaling_factor = -1 - - solve_wrapper = swig_helper.SolveWrapper() - response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) - - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) - self.assertEqual(30.0, response_wrapper.objective_value()) - self.assertEqual(30.0, response_wrapper.best_objective_bound()) - - def testSolutionCallback(self): - model_string = """ - variables { domain: 0 domain: 5 } - variables { domain: 0 domain: 5 } - constraints { - linear { vars: 0 vars: 1 coeffs: 1 coeffs: 1 domain: 6 domain: 6 } } - """ - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) - - solve_wrapper = swig_helper.SolveWrapper() - callback = Callback() - solve_wrapper.add_solution_callback(callback) - params = sat_parameters_pb2.SatParameters() - params.enumerate_all_solutions = True - solve_wrapper.set_parameters(params) - response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) - - self.assertEqual(5, callback.solution_count()) - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) - - def testBestBoundCallback(self): - model_string = """ - variables { domain: 0 domain: 1 } - variables { domain: 0 domain: 1 } - variables { domain: 0 domain: 1 } - variables { domain: 0 domain: 1 } - constraints { bool_or { literals: [0, 1, 2, 3] } } - objective { - vars: [0, 1, 2, 3] - coeffs: [3, 2, 4, 5] - offset: 0.6 - } - """ - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) - - solve_wrapper = swig_helper.SolveWrapper() - best_bound_callback = BestBoundCallback() - solve_wrapper.add_best_bound_callback(best_bound_callback.new_best_bound) - params = sat_parameters_pb2.SatParameters() - params.num_workers = 1 - params.linearization_level = 2 - params.log_search_progress = True - solve_wrapper.set_parameters(params) - response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) - - self.assertEqual(2.6, best_bound_callback.best_bound) - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) - - def testModelStats(self): - model_string = """ - variables { domain: -10 domain: 10 } - variables { domain: -10 domain: 10 } - variables { domain: -1000 domain: 1000 } - constraints { - linear { - vars: 0 - vars: 1 - coeffs: 1 - coeffs: 1 - domain: -1000 - domain: 1000 - } - } - constraints { - linear { - vars: 0 - vars: 1 - vars: 2 - coeffs: 1 - coeffs: 2 - coeffs: -1 - domain: 0 - domain: 1000 - } - } - objective { - vars: -3 - coeffs: 1 - scaling_factor: -1 - } - name: 'testModelStats' - """ - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) - stats = swig_helper.CpSatHelper.model_stats(model) - self.assertTrue(stats) - - def testIntLinExpr(self): - x = TestIntVar(0, "x") - self.assertTrue(x.is_integer()) - self.assertIsInstance(x, swig_helper.BaseIntVar) - self.assertIsInstance(x, swig_helper.LinearExpr) - e1 = x + 2 - self.assertTrue(e1.is_integer()) - self.assertEqual(str(e1), "(x + 2)") - e2 = 3 + x - self.assertTrue(e2.is_integer()) - self.assertEqual(str(e2), "(x + 3)") - y = TestIntVar(1, "y") - e3 = y * 5 - self.assertTrue(e3.is_integer()) - self.assertEqual(str(e3), "(5 * y)") - e4 = -2 * y - self.assertTrue(e4.is_integer()) - self.assertEqual(str(e4), "(-2 * y)") - e5 = x - 1 - self.assertTrue(e5.is_integer()) - self.assertEqual(str(e5), "(x - 1)") - e6 = x - 2 * y - self.assertTrue(e6.is_integer()) - self.assertEqual(str(e6), "(x - (2 * y))") - z = TestIntVar(2, "z", True) - e7 = -z - self.assertTrue(e7.is_integer()) - self.assertEqual(str(e7), "(-z)") - not_z = ~z - self.assertTrue(not_z.is_integer()) - self.assertEqual(str(not_z), "not(z)") - self.assertEqual(not_z.index, -3) - - e8 = swig_helper.LinearExpr.sum([x, y, z]) - self.assertEqual(str(e8), "(x + y + z)") - e9 = swig_helper.LinearExpr.sum([x, y, z, 11]) - self.assertEqual(str(e9), "(x + y + z + 11)") - e10 = swig_helper.LinearExpr.weighted_sum([x, y, z], [1, 2, 3]) - self.assertEqual(str(e10), "(x + 2 * y + 3 * z)") - e11 = swig_helper.LinearExpr.weighted_sum([x, y, z, 5], [1, 2, 3, -1]) - self.assertEqual(str(e11), "(x + 2 * y + 3 * z - 5)") - - def testFloatLinExpr(self): - x = TestIntVar(0, "x") - self.assertTrue(x.is_integer()) - self.assertIsInstance(x, TestIntVar) - self.assertIsInstance(x, swig_helper.LinearExpr) - e1 = x + 2.5 - self.assertFalse(e1.is_integer()) - self.assertEqual(str(e1), "(x + 2.5)") - e2 = 3.1 + x - self.assertFalse(e2.is_integer()) - self.assertEqual(str(e2), "(x + 3.1)") - y = TestIntVar(1, "y") - e3 = y * 5.2 - self.assertFalse(e3.is_integer()) - self.assertEqual(str(e3), "(5.2 * y)") - e4 = -2.2 * y - self.assertFalse(e4.is_integer()) - self.assertEqual(str(e4), "(-2.2 * y)") - e5 = x - 1.1 - self.assertFalse(e5.is_integer()) - self.assertEqual(str(e5), "(x - 1.1)") - e6 = x + 2.4 * y - self.assertFalse(e6.is_integer()) - self.assertEqual(str(e6), "(x + (2.4 * y))") - e7 = x - 2.4 * y - self.assertFalse(e7.is_integer()) - self.assertEqual(str(e7), "(x - (2.4 * y))") - - z = TestIntVar(2, "z") - e8 = swig_helper.LinearExpr.sum([x, y, z, -2]) - self.assertTrue(e8.is_integer()) - self.assertEqual(str(e8), "(x + y + z - 2)") - e9 = swig_helper.LinearExpr.sum([x, y, z, 1.5]) - self.assertFalse(e9.is_integer()) - self.assertEqual(str(e9), "(x + y + z + 1.5)") - e10 = swig_helper.LinearExpr.weighted_sum([x, y, z], [1.0, 2.2, 3.3]) - self.assertFalse(e10.is_integer()) - self.assertEqual(str(e10), "(x + 2.2 * y + 3.3 * z)") - e11 = swig_helper.LinearExpr.weighted_sum([x, y, z, 1.5], [1.0, 2.2, 3.3, -1]) - self.assertFalse(e11.is_integer()) - self.assertEqual(str(e11), "(x + 2.2 * y + 3.3 * z - 1.5)") - e12 = (x + 2) * 3.1 - self.assertFalse(e12.is_integer()) - self.assertEqual(str(e12), "(3.1 * (x + 2))") - - -if __name__ == "__main__": - absltest.main() diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index 2b9e883f31..5fedf35f83 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -486,10 +486,12 @@ message SatParameters { // solver. Note that contrary to the precedence encoding, this easily support // variable demands. // - // WARNING: with this encoding, the constraint take a slighlty different - // meaning. The level must be within the reservoir for any permutation of the - // events. So we cannot have +100 and -100 at the same time if the maximum - // level is 10 (as autorized by the reservoir constraint). + // WARNING: with this encoding, the constraint takes a slightly different + // meaning. There must exist a permutation of the events occurring at the same + // time such that the level is within the reservoir after each of these events + // (in this permuted order). So we cannot have +100 and -100 at the same time + // if the level must be between 0 and 10 (as authorized by the reservoir + // constraint). optional bool expand_reservoir_using_circuit = 288 [default = false]; // Encore cumulative with fixed demands and capacity as a reservoir diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index 4ccffe9e46..4d2a7f6059 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -480,7 +480,7 @@ void SatSolver::SaveDebugAssignment() { } } -void SatSolver::LoadDebugSolution(const std::vector& solution) { +void SatSolver::LoadDebugSolution(absl::Span solution) { debug_assignment_.Resize(num_variables_.value()); for (BooleanVariable var(0); var < num_variables_; ++var) { if (!debug_assignment_.VariableIsAssigned(var)) continue; @@ -521,7 +521,7 @@ bool SatSolver::ClauseIsValidUnderDebugAssignment( } bool SatSolver::PBConstraintIsValidUnderDebugAssignment( - const std::vector& cst, const Coefficient rhs) const { + absl::Span cst, const Coefficient rhs) const { Coefficient sum(0.0); for (LiteralWithCoeff term : cst) { if (term.literal.Variable() >= debug_assignment_.NumberOfVariables()) { @@ -2193,7 +2193,7 @@ void SatSolver::ComputeFirstUIPConflict( } } -void SatSolver::ComputeUnionOfReasons(const std::vector& input, +void SatSolver::ComputeUnionOfReasons(absl::Span input, std::vector* literals) { tmp_mark_.ClearAndResize(num_variables_); literals->clear(); diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index 51d205d4ad..be776ba35e 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -444,7 +444,7 @@ class SatSolver { // debug mode, and after this is called, all the learned clauses are tested to // satisfy this saved assignment. void SaveDebugAssignment(); - void LoadDebugSolution(const std::vector& solution); + void LoadDebugSolution(absl::Span solution); void SetDratProofHandler(DratProofHandler* drat_proof_handler) { drat_proof_handler_ = drat_proof_handler; @@ -523,7 +523,7 @@ class SatSolver { bool ClauseIsValidUnderDebugAssignment( absl::Span clause) const; bool PBConstraintIsValidUnderDebugAssignment( - const std::vector& cst, Coefficient rhs) const; + absl::Span cst, Coefficient rhs) const; // Logs the given status if parameters_.log_search_progress() is true. // Also returns it. @@ -658,7 +658,7 @@ class SatSolver { // Fills literals with all the literals in the reasons of the literals in the // given input. The output vector will have no duplicates and will not contain // the literals already present in the input. - void ComputeUnionOfReasons(const std::vector& input, + void ComputeUnionOfReasons(absl::Span input, std::vector* literals); // Do the full pseudo-Boolean constraint analysis. This calls multiple @@ -898,8 +898,9 @@ inline std::function BooleanLinearConstraint( inline std::function CardinalityConstraint( int64_t lower_bound, int64_t upper_bound, - const std::vector& literals) { - return [=](Model* model) { + absl::Span literals) { + return [=, literals = std::vector(literals.begin(), literals.end())]( + Model* model) { std::vector cst; cst.reserve(literals.size()); for (int i = 0; i < literals.size(); ++i) { @@ -965,8 +966,9 @@ inline std::function Equality(Literal a, Literal b) { // r <=> (at least one literal is true). This is a reified clause. inline std::function ReifiedBoolOr( - const std::vector& literals, Literal r) { - return [=](Model* model) { + absl::Span literals, Literal r) { + return [=, literals = std::vector(literals.begin(), literals.end())]( + Model* model) { std::vector clause; for (const Literal l : literals) { model->Add(Implication(l, r)); // l => r. diff --git a/ortools/sat/scheduling_helpers.cc b/ortools/sat/scheduling_helpers.cc new file mode 100644 index 0000000000..37eaccfe84 --- /dev/null +++ b/ortools/sat/scheduling_helpers.cc @@ -0,0 +1,1078 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/scheduling_helpers.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/meta/type_traits.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "ortools/base/logging.h" +#include "ortools/base/strong_vector.h" +#include "ortools/sat/implied_bounds.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_base.h" +#include "ortools/sat/integer_expr.h" +#include "ortools/sat/linear_constraint.h" +#include "ortools/sat/model.h" +#include "ortools/sat/precedences.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/sort.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { + +SchedulingConstraintHelper::SchedulingConstraintHelper( + std::vector starts, std::vector ends, + std::vector sizes, + std::vector reason_for_presence, Model* model) + : model_(model), + trail_(model->GetOrCreate()), + sat_solver_(model->GetOrCreate()), + integer_trail_(model->GetOrCreate()), + watcher_(model->GetOrCreate()), + precedence_relations_(model->GetOrCreate()), + starts_(std::move(starts)), + ends_(std::move(ends)), + sizes_(std::move(sizes)), + reason_for_presence_(std::move(reason_for_presence)), + capacity_(starts_.size()), + cached_size_min_(new IntegerValue[capacity_]), + cached_start_min_(new IntegerValue[capacity_]), + cached_end_min_(new IntegerValue[capacity_]), + cached_negated_start_max_(new IntegerValue[capacity_]), + cached_negated_end_max_(new IntegerValue[capacity_]), + cached_shifted_start_min_(new IntegerValue[capacity_]), + cached_negated_shifted_end_max_(new IntegerValue[capacity_]) { + minus_ends_.clear(); + minus_starts_.clear(); + DCHECK_EQ(starts_.size(), ends_.size()); + DCHECK_EQ(starts_.size(), sizes_.size()); + DCHECK_EQ(starts_.size(), reason_for_presence_.size()); + + for (int i = 0; i < starts_.size(); ++i) { + minus_ends_.push_back(ends_[i].Negated()); + minus_starts_.push_back(starts_[i].Negated()); + } + + InitSortedVectors(); + if (!SynchronizeAndSetTimeDirection(true)) { + model->GetOrCreate()->NotifyThatModelIsUnsat(); + } +} + +SchedulingConstraintHelper::SchedulingConstraintHelper(int num_tasks, + Model* model) + : model_(model), + trail_(model->GetOrCreate()), + sat_solver_(model->GetOrCreate()), + integer_trail_(model->GetOrCreate()), + precedence_relations_(model->GetOrCreate()), + capacity_(num_tasks), + cached_size_min_(new IntegerValue[capacity_]), + cached_start_min_(new IntegerValue[capacity_]), + cached_end_min_(new IntegerValue[capacity_]), + cached_negated_start_max_(new IntegerValue[capacity_]), + cached_negated_end_max_(new IntegerValue[capacity_]), + cached_shifted_start_min_(new IntegerValue[capacity_]), + cached_negated_shifted_end_max_(new IntegerValue[capacity_]) { + starts_.resize(num_tasks); + CHECK_EQ(NumTasks(), num_tasks); +} + +bool SchedulingConstraintHelper::Propagate() { + recompute_all_cache_ = true; + for (const int id : propagator_ids_) watcher_->CallOnNextPropagate(id); + return true; +} + +bool SchedulingConstraintHelper::IncrementalPropagate( + const std::vector& watch_indices) { + for (const int t : watch_indices) recompute_cache_.Set(t); + for (const int id : propagator_ids_) watcher_->CallOnNextPropagate(id); + return true; +} + +void SchedulingConstraintHelper::RegisterWith(GenericLiteralWatcher* watcher) { + const int id = watcher->Register(this); + const int num_tasks = starts_.size(); + for (int t = 0; t < num_tasks; ++t) { + watcher->WatchIntegerVariable(sizes_[t].var, id, t); + watcher->WatchIntegerVariable(starts_[t].var, id, t); + watcher->WatchIntegerVariable(ends_[t].var, id, t); + } + watcher->SetPropagatorPriority(id, 0); +} + +bool SchedulingConstraintHelper::UpdateCachedValues(int t) { + if (IsAbsent(t)) return true; + + IntegerValue smin = integer_trail_->LowerBound(starts_[t]); + IntegerValue smax = integer_trail_->UpperBound(starts_[t]); + IntegerValue emin = integer_trail_->LowerBound(ends_[t]); + IntegerValue emax = integer_trail_->UpperBound(ends_[t]); + + // We take the max for the corner case where the size of an optional interval + // is used elsewhere and has a domain with negative value. + // + // TODO(user): maybe we should just disallow size with a negative domain, but + // is is harder to enforce if we have a linear expression for size. + IntegerValue dmin = + std::max(IntegerValue(0), integer_trail_->LowerBound(sizes_[t])); + IntegerValue dmax = integer_trail_->UpperBound(sizes_[t]); + + // Detect first if we have a conflict using the relation start + size = end. + if (dmax < 0) { + ClearReason(); + AddSizeMaxReason(t, dmax); + return PushTaskAbsence(t); + } + if (smin + dmin - emax > 0) { + ClearReason(); + AddStartMinReason(t, smin); + AddSizeMinReason(t, dmin); + AddEndMaxReason(t, emax); + return PushTaskAbsence(t); + } + if (smax + dmax - emin < 0) { + ClearReason(); + AddStartMaxReason(t, smax); + AddSizeMaxReason(t, dmax); + AddEndMinReason(t, emin); + return PushTaskAbsence(t); + } + + // Sometimes, for optional interval with non-optional bounds, this propagation + // give tighter bounds. We always consider the value assuming + // the interval is present. + // + // Note that this is also useful in case not everything was propagated. Note + // also that since there is no conflict, we reach the fix point in one pass. + smin = std::max(smin, emin - dmax); + smax = std::min(smax, emax - dmin); + dmin = std::max(dmin, emin - smax); + emin = std::max(emin, smin + dmin); + emax = std::min(emax, smax + dmax); + + if (emin != cached_end_min_[t]) { + recompute_energy_profile_ = true; + } + + // We might only want to do that if the value changed, but I am not sure it + // is worth the test. + recompute_by_start_max_ = true; + recompute_by_end_min_ = true; + + cached_start_min_[t] = smin; + cached_end_min_[t] = emin; + cached_negated_start_max_[t] = -smax; + cached_negated_end_max_[t] = -emax; + cached_size_min_[t] = dmin; + + // Note that we use the cached value here for EndMin()/StartMax(). + const IntegerValue new_shifted_start_min = emin - dmin; + if (new_shifted_start_min != cached_shifted_start_min_[t]) { + recompute_energy_profile_ = true; + recompute_shifted_start_min_ = true; + cached_shifted_start_min_[t] = new_shifted_start_min; + } + const IntegerValue new_negated_shifted_end_max = -(smax + dmin); + if (new_negated_shifted_end_max != cached_negated_shifted_end_max_[t]) { + recompute_negated_shifted_end_max_ = true; + cached_negated_shifted_end_max_[t] = new_negated_shifted_end_max; + } + return true; +} + +bool SchedulingConstraintHelper::ResetFromSubset( + const SchedulingConstraintHelper& other, absl::Span tasks) { + current_time_direction_ = other.current_time_direction_; + + const int num_tasks = tasks.size(); + starts_.resize(num_tasks); + ends_.resize(num_tasks); + minus_ends_.resize(num_tasks); + minus_starts_.resize(num_tasks); + sizes_.resize(num_tasks); + reason_for_presence_.resize(num_tasks); + for (int i = 0; i < num_tasks; ++i) { + const int t = tasks[i]; + starts_[i] = other.starts_[t]; + ends_[i] = other.ends_[t]; + minus_ends_[i] = other.minus_ends_[t]; + minus_starts_[i] = other.minus_starts_[t]; + sizes_[i] = other.sizes_[t]; + reason_for_presence_[i] = other.reason_for_presence_[t]; + } + + InitSortedVectors(); + return SynchronizeAndSetTimeDirection(true); +} + +void SchedulingConstraintHelper::InitSortedVectors() { + const int num_tasks = starts_.size(); + + recompute_all_cache_ = true; + recompute_cache_.Resize(num_tasks); + for (int t = 0; t < num_tasks; ++t) { + recompute_cache_.Set(t); + } + + // Make sure all the cached_* arrays can hold enough data. + CHECK_LE(num_tasks, capacity_); + + task_by_increasing_start_min_.resize(num_tasks); + task_by_increasing_end_min_.resize(num_tasks); + task_by_increasing_negated_start_max_.resize(num_tasks); + task_by_decreasing_end_max_.resize(num_tasks); + task_by_increasing_shifted_start_min_.resize(num_tasks); + task_by_negated_shifted_end_max_.resize(num_tasks); + for (int t = 0; t < num_tasks; ++t) { + task_by_increasing_start_min_[t].task_index = t; + task_by_increasing_end_min_[t].task_index = t; + task_by_increasing_negated_start_max_[t].task_index = t; + task_by_decreasing_end_max_[t].task_index = t; + + task_by_increasing_shifted_start_min_[t].task_index = t; + task_by_increasing_shifted_start_min_[t].presence_lit = + reason_for_presence_[t]; + task_by_negated_shifted_end_max_[t].task_index = t; + task_by_negated_shifted_end_max_[t].presence_lit = reason_for_presence_[t]; + } + + recompute_by_start_max_ = true; + recompute_by_end_min_ = true; + recompute_energy_profile_ = true; + recompute_shifted_start_min_ = true; + recompute_negated_shifted_end_max_ = true; +} + +void SchedulingConstraintHelper::SetTimeDirection(bool is_forward) { + if (current_time_direction_ != is_forward) { + current_time_direction_ = is_forward; + + std::swap(starts_, minus_ends_); + std::swap(ends_, minus_starts_); + + std::swap(task_by_increasing_start_min_, task_by_decreasing_end_max_); + std::swap(task_by_increasing_end_min_, + task_by_increasing_negated_start_max_); + std::swap(recompute_by_end_min_, recompute_by_start_max_); + std::swap(task_by_increasing_shifted_start_min_, + task_by_negated_shifted_end_max_); + + recompute_energy_profile_ = true; + std::swap(cached_start_min_, cached_negated_end_max_); + std::swap(cached_end_min_, cached_negated_start_max_); + std::swap(cached_shifted_start_min_, cached_negated_shifted_end_max_); + std::swap(recompute_shifted_start_min_, recompute_negated_shifted_end_max_); + } +} + +bool SchedulingConstraintHelper::SynchronizeAndSetTimeDirection( + bool is_forward) { + SetTimeDirection(is_forward); + + // If there was any backtracks since the last time this was called, we + // recompute our cache. + if (sat_solver_->num_backtracks() != saved_num_backtracks_) { + recompute_all_cache_ = true; + saved_num_backtracks_ = sat_solver_->num_backtracks(); + } + + if (recompute_all_cache_) { + for (int t = 0; t < recompute_cache_.size(); ++t) { + if (!UpdateCachedValues(t)) return false; + } + } else { + for (const int t : recompute_cache_) { + if (!UpdateCachedValues(t)) return false; + } + } + recompute_cache_.ClearAll(); + recompute_all_cache_ = false; + return true; +} + +// TODO(user): be more precise when we know a and b are in disjunction. +// we really just need start_b > start_a, or even >= if duration is non-zero. +IntegerValue SchedulingConstraintHelper::GetCurrentMinDistanceBetweenTasks( + int a, int b, bool add_reason_if_after) { + const AffineExpression before = ends_[a]; + const AffineExpression after = starts_[b]; + if (before.var == kNoIntegerVariable || before.coeff != 1 || + after.var == kNoIntegerVariable || after.coeff != 1) { + return kMinIntegerValue; + } + + // We take the max of the level zero offset and the one coming from a + // conditional precedence at true. + const IntegerValue conditional_offset = + precedence_relations_->GetConditionalOffset(before.var, after.var); + const IntegerValue known = integer_trail_->LevelZeroLowerBound(after.var) - + integer_trail_->LevelZeroUpperBound(before.var); + const IntegerValue offset = std::max(conditional_offset, known); + + const IntegerValue needed_offset = before.constant - after.constant; + const IntegerValue distance = offset - needed_offset; + if (add_reason_if_after && distance >= 0 && known < conditional_offset) { + for (const Literal l : precedence_relations_->GetConditionalEnforcements( + before.var, after.var)) { + literal_reason_.push_back(l.Negated()); + } + } + return distance; +} + +// Note that we could call this at a positive level to propagate any literal +// associated to task a before task b. However we only call this for task that +// are in detectable precedence, which means the normal precedence or linear +// propagator should have already propagated that Boolean too. +bool SchedulingConstraintHelper::PropagatePrecedence(int a, int b) { + CHECK(IsPresent(a)); + CHECK(IsPresent(b)); + CHECK_EQ(trail_->CurrentDecisionLevel(), 0); + + const AffineExpression before = ends_[a]; + const AffineExpression after = starts_[b]; + if (after.coeff != 1) return true; + if (before.coeff != 1) return true; + if (after.var == kNoIntegerVariable) return true; + if (before.var == kNoIntegerVariable) return true; + const IntegerValue offset = before.constant - after.constant; + if (precedence_relations_->Add(before.var, after.var, offset)) { + VLOG(2) << "new relation " << TaskDebugString(a) + << " <= " << TaskDebugString(b); + + // TODO(user): Adding new constraint during propagation might not be the + // best idea as it can create some complication. + AddWeightedSumLowerOrEqual({}, {before.var, after.var}, + {int64_t{1}, int64_t{-1}}, -offset.value(), + model_); + if (model_->GetOrCreate()->ModelIsUnsat()) return false; + } + return true; +} + +absl::Span +SchedulingConstraintHelper::TaskByIncreasingStartMin() { + for (TaskTime& ref : task_by_increasing_start_min_) { + ref.time = StartMin(ref.task_index); + } + IncrementalSort(task_by_increasing_start_min_.begin(), + task_by_increasing_start_min_.end()); + return task_by_increasing_start_min_; +} + +absl::Span +SchedulingConstraintHelper::TaskByIncreasingEndMin() { + if (!recompute_by_end_min_) return task_by_increasing_end_min_; + for (TaskTime& ref : task_by_increasing_end_min_) { + ref.time = EndMin(ref.task_index); + } + IncrementalSort(task_by_increasing_end_min_.begin(), + task_by_increasing_end_min_.end()); + recompute_by_end_min_ = false; + return task_by_increasing_end_min_; +} + +absl::Span +SchedulingConstraintHelper::TaskByIncreasingNegatedStartMax() { + if (!recompute_by_start_max_) return task_by_increasing_negated_start_max_; + for (TaskTime& ref : task_by_increasing_negated_start_max_) { + ref.time = cached_negated_start_max_[ref.task_index]; + } + IncrementalSort(task_by_increasing_negated_start_max_.begin(), + task_by_increasing_negated_start_max_.end()); + recompute_by_start_max_ = false; + return task_by_increasing_negated_start_max_; +} + +absl::Span +SchedulingConstraintHelper::TaskByDecreasingEndMax() { + for (TaskTime& ref : task_by_decreasing_end_max_) { + ref.time = EndMax(ref.task_index); + } + IncrementalSort(task_by_decreasing_end_max_.begin(), + task_by_decreasing_end_max_.end(), std::greater()); + return task_by_decreasing_end_max_; +} + +absl::Span +SchedulingConstraintHelper::TaskByIncreasingShiftedStartMin() { + if (recompute_shifted_start_min_) { + recompute_shifted_start_min_ = false; + bool is_sorted = true; + IntegerValue previous = kMinIntegerValue; + for (CachedTaskBounds& ref : task_by_increasing_shifted_start_min_) { + ref.time = ShiftedStartMin(ref.task_index); + is_sorted = is_sorted && ref.time >= previous; + previous = ref.time; + } + if (is_sorted) return task_by_increasing_shifted_start_min_; + IncrementalSort(task_by_increasing_shifted_start_min_.begin(), + task_by_increasing_shifted_start_min_.end()); + } + return task_by_increasing_shifted_start_min_; +} + +// TODO(user): Avoid recomputing it if nothing changed. +const std::vector& +SchedulingConstraintHelper::GetEnergyProfile() { + if (energy_profile_.empty()) { + const int num_tasks = NumTasks(); + for (int t = 0; t < num_tasks; ++t) { + energy_profile_.push_back( + {cached_shifted_start_min_[t], t, /*is_first=*/true}); + energy_profile_.push_back({cached_end_min_[t], t, /*is_first=*/false}); + } + } else { + if (!recompute_energy_profile_) return energy_profile_; + for (ProfileEvent& ref : energy_profile_) { + const int t = ref.task; + if (ref.is_first) { + ref.time = cached_shifted_start_min_[t]; + } else { + ref.time = cached_end_min_[t]; + } + } + } + IncrementalSort(energy_profile_.begin(), energy_profile_.end()); + recompute_energy_profile_ = false; + return energy_profile_; +} + +// Produces a relaxed reason for StartMax(before) < EndMin(after). +void SchedulingConstraintHelper::AddReasonForBeingBefore(int before, + int after) { + AddOtherReason(before); + AddOtherReason(after); + + // The reason will be a linear expression greater than a value. Note that all + // coeff must be positive, and we will use the variable lower bound. + std::vector vars; + std::vector coeffs; + + // Reason for StartMax(before). + const IntegerValue smax_before = StartMax(before); + if (smax_before >= integer_trail_->UpperBound(starts_[before])) { + if (starts_[before].var != kNoIntegerVariable) { + vars.push_back(NegationOf(starts_[before].var)); + coeffs.push_back(starts_[before].coeff); + } + } else { + if (ends_[before].var != kNoIntegerVariable) { + vars.push_back(NegationOf(ends_[before].var)); + coeffs.push_back(ends_[before].coeff); + } + if (sizes_[before].var != kNoIntegerVariable) { + vars.push_back(sizes_[before].var); + coeffs.push_back(sizes_[before].coeff); + } + } + + // Reason for EndMin(after); + const IntegerValue emin_after = EndMin(after); + if (emin_after <= integer_trail_->LowerBound(ends_[after])) { + if (ends_[after].var != kNoIntegerVariable) { + vars.push_back(ends_[after].var); + coeffs.push_back(ends_[after].coeff); + } + } else { + if (starts_[after].var != kNoIntegerVariable) { + vars.push_back(starts_[after].var); + coeffs.push_back(starts_[after].coeff); + } + if (sizes_[after].var != kNoIntegerVariable) { + vars.push_back(sizes_[after].var); + coeffs.push_back(sizes_[after].coeff); + } + } + + DCHECK_LT(smax_before, emin_after); + const IntegerValue slack = emin_after - smax_before - 1; + integer_trail_->AppendRelaxedLinearReason(slack, coeffs, vars, + &integer_reason_); +} + +bool SchedulingConstraintHelper::PushIntegerLiteral(IntegerLiteral lit) { + CHECK(other_helper_ == nullptr); + return integer_trail_->Enqueue(lit, literal_reason_, integer_reason_); +} + +bool SchedulingConstraintHelper::PushIntegerLiteralIfTaskPresent( + int t, IntegerLiteral lit) { + if (IsAbsent(t)) return true; + AddOtherReason(t); + ImportOtherReasons(); + if (IsOptional(t)) { + return integer_trail_->ConditionalEnqueue( + PresenceLiteral(t), lit, &literal_reason_, &integer_reason_); + } + return integer_trail_->Enqueue(lit, literal_reason_, integer_reason_); +} + +// We also run directly the precedence propagator for this variable so that when +// we push an interval start for example, we have a chance to push its end. +bool SchedulingConstraintHelper::PushIntervalBound(int t, IntegerLiteral lit) { + if (!PushIntegerLiteralIfTaskPresent(t, lit)) return false; + if (IsAbsent(t)) return true; + if (!UpdateCachedValues(t)) return false; + recompute_cache_.Clear(t); + return true; +} + +bool SchedulingConstraintHelper::IncreaseStartMin(int t, IntegerValue value) { + if (starts_[t].var == kNoIntegerVariable) { + if (value > starts_[t].constant) return PushTaskAbsence(t); + return true; + } + return PushIntervalBound(t, starts_[t].GreaterOrEqual(value)); +} + +bool SchedulingConstraintHelper::IncreaseEndMin(int t, IntegerValue value) { + if (ends_[t].var == kNoIntegerVariable) { + if (value > ends_[t].constant) return PushTaskAbsence(t); + return true; + } + return PushIntervalBound(t, ends_[t].GreaterOrEqual(value)); +} + +bool SchedulingConstraintHelper::DecreaseEndMax(int t, IntegerValue value) { + if (ends_[t].var == kNoIntegerVariable) { + if (value < ends_[t].constant) return PushTaskAbsence(t); + return true; + } + return PushIntervalBound(t, ends_[t].LowerOrEqual(value)); +} + +bool SchedulingConstraintHelper::PushLiteral(Literal l) { + integer_trail_->EnqueueLiteral(l, literal_reason_, integer_reason_); + return true; +} + +bool SchedulingConstraintHelper::PushTaskAbsence(int t) { + if (IsAbsent(t)) return true; + if (!IsOptional(t)) return ReportConflict(); + + AddOtherReason(t); + + if (IsPresent(t)) { + literal_reason_.push_back(Literal(reason_for_presence_[t]).Negated()); + return ReportConflict(); + } + ImportOtherReasons(); + integer_trail_->EnqueueLiteral(Literal(reason_for_presence_[t]).Negated(), + literal_reason_, integer_reason_); + return true; +} + +bool SchedulingConstraintHelper::PushTaskPresence(int t) { + DCHECK_NE(reason_for_presence_[t], kNoLiteralIndex); + DCHECK(!IsPresent(t)); + + AddOtherReason(t); + + if (IsAbsent(t)) { + literal_reason_.push_back(Literal(reason_for_presence_[t])); + return ReportConflict(); + } + ImportOtherReasons(); + integer_trail_->EnqueueLiteral(Literal(reason_for_presence_[t]), + literal_reason_, integer_reason_); + return true; +} + +bool SchedulingConstraintHelper::ReportConflict() { + ImportOtherReasons(); + return integer_trail_->ReportConflict(literal_reason_, integer_reason_); +} + +void SchedulingConstraintHelper::WatchAllTasks(int id, bool watch_max_side) { + // In all cases, we watch presence literals since this class is not waked up + // when those changes. + const int num_tasks = starts_.size(); + for (int t = 0; t < num_tasks; ++t) { + if (!IsPresent(t) && !IsAbsent(t)) { + watcher_->WatchLiteral(Literal(reason_for_presence_[t]), id); + } + } + + // If everything is watched, it is slighlty more efficient to enqueue the + // propagator when the helper Propagate() is called. This result in less + // entries in our watched lists. + if (watch_max_side) { + propagator_ids_.push_back(id); + return; + } + + // We only watch "min" side. + for (int t = 0; t < num_tasks; ++t) { + watcher_->WatchLowerBound(starts_[t], id); + watcher_->WatchLowerBound(ends_[t], id); + watcher_->WatchLowerBound(sizes_[t], id); + } +} + +void SchedulingConstraintHelper::AddOtherReason(int t) { + if (other_helper_ == nullptr || already_added_to_other_reasons_[t]) return; + already_added_to_other_reasons_[t] = true; + const int mapped_t = map_to_other_helper_[t]; + other_helper_->AddStartMaxReason(mapped_t, event_for_other_helper_); + other_helper_->AddEndMinReason(mapped_t, event_for_other_helper_ + 1); +} + +void SchedulingConstraintHelper::ImportOtherReasons() { + if (other_helper_ != nullptr) ImportOtherReasons(*other_helper_); +} + +void SchedulingConstraintHelper::ImportOtherReasons( + const SchedulingConstraintHelper& other_helper) { + literal_reason_.insert(literal_reason_.end(), + other_helper.literal_reason_.begin(), + other_helper.literal_reason_.end()); + integer_reason_.insert(integer_reason_.end(), + other_helper.integer_reason_.begin(), + other_helper.integer_reason_.end()); +} + +std::string SchedulingConstraintHelper::TaskDebugString(int t) const { + return absl::StrCat("t=", t, " is_present=", + (IsPresent(t) ? "1" + : IsAbsent(t) ? "0" + : "?"), + " size=[", SizeMin(t).value(), ",", SizeMax(t).value(), + "]", " start=[", StartMin(t).value(), ",", + StartMax(t).value(), "]", " end=[", EndMin(t).value(), + ",", EndMax(t).value(), "]"); +} + +IntegerValue SchedulingConstraintHelper::GetMinOverlap(int t, + IntegerValue start, + IntegerValue end) const { + return std::min(std::min(end - start, SizeMin(t)), + std::min(EndMin(t) - start, end - StartMax(t))); +} + +IntegerValue ComputeEnergyMinInWindow( + IntegerValue start_min, IntegerValue start_max, IntegerValue end_min, + IntegerValue end_max, IntegerValue size_min, IntegerValue demand_min, + absl::Span filtered_energy, + IntegerValue window_start, IntegerValue window_end) { + if (window_end <= window_start) return IntegerValue(0); + + // Returns zero if the interval do not necessarily overlap. + if (end_min <= window_start) return IntegerValue(0); + if (start_max >= window_end) return IntegerValue(0); + const IntegerValue window_size = window_end - window_start; + const IntegerValue simple_energy_min = + demand_min * std::min({end_min - window_start, window_end - start_max, + size_min, window_size}); + if (filtered_energy.empty()) return simple_energy_min; + + IntegerValue result = kMaxIntegerValue; + for (const auto [lit, fixed_size, fixed_demand] : filtered_energy) { + const IntegerValue alt_end_min = std::max(end_min, start_min + fixed_size); + const IntegerValue alt_start_max = + std::min(start_max, end_max - fixed_size); + const IntegerValue energy_min = + fixed_demand * + std::min({alt_end_min - window_start, window_end - alt_start_max, + fixed_size, window_size}); + result = std::min(result, energy_min); + } + if (result == kMaxIntegerValue) return simple_energy_min; + return std::max(simple_energy_min, result); +} + +SchedulingDemandHelper::SchedulingDemandHelper( + absl::Span demands, + SchedulingConstraintHelper* helper, Model* model) + : integer_trail_(model->GetOrCreate()), + product_decomposer_(model->GetOrCreate()), + sat_solver_(model->GetOrCreate()), + assignment_(model->GetOrCreate()->Assignment()), + demands_(demands.begin(), demands.end()), + helper_(helper) { + const int num_tasks = helper->NumTasks(); + linearized_energies_.resize(num_tasks); + decomposed_energies_.resize(num_tasks); + cached_energies_min_.resize(num_tasks, kMinIntegerValue); + cached_energies_max_.resize(num_tasks, kMaxIntegerValue); + energy_is_quadratic_.resize(num_tasks, false); + + // We try to init decomposed energies. This is needed for the cuts that are + // created after we call InitAllDecomposedEnergies(). + InitDecomposedEnergies(); +} + +void SchedulingDemandHelper::InitDecomposedEnergies() { + // For the special case were demands is empty. + const int num_tasks = helper_->NumTasks(); + if (demands_.size() != num_tasks) return; + for (int t = 0; t < num_tasks; ++t) { + const AffineExpression size = helper_->Sizes()[t]; + const AffineExpression demand = demands_[t]; + decomposed_energies_[t] = product_decomposer_->TryToDecompose(size, demand); + } +} + +IntegerValue SchedulingDemandHelper::SimpleEnergyMin(int t) const { + if (demands_.empty()) return kMinIntegerValue; + return CapProdI(DemandMin(t), helper_->SizeMin(t)); +} + +IntegerValue SchedulingDemandHelper::LinearEnergyMin(int t) const { + if (!linearized_energies_[t].has_value()) return kMinIntegerValue; + return linearized_energies_[t]->Min(*integer_trail_); +} + +IntegerValue SchedulingDemandHelper::DecomposedEnergyMin(int t) const { + if (decomposed_energies_[t].empty()) return kMinIntegerValue; + IntegerValue result = kMaxIntegerValue; + for (const auto [lit, fixed_size, fixed_demand] : decomposed_energies_[t]) { + if (assignment_.LiteralIsTrue(lit)) { + return fixed_size * fixed_demand; + } + if (assignment_.LiteralIsFalse(lit)) continue; + result = std::min(result, fixed_size * fixed_demand); + } + DCHECK_NE(result, kMaxIntegerValue); + return result; +} + +IntegerValue SchedulingDemandHelper::SimpleEnergyMax(int t) const { + if (demands_.empty()) return kMaxIntegerValue; + return CapProdI(DemandMax(t), helper_->SizeMax(t)); +} + +IntegerValue SchedulingDemandHelper::LinearEnergyMax(int t) const { + if (!linearized_energies_[t].has_value()) return kMaxIntegerValue; + return linearized_energies_[t]->Max(*integer_trail_); +} + +IntegerValue SchedulingDemandHelper::DecomposedEnergyMax(int t) const { + if (decomposed_energies_[t].empty()) return kMaxIntegerValue; + IntegerValue result = kMinIntegerValue; + for (const auto [lit, fixed_size, fixed_demand] : decomposed_energies_[t]) { + if (assignment_.LiteralIsTrue(lit)) { + return fixed_size * fixed_demand; + } + if (assignment_.LiteralIsFalse(lit)) continue; + result = std::max(result, fixed_size * fixed_demand); + } + DCHECK_NE(result, kMinIntegerValue); + return result; +} + +bool SchedulingDemandHelper::CacheAllEnergyValues() { + const int num_tasks = cached_energies_min_.size(); + const bool is_at_level_zero = sat_solver_->CurrentDecisionLevel() == 0; + for (int t = 0; t < num_tasks; ++t) { + // Try to reduce the size of the decomposed energy vector. + if (is_at_level_zero) { + int new_size = 0; + for (int i = 0; i < decomposed_energies_[t].size(); ++i) { + if (assignment_.LiteralIsFalse(decomposed_energies_[t][i].literal)) { + continue; + } + decomposed_energies_[t][new_size++] = decomposed_energies_[t][i]; + } + decomposed_energies_[t].resize(new_size); + } + + cached_energies_min_[t] = std::max( + {SimpleEnergyMin(t), LinearEnergyMin(t), DecomposedEnergyMin(t)}); + if (cached_energies_min_[t] <= kMinIntegerValue) return false; + energy_is_quadratic_[t] = + decomposed_energies_[t].empty() && !demands_.empty() && + !integer_trail_->IsFixed(demands_[t]) && !helper_->SizeIsFixed(t); + cached_energies_max_[t] = std::min( + {SimpleEnergyMax(t), LinearEnergyMax(t), DecomposedEnergyMax(t)}); + if (cached_energies_max_[t] >= kMaxIntegerValue) return false; + } + + return true; +} + +IntegerValue SchedulingDemandHelper::DemandMin(int t) const { + DCHECK_LT(t, demands_.size()); + return integer_trail_->LowerBound(demands_[t]); +} + +IntegerValue SchedulingDemandHelper::DemandMax(int t) const { + DCHECK_LT(t, demands_.size()); + return integer_trail_->UpperBound(demands_[t]); +} + +bool SchedulingDemandHelper::DemandIsFixed(int t) const { + return integer_trail_->IsFixed(demands_[t]); +} + +bool SchedulingDemandHelper::DecreaseEnergyMax(int t, IntegerValue value) { + if (value < EnergyMin(t)) { + if (helper_->IsOptional(t)) { + return helper_->PushTaskAbsence(t); + } else { + return helper_->ReportConflict(); + } + } else if (!decomposed_energies_[t].empty()) { + for (const auto [lit, fixed_size, fixed_demand] : decomposed_energies_[t]) { + if (fixed_size * fixed_demand > value) { + if (assignment_.LiteralIsTrue(lit)) return helper_->ReportConflict(); + if (assignment_.LiteralIsFalse(lit)) continue; + if (!helper_->PushLiteral(lit.Negated())) return false; + } + } + } else if (linearized_energies_[t].has_value() && + linearized_energies_[t]->vars.size() == 1) { + const LinearExpression& e = linearized_energies_[t].value(); + const AffineExpression affine_energy(e.vars[0], e.coeffs[0], e.offset); + const IntegerLiteral deduction = affine_energy.LowerOrEqual(value); + if (!helper_->PushIntegerLiteralIfTaskPresent(t, deduction)) { + return false; + } + } else { + // TODO(user): Propagate if possible. + VLOG(3) << "Cumulative energy missed propagation"; + } + return true; +} + +void SchedulingDemandHelper::AddDemandMinReason(int t) { + DCHECK_LT(t, demands_.size()); + if (demands_[t].var != kNoIntegerVariable) { + helper_->MutableIntegerReason()->push_back( + integer_trail_->LowerBoundAsLiteral(demands_[t].var)); + } +} + +void SchedulingDemandHelper::AddDemandMinReason(int t, + IntegerValue min_demand) { + DCHECK_LT(t, demands_.size()); + if (demands_[t].var != kNoIntegerVariable) { + helper_->MutableIntegerReason()->push_back( + demands_[t].GreaterOrEqual(min_demand)); + } +} + +void SchedulingDemandHelper::AddEnergyMinReason(int t) { + // We prefer these reason in order. + const IntegerValue value = cached_energies_min_[t]; + if (DecomposedEnergyMin(t) >= value) { + auto* reason = helper_->MutableLiteralReason(); + const int old_size = reason->size(); + for (const auto [lit, fixed_size, fixed_demand] : decomposed_energies_[t]) { + if (assignment_.LiteralIsTrue(lit)) { + reason->resize(old_size); + reason->push_back(lit.Negated()); + return; + } else if (fixed_size * fixed_demand < value && + assignment_.LiteralIsFalse(lit)) { + reason->push_back(lit); + } + } + } else if (SimpleEnergyMin(t) >= value) { + AddDemandMinReason(t); + helper_->AddSizeMinReason(t); + } else { + DCHECK_GE(LinearEnergyMin(t), value); + for (const IntegerVariable var : linearized_energies_[t]->vars) { + helper_->MutableIntegerReason()->push_back( + integer_trail_->LowerBoundAsLiteral(var)); + } + } +} + +bool SchedulingDemandHelper::AddLinearizedDemand( + int t, LinearConstraintBuilder* builder) const { + if (helper_->IsPresent(t)) { + if (!decomposed_energies_[t].empty()) { + for (const LiteralValueValue& entry : decomposed_energies_[t]) { + if (!builder->AddLiteralTerm(entry.literal, entry.right_value)) { + return false; + } + } + } else { + builder->AddTerm(demands_[t], IntegerValue(1)); + } + } else if (!helper_->IsAbsent(t)) { + return builder->AddLiteralTerm(helper_->PresenceLiteral(t), DemandMin(t)); + } + return true; +} + +void SchedulingDemandHelper::OverrideLinearizedEnergies( + absl::Span energies) { + const int num_tasks = energies.size(); + DCHECK_EQ(num_tasks, helper_->NumTasks()); + linearized_energies_.resize(num_tasks); + for (int t = 0; t < num_tasks; ++t) { + linearized_energies_[t] = energies[t]; + if (DEBUG_MODE) { + for (const IntegerValue coeff : linearized_energies_[t]->coeffs) { + DCHECK_GE(coeff, 0); + } + } + } +} + +std::vector SchedulingDemandHelper::FilteredDecomposedEnergy( + int index) { + if (decomposed_energies_[index].empty()) return {}; + if (sat_solver_->CurrentDecisionLevel() == 0) { + // CacheAllEnergyValues has already filtered false literals. + return decomposed_energies_[index]; + } + + // Scan and filter false literals. + std::vector result; + for (const auto& e : decomposed_energies_[index]) { + if (assignment_.LiteralIsFalse(e.literal)) continue; + result.push_back(e); + } + return result; +} + +void SchedulingDemandHelper::OverrideDecomposedEnergies( + const std::vector>& energies) { + DCHECK_EQ(energies.size(), helper_->NumTasks()); + decomposed_energies_ = energies; +} + +IntegerValue SchedulingDemandHelper::EnergyMinInWindow( + int t, IntegerValue window_start, IntegerValue window_end) { + return ComputeEnergyMinInWindow( + helper_->StartMin(t), helper_->StartMax(t), helper_->EndMin(t), + helper_->EndMax(t), helper_->SizeMin(t), DemandMin(t), + FilteredDecomposedEnergy(t), window_start, window_end); +} + +// Since we usually ask way less often for the reason, we redo the computation +// here. +void SchedulingDemandHelper::AddEnergyMinInWindowReason( + int t, IntegerValue window_start, IntegerValue window_end) { + const IntegerValue actual_energy_min = + EnergyMinInWindow(t, window_start, window_end); + if (actual_energy_min == 0) return; + + // Return simple reason right away if there is no decomposition or the simple + // energy is enough. + const IntegerValue start_max = helper_->StartMax(t); + const IntegerValue end_min = helper_->EndMin(t); + const IntegerValue min_overlap = + helper_->GetMinOverlap(t, window_start, window_end); + const IntegerValue simple_energy_min = DemandMin(t) * min_overlap; + if (simple_energy_min == actual_energy_min) { + AddDemandMinReason(t); + helper_->AddSizeMinReason(t); + helper_->AddStartMaxReason(t, start_max); + helper_->AddEndMinReason(t, end_min); + return; + } + + // TODO(user): only include the one we need? + const IntegerValue start_min = helper_->StartMin(t); + const IntegerValue end_max = helper_->EndMax(t); + DCHECK(!decomposed_energies_[t].empty()); + helper_->AddStartMinReason(t, start_min); + helper_->AddStartMaxReason(t, start_max); + helper_->AddEndMinReason(t, end_min); + helper_->AddEndMaxReason(t, end_max); + + auto* literal_reason = helper_->MutableLiteralReason(); + const int old_size = literal_reason->size(); + + DCHECK(!decomposed_energies_[t].empty()); + for (const auto [lit, fixed_size, fixed_demand] : decomposed_energies_[t]) { + // Should be the same in most cases. + if (assignment_.LiteralIsTrue(lit)) { + literal_reason->resize(old_size); + literal_reason->push_back(lit.Negated()); + return; + } + if (assignment_.LiteralIsFalse(lit)) { + const IntegerValue alt_em = std::max(end_min, start_min + fixed_size); + const IntegerValue alt_sm = std::min(start_max, end_max - fixed_size); + const IntegerValue energy_min = + fixed_demand * + std::min({alt_em - window_start, window_end - alt_sm, fixed_size}); + if (energy_min >= actual_energy_min) continue; + literal_reason->push_back(lit); + } + } +} + +void AddIntegerVariableFromIntervals(SchedulingConstraintHelper* helper, + Model* model, + std::vector* vars) { + IntegerEncoder* encoder = model->GetOrCreate(); + for (int t = 0; t < helper->NumTasks(); ++t) { + if (helper->Starts()[t].var != kNoIntegerVariable) { + vars->push_back(helper->Starts()[t].var); + } + if (helper->Sizes()[t].var != kNoIntegerVariable) { + vars->push_back(helper->Sizes()[t].var); + } + if (helper->Ends()[t].var != kNoIntegerVariable) { + vars->push_back(helper->Ends()[t].var); + } + if (helper->IsOptional(t) && !helper->IsAbsent(t) && + !helper->IsPresent(t)) { + const Literal l = helper->PresenceLiteral(t); + IntegerVariable view = kNoIntegerVariable; + if (!encoder->LiteralOrNegationHasView(l, &view)) { + view = model->Add(NewIntegerVariableFromLiteral(l)); + } + vars->push_back(view); + } + } +} + +void AppendVariablesFromCapacityAndDemands( + const AffineExpression& capacity, SchedulingDemandHelper* demands_helper, + Model* model, std::vector* vars) { + auto* integer_trail = model->GetOrCreate(); + for (const AffineExpression& demand_expr : demands_helper->Demands()) { + if (!integer_trail->IsFixed(demand_expr)) { + vars->push_back(demand_expr.var); + } + } + IntegerEncoder* encoder = model->GetOrCreate(); + for (const auto& product : demands_helper->DecomposedEnergies()) { + for (const auto& lit_val_val : product) { + IntegerVariable view = kNoIntegerVariable; + if (!encoder->LiteralOrNegationHasView(lit_val_val.literal, &view)) { + view = model->Add(NewIntegerVariableFromLiteral(lit_val_val.literal)); + } + vars->push_back(view); + } + } + + if (!integer_trail->IsFixed(capacity)) { + vars->push_back(capacity.var); + } +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/scheduling_helpers.h b/ortools/sat/scheduling_helpers.h new file mode 100644 index 0000000000..dccbcfa06f --- /dev/null +++ b/ortools/sat/scheduling_helpers.h @@ -0,0 +1,785 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_SCHEDULING_HELPERS_H_ +#define OR_TOOLS_SAT_SCHEDULING_HELPERS_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "ortools/sat/implied_bounds.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_base.h" +#include "ortools/sat/linear_constraint.h" +#include "ortools/sat/model.h" +#include "ortools/sat/precedences.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/util/bitset.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { + +// An helper struct to sort task by time. This is used by the +// SchedulingConstraintHelper but also by many scheduling propagators to sort +// tasks. +struct TaskTime { + int task_index; + IntegerValue time; + bool operator<(TaskTime other) const { return time < other.time; } + bool operator>(TaskTime other) const { return time > other.time; } +}; + +// We have some free space in TaskTime. +// We stick the presence_lit to save an indirection in some algo. +// +// TODO(user): Experiment caching more value. In particular +// TaskByIncreasingShiftedStartMin() could tie break task for better heuristics? +struct CachedTaskBounds { + int task_index; + LiteralIndex presence_lit; + IntegerValue time; + bool operator<(CachedTaskBounds other) const { return time < other.time; } + bool operator>(CachedTaskBounds other) const { return time > other.time; } +}; + +struct IntervalDefinition { + AffineExpression start; + AffineExpression end; + AffineExpression size; + std::optional is_present; + + template + friend H AbslHashValue(H h, const IntervalDefinition& i) { + return H::combine(std::move(h), i.start, i.end, i.size, i.is_present); + } + + bool operator==(const IntervalDefinition& other) const { + return start == other.start && end == other.end && size == other.size && + is_present == other.is_present; + } +}; + +// Helper class shared by the propagators that manage a given list of tasks. +// +// One of the main advantage of this class is that it allows to share the +// vectors of tasks sorted by various criteria between propagator for a faster +// code. It is also helpful to allow in-processing: the intervals that are +// handled by this class are not necessarily the same as the ones in the model. +class SchedulingConstraintHelper : public PropagatorInterface { + public: + // All the functions below refer to a task by its index t in the tasks + // vector given at construction. + SchedulingConstraintHelper(std::vector starts, + std::vector ends, + std::vector sizes, + std::vector reason_for_presence, + Model* model); + + // Temporary constructor. + // The class will not be usable until ResetFromSubset() is called. + // + // TODO(user): Remove this. It is a hack because the disjunctive class needs + // to fetch the maximum possible number of task at construction. + SchedulingConstraintHelper(int num_tasks, Model* model); + + // This is a propagator so we can "cache" all the intervals relevant + // information. This gives good speedup. Note however that the info is stale + // except if a bound was pushed by this helper or if this was called. We run + // it at the highest priority, so that will mostly be the case at the + // beginning of each Propagate() call of the classes using this. + bool Propagate() final; + bool IncrementalPropagate(const std::vector& watch_indices) final; + void RegisterWith(GenericLiteralWatcher* watcher); + + // Resets the class to the same state as if it was constructed with + // the given subset of tasks from other. + ABSL_MUST_USE_RESULT bool ResetFromSubset( + const SchedulingConstraintHelper& other, absl::Span tasks); + + // Returns the number of task. + int NumTasks() const { return starts_.size(); } + + // Make sure the cached values are up to date. Also sets the time direction to + // either forward/backward. This will impact all the functions below. This + // MUST be called at the beginning of all Propagate() call that uses this + // helper. + void SetTimeDirection(bool is_forward); + bool CurrentTimeIsForward() const { return current_time_direction_; } + ABSL_MUST_USE_RESULT bool SynchronizeAndSetTimeDirection(bool is_forward); + + // Helpers for the current bounds on the current task time window. + // [ (size-min) ... (size-min) ] + // ^ ^ ^ ^ + // start-min end-min start-max end-max + // + // Note that for tasks with variable durations, we don't necessarily have + // duration-min between the XXX-min and XXX-max value. + // + // Remark: We use cached values for most of these function as this is faster. + // In practice, the cache will almost always be up to date, but not in corner + // cases where pushing the start of one task will change values for many + // others. This is fine as the new values will be picked up as we reach the + // propagation fixed point. + IntegerValue SizeMin(int t) const { return cached_size_min_[t]; } + IntegerValue SizeMax(int t) const { + // This one is "rare" so we don't cache it. + return integer_trail_->UpperBound(sizes_[t]); + } + IntegerValue StartMin(int t) const { return cached_start_min_[t]; } + IntegerValue EndMin(int t) const { return cached_end_min_[t]; } + IntegerValue StartMax(int t) const { return -cached_negated_start_max_[t]; } + IntegerValue EndMax(int t) const { return -cached_negated_end_max_[t]; } + + IntegerValue LevelZeroStartMin(int t) const { + return integer_trail_->LevelZeroLowerBound(starts_[t]); + } + IntegerValue LevelZeroStartMax(int t) const { + return integer_trail_->LevelZeroUpperBound(starts_[t]); + } + IntegerValue LevelZeroEndMax(int t) const { + return integer_trail_->LevelZeroUpperBound(ends_[t]); + } + + // In the presence of tasks with a variable size, we do not necessarily + // have start_min + size_min = end_min, we can instead have a situation + // like: + // | |<--- size-min --->| + // ^ ^ ^ + // start-min | end-min + // | + // We define the "shifted start min" to be the right most time such that + // we known that we must have min-size "energy" to the right of it if the + // task is present. Using it in our scheduling propagators allows to propagate + // more in the presence of tasks with variable size (or optional task + // where we also do not necessarily have start_min + size_min = end_min. + // + // To explain this shifted start min, one must use the AddEnergyAfterReason(). + IntegerValue ShiftedStartMin(int t) const { + return cached_shifted_start_min_[t]; + } + + // As with ShiftedStartMin(), we can compute the shifted end max (that is + // start_max + size_min. + IntegerValue ShiftedEndMax(int t) const { + return -cached_negated_shifted_end_max_[t]; + } + + bool StartIsFixed(int t) const; + bool EndIsFixed(int t) const; + bool SizeIsFixed(int t) const; + + // Returns true if the corresponding fact is known for sure. A normal task is + // always present. For optional task for which the presence is still unknown, + // both of these function will return false. + bool IsOptional(int t) const; + bool IsPresent(int t) const; + bool IsAbsent(int t) const; + + // Same if one already have the presence LiteralIndex of a task. + bool IsOptional(LiteralIndex lit) const; + bool IsPresent(LiteralIndex lit) const; + bool IsAbsent(LiteralIndex lit) const; + + // Return a value so that End(a) + dist <= Start(b). + // Returns kMinInterValue if we don't have any such relation. + IntegerValue GetCurrentMinDistanceBetweenTasks( + int a, int b, bool add_reason_if_after = false); + + // We detected a precedence between two tasks. + // If we are at level zero, we might want to add the constraint. + // If we are at positive level, we might want to propagate the associated + // precedence literal if it exists. + bool PropagatePrecedence(int a, int b); + + // Return the minimum overlap of interval i with the time window [start..end]. + // + // Note: this is different from the mandatory part of an interval. + IntegerValue GetMinOverlap(int t, IntegerValue start, IntegerValue end) const; + + // Returns a string with the current task bounds. + std::string TaskDebugString(int t) const; + + // Sorts and returns the tasks in corresponding order at the time of the call. + // Note that we do not mean strictly-increasing/strictly-decreasing, there + // will be duplicate time values in these vectors. + // + // TODO(user): we could merge the first loop of IncrementalSort() with the + // loop that fill TaskTime.time at each call. + absl::Span TaskByIncreasingStartMin(); + absl::Span TaskByDecreasingEndMax(); + + absl::Span TaskByIncreasingNegatedStartMax(); + absl::Span TaskByIncreasingEndMin(); + + absl::Span TaskByIncreasingShiftedStartMin(); + + // Returns a sorted vector where each task appear twice, the first occurrence + // is at size (end_min - size_min) and the second one at (end_min). + // + // This is quite usage specific. + struct ProfileEvent { + IntegerValue time; + int task; + bool is_first; + + bool operator<(const ProfileEvent& other) const { + if (time == other.time) { + if (task == other.task) return is_first > other.is_first; + return task < other.task; + } + return time < other.time; + } + }; + const std::vector& GetEnergyProfile(); + + // Functions to clear and then set the current reason. + void ClearReason(); + void AddPresenceReason(int t); + void AddAbsenceReason(int t); + void AddSizeMinReason(int t); + void AddSizeMinReason(int t, IntegerValue lower_bound); + void AddSizeMaxReason(int t, IntegerValue upper_bound); + void AddStartMinReason(int t, IntegerValue lower_bound); + void AddStartMaxReason(int t, IntegerValue upper_bound); + void AddEndMinReason(int t, IntegerValue lower_bound); + void AddEndMaxReason(int t, IntegerValue upper_bound); + void AddShiftedEndMaxReason(int t, IntegerValue upper_bound); + + void AddEnergyAfterReason(int t, IntegerValue energy_min, IntegerValue time); + void AddEnergyMinInIntervalReason(int t, IntegerValue min, IntegerValue max); + + // Adds the reason why task "before" must be before task "after". + // That is StartMax(before) < EndMin(after). + void AddReasonForBeingBefore(int before, int after); + + // It is also possible to directly manipulates the underlying reason vectors + // that will be used when pushing something. + std::vector* MutableLiteralReason() { return &literal_reason_; } + std::vector* MutableIntegerReason() { + return &integer_reason_; + } + + // Push something using the current reason. Note that IncreaseStartMin() will + // also increase the end-min, and DecreaseEndMax() will also decrease the + // start-max. + // + // Important: IncreaseStartMin() and DecreaseEndMax() can be called on an + // optional interval whose presence is still unknown and push a bound + // conditioned on its presence. The functions will do the correct thing + // depending on whether or not the start_min/end_max are optional variables + // whose presence implies the interval presence. + ABSL_MUST_USE_RESULT bool IncreaseStartMin(int t, IntegerValue value); + ABSL_MUST_USE_RESULT bool IncreaseEndMin(int t, IntegerValue value); + ABSL_MUST_USE_RESULT bool DecreaseEndMax(int t, IntegerValue value); + ABSL_MUST_USE_RESULT bool PushLiteral(Literal l); + ABSL_MUST_USE_RESULT bool PushTaskAbsence(int t); + ABSL_MUST_USE_RESULT bool PushTaskPresence(int t); + ABSL_MUST_USE_RESULT bool PushIntegerLiteral(IntegerLiteral lit); + ABSL_MUST_USE_RESULT bool ReportConflict(); + ABSL_MUST_USE_RESULT bool PushIntegerLiteralIfTaskPresent(int t, + IntegerLiteral lit); + + absl::Span Starts() const { return starts_; } + absl::Span Ends() const { return ends_; } + absl::Span Sizes() const { return sizes_; } + + IntervalDefinition GetIntervalDefinition(int index) const { + return IntervalDefinition{ + .start = starts_[index], + .end = ends_[index], + .size = sizes_[index], + .is_present = (reason_for_presence_[index] == kNoLiteralIndex + ? std::optional() + : Literal(reason_for_presence_[index]))}; + } + + Literal PresenceLiteral(int index) const { + DCHECK(IsOptional(index)); + return Literal(reason_for_presence_[index]); + } + + // Registers the given propagator id to be called if any of the tasks + // in this class change. Note that we do not watch size max though. + void WatchAllTasks(int id, bool watch_max_side = true); + + // Manages the other helper (used by the diffn constraint). + // + // For each interval appearing in a reason on this helper, another reason + // will be added. This other reason specifies that on the other helper, the + // corresponding interval overlaps 'event'. + void SetOtherHelper(SchedulingConstraintHelper* other_helper, + absl::Span map_to_other_helper, + IntegerValue event) { + CHECK(other_helper != nullptr); + other_helper_ = other_helper; + map_to_other_helper_ = map_to_other_helper; + event_for_other_helper_ = event; + } + + bool HasOtherHelper() const { return other_helper_ != nullptr; } + + void ClearOtherHelper() { other_helper_ = nullptr; } + + // Adds to this helper reason all the explanation of the other helper. + // This checks that other_helper_ is null. + // + // This is used in the 2D energetic reasoning in the diffn constraint. + void ImportOtherReasons(const SchedulingConstraintHelper& other_helper); + + // TODO(user): Change the propagation loop code so that we don't stop + // pushing in the middle of the propagation as more advanced propagator do + // not handle this correctly. + bool InPropagationLoop() const { return integer_trail_->InPropagationLoop(); } + + int CurrentDecisionLevel() const { return trail_->CurrentDecisionLevel(); } + + private: + // Tricky: when a task is optional, it is possible it size min is negative, + // but we know that if a task is present, its size should be >= 0. So in the + // reason, when we need the size_min and it is currently negative, we can just + // ignore it and use zero instead. + AffineExpression NegatedSizeOrZero(int t) { + if (integer_trail_->LowerBound(sizes_[t]) <= 0) { + return AffineExpression(0); + } + return sizes_[t].Negated(); + } + + // Generic reason for a <= upper_bound, given that a = b + c in case the + // current upper bound of a is not good enough. + void AddGenericReason(const AffineExpression& a, IntegerValue upper_bound, + const AffineExpression& b, const AffineExpression& c); + + void InitSortedVectors(); + ABSL_MUST_USE_RESULT bool UpdateCachedValues(int t); + + // Internal function for IncreaseStartMin()/DecreaseEndMax(). + bool PushIntervalBound(int t, IntegerLiteral lit); + + // This will be called on any interval that is part of a reason or + // a bound push. Since the last call to ClearReason(), for each unique + // t, we will add once to other_helper_ the reason for t containing + // the point event_for_other_helper_. + void AddOtherReason(int t); + + // Import the reasons on the other helper into this helper. + void ImportOtherReasons(); + + Model* model_; + Trail* trail_; + SatSolver* sat_solver_; + IntegerTrail* integer_trail_; + GenericLiteralWatcher* watcher_; + PrecedenceRelations* precedence_relations_; + + // The current direction of time, true for forward, false for backward. + bool current_time_direction_ = true; + + // All the underlying variables of the tasks. + // The vectors are indexed by the task index t. + std::vector starts_; + std::vector ends_; + std::vector sizes_; + std::vector reason_for_presence_; + + // The negation of the start/end variable so that SetTimeDirection() + // can do its job in O(1) instead of calling NegationOf() on each entry. + std::vector minus_starts_; + std::vector minus_ends_; + + // This is used to detect when we need to invalidate the cache. + int64_t saved_num_backtracks_ = 0; + + // The caches of all relevant interval values. + // These are initially of size capacity and never resized. + // + // TODO(user): Because of std::swap() in SetTimeDirection, we cannot mark + // most of them as "const" and as a result we loose some performance since + // the address need to be re-fetched on most access. + const int capacity_; + const std::unique_ptr cached_size_min_; + std::unique_ptr cached_start_min_; + std::unique_ptr cached_end_min_; + std::unique_ptr cached_negated_start_max_; + std::unique_ptr cached_negated_end_max_; + std::unique_ptr cached_shifted_start_min_; + std::unique_ptr cached_negated_shifted_end_max_; + + // Sorted vectors returned by the TasksBy*() functions. + std::vector task_by_increasing_start_min_; + std::vector task_by_decreasing_end_max_; + + bool recompute_by_start_max_ = true; + bool recompute_by_end_min_ = true; + std::vector task_by_increasing_negated_start_max_; + std::vector task_by_increasing_end_min_; + + // Sorted vector returned by GetEnergyProfile(). + bool recompute_energy_profile_ = true; + std::vector energy_profile_; + + // This one is the most commonly used, so we optimized a bit more its + // computation by detecting when there is nothing to do. + std::vector task_by_increasing_shifted_start_min_; + std::vector task_by_negated_shifted_end_max_; + bool recompute_shifted_start_min_ = true; + bool recompute_negated_shifted_end_max_ = true; + + // If recompute_cache_[t] is true, then we need to update all the cached + // value for the task t in SynchronizeAndSetTimeDirection(). + bool recompute_all_cache_ = true; + Bitset64 recompute_cache_; + + // Reason vectors. + std::vector literal_reason_; + std::vector integer_reason_; + + // Optional 'proxy' helper used in the diffn constraint. + SchedulingConstraintHelper* other_helper_ = nullptr; + absl::Span map_to_other_helper_; + IntegerValue event_for_other_helper_; + std::vector already_added_to_other_reasons_; + + // List of watcher to "wake-up" each time one of the task bounds changes. + std::vector propagator_ids_; +}; + +// Helper class for cumulative constraint to wrap demands and expose concept +// like energy. +// +// In a cumulative constraint, an interval always has a size and a demand, but +// it can also have a set of "selector" literals each associated with a fixed +// size / fixed demands. This allows more precise energy estimation. +// +// TODO(user): Cache energy min and reason for the non O(1) cases. +class SchedulingDemandHelper { + public: + // Hack: this can be called with and empty demand vector as long as + // OverrideEnergies() is called to define the energies. + SchedulingDemandHelper(absl::Span demands, + SchedulingConstraintHelper* helper, Model* model); + + // When defined, the interval will consume this much demand during its whole + // duration. Some propagator only relies on the "energy" and thus never uses + // this. + IntegerValue DemandMin(int t) const; + IntegerValue DemandMax(int t) const; + IntegerValue LevelZeroDemandMin(int t) const { + return integer_trail_->LevelZeroLowerBound(demands_[t]); + } + bool DemandIsFixed(int t) const; + void AddDemandMinReason(int t); + void AddDemandMinReason(int t, IntegerValue min_demand); + const std::vector& Demands() const { return demands_; } + + // Adds the linearized demand (either the affine demand expression, or the + // demand part of the decomposed energy if present) to the builder. + // It returns false and do not add any term to the builder.if any literal + // involved has no integer view. + ABSL_MUST_USE_RESULT bool AddLinearizedDemand( + int t, LinearConstraintBuilder* builder) const; + + // The "energy" is usually size * demand, but in some non-conventional usage + // it might have a more complex formula. In all case, the energy is assumed + // to be only consumed during the interval duration. + // + // Returns false if the energy can overflow and was not computed. + // + // IMPORTANT: One must call CacheAllEnergyValues() for the values to be + // updated. TODO(user): this is error prone, maybe we should revisit. But if + // there is many alternatives, we don't want to rescan the list more than a + // linear number of time per propagation. + // + // TODO(user): Add more complex EnergyMinBefore(time) once we also support + // expressing the interval as a set of alternatives. + // + // At level 0, it will filter false literals from decomposed energies. + bool CacheAllEnergyValues(); + IntegerValue EnergyMin(int t) const { return cached_energies_min_[t]; } + IntegerValue EnergyMax(int t) const { return cached_energies_max_[t]; } + bool EnergyIsQuadratic(int t) const { return energy_is_quadratic_[t]; } + void AddEnergyMinReason(int t); + + // Returns the energy min in [start, end]. + // + // Note(user): These functions are not in O(1) if the decomposition is used, + // so we have to be careful in not calling them too often. + IntegerValue EnergyMinInWindow(int t, IntegerValue window_start, + IntegerValue window_end); + void AddEnergyMinInWindowReason(int t, IntegerValue window_start, + IntegerValue window_end); + + // Important: This might not do anything depending on the representation of + // the energy we have. + ABSL_MUST_USE_RESULT bool DecreaseEnergyMax(int t, IntegerValue value); + + // Different optional representation of the energy of an interval. + // + // Important: first value is size, second value is demand. + const std::vector>& DecomposedEnergies() + const { + return decomposed_energies_; + } + + // Visible for testing. + void OverrideLinearizedEnergies(absl::Span energies); + void OverrideDecomposedEnergies( + const std::vector>& energies); + // Returns the decomposed energy terms compatible with the current literal + // assignment. It must not be used to create reasons if not at level 0. + // It returns en empty vector if the decomposed energy is not available. + // + // Important: first value is size, second value is demand. + std::vector FilteredDecomposedEnergy(int index); + + // Init all decomposed energies. It needs probing to be finished. This happens + // after the creation of the helper. + void InitDecomposedEnergies(); + + private: + IntegerValue SimpleEnergyMin(int t) const; + IntegerValue LinearEnergyMin(int t) const; + IntegerValue SimpleEnergyMax(int t) const; + IntegerValue LinearEnergyMax(int t) const; + IntegerValue DecomposedEnergyMin(int t) const; + IntegerValue DecomposedEnergyMax(int t) const; + + IntegerTrail* integer_trail_; + ProductDecomposer* product_decomposer_; + SatSolver* sat_solver_; // To get the current propagation level. + const VariablesAssignment& assignment_; + std::vector demands_; + SchedulingConstraintHelper* helper_; + + // Cached value of the energies, as it can be a bit costly to compute. + std::vector cached_energies_min_; + std::vector cached_energies_max_; + std::vector energy_is_quadratic_; + + // A representation of the energies as a set of alternative. + // If subvector is empty, we don't have this representation. + std::vector> decomposed_energies_; + + // A representation of the energies as a set of linear expression. + // If the optional is not set, we don't have this representation. + std::vector> linearized_energies_; +}; + +// ============================================================================= +// Utilities +// ============================================================================= + +IntegerValue ComputeEnergyMinInWindow( + IntegerValue start_min, IntegerValue start_max, IntegerValue end_min, + IntegerValue end_max, IntegerValue size_min, IntegerValue demand_min, + absl::Span filtered_energy, + IntegerValue window_start, IntegerValue window_end); + +// ============================================================================= +// SchedulingConstraintHelper inlined functions. +// ============================================================================= + +inline bool SchedulingConstraintHelper::StartIsFixed(int t) const { + return integer_trail_->IsFixed(starts_[t]); +} + +inline bool SchedulingConstraintHelper::EndIsFixed(int t) const { + return integer_trail_->IsFixed(ends_[t]); +} + +inline bool SchedulingConstraintHelper::SizeIsFixed(int t) const { + return integer_trail_->IsFixed(sizes_[t]); +} + +inline bool SchedulingConstraintHelper::IsOptional(int t) const { + return reason_for_presence_[t] != kNoLiteralIndex; +} + +inline bool SchedulingConstraintHelper::IsPresent(int t) const { + if (reason_for_presence_[t] == kNoLiteralIndex) return true; + return trail_->Assignment().LiteralIsTrue(Literal(reason_for_presence_[t])); +} + +inline bool SchedulingConstraintHelper::IsAbsent(int t) const { + if (reason_for_presence_[t] == kNoLiteralIndex) return false; + return trail_->Assignment().LiteralIsFalse(Literal(reason_for_presence_[t])); +} + +inline bool SchedulingConstraintHelper::IsOptional(LiteralIndex lit) const { + return lit != kNoLiteralIndex; +} + +inline bool SchedulingConstraintHelper::IsPresent(LiteralIndex lit) const { + if (lit == kNoLiteralIndex) return true; + return trail_->Assignment().LiteralIsTrue(Literal(lit)); +} + +inline bool SchedulingConstraintHelper::IsAbsent(LiteralIndex lit) const { + if (lit == kNoLiteralIndex) return false; + return trail_->Assignment().LiteralIsFalse(Literal(lit)); +} + +inline void SchedulingConstraintHelper::ClearReason() { + integer_reason_.clear(); + literal_reason_.clear(); + if (other_helper_) { + other_helper_->ClearReason(); + already_added_to_other_reasons_.assign(NumTasks(), false); + } +} + +inline void SchedulingConstraintHelper::AddPresenceReason(int t) { + DCHECK(IsPresent(t)); + AddOtherReason(t); + if (reason_for_presence_[t] != kNoLiteralIndex) { + literal_reason_.push_back(Literal(reason_for_presence_[t]).Negated()); + } +} + +inline void SchedulingConstraintHelper::AddAbsenceReason(int t) { + DCHECK(IsAbsent(t)); + AddOtherReason(t); + if (reason_for_presence_[t] != kNoLiteralIndex) { + literal_reason_.push_back(Literal(reason_for_presence_[t])); + } +} + +inline void SchedulingConstraintHelper::AddSizeMinReason(int t) { + AddSizeMinReason(t, SizeMin(t)); +} + +inline void SchedulingConstraintHelper::AddGenericReason( + const AffineExpression& a, IntegerValue upper_bound, + const AffineExpression& b, const AffineExpression& c) { + if (integer_trail_->UpperBound(a) <= upper_bound) { + if (a.var != kNoIntegerVariable) { + integer_reason_.push_back(a.LowerOrEqual(upper_bound)); + } + return; + } + CHECK_NE(a.var, kNoIntegerVariable); + + // Here we assume that the upper_bound on a comes from the bound on b + c. + const IntegerValue slack = upper_bound - integer_trail_->UpperBound(b) - + integer_trail_->UpperBound(c); + CHECK_GE(slack, 0); + if (b.var == kNoIntegerVariable && c.var == kNoIntegerVariable) return; + if (b.var == kNoIntegerVariable) { + integer_reason_.push_back(c.LowerOrEqual(upper_bound - b.constant)); + } else if (c.var == kNoIntegerVariable) { + integer_reason_.push_back(b.LowerOrEqual(upper_bound - c.constant)); + } else { + integer_trail_->AppendRelaxedLinearReason( + slack, {b.coeff, c.coeff}, {NegationOf(b.var), NegationOf(c.var)}, + &integer_reason_); + } +} + +inline void SchedulingConstraintHelper::AddSizeMinReason( + int t, IntegerValue lower_bound) { + AddOtherReason(t); + DCHECK(!IsAbsent(t)); + if (lower_bound <= 0) return; + AddGenericReason(sizes_[t].Negated(), -lower_bound, minus_ends_[t], + starts_[t]); +} + +inline void SchedulingConstraintHelper::AddSizeMaxReason( + int t, IntegerValue upper_bound) { + AddOtherReason(t); + DCHECK(!IsAbsent(t)); + AddGenericReason(sizes_[t], upper_bound, ends_[t], minus_starts_[t]); +} + +inline void SchedulingConstraintHelper::AddStartMinReason( + int t, IntegerValue lower_bound) { + AddOtherReason(t); + DCHECK(!IsAbsent(t)); + AddGenericReason(minus_starts_[t], -lower_bound, minus_ends_[t], sizes_[t]); +} + +inline void SchedulingConstraintHelper::AddStartMaxReason( + int t, IntegerValue upper_bound) { + AddOtherReason(t); + DCHECK(!IsAbsent(t)); + AddGenericReason(starts_[t], upper_bound, ends_[t], NegatedSizeOrZero(t)); +} + +inline void SchedulingConstraintHelper::AddEndMinReason( + int t, IntegerValue lower_bound) { + AddOtherReason(t); + DCHECK(!IsAbsent(t)); + AddGenericReason(minus_ends_[t], -lower_bound, minus_starts_[t], + NegatedSizeOrZero(t)); +} + +inline void SchedulingConstraintHelper::AddEndMaxReason( + int t, IntegerValue upper_bound) { + AddOtherReason(t); + DCHECK(!IsAbsent(t)); + AddGenericReason(ends_[t], upper_bound, starts_[t], sizes_[t]); +} + +inline void SchedulingConstraintHelper::AddShiftedEndMaxReason( + int t, IntegerValue upper_bound) { + AddStartMaxReason(t, upper_bound - SizeMin(t)); +} + +inline void SchedulingConstraintHelper::AddEnergyAfterReason( + int t, IntegerValue energy_min, IntegerValue time) { + if (StartMin(t) >= time) { + AddStartMinReason(t, time); + } else { + AddEndMinReason(t, time + energy_min); + } + AddSizeMinReason(t, energy_min); +} + +inline void SchedulingConstraintHelper::AddEnergyMinInIntervalReason( + int t, IntegerValue time_min, IntegerValue time_max) { + const IntegerValue energy_min = SizeMin(t); + CHECK_LE(time_min + energy_min, time_max); + if (StartMin(t) >= time_min) { + AddStartMinReason(t, time_min); + } else { + AddEndMinReason(t, time_min + energy_min); + } + if (EndMax(t) <= time_max) { + AddEndMaxReason(t, time_max); + } else { + AddStartMaxReason(t, time_max - energy_min); + } + AddSizeMinReason(t, energy_min); +} + +// Cuts helpers. +void AddIntegerVariableFromIntervals(SchedulingConstraintHelper* helper, + Model* model, + std::vector* vars); + +void AppendVariablesFromCapacityAndDemands( + const AffineExpression& capacity, SchedulingDemandHelper* demands_helper, + Model* model, std::vector* vars); + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_SCHEDULING_HELPERS_H_ diff --git a/ortools/sat/scheduling_helpers_test.cc b/ortools/sat/scheduling_helpers_test.cc new file mode 100644 index 0000000000..526b195d08 --- /dev/null +++ b/ortools/sat/scheduling_helpers_test.cc @@ -0,0 +1,255 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/scheduling_helpers.h" + +#include + +#include "gtest/gtest.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_base.h" +#include "ortools/sat/intervals.h" +#include "ortools/sat/linear_constraint.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" + +namespace operations_research { +namespace sat { +namespace { + +TEST(SchedulingConstraintHelperTest, PushConstantBoundWithOptionalIntervals) { + Model model; + auto* repo = model.GetOrCreate(); + + const AffineExpression start(IntegerValue(0)); + const AffineExpression size(IntegerValue(10)); + const AffineExpression end(IntegerValue(10)); + + Literal presence2 = Literal(model.Add(NewBooleanVariable()), true); + IntervalVariable inter1 = + repo->CreateInterval(start, end, size, kNoLiteralIndex, false); + IntervalVariable inter2 = + repo->CreateInterval(start, end, size, presence2.Index(), false); + + SchedulingConstraintHelper* helper = + repo->GetOrCreateHelper({inter1, inter2}); + + EXPECT_TRUE(helper->IncreaseStartMin(1, IntegerValue(20))); + EXPECT_FALSE(model.Get(Value(presence2))); +} + +TEST(SchedulingDemandHelperTest, EnergyInWindow) { + Model model; + + const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); + const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); + const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); + auto* repo = model.GetOrCreate(); + const IntervalVariable inter = + repo->CreateInterval(start, end, size, kNoLiteralIndex, false); + + const AffineExpression demand(model.Add(NewIntegerVariable(2, 10))); + + SchedulingConstraintHelper* helper = repo->GetOrCreateHelper({inter}); + SchedulingDemandHelper demands_helper({demand}, helper, &model); + demands_helper.CacheAllEnergyValues(); + EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(4)); + + const Literal alt1 = Literal(model.Add(NewBooleanVariable()), true); + const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); + demands_helper.OverrideDecomposedEnergies( + {{{alt1, IntegerValue(2), IntegerValue(4)}, + {alt2, IntegerValue(4), IntegerValue(2)}}}); + demands_helper.CacheAllEnergyValues(); + EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(8)); + + EXPECT_EQ(0, demands_helper.EnergyMinInWindow(0, 8, 2)); + EXPECT_EQ(8, demands_helper.EnergyMinInWindow(0, 0, 10)); + EXPECT_EQ(0, demands_helper.EnergyMinInWindow(0, 2, 10)); + EXPECT_EQ(0, demands_helper.EnergyMinInWindow(0, 0, 8)); + EXPECT_EQ(4, demands_helper.EnergyMinInWindow(0, 0, 9)); +} + +TEST(SchedulingDemandHelperTest, EnergyInWindowTakeIntoAccountWindowSize) { + Model model; + + const AffineExpression start(model.Add(NewIntegerVariable(0, 4))); + const AffineExpression size(model.Add(NewIntegerVariable(6, 10))); + const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); + auto* repo = model.GetOrCreate(); + const IntervalVariable inter = + repo->CreateInterval(start, end, size, kNoLiteralIndex, false); + + const AffineExpression demand(model.Add(NewIntegerVariable(6, 10))); + + SchedulingConstraintHelper* helper = repo->GetOrCreateHelper({inter}); + SchedulingDemandHelper demands_helper({demand}, helper, &model); + demands_helper.CacheAllEnergyValues(); + + const Literal alt1 = Literal(model.Add(NewBooleanVariable()), true); + const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); + demands_helper.OverrideDecomposedEnergies( + {{{alt1, IntegerValue(8), IntegerValue(6)}, + {alt2, IntegerValue(6), IntegerValue(8)}}}); + demands_helper.CacheAllEnergyValues(); + EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(48)); + + EXPECT_EQ(6, demands_helper.EnergyMinInWindow(0, 5, 6)); +} + +TEST(SchedulingDemandHelperTest, LinearizedDemandWithAffineExpression) { + Model model; + + const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); + const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); + const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); + auto* repo = model.GetOrCreate(); + const IntervalVariable inter = + repo->CreateInterval(start, end, size, kNoLiteralIndex, false); + + const AffineExpression demand( + AffineExpression(model.Add(NewIntegerVariable(2, 10)), 2, 5)); + + SchedulingConstraintHelper* helper = repo->GetOrCreateHelper({inter}); + SchedulingDemandHelper demands_helper({demand}, helper, &model); + demands_helper.CacheAllEnergyValues(); + + LinearConstraintBuilder builder(&model); + ASSERT_TRUE(demands_helper.AddLinearizedDemand(0, &builder)); + EXPECT_EQ(builder.BuildExpression().DebugString(), "2*X3 + 5"); +} + +TEST(SchedulingDemandHelperTest, LinearizedDemandWithDecomposedEnergy) { + Model model; + + const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); + const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); + const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); + auto* repo = model.GetOrCreate(); + const IntervalVariable inter = + repo->CreateInterval(start, end, size, kNoLiteralIndex, false); + + const AffineExpression demand(model.Add(NewIntegerVariable(2, 10))); + + SchedulingConstraintHelper* helper = repo->GetOrCreateHelper({inter}); + SchedulingDemandHelper demands_helper({demand}, helper, &model); + demands_helper.CacheAllEnergyValues(); + EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(4)); + + const Literal alt1 = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var1(model.Add(NewIntegerVariable(0, 1))); + model.GetOrCreate()->AssociateToIntegerEqualValue( + alt1, var1, IntegerValue(1)); + + const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var2(model.Add(NewIntegerVariable(0, 1))); + model.GetOrCreate()->AssociateToIntegerEqualValue( + alt2, var2, IntegerValue(1)); + demands_helper.OverrideDecomposedEnergies( + {{{alt1, IntegerValue(2), IntegerValue(4)}, + {alt2, IntegerValue(4), IntegerValue(2)}}}); + demands_helper.CacheAllEnergyValues(); + LinearConstraintBuilder builder(&model); + ASSERT_TRUE(demands_helper.AddLinearizedDemand(0, &builder)); + EXPECT_EQ(builder.BuildExpression().DebugString(), "4*X4 2*X5"); +} + +TEST(SchedulingDemandHelperTest, FilteredDecomposedEnergy) { + Model model; + SatSolver* sat_solver = model.GetOrCreate(); + IntegerEncoder* encoder = model.GetOrCreate(); + + const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); + const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); + const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); + auto* repo = model.GetOrCreate(); + const IntervalVariable inter = + repo->CreateInterval(start, end, size, kNoLiteralIndex, false); + + const AffineExpression demand(model.Add(NewIntegerVariable(2, 10))); + + SchedulingConstraintHelper* helper = repo->GetOrCreateHelper({inter}); + SchedulingDemandHelper demands_helper({demand}, helper, &model); + demands_helper.CacheAllEnergyValues(); + EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(4)); + + const std::vector no_energy; + EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), no_energy); + + const Literal alt1 = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var1(model.Add(NewIntegerVariable(0, 1))); + encoder->AssociateToIntegerEqualValue(alt1, var1, IntegerValue(1)); + + const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var2(model.Add(NewIntegerVariable(0, 1))); + encoder->AssociateToIntegerEqualValue(alt2, var2, IntegerValue(1)); + const std::vector energy = { + {alt1, IntegerValue(2), IntegerValue(4)}, + {alt2, IntegerValue(4), IntegerValue(2)}}; + demands_helper.OverrideDecomposedEnergies({energy}); + demands_helper.CacheAllEnergyValues(); + EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), energy); + + EXPECT_EQ(sat_solver->EnqueueDecisionAndBackjumpOnConflict(alt1.Negated()), + 0); + const std::vector filtered_energy = { + {alt2, IntegerValue(4), IntegerValue(2)}}; + EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), filtered_energy); + EXPECT_EQ(demands_helper.DecomposedEnergies()[0], energy); +} + +TEST(SchedulingDemandHelperTest, FilteredDecomposedEnergyWithFalseLiteral) { + Model model; + IntegerEncoder* encoder = model.GetOrCreate(); + + const AffineExpression start(model.Add(NewIntegerVariable(0, 10))); + const AffineExpression size(model.Add(NewIntegerVariable(2, 10))); + const AffineExpression end(model.Add(NewIntegerVariable(0, 10))); + auto* repo = model.GetOrCreate(); + const IntervalVariable inter = + repo->CreateInterval(start, end, size, kNoLiteralIndex, false); + + const AffineExpression demand(model.Add(NewIntegerVariable(2, 10))); + + SchedulingConstraintHelper* helper = repo->GetOrCreateHelper({inter}); + SchedulingDemandHelper demands_helper({demand}, helper, &model); + demands_helper.CacheAllEnergyValues(); + EXPECT_EQ(demands_helper.EnergyMin(0), IntegerValue(4)); + + const std::vector no_energy; + EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), no_energy); + + const Literal alt1 = encoder->GetFalseLiteral(); + const IntegerVariable var1(model.Add(NewIntegerVariable(0, 1))); + model.GetOrCreate()->AssociateToIntegerEqualValue( + alt1, var1, IntegerValue(1)); + + const Literal alt2 = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var2(model.Add(NewIntegerVariable(0, 1))); + encoder->AssociateToIntegerEqualValue(alt2, var2, IntegerValue(1)); + const std::vector energy = { + {alt1, IntegerValue(2), IntegerValue(4)}, + {alt2, IntegerValue(4), IntegerValue(2)}}; + demands_helper.OverrideDecomposedEnergies({energy}); + demands_helper.CacheAllEnergyValues(); + const std::vector filtered_energy = { + {alt2, IntegerValue(4), IntegerValue(2)}}; + EXPECT_EQ(demands_helper.DecomposedEnergies()[0], filtered_energy); + EXPECT_EQ(demands_helper.FilteredDecomposedEnergy(0), filtered_energy); + EXPECT_EQ(0, model.GetOrCreate()->CurrentDecisionLevel()); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/synchronization.cc b/ortools/sat/synchronization.cc index a392a5a4c7..2ea2419e93 100644 --- a/ortools/sat/synchronization.cc +++ b/ortools/sat/synchronization.cc @@ -619,10 +619,12 @@ void SharedResponseManager::FillObjectiveValuesInResponse( response->set_gap_integral(gap_integral_); } -void SharedResponseManager::NewSolution( - absl::Span solution_values, const std::string& solution_info, - Model* model) { +std::shared_ptr::Solution> +SharedResponseManager::NewSolution(absl::Span solution_values, + const std::string& solution_info, + Model* model) { absl::MutexLock mutex_lock(&mutex_); + std::shared_ptr::Solution> ret; // For SAT problems, we add the solution to the solution pool for retrieval // later. @@ -631,7 +633,7 @@ void SharedResponseManager::NewSolution( solution.variable_values.assign(solution_values.begin(), solution_values.end()); solution.info = solution_info; - solutions_.Add(solution); + ret = solutions_.Add(solution); } else { const int64_t objective_value = ComputeInnerObjective(*objective_or_null_, solution_values); @@ -642,12 +644,12 @@ void SharedResponseManager::NewSolution( solution_values.end()); solution.rank = objective_value; solution.info = solution_info; - solutions_.Add(solution); + ret = solutions_.Add(solution); // Ignore any non-strictly improving solution. - if (objective_value > inner_objective_upper_bound_) return; + if (objective_value > inner_objective_upper_bound_) return ret; - // Our inner_objective_lower_bound_ should be a globaly valid bound, until + // Our inner_objective_lower_bound_ should be a globally valid bound, until // the problem become infeasible (i.e the lb > ub) in which case the bound // is no longer globally valid. Here, because we have a strictly improving // solution, we shouldn't be in the infeasible setting yet. @@ -758,6 +760,8 @@ void SharedResponseManager::NewSolution( CHECK_OK(file::SetTextProto(file, response, file::Defaults())); } #endif // __PORTABLE_PLATFORM__ + + return ret; } bool SharedResponseManager::ProblemIsSolved() const { diff --git a/ortools/sat/synchronization.h b/ortools/sat/synchronization.h index 613fcf79bc..afd6be261a 100644 --- a/ortools/sat/synchronization.h +++ b/ortools/sat/synchronization.h @@ -121,7 +121,9 @@ class SharedSolutionRepository { // right away. One must call Synchronize for this to happen. In order to be // deterministic, this will keep all solutions until Synchronize() is called, // so we need to be careful not to generate too many solutions at once. - void Add(Solution solution); + // + // Returns a shared pointer to the solution that was stored in the repository. + std::shared_ptr Add(Solution solution); // Updates the current pool of solution with the one recently added. Note that // we use a stable ordering of solutions, so the final pool will be @@ -355,9 +357,11 @@ class SharedResponseManager { // Reads the new solution from the response and update our state. For an // optimization problem, we only do something if the solution is strictly - // improving. - void NewSolution(absl::Span solution_values, - const std::string& solution_info, Model* model = nullptr); + // improving. Returns a shared pointer to the solution that was potentially + // stored in the repository. + std::shared_ptr::Solution> + NewSolution(absl::Span solution_values, + const std::string& solution_info, Model* model = nullptr); // Changes the solution to reflect the fact that the "improving" problem is // infeasible. This means that if we have a solution, we have proven @@ -903,15 +907,17 @@ SharedSolutionRepository::GetRandomBiasedSolution( } template -void SharedSolutionRepository::Add(Solution solution) { - if (num_solutions_to_keep_ <= 0) return; +std::shared_ptr::Solution> +SharedSolutionRepository::Add(Solution solution) { std::shared_ptr solution_ptr = std::make_shared(std::move(solution)); + if (num_solutions_to_keep_ <= 0) return std::move(solution_ptr); { absl::MutexLock mutex_lock(&mutex_); ++num_added_; - new_solutions_.push_back(std::move(solution_ptr)); + new_solutions_.push_back(solution_ptr); } + return solution_ptr; } template diff --git a/ortools/sat/table.cc b/ortools/sat/table.cc index 505450567f..09d69129ce 100644 --- a/ortools/sat/table.cc +++ b/ortools/sat/table.cc @@ -29,9 +29,12 @@ namespace sat { std::function LiteralTableConstraint( absl::Span> literal_tuples, - const std::vector& line_literals) { - return [=, literal_tuples = std::vector>( - literal_tuples.begin(), literal_tuples.end())](Model* model) { + absl::Span line_literals) { + return [=, + line_literals = + std::vector(line_literals.begin(), line_literals.end()), + literal_tuples = std::vector>( + literal_tuples.begin(), literal_tuples.end())](Model* model) { CHECK_EQ(literal_tuples.size(), line_literals.size()); const int num_tuples = line_literals.size(); if (num_tuples == 0) return; diff --git a/ortools/sat/table.h b/ortools/sat/table.h index b2665bbb62..176e2559a4 100644 --- a/ortools/sat/table.h +++ b/ortools/sat/table.h @@ -30,7 +30,7 @@ namespace sat { // literal_tuples matrix is true. std::function LiteralTableConstraint( absl::Span> literal_tuples, - const std::vector& line_literals); + absl::Span line_literals); } // namespace sat } // namespace operations_research diff --git a/ortools/sat/timetable.cc b/ortools/sat/timetable.cc index 42e4f708d5..84d6adf08c 100644 --- a/ortools/sat/timetable.cc +++ b/ortools/sat/timetable.cc @@ -24,6 +24,7 @@ #include "ortools/sat/intervals.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" +#include "ortools/sat/scheduling_helpers.h" #include "ortools/util/strong_integers.h" namespace operations_research { diff --git a/ortools/sat/timetable.h b/ortools/sat/timetable.h index 6b31a07163..ba937c0844 100644 --- a/ortools/sat/timetable.h +++ b/ortools/sat/timetable.h @@ -19,9 +19,9 @@ #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" -#include "ortools/sat/intervals.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" +#include "ortools/sat/scheduling_helpers.h" #include "ortools/util/strong_integers.h" namespace operations_research { diff --git a/ortools/sat/timetable_edgefinding.cc b/ortools/sat/timetable_edgefinding.cc index bd05a3901b..f9ee14890f 100644 --- a/ortools/sat/timetable_edgefinding.cc +++ b/ortools/sat/timetable_edgefinding.cc @@ -14,14 +14,15 @@ #include "ortools/sat/timetable_edgefinding.h" #include +#include #include #include "absl/log/check.h" #include "ortools/base/iterator_adaptors.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" -#include "ortools/sat/intervals.h" #include "ortools/sat/model.h" +#include "ortools/sat/scheduling_helpers.h" #include "ortools/util/strong_integers.h" namespace operations_research { @@ -152,6 +153,9 @@ void TimeTableEdgeFinding::BuildTimeTable() { bool TimeTableEdgeFinding::TimeTableEdgeFindingPass() { if (!demands_->CacheAllEnergyValues()) return true; + IntegerValue earliest_start_min = std::numeric_limits::max(); + IntegerValue latest_end_max = std::numeric_limits::min(); + IntegerValue maximum_demand_min = IntegerValue(0); // Initialize the data structures and build the free parts. // -------------------------------------------------------- for (int t = 0; t < num_tasks_; ++t) { @@ -161,6 +165,10 @@ bool TimeTableEdgeFinding::TimeTableEdgeFindingPass() { const IntegerValue demand_min = demands_->DemandMin(t); IntegerValue mandatory_energy(0); + earliest_start_min = std::min(earliest_start_min, helper_->StartMin(t)); + latest_end_max = std::max(latest_end_max, helper_->EndMax(t)); + maximum_demand_min = std::max(maximum_demand_min, demand_min); + if (start_max >= end_min) { size_free_[t] = helper_->SizeMin(t); } else { @@ -174,6 +182,12 @@ bool TimeTableEdgeFinding::TimeTableEdgeFindingPass() { DCHECK_GE(energy_free_[t], 0); } + if (AtMinOrMaxInt64I(CapProdI(CapSubI(latest_end_max, earliest_start_min), + maximum_demand_min))) { + // Avoid possible overflow. + return true; + } + // TODO(user): Is it possible to have a 'higher' mandatory profile using // the min energy instead of the demand_min * size_min? How can we incorporate // this extra energy in the mandatory profile ? diff --git a/ortools/sat/timetable_edgefinding.h b/ortools/sat/timetable_edgefinding.h index 5842878a89..2d153f1e55 100644 --- a/ortools/sat/timetable_edgefinding.h +++ b/ortools/sat/timetable_edgefinding.h @@ -20,6 +20,7 @@ #include "ortools/sat/integer_base.h" #include "ortools/sat/intervals.h" #include "ortools/sat/model.h" +#include "ortools/sat/scheduling_helpers.h" namespace operations_research { namespace sat { diff --git a/ortools/sat/timetable_test.cc b/ortools/sat/timetable_test.cc index 60247053ac..10733ba526 100644 --- a/ortools/sat/timetable_test.cc +++ b/ortools/sat/timetable_test.cc @@ -35,6 +35,7 @@ #include "ortools/sat/precedences.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_solver.h" +#include "ortools/sat/scheduling_helpers.h" namespace operations_research { namespace sat { @@ -95,8 +96,8 @@ bool TestTimeTablingPropagation(absl::Span tasks, // Propagate properly the other bounds of the intervals. EXPECT_TRUE(precedences->Propagate()); - SchedulingConstraintHelper* helper = model.TakeOwnership( - new SchedulingConstraintHelper(interval_vars, &model)); + auto* repo = model.GetOrCreate(); + SchedulingConstraintHelper* helper = repo->GetOrCreateHelper(interval_vars); SchedulingDemandHelper* demands_helper = model.TakeOwnership(new SchedulingDemandHelper(demands, helper, &model)); diff --git a/ortools/sat/util.h b/ortools/sat/util.h index 7f818af0b3..921b3029ae 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -66,6 +66,8 @@ class IdentityMap { template class CompactVectorVector { public: + using value_type = V; + // Size of the "key" space, always in [0, size()). size_t size() const; bool empty() const; diff --git a/ortools/sat/var_domination.cc b/ortools/sat/var_domination.cc index 3e03f3a448..91b0b655a3 100644 --- a/ortools/sat/var_domination.cc +++ b/ortools/sat/var_domination.cc @@ -25,6 +25,7 @@ #include #include +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" @@ -227,7 +228,7 @@ bool VarDomination::EndFirstPhase() { // complexity is borned by this number times the number of entries in the // constraints. Still we should in most situation be a lot lower than that. const int kMaxInitialSize = 50; - std::vector cropped_vars; + absl::btree_set cropped_vars; util_intops::StrongVector is_cropped( num_vars_with_negation_, false); @@ -261,12 +262,12 @@ bool VarDomination::EndFirstPhase() { buffer_.push_back(x); if (new_size >= kMaxInitialSize) { is_cropped[var] = true; - cropped_vars.push_back(var); + cropped_vars.insert(var); } } } else { is_cropped[var] = true; - cropped_vars.push_back(var); + cropped_vars.insert(var); for (int i = 0; i < 200; ++i) { const IntegerVariable x = to_scan[i]; if (var_sig & ~block_down_signatures_[x]) continue; // !included. diff --git a/ortools/util/fp_roundtrip_conv.h b/ortools/util/fp_roundtrip_conv.h index 57b8046474..f16a4c2851 100644 --- a/ortools/util/fp_roundtrip_conv.h +++ b/ortools/util/fp_roundtrip_conv.h @@ -26,7 +26,7 @@ namespace operations_research { -// True if the plateform supports `double` to std::to_chars(). +// True if the platform supports `double` to std::to_chars(). // // std::to_chars() for double is not yet supported on Emscripten, Android and // iOS; they only implement std::to_chars() for integers.