torch.narrow¶
- torch.narrow(input, dim, start, length) Tensor ¶
返回一个新的张量,它是
input
张量的缩小版本。维度dim
从start
到start + length
输入。返回的张量和input
张量共享相同的底层存储。- 参数:
输入(张量)- 要缩小的张量
dim(整数)- 缩小所沿的维度
start(整数或张量)- 从开始缩小维度的元素索引。可以是负数,表示从 dim 的末尾开始索引。如果是张量,它必须是一个 0 维整数张量(不允许 bools)
长度(int)- 窄化维度的长度,必须是弱正数
示例:
>>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) >>> torch.narrow(x, 0, 0, 2) tensor([[ 1, 2, 3], [ 4, 5, 6]]) >>> torch.narrow(x, 1, 1, 2) tensor([[ 2, 3], [ 5, 6], [ 8, 9]]) >>> torch.narrow(x, -1, torch.tensor(-1), 1) tensor([[3], [6], [9]])