Source code for NEDAS.assim_tools.inflation.multiplicative

import numpy as np
from NEDAS.core.inflation import Inflation

[docs] class MultiplicativeInflation(Inflation):
[docs] def adaptive_prior_inflation(self, c): """compute prior inflate coef by obs-space statistics (Desroziers et al. 2005)""" c.debug_message = "adaptive prior inflation" stats = self.obs_space_stats(c) if stats['total_nobs'] < 3: if c.debug: c.log_event(f"insufficient nobs to establish statistics, setting inflate_coef=1", flag='warning') self.coef = 1. return varb = stats['varb'] / stats['total_nobs'] varo = stats['varo'] / stats['total_nobs'] omb2 = stats['omb2'] / stats['total_nobs'] if c.debug: c.log_event(f"varb = {varb}, varo={varo}; omb2 = {omb2}", flag='stats') self.coef = np.sqrt((omb2 - varo) / varb) c.message = f"varb = {varb}, varo={varo}; omb2 = {omb2}; coef = {self.coef}"
[docs] def adaptive_post_inflation(self, c): """compute posterior inflate coef by obs-space statistics (Desroziers et al. 2005) """ c.debug_message = "adaptive posterior inflation" stats = self.obs_space_stats(c) if stats['total_nobs'] < 3: if c.debug: c.log_event(f"insufficient nobs to establish statistics, setting inflate_coef=1", flag='warning') self.coef = 1. c.message = "nobs too small, setting coef=1." return if stats['vara'] == 0: if c.debug: c.log_event(f"vara=0 detected, skipping with coef=1 (no inflation)", flag='warning') self.coef = 1. c.message = "vara=0 detected, setting coef=1." return varb = stats['varb'] / stats['total_nobs'] vara = stats['vara'] / stats['total_nobs'] varo = stats['varo'] / stats['total_nobs'] omb2 = stats['omb2'] / stats['total_nobs'] omaamb = stats['omaamb'] / stats['total_nobs'] amb2 = stats['amb2'] / stats['total_nobs'] if c.debug: c.log_event(f"varb = {varb}, vara = {vara}, varo={varo}; omb2 = {omb2}, omaamb = {omaamb}, amb2={amb2}", flag='stats') # self.coef = np.sqrt(omaamb/vara) ratio = (omb2-varo-amb2)/vara if ratio < 0: self.coef = 1.0 c.message = f"omb2 = {omb2}, varo={varo}, amb2 = {amb2}; ratio<0, setting coef=1." return self.coef = np.sqrt(ratio) c.message = f"varb = {varb}, vara = {vara}, varo={varo}; coef = {self.coef}"
[docs] def apply_inflation(self, c, flag): if flag not in ['prior', 'post']: raise ValueError(f"Unknown flag {flag}, should be prior or post") fields = getattr(c.state, f"fields_{flag}") pid_mem_show = [p for p,lst in c.mem_list.items() if len(lst)>0][0] pid_rec_show = [p for p,lst in c.state.rec_list.items() if len(lst)>0][0] c.pid_show = pid_rec_show * c.config.nproc_mem + pid_mem_show c.debug_message = f'inflating {flag} ensemble with multiplicative coef={self.coef}' # process the fields, each processor goes through its own subset of # mem_id,rec_id simultaneously nm = len(c.mem_list[c.pid_mem]) nr = len(c.state.rec_list[c.pid_rec]) c.total_tasks = nm * nr for r, rec_id in enumerate(c.state.rec_list[c.pid_rec]): # read the mean field with rec_id #c.io.read_field() fields_mean = c.io.read_field(c, f"{flag}_mean", rec_id, mem_id=0) for m, mem_id in enumerate(c.mem_list[c.pid_mem]): c.debug_message = f"inflating mem{mem_id+1:03}" c.current_task = m*nr+r # inflate the ensemble perturbations by coef fields[mem_id, rec_id] = self.coef*(fields[mem_id, rec_id] - fields_mean) + fields_mean c.comm.Barrier()