import os
import numpy as np
from datetime import datetime
from typing import Callable, Any
from NEDAS.utils.conversion import ensure_list, dt1h
from NEDAS.utils import random_perturb, spatial_operation, parallel
from NEDAS.grid import GridType, RegularGrid
from .context import Context
class PerturbField:
"""
Generates and applies perturbation on given 2D field(s)
"""
grid: GridType
mask: np.ndarray
perturb_methods: dict[str, Callable]
perturb_type: str
other_opts: list[str] = []
params: dict[str, dict[str, Any]]= {}
def __init__(self, **kwargs) -> None:
# get seed, if not specified get a random seed from system entropy
seed = kwargs.get('seed', int.from_bytes(os.urandom(4), 'little'))
assert isinstance(seed, int), f"seed {seed} invalid"
# set the random seed
np.random.seed(seed)
self.grid = kwargs['grid']
self.perturb_methods = {
'gaussian': self.perturb_random_gaussian,
'powerlaw': self.perturb_random_powerlaw,
'displace': self.perturb_random_displace,
}
# parse kwargs and init the perturbation parameters
self.parse_perturb_opts(**kwargs)
def parse_perturb_opts(self, **kwargs) -> None:
# perturb['type'] string format:
#main option (gaussian/powerlaw/displace) followed by , then additional options separated by ,
opts = kwargs['type'].split(',')
self.perturb_type = opts[0]
if self.perturb_type not in self.perturb_methods:
raise NotImplementedError(f"Perturbation type: '{self.perturb_type}' is not implemented")
self.other_opts = []
for opt in opts[1:]:
self.other_opts.append(opt)
key_list = []
for key in ['amp', 'hcorr', 'tcorr', 'powerlaw']:
if key in kwargs:
key_list.append(key)
# a list of variables can be specified if running a multivariate perturbation scheme
# rectify variable and parameter to be lists for further processing
if not isinstance(kwargs['variable'], list):
kwargs['variable'] = [kwargs['variable']]
for key in key_list:
kwargs[key] = [kwargs[key]]
variable_list = kwargs['variable']
nv = len(variable_list) # number of variables
# ensure again that parameters are rectified to lists
for key in key_list:
if not isinstance(kwargs[key], list):
kwargs[key] = [kwargs[key]]
# check for mismatch in list length
if len(kwargs[key]) != nv:
raise ValueError(f"perturb option: {key} has {len(kwargs[key])} entries, but {nv} variables are specified")
# get perturbation parameters for each variable from kwargs
self.params = {}
for v in range(nv):
vname = variable_list[v]
self.params[vname] = {}
# in multiscale approach, a list of parameters can be specified for a variable;
# one separate perturbation will be generated for each, then they will be added together
if isinstance(kwargs[key_list[0]][v], list):
nscale = len(kwargs[key_list[0]][v])
else:
nscale = 1
for key in key_list: # make a list even if only one value for the key
kwargs[key][v] = [kwargs[key][v]]
self.params[vname]['nscale'] = nscale
# check if all keys are lists with same len
for key in key_list[1:]:
if len(kwargs[key][v]) != nscale:
raise ValueError(f"perturb option: {key} has different number of entries from {key_list[0]}, check config")
# assign the parameters
for key in key_list:
self.params[vname][key] = kwargs[key][v]
def generate_perturb(self, grid: GridType,
fields: dict[str, np.ndarray],
prev_perturb: dict[str, Any],
dt: float=1,
n: int=0,) -> dict[str, np.ndarray]:
"""
Add random perturbation to the given 2D fields
Args:
grid (GridType): Grid object describing the 2d domain
fields (dict[str, np.ndarray]): the input fields
prev_perturb (dict[str, Any]): previous perturbation data, dict[str, None] if unavailable
dt (float): interval (hours) between time steps
n (int), current time step index
Returns:
dict[str, np.ndarray]: the generated perturbations
"""
perturb = {}
for vname,rec in self.params.items():
fld = fields[vname]
assert grid.x.shape == fld.shape[-2:], f"input fields[{vname}] dimension mismatch with grid"
ns = rec['nscale']
if self.perturb_type == 'displace':
perturb[vname] = np.zeros((ns,2)+fld.shape[-2:])
else:
perturb[vname] = np.zeros((ns,)+fld.shape)
if prev_perturb[vname] is not None and n==0:
perturb[vname] = prev_perturb[vname]
continue
# loop over scale s and generate perturbation
for s in range(ns):
# draw a random field for each 2d field component in fields
for ind in np.ndindex(fld.shape[:-2]):
perturb[vname][(s,)+ind] = self.perturb_methods[self.perturb_type](rec, s)
# make perturb temporally correlated by blending with prev_perturb
pp = prev_perturb[vname]
if pp is not None:
perturb[vname][s] = self.make_correlated_perturb(pp[s], perturb[vname][s], rec['tcorr'][s] / dt)
if 'press_wind_relate' in self.other_opts:
perturb = self.make_wind_perturb_from_press(perturb)
return perturb
def add_perturb(self, fields: dict[str, np.ndarray], perturb: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:
""" Add perturbations to each field """
for vname,rec in self.params.items():
for s in range(rec['nscale']):
if self.perturb_type == 'displace':
fields[vname] = spatial_operation.warp(self.grid, fields[vname], perturb[vname][s,0,...], perturb[vname][s,1,...])
else:
if 'exp' in self.other_opts:
# add lognormal perturbations
fields[vname] *= np.exp(perturb[vname][s,...] - 0.5*rec['amp'][s]**4)
else:
# just add the gaussian perturbations
fields[vname] += perturb[vname][s,...]
# respect value bounds after perturbing
if 'bounds' in kwargs:
vmin, vmax = kwargs['bounds']
fields[vname] = np.minimum(np.maximum(fields[vname], vmin), vmax)
return fields
def perturb_random_gaussian(self, rec: dict[str, Any], s: int) -> np.ndarray:
""" Generate a random perturbation using the Gaussian random field method """
grid = self.grid
assert isinstance(grid, RegularGrid), f"perturbation by random_field_gaussian only support RegularGrid, {grid}"
p = random_perturb.random_field_gaussian(grid.nx, grid.ny, rec['amp'][s], rec['hcorr'][s]/grid.dx)
return p
def perturb_random_powerlaw(self, rec: dict[str, Any], s: int) -> np.ndarray:
""" Generate a random perturbation using the powerlaw method """
grid = self.grid
assert isinstance(grid, RegularGrid), "perturbation by random_field_powerlaw only support RegularGrid"
p = random_perturb.random_field_powerlaw(grid.nx, grid.ny, rec['amp'][s], rec['powerlaw'][s])
return p
def perturb_random_displace(self, rec: dict[str, Any], s: int) -> np.ndarray:
""" Generate a random perturbation using the displacement method (returns a vector field) """
du, dv = random_perturb.random_displacement(self.grid, self.grid.mask, rec['amp'][s], rec['hcorr'][s]/self.grid.dx)
return np.array([du, dv])
def make_correlated_perturb(self, prev_perturb: np.ndarray, perturb: np.ndarray, corr: float) -> np.ndarray:
""" Create perturbations that are correlated in time """
autocorr = 0.75
alpha = autocorr**(1.0 / corr)
return np.sqrt(1-alpha**2) * perturb + alpha * prev_perturb
def make_wind_perturb_from_press(self, perturb: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
"""
Legacy option in TOPAZ prsflg==1,2 options in force_perturb program, reproduced here.
Expecting the vnames 'atmos_surf_press' for pressure field and 'atmos_surf_velocity' for the wind field.
Derive the wind perturbation from pressure perturbations, so that they are in wind-pressure relation (prsflg==1)
Additionally, derived wind perturbations are rescaled to match the specified amp (prsflg==2)
"""
wind_to_press = random_perturb.get_velocity_from_press
for vname in ['atmos_surf_velocity', 'atmos_surf_press']:
assert vname in self.params.keys(), f'{vname} not in variable list, cannot run press_wind_relate option'
for s in range(self.params['atmos_surf_press']['nscale']):
pres_pert = perturb['atmos_surf_press'][s]
scale_wind = ('scale_wind' in self.other_opts)
pres_amp = self.params['atmos_surf_press']['amp'][s]
pres_hcorr = self.params['atmos_surf_press']['hcorr'][s]
wind_amp = self.params['atmos_surf_velocity']['amp'][s]
wind_pert = wind_to_press(self.grid, pres_pert, scale_wind, pres_amp, pres_hcorr, wind_amp)
perturb['atmos_surf_velocity'][s] = wind_pert
return perturb
[docs]
class Perturbation:
"""
Perturbation top-level manager
"""
nfld: int = 0
task_list: dict[int, list[dict]] = {}
perturb: dict[str, Any] = {}
def __init__(self, c: Context):
# distribute perturbation items among MPI ranks
self.task_list = parallel.bcast_by_root(c.comm)(self.distribute_perturb_tasks)(c)
# go through the opts to count how many fields will be perturbed (for showing progress)
self.count_num_fields(c)
[docs]
def distribute_perturb_tasks(self, c: Context) -> dict[int, list[dict]]:
task_list_full = []
for perturb_rec in ensure_list(c.config.perturb):
for mem_id in range(c.nens):
task_list_full.append({**perturb_rec, 'member':mem_id})
task_list = parallel.distribute_tasks(c.comm, task_list_full)
return task_list
[docs]
def count_num_fields(self, c: Context):
# first go through the fields to count how many (for showing progress)
for rec in self.task_list[c.pid]:
model_name = rec['model_src']
model = c.models[model_name]
vname = ensure_list(rec['variable'])[0]
dt = model.variables[vname].dt
nstep = int(c.config.cycle_period / dt) + 1
for _ in range(nstep):
for _ in model.variables[vname].levels:
self.nfld += 1
def __call__(self, c: Context) -> None:
if c.config.io_mode == 'offline':
self.prepare_perturb_dir(c)
c.pid_show = [p for p,lst in self.task_list.items() if len(lst)>0][0]
# go through the tasks
c.total_tasks = self.nfld+1
fld_id = 0
for rec in self.task_list[c.pid]:
p = PerturbField(**rec, grid=c.grid)
model = c.models[rec['model_src']] # model class object
member = rec['member']
variable_list = ensure_list(rec['variable'])
# check if previous perturb is available from past cycles
self.load_perturb_data(c, **rec)
# get number of time steps for this set of variables
# perturbation will be generated for all time steps if variable is available
dt = max([model.variables[v].dt for v in variable_list])
nstep = int(c.config.cycle_period / dt) + 1
for n in range(nstep):
t = c.time + n * dt * dt1h
# TODO: perturbation for each k level is drawn independently, can be improved
# by introducing a vertical correlation length scale, or using EOF modes.
# Note: assuming all variables in the list have the same k levels
for k in model.variables[variable_list[0]].levels:
fld_id += 1
c.debug_message = f"perturbing mem{member+1:03} {variable_list} at {t} level {k}"
c.current_task = fld_id
fields = self.collect_fields(c, t, k, **rec)
self.perturb = p.generate_perturb(c.grid, fields, prev_perturb=self.perturb, dt=dt, n=n)
fields = p.add_perturb(fields, self.perturb, **rec)
self.output_perturbed_fields(c, fields, t, k, **rec)
self.save_perturb_data(c, **rec)
c.comm.Barrier()
[docs]
def prepare_perturb_dir(self, c):
""" Prepare and clear the directory where perturbation data will be stored (offline mode) """
assert c.config.io_mode == 'offline', f"prepare_perturb_dir only needed in offline io mode"
# clean up perturb files in current cycle dir
for rec in c.config.perturb:
path = c.fs.forecast_dir(c.time, rec['model_src'])
perturb_dir = os.path.join(path, 'perturb')
if c.pid==0:
c.run_job(f"rm -rf {perturb_dir}; mkdir -p {perturb_dir}")
c.comm.Barrier()
[docs]
def save_perturb_data(self, c: Context, **rec):
""" Save a copy of perturbation data, for use by the next analysis cycle """
path = None
if c.config.io_mode == 'offline':
path = os.path.join(c.fs.forecast_dir(c.time, rec['model_src']), 'perturb')
for vname in ensure_list(rec['variable']):
data = self.perturb[vname]
assert data is not None
c.io.save_ndarray(c, f"{vname}_mem{rec['member']+1:03d}", data, path)
[docs]
def load_perturb_data(self, c: Context, **rec):
""" Load the perturbation data """
path = None
if c.config.io_mode == 'offline':
path = os.path.join(c.fs.forecast_dir(c.time, rec['model_src']), 'perturb')
for vname in ensure_list(rec['variable']):
data = c.io.load_ndarray(c, f"{vname}_mem{rec['member']+1:03d}", path)
self.perturb[vname] = data
[docs]
def collect_fields(self, c: Context, t: datetime, k: int, **rec) -> dict[str, np.ndarray]:
""" Collect all model fields to be perturbed """
variable_list = ensure_list(rec['variable'])
model = c.models[rec['model_src']]
# set up grids
vname =variable_list[0] # note: all variables in the list shall have same dt and k levels
c.io.call_method(c, 'current', model.read_grid, name=vname, time=t, k=k, **rec)
model.grid.set_destination_grid(c.grid)
c.grid.set_destination_grid(model.grid)
# collect model variable fields
fields = {}
for vname in variable_list:
# read variable from model state
fld = c.io.call_method(c, 'current', model.read_var, name=vname, time=t, k=k, **rec)
# convert to analysis grid
fields[vname] = model.grid.convert(fld, is_vector=model.variables[vname].is_vector)
return fields
[docs]
def output_perturbed_fields(self, c: Context, fields: dict[str, np.ndarray], t: datetime, k:int, **rec) -> None:
variable_list = ensure_list(rec['variable'])
model = c.models[rec['model_src']]
if rec['type'].split(',')[0]=='displace' and hasattr(model, 'displace'):
# use model internal method to apply displacement perturbations directly
displace_method = getattr(model, 'displace')
c.io.call_method(c, 'current', displace_method, self.perturb, time=t, k=k, **rec)
else:
# convert from analysis grid to model grid, and
# write the perturbed variables back to model state files
for vname in variable_list:
fld = c.grid.convert(fields[vname], is_vector=model.variables[vname].is_vector)
c.io.call_method(c, 'current', model.write_var, fld, name=vname, time=t, k=k, **rec)