more work on dags

This commit is contained in:
Laurent Perron
2025-03-21 06:55:16 -07:00
parent 8375a92d58
commit b3d6ef66d2
6 changed files with 312 additions and 415 deletions

View File

@@ -16,6 +16,7 @@
#include <cmath>
#include <limits>
#include <optional>
#include <vector>
#include "absl/algorithm/container.h"
@@ -214,6 +215,9 @@ class ConstrainedShortestPathsOnDagWrapper {
absl::Span<const NodeIndex> destinations_;
const int num_resources_;
// Set to a node if and only if this node is in both `sources_` and
// `destinations_`.
std::optional<NodeIndex> source_is_destination_ = std::nullopt;
// Data about *reachable* sub-graphs split in two for bidirectional search.
// Reachable nodes are nodes that can be reached given the resources
// constraints, i.e., for each resource, the sum of the minimum resource to
@@ -334,13 +338,15 @@ ConstrainedShortestPathsOnDagWrapper<GraphType>::
<< absl::StrFormat(
"max_resource cannot be negative not +inf nor NaN");
}
std::vector<bool> is_source(graph->num_nodes(), false);
for (const NodeIndex source : sources) {
is_source[source] = true;
}
for (const NodeIndex destination : destinations) {
CHECK(!is_source[destination])
<< "A node cannot be both a source and destination";
}
std::vector<bool> is_source(graph->num_nodes(), false);
for (const NodeIndex source : sources) {
is_source[source] = true;
}
for (const NodeIndex destination : destinations) {
if (is_source[destination]) {
source_is_destination_ = destination;
return;
}
}
@@ -542,6 +548,10 @@ template <class GraphType>
#endif
GraphPathWithLength<GraphType> ConstrainedShortestPathsOnDagWrapper<
GraphType>::RunConstrainedShortestPathOnDag() {
if (source_is_destination_.has_value()) {
return {
.length = 0, .arc_path = {}, .node_path = {*source_is_destination_}};
}
// Assign lengths on sub-relevant graphs.
std::vector<double> sub_arc_lengths[2];
for (const Direction dir : {FORWARD, BACKWARD}) {
@@ -837,9 +847,6 @@ ConstrainedShortestPathsOnDagWrapper<GraphType>::MergeHalfRuns(
for (int label_to_index = first_label_to;
label_to_index < first_label_to + num_labels_to; ++label_to_index) {
const double length_to = backward_lengths[label_to_index];
if (arc_length + length_to >= best_label_pair.length) {
continue;
}
for (int label_from_index = first_label_from;
label_from_index < first_label_from + num_labels_from;
++label_from_index) {

View File

@@ -90,6 +90,45 @@ TEST(ConstrainedShortestPathOnDagTest, SimpleGraph) {
/*node_path=*/ElementsAre(source, b, destination)));
}
TEST(ConstrainedShortestPathOnDagTest, DiamondGraph) {
const std::vector<ArcWithLengthAndResources> arcs_with_length_and_resources =
{{.from = 0, .to = 1, .length = 1.0, .resources = {1.0}},
{.from = 0, .to = 2, .length = 1.0, .resources = {1.0}},
{.from = 0, .to = 3, .length = 1.0, .resources = {1.0}},
{.from = 0, .to = 4, .length = 1.0, .resources = {1.0}},
{.from = 0, .to = 5, .length = 1.0, .resources = {1.0}},
{.from = 0, .to = 6, .length = 1.0, .resources = {1.0}},
{.from = 1, .to = 2, .length = -1.0, .resources = {1.0}},
{.from = 1, .to = 7, .length = -1.0, .resources = {0.0}},
{.from = 2, .to = 3, .length = -1.0, .resources = {1.0}},
{.from = 2, .to = 7, .length = -1.0, .resources = {0.0}},
{.from = 3, .to = 4, .length = -1.0, .resources = {1.0}},
{.from = 3, .to = 7, .length = -1.0, .resources = {0.0}},
{.from = 4, .to = 5, .length = 1.0, .resources = {1.0}},
{.from = 4, .to = 7, .length = 1.0, .resources = {0.0}},
{.from = 5, .to = 6, .length = -1.0, .resources = {1.0}},
{.from = 5, .to = 7, .length = -1.0, .resources = {0.0}},
{.from = 6, .to = 7, .length = -1.0, .resources = {0.0}}};
EXPECT_THAT(ConstrainedShortestPathsOnDag(8, arcs_with_length_and_resources,
0, 7, /*max_resources=*/{3.0}),
FieldsAre(/*length=*/-2.0, /*arc_path=*/ElementsAre(0, 6, 8, 11),
/*node_path=*/ElementsAre(0, 1, 2, 3, 7)));
}
TEST(ConstrainedShortestPathOnDagTest, GraphWithNoArcs) {
EXPECT_THAT(ConstrainedShortestPathsOnDag(
/*num_nodes=*/1, /*arcs_with_length_and_resources=*/{},
/*source=*/0, /*destination=*/0, /*max_resources=*/{7.0}),
FieldsAre(/*length=*/0, /*arc_path=*/IsEmpty(),
/*node_path=*/ElementsAre(0)));
EXPECT_THAT(ConstrainedShortestPathsOnDag(
/*num_nodes=*/2, /*arcs_with_length_and_resources=*/{},
/*source=*/0, /*destination=*/1, /*max_resources=*/{7.0}),
FieldsAre(/*length=*/kInf, /*arc_path=*/IsEmpty(),
/*node_path=*/IsEmpty()));
}
TEST(ConstrainedShortestPathOnDagTest, SimpleGraphTwoPaths) {
const int source = 0;
const int destination = 1;
@@ -818,17 +857,6 @@ TEST(ConstrainedShortestPathOnDagTest, NegativeMaxResource) {
"negative");
}
TEST(ConstrainedShortestPathOnDagTest, SourceIsDestination) {
const int source = 0;
const int num_nodes = 1;
EXPECT_DEATH(
ConstrainedShortestPathsOnDag(
num_nodes, /*arcs_with_length_and_resources=*/{}, source, source,
/*max_resources=*/{0.0}),
"source and destination");
}
TEST(ConstrainedShortestPathsOnDagWrapperTest, ValidateTopologicalOrder) {
const int source = 0;
const int destination = 1;

View File

@@ -366,7 +366,6 @@ ShortestPathsOnDagWrapper<GraphType>::ShortestPathsOnDagWrapper(
CHECK(graph_ != nullptr);
CHECK(arc_lengths_ != nullptr);
CHECK_GT(graph_->num_nodes(), 0) << "The graph is empty: it has no nodes";
CHECK_GT(graph_->num_arcs(), 0) << "The graph is empty: it has no arcs";
#ifndef NDEBUG
CHECK_EQ(arc_lengths_->size(), graph_->num_arcs());
for (const double arc_length : *arc_lengths_) {
@@ -489,7 +488,6 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
CHECK(graph_ != nullptr);
CHECK(arc_lengths_ != nullptr);
CHECK_GT(graph_->num_nodes(), 0) << "The graph is empty: it has no nodes";
CHECK_GT(graph_->num_arcs(), 0) << "The graph is empty: it has no arcs";
CHECK_GT(path_count_, 0) << "path_count must be greater than 0";
#ifndef NDEBUG
CHECK_EQ(arc_lengths_->size(), graph_->num_arcs());

View File

@@ -78,12 +78,6 @@ TEST(ShortestPathOnDagTest, EmptyGraph) {
"num_nodes\\(\\) > 0");
}
TEST(ShortestPathOnDagTest, NoArcGraph) {
EXPECT_DEATH(ShortestPathsOnDag(/*num_nodes=*/1, /*arcs_with_length=*/{},
/*source=*/0, /*destination=*/0),
"num_arcs\\(\\) > 0");
}
TEST(ShortestPathOnDagTest, NonExistingSourceBecauseNegative) {
EXPECT_DEATH(
ShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{{0, 1, 0.0}},
@@ -137,6 +131,17 @@ TEST(ShortestPathOnDagTest, SimpleGraph) {
/*node_path=*/ElementsAre(source, a, destination)));
}
TEST(ShortestPathOnDagTest, GraphsWithNoArcs) {
EXPECT_THAT(ShortestPathsOnDag(/*num_nodes=*/1, /*arcs_with_length=*/{},
/*source=*/0, /*destination=*/0),
FieldsAre(/*length=*/0, /*arc_path=*/IsEmpty(),
/*node_path=*/ElementsAre(0)));
EXPECT_THAT(ShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{},
/*source=*/0, /*destination=*/1),
FieldsAre(/*length=*/kInf, /*arc_path=*/IsEmpty(),
/*node_path=*/IsEmpty()));
}
TEST(ShortestPathOnDagTest, SourceIsDestination) {
const int source = 0;
const int destination = 1;
@@ -632,13 +637,6 @@ TEST(KShortestPathOnDagTest, EmptyGraph) {
"num_nodes\\(\\) > 0");
}
TEST(KShortestPathOnDagTest, NoArcGraph) {
EXPECT_DEATH(
KShortestPathsOnDag(/*num_nodes=*/1, /*arcs_with_length=*/{},
/*source=*/0, /*destination=*/0, /*path_count=*/2),
"num_arcs\\(\\) > 0");
}
TEST(KShortestPathOnDagTest, NonExistingSourceBecauseNegative) {
EXPECT_DEATH(
KShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{{0, 1, 0.0}},
@@ -689,6 +687,19 @@ TEST(KShortestPathOnDagTest, OnlyHasOnePath) {
/*node_path=*/ElementsAre(source, a, destination))));
}
TEST(KShortestPathOnDagTest, GraphsWithNoArcs) {
EXPECT_THAT(
KShortestPathsOnDag(/*num_nodes=*/1, /*arcs_with_length=*/{},
/*source=*/0, /*destination=*/0, /*path_count=*/2),
ElementsAre(FieldsAre(/*length=*/0, /*arc_path=*/IsEmpty(),
/*node_path=*/ElementsAre(0))));
EXPECT_THAT(
KShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{},
/*source=*/0, /*destination=*/1, /*path_count=*/2),
ElementsAre(FieldsAre(/*length=*/kInf, /*arc_path=*/IsEmpty(),
/*node_path=*/IsEmpty())));
}
TEST(KShortestPathOnDagTest, SourceIsDestination) {
const int source = 0;
const int destination = 1;

View File

@@ -290,6 +290,66 @@ class BaseGraph {
bool const_capacities_;
};
// An iterator that wraps an arc iterator and retrieves a property of the arc.
// The property to retrieve is specified by a `Graph` member function taking an
// `ArcIndex` parameter. For example, `ArcHeadIterator` retrieves the head of an
// arc with `&Graph::Head`.
template <typename Graph, typename ArcIterator, typename PropertyT,
PropertyT (Graph::*property)(typename Graph::ArcIndex) const>
class ArcPropertyIterator
#if __cplusplus < 202002L
: public std::iterator<std::input_iterator_tag, PropertyT>
#endif
{
public:
using value_type = PropertyT;
// TODO(b/385094969): This should be `NodeIndex` for integers,
// `NodeIndex::value_type` for strong signed integer types.
using difference_type = std::ptrdiff_t;
ArcPropertyIterator() = default;
ArcPropertyIterator(const Graph& graph, ArcIterator arc_it)
: arc_it_(std::move(arc_it)), graph_(&graph) {}
value_type operator*() const { return (graph_->*property)(*arc_it_); }
ArcPropertyIterator& operator++() {
++arc_it_;
return *this;
}
ArcPropertyIterator operator++(int) {
auto tmp = *this;
++arc_it_;
return tmp;
}
friend bool operator==(const ArcPropertyIterator& l,
const ArcPropertyIterator& r) {
return l.arc_it_ == r.arc_it_;
}
friend bool operator!=(const ArcPropertyIterator& l,
const ArcPropertyIterator& r) {
return !(l == r);
}
private:
ArcIterator arc_it_;
const Graph* graph_;
};
// An iterator that iterates on the heads of the arcs of another iterator.
template <typename Graph, typename ArcIterator>
using ArcHeadIterator =
ArcPropertyIterator<Graph, ArcIterator, typename Graph::NodeIndex,
&Graph::Head>;
// An iterator that iterates on the opposite arcs of another iterator.
template <typename Graph, typename ArcIterator>
using ArcOppositeArcIterator =
ArcPropertyIterator<Graph, ArcIterator, typename Graph::ArcIndex,
&Graph::OppositeArc>;
// Basic graph implementation without reverse arc. This class also serves as a
// documentation for the generic graph interface (minus the part related to
// reverse arcs).
@@ -352,9 +412,15 @@ class ListGraph : public BaseGraph<NodeIndexType, ArcIndexType, false> {
void Build() { Build(nullptr); }
void Build(std::vector<ArcIndexType>* permutation);
// Returns the tail/head of a valid arc.
NodeIndexType Tail(ArcIndexType arc) const;
NodeIndexType Head(ArcIndexType arc) const;
// Do not use directly.
class OutgoingArcIterator;
class OutgoingHeadIterator;
struct OutgoingArcIteratorTag {};
using OutgoingArcIterator =
ChasingIterator<ArcIndexType, Base::kNilArc, OutgoingArcIteratorTag>;
using OutgoingHeadIterator = ArcHeadIterator<ListGraph, OutgoingArcIterator>;
// Graph jargon: the "degree" of a node is its number of arcs. The out-degree
// is the number of outgoing arcs. The in-degree is the number of incoming
@@ -366,22 +432,30 @@ class ListGraph : public BaseGraph<NodeIndexType, ArcIndexType, false> {
// Allows to iterate over the forward arcs that verify Tail(arc) == node.
// This is meant to be used as:
// for (const ArcIndex arc : graph.OutgoingArcs(node)) { ... }
BeginEndWrapper<OutgoingArcIterator> OutgoingArcs(NodeIndexType node) const;
BeginEndWrapper<OutgoingArcIterator> OutgoingArcs(NodeIndexType node) const {
DCHECK(Base::IsNodeValid(node));
return {OutgoingArcIterator(start_[node], next_.data()),
OutgoingArcIterator()};
}
// Advanced usage. Same as OutgoingArcs(), but allows to restart the iteration
// from an already known outgoing arc of the given node. If `from` is
// `kNilArc`, an empty range is returned.
BeginEndWrapper<OutgoingArcIterator> OutgoingArcsStartingFrom(
NodeIndexType node, ArcIndexType from) const;
NodeIndexType node, ArcIndexType from) const {
DCHECK(Base::IsNodeValid(node));
if (from == Base::kNilArc) return {};
DCHECK_EQ(Tail(from), node);
return {OutgoingArcIterator(from, next_.data()), OutgoingArcIterator()};
}
// This loops over the heads of the OutgoingArcs(node). It is just a more
// convenient way to achieve this. Moreover this interface is used by some
// graph algorithms.
BeginEndWrapper<OutgoingHeadIterator> operator[](NodeIndexType node) const;
// Returns the tail/head of a valid arc.
NodeIndexType Tail(ArcIndexType arc) const;
NodeIndexType Head(ArcIndexType arc) const;
BeginEndWrapper<OutgoingHeadIterator> operator[](NodeIndexType node) const {
return {OutgoingHeadIterator(*this, OutgoingArcs(node).begin()),
OutgoingHeadIterator()};
}
void ReserveNodes(NodeIndexType bound) override;
void ReserveArcs(ArcIndexType bound) override;
@@ -508,16 +582,26 @@ class ReverseArcListGraph
}
}
NodeIndexType Head(ArcIndexType arc) const;
NodeIndexType Tail(ArcIndexType arc) const;
// Returns the opposite arc of a given arc. That is the reverse arc of the
// given forward arc or the forward arc of a given reverse arc.
ArcIndexType OppositeArc(ArcIndexType arc) const;
// Do not use directly. See instead the arc iteration functions below.
struct OutgoingArcIteratorTag {};
using OutgoingArcIterator =
ChasingIterator<ArcIndexType, Base::kNilArc, OutgoingArcIteratorTag>;
struct OppositeIncomingArcIteratorTag {};
using OppositeIncomingArcIterator =
ChasingIterator<ArcIndexType, Base::kNilArc,
OppositeIncomingArcIteratorTag>;
class OutgoingOrOppositeIncomingArcIterator;
class OppositeIncomingArcIterator;
class IncomingArcIterator;
class OutgoingArcIterator;
class OutgoingHeadIterator;
using OutgoingHeadIterator =
ArcHeadIterator<ReverseArcListGraph, OutgoingArcIterator>;
using IncomingArcIterator =
ArcOppositeArcIterator<ReverseArcListGraph, OppositeIncomingArcIterator>;
// ReverseArcListGraph<>::OutDegree() and ::InDegree() work in O(degree).
ArcIndexType OutDegree(NodeIndexType node) const;
@@ -530,30 +614,62 @@ class ReverseArcListGraph
// The StartingFrom() version are similar, but restart the iteration from a
// given arc position (which must be valid in the iteration context), or
// `kNilArc`, in which case an empty range is returned.
BeginEndWrapper<OutgoingArcIterator> OutgoingArcs(NodeIndexType node) const;
BeginEndWrapper<IncomingArcIterator> IncomingArcs(NodeIndexType node) const;
BeginEndWrapper<OutgoingArcIterator> OutgoingArcs(NodeIndexType node) const {
DCHECK(Base::IsNodeValid(node));
return {OutgoingArcIterator(start_[node], next_.data()),
OutgoingArcIterator()};
}
BeginEndWrapper<OutgoingArcIterator> OutgoingArcsStartingFrom(
NodeIndexType node, ArcIndexType from) const {
DCHECK(Base::IsNodeValid(node));
if (from == Base::kNilArc) return {};
DCHECK_GE(from, 0);
DCHECK_EQ(Tail(from), node);
return {OutgoingArcIterator(from, next_.data()), OutgoingArcIterator()};
}
BeginEndWrapper<IncomingArcIterator> IncomingArcs(NodeIndexType node) const {
return {IncomingArcIterator(*this, OppositeIncomingArcs(node).begin()),
IncomingArcIterator()};
}
BeginEndWrapper<IncomingArcIterator> IncomingArcsStartingFrom(
NodeIndexType node, ArcIndexType from) const {
DCHECK(Base::IsNodeValid(node));
if (from == Base::kNilArc) return {};
return {
IncomingArcIterator(
*this,
OppositeIncomingArcsStartingFrom(node, OppositeArc(from)).begin()),
IncomingArcIterator()};
}
BeginEndWrapper<OutgoingOrOppositeIncomingArcIterator>
OutgoingOrOppositeIncomingArcs(NodeIndexType node) const;
BeginEndWrapper<OppositeIncomingArcIterator> OppositeIncomingArcs(
NodeIndexType node) const;
BeginEndWrapper<OutgoingArcIterator> OutgoingArcsStartingFrom(
NodeIndexType node, ArcIndexType from) const;
BeginEndWrapper<IncomingArcIterator> IncomingArcsStartingFrom(
NodeIndexType node, ArcIndexType from) const;
NodeIndexType node) const {
DCHECK(Base::IsNodeValid(node));
return {OppositeIncomingArcIterator(reverse_start_[node], next_.data()),
OppositeIncomingArcIterator()};
}
BeginEndWrapper<OppositeIncomingArcIterator> OppositeIncomingArcsStartingFrom(
NodeIndexType node, ArcIndexType from) const {
DCHECK(Base::IsNodeValid(node));
if (from == Base::kNilArc) return {};
DCHECK_LT(from, 0);
DCHECK_EQ(Tail(from), node);
return {OppositeIncomingArcIterator(from, next_.data()),
OppositeIncomingArcIterator()};
}
BeginEndWrapper<OutgoingOrOppositeIncomingArcIterator>
OutgoingOrOppositeIncomingArcsStartingFrom(NodeIndexType node,
ArcIndexType from) const;
BeginEndWrapper<OppositeIncomingArcIterator> OppositeIncomingArcsStartingFrom(
NodeIndexType node, ArcIndexType from) const;
// This loops over the heads of the OutgoingArcs(node). It is just a more
// convenient way to achieve this. Moreover this interface is used by some
// graph algorithms.
BeginEndWrapper<OutgoingHeadIterator> operator[](NodeIndexType node) const;
NodeIndexType Head(ArcIndexType arc) const;
NodeIndexType Tail(ArcIndexType arc) const;
void ReserveNodes(NodeIndexType bound) override;
void ReserveArcs(ArcIndexType bound) override;
void AddNode(NodeIndexType node);
@@ -602,16 +718,23 @@ class ReverseArcStaticGraph
}
}
// Deprecated.
class OutgoingOrOppositeIncomingArcIterator;
class OppositeIncomingArcIterator;
class IncomingArcIterator;
class OutgoingArcIterator;
ArcIndexType OppositeArc(ArcIndexType arc) const;
// TODO(user): support Head() and Tail() before Build(), like StaticGraph<>.
NodeIndexType Head(ArcIndexType arc) const;
NodeIndexType Tail(ArcIndexType arc) const;
// ReverseArcStaticGraph<>::OutDegree() and ::InDegree() work in O(1).
ArcIndexType OutDegree(NodeIndexType node) const;
ArcIndexType InDegree(NodeIndexType node) const;
// Deprecated.
class OutgoingOrOppositeIncomingArcIterator;
using OppositeIncomingArcIterator = IntegerRangeIterator<ArcIndexType>;
using IncomingArcIterator =
ArcOppositeArcIterator<ReverseArcStaticGraph,
OppositeIncomingArcIterator>;
using OutgoingArcIterator = IntegerRangeIterator<ArcIndexType>;
IntegerRange<ArcIndexType> OutgoingArcs(NodeIndexType node) const {
return IntegerRange<ArcIndexType>(start_[node], DirectArcLimit(node));
}
@@ -635,12 +758,24 @@ class ReverseArcStaticGraph
limit);
}
BeginEndWrapper<IncomingArcIterator> IncomingArcs(NodeIndexType node) const;
BeginEndWrapper<IncomingArcIterator> IncomingArcs(NodeIndexType node) const {
const auto opposite_incoming_arcs = OppositeIncomingArcs(node);
return {IncomingArcIterator(*this, opposite_incoming_arcs.begin()),
IncomingArcIterator(*this, opposite_incoming_arcs.end())};
}
BeginEndWrapper<IncomingArcIterator> IncomingArcsStartingFrom(
NodeIndexType node, ArcIndexType from) const {
DCHECK(Base::IsNodeValid(node));
const auto opposite_incoming_arcs = OppositeIncomingArcsStartingFrom(
node, from == Base::kNilArc ? Base::kNilArc : OppositeArc(from));
return {IncomingArcIterator(*this, opposite_incoming_arcs.begin()),
IncomingArcIterator(*this, opposite_incoming_arcs.end())};
}
BeginEndWrapper<OutgoingOrOppositeIncomingArcIterator>
OutgoingOrOppositeIncomingArcs(NodeIndexType node) const;
BeginEndWrapper<IncomingArcIterator> IncomingArcsStartingFrom(
NodeIndexType node, ArcIndexType from) const;
BeginEndWrapper<OutgoingOrOppositeIncomingArcIterator>
OutgoingOrOppositeIncomingArcsStartingFrom(NodeIndexType node,
ArcIndexType from) const;
@@ -650,11 +785,6 @@ class ReverseArcStaticGraph
// graph algorithms.
absl::Span<const NodeIndexType> operator[](NodeIndexType node) const;
ArcIndexType OppositeArc(ArcIndexType arc) const;
// TODO(user): support Head() and Tail() before Build(), like StaticGraph<>.
NodeIndexType Head(ArcIndexType arc) const;
NodeIndexType Tail(ArcIndexType arc) const;
void ReserveArcs(ArcIndexType bound) override;
void AddNode(NodeIndexType node);
ArcIndexType AddArc(NodeIndexType tail, NodeIndexType head);
@@ -989,7 +1119,7 @@ void BaseGraph<NodeIndexType, ArcIndexType, HasNegativeReverseArcs>::
(*v)[i] = sum;
sum += temp;
}
DCHECK(sum == num_arcs_);
DCHECK_EQ(sum, num_arcs_);
(*v)[num_nodes_] = sum; // Sentinel.
}
@@ -1109,17 +1239,6 @@ void BaseGraph<NodeIndexType, ArcIndexType, HasNegativeReverseArcs>::
// ListGraph implementation ----------------------------------------------------
DEFINE_RANGE_BASED_ARC_ITERATION(ListGraph, Outgoing);
template <typename NodeIndexType, typename ArcIndexType>
BeginEndWrapper<
typename ListGraph<NodeIndexType, ArcIndexType>::OutgoingHeadIterator>
ListGraph<NodeIndexType, ArcIndexType>::operator[](NodeIndexType node) const {
return BeginEndWrapper<OutgoingHeadIterator>(
OutgoingHeadIterator(*this, node),
OutgoingHeadIterator(*this, node, Base::kNilArc));
}
template <typename NodeIndexType, typename ArcIndexType>
NodeIndexType ListGraph<NodeIndexType, ArcIndexType>::Tail(
ArcIndexType arc) const {
@@ -1130,7 +1249,7 @@ NodeIndexType ListGraph<NodeIndexType, ArcIndexType>::Tail(
template <typename NodeIndexType, typename ArcIndexType>
NodeIndexType ListGraph<NodeIndexType, ArcIndexType>::Head(
ArcIndexType arc) const {
DCHECK(IsArcValid(arc));
DCHECK(IsArcValid(arc)) << arc;
return head_[arc];
}
@@ -1188,80 +1307,6 @@ void ListGraph<NodeIndexType, ArcIndexType>::Build(
}
}
template <typename NodeIndexType, typename ArcIndexType>
class ListGraph<NodeIndexType, ArcIndexType>::OutgoingArcIterator {
public:
OutgoingArcIterator(const ListGraph& graph, NodeIndexType node)
: graph_(&graph), index_(graph.start_[node]) {
DCHECK(graph.IsNodeValid(node));
}
OutgoingArcIterator(const ListGraph& graph, NodeIndexType node,
ArcIndexType arc)
: graph_(&graph), index_(arc) {
DCHECK(graph.IsNodeValid(node));
DCHECK(arc == Base::kNilArc || graph.Tail(arc) == node);
}
bool Ok() const { return index_ != Base::kNilArc; }
ArcIndexType Index() const { return index_; }
void Next() {
DCHECK(Ok());
index_ = graph_->next_[index_];
}
DEFINE_STL_ITERATOR_FUNCTIONS(OutgoingArcIterator);
private:
const ListGraph* graph_;
ArcIndexType index_;
};
template <typename NodeIndexType, typename ArcIndexType>
class ListGraph<NodeIndexType, ArcIndexType>::OutgoingHeadIterator {
public:
using iterator_category = std::input_iterator_tag;
using difference_type = ptrdiff_t;
using pointer = const NodeIndexType*;
using reference = const NodeIndexType&;
using value_type = NodeIndexType;
OutgoingHeadIterator(const ListGraph& graph, NodeIndexType node)
: graph_(&graph), index_(graph.start_[node]) {
DCHECK(graph.IsNodeValid(node));
}
OutgoingHeadIterator(const ListGraph& graph, NodeIndexType node,
ArcIndexType arc)
: graph_(&graph), index_(arc) {
DCHECK(graph.IsNodeValid(node));
DCHECK(arc == Base::kNilArc || graph.Tail(arc) == node);
}
bool Ok() const { return index_ != Base::kNilArc; }
NodeIndexType Index() const { return graph_->Head(index_); }
void Next() {
DCHECK(Ok());
index_ = graph_->next_[index_];
}
bool operator!=(
const typename ListGraph<
NodeIndexType, ArcIndexType>::OutgoingHeadIterator& other) const {
return index_ != other.index_;
}
NodeIndexType operator*() const { return Index(); }
OutgoingHeadIterator& operator++() {
Next();
return *this;
}
OutgoingHeadIterator operator++(int) {
auto tmp = *this;
Next();
return *this;
}
private:
const ListGraph* graph_;
ArcIndexType index_;
};
// StaticGraph implementation --------------------------------------------------
template <typename NodeIndexType, typename ArcIndexType>
@@ -1422,49 +1467,21 @@ void StaticGraph<NodeIndexType, ArcIndexType>::Build(
}
}
// TODO(b/385094969): Remove this class.
template <typename NodeIndexType, typename ArcIndexType>
class StaticGraph<NodeIndexType, ArcIndexType>::OutgoingArcIterator {
public:
OutgoingArcIterator(const OutgoingArcIterator&) = default;
OutgoingArcIterator& operator=(const OutgoingArcIterator&) = default;
OutgoingArcIterator(const StaticGraph& graph, NodeIndexType node)
: index_(graph.start_[node]), limit_(graph.DirectArcLimit(node)) {}
OutgoingArcIterator(const StaticGraph& graph, NodeIndexType node,
ArcIndexType arc)
: limit_(graph.DirectArcLimit(node)) {
index_ = arc == Base::kNilArc ? limit_ : arc;
DCHECK_GE(arc, graph.start_[node]);
}
bool Ok() const { return index_ != limit_; }
ArcIndexType Index() const { return index_; }
void Next() {
DCHECK(Ok());
index_++;
}
private:
ArcIndexType index_;
ArcIndexType limit_;
};
// ReverseArcListGraph implementation ------------------------------------------
DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph, Outgoing);
DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph, Incoming);
DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph,
OutgoingOrOppositeIncoming);
DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcListGraph, OppositeIncoming);
template <typename NodeIndexType, typename ArcIndexType>
BeginEndWrapper<typename ReverseArcListGraph<
NodeIndexType, ArcIndexType>::OutgoingHeadIterator>
ReverseArcListGraph<NodeIndexType, ArcIndexType>::operator[](
NodeIndexType node) const {
const auto outgoing_arcs = OutgoingArcs(node);
// Note: `BeginEndWrapper` is a borrowed range (`std::ranges::borrowed_range`)
// so copying begin/end is safe.
return BeginEndWrapper<OutgoingHeadIterator>(
OutgoingHeadIterator(*this, node),
OutgoingHeadIterator(*this, node, Base::kNilArc));
OutgoingHeadIterator(*this, outgoing_arcs.begin()),
OutgoingHeadIterator(*this, outgoing_arcs.end()));
}
template <typename NodeIndexType, typename ArcIndexType>
@@ -1553,90 +1570,6 @@ void ReverseArcListGraph<NodeIndexType, ArcIndexType>::Build(
}
}
template <typename NodeIndexType, typename ArcIndexType>
class ReverseArcListGraph<NodeIndexType, ArcIndexType>::OutgoingArcIterator {
public:
OutgoingArcIterator(const ReverseArcListGraph& graph, NodeIndexType node)
: graph_(&graph), index_(graph.start_[node]) {
DCHECK(graph.IsNodeValid(node));
}
OutgoingArcIterator(const ReverseArcListGraph& graph, NodeIndexType node,
ArcIndexType arc)
: graph_(&graph), index_(arc) {
DCHECK(graph.IsNodeValid(node));
DCHECK(arc == Base::kNilArc || arc >= 0);
DCHECK(arc == Base::kNilArc || graph.Tail(arc) == node);
}
bool Ok() const { return index_ != Base::kNilArc; }
ArcIndexType Index() const { return index_; }
void Next() {
DCHECK(Ok());
index_ = graph_->next_[index_];
}
DEFINE_STL_ITERATOR_FUNCTIONS(OutgoingArcIterator);
private:
const ReverseArcListGraph* graph_;
ArcIndexType index_;
};
template <typename NodeIndexType, typename ArcIndexType>
class ReverseArcListGraph<NodeIndexType,
ArcIndexType>::OppositeIncomingArcIterator {
public:
OppositeIncomingArcIterator(const ReverseArcListGraph& graph,
NodeIndexType node)
: next_(graph.next_.data()), index_(graph.reverse_start_[node]) {
DCHECK(graph.IsNodeValid(node));
}
OppositeIncomingArcIterator(const ReverseArcListGraph& graph,
NodeIndexType node, ArcIndexType arc)
: next_(graph.next_.data()), index_(arc) {
DCHECK(graph.IsNodeValid(node));
DCHECK(arc == Base::kNilArc || arc < 0);
DCHECK(arc == Base::kNilArc || graph.Tail(arc) == node);
}
bool Ok() const { return index_ != Base::kNilArc; }
ArcIndexType Index() const { return index_; }
void Next() {
DCHECK(Ok());
index_ = next_[index_];
}
DEFINE_STL_ITERATOR_FUNCTIONS(OppositeIncomingArcIterator);
protected:
const ArcIndexType* next_;
ArcIndexType index_;
};
template <typename NodeIndexType, typename ArcIndexType>
class ReverseArcListGraph<NodeIndexType, ArcIndexType>::IncomingArcIterator
: public OppositeIncomingArcIterator {
public:
IncomingArcIterator(const ReverseArcListGraph& graph, NodeIndexType node)
: OppositeIncomingArcIterator(graph, node), graph_(&graph) {}
IncomingArcIterator(const ReverseArcListGraph& graph, NodeIndexType node,
ArcIndexType arc)
: OppositeIncomingArcIterator(
graph, node,
arc == Base::kNilArc ? Base::kNilArc : graph.OppositeArc(arc)),
graph_(&graph) {}
// We overwrite OppositeIncomingArcIterator::Index() here.
ArcIndexType Index() const {
return this->index_ == Base::kNilArc ? Base::kNilArc
: graph_->OppositeArc(this->index_);
}
DEFINE_STL_ITERATOR_FUNCTIONS(IncomingArcIterator);
private:
const ReverseArcListGraph* graph_;
};
template <typename NodeIndexType, typename ArcIndexType>
class ReverseArcListGraph<NodeIndexType,
ArcIndexType>::OutgoingOrOppositeIncomingArcIterator {
@@ -1676,37 +1609,8 @@ class ReverseArcListGraph<NodeIndexType,
NodeIndexType node_;
};
template <typename NodeIndexType, typename ArcIndexType>
class ReverseArcListGraph<NodeIndexType, ArcIndexType>::OutgoingHeadIterator {
public:
OutgoingHeadIterator(const ReverseArcListGraph& graph, NodeIndexType node)
: graph_(&graph), index_(graph.start_[node]) {
DCHECK(graph.IsNodeValid(node));
}
OutgoingHeadIterator(const ReverseArcListGraph& graph, NodeIndexType node,
ArcIndexType arc)
: graph_(&graph), index_(arc) {
DCHECK(graph.IsNodeValid(node));
DCHECK(arc == Base::kNilArc || arc >= 0);
DCHECK(arc == Base::kNilArc || graph.Tail(arc) == node);
}
bool Ok() const { return index_ != Base::kNilArc; }
ArcIndexType Index() const { return graph_->Head(index_); }
void Next() {
DCHECK(Ok());
index_ = graph_->next_[index_];
}
DEFINE_STL_ITERATOR_FUNCTIONS(OutgoingHeadIterator);
private:
const ReverseArcListGraph* graph_;
ArcIndexType index_;
};
// ReverseArcStaticGraph implementation ----------------------------------------
DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcStaticGraph, Incoming);
DEFINE_RANGE_BASED_ARC_ITERATION(ReverseArcStaticGraph,
OutgoingOrOppositeIncoming);
@@ -1833,93 +1737,6 @@ void ReverseArcStaticGraph<NodeIndexType, ArcIndexType>::Build(
}
}
// TODO(b/385094969): Remove this class.
template <typename NodeIndexType, typename ArcIndexType>
class ReverseArcStaticGraph<NodeIndexType, ArcIndexType>::OutgoingArcIterator {
public:
OutgoingArcIterator(const ReverseArcStaticGraph& graph, NodeIndexType node)
: index_(graph.start_[node]), limit_(graph.DirectArcLimit(node)) {}
OutgoingArcIterator(const ReverseArcStaticGraph& graph, NodeIndexType node,
ArcIndexType arc)
: limit_(graph.DirectArcLimit(node)) {
index_ = arc == Base::kNilArc ? limit_ : arc;
DCHECK_GE(arc, graph.start_[node]);
}
bool Ok() const { return index_ != limit_; }
ArcIndexType Index() const { return index_; }
void Next() {
DCHECK(Ok());
index_++;
}
private:
ArcIndexType index_;
const ArcIndexType limit_;
};
// TODO(b/385094969): Remove this class.
template <typename NodeIndexType, typename ArcIndexType>
class ReverseArcStaticGraph<NodeIndexType,
ArcIndexType>::OppositeIncomingArcIterator {
public:
OppositeIncomingArcIterator(const ReverseArcStaticGraph& graph,
NodeIndexType node)
: limit_(graph.ReverseArcLimit(node)),
index_(graph.reverse_start_[node]) {
DCHECK(graph.IsNodeValid(node));
DCHECK_LE(index_, limit_);
}
OppositeIncomingArcIterator(const ReverseArcStaticGraph& graph,
NodeIndexType node, ArcIndexType arc)
: limit_(graph.ReverseArcLimit(node)) {
index_ = arc == Base::kNilArc ? limit_ : arc;
DCHECK(graph.IsNodeValid(node));
DCHECK_GE(index_, graph.reverse_start_[node]);
DCHECK_LE(index_, limit_);
}
bool Ok() const { return index_ != limit_; }
ArcIndexType Index() const { return index_; }
void Next() {
DCHECK(Ok());
index_++;
}
DEFINE_STL_ITERATOR_FUNCTIONS(OppositeIncomingArcIterator);
protected:
const ArcIndexType limit_;
ArcIndexType index_;
};
template <typename NodeIndexType, typename ArcIndexType>
class ReverseArcStaticGraph<NodeIndexType, ArcIndexType>::IncomingArcIterator
: public OppositeIncomingArcIterator {
public:
IncomingArcIterator(const ReverseArcStaticGraph& graph, NodeIndexType node)
: OppositeIncomingArcIterator(graph, node), graph_(graph) {}
IncomingArcIterator(const ReverseArcStaticGraph& graph, NodeIndexType node,
ArcIndexType arc)
: OppositeIncomingArcIterator(graph, node,
arc == Base::kNilArc
? Base::kNilArc
: (arc == graph.ReverseArcLimit(node)
? graph.ReverseArcLimit(node)
: graph.OppositeArc(arc))),
graph_(graph) {}
ArcIndexType Index() const {
return this->index_ == this->limit_ ? this->limit_
: graph_.OppositeArc(this->index_);
}
DEFINE_STL_ITERATOR_FUNCTIONS(IncomingArcIterator);
private:
const ReverseArcStaticGraph& graph_;
};
template <typename NodeIndexType, typename ArcIndexType>
class ReverseArcStaticGraph<
NodeIndexType, ArcIndexType>::OutgoingOrOppositeIncomingArcIterator {
@@ -2091,25 +1908,6 @@ class CompleteBipartiteGraph
ArcIndexType from) const;
IntegerRange<NodeIndexType> operator[](NodeIndexType node) const;
// Deprecated interface.
class OutgoingArcIterator {
public:
OutgoingArcIterator(const CompleteBipartiteGraph& graph, NodeIndexType node)
: index_(static_cast<ArcIndexType>(graph.right_nodes_) * node),
limit_(node >= graph.left_nodes_
? index_
: static_cast<ArcIndexType>(graph.right_nodes_) *
(node + 1)) {}
bool Ok() const { return index_ < limit_; }
ArcIndexType Index() const { return index_; }
void Next() { index_++; }
private:
ArcIndexType index_;
const ArcIndexType limit_;
};
private:
const NodeIndexType left_nodes_;
const NodeIndexType right_nodes_;

View File

@@ -37,13 +37,24 @@ namespace util {
// And a client will use it like this:
//
// for (const ArcIndex arc : graph.OutgoingArcs(node)) { ... }
//
// Note that `BeginEndWrapper` is conceptually a borrowed range as per the C++
// standard (`std::ranges::borrowed_range`):
// "The concept borrowed_range defines the requirements of a range such that a
// function can take it by value and return iterators obtained from it without
// danger of dangling". We cannot `static_assert` this property though as
// `std::ranges` is prohibited in google3.
template <typename Iterator>
class BeginEndWrapper {
public:
using const_iterator = Iterator;
using value_type = typename std::iterator_traits<Iterator>::value_type;
// If `Iterator` is default-constructible, an empty range.
BeginEndWrapper() = default;
BeginEndWrapper(Iterator begin, Iterator end) : begin_(begin), end_(end) {}
Iterator begin() const { return begin_; }
Iterator end() const { return end_; }
size_t size() const { return end_ - begin_; }
@@ -51,8 +62,8 @@ class BeginEndWrapper {
bool empty() const { return begin() == end(); }
private:
const Iterator begin_;
const Iterator end_;
Iterator begin_;
Iterator end_;
};
// Inline wrapper methods, to make the client code even simpler.
@@ -227,6 +238,50 @@ class IntegerRange : public BeginEndWrapper<IntegerRangeIterator<IntegerType>> {
}
};
// A helper class for implementing list graph iterators: This does pointer
// chasing on `next` until `sentinel` is found. `Tag` allows distinguishing
// different iterators with the same index type and sentinel.
template <typename IndexT, IndexT sentinel, typename Tag>
class ChasingIterator
#if __cplusplus < 202002L
: public std::iterator<std::input_iterator_tag, IndexT>
#endif
{
public:
// TODO(b/385094969): This should be `IntegerType` for integers,
// `IntegerType:value_type` for strong signed integer types.
using difference_type = std::ptrdiff_t;
using value_type = IndexT;
ChasingIterator() : index_(sentinel), next_(nullptr) {}
ChasingIterator(IndexT index, const IndexT* next)
: index_(index), next_(next) {}
IndexT operator*() const { return index_; }
ChasingIterator& operator++() {
index_ = next_[index_];
return *this;
}
ChasingIterator operator++(int) {
auto tmp = *this;
index_ = next_[index_];
return tmp;
}
friend bool operator==(const ChasingIterator& l, const ChasingIterator& r) {
return l.index_ == r.index_;
}
friend bool operator!=(const ChasingIterator& l, const ChasingIterator& r) {
return l.index_ != r.index_;
}
private:
IndexT index_;
const IndexT* next_;
};
} // namespace util
#endif // UTIL_GRAPH_ITERATORS_H_