From aed4f68b0dee502a46d63f2d2f5dc8a3eb9acfa1 Mon Sep 17 00:00:00 2001 From: ZekShawn Date: Tue, 30 Dec 2025 22:42:47 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E5=88=9D=E5=A7=8B=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=EF=BC=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/bert_.py => confs/__init__.py | 0 confs/logging.conf | 39 +++++++ models/base.py | 16 +++ models/bert.py | 0 models/layers/__init__.py | 0 models/layers/transformer.py | 74 ++++++++++++ models/transformer.py | 91 +++++++++++++++ models/transformer_.py | 168 --------------------------- tools/__init__.py | 0 tools/devices.py | 17 +++ tools/visualize.py | 31 +++++ 11 files changed, 268 insertions(+), 168 deletions(-) rename models/bert_.py => confs/__init__.py (100%) create mode 100644 confs/logging.conf create mode 100644 models/base.py create mode 100644 models/bert.py create mode 100644 models/layers/__init__.py create mode 100644 models/layers/transformer.py create mode 100644 models/transformer.py delete mode 100644 models/transformer_.py create mode 100644 tools/__init__.py create mode 100644 tools/devices.py create mode 100644 tools/visualize.py diff --git a/models/bert_.py b/confs/__init__.py similarity index 100% rename from models/bert_.py rename to confs/__init__.py diff --git a/confs/logging.conf b/confs/logging.conf new file mode 100644 index 0000000..ddb871f --- /dev/null +++ b/confs/logging.conf @@ -0,0 +1,39 @@ +# logging.conf +[loggers] +keys=root,my_module + +[handlers] +keys=consoleHandler,fileHandler + +[formatters] +keys=simpleFormatter,detailedFormatter + +[logger_root] +level=INFO +handlers=consoleHandler + +[logger_my_module] +level=DEBUG +qualname=my_module +handlers=fileHandler +propagate=0 + +[handler_consoleHandler] +class=StreamHandler +level=INFO +formatter=simpleFormatter +args=(sys.stdout,) + +[handler_fileHandler] +class=FileHandler +level=DEBUG +formatter=detailedFormatter +args=('app.log', 'a') + +[formatter_simpleFormatter] +format=%(asctime)s - %(name)s - %(levelname)s - %(message)s +datefmt=%Y-%m-%d %H:%M:%S + +[formatter_detailedFormatter] +format=%(asctime)s - %(name)s - %(levelname)s - %(module)s:%(lineno)d - %(message)s +datefmt=%Y-%m-%d %H:%M:%S diff --git a/models/base.py b/models/base.py new file mode 100644 index 0000000..f126a1c --- /dev/null +++ b/models/base.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod + + +class BaseModelRunner(ABC): + + @abstractmethod + def forward(self): + pass + + @abstractmethod + def run_test(self): + pass + + @abstractmethod + def run(self): + pass diff --git a/models/bert.py b/models/bert.py new file mode 100644 index 0000000..e69de29 diff --git a/models/layers/__init__.py b/models/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/layers/transformer.py b/models/layers/transformer.py new file mode 100644 index 0000000..368a547 --- /dev/null +++ b/models/layers/transformer.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn +import math + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, max_len=100): + super().__init__() + pe = torch.zeros(max_len, d_model) + pos = torch.arange(0, max_len).unsqueeze(1) + div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(pos * div) + pe[:, 1::2] = torch.cos(pos * div) + self.register_buffer('pe', pe.unsqueeze(0)) + + def forward(self, x): + return x + self.pe[:, :x.size(1)].to(x.device) + +class ScaledDotProductAttention(nn.Module): + def forward(self, Q, K, V, mask=None): + d_k = Q.size(-1) + scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) + attn = torch.softmax(scores, dim=-1) + return torch.matmul(attn, V), attn + +class MultiHeadAttention(nn.Module): + def __init__(self, d_model, num_heads): + super().__init__() + assert d_model % num_heads == 0 + self.d_k = d_model // num_heads + self.num_heads = num_heads + self.Q = nn.Linear(d_model, d_model) + self.K = nn.Linear(d_model, d_model) + self.V = nn.Linear(d_model, d_model) + self.out = nn.Linear(d_model, d_model) + self.attn = ScaledDotProductAttention() + self.last_attn_weights = None + + def forward(self, q, k, v, mask=None): + B = q.size(0) + Q = self.Q(q).view(B, -1, self.num_heads, self.d_k).transpose(1, 2) + K = self.K(k).view(B, -1, self.num_heads, self.d_k).transpose(1, 2) + V = self.V(v).view(B, -1, self.num_heads, self.d_k).transpose(1, 2) + out, attn = self.attn(Q, K, V) + self.last_attn_weights = attn.detach().cpu() + out = out.transpose(1, 2).contiguous().view(B, -1, self.num_heads * self.d_k) + return self.out(out) + +class FeedForward(nn.Module): + def __init__(self, d_model, d_ff): + super().__init__() + self.ff = nn.Sequential( + nn.Linear(d_model, d_ff), + nn.ReLU(), + nn.Linear(d_ff, d_model) + ) + + def forward(self, x): + return self.ff(x) + +class EncoderLayer(nn.Module): + def __init__(self, d_model, num_heads, d_ff): + super().__init__() + self.attn = MultiHeadAttention(d_model, num_heads) + self.ff = FeedForward(d_model, d_ff) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x): + x2 = self.attn(x, x, x) + x = self.norm1(x + x2) + x2 = self.ff(x) + x = self.norm2(x + x2) + return x diff --git a/models/transformer.py b/models/transformer.py new file mode 100644 index 0000000..8bb81a2 --- /dev/null +++ b/models/transformer.py @@ -0,0 +1,91 @@ +import logging +from torchvision import datasets, transforms +from tools.devices import DeviceManager +from models.layers.transformer import * +from models.base import BaseModelRunner +from tools.visualize import visualize_attention_heads, tqdm_logging + + +class Transformer(BaseModelRunner): + + def __init__(self, args: dict): + logging.info("Initializing Transformer Model") + self.args = args + logging.info(f"ALL arguments passed: {args.items()}") + self.model = TransformerEncoder() + self.device = DeviceManager().device + self.model.to(self.device) + logging.info("Model transform from cpu to {}".format(str(self.device))) + + def forward(self): + self.model() + + def run(self): + + logging.info("Loading MNIST dataset from network.") + transform = transforms.ToTensor() + logging.info("Loading MNIST training dataset from network...") + train_dataset = datasets.MNIST(root=".\data", train=True, transform=transform, download=True) + test_dataset = datasets.MNIST(root=".\data", train=False, transform=transform, download=True) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False) + logging.info("Loaded MNIST dataset from network.") + + optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3) + loss_fn = nn.CrossEntropyLoss() + + for epoch in range(5): + self.model.train() + total_loss = 0 + len(train_loader) + for images, labels in train_loader: + + images = images.squeeze(1).to(self.device) + labels = labels.to(self.device) + predicts = self.model(images) + loss = loss_fn(predicts, labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.item() + tqdm_logging(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}", epoch, 5) + + correct, total = 0, 0 + self.model.eval() + with torch.no_grad(): + for images, labels in test_loader: + images = images.squeeze(1).to(self.device) + labels = labels.to(self.device) + predicts = self.model(images) + predicted = predicts.argmax(dim=1) + correct += (predicted == labels).sum().item() + total += labels.size(0) + logging.info(f"Test Accuracy: {correct / total:.4f}") + return self.model, test_loader + + def run_test(self): + model, test_loader = self.run() + visualize_attention_heads(model, test_loader, device=self.device) + + +class TransformerEncoder(nn.Module): + def __init__(self, input_dim=28, d_model=128, num_heads=4, d_ff=256, num_layers=2, seq_len=28): + super().__init__() + self.attn_weights = None + self.input_fc = nn.Linear(input_dim, d_model) + self.pos = PositionalEncoding(d_model, max_len=seq_len) + self.layers = nn.ModuleList([ + EncoderLayer(d_model, num_heads, d_ff) + for _ in range(num_layers) + ]) + self.classifier = nn.Linear(d_model, 10) + + def forward(self, x): + x = self.input_fc(x) + x = self.pos(x) + layer = None + for layer in self.layers: + x = layer(x) + self.attn_weights = layer.attn.last_attn_weights + x = x.mean(dim=1) + return self.classifier(x) diff --git a/models/transformer_.py b/models/transformer_.py deleted file mode 100644 index 1edfd2e..0000000 --- a/models/transformer_.py +++ /dev/null @@ -1,168 +0,0 @@ -import torch -import torch.nn as nn -from torchvision import datasets, transforms -import matplotlib.pyplot as plt -import seaborn as sns -import math - - -def get_device(): - if torch.cuda.is_available(): - return torch.device('cuda:0') - if torch.backends.mps.is_available(): - return torch.device('mps') - return torch.device('cpu') - -device = get_device() -print("Using device: " + device) - -class PositionalEncoding(nn.Module): - def __init__(self, d_model, max_len=100): - super().__init__() - pe = torch.zeros(max_len, d_model) - pos = torch.arange(0, max_len).unsqueeze(1) - div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) - pe[:, 0::2] = torch.sin(pos * div) - pe[:, 1::2] = torch.cos(pos * div) - self.register_buffer('pe', pe.unsqueeze(0)) - - def forward(self, x): - return x + self.pe[:, :x.size(1)].to(x.device) - -class ScaledDotProductAttention(nn.Module): - def forward(self, Q, K, V, mask=None): - d_k = Q.size(-1) - scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) - attn = torch.softmax(scores, dim=-1) - return torch.matmul(attn, V), attn - -class MultiHeadAttention(nn.Module): - def __init__(self, d_model, num_heads): - super().__init__() - assert d_model % num_heads == 0 - self.d_k = d_model // num_heads - self.num_heads = num_heads - self.Q = nn.Linear(d_model, d_model) - self.K = nn.Linear(d_model, d_model) - self.V = nn.Linear(d_model, d_model) - self.out = nn.Linear(d_model, d_model) - self.attn = ScaledDotProductAttention() - self.last_attn_weights = None - - def forward(self, q, k, v, mask=None): - B = q.size(0) - Q = self.Q(q).view(B, -1, self.num_heads, self.d_k).transpose(1, 2) - K = self.K(k).view(B, -1, self.num_heads, self.d_k).transpose(1, 2) - V = self.V(v).view(B, -1, self.num_heads, self.d_k).transpose(1, 2) - out, attn = self.attn(Q, K, V) - self.last_attn_weights = attn.detach().cpu() - out = out.transpose(1, 2).contiguous().view(B, -1, self.num_heads * self.d_k) - return self.out(out) - -class FeedForward(nn.Module): - def __init__(self, d_model, d_ff): - super().__init__() - self.ff = nn.Sequential( - nn.Linear(d_model, d_ff), - nn.ReLU(), - nn.Linear(d_ff, d_model) - ) - - def forward(self, x): - return self.ff(x) - -class EncoderLayer(nn.Module): - def __init__(self, d_model, num_heads, d_ff): - super().__init__() - self.attn = MultiHeadAttention(d_model, num_heads) - self.ff = FeedForward(d_model, d_ff) - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - - def forward(self, x): - x2 = self.attn(x, x, x) - x = self.norm1(x + x2) - x2 = self.ff(x) - x = self.norm2(x + x2) - return x - -class TransformerEncoder(nn.Module): - def __init__(self, input_dim=28, d_model=128, num_heads=4, d_ff=256, num_layers=2, seq_len=28): - super().__init__() - self.input_fc = nn.Linear(input_dim, d_model) - self.pos = PositionalEncoding(d_model, max_len=seq_len) - self.layers = nn.ModuleList([ - EncoderLayer(d_model, num_heads, d_ff) - for _ in range(num_layers) - ]) - self.classifier = nn.Linear(d_model, 10) - - def forward(self, x): - x = self.input_fc(x) - x = self.pos(x) - layer = None - for layer in self.layers: - x = layer(x) - self.attn_weights = layer.attn.last_attn_weights - x = x.mean(dim=1) - return self.classifier(x) - -def train_and_test(): - transform = transforms.ToTensor() - train_dataset = datasets.MNIST(root=".\data", train=True, transform=transform, download=True) - test_dataset = datasets.MNIST(root=".\data", train=False, transform=transform, download=True) - train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) - test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False) - - model = TransformerEncoder().to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - loss_fn = nn.CrossEntropyLoss() - - for epoch in range(5): - model.train() - total_loss = 0 - for images, labels in train_loader: - images = images.squeeze(1).to(device) - labels = labels.to(device) - preds = model(images) - loss = loss_fn(preds, labels) - optimizer.zero_grad() - loss.backward() - optimizer.step() - total_loss += loss.item() - print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}") - - correct, total = 0, 0 - model.eval() - with torch.no_grad(): - for images, labels in test_loader: - images = images.squeeze(1).to(device) - labels = labels.to(device) - preds = model(images) - predicted = preds.argmax(dim=1) - correct += (predicted == labels).sum().item() - total += labels.size(0) - print(f"Test Accuracy: {correct / total:.4f}") - return model, test_loader - -def visualize_attention_heads(model, test_loader): - model.eval() - images, _ = next(iter(test_loader)) - image = images[0].unsqueeze(0).squeeze(1).to(device) - with torch.no_grad(): - _ = model(image) - - attn_weights = model.attn_weights[0] # shape: [num_heads, seq_len, seq_len] - num_heads = attn_weights.shape[0] - - fig, axes = plt.subplots(1, num_heads, figsize=(num_heads * 3, 3)) - for i in range(num_heads): - sns.heatmap(attn_weights[i], ax=axes[i], cbar=False) - axes[i].set_title(f"Head {i}") - plt.tight_layout() - plt.show() - -if __name__ == "__main__": - model, test_loader = train_and_test() - visualize_attention_heads(model, test_loader) - diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/devices.py b/tools/devices.py new file mode 100644 index 0000000..7cb4aac --- /dev/null +++ b/tools/devices.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn +from functools import cached_property + + +class DeviceManager(object): + + def __init__(self): + self.device = self.__get_device() + + @staticmethod + def __get_device(): + if torch.cuda.is_available(): + return torch.device('cuda:0') + if torch.backends.mps.is_available(): + return torch.device('mps') + return torch.device('cpu') diff --git a/tools/visualize.py b/tools/visualize.py new file mode 100644 index 0000000..fc7709f --- /dev/null +++ b/tools/visualize.py @@ -0,0 +1,31 @@ +import logging +import sys +import torch +import seaborn as sns +import matplotlib.pyplot as plt + + +def visualize_attention_heads(model, test_loader, device): + model.eval() + images, _ = next(iter(test_loader)) + image = images[0].unsqueeze(0).squeeze(1).to(device) + with torch.no_grad(): + _ = model(image) + + attn_weights = model.attn_weights[0] # shape: [num_heads, seq_len, seq_len] + num_heads = attn_weights.shape[0] + + fig, axes = plt.subplots(1, num_heads, figsize=(num_heads * 3, 3)) + for i in range(num_heads): + sns.heatmap(attn_weights[i], ax=axes[i], cbar=False) + axes[i].set_title(f"Head {i}") + plt.tight_layout() + plt.show() + + +def tqdm_logging(message, current, total): + percentage = (current / total) + arrow = '#' * int(round(percentage * 100) - 1) + spaces = ' ' * (100 - len(arrow)) + logging.info(message + f" Processed: {arrow}{spaces} -> {percentage * 100:.2f} %") + sys.stdout.flush()