# mypy: 允许未类型化定义
rDataLoader 和相关派生自 _BaseDataLoaderIter 的迭代器的定义。
为了支持这两个类,在 `./_utils` 中我们定义了许多实用方法。
多进程中的运行函数。例如,数据加载工作循环是
在 `./_utils/worker.py` 文件中。
""
导入 functools
导入 itertools
导入
记录日志
导入
多进程
作为 python_multiprocessing
导入
操作系统
导入
队列
导入
线程
导入
警告
来自 collections.abc
导入
迭代器
来自
打字
导入
任意,
可调用,
通用,
可选,
类型变量,
联合
导入
火炬
导入 torch.distributed
作为 dist
导入 torch.utils.data.graph_settings
来自 torch._utils
导入
异常包装器
来自 torch.utils.data
导入 _utils
来自 torch.utils.data.datapipes.datapipe
导入 (
_IterDataPipe 序列化包装器,
_MapDataPipe 序列化包装器,
IterDataPipe,
MapDataPipe,
)
来自 torch.utils.data.dataset
导入
数据集,
可迭代数据集
来自 torch.utils.data.sampler
导入 (
批量采样器,
随机采样器,
采样器,
顺序采样器,
)
__all__ = [
"数据加载器",
获取工作信息,
默认合并,
默认转换,
]
_T = 类型变量("_T")
_T_co = 类型变量(
_T_co,
协变=True)
_worker_init_fn_t = 可调用[[int
]
无]
理想情况下,我们应该通过 `collate_fn` 的返回类型来参数化 `DataLoader`,但目前还没有办法设置一个默认值,如果用户没有传入自定义的 'collate_fn'。
如果用户没有传入自定义的 'collate_fn',则可以将类型参数设置为默认值。
请参阅 https://github.com/python/mypy/issues/3737。
_collate_fn_t = 可调用[[
列表[_T]],
任意]
这些函数曾经定义在这个文件中。但是,它已经被移动到了
# _utils/collate.py。尽管从用户层面访问它相当困难
# (必须显式地直接导入 `import torch.utils.data.dataloader`),因此
# 可能是用户代码在使用它。这种别名保持向后兼容性。
# 方面。
默认的 collate: _collate_fn_t =
_工具.
汇总.
默认汇总
默认转换 =
_工具.
汇总.
默认转换
获取工作者信息 =
_工具.
工人.
获取工作者信息
日志记录器 =
记录日志.
获取日志记录器(__name__)
类
_数据集类型:
地图 = 0
迭代器 = 1
@staticmethod
def 创建获取器(
类型,
数据集,
自动对齐, collate_fn, drop_last):
如果
仁慈 == _DatasetKind.Map:
返回
_工具.
获取._MapDatasetFetcher(
数据集,
自动归一化,
归一化函数, drop_last
)
否则:
返回
_工具.
获取._IterableDatasetFetcher(
dataset, auto_collation, collate_fn, drop_last
)
类 _InfiniteConstantSampler(Sampler):
r类似于 `itertools.repeat(None, None)`。
用作 `torch.utils.data.IterableDataset` 的采样器。
"源代码"
def __iter__(self):
while True:
产生 None
def _get_distributed_settings():
如果
距离.
是否可用()
和
距离.
已初始化():
返回
距离.
获取世界大小(),
距离.
获取排名()
否则:
返回 1, 0
def _sharding_worker_init_fn(工作器初始化函数,
世界大小,
排名 ID,
工作器 ID):
全局工作器 ID =
工作 ID
信息 =
火炬.
工具.
数据.
获取工人信息()
断言
信息
是 not None
总工人数 =
信息.num_workers
数据管道 =
信息.
数据集
断言 isinstance(
数据管道, (
迭代数据管道,
映射数据管道))
为了使元素在分布式进程间均匀分布,我们应该在分布式上分片数据
首先处理然后在工作进程中分片
总工作进程数 *=
世界大小
全局工作进程 ID =
全局工作进程 ID *
世界大小 + rank_id
# 对于 BC,使用默认的 SHARDING_PRIORITIES
火炬.
工具.
数据.
图设置.
应用分片(
数据管道,
总工作进程数,
全局工作进程 ID
)
如果
工作进程初始化函数
是 not
无:
worker_init_fn(worker_id)
def _share_dist_seed(生成器, pg):
_shared_seed = 火炬.
空的((),
数据类型=
火炬.int64).
随机(
生成器=
生成器)
如果 isinstance(pg,
距离.
流程组):
距离.
广播(_shared_seed,
源=0,
群组=pg)
返回 _shared_seed.
项目()
[文档]
类
数据加载器(
通用[_T_co
)]
r""
数据加载器将数据集和采样器结合,并为给定数据集提供可迭代的对象。
`torch.utils.data.DataLoader` 支持映射样式和可迭代样式的数据集,具有单进程或多进程加载,可自定义加载顺序以及可选的自动批处理(收集)和内存固定。
加载,自定义加载顺序以及可选的自动批处理(收集)和内存固定。
加载顺序以及可选的自动批处理(收集)和内存固定。
请参阅:py:mod:`torch.utils.data` 文档页面以获取更多详细信息。
参数:
dataset (Dataset):从其中加载数据的数据集。
batch_size (int, 可选):每次加载的批次样本数量
(默认:`1`)。
shuffle(布尔值,可选):设置为 ``True`` 以在每次 epoch(默认:``False``)时重新排列数据
在每个 epoch(默认:``False``)时进行采样。
sampler(采样器或可迭代对象,可选):定义从数据集中抽取样本的策略。可以是任何具有 ``__len__`` 的 ``Iterable`` 对象
可以是任何具有 ``__len__`` 的 ``Iterable`` 对象,用于从数据集中抽取样本。
实现。如果指定了,则不能指定:attr:`shuffle`。
批量采样器(采样器或可迭代对象,可选):类似于 :attr:`sampler`,但
每次返回一批索引。与...互斥。
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`
和:attr:`drop_last`。
num_workers(int,可选):使用多少个子进程来加载数据。
``0`` 表示数据将在主进程中加载。
(默认:``0``)
collate_fn (Callable, 可选): 合并一系列样本
小批量张量。在从批量加载时使用。
地图样式数据集。
pin_memory(布尔值,可选):如果为 ``True``,数据加载器将复制张量
在返回之前将数据元素放入设备/CUDA 固定内存中。如果您的数据元素是自定义类型,或者您的:attr:`collate_fn`返回的批次是自定义类型,请参阅下面的示例。
如果您的数据元素是自定义类型,或者您的:attr:`collate_fn`返回的批次是自定义类型,请参阅下面的示例。
请参阅下面的示例。
drop_last (布尔值,可选): 将其设置为 ``True`` 以丢弃最后一个不完整的批次,
如果数据集大小不能被批处理大小整除。如果 ``False``,
那么数据集的大小不能被批处理大小整除,则最后一个批次
的大小将更小。(默认:``False``)
超时时间(数值,可选):如果为正数,则为收集一个批次的超时时间
来自工人。始终应为非负数。(默认:``0``)
worker_init_fn(Callable,可选):如果不为 ``None``,则将在每个
工人子进程上调用,使用工人 ID(一个在 ``[0, num_workers - 1]`` 范围内的 int)作为
输入,在播种和加载数据之前。(默认:``None``)
multiprocessing_context (str 或 multiprocessing.context.BaseContext, 可选): 如果
``None`,您的操作系统的默认 `multiprocessing context`_ 将
被使用。(默认:`None`)
生成器(torch.Generator,可选):如果不为 ``None``,则使用此随机数生成器
通过 RandomSampler 生成随机索引和通过多进程生成
`base_seed` 用于工作者。(默认:`None`)
prefetch_factor(int,可选,关键字参数):预取批次数
提前由每个工人完成。``2``表示总共
2 * num_workers 个批次的预取数据跨所有工作者。默认值取决于 num_workers 的设置值。如果 num_workers=0,则默认为 ``None``。
如果 num_workers 的值为 0,则默认为 ``None``。
否则,如果 num_workers 的值大于 0,则默认为 ``2``)。
persistent_workers(布尔值,可选):如果设置为 ``True``,数据加载器将不会关闭
数据集被消耗一次后的工作进程。这允许
维护工作进程的 `Dataset` 实例存活。(默认:``False``)
pin_memory_device(str,可选):如果 ``pin_memory`` 为 ``True``,则将其固定在的设备上
。如果没有提供,则当前 :ref:`加速器` 将是
默认。此参数不建议使用,并可能被弃用。
in_order (bool, optional): 如果 ``False``,数据加载器将不会强制执行批次以先进先出顺序返回。仅在 ``num_workers > 0`` 时适用。(默认:``True``)
..警告:: 如果使用 ``spawn`` 启动方法,:attr:`worker_init_fn`
..警告:: 如果使用 ``spawn`` 启动方法,:attr:`worker_init_fn`
无法是一个不可序列化的对象,例如 lambda 函数。详见
ref:`multiprocessing-best-practices` 了解更多与 PyTorch 中多进程相关的细节。
有关。
..警告:: ``len(dataloader)`` 指数基于所使用的采样器的长度。
当 :attr:`dataset` 是一个 :class:`~torch.utils.data.IterableDataset` 时,
它将返回一个基于 ``len(dataset) / batch_size`` 的估计值,根据 :attr:`drop_last` 进行适当的舍入,
无论多进程加载配置如何,这代表了 PyTorch 能做出的最佳猜测。
这代表了 PyTorch 能做出的最佳猜测,因为 PyTorch
信任用户正确处理多进程:attr:`数据集`代码,以避免重复数据。
加载以避免重复数据。
然而,如果分片导致多个工作进程拥有不完整的最后批次,
这个估计仍然可能不准确,因为(1)一个本应完整的批次可以
可以拆分成多个,并且(2)多个批次值的样本也可以被
当设置 :attr:`drop_last` 时会被丢弃。不幸的是,PyTorch 通常无法检测到这种情况。
的情况。
请参阅 `数据集类型`_ 了解这两种类型的数据集的更多详细信息以及如何
`torch.utils.data.IterableDataset` 类的交互
多进程数据加载_
请参阅 :ref:`可重现性`, 以及 :ref:`dataloader-workers-random-seed`, 和 :ref:`数据加载随机性` 相关的注意事项。
关于随机种子相关问题的说明,请参阅 :ref:`data-loading-randomness`。
..警告:将`in_order`设置为`False`可能会损害可重复性,并可能导致数据偏差
输入数据分布到训练器中的不平衡数据情况。
.. _多进程上下文:
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
"源代码"
数据集:
数据集[_T_co]
批处理大小:
可选[int]
num_workers: 整型
持久化内存:
布尔类型
drop_last: 布尔类型
超时:
浮点数
sampler: 联盟[Sampler,
迭代器]
pin_memory_device: 字符串
预取因子:
可选[int]
_迭代器:
可选[
"_基础数据加载迭代器"]
__已初始化 =
假
def 初始化(
self,
数据集:
数据集[_T_co
]
批大小:
可选[int] = 1,
打乱顺序:
可选[
布尔] =
无,
样本器:
联盟[
样本器,
迭代器,
无] =
无,
批量样本器:
联盟[
样本器[
列表
]
迭代器[
列表
]
无] =
无,
num_workers: 整型 = 0,
collate_fn: 可选[_collate_fn_t] =
无,
持久化内存:
布尔类型 =
错误,
drop_last: 布尔类型 =
错误,
超时:
浮点数 = 0,
worker_init_fn: 可选[_worker_init_fn_t] =
无,
多进程上下文=
无,
生成器=
无,
*,
预取因子:
可选[int] =
无,
持久化工作者:
布尔类型 =
错误,
内存设备:
字符串 = "",
按顺序:
布尔类型 = True,
):
火炬._C._log_api_usage_once("python.data_loader")
如果 num_workers < 0:
提升 ValueError(
num_workers 选项应为非负数;
"使用 num_workers=0 来禁用多进程。"
)
如果
超时 < 0:
提升 ValueError(
"超时选项应该是非负数。")
如果 num_workers == 0
和
预取因子
是 not
无:
提升 ValueError(
"prefetch_factor 选项只能在多进程模式下指定。"
"请设置 num_workers > 0 以启用多进程,否则将 prefetch_factor 设置为 None。"
)
elif num_workers > 0 和 prefetch_factor
是
无:
prefetch_factor = 2
elif 预取因子
是 not
无
和
预取因子 < 0:
提升 ValueError(
"预取因子选项应为非负")
如果
持久工作进程
和 num_workers == 0:
提升 ValueError(
"持久工作进程选项需要 num_workers > 0")
self.数据集 =
数据集
self.num_workers = num_workers
self.预取因子 =
预取因子
self.内存映射 =
内存映射
self.内存映射设备 =
内存映射设备
self.超时 =
超时
self.worker 初始化函数 =
worker 初始化函数
self.多进程上下文 =
多进程上下文
self.按顺序 =
按顺序
# 添加向前兼容性,以便经典 DataLoader 可以与 DataPipes 一起工作:
# _DataPipeSerializationWrapper 容器使得在不重新定义 pickler 的情况下进行序列化变得更加容易
如果 isinstance(self.
数据集,
迭代数据处理管道):
self.数据集 =
_迭代数据管道序列化包装器(self.
数据集)
elif isinstance(self.数据集,
地图数据管道):
self.数据集 =
_地图数据管道序列化包装器(self.
数据集)
在检查采样器之前先检查相关数据集,因为我们希望
告诉用户可迭代式数据集与自定义采样器不兼容,
因此,他们不会在修复自定义采样器错误后才发现这种组合不起作用。
这会浪费他们修复自定义采样器错误的时间。
如果 isinstance(
数据集,
可迭代数据集):
self._数据集类型 = _DatasetKind.
迭代器
# [自定义采样器和 IterableDataset]
#
# `IterableDataset` 不支持自定义 `batch_sampler` 或
# `sampler` 因为键无关紧要(除非我们有一天支持
# 生成器风格的 dataset...)。
#
对于`sampler`,我们始终创建一个虚拟采样器。这是一个
# 无限采样器,即使数据集可能已实现
# finite `__len__` 因为在多进程数据加载中,原始
设置将返回重复数据(可能是有意的)
因此使用与数据集长度匹配的采样器会导致数据丢失(你可能会看到前几批次的重复,但之后永远不会看到任何东西)。因此,`Iterabledataset` 总是使用无限采样器,即无限采样器的一个实例。
# cause data lost (you may have duplicates of the first couple
# batches, but never see anything afterwards). Therefore,
# `Iterabledataset` always uses an infinite sampler, an instance of
# 上文定义了 `_InfiniteConstantSampler`。
#
# 自定义 `batch_sampler` 实质上仅控制批大小。
然而,由于不清楚它会有多有用,因此以可迭代方式
数据集本身可以处理。此外,这样做毫无意义。
在多进程数据加载中,批次的分配顺序是一个实现细节,因此用户无法控制
因此我们禁用此选项。如果将来证明这个选项很有用,我们可以重新启用
因此我们禁用此选项。如果将来证明这个选项很有用,我们可以重新启用
因此我们禁用此选项。如果将来证明这个选项很有用,我们可以重新启用
# 这,并支持指定分配的自定义采样器。
# 指定特定工作者。
如果 isinstance(
数据集,
迭代数据管道):
如果
打乱
是 not
无:
数据集 =
火炬.
工具.
数据.
图形设置.
应用打乱设置(
数据集,
打乱=
打乱
)
# 我们不能在这里检查 `shuffle is not None`,因为之前 `shuffle=False` 是默认值。
elif 打乱 not
在 {
错误,
无}:
提升 ValueError(
fDataLoader 与 IterableDataset:期望未指定打乱选项,但得到 shuffle={
打乱}"
)
如果
样本器
是 not
无:
#参见 NOTE [自定义样本器和 IterableDataset]
提升 ValueError(
fDataLoader 与 IterableDataset:期望未指定采样器选项,但获取到 sampler={
采样器}"
)
elif 批量采样器
是 not
无:
#参见笔记[自定义采样器和 IterableDataset]
提升 ValueError(
DataLoader 与 IterableDataset:期望未指定
fbatch_sampler 选项,但得到 batch_sampler={batch_sampler}"
)
否则:
打乱顺序 =
布尔(
打乱)
self.数据集类型 =
_数据集类型.
地图
如果
样本
是 not None
和
随机播放:
提升 ValueError(
"样本选项与随机播放互斥")
如果
批量样本器
是 not
无:
使用自定义批采样器进行自动对齐
如果
批处理大小 != 1
或
打乱顺序
或
样本器
是 not None
或
丢弃最后:
提升 ValueError(
"batch_sampler 选项与 batch_size、shuffle、sampler 和 drop_last 互斥"
"以及 batch_size、shuffle、sampler 和 "
"drop_last"
)
批处理大小 = None
drop_last = 假
elif 批处理大小
是
无:
无自动对齐
如果 drop_last:
提升 ValueError(
"batch_size=None 选项禁用自动批处理"
"且与 drop_last 互斥"
)
如果
样本器
是
无:
# 提供默认样本器
如果 self.
_数据集类型 ==
_数据集类型.
迭代器:
# 注意 [自定义采样器和 IterableDataset]
采样器 =
_无限常量采样器()
否则:
# 地图样式
如果
打乱:
样本器 =
随机样本器(
数据集,
生成器=
生成器) # type: ignore[arg-type]
否则:
样本器 =
顺序样本器(
数据集) # type: ignore[arg-type]
如果
批处理大小
是 not None
和
批处理样本器
是
无:
自动归一化无需自定义批量采样器
批量采样器 =
批量采样器(
采样器,
批处理大小,
丢弃最后)
self.批处理大小 =
批处理大小
self.删除最后 =
删除最后
self.样本器 =
样本器
self.批量样本器 =
批量样本器
self.生成器 =
生成器
如果
合并函数
是
无:
如果 self.
_自动合并:
collate_fn = _工具.collate.default_collate
否则:
collate_fn = _工具.
汇总.
默认转换
self.汇总函数 =
汇总函数
self.持久化工作者 =
持久化工作者
self.__已初始化 =
真实
self._IterableDataset_len_called = (
None # 查看 NOTE [ IterableDataset 和 __len__ ]
)
self.迭代器 = None
self.检查工作数量合理性()
火炬.
设置生命体征(
数据加载器,
启用,
"真")
# 类型: 忽略[attr-defined]
def _获取迭代器(self) ->
"_基础数据加载迭代器":
如果 self.num_workers == 0:
返回
_单进程数据加载迭代器(self)
否则:
self.检查工作数量合理性()
返回
_多进程数据加载迭代器(self)
@property
def 多进程上下文(self):
返回 self.
__多进程上下文
@multiprocessing_context.设置器
def 多进程上下文(self,
多进程上下文):
如果
多进程上下文
是 not
无:
如果 self.num_workers > 0:
如果 isinstance(
多进程上下文,
字符串):
有效的启动方法 =
火炬.
多进程.
获取所有启动方法()
如果
多进程上下文 not
在
有效的启动方法:
提升 ValueError(
"multiprocessing_context 选项 "
f应指定一个有效的启动方法{
有效的启动方法!r}
,但实际上得到了"
fmultiprocessing_context={multiprocessing_context!r}"
)
multiprocessing_context = 火炬.
多进程.
获取上下文(
multiprocessing_context
)
如果 not isinstance(
多进程上下文,
Python 多进程.
上下文.
基础上下文
):
提升
类型错误(
"多进程上下文选项应是一个有效的上下文"
"对象或指定启动方法的字符串,但得到了"
f"multiprocessing_context="{multiprocessing_context}"
)
否则:
提升 ValueError(
"multiprocessing_context 只能与"
"多进程加载(num_workers > 0),但得到"
f"num_workers="{self.num_workers}"
)
self.__多进程上下文 =
多进程上下文
def __setattr__(self, 属性, val):
如果 self.
__已初始化
和
属性
在 (
批处理大小,
批采样器,
采样器,
最后一个丢弃,
"数据集",
"持久化工作者",
):
提升 ValueError(
f"{属性}
属性设置后不应更改{self.
类.__name__}
已初始化
)
超级().__setattr__(
属性, val)
# 我们引用 '_BaseDataLoaderIter',因为它尚未定义,且定义不能上移
# 因为 '_BaseDataLoaderIter' 引用了 'DataLoader'
def __iter__(self) -> "_BaseDataLoaderIter":
当使用单个工作线程时,返回的迭代器应该是
每次创建以避免重置其状态
然而,在多个工作者迭代器的情况下
迭代器在其生命周期中只创建一次
DataLoader 对象,以便工作者可以被重用
如果 self.
持久工作者
和 self.num_workers > 0:
如果 self.
迭代器
是
无:
self.迭代器 = self.
获取迭代器()
否则:
self.迭代器.
重置(self)
返回 self.
迭代器
否则:
返回 self.
获取迭代器()
@property
def 自动归一化(self):
返回 self.
批处理采样器
是 not None
@property
def _index_sampler(self):
# 实际用于为 `_DatasetFetcher` 生成索引的采样器
# (请参阅 _utils/fetch.py) 在每次读取数据时。这将是
`.batch_sampler` 如果处于自动合并模式,否则为 `.sampler`。
我们不能更改 `.sampler` 和 `.batch_sampler` 属性以保持向后兼容。
原因。
如果 self._auto_collation:
返回 self.
批处理采样器
否则:
返回 self.
采样器
def __len__(self) -> int:
如果 self.
数据集类型 == _DatasetKind.
迭代器:
# NOTE [ IterableDataset 和 __len__ ]
#
# 对于 `IterableDataset`,`__len__` 可能不准确,当直接
多进程加载数据,因为样本将被重复。
然而,没有任何实际用例应该真正使用这种行为,所以
它应该被视为用户错误。我们通常应该信任用户
代码做正确的事情(例如,为每个副本配置不同的设置)
# in `__iter__`), 并提供正确的 `__len__` 如果他们选择这样做
# 实现(如果数据集没有实现,这仍然会抛出异常)
# 为了提供进一步的警告,我们跟踪是否调用了 `__len__`
#
# ,如果 `__len__` 被调用,我们将提供额外的信息
# `DataLoader`,将返回值保存到`self._len_called`中,并警告
# 如果迭代器最终产生了超过这个数量的样本。
# 无法静态验证数据集是否为 Sized
长度 = self._IterableDataset_len_called =
长度(self.
数据集) # type: ignore[assignment, arg-type]
如果 (
self.批处理大小
是 not None
): # IterableDataset 不允许自定义采样器或批量采样器
来自
数学
导入
向上取整
如果 self.drop_last:
长度 =
长度 // self.
批处理大小
否则:
长度 =
向上取整(
长度 / self.
批量大小)
返回
长度
否则:
返回
长度(self._index_sampler)
def 检查工作进程数量合理性(self):
# 此函数检查根据当前系统资源,数据加载器的工作进程数量是否合理。当前规则是,如果工作进程的数量为
# 当前系统资源的倍数,则认为该数量是合理的。
# 加载数据加载器创建的进程数大于允许的逻辑 CPU 数时
# 我们将弹出警告提示用户注意。
#
# 例如,如果当前系统有 2 个物理 CPU,每个 CPU 有 16 个核心。并且每个核心支持 2
# 个线程,那么这里的总逻辑 CPU 数就是 2 * 16 * 2 = 64。假设当前
DataLoader 进程可以使用其中的一半,即 32 个,然后从该进程启动的 worker 的最大合理数量是 32。
现在,假设创建的 DataLoader 具有 num_works = 40,这个数字大于 32。
因此,会触发警告信息来通知用户降低 worker 的数量。
所以会触发警告信息来通知用户降低 worker 的数量。
# 必要的。
#
#
# [注意] 请注意,此函数仅在 os.sched_getaffinity 可用时才尊重 `cpuset`。
# 它在大多数 Linux 系统中可用,但在 OSX 和 Windows 中不可用。
# 当 os.sched_getaffinity 不可用时,将调用 os.cpu_count(),但它不尊重 cpuset。
# 它不尊重 cpuset。
我们不考虑线程,因为每个工作进程是单线程的
# 此时。
#
我们不设置任何线程标志(例如 OMP_NUM_THREADS, MKL_NUM_THREADS 等)
除了在工作进程中将`torch.set_num_threads`设置为 1 之外,如果传递
在函数中使用依赖于这些线程标志的第三方模块
创建多少个线程(例如,numpy 等),然后是调用者的责任
正确设置这些标志。
def 创建警告信息(
建议的 worker 数量,
已创建的 worker 数量,
检查了 cpuset):
建议的最大 worker 消息量 = (
(
(
我们建议当前系统中最大工人数为{}{}
, 是更小的 "
比这个 DataLoader 将要创建的内容。
).格式(
建议工作进程数,
(
请提供需要翻译的文本
如果
cpuset 已检查
否则
"(`cpuset`不被考虑)"
),
)
)
如果
建议工作进程数
是 not None
否则 (
数据加载器无法计算当前系统建议的最大工作进程数。
)
)
警告信息 = (
f"此数据加载器将创建"{
创建的进程数}
总共的 worker 进程。{
建议的最大 worker 消息数} "
"请注意,过度创建工作进程可能会导致 DataLoader 运行缓慢甚至冻结,"
"如果需要,请降低工作进程数量以避免潜在的缓慢/冻结。"
)
返回
警告信息
如果 not self.num_workers
或 self.num_workers == 0:
返回
尝试根据系统资源计算建议的最大工作线程数
最大工作线程数建议 = None
检查 CPU 集 =
假
如果
有属性(
操作系统,
"获取调度亲和性"):
try:
最大工作线程建议 =
长度(
操作系统.sched_getaffinity(0))
cpuset_checked = 真实
除了
异常:
通过
如果 max_num_worker_suggest
是
无:
# os.cpu_count() 可能返回 Optional[int]
首先获取 CPU 数量并检查是否为 None 以满足 mypy 检查
cpu_count = 操作系统.cpu_count()
如果 cpu_count
是 not
无:
最大工作线程建议 = cpu_count
如果 max_num_worker_suggest
是
无:
warnings.警告(
创建警告信息(
最大工作进程数建议, self.num_workers,
CPU 集检查
)
)
返回
如果 self.num_workers >
最大工作进程数建议:
warnings.警告(
创建警告信息(
最大工作线程数建议, self.num_workers,
cpuset 已检查
)
)
类
_基础数据加载迭代器:
def 初始化(self,
加载器:
数据加载器) ->
无:
self._数据集 =
加载器.
数据集
self.共享种子 = None
self._pg = None
如果 isinstance(self.
_数据集,
迭代数据处理管道):
如果
距离.
是否可用()
和
距离.
已初始化():
self._pg = 距离.
创建新组(
后端=
gluu)
self.共享种子 =
分享分布种子(
加载器.
生成器, self._pg)
shared_rng = 火炬.
生成器()
shared_rng.手动播种(self._shared_seed)
self.数据集 =
火炬.
工具.
数据.
图设置.
应用随机种子(
self.数据集,
共享随机数生成器
)
self.数据集类型 =
加载器.
数据集类型
self._IterableDataset_len_called = 加载器._IterableDataset_len_called
self.自动分词 =
加载器.
_自动对齐
self._丢弃最后一项 =
加载器.
删除最后
self._index_sampler = 加载器._index_sampler
self._num_workers = 加载器.num_workers
ws, 排名 =
获取分布式设置()
self.世界大小 = ws
self._排名 =
排名
如果未设置 pin_memory_device,则默认行为为当前加速器。
如果 pin_memory_device 已设置但 pin_memory 未设置,则默认
# 行为关闭。
如果
长度(
加载器.pin_memory_device) == 0:
如果
加载器.
内存映射
和 not
火炬.
加速器.
是否可用():
警告信息 = (
"将 'pin_memory' 参数设置为 true,但未找到加速器"
然后设备固定内存将不会被使用。
)
warnings.警告(warn_msg)
self._pin_memory = 加载器.
内存映射
和
火炬.
加速器.
是否可用()
self._pin_memory_device = None
目前,pin_memory 在 MPS 后端会引发错误(参见
# https://github.com/pytorch/pytorch/issues/86060),因此强制
# 禁用 MPS 上的 pin_memory。一旦固定,请移除此限制
# MPS 的内存分配是固定的。
如果 (
self._内存映射
和 (acc :=
火炬.
加速器.
当前加速器())
是 not None
和 acc.
类型 ==
mps
):
self._pin_memory = 假
警告信息 = (
"'pin_memory' 参数设置为 true,但目前 MPS 不支持,"
"则不会使用设备固定内存。"
)
warnings.警告(warn_msg)
否则:
如果 not
加载器.
持久化内存:
警告信息 = (
"已设置 'pin_memory_device' 但未设置 'pin_memory' 参数,"
"设备固定内存将不会被使用。"
"如果需要使用设备固定内存,请将 'pin_memory' 设置为 true"
)
warnings.警告(warn_msg)
self.固定内存 =
加载器.
内存映射
self.固定内存设备 =
加载器.
内存映射设备
self._超时 =
加载器.
超时
self._collate_fn = 加载器.
整理函数
self._sampler_iter = 迭代(self._index_sampler)
self._base_seed = (
火炬.
空的((),
数据类型=
火炬.int64)
.随机(
生成器=
加载器.
生成器)
.项目()
)
self._persistent_workers = 加载器.
持久化工作者
self._num_yielded = 0
self._profile_name = f"enumerate(DataLoader)#{self.类.__name__}
.__next__
def __iter__(self) -> "_基础数据加载迭代器":
返回 self
def 重置(self,
加载器,
第一次迭代=
错误):
self._采样迭代 =
迭代(self.
_索引采样)
self._已产出数量 = 0
self._IterableDataset_len_called = 加载器._IterableDataset_len_called
如果 isinstance(self.
数据集,
迭代数据处理管道):
self.共享种子 =
分享分布种子(
加载器.
生成器, self.
_ pg)
共享随机数生成器 =
火炬.
生成器()
共享随机数生成器.
手动播种(self.
_共享种子)
self._数据集 =
火炬.
工具.
数据.
图设置.
应用随机种子(
self.数据集,
共享随机数生成器
)
def 下一个索引(self):
返回
下一(self.
样本迭代器)
可能引发 StopIteration
def _下一数据(self):
提升
未实现异常
def __下一__(self) ->
任意:
与
火炬.
自动微分.
分析器.
记录功能(self.
_用户名):
如果 self._sampler_iter
是
无:
# TODO(https://github.com/pytorch/pytorch/issues/76750)
self.重置()
忽略调用参数
数据 = self._next_data()
self._num_yielded += 1
如果 (
self.数据集类型 == _DatasetKind.
迭代器
和 self._IterableDataset_len_called
是 not None
和 self._num_yielded > self._IterableDataset_len_called
):
警告信息 = (
fIterableDataset 的长度{self._dataset}
被报告为是{self.
_可迭代数据集_len_被调用}"
f"(当访问 len(dataloader))时,但"{self.
_已获取的样本数"}
已获取样本。
)
如果 self._num_workers > 0:
警告信息 += (
对于多进程数据加载,这可能是由于未正确配置“
"每个工作器上的可迭代数据集副本。请参阅"
"https://maskerprc.github.io/docs/stable/data.html#torch.utils.data.IterableDataset 以获取示例。"
)
warnings.警告(warn_msg)
返回
数据
def __len__(self) -> int:
返回
长度(self._index_sampler)
def __getstate__(self):
# TODO: 为共享迭代器添加有限的序列化支持
# 以跨多个线程用于 HOGWILD。
可能是这样做最好的方法,就是移动样本推送
将数据队列分离到单独的线程,然后仅共享数据队列
但是,在没有非阻塞 API 的情况下发出结束信号很棘手
提升
不支持的操作异常("{}
无法序列化, self.
类.__name__)
类
_单进程数据加载迭代器(
_基础数据加载迭代器):
def 初始化(self,
加载器):
超级().
初始化(
加载器)
断言 self.
_超时 == 0
断言 self._num_workers == 0
# 添加向前兼容性,以便经典 DataLoader 可以与 DataPipes 一起工作:
# 分布式分片处理
如果 isinstance(self.
_数据集, (
迭代数据处理管道,
地图数据管道)):
# 对于 BC,使用默认的 SHARDING_PRIORITIES
火炬.
工具.
数据.
图设置.
应用分片(
self._数据集, self.
世界大小, self.
_排名
)
self.数据集获取器 = _DatasetKind.
创建获取器(
self.数据集类型,
self._数据集,
self._自动合并,
self._collate_fn,
self._drop_last,
)
def _next_data(self):
索引 = self._next_index()
# 可能引发 StopIteration 异常
数据 = self.
_dataset_fetcher
数据集获取器.获取(
索引)
# 可能引发 StopIteration 异常
如果 self._pin_memory:
数据 =
_工具.
持久化内存.
持久化内存(
数据, self._pin_memory_device)
返回
数据
类
_多进程数据加载迭代器(
_基础数据加载迭代器):
r"""遍历一次 DataLoader 的数据集,由采样器指定。"""
# 数据加载器多进程关闭逻辑
#
# 初步:
#
# 我们的数据模型如下(队列用花括号表示):
#
# 主进程 ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {数据队列} ||
# | ||
# 数据输出 \/
#
# P.S. 如果 `worker_result_queue` 和 `pin_memory_thread` 部分可以省略,则
# `pin_memory=False`。
#
#
# 终止多进程逻辑需要非常谨慎的设计。特别是,我们需要确保
#
#
当迭代器的最后一个引用消失或耗尽时,它将优雅地退出工作者。
在这种情况下,应该优雅地退出工作者,因为主进程可能仍然需要继续运行,而我们希望进行清理。
#
在此情况下,工作者应该被优雅地退出,因为主进程可能仍然需要继续运行,我们希望进行清理。
在这种情况下,应该优雅地退出工作者,因为主进程可能仍然需要继续运行,我们希望进行清理。
# 在工作线程中执行释放 GPU 内存等操作代码。
# 自然地,我们在`__del__`中实现了关闭逻辑。
# DataLoaderIterator。
#
# 我们将对此处的逻辑讨论推迟到以后。
#
2. 迭代器在加载进程和/或工作进程退出时终止
进程正常退出或出现错误。
#
我们将所有工作者和`pin_memory_thread`设置为`daemon=True`。
#
您可能会问,为什么我们不能让工作进程非守护进程化,
优雅地退出,使用与我们在 `__del__` 中相同的逻辑
迭代器被删除了(参见上面第 1 点)?
#
首先,`__del__` 并不保证在调用时会被调用
# 解释器退出。即使被调用,在它执行时,
许多 Python 核心库资源可能已经被释放,甚至
简单的事情,比如获取队列的内部锁,也可能挂起。
因此,在这种情况下,我们实际上需要防止`__del__`的执行,
并依赖于守护进程的自动终止。
# 子代。
#
因此,我们注册了一个`atexit`钩子,该钩子设置了一个全局标志
`# `_utils.python_exit_status`。由于 `atexit` 钩子在
注册顺序的逆序,我们保证这个标志是
设置在释放我们使用的库资源之前(至少在
CPython,是通过在 `atexit` 处理器中定义完成的
`multiprocessing/util.py`
https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
当首次注册需要此机制的对象时
# 创建,例如,`mp.Queue`
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
# )
#
# 因此在 `__del__` 中,我们检查 `_utils.python_exit_status` 是否已设置
# `None`(释放),如果为空则执行无操作。
#
# 然而,仅仅让库的清理代码运行也可能不好,
# 因为这样的代码(例如,`multiprocessing.util._exit_function()`)
# 包括为`mp.Queue`的线程加入操作,这可能会导致阻塞。
因此,创建线程的主要过程被调用为
# 在创建时取消加入线程。参见后续章节
# [ 3b. 将进程放入队列时不会挂起; ]
# 查看更多详情。
#
这里是两个库清理代码可以运行的示例案例
在调用 `__del__` 之前:
#
1. 如果我们保留对迭代器的引用,它更常
# 不如尝试在 `multiprocessing` 库清理之前
清除所有活跃的引用对象(https://github.com/pytorch/pytorch/issues/48666)
因此阻止了我们的清理代码首先运行。
#
2. 当在子进程中使用 `DataLoader` 时,也会出现类似的问题。
当进程结束时,它会关闭所有其守护进程子进程。
# 将它们(而不是在超时前)一起退出。
# 对于线程来说,情况类似,但机制不同。这个事实,
# 以及多进程的一些实现细节,迫使我们让工作进程成为守护进程。所有的问题都源于当
# 我们的工作进程在某个时候退出时。
# 数据加载器在子进程中使用,由多进程引起
# 看起来大致是这样的代码:
#
尝试:
# 使用数据加载器调用你的函数()
finally:
# multiprocessing.util._退出函数()
#
上述提及的加入/终止发生在内部
# `_退出函数()`. 现在,如果 `your_function_using_a_dataloader()`
抛出异常时,异常中存储的堆栈跟踪将阻止使用 `DataLoaderIter` 的帧被释放。如果该帧有任何对 `DataLoaderIter` 的引用(例如,在迭代器的方法中),则其 `__del__` 方法,它启动关闭程序,将不会执行。
如果该帧有任何对 `DataLoaderIter` 的引用(例如,在迭代器的方法中),则其 `__del__` 方法,它启动关闭程序,将不会执行。
如果该帧有任何对 `DataLoaderIter` 的引用(例如,在迭代器的方法中),则其 `__del__` 方法,它启动关闭程序,将不会执行。
如果该帧有任何对 `DataLoaderIter` 的引用(例如,在迭代器的方法中),则其 `__del__` 方法,它启动关闭程序,将不会执行。
# 被调用。这反过来意味着工人没有收到通知。尝试
# 加入 `_exit_function` 将导致程序挂起。
#
# 为了上下文,`_exit_function` 也被注册为 `atexit` 调用。
我不明白(@ssnl)为什么需要在 finally 块中这样做。
这段代码可以追溯到 2008 年,原始代码中没有任何注释。
PEP 371 或补丁 https://bugs.python.org/issue3050(包含 finally 块和`atexit`注册)对此进行了解释。
这解释了为什么 finally 块和`atexit`注册需要同时存在。
#
#
最后,另一个选择是当我们在`next`中看到错误时,仅通过逻辑关闭工作者。这并不理想,因为
a. 它阻止用户使用 try-catch 来恢复数据加载。
b. 如果用户持有引用,则无法防止挂起。
c. 它不会阻止用户在`next`中看到错误时使用 try-catch 来恢复数据加载。
迭代器。
#
如果任何进程意外地因致命信号而死亡,则所有进程都将退出。
#
如上所示,工作进程被设置为父进程的守护进程子进程。
然而,此类子进程的自动清理仅
如果父进程优雅退出(例如,不是通过致命信号如 SIGKILL),则会出现这种情况。因此,我们必须确保每个进程都会退出,即使应该与之发送/接收数据的进程被杀掉,即,
所以我们必须确保每个进程都会退出,即使应该与之发送/接收数据的进程被杀掉,即,
所以我们必须确保每个进程都会退出,即使应该与之发送/接收数据的进程被杀掉,即,
所以我们必须确保每个进程都会退出,即使应该与之发送/接收数据的进程被杀掉,即,
#
在从队列中获取时,进程不会挂起。
#
即使数据依赖关系设计得非常精心(即,`put()` 总是与 `get()` 相对应),当队列中的数据损坏时(例如,由于 ...),在 `get()` 上仍然可能会发生挂起。
当队列中的数据损坏时(例如,由于 ...),在 `get()` 上仍然可能会发生挂起。
当队列中的数据损坏时(例如,由于 ...),在 `get()` 上仍然可能会发生挂起。
“取消加入线程”或意外退出。
#
对于子进程退出,我们每次尝试从`data_queue`获取数据时都会设置超时。
并在每个超时和错误时检查工作进程的状态。
出错。
查看 `_DataLoaderiter._get_batch()` 和 `_DataLoaderiter._try_get_data()` 以获取详细信息。
此外,对于非 Windows 平台上的子进程退出,我们还会注册一个 SIGCHLD 处理器(在 Windows 上受支持)。
#
另外,对于非 Windows 平台上的子进程退出,我们也会在 Windows 上注册一个 SIGCHLD 处理器。
在 Windows 上注册一个 SIGCHLD 处理器。
主流程,用于检查是否有任何工作进程失败,
(Python)处理器。与仅使用上述机制相比,这种方法在检测工作进程失败方面更高效、更快。
请参阅`DataLoader.cpp`和`_utils/signal_handling.py`以获取详细信息。
请参阅`DataLoader.cpp`和`_utils/signal_handling.py`以获取详细信息。
#
对于不是工作者的发送者进行的 `.get()` 调用,我们用超时来保护它们,并在超时发生时检查发送者的状态:
在工作者中,使用 `_utils.worker.ManagerWatchdog` 类
当超时发生时:
+ 在工作者中,使用 `_utils.worker.ManagerWatchdog` 类
检查主进程状态。
# + 如果 `pin_memory=True`,从 `pin_memory_thread` 获取时
定期检查 `pin_memory_thread` 状态
# 返回或查看 `pin_memory_thread` 已死。
#
# b. 将进程放入队列时不会挂起;
#
# 我们使用 `mp.Queue`,它有一个独立的后台线程将
# 对象从无界缓冲数组中放入。后台线程是守护线程,
# 通常在进程结束时自动结束。
# *退出*。
#
# 如果接收方在读取管道时突然结束,则连接将永远挂起。
# 在 Python 中,解决这个问题通常是通过调用 `q.cancel_join_thread`。
# 来实现的。
这将防止在最终确定时自动将其连接
(退出)。
#
然而,`cancel_join_thread` 只应在队列
**不会**被其他进程读取或写入时调用
# 处理过程,因为它可能持有锁或留下损坏的数据
# 在队列中,导致其他读者/写者挂起。
#
因此,
# + 对于工作进程,我们只这样做(对于它们的输出)
在退出前,请确保队列(例如,`worker_result_queue`)已被清空。
对于`pin_memory_thread`,其输出队列`data_queue`是一个`queue.Queue`,如果队列已满,则会进行阻塞`put`操作。
因此,不会出现上述问题,但结果是在
因此,不会出现上述问题,但结果是在
`_pin_memory_loop`,我们确实需要将`put`操作放在循环中
那样不仅会在成功时跳出循环,还会在主进程停止读取时跳出,即正在关闭。
对于加载进程,我们对所有进程调用`cancel_join_thread()`
+ 对于加载进程,我们为所有进程调用`cancel_join_thread()`
`_index_queues` 因为整个工作进程和 `pin_memory_thread` 的作用就是为了服务加载进程。如果加载进程已经正在退出,我们其实并不关心
`pin_memory_thread` 是为了服务加载进程。如果加载进程已经正在退出,我们其实并不关心
加载进程已经正在退出,我们其实并不关心它是否能够成功退出
队列已损坏。
#
#
现在让我们回到 1:
如何优雅地退出工作线程,当最后一个引用到
迭代器已消失。
#
实现此目的,我们实现了以下逻辑以及设计
如上所述的选择:
#
`workers_done_event`:
在主进程和所有工作进程之间共享的 `multiprocessing.Event`
# 处理进程。这用于通知工作进程迭代器正在关闭。
# 关闭。设置后,它们将不再向队列发送处理过的数据,
# 并且只等待最后的 `None` 才退出。
# `done_event` 并非必需。也就是说,我们只需检查 `None` 即可。
从输入队列中取出,但允许我们跳过浪费资源
正在关闭时处理数据。
#
`pin_memory_thread_done_event`: 粘性内存线程完成事件
一个用于类似目的的 `threading.Event`
`workers_done_event`,但它是为`pin_memory_thread`准备的。之所以需要单独的事件,是因为`pin_memory_thread`从
工作进程的输出队列中读取。但是,当工作进程看到`workers_done_event`被设置时,它们只想看到最终的`None`,
# 但工作进程,在看到`workers_done_event`被设置后,只想看到最终的`None`,并且是
# 是因为`pin_memory_thread`需要读取工作进程的输出队列。但是,当工作进程看到`workers_done_event`被设置时,
不需要刷新输出队列中的所有数据(例如,它可能调用该队列上的 `cancel_join_thread`,如果其 `IterableDataset` 迭代器偶然耗尽,这超出了主进程的控制)。因此,由于我们将在退出 `pin_memory_thread` 之前退出,所以
`cancel_join_thread` 在该队列上(如果其 `IterableDataset` 迭代器偶然耗尽,这超出了主进程的控制)。因此,由于我们将在退出 `pin_memory_thread` 之前退出,所以
`IterableDataset` 迭代器偶然耗尽,这超出了主进程的控制)。因此,由于我们将在退出 `pin_memory_thread` 之前退出,所以
主进程的控制)。因此,由于我们将在退出 `pin_memory_thread` 之前退出,所以
# 工作者(见下文),使用两个独立的事件。
#
# NOTE:简而言之,协议是主进程将设置这些
# `done_event`,然后相应的进程/线程接收到`None`,
# 并且它们可以在收到`None`后随时退出。
#
# 注意:使用 `None` 作为最终信号是有效的,因为正常数据将
始终是一个包含两个元素的元组,其中第一个元素是数据的索引
# 转移(不同于数据集索引/键),第二个是
# 数据集键或数据样本(取决于哪一部分)
数据模型队列所在的位置。
#
[工作进程]
当加载进程存活时:
从 `index_queue` 获取。
如果获取到其他任何内容,
检查 `workers_done_event`。
如果已设置,则继续到下一个迭代
即,继续获取,直到看到 `None`,然后退出。
否则,处理数据:
如果是从`IterableDataset`中获取并迭代
# 已耗尽,发送 `_IterableDatasetStopIteration`
对象用于表示迭代结束。主进程,当
接收此类对象时,将发送 `None` 到此
# 工作者而不是使用相应的 `index_queue`
# 再也不了。
如果超时,
无论是否设置了`workers_done_event`(仍需要看到`None`)
无论是否,都必须继续到下一个迭代。
(循环外部)
如果设置了`workers_done_event`(在`IterableDataset`中这可能是`False`)
# `data_queue.cancel_join_thread()`。 (一切都在这里结束:)
主进程不会从中读取;
# 其他工作者也将调用
# 取消加入线程。
#
# [ 内存固定线程 ]
# # 无需检查主线程。如果此线程存活,则主加载
# # 线程必须存活,因为此线程被设置为守护线程。
# 当 `pin_memory_thread_done_event` 事件未设置时:
从 `worker_result_queue` 获取
如果超时,继续在下一次迭代中获取。
否则,处理数据。
当`pin_memory_thread_done_event`未设置时:
将处理过的数据放入`data_queue`(一个带有阻塞 put 的`queue.Queue`)
如果超时,则继续在下一次迭代中输入。
否则,中断,即继续到外部循环。
#
注意:我们不检查主线程的状态,因为
1. 如果进程被致命信号终止,`pin_memory_thread`
# 结束。
# 2. 在其他情况下,无论是 __del__ 中的清理还是守护线程的自动退出都将处理它。
# 这也不会忙等待,因为 `.get(timeout)` 不会
# 这也不会忙等待,因为 `.get(timeout)` 不会
# 繁忙等待。
#
# [主进程]
# 在 DataLoader Iter 的 `__del__` 中
# b. 退出 `pin_memory_thread`
# i. 设置 `pin_memory_thread_done_event`。
# ii 在 `worker_result_queue` 中放入 `None`。
# iii. 等待 `pin_memory_thread` 完成。
# iv. 调用 `worker_result_queue.cancel_join_thread()`。
#
# c. 退出工作进程。
# i. 设置 `workers_done_event`。
# ii. 在每个工作进程的 `index_queue` 中置为 `None`。
# iii. 等待工作进程结束。
在各个工作者的 `index_queue` 上调用 `.cancel_join_thread()`。
#
注意:(c) 放在 (b) 之后更好,因为它可能会留下损坏
# 工作结果队列中的数据,该数据由 `pin_memory_thread`
从读取,在这种情况下,`pin_memory_thread`只能
尽管在超时发生时会发生,这很慢,但如果是工人在不幸的时刻被信号杀死,情况也相同
但是在其他情况下,我们最好让`worker_result_queue`对`pin_memory_thread`保持非损坏状态
但是在其他情况下,我们最好让`worker_result_queue`对`pin_memory_thread`保持非损坏状态
但是在其他情况下,我们最好让`worker_result_queue`对`pin_memory_thread`保持非损坏状态
#
# 如果 `pin_memory=False`,则没有 `pin_memory_thread`,并且(b)
# 可以省略
#
# NB: `done_event`s 并不是严格必要的。例如,我们可以直接检查
# `index_queue` 中的 `None`,但这允许我们跳过浪费资源
如果我们已经在关闭,则处理 `index_queue` 中已经存在的索引。
。
def 初始化(self,
加载器):
超级().
初始化(
加载器)
self._prefetch_factor = 加载器.
预取因子
self._in_order = 加载器.
按顺序
断言 self._num_workers > 0
断言 self._prefetch_factor > 0
如果
加载器.
多进程上下文
是
无:
多进程上下文 =
火炬.
多进程
否则:
多进程上下文 =
加载器.
多进程上下文
self._worker_init_fn = 加载器.
工作进程初始化函数
# 添加向前兼容性,以便经典 DataLoader 可以与 DataPipes 一起工作:
# 额外的 worker 初始化函数将负责在 MP 和分布式中进行分片
如果 isinstance(self.
_数据集, (
迭代数据处理管道,
地图数据管道)):
self._worker_init_fn = functools.偏函数(
_sharding_worker_init_fn,
self._worker_init_fn,
self._world_size,
self.排名,
)
无法确定 multiprocessing_context 模块
self._工作结果队列 =
多进程上下文.
队列() # type: ignore[var-annotated]
self._工作进程 ID 集合 =
假
self.关闭 =
假
self.工人完成事件 =
多进程上下文.
活动()
self._index_queues = []
self._workers = []
为 i
在
范围(self._num_workers):
# 没有确定性,multiprocessing_context 模块
索引队列 =
多进程上下文.
队列() # type: ignore[var-annotated]
这里需要调用`cancel_join_thread`!
请参阅上面的(2)和(3b)部分。
索引队列.
取消加入线程()
w = 多进程上下文.
流程(
目标=
_工具.
工人.
工作循环,
参数=(
self.数据集类型,
self._数据集,
索引队列,
self._worker 结果队列,
self._工作者完成事件,
self._auto_collation,
self._collate_fn,
self._drop_last,
self._base_seed,
self.工人初始化函数,
i,
self._num_workers,
self._持久化工作进程,
self._共享种子,
),
)
w.守护进程 =
真实
# NB: Process.start() 实际上需要一些时间,因为它需要
# 启动一个进程并通过管道传递参数。
因此,我们只在它启动后将其添加到 self._workers 列表中,这样我们就不需要在程序死亡之前调用.join(),否则__del__尝试 join 时会得到:
AssertionError: 只能 join 已启动的进程。
因此,我们只在它启动后将其添加到 self._workers 列表中,这样我们就不需要在程序死亡之前调用.join(),否则__del__尝试 join 时会得到:
AssertionError: 只能 join 已启动的进程。
w.开始()
self._索引队列.
追加(
索引队列)
self._工作者.
追加(w)
如果 self.
_内存映射:
self._内存线程完成事件 =
线程.
活动()
队列未进行类型注解
self._数据队列 =
队列.
队列() # type: ignore[var-annotated]
当前设备 = -1
如果 self.
_内存设备 ==
cuda:
当前设备 =
火炬.cuda.
当前设备()
elif self._内存设备 ==
XPU:
当前设备 =
火炬.xpu.
当前设备()
elif self._内存设备 ==
火炬._C._get_privateuse1_backend_name():
自定义设备模块 = getattr(
火炬,
火炬._C._get_privateuse1_backend_name()
)
当前设备 =
自定义设备模块.
当前设备()
elif self._内存设备
是
无:
当前设备 =
火炬.
加速器.current_device_index()
内存线程 =
线程.Thread(
目标=
_工具.
持久化内存.
_内存循环,
参数=(
self._工作结果队列,
self._data_queue,
当前设备,
self._内存线程完成事件,
self._内存设备,
),
)
内存线程.
守护进程 =
真实
针对内存线程.
开始()
# 与工作者类似(见上方注释),我们只注册
# 一旦启动,就注册针对内存线程
self._pin_memory_thread = _pin_memory_thread
否则:
self._data_queue = self._worker_result_queue # 类型:忽略[赋值]
在某些罕见情况下,持久的工人(守护进程)
在主进程退出之前会被终止
当主进程退出时
这会导致当 pin_memory_thread 尝试读取时失败
工作结果队列中的损坏数据
在主进程退出前正确顺序使用 atexit 关闭线程和子进程
在主进程退出前正确顺序使用 atexit 关闭线程和子进程
如果 self._persistent_workers
和 self._pin_memory:
导入 atexit
为 w
在 self.
工人:
atexit.注册(
_多进程数据加载迭代器.
_清理工作进程, w)
# .pid 可以仅在进程尚未启动时为 None(不是这种情况,所以忽略)
_工具.
信号处理.
_设置工作进程 ID(id(self),
元组(w.
进程 ID
为 w
在 self.
_工作者))
# 类型:忽略[杂项]
_工具.
信号处理.
设置 SIGCHLD 处理程序()
self.设置工作进程 ID =
真实
self.重置(
加载器,
第一次迭代=True)
def 重置(self,
加载器,
第一次迭代=
错误):
超级().
重置(
加载器,
第一次迭代)
self.发送索引 = 0
# 下一个要发送给工作者的任务索引
self._rcvd_idx = 0 # 下一个在 __next__ 中返回的任务索引
# 尚未产生的数据信息,即索引在 [rcvd_idx, send_idx) 范围内的任务
# 地图:任务索引 => - (工作者 ID,) 如果数据未获取(待处理)
# \ (工作者 ID,数据) 如果数据已获取(顺序错误)
self._任务信息 = {}
self._待处理任务 = (
0 # always equal to count(v for v in task_info.values() if len(v) == 1)
)
# A list of booleans representing whether each worker still has work to
# do, i.e., not having exhausted its iterable dataset object. It always
# contains all `True`s if not using an iterable-style dataset
# (即,如果 kind 不等于可迭代)。
# 并不意味着这个工作者在这个 epoch 中还有工作要做。
# 这并不意味着工作者已经死亡。在`_persistent_workers`的情况下,
# 工作者将在下一个 epoch 中被重置为可用状态。
self.
_workers_status
工人状态 = [真实
为 i
在
范围(self._num_workers)]
# 每个工作者未完成的任务数量列表
# 当任务分配给工作者时递增
当那些数据已提供给主线程时递减
每个工作线程最多应有 self._prefetch_factor 个挂起的任务
self.工人_num_tasks = [0
为 i
在
范围(self._num_workers
]
重置工作队列循环,以便在下一个 epoch 从工作器 0 开始
self._worker_queue_idx_cycle = itertools.循环(
范围(self._num_workers))
# 我们在启用的情况下继续预取
如果 not
第一次迭代:
为
索引
在
范围(self._num_workers):
self._index_queues[索引].
放置(
_工具.
工人.
_简历迭代(self.
_共享种子)
)
简历迭代次数 = self._num_workers
while 简历迭代次数 > 0:
返回索引,
返回数据 = self.
_获取数据()
如果 isinstance(
返回索引,
_工具.
工人.
_恢复迭代):
断言
返回数据
是 None
简历迭代次数 -= 1
预先启动预取循环
为 _
在
范围(self._prefetch_factor * self._num_workers):
self.尝试设置索引()
def 尝试获取数据(self,
超时=
_工具.
MP 状态检查间隔):
尝试在给定超时时间内从`self._data_queue`获取数据一次。
这也可以用作无超时获取的内部循环。
循环条件为发送者状态。
#
如果任何工作进程意外死亡,将引发 `RuntimeError`。此错误
# 可来自 `_utils/signal_handling.py` 中的 SIGCHLD 处理程序
#(仅适用于非 Windows 平台),或以下手动检查错误
# 和超时。
#
返回一个 2 元组:
(bool:是否成功获取数据,any:成功获取数据时的数据,否则为 None)
try:
数据 = self.
_数据队列.
获取(
超时=
超时)
返回 (True,
数据)
除了
异常
作为 e:
在超时和错误情况下,我们手动检查是否有任何工作进程
# 失败。注意,这是 Windows 检测的唯一机制
# 工作失败。
failed_workers = []
为 worker_id, w
在
列举(self.
工人):
如果 self.
工人状态[worker_id]
和 not w.
是否存活():
失败的工人.
追加(w)
self.标记工人为不可用(worker_id)
如果
长度(
失败的工人) > 0:
进程 ID 字符串 =
“,”.
加入(
字符串(w.
进程 ID)
为 w
在
失败的工人)
提升
运行时错误(
f"数据加载器工作进程(进程 ID("{pids_str}
程序异常退出"
) 来自 e
如果 isinstance(e,
队列.
空的):
返回 (
错误,
无)
导入 errno
导入 tempfile
try:
如果接近文件描述符限制,则抛出异常。
显然,仅尝试打开一个文件是不够的
测试。
查看[ DataLoader 在 Linux 上和打开文件限制 ]
fds_limit_margin = 10
[临时文件.
命名临时文件()
为 i
在
范围(fds_limit_margin
]
除了 OSError
作为 e:
如果 e.errno ==
错误号.
文件描述过多:
提升
运行时错误(
文件打开太多。与“的”通信
工人不再可能。请增加“”
"在 shell 中使用`ulimit -n`限制或更改"
"通过调用"
" `torch.multiprocessing.set_sharing_strategy('file_system')`"
"在代码开头"
) 来自 None
提升
# 注意 [Linux 上的 DataLoader 和打开文件限制]
#
在 Linux 系统中,当使用多进程与 DataLoader 一起使用时,我们传递数据之间
通过 SHM 文件管理根进程和工作者。我们删除这些文件。
文件系统一旦创建就立即加载,并通过它们保持活跃
通过 AF_UNIX 套接字传递文件描述符。(参见
文档/源/multiprocessing.rst 和维基百科中的'多进程技术笔记'。)
(https://github.com/pytorch/pytorch/wiki)。)
#
这有时会导致我们超出打开文件的限制。当这种情况发生时,
犯罪文件描述符是通过套接字传来的,Python 的 `socket`
# 包无声地移除了文件描述符,只设置了
`MSG_CTRUNC` 标志(可能有点误导,因为手册说明)
# 表示由于空间不足而丢弃了一些控制数据
辅助数据缓冲区)。这可能会反映 C 实现的
# 阿福 UNIX 套接字。
#
# 这种行为可以通过下面的脚本和说明进行重现。
# 请在笔记底部查看。
#
# 当发生这种情况时,标准的 Python `multiprocessing`(而不是)
# `torch.multiprocessing`引发`RuntimeError: 收到 0 个 ancdata`项
#
# 有时,FD 没有被剥离,你可能会得到一个`OSError:`
# 在下面的脚本和 DataLoader 中,你可能会得到一个`Too many open files`,然而
# 这很罕见,并且似乎是非确定性的。
#
#
# #!/usr/bin/env python3
# import sys
# import socket
# import os
# 导入数组
# 导入 shutil 模块
# 导入 socket 模块
#
#
# 如果 sys.argv 的长度不等于 4:
打印("用法: ", sys.argv[0], " tmp_dirname 迭代 (发送|接收)")
sys.exit(1)
#
如果 __name__ == '__main__':
dirname = sys.argv[1]
# sock_path = dirname + "/sock"
# iterations = int(sys.argv[2])
# def dummy_path(i):
# return dirname + "/" + str(i) + ".dummy"
#
#
# 如果 sys.argv[3] == 'send':
# 当 os.path.exists(sock_path) 为 False 时:
# 空操作
# client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
# 客户端连接(sock_path)
# for i in range(iterations):
# fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
# ancdata = array.array('i', [fd])
# msg = bytes([i % 256])
# 打印("发送 fd ", fd, " (迭代 #", i, ")")
# client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
#
#
# else:
# assert sys.argv[3] == 'recv'
#
# 如果 os.path.exists(dirname):
# 抛出异常("目录存在")
#
# os.mkdir(dirname)
#
打开套接字...
# 服务器 = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
# 服务器绑定到套接字路径
#
# 正在监听...
# for i in range(iterations):
# a = array.array('i')
# msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
# assert(len(ancdata) == 1)
# cmsg_level, cmsg_type, cmsg_data = ancdata[0]
# a.frombytes(cmsg_data)
# 打印("收到 fd ", a[0], " (迭代 #", i, ")")
#
# 删除目录(dirname)
#
# 步骤重现:
#
# 1. 运行两个 shell,并在接收端设置文件描述符限制:
# (shell1) ulimit -n 1020
# (shell2) ulimit -n 1022
#
# 2. 在第一个 shell 中运行上面的脚本,使用`recv`选项
# (shell1) ./test_socket.py sock_tmp 1017 recv
#
# 3. 在第二个 shell 中运行脚本,使用`send`选项
# (shell2) ./test_socket.py sock_tmp 1017 send
def 获取数据(self):
# 从 `self._data_queue` 中获取数据。
#
# 我们每 `MP_STATUS_CHECK_INTERVAL` 秒检查一次工作者的状态,
# 通过运行 `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` 实现。
在循环中。这是检测工作失败的唯一机制。
对于 Windows。对于其他平台,也使用 SIGCHLD 处理程序来检测工作失败。
工作失败检测。
#
如果`pin_memory=True`,我们还需要检查`pin_memory_thread`是否已
# 在超时中死亡。
如果 self.
_超时 > 0:
成功,
数据 = self.
尝试获取数据(self.
超时)
如果
成功:
返回
数据
否则:
提升
运行时错误(
fDataLoader 超时后{self.
超时}
秒
)
elif self._内存映射:
while self._pin_memory_thread.是否存活():
成功,
数据 = self._try_get_data()
如果
成功:
返回
数据
否则:
# 当 while 条件为假时,即 pin_memory_thread 已死亡。
提升
运行时错误(
"Pin 内存线程意外退出")
在此情况下,`self._data_queue` 是一个 `queue.Queue`,但我们不需要调用 `.task_done()`,因为我们没有使用 `.join()`。
不需要调用 `.task_done()`,因为我们没有使用 `.join()`。
否则:
while True:
成功,
数据 = self._try_get_data()
如果
成功:
返回
数据
def _next_data(self):
while True:
如果负责 `self._rcvd_idx` 的工作线程已经结束
由于耗尽了一个 `IterableDataset`,无法完成此任务,
我们尝试将 `self._rcvd_idx` 前进以找到下一个有效的索引。
#
这部分需要在循环中运行,因为 `self._get_data()` 调用和下面的 `_IterableDatasetStopIteration` 检查都可以标记
这一部分需要在循环中运行,因为 `self._get_data()` 调用和下面的 `_IterableDatasetStopIteration` 检查都可以标记
多余的工人已死亡。
while self.接收索引 < self.
发送索引:
信息 = self.
任务信息.
获取(self.
接收索引,
无)
如果
信息:
工作者 ID =
信息[0]
如果 (
长度(
信息) == 2
或 self.
_workers 状态[worker_id]
): # 有数据或仍在活动状态
断开
删除 self.
_任务信息[self.
_接收索引]
self._rcvd_idx += 1
否则:
# 未找到有效的 `self._接收索引`(即未中断)
如果 not self.
_持久化工作进程:
self._关闭工作者()
提升 StopIteration
现在 self._rcvd_idx 是我们要获取的批次索引
检查下一个样本是否已经被生成
如果
长度(self.
任务信息[self.
_接收索引]) == 2:
worker_id, 数据 = self.
_任务信息.
弹出(self.
_接收索引)
self._rcvd_idx += 1
返回 self.
_处理数据(
数据, worker_id)
断言 not self.
关闭
和 self.
_待办任务 > 0
索引,
数据 = self.
获取数据()
self.未完成任务 -= 1
如果 self.
数据集类型 ==
_数据集类型.
迭代器:
检查 IterableDatasetStopIteration
如果 isinstance(
数据,
_工具.
工人.
_可迭代数据集停止迭代):
如果 self.
_持久化工作进程:
self._工作者状态[
数据.worker_id] =
假
否则:
self._标记工作者为不可用(
数据.worker_id)
self._尝试放入索引()
继续
如果
索引 != self.
接收索引:
如果 not self.
按顺序:
不要存储以备后用,立即处理
立即从 self._task_info 中删除
# 保持对象大小可控
员工 ID = self.
任务信息.
弹出(
索引
)0]
返回 self.
处理数据(
数据, worker_id)
# 存储乱序样本
self._任务信息[
索引] += (
数据,)
否则:
工作器 ID = self.
任务信息.
弹出(
索引
)0]
self._rcvd_idx += 1
返回 self.
处理数据(
数据, worker_id)
def 尝试设置索引(self):
最大任务 = self._prefetch_factor * self._num_workers
断言 self.
待处理任务 <
最大任务数
try:
索引 = self.
下一个索引()
除了
StopIteration 异常:
返回
为 _
在
范围(self._num_workers):
# 查找下一个活跃的工作者(如果有)
工作队列索引 =
下一(self.
工作队列索引循环)
如果 self.
工人状态[
工作队列索引
]:
如果 self.
_按顺序:
断开
elif self._工作者任务数[
工作队列索引] <
最大任务数 //
总和(
self.工作者状态
):
当 self._in_order 为 False 时,如果工作者有容量,则分配工作给工作者
_workers_status 只在此线程中更新,因此总和保证大于 0
断开
否则:
未找到(即未中断)
返回
self._index_queues[工作队列索引].
放置((self.
发送索引,
索引)) # type: ignore[possibly-undefined]
self.任务信息[self.
发送索引] = (
工作队列索引,)
self.工作者任务数[
工作队列索引] += 1
self.未完成任务数 += 1
self.发送索引 += 1
def 处理数据(self,
数据, worker_idx):
self._workers_num_tasks[worker_idx] -= 1
self.尝试设置索引()
如果 isinstance(
数据, ExceptionWrapper):
数据.reraise()
返回
数据
def 标记工人为不可用(self, worker_id,
关闭=
错误):
标记一个工作进程已完成其工作,例如,由于耗尽一个 `IterableDataset`。
这应在 `_MultiProcessingDataLoaderIter` 继续运行时使用。
# `_MultiProcessingDataLoaderIter` 即将运行。
断言 self.
工人状态[worker_id]
或 (
self._persistent_workers 和
关闭
)
# 向特定工作者发送终止信号。
q = self._index_queues[worker_id]
# 表示当前进程不再向该队列添加更多数据。
# 结束。
q.放置(
无)
注意这里我们并没有在这里连接工作线程,也没有从 C 侧结构中移除工作线程的进程 ID,因为(1)连接可能很慢,并且
# (2)由于我们不连接,工作线程可能仍然会引发错误,而我们更愿意捕获这些错误,而不是忽略它们,即使它们
# (3)可能不会立即影响主线程的执行
# ,我们仍然希望记录这些错误,以便于后续分析
在工人完成其工作后引发。
将连接推迟到 `_shutdown_workers`,它在调用时执行
所有工作者完成他们的工作(例如,`IterableDataset`副本)
当这个迭代器被垃圾回收时。
self.工人状态[worker_id] =
假
断言 self.
_工作者完成事件.
已设置() ==
关闭
def _关闭工作者(self):
# 当关闭此 `_MultiProcessingDataLoaderIter` 时调用。
# 详细请参阅 NOTE [数据加载器多进程关闭逻辑]。
# 该函数的逻辑请参阅。
如果 (
_utils 是 None
或
_工具.
python 退出状态
是
真实
或
_工具.
python 退出状态
是 None
):
# 参考第(2)条注释。如果 Python 正在关闭,则执行无操作。
返回
# 当最后一个引用消失/迭代器耗尽时正常退出。
# 查看(1)和笔记的第二部分。
如果 not self.
关闭:
self.关闭 =
真实
try:
# 当最后一个引用消失/迭代器耗尽时正常退出。
# 查看(1)和笔记的第二部分。
首先退出 `pin_memory_thread`,因为退出工作进程可能会在 `worker_result_queue` 中留下
被损坏的数据,而 `pin_memory_thread` 会从中读取。
。
如果
有属性(self,
_pin_memory_thread):
# 使用 hasattr 以防在设置属性之前发生错误。
self._内存线程完成事件.
集合()
# 如果 pin_memory_thread 正在等待,则发送信息给它。
# 这样它可以醒来并检查 `pin_memory_thread_done_event`。
self._工作结果队列.
放置((
无,
无))
self._pin_memory_thread.加入()
self._工作结果队列.
取消加入线程()
self._工作结果队列.
关闭()
立即退出工作者。
self._工作者完成事件.
集合()
为
员工 ID
在
范围(
长度(self.
工人)):
应从 `len(self._workers)` 获取工作者数量,而不是从 `self._num_workers`,以防在启动所有工作者之前发生错误。
`self._num_workers`,以防在启动所有工作者之前发生错误。
# workers.
如果我们使用带有持久工作者的 workers_status
我们必须将其关闭,因为工作器已暂停
如果 self._persistent_workers
或 self.
工人状态[worker_id
]:
self.标记工人为不可用(worker_id,
关闭=True)
为 w
在 self.
工人:
我们应该能够在这里连接,但如果出了问题,
我们设置了一个超时,如果工作者未能连接,
它们将在`finally`块中被杀死。
w.加入(
超时=
_工具.
MP 状态检查间隔)
为 q
在 self._index_queues:
q.取消加入线程()
q.关闭()
最后:
即使这个函数所做的只是将任务放入队列,
即使我们在其上调用`cancel_join_thread`,当工作进程被信号杀死时,
例如,在`Event.set()`中挂起时,也可能发生奇怪的事情。
因此,我们需要用 SIGCHLD 处理程序来保护这一点。
仅从 C 侧数据结构中删除 pids
结束。
#
# FIXME: 很遗憾,对于 Windows,我们缺少一个工作进程
# 此函数中包含错误检测机制,如下
# 不提供 SIGCHLD 处理程序。
如果 self._worker_pids_set:
_工具.
信号处理._remove_worker_pids(id(self))
self._工作进程 ID 集合 =
假
为 w
在 self.
工人:
如果 w.
是否存活():
# 现有的机制试图让工作进程退出
平静地,但不幸的是,如果我们不幸达到
这里,我们不应该到达的(例如,pytorch/pytorch#39570),
我们会杀死这个工作进程。
w.终止()
静态方法用于移除对 `_MultiProcessingDataLoaderIter` 的引用
@staticmethod
def 清理工作线程(w):
try:
w.加入(
超时=
_工具.
MP 状态检查间隔)
最后:
如果 w.
是否存活():
w.终止()
def __del__(self):
self.关闭工作线程()