Update thread_pool code (#4890)

This commit is contained in:
Guillaume Chatelet
2025-10-21 18:29:06 +01:00
committed by Corentin Le Molgat
parent 6969f23df4
commit d8d50bae68
9 changed files with 143 additions and 88 deletions

View File

@@ -535,8 +535,12 @@ cc_library(
srcs = ["threadpool.cc"],
hdrs = ["threadpool.h"],
deps = [
"@abseil-cpp//absl/algorithm:container",
"@abseil-cpp//absl/base:core_headers",
"@abseil-cpp//absl/base:nullability",
"@abseil-cpp//absl/functional:any_invocable",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/synchronization",
],
)

View File

@@ -13,84 +13,120 @@
#include "ortools/base/threadpool.h"
#include <functional>
#include <mutex>
#include <optional>
#include <utility>
#include "absl/algorithm/container.h"
#include "absl/base/nullability.h"
#include "absl/base/optimization.h"
#include "absl/base/thread_annotations.h"
#include "absl/functional/any_invocable.h"
#include "absl/log/check.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
namespace operations_research {
void RunWorker(void* data) {
ThreadPool* const thread_pool = reinterpret_cast<ThreadPool*>(data);
std::function<void()> work = thread_pool->GetNextTask();
while (work != nullptr) {
work();
work = thread_pool->GetNextTask();
}
// It is a common error to call ThreadPool(workitems.size()), which
// crashes when workitems is empty. Prevent those crashes by creating at
// least one thread.
ThreadPool::ThreadPool(int num_threads)
: max_threads_(num_threads == 0 ? 1 : num_threads) {
CHECK_GT(max_threads_, 0u);
// Spawn a single thread to handle work by default.
absl::MutexLock lock(mutex_);
SpawnThread();
}
ThreadPool::ThreadPool(int num_threads) : num_workers_(num_threads) {}
ThreadPool::ThreadPool(absl::string_view /*prefix*/, int num_threads)
: num_workers_(num_threads) {}
ThreadPool::ThreadPool(absl::string_view prefix, int num_threads)
: ThreadPool(num_threads) {}
ThreadPool::~ThreadPool() {
if (started_) {
std::unique_lock<std::mutex> mutex_lock(mutex_);
waiting_to_finish_ = true;
mutex_lock.unlock();
condition_.notify_all();
for (int i = 0; i < num_workers_; ++i) {
all_workers_[i].join();
// Make threads finish up by setting stopping_. Ensure all threads waiting see
// this change by signalling their condvar.
{
absl::MutexLock l(mutex_);
stopping_ = true;
for (Waiter* absl_nonnull waiter : waiters_) {
waiter->cv.Signal();
}
// Wait until the queue is empty. This implies no new threads will be
// spawned, and all existing threads are exiting.
auto queue_empty = [this]() ABSL_SHARED_LOCKS_REQUIRED(mutex_) {
return queue_.empty();
};
mutex_.Await(absl::Condition(&queue_empty));
}
// Join and delete all threads. Because the queue is empty, we know no new
// threads will be added to threads_.
for (auto& worker : threads_) {
worker.join();
}
}
void ThreadPool::SetQueueCapacity(int capacity) {
CHECK_GT(capacity, num_workers_);
CHECK(!started_);
queue_capacity_ = capacity;
void ThreadPool::SpawnThread() {
CHECK_LE(threads_.size(), max_threads_);
threads_.emplace_back([this] { RunWorker(); });
}
void ThreadPool::StartWorkers() {
started_ = true;
for (int i = 0; i < num_workers_; ++i) {
all_workers_.push_back(std::thread(&RunWorker, this));
void ThreadPool::RunWorker() {
{
absl::MutexLock lock(mutex_);
++running_threads_;
}
}
std::function<void()> ThreadPool::GetNextTask() {
std::unique_lock<std::mutex> lock(mutex_);
for (;;) {
if (!tasks_.empty()) {
std::function<void()> task = tasks_.front();
tasks_.pop_front();
if (tasks_.size() < queue_capacity_ && waiting_for_capacity_) {
waiting_for_capacity_ = false;
capacity_condition_.notify_all();
}
return task;
}
if (waiting_to_finish_) {
return nullptr;
} else {
condition_.wait(lock);
while (true) {
std::optional<absl::AnyInvocable<void() &&>> item = DequeueWork();
if (!item.has_value()) { // Requesting to stop the worker thread.
break;
}
DCHECK(item);
std::move (*item)();
}
return nullptr;
}
void ThreadPool::Schedule(std::function<void()> closure) {
std::unique_lock<std::mutex> lock(mutex_);
while (tasks_.size() >= queue_capacity_) {
waiting_for_capacity_ = true;
capacity_condition_.wait(lock);
void ThreadPool::SignalWaiter() {
DCHECK(!queue_.empty());
if (waiters_.empty()) {
// If there are no waiters, try spawning a new thread to pick up work.
if (running_threads_ == threads_.size() && threads_.size() < max_threads_) {
SpawnThread();
}
} else {
// If there are waiters we wake the last inserted waiter. Note that we can
// signal this waiter multiple times. This is not only ok but it is crucial
// to reduce spurious wakeups.
waiters_.back()->cv.Signal();
}
tasks_.push_back(closure);
if (started_) {
lock.unlock();
condition_.notify_all();
}
std::optional<absl::AnyInvocable<void() &&>> ThreadPool::DequeueWork() {
// Wait for queue to be not-empty
absl::MutexLock m(mutex_);
while (queue_.empty() && !stopping_) {
Waiter self;
waiters_.push_back(&self);
self.cv.Wait(&mutex_);
waiters_.erase(absl::c_find(waiters_, &self));
}
if (queue_.empty()) {
DCHECK(stopping_);
return std::nullopt;
}
absl::AnyInvocable<void() &&> result = std::move(queue_.front());
queue_.pop_front();
if (!queue_.empty()) {
SignalWaiter();
}
return std::move(result);
}
void ThreadPool::Schedule(absl::AnyInvocable<void() &&> callback) {
// Wait for queue to be not-full
absl::MutexLock m(mutex_);
DCHECK(!stopping_) << "Callback added after destructor started";
if (ABSL_PREDICT_FALSE(stopping_)) return;
queue_.push_back(std::move(callback));
SignalWaiter();
}
} // namespace operations_research

View File

@@ -14,39 +14,62 @@
#ifndef OR_TOOLS_BASE_THREADPOOL_H_
#define OR_TOOLS_BASE_THREADPOOL_H_
#include <condition_variable> // NOLINT
#include <functional>
#include <list>
#include <mutex> // NOLINT
#include <string>
#include <cstddef>
#include <deque>
#include <optional>
#include <thread> // NOLINT
#include <vector>
#include "absl/base/nullability.h"
#include "absl/base/thread_annotations.h"
#include "absl/functional/any_invocable.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
namespace operations_research {
class ThreadPool {
public:
explicit ThreadPool(int num_threads);
ThreadPool(absl::string_view prefix, int num_threads);
~ThreadPool();
void StartWorkers();
void Schedule(std::function<void()> closure);
std::function<void()> GetNextTask();
void SetQueueCapacity(int capacity);
void Schedule(absl::AnyInvocable<void() &&> callback);
private:
const int num_workers_;
std::list<std::function<void()>> tasks_;
std::mutex mutex_;
std::condition_variable condition_;
std::condition_variable capacity_condition_;
bool waiting_to_finish_ = false;
bool waiting_for_capacity_ = false;
bool started_ = false;
int queue_capacity_ = 2e9;
std::vector<std::thread> all_workers_;
// Waiter for a single thread.
struct Waiter {
absl::CondVar cv; // signalled when there is work to do
};
// Spawn a single new worker thread.
//
// REQUIRES: threads_.size() < max_threads_
void SpawnThread() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
void RunWorker();
// Removes the oldest element from the queue and returns it. Causes the
// current thread to wait for producers if the queue is empty. Returns
// an empty `std::optional` if the thread pool is shutting down.
std::optional<absl::AnyInvocable<void() &&>> DequeueWork()
ABSL_LOCKS_EXCLUDED(mutex_);
// Signals a waiter if there is one, or spawns a thread to try to add a new
// waiter.
//
// REQUIRES: !queue_.empty()
void SignalWaiter() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
mutable absl::Mutex mutex_;
absl::CondVar wait_nonfull_ ABSL_GUARDED_BY(mutex_);
std::vector<Waiter* absl_nonnull> waiters_ ABSL_GUARDED_BY(mutex_);
const size_t max_threads_;
std::deque<absl::AnyInvocable<void() &&>> queue_;
bool stopping_ ABSL_GUARDED_BY(mutex_) = false;
size_t running_threads_ ABSL_GUARDED_BY(mutex_) = 0;
std::vector<std::thread> threads_ ABSL_GUARDED_BY(mutex_);
};
} // namespace operations_research
#endif // OR_TOOLS_BASE_THREADPOOL_H_

View File

@@ -206,7 +206,6 @@ BidirectionalDijkstra<GraphType, DistanceType>::BidirectionalDijkstra(
distances_[dir].assign(num_nodes, infinity());
parent_arc_[dir].assign(num_nodes, -1);
}
search_threads_.StartWorkers();
}
template <typename GraphType, typename DistanceType>

View File

@@ -557,7 +557,6 @@ GraphPathWithLength<GraphType> ConstrainedShortestPathsOnDagWrapper<
{
ThreadPool search_threads(2);
search_threads.StartWorkers();
for (const Direction dir : {FORWARD, BACKWARD}) {
search_threads.Schedule([this, dir, &sub_arc_lengths]() {
RunHalfConstrainedShortestPathOnDag(

View File

@@ -761,7 +761,6 @@ void ComputeManyToManyShortestPathsWithMultipleThreads(
graph.num_nodes());
{
std::unique_ptr<ThreadPool> pool(new ThreadPool(num_threads));
pool->StartWorkers();
for (int i = 0; i < unique_sources.size(); ++i) {
pool->Schedule(absl::bind_front(
&internal::ComputeOneToManyOnGraph<GraphType>, &graph, &arc_lengths,

View File

@@ -1155,7 +1155,6 @@ void MPSolver::SolveLazyMutableRequest(LazyMutableCopy<MPModelRequest> request,
// the user. They shouldn't matter for polling, but for solving we might
// e.g. use a larger stack.
ThreadPool thread_pool(/*num_threads=*/1);
thread_pool.StartWorkers();
thread_pool.Schedule(polling_func);
// Make sure the interruption notification didn't arrive while waiting to

View File

@@ -51,9 +51,7 @@ class GoogleThreadPoolScheduler : public Scheduler {
public:
GoogleThreadPoolScheduler(int num_threads)
: num_threads_(num_threads),
threadpool_(std::make_unique<ThreadPool>("pdlp", num_threads)) {
threadpool_->StartWorkers();
}
threadpool_(std::make_unique<ThreadPool>("pdlp", num_threads)) {}
int num_threads() const override { return num_threads_; };
std::string info_string() const override { return "google_threadpool"; };
@@ -79,7 +77,7 @@ class EigenThreadPoolScheduler : public Scheduler {
public:
EigenThreadPoolScheduler(int num_threads)
: num_threads_(num_threads),
eigen_threadpool_(std::make_unique<Eigen::ThreadPool>(num_threads)) {}
g3_threadpool_(std::make_unique<Eigen::ThreadPool>(num_threads)) {}
int num_threads() const override { return num_threads_; };
std::string info_string() const override { return "eigen_threadpool"; };
@@ -87,7 +85,7 @@ class EigenThreadPoolScheduler : public Scheduler {
absl::AnyInvocable<void(int)> do_func) override {
Eigen::Barrier eigen_barrier(end - start);
for (int i = start; i < end; ++i) {
eigen_threadpool_->Schedule([&, i]() {
g3_threadpool_->Schedule([&, i]() {
do_func(i);
eigen_barrier.Notify();
});
@@ -97,7 +95,7 @@ class EigenThreadPoolScheduler : public Scheduler {
private:
const int num_threads_;
std::unique_ptr<Eigen::ThreadPool> eigen_threadpool_ = nullptr;
std::unique_ptr<Eigen::ThreadPool> g3_threadpool_ = nullptr;
};
// Makes a scheduler of a given type.

View File

@@ -142,7 +142,6 @@ void DeterministicLoop(std::vector<std::unique_ptr<SubSolver>>& subsolvers,
std::vector<double> timing;
to_run.reserve(batch_size);
ThreadPool pool(num_threads);
pool.StartWorkers();
for (int batch_index = 0;; ++batch_index) {
VLOG(2) << "Starting deterministic batch of size " << batch_size;
SynchronizeAll(subsolvers);
@@ -214,7 +213,6 @@ void NonDeterministicLoop(std::vector<std::unique_ptr<SubSolver>>& subsolvers,
};
ThreadPool pool(num_threads);
pool.StartWorkers();
// The lambda below are using little space, but there is no reason
// to create millions of them, so we use the blocking nature of