Source code for NEDAS.assim_tools.assimilators.serial

import copy
from abc import abstractmethod
import numpy as np
from NEDAS.utils.parallel import bcast_by_root, distribute_tasks
from NEDAS.core import Context, Assimilator

[docs] class SerialAssimilator(Assimilator): """ Subclass for serial assimilation algorithms """ assim_mode = 'serial'
[docs] def init_partitions(self, c: Context) -> list: """ Generate spatial partitioning of the domain """ if len(c.grid.x.shape) == 2: ny, nx = c.grid.x.shape # the domain is divided into tiles, each is formed by nproc_mem elements # each element is stored on a different pid_mem # for each pid, its loc points cover the entire domain with some spacing # list of possible factoring of nproc_mem = nx_intv * ny_intv # pick the last factoring that is most 'square', so that the interval # is relatively even in both directions for each pid nx_intv, ny_intv = [(i, int(c.config.nproc_mem / i)) for i in range(1, int(np.ceil(np.sqrt(c.config.nproc_mem))) + 1) if c.config.nproc_mem % i == 0][-1] # a list of (ist, ied, di, jst, jed, dj) for slicing # note: we have nproc_mem entries in the list partitions = [(i, nx, nx_intv, j, ny, ny_intv) for j in range(ny_intv) for i in range(nx_intv) ] else: npoints = c.grid.x.size # just divide the list of points into nproc_mem parts, each part spanning the entire domain nparts = c.config.nproc_mem partitions = [np.arange(i, npoints, nparts) for i in np.arange(nparts)] return partitions
[docs] def assign_obs(self, c: Context): obs_inds_pid = {} for obs_rec_id in c.obs.obs_rec_list[c.pid_rec]: full_inds = np.arange(c.obs.obs_seq[obs_rec_id]['obs'].shape[-1]) obs_inds_pid[obs_rec_id] = {} # locality doesn't matter, we just divide obs_rec into nproc_mem parts inds = distribute_tasks(c.comm_mem, full_inds) for par_id in range(c.config.nproc_mem): obs_inds_pid[obs_rec_id][par_id] = inds[par_id] # now each pid_rec has figured out obs_inds for its own list of obs_rec_ids, we # gather all obs_rec_id from different pid_rec to form the complete obs_inds dict obs_inds = {} for entry in c.comm_rec.allgather(obs_inds_pid): for obs_rec_id, data in entry.items(): obs_inds[obs_rec_id] = data return obs_inds
[docs] def distribute_partitions(self, c: Context): # just assign each partition to each pid, pid==par_id par_list = {p:[p] for p in range(c.config.nproc_mem)} return par_list
[docs] def assimilation_algorithm(self, c: Context): """ Implementation of the serial assimilation algorithm. Notes: serial assimilation goes through the list of observations one by one for each obs the near by state variables are updated one by one. so each update is a scalar problem, which is solved in 2 steps: obs_increment, update_ensemble """ c.message = 'preparing...' c.state.state_post = copy.deepcopy(c.state.state_prior) c.obs.lobs_post =copy.deepcopy(c.obs.lobs_prior) par_id = c.pid_mem state_data = c.state.pack_local_state_data(c, par_id, c.state.state_prior, c.state.state_z) obs_data = c.obs.pack_local_obs_data(c, par_id, c.obs.lobs, c.obs.lobs_prior) obs_list = bcast_by_root(c.comm)(c.obs.global_obs_list)(c) # ens-complete pre transforms (probit) self.transform_ens_state_forward(state_data) self.transform_ens_obs_forward(obs_data) # go through the entire obs list, indexed by p, one scalar obs at a time c.total_tasks = len(obs_list) for p in range(len(obs_list)): obs_rec_id, v, owner_pid, i = obs_list[p] c.debug_message = f"Processing observation obs_rec_id={obs_rec_id:2}, i={i}" c.message = f"completed {c.current_task}/{c.total_tasks} observations." c.current_task = p # 1. if the pid owns this obs, broadcast it to all pid if c.pid_mem == owner_pid: # collect obs info obs_p = {} obs_p['prior'] = obs_data['obs_prior'][:, i] for key in ('obs', 'x', 'y', 'z', 't', 'err_std'): obs_p[key] = obs_data[key][i] for key in ('hroi', 'vroi', 'troi', 'impact_on_state'): obs_p[key] = obs_data[key][obs_rec_id] # mark this obs as used obs_data['used'][i] = True else: obs_p = None obs_p = c.comm_mem.bcast(obs_p, root=owner_pid) if np.isnan(obs_p['prior']).any() or np.isnan(obs_p['obs']): continue # compute obs-space increment obs_incr = self.obs_increment(obs_p['prior'], obs_p['obs'], obs_p['err_std']) # 2. all pid update their own locally stored state: state_h_dist = c.grid.distance(obs_p['x'], state_data['x'], obs_p['y'], state_data['y'], p=2) state_v_dist = np.abs(obs_p['z'] - state_data['z']) state_t_dist = np.abs(obs_p['t'] - state_data['t']) self.update_local_state(state_data['state_prior'], obs_p['prior'], obs_incr, state_h_dist, state_v_dist, state_t_dist, obs_p['hroi'], obs_p['vroi'], obs_p['troi'], c.localization_funcs['horizontal'], c.localization_funcs['vertical'], c.localization_funcs['temporal']) # 3. all pid update their own locally stored obs: obs_h_dist = c.grid.distance(obs_p['x'], obs_data['x'], obs_p['y'], obs_data['y'], p=2) obs_v_dist = np.abs(obs_p['z'] - obs_data['z']) obs_t_dist = np.abs(obs_p['t'] - obs_data['t']) self.update_local_obs(obs_data['obs_prior'], obs_data['used'], obs_p['prior'], obs_incr, obs_h_dist, obs_v_dist, obs_t_dist, obs_p['hroi'], obs_p['vroi'], obs_p['troi'], c.localization_funcs['horizontal'], c.localization_funcs['vertical'], c.localization_funcs['temporal']) # ens-complete inverse transforms (probit) self.transform_ens_state_backward(state_data) self.transform_ens_obs_backward(obs_data) c.state.unpack_local_state_data(c, par_id, c.state.state_post, state_data) c.obs.unpack_local_obs_data(c, par_id, c.obs.lobs, c.obs.lobs_post, obs_data)
[docs] @abstractmethod def obs_increment(self, obs_prior, obs, obs_err) -> np.ndarray: """ Compute observation-space analysis increments. Args: obs_prior (np.ndarray): Observation priors, 1-D float array of length nens obs (float): The real observation value obs_err (float): Observation error std Returns: ndarray: observation-space analysis increments """ pass
[docs] @abstractmethod def update_local_state(self, state_prior, obs_prior, obs_incr, state_h_dist, state_v_dist, state_t_dist, hroi, vroi, troi, h_local_func, v_local_func, t_local_func) -> None: """ Update the local state vector with the analysis increments. Args: state_data (np.ndarray): Local state vector, shape (nens, nfld, nloc) obs_prior (np.ndarray): Observation priors, shape (nens,) obs_incr (np.ndarray): Analysis increments, shape (nens,) """ pass
[docs] @abstractmethod def update_local_obs(self, obs_data, used, obs_prior, obs_incr, h_dist, v_dist, t_dist, hroi, vroi, troi, h_local_func, v_local_func, t_local_func) -> None: """ Update the local observations with analysis increments. Args: obs_data (np.ndarray): obs prior ensemble, shape (nens, nlobs) used (np.ndarray): boolean mask of already-assimilated obs """ pass
[docs] def transform_ens_state_forward(self, state_data): pass
[docs] def transform_ens_state_backward(self, state_data): pass
[docs] def transform_ens_obs_forward(self, obs_data): pass
[docs] def transform_ens_obs_backward(self, obs_data): pass