vision transformer的计算复杂度
Vision transformer
假设每个图像有 h ∗ w h*w h∗w 个patch,维度是 C C C
输入的图像
X
X
X ( 大小为
h
w
∗
C
hw* C
hw∗C ),和三个系数矩阵相乘 ( 大小为
C
∗
C
C*C
C∗C ),得到
q
k
v
qkv
qkv 三个向量 (
h
w
∗
C
hw*C
hw∗C ),复杂度为:
3
h
w
C
2
3hwC^2
3hwC2
q q q ( h w ∗ C hw*C hw∗C ) 和 k T k^T kT ( C ∗ h w C*hw C∗hw ) 相乘得到矩阵 A A A ( h w ∗ h w hw*hw hw∗hw ),复杂度为: ( h w ) 2 C (hw)^2C (hw)2C
A A A ( h w ∗ h w hw*hw hw∗hw ) 和 v v v ( h w ∗ C hw*C hw∗C )相乘,得到多头注意力的结果 ( h w ∗ C hw*C hw∗C ),复杂度为: ( h w ) 2 C (hw)^2C (hw)2C
经过MLP投影层 (
C
∗
C
C*C
C∗C ),得到 (
h
w
∗
C
hw*C
hw∗C ),复杂度为:
h
w
C
2
hwC^2
hwC2
所以复杂度之和为: 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2 + 2(hw)^2C 4hwC2+2(hw)2C
Swin transformer
基于滑动窗口的多头注意力,是在每个窗口内计算注意力
假设每个窗口有 M × M M×M M×M 个patch
在一个窗口内的复杂度为:
4 M 2 ∗ C + 2 M 4 C 4M^2*C+2M^4C 4M2∗C+2M4C
共有 h w / M 2 hw /M^2 hw/M2 个窗口,所以复杂度之和为:
4 h w C + 2 M 2 h w C 4hwC+2M^2hwC 4hwC+2M2hwC
Convolutional vision Transformer
使用 s × s s×s s×s 卷积进行卷积投影,有 h w hw hw 个patch,通道维度为 C C C
输入的图像 X X X ( 大小为 h w ∗ C hw* C hw∗C ),使用三个标准卷积进行投影 ( 大小为 s ∗ s ∗ C s*s*C s∗s∗C ),得到 q k v qkv qkv 三个向量 ( h w ∗ C hw*C hw∗C ),投影的复杂度为:
3 h w s 2 C 2 3hws^2C^2 3hws2C2
使用深度可分离卷积,投影的复杂度为:
3 h w s 2 C 3hws^2C 3hws2C
使用步长大于1的卷积进行多头注意力的投影,减小后面注意力的计算花销。
key和value的步长为2,query的步长为1,key和value的token数量减小了4倍,所以后续的多头注意力计算花销也减小了4倍。
Cross Attention Transformer
交叉注意力包括IPSA和CPSA,IPSA在单个patch内使用卷积进行投影,CPSA在单个通道计算patch间的注意力
IPSA的复杂度:
patch大小为 N N N,通道数为 C C C
输入的图像
X
X
X ( 大小为
N
2
∗
C
N^2* C
N2∗C ),使用卷积进行投影 ( 大小为
1
∗
1
∗
C
1*1*C
1∗1∗C ),得到
q
k
v
qkv
qkv 三个向量 (
N
2
∗
C
N^2*C
N2∗C ),复杂度为:
3
N
2
C
2
3N^2C^2
3N2C2
q q q ( N 2 ∗ C N^2*C N2∗C ) 和 k k k ( C ∗ N 2 C*N^2 C∗N2 ) 相乘得到矩阵 A A A ( N 2 ∗ N 2 N^2*N^2 N2∗N2 ),复杂度为: N 4 C 2 N^4C^2 N4C2
A A A ( N 2 ∗ N 2 N^2*N^2 N2∗N2 ) 和 v v v ( N 2 ∗ C N^2*C N2∗C )相乘,得到多头注意力的结果 ( N 2 ∗ C N^2*C N2∗C ),复杂度为: N 4 C 2 N^4C^2 N4C2
经过MLP投影层 (
C
∗
C
C*C
C∗C ),得到 (
N
2
∗
C
N^2*C
N2∗C ),复杂度为:
N
2
C
2
N^2C^2
N2C2
单个patch内的复杂度为:
4 N 2 C 2 + 2 N 4 C 2 4N^2C^2+2N^4C^2 4N2C2+2N4C2
共有
H
W
/
N
2
HW/N^2
HW/N2 个patch,所以IPSA总复杂度为:
4
H
W
C
2
+
2
N
2
H
W
C
2
4HWC^2+2N^2HWC^2
4HWC2+2N2HWC2
CPSA的复杂度:
patch数目为 H W / N 2 HW/N^2 HW/N2,patch大小为 N 2 N^2 N2
输入的图像
X
X
X ( 大小为
H
W
/
N
2
∗
N
2
HW/N^2*N^2
HW/N2∗N2 ),和三个系数矩阵相乘 ( 大小为
N
2
∗
N
2
N^2*N^2
N2∗N2 ),得到
q
k
v
qkv
qkv 三个向量 (
H
W
/
N
2
∗
N
2
HW/N^2*N^2
HW/N2∗N2 ),复杂度为:
3
H
W
N
2
3HWN^2
3HWN2
q q q ( H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2∗N2 ) 和 k k k ( N 2 ∗ H W / N 2 N^2*HW/N^2 N2∗HW/N2 ) 相乘得到矩阵 A A A ( H W / N 2 ∗ H W / N 2 HW/N^2*HW/N^2 HW/N2∗HW/N2 ),复杂度为: ( H W ) 2 / N 2 (HW)^2/N^2 (HW)2/N2
A A A ( H W / N 2 ∗ H W / N 2 HW/N^2*HW/N^2 HW/N2∗HW/N2 ) 和 v v v ( H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2∗N2 )相乘,得到多头注意力的结果 ( H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2∗N2 ),复杂度为: ( H W ) 2 / N 2 (HW)^2/N^2 (HW)2/N2
经过MLP投影层 (
N
2
∗
N
2
N^2*N^2
N2∗N2 ),得到 (
H
W
/
N
2
∗
N
2
HW/N^2*N^2
HW/N2∗N2 ),复杂度为:
H
W
N
2
HWN^2
HWN2
单个通道内的复杂度为:
4 N 2 H W + 2 ( H W / N ) 2 4N^2HW+2(HW/N)^2 4N2HW+2(HW/N)2
共有
C
C
C 个通道,所以CPSA总复杂度为:
4
N
2
H
W
C
+
2
(
H
W
/
N
)
2
C
4N^2HWC+2(HW/N)^2C
4N2HWC+2(HW/N)2C