avatar

3cats

保持理性

  • 首页
  • 分类
  • 标签
  • 文章归档
Home 避免在PyTorch的Dataset中使用numpy.random
文章

避免在PyTorch的Dataset中使用numpy.random

Posted 2022-04-16 Updated 2023-04- 16
By 3CATS
10~13 min read

最近使用PyTorch的过程中遇到了Dataset在self.__getitem__方法里使用numpy的random函数无效的问题,之前似乎就遇到过,但没有太放心上,这次再次遇到。搜索解决方案的时候发现网上一篇博客记录得很详细,我d额这篇博文算是对这篇博客的一个不完整的中文翻译

深度学习用起来的确是好,拿一个比较成熟的网络结构,只要数据得当,我们总能得到一个差不多的结果。但这种拿来就能用的特性会让我们忽略其中的一些设计细节,甚至隐性的Bug。这些问题在正常情况下表现可能不明显,但当条件变得极端,网络表现或许会出乎我们的意料。Dataset + Numpy就是这种情况的一个非常有代表性的体现。一个叫Tanel Pärnamaa的博主发现,自己的代码在使用numpy的random函数对数据进行数据增广时,性能表现不尽满意,最后发现random的函数在PyTorch Dataset类中的表现并不如我们直观想象的那样(甚至他发现网上95%的开源代码都有类似的问题,在OpenAI的代码,NVIDIA的项目中,PyTorch的官方教程里,甚至特斯拉AI部门老大Even Karpathy也遇到了这个问题)

Bug 介绍

当我们使用PyTorch加载或者是对数据进行预处理/增广的时候,最通用的做法是定义一个torch.utils.data.Dataset的子类,然后重写__getitem__方法。我们有时候会在方法里加入一些随机的特性(随机裁切/旋转、噪声等等),并且在训练使用DataLoader创建每个训练的batch。为加快数据预处理速度,我们还会设置num_workers来并行处理。如果我们在做随机操作的时候使用numpy.random来生成随机数的话,你会发现:

你random的值不是随着每一张图/个数据的改变而改变,而是会随每num_workers张图/个数据的改变而改变

一个简单的例子

下面是一个简单的例子,Dataloader生成batch_size为2的数据,并且使用4个进程进行数据读取:

import numpy as np
from torch.utils.data import Dataset, DataLoader

class RandomDataset(Dataset):
    def __getitem__(self, index):
        return np.random.randint(0, 1000, 3)

    def __len__(self):
        return 16
    
dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=2, num_workers=4)
for batch in dataloader:
    print(batch)

最终的结果是:

tensor([[116, 760, 679],   # 1st batch, returned by process 0
        [754, 897, 764]])
tensor([[116, 760, 679],   # 2nd batch, returned by process 1
        [754, 897, 764]])
tensor([[116, 760, 679],   # 3rd batch, returned by process 2
        [754, 897, 764]])
tensor([[116, 760, 679],   # 4th batch, returned by process 3
        [754, 897, 764]])

tensor([[866, 919, 441],   # 5th batch, returned by process 0
        [ 20, 727, 680]])
tensor([[866, 919, 441],   # 6th batch, returned by process 1
        [ 20, 727, 680]])
tensor([[866, 919, 441],   # 7th batch, returned by process 2
        [ 20, 727, 680]])
tensor([[866, 919, 441],   # 8th batch, returned by process 3
        [ 20, 727, 680]])

我们会发现每4个random出的数据都是一样的……

Why?

PyTorch使用multiprocessing并行加载数据,worker进程使用fork的启动方式 (start methods),这种方式将会使每一个worker进程继承父进程的所有资源,包括NumPy的随机数生成器。

解决方案

关于解决方案,详细的可以看原博客,这里只提一下两个简单的解决方法:

  • 使用Python 内置的random函数生成随机数(PyTorch可以对Python内置random函数进行正常处理)
  • 使用PyTorch自带随机函数

在较新的PyTorch版本中,这个问题似乎已经得到了解决:relevant pull request,但考虑到我们代码的可复现性,这里还是非常建议避免在Dataset中直接使用numpy.random函数,如果非常有必要使用的话,可参考原博客提到的一种方法,通过设置跟worker ID绑定的seed来解决这个问题。

写在最后

事实上这个问题非常隐性,尤其是在batch_size相对较大而num_workers又比较小的情况下,这种随机数的Bug可能影响并不大,但这种影响可能是非常严重的,因为问题一旦出现,就很难Debug,在之前参数合适的情况下,我们改变了下num_workers结果就变差了,倘若直接归咎于炼丹的玄学,你可能跟一个SOTA擦肩而过。炼丹是有玄学,但掌握更多的确定性对我们理解和使用深度学习还是挺重要的。

深度学习
Pytorch
License:  CC BY 4.0
Share

Further Reading

Oct 17, 2023

Qwen-VL 论文阅读

Qwen 中文名是通义千问,由阿里,对应的多模态版本是 Qwen-VL。Qwen 能够根据用户的 prompts, 完成各种视觉任务。相比于其他任务,该模型在 grounding,文本阅读,面向文本问题问答和细粒度对话等方面具有优势。这个模型支持交错图文的输入。 方法 模型基于Qwen,由三部分组成

Apr 16, 2022

避免在PyTorch的Dataset中使用numpy.random

最近使用PyTorch的过程中遇到了Dataset在self.__getitem__方法里使用numpy的random函数无效的问题,之前似乎就遇到过,但没有太放心上,这次再次遇到。搜索解决方案的时候发现网上一篇博客记录得很详细,所以这篇博文是对这篇博客的一个中文概述深度学习用起来的确是好,拿一个比较成

OLDER

第一篇又不是第一篇的博文

NEWER

升级Halo2后的第一次文章更新

Recently Updated

  • 使用 Snapper 管理 NAS 快照
  • 自建NAS照片管理服务启用TLS踩坑记录
  • Qwen-VL 论文阅读
  • 八位堂 Pro2 手柄通过蓝牙连接Linux系统
  • 升级Halo2后的第一次文章更新

Trending Tags

games paper reading nas Pytorch nag

Contents

©2025 3cats. Some rights reserved.

Using the Halo theme Chirpy