Correlation Models for Data Generation

This example demonstrates the correlation models in Kaira, which are useful for simulating statistical correlations between data sources in distributed source coding scenarios like Wyner-Ziv coding.

import matplotlib.pyplot as plt
import numpy as np
import torch

from kaira.data import WynerZivCorrelationDataset, create_binary_tensor, create_uniform_tensor
from kaira.models.wyner_ziv import WynerZivCorrelationModel

# Plotting imports
from kaira.utils.plotting import PlottingUtils

PlottingUtils.setup_plotting_style()

Imports and Setup

Correlation Models Configuration and Setup

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

1. Introduction to Wyner-Ziv Correlation Models

In Wyner-Ziv coding, there is correlation between the source X and the side information Y available at the decoder. This correlation is critical as it determines the theoretical rate bounds and practical coding efficiency.

# First, let's create a source signal
n_samples = 1
n_features = 1000
source = create_uniform_tensor(size=[n_samples, n_features], low=0.0, high=1.0)

# We'll create different correlation models to demonstrate the relationships
# between the source and side information

2. Gaussian Correlation Model

The Gaussian correlation model adds Gaussian noise to the source. This is equivalent to passing the source through an AWGN channel.

# Create a correlation model with Gaussian noise
sigma_values = [0.1, 0.3, 0.5]
gaussian_models = []
gaussian_side_info = []

for sigma in sigma_values:
    model = WynerZivCorrelationModel(correlation_type="gaussian", correlation_params={"sigma": sigma})
    gaussian_models.append(model)
    # Generate correlated side information
    with torch.no_grad():
        side_info = model(source)
    gaussian_side_info.append(side_info)

Visualizing Gaussian Correlation

Gaussian Correlation Visualization

Let’s visualize the relationship between the source and side information for different noise levels.

fig, axes = plt.subplots(4, 1, figsize=(15, 10))

# Only show a segment for clarity
segment_size = 100
segment_start = 0
segment_end = segment_start + segment_size

# Plot original source
axes[0].plot(source[0, segment_start:segment_end].numpy(), "b-", label="Source X")
axes[0].set_title("Original Source Signal")
axes[0].set_ylabel("Amplitude")
axes[0].grid(True, alpha=0.3)
axes[0].legend()

# Plot side information for each sigma value
colors = ["g", "r", "m"]
for i, (sigma, side_info) in enumerate(zip(sigma_values, gaussian_side_info)):
    axes[i + 1].plot(source[0, segment_start:segment_end].numpy(), "b-", label="Source X")
    axes[i + 1].plot(side_info[0, segment_start:segment_end].numpy(), colors[i] + "-", label=f"Side Info Y (σ={sigma})")
    axes[i + 1].set_title(f"Gaussian Correlation (σ={sigma})")
    axes[i + 1].set_ylabel("Amplitude")
    axes[i + 1].grid(True, alpha=0.3)
    axes[i + 1].legend()

axes[-1].set_xlabel("Sample Index")
plt.tight_layout()
plt.show()
Original Source Signal, Gaussian Correlation (σ=0.1), Gaussian Correlation (σ=0.3), Gaussian Correlation (σ=0.5)

Visualizing the Statistical Dependence

Statistical Dependence Visualization

Let’s plot the joint distribution of X and Y to visualize the correlation strength.

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for i, (sigma, side_info) in enumerate(zip(sigma_values, gaussian_side_info)):
    axes[i].scatter(source.numpy().flatten(), side_info.numpy().flatten(), alpha=0.3, s=10)
    axes[i].set_title(f"Joint Distribution (σ={sigma})")
    axes[i].set_xlabel("Source X")
    axes[i].set_ylabel("Side Information Y")

    # Add regression line to visualize correlation
    z = np.polyfit(source.numpy().flatten(), side_info.numpy().flatten(), 1)
    p = np.poly1d(z)
    axes[i].plot([0, 1], [p(0), p(1)], "r--", alpha=0.8)

    # Calculate and display correlation coefficient
    corr_coef = np.corrcoef(source.numpy().flatten(), side_info.numpy().flatten())[0, 1]
    axes[i].text(0.05, 0.95, f"Correlation: {corr_coef:.4f}", transform=axes[i].transAxes, fontsize=12, verticalalignment="top", bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))

    axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
Joint Distribution (σ=0.1), Joint Distribution (σ=0.3), Joint Distribution (σ=0.5)

3. Binary Symmetric Channel Correlation

For binary sources, we can model correlation as a Binary Symmetric Channel (BSC) where bits are flipped with probability p.

# Create a binary source
binary_source = create_binary_tensor(size=[1, n_features], prob=0.5)

# Create correlation models with different crossover probabilities
crossover_probs = [0.05, 0.1, 0.3]
binary_models = []
binary_side_info = []

for crossover_p in crossover_probs:
    model = WynerZivCorrelationModel(correlation_type="binary", correlation_params={"crossover_prob": crossover_p})
    binary_models.append(model)
    # Generate correlated side information
    with torch.no_grad():
        side_info = model(binary_source)
    binary_side_info.append(side_info)

Visualizing Binary Correlation

Let’s visualize the relationship between the binary source and side information for different crossover probabilities.

plt.figure(figsize=(15, 10))

# Only show a segment for clarity
segment_size = 50
segment_start = 0
segment_end = segment_start + segment_size

# Plot original binary source
ax1 = plt.subplot(4, 1, 1)
plt.step(np.arange(segment_size), binary_source[0, segment_start:segment_end].numpy(), "b-", where="mid", label="Source X")
plt.title("Original Binary Source")
plt.ylabel("Value")
plt.ylim(-0.1, 1.1)
plt.grid(True, alpha=0.3)
plt.legend()

# Plot side information for each crossover probability
colors = ["g", "r", "m"]
for i, (crossover_prob, side_info) in enumerate(zip(crossover_probs, binary_side_info)):
    ax = plt.subplot(4, 1, i + 2, sharex=ax1)
    plt.step(np.arange(segment_size), binary_source[0, segment_start:segment_end].numpy(), "b-", where="mid", label="Source X")
    plt.step(np.arange(segment_size), side_info[0, segment_start:segment_end].numpy(), colors[i] + "-", where="mid", label=f"Side Info Y (p={crossover_prob})")

    # Highlight the flipped bits
    flipped = binary_source[0, segment_start:segment_end] != side_info[0, segment_start:segment_end]
    flipped_indices = np.where(flipped.numpy())[0]
    if len(flipped_indices) > 0:
        plt.scatter(flipped_indices, side_info[0, segment_start:segment_end][flipped].numpy(), s=100, facecolors="none", edgecolors="black")

    plt.title(f"Binary Symmetric Channel Correlation (p={crossover_prob})")
    plt.ylabel("Value")
    plt.ylim(-0.1, 1.1)
    plt.grid(True, alpha=0.3)
    plt.legend()

plt.xlabel("Sample Index")
plt.tight_layout()
plt.show()
Original Binary Source, Binary Symmetric Channel Correlation (p=0.05), Binary Symmetric Channel Correlation (p=0.1), Binary Symmetric Channel Correlation (p=0.3)

4. Custom Correlation Models

WynerZivCorrelationModel also supports custom correlation models through a user-defined transformation function.

# Define a custom transformation function
def custom_transform(x):
    """A custom correlation model where Y = 0.8*X + 0.2*sin(2πX) This introduces both linear
    correlation and nonlinear distortion."""
    return 0.8 * x + 0.2 * torch.sin(2 * np.pi * x)


# Create a custom correlation model
custom_model = WynerZivCorrelationModel(correlation_type="custom", correlation_params={"transform_fn": custom_transform})

# Generate source and correlated side information
source = create_uniform_tensor(size=[1, n_features], low=0.0, high=1.0)
with torch.no_grad():
    custom_side_info = custom_model(source)

Visualizing Custom Correlation

Let’s visualize the relationship for our custom correlation model.

plt.figure(figsize=(12, 10))

# Plot the signals
plt.subplot(2, 1, 1)
plt.plot(source[0, segment_start:segment_end].numpy(), "b-", label="Source X")
plt.plot(custom_side_info[0, segment_start:segment_end].numpy(), "g-", label="Side Info Y (Custom)")
plt.title("Custom Correlation Model")
plt.ylabel("Amplitude")
plt.grid(True, alpha=0.3)
plt.legend()

# Plot the joint distribution
plt.subplot(2, 1, 2)
plt.scatter(source.numpy().flatten(), custom_side_info.numpy().flatten(), alpha=0.3, s=10)
plt.title("Joint Distribution (Custom Model)")
plt.xlabel("Source X")
plt.ylabel("Side Information Y")
plt.grid(True, alpha=0.3)

# Plot the theoretical curve Y = 0.8*X + 0.2*sin(2πX)
x_vals = np.linspace(0, 1, 100)
y_vals = 0.8 * x_vals + 0.2 * np.sin(2 * np.pi * x_vals)
plt.plot(x_vals, y_vals, "r-", alpha=0.8, label="Theoretical Y = 0.8X + 0.2sin(2πX)")
plt.legend()

plt.tight_layout()
plt.show()
Custom Correlation Model, Joint Distribution (Custom Model)

5. Using the WynerZivCorrelationDataset

Kaira provides a dataset class that pairs source data with correlated side information according to a specified model.

# Generate source data
n_samples = 1000
feature_dim = 8
source_data = create_uniform_tensor(size=[n_samples, feature_dim], low=0.0, high=1.0)

# Create datasets with different correlation types
gaussian_dataset = WynerZivCorrelationDataset(source=source_data, correlation_type="gaussian", correlation_params={"sigma": 0.2})

binary_source = create_binary_tensor(size=[n_samples, feature_dim], prob=0.5)
binary_dataset = WynerZivCorrelationDataset(source=binary_source, correlation_type="binary", correlation_params={"crossover_prob": 0.1})

custom_dataset = WynerZivCorrelationDataset(source=source_data, correlation_type="custom", correlation_params={"transform_fn": custom_transform})

print(f"Dataset size: {len(gaussian_dataset)}")
print(f"Sample shape: {gaussian_dataset[0][0].shape}")
print(f"Sample type: {type(gaussian_dataset[0])}")
Dataset size: 1000
Sample shape: torch.Size([8])
Sample type: <class 'tuple'>

Visualizing Dataset Samples

Let’s visualize some samples from our correlation datasets.

plt.figure(figsize=(15, 12))

# Select a few samples to visualize
sample_indices = [0, 1, 2]

# Plot Gaussian correlation dataset samples
plt.subplot(3, 1, 1)
for i, idx in enumerate(sample_indices):
    x, y = gaussian_dataset[idx]
    plt.plot(x.numpy(), "b-", alpha=0.7, label=f"Source X {i+1}" if i == 0 else "_")
    plt.plot(y.numpy(), "g-", alpha=0.7, label=f"Side Info Y {i+1}" if i == 0 else "_")
plt.title("Gaussian Correlation Dataset Samples")
plt.xlabel("Feature Index")
plt.ylabel("Value")
plt.grid(True, alpha=0.3)
plt.legend()

# Plot Binary correlation dataset samples
plt.subplot(3, 1, 2)
for i, idx in enumerate(sample_indices):
    x, y = binary_dataset[idx]
    plt.step(np.arange(feature_dim), x.numpy(), "b-", where="mid", alpha=0.7, label=f"Source X {i+1}" if i == 0 else "_")
    plt.step(np.arange(feature_dim), y.numpy(), "g-", where="mid", alpha=0.7, label=f"Side Info Y {i+1}" if i == 0 else "_")
plt.title("Binary Correlation Dataset Samples")
plt.xlabel("Feature Index")
plt.ylabel("Value")
plt.ylim(-0.1, 1.1)
plt.grid(True, alpha=0.3)
plt.legend()

# Plot Custom correlation dataset samples
plt.subplot(3, 1, 3)
for i, idx in enumerate(sample_indices):
    x, y = custom_dataset[idx]
    plt.plot(x.numpy(), "b-", alpha=0.7, label=f"Source X {i+1}" if i == 0 else "_")
    plt.plot(y.numpy(), "g-", alpha=0.7, label=f"Side Info Y {i+1}" if i == 0 else "_")
plt.title("Custom Correlation Dataset Samples")
plt.xlabel("Feature Index")
plt.ylabel("Value")
plt.grid(True, alpha=0.3)
plt.legend()

plt.tight_layout()
plt.show()
Gaussian Correlation Dataset Samples, Binary Correlation Dataset Samples, Custom Correlation Dataset Samples

6. Application: Distributed Source Coding Simulation

Let’s demonstrate a practical application where we simulate a basic distributed source coding scenario.

# Generate a larger binary source
n_samples = 1
n_bits = 1000
source_bits = create_binary_tensor(size=[n_samples, n_bits], prob=0.5)

# Create correlated side information (BSC with p=0.1)
correlation_model = WynerZivCorrelationModel(correlation_type="binary", correlation_params={"crossover_prob": 0.1})
side_info = correlation_model(source_bits)

# Calculate the empirical joint distribution
joint_counts = torch.zeros(2, 2)
for i in range(n_bits):
    x = int(source_bits[0, i].item())
    y = int(side_info[0, i].item())
    joint_counts[x, y] += 1

joint_probs = joint_counts / n_bits
marginal_x = joint_probs.sum(dim=1)
marginal_y = joint_probs.sum(dim=0)

# Calculate conditional entropies
H_X_given_Y = 0
for x in range(2):
    for y in range(2):
        if joint_probs[x, y] > 0:
            p_x_given_y = joint_probs[x, y] / marginal_y[y]
            if p_x_given_y > 0:
                H_X_given_Y -= marginal_y[y] * p_x_given_y * np.log2(p_x_given_y)

H_X = -sum(p * np.log2(p) if p > 0 else 0 for p in marginal_x)
H_Y = -sum(p * np.log2(p) if p > 0 else 0 for p in marginal_y)
I_X_Y = H_X - H_X_given_Y  # Mutual information

print("Joint Probability Distribution:")
print(joint_probs)
print(f"Entropy of X: H(X) = {H_X:.4f} bits")
print(f"Entropy of Y: H(Y) = {H_Y:.4f} bits")
print(f"Conditional Entropy: H(X|Y) = {H_X_given_Y:.4f} bits")
print(f"Mutual Information: I(X;Y) = {I_X_Y:.4f} bits")
print(f"Theoretical Rate Savings: {I_X_Y/H_X*100:.2f}%")
/home/runner/work/kaira/kaira/examples/data/plot_correlation_models.py:365: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)
  H_X_given_Y -= marginal_y[y] * p_x_given_y * np.log2(p_x_given_y)
/home/runner/work/kaira/kaira/examples/data/plot_correlation_models.py:367: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)
  H_X = -sum(p * np.log2(p) if p > 0 else 0 for p in marginal_x)
/home/runner/work/kaira/kaira/examples/data/plot_correlation_models.py:368: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)
  H_Y = -sum(p * np.log2(p) if p > 0 else 0 for p in marginal_y)
Joint Probability Distribution:
tensor([[0.4470, 0.0580],
        [0.0490, 0.4460]])
Entropy of X: H(X) = 0.9999 bits
Entropy of Y: H(Y) = 1.0000 bits
Conditional Entropy: H(X|Y) = 0.4903 bits
Mutual Information: I(X;Y) = 0.5096 bits
Theoretical Rate Savings: 50.97%

Visualizing Joint Distribution

plt.figure(figsize=(10, 8))

# Plot joint distribution as a heatmap
plt.subplot(2, 2, 1)
plt.imshow(joint_probs.numpy(), cmap="Blues", interpolation="nearest")
plt.colorbar(label="Joint Probability P(X,Y)")
plt.title("Joint Distribution P(X,Y)")
plt.xlabel("Side Information Y")
plt.ylabel("Source X")
plt.xticks([0, 1], ["0", "1"])
plt.yticks([0, 1], ["0", "1"])

for i in range(2):
    for j in range(2):
        plt.text(j, i, f"{joint_probs[i, j]:.3f}", ha="center", va="center", color="black" if joint_probs[i, j] < 0.4 else "white", fontsize=12)

# Plot conditional distribution P(X|Y) as a heatmap
plt.subplot(2, 2, 2)
cond_probs = joint_probs / marginal_y.unsqueeze(0)
plt.imshow(cond_probs.numpy(), cmap="Greens", interpolation="nearest")
plt.colorbar(label="Conditional Probability P(X|Y)")
plt.title("Conditional Distribution P(X|Y)")
plt.xlabel("Side Information Y")
plt.ylabel("Source X")
plt.xticks([0, 1], ["0", "1"])
plt.yticks([0, 1], ["0", "1"])

for i in range(2):
    for j in range(2):
        plt.text(j, i, f"{cond_probs[i, j]:.3f}", ha="center", va="center", color="black" if cond_probs[i, j] < 0.4 else "white", fontsize=12)

# Plot information theoretic quantities
plt.subplot(2, 1, 2)
labels = ["H(X)", "H(Y)", "H(X|Y)", "I(X;Y)"]
values = [H_X, H_Y, H_X_given_Y, I_X_Y]
plt.bar(labels, values, color=["blue", "green", "red", "purple"])
plt.title("Information Theoretic Quantities")
plt.ylabel("Bits")
plt.grid(axis="y", alpha=0.3)

for i, v in enumerate(values):
    plt.text(i, v + 0.02, f"{v:.3f}", ha="center", va="bottom")

plt.tight_layout()
plt.show()
Joint Distribution P(X,Y), Conditional Distribution P(X|Y), Information Theoretic Quantities

Conclusion

This example demonstrated the correlation models in Kaira:

  1. Gaussian correlation for continuous-valued sources

  2. Binary symmetric channel correlation for binary sources

  3. Custom correlation through user-defined functions

  4. Using WynerZivCorrelationDataset for paired data

  5. Application to distributed source coding

These models are useful for:

  • Simulating Wyner-Ziv coding scenarios

  • Evaluating distributed compression algorithms

  • Studying rate-distortion tradeoffs with side information

  • Information theoretic analysis of correlated sources

Total running time of the script: (0 minutes 1.928 seconds)

Gallery generated by Sphinx-Gallery