import os
import sys
import signal
import tempfile
import inspect
from typing import Callable
from abc import ABC, abstractmethod
from NEDAS.job_submitters.hpc import HPCJobSubmitter
from NEDAS.utils.parallel import OfflineScheduler
from NEDAS.datasets.synthetic import SyntheticObs
from NEDAS.config import Config
from NEDAS.core.context import Context
from NEDAS.core.types import EnsRunStrategy, IOTag
[docs]
class Scheme(ABC):
"""
Runtime scheme base class.
The Scheme coordinates all runtime generation and manipulation of objects.
"""
config: Config
online_mode: bool
use_synthetic_obs: bool = False
steps_need_mpi: dict[str, bool] = {}
_context: Context|None = None
scheduler: OfflineScheduler|None = None
def __init__(self, config: Config|None=None,
config_file: str|None=None,
parse_args: bool=False,
**kwargs) -> None:
# parse configuration and generate context
if config:
self.config = config
else:
self.config = Config(config_file=config_file, parse_args=parse_args, **kwargs)
self.c = Context(self.config)
# check if io mode is online:
self.online_mode = (self.config.io_mode == 'online')
# check if one or more of the datasets is synthetic type:
for dataset in self.c.datasets.values():
if isinstance(dataset, SyntheticObs):
self.use_synthetic_obs = True
self._main_pid = os.getpid()
for sig in (signal.SIGTERM, signal.SIGHUP):
signal.signal(sig, self._handle_exit_signal)
def _handle_exit_signal(self, signum, frame):
""" The 'trap' handler. This runs when the OS sends kill signal. """
# only the main process should run the shutdown procedure
if os.getpid() != self._main_pid:
return
self.c.log_event(f"{self.__class__.__name__}: Received signal {signum}. Cleaning up...", flag='info')
# 1. kill the scheduler worker if it exists
if self.scheduler:
worker_pids = list(self.scheduler.executor._processes.keys())
self.c.log_event(f"OfflineScheduler: Cleaning up worker processes {worker_pids}", flag='info')
for pid in worker_pids:
try:
os.killpg(pid, signal.SIGKILL)
except Exception:
pass
# 2. kill the entire main process group
try:
os.killpg(0, signal.SIGKILL)
except Exception:
pass
@property
def c(self):
""" The runtime context, with lazy initialization """
if self._context is None:
self._context = Context(self.config)
return self._context
@c.setter
def c(self, value: Context):
""" Allows refreshing the context from subprocess returns """
self._context = value
def __call__(self) -> None:
"""
The entry point that handles the environment check and starts the engine.
"""
self.c.show_greeting()
self.c.print_1p(f"\nINITIALIZING...\n")
self.c.show_summary()
# Environment check:
# 1. Online mode: (requires mpi environment if nproc>1)
if self.online_mode:
if self.c.comm.mpi_ready or self.config.nproc==1:
# we are already inside the mpi environment, proceed
self.c.logger(self.__class__.__name__)(self.run_all)()
else:
# if not, we will dispatch the whole scheme itself to a job submitter.
if self.c.debug:
self.c.log_event(f"run_all: config.nproc={self.config.nproc}, elevating to a mpi-enabled environment...", flag='info')
self.external_call(step='run_all', parallel_mode='mpi', nproc=self.config.nproc)
# 2. offline mode (manual dispatch per step)
else:
# check if we are accidentally inside an mpi environment already
comm_size = self.c.comm.Get_size()
if self.c.comm.mpi_ready and comm_size>1:
raise RuntimeError(f"Running in offline mode, but an mpi environment comm.size={comm_size} is detected."
"The main program should be run in serial.")
# in offline io mode, each step in run_all will decide how to dispatch itself
self.c.logger(self.__class__.__name__)(self.run_all)()
[docs]
@abstractmethod
def run_all(self):
"""
A schemem must implement a run_all method to describe the workflow.
"""
...
[docs]
def external_call(self, step:str|None=None, **kwargs):
"""
Run the scheme from an external call.
Saving the current context to a temporary config file, then run a subprocess to
"""
script_file = os.path.abspath(inspect.getfile(self.__class__))
# create a temporary config yaml file to hold c, and pass into program through runtime arg
with tempfile.NamedTemporaryFile(dir=self.config.work_dir,
prefix='config-',
suffix='.yml') as tmp_config_file:
self.c.dump_config(tmp_config_file.name)
if self.config.debug:
self.c.log_event(f"config file: {tmp_config_file.name}", flag='info')
# build run commands for the ensemble forecast script
commands = ""
if self.config.python_env:
commands = f". {self.config.python_env}; "
commands += f"JOB_EXECUTE {sys.executable} {script_file} -c {tmp_config_file.name}"
if step:
commands += f" --step {step}"
if self.config.debug:
self.c.log_event(f"running commands: '{commands}'", flag='info')
# build job options
job_opts = {
**(self.config.job_submit or {}),
'job_name': step,
'run_dir': self.c.fs.cycle_dir(self.c.time),
'nproc': self.config.nproc,
'debug': self.config.debug,
**kwargs,
}
# run job
self.c.run_job(commands, **job_opts)
[docs]
def run_step(self, step: str) -> None:
"""
Manages how to run a specified step in the workflow.
"""
if not hasattr(self, step):
raise NotImplementedError(f"Step '{step}' is not implemented for {self.__class__.__name__}")
# in offline mode, run_step starts in serial
# if the step requires mpi for nproc>1, make an external call
if not self.online_mode and self.steps_need_mpi[step]:
if self.config.nproc>1 and not self.c.comm.mpi_ready:
if self.c.debug:
self.c.log_event(f"{step}: config.nproc={self.config.nproc}, elevating to a mpi-enabled environment...", flag='info')
self.external_call(step, parallel_mode='mpi', nproc=self.config.nproc)
return
# otherwise, just call the step func
stepfunc = getattr(self, step)
if step == 'run_all':
func_name = self.__class__.__name__
else:
func_name = f'Running {step} step'
self.c.logger(func_name)(stepfunc)()
[docs]
def run_ensemble_tasks(self, strategy: EnsRunStrategy,
tag: IOTag,
task_name: str,
func: Callable,
**opts) -> None:
if strategy == 'batch':
self._run_ensemble_tasks_batch(tag, task_name, func, **opts)
elif strategy == 'scheduler':
if self.online_mode:
self._run_ensemble_tasks_online(tag, task_name, func, **opts)
else:
self._run_ensemble_tasks_offline_scheduler(tag, task_name, func, **opts)
else:
raise ValueError(f"Unknown ensemble run strategy '{strategy}'")
def _run_ensemble_tasks_batch(self, tag: IOTag, task_name: str, func: Callable, **opts) -> None:
# the func should handle the entire ensemble in one go
# make sure nens is defined in opts
self.c.debug_message = f"running {task_name} in batch mode..."
opts['nens'] = self.c.nens
self.c.io.call_method(self.c, tag, func, **opts)
def _run_ensemble_tasks_online(self, tag: IOTag, task_name: str, func: Callable, **opts) -> None:
# scheduling internally within mpi environment
# using the mem_list (member lists distributed on pid ranks by comm)
nm = len(self.c.mem_list[self.c.pid_mem])
self.c.total_tasks = nm
for m, mem_id in enumerate(self.c.mem_list[self.c.pid_mem]):
opts['member'] = mem_id
self.c.debug_message = f"running {task_name} for mem{mem_id+1:03}"
self.c.current_task = m
self.c.io.call_method(self.c, tag, func, **opts)
self.c.comm.Barrier()
def _run_ensemble_tasks_offline_scheduler(self, tag: IOTag, task_name: str, func: Callable, **opts) -> None:
# setup an offline scheduler to distribute tasks
# get number of available workers to initialize the scheduler
total_nproc = opts.get('total_nproc', self.config.nproc)
if opts['nproc']>1 and isinstance(self.c.jsub, HPCJobSubmitter) and not self.c.jsub.in_job_allocation:
# the scheduling is then delegated to HPC's scheduler (each task submitted as a separate job)
# here, the offline scheduler should just submit all tasks at once
nworker = self.c.nens
else:
assert total_nproc >= opts['nproc'], f"requested nproc ({opts['nproc']}) exceeds available total_nproc {total_nproc}"
nworker = total_nproc // opts['nproc']
# initialize the scheduler
self.c.debug_message = f"running {task_name} in offline scheduler: nworker={nworker}"
self.scheduler = OfflineScheduler(self.c, nworker, opts.get('walltime'), debug=self.config.debug)
# submit jobs
for mem_id in range(self.c.nens):
job_opts = {
**opts,
'member': mem_id,
'debug': self.config.debug,
}
self.scheduler.submit_job(f"{task_name}_mem{mem_id+1:03}", self.c.io.call_method, self.c, tag, func, **job_opts)
try:
self.scheduler.start_queue()
finally:
self.scheduler.shutdown()
self.scheduler = None
self.c.comm.Barrier()