kaira.losses.audio.MultiResolutionSTFTLoss

Inheritance diagram of MultiResolutionSTFTLoss

Inheritance diagram for MultiResolutionSTFTLoss

class kaira.losses.audio.MultiResolutionSTFTLoss(fft_sizes=[512, 1024, 2048], hop_sizes=[128, 256, 512], win_lengths=[512, 1024, 2048], window='hann')[source]

Bases: BaseLoss

Multi-Resolution STFT Loss Module.

This module calculates STFT loss at multiple resolutions for better time-frequency coverage.

Methods

__init__

Initialize the MultiResolutionSTFTLoss module.

forward

Forward pass through the MultiResolutionSTFTLoss module.

__init__(fft_sizes=[512, 1024, 2048], hop_sizes=[128, 256, 512], win_lengths=[512, 1024, 2048], window='hann')[source]

Initialize the MultiResolutionSTFTLoss module.

Parameters:
  • fft_sizes (list) – List of FFT sizes for each resolution. Default is [512, 1024, 2048].

  • hop_sizes (list) – List of hop sizes for each resolution. Default is [128, 256, 512].

  • win_lengths (list) – List of window lengths for each resolution. Default is [512, 1024, 2048].

  • window (str) – Window function type. Default is ‘hann’.

forward(x: Tensor, target: Tensor) Tensor[source]

Forward pass through the MultiResolutionSTFTLoss module.

Parameters:
Returns:

The multi-resolution STFT loss.

Return type:

torch.Tensor