Source code for NEDAS.models.wrf.restart_io

import sys
import os
import argparse
import threading
import numpy as np
from netCDF4 import Dataset
from NEDAS.utils.parallel import distribute_tasks, Comm

[docs] def read_chunks(filename, chk_list, var_list): # get the total dimensions for each variable var_dims = {} chunks = {} with Dataset(filename+'_0000', 'r') as f: for vname in var_list: dims = f[vname].dimensions if 'west_east' in dims: ni = f.getncattr('WEST-EAST_GRID_DIMENSION') - 1 dim_i = 'west_east' elif 'west_east_stag' in dims: ni = f.getncattr('WEST-EAST_GRID_DIMENSION') dim_i = 'west_east_stag' else: raise ValueError("west_east dimension not found in "+vname) if 'south_north' in dims: nj = f.getncattr('SOUTH-NORTH_GRID_DIMENSION') - 1 dim_j = 'south_north' elif 'south_north_stag' in dims: nj = f.getncattr('SOUTH-NORTH_GRID_DIMENSION') dim_j = 'south_north_stag' else: raise ValueError("south_north dimension not found in "+vname) if 'bottom_top' in dims: nk = f.getncattr('BOTTOM-TOP_GRID_DIMENSION') - 1 dim_k = 'bottom_top' elif 'bottom_top_stag' in dims: nk = f.getncattr('BOTTOM-TOP_GRID_DIMENSION') dim_k = 'bottom_top_stag' else: nk = 1 dim_k = 'surface' var_dims[vname] = (ni, nj, nk, dim_i, dim_j, dim_k) chunks[vname] = {} # each processor read subset of chunks for chk_id in chk_list: with Dataset(filename+f'_{chk_id:04d}', 'r') as f: for vname in var_list: dims = f[vname].dimensions if 'west_east' in dims: i1 = f.getncattr('WEST-EAST_PATCH_START_UNSTAG') - 1 i2 = f.getncattr('WEST-EAST_PATCH_END_UNSTAG') elif 'west_east_stag' in dims: i1 = f.getncattr('WEST-EAST_PATCH_START_STAG') - 1 i2 = f.getncattr('WEST-EAST_PATCH_END_STAG') else: raise ValueError("west_east dimension not found in "+vname) if 'south_north' in dims: j1 = f.getncattr('SOUTH-NORTH_PATCH_START_UNSTAG') - 1 j2 = f.getncattr('SOUTH-NORTH_PATCH_END_UNSTAG') elif 'south_north_stag' in dims: j1 = f.getncattr('SOUTH-NORTH_PATCH_START_STAG') - 1 j2 = f.getncattr('SOUTH-NORTH_PATCH_END_STAG') else: raise ValueError("south_north dimension not found in "+vname) if 'bottom_top' in dims: k1 = f.getncattr('BOTTOM-TOP_PATCH_START_UNSTAG') - 1 k2 = f.getncattr('BOTTOM-TOP_PATCH_END_UNSTAG') elif 'bottom_top_stag' in dims: k1 = f.getncattr('BOTTOM-TOP_PATCH_START_STAG') - 1 k2 = f.getncattr('BOTTOM-TOP_PATCH_END_STAG') else: k1 = 0 k2 = 1 chunks[vname][chk_id] = (i1,i2, j1,j2, k1,k2, f[vname][0, ...]) return var_dims, chunks
[docs] def write_chunks(filename, chk_list, var_list, chunks): for chk_id in chk_list: with Dataset(filename+f'_{chk_id:04d}', 'r+') as f: for vname in var_list: _,_, _,_, _,_, chk = chunks[vname][chk_id] f[vname][0, ...] = chk
[docs] def transpose_chunks_to_fields(comm, chk_list_pid, var_list_pid, var_dims, chunks): fields = {} pid = comm.Get_rank() nproc = comm.Get_size() nv_max = np.max([len(lst) for p,lst in var_list_pid.items()]) for v in range(nv_max): vname = None if v < len(var_list_pid[pid]): # prepare empty field for receiving chunks vname = var_list_pid[pid][v] ni, nj, nk, _,_,_ = var_dims[vname] fields[vname] = np.full((nk, nj, ni), np.nan) assert vname is not None, f"Variable index {v} exceeds the number of variables assigned to processor {pid}" for dst_pid in np.arange(0, pid): if v < len(var_list_pid[dst_pid]): dst_vname = var_list_pid[dst_pid][v] comm.send(chunks[dst_vname], dest=dst_pid, tag=v) if v < len(var_list_pid[pid]): for src_pid in np.mod(np.arange(nproc)+pid, nproc): if src_pid == pid: chks = chunks[vname] else: chks = comm.recv(source=src_pid, tag=v) # unpack chunks into the full field for chk_id in chk_list_pid[src_pid]: i1,i2,j1,j2,k1,k2, chk = chks[chk_id] fields[vname][k1:k2, j1:j2, i1:i2] = chk for dst_pid in np.arange(pid+1, nproc): if v < len(var_list_pid[dst_pid]): dst_vname = var_list_pid[dst_pid][v] comm.send(chunks[dst_vname], dest=dst_pid, tag=v) return fields
[docs] def transpose_fields_to_chunks(): pass
[docs] def read_fields_bin(comm, filename, var_list, var_dims): return fields
[docs] def write_fields_bin(comm, filename, var_list, var_dims, fields): pid = comm.Get_rank() if pid == 0: with open(filename, 'w') as f: pass comm.Barrier() for vname in var_list: with open(filename, 'r+b') as f: pass
#TODO #f.seek() #f.write()
[docs] def read_fields_nc(comm, filename, var_list, var_dims): return fields
[docs] def write_fields_nc(comm, filename, var_list, var_dims, fields): """output a joined nc file for wrfrst, one variable per file""" for vname in var_list: with Dataset(filename+'_'+vname+'.nc', 'w') as f: ni,nj,nk, dim_i,dim_j,dim_k = var_dims[vname] if dim_i not in f.dimensions: f.createDimension(dim_i, ni) if dim_j not in f.dimensions: f.createDimension(dim_j, nj) if dim_k not in f.dimensions: f.createDimension(dim_k, nk) f.createVariable(vname, float, (dim_k, dim_j, dim_i)) f[vname][...] = fields[vname]
if __name__ == '__main__': parser = argparse.ArgumentParser(description='Program to read chunks of wrf restart files into a binary file, or reverse') parser.add_argument('mode', choices=['join', 'split'], help='operation mode: join or split') parser.add_argument('filename', help='wrf restart file name: wrfrst_d01_<time>') parser.add_argument('nchk', default=1, type=int, help='number of processors (chunks)') args = parser.parse_args() comm = Comm() pid = comm.Get_rank() chk_list = np.arange(args.nchk) chk_list_pid = distribute_tasks(comm, chk_list) var_list = ['U_1', 'V_1', 'W_1', 'T', 'PH_1', 'P', 'MU_1', 'QVAPOR', 'QCLOUD'] var_list_pid = distribute_tasks(comm, var_list) full_file = os.path.join(os.path.dirname(args.filename), 'restart') if args.mode == 'join': var_dims, chunks = read_chunks(args.filename, chk_list_pid[pid], var_list) fields = transpose_chunks_to_fields(comm, chk_list_pid, var_list_pid, var_dims, chunks) write_fields_nc(comm, full_file, var_list_pid[pid], var_dims, fields) # elif args.mode == 'split': # split_chunks(args.filename, chk_list_pid[pid], var_list, full_file)