Files
HandRefinerPortable/mesh_graphormer/modeling/util.py
2024-01-03 19:06:35 -05:00

26 lines
820 B
Python

import torch
class SparseMM(torch.autograd.Function):
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
The builtin matrix multiplication operation does not support backpropagation in some cases.
"""
@staticmethod
def forward(ctx, sparse, dense):
ctx.req_grad = dense.requires_grad
ctx.save_for_backward(sparse)
return torch.matmul(sparse, dense)
@staticmethod
def backward(ctx, grad_output):
grad_input = None
sparse, = ctx.saved_tensors
if ctx.req_grad:
grad_input = torch.matmul(sparse.t(), grad_output)
return None, grad_input
def spmm(sparse, dense, device):
assert device is not None
sparse = sparse.to(device)
dense = dense.to(device)
return SparseMM.apply(sparse, dense)