Source code for NEDAS.core.perturb

import os
import numpy as np
from datetime import datetime
from typing import Callable, Any
from NEDAS.utils.conversion import ensure_list, dt1h
from NEDAS.utils import random_perturb, spatial_operation, parallel
from NEDAS.grid import GridType, RegularGrid
from .context import Context

class PerturbField:
    """
    Generates and applies perturbation on given 2D field(s)
    """
    grid: GridType
    mask: np.ndarray
    perturb_methods: dict[str, Callable]
    perturb_type: str
    other_opts: list[str] = []
    params: dict[str, dict[str, Any]]= {}

    def __init__(self, **kwargs) -> None:
        # get seed, if not specified get a random seed from system entropy
        seed = kwargs.get('seed', int.from_bytes(os.urandom(4), 'little'))
        assert isinstance(seed, int), f"seed {seed} invalid"
        # set the random seed
        np.random.seed(seed)

        self.grid = kwargs['grid']

        self.perturb_methods = {
            'gaussian': self.perturb_random_gaussian,
            'powerlaw': self.perturb_random_powerlaw,
            'displace': self.perturb_random_displace,
        }

        # parse kwargs and init the perturbation parameters
        self.parse_perturb_opts(**kwargs)

    def parse_perturb_opts(self, **kwargs) -> None:
        # perturb['type'] string format:
        #main option (gaussian/powerlaw/displace) followed by , then additional options separated by ,
        opts = kwargs['type'].split(',')
        self.perturb_type = opts[0]
        if self.perturb_type not in self.perturb_methods:
            raise NotImplementedError(f"Perturbation type: '{self.perturb_type}' is not implemented")

        self.other_opts = []
        for opt in opts[1:]:
            self.other_opts.append(opt)

        key_list = []
        for key in ['amp', 'hcorr', 'tcorr', 'powerlaw']:
            if key in kwargs:
                key_list.append(key)

        # a list of variables can be specified if running a multivariate perturbation scheme
        # rectify variable and parameter to be lists for further processing
        if not isinstance(kwargs['variable'], list):
            kwargs['variable'] = [kwargs['variable']]
            for key in key_list:
                kwargs[key] = [kwargs[key]]
        variable_list = kwargs['variable']
        nv = len(variable_list)  # number of variables

        # ensure again that parameters are rectified to lists
        for key in key_list:
            if not isinstance(kwargs[key], list):
                kwargs[key] = [kwargs[key]]

            # check for mismatch in list length
            if len(kwargs[key]) != nv:
                raise ValueError(f"perturb option: {key} has {len(kwargs[key])} entries, but {nv} variables are specified")

        # get perturbation parameters for each variable from kwargs
        self.params = {}
        for v in range(nv):
            vname = variable_list[v]
            self.params[vname] = {}
            # in multiscale approach, a list of parameters can be specified for a variable;
            # one separate perturbation will be generated for each, then they will be added together
            if isinstance(kwargs[key_list[0]][v], list):
                nscale = len(kwargs[key_list[0]][v])
            else:
                nscale = 1
                for key in key_list:  # make a list even if only one value for the key
                    kwargs[key][v] = [kwargs[key][v]]
            self.params[vname]['nscale'] = nscale
            # check if all keys are lists with same len
            for key in key_list[1:]:
                if len(kwargs[key][v]) != nscale:
                    raise ValueError(f"perturb option: {key} has different number of entries from {key_list[0]}, check config")
            # assign the parameters
            for key in key_list:
                self.params[vname][key] = kwargs[key][v]

    def generate_perturb(self, grid: GridType,
                        fields: dict[str, np.ndarray],
                        prev_perturb: dict[str, Any],
                        dt: float=1,
                        n: int=0,) -> dict[str, np.ndarray]:
        """
        Add random perturbation to the given 2D fields

        Args:
            grid (GridType): Grid object describing the 2d domain
            fields (dict[str, np.ndarray]): the input fields
            prev_perturb (dict[str, Any]): previous perturbation data, dict[str, None] if unavailable
            dt (float): interval (hours) between time steps
            n (int), current time step index

        Returns:
            dict[str, np.ndarray]: the generated perturbations
        """
        perturb = {}
        for vname,rec in self.params.items():
            fld = fields[vname]
            assert grid.x.shape == fld.shape[-2:], f"input fields[{vname}] dimension mismatch with grid"

            ns = rec['nscale']
            if self.perturb_type == 'displace':
                perturb[vname] = np.zeros((ns,2)+fld.shape[-2:])
            else:
                perturb[vname] = np.zeros((ns,)+fld.shape)

            if prev_perturb[vname] is not None and n==0:
                perturb[vname] = prev_perturb[vname]
                continue

            # loop over scale s and generate perturbation
            for s in range(ns):
                # draw a random field for each 2d field component in fields
                for ind in np.ndindex(fld.shape[:-2]):
                    perturb[vname][(s,)+ind] = self.perturb_methods[self.perturb_type](rec, s)

                # make perturb temporally correlated by blending with prev_perturb
                pp = prev_perturb[vname]
                if pp is not None:
                    perturb[vname][s] = self.make_correlated_perturb(pp[s], perturb[vname][s], rec['tcorr'][s] / dt)

        if 'press_wind_relate' in self.other_opts:
            perturb = self.make_wind_perturb_from_press(perturb)

        return perturb

    def add_perturb(self, fields: dict[str, np.ndarray], perturb: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:
        """ Add perturbations to each field """
        for vname,rec in self.params.items():
            for s in range(rec['nscale']):
                if self.perturb_type == 'displace':
                    fields[vname] = spatial_operation.warp(self.grid, fields[vname], perturb[vname][s,0,...], perturb[vname][s,1,...])
                else:
                    if 'exp' in self.other_opts:
                        # add lognormal perturbations
                        fields[vname] *= np.exp(perturb[vname][s,...] - 0.5*rec['amp'][s]**4)
                    else:
                        # just add the gaussian perturbations
                        fields[vname] += perturb[vname][s,...]

            # respect value bounds after perturbing
            if 'bounds' in kwargs:
                vmin, vmax = kwargs['bounds']
                fields[vname] = np.minimum(np.maximum(fields[vname], vmin), vmax)
        return fields

    def perturb_random_gaussian(self, rec: dict[str, Any], s: int) -> np.ndarray:
        """ Generate a random perturbation using the Gaussian random field method """
        grid = self.grid
        assert isinstance(grid, RegularGrid), f"perturbation by random_field_gaussian only support RegularGrid, {grid}"
        p = random_perturb.random_field_gaussian(grid.nx, grid.ny, rec['amp'][s], rec['hcorr'][s]/grid.dx)
        return p

    def perturb_random_powerlaw(self, rec: dict[str, Any], s: int) -> np.ndarray:
        """ Generate a random perturbation using the powerlaw method """
        grid = self.grid
        assert isinstance(grid, RegularGrid), "perturbation by random_field_powerlaw only support RegularGrid"
        p = random_perturb.random_field_powerlaw(grid.nx, grid.ny, rec['amp'][s], rec['powerlaw'][s])
        return p

    def perturb_random_displace(self, rec: dict[str, Any], s: int) -> np.ndarray:
        """ Generate a random perturbation using the displacement method (returns a vector field) """
        du, dv = random_perturb.random_displacement(self.grid, self.grid.mask, rec['amp'][s], rec['hcorr'][s]/self.grid.dx)
        return np.array([du, dv])

    def make_correlated_perturb(self, prev_perturb: np.ndarray, perturb: np.ndarray, corr: float) -> np.ndarray:
        """ Create perturbations that are correlated in time """
        autocorr = 0.75
        alpha = autocorr**(1.0 / corr)
        return np.sqrt(1-alpha**2) * perturb + alpha * prev_perturb

    def make_wind_perturb_from_press(self, perturb: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
        """
        Legacy option in TOPAZ prsflg==1,2 options in force_perturb program, reproduced here.
        Expecting the vnames 'atmos_surf_press' for pressure field and 'atmos_surf_velocity' for the wind field.
        Derive the wind perturbation from pressure perturbations, so that they are in wind-pressure relation (prsflg==1)
        Additionally, derived wind perturbations are rescaled to match the specified amp (prsflg==2)
        """
        wind_to_press = random_perturb.get_velocity_from_press
        for vname in ['atmos_surf_velocity', 'atmos_surf_press']:
            assert vname in self.params.keys(), f'{vname} not in variable list, cannot run press_wind_relate option'

        for s in range(self.params['atmos_surf_press']['nscale']):
            pres_pert = perturb['atmos_surf_press'][s]
            scale_wind = ('scale_wind' in self.other_opts)
            pres_amp = self.params['atmos_surf_press']['amp'][s]
            pres_hcorr = self.params['atmos_surf_press']['hcorr'][s]
            wind_amp = self.params['atmos_surf_velocity']['amp'][s]
            wind_pert = wind_to_press(self.grid, pres_pert, scale_wind, pres_amp, pres_hcorr, wind_amp)
            perturb['atmos_surf_velocity'][s] = wind_pert
        return perturb

[docs] class Perturbation: """ Perturbation top-level manager """ nfld: int = 0 task_list: dict[int, list[dict]] = {} perturb: dict[str, Any] = {} def __init__(self, c: Context): # distribute perturbation items among MPI ranks self.task_list = parallel.bcast_by_root(c.comm)(self.distribute_perturb_tasks)(c) # go through the opts to count how many fields will be perturbed (for showing progress) self.count_num_fields(c)
[docs] def distribute_perturb_tasks(self, c: Context) -> dict[int, list[dict]]: task_list_full = [] for perturb_rec in ensure_list(c.config.perturb): for mem_id in range(c.nens): task_list_full.append({**perturb_rec, 'member':mem_id}) task_list = parallel.distribute_tasks(c.comm, task_list_full) return task_list
[docs] def count_num_fields(self, c: Context): # first go through the fields to count how many (for showing progress) for rec in self.task_list[c.pid]: model_name = rec['model_src'] model = c.models[model_name] vname = ensure_list(rec['variable'])[0] dt = model.variables[vname].dt nstep = int(c.config.cycle_period / dt) + 1 for _ in range(nstep): for _ in model.variables[vname].levels: self.nfld += 1
def __call__(self, c: Context) -> None: if c.config.io_mode == 'offline': self.prepare_perturb_dir(c) c.pid_show = [p for p,lst in self.task_list.items() if len(lst)>0][0] # go through the tasks c.total_tasks = self.nfld+1 fld_id = 0 for rec in self.task_list[c.pid]: p = PerturbField(**rec, grid=c.grid) model = c.models[rec['model_src']] # model class object member = rec['member'] variable_list = ensure_list(rec['variable']) # check if previous perturb is available from past cycles self.load_perturb_data(c, **rec) # get number of time steps for this set of variables # perturbation will be generated for all time steps if variable is available dt = max([model.variables[v].dt for v in variable_list]) nstep = int(c.config.cycle_period / dt) + 1 for n in range(nstep): t = c.time + n * dt * dt1h # TODO: perturbation for each k level is drawn independently, can be improved # by introducing a vertical correlation length scale, or using EOF modes. # Note: assuming all variables in the list have the same k levels for k in model.variables[variable_list[0]].levels: fld_id += 1 c.debug_message = f"perturbing mem{member+1:03} {variable_list} at {t} level {k}" c.current_task = fld_id fields = self.collect_fields(c, t, k, **rec) self.perturb = p.generate_perturb(c.grid, fields, prev_perturb=self.perturb, dt=dt, n=n) fields = p.add_perturb(fields, self.perturb, **rec) self.output_perturbed_fields(c, fields, t, k, **rec) self.save_perturb_data(c, **rec) c.comm.Barrier()
[docs] def prepare_perturb_dir(self, c): """ Prepare and clear the directory where perturbation data will be stored (offline mode) """ assert c.config.io_mode == 'offline', f"prepare_perturb_dir only needed in offline io mode" # clean up perturb files in current cycle dir for rec in c.config.perturb: path = c.fs.forecast_dir(c.time, rec['model_src']) perturb_dir = os.path.join(path, 'perturb') if c.pid==0: c.run_job(f"rm -rf {perturb_dir}; mkdir -p {perturb_dir}") c.comm.Barrier()
[docs] def save_perturb_data(self, c: Context, **rec): """ Save a copy of perturbation data, for use by the next analysis cycle """ path = None if c.config.io_mode == 'offline': path = os.path.join(c.fs.forecast_dir(c.time, rec['model_src']), 'perturb') for vname in ensure_list(rec['variable']): data = self.perturb[vname] assert data is not None c.io.save_ndarray(c, f"{vname}_mem{rec['member']+1:03d}", data, path)
[docs] def load_perturb_data(self, c: Context, **rec): """ Load the perturbation data """ path = None if c.config.io_mode == 'offline': path = os.path.join(c.fs.forecast_dir(c.time, rec['model_src']), 'perturb') for vname in ensure_list(rec['variable']): data = c.io.load_ndarray(c, f"{vname}_mem{rec['member']+1:03d}", path) self.perturb[vname] = data
[docs] def collect_fields(self, c: Context, t: datetime, k: int, **rec) -> dict[str, np.ndarray]: """ Collect all model fields to be perturbed """ variable_list = ensure_list(rec['variable']) model = c.models[rec['model_src']] # set up grids vname =variable_list[0] # note: all variables in the list shall have same dt and k levels c.io.call_method(c, 'current', model.read_grid, name=vname, time=t, k=k, **rec) model.grid.set_destination_grid(c.grid) c.grid.set_destination_grid(model.grid) # collect model variable fields fields = {} for vname in variable_list: # read variable from model state fld = c.io.call_method(c, 'current', model.read_var, name=vname, time=t, k=k, **rec) # convert to analysis grid fields[vname] = model.grid.convert(fld, is_vector=model.variables[vname].is_vector) return fields
[docs] def output_perturbed_fields(self, c: Context, fields: dict[str, np.ndarray], t: datetime, k:int, **rec) -> None: variable_list = ensure_list(rec['variable']) model = c.models[rec['model_src']] if rec['type'].split(',')[0]=='displace' and hasattr(model, 'displace'): # use model internal method to apply displacement perturbations directly displace_method = getattr(model, 'displace') c.io.call_method(c, 'current', displace_method, self.perturb, time=t, k=k, **rec) else: # convert from analysis grid to model grid, and # write the perturbed variables back to model state files for vname in variable_list: fld = c.grid.convert(fields[vname], is_vector=model.variables[vname].is_vector) c.io.call_method(c, 'current', model.write_var, fld, name=vname, time=t, k=k, **rec)