Source code for NEDAS.core.assimilator

import os
import inspect
from abc import ABC, abstractmethod
import numpy as np
from NEDAS.config import parse_config
from NEDAS.utils.parallel import bcast_by_root
from .context import Context
from .types import ObsRecordID, PartitionID, ProcIDMem

[docs] class Assimilator(ABC): def __init__(self, c: Context): # get parameters from config file code_dir = os.path.dirname(inspect.getfile(self.__class__)) config_dict = parse_config(code_dir, parse_args=False, **c.config.assimilator_def) for key, value in config_dict.items(): setattr(self, key, value)
[docs] def assimilate(self, c: Context): """ Main method to run the batch assimilation algorithm """ # prior inflation step c.logger('Prior inflation')(c.inflation_func)(c, 'prior') # transpose c.state.fields_prior to ensemble-complete c.state.state_prior self.partition_grid(c) c.logger('Transpose to ensemble-complete')(self.transpose_to_ensemble_complete)(c) # the core assimilation algorithm # assimilates c.obs.obs_seq into c.state.state_prior to get c.state.state_post c.logger('Assimilation algorithm')(self.assimilation_algorithm)(c) # transpose c.state.state_post back to field-complete c.state.fields_post c.logger('Transpose back to field-complete')(self.transpose_to_field_complete)(c) # output the post state # TODO: which version of posterior to output? ideally the inflated one? # algorithmically clean way is to output the intermediate versions as well and # let the files be input/output to the inflation func # but this is too much IO overhead. c.logger('Output posterior ensemble members')(c.state.output_state)(c, 'post') c.logger('Output posterior ensemble mean')(c.state.output_ens_mean)(c, 'post') if not c.obs.obs_post: # for batch filters the obs_post needs to be computed # (TODO: they can be updated along with the state, as an alternative) c.logger('Prepare obs from posterior state')(c.obs.prepare_obs_from_state)(c, 'post') # posterior inflation c.logger('Posterior inflation')(c.inflation_func)(c, 'post')
[docs] def partition_grid(self, c: Context) -> None: """ Partition the analysis grid into several parts and distribute the workload over the mpi ranks. """ c.state.partitions = bcast_by_root(c.comm)(self.init_partitions)(c) c.obs.obs_inds = bcast_by_root(c.comm_mem)(self.assign_obs)(c) c.state.par_list = bcast_by_root(c.comm)(self.distribute_partitions)(c)
[docs] @abstractmethod def init_partitions(self, c: Context) -> list: """ Generate spatial partitioning of the domain """ ...
[docs] @abstractmethod def assign_obs(self, c: Context) -> dict[ObsRecordID, dict[PartitionID, np.ndarray]]: """ Assign the observation sequence to each partition par_id Args: c (Context): the runtime context object Returns: dict[ObsRecordID, dict[PartitionID, np.ndarray]]: Indices in the full obs_seq for the subset of obs that belongs to partition par_id """ ...
[docs] @abstractmethod def distribute_partitions(self, c: Context) -> dict[ProcIDMem, list[PartitionID]]: """ Distribute partitions across processors """ ...
[docs] def transpose_to_ensemble_complete(self, c: Context) -> None: """ Communicate among mpi ranks and transpose the locally-stored state/obs chunks to ensemble-complete """ c.state.state_prior = c.logger('Transpose prior state')(c.state.transpose_to_ensemble_complete)(c, c.state.fields_prior) c.state.state_z = c.logger('Transpose z coordinates')(c.state.transpose_to_ensemble_complete)(c, c.state.fields_z) c.obs.lobs = c.logger('Transpose obs sequences')(c.obs.transpose_obs_seq)(c, c.obs.obs_seq) c.obs.lobs_prior = c.logger('Transpose obs prior ensemble')(c.obs.transpose_to_ensemble_complete)(c, c.obs.obs_prior)
# if c.debug: # np.save(os.path.join(self.analysis_dir, f'state_prior.{c.pid_mem}.{c.pid_rec}.npy'), state.state_prior) # np.save(os.path.join(self.analysis_dir, f'z_state.{c.pid_mem}.{c.pid_rec}.npy'), state.z_state) # np.save(os.path.join(self.analysis_dir, f'lobs.{c.pid_mem}.{c.pid_rec}.npy'), obs.lobs) # np.save(os.path.join(self.analysis_dir, f'lobs_prior.{c.pid_mem}.{c.pid_rec}.npy'), obs.lobs_prior)
[docs] def transpose_to_field_complete(self, c: Context): """ Communicate among mpi ranks and transpose the locally-stored state/obs chunks back to field-complete """ c.state.fields_post = c.logger('Tranpose posterior state back')(c.state.transpose_to_field_complete)(c, c.state.state_post) if c.obs.lobs_post: # TODO there is a bug here, in transpose seq[:, ind] out of bound c.obs.obs_post = c.logger('Transpose obs posterior ensemble back')(c.obs.transpose_to_field_complete)(c, c.obs.lobs_post)
[docs] @abstractmethod def assimilation_algorithm(self, c: Context) -> None: """ The main assimilation algorithm will be implemented by subclasses """ ...