kaira.models.image.Yilmaz2024DeepJSCCWZDecoder

Inheritance diagram of Yilmaz2024DeepJSCCWZDecoder

Inheritance diagram for Yilmaz2024DeepJSCCWZDecoder

class kaira.models.image.Yilmaz2024DeepJSCCWZDecoder(N: int, M: int, *args: Any, **kwargs: Any)[source]

Bases: ChannelAwareBaseModel

DeepJSCC-WZ Decoder Module [Yilmaz et al., 2024].

The full-size decoder for the DeepJSCC-WZ model that reconstructs the original image from the received noisy representation and side information. It follows a symmetric structure to the encoder with progressive upsampling and feature fusion mechanisms.

Unlike the small variant, this decoder uses a dedicated set of parameters for processing side information, potentially allowing for more specialized feature extraction at the cost of increased parameter count.

Key features: - Multi-scale feature fusion with side information at 5 different resolution levels - Progressive spatial resolution recovery (4 upsampling stages, H/16×W/16 → H×W) - Attention-based feature refinement - Channel-adaptive processing through AFModule layers - Sophisticated feature reconstruction with residual connections

Methods

__init__

Initialize the full-size DeepJSCC-WZ decoder.

create_csi_for_submodules

Create appropriate CSI tensors for multiple submodules.

extract_csi_features

Extract common features from CSI tensor for analysis.

extract_csi_from_channel_output

Extract CSI from channel output if available.

format_csi_for_channel

Format CSI tensor for passing to channels that expect specific formats.

forward

Decode the received representation into a reconstructed image.

forward_csi_to_submodules

Helper method to consistently pass CSI to submodules.

normalize_csi

Normalize CSI values to a specified range.

transform_csi

Transform CSI tensor to match target shape requirements.

validate_csi

Validate and ensure CSI tensor is in the correct format.

__init__(N: int, M: int, *args: Any, **kwargs: Any) None[source]

Initialize the full-size DeepJSCC-WZ decoder.

Parameters:
  • N (int) – Number of intermediate channels in the residual blocks. Controls the network capacity and feature dimension.

  • M (int) – Number of input channels from the encoded representation. Matches the encoder’s output channel count.

  • *args – Variable positional arguments passed to the base class.

  • **kwargs – Variable keyword arguments passed to the base class.

forward(x: Tensor, x_side: Tensor, csi: Tensor, *args: Any, **kwargs: Any) Tensor[source]

Decode the received representation into a reconstructed image.

This method first processes the side information through the g_a2 encoder path, then progressively decodes the received signal while fusing with side information features at multiple scales.

Parameters:
  • x (torch.Tensor) – Received noisy encoded representation of shape [B, M, H/16, W/16].

  • x_side (torch.Tensor) – Side information tensor of shape [B, 3, H, W] to assist in decoding.

  • csi (torch.Tensor) – Channel state information tensor of shape [B, 1, 1, 1].

  • *args – Additional positional arguments (passed to internal layers).

  • **kwargs – Additional keyword arguments (passed to internal layers).

Returns:

Reconstructed image tensor of shape [B, 3, H, W].

Return type:

torch.Tensor

create_csi_for_submodules(csi: Tensor, num_modules: int) List[Tensor]

Create appropriate CSI tensors for multiple submodules.

Parameters:
  • csi (torch.Tensor) – Original CSI tensor

  • num_modules (int) – Number of submodules that need CSI

Returns:

List of CSI tensors for each submodule

Return type:

List[torch.Tensor]

extract_csi_features(csi: Tensor) Dict[str, Tensor]

Extract common features from CSI tensor for analysis.

Parameters:

csi (torch.Tensor) – The CSI tensor to analyze

Returns:

Dictionary containing extracted features

Return type:

Dict[str, torch.Tensor]

static extract_csi_from_channel_output(channel_output: Any) Tensor | None

Extract CSI from channel output if available.

Some channels return both the transmitted signal and CSI information. This static method provides a standardized way to extract CSI from various channel output formats.

Parameters:

channel_output – Output from a channel, which may contain CSI

Returns:

Extracted CSI tensor if available, None otherwise

Return type:

Optional[torch.Tensor]

static format_csi_for_channel(csi: Tensor, channel_format: str = 'tensor') Any

Format CSI tensor for passing to channels that expect specific formats.

Parameters:
  • csi (torch.Tensor) – CSI tensor to format

  • channel_format (str) – Expected format (“tensor”, “dict”, “kwargs”)

Returns:

Formatted CSI in the requested format

Return type:

Any

forward_csi_to_submodules(csi: Tensor, modules: List[BaseModel], *args, **kwargs) List[Any]

Helper method to consistently pass CSI to submodules.

This method facilitates passing CSI to multiple submodules that require channel state information, ensuring consistent handling across the model.

Parameters:
  • csi (torch.Tensor) – Channel state information tensor

  • modules (List[BaseModel]) – List of modules to apply

  • *args – Positional arguments to pass to modules

  • **kwargs – Keyword arguments to pass to modules

Returns:

List of outputs from each module

Return type:

List[Any]

normalize_csi(csi: Tensor, method: str = 'minmax', target_range: tuple = (0.0, 1.0)) Tensor

Normalize CSI values to a specified range.

Parameters:
  • csi (torch.Tensor) – The CSI tensor to normalize

  • method (str) – Normalization method. Options: “minmax”, “zscore”, “none”

  • target_range (tuple) – Target range for minmax normalization (min, max)

Returns:

Normalized CSI tensor

Return type:

torch.Tensor

Raises:

ValueError – If normalization method is not supported

transform_csi(csi: Tensor, target_shape: Size) Tensor

Transform CSI tensor to match target shape requirements.

Parameters:
  • csi (torch.Tensor) – The CSI tensor to transform

  • target_shape (torch.Size) – Target shape for the CSI tensor

Returns:

Transformed CSI tensor

Return type:

torch.Tensor

validate_csi(csi: Tensor, expected_shape: Size | None = None) Tensor

Validate and ensure CSI tensor is in the correct format.

Parameters:
  • csi (torch.Tensor) – The CSI tensor to validate

  • expected_shape (Optional[torch.Size]) – Expected shape for the CSI tensor. If None, uses cached shape or infers from tensor.

Returns:

Validated CSI tensor

Return type:

torch.Tensor

Raises:
  • ValueError – If CSI tensor is invalid or has incorrect shape

  • TypeError – If CSI is not a tensor