Vision Transformer:用Transformer处理图像的革命性突破
引言
近年来,深度学习在计算机视觉领域取得了巨大成功。从卷积神经网络(CNN)到更先进的架构如ResNet、EfficientNet等,我们见证了图像识别、目标检测、语义分割等任务性能的显著提升。然而,这些模型大多依赖于卷积操作来捕捉图像的局部特征,这在一定程度上限制了它们对全局上下文信息的建模能力。
2020年,Google Research团队发表的论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》提出了Vision Transformer(ViT),将Transformer架构首次应用于图像分类任务,并取得了令人瞩目的结果。这篇开创性的工作不仅展示了纯Transformer架构在计算机视觉中的潜力,也开启了将Transformer应用于各种视觉任务的新时代。
传统CNN的局限性
在我们深入探讨Vision Transformer之前,有必要理解为什么需要一种新的方法来处理图像数据。卷积神经网络虽然在图像任务中表现出色,但也存在一些固有的局限性:
- 局部感受野:CNN通过滑动窗口的方式处理图像,每个神经元只能看到输入的一小部分区域。虽然可以通过堆叠多层来扩大感受野,但这会增加模型的深度和计算复杂度。
- 平移不变性假设:CNN通常假设图像中的物体在不同位置出现时应该得到相同的响应。然而,这种假设并不总是成立,特别是在需要精确定位或理解物体间相对位置的任务中。
- 难以建模长距离依赖:在处理大尺寸图像时,CNN很难有效地捕捉远距离像素之间的复杂关系。
Vision Transformer的核心思想
Vision Transformer的核心思想是将图像分解为一系列patch(小图块),然后将这些patch序列化作为Transformer的输入。具体来说,ViT将一幅图像分割成固定大小的patch,然后将每个patch展平并线性映射到一个固定维度的向量,这些向量作为Transformer编码器的输入。
这个过程类似于将图像"阅读"成一个单词序列,只不过这里的"单词"是图像patch而不是自然语言的token。通过这种方式,ViT能够利用Transformer强大的自注意力机制来捕捉patch之间的长距离依赖关系。
ViT架构详解
让我们详细了解一下Vision Transformer的各个组件:
1. 图像分块(Patch Embedding)
首先,原始图像被分割成一系列不重叠的patch。例如,对于一个224×224的RGB图像,如果每个patch的大小为16×16,那么可以得到14×14=196个patch。每个patch包含16×16×3=768个像素值。
然后,每个patch被展平为一个一维向量,并通过一个可学习的线性变换层将其投影到Transformer期望的嵌入维度(通常是768维)。这个过程可以表示为:
class PatchEmbedding(nn.Module):
def init(self, imgsize=224, patchsize=16, inchannels=3, embeddim=768):
super().init()
self.imgsize = imgsize
self.patchsize = patchsize
# 计算patch数量
numpatches = (imgsize // patchsize) ** 2
# 创建patch embedding层
self.proj = nn.Conv2d(
inchannels,
embeddim,
kernelsize=patchsize,
stride=patchsize
)
# 可学习的[CLS] token
self.clstoken = nn.Parameter(torch.zeros(1, 1, embeddim))
# 位置编码
self.posembed = nn.Parameter(torch.zeros(1, numpatches + 1, embeddim))
self.posdrop = nn.Dropout(p=0.1)
def forward(self, x):
B = x.shape[0]
x = self.proj(x) # [B, embeddim, H', W']
x = x.flatten(2).transpose(1, 2) # [B, numpatches, embeddim]
# 添加[CLS] token
clstokens = self.clstoken.expand(B, -1, -1)
x = torch.cat((clstokens, x), dim=1)
# 添加位置编码
x = x + self.posembed
x = self.posdrop(x)
return x
2. Transformer编码器
ViT使用了标准的Transformer编码器架构,由多个相同的层组成。每一层包含两个子层:多头自注意力和前馈神经网络,每个子层都采用了残差连接和层归一化。
class TransformerEncoderLayer(nn.Module):
def init(self, dmodel, nhead, dimfeedforward=2048, dropout=0.1):
super().init()
self.selfattn = nn.MultiheadAttention(dmodel, nhead, dropout=dropout)
self.linear1 = nn.Linear(dmodel, dimfeedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dimfeedforward, dmodel)
self.norm1 = nn.LayerNorm(dmodel)
self.norm2 = nn.LayerNorm(dmodel)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = nn.GELU()
def forward(self, src, srcmask=None, srckeypaddingmask=None):
# Self attention
src2 = self.selfattn(src, src, src, attnmask=srcmask,
keypaddingmask=srckeypaddingmask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
# Feed forward
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
3. 分类头
在Transformer编码器之后,ViT使用一个简单的分类头来进行最终的类别预测。这个分类头通常只包含一个线性层和softmax激活函数。
class ViTClassifier(nn.Module):
def init(self, model, numclasses=1000, dropout=0.5):
super().init()
self.model = model
self.head = nn.Sequential(
nn.Linear(model.embeddim, num_classes),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.model(x)
x = x[:, 0] # 取[CLS] token的输出
x = self.head(x)
return x
ViT的训练策略
由于ViT是一个大规模模型(参数量可达数亿),其训练过程需要精心设计和大量计算资源。以下是ViT训练的关键策略:
- 大规模预训练:ViT通常在ImageNet-21k(1400万图像,21841类)上预训练,然后在ImageNet-1k(120万图像,1000类)上进行微调。这种两阶段训练策略有助于模型学习更丰富的视觉表示。
- 数据增强:使用RandAugment、MixUp等高级数据增强技术来提高模型的泛化能力。
- 优化器选择:使用AdamW优化器,配合余弦退火学习率调度器和权重衰减。
- 批处理大小:由于ViT对计算资源要求较高,通常需要使用较大的批处理大小(如4096或更大)来获得稳定的梯度估计。
ViT的变种与改进
ViT的成功催生了许多基于它的改进架构,以下是一些值得注意的变种:
- DeiT(Data-efficient Image Transformers):通过引入教师网络进行知识蒸馏,使ViT能够在较小的数据集上训练。
- Swin Transformer:提出了层次化的Transformer结构,通过滑动窗口注意力机制实现了比ViT更好的效率和性能。
- CvT(Convolutional Vision Transformer):将CNN的卷积操作融入Transformer架构,平衡了性能和效率。
- LeViT:专门设计用于实时推理的轻量级Transformer,在保持高性能的同时大幅减少了计算量。
ViT的应用与挑战
应用前景
ViT及其变种不仅在图像分类任务中表现出色,还在其他计算机视觉任务中具有广泛应用:
- 目标检测:DETR(Detection Transformer)使用Transformer直接进行端到端的目标检测。
- 语义分割:SegFormer等模型将ViT与轻量级MLP解码器结合,实现了高效的语义分割。
- 视频理解:TimeSformer等模型将Transformer扩展到视频数据,用于动作识别等任务。
面临的挑战
尽管ViT取得了显著成功,但仍然面临一些挑战:
- 计算资源需求:ViT需要大量的GPU内存和计算时间,限制了其在小规模设备上的应用。
- 数据依赖性:ViT在数据量不足时容易过拟合,需要大量标注数据进行训练。
- 理论解释性:与CNN相比,Transformer的决策过程更难解释和理解。
结论
Vision Transformer的出现标志着计算机视觉领域的一个重要转折点。它打破了CNN在图像处理