Source code for NEDAS.core.inflation

from abc import ABC, abstractmethod
from typing import Literal
import numpy as np
from .context import Context

[docs] class Inflation(ABC): """ Class for inflating the ensemble members (covariance inflation) """ def __init__(self, coef: float=1.0, adaptive: bool=False, prior: bool=False, post: bool=False): self.coef = coef self.adaptive = adaptive self.prior = prior self.post = post def __call__(self, c: Context, flag: Literal['prior', 'post']) -> None: """ Perform the covariance inflation method """ if flag == 'prior' and self.prior: if self.adaptive: assert self.validate_obs_ens(c, c.obs.obs_prior), "obs.obs_prior is corrupted, cannot compute obs_space_stats for adaptive inflation." self.adaptive_prior_inflation(c) self.apply_inflation(c, flag) if flag == 'post' and self.post: if self.adaptive: assert self.validate_obs_ens(c, c.obs.obs_prior), "obs.obs_prior is corrupted, cannot compute obs_space_stats for adaptive inflation." assert self.validate_obs_ens(c, c.obs.obs_post), "obs.obs_post is corrupted, cannot compute obs_space_stats for adaptive inflation." self.adaptive_post_inflation(c) self.apply_inflation(c, flag)
[docs] def validate_obs_ens(self, c: Context, obs_ens: dict) -> bool: """ Check if the obs_ens has all member and records""" if isinstance(obs_ens, dict): for obs_rec_id in c.obs.obs_rec_list[c.pid_rec]: for mem_id in c.mem_list[c.pid_mem]: if (mem_id, obs_rec_id) not in obs_ens: return False if not isinstance(obs_ens[mem_id, obs_rec_id], np.ndarray): return False return True return False
[docs] def obs_space_stats(self, c: Context): """observation-space statistics""" stats = {'total_nobs': 0, 'omb2': 0.0, # obs-minus-background differences squared 'omaamb': 0.0, 'amb2': 0.0, # analysis-minus-background diff squared 'varo': 0.0, # obs err variance 'varb': 0.0, # obs_prior (background) ensemble variances 'vara': 0.0, # obs_post (analysis) ensemble variances } # go through each obs record for r, obs_rec_id in enumerate(c.obs.obs_rec_list[c.pid_rec]): obs_rec = c.obs.info.records[obs_rec_id] nobs = obs_rec.nobs # 1. get ensemble mean obs_prior: if obs_rec.is_vector: nv = 2 shape = (nv, nobs) else: nv = 1 shape = (nobs,) # sum over all obs_prior_seq locally stored on pid sum_obs_prior_pid = np.zeros(shape) for mem_id in c.mem_list[c.pid_mem]: sum_obs_prior_pid += c.obs.obs_prior[mem_id, obs_rec_id] # sum over all obs_prior_seq on differnet pids to get the total sum sum_obs_prior = c.comm_mem.allreduce(sum_obs_prior_pid) mean_obs_prior = sum_obs_prior / c.nens mean_obs_post = None if c.obs.obs_post: # sum over all obs_prior_seq locally stored on pid sum_obs_post_pid = np.zeros(shape) for mem_id in c.mem_list[c.pid_mem]: sum_obs_post_pid += c.obs.obs_post[mem_id, obs_rec_id] # sum over all obs_prior_seq on differnet pids to get the total sum sum_obs_post = c.comm_mem.allreduce(sum_obs_post_pid) mean_obs_post = sum_obs_post / c.nens # 2. get ensemble spread obs_prior: pert2_obs_prior_pid = np.zeros(shape) for mem_id in c.mem_list[c.pid_mem]: pert2_obs_prior_pid += (c.obs.obs_prior[mem_id, obs_rec_id] - mean_obs_prior)**2 pert2_obs_prior = c.comm_mem.allreduce(pert2_obs_prior_pid) variance_obs_prior = pert2_obs_prior / (c.nens - 1) variance_obs_post = None if c.obs.obs_post: pert2_obs_post_pid = np.zeros(shape) for mem_id in c.mem_list[c.pid_mem]: pert2_obs_post_pid += (c.obs.obs_post[mem_id, obs_rec_id] - mean_obs_post)**2 pert2_obs_post = c.comm_mem.allreduce(pert2_obs_post_pid) variance_obs_post = pert2_obs_post / (c.nens - 1) obs_value = c.obs.obs_seq[obs_rec_id]['obs'] stats['total_nobs'] += nv * nobs stats['omb2'] += np.sum((obs_value - mean_obs_prior)**2) stats['varo'] += np.sum(c.obs.obs_seq[obs_rec_id]['err_std']**2) * nv stats['varb'] += np.sum(variance_obs_prior) if c.obs.obs_post and variance_obs_post is not None: stats['amb2'] += np.sum((mean_obs_post - mean_obs_prior)**2) stats['omaamb'] += np.sum((obs_value - mean_obs_post)*(mean_obs_post - mean_obs_prior)) stats['vara'] += np.sum(variance_obs_post) return stats
[docs] @abstractmethod def adaptive_prior_inflation(self, c: Context): pass
[docs] @abstractmethod def adaptive_post_inflation(self, c: Context): pass
[docs] @abstractmethod def apply_inflation(self, c: Context, flag: Literal['prior', 'post']): pass