Source code for kaira.data.correlation
"""Correlation models for data generation and simulation.
This module contains models for simulating statistical correlations between data sources, which is
particularly useful for distributed source coding scenarios.
"""
from typing import Any, Dict, Optional
import torch
from torch.utils.data import Dataset
from kaira.models.wyner_ziv import WynerZivCorrelationModel
[docs]
class WynerZivCorrelationDataset(Dataset):
r"""Dataset for Wyner-Ziv coding scenarios with correlated sources.
This dataset pairs source data with correlated side information according to a
specified correlation model. It's particularly useful for simulating and evaluating
Wyner-Ziv coding scenarios where the decoder has access to side information that is
statistically correlated with the source.
Attributes:
model: The correlation model used to generate side information
data: The source data tensor with shape (n_samples, \*feature_dims)
correlated_data: The correlated side information with same shape as source data
"""
[docs]
def __init__(self, source: torch.Tensor, correlation_type: str = "gaussian", correlation_params: Optional[Dict[str, Any]] = None, *args, **kwargs):
"""Initialize the Wyner-Ziv correlated dataset.
Args:
source: Source data tensor where the first dimension represents the number of samples
correlation_type: Type of correlation model:
- 'gaussian': Additive Gaussian noise
- 'binary': Binary symmetric channel
- 'custom': User-defined model
correlation_params: Parameters for the correlation model:
- For 'gaussian': {'sigma': float} - Standard deviation of the noise
- For 'binary': {'crossover_prob': float} - Probability of bit flipping
- For 'custom': {'transform_fn': callable} - Custom transformation function
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs) # Pass args and kwargs to parent if necessary
self.model = WynerZivCorrelationModel(correlation_type, correlation_params, *args, **kwargs)
self.data = source
self.correlated_data = self.model(source, *args, **kwargs)
def __len__(self):
"""Return the number of samples in the dataset.
Returns:
int: The number of samples, corresponding to the first dimension of data
"""
return self.data.size(0)
def __getitem__(self, idx):
"""Retrieve a source-side information pair from the dataset at the specified index.
Args:
idx: Index or slice object to index into the dataset
Returns:
tuple: A pair of tensors (source, side_information) representing the
source data and its correlated side information at the specified
index/indices
"""
return self.data[idx], self.correlated_data[idx]