返回列表

前缀调优(Prefix Tuning):轻量级参数高效微调新范式

发布于 ·

前缀调优(Prefix Tuning):轻量级参数高效微调新范式

在深度学习领域,大型语言模型(LLMs)展现出了惊人的能力,但它们的训练成本极高。为了适应特定任务,研究者们开发了多种参数高效微调(Parameter-Efficient Fine-Tuning, PEFT)方法。其中,前缀调优(Prefix Tuning)作为一种创新且高效的PEFT技术,近年来受到了广泛关注。它通过为Transformer模型的每个注意力层添加可学习的前缀向量,在不修改原始预训练权重的情况下,显著提升下游任务的性能。

Transformer 回顾与微调挑战

标准的Transformer架构由多个相同的堆叠层组成,每一层通常包含两个子层:一个多头自注意力层和一个前馈神经网络层。每个子层的输出都会经过残差连接层归一化(即"Pre-LN"结构)。

在传统的全参数微调中,我们需要更新模型的所有权重以适应新任务。对于拥有数十亿甚至数千亿参数的LLM,这会带来巨大的计算和存储开销。因此,PEFT的目标是在仅训练少量额外参数的同时,保持或超越全参数微调的效果。

前缀调优的核心思想

前缀调优的核心理念是:在Transformer的每一层中,为输入序列的每个位置添加一个可学习的"前缀"向量,这个前缀会直接影响该层注意力机制的键(Key)和值(Value)计算,从而引导模型产生期望的行为。

具体来说,对于一个长度为 $n$ 的输入序列,前缀调优会在第 $i$ 个Transformer层中为每个注意力头添加两个长度为 $n$ 的可学习向量 $\mathbf{P}i^K$ 和 $\mathbf{P}i^V$。

数学公式

在第 $i$ 层,原始的注意力计算为:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{dk}}\right)V $$

其中,

  • $Q = XWQ$

  • $K = XWK$

  • $V = XWV$

在前缀调优中,我们将这些键和值替换为:

$$ \tilde{K} = [\mathbf{P}i^K; XWK] $$
$$ \tilde{V} = [\mathbf{P}i^V; XWV] $$

这里的 ; 表示拼接操作。这意味着,对于输入序列的每一个元素 $xj$,它的键和值不仅取决于 $xj$ 本身,还受到整个前缀 $\mathbf{P}i^K$ 和 $\mathbf{P}i^V$ 的影响。

最终,注意力输出为:
$$ \text{Attention}(Q, \tilde{K}, \tilde{V}) $$

由于 $Q$ 仍然是 $XWQ$,而 $\tilde{K}$ 和 $\tilde{V}$ 包含了前缀信息,这使得注意力机制能够被外部控制。

实现细节

  1. 前缀长度选择:前缀的长度通常远小于输入序列的长度,例如,如果输入序列有1024个token,前缀可能只有64个token长。这意味着我们只需要训练64 (层数 头数 * 隐藏维度)个参数,这相对于整个模型的参数来说是非常小的。
  2. 位置无关性:前缀向量本身不带有位置编码信息。然而,它们的作用是通过注意力机制来传递的。在实践中,研究者发现,将前缀放置在序列的开头(即 $Pi^K$ 和 $Pi^V$ 的第一个元素对应于输入序列的第一个token)效果最佳。

前缀调优的优点

  1. 参数效率极高:这是前缀调优最显著的优势。例如,对于一个拥有3亿参数的模型,使用前缀调优可能只需要训练不到20万个参数,节省了超过99%的训练资源。
  2. 性能接近全参数微调:尽管只训练了少量参数,但前缀调优在许多基准测试中表现出色,其性能通常与全参数微调或适配器(Adapter)等方法相当,甚至在某些情况下更优。
  3. 易于并行化和部署:前缀可以像其他嵌入一样被缓存和重用,使得模型推理过程非常高效。

代码示例

以下是一个简化版的PyTorch伪代码,展示如何实现前缀调优的核心逻辑。

import torch
import torch.nn as nn
import torch.nn.functional as F

class PrefixTuningLayer(nn.Module):
def init(self, hidden
size, numheads, numprefixtokens=64, device="cpu"):
super(PrefixTuningLayer, self).init()
self.hidden
size = hiddensize
self.num
heads = numheads
self.head
dim = hiddensize // numheads

self.numprefixtokens = numprefixtokens

# 可学习的前缀向量
self.prefixkeys = nn.Parameter(torch.randn(numprefixtokens, hiddensize, device=device))
self.prefixvalues = nn.Parameter(torch.randn(numprefixtokens, hiddensize, device=device))

# 假设我们有标准的线性投影层
self.queryproj = nn.Linear(hiddensize, hiddensize)
self.key
proj = nn.Linear(hiddensize, hiddensize)
self.valueproj = nn.Linear(hiddensize, hiddensize)

def forward(self, x):
"""
Args:
x: 输入张量,形状为 (batch
size, seqlen, hiddensize)
Returns:
output: 输出张量,形状为 (batchsize, seqlen, hiddensize)
"""
batch
size, seqlen, = x.shape

# 投影到查询、键、值
Q = self.queryproj(x) # (batchsize, seqlen, hiddensize)
K = self.keyproj(x) # (batchsize, seqlen, hiddensize)
V = self.valueproj(x) # (batchsize, seqlen, hiddensize)

# 重塑为多头格式
Q = Q.reshape(batchsize, seqlen, self.numheads, self.headdim).transpose(1, 2) # (batchsize, numheads, seqlen, headdim)
K = K.reshape(batchsize, seqlen, self.numheads, self.headdim).transpose(1, 2)
V = V.reshape(batchsize, seqlen, self.numheads, self.headdim).transpose(1, 2)

# 添加前缀到键和值
prefixkeysexpanded = self.prefixkeys.unsqueeze(0).unsqueeze(0) # (1, 1, numprefixtokens, hiddensize)
prefixvaluesexpanded = self.prefixvalues.unsqueeze(0).unsqueeze(0) # (1, 1, numprefixtokens, hiddensize)

# 将前缀扩展到batchsize维度并重塑为多头格式
prefix
keysheads = prefixkeysexpanded.reshape(
1, 1, self.num
prefixtokens, self.numheads, self.headdim
).repeat(batch
size, self.numheads, 1, 1).transpose(1, 2) # (batchsize, numheads, numprefixtokens, headdim)
prefixvaluesheads = prefixvaluesexpanded.reshape(
1, 1, self.numprefixtokens, self.numheads, self.headdim
).repeat(batchsize, self.numheads, 1, 1).transpose(1, 2) # (batchsize, numheads, numprefixtokens, headdim)

# 拼接前缀和输入序列的K和V
K
withprefix = torch.cat([prefixkeysheads, K], dim=-2) # (batchsize, numheads, numprefixtokens + seqlen, headdim)
V
withprefix = torch.cat([prefixvaluesheads, V], dim=-2) # (batchsize, numheads, numprefixtokens + seqlen, headdim)

# 计算注意力
attn
weights = torch.matmul(Q, Kwithprefix.transpose(-1, -2)) / (self.headdim ** 0.5)
attn
weights = F.softmax(attnweights, dim=-1)
output = torch.matmul(attn
weights, Vwithprefix) # (batchsize, numheads, seqlen, headdim)

# 合并多头输出
output = output.transpose(1, 2).reshape(batchsize, seqlen, self.hiddensize)

return output

--- 使用示例 ---

if name == "main": device = "cuda" if torch.cuda.is
available() else "cpu" layer = PrefixTuningLayer(hiddensize=768, numheads=12, numprefixtokens=64, device=device) # 模拟输入 inputtensor = torch.randn(2, 100, 768).to(device) # batchsize=2, seqlen=100 outputtensor = layer(inputtensor) print(f"Input shape: {inputtensor.shape}") print(f"Output shape: {output_tensor.shape}")

与其他PEFT方法的比较

| 方法 | 原理 | 优点 |