Source code for NEDAS.assim_tools.updators.additive

import numpy as np
from NEDAS.core import Context, Updator

[docs] class AdditiveUpdator(Updator):
[docs] def compute_increment(self, c: Context): """ Additive updator: just compute the difference between prior and posterior as increments """ self.increment = {} for mem_id, rec_id in c.state.fields_prior.keys(): rec = c.state.info.fields[rec_id] fld_prior = c.state.fields_prior[mem_id, rec_id] fld_post = c.state.fields_post[mem_id, rec_id] # misc transform inverse for transform_func in c.transform_funcs: fld_prior = transform_func.backward_state(c, rec, fld_prior) fld_post = transform_func.backward_state(c, rec, fld_post) # collect the increments self.increment[mem_id, rec_id] = fld_post - fld_prior
[docs] def update_files(self, c, mem_id, rec_id): """ Method to update a single field rec_id in the model restart file. This can be overridden by derived classes for specific update methods Inputs: - c: context object - mem_id: member index - rec_id: record index """ rec = c.state.info.fields[rec_id].asdict() model = c.models[rec['model_src']] # convert the posterior variable back to native model grid var_prior = c.io.call_method(c, 'current', model.read_var, member=mem_id, **rec) c.grid.set_destination_grid(model.grid) incr = c.grid.convert(self.increment[mem_id, rec_id], is_vector=rec['is_vector'], method='linear') if rec['is_vector']: fld_shape = var_prior.shape[1:] else: fld_shape = var_prior.shape if fld_shape == model.grid.x.shape: var_post = var_prior + incr elif fld_shape == model.grid.x_elem.shape: incr = np.mean(incr[...,model.grid.tri.triangles], axis=-1) var_post = var_prior + incr else: raise RuntimeError(f"mismatch in field prior {var_prior.shape} with increment {incr.shape}") # TODO: temporary solution for nan values due to interpolation ind = np.where(np.isnan(var_post)) var_post[ind] = var_prior[ind] # if np.isnan(var_post).any(): # raise ValueError('nan detected in var_post') # write the posterior variable to restart file c.io.call_method(c, 'current', model.write_var, var_post, member=mem_id, comm=c.comm, **rec)