Source code for alphadia.search.selection.fft

"""Fast Fourier Transform operations for signal processing."""

import numba as nb
import numpy as np
from numba.extending import overload
from rocket_fft import pocketfft
from rocket_fft.overloads import (
    decrease_shape,
    get_fct,
    increase_shape,
    ndshape_and_axes,
    resize,
    zeropad_or_crop,
)


[docs] class NumbaContextOnly(Exception): pass
def _rfft2(x: np.array, s: None | tuple = None) -> np.array: """Numba function to compute the 2D real-to-complex FFT of a real array. Parameters ---------- x : np.ndarray dtype = np.float32, ndim = 2, containing the input data. s : Union[None, tuple] Tuple of integers containing the shape of the output array. Returns ------- np.ndarray dtype = np.complex64, ndim = 2, containing the 2D real-to-complex FFT of the input array. .. note:: This function should only be used in a numba context as it relies on numba overloads. """ raise NumbaContextOnly( "This function should only be used in a numba context as it relies on numbas overloads." ) @overload(_rfft2, fastmath=True) def _(x, s=None): if not isinstance(x, nb.types.Array): return None if x.ndim != 2: return None if x.dtype != nb.types.float32: return None def funcx_impl(x, s=None): s, axes = ndshape_and_axes(x, s, (-2, -1)) x = zeropad_or_crop(x, s, axes, nb.types.float32) shape = decrease_shape(x.shape, axes) out = np.empty(shape, dtype=nb.types.complex64) fct = get_fct(x, axes, None, True) pocketfft.numba_r2c(x, out, axes, True, fct, 1) return out return funcx_impl def _irfft2(x: np.array, s: None | tuple = None) -> np.array: """Numba function to compute the 2D complex-to-real FFT of a complex array. Parameters ---------- x : np.ndarray dtype = np.complex64, ndim = 2, containing the input data. s : Union[None, tuple] Tuple of integers containing the shape of the output array. Returns ------- np.ndarray dtype = np.float32, ndim = 2, containing the 2D complex-to-real FFT of the input array. .. note:: This function should only be used in a numba context as it relies on numba overloads. """ raise NumbaContextOnly( "This function should only be used in a numba context as it relies on numbas overloads." ) @overload(_irfft2, fastmath=True) def _(x, s=None): if not isinstance(x, nb.types.Array): return None if x.ndim != 2: return None if x.dtype != nb.types.complex64: return None def funcx_impl(x, s=None): s, axes = ndshape_and_axes(x, s, (-2, -1)) xin = zeropad_or_crop(x, s, axes, nb.types.complex64) shape = increase_shape(x.shape, axes) shape = resize(shape, x, s, axes) out = np.empty(shape, dtype=nb.types.float32) fct = get_fct(out, axes, None, False) pocketfft.numba_c2r(xin, out, axes, False, fct, 1) return out return funcx_impl
[docs] def convolve_fourier(dense, kernel): """Numba helper function to apply a gaussian filter to a 2d or 3d dense matrix. Parameters ---------- dense : np.ndarray Array of shape (..., n_scans, n_frames) kernel : np.ndarray Array of shape (i, j) Returns ------- np.ndarray Array of shape (..., n_scans, n_frames) containing the filtered dense stack. """ raise NumbaContextOnly( "This function should only be used in a numba context as it relies on numbas overloads." )
@overload(convolve_fourier, fastmath=True) def _(dense, kernel): if not isinstance(dense, nb.types.Array): return None if not isinstance(kernel, nb.types.Array): return None if kernel.ndim != 2: return None if dense.ndim < 2: return None if dense.ndim == 2: def funcx_impl(dense, kernel): k0, k1 = kernel.shape delta0, delta1 = -k0 // 2, -k1 // 2 out = np.zeros_like(dense) fourier_filter = _rfft2(kernel, dense.shape) layer = _irfft2(_rfft2(dense) * fourier_filter) out[delta0:, delta1:] = layer[:-delta0, :-delta1] out[:delta0, delta1:] = layer[-delta0:, :-delta1] out[delta0:, :delta1] = layer[:-delta0, -delta1:] out[:delta0, :delta1] = layer[-delta0:, -delta1:] return out return funcx_impl if dense.ndim == 3: def funcx_impl(dense, kernel): k0, k1 = kernel.shape delta0, delta1 = -k0 // 2, -k1 // 2 out = np.zeros_like(dense) fourier_filter = _rfft2(kernel, dense.shape[-2:]) for i in range(dense.shape[0]): layer = _irfft2(_rfft2(dense[i]) * fourier_filter) out[i, delta0:, delta1:] = layer[:-delta0, :-delta1] out[i, :delta0, delta1:] = layer[-delta0:, :-delta1] out[i, delta0:, :delta1] = layer[:-delta0, -delta1:] out[i, :delta0, :delta1] = layer[-delta0:, -delta1:] return out return funcx_impl if dense.ndim == 4: def funcx_impl(dense, kernel): k0, k1 = kernel.shape delta0, delta1 = -k0 // 2, -k1 // 2 out = np.zeros_like(dense) fourier_filter = _rfft2(kernel, dense.shape[-2:]) for i in range(dense.shape[0]): for j in range(dense.shape[1]): layer = _irfft2(_rfft2(dense[i, j]) * fourier_filter) out[i, j, delta0:, delta1:] = layer[:-delta0, :-delta1] out[i, j, :delta0, delta1:] = layer[-delta0:, :-delta1] out[i, j, delta0:, :delta1] = layer[:-delta0, -delta1:] out[i, j, :delta0, :delta1] = layer[-delta0:, -delta1:] return out return funcx_impl