Source code for NEDAS.utils.parallel

import os
import sys
from functools import wraps
from typing import Any, TypeVar, Callable, ParamSpec, Sequence
import time
from concurrent.futures import ProcessPoolExecutor, process
import threading
import traceback
import numpy as np

[docs] class Comm: """ Communicator class supporting both serial and MPI programs. When the python program is started with MPI environment, for example:: $ mpirun -n 10 python -m mpi4py program.py A communicator can be obtained from the mpi4py package: >>> from mpi4py import MPI >>> comm = MPI.COMM_WORLD However, when the program is run in Attributes: parallel_io (bool): If netCDF4.Dataset is built with parallel I/O support. """ parallel_io: bool mpi_ready: bool = False def __init__(self): # detect if mpi environment exists # possible environ variable names from mpi calls mpi_env_var = ('PMI_SIZE', 'OMPI_UNIVERSE_SIZE') if any([ev in os.environ for ev in mpi_env_var]): # program is called from mpi, initialize comm try: from mpi4py import MPI #type: ignore self._MPI = MPI self._comm = MPI.COMM_WORLD self.mpi_ready = True except ImportError: print("Warning: MPI environment found but 'mpi4py' module is not installed. Falling back to serial program for now.", flush=True) self._MPI = None self._comm = DummyComm() else: # serial program, use a dummy communicator self._MPI = None self._comm = DummyComm() self.parallel_io = self.check_parallel_io() # file lock to ensure only one processor access a file at a time self._locks = {} def __getattr__(self, attr): if attr == '_comm': raise AttributeError("_comm not initialized") comm = self.__dict__.get('_comm') if comm is not None and hasattr(comm, attr): return getattr(comm, attr) raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'")
[docs] def init_file_lock(self, filename): """ Initialize file locks for thread-safe I/O. Args: filename (str): Path to the file. """ if self._MPI is None or isinstance(self._comm, DummyComm) or not filename: return if filename not in self._locks: # create the lock memory if self.Get_rank() == 0: lock_mem = np.zeros(1, dtype='B') else: lock_mem = None lock_win = self._MPI.Win.Create(lock_mem, comm=self._comm) self._locks[filename] = lock_win
[docs] def check_parallel_io(self) -> bool: """ Check if netCDF4 is built with parallel I/O support. Returns: bool: True if netCDF4 module support parallel I/O mode. """ try: from netCDF4 import Dataset with Dataset('dummy.nc', mode='w', parallel=True): return True except Exception: return False
[docs] def cleanup_file_locks(self): try: for file, lock_win in self._locks.items(): try: lock_win.Free() except Exception as e: print(f"Rank {self.Get_rank()}: warning freeing win for {file}: {e}", flush=True) self._locks.clear() except Exception as e: print(f"Rank {self.Get_rank()}: error cleaning locks: {e}", file=sys.stderr, flush=True)
[docs] def acquire_file_lock(self, filename): if self._MPI is None or isinstance(self._comm, DummyComm): return assert filename in self._locks, f"Comm: file lock for {filename} not initialized" lock_win = self._locks[filename] check_dt = 0.1 # check file locks every 0.1 seconds, can make this configurable while True: # print(f"pid {self.Get_rank()} waiting for lock on {filename}", flush=True) lock_mem = np.zeros(1, dtype='B') one = np.array([1], dtype='B') lock_win.Lock(0, self._MPI.LOCK_EXCLUSIVE) lock_win.Fetch_and_op(one, lock_mem, 0, 0, self._MPI.REPLACE) lock_win.Unlock(0) if lock_mem[0] == 0: # print(f"pid {self.Get_rank()} acquires lock on {filename}", flush=True) break time.sleep(check_dt)
[docs] def release_file_lock(self, filename): if self._MPI is None or isinstance(self._comm, DummyComm): return if filename in self._locks: zero = np.array([0], dtype='B') lock_win = self._locks[filename] lock_win.Lock(0, self._MPI.LOCK_EXCLUSIVE) lock_win.Put(zero, 0, 0) lock_win.Unlock(0)
# print(f"pid {self.Get_rank()} releases lock on {filename}", flush=True)
[docs] def finalize(self): """Clean up MPI resources cleanly to avoid hangs on exit.""" # nothing to do for serial/DummyComm if self._MPI is None or isinstance(self._comm, DummyComm): return self.cleanup_file_locks() self._comm.Barrier() self._locks = {} self._MPI = None
[docs] class DummyComm: """Dummy communicator for python without mpi""" def __init__(self): self.size = 1 self.rank = 0 self.buf = {}
[docs] def Get_size(self): return self.size
[docs] def Get_rank(self): return self.rank
[docs] def Barrier(self): pass
[docs] def Abort(self, code:int): print(f"\nAbort({code}) on rank 0: application called MPI_Abort.") sys.exit(code)
[docs] def Split(self, color=0, key=0): return self
[docs] def bcast(self, obj, root=0): return obj
[docs] def send(self, obj, dest, tag): self.buf[tag] = obj
[docs] def recv(self, source, tag): return self.buf[tag]
[docs] def allgather(self, obj): return [obj]
[docs] def gather(self, obj, root=0): return obj
[docs] def allreduce(self, obj): return obj
[docs] def reduce(self, obj, root=0): return obj
T = TypeVar("T") # represents the return type of a func P = ParamSpec("P") # represents the parameter list of a func
[docs] def by_rank(comm: Comm, rank: int) -> Callable[[Callable[P, T]], Callable[P, T|None]]: """ Decorator for func() to be run only by rank 0 in comm """ def decorator(func: Callable[P, T]) -> Callable[P, T|None]: @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> T|None: if comm.Get_rank() == rank: try: return func(*args, **kwargs) except Exception as e: tb = traceback.format_exc() print(f"\nPID {rank} raised {type(e).__name__}: {e}\n{tb}", file=sys.stderr, flush=True) comm.Abort(1) else: return None return wrapper return decorator
[docs] def bcast_by_root(comm: Comm) -> Callable[[Callable[P, T]], Callable[P, T]]: """ Decorator for func() to be run only by rank 0 in comm, and result of func() is then broadcasted to all other ranks. """ def decorator(func: Callable[P, T]) -> Callable[P, T]: @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: result: dict[str, Any] = {'return':None, 'error':None} if comm.Get_rank() == 0: try: result['return'] = func(*args, **kwargs) except Exception as e: tb = traceback.format_exc() print(f"\nPID 0 raised {type(e).__name__}: {e}\n{tb}", file=sys.stderr) result['error'] = str(e) result = comm.bcast(result, root=0) if result['error'] is not None: comm.Abort(1) return result['return'] return wrapper return decorator
[docs] def distribute_tasks(comm: Comm, tasks: np.ndarray|Sequence, load: np.ndarray|Sequence|None=None) -> dict[int, list]: """ Divide a list of task indices and assign a subset to each rank in comm Args: comm (Comm): MPI communicator tasks (ArrayLike): List of task indices (to be distributed over the processors) load (np.ndarray, optional): Amount of workload for each task element The default is None, we will let tasks have equal workload Returns: dict: Dictionary {rank:list}, list is the subset of tasks for the processor rank calling this function to work on """ nproc = comm.Get_size() # number of processors ntask = len(tasks) # number of tasks # assume equal load between tasks if not specified if load is None: load = np.ones(ntask) # make sure load has right length _load = np.array(load) if _load.size != ntask: raise ValueError(f'Length of task load = {_load.size} not equal to ntask = {ntask}') # normalize to get load distribution function _load = _load / np.sum(_load) # cumulative load distribution, rounded to 5 decimals cum_load = np.round(np.cumsum(_load), decimals=5) # given the load distribution function, we assign load to processors # by evenly divide the distribution into nproc parts # this is done by searching for r/nproc in the cumulative load for rank r # task_id holds the start/end index of task for each rank in a sequence task_id = np.zeros(nproc+1, dtype=int) target_cum_load = np.arange(nproc)/nproc # we want even distribution of load tol = 0.1/nproc # allow some tolerance for rounding error in comparing cum_load to target_cum_load ind1 = np.searchsorted(cum_load+tol, target_cum_load, side='right') ind2 = np.searchsorted(cum_load-tol, target_cum_load, side='right') # choose between ind1,ind2, whoever gives best match between cum_load[ind?] and target_cum_load task_id[0:-1] = np.where(np.abs(cum_load[ind1-1]-target_cum_load) < np.abs(cum_load[ind2-1]-target_cum_load), ind1, ind2) # make sure the two end points are right task_id[0] = 0 task_id[-1] = ntask # dict for each rank r -> its own task list given start/end index task_list = {} for r in range(nproc): task_list[r] = tasks[task_id[r]:task_id[r+1]] return task_list
[docs] class OfflineScheduler: """ An offline scheduler class for queuing and running multiple jobs on available workers (group of processors). The jobs are submitted by one processor with the scheduler, while the job.run code is calling subprocess to be run on the worker """ def __init__(self, c, nworker: int, walltime: int|None=None, check_dt: float=0.1, debug: bool=False) -> None: self.nworker = nworker self.available_workers = list(range(nworker)) self.walltime = walltime self.check_dt = check_dt self.debug = debug self.jobs = {} self.queue_open = True self.running_jobs = [] self.pending_jobs = [] self.completed_jobs = [] self.error_jobs = {} self.njob = 0 self.c = c self.executor = ProcessPoolExecutor( max_workers=nworker, initializer=os.setpgrp, )
[docs] def submit_job(self, name: str, job: Callable, *args, **kwargs) -> None: """ Submit a job to the scheduler, hold info in jobs dict. Args: name (str): unique name to identify this job job (Callable): callable with is_running and kill methods ``*args``, ``**kwargs``: passed into job() """ self.jobs[name] = {'worker_id':None, 'start_time':None, 'job':job, 'args': args, 'kwargs': kwargs, 'future':None } self.pending_jobs.append(name) self.njob += 1 self.c.debug_message = f"Scheduler: Job {name} added: {job.__name__}, args={args}, kwargs={kwargs})"
[docs] def monitor_job_queue(self) -> None: """ Monitor the available_workers and pending_jobs, assign a job to a worker if possible Monitor the running_jobs for jobs that are finished, kill jobs that exceed walltime, and move the finished jobs to completed_jobs """ self.c.total_tasks = self.njob + 1 while self.queue_open and len(self.completed_jobs) < self.njob: # assign pending job to available workers while self.available_workers and self.pending_jobs and self.queue_open: worker_id = self.available_workers.pop(0) name = self.pending_jobs.pop(0) info = self.jobs[name] info['worker_id'] = worker_id info['start_time'] = time.time() try: info['future'] = self.executor.submit(info['job'], *info['args'], worker_id=worker_id, **info['kwargs']) self.running_jobs.append(name) self.c.debug_message = f"Scheduler: Job {name} started by worker {worker_id}" except (process.BrokenProcessPool, RuntimeError): return # if there are completed jobs, free up their workers names = [name for name in self.running_jobs if self.jobs[name]['future'].done()] for name in names: # catch errors from job try: self.jobs[name]['future'].result() except Exception as e: tb = traceback.format_exc() self.c.debug_message = f'Scheduler: Job {name} raised {type(e).__name__}: {e}\n{tb}' self.error_jobs[name] = tb #return # #if exit right away and don't wait for other jobs to finish, uncomment this self.running_jobs.remove(name) self.completed_jobs.append(name) self.available_workers.append(self.jobs[name]['worker_id']) self.c.debug_message = f"Scheduler: Job {name} completed" # kill jobs that exceed walltime if self.walltime is not None: for name in self.running_jobs: elapsed_time = time.time() - self.jobs[name]['start_time'] if elapsed_time > self.walltime: self.jobs[name]['future'].cancel() self.running_jobs.remove(name) self.available_workers.append(self.jobs[name]['worker_id']) e = RuntimeError(f'Scheduler: Job {name} exceeds walltime ({self.walltime}s)') self.error_jobs[name] = e self.completed_jobs.append(name) # log the progress info and let context handle the messaging self.c.current_task = len(self.completed_jobs) self.c.message = f"{len(self.completed_jobs)}/{self.njob} jobs done, {len(self.running_jobs)} running" time.sleep(self.check_dt) self.c.message = f"all {self.njob} jobs done"
[docs] def start_queue(self): """ Start the job queue, and wait for jobs to complete """ try: monitor_thread = threading.Thread(target=self.monitor_job_queue) monitor_thread.daemon = True monitor_thread.start() monitor_thread.join() finally: self.queue_open = False self.shutdown()
[docs] def shutdown(self): # determine if we need to kill workers immediately kill = (len(self.error_jobs) > 0) # shutdown the process pool workers self.executor.shutdown(wait=not kill, cancel_futures=kill) # raise errors within jobs if there are any if self.error_jobs: error_details = "\n".join([f"ERROR: Job {job}: {error}" for job, error in self.error_jobs.items()]) raise RuntimeError(f'Scheduler: there are jobs with errors:\n{error_details}')