Transformer复习
无需多言,直接上代码。 import torch import torch.nn as nn import math class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=1024): super().__init__() self.d_model = d_model pe = torch.zeros(max_len, d_model) for pos in range(max_len): for i in range(0, d_model, 2): pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model))) pe[pos, i + 1] = math.cos(pos / 10000 ** (2 * (i + 1) / d_model)) pe = pe.unsqueeze(0) self.register_buffer('pe', pe) def forward(self, x): x = x * math.sqrt(self.d_model) seq_len = x.size(1) x = x + self.pe[:, :seq_len, :] return x class MultiHeadAttention(nn.Module): def __init__(self, d_model, heads, dropout = 0.1): super().__init__() self.d_model = d_model self.h = heads self.d_k = d_model // heads self.q_linear = nn.Linear(d_model, d_model) self.k_linear = nn.Linear(d_model, d_model) self.v_linear = nn.Linear(d_model, d_model) self.out = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) def attention(self, q, k, v, d_k, mask = None, dropout=None): scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: mask = mask.unsqueeze(1) scores = scores.masked_fill(mask == 0, -1e9) scores = F.softmax(scores, dim=-1) if dropout is not None: scores = self.dropout(scores) return torch.matmul(scores, v) def forward(self, q, k, v, mask = None): batch_size = q.size(0) q = self.q_linear(q) k = self.k_linear(k) v = self.v_linear(v) q = q.view(batch_size, -1, self.h, self.d_k) k = k.view(batch_size, -1, self.h, self.d_k) v = v.view(batch_size, -1, self.h, self.d_k) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) scores = self.attention(q, k, v, self.d_k, mask, self.dropout) concat = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) return self.out(concat) class FeedForward(nn.Module): def __init__(self, d_model, d_ff=2048, dropout = 0.1): super().__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.dropout(F.relu(self.linear1(x))) x = self.linear2(x) return x class NormLayer(nn.Module): def __init__(self, d_model, eps = 1e-6): super().__init__() self.size = d_model self.alpha = nn.Parameter(torch.ones(d_model)) self.bias = nn.Parameter(torch.zeros(d_model)) self.eps = eps def forward(self, x): norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias return norm class Encoder(nn.Module): def __init__(self, vocal_size, d_model, N, heads, dropout): super().__init__() self.N = N self.embed = nn.Embedding(vocal_size, d_model) self.pe = PositionalEncoding(d_model) self.layers = nn.ModuleList([EncoderLayer(d_model, heads, dropout) for _ in range(N)]) self.norm = NormLayer(d_model) def forward(self, src, mask): x = self.embed(src) x = self.pe(x) for layer in self.layers: x = layer(x, mask) return self.norm(x) class EncoderLayer(nn.Module): def __init__(self, d_model, heads, dropout = 0.1): super().__init__() self.norm_1 = NormLayer(d_model) self.norm_2 = NormLayer(d_model) self.attn = MultiHeadAttention(d_model, heads, dropout = dropout) self.ff = FeedForward(d_model, dropout = dropout) self.dropout_1 = nn.Dropout(dropout) self.dropout_2 = nn.Dropout(dropout) def forward(self, x, mask): x2 = self.norm_1(x) x = x + self.dropout_1(self.attn(x2, x2, x2, mask)) x2 = self.norm_2(x) x = x + self.dropout_2(self.ff(x2)) return x class Decoder(nn.Module): def __init__(self, vocab_size, d_model, N, heads, dropout): super().__init__() self.N = N self.embed = nn.Embedding(vocab_size, d_model) self.pe = PositionalEncoding(d_model) self.layers = nn.ModuleList([DecoderLayer(d_model, heads, dropout) for _ in range(N)]) self.norm = NormLayer(d_model) def forward(self, trg, e_outputs, src_mask, trg_mask): x = self.embed(trg) x = self.pe(x) for layer in self.layers: x = layer(x, e_outputs, src_mask, trg_mask) return self.norm(x) class DecoderLayer(nn.Module): def __init__(self, d_model, heads, dropout = 0.1): super().__init__() self.norm_1 = NormLayer(d_model) self.norm_2 = NormLayer(d_model) self.norm_3 = NormLayer(d_model) self.attn_1 = MultiHeadAttention(d_model, heads, dropout = dropout) self.attn_2 = MultiHeadAttention(d_model, heads, dropout = dropout) self.ff = FeedForward(d_model, dropout = dropout) self.dropout_1 = nn.Dropout(dropout) self.dropout_2 = nn.Dropout(dropout) self.dropout_3 = nn.Dropout(dropout) def forward(self, x, e_outputs, src_mask, trg_mask): x2 = self.norm_1(x) x = x + self.dropout_1(self.attn_1(x2, x2, x2, trg_mask)) x2 = self.norm_2(x) x = x + self.dropout_2(self.attn_2(x2, e_outputs, e_outputs, src_mask)) x2 = self.norm_3(x) x = x + self.dropout_3(self.ff(x2)) return x class Transformer(nn.Module): def __init__(self, src_vocab, trg_vocab, d_model, N, heads, dropout): super().__init__() self.encoder = Encoder(src_vocab, d_model, N, heads, dropout) self.decoder = Decoder(trg_vocab, d_model, N, heads, dropout) self.out = nn.Linear(d_model, trg_vocab) def forward(self, src, trg, src_mask, trg_mask): e_outputs = self.encoder(src, src_mask) d_outputs = self.decoder(trg, e_outputs, src_mask, trg_mask) return self.out(d_outputs)