import logging import sys import torch import seaborn as sns import matplotlib.pyplot as plt def visualize_attention_heads(model, test_loader, device): model.eval() images, _ = next(iter(test_loader)) image = images[0].unsqueeze(0).squeeze(1).to(device) with torch.no_grad(): _ = model(image) attn_weights = model.attn_weights[0] # shape: [num_heads, seq_len, seq_len] num_heads = attn_weights.shape[0] fig, axes = plt.subplots(1, num_heads, figsize=(num_heads * 3, 3)) for i in range(num_heads): sns.heatmap(attn_weights[i], ax=axes[i], cbar=False) axes[i].set_title(f"Head {i}") plt.tight_layout() plt.show() def tqdm_logging(message, current, total): percentage = (current / total) arrow = '#' * int(round(percentage * 100) - 1) spaces = ' ' * (100 - len(arrow)) logging.info(message + f" Processed: {arrow}{spaces} -> {percentage * 100:.2f} %") sys.stdout.flush()