torch.unravel_index¶
- torch.unravel_index(indices, shape)[source][source]¶
将平面索引张量转换为坐标张量元组,这些坐标张量用于索引任意形状的张量。
- 参数:
索引(Tensor)- 包含对任意形状为
shape
的 Tensor 展开版本索引的整数 Tensor。所有元素必须在[0, prod(shape) - 1]
范围内。形状(int,int 序列或 torch.Size)- 任意 Tensor 的形状。所有元素必须为非负数。
- 返回值:
输出中的每个
i
-th Tensor 与shape
的i
维度相对应。每个 Tensor 具有与indices
相同的形状,并包含一个索引,该索引对应于每个由indices
给出的扁平索引的i
维度。- 返回类型:
张量元组的序列
示例:
>>> import torch >>> torch.unravel_index(torch.tensor(4), (3, 2)) (tensor(2), tensor(0)) >>> torch.unravel_index(torch.tensor([4, 1]), (3, 2)) (tensor([2, 0]), tensor([0, 1])) >>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2)) (tensor([0, 0, 1, 1, 2, 2]), tensor([0, 1, 0, 1, 0, 1])) >>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10)) (tensor([1, 5]), tensor([2, 6]), tensor([3, 7]), tensor([4, 8])) >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10)) (tensor([[1], [5]]), tensor([[2], [6]]), tensor([[3], [7]]), tensor([[4], [8]])) >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100)) (tensor([[12], [56]]), tensor([[34], [78]]))