"""CNN-based encoder and decoder components for deep communications."""
from typing import Any, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from kaira.models.base import BaseModel
from ..registry import ModelRegistry
[docs]
@ModelRegistry.register_model()
class ConvEncoder(BaseModel):
"""Convolutional Neural Network (CNN) Encoder for image transmission systems.
This module implements a CNN-based encoder that maps input images to encoded signals suitable
for transmission over a communication channel.
"""
[docs]
def __init__(
self,
in_channels: int,
out_features: int,
hidden_dims: Optional[List[int]] = None,
kernel_size: int = 3,
stride: int = 2,
padding: int = 1,
activation: Optional[nn.Module] = None,
*args: Any,
**kwargs: Any,
):
"""Initialize the ConvEncoder.
Args:
in_channels (int): Number of input channels in the image.
out_features (int): Dimensionality of the output encoded signals.
hidden_dims (List[int], optional): List of feature dimensions for hidden layers.
If None, default dimensions [16, 32, 64] will be used.
kernel_size (int, optional): Kernel size for convolutions. Default is 3.
stride (int, optional): Stride for convolutions. Default is 2.
padding (int, optional): Padding for convolutions. Default is 1.
activation (nn.Module, optional): Activation function to use.
If None, ReLU is used.
*args: Variable positional arguments passed to the base class.
**kwargs: Variable keyword arguments passed to the base class.
"""
super().__init__(*args, **kwargs)
if hidden_dims is None:
hidden_dims = [16, 32, 64]
if activation is None:
activation = nn.ReLU()
# Build CNN encoder layers
layers = []
# First convolutional layer
layers.append(nn.Conv2d(in_channels, hidden_dims[0], kernel_size=kernel_size, stride=stride, padding=padding))
layers.append(activation)
# Additional convolutional layers
for i in range(len(hidden_dims) - 1):
layers.append(nn.Conv2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=kernel_size, stride=stride, padding=padding))
layers.append(activation)
self.conv_layers = nn.Sequential(*layers)
# Calculate the size of flattened features after convolutions
# This is an approximate calculation assuming square input images and valid padding
self._feature_size: Optional[int] = None
# Add a final linear layer to map to the desired output dimension
calculated_feature_size = self._get_flattened_size(in_channels, hidden_dims)
if calculated_feature_size is None:
# This case should ideally not happen if _get_flattened_size works correctly
raise RuntimeError("Could not determine flattened feature size.")
self.fc = nn.Linear(calculated_feature_size, out_features)
def _get_flattened_size(self, in_channels: int, hidden_dims: List[int]) -> Optional[int]:
"""Calculate the flattened size after convolutions.
Since the actual spatial dimensions depend on the input size, we'll use a
forward pass with a dummy input to determine the size.
Args:
in_channels (int): Number of input channels.
hidden_dims (List[int]): Hidden dimensions list.
Returns:
int: Size of flattened feature vector.
"""
if self._feature_size is not None:
return self._feature_size
# Use a small dummy input to calculate output size
dummy_input = torch.zeros(1, in_channels, 32, 32) # Assume minimum size of 32x32
with torch.no_grad():
dummy_output = self.conv_layers(dummy_input)
self._feature_size = dummy_output.numel()
return self._feature_size
[docs]
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Forward pass of the ConvEncoder.
Args:
x (torch.Tensor): Input image tensor of shape (batch_size, in_channels, height, width).
*args: Additional positional arguments (unused).
**kwargs: Additional keyword arguments (unused).
Returns:
torch.Tensor: Output tensor of shape (batch_size, out_features).
"""
batch_size = x.size(0)
# Apply convolutional layers
x = self.conv_layers(x)
# Flatten the output
x = x.view(batch_size, -1)
# Apply the final linear layer
x = self.fc(x)
return x
[docs]
@ModelRegistry.register_model()
class ConvDecoder(BaseModel):
"""Convolutional Neural Network (CNN) Decoder for image transmission systems.
This module implements a CNN-based decoder that maps received signals back to their
corresponding images.
"""
[docs]
def __init__(
self,
in_features: int,
out_channels: int,
output_size: Tuple[int, int],
hidden_dims: Optional[List[int]] = None,
kernel_size: int = 3,
stride: int = 2,
padding: int = 1,
output_padding: int = 1,
activation: Optional[nn.Module] = None,
output_activation: Optional[nn.Module] = None,
*args: Any,
**kwargs: Any,
):
"""Initialize the ConvDecoder.
Args:
in_features (int): Dimensionality of the input received signals.
out_channels (int): Number of output channels in the reconstructed image.
output_size (Tuple[int, int]): Height and width of the output image.
hidden_dims (List[int], optional): List of feature dimensions for hidden layers.
If None, default dimensions [64, 32, 16] will be used.
kernel_size (int, optional): Kernel size for transposed convolutions. Default is 3.
stride (int, optional): Stride for transposed convolutions. Default is 2.
padding (int, optional): Padding for transposed convolutions. Default is 1.
output_padding (int, optional): Output padding for transposed convolutions. Default is 1.
activation (nn.Module, optional): Activation function to use between layers.
If None, ReLU is used.
output_activation (nn.Module, optional): Activation function to use at the output.
If None, Sigmoid is used to output values in [0, 1] range.
*args: Variable positional arguments passed to the base class.
**kwargs: Variable keyword arguments passed to the base class.
"""
super().__init__(*args, **kwargs)
if hidden_dims is None:
hidden_dims = [64, 32, 16] # Decoder usually goes from smaller to larger
if activation is None:
activation = nn.ReLU()
if output_activation is None:
output_activation = nn.Sigmoid() # For image output in [0, 1] range
# Calculate initial spatial dimension
self.output_height, self.output_width = output_size
self.initial_height = self.output_height // (2 ** len(hidden_dims))
self.initial_width = self.output_width // (2 ** len(hidden_dims))
# Ensure minimum size
self.initial_height = max(1, self.initial_height)
self.initial_width = max(1, self.initial_width)
# Calculate initial feature map size for the first layer
self.initial_features = hidden_dims[0]
# Initial linear layer to transform from code vector to initial feature maps
self.fc = nn.Linear(in_features, self.initial_features * self.initial_height * self.initial_width)
# Build transpose convolutional layers
layers = []
# Add transpose convolutional layers
for i in range(len(hidden_dims) - 1):
layers.append(nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding))
layers.append(activation)
# Final transpose convolutional layer to produce the output image
layers.append(nn.ConvTranspose2d(hidden_dims[-1], out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding))
# Final activation for output in [0, 1] range
if output_activation is not None:
layers.append(output_activation)
self.conv_layers = nn.Sequential(*layers)
[docs]
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Forward pass of the ConvDecoder.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
*args: Additional positional arguments (unused).
**kwargs: Additional keyword arguments (unused).
Returns:
torch.Tensor: Output image tensor of shape (batch_size, out_channels, height, width).
"""
batch_size = x.size(0)
# Apply the initial linear layer
x = self.fc(x)
# Reshape to initial feature maps
x = x.view(batch_size, self.initial_features, self.initial_height, self.initial_width)
# Apply transpose convolutional layers
x = self.conv_layers(x)
# Ensure output size is correct (in case of dimension mismatch due to rounding)
# This can happen due to integer division in calculating initial dimensions
x = F.interpolate(x, size=(self.output_height, self.output_width), mode="bilinear", align_corners=False)
return x