Source code for NEDAS.config.config

from code import interact
import os
import inspect
from typing import Literal, Any
import yaml
import dateutil.parser
from datetime import datetime, timezone
import NEDAS
from NEDAS.utils import progress
from .parse_config import parse_config

[docs] class Config: """ Class to manage the configuration for running the NEDAS analysis. Configuration entries are described in details in :doc:`config_file`. Args: config_file (str, optional): Path to the configuration file. parse_args (bool, optional): If true, parse command line arguments to collect configuration. Default is False. **kwargs: Additional key-value pairs to be passed to parse_config. Can be used to override values in the config file. """ work_dir: str directories: dict[str, str] python_env: str|None io_mode: Literal['online', 'offline'] job_submit: dict|None # parallel scheme nproc: int nproc_mem: int nproc_rec: int nproc_util: int pid: int pid_mem: int pid_rec: int pid_show: int # avail in context # experiment design parameters nens: int run_preproc: bool run_forecast: bool run_analysis: bool run_postproc: bool run_diagnose: bool step: str|None # runtime logging options call_stack: list[dict]|None debug: bool timer: bool quiet: bool interactive: bool|None is_notebook: bool|None call_stack_max_level: int|None cols: int anchor: int tabspace: int progress_bar_width: int # time control time: datetime # avail in context time_start: datetime time_end: datetime time_analysis_start: datetime time_analysis_end: datetime cycle_period: float forecast_period: float obs_time_steps: list[float] obs_time_scale: float state_time_steps: list[float] state_time_scale: float # some definitions grid_def: dict state_def: dict|None model_def: dict|None obs_def: dict|None dataset_def: dict|None shuffle_obs: bool z_coords_from: Literal['mean', 'member'] interp_method: str perturb: dict|None # more details in assimilation algorithm scheme: str niter: int iter: int # avail in context resolution_level: list[int] character_length: list[float] localize_scale_fac: list[float] obs_err_scale_fac: list[float] assimilator_def: dict updator_def: dict covariance_def: dict inflation_def: dict localization_def: dict transform_def: dict alignment: dict|None diag: dict|None def __init__(self, config_file: str|None=None, parse_args: bool=False, **kwargs): # parse the yaml config file to obtain the values code_dir = os.path.dirname(inspect.getfile(self.__class__)) config_dict = parse_config(code_dir, config_file, parse_args, **kwargs) # replace placeholders in dir paths with actual values config_dict['work_dir'] = os.path.abspath(config_dict['work_dir']) self.work_dir = config_dict['work_dir'] self.nedas_root = NEDAS.__path__[0] config_dict = self._parse_directories(config_dict) # check a few attributes, setting default values if not specified in yaml file config_dict = self._check_time_scheme(config_dict) config_dict = self._check_parallel_scheme(config_dict) # set current iteration to 0 if undefined if 'iter' not in config_dict or config_dict['iter'] is None: config_dict['iter'] = 0 # set the attributes self.__dict__.update(config_dict) def _parse_directories(self, data: Any) -> Any: """ Parse the directories or file names defined in :code:`data` and replace the placeholders {work_dir} and {nedas_root} with the actual values. """ if isinstance(data, dict): return {key: self._parse_directories(value) for key, value in data.items()} elif isinstance(data, list): return [self._parse_directories(element) for element in data] elif isinstance(data, str): return data.replace('{work_dir}', self.work_dir).replace('{nedas_root}', self.nedas_root) else: return data def _check_time_scheme(self, config_dict: dict) -> dict: """ Initialize the time variables for the analysis. Checks if the mandatory :code:`time_*` entries are defined in the config file. If :code:`time` is not set, set it to :code:`time_start` by default. YAML file recognizes 2001-01-01T00:00:00 format and convert directly to datetime object. If time is a formatted string, will try to parse it using dateutil.parser. """ # check if mandatory time keys are defined in config file for key in ['time', 'time_start', 'time_end', 'time_analysis_start', 'time_analysis_end']: if key not in config_dict: raise KeyError(f"'{key}' is missing in config file") if isinstance(config_dict[key], str): try: config_dict[key] = dateutil.parser.parse(config_dict[key]) except Exception: raise ValueError(f"Failed to convert string {key}={config_dict[key]} to datetime") # add default tzinfo if config_dict[key] and config_dict[key].tzinfo is None: config_dict[key] = config_dict[key].replace(tzinfo=timezone.utc) if config_dict['time'] is None: # initialize current time to start time, if not available config_dict['time'] = config_dict['time_start'].replace() if config_dict['time_analysis_start'] is None: # initialize analysis start time if not available config_dict['time_analysis_start'] = config_dict['time_start'].replace() if config_dict['time_analysis_end'] is None: # initialize analysis end time if not available config_dict['time_analysis_end'] = config_dict['time_end'].replace() return config_dict def _check_parallel_scheme(self, config_dict: dict) -> dict: """ Check the number of processors for parallelization """ # nproc is the total number of processpors # if not defined, set to 1 (serial program) by default if 'nproc' not in config_dict or config_dict['nproc'] is None: config_dict['nproc'] = 1 # In parallel schemes, the communicator is divided into mem/rec groups # nproc_mem and nproc_rec are the number of groups in each direction # set default values if they are not defined if 'nproc_mem' not in config_dict or config_dict['nproc_mem'] is None: config_dict['nproc_mem'] = config_dict['nproc'] # check if division works if config_dict['nproc'] % config_dict['nproc_mem'] != 0: raise ValueError(f"nproc={config_dict['nproc']} is not evenly divided by nproc_mem={config_dict['nproc_mem']}") config_dict['nproc_rec'] = int(config_dict['nproc']/config_dict['nproc_mem']) # nproc_util (optional) is nproc to use for utility functions if 'nproc_util' not in config_dict or config_dict['nproc_util'] is None: config_dict['nproc_util'] = config_dict['nproc'] return config_dict
[docs] def dump_yaml(self, config_file: str): """ Dump the current configuration to a YAML file. Args: config_file (str): Path to the output configuration file. """ with open(config_file, 'w') as f: yaml.dump(self.__dict__, f, sort_keys=False)
[docs] def summary(self) -> str: """ Return a comprehensive summary of the NEDAS configuration. """ # Format the active flags for the workflow workflow = [] if self.run_preproc: workflow.append("Preprocess") if self.run_analysis: workflow.append("Analysis") if self.run_postproc: workflow.append("Postprocess") if self.run_forecast: workflow.append("Forecast") if self.run_diagnose: workflow.append("Diagnose") workflow_str = " -> ".join(workflow) if workflow else "None" fcst_str = f"{self.forecast_period}h" if hasattr(self, 'forecast_period') and self.forecast_period else "N/A" js = self.job_submit or {} loc = self.localization_def or {} h_loc = loc.get('horizontal', {}).get('type', 'N/A') v_loc = loc.get('vertical', {}).get('type', 'N/A') inf = self.inflation_def or {} inf_str = f"{inf.get('type', 'None')} (coef: {inf.get('coef', 1.0)}, adaptive: {inf.get('adaptive', False)})" state_vars = [f"{d.get('name')} ({d.get('model_src')})" for d in (self.state_def or [])] obs_vars = [f"{d.get('name')} ({d.get('dataset_src')})" for d in (self.obs_def or [])] # Construct the summary block summary_text = f""" CONFIGURATION SUMMARY {'='*21} Directories: Work Dir: {self.work_dir} NEDAS Root: {self.nedas_root} Time Configuration: Current Time: {self.time} Experiment: [{self.time_start}] to [{self.time_end}] Analysis: [{self.time_analysis_start}] to [{self.time_analysis_end}] Periods: Cycle: {self.cycle_period}h | Forecast: {fcst_str} Parallel Scheme: Total Procs: {self.nproc} Decomposition: {self.nproc_mem} (mem) x {self.nproc_rec} (rec) Procs for utility funcs: {self.nproc_util} Host: {js.get('host', 'local')} Scheduler: {js.get('scheduler', 'None')} | Project: {js.get('project', 'N/A')} Queue/Mode: {js.get('queue', 'N/A')} | Parallel mode: {js.get('parallel_mode', 'serial')} Analysis Scheme: General: Scheme: {self.scheme} | Ensemble Size: {self.nens} | IO: {self.io_mode} Grid Type: {self.grid_def.get('type', 'N/A') if self.grid_def else 'N/A'} Iteration: {self.iter + 1} of {self.niter} (Outer Loops) Assimilator: Type: {self.assimilator_def.get('type') if self.assimilator_def else 'None'} Updator: Type: {self.updator_def.get('type') if self.updator_def else 'None'} Inflation: {inf_str} Localization: H: {h_loc} | V: {v_loc} | T: {loc.get('temporal', {}).get('type', 'N/A')} Multiscale: Resolution Levels: {self.resolution_level} | Character Lengths: {self.character_length} Localization Factor: {self.localize_scale_fac} | Obs Err Factor: {self.obs_err_scale_fac} Definitions: Models Used: {", ".join(self.model_def.keys()) if self.model_def else 'None'} Datasets: {", ".join(self.dataset_def.keys()) if self.dataset_def else 'None'} State Vector: {', '.join(state_vars) if state_vars else 'None'} Observations: {', '.join(obs_vars) if obs_vars else 'None'} Workflow Status: Active Steps: {workflow_str} Debug Mode: {self.debug} | Timer: {self.timer} | Interactive: {self.interactive} """ return summary_text