import numpy as np
from NEDAS.core import Context, Inflation
[docs]
class RTPPInflation(Inflation):
[docs]
def adaptive_prior_inflation(self, c: Context):
raise NotImplementedError("Relaxation method is only implemented for posterior ensemble")
[docs]
def adaptive_post_inflation(self, c: Context):
"""Adaptive covariance relaxation method (Ying and Zhang 2015, QJRMS)"""
stats = self.obs_space_stats(c)
if stats['total_nobs'] < 3:
if c.debug:
c.log_event(f"insufficient nobs to establish statistics, setting self.coef=0", flag='warning')
self.coef = 0.
return
if stats['vara'] == 0:
if c.debug:
c.log_event(f"vara=0 detected, skipping with coef=1 (no inflation)", flag='warning')
self.coef = 0.
return
varb = stats['varb'] / stats['total_nobs']
vara = stats['vara'] / stats['total_nobs']
varo = stats['varo'] / stats['total_nobs']
omb2 = stats['omb2'] / stats['total_nobs']
omaamb = stats['omaamb'] / stats['total_nobs']
amb2 = stats['amb2'] / stats['total_nobs']
beta = np.sqrt(varb/vara)
lamb = np.sqrt(max(0.0, (omb2-varo-amb2)/vara))
if c.debug:
c.log_event(f"varb = {varb}, vara = {vara}, varo={varo}; omb2 = {omb2}, omaamb = {omaamb}, amb2={amb2}; beta = {beta}, lambda = {lamb}", flag='stats')
if beta <= 1:
self.coef = 0
else:
self.coef = (lamb - 1) / (beta - 1)
if self.coef > 2:
self.coef = 2
if self.coef < -1:
self.coef = -1
c.message = f"varb = {varb}, vara = {vara}, varo={varo}; omb2 = {omb2}, omaamb = {omaamb}, amb2={amb2}; beta = {beta}, lambda = {lamb}"
[docs]
def apply_inflation(self, c: Context, flag: str):
pid_mem_show = [p for p,lst in c.mem_list.items() if len(lst)>0][0]
pid_rec_show = [p for p,lst in c.state.rec_list.items() if len(lst)>0][0]
c.pid_show = pid_rec_show * c.config.nproc_mem + pid_mem_show
if c.debug:
c.log_event(f'relaxing to prior ensemble perturbations with coef={self.coef}', flag='info')
# process the fields, each processor goes through its own subset of
# mem_id,rec_id simultaneously
nm = len(c.mem_list[c.pid_mem])
nr = len(c.state.rec_list[c.pid_rec])
c.total_tasks = nm*nr
for r, rec_id in enumerate(c.state.rec_list[c.pid_rec]):
# read the mean field with rec_id
fld_prior_mean = c.io.read_field(c, 'prior_mean', rec_id, mem_id=0)
fld_post_mean = c.io.read_field(c, 'post_mean', rec_id, mem_id=0)
for m, mem_id in enumerate(c.mem_list[c.pid_mem]):
c.debug_message = f"relax_to_prior_perturb mem{mem_id+1:03}"
c.current_task = m*nr+r
# inflate the ensemble perturbations by relaxing to prior perturbations
fld_prior = c.state.fields_prior[mem_id, rec_id]
fld_post = c.state.fields_post[mem_id, rec_id]
fld_post = fld_post_mean + self.coef*(fld_prior - fld_prior_mean) + (1.-self.coef)*(fld_post - fld_post_mean)
c.comm.Barrier()