• 教程 >
  • (原型)MaskedTensor 简介
快捷键

(原型)掩码张量概述 ¶

创建时间:2025 年 4 月 1 日 | 最后更新时间:2025 年 4 月 1 日 | 最后验证:未验证

本教程旨在作为使用 MaskedTensors 的起点,并讨论其掩码语义。

MaskedTensor 作为对 torch.Tensor 的扩展,为用户提供以下功能:

  • 使用任何掩码语义(例如,变长张量、nan*运算符等)

  • 0 和 NaN 梯度的区分

  • 不同的稀疏应用(请参阅下面的教程)

想要更详细地了解 MaskedTensors 是什么,请参阅 torch.masked 文档。

使用掩码张量 ¶

在本节中,我们将讨论如何使用 MaskedTensor,包括如何构建、访问数据以及掩码,以及索引和切片。

准备 ¶

我们将开始进行教程所需的设置:

import torch
from torch.masked import masked_tensor, as_masked_tensor
import warnings

# Disable prototype warnings and such
warnings.filterwarnings(action='ignore', category=UserWarning)

构造 ¶

有几种不同的方法来构造一个 MaskedTensor:

  • 第一种方法是直接调用 MaskedTensor 类

  • 第二种(也是我们推荐的方法)是使用 masked.masked_tensor()masked.as_masked_tensor() 工厂函数,它们与 torch.tensor()torch.as_tensor() 类似

在本教程中,我们将假设导入行:from torch.masked import masked_tensor。

访问数据和掩码

MaskedTensor 中的底层字段可以通过以下方式访问:

  • MaskedTensor.get_data() 函数

  • MaskedTensor.get_mask() 函数。回想一下, True 表示“指定的”或“有效的”,而 False 表示“未指定的”或“无效的”。

通常情况下,返回的底层数据在未指定条目中可能无效,因此我们建议当用户需要没有任何掩码条目的 Tensor 时,他们应使用 MaskedTensor.to_tensor() (如上所示)来返回一个填充值的 Tensor。

索引和切片

MaskedTensor 是 Tensor 的子类,这意味着它继承了与 torch.Tensor 相同的索引和切片语义。以下是一些常见的索引和切片模式的示例:

data = torch.arange(24).reshape(2, 3, 4)
mask = data % 2 == 0

print("data:\n", data)
print("mask:\n", mask)
# float is used for cleaner visualization when being printed
mt = masked_tensor(data.float(), mask)

print("mt[0]:\n", mt[0])
print("mt[:, :, 2:4]:\n", mt[:, :, 2:4])

为什么 MaskedTensor 很有用?

由于 MaskedTensor 将指定值和不指定值视为一等公民,而不是事后考虑(使用填充值、nan 等),它能够解决常规 Tensor 无法解决的许多缺点;事实上, MaskedTensor 在很大程度上是由于这些反复出现的问题而诞生的。

下面,我们将讨论 PyTorch 今天仍然存在的一些最常见的问题,并展示 MaskedTensor 如何解决这些问题。

区分 0 和 NaN 梯度

torch.Tensor 遇到的一个问题是无法区分未定义(NaN)的梯度与实际为 0 的梯度。因为 PyTorch 没有标记值是否指定/有效或未指定/无效的方法,所以它被迫依赖于 NaN 或 0(取决于用例),导致语义不可靠,因为许多操作并不适合正确处理 NaN 值。更令人困惑的是,有时根据操作顺序,梯度可能会变化(例如,取决于 NaN 值在操作链中出现的早晚)。

MaskedTensor 是这个问题的完美解决方案!

torch.where

在 Issue 10729 中,我们注意到在使用 torch.where() 时,操作顺序可能会影响结果,因为我们难以区分 0 是一个真正的 0 还是来自未定义梯度的 0。因此,我们保持一致性,并屏蔽结果:

当前结果:

x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float)
y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))
y.sum().backward()
x.grad

MaskedTensor 结果:

x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100])
mask = x < 0
mx = masked_tensor(x, mask, requires_grad=True)
my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)
y = torch.where(mask, torch.exp(mx), my)
y.sum().backward()
mx.grad

这里只提供了所选子集的梯度。实际上,这改变了屏蔽元素的位置梯度,而不是将它们设置为 0。

另一个 torch.where ¶

Issue 52248 是另一个例子。

当前结果:

a = torch.randn((), requires_grad=True)
b = torch.tensor(False)
c = torch.ones(())
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))

MaskedTensor 结果:

a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)
b = torch.tensor(False)
c = torch.ones(())
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))

该问题与下一个问题类似(甚至链接到下一个问题),因为它表达了由于无法区分“无梯度”与“零梯度”而导致的对意外行为的挫败感,这反过来使得与其他操作一起工作时难以推理。

当使用掩码时,x/0 返回 NaN 梯度 ¶

在问题 4132 中,用户建议 x.grad 应该是 [0, 1] 而不是 [nan, 1],而 MaskedTensor 通过屏蔽整个梯度使这一点非常明确。

当前结果:

x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x/div # => y is [inf, 1]
mask = (div != 0)  # => mask is [0, 1]
y[mask].backward()
x.grad

MaskedTensor 结果:

x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x/div # => y is [inf, 1]
mask = (div != 0) # => mask is [0, 1]
loss = as_masked_tensor(y, mask)
loss.sum().backward()
x.grad

torch.nansum()torch.nanmean()

在第 67180 期中,梯度计算不正确(一个长期存在的问题),而 MaskedTensor 处理得正确。

当前结果:

a = torch.tensor([1., 2., float('nan')])
b = torch.tensor(1.0, requires_grad=True)
c = a * b
c1 = torch.nansum(c)
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
bgrad1

MaskedTensor 结果:

a = torch.tensor([1., 2., float('nan')])
b = torch.tensor(1.0, requires_grad=True)
mt = masked_tensor(a, ~torch.isnan(a))
c = mt * b
c1 = torch.sum(c)
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
bgrad1

安全的 Softmax ¶

安全的 softmax 是另一个经常出现的问题的绝佳例子。简而言之,如果一个整个批次被“屏蔽”或完全由填充(在 softmax 的情况下,相当于设置为-∞)组成,那么这将导致 NaN,这可能导致训练发散。

幸运的是, MaskedTensor 已经解决了这个问题。考虑以下设置:

data = torch.randn(3, 3)
mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]])
x = data.masked_fill(~mask, float('-inf'))
mt = masked_tensor(data, mask)
print("x:\n", x)
print("mt:\n", mt)

例如,我们想要沿着 dim=0 计算 softmax。请注意,第二列是“不安全的”(即完全屏蔽),因此当计算 softmax 时,结果将为 0/0 = nan,因为 exp(-∞) = 0。然而,我们真正希望的是梯度被屏蔽,因为它们是未指定的,对于训练来说将是无效的。

PyTorch 结果:

x.softmax(0)

MaskedTensor 结果:

mt.softmax(0)

实现缺失的 torch.nan*运算符 ¶

在 Issue 61474 中,有一个请求添加额外的运算符来覆盖 torch.nan*的各种应用,例如 torch.nanmaxtorch.nanmin 等。

通常,这些问题更适合使用掩码语义,因此我们建议使用 MaskedTensor 代替引入额外的运算符。由于 nanmean 已经实现,我们可以将其作为比较点:

x = torch.arange(16).float()
y = x * x.fmod(4)
z = y.masked_fill(y == 0, float('nan'))  # we want to get the mean of y when ignoring the zeros
print("y:\n", y)
# z is just y with the zeros replaced with nan's
print("z:\n", z)
print("y.mean():\n", y.mean())
print("z.nanmean():\n", z.nanmean())
# MaskedTensor successfully ignores the 0's
print("torch.mean(masked_tensor(y, y != 0)):\n", torch.mean(masked_tensor(y, y != 0)))

在上述示例中,我们已经构建了 y,并希望计算序列的均值,同时忽略零。可以使用 torch.nanmean 来完成此操作,但我们没有实现 torch.nan*的其他操作。 MaskedTensor 通过能够使用基本运算符来解决这个问题,并且我们已经支持了问题中列出的其他操作。例如:

torch.argmin(masked_tensor(y, y != 0))

事实上,忽略 0 的情况下最小参数的索引是索引 1 中的 1。

MaskedTensor 还可以支持当数据完全被掩码时的缩减操作,这与数据 Tensor 完全 nan 的情况上述相同。 nanmean 将返回 nan (一个模糊的返回值),而 MaskedTensor 将更准确地指示掩码结果。

x = torch.empty(16).fill_(float('nan'))
print("x:\n", x)
print("torch.nanmean(x):\n", torch.nanmean(x))
print("torch.nanmean via maskedtensor:\n", torch.mean(masked_tensor(x, ~torch.isnan(x))))

这与安全 softmax 类似的问题,其中 0/0 等于 nan,而我们真正想要的是未定义的值。

结论 ¶

在本教程中,我们介绍了 MaskedTensors 是什么,演示了如何使用它们,并通过一系列示例和问题展示了它们的价值,这些问题它们已经帮助解决。

进一步阅读

要继续学习更多,您可以找到我们的 MaskedTensor 稀疏性教程,了解 MaskedTensor 如何实现稀疏性以及我们目前支持的不同存储格式。

脚本总运行时间:(0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的画廊


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源