"""This module is used to perturb the forcing variables for the next forecast cycle.
This module is specifically designed for neXtSIM-DG in NEDAS.
The design follows the perturbation strategy in TOAPZ4 where
the perturbation is temporally correlationed as an AR1 process.
Parameters of the perturbation are read from the yaml file.
The perturbation is applied to the forcing variables in the forcing files.
"""
from datetime import datetime, timedelta, timezone
from dateutil.relativedelta import relativedelta # type: ignore
import os
import threading
import typing
import cftime # type: ignore
import netCDF4 # type: ignore
import numpy as np
import pyproj # type: ignore
from NEDAS.utils.conversion import t2s
from NEDAS.grid import RegularGrid
from NEDAS.models.nextsim.dg.perturb import gen_perturb, apply_perturb, pres_adjusted_wind_perturb, apply_AR1_perturb
from NEDAS.models.nextsim.dg import slicing_nc
# both the topaz and ERA5 data are projected onto the same grid at the moment
_proj:pyproj.Proj = pyproj.Proj(proj='stere', a=6378273, b=6356889.448910593, lat_0=90., lon_0=-45., lat_ts=60.)
# the thread lock for reading and writing netcdf files
thread_lock = threading.Lock()
[docs]
def get_fname_daterange(current_date: datetime, initial_date:str, interval:str, forcing_file_date_format:str) -> tuple[str, str]:
"""Inferring the date range of the forcing file for the current time,
the forcing start date and the forcing interval in the initial forcing file given by yaml file.
Parameters
----------
current_date : datetime
current date
initial_date : str
forcing start date
interval : str
forcing interval
forcing_file_date_format : str
forcing file date format expressed in strftime format
Returns
-------
tuple[str, str]
start date and end date of the forcing file for current time
"""
# Parse the dates
initial_date_dt:datetime = datetime.strptime(initial_date, forcing_file_date_format)
initial_date_dt = initial_date_dt.replace(tzinfo=timezone.utc)
keywords:dict[str, str] = {'y': 'years', 'm': 'months', 'd': 'days'}
# Initialize start and end dates
start_date = initial_date_dt
end_date = initial_date_dt + relativedelta(**{keywords[interval[-1]]: int(interval[:-1])}) #type: ignore
assert current_date >= start_date, \
f'Current time {current_date} is earlier than the initial forcing date {initial_date}'
# Calculate the intervals until the current date is within the range
while end_date <= current_date:
start_date = end_date
end_date = start_date + relativedelta(**{keywords[interval[-1]]: int(interval[:-1])}) #type: ignore
# Format the dates back to strings
start_date_str = start_date.strftime(forcing_file_date_format)
end_date_str = end_date.strftime(forcing_file_date_format)
return start_date_str, end_date_str
[docs]
def get_time_from_nc(fname:str, time_varname:str, time_units_name:str, time: datetime, next_time: datetime, debug:bool=False) -> tuple[np.ndarray, list[datetime]]:
"""Get the indices and corresponding time that includes time and next_time from the netcdf file
This function is not seeking the exact time and next_time in the forcing file,
but the time steps that include the time and next_time such that the perturbed forcing
file can be a bit smaller. Therefore, we allow for a few more time steps in this file.
Parameters
----------
fname : str
forcing file name
time_varname : str
time variable name in the forcing file
time_units_name : str
variable name that gives time units in the forcing file
time : datetime
current time
next_time : datetime
time at the end of the forecast cycle
Returns
-------
tuple[np.ndarray, np.ndarray]
indices and corresponding time for the next forecast cycle
"""
with thread_lock:
with netCDF4.Dataset(fname, 'r') as f:
time_units = f[time_units_name].units
# get the start time in the forcing file
start_time: datetime = cftime.num2date(f[time_varname][0], units=time_units, only_use_cftime_datetimes=False)
# get the time step in the forcing file
time_step: timedelta = cftime.num2date(f[time_varname][1], units=time_units, only_use_cftime_datetimes=False) - start_time
start_time = start_time.replace(tzinfo=timezone.utc)
# get the indices of the current and next time steps
it0: int = int(np.rint((time - start_time) / time_step))
it1: int = int(np.rint((next_time - start_time) / time_step))
# get total number of time steps in the current file
nt: int = len(f[time_varname][:])
# extend the forcing time step by one to ensure all time steps are included
it0 = max(0, min(it0 - 1, nt - 1))
it1 = max(0, min(it1 + 1, nt - 1))
# get the all the time between current time and next time in the forcing file
file_time: list[datetime] = [cftime.num2date(f[time_varname][it], time_units)
for it in range(it0, it1 + 1)]
if debug:
print (f'file: {fname}; 'f'file time: {file_time[0]} to {file_time[-1]},'
f'forecast time: {time} to {next_time}')
return np.arange(it0, it1 + 1), file_time
[docs]
def get_time_index(fname:str, time_varname:str, time_units_name:str, time:datetime) -> int:
"""
Get the index of time in a netcdf file
"""
with thread_lock:
with netCDF4.Dataset(fname, 'r') as f:
time_units = f[time_units_name].units
start_time: datetime = cftime.num2date(f[time_varname][0], units=time_units, only_use_cftime_datetimes=False)
time_step: timedelta = cftime.num2date(f[time_varname][1], units=time_units, only_use_cftime_datetimes=False) - start_time
start_time = start_time.replace(tzinfo=timezone.utc)
ind: int = int(np.rint((time - start_time) / time_step))
return ind
[docs]
def get_prev_time_from_nc(fname:str, time_varname:str, time_units_name:str, itime:int) -> datetime:
"""Get the previous time in the netcdf file before the start of the forecast cycle
Parameters
----------
fname : str
forcing file name
time_varname : str
time variable name in the forcing file
time_units_name : str
variable name that gives time units in the forcing file
itime : int
current time index in the forcing file
"""
with thread_lock:
with netCDF4.Dataset(fname, 'r') as f:
it :int = max(0, itime - 1)
time_units:str = f[time_units_name].units
prev_time: datetime = cftime.num2date(f[time_varname][it], units=time_units, only_use_cftime_datetimes=False)
prev_time = prev_time.replace(tzinfo=timezone.utc)
return prev_time
[docs]
def read_var(fname:str, varnames:list[str], itime: int) -> np.ma.MaskedArray:
"""reading a variable from a netcdf file
Parameters
----------
fname : str
forcing file name
varnames : list[str]
list of variable names
itime : int
time index in the forcing file
"""
data: list[np.ndarray] = []
with thread_lock:
with netCDF4.Dataset(fname, 'r') as f:
# read the variable
for vname in varnames:
data.append(f[vname][itime])
return np.ma.array(data)
[docs]
def write_var(fname:str, varnames: list[str], data: np.ndarray, itime: int) -> None:
"""Write the perturbed variable back to the forcing file
Parameters
----------
fname : str
forcing file name
varnames : list[str]
list of variable names
itime : int
time index in the forcing file
"""
# We assume all variables in the forcing file exists
assert os.path.exists(fname), f'{fname} does not exist; Please copy the forcing file to the correct path first.'
with thread_lock:
with netCDF4.Dataset(fname, 'r+') as f:
for i, vname in enumerate(varnames):
f[vname][itime] = data[i]
f.sync()
[docs]
def geostrophic_perturb(fname:str, grid:RegularGrid, options:dict, itime:int, pert:np.ndarray, varname:str) -> None:
"""Perturb the atmosphere wind by the geostrophic balance.
This applies to horizontal 2D wind fields.
Parameters
----------
fname : str
forcing file name
options : dict
perturbation options for the geostrophic_wind_adjust section of atmosphere forcing from yaml file
itime : int
current time index in the forcing file
pert : np.ndarray
perturbation array
varname : str
name of the variable to be used to perturb wind fields
Returns
-------
None
"""
if not options['do_adjust']: return
pres_name: str = options['pres_name']
if pres_name != varname: return
# doing wind perturbations by considering the geostrophic balance
pert_u, pert_v = pres_adjusted_wind_perturb(grid,
float(options['pres_pert_amp']),
float(options['wind_pert_amp']),
float(options['hcorr']), pert)
uname:str = options['u_name']
vname:str = options['v_name']
u: np.ndarray = read_var(fname, [uname,], itime)
v: np.ndarray = read_var(fname, [vname,], itime)
u = apply_perturb(grid, u, pert_u, options['type'])
v = apply_perturb(grid, v, pert_v, options['type'])
if options['wind_amp_name'] != 'None':
wind_amp_name:str = options['wind_amp_name']
wind_amp = np.sqrt(u**2 + v**2)
write_var(fname, [wind_amp_name,], wind_amp, itime)
write_var(fname, [uname, ], u, itime)
write_var(fname, [vname, ], v, itime)
[docs]
def get_forcing_filename(forcing_file_options:dict, i_ens:int, time:datetime) -> str:
"""Get the forcing file name based on the current time and the forcing file format
Parameters
----------
forcing_file_options : dict
forcing file options in the `file` section of the subsections of `perturb` section from the yaml file
i_ens : int
ensemble index
time : datetime
current time
Returns
-------
str
forcing file name
"""
# derive the forcing file name
# the format of the forcing file name
file_format:str = forcing_file_options['format']
# the date of the first forcing file
forcing_file_initial_date:str = forcing_file_options['initial_date']
# the length of the each forcing file
forcing_file_interval:str = forcing_file_options['interval']
# the length of the each forcing file
forcing_file_date_format:str = forcing_file_options['datetime_format']
# get the forcing file time
forcing_start_date:str
forcing_end_date:str
forcing_start_date, forcing_end_date = \
get_fname_daterange(time, forcing_file_initial_date, forcing_file_interval, forcing_file_date_format)
fname: str
try:
fname = file_format.format(i=i_ens , start=forcing_start_date, end=forcing_end_date)
except KeyError:
try:
fname = file_format.format(start=forcing_start_date, end=forcing_end_date)
except KeyError:
raise RuntimeError('Currently, we only supports keyword of 1. "start"+"end",'
'2. "start"+"end"+"i".'
'See the example yaml file for more information. '
'Modified the code if you have other requirements.')
return fname
[docs]
def perturb_forcing(forcing_options:dict, file_options:dict, i_ens: int, time: datetime, next_time:datetime, debug=False) -> None:
"""perturb the forcing variables
Parameters
----------
forcing_options : dict
perturbation options from the yaml file
file_options : dict
forcing file options in the corresponding subsection of the `file` section from the yaml file
e.g., info in the file/forcing/atmosphere is used in the perturb/forcing/atmosphere section
Before calling this function, one must add the following keys to the file_options dictionary:
- fname: the exact filename of the perturbed forcing file under absbolute path
i_ens : int
ensemble index
time : datetime
current time as the begining of the forecast cycle
next_time : datetime
end time of the next forecast cycle
"""
# perturbation arrays
pert: np.ndarray[typing.Any, np.dtype[np.float64]]
# path to the directory of the perturbation files
pert_path: str = forcing_options['path']
# create the directory if it does not exist
os.makedirs(os.path.join(pert_path, f'ensemble_{i_ens}'), exist_ok=True)
# time index and time array
time_index: np.ndarray[typing.Any, np.dtype[np.int64]]
time_array: list[datetime]
for forcing_name in forcing_options:
if forcing_name not in file_options: continue
# forcing options for each component, e.g., atmosphere or ocean
forcing_options_comp:dict = forcing_options[forcing_name]
file_options_comp:dict = file_options[forcing_name]
# get the forcing file name
fname:str = file_options_comp['fname']
# copy forcing files to the ensemble member directory
# we don't change the filename,
# but only copy limited time slices of the original forcing file
time_index, time_array = get_time_from_nc(file_options_comp['fname_src'],
file_options_comp['time_name'],
file_options_comp['time_units_name'],
time, next_time, debug
)
# get prev_time
prev_time:datetime = get_prev_time_from_nc(file_options_comp['fname_src'],
file_options_comp['time_name'],
file_options_comp['time_units_name'],
time_index[0]
)
with thread_lock:
slicing_nc.copy_time_sliced_nc_file(file_options_comp['fname_src'],
fname, time_index,
file_options_comp['time_name'],
time_array[0])
# get grid object for geometric information
with thread_lock:
with netCDF4.Dataset(fname, 'r') as f:
grid = RegularGrid(_proj, *_proj(f[file_options_comp['lon_name']],
f[file_options_comp['lat_name']]
)
)
for itime, time_f in enumerate(time_array):
# get options for perturbing the forcing variables
options = forcing_options_comp['variables']
for i, varname in enumerate(options['names']):
if typing.TYPE_CHECKING:
assert type(varname) == str, 'variable name must be a string'
# variable name in saved .npy filename
varname_f:str = varname.replace("/", "_")
# get perturbations
pert_fname:str = os.path.join(pert_path, f'ensemble_{i_ens}',
f'perturb_{varname_f}_{t2s(time_f)}.npy')
if os.path.exists(pert_fname):
pert = np.load(pert_fname)
else:
# convert the horizontal correlation length scale to grid points
hcorr:int = np.rint(float(options['hcorr'][i])/grid.dx)
# generate new perturbations
if prev_time != time_f:
prev_perturb_fname:str = os.path.join(pert_path,
f'ensemble_{i_ens}',
f'perturb_{varname_f}_{t2s(prev_time)}.npy')
pert_prev:np.ndarray = np.load(prev_perturb_fname)
# generate random perturbations for the current time step with AR1 correlation
pert_new:np.ndarray = gen_perturb(grid,
options['type'][i], float(options['amp'][i]),
hcorr)
pert = apply_AR1_perturb(pert_new, float(options['tcorr'][i]), pert_prev)
else:
pert = gen_perturb(grid, options['type'][i], float(options['amp'][i]), hcorr)
# save the perturbations for the next time step
perturb_file:str = os.path.join(pert_path,
f'ensemble_{i_ens}',
f'perturb_{varname_f}_{t2s(time_f)}.npy')
np.save(perturb_file, pert)
# apply perturbations to the variable data
# in the case of vector fields, we need to split the variable name
varname_list: list[str] = varname.split(';')
# read the variable data from forcing file
data: np.ndarray = read_var(fname, varname_list, itime)
# apply perturbations to the variable data
data = apply_perturb(grid, data, pert, options['type'][i])
# apply lower and upper bounds
lb = float(options['lower_bounds'][i]) if options['lower_bounds'][i] != 'None' else -np.inf
ub = float(options['upper_bounds'][i]) if options['upper_bounds'][i] != 'None' else np.inf
data = np.minimum(np.maximum(data, lb), ub)
# write the perturbed variable back to the forcing file
write_var(fname, varname_list, data, itime)
# generate wind perturbations and apply them to the atmosphere forcing files based on pressure perturbations
if forcing_name == 'atmosphere': geostrophic_perturb(fname, grid,
forcing_options_comp['geostrophic_wind_adjust'],
itime, pert, varname)
prev_time = time_f