Files
nlp_learning/models/layers/transformer.py
2025-12-30 22:42:47 +08:00

75 lines
2.5 KiB
Python

import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=100):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)].to(x.device)
class ScaledDotProductAttention(nn.Module):
def forward(self, Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
attn = torch.softmax(scores, dim=-1)
return torch.matmul(attn, V), attn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.Q = nn.Linear(d_model, d_model)
self.K = nn.Linear(d_model, d_model)
self.V = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.attn = ScaledDotProductAttention()
self.last_attn_weights = None
def forward(self, q, k, v, mask=None):
B = q.size(0)
Q = self.Q(q).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.K(k).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.V(v).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
out, attn = self.attn(Q, K, V)
self.last_attn_weights = attn.detach().cpu()
out = out.transpose(1, 2).contiguous().view(B, -1, self.num_heads * self.d_k)
return self.out(out)
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
def forward(self, x):
return self.ff(x)
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.attn = MultiHeadAttention(d_model, num_heads)
self.ff = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
x2 = self.attn(x, x, x)
x = self.norm1(x + x2)
x2 = self.ff(x)
x = self.norm2(x + x2)
return x