Source code for kaira.models.components.afmodule

"""Attention Feature (AF) Module implementation."""

from typing import Any

import torch
from torch import nn

from kaira.models.base import ChannelAwareBaseModel

from ..registry import ModelRegistry


[docs] @ModelRegistry.register_model() class AFModule(ChannelAwareBaseModel): """ AFModule: Attention-Feature Module :cite:`xu2021wireless`. This module implements a an attention mechanism that recalibrates feature maps by explicitly modeling interdependencies between channel state information and the input features. This module allows the same model to be used during training and testing across channels with different signal-to-noise ratio without significant performance degradation. """
[docs] def __init__(self, N, csi_length, *args: Any, **kwargs: Any): """Initialize the AFModule. Args: N (int): The number of input and output features. csi_length (int): The length of the channel state information. *args: Variable positional arguments passed to the base class. **kwargs: Variable keyword arguments passed to the base class. """ super().__init__(*args, **kwargs) self.c_in = N self.layers = nn.Sequential( nn.Linear(in_features=N + csi_length, out_features=N), nn.LeakyReLU(), nn.Linear(in_features=N, out_features=N), nn.Sigmoid(), )
[docs] def forward(self, x: torch.Tensor, csi: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: """Forward pass through the AFModule. Args: x (torch.Tensor): The input tensor. csi (torch.Tensor): Channel State Information tensor. *args: Additional positional arguments (unused). **kwargs: Additional keyword arguments (unused). Returns: torch.Tensor: The output tensor after applying the attention mechanism. """ input_tensor = x # Handle different input dimensions input_dims = len(input_tensor.shape) batch_size = input_tensor.shape[0] # Get the actual number of channels from the input tensor if input_dims == 4: actual_channels = input_tensor.shape[1] context = torch.mean(input_tensor, dim=(2, 3)) elif input_dims == 3: actual_channels = input_tensor.shape[2] context = torch.mean(input_tensor, dim=1) else: actual_channels = input_tensor.shape[1] if len(input_tensor.shape) > 1 else 1 context = input_tensor # Convert csi to 2D tensor if needed if len(csi.shape) == 1: csi = csi.view(batch_size, 1) elif len(csi.shape) > 2: csi = csi.flatten(start_dim=1) # Make sure the context and csi dimensions match what the linear layer expects # The first linear layer expects N + csi_length input features expected_context_dim = self.layers[0].in_features - csi.shape[1] if context.shape[1] != expected_context_dim: if context.shape[1] > expected_context_dim: # Trim extra dimensions if needed context = context[:, :expected_context_dim] else: # Pad with zeros if needed padding = torch.zeros(batch_size, expected_context_dim - context.shape[1], device=context.device) context = torch.cat([context, padding], dim=1) context_input = torch.cat([context, csi], dim=1) mask = self.layers(context_input) # Apply the mask according to input dimensions and actual channels if input_dims == 4: # Reshape mask to match the number of channels in the original AFModule config mask = mask.view(-1, mask.shape[1], 1, 1) # If input has more channels than the mask, extend the mask if actual_channels > mask.shape[1]: additional_channels = actual_channels - mask.shape[1] extension = torch.ones(batch_size, additional_channels, 1, 1, device=mask.device) mask = torch.cat([mask, extension], dim=1) else: # Trim the mask if needed mask = mask[:, :actual_channels, :, :] elif input_dims == 3: mask = mask.view(-1, 1, mask.shape[1]) # If input has more channels than the mask, extend the mask if actual_channels > mask.shape[2]: additional_channels = actual_channels - mask.shape[2] extension = torch.ones(batch_size, 1, additional_channels, device=mask.device) mask = torch.cat([mask, extension], dim=2) else: # Trim the mask if needed mask = mask[:, :, :actual_channels] else: # If input has more features than the mask, extend the mask if actual_channels > mask.shape[1]: additional_channels = actual_channels - mask.shape[1] extension = torch.ones(batch_size, additional_channels, device=mask.device) mask = torch.cat([mask, extension], dim=1) else: # Trim the mask if needed mask = mask[:, :actual_channels] # Apply mask to the input tensor out = mask * input_tensor return out