kaira.models.image.Yilmaz2024DeepJSCCWZSmallDecoder

Inheritance diagram of Yilmaz2024DeepJSCCWZSmallDecoder

Inheritance diagram for Yilmaz2024DeepJSCCWZSmallDecoder

class kaira.models.image.Yilmaz2024DeepJSCCWZSmallDecoder(N: int, M: int, encoder: Yilmaz2024DeepJSCCWZSmallEncoder, *args: Any, **kwargs: Any)[source]

Bases: ChannelAwareBaseModel

DeepJSCC-WZ-sm Decoder Module [Yilmaz et al., 2024].

This lightweight decoder reconstructs the original image from the received noisy representation and available side information. It employs a symmetric structure to the encoder with upsampling operations and feature fusion with side information.

The decoder follows a multi-scale fusion approach where the side information is encoded using the same encoder as the main signal, and features are fused at multiple scales during decoding. This approach effectively exploits correlations between the received signal and the side information.

DeepJSCC-WZ-sm shares encoder parameters for encoding image at the transmitter and encoding side information at the receiver, providing parameter efficiency.

Key features: - Progressive upsampling to restore spatial dimensions (H/16×W/16 → H×W) - Multi-scale side information fusion at 5 different resolution levels - Attention mechanisms to focus on important features - Channel-adaptive processing through AFModule layers - Residual connections for improved gradient flow

Methods

__init__

Initialize the DeepJSCC-WZ-sm decoder.

create_csi_for_submodules

Create appropriate CSI tensors for multiple submodules.

extract_csi_features

Extract common features from CSI tensor for analysis.

extract_csi_from_channel_output

Extract CSI from channel output if available.

format_csi_for_channel

Format CSI tensor for passing to channels that expect specific formats.

forward

Decode the received representation into a reconstructed image.

forward_csi_to_submodules

Helper method to consistently pass CSI to submodules.

normalize_csi

Normalize CSI values to a specified range.

transform_csi

Transform CSI tensor to match target shape requirements.

validate_csi

Validate and ensure CSI tensor is in the correct format.

Attributes

encoder

__init__(N: int, M: int, encoder: Yilmaz2024DeepJSCCWZSmallEncoder, *args: Any, **kwargs: Any) None[source]

Initialize the DeepJSCC-WZ-sm decoder.

Parameters:
  • N (int) – Number of intermediate channels in the residual blocks. Controls the network capacity and feature dimension.

  • M (int) – Number of input channels from the encoded representation. Matches the encoder’s output channel count.

  • encoder (Yilmaz2024DeepJSCCWZSmallEncoder) – Reference to the small encoder model for feature sharing. This enables the decoder to process side information using the same parameters as the main encoder.

  • *args – Variable positional arguments passed to the base class.

  • **kwargs – Variable keyword arguments passed to the base class.

encoder: Yilmaz2024DeepJSCCWZSmallEncoder
forward(x: Tensor, x_side: Tensor, csi: Tensor, *args: Any, **kwargs: Any) Tensor[source]

Decode the received representation into a reconstructed image.

This method first processes the side information through the shared encoder, then progressively decodes the received signal while fusing with side information features at multiple scales.

Parameters:
  • x (torch.Tensor) – Received noisy encoded representation of shape [B, M, H/16, W/16].

  • x_side (torch.Tensor) – Side information tensor of shape [B, 3, H, W].

  • csi (torch.Tensor) – Channel state information tensor of shape [B, 1, 1, 1].

  • *args – Additional positional arguments (passed to internal layers).

  • **kwargs – Additional keyword arguments (passed to internal layers).

  • x_side – Side information tensor of shape [B, 3, H, W] to assist in decoding.

  • csi – Channel state information tensor of shape [B, 1, 1, 1].

  • *args – Additional positional arguments (passed to internal layers).

  • **kwargs – Additional keyword arguments (passed to internal layers).

Returns:

Reconstructed image tensor of shape [B, 3, H, W].

Return type:

torch.Tensor

create_csi_for_submodules(csi: Tensor, num_modules: int) List[Tensor]

Create appropriate CSI tensors for multiple submodules.

Parameters:
  • csi (torch.Tensor) – Original CSI tensor

  • num_modules (int) – Number of submodules that need CSI

Returns:

List of CSI tensors for each submodule

Return type:

List[torch.Tensor]

extract_csi_features(csi: Tensor) Dict[str, Tensor]

Extract common features from CSI tensor for analysis.

Parameters:

csi (torch.Tensor) – The CSI tensor to analyze

Returns:

Dictionary containing extracted features

Return type:

Dict[str, torch.Tensor]

static extract_csi_from_channel_output(channel_output: Any) Tensor | None

Extract CSI from channel output if available.

Some channels return both the transmitted signal and CSI information. This static method provides a standardized way to extract CSI from various channel output formats.

Parameters:

channel_output – Output from a channel, which may contain CSI

Returns:

Extracted CSI tensor if available, None otherwise

Return type:

Optional[torch.Tensor]

static format_csi_for_channel(csi: Tensor, channel_format: str = 'tensor') Any

Format CSI tensor for passing to channels that expect specific formats.

Parameters:
  • csi (torch.Tensor) – CSI tensor to format

  • channel_format (str) – Expected format (“tensor”, “dict”, “kwargs”)

Returns:

Formatted CSI in the requested format

Return type:

Any

forward_csi_to_submodules(csi: Tensor, modules: List[BaseModel], *args, **kwargs) List[Any]

Helper method to consistently pass CSI to submodules.

This method facilitates passing CSI to multiple submodules that require channel state information, ensuring consistent handling across the model.

Parameters:
  • csi (torch.Tensor) – Channel state information tensor

  • modules (List[BaseModel]) – List of modules to apply

  • *args – Positional arguments to pass to modules

  • **kwargs – Keyword arguments to pass to modules

Returns:

List of outputs from each module

Return type:

List[Any]

normalize_csi(csi: Tensor, method: str = 'minmax', target_range: tuple = (0.0, 1.0)) Tensor

Normalize CSI values to a specified range.

Parameters:
  • csi (torch.Tensor) – The CSI tensor to normalize

  • method (str) – Normalization method. Options: “minmax”, “zscore”, “none”

  • target_range (tuple) – Target range for minmax normalization (min, max)

Returns:

Normalized CSI tensor

Return type:

torch.Tensor

Raises:

ValueError – If normalization method is not supported

transform_csi(csi: Tensor, target_shape: Size) Tensor

Transform CSI tensor to match target shape requirements.

Parameters:
  • csi (torch.Tensor) – The CSI tensor to transform

  • target_shape (torch.Size) – Target shape for the CSI tensor

Returns:

Transformed CSI tensor

Return type:

torch.Tensor

validate_csi(csi: Tensor, expected_shape: Size | None = None) Tensor

Validate and ensure CSI tensor is in the correct format.

Parameters:
  • csi (torch.Tensor) – The CSI tensor to validate

  • expected_shape (Optional[torch.Size]) – Expected shape for the CSI tensor. If None, uses cached shape or infers from tensor.

Returns:

Validated CSI tensor

Return type:

torch.Tensor

Raises:
  • ValueError – If CSI tensor is invalid or has incorrect shape

  • TypeError – If CSI is not a tensor