• 文档 >
  • torch.utils.data
快捷键

torch.utils.data 注释

PyTorch 数据加载工具的核心是 torch.utils.data.DataLoader 类。它表示一个 Python 迭代器,遍历数据集,并支持

  • 沉浸式地图样式和可迭代样式数据集,

  • 自定义数据加载顺序,

  • 自动批量处理,

  • 单进程和多进程数据加载,

  • 自动内存固定。

这些选项由构造函数的参数配置,该参数的签名是:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

下面的部分详细描述了这些选项的效果和用法。

数据集类型 ¶

DataLoader 构造函数最重要的参数是 dataset ,它表示一个用于加载数据的数据集对象。PyTorch 支持两种不同类型的数据集:

  • 映射风格的数据集,

  • 可迭代风格的数据集。

映射风格的数据集 ¶

地图风格的数据集是指实现了 __getitem__()__len__() 协议,并将(可能是非整数的)索引/键映射到数据样本的映射。

例如,当使用 dataset[idx] 访问此类数据集时,可以从磁盘上的文件夹中读取第 idx 张图像及其对应的标签。

更多详情请见 Dataset

可迭代风格的数据集 ¶

可迭代式数据集是 IterableDataset 的子类实例,实现了 __iter__() 协议,并代表对数据样本的可迭代访问。此类数据集特别适用于随机读取昂贵或甚至不可能的情况,以及批量大小依赖于获取的数据的情况。

例如,此类数据集在调用 iter(dataset) 时,可以返回从数据库、远程服务器或实时生成的日志的数据流。

更多详情请见 IterableDataset

注意

当使用多进程数据加载的 IterableDataset 时。相同的数据集对象在每个工作进程中都会被复制,因此副本必须配置不同,以避免数据重复。有关如何实现此操作的文档,请参阅 IterableDataset

数据加载顺序和 Sampler

对于迭代式数据集,数据加载顺序完全由用户定义的迭代器控制。这允许更容易地实现分块读取和动态批量大小(例如,通过每次产生一个批处理样本)。

本节其余部分涉及具有映射式数据集的情况。 torch.utils.data.Sampler 类用于指定数据加载中使用的索引/键的顺序。它们代表对数据集索引的可迭代对象。例如,在常见的随机梯度下降(SGD)情况下, Sampler 可以随机排列索引列表,并逐个产生每个索引,或者为小批量 SGD 产生少量索引。

将根据 shuffle 参数自动构建一个顺序或洗牌采样器,或者用户可以使用 sampler 参数指定一个自定义的 Sampler 对象,该对象在每个时间点产生下一个要获取的索引/键。

可以将一个自定义的 Sampler ,该 Sampler 在每次产生一个批次的索引列表,作为 batch_sampler 参数传递。也可以通过 batch_sizedrop_last 参数启用自动批处理。有关更多详细信息,请参阅下一节。

注意

由于此类数据集没有键或索引的概念,因此 samplerbatch_sampler 都不兼容可迭代式数据集。

加载批处理和非批处理数据

DataLoader 支持通过参数 batch_sizedrop_lastbatch_samplercollate_fn (具有默认函数)自动将单独获取的数据样本整理成批次(默认函数)。

自动批量处理(默认)

这是最常见的用例,对应于获取一个数据小批量并将它们整理成批处理样本,即包含具有一个维度为批处理维度的张量(通常为第一个维度)。

batch_size (默认 1 )不是 None 时,数据加载器将返回批处理样本而不是单个样本。 batch_sizedrop_last 参数用于指定数据加载器如何获取数据集键的批次。对于映射样式数据集,用户可以指定 batch_sampler ,它一次返回一个键的列表。

注意

batch_sizedrop_last 参数本质上用于从 sampler 构建 batch_sampler 。对于映射样式数据集, sampler 由用户提供或基于 shuffle 参数构建。对于可迭代样式数据集, sampler 是一个虚拟的无限迭代器。请参阅有关采样器的更多详细信息。

注意

在从多进程的迭代式数据集获取数据时, drop_last 参数会丢弃每个工作进程数据副本的最后一个非满批次。

使用采样器索引获取样本列表后,将作为 collate_fn 参数传递的函数用于将样本列表汇总成批次。

在这种情况下,从映射式数据集加载大致等同于:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

以及从迭代式数据集加载大致等同于:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

可以使用自定义的 collate_fn 来定制排序,例如将序列数据填充到批次的最大长度。请参阅本节了解有关 collate_fn 的更多信息。

禁用自动批处理

在某些情况下,用户可能希望在数据集代码中手动处理批处理,或者简单地加载单个样本。例如,直接加载批处理数据(例如从数据库的批量读取或读取内存的连续块)可能更便宜,或者批处理大小与数据相关,或者程序设计为处理单个样本。在这些情况下,最好不使用自动批处理(其中 collate_fn 用于排序样本),而是让数据加载器直接返回 dataset 对象的每个成员。

batch_sizebatch_sampler 都是 Nonebatch_sampler 的默认值已经是 None )时,自动批处理被禁用。从 dataset 获得的每个样本都使用作为 collate_fn 参数传递的函数进行处理。

当自动批处理被禁用时,默认的 collate_fn 只是将 NumPy 数组转换为 PyTorch 张量,其他一切保持不变。

在这种情况下,从地图样式数据集中加载数据大致等同于:

for index in sampler:
    yield collate_fn(dataset[index])

从可迭代样式数据集中加载数据大致等同于:

for data in iter(dataset):
    yield collate_fn(data)

请参阅本节了解更多关于 collate_fn 的信息。

collate_fn ¶合作

collate_fn 在自动批处理启用或禁用时使用略有不同。

当自动批处理被禁用时,每个单独的数据样本都会调用 collate_fn ,输出从数据加载器迭代器中产生。在这种情况下,默认的 collate_fn 简单地将 NumPy 数组转换为 PyTorch 张量。

当启用自动批处理时,每次调用 collate_fn 时都会带有一个数据样本列表。它预期将输入样本收集到一个批次中,以便从数据加载器迭代器中产生。本节其余部分描述了默认的 collate_fndefault_collate() )的行为。

例如,如果每个数据样本由一个 3 通道图像和一个积分类别标签组成,即数据集中的每个元素返回一个元组 (image, class_index) ,默认的 collate_fn 将此类元组列表合并成一个包含批处理图像张量和批处理类别标签张量的单个元组。特别是默认的 collate_fn 具有以下特性:

  • 它始终在前面添加一个新维度作为批处理维度。

  • 它自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。

  • 它保留数据结构,例如,如果每个样本是一个字典,则输出具有相同键集的字典,但值是批处理张量(如果值不能转换为张量,则为列表)。同样适用于 listtuplenamedtuple 等。

用户可以使用自定义的 collate_fn 来实现自定义批处理,例如按除第一个维度以外的维度整理,填充不同长度的序列,或添加对自定义数据类型的支持。

如果您遇到 DataLoader 的输出维度或类型与您的预期不同的情况,您可能需要检查您的 collate_fn

单进程和多进程数据加载 ¶

DataLoader 默认使用单进程数据加载。

在 Python 进程中,全局解释器锁(GIL)阻止了 Python 代码在多个线程中真正并行执行。为了避免数据加载阻塞计算代码,PyTorch 提供了一个简单的开关,通过将参数 num_workers 设置为正整数来执行多进程数据加载。

单进程数据加载(默认)

在此模式下,数据获取是在初始化 DataLoader 的同一进程中完成的。因此,数据加载可能会阻塞计算。然而,当用于进程间共享数据(例如,共享内存、文件描述符)的资源有限,或者整个数据集很小且可以完全加载到内存中时,此模式可能更受欢迎。此外,单进程加载通常显示更易读的错误跟踪,因此对于调试很有用。

多进程数据加载

将参数 num_workers 设置为正整数将启用多进程数据加载,并指定加载工作进程的数量。

警告

经过多次迭代后,加载工作进程将消耗与父进程相同的 CPU 内存,用于父进程中所有从工作进程访问的 Python 对象。如果数据集包含大量数据(例如,在构建数据集时加载一个非常大的文件名列表)并且/或者使用大量工作进程(总内存使用量为 number of workers * size of parent process ),则可能会出现问题。最简单的解决方案是将 Python 对象替换为非引用计数表示,例如 Pandas、Numpy 或 PyArrow 对象。有关此问题的更多详细信息以及如何解决这些问题的示例代码,请参阅问题#13246。

在此模式下,每次创建 DataLoader 的迭代器时(例如,当您调用 enumerate(dataloader) 时),将创建 num_workers 个工作进程。此时, datasetcollate_fnworker_init_fn 将传递给每个工作进程,它们用于初始化和获取数据。这意味着数据集访问及其内部 IO、转换(包括 collate_fn )都在工作进程中运行。

torch.utils.data.get_worker_info() 在工作进程中返回各种有用的信息(包括工作进程 ID、数据集副本、初始种子等),并在主进程中返回 None 。用户可以在数据集代码中使用此函数,并通过 worker_init_fn 单独配置每个数据集副本,以及确定代码是否在工作进程中运行。例如,这可以在数据集分片时特别有用。

对于映射样式的数据集,主进程使用 sampler 生成索引,并将它们发送到工作进程。因此,任何洗牌随机化都在主进程中完成,通过分配索引来引导加载。

对于可迭代样式的数据集,由于每个工作进程都获得 dataset 对象的副本,因此简单的多进程加载通常会得到重复的数据。用户可以使用 torch.utils.data.get_worker_info() 和/或 worker_init_fn 独立配置每个副本。(有关如何实现此操作的说明,请参阅 IterableDataset 文档。)出于类似原因,在多进程加载中, drop_last 参数将丢弃每个工作进程的可迭代数据集副本的最后一个非完整批次。

当迭代结束时或迭代器被垃圾回收时,将关闭工作进程。

警告

通常不建议在多进程加载时返回 CUDA 张量,因为 CUDA 和多进程共享 CUDA 张量存在许多细微之处(参见多进程中的 CUDA)。相反,我们建议使用自动内存固定(即设置 pin_memory=True ),这可以启用快速数据传输到 CUDA 支持的 GPU。

平台特定行为

由于工作进程依赖于 Python multiprocessing ,与 Unix 相比,Windows 上的工作进程启动行为不同。

  • 在 Unix 上, fork() 是默认的 multiprocessing 启动方法。使用 fork() ,子工作进程通常可以直接通过复制的地址空间直接访问 dataset 和 Python 参数函数。

  • 在 Windows 或 MacOS 上, spawn() 是默认的 multiprocessing 启动方法。使用 spawn() 时,将启动另一个解释器来运行您的主脚本,随后是内部工作函数,该函数通过 pickle 序列化接收 datasetcollate_fn 和其他参数。

这种独立的序列化意味着,在使用多进程数据加载时,您需要采取两个步骤来确保与 Windows 兼容:

  • 将您主脚本的大部分代码包裹在 if __name__ == '__main__': 块中,以确保在每次启动工作进程时不会再次运行(很可能会生成错误)。您可以将数据集和 DataLoader 实例创建逻辑放在这里,因为它们不需要在工作进程中重新执行。

  • 确保任何自定义的 collate_fnworker_init_fndataset 代码都被声明为顶级定义,位于 __main__ 检查之外。这确保它们在工作进程中可用。(这是必要的,因为函数仅作为引用序列化,而不是 bytecode 。)

多进程数据加载中的随机性 ¶

默认情况下,每个工作进程都将将其 PyTorch 随机数种子设置为 base_seed + worker_id ,其中 base_seed 是主进程使用其 RNG 生成的长随机数(因此,必须消耗 RNG 状态)或指定的 generator 。然而,在初始化工作进程时,其他库的种子可能会重复,导致每个工作进程返回相同的随机数。(参见 FAQ 中的本节。)

worker_init_fn 中,您可以使用 torch.utils.data.get_worker_info().seedtorch.initial_seed() 访问每个工作进程的 PyTorch 随机数种子,并在数据加载之前使用它来初始化其他库。

内存固定 ¶

当主机到 GPU 的复制操作从固定(页面锁定)内存发起时,速度会快得多。有关何时以及如何一般使用固定内存缓冲区的详细信息,请参阅使用固定内存缓冲区。

对于数据加载,将 pin_memory=True 传递给 DataLoader 将自动将获取的数据 Tensors 放入固定内存中,从而实现更快的 CUDA 启用 GPU 数据传输。

默认内存固定逻辑仅识别 Tensors 以及包含 Tensors 的映射和可迭代对象。默认情况下,如果固定逻辑看到的是一个自定义类型的数据批次(如果您有一个返回自定义批次类型的 collate_fn ),或者如果批次中的每个元素都是自定义类型,固定逻辑将不会识别它们,并且将返回该批次(或这些元素)而不固定内存。要为自定义批次或数据类型启用内存固定,请在您的自定义类型上定义一个 pin_memory() 方法。

请参阅下面的示例。

示例:

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='', in_order=True)[source][source]

数据加载器结合数据集和采样器,并为给定数据集提供可迭代的对象。

DataLoader 支持使用单进程或多进程加载的映射风格和可迭代风格的数据集,可自定义加载顺序和可选的自动批处理(收集)以及内存固定。

更多详细信息,请参阅 torch.utils.data 文档页面。

参数:
  • 数据集(数据集)- 从中加载数据的数据集。

  • 批大小(整数,可选)- 每批加载的样本数量(默认: 1 )。

  • 打乱顺序(布尔值,可选)- 设置为 True 以在每个 epoch 重新打乱数据(默认: False )。

  • 样本抽取器(Sampler 或可迭代对象,可选)- 定义从数据集中抽取样本的策略。可以是实现了 __len__ 的任何 Iterable 。如果指定,则必须不指定 shuffle

  • batch_sampler(采样器或可迭代对象,可选)- 类似于 sampler ,但每次返回一批索引。与 batch_sizeshufflesamplerdrop_last 互斥。

  • num_workers(整数,可选)- 用于数据加载的子进程数量。 0 表示数据将在主进程中加载。(默认: 0

  • collate_fn(可调用对象,可选)- 将样本列表合并成一个包含张量的迷你批次。在从映射式数据集使用批量加载时使用。

  • pin_memory(布尔值,可选)- 如果 True ,数据加载器将在返回之前将张量复制到设备/CUDA 固定内存中。如果您的数据元素是自定义类型,或者您的 collate_fn 返回的批次是自定义类型,请参阅下面的示例。

  • drop_last (bool, 可选) – 设置为 True 以丢弃最后一个不完整的批次,如果数据集大小不能被批次大小整除。如果 False 且数据集大小不能被批次大小整除,则最后一个批次将更小。(默认: False )

  • timeout (数值,可选) – 如果为正数,则为从工作者那里收集一个批次的超时值。应始终为非负数。(默认: 0 )

  • worker_init_fn (Callable,可选) – 如果不为 None ,则将在每个工作者子进程中以工作者 ID(一个整数, [0, num_workers - 1] )作为输入调用此函数,在播种之后和加载数据之前。(默认: None )

  • multiprocessing_context (str 或 multiprocessing.context.BaseContext,可选) – 如果为 None ,则使用操作系统的默认多进程上下文。(默认: None )

  • generator (torch.Generator, 可选) – 如果不是 None ,则此随机数生成器将被 RandomSampler 用于生成随机索引,以及用于工作进程的 base_seed 。(默认: None )

  • prefetch_factor (int, 可选,关键字参数) – 每个工作进程预先加载的批次数。 2 表示所有工作进程将总共预取 2 * num_workers 批次。 (默认值取决于 num_workers 的设置值。如果 num_workers 的值为 0,则默认为 None 。否则,如果 num_workers > 0 的默认值为 2 。)

  • persistent_workers (bool, 可选) – 如果 True ,则数据加载器在消耗完数据集后不会关闭工作进程。这允许保持工作进程的 Dataset 实例存活。 (默认: False )

  • pin_memory_device (str, 可选) – 如果 pin_memoryTrue ,则在 pin_memory 上。如果没有给出,则当前加速器将是默认值。此参数不建议使用,并可能被弃用。

  • in_order(布尔值,可选)- 如果 False ,数据加载器将不会强制要求返回的批次按先入先出顺序排列。仅当 num_workers > 0 时适用。(默认: True

警告

如果使用 spawn 启动方法, worker_init_fn 不能是一个不可序列化的对象,例如 lambda 函数。有关 PyTorch 中多进程的更多详细信息,请参阅多进程最佳实践。

警告

len(dataloader) 启发式算法基于采样器的长度。当 datasetIterableDataset 时,它将基于 len(dataset) / batch_size 返回一个估计值,并根据 drop_last 进行适当的舍入,无论多进程加载配置如何。这代表了 PyTorch 能做出的最佳猜测,因为 PyTorch 相信用户 dataset 代码能够正确处理多进程加载以避免数据重复。

然而,如果分片导致多个工作进程拥有不完整的最后一个批次,这个估计仍然可能不准确,因为(1)一个本应完整的批次可能被分割成多个批次,并且(2)当 drop_last 被设置时,可能丢弃超过一个批次值的样本。不幸的是,PyTorch 通常无法检测到这种情况。

请参阅数据集类型,以获取有关这两种数据集的更多详细信息以及 IterableDataset 如何与多进程数据加载交互。

警告

请参阅可重现性、我的数据加载器工作进程返回相同的随机数以及多进程数据加载中的随机性相关笔记,以了解随机种子相关问题。

警告

将 in_order 设置为 False 可能会损害可重现性,并可能导致在数据不平衡的情况下,倾斜的数据分布被喂给训练器。

class torch.utils.data.Dataset[source][source]

表示一个抽象类的 Dataset

所有表示从键到数据样本映射的数据集都应该继承它。所有子类都应该重写 __getitem__() ,以支持获取给定键的数据样本。子类还可以选择性地重写 __len__() ,这在许多 Sampler 实现和 DataLoader 默认选项中预期返回数据集的大小。子类还可以选择性地实现 __getitems__() ,以加速批量样本加载。此方法接受样本批次的索引列表,并返回样本列表。

注意

DataLoader 默认构建一个产生整数索引的索引采样器。为了使其与非整数索引/键的映射式数据集兼容,必须提供自定义采样器。

class torch.utils.data.IterableDataset[source][source]

一个可迭代的数据集。

所有表示数据样本可迭代的集合都应该继承它。这种形式的集合在数据来自流时尤其有用。

所有子类都应该重写 __iter__() ,这将返回该数据集中的样本迭代器。

当子类与 DataLoader 一起使用时,数据集中的每个项目都将从 DataLoader 迭代器中产生。当 num_workers > 0 时,每个工作进程将拥有数据集对象的不同副本,因此通常希望独立配置每个副本以避免从工作进程返回重复的数据。 get_worker_info() ,当在工作进程中调用时,返回有关工作进程的信息。它可以用在数据集的 __iter__() 方法或 DataLoaderworker_init_fn 选项中,以修改每个副本的行为。

示例 1:在 __iter__() 中将工作负载分配给所有工作者

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]

>>> # Multi-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

示例 2:使用 worker_init_fn 在所有工作者之间分配工作负载

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
class torch.utils.data.TensorDataset(*tensors)[source][source]

包装张量的数据集。

每个样本将通过沿第一个维度索引张量来检索。

参数:

*张量(Tensor)- 第一个维度大小相同的张量。

class torch.utils.data.StackDataset(*args, **kwargs)[source][source]

数据集作为多个数据集的堆叠。

这个类用于组装复杂输入数据的各个部分,这些数据以数据集的形式给出。

示例

>>> images = ImageDataset()
>>> texts = TextDataset()
>>> tuple_stack = StackDataset(images, texts)
>>> tuple_stack[0] == (images[0], texts[0])
>>> dict_stack = StackDataset(image=images, text=texts)
>>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
参数:
  • *args(数据集)- 返回的数据集以元组形式堆叠。

  • **kwargs(数据集)- 返回的数据集以字典形式堆叠。

class torch.utils.data.ConcatDataset(datasets)[source][source]

数据集作为多个数据集的拼接。

此类可用于组装不同的现有数据集。

参数:

数据集(序列)- 要拼接的数据集列表

class torch.utils.data.ChainDataset(datasets)[source][source]

多个 IterableDataset 的数据集。

此类可用于组装不同的现有数据集流。链式操作是即时完成的,因此使用此类连接大规模数据集将非常高效。

参数:

datasets (IterableDataset 的可迭代集合) – 要连接的数据集

class torch.utils.data.Subset(dataset, indices)[source][source]

指定索引的子数据集。

参数:
  • 数据集(数据集)- 整个数据集

  • 索引(序列)- 在整个集合中选择的子集索引

torch.utils.data._utils.collate.collate(batch, *, collate_fn_map=None)[source][source]

处理每个批次中元素集合的通用归并函数。

该函数还打开功能注册表以处理特定元素类型。default_collate_fn_map 为张量、NumPy 数组、数字和字符串提供默认归并函数。

参数:
  • batch – 要归并的单个批次。

  • collate_fn_map(可选[dict[Union[type, tuple[type, ...]], Callable]])- 可选字典,将元素类型映射到相应的归并函数。如果元素类型不在此字典中,则该函数将按插入顺序遍历字典的每个键,如果元素类型是键的子类,则调用相应的归并函数。

示例

>>> def collate_tensor_fn(batch, *, collate_fn_map):
...     # Extend this function to handle batch of tensors
...     return torch.stack(batch, 0)
>>> def custom_collate(batch):
...     collate_map = {torch.Tensor: collate_tensor_fn}
...     return collate(batch, collate_fn_map=collate_map)
>>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
>>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})

注意

每个归一化函数都需要一个位置参数用于批次和一个关键字参数用于归一化函数字典,即 collate_fn_map。

torch.utils.data.default_collate(batch)[source][source]

接收一个数据批次,并将批次中的元素放入一个具有额外外部维度的张量中 - 批次大小。

确切的输出类型可以是 torch.Tensortorch.Tensor 的序列, torch.Tensor 的集合,或者根据输入类型保持不变。当在 DataLoader 中定义 batch_size 或 batch_sampler 时,这用作归一化的默认函数。

这里是通用输入类型(基于批次中元素的类型)到输出类型的映射:

  • torch.Tensor -> torch.Tensor (添加了外部批次大小维度)

  • NumPy 数组 -> torch.Tensor

  • float -> torch.Tensor

  • int -> torch.Tensor

  • str -> str (未改变)

  • bytes -> bytes (未改变)

  • Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]

  • NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

  • Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

参数:

批次 – 要合并的单个批次

示例

>>> # Example with a batch of `int`s:
>>> default_collate([0, 1, 2, 3])
tensor([0, 1, 2, 3])
>>> # Example with a batch of `str`s:
>>> default_collate(['a', 'b', 'c'])
['a', 'b', 'c']
>>> # Example with `Map` inside the batch:
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
{'A': tensor([  0, 100]), 'B': tensor([  1, 100])}
>>> # Example with `NamedTuple` inside the batch:
>>> Point = namedtuple('Point', ['x', 'y'])
>>> default_collate([Point(0, 0), Point(1, 1)])
Point(x=tensor([0, 1]), y=tensor([0, 1]))
>>> # Example with `Tuple` inside the batch:
>>> default_collate([(0, 1), (2, 3)])
[tensor([0, 2]), tensor([1, 3])]
>>> # Example with `List` inside the batch:
>>> default_collate([[0, 1], [2, 3]])
[tensor([0, 2]), tensor([1, 3])]
>>> # Two options to extend `default_collate` to handle specific type
>>> # Option 1: Write custom collate function and invoke `default_collate`
>>> def custom_collate(batch):
...     elem = batch[0]
...     if isinstance(elem, CustomType):  # Some custom condition
...         return ...
...     else:  # Fall back to `default_collate`
...         return default_collate(batch)
>>> # Option 2: In-place modify `default_collate_fn_map`
>>> def collate_customtype_fn(batch, *, collate_fn_map=None):
...     return ...
>>> default_collate_fn_map.update(CustomType, collate_customtype_fn)
>>> default_collate(batch)  # Handle `CustomType` automatically
torch.utils.data.default_convert(data)[source][source]

将每个 NumPy 数组元素转换为 torch.Tensor

如果输入是序列、集合或映射,则尝试将内部每个元素转换为 torch.Tensor 。如果输入不是 NumPy 数组,则保持不变。这是在 DataLoader 中未定义 batch_sampler 和 batch_size 时的默认函数,用于合并。

输入类型到输出类型的映射与 default_collate() 类似。更多详细信息请参阅那里的描述。

参数:

data – 要转换的单个数据点

示例

>>> # Example with `int`
>>> default_convert(0)
0
>>> # Example with NumPy array
>>> default_convert(np.array([0, 1]))
tensor([0, 1])
>>> # Example with NamedTuple
>>> Point = namedtuple('Point', ['x', 'y'])
>>> default_convert(Point(0, 0))
Point(x=0, y=0)
>>> default_convert(Point(np.array(0), np.array(0)))
Point(x=tensor(0), y=tensor(0))
>>> # Example with List
>>> default_convert([np.array([0, 1]), np.array([2, 3])])
[tensor([0, 1]), tensor([2, 3])]
torch.utils.data.get_worker_info()[source][source]

返回当前 DataLoader 迭代器工作进程的信息。

当在工作进程中调用时,此函数返回一个保证具有以下属性的对象:

  • id :当前工作进程 ID。

  • num_workers : 总工人数。

  • seed : 当前工作者的随机种子设置。此值由主进程的 RNG 和工作者 ID 决定。有关详细信息,请参阅 DataLoader 的文档。

  • dataset : 此进程中的数据集对象副本。请注意,这将在主进程中的不同进程中是不同的对象。

当在主进程中调用时,此函数返回 None

注意

当用于传递给 worker_init_fnDataLoader 时,此方法可以用来分别设置每个工作进程,例如,使用 worker_id 来配置 dataset 对象,使其只读取分片数据集的一部分,或者使用 seed 来初始化数据集代码中使用的其他库。

返回类型:

可选[WorkerInfo]

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)[source][source]

随机将数据集分割为给定长度的非重叠新数据集。

如果给定一个总和为 1 的分数列表,则将自动计算每个分数的长度,长度为 floor(frac * len(dataset))。

计算长度后,如果有任何余数,则将 1 个计数以轮询方式分配给长度,直到没有余数为止。

可选地固定生成器以获得可重复的结果,例如:

示例

>>> generator1 = torch.Generator().manual_seed(42)
>>> generator2 = torch.Generator().manual_seed(42)
>>> random_split(range(10), [3, 7], generator=generator1)
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
参数:
  • dataset (数据集) – 要分割的数据集

  • 长度(序列)- 要生成的分割的长度或分数

  • 生成器(Generator)- 用于随机排列的生成器

返回类型:

列[torch.utils.data.dataset.Subset[~_T]]

类 torch.utils.data.Sampler(data_source=None)[source][source]

所有采样器的基类。

每个采样器子类都必须提供一个 __iter__() 方法,提供遍历数据集元素索引或索引列表(批次)的方式,并且可以提供一个 __len__() 方法,返回返回迭代器的长度。

参数:

data_source(数据集)- 此参数未使用,将在 2.2.0 版本中删除。您仍然可以自定义使用它的实现。

示例

>>> class AccedingSequenceLengthSampler(Sampler[int]):
>>>     def __init__(self, data: List[str]) -> None:
>>>         self.data = data
>>>
>>>     def __len__(self) -> int:
>>>         return len(self.data)
>>>
>>>     def __iter__(self) -> Iterator[int]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         yield from torch.argsort(sizes).tolist()
>>>
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
>>>     def __init__(self, data: List[str], batch_size: int) -> None:
>>>         self.data = data
>>>         self.batch_size = batch_size
>>>
>>>     def __len__(self) -> int:
>>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
>>>
>>>     def __iter__(self) -> Iterator[List[int]]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
>>>             yield batch.tolist()

注意

__len__() 方法不是 DataLoader 严格要求的,但在涉及 DataLoader 长度的任何计算中都是期望的。

class torch.utils.data.SequentialSampler(data_source)[source][source]

依次采样元素,始终按相同顺序。

参数:

data_source (Dataset) – 从中采样的数据集

class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)[source][source]

随机抽取样本。如果不带替换,则从打乱的数据集中抽取。

如果带替换,则用户可以指定 num_samples 来抽取。

参数:
  • data_source(数据集)- 从中抽取样本的数据集

  • replacement(布尔值)- 如果为 True ,则按需带替换抽取样本,默认为 ``False``

  • num_samples (int) – 要抽取的样本数量,默认为`len(dataset)`。

  • generator (Generator) – 用于抽取的生成器。

class torch.utils.data.SubsetRandomSampler(indices, generator=None)[source][source]

从给定的索引列表中随机抽取样本,不重复。

参数:
  • 索引(序列)- 索引的序列

  • 生成器(Generator)- 用于采样的生成器

类 torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)[source][source] ¶

根据给定的概率(权重)从 [0,..,len(weights)-1] 中采样元素。

参数:
  • 权重(序列)- 一系列权重,不一定要加起来等于一

  • num_samples(整数)- 要抽取的样本数量

  • replacement(布尔值)- 如果是 True ,则进行有放回的抽样。如果不是,则进行无放回抽样,这意味着当抽取一个样本索引时,该索引在该行中不能再被抽取。

  • generator(生成器)- 用于抽样的生成器

示例

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)[source][source]

包装另一个采样器以生成索引的迷你批次。

参数:
  • sampler (采样器或可迭代对象) – 基础采样器。可以是任何可迭代对象

  • batch_size (int) – 迷你批次的大小。

  • drop_last (bool) – 如果 True ,采样器将丢弃最后一个批次,如果其大小将小于 batch_size

示例

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)[source][source]

采样器,限制数据加载到数据集的子集。

特别适用于与 torch.nn.parallel.DistributedDataParallel 结合使用。在这种情况下,每个进程可以将 DistributedSampler 实例作为 DataLoader 采样器传递,并加载原始数据集的子集,该子集仅属于它。

注意

数据集假设为固定大小,并且任何实例都始终返回相同的元素,且顺序相同。

参数:
  • 数据集(Dataset)- 用于采样的数据集。

  • num_replicas(int,可选)- 参与分布式训练的进程数。默认情况下,从当前分布式组中检索 world_size

  • rank(int,可选)- 当前进程在 num_replicas 中的排名。默认情况下,从当前分布式组中检索 rank

  • shuffle(布尔值,可选)- 如果 True (默认),采样器将打乱索引。

  • seed(整数,可选)- 用于打乱采样器的随机种子。此数字应在分布式组中的所有进程中保持一致。默认: 0

  • drop_last(布尔值,可选)- 如果 True ,则采样器将丢弃数据尾部以使其能够被副本数整除。如果 False ,采样器将添加额外的索引以使数据能够被副本数整除。默认: False

警告

在分布式模式下,在每个 epoch 开始时调用 set_epoch() 方法,在创建 DataLoader 迭代器之前,以确保跨多个 epoch 正确打乱。否则,将始终使用相同的顺序。

示例:

>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
...                     sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源