torch.nested¶
简介
警告
PyTorch 嵌套张量的 API 目前处于原型阶段,未来将进行更改。
嵌套张量允许将不规则形状的数据包含在其中,并作为一个单独的张量进行操作。此类数据以高效的打包表示形式存储在下面,同时暴露标准的 PyTorch 张量接口以应用操作。
嵌套张量的一种常见应用是表达存在于各个领域的可变长度序列数据,例如变化的句子长度、图像大小以及音频/视频剪辑长度。传统上,此类数据通过填充序列到批次中的最大长度来处理,在填充形式上执行计算,然后通过掩码来去除填充。这种方法效率低下且容易出错,嵌套张量正是为了解决这些问题而存在的。
调用嵌套张量上操作的 API 与常规的 torch.Tensor
没有区别,允许与现有模型无缝集成,主要区别在于输入的构建。
由于这是一个原型功能,支持的运算集有限,但正在增长。我们欢迎提出问题、功能请求和贡献。有关贡献的更多信息,请参阅此 Readme。
构造 ¶
注意
PyTorch 中存在两种嵌套张量形式,由构造时指定的布局来区分。布局可以是 torch.strided
或 torch.jagged
。我们建议尽可能使用 torch.jagged
布局。虽然它目前只支持单个稀疏维度,但它具有更好的操作覆盖范围,正在积极开发中,并且与 torch.compile
很好地集成。这些文档遵循此建议,并将具有 torch.jagged
布局的嵌套张量简称为“NJTs”。
构造过程简单,涉及将张量列表传递给 torch.nested.nested_tensor
构造函数。具有 torch.jagged
布局的嵌套张量(又称“NJT”)支持单个稀疏维度。此构造函数将根据下文数据布局部分中描述的布局将输入张量复制到连续的内存块中。
>>> a, b = torch.arange(3), torch.arange(5) + 3
>>> a
tensor([0, 1, 2])
>>> b
tensor([3, 4, 5, 6, 7])
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> print([component for component in nt])
[tensor([0, 1, 2]), tensor([3, 4, 5, 6, 7])]
列表中的每个张量必须具有相同的维度数,但形状可以在单个维度上有所不同。如果输入组件的维度不匹配,构造函数将抛出错误。
>>> a = torch.randn(50, 128) # 2D tensor
>>> b = torch.randn(2, 50, 128) # 3D tensor
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
...
RuntimeError: When constructing a nested tensor, all tensors in list must have the same dim
在构造过程中,可以通过通常的关键字参数选择 dtype、device 以及是否需要梯度。
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32, device="cuda", requires_grad=True)
>>> print([component for component in nt])
[tensor([0., 1., 2.], device='cuda:0',
grad_fn=<UnbindBackwardAutogradNestedTensor0>), tensor([3., 4., 5., 6., 7.], device='cuda:0',
grad_fn=<UnbindBackwardAutogradNestedTensor0>)]
可以用于从构造函数传递给张量的张量中保留自动微分历史。当使用此构造函数时,梯度将通过嵌套张量反向流回原始组件。请注意,此构造函数仍然会将输入组件复制到一个打包的连续内存块中。
>>> a = torch.randn(12, 512, requires_grad=True)
>>> b = torch.randn(23, 512, requires_grad=True)
>>> nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt.sum().backward()
>>> a.grad
tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]])
>>> b.grad
tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]])
上述函数都创建连续的 NJT,其中分配了一块内存来存储底层组件的打包形式(有关更多详细信息,请参阅下面的数据布局部分)。
还可以创建一个非连续的 NJT 视图,该视图基于现有的密集张量并带有填充,从而避免内存分配和复制。 torch.nested.narrow()
是实现此目的的工具。
>>> padded = torch.randn(3, 5, 4)
>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64)
>>> nt = torch.nested.narrow(padded, dim=1, start=0, length=seq_lens, layout=torch.jagged)
>>> nt.shape
torch.Size([3, j1, 4])
>>> nt.is_contiguous()
False
请注意,嵌套张量作为原始填充密集张量的视图,引用相同的内存而不进行复制/分配。对于非连续 NJT 的操作支持相对有限,因此如果您遇到支持差距,始终可以使用 contiguous()
将其转换为连续的 NJT。
数据布局和形状
为了提高效率,嵌套张量通常将张量组件打包成连续的内存块,并维护额外的元数据来指定批次项边界。对于 torch.jagged
布局,连续的内存块存储在 values
组件中,而 offsets
组件用于界定有锯齿维度的批次项边界。

在必要时,可以直接访问底层的 NJT 组件。
>>> a = torch.randn(50, 128) # text 1
>>> b = torch.randn(32, 128) # text 2
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt.values().shape # note the "packing" of the ragged dimension; no padding needed
torch.Size([82, 128])
>>> nt.offsets()
tensor([ 0, 50, 82])
直接从锯齿状的 values
和 offsets
组成部分构建 NJT 也可能很有用; torch.nested.nested_tensor_from_jagged()
构造函数用于此目的。
>>> values = torch.randn(82, 128)
>>> offsets = torch.tensor([0, 50, 82], dtype=torch.int64)
>>> nt = torch.nested.nested_tensor_from_jagged(values=values, offsets=offsets)
NJT 具有比其组成部分高一个维度的明确形状。下面示例中的粗糙维度的底层结构由一个符号值( j1
)表示。
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt.dim()
3
>>> nt.shape
torch.Size([2, j1, 128])
NJT 必须具有相同的粗糙结构才能相互兼容。例如,要运行涉及两个 NJT 的二进制运算,粗糙结构必须匹配(即它们的形状中必须具有相同的粗糙形状符号)。在细节上,每个符号对应一个确切的 offsets
张量,因此两个 NJT 必须具有相同的 offsets
张量才能相互兼容。
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt2 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt1.offsets() is nt2.offsets()
False
>>> nt3 = nt1 + nt2
RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128)
在上述示例中,尽管两个 NJT 的概念形状相同,但它们不共享对同一个 offsets
张量的引用,因此它们的形状不同,它们不兼容。我们认识到这种行为是不直观的,并且正在努力为嵌套张量的 beta 版本放宽这一限制。有关解决方案,请参阅本文件的故障排除部分。
除了 offsets
元数据之外,NJT 还可以计算并缓存其组件的最小和最大序列长度,这对于调用特定内核(例如 SDPA)非常有用。目前还没有公开的 API 可以访问这些信息,但在 beta 版本中将会改变。
支持的操作
本节包含了一组您可能会用到的嵌套张量常见操作列表。这并不是一个全面的列表,因为 PyTorch 中有大约两千个操作。虽然其中相当一部分操作目前支持嵌套张量,但全面支持仍然是一个巨大的任务。嵌套张量的理想状态是支持所有适用于非嵌套张量的 PyTorch 操作。为了帮助我们实现这一目标,请考虑:
在此处请求您用例所需的特定操作,以帮助我们确定优先级。
贡献!为给定的 PyTorch 操作添加嵌套张量支持并不困难;请参阅下面的贡献部分以获取详细信息。
查看嵌套张量的组成部分
unbind()
允许您获取嵌套张量的组成部分视图。
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 3)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> nt.unbind()
(tensor([[-0.9916, -0.3363, -0.2799],
[-2.3520, -0.5896, -0.4374]]), tensor([[-2.0969, -1.0104, 1.4841],
[ 2.0952, 0.2973, 0.2516],
[ 0.9035, 1.3623, 0.2026]]))
>>> nt.unbind()[0] is not a
True
>>> nt.unbind()[0].mul_(3)
tensor([[ 3.6858, -3.7030, -4.4525],
[-2.3481, 2.0236, 0.1975]])
>>> nt.unbind()
(tensor([[-2.9747, -1.0089, -0.8396],
[-7.0561, -1.7688, -1.3122]]), tensor([[-2.0969, -1.0104, 1.4841],
[ 2.0952, 0.2973, 0.2516],
[ 0.9035, 1.3623, 0.2026]]))
注意, nt.unbind()[0]
不是一个副本,而是底层内存的切片,它代表了嵌套张量的第一个条目或组成部分。
到/从填充的转换
将 NJT 转换为具有指定填充值的填充密集张量。稀疏维度将被填充到最大序列长度的大小。
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(6, 3)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> padded = torch.nested.to_padded_tensor(nt, padding=4.2)
>>> padded
tensor([[[ 1.6107, 0.5723, 0.3913],
[ 0.0700, -0.4954, 1.8663],
[ 4.2000, 4.2000, 4.2000],
[ 4.2000, 4.2000, 4.2000],
[ 4.2000, 4.2000, 4.2000],
[ 4.2000, 4.2000, 4.2000]],
[[-0.0479, -0.7610, -0.3484],
[ 1.1345, 1.0556, 0.3634],
[-1.7122, -0.5921, 0.0540],
[-0.5506, 0.7608, 2.0606],
[ 1.5658, -1.1934, 0.3041],
[ 0.1483, -1.1284, 0.6957]]])
这可以作为逃生口来绕过 NJT 支持差距,但理想情况下,应尽可能避免此类转换,以实现最佳内存使用和性能,因为更有效的嵌套张量布局不会产生填充。
反向转换可以使用 torch.nested.narrow()
完成,它将给定的密集张量应用于稀疏结构以产生 NJT。请注意,默认情况下,此操作不会复制底层数据,因此输出的 NJT 通常是连续的。如果需要连续的 NJT,则可能需要显式调用 contiguous()
。
>>> padded = torch.randn(3, 5, 4)
>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64)
>>> nt = torch.nested.narrow(padded, dim=1, length=seq_lens, layout=torch.jagged)
>>> nt.shape
torch.Size([3, j1, 4])
>>> nt = nt.contiguous()
>>> nt.shape
torch.Size([3, j2, 4])
形状操作 ¶
嵌套张量支持广泛的形状操作,包括视图。
>>> a = torch.randn(2, 6)
>>> b = torch.randn(4, 6)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> nt.shape
torch.Size([2, j1, 6])
>>> nt.unsqueeze(-1).shape
torch.Size([2, j1, 6, 1])
>>> nt.unflatten(-1, [2, 3]).shape
torch.Size([2, j1, 2, 3])
>>> torch.cat([nt, nt], dim=2).shape
torch.Size([2, j1, 12])
>>> torch.stack([nt, nt], dim=2).shape
torch.Size([2, j1, 2, 6])
>>> nt.transpose(-1, -2).shape
torch.Size([2, 6, j1])
注意力机制
由于可变长序列是注意力机制的常见输入,嵌套张量支持重要的注意力算子缩放点积注意力(SDPA)和 FlexAttention。有关 NJT 与 SDPA 的使用示例,请参阅此处;有关 NJT 与 FlexAttention 的使用示例,请参阅此处。
与 torch.compile 一起使用
NJTs 设计用于与 torch.compile()
配合使用以实现最佳性能,我们始终建议在可能的情况下使用 torch.compile()
与 NJT 结合。NJT 在作为编译函数或模块的输入传递时,或者在函数内联实例化时,都能即插即用且无需断开图结构。
注意
如果您无法为您的用例使用 torch.compile()
,性能和内存使用可能仍然会从 NJT 的使用中受益,但这种情况并不那么明显。重要的是要确保正在操作的张量足够大,以便性能提升不会因 Python 张量子类的开销而抵消。
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(4, 3)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> def f(x): return x.sin() + 1
...
>>> compiled_f = torch.compile(f, fullgraph=True)
>>> output = compiled_f(nt)
>>> output.shape
torch.Size([2, j1, 3])
>>> def g(values, offsets): return torch.nested.nested_tensor_from_jagged(values, offsets) * 2.
...
>>> compiled_g = torch.compile(g, fullgraph=True)
>>> output2 = compiled_g(nt.values(), nt.offsets())
>>> output2.shape
torch.Size([2, j1, 3])
注意,NJT 支持动态形状,以避免在结构变化时进行不必要的重新编译。
>>> a = torch.randn(2, 3)
>>> b = torch.randn(4, 3)
>>> c = torch.randn(5, 3)
>>> d = torch.randn(6, 3)
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> nt2 = torch.nested.nested_tensor([c, d], layout=torch.jagged)
>>> def f(x): return x.sin() + 1
...
>>> compiled_f = torch.compile(f, fullgraph=True)
>>> output1 = compiled_f(nt1)
>>> output2 = compiled_f(nt2) # NB: No recompile needed even though ragged structure differs
如果您在使用 NJT + torch.compile
时遇到问题或古怪的错误,请提交 PyTorch 问题。在 torch.compile
中实现完整的子类支持是一个长期目标,目前可能存在一些粗糙的边缘。
故障排除
本节包含您在利用嵌套张量时可能遇到的常见错误,以及这些错误的成因和解决建议。
未实现的操作符
随着嵌套张量操作符支持的增长,这种错误变得越来越少,但鉴于 PyTorch 中有数千个操作符,今天仍然有可能遇到。
NotImplementedError: aten.view_as_real.default
错误很简单;我们还没有来得及为这个特定的操作符添加操作符支持。如果您愿意,您可以自己贡献一个实现,或者简单地请求我们在未来的 PyTorch 版本中添加对这个操作符的支持。
杂乱结构不兼容 ¶
RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128)
当调用操作符对多个不兼容杂乱结构的 NJTs 进行操作时,会发生此错误。目前,要求输入 NJTs 必须具有完全相同的 offsets
构成成分,才能具有相同的符号杂乱结构符号(例如, j1
)。
对于这种情况,可以直接从 values
和 offsets
组件构建 NJTs。当两个 NJTs 都引用相同的 offsets
组件时,它们被认为具有相同的杂乱结构,因此是兼容的。
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt2 = torch.nested.nested_tensor_from_jagged(values=torch.randn(82, 128), offsets=nt1.offsets())
>>> nt3 = nt1 + nt2
>>> nt3.shape
torch.Size([2, j1, 128])
torch.compile 中的数据相关操作 ¶
torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
这种错误发生在调用在 torch.compile 中执行数据相关操作的 op 时;这通常发生在需要检查 NJT 的 offsets
的值以确定输出形状的 op。例如:
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> def f(nt): return nt.chunk(2, dim=0)[0]
...
>>> compiled_f = torch.compile(f, fullgraph=True)
>>> output = compiled_f(nt)
在这个例子中,在 NJT 的批处理维度上调用 chunk()
需要检查 NJT 的 offsets
数据以划分打包稀疏维度的批处理项边界。作为解决方案,可以设置几个 torch.compile 标志:
>>> torch._dynamo.config.capture_dynamic_output_shape_ops = True
>>> torch._dynamo.config.capture_scalar_outputs = True
如果设置这些后仍然看到数据相关的操作错误,请向 PyTorch 提交问题。 torch.compile()
的这个区域仍在积极开发中,NJT 支持的一些方面可能不完整。
贡献说明
如果你希望为嵌套张量开发做出贡献,最有效的方法之一是为当前不支持 PyTorch 操作添加嵌套张量支持。这个过程通常包括几个简单的步骤:
确定要添加的操作名称;这应该类似于
aten.view_as_real.default
。该操作的签名可以在aten/src/ATen/native/native_functions.yaml
中找到。在
torch/nested/_internal/ops.py
中注册操作实现,遵循那里为其他操作建立的模式。使用native_functions.yaml
中的签名进行模式验证。
实现操作最常见的方法是将 NJT 拆解为其组成部分,在底层的 values
缓冲区上重新调度操作,并将相关的 NJT 元数据(包括 offsets
)传播到新的输出 NJT。如果操作的输出预期与输入具有不同的形状,则必须计算新的 offsets
等元数据。
当操作应用于批量或稀疏维度时,以下技巧可以帮助快速获得一个可工作的实现:
对于非批量操作,应使用基于
unbind()
的回退方案。对于对稀疏维度的操作,考虑将数据转换为带有适当选择的填充值的填充密集格式,运行操作,然后再转换回 NJT。在
torch.compile
中,这些转换可以融合以避免产生填充的中间结果。
构造和转换函数的详细文档
- torch.nested.nested_tensor(tensor_list, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False)[source][source]¶
从一个张量列表构建一个没有自动微分历史记录的嵌套张量(也称为“叶张量”,参见自动微分机制)。
- 参数:
tensor_list (List[array_like]) – 张量列表,或任何可以传递给 torch.tensor 的内容,
维度。(列表中的每个元素都具有相同的) –
- 关键字参数:
dtype (
torch.dtype
,可选) – 返回嵌套张量的期望类型。默认:如果为 None,则与列表中最左侧的张量相同torch.dtype
。layout (
torch.layout
,可选) – 返回嵌套张量的期望布局。仅支持 strided 和 jagged 布局。默认:如果为 None,则使用 strided 布局。device (
torch.device
,可选) – 返回嵌套张量的期望设备。默认:如果为 None,则与列表中最左侧的张量相同torch.device
。requires_grad (bool,可选) – 如果 autograd 应记录对返回的嵌套张量上的操作。默认:
False
。pin_memory(布尔值,可选)- 如果设置,返回的嵌套张量将在固定内存中分配。仅适用于 CPU 张量。默认值:
False
。
- 返回类型:
示例:
>>> a = torch.arange(3, dtype=torch.float, requires_grad=True) >>> b = torch.arange(5, dtype=torch.float, requires_grad=True) >>> nt = torch.nested.nested_tensor([a, b], requires_grad=True) >>> nt.is_leaf True
- torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None, jagged_dim=None, min_seqlen=None, max_seqlen=None)[source][source]¶
从给定的交错组件构建交错布局嵌套张量。交错布局由一个必需的 values 缓冲区组成,交错维度打包到一个单独的维度中。offsets / lengths 元数据确定如何将此维度分割成批处理元素,并且预期将与 values 缓冲区分配在相同的设备上。
- 预期元数据格式:
偏移量:打包维度内的索引,将其分割成不同大小的批次元素。例如:[0, 2, 3, 6] 表示一个大小为 6 的打包交错维度在概念上应分割成长度为[2, 1, 3]的批次元素。请注意,为了内核方便,需要提供起始和结束偏移量(即形状 batch_size + 1)。
长度:单个批次元素的长度的列表;形状 == batch_size。例如:[2, 1, 3] 表示一个大小为 6 的打包交错维度在概念上应分割成长度为[2, 1, 3]的批次元素。
提供偏移量和长度可能很有用。这描述了一个具有“空洞”的嵌套张量,其中偏移量指示每个批次项的起始位置,长度指定元素的总数(见下例)。
返回的交错布局嵌套张量将是输入值张量的一个视图。
- 参数:
值(
torch.Tensor
)- 基础缓冲区,形状为(sum_B(*), D_1, …, D_N)。交错维度被压缩成一个单一维度,使用偏移量/长度元数据来区分批处理元素。偏移量(可选
torch.Tensor
)- 进入交错维度的偏移量,形状为 B + 1。长度(可选
torch.Tensor
)- 形状为 B 的批处理元素长度。jagged_dim(可选 python:int)- 指示 values 中的哪个维度是压缩的交错维度。如果为 None,则设置为 dim=1(即紧接批处理维度的维度)。默认:None
min_seqlen(可选 python:int)- 如果设置,则使用指定的值作为返回的嵌套张量的缓存最小序列长度。这可以作为一个有用的替代方案,以避免在需要时计算此值,可能避免 GPU -> CPU 同步。默认:None
max_seqlen(可选 python:int)- 如果设置,则使用指定的值作为返回的嵌套张量的缓存最大序列长度。这可以作为一个有用的替代方案,以避免在需要时计算此值,可能避免 GPU -> CPU 同步。默认:None
- 返回类型:
示例:
>>> values = torch.randn(12, 5) >>> offsets = torch.tensor([0, 3, 5, 6, 10, 12]) >>> nt = nested_tensor_from_jagged(values, offsets) >>> # 3D shape with the middle dimension jagged >>> nt.shape torch.Size([5, j2, 5]) >>> # Length of each item in the batch: >>> offsets.diff() tensor([3, 2, 1, 4, 2]) >>> values = torch.randn(6, 5) >>> offsets = torch.tensor([0, 2, 3, 6]) >>> lengths = torch.tensor([1, 1, 2]) >>> # NT with holes >>> nt = nested_tensor_from_jagged(values, offsets, lengths) >>> a, b, c = nt.unbind() >>> # Batch item 1 consists of indices [0, 1) >>> torch.equal(a, values[0:1, :]) True >>> # Batch item 2 consists of indices [2, 3) >>> torch.equal(b, values[2:3, :]) True >>> # Batch item 3 consists of indices [3, 5) >>> torch.equal(c, values[3:5, :]) True
- torch.nested.as_nested_tensor(ts, dtype=None, device=None, layout=None)[source][source]¶
从张量或张量列表/元组构建一个保留 autograd 历史的嵌套张量。
如果传递嵌套张量,除非设备/数据类型/布局不同,否则将直接返回。请注意,转换设备/数据类型将导致复制,而转换布局目前不支持此函数。
如果传递非嵌套张量,则将其视为大小一致的构成元素的批次。如果传递的设备/数据类型与输入不同,或者输入非连续,则将产生复制。否则,将直接使用输入的存储。
如果提供张量列表,则列表中的张量在构建嵌套张量时总是被复制。
- 参数:
ts(Tensor 或 List[Tensor]或 Tuple[Tensor])- 要作为嵌套张量处理的张量,或具有相同 ndim 的张量列表/元组
- 关键字参数:
dtype (
torch.dtype
,可选) – 返回嵌套张量的期望类型。默认:如果为 None,则与列表中最左侧的张量相同torch.dtype
。device (
torch.device
,可选) – 返回嵌套张量的期望设备。默认:如果为 None,则与列表中最左侧的张量相同torch.device
。layout (
torch.layout
,可选) – 返回嵌套张量的期望布局。仅支持 strided 和 jagged 布局。默认:如果为 None,则使用 strided 布局。
- 返回类型:
示例:
>>> a = torch.arange(3, dtype=torch.float, requires_grad=True) >>> b = torch.arange(5, dtype=torch.float, requires_grad=True) >>> nt = torch.nested.as_nested_tensor([a, b]) >>> nt.is_leaf False >>> fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)]) >>> nt.backward(fake_grad) >>> a.grad tensor([1., 1., 1.]) >>> b.grad tensor([0., 0., 0., 0., 0.]) >>> c = torch.randn(3, 5, requires_grad=True) >>> nt2 = torch.nested.as_nested_tensor(c)
- torch.nested.to_padded_tensor(input, padding, output_size=None, out=None) Tensor ¶
通过填充
input
嵌套张量返回一个新的(非嵌套)张量。前导条目将填充嵌套数据,而尾部条目将进行填充。警告
由于嵌套张量和非嵌套张量在内存布局上不同,
to_padded_tensor()
始终复制底层数据。- 参数:
填充(浮点数)- 尾部条目的填充值。
- 关键字参数:
输出大小(元组[int])- 输出张量的大小。如果提供,它必须足够大以容纳所有嵌套数据;否则,将通过取每个嵌套子张量沿每个维度的最大大小来推断。
输出(张量,可选)- 输出张量。
示例:
>>> nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))]) nested_tensor([ tensor([[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276], [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995]]), tensor([[-1.8546, -0.7194, -0.2918, -0.1846], [ 0.2773, 0.8793, -0.5183, -0.6447], [ 1.8009, 1.8468, -0.9832, -1.5272]]) ]) >>> pt_infer = torch.nested.to_padded_tensor(nt, 0.0) tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276], [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], [[-1.8546, -0.7194, -0.2918, -0.1846, 0.0000], [ 0.2773, 0.8793, -0.5183, -0.6447, 0.0000], [ 1.8009, 1.8468, -0.9832, -1.5272, 0.0000]]]) >>> pt_large = torch.nested.to_padded_tensor(nt, 1.0, (2, 4, 6)) tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276, 1.0000], [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995, 1.0000], [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]], [[-1.8546, -0.7194, -0.2918, -0.1846, 1.0000, 1.0000], [ 0.2773, 0.8793, -0.5183, -0.6447, 1.0000, 1.0000], [ 1.8009, 1.8468, -0.9832, -1.5272, 1.0000, 1.0000], [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]]) >>> pt_small = torch.nested.to_padded_tensor(nt, 2.0, (2, 2, 2)) RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported.
- torch.nested.masked_select(tensor, mask)[source][source]¶
根据给定的步长张量输入和步长掩码构建嵌套张量,生成的错落布局嵌套张量将在掩码等于 True 的位置保留值。掩码的维度被保留,并使用偏移量表示,这与
masked_select()
不同,其输出被折叠为 1D 张量。参数:tensor(
torch.Tensor
):构建错落布局嵌套张量的步长张量。mask(torch.Tensor
):应用于张量输入的步长掩码张量示例:
>>> tensor = torch.randn(3, 3) >>> mask = torch.tensor([[False, False, True], [True, False, True], [False, False, True]]) >>> nt = torch.nested.masked_select(tensor, mask) >>> nt.shape torch.Size([3, j4]) >>> # Length of each item in the batch: >>> nt.offsets().diff() tensor([1, 2, 1]) >>> tensor = torch.randn(6, 5) >>> mask = torch.tensor([False]) >>> nt = torch.nested.masked_select(tensor, mask) >>> nt.shape torch.Size([6, j5]) >>> # Length of each item in the batch: >>> nt.offsets().diff() tensor([0, 0, 0, 0, 0, 0])
- 返回类型:
- torch.nested.narrow(tensor, dim, start, length, layout=torch.strided)[source][source]¶
从
tensor
构建一个嵌套张量(可能是一个视图),这是一个带偏移量的张量。这与 torch.Tensor.narrow 的语义相似,在dim
维度中,新的嵌套张量只显示[start, start+length)区间的元素。由于嵌套表示允许在该维度的每一“行”中具有不同的起始和长度,因此start
和length
也可以是形状为 tensor.shape[0]的张量。根据嵌套张量的布局,会有一些差异。如果使用 strided 布局,torch.narrow 将复制缩小的数据到一个连续的 NT(嵌套张量)中,而 jagged 布局的 narrow()将创建原始带偏移量张量的非连续视图。这种特定的表示对于在 Transformer 模型中表示 kv 缓存非常有用,因为专门的 SDPA 内核可以轻松处理这种格式,从而提高性能。
- 参数:
tensor (
torch.Tensor
) – 一个带偏移量的张量,如果使用 jagged 布局,将作为嵌套张量的底层数据,如果使用 strided 布局,则将进行复制。dim(int)- 应用狭窄操作的维度。对于交错布局,仅支持 dim=1,而 strided 支持所有维度
start(Union[int,
torch.Tensor
])- 狭窄操作的起始元素length(Union[int,
torch.Tensor
])- 狭窄操作期间取出的元素数量
- 关键字参数:
layout(
torch.layout
,可选)- 返回嵌套张量的期望布局。仅支持 strided 和交错布局。默认:如果为 None,则为 strided 布局- 返回类型:
示例:
>>> starts = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64) >>> lengths = torch.tensor([3, 2, 2, 1, 5], dtype=torch.int64) >>> narrow_base = torch.randn(5, 10, 20) >>> nt_narrowed = torch.nested.narrow(narrow_base, 1, starts, lengths, layout=torch.jagged) >>> nt_narrowed.is_contiguous() False