[2201] VRT: A Video Restoration Transformer
Content
Abstract
video restoration methods
- sliding window-based method
input multiple LQ frames to generate a single HQ frame
each input frame processed for multiple times in inference
⟹ \implies ⟹ inefficient feature utilization and increased computation cost - recurrent method
use previously reconstructed HQ frames for subsequent frame reconstruction
3 drawbacks due to recurrent nature- limited in parallelization
- poor at long-range temporal dependency modelling
⟸ \impliedby ⟸ one frame strongly affect the next frame, but its influence quickly lost after few time steps - significant performance drop on few-frame videos
- parallel method
divide video sequence into non-overlapping clips and shift it alternately to enable inter-clip interactions
Illustrative comparison of sliding window-based models, recurrent models and the proposed parallel VRT model. Green and blue circles denote low-quality (LQ) input frames and high-quality (HQ) output frames, respectively. t − 1 t-1 t−1, t t t and t + 1 t+1 t+1 are frame serial numbers. Dashed lines represent information fusion among different frames.
contributions
- propose Video Restoration Transformer (VRT)
- parallel computation, long-range dependency modelling
- jointly extract, align, fuse frame features at multiple scales
- propose multi-head mutual attention (MMA)
- mutual alignment between frames
- SOTA on video restoration
- video SR, video deblurring, video denoising
Method
model architecture
The framework of the proposed Video Restoration Transformer (VRT). Given T low-quality input frames, VRT reconstructs T high-quality frames in parallel. It jointly extracts features, deals with misalignment, and fuses temporal information at multiple scales. On each scale, it has two kinds of modules: temporal mutual self attention (TMSA) and parallel warping. The down-sampling and up-sampling operations between different scales are omitted for clarity.
given a sequence of low-quality input frames
I
L
Q
∈
R
T
×
H
×
W
×
C
i
n
I^{LQ}\in\mathbb{R}^{T\times H\times W\times C_{in}}
ILQ∈RT×H×W×Cin, a sequence of high-quality target frames
I
H
Q
∈
R
T
×
s
H
×
s
W
×
C
o
u
t
I^{HQ}\in\mathbb{R}^{T\times sH\times sW\times C_{out}}
IHQ∈RT×sH×sW×Cout
where,
s
s
s is upscaling factor:
s
>
1
s>1
s>1 for sr,
s
=
1
s=1
s=1 for db and dn
aim to restore
T
T
T HQ frames from
T
T
T LQ frames in parallel for various video restoration tasks
feature extraction
extract shallow features
I
S
F
∈
R
T
×
H
×
W
×
C
I^{SF}\in\mathbb{R}^{T\times H\times W\times C}
ISF∈RT×H×W×C from
I
L
Q
I^{LQ}
ILQ by a conv
propose a multi-scale network that aligns frames at different resolutions based on U-Net
capture features and motions at different scales by TMSA and PW, where skip connections added for features of same scales
add TMSA for further feature refinement to obtain deep features
I
D
F
∈
R
T
×
H
×
W
×
C
I^{DF}\in\mathbb{R}^{T\times H\times W\times C}
IDF∈RT×H×W×C
reconstruction
restore HQ frames
I
R
H
Q
∈
R
T
×
s
H
×
s
W
×
C
I^{RHQ}\in{\Reals}^{T\times sH\times sW\times C}
IRHQ∈RT×sH×sW×C from addition of shallow feature
I
S
F
I^{SF}
ISF and deep feature
I
D
F
I^{DF}
IDF
- sr: sub-pixel conv with upscale factor s s s
- db, dn: single conv
loss function Charbonnier loss
between reconstructed HQ sequence
I
R
H
Q
I^{RHQ}
IRHQ and ground-truth HQ sequence
I
H
Q
I^{HQ}
IHQ
L
=
∥
I
R
H
Q
−
I
H
Q
∥
2
+
ϵ
2
\mathcal{L}=\sqrt{\Vert I^{RHQ}-I^{HQ}\Vert^2+{\epsilon}^2}
L=∥IRHQ−IHQ∥2+ϵ2
where, ϵ \epsilon ϵ is a constant, empirically set as 1e-3
temporal mutual self attention (TMSA)
Illustrations for mutual attention and temporal mutual self attention (TMSA). In (a), we let the orange square (the i i i-th element of the reference frame) query elements in the supporting frame and use their weighted features as a new representation for the orange square. The weights are shown around solid arrows (we only show three examples for clarity). When A i , k → 1 A_{i, k}\rightarrow1 Ai,k→1 and the rest A i , j → 0 ( j ≠ k ) A_{i, j}\rightarrow0 (j\neq k) Ai,j→0(j=k), the mutual attention equals to warping the yellow square to the position of the orange square (illustrated as a dashed arrow). (b) shows a stack of temporal mutual self attention (TMSA) layers. The sequence is partitioned into 2-frame clips at each layer and shifted for every other layer to enable cross-clip interactions. Dashed lines represent information fusion among different frames.
mutual attention
given reference frame features
X
R
∈
R
N
×
C
X^R\in{\Reals}^{N\times C}
XR∈RN×C, neighboring frame features
X
S
∈
R
N
×
C
X^S\in{\Reals}^{N\times C}
XS∈RN×C
compute query, key, value by linear projection
Q
R
=
X
R
P
Q
K
S
=
X
S
P
K
V
S
=
X
S
P
V
\begin{aligned} Q^R&=X^RP^Q \\ K^S&=X^SP^K \\ V^S&=X^SP^V \end{aligned}
QRKSVS=XRPQ=XSPK=XSPV
where,
P
Q
,
P
K
,
P
V
∈
R
C
×
D
P^Q, P^K, P^V\in{\Reals}^{C\times D}
PQ,PK,PV∈RC×D are projection matrices,
N
N
N is feature elements number,
D
D
D is channels number of projected features
use
Q
R
Q^R
QR to query
K
S
K^S
KS to generate attention features for weighed sum of
V
S
V^S
VS
M
A
(
Q
R
,
K
S
,
V
S
)
=
s
o
f
t
m
a
x
(
Q
R
(
K
S
)
T
D
)
V
S
MA(Q^R, K^S, V^S)=softmax(\frac{Q^R(K^S)^T}{\sqrt{D}})V^S
MA(QR,KS,VS)=softmax(DQR(KS)T)VS
rewrite equation for
i
i
i-th element in reference frame
Y
i
,
:
R
=
∑
j
=
1
N
A
i
,
j
V
j
,
:
S
Y_{i, :}^R=\sum_{j=1}^NA{i, j}V_{j, :}^S
Yi,:R=j=1∑NAi,jVj,:S
where, Y i , : R Y_{i, :}^R Yi,:R is the new features of i i i-th element in reference frame, A ∈ R N × N A\in{\Reals}^{N\times N} A∈RN×N is attention features reflecting correlations between reference and neighboring frame
K
k
,
:
S
K_{k, :}^S
Kk,:S (yellow box in fig.a) is the most similar element to
Q
i
,
:
R
Q_{i, :}^R
Qi,:R (orange box in fig.a), and
K
j
,
:
S
(
j
≠
k
)
K_{j, :}^S (j\neq k)
Kj,:S(j=k) are dissimilar to
Q
i
R
Q_i^R
QiR
A
i
,
k
>
A
i
,
j
,
∀
j
≠
k
,
j
≤
N
A_{i, k}>A_{i, j}, \forall j\neq k, j\leq N
Ai,k>Ai,j,∀j=k,j≤N
{ A i , k → 1 , A i , j → 0 , ∀ j ≠ k , j ≤ N \begin{cases} A_{i, k}\rightarrow1 &\text{, } \\ A_{i, j}\rightarrow0 &\text{, } \forall j\neq k, j\leq N\\ \end{cases} {Ai,k→1Ai,j→0, , ∀j=k,j≤N
combining above equations, have
Y
i
,
:
R
=
V
k
,
:
S
Y_{i, :}^R=V_{k, :}^S
Yi,:R=Vk,:S
⟹
\implies
⟹ move
k
k
k-th element in neighboring frame to the position of
i
i
i-th element in reference frame (red arrow in fig.a)
⟹
\implies
⟹ image warping given an optical flow vector
in practice, reference frame and neighboring frame can be exchanged, allowing mutual alignment between those 2 frames
similar to MSA, perform attention for
h
h
h times and concatenate results as multi-head mutual attention (MMA)
benefits of mutual attention
- preserve information from neighboring frames adaptively
avoid “black hole” artifacts - no inductive biases of locality
inherent to most CNN-based motion estimation
performance drop when 2 neighboring objects move towards different directions - conduct motion estimation and warp on features in a joint way
optical flows only estimated on RGB image and not robust
temporal mutual self attention
combine mutual attention with self-attention
given
X
∈
R
2
×
N
×
C
X\in{\Reals}^{2\times N\times C}
X∈R2×N×Crepresent 2 frames
split
X
X
X into 2 part of features
X
1
,
X
2
=
s
p
l
i
t
0
(
L
N
(
X
)
)
∈
R
1
×
N
×
C
X_1, X_2=split_0(LN(X))\in{\Reals}^{1\times N\times C}
X1,X2=split0(LN(X))∈R1×N×C
where,
s
p
l
i
t
(
⋅
)
split(\cdot)
split(⋅) is an operator on 0-dimension
apply MMA on
X
1
,
X
2
X_1, X_2
X1,X2 for 2 times: warp
X
1
X_1
X1 towards
X
2
X_2
X2, warp
X
2
X_2
X2 towards
X
1
X_1
X1
Y
1
=
M
M
A
(
X
1
,
X
2
)
Y
2
=
M
M
A
(
X
2
,
X
1
)
\begin{aligned} Y_1&=MMA(X_1, X_2) \\ Y_2&=MMA(X_2, X_1) \end{aligned}
Y1Y2=MMA(X1,X2)=MMA(X2,X1)
combine warped features and concatenate with MSA result
Y
=
c
o
n
c
a
t
0
(
c
o
n
c
a
t
2
(
Y
1
,
Y
2
)
,
M
S
A
(
X
)
)
Y=concat_0(concat_2(Y_1, Y_2), MSA(X))
Y=concat0(concat2(Y1,Y2),MSA(X))
where,
c
o
n
c
a
t
0
(
⋅
)
,
c
o
n
c
a
t
2
(
⋅
)
concat_0(\cdot), concat_2(\cdot)
concat0(⋅),concat2(⋅) are operators on 0- and 2-dimension
feed
Y
Y
Y into 2 consecutive MLP with skip connection
X
=
M
L
P
(
Y
)
+
X
X
=
M
L
P
(
L
N
(
X
)
)
+
X
\begin{aligned} X&=MLP(Y)+X \\ X&=MLP(LN(X))+X \end{aligned}
XX=MLP(Y)+X=MLP(LN(X))+X
only 2 frames dealt at a time
⟸
\impliedby
⟸ design of mutual attention
extend for
T
T
T frames: deal with frame-to-frame pairs exhaustively
⟹
\implies
⟹ complexity
O
(
T
2
)
\mathcal{O}(T^2)
O(T2)
solution inspired by shifted-window mechanism in Swin
step 1 partitions video sequence into non-overlapping 2-frame clips, and apply MMA-MSA to each clip in parallel
step 2 shift sequence temporally by 1 frame for every other layer (in fig.b) to enable cross-clip connections
⟹
\implies
⟹ complexity
O
(
T
)
\mathcal{O}(T)
O(T)
temporal receptive field size increase when multiple TMSA modules stacked together
at layer
ℓ
(
ℓ
≥
2
)
\ell(\ell\geq2)
ℓ(ℓ≥2), one frame utilize information from up to
2
(
ℓ
−
1
)
2(\ell-1)
2(ℓ−1) frames
parallel warping (PW)
spatial window partition
⟹
\implies
⟹ mutual attention unable to deal with large motions well
solution use feature warping at the end of each stage
Illustration of parallel warping. For every frame feature X t ( t ≤ T ) X_t(t\leq T) Xt(t≤T), frame X t − 1 X_{t-1} Xt−1 and X t + 1 X_{t+1} Xt+1 are warped towards X t X_t Xt as X ^ t − 1 \hat{X}_{t-1} X^t−1 and X ^ t + 1 \hat{X}_{t+1} X^t+1, respectively. Then, X t X_t Xt, X ^ t − 1 \hat{X}_{t-1} X^t−1 and X ^ t + 1 \hat{X}_{t+1} X^t+1 are concatenated together (denoted by blue boxes) for feature fusion and dimension reduction with a multi-layer perception (MLP). The final output is X ˉ t \bar{X}_t Xˉt. The dashed arrows and circles denote warping operations and warped features, respectively.
given frame
X
t
X_t
Xt and neighboring frames
X
t
−
1
,
X
t
+
1
X_{t-1}, X_{t+1}
Xt−1,Xt+1
step 1 calculate optical flows
O
t
−
1
,
t
,
O
t
+
1
,
t
O_{t-1, t}, O_{t+1, t}
Ot−1,t,Ot+1,t from
X
t
X_t
Xt and
X
t
−
1
,
X
t
+
1
X_{t-1}, X_{t+1}
Xt−1,Xt+1
step 2 use
O
t
−
1
,
t
,
O
t
+
1
,
t
O_{t-1, t}, O_{t+1, t}
Ot−1,t,Ot+1,t to warp
X
t
X_t
Xt to obtain initial warped features
X
t
−
1
′
,
X
t
+
1
′
X_{t-1}', X_{t+1}'
Xt−1′,Xt+1′
X
t
−
1
′
=
w
a
r
p
(
X
t
−
1
,
O
t
−
1
,
t
)
X
t
+
1
′
=
w
a
r
p
(
X
t
+
1
,
O
t
+
1
,
t
)
\begin{aligned} X_{t-1}'&=warp(X_{t-1}, O_{t-1, t}) \\ X_{t+1}'&=warp(X_{t+1}, O_{t+1, t}) \end{aligned}
Xt−1′Xt+1′=warp(Xt−1,Ot−1,t)=warp(Xt+1,Ot+1,t)
step 3 predict offset residuals
o
t
−
1
,
t
,
o
t
+
1
,
t
o_{t-1, t}, o_{t+1, t}
ot−1,t,ot+1,t and modulation masks
m
t
−
1
,
t
,
m
t
+
1
,
t
m_{t-1, t}, m_{t+1, t}
mt−1,t,mt+1,t
o
t
−
1
,
t
,
o
t
+
1
,
t
,
m
t
−
1
,
t
,
m
t
+
1
,
t
=
C
(
[
O
t
−
1
,
t
,
O
t
+
1
,
t
,
X
t
−
1
′
,
X
t
+
1
′
]
)
o_{t-1, t}, o_{t+1, t}, m_{t-1, t}, m_{t+1, t}=\mathcal{C}([O_{t-1, t}, O_{t+1, t}, X_{t-1}', X_{t+1}'])
ot−1,t,ot+1,t,mt−1,t,mt+1,t=C([Ot−1,t,Ot+1,t,Xt−1′,Xt+1′])
where,
C
(
⋅
)
\mathcal{C}(\cdot)
C(⋅) is a convolution layer,
[
⋅
]
[\cdot]
[⋅] is a concatenation operator
step 4 warp
X
t
−
1
,
X
t
+
1
X_{t-1}, X_{t+1}
Xt−1,Xt+1 with results above
X
^
t
−
1
=
D
(
X
t
−
1
,
O
t
−
1
,
t
+
o
t
−
1
,
t
,
m
t
−
1
,
t
)
X
^
t
+
1
=
D
(
X
t
+
1
,
O
t
+
1
,
t
+
o
t
+
1
,
t
,
m
t
+
1
,
t
)
\begin{aligned} \hat{X}_{t-1}&=\mathcal{D}(X_{t-1}, O_{t-1, t}+o_{t-1, t}, m_{t-1, t}) \\ \hat{X}_{t+1}&=\mathcal{D}(X_{t+1}, O_{t+1, t}+o_{t+1, t}, m_{t+1, t}) \end{aligned}
X^t−1X^t+1=D(Xt−1,Ot−1,t+ot−1,t,mt−1,t)=D(Xt+1,Ot+1,t+ot+1,t,mt+1,t)
where,
D
(
⋅
)
\mathcal{D}(\cdot)
D(⋅) is a deformable convolution layer
step 5 concatenate
X
t
,
X
^
t
−
1
,
X
^
t
+
1
X_t, \hat{X}_{t-1}, \hat{X}_{t+1}
Xt,X^t−1,X^t+1 and feed into an MLP layer for
X
ˉ
t
\bar{X}_t
Xˉt with reduced dimension
Experiment
dataset
resolution | training set | testing set | usage | |
---|---|---|---|---|
REDS | 1280 × 720 1280\times720 1280×720 | 266 clips | REDS4 4 clips | video super resolution (BI) video deblurring |
Vimeo-90K | 448 × 256 448\times256 448×256 | 64,612 clips | Vimeo-90K-T 7,824 clips | video super resolution (BI, BD) |
Vid4 | 720 × 480 720\times480 720×480 | 4 clips each 34 frames | video super resolution | |
UDM10 | 1272 × 720 1272\times720 1272×720 | 4 clips each 32 frames | video super resolution | |
DVD | 1280 × 720 1280\times720 1280×720 | 61 clips 5,708 frames totally | 10 clips 1,000 frames totally | video deblurring |
GoPro | 1280 × 720 1280\times720 1280×720 | 22 clips 2,103 frames totally | 11 clips 1,111 frames totally | video deblurring |
DAVIS | 854 × 480 854\times480 854×480 | 90 clips | 30 clips | video denoising |
Set8 | 960 × 540 960\times540 960×540 | 8 clips each 85 frames | video denoising | |
experiment detail |
- data augmentation random flipping, random rotation, random cropping
- input
- sr on REDS: 64 × 64 64\times64 64×64-size, 6 or 16 frames
- sr on Vimeo-90K: 64 × 64 64\times64 64×64-size, 7 frames
- db, dn: 192 × 192 192\times192 192×192-size, 6 frames
- degradation
- sr: bicubic down-sampling (BI), blur down-sampling (BD)
- db: motion blur
- dn: Gaussian noise σ ∈ [ 0 , 50 ] \sigma\in[0, 50] σ∈[0,50]
- optimizer Adam: β 1 = 0.9 , β 2 = 0.99 \beta_1=0.9, \beta_2=0.99 β1=0.9,β2=0.99, batch size=8, 300K iterations
- learning rate initial 4e-4, cosine decay
video super resolution
Quantitative comparison (average PSNR/SSIM) with state-of-the-art methods for video super-resolution ( × 4 \times4 ×4) on REDS4, Vimeo-90K-T, Vid4 and UDM10. Best and second best results are in red and blue colors, respectively. “ † \dag †” We currently do not have enough GPU memory to train the fully parallel model VRT on 30 frames.
Visual comparison of video super-resolution ( × 4 \times4 ×4) methods.
video deblurring
Quantitative comparison (average RGB channel PSNR/SSIM) with state-of-the-art methods for video deblurring on DVD. Best and second best results are in red and blue colors, respectively.
Quantitative comparison (average RGB channel PSNR/SSIM) with state-of-the-art methods for video deblurring on GoPro. Best and second best results are in red and blue colors, respectively.
Quantitative comparison (average RGB channel PSNR/SSIM) with state-of-the-art methods for video deblurring on REDS. Best and second best results are in red and blue colors, respectively.
Visual comparison of video deblurring methods.
video denoising
Quantitative comparison (average RGB channel PSNR) with state-of-the-art methods for video denoising on DAVIS and Set8. σ \sigma σ is the additive white Gaussian noise level. Best and second best results are in red and blue colors, respectively.
ablation study
baseline a small version of VRT: layers and channels number halved
multi-scale architecture & parallel warping
Ablation study on multi-scale architecture and parallel warping. Given an input of spatial size 64 × 64 64\times64 64×64, the corresponding feature sizes of each scale are shown in brackets. When some scales are removed, we add more layers to the rest scales to keep similar model size.
key findings
- when number of model scales reduced, performance drop gradually
⟸ \impliedby ⟸ multi-scale processing help to utilize information from a larger area and deal with large motions between frames - parallel warping bring an improvement of 0.17dB
temporal mutual self attention
Ablation study on temporal mutual self attention.
key findings
- when replace MA with SA or only use SA, performance drop by 0.11 to 0.17dB, for
- SA focus more on reference frame rather than on neighboring frame during computation of attention
- MA attend to neighboring frame and benefit from feature fusion
- only using MA is not enough
⟸ \impliedby ⟸ MA cannot preserve information of reference frames
attention window size
Ablation study on attention window size (frame × \times ×height × \times ×width).
study window size on temporal dimension in the last several TMSA layers of each scale
- when window size increase from 1 to 2, performance improve slightly
⟸ \impliedby ⟸ previous TMSA layers already utilize neighboring 2-frame information well - when window size increase to 8, see an obvious improvement of 0.18dB
⟹ \implies ⟹ use window size of 8 × 8 × 8 8\times8\times8 8×8×8 for those layers