import torch class SparseMM(torch.autograd.Function): """Redefine sparse @ dense matrix multiplication to enable backpropagation. The builtin matrix multiplication operation does not support backpropagation in some cases. """ @staticmethod def forward(ctx, sparse, dense): ctx.req_grad = dense.requires_grad ctx.save_for_backward(sparse) return torch.matmul(sparse, dense) @staticmethod def backward(ctx, grad_output): grad_input = None sparse, = ctx.saved_tensors if ctx.req_grad: grad_input = torch.matmul(sparse.t(), grad_output) return None, grad_input def spmm(sparse, dense, device): assert device is not None sparse = sparse.to(device) dense = dense.to(device) return SparseMM.apply(sparse, dense)