Source code for NEDAS.core.updator
import os
import inspect
from abc import ABC, abstractmethod
import numpy as np
from NEDAS.config import parse_config
from .context import Context
from .types import MemID, FieldRecordID
[docs]
class Updator(ABC):
"""
Base class for updators of the model restart files
"""
increment: dict = {}
def __init__(self, c: Context):
# get updator parameters from config file
code_dir = os.path.dirname(inspect.getfile(self.__class__))
config_dict = parse_config(code_dir, parse_args=False, **c.config.updator_def)
for key, value in config_dict.items():
setattr(self, key, value)
[docs]
def update(self, c: Context) -> None:
"""
Top-level routine to apply the analysis increments to the original model
restart files (as initial conditions for the next forecast)
"""
pid_mem_show = [p for p,lst in c.mem_list.items() if len(lst)>0][0]
pid_rec_show = [p for p,lst in c.state.rec_list.items() if len(lst)>0][0]
c.pid_show = pid_rec_show * c.config.nproc_mem + pid_mem_show
# compute analysis increments
self.compute_increment(c)
# if in offline mode, initialize file locks for async io
if c.config.io_mode == 'offline':
self.init_all_file_locks(c)
# process the fields, each processor goes through its own subset of
# mem_id,rec_id simultaneously
# but need to keep every rank in sync to coordinate multiprocess file access
nm_max = np.max([len(lst) for _,lst in c.mem_list.items()])
nr_max = np.max([len(lst) for _,lst in c.state.rec_list.items()])
c.total_tasks = nr_max * nm_max
for r in range(nr_max):
for m in range(nm_max):
pid_active = ( m < len(c.mem_list[c.pid_mem]) and r < len(c.state.rec_list[c.pid_rec]) )
if pid_active:
mem_id = c.mem_list[c.pid_mem][m]
rec_id = c.state.rec_list[c.pid_rec][r]
rec = c.state.info.fields[rec_id].asdict()
debug_msg = f"update_restartfile mem{mem_id+1:03} '{rec['name']:20}' {rec['time']} k={rec['k']}"
# apply the increment to restart files (use io backend)
self.update_files(c, mem_id, rec_id)
else:
debug_msg = f"waiting"
c.debug_message = debug_msg
c.current_task = m*nr_max+r
c.comm.Barrier()
c.comm.cleanup_file_locks()
[docs]
def init_all_file_locks(self, c: Context) -> None:
"""
Prepare file locks for asynchronous io, needed for blocking write (e.g. in netcdf without parallel support)
"""
# get file names for async io
files = []
for mem_id in c.mem_list[c.pid_mem]:
for rec_id in c.state.rec_list[c.pid_rec]:
rec = c.state.info.fields[rec_id].asdict()
model = c.models[rec['model_src']]
file = c.io.call_method(c, 'current', getattr(model, 'filename'), member=mem_id, **rec)
if file:
files.append(file)
# collect files from all pids
all_files = c.comm.allgather(files)
# flatten and filter to unique files
unique_files = {f for sublist in all_files for f in sublist if f}
# create the file locks
for file in unique_files:
c.comm.init_file_lock(file)
c.comm.Barrier()
[docs]
@abstractmethod
def compute_increment(self, c: Context) -> None:
...
[docs]
@abstractmethod
def update_files(self, c: Context, mem_id: MemID, rec_id: FieldRecordID) -> None:
...