Source code for kaira.models.base

"""Base model definitions for deep learning architectures.

This module provides the foundation for all model implementations in the Kaira framework. The
BaseModel class implements common functionality and enforces a consistent interface across
different model types.
"""

from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional  # Added imports

import torch
from torch import nn


[docs] class BaseModel(nn.Module, ABC): """Abstract base class for all models in the Kaira framework. This class extends PyTorch's nn.Module and adds framework-specific functionality. All models should inherit from this class to ensure compatibility with the framework's training, evaluation, and inference pipelines. The class provides a consistent interface for model implementation while allowing flexibility in architecture design. It enforces proper initialization and forward pass implementation. """
[docs] def __init__(self, *args: Any, **kwargs: Any): """Initialize the model. Args: *args: Variable positional arguments. **kwargs: Variable keyword arguments. """ super().__init__()
[docs] @abstractmethod def forward(self, *args: Any, **kwargs: Any) -> Any: """Define the forward pass computation. This method should be implemented by all subclasses to define how input data is processed through the model to produce output. Args: *args: Variable positional arguments for flexible input handling **kwargs: Variable keyword arguments for optional parameters Returns: Any: Model output, type depends on specific implementation Raises: NotImplementedError: If the subclass does not implement this method """ raise NotImplementedError("Subclasses must implement forward method")
class ChannelAwareBaseModel(BaseModel): """Abstract base class for models that require channel state information (CSI). This class extends BaseModel to standardize how CSI is handled in channel-aware models throughout the Kaira framework. It provides utility methods for CSI validation, normalization, and transformation, ensuring consistent CSI handling across different model implementations. CSI typically contains information about the communication channel conditions such as: - Signal-to-noise ratio (SNR) in dB - Channel gain coefficients - Fading characteristics - Quality indicators All subclasses must implement the forward method with explicit CSI parameter. """ def __init__(self, *args: Any, **kwargs: Any): """Initialize the channel-aware model. Args: *args: Variable positional arguments passed to BaseModel. **kwargs: Variable keyword arguments passed to BaseModel. """ super().__init__(*args, **kwargs) # CSI configuration self._csi_shape_cache: Optional[torch.Size] = None self._expected_csi_length: Optional[int] = None @abstractmethod def forward(self, *args: Any, **kwargs: Any) -> Any: """Define the forward pass computation with CSI support. This method must be implemented by all subclasses to define how input data is processed through the model with channel state information. All implementations must accept CSI either as a positional argument or keyword argument named 'csi'. The exact signature can vary to accommodate different model architectures (e.g., single input + CSI, or multiple inputs + CSI). Common patterns: - forward(self, x, csi, *args, **kwargs) -> single input with CSI - forward(self, x, x_side, csi, *args, **kwargs) -> multiple inputs with CSI - forward(self, *args, csi=csi, **kwargs) -> CSI as keyword argument Args: *args: Variable positional arguments for flexible input handling **kwargs: Variable keyword arguments, must support 'csi' parameter Returns: Any: Model output, type depends on specific implementation Raises: NotImplementedError: If the subclass does not implement this method """ raise NotImplementedError("Subclasses must implement forward method with CSI support") def validate_csi(self, csi: torch.Tensor, expected_shape: Optional[torch.Size] = None) -> torch.Tensor: """Validate and ensure CSI tensor is in the correct format. Args: csi (torch.Tensor): The CSI tensor to validate expected_shape (Optional[torch.Size]): Expected shape for the CSI tensor. If None, uses cached shape or infers from tensor. Returns: torch.Tensor: Validated CSI tensor Raises: ValueError: If CSI tensor is invalid or has incorrect shape TypeError: If CSI is not a tensor """ if not isinstance(csi, torch.Tensor): raise TypeError(f"CSI must be a torch.Tensor, got {type(csi)}") if csi.numel() == 0: raise ValueError("CSI tensor cannot be empty") if torch.isnan(csi).any(): raise ValueError("CSI tensor contains NaN values") if torch.isinf(csi).any(): raise ValueError("CSI tensor contains infinite values") # Validate shape if expected_shape is provided if expected_shape is not None: if csi.shape != expected_shape: raise ValueError(f"CSI shape mismatch. Expected {expected_shape}, got {csi.shape}") # Cache the shape for future validations if self._csi_shape_cache is None: self._csi_shape_cache = csi.shape return csi def normalize_csi(self, csi: torch.Tensor, method: str = "minmax", target_range: tuple = (0.0, 1.0)) -> torch.Tensor: """Normalize CSI values to a specified range. Args: csi (torch.Tensor): The CSI tensor to normalize method (str): Normalization method. Options: "minmax", "zscore", "none" target_range (tuple): Target range for minmax normalization (min, max) Returns: torch.Tensor: Normalized CSI tensor Raises: ValueError: If normalization method is not supported """ if method == "none": return csi elif method == "minmax": min_val, max_val = target_range # Normalize per batch (along dim=1 for 2D tensors) if csi.dim() >= 2: csi_min = torch.min(csi, dim=1, keepdim=True)[0] csi_max = torch.max(csi, dim=1, keepdim=True)[0] else: csi_min = torch.min(csi) csi_max = torch.max(csi) # Avoid division by zero range_vals = csi_max - csi_min range_vals = torch.where(range_vals == 0, torch.ones_like(range_vals), range_vals) normalized = (csi - csi_min) / range_vals return normalized * (max_val - min_val) + min_val elif method == "zscore": # Normalize per batch (along dim=1 for 2D tensors) if csi.dim() >= 2: mean = torch.mean(csi, dim=1, keepdim=True) std = torch.std(csi, dim=1, keepdim=True, unbiased=False) else: mean = torch.mean(csi) std = torch.std(csi, unbiased=False) # Avoid division by zero std = torch.where(std == 0, torch.ones_like(std), std) return (csi - mean) / std else: raise ValueError(f"Unsupported normalization method: {method}") def transform_csi(self, csi: torch.Tensor, target_shape: torch.Size) -> torch.Tensor: """Transform CSI tensor to match target shape requirements. Args: csi (torch.Tensor): The CSI tensor to transform target_shape (torch.Size): Target shape for the CSI tensor Returns: torch.Tensor: Transformed CSI tensor """ current_shape = csi.shape # If shapes already match, return as-is if current_shape == target_shape: return csi # Handle different transformation cases if len(target_shape) > len(current_shape): # Add dimensions while len(csi.shape) < len(target_shape): csi = csi.unsqueeze(-1) elif len(target_shape) < len(current_shape): # Remove dimensions by flattening csi = csi.flatten(start_dim=len(target_shape) - 1) # Reshape to exact target shape if needed if csi.shape != target_shape: # Try to reshape, expanding/contracting as needed total_elements = csi.numel() target_elements = torch.prod(torch.tensor(target_shape)).item() if total_elements == target_elements: csi = csi.reshape(target_shape) else: # Pad or truncate to match target size if total_elements < target_elements: # Pad with the last value pad_size = int(target_elements - total_elements) pad_value = float(csi.flatten()[-1].item()) if total_elements > 0 else 0.0 padding = torch.full((pad_size,), pad_value, device=csi.device, dtype=csi.dtype) csi = torch.cat([csi.flatten(), padding]) else: # Truncate to target size csi = csi.flatten()[: int(target_elements)] csi = csi.reshape(target_shape) return csi def extract_csi_features(self, csi: torch.Tensor) -> Dict[str, torch.Tensor]: """Extract common features from CSI tensor for analysis. Args: csi (torch.Tensor): The CSI tensor to analyze Returns: Dict[str, torch.Tensor]: Dictionary containing extracted features """ features = {} # Handle complex tensors if csi.is_complex(): # Extract magnitude and phase magnitude = torch.abs(csi) phase = torch.angle(csi) # Basic statistics on magnitude features["mean"] = torch.mean(magnitude) features["std"] = torch.std(magnitude) features["min"] = torch.min(magnitude) features["max"] = torch.max(magnitude) # Complex-specific features features["magnitude"] = torch.mean(magnitude) features["phase"] = torch.mean(phase) # Phase statistics features["phase_std"] = torch.std(phase) else: # Basic statistics for real tensors features["mean"] = torch.mean(csi) features["std"] = torch.std(csi) features["min"] = torch.min(csi) features["max"] = torch.max(csi) # Shape information features["shape"] = torch.tensor(csi.shape) features["numel"] = torch.tensor(csi.numel()) # Signal quality indicators (assuming CSI represents SNR in dB) if csi.dim() > 0 and not csi.is_complex(): # Convert dB to linear scale for additional metrics linear_csi = 10 ** (csi / 10) features["linear_mean"] = torch.mean(linear_csi) features["linear_std"] = torch.std(linear_csi) return features def forward_csi_to_submodules(self, csi: torch.Tensor, modules: List[BaseModel], *args, **kwargs) -> List[Any]: """Helper method to consistently pass CSI to submodules. This method facilitates passing CSI to multiple submodules that require channel state information, ensuring consistent handling across the model. Args: csi (torch.Tensor): Channel state information tensor modules (List[BaseModel]): List of modules to apply *args: Positional arguments to pass to modules **kwargs: Keyword arguments to pass to modules Returns: List[Any]: List of outputs from each module """ outputs = [] current_input = args[0] if args else None for module in modules: if isinstance(module, ChannelAwareBaseModel): # Module requires CSI - pass it explicitly if current_input is not None: output = module(current_input, csi=csi, **kwargs) else: output = module(csi=csi, *args, **kwargs) else: # Standard module - pass without CSI if current_input is not None: output = module(current_input, **kwargs) else: output = module(*args, **kwargs) outputs.append(output) current_input = output # Chain outputs for sequential processing return outputs def create_csi_for_submodules(self, csi: torch.Tensor, num_modules: int) -> List[torch.Tensor]: """Create appropriate CSI tensors for multiple submodules. Args: csi (torch.Tensor): Original CSI tensor num_modules (int): Number of submodules that need CSI Returns: List[torch.Tensor]: List of CSI tensors for each submodule """ if num_modules <= 0: return [] # For now, return the same CSI for all modules # This can be extended to create module-specific CSI transformations return [csi.clone() for _ in range(num_modules)] @staticmethod def extract_csi_from_channel_output(channel_output: Any) -> Optional[torch.Tensor]: """Extract CSI from channel output if available. Some channels return both the transmitted signal and CSI information. This static method provides a standardized way to extract CSI from various channel output formats. Args: channel_output: Output from a channel, which may contain CSI Returns: Optional[torch.Tensor]: Extracted CSI tensor if available, None otherwise """ if isinstance(channel_output, tuple): # Assume tuple format: (signal, csi) if len(channel_output) >= 2: return channel_output[1] elif isinstance(channel_output, dict): # Dictionary format with CSI key return channel_output.get("csi") elif hasattr(channel_output, "csi"): # Object with CSI attribute return getattr(channel_output, "csi") return None @staticmethod def format_csi_for_channel(csi: torch.Tensor, channel_format: str = "tensor") -> Any: """Format CSI tensor for passing to channels that expect specific formats. Args: csi (torch.Tensor): CSI tensor to format channel_format (str): Expected format ("tensor", "dict", "kwargs") Returns: Any: Formatted CSI in the requested format """ if channel_format == "tensor": return csi elif channel_format == "dict": return {"csi": csi} elif channel_format == "kwargs": return {"csi": csi} else: raise ValueError(f"Unsupported channel format: {channel_format}")
[docs] class ConfigurableModel(BaseModel): """Model that supports dynamically adding and removing steps. This class extends the basic model functionality with methods to add, remove, and manage model steps during runtime. """
[docs] def __init__(self, *args: Any, **kwargs: Any): """Initialize the configurable model.""" super().__init__(*args, **kwargs) self.steps: List[Callable] = [] # Added initialization here, changed type to List[Callable]
[docs] def add_step(self, step: Callable) -> "ConfigurableModel": # Changed step type to Callable """Add a processing step to the model. Args: step: A callable that will be added to the processing pipeline. Must accept and return tensor-like objects. Returns: Self for method chaining """ if not callable(step): # Added check raise TypeError("Step must be callable") self.steps.append(step) return self
[docs] def remove_step(self, index: int) -> "ConfigurableModel": """Remove a processing step from the model. Args: index: The index of the step to remove Returns: Self for method chaining Raises: IndexError: If the index is out of range """ if not 0 <= index < len(self.steps): raise IndexError(f"Step index {index} out of range (0-{len(self.steps)-1})") self.steps.pop(index) return self
[docs] def forward(self, input_data: Any, *args: Any, **kwargs: Any) -> Any: """Process input through all steps sequentially. Args: input_data (Any): The input to process *args (Any): Positional arguments passed to each step **kwargs (Any): Additional keyword arguments passed to each step Returns: The result after applying all steps """ result = input_data for step in self.steps: result = step(result, *args, **kwargs) return result