"""Multiple Access Channel models for Kaira.
This module implements models for Multiple Access Channel (MAC) scenarios, where multiple devices
transmit data simultaneously over a shared wireless channel. It provides base classes and utilities
for implementing various MAC protocols and studying their performance.
"""
from typing import Any, List, Optional, Type, Union
import torch
from torch import nn
from kaira.channels import BaseChannel
from kaira.constraints import BaseConstraint
from kaira.models.base import BaseModel
from kaira.models.registry import ModelRegistry
[docs]
@ModelRegistry.register_model("multiple_access_channel")
class MultipleAccessChannelModel(BaseModel):
"""A model simulating a Multiple Access Channel (MAC).
In a MAC scenario, multiple transmitters (users) send signals simultaneously
over a shared channel to a single receiver. The receiver then attempts to
decode the messages from all users.
This model supports both shared and separate encoders/decoders based on the
provided configuration during initialization. A single decoder instance implies
joint decoding.
Attributes:
encoders (nn.ModuleList): A list of encoder modules. Contains one shared encoder
or one encoder per user.
decoders (nn.ModuleList): A list of decoder modules. Contains one shared (joint)
decoder or one decoder per user.
channel (BaseChannel): The communication channel model.
power_constraint (BaseConstraint): Power constraint applied to the sum of encoded signals.
num_users (int): The number of users (transmitters).
"""
[docs]
def __init__(
self,
encoders: Union[Type[BaseModel], BaseModel, List[BaseModel], nn.ModuleList],
decoders: Union[Type[BaseModel], BaseModel, List[BaseModel], nn.ModuleList],
channel: BaseChannel,
power_constraint: BaseConstraint,
num_devices: Optional[int] = None,
*args: Any,
**kwargs: Any,
):
"""Initialize the MultipleAccessChannelModel.
Args:
encoders (Union[Type[BaseModel], BaseModel, List[BaseModel], nn.ModuleList]):
Encoder configuration. Can be:
- A class (Type[BaseModel]): An instance will be created for each device.
- An instance (BaseModel): This instance will be shared across all devices.
- A list/ModuleList of instances: One encoder per device. Length must match num_devices.
decoders (Union[Type[BaseModel], BaseModel, List[BaseModel], nn.ModuleList]):
Decoder configuration. Can be:
- A class (Type[BaseModel]): A single instance will be created (joint decoder).
- An instance (BaseModel): This instance will be used as the single joint decoder.
- A list/ModuleList of instances: One decoder per device (separate decoding). Length must match num_devices.
channel (BaseChannel): The channel model instance.
power_constraint (BaseConstraint): The power constraint instance.
num_devices (Optional[int]): The number of users/devices. Required if encoders/decoders
are provided as single instances or classes. Inferred if encoders/decoders are lists.
*args: Variable positional arguments passed to the base class and module instantiation.
**kwargs: Variable keyword arguments passed to the base class and module instantiation.
"""
super().__init__(*args, **kwargs)
# --- Determine Number of Devices ---
inferred_num_devices_enc = None
if isinstance(encoders, (list, nn.ModuleList)):
inferred_num_devices_enc = len(encoders)
if num_devices is None:
num_devices = inferred_num_devices_enc
elif num_devices != inferred_num_devices_enc:
raise ValueError(f"Provided num_devices ({num_devices}) does not match the number of encoders ({inferred_num_devices_enc}).")
inferred_num_devices_dec = None
if isinstance(decoders, (list, nn.ModuleList)):
inferred_num_devices_dec = len(decoders)
if num_devices is None:
num_devices = inferred_num_devices_dec
elif num_devices != inferred_num_devices_dec:
# Allow single decoder in list for joint decoding
# Check if it's the decoder case and only one decoder is provided
is_single_joint_decoder_in_list = inferred_num_devices_dec == 1
if not is_single_joint_decoder_in_list:
raise ValueError(f"Provided num_devices ({num_devices}) does not match the number of decoders ({inferred_num_devices_dec}).")
# Check consistency if both were lists and neither was a single joint decoder
if inferred_num_devices_enc is not None and inferred_num_devices_dec is not None and inferred_num_devices_dec != 1 and inferred_num_devices_enc != inferred_num_devices_dec:
raise ValueError(f"Number of encoders ({inferred_num_devices_enc}) must match number of decoders ({inferred_num_devices_dec}) when both are provided as lists with more than one decoder.")
if num_devices is None:
# Try inferring from decoder if encoder wasn't a list but decoder was
if inferred_num_devices_dec is not None:
# If only one decoder was provided in the list, we still don't know num_devices
if inferred_num_devices_dec != 1:
num_devices = inferred_num_devices_dec
else:
# Need num_devices from encoder or explicit arg if decoder is single/joint
raise ValueError("num_devices must be specified if encoders are not provided as a list and only a single (joint) decoder is provided.")
else:
raise ValueError("num_devices must be specified if encoders and decoders are not provided as lists.")
if not isinstance(num_devices, int) or num_devices <= 0:
raise ValueError(f"num_devices must be a positive integer, got {num_devices}")
self.num_users = num_devices
self.num_devices = num_devices # Keep for compatibility
# --- Initialize Encoders ---
# Pass *args, **kwargs to _initialize_modules
self.encoders = self._initialize_modules(encoders, num_devices, "Encoder", *args, **kwargs)
# --- Initialize Decoders ---
# Pass *args, **kwargs to _initialize_modules
self.decoders = self._initialize_modules(decoders, num_devices, "Decoder", *args, **kwargs)
# --- Assign Channel and Constraint ---
if not isinstance(channel, BaseChannel):
raise TypeError(f"Channel must be an instance of BaseChannel, but got {type(channel)}")
self.channel = channel
if not isinstance(power_constraint, BaseConstraint):
raise TypeError(f"Power constraint must be an instance of BaseConstraint, but got {type(power_constraint)}")
self.power_constraint = power_constraint
def _initialize_modules(self, module_config: Union[Type[BaseModel], BaseModel, List[BaseModel], nn.ModuleList], num_devices: int, module_name: str, *args: Any, **kwargs: Any) -> nn.ModuleList:
"""Helper function to initialize encoder or decoder modules."""
modules_list = []
is_shared = False # Track if the module is intended to be shared
if isinstance(module_config, (list, nn.ModuleList)):
# Separate instances provided in a list
if module_name == "Decoder" and len(module_config) == 1:
# Special case: A list containing a single decoder implies joint decoding
is_shared = True
modules_list = list(module_config)
elif len(module_config) != num_devices:
raise ValueError(f"Number of {module_name.lower()}s in the list ({len(module_config)}) must match num_devices ({num_devices}).")
else:
# Correct number of separate instances provided
modules_list = list(module_config)
elif isinstance(module_config, nn.Module):
# Single instance provided -> treat as shared
is_shared = True
instance = module_config
if module_name == "Decoder":
# For joint decoder, store only the single instance in the list
modules_list = [instance]
else: # Encoders: replicate reference num_devices times for forward pass indexing
modules_list = [instance] * num_devices
elif isinstance(module_config, type):
# Class provided
module_cls = module_config
if module_name == "Decoder":
# Create one instance for joint decoding
is_shared = True
instance = module_cls(*args, **kwargs)
modules_list = [instance]
else: # Encoders: Create separate instances
is_shared = False # Explicitly separate
modules_list = [module_cls(*args, **kwargs) for _ in range(num_devices)]
else:
raise TypeError(f"Invalid type for {module_name.lower()} configuration: {type(module_config)}")
# Validate all items are nn.Module
for i, mod in enumerate(modules_list):
if not isinstance(mod, nn.Module):
raise TypeError(f"{module_name} at index {i} (or the shared instance) must be an instance of nn.Module, but got {type(mod)}")
# Final check for encoder list length if shared instance was replicated
if module_name == "Encoder" and is_shared and len(modules_list) != num_devices:
# This should not happen with the current logic, but as a safeguard
raise RuntimeError(f"Internal error: Shared encoder list length ({len(modules_list)}) doesn't match num_devices ({num_devices}).")
# Final check for decoder list length
if module_name == "Decoder" and not is_shared and len(modules_list) != num_devices:
raise RuntimeError(f"Internal error: Separate decoder list length ({len(modules_list)}) doesn't match num_devices ({num_devices}).")
if module_name == "Decoder" and is_shared and len(modules_list) != 1:
raise RuntimeError(f"Internal error: Shared decoder list should have length 1, but got {len(modules_list)}.")
# Return as ModuleList
return nn.ModuleList(modules_list)
[docs]
def forward(self, x: List[torch.Tensor], *args: Any, **kwargs: Any) -> torch.Tensor:
"""Forward pass through the Multiple Access Channel model.
Args:
x (List[torch.Tensor]): A list of input tensors, one for each user.
Each tensor should have shape (batch_size, message_dim).
*args: Additional positional arguments passed to encoders, channel, and decoder(s).
**kwargs: Additional keyword arguments passed to encoders, channel, and decoder(s).
Returns:
torch.Tensor: The output tensor from the decoder(s).
If joint decoder: (batch_size, decoded_message_dim).
If separate decoders: (batch_size, num_users * decoded_message_dim_per_user).
"""
if not isinstance(x, list) or not all(isinstance(t, torch.Tensor) for t in x):
# Added check for list input based on test_mac_model_invalid_function_call
raise ValueError("Input 'x' must be a list of torch.Tensors.")
if len(x) != self.num_users:
raise ValueError(f"Number of input tensors ({len(x)}) must match the number of users ({self.num_users}).")
if not self.encoders:
raise ValueError("Encoders must be initialized before calling forward.")
if not self.decoders:
raise ValueError("Decoders must be initialized before calling forward.")
# 1. Encode messages for each user
encoded_signals = []
# Determine if encoder is shared: list has 1 element, or list has num_users refs to the same object
is_shared_encoder = len(self.encoders) == 1 or (self.num_users > 1 and len(self.encoders) == self.num_users and self.encoders[0] is self.encoders[1])
for i in range(self.num_users):
# Use index 0 if shared, otherwise use user index i
# We need to ensure the index is valid for the actual list length
actual_encoder_list_len = len(self.encoders)
if is_shared_encoder and actual_encoder_list_len > 0:
encoder_to_use_idx = 0 # Always use the first (and only unique) encoder if shared
elif not is_shared_encoder and i < actual_encoder_list_len:
encoder_to_use_idx = i
else:
# This case should ideally not happen due to init logic, but check defensively
raise IndexError(f"Encoder index calculation error. is_shared={is_shared_encoder}, index={i}, list_len={actual_encoder_list_len}")
encoder = self.encoders[encoder_to_use_idx]
encoded_signals.append(encoder(x[i], *args, **kwargs))
# 2. Combine encoded signals (summing them simulates superposition on the channel)
combined_signal = torch.sum(torch.stack(encoded_signals), dim=0)
# 3. Apply power constraint to the combined signal
constrained_signal = self.power_constraint(combined_signal)
# 4. Pass the combined signal through the channel
# Pass *args and **kwargs to the channel
received_signal = self.channel(constrained_signal, *args, **kwargs)
# 5. Decode the received signal
is_joint_decoder = len(self.decoders) == 1
if is_joint_decoder:
# Use the single joint decoder
decoder = self.decoders[0]
reconstructed_messages = decoder(received_signal, *args, **kwargs)
else:
# Use separate decoders
# Ensure the number of decoders matches the number of users for separate decoding
if len(self.decoders) != self.num_users:
raise ValueError(f"Number of separate decoders ({len(self.decoders)}) must match num_users ({self.num_users}) for separate decoding mode.")
reconstructed_signals_list = [] # Renamed to avoid confusion
for i in range(self.num_users):
decoder = self.decoders[i]
# Assumption: Each separate decoder `i` processes the combined signal
# to reconstruct the message for user `i`.
reconstructed_signals_list.append(decoder(received_signal, *args, **kwargs))
# Concatenate the outputs along the feature dimension
reconstructed_messages = torch.cat(reconstructed_signals_list, dim=1)
# Return the final reconstructed messages
return reconstructed_messages