• 文档 >
  • 广播语义
快捷键

广播语义 ¶

许多 PyTorch 操作支持 NumPy 的广播语义。详情请见 https://numpy.org/doc/stable/user/basics/broadcasting.html。

简而言之,如果 PyTorch 操作支持广播,那么其张量参数可以自动扩展为相等的大小(而不复制数据)。

一般语义学 ¶

两个张量是“可广播的”,如果满足以下规则:

  • 每个张量至少有一个维度。

  • 在迭代维度大小时,从最后一个维度开始,维度大小必须相等,或者其中一个为 1,或者其中一个不存在。

例如:

>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension

# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist

# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3

如果两个张量 xy 是“可广播的”,则结果张量的大小计算如下:

  • 如果 xy 的维度数不相等,则将 1 添加到维度较少的张量的维度之前,使它们的长度相等。

  • 然后,对于每个维度大小,结果维度大小是该维度上 xy 大小的最大值。

例如:

# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty(  3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])

# but not necessary:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])

>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

原地语义 ¶

一个问题是原地操作不允许原地张量在广播过程中改变形状。

例如:

>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])

# but:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.

向后兼容性

PyTorch 的早期版本允许某些逐点函数在具有相同元素数量的不同形状的张量上执行,此时逐点操作会将每个张量视为一维的。现在 PyTorch 支持广播,将“一维”逐点行为视为已弃用,并在张量不可广播但具有相同元素数量时生成 Python 警告。

注意,引入广播可能会导致向后不兼容的更改,在这种情况下,两个张量形状不同,但可广播且具有相同数量的元素。例如:

>>> torch.add(torch.ones(4,1), torch.randn(4))

之前会生成大小为 torch.Size([4,1])的张量,但现在会生成大小为 torch.Size([4,4])的张量。为了帮助您识别代码中可能存在的由广播引起的向后不兼容性案例,您可以设置 torch.utils.backcompat.broadcast_warning.enabled 为 True,这将在此类情况下生成 Python 警告。

例如:

>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源