前缀调优(Prefix Tuning):轻量级参数高效微调新范式
在深度学习领域,大型语言模型(LLMs)展现出了惊人的能力,但它们的训练成本极高。为了适应特定任务,研究者们开发了多种参数高效微调(Parameter-Efficient Fine-Tuning, PEFT)方法。其中,前缀调优(Prefix Tuning)作为一种创新且高效的PEFT技术,近年来受到了广泛关注。它通过为Transformer模型的每个注意力层添加可学习的前缀向量,在不修改原始预训练权重的情况下,显著提升下游任务的性能。
Transformer 回顾与微调挑战
标准的Transformer架构由多个相同的堆叠层组成,每一层通常包含两个子层:一个多头自注意力层和一个前馈神经网络层。每个子层的输出都会经过残差连接和层归一化(即"Pre-LN"结构)。
在传统的全参数微调中,我们需要更新模型的所有权重以适应新任务。对于拥有数十亿甚至数千亿参数的LLM,这会带来巨大的计算和存储开销。因此,PEFT的目标是在仅训练少量额外参数的同时,保持或超越全参数微调的效果。
前缀调优的核心思想
前缀调优的核心理念是:在Transformer的每一层中,为输入序列的每个位置添加一个可学习的"前缀"向量,这个前缀会直接影响该层注意力机制的键(Key)和值(Value)计算,从而引导模型产生期望的行为。
具体来说,对于一个长度为 $n$ 的输入序列,前缀调优会在第 $i$ 个Transformer层中为每个注意力头添加两个长度为 $n$ 的可学习向量 $\mathbf{P}i^K$ 和 $\mathbf{P}i^V$。
数学公式
在第 $i$ 层,原始的注意力计算为:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{dk}}\right)V $$
其中,
- $Q = XWQ$
- $K = XWK$
- $V = XWV$
在前缀调优中,我们将这些键和值替换为:
$$ \tilde{K} = [\mathbf{P}i^K; XWK] $$
$$ \tilde{V} = [\mathbf{P}i^V; XWV] $$
这里的 ; 表示拼接操作。这意味着,对于输入序列的每一个元素 $xj$,它的键和值不仅取决于 $xj$ 本身,还受到整个前缀 $\mathbf{P}i^K$ 和 $\mathbf{P}i^V$ 的影响。
最终,注意力输出为:
$$ \text{Attention}(Q, \tilde{K}, \tilde{V}) $$
由于 $Q$ 仍然是 $XWQ$,而 $\tilde{K}$ 和 $\tilde{V}$ 包含了前缀信息,这使得注意力机制能够被外部控制。
实现细节
- 前缀长度选择:前缀的长度通常远小于输入序列的长度,例如,如果输入序列有1024个token,前缀可能只有64个token长。这意味着我们只需要训练64 (层数 头数 * 隐藏维度)个参数,这相对于整个模型的参数来说是非常小的。
- 位置无关性:前缀向量本身不带有位置编码信息。然而,它们的作用是通过注意力机制来传递的。在实践中,研究者发现,将前缀放置在序列的开头(即 $Pi^K$ 和 $Pi^V$ 的第一个元素对应于输入序列的第一个token)效果最佳。
前缀调优的优点
- 参数效率极高:这是前缀调优最显著的优势。例如,对于一个拥有3亿参数的模型,使用前缀调优可能只需要训练不到20万个参数,节省了超过99%的训练资源。
- 性能接近全参数微调:尽管只训练了少量参数,但前缀调优在许多基准测试中表现出色,其性能通常与全参数微调或适配器(Adapter)等方法相当,甚至在某些情况下更优。
- 易于并行化和部署:前缀可以像其他嵌入一样被缓存和重用,使得模型推理过程非常高效。
代码示例
以下是一个简化版的PyTorch伪代码,展示如何实现前缀调优的核心逻辑。
import torch
import torch.nn as nn
import torch.nn.functional as F
class PrefixTuningLayer(nn.Module):
def init(self, hiddensize, numheads, numprefixtokens=64, device="cpu"):
super(PrefixTuningLayer, self).init()
self.hiddensize = hiddensize
self.numheads = numheads
self.headdim = hiddensize // numheads
self.numprefixtokens = numprefixtokens
# 可学习的前缀向量
self.prefixkeys = nn.Parameter(torch.randn(numprefixtokens, hiddensize, device=device))
self.prefixvalues = nn.Parameter(torch.randn(numprefixtokens, hiddensize, device=device))
# 假设我们有标准的线性投影层
self.queryproj = nn.Linear(hiddensize, hiddensize)
self.keyproj = nn.Linear(hiddensize, hiddensize)
self.valueproj = nn.Linear(hiddensize, hiddensize)
def forward(self, x):
"""
Args:
x: 输入张量,形状为 (batchsize, seqlen, hiddensize)
Returns:
output: 输出张量,形状为 (batchsize, seqlen, hiddensize)
"""
batchsize, seqlen, = x.shape
# 投影到查询、键、值
Q = self.queryproj(x) # (batchsize, seqlen, hiddensize)
K = self.keyproj(x) # (batchsize, seqlen, hiddensize)
V = self.valueproj(x) # (batchsize, seqlen, hiddensize)
# 重塑为多头格式
Q = Q.reshape(batchsize, seqlen, self.numheads, self.headdim).transpose(1, 2) # (batchsize, numheads, seqlen, headdim)
K = K.reshape(batchsize, seqlen, self.numheads, self.headdim).transpose(1, 2)
V = V.reshape(batchsize, seqlen, self.numheads, self.headdim).transpose(1, 2)
# 添加前缀到键和值
prefixkeysexpanded = self.prefixkeys.unsqueeze(0).unsqueeze(0) # (1, 1, numprefixtokens, hiddensize)
prefixvaluesexpanded = self.prefixvalues.unsqueeze(0).unsqueeze(0) # (1, 1, numprefixtokens, hiddensize)
# 将前缀扩展到batchsize维度并重塑为多头格式
prefixkeysheads = prefixkeysexpanded.reshape(
1, 1, self.numprefixtokens, self.numheads, self.headdim
).repeat(batchsize, self.numheads, 1, 1).transpose(1, 2) # (batchsize, numheads, numprefixtokens, headdim)
prefixvaluesheads = prefixvaluesexpanded.reshape(
1, 1, self.numprefixtokens, self.numheads, self.headdim
).repeat(batchsize, self.numheads, 1, 1).transpose(1, 2) # (batchsize, numheads, numprefixtokens, headdim)
# 拼接前缀和输入序列的K和V
Kwithprefix = torch.cat([prefixkeysheads, K], dim=-2) # (batchsize, numheads, numprefixtokens + seqlen, headdim)
Vwithprefix = torch.cat([prefixvaluesheads, V], dim=-2) # (batchsize, numheads, numprefixtokens + seqlen, headdim)
# 计算注意力
attnweights = torch.matmul(Q, Kwithprefix.transpose(-1, -2)) / (self.headdim ** 0.5)
attnweights = F.softmax(attnweights, dim=-1)
output = torch.matmul(attnweights, Vwithprefix) # (batchsize, numheads, seqlen, headdim)
# 合并多头输出
output = output.transpose(1, 2).reshape(batchsize, seqlen, self.hiddensize)
return output
--- 使用示例 ---
if name == "main":
device = "cuda" if torch.cuda.isavailable() else "cpu"
layer = PrefixTuningLayer(hiddensize=768, numheads=12, numprefixtokens=64, device=device)
# 模拟输入
inputtensor = torch.randn(2, 100, 768).to(device) # batchsize=2, seqlen=100
outputtensor = layer(inputtensor)
print(f"Input shape: {inputtensor.shape}")
print(f"Output shape: {output_tensor.shape}")
与其他PEFT方法的比较
| 方法 | 原理 | 优点 |