torch.diag¶
- torch.diag(input, diagonal=0, *, out=None) → Tensor
如果
input
是一个向量(1-D 张量),则返回一个 2-D 平方张量,其元素为input
的对角线元素。如果
input
是一个矩阵(2-D 张量),则返回一个包含input
对角线元素的 1-D 张量。
参数
diagonal
控制要考虑哪个对角线:如果
diagonal
= 0,则是主对角线。如果
diagonal
> 0,则位于主对角线上方。如果
diagonal
< 0,则位于主对角线下方。
- 参数:
input (Tensor) – 输入张量。
对角线(int,可选)- 要考虑的对角线
- 关键字参数:
输出(张量,可选)- 输出张量。
参见
torch.diagonal()
总是返回输入的对角线。torch.diagflat()
总是构建一个对角线元素由输入指定的张量。示例:
获取对角线为输入向量的正方形矩阵:
>>> a = torch.randn(3) >>> a tensor([ 0.5950,-0.0872, 2.3298]) >>> torch.diag(a) tensor([[ 0.5950, 0.0000, 0.0000], [ 0.0000,-0.0872, 0.0000], [ 0.0000, 0.0000, 2.3298]]) >>> torch.diag(a, 1) tensor([[ 0.0000, 0.5950, 0.0000, 0.0000], [ 0.0000, 0.0000,-0.0872, 0.0000], [ 0.0000, 0.0000, 0.0000, 2.3298], [ 0.0000, 0.0000, 0.0000, 0.0000]])
获取给定矩阵的第 k 个对角线:
>>> a = torch.randn(3, 3) >>> a tensor([[-0.4264, 0.0255,-0.1064], [ 0.8795,-0.2429, 0.1374], [ 0.1029,-0.6482,-1.6300]]) >>> torch.diag(a, 0) tensor([-0.4264,-0.2429,-1.6300]) >>> torch.diag(a, 1) tensor([ 0.0255, 0.1374])