Source code for kaira.channels.uplink_mac

"""Uplink Multiple Access Channel (MAC) implementations.

This module provides channel models for uplink communication scenarios where multiple users
transmit simultaneously to a single receiver. The UplinkMACChannel uses a composition pattern,
accepting existing channel implementations as parameters to model different channel conditions for
individual user transmissions.
"""

from typing import Any, Dict, List, Optional, Union

import torch

from .base import BaseChannel
from .registry import ChannelRegistry


[docs] @ChannelRegistry.register_channel() class UplinkMACChannel(BaseChannel): """Uplink Multiple Access Channel (MAC) for modeling multi-user uplink communications. This channel models uplink communication scenarios where multiple users transmit simultaneously to a single receiver. The channel uses a composition pattern, accepting existing channel implementations (e.g., FlatFadingChannel, AWGNChannel) as parameters to model different channel conditions for individual user transmissions. The channel applies per-user channel effects, models inter-user interference, and combines the signals according to the MAC model. This enables realistic simulation of uplink scenarios with different channel conditions per user. Mathematical Model: For N users, the received signal is: y = Σᵢ₌₁ᴺ hᵢ(xᵢ) + interference + noise where hᵢ(xᵢ) is the channel response for user i's signal xᵢ. Args: user_channels (Union[BaseChannel, List[BaseChannel]]): Channel instances for each user. Can be a single channel to be shared among all users, or a list of channels (one per user). num_users (Optional[int]): Number of users. Required if user_channels is a single channel instance. Inferred from the list length if user_channels is a list. user_gains (Optional[Union[float, List[float]]]): Per-user channel gains. Can be a single gain applied to all users or a list of gains (one per user). Defaults to 1.0 for all users. interference_power (float): Power of inter-user interference. Defaults to 0.0. combine_method (str): Method for combining user signals. Options: 'sum', 'weighted_sum'. Defaults to 'sum'. Example: >>> # Using the same AWGN channel for all users >>> from kaira.channels import AWGNChannel, UplinkMACChannel >>> base_channel = AWGNChannel(avg_noise_power=0.1) >>> uplink_channel = UplinkMACChannel( ... user_channels=base_channel, ... num_users=3, ... user_gains=[1.0, 0.8, 0.6] ... ) >>> # Using different channels for each user >>> from kaira.channels import FlatFadingChannel, RayleighFadingChannel >>> user_channels = [ ... AWGNChannel(avg_noise_power=0.1), ... FlatFadingChannel(fading_type="rayleigh", coherence_time=10, avg_noise_power=0.05), ... RayleighFadingChannel(coherence_time=5, avg_noise_power=0.15) ... ] >>> uplink_channel = UplinkMACChannel(user_channels=user_channels) """
[docs] def __init__( self, user_channels: Union[BaseChannel, List[BaseChannel]], num_users: Optional[int] = None, user_gains: Optional[Union[float, List[float]]] = None, interference_power: float = 0.0, combine_method: str = "sum", *args: Any, **kwargs: Any, ): """Initialize the UplinkMAC channel. Args: user_channels (Union[BaseChannel, List[BaseChannel]]): Channel instances for each user. num_users (Optional[int]): Number of users. Required if user_channels is a single channel. user_gains (Optional[Union[float, List[float]]]): Per-user channel gains. interference_power (float): Power of inter-user interference. Defaults to 0.0. combine_method (str): Method for combining user signals. Defaults to 'sum'. *args: Variable length argument list passed to the base class. **kwargs: Arbitrary keyword arguments passed to the base class. """ super().__init__(*args, **kwargs) # Validate and set up user channels if isinstance(user_channels, list): if num_users is None: num_users = len(user_channels) elif num_users != len(user_channels): raise ValueError(f"Number of user channels ({len(user_channels)}) must match num_users ({num_users})") self.user_channels = user_channels self.shared_channel = False elif isinstance(user_channels, BaseChannel): if num_users is None: raise ValueError("num_users must be specified when using a shared channel") if num_users <= 0: raise ValueError("num_users must be positive") self.user_channels = [user_channels] * num_users self.shared_channel = True else: raise TypeError("user_channels must be a BaseChannel instance or a list of BaseChannel instances") self.num_users = num_users # Validate and set up user gains if user_gains is None: self.user_gains = torch.ones(self.num_users, dtype=torch.float32) elif isinstance(user_gains, (int, float)): self.user_gains = torch.full((self.num_users,), float(user_gains), dtype=torch.float32) elif isinstance(user_gains, list): if len(user_gains) != self.num_users: raise ValueError(f"Length of user_gains ({len(user_gains)}) must match num_users ({self.num_users})") self.user_gains = torch.tensor(user_gains, dtype=torch.float32) else: raise TypeError("user_gains must be a number or a list of numbers") # Validate interference power if interference_power < 0: raise ValueError("interference_power must be non-negative") self.interference_power = interference_power # Validate combine method valid_methods = ["sum", "weighted_sum"] if combine_method not in valid_methods: raise ValueError(f"combine_method must be one of {valid_methods}") self.combine_method = combine_method
[docs] def forward( self, x: List[torch.Tensor], *args: Any, user_csi: Optional[List[torch.Tensor]] = None, user_noise: Optional[List[torch.Tensor]] = None, **kwargs: Any, ) -> torch.Tensor: """Apply uplink MAC channel effects to user signals. Args: x (List[torch.Tensor]): List of input signals, one per user. Each tensor should have the same shape. *args: Additional positional arguments passed to individual channels. user_csi (Optional[List[torch.Tensor]]): Per-user channel state information. If provided, should be a list of CSI tensors (one per user). user_noise (Optional[List[torch.Tensor]]): Per-user noise tensors. If provided, should be a list of noise tensors (one per user). **kwargs: Additional keyword arguments passed to individual channels. Returns: torch.Tensor: Combined received signal after applying channel effects and inter-user interference. Raises: ValueError: If the number of input signals doesn't match num_users. ValueError: If user_csi or user_noise lists don't match num_users. """ user_signals = x # Validate inputs if not isinstance(x, list): raise TypeError("user_signals must be a list of torch.Tensors") if len(user_signals) != self.num_users: raise ValueError(f"Expected {self.num_users} user signals, got {len(user_signals)}") if user_csi is not None and len(user_csi) != self.num_users: raise ValueError(f"Expected {self.num_users} user_csi, got {len(user_csi)}") if user_noise is not None and len(user_noise) != self.num_users: raise ValueError(f"Expected {self.num_users} user_noise, got {len(user_noise)}") # Validate that all user signals have the same shape reference_shape = user_signals[0].shape for i, signal in enumerate(user_signals[1:], 1): if signal.shape != reference_shape: raise ValueError(f"All user signals must have the same shape. " f"User 0: {reference_shape}, User {i}: {signal.shape}") # Process each user's signal through their respective channel processed_signals = [] for i in range(self.num_users): channel = self.user_channels[i] signal = user_signals[i] gain = self.user_gains[i] # Prepare channel-specific arguments channel_kwargs = kwargs.copy() if user_csi is not None: channel_kwargs["csi"] = user_csi[i] if user_noise is not None: channel_kwargs["noise"] = user_noise[i] # Apply channel effects processed_signal = channel(signal, *args, **channel_kwargs) # Apply user-specific gain gain_value = gain.item() if hasattr(gain, "item") else gain if gain_value != 1.0: processed_signal = processed_signal * gain_value processed_signals.append(processed_signal) # Add inter-user interference if specified if self.interference_power > 0: processed_signals = self._add_interference(processed_signals) # Combine signals according to the specified method combined_signal = self._combine_signals(processed_signals) return combined_signal
def _add_interference(self, processed_signals: List[torch.Tensor]) -> List[torch.Tensor]: """Add inter-user interference to processed signals. Args: processed_signals (List[torch.Tensor]): List of processed user signals. Returns: List[torch.Tensor]: Signals with added interference. """ if self.interference_power <= 0: return processed_signals # Create interference signals interfered_signals = [] for i, signal in enumerate(processed_signals): # Generate interference from other users interference = torch.zeros_like(signal) for j, other_signal in enumerate(processed_signals): if i != j: # Don't add self-interference # Add scaled version of other user's signal as interference interference_scale = torch.sqrt(torch.tensor(self.interference_power, device=signal.device)) interference += other_signal * interference_scale / torch.sqrt(torch.tensor(self.num_users - 1, device=signal.device)) interfered_signals.append(signal + interference) return interfered_signals def _combine_signals(self, signals: List[torch.Tensor]) -> torch.Tensor: """Combine processed user signals according to the specified method. Args: signals (List[torch.Tensor]): List of processed user signals. Returns: torch.Tensor: Combined signal. """ if self.combine_method == "sum": # Simple summation (superposition principle) return torch.sum(torch.stack(signals), dim=0) elif self.combine_method == "weighted_sum": # Weighted summation using user gains (gains already applied in forward method) return torch.sum(torch.stack(signals), dim=0) else: # This should not happen due to validation in __init__ raise ValueError(f"Unknown combine method: {self.combine_method}")
[docs] def get_user_csi(self, user_idx: int) -> Optional[torch.Tensor]: """Get channel state information for a specific user. Args: user_idx (int): Index of the user (0-based). Returns: Optional[torch.Tensor]: CSI for the specified user, if available. Raises: ValueError: If user_idx is out of range. """ if not 0 <= user_idx < self.num_users: raise ValueError(f"User index {user_idx} is out of range for {self.num_users} users") channel = self.user_channels[user_idx] # Try to get CSI if the channel supports it if hasattr(channel, "get_csi"): return channel.get_csi() elif hasattr(channel, "csi"): return channel.csi else: return None
[docs] def update_user_gain(self, user_idx: int, new_gain: float) -> None: """Update the channel gain for a specific user. Args: user_idx (int): Index of the user (0-based). new_gain (float): New gain value. Raises: ValueError: If user_idx is out of range. """ if not 0 <= user_idx < self.num_users: raise ValueError(f"User index {user_idx} is out of range for {self.num_users} users") self.user_gains[user_idx] = float(new_gain)
[docs] def update_interference_power(self, new_power: float) -> None: """Update the inter-user interference power. Args: new_power (float): New interference power. Raises: ValueError: If new_power is negative. """ if new_power < 0: raise ValueError("interference_power must be non-negative") self.interference_power = new_power
[docs] def get_config(self) -> Dict[str, Any]: """Get a dictionary of the channel's configuration. Returns: Dict[str, Any]: Dictionary of parameter names and values. """ config = super().get_config() config.update( { "num_users": self.num_users, "user_gains": self.user_gains, "interference_power": self.interference_power, "combine_method": self.combine_method, "shared_channel": self.shared_channel, } ) # Add channel configurations if self.shared_channel: config["shared_channel_config"] = self.user_channels[0].get_config() else: config["user_channel_configs"] = [ch.get_config() for ch in self.user_channels] return config
def __repr__(self) -> str: """String representation of the UplinkMACChannel. Returns: str: String representation of the channel. """ return f"UplinkMACChannel(num_users={self.num_users}, " f"user_gains={self.user_gains.tolist()}, " f"interference_power={self.interference_power}, " f"combine_method={self.combine_method})"