1. 初始化代码;
This commit is contained in:
31
tools/visualize.py
Normal file
31
tools/visualize.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import logging
|
||||
import sys
|
||||
import torch
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def visualize_attention_heads(model, test_loader, device):
|
||||
model.eval()
|
||||
images, _ = next(iter(test_loader))
|
||||
image = images[0].unsqueeze(0).squeeze(1).to(device)
|
||||
with torch.no_grad():
|
||||
_ = model(image)
|
||||
|
||||
attn_weights = model.attn_weights[0] # shape: [num_heads, seq_len, seq_len]
|
||||
num_heads = attn_weights.shape[0]
|
||||
|
||||
fig, axes = plt.subplots(1, num_heads, figsize=(num_heads * 3, 3))
|
||||
for i in range(num_heads):
|
||||
sns.heatmap(attn_weights[i], ax=axes[i], cbar=False)
|
||||
axes[i].set_title(f"Head {i}")
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def tqdm_logging(message, current, total):
|
||||
percentage = (current / total)
|
||||
arrow = '#' * int(round(percentage * 100) - 1)
|
||||
spaces = ' ' * (100 - len(arrow))
|
||||
logging.info(message + f" Processed: {arrow}{spaces} -> {percentage * 100:.2f} %")
|
||||
sys.stdout.flush()
|
||||
Reference in New Issue
Block a user