mirror of
https://github.com/huchenlei/HandRefinerPortable.git
synced 2026-01-26 15:49:45 +00:00
26 lines
820 B
Python
26 lines
820 B
Python
import torch
|
|
|
|
class SparseMM(torch.autograd.Function):
|
|
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
|
|
The builtin matrix multiplication operation does not support backpropagation in some cases.
|
|
"""
|
|
@staticmethod
|
|
def forward(ctx, sparse, dense):
|
|
ctx.req_grad = dense.requires_grad
|
|
ctx.save_for_backward(sparse)
|
|
return torch.matmul(sparse, dense)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
grad_input = None
|
|
sparse, = ctx.saved_tensors
|
|
if ctx.req_grad:
|
|
grad_input = torch.matmul(sparse.t(), grad_output)
|
|
return None, grad_input
|
|
|
|
|
|
def spmm(sparse, dense, device):
|
|
assert device is not None
|
|
sparse = sparse.to(device)
|
|
dense = dense.to(device)
|
|
return SparseMM.apply(sparse, dense) |