From c7fb1740d96e0d78d0e774a33464487373d0450e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20Kr=C3=BCger?= Date: Thu, 14 Mar 2024 16:54:48 +0100 Subject: [PATCH] implemented seed pass-through --- alma/experiment.py | 40 ++++++++++---------- alma/plan.py | 94 +++++++++++++++++++++++++++------------------- 2 files changed, 77 insertions(+), 57 deletions(-) diff --git a/alma/experiment.py b/alma/experiment.py index 7696916..3632194 100644 --- a/alma/experiment.py +++ b/alma/experiment.py @@ -6,15 +6,19 @@ import threading import concurrent.futures as concfut import os import time +import random +import sys from . import batch from . import plan + def execute(exp_file): dispatcher = load(exp_file) dispatcher.start() dispatcher.join() + def load(exp_file): exp_plan = plan.Plan(exp_file, multiprocessing.Lock()) @@ -32,9 +36,10 @@ def load(exp_file): num_workers = os.cpu_count() else: num_workers = int(exp_obj["workers"]) - + return Dispatcher(exp_mod, exp_plan, num_workers) + class Dispatcher (threading.Thread): def __init__(self, exp_mod, exp_plan, num_workers): threading.Thread.__init__(self) @@ -46,36 +51,35 @@ class Dispatcher (threading.Thread): self.__exp_mod = exp_mod for i in range(self.__num_workers): - self.__workers.append(Worker(exp_mod, exp_plan, i)) + self.__workers.append(Worker(exp_mod, + exp_plan, + i)) def run(self): for worker in self.__workers: worker.start() - + def wait_to_continue(workers, stop_called): - any_worker_alive = lambda: any(map(lambda w: w.is_alive(), workers)) + def any_worker_alive(): any(map(lambda w: w.is_alive(), workers)) while any_worker_alive() and not stop_called.is_set(): time.sleep(0) - waiter = threading.Thread(target=wait_to_continue, args=(self.__workers, self.__stop_called)) - + waiter.start() waiter.join() - + if self.__stop_called.is_set(): for worker in self.__workers: worker.terminate() - - for worker in self.__workers: + for worker in self.__workers: worker.join() self.__done() - def stop(self): self.__stop_called.set() @@ -95,31 +99,29 @@ class Dispatcher (threading.Thread): class Worker (multiprocessing.Process): def __init__(self, exp_mod, exp_plan, id): multiprocessing.Process.__init__(self) - + self.__exp_mod = exp_mod self.__exp_plan = exp_plan self.__id = id def run(self): instance = self.__exp_plan.next() - print(instance) - while instance != None: + + while instance is not None: instance_state = self.__exp_plan.load_instance_state(instance) - + self.__exp_mod.run(instance, lambda data: self.__exp_plan.save_instance_state( instance, data - ), + ), instance_state, worker_id=self.__id) - + self.__exp_plan.done_with(instance) instance = self.__exp_plan.next() - + def terminate(self): self.__exp_plan.delete() multiprocessing.Process.terminate(self) - - diff --git a/alma/plan.py b/alma/plan.py index 58af20c..d35ad13 100644 --- a/alma/plan.py +++ b/alma/plan.py @@ -3,9 +3,12 @@ import json import os import multiprocessing import threading +import random +import sys from . import batch + class Plan: def __init__(self, experiment=None, lock=None): self.experiment = None @@ -13,15 +16,14 @@ class Plan: self.pending_instances = [] self.assigned_instances = [] self.instance_states = {} - + self.__instance_id_counter = 0 - self.__lock = threading.Lock() if lock == None else lock + self.__lock = threading.Lock() if lock is None else lock if experiment: self.create(experiment) - def create(self, experiment): self.experiment = pl.Path(experiment).resolve() self.__set_file() @@ -40,62 +42,65 @@ class Plan: with self.__lock: self.__update_file() - def __create_content(self, iterations_left = None): + def __create_content(self, iterations_left=None): content = {} with open(self.experiment, "r") as expf: exp_obj = json.loads(expf.read()) instances = batch.load(pl.Path(exp_obj["batch"])) - - if iterations_left == None: + + if iterations_left is None: + if "iterations" in exp_obj: iterations_left = exp_obj["iterations"] - 1 + if "seed" in exp_obj: + random.seed(exp_obj["seed"]) + else: iterations_left = 0 - content["pending"] = instances content["iterations_left"] = iterations_left return content - def __set_file(self): - if self.experiment == None: + if self.experiment is None: self.file = None else: exp_path = pl.Path(self.experiment) - self.file = exp_path.parent / (exp_path.stem + ".plan") + self.file = exp_path.parent / (exp_path.stem + ".plan") def __load(self): self.pending_instances = [] self.assigned_instances = [] - + if not self.file.is_file(): return with open(self.file, "r") as pfile: content = json.loads(pfile.read()) - + if "assigned" in content: self.assigned_instances = content["assigned"] self.__instance_id_counter = max(map(lambda i: i["id"], self.assigned_instances)) + 1 - + if "pending" in content: self.pending_instances = content["pending"] if "iterations_left" in content: self.iterations_left = content["iterations_left"] - + if "instance_states" in content: self.instance_states = content["instance_states"] - - + + if "rand_state" in content: + random.setstate(self.__arr2tup(content["rand_state"])) + def __is_finished(self): return False if self.file.is_file() else True - def next(self): @@ -107,25 +112,26 @@ class Plan: self.__load_next_iteration() else: return None - + next_instance = self.pending_instances.pop() next_instance["id"] = self.__instance_id_counter + next_instance["seed"] = random.randint(0, sys.maxsize) self.__instance_id_counter += 1 self.assigned_instances.append(next_instance) - + self.__update_file() - + return next_instance - + def done_with(self, instance): - + with self.__lock: self.__load() self.assigned_instances = list(filter(lambda i: i["id"] != instance["id"], - self.assigned_instances )) - + self.assigned_instances)) + if str(instance["id"]) in self.instance_states: self.instance_states.pop(str(instance["id"])) @@ -133,13 +139,13 @@ class Plan: def __update_file(self): content = {} - + all_done = True content["iterations_left"] = self.iterations_left - + content["instance_states"] = self.instance_states - + if len(self.assigned_instances) > 0: content["assigned"] = self.assigned_instances all_done = False @@ -147,7 +153,9 @@ class Plan: if len(self.pending_instances) > 0: content["pending"] = self.pending_instances all_done = False - + + content["rand_state"] = random.getstate() + if all_done: if self.iterations_left > 0: self.__load_next_iteration() @@ -156,9 +164,9 @@ class Plan: else: self.__write_content(content) - def __load_next_iteration(self): + def __load_next_iteration(self): content = self.__create_content(self.iterations_left - 1) - + self.pending_instances = content["pending"] self.iterations_left = content["iterations_left"] @@ -168,26 +176,36 @@ class Plan: with open(self.file, "w") as pfile: pfile.write(json.dumps(content)) - + + def __serialize_rand_state(self): + return json.dumps(random.getstate()) + + def __arr2tup(self, arr): + for i, e in enumerate(arr): + if type(e) is list: + arr[i] = self.__arr2tup(e) + + return tuple(arr) + def save_instance_state(self, instance, data): - + with self.__lock: self.__load() - + self.instance_states[str(instance["id"])] = data - + self.__update_file() - + def load_instance_state(self, instance): - + with self.__lock: self.__load() - + if str(instance["id"]) in self.instance_states: return self.instance_states[str(instance["id"])] else: return "" - + def delete(self): with self.__lock: self.__load()