Vision Transformer
ViT
ViT首次将Transformer应用到图像处理领域,属于Transformer的Encoder-only类型。相比同参数量的CNN,经过大规模数据集训练后有着更好的效果。
ViT中使用到的方法 分块(Patch) 将Transformer应用到图像上,首先面临的问题就是token的规模过大,将占用大量内存。ViT中先将图像进行分块展平,再使用线性层映射到低维输入到Transformer中,有效地减少了token的长度。
对于2D图片$\mathbf x\in \mathbb{R}^{H\times W\times C}$,将其分割成$P\times P$的patch,展平得到$\mathbf x_p\in \mathbb{R}^{N\times (P^2\cdot C)},N=\frac{HW}{P^2}$。$N$对应Transformer输入序列的有效长度,而Transformer使用的向量长度为$D$,则需要使用可学习的线性映射将图像patch映射到$D$维上,即$\mathbf x_p\mathbf E\in \mathbb{R}^{N\times D}$。
将图像转为embedding的方式:
将图像分块,每一块展平作为一个token。
使用二维卷积预处理,将得到特征图flatten。
位置编码 使用可学习的一维位置编码,让模型自行学习图片像素的位置关系。
CNN的结构有以下特性:
ViT相比CNN,消除了这些归纳偏置。Encoder中只有MLP具有局部、平移不变的特性,而自注意力层则会关注全局特性。
Class Token Class token的机制源于BERT,ViT希望也能将其用于图像分类任务中,提出在输入的token中加入一个class token作为第0个向量,与由图像得到的token一同输入到Transformer中。最后取Transformer输出中的第0个向量作为分类预测的依据。另一种处理方式为对输出的图像token进行平均池化,作为分类判断的特征。
ViT为了验证Transformer在图像处理中也能有优秀的处理能力,使用了前者。
ViT的编码器设计与Transformer的编码器结构相同,由多个包含多头自注意力和MLP的块连接而成。每个块中包含两个残差结构,在残差结构的头部使用LayerNorm。MLP包含两层,使用GELU作为激活函数。
多头注意力机制:
标准自注意力机制:
通过MLP将序列$\mathbf z\in \mathbb R ^{N\times D}$映射到$q,k,v\in\mathbb R^{N\times D_h}$。注意力的权重$A_{ij}$由$q,k$的相似程度决定,并以加权的方式作用到$v$上:
多头自注意力机制:将输入序列$\mathbf z\in \mathbb R ^{N\times D}$拆分为$k$个等长的序列$\mathbf z_i\in\mathbb R^{N\times \frac Dk}$,并行运行$k$个自注意力操作,最后将多头的输出拼接为一个输出。
代码实现 vit-pytorch/vit_pytorch/vit.py
FeedForward
对应Transformer Encoder内的MLP前馈部分:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 class FeedForward (nn.Module): def __init__ (self, dim, hidden_dim, dropout = 0. ): super ().__init__() self .net = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward (self, x ): return self .net(x)
Attention
对应Transformer Encoder内的Attention部分:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 class Attention (nn.Module): def __init__ (self, dim, heads = 8 , dim_head = 64 , dropout = 0. ): super ().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self .heads = heads self .scale = dim_head ** -0.5 self .norm = nn.LayerNorm(dim) self .attend = nn.Softmax(dim = -1 ) self .dropout = nn.Dropout(dropout) self .to_qkv = nn.Linear(dim, inner_dim * 3 , bias = False ) self .to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward (self, x ): """ Args: x: torch.Tensor, token with dimension: [B, N, D] """ x = self .norm(x) qkv = self .to_qkv(x).chunk(3 , dim = -1 ) q, k, v = map (lambda t: rearrange(t, 'b n (h d) -> b h n d' , h = self .heads), qkv) dots = torch.matmul(q, k.transpose(-1 , -2 )) * self .scale attn = self .attend(dots) attn = self .dropout(attn) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)' ) return self .to_out(out)
ViT的Transformer部分(只需要Encoder):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 class Transformer (nn.Module): def __init__ (self, dim, depth, heads, dim_head, mlp_dim, dropout = 0. ): super ().__init__() self .norm = nn.LayerNorm(dim) self .layers = nn.ModuleList([]) for _ in range (depth): self .layers.append(nn.ModuleList([ Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), FeedForward(dim, mlp_dim, dropout = dropout) ])) def forward (self, x ): for attn, ff in self .layers: x = attn(x) + x x = ff(x) + x return self .norm(x)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 class ViT (nn.Module): def __init__ (self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls' , channels = 3 , dim_head = 64 , dropout = 0. , emb_dropout = 0. ): super ().__init__() image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) assert image_height % patch_height == 0 and image_width % patch_width == 0 , 'Image dimensions must be divisible by the patch size.' num_patches = (image_height // patch_height) * (image_width // patch_width) patch_dim = channels * patch_height * patch_width assert pool in {'cls' , 'mean' }, 'pool type must be either cls (cls token) or mean (mean pooling)' self .to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)' , p1 = patch_height, p2 = patch_width), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim), ) self .pos_embedding = nn.Parameter(torch.randn(1 , num_patches + 1 , dim)) self .cls_token = nn.Parameter(torch.randn(1 , 1 , dim)) self .dropout = nn.Dropout(emb_dropout) self .transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) self .pool = pool self .to_latent = nn.Identity() self .mlp_head = nn.Linear(dim, num_classes) def forward (self, img ): """ Args: img: torch.Tensor, [B, C, H, W] """ x = self .to_patch_embedding(img) b, n, _ = x.shape cls_tokens = repeat(self .cls_token, '1 1 d -> b 1 d' , b = b) x = torch.cat((cls_tokens, x), dim=1 ) x += self .pos_embedding[:, :(n + 1 )] x = self .dropout(x) x = self .transformer(x) x = x.mean(dim = 1 ) if self .pool == 'mean' else x[:, 0 ] x = self .to_latent(x) return self .mlp_head(x)