"""Visualization utilities for benchmark results."""
import json
from pathlib import Path
from typing import Any, Dict, Optional
import matplotlib.pyplot as plt
import seaborn as sns
import torch
# Set style
plt.style.use("seaborn-v0_8")
sns.set_palette("husl")
[docs]
class BenchmarkVisualizer:
"""Visualizer for benchmark results."""
[docs]
def __init__(self, figsize: tuple = (10, 6), dpi: int = 100):
"""Initialize visualizer.
Args:
figsize: Figure size in inches (width, height)
dpi: Figure resolution
"""
self.figsize = figsize
self.dpi = dpi
[docs]
def plot_ber_curve(self, results: Dict[str, Any], save_path: Optional[str] = None) -> plt.Figure:
"""Plot BER vs SNR curve.
Args:
results: Benchmark results containing SNR and BER data
save_path: Optional path to save the figure
Returns:
Matplotlib figure object
"""
fig, ax = plt.subplots(figsize=self.figsize, dpi=self.dpi)
snr_range = results.get("snr_range", [])
# Plot simulated BER
if "ber_simulated" in results:
ax.semilogy(snr_range, results["ber_simulated"], "o-", label="Simulated", linewidth=2, markersize=6)
elif "ber_results" in results:
ax.semilogy(snr_range, results["ber_results"], "o-", label="Simulated", linewidth=2, markersize=6)
# Plot theoretical BER if available
if "ber_theoretical" in results:
ax.semilogy(snr_range, results["ber_theoretical"], "--", label="Theoretical", linewidth=2)
# Plot coded and uncoded BER if available
if "ber_uncoded" in results and "ber_coded" in results:
ax.semilogy(snr_range, results["ber_uncoded"], "o-", label="Uncoded", linewidth=2, markersize=6)
ax.semilogy(snr_range, results["ber_coded"], "s-", label="Coded", linewidth=2, markersize=6)
ax.set_xlabel("SNR (dB)", fontsize=12)
ax.set_ylabel("Bit Error Rate", fontsize=12)
# Determine title from benchmark name or context
benchmark_name = results.get("benchmark_name", "")
if not benchmark_name:
# Try to infer from other fields
if "modulation" in results:
benchmark_name = f"BER Simulation ({results['modulation'].upper()})"
elif "constellation_size" in results:
benchmark_name = f"{results['constellation_size']}-QAM BER"
else:
benchmark_name = "BER Performance"
ax.set_title(f"BER Performance - {benchmark_name}", fontsize=14)
ax.grid(True, alpha=0.3)
ax.legend(fontsize=11)
# Add text with key metrics
if "rmse" in results:
ax.text(0.02, 0.98, f'RMSE: {results["rmse"]:.2e}', transform=ax.transAxes, verticalalignment="top", bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=self.dpi, bbox_inches="tight")
return fig
[docs]
def plot_throughput_comparison(self, results: Dict[str, Any], save_path: Optional[str] = None) -> plt.Figure:
"""Plot throughput comparison.
Args:
results: Benchmark results containing throughput data
save_path: Optional path to save the figure
Returns:
Matplotlib figure object
"""
fig, ax = plt.subplots(figsize=self.figsize, dpi=self.dpi)
if "throughput_results" in results:
# Bar plot for different payload sizes
payload_sizes = []
mean_throughputs = []
std_throughputs = []
for size, stats in results["throughput_results"].items():
payload_sizes.append(size)
mean_throughputs.append(stats["mean"])
std_throughputs.append(stats["std"])
x_pos = torch.arange(len(payload_sizes))
bars = ax.bar(x_pos, mean_throughputs, yerr=std_throughputs, capsize=5, alpha=0.7, edgecolor="black")
ax.set_xlabel("Payload Size (bits)", fontsize=12)
ax.set_ylabel("Throughput (bits/s)", fontsize=12)
ax.set_title("Throughput vs Payload Size", fontsize=14)
ax.set_xticks(x_pos)
ax.set_xticklabels([str(size) for size in payload_sizes])
ax.grid(True, alpha=0.3)
# Color bars based on throughput
import matplotlib.colors as mcolors
import numpy as np
colors = mcolors.LinearSegmentedColormap.from_list("viridis", ["purple", "blue", "green", "yellow"])(np.linspace(0, 1, len(bars)))
for bar, color in zip(bars, colors):
bar.set_color(color)
elif "throughput_bps" in results:
# Line plot for OFDM throughput vs SNR
snr_range = results.get("snr_range", [])
ax.plot(snr_range, results["throughput_bps"], "o-", linewidth=2, markersize=6)
ax.set_xlabel("SNR (dB)", fontsize=12)
ax.set_ylabel("Throughput (bits/s)", fontsize=12)
ax.set_title("Throughput vs SNR", fontsize=14)
ax.grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=self.dpi, bbox_inches="tight")
return fig
[docs]
def plot_latency_distribution(self, results: Dict[str, Any], save_path: Optional[str] = None) -> plt.Figure:
"""Plot latency distribution.
Args:
results: Benchmark results containing latency data
save_path: Optional path to save the figure
Returns:
Matplotlib figure object
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6), dpi=self.dpi)
# Extract latency statistics
latency_stats = results.get("inference_latency_ms", results)
# Box plot
if "percentiles" in latency_stats:
percentiles = latency_stats["percentiles"]
box_data = [percentiles["p25"], percentiles["p50"], percentiles["p75"]]
bp = ax1.boxplot([box_data], patch_artist=True, labels=["Latency"])
bp["boxes"][0].set_facecolor("lightblue")
bp["boxes"][0].set_alpha(0.7)
ax1.set_ylabel("Latency (ms)", fontsize=12)
ax1.set_title("Latency Distribution", fontsize=14)
ax1.grid(True, alpha=0.3)
# Add statistics text
stats_text = []
if "mean_latency" in latency_stats:
stats_text.append(f"Mean: {latency_stats['mean_latency']:.2f} ms")
if "std_latency" in latency_stats:
stats_text.append(f"Std: {latency_stats['std_latency']:.2f} ms")
if "min_latency" in latency_stats:
stats_text.append(f"Min: {latency_stats['min_latency']:.2f} ms")
if "max_latency" in latency_stats:
stats_text.append(f"Max: {latency_stats['max_latency']:.2f} ms")
if stats_text:
ax1.text(0.02, 0.98, "\n".join(stats_text), transform=ax1.transAxes, verticalalignment="top", bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))
# Throughput bar (if available)
if "throughput_samples_per_second" in results:
throughput = results["throughput_samples_per_second"]
ax2.bar(["Throughput"], [throughput], color="orange", alpha=0.7)
ax2.set_ylabel("Samples/second", fontsize=12)
ax2.set_title("Processing Throughput", fontsize=14)
ax2.grid(True, alpha=0.3)
else:
ax2.axis("off")
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=self.dpi, bbox_inches="tight")
return fig
[docs]
def plot_constellation(self, constellation: torch.Tensor, received_symbols: Optional[torch.Tensor] = None, save_path: Optional[str] = None) -> plt.Figure:
"""Plot constellation diagram.
Args:
constellation: Ideal constellation points
received_symbols: Optional received symbols to overlay
save_path: Optional path to save the figure
Returns:
Matplotlib figure object
"""
fig, ax = plt.subplots(figsize=self.figsize, dpi=self.dpi)
# Plot ideal constellation
ax.scatter(constellation.real, constellation.imag, c="red", s=100, marker="x", linewidths=3, label="Ideal")
# Plot received symbols if provided
if received_symbols is not None:
# Subsample if too many points
if len(received_symbols) > 1000:
indices = torch.randperm(len(received_symbols))[:1000]
received_symbols = received_symbols[indices]
ax.scatter(received_symbols.real, received_symbols.imag, c="blue", s=20, alpha=0.6, label="Received")
ax.set_xlabel("In-Phase", fontsize=12)
ax.set_ylabel("Quadrature", fontsize=12)
ax.set_title("Constellation Diagram", fontsize=14)
ax.grid(True, alpha=0.3)
ax.legend()
ax.axis("equal")
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=self.dpi, bbox_inches="tight")
return fig
[docs]
def plot_coding_gain(self, results: Dict[str, Any], save_path: Optional[str] = None) -> plt.Figure:
"""Plot coding gain vs SNR.
Args:
results: Benchmark results containing coding gain data
save_path: Optional path to save the figure
Returns:
Matplotlib figure object
"""
fig, ax = plt.subplots(figsize=self.figsize, dpi=self.dpi)
snr_range = results.get("snr_range", [])
coding_gain = results.get("coding_gain_db", [])
# Filter out infinite values
coding_gain_tensor = torch.tensor(coding_gain) if not isinstance(coding_gain, torch.Tensor) else coding_gain
finite_mask = torch.isfinite(coding_gain_tensor)
snr_range_tensor = torch.tensor(snr_range) if not isinstance(snr_range, torch.Tensor) else snr_range
snr_finite = snr_range_tensor[finite_mask]
gain_finite = coding_gain_tensor[finite_mask]
ax.plot(snr_finite, gain_finite, "o-", linewidth=2, markersize=6)
ax.set_xlabel("SNR (dB)", fontsize=12)
ax.set_ylabel("Coding Gain (dB)", fontsize=12)
ax.set_title(f'Coding Gain - {results.get("code_type", "Unknown")} Code', fontsize=14)
ax.grid(True, alpha=0.3)
# Add average coding gain
if "average_coding_gain" in results:
avg_gain = results["average_coding_gain"]
ax.axhline(y=avg_gain, color="red", linestyle="--", alpha=0.7, label=f"Average: {avg_gain:.2f} dB")
ax.legend()
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=self.dpi, bbox_inches="tight")
return fig
[docs]
def plot_benchmark_summary(self, results_file: str, save_path: Optional[str] = None) -> plt.Figure:
"""Plot summary of multiple benchmark results.
Args:
results_file: Path to JSON file containing benchmark results
save_path: Optional path to save the figure
Returns:
Matplotlib figure object
"""
with open(results_file) as f:
data = json.load(f)
benchmarks = data.get("benchmark_results", [])
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12), dpi=self.dpi)
# Success rate
success_count = sum(1 for b in benchmarks if b.get("success", False))
total_count = len(benchmarks)
ax1.pie([success_count, total_count - success_count], labels=["Success", "Failed"], autopct="%1.1f%%", colors=["lightgreen", "lightcoral"])
ax1.set_title("Benchmark Success Rate", fontsize=14)
# Execution times
execution_times = [b.get("execution_time", 0) for b in benchmarks]
bars = ax2.bar(range(len(execution_times)), execution_times, alpha=0.7)
ax2.set_xlabel("Benchmark Index", fontsize=12)
ax2.set_ylabel("Execution Time (s)", fontsize=12)
ax2.set_title("Execution Times", fontsize=14)
ax2.grid(True, alpha=0.3)
# Color bars by execution time
if execution_times:
import matplotlib.colors as mcolors
import numpy as np
colors = mcolors.LinearSegmentedColormap.from_list("plasma", ["purple", "red", "orange", "yellow"])(np.linspace(0, 1, len(bars)))
for bar, color in zip(bars, colors):
bar.set_color(color)
# Device usage
devices = [b.get("device", "unknown") for b in benchmarks]
device_counts: dict[str, int] = {}
for device in devices:
device_counts[device] = device_counts.get(device, 0) + 1
if device_counts:
ax3.pie(device_counts.values(), labels=device_counts.keys(), autopct="%1.1f%%")
ax3.set_title("Device Usage", fontsize=14)
else:
ax3.axis("off")
# Summary statistics
summary_stats = data.get("summary", {})
stats_text = []
if "total_benchmarks" in summary_stats:
stats_text.append(f"Total Benchmarks: {summary_stats['total_benchmarks']}")
if "successful_benchmarks" in summary_stats:
stats_text.append(f"Successful: {summary_stats['successful_benchmarks']}")
if "total_execution_time" in summary_stats:
stats_text.append(f"Total Time: {summary_stats['total_execution_time']:.2f}s")
if "average_execution_time" in summary_stats:
stats_text.append(f"Avg Time: {summary_stats['average_execution_time']:.2f}s")
ax4.text(0.1, 0.9, "\n".join(stats_text), transform=ax4.transAxes, fontsize=12, verticalalignment="top", bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8))
ax4.set_title("Summary Statistics", fontsize=14)
ax4.axis("off")
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=self.dpi, bbox_inches="tight")
return fig
[docs]
def create_benchmark_report(self, results_file: str, output_dir: str = "benchmark_plots"):
"""Create a comprehensive visual report from benchmark results.
Args:
results_file: Path to JSON file containing benchmark results
output_dir: Directory to save plots
"""
Path(output_dir).mkdir(exist_ok=True)
with open(results_file) as f:
data = json.load(f)
benchmarks = data.get("benchmark_results", [])
# Create summary plot
summary_fig = self.plot_benchmark_summary(results_file, save_path=f"{output_dir}/summary.png")
plt.close(summary_fig)
# Create individual plots for each benchmark
for i, benchmark in enumerate(benchmarks):
if not benchmark.get("success", False):
continue
benchmark_name = benchmark.get("benchmark_name", f"benchmark_{i}")
safe_name = benchmark_name.replace(" ", "_").replace("(", "").replace(")", "")
try:
# BER plots
if any(key in benchmark for key in ["ber_simulated", "ber_results", "ber_uncoded"]):
ber_fig = self.plot_ber_curve(benchmark, save_path=f"{output_dir}/{safe_name}_ber.png")
plt.close(ber_fig)
# Throughput plots
if "throughput_results" in benchmark or "throughput_bps" in benchmark:
throughput_fig = self.plot_throughput_comparison(benchmark, save_path=f"{output_dir}/{safe_name}_throughput.png")
plt.close(throughput_fig)
# Latency plots
if "inference_latency_ms" in benchmark or "mean_latency" in benchmark:
latency_fig = self.plot_latency_distribution(benchmark, save_path=f"{output_dir}/{safe_name}_latency.png")
plt.close(latency_fig)
# Coding gain plots
if "coding_gain_db" in benchmark:
coding_fig = self.plot_coding_gain(benchmark, save_path=f"{output_dir}/{safe_name}_coding_gain.png")
plt.close(coding_fig)
except Exception as e:
print(f"Warning: Could not create plot for {benchmark_name}: {e}")
print(f"Benchmark report saved to {output_dir}/")