implement correct behavior for the threadpool in the CP-SAT LNS

This commit is contained in:
Laurent Perron
2019-06-25 13:11:03 +02:00
parent 7b73d767ad
commit 4d91cf5d6b
3 changed files with 25 additions and 4 deletions

View File

@@ -13,6 +13,8 @@
#include "ortools/base/threadpool.h"
#include "ortools/base/logging.h"
namespace operations_research {
void RunWorker(void* data) {
ThreadPool* const thread_pool = reinterpret_cast<ThreadPool*>(data);
@@ -24,7 +26,7 @@ void RunWorker(void* data) {
}
ThreadPool::ThreadPool(const std::string& prefix, int num_workers)
: num_workers_(num_workers), waiting_to_finish_(false), started_(false) {}
: num_workers_(num_workers) {}
ThreadPool::~ThreadPool() {
if (started_) {
@@ -38,6 +40,12 @@ ThreadPool::~ThreadPool() {
}
}
void ThreadPool::SetQueueCapacity(int capacity) {
CHECK_GT(capacity, num_workers_);
CHECK(!started_);
queue_capacity_ = capacity;
}
void ThreadPool::StartWorkers() {
started_ = true;
for (int i = 0; i < num_workers_; ++i) {
@@ -51,6 +59,10 @@ std::function<void()> ThreadPool::GetNextTask() {
if (!tasks_.empty()) {
std::function<void()> task = tasks_.front();
tasks_.pop_front();
if (tasks_.size() < queue_capacity_ && waiting_for_capacity_) {
waiting_for_capacity_ = false;
capacity_condition_.notify_all();
}
return task;
}
if (waiting_to_finish_) {
@@ -64,10 +76,15 @@ std::function<void()> ThreadPool::GetNextTask() {
void ThreadPool::Schedule(std::function<void()> closure) {
std::unique_lock<std::mutex> lock(mutex_);
while (tasks_.size() >= queue_capacity_) {
waiting_for_capacity_ = true;
capacity_condition_.wait(lock);
}
tasks_.push_back(closure);
if (started_) {
lock.unlock();
condition_.notify_all();
}
}
} // namespace operations_research

View File

@@ -31,14 +31,18 @@ class ThreadPool {
void StartWorkers();
void Schedule(std::function<void()> closure);
std::function<void()> GetNextTask();
void SetQueueCapacity(int capacity);
private:
const int num_workers_;
std::list<std::function<void()>> tasks_;
std::mutex mutex_;
std::condition_variable condition_;
bool waiting_to_finish_;
bool started_;
std::condition_variable capacity_condition_;
bool waiting_to_finish_ = false;
bool waiting_for_capacity_ = false;
bool started_ = false;
int queue_capacity_ = 2000000000;
std::vector<std::thread> all_workers_;
};
} // namespace operations_research

View File

@@ -159,7 +159,7 @@ inline void NonDeterministicOptimizeWithLNS(
// to create millions of them, so we use the blocking nature of
// pool.Schedule() when the queue capacity is set.
ThreadPool pool("Parallel_LNS", num_threads);
// pool.SetQueueCapacity(10 * num_threads);
pool.SetQueueCapacity(10 * num_threads);
pool.StartWorkers();
while (!synchronize_and_maybe_stop()) {
pool.Schedule([&generate_and_solve, seed]() { generate_and_solve(seed); });