• 文档 >
  • torch.masked
快捷键

torch.masked

简介

动机 ¶

警告

PyTorch 的掩码张量 API 处于原型阶段,未来可能会也可能不会发生变化。

MaskedTensor 作为对 torch.Tensor 的扩展,为用户提供以下功能:

  • 使用任何掩码语义(例如可变长度张量、nan*运算符等)

  • 区分 0 和 NaN 梯度

  • 不同的稀疏应用(请参阅下面的教程)

“指定”和“未指定”在 PyTorch 中有着悠久的历史,但没有正式的语义,当然也没有一致性;事实上,MaskedTensor 的诞生源于一系列问题,原始的 torch.Tensor 类无法正确解决。因此,MaskedTensor 的主要目标是成为 PyTorch 中“指定”和“未指定”值的真相来源,在这些值中,它们是第一公民而不是事后考虑。反过来,这应该进一步释放稀疏性的潜力,使操作更安全、更一致,并为用户和开发者提供更流畅、更直观的体验。

什么是 MaskedTensor?

MaskedTensor 是一个张量子类,由以下两部分组成:1)输入(数据),2)掩码。掩码告诉我们哪些输入条目应该被包含或忽略。

例如,假设我们想要屏蔽所有等于 0 的值(用灰色表示)并取最大值:

_images/tensor_comparison.jpg

顶部是普通的张量示例,底部是屏蔽张量,其中所有 0 都被屏蔽。这清楚地表明,根据是否有屏蔽,结果会有所不同。这种灵活的结构允许用户在计算过程中系统地忽略他们想要的任何元素。

我们已经编写了大量的教程来帮助用户入门,例如:

支持的操作 §

一元运算符 ¶

一元运算符是只包含单个输入的运算符。将它们应用于掩码张量相对简单:如果数据在给定索引处被掩码,则应用运算符,否则我们将继续掩码数据。

可用的单运算符有:

abs

计算每个元素在 input 中的绝对值。

absolute

torch.abs() 的别名。

acos

计算每个元素在 input 中的反余弦值。

arccos

torch.acos() 的别名。

acosh

返回一个新张量,包含 input 元素的反双曲余弦值。

arccosh

torch.acosh() 的别名。

angle

计算给定 input 张量的元素角度(以弧度为单位)。

asin

返回一个新张量,包含 input 的元素的反正弦值。

arcsin

torch.asin() 的别名。

asinh

返回一个新张量,包含 input 的元素的逆双曲正弦值。

arcsinh

torch.asinh() 的别名。

atan

返回一个新张量,包含 input 的元素的反正切。

arctan

torch.atan() 的别名。

atanh

返回一个新张量,包含 input 的元素的双曲反正切。

arctanh

torch.atanh() 的别名。

bitwise_not

计算给定输入张量的按位非。

ceil

返回一个新张量,其中包含 input 元素的 ceil 值,即大于或等于每个元素的最小整数。

clamp

input 中的所有元素夹在 [ min , max ] 范围内。

clip

torch.clamp() 的别名。

conj_physical

计算给定 input 张量的逐元素共轭。

cos

返回一个新张量,其中包含 input 元素的余弦值。

cosh

返回一个新张量,其中包含 input 元素的双曲余弦值。

deg2rad

input 中的每个元素从角度(度)转换为弧度的新张量。

digamma

torch.special.digamma() 的别名。

erf

torch.special.erf() 的别名。

erfc

torch.special.erfc() 的别名。

erfinv

torch.special.erfinv() 的别名。

exp

返回一个新张量,其元素为输入张量 input 的指数。

exp2

torch.special.exp2() 的别名。

expm1

torch.special.expm1() 的别名。

fix

torch.trunc() 的别名。

floor

返回一个新张量,包含 input 中元素的向下取整,即每个元素的最大整数。

frac

计算每个元素的分数部分。

lgamma

计算伽玛函数在 input 上的自然对数。

log

返回一个新张量,包含 input 中元素的自然对数。

log10

返回一个新张量,其元素为 input 的以 10 为底的对数。

log1p

返回一个新张量,其元素为(1 + input )的自然对数。

log2

返回一个新张量,其元素为 input 的以 2 为底的对数。

logit

torch.special.logit() 的别名。

i0

torch.special.i0() 的别名。

isnan

返回一个新张量,其中的布尔元素表示 input 的每个元素是否为 NaN。

nan_to_num

input 中的 NaN 、正无穷和负无穷值分别替换为 nanposinfneginf 中指定的值。

neg

返回一个新张量,其元素为 input 元素的相反数。

negative

torch.neg() 的别名。

positive

返回 input

pow

input 中每个元素的幂与 exponent 相乘,并返回结果张量。

rad2deg

input 中的每个元素从弧度转换为度,并返回新的张量。

reciprocal

返回 input 中每个元素倒数的新张量。

round

input 的元素四舍五入到最接近的整数。

rsqrt

返回一个新张量,其中包含 input 中每个元素的平方根的倒数。

sigmoid

torch.special.expit() 的别名。

sign

返回一个新张量,其中包含 input 中每个元素的符号。

sgn

这个函数是 torch.sign()对复数张量的扩展。

signbit

测试 input 中的每个元素是否设置了符号位。

sin

返回一个新张量,其中包含 input 元素的正弦值。

sinc

torch.special.sinc() 的别名。

sinh

返回一个新张量,其元素为 input 的双曲正弦。

sqrt

返回一个新张量,其元素为 input 的平方根。

square

返回一个新张量,其元素为 input 的平方。

tan

返回一个新张量,其元素为 input 的正切。

tanh

返回一个新张量,其元素为 input 的双曲正切值。

trunc

返回一个新张量,其元素为 input 的截断整数值。

可用的就地一元运算符都是上述所有,除了:

angle

计算给定 input 张量的元素角度(以弧度为单位)。

positive

返回 input

signbit

测试 input 中每个元素是否设置了符号位。

isnan

返回一个新张量,其中的布尔元素表示 input 中每个元素是否为 NaN。

二元运算符 §

如您在教程中看到的, MaskedTensor 还实现了二进制运算,但前提是两个 MaskedTensors 中的掩码必须匹配,否则会引发错误。正如错误信息中提到的,如果您需要支持特定的运算符或已提出关于它们应该如何行为的语义,请打开 GitHub 上的问题。目前,我们已决定采用最保守的实现方式,以确保用户确切了解正在发生的情况,并且在使用掩码语义时做出有意的决策。

可用的二进制运算符有:

add

other 加上 alpha 缩放的结果加到 input 上。

atan2

考虑象限的 inputi/otheri\text{input}_{i} / \text{other}_{i} 的逐元素反正切。

arctan2

torch.atan2() 的别名。

bitwise_and

计算 inputother 的按位与。

bitwise_or

计算 inputother 的按位或。

bitwise_xor

计算 inputother 的按位异或。

bitwise_left_shift

计算对 input 进行 other 位左算术移位。

bitwise_right_shift

计算对 input 进行 other 位右算术移位。

div

将输入 input 的每个元素除以对应的 other 元素。

divide

torch.div() 的别名。

floor_divide

fmod

应用 C++的 std::fmod 逐元素操作。

logaddexp

输入指数幂之和的对数。

logaddexp2

输入在 2 为底下的指数幂之和的对数。

mul

input 乘以 other

multiply

torch.mul() 的别名。

nextafter

按元素方式返回 input 之后相对于 other 的下一个浮点数值。

remainder

计算 Python 的逐元素取模运算。

sub

input 中减去 other ,并按 alpha 缩放。

subtract

torch.sub() 的别名。

true_divide

torch.div()rounding_mode=None 的别名。

eq

计算逐元素相等性。

ne

计算逐元素 inputother\text{input} \neq \text{other}

le

计算第 0#元素级。

ge

计算第 0#元素级。

greater

torch.gt() 的别名。

greater_equal

torch.ge() 的别名。

gt

计算元素级 input>other\text{input} > \text{other}

less_equal

torch.le() 的别名

lt

计算元素级 input<other\text{input} < \text{other}

less

torch.lt() 的别名

maximum

计算元素-wise inputother 的最大值。

minimum

计算元素-wise inputother 的最小值。

fmax

计算元素-wise inputother 的最大值。

fmin

计算元素-wise inputother 的最小值。

not_equal

别名: torch.ne()

可用的就地二进制运算符都是上述所有,除了:

logaddexp

输入指数和的对数的对数。

logaddexp2

输入指数和的以 2 为底的对数。

equal

True 如果两个张量具有相同的大小和元素, False 否则。

fmin

计算元素级的 inputother 的最小值。

minimum

计算元素级的 inputother 的最小值。

fmax

计算元素级的 inputother 的最大值。

减少

以下为可用的归约操作(具有自动微分支持)。有关更多信息,概述教程详细介绍了归约的一些示例,而高级语义教程则对某些归约语义的决策进行了更深入的讨论。

sum

返回 input 张量中所有元素的总和。

mean

amin

返回 input 张量在给定维度 dim 上的每个切片的最小值。

amax

返回 input 张量在给定维度 dim 上的每个切片的最大值。

argmin

返回展平张量或沿维度的最小值索引。

argmax

返回 input 张量中所有元素的最大值的索引。

prod

返回 input 张量中所有元素的乘积。

all

测试 input 张量中所有元素是否都评估为 True。

norm

返回给定张量的矩阵范数或向量范数。

var

计算由 dim 指定的维度上的方差。

std

计算由 dim 指定的维度上的标准差。

查看和选择功能

我们包括了多个视图和选择功能;直观上,这些操作符将应用于数据和掩码,然后将结果包装在 MaskedTensor 中。为了快速示例,考虑 select()

>>> data = torch.arange(12, dtype=torch.float).reshape(3, 4)
>>> data
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
>>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]])
>>> mt = masked_tensor(data, mask)
>>> data.select(0, 1)
tensor([4., 5., 6., 7.])
>>> mask.select(0, 1)
tensor([False,  True, False, False])
>>> mt.select(0, 1)
MaskedTensor(
  [      --,   5.0000,       --,       --]
)

当前支持以下操作:

atleast_1d

返回每个输入张量的 1 维视图,其中没有零维。

broadcast_tensors

根据广播语义广播给定的张量。

broadcast_to

input 广播到 shape 的形状。

cat

在指定维度上连接给定的张量序列 tensors

chunk

尝试将张量分割成指定的块数。

column_stack

通过水平堆叠张量 tensors 创建一个新的张量。

dsplit

将具有三个或更多维度的张量 input 深度分割成多个张量,根据 indices_or_sections

flatten

input 通过重塑成一维张量来展平。

hsplit

将具有一个或多个维度的张量 input 根据 indices_or_sections 水平分割成多个张量。

hstack

水平(按列)顺序堆叠张量。

kron

计算由 inputother 表示的克罗内克积,记作 \otimes

meshgrid

根据 attr:tensors 中的 1D 输入创建坐标网格。

narrow

返回一个新的张量,它是 input 张量的缩小版本。

nn.functional.unfold

从批处理的输入张量中提取滑动局部块。

ravel

返回一个连续的展平张量。

select

沿选定维度在给定索引处切割 input 张量。

split

将张量分割成块。

stack

沿新维度连接一系列张量。

t

预期 input 为<=2 维张量,并交换维度 0 和 1。

transpose

返回一个张量,它是 input 的转置版本。

vsplit

将具有两个或更多维度的张量 input 垂直分割成多个张量。

vstack

将张量按顺序垂直堆叠(按行)。

Tensor.expand

返回 self 张量的新视图,将单例维度扩展到更大的大小。

Tensor.expand_as

将此张量扩展到与 other 相同的大小。

Tensor.reshape

返回一个具有与 self 相同数据和元素数量的张量,但具有指定的形状。

Tensor.reshape_as

以与 other 相同的形状返回此张量。

Tensor.unfold

返回原始张量的一个视图,该视图包含从 self 张量在 dimension 维度上的所有大小为 size 的切片。

Tensor.view

返回一个新张量,其数据与 self 张量相同,但类型不同 shape


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

PyTorch 开发者文档全面访问

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源