Source code for NEDAS.assim_tools.assimilators.batch

import copy
from abc import abstractmethod
import numpy as np
from NEDAS.utils.parallel import distribute_tasks
from NEDAS.core import Context, Assimilator

[docs] class BatchAssimilator(Assimilator):
[docs] def init_partitions(self, c: Context) -> list: """ Generate spatial partitioning of the domain partitions: dict[par_id, tuple(istart, iend, di, jstart, jend, dj)] for each partition indexed by par_id, the tuple contains indices for slicing the domain Using regular slicing is more efficient than fancy indexing (used in irregular grid) """ if len(c.grid.x.shape) == 2: ny, nx = c.grid.x.shape # divide into square tiles with nx_tile grid points in each direction # the workload on each tile is uneven since there are masked points # so we divide into 3*nproc tiles so that they can be distributed # according to their load (number of unmasked points) ntile = c.config.nproc_mem * 3 nx_tile = max(int(np.round(np.sqrt(nx * ny / ntile))), 1) # a list of (istart, iend, di, jstart, jend, dj) for tiles # note: we have 3*nproc entries in the list partitions = [(i, min(i+nx_tile, nx), 1, # istart, iend, di j, min(j+nx_tile, ny), 1) # jstart, jend, dj for j in range(0, ny, nx_tile) for i in range(0, nx, nx_tile) ] else: # divide the domain into sqaure tiles, similar to regular_grid case, but collect # the grid points inside each tile and return the indices ntile = c.config.nproc_mem * 3 if c.grid.Ly==0: # for 1D grid, just divide into equal sections, no y dimension Dx = c.grid.Lx / ntile partitions = [np.where(np.logical_and(c.grid.x>=x, c.grid.x<x+Dx))[0] for x in np.arange(c.grid.xmin, c.grid.xmax, Dx)] else: # for 2D grid, find number of tiles in each direction according to aspect ratio ntile_y = max(int(np.sqrt(ntile * c.grid.Ly / c.grid.Lx)), 1) ntile_x = max(ntile // ntile_y, 1) Dx = c.grid.Lx / ntile_x Dy = c.grid.Ly / ntile_y partitions = [np.where(np.logical_and(np.logical_and(c.grid.x>=x, c.grid.x<x+Dx), np.logical_and(c.grid.y>=y, c.grid.y<y+Dy)))[0] for y in np.arange(c.grid.ymin, c.grid.ymax, Dy) for x in np.arange(c.grid.xmin, c.grid.xmax, Dx)] return partitions
[docs] def assign_obs(self, c: Context) -> dict: """ Assign the observation sequence to each partition par_id """ # each pid_rec has a subset of obs_rec_list obs_inds_pid = {} for obs_rec_id in c.obs.obs_rec_list[c.pid_rec]: # screen horizontally for obs inside hroi of each partition obs_inds_pid[obs_rec_id] = self.assign_obs_to_tiles(c, c.state, c.obs, obs_rec_id) # now each pid_rec has figured out obs_inds for its own list of obs_rec_ids, we # gather all obs_rec_id from different pid_rec to form the complete obs_inds dict obs_inds = {} for entry in c.comm_rec.allgather(obs_inds_pid): for obs_rec_id, data in entry.items(): obs_inds[obs_rec_id] = data return obs_inds
[docs] def assign_obs_to_tiles(self, c, state, obs, obs_rec_id): hroi = obs.info.records[obs_rec_id].hroi xo = np.array(obs.obs_seq[obs_rec_id]['x']) # obs x,y yo = np.array(obs.obs_seq[obs_rec_id]['y']) # loop over partitions with par_id obs_inds = {} for par_id in range(len(state.partitions)): # find bounding box for this partition if len(c.grid.x.shape)==2: ist,ied,di,jst,jed,dj = state.partitions[par_id] xmin, xmax, ymin, ymax = c.grid.x[0,ist], c.grid.x[0,ied-1], c.grid.y[jst,0], c.grid.y[jed-1,0] else: inds = state.partitions[par_id] x = c.grid.x[inds] y = c.grid.y[inds] xmin, xmax, ymin, ymax = x.min(), x.max(), y.min(), y.max() Dx = 0.5 * (xmax - xmin) Dy = 0.5 * (ymax - ymin) xc = xmin + Dx yc = ymin + Dy # observations within the bounding box + halo region of width hroi will be assigned to # this partition. Although this will include some observations near the corner that are # not within hroi of any grid points, this is favorable for the efficiency in finding subset obs_inds[par_id] = np.where(np.logical_and(c.grid.distance(xc, xo, yc, yc, p=1) <= Dx+hroi, c.grid.distance(xc, xc, yc, yo, p=1) <= Dy+hroi))[0] return obs_inds
[docs] def distribute_partitions(self, c: Context): par_list_full = np.arange(len(c.state.partitions)) # distribute the list of par_id according to workload to each pid # number of unmasked grid points in each tile if len(c.grid.x.shape) == 2: nlpts_loc = np.array([np.sum((~c.grid.mask[jst:jed:dj, ist:ied:di]).astype(int)) for ist,ied,di,jst,jed,dj in c.state.partitions] ) else: nlpts_loc = np.array([np.sum((~c.grid.mask[inds]).astype(int)) for inds in c.state.partitions] ) # number of observations within the hroi of each tile, at loc, # sum over the len of obs_inds for obs_rec_id over all obs_rec_ids nlobs_loc = np.array([np.sum([len(c.obs.obs_inds[r][p]) for r in c.obs.info.records.keys()]) for p in par_list_full] ) workload = np.maximum(nlpts_loc, 1) * np.maximum(nlobs_loc, 1) par_list = distribute_tasks(c.comm_mem, par_list_full, workload) return par_list
[docs] def assimilation_algorithm(self, c: Context): """ batch assimilation solves the matrix version EnKF analysis for each local state, the local states in each partition are processed in parallel """ c.message = 'preparing...' c.state.state_post = copy.deepcopy(c.state.state_prior) # TODO: obs_prior is not updated to obs_post by the filter #c.obs.lobs_post = copy.deepcopy(c.obs.lobs_prior) # pid with the most obs in its task list with show progress message obs_count = np.array([np.sum([len(c.obs.obs_inds[r][p]) for r in c.obs.info.records.keys() for p in lst]) for lst in c.state.par_list.values()]) c.pid_show = np.argsort(obs_count)[-1] # count number of tasks ntask = 0 for par_id in c.state.par_list[c.pid_mem]: if len(c.grid.x.shape)==2: ist,ied,di,jst,jed,dj = c.state.partitions[par_id] msk = c.grid.mask[jst:jed:dj, ist:ied:di] else: inds = c.state.partitions[par_id] msk = c.grid.mask[inds] for loc_id in range(np.sum((~msk).astype(int))): ntask += 1 c.total_tasks = ntask # now the actual work starts, loop through partitions stored on pid_mem c.current_task = 0 for par_id in c.state.par_list[c.pid_mem]: state_data = c.state.pack_local_state_data(c, par_id, c.state.state_prior, c.state.state_z) nloc = state_data['state_prior'].shape[-1] # skip forward if the partition is empty if nloc == 0: continue obs_data = c.obs.pack_local_obs_data(c, par_id, c.obs.lobs, c.obs.lobs_prior) nlobs = obs_data['x'].size # if there is no obs to assimilate, update progress message and skip that partition if nlobs == 0: c.debug_message = f"processed partition {par_id:7} (which is empty)" c.current_task += nloc c.message = f"completed {c.current_task}/{c.total_tasks} state variables." continue # loop through the unmasked grid points in the partition for loc_id in range(nloc): # state variable metadata for this location state_x = state_data['x'][loc_id] state_y = state_data['y'][loc_id] # filter out obs outside the hroi in each direction first (using L1 norm to speed up) obs_rec_id = obs_data['obs_rec_id'] hroi = obs_data['hroi'][obs_rec_id] hdist = c.grid.distance(state_x, obs_data['x'], state_y, obs_data['y'], p=1) ind = np.where(hdist<=hroi)[0] # compute horizontal localization factor (using L2 norm for distance) obs_rec_id = obs_data['obs_rec_id'][ind] hroi = obs_data['hroi'][obs_rec_id] hdist = c.grid.distance(state_x, obs_data['x'][ind], state_y, obs_data['y'][ind], p=2) hlfactor = c.localization_funcs['horizontal'](hdist, hroi) ind1 = np.where(hlfactor>0)[0] ind = ind[ind1] hlfactor = hlfactor[ind1] if len(ind1) == 0: c.debug_message = f"processed partition {par_id:7} grid point {loc_id} (all local obs outside hroi)" c.current_task += 1 c.message = f"completed {c.current_task}/{c.total_tasks} state variables." continue # if all obs has no impact on state, just skip to next location self.local_analysis(c, loc_id, ind, hlfactor, state_data, obs_data) # add progress message c.debug_message = f"processed partition {par_id:7} grid point {loc_id}" c.current_task += 1 c.message = f"completed {c.current_task}/{c.total_tasks} state variables." c.state.unpack_local_state_data(c, par_id, c.state.state_post, state_data)
#c.obs.unpack_local_obs_data(c, par_id, c.obs.lobs, c.obs.lobs_post, obs_data)
[docs] @abstractmethod def local_analysis(self, c, loc_id, ind, hlfactor, state_data, obs_data): """Local analysis scheme for each model state variable (grid point) to be implemented by derived classes""" ...