import numpy as np
from NEDAS.utils.conversion import t2h, h2t, dt1h
from NEDAS.utils.parallel import distribute_tasks, bcast_by_root
from .context import Context
from .types import ProcIDMem, ProcIDRec, FieldRecordID, PartitionID, FieldRecord, FieldEns, StateEns
from .state_info import StateInfo
[docs]
class State:
"""
The State class manages the state variables for the assimilation system.
The analysis is performed on a regular grid.
The entire state has dimensions: member, variable, time, z, y, x
indexed by: mem_id, v, t, k, j, i
with size: nens, nv, nt, nz, ny, nx
To parallelize workload, we group the dimensions into 3 indices:
mem_id indexes the ensemble members
rec_id indexes the uniq 2D fields with (v, t, k), since nz and nt may vary
for different variables, we stack these dimensions in the 'record'
dimension with size nrec
par_id indexes the spatial partitions, which are subset of the 2D grid
given by (ist, ied, di, jst, jed, dj), for a complete field fld[j,i]
the processor with par_id stores fld[ist:ied:di, jst:jed:dj] locally.
The entire state is distributed across the memory of many processors,
at any moment, a processor only stores a subset of state in its memory:
either having all the mem_id,rec_id but only a subset of par_id (we call this
ensemble-complete), or having all the par_id but a subset of mem_id,rec_id
(we call this field-complete).
It is easier to perform i/o and pre/post processing on field-complete state,
while easier to run assimilation algorithms with ensemble-complete state.
"""
info: StateInfo
rec_list: dict[ProcIDRec, list[FieldRecordID]]
partitions: list # will be created by assimilator.partition_grid()
par_list: dict[ProcIDMem, list[PartitionID]]
fields_prior: FieldEns # will be created by self.prepare_state()
fields_z: FieldEns
state_prior: StateEns # will be created by self.transpose_to_ensemble_complete()
state_z: StateEns
state_post: StateEns # will be created by assimilator.assimilate()
fields_post: FieldEns # will be created by self.transpose_to_field_complete()
data: dict # will be created by self.pack_state_data(), for use in assmilator.assimilate()
def __init__(self, c: Context):
self.info = bcast_by_root(c.comm)(StateInfo)(c)
self.rec_list = bcast_by_root(c.comm)(self.distribute_state_tasks)(c)
self.partitions = []
self.par_list = {}
self.fields_prior = {}
self.fields_z = {}
self.state_prior = {}
self.state_z = {}
self.state_post = {}
self.fields_post = {}
self.data = {}
[docs]
def distribute_state_tasks(self, c: Context) -> dict[int, list[int]]:
"""
Distribute rec_id across processors
"""
# list rec_id as tasks
rec_list_full = [i for i in self.info.fields.keys()]
rec_size = np.array([2 if r.is_vector else 1 for i,r in self.info.fields.items()])
rec_list = distribute_tasks(c.comm_rec, rec_list_full, rec_size)
return rec_list
[docs]
def prepare_state(self, c: Context) -> None:
"""
Main method to collect fields from model to form the complete state (field-complete distributed)
"""
c.logger('Collect prior fields')(self.collect_prior_fields)(c)
#self.scalars_prior = self.collect_scalars(c)
c.logger('Collect reference z coords')(self.output_ref_z)(c)
# compute and save the prior ensemble and mean fields
c.logger('Output prior ensemble members')(self.output_state)(c, 'prior')
c.logger('Output prior ensemble mean')(self.output_ens_mean)(c, 'prior')
[docs]
def collect_prior_fields(self, c: Context) -> None:
"""
Collect fields from prior model state, convert them to the analysis grid,
preprocess (coarse-graining etc), save to fields[mem_id, rec_id] pointing to the uniq fields
Args:
c (Context): context object
Returns:
dict: fields dictionary [(mem_id, rec_id), fld]
where fld is np.array defined on c.grid, it's one of the state variable field
dict: fields_z dictionary [(mem_id, rec_id), zfld]
where zfld is same shape as fld, it's he z coordinates corresponding to each field
"""
pid_mem_show = [p for p,lst in c.mem_list.items() if len(lst)>0][0]
pid_rec_show = [p for p,lst in self.rec_list.items() if len(lst)>0][0]
# pid_show has some workload, it will print progress message
c.pid_show = pid_rec_show * c.config.nproc_mem + pid_mem_show
# process the fields, each proc gets its own workload as a subset of
# mem_id,rec_id; all pid goes through their own task list simultaneously
nm = len(c.mem_list[c.pid_mem])
nr = len(self.rec_list[c.pid_rec])
c.total_tasks = nm*nr
for m, mem_id in enumerate(c.mem_list[c.pid_mem]):
for r, rec_id in enumerate(self.rec_list[c.pid_rec]):
rec = self.info.fields[rec_id]
c.debug_message = f"prepare_state mem{mem_id+1:03} '{rec.name:20}' {rec.time} k={rec.k}"
c.current_task = m*nr+r
model_name = rec.model_src
model = c.models[model_name]
model_fld = c.io.call_method(c, 'current', model.read_var, member=mem_id, **rec.asdict())
model.grid.set_destination_grid(c.grid)
fld = model.grid.convert(model_fld, is_vector=rec.is_vector, method='linear', coarse_grain=True)
if rec.is_vector:
fld[:, c.grid.mask] = np.nan
else:
fld[c.grid.mask] = np.nan
# misc. transform can be added here
for transform_func in c.transform_funcs:
fld = transform_func.forward_state(c, rec, fld)
# save field to dict
self.fields_prior[mem_id, rec_id] = fld
# read z_coords for the field
# only need to generate the uniq z coords, store in bank
model_z = c.io.call_method(c, 'current', model.z_coords, member=mem_id, **rec.asdict())
z = model.grid.convert(model_z, is_vector=False, method='linear', coarse_grain=True)
if rec.is_vector:
self.fields_z[mem_id, rec_id] = np.array([z, z])
else:
self.fields_z[mem_id, rec_id] = z
c.comm.Barrier()
# additonal output of debugging
if c.debug:
c.io.save_debug_data(c, f"fields_prior_{c.pid_mem}_{c.pid_rec}", self.fields_prior, path=c.fs.analysis_dir(c.time, c.iter))
[docs]
def collect_scalar_variables(self, c):
pass
# TODO: implement scalars here for simultaneous state parameter estimation (SSPE)
[docs]
def output_state(self, c: Context, tag: str, mem_id_out: int|None=None, rec_id_out: int|None=None) -> None:
"""
Parallel output the fields to the binary state_file
Args:
c (Context): the runtime context obj
tag (str): which version of state this is: 'prior', 'post' or 'z' coords?
mem_id_out (int, optional): member id to be output, if None all available ids will output.
rec_id_out (int, optional): record id to be output, if None all available ids will output.
"""
c.io.prepare_fields_storage(c, tag)
nm = len(c.mem_list[c.pid_mem])
nr = len(self.rec_list[c.pid_rec])
c.total_tasks = nm*nr
for m, mem_id in enumerate(c.mem_list[c.pid_mem]):
if mem_id_out is not None and mem_id != mem_id_out:
continue
for r, rec_id in enumerate(self.rec_list[c.pid_rec]):
if rec_id_out is not None and rec_id != rec_id_out:
continue
rec = self.info.fields[rec_id]
c.debug_message = f"saving field: mem{mem_id+1:03} '{rec.name:20}' {rec.time} k={rec.k}"
c.current_task = m*nr+r
# get the field record for output
fields = getattr(self, f"fields_{tag}")
fld = fields[mem_id, rec_id]
c.io.write_field(fld, c, tag, rec_id, mem_id)
c.comm.Barrier()
[docs]
def output_ens_mean(self, c: Context, tag: str) -> None:
"""
Compute ensemble mean of a field stored distributively on all pid_mem
collect means on pid_mem=0, and output to mean_file
Args:
c (Context): the runtime context obj
tag (str): which version of state this is: 'prior_mean', 'post_mean', or 'z'
mean_file (str): path to the output binary file for the ensemble mean
"""
fields = getattr(self, f"fields_{tag}")
c.io.prepare_fields_storage(c, f"{tag}_mean")
c.total_tasks = len(self.rec_list[c.pid_rec])
for r, rec_id in enumerate(self.rec_list[c.pid_rec]):
rec = self.info.fields[rec_id]
c.debug_message = f"saving mean field '{rec.name:20}' {rec.time} k={rec.k}"
c.current_task = r
# initialize a zero field with right dimensions for rec_id
fld_shape = (2,)+self.info.shape if rec.is_vector else self.info.shape
sum_fld_pid = np.zeros(fld_shape)
# sum over all fields locally stored on pid
for mem_id in c.mem_list[c.pid_mem]:
sum_fld_pid += fields[mem_id, rec_id]
# sum over all field sums on different pids together to get the total sum
# TODO:reduce is expensive if only sparse pid holds state in memory, so in runtime should try to
# populate the comm_mem with members as much as possible.
sum_fld = c.comm_mem.allreduce(sum_fld_pid)
mean_fld = sum_fld / c.nens
c.io.write_field(mean_fld, c, f"{tag}_mean", rec_id, mem_id=0)
c.comm.Barrier()
[docs]
def output_ref_z(self, c: Context):
# topaz uses the first ensemble member z coords as the reference z for obs
# include this here for backward compatibility
# there is no need for choosing which member also, will just use the first one
if c.config.z_coords_from == 'member':
self.output_state(c, 'z', mem_id_out=0)
# we use by default the ensemble mean z coords as the reference z for obs
if c.config.z_coords_from == 'mean':
self.output_ens_mean(c, 'z')
[docs]
def pack_field_chunk(self, c: Context, fld, is_vector, dst_pid):
fld_chk = {}
for par_id in self.par_list[dst_pid]:
if len(c.grid.x.shape) == 2:
# slice for this par_id
istart,iend,di,jstart,jend,dj = self.partitions[par_id]
# save the unmasked points in slice to fld_chk for this par_id
mask_chk = c.grid.mask[jstart:jend:dj, istart:iend:di]
if is_vector:
fld_chk[par_id] = fld[:, jstart:jend:dj, istart:iend:di][:, ~mask_chk]
else:
fld_chk[par_id] = fld[jstart:jend:dj, istart:iend:di][~mask_chk]
else:
inds = self.partitions[par_id]
mask_chk = c.grid.mask[inds]
if is_vector:
fld_chk[par_id] = fld[:, inds][:, ~mask_chk]
else:
fld_chk[par_id] = fld[inds][~mask_chk]
return fld_chk
[docs]
def unpack_field_chunk(self, c, fld, fld_chk, src_pid):
for par_id in self.par_list[src_pid]:
if len(c.grid.x.shape) == 2:
istart,iend,di,jstart,jend,dj = self.partitions[par_id]
mask_chk = c.grid.mask[jstart:jend:dj, istart:iend:di]
fld[..., jstart:jend:dj, istart:iend:di][..., ~mask_chk] = fld_chk[par_id]
else:
inds = self.partitions[par_id]
mask_chk = c.grid.mask[inds]
fld[..., inds[~mask_chk]] = fld_chk[par_id]
[docs]
def transpose_to_ensemble_complete(self, c: Context, fields: FieldEns) -> StateEns:
"""
Send chunks of field owned by a pid to other pid
so that the field-complete fields get transposed into ensemble-complete state
with keys (mem_id, rec_id) pointing to the partition in par_list
Args:
c (Context): the runtime context
fields (FieldEns): The locally stored field-complete fields with subset of mem_id,rec_id
Returns:
StateEns: The locally stored ensemble-complete field chunks on partitions, dict[(mem_id, rec_id), dict[par_id, fld_chk]]
"""
state = {}
nr = len(self.rec_list[c.pid_rec])
nm_max = np.max([len(lst) for p,lst in c.mem_list.items()])
c.total_tasks = nr * nm_max
for r, rec_id in enumerate(self.rec_list[c.pid_rec]):
# all pid goes through their own mem_list simultaneously
mem_list_own = c.mem_list[c.pid_mem]
for m in range(nm_max):
status = f"processing mem{mem_list_own[m]+1:03} rec{rec_id}" if m < len(mem_list_own) else "waiting"
c.debug_message = f"transposing field: {status}"
c.current_task = r*nm_max+m
# prepare the fld for sending if not at the end of mem_list
fld = None
mem_id = None
rec = None
if m < len(c.mem_list[c.pid_mem]):
mem_id = c.mem_list[c.pid_mem][m]
rec = self.info.fields[rec_id]
fld = fields[mem_id, rec_id].copy()
# - for each source pid_mem (src_pid) with fields[mem_id, rec_id],
# send chunk of fld to destination pid_mem (dst_pid) with its partition in par_list
# - every pid needs to send/recv to/from every pid, so we use cyclic
# coreography here to prevent deadlock
# 1) receive fld_chk from src_pid, for src_pid<pid first
for src_pid in range(0, c.pid_mem):
if m < len(c.mem_list[src_pid]):
src_mem_id = c.mem_list[src_pid][m]
state[src_mem_id, rec_id] = c.comm_mem.recv(source=src_pid, tag=m)
# 2) send my fld chunk to a list of dst_pid, send to dst_pid>=pid first
# because they wait to receive before able to send their own stuff;
# when finished with dst_pid>=pid, cycle back to send to dst_pid<pid,
# i.e., dst_pid list = [pid, pid+1, ..., nproc-1, 0, 1, ..., pid-1]
if m < len(c.mem_list[c.pid_mem]):
assert isinstance(rec, FieldRecord), f"{rec} is not a FieldRecord"
for dst_pid in np.mod(np.arange(c.config.nproc_mem)+c.pid_mem, c.config.nproc_mem):
fld_chk = self.pack_field_chunk(c, fld, rec.is_vector, dst_pid)
if dst_pid == c.pid_mem:
# same pid, so just write to state
state[mem_id, rec_id] = fld_chk
else:
# send fld_chk to dst_pid's state
c.comm_mem.send(fld_chk, dest=dst_pid, tag=m)
# 3) finish receiving fld_chk from src_pid, for src_pid>pid now
for src_pid in range(c.pid_mem+1, c.config.nproc_mem):
if m < len(c.mem_list[src_pid]):
src_mem_id = c.mem_list[src_pid][m]
state[src_mem_id, rec_id] = c.comm_mem.recv(source=src_pid, tag=m)
c.comm.Barrier()
return state
[docs]
def transpose_to_field_complete(self, c: Context, state: StateEns) -> FieldEns:
"""
Transposes back the state to field-complete fields
Args:
c (Context): the runtime context
state (StateEns): the locally stored ensemble-complete field chunks for subset of par_id
Returns:
FieldEns: the locally stored field-complete fields for subset of mem_id,rec_id.
"""
fields = {}
# all pid goes through their own task list simultaneously
nr = len(self.rec_list[c.pid_rec])
nm_max = np.max([len(lst) for p,lst in c.mem_list.items()])
c.total_tasks = nr * nm_max
for r, rec_id in enumerate(self.rec_list[c.pid_rec]):
# all pid goes through their own mem_list simultaneously
mem_list_own = c.mem_list[c.pid_mem]
for m in range(nm_max):
status = f"processing mem{mem_list_own[m]} rec{rec_id}" if m < len(mem_list_own) else "waiting"
c.debug_message = f"transposing field: {status}"
c.current_task = r*nm_max+m
# prepare an empty fld for receiving if not at the end of mem_list
mem_id = None
fld = None
if m < len(c.mem_list[c.pid_mem]):
mem_id = c.mem_list[c.pid_mem][m]
rec = self.info.fields[rec_id]
if rec.is_vector:
fld = np.full((2,)+c.grid.x.shape, np.nan)
else:
fld = np.full(c.grid.x.shape, np.nan)
fields[mem_id, rec_id] = fld
# this is just the reverse of transpose_field_to_state
# we take the exact steps, but swap send and recv operations here
#
# 1) send my fld_chk to dst_pid, for dst_pid<pid first
for dst_pid in range(0, c.pid_mem):
if m < len(c.mem_list[dst_pid]):
dst_mem_id = c.mem_list[dst_pid][m]
c.comm_mem.send(state[dst_mem_id, rec_id], dest=dst_pid, tag=m)
del state[dst_mem_id, rec_id] # free up memory
# 2) receive fld_chk from a list of src_pid, from src_pid>=pid first
# because they wait to send stuff before able to receive themselves,
# cycle back to receive from src_pid<pid then.
if m < len(c.mem_list[c.pid_mem]):
assert mem_id is not None
assert fld is not None
for src_pid in np.mod(np.arange(c.config.nproc_mem)+c.pid_mem, c.config.nproc_mem):
if src_pid == c.pid_mem:
# same pid, so just copy fld_chk from state
fld_chk = state[mem_id, rec_id].copy()
else:
# receive fld_chk from src_pid's state
fld_chk = c.comm_mem.recv(source=src_pid, tag=m)
# unpack the fld_chk to form a complete field
self.unpack_field_chunk(c, fld, fld_chk, src_pid)
# 3) finish sending fld_chk to dst_pid, for dst_pid>pid now
for dst_pid in range(c.pid_mem+1, c.config.nproc_mem):
if m < len(c.mem_list[dst_pid]):
dst_mem_id = c.mem_list[dst_pid][m]
c.comm_mem.send(state[dst_mem_id, rec_id], dest=dst_pid, tag=m)
del state[dst_mem_id, rec_id] # free up memory
c.comm.Barrier()
return fields
[docs]
def pack_local_state_data(self, c: Context, par_id: PartitionID, state_prior: StateEns, state_z: StateEns) -> dict:
"""pack state dict into arrays to be more easily handled by jitted funcs"""
data = {}
# x,y coordinates for local state variables on pid
if len(c.grid.x.shape) == 2: # regular grid
ist,ied,di,jst,jed,dj = self.partitions[par_id]
msk = c.grid.mask[jst:jed:dj, ist:ied:di]
data['x'] = c.grid.x[jst:jed:dj, ist:ied:di][~msk]
data['y'] = c.grid.y[jst:jed:dj, ist:ied:di][~msk]
else:
inds = self.partitions[par_id]
msk = c.grid.mask[inds]
data['x'] = c.grid.x[inds][~msk]
data['y'] = c.grid.y[inds][~msk]
data['field_ids'] = []
for rec_id in self.rec_list[c.pid_rec]:
rec = self.info.fields[rec_id]
v_list = [0, 1] if rec.is_vector else [None]
for v in v_list:
data['field_ids'].append((rec_id, v))
nfld = len(data['field_ids'])
nloc = len(data['x'])
data['t'] = np.full(nfld, np.nan)
data['z'] = np.zeros((nfld, nloc))
data['var_id'] = np.full(nfld, 0)
data['err_type'] = np.full(nfld, 0)
data['state_prior'] = np.full((c.nens, nfld, nloc), np.nan)
for n in range(nfld):
rec_id, v = data['field_ids'][n]
rec = self.info.fields[rec_id]
data['t'][n] = t2h(rec.time)
data['err_type'][n] = self.info.err_types.index(rec.err_type)
data['var_id'][n] = self.info.variables.index(rec.name)
for m in range(c.nens):
data['z'][n, :] += np.squeeze(state_z[m, rec_id][par_id][v, :]).astype(np.float32) / c.nens # ens mean z
data['state_prior'][m, n, :] = np.squeeze(state_prior[m, rec_id][par_id][v, :].copy())
return data
[docs]
def unpack_local_state_data(self, c: Context, par_id: PartitionID, state_prior: StateEns, data: dict) -> None:
"""unpack data and write back to the state dict"""
nfld = len(data['field_ids'])
for m in range(c.nens):
for n in range(nfld):
rec_id, v = data['field_ids'][n]
state_prior[m, rec_id][par_id][v, :] = data['state_prior'][m, n, :]