from __future__ import annotations
import os
import sys
import shutil
import copy
import time
from typing import get_args, Callable, TYPE_CHECKING
from functools import wraps
import numpy as np
from datetime import datetime, timedelta
from pyproj import Proj
from NEDAS import __version__
from NEDAS.utils import parallel, progress
from NEDAS.config import Config
from NEDAS import grid, models, datasets, assim_tools, io_backends, job_submitters
from .file_system import FileSystem
from .types import ProcIDMem, MemID, ParallelMode
if TYPE_CHECKING:
from . import Model, Dataset, IOBackend, JobSubmitter, State, Obs, Transform, Inflation, Assimilator, Updator
[docs]
class Context:
"""
Runtime context manages the generation and interaction of dynamic objects in runtime
"""
interactive: bool
is_notebook: bool
debug: bool
comm: parallel.Comm
comm_rec: parallel.Comm
comm_mem: parallel.Comm
pid_show: int
progress: progress.Progress
fs: FileSystem
io: IOBackend
jsub: JobSubmitter
nens: int
mem_list: dict[ProcIDMem, list[MemID]]
grid: grid.GridType
grid_orig: grid.GridType
time: datetime
iter: int
models: dict[str, Model]
datasets: dict[str, Dataset]
assimilator: Assimilator
updator: Updator
transform_funcs: list[Transform]
localization_funcs: dict[str, Callable]
inflation_func: Inflation
state: State
obs: Obs
def __init__(self, config: Config|None=None,
config_file: str|None=None,
parse_args: bool=False,
**kwargs) -> None:
if isinstance(config, Config):
self.config = config
else:
self.config = Config(config_file=config_file, parse_args=parse_args, **kwargs)
# initialize logging
self.debug = self.config.debug
self.interactive = self.check_interactive()
self.is_notebook = self.check_notebook()
self.set_logging()
# initialize the current time pointer
# prev_time and next_time properties provide the time for previous/next analysis cycle
self.time = self.config.time
# initialize the current iteration
self.iter = self.config.iter
# initialize the pid that shows progress (default to the root process pid=0)
self.pid_show = 0
self._prev_msg = ''
# ensemble size
self.nens = self.config.nens
# setup the parallel (serial or MPI program) communicator
self.set_comm()
self.mem_list = parallel.bcast_by_root(self.comm)(self.distribute_mem_tasks)()
# initialize a few helper class instances
self.fs = FileSystem(self.config)
self.io = io_backends.get_io_backend(self.config.io_mode)
self.jsub = job_submitters.get_job_submitter(**(self.config.job_submit or {}))
# setup the analysis grid object
self.set_grid()
# setup the model and obs dataset objects
self.set_models()
self.set_datasets()
# more living objects (io, state, obs, other components)
# will be created by scheme class __init__ and methods at runtime
[docs]
def check_interactive(self) -> bool:
"""
If the runtime environment supports interactive output (with ANSI escape code).
"""
# if debug mode is on, disable interactive output
# since we need to show a lot of debug messages
if self.config.debug:
return False
# if interactive option is explicitly set in config, use it
if self.config.interactive is not None:
return self.config.interactive
# otherwise, check if a tty is present, if so, should support interactive output.
return os.isatty(sys.stdout.fileno())
[docs]
def check_notebook(self) -> bool:
"""
If the runtime environment is a jupyter notebook
"""
if self.config.is_notebook is not None:
return self.config.is_notebook
return "ipykernel" in sys.modules
[docs]
def get_cols(self) -> int:
# if cols is explicitly set in config, use it
if self.config.cols is not None:
return self.config.cols
# in jupyter notebook environment, just use a large number
if self.is_notebook:
return 800
# otherwise, get real time terminal size
cols, _ = shutil.get_terminal_size()
return cols
[docs]
def set_logging(self) -> None:
progress_opts = {
'interactive': self.interactive,
'is_notebook': self.is_notebook,
'cols': self.get_cols(),
'debug': self.config.debug,
'call_stack': self.config.call_stack,
'call_stack_max_level': self.config.call_stack_max_level,
'anchor': self.config.anchor,
'tabspace': self.config.tabspace,
'progress_bar_width': self.config.progress_bar_width,
}
self.progress = progress.Progress(**progress_opts)
[docs]
def distribute_mem_tasks(self) -> dict[int, list[int]]:
"""
Distribute mem_id across processors
"""
# list of mem_id as tasks
mem_list_full = [m for m in range(self.nens)]
mem_list = parallel.distribute_tasks(self.comm_mem, mem_list_full)
return mem_list
[docs]
def update_assim_tools(self):
"""
Update the assimilation tool components based on runtime configuration
"""
# update grid with current iteration settings
res_lev = self.config.resolution_level[self.iter]
self.grid = self.grid_orig.change_resolution_level(res_lev)
# initialize a few func components in the assimilation algorithm
self.assimilator = assim_tools.assimilators.get_assimilator(self)
self.updator = assim_tools.updators.get_updator(self)
self.localization_funcs = assim_tools.localization.get_localization_funcs(self)
self.inflation_func = assim_tools.inflation.get_inflation_func(self)
self.transform_funcs = assim_tools.transforms.get_transform_funcs(self)
@property
def prev_time(self) -> datetime:
"""
Previous analysis time. Automatically updated when self.time changes.
Returns:
datetime: Previous analysis time.
"""
if self.time > self.config.time_start:
return self.time - self.config.cycle_period * timedelta(hours=1)
else:
return self.time
@property
def next_time(self) -> datetime:
"""
Next analysis time. Automatically updated when self.time changes.
Returns:
datetime: Next analysis time.
"""
return self.time + self.config.cycle_period * timedelta(hours=1)
[docs]
def set_comm(self) -> None:
"""
Initialize the MPI communicator, split the communicator if necessary.
For serial program, use a dummy communicator, set :code:`nproc` to the number of available processors on the machine;
for MPI program, use :code:`MPI.COMM_WORLD` and check if size matchs with :code:`nproc`.
Split the communicator into member and record groups, according to :code:`nproc` and :code:`nproc_mem`.
See :mod:`NEDAS.utils.parallel` module for more details.
"""
# initialize mpi communicator (could be size 1 for serial program)
self.comm = parallel.Comm()
comm_size = self.comm.Get_size()
self.pid = self.comm.Get_rank() # current processor id
# stop early if mpi environment is not ready (program is not called from mpirun)
if not self.comm.mpi_ready:
self.pid_mem, self.pid_rec = 0, 0
self.comm_mem, self.comm_rec = self.comm, self.comm
return
# validate mpi environment
if comm_size != self.config.nproc:
raise RuntimeError(f"Config nproc={self.config.nproc} does not match with MPI COMM size={comm_size}.")
# split comm so that nproc_mem * nproc_rec == nproc
self.pid_mem = self.pid % self.config.nproc_mem
self.pid_rec = self.pid // self.config.nproc_mem
self.comm_mem = self.comm.Split(self.pid_rec, self.pid_mem)
self.comm_rec = self.comm.Split(self.pid_mem, self.pid_rec)
[docs]
def set_grid(self) -> None:
"""
Initialize the analysis grid based on the configuration.
If :code:`grid_def['type']` is 'custom', will create a analysis grid based on provided parameters.
If :code:`grid_def['type']` is a model name, will load the grid from the specified model class.
"""
grid_def = self.config.grid_def
if grid_def['type'] == 'custom':
if 'proj' in grid_def and grid_def['proj'] is not None:
proj = Proj(grid_def['proj'])
else:
proj = None
xmin, xmax = grid_def['xmin'], grid_def['xmax']
ymin, ymax = grid_def['ymin'], grid_def['ymax']
dx = grid_def['dx']
known_keys = {'type', 'proj', 'xmin', 'xmax', 'ymin', 'ymax', 'dx', 'mask'}
other_opts = {k: v for k, v in grid_def.items() if k not in known_keys}
self.grid = grid.Grid.regular_grid(proj, xmin, xmax, ymin, ymax, dx, **other_opts)
# mask for invalid grid points (none for now, add option later)
self.grid.mask = np.full((self.grid.ny, self.grid.nx), False, dtype=bool)
if 'mask' in grid_def and grid_def['mask'] is not None:
model_name = grid_def['mask']
Model = models.get_model_class(model_name)
model = Model(context=self)
prepare_mask = getattr(model, 'prepare_mask', None)
if prepare_mask is not None:
self.grid.mask = prepare_mask(self.grid)
else:
# get analysis grid from model module
model_def = self.config.model_def
model_name = grid_def['type']
if model_def is None or model_name not in model_def:
raise KeyError(f"'{model_name}' not defined in config file model_def section")
kwargs = model_def[model_name]
Model = models.get_model_class(model_name)
model = Model(context=self, **kwargs)
model_grid = getattr(model, 'grid')
if not isinstance(model_grid, get_args(grid.GridType)):
raise TypeError(f"Model {model_name} does not have a valid grid attribute.")
self.grid = model_grid
# make a copy of the original analysis grid
self.grid_orig = self.grid
[docs]
def set_models(self) -> None:
"""
Initialize model instances based on :code:`model_def[model_name]` settings.
Store the model instances in :code:`models[model_name]`.
"""
self.models = {}
if self.config.model_def is None:
return
for model_name, kwargs in self.config.model_def.items():
#instantiate the model class
ModelClass = models.get_model_class(model_name)
self.models[model_name] = ModelClass(context=self, **(kwargs or {}))
[docs]
def set_datasets(self) -> None:
"""
Initialize dataset instances based on :code:`dataset_def[dataset_name]` settings.
Store the dataset instances in :code:`datasets[dataset_name]`.
"""
self.datasets = {}
if self.config.dataset_def is None:
return
for dataset_name, kwargs in self.config.dataset_def.items():
DatasetClass = datasets.get_dataset_class(dataset_name)
self.datasets[dataset_name] = DatasetClass(context=self, **(kwargs or {}))
@property
def total_tasks(self):
return self.progress.node['total_tasks']
@total_tasks.setter
def total_tasks(self, value: int):
self.progress.node['total_tasks'] = value
self.debug_message = f"total tasks: {value}"
@property
def current_task(self):
return self.progress.node['current_task']
@current_task.setter
def current_task(self, value: int):
self.progress.node['current_task'] = value
self.progress.set_flag('running')
status = self.progress.update()
self.print_1p(status)
@property
def message(self):
return ''
@message.setter
def message(self, msg: str):
self.progress.node['message'] = msg
@property
def debug_message(self):
return ''
@debug_message.setter
def debug_message(self, msg: str):
if self.config.quiet or not self.debug:
return
# show the debug message with PID info.
# this uses the print since all PID ranks shall print its own message
print(f"PID {self.pid:>4}: {msg}", flush=True)
[docs]
def timer(self, func: Callable):
"""
Decorator to count the elapsed time for a function at runtime.
"""
if not self.config.timer:
return func
@wraps(func)
def wrapper(*args, **kwargs):
t0 = time.time()
try:
return func(*args, **kwargs)
finally:
t1 = time.time()
self.progress.node['elapsed_time'] = t1 - t0
return wrapper
[docs]
def logger(self, func_name: str):
"""
Decorator to register the func in call stack and show runtime messages.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
# register the function in call stack
status = self.progress.push(func_name)
self.print_1p(status)
self.progress.set_flag('waiting')
# execute the function with timer
result = self.timer(func)(*args, **kwargs)
self.progress.set_flag('done')
return result
except (KeyboardInterrupt, SystemExit):
raise
except Exception:
self.progress.set_flag('error')
raise
finally:
status = self.progress.pop()
self.print_1p(status)
return wrapper
return decorator
[docs]
def print_1p(self, msg: str):
"""
Customized print function for showing runtime message.
Only the processor with PID = self.pid_show will show the message,
this avoids the redundancy if all processors are showing the same message.
"""
if self.config.quiet or not msg:
return
if self.comm.Get_rank() != self.pid_show:
return
self._prev_msg = progress.print_with_cache(msg, self._prev_msg)
[docs]
def log_event(self, msg: str, flag=''):
self.print_1p(self.progress.log(msg, flag))
[docs]
def show_greeting(self) -> None:
greeting_msg = f"""
█▄ █ █▀▀▀ █▀▀▄ ▄▀▀▄ ▄▀▀▀
█ ▀▄█ █▀▀ █ █ █▀▀█ ▀▀▀█
▀ ▀ ▀▀▀▀ ▀▀▀ ▀ ▀ ▀▀▀ version {__version__}
"""
self.print_1p(greeting_msg)
[docs]
def show_summary(self) -> None:
summary_text = self.config.summary()
self.print_1p(summary_text)
[docs]
def dump_config(self, config_file: str) -> None:
"""
Dumps a snapshot of the current state to a yaml config file.
The original config object remains unchanged in memory.
"""
# make a copy of the config object for dumping
tmp_config = copy.copy(self.config)
# inject runtime state to the temporary config
for rt_state in ['time', 'iter', 'pid_show', 'interactive', 'is_notebook']:
val = getattr(self, rt_state)
setattr(tmp_config, rt_state, val)
setattr(tmp_config, 'call_stack', self.progress.call_stack)
setattr(tmp_config, 'cols', self.progress.fmt.cols)
# save the config to yaml file
tmp_config.dump_yaml(config_file)
[docs]
def run_job(self, commands: str,
parallel_mode: ParallelMode='serial',
nproc: int=1,
offset: int=0,
**kwargs) -> None:
"""
The user-facing method for running command on a computer.
It re-configures the existing job submitter with runtime arguments and execute the command.
Args:
commands (str): Shell commands to be dispatched by job submitter
parallel_mode (ParallelMode, optional): parallel mode ('serial', 'mpi', 'openmp'), default is 'serial'
nproc (int, optional): number of processors (default is 1)
offset (int, optional): offset in full list of processors (default is 0)
**kwargs: other keyword arguments to update the job submitter configuration
"""
# update the state of the job submitter for this specific task
self.jsub.parallel_mode = parallel_mode
self.jsub.nproc = nproc
self.jsub.offset = offset
for key, value in kwargs.items():
if value and hasattr(self.jsub, key):
setattr(self.jsub, key, value)
# dispatch the command
self.jsub.run(commands)