快捷键

torch.utils.data.dataloader 的源代码

# mypy: 允许未类型化定义
rDataLoader 和相关派生自 _BaseDataLoaderIter 的迭代器的定义。

为了支持这两个类,在 `./_utils` 中我们定义了许多实用方法。
多进程中的运行函数。例如,数据加载工作循环是
在 `./_utils/worker.py` 文件中。
""

导入 functools
导入 itertools
导入 记录日志
导入 多进程 作为 python_multiprocessing
导入 操作系统
导入 队列
导入 线程
导入 警告
来自 collections.abc 导入 迭代器
来自 打字 导入 任意, 可调用, 通用, 可选, 类型变量, 联合

导入 火炬
导入 torch.distributed 作为 dist
导入 torch.utils.data.graph_settings
来自 torch._utils 导入 异常包装器
来自 torch.utils.data 导入 _utils
来自 torch.utils.data.datapipes.datapipe 导入 (
    _IterDataPipe 序列化包装器,
    _MapDataPipe 序列化包装器,
    IterDataPipe,
    MapDataPipe,
)
来自 torch.utils.data.dataset 导入 数据集, 可迭代数据集
来自 torch.utils.data.sampler 导入 (
    批量采样器,
    随机采样器,
    采样器,
    顺序采样器,
)


__all__ = [
    "数据加载器",
    获取工作信息,
    默认合并,
    默认转换,
]


_T = 类型变量("_T")
_T_co = 类型变量(_T_co, 协变=True)
_worker_init_fn_t = 可调用[[int] ]

理想情况下,我们应该通过 `collate_fn` 的返回类型来参数化 `DataLoader`,但目前还没有办法设置一个默认值,如果用户没有传入自定义的 'collate_fn'。
如果用户没有传入自定义的 'collate_fn',则可以将类型参数设置为默认值。
请参阅 https://github.com/python/mypy/issues/3737。
_collate_fn_t = 可调用[[列表[_T]], 任意]


这些函数曾经定义在这个文件中。但是,它已经被移动到了
# _utils/collate.py。尽管从用户层面访问它相当困难
# (必须显式地直接导入 `import torch.utils.data.dataloader`),因此
# 可能是用户代码在使用它。这种别名保持向后兼容性。
# 方面。
默认的 collate: _collate_fn_t = _工具.汇总.默认汇总
默认转换 = _工具.汇总.默认转换

获取工作者信息 = _工具.工人.获取工作者信息

日志记录器 = 记录日志.获取日志记录器(__name__)


 _数据集类型:
    地图 = 0
    迭代器 = 1

    @staticmethod
    def 创建获取器(类型, 数据集, 自动对齐, collate_fn, drop_last):
        如果 仁慈 == _DatasetKind.Map:
            返回 _工具.获取._MapDatasetFetcher(
                数据集, 自动归一化, 归一化函数, drop_last
            )
        否则:
            返回 _工具.获取._IterableDatasetFetcher(
                dataset, auto_collation, collate_fn, drop_last
            )


 _InfiniteConstantSampler(Sampler):
    r类似于 `itertools.repeat(None, None)`。

用作 `torch.utils.data.IterableDataset` 的采样器。
"源代码"

    def __iter__(self):
        while True:
            产生 None


def _get_distributed_settings():
    如果 距离.是否可用()  距离.已初始化():
        返回 距离.获取世界大小(), 距离.获取排名()
    否则:
        返回 1, 0


def _sharding_worker_init_fn(工作器初始化函数, 世界大小, 排名 ID, 工作器 ID):
    全局工作器 ID = 工作 ID
    信息 = 火炬.工具.数据.获取工人信息()
    断言 信息  not None
    总工人数 = 信息.num_workers
    数据管道 = 信息.数据集
    断言 isinstance(数据管道, (迭代数据管道, 映射数据管道))
    为了使元素在分布式进程间均匀分布,我们应该在分布式上分片数据
    首先处理然后在工作进程中分片
    总工作进程数 *= 世界大小
    全局工作进程 ID = 全局工作进程 ID * 世界大小 + rank_id
    # 对于 BC,使用默认的 SHARDING_PRIORITIES
    火炬.工具.数据.图设置.应用分片(
        数据管道, 总工作进程数, 全局工作进程 ID
    )
    如果 工作进程初始化函数  not :
        worker_init_fn(worker_id)


def _share_dist_seed(生成器, pg):
    _shared_seed = 火炬.空的((), 数据类型=火炬.int64).随机(生成器=生成器)
    如果 isinstance(pg, 距离.流程组):
        距离.广播(_shared_seed, =0, 群组=pg)
    返回 _shared_seed.项目()


[文档] 数据加载器(通用[_T_co)] r"" 数据加载器将数据集和采样器结合,并为给定数据集提供可迭代的对象。 `torch.utils.data.DataLoader` 支持映射样式和可迭代样式的数据集,具有单进程或多进程加载,可自定义加载顺序以及可选的自动批处理(收集)和内存固定。 加载,自定义加载顺序以及可选的自动批处理(收集)和内存固定。 加载顺序以及可选的自动批处理(收集)和内存固定。 请参阅:py:mod:`torch.utils.data` 文档页面以获取更多详细信息。 参数: dataset (Dataset):从其中加载数据的数据集。 batch_size (int, 可选):每次加载的批次样本数量 (默认:`1`)。 shuffle(布尔值,可选):设置为 ``True`` 以在每次 epoch(默认:``False``)时重新排列数据 在每个 epoch(默认:``False``)时进行采样。 sampler(采样器或可迭代对象,可选):定义从数据集中抽取样本的策略。可以是任何具有 ``__len__`` 的 ``Iterable`` 对象 可以是任何具有 ``__len__`` 的 ``Iterable`` 对象,用于从数据集中抽取样本。 实现。如果指定了,则不能指定:attr:`shuffle`。 批量采样器(采样器或可迭代对象,可选):类似于 :attr:`sampler`,但 每次返回一批索引。与...互斥。 :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler` 和:attr:`drop_last`。 num_workers(int,可选):使用多少个子进程来加载数据。 ``0`` 表示数据将在主进程中加载。 (默认:``0``) collate_fn (Callable, 可选): 合并一系列样本 小批量张量。在从批量加载时使用。 地图样式数据集。 pin_memory(布尔值,可选):如果为 ``True``,数据加载器将复制张量 在返回之前将数据元素放入设备/CUDA 固定内存中。如果您的数据元素是自定义类型,或者您的:attr:`collate_fn`返回的批次是自定义类型,请参阅下面的示例。 如果您的数据元素是自定义类型,或者您的:attr:`collate_fn`返回的批次是自定义类型,请参阅下面的示例。 请参阅下面的示例。 drop_last (布尔值,可选): 将其设置为 ``True`` 以丢弃最后一个不完整的批次, 如果数据集大小不能被批处理大小整除。如果 ``False``, 那么数据集的大小不能被批处理大小整除,则最后一个批次 的大小将更小。(默认:``False``) 超时时间(数值,可选):如果为正数,则为收集一个批次的超时时间 来自工人。始终应为非负数。(默认:``0``) worker_init_fn(Callable,可选):如果不为 ``None``,则将在每个 工人子进程上调用,使用工人 ID(一个在 ``[0, num_workers - 1]`` 范围内的 int)作为 输入,在播种和加载数据之前。(默认:``None``) multiprocessing_context (str 或 multiprocessing.context.BaseContext, 可选): 如果 ``None`,您的操作系统的默认 `multiprocessing context`_ 将 被使用。(默认:`None`) 生成器(torch.Generator,可选):如果不为 ``None``,则使用此随机数生成器 通过 RandomSampler 生成随机索引和通过多进程生成 `base_seed` 用于工作者。(默认:`None`) prefetch_factor(int,可选,关键字参数):预取批次数 提前由每个工人完成。``2``表示总共 2 * num_workers 个批次的预取数据跨所有工作者。默认值取决于 num_workers 的设置值。如果 num_workers=0,则默认为 ``None``。 如果 num_workers 的值为 0,则默认为 ``None``。 否则,如果 num_workers 的值大于 0,则默认为 ``2``)。 persistent_workers(布尔值,可选):如果设置为 ``True``,数据加载器将不会关闭 数据集被消耗一次后的工作进程。这允许 维护工作进程的 `Dataset` 实例存活。(默认:``False``) pin_memory_device(str,可选):如果 ``pin_memory`` 为 ``True``,则将其固定在的设备上 。如果没有提供,则当前 :ref:`加速器` 将是 默认。此参数不建议使用,并可能被弃用。 in_order (bool, optional): 如果 ``False``,数据加载器将不会强制执行批次以先进先出顺序返回。仅在 ``num_workers > 0`` 时适用。(默认:``True``) ..警告:: 如果使用 ``spawn`` 启动方法,:attr:`worker_init_fn` ..警告:: 如果使用 ``spawn`` 启动方法,:attr:`worker_init_fn` 无法是一个不可序列化的对象,例如 lambda 函数。详见 ref:`multiprocessing-best-practices` 了解更多与 PyTorch 中多进程相关的细节。 有关。 ..警告:: ``len(dataloader)`` 指数基于所使用的采样器的长度。 当 :attr:`dataset` 是一个 :class:`~torch.utils.data.IterableDataset` 时, 它将返回一个基于 ``len(dataset) / batch_size`` 的估计值,根据 :attr:`drop_last` 进行适当的舍入, 无论多进程加载配置如何,这代表了 PyTorch 能做出的最佳猜测。 这代表了 PyTorch 能做出的最佳猜测,因为 PyTorch 信任用户正确处理多进程:attr:`数据集`代码,以避免重复数据。 加载以避免重复数据。 然而,如果分片导致多个工作进程拥有不完整的最后批次, 这个估计仍然可能不准确,因为(1)一个本应完整的批次可以 可以拆分成多个,并且(2)多个批次值的样本也可以被 当设置 :attr:`drop_last` 时会被丢弃。不幸的是,PyTorch 通常无法检测到这种情况。 的情况。 请参阅 `数据集类型`_ 了解这两种类型的数据集的更多详细信息以及如何 `torch.utils.data.IterableDataset` 类的交互 多进程数据加载_ 请参阅 :ref:`可重现性`, 以及 :ref:`dataloader-workers-random-seed`, 和 :ref:`数据加载随机性` 相关的注意事项。 关于随机种子相关问题的说明,请参阅 :ref:`data-loading-randomness`。 ..警告:将`in_order`设置为`False`可能会损害可重复性,并可能导致数据偏差 输入数据分布到训练器中的不平衡数据情况。 .. _多进程上下文: https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods "源代码" 数据集: 数据集[_T_co] 批处理大小: 可选[int] num_workers: 整型 持久化内存: 布尔类型 drop_last: 布尔类型 超时: 浮点数 sampler: 联盟[Sampler, 迭代器] pin_memory_device: 字符串 预取因子: 可选[int] _迭代器: 可选["_基础数据加载迭代器"] __已初始化 = def 初始化( self, 数据集: 数据集[_T_co] 批大小: 可选[int] = 1, 打乱顺序: 可选[布尔] = , 样本器: 联盟[样本器, 迭代器, ] = , 批量样本器: 联盟[样本器[列表] 迭代器[列表] ] = , num_workers: 整型 = 0, collate_fn: 可选[_collate_fn_t] = , 持久化内存: 布尔类型 = 错误, drop_last: 布尔类型 = 错误, 超时: 浮点数 = 0, worker_init_fn: 可选[_worker_init_fn_t] = , 多进程上下文=, 生成器=, *, 预取因子: 可选[int] = , 持久化工作者: 布尔类型 = 错误, 内存设备: 字符串 = "", 按顺序: 布尔类型 = True, ): 火炬._C._log_api_usage_once("python.data_loader") 如果 num_workers < 0: 提升 ValueError( num_workers 选项应为非负数; "使用 num_workers=0 来禁用多进程。" ) 如果 超时 < 0: 提升 ValueError("超时选项应该是非负数。") 如果 num_workers == 0 预取因子 not : 提升 ValueError( "prefetch_factor 选项只能在多进程模式下指定。" "请设置 num_workers > 0 以启用多进程,否则将 prefetch_factor 设置为 None。" ) elif num_workers > 0 prefetch_factor : prefetch_factor = 2 elif 预取因子 not 预取因子 < 0: 提升 ValueError("预取因子选项应为非负") 如果 持久工作进程 num_workers == 0: 提升 ValueError("持久工作进程选项需要 num_workers > 0") self.数据集 = 数据集 self.num_workers = num_workers self.预取因子 = 预取因子 self.内存映射 = 内存映射 self.内存映射设备 = 内存映射设备 self.超时 = 超时 self.worker 初始化函数 = worker 初始化函数 self.多进程上下文 = 多进程上下文 self.按顺序 = 按顺序 # 添加向前兼容性,以便经典 DataLoader 可以与 DataPipes 一起工作: # _DataPipeSerializationWrapper 容器使得在不重新定义 pickler 的情况下进行序列化变得更加容易 如果 isinstance(self.数据集, 迭代数据处理管道): self.数据集 = _迭代数据管道序列化包装器(self.数据集) elif isinstance(self.数据集, 地图数据管道): self.数据集 = _地图数据管道序列化包装器(self.数据集) 在检查采样器之前先检查相关数据集,因为我们希望 告诉用户可迭代式数据集与自定义采样器不兼容, 因此,他们不会在修复自定义采样器错误后才发现这种组合不起作用。 这会浪费他们修复自定义采样器错误的时间。 如果 isinstance(数据集, 可迭代数据集): self._数据集类型 = _DatasetKind.迭代器 # [自定义采样器和 IterableDataset] # # `IterableDataset` 不支持自定义 `batch_sampler` 或 # `sampler` 因为键无关紧要(除非我们有一天支持 # 生成器风格的 dataset...)。 # 对于`sampler`,我们始终创建一个虚拟采样器。这是一个 # 无限采样器,即使数据集可能已实现 # finite `__len__` 因为在多进程数据加载中,原始 设置将返回重复数据(可能是有意的) 因此使用与数据集长度匹配的采样器会导致数据丢失(你可能会看到前几批次的重复,但之后永远不会看到任何东西)。因此,`Iterabledataset` 总是使用无限采样器,即无限采样器的一个实例。 # cause data lost (you may have duplicates of the first couple # batches, but never see anything afterwards). Therefore, # `Iterabledataset` always uses an infinite sampler, an instance of # 上文定义了 `_InfiniteConstantSampler`。 # # 自定义 `batch_sampler` 实质上仅控制批大小。 然而,由于不清楚它会有多有用,因此以可迭代方式 数据集本身可以处理。此外,这样做毫无意义。 在多进程数据加载中,批次的分配顺序是一个实现细节,因此用户无法控制 因此我们禁用此选项。如果将来证明这个选项很有用,我们可以重新启用 因此我们禁用此选项。如果将来证明这个选项很有用,我们可以重新启用 因此我们禁用此选项。如果将来证明这个选项很有用,我们可以重新启用 # 这,并支持指定分配的自定义采样器。 # 指定特定工作者。 如果 isinstance(数据集, 迭代数据管道): 如果 打乱 not : 数据集 = 火炬.工具.数据.图形设置.应用打乱设置( 数据集, 打乱=打乱 ) # 我们不能在这里检查 `shuffle is not None`,因为之前 `shuffle=False` 是默认值。 elif 打乱 not {错误, }: 提升 ValueError( fDataLoader 与 IterableDataset:期望未指定打乱选项,但得到 shuffle={打乱}" ) 如果 样本器 not : #参见 NOTE [自定义样本器和 IterableDataset] 提升 ValueError( fDataLoader 与 IterableDataset:期望未指定采样器选项,但获取到 sampler={采样器}" ) elif 批量采样器 not : #参见笔记[自定义采样器和 IterableDataset] 提升 ValueError( DataLoader 与 IterableDataset:期望未指定 fbatch_sampler 选项,但得到 batch_sampler={batch_sampler}" ) 否则: 打乱顺序 = 布尔(打乱) self.数据集类型 = _数据集类型.地图 如果 样本 not None 随机播放: 提升 ValueError("样本选项与随机播放互斥") 如果 批量样本器 not : 使用自定义批采样器进行自动对齐 如果 批处理大小 != 1 打乱顺序 样本器 not None 丢弃最后: 提升 ValueError( "batch_sampler 选项与 batch_size、shuffle、sampler 和 drop_last 互斥" "以及 batch_size、shuffle、sampler 和 " "drop_last" ) 批处理大小 = None drop_last = elif 批处理大小 : 无自动对齐 如果 drop_last: 提升 ValueError( "batch_size=None 选项禁用自动批处理" "且与 drop_last 互斥" ) 如果 样本器 : # 提供默认样本器 如果 self._数据集类型 == _数据集类型.迭代器: # 注意 [自定义采样器和 IterableDataset] 采样器 = _无限常量采样器() 否则: # 地图样式 如果 打乱: 样本器 = 随机样本器(数据集, 生成器=生成器) # type: ignore[arg-type] 否则: 样本器 = 顺序样本器(数据集) # type: ignore[arg-type] 如果 批处理大小 not None 批处理样本器 : 自动归一化无需自定义批量采样器 批量采样器 = 批量采样器(采样器, 批处理大小, 丢弃最后) self.批处理大小 = 批处理大小 self.删除最后 = 删除最后 self.样本器 = 样本器 self.批量样本器 = 批量样本器 self.生成器 = 生成器 如果 合并函数 : 如果 self._自动合并: collate_fn = _工具.collate.default_collate 否则: collate_fn = _工具.汇总.默认转换 self.汇总函数 = 汇总函数 self.持久化工作者 = 持久化工作者 self.__已初始化 = 真实 self._IterableDataset_len_called = ( None # 查看 NOTE [ IterableDataset 和 __len__ ] ) self.迭代器 = None self.检查工作数量合理性() 火炬.设置生命体征(数据加载器, 启用, "真") # 类型: 忽略[attr-defined] def _获取迭代器(self) -> "_基础数据加载迭代器": 如果 self.num_workers == 0: 返回 _单进程数据加载迭代器(self) 否则: self.检查工作数量合理性() 返回 _多进程数据加载迭代器(self) @property def 多进程上下文(self): 返回 self.__多进程上下文 @multiprocessing_context.设置器 def 多进程上下文(self, 多进程上下文): 如果 多进程上下文 not : 如果 self.num_workers > 0: 如果 isinstance(多进程上下文, 字符串): 有效的启动方法 = 火炬.多进程.获取所有启动方法() 如果 多进程上下文 not 有效的启动方法: 提升 ValueError( "multiprocessing_context 选项 " f应指定一个有效的启动方法{有效的启动方法!r},但实际上得到了" fmultiprocessing_context={multiprocessing_context!r}" ) multiprocessing_context = 火炬.多进程.获取上下文( multiprocessing_context ) 如果 not isinstance( 多进程上下文, Python 多进程.上下文.基础上下文 ): 提升 类型错误( "多进程上下文选项应是一个有效的上下文" "对象或指定启动方法的字符串,但得到了" f"multiprocessing_context="{multiprocessing_context}" ) 否则: 提升 ValueError( "multiprocessing_context 只能与" "多进程加载(num_workers > 0),但得到" f"num_workers="{self.num_workers}" ) self.__多进程上下文 = 多进程上下文 def __setattr__(self, 属性, val): 如果 self.__已初始化 属性 ( 批处理大小, 批采样器, 采样器, 最后一个丢弃, "数据集", "持久化工作者", ): 提升 ValueError( f"{属性}属性设置后不应更改{self..__name__}已初始化 ) 超级().__setattr__(属性, val) # 我们引用 '_BaseDataLoaderIter',因为它尚未定义,且定义不能上移 # 因为 '_BaseDataLoaderIter' 引用了 'DataLoader' def __iter__(self) -> "_BaseDataLoaderIter": 当使用单个工作线程时,返回的迭代器应该是 每次创建以避免重置其状态 然而,在多个工作者迭代器的情况下 迭代器在其生命周期中只创建一次 DataLoader 对象,以便工作者可以被重用 如果 self.持久工作者 self.num_workers > 0: 如果 self.迭代器 : self.迭代器 = self.获取迭代器() 否则: self.迭代器.重置(self) 返回 self.迭代器 否则: 返回 self.获取迭代器() @property def 自动归一化(self): 返回 self.批处理采样器 not None @property def _index_sampler(self): # 实际用于为 `_DatasetFetcher` 生成索引的采样器 # (请参阅 _utils/fetch.py) 在每次读取数据时。这将是 `.batch_sampler` 如果处于自动合并模式,否则为 `.sampler`。 我们不能更改 `.sampler` 和 `.batch_sampler` 属性以保持向后兼容。 原因。 如果 self._auto_collation: 返回 self.批处理采样器 否则: 返回 self.采样器 def __len__(self) -> int: 如果 self.数据集类型 == _DatasetKind.迭代器: # NOTE [ IterableDataset 和 __len__ ] # # 对于 `IterableDataset`,`__len__` 可能不准确,当直接 多进程加载数据,因为样本将被重复。 然而,没有任何实际用例应该真正使用这种行为,所以 它应该被视为用户错误。我们通常应该信任用户 代码做正确的事情(例如,为每个副本配置不同的设置) # in `__iter__`), 并提供正确的 `__len__` 如果他们选择这样做 # 实现(如果数据集没有实现,这仍然会抛出异常) # 为了提供进一步的警告,我们跟踪是否调用了 `__len__` # # ,如果 `__len__` 被调用,我们将提供额外的信息 # `DataLoader`,将返回值保存到`self._len_called`中,并警告 # 如果迭代器最终产生了超过这个数量的样本。 # 无法静态验证数据集是否为 Sized 长度 = self._IterableDataset_len_called = 长度(self.数据集) # type: ignore[assignment, arg-type] 如果 ( self.批处理大小 not None ): # IterableDataset 不允许自定义采样器或批量采样器 来自 数学 导入 向上取整 如果 self.drop_last: 长度 = 长度 // self.批处理大小 否则: 长度 = 向上取整(长度 / self.批量大小) 返回 长度 否则: 返回 长度(self._index_sampler) def 检查工作进程数量合理性(self): # 此函数检查根据当前系统资源,数据加载器的工作进程数量是否合理。当前规则是,如果工作进程的数量为 # 当前系统资源的倍数,则认为该数量是合理的。 # 加载数据加载器创建的进程数大于允许的逻辑 CPU 数时 # 我们将弹出警告提示用户注意。 # # 例如,如果当前系统有 2 个物理 CPU,每个 CPU 有 16 个核心。并且每个核心支持 2 # 个线程,那么这里的总逻辑 CPU 数就是 2 * 16 * 2 = 64。假设当前 DataLoader 进程可以使用其中的一半,即 32 个,然后从该进程启动的 worker 的最大合理数量是 32。 现在,假设创建的 DataLoader 具有 num_works = 40,这个数字大于 32。 因此,会触发警告信息来通知用户降低 worker 的数量。 所以会触发警告信息来通知用户降低 worker 的数量。 # 必要的。 # # # [注意] 请注意,此函数仅在 os.sched_getaffinity 可用时才尊重 `cpuset`。 # 它在大多数 Linux 系统中可用,但在 OSX 和 Windows 中不可用。 # 当 os.sched_getaffinity 不可用时,将调用 os.cpu_count(),但它不尊重 cpuset。 # 它不尊重 cpuset。 我们不考虑线程,因为每个工作进程是单线程的 # 此时。 # 我们不设置任何线程标志(例如 OMP_NUM_THREADS, MKL_NUM_THREADS 等) 除了在工作进程中将`torch.set_num_threads`设置为 1 之外,如果传递 在函数中使用依赖于这些线程标志的第三方模块 创建多少个线程(例如,numpy 等),然后是调用者的责任 正确设置这些标志。 def 创建警告信息(建议的 worker 数量, 已创建的 worker 数量, 检查了 cpuset): 建议的最大 worker 消息量 = ( ( ( 我们建议当前系统中最大工人数为{}{}, 是更小的 " 比这个 DataLoader 将要创建的内容。 ).格式( 建议工作进程数, ( 请提供需要翻译的文本 如果 cpuset 已检查 否则 "(`cpuset`不被考虑)" ), ) ) 如果 建议工作进程数 not None 否则 ( 数据加载器无法计算当前系统建议的最大工作进程数。 ) ) 警告信息 = ( f"此数据加载器将创建"{创建的进程数}总共的 worker 进程。{建议的最大 worker 消息数} " "请注意,过度创建工作进程可能会导致 DataLoader 运行缓慢甚至冻结," "如果需要,请降低工作进程数量以避免潜在的缓慢/冻结。" ) 返回 警告信息 如果 not self.num_workers self.num_workers == 0: 返回 尝试根据系统资源计算建议的最大工作线程数 最大工作线程数建议 = None 检查 CPU 集 = 如果 有属性(操作系统, "获取调度亲和性"): try: 最大工作线程建议 = 长度(操作系统.sched_getaffinity(0)) cpuset_checked = 真实 除了 异常: 通过 如果 max_num_worker_suggest : # os.cpu_count() 可能返回 Optional[int] 首先获取 CPU 数量并检查是否为 None 以满足 mypy 检查 cpu_count = 操作系统.cpu_count() 如果 cpu_count not : 最大工作线程建议 = cpu_count 如果 max_num_worker_suggest : warnings.警告( 创建警告信息( 最大工作进程数建议, self.num_workers, CPU 集检查 ) ) 返回 如果 self.num_workers > 最大工作进程数建议: warnings.警告( 创建警告信息( 最大工作线程数建议, self.num_workers, cpuset 已检查 ) )
_基础数据加载迭代器: def 初始化(self, 加载器: 数据加载器) -> : self._数据集 = 加载器.数据集 self.共享种子 = None self._pg = None 如果 isinstance(self._数据集, 迭代数据处理管道): 如果 距离.是否可用() 距离.已初始化(): self._pg = 距离.创建新组(后端=gluu) self.共享种子 = 分享分布种子(加载器.生成器, self._pg) shared_rng = 火炬.生成器() shared_rng.手动播种(self._shared_seed) self.数据集 = 火炬.工具.数据.图设置.应用随机种子( self.数据集, 共享随机数生成器 ) self.数据集类型 = 加载器.数据集类型 self._IterableDataset_len_called = 加载器._IterableDataset_len_called self.自动分词 = 加载器._自动对齐 self._丢弃最后一项 = 加载器.删除最后 self._index_sampler = 加载器._index_sampler self._num_workers = 加载器.num_workers ws, 排名 = 获取分布式设置() self.世界大小 = ws self._排名 = 排名 如果未设置 pin_memory_device,则默认行为为当前加速器。 如果 pin_memory_device 已设置但 pin_memory 未设置,则默认 # 行为关闭。 如果 长度(加载器.pin_memory_device) == 0: 如果 加载器.内存映射 not 火炬.加速器.是否可用(): 警告信息 = ( "将 'pin_memory' 参数设置为 true,但未找到加速器" 然后设备固定内存将不会被使用。 ) warnings.警告(warn_msg) self._pin_memory = 加载器.内存映射 火炬.加速器.是否可用() self._pin_memory_device = None 目前,pin_memory 在 MPS 后端会引发错误(参见 # https://github.com/pytorch/pytorch/issues/86060),因此强制 # 禁用 MPS 上的 pin_memory。一旦固定,请移除此限制 # MPS 的内存分配是固定的。 如果 ( self._内存映射 (acc := 火炬.加速器.当前加速器()) not None acc.类型 == mps ): self._pin_memory = 警告信息 = ( "'pin_memory' 参数设置为 true,但目前 MPS 不支持," "则不会使用设备固定内存。" ) warnings.警告(warn_msg) 否则: 如果 not 加载器.持久化内存: 警告信息 = ( "已设置 'pin_memory_device' 但未设置 'pin_memory' 参数," "设备固定内存将不会被使用。" "如果需要使用设备固定内存,请将 'pin_memory' 设置为 true" ) warnings.警告(warn_msg) self.固定内存 = 加载器.内存映射 self.固定内存设备 = 加载器.内存映射设备 self._超时 = 加载器.超时 self._collate_fn = 加载器.整理函数 self._sampler_iter = 迭代(self._index_sampler) self._base_seed = ( 火炬.空的((), 数据类型=火炬.int64) .随机(生成器=加载器.生成器) .项目() ) self._persistent_workers = 加载器.持久化工作者 self._num_yielded = 0 self._profile_name = f"enumerate(DataLoader)#{self..__name__}.__next__ def __iter__(self) -> "_基础数据加载迭代器": 返回 self def 重置(self, 加载器, 第一次迭代=错误): self._采样迭代 = 迭代(self._索引采样) self._已产出数量 = 0 self._IterableDataset_len_called = 加载器._IterableDataset_len_called 如果 isinstance(self.数据集, 迭代数据处理管道): self.共享种子 = 分享分布种子(加载器.生成器, self._ pg) 共享随机数生成器 = 火炬.生成器() 共享随机数生成器.手动播种(self._共享种子) self._数据集 = 火炬.工具.数据.图设置.应用随机种子( self.数据集, 共享随机数生成器 ) def 下一个索引(self): 返回 下一(self.样本迭代器) 可能引发 StopIteration def _下一数据(self): 提升 未实现异常 def __下一__(self) -> 任意: 火炬.自动微分.分析器.记录功能(self._用户名): 如果 self._sampler_iter : # TODO(https://github.com/pytorch/pytorch/issues/76750) self.重置() 忽略调用参数 数据 = self._next_data() self._num_yielded += 1 如果 ( self.数据集类型 == _DatasetKind.迭代器 self._IterableDataset_len_called not None self._num_yielded > self._IterableDataset_len_called ): 警告信息 = ( fIterableDataset 的长度{self._dataset}被报告为是{self._可迭代数据集_len_被调用}" f"(当访问 len(dataloader))时,但"{self._已获取的样本数"}已获取样本。 ) 如果 self._num_workers > 0: 警告信息 += ( 对于多进程数据加载,这可能是由于未正确配置“ "每个工作器上的可迭代数据集副本。请参阅" "https://maskerprc.github.io/docs/stable/data.html#torch.utils.data.IterableDataset 以获取示例。" ) warnings.警告(warn_msg) 返回 数据 def __len__(self) -> int: 返回 长度(self._index_sampler) def __getstate__(self): # TODO: 为共享迭代器添加有限的序列化支持 # 以跨多个线程用于 HOGWILD。 可能是这样做最好的方法,就是移动样本推送 将数据队列分离到单独的线程,然后仅共享数据队列 但是,在没有非阻塞 API 的情况下发出结束信号很棘手 提升 不支持的操作异常("{}无法序列化, self..__name__) _单进程数据加载迭代器(_基础数据加载迭代器): def 初始化(self, 加载器): 超级().初始化(加载器) 断言 self._超时 == 0 断言 self._num_workers == 0 # 添加向前兼容性,以便经典 DataLoader 可以与 DataPipes 一起工作: # 分布式分片处理 如果 isinstance(self._数据集, (迭代数据处理管道, 地图数据管道)): # 对于 BC,使用默认的 SHARDING_PRIORITIES 火炬.工具.数据.图设置.应用分片( self._数据集, self.世界大小, self._排名 ) self.数据集获取器 = _DatasetKind.创建获取器( self.数据集类型, self._数据集, self._自动合并, self._collate_fn, self._drop_last, ) def _next_data(self): 索引 = self._next_index() # 可能引发 StopIteration 异常 数据 = self._dataset_fetcher 数据集获取器.获取(索引) # 可能引发 StopIteration 异常 如果 self._pin_memory: 数据 = _工具.持久化内存.持久化内存(数据, self._pin_memory_device) 返回 数据 _多进程数据加载迭代器(_基础数据加载迭代器): r"""遍历一次 DataLoader 的数据集,由采样器指定。""" # 数据加载器多进程关闭逻辑 # # 初步: # # 我们的数据模型如下(队列用花括号表示): # # 主进程 || # | || # {index_queue} || # | || # worker processes || DATA # | || # {worker_result_queue} || FLOW # | || # pin_memory_thread of main process || DIRECTION # | || # {数据队列} || # | || # 数据输出 \/ # # P.S. 如果 `worker_result_queue` 和 `pin_memory_thread` 部分可以省略,则 # `pin_memory=False`。 # # # 终止多进程逻辑需要非常谨慎的设计。特别是,我们需要确保 # # 当迭代器的最后一个引用消失或耗尽时,它将优雅地退出工作者。 在这种情况下,应该优雅地退出工作者,因为主进程可能仍然需要继续运行,而我们希望进行清理。 # 在此情况下,工作者应该被优雅地退出,因为主进程可能仍然需要继续运行,我们希望进行清理。 在这种情况下,应该优雅地退出工作者,因为主进程可能仍然需要继续运行,我们希望进行清理。 # 在工作线程中执行释放 GPU 内存等操作代码。 # 自然地,我们在`__del__`中实现了关闭逻辑。 # DataLoaderIterator。 # # 我们将对此处的逻辑讨论推迟到以后。 # 2. 迭代器在加载进程和/或工作进程退出时终止 进程正常退出或出现错误。 # 我们将所有工作者和`pin_memory_thread`设置为`daemon=True`。 # 您可能会问,为什么我们不能让工作进程非守护进程化, 优雅地退出,使用与我们在 `__del__` 中相同的逻辑 迭代器被删除了(参见上面第 1 点)? # 首先,`__del__` 并不保证在调用时会被调用 # 解释器退出。即使被调用,在它执行时, 许多 Python 核心库资源可能已经被释放,甚至 简单的事情,比如获取队列的内部锁,也可能挂起。 因此,在这种情况下,我们实际上需要防止`__del__`的执行, 并依赖于守护进程的自动终止。 # 子代。 # 因此,我们注册了一个`atexit`钩子,该钩子设置了一个全局标志 `# `_utils.python_exit_status`。由于 `atexit` 钩子在 注册顺序的逆序,我们保证这个标志是 设置在释放我们使用的库资源之前(至少在 CPython,是通过在 `atexit` 处理器中定义完成的 `multiprocessing/util.py` https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362 当首次注册需要此机制的对象时 # 创建,例如,`mp.Queue` # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103 # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29 # ) # # 因此在 `__del__` 中,我们检查 `_utils.python_exit_status` 是否已设置 # `None`(释放),如果为空则执行无操作。 # # 然而,仅仅让库的清理代码运行也可能不好, # 因为这样的代码(例如,`multiprocessing.util._exit_function()`) # 包括为`mp.Queue`的线程加入操作,这可能会导致阻塞。 因此,创建线程的主要过程被调用为 # 在创建时取消加入线程。参见后续章节 # [ 3b. 将进程放入队列时不会挂起; ] # 查看更多详情。 # 这里是两个库清理代码可以运行的示例案例 在调用 `__del__` 之前: # 1. 如果我们保留对迭代器的引用,它更常 # 不如尝试在 `multiprocessing` 库清理之前 清除所有活跃的引用对象(https://github.com/pytorch/pytorch/issues/48666) 因此阻止了我们的清理代码首先运行。 # 2. 当在子进程中使用 `DataLoader` 时,也会出现类似的问题。 当进程结束时,它会关闭所有其守护进程子进程。 # 将它们(而不是在超时前)一起退出。 # 对于线程来说,情况类似,但机制不同。这个事实, # 以及多进程的一些实现细节,迫使我们让工作进程成为守护进程。所有的问题都源于当 # 我们的工作进程在某个时候退出时。 # 数据加载器在子进程中使用,由多进程引起 # 看起来大致是这样的代码: # 尝试: # 使用数据加载器调用你的函数() finally: # multiprocessing.util._退出函数() # 上述提及的加入/终止发生在内部 # `_退出函数()`. 现在,如果 `your_function_using_a_dataloader()` 抛出异常时,异常中存储的堆栈跟踪将阻止使用 `DataLoaderIter` 的帧被释放。如果该帧有任何对 `DataLoaderIter` 的引用(例如,在迭代器的方法中),则其 `__del__` 方法,它启动关闭程序,将不会执行。 如果该帧有任何对 `DataLoaderIter` 的引用(例如,在迭代器的方法中),则其 `__del__` 方法,它启动关闭程序,将不会执行。 如果该帧有任何对 `DataLoaderIter` 的引用(例如,在迭代器的方法中),则其 `__del__` 方法,它启动关闭程序,将不会执行。 如果该帧有任何对 `DataLoaderIter` 的引用(例如,在迭代器的方法中),则其 `__del__` 方法,它启动关闭程序,将不会执行。 # 被调用。这反过来意味着工人没有收到通知。尝试 # 加入 `_exit_function` 将导致程序挂起。 # # 为了上下文,`_exit_function` 也被注册为 `atexit` 调用。 我不明白(@ssnl)为什么需要在 finally 块中这样做。 这段代码可以追溯到 2008 年,原始代码中没有任何注释。 PEP 371 或补丁 https://bugs.python.org/issue3050(包含 finally 块和`atexit`注册)对此进行了解释。 这解释了为什么 finally 块和`atexit`注册需要同时存在。 # # 最后,另一个选择是当我们在`next`中看到错误时,仅通过逻辑关闭工作者。这并不理想,因为 a. 它阻止用户使用 try-catch 来恢复数据加载。 b. 如果用户持有引用,则无法防止挂起。 c. 它不会阻止用户在`next`中看到错误时使用 try-catch 来恢复数据加载。 迭代器。 # 如果任何进程意外地因致命信号而死亡,则所有进程都将退出。 # 如上所示,工作进程被设置为父进程的守护进程子进程。 然而,此类子进程的自动清理仅 如果父进程优雅退出(例如,不是通过致命信号如 SIGKILL),则会出现这种情况。因此,我们必须确保每个进程都会退出,即使应该与之发送/接收数据的进程被杀掉,即, 所以我们必须确保每个进程都会退出,即使应该与之发送/接收数据的进程被杀掉,即, 所以我们必须确保每个进程都会退出,即使应该与之发送/接收数据的进程被杀掉,即, 所以我们必须确保每个进程都会退出,即使应该与之发送/接收数据的进程被杀掉,即, # 在从队列中获取时,进程不会挂起。 # 即使数据依赖关系设计得非常精心(即,`put()` 总是与 `get()` 相对应),当队列中的数据损坏时(例如,由于 ...),在 `get()` 上仍然可能会发生挂起。 当队列中的数据损坏时(例如,由于 ...),在 `get()` 上仍然可能会发生挂起。 当队列中的数据损坏时(例如,由于 ...),在 `get()` 上仍然可能会发生挂起。 “取消加入线程”或意外退出。 # 对于子进程退出,我们每次尝试从`data_queue`获取数据时都会设置超时。 并在每个超时和错误时检查工作进程的状态。 出错。 查看 `_DataLoaderiter._get_batch()` 和 `_DataLoaderiter._try_get_data()` 以获取详细信息。 此外,对于非 Windows 平台上的子进程退出,我们还会注册一个 SIGCHLD 处理器(在 Windows 上受支持)。 # 另外,对于非 Windows 平台上的子进程退出,我们也会在 Windows 上注册一个 SIGCHLD 处理器。 在 Windows 上注册一个 SIGCHLD 处理器。 主流程,用于检查是否有任何工作进程失败, (Python)处理器。与仅使用上述机制相比,这种方法在检测工作进程失败方面更高效、更快。 请参阅`DataLoader.cpp`和`_utils/signal_handling.py`以获取详细信息。 请参阅`DataLoader.cpp`和`_utils/signal_handling.py`以获取详细信息。 # 对于不是工作者的发送者进行的 `.get()` 调用,我们用超时来保护它们,并在超时发生时检查发送者的状态: 在工作者中,使用 `_utils.worker.ManagerWatchdog` 类 当超时发生时: + 在工作者中,使用 `_utils.worker.ManagerWatchdog` 类 检查主进程状态。 # + 如果 `pin_memory=True`,从 `pin_memory_thread` 获取时 定期检查 `pin_memory_thread` 状态 # 返回或查看 `pin_memory_thread` 已死。 # # b. 将进程放入队列时不会挂起; # # 我们使用 `mp.Queue`,它有一个独立的后台线程将 # 对象从无界缓冲数组中放入。后台线程是守护线程, # 通常在进程结束时自动结束。 # *退出*。 # # 如果接收方在读取管道时突然结束,则连接将永远挂起。 # 在 Python 中,解决这个问题通常是通过调用 `q.cancel_join_thread`。 # 来实现的。 这将防止在最终确定时自动将其连接 (退出)。 # 然而,`cancel_join_thread` 只应在队列 **不会**被其他进程读取或写入时调用 # 处理过程,因为它可能持有锁或留下损坏的数据 # 在队列中,导致其他读者/写者挂起。 # 因此, # + 对于工作进程,我们只这样做(对于它们的输出) 在退出前,请确保队列(例如,`worker_result_queue`)已被清空。 对于`pin_memory_thread`,其输出队列`data_queue`是一个`queue.Queue`,如果队列已满,则会进行阻塞`put`操作。 因此,不会出现上述问题,但结果是在 因此,不会出现上述问题,但结果是在 `_pin_memory_loop`,我们确实需要将`put`操作放在循环中 那样不仅会在成功时跳出循环,还会在主进程停止读取时跳出,即正在关闭。 对于加载进程,我们对所有进程调用`cancel_join_thread()` + 对于加载进程,我们为所有进程调用`cancel_join_thread()` `_index_queues` 因为整个工作进程和 `pin_memory_thread` 的作用就是为了服务加载进程。如果加载进程已经正在退出,我们其实并不关心 `pin_memory_thread` 是为了服务加载进程。如果加载进程已经正在退出,我们其实并不关心 加载进程已经正在退出,我们其实并不关心它是否能够成功退出 队列已损坏。 # # 现在让我们回到 1: 如何优雅地退出工作线程,当最后一个引用到 迭代器已消失。 # 实现此目的,我们实现了以下逻辑以及设计 如上所述的选择: # `workers_done_event`: 在主进程和所有工作进程之间共享的 `multiprocessing.Event` # 处理进程。这用于通知工作进程迭代器正在关闭。 # 关闭。设置后,它们将不再向队列发送处理过的数据, # 并且只等待最后的 `None` 才退出。 # `done_event` 并非必需。也就是说,我们只需检查 `None` 即可。 从输入队列中取出,但允许我们跳过浪费资源 正在关闭时处理数据。 # `pin_memory_thread_done_event`: 粘性内存线程完成事件 一个用于类似目的的 `threading.Event` `workers_done_event`,但它是为`pin_memory_thread`准备的。之所以需要单独的事件,是因为`pin_memory_thread`从 工作进程的输出队列中读取。但是,当工作进程看到`workers_done_event`被设置时,它们只想看到最终的`None`, # 但工作进程,在看到`workers_done_event`被设置后,只想看到最终的`None`,并且是 # 是因为`pin_memory_thread`需要读取工作进程的输出队列。但是,当工作进程看到`workers_done_event`被设置时, 不需要刷新输出队列中的所有数据(例如,它可能调用该队列上的 `cancel_join_thread`,如果其 `IterableDataset` 迭代器偶然耗尽,这超出了主进程的控制)。因此,由于我们将在退出 `pin_memory_thread` 之前退出,所以 `cancel_join_thread` 在该队列上(如果其 `IterableDataset` 迭代器偶然耗尽,这超出了主进程的控制)。因此,由于我们将在退出 `pin_memory_thread` 之前退出,所以 `IterableDataset` 迭代器偶然耗尽,这超出了主进程的控制)。因此,由于我们将在退出 `pin_memory_thread` 之前退出,所以 主进程的控制)。因此,由于我们将在退出 `pin_memory_thread` 之前退出,所以 # 工作者(见下文),使用两个独立的事件。 # # NOTE:简而言之,协议是主进程将设置这些 # `done_event`,然后相应的进程/线程接收到`None`, # 并且它们可以在收到`None`后随时退出。 # # 注意:使用 `None` 作为最终信号是有效的,因为正常数据将 始终是一个包含两个元素的元组,其中第一个元素是数据的索引 # 转移(不同于数据集索引/键),第二个是 # 数据集键或数据样本(取决于哪一部分) 数据模型队列所在的位置。 # [工作进程] 当加载进程存活时: 从 `index_queue` 获取。 如果获取到其他任何内容, 检查 `workers_done_event`。 如果已设置,则继续到下一个迭代 即,继续获取,直到看到 `None`,然后退出。 否则,处理数据: 如果是从`IterableDataset`中获取并迭代 # 已耗尽,发送 `_IterableDatasetStopIteration` 对象用于表示迭代结束。主进程,当 接收此类对象时,将发送 `None` 到此 # 工作者而不是使用相应的 `index_queue` # 再也不了。 如果超时, 无论是否设置了`workers_done_event`(仍需要看到`None`) 无论是否,都必须继续到下一个迭代。 (循环外部) 如果设置了`workers_done_event`(在`IterableDataset`中这可能是`False`) # `data_queue.cancel_join_thread()`。 (一切都在这里结束:) 主进程不会从中读取; # 其他工作者也将调用 # 取消加入线程。 # # [ 内存固定线程 ] # # 无需检查主线程。如果此线程存活,则主加载 # # 线程必须存活,因为此线程被设置为守护线程。 # 当 `pin_memory_thread_done_event` 事件未设置时: 从 `worker_result_queue` 获取 如果超时,继续在下一次迭代中获取。 否则,处理数据。 当`pin_memory_thread_done_event`未设置时: 将处理过的数据放入`data_queue`(一个带有阻塞 put 的`queue.Queue`) 如果超时,则继续在下一次迭代中输入。 否则,中断,即继续到外部循环。 # 注意:我们不检查主线程的状态,因为 1. 如果进程被致命信号终止,`pin_memory_thread` # 结束。 # 2. 在其他情况下,无论是 __del__ 中的清理还是守护线程的自动退出都将处理它。 # 这也不会忙等待,因为 `.get(timeout)` 不会 # 这也不会忙等待,因为 `.get(timeout)` 不会 # 繁忙等待。 # # [主进程] # 在 DataLoader Iter 的 `__del__` 中 # b. 退出 `pin_memory_thread` # i. 设置 `pin_memory_thread_done_event`。 # ii 在 `worker_result_queue` 中放入 `None`。 # iii. 等待 `pin_memory_thread` 完成。 # iv. 调用 `worker_result_queue.cancel_join_thread()`。 # # c. 退出工作进程。 # i. 设置 `workers_done_event`。 # ii. 在每个工作进程的 `index_queue` 中置为 `None`。 # iii. 等待工作进程结束。 在各个工作者的 `index_queue` 上调用 `.cancel_join_thread()`。 # 注意:(c) 放在 (b) 之后更好,因为它可能会留下损坏 # 工作结果队列中的数据,该数据由 `pin_memory_thread` 从读取,在这种情况下,`pin_memory_thread`只能 尽管在超时发生时会发生,这很慢,但如果是工人在不幸的时刻被信号杀死,情况也相同 但是在其他情况下,我们最好让`worker_result_queue`对`pin_memory_thread`保持非损坏状态 但是在其他情况下,我们最好让`worker_result_queue`对`pin_memory_thread`保持非损坏状态 但是在其他情况下,我们最好让`worker_result_queue`对`pin_memory_thread`保持非损坏状态 # # 如果 `pin_memory=False`,则没有 `pin_memory_thread`,并且(b) # 可以省略 # # NB: `done_event`s 并不是严格必要的。例如,我们可以直接检查 # `index_queue` 中的 `None`,但这允许我们跳过浪费资源 如果我们已经在关闭,则处理 `index_queue` 中已经存在的索引。 def 初始化(self, 加载器): 超级().初始化(加载器) self._prefetch_factor = 加载器.预取因子 self._in_order = 加载器.按顺序 断言 self._num_workers > 0 断言 self._prefetch_factor > 0 如果 加载器.多进程上下文 : 多进程上下文 = 火炬.多进程 否则: 多进程上下文 = 加载器.多进程上下文 self._worker_init_fn = 加载器.工作进程初始化函数 # 添加向前兼容性,以便经典 DataLoader 可以与 DataPipes 一起工作: # 额外的 worker 初始化函数将负责在 MP 和分布式中进行分片 如果 isinstance(self._数据集, (迭代数据处理管道, 地图数据管道)): self._worker_init_fn = functools.偏函数( _sharding_worker_init_fn, self._worker_init_fn, self._world_size, self.排名, ) 无法确定 multiprocessing_context 模块 self._工作结果队列 = 多进程上下文.队列() # type: ignore[var-annotated] self._工作进程 ID 集合 = self.关闭 = self.工人完成事件 = 多进程上下文.活动() self._index_queues = [] self._workers = [] i 范围(self._num_workers): # 没有确定性,multiprocessing_context 模块 索引队列 = 多进程上下文.队列() # type: ignore[var-annotated] 这里需要调用`cancel_join_thread`! 请参阅上面的(2)和(3b)部分。 索引队列.取消加入线程() w = 多进程上下文.流程( 目标=_工具.工人.工作循环, 参数=( self.数据集类型, self._数据集, 索引队列, self._worker 结果队列, self._工作者完成事件, self._auto_collation, self._collate_fn, self._drop_last, self._base_seed, self.工人初始化函数, i, self._num_workers, self._持久化工作进程, self._共享种子, ), ) w.守护进程 = 真实 # NB: Process.start() 实际上需要一些时间,因为它需要 # 启动一个进程并通过管道传递参数。 因此,我们只在它启动后将其添加到 self._workers 列表中,这样我们就不需要在程序死亡之前调用.join(),否则__del__尝试 join 时会得到: AssertionError: 只能 join 已启动的进程。 因此,我们只在它启动后将其添加到 self._workers 列表中,这样我们就不需要在程序死亡之前调用.join(),否则__del__尝试 join 时会得到: AssertionError: 只能 join 已启动的进程。 w.开始() self._索引队列.追加(索引队列) self._工作者.追加(w) 如果 self._内存映射: self._内存线程完成事件 = 线程.活动() 队列未进行类型注解 self._数据队列 = 队列.队列() # type: ignore[var-annotated] 当前设备 = -1 如果 self._内存设备 == cuda: 当前设备 = 火炬.cuda.当前设备() elif self._内存设备 == XPU: 当前设备 = 火炬.xpu.当前设备() elif self._内存设备 == 火炬._C._get_privateuse1_backend_name(): 自定义设备模块 = getattr( 火炬, 火炬._C._get_privateuse1_backend_name() ) 当前设备 = 自定义设备模块.当前设备() elif self._内存设备 : 当前设备 = 火炬.加速器.current_device_index() 内存线程 = 线程.Thread( 目标=_工具.持久化内存._内存循环, 参数=( self._工作结果队列, self._data_queue, 当前设备, self._内存线程完成事件, self._内存设备, ), ) 内存线程.守护进程 = 真实 针对内存线程.开始() # 与工作者类似(见上方注释),我们只注册 # 一旦启动,就注册针对内存线程 self._pin_memory_thread = _pin_memory_thread 否则: self._data_queue = self._worker_result_queue # 类型:忽略[赋值] 在某些罕见情况下,持久的工人(守护进程) 在主进程退出之前会被终止 当主进程退出时 这会导致当 pin_memory_thread 尝试读取时失败 工作结果队列中的损坏数据 在主进程退出前正确顺序使用 atexit 关闭线程和子进程 在主进程退出前正确顺序使用 atexit 关闭线程和子进程 如果 self._persistent_workers self._pin_memory: 导入 atexit w self.工人: atexit.注册(_多进程数据加载迭代器._清理工作进程, w) # .pid 可以仅在进程尚未启动时为 None(不是这种情况,所以忽略) _工具.信号处理._设置工作进程 ID(id(self), 元组(w.进程 ID w self._工作者)) # 类型:忽略[杂项] _工具.信号处理.设置 SIGCHLD 处理程序() self.设置工作进程 ID = 真实 self.重置(加载器, 第一次迭代=True) def 重置(self, 加载器, 第一次迭代=错误): 超级().重置(加载器, 第一次迭代) self.发送索引 = 0 # 下一个要发送给工作者的任务索引 self._rcvd_idx = 0 # 下一个在 __next__ 中返回的任务索引 # 尚未产生的数据信息,即索引在 [rcvd_idx, send_idx) 范围内的任务 # 地图:任务索引 => - (工作者 ID,) 如果数据未获取(待处理) # \ (工作者 ID,数据) 如果数据已获取(顺序错误) self._任务信息 = {} self._待处理任务 = ( 0 # always equal to count(v for v in task_info.values() if len(v) == 1) ) # A list of booleans representing whether each worker still has work to # do, i.e., not having exhausted its iterable dataset object. It always # contains all `True`s if not using an iterable-style dataset # (即,如果 kind 不等于可迭代)。 # 并不意味着这个工作者在这个 epoch 中还有工作要做。 # 这并不意味着工作者已经死亡。在`_persistent_workers`的情况下, # 工作者将在下一个 epoch 中被重置为可用状态。 self._workers_status 工人状态 = [真实 i 范围(self._num_workers)] # 每个工作者未完成的任务数量列表 # 当任务分配给工作者时递增 当那些数据已提供给主线程时递减 每个工作线程最多应有 self._prefetch_factor 个挂起的任务 self.工人_num_tasks = [0 i 范围(self._num_workers] 重置工作队列循环,以便在下一个 epoch 从工作器 0 开始 self._worker_queue_idx_cycle = itertools.循环(范围(self._num_workers)) # 我们在启用的情况下继续预取 如果 not 第一次迭代: 索引 范围(self._num_workers): self._index_queues[索引].放置( _工具.工人._简历迭代(self._共享种子) ) 简历迭代次数 = self._num_workers while 简历迭代次数 > 0: 返回索引, 返回数据 = self._获取数据() 如果 isinstance(返回索引, _工具.工人._恢复迭代): 断言 返回数据 None 简历迭代次数 -= 1 预先启动预取循环 _ 范围(self._prefetch_factor * self._num_workers): self.尝试设置索引() def 尝试获取数据(self, 超时=_工具.MP 状态检查间隔): 尝试在给定超时时间内从`self._data_queue`获取数据一次。 这也可以用作无超时获取的内部循环。 循环条件为发送者状态。 # 如果任何工作进程意外死亡,将引发 `RuntimeError`。此错误 # 可来自 `_utils/signal_handling.py` 中的 SIGCHLD 处理程序 #(仅适用于非 Windows 平台),或以下手动检查错误 # 和超时。 # 返回一个 2 元组: (bool:是否成功获取数据,any:成功获取数据时的数据,否则为 None) try: 数据 = self._数据队列.获取(超时=超时) 返回 (True, 数据) 除了 异常 作为 e: 在超时和错误情况下,我们手动检查是否有任何工作进程 # 失败。注意,这是 Windows 检测的唯一机制 # 工作失败。 failed_workers = [] worker_id, w 列举(self.工人): 如果 self.工人状态[worker_id] not w.是否存活(): 失败的工人.追加(w) self.标记工人为不可用(worker_id) 如果 长度(失败的工人) > 0: 进程 ID 字符串 = “,”.加入(字符串(w.进程 ID) w 失败的工人) 提升 运行时错误( f"数据加载器工作进程(进程 ID("{pids_str}程序异常退出" ) 来自 e 如果 isinstance(e, 队列.空的): 返回 (错误, ) 导入 errno 导入 tempfile try: 如果接近文件描述符限制,则抛出异常。 显然,仅尝试打开一个文件是不够的 测试。 查看[ DataLoader 在 Linux 上和打开文件限制 ] fds_limit_margin = 10 [临时文件.命名临时文件() i 范围(fds_limit_margin] 除了 OSError 作为 e: 如果 e.errno == 错误号.文件描述过多: 提升 运行时错误( 文件打开太多。与“的”通信 工人不再可能。请增加“” "在 shell 中使用`ulimit -n`限制或更改" "通过调用" " `torch.multiprocessing.set_sharing_strategy('file_system')`" "在代码开头" ) 来自 None 提升 # 注意 [Linux 上的 DataLoader 和打开文件限制] # 在 Linux 系统中,当使用多进程与 DataLoader 一起使用时,我们传递数据之间 通过 SHM 文件管理根进程和工作者。我们删除这些文件。 文件系统一旦创建就立即加载,并通过它们保持活跃 通过 AF_UNIX 套接字传递文件描述符。(参见 文档/源/multiprocessing.rst 和维基百科中的'多进程技术笔记'。) (https://github.com/pytorch/pytorch/wiki)。) # 这有时会导致我们超出打开文件的限制。当这种情况发生时, 犯罪文件描述符是通过套接字传来的,Python 的 `socket` # 包无声地移除了文件描述符,只设置了 `MSG_CTRUNC` 标志(可能有点误导,因为手册说明) # 表示由于空间不足而丢弃了一些控制数据 辅助数据缓冲区)。这可能会反映 C 实现的 # 阿福 UNIX 套接字。 # # 这种行为可以通过下面的脚本和说明进行重现。 # 请在笔记底部查看。 # # 当发生这种情况时,标准的 Python `multiprocessing`(而不是) # `torch.multiprocessing`引发`RuntimeError: 收到 0 个 ancdata`项 # # 有时,FD 没有被剥离,你可能会得到一个`OSError:` # 在下面的脚本和 DataLoader 中,你可能会得到一个`Too many open files`,然而 # 这很罕见,并且似乎是非确定性的。 # # # #!/usr/bin/env python3 # import sys # import socket # import os # 导入数组 # 导入 shutil 模块 # 导入 socket 模块 # # # 如果 sys.argv 的长度不等于 4: 打印("用法: ", sys.argv[0], " tmp_dirname 迭代 (发送|接收)") sys.exit(1) # 如果 __name__ == '__main__': dirname = sys.argv[1] # sock_path = dirname + "/sock" # iterations = int(sys.argv[2]) # def dummy_path(i): # return dirname + "/" + str(i) + ".dummy" # # # 如果 sys.argv[3] == 'send': # 当 os.path.exists(sock_path) 为 False 时: # 空操作 # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) # 客户端连接(sock_path) # for i in range(iterations): # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT) # ancdata = array.array('i', [fd]) # msg = bytes([i % 256]) # 打印("发送 fd ", fd, " (迭代 #", i, ")") # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)]) # # # else: # assert sys.argv[3] == 'recv' # # 如果 os.path.exists(dirname): # 抛出异常("目录存在") # # os.mkdir(dirname) # 打开套接字... # 服务器 = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) # 服务器绑定到套接字路径 # # 正在监听... # for i in range(iterations): # a = array.array('i') # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize)) # assert(len(ancdata) == 1) # cmsg_level, cmsg_type, cmsg_data = ancdata[0] # a.frombytes(cmsg_data) # 打印("收到 fd ", a[0], " (迭代 #", i, ")") # # 删除目录(dirname) # # 步骤重现: # # 1. 运行两个 shell,并在接收端设置文件描述符限制: # (shell1) ulimit -n 1020 # (shell2) ulimit -n 1022 # # 2. 在第一个 shell 中运行上面的脚本,使用`recv`选项 # (shell1) ./test_socket.py sock_tmp 1017 recv # # 3. 在第二个 shell 中运行脚本,使用`send`选项 # (shell2) ./test_socket.py sock_tmp 1017 send def 获取数据(self): # 从 `self._data_queue` 中获取数据。 # # 我们每 `MP_STATUS_CHECK_INTERVAL` 秒检查一次工作者的状态, # 通过运行 `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` 实现。 在循环中。这是检测工作失败的唯一机制。 对于 Windows。对于其他平台,也使用 SIGCHLD 处理程序来检测工作失败。 工作失败检测。 # 如果`pin_memory=True`,我们还需要检查`pin_memory_thread`是否已 # 在超时中死亡。 如果 self._超时 > 0: 成功, 数据 = self.尝试获取数据(self.超时) 如果 成功: 返回 数据 否则: 提升 运行时错误( fDataLoader 超时后{self.超时} ) elif self._内存映射: while self._pin_memory_thread.是否存活(): 成功, 数据 = self._try_get_data() 如果 成功: 返回 数据 否则: # 当 while 条件为假时,即 pin_memory_thread 已死亡。 提升 运行时错误("Pin 内存线程意外退出") 在此情况下,`self._data_queue` 是一个 `queue.Queue`,但我们不需要调用 `.task_done()`,因为我们没有使用 `.join()`。 不需要调用 `.task_done()`,因为我们没有使用 `.join()`。 否则: while True: 成功, 数据 = self._try_get_data() 如果 成功: 返回 数据 def _next_data(self): while True: 如果负责 `self._rcvd_idx` 的工作线程已经结束 由于耗尽了一个 `IterableDataset`,无法完成此任务, 我们尝试将 `self._rcvd_idx` 前进以找到下一个有效的索引。 # 这部分需要在循环中运行,因为 `self._get_data()` 调用和下面的 `_IterableDatasetStopIteration` 检查都可以标记 这一部分需要在循环中运行,因为 `self._get_data()` 调用和下面的 `_IterableDatasetStopIteration` 检查都可以标记 多余的工人已死亡。 while self.接收索引 < self.发送索引: 信息 = self.任务信息.获取(self.接收索引, ) 如果 信息: 工作者 ID = 信息[0] 如果 ( 长度(信息) == 2 self._workers 状态[worker_id] ): # 有数据或仍在活动状态 断开 删除 self._任务信息[self._接收索引] self._rcvd_idx += 1 否则: # 未找到有效的 `self._接收索引`(即未中断) 如果 not self._持久化工作进程: self._关闭工作者() 提升 StopIteration 现在 self._rcvd_idx 是我们要获取的批次索引 检查下一个样本是否已经被生成 如果 长度(self.任务信息[self._接收索引]) == 2: worker_id, 数据 = self._任务信息.弹出(self._接收索引) self._rcvd_idx += 1 返回 self._处理数据(数据, worker_id) 断言 not self.关闭 self._待办任务 > 0 索引, 数据 = self.获取数据() self.未完成任务 -= 1 如果 self.数据集类型 == _数据集类型.迭代器: 检查 IterableDatasetStopIteration 如果 isinstance(数据, _工具.工人._可迭代数据集停止迭代): 如果 self._持久化工作进程: self._工作者状态[数据.worker_id] = 否则: self._标记工作者为不可用(数据.worker_id) self._尝试放入索引() 继续 如果 索引 != self.接收索引: 如果 not self.按顺序: 不要存储以备后用,立即处理 立即从 self._task_info 中删除 # 保持对象大小可控 员工 ID = self.任务信息.弹出(索引)0] 返回 self.处理数据(数据, worker_id) # 存储乱序样本 self._任务信息[索引] += (数据,) 否则: 工作器 ID = self.任务信息.弹出(索引)0] self._rcvd_idx += 1 返回 self.处理数据(数据, worker_id) def 尝试设置索引(self): 最大任务 = self._prefetch_factor * self._num_workers 断言 self.待处理任务 < 最大任务数 try: 索引 = self.下一个索引() 除了 StopIteration 异常: 返回 _ 范围(self._num_workers): # 查找下一个活跃的工作者(如果有) 工作队列索引 = 下一(self.工作队列索引循环) 如果 self.工人状态[工作队列索引]: 如果 self._按顺序: 断开 elif self._工作者任务数[工作队列索引] < 最大任务数 // 总和( self.工作者状态 ): 当 self._in_order 为 False 时,如果工作者有容量,则分配工作给工作者 _workers_status 只在此线程中更新,因此总和保证大于 0 断开 否则: 未找到(即未中断) 返回 self._index_queues[工作队列索引].放置((self.发送索引, 索引)) # type: ignore[possibly-undefined] self.任务信息[self.发送索引] = (工作队列索引,) self.工作者任务数[工作队列索引] += 1 self.未完成任务数 += 1 self.发送索引 += 1 def 处理数据(self, 数据, worker_idx): self._workers_num_tasks[worker_idx] -= 1 self.尝试设置索引() 如果 isinstance(数据, ExceptionWrapper): 数据.reraise() 返回 数据 def 标记工人为不可用(self, worker_id, 关闭=错误): 标记一个工作进程已完成其工作,例如,由于耗尽一个 `IterableDataset`。 这应在 `_MultiProcessingDataLoaderIter` 继续运行时使用。 # `_MultiProcessingDataLoaderIter` 即将运行。 断言 self.工人状态[worker_id] ( self._persistent_workers 关闭 ) # 向特定工作者发送终止信号。 q = self._index_queues[worker_id] # 表示当前进程不再向该队列添加更多数据。 # 结束。 q.放置() 注意这里我们并没有在这里连接工作线程,也没有从 C 侧结构中移除工作线程的进程 ID,因为(1)连接可能很慢,并且 # (2)由于我们不连接,工作线程可能仍然会引发错误,而我们更愿意捕获这些错误,而不是忽略它们,即使它们 # (3)可能不会立即影响主线程的执行 # ,我们仍然希望记录这些错误,以便于后续分析 在工人完成其工作后引发。 将连接推迟到 `_shutdown_workers`,它在调用时执行 所有工作者完成他们的工作(例如,`IterableDataset`副本) 当这个迭代器被垃圾回收时。 self.工人状态[worker_id] = 断言 self._工作者完成事件.已设置() == 关闭 def _关闭工作者(self): # 当关闭此 `_MultiProcessingDataLoaderIter` 时调用。 # 详细请参阅 NOTE [数据加载器多进程关闭逻辑]。 # 该函数的逻辑请参阅。 如果 ( _utils None _工具.python 退出状态 真实 _工具.python 退出状态 None ): # 参考第(2)条注释。如果 Python 正在关闭,则执行无操作。 返回 # 当最后一个引用消失/迭代器耗尽时正常退出。 # 查看(1)和笔记的第二部分。 如果 not self.关闭: self.关闭 = 真实 try: # 当最后一个引用消失/迭代器耗尽时正常退出。 # 查看(1)和笔记的第二部分。 首先退出 `pin_memory_thread`,因为退出工作进程可能会在 `worker_result_queue` 中留下 被损坏的数据,而 `pin_memory_thread` 会从中读取。 如果 有属性(self, _pin_memory_thread): # 使用 hasattr 以防在设置属性之前发生错误。 self._内存线程完成事件.集合() # 如果 pin_memory_thread 正在等待,则发送信息给它。 # 这样它可以醒来并检查 `pin_memory_thread_done_event`。 self._工作结果队列.放置((, )) self._pin_memory_thread.加入() self._工作结果队列.取消加入线程() self._工作结果队列.关闭() 立即退出工作者。 self._工作者完成事件.集合() 员工 ID 范围(长度(self.工人)): 应从 `len(self._workers)` 获取工作者数量,而不是从 `self._num_workers`,以防在启动所有工作者之前发生错误。 `self._num_workers`,以防在启动所有工作者之前发生错误。 # workers. 如果我们使用带有持久工作者的 workers_status 我们必须将其关闭,因为工作器已暂停 如果 self._persistent_workers self.工人状态[worker_id]: self.标记工人为不可用(worker_id, 关闭=True) w self.工人: 我们应该能够在这里连接,但如果出了问题, 我们设置了一个超时,如果工作者未能连接, 它们将在`finally`块中被杀死。 w.加入(超时=_工具.MP 状态检查间隔) q self._index_queues: q.取消加入线程() q.关闭() 最后: 即使这个函数所做的只是将任务放入队列, 即使我们在其上调用`cancel_join_thread`,当工作进程被信号杀死时, 例如,在`Event.set()`中挂起时,也可能发生奇怪的事情。 因此,我们需要用 SIGCHLD 处理程序来保护这一点。 仅从 C 侧数据结构中删除 pids 结束。 # # FIXME: 很遗憾,对于 Windows,我们缺少一个工作进程 # 此函数中包含错误检测机制,如下 # 不提供 SIGCHLD 处理程序。 如果 self._worker_pids_set: _工具.信号处理._remove_worker_pids(id(self)) self._工作进程 ID 集合 = w self.工人: 如果 w.是否存活(): # 现有的机制试图让工作进程退出 平静地,但不幸的是,如果我们不幸达到 这里,我们不应该到达的(例如,pytorch/pytorch#39570), 我们会杀死这个工作进程。 w.终止() 静态方法用于移除对 `_MultiProcessingDataLoaderIter` 的引用 @staticmethod def 清理工作线程(w): try: w.加入(超时=_工具.MP 状态检查间隔) 最后: 如果 w.是否存活(): w.终止() def __del__(self): self.关闭工作线程()

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

查看 PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源,获取您的疑问解答

查看资源