// Copyright 2010-2025 Google LLC // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // This is a simplistic insertion-ordered map. It behaves similarly to an STL // map, but only implements a small subset of the map's methods. Internally, we // just keep a map and a list going in parallel. // // This class provides no thread safety guarantees, beyond what you would // normally see with std::list. // // Iterators point into the list and should be stable in the face of // mutations, except for an iterator pointing to an element that was just // deleted. // // This class supports heterogeneous lookups. // #ifndef ORTOOLS_BASE_LINKED_HASH_MAP_H_ #define ORTOOLS_BASE_LINKED_HASH_MAP_H_ #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/container/internal/common.h" #include "ortools/base/logging.h" namespace gtl { // This holds a list of pair items. This list is what gets // traversed, and it's iterators from this list that we return from // begin/end/find. // // We also keep a set for find. Since std::list is a // doubly-linked list, the iterators should remain stable. template ::hasher, typename KeyEq = typename absl::flat_hash_set::key_equal, typename Alloc = std::allocator>> class linked_hash_map { using KeyArgImpl = absl::container_internal::KeyArg< absl::container_internal::IsTransparent::value && absl::container_internal::IsTransparent::value>; // Alias used for heterogeneous lookup functions. // `key_arg` evaluates to `K` when the functors are transparent and to // `key_type` otherwise. It permits template argument deduction on `K` for the // transparent case. template using key_arg = typename KeyArgImpl::template type; public: using key_type = Key; using mapped_type = Value; using hasher = KeyHash; using key_equal = KeyEq; using value_type = std::pair; using allocator_type = Alloc; using difference_type = ptrdiff_t; private: using ListType = std::list; template class Wrapped { template static const K& ToKey(const K& k) { return k; } static const key_type& ToKey(typename ListType::const_iterator it) { return it->first; } static const key_type& ToKey(typename ListType::iterator it) { return it->first; } Fn fn_; friend linked_hash_map; public: using is_transparent = void; Wrapped() = default; explicit Wrapped(Fn fn) : fn_(std::move(fn)) {} template auto operator()(Args&&... args) const -> decltype(this->fn_(ToKey(args)...)) { return fn_(ToKey(args)...); } }; using SetType = absl::flat_hash_set, Wrapped, Alloc>; class NodeHandle { public: using key_type = linked_hash_map::key_type; using mapped_type = linked_hash_map::mapped_type; using allocator_type = linked_hash_map::allocator_type; constexpr NodeHandle() noexcept = default; NodeHandle(NodeHandle&& nh) noexcept = default; ~NodeHandle() = default; NodeHandle& operator=(NodeHandle&& node) noexcept = default; bool empty() const noexcept { return list_.empty(); } explicit operator bool() const noexcept { return !empty(); } allocator_type get_allocator() const { return list_.get_allocator(); } const key_type& key() const { return list_.front().first; } mapped_type& mapped() { return list_.front().second; } void swap(NodeHandle& nh) noexcept { list_.swap(nh.list_); } private: friend linked_hash_map; explicit NodeHandle(ListType list) : list_(std::move(list)) {} ListType list_; }; template struct InsertReturnType { Iterator position; bool inserted; NodeType node; }; public: using iterator = typename ListType::iterator; using const_iterator = typename ListType::const_iterator; using reverse_iterator = typename ListType::reverse_iterator; using const_reverse_iterator = typename ListType::const_reverse_iterator; using reference = typename ListType::reference; using const_reference = typename ListType::const_reference; using size_type = typename ListType::size_type; using pointer = typename std::allocator_traits::pointer; using const_pointer = typename std::allocator_traits::const_pointer; using node_type = NodeHandle; using insert_return_type = InsertReturnType; linked_hash_map() = default; explicit linked_hash_map(size_t bucket_count, const hasher& hash = hasher(), const key_equal& eq = key_equal(), const allocator_type& alloc = allocator_type()) : set_(bucket_count, Wrapped(hash), Wrapped(eq), alloc), list_(alloc) {} linked_hash_map(size_t bucket_count, const hasher& hash, const allocator_type& alloc) : linked_hash_map(bucket_count, hash, key_equal(), alloc) {} linked_hash_map(size_t bucket_count, const allocator_type& alloc) : linked_hash_map(bucket_count, hasher(), key_equal(), alloc) {} explicit linked_hash_map(const allocator_type& alloc) : linked_hash_map(0, hasher(), key_equal(), alloc) {} template linked_hash_map(InputIt first, InputIt last, size_t bucket_count = 0, const hasher& hash = hasher(), const key_equal& eq = key_equal(), const allocator_type& alloc = allocator_type()) : linked_hash_map(bucket_count, hash, eq, alloc) { insert(first, last); } template linked_hash_map(InputIt first, InputIt last, size_t bucket_count, const hasher& hash, const allocator_type& alloc) : linked_hash_map(first, last, bucket_count, hash, key_equal(), alloc) {} template linked_hash_map(InputIt first, InputIt last, size_t bucket_count, const allocator_type& alloc) : linked_hash_map(first, last, bucket_count, hasher(), key_equal(), alloc) {} template linked_hash_map(InputIt first, InputIt last, const allocator_type& alloc) : linked_hash_map(first, last, /*bucket_count=*/0, hasher(), key_equal(), alloc) {} linked_hash_map(std::initializer_list init, size_t bucket_count = 0, const hasher& hash = hasher(), const key_equal& eq = key_equal(), const allocator_type& alloc = allocator_type()) : linked_hash_map(init.begin(), init.end(), bucket_count, hash, eq, alloc) {} linked_hash_map(std::initializer_list init, size_t bucket_count, const hasher& hash, const allocator_type& alloc) : linked_hash_map(init, bucket_count, hash, key_equal(), alloc) {} linked_hash_map(std::initializer_list init, size_t bucket_count, const allocator_type& alloc) : linked_hash_map(init, bucket_count, hasher(), key_equal(), alloc) {} linked_hash_map(std::initializer_list init, const allocator_type& alloc) : linked_hash_map(init, /*bucket_count=*/0, hasher(), key_equal(), alloc) {} linked_hash_map(const linked_hash_map& other) : linked_hash_map(other.bucket_count(), other.hash_function(), other.key_eq(), other.get_allocator()) { CopyFrom(other); } linked_hash_map(const linked_hash_map& other, const allocator_type& alloc) : linked_hash_map(other.bucket_count(), other.hash_function(), other.key_eq(), alloc) { CopyFrom(other); } linked_hash_map(linked_hash_map&& other) noexcept : set_(std::move(other.set_)), list_(std::move(other.list_)) { // Since the list and set must agree for other to end up "valid", // explicitly clear them. other.set_.clear(); other.list_.clear(); } linked_hash_map(linked_hash_map&& other, const allocator_type& alloc) : linked_hash_map(0, other.hash_function(), other.key_eq(), alloc) { if (get_allocator() == other.get_allocator()) { *this = std::move(other); } else { CopyFrom(std::move(other)); } } linked_hash_map& operator=(const linked_hash_map& other) { if (this == &other) return *this; // Make a new set, with other's hash/eq/alloc. set_ = SetType(other.bucket_count(), other.set_.hash_function(), other.set_.key_eq(), other.get_allocator()); // Copy the list, with other's allocator. list_ = ListType(other.get_allocator()); CopyFrom(other); return *this; } linked_hash_map& operator=(linked_hash_map&& other) noexcept { // underlying containers will handle progagate_on_container_move details set_ = std::move(other.set_); list_ = std::move(other.list_); other.set_.clear(); other.list_.clear(); return *this; } linked_hash_map& operator=(std::initializer_list values) { clear(); insert(values.begin(), values.end()); return *this; } // Derive size_ from set_, as list::size might be O(N). size_type size() const { return set_.size(); } size_type max_size() const noexcept { return ~size_type{}; } bool empty() const { return set_.empty(); } // Iteration is list-like, in insertion order. // These are all forwarded. iterator begin() { return list_.begin(); } iterator end() { return list_.end(); } const_iterator begin() const { return list_.begin(); } const_iterator end() const { return list_.end(); } const_iterator cbegin() const { return list_.cbegin(); } const_iterator cend() const { return list_.cend(); } reverse_iterator rbegin() { return list_.rbegin(); } reverse_iterator rend() { return list_.rend(); } const_reverse_iterator rbegin() const { return list_.rbegin(); } const_reverse_iterator rend() const { return list_.rend(); } const_reverse_iterator crbegin() const { return list_.crbegin(); } const_reverse_iterator crend() const { return list_.crend(); } reference front() { return list_.front(); } reference back() { return list_.back(); } const_reference front() const { return list_.front(); } const_reference back() const { return list_.back(); } void pop_front() { erase(begin()); } void pop_back() { erase(std::prev(end())); } ABSL_ATTRIBUTE_REINITIALIZES void clear() { set_.clear(); list_.clear(); } void reserve(size_t n) { set_.reserve(n); } size_t capacity() const { return set_.capacity(); } size_t bucket_count() const { return set_.bucket_count(); } float load_factor() const { return set_.load_factor(); } hasher hash_function() const { return set_.hash_function().fn_; } key_equal key_eq() const { return set_.key_eq().fn_; } allocator_type get_allocator() const { return list_.get_allocator(); } template size_type erase(const key_arg& key) { auto found = set_.find(key); if (found == set_.end()) return 0; auto list_it = *found; // Erase set entry first since it refers to the list element. set_.erase(found); list_.erase(list_it); return 1; } iterator erase(const_iterator position) { auto found = set_.find(position); CHECK(*found == position) << "Inconsistent iterator for set and list, " "or the iterator is invalid."; set_.erase(found); return list_.erase(position); } iterator erase(iterator position) { return erase(static_cast(position)); } iterator erase(iterator first, iterator last) { while (first != last) first = erase(first); return first; } iterator erase(const_iterator first, const_iterator last) { while (first != last) first = erase(first); if (first == end()) return end(); return *set_.find(first); } template iterator find(const key_arg& key) { auto found = set_.find(key); if (found == set_.end()) return end(); return *found; } template const_iterator find(const key_arg& key) const { auto found = set_.find(key); if (found == set_.end()) return end(); return *found; } template size_type count(const key_arg& key) const { return contains(key) ? 1 : 0; } template bool contains(const key_arg& key) const { return set_.contains(key); } template mapped_type& at(const key_arg& key) { auto it = find(key); if (ABSL_PREDICT_FALSE(it == end())) { LOG(FATAL) << "linked_hash_map::at failed bounds check"; } return it->second; } template const mapped_type& at(const key_arg& key) const { return const_cast(this)->at(key); } template std::pair equal_range(const key_arg& key) { auto iter = set_.find(key); if (iter == set_.end()) return {end(), end()}; return {*iter, std::next(*iter)}; } template std::pair equal_range( const key_arg& key) const { auto iter = set_.find(key); if (iter == set_.end()) return {end(), end()}; return {*iter, std::next(*iter)}; } template mapped_type& operator[](const key_arg& key) { return LazyEmplaceInternal(key).first->second; } template mapped_type& operator[](key_arg&& key) { return LazyEmplaceInternal(std::forward(key)).first->second; } std::pair insert(const value_type& v) { return InsertInternal(v); } std::pair insert(value_type&& v) { // NOLINT(build/c++11) return InsertInternal(std::move(v)); } iterator insert(const_iterator, const value_type& v) { return insert(v).first; } iterator insert(const_iterator, value_type&& v) { return insert(std::move(v)).first; } void insert(std::initializer_list ilist) { insert(ilist.begin(), ilist.end()); } template void insert(InputIt first, InputIt last) { for (; first != last; ++first) insert(*first); } insert_return_type insert(node_type&& node) { if (!node) return {end(), false, node_type()}; auto itr = find(node.key()); if (itr != end()) return {itr, false, std::move(node)}; list_.splice(list_.end(), node.list_); set_.insert(--list_.end()); return {--list_.end(), true, node_type()}; } iterator insert(const_iterator, node_type&& node) { return insert(std::move(node)).first; } // The last two template parameters ensure that both arguments are rvalues // (lvalue arguments are handled by the overloads below). This is necessary // for supporting bitfield arguments. // // union { int n : 1; }; // linked_hash_map m; // m.insert_or_assign(n, n); template std::pair insert_or_assign(key_arg&& k, V&& v) { return InsertOrAssignInternal(std::forward(k), std::forward(v)); } template std::pair insert_or_assign(key_arg&& k, const V& v) { return InsertOrAssignInternal(std::forward(k), v); } template std::pair insert_or_assign(const key_arg& k, V&& v) { return InsertOrAssignInternal(k, std::forward(v)); } template std::pair insert_or_assign(const key_arg& k, const V& v) { return InsertOrAssignInternal(k, v); } template iterator insert_or_assign(const_iterator, key_arg&& k, V&& v) { return insert_or_assign(std::forward(k), std::forward(v)).first; } template iterator insert_or_assign(const_iterator, key_arg&& k, const V& v) { return insert_or_assign(std::forward(k), v).first; } template iterator insert_or_assign(const_iterator, const key_arg& k, V&& v) { return insert_or_assign(k, std::forward(v)).first; } template iterator insert_or_assign(const_iterator, const key_arg& k, const V& v) { return insert_or_assign(k, v).first; } template std::pair emplace(Args&&... args) { ListType node_donor; auto list_iter = node_donor.emplace(node_donor.end(), std::forward(args)...); auto ins = set_.insert(list_iter); if (!ins.second) return {*ins.first, false}; list_.splice(list_.end(), node_donor, list_iter); return {list_iter, true}; } template iterator try_emplace(const_iterator, key_arg&& k, Args&&... args) { return try_emplace(std::forward(k), std::forward(args)...).first; } template iterator emplace_hint(const_iterator, Args&&... args) { return emplace(std::forward(args)...).first; } template std::pair try_emplace(key_arg&& key, Args&&... args) { return LazyEmplaceInternal(std::forward>(key), std::forward(args)...); } template void merge(linked_hash_map& src) { auto itr = src.list_.begin(); while (itr != src.list_.end()) { if (contains(itr->first)) { ++itr; } else { insert(src.extract(itr++)); } } } template void merge(linked_hash_map&& src) { merge(src); } node_type extract(const_iterator position) { set_.erase(position->first); ListType extracted_node_list; extracted_node_list.splice(extracted_node_list.end(), list_, position); return node_type(std::move(extracted_node_list)); } template , int> = 0> node_type extract(const key_arg& key) { auto it = find(key); return it == end() ? node_type() : extract(const_iterator{it}); } template std::pair try_emplace(const key_arg& key, Args&&... args) { return LazyEmplaceInternal(key, std::forward(args)...); } void swap(linked_hash_map& other) { using std::swap; swap(set_, other.set_); swap(list_, other.list_); } friend bool operator==(const linked_hash_map& a, const linked_hash_map& b) { if (a.size() != b.size()) return false; const linked_hash_map* outer = &a; const linked_hash_map* inner = &b; if (outer->capacity() > inner->capacity()) std::swap(outer, inner); for (const value_type& elem : *outer) { auto it = inner->find(elem.first); if (it == inner->end()) return false; if (it->second != elem.second) return false; } return true; } friend bool operator!=(const linked_hash_map& a, const linked_hash_map& b) { return !(a == b); } void rehash(size_t n) { set_.rehash(n); } private: template void CopyFrom(Other&& other) { for (auto& elem : other.list_) { set_.insert(list_.insert(list_.end(), std::move(elem))); } DCHECK_EQ(set_.size(), list_.size()) << "Set and list are inconsistent."; } template std::pair InsertInternal(U&& pair) { // NOLINT(build/c++11) auto iter = set_.find(pair.first); if (iter != set_.end()) return {*iter, false}; auto list_iter = list_.insert(list_.end(), std::forward(pair)); auto inserted = set_.insert(list_iter); DCHECK(inserted.second); return {list_iter, true}; } template std::pair InsertOrAssignInternal(K&& k, V&& v) { auto iter = set_.find(k); if (iter != set_.end()) { (*iter)->second = std::forward(v); return {*iter, false}; } return LazyEmplaceInternal(std::forward(k), std::forward(v)); } template std::pair LazyEmplaceInternal(K&& key, Args&&... args) { bool constructed = false; auto set_iter = set_.lazy_emplace(key, [this, &constructed, &key, &args...](auto ctor) { auto list_iter = list_.emplace(list_.end(), std::piecewise_construct, std::forward_as_tuple(std::forward(key)), std::forward_as_tuple(std::forward(args)...)); constructed = true; ctor(list_iter); }); return {*set_iter, constructed}; } // The set component, used for speedy lookups. SetType set_; // The list component, used for maintaining insertion order. ListType list_; }; } // namespace gtl #endif // ORTOOLS_BASE_LINKED_HASH_MAP_H_