From 4d91cf5d6b4cd7be2b846ebe9d22add0fedee8a8 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Tue, 25 Jun 2019 13:11:03 +0200 Subject: [PATCH] implement correct behavior for the threadpool in the CP-SAT LNS --- ortools/base/threadpool.cc | 19 ++++++++++++++++++- ortools/base/threadpool.h | 8 ++++++-- ortools/sat/lns.h | 2 +- 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/ortools/base/threadpool.cc b/ortools/base/threadpool.cc index 3fb7efd1da..506d601dfc 100644 --- a/ortools/base/threadpool.cc +++ b/ortools/base/threadpool.cc @@ -13,6 +13,8 @@ #include "ortools/base/threadpool.h" +#include "ortools/base/logging.h" + namespace operations_research { void RunWorker(void* data) { ThreadPool* const thread_pool = reinterpret_cast(data); @@ -24,7 +26,7 @@ void RunWorker(void* data) { } ThreadPool::ThreadPool(const std::string& prefix, int num_workers) - : num_workers_(num_workers), waiting_to_finish_(false), started_(false) {} + : num_workers_(num_workers) {} ThreadPool::~ThreadPool() { if (started_) { @@ -38,6 +40,12 @@ ThreadPool::~ThreadPool() { } } +void ThreadPool::SetQueueCapacity(int capacity) { + CHECK_GT(capacity, num_workers_); + CHECK(!started_); + queue_capacity_ = capacity; +} + void ThreadPool::StartWorkers() { started_ = true; for (int i = 0; i < num_workers_; ++i) { @@ -51,6 +59,10 @@ std::function ThreadPool::GetNextTask() { if (!tasks_.empty()) { std::function task = tasks_.front(); tasks_.pop_front(); + if (tasks_.size() < queue_capacity_ && waiting_for_capacity_) { + waiting_for_capacity_ = false; + capacity_condition_.notify_all(); + } return task; } if (waiting_to_finish_) { @@ -64,10 +76,15 @@ std::function ThreadPool::GetNextTask() { void ThreadPool::Schedule(std::function closure) { std::unique_lock lock(mutex_); + while (tasks_.size() >= queue_capacity_) { + waiting_for_capacity_ = true; + capacity_condition_.wait(lock); + } tasks_.push_back(closure); if (started_) { lock.unlock(); condition_.notify_all(); } } + } // namespace operations_research diff --git a/ortools/base/threadpool.h b/ortools/base/threadpool.h index bfeb502414..82d2b492c3 100644 --- a/ortools/base/threadpool.h +++ b/ortools/base/threadpool.h @@ -31,14 +31,18 @@ class ThreadPool { void StartWorkers(); void Schedule(std::function closure); std::function GetNextTask(); + void SetQueueCapacity(int capacity); private: const int num_workers_; std::list> tasks_; std::mutex mutex_; std::condition_variable condition_; - bool waiting_to_finish_; - bool started_; + std::condition_variable capacity_condition_; + bool waiting_to_finish_ = false; + bool waiting_for_capacity_ = false; + bool started_ = false; + int queue_capacity_ = 2000000000; std::vector all_workers_; }; } // namespace operations_research diff --git a/ortools/sat/lns.h b/ortools/sat/lns.h index 557a429f18..31a33bc95d 100644 --- a/ortools/sat/lns.h +++ b/ortools/sat/lns.h @@ -159,7 +159,7 @@ inline void NonDeterministicOptimizeWithLNS( // to create millions of them, so we use the blocking nature of // pool.Schedule() when the queue capacity is set. ThreadPool pool("Parallel_LNS", num_threads); - // pool.SetQueueCapacity(10 * num_threads); + pool.SetQueueCapacity(10 * num_threads); pool.StartWorkers(); while (!synchronize_and_maybe_stop()) { pool.Schedule([&generate_and_solve, seed]() { generate_and_solve(seed); });