torch.take
-
torch.take(input, index) → Tensor
返回一个新的张量,包含 input
在给定索引处的元素。输入张量被视为一个 1-D 张量。结果张量的形状与索引相同。
- 参数:
input (Tensor) – 输入张量。
index (长张量) – 张量中的索引
示例:
>>> src = torch.tensor([[4, 3, 5],
... [6, 7, 8]])
>>> torch.take(src, torch.tensor([0, 2, 5]))
tensor([ 4, 5, 8])