diff --git a/ortools/graph/BUILD.bazel b/ortools/graph/BUILD.bazel index b2c004240c..d0973f4a4b 100644 --- a/ortools/graph/BUILD.bazel +++ b/ortools/graph/BUILD.bazel @@ -93,6 +93,7 @@ cc_library( "//ortools/base:iterator_adaptors", "//ortools/base:threadpool", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", ], @@ -121,6 +122,7 @@ cc_library( "//ortools/util:bitset", "//ortools/util:saturated_arithmetic", "//ortools/util:vector_or_function", + "@com_google_absl//absl/types:span", ], ) @@ -167,8 +169,9 @@ cc_library( hdrs = ["one_tree_lower_bound.h"], deps = [ ":christofides", + ":graph", ":minimum_spanning_tree", - "//ortools/base:types", + "@com_google_absl//absl/log", "@com_google_absl//absl/types:span", ], ) @@ -215,7 +218,10 @@ cc_library( ":shortest_paths", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -265,7 +271,6 @@ cc_test( "//ortools/base:path", "//ortools/linear_solver", "//ortools/util:file_util", - "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/random", "@com_google_absl//absl/strings:str_format", "@com_google_benchmark//:benchmark", @@ -363,10 +368,11 @@ cc_library( "//ortools/base:adjustable_priority_queue", "//ortools/base:int_type", "//ortools/base:strong_vector", - "//ortools/base:types", "//ortools/util:saturated_arithmetic", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/memory", + "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", ], ) @@ -390,7 +396,6 @@ cc_library( cc_library( name = "dag_constrained_shortest_path", - testonly = True, srcs = ["dag_constrained_shortest_path.cc"], hdrs = ["dag_constrained_shortest_path.h"], deps = [ @@ -407,6 +412,21 @@ cc_library( ], ) +cc_library( + name = "rooted_tree", + hdrs = ["rooted_tree.h"], + deps = [ + "//ortools/base:status_macros", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + # From util/graph cc_library( name = "connected_components", diff --git a/ortools/graph/CMakeLists.txt b/ortools/graph/CMakeLists.txt index 2b17e37e7b..109cf30b01 100644 --- a/ortools/graph/CMakeLists.txt +++ b/ortools/graph/CMakeLists.txt @@ -31,6 +31,7 @@ list(REMOVE_ITEM _SRCS ${CMAKE_CURRENT_SOURCE_DIR}/multi_dijkstra_test.cc ${CMAKE_CURRENT_SOURCE_DIR}/one_tree_lower_bound_test.cc ${CMAKE_CURRENT_SOURCE_DIR}/perfect_matching_test.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rooted_tree_test.cc ${CMAKE_CURRENT_SOURCE_DIR}/shortest_paths_benchmarks.cc ${CMAKE_CURRENT_SOURCE_DIR}/shortest_paths_test.cc ${CMAKE_CURRENT_SOURCE_DIR}/solve_flow_model.cc @@ -54,6 +55,6 @@ target_link_libraries(${NAME} PRIVATE absl::strings absl::str_format protobuf::libprotobuf - ${PROJECT_NAMESPACE}::${PROJECT_NAME}_proto + ${PROJECT_NAMESPACE}::ortools_proto $<$:Coin::Cbc>) #add_library(${PROJECT_NAMESPACE}::graph ALIAS ${NAME}) diff --git a/ortools/graph/christofides.h b/ortools/graph/christofides.h index 0b24e1088f..38f0c907bf 100644 --- a/ortools/graph/christofides.h +++ b/ortools/graph/christofides.h @@ -27,13 +27,14 @@ #define OR_TOOLS_GRAPH_CHRISTOFIDES_H_ #include +#include #include +#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "ortools/base/logging.h" -#include "ortools/base/types.h" #include "ortools/graph/eulerian_path.h" #include "ortools/graph/graph.h" #include "ortools/graph/minimum_spanning_tree.h" diff --git a/ortools/graph/cliques.cc b/ortools/graph/cliques.cc index 5f2506b6f5..ed3ccc1b93 100644 --- a/ortools/graph/cliques.cc +++ b/ortools/graph/cliques.cc @@ -188,7 +188,9 @@ class FindAndEliminate { public: FindAndEliminate(std::function graph, int node_count, std::function&)> callback) - : graph_(graph), node_count_(node_count), callback_(callback) {} + : graph_(std::move(graph)), + node_count_(node_count), + callback_(std::move(callback)) {} bool GraphCallback(int node1, int node2) { if (visited_.find( @@ -233,13 +235,13 @@ void FindCliques(std::function graph, int node_count, } bool stop = false; - Search(graph, callback, initial_candidates.get(), 0, node_count, &actual, - &stop); + Search(std::move(graph), std::move(callback), initial_candidates.get(), 0, + node_count, &actual, &stop); } void CoverArcsByCliques(std::function graph, int node_count, std::function&)> callback) { - FindAndEliminate cache(graph, node_count, callback); + FindAndEliminate cache(std::move(graph), node_count, std::move(callback)); std::unique_ptr initial_candidates(new int[node_count]); std::vector actual; @@ -256,8 +258,8 @@ void CoverArcsByCliques(std::function graph, int node_count, } bool stop = false; - Search(cached_graph, cached_callback, initial_candidates.get(), 0, node_count, - &actual, &stop); + Search(std::move(cached_graph), std::move(cached_callback), + initial_candidates.get(), 0, node_count, &actual, &stop); } } // namespace operations_research diff --git a/ortools/graph/cliques.h b/ortools/graph/cliques.h index 67703b7895..7901566bf2 100644 --- a/ortools/graph/cliques.h +++ b/ortools/graph/cliques.h @@ -24,6 +24,7 @@ #ifndef OR_TOOLS_GRAPH_CLIQUES_H_ #define OR_TOOLS_GRAPH_CLIQUES_H_ +#include #include #include #include diff --git a/ortools/graph/dag_constrained_shortest_path.h b/ortools/graph/dag_constrained_shortest_path.h index e060bddbcd..ddb22ada99 100644 --- a/ortools/graph/dag_constrained_shortest_path.h +++ b/ortools/graph/dag_constrained_shortest_path.h @@ -14,8 +14,6 @@ #ifndef OR_TOOLS_GRAPH_DAG_CONSTRAINED_SHORTEST_PATH_H_ #define OR_TOOLS_GRAPH_DAG_CONSTRAINED_SHORTEST_PATH_H_ -#include - #include #include #include @@ -165,6 +163,7 @@ class ConstrainedShortestPathsOnDagWrapper { std::vector& lengths_from_sources, std::vector>& resources_from_sources, std::vector& incoming_arc_indices_from_sources, + std::vector& incoming_label_indices_from_sources, std::vector& first_label, std::vector& num_labels); // Returns the arc index linking two nodes from each pass forming the best @@ -184,12 +183,9 @@ class ConstrainedShortestPathsOnDagWrapper { // `sources` (if `direction` iS FORWARD) or `destinations` (if `direction` is // BACKWARD) and ends in node represented by `best_label_index`. std::vector ArcPathTo( - int best_label_index, const GraphType& reverse_graph, - absl::Span arc_lengths, - absl::Span lengths_from_sources, + int best_label_index, absl::Span incoming_arc_indices_from_sources, - absl::Span first_label, - absl::Span num_labels) const; + absl::Span incoming_label_indices_from_sources) const; // Returns the list of all the nodes implied by a given `arc_path`. std::vector NodePathImpliedBy(absl::Span arc_path, @@ -257,6 +253,7 @@ class ConstrainedShortestPathsOnDagWrapper { std::vector lengths_from_sources_[2]; std::vector> resources_from_sources_[2]; std::vector incoming_arc_indices_from_sources_[2]; + std::vector incoming_label_indices_from_sources_[2]; std::vector node_first_label_[2]; std::vector node_num_labels_[2]; }; @@ -560,6 +557,8 @@ PathWithLength ConstrainedShortestPathsOnDagWrapper< /*resources_from_sources=*/resources_from_sources_[dir], /*incoming_arc_indices_from_sources=*/ incoming_arc_indices_from_sources_[dir], + /*incoming_label_indices_from_sources=*/ + incoming_label_indices_from_sources_[dir], /*first_label=*/node_first_label_[dir], /*num_labels=*/node_num_labels_[dir]); }); @@ -608,13 +607,10 @@ PathWithLength ConstrainedShortestPathsOnDagWrapper< for (const Direction dir : {FORWARD, BACKWARD}) { for (const ArcIndex sub_arc_index : ArcPathTo( /*best_label_index=*/best_label_pair.label_index[dir], - /*reverse_graph=*/sub_reverse_graph_[dir], - /*arc_lengths=*/sub_arc_lengths[dir], - /*lengths_from_sources=*/lengths_from_sources_[dir], /*incoming_arc_indices_from_sources=*/ incoming_arc_indices_from_sources_[dir], - /*first_label=*/node_first_label_[dir], - /*num_labels=*/node_num_labels_[dir])) { + /*incoming_label_indices_from_sources=*/ + incoming_label_indices_from_sources_[dir])) { const ArcIndex arc_index = sub_full_arc_indices_[dir][sub_arc_index]; if (arc_index == -1) { break; @@ -634,6 +630,7 @@ PathWithLength ConstrainedShortestPathsOnDagWrapper< resources_from_sources_[dir][r].clear(); } incoming_arc_indices_from_sources_[dir].clear(); + incoming_label_indices_from_sources_[dir].clear(); } return {.length = best_label_pair.length, .arc_path = arc_path, @@ -654,6 +651,7 @@ void ConstrainedShortestPathsOnDagWrapper:: std::vector& lengths_from_sources, std::vector>& resources_from_sources, std::vector& incoming_arc_indices_from_sources, + std::vector& incoming_label_indices_from_sources, std::vector& first_label, std::vector& num_labels) { // Initialize source node. const NodeIndex source_node = reverse_graph.num_nodes() - 1; @@ -664,10 +662,12 @@ void ConstrainedShortestPathsOnDagWrapper:: resources_from_sources[r].push_back(0); } incoming_arc_indices_from_sources.push_back(-1); + incoming_label_indices_from_sources.push_back(-1); std::vector lengths_to; std::vector> resources_to(num_resources_); std::vector incoming_arc_indices_to; + std::vector incoming_label_indices_to; std::vector label_indices_to; std::vector resources(num_resources_); for (NodeIndex to = 0; to < source_node; ++to) { @@ -676,6 +676,7 @@ void ConstrainedShortestPathsOnDagWrapper:: resources_to[r].clear(); } incoming_arc_indices_to.clear(); + incoming_label_indices_to.clear(); for (const ArcIndex reverse_arc_index : reverse_graph.OutgoingArcs(to)) { const NodeIndex from = reverse_graph.Head(reverse_arc_index); const double arc_length = arc_lengths[reverse_arc_index]; @@ -703,6 +704,7 @@ void ConstrainedShortestPathsOnDagWrapper:: resources_to[r].push_back(resources[r]); } incoming_arc_indices_to.push_back(reverse_arc_index); + incoming_label_indices_to.push_back(label_index); } } // Sort labels lexicographically with lengths then resources. @@ -753,6 +755,8 @@ void ConstrainedShortestPathsOnDagWrapper:: } incoming_arc_indices_from_sources.push_back( incoming_arc_indices_to[label_i_index]); + incoming_label_indices_from_sources.push_back( + incoming_label_indices_to[label_i_index]); ++num_labels_to; if (lengths_from_sources.size() >= max_num_created_labels) { return; @@ -857,37 +861,19 @@ template #endif std::vector ConstrainedShortestPathsOnDagWrapper::ArcPathTo( - const int best_label_index, const GraphType& reverse_graph, - absl::Span arc_lengths, - absl::Span lengths_from_sources, + const int best_label_index, absl::Span incoming_arc_indices_from_sources, - absl::Span first_label, absl::Span num_labels) const { - if (best_label_index == -1) { - return {}; - } + absl::Span incoming_label_indices_from_sources) const { int current_label_index = best_label_index; std::vector arc_path; - for (int i = 0; i < reverse_graph.num_nodes(); ++i) { - const ArcIndex current_arc_index = - incoming_arc_indices_from_sources[current_label_index]; - if (current_arc_index == -1) { + for (int i = 0; i < graph_->num_nodes(); ++i) { + if (current_label_index == -1) { break; } - arc_path.push_back(current_arc_index); - const NodeIndex sub_node = reverse_graph.Head(current_arc_index); - const double current_length = lengths_from_sources[current_label_index]; - for (int label_index = first_label[sub_node]; - label_index < first_label[sub_node] + num_labels[sub_node]; - ++label_index) { - if (std::abs(lengths_from_sources[label_index] + - arc_lengths[current_arc_index] - current_length) <= - kTolerance) { - current_label_index = label_index; - break; - } - } + arc_path.push_back(incoming_arc_indices_from_sources[current_label_index]); + current_label_index = + incoming_label_indices_from_sources[current_label_index]; } - CHECK_EQ(incoming_arc_indices_from_sources[current_label_index], -1); return arc_path; } diff --git a/ortools/graph/dag_shortest_path.cc b/ortools/graph/dag_shortest_path.cc index 754c5e2506..5b487c98cc 100644 --- a/ortools/graph/dag_shortest_path.cc +++ b/ortools/graph/dag_shortest_path.cc @@ -27,9 +27,9 @@ namespace operations_research { namespace { - using GraphType = util::StaticGraph<>; - using NodeIndex = GraphType::NodeIndex; - using ArcIndex = GraphType::ArcIndex; +using GraphType = util::StaticGraph<>; +using NodeIndex = GraphType::NodeIndex; +using ArcIndex = GraphType::ArcIndex; struct ShortestPathOnDagProblem { GraphType graph; @@ -44,7 +44,7 @@ ShortestPathOnDagProblem ReadProblem( std::vector arc_lengths; arc_lengths.reserve(arcs_with_length.size()); for (const auto& arc : arcs_with_length) { - graph.AddArc(arc.tail, arc.head); + graph.AddArc(arc.from, arc.to); arc_lengths.push_back(arc.length); } std::vector permutation; @@ -58,17 +58,18 @@ ShortestPathOnDagProblem ReadProblem( } } - const absl::StatusOr> topological_order = + absl::StatusOr> topological_order = util::graph::FastTopologicalSort(graph); CHECK_OK(topological_order) << "arcs_with_length form a cycle."; - return ShortestPathOnDagProblem{.graph = graph, - .arc_lengths = arc_lengths, - .original_arc_indices = original_arc_indices, - .topological_order = *topological_order}; + return ShortestPathOnDagProblem{ + .graph = std::move(graph), + .arc_lengths = std::move(arc_lengths), + .original_arc_indices = std::move(original_arc_indices), + .topological_order = std::move(topological_order).value()}; } -void GetOriginalArcPath(const std::vector& original_arc_indices, +void GetOriginalArcPath(absl::Span original_arc_indices, std::vector& arc_path) { if (original_arc_indices.empty()) { return; @@ -98,7 +99,7 @@ PathWithLength ShortestPathsOnDag( GetOriginalArcPath(problem.original_arc_indices, arc_path); return PathWithLength{ .length = shortest_path_on_dag.LengthTo(destination), - .arc_path = arc_path, + .arc_path = std::move(arc_path), .node_path = shortest_path_on_dag.NodePathTo(destination)}; } diff --git a/ortools/graph/dag_shortest_path.h b/ortools/graph/dag_shortest_path.h index 712e5e6321..cd1e2c85f2 100644 --- a/ortools/graph/dag_shortest_path.h +++ b/ortools/graph/dag_shortest_path.h @@ -49,11 +49,11 @@ namespace operations_research { // Basic API. // ----------------------------------------------------------------------------- -// `tail` and `head` should both be in [0, num_nodes) +// `from` and `to` should both be in [0, num_nodes). // If the length is +inf, then the arc should not be used. struct ArcWithLength { - int tail = 0; - int head = 0; + int from = 0; + int to = 0; double length = 0.0; }; @@ -368,13 +368,13 @@ ShortestPathsOnDagWrapper::ShortestPathsOnDagWrapper( CHECK_GT(graph_->num_nodes(), 0) << "The graph is empty: it has no nodes"; CHECK_GT(graph_->num_arcs(), 0) << "The graph is empty: it has no arcs"; #ifndef NDEBUG - CHECK_EQ(arc_lengths_->size(), graph_->num_arcs()); - for (const double arc_length : *arc_lengths_) { + CHECK_EQ(arc_lengths_->size(), graph_->num_arcs()); + for (const double arc_length : *arc_lengths_) { CHECK(arc_length != -kInf && !std::isnan(arc_length)) - << absl::StrFormat("length cannot be -inf nor NaN"); - } - CHECK_OK(TopologicalOrderIsValid(*graph_, topological_order_)) - << "Invalid topological order"; + << absl::StrFormat("length cannot be -inf nor NaN"); + } + CHECK_OK(TopologicalOrderIsValid(*graph_, topological_order_)) + << "Invalid topological order"; #endif // Memory allocation is done here and only once in order to avoid reallocation @@ -507,15 +507,15 @@ KShortestPathsOnDagWrapper::KShortestPathsOnDagWrapper( reverse_graph_ = GraphType(graph_->num_nodes(), num_arcs); for (ArcIndex arc_index = 0; arc_index < num_arcs; ++arc_index) { reverse_graph_.AddArc(graph->Head(arc_index), graph->Tail(arc_index)); - } + } std::vector permutation; reverse_graph_.Build(&permutation); arc_indices_.resize(permutation.size()); if (!permutation.empty()) { for (int i = 0; i < permutation.size(); ++i) { arc_indices_[permutation[i]] = i; + } } - } // Memory allocation is done here and only once in order to avoid reallocation // at each call of `RunKShortestPathOnDag()` for better performance. @@ -526,10 +526,10 @@ KShortestPathsOnDagWrapper::KShortestPathsOnDagWrapper( lengths_from_sources_[k].resize(graph_->num_nodes(), kInf); incoming_shortest_paths_arc_[k].resize(graph_->num_nodes(), -1); incoming_shortest_paths_index_[k].resize(graph_->num_nodes(), -1); - } + } is_source_.resize(graph_->num_nodes(), false); reached_nodes_.reserve(graph_->num_nodes()); - } +} template #if __cplusplus >= 202002L diff --git a/ortools/graph/ebert_graph.h b/ortools/graph/ebert_graph.h index d31345f89f..e3541a995f 100644 --- a/ortools/graph/ebert_graph.h +++ b/ortools/graph/ebert_graph.h @@ -177,7 +177,6 @@ #include "absl/strings/str_cat.h" #include "ortools/base/logging.h" -#include "ortools/base/types.h" #include "ortools/util/permutation.h" #include "ortools/util/zvector.h" diff --git a/ortools/graph/hamiltonian_path.h b/ortools/graph/hamiltonian_path.h index f3d4f29e78..e5c77d6ef5 100644 --- a/ortools/graph/hamiltonian_path.h +++ b/ortools/graph/hamiltonian_path.h @@ -77,21 +77,19 @@ // Keywords: Traveling Salesman, Hamiltonian Path, Dynamic Programming, // Held, Karp. -#include #include #include #include #include #include -#include #include #include #include #include +#include "absl/types/span.h" #include "ortools/base/logging.h" -#include "ortools/base/types.h" #include "ortools/util/bitset.h" #include "ortools/util/saturated_arithmetic.h" #include "ortools/util/vector_or_function.h" @@ -562,7 +560,7 @@ class HamiltonianPathSolver { std::vector ComputePath(CostType cost, NodeSet set, int end); // Returns true if the path covers all nodes, and its cost is equal to cost. - bool PathIsValid(const std::vector& path, CostType cost); + bool PathIsValid(absl::Span path, CostType cost); // Cost function used to build Hamiltonian paths. MatrixOrFunction cost_; @@ -765,7 +763,7 @@ std::vector HamiltonianPathSolver::ComputePath( template bool HamiltonianPathSolver::PathIsValid( - const std::vector& path, CostType cost) { + absl::Span path, CostType cost) { NodeSet coverage(0); for (int node : path) { coverage = coverage.AddElement(node); diff --git a/ortools/graph/k_shortest_paths.h b/ortools/graph/k_shortest_paths.h index 1c017647c2..89011bc989 100644 --- a/ortools/graph/k_shortest_paths.h +++ b/ortools/graph/k_shortest_paths.h @@ -22,6 +22,10 @@ // | Yen | No | No | (Un)directed | Yes | // // +// Loopless path: path not going through the same node more than once. Also +// called simple path. +// +// // Design choices // ============== // @@ -60,8 +64,11 @@ #include "absl/algorithm/container.h" #include "absl/base/optimization.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "ortools/base/logging.h" #include "ortools/graph/bounded_dijkstra.h" #include "ortools/graph/ebert_graph.h" #include "ortools/graph/shortest_paths.h" @@ -302,16 +309,25 @@ KShortestPaths YenKShortestPaths(const GraphType& graph, std::priority_queue> variant_path_queue; - for (; k > 0; --k) { + // One path has already been generated (the shortest one). Only k-1 more + // paths need to be generated. + for (; k > 1; --k) { + VLOG(1) << "k: " << k; + // Generate variant paths from the last shortest path. const absl::Span last_shortest_path = absl::MakeSpan(paths.paths.back()); // TODO(user): think about adding parallelism for this loop to improve - // running times. + // running times. This is not a priority as long as the algorithm is + // faster than the one in `shortest_paths.h`. for (int spur_node_position = 0; spur_node_position < last_shortest_path.size() - 1; ++spur_node_position) { + VLOG(4) << " spur_node_position: " << spur_node_position; + VLOG(4) << " last_shortest_path: " + << absl::StrJoin(last_shortest_path, " - ") << " (" + << last_shortest_path.size() << ")"; if (spur_node_position > 0) { DCHECK_NE(last_shortest_path[spur_node_position], source); } @@ -342,18 +358,34 @@ KShortestPaths YenKShortestPaths(const GraphType& graph, // of the path in the search for the next shortest path. More // precisely, in that case, avoid the arc from the spur node to the // next node in the path. - if (previous_path.size() < spur_node_position) continue; + if (previous_path.size() <= root_path.length()) continue; const bool has_same_prefix_as_root_path = std::equal( root_path.begin(), root_path.end(), previous_path.begin(), previous_path.begin() + root_path.length()); - if (has_same_prefix_as_root_path) { - const ArcIndex after_spur_node_arc = - internal::FindArcIndex(graph, previous_path[spur_node_position], - previous_path[spur_node_position + 1]); - arc_lengths_for_detour[after_spur_node_arc] = - internal::kDisconnectedDistance; + if (!has_same_prefix_as_root_path) continue; + + const ArcIndex after_spur_node_arc = + internal::FindArcIndex(graph, previous_path[spur_node_position], + previous_path[spur_node_position + 1]); + VLOG(4) << " after_spur_node_arc: " << graph.Tail(after_spur_node_arc) + << " - " << graph.Head(after_spur_node_arc) << " (" << source + << " - " << destination << ")"; + arc_lengths_for_detour[after_spur_node_arc] = + internal::kDisconnectedDistance; + } + // Ensure that the path computed from the new weights is loopless by + // "removing" the nodes of the root path from the graph (by tweaking the + // weights, again). The previous operation only disallows the arc from the + // spur node (at the end of the root path) to the next node in the + // previously found paths. + for (int node_position = 0; node_position < spur_node_position; + ++node_position) { + for (const int arc : graph.OutgoingArcs(root_path[node_position])) { + arc_lengths_for_detour[arc] = internal::kDisconnectedDistance; } } + VLOG(3) << " arc_lengths_for_detour: " + << absl::StrJoin(arc_lengths_for_detour, " - "); // Generate a new candidate path from the spur node to the destination // without using the forbidden arcs. @@ -366,11 +398,16 @@ KShortestPaths YenKShortestPaths(const GraphType& graph, // Node unreachable after some arcs are forbidden. continue; } + VLOG(2) << " detour_path: " + << absl::StrJoin(std::get<0>(detour_path), " - ") << " (" + << std::get<0>(detour_path).size() + << "): " << std::get<1>(detour_path); std::vector spur_path = std::move(std::get<0>(detour_path)); if (ABSL_PREDICT_FALSE(spur_path.empty())) continue; #ifndef NDEBUG CHECK_EQ(root_path.back(), spur_path.front()); + CHECK_EQ(spur_node, spur_path.front()); if (spur_path.size() == 1) { CHECK_EQ(spur_path.front(), destination); @@ -386,6 +423,22 @@ KShortestPaths YenKShortestPaths(const GraphType& graph, }); CHECK(root_path_leads_to_spur_path); } + + // Ensure the forbidden arc is not present in any previously generated + // path. + for (absl::Span previous_path : paths.paths) { + if (previous_path.size() <= spur_node_position + 1) continue; + const bool has_same_prefix_as_root_path = std::equal( + root_path.begin(), root_path.end(), previous_path.begin(), + previous_path.begin() + root_path.length()); + if (has_same_prefix_as_root_path) { + CHECK_NE(spur_path[1], previous_path[spur_node_position + 1]) + << "Forbidden arc " << previous_path[spur_node_position] + << " - " << previous_path[spur_node_position + 1] + << " is present in the spur path " + << absl::StrJoin(spur_path, " - "); + } + } #endif // !defined(NDEBUG) // Assemble the new path. @@ -397,6 +450,15 @@ KShortestPaths YenKShortestPaths(const GraphType& graph, DCHECK_EQ(new_path.front(), source); DCHECK_EQ(new_path.back(), destination); +#ifndef NDEBUG + // Ensure the assembled path is loopless, i.e. no node is repeated. + { + absl::flat_hash_set visited_nodes(new_path.begin(), + new_path.end()); + CHECK_EQ(visited_nodes.size(), new_path.size()); + } +#endif // !defined(NDEBUG) + // Ensure the new path is not one of the previously known ones. This // operation is required, as there are two sources of paths from the // source to the destination: @@ -419,6 +481,13 @@ KShortestPaths YenKShortestPaths(const GraphType& graph, const PathDistance path_length = internal::ComputePathLength(graph, arc_lengths, new_path); + VLOG(5) << " New potential path generated: " + << absl::StrJoin(new_path, " - ") << " (" << new_path.size() + << ")"; + VLOG(5) << " Root: " << absl::StrJoin(root_path, " - ") << " (" + << root_path.size() << ")"; + VLOG(5) << " Spur: " << absl::StrJoin(spur_path, " - ") << " (" + << spur_path.size() << ")"; variant_path_queue.emplace( /*priority=*/path_length, /*path=*/std::move(new_path)); } @@ -431,6 +500,9 @@ KShortestPaths YenKShortestPaths(const GraphType& graph, const internal::PathWithPriority& next_shortest_path = variant_path_queue.top(); + VLOG(5) << "> New path generated: " + << absl::StrJoin(next_shortest_path.path(), " - ") << " (" + << next_shortest_path.path().size() << ")"; paths.paths.emplace_back(next_shortest_path.path()); paths.distances.push_back(next_shortest_path.priority()); variant_path_queue.pop(); diff --git a/ortools/graph/k_shortest_paths_test.cc b/ortools/graph/k_shortest_paths_test.cc index f1b4808fbf..f36e743ff4 100644 --- a/ortools/graph/k_shortest_paths_test.cc +++ b/ortools/graph/k_shortest_paths_test.cc @@ -13,17 +13,32 @@ #include "ortools/graph/k_shortest_paths.h" +#include +#include +#include +#include +#include +#include #include +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/random/distributions.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "benchmark/benchmark.h" #include "gtest/gtest.h" #include "ortools/base/gmock.h" #include "ortools/graph/graph.h" +#include "ortools/graph/io.h" #include "ortools/graph/shortest_paths.h" namespace operations_research { namespace { using testing::ElementsAre; +using testing::IsEmpty; +using testing::UnorderedElementsAreArray; using util::StaticGraph; TEST(KShortestPathsYenDeathTest, EmptyGraph) { @@ -164,8 +179,261 @@ TEST(KShortestPathsYenTest, HasTwoPathsWithLongerPath) { EXPECT_THAT(paths.distances, ElementsAre(4, 30)); } -// TODO(user): randomized tests? Check validity with exhaustive -// exploration/IP formulation? +TEST(KShortestPathsYenTest, ReturnsTheRightNumberOfPaths) { + StaticGraph<> graph; + graph.AddArc(0, 1); + graph.AddArc(0, 2); + graph.AddArc(0, 3); + graph.AddArc(1, 2); + graph.AddArc(3, 2); + + (void)graph.Build(); + std::vector lengths{1, 1, 1, 1, 1}; + + const KShortestPaths paths = YenKShortestPaths(graph, lengths, /*source=*/0, + /*destination=*/2, /*k=*/2); + EXPECT_THAT(paths.paths, + ElementsAre(std::vector{0, 2}, std::vector{0, 1, 2})); + EXPECT_THAT(paths.distances, ElementsAre(1, 2)); +} + +namespace internal { + +template +Graph GenerateUniformGraph(URBG&& urbg, const NodeIndexType num_nodes, + const ArcIndexType num_edges) { + // TODO(user): make these utility functions so they can be reused. + const auto pick_one_node = [&urbg, num_nodes]() -> NodeIndexType { + const NodeIndexType node = absl::Uniform(urbg, 0, num_nodes); + CHECK_GE(node, 0); + CHECK_LT(node, num_nodes); + return node; + }; + const auto pick_two_distinct_nodes = + [&pick_one_node]() -> std::pair { + const NodeIndexType src = pick_one_node(); + NodeIndexType dst; + do { + dst = pick_one_node(); + } while (src == dst); + CHECK_NE(src, dst); + return {src, dst}; + }; + + // Determine the maximum number of arcs in the graph. + const ArcIndexType max_num_arcs = IsDirected + ? (num_nodes * (num_nodes - 1)) + : (num_nodes * (num_nodes - 1)) / 2; + + // Build a random graph (and not multigraph) with `num_arcs` or `max_num_arcs` + // arcs, whichever is lower. The set is useful to ensure the graph does not + // contain the same arc more than once (the result would be a multigraph). + // TODO(user): this is an awful way to generate a complete graph. + StaticGraph<> graph; + graph.AddNode(num_nodes - 1); + + std::set> arcs; + for (ArcIndexType i = 0; i < std::min(num_edges, max_num_arcs); ++i) { + NodeIndexType src, dst; + std::tie(src, dst) = pick_two_distinct_nodes(); + if (arcs.contains({src, dst})) continue; + if (IsDirected && arcs.contains({dst, src})) continue; + + arcs.insert({src, dst}); + graph.AddArc(src, dst); + + if (IsDirected) { + arcs.insert({dst, src}); + graph.AddArc(dst, src); + } + } + + // No need to keep the permutation when building, as there are no associated + // attributes such as lengths in this function. + graph.Build(nullptr); + + return graph; +} + +} // namespace internal + +// Generates a random (un)directed graph with `num_nodes` nodes and up to +// `num_arcs` arcs / `num_edges` edges, following a uniform probability +// distribution. `urbg` is a source of randomness, such as an `std::mt19937` +// object. +// +// If the number of arcs that is requested is too large compared to the number +// of nodes (i.e. greater than the maximum number of arcs for a directed or +// undirected graph with the specified number of node), this function returns a +// complete graph. +template , + typename URBG> +Graph GenerateUniformGraph(URBG&& urbg, const NodeIndexType num_nodes, + const ArcIndexType num_edges) { + return internal::GenerateUniformGraph( + urbg, num_nodes, num_edges); +} +template , + typename URBG> +Graph GenerateUniformDirectedGraph(URBG&& urbg, const NodeIndexType num_nodes, + const ArcIndexType num_arcs) { + return internal::GenerateUniformGraph( + urbg, num_nodes, num_arcs); +} + +TEST(KShortestPathsYenTest, RandomTest) { + std::mt19937 random(12345); + constexpr int kNumGraphs = 10; + constexpr int kNumQueriesPerGraph = 10; + constexpr int kNumNodes = 10; + constexpr int kNumArcs = 3 * kNumNodes; + // TODO(user): when supported, also test negative weights. + constexpr int kMinLength = 0; + constexpr int kMaxLength = 1'000; + + const auto pick_one_node = [&random]() -> int { + int node = absl::Uniform(random, 0, kNumNodes); + CHECK_GE(node, 0); + CHECK_LT(node, kNumNodes); + return node; + }; + const auto pick_two_distinct_nodes = + [&pick_one_node]() -> std::pair { + int src = pick_one_node(); + int dst; + do { + dst = pick_one_node(); + } while (src == dst); + CHECK_NE(src, dst); + return {src, dst}; + }; + + const auto format_path = [](std::string* out, const std::vector& path) { + absl::StrAppend(out, absl::StrJoin(path, " - ")); + }; + + for (int graph_iter = 0; graph_iter < kNumGraphs; ++graph_iter) { + (void)graph_iter; + + StaticGraph<> graph = + GenerateUniformDirectedGraph(random, kNumNodes, kNumArcs); + std::vector lengths; + for (int i = 0; i < graph.num_arcs(); ++i) { + lengths.push_back(absl::Uniform(random, kMinLength, kMaxLength)); + } + + // Run random queries, with one source and one destination per query. + for (int q = 0; q < kNumQueriesPerGraph; ++q) { + int src, dst; + std::tie(src, dst) = pick_two_distinct_nodes(); + + // Determine the set of simple paths between these nodes by brute force. + // (Simple in the sense that the path does not contain loops.) + // + // Basic idea: graph traversal from the source node until the destination + // node, not stopping until the whole graph is searched. + // + // This loop always finishes, even if the two nodes are not connected: + // at some point, there will be no tentative path left. In case of a loop + // in the graph, the tested paths will not contain loops. + std::set> brute_force_paths; + std::vector> tentative_paths{{src}}; + while (!tentative_paths.empty()) { + std::vector partial_path = tentative_paths.front(); + tentative_paths.erase(tentative_paths.begin()); + + const int last_node = partial_path.back(); + for (const int next_arc : graph.OutgoingArcs(last_node)) { + const int next_node = graph.Head(next_arc); + ASSERT_NE(last_node, next_node); + + if (absl::c_find(partial_path, next_node) != partial_path.end()) { + // To avoid loops (both in the path and at run time), ensure that + // the path does not go through `next_node`. Otherwise, there would + // be a loop in path, going at least twice through `next_node`. + continue; + } + + std::vector new_path = partial_path; + new_path.push_back(next_node); + + if (next_node == dst) { + brute_force_paths.emplace(std::move(new_path)); + } else { + tentative_paths.emplace_back(std::move(new_path)); + } + } + } + ASSERT_THAT(tentative_paths, IsEmpty()); + + // Maybe the procedure fails to find paths because none exist, which is + // possible with random graphs (i.e. the graph is disconnected, with `src` + // and `dst` in distinct connected components). + if (brute_force_paths.empty()) continue; + + // Use the algorithm-under-test to generate as many paths as possible. + const KShortestPaths yen_paths = + YenKShortestPaths(graph, lengths, src, dst, + /*k=*/brute_force_paths.size()); + + // The two sets of paths must correspond. + EXPECT_THAT(brute_force_paths, UnorderedElementsAreArray(yen_paths.paths)) + << "[" << util::GraphToString(graph, util::PRINT_GRAPH_ARCS) + << "] Brute-force paths: [" + << absl::StrJoin(brute_force_paths, ", ", format_path) + << "] Yen paths: [" + << absl::StrJoin(yen_paths.paths, ", ", format_path) << "]"; + } + } +} + +void BM_Yen(benchmark::State& state) { + const int num_nodes = state.range(0); + // Use half the maximum number of arcs, so that the graph is a bit sparse. + const int num_arcs = num_nodes * (num_nodes - 1) / 4; + // TODO(user): when supported, also benchmark negative weights + // (separately?). + constexpr int kMinLength = 0; + constexpr int kMaxLength = 1'000; + + std::mt19937 random(12345); + const auto pick_one_node = [&random, num_nodes]() -> int { + int node = absl::Uniform(random, 0, num_nodes); + CHECK_GE(node, 0); + CHECK_LT(node, num_nodes); + return node; + }; + const auto pick_two_distinct_nodes = + [&pick_one_node]() -> std::pair { + int src = pick_one_node(); + int dst; + do { + dst = pick_one_node(); + } while (src == dst); + CHECK_NE(src, dst); + return {src, dst}; + }; + + StaticGraph<> graph = + GenerateUniformDirectedGraph(random, num_nodes, num_arcs); + std::vector lengths; + for (int i = 0; i < graph.num_arcs(); ++i) { + lengths.push_back(absl::Uniform(random, kMinLength, kMaxLength)); + } + + for (auto unused : state) { + int src, dst; + std::tie(src, dst) = pick_two_distinct_nodes(); + YenKShortestPaths(graph, lengths, src, dst, /*k=*/10); + } +} + +BENCHMARK(BM_Yen)->Range(10, 1'000); } // namespace } // namespace operations_research diff --git a/ortools/graph/linear_assignment.h b/ortools/graph/linear_assignment.h index 989160b091..dba3e28b73 100644 --- a/ortools/graph/linear_assignment.h +++ b/ortools/graph/linear_assignment.h @@ -205,9 +205,9 @@ #include #include "absl/flags/declare.h" +#include "absl/flags/flag.h" #include "absl/strings/str_format.h" #include "ortools/base/logging.h" -#include "ortools/base/types.h" #include "ortools/graph/ebert_graph.h" #include "ortools/util/permutation.h" #include "ortools/util/zvector.h" diff --git a/ortools/graph/max_flow.h b/ortools/graph/max_flow.h index 750cdba1ae..a5d961e539 100644 --- a/ortools/graph/max_flow.h +++ b/ortools/graph/max_flow.h @@ -123,14 +123,13 @@ #ifndef OR_TOOLS_GRAPH_MAX_FLOW_H_ #define OR_TOOLS_GRAPH_MAX_FLOW_H_ -#include #include #include +#include #include #include "absl/strings/string_view.h" #include "ortools/base/logging.h" -#include "ortools/base/types.h" #include "ortools/graph/ebert_graph.h" #include "ortools/graph/flow_problem.pb.h" #include "ortools/graph/graph.h" diff --git a/ortools/graph/min_cost_flow.h b/ortools/graph/min_cost_flow.h index 3664bda5fb..270d2c6157 100644 --- a/ortools/graph/min_cost_flow.h +++ b/ortools/graph/min_cost_flow.h @@ -168,15 +168,12 @@ #ifndef OR_TOOLS_GRAPH_MIN_COST_FLOW_H_ #define OR_TOOLS_GRAPH_MIN_COST_FLOW_H_ -#include #include #include #include #include #include "absl/strings/string_view.h" -#include "ortools/base/logging.h" -#include "ortools/base/types.h" #include "ortools/graph/ebert_graph.h" #include "ortools/graph/graph.h" #include "ortools/util/stats.h" diff --git a/ortools/graph/minimum_spanning_tree.h b/ortools/graph/minimum_spanning_tree.h index 1aad1fb4ec..ed909e5d6e 100644 --- a/ortools/graph/minimum_spanning_tree.h +++ b/ortools/graph/minimum_spanning_tree.h @@ -14,15 +14,13 @@ #ifndef OR_TOOLS_GRAPH_MINIMUM_SPANNING_TREE_H_ #define OR_TOOLS_GRAPH_MINIMUM_SPANNING_TREE_H_ -#include +#include #include #include "absl/types/span.h" #include "ortools/base/adjustable_priority_queue-inl.h" #include "ortools/base/adjustable_priority_queue.h" -#include "ortools/base/types.h" #include "ortools/graph/connected_components.h" -#include "ortools/util/vector_or_function.h" namespace operations_research { diff --git a/ortools/graph/multi_dijkstra.h b/ortools/graph/multi_dijkstra.h index 16c9e44f3d..9009d686fb 100644 --- a/ortools/graph/multi_dijkstra.h +++ b/ortools/graph/multi_dijkstra.h @@ -52,7 +52,6 @@ #include "absl/container/flat_hash_map.h" #include "ortools/base/map_util.h" -#include "ortools/base/types.h" namespace operations_research { diff --git a/ortools/graph/one_tree_lower_bound.h b/ortools/graph/one_tree_lower_bound.h index dce6895072..71e60eccc1 100644 --- a/ortools/graph/one_tree_lower_bound.h +++ b/ortools/graph/one_tree_lower_bound.h @@ -121,8 +121,6 @@ #ifndef OR_TOOLS_GRAPH_ONE_TREE_LOWER_BOUND_H_ #define OR_TOOLS_GRAPH_ONE_TREE_LOWER_BOUND_H_ -#include - #include #include #include @@ -131,8 +129,9 @@ #include #include "absl/types/span.h" -#include "ortools/base/types.h" +#include "ortools/base/logging.h" #include "ortools/graph/christofides.h" +#include "ortools/graph/graph.h" #include "ortools/graph/minimum_spanning_tree.h" namespace operations_research { diff --git a/ortools/graph/perfect_matching.cc b/ortools/graph/perfect_matching.cc index 726d2ad00d..d8a1ecd5a8 100644 --- a/ortools/graph/perfect_matching.cc +++ b/ortools/graph/perfect_matching.cc @@ -22,7 +22,11 @@ #include #include -#include "absl/memory/memory.h" +#include "absl/base/log_severity.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "ortools/base/logging.h" #include "ortools/util/saturated_arithmetic.h" namespace operations_research { diff --git a/ortools/graph/perfect_matching.h b/ortools/graph/perfect_matching.h index c06e876673..2f1b504641 100644 --- a/ortools/graph/perfect_matching.h +++ b/ortools/graph/perfect_matching.h @@ -28,21 +28,16 @@ #include #include -#include #include #include #include #include "absl/base/attributes.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "ortools/base/adjustable_priority_queue-inl.h" #include "ortools/base/adjustable_priority_queue.h" #include "ortools/base/int_type.h" #include "ortools/base/logging.h" -#include "ortools/base/macros.h" #include "ortools/base/strong_vector.h" -#include "ortools/base/types.h" namespace operations_research { diff --git a/ortools/graph/rooted_tree.h b/ortools/graph/rooted_tree.h new file mode 100644 index 0000000000..c0f68d0179 --- /dev/null +++ b/ortools/graph/rooted_tree.h @@ -0,0 +1,802 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Find paths and compute path distances between nodes on a rooted tree. +// +// A tree is a connected undirected graph with no cycles. A rooted tree is a +// directed graph derived from a tree, where a node is designated as the root, +// and then all edges are directed towards the root. +// +// This file provides the class RootedTree, which stores a rooted tree on dense +// integer nodes a single vector, and a function RootedTreeFromGraph(), which +// converts the adjacency list of a an undirected tree to a RootedTree. + +#ifndef OR_TOOLS_GRAPH_ROOTED_TREE_H_ +#define OR_TOOLS_GRAPH_ROOTED_TREE_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "ortools/base/logging.h" +#include "ortools/base/status_builder.h" +#include "ortools/base/status_macros.h" + +namespace operations_research { + +// A tree is an undirected graph with no cycles, n nodes, and n-1 undirected +// edges. Consequently, a tree is connected. Given a tree on the nodes [0..n), +// a RootedTree picks any node to be the root, and then converts all edges into +// (directed) arcs pointing at the root. Each node has one outgoing edge, so we +// can store the adjacency list of this directed view of the graph as a single +// vector of integers with length equal to the number of nodes. At the root +// index, we store RootedTree::kNullParent=-1, and at every other index, we +// store the next node towards the root (the parent in the tree). +// +// This class is templated on the NodeType, which must be an integer type, e.g., +// int or int32_t (signed and unsigned types both work). +// +// The following operations are supported: +// * Path from node to root in O(path length to root) +// * Lowest Common Ancestor (LCA) of two nodes in O(path length between nodes) +// * Depth of all nodes in O(num nodes) +// * Topological sort in O(num nodes) +// * Path between any two nodes in O(path length between nodes) +// +// Users can provide a vector of arc lengths (indexed by source) to get: +// * Distance from node to root in O(path length to root) +// * Distance from all nodes to root in O(num nodes) +// * Distance between any two nodes in O(path length between nodes) +// +// Operations on rooted trees are generally more efficient than on adjacency +// list representations because the entire tree is in one contiguous allocation. +// There is also an asymptotic advantage on path finding problems. +// +// Two methods for finding the LCA are provided. The first requires the depth of +// every node ahead of time. The second requires a workspace of n bools, all +// starting at false. These values are modified and restored to false when the +// LCA computation finishes. In both cases, if the depths/workspace allocation +// is an O(n) precomputation, then the LCA runs in O(path length). +// Non-asymptotically, the depth method requires more precomputation, but the +// LCA is faster and does not require the user to manage mutable state (i.e., +// may be better for multi-threaded computation). +// +// An operation that is missing is bulk LCA, see +// https://en.wikipedia.org/wiki/Tarjan%27s_off-line_lowest_common_ancestors_algorithm. +template +class RootedTree { + public: + static constexpr NodeType kNullParent = static_cast(-1); + // Like the constructor but checks that the tree is valid. Uses O(num nodes) + // temporary space with O(log(n)) allocations. + // + // If the input is cyclic, an InvalidArgument error will be returned with + // "cycle" as a substring. Further, if error_cycle is not null, it will be + // cleared and then set to contain the cycle. We will not modify error cycle + // or return an error message containing the string cycle if there is no + // cycle. The cycle output will always begin with its smallest element. + static absl::StatusOr Create( + NodeType root, std::vector parents, + std::vector* error_cycle = nullptr, + std::vector* topological_order = nullptr); + + // Like Create(), but data is not validated (UB on bad input). + explicit RootedTree(NodeType root, std::vector parents) + : root_(root), parents_(std::move(parents)) {} + + // The root node of this rooted tree. + NodeType root() const { return root_; } + + // The number of nodes in this rooted tree. + NodeType num_nodes() const { return static_cast(parents_.size()); } + + // A vector that holds the parent of each non root node, and kNullParent at + // the root. + absl::Span parents() const { return parents_; } + + // Returns the path from `node` to `root()` as a vector of nodes starting with + // `node`. + std::vector PathToRoot(NodeType node) const; + + // Returns the path from `root()` to `node` as a vector of nodes starting with + // `node`. + std::vector PathFromRoot(NodeType node) const; + + // Returns the sum of the arc lengths of the arcs in the path from `start` to + // `root()`. + // + // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`. + // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail. + // The value of `arc_lengths[root()]` is unused. + double DistanceToRoot(NodeType start, + absl::Span arc_lengths) const; + + // Returns the path from `start` to `root()` as a vector of nodes starting + // with `start`, and the sum of the arc lengths of the arcs in the path. + // + // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`. + // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail. + // The value of `arc_lengths[root()]` is unused. + std::pair> DistanceAndPathToRoot( + NodeType start, absl::Span arc_lengths) const; + + // Returns the path from `start` to `end` as a vector of nodes starting with + // `start` and ending with `end`. + // + // `lca` is the lowest common ancestor of `start` and `end`. This can be + // computed using LowestCommonAncestorByDepth() or + // LowestCommonAncestorByDepth(), both defined on this class. + // + // Runs in time O(path length). + std::vector Path(NodeType start, NodeType end, NodeType lca) const; + + // Returns the sum of the arc lengths of the arcs in the path from `start` to + // `end`. + // + // `lca` is the lowest common ancestor of `start` and `end`. This can be + // computed using LowestCommonAncestorByDepth() or + // LowestCommonAncestorByDepth(), both defined on this class. + // + // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`. + // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail. + // The value of `arc_lengths[root()]` is unused. + // + // Runs in time O(number of edges connecting start to end). + double Distance(NodeType start, NodeType end, NodeType lca, + absl::Span arc_lengths) const; + + // Returns the path from `start` to `end` as a vector of nodes starting with + // `start`, and the sum of the arc lengths of the arcs in the path. + // + // `lca` is the lowest common ancestor of `start` and `end`. This can be + // computed using LowestCommonAncestorByDepth() or + // LowestCommonAncestorByDepth(), both defined on this class. + // + // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`. + // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail. + // The value of `arc_lengths[root()]` is unused. + // + // Runs in time O(number of edges connecting start to end). + std::pair> DistanceAndPath( + NodeType start, NodeType end, NodeType lca, + absl::Span arc_lengths) const; + + // Given a path of nodes, returns the sum of the length of the arcs connecting + // them. + // + // `path` must be a list of nodes in the tree where + // path[i+1] == parents()[path[i]]. + // + // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`. + // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail. + // The value of `arc_lengths[root()]` is unused. + double DistanceOfPath(absl::Span path, + absl::Span arc_lengths) const; + + // Returns a topological ordering of the nodes where the root is first and + // every other node appears after its parent. + std::vector TopologicalSort() const; + + // Returns the distance of every node from `root()`, if the edge leaving node + // i has length costs[i]. + // + // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`. + // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail. + // The value of `arc_lengths[root()]` is unused. + // + // If you already have a topological order, prefer + // `AllDistances(absl::Span costs, + // absl::Span& topological_order)`. + template + std::vector AllDistancesToRoot(absl::Span arc_lengths) const; + + // Returns the distance from every node to root(). + // + // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`. + // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail. + // The value of `arc_lengths[root()]` is unused. + // + // `topological_order` must have size equal to `num_nodes()` and start with + // `root()`, or else we CHECK fail. It can be any topological over nodes when + // the orientation of the arcs from rooting the tree is reversed. + template + std::vector AllDistancesToRoot( + absl::Span arc_lengths, + absl::Span topological_order) const; + + // Returns the distance (arcs to move over) from every node to the root. + // + // If you already have a topological order, prefer + // AllDepths(absl::Span). + std::vector AllDepths() const { + return AllDepths(TopologicalSort()); + } + + // Returns the distance (arcs to move over) from every node to the root. + // + // `topological_order` must have size equal to `num_nodes()` and start with + // `root()`, or else we CHECK fail. It can be any topological over nodes when + // the orientation of the arcs from rooting the tree is reversed. + std::vector AllDepths( + absl::Span topological_order) const; + + // Returns the lowest common ancestor of n1 and n2. + // + // `depths` must have size equal to `num_nodes()`, or else we CHECK fail. + // Values must be the distance of each node to the root in arcs (see + // AllDepths()). + NodeType LowestCommonAncestorByDepth(NodeType n1, NodeType n2, + absl::Span depths) const; + + // Returns the lowest common ancestor of n1 and n2. + // + // `visited_workspace` must be a vector with num_nodes() size, or else we + // CHECK fail. All values of `visited_workspace` should be false. It will be + // modified and then restored to its starting state. + NodeType LowestCommonAncestorBySearch( + NodeType n1, NodeType n2, std::vector& visited_workspace) const; + + // Modifies the tree in place to change the root. Runs in + // O(path length from root() to new_root). + void Evert(NodeType new_root); + + private: + static_assert(std::is_integral_v, + "NodeType must be an integral type."); + static_assert(sizeof(NodeType) <= sizeof(std::size_t), + "NodeType cannot be larger than size_t, because NodeType is " + "used to index into std::vector."); + + // Returns the number of nodes appended. + NodeType AppendToPath(NodeType start, NodeType end, + std::vector& path) const; + + // Returns the number of nodes appended. + NodeType ReverseAppendToPath(NodeType start, NodeType end, + std::vector& path) const; + + // Like AllDistancestoRoot(), but the input arc_lengths is mutated to hold + // the output, instead of just returning the output as a new vector. + template + void AllDistancesToRootInPlace( + absl::Span topological_order, + absl::Span arc_lengths_in_distances_out) const; + + // Returns the cost of the path from start to end. + // + // end must be either equal to an or ancestor of start in the tree (otherwise + // DCHECK/UB). + // + // `arc_lengths[i]` is the length of the arc from node i to `parents()[i]`. + // `arc_lengths` must have size equal to `num_nodes()` or else we CHECK fail. + // The value of `arc_lengths[root()]` is unused. + double DistanceOfUpwardPath(NodeType start, NodeType end, + absl::Span arc_lengths) const; + + int root_; + std::vector parents_; // kNullParent=-1 if root +}; + +//////////////////////////////////////////////////////////////////////////////// +// Graph API +//////////////////////////////////////////////////////////////////////////////// + +// Converts an adjacency list representation of an undirected tree into a rooted +// tree. +// +// Graph must meet the API defined in ortools/graph/graph.h, e.g., StaticGraph +// or ListGraph. Note that these are directed graph APIs, so they must have both +// forward and backward arcs for each edge in the tree. +// +// Graph must be a tree when viewed as an undirected graph. +// +// If topological_order is not null, it is set to a vector with one entry for +// each node giving a topological ordering over the nodes of the graph, with the +// root first. +// +// If depths is not null, it is set to a vector with one entry for each node, +// giving the depth in the tree of that node (the root has depth zero). +template +absl::StatusOr> RootedTreeFromGraph( + typename Graph::NodeType root, const Graph& graph, + std::vector* topological_order = nullptr, + std::vector* depths = nullptr); + +//////////////////////////////////////////////////////////////////////////////// +// Template implementations +//////////////////////////////////////////////////////////////////////////////// + +namespace internal { + +template +bool IsValidParent(const NodeType node, const NodeType num_tree_nodes) { + return node == RootedTree::kNullParent || + (node >= NodeType{0} && node < num_tree_nodes); +} + +template +absl::Status IsValidNode(const NodeType node, const NodeType num_tree_nodes) { + if (node < NodeType{0} || node >= num_tree_nodes) { + return util::InvalidArgumentErrorBuilder() + << "nodes must be in [0.." << num_tree_nodes + << "), found bad node: " << node; + } + return absl::OkStatus(); +} + +template +std::vector ExtractCycle(absl::Span parents, + const NodeType node_in_cycle) { + std::vector cycle; + cycle.push_back(node_in_cycle); + for (NodeType i = parents[node_in_cycle]; i != node_in_cycle; + i = parents[i]) { + CHECK_NE(i, RootedTree::kNullParent) + << "node_in_cycle: " << node_in_cycle + << " not in cycle, reached the root"; + cycle.push_back(i); + CHECK_LE(cycle.size(), parents.size()) + << "node_in_cycle: " << node_in_cycle + << " not in cycle, just (transitively) leads to a cycle"; + } + absl::c_rotate(cycle, absl::c_min_element(cycle)); + cycle.push_back(cycle[0]); + return cycle; +} + +template +std::string CycleErrorMessage(absl::Span cycle) { + CHECK_GT(cycle.size(), 0); + const NodeType start = cycle[0]; + std::string cycle_string; + if (cycle.size() > 10) { + cycle_string = absl::StrCat( + absl::StrJoin(absl::MakeConstSpan(cycle).subspan(0, 8), ", "), + ", ..., ", start); + } else { + cycle_string = absl::StrJoin(cycle, ", "); + } + return absl::StrCat("found cycle of size: ", cycle.size(), + " with nodes: ", cycle_string); +} + +// Every element of parents must be in {kNullParent} union [0..parents.size()), +// otherwise UB. +template +std::vector CheckForCycle(absl::Span parents, + std::vector* topological_order) { + const NodeType n = static_cast(parents.size()); + if (topological_order != nullptr) { + topological_order->clear(); + topological_order->reserve(n); + } + std::vector visited(n); + std::vector dfs_stack; + for (NodeType i = 0; i < n; ++i) { + if (visited[i]) { + continue; + } + NodeType next = i; + while (next != RootedTree::kNullParent && !visited[next]) { + dfs_stack.push_back(next); + if (dfs_stack.size() > n) { + if (topological_order != nullptr) { + topological_order->clear(); + } + return ExtractCycle(parents, next); + } + next = parents[next]; + DCHECK(IsValidParent(next, n)) << "next: " << next << ", n: " << n; + } + absl::c_reverse(dfs_stack); + for (const NodeType j : dfs_stack) { + visited[j] = true; + if (topological_order != nullptr) { + topological_order->push_back(j); + } + } + dfs_stack.clear(); + } + return {}; +} + +} // namespace internal + +template +NodeType RootedTree::AppendToPath(const NodeType start, + const NodeType end, + std::vector& path) const { + NodeType num_new = 0; + for (NodeType node = start; node != end; node = parents_[node]) { + DCHECK_NE(node, kNullParent); + path.push_back(node); + num_new++; + } + path.push_back(end); + return num_new + 1; +} + +template +NodeType RootedTree::ReverseAppendToPath( + NodeType start, NodeType end, std::vector& path) const { + NodeType result = AppendToPath(start, end, path); + std::reverse(path.end() - result, path.end()); + return result; +} + +template +double RootedTree::DistanceOfUpwardPath( + const NodeType start, const NodeType end, + absl::Span arc_lengths) const { + CHECK_EQ(num_nodes(), arc_lengths.size()); + double distance = 0.0; + for (NodeType next = start; next != end; next = parents_[next]) { + DCHECK_NE(next, root_); + distance += arc_lengths[next]; + } + return distance; +} + +template +absl::StatusOr> RootedTree::Create( + const NodeType root, std::vector parents, + std::vector* error_cycle, + std::vector* topological_order) { + const NodeType num_nodes = static_cast(parents.size()); + RETURN_IF_ERROR(internal::IsValidNode(root, num_nodes)) << "invalid root"; + if (parents[root] != kNullParent) { + return util::InvalidArgumentErrorBuilder() + << "root should have no parent (-1), but found parent of: " + << parents[root]; + } + for (NodeType i = 0; i < num_nodes; ++i) { + if (i == root) { + continue; + } + RETURN_IF_ERROR(internal::IsValidNode(parents[i], num_nodes)) + << "invalid value for parent of node: " << i; + } + std::vector cycle = + internal::CheckForCycle(absl::MakeConstSpan(parents), topological_order); + if (!cycle.empty()) { + std::string error_message = + internal::CycleErrorMessage(absl::MakeConstSpan(cycle)); + if (error_cycle != nullptr) { + *error_cycle = std::move(cycle); + } + return absl::InvalidArgumentError(std::move(error_message)); + } + return RootedTree(root, std::move(parents)); +} + +template +std::vector RootedTree::PathToRoot( + const NodeType node) const { + std::vector path; + for (NodeType next = node; next != root_; next = parents_[next]) { + path.push_back(next); + } + path.push_back(root_); + return path; +} + +template +std::vector RootedTree::PathFromRoot( + const NodeType node) const { + std::vector result = PathToRoot(node); + absl::c_reverse(result); + return result; +} + +template +std::vector RootedTree::TopologicalSort() const { + std::vector result; + const std::vector cycle = + internal::CheckForCycle(absl::MakeConstSpan(parents_), &result); + CHECK(cycle.empty()) << internal::CycleErrorMessage( + absl::MakeConstSpan(cycle)); + return result; +} + +template +double RootedTree::DistanceToRoot( + const NodeType start, absl::Span arc_lengths) const { + return DistanceOfUpwardPath(start, root_, arc_lengths); +} + +template +std::pair> +RootedTree::DistanceAndPathToRoot( + const NodeType start, absl::Span arc_lengths) const { + CHECK_EQ(num_nodes(), arc_lengths.size()); + double distance = 0.0; + std::vector path; + for (NodeType next = start; next != root_; next = parents_[next]) { + path.push_back(next); + distance += arc_lengths[next]; + } + path.push_back(root_); + return {distance, path}; +} + +template +std::vector RootedTree::Path(const NodeType start, + const NodeType end, + const NodeType lca) const { + std::vector result; + if (start == end) { + result.push_back(start); + return result; + } + if (start == lca) { + ReverseAppendToPath(end, lca, result); + return result; + } + if (end == lca) { + AppendToPath(start, lca, result); + return result; + } + AppendToPath(start, lca, result); + result.pop_back(); // Don't include the LCA twice + ReverseAppendToPath(end, lca, result); + return result; +} + +template +double RootedTree::Distance( + const NodeType start, const NodeType end, const NodeType lca, + absl::Span arc_lengths) const { + return DistanceOfUpwardPath(start, lca, arc_lengths) + + DistanceOfUpwardPath(end, lca, arc_lengths); +} + +template +std::pair> RootedTree::DistanceAndPath( + const NodeType start, const NodeType end, const NodeType lca, + absl::Span arc_lengths) const { + std::vector path = Path(start, end, lca); + const double dist = DistanceOfPath(path, arc_lengths); + return {dist, std::move(path)}; +} + +template +double RootedTree::DistanceOfPath( + absl::Span path, + absl::Span arc_lengths) const { + CHECK_EQ(num_nodes(), arc_lengths.size()); + double distance = 0.0; + for (int i = 0; i + 1 < path.size(); ++i) { + if (parents_[path[i]] != path[i + 1]) { + distance += arc_lengths[path[i]]; + } else if (parents_[path[i + 1]] == path[i]) { + distance += arc_lengths[path[i + 1]]; + } else { + LOG(FATAL) << "bad edge in path from " << path[i] << " to " + << path[i + 1]; + } + } + return distance; +} + +template +NodeType RootedTree::LowestCommonAncestorByDepth( + const NodeType n1, const NodeType n2, + absl::Span depths) const { + CHECK_EQ(num_nodes(), depths.size()); + const NodeType n = num_nodes(); + CHECK_OK(internal::IsValidNode(n1, n)); + CHECK_OK(internal::IsValidNode(n2, n)); + CHECK_EQ(depths.size(), n); + if (n1 == root_ || n2 == root_) { + return root_; + } + if (n1 == n2) { + return n1; + } + NodeType next1 = n1; + NodeType next2 = n2; + while (depths[next1] > depths[next2]) { + next1 = parents_[next1]; + } + while (depths[next2] > depths[next1]) { + next2 = parents_[next2]; + } + while (next1 != next2) { + next1 = parents_[next1]; + next2 = parents_[next2]; + } + return next1; +} + +template +NodeType RootedTree::LowestCommonAncestorBySearch( + const NodeType n1, const NodeType n2, + std::vector& visited_workspace) const { + const NodeType n = num_nodes(); + CHECK_OK(internal::IsValidNode(n1, n)); + CHECK_OK(internal::IsValidNode(n2, n)); + CHECK_EQ(visited_workspace.size(), n); + if (n1 == root_ || n2 == root_) { + return root_; + } + if (n1 == n2) { + return n1; + } + NodeType next1 = n1; + NodeType next2 = n2; + visited_workspace[n1] = true; + visited_workspace[n2] = true; + NodeType lca = kNullParent; + NodeType lca_distance = + 1; // used only for cleanup purposes, can over estimate + while (true) { + lca_distance++; + if (next1 != root_) { + next1 = parents_[next1]; + if (visited_workspace[next1]) { + lca = next1; + break; + } + } + if (next2 != root_) { + visited_workspace[next1] = true; + next2 = parents_[next2]; + if (visited_workspace[next2]) { + lca = next2; + break; + } + visited_workspace[next2] = true; + } + } + CHECK_OK(internal::IsValidNode(lca, n)); + auto cleanup = [this, lca_distance, &visited_workspace](NodeType next) { + for (NodeType i = 0; i < lca_distance && next != kNullParent; ++i) { + visited_workspace[next] = false; + next = parents_[next]; + } + }; + cleanup(n1); + cleanup(n2); + return lca; +} + +template +void RootedTree::Evert(const NodeType new_root) { + NodeType previous_node = kNullParent; + for (NodeType node = new_root; node != kNullParent;) { + NodeType next_node = parents_[node]; + parents_[node] = previous_node; + previous_node = node; + node = next_node; + } + root_ = new_root; +} + +template +template +void RootedTree::AllDistancesToRootInPlace( + absl::Span topological_order, + absl::Span arc_lengths_in_distances_out) const { + CHECK_EQ(num_nodes(), arc_lengths_in_distances_out.size()); + CHECK_EQ(num_nodes(), topological_order.size()); + if (!topological_order.empty()) { + CHECK_EQ(topological_order[0], root_); + } + for (const NodeType node : topological_order) { + if (parents_[node] == kNullParent) { + arc_lengths_in_distances_out[node] = 0; + } else { + arc_lengths_in_distances_out[node] += + arc_lengths_in_distances_out[parents_[node]]; + } + } +} + +template +std::vector RootedTree::AllDepths( + absl::Span topological_order) const { + std::vector arc_length_in_distance_out(num_nodes(), 1); + AllDistancesToRootInPlace(topological_order, + absl::MakeSpan(arc_length_in_distance_out)); + return arc_length_in_distance_out; +} + +template +template +std::vector RootedTree::AllDistancesToRoot( + absl::Span arc_lengths) const { + return AllDistancesToRoot(arc_lengths, TopologicalSort()); +} + +template +template +std::vector RootedTree::AllDistancesToRoot( + absl::Span arc_lengths, + absl::Span topological_order) const { + std::vector distances(arc_lengths.begin(), arc_lengths.end()); + AllDistancesToRootInPlace(topological_order, absl::MakeSpan(distances)); + return distances; +} + +template +absl::StatusOr> RootedTreeFromGraph( + const typename Graph::NodeIndex root, const Graph& graph, + std::vector* const topological_order, + std::vector* const depths) { + using NodeIndex = typename Graph::NodeIndex; + const NodeIndex num_nodes = graph.num_nodes(); + RETURN_IF_ERROR(internal::IsValidNode(root, num_nodes)) + << "invalid root node"; + if (topological_order != nullptr) { + topological_order->clear(); + topological_order->reserve(num_nodes); + topological_order->push_back(root); + } + if (depths != nullptr) { + depths->clear(); + depths->resize(num_nodes, 0); + } + std::vector tree(num_nodes, RootedTree::kNullParent); + auto visited = [&tree, root](const NodeIndex node) { + if (node == root) { + return true; + } + return tree[node] != RootedTree::kNullParent; + }; + std::vector must_search_children = {root}; + while (!must_search_children.empty()) { + NodeIndex next = must_search_children.back(); + must_search_children.pop_back(); + for (const NodeIndex neighbor : graph[next]) { + if (visited(neighbor)) { + if (tree[next] == neighbor) { + continue; + } else { + // NOTE: this will also catch nodes with self loops. + return util::InvalidArgumentErrorBuilder() + << "graph has cycle containing arc from " << next << " to " + << neighbor; + } + } + tree[neighbor] = next; + if (topological_order != nullptr) { + topological_order->push_back(neighbor); + } + if (depths != nullptr) { + (*depths)[neighbor] = (*depths)[next] + 1; + } + must_search_children.push_back(neighbor); + } + } + for (NodeIndex i = 0; i < num_nodes; ++i) { + if (!visited(i)) { + return util::InvalidArgumentErrorBuilder() + << "graph is not connected, no path to " << i; + } + } + return RootedTree(root, tree); +} + +} // namespace operations_research + +#endif // OR_TOOLS_GRAPH_ROOTED_TREE_H_ diff --git a/ortools/graph/rooted_tree_test.cc b/ortools/graph/rooted_tree_test.cc new file mode 100644 index 0000000000..d2160c2457 --- /dev/null +++ b/ortools/graph/rooted_tree_test.cc @@ -0,0 +1,653 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/graph/rooted_tree.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/random/random.h" +#include "absl/status/status.h" +#include "benchmark/benchmark.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/graph/graph.h" + +namespace operations_research { +namespace { + +using ::testing::AnyOf; +using ::testing::DoubleEq; +using ::testing::Each; +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; +using ::testing::HasSubstr; +using ::testing::IsFalse; +using ::testing::Pair; +using ::testing::SizeIs; +using ::testing::status::StatusIs; + +//////////////////////////////////////////////////////////////////////////////// +// RootedTree Tests +//////////////////////////////////////////////////////////////////////////////// + +template +class RootedTreeTest : public testing::Test { + public: + static constexpr T kNullParent = RootedTree::kNullParent; +}; + +TYPED_TEST_SUITE_P(RootedTreeTest); + +TYPED_TEST_P(RootedTreeTest, CreateFailsRootOutOfBoundsInvalidArgument) { + using Node = TypeParam; + const Node root = 5; + std::vector parents = {0, this->kNullParent}; + EXPECT_THAT(RootedTree::Create(root, parents), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("root"))); +} + +TYPED_TEST_P(RootedTreeTest, CreateFailsRootHasParentInvalidArgument) { + using Node = TypeParam; + const Node root = 0; + std::vector parents = {1, 0}; + EXPECT_THAT(RootedTree::Create(root, parents), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("root"))); +} + +TYPED_TEST_P(RootedTreeTest, CreateFailsExtraRootInvalidArgument) { + using Node = TypeParam; + const Node root = 1; + std::vector parents = {this->kNullParent, this->kNullParent}; + EXPECT_THAT( + RootedTree::Create(root, parents), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("parent"))); +} + +TYPED_TEST_P(RootedTreeTest, CreateFailsBadParentInvalidArgument) { + using Node = TypeParam; + const Node root = 1; + std::vector parents = {3, this->kNullParent}; + EXPECT_THAT( + RootedTree::Create(root, parents), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("parent"))); +} + +TYPED_TEST_P(RootedTreeTest, CreateFailsIsolatedCycleInvalidArgument) { + using Node = TypeParam; + const Node root = 3; + std::vector parents = {1, 2, 0, this->kNullParent, 3}; + EXPECT_THAT(RootedTree::Create(root, parents), + StatusIs(absl::StatusCode::kInvalidArgument, + AllOf(HasSubstr("cycle"), HasSubstr("0, 1, 2")))); + std::vector cycle; + EXPECT_THAT(RootedTree::Create(root, parents, &cycle), + StatusIs(absl::StatusCode::kInvalidArgument, + AllOf(HasSubstr("cycle"), HasSubstr("0, 1, 2")))); + EXPECT_THAT(cycle, ElementsAre(0, 1, 2, 0)); +} + +TYPED_TEST_P(RootedTreeTest, CreateFailsPathLeadsToCycleInvalidArgument) { + using Node = TypeParam; + const Node root = 3; + std::vector parents = {1, 2, 0, this->kNullParent, 0}; + EXPECT_THAT(RootedTree::Create(root, parents), + StatusIs(absl::StatusCode::kInvalidArgument, + AllOf(HasSubstr("cycle"), HasSubstr("0, 1, 2")))); + std::vector cycle; + EXPECT_THAT(RootedTree::Create(root, parents, &cycle), + StatusIs(absl::StatusCode::kInvalidArgument, + AllOf(HasSubstr("cycle"), HasSubstr("0, 1, 2")))); + EXPECT_THAT(cycle, ElementsAre(0, 1, 2, 0)); +} + +TYPED_TEST_P(RootedTreeTest, CreatePathFailsLongCycleErrorIsTruncated) { + using Node = TypeParam; + const Node root = 50; + std::vector parents(51); + parents[root] = this->kNullParent; + for (Node i = 0; i < Node{50}; ++i) { + parents[i] = (i + 1) % Node{50}; + } + EXPECT_THAT(RootedTree::Create(root, parents), + StatusIs(absl::StatusCode::kInvalidArgument, + AllOf(HasSubstr("cycle"), + HasSubstr("0, 1, 2, 3, 4, 5, 6, 7, ..., 0")))); + std::vector cycle; + EXPECT_THAT(RootedTree::Create(root, parents, &cycle), + StatusIs(absl::StatusCode::kInvalidArgument, + AllOf(HasSubstr("cycle"), + HasSubstr("0, 1, 2, 3, 4, 5, 6, 7, ..., 0")))); + std::vector expected_cycle(50); + absl::c_iota(expected_cycle, 0); + expected_cycle.push_back(0); + EXPECT_THAT(cycle, ElementsAreArray(expected_cycle)); +} + +TYPED_TEST_P(RootedTreeTest, PathToRoot) { + using Node = TypeParam; + // 1 + // / | + // 0 3 + // | + // 2 + const Node root = 1; + std::vector parents = {1, this->kNullParent, 3, 1}; + ASSERT_OK_AND_ASSIGN(const auto tree, + RootedTree::Create(root, parents)); + EXPECT_THAT(tree.PathToRoot(0), ElementsAre(0, 1)); + EXPECT_THAT(tree.PathToRoot(1), ElementsAre(1)); + EXPECT_THAT(tree.PathToRoot(2), ElementsAre(2, 3, 1)); + EXPECT_THAT(tree.PathToRoot(3), ElementsAre(3, 1)); +} + +TYPED_TEST_P(RootedTreeTest, DistanceToRoot) { + using Node = TypeParam; + // 1 + // / | + // 0 3 + // | + // 2 + const int root = 1; + std::vector parents = {1, this->kNullParent, 3, 1}; + std::vector arc_lengths = {1, 0, 10, 100}; + ASSERT_OK_AND_ASSIGN(const auto tree, + RootedTree::Create(root, parents)); + EXPECT_DOUBLE_EQ(tree.DistanceToRoot(0, arc_lengths), 1); + EXPECT_DOUBLE_EQ(tree.DistanceToRoot(1, arc_lengths), 0); + EXPECT_DOUBLE_EQ(tree.DistanceToRoot(2, arc_lengths), 110); + EXPECT_DOUBLE_EQ(tree.DistanceToRoot(3, arc_lengths), 100); +} + +TYPED_TEST_P(RootedTreeTest, DistanceAndPathToRoot) { + using Node = TypeParam; + // 1 + // / | + // 0 3 + // | + // 2 + const Node root = 1; + std::vector parents = {1, this->kNullParent, 3, 1}; + std::vector arc_lengths = {1, 0, 10, 100}; + ASSERT_OK_AND_ASSIGN(const auto tree, + RootedTree::Create(root, parents)); + EXPECT_THAT(tree.DistanceAndPathToRoot(0, arc_lengths), + Pair(DoubleEq(1.0), ElementsAre(0, 1))); + EXPECT_THAT(tree.DistanceAndPathToRoot(1, arc_lengths), + Pair(DoubleEq(0.0), ElementsAre(1))); + EXPECT_THAT(tree.DistanceAndPathToRoot(2, arc_lengths), + Pair(DoubleEq(110.0), ElementsAre(2, 3, 1))); + EXPECT_THAT(tree.DistanceAndPathToRoot(3, arc_lengths), + Pair(DoubleEq(100.0), ElementsAre(3, 1))); +} + +TYPED_TEST_P(RootedTreeTest, TopologicalSort) { + using Node = TypeParam; + // 1 + // / | + // 0 3 + // | + // 2 + const Node root = 1; + std::vector parents = {1, this->kNullParent, 3, 1}; + ASSERT_OK_AND_ASSIGN(const auto tree, + RootedTree::Create(root, parents)); + + EXPECT_THAT(tree.TopologicalSort(), + AnyOf(ElementsAre(1, 0, 3, 2), ElementsAre(1, 3, 2, 0), + ElementsAre(1, 3, 0, 2))); +} + +TYPED_TEST_P(RootedTreeTest, AllDistancesToRoot) { + using Node = TypeParam; + // 1 + // / | + // 0 3 + // | + // 2 + const int root = 1; + std::vector parents = {1, this->kNullParent, 3, 1}; + const std::vector arc_lengths = {1, 0, 10, 100}; + ASSERT_OK_AND_ASSIGN(const auto tree, + RootedTree::Create(root, parents)); + EXPECT_THAT(tree.AllDistancesToRoot(arc_lengths), + ElementsAre(1.0, 0.0, 110.0, 100.0)); +} + +TYPED_TEST_P(RootedTreeTest, AllDepths) { + using Node = TypeParam; + // 1 + // / | + // 0 3 + // | + // 2 + const Node root = 1; + std::vector parents = {1, this->kNullParent, 3, 1}; + ASSERT_OK_AND_ASSIGN(const auto tree, + RootedTree::Create(root, parents)); + EXPECT_THAT(tree.AllDepths(), ElementsAre(1, 0, 2, 1)); +} + +TYPED_TEST_P(RootedTreeTest, LCAByDepth) { + using Node = TypeParam; + // 4 + // / + // 1 + // / | + // 0 3 + // | + // 2 + const int root = 4; + std::vector parents = {1, 4, 3, 1, this->kNullParent}; + ASSERT_OK_AND_ASSIGN(const auto tree, + RootedTree::Create(root, parents)); + const std::vector depths = {2, 1, 3, 2, 0}; + ASSERT_THAT(tree.AllDepths(), ElementsAreArray(depths)); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(0, 0, depths), 0); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(0, 1, depths), 1); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(0, 2, depths), 1); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(0, 3, depths), 1); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(0, 4, depths), 4); + + EXPECT_EQ(tree.LowestCommonAncestorByDepth(1, 0, depths), 1); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(1, 1, depths), 1); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(1, 2, depths), 1); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(1, 3, depths), 1); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(1, 4, depths), 4); + + EXPECT_EQ(tree.LowestCommonAncestorByDepth(2, 0, depths), 1); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(2, 1, depths), 1); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(2, 2, depths), 2); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(2, 3, depths), 3); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(2, 4, depths), 4); + + EXPECT_EQ(tree.LowestCommonAncestorByDepth(3, 0, depths), 1); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(3, 1, depths), 1); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(3, 2, depths), 3); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(3, 3, depths), 3); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(3, 4, depths), 4); + + EXPECT_EQ(tree.LowestCommonAncestorByDepth(4, 0, depths), 4); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(4, 1, depths), 4); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(4, 2, depths), 4); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(4, 3, depths), 4); + EXPECT_EQ(tree.LowestCommonAncestorByDepth(4, 4, depths), 4); +} + +TYPED_TEST_P(RootedTreeTest, LCAByBySearch) { + using Node = TypeParam; + // 4 + // / + // 1 + // / | + // 0 3 + // | + // 2 + const Node root = 4; + std::vector parents = {1, 4, 3, 1, this->kNullParent}; + ASSERT_OK_AND_ASSIGN(const auto tree, + RootedTree::Create(root, parents)); + std::vector ws(5, false); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(0, 0, ws), 0); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(0, 1, ws), 1); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(0, 2, ws), 1); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(0, 3, ws), 1); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(0, 4, ws), 4); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + + EXPECT_EQ(tree.LowestCommonAncestorBySearch(1, 0, ws), 1); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(1, 1, ws), 1); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(1, 2, ws), 1); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(1, 3, ws), 1); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(1, 4, ws), 4); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + + EXPECT_EQ(tree.LowestCommonAncestorBySearch(2, 0, ws), 1); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(2, 1, ws), 1); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(2, 2, ws), 2); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(2, 3, ws), 3); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(2, 4, ws), 4); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + + EXPECT_EQ(tree.LowestCommonAncestorBySearch(3, 0, ws), 1); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(3, 1, ws), 1); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(3, 2, ws), 3); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(3, 3, ws), 3); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(3, 4, ws), 4); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + + EXPECT_EQ(tree.LowestCommonAncestorBySearch(4, 0, ws), 4); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(4, 1, ws), 4); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(4, 2, ws), 4); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(4, 3, ws), 4); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); + EXPECT_EQ(tree.LowestCommonAncestorBySearch(4, 4, ws), 4); + ASSERT_THAT(ws, AllOf(SizeIs(5), Each(IsFalse()))); +} + +TYPED_TEST_P(RootedTreeTest, Path) { + using Node = TypeParam; + // 4 + // / + // 1 + // / | + // 0 3 + // | + // 2 + const Node root = 4; + std::vector parents = {1, 4, 3, 1, this->kNullParent}; + ASSERT_OK_AND_ASSIGN(const auto tree, + RootedTree::Create(root, parents)); + std::vector depths = {2, 1, 3, 2, 0}; + ASSERT_THAT(tree.AllDepths(), ElementsAreArray(depths)); + auto path = [&tree, &depths](Node start, Node end) { + const Node lca = tree.LowestCommonAncestorByDepth(start, end, depths); + return tree.Path(start, end, lca); + }; + EXPECT_THAT(path(0, 0), ElementsAre(0)); + EXPECT_THAT(path(0, 1), ElementsAre(0, 1)); + EXPECT_THAT(path(0, 2), ElementsAre(0, 1, 3, 2)); + EXPECT_THAT(path(0, 3), ElementsAre(0, 1, 3)); + EXPECT_THAT(path(0, 4), ElementsAre(0, 1, 4)); + + EXPECT_THAT(path(1, 0), ElementsAre(1, 0)); + EXPECT_THAT(path(1, 1), ElementsAre(1)); + EXPECT_THAT(path(1, 2), ElementsAre(1, 3, 2)); + EXPECT_THAT(path(1, 3), ElementsAre(1, 3)); + EXPECT_THAT(path(1, 4), ElementsAre(1, 4)); + + EXPECT_THAT(path(2, 0), ElementsAre(2, 3, 1, 0)); + EXPECT_THAT(path(2, 1), ElementsAre(2, 3, 1)); + EXPECT_THAT(path(2, 2), ElementsAre(2)); + EXPECT_THAT(path(2, 3), ElementsAre(2, 3)); + EXPECT_THAT(path(2, 4), ElementsAre(2, 3, 1, 4)); + + EXPECT_THAT(path(3, 0), ElementsAre(3, 1, 0)); + EXPECT_THAT(path(3, 1), ElementsAre(3, 1)); + EXPECT_THAT(path(3, 2), ElementsAre(3, 2)); + EXPECT_THAT(path(3, 3), ElementsAre(3)); + EXPECT_THAT(path(3, 4), ElementsAre(3, 1, 4)); + + EXPECT_THAT(path(4, 0), ElementsAre(4, 1, 0)); + EXPECT_THAT(path(4, 1), ElementsAre(4, 1)); + EXPECT_THAT(path(4, 2), ElementsAre(4, 1, 3, 2)); + EXPECT_THAT(path(4, 3), ElementsAre(4, 1, 3)); + EXPECT_THAT(path(4, 4), ElementsAre(4)); +} + +TYPED_TEST_P(RootedTreeTest, Distance) { + using Node = TypeParam; + // 4 + // / + // 1 + // / | + // 0 3 + // | + // 2 + const int root = 4; + std::vector parents = {1, 4, 3, 1, this->kNullParent}; + std::vector arc_lengths = {1.0, 10.0, 100.0, 1000.0, 0.0}; + ASSERT_OK_AND_ASSIGN(const auto tree, + RootedTree::Create(root, parents)); + std::vector depths = {2, 1, 3, 2, 0}; + ASSERT_THAT(tree.AllDepths(), ElementsAreArray(depths)); + auto dist = [&tree, &depths, &arc_lengths](Node start, Node end) { + const Node lca = tree.LowestCommonAncestorByDepth(start, end, depths); + return tree.Distance(start, end, lca, arc_lengths); + }; + EXPECT_DOUBLE_EQ(dist(0, 0), 0.0); + EXPECT_DOUBLE_EQ(dist(0, 1), 1.0); + EXPECT_DOUBLE_EQ(dist(0, 2), 1101.0); + EXPECT_DOUBLE_EQ(dist(0, 3), 1001.0); + EXPECT_DOUBLE_EQ(dist(0, 4), 11.0); + + EXPECT_DOUBLE_EQ(dist(1, 0), 1.0); + EXPECT_DOUBLE_EQ(dist(1, 1), 0.0); + EXPECT_DOUBLE_EQ(dist(1, 2), 1100.0); + EXPECT_DOUBLE_EQ(dist(1, 3), 1000.0); + EXPECT_DOUBLE_EQ(dist(1, 4), 10.0); + + EXPECT_DOUBLE_EQ(dist(2, 0), 1101.0); + EXPECT_DOUBLE_EQ(dist(2, 1), 1100.0); + EXPECT_DOUBLE_EQ(dist(2, 2), 0.0); + EXPECT_DOUBLE_EQ(dist(2, 3), 100.0); + EXPECT_DOUBLE_EQ(dist(2, 4), 1110.0); + + EXPECT_DOUBLE_EQ(dist(3, 0), 1001.0); + EXPECT_DOUBLE_EQ(dist(3, 1), 1000.0); + EXPECT_DOUBLE_EQ(dist(3, 2), 100.0); + EXPECT_DOUBLE_EQ(dist(3, 3), 0.0); + EXPECT_DOUBLE_EQ(dist(3, 4), 1010.0); + + EXPECT_DOUBLE_EQ(dist(4, 0), 11.0); + EXPECT_DOUBLE_EQ(dist(4, 1), 10.0); + EXPECT_DOUBLE_EQ(dist(4, 2), 1110.0); + EXPECT_DOUBLE_EQ(dist(4, 3), 1010.0); + EXPECT_DOUBLE_EQ(dist(4, 4), 0.0); +} + +TYPED_TEST_P(RootedTreeTest, EvertChangeRoot) { + using Node = TypeParam; + // Starting graph, with root 2: + // 0 -> 1 -> 2 + // | | | + // 3 4 5 + // + // Evert: change the root to 0 + // + // 0 <- 1 <- 2 + // | | | + // 3 4 5 + const Node root = 2; + const std::vector parents = {1, 2, this->kNullParent, 0, 1, 2}; + ASSERT_OK_AND_ASSIGN(auto tree, RootedTree::Create(root, parents)); + tree.Evert(0); + EXPECT_EQ(tree.root(), 0); + EXPECT_THAT(tree.parents(), ElementsAre(this->kNullParent, 0, 1, 0, 1, 2)); +} + +TYPED_TEST_P(RootedTreeTest, EvertSameRoot) { + using Node = TypeParam; + const Node root = 1; + const std::vector parents = {1, this->kNullParent, 1}; + ASSERT_OK_AND_ASSIGN(auto tree, RootedTree::Create(root, parents)); + tree.Evert(1); + EXPECT_EQ(tree.root(), 1); + EXPECT_THAT(tree.parents(), ElementsAre(1, this->kNullParent, 1)); +} + +TYPED_TEST_P(RootedTreeTest, RootedTreeFromGraphSuccessNoExtraOutputs) { + using Node = TypeParam; + // 4 + // / + // 1 + // / | + // 0 3 + // | + // 2 + util::ListGraph graph; + graph.AddNode(4); + for (auto [n1, n2] : + std::vector>{{0, 1}, {1, 4}, {1, 3}, {3, 2}}) { + graph.AddArc(n1, n2); + graph.AddArc(n2, n1); + } + const Node root = 4; + std::vector* topo = nullptr; + std::vector* depth = nullptr; + ASSERT_OK_AND_ASSIGN(const RootedTree tree, + RootedTreeFromGraph(root, graph, topo, depth)); + EXPECT_EQ(tree.root(), 4); + EXPECT_THAT(tree.parents(), ElementsAre(1, 4, 3, 1, this->kNullParent)); + EXPECT_EQ(topo, nullptr); + EXPECT_EQ(depth, nullptr); +} + +TYPED_TEST_P(RootedTreeTest, RootedTreeFromGraphSuccessAllExtraOutputs) { + using Node = TypeParam; + // 4 + // / + // 1 + // / | + // 0 3 + // | + // 2 + util::ListGraph graph; + graph.AddNode(4); + for (auto [n1, n2] : + std::vector>{{0, 1}, {1, 4}, {1, 3}, {3, 2}}) { + graph.AddArc(n1, n2); + graph.AddArc(n2, n1); + } + const Node root = 4; + std::vector topo; + std::vector depth; + ASSERT_OK_AND_ASSIGN(const RootedTree tree, + RootedTreeFromGraph(root, graph, &topo, &depth)); + EXPECT_EQ(tree.root(), 4); + EXPECT_THAT(tree.parents(), ElementsAre(1, 4, 3, 1, this->kNullParent)); + EXPECT_THAT(topo, + AnyOf(ElementsAre(4, 1, 0, 3, 2), ElementsAre(4, 1, 3, 0, 2), + ElementsAre(4, 1, 3, 2, 0))); + EXPECT_THAT(depth, ElementsAre(2, 1, 3, 2, 0)); +} + +TYPED_TEST_P(RootedTreeTest, RootedTreeFromGraphBadRootInvalidArgument) { + using Node = TypeParam; + util::ListGraph graph; + graph.AddNode(2); + graph.AddArc(0, 1); + graph.AddArc(1, 0); + const Node root = 4; + EXPECT_THAT( + RootedTreeFromGraph(root, graph, nullptr, nullptr), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("invalid root"))); +} + +TYPED_TEST_P(RootedTreeTest, RootedTreeFromGraphSelfCycleInvalidArgument) { + using Node = TypeParam; + util::ListGraph graph; + graph.AddNode(2); + graph.AddArc(0, 1); + graph.AddArc(1, 0); + graph.AddArc(1, 1); + const Node root = 0; + EXPECT_THAT(RootedTreeFromGraph(root, graph, nullptr, nullptr), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("cycle"))); +} + +TYPED_TEST_P(RootedTreeTest, RootedTreeFromGraphHasCycleInvalidArgument) { + using Node = TypeParam; + util::ListGraph graph; + graph.AddNode(3); + graph.AddArc(0, 1); + graph.AddArc(1, 0); + graph.AddArc(1, 2); + graph.AddArc(2, 1); + graph.AddArc(2, 0); + graph.AddArc(0, 2); + const Node root = 0; + EXPECT_THAT(RootedTreeFromGraph(root, graph, nullptr, nullptr), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("cycle"))); +} + +TYPED_TEST_P(RootedTreeTest, RootedTreeFromGraphNotConnectedInvalidArgument) { + using Node = TypeParam; + util::ListGraph graph; + graph.AddNode(4); + graph.AddArc(0, 1); + graph.AddArc(1, 0); + graph.AddArc(2, 3); + graph.AddArc(3, 2); + const Node root = 0; + EXPECT_THAT( + RootedTreeFromGraph(root, graph, nullptr, nullptr), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("not connected"))); +} + +REGISTER_TYPED_TEST_SUITE_P( + RootedTreeTest, CreateFailsRootOutOfBoundsInvalidArgument, + CreateFailsRootHasParentInvalidArgument, + CreateFailsExtraRootInvalidArgument, CreateFailsBadParentInvalidArgument, + CreateFailsIsolatedCycleInvalidArgument, + CreateFailsPathLeadsToCycleInvalidArgument, + CreatePathFailsLongCycleErrorIsTruncated, PathToRoot, DistanceToRoot, + DistanceAndPathToRoot, TopologicalSort, AllDistancesToRoot, AllDepths, + LCAByDepth, LCAByBySearch, Path, Distance, EvertChangeRoot, EvertSameRoot, + RootedTreeFromGraphSuccessNoExtraOutputs, + RootedTreeFromGraphSuccessAllExtraOutputs, + RootedTreeFromGraphBadRootInvalidArgument, + RootedTreeFromGraphSelfCycleInvalidArgument, + RootedTreeFromGraphHasCycleInvalidArgument, + RootedTreeFromGraphNotConnectedInvalidArgument); + +using NodeTypes = + ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(AllRootedTreeTests, RootedTreeTest, NodeTypes); + +//////////////////////////////////////////////////////////////////////////////// +// Benchmarks +//////////////////////////////////////////////////////////////////////////////// + +std::vector RandomTreeRootedZero(int num_nodes) { + absl::BitGen bit_gen; + std::vector nodes(num_nodes); + absl::c_iota(nodes, 0); + std::shuffle(nodes.begin() + 1, nodes.end(), bit_gen); + std::vector result(num_nodes); + result[0] = -1; + for (int i = 1; i < num_nodes; ++i) { + int target = absl::Uniform(bit_gen, 0, i); + result[i] = target; + } + return result; +} + +void BM_RootedTreeShortestPath(benchmark::State& state) { + const int num_nodes = state.range(0); + std::vector random_tree_data = RandomTreeRootedZero(num_nodes); + for (auto s : state) { + const RootedTree tree = + RootedTree::Create(0, random_tree_data).value(); + std::vector path = tree.PathToRoot(num_nodes - 1); + CHECK_GE(path.size(), 2); + } +} + +BENCHMARK(BM_RootedTreeShortestPath)->Arg(100)->Arg(10'000)->Arg(1'000'000); + +} // namespace +} // namespace operations_research diff --git a/ortools/graph/samples/BUILD.bazel b/ortools/graph/samples/BUILD.bazel index 915ea2b68b..e943771d66 100644 --- a/ortools/graph/samples/BUILD.bazel +++ b/ortools/graph/samples/BUILD.bazel @@ -37,6 +37,16 @@ code_sample_cc(name = "dag_shortest_path_sequential") code_sample_cc(name = "dag_simple_shortest_path") +code_sample_cc(name = "dag_multiple_shortest_paths_one_to_all") + +code_sample_cc(name = "dag_multiple_shortest_paths_sequential") + +code_sample_cc(name = "dag_simple_multiple_shortest_paths") + +code_sample_cc(name = "dag_constrained_shortest_path_sequential") + +code_sample_cc(name = "dag_simple_constrained_shortest_path") + code_sample_cc(name = "dijkstra_all_pairs_shortest_paths") code_sample_cc(name = "dijkstra_directed") @@ -47,6 +57,10 @@ code_sample_cc(name = "dijkstra_sequential") code_sample_cc(name = "dijkstra_undirected") +code_sample_cc(name = "root_a_tree") + +code_sample_cc(name = "rooted_tree_paths") + code_sample_java(name = "SimpleMaxFlowProgram") code_sample_cc_py(name = "simple_max_flow_program") diff --git a/ortools/graph/samples/code_samples.bzl b/ortools/graph/samples/code_samples.bzl index f7b3700973..cda07eb817 100644 --- a/ortools/graph/samples/code_samples.bzl +++ b/ortools/graph/samples/code_samples.bzl @@ -14,6 +14,7 @@ """Helper macro to compile and test code samples.""" load("@pip_deps//:requirements.bzl", "requirement") +load("@rules_python//python:defs.bzl", "py_binary", "py_test") def code_sample_cc(name): native.cc_binary( @@ -26,11 +27,13 @@ def code_sample_cc(name): "//ortools/graph:assignment", "//ortools/graph:bounded_dijkstra", "//ortools/graph:bfs", + "//ortools/graph:dag_constrained_shortest_path", "//ortools/graph:dag_shortest_path", "//ortools/graph:ebert_graph", "//ortools/graph:linear_assignment", "//ortools/graph:max_flow", "//ortools/graph:min_cost_flow", + "//ortools/graph:rooted_tree", "@com_google_absl//absl/random", ], ) @@ -47,17 +50,19 @@ def code_sample_cc(name): "//ortools/graph:assignment", "//ortools/graph:bounded_dijkstra", "//ortools/graph:bfs", + "//ortools/graph:dag_constrained_shortest_path", "//ortools/graph:dag_shortest_path", "//ortools/graph:ebert_graph", "//ortools/graph:linear_assignment", "//ortools/graph:max_flow", "//ortools/graph:min_cost_flow", + "//ortools/graph:rooted_tree", "@com_google_absl//absl/random", ], ) def code_sample_py(name): - native.py_binary( + py_binary( name = name + "_py3", srcs = [name + ".py"], main = name + ".py", @@ -72,7 +77,7 @@ def code_sample_py(name): srcs_version = "PY3", ) - native.py_test( + py_test( name = name + "_py_test", size = "small", srcs = [name + ".py"], diff --git a/ortools/graph/samples/dag_constrained_shortest_path_sequential.cc b/ortools/graph/samples/dag_constrained_shortest_path_sequential.cc new file mode 100644 index 0000000000..289c32eeb6 --- /dev/null +++ b/ortools/graph/samples/dag_constrained_shortest_path_sequential.cc @@ -0,0 +1,138 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// [START imports] +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "ortools/base/init_google.h" +#include "ortools/graph/dag_constrained_shortest_path.h" +#include "ortools/graph/dag_shortest_path.h" +#include "ortools/graph/graph.h" +// [END imports] + +int main(int argc, char** argv) { + InitGoogle(argv[0], &argc, &argv, true); + + // [START graph] + // Create a graph with n + 2 nodes, indexed from 0: + // * Node n is `source` + // * Node n+1 is `dest` + // * Nodes M = [0, 1, ..., n-1] are in the middle. + // + // There is a single resource constraints with limit 1. + // + // The graph has 3 * n - 1 arcs (with weights and both resources): + // * (source -> i) with weight 100 and no resource use for i in M + // * (i -> dest) with weight 100 and no resource use for i in M + // * (i -> (i+1)) with weight 1 and resource use of 1 for i = 0, ..., n-2 + // + // Every path [source, i, dest] for i in M is a constrained shortest path from + // source to dest with weight 200. + const int n = 5; + const int source = n; + const int dest = n + 1; + const int num_arcs = 3 * n - 1; + util::StaticGraph<> graph; + // There are 3 types of arcs: (1) source to M, (2) M to dest, and (3) within + // M. This vector stores all of them, first of type (1), then type (2), + // then type (3). The arcs are ordered by i in M within each type. + std::vector weights(num_arcs); + // Resources are first indexed by resource, then by arc. + std::vector> resources(1, std::vector(num_arcs)); + + for (int i = 0; i < n; ++i) { + graph.AddArc(source, i); + weights[i] = 100.0; + resources[0][i] = 0.0; + } + for (int i = 0; i < n; ++i) { + graph.AddArc(i, dest); + weights[n + i] = 100.0; + resources[0][n + i] = 0.0; + } + for (int i = 0; i + 1 < n; ++i) { + graph.AddArc(i, i + 1); + weights[2 * n + i] = 1.0; + resources[0][2 * n + i] = 1.0; + } + + // Static graph reorders the arcs at Build() time, use permutation to get from + // the old ordering to the new one. + std::vector permutation; + graph.Build(&permutation); + util::Permute(permutation, &weights); + util::Permute(permutation, &resources[0]); + // [END graph] + + // [START first-path] + // A reusable shortest path calculator. + // We need a topological order. For this structured graph, we find it by hand + // instead of using util::graph::FastTopologicalSort(). + std::vector topological_order = {source}; + for (int32_t i = 0; i < n; ++i) { + topological_order.push_back(i); + } + topological_order.push_back(dest); + + const std::vector sources = {source}; + const std::vector destinations = {dest}; + const std::vector max_resources = {1.0}; + + operations_research::ConstrainedShortestPathsOnDagWrapper> + constrained_shortest_path_on_dag(&graph, &weights, &resources, + topological_order, sources, destinations, + &max_resources); + operations_research::PathWithLength initial_constrained_shortest_path = + constrained_shortest_path_on_dag.RunConstrainedShortestPathOnDag(); + + std::cout << "Initial distance: " << initial_constrained_shortest_path.length + << std::endl; + std::cout << "Initial path: " + << absl::StrJoin(initial_constrained_shortest_path.node_path, ", ") + << std::endl; + // [END first-path] + + // [START more-paths] + // Now, we make a single arc from source to M free, and a single arc from M + // to dest free, and resolve. If the free edge from the source hits before + // the free edge to the dest in M, we use both, walking through M. Otherwise, + // we use only one free arc. + std::vector> fast_paths = {{2, 3}, {8, 1}, {3, 7}}; + for (const auto [free_from_source, free_to_dest] : fast_paths) { + weights[permutation[free_from_source]] = 0; + weights[permutation[n + free_to_dest]] = 0; + + operations_research::PathWithLength constrained_shortest_path = + constrained_shortest_path_on_dag.RunConstrainedShortestPathOnDag(); + std::cout << "source -> " << free_from_source << " and " << free_to_dest + << " -> dest are now free" << std::endl; + std::string label = absl::StrCat("_", free_from_source, "_", free_to_dest); + std::cout << "Distance" << label << ": " << constrained_shortest_path.length + << std::endl; + std::cout << "Path" << label << ": " + << absl::StrJoin(constrained_shortest_path.node_path, ", ") + << std::endl; + + // Restore the old weights + weights[permutation[free_from_source]] = 100; + weights[permutation[n + free_to_dest]] = 100; + } + // [END more-paths] + return 0; +} diff --git a/ortools/graph/samples/dag_multiple_shortest_paths_one_to_all.cc b/ortools/graph/samples/dag_multiple_shortest_paths_one_to_all.cc new file mode 100644 index 0000000000..a0a4a9ebd2 --- /dev/null +++ b/ortools/graph/samples/dag_multiple_shortest_paths_one_to_all.cc @@ -0,0 +1,87 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_join.h" +#include "ortools/base/init_google.h" +#include "ortools/base/status_macros.h" +#include "ortools/graph/dag_shortest_path.h" +#include "ortools/graph/graph.h" +#include "ortools/graph/topologicalsorter.h" + +namespace { + +absl::Status Main() { + util::StaticGraph<> graph; + std::vector weights; + graph.AddArc(0, 1); + weights.push_back(2.0); + graph.AddArc(0, 2); + weights.push_back(5.0); + graph.AddArc(1, 4); + weights.push_back(1.0); + graph.AddArc(2, 4); + weights.push_back(-3.0); + graph.AddArc(3, 4); + weights.push_back(0.0); + + // Static graph reorders the arcs at Build() time, use permutation to get + // from the old ordering to the new one. + std::vector permutation; + graph.Build(&permutation); + util::Permute(permutation, &weights); + + // We need a topological order. We can find it by hand on this small graph, + // e.g., {0, 1, 2, 3, 4}, but we demonstrate how to compute one instead. + ASSIGN_OR_RETURN(const std::vector topological_order, + util::graph::FastTopologicalSort(graph)); + + operations_research::KShortestPathsOnDagWrapper> + shortest_paths_on_dag(&graph, &weights, topological_order, + /*path_count=*/2); + const int source = 0; + shortest_paths_on_dag.RunKShortestPathOnDag({source}); + + // For each node other than 0, print its distance and the shortest path. + for (int node = 1; node < 5; ++node) { + std::cout << "Node " << node << ":\n"; + if (!shortest_paths_on_dag.IsReachable(node)) { + std::cout << "\tNo path to node " << node << std::endl; + continue; + } + const std::vector lengths = shortest_paths_on_dag.LengthsTo(node); + const std::vector> paths = + shortest_paths_on_dag.NodePathsTo(node); + for (int path_index = 0; path_index < lengths.size(); ++path_index) { + std::cout << "\t#" << (path_index + 1) << " shortest path to node " + << node << " has length: " << lengths[path_index] << std::endl; + std::cout << "\t#" << (path_index + 1) << " shortest path to node " + << node << " is: " << absl::StrJoin(paths[path_index], ", ") + << std::endl; + } + } + return absl::OkStatus(); +} + +} // namespace + +int main(int argc, char** argv) { + InitGoogle(argv[0], &argc, &argv, true); + QCHECK_OK(Main()); + return 0; +} diff --git a/ortools/graph/samples/dag_multiple_shortest_paths_sequential.cc b/ortools/graph/samples/dag_multiple_shortest_paths_sequential.cc new file mode 100644 index 0000000000..cab25be1fe --- /dev/null +++ b/ortools/graph/samples/dag_multiple_shortest_paths_sequential.cc @@ -0,0 +1,135 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// [START imports] +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "ortools/base/init_google.h" +#include "ortools/graph/dag_shortest_path.h" +#include "ortools/graph/graph.h" +// [END imports] + +int main(int argc, char** argv) { + InitGoogle(argv[0], &argc, &argv, true); + + // [START graph] + // Create a graph with n + 2 nodes, indexed from 0: + // * Node n is `source` + // * Node n+1 is `dest` + // * Nodes M = [0, 1, ..., n-1] are in the middle. + // + // The graph has 3 * n - 1 arcs (with weights): + // * (source -> i) with weight 100 + i for i in M + // * (i -> dest) with weight 100 + i for i in M + // * (i -> (i+1)) with weight 10 for i = 0, ..., n-2 + const int n = 10; + const int source = n; + const int dest = n + 1; + util::StaticGraph<> graph; + // There are 3 types of arcs: (1) source to M, (2) M to dest, and (3) within + // M. This vector stores all of them, first of type (1), then type (2), + // then type (3). The arcs are ordered by i in M within each type. + std::vector weights(3 * n - 1); + + for (int i = 0; i < n; ++i) { + graph.AddArc(source, i); + weights[i] = 100.0 + i; + } + for (int i = 0; i < n; ++i) { + graph.AddArc(i, dest); + weights[n + i] = 100.0 + i; + } + for (int i = 0; i + 1 < n; ++i) { + graph.AddArc(i, i + 1); + weights[2 * n + i] = 10.0; + } + + // Static graph reorders the arcs at Build() time, use permutation to get from + // the old ordering to the new one. + std::vector permutation; + graph.Build(&permutation); + util::Permute(permutation, &weights); + // [END graph] + + // [START first-path] + // A reusable shortest path calculator. + // We need a topological order. For this structured graph, we find it by hand + // instead of using util::graph::FastTopologicalSort(). + std::vector topological_order = {source}; + for (int32_t i = 0; i < n; ++i) { + topological_order.push_back(i); + } + topological_order.push_back(dest); + + operations_research::KShortestPathsOnDagWrapper> + shortest_paths_on_dag(&graph, &weights, topological_order, + /*path_count=*/2); + shortest_paths_on_dag.RunKShortestPathOnDag({source}); + + const std::vector initial_lengths = + shortest_paths_on_dag.LengthsTo(dest); + const std::vector> initial_paths = + shortest_paths_on_dag.NodePathsTo(dest); + + std::cout << "No free arcs" << std::endl; + for (int path_index = 0; path_index < initial_lengths.size(); ++path_index) { + std::cout << "\t#" << (path_index + 1) + << " shortest path has length: " << initial_lengths[path_index] + << std::endl; + std::cout << "\t#" << (path_index + 1) << " shortest path is: " + << absl::StrJoin(initial_paths[path_index], ", ") << std::endl; + } + // [END first-path] + + // [START more-paths] + // Now, we make a single arc from source to M free, and a single arc from M + // to dest free, and resolve. If the free edge from the source hits before + // the free edge to the dest in M, we use both, walking through M. Otherwise, + // we use only one free arc. + std::vector> fast_paths = { + {2, 4}, {8, 1}, {3, 3}, {0, 0}}; + for (const auto [free_from_source, free_to_dest] : fast_paths) { + weights[permutation[free_from_source]] = 0; + weights[permutation[n + free_to_dest]] = 0; + + shortest_paths_on_dag.RunKShortestPathOnDag({source}); + std::cout << "source -> " << free_from_source << " and " << free_to_dest + << " -> dest are now free" << std::endl; + std::string label = + absl::StrCat(" (", free_from_source, ", ", free_to_dest, ")"); + + const std::vector lengths = shortest_paths_on_dag.LengthsTo(dest); + const std::vector> paths = + shortest_paths_on_dag.NodePathsTo(dest); + + for (int path_index = 0; path_index < lengths.size(); ++path_index) { + std::cout << "\t#" << (path_index + 1) << " shortest path" << label + << " has length: " << lengths[path_index] << std::endl; + std::cout << "\t#" << (path_index + 1) << " shortest path" << label + << " is: " << absl::StrJoin(paths[path_index], ", ") + << std::endl; + } + + // Restore the old weights + weights[permutation[free_from_source]] = 100 + free_from_source; + weights[permutation[n + free_to_dest]] = 100 + free_to_dest; + } + // [END more-paths] + return 0; +} diff --git a/ortools/graph/samples/dag_simple_constrained_shortest_path.cc b/ortools/graph/samples/dag_simple_constrained_shortest_path.cc new file mode 100644 index 0000000000..071846da2a --- /dev/null +++ b/ortools/graph/samples/dag_simple_constrained_shortest_path.cc @@ -0,0 +1,47 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/strings/str_join.h" +#include "ortools/base/init_google.h" +#include "ortools/graph/dag_constrained_shortest_path.h" +#include "ortools/graph/dag_shortest_path.h" + +int main(int argc, char** argv) { + InitGoogle(argv[0], &argc, &argv, true); + + // The input graph, encoded as a list of arcs with distances. + std::vector arcs = { + {.from = 0, .to = 1, .length = 5, .resources = {1, 2}}, + {.from = 0, .to = 2, .length = 4, .resources = {3, 2}}, + {.from = 0, .to = 2, .length = 1, .resources = {2, 3}}, + {.from = 1, .to = 3, .length = -3, .resources = {8, 0}}, + {.from = 2, .to = 3, .length = 0, .resources = {3, 1}}}; + const int num_nodes = 4; + const std::vector max_resources = {6, 3}; + + const int source = 0; + const int destination = 3; + const operations_research::PathWithLength path_with_length = + operations_research::ConstrainedShortestPathsOnDag( + num_nodes, arcs, source, destination, max_resources); + + // Print to length of the path and then the nodes in the path. + std::cout << "Constrained shortest path length: " << path_with_length.length + << std::endl; + std::cout << "Constrained shortest path nodes: " + << absl::StrJoin(path_with_length.node_path, ", ") << std::endl; + return 0; +} diff --git a/ortools/graph/samples/dag_simple_multiple_shortest_paths.cc b/ortools/graph/samples/dag_simple_multiple_shortest_paths.cc new file mode 100644 index 0000000000..33f9e7291a --- /dev/null +++ b/ortools/graph/samples/dag_simple_multiple_shortest_paths.cc @@ -0,0 +1,47 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/strings/str_join.h" +#include "ortools/base/init_google.h" +#include "ortools/graph/dag_shortest_path.h" + +int main(int argc, char** argv) { + InitGoogle(argv[0], &argc, &argv, true); + + // The input graph, encoded as a list of arcs with distances. + std::vector arcs = { + {.from = 0, .to = 1, .length = 2}, {.from = 0, .to = 2, .length = 5}, + {.from = 0, .to = 3, .length = 4}, {.from = 1, .to = 4, .length = 1}, + {.from = 2, .to = 4, .length = -3}, {.from = 3, .to = 4, .length = 0}}; + const int num_nodes = 5; + + const int source = 0; + const int destination = 4; + const int path_count = 2; + const std::vector paths_with_length = + operations_research::KShortestPathsOnDag(num_nodes, arcs, source, + destination, path_count); + + for (int path_index = 0; path_index < paths_with_length.size(); + ++path_index) { + std::cout << "#" << (path_index + 1) << " shortest path has length: " + << paths_with_length[path_index].length << std::endl; + std::cout << "#" << (path_index + 1) << " shortest path is: " + << absl::StrJoin(paths_with_length[path_index].node_path, ", ") + << std::endl; + } + return 0; +} diff --git a/ortools/graph/samples/dag_simple_shortest_path.cc b/ortools/graph/samples/dag_simple_shortest_path.cc index 1ade5d59fe..23c9e1a2ba 100644 --- a/ortools/graph/samples/dag_simple_shortest_path.cc +++ b/ortools/graph/samples/dag_simple_shortest_path.cc @@ -23,11 +23,11 @@ int main(int argc, char** argv) { // The input graph, encoded as a list of arcs with distances. std::vector arcs = { - {.tail = 0, .head = 2, .length = 5}, - {.tail = 0, .head = 3, .length = 4}, - {.tail = 1, .head = 3, .length = 1}, - {.tail = 2, .head = 4, .length = -3}, - {.tail = 3, .head = 4, .length = 0}}; + {.from = 0, .to = 2, .length = 5}, + {.from = 0, .to = 3, .length = 4}, + {.from = 1, .to = 3, .length = 1}, + {.from = 2, .to = 4, .length = -3}, + {.from = 3, .to = 4, .length = 0}}; const int num_nodes = 5; const int source = 0; diff --git a/ortools/graph/samples/root_a_tree.cc b/ortools/graph/samples/root_a_tree.cc new file mode 100644 index 0000000000..18e569902b --- /dev/null +++ b/ortools/graph/samples/root_a_tree.cc @@ -0,0 +1,86 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_join.h" +#include "ortools/base/init_google.h" +#include "ortools/base/status_macros.h" +#include "ortools/graph/graph.h" +#include "ortools/graph/rooted_tree.h" + +namespace { + +absl::Status Main() { + // Make an undirected tree as a graph using ListGraph (add the arcs in each + // direction). + const int32_t num_nodes = 5; + std::vector> arcs = { + {0, 1}, {1, 2}, {2, 3}, {1, 4}}; + util::ListGraph<> graph(num_nodes, 2 * static_cast(arcs.size())); + for (const auto [s, t] : arcs) { + graph.AddArc(s, t); + graph.AddArc(t, s); + } + + // Root the tree from 2. Save the depth of each node and topological ordering + int root = 2; + std::vector topological_order; + std::vector depth; + ASSIGN_OR_RETURN(const operations_research::RootedTree tree, + operations_research::RootedTreeFromGraph( + root, graph, &topological_order, &depth)); + + // Parents are: + // 0 -> 1 + // 1 -> 2 + // 2 is root (returns -1) + // 3 -> 2 + // 4 -> 1 + std::cout << "Parents:" << std::endl; + for (int i = 0; i < num_nodes; ++i) { + std::cout << " " << i << " -> " << tree.parents()[i] << std::endl; + } + + // Depths are: + // 0: 2 + // 1: 1 + // 2: 0 + // 3: 1 + // 4: 2 + std::cout << "Depths:" << std::endl; + for (int i = 0; i < num_nodes; ++i) { + std::cout << " " << i << " -> " << depth[i] << std::endl; + } + + // Many possible topological orders, including: + // [2, 1, 0, 4, 3] + // all starting with 2. + std::cout << "Topological order: " << absl::StrJoin(topological_order, ", ") + << std::endl; + + return absl::OkStatus(); +} + +} // namespace + +int main(int argc, char** argv) { + InitGoogle(argv[0], &argc, &argv, true); + QCHECK_OK(Main()); + return 0; +} diff --git a/ortools/graph/samples/rooted_tree_paths.cc b/ortools/graph/samples/rooted_tree_paths.cc new file mode 100644 index 0000000000..a4908f31cc --- /dev/null +++ b/ortools/graph/samples/rooted_tree_paths.cc @@ -0,0 +1,58 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_join.h" +#include "ortools/base/init_google.h" +#include "ortools/base/status_macros.h" +#include "ortools/graph/rooted_tree.h" + +namespace { + +absl::Status Main() { + // Make an rooted tree on 5 nodes with root 2 and the parental args: + // 0 -> 1 + // 1 -> 2 + // 2 is root + // 3 -> 2 + // 4 -> 1 + ASSIGN_OR_RETURN( + const operations_research::RootedTree tree, + operations_research::RootedTree::Create(2, {1, 2, -1, 2, 1})); + + // Precompute this for LCA computations below. + const std::vector depths = tree.AllDepths(); + + // Find the path between every pair of nodes in the tree. + for (int s = 0; s < 5; ++s) { + for (int t = 0; t < 5; ++t) { + int lca = tree.LowestCommonAncestorByDepth(s, t, depths); + const std::vector path = tree.Path(s, t, lca); + std::cout << s << " -> " << t << " [" << absl::StrJoin(path, ", ") << "]" + << std::endl; + } + } + return absl::OkStatus(); +} + +} // namespace + +int main(int argc, char** argv) { + InitGoogle(argv[0], &argc, &argv, true); + QCHECK_OK(Main()); + return 0; +}