备注
点击此处下载完整示例代码
(原型)MaskedTensor 高级语义 ¶
创建时间:2025 年 4 月 1 日 | 最后更新时间:2025 年 4 月 1 日 | 最后验证:未验证
在开始本教程之前,请确保您已经阅读了我们的 MaskedTensor 概述教程 。
本教程的目的是帮助用户理解一些高级语义的工作原理以及它们是如何产生的。我们将重点关注其中的两个:
面向 MaskedTensor 和 NumPy 的 MaskedArray 的差异 *. 累加语义
准备工作 ¶
import torch
from torch.masked import masked_tensor
import numpy as np
import warnings
# Disable prototype warnings and such
warnings.filterwarnings(action='ignore', category=UserWarning)
MaskedTensor 与 NumPy 的 MaskedArray 对比 ¶
NumPy 的 MaskedArray
与 MaskedTensor 存在一些基本的语义差异。
- *. 工厂函数和基本定义反转掩码(类似于
torch.nn.MHA
);即 MaskedTensor
使用True
表示“指定”和False
表示“未指定”、“有效”/“无效”,而 NumPy 则相反。我们认为我们的掩码定义不仅更直观,而且与 PyTorch 的整体语义更加一致。- *. 交集语义。在 NumPy 中,如果两个元素中有一个被掩码,则结果元素也会被掩码 – 实际上,它们应用逻辑或运算符。
掩码元素也会被掩码 – 实际上,它们应用逻辑或运算符。
data = torch.arange(5.)
mask = torch.tensor([True, True, False, True, False])
npm0 = np.ma.masked_array(data.numpy(), (~mask).numpy())
npm1 = np.ma.masked_array(data.numpy(), (mask).numpy())
print("npm0:\n", npm0)
print("npm1:\n", npm1)
print("npm0 + npm1:\n", npm0 + npm1)
同时,MaskedTensor 不支持与不匹配的掩码进行加法或二元运算 - 要了解原因,请参阅关于减少的部分。
mt0 = masked_tensor(data, mask)
mt1 = masked_tensor(data, ~mask)
print("mt0:\n", mt0)
print("mt1:\n", mt1)
try:
mt0 + mt1
except ValueError as e:
print ("mt0 + mt1 failed. Error: ", e)
然而,如果需要这种行为,MaskedTensor 通过提供数据和掩码的访问,并方便地将 MaskedTensor 转换为使用 to_tensor()
填充掩码值的 Tensor 来支持这些语义。例如:
t0 = mt0.to_tensor(0)
t1 = mt1.to_tensor(0)
mt2 = masked_tensor(t0 + t1, mt0.get_mask() & mt1.get_mask())
print("t0:\n", t0)
print("t1:\n", t1)
print("mt2 (t0 + t1):\n", mt2)
注意,掩码是 mt0.get_mask() & mt1.get_mask(),因为 MaskedTensor
的掩码是 NumPy 的逆掩码。
减少语义
在 MaskedTensor 的概述教程中,我们讨论了“实现 torch.nan*操作”。这些都是降维操作的例子——从 Tensor 中移除一个(或多个)维度然后聚合结果的运算符。在本节中,我们将使用降维语义来阐述我们对匹配上方掩码的严格要求。
基本上,:class:`MaskedTensor`执行相同的降维操作,同时忽略被掩码(未指定)的值。以下是一个示例:
data = torch.arange(12, dtype=torch.float).reshape(3, 4)
mask = torch.randint(2, (3, 4), dtype=torch.bool)
mt = masked_tensor(data, mask)
print("data:\n", data)
print("mask:\n", mask)
print("mt:\n", mt)
现在,不同的降维操作(都在 dim=1 维度上):
print("torch.sum:\n", torch.sum(mt, 1))
print("torch.mean:\n", torch.mean(mt, 1))
print("torch.prod:\n", torch.prod(mt, 1))
print("torch.amin:\n", torch.amin(mt, 1))
print("torch.amax:\n", torch.amax(mt, 1))
值得注意的是,掩码元素下的值不一定具有任何特定值,特别是如果行或列完全被掩码(对归一化也是如此)。有关掩码语义的更多详细信息,您可以查阅此 RFC。
现在,我们可以重新审视这个问题:为什么我们要强制执行掩码必须匹配的二进制运算符的不变性?换句话说,为什么我们不使用与 np.ma.masked_array
相同的语义?考虑以下示例:
data0 = torch.arange(10.).reshape(2, 5)
data1 = torch.arange(10.).reshape(2, 5) + 10
mask0 = torch.tensor([[True, True, False, False, False], [False, False, False, True, True]])
mask1 = torch.tensor([[False, False, False, True, True], [True, True, False, False, False]])
npm0 = np.ma.masked_array(data0.numpy(), (mask0).numpy())
npm1 = np.ma.masked_array(data1.numpy(), (mask1).numpy())
print("npm0:", npm0)
print("npm1:", npm1)
现在,让我们尝试加法:
print("(npm0 + npm1).sum(0):\n", (npm0 + npm1).sum(0))
print("npm0.sum(0) + npm1.sum(0):\n", npm0.sum(0) + npm1.sum(0))
和加法应该显然是结合的,但根据 NumPy 的语义,它们并不是,这可能会使用户感到困惑。
另一方面, MaskedTensor
将简单地不允许这种操作,因为 mask0 不等于 mask1。话虽如此,如果用户愿意,仍然有方法可以绕过这一点(例如,使用 to_tensor()
将 MaskedTensor 的未定义元素填充为 0 值,如下所示),但用户现在必须更明确地表达他们的意图。
mt0 = masked_tensor(data0, ~mask0)
mt1 = masked_tensor(data1, ~mask1)
(mt0.to_tensor(0) + mt1.to_tensor(0)).sum(0)
结论 ¶
在本教程中,我们学习了 MaskedTensor 和 NumPy 的 MaskedArray 背后的不同设计决策,以及缩减语义。总的来说,MaskedTensor 旨在避免歧义和混淆的语义(例如,我们试图在二元运算中保持结合性),这反过来又可能需要用户在编写代码时更加有意图,但我们认为这是更好的选择。如果您对此有任何想法,请告诉我们!
脚本总运行时间:(0 分钟 0.000 秒)