torch.index_select¶
- torch.index_select(input, dim, index, *, out=None) Tensor ¶
返回一个新的张量,该张量沿维度
dim
使用index
中的条目索引input
张量,其中index
是一个 LongTensor。返回的张量与原始张量具有相同的维度数(
input
)。第dim
维的大小与index
的长度相同;其他维度与原始张量的大小相同。注意
返回的张量不使用与原始张量相同的存储。如果
out
的形状与预期不同,我们将静默地将其更改为正确的形状,如果需要,重新分配底层存储。- 参数:
input (Tensor) – 输入张量。
dim(整型)- 我们索引的维度
index(IntTensor 或 LongTensor)- 包含索引的 1-D 张量
- 关键字参数:
输出(张量,可选)- 输出张量。
示例:
>>> x = torch.randn(3, 4) >>> x tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-0.4664, 0.2647, -0.1228, -1.1068], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> indices = torch.tensor([0, 2]) >>> torch.index_select(x, 0, indices) tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> torch.index_select(x, 1, indices) tensor([[ 0.1427, -0.5414], [-0.4664, -0.1228], [-1.1734, 0.7230]])