Pytorch中的contiguous问题
在有的程序中,如果对view、transpose、permute等维数变换操作不恰当,就容易引起
"RuntimeError: : dimension must be contiguous"的错误。在连续性中,有两个常用方法,x=x.contiguous()使一个tensor变连续;x.is_contiguous()是检查x内存是否连续,准确地讲是检查内存排列顺序与以行优先的内存排列顺序是否一致,如果一致则返回True,否则返回False。什么意思呢?我们首先来看个例子:
x = torch.tensor([[1,2,3,4],[5,6,7,8]])
output:
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
那在计算机内存中是如何存储这个矩阵的呢?它是以行展开成一维数组进行存储,称为"按行优先" ,这
与C和C++的存储方式一致,而有的程序就是"按列优先"如matlab。现在,我们按行优先拉直这个矩阵,
x = x.flatten()
output:
tensor([1, 2, 3, 4, 5, 6, 7, 8])
这就是内存存储方方式。如果我们要访问这个矩阵的下一个元素,则偏移一个单位就可以了,这个偏移量
我们称为步长(stride=1),但如果是“按列优先”,则stride就不等于1了,因为:
y = x.t()#相当于将x按列优先存储
output:
tensor([[1, 5],
[2, 6],
[3, 7],
[4, 8]])
同样,拉直矩阵有:
y = y.flatten()
output:
tensor([1, 5, 2, 6, 3, 7, 4, 8])
这个时候,如果"按列优先"存储,则矩阵访问下一元素也就2,其步长stride=2对内存方法有一个初步了解之后,我们现在来看一下is_contiguous到底在检查什么?
x = torch.tensor([[1,2,3,4],[5,6,7,8]])
output:
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
#检查连续性
x.is_contiguous()
output:
True
y = x.transpose(0,1)
output:
tensor([[1, 5],
[2, 6],
[3, 7],
[4, 8]])
y.is_contiguous()
output:
False
我们发现,使用transpose操作后,内存就不连续了,is_contiguous()=False,这是因为这个y
的内存存储顺序与x不一样,即x是这样存储:
tensor([1, 2, 3, 4, 5, 6, 7, 8])
y是这样存储(注意Pytorch是按行优先存储):
tensor([1, 5, 2, 6, 3, 7, 4, 8])
我们再检查两个张量的内存地址(指针):
y.data_ptr()
output:6042087982400
x.data_ptr()
output: 6042087982400
我们发现:两个张量的内存是一样的,但是内存存储的排列不一样,这样就是报y.is_contiguous()=Fasle 。因此,is_contiguous()是在检查,同一个内存空间上,两个张量之间的内存方式(按行展开后的内存)是否一致。
如果我们使用contiguous()方法后:
y = y.contiguous()
y.is_contiguous()
output:True
我们发现,y就变得连续了,我们观察y = y.contiguous()的内存地址:
y.data_ptr()
output:6042087982528
x.data_ptr()
output:6042087982400
我们发现,进行y = y.contiguous()操作后,重新分配了一个内存地址存储y。虽然使用y = y.contiguous()
能解决不连续问题,但是这样的操作增加了内存负担。类似地,使用permute方法、目前比较常用的包einops中 rearrange,reduce,repeat方法,都会使得操作产生不连续问题。在大部分程序中,不连续的影响并不是很大,但是在部分程序中,不连续性是致命的,而解决方法就就是contiguous(),重新分配内存空间。在实际中,我们都希望内存是连续的,以提高数据预取的缓存速度,从而减少内存请求次数,加快模型训练。