关于nn.upsample在GPU上无法兼容BFloat16的问题

在CNN中,nn.upsample常用于上采样操作,尤其是最近大火的扩散模型中,UNet网络的上采样均是采用这个操作执行该任务。尽管如此,nn.upsample在GPU上运行时,与torch.bfloat16会发生冲突,常给出这样的错误:RuntimeError:“upsample_nearest2d_out_frame” not implemented for ‘BFloat16’,从而导致很多高性能计算受阻。torch.bfloat16数据格式,是指"Brain Floating Point"格式占位16位,由Google Brain发明,专门为TPU研制,这种格式有很多优越的性能(详见https://cloud.google.com/tpu/docs/bfloat16?hl=zh-cn);后面人们发现这种数据格式,在GPU框架下的训练速度很快,同时对性能影响很小。如Lightning库(https://lightning.ai/)专门为Pytorch加速时,常使用这种数据格式,我们尝试过,使用这种数据格式训练扩散模型,每迭代1000次,要比其他数据格式快10s左右(在3090上)。因此,这个数据格式nn.upsample这个类在GPU上计算不兼容,将极大地影响学习进程。注意:只是在GPU上,会冲突,在CPU上不会冲突。即:

import torch.nn as nn
data=torch.rand(4,3,8,8,dtype=torch.bfloat16)
up = nn.Upsample(scale_factor=2.0, mode="nearest")
output_up = up(data)

在CPU上,上面这个运行是没有问题的,但是如果在GPU上,如:

data=torch.rand(4,3,8,8,dtype=torch.bfloat16).to('cuda:0')
up = nn.Upsample(scale_factor=2.0, mode="nearest").to('cuda:0')
output = up(data)

这将产生一个错误:RuntimeError: “upsample_nearest2d_out_frame” not implemented for ‘BFloat16’,这个错误的意思,就是在GPU上,upsample与BFloat16格式不匹配。而如果使用其他模式呢?如:mode=‘linear’(3D上采样), ‘bilinear’(4D上采样), ‘bicubic’(4D上采样) and ‘trilinear’(5D上采样),都会产生这样的错误。那有什么解决这个问题呢?这是深层的bug,很难根治,但我们可以重写mode='nearest’下的上采样,具体如下:

class UpsampleDeterministic(nn.Module):
    def __init__(self,upscale=2):
        super(UpsampleDeterministic, self).__init__()
        self.upscale = upscale

    def forward(self, x):
        '''
        x: 4-dim tensor. shape is (batch,channel,h,w)
        output: 4-dim tensor. shape is (batch,channel,self.upscale*h,self.upscale*w)
        '''
        return x[:, :, :, None, :, None]\
        .expand(-1, -1, -1, self.upscale, -1, self.upscale)\
        .reshape(x.size(0), x.size(1), x.size(2)\
                 *self.upscale, x.size(3)*self.upscale)

可以验证,UpsampleDeterministic(upscale=2)与nn.Upsample(scale_factor=2.0, mode=“nearest”)是
完全等价的,即
upsampledet = UpsampleDeterministic(upscale=2)
output_det = upsampledet(data)#与上面的data一样
output_det == output_up
这样,在GPU场景下,我们就用UpsampleDeterministic(upscale=2)代替nn.Upsample(scale_factor=2.0, mode=“nearest”)
即可。这样就是完美的解决了上述问题。注意这个等价关系,只是在mode='nearest’下的等价,其他mode不等价,而在很多模型中的上采样均是mode=‘nearest’。类似地,在torch.nn.functional.upsample下也存在这样的问题,注意:函数格式已经被弃用,使用 torch.nn.functional.interpolate(),如果在GPU场景下,会出现同样的错误,此时,我们可以定义一个函数格式的上采样(nearest模式):

def upsample_deterministic(x,upscale):
    return x[:, :, :, None, :, None]\
    .expand(-1, -1, -1, upscale, -1, upscale)\
    .reshape(x.size(0), x.size(1), x.size(2)\
             *upscale, x.size(3)*upscale)

参考:UpsampleDeterministic(nn.Module)和函数格式,请参考官网问题最底端:https://github.com/pytorch/pytorch/issues/12207。