Files
nlp_learning/tools/visualize.py
2025-12-30 22:42:47 +08:00

32 lines
959 B
Python

import logging
import sys
import torch
import seaborn as sns
import matplotlib.pyplot as plt
def visualize_attention_heads(model, test_loader, device):
model.eval()
images, _ = next(iter(test_loader))
image = images[0].unsqueeze(0).squeeze(1).to(device)
with torch.no_grad():
_ = model(image)
attn_weights = model.attn_weights[0] # shape: [num_heads, seq_len, seq_len]
num_heads = attn_weights.shape[0]
fig, axes = plt.subplots(1, num_heads, figsize=(num_heads * 3, 3))
for i in range(num_heads):
sns.heatmap(attn_weights[i], ax=axes[i], cbar=False)
axes[i].set_title(f"Head {i}")
plt.tight_layout()
plt.show()
def tqdm_logging(message, current, total):
percentage = (current / total)
arrow = '#' * int(round(percentage * 100) - 1)
spaces = ' ' * (100 - len(arrow))
logging.info(message + f" Processed: {arrow}{spaces} -> {percentage * 100:.2f} %")
sys.stdout.flush()