Source code for NEDAS.grid.grid_2d_base

import os
import inspect
from functools import cached_property
from typing import Optional
from abc import ABC, abstractmethod
import numpy as np
import shapefile
from pyproj import Proj, Geod
import matplotlib
from NEDAS.utils.graphics import draw_line, draw_patch, arrowhead_xy, draw_reference_vector_legend

[docs] class Grid2DBase(ABC): """ Base class to handle 2D fields defined on regular grids or unstructured meshes. Args: proj (pyproj.Proj, custom func, None): Projection function mapping from longitude,latitude to x,y coordinates. If None, a default Mercator projection will be used. x (np.ndarray): X-coordinates for each grid point. y (np.ndarray): Y-coordinates for each grid point. bounds (list, optional): Grid boundary limits, [xmin, xmax, ymin, ymax], all float numbers. If not specified, will use min/max value of the coordinates. regular (bool, optional): Whether grid is regular or unstructured. Default is True (regular grid). cyclic_dim (str, optional): Cyclic dimension(s): ``'x'``, ``'y'``, ``'xy'``, or ``None`` if noncyclic. distance_type (str, optional): Type of distance functions: `cartesian` (default) or `spherical`. dst_grid (GridBase, optional): Destination grid object to convert to. Attributes: proj (pyproj.Proj or custom function): Projection function. proj_name (str): Name of the projection, empty if not available. bounds (list): Grid boundary limits [xmin, xmax, ymin, ymax]. mask (np.ndarray): Mask (bool) for points that are not participating the analysis,same shape as :code:`x`, default is all False. """ x: np.ndarray y: np.ndarray x_elem: np.ndarray y_elem: np.ndarray dx: float dy: float Lx: float Ly: float mask: np.ndarray regular: bool cyclic_dim: str|None def __init__(self, proj, x, y, bounds=None, cyclic_dim=None, distance_type='cartesian', dst_grid=None): assert x.shape == y.shape, "x, y shape does not match" if proj is None: self.proj = Proj('+proj=stere') # default projection else: self.proj = proj # name of the projection if hasattr(proj, 'name'): self.proj_name = proj.name else: self.proj_name = '' # proj info, ellps is used in Geod for distance calculation self.proj_ellps = 'WGS84' self.proj_lon0 = 0 self.proj_lat0 = 0 if hasattr(proj, 'definition'): for e in proj.definition.split(): es = e.split('=') if es[0]=='ellps': self.proj_ellps = es[1] if es[0]=='lat_0': self.proj_lat0 = np.float32(es[1]) if es[0]=='lon_0': self.proj_lon0 = np.float32(es[1]) # coordinates and properties of the 2D grid self.x = x self.y = y self.cyclic_dim = cyclic_dim # internally we use -180:180 convention for longitude if self.proj_name == 'longlat': self.x = np.mod(self.x + 180., 360.) - 180. # boundary corners of the grid if bounds is not None: self.xmin, self.xmax, self.ymin, self.ymax = bounds else: self.xmin = np.min(self.x) self.xmax = np.max(self.x) self.ymin = np.min(self.y) self.ymax = np.max(self.y) self.distance_type = distance_type self.mask = np.full(self.x.shape, False) self._dst_grid = None if dst_grid is not None: self.set_destination_grid(dst_grid) def __eq__(self, other): if not isinstance(other, Grid2DBase): return False if self.proj != other.proj: return False if self.x.shape != other.x.shape: return False if not np.allclose(self.x, other.x): return False if not np.allclose(self.y, other.y): return False return True @property def mfx(self): """ Map scaling factors in x direction (mfx), since on the projection plane dx is not exactly the distance on Earth. The mfx is defined as ratio between dx and the actual distance. """ if self.proj_name == 'longlat': # long/lat grid doesn't have units in meters, so will not use map factors return np.ones(self.x.shape) else: # map factor: ratio of (dx, dy) to their actual distances on the earth. geod = Geod(ellps=self.proj_ellps) lon, lat = self.proj(self.x, self.y, inverse=True) lon1x, lat1x = self.proj(self.x+self.dx, self.y, inverse=True) _,_,gcdx = geod.inv(lon, lat, lon1x, lat1x) return self.dx / gcdx @property def mfy(self): """ Map scaling factors in y direction (mfy), since on the projection plane dy is not exactly the distance on Earth. The mfy is defined as ratio between dy and the actual distance. """ if self.proj_name == 'longlat': # long/lat grid doesn't have units in meters, so will not use map factors return np.ones(self.x.shape) else: # map factor: ratio of (dx, dy) to their actual distances on the earth. geod = Geod(ellps=self.proj_ellps) lon, lat = self.proj(self.x, self.y, inverse=True) lon1y, lat1y = self.proj(self.x, self.y+self.dy, inverse=True) _,_,gcdy = geod.inv(lon, lat, lon1y, lat1y) return self.dy / gcdy @property def dst_grid(self): """ Destination grid for convert, interp, rotate_vector methods once specified a dst_grid, the setter will compute corresponding rotation_matrix and interp_weights """ return self._dst_grid @dst_grid.setter def dst_grid(self, grid): assert isinstance(grid, Grid2DBase), "dst_grid should be a Grid instance" if grid == self.dst_grid: # the same grid is set before return self._dst_grid = grid # rotation of vector field from self.proj to dst_grid.proj self._set_rotation_matrix() # prepare indices and weights for interpolation # when dst_grid is set, these info are prepared and stored to avoid recalculating # too many times, when applying the same interp to a lot of flds x, y = self._proj_from(grid.x, grid.y) inside, indices, vertices, in_coords, nearest = self.find_index(x, y) self.interp_inside = inside self.interp_indices = indices self.interp_vertices = vertices self.interp_nearest = nearest self.interp_weights = self._interp_weights(inside, vertices, in_coords) # prepare indices for coarse-graining x, y = self._proj_to(self.x, self.y) inside, _, _, _, nearest = grid.find_index(x, y) self.coarsen_inside = inside self.coarsen_nearest = nearest if not self.regular: # for irregular mesh, find indices for elements too x, y = self._proj_to(self.x_elem, self.y_elem) inside, _, _, _, nearest = grid.find_index(x, y) self.coarsen_inside_elem = inside self.coarsen_nearest_elem = nearest
[docs] def set_destination_grid(self, grid): """ Set method for self.dst_grid the destination Grid object to convert to. """ self.dst_grid = grid
def _wrap_cyclic_xy(self, x_, y_): """ When interpolating for point x_,y_, if the coordinates falls outside of the domain, we wrap around and make then inside again, if the boundary condition is cyclic (self.cyclic_dim) Args: x_, y_ (np.array): x, y coordinates of a grid Returns: Same as input but x, y values are wrapped to be within the boundaries again. """ if self.cyclic_dim is not None: for d in self.cyclic_dim: if d=='x': x_ = np.mod(x_ - self.xmin, self.Lx) + self.xmin elif d=='y': y_ = np.mod(y_ - self.ymin, self.Ly) + self.ymin return x_, y_
[docs] @abstractmethod def find_index(self, x_, y_) -> tuple[np.ndarray, np.ndarray|None, np.ndarray, np.ndarray, np.ndarray]: """ Find indices of `self.x`, `self.y` corresponding to the given `x_`, `y_`. Args: x_ (float or np.ndarray): x-coordinates of target point(s). y_ (float or np.ndarray): y-coordinates of target point(s). Outputs: inside (np.ndarray of bool): Boolean array of shape `(x_.size,)` indicating whether each `x_`, `y_` point lies inside the grid. indices (np.ndarray of int or None): Indices of grid elements containing the input points. - For regular grids, this is `None` since vertices suffice to locate the grid box. - For unstructured meshes, these are indices into `tri.triangles`, from `tri_finder`. vertices (np.ndarray of int): Array of shape `(inside_size, n)`, where `n = 4` for regular grid boxes or `n = 3` for mesh triangles. These are indices into `self.x`, `self.y` (flattened) for the vertices of the grid element that each point falls in. in_coords (np.ndarray of float): Array of shape `(inside_size, n)` giving internal coordinates of each point within the containing element. Used to compute interpolation weights. nearest (np.ndarray of int): Array of shape `(inside_size,)` with indices of the grid nodes closest to each point. Notes: - This function assumes `self.x`, `self.y` define either a regular or triangular grid. - Internal coordinates are used for interpolation and vary in dimension based on the grid type. """ ...
def _proj_to(self, x, y): """ Transform coordinates from self.proj to dst_grid.proj """ if self.dst_grid is None: raise ValueError("dst_grid is not set, cannot project to dst_grid coordinates") if self.dst_grid.proj != self.proj: lon, lat = self.proj(x, y, inverse=True) x, y = self.dst_grid.proj(lon, lat) x, y = self.dst_grid._wrap_cyclic_xy(x, y) return x, y def _proj_from(self, x, y): """ transform coordinates from dst_grid.proj to self.proj """ if self.dst_grid is None: raise ValueError("dst_grid is not set, cannot project from dst_grid coordinates") if self.dst_grid.proj != self.proj: lon, lat = self.dst_grid.proj(x, y, inverse=True) x, y = self.proj(lon, lat) x, y = self._wrap_cyclic_xy(x, y) return x, y def _set_rotation_matrix(self): """ setting the rotation matrix for converting vector fields from self to dst_grid Note: self.rotate_matrix is rotating two unit vectors with different angles, not the classic rotation matrix """ if self.dst_grid is None: raise ValueError("dst_grid is not set, cannot set rotation matrix") self.rotate_matrix = np.zeros((4,)+self.x.shape) if self.proj != self.dst_grid.proj: # self.x,y corresponding coordinates in dst_proj, call them x,y x, y = self._proj_to(self.x, self.y) # find small increments in x,y due to small changes in self.x,y in dst_proj eps = 0.1 * self.dx # grid spacing is specified in Grid object xu, yu = self._proj_to(self.x + eps, self.y ) # move a bit in x dirn xv, yv = self._proj_to(self.x , self.y + eps) # move a bit in y dirn np.seterr(invalid='ignore') # will get nan at poles due to singularity, fill_pole_void takes care later dxu = xu-x dyu = yu-y dxv = xv-x dyv = yv-y hu = np.hypot(dxu, dyu) hv = np.hypot(dxv, dyv) self.rotate_matrix[0, :] = dxu/hu # rotation of x self.rotate_matrix[2, :] = dyu/hu self.rotate_matrix[1, :] = dxv/hv # rotation of v self.rotate_matrix[3, :] = dyv/hv else: # if no change in proj, we can skip the calculation self.rotate_matrix[0, :] = 1. self.rotate_matrix[1, :] = 0. self.rotate_matrix[2, :] = 0. self.rotate_matrix[3, :] = 1.
[docs] def rotate_vectors(self, vec_fld): """ Apply the rotate_matrix to a vector field Args: vec_fld (np.array): The input vector field, shape (2, self.x.shape) Returns: The vector field rotated to the dst_grid. """ u = vec_fld[0, :].copy() v = vec_fld[1, :].copy() rw = self.rotate_matrix u_rot = rw[0, :]*u + rw[1, :]*v v_rot = rw[2, :]*u + rw[3, :]*v return np.array([u_rot, v_rot])
@abstractmethod def _interp_weights(self, inside, vertices, in_coords) -> np.ndarray: ...
[docs] @abstractmethod def interp(self, fld, x=None, y=None, method='linear') -> np.ndarray: """ Interpolation of 2D field data (fld) from one grid (self or given x,y) to another (dst_grid). This can be used for grid refining (low->high resolution) or grid thinning (high->low resolution). This also converts between different grid geometries. Args: fld (np.array): Input field defined on self, should have same shape as self.x x,y (float or np.array): Optional; If x,y are specified, the function computes the weights and apply them to fld If x,y are None, the self.dst_grid.x,y are used. Since their interp_weights are precalculated by dst_grid.setter it will be efficient to run interp for many different input flds quickly. method (str): Interpolation method, can be 'nearest' or 'linear' Returns: The interpolated field defined on the destination grid """ ...
[docs] @abstractmethod def coarsen(self, fld) -> np.ndarray: """ Coarse-graining is sometimes needed when the dst_grid is at lower resolution than self. Since many points of self.x,y falls in one dst_grid box/element, it is better to average them to represent the field on the low-res grid, instead of interpolating only from the nearest points that will cause representation errors. Args: fld (np.array): Input field to perform coarse-graining on, it is defined on self. Returns: The coarse-grained field defined on self.dst_grid. """ ...
[docs] def convert(self, fld, is_vector=False, method='linear', coarse_grain=False): """ Main method to convert from self.proj, x, y to dst_grid coordinate systems: Notes: 1. if projection changes and is_vector, rotate vectors from self.proj to dst_grid.proj 2.1 interpolate fld components from self.x,y to dst_grid.x,y 2.2 if dst_grid is low-res, coarse_grain=True will perform coarse-graining Args: fld (np.array): Input field to perform convertion on. is_vector (bool, optional): If False (default) the input fld is a scalar field, otherwise the input fld is a vector field. method (str, optional): Interpolation method, 'linear' (default) or 'nearest' coarse_grain (bool, optional): If True, the coarse-graining will be applied using self.coarsen(). The default is False. Returns: The converted field defined on the destination grid self.dst_grid. """ if self.dst_grid is None: raise ValueError("dst_grid not set for convert") if self.dst_grid != self: if is_vector: assert fld.shape[0] == 2, "vector field should have first dim==2, for u,v component" # vector field needs to rotate to dst_grid.proj before interp fld = self.rotate_vectors(fld) fld_out = np.full((2,)+self.dst_grid.x.shape, np.nan) for i in range(2): # interp each component: u, v fld_out[i, :] = self.interp(fld[i, :], method=method) if coarse_grain: # coarse-graining if more points fall in one grid fld_coarse = self.coarsen(fld[i, :]) ind = ~np.isnan(fld_coarse) fld_out[i, ind] = fld_coarse[ind] else: # scalar field, just interpolate fld_out = np.full(self.dst_grid.x.shape, np.nan) fld_out = self.interp(fld, method=method) if coarse_grain: # coarse-graining if more points fall in one grid fld_coarse = self.coarsen(fld) ind = ~np.isnan(fld_coarse) fld_out[ind] = fld_coarse[ind] else: fld_out = fld return fld_out
[docs] def distance(self, ref_x, x, ref_y, y, p=2, type='cartesian'): """ Compute distance for points (x,y) to the reference point Args: ref_x, ref_y (float): reference point x,y coordinates x, y (np.array): points whose distance to the reference points will be computed p (int, optional): Minkowski p-norm order, default is 2 type (str, optional): distance type, 'cartesian' (default) or 'spherical' Returns: Distances between x,y and the reference point ref_x, ref_y. """ if type == 'cartesian': # normal cartesian distances in x and y dist_x = np.abs(x - ref_x) if self.cyclic_dim is not None and 'x' in self.cyclic_dim: dist_x = np.minimum(dist_x, self.Lx - dist_x) dist_y = np.abs(y - ref_y) if self.cyclic_dim is not None and 'y' in self.cyclic_dim: dist_y = np.minimum(dist_y, self.Ly - dist_y) if p == 1: dist = dist_x + dist_y # Manhattan distance, order 1 elif p == 2: dist = np.hypot(dist_x, dist_y) # Euclidean distance, order 2 else: raise NotImplementedError(f"grid.distance: p-norm order {p} is not implemented for 2D grid") return dist # compute spherical distance on Earth instead elif type == 'spherical': reflon, reflat = self.proj(ref_x, ref_y, inverse=True) lon, lat = self.proj(x, y, inverse=True) RE = 6371000.0 invrad = np.pi / 180. rlon1 = np.atleast_1d(reflon) * invrad rlat1 = np.atleast_1d(reflat) * invrad rlon2 = np.atleast_1d(lon) * invrad rlat2 = np.atleast_1d(lat) * invrad # from m_spherdist.F90 in enkf-topaz: cos_d = np.sin(rlat1) * np.sin(rlat2) + np.cos(rlat1) * np.cos(rlat2) * np.cos(rlon1 - rlon2) dist = RE * np.acos(np.clip(cos_d, -1, 1)) # Haversine formula to avoid precision loss # a = np.sin((rlat2 - rlat1) / 2)**2 + np.cos(rlat1) * np.cos(rlat2) * np.sin((rlon1 - rlon2) / 2)**2 # dist = 2 * RE * np.asin(np.sqrt(a)) return dist else: raise ValueError(f"unknown distance type '{type}'")
def _collect_shape_data(self, shapes): """ This collects the x,y coordinates from shapes read from .shp files for later plotting filter the points not inside the grid domain """ data = {'xy':[], 'parts':[]} for shape in shapes: if len(shape.points) > 0: xy = [] inside = [] lon, lat = [np.array(x) for x in zip(*shape.points)] x, y = self.proj(lon, lat) inside = np.logical_and(np.logical_and(x >= self.xmin, x <= self.xmax), np.logical_and(y >= self.ymin, y <= self.ymax)) # when showing global maps, the lines leave the domain and re-enter # from the other side, the cross-over lines are visible on the plot # temporary solution: make a pause when lines wrap around cut meridian # lines work fine now but filled patches do not if self.proj_name in ['longlat', 'tripolar', 'bipolar']: x[~inside] = np.nan y[~inside] = np.nan xy = [(x[i], y[i]) for i in range(x.size)] # if any point in the polygon lies inside the grid, need to plot it. if any(inside): data['xy'].append(xy) data['parts'].append(shape.parts) return data
[docs] @cached_property def land_data(self): """ prepare data to show the land area, the shp file ne_50m_coastlines is downloaded from https://www.naturalearthdata.com """ path = os.path.split(inspect.getfile(self.__class__))[0] sf = shapefile.Reader(os.path.join(path, 'ne_50m_coastline.shp')) shapes: list[shapefile.Shape] = sf.shapes() # type: ignore # Some cosmetic tweaks of the shapefile for some Canadian coastlines shapes[1200].points = shapes[1200].points + shapes[1199].points[1:] shapes[1199].points = [] shapes[1230].points = shapes[1230].points + shapes[1229].points[1:] + shapes[1228].points[1:] + shapes[1227].points[1:] shapes[1229].points = [] shapes[1228].points = [] shapes[1227].points = [] shapes[1233].points = shapes[1233].points + shapes[1234].points shapes[1234].points = [] return self._collect_shape_data(shapes)
[docs] @cached_property def river_data(self): """ prepare data to show river features """ path = os.path.split(inspect.getfile(self.__class__))[0] sf = shapefile.Reader(os.path.join(path, 'ne_50m_rivers.shp')) shapes = sf.shapes() return self._collect_shape_data(shapes)
[docs] @cached_property def lake_data(self): """ prepare data to show lake features """ path = os.path.split(inspect.getfile(self.__class__))[0] sf = shapefile.Reader(os.path.join(path, 'ne_50m_lakes.shp')) shapes = sf.shapes() return self._collect_shape_data(shapes)
[docs] def llgrid_xy(self, dlon:float, dlat:float): """ Prepare a lon/lat grid to plot as reference lines Args: - dlon, dlat: spacing of lon/lat grid in degrees """ self.dlon = dlon self.dlat = dlat llgrid_xy = [] for lon_r in np.arange(-180, 180, dlon): xy = [] inside = [] lat = np.arange(-89.9, 90, 0.1) lon = np.ones(lat.size) * lon_r x, y = self.proj(lon, lat) inside = np.logical_and(np.logical_and(x >= self.xmin, x <= self.xmax), np.logical_and(y >= self.ymin, y <= self.ymax)) x[~inside] = np.nan y[~inside] = np.nan xy = [(x[i], y[i]) for i in range(x.size)] if any(inside): llgrid_xy.append(xy) for lat_r in np.arange(-90, 90+dlat, dlat): xy = [] inside = [] lon = np.arange(-180., 180., 0.1) lat = np.ones(lon.size) * lat_r x, y = self.proj(lon, lat) inside = np.logical_and(np.logical_and(x >= self.xmin, x <= self.xmax), np.logical_and(y >= self.ymin, y <= self.ymax)) x[~inside] = np.nan y[~inside] = np.nan xy = [(x[i], y[i]) for i in range(x.size)] if any(inside): llgrid_xy.append(xy) return llgrid_xy
[docs] @abstractmethod def plot_field(self, ax, fld, vmin=None, vmax=None, cmap='viridis', **kwargs): """ Plot a scalar field using pcolor/tripcolor Args: ax (matplotlib.pyplot.Axes): Axes handle for plotting fld (np.array): The scalar field for plotting vmin, vmax (float, optional): The minimum and maximum value range for the colormap, if not specified (None) the np.min, np.max of the input fld will be used. cmap (matplotlib colormap, or str, optional): Colormap used in the plot, default is 'viridis' """ ...
[docs] def plot_vectors(self, ax, vec_fld, V=None, L=None, spacing=0.5, num_steps=10, linecolor='k', linewidth=1, showref=False, ref_xy=(0.95, 0.95), refcolor='w', ref_units='', showhead=True, headwidth=0.1, headlength=0.3): """ Plot vector fields (improved version of matplotlib quiver) Args: ax (matplotlib.pyplot.Axes): Axes handle for plotting vec_fld (np.array): The vector field for plotting V (float, optional): Velocity scale, typical velocity value in vec_fld units. If not specified (None) a typical value 0.33*max(abs(vec_fld[0,:])) will be used. L (float, optional): Length scale, how long in x,y units do vectors with velocity V show in the plot If not specified (None), a typical value 0.05*self.Lx will be used. spacing (float, optional): Distance between vectors in both directions is given by spacing*L. Default is 0.5. This controls the density of vectors in the plot. You can provide a tuple (float, float) for spacings in (x, y) if you want them to be set differently. num_steps (int, optional): Default is 10. If num_steps=1, straight vectors (as in quiver) will be displayed. num_steps>1 lets you display curved trajectories, at each sub-step the velocity is re-interpolated at the new position along the trajectories. As num_steps get larger the trajectories are more detailed. linecolor (str or matplotlib color, optional): Line color for the vector lines, default is 'k' linewidth (float, optional): Line width for the vector lines, default is 1. showref (bool, optional): If True, show a legend box with a reference vector (size L) inside. Default is False. ref_xy (tuple, optional): The x,y relative coordinates (0-1) for the reference vector box, default is upper right corner. ref_color (str or matplotlib color, optional): Background color for the reference vector box, default is 'w' (white). ref_units (str, optional): Units to be included in the reference vector box, default is ''. showhead (bool, optional): If True (default), show the arrow head of the vectors headwidth (float, optional): Width of arrow heads relative to L, default is 0.1. headlength (float, optional): Length of arrow heads relative to L, default is 0.3. """ assert vec_fld.shape == (2,)+self.x.shape, "vector field shape mismatch with x,y" x = self.x y = self.y u = vec_fld[0,:] v = vec_fld[1,:] # set typicall L, V if not defined if V is None: V = 0.33 * np.nanmax(np.abs(u)) if L is None: L = 0.05 * (np.max(x) - np.min(x)) # start trajectories on a regular grid with spacing d if isinstance(spacing, tuple): d = (spacing[0]*L, spacing[1]*L) else: d = (spacing*L, spacing*L) dt = L / V / num_steps xo, yo = np.mgrid[x.min()+0.5*d[0]:x.max():d[0], y.min()+0.5*d[1]:y.max():d[1]] npoints = xo.flatten().shape[0] xtraj = np.full((npoints, num_steps+1,), np.nan) ytraj = np.full((npoints, num_steps+1,), np.nan) leng = np.zeros(npoints) xtraj[:, 0] = xo.flatten() ytraj[:, 0] = yo.flatten() for t in range(num_steps): # find velocity ut,vt at traj position for step t ut = self.interp(u, xtraj[:,t], ytraj[:,t]) vt = self.interp(v, xtraj[:,t], ytraj[:,t]) # velocity should be in physical units, to plot the right length on projection # we use the map factors to scale distance units ut = ut * self.interp(self.mfx, xtraj[:,t], ytraj[:,t]) vt = vt * self.interp(self.mfy, xtraj[:,t], ytraj[:,t]) # update traj position xtraj[:, t+1] = xtraj[:, t] + ut * dt ytraj[:, t+1] = ytraj[:, t] + vt * dt # update length leng = leng + np.hypot(ut, vt) * dt # plot the vector lines hl = headlength * L hw = headwidth * L for i in range(xtraj.shape[0]): # plot trajectory at one output location ax.plot(xtraj[i, :], ytraj[i, :], color=linecolor, linewidth=linewidth, zorder=4) # add vector head if traj is long and straight enough dist = np.hypot(xtraj[i,0]-xtraj[i,-1], ytraj[i,0]-ytraj[i,-1]) if showhead and hl < leng[i] < 1.6*dist: ax.fill(*arrowhead_xy(xtraj[i,-1], xtraj[i,-2], ytraj[i,-1],ytraj[i,-2], hw, hl), color=linecolor, zorder=5) # add reference vector if showref: xr = self.xmin + L*1.3 + ref_xy[0] * (self.Lx - L*2.6) yr = self.ymin + L + ref_xy[1] * (self.Ly - L*1.5) draw_reference_vector_legend(ax, xr, yr, V, L, hw, hl, refcolor, linecolor, ref_units) self.set_xylim(ax)
[docs] def plot_scatter(self, ax, fld, vmin=None, vmax=None, nlevels=20, cmap='viridis', markersize=10, x=None, y=None, L=None, is_vector=False, **kwargs): """ Same as plot_field/vectors, but showing individual scattered points instead This is more suitable for plotting observations in space """ if x is None: x = self.x if y is None: y = self.y if vmin is None: vmin = np.nanmin(fld) if vmax is None: vmax = np.nanmax(fld) if L is None: L = 0.05 * self.Lx dv = (vmax - vmin) / nlevels if is_vector: assert fld.shape[0] == 2, f"vector field should have first dim==2" assert fld.shape[1:] == x.shape, f"vector field shape does not match with grid" V = vmax hl, hw = 0.3 * L, 0.15 * L refcolor = kwargs.get('refcolor', 'w') ref_units = kwargs.get('units', '') ref_xy = kwargs.get('ref_xy', (0.95, 0.95)) xr = self.xmin + L*1.3 + ref_xy[0] * (self.Lx - L*2.6) yr = self.ymin + L + ref_xy[1] * (self.Ly - L*1.5) d = fld * L / V xtraj, ytraj = np.array([x, x + d[0,...]]), np.array([y, y + d[1,...]]) linecolor = kwargs.get('linecolor', 'k') linewidth = kwargs.get('linewidth', 1) for i in np.ndindex(x.shape): ax.plot(xtraj[:,i], ytraj[:,i], color=linecolor, linewidth=linewidth, zorder=5) dist = np.hypot(xtraj[0,i]-xtraj[1,i], ytraj[0,i]-ytraj[1,i]) if hl < 1.6*dist: ax.fill(*arrowhead_xy(xtraj[1,i], xtraj[0,i], ytraj[1,i], ytraj[0,i], hw, hl), color=linecolor, zorder=5) if kwargs.get('showref', True): draw_reference_vector_legend(ax, xr, yr, V, L, hw, hl, refcolor, linecolor, ref_units) else: assert fld.shape == x.shape msk = ~np.isnan(fld) v = np.array(fld[msk]) vbound = np.maximum(np.minimum(v, vmax), vmin) if isinstance(cmap, str): cmap = matplotlib.colormaps[cmap] # type: ignore cmap = np.array([cmap(x)[0:3] for x in np.linspace(0, 1, nlevels+1)]) cind = ((vbound - vmin) / dv).astype(int) ax.scatter(x[msk], y[msk], markersize, color=cmap[cind], **kwargs) self.set_xylim(ax)
[docs] def plot_land(self, ax, color=None, linecolor='k', linewidth=1, showriver=False, rivercolor='c', showgrid=True, dlon=20, dlat=5): """ Shows the map (coastline, rivers, lakes) and lon/lat grid for reference Args: ax (matplotlib.pyplot.Axes): Axes handle for plotting color (matplotlib color, optional): Face color of the landmass polygon, default is None (transparent). linecolor (matplotlib color, optional): Line color of the coastline, default is 'k' (black). linewidth (float, optional): Line width of the coastline, default is 1. showriver (bool, optional): If True, show the rivers and lakes over the landmass. Default is False. rivercolor (matplotlib color, optional): Color of the rivers and lakes, default is 'c' (cyan). showgrid (bool, optional): If True (default), show the reference lat/lon grid. dlon (float, optional): The interval of longitude lines in the reference grid. Default is 20 degrees. dlat (float, optional): The interval in latitude lines in the reference grid. Default is 5 degress. """ # plot the coastline to indicate land area if color is not None: draw_patch(ax, self.land_data, color=color, zorder=3) if linecolor is not None: draw_line(ax, self.land_data, linecolor=linecolor, linewidth=linewidth, linestyle='-', zorder=8) if showriver: draw_line(ax, self.river_data, linecolor=rivercolor, linewidth=0.5, linestyle='-', zorder=1) draw_patch(ax, self.lake_data, color=rivercolor, zorder=1) # add reference lonlat grid on map if showgrid: for xy in self.llgrid_xy(dlon, dlat): ax.plot(*zip(*xy), color='k', linewidth=0.5, linestyle=':', zorder=4) self.set_xylim(ax)
[docs] def set_xylim(self, ax): # set the correct extent of plot ax.set_xlim(self.xmin, self.xmax) ax.set_ylim(self.ymin, self.ymax)