"""Wavelet class and functions.
This module contains the wavelet class and functions to generate wavelets.
.. dropdown:: Terms of use
.. code-block:: text
Copyright (C) 2023 Léonard Seydoux.
This program is free software: you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation, either version 3 of the License, or (at your
option) any later version.
This program is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import typing as T
try:
import cupy as xp # type: ignore
except ImportError:
import numpy as xp
import numpy as np
[docs]
def gaussian_window(
x: xp.ndarray,
width: T.Union[float, T.Sequence[float], xp.ndarray],
) -> xp.ndarray:
"""Gaussian function.
This function can generate a bank of windows at once if the width
argument is a vector (and/or amplitude). In this case, it should have
a new axis with respect to the time vector to allow for outer product.
Parameters
----------
x : :class:`numpy.ndarray` or :class:`cupy.ndarray`
Input variable, in the same units than the width.
width : float or np.ndarray
Window width (in the same units than the input variable). If an array
is provided, the function returns as many windows as the number of
elements of this parameter.
amplitude : float or np.ndarray, optional
Window amplitude at maximum (default 1). If this parameter is a vector,
it should have the same number of elements than the width.
Returns
-------
Same type as ``x``.
The Gaussian window in the time domain. If the width (and possibly
amplitude) argument is a vector, the function returns a matrix with
shape (len(width), len(x)).
"""
# turn parameters into a numpy arrays for dimension check
x = xp.array(x)
width = xp.array(width)
# add new axis for outer product if several widths are given
width = width[:, None] if width.shape and (width.ndim == 1) else width
return xp.exp(-((x / width) ** 2))
[docs]
def complex_morlet(
x: xp.ndarray,
center: T.Union[float, T.Sequence[float], xp.ndarray],
width: T.Union[float, T.Sequence[float], xp.ndarray],
) -> xp.ndarray:
"""Complex Morlet wavelet.
The complex Morlet wavelet is a complex plane wave modulated by a Gaussian
window. The oscillatory frequency of the plane wave is the center frequency,
and the temporal width of the Gaussian is the width argument.
This function can generate a filter bank at once if the width and center
arguments are vectors of the same size. In this case, they should have a new
axis with respect to the time vector to allow for outer product.
Arguments
---------
x: :class:`numpy.ndarray` or :class:`cupy.ndarray`
Time vector in seconds.
width: float or :class:`numpy.ndarray` or :class:`cupy.ndarray`.
Temporal signal width in seconds.
center: float or :class:`numpy.ndarray` or :class:`cupy.ndarray`.
Center frequency in Hertz.
Returns
-------
Same type as ``x``.
The complex Mortlet wavelet in the time domain. If the center and width
(and possibly amplitude) arguments are vectors, the function returns a
matrix with shape ``(len(width), len(x))``.
"""
# turn parameters into a numpy arrays for dimension check
x = xp.array(x)
width = xp.array(width)
center = xp.array(center)
# add new axis for outer product if several widths are given
width = width[:, None] if width.shape else width
center = center[:, None] if center.shape else center
# check compatibility between arguments
if width.shape and center.shape:
assert (
width.shape == center.shape
), f"Shape for widths {width.shape} and centers {center.shape} differ."
return gaussian_window(x, width) * xp.exp(2j * xp.pi * center * x)
[docs]
class ComplexMorletBank:
"""Complex Morlet filter bank."""
def __init__(
self,
bins: int,
octaves: int = 8,
resolution: int = 1,
quality: float = 4.0,
normalize_wavelet=None,
sampling_rate: float = 1.0,
):
"""Filter bank creation.
This function creates the filter bank in the time domain, and obtains
it in the frequency domain with a fast Fourier transform.
Parameters
----------
bins: int
Number of bins in the time domain. The filter bank will be
symmetric around the center of the time vector.
octaves: int
Number of octaves in the frequency domain.
resolution: int, optional
Number of filters per octaves (default 1).
quality: float, optional
Filter bank quality factor (constant, default 4).
sampling_rate: float, optional
Sampling rate of the signal (default 1).
"""
[docs]
self.resolution = resolution
[docs]
self.sampling_rate = sampling_rate
# Generate the filter bank
[docs]
self.wavelets = complex_morlet(self.times, self.centers, self.widths)
# Normalize filter bank or not
if normalize_wavelet is not None:
if normalize_wavelet == "L1":
self.norm_factor = xp.abs(self.wavelets).sum(axis=1)[
:, xp.newaxis
]
elif normalize_wavelet == "L2":
self.norm_factor = xp.sqrt(
(xp.abs(self.wavelets) ** 2).sum(axis=1)
)[:, xp.newaxis]
else:
AttributeError(
f"'normalize_wavelet' has no attribute {normalize_wavelet}",
"Supported are normalization by the 'L1'- and 'L2'-norm'.",
)
# Normalize filter bank
self.wavelets /= self.norm_factor
# Obtain the filter bank in the frequency domain
[docs]
self.spectra = xp.fft.fft(self.wavelets)
# Size attributes
[docs]
self.size = self.wavelets.shape[0]
[docs]
def __repr__(self) -> str:
"""Representation of the filter bank."""
return (
f"ComplexMorletBank(bins={self.bins}, octaves={self.octaves}, "
f"resolution={self.resolution}, quality={self.quality}, "
f"sampling_rate={self.sampling_rate}, len={len(self)})"
)
[docs]
def __len__(self) -> int:
"""Length of the filter bank."""
return self.octaves * self.resolution
@property
[docs]
def times(self) -> np.ndarray:
"""Wavelet bank symmetric time vector in seconds."""
duration = self.bins / self.sampling_rate
if xp.__name__ == "cupy":
return xp.asnumpy(xp.linspace(-0.5, 0.5, num=self.bins) * duration) # type: ignore
else:
return xp.linspace(-0.5, 0.5, num=self.bins) * duration
@property
[docs]
def frequencies(self) -> np.ndarray:
"""Wavelet bank frequency vector in Hertz."""
if xp.__name__ == "cupy":
return xp.asnumpy(xp.linspace(0, self.sampling_rate, self.bins)) # type: ignore
else:
return xp.linspace(0, self.sampling_rate, self.bins)
@property
[docs]
def nyquist(self) -> float:
"""Nyqyust frequency in Hertz."""
return self.sampling_rate / 2
@property
[docs]
def shape(self) -> tuple:
"""Filter bank total number of filters."""
return len(self), self.bins
@property
[docs]
def ratios(self) -> np.ndarray:
"""Wavelet bank ratios."""
ratios = xp.linspace(self.octaves, 0.0, self.shape[0], endpoint=False)
if xp.__name__ == "cupy":
return xp.asnumpy(-ratios[::-1]) # type: ignore
else:
return -ratios[::-1]
@property
[docs]
def scales(self) -> np.ndarray:
"""Wavelet bank scaling factors."""
if xp.__name__ == "cupy":
return xp.asnumpy(2**self.ratios) # type: ignore
else:
return 2**self.ratios
@property
[docs]
def centers(self) -> np.ndarray:
"""Wavelet bank center frequencies."""
if xp.__name__ == "cupy":
return xp.asnumpy(self.scales * self.nyquist) # type: ignore
else:
return self.scales * self.nyquist
@property
[docs]
def widths(self) -> np.ndarray:
"""Wavelet bank temporal widths."""
if xp.__name__ == "cupy":
return xp.asnumpy(self.quality / self.centers) # type: ignore
else:
return self.quality / self.centers