"""
QGModel: Python implementation of the multi-layer quasi-geostrophic spectral model.
Translates qg_driver.f90 / qg_run_tools.f90 / qg_init_tools.f90.
Physics
-------
PV equation (spectral):
dq/dt = J(ψ, q) + β·∂ψ/∂x + (mean-flow terms) + (dissipation) + (forcing)
PV-streamfunction relation:
q = -\\|k\\|²ψ + S·ψ (S = tridiagonal stratification operator, \\|k\\| = wavenumber magnitude)
Time integration: leapfrog with Robert (Asselin) filter, adaptive timestep.
Array conventions
-----------------
Spectral: shape (nz, nky, nkx) — nky = kmax+1, nkx = 2*kmax+1
Physical: shape (nz, ny, nx) — nx = ny = 2*(kmax+1)
Wavenumber grids: shape (nky, nkx)
Usage
-----
::
m = QGModel(kmax=63, nz=4, F=50.0, beta=1.5, ...)
m.initialize()
for _ in range(1000):
m.step()
"""
from __future__ import annotations
import numpy as np
from dataclasses import dataclass, field
from typing import Any, Optional
from .spectral import (
setup_spectral_grid, make_filter,
spec2grid_cc, grid2spec, spec2grid,
ir_prod, ir_pwr, jacob,
)
from .strat import (
get_psiq, get_vmodes, get_layer_depths, get_tripint,
layer2mode, mode2layer,
)
from .numerics import march, ran, ran_reset, ring_integral
[docs]
@dataclass
class QGModel:
"""Multi-layer QG spectral model (Python / NumPy / JAX).
Parameters mirror the Fortran input namelist (qg_params.f90).
"""
# Resolution
kmax: int = 63
nz: int = 1
# Fundamental scales
beta: float = 1.5
F: float = 0.0 # f²L²/[(2π)²g'H₀]
Fe: float = 0.0 # free-surface parameter
# Mean flow
uscale: float = 0.0
vscale: float = 0.0
# Stratification
strat_type: str = 'linear'
deltc: float = 0.2 # thermocline thickness (exp/stc profile)
surface_bc: str = 'rigid_lid'
# Time stepping
dt: float = 0.0
dt_max: float = 0.0 # hard ceiling on adapt_dt; 0 = uncapped
adapt_dt: bool = True
dt_tune: float = 1.5
dt_step: int = 10
robert: float = 0.01
# Filters / de-aliasing
filter_type: str = 'hyperviscous'
filter_exp: float = 8.0
k_cut: float = 0.0
dealiasing: str = 'isotropic'
filt_tune: float = 1.0
# Dissipation
bot_drag: float = 0.0
top_drag: float = 0.0
therm_drag: float = 0.0
quad_drag: float = 0.0
qd_angle: float = 0.0
# Markovian forcing
use_forcing: bool = False
norm_forcing: bool = False
forc_coef: float = 0.0
forc_corr: float = 0.0
kf_min: float = 0.0
kf_max: float = 0.0
# Topography
use_topo: bool = False
# Tracer
use_tracer: bool = False
# Control
linear: bool = False
idum: int = -7
pi: float = field(default=np.pi, init=False, repr=False)
# ---- Internal state (populated by initialize()) ----
_g: Optional[dict[str, Any]] = field(default=None, init=False, repr=False)
_psiq: Optional[np.ndarray] = field(default=None, init=False, repr=False)
_filt: Optional[np.ndarray] = field(default=None, init=False, repr=False)
_ubar: Optional[np.ndarray] = field(default=None, init=False, repr=False)
_vbar: Optional[np.ndarray] = field(default=None, init=False, repr=False)
_qbarx: Optional[np.ndarray] = field(default=None, init=False, repr=False)
_qbary: Optional[np.ndarray] = field(default=None, init=False, repr=False)
_shearu: Optional[np.ndarray] = field(default=None, init=False, repr=False)
_shearv: Optional[np.ndarray] = field(default=None, init=False, repr=False)
_dz: Optional[np.ndarray] = field(default=None, init=False, repr=False)
_rho: Optional[np.ndarray] = field(default=None, init=False, repr=False)
_hb: Optional[np.ndarray] = field(default=None, init=False, repr=False)
_force_o: Optional[np.ndarray] = field(default=None, init=False, repr=False)
_toposhift: Optional[np.ndarray] = field(default=None, init=False, repr=False)
# Prognostic fields
q: Optional[np.ndarray] = field(default=None, init=False, repr=False)
q_o: Optional[np.ndarray] = field(default=None, init=False, repr=False)
rhs: Optional[np.ndarray] = field(default=None, init=False, repr=False)
psi: Optional[np.ndarray] = field(default=None, init=False, repr=False)
psi_o: Optional[np.ndarray] = field(default=None, init=False, repr=False)
# Counters
cntr: int = field(default=0, init=False)
time: float = field(default=0.0, init=False)
_call_q: int = field(default=0, init=False)
# ---------------------------------------------------------------
# Initialisation
# ---------------------------------------------------------------
[docs]
def initialize(self, psi_init=None, dz=None, rho=None,
ubar=None, vbar=None, hb=None):
"""Set up spectral grid, stratification, and initial fields.
Parameters
----------
psi_init : ndarray (nz, nky, nkx) complex, optional
Initial streamfunction in spectral space.
dz : (nz,) layer thicknesses (should sum to 1)
rho : (nz,) layer densities
ubar, vbar : (nz,) mean zonal / meridional velocity profiles
hb : (nky, nkx) complex bottom topography in spectral space
"""
g = setup_spectral_grid(self.kmax)
self._g = g
nz = self.nz
nky = int(g['nky'])
nkx = int(g['nkx'])
# Default vertical grid
if dz is None:
dz = np.ones(nz) / nz
if rho is None:
rho = np.ones(nz)
self._dz = np.asarray(dz, dtype=np.float64)
self._rho = np.asarray(rho, dtype=np.float64)
# Stratification operator
drho = np.diff(rho)
if nz > 1 and len(drho) > 0:
self._psiq = get_psiq(self._dz, drho, self.F, self.Fe, self.surface_bc)
else:
self._psiq = None
# Mean flow
self._ubar = np.asarray(ubar, dtype=np.float64) if ubar is not None else np.zeros(nz)
self._vbar = np.asarray(vbar, dtype=np.float64) if vbar is not None else np.zeros(nz)
self._shearu = np.zeros(nz)
self._shearv = np.zeros(nz)
if nz > 1 and self._psiq is not None:
# shearu = tri2mat(psiq) * ubar (Fortran qg_driver line 61)
psiq_mat = _build_trimat(self._psiq, nz)
self._shearu = psiq_mat @ self._ubar
self._shearv = psiq_mat @ self._vbar
# PV gradient from mean flow + beta
self._qbarx = self._shearv[:nz]
self._qbary = -self._shearu[:nz] + self.beta
# Filter
self._filt = make_filter(
g,
filter_type=self.filter_type,
filter_exp=self.filter_exp,
k_cut=self.k_cut if self.k_cut > 0 else None,
dealiasing=self.dealiasing,
filt_tune=self.filt_tune,
)
# Topography
self._hb = hb if hb is not None else None
if self._hb is not None and self._ubar is not None:
kx_ = g['kx_']
ky_ = g['ky_']
self._toposhift = (
-1j * self._ubar[nz-1] * (kx_ * self._hb)
-1j * self._vbar[nz-1] * (ky_ * self._hb)
)
else:
self._toposhift = None
# Forcing state
if self.use_forcing:
self._force_o = np.zeros((nky, nkx), dtype=complex)
else:
self._force_o = None
ran_reset(self.idum)
# Initialise prognostic fields
shape_spec = (nz, nky, nkx)
psi_arr = (np.asarray(psi_init, dtype=complex) if psi_init is not None
else np.zeros(shape_spec, dtype=complex))
self.psi = psi_arr
self.psi_o = psi_arr.copy()
q_arr = self.get_pv(psi_arr)
self.q = q_arr
self.q_o = q_arr.copy()
self.rhs = self._get_rhs()
if self.adapt_dt and self.dt == 0.0:
self._update_dt()
self._call_q = 0
# ---------------------------------------------------------------
# PV operations
# ---------------------------------------------------------------
[docs]
def get_pv(self, psi):
"""Compute PV q from streamfunction ψ (spectral).
q = -\\|k\\|²ψ + S·ψ
Translates Fortran Get_pv in qg_run_tools.f90.
psi : (nz, nky, nkx) or (nv, nky, nkx) for surf_buoy
Returns (nz, nky, nkx)
"""
assert self._g is not None, 'call initialize() before get_pv()'
g = self._g
nz = self.nz
ksqd_ = g['ksqd_'] # (nky, nkx)
psiq = self._psiq # (nv, 3)
q = np.zeros((nz,) + ksqd_.shape, dtype=complex)
if nz > 1 and psiq is not None:
# Vertical stretching term
# top boundary index depends on surface_bc
top = nz - 1 if self.surface_bc == 'periodic' else 0
# q[0] = psiq[0,-1]*psi[top] + psiq[0,0]*psi[0] + psiq[0,1]*psi[1]
q[0] = (psiq[0, 0] * psi[top]
+ psiq[0, 1] * psi[0]
+ psiq[0, 2] * psi[1])
# interior
for iz in range(1, nz - 1):
q[iz] = (psiq[iz, 0] * psi[iz - 1]
+ psiq[iz, 1] * psi[iz]
+ psiq[iz, 2] * psi[iz + 1])
# bottom
q[nz-1] = (psiq[nz-1, 0] * psi[nz - 2]
+ psiq[nz-1, 1] * psi[nz - 1]
+ psiq[nz-1, 2] * psi[0]) # wrap for periodic; 0 for rigid_lid
else:
q = -self.F * psi
q = -ksqd_[np.newaxis, :, :] * psi[:nz] + q
return q
[docs]
def invert_pv(self):
"""Invert PV q → ψ by solving the tridiagonal system per wavenumber.
Translates Fortran Invert_pv in qg_run_tools.f90.
Returns psi of shape (nz, nky, nkx).
"""
from .numerics import tridiag_vec, tridiag_cyc_vec
assert self._g is not None, 'call initialize() before invert_pv()'
assert self.q is not None
g = self._g
nz = self.nz
ksqd_ = g['ksqd_']
lin2kx = g['lin2kx']
lin2ky = g['lin2ky']
nmask = g['nmask']
psiq = self._psiq
psi = np.zeros((nz,) + ksqd_.shape, dtype=complex)
if nz == 1:
# Barotropic: psi = -q / (F + k²)
denom = self.F + ksqd_
psi[0] = np.where(ksqd_ != 0, -self.q[0] / denom, 0.0)
return psi
# Build per-wavenumber diagonal: psiq[:, 1] - k²
# Shape: (nmask, nz) — k² varies per wavenumber
assert psiq is not None
ksqd_flat = ksqd_[lin2ky, lin2kx] # (nmask,)
diag_base = psiq[:nz, 1] # (nz,) shared part
diag = diag_base[np.newaxis, :] - ksqd_flat[:, np.newaxis] # (nmask, nz)
sub = psiq[1:nz, 0] # (nz-1,)
sup = psiq[0:nz-1, 2] # (nz-1,)
# RHS: q at active wavenumbers, shape (nmask, nz)
q_flat = self.q[:, lin2ky, lin2kx].T # (nmask, nz)
if self.surface_bc == 'periodic':
corner_lo = psiq[0, 0]
corner_hi = psiq[nz-1, 2]
psi_flat = tridiag_cyc_vec(q_flat, sub, diag, sup, corner_lo, corner_hi)
else:
psi_flat = tridiag_vec(q_flat, sub, diag, sup)
# Scatter back
for iz in range(nz):
psi[iz, lin2ky, lin2kx] = psi_flat[:, iz]
return psi
# ---------------------------------------------------------------
# RHS (advection + dissipation + forcing)
# ---------------------------------------------------------------
def _get_rhs(self):
"""Compute the right-hand side of the PV equation.
Translates Fortran Get_rhs subroutine in qg_driver.f90.
Returns rhs of shape (nz, nky, nkx).
"""
assert self._g is not None
assert self._filt is not None
assert self._dz is not None
assert self.psi is not None and self.q is not None and self.psi_o is not None
assert self._qbarx is not None and self._qbary is not None
assert self._ubar is not None and self._vbar is not None
g = self._g
nz = self.nz
kx_ = g['kx_']
ky_ = g['ky_']
ksqd_ = g['ksqd_']
filt = self._filt # (nky, nkx)
rhs = np.zeros((nz,) + kx_.shape, dtype=complex)
if not self.linear:
# Velocities and PV gradients in physical space (staggered-packed)
ug = spec2grid_cc(-1j * ky_[np.newaxis] * self.psi, g)
vg = spec2grid_cc( 1j * kx_[np.newaxis] * self.psi, g)
q_work = self.q.copy()
if self.use_topo and self._hb is not None:
q_work[nz-1] += self._hb
qxg = spec2grid_cc(1j * kx_[np.newaxis] * q_work, g)
qyg = spec2grid_cc(1j * ky_[np.newaxis] * q_work, g)
# Advection: -J(ψ, q) = -(u·∇q) = -(ug*dq/dx + vg*dq/dy)
rhs = -grid2spec(ir_prod(ug, qxg) + ir_prod(vg, qyg), g)
# Quadratic bottom drag
if self.quad_drag != 0.0:
speed = ir_pwr(ir_pwr(ug[nz-1], 2.0) + ir_pwr(vg[nz-1], 2.0), 0.5)
qdrag = self.quad_drag * filt * (
1j * kx_ * grid2spec(speed * (
ug[nz-1] * np.sin(self.qd_angle)
+ vg[nz-1] * np.cos(self.qd_angle)
), g)
- 1j * ky_ * grid2spec(speed * (
ug[nz-1] * np.cos(self.qd_angle)
- vg[nz-1] * np.sin(self.qd_angle)
), g)
)
rhs[nz-1] -= qdrag
# Mean-flow / beta-plane linear terms
# rhs -= ik_y*(-qbarx*ψ + vbar*q) + ik_x*(qbary*ψ + ubar*q)
if np.any(self._qbarx != 0):
rhs -= 1j * (ky_[np.newaxis] * (
-self._qbarx[:, np.newaxis, np.newaxis] * self.psi
+ self._vbar[:, np.newaxis, np.newaxis] * self.q
))
if np.any(self._qbary != 0):
rhs -= 1j * (kx_[np.newaxis] * (
self._qbary[:, np.newaxis, np.newaxis] * self.psi
+ self._ubar[:, np.newaxis, np.newaxis] * self.q
))
# Topographic phase shifting
if self.use_topo and self._toposhift is not None:
rhs[nz-1] += self._toposhift
# Bottom / top Ekman drag.
# Use current psi (not time-lagged psi_o) so the drag acts as a
# semi-implicit term: ψ_{n+1} = ψ_{n-1}/(1+2*dt*drag) — unconditionally
# stable, avoiding the leapfrog computational-mode growth that occurs
# when drag*k²*dt >> 1 near the de-aliasing cutoff.
if self.bot_drag != 0.0:
rhs[nz-1] += self.bot_drag * ksqd_ * self.psi[nz-1]
if self.top_drag != 0.0:
rhs[0] += self.top_drag * ksqd_ * self.psi[0]
# Markovian stochastic forcing on top layer
if self.use_forcing:
rhs[0] += self._markovian(
self.kf_min, self.kf_max, self.forc_coef, self.forc_corr,
self.norm_forcing, self.psi[0]
)
# Thermal drag
if self.therm_drag != 0.0 and self.F != 0.0:
if nz == 1:
rhs += self.therm_drag * self.F * self.psi_o
elif self.Fe != 0.0:
rhs -= self.therm_drag * (self.q + ksqd_[np.newaxis] * self.psi)
else:
rhs[0] -= self.therm_drag * self.F * (self.psi_o[1] - self.psi_o[0]) / self._dz[0]
rhs[1] -= self.therm_drag * self.F * (self.psi_o[0] - self.psi_o[1]) / self._dz[1]
rhs = filt[np.newaxis] * rhs
return rhs
# ---------------------------------------------------------------
# Markovian forcing
# ---------------------------------------------------------------
def _markovian(self, kf_min, kf_max, amp, lam, normalize, field):
"""Random Markovian (red-noise) forcing in spectral space.
Translates Fortran Markovian in qg_run_tools.f90.
Returns forcing array (nky, nkx).
"""
assert self._g is not None
g = self._g
ksqd_ = g['ksqd_']
nkx = int(g['nkx'])
nky = int(g['nky'])
assert self._force_o is not None
mask = (ksqd_ > kf_min**2) & (ksqd_ <= kf_max**2)
noise_phase = 2.0 * np.pi * ran(nkx, nky).T # (nky, nkx)
noise = amp * np.sqrt(1.0 - lam**2) * np.exp(1j * noise_phase)
frc_o = self._force_o
frc_o = np.where(mask, lam * frc_o + noise, frc_o)
forc = frc_o.copy()
if normalize:
gamma = -2.0 * np.sum(np.conj(field) * forc) / amp
if abs(gamma) > 1e-5:
forc = forc / gamma
self._force_o = frc_o
return forc
# ---------------------------------------------------------------
# Adaptive timestep
# ---------------------------------------------------------------
def _update_dt(self):
"""Set dt based on CFL condition.
Translates Fortran adapt_dt logic in qg_driver.f90:
dt = dt_tune * 2π / (kmax * sqrt(max(zsq(psi), beta, 1)))
where zsq = 2 * sum(dz * k⁴ * |ψ|²) (twice the enstrophy).
"""
if self.psi is not None and self._g is not None and self._dz is not None:
ksqd_ = self._g['ksqd_']
zsq = 2.0 * float(
np.sum(self._dz[:, np.newaxis, np.newaxis]
* ksqd_[np.newaxis] ** 2
* np.abs(self.psi) ** 2)
)
else:
zsq = 0.0
denom = np.sqrt(max(zsq, self.beta, 1.0))
self.dt = self.dt_tune * 2.0 * self.pi / (self.kmax * denom)
if self.dt_max > 0.0:
self.dt = min(self.dt, self.dt_max)
# ---------------------------------------------------------------
# Time integration
# ---------------------------------------------------------------
[docs]
def step(self):
"""Advance the model one time step.
Translates the main time loop body in qg_driver.f90.
"""
if self._g is None:
raise RuntimeError('call initialize() before step()')
assert self.q is not None and self.q_o is not None
assert self.rhs is not None and self._filt is not None
assert self.psi is not None
g = self._g
# Adaptive timestep update
if self.adapt_dt and (self.cntr % self.dt_step == 0 or self.cntr == 0):
self._update_dt()
# March PV forward
q_new, q_o_new, self._call_q = march(
self.q, self.q_o, self.rhs, self.dt, self.robert, self._call_q
)
self.q = self._filt[np.newaxis] * q_new
self.q_o = q_o_new
# Save psi for time-lagged drag, then invert
self.psi_o = self.psi.copy()
self.psi = self.invert_pv()
# New RHS
self.rhs = self._get_rhs()
self.time += self.dt
self.cntr += 1
[docs]
def run(self, n_steps):
"""Run the model for n_steps time steps."""
for _ in range(n_steps):
self.step()
# ---------------------------------------------------------------
# Convenience: physical-space fields
# ---------------------------------------------------------------
[docs]
def get_psi_grid(self):
"""Return streamfunction in physical space, shape (nz, ny, nx)."""
assert self.psi is not None and self._g is not None
return spec2grid(self.psi, self._g)
[docs]
def get_q_grid(self):
"""Return PV in physical space, shape (nz, ny, nx)."""
assert self.q is not None and self._g is not None
return spec2grid(self.q, self._g)
[docs]
def get_u_grid(self):
"""Return zonal velocity, shape (nz, ny, nx)."""
assert self.psi is not None and self._g is not None
return spec2grid(-1j * self._g['ky_'][np.newaxis] * self.psi, self._g)
[docs]
def get_v_grid(self):
"""Return meridional velocity, shape (nz, ny, nx)."""
assert self.psi is not None and self._g is not None
return spec2grid(1j * self._g['kx_'][np.newaxis] * self.psi, self._g)
@property
def nx(self):
return self._g['nx'] if self._g else 2 * (self.kmax + 1)
@property
def ny(self):
return self.nx
@property
def nkx(self):
return self._g['nkx'] if self._g else 2 * self.kmax + 1
@property
def nky(self):
return self._g['nky'] if self._g else self.kmax + 1
# ---------------------------------------------------------------------------
# Helper referenced in strat.py (avoids circular import)
# ---------------------------------------------------------------------------
def _build_trimat(psiq, nz):
"""Build full nz×nz matrix from tridiagonal psiq coefficients."""
mat = np.diag(psiq[:nz, 1])
if nz > 1:
mat += np.diag(psiq[1:nz, 0], -1)
mat += np.diag(psiq[0:nz-1, 2], 1)
return mat