Source code for NEDAS.utils.fft_lib

import numpy as np

try:
    # #fft implementation using FFTW
    import pyfftw  #type: ignore

    def fft2(f: np.ndarray) -> np.ndarray:
        """
        2D FFT implemented by pyFFTW. If pyFFTW is not available, will switch to np.fft.fft2

        Args:
            f (np.ndarray):
                Input field, of shape (..., ny, nx), of float32 type,
                the last two dimensions will be transformed by the 2D FFT.

        Returns:
            np.ndarray: Output field in spectral space, of complex64 type.
        """
        # prepare fftw plan
        a = pyfftw.empty_aligned(f.shape, dtype='float32')
        b = pyfftw.empty_aligned(f.shape[:-1] + (f.shape[-1]//2+1,), dtype='complex64')
        fft_obj = pyfftw.FFTW(a, b, axes=(-2, -1))

        # perform the fft2, output fh with same shape as f (keep dimension info)
        fh_ = fft_obj(f)
        fh = np.zeros(f.shape, dtype='complex64')

        nup = f.shape[-1]//2+1

        # top half of the spectrum
        fh[..., 0:nup] = fh_

        # the bottom half is conj of top half of the spectrum
        if f.shape[-1]%2 == 0:
            fh[..., 0, nup:] = np.conj(fh_[..., 0, nup-2:0:-1])
            fh[..., 1:, nup:] = np.conj(fh_[..., :0:-1, nup-2:0:-1])
        else:
            fh[..., 0, nup:] = np.conj(fh_[..., 0, nup-1:0:-1])
            fh[..., 1:, nup:] = np.conj(fh_[..., :0:-1, nup-1:0:-1])

        return fh

    def ifft2(fh: np.ndarray) -> np.ndarray:
        """
        Inverse 2D FFT implemented by pyFFTW. If pyFFTW is not available, will switch to np.fft.ifft2

        Args:
            fh (np.ndarray):
                Input field, of shape (..., nky, nkx), of complex64 type,
                the last two dimensions will be inverse transformed by 2D FFT.

        Returns:
            np.ndarray: Output field in physical space, of float32 type.
        """
        # prepare fftw plan
        b = pyfftw.empty_aligned(fh.shape[:-1] + (fh.shape[-1]//2+1,), dtype='complex64')
        a = pyfftw.empty_aligned(fh.shape, dtype='float32')
        fft_obj = pyfftw.FFTW(b, a, axes=(-2, -1), direction='FFTW_BACKWARD')
        # perform the ifft2
        # TODO: check index
        # TODO: normalize?
        f = fft_obj(fh[..., 0:fh.shape[1]//2+1])
        return f

except ImportError:
    #print("Warning: pyFFTW not found in your environment, will use numpy.fft instead.", flush=True)

[docs] def fft2(f: np.ndarray) -> np.ndarray: return np.fft.fft2(f)
[docs] def ifft2(fh: np.ndarray) -> np.ndarray: return np.real(np.fft.ifft2(fh))
[docs] def fftwn(n): """ Wavenumber sequence corresponding to the FFT output in 1 dimension. Args: n (int): The size of the dimension Returns: np.ndarray: The sequence of wavenumber (0,1,2,...-2,-1) for this dimension """ nup = int(np.ceil((n+1)/2)) if n%2 == 0: wn = np.concatenate((np.arange(0, nup), np.arange(2-nup, 0))) else: wn = np.concatenate((np.arange(0, nup), np.arange(1-nup, 0))) return wn
[docs] def get_wn(fld): """ Generates a meshgrid of wavenumbers corresponding to the input field dimensions. Args: fld (np.ndarray): Input field, the last two dimensions are the horizontal directions (ny, nx) Returns: np.ndarray: Wavenumbers corresponding to the input fields dimensions, relative to the domain size (nx or ny, whichever is larger). """ ny, nx = fld.shape[-2:] wnx = np.zeros(fld.shape) wny = np.zeros(fld.shape) nup = int(max(nx, ny)) for i in fftwn(nx): wnx[..., i] = i * nup / nx for j in fftwn(ny): wny[..., j, :] = j * nup / ny return wnx, wny