vision transformer的计算复杂度

Vision transformer

在这里插入图片描述

假设每个图像有 h ∗ w h*w hw 个patch,维度是 C C C

输入的图像 X X X ( 大小为 h w ∗ C hw* C hwC ),和三个系数矩阵相乘 ( 大小为 C ∗ C C*C CC ),得到 q k v qkv qkv 三个向量 ( h w ∗ C hw*C hwC ),复杂度为:
3 h w C 2 3hwC^2 3hwC2

q q q ( h w ∗ C hw*C hwC ) 和 k T k^T kT ( C ∗ h w C*hw Chw ) 相乘得到矩阵 A A A ( h w ∗ h w hw*hw hwhw ),复杂度为: ( h w ) 2 C (hw)^2C (hw)2C

A A A ( h w ∗ h w hw*hw hwhw ) 和 v v v ( h w ∗ C hw*C hwC )相乘,得到多头注意力的结果 ( h w ∗ C hw*C hwC ),复杂度为: ( h w ) 2 C (hw)^2C (hw)2C

经过MLP投影层 ( C ∗ C C*C CC ),得到 ( h w ∗ C hw*C hwC ),复杂度为:
h w C 2 hwC^2 hwC2

所以复杂度之和为: 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2 + 2(hw)^2C 4hwC2+2(hw)2C

Swin transformer

在这里插入图片描述
基于滑动窗口的多头注意力,是在每个窗口内计算注意力

假设每个窗口有 M × M M×M M×M 个patch

在一个窗口内的复杂度为:

4 M 2 ∗ C + 2 M 4 C 4M^2*C+2M^4C 4M2C+2M4C

共有 h w / M 2 hw /M^2 hw/M2 个窗口,所以复杂度之和为:

4 h w C + 2 M 2 h w C 4hwC+2M^2hwC 4hwC+2M2hwC

Convolutional vision Transformer

使用 s × s s×s s×s 卷积进行卷积投影,有 h w hw hw 个patch,通道维度为 C C C

输入的图像 X X X ( 大小为 h w ∗ C hw* C hwC ),使用三个标准卷积进行投影 ( 大小为 s ∗ s ∗ C s*s*C ssC ),得到 q k v qkv qkv 三个向量 ( h w ∗ C hw*C hwC ),投影的复杂度为:

3 h w s 2 C 2 3hws^2C^2 3hws2C2

使用深度可分离卷积,投影的复杂度为:

3 h w s 2 C 3hws^2C 3hws2C

使用步长大于1的卷积进行多头注意力的投影,减小后面注意力的计算花销。

key和value的步长为2,query的步长为1,key和value的token数量减小了4倍,所以后续的多头注意力计算花销也减小了4倍。

Cross Attention Transformer

在这里插入图片描述

交叉注意力包括IPSA和CPSA,IPSA在单个patch内使用卷积进行投影,CPSA在单个通道计算patch间的注意力

IPSA的复杂度:

patch大小为 N N N,通道数为 C C C

输入的图像 X X X ( 大小为 N 2 ∗ C N^2* C N2C ),使用卷积进行投影 ( 大小为 1 ∗ 1 ∗ C 1*1*C 11C ),得到 q k v qkv qkv 三个向量 ( N 2 ∗ C N^2*C N2C ),复杂度为:
3 N 2 C 2 3N^2C^2 3N2C2

q q q ( N 2 ∗ C N^2*C N2C ) 和 k k k ( C ∗ N 2 C*N^2 CN2 ) 相乘得到矩阵 A A A ( N 2 ∗ N 2 N^2*N^2 N2N2 ),复杂度为: N 4 C 2 N^4C^2 N4C2

A A A ( N 2 ∗ N 2 N^2*N^2 N2N2 ) 和 v v v ( N 2 ∗ C N^2*C N2C )相乘,得到多头注意力的结果 ( N 2 ∗ C N^2*C N2C ),复杂度为: N 4 C 2 N^4C^2 N4C2

经过MLP投影层 ( C ∗ C C*C CC ),得到 ( N 2 ∗ C N^2*C N2C ),复杂度为:
N 2 C 2 N^2C^2 N2C2

单个patch内的复杂度为:

4 N 2 C 2 + 2 N 4 C 2 4N^2C^2+2N^4C^2 4N2C2+2N4C2

共有 H W / N 2 HW/N^2 HW/N2 个patch,所以IPSA总复杂度为:
4 H W C 2 + 2 N 2 H W C 2 4HWC^2+2N^2HWC^2 4HWC2+2N2HWC2

CPSA的复杂度:

patch数目为 H W / N 2 HW/N^2 HW/N2,patch大小为 N 2 N^2 N2

输入的图像 X X X ( 大小为 H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2N2 ),和三个系数矩阵相乘 ( 大小为 N 2 ∗ N 2 N^2*N^2 N2N2 ),得到 q k v qkv qkv 三个向量 ( H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2N2 ),复杂度为:
3 H W N 2 3HWN^2 3HWN2

q q q ( H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2N2 ) 和 k k k ( N 2 ∗ H W / N 2 N^2*HW/N^2 N2HW/N2 ) 相乘得到矩阵 A A A ( H W / N 2 ∗ H W / N 2 HW/N^2*HW/N^2 HW/N2HW/N2 ),复杂度为: ( H W ) 2 / N 2 (HW)^2/N^2 (HW)2/N2

A A A ( H W / N 2 ∗ H W / N 2 HW/N^2*HW/N^2 HW/N2HW/N2 ) 和 v v v ( H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2N2 )相乘,得到多头注意力的结果 ( H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2N2 ),复杂度为: ( H W ) 2 / N 2 (HW)^2/N^2 (HW)2/N2

经过MLP投影层 ( N 2 ∗ N 2 N^2*N^2 N2N2 ),得到 ( H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2N2 ),复杂度为:
H W N 2 HWN^2 HWN2

单个通道内的复杂度为:

4 N 2 H W + 2 ( H W / N ) 2 4N^2HW+2(HW/N)^2 4N2HW+2(HW/N)2

共有 C C C 个通道,所以CPSA总复杂度为:
4 N 2 H W C + 2 ( H W / N ) 2 C 4N^2HWC+2(HW/N)^2C 4N2HWC+2(HW/N)2C