32 lines
959 B
Python
32 lines
959 B
Python
import logging
|
|
import sys
|
|
import torch
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def visualize_attention_heads(model, test_loader, device):
|
|
model.eval()
|
|
images, _ = next(iter(test_loader))
|
|
image = images[0].unsqueeze(0).squeeze(1).to(device)
|
|
with torch.no_grad():
|
|
_ = model(image)
|
|
|
|
attn_weights = model.attn_weights[0] # shape: [num_heads, seq_len, seq_len]
|
|
num_heads = attn_weights.shape[0]
|
|
|
|
fig, axes = plt.subplots(1, num_heads, figsize=(num_heads * 3, 3))
|
|
for i in range(num_heads):
|
|
sns.heatmap(attn_weights[i], ax=axes[i], cbar=False)
|
|
axes[i].set_title(f"Head {i}")
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
|
|
def tqdm_logging(message, current, total):
|
|
percentage = (current / total)
|
|
arrow = '#' * int(round(percentage * 100) - 1)
|
|
spaces = ' ' * (100 - len(arrow))
|
|
logging.info(message + f" Processed: {arrow}{spaces} -> {percentage * 100:.2f} %")
|
|
sys.stdout.flush()
|