mirror of
https://github.com/huchenlei/HandRefinerPortable.git
synced 2026-04-30 19:21:17 +00:00
Fix device issue
This commit is contained in:
@@ -8,7 +8,7 @@ class MeshGraphormerDetector:
|
|||||||
self.pipeline = pipeline
|
self.pipeline = pipeline
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device="cuda"):
|
def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device=None):
|
||||||
filename = filename or "graphormer_hand_state_dict.bin"
|
filename = filename or "graphormer_hand_state_dict.bin"
|
||||||
hrnet_filename = hrnet_filename or "hrnetv2_w64_imagenet_pretrained.pth"
|
hrnet_filename = hrnet_filename or "hrnetv2_w64_imagenet_pretrained.pth"
|
||||||
args.resume_checkpoint = custom_hf_download(pretrained_model_or_path, filename, cache_dir)
|
args.resume_checkpoint = custom_hf_download(pretrained_model_or_path, filename, cache_dir)
|
||||||
|
|||||||
@@ -98,10 +98,9 @@ class MeshGraphormerMediapipe(Preprocessor):
|
|||||||
for i in range(len(output_feat_dim)):
|
for i in range(len(output_feat_dim)):
|
||||||
config_class, model_class = BertConfig, Graphormer
|
config_class, model_class = BertConfig, Graphormer
|
||||||
config = config_class.from_pretrained(
|
config = config_class.from_pretrained(
|
||||||
args.config_name if args.config_name else args.model_name_or_path,
|
args.config_name if args.config_name else args.model_name_or_path
|
||||||
device=args.device,
|
|
||||||
)
|
)
|
||||||
|
setattr(config, "device", args.device)
|
||||||
config.output_attentions = False
|
config.output_attentions = False
|
||||||
config.img_feature_dim = input_feat_dim[i]
|
config.img_feature_dim = input_feat_dim[i]
|
||||||
config.output_feature_dim = output_feat_dim[i]
|
config.output_feature_dim = output_feat_dim[i]
|
||||||
|
|||||||
@@ -2,35 +2,13 @@ from __future__ import division
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.sparse
|
|
||||||
import math
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from .util import spmm
|
||||||
|
|
||||||
data_path = Path(__file__).parent / "data"
|
data_path = Path(__file__).parent / "data"
|
||||||
|
|
||||||
|
|
||||||
sparse_to_dense = lambda x: x
|
sparse_to_dense = lambda x: x
|
||||||
|
|
||||||
class SparseMM(torch.autograd.Function):
|
|
||||||
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
|
|
||||||
The builtin matrix multiplication operation does not support backpropagation in some cases.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, sparse, dense):
|
|
||||||
ctx.req_grad = dense.requires_grad
|
|
||||||
ctx.save_for_backward(sparse)
|
|
||||||
return torch.matmul(sparse, dense)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
grad_input = None
|
|
||||||
sparse, = ctx.saved_tensors
|
|
||||||
if ctx.req_grad:
|
|
||||||
grad_input = torch.matmul(sparse.t(), grad_output)
|
|
||||||
return None, grad_input
|
|
||||||
|
|
||||||
def spmm(sparse, dense):
|
|
||||||
return SparseMM.apply(sparse, dense)
|
|
||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
"""Implementation of the gelu activation function.
|
"""Implementation of the gelu activation function.
|
||||||
@@ -60,7 +38,7 @@ class GraphResBlock(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
Graph Residual Block similar to the Bottleneck Residual Block in ResNet
|
Graph Residual Block similar to the Bottleneck Residual Block in ResNet
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_channels, out_channels, mesh_type='body', device="cuda"):
|
def __init__(self, in_channels, out_channels, mesh_type='body', device=None):
|
||||||
super(GraphResBlock, self).__init__()
|
super(GraphResBlock, self).__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
@@ -87,24 +65,6 @@ class GraphResBlock(torch.nn.Module):
|
|||||||
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
# class GraphResBlock(torch.nn.Module):
|
|
||||||
# """
|
|
||||||
# Graph Residual Block similar to the Bottleneck Residual Block in ResNet
|
|
||||||
# """
|
|
||||||
# def __init__(self, in_channels, out_channels, mesh_type='body'):
|
|
||||||
# super(GraphResBlock, self).__init__()
|
|
||||||
# self.in_channels = in_channels
|
|
||||||
# self.out_channels = out_channels
|
|
||||||
# self.conv = GraphConvolution(self.in_channels, self.out_channels, mesh_type)
|
|
||||||
# print('Use BertLayerNorm and GeLU in GraphResBlock')
|
|
||||||
# self.norm = BertLayerNorm(self.out_channels)
|
|
||||||
# def forward(self, x):
|
|
||||||
# y = self.conv(x)
|
|
||||||
# y = self.norm(y)
|
|
||||||
# y = gelu(y)
|
|
||||||
# z = x+y
|
|
||||||
# return z
|
|
||||||
|
|
||||||
class GraphLinear(torch.nn.Module):
|
class GraphLinear(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Generalization of 1x1 convolutions on Graphs
|
Generalization of 1x1 convolutions on Graphs
|
||||||
@@ -127,7 +87,7 @@ class GraphLinear(torch.nn.Module):
|
|||||||
|
|
||||||
class GraphConvolution(torch.nn.Module):
|
class GraphConvolution(torch.nn.Module):
|
||||||
"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
|
"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
|
||||||
def __init__(self, in_features, out_features, mesh='body', bias=True, device="cuda"):
|
def __init__(self, in_features, out_features, mesh='body', bias=True, device=None):
|
||||||
super(GraphConvolution, self).__init__()
|
super(GraphConvolution, self).__init__()
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
@@ -170,7 +130,7 @@ class GraphConvolution(torch.nn.Module):
|
|||||||
for i in range(x.shape[0]):
|
for i in range(x.shape[0]):
|
||||||
support = torch.matmul(x[i], self.weight)
|
support = torch.matmul(x[i], self.weight)
|
||||||
# output.append(torch.matmul(self.adjmat, support))
|
# output.append(torch.matmul(self.adjmat, support))
|
||||||
output.append(spmm(self.adjmat.to(self.device), support.to(self.device)))
|
output.append(spmm(self.adjmat, support, self.device))
|
||||||
output = torch.stack(output, dim=0)
|
output = torch.stack(output, dim=0)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
output = output + self.bias
|
output = output + self.bias
|
||||||
|
|||||||
@@ -12,11 +12,10 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import json
|
|
||||||
import code
|
|
||||||
from manopth.manolayer import ManoLayer
|
from manopth.manolayer import ManoLayer
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
import mesh_graphormer.modeling.data.config as cfg
|
import mesh_graphormer.modeling.data.config as cfg
|
||||||
|
from .util import spmm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
@@ -64,28 +63,6 @@ class MANO(nn.Module):
|
|||||||
return joints
|
return joints
|
||||||
|
|
||||||
|
|
||||||
class SparseMM(torch.autograd.Function):
|
|
||||||
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
|
|
||||||
The builtin matrix multiplication operation does not support backpropagation in some cases.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, sparse, dense):
|
|
||||||
ctx.req_grad = dense.requires_grad
|
|
||||||
ctx.save_for_backward(sparse)
|
|
||||||
return torch.matmul(sparse, dense)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
grad_input = None
|
|
||||||
sparse, = ctx.saved_tensors
|
|
||||||
if ctx.req_grad:
|
|
||||||
grad_input = torch.matmul(sparse.t(), grad_output)
|
|
||||||
return None, grad_input
|
|
||||||
|
|
||||||
def spmm(sparse, dense):
|
|
||||||
return SparseMM.apply(sparse, dense)
|
|
||||||
|
|
||||||
|
|
||||||
def scipy_to_pytorch(A, U, D):
|
def scipy_to_pytorch(A, U, D):
|
||||||
"""Convert scipy sparse matrices to pytorch sparse matrix."""
|
"""Convert scipy sparse matrices to pytorch sparse matrix."""
|
||||||
ptU = []
|
ptU = []
|
||||||
@@ -141,12 +118,13 @@ def get_graph_params(filename, nsize=1):
|
|||||||
class Mesh(object):
|
class Mesh(object):
|
||||||
"""Mesh object that is used for handling certain graph operations."""
|
"""Mesh object that is used for handling certain graph operations."""
|
||||||
def __init__(self, filename=cfg.MANO_sampling_matrix,
|
def __init__(self, filename=cfg.MANO_sampling_matrix,
|
||||||
num_downsampling=1, nsize=1, device=torch.device('cuda')):
|
num_downsampling=1, nsize=1, device=None):
|
||||||
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
|
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
|
||||||
# self._A = [a.to(device) for a in self._A]
|
# self._A = [a.to(device) for a in self._A]
|
||||||
self._U = [u.to(device) for u in self._U]
|
self._U = [u.to(device) for u in self._U]
|
||||||
self._D = [d.to(device) for d in self._D]
|
self._D = [d.to(device) for d in self._D]
|
||||||
self.num_downsampling = num_downsampling
|
self.num_downsampling = num_downsampling
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def downsample(self, x, n1=0, n2=None):
|
def downsample(self, x, n1=0, n2=None):
|
||||||
"""Downsample mesh."""
|
"""Downsample mesh."""
|
||||||
@@ -154,13 +132,13 @@ class Mesh(object):
|
|||||||
n2 = self.num_downsampling
|
n2 = self.num_downsampling
|
||||||
if x.ndimension() < 3:
|
if x.ndimension() < 3:
|
||||||
for i in range(n1, n2):
|
for i in range(n1, n2):
|
||||||
x = spmm(self._D[i], x)
|
x = spmm(self._D[i], x, self.device)
|
||||||
elif x.ndimension() == 3:
|
elif x.ndimension() == 3:
|
||||||
out = []
|
out = []
|
||||||
for i in range(x.shape[0]):
|
for i in range(x.shape[0]):
|
||||||
y = x[i]
|
y = x[i]
|
||||||
for j in range(n1, n2):
|
for j in range(n1, n2):
|
||||||
y = spmm(self._D[j], y)
|
y = spmm(self._D[j], y, self.device)
|
||||||
out.append(y)
|
out.append(y)
|
||||||
x = torch.stack(out, dim=0)
|
x = torch.stack(out, dim=0)
|
||||||
return x
|
return x
|
||||||
@@ -169,13 +147,13 @@ class Mesh(object):
|
|||||||
"""Upsample mesh."""
|
"""Upsample mesh."""
|
||||||
if x.ndimension() < 3:
|
if x.ndimension() < 3:
|
||||||
for i in reversed(range(n2, n1)):
|
for i in reversed(range(n2, n1)):
|
||||||
x = spmm(self._U[i], x)
|
x = spmm(self._U[i], x, self.device)
|
||||||
elif x.ndimension() == 3:
|
elif x.ndimension() == 3:
|
||||||
out = []
|
out = []
|
||||||
for i in range(x.shape[0]):
|
for i in range(x.shape[0]):
|
||||||
y = x[i]
|
y = x[i]
|
||||||
for j in reversed(range(n2, n1)):
|
for j in reversed(range(n2, n1)):
|
||||||
y = spmm(self._U[j], y)
|
y = spmm(self._U[j], y, self.device)
|
||||||
out.append(y)
|
out.append(y)
|
||||||
x = torch.stack(out, dim=0)
|
x = torch.stack(out, dim=0)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ except ImportError:
|
|||||||
|
|
||||||
from mesh_graphormer.utils.geometric_layers import rodrigues
|
from mesh_graphormer.utils.geometric_layers import rodrigues
|
||||||
import mesh_graphormer.modeling.data.config as cfg
|
import mesh_graphormer.modeling.data.config as cfg
|
||||||
|
from .util import spmm
|
||||||
|
|
||||||
|
|
||||||
sparse_to_dense = lambda x: x
|
sparse_to_dense = lambda x: x
|
||||||
@@ -140,27 +141,6 @@ class SMPL(nn.Module):
|
|||||||
joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct])
|
joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct])
|
||||||
return joints
|
return joints
|
||||||
|
|
||||||
class SparseMM(torch.autograd.Function):
|
|
||||||
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
|
|
||||||
The builtin matrix multiplication operation does not support backpropagation in some cases.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, sparse, dense):
|
|
||||||
ctx.req_grad = dense.requires_grad
|
|
||||||
ctx.save_for_backward(sparse)
|
|
||||||
return torch.matmul(sparse, dense)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
grad_input = None
|
|
||||||
sparse, = ctx.saved_tensors
|
|
||||||
if ctx.req_grad:
|
|
||||||
grad_input = torch.matmul(sparse.t(), grad_output)
|
|
||||||
return None, grad_input
|
|
||||||
|
|
||||||
def spmm(sparse, dense):
|
|
||||||
return SparseMM.apply(sparse, dense)
|
|
||||||
|
|
||||||
|
|
||||||
def scipy_to_pytorch(A, U, D):
|
def scipy_to_pytorch(A, U, D):
|
||||||
"""Convert scipy sparse matrices to pytorch sparse matrix."""
|
"""Convert scipy sparse matrices to pytorch sparse matrix."""
|
||||||
@@ -217,7 +197,7 @@ def get_graph_params(filename, nsize=1):
|
|||||||
class Mesh(object):
|
class Mesh(object):
|
||||||
"""Mesh object that is used for handling certain graph operations."""
|
"""Mesh object that is used for handling certain graph operations."""
|
||||||
def __init__(self, filename=cfg.SMPL_sampling_matrix,
|
def __init__(self, filename=cfg.SMPL_sampling_matrix,
|
||||||
num_downsampling=1, nsize=1, device=torch.device('cuda')):
|
num_downsampling=1, nsize=1, device=None):
|
||||||
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
|
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
|
||||||
# self._A = [a.to(device) for a in self._A]
|
# self._A = [a.to(device) for a in self._A]
|
||||||
self._U = [u.to(device) for u in self._U]
|
self._U = [u.to(device) for u in self._U]
|
||||||
@@ -233,18 +213,14 @@ class Mesh(object):
|
|||||||
|
|
||||||
self._ref_vertices = ref_vertices.to(device)
|
self._ref_vertices = ref_vertices.to(device)
|
||||||
self.faces = smpl.faces.int().to(device)
|
self.faces = smpl.faces.int().to(device)
|
||||||
|
self.device = device
|
||||||
# @property
|
|
||||||
# def adjmat(self):
|
|
||||||
# """Return the graph adjacency matrix at the specified subsampling level."""
|
|
||||||
# return self._A[self.num_downsampling].float()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ref_vertices(self):
|
def ref_vertices(self):
|
||||||
"""Return the template vertices at the specified subsampling level."""
|
"""Return the template vertices at the specified subsampling level."""
|
||||||
ref_vertices = self._ref_vertices
|
ref_vertices = self._ref_vertices
|
||||||
for i in range(self.num_downsampling):
|
for i in range(self.num_downsampling):
|
||||||
ref_vertices = torch.spmm(self._D[i], ref_vertices)
|
ref_vertices = torch.spmm(self._D[i], ref_vertices, self.device)
|
||||||
return ref_vertices
|
return ref_vertices
|
||||||
|
|
||||||
def downsample(self, x, n1=0, n2=None):
|
def downsample(self, x, n1=0, n2=None):
|
||||||
@@ -253,13 +229,13 @@ class Mesh(object):
|
|||||||
n2 = self.num_downsampling
|
n2 = self.num_downsampling
|
||||||
if x.ndimension() < 3:
|
if x.ndimension() < 3:
|
||||||
for i in range(n1, n2):
|
for i in range(n1, n2):
|
||||||
x = spmm(self._D[i], x)
|
x = spmm(self._D[i], x, self.device)
|
||||||
elif x.ndimension() == 3:
|
elif x.ndimension() == 3:
|
||||||
out = []
|
out = []
|
||||||
for i in range(x.shape[0]):
|
for i in range(x.shape[0]):
|
||||||
y = x[i]
|
y = x[i]
|
||||||
for j in range(n1, n2):
|
for j in range(n1, n2):
|
||||||
y = spmm(self._D[j], y)
|
y = spmm(self._D[j], y, self.device)
|
||||||
out.append(y)
|
out.append(y)
|
||||||
x = torch.stack(out, dim=0)
|
x = torch.stack(out, dim=0)
|
||||||
return x
|
return x
|
||||||
@@ -268,13 +244,13 @@ class Mesh(object):
|
|||||||
"""Upsample mesh."""
|
"""Upsample mesh."""
|
||||||
if x.ndimension() < 3:
|
if x.ndimension() < 3:
|
||||||
for i in reversed(range(n2, n1)):
|
for i in reversed(range(n2, n1)):
|
||||||
x = spmm(self._U[i], x)
|
x = spmm(self._U[i], x, self.device)
|
||||||
elif x.ndimension() == 3:
|
elif x.ndimension() == 3:
|
||||||
out = []
|
out = []
|
||||||
for i in range(x.shape[0]):
|
for i in range(x.shape[0]):
|
||||||
y = x[i]
|
y = x[i]
|
||||||
for j in reversed(range(n2, n1)):
|
for j in reversed(range(n2, n1)):
|
||||||
y = spmm(self._U[j], y)
|
y = spmm(self._U[j], y, self.device)
|
||||||
out.append(y)
|
out.append(y)
|
||||||
x = torch.stack(out, dim=0)
|
x = torch.stack(out, dim=0)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -7,18 +7,18 @@ Licensed under the MIT license.
|
|||||||
import torch
|
import torch
|
||||||
import mesh_graphormer.modeling.data.config as cfg
|
import mesh_graphormer.modeling.data.config as cfg
|
||||||
|
|
||||||
device = "cuda"
|
|
||||||
|
|
||||||
class Graphormer_Body_Network(torch.nn.Module):
|
class Graphormer_Body_Network(torch.nn.Module):
|
||||||
'''
|
'''
|
||||||
End-to-end Graphormer network for human pose and mesh reconstruction from a single image.
|
End-to-end Graphormer network for human pose and mesh reconstruction from a single image.
|
||||||
'''
|
'''
|
||||||
def __init__(self, args, config, backbone, trans_encoder, mesh_sampler):
|
def __init__(self, args, config, backbone, trans_encoder, mesh_sampler, device):
|
||||||
super(Graphormer_Body_Network, self).__init__()
|
super(Graphormer_Body_Network, self).__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.config.device = device
|
self.config.device = device
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.trans_encoder = trans_encoder
|
self.trans_encoder = trans_encoder
|
||||||
|
self.device = device
|
||||||
self.upsampling = torch.nn.Linear(431, 1723)
|
self.upsampling = torch.nn.Linear(431, 1723)
|
||||||
self.upsampling2 = torch.nn.Linear(1723, 6890)
|
self.upsampling2 = torch.nn.Linear(1723, 6890)
|
||||||
self.cam_param_fc = torch.nn.Linear(3, 1)
|
self.cam_param_fc = torch.nn.Linear(3, 1)
|
||||||
@@ -32,8 +32,8 @@ class Graphormer_Body_Network(torch.nn.Module):
|
|||||||
# Generate T-pose template mesh
|
# Generate T-pose template mesh
|
||||||
template_pose = torch.zeros((1,72))
|
template_pose = torch.zeros((1,72))
|
||||||
template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord
|
template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord
|
||||||
template_pose = template_pose.to(device)
|
template_pose = template_pose.to(self.device)
|
||||||
template_betas = torch.zeros((1,10)).to(device)
|
template_betas = torch.zeros((1,10)).to(self.device)
|
||||||
template_vertices = smpl(template_pose, template_betas)
|
template_vertices = smpl(template_pose, template_betas)
|
||||||
|
|
||||||
# template mesh simplification
|
# template mesh simplification
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ class Graphormer_Hand_Network(torch.nn.Module):
|
|||||||
'''
|
'''
|
||||||
End-to-end Graphormer network for hand pose and mesh reconstruction from a single image.
|
End-to-end Graphormer network for hand pose and mesh reconstruction from a single image.
|
||||||
'''
|
'''
|
||||||
def __init__(self, args, config, backbone, trans_encoder, device="cuda"):
|
def __init__(self, args, config, backbone, trans_encoder, device):
|
||||||
super(Graphormer_Hand_Network, self).__init__()
|
super(Graphormer_Hand_Network, self).__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
|
|||||||
@@ -127,7 +127,12 @@ class GraphormerLayer(nn.Module):
|
|||||||
self.device = config.device
|
self.device = config.device
|
||||||
|
|
||||||
if self.has_graph_conv == True:
|
if self.has_graph_conv == True:
|
||||||
self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type)
|
self.graph_conv = GraphResBlock(
|
||||||
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
mesh_type=self.mesh_type,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
self.intermediate = BertIntermediate(config)
|
self.intermediate = BertIntermediate(config)
|
||||||
self.output = BertOutput(config)
|
self.output = BertOutput(config)
|
||||||
|
|||||||
26
mesh_graphormer/modeling/util.py
Normal file
26
mesh_graphormer/modeling/util.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class SparseMM(torch.autograd.Function):
|
||||||
|
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
|
||||||
|
The builtin matrix multiplication operation does not support backpropagation in some cases.
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, sparse, dense):
|
||||||
|
ctx.req_grad = dense.requires_grad
|
||||||
|
ctx.save_for_backward(sparse)
|
||||||
|
return torch.matmul(sparse, dense)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
grad_input = None
|
||||||
|
sparse, = ctx.saved_tensors
|
||||||
|
if ctx.req_grad:
|
||||||
|
grad_input = torch.matmul(sparse.t(), grad_output)
|
||||||
|
return None, grad_input
|
||||||
|
|
||||||
|
|
||||||
|
def spmm(sparse, dense, device):
|
||||||
|
assert device is not None
|
||||||
|
sparse = sparse.to(device)
|
||||||
|
dense = dense.to(device)
|
||||||
|
return SparseMM.apply(sparse, dense)
|
||||||
Reference in New Issue
Block a user