Source code for kaira.models.image.yilmaz2023_deepjscc_noma

"""DeepJSCC-NOMA module for Kaira.

This module contains the Yilmaz2023DeepJSCCNOMA model, which implements Distributed Deep Joint
Source-Channel Coding over a Multiple Access Channel as described in the paper by Yilmaz et al.
(2023).
"""

from typing import Any, List, Optional, Tuple, Type, Union

import torch
from torch import nn

from kaira.channels.base import BaseChannel
from kaira.constraints.base import BaseConstraint
from kaira.models.base import BaseModel
from kaira.models.image.tung2022_deepjscc_q import (
    Tung2022DeepJSCCQ2Decoder,
    Tung2022DeepJSCCQ2Encoder,
)
from kaira.models.multiple_access_channel import MultipleAccessChannelModel
from kaira.models.registry import ModelRegistry


[docs] @ModelRegistry.register_model() class Yilmaz2023DeepJSCCNOMAEncoder(Tung2022DeepJSCCQ2Encoder): """DeepJSCC-NOMA Encoder Module :cite:`yilmaz2023distributed`. This encoder transforms input images into latent representations. This class extends the Tung2022DeepJSCCQ2Encoder class with parameter adaptation as used in the paper :cite:t:`yilmaz2023distributed`. """
[docs] def __init__(self, N=64, M=16, in_ch=4, csi_length=1, *args: Any, **kwargs: Any): """Initialize the DeepJSCCNOMAEncoder. Args: N (int, optional): Number of channels in the network. M (int, optional): Latent dimension of the bottleneck representation. in_ch (int, optional): Number of input channels. Defaults to 4. csi_length (int, optional): The number of dimensions in the CSI data. Defaults to 1. *args: Variable positional arguments passed to the base class. **kwargs: Variable keyword arguments passed to the base class. """ super().__init__(N=N, M=M, in_ch=in_ch, csi_length=csi_length)
# Forward method is inherited from Tung2022DeepJSCCQ2Encoder, which already handles *args, **kwargs
[docs] @ModelRegistry.register_model() class Yilmaz2023DeepJSCCNOMADecoder(Tung2022DeepJSCCQ2Decoder): """DeepJSCC-NOMA Decoder Module :cite:`yilmaz2023distributed`. This decoder reconstructs images from received channel signals, supporting both individual device decoding and shared decoding for multiple devices. This class extends the Tung2022DeepJSCCQ2Decoder class with parameter adaptation as used in the paper :cite:t:`yilmaz2023distributed`. """
[docs] def __init__(self, N=64, M=16, out_ch_per_device=3, csi_length=1, num_devices=1, shared_decoder=False, *args: Any, **kwargs: Any): """Initialize the DeepJSCCNOMADecoder. Args: N (int, optional): Number of channels in the network. Defaults to 64 if not provided. M (int, optional): Latent dimension of the bottleneck representation. Defaults to 16 if not provided. out_ch_per_device (int, optional): Number of output channels per device. Defaults to 3. csi_length (int, optional): The number of dimensions in the CSI data. Defaults to 1. num_devices (int, optional): Number of devices. Used for shared decoder. Defaults to 1. shared_decoder (bool, optional): Whether this is a shared decoder. Defaults to False. *args: Variable positional arguments passed to the base class. **kwargs: Variable keyword arguments passed to the base class. """ # Store additional parameters self.num_devices = num_devices self.shared_decoder = shared_decoder super().__init__(N=N, M=M, out_ch=self.num_devices * out_ch_per_device, csi_length=csi_length)
# Forward method is inherited from Tung2022DeepJSCCQ2Decoder, which already handles *args, **kwargs # Use Tung2022DeepJSCCQ2 models as default DEFAULT_ENCODER = Yilmaz2023DeepJSCCNOMAEncoder DEFAULT_DECODER = Yilmaz2023DeepJSCCNOMADecoder
[docs] @ModelRegistry.register_model("deepjscc_noma") class Yilmaz2023DeepJSCCNOMAModel(MultipleAccessChannelModel): """Distributed Deep Joint Source-Channel Coding over a Multiple Access Channel :cite:`yilmaz2023distributed`. This model implements the DeepJSCC-NOMA system from the paper by Yilmaz et al. (2023), which enables multiple devices to transmit jointly encoded data over a shared wireless channel using Non-Orthogonal Multiple Access (NOMA). Attributes: M: Channel bandwidth expansion/compression factor latent_dim: Dimension of latent representation use_perfect_sic: Whether to use perfect successive interference cancellation use_device_embedding: Whether to use device embeddings image_shape: Shape of the input images used for embedding device_images: Embedding table for device-specific embeddings """
[docs] def __init__( self, channel: BaseChannel, power_constraint: BaseConstraint, encoder: Optional[Union[Type[BaseModel], BaseModel]] = None, # Allow class or instance decoder: Optional[Union[Type[BaseModel], BaseModel]] = None, # Allow class or instance num_devices: int = 2, M: float = 1.0, latent_dim: int = 16, shared_encoder: bool = False, shared_decoder: bool = False, use_perfect_sic: bool = False, use_device_embedding: Optional[bool] = None, image_shape: Tuple[int, int] = (32, 32), csi_length: int = 1, ckpt_path: Optional[str] = None, *args: Any, **kwargs: Any, ): """Initialize the DeepJSCC-NOMA model. Args: channel: Channel model for transmission power_constraint: Power constraint to apply to transmitted signals encoder: Encoder network class or constructor (default: Tung2022DeepJSCCQ2Encoder) decoder: Decoder network class or constructor (default: Tung2022DeepJSCCQ2Decoder) num_devices: Number of transmitting devices M: Channel bandwidth expansion/compression factor latent_dim: Dimension of latent representation shared_encoder: Whether to use a shared encoder across devices shared_decoder: Whether to use a shared decoder across devices use_perfect_sic: Whether to use perfect successive interference cancellation use_device_embedding: Whether to use device embeddings image_shape: Shape of input images (height, width) for determining embedding dimensions csi_length: The length of CSI (Channel State Information) vector ckpt_path: Path to checkpoint file for loading pre-trained weights *args: Variable positional arguments passed to the base class. **kwargs: Variable keyword arguments passed to the base class. """ # Initialize DeepJSCC-NOMA specific attributes self.shared_encoder = shared_encoder self.shared_decoder = shared_decoder self.M = M self.latent_dim = latent_dim self.use_perfect_sic = use_perfect_sic self.use_device_embedding = use_device_embedding if use_device_embedding is not None else shared_encoder self.image_shape = image_shape self.csi_length = csi_length self.embedding_dim = image_shape[0] * image_shape[1] # Ensure num_devices is set before using it below self.num_devices = num_devices # Determine encoder/decoder classes or instances encoder_config = encoder if encoder is not None else DEFAULT_ENCODER decoder_config = decoder if decoder is not None else DEFAULT_DECODER # Prepare decoder config based on shared_decoder flag BEFORE calling super().__init__ final_decoder_config: Union[Type[BaseModel], BaseModel, List[BaseModel], nn.ModuleList] if not self.shared_decoder and isinstance(decoder_config, type): # Need separate instances, create them now using appropriate args # Assuming default N=64, out_ch_per_device=3 based on context/defaults decoder_args = { "N": kwargs.get("N", 64), # Allow override via kwargs if provided "M": self.latent_dim, "out_ch_per_device": kwargs.get("out_ch_per_device", 3), # Allow override, assume 3 for RGB "csi_length": self.csi_length, # Pass specific args for Yilmaz decoder "num_devices": 1, # Each instance conceptually handles one device stream "shared_decoder": False, # Explicitly false for each instance } # Instantiate num_devices separate decoders final_decoder_config = [decoder_config(**decoder_args) for _ in range(self.num_devices)] else: # Shared decoder, or already an instance/list: pass as is to superclass final_decoder_config = decoder_config # Initialize the base class, passing the potentially modified decoder config super().__init__(encoders=encoder_config, decoders=final_decoder_config, channel=channel, power_constraint=power_constraint, num_devices=num_devices) # Use the processed config # Device embedding setup (needs num_devices from base class) if self.use_device_embedding: self.device_images = nn.Embedding(self.num_devices, embedding_dim=self.embedding_dim) # Loading checkpoint needs to happen *after* models are created by super().__init__ if ckpt_path is not None: self._load_checkpoint(ckpt_path)
[docs] def forward(self, x: Union[List[torch.Tensor], torch.Tensor], *args: Any, **kwargs: Any) -> torch.Tensor: """Forward pass through the DeepJSCC-NOMA model. Args: x: Input data. - If use_perfect_sic=True: A single tensor [batch_size, num_devices, channels, height, width]. - If use_perfect_sic=False: A list of tensors [batch_size, channels, height, width] for each device. *args: Additional positional arguments passed to internal components. **kwargs: Additional keyword arguments passed to internal components (e.g., csi). Returns: Reconstructed signals with shape [batch_size, num_devices, channels, height, width] """ csi = kwargs.get("csi") # --- Perfect SIC Mode --- if self.use_perfect_sic: if not isinstance(x, torch.Tensor): raise ValueError(f"Input 'x' must be a single tensor when use_perfect_sic=True, but got {type(x)}.") # Check tensor shape for perfect SIC mode if x.ndim != 5 or x.shape[1] != self.num_devices: raise ValueError(f"Input tensor shape mismatch for perfect SIC. Expected [B, {self.num_devices}, C, H, W], got {x.shape}") if csi is None: raise ValueError("Keyword argument 'csi' must be provided when use_perfect_sic=True.") # Remove csi from kwargs before passing to avoid duplicate argument error kwargs_without_csi = kwargs.copy() kwargs_without_csi.pop("csi", None) # Remove csi if it exists # Call the perfect SIC forward method return self._forward_perfect_sic(x, csi, *args, **kwargs_without_csi) # --- Standard NOMA Mode (Non-Perfect SIC) --- if not isinstance(x, list) or len(x) != self.num_devices: raise ValueError(f"Input 'x' must be a list of {self.num_devices} tensors when use_perfect_sic=False, but got {type(x)} with length {len(x) if isinstance(x, list) else 'N/A'}.") processed_x = [] # Add device embeddings if enabled if self.use_device_embedding: h, w = self.image_shape if not x: raise ValueError("Input list 'x' cannot be empty when use_device_embedding is True.") batch_size = x[0].size(0) device = x[0].device for i in range(self.num_devices): if i >= len(x): raise ValueError(f"Input list 'x' has length {len(x)}, but expected at least {i+1} tensors for num_devices={self.num_devices}.") emb_i = self.device_images(torch.ones((batch_size), dtype=torch.long, device=device) * i).view(batch_size, 1, h, w) if x[i].ndim != 4: raise ValueError(f"Input tensor for device {i} has unexpected dimensions: {x[i].shape}") processed_x.append(torch.cat([x[i], emb_i], dim=1)) else: processed_x = x # Encode inputs transmissions: List[torch.Tensor] = [] for i in range(self.num_devices): encoder = self.encoders[0] if self.shared_encoder else self.encoders[i] device_input = processed_x[i] tx = encoder(device_input, *args, **kwargs) tx = self.power_constraint(tx, *args, **kwargs) transmissions.append(tx) # Combine signals for NOMA x_stacked = torch.stack(transmissions, dim=1) x_summed = torch.sum(x_stacked, dim=1) # Apply channel x_channel_out = self.channel(x_summed, *args, **kwargs) decoded_outputs: List[torch.Tensor] = [] # In standard NOMA, the decoder receives the summed signal from the channel decoder_input = x_channel_out if self.shared_decoder: decoder = self.decoders[0] # Shared decoder processes the summed signal x_decoded_all = decoder(decoder_input, *args, **kwargs) # Expected: [B, D*C_out, H, W] try: # Assumes decoder has 'out_ch' attribute storing total output channels (D*C_out) num_out_channels_total = decoder.out_ch except AttributeError: # Fallback for decoders without 'out_ch' if x_decoded_all.shape[1] % self.num_devices != 0: raise ValueError(f"Cannot determine output channels per device for shared decoder. Output shape: {x_decoded_all.shape}") num_out_channels_total = x_decoded_all.shape[1] if num_out_channels_total % self.num_devices != 0: raise ValueError(f"Shared decoder total output channels ({num_out_channels_total}) not divisible by num_devices ({self.num_devices}).") num_out_channels_per_device = num_out_channels_total // self.num_devices for i in range(self.num_devices): start_channel = i * num_out_channels_per_device end_channel = (i + 1) * num_out_channels_per_device decoded_outputs.append(x_decoded_all[:, start_channel:end_channel, :, :]) else: # Non-shared decoders: each processes the same summed input for i in range(self.num_devices): decoder = self.decoders[i] x_decoded = decoder(decoder_input, *args, **kwargs) # Expected: [B, C_out, H, W] decoded_outputs.append(x_decoded) # Stack results x_out = torch.stack(decoded_outputs, dim=1) # Shape: [B, N_dev, C_out, H, W] return x_out
def _forward_perfect_sic(self, x: torch.Tensor, csi: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: """Forward pass with perfect successive interference cancellation. Args: x: Input data with shape [batch_size, num_devices, channels, height, width] csi: Channel state information with shape [batch_size, csi_length] *args: Additional positional arguments passed to internal components. **kwargs: Additional keyword arguments passed to internal components. Returns: Reconstructed signals with shape [batch_size, num_devices, channels, height, width] """ # Add device embeddings if enabled if self.use_device_embedding: h, w = self.image_shape emb = torch.stack([self.device_images(torch.ones((x.size(0)), dtype=torch.long, device=x.device) * i).view(x.size(0), 1, h, w) for i in range(self.num_devices)], dim=1) x = torch.cat([x, emb], dim=2) transmissions: List[torch.Tensor] = [] # Apply encoders and channel with SIC - support different encoder interfaces for i in range(self.num_devices): # Use shared_encoder flag encoder = self.encoders[0] if self.shared_encoder else self.encoders[i] device_input = x[:, i, ...] # Pass csi explicitly as a keyword argument to the encoder t = encoder(device_input, *args, csi=csi, **kwargs) # Apply power constraint - Assuming output t is 4D [B, C, H, W] # The original power constraint logic seemed specific and might need review # For simplicity, let's assume a standard power constraint applied per device signal t = self.power_constraint(t, *args, **kwargs) # Ensure output is 4D # Use the provided channel model for each transmission - Pass only the 4D signal tensor t_channel = self.channel(t, *args, **kwargs) # Input is 4D, output is 4D transmissions.append(t_channel) # List of 4D tensors # Decode each transmission - support different decoder interfaces results: List[torch.Tensor] = [] for i in range(self.num_devices): # Use shared_decoder flag decoder = self.decoders[0] if self.shared_decoder else self.decoders[i] # Pass only the relevant 4D tensor transmission to the decoder, including csi=csi xi = decoder(transmissions[i], csi=csi, *args, **kwargs) # Input is 4D if self.shared_decoder and xi.ndim == 4 and xi.shape[1] > decoder.out_ch_per_device: # Example check # If shared decoder outputs all devices, select the relevant one num_out_channels_total = xi.shape[1] num_out_channels_per_device = num_out_channels_total // self.num_devices start_channel = i * num_out_channels_per_device end_channel = (i + 1) * num_out_channels_per_device xi = xi[:, start_channel:end_channel, :, :] results.append(xi) # List of [B, C_out, H, W] return torch.stack(results, dim=1) # Stack to [B, D, C_out, H, W]