torch.diagflat¶
- torch.diagflat(input, offset=0) Tensor ¶
如果
input
是一个向量(1-D 张量),则返回一个 2-D 平方张量,其元素为input
的对角线元素。如果
input
是一个多维张量,则返回一个对角线元素等于input
展平的 2-D 张量。
参数
offset
控制考虑哪个对角线:如果
offset
= 0,则是主对角线。如果
offset
> 0,则是在主对角线之上的对角线。如果
offset
小于 0,则位于主对角线以下。
- 参数:
input (Tensor) – 输入张量。
偏移量(int,可选)- 要考虑的对角线。默认:0(主对角线)。
示例:
>>> a = torch.randn(3) >>> a tensor([-0.2956, -0.9068, 0.1695]) >>> torch.diagflat(a) tensor([[-0.2956, 0.0000, 0.0000], [ 0.0000, -0.9068, 0.0000], [ 0.0000, 0.0000, 0.1695]]) >>> torch.diagflat(a, 1) tensor([[ 0.0000, -0.2956, 0.0000, 0.0000], [ 0.0000, 0.0000, -0.9068, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.1695], [ 0.0000, 0.0000, 0.0000, 0.0000]]) >>> a = torch.randn(2, 2) >>> a tensor([[ 0.2094, -0.3018], [-0.1516, 1.9342]]) >>> torch.diagflat(a) tensor([[ 0.2094, 0.0000, 0.0000, 0.0000], [ 0.0000, -0.3018, 0.0000, 0.0000], [ 0.0000, 0.0000, -0.1516, 0.0000], [ 0.0000, 0.0000, 0.0000, 1.9342]])