kaira.models.image.Yilmaz2024DeepJSCCWZSmallEncoder

Inheritance diagram of Yilmaz2024DeepJSCCWZSmallEncoder

Inheritance diagram for Yilmaz2024DeepJSCCWZSmallEncoder

class kaira.models.image.Yilmaz2024DeepJSCCWZSmallEncoder(N: int, M: int, *args: Any, **kwargs: Any)[source]

Bases: ChannelAwareBaseModel

DeepJSCC-WZ-sm Encoder Module [Yilmaz et al., 2024].

This is a lightweight version of the DeepJSCC-WZ encoder that transforms input images into a compressed latent representation suitable for transmission over noisy channels. The encoder consists of a series of residual blocks with downsampling, attention modules, and adaptive feature modules that incorporate channel state information (CSI).

DeepJSCC-WZ-sm shares encoder parameters for encoding image at the transmitter and encoding side information at the receiver, resulting in a parameter-efficient design while maintaining competitive performance.

Architecture highlights: - 4 stages of downsampling (factor of 16 total spatial reduction) - Attention mechanisms to capture important features - AFModule layers that adapt features based on channel conditions - Progressive compression: 3×H×W → M×(H/16)×(W/16) - Channel-aware design through CSI conditioning

Methods

__init__

Initialize the DeepJSCC-WZ-sm encoder.

create_csi_for_submodules

Create appropriate CSI tensors for multiple submodules.

extract_csi_features

Extract common features from CSI tensor for analysis.

extract_csi_from_channel_output

Extract CSI from channel output if available.

format_csi_for_channel

Format CSI tensor for passing to channels that expect specific formats.

forward

Process input image through the encoder.

forward_csi_to_submodules

Helper method to consistently pass CSI to submodules.

normalize_csi

Normalize CSI values to a specified range.

transform_csi

Transform CSI tensor to match target shape requirements.

validate_csi

Validate and ensure CSI tensor is in the correct format.

__init__(N: int, M: int, *args: Any, **kwargs: Any) None[source]

Initialize the DeepJSCC-WZ-sm encoder.

Parameters:
  • N (int) – Number of intermediate channels in the residual blocks. Controls the network capacity and feature dimension.

  • M (int) – Number of output channels in the final latent representation. Determines the compression rate and bandwidth usage.

  • *args – Variable positional arguments passed to the base class.

  • **kwargs – Variable keyword arguments passed to the base class.

forward(x: Tensor, csi: Tensor, *args: Any, **kwargs: Any) Tensor[source]

Process input image through the encoder.

Parameters:
  • x (torch.Tensor) – Input image tensor of shape [B, 3, H, W].

  • csi (torch.Tensor) – Channel state information tensor of shape [B, 1, 1, 1]. Contains SNR or other channel quality indicators.

  • *args – Additional positional arguments (passed to internal layers).

  • **kwargs – Additional keyword arguments (passed to internal layers).

Returns:

Encoded representation ready for transmission.

Shape: [B, M, H/16, W/16], where M is the number of channels specified during initialization.

Return type:

torch.Tensor

create_csi_for_submodules(csi: Tensor, num_modules: int) List[Tensor]

Create appropriate CSI tensors for multiple submodules.

Parameters:
  • csi (torch.Tensor) – Original CSI tensor

  • num_modules (int) – Number of submodules that need CSI

Returns:

List of CSI tensors for each submodule

Return type:

List[torch.Tensor]

extract_csi_features(csi: Tensor) Dict[str, Tensor]

Extract common features from CSI tensor for analysis.

Parameters:

csi (torch.Tensor) – The CSI tensor to analyze

Returns:

Dictionary containing extracted features

Return type:

Dict[str, torch.Tensor]

static extract_csi_from_channel_output(channel_output: Any) Tensor | None

Extract CSI from channel output if available.

Some channels return both the transmitted signal and CSI information. This static method provides a standardized way to extract CSI from various channel output formats.

Parameters:

channel_output – Output from a channel, which may contain CSI

Returns:

Extracted CSI tensor if available, None otherwise

Return type:

Optional[torch.Tensor]

static format_csi_for_channel(csi: Tensor, channel_format: str = 'tensor') Any

Format CSI tensor for passing to channels that expect specific formats.

Parameters:
  • csi (torch.Tensor) – CSI tensor to format

  • channel_format (str) – Expected format (“tensor”, “dict”, “kwargs”)

Returns:

Formatted CSI in the requested format

Return type:

Any

forward_csi_to_submodules(csi: Tensor, modules: List[BaseModel], *args, **kwargs) List[Any]

Helper method to consistently pass CSI to submodules.

This method facilitates passing CSI to multiple submodules that require channel state information, ensuring consistent handling across the model.

Parameters:
  • csi (torch.Tensor) – Channel state information tensor

  • modules (List[BaseModel]) – List of modules to apply

  • *args – Positional arguments to pass to modules

  • **kwargs – Keyword arguments to pass to modules

Returns:

List of outputs from each module

Return type:

List[Any]

normalize_csi(csi: Tensor, method: str = 'minmax', target_range: tuple = (0.0, 1.0)) Tensor

Normalize CSI values to a specified range.

Parameters:
  • csi (torch.Tensor) – The CSI tensor to normalize

  • method (str) – Normalization method. Options: “minmax”, “zscore”, “none”

  • target_range (tuple) – Target range for minmax normalization (min, max)

Returns:

Normalized CSI tensor

Return type:

torch.Tensor

Raises:

ValueError – If normalization method is not supported

transform_csi(csi: Tensor, target_shape: Size) Tensor

Transform CSI tensor to match target shape requirements.

Parameters:
  • csi (torch.Tensor) – The CSI tensor to transform

  • target_shape (torch.Size) – Target shape for the CSI tensor

Returns:

Transformed CSI tensor

Return type:

torch.Tensor

validate_csi(csi: Tensor, expected_shape: Size | None = None) Tensor

Validate and ensure CSI tensor is in the correct format.

Parameters:
  • csi (torch.Tensor) – The CSI tensor to validate

  • expected_shape (Optional[torch.Size]) – Expected shape for the CSI tensor. If None, uses cached shape or infers from tensor.

Returns:

Validated CSI tensor

Return type:

torch.Tensor

Raises:
  • ValueError – If CSI tensor is invalid or has incorrect shape

  • TypeError – If CSI is not a tensor