Source code for NEDAS.core.diag
import importlib
from NEDAS.utils.conversion import ensure_list
from NEDAS.utils.parallel import bcast_by_root, distribute_tasks
from .context import Context
from .types import ProcID
[docs]
class Diagnostics:
"""
This class manages diagnostics functions
"""
task_list: dict[ProcID, list]
def __init__(self, c: Context) -> None:
# get task list for each rank
self.task_list = bcast_by_root(c.comm)(self.distribute_diag_tasks)(c)
# the processor with most work load will show progress messages
c.pid_show = [p for p,lst in self.task_list.items() if len(lst)>0][0]
# init file locks for collective i/o
self.init_file_locks(c)
def __call__(self, c: Context) -> None:
c.total_tasks = len(self.task_list[c.pid])
for task_id, rec in enumerate(self.task_list[c.pid]):
c.debug_message = f"running diagnostics '{rec['method']}'"
c.current_task = task_id
method_name = f"NEDAS.diag.{rec['method']}"
mod = importlib.import_module(method_name)
# perform the diag task
mod.run(c, **rec)
c.comm.Barrier()
c.comm.cleanup_file_locks()
[docs]
def distribute_diag_tasks(self, c: Context):
"""Build the full task list and distribute among mpi ranks"""
task_list_full = []
for rec in ensure_list(c.config.diag):
# load the module for the given method
method_name = f"NEDAS.diag.{rec['method']}"
module = importlib.import_module(method_name)
# module returns a list of tasks to be done by each processor
if not hasattr(module, 'get_task_list'):
task_list_full.append(rec)
continue
task_list_rec = module.get_task_list(c, **rec)
for task in task_list_rec:
task_list_full.append(task)
# collected full list of tasks is evenly distributed across the mpi communicator
task_list = distribute_tasks(c.comm, task_list_full)
return task_list
[docs]
def init_file_locks(self, c: Context):
"""Build the full task list for the diagnostics part of the config"""
for rec in ensure_list(c.config.diag):
# load the module for the given method
method_name = f"NEDAS.diag.{rec['method']}"
module = importlib.import_module(method_name)
# module get_file_list returns a list of files for collective i/o
if not hasattr(module, 'get_file_list'):
continue
files = module.get_file_list(c, **rec)
for file in files:
# create the file lock across mpi ranks for this file
c.comm.init_file_lock(file)