from typing import Optional
import copy
import numpy as np
import matplotlib
from NEDAS.grid.grid_2d_base import Grid2DBase
[docs]
class RegularGrid(Grid2DBase):
"""
Regular 2D grid class
Args:
pole_dim (str, optional):
`None` (default), if one of the dimension has poles, :code:`'x'` or :code:`'y'`
pole_index (str, optional):
`None` (default), tuple of the pole index(s) in `pole_dim`
neighbors (np.ndarray, optional):
`None` (default), for regular grid with special geometry (e.g. tripolar ocean grid),
`neighbors` stores the j,i index of 4 neighors (east, north, west, and south) on each grid point.
Since specifying neighbors already take care of cyclic boundary conditions, `cyclic_dim` will be discarded if `neighbors` is set.
"""
nx: int
ny: int
pole_dim: str|None
pole_index: list[int]|None
def __init__(self, proj, x, y, bounds=None, cyclic_dim=None, distance_type='cartesian',
pole_dim=None, pole_index=None, neighbors=None, dst_grid=None,):
super().__init__(proj, x, y, bounds, cyclic_dim, distance_type, dst_grid)
self.regular = True
self.pole_dim = pole_dim
self.pole_index = pole_index
self.neighbors = neighbors
if self.neighbors is not None and self.cyclic_dim is not None:
print('neighbors already implemented, discarding cyclic_dim')
self.cyclic_dim = None
self.nx = self.x.shape[1]
self.ny = self.x.shape[0]
self.dx = (self.xmax - self.xmin) / (self.nx - 1)
self.dy = (self.ymax - self.ymin) / (self.ny - 1)
self.Lx = self.nx * self.dx
self.Ly = self.ny * self.dy
self.npoints = self.nx * self.ny
self._dst_grid = None
if dst_grid is not None:
self.set_destination_grid(dst_grid)
else:
self.set_destination_grid(self)
[docs]
def change_resolution_level(self, nlevel):
"""
Generate a new grid with changed resolution.
Args:
nlevel (int):
Positive number, downsample grid abs(nlevel) times, each time doubling the grid spacing;
Negative number, upsample grid abs(nlevel) times, each time halving the grid spacing
Returns:
A new grid object with changed resolution.
"""
if nlevel == 0:
return self
else:
# create a new grid object with x,y at new resolution level
self._dst_grid = None
new_grid = copy.deepcopy(self)
fac = 2**nlevel
new_grid.dx = self.dx * fac
new_grid.dy = self.dy * fac
new_grid.nx = int(np.round(self.Lx / new_grid.dx))
new_grid.ny = int(np.round(self.Ly / new_grid.dy))
assert min(new_grid.nx, new_grid.ny) > 1, "Grid.change_resolution_level: new resolution too low, try smaller nlevel"
new_grid.x, new_grid.y = np.meshgrid(self.xmin + np.arange(new_grid.nx) * new_grid.dx, self.ymin + np.arange(new_grid.ny) * new_grid.dy)
# coarsen the mask
self.set_destination_grid(new_grid)
new_grid.mask = self.convert(self.mask, method='nearest').astype(bool)
return new_grid
[docs]
def find_index(self, x_, y_):
x_ = np.array(x_).flatten()
y_ = np.array(y_).flatten()
# lon: pyproj.Proj works only for lon=-180:180
if self.proj_name == 'longlat':
x_ = np.mod(x_ + 180., 360.) - 180.
# #account for cyclic dim, when points drop "outside" then wrap around
x_, y_ = self._wrap_cyclic_xy(x_, y_)
xi = self.x[0, :]
yi = self.y[:, 0]
# sort the index to monoticially increasing
# x_,y_ are the sorted coordinates of grid points
# i_,j_ are their original grid index
i_ = np.argsort(xi)
xi_ = xi[i_]
j_ = np.argsort(yi)
yi_ = yi[j_]
# pad cyclic dimensions with additional grid point for the wrap-around
if self.cyclic_dim is not None:
for d in self.cyclic_dim:
if d=='x':
if xi_[0]+self.Lx not in xi_:
xi_ = np.hstack((xi_, xi_[0] + self.Lx))
i_ = np.hstack((i_, i_[0]))
elif d=='y':
if yi_[0]+self.Ly not in yi_:
yi_ = np.hstack((yi_, yi_[0] + self.Ly))
j_ = np.hstack((j_, j_[0]))
# if neighbors indices are provided, the search range is extended by 1 grid on both sides
if self.neighbors is not None:
xi_ = np.hstack((xi_[0]-self.dx, xi_, xi_[-1]+self.dx))
i_ = np.hstack((i_[0]-1, i_, i_[-1]+1))
yi_ = np.hstack((yi_[0]-self.dy, yi_, yi_[-1]+self.dy))
j_ = np.hstack((j_[0]-1, j_, j_[-1]+1))
# now find the position near the given x_,y_ coordinates
# pi,pj are the index in the padded array, right side of the given x_,y_
# only the positions inside the grid will be kept
pi = np.array(np.searchsorted(xi_, x_, side='right'))
pj = np.array(np.searchsorted(yi_, y_, side='right'))
inside = ~np.logical_or(np.logical_or(pi==len(xi_), pi==0),
np.logical_or(pj==len(yi_), pj==0))
pi, pj = pi[inside], pj[inside]
# vertices (p1, p2, p3, p4) for the rectangular grid box
# p3 is the point found by the search index (pj,pi),
# internal coordinates (in_x, in_y) pinpoint the x_,y_ location inside
# the rectangle with values range [0, 1)
# (pj,pi-1) p4----+------p3 (pj,pi)
# | | |
# +in_x*------+
# | in_y |
# (pj-1,pi-1) p1----+------p2 (pj-1,pi)
indices = None #for regular grid, the element indices are not used
if self.neighbors is not None:
# find the right indices for each vertex grid point
j1,i1 = j_[pj-1], i_[pi-1]
j2,i2 = np.zeros(pj.shape, dtype=int), np.zeros(pj.shape, dtype=int)
j3,i3 = np.zeros(pj.shape, dtype=int), np.zeros(pj.shape, dtype=int)
j4,i4 = np.zeros(pj.shape, dtype=int), np.zeros(pj.shape, dtype=int)
ind = np.where(np.logical_and(j1>=0, i1>=0)) # p1 is the anchor in neighbors
j2[ind], i2[ind] = self.neighbors[0,0,j1[ind],i1[ind]], self.neighbors[1,0,j1[ind],i1[ind]]
j3[ind], i3[ind] = self.neighbors[0,1,j2[ind],i2[ind]], self.neighbors[1,1,j2[ind],i2[ind]]
j4[ind], i4[ind] = self.neighbors[0,1,j1[ind],i1[ind]], self.neighbors[1,1,j1[ind],i1[ind]]
ind = np.where(np.logical_and(j1>=0, i1<0)) # p2 is the anchor in neighbors
j2[ind], i2[ind] = j_[pj-1][ind], i_[pi][ind]
j1[ind], i1[ind] = self.neighbors[0,2,j2[ind],i2[ind]], self.neighbors[1,2,j2[ind],i2[ind]]
j3[ind], i3[ind] = self.neighbors[0,1,j2[ind],i2[ind]], self.neighbors[1,1,j2[ind],i2[ind]]
j4[ind], i4[ind] = self.neighbors[0,2,j3[ind],i3[ind]], self.neighbors[1,2,j3[ind],i3[ind]]
ind = np.where(np.logical_and(j1<0, i1<0)) # p3 is the anchor in neighbors
j3[ind], i3[ind] = j_[pj][ind], i_[pi][ind]
j2[ind], i2[ind] = self.neighbors[0,3,j3[ind],i3[ind]], self.neighbors[1,3,j3[ind],i3[ind]]
j4[ind], i4[ind] = self.neighbors[0,2,j3[ind],i3[ind]], self.neighbors[1,2,j3[ind],i3[ind]]
j1[ind], i1[ind] = self.neighbors[0,2,j2[ind],i2[ind]], self.neighbors[1,2,j2[ind],i2[ind]]
ind = np.where(np.logical_and(j1<0, i1>=0)) # p4 is the anchor in neighbors
j4[ind], i4[ind] = j_[pj][ind], i_[pi-1][ind]
j1[ind], i1[ind] = self.neighbors[0,3,j4[ind],i4[ind]], self.neighbors[1,3,j4[ind],i4[ind]]
j3[ind], i3[ind] = self.neighbors[0,0,j4[ind],i4[ind]], self.neighbors[1,0,j4[ind],i4[ind]]
j2[ind], i2[ind] = self.neighbors[0,0,j1[ind],i1[ind]], self.neighbors[1,0,j1[ind],i1[ind]]
else:
# use normal rectangle grid indices
j1, i1 = j_[pj-1], i_[pi-1]
j2, i2 = j_[pj-1], i_[pi]
j3, i3 = j_[pj], i_[pi]
j4, i4 = j_[pj], i_[pi-1]
# assign the points to vertices
vertices = np.zeros(pi.shape+(4,), dtype=int)
vertices[:, 0] = j1 * self.nx + i1
vertices[:, 1] = j2 * self.nx + i2
vertices[:, 2] = j3 * self.nx + i3
vertices[:, 3] = j4 * self.nx + i4
# internal coordinates inside rectangles
in_coords = np.zeros(pi.shape+(2,), dtype=np.float64)
in_coords[:, 0] = (x_[inside] - xi_[pi-1]) / (xi_[pi] - xi_[pi-1])
in_coords[:, 1] = (y_[inside] - yi_[pj-1]) / (yi_[pj] - yi_[pj-1])
# index of grid point nearest to (x_,y_)
j_near = np.zeros(pj.shape, dtype=int)
i_near = np.zeros(pj.shape, dtype=int)
ind = np.where(np.logical_and(in_coords[:,0]<0.5, in_coords[:,1]<0.5))
j_near[ind], i_near[ind] = j1[ind], i1[ind]
ind = np.where(np.logical_and(in_coords[:,0]>=0.5, in_coords[:,1]<0.5))
j_near[ind], i_near[ind] = j2[ind], i2[ind]
ind = np.where(np.logical_and(in_coords[:,0]>=0.5, in_coords[:,1]>=0.5))
j_near[ind], i_near[ind] = j3[ind], i3[ind]
ind = np.where(np.logical_and(in_coords[:,0]<0.5, in_coords[:,1]>=0.5))
j_near[ind], i_near[ind] = j4[ind], i4[ind]
nearest = j_near * self.nx + i_near
return inside, indices, vertices, in_coords, nearest
def _fill_pole_void(self, fld):
"""
if rotation of vectors (or other reasons) generates nan at the poles
we fill in the void using surrounding values for each pole defined by self.pole_dim and pole_index
"""
if self.pole_dim == 'x':
for i in self.pole_index or []:
if i==0:
fld[:, 0] = np.mean(fld[:, 1])
if i==-1:
fld[:, -1] = np.mean(fld[:, -2])
if self.pole_dim == 'y':
for i in self.pole_index or []:
if i==0:
fld[0, :] = np.mean(fld[1, :])
if i==-1:
fld[-1, :] = np.mean(fld[-2, :])
return fld
[docs]
def rotate_vectors(self, vec_fld):
vec_fld = super().rotate_vectors(vec_fld)
for i in range(2):
vec_fld[i,...] = self._fill_pole_void(vec_fld[i,...])
return vec_fld
[docs]
def get_corners(self, fld):
"""
given fld defined on a regular grid, obtain its value on the 4 corners/vertices
"""
assert fld.shape == self.x.shape, "fld shape does not match x,y"
nx, ny = fld.shape
fld_ = np.zeros((nx+1, ny+1))
# use linear interp in interior
fld_[1:nx, 1:ny] = 0.25*(fld[1:nx, 1:ny] + fld[1:nx, 0:ny-1] + fld[0:nx-1, 1:ny] + fld[0:nx-1, 0:ny-1])
# use 2nd-order polynomial extrapolat along borders
fld_[0, :] = 3*fld_[1, :] - 3*fld_[2, :] + fld_[3, :]
fld_[nx, :] = 3*fld_[nx-1, :] - 3*fld_[nx-2, :] + fld_[nx-3, :]
fld_[:, 0] = 3*fld_[:, 1] - 3*fld_[:, 2] + fld_[:, 3]
fld_[:, ny] = 3*fld_[:, ny-1] - 3*fld_[:, ny-2] + fld_[:, ny-3]
# make corners into new dimension
fld_corners = np.zeros((nx, ny, 4))
fld_corners[:, :, 0] = fld_[0:nx, 0:ny]
fld_corners[:, :, 1] = fld_[0:nx, 1:ny+1]
fld_corners[:, :, 2] = fld_[1:nx+1, 1:ny+1]
fld_corners[:, :, 3] = fld_[1:nx+1, 0:ny]
return fld_corners
def _interp_weights(self, inside, vertices, in_coords):
"""
Compute interpolation weights from the outputs of find_index
the interp_weights are the weights (sums to 1) given to each grid vertex in self.x,y
based on their distance to the x_,y_ points (as specified by the in_coords)
Args:
inside, vertices, in_coords: from the output of self.find_index
Returns:
interp_weights (np.array): interpolation weights
"""
# compute bilinear interp weights
interp_weights = np.zeros(vertices.shape)
interp_weights[:, 0] = (1-in_coords[:, 0]) * (1-in_coords[:, 1])
interp_weights[:, 1] = in_coords[:, 0] * (1-in_coords[:, 1])
interp_weights[:, 2] = in_coords[:, 0] * in_coords[:, 1]
interp_weights[:, 3] = (1-in_coords[:, 0]) * in_coords[:, 1]
return interp_weights
[docs]
def interp(self, fld, x=None, y=None, method='linear'):
if self.dst_grid is None:
raise ValueError("dst_grid not set for interpolation")
if x is None or y is None:
# use precalculated weights for self.dst_grid
inside = self.interp_inside
indices = self.interp_indices
vertices = self.interp_vertices
nearest = self.interp_nearest
weights = self.interp_weights
x = self.dst_grid.x
else:
# otherwise compute the weights for the given x,y
inside, indices, vertices, in_coords, nearest = self.find_index(x, y)
weights = self._interp_weights(inside, vertices, in_coords)
fld_interp = np.full(np.array(x).flatten().shape, np.nan)
if fld.shape == self.x.shape:
if method == 'nearest':
# find the node of the triangle with the maximum weight
fld_interp[inside] = fld.flatten()[nearest]
elif method == 'linear':
# sum over the weights for each node of triangle
fld_interp[inside] = np.einsum('nj,nj->n', np.take(fld.flatten(), vertices), weights)
else:
raise NotImplementedError(f"interp method {method} is not yet available")
else:
raise ValueError(f"field shape {fld.shape} does not match grid shape {self.x.shape}")
return fld_interp.reshape(np.array(x).shape)
# #utility function for coarse-graining (high->low resolution)
[docs]
def coarsen(self, fld):
if self.dst_grid is None:
raise ValueError("dst_grid not set for coarse-graining")
# find which location x_,y_ falls in in dst_grid
if fld.shape == self.x.shape:
inside = self.coarsen_inside
nearest = self.coarsen_nearest
else:
raise ValueError(f"field shape {fld.shape} does not match grid shape {self.x.shape}")
fld_coarse = np.zeros(self.dst_grid.x.flatten().shape)
count = np.zeros(self.dst_grid.x.flatten().shape)
fld_inside = fld.flatten()[inside]
valid = ~np.isnan(fld_inside) # filter out nan
# average the fld points inside each dst_grid box
np.add.at(fld_coarse, nearest[valid], fld_inside[valid])
np.add.at(count, nearest[valid], 1)
valid = (count>1) # do not coarse grain if only one point near by
fld_coarse[valid] /= count[valid]
fld_coarse[~valid] = np.nan
return fld_coarse.reshape(self.dst_grid.x.shape)
[docs]
def plot_field(self, ax, fld, vmin=None, vmax=None, cmap='viridis', **kwargs):
if vmin is None:
vmin = np.nanmin(fld)
if vmax is None:
vmax = np.nanmax(fld)
if isinstance(cmap, str):
cmap = matplotlib.colormaps[cmap] # type: ignore
x = self.x
y = self.y
# in case of lon convention 0:360, need to reorder so that x is monotonic
if self.proj_name == 'longlat':
ind = np.argsort(x[0,:])
x = np.take(x, ind, axis=1)
fld = np.take(fld, ind, axis=1)
im = ax.pcolor(x, y, fld, vmin=vmin, vmax=vmax, cmap=cmap, **kwargs)
self.set_xylim(ax)
return im