27. 实现 GPT 多头注意力层#

27.1. 介绍#

上一节我们已成实现 Transformer 块的层归一化组件,本小节将继续实现 Transformer 块的多头注意力层。

多头注意力层

27.2. 环境配置#

27.2.1. 安装依赖#

!pip install --upgrade dsxllm

27.2.2. 环境版本#

from dsxllm.util import show_version

show_version()
本书愿景:
+------+--------------------------------------------------------+
| Info |                  《动手学大语言模型》                  |
+------+--------------------------------------------------------+
| 作者 |                       吾辈亦有感                       |
| 哔站 |      https://space.bilibili.com/3546632320715420       |
| 定位 | 基于'从零构建'的理念,用实战帮助程序员快速入门大模型。 |
| 愿景 | 若让你的AI学习之路走的更容易一点,我将倍感荣幸!祝好😄 |
+------+--------------------------------------------------------+
环境信息:
+-------------+--------------+------------------------+
| Python 版本 | PyTorch 版本 | PyTorch Lightning 版本 |
+-------------+--------------+------------------------+
|   3.12.12   |    2.10.0    |         2.6.1          |
+-------------+--------------+------------------------+

27.3. 带掩码多头注意力层#

在 GPT 中,带掩码的多头注意力层(Masked Multi-Head Self-Attention)是其核心组件之一,负责捕捉序列内部的依赖关系,同时确保生成过程是自回归的——即每个位置只能依赖于它之前的位置,不能“看到”未来的信息。

27.3.1. 自注意力机制#

首先,标准的缩放点积自注意力(Scaled Dot-Product Self-Attention)允许序列中的每个位置与序列中的所有其他位置进行交互,计算注意力权重,从而聚合全局信息。公式如下:

\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]

其中,\(Q, K, V\) 分别是通过输入嵌入与权重矩阵相乘得到的查询、键和值矩阵。对于长度为 \(n\) 的序列,注意力矩阵 \(QK^T\) 的大小为 \(n \times n\),元素 \((i, j)\) 表示位置 \(i\) 对位置 \(j\) 的注意力分数。经过softmax后,每个位置会聚合所有位置的值。

27.3.2. 多头注意力机制#

多头注意力机制(Multi-Head Attention)通过并行计算多个注意力头,每个头关注序列中的不同子空间,从而捕捉到不同类型的依赖关系。在学习多头注意力详细的计算过程之前,我们先回顾一下标准的单头自注意力机制计算细节。

单头注意力层的计算图如下所示:

单头注意力层

多头注意力 是指将查询、键、值分别投影到 head 个不同的低维空间(称为“头”),然后在每个头上独立地执行注意力计算,最后将所有头的结果拼接并投影回原始维度。这样做的好处是:

  • 捕捉不同类型的依赖关系:不同的头可以关注序列中不同位置的关系(例如,语法关系、长距离依赖、局部关系等)。

  • 增强表达能力:通过多个头,模型可以从多个角度理解序列,提升表示能力。

虽然在 GPT 中,每个头虽然独立地进行自注意力计算,但是每个头的Q、K、V矩阵都是从原始输入通过一个大的线性变换得到的,然后再将大的Q、K、V矩阵分别切分成多个小的Q、K、V矩阵。这样做的好处是,将多个小矩阵的计算转换成了一个大矩阵的计算,可以充分利用 GPU 并行计算的能力,提高计算效率。

所以,多头注意力层的计算图如下所示:

多头注意力层

进行多头注意力计算时,有两个额外的拆分和合并操作:

  • 拆分操作:将大的Q、K、V矩阵分别切分成多个小的Q、K、V矩阵,每个小矩阵的大小为 head_dim

  • 合并操作:将多个头的输出拼接起来,将拼接后的矩阵再通过一个线性变换融合信息,就得到最终的多头注意力输出。

27.3.3. 因果注意力掩码#

在 GPT 这样的自回归语言模型中,目标是预测下一个词,因此生成时必须确保当前位置不能访问未来位置的信息。否则,模型会“作弊”,导致训练和推理不一致。这里使用的掩码是因果注意力掩码,它的作用是在计算注意力时,将未来位置的信息屏蔽掉。

掩码的作用就是在计算注意力权重时,将未来位置的注意力分数设置为负无穷(通常用 −∞ 或一个非常大的负数),使得经过softmax后这些位置的权重变为0。这样,每个位置只能关注它自己和之前的位置。

在GPT这样的自回归语言模型中,目标是预测下一个词,因此生成时必须确保当前位置不能访问未来位置的信息。否则,模型会“作弊”,导致训练和推理不一致。这里使用的掩码是因果注意力掩码,它的作用是在计算注意力时,将未来位置的信息屏蔽掉。

掩码的作用就是在计算注意力权重时,将未来位置的注意力分数设置为负无穷(通常用 −∞ 或一个非常大的负数),使得经过softmax后这些位置的权重变为0。这样,每个位置只能关注它自己和之前的位置。

以输入序列 动手学大模型 为例,使用因果注意力掩码的过程如下所示:

  1. 序列长度为 6,则其对应的注意力矩阵为:

image.png

  1. 假设每个 Token 之间自注意力得分矩阵为:

自注意力得分

  1. 应用因果注意力掩码,将未来位置的注意力分数设置为负无穷:

应用注意力掩码

  1. 使用 softmax 计算注意力权重,掩码位置的权重为 0:

注意力权重

27.4. 带掩码多头注意力层的代码实现#

import torch
from dsxllm.util import print_red


class MultiHeadAttention(torch.nn.Module):
    """
    多头自注意力模块。

    Args:
        input_dim (int): 输入特征维度。
        output_dim (int): 输出特征维度(必须能被 num_heads 整除)。
        context_length (int): 最大序列长度,用于生成因果掩码。
        dropout (float): Dropout 概率。
        num_heads (int): 注意力头的数量。
        qkv_bias (bool): 是否在 Q、K、V 线性投影中使用偏置。
    """

    def __init__(
        self, input_dim, output_dim, context_length, dropout, num_heads, qkv_bias=False
    ):
        super().__init__()
        assert output_dim % num_heads == 0, "output_dim must be divisible by num_heads"

        self.output_dim = output_dim
        self.num_heads = num_heads
        self.head_dim = output_dim // num_heads  # 每个头的维度

        # 定义查询、键、值的线性层
        self.query_layer = torch.nn.Linear(
            in_features=input_dim, out_features=output_dim, bias=qkv_bias
        )
        self.key_layer = torch.nn.Linear(
            in_features=input_dim, out_features=output_dim, bias=qkv_bias
        )
        self.value_layer = torch.nn.Linear(
            in_features=input_dim, out_features=output_dim, bias=qkv_bias
        )

        # 定义输出层
        self.output_layer = torch.nn.Linear(
            in_features=output_dim, out_features=output_dim
        )

        self.dropout = torch.nn.Dropout(dropout)

        # 因果掩码:上三角矩阵,用于屏蔽未来位置的注意力
        # 注册为缓冲区,不会作为模型参数更新,但会随模型移动
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        """
        前向传播。

        Args:
            x (torch.Tensor): 输入张量,形状为 (batch_size, seq_len, input_dim)。

        Returns:
            torch.Tensor: 输出张量,形状为 (batch_size, seq_len, output_dim)。
        """
        batch_size, seq_len, _ = x.shape

        # 1. 通过线性层得到 Q、K、V,形状为 (batch_size, seq_len, output_dim)
        queries = self.query_layer(x)
        keys = self.key_layer(x)
        values = self.value_layer(x)

        # 2. 将 Q、K、V 拆分为多头,重塑为 (batch_size, seq_len, num_heads, head_dim)
        queries = queries.view(batch_size, seq_len, self.num_heads, self.head_dim)
        keys = keys.view(batch_size, seq_len, self.num_heads, self.head_dim)
        values = values.view(batch_size, seq_len, self.num_heads, self.head_dim)

        # 转置为 (batch_size, seq_len, num_heads, head_dim) ➡️ (batch_size, num_heads, seq_len, head_dim)
        # 这一步是为了将序列长度维度和头维度交换,方便后续的分头进行自注意力计算
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # 3. 使用矩阵计算缩放点积注意力分数,keys 转置最后两维:(batch_size, num_heads, seq_len, head_dim) ➡️ (batch_size, num_heads, head_dim, seq_len)
        # 在PyTorch中,当使用 @ 运算符(或 torch.matmul 函数)进行矩阵乘法时,会自动处理批量维度并对最后两个维度执行标准矩阵乘法
        # 这一步计算的是每个头的注意力分数,形状为 (batch_size, num_heads, seq_len, seq_len)
        attention_scores = queries @ keys.transpose(-2, -1)

        print_red("\n1. 注意力分数:")
        print(attention_scores)

        # 4. 应用因果掩码(屏蔽未来位置)
        # 将原始掩码截断到当前序列长度,并转换为布尔类型,布尔类型张量比浮点数或整数类型占用更少的内存(每个元素 1 字节),这使得它在处理大规模数据时更加高效。
        causal_mask = self.mask.bool()[:seq_len, :seq_len]
        attention_scores.masked_fill_(causal_mask, -torch.inf)

        print_red("\n2. 应用因果掩码后的注意力分数:")
        print(attention_scores)

        # 5. 计算注意力权重(softmax)并应用 dropout
        attention_weights = torch.softmax(
            attention_scores / (self.head_dim**0.5), dim=-1
        )
        print_red("\n3. 应用因果掩码后的注意力权重:")
        print(attention_weights)

        attention_weights = self.dropout(attention_weights)

        # 6. 加权求和得到上下文向量,形状为 (batch_size, num_heads, seq_len, head_dim)
        context = attention_weights @ values

        # 7. 合并多头输出
        # 先转置为 (batch_size, num_tokens, num_heads, head_dim),然后重塑为 (batch_size, num_tokens, output_dim)
        context = context.transpose(1, 2).contiguous()
        context = context.view(batch_size, seq_len, self.output_dim)

        # 8. 输出投影(可选)
        output = self.output_layer(context)

        return output
/Users/kong/opt/anaconda3/envs/dsx-ai/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

27.5. 带掩码多头注意力层的应用实例#

batch_size = 1
seq_len = 6
input_dim, output_dim, dropout, num_heads = 4, 4, 0.1, 2

# 创建输入数据
x = torch.randn(batch_size, seq_len, input_dim)

print("\n多头自注意力层最初输入的数据形状:\n", x.shape, "\n")


# 初始化多头自注意力层
multi_head_attn = MultiHeadAttention(input_dim, output_dim, seq_len, dropout, num_heads)
print("多头自注意力层:\n", multi_head_attn, "\n")

# 使用多头自注意力层进行计算
context_vec = multi_head_attn(x)

print("\n多头自注意力层最终输出的数据形状:\n", context_vec.shape)
多头自注意力层最初输入的数据形状:
 torch.Size([1, 6, 4]) 

多头自注意力层:
 MultiHeadAttention(
  (query_layer): Linear(in_features=4, out_features=4, bias=False)
  (key_layer): Linear(in_features=4, out_features=4, bias=False)
  (value_layer): Linear(in_features=4, out_features=4, bias=False)
  (output_layer): Linear(in_features=4, out_features=4, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
) 


1. 注意力分数:
tensor([[[[-1.9254,  0.5550,  0.0929, -0.0293,  0.5207, -1.0635],
          [ 0.6581,  0.2271, -0.2832,  0.8665, -1.7912, -0.1338],
          [-0.3595, -0.0062,  0.0836, -0.2311,  0.5223, -0.0676],
          [-0.4889,  0.2849, -0.0633,  0.2884, -0.4250, -0.4419],
          [ 1.0251, -0.5967,  0.1322, -0.6032,  0.8883,  0.9255],
          [-1.6281,  0.1580,  0.2664, -0.6645,  1.6453, -0.5279]],

         [[-0.0861, -0.3119,  0.1356, -0.1364,  0.3260,  0.2548],
          [ 0.4471, -1.1365,  0.2674, -0.1142,  0.2327,  1.0212],
          [-0.1740,  0.1345,  0.0044, -0.0474,  0.1245, -0.1356],
          [-0.7856,  0.1128,  0.1942, -0.3616,  0.9074, -0.1919],
          [ 1.2710, -0.0874, -0.3478,  0.6135, -1.5345,  0.2295],
          [-0.4169,  0.4893, -0.0483, -0.0637,  0.1815, -0.4671]]]],
       grad_fn=<UnsafeViewBackward0>)

2. 应用因果掩码后的注意力分数:
tensor([[[[-1.9254,    -inf,    -inf,    -inf,    -inf,    -inf],
          [ 0.6581,  0.2271,    -inf,    -inf,    -inf,    -inf],
          [-0.3595, -0.0062,  0.0836,    -inf,    -inf,    -inf],
          [-0.4889,  0.2849, -0.0633,  0.2884,    -inf,    -inf],
          [ 1.0251, -0.5967,  0.1322, -0.6032,  0.8883,    -inf],
          [-1.6281,  0.1580,  0.2664, -0.6645,  1.6453, -0.5279]],

         [[-0.0861,    -inf,    -inf,    -inf,    -inf,    -inf],
          [ 0.4471, -1.1365,    -inf,    -inf,    -inf,    -inf],
          [-0.1740,  0.1345,  0.0044,    -inf,    -inf,    -inf],
          [-0.7856,  0.1128,  0.1942, -0.3616,    -inf,    -inf],
          [ 1.2710, -0.0874, -0.3478,  0.6135, -1.5345,    -inf],
          [-0.4169,  0.4893, -0.0483, -0.0637,  0.1815, -0.4671]]]],
       grad_fn=<MaskedFillBackward0>)

3. 应用因果掩码后的注意力权重:
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5756, 0.4244, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.2738, 0.3516, 0.3746, 0.0000, 0.0000, 0.0000],
          [0.1720, 0.2974, 0.2325, 0.2981, 0.0000, 0.0000],
          [0.3254, 0.1034, 0.1730, 0.1029, 0.2954, 0.0000],
          [0.0442, 0.1563, 0.1687, 0.0873, 0.4473, 0.0962]],

         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.7539, 0.2461, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.2960, 0.3682, 0.3358, 0.0000, 0.0000, 0.0000],
          [0.1603, 0.3027, 0.3206, 0.2164, 0.0000, 0.0000],
          [0.4054, 0.1551, 0.1291, 0.2547, 0.0558, 0.0000],
          [0.1255, 0.2381, 0.1628, 0.1610, 0.1915, 0.1211]]]],
       grad_fn=<SoftmaxBackward0>)

多头自注意力层最终输出的数据形状:
 torch.Size([1, 6, 4])

27.6. 答疑讨论#