PyTorch深度学习框架进阶——注意力机制原理

喜欢花科技君 2025-03-06 06:43:10

注意力机制(Attention Mechanism)在自然语言处理(NLP)和计算机视觉(CV)领域有着广泛的应用,是理解现代深度学习模型(如BERT、GPT)的核心。它的核心思想是让模型在处理输入数据时,能够动态地关注输入的不同部分,从而提高模型的性能。

基本原理加权求和:注意力机制通过为输入的每个部分分配一个权重,来计算一个加权求和的表示。权重通常是通过一个神经网络计算得出的,反映了模型对不同输入部分的关注程度。上下文向量:通过加权求和,模型生成一个上下文向量,这个向量包含了与当前任务最相关的信息。自注意力(Self-Attention):在自注意力机制中,输入序列的每个元素都可以与其他元素进行交互,从而捕捉到序列中元素之间的关系。广泛应用于Transformer模型。应用机器翻译:在翻译过程中,模型可以根据当前翻译的单词动态地关注源语言句子的不同部分。图像处理:在图像分类或目标检测中,注意力机制可以帮助模型关注图像中的重要区域。文本生成:在生成文本时,模型可以根据上下文动态调整关注的内容,从而生成更连贯的文本。代码实现

Scaled Dot-Product Attention

注意力机制的核心是计算查询(Query)、键(Key)、值(Value)之间的相关性:

import torchimport torch.nn as nnimport torch.nn.functional as Fclass ScaledDotProductAttention(nn.Module): def __init__(self, d_k): super().__init__() self.d_k = d_k # Query和Key的维度 def forward(self, Q, K, V, mask=None): # Q: [batch_size, n_heads, seq_len, d_k] # K, V: [batch_size, n_heads, seq_len, d_k] scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k)) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attn_weights = F.softmax(scores, dim=-1) output = torch.matmul(attn_weights, V) return output, attn_weights

Multi-Head Attention

多头注意力并行计算多个注意力头,增强模型表达能力:

class MultiHeadAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.d_model = d_model # 输入维度 self.n_heads = n_heads self.d_k = d_model // n_heads self.W_Q = nn.Linear(d_model, d_model) self.W_K = nn.Linear(d_model, d_model) self.W_V = nn.Linear(d_model, d_model) self.W_O = nn.Linear(d_model, d_model) def split_heads(self, x): # x: [batch_size, seq_len, d_model] batch_size = x.size(0) return x.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) def forward(self, Q, K, V, mask=None): Q = self.split_heads(self.W_Q(Q)) # [batch, n_heads, seq_len, d_k] K = self.split_heads(self.W_K(K)) V = self.split_heads(self.W_V(V)) attn_output, attn_weights = ScaledDotProductAttention(self.d_k)(Q, K, V, mask) attn_output = attn_output.transpose(1, 2).contiguous().view(attn_output.size(0), -1, self.d_model) return self.W_O(attn_output)

注意力机制

0 阅读:1

喜欢花科技君

简介:感谢大家的关注