Source code for NEDAS.core.dataset

import os
import glob
import inspect
import numpy as np
from typing import Callable, Any
from abc import ABC, abstractmethod
from datetime import datetime
from NEDAS.config import parse_config
from .types import VarName, VarDesc, LevelID
from .context import Context

[docs] class Dataset(ABC): """ Dataset class (template for specific dataset sources) """ dataset_name: str variables: dict[VarName, VarDesc] = {} obs_operator: dict[VarName, Callable] = {} memory: dict = {} _c: Context def __init__(self, context: Context|None=None, config_file: str|None=None, parse_args: bool=False, **kwargs) -> None: # prepare context if context is not None: assert isinstance(context, Context), f"{context} is not a Context object" self._c = context else: self._c = Context() # use default context if not specified # parse dataset config file and obtain a list of attributes code_dir = os.path.dirname(inspect.getfile(self.__class__)) self.dataset_name = os.path.basename(code_dir) config_dict = parse_config(code_dir, config_file, parse_args, **kwargs) for key, value in config_dict.items(): setattr(self, key, value) if not hasattr(self, 'dataset_dir'): self.dataset_dir = '.' if not hasattr(self, 'obs_window_min'): self.obs_window_min = 0 if not hasattr(self, 'obs_window_max'): self.obs_window_max = 0 @property def c(self) -> Context: return self._c
[docs] def parse_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: """ Parse the input kwargs to pinpoint a specific file/variable... """ # args to pinpoint a certain observatino record, used by read_obs, etc. if 'path' not in kwargs: kwargs['path'] = self.dataset_dir if 'name' not in kwargs: kwargs['name'] = list(self.variables.keys())[0] assert kwargs['name'] in self.variables, f"'{kwargs['name']}' is not defined in dataset variables" if 'time' not in kwargs: kwargs['time'] = None if kwargs['time'] is not None: assert isinstance(kwargs['time'], datetime), "kwargs 'time' is not a datetime object" if 'units' not in kwargs: kwargs['units'] = self.variables[kwargs['name']].units # TODO: potential key error here if variables is not defined. # other args, set default values if not specified # model (Model): model class instance # grid (GridType): target analysis grid # mask (np.ndarray): target analysis grid mask (True if grid point is not part of the state) # z (dict[LevelID, np.ndarray]): z coordinates at each level k on analysis grid #TODO: These should be available from the context for key in ['model', 'grid', 'mask', 'z']: if key not in kwargs: kwargs[key] = None # nobs (int): number of observations # obs_window_min (int) # obs_window_max (int) # TODO: maybe setting them both to 0 is incorrect? need to iterate from min to max... for key in ['nobs', 'obs_window_min', 'obs_window_max']: if key not in kwargs: kwargs[key] = 0 return kwargs
[docs] def get_mstr(self, member): return f'_mem{member+1:03d}' if member is not None else ''
[docs] def get_tstr(self, time): assert time is not None, 'missing time' return f"{time:%Y%m%d_%H%M}"
[docs] def generate_obs_network(self, **kwargs) -> dict[str, np.ndarray]: """ Generate a random observing network for use in synthetic observation experiments. Args: **kwargs """ raise NotImplementedError(f"'generate_obs_network' is not implemented for {self.__class__.__name__}")
[docs] def read_obs(self, **kwargs) -> dict[str, np.ndarray]: if self.c.config.io_mode == 'offline': return self.read_obs_from_file(**kwargs) elif self.c.config.io_mode == 'online': return self.read_obs_from_memory(**kwargs) else: raise ValueError(f"Unknown io_mode: {self.c.config.io_mode}")
[docs] @abstractmethod def read_obs_from_file(self, **kwargs) -> dict[str, np.ndarray]: """ Return observation sequence matching the given kwargs """ obs_seq = { 'obs':np.array([]), 't':np.array([]), 'z':np.array([]), 'y':np.array([]), 'x':np.array([]), 'err_std':np.array([]), } return obs_seq
[docs] def read_obs_from_memory(self, **kwargs) -> dict[str, np.ndarray]: kwargs = self.parse_kwargs(kwargs) tstr = self.get_tstr(kwargs['time']) tag = kwargs['tag'] mstr = self.get_mstr(kwargs['member']) key = tag+mstr name = kwargs['name'] if tstr not in self.memory: raise KeyError(f"{self.__class__.__name__}: '{tstr}' not found in memory") if key not in self.memory[tstr]: raise KeyError(f"{self.__class__.__name__}: '{key}' not found in memory['{tstr}']") if name not in self.memory[tstr][key]: raise KeyError(f"{self.__class__.__name__}: '{name}' not found in memory['{tstr}']['{key}']") return self.memory[tstr][key][name]
[docs] def write_obs(self, seq: dict, **kwargs) -> None: if self.c.config.io_mode == 'offline': self.write_obs_to_file(seq, **kwargs) elif self.c.config.io_mode == 'online': self.write_obs_to_memory(seq, **kwargs) else: raise ValueError(f"Unknown io_mode: {self.c.config.io_mode}")
[docs] def write_obs_to_file(self, seq: dict, **kwargs): pass
[docs] def write_obs_to_memory(self, seq: dict, **kwargs): kwargs = self.parse_kwargs(kwargs) tag = kwargs['tag'] mstr = self.get_mstr(kwargs['member']) key = tag+mstr tstr = self.get_tstr(kwargs['time']) name = kwargs['name'] # create memory dict entry if not yet if tstr not in self.memory: self.memory[tstr] = {} if key not in self.memory[tstr]: self.memory[tstr][key] = {} self.memory[tstr][key][name] = seq
[docs] def save_memory(self, tag: str, time: datetime|None=None, path: str|None=None) -> None: if self.c.config.io_mode == 'offline': return if path is None: path = self.c.config.work_dir times_to_save = [self.get_tstr(time)] if time is not None else self.memory.keys() for tstr in times_to_save: if tstr not in self.memory: continue for key in self.memory[tstr]: if not key.startswith(tag): continue savedir = os.path.join(path, 'memory', 'dataset', self.dataset_name, tstr, key) self.c.fs.make_dir(savedir) for name in self.memory[tstr][key]: savefile = os.path.join(savedir, f'{name}.npy') np.save(savefile, np.array(self.memory[tstr][key][name], dtype=object))
[docs] def load_memory(self, tag: str, time: datetime|None=None, path: str|None=None) -> None: if self.c.config.io_mode == 'offline': return if path is None: path = self.c.config.work_dir tstr_pattern = self.get_tstr(time) if time is not None else '????????_????' search_path = os.path.join(path, 'memory', 'dataset', self.dataset_name, tstr_pattern, f'{tag}*', '*.npy') for savefile in glob.glob(search_path): # extract tstr and key tstr = os.path.basename(os.path.dirname(os.path.dirname(savefile))) key = os.path.basename(os.path.dirname(savefile)) name = os.path.splitext(os.path.basename(savefile))[0] # load data to memory if tstr not in self.memory: self.memory[tstr] = {} if key not in self.memory[tstr]: self.memory[tstr][key] = {} self.memory[tstr][key][name] = np.load(savefile, allow_pickle=True).item()