implement correct behavior for the threadpool in the CP-SAT LNS
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
|
||||
#include "ortools/base/threadpool.h"
|
||||
|
||||
#include "ortools/base/logging.h"
|
||||
|
||||
namespace operations_research {
|
||||
void RunWorker(void* data) {
|
||||
ThreadPool* const thread_pool = reinterpret_cast<ThreadPool*>(data);
|
||||
@@ -24,7 +26,7 @@ void RunWorker(void* data) {
|
||||
}
|
||||
|
||||
ThreadPool::ThreadPool(const std::string& prefix, int num_workers)
|
||||
: num_workers_(num_workers), waiting_to_finish_(false), started_(false) {}
|
||||
: num_workers_(num_workers) {}
|
||||
|
||||
ThreadPool::~ThreadPool() {
|
||||
if (started_) {
|
||||
@@ -38,6 +40,12 @@ ThreadPool::~ThreadPool() {
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadPool::SetQueueCapacity(int capacity) {
|
||||
CHECK_GT(capacity, num_workers_);
|
||||
CHECK(!started_);
|
||||
queue_capacity_ = capacity;
|
||||
}
|
||||
|
||||
void ThreadPool::StartWorkers() {
|
||||
started_ = true;
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
@@ -51,6 +59,10 @@ std::function<void()> ThreadPool::GetNextTask() {
|
||||
if (!tasks_.empty()) {
|
||||
std::function<void()> task = tasks_.front();
|
||||
tasks_.pop_front();
|
||||
if (tasks_.size() < queue_capacity_ && waiting_for_capacity_) {
|
||||
waiting_for_capacity_ = false;
|
||||
capacity_condition_.notify_all();
|
||||
}
|
||||
return task;
|
||||
}
|
||||
if (waiting_to_finish_) {
|
||||
@@ -64,10 +76,15 @@ std::function<void()> ThreadPool::GetNextTask() {
|
||||
|
||||
void ThreadPool::Schedule(std::function<void()> closure) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (tasks_.size() >= queue_capacity_) {
|
||||
waiting_for_capacity_ = true;
|
||||
capacity_condition_.wait(lock);
|
||||
}
|
||||
tasks_.push_back(closure);
|
||||
if (started_) {
|
||||
lock.unlock();
|
||||
condition_.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace operations_research
|
||||
|
||||
@@ -31,14 +31,18 @@ class ThreadPool {
|
||||
void StartWorkers();
|
||||
void Schedule(std::function<void()> closure);
|
||||
std::function<void()> GetNextTask();
|
||||
void SetQueueCapacity(int capacity);
|
||||
|
||||
private:
|
||||
const int num_workers_;
|
||||
std::list<std::function<void()>> tasks_;
|
||||
std::mutex mutex_;
|
||||
std::condition_variable condition_;
|
||||
bool waiting_to_finish_;
|
||||
bool started_;
|
||||
std::condition_variable capacity_condition_;
|
||||
bool waiting_to_finish_ = false;
|
||||
bool waiting_for_capacity_ = false;
|
||||
bool started_ = false;
|
||||
int queue_capacity_ = 2000000000;
|
||||
std::vector<std::thread> all_workers_;
|
||||
};
|
||||
} // namespace operations_research
|
||||
|
||||
@@ -159,7 +159,7 @@ inline void NonDeterministicOptimizeWithLNS(
|
||||
// to create millions of them, so we use the blocking nature of
|
||||
// pool.Schedule() when the queue capacity is set.
|
||||
ThreadPool pool("Parallel_LNS", num_threads);
|
||||
// pool.SetQueueCapacity(10 * num_threads);
|
||||
pool.SetQueueCapacity(10 * num_threads);
|
||||
pool.StartWorkers();
|
||||
while (!synchronize_and_maybe_stop()) {
|
||||
pool.Schedule([&generate_and_solve, seed]() { generate_and_solve(seed); });
|
||||
|
||||
Reference in New Issue
Block a user