Source code for NEDAS.assim_tools.inflation.RTPP

import numpy as np
from NEDAS.core import Context, Inflation

[docs] class RTPPInflation(Inflation):
[docs] def adaptive_prior_inflation(self, c: Context): raise NotImplementedError("Relaxation method is only implemented for posterior ensemble")
[docs] def adaptive_post_inflation(self, c: Context): """Adaptive covariance relaxation method (Ying and Zhang 2015, QJRMS)""" stats = self.obs_space_stats(c) if stats['total_nobs'] < 3: if c.debug: c.log_event(f"insufficient nobs to establish statistics, setting self.coef=0", flag='warning') self.coef = 0. return if stats['vara'] == 0: if c.debug: c.log_event(f"vara=0 detected, skipping with coef=1 (no inflation)", flag='warning') self.coef = 0. return varb = stats['varb'] / stats['total_nobs'] vara = stats['vara'] / stats['total_nobs'] varo = stats['varo'] / stats['total_nobs'] omb2 = stats['omb2'] / stats['total_nobs'] omaamb = stats['omaamb'] / stats['total_nobs'] amb2 = stats['amb2'] / stats['total_nobs'] beta = np.sqrt(varb/vara) lamb = np.sqrt(max(0.0, (omb2-varo-amb2)/vara)) if c.debug: c.log_event(f"varb = {varb}, vara = {vara}, varo={varo}; omb2 = {omb2}, omaamb = {omaamb}, amb2={amb2}; beta = {beta}, lambda = {lamb}", flag='stats') if beta <= 1: self.coef = 0 else: self.coef = (lamb - 1) / (beta - 1) if self.coef > 2: self.coef = 2 if self.coef < -1: self.coef = -1 c.message = f"varb = {varb}, vara = {vara}, varo={varo}; omb2 = {omb2}, omaamb = {omaamb}, amb2={amb2}; beta = {beta}, lambda = {lamb}"
[docs] def apply_inflation(self, c: Context, flag: str): pid_mem_show = [p for p,lst in c.mem_list.items() if len(lst)>0][0] pid_rec_show = [p for p,lst in c.state.rec_list.items() if len(lst)>0][0] c.pid_show = pid_rec_show * c.config.nproc_mem + pid_mem_show if c.debug: c.log_event(f'relaxing to prior ensemble perturbations with coef={self.coef}', flag='info') # process the fields, each processor goes through its own subset of # mem_id,rec_id simultaneously nm = len(c.mem_list[c.pid_mem]) nr = len(c.state.rec_list[c.pid_rec]) c.total_tasks = nm*nr for r, rec_id in enumerate(c.state.rec_list[c.pid_rec]): # read the mean field with rec_id fld_prior_mean = c.io.read_field(c, 'prior_mean', rec_id, mem_id=0) fld_post_mean = c.io.read_field(c, 'post_mean', rec_id, mem_id=0) for m, mem_id in enumerate(c.mem_list[c.pid_mem]): c.debug_message = f"relax_to_prior_perturb mem{mem_id+1:03}" c.current_task = m*nr+r # inflate the ensemble perturbations by relaxing to prior perturbations fld_prior = c.state.fields_prior[mem_id, rec_id] fld_post = c.state.fields_post[mem_id, rec_id] fld_post = fld_post_mean + self.coef*(fld_prior - fld_prior_mean) + (1.-self.coef)*(fld_post - fld_post_mean) c.comm.Barrier()