kaira.models.image.Yilmaz2024DeepJSCCWZConditionalDecoder

Inheritance diagram for Yilmaz2024DeepJSCCWZConditionalDecoder
- class kaira.models.image.Yilmaz2024DeepJSCCWZConditionalDecoder(N: int, M: int, *args: Any, **kwargs: Any)[source]
Bases:
ChannelAwareBaseModelDeepJSCC-WZ Conditional Decoder Module [Yilmaz et al., 2024].
The decoder counterpart to the conditional encoder, designed to reconstruct images from representations created by the conditional encoder. This decoder leverages side information and received encoded representation to generate high-quality reconstructions.
DeepJSCC-WZ Conditional is designed for scenarios where side information is available at both the encoder and decoder, serving as a performance upper bound. The decoder’s architecture is optimized to work with the conditional encoder’s output, where side information correlations have already been exploited during encoding.
Key features: - Multi-scale feature fusion with side information at 5 different resolution levels - Progressive upsampling to restore spatial dimensions (H/16×W/16 → H×W) - Attention-based feature refinement - Channel-adaptive processing through AFModule layers - Optimized for conditionally encoded representations
Methods
Initialize the conditional DeepJSCC-WZ decoder.
Create appropriate CSI tensors for multiple submodules.
Extract common features from CSI tensor for analysis.
Extract CSI from channel output if available.
Format CSI tensor for passing to channels that expect specific formats.
Decode the received representation into a reconstructed image.
Helper method to consistently pass CSI to submodules.
Normalize CSI values to a specified range.
Transform CSI tensor to match target shape requirements.
Validate and ensure CSI tensor is in the correct format.
- __init__(N: int, M: int, *args: Any, **kwargs: Any) None[source]
Initialize the conditional DeepJSCC-WZ decoder.
- Parameters:
N (int) – Number of intermediate channels in the residual blocks. Controls the network capacity and feature dimension.
M (int) – Number of input channels from the encoded representation. Matches the encoder’s output channel count.
*args – Variable positional arguments passed to the base class.
**kwargs – Variable keyword arguments passed to the base class.
- forward(x: Tensor, x_side: Tensor, csi: Tensor, *args: Any, **kwargs: Any) Tensor[source]
Decode the received representation into a reconstructed image.
This method first processes the side information through the g_a2 encoder path, then progressively decodes the received signal while fusing with side information features at multiple scales.
- Parameters:
x (torch.Tensor) – Received noisy encoded representation of shape [B, M, H/16, W/16].
x_side (torch.Tensor) – Side information tensor of shape [B, 3, H, W] to assist in decoding.
csi (torch.Tensor) – Channel state information tensor of shape [B, 1, 1, 1].
*args – Additional positional arguments (passed to internal layers).
**kwargs – Additional keyword arguments (passed to internal layers).
- Returns:
Reconstructed image tensor of shape [B, 3, H, W].
- Return type:
- create_csi_for_submodules(csi: Tensor, num_modules: int) List[Tensor]
Create appropriate CSI tensors for multiple submodules.
- Parameters:
csi (torch.Tensor) – Original CSI tensor
num_modules (int) – Number of submodules that need CSI
- Returns:
List of CSI tensors for each submodule
- Return type:
List[torch.Tensor]
- extract_csi_features(csi: Tensor) Dict[str, Tensor]
Extract common features from CSI tensor for analysis.
- Parameters:
csi (torch.Tensor) – The CSI tensor to analyze
- Returns:
Dictionary containing extracted features
- Return type:
Dict[str, torch.Tensor]
- static extract_csi_from_channel_output(channel_output: Any) Tensor | None
Extract CSI from channel output if available.
Some channels return both the transmitted signal and CSI information. This static method provides a standardized way to extract CSI from various channel output formats.
- Parameters:
channel_output – Output from a channel, which may contain CSI
- Returns:
Extracted CSI tensor if available, None otherwise
- Return type:
Optional[torch.Tensor]
- static format_csi_for_channel(csi: Tensor, channel_format: str = 'tensor') Any
Format CSI tensor for passing to channels that expect specific formats.
- Parameters:
csi (torch.Tensor) – CSI tensor to format
channel_format (str) – Expected format (“tensor”, “dict”, “kwargs”)
- Returns:
Formatted CSI in the requested format
- Return type:
Any
- forward_csi_to_submodules(csi: Tensor, modules: List[BaseModel], *args, **kwargs) List[Any]
Helper method to consistently pass CSI to submodules.
This method facilitates passing CSI to multiple submodules that require channel state information, ensuring consistent handling across the model.
- Parameters:
csi (torch.Tensor) – Channel state information tensor
modules (List[BaseModel]) – List of modules to apply
*args – Positional arguments to pass to modules
**kwargs – Keyword arguments to pass to modules
- Returns:
List of outputs from each module
- Return type:
List[Any]
- normalize_csi(csi: Tensor, method: str = 'minmax', target_range: tuple = (0.0, 1.0)) Tensor
Normalize CSI values to a specified range.
- Parameters:
csi (torch.Tensor) – The CSI tensor to normalize
method (str) – Normalization method. Options: “minmax”, “zscore”, “none”
target_range (tuple) – Target range for minmax normalization (min, max)
- Returns:
Normalized CSI tensor
- Return type:
- Raises:
ValueError – If normalization method is not supported
- transform_csi(csi: Tensor, target_shape: Size) Tensor
Transform CSI tensor to match target shape requirements.
- Parameters:
csi (torch.Tensor) – The CSI tensor to transform
target_shape (torch.Size) – Target shape for the CSI tensor
- Returns:
Transformed CSI tensor
- Return type:
- validate_csi(csi: Tensor, expected_shape: Size | None = None) Tensor
Validate and ensure CSI tensor is in the correct format.
- Parameters:
csi (torch.Tensor) – The CSI tensor to validate
expected_shape (Optional[torch.Size]) – Expected shape for the CSI tensor. If None, uses cached shape or infers from tensor.
- Returns:
Validated CSI tensor
- Return type:
- Raises:
ValueError – If CSI tensor is invalid or has incorrect shape
TypeError – If CSI is not a tensor