torch.roll¶
- torch.roll(input, shifts, dims=None) Tensor ¶
将张量
input
沿给定维度滚动。超出最后一个位置的元素将被重新引入到第一个位置。如果dims
为 None,则张量在滚动之前将被展平,然后恢复到原始形状。- 参数:
input (Tensor) – 输入张量。
shifts(整数或整数元组)- 张量元素移动的位置数。如果 shifts 是元组,则 dims 也必须是一个相同大小的元组,并且每个维度将按相应的值滚动
dims(整数或整数元组)- 滚动的轴
示例:
>>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) >>> x tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) >>> torch.roll(x, 1) tensor([[8, 1], [2, 3], [4, 5], [6, 7]]) >>> torch.roll(x, 1, 0) tensor([[7, 8], [1, 2], [3, 4], [5, 6]]) >>> torch.roll(x, -1, 0) tensor([[3, 4], [5, 6], [7, 8], [1, 2]]) >>> torch.roll(x, shifts=(2, 1), dims=(0, 1)) tensor([[6, 5], [8, 7], [2, 1], [4, 3]])