kaira.losses.audio.MultiResolutionSTFTLoss

Inheritance diagram for MultiResolutionSTFTLoss
- class kaira.losses.audio.MultiResolutionSTFTLoss(fft_sizes=[512, 1024, 2048], hop_sizes=[128, 256, 512], win_lengths=[512, 1024, 2048], window='hann')[source]
Bases:
BaseLossMulti-Resolution STFT Loss Module.
This module calculates STFT loss at multiple resolutions for better time-frequency coverage.
Methods
Initialize the MultiResolutionSTFTLoss module.
Forward pass through the MultiResolutionSTFTLoss module.
- __init__(fft_sizes=[512, 1024, 2048], hop_sizes=[128, 256, 512], win_lengths=[512, 1024, 2048], window='hann')[source]
Initialize the MultiResolutionSTFTLoss module.
- Parameters:
fft_sizes (list) – List of FFT sizes for each resolution. Default is [512, 1024, 2048].
hop_sizes (list) – List of hop sizes for each resolution. Default is [128, 256, 512].
win_lengths (list) – List of window lengths for each resolution. Default is [512, 1024, 2048].
window (str) – Window function type. Default is ‘hann’.
- forward(x: Tensor, target: Tensor) Tensor[source]
Forward pass through the MultiResolutionSTFTLoss module.
- Parameters:
x (torch.Tensor) – The input audio tensor.
target (torch.Tensor) – The target audio tensor.
- Returns:
The multi-resolution STFT loss.
- Return type: