rosa/inc/alsk/executor/utility/staticpool.h
2021-05-10 18:14:24 +02:00

162 lines
3.4 KiB
C++

#ifndef ALSK_ALSK_EXECUTOR_UTILITY_STATICPOOL_H
#define ALSK_ALSK_EXECUTOR_UTILITY_STATICPOOL_H
#include <atomic>
#include <condition_variable>
#include <functional>
#include <future>
#include <list>
#include <mutex>
#include <thread>
#include <unordered_map>
#include <vector>
#include <tmp/traits.h>
namespace alsk {
namespace exec {
namespace util {
struct StaticPool {
using Task = std::function<void()>;
using TaskInfo = std::tuple<Task, std::promise<void>>;
struct ThreadInfo {
std::atomic_bool running;
std::thread thread;
std::list<TaskInfo> tasks;
std::mutex mutex;
std::condition_variable cv;
ThreadInfo() {}
ThreadInfo(ThreadInfo&&) {}
};
private:
std::vector<ThreadInfo> _threads;
std::unordered_map<std::thread::id, std::reference_wrapper<ThreadInfo>> _threadFromId;
public:
StaticPool() {}
StaticPool(StaticPool const& o) {
config(o._threads.size());
}
StaticPool(StaticPool&& o) {
config(o._threads.size());
}
~StaticPool() {
terminate();
}
StaticPool const& operator=(StaticPool const& o) {
if(this == &o) return *this;
config(o._threads.size());
return *this;
}
StaticPool const& operator=(StaticPool&& o) {
if(this == &o) return *this;
config(o._threads.size());
return *this;
}
void config(unsigned int cores) {
terminate();
if(cores == 0) return;
_threads.resize(cores);
for(unsigned int i = 0; i < cores; ++i) {
ThreadInfo& threadInfo = _threads[i];
threadInfo.running = true;
threadInfo.thread = std::thread{[&,&threadInfo=threadInfo] { worker(threadInfo); }};
_threadFromId.emplace(threadInfo.thread.get_id(), threadInfo);
}
}
template<typename F, typename R = tmp::invoke_result_t<F>, std::enable_if_t<std::is_same<R, void>{}>* = nullptr>
std::future<void> run(std::size_t i, F&& task) {
ThreadInfo& threadInfo = _threads[i];
std::future<void> future;
{
std::lock_guard<std::mutex> lg{threadInfo.mutex};
threadInfo.tasks.emplace_back(std::forward<F>(task), std::promise<void>{});
future = std::get<1>(threadInfo.tasks.back()).get_future();
}
threadInfo.cv.notify_one();
return future;
}
template<typename Futures>
void wait(Futures& futures) {
auto const& id = std::this_thread::get_id();
if(_threadFromId.count(id)) {
auto& threadInfo = _threadFromId.at(id);
while(tryProcessOne(threadInfo));
}
for(auto& future: futures) future.wait();
futures.clear();
}
protected:
void terminate() {
for(auto& threadInfo: _threads) {
{
std::lock_guard<std::mutex> lg{threadInfo.mutex};
threadInfo.running = false;
}
threadInfo.cv.notify_all();
threadInfo.thread.join();
}
_threads.clear();
_threadFromId.clear();
}
void worker(ThreadInfo& threadInfo) {
auto test = [&]{ return !threadInfo.running || threadInfo.tasks.size(); };
for(;;) {
TaskInfo taskInfo;
{
std::unique_lock<std::mutex> lk{threadInfo.mutex};
if(!test()) threadInfo.cv.wait(lk, test);
if(!threadInfo.running) return;
taskInfo = std::move(threadInfo.tasks.front());
threadInfo.tasks.pop_front();
}
process(taskInfo);
}
}
bool tryProcessOne(ThreadInfo& threadInfo) {
TaskInfo taskInfo;
{
std::unique_lock<std::mutex> lk{threadInfo.mutex};
if(threadInfo.tasks.empty()) return false;
taskInfo = std::move(threadInfo.tasks.front());
threadInfo.tasks.pop_front();
}
process(taskInfo);
return true;
}
void process(TaskInfo& taskInfo) {
std::get<0>(taskInfo)();
std::get<1>(taskInfo).set_value();
}
};
}
}
}
#endif