1. 初始化代码;
This commit is contained in:
91
models/transformer.py
Normal file
91
models/transformer.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import logging
|
||||
from torchvision import datasets, transforms
|
||||
from tools.devices import DeviceManager
|
||||
from models.layers.transformer import *
|
||||
from models.base import BaseModelRunner
|
||||
from tools.visualize import visualize_attention_heads, tqdm_logging
|
||||
|
||||
|
||||
class Transformer(BaseModelRunner):
|
||||
|
||||
def __init__(self, args: dict):
|
||||
logging.info("Initializing Transformer Model")
|
||||
self.args = args
|
||||
logging.info(f"ALL arguments passed: {args.items()}")
|
||||
self.model = TransformerEncoder()
|
||||
self.device = DeviceManager().device
|
||||
self.model.to(self.device)
|
||||
logging.info("Model transform from cpu to {}".format(str(self.device)))
|
||||
|
||||
def forward(self):
|
||||
self.model()
|
||||
|
||||
def run(self):
|
||||
|
||||
logging.info("Loading MNIST dataset from network.")
|
||||
transform = transforms.ToTensor()
|
||||
logging.info("Loading MNIST training dataset from network...")
|
||||
train_dataset = datasets.MNIST(root=".\data", train=True, transform=transform, download=True)
|
||||
test_dataset = datasets.MNIST(root=".\data", train=False, transform=transform, download=True)
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
|
||||
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
|
||||
logging.info("Loaded MNIST dataset from network.")
|
||||
|
||||
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
for epoch in range(5):
|
||||
self.model.train()
|
||||
total_loss = 0
|
||||
len(train_loader)
|
||||
for images, labels in train_loader:
|
||||
|
||||
images = images.squeeze(1).to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
predicts = self.model(images)
|
||||
loss = loss_fn(predicts, labels)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.item()
|
||||
tqdm_logging(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}", epoch, 5)
|
||||
|
||||
correct, total = 0, 0
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
for images, labels in test_loader:
|
||||
images = images.squeeze(1).to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
predicts = self.model(images)
|
||||
predicted = predicts.argmax(dim=1)
|
||||
correct += (predicted == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
logging.info(f"Test Accuracy: {correct / total:.4f}")
|
||||
return self.model, test_loader
|
||||
|
||||
def run_test(self):
|
||||
model, test_loader = self.run()
|
||||
visualize_attention_heads(model, test_loader, device=self.device)
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, input_dim=28, d_model=128, num_heads=4, d_ff=256, num_layers=2, seq_len=28):
|
||||
super().__init__()
|
||||
self.attn_weights = None
|
||||
self.input_fc = nn.Linear(input_dim, d_model)
|
||||
self.pos = PositionalEncoding(d_model, max_len=seq_len)
|
||||
self.layers = nn.ModuleList([
|
||||
EncoderLayer(d_model, num_heads, d_ff)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.classifier = nn.Linear(d_model, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_fc(x)
|
||||
x = self.pos(x)
|
||||
layer = None
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
self.attn_weights = layer.attn.last_attn_weights
|
||||
x = x.mean(dim=1)
|
||||
return self.classifier(x)
|
||||
Reference in New Issue
Block a user