Source code for NEDAS.core.obs_info

import numpy as np
from NEDAS.utils.conversion import dt1h, ensure_list
from .types import ErrorModel, ObsRecord
from .context import Context

[docs] class ObsInfo: """ Manages the metadata, indexing and memory allocation for the observation sequences Attributes: records (dict[int], ObsRecord]): dictionary containing obs_rec_id and the corresponding obs record variables (set[str]): set of unique variables in the observations err_types (set[str]): set of unique error models used in the observations """ records: dict[int, ObsRecord] variables: list[str] err_types: list[str] def __init__(self, c: Context): """ Parse the configuration to generate the observation info object. Args: c (Context): the runtime context object. Returns: dict: A dictionary with some dimensions and list of unique obs records """ self.records = {} variables = set() err_types = set() # loop through variables in obs_def for vrec in ensure_list(c.config.obs_def): vname = vrec['name'] variables.add(vname) if 'err' not in vrec or vrec['err'] is None: vrec['err'] = {} assert isinstance(vrec.get('err'), dict), f"obs_def: {vname}: expect 'err' to be a dictionary" err_types.add(vrec['err'].get('type', 'normal')) self.add_obs_record(c, vrec) # convert set to list, for later indexing self.variables = list(variables) self.err_types = list(err_types) self.complete_err_cross_corr_matrix() c.debug_message = f"number of unique observation records = {len(self.records)}" c.debug_message = f"observation variables: {self.variables}"
[docs] def add_obs_record(self, c: Context, vrec: dict): """ Add observation record Args: c (Context): the runtime context object vrec (dict): the observation record defining its properties """ vname = vrec['name'] dataset = c.datasets[vrec['dataset_src']] variables = dataset.variables assert vname in variables, 'variable '+vname+' not defined in '+vrec['dataset_src']+'.dataset.variables' # parse impact of obs on each state variable, default is 1.0 on all variables unless set by obs_def record impact_on_state = {} for state_name in c.state.info.variables: impact_on_state[state_name] = 1.0 if 'impact_on_state' in vrec and vrec['impact_on_state'] is not None: for state_name, impact_fac in vrec['impact_on_state'].items(): impact_on_state[state_name] = impact_fac # loop through time steps in obs window time_steps = c.time + np.array(c.config.obs_time_steps)*dt1h rec_id = len(self.records) for time in time_steps: err_opts = vrec['err'] err = ErrorModel( type=err_opts.get('type', 'normal'), std=err_opts.get('std', 1.), hcorr=err_opts.get('hcorr',0.), vcorr=err_opts.get('vcorr',0.), tcorr=err_opts.get('tcorr',0.), cross_corr=err_opts.get('cross_corr',{}), ) rec = ObsRecord( name=vname, dataset_src=vrec['dataset_src'], model_src=vrec['model_src'], nobs=vrec.get('nobs', 0), # for synthetic observation use only, real obs will count nobs later in prepare_obs obs_window_min=vrec.get('obs_window_min', dataset.obs_window_min), obs_window_max=vrec.get('obs_window_max', dataset.obs_window_max), dtype=variables[vname].dtype, is_vector=variables[vname].is_vector, units=variables[vname].units, z_units=variables[vname].z_units, time=time, dt=0, err=err, hroi=vrec['hroi'] * c.config.localize_scale_fac[c.iter], vroi=vrec['vroi'], troi=vrec['troi'], impact_on_state=impact_on_state, ) self.records[rec_id] = rec
[docs] def complete_err_cross_corr_matrix(self): """Go through the obs error cross correlation matrix again to fill in the default values""" for obs_rec_id, obs_rec in self.records.items(): if not isinstance(obs_rec.err.cross_corr, dict): raise TypeError(f"obs_def: {obs_rec.name} has err.cross_corr defined as {obs_rec.err.cross_corr}, expecting a dictionary") for vname in self.variables: if vname not in obs_rec.err.cross_corr: if vname == obs_rec.name: obs_rec.err.cross_corr[vname] = 1.0 else: obs_rec.err.cross_corr[vname] = 0.0 else: if not isinstance(obs_rec.err.cross_corr[vname], float): raise TypeError(f"obs_def: {obs_rec.name} has err.cross_corr.{vname} defined as {obs_rec.err.cross_corr[vname]}, expecting a float")
# def write_obs_info(self, binfile): # with open(binfile.replace('.bin','.dat'), 'wt') as f: # f.write('{} {}\n'.format(self.info['nobs'], self.info['nens'])) # for rec in self.info['obs_seq'].values(): # f.write('{} {} {} {} {} {} {} {} {} {} {} {} {} {}\n'.format(rec['name'], rec['dataset_src'], rec['model_src'], rec['dtype'], int(rec['is_vector']), rec['units'], rec['z_units'], rec['x'], rec['y'], rec['z'], t2h(rec['time']), rec['pos'])) # # read obs_info from the dat file # def read_obs_info(self, binfile): # with open(binfile.replace('.bin','.dat'), 'r') as f: # lines = f.readlines() # ss = lines[0].split() # self.info = {'nobs':int(ss[0]), 'nens':int(ss[1]), 'obs_seq':{}} # # following lines of obs records # obs_id = 0 # for lin in lines[1:]: # ss = lin.split() # rec = {'name': ss[0], # 'dataset_src': ss[1], # 'model_src': ss[2], # 'dtype': ss[3], # 'is_vector': bool(int(ss[4])), # 'units': ss[5], # 'z_units':ss[6], # 'err_type': ss[7], # 'err': float(ss[8]), # 'x': float(ss[9]), # 'y': float(ss[10]), # 'z': float(ss[11]), # 'time': h2t(float(ss[12])), # 'pos': int(ss[13]), } # self.info['obs_seq'][obs_id] = rec # obs_id += 1