"""
Spectral grid setup and FFT transforms for the QG model.
Faithfully translates the Fortran transform_tools.f90 staggered-grid scheme.
Conventions
-----------
Spectral arrays: shape (..., nky, nkx) where nkx=2*kmax+1, nky=kmax+1
Physical arrays: shape (..., ny, nx) where nx=ny=2*(kmax+1)
kx index: kxv[j] = j - kmax in [-kmax, kmax]
ky index: kyv[j] = j in [0, kmax]
Staggered-grid dealiasing
--------------------------
Following Vallis (c1991) / Smith (c1998), each spec->phys transform packs
the straight physical field into Re(output) and the half-grid-shifted
(staggered) field into Im(output). Products are computed separately on
the two grids (ir_prod) to avoid aliasing, then transformed back.
FFT conventions (from fft_fftw3_1pe.f90)
-----------------------------------------
Fortran fft(x, dirn=-1) -> FFTW_BACKWARD, scale=1 -> np.fft.ifft2(x)*N²
Fortran fft(x, dirn=+1) -> FFTW_FORWARD, scale=1/N² -> np.fft.fft2(x)/N²
"""
from typing import Any
import numpy as np
try:
import jax.numpy as jnp # type: ignore[import-untyped]
_JAX = True
except ImportError:
jnp = np
_JAX = False
# ---------------------------------------------------------------------------
# Grid setup
# ---------------------------------------------------------------------------
[docs]
def setup_spectral_grid(kmax: int) -> dict[str, Any]:
"""Return wavenumber grids and index arrays for spec<->phys transforms.
Returns a dict with keys:
kxv, kyv 1-D wavenumber vectors
``kx_``, ``ky_``, ``ksqd_`` 2-D wavenumber grids (nky, nkx)
nx, ny, nkx, nky
kxup, kyup indices into the nx×ny physical array (positive side)
kxdn, kydn indices for the conjugate (negative) side
exx phase-shift for staggered grid, shape (nky, nkx)
sgn (-1)^(m+n) sign matrix, shape (nx, ny)
filter_mask base de-aliasing mask (1 inside, 0 outside), (nky, nkx)
lin2kx, lin2ky packed index lists for the non-zero mask region
nmask number of active (kx, ky) points
"""
nx = 2 * (kmax + 1)
ny = nx
nkx = 2 * kmax + 1
nky = kmax + 1
kxv = np.arange(-kmax, kmax + 1, dtype=np.float64) # (nkx,)
kyv = np.arange(0, kmax + 1, dtype=np.float64) # (nky,)
# 2-D wavenumber grids: shape (nky, nkx)
kx_ = np.broadcast_to(kxv[np.newaxis, :], (nky, nkx)).copy()
ky_ = np.broadcast_to(kyv[:, np.newaxis], (nky, nkx)).copy()
ksqd_ = kx_**2 + ky_**2
ksqd_[0, kmax] = 0.1 # avoid division by zero at (kx=0, ky=0)
# Index mapping from spectral half-plane into the full nx×ny FFT array.
# Translates Fortran 1-based kxup = kxv + kmax + 2 to 0-based:
kxv_i = np.arange(-kmax, kmax + 1, dtype=int)
kyv_i = np.arange(0, kmax + 1, dtype=int)
kxup = kxv_i + kmax + 1 # in [1, nkx]
kyup = kyv_i + kmax + 1 # in [kmax+1, nx-1]
kxdn = -kxv_i + kmax + 1 # in [1, nkx]
kydn = -kyv_i + kmax + 1 # in [1, kmax+1]
# Phase shift for staggered grid: exx[ky_idx, kx_idx] = exp(i*pi*(kx+ky)/nx)
exx = np.exp(1j * np.pi * (kxv[np.newaxis, :] + kyv[:, np.newaxis]) / nx)
# Checkerboard sign matrix: sgn[m, n] = (-1)^(m+n)
m = np.arange(nx)
n = np.arange(ny)
sgn = ((-1) ** (m[:, np.newaxis] + n[np.newaxis, :])).astype(np.float64)
# De-aliasing mask (isotropic, Orszag criterion from qg_arrays.f90):
# filter = 0 where ksqd_ >= (8/9)*(kmax+1)^2
# Also zero out kx <= 0, ky = 0 (those come from conjugate symmetry)
filter_mask = np.ones((nky, nkx), dtype=np.float64)
filter_mask[ksqd_ >= (8.0 / 9.0) * (kmax + 1) ** 2] = 0.0
filter_mask[0, :kmax + 1] = 0.0 # kx <= 0, ky = 0
# Packed lists of active wavenumber indices
flat_mask = filter_mask.ravel()
active = np.where(flat_mask > 0)[0]
lin2ky = (active // nkx).astype(np.int32)
lin2kx = (active % nkx).astype(np.int32)
nmask = len(active)
return dict(
kmax=kmax, nx=nx, ny=ny, nkx=nkx, nky=nky,
kxv=kxv, kyv=kyv, kx_=kx_, ky_=ky_, ksqd_=ksqd_,
kxup=kxup, kyup=kyup, kxdn=kxdn, kydn=kydn,
exx=exx, sgn=sgn,
filter_mask=filter_mask,
lin2kx=lin2kx, lin2ky=lin2ky, nmask=nmask,
)
# ---------------------------------------------------------------------------
# Filter construction
# ---------------------------------------------------------------------------
[docs]
def make_filter(g, filter_type='hyperviscous', filter_exp=8.0, k_cut=None,
dealiasing='isotropic', filt_tune=1.0):
"""Build the full spectral filter (de-aliasing + small-scale damping).
Parameters match the Fortran Init_filter in qg_init_tools.f90.
Returns real array of shape (nky, nkx).
"""
kmax = g['kmax']
nx = g['nx']
ksqd_ = g['ksqd_']
kx_ = g['kx_']
ky_ = g['ky_']
filt = np.ones((g['nky'], g['nkx']), dtype=np.float64)
filt[0, :kmax + 1] = 0.0 # conjugate-symmetry side
if dealiasing == 'isotropic':
filt[ksqd_ >= (8.0 / 9.0) * (kmax + 1) ** 2] = 0.0
kmax_da = np.sqrt(8.0 / 9.0) * (kmax + 1)
elif dealiasing == 'orszag':
filt[(np.abs(kx_) + np.abs(ky_)) >= (4.0 / 3.0) * (kmax + 1)] = 0.0
kmax_da = np.sqrt(8.0 / 9.0) * (kmax + 1)
else: # 'none'
kmax_da = kmax
if filter_type == 'hyperviscous':
mask = filt > 0.0
filt[mask] = 1.0 / (
1.0 + filt_tune * (4 * np.pi / nx) * (ksqd_[mask] / kmax_da**2) ** filter_exp
)
elif filter_type == 'exp_cutoff':
if k_cut is None:
raise ValueError('exp_cutoff filter requires k_cut')
mask = filt > 0.0
filt[mask] = np.exp(-((np.sqrt(ksqd_[mask]) - k_cut) / (kmax_da - k_cut)) ** filter_exp)
filt[ksqd_ < k_cut**2] = 1.0
# 'none': keep filt = 1 inside de-aliasing region
return filt
# ---------------------------------------------------------------------------
# Spectral <-> physical transforms
# ---------------------------------------------------------------------------
[docs]
def spec2grid_cc(wf, g):
"""Spectral to physical transform with staggered-grid packing.
wf : complex array, shape (..., nky, nkx)
Returns complex array shape (..., nx, ny): real part = physical field on
straight grid, imag part = physical field on staggered (half-shifted) grid.
Translates Fortran Spec2grid_cc2/3 from transform_tools.f90.
"""
kmax = g['kmax']
nx = g['nx']
ny = g['ny']
nkx = g['nkx']
nky = g['nky']
kxup = g['kxup']
kyup = g['kyup']
kxdn = g['kxdn']
kydn = g['kydn']
exx = g['exx']
sgn = g['sgn']
xp = np if not _JAX else jnp
batch = wf.shape[:-2]
wavefield = wf.copy() if not _JAX else wf
# Enforce Hermitian symmetry along ky=0 for kx < 0.
# Fortran: wavefield(1:kmax+1, 1) = conjg(wavefield(nkx:kmax+1:-1, 1))
# Python (0-based): indices 0..kmax set from indices 2*kmax..kmax (reversed)
if _JAX:
ky0 = wavefield[..., 0, :]
ky0 = ky0.at[..., :kmax + 1].set(xp.conj(ky0[..., 2*kmax:kmax-1:-1])) # type: ignore[union-attr]
wavefield = wavefield.at[..., 0, :].set(ky0) # type: ignore[union-attr]
else:
wavefield[..., 0, :kmax + 1] = np.conj(wavefield[..., 0, 2*kmax:kmax-1:-1])
# Build sparse nx×ny physical-space array
phys_shape = batch + (nx, ny)
physfield = xp.zeros(phys_shape, dtype=np.complex128)
# Meshgrid of insertion indices
ix_up, iy_up = np.ix_(kxup, kyup) if not _JAX else (
xp.ix_(kxup, kyup))
ix_dn, iy_dn = np.ix_(kxdn, kydn) if not _JAX else (
xp.ix_(kxdn, kydn))
# exx has shape (nky, nkx); wf has shape (..., nky, nkx)
# We need to swap to (nkx, nky) for the insertion
wf_t = wavefield.swapaxes(-2, -1) # (..., nkx, nky)
exx_t = exx.T # (nkx, nky)
plus_ = wf_t + 1j * (exx_t * wf_t)
minus_ = xp.conj(wf_t - 1j * (exx_t * wf_t))
if _JAX:
physfield = physfield.at[..., ix_up, iy_up].set(plus_) # type: ignore[union-attr]
physfield = physfield.at[..., ix_dn, iy_dn].set(minus_) # type: ignore[union-attr]
else:
physfield[..., ix_up, iy_up] = plus_
physfield[..., ix_dn, iy_dn] = minus_
# FFT: Fortran fft(x, -1) = FFTW_BACKWARD (no normalization)
# = np.fft.ifft2(x) * N²
N2 = nx * ny
physfield = xp.fft.ifft2(physfield, axes=(-2, -1)) * N2
physfield = sgn * physfield
return physfield
[docs]
def grid2spec(pf, g):
"""Physical (staggered-packed complex) to spectral transform.
pf : complex array, shape (..., nx, ny)
Returns complex array shape (..., nky, nkx).
Translates Fortran Grid2spec_cc2/3 from transform_tools.f90.
"""
kxup = g['kxup']
kyup = g['kyup']
kxdn = g['kxdn']
kydn = g['kydn']
exx = g['exx']
sgn = g['sgn']
nx = g['nx']
ny = g['ny']
xp = np if not _JAX else jnp
physfield = sgn * pf
# Fortran fft(x, +1) = FFTW_FORWARD / N² = np.fft.fft2(x) / N²
N2 = nx * ny
physfield = xp.fft.fft2(physfield, axes=(-2, -1)) / N2
# Extract at (kxup, kyup) and (kxdn, kydn)
ix_up, iy_up = np.ix_(kxup, kyup)
ix_dn, iy_dn = np.ix_(kxdn, kydn)
Pp = physfield[..., ix_up, iy_up] # (..., nkx, nky)
Pm = physfield[..., ix_dn, iy_dn] # (..., nkx, nky)
exx_t = exx.T # (nkx, nky)
wavefield = (
np.real(Pp + Pm)
+ 1j * np.imag(Pp - Pm)
+ (np.imag(Pp + Pm) + 1j * np.real(-Pp + Pm)) * np.conj(exx_t)
)
wavefield = 0.25 * wavefield
# Swap back to (..., nky, nkx)
return wavefield.swapaxes(-2, -1)
[docs]
def spec2grid(wf, g):
"""Spectral to physical (real output), discards staggered grid.
wf : complex (..., nky, nkx)
Returns real (..., nx, ny).
"""
return np.real(spec2grid_cc(wf, g))
# ---------------------------------------------------------------------------
# Physical-space products (de-aliased via staggered grid)
# ---------------------------------------------------------------------------
[docs]
def ir_prod(f, g_field):
"""Product of two staggered-packed physical fields.
f, g_field : complex (..., nx, ny) with Re=straight, Im=staggered
Returns complex (..., nx, ny) with the same packing.
"""
return (np.real(f) * np.real(g_field)
+ 1j * np.imag(f) * np.imag(g_field))
[docs]
def ir_pwr(f, pwr):
"""Raise staggered-packed field to power pwr (supports 2 and 0.5)."""
return (np.real(f) ** pwr + 1j * np.imag(f) ** pwr)
# ---------------------------------------------------------------------------
# Jacobian
# ---------------------------------------------------------------------------
[docs]
def jacob(fk, gk, g):
"""Arakawa Jacobian J(f, g) in spectral space.
fk, gk : complex (..., nky, nkx)
Returns complex (..., nky, nkx).
J(f,g) = df/dx * dg/dy - df/dy * dg/dx
"""
kx_ = g['kx_']
ky_ = g['ky_']
dfx = spec2grid_cc(1j * kx_ * fk, g)
dgy = spec2grid_cc(1j * ky_ * gk, g)
dfy = spec2grid_cc(1j * ky_ * fk, g)
dgx = spec2grid_cc(1j * kx_ * gk, g)
return grid2spec(ir_prod(dfx, dgy) - ir_prod(dfy, dgx), g)