Source code for NEDAS.core.model

import os
import glob
import inspect
from typing import Generic, TypeVar, Any
from abc import ABC, abstractmethod
import numpy as np
from datetime import datetime
from NEDAS.config import parse_config
from NEDAS.grid import GridType
from .types import IOMode, VarName, VarDesc, LevelID, EnsRunStrategy
from .context import Context

GridT = TypeVar("GridT", bound=GridType)

[docs] class Model(Generic[GridT], ABC): """ Class for configuring and running a model """ model_name: str io_mode: IOMode variables: dict[VarName, VarDesc] grid: GridT z: dict[LevelID, np.ndarray] mask: np.ndarray ens_init_dir: str|None truth_dir: str|None ens_run_strategy: EnsRunStrategy nproc_per_run: int = 1 nproc_per_util: int = 1 walltime: int|None = None run_process = None run_status: str = 'pending' restart_dir: str forecast_period: int memory: dict _c: Context def __init__(self, context: Context|None=None, io_mode: IOMode|None=None, config_file: str|None=None, parse_args: bool=False, **kwargs) -> None: # prepare context if context is not None: assert isinstance(context, Context), f"{context} is not a Context object" self._c = context else: self._c = Context() # use default context if not specified # determine io_mode if io_mode: self.io_mode = io_mode else: self.io_mode = self._c.config.io_mode # parse model config file and obtain a list of attributes # get a list of values from default.yml and update with kwargs, save to config_dict code_dir = os.path.dirname(inspect.getfile(self.__class__)) self.model_name = os.path.basename(code_dir) config_dict = parse_config(code_dir, config_file, parse_args, **kwargs) for key, value in config_dict.items(): setattr(self, key, value) @property def c(self) -> Context: return self._c
[docs] def parse_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: # args that pinpoints a certain model state variable if 'path' not in kwargs: kwargs['path'] = self.c.config.work_dir # default path is work_dir if 'member' not in kwargs: kwargs['member'] = None if kwargs['member'] is not None: assert kwargs['member'] >= 0, f"member index should be >= 0, got {kwargs['member']}" if 'name' not in kwargs: kwargs['name'] = list(self.variables.keys())[0] # if not specified, use first variable listed assert kwargs['name'] in self.variables, f"'{kwargs['name']}' is not defined in model variables" if 'time' not in kwargs: kwargs['time'] = None if kwargs['time'] is not None: assert isinstance(kwargs['time'], datetime), "kwargs 'time' is expected to be a datetime object'" levels = list(self.variables[kwargs['name']].levels) if 'k' not in kwargs: kwargs['k'] = levels[0] # set to the first level if not specified assert kwargs['k'] in levels, f"level {kwargs['k']} is not available for variable {kwargs['name']}" if 'units' not in kwargs: kwargs['units'] = self.variables[kwargs['name']].units # some other runtime args need to be initialized for methods for key in ['restart_dir', 'forecast_period']: if key not in kwargs: kwargs[key] = None return kwargs
[docs] def get_mstr(self, member): return f'_mem{member+1:03d}' if member is not None else ''
[docs] def get_tstr(self, time): assert time is not None, 'missing time' return f"{time:%Y%m%d_%H%M}"
[docs] @abstractmethod def read_grid(self, **kwargs) -> None: """ Read the grid information from the model output. Args: **kwargs: Keyword arguments for reading the grid. """ ...
[docs] def read_var(self, **kwargs) -> np.ndarray: """ Read a variable from the model output. Args: **kwargs: Keyword arguments for reading the variable. Returns: np.ndarray: The read variable. """ if self.io_mode == 'offline': return self.read_var_from_file(**kwargs) elif self.io_mode == 'online': return self.read_var_from_memory(**kwargs) else: raise ValueError(f"Unknown io_mode {self.io_mode}")
[docs] def read_var_from_file(self, **kwargs) -> np.ndarray: raise NotImplementedError
[docs] def read_var_from_memory(self, **kwargs): kwargs = self.parse_kwargs(kwargs) tstr = self.get_tstr(kwargs['time']) tag = kwargs['tag'] mstr = self.get_mstr(kwargs['member']) key = tag+mstr name = kwargs['name'] if tstr not in self.memory: raise KeyError(f"{self.__class__.__name__}: '{tstr}' not found in memory") if key not in self.memory[tstr]: raise KeyError(f"{self.__class__.__name__}: '{key}' not found in memory['{tstr}']") if name not in self.memory[tstr][key]: raise KeyError(f"{self.__class__.__name__}: '{name}' not found in memory['{tstr}']['{key}']") return self.memory[tstr][key][name]
[docs] def write_var(self, var, **kwargs) -> None: """ Write a variable to the model output. Args: var (np.ndarray): The variable to write. **kwargs: Keyword arguments for writing the variable. """ if self.io_mode == 'offline': self.write_var_to_file(var, **kwargs) elif self.io_mode == 'online': self.write_var_to_memory(var, **kwargs) else: raise ValueError(f"Unknown io_mode {self.io_mode}")
[docs] def write_var_to_file(self, var, **kwargs): raise NotImplementedError
[docs] def write_var_to_memory(self, var, **kwargs): kwargs = self.parse_kwargs(kwargs) tag = kwargs['tag'] mstr = self.get_mstr(kwargs['member']) tstr = self.get_tstr(kwargs['time']) key = tag+mstr name = kwargs['name'] # create memory dict entry if not yet if tstr not in self.memory: self.memory[tstr] = {} if key not in self.memory[tstr]: self.memory[tstr][key] = {} self.memory[tstr][key][name] = var
[docs] def save_memory(self, tag: str, time: datetime|None=None, path: str|None=None) -> None: if self.io_mode == 'offline': return if path is None: path = self.c.config.work_dir times_to_save = [self.get_tstr(time)] if time is not None else self.memory.keys() for tstr in times_to_save: if tstr not in self.memory: continue for key in self.memory[tstr]: if not key.startswith(tag): continue savedir = os.path.join(path, 'memory', 'model', self.model_name, tstr, key) self.c.fs.make_dir(savedir) for name in self.memory[tstr][key]: savefile = os.path.join(savedir, f'{name}.npy') np.save(savefile, self.memory[tstr][key][name])
[docs] def load_memory(self, tag: str, time: datetime|None=None, path: str|None=None) -> None: if self.io_mode == 'offline': return if path is None: path = self.c.config.work_dir tstr_pattern = self.get_tstr(time) if time is not None else '????????_????' search_path = os.path.join(path, 'memory', 'model', self.model_name, tstr_pattern, f'{tag}*', '*.npy') for savefile in glob.glob(search_path): # extract tstr and key tstr = os.path.basename(os.path.dirname(os.path.dirname(savefile))) key = os.path.basename(os.path.dirname(savefile)) name = os.path.splitext(os.path.basename(savefile))[0] # load data to memory if tstr not in self.memory: self.memory[tstr] = {} if key not in self.memory[tstr]: self.memory[tstr][key] = {} self.memory[tstr][key][name] = np.load(savefile)
[docs] @abstractmethod def z_coords(self, **kwargs) -> np.ndarray: """ Get the vertical coordinates of the model. Args: **kwargs: Keyword arguments for getting the vertical coordinates. Returns: np.ndarray: The vertical coordinates. """ ...
[docs] @abstractmethod def preprocess(self, *args, **kwargs) -> None: """ Preprocess the model data. Args: **kwargs: Keyword arguments for preprocessing. """ ...
[docs] @abstractmethod def postprocess(self, *args, **kwargs) -> None: """ Postprocess the model data. Args: **kwargs: Keyword arguments for postprocessing. """ ...
[docs] @abstractmethod def run(self, *args, **kwargs) -> None: """ Run the model forward in time. Args: *args: Arguments **kwargs: Keyword arguments Keyword Args: time (datetime): current time when forecast starts restart_dir (str): directory where restart files are located forecast_period (int): forecast period in hours If self.ens_run_strategy == 'batch', the method will run all ensemble members in one go, expect additional kwargs['nens'] to be the ensemble size. If self.ens_run_strategy == 'scheduler', the method runs a single member indexed by kwargs['member'], and kwargs['worker_id'] is the pid assigned by the scheduler to run this method. """ ...
[docs] def generate_truth(self, *args, **kwargs) -> None: """ Generate truth (nature run) model states. Use for running synthetic observation experiments. """ raise NotImplementedError(f"'generate_truth' is not implemented for {self.__class__.__name__}")
[docs] def generate_init_ensemble(self, *args, **kwargs) -> None: """ Generate initial perturbed model states for ensemble forecasts. Args: nens (int): ensemble size **kwargs """ raise NotImplementedError(f"'generate_init_ensemble' is not implemented for {self.__class__.__name__}")