From ac94d79b38a30149b6d839f3d548dfa67e9376d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20Kr=C3=BCger?= Date: Tue, 1 Dec 2020 18:33:21 +0100 Subject: [PATCH] fixed threading/multiprocessing --- al2/experiment.py | 69 +++++++++++++++++++++++++++++++++++------------ al2/plan.py | 4 +-- 2 files changed, 54 insertions(+), 19 deletions(-) diff --git a/al2/experiment.py b/al2/experiment.py index 6b06649..91a98fc 100644 --- a/al2/experiment.py +++ b/al2/experiment.py @@ -5,6 +5,7 @@ import multiprocessing import threading import concurrent.futures as concfut import os +import time from . import batch from . import plan @@ -20,42 +21,76 @@ def load(exp_file): with open(exp_file) as efile: exp_obj = json.loads(efile.read()) exp_obj["load"] = pl.Path(exp_obj["load"]) - + exp_mod = impmach.SourceFileLoader(exp_obj["load"].stem, str(exp_obj["load"])).load_module() - return Dispatcher(exp_mod.run, exp_plan, os.cpu_count()) + num_workers = 1 + + if "workers" in exp_obj: + if exp_obj["workers"] == "all": + num_workers = os.cpu_count() + else: + num_workers = int(exp_obj["workers"]) + + return Dispatcher(exp_mod, exp_plan, num_workers) class Dispatcher (threading.Thread): - def __init__(self, exp_func, exp_plan, num_workers): + def __init__(self, exp_mod, exp_plan, num_workers): threading.Thread.__init__(self) - - self.__exp_func = exp_func - self.__plan = exp_plan self.__num_workers = num_workers self.__workers = [] + self.__stop_called = threading.Event() for i in range(self.__num_workers): - self.__workers.append(multiprocessing.Process(target=self.__run_exp, - args=(self.__exp_func, - self.__plan))) + self.__workers.append(Worker(exp_mod, exp_plan)) def run(self): for worker in self.__workers: worker.start() - for worker in self.__workers: + def wait_to_continue(workers, stop_called): + any_worker_alive = any(map(lambda w: w.is_alive(), workers)) + + while any_worker_alive and not stop_called.is_set(): + time.sleep(0) + + waiter = threading.Thread(target=wait_to_continue, + args=(self.__workers, + self.__stop_called)) + + waiter.start() + waiter.join() + + if self.__stop_called.is_set(): + for worker in self.__workers: + worker.terminate() + + for worker in self.__workers: worker.join() - @staticmethod - def __run_exp(exp_func, exp_plan): - instance = exp_plan.next() + def stop(self): + self.__stop_called.set() + +class Worker (multiprocessing.Process): + def __init__(self, exp_mod, exp_plan): + multiprocessing.Process.__init__(self) + + self.__exp_mod = exp_mod + self.__exp_plan = exp_plan + + def run(self): + instance = self.__exp_plan.next() + + while instance != None: + self.__exp_mod.run(instance) + self.__exp_plan.done_with(instance) - while instance != None: - exp_func(instance) + instance = self.__exp_plan.next() - exp_plan.done_with(instance) + def terminate(self): + self.__exp_plan.delete() + multiprocessing.Process.terminate(self) - instance = exp_plan.next() diff --git a/al2/plan.py b/al2/plan.py index 82a8911..f5f0956 100644 --- a/al2/plan.py +++ b/al2/plan.py @@ -104,8 +104,8 @@ class Plan: elif self.file.is_file(): self.file.unlink() - def __del__(self): - + #def __del__(self): + def delete(self): with self.__lock: self.__load()