import logging from torchvision import datasets, transforms from tools.devices import DeviceManager from models.layers.transformer import * from models.base import BaseModelRunner from tools.visualize import visualize_attention_heads, tqdm_logging class Transformer(BaseModelRunner): def __init__(self, args: dict): logging.info("Initializing Transformer Model") self.args = args logging.info(f"ALL arguments passed: {args.items()}") self.model = TransformerEncoder() self.device = DeviceManager().device self.model.to(self.device) logging.info("Model transform from cpu to {}".format(str(self.device))) def forward(self): self.model() def run(self): logging.info("Loading MNIST dataset from network.") transform = transforms.ToTensor() logging.info("Loading MNIST training dataset from network...") train_dataset = datasets.MNIST(root=".\data", train=True, transform=transform, download=True) test_dataset = datasets.MNIST(root=".\data", train=False, transform=transform, download=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False) logging.info("Loaded MNIST dataset from network.") optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3) loss_fn = nn.CrossEntropyLoss() for epoch in range(5): self.model.train() total_loss = 0 len(train_loader) for images, labels in train_loader: images = images.squeeze(1).to(self.device) labels = labels.to(self.device) predicts = self.model(images) loss = loss_fn(predicts, labels) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() tqdm_logging(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}", epoch, 5) correct, total = 0, 0 self.model.eval() with torch.no_grad(): for images, labels in test_loader: images = images.squeeze(1).to(self.device) labels = labels.to(self.device) predicts = self.model(images) predicted = predicts.argmax(dim=1) correct += (predicted == labels).sum().item() total += labels.size(0) logging.info(f"Test Accuracy: {correct / total:.4f}") return self.model, test_loader def run_test(self): model, test_loader = self.run() visualize_attention_heads(model, test_loader, device=self.device) class TransformerEncoder(nn.Module): def __init__(self, input_dim=28, d_model=128, num_heads=4, d_ff=256, num_layers=2, seq_len=28): super().__init__() self.attn_weights = None self.input_fc = nn.Linear(input_dim, d_model) self.pos = PositionalEncoding(d_model, max_len=seq_len) self.layers = nn.ModuleList([ EncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers) ]) self.classifier = nn.Linear(d_model, 10) def forward(self, x): x = self.input_fc(x) x = self.pos(x) layer = None for layer in self.layers: x = layer(x) self.attn_weights = layer.attn.last_attn_weights x = x.mean(dim=1) return self.classifier(x)