sync: cp-sat bug fixes, stringview, fix strong int missing stl code, update graph

This commit is contained in:
Laurent Perron
2025-04-01 06:29:45 +02:00
parent d51bce0df9
commit 4af14d7d4f
19 changed files with 505 additions and 556 deletions

View File

@@ -455,6 +455,58 @@ namespace std {
template <typename StrongIntName, typename ValueType>
struct hash<util_intops::StrongInt<StrongIntName, ValueType>>
: util_intops::StrongInt<StrongIntName, ValueType>::Hasher {};
template <typename TagType, typename NativeType>
struct numeric_limits<util_intops::StrongInt<TagType, NativeType>> {
private:
using StrongIntT = util_intops::StrongInt<TagType, NativeType>;
public:
// NOLINTBEGIN(google3-readability-class-member-naming)
static constexpr bool is_specialized = true;
static constexpr bool is_signed = numeric_limits<NativeType>::is_signed;
static constexpr bool is_integer = numeric_limits<NativeType>::is_integer;
static constexpr bool is_exact = numeric_limits<NativeType>::is_exact;
static constexpr bool has_infinity = numeric_limits<NativeType>::has_infinity;
static constexpr bool has_quiet_NaN =
numeric_limits<NativeType>::has_quiet_NaN;
static constexpr bool has_signaling_NaN =
numeric_limits<NativeType>::has_signaling_NaN;
static constexpr float_denorm_style has_denorm =
numeric_limits<NativeType>::has_denorm;
static constexpr bool has_denorm_loss =
numeric_limits<NativeType>::has_denorm_loss;
static constexpr float_round_style round_style =
numeric_limits<NativeType>::round_style;
static constexpr bool is_iec559 = numeric_limits<NativeType>::is_iec559;
static constexpr bool is_bounded = numeric_limits<NativeType>::is_bounded;
static constexpr bool is_modulo = numeric_limits<NativeType>::is_modulo;
static constexpr int digits = numeric_limits<NativeType>::digits;
static constexpr int digits10 = numeric_limits<NativeType>::digits10;
static constexpr int max_digits10 = numeric_limits<NativeType>::max_digits10;
static constexpr int radix = numeric_limits<NativeType>::radix;
static constexpr int min_exponent = numeric_limits<NativeType>::min_exponent;
static constexpr int min_exponent10 =
numeric_limits<NativeType>::min_exponent10;
static constexpr int max_exponent = numeric_limits<NativeType>::max_exponent;
static constexpr int max_exponent10 =
numeric_limits<NativeType>::max_exponent10;
static constexpr bool traps = numeric_limits<NativeType>::traps;
static constexpr bool tinyness_before =
numeric_limits<NativeType>::tinyness_before;
// NOLINTEND(google3-readability-class-member-naming)
static constexpr StrongIntT(min)() { return StrongIntT(numeric_limits<NativeType>::min()); }
static constexpr StrongIntT lowest() { return StrongIntT(numeric_limits<NativeType>::min()); }
static constexpr StrongIntT(max)() { return StrongIntT(numeric_limits<NativeType>::max()); }
static constexpr StrongIntT epsilon() { return StrongIntT(numeric_limits<NativeType>::epsilon()); }
static constexpr StrongIntT round_error() { return StrongIntT(numeric_limits<NativeType>::round_error()); }
static constexpr StrongIntT infinity() { return StrongIntT(numeric_limits<NativeType>::infinity()); }
static constexpr StrongIntT quiet_NaN() { return StrongIntT(numeric_limits<NativeType>::quiet_NaN()); }
static constexpr StrongIntT signaling_NaN() { return StrongIntT(numeric_limits<NativeType>::signaling_NaN()); }
static constexpr StrongIntT denorm_min() { return StrongIntT(numeric_limits<NativeType>::denorm_min()); }
};
} // namespace std
#endif // OR_TOOLS_BASE_STRONG_INT_H_

View File

@@ -766,6 +766,8 @@ cc_test(
":io",
"//ortools/base:dump_vars",
"//ortools/base:gmock_main",
"//ortools/base:intops",
"//ortools/base:strong_vector",
"//ortools/util:flat_matrix",
"@abseil-cpp//absl/algorithm:container",
"@abseil-cpp//absl/log:check",

View File

@@ -81,9 +81,6 @@ PathWithLength ConstrainedShortestPathsOnDag(
// Advanced API.
// -----------------------------------------------------------------------------
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
struct GraphPathWithLength {
double length = 0.0;
// The returned arc indices points into the `arcs_with_length` passed to the
@@ -97,9 +94,6 @@ struct GraphPathWithLength {
// computations efficiently on the given DAG (on which resources do not change).
// `GraphType` can use one of the interfaces defined in `util/graph/graph.h`.
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
class ConstrainedShortestPathsOnDagWrapper {
public:
using NodeIndex = typename GraphType::NodeIndex;
@@ -285,9 +279,6 @@ std::vector<T> GetInversePermutation(const std::vector<T>& permutation);
// -----------------------------------------------------------------------------
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
ConstrainedShortestPathsOnDagWrapper<GraphType>::
ConstrainedShortestPathsOnDagWrapper(
const GraphType* graph, const std::vector<double>* arc_lengths,
@@ -543,9 +534,6 @@ ConstrainedShortestPathsOnDagWrapper<GraphType>::
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
GraphPathWithLength<GraphType> ConstrainedShortestPathsOnDagWrapper<
GraphType>::RunConstrainedShortestPathOnDag() {
if (source_is_destination_.has_value()) {
@@ -664,9 +652,6 @@ GraphPathWithLength<GraphType> ConstrainedShortestPathsOnDagWrapper<
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
void ConstrainedShortestPathsOnDagWrapper<GraphType>::
RunHalfConstrainedShortestPathOnDag(
const GraphType& reverse_graph, absl::Span<const double> arc_lengths,
@@ -792,9 +777,6 @@ void ConstrainedShortestPathsOnDagWrapper<GraphType>::
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
typename GraphType::ArcIndex
ConstrainedShortestPathsOnDagWrapper<GraphType>::MergeHalfRuns(
const GraphType& graph, absl::Span<const double> arc_lengths,
@@ -879,9 +861,6 @@ ConstrainedShortestPathsOnDagWrapper<GraphType>::MergeHalfRuns(
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
std::vector<typename GraphType::ArcIndex>
ConstrainedShortestPathsOnDagWrapper<GraphType>::ArcPathTo(
const int best_label_index,
@@ -901,9 +880,6 @@ ConstrainedShortestPathsOnDagWrapper<GraphType>::ArcPathTo(
}
template <typename GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
std::vector<typename GraphType::NodeIndex>
ConstrainedShortestPathsOnDagWrapper<GraphType>::NodePathImpliedBy(
absl::Span<const ArcIndex> arc_path, const GraphType& graph) const {

View File

@@ -15,9 +15,7 @@
#define OR_TOOLS_GRAPH_DAG_SHORTEST_PATH_H_
#include <cmath>
#if __cplusplus >= 202002L
#include <concepts>
#endif
#include <cstddef>
#include <functional>
#include <limits>
#include <vector>
@@ -82,39 +80,17 @@ std::vector<PathWithLength> KShortestPathsOnDag(
// -----------------------------------------------------------------------------
// Advanced API.
// -----------------------------------------------------------------------------
// This concept only enforces the standard graph API needed for all algorithms
// on DAGs. One could add the requirement of being a DAG wihtin this concept
// (which is done before running the algorithm).
#if __cplusplus >= 202002L
template <class GraphType>
concept DagGraphType = requires(GraphType graph) {
{ typename GraphType::NodeIndex{} };
{ typename GraphType::ArcIndex{} };
{ graph.num_nodes() } -> std::same_as<typename GraphType::NodeIndex>;
{ graph.num_arcs() } -> std::same_as<typename GraphType::ArcIndex>;
{ graph.OutgoingArcs(typename GraphType::NodeIndex{}) };
{
graph.Tail(typename GraphType::ArcIndex{})
} -> std::same_as<typename GraphType::NodeIndex>;
{
graph.Head(typename GraphType::ArcIndex{})
} -> std::same_as<typename GraphType::NodeIndex>;
{ graph.Build() };
};
#endif
// A wrapper that holds the memory needed to run many shortest path computations
// efficiently on the given DAG. One call of `RunShortestPathOnDag()` has time
// complexity O(|E| + |V|) and space complexity O(|V|).
// `GraphType` can use any of the interfaces defined in `util/graph/graph.h`.
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
// `ArcLengthContainer` can be any container of doubles.
template <class GraphType, typename ArcLengthContainer = std::vector<double>>
class ShortestPathsOnDagWrapper {
public:
using NodeIndex = typename GraphType::NodeIndex;
using ArcIndex = typename GraphType::ArcIndex;
using ArcLengths = ArcLengthContainer;
// IMPORTANT: All arguments must outlive the class.
//
@@ -138,7 +114,7 @@ class ShortestPathsOnDagWrapper {
// so will obviously invalidate the result API of the last shortest path run,
// which could return an upper bound, junk, or crash.
ShortestPathsOnDagWrapper(const GraphType* graph,
const std::vector<double>* arc_lengths,
const ArcLengths* arc_lengths,
absl::Span<const NodeIndex> topological_order);
// Computes the shortest path to all reachable nodes from the given sources.
@@ -151,7 +127,9 @@ class ShortestPathsOnDagWrapper {
const std::vector<NodeIndex>& reached_nodes() const { return reached_nodes_; }
// Returns the length of the shortest path from `node`'s source to `node`.
double LengthTo(NodeIndex node) const { return length_from_sources_[node]; }
double LengthTo(NodeIndex node) const {
return length_from_sources_[static_cast<size_t>(node)];
}
std::vector<double> LengthTo() const { return length_from_sources_; }
// Returns the list of all the arcs in the shortest path from `node`'s
@@ -164,12 +142,12 @@ class ShortestPathsOnDagWrapper {
// Accessors to the underlying graph and arc lengths.
const GraphType& graph() const { return *graph_; }
const std::vector<double>& arc_lengths() const { return *arc_lengths_; }
const ArcLengths& arc_lengths() const { return *arc_lengths_; }
private:
static constexpr double kInf = std::numeric_limits<double>::infinity();
const GraphType* const graph_;
const std::vector<double>* const arc_lengths_;
const ArcLengths* const arc_lengths_;
absl::Span<const NodeIndex> const topological_order_;
// Data about the last call of the RunShortestPathOnDag() function.
@@ -185,14 +163,12 @@ class ShortestPathsOnDagWrapper {
// `GraphType` can use any of the interfaces defined in `util/graph/graph.h`.
// IMPORTANT: Only use if `path_count > 1` (k > 1) otherwise use
// `ShortestPathsOnDagWrapper`.
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
template <class GraphType, typename ArcLengthContainer = std::vector<double>>
class KShortestPathsOnDagWrapper {
public:
using NodeIndex = typename GraphType::NodeIndex;
using ArcIndex = typename GraphType::ArcIndex;
using ArcLengths = ArcLengthContainer;
// IMPORTANT: All arguments must outlive the class.
//
@@ -216,7 +192,7 @@ class KShortestPathsOnDagWrapper {
// so will obviously invalidate the result API of the last shortest path run,
// which could return an upper bound, junk, or crash.
KShortestPathsOnDagWrapper(const GraphType* graph,
const std::vector<double>* arc_lengths,
const ArcLengths* arc_lengths,
absl::Span<const NodeIndex> topological_order,
int path_count);
@@ -244,14 +220,14 @@ class KShortestPathsOnDagWrapper {
// Accessors to the underlying graph and arc lengths.
const GraphType& graph() const { return *graph_; }
const std::vector<double>& arc_lengths() const { return *arc_lengths_; }
const ArcLengths& arc_lengths() const { return *arc_lengths_; }
int path_count() const { return path_count_; }
private:
static constexpr double kInf = std::numeric_limits<double>::infinity();
const GraphType* const graph_;
const std::vector<double>* const arc_lengths_;
const ArcLengths* const arc_lengths_;
absl::Span<const NodeIndex> const topological_order_;
const int path_count_;
@@ -269,10 +245,7 @@ class KShortestPathsOnDagWrapper {
std::vector<NodeIndex> reached_nodes_;
};
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
template <class GraphType, typename ArcLengths>
absl::Status TopologicalOrderIsValid(
const GraphType& graph,
absl::Span<const typename GraphType::NodeIndex> topological_order);
@@ -286,9 +259,6 @@ absl::Status TopologicalOrderIsValid(
// (2) assign into an index rather than with push_back
// (3) return by absl::Span (or return a copy) with known size.
template <typename GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
std::vector<typename GraphType::NodeIndex> NodePathImpliedBy(
absl::Span<const typename GraphType::ArcIndex> arc_path,
const GraphType& graph) {
@@ -303,47 +273,47 @@ std::vector<typename GraphType::NodeIndex> NodePathImpliedBy(
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
void CheckNodeIsValid(typename GraphType::NodeIndex node,
const GraphType& graph) {
CHECK_GE(node, 0) << "Node must be nonnegative. Input value: " << node;
CHECK_GE(node, typename GraphType::NodeIndex(0))
<< "Node must be nonnegative. Input value: " << node;
CHECK_LT(node, graph.num_nodes())
<< "Node must be a valid node. Input value: " << node
<< ". Number of nodes in the input graph: " << graph.num_nodes();
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
absl::Status TopologicalOrderIsValid(
const GraphType& graph,
absl::Span<const typename GraphType::NodeIndex> topological_order) {
using NodeIndex = typename GraphType::NodeIndex;
using ArcIndex = typename GraphType::ArcIndex;
const NodeIndex num_nodes = graph.num_nodes();
if (topological_order.size() != num_nodes) {
if (topological_order.size() != static_cast<size_t>(num_nodes)) {
return absl::InvalidArgumentError(absl::StrFormat(
"topological_order.size() = %i, != graph.num_nodes() = %i",
"topological_order.size() = %i, != graph.num_nodes() = %v",
topological_order.size(), num_nodes));
}
std::vector<NodeIndex> inverse_topology(num_nodes, -1);
for (NodeIndex node = 0; node < topological_order.size(); ++node) {
if (inverse_topology[topological_order[node]] >= 0) {
std::vector<NodeIndex> inverse_topology(static_cast<size_t>(num_nodes),
GraphType::kNilNode);
for (NodeIndex node(0); node < num_nodes; ++node) {
if (inverse_topology[static_cast<size_t>(
topological_order[static_cast<size_t>(node)])] !=
GraphType::kNilNode) {
return absl::InvalidArgumentError(
absl::StrFormat("node % i appears twice in topological order",
topological_order[node]));
absl::StrFormat("node %v appears twice in topological order",
topological_order[static_cast<size_t>(node)]));
}
inverse_topology[topological_order[node]] = node;
inverse_topology[static_cast<size_t>(
topological_order[static_cast<size_t>(node)])] = node;
}
for (NodeIndex tail = 0; tail < num_nodes; ++tail) {
for (NodeIndex tail(0); tail < num_nodes; ++tail) {
for (const ArcIndex arc : graph.OutgoingArcs(tail)) {
const NodeIndex head = graph.Head(arc);
if (inverse_topology[tail] >= inverse_topology[head]) {
if (inverse_topology[static_cast<size_t>(tail)] >=
inverse_topology[static_cast<size_t>(head)]) {
return absl::InvalidArgumentError(absl::StrFormat(
"arc (%i, %i) is inconsistent with topological order", tail, head));
"arc (%v, %v) is inconsistent with topological order", tail, head));
}
}
}
@@ -353,21 +323,20 @@ absl::Status TopologicalOrderIsValid(
// -----------------------------------------------------------------------------
// ShortestPathsOnDagWrapper implementation.
// -----------------------------------------------------------------------------
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
ShortestPathsOnDagWrapper<GraphType>::ShortestPathsOnDagWrapper(
const GraphType* graph, const std::vector<double>* arc_lengths,
template <class GraphType, typename ArcLengths>
ShortestPathsOnDagWrapper<GraphType, ArcLengths>::ShortestPathsOnDagWrapper(
const GraphType* graph, const ArcLengths* arc_lengths,
absl::Span<const NodeIndex> topological_order)
: graph_(graph),
arc_lengths_(arc_lengths),
topological_order_(topological_order) {
const size_t num_nodes = static_cast<size_t>(graph_->num_nodes());
CHECK(graph_ != nullptr);
CHECK(arc_lengths_ != nullptr);
CHECK_GT(graph_->num_nodes(), 0) << "The graph is empty: it has no nodes";
CHECK_GT(num_nodes, 0) << "The graph is empty: it has no nodes";
#ifndef NDEBUG
CHECK_EQ(arc_lengths_->size(), graph_->num_arcs());
CHECK_EQ(typename GraphType::ArcIndex(arc_lengths_->size()),
graph_->num_arcs());
for (const double arc_length : *arc_lengths_) {
CHECK(arc_length != -kInf && !std::isnan(arc_length))
<< absl::StrFormat("length cannot be -inf nor NaN");
@@ -378,16 +347,13 @@ ShortestPathsOnDagWrapper<GraphType>::ShortestPathsOnDagWrapper(
// Memory allocation is done here and only once in order to avoid reallocation
// at each call of `RunShortestPathOnDag()` for better performance.
length_from_sources_.resize(graph_->num_nodes(), kInf);
incoming_shortest_path_arc_.resize(graph_->num_nodes(), -1);
reached_nodes_.reserve(graph_->num_nodes());
length_from_sources_.resize(num_nodes, kInf);
incoming_shortest_path_arc_.resize(num_nodes, GraphType::kNilArc);
reached_nodes_.reserve(num_nodes);
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
template <class GraphType, typename ArcLengths>
void ShortestPathsOnDagWrapper<GraphType, ArcLengths>::RunShortestPathOnDag(
absl::Span<const NodeIndex> sources) {
// Caching the vector addresses allow to not fetch it on each access.
const absl::Span<double> length_from_sources =
@@ -398,7 +364,7 @@ void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
// performance, so it only makes sense for nodes that are reachable from at
// least one source, the other ones will contain junk.
for (const NodeIndex node : reached_nodes_) {
length_from_sources[node] = kInf;
length_from_sources[static_cast<size_t>(node)] = kInf;
}
DCHECK(std::all_of(length_from_sources.begin(), length_from_sources.end(),
[](double l) { return l == kInf; }));
@@ -406,11 +372,12 @@ void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
for (const NodeIndex source : sources) {
CheckNodeIsValid(source, *graph_);
length_from_sources[source] = 0.0;
length_from_sources[static_cast<size_t>(source)] = 0.0;
}
for (const NodeIndex tail : topological_order_) {
const double length_to_tail = length_from_sources[tail];
const double length_to_tail =
length_from_sources[static_cast<size_t>(tail)];
// Stop exploring a node as soon as its length to all sources is +inf.
if (length_to_tail == kInf) {
continue;
@@ -418,37 +385,35 @@ void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
reached_nodes_.push_back(tail);
for (const ArcIndex arc : graph_->OutgoingArcs(tail)) {
const NodeIndex head = graph_->Head(arc);
DCHECK(arc_lengths[arc] != -kInf);
const double length_to_head = arc_lengths[arc] + length_to_tail;
if (length_to_head < length_from_sources[head]) {
length_from_sources[head] = length_to_head;
incoming_shortest_path_arc_[head] = arc;
DCHECK(arc_lengths[static_cast<size_t>(arc)] != -kInf);
const double length_to_head =
arc_lengths[static_cast<size_t>(arc)] + length_to_tail;
if (length_to_head < length_from_sources[static_cast<size_t>(head)]) {
length_from_sources[static_cast<size_t>(head)] = length_to_head;
incoming_shortest_path_arc_[static_cast<size_t>(head)] = arc;
}
}
}
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
bool ShortestPathsOnDagWrapper<GraphType>::IsReachable(NodeIndex node) const {
template <class GraphType, typename ArcLengths>
bool ShortestPathsOnDagWrapper<GraphType, ArcLengths>::IsReachable(
NodeIndex node) const {
CheckNodeIsValid(node, *graph_);
return length_from_sources_[node] < kInf;
return length_from_sources_[static_cast<size_t>(node)] < kInf;
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
template <class GraphType, typename ArcLengths>
std::vector<typename GraphType::ArcIndex>
ShortestPathsOnDagWrapper<GraphType>::ArcPathTo(NodeIndex node) const {
ShortestPathsOnDagWrapper<GraphType, ArcLengths>::ArcPathTo(
NodeIndex node) const {
CHECK(IsReachable(node));
std::vector<ArcIndex> arc_path;
NodeIndex current_node = node;
for (int i = 0; i < graph_->num_nodes(); ++i) {
ArcIndex current_arc = incoming_shortest_path_arc_[current_node];
if (current_arc == -1) {
for (NodeIndex i(0); i < graph_->num_nodes(); ++i) {
ArcIndex current_arc =
incoming_shortest_path_arc_[static_cast<size_t>(current_node)];
if (current_arc == GraphType::kNilArc) {
break;
}
arc_path.push_back(current_arc);
@@ -458,12 +423,10 @@ ShortestPathsOnDagWrapper<GraphType>::ArcPathTo(NodeIndex node) const {
return arc_path;
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
template <class GraphType, typename ArcLengths>
std::vector<typename GraphType::NodeIndex>
ShortestPathsOnDagWrapper<GraphType>::NodePathTo(NodeIndex node) const {
ShortestPathsOnDagWrapper<GraphType, ArcLengths>::NodePathTo(
NodeIndex node) const {
const std::vector<typename GraphType::ArcIndex> arc_path = ArcPathTo(node);
if (arc_path.empty()) {
return {node};
@@ -474,12 +437,9 @@ ShortestPathsOnDagWrapper<GraphType>::NodePathTo(NodeIndex node) const {
// -----------------------------------------------------------------------------
// KShortestPathsOnDagWrapper implementation.
// -----------------------------------------------------------------------------
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
const GraphType* graph, const std::vector<double>* arc_lengths,
template <class GraphType, typename ArcLengths>
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::KShortestPathsOnDagWrapper(
const GraphType* graph, const ArcLengths* arc_lengths,
absl::Span<const NodeIndex> topological_order, const int path_count)
: graph_(graph),
arc_lengths_(arc_lengths),
@@ -487,10 +447,12 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
path_count_(path_count) {
CHECK(graph_ != nullptr);
CHECK(arc_lengths_ != nullptr);
CHECK_GT(graph_->num_nodes(), 0) << "The graph is empty: it has no nodes";
const size_t num_nodes = static_cast<size_t>(graph_->num_nodes());
CHECK_GT(num_nodes, 0) << "The graph is empty: it has no nodes";
CHECK_GT(path_count_, 0) << "path_count must be greater than 0";
#ifndef NDEBUG
CHECK_EQ(arc_lengths_->size(), graph_->num_arcs());
CHECK_EQ(typename GraphType::ArcIndex(arc_lengths_->size()),
graph_->num_arcs());
for (const double arc_length : *arc_lengths_) {
CHECK(arc_length != -kInf && !std::isnan(arc_length))
<< absl::StrFormat("length cannot be -inf nor NaN");
@@ -501,9 +463,9 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
// TODO(b/332475713): Optimize if reverse graph is already provided in
// `GraphType`.
const int num_arcs = graph_->num_arcs();
const ArcIndex num_arcs = graph_->num_arcs();
reverse_graph_ = GraphType(graph_->num_nodes(), num_arcs);
for (ArcIndex arc_index = 0; arc_index < num_arcs; ++arc_index) {
for (ArcIndex arc_index(0); arc_index < num_arcs; ++arc_index) {
reverse_graph_.AddArc(graph->Head(arc_index), graph->Tail(arc_index));
}
std::vector<ArcIndex> permutation;
@@ -511,7 +473,7 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
arc_indices_.resize(permutation.size());
if (!permutation.empty()) {
for (int i = 0; i < permutation.size(); ++i) {
arc_indices_[permutation[i]] = i;
arc_indices_[static_cast<size_t>(permutation[i])] = ArcIndex(i);
}
}
@@ -521,19 +483,16 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
incoming_shortest_paths_arc_.resize(path_count_);
incoming_shortest_paths_index_.resize(path_count_);
for (int k = 0; k < path_count_; ++k) {
lengths_from_sources_[k].resize(graph_->num_nodes(), kInf);
incoming_shortest_paths_arc_[k].resize(graph_->num_nodes(), -1);
incoming_shortest_paths_index_[k].resize(graph_->num_nodes(), -1);
lengths_from_sources_[k].resize(num_nodes, kInf);
incoming_shortest_paths_arc_[k].resize(num_nodes, GraphType::kNilArc);
incoming_shortest_paths_index_[k].resize(num_nodes, -1);
}
is_source_.resize(graph_->num_nodes(), false);
reached_nodes_.reserve(graph_->num_nodes());
is_source_.resize(num_nodes, false);
reached_nodes_.reserve(num_nodes);
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
template <class GraphType, typename ArcLengths>
void KShortestPathsOnDagWrapper<GraphType, ArcLengths>::RunKShortestPathOnDag(
absl::Span<const NodeIndex> sources) {
// Caching the vector addresses allow to not fetch it on each access.
const absl::Span<const double> arc_lengths = *arc_lengths_;
@@ -544,9 +503,9 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
// least one source, the other ones will contain junk.
for (const NodeIndex node : reached_nodes_) {
is_source_[node] = false;
is_source_[static_cast<size_t>(node)] = false;
for (int k = 0; k < path_count_; ++k) {
lengths_from_sources_[k][node] = kInf;
lengths_from_sources_[k][static_cast<size_t>(node)] = kInf;
}
}
reached_nodes_.clear();
@@ -560,14 +519,14 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
for (const NodeIndex source : sources) {
CheckNodeIsValid(source, *graph_);
is_source_[source] = true;
is_source_[static_cast<size_t>(source)] = true;
}
struct IncomingArcPath {
double path_length = 0.0;
ArcIndex arc_index = 0;
ArcIndex arc_index = ArcIndex(0);
double arc_length = 0.0;
NodeIndex from = 0;
NodeIndex from = NodeIndex(0);
int path_index = 0;
bool operator<(const IncomingArcPath& other) const {
@@ -580,18 +539,19 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
auto comp = std::greater<IncomingArcPath>();
for (const NodeIndex to : topological_order_) {
min_heap.clear();
if (is_source_[to]) {
min_heap.push_back({.arc_index = -1});
if (is_source_[static_cast<size_t>(to)]) {
min_heap.push_back({.arc_index = GraphType::kNilArc});
}
for (const ArcIndex reverse_arc_index : reverse_graph_.OutgoingArcs(to)) {
const ArcIndex arc_index = arc_indices.empty()
? reverse_arc_index
: arc_indices[reverse_arc_index];
const ArcIndex arc_index =
arc_indices.empty()
? reverse_arc_index
: arc_indices[static_cast<size_t>(reverse_arc_index)];
const NodeIndex from = graph_->Tail(arc_index);
const double arc_length = arc_lengths[arc_index];
const double arc_length = arc_lengths[static_cast<size_t>(arc_index)];
DCHECK(arc_length != -kInf);
const double path_length =
lengths_from_sources_.front()[from] + arc_length;
lengths_from_sources_.front()[static_cast<size_t>(from)] + arc_length;
if (path_length == kInf) {
continue;
}
@@ -608,17 +568,21 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
for (int k = 0; k < path_count_; ++k) {
std::pop_heap(min_heap.begin(), min_heap.end(), comp);
IncomingArcPath& incoming_arc_path = min_heap.back();
lengths_from_sources_[k][to] = incoming_arc_path.path_length;
incoming_shortest_paths_arc_[k][to] = incoming_arc_path.arc_index;
incoming_shortest_paths_index_[k][to] = incoming_arc_path.path_index;
if (incoming_arc_path.arc_index != -1 &&
lengths_from_sources_[k][static_cast<size_t>(to)] =
incoming_arc_path.path_length;
incoming_shortest_paths_arc_[k][static_cast<size_t>(to)] =
incoming_arc_path.arc_index;
incoming_shortest_paths_index_[k][static_cast<size_t>(to)] =
incoming_arc_path.path_index;
if (incoming_arc_path.arc_index != GraphType::kNilArc &&
incoming_arc_path.path_index < path_count_ - 1 &&
lengths_from_sources_[incoming_arc_path.path_index + 1]
[incoming_arc_path.from] < kInf) {
[static_cast<size_t>(incoming_arc_path.from)] <
kInf) {
++incoming_arc_path.path_index;
incoming_arc_path.path_length =
lengths_from_sources_[incoming_arc_path.path_index]
[incoming_arc_path.from] +
[static_cast<size_t>(incoming_arc_path.from)] +
incoming_arc_path.arc_length;
std::push_heap(min_heap.begin(), min_heap.end(), comp);
} else {
@@ -631,25 +595,22 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
}
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
bool KShortestPathsOnDagWrapper<GraphType>::IsReachable(NodeIndex node) const {
template <class GraphType, typename ArcLengths>
bool KShortestPathsOnDagWrapper<GraphType, ArcLengths>::IsReachable(
NodeIndex node) const {
CheckNodeIsValid(node, *graph_);
return lengths_from_sources_.front()[node] < kInf;
return lengths_from_sources_.front()[static_cast<size_t>(node)] < kInf;
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
std::vector<double> KShortestPathsOnDagWrapper<GraphType>::LengthsTo(
template <class GraphType, typename ArcLengths>
std::vector<double>
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::LengthsTo(
NodeIndex node) const {
std::vector<double> lengths_to;
lengths_to.reserve(path_count_);
for (int k = 0; k < path_count_; ++k) {
const double length_to = lengths_from_sources_[k][node];
const double length_to =
lengths_from_sources_[k][static_cast<size_t>(node)];
if (length_to == kInf) {
break;
}
@@ -658,30 +619,30 @@ std::vector<double> KShortestPathsOnDagWrapper<GraphType>::LengthsTo(
return lengths_to;
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
template <class GraphType, typename ArcLengths>
std::vector<std::vector<typename GraphType::ArcIndex>>
KShortestPathsOnDagWrapper<GraphType>::ArcPathsTo(NodeIndex node) const {
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::ArcPathsTo(
NodeIndex node) const {
std::vector<std::vector<ArcIndex>> arc_paths;
arc_paths.reserve(path_count_);
for (int k = 0; k < path_count_; ++k) {
if (lengths_from_sources_[k][node] == kInf) {
if (lengths_from_sources_[k][static_cast<size_t>(node)] == kInf) {
break;
}
std::vector<ArcIndex> arc_path;
int current_path_index = k;
NodeIndex current_node = node;
for (int i = 0; i < graph_->num_nodes(); ++i) {
for (NodeIndex i(0); i < graph_->num_nodes(); ++i) {
ArcIndex current_arc =
incoming_shortest_paths_arc_[current_path_index][current_node];
if (current_arc == -1) {
incoming_shortest_paths_arc_[current_path_index]
[static_cast<size_t>(current_node)];
if (current_arc == GraphType::kNilArc) {
break;
}
arc_path.push_back(current_arc);
current_path_index =
incoming_shortest_paths_index_[current_path_index][current_node];
incoming_shortest_paths_index_[current_path_index]
[static_cast<size_t>(current_node)];
current_node = graph_->Tail(current_arc);
}
absl::c_reverse(arc_path);
@@ -690,12 +651,10 @@ KShortestPathsOnDagWrapper<GraphType>::ArcPathsTo(NodeIndex node) const {
return arc_paths;
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
template <class GraphType, typename ArcLengths>
std::vector<std::vector<typename GraphType::NodeIndex>>
KShortestPathsOnDagWrapper<GraphType>::NodePathsTo(NodeIndex node) const {
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::NodePathsTo(
NodeIndex node) const {
const std::vector<std::vector<ArcIndex>> arc_paths = ArcPathsTo(node);
std::vector<std::vector<NodeIndex>> node_paths(arc_paths.size());
for (int k = 0; k < arc_paths.size(); ++k) {

View File

@@ -29,6 +29,8 @@
#include "gtest/gtest.h"
#include "ortools/base/dump_vars.h"
#include "ortools/base/gmock.h"
#include "ortools/base/strong_int.h"
#include "ortools/base/strong_vector.h"
#include "ortools/graph/graph.h"
#include "ortools/graph/graph_io.h"
#include "ortools/util/flat_matrix.h"
@@ -75,7 +77,7 @@ TEST(TopologicalOrderIsValidTest, ValidateTopologicalOrder) {
TEST(ShortestPathOnDagTest, EmptyGraph) {
EXPECT_DEATH(ShortestPathsOnDag(/*num_nodes=*/0, /*arcs_with_length=*/{},
/*source=*/0, /*destination=*/0),
"num_nodes\\(\\) > 0");
"num_nodes > 0");
}
TEST(ShortestPathOnDagTest, NonExistingSourceBecauseNegative) {
@@ -89,7 +91,7 @@ TEST(ShortestPathOnDagTest, NonExistingSourceBecauseTooLarge) {
EXPECT_DEATH(
ShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{{0, 1, 0.0}},
/*source=*/3, /*destination=*/1),
"num_nodes\\(\\)");
"num_nodes");
}
TEST(ShortestPathOnDagTest, NonExistingDestinationBecauseNegative) {
@@ -103,7 +105,7 @@ TEST(ShortestPathOnDagTest, NonExistingDestinationBecauseTooLarge) {
EXPECT_DEATH(
ShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{{0, 1, 0.0}},
/*source=*/0, /*destination=*/3),
"num_nodes\\(\\)");
"num_nodes");
}
TEST(ShortestPathOnDagTest, Cycle) {
@@ -287,6 +289,37 @@ TEST(ShortestPathOnDagTest, UpdateCost) {
/*node_path=*/ElementsAre(source, b, destination)));
}
DEFINE_STRONG_INT_TYPE(NodeIndex, int32_t);
DEFINE_STRONG_INT_TYPE(ArcIndex, int32_t);
TEST(ShortestPathsOnDagWrapperTest, StrongIndices) {
const NodeIndex source_1(0);
const NodeIndex source_2(1);
const NodeIndex destination(2);
const NodeIndex num_nodes(3);
util::ListGraph<NodeIndex, ArcIndex> graph(num_nodes,
/*arc_capacity=*/ArcIndex(2));
using ArcLengths = util_intops::StrongVector<ArcIndex, double>;
ArcLengths arc_lengths;
graph.AddArc(source_1, destination);
arc_lengths.push_back(-6.0);
graph.AddArc(source_2, destination);
arc_lengths.push_back(3.0);
const std::vector<NodeIndex> topological_order = {source_2, source_1,
destination};
ShortestPathsOnDagWrapper<util::ListGraph<NodeIndex, ArcIndex>, ArcLengths>
shortest_path_on_dag(&graph, &arc_lengths, topological_order);
shortest_path_on_dag.RunShortestPathOnDag({source_1, source_2});
EXPECT_TRUE(shortest_path_on_dag.IsReachable(destination));
EXPECT_THAT(shortest_path_on_dag.LengthTo(destination), -6.0);
EXPECT_THAT(shortest_path_on_dag.ArcPathTo(destination),
ElementsAre(ArcIndex(0)));
EXPECT_THAT(shortest_path_on_dag.NodePathTo(destination),
ElementsAre(source_1, destination));
}
TEST(ShortestPathsOnDagWrapperTest, MultipleSources) {
const int source_1 = 0;
const int source_2 = 1;
@@ -634,7 +667,7 @@ TEST(KShortestPathOnDagTest, EmptyGraph) {
EXPECT_DEATH(
KShortestPathsOnDag(/*num_nodes=*/0, /*arcs_with_length=*/{},
/*source=*/0, /*destination=*/0, /*path_count=*/2),
"num_nodes\\(\\) > 0");
"num_nodes > 0");
}
TEST(KShortestPathOnDagTest, NonExistingSourceBecauseNegative) {
@@ -648,7 +681,7 @@ TEST(KShortestPathOnDagTest, NonExistingSourceBecauseTooLarge) {
EXPECT_DEATH(
KShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{{0, 1, 0.0}},
/*source=*/3, /*destination=*/1, /*path_count=*/2),
"num_nodes\\(\\)");
"num_nodes");
}
TEST(KShortestPathOnDagTest, NonExistingDestinationBecauseNegative) {
@@ -936,6 +969,38 @@ TEST(KShortestPathOnDagTest, UpdateCost) {
/*node_path=*/ElementsAre(source, a, destination))));
}
TEST(KShortestPathsOnDagWrapperTest, StrongIndices) {
const NodeIndex source_1(0);
const NodeIndex source_2(1);
const NodeIndex destination(2);
const NodeIndex num_nodes(3);
util::ListGraph<NodeIndex, ArcIndex> graph(num_nodes,
/*arc_capacity=*/ArcIndex(2));
using ArcLengths = util_intops::StrongVector<ArcIndex, double>;
ArcLengths arc_lengths;
graph.AddArc(source_1, destination);
arc_lengths.push_back(-6.0);
graph.AddArc(source_2, destination);
arc_lengths.push_back(3.0);
const std::vector<NodeIndex> topological_order = {source_2, source_1,
destination};
const int path_count = 2;
KShortestPathsOnDagWrapper<util::ListGraph<NodeIndex, ArcIndex>, ArcLengths>
shortest_paths_on_dag(&graph, &arc_lengths, topological_order,
path_count);
shortest_paths_on_dag.RunKShortestPathOnDag({source_1, source_2});
EXPECT_TRUE(shortest_paths_on_dag.IsReachable(destination));
EXPECT_THAT(shortest_paths_on_dag.LengthsTo(destination),
ElementsAre(-6.0, 3.0));
EXPECT_THAT(shortest_paths_on_dag.ArcPathsTo(destination),
ElementsAre(ElementsAre(ArcIndex(0)), ElementsAre(ArcIndex(1))));
EXPECT_THAT(shortest_paths_on_dag.NodePathsTo(destination),
ElementsAre(ElementsAre(source_1, destination),
ElementsAre(source_2, destination)));
}
TEST(KShortestPathsOnDagWrapperTest, MultipleSources) {
const int source_1 = 0;
const int source_2 = 1;

View File

@@ -286,8 +286,12 @@ class BaseGraph {
// Constants that will never be a valid node or arc.
// They are the maximum possible node and arc capacity.
static const NodeIndexType kNilNode;
static const ArcIndexType kNilArc;
static_assert(std::numeric_limits<NodeIndexType>::is_specialized);
static constexpr NodeIndexType kNilNode =
std::numeric_limits<NodeIndexType>::max();
static_assert(std::numeric_limits<ArcIndexType>::is_specialized);
static constexpr ArcIndexType kNilArc =
std::numeric_limits<ArcIndexType>::max();
protected:
// Functions commented when defined because they are implementation details.
@@ -590,6 +594,26 @@ class SVector {
} // namespace internal
// Graph traits, to allow algorithms to manipulate graphs as adjacency lists.
// This works with any graph type, and any object that has:
// - a size() method returning the number of nodes.
// - an operator[] method taking a node index and returning a range of neighbour
// node indices.
// One common example is using `std::vector<std::vector<int>>` to represent
// adjacency lists.
template <typename Graph>
struct GraphTraits {
private:
// The type of the range returned by `operator[]`.
using NeighborRangeType = std::decay_t<
decltype(std::declval<Graph>()[std::declval<Graph>().size()])>;
public:
// The index type for nodes of the graph.
using NodeIndex =
std::decay_t<decltype(*(std::declval<NeighborRangeType>().begin()))>;
};
// Basic graph implementation without reverse arc. This class also serves as a
// documentation for the generic graph interface (minus the part related to
// reverse arcs).
@@ -1121,18 +1145,6 @@ BaseGraph<NodeIndexType, ArcIndexType, HasNegativeReverseArcs>::AllForwardArcs()
return IntegerRange<ArcIndexType>(ArcIndexType(0), num_arcs_);
}
template <typename NodeIndexType, typename ArcIndexType,
bool HasNegativeReverseArcs>
const NodeIndexType
BaseGraph<NodeIndexType, ArcIndexType, HasNegativeReverseArcs>::kNilNode =
std::numeric_limits<NodeIndexType>::max();
template <typename NodeIndexType, typename ArcIndexType,
bool HasNegativeReverseArcs>
const ArcIndexType
BaseGraph<NodeIndexType, ArcIndexType, HasNegativeReverseArcs>::kNilArc =
std::numeric_limits<ArcIndexType>::max();
template <typename NodeIndexType, typename ArcIndexType,
bool HasNegativeReverseArcs>
NodeIndexType BaseGraph<NodeIndexType, ArcIndexType,

View File

@@ -57,6 +57,8 @@ class BeginEndWrapper {
Iterator begin() const { return begin_; }
Iterator end() const { return end_; }
// Available only if `Iterator` is a random access iterator.
size_t size() const { return end_ - begin_; }
bool empty() const { return begin() == end(); }
@@ -127,8 +129,6 @@ class IntegerRangeIterator
#endif
{
public:
// TODO(b/385094969): This should be `IntegerType` for integers,
// `IntegerType:value_type` for strong signed integer types.
using difference_type = ptrdiff_t;
using value_type = IntegerType;
@@ -210,7 +210,7 @@ class IntegerRangeIterator
friend difference_type operator-(const IntegerRangeIterator l,
const IntegerRangeIterator r) {
return l.index_ - r.index_;
return static_cast<difference_type>(l.index_ - r.index_);
}
private:
@@ -248,9 +248,7 @@ class ChasingIterator
#endif
{
public:
// TODO(b/385094969): This should be `IntegerType` for integers,
// `IntegerType:value_type` for strong signed integer types.
using difference_type = std::ptrdiff_t;
using difference_type = ptrdiff_t;
using value_type = IndexT;
ChasingIterator() : index_(sentinel), next_(nullptr) {}

View File

@@ -30,6 +30,7 @@
#ifndef UTIL_GRAPH_TOPOLOGICALSORTER_H__
#define UTIL_GRAPH_TOPOLOGICALSORTER_H__
#include <cstddef>
#include <functional>
#include <limits>
#include <queue>
@@ -84,7 +85,9 @@ namespace graph {
// FastTopologicalSort(util::StaticGraph<>::FromArcs(num_nodes, arcs)));
//
template <class AdjacencyLists> // vector<vector<int>>, util::StaticGraph<>, ..
absl::StatusOr<std::vector<int>> FastTopologicalSort(const AdjacencyLists& adj);
absl::StatusOr<
std::vector<typename util::GraphTraits<AdjacencyLists>::NodeIndex>>
FastTopologicalSort(const AdjacencyLists& adj);
// Finds a cycle in the directed graph given as argument: nodes are dense
// integers in 0..num_nodes-1, and (directed) arcs are pairs of nodes
@@ -93,7 +96,9 @@ absl::StatusOr<std::vector<int>> FastTopologicalSort(const AdjacencyLists& adj);
// if the cycle 1->4->3->1 exists.
// If the graph is acyclic, returns an empty vector.
template <class AdjacencyLists> // vector<vector<int>>, util::StaticGraph<>, ..
absl::StatusOr<std::vector<int>> FindCycleInGraph(const AdjacencyLists& adj);
absl::StatusOr<
std::vector<typename util::GraphTraits<AdjacencyLists>::NodeIndex>>
FindCycleInGraph(const AdjacencyLists& adj);
} // namespace graph
@@ -615,38 +620,38 @@ std::vector<T> StableTopologicalSortOrDie(
}
template <class AdjacencyLists>
absl::StatusOr<std::vector<int>> FastTopologicalSort(
const AdjacencyLists& adj) {
const size_t num_nodes = adj.size();
if (num_nodes > std::numeric_limits<int>::max()) {
return absl::InvalidArgumentError("More than kint32max nodes");
absl::StatusOr<std::vector<typename GraphTraits<AdjacencyLists>::NodeIndex>>
FastTopologicalSort(const AdjacencyLists& adj) {
using NodeIndex = typename GraphTraits<AdjacencyLists>::NodeIndex;
if (adj.size() > std::numeric_limits<NodeIndex>::max()) {
return absl::InvalidArgumentError(
absl::StrFormat("Too many nodes: adj.size()=%v", adj.size()));
}
std::vector<int> indegree(num_nodes, 0);
std::vector<int> topo_order;
topo_order.reserve(num_nodes);
for (int from = 0; from < num_nodes; ++from) {
for (const int head : adj[from]) {
// We cast to unsigned int to test "head < 0 || head ≥ num_nodes" with a
// single test. Microbenchmarks showed a ~1% overall performance gain.
if (static_cast<uint32_t>(head) >= num_nodes) {
const NodeIndex num_nodes(adj.size());
std::vector<NodeIndex> indegree(static_cast<size_t>(num_nodes), NodeIndex(0));
std::vector<NodeIndex> topo_order;
topo_order.reserve(static_cast<size_t>(num_nodes));
for (NodeIndex from(0); from < num_nodes; ++from) {
for (const NodeIndex head : adj[from]) {
if (!(NodeIndex(0) <= head && head < num_nodes)) {
return absl::InvalidArgumentError(
absl::StrFormat("Invalid arc in adj[%d]: %d (num_nodes=%d)", from,
absl::StrFormat("Invalid arc in adj[%v]: %v (num_nodes=%v)", from,
head, num_nodes));
}
// NOTE(user): We could detect self-arcs here (head == from) and exit
// early, but microbenchmarks show a 2 to 4% slow-down if we do it, so we
// simply rely on self-arcs being detected as cycles in the topo sort.
++indegree[head];
++indegree[static_cast<size_t>(head)];
}
}
for (int i = 0; i < num_nodes; ++i) {
if (!indegree[i]) topo_order.push_back(i);
for (NodeIndex i(0); i < num_nodes; ++i) {
if (!indegree[static_cast<size_t>(i)]) topo_order.push_back(i);
}
size_t num_visited = 0;
while (num_visited < topo_order.size()) {
const int from = topo_order[num_visited++];
for (const int head : adj[from]) {
if (!--indegree[head]) topo_order.push_back(head);
const NodeIndex from = topo_order[num_visited++];
for (const NodeIndex head : adj[from]) {
if (!--indegree[static_cast<size_t>(head)]) topo_order.push_back(head);
}
}
if (topo_order.size() < static_cast<size_t>(num_nodes)) {
@@ -656,77 +661,99 @@ absl::StatusOr<std::vector<int>> FastTopologicalSort(
}
template <class AdjacencyLists>
absl::StatusOr<std::vector<int>> FindCycleInGraph(const AdjacencyLists& adj) {
const size_t num_nodes = adj.size();
if (num_nodes > std::numeric_limits<int>::max()) {
absl::StatusOr<
std::vector<typename util::GraphTraits<AdjacencyLists>::NodeIndex>>
FindCycleInGraph(const AdjacencyLists& adj) {
using NodeIndex = typename GraphTraits<AdjacencyLists>::NodeIndex;
if (adj.size() > std::numeric_limits<NodeIndex>::max()) {
return absl::InvalidArgumentError(
absl::StrFormat("Too many nodes: adj.size()=%d", adj.size()));
absl::StrFormat("Too many nodes: adj.size()=%v", adj.size()));
}
const NodeIndex num_nodes(adj.size());
// First pass to validate that inputs are valid.
for (NodeIndex node(0); node < NodeIndex(node); ++node) {
for (const NodeIndex head : adj[node]) {
if (head >= num_nodes) {
return absl::InvalidArgumentError(
absl::StrFormat("Invalid child %v in adj[%v]", head, node));
}
}
}
// To find a cycle, we start a DFS from each yet-unvisited node and
// try to find a cycle, if we don't find it then we know for sure that
// no cycle is reachable from any of the explored nodes (so, we don't
// explore them in later DFSs).
std::vector<bool> no_cycle_reachable_from(num_nodes, false);
std::vector<bool> no_cycle_reachable_from(static_cast<size_t>(num_nodes),
false);
// The DFS stack will contain a chain of nodes, from the root of the
// DFS to the current leaf.
struct DfsState {
int node;
NodeIndex node;
// Points at the first child node that we did *not* yet look at.
int adj_list_index;
explicit DfsState(int _node) : node(_node), adj_list_index(0) {}
decltype(adj[NodeIndex(0)].begin()) children;
decltype(adj[NodeIndex(0)].end()) children_end;
explicit DfsState(NodeIndex _node,
const decltype(adj[NodeIndex(0)])& neighbours)
: node(_node),
children(neighbours.begin()),
children_end(neighbours.end()) {}
};
std::vector<DfsState> dfs_stack;
std::vector<bool> in_cur_stack(num_nodes, false);
for (int start_node = 0; start_node < static_cast<int>(num_nodes);
std::vector<bool> visited(static_cast<size_t>(num_nodes), false);
for (NodeIndex start_node(0); start_node < NodeIndex(num_nodes);
++start_node) {
if (no_cycle_reachable_from[start_node]) continue;
if (no_cycle_reachable_from[static_cast<size_t>(start_node)]) continue;
// Start the DFS.
dfs_stack.push_back(DfsState(start_node));
in_cur_stack[start_node] = true;
visited[static_cast<size_t>(start_node)] = true;
dfs_stack.push_back(DfsState(start_node, adj[start_node]));
while (!dfs_stack.empty()) {
DfsState* cur_state = &dfs_stack.back();
if (static_cast<size_t>(cur_state->adj_list_index) >=
adj[cur_state->node].size()) {
no_cycle_reachable_from[cur_state->node] = true;
in_cur_stack[cur_state->node] = false;
DfsState* const cur_state = &dfs_stack.back();
while (
cur_state->children != cur_state->children_end &&
no_cycle_reachable_from[static_cast<size_t>(*cur_state->children)]) {
++cur_state->children;
}
if (cur_state->children == cur_state->children_end) {
no_cycle_reachable_from[static_cast<size_t>(cur_state->node)] = true;
dfs_stack.pop_back();
continue;
}
// Look at the current child, and increase the current state's
// adj_list_index.
// TODO(user): Caching adj[cur_state->node] in a local stack to improve
// locality and so that the [] operator is called exactly once per node.
const int child = adj[cur_state->node][cur_state->adj_list_index++];
if (static_cast<size_t>(child) >= num_nodes) {
return absl::InvalidArgumentError(absl::StrFormat(
"Invalid child %d in adj[%d]", child, cur_state->node));
}
if (no_cycle_reachable_from[child]) continue;
if (in_cur_stack[child]) {
const NodeIndex child = *cur_state->children;
// At that point the child is either:
// - visited and all finalized (all its children are visited). We know
// that it's not part of a cycle, otherwise we'd already have
// returned.
// - visited and not finalized (some of its children are not visited).
// That means that we've reached it again from a child, so we've found
// a cycle.
// - not visited. We push it on the stack and explore it.
if (no_cycle_reachable_from[static_cast<size_t>(child)]) continue;
if (visited[static_cast<size_t>(child)]) {
// We detected a cycle! It corresponds to the tail end of dfs_stack,
// in reverse order, until we find "child".
int cycle_start = dfs_stack.size() - 1;
size_t cycle_start = dfs_stack.size() - 1;
while (dfs_stack[cycle_start].node != child) --cycle_start;
const int cycle_size = dfs_stack.size() - cycle_start;
std::vector<int> cycle(cycle_size);
for (int c = 0; c < cycle_size; ++c) {
const size_t cycle_size = dfs_stack.size() - cycle_start;
std::vector<NodeIndex> cycle(cycle_size);
for (size_t c = 0; c < cycle_size; ++c) {
cycle[c] = dfs_stack[cycle_start + c].node;
}
return cycle;
}
// Push the child onto the stack.
dfs_stack.push_back(DfsState(child));
in_cur_stack[child] = true;
// Verify that its adjacency list seems valid.
if (adj[child].size() > std::numeric_limits<int>::max()) {
return absl::InvalidArgumentError(absl::StrFormat(
"Invalid adj[%d].size() = %d", child, adj[child].size()));
}
dfs_stack.push_back(DfsState(child, adj[child]));
visited[static_cast<size_t>(child)] = true;
}
}
// If we're here, then all the DFS stopped, and there is no cycle.
return std::vector<int>{};
return std::vector<NodeIndex>{};
}
} // namespace graph

View File

@@ -346,7 +346,7 @@ void LoadGurobiFunctions(DynamicLibrary* gurobi_dynamic_library) {
std::vector<std::string> GurobiDynamicLibraryPotentialPaths() {
std::vector<std::string> potential_paths;
const std::vector<std::string> kGurobiVersions = {
const std::vector<absl::string_view> kGurobiVersions = {
"1201", "1200", "1103", "1102", "1101", "1100", "1003",
"1002", "1001", "1000", "952", "951", "950", "911",
"910", "903", "902", "811", "801", "752"};
@@ -355,8 +355,8 @@ std::vector<std::string> GurobiDynamicLibraryPotentialPaths() {
// Look for libraries pointed by GUROBI_HOME first.
const char* gurobi_home_from_env = getenv("GUROBI_HOME");
if (gurobi_home_from_env != nullptr) {
for (const std::string& version : kGurobiVersions) {
const std::string lib = version.substr(0, version.size() - 1);
for (const absl::string_view version : kGurobiVersions) {
const absl::string_view lib = version.substr(0, version.size() - 1);
#if defined(_MSC_VER) // Windows
potential_paths.push_back(
absl::StrCat(gurobi_home_from_env, "\\bin\\gurobi", lib, ".dll"));
@@ -376,8 +376,8 @@ std::vector<std::string> GurobiDynamicLibraryPotentialPaths() {
}
// Search for canonical places.
for (const std::string& version : kGurobiVersions) {
const std::string lib = version.substr(0, version.size() - 1);
for (const absl::string_view version : kGurobiVersions) {
const absl::string_view lib = version.substr(0, version.size() - 1);
#if defined(_MSC_VER) // Windows
potential_paths.push_back(absl::StrCat("C:\\Program Files\\gurobi", version,
"\\win64\\bin\\gurobi", lib,
@@ -407,7 +407,7 @@ std::vector<std::string> GurobiDynamicLibraryPotentialPaths() {
}
#if defined(__GNUC__) // path in linux64 gurobi/optimizer docker image.
for (const std::string& version :
for (const absl::string_view version :
{"12.0.1", "12.0.0", "11.0.3", "11.0.2", "11.0.1", "11.0.0", "10.0.3",
"10.0.2", "10.0.1", "10.0.0", "9.5.2", "9.5.1", "9.5.0"}) {
potential_paths.push_back(
@@ -418,7 +418,7 @@ std::vector<std::string> GurobiDynamicLibraryPotentialPaths() {
}
absl::Status LoadGurobiDynamicLibrary(
std::vector<std::string> potential_paths) {
std::vector<absl::string_view> potential_paths) {
static std::once_flag gurobi_loading_done;
static absl::Status gurobi_load_status;
static DynamicLibrary gurobi_library;
@@ -431,7 +431,7 @@ absl::Status LoadGurobiDynamicLibrary(
GurobiDynamicLibraryPotentialPaths();
potential_paths.insert(potential_paths.end(), canonical_paths.begin(),
canonical_paths.end());
for (const std::string& path : potential_paths) {
for (const absl::string_view path : potential_paths) {
if (gurobi_library.TryToLoad(path)) {
LOG(INFO) << "Found the Gurobi library in '" << path << ".";
break;

View File

@@ -18,6 +18,7 @@
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "ortools/base/dynamic_library.h"
#include "ortools/base/logging.h"
@@ -52,7 +53,7 @@ bool GurobiIsCorrectlyInstalled();
// Successive calls are no-op.
//
// Note that it does not check if a token license can be grabbed.
absl::Status LoadGurobiDynamicLibrary(std::vector<std::string> potential_paths);
absl::Status LoadGurobiDynamicLibrary(std::vector<absl::string_view> potential_paths);
// The list of #define and extern std::function<> below is generated directly
// from gurobi_c.h via parse_header.py

View File

@@ -20,12 +20,13 @@
#include "absl/flags/usage.h"
#include "absl/log/globals.h"
#include "absl/log/initialize.h"
#include "absl/strings/string_view.h"
#include "ortools/gurobi/environment.h"
#include "ortools/sat/cp_model_solver.h"
#include "ortools/sat/cp_model_solver_helpers.h"
namespace operations_research {
void CppBridge::InitLogging(const std::string& usage) {
void CppBridge::InitLogging(absl::string_view usage) {
absl::SetProgramUsageMessage(usage);
absl::InitializeLog();
}
@@ -41,7 +42,7 @@ void CppBridge::SetFlags(const CppFlags& flags) {
absl::SetFlag(&FLAGS_cp_model_dump_response, flags.cp_model_dump_response);
}
bool CppBridge::LoadGurobiSharedLibrary(const std::string& full_library_path) {
bool CppBridge::LoadGurobiSharedLibrary(absl::string_view full_library_path) {
return LoadGurobiDynamicLibrary({full_library_path}).ok();
}

View File

@@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include "absl/strings/string_view.h"
#include "ortools/base/logging.h"
#include "ortools/base/version.h"
#include "ortools/sat/cp_model_solver_helpers.h"
@@ -86,7 +87,7 @@ class CppBridge {
*
* This must be called once before any other library from OR-Tools are used.
*/
static void InitLogging(const std::string& usage);
static void InitLogging(absl::string_view usage);
/**
* Shutdown the C++ logging layer.
@@ -111,7 +112,7 @@ class CppBridge {
* You need to pass the full path, including the shared library file.
* It returns true if the library was found and correctly loaded.
*/
static bool LoadGurobiSharedLibrary(const std::string& full_library_path);
static bool LoadGurobiSharedLibrary(absl::string_view full_library_path);
/**
* Delete a temporary C++ byte array.

View File

@@ -26,11 +26,11 @@ import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.SimpleFileVisitor;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.AbstractMap;
import java.util.Objects;
/** Load native libraries needed for using ortools-java. */
@@ -144,28 +144,28 @@ public class Loader {
URI resourceURI = getNativeResourceURI();
Path tempPath = unpackNativeResources(resourceURI);
// libraries order does matter <LibraryName, isMandatory> !
List<Map.Entry<String,Boolean>> dlls = Arrays.asList(
(new AbstractMap.SimpleEntry("zlib1", true)),
(new AbstractMap.SimpleEntry("abseil_dll", true)),
(new AbstractMap.SimpleEntry("re2", true)),
(new AbstractMap.SimpleEntry("libutf8_validity", true)),
(new AbstractMap.SimpleEntry("libprotobuf", true)),
(new AbstractMap.SimpleEntry("highs", false)),
(new AbstractMap.SimpleEntry("libscip", false)),
(new AbstractMap.SimpleEntry("ortools", true)),
(new AbstractMap.SimpleEntry("jniortools", true)));
List<Map.Entry<String, Boolean>> dlls =
Arrays.asList((new AbstractMap.SimpleEntry("zlib1", true)),
(new AbstractMap.SimpleEntry("abseil_dll", true)),
(new AbstractMap.SimpleEntry("re2", true)),
(new AbstractMap.SimpleEntry("libutf8_validity", true)),
(new AbstractMap.SimpleEntry("libprotobuf", true)),
(new AbstractMap.SimpleEntry("highs", false)),
(new AbstractMap.SimpleEntry("libscip", false)),
(new AbstractMap.SimpleEntry("ortools", true)),
(new AbstractMap.SimpleEntry("jniortools", true)));
for (Map.Entry<String,Boolean> dll : dlls) {
for (Map.Entry<String, Boolean> dll : dlls) {
try {
//System.out.println("System.load(" + dll.getKey() + ")");
// System.out.println("System.load(" + dll.getKey() + ")");
System.load(tempPath.resolve(RESOURCE_PATH)
.resolve(System.mapLibraryName(dll.getKey()))
.toAbsolutePath()
.toString());
} catch (UnsatisfiedLinkError e) {
System.out.println("System.load(" + dll.getKey() + ") failed!");
if(dll.getValue()) {
throw new RuntimeException(e);
if (dll.getValue()) {
throw new RuntimeException(e);
}
}
}

View File

@@ -126,8 +126,6 @@ CP solver built on top of the SAT solver:
Propagation algorithms for the cumulative scheduling constraint.
* [cumulative_energy.h](../sat/cumulative_energy.h):
Propagation algorithms for a more general cumulative constraint.
* [theta_tree.h](../sat/theta_tree.h):
Data structure used in the cumulative/disjunctive propagation algorithm.
### Packing constraints

View File

@@ -1876,46 +1876,18 @@ TEST(LinMaxExpansionTest, GoldenTest) {
variables { domain: 0 domain: 1 }
constraints {}
constraints {
linear {
vars: 0
vars: 1
coeffs: 1
coeffs: -2
domain: -1
domain: 9223372036854775806
}
linear { vars: 0 vars: 1 coeffs: 1 coeffs: -2 domain: -1 domain: 5 }
}
constraints {
linear {
vars: 0
vars: 2
coeffs: 1
coeffs: -1
domain: -4
domain: 9223372036854775803
}
linear { vars: 0 vars: 2 coeffs: 1 coeffs: -1 domain: -4 domain: 5 }
}
constraints {
enforcement_literal: 3
linear {
vars: 0
vars: 1
coeffs: 1
coeffs: -2
domain: -9223372036854775808
domain: -1
}
linear { vars: 0 vars: 1 coeffs: 1 coeffs: -2 domain: -10 domain: -1 }
}
constraints {
enforcement_literal: -4
linear {
vars: 0
vars: 2
coeffs: 1
coeffs: -1
domain: -9223372036854775808
domain: -4
}
linear { vars: 0 vars: 2 coeffs: 1 coeffs: -1 domain: -6 domain: -4 }
}
)pb");
EXPECT_THAT(initial_model, testing::EqualsProto(expected_model));

View File

@@ -2657,10 +2657,18 @@ bool PresolveContext::CanonicalizeLinearConstraint(ConstraintProto* ct) {
const bool result = CanonicalizeLinearExpressionInternal(
ct->enforcement_literal(), ct->mutable_linear(), &offset, &tmp_terms_,
this);
if (offset != 0) {
FillDomainInProto(
ReadDomainFromProto(ct->linear()).AdditionWith(Domain(-offset)),
ct->mutable_linear());
const auto [min_activity, max_activity] = ComputeMinMaxActivity(ct->linear());
const Domain implied = Domain(min_activity, max_activity);
const Domain original_domain =
ReadDomainFromProto(ct->linear()).AdditionWith(Domain(-offset));
const Domain tight_domain = implied.IntersectionWith(original_domain);
if (tight_domain.IsEmpty()) {
// Canonicalization is not the right place to handle unsat constraints.
// Let's just replace the domain by one that is overflow-safe.
const Domain bad_domain = Domain(implied.Max() + 1, implied.Max() + 2);
FillDomainInProto(bad_domain, ct->mutable_linear());
} else {
FillDomainInProto(tight_domain, ct->mutable_linear());
}
return result;
}

View File

@@ -1053,7 +1053,7 @@ TEST(PresolveContextTest, CanonicalizeLinearConstraint) {
linear {
vars: [ 0, 1, 2 ]
coeffs: [ -2, 2, -2 ]
domain: [ 0, 1000 ]
domain: [ 0, 16 ]
}
)pb");
EXPECT_THAT(working_model.constraints(0), testing::EqualsProto(expected));

View File

@@ -575,7 +575,7 @@ PYBIND11_MODULE(set_cover, m) {
return {cleared.begin(), cleared.end()};
});
m.def("clear_random_subsets",
[](const std::vector<BaseInt>& focus, BaseInt num_subsets,
[](absl::Span<const BaseInt> focus, BaseInt num_subsets,
SetCoverInvariant* inv) -> std::vector<BaseInt> {
const std::vector<SubsetIndex> cleared = ClearRandomSubsets(
VectorIntToVectorSubsetIndex(focus), num_subsets, inv);

View File

@@ -101,62 +101,66 @@ class ThetaLambdaTree {
// allows to keep the same memory for each call.
void Reset(int num_events);
// Recomputes the values of internal nodes of the tree from the values in the
// leaves. We enable batching modifications to the tree by providing
// DelayedXXX() methods that run in O(1), but those methods do not
// update internal nodes. This breaks tree invariants, so that GetXXX()
// methods will not reflect modifications made to events.
// RecomputeTreeForDelayedOperations() restores those invariants in O(n).
// Thus, batching operations can be done by first doing calls to DelayedXXX()
// methods, then calling RecomputeTreeForDelayedOperations() once.
void RecomputeTreeForDelayedOperations();
// Makes event present and updates its initial envelope and min/max energies.
// The initial_envelope must be >= ThetaLambdaTreeNegativeInfinity().
// This updates the tree in O(log n).
void AddOrUpdateEvent(int event, IntegerType initial_envelope,
IntegerType energy_min, IntegerType energy_max);
// Delayed version of AddOrUpdateEvent(),
// see RecomputeTreeForDelayedOperations().
void DelayedAddOrUpdateEvent(int event, IntegerType initial_envelope,
IntegerType energy_min, IntegerType energy_max);
IntegerType energy_min, IntegerType energy_max) {
DCHECK_LE(0, energy_min);
DCHECK_LE(energy_min, energy_max);
const int node = GetLeafFromEvent(event);
tree_[node] = {.envelope = initial_envelope + energy_min,
.envelope_opt = initial_envelope + energy_max,
.sum_of_energy_min = energy_min,
.max_of_energy_delta = energy_max - energy_min};
RefreshNode(node);
}
// Adds event to the lambda part of the tree only.
// This will leave GetEnvelope() unchanged, only GetOptionalEnvelope() can
// be affected. This is done by setting envelope to IntegerTypeMinimumValue(),
// be affected, by setting envelope to std::numeric_limits<>::min(),
// energy_min to 0, and initial_envelope_opt and energy_max to the parameters.
// This updates the tree in O(log n).
void AddOrUpdateOptionalEvent(int event, IntegerType initial_envelope_opt,
IntegerType energy_max);
// Delayed version of AddOrUpdateOptionalEvent(),
// see RecomputeTreeForDelayedOperations().
void DelayedAddOrUpdateOptionalEvent(int event,
IntegerType initial_envelope_opt,
IntegerType energy_max);
IntegerType energy_max) {
DCHECK_LE(0, energy_max);
const int node = GetLeafFromEvent(event);
tree_[node] = {.envelope = std::numeric_limits<IntegerType>::min(),
.envelope_opt = initial_envelope_opt + energy_max,
.sum_of_energy_min = IntegerType{0},
.max_of_energy_delta = energy_max};
RefreshNode(node);
}
// Makes event absent, compute the new envelope in O(log n).
void RemoveEvent(int event);
// Delayed version of RemoveEvent(), see RecomputeTreeForDelayedOperations().
void DelayedRemoveEvent(int event);
void RemoveEvent(int event) {
const int node = GetLeafFromEvent(event);
tree_[node] = {.envelope = std::numeric_limits<IntegerType>::min(),
.envelope_opt = std::numeric_limits<IntegerType>::min(),
.sum_of_energy_min = IntegerType{0},
.max_of_energy_delta = IntegerType{0}};
RefreshNode(node);
}
// Returns the maximum envelope using all the energy_min in O(1).
// If theta is empty, returns ThetaLambdaTreeNegativeInfinity().
IntegerType GetEnvelope() const;
// If theta is empty, returns std::numeric_limits<>::min().
IntegerType GetEnvelope() const { return tree_[1].envelope; }
// Returns the maximum envelope using the energy min of all task but
// one and the energy max of the last one in O(1).
// If theta and lambda are empty, returns ThetaLambdaTreeNegativeInfinity().
IntegerType GetOptionalEnvelope() const;
// If theta and lambda are empty, returns std::numeric_limits<>::min().
IntegerType GetOptionalEnvelope() const { return tree_[1].envelope_opt; }
// Computes the maximum event s.t. GetEnvelopeOf(event) > envelope_max.
// There must be such an event, i.e. GetEnvelope() > envelope_max.
// This finds the maximum event e such that
// initial_envelope(e) + sum_{e' >= e} energy_min(e') > target_envelope.
// This operation is O(log n).
int GetMaxEventWithEnvelopeGreaterThan(IntegerType target_envelope) const;
int GetMaxEventWithEnvelopeGreaterThan(IntegerType target_envelope) const {
DCHECK_LT(target_envelope, tree_[1].envelope);
IntegerType unused;
return GetEventFromLeaf(
GetMaxLeafWithEnvelopeGreaterThan(1, target_envelope, &unused));
}
// Returns initial_envelope(event) + sum_{event' >= event} energy_min(event'),
// in time O(log n).
@@ -181,7 +185,14 @@ class ThetaLambdaTree {
// This operation is O(log n).
void GetEventsWithOptionalEnvelopeGreaterThan(
IntegerType target_envelope, int* critical_event, int* optional_event,
IntegerType* available_energy) const;
IntegerType* available_energy) const {
int critical_leaf;
int optional_leaf;
GetLeavesWithOptionalEnvelopeGreaterThan(target_envelope, &critical_leaf,
&optional_leaf, available_energy);
*critical_event = GetEventFromLeaf(critical_leaf);
*optional_event = GetEventFromLeaf(optional_leaf);
}
// Getters.
IntegerType EnergyMin(int event) const {
@@ -196,10 +207,36 @@ class ThetaLambdaTree {
IntegerType max_of_energy_delta;
};
TreeNode ComposeTreeNodes(const TreeNode& left, const TreeNode& right);
TreeNode ComposeTreeNodes(const TreeNode& left, const TreeNode& right) {
return {
.envelope =
std::max(right.envelope, left.envelope + right.sum_of_energy_min),
.envelope_opt =
std::max(right.envelope_opt,
right.sum_of_energy_min +
std::max(left.envelope_opt,
left.envelope + right.max_of_energy_delta)),
.sum_of_energy_min = left.sum_of_energy_min + right.sum_of_energy_min,
.max_of_energy_delta =
std::max(right.max_of_energy_delta, left.max_of_energy_delta)};
}
int GetLeafFromEvent(int event) const;
int GetEventFromLeaf(int leaf) const;
int GetLeafFromEvent(int event) const {
DCHECK_LE(0, event);
DCHECK_LT(event, num_events_);
// Keeping the ordering of events is important, so the first set of events
// must be mapped to the set of leaves at depth d, and the second set of
// events must be mapped to the set of leaves at depth d-1.
const int r = power_of_two_ + event;
return r < 2 * num_leaves_ ? r : r - num_leaves_;
}
int GetEventFromLeaf(int leaf) const {
DCHECK_GE(leaf, num_leaves_);
DCHECK_LT(leaf, 2 * num_leaves_);
const int r = leaf - power_of_two_;
return r >= 0 ? r : r + num_leaves_;
}
// Propagates the change of leaf energies and envelopes towards the root.
void RefreshNode(int node);
@@ -225,32 +262,12 @@ class ThetaLambdaTree {
int num_leaves_;
int power_of_two_;
// A bool used in debug mode, to check that sequences of delayed operations
// are ended by Reset() or RecomputeTreeForDelayedOperations().
bool leaf_nodes_have_delayed_operations_ = false;
// Envelopes and energies of nodes.
std::vector<TreeNode> tree_;
};
template <typename IntegerType>
typename ThetaLambdaTree<IntegerType>::TreeNode
ThetaLambdaTree<IntegerType>::ComposeTreeNodes(const TreeNode& left,
const TreeNode& right) {
return {std::max(right.envelope, left.envelope + right.sum_of_energy_min),
std::max(right.envelope_opt,
right.sum_of_energy_min +
std::max(left.envelope_opt,
left.envelope + right.max_of_energy_delta)),
left.sum_of_energy_min + right.sum_of_energy_min,
std::max(right.max_of_energy_delta, left.max_of_energy_delta)};
}
template <typename IntegerType>
void ThetaLambdaTree<IntegerType>::Reset(int num_events) {
#ifndef NDEBUG
leaf_nodes_have_delayed_operations_ = false;
#endif
// Because the algorithm needs to access a node sibling (i.e. node_index ^ 1),
// our tree will always have an even number of leaves, just large enough to
// fit our number of events. And at least 2 for the empty tree case.
@@ -258,9 +275,11 @@ void ThetaLambdaTree<IntegerType>::Reset(int num_events) {
num_leaves_ = std::max(2, num_events + (num_events & 1));
const int num_nodes = 2 * num_leaves_;
tree_.assign(num_nodes, TreeNode{std::numeric_limits<IntegerType>::min(),
std::numeric_limits<IntegerType>::min(),
IntegerType{0}, IntegerType{0}});
tree_.assign(num_nodes,
TreeNode{.envelope = std::numeric_limits<IntegerType>::min(),
.envelope_opt = std::numeric_limits<IntegerType>::min(),
.sum_of_energy_min = IntegerType{0},
.max_of_energy_delta = IntegerType{0}});
// If num_leaves is not a power or two, the last depth of the tree will not be
// full, and the array will look like:
@@ -270,147 +289,8 @@ void ThetaLambdaTree<IntegerType>::Reset(int num_events) {
}
}
template <typename IntegerType>
int ThetaLambdaTree<IntegerType>::GetLeafFromEvent(int event) const {
DCHECK_LE(0, event);
DCHECK_LT(event, num_events_);
// Keeping the ordering of events is important, so the first set of events
// must be mapped to the set of leaves at depth d, and the second set of
// events must be mapped to the set of leaves at depth d-1.
const int r = power_of_two_ + event;
return r < 2 * num_leaves_ ? r : r - num_leaves_;
}
template <typename IntegerType>
int ThetaLambdaTree<IntegerType>::GetEventFromLeaf(int leaf) const {
DCHECK_GE(leaf, num_leaves_);
DCHECK_LT(leaf, 2 * num_leaves_);
const int r = leaf - power_of_two_;
return r >= 0 ? r : r + num_leaves_;
}
template <typename IntegerType>
void ThetaLambdaTree<IntegerType>::RecomputeTreeForDelayedOperations() {
#ifndef NDEBUG
leaf_nodes_have_delayed_operations_ = false;
#endif
// Only recompute internal nodes.
const int last_internal_node = tree_.size() / 2 - 1;
for (int node = last_internal_node; node >= 1; --node) {
const int right = 2 * node + 1;
const int left = 2 * node;
tree_[node] = ComposeTreeNodes(tree_[left], tree_[right]);
}
}
template <typename IntegerType>
void ThetaLambdaTree<IntegerType>::DelayedAddOrUpdateEvent(
int event, IntegerType initial_envelope, IntegerType energy_min,
IntegerType energy_max) {
#ifndef NDEBUG
leaf_nodes_have_delayed_operations_ = true;
#endif
DCHECK_LE(0, energy_min);
DCHECK_LE(energy_min, energy_max);
const int node = GetLeafFromEvent(event);
tree_[node] = {initial_envelope + energy_min, initial_envelope + energy_max,
energy_min, energy_max - energy_min};
}
template <typename IntegerType>
void ThetaLambdaTree<IntegerType>::AddOrUpdateEvent(
int event, IntegerType initial_envelope, IntegerType energy_min,
IntegerType energy_max) {
DCHECK(!leaf_nodes_have_delayed_operations_);
DCHECK_LE(0, energy_min);
DCHECK_LE(energy_min, energy_max);
const int node = GetLeafFromEvent(event);
tree_[node] = {initial_envelope + energy_min, initial_envelope + energy_max,
energy_min, energy_max - energy_min};
RefreshNode(node);
}
template <typename IntegerType>
void ThetaLambdaTree<IntegerType>::AddOrUpdateOptionalEvent(
int event, IntegerType initial_envelope_opt, IntegerType energy_max) {
DCHECK(!leaf_nodes_have_delayed_operations_);
DCHECK_LE(0, energy_max);
const int node = GetLeafFromEvent(event);
tree_[node] = {std::numeric_limits<IntegerType>::min(),
initial_envelope_opt + energy_max, IntegerType{0}, energy_max};
RefreshNode(node);
}
template <typename IntegerType>
void ThetaLambdaTree<IntegerType>::DelayedAddOrUpdateOptionalEvent(
int event, IntegerType initial_envelope_opt, IntegerType energy_max) {
#ifndef NDEBUG
leaf_nodes_have_delayed_operations_ = true;
#endif
DCHECK_LE(0, energy_max);
const int node = GetLeafFromEvent(event);
tree_[node] = {std::numeric_limits<IntegerType>::min(),
initial_envelope_opt + energy_max, IntegerType{0}, energy_max};
}
template <typename IntegerType>
void ThetaLambdaTree<IntegerType>::RemoveEvent(int event) {
DCHECK(!leaf_nodes_have_delayed_operations_);
const int node = GetLeafFromEvent(event);
tree_[node] = {std::numeric_limits<IntegerType>::min(),
std::numeric_limits<IntegerType>::min(), IntegerType{0},
IntegerType{0}};
RefreshNode(node);
}
template <typename IntegerType>
void ThetaLambdaTree<IntegerType>::DelayedRemoveEvent(int event) {
#ifndef NDEBUG
leaf_nodes_have_delayed_operations_ = true;
#endif
const int node = GetLeafFromEvent(event);
tree_[node] = {std::numeric_limits<IntegerType>::min(),
std::numeric_limits<IntegerType>::min(), IntegerType{0},
IntegerType{0}};
}
template <typename IntegerType>
IntegerType ThetaLambdaTree<IntegerType>::GetEnvelope() const {
DCHECK(!leaf_nodes_have_delayed_operations_);
return tree_[1].envelope;
}
template <typename IntegerType>
IntegerType ThetaLambdaTree<IntegerType>::GetOptionalEnvelope() const {
DCHECK(!leaf_nodes_have_delayed_operations_);
return tree_[1].envelope_opt;
}
template <typename IntegerType>
int ThetaLambdaTree<IntegerType>::GetMaxEventWithEnvelopeGreaterThan(
IntegerType target_envelope) const {
DCHECK(!leaf_nodes_have_delayed_operations_);
DCHECK_LT(target_envelope, tree_[1].envelope);
IntegerType unused;
return GetEventFromLeaf(
GetMaxLeafWithEnvelopeGreaterThan(1, target_envelope, &unused));
}
template <typename IntegerType>
void ThetaLambdaTree<IntegerType>::GetEventsWithOptionalEnvelopeGreaterThan(
IntegerType target_envelope, int* critical_event, int* optional_event,
IntegerType* available_energy) const {
DCHECK(!leaf_nodes_have_delayed_operations_);
int critical_leaf;
int optional_leaf;
GetLeavesWithOptionalEnvelopeGreaterThan(target_envelope, &critical_leaf,
&optional_leaf, available_energy);
*critical_event = GetEventFromLeaf(critical_leaf);
*optional_event = GetEventFromLeaf(optional_leaf);
}
template <typename IntegerType>
IntegerType ThetaLambdaTree<IntegerType>::GetEnvelopeOf(int event) const {
DCHECK(!leaf_nodes_have_delayed_operations_);
const int leaf = GetLeafFromEvent(event);
IntegerType envelope = tree_[leaf].envelope;
for (int node = leaf; node > 1; node >>= 1) {
@@ -434,7 +314,6 @@ void ThetaLambdaTree<IntegerType>::RefreshNode(int node) {
template <typename IntegerType>
int ThetaLambdaTree<IntegerType>::GetMaxLeafWithEnvelopeGreaterThan(
int node, IntegerType target_envelope, IntegerType* extra) const {
DCHECK(!leaf_nodes_have_delayed_operations_);
DCHECK_LT(target_envelope, tree_[node].envelope);
while (node < num_leaves_) {
const int left = node << 1;
@@ -454,7 +333,6 @@ int ThetaLambdaTree<IntegerType>::GetMaxLeafWithEnvelopeGreaterThan(
template <typename IntegerType>
int ThetaLambdaTree<IntegerType>::GetLeafWithMaxEnergyDelta(int node) const {
DCHECK(!leaf_nodes_have_delayed_operations_);
const IntegerType delta_node = tree_[node].max_of_energy_delta;
while (node < num_leaves_) {
const int left = node << 1;
@@ -474,7 +352,6 @@ template <typename IntegerType>
void ThetaLambdaTree<IntegerType>::GetLeavesWithOptionalEnvelopeGreaterThan(
IntegerType target_envelope, int* critical_leaf, int* optional_leaf,
IntegerType* available_energy) const {
DCHECK(!leaf_nodes_have_delayed_operations_);
DCHECK_LT(target_envelope, tree_[1].envelope_opt);
int node = 1;
while (node < num_leaves_) {