import torch import torch.nn as nn from torchvision import datasets, transforms import matplotlib.pyplot as plt import seaborn as sns import math def get_device(): if torch.cuda.is_available(): return torch.device('cuda:0') if torch.backends.mps.is_available(): return torch.device('mps') return torch.device('cpu') device = get_device() print("Using device: " + device) class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=100): super().__init__() pe = torch.zeros(max_len, d_model) pos = torch.arange(0, max_len).unsqueeze(1) div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(pos * div) pe[:, 1::2] = torch.cos(pos * div) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1)].to(x.device) class ScaledDotProductAttention(nn.Module): def forward(self, Q, K, V, mask=None): d_k = Q.size(-1) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) attn = torch.softmax(scores, dim=-1) return torch.matmul(attn, V), attn class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() assert d_model % num_heads == 0 self.d_k = d_model // num_heads self.num_heads = num_heads self.Q = nn.Linear(d_model, d_model) self.K = nn.Linear(d_model, d_model) self.V = nn.Linear(d_model, d_model) self.out = nn.Linear(d_model, d_model) self.attn = ScaledDotProductAttention() self.last_attn_weights = None def forward(self, q, k, v, mask=None): B = q.size(0) Q = self.Q(q).view(B, -1, self.num_heads, self.d_k).transpose(1, 2) K = self.K(k).view(B, -1, self.num_heads, self.d_k).transpose(1, 2) V = self.V(v).view(B, -1, self.num_heads, self.d_k).transpose(1, 2) out, attn = self.attn(Q, K, V) self.last_attn_weights = attn.detach().cpu() out = out.transpose(1, 2).contiguous().view(B, -1, self.num_heads * self.d_k) return self.out(out) class FeedForward(nn.Module): def __init__(self, d_model, d_ff): super().__init__() self.ff = nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model) ) def forward(self, x): return self.ff(x) class EncoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff): super().__init__() self.attn = MultiHeadAttention(d_model, num_heads) self.ff = FeedForward(d_model, d_ff) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) def forward(self, x): x2 = self.attn(x, x, x) x = self.norm1(x + x2) x2 = self.ff(x) x = self.norm2(x + x2) return x class TransformerEncoder(nn.Module): def __init__(self, input_dim=28, d_model=128, num_heads=4, d_ff=256, num_layers=2, seq_len=28): super().__init__() self.input_fc = nn.Linear(input_dim, d_model) self.pos = PositionalEncoding(d_model, max_len=seq_len) self.layers = nn.ModuleList([ EncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers) ]) self.classifier = nn.Linear(d_model, 10) def forward(self, x): x = self.input_fc(x) x = self.pos(x) layer = None for layer in self.layers: x = layer(x) self.attn_weights = layer.attn.last_attn_weights x = x.mean(dim=1) return self.classifier(x) def train_and_test(): transform = transforms.ToTensor() train_dataset = datasets.MNIST(root=".\data", train=True, transform=transform, download=True) test_dataset = datasets.MNIST(root=".\data", train=False, transform=transform, download=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False) model = TransformerEncoder().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) loss_fn = nn.CrossEntropyLoss() for epoch in range(5): model.train() total_loss = 0 for images, labels in train_loader: images = images.squeeze(1).to(device) labels = labels.to(device) preds = model(images) loss = loss_fn(preds, labels) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}") correct, total = 0, 0 model.eval() with torch.no_grad(): for images, labels in test_loader: images = images.squeeze(1).to(device) labels = labels.to(device) preds = model(images) predicted = preds.argmax(dim=1) correct += (predicted == labels).sum().item() total += labels.size(0) print(f"Test Accuracy: {correct / total:.4f}") return model, test_loader def visualize_attention_heads(model, test_loader): model.eval() images, _ = next(iter(test_loader)) image = images[0].unsqueeze(0).squeeze(1).to(device) with torch.no_grad(): _ = model(image) attn_weights = model.attn_weights[0] # shape: [num_heads, seq_len, seq_len] num_heads = attn_weights.shape[0] fig, axes = plt.subplots(1, num_heads, figsize=(num_heads * 3, 3)) for i in range(num_heads): sns.heatmap(attn_weights[i], ax=axes[i], cbar=False) axes[i].set_title(f"Head {i}") plt.tight_layout() plt.show() if __name__ == "__main__": model, test_loader = train_and_test() visualize_attention_heads(model, test_loader)