torch.sparse.sampled_addmm¶
- torch.sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) Tensor ¶
在指定稀疏模式
input
的位置,执行mat1
和mat2
的密集矩阵乘法。矩阵input
被加到最终结果中。从数学上讲,这执行以下操作:
其中 是
input
的稀疏模式矩阵,alpha
和beta
是缩放因子。 在input
非零值的位置上具有值为 1,其他位置为 0。注意
input
必须是一个稀疏的 CSR 张量。mat1
和mat2
必须是密集张量。- 参数:
输入(张量)- 一个形状为(m, n)的稀疏 CSR 矩阵,用于相加和使用以计算采样矩阵乘法
mat1(张量)- 一个形状为(m, k)的密集矩阵,用于相乘
mat2(张量)- 一个形状为(k, n)的密集矩阵,用于相乘
- 关键字参数:
beta(数字,可选)-
input
( )的乘数alpha(数字,可选)- ( )的乘数
out(张量,可选)- 输出张量。如果为 None 则忽略。默认:None。
示例:
>>> input = torch.eye(3, device='cuda').to_sparse_csr() >>> mat1 = torch.randn(3, 5, device='cuda') >>> mat2 = torch.randn(5, 3, device='cuda') >>> torch.sparse.sampled_addmm(input, mat1, mat2) tensor(crow_indices=tensor([0, 1, 2, 3]), col_indices=tensor([0, 1, 2]), values=tensor([ 0.2847, -0.7805, -0.1900]), device='cuda:0', size=(3, 3), nnz=3, layout=torch.sparse_csr) >>> torch.sparse.sampled_addmm(input, mat1, mat2).to_dense() tensor([[ 0.2847, 0.0000, 0.0000], [ 0.0000, -0.7805, 0.0000], [ 0.0000, 0.0000, -0.1900]], device='cuda:0') >>> torch.sparse.sampled_addmm(input, mat1, mat2, beta=0.5, alpha=0.5) tensor(crow_indices=tensor([0, 1, 2, 3]), col_indices=tensor([0, 1, 2]), values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0', size=(3, 3), nnz=3, layout=torch.sparse_csr)