• 文档 >
  • torch >
  • torch.repeat_interleave
快捷键

torch.repeat_interleave

torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) Tensor

重复张量中的元素。

警告

这与 torch.Tensor.repeat() 不同,但与 numpy.repeat 相似。

参数:
  • input (Tensor) – 输入张量。

  • repeats(张量或 int)- 每个元素的重复次数。repeats 会广播以适应给定轴的形状。

  • dim(int,可选)- 重复值的维度。默认情况下,使用展平的输入数组,并返回一个扁平的输出数组。

关键字参数:

输出大小(int,可选)- 给定轴上的总输出大小(例如重复项的总和)。如果提供,将避免计算张量输出形状所需的流同步。

返回值:

重复张量,其形状与输入相同,除了在给定的轴上。

返回类型:

张量

示例:

>>> x = torch.tensor([1, 2, 3])
>>> x.repeat_interleave(2)
tensor([1, 1, 2, 2, 3, 3])
>>> y = torch.tensor([[1, 2], [3, 4]])
>>> torch.repeat_interleave(y, 2)
tensor([1, 1, 2, 2, 3, 3, 4, 4])
>>> torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
tensor([[1, 2],
        [3, 4],
        [3, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3)
tensor([[1, 2],
        [3, 4],
        [3, 4]])

如果 repeats 是 tensor([n1, n2, n3, …]),则输出将是 tensor([0, 0, …, 1, 1, …, 2, 2, …, …]),其中 0 出现 n1 次,1 出现 n2 次,2 出现 n3 次,等等。

torch.repeat_interleave(repeats, *) → 张量

重复 0 次,1 重复[1]次,2 重复[2]次,等等。

参数:

重复(张量)- 每个元素的重复次数。

返回值:

重复张量的大小为 sum(repeats)。

返回类型:

张量

示例:

>>> torch.repeat_interleave(torch.tensor([1, 2, 3]))
tensor([0, 1, 1, 2, 2, 2])

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源