diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/transformer_.py b/models/transformer_.py new file mode 100644 index 0000000..1edfd2e --- /dev/null +++ b/models/transformer_.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +from torchvision import datasets, transforms +import matplotlib.pyplot as plt +import seaborn as sns +import math + + +def get_device(): + if torch.cuda.is_available(): + return torch.device('cuda:0') + if torch.backends.mps.is_available(): + return torch.device('mps') + return torch.device('cpu') + +device = get_device() +print("Using device: " + device) + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, max_len=100): + super().__init__() + pe = torch.zeros(max_len, d_model) + pos = torch.arange(0, max_len).unsqueeze(1) + div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(pos * div) + pe[:, 1::2] = torch.cos(pos * div) + self.register_buffer('pe', pe.unsqueeze(0)) + + def forward(self, x): + return x + self.pe[:, :x.size(1)].to(x.device) + +class ScaledDotProductAttention(nn.Module): + def forward(self, Q, K, V, mask=None): + d_k = Q.size(-1) + scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) + attn = torch.softmax(scores, dim=-1) + return torch.matmul(attn, V), attn + +class MultiHeadAttention(nn.Module): + def __init__(self, d_model, num_heads): + super().__init__() + assert d_model % num_heads == 0 + self.d_k = d_model // num_heads + self.num_heads = num_heads + self.Q = nn.Linear(d_model, d_model) + self.K = nn.Linear(d_model, d_model) + self.V = nn.Linear(d_model, d_model) + self.out = nn.Linear(d_model, d_model) + self.attn = ScaledDotProductAttention() + self.last_attn_weights = None + + def forward(self, q, k, v, mask=None): + B = q.size(0) + Q = self.Q(q).view(B, -1, self.num_heads, self.d_k).transpose(1, 2) + K = self.K(k).view(B, -1, self.num_heads, self.d_k).transpose(1, 2) + V = self.V(v).view(B, -1, self.num_heads, self.d_k).transpose(1, 2) + out, attn = self.attn(Q, K, V) + self.last_attn_weights = attn.detach().cpu() + out = out.transpose(1, 2).contiguous().view(B, -1, self.num_heads * self.d_k) + return self.out(out) + +class FeedForward(nn.Module): + def __init__(self, d_model, d_ff): + super().__init__() + self.ff = nn.Sequential( + nn.Linear(d_model, d_ff), + nn.ReLU(), + nn.Linear(d_ff, d_model) + ) + + def forward(self, x): + return self.ff(x) + +class EncoderLayer(nn.Module): + def __init__(self, d_model, num_heads, d_ff): + super().__init__() + self.attn = MultiHeadAttention(d_model, num_heads) + self.ff = FeedForward(d_model, d_ff) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x): + x2 = self.attn(x, x, x) + x = self.norm1(x + x2) + x2 = self.ff(x) + x = self.norm2(x + x2) + return x + +class TransformerEncoder(nn.Module): + def __init__(self, input_dim=28, d_model=128, num_heads=4, d_ff=256, num_layers=2, seq_len=28): + super().__init__() + self.input_fc = nn.Linear(input_dim, d_model) + self.pos = PositionalEncoding(d_model, max_len=seq_len) + self.layers = nn.ModuleList([ + EncoderLayer(d_model, num_heads, d_ff) + for _ in range(num_layers) + ]) + self.classifier = nn.Linear(d_model, 10) + + def forward(self, x): + x = self.input_fc(x) + x = self.pos(x) + layer = None + for layer in self.layers: + x = layer(x) + self.attn_weights = layer.attn.last_attn_weights + x = x.mean(dim=1) + return self.classifier(x) + +def train_and_test(): + transform = transforms.ToTensor() + train_dataset = datasets.MNIST(root=".\data", train=True, transform=transform, download=True) + test_dataset = datasets.MNIST(root=".\data", train=False, transform=transform, download=True) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False) + + model = TransformerEncoder().to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + loss_fn = nn.CrossEntropyLoss() + + for epoch in range(5): + model.train() + total_loss = 0 + for images, labels in train_loader: + images = images.squeeze(1).to(device) + labels = labels.to(device) + preds = model(images) + loss = loss_fn(preds, labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.item() + print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}") + + correct, total = 0, 0 + model.eval() + with torch.no_grad(): + for images, labels in test_loader: + images = images.squeeze(1).to(device) + labels = labels.to(device) + preds = model(images) + predicted = preds.argmax(dim=1) + correct += (predicted == labels).sum().item() + total += labels.size(0) + print(f"Test Accuracy: {correct / total:.4f}") + return model, test_loader + +def visualize_attention_heads(model, test_loader): + model.eval() + images, _ = next(iter(test_loader)) + image = images[0].unsqueeze(0).squeeze(1).to(device) + with torch.no_grad(): + _ = model(image) + + attn_weights = model.attn_weights[0] # shape: [num_heads, seq_len, seq_len] + num_heads = attn_weights.shape[0] + + fig, axes = plt.subplots(1, num_heads, figsize=(num_heads * 3, 3)) + for i in range(num_heads): + sns.heatmap(attn_weights[i], ax=axes[i], cbar=False) + axes[i].set_title(f"Head {i}") + plt.tight_layout() + plt.show() + +if __name__ == "__main__": + model, test_loader = train_and_test() + visualize_attention_heads(model, test_loader) +