Source code for kaira.channels.analog

"""Analog Channel Implementations for Continuous-Input Signals.

This module provides implementations of channels with continuous inputs, supporting both real and
complex-valued signals. These channels represent various types of noise and distortions found in
analog communication systems.

For a comprehensive overview of analog channel models, see :cite:`goldsmith2005wireless` and :cite:`proakis2007digital`.
"""

from typing import Any, Optional, Union  # Add Union

import torch

from kaira.utils import snr_to_noise_power

from .base import BaseChannel
from .registry import ChannelRegistry


# Change type hints to accept torch.Tensor or float
def _apply_noise(x: torch.Tensor, noise_power: Optional[Union[float, torch.Tensor]] = None, snr_db: Optional[Union[float, torch.Tensor]] = None) -> torch.Tensor:
    """Add Gaussian noise to a signal with specified power or SNR.

    Automatically handles both real and complex signals by adding
    appropriate noise to each component.

    Args:
        x (torch.Tensor): The input signal (real or complex)
        noise_power (Optional[Union[float, torch.Tensor]]): The noise power to apply
        snr_db (Optional[Union[float, torch.Tensor]]): The SNR in dB (alternative to noise_power)

    Returns:
        torch.Tensor: The signal with added noise
    """
    # Calculate noise power if SNR specified
    if snr_db is not None:
        signal_power = torch.mean(torch.abs(x) ** 2)
        # Ensure snr_db is float for snr_to_noise_power if it's a tensor
        snr_db_float = snr_db.item() if isinstance(snr_db, torch.Tensor) else snr_db
        noise_power = snr_to_noise_power(signal_power, snr_db_float)

    # Validate that at least one of noise_power or snr_db was provided
    if noise_power is None:
        raise ValueError("Either noise_power or snr_db must be provided")

    # Ensure noise_power is a tensor
    if not isinstance(noise_power, torch.Tensor):
        noise_power = torch.tensor(noise_power, device=x.device, dtype=x.dtype if not torch.is_complex(x) else x.real.dtype)

    # Add appropriate noise type
    if torch.is_complex(x):
        # For complex signals, split noise power between real/imag components
        noise_power_component = noise_power * 0.5
        noise_real = torch.randn_like(x.real) * torch.sqrt(noise_power_component)
        noise_imag = torch.randn_like(x.imag) * torch.sqrt(noise_power_component)
        noise = torch.complex(noise_real, noise_imag)
    else:
        # For real signals, apply all noise power
        noise = torch.randn_like(x) * torch.sqrt(noise_power)

    return x + noise


[docs] @ChannelRegistry.register_channel() class AWGNChannel(BaseChannel): """Additive white Gaussian noise (AWGN) channel for signal transmission. This channel adds Gaussian noise to the input signal, supporting both real and complex-valued inputs automatically. For complex inputs, noise is added to both real and imaginary components. AWGN channels are fundamental in communication theory and commonly used as a baseline model :cite:`proakis2007digital`. Mathematical Model: y = x + n where n ~ N(0, σ²) for real inputs or n ~ CN(0, σ²) for complex inputs Args: avg_noise_power (float, optional): The average noise power σ². snr_db (float, optional): SNR in dB (alternative to avg_noise_power). Example: >>> # For real-valued signals >>> channel = AWGNChannel(avg_noise_power=0.1) >>> x_real = torch.ones(10, 1) >>> y_real = channel(x_real) # Real noisy output >>> # For complex-valued signals (same channel works) >>> x_complex = torch.complex(torch.ones(10, 1), torch.zeros(10, 1)) >>> y_complex = channel(x_complex) # Complex noisy output """ avg_noise_power: Optional[float] snr_db: Optional[float]
[docs] def __init__(self, avg_noise_power: Optional[float] = None, snr_db: Optional[float] = None, *args: Any, **kwargs: Any): """Initialize the AWGN channel. Args: avg_noise_power (float, optional): The average noise power σ². snr_db (float, optional): SNR in dB (alternative to avg_noise_power). *args: Variable length argument list passed to the base class. **kwargs: Arbitrary keyword arguments passed to the base class. """ super().__init__(*args, **kwargs) if snr_db is not None: self.snr_db = snr_db self.avg_noise_power = None elif avg_noise_power is not None: self.avg_noise_power = avg_noise_power self.snr_db = None else: raise ValueError("Either avg_noise_power or snr_db must be provided")
[docs] def forward(self, x: torch.Tensor, *args: Any, csi=None, noise=None, **kwargs: Any) -> torch.Tensor: """Apply AWGN to the input signal. Args: x (torch.Tensor): The input tensor. *args: Additional positional arguments (unused). csi (Optional[torch.Tensor]): Channel state information (unused in AWGN). noise (Optional[torch.Tensor]): Pre-generated noise tensor. If provided, this noise will be added instead of generating new noise. **kwargs: Additional keyword arguments (unused). Returns: torch.Tensor: The output tensor with AWGN added. """ # If pre-generated noise is provided, use it if noise is not None: return x + noise return _apply_noise(x, snr_db=self.snr_db, noise_power=self.avg_noise_power)
GaussianChannel = AWGNChannel
[docs] @ChannelRegistry.register_channel() class LaplacianChannel(BaseChannel): """Channel with additive Laplacian (double-exponential) noise. Models a channel with noise following the Laplacian distribution, which has heavier tails than Gaussian noise. This channel supports both real and complex-valued inputs. Laplacian noise is often used to model impulsive noise environments :cite:`middleton1977statistical`. Mathematical Model: y = x + n where n follows a Laplacian distribution Args: scale (float, optional): Scale parameter of the Laplacian distribution. avg_noise_power (float, optional): The average noise power. snr_db (float, optional): SNR in dB (alternative to scale or avg_noise_power). Example: >>> # Create a Laplacian channel with scale=0.5 >>> channel = LaplacianChannel(scale=0.5) >>> x = torch.ones(10, 1) >>> y = channel(x) # Output with Laplacian noise """ scale: Optional[float] avg_noise_power: Optional[float] snr_db: Optional[float]
[docs] def __init__(self, scale: Optional[float] = None, avg_noise_power: Optional[float] = None, snr_db: Optional[float] = None, *args: Any, **kwargs: Any): """Initialize the Laplacian channel. Args: scale (float, optional): Scale parameter of the Laplacian distribution. avg_noise_power (float, optional): The average noise power. snr_db (float, optional): SNR in dB (alternative to scale or avg_noise_power). *args: Variable length argument list passed to the base class. **kwargs: Arbitrary keyword arguments passed to the base class. """ super().__init__(*args, **kwargs) # Handle different parameter specifications if scale is not None: self.scale = scale self.avg_noise_power = None self.snr_db = None elif snr_db is not None: self.snr_db = snr_db self.scale = None self.avg_noise_power = None elif avg_noise_power is not None: self.avg_noise_power = avg_noise_power self.scale = None self.snr_db = None else: raise ValueError("Either scale, avg_noise_power, or snr_db must be provided")
def _get_laplacian_noise(self, shape, device): """Generate Laplacian distributed noise.""" u = torch.rand(shape, device=device) # Transform uniformly distributed samples to Laplacian distribution # using the inverse CDF method: sign(u-0.5) * -ln(1-2|u-0.5|) shifted_u = u - 0.5 sign = torch.sign(shifted_u) abs_shifted_u = torch.abs(shifted_u) # Handle edge case to avoid log(0) safe_abs_shifted_u = torch.clamp(2 * abs_shifted_u, max=0.999999) raw_laplacian = sign * (-torch.log(1 - safe_abs_shifted_u)) return raw_laplacian
[docs] def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: """Apply Laplacian noise to the input signal. Args: x (torch.Tensor): The input tensor. *args: Additional positional arguments (unused). **kwargs: Additional keyword arguments (unused). Returns: torch.Tensor: The output tensor with Laplacian noise added. """ # Determine noise parameters scale = self.scale target_noise_power: Optional[Union[float, torch.Tensor]] = None # Define target_noise_power if self.snr_db is not None: signal_power = torch.mean(torch.abs(x) ** 2) # Ensure snr_db is float snr_db_float = self.snr_db.item() if isinstance(self.snr_db, torch.Tensor) else self.snr_db target_noise_power = snr_to_noise_power(signal_power, snr_db_float) # For Laplacian distribution with zero mean, variance = 2*scale² scale = torch.sqrt(target_noise_power / 2) elif self.avg_noise_power is not None: # For Laplacian distribution with zero mean, variance = 2*scale² # Ensure avg_noise_power is float or tensor avg_noise_power_val = self.avg_noise_power.item() if isinstance(self.avg_noise_power, torch.Tensor) else self.avg_noise_power scale = torch.sqrt(torch.tensor(avg_noise_power_val / 2, device=x.device)) # Convert to tensor # Make sure scale is a tensor for calculations if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale, device=x.device, dtype=x.dtype if not torch.is_complex(x) else x.real.dtype) # Handle complex input if torch.is_complex(x): noise_real = self._get_laplacian_noise(x.real.shape, x.device) * scale noise_imag = self._get_laplacian_noise(x.imag.shape, x.device) * scale noise = torch.complex(noise_real, noise_imag) else: noise = self._get_laplacian_noise(x.shape, x.device) * scale return x + noise
[docs] @ChannelRegistry.register_channel() class PoissonChannel(BaseChannel): r"""Channel with signal-dependent Poisson noise. Models a channel where the output follows a Poisson distribution with mean proportional to the input. This is commonly used to model photon counting systems and optical communication channels :cite:`middleton1977statistical`. Mathematical Model: y ~ Poisson(λ·\|x\|) Args: rate_factor (float): Scaling factor λ for the Poisson rate. normalize (bool): Whether to normalize output back to input scale. Example: >>> # Create a Poisson channel with rate_factor=0.1 >>> channel = PoissonChannel(rate_factor=0.1) >>> x = torch.ones(10, 1) >>> y = channel(x) # Output with Poisson noise """ rate_factor: float
[docs] def __init__(self, rate_factor: float = 1.0, normalize: bool = False, *args: Any, **kwargs: Any): """Initialize the Poisson channel. Args: rate_factor (float): Scaling factor λ for the Poisson rate. normalize (bool): Whether to normalize output back to input scale. *args: Variable length argument list passed to the base class. **kwargs: Arbitrary keyword arguments passed to the base class. """ super().__init__(*args, **kwargs) if rate_factor <= 0: raise ValueError("Rate factor must be positive") self.rate_factor = rate_factor self.normalize = normalize
[docs] def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: """Apply Poisson channel to the input signal. Args: x (torch.Tensor): The input tensor (must be non-negative if real, or will use magnitude if complex) *args: Additional positional arguments (unused). **kwargs: Additional keyword arguments (unused). Returns: torch.Tensor: The output tensor following Poisson distribution """ # Ensure rate_factor is tensor for calculations rate_factor_tensor = torch.tensor(self.rate_factor, device=x.device, dtype=x.dtype if not torch.is_complex(x) else x.real.dtype) # Handle complex input if torch.is_complex(x): magnitude = torch.abs(x) # Store the phase to ensure we preserve it exactly phase = torch.angle(x) # Apply Poisson noise to magnitude rate = rate_factor_tensor * magnitude noisy_magnitude = torch.poisson(rate) # Normalize if requested if self.normalize: noisy_magnitude = noisy_magnitude / rate_factor_tensor # Reconstruct complex signal preserving exact phase return torch.polar(noisy_magnitude, phase) # Uses polar form with exact phase preservation else: if torch.any(x < 0): raise ValueError("Input to PoissonChannel must be non-negative") # Scale the input to get the Poisson rate rate = rate_factor_tensor * x # Generate Poisson random values y = torch.poisson(rate) # Normalize back to input scale if requested if self.normalize: y = y / rate_factor_tensor return y
[docs] @ChannelRegistry.register_channel() class PhaseNoiseChannel(BaseChannel): """Channel that introduces random phase noise. Models a channel where the phase of the signal is perturbed by random noise, which is common in oscillator circuits and synchronization :cite:`demir2000phase`. Mathematical Model: y = x * exp(j·θ) where θ ~ N(0, σ²) is the phase noise Args: phase_noise_std (float): Standard deviation of phase noise in radians. """ phase_noise_std: float
[docs] def __init__(self, phase_noise_std: float, *args: Any, **kwargs: Any): """Initialize the Phase Noise channel. Args: phase_noise_std (float): Standard deviation of phase noise in radians. *args: Variable length argument list passed to the base class. **kwargs: Arbitrary keyword arguments passed to the base class. """ super().__init__(*args, **kwargs) if phase_noise_std < 0: raise ValueError("Phase noise standard deviation must be non-negative") self.phase_noise_std = phase_noise_std
[docs] def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: """Apply phase noise to the input signal. Args: x (torch.Tensor): The input tensor (must be complex). *args: Additional positional arguments (unused). **kwargs: Additional keyword arguments (unused). Returns: torch.Tensor: The output tensor with phase noise applied. """ # Ensure phase_noise_std is tensor for calculations phase_noise_std_tensor = torch.tensor(self.phase_noise_std, device=x.device, dtype=x.real.dtype if torch.is_complex(x) else x.dtype) # Convert real signal to complex if needed if not torch.is_complex(x): x = torch.complex(x, torch.zeros_like(x)) # Generate random phase noise with controlled standard deviation phase_noise = torch.randn_like(x.real) * phase_noise_std_tensor return x * torch.exp(1j * phase_noise)
[docs] @ChannelRegistry.register_channel() class FlatFadingChannel(BaseChannel): """Flat fading channel with configurable distribution and coherence time. Models a wireless channel where the fading coefficient remains constant over a specified coherence time and then changes to a new independent realization. This represents blockwise fading commonly used in communications analysis :cite:`tse2005fundamentals` :cite:`rappaport2024wireless`. Mathematical Model: y[i] = h[⌊i/L⌋] * x[i] + n[i] where L is the coherence length, h follows a specified distribution, and n ~ CN(0,σ²) Args: fading_type (str): Distribution type for fading coefficients ('rayleigh', 'rician', or 'lognormal') coherence_time (int): Number of samples over which the fading coefficient remains constant k_factor (float, optional): Rician K-factor (ratio of direct to scattered power), used only when fading_type='rician' avg_noise_power (float, optional): The average noise power σ² snr_db (float, optional): SNR in dB (alternative to avg_noise_power) shadow_sigma_db (float, optional): Standard deviation in dB for log-normal shadowing, used only when fading_type='lognormal' Example: >>> # Create a flat Rayleigh fading channel with coherence time of 10 samples >>> channel = FlatFadingChannel('rayleigh', coherence_time=10, snr_db=15) >>> x = torch.complex(torch.ones(100), torch.zeros(100)) >>> y = channel(x) # Output with block fading effects """ k_factor: Optional[float] avg_noise_power: Optional[float] snr_db: Optional[float] shadow_sigma_db: Optional[float]
[docs] def __init__( self, fading_type: str, coherence_time: int, k_factor: Optional[float] = None, avg_noise_power: Optional[float] = None, snr_db: Optional[float] = None, shadow_sigma_db: Optional[float] = None, *args: Any, **kwargs: Any, ): """Initialize the Flat Fading channel. Args: fading_type (str): Distribution type ('rayleigh', 'rician', 'lognormal'). coherence_time (int): Samples over which fading is constant. k_factor (float, optional): Rician K-factor (for 'rician'). avg_noise_power (float, optional): Average noise power σ². snr_db (float, optional): SNR in dB (alternative to avg_noise_power). shadow_sigma_db (float, optional): Shadowing std dev in dB (for 'lognormal'). *args: Variable length argument list passed to the base class. **kwargs: Arbitrary keyword arguments passed to the base class. """ super().__init__(*args, **kwargs) # Validate and store fading type valid_types = ["rayleigh", "rician", "lognormal"] if fading_type not in valid_types: raise ValueError(f"Fading type must be one of {valid_types}") self.fading_type = fading_type # Store fading parameters self.coherence_time = coherence_time self.k_factor = k_factor self.shadow_sigma_db = shadow_sigma_db # Verify required parameters based on fading type if fading_type == "rician" and k_factor is None: raise ValueError("K-factor must be provided for Rician fading") if fading_type == "lognormal" and shadow_sigma_db is None: raise ValueError("shadow_sigma_db must be provided for lognormal fading") # Store noise parameters if snr_db is not None: self.snr_db = snr_db self.avg_noise_power = None elif avg_noise_power is not None: self.avg_noise_power = avg_noise_power self.snr_db = None else: raise ValueError("Either avg_noise_power or snr_db must be provided")
def _generate_fading_coefficients(self, batch_size, seq_length, device): """Generate fading coefficients based on the specified distribution. Args: batch_size (int): Number of independent channel realizations seq_length (int): Length of the input sequence device (torch.device): Device to create tensors on Returns: torch.Tensor: Complex fading coefficients of shape (batch_size, blocks) where blocks = ceil(seq_length / coherence_time) """ # Calculate number of fading blocks needed num_blocks = (seq_length + self.coherence_time - 1) // self.coherence_time if self.fading_type == "rayleigh": # Complex Gaussian distribution for Rayleigh fading h_real = torch.randn(batch_size, num_blocks, device=device) h_imag = torch.randn(batch_size, num_blocks, device=device) h = torch.complex(h_real, h_imag) / (2**0.5) elif self.fading_type == "rician": # Rician fading with K factor # Ensure k_factor is tensor for calculations if self.k_factor is None: raise ValueError("K-factor must be provided for Rician fading") k = torch.tensor(self.k_factor, device=device) # Direct component (line of sight) los_magnitude = torch.sqrt(k / (k + 1)) los = los_magnitude * torch.ones(batch_size, num_blocks, device=device) # Scattered component scattered_magnitude = torch.sqrt(1 / (k + 1)) / (2**0.5) h_real = torch.randn(batch_size, num_blocks, device=device) * scattered_magnitude h_imag = torch.randn(batch_size, num_blocks, device=device) * scattered_magnitude scattered = torch.complex(h_real, h_imag) # Combined Rician fading h = torch.complex(los, torch.zeros_like(los)) + scattered elif self.fading_type == "lognormal": # Log-normal shadowing combined with Rayleigh fading # First generate Rayleigh component h_real = torch.randn(batch_size, num_blocks, device=device) h_imag = torch.randn(batch_size, num_blocks, device=device) h_rayleigh = torch.complex(h_real, h_imag) / (2**0.5) # Generate log-normal shadowing in linear scale # Ensure shadow_sigma_db is tensor for calculations if self.shadow_sigma_db is None: raise ValueError("shadow_sigma_db must be provided for lognormal fading") shadow_sigma_db_tensor = torch.tensor(self.shadow_sigma_db, device=device) sigma_ln = shadow_sigma_db_tensor * (torch.log(torch.tensor(10.0, device=device)) / 10) # Convert from dB to natural log ln_mean = -(sigma_ln**2) / 2 # Ensure unit mean shadow = torch.exp(torch.randn(batch_size, num_blocks, device=device) * sigma_ln + ln_mean) # Apply shadowing to fast fading component h = h_rayleigh * torch.complex(shadow, torch.zeros_like(shadow)) return h def _expand_coefficients(self, h, seq_length): """Expand block fading coefficients to match input sequence length. Args: h (torch.Tensor): Block fading coefficients of shape (batch_size, num_blocks) seq_length (int): Target sequence length Returns: torch.Tensor: Expanded coefficients of shape (batch_size, seq_length) """ batch_size = h.shape[0] device = h.device # Create indices for each position in the sequence block_indices = torch.arange(seq_length, device=device) // self.coherence_time # Expand block fading coefficients to full sequence length h_expanded = torch.zeros(batch_size, seq_length, dtype=h.dtype, device=device) for b in range(batch_size): h_expanded[b] = h[b, block_indices] return h_expanded
[docs] def forward(self, x: torch.Tensor, *args: Any, csi=None, noise=None, **kwargs: Any) -> torch.Tensor: """Apply flat fading and noise to the input signal. Args: x (torch.Tensor): The input tensor. *args: Additional positional arguments (unused). csi (Optional[torch.Tensor]): Pre-computed channel state information (fading coefficients). If provided, these coefficients are used instead of generating new ones. noise (Optional[torch.Tensor]): Pre-generated noise tensor. If provided, this noise is added. **kwargs: Additional keyword arguments (unused). Returns: torch.Tensor: The output tensor after applying fading and noise. """ # Handle different input shapes original_shape = x.shape is_1d = len(original_shape) == 1 if is_1d: # Handle 1D inputs by adding a batch dimension x = x.unsqueeze(0) if len(x.shape) > 2: # Reshape to (batch_size, seq_length) for processing x = x.reshape(x.shape[0], -1) # Ensure input is complex if not torch.is_complex(x): x = torch.complex(x, torch.zeros_like(x)) batch_size, seq_length = x.shape device = x.device # Use provided CSI if available, otherwise generate fading coefficients if csi is not None: # Use the provided CSI h = csi else: # Generate fading coefficients h_blocks = self._generate_fading_coefficients(batch_size, seq_length, device) # Expand to match sequence length h = self._expand_coefficients(h_blocks, seq_length) # Apply fading y = h * x # Add noise if provided, otherwise generate it if noise is not None: y = y + noise else: # Determine noise power noise_power_val: Union[float, torch.Tensor] # Type hint for clarity if self.snr_db is not None: signal_power = torch.mean(torch.abs(y) ** 2) # self.snr_db is guaranteed to be float by __init__ noise_power_val = snr_to_noise_power(signal_power, self.snr_db) elif self.avg_noise_power is not None: # self.avg_noise_power is guaranteed to be float by __init__ noise_power_val = self.avg_noise_power else: # This case should be prevented by __init__ validation raise ValueError("Noise parameters not properly initialized.") # Should not happen # Ensure noise_power is tensor for calculations if not isinstance(noise_power_val, torch.Tensor): noise_power_tensor = torch.tensor(noise_power_val, device=device, dtype=y.dtype if not torch.is_complex(y) else y.real.dtype) else: noise_power_tensor = noise_power_val # Already a tensor # Split noise power between real and imaginary components component_noise_power = noise_power_tensor * 0.5 noise_real = torch.randn_like(y.real) * torch.sqrt(component_noise_power) noise_imag = torch.randn_like(y.imag) * torch.sqrt(component_noise_power) noise = torch.complex(noise_real, noise_imag) y = y + noise # Reshape to original dimensions if needed if len(original_shape) > 2: y = y.reshape(*original_shape) elif is_1d: # Remove the batch dimension we added for 1D inputs y = y.squeeze(0) return y
[docs] @ChannelRegistry.register_channel() class NonlinearChannel(BaseChannel): """General nonlinear channel with configurable transfer function. Models various nonlinear effects by applying a user-specified nonlinear function to the input signal, optionally followed by additive noise. Handles both real and complex-valued signals. Common nonlinear models include the Saleh model for traveling-wave tube amplifiers :cite:`saleh1981frequency`. Mathematical Model: y = f(x) + n where f is a nonlinear function and n is optional noise Args: nonlinear_fn (callable): A function that implements the nonlinear transformation add_noise (bool): Whether to add noise after the nonlinear operation avg_noise_power (float, optional): The average noise power if add_noise is True snr_db (float, optional): SNR in dB (alternative to avg_noise_power) complex_mode (str, optional): How to handle complex inputs: 'direct' (default) passes the complex signal directly to nonlinear_fn, 'cartesian' applies the function separately to real and imaginary parts, 'polar' applies to magnitude and preserves phase Example: >>> # Create a channel with cubic nonlinearity for real signals >>> channel = NonlinearChannel(lambda x: x + 0.2 * x**3) >>> x = torch.linspace(-1, 1, 100) >>> y = channel(x) # Output with cubic distortion >>> # For complex signals, using polar mode (apply nonlinearity to magnitude only) >>> def mag_distortion(x): return x * (1 - 0.1 * x) # compression >>> channel = NonlinearChannel(mag_distortion, complex_mode='polar') >>> x = torch.complex(torch.randn(100), torch.randn(100)) >>> y = channel(x) # Output with magnitude distortion, phase preserved """ avg_noise_power: Optional[float] snr_db: Optional[float]
[docs] def __init__( self, nonlinear_fn, add_noise=False, avg_noise_power: Optional[float] = None, snr_db: Optional[float] = None, complex_mode="direct", *args: Any, **kwargs: Any, ): """Initialize the Nonlinear channel. Args: nonlinear_fn (callable): The nonlinear transformation function. add_noise (bool): Whether to add noise after nonlinearity. avg_noise_power (float, optional): Average noise power if add_noise=True. snr_db (float, optional): SNR in dB (alternative if add_noise=True). complex_mode (str): How to handle complex inputs ('direct', 'cartesian', 'polar'). *args: Variable length argument list passed to the base class. **kwargs: Arbitrary keyword arguments passed to the base class. """ super().__init__(*args, **kwargs) self.nonlinear_fn = nonlinear_fn self.add_noise = add_noise self.complex_mode = complex_mode if complex_mode not in ["direct", "cartesian", "polar"]: raise ValueError("complex_mode must be 'direct', 'cartesian', or 'polar'") if add_noise: if snr_db is not None and avg_noise_power is not None: raise ValueError("Cannot specify both snr_db and avg_noise_power") elif snr_db is not None: self.snr_db = snr_db self.avg_noise_power = None elif avg_noise_power is not None: self.avg_noise_power = avg_noise_power self.snr_db = None else: raise ValueError("If add_noise=True, either avg_noise_power or snr_db must be provided") else: self.avg_noise_power = None self.snr_db = None
[docs] def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: """Apply the nonlinear function and optional noise to the input signal. Args: x (torch.Tensor): The input tensor. *args: Additional positional arguments (unused). **kwargs: Additional keyword arguments (unused). Returns: torch.Tensor: The output tensor after applying nonlinearity and noise. """ # Handle complex inputs according to specified mode if torch.is_complex(x): if self.complex_mode == "direct": # Pass complex tensor directly to the function y = self.nonlinear_fn(x) elif self.complex_mode == "cartesian": # Apply nonlinearity separately to real and imaginary parts y_real = self.nonlinear_fn(x.real) y_imag = self.nonlinear_fn(x.imag) y = torch.complex(y_real, y_imag) elif self.complex_mode == "polar": # Apply nonlinearity to magnitude, preserve phase magnitude = torch.abs(x) phase = torch.angle(x) # Apply nonlinearity to magnitude new_magnitude = self.nonlinear_fn(magnitude) # Reconstruct complex signal y = new_magnitude * torch.exp(1j * phase) else: # For real inputs, just apply the function y = self.nonlinear_fn(x) # Add noise if requested if self.add_noise: # Check if avg_noise_power or snr_db is None before passing to _apply_noise # _apply_noise already handles the case where one is None y = _apply_noise(y, snr_db=self.snr_db, noise_power=self.avg_noise_power) return y
[docs] @ChannelRegistry.register_channel() class RayleighFadingChannel(FlatFadingChannel): """Specialized channel for Rayleigh fading in wireless communications. This is a convenience class that creates a FlatFadingChannel with the fading_type set to "rayleigh" to model Rayleigh fading, which is common in non-line-of-sight wireless propagation environments. Mathematical Model: y[i] = h[⌊i/L⌋] * x[i] + n[i] where L is the coherence length, h follows a Rayleigh distribution, and n ~ CN(0,σ²) Args: coherence_time (int, optional): Number of samples over which the fading coefficient remains constant. Defaults to 1. avg_noise_power (float, optional): The average noise power σ² snr_db (float, optional): SNR in dB (alternative to avg_noise_power) Example: >>> # Create a Rayleigh fading channel with coherence time of 10 samples >>> channel = RayleighFadingChannel(coherence_time=10, snr_db=15) >>> x = torch.complex(torch.ones(100), torch.zeros(100)) >>> y = channel(x) # Output with Rayleigh fading """
[docs] def __init__( self, coherence_time=1, avg_noise_power: Optional[float] = None, snr_db: Optional[float] = None, *args: Any, **kwargs: Any, ): """Initialize the Rayleigh Fading channel. Args: coherence_time (int, optional): Samples over which fading is constant. Defaults to 1. avg_noise_power (float, optional): Average noise power σ². snr_db (float, optional): SNR in dB (alternative to avg_noise_power). *args: Variable length argument list passed to the base class. **kwargs: Arbitrary keyword arguments passed to the base class. """ kwargs = kwargs.copy() kwargs["coherence_time"] = coherence_time kwargs["avg_noise_power"] = avg_noise_power kwargs["snr_db"] = snr_db kwargs["fading_type"] = "rayleigh" super().__init__(*args, **kwargs)
[docs] @ChannelRegistry.register_channel() class RicianFadingChannel(FlatFadingChannel): """Rician fading channel with configurable K-factor and coherence time. A specialized version of FlatFadingChannel that uses Rician fading. Suitable for modeling wireless channels with a dominant direct path plus multiple weaker reflection paths. Mathematical Model: y = h*x + n where h follows a Rician distribution with K-factor and n ~ CN(0,σ²) The K-factor represents the ratio of power in the direct path to the power in the scattered paths. Higher K values indicate a stronger line-of-sight component. Args: k_factor (float): Rician K-factor (ratio of direct to scattered power) coherence_time (int): Number of samples over which the fading coefficient remains constant avg_noise_power (float, optional): The average noise power snr_db (float, optional): SNR in dB (alternative to avg_noise_power) Example: >>> # Create a Rician channel with K=5 (strong direct path) >>> channel = RicianFadingChannel(k_factor=5, coherence_time=10, snr_db=15) >>> x = torch.complex(torch.ones(100), torch.zeros(100)) >>> y = channel(x) # Output with Rician fading """
[docs] def __init__( self, k_factor: float = 1.0, coherence_time=1, avg_noise_power: Optional[float] = None, snr_db: Optional[float] = None, *args: Any, **kwargs: Any, ): """Initialize the Rician Fading channel. Args: k_factor (float): Rician K-factor. Defaults to 1.0. coherence_time (int): Samples over which fading is constant. Defaults to 1. avg_noise_power (float, optional): Average noise power. snr_db (float, optional): SNR in dB (alternative to avg_noise_power). *args: Variable length argument list passed to the base class. **kwargs: Arbitrary keyword arguments passed to the base class. """ # Validate k_factor is non-negative before passing to parent if k_factor < 0: raise ValueError("K-factor must be non-negative") kwargs = kwargs.copy() kwargs["k_factor"] = k_factor kwargs["coherence_time"] = coherence_time kwargs["avg_noise_power"] = avg_noise_power kwargs["snr_db"] = snr_db kwargs["fading_type"] = "rician" super().__init__(*args, **kwargs)
[docs] @ChannelRegistry.register_channel() class LogNormalFadingChannel(FlatFadingChannel): """Log-normal fading channel with configurable shadowing standard deviation. A specialized version of FlatFadingChannel that uses log-normal fading. Suitable for modeling large-scale shadowing effects in wireless channels where obstacles like buildings, terrain, and foliage cause signal power variations. Mathematical Model: y = h*x + n where h includes log-normal shadowing and n ~ CN(0,σ²) The shadowing standard deviation (shadow_sigma_db) controls the variability of the fading. Higher values lead to more severe shadowing effects. Args: shadow_sigma_db (float): Standard deviation in dB for log-normal shadowing coherence_time (int): Number of samples over which the fading coefficient remains constant avg_noise_power (float, optional): The average noise power snr_db (float, optional): SNR in dB (alternative to avg_noise_power) Example: >>> # Create a log-normal shadowing channel with 8 dB standard deviation >>> channel = LogNormalFadingChannel(shadow_sigma_db=8.0, coherence_time=100, snr_db=15) >>> x = torch.complex(torch.ones(1000), torch.zeros(1000)) >>> y = channel(x) # Output with log-normal shadowing """
[docs] def __init__( self, shadow_sigma_db: float = 4.0, coherence_time=100, avg_noise_power: Optional[float] = None, snr_db: Optional[float] = None, *args: Any, **kwargs: Any, ): """Initialize the Log-Normal Fading channel. Args: shadow_sigma_db (float): Shadowing std dev in dB. Defaults to 4.0. coherence_time (int): Samples over which fading is constant. Defaults to 100. avg_noise_power (float, optional): Average noise power. snr_db (float, optional): SNR in dB (alternative to avg_noise_power). *args: Variable length argument list passed to the base class. **kwargs: Arbitrary keyword arguments passed to the base class. """ # Validate shadow_sigma_db is non-negative if shadow_sigma_db < 0: raise ValueError("shadow_sigma_db must be non-negative") kwargs = kwargs.copy() kwargs["shadow_sigma_db"] = shadow_sigma_db kwargs["coherence_time"] = coherence_time kwargs["avg_noise_power"] = avg_noise_power kwargs["snr_db"] = snr_db kwargs["fading_type"] = "lognormal" super().__init__(*args, **kwargs)