Source code for NEDAS.core.obs

import numpy as np
from datetime import datetime
from NEDAS.utils.conversion import t2h, ensure_list
from NEDAS.utils.parallel import bcast_by_root, distribute_tasks
from NEDAS.datasets.synthetic import SyntheticObs
from .context import Context
from .types import LevelID, Levels, ProcID, ProcIDRec, PartitionID, ObsRecordID, ObsSeq, ObsEns, LocalObsEns, LocalObsSeq
from .obs_info import ObsInfo

[docs] class Obs: """ Class for handling observations. The observation has dimensions: variable, time, z, y, x Since the observation network is typically irregular, we store the obs record for each variable in a 1d sequence, with coordinates (t,z,y,x), and size nobs To parallelize workload, we distribute each obs record over all the processors for batch assimilation mode, each pid stores the list of local obs within the hroi of its tiles, with size nlobs (number of local obs) for serial mode, each pid stores a non-overlapping subset of the obs list, here 'local' obs (in storage sense) is broadcast to all pid before computing its update to the state/obs near that obs. The hroi is separately defined for each obs record. For very large hroi, the serial mode is more parallel efficient option, since in batch mode the same obs may need to be stored in multiple pids To compare to the observation, obs_prior simulated by the model needs to be computed, they have dimension [nens, nlobs], indexed by (mem_id, obs_id) """ obs_rec_list: dict[ProcIDRec, list[ObsRecordID]] obs_inds: dict # will be created by assimilator.assign_obs() obs_seq: ObsSeq # will be created by self.prepare_obs() obs_prior: ObsEns # will be created by self.prepare_obs_from_state() lobs: LocalObsSeq # will be created by self.transpose_to_ensemble_complete() lobs_prior: LocalObsEns lobs_post: LocalObsEns # will be created by assimilator.assimilate() obs_post: ObsEns # will be created by self.transpose_to_field_complete() data: dict # will be created by self.pack_obs_data, for use in assimilator.assimilate() def __init__(self, c: Context): self.info = bcast_by_root(c.comm)(ObsInfo)(c) self.obs_rec_list = bcast_by_root(c.comm)(self.distribute_obs_tasks)(c) self.obs_inds = {} self.obs_seq = {} self.obs_prior = {} self.lobs = {} self.lobs_prior = {} self.lobs_post = {} self.obs_post = {} self.data = {}
[docs] def distribute_obs_tasks(self, c: Context): """ Distribute obs_rec_id across processors Args: c (Context): the runtime context object. Returns: dict: Dictionary {pid_rec (int): list[obs_rec_id (int)]} """ obs_rec_list_full = [i for i in self.info.records.keys()] obs_rec_size = np.array([2 if r.is_vector else 1 for i,r in self.info.records.items()]) obs_rec_list = distribute_tasks(c.comm_rec, obs_rec_list_full, obs_rec_size) return obs_rec_list
[docs] def get_ref_z(self, c: Context, model_name: str, time: datetime) -> dict[LevelID, np.ndarray]: """ Get the reference z coords at level k on the analysis grid from a model, to be used in dataset modules for generating generate_obs_network or superobing/thinning in read_obs. Args: c (Context): the runtime context model_name (str): the model name time (datetime): the time of the model state Returns: dict[LevelID, np.ndarray]: the z coords field at each level """ if c.config.z_coords_from == 'mean': # will read the ensemble mean z coords as reference ztag = 'z_mean' elif c.config.z_coords_from == 'member': # will just read z coords from the first member as reference ztag = 'z' else: raise ValueError("unknown config.z_coords_from: {c.config.z_coords_from}") # get the full list of vertical level ids k_list = list(set([r.k for r in c.state.info.fields.values() if r.time==time and r.model_src==model_name])) z = {} for k in k_list: # get the rec_id corresponding to this level # there can be multiple records (different state variables), we only need to take the first one rec_id_find = [i for i,r in c.state.info.fields.items() if r.time==time and r.model_src==model_name and r.k==k] if len(rec_id_find) == 0: raise RuntimeError(f"no record in state.info.fields found for vertical level id {k}") rec_id = rec_id_find[0] rec = c.state.info.fields[rec_id] # read the z field with (mem_id=0, rec_id) from state.fields_z and add to z dict z_fld = c.io.read_field(c, ztag, rec_id, mem_id=0) z[k] = z_fld[0, ...] if rec.is_vector else z_fld return z
[docs] def state_to_obs(self, c: Context, tag: str, **kwargs) -> np.ndarray: """ Compute the corresponding obs value given the state variable(s), namely the "obs_prior" This function includes several ways to compute the obs_prior: 1, If obs_name is one of the variables provided by the model_src module, then model_src.read_var shall be able to provide the obs field defined on model native grid. Then we convert the obs field to the analysis grid and do vertical interpolation. 2, If obs_name is one of the variables provided by obs.obs_operator, we call it to obtain the obs seq. Typically the obs_operator performs more complex computation, such as path integration, radiative transfer model, etc. (slowest) Args: c (Context): the runtime context object tag (str): 'prior' or 'post', or 'truth' if generating synthetic obs **kwargs: Additional parameters - member: int, member index; or None if dealing with synthetic obs - name: str, obs variable name - time: datetime obj, time of the obs window - is_vector: bool, if True the obs is a vector measurement - dataset_src: str, dataset source module name providing the obs - model_src: str, model source module name providing the state - x, y, z, t: np.array, coordinates from obs_seq Returns: np.ndarray: Values corresponding to the obs_seq but from the state identified by kwargs """ obs_name = kwargs['name'] obs_x = np.array(kwargs['x']) obs_y = np.array(kwargs['y']) obs_z = np.array(kwargs['z']) dataset = c.datasets[kwargs['dataset_src']] model = c.models[kwargs['model_src']] if obs_name in model.variables: # option 1: ------------------- # obs_name is one of the variables provided by the model.read_var # then we just need to collect the 3D variable and interpolate in x,y,z nobs = len(obs_x) seq = np.full((2, nobs), np.nan) if kwargs['is_vector'] else np.full(nobs, np.nan) levels = model.variables[obs_name].levels fp, zp, dzp = None, None, None # previous layer no yet available for first level for k in levels: # get model fld and z values at level k kwargs['k'] = k fld, zfld = self.get_model_fld_z_on_grid(c, tag, **kwargs) # interpolate fld to the obs location in x,y,z f, z = self.horizontal_interp(c, fld, zfld, kwargs['is_vector'], obs_x, obs_y) seq, fp, zp, dzp = self.vertical_interp(seq, k, levels, f, fp, z, zp, dzp, obs_z) elif kwargs['name'] in dataset.obs_operator: # option 2: --------------------- # if dataset module provides an obs_operator, we use it to compute obs seq operator = dataset.obs_operator[kwargs['name']] # get the obs seq from operator seq = c.io.call_method(c, tag, operator, model=model, grid=c.grid, mask=c.grid.mask, **kwargs) else: raise ValueError(f"unable to obtain obs prior for '{kwargs['name']}'") return seq
[docs] def get_model_fld_z_on_grid(self, c: Context, tag: str, **kwargs) -> tuple[np.ndarray, np.ndarray]: """ Get obs variable field and z coords at level k and convert to c.grid """ model = c.models[kwargs['model_src']] if kwargs['name'] in [r['name'] for r in ensure_list(c.config.state_def)] and tag != 'truth': # the obs variable is one of the state variables # we can find its corresponding rec_id and call io.read_field to get it rec_id_found = [i for i,r in c.state.info.fields.items() if r.name==kwargs['name'] and r.time==kwargs['time'] and r.k==kwargs['k']] if len(rec_id_found) == 0: raise RuntimeError(f"field '{kwargs['name']}' at t={kwargs['time']} k={kwargs['k']} not found in state.info.fields") rec_id = rec_id_found[0] fld = c.io.read_field(c, tag, rec_id, kwargs['member']) zfld = c.io.read_field(c, 'z', rec_id, kwargs['member']) else: # otherwise, we get the field from by calling model.read_var model_fld = c.io.call_method(c, tag, model.read_var, **kwargs) model_z = c.io.call_method(c, 'z', model.z_coords, **kwargs) # convert the model fields to the analysis c.grid model.grid.set_destination_grid(c.grid) fld = model.grid.convert(model_fld, is_vector=kwargs['is_vector'], method=c.config.interp_method) z_ = model.grid.convert(model_z, is_vector=False, method=c.config.interp_method) zfld = np.array([z_, z_]) if kwargs['is_vector'] else z_ return fld, zfld
[docs] def horizontal_interp(self, c: Context, fld: np.ndarray, zfld: np.ndarray, is_vector: bool, obs_x: np.ndarray, obs_y: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ Interpolate fld and zfld horizontally to the obs_x,obs_y locations""" if is_vector: z1 = c.grid.interp(zfld[0, ...], obs_x, obs_y, method=c.config.interp_method) z = np.array([z1, z1]) f1 = c.grid.interp(fld[0, ...], obs_x, obs_y, method=c.config.interp_method) f2 = c.grid.interp(fld[1, ...], obs_x, obs_y, method=c.config.interp_method) f = np.array([f1, f2]) else: z = c.grid.interp(zfld, obs_x, obs_y, ) f = c.grid.interp(fld, obs_x, obs_y, method=c.config.interp_method) return f, z
[docs] def vertical_interp(self, seq: np.ndarray, k: int, levels: Levels, f: np.ndarray, fp: np.ndarray|None, z: np.ndarray, zp: np.ndarray|None, dzp: np.ndarray|None, obs_z: np.ndarray) -> tuple: """ Interpolate f(k) with z(k) coords vertically to the obs z locations. Vertical interp to obs_z, take ocean depth as example:: : - - - - : ------------z[k-2] ------------------------ k-1 - - - - f[k-1], fp }dzp prevous layer ---------- z[k-1], zp -------------------- k - - - - v[k], f }dz current layer ---------- z[k], z -------------------- k+1 - - - - v[k+1] layer thickness of the current level k is denoted as z, for the previous level as zp; the variable f are considered layer averages, so they are defined at layer centers. """ # index of current layer in all levels # we use index instead of k directly, since k is allowed to be either increasing or decreasing i = list(levels).index(k) dz = z # layer thickness for the first level if i == 0: # the first level: constant f from z0 to dz/2 inds = (obs_z >= np.minimum(0, 0.5*dz)) & (obs_z < np.maximum(0, 0.5*dz)) seq[..., inds] = f[..., inds] if i > 0: dz = z - zp # layer thickness for level k # in between levels: linear interp between fp and f assert fp is not None assert zp is not None assert dzp is not None z_fp = zp - 0.5*dzp z_f = zp + 0.5*dz inds = (obs_z >= np.minimum(z_fp, z_f)) & (obs_z < np.maximum(z_fp, z_f)) # there can be collapsed layers if z_f=z_fp zdiff = z_f - z_fp collapsed = (zdiff == 0) zdiff = np.where(collapsed, 1, zdiff) fi = ((z_f - obs_z)*fp + (obs_z - z_fp)*f) / zdiff fi = np.where(collapsed, fp, fi) seq[..., inds] = fi[..., inds] if i == len(levels)-1: # the last level: constant f from z-dz/2 to z inds = (obs_z >= np.minimum(z-0.5*dz, z)) & (obs_z <= np.maximum(z-0.5*dz, z)) seq[..., inds] = f[..., inds] # save a copy of the current layer as 'previous' layer for use in next k fp, zp, dzp = f.copy(), z.copy(), dz.copy() return seq, fp, zp, dzp
[docs] def validate_seq_shape(self, seq: np.ndarray, is_vector: bool) -> None: """ Validate the shape of an observation sequence. Allowed shape: (nobs,) for scalar obs seq; (2, nobs) for vector obs seq. """ if not isinstance(seq, np.ndarray): raise TypeError(f"obs sequence must be a numpy array, got {type(seq)}") shape = seq.shape if is_vector: if len(shape) != 2: raise ValueError(f"vector obs sequence must have shape (2, nobs), got {shape}") if shape[0] != 2: raise ValueError(f"vector obs sequence first dimension must be 2, got {shape[0]}") else: if len(shape) != 1: raise ValueError(f"scalar obs sequence must have shape (nobs,), got {shape}")
[docs] def collect_obs_seq(self, c: Context) -> ObsSeq: """ Process the obs in parallel, read dataset files and convert to obs_seq which contains obs value, coordinates and other info Since this is the actual obs (1 copy), only 1 processor needs to do the work Argss: c (Context): The runtime context object. Returns: ObsSeq: observation sequence. Dictionary {obs_rec_id (int): record} where each record is a dictionary {key: np.ndarray}, the mandatory keys are 'obs' the observed values (measurements) 'x', 'y', 'z', 't' the coordinates for each measurement 'err_std' the uncertainties for each measurement there can be other optional keys provided by read_obs() but we don't use them """ c.debug_message = 'read observation sequence from datasets' c.progress.set_flag('running') # get obs_seq from dataset module, each pid_rec gets its own workload as a subset of obs_rec_list obs_seq = {} for obs_rec_id in self.obs_rec_list[c.pid_rec]: obs_rec = self.info.records[obs_rec_id] # load the dataset module dataset = c.datasets[obs_rec.dataset_src] if obs_rec.name not in dataset.variables: raise ValueError(f"variable '{obs_rec.name}' not defined in dataset.{obs_rec.dataset_src}.variables") model = c.models[obs_rec.model_src] ref_z = self.get_ref_z(c, obs_rec.model_src, obs_rec.time) if isinstance(dataset, SyntheticObs): #using synthetic observation # generate synthetic obs network seq = dataset.generate_obs_network(model=model, grid=c.grid, mask=c.grid.mask, z=ref_z, **obs_rec.asdict(), tag='truth') # compute obs values seq['obs'] = self.state_to_obs(c, 'truth', member=None, **obs_rec.asdict(), **seq) # perturb with obs err # TODO: only support normal err_type here seq['obs'] += np.random.normal(0, 1, seq['obs'].shape) * obs_rec.err.std else: # read dataset files and obtain obs sequence seq = dataset.read_obs(model=model, grid=c.grid, mask=c.grid.mask, z=ref_z, **obs_rec.asdict(), tag='raw') self.validate_seq_shape(seq['obs'], obs_rec.is_vector) if c.pid_mem == 0: c.debug_message = f"number of '{obs_rec.name}' obs from '{obs_rec.dataset_src}': {seq['obs'].shape[-1]}" # misc. transform here for transform_func in c.transform_funcs: seq = transform_func.forward_obs(c, obs_rec, seq) obs_seq[obs_rec_id] = seq obs_rec.nobs = seq['obs'].shape[-1] # update nobs in obs_rec c.io.call_method(c, 'raw', dataset.write_obs, seq, **obs_rec.asdict(), member=None) # output obs sequence for debugging if c.debug and c.pid_mem == 0: for obs_rec_id, rec in obs_seq.items(): c.io.save_debug_data(c, f'obs_seq.rec{obs_rec_id}', rec, path=c.fs.analysis_dir(c.time, c.iter)) return obs_seq
[docs] def prepare_obs(self, c: Context) -> None: self.obs_seq = bcast_by_root(c.comm_mem)(self.collect_obs_seq)(c)
[docs] def prepare_obs_from_state(self, c: Context, tag: str) -> None: """ Compute the obs priors in parallel, run state_to_obs to obtain obs_prior_seq Args: c (Context): the runtime context object tag (str): 'prior' or 'post' ensemble model states """ mem_list = c.mem_list pid_mem_show = [p for p,lst in mem_list.items() if len(lst)>0][0] pid_rec_show = [p for p,lst in self.obs_rec_list.items() if len(lst)>0][0] c.pid_show = pid_rec_show * c.config.nproc_mem + pid_mem_show # process the obs, each proc gets its own workload as a subset of # all proc goes through their own task list simultaneously nr = len(self.obs_rec_list[c.pid_rec]) nm = len(mem_list[c.pid_mem]) c.total_tasks = nr * nm for m, mem_id in enumerate(mem_list[c.pid_mem]): for r, obs_rec_id in enumerate(self.obs_rec_list[c.pid_rec]): # this is the obs record to process obs_rec = self.info.records[obs_rec_id] dataset = c.datasets[obs_rec.dataset_src] c.debug_message = f"obs_prior mem{mem_id+1:03} {obs_rec.name:20}" c.current_task = m*nr+r seq = {} # need the coordinates for transform later for key in ['x', 'y', 'z', 't', 'err_std']: seq[key] = self.obs_seq[obs_rec_id][key] # obtain obs_prior values from model state seq['obs'] = self.state_to_obs(c, tag, member=mem_id, **obs_rec.asdict(), **self.obs_seq[obs_rec_id]) # misc. transform here for transform_func in c.transform_funcs: seq = transform_func.forward_obs(c, obs_rec, seq) c.io.call_method(c, tag, dataset.write_obs, seq, **obs_rec.asdict(), member=mem_id) # collect obs ensemble data to the local memory getattr(self, f"obs_{tag}")[mem_id, obs_rec_id] = seq['obs'] c.comm.Barrier() # output the obs sequeneces for debugging if c.debug: for key, seq in getattr(self, f"obs_{tag}").items(): mem_id, obs_rec_id = key file = f'obs_{tag}.rec{obs_rec_id}.mem{mem_id:03}' c.io.save_debug_data(c, file, {f'obs_{tag}':seq}, path=c.fs.analysis_dir(c.time, c.iter))
[docs] def global_obs_list(self, c: Context) -> list[tuple[ObsRecordID, int|None, ProcID, int]]: # form the global list of obs (in serial mode the main loop is over this list) n_obs_rec = len(self.info.records) i = {} # location in full obs vector on owner pid for owner_pid in range(c.config.nproc_mem): i[owner_pid] = 0 obs_list = [] for obs_rec_id in range(n_obs_rec): obs_rec = self.info.records[obs_rec_id] v_list = [0, 1] if obs_rec.is_vector else [None] for owner_pid in self.obs_inds[obs_rec_id].keys(): for _ in self.obs_inds[obs_rec_id][owner_pid]: for v in v_list: obs_list.append((obs_rec_id, v, owner_pid, i[owner_pid])) i[owner_pid] += 1 if getattr(c, 'shuffle_obs', False): np.random.shuffle(obs_list) # randomize the order of obs (this is optional) return obs_list
[docs] def transpose_obs_seq(self, c: Context, input_obs: ObsSeq) -> LocalObsSeq: """ Transpose the obs sequence from field-complete to ensemble-complete Args: c (Context): the runtime context input_obs (ObsSeq): obs_seq from process_all_obs(), dict[obs_rec_id, dict[key, np.array]] Returns, LocalObsSeq: the lobs dict[obs_rec_id, dict[par_id, dict[key, np.array]]], key = 'obs','x','y','z','t'... """ mem_list = c.mem_list nproc_mem = c.config.nproc_mem pid_mem_show = [p for p,lst in mem_list.items() if len(lst)>0][0] pid_rec_show = [p for p,lst in self.obs_rec_list.items() if len(lst)>0][0] c.pid_show = pid_rec_show * nproc_mem + pid_mem_show # Step 1: transpose to ensemble-complete by exchanging mem_id, par_id in comm_mem # input_obs -> tmp_obs tmp_obs = {} # local obs at intermediate stage nr = len(self.obs_rec_list[c.pid_rec]) nm_max = np.max([len(lst) for p,lst in mem_list.items()]) c.total_tasks = nr * nm_max for r, obs_rec_id in enumerate(self.obs_rec_list[c.pid_rec]): # all pid goes through their own mem_list simultaneously for m in range(nm_max): mem_id = None seq = None if m < len(mem_list[c.pid_mem]): mem_id = mem_list[c.pid_mem][m] status = f"processing mem{mem_id+1:03} obs_rec{obs_rec_id}" if mem_id else "waiting" c.debug_message = f"transposing obs: {status}" c.current_task = r*nm_max+m # prepare the obs seq for sending if not at the end of mem_list if m < len(mem_list[c.pid_mem]): mem_id = mem_list[c.pid_mem][m] if mem_id == 0: # this is the obs seq, just let mem_id=0 send it seq = input_obs[obs_rec_id].copy() # the collective send/recv follows the same idea under state.transpose_field_to_state # 1) receive lobs_seq from src_pid, for src_pid<pid first for src_pid in range(0, c.pid_mem): if m < len(mem_list[src_pid]): src_mem_id = mem_list[src_pid][m] if src_mem_id == 0: tmp_obs[obs_rec_id] = c.comm_mem.recv(source=src_pid, tag=m) # 2) send my obs chunk to a list of dst_pid, send to dst_pid>=pid first # then cycle back to send to dst_pid<pid. i.e. the dst_pid sequence is # [pid, pid+1, ..., nproc-1, 0, 1, ..., pid-1] if m < len(mem_list[c.pid_mem]): for dst_pid in np.mod(np.arange(nproc_mem)+c.pid_mem, nproc_mem): if mem_id == 0: # this is the obs seq with keys 'obs','err_std','x','y','z','t' # assemble the lobs_seq dict with same keys but subset obs_inds # do this for each par_id to get the full lobs_seq lobs_seq = {} for par_id in c.state.par_list[dst_pid]: lobs_seq[par_id] = {} inds = self.obs_inds[obs_rec_id][par_id] assert seq is not None for key in ('obs', 'err_std', 'x', 'y', 'z', 't'): lobs_seq[par_id][key] = seq[key][..., inds] if dst_pid == c.pid_mem: # pid already stores the lobs_seq, just copy tmp_obs[obs_rec_id] = lobs_seq else: # send lobs_seq to dst_pid's lobs c.comm_mem.send(lobs_seq, dest=dst_pid, tag=m) # 3) finish receiving lobs_seq from src_pid, for src_pid>pid now for src_pid in range(c.pid_mem+1, nproc_mem): if m < len(mem_list[src_pid]): src_mem_id = mem_list[src_pid][m] if src_mem_id == 0: tmp_obs[obs_rec_id] = c.comm_mem.recv(source=src_pid, tag=m) c.comm.Barrier() # Step 2: collect all obs records (all obs_rec_ids) on pid_rec # tmp_obs -> output_obs output_obs = {} for entry in c.comm_rec.allgather(tmp_obs): for key, data in entry.items(): output_obs[key] = data c.comm.Barrier() return output_obs
[docs] def transpose_to_ensemble_complete(self, c: Context, input_obs: ObsEns) -> LocalObsEns: """ Transpose obs from field-complete to ensemble-complete Step 1, Within comm_mem, send the subset of input_obs with mem_id and par_id from the source proc (src_pid) to the destination proc (dst_pid), store the result in tmp_obs with all the mem_id (ensemble-complete) Step 2, Gather all obs_rec_id within comm_rec, so that each pid_rec will have the entire obs record for assimilation Args: c (Context): the runtime context input_obs (ObsEns): obs_prior from process_all_obs_priors(), dict[(mem_id, obs_rec_id), np.array]; Returns, LocalObsEns: the lobs_prior dict[(mem_id, obs_rec_id), dict[par_id, np.array]] """ mem_list = c.mem_list nproc_mem = c.config.nproc_mem pid_mem_show = [p for p,lst in mem_list.items() if len(lst)>0][0] pid_rec_show = [p for p,lst in self.obs_rec_list.items() if len(lst)>0][0] c.pid_show = pid_rec_show * nproc_mem + pid_mem_show c.debug_message = 'transpose obs prior ensemble to local obs priors' # Step 1: transpose to ensemble-complete by exchanging mem_id, par_id in comm_mem # input_obs -> tmp_obs tmp_obs = {} # local obs at intermediate stage nr = len(self.obs_rec_list[c.pid_rec]) nm_max = np.max([len(lst) for p,lst in mem_list.items()]) c.total_tasks = nr * nm_max for r, obs_rec_id in enumerate(self.obs_rec_list[c.pid_rec]): # all pid goes through their own mem_list simultaneously for m in range(nm_max): mem_id = None seq = None # prepare the obs seq for sending if not at the end of mem_list if m < len(mem_list[c.pid_mem]): mem_id = mem_list[c.pid_mem][m] seq = input_obs[mem_id, obs_rec_id].copy() status = f"processing mem{mem_id+1:03} obs_rec{obs_rec_id}" if mem_id else "waiting" c.debug_message = f"transposing obs: {status}" c.current_task = r*nm_max+m # the collective send/recv follows the same idea under state.transpose_field_to_state # 1) receive lobs_seq from src_pid, for src_pid<pid first for src_pid in range(0, c.pid_mem): if m < len(mem_list[src_pid]): src_mem_id = mem_list[src_pid][m] tmp_obs[src_mem_id, obs_rec_id] = c.comm_mem.recv(source=src_pid, tag=m) # 2) send my obs chunk to a list of dst_pid, send to dst_pid>=pid first # then cycle back to send to dst_pid<pid. i.e. the dst_pid sequence is # [pid, pid+1, ..., nproc-1, 0, 1, ..., pid-1] if m < len(mem_list[c.pid_mem]): for dst_pid in np.mod(np.arange(nproc_mem)+c.pid_mem, nproc_mem): # this is the obs prior seq for mem_id, obs_rec_id # for each par_id, assemble the subset lobs_seq using obs_inds lobs_seq = {} for par_id in c.state.par_list[dst_pid]: inds = self.obs_inds[obs_rec_id][par_id] assert seq is not None lobs_seq[par_id] = seq[..., inds] if dst_pid == c.pid_mem: # pid already stores the lobs_seq, just copy tmp_obs[mem_id, obs_rec_id] = lobs_seq else: # send lobs_seq to dst_pid c.comm_mem.send(lobs_seq, dest=dst_pid, tag=m) # 3) finish receiving lobs_seq from src_pid, for src_pid>pid now for src_pid in range(c.pid_mem+1, nproc_mem): if m < len(mem_list[src_pid]): src_mem_id = mem_list[src_pid][m] tmp_obs[src_mem_id, obs_rec_id] = c.comm_mem.recv(source=src_pid, tag=m) c.comm.Barrier() # Step 2: collect all obs records (all obs_rec_ids) on pid_rec # tmp_obs -> output_obs output_obs = {} for entry in c.comm_rec.allgather(tmp_obs): for key, data in entry.items(): output_obs[key] = data c.comm.Barrier() return output_obs
[docs] def transpose_to_field_complete(self, c: Context, lobs: LocalObsEns) -> ObsEns: """ Transpose obs from ensemble-complete to field-complete Args: c (Context): the runtime context lobs (LocalObsEns): ensemble-complete local obs Returns: ObsEns: field-complete obs_seq ensemble """ mem_list = c.mem_list nproc_mem = c.config.nproc_mem pid_mem_show = [p for p,lst in mem_list.items() if len(lst)>0][0] pid_rec_show = [p for p,lst in self.obs_rec_list.items() if len(lst)>0][0] c.pid_show = pid_rec_show * nproc_mem + pid_mem_show c.debug_message = 'obs post sequences: ' c.debug_message = 'transpose local obs to obs' obs_seq = {} nr = len(self.obs_rec_list[c.pid_rec]) nm_max = np.max([len(lst) for p,lst in mem_list.items()]) c.total_tasks = nr * nm_max for r, obs_rec_id in enumerate(self.obs_rec_list[c.pid_rec]): # all pid goes through their own mem_list simultaneously for m in range(nm_max): mem_id = None seq = None if m < len(mem_list[c.pid_mem]): mem_id = mem_list[c.pid_mem][m] rec = self.info.records[obs_rec_id] # prepare an empty obs_seq for receiving if not at the end of mem_list if rec.is_vector: seq = np.full((2, rec.nobs), np.nan) else: seq = np.full((rec.nobs,), np.nan) status = f"processing mem{mem_id+1:03} obs_rec{obs_rec_id}" if mem_id else "waiting" c.debug_message = f"transposing obs: {status}" c.current_task = r*nm_max+m # this is just the reverse of transpose_obs_to_lobs # we take the exact steps, but swap send and recv operations here # # 1) send my lobs to dst_pid, for dst_pid<pid first for dst_pid in range(0, c.pid_mem): if m < len(mem_list[dst_pid]): dst_mem_id = mem_list[dst_pid][m] c.comm_mem.send(lobs[dst_mem_id, obs_rec_id], dest=dst_pid, tag=m) # 2) receive fld_chk from a list of src_pid, from src_pid>=pid first # because they wait to send stuff before able to receive themselves, # cycle back to receive from src_pid<pid then. if m < len(mem_list[c.pid_mem]): assert mem_id is not None for src_pid in np.mod(np.arange(nproc_mem)+c.pid_mem, nproc_mem): if src_pid == c.pid_mem: # pid already stores the lobs_seq, just copy lobs_seq = lobs[mem_id, obs_rec_id].copy() else: # send lobs_seq to dst_pid lobs_seq = c.comm_mem.recv(source=src_pid, tag=m) # unpack the lobs_seq to form a complete seq for par_id in c.state.par_list[src_pid]: inds = self.obs_inds[obs_rec_id][par_id] assert seq is not None seq[..., inds] = lobs_seq[par_id] obs_seq[mem_id, obs_rec_id] = seq # 3) finish sending lobs_seq to dst_pid, for dst_pid>pid now for dst_pid in range(c.pid_mem+1, nproc_mem): if m < len(mem_list[dst_pid]): dst_mem_id = mem_list[dst_pid][m] c.comm_mem.send(lobs[dst_mem_id, obs_rec_id], dest=dst_pid, tag=m) c.comm.Barrier() return obs_seq
[docs] def pack_local_obs_data(self, c: Context, par_id: PartitionID, lobs: LocalObsSeq, lobs_prior: LocalObsEns) -> dict: """pack lobs and lobs_prior into arrays for the jitted functions""" n_obs_rec = len(self.info.records) # number of obs records n_state_var = len(c.state.info.variables) # number of state variable names # filter out obs with nan in obs_prior, valid index stored as subset of local_inds nlobs = 0 # number of local obs on partition self.valid = {} for obs_rec_id in range(n_obs_rec): obs_rec = self.info.records[obs_rec_id] v_list = [0, 1] if obs_rec.is_vector else [None] values = np.stack([lobs_prior[m, obs_rec_id][par_id][v, :].flatten() for m in range(c.nens) for v in v_list], axis=0) no_nan_mask = ~np.isnan(values).any(axis=0) self.valid[obs_rec_id] = np.where(no_nan_mask)[0].tolist() nlobs += len(self.valid[obs_rec_id]) * len(v_list) data = {} data['obs_rec_id'] = np.zeros(nlobs, dtype=int) data['obs'] = np.full(nlobs, np.nan) data['x'] = np.full(nlobs, np.nan) data['y'] = np.full(nlobs, np.nan) data['z'] = np.full(nlobs, np.nan) data['t'] = np.full(nlobs, np.nan) data['err_std'] = np.full(nlobs, np.nan) data['obs_prior'] = np.full((c.nens, nlobs), np.nan) data['used'] = np.full(nlobs, False) data['hroi'] = np.ones(n_obs_rec) data['vroi'] = np.ones(n_obs_rec) data['troi'] = np.ones(n_obs_rec) data['impact_on_state'] = np.ones((n_obs_rec, n_state_var)) i = 0 for obs_rec_id in range(n_obs_rec): obs_rec = self.info.records[obs_rec_id] v_list = [0, 1] if obs_rec.is_vector else [None] data['hroi'][obs_rec_id] = obs_rec.hroi data['vroi'][obs_rec_id] = obs_rec.vroi data['troi'][obs_rec_id] = obs_rec.troi for state_var_id in range(len(c.state.info.variables)): state_vname = c.state.info.variables[state_var_id] data['impact_on_state'][obs_rec_id, state_var_id] = obs_rec.impact_on_state[state_vname] valid = self.valid[obs_rec_id] local_inds = self.obs_inds[obs_rec_id][par_id] d = len(local_inds[valid]) # append obs and obs prior records to the full array for v in v_list: data['obs_rec_id'][i:i+d] = obs_rec_id data['obs'][i:i+d] = np.squeeze(lobs[obs_rec_id][par_id]['obs'][v, valid]) data['x'][i:i+d] = lobs[obs_rec_id][par_id]['x'][valid] data['y'][i:i+d] = lobs[obs_rec_id][par_id]['y'][valid] data['z'][i:i+d] = lobs[obs_rec_id][par_id]['z'][valid].astype(np.float32) data['t'][i:i+d] = np.array([t2h(t) for t in lobs[obs_rec_id][par_id]['t'][valid]]) data['err_std'][i:i+d] = lobs[obs_rec_id][par_id]['err_std'][valid] for m in range(c.nens): data['obs_prior'][m, i:i+d] = np.squeeze(lobs_prior[m, obs_rec_id][par_id][v, valid].copy()) i += d return data
[docs] def unpack_local_obs_data(self, c: Context, par_id: PartitionID, lobs: LocalObsSeq, lobs_prior: LocalObsEns, data: dict) -> None: """unpack data and write back to the original lobs_prior dict""" n_obs_rec = len(self.info.records) i = 0 for obs_rec_id in range(n_obs_rec): obs_rec = self.info.records[obs_rec_id] valid = self.valid[obs_rec_id] local_inds = self.obs_inds[obs_rec_id][par_id] d = len(local_inds[valid]) v_list = [0, 1] if obs_rec.is_vector else [None] for v in v_list: for m in range(c.nens): lobs_prior[m, obs_rec_id][par_id][v, valid] = data['obs_prior'][m, i:i+d] i += d