1. 第一次尝试提交;

This commit is contained in:
2025-12-23 14:37:27 +08:00
parent e66642073e
commit e886c805f0
2 changed files with 168 additions and 0 deletions

0
models/__init__.py Normal file
View File

168
models/transformer_.py Normal file
View File

@@ -0,0 +1,168 @@
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import seaborn as sns
import math
def get_device():
if torch.cuda.is_available():
return torch.device('cuda:0')
if torch.backends.mps.is_available():
return torch.device('mps')
return torch.device('cpu')
device = get_device()
print("Using device: " + device)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=100):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)].to(x.device)
class ScaledDotProductAttention(nn.Module):
def forward(self, Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
attn = torch.softmax(scores, dim=-1)
return torch.matmul(attn, V), attn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.Q = nn.Linear(d_model, d_model)
self.K = nn.Linear(d_model, d_model)
self.V = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.attn = ScaledDotProductAttention()
self.last_attn_weights = None
def forward(self, q, k, v, mask=None):
B = q.size(0)
Q = self.Q(q).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.K(k).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.V(v).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
out, attn = self.attn(Q, K, V)
self.last_attn_weights = attn.detach().cpu()
out = out.transpose(1, 2).contiguous().view(B, -1, self.num_heads * self.d_k)
return self.out(out)
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
def forward(self, x):
return self.ff(x)
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.attn = MultiHeadAttention(d_model, num_heads)
self.ff = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
x2 = self.attn(x, x, x)
x = self.norm1(x + x2)
x2 = self.ff(x)
x = self.norm2(x + x2)
return x
class TransformerEncoder(nn.Module):
def __init__(self, input_dim=28, d_model=128, num_heads=4, d_ff=256, num_layers=2, seq_len=28):
super().__init__()
self.input_fc = nn.Linear(input_dim, d_model)
self.pos = PositionalEncoding(d_model, max_len=seq_len)
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
self.classifier = nn.Linear(d_model, 10)
def forward(self, x):
x = self.input_fc(x)
x = self.pos(x)
layer = None
for layer in self.layers:
x = layer(x)
self.attn_weights = layer.attn.last_attn_weights
x = x.mean(dim=1)
return self.classifier(x)
def train_and_test():
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root=".\data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root=".\data", train=False, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
model = TransformerEncoder().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(5):
model.train()
total_loss = 0
for images, labels in train_loader:
images = images.squeeze(1).to(device)
labels = labels.to(device)
preds = model(images)
loss = loss_fn(preds, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")
correct, total = 0, 0
model.eval()
with torch.no_grad():
for images, labels in test_loader:
images = images.squeeze(1).to(device)
labels = labels.to(device)
preds = model(images)
predicted = preds.argmax(dim=1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
print(f"Test Accuracy: {correct / total:.4f}")
return model, test_loader
def visualize_attention_heads(model, test_loader):
model.eval()
images, _ = next(iter(test_loader))
image = images[0].unsqueeze(0).squeeze(1).to(device)
with torch.no_grad():
_ = model(image)
attn_weights = model.attn_weights[0] # shape: [num_heads, seq_len, seq_len]
num_heads = attn_weights.shape[0]
fig, axes = plt.subplots(1, num_heads, figsize=(num_heads * 3, 3))
for i in range(num_heads):
sns.heatmap(attn_weights[i], ax=axes[i], cbar=False)
axes[i].set_title(f"Head {i}")
plt.tight_layout()
plt.show()
if __name__ == "__main__":
model, test_loader = train_and_test()
visualize_attention_heads(model, test_loader)