避免在PyTorch的Dataset中使用numpy.random
最近使用PyTorch的过程中遇到了Dataset在
self.__getitem__
方法里使用numpy的random函数无效的问题,之前似乎就遇到过,但没有太放心上,这次再次遇到。搜索解决方案的时候发现网上一篇博客记录得很详细,我d额这篇博文算是对这篇博客的一个不完整的中文翻译
深度学习用起来的确是好,拿一个比较成熟的网络结构,只要数据得当,我们总能得到一个差不多的结果。但这种拿来就能用的特性会让我们忽略其中的一些设计细节,甚至隐性的Bug。这些问题在正常情况下表现可能不明显,但当条件变得极端,网络表现或许会出乎我们的意料。Dataset + Numpy就是这种情况的一个非常有代表性的体现。一个叫Tanel Pärnamaa的博主发现,自己的代码在使用numpy的random函数对数据进行数据增广时,性能表现不尽满意,最后发现random的函数在PyTorch Dataset类中的表现并不如我们直观想象的那样(甚至他发现网上95%的开源代码都有类似的问题,在OpenAI的代码,NVIDIA的项目中,PyTorch的官方教程里,甚至特斯拉AI部门老大Even Karpathy也遇到了这个问题)
Bug 介绍
当我们使用PyTorch加载或者是对数据进行预处理/增广的时候,最通用的做法是定义一个torch.utils.data.Dataset
的子类,然后重写__getitem__
方法。我们有时候会在方法里加入一些随机的特性(随机裁切/旋转、噪声等等),并且在训练使用DataLoader创建每个训练的batch。为加快数据预处理速度,我们还会设置num_workers
来并行处理。如果我们在做随机操作的时候使用numpy.random
来生成随机数的话,你会发现:
你random的值不是随着每一张图/个数据的改变而改变,而是会随每num_workers
张图/个数据的改变而改变
一个简单的例子
下面是一个简单的例子,Dataloader生成batch_size
为2的数据,并且使用4个进程进行数据读取:
import numpy as np
from torch.utils.data import Dataset, DataLoader
class RandomDataset(Dataset):
def __getitem__(self, index):
return np.random.randint(0, 1000, 3)
def __len__(self):
return 16
dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=2, num_workers=4)
for batch in dataloader:
print(batch)
最终的结果是:
tensor([[116, 760, 679], # 1st batch, returned by process 0
[754, 897, 764]])
tensor([[116, 760, 679], # 2nd batch, returned by process 1
[754, 897, 764]])
tensor([[116, 760, 679], # 3rd batch, returned by process 2
[754, 897, 764]])
tensor([[116, 760, 679], # 4th batch, returned by process 3
[754, 897, 764]])
tensor([[866, 919, 441], # 5th batch, returned by process 0
[ 20, 727, 680]])
tensor([[866, 919, 441], # 6th batch, returned by process 1
[ 20, 727, 680]])
tensor([[866, 919, 441], # 7th batch, returned by process 2
[ 20, 727, 680]])
tensor([[866, 919, 441], # 8th batch, returned by process 3
[ 20, 727, 680]])
我们会发现每4个random出的数据都是一样的……
Why?
PyTorch使用multiprocessing并行加载数据,worker进程使用fork
的启动方式 (start methods),这种方式将会使每一个worker进程继承父进程的所有资源,包括NumPy的随机数生成器。
解决方案
关于解决方案,详细的可以看原博客,这里只提一下两个简单的解决方法:
- 使用Python 内置的random函数生成随机数(PyTorch可以对Python内置random函数进行正常处理)
- 使用PyTorch自带随机函数
在较新的PyTorch版本中,这个问题似乎已经得到了解决:relevant pull request,但考虑到我们代码的可复现性,这里还是非常建议避免在Dataset中直接使用numpy.random函数,如果非常有必要使用的话,可参考原博客提到的一种方法,通过设置跟worker ID绑定的seed来解决这个问题。
写在最后
事实上这个问题非常隐性,尤其是在batch_size
相对较大而num_workers
又比较小的情况下,这种随机数的Bug可能影响并不大,但这种影响可能是非常严重的,因为问题一旦出现,就很难Debug,在之前参数合适的情况下,我们改变了下num_workers
结果就变差了,倘若直接归咎于炼丹的玄学,你可能跟一个SOTA擦肩而过。炼丹是有玄学,但掌握更多的确定性对我们理解和使用深度学习还是挺重要的。