Vision Transformer

ViT

ViT首次将Transformer应用到图像处理领域,属于Transformer的Encoder-only类型。相比同参数量的CNN,经过大规模数据集训练后有着更好的效果。

ViT中使用到的方法

分块(Patch)

将Transformer应用到图像上,首先面临的问题就是token的规模过大,将占用大量内存。ViT中先将图像进行分块展平,再使用线性层映射到低维输入到Transformer中,有效地减少了token的长度。

对于2D图片$\mathbf x\in \mathbb{R}^{H\times W\times C}$,将其分割成$P\times P$的patch,展平得到$\mathbf x_p\in \mathbb{R}^{N\times (P^2\cdot C)},N=\frac{HW}{P^2}$。$N$对应Transformer输入序列的有效长度,而Transformer使用的向量长度为$D$,则需要使用可学习的线性映射将图像patch映射到$D$维上,即$\mathbf x_p\mathbf E\in \mathbb{R}^{N\times D}$。

将图像转为embedding的方式:

  • 将图像分块,每一块展平作为一个token。

  • 使用二维卷积预处理,将得到特征图flatten。

位置编码

使用可学习的一维位置编码,让模型自行学习图片像素的位置关系。

CNN的结构有以下特性:

  • 相邻区域具有相似特征
  • 平移不变性

ViT相比CNN,消除了这些归纳偏置。Encoder中只有MLP具有局部、平移不变的特性,而自注意力层则会关注全局特性。

Class Token

Class token的机制源于BERT,ViT希望也能将其用于图像分类任务中,提出在输入的token中加入一个class token作为第0个向量,与由图像得到的token一同输入到Transformer中。最后取Transformer输出中的第0个向量作为分类预测的依据。另一种处理方式为对输出的图像token进行平均池化,作为分类判断的特征。

ViT为了验证Transformer在图像处理中也能有优秀的处理能力,使用了前者。

Transformer Encoder

ViT的编码器设计与Transformer的编码器结构相同,由多个包含多头自注意力和MLP的块连接而成。每个块中包含两个残差结构,在残差结构的头部使用LayerNorm。MLP包含两层,使用GELU作为激活函数。

多头注意力机制:

  • 标准自注意力机制:

    通过MLP将序列$\mathbf z\in \mathbb R ^{N\times D}$映射到$q,k,v\in\mathbb R^{N\times D_h}$。注意力的权重$A_{ij}$由$q,k$的相似程度决定,并以加权的方式作用到$v$上:

  • 多头自注意力机制:将输入序列$\mathbf z\in \mathbb R ^{N\times D}$拆分为$k$个等长的序列$\mathbf z_i\in\mathbb R^{N\times \frac Dk}$,并行运行$k$个自注意力操作,最后将多头的输出拼接为一个输出。

代码实现

vit-pytorch/vit_pytorch/vit.py

FeedForward对应Transformer Encoder内的MLP前馈部分:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x):
return self.net(x)

Attention对应Transformer Encoder内的Attention部分:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads # 将输入的维度按照多头注意力机制划分
project_out = not (heads == 1 and dim_head == dim) # 判断是否为多头自注意力

self.heads = heads
self.scale = dim_head ** -0.5

self.norm = nn.LayerNorm(dim)

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()

def forward(self, x):
"""
Args:
x: torch.Tensor, token with dimension: [B, N, D]
"""
x = self.norm(x)

qkv = self.to_qkv(x).chunk(3, dim = -1)
# qkv: List[torch.Tensor]
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
# q, k, v: [N, D, heads, dim_head]
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# dots: [N, D, heads, heads]
attn = self.attend(dots)
attn = self.dropout(attn)

out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

ViT的Transformer部分(只需要Encoder):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))

def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x

return self.norm(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)

assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

# 将patch的维度映射到Transformer能处理的token的长度
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

self.pool = pool
self.to_latent = nn.Identity()

self.mlp_head = nn.Linear(dim, num_classes)

def forward(self, img):
"""
Args:
img: torch.Tensor, [B, C, H, W]
"""
x = self.to_patch_embedding(img)
# x: [B, N, D]
b, n, _ = x.shape

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)

x = self.transformer(x)

x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

x = self.to_latent(x)
return self.mlp_head(x)