import os
import struct
from typing import Callable
import numpy as np
from NEDAS.utils.conversion import type_dic, type_size
from NEDAS.core.io_backend import IOBackend
from NEDAS.core.context import Context
[docs]
class OfflineIO(IOBackend):
"""
Offline IO backend using restart files to hold model state (a pause-restart strategy)
"""
io_mode = 'offline'
[docs]
def binfile_name(self, c: Context, tag: str) -> str:
"""
Name of the binary file that stores the state data.
Args:
c (Context): the runtime context
tag (str): which version of the state
Returns:
str: file name
"""
analysis_dir = c.fs.analysis_dir(c.time, c.iter)
return os.path.join(analysis_dir, f'fields_{tag}.bin')
[docs]
def prepare_fields_storage(self, c: Context, tag: str):
binfile = self.binfile_name(c, tag)
if c.pid == 0:
# create the .bin file
with open(binfile, 'wb') as f:
pass
# write state_info to the accompanying .dat file
c.state.info.write_to_file(binfile)
c.comm.Barrier()
[docs]
def read_field(self, c: Context, tag: str, rec_id: int, mem_id: int) -> np.ndarray:
"""
Read a field from cache or binary file
"""
self.validate_tag(tag)
# check if it is available in cache
if hasattr(c.state, f"fields_{tag}"):
if c.state and rec_id in c.state.rec_list[c.pid_rec] and mem_id in c.mem_list[c.pid_mem]:
fields = getattr(c.state, f"fields_{tag}")
return fields[mem_id, rec_id]
# otherwise, read it from binfile
rec = c.state.info.fields[rec_id]
nv = 2 if rec.is_vector else 1
fld_shape = (2,)+c.state.info.shape if rec.is_vector else c.state.info.shape
fld_size = np.sum((~c.grid.mask).astype(int))
binfile = self.binfile_name(c, tag)
with open(binfile, 'rb') as f:
f.seek(mem_id*c.state.info.size + rec.pos)
fld_ = np.array(struct.unpack((nv*fld_size*type_dic[rec.dtype]),
f.read(nv*fld_size*type_size[rec.dtype])))
fld = np.full(fld_shape, np.nan)
if rec.is_vector:
fld[:, ~c.grid.mask] = fld_.reshape((2, -1))
else:
fld[~c.grid.mask] = fld_
return fld
[docs]
def write_field(self, fld: np.ndarray, c: Context, tag: str, rec_id: int, mem_id: int) -> None:
"""
Write a field to a binary file
"""
# only write to binfile if the field is owned by the pid_mem
# for ensemble mean every pid_mem receives a copy from allreduce, but only root need to write it.
if mem_id not in c.mem_list[c.pid_mem]:
return
self.validate_tag(tag)
rec = c.state.info.fields[rec_id]
fld_shape = (2,)+c.state.info.shape if rec.is_vector else c.state.info.shape
assert fld.shape == fld_shape, f'fld shape incorrect: expected {fld_shape}, got {fld.shape}'
if rec.is_vector:
fld_ = fld[:, ~c.grid.mask].flatten()
else:
fld_ = fld[~c.grid.mask]
binfile = self.binfile_name(c, tag)
with open(binfile, 'r+b') as f:
f.seek(mem_id*c.state.info.size + rec.pos)
f.write(struct.pack(fld_.size*type_dic[rec.dtype], *fld_))
[docs]
def call_method(self, c: Context, tag: str, method: Callable, *args, **kwargs):
self.validate_tag(tag)
# if path is already specified, directly call the method
if 'path' in kwargs and kwargs['path'] is not None:
return method(*args, **kwargs)
# otherwise, use additional info from kwargs to form the path
model_name = kwargs['model_src']
model = c.models[model_name]
if tag in ['raw', 'current', 'post', 'z']:
path = c.fs.forecast_dir(c.time, model_name)
elif tag == 'prior':
if kwargs['time'] == c.time:
path = c.fs.forecast_dir(c.prev_time, model_name)
else:
path = c.fs.forecast_dir(c.time, model_name)
elif tag == 'truth':
path = model.truth_dir
else:
raise ValueError(f"tag '{tag}' not supported in io.call_method")
# make sure path exists
if path:
c.fs.make_dir(path)
kwargs['path'] = path
return method(*args, **kwargs)