Fix device issue

This commit is contained in:
huchenlei
2024-01-03 19:06:35 -05:00
parent 0d668aff1c
commit 5926b7e47b
9 changed files with 60 additions and 116 deletions

View File

@@ -8,7 +8,7 @@ class MeshGraphormerDetector:
self.pipeline = pipeline
@classmethod
def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device="cuda"):
def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device=None):
filename = filename or "graphormer_hand_state_dict.bin"
hrnet_filename = hrnet_filename or "hrnetv2_w64_imagenet_pretrained.pth"
args.resume_checkpoint = custom_hf_download(pretrained_model_or_path, filename, cache_dir)

View File

@@ -98,10 +98,9 @@ class MeshGraphormerMediapipe(Preprocessor):
for i in range(len(output_feat_dim)):
config_class, model_class = BertConfig, Graphormer
config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
device=args.device,
args.config_name if args.config_name else args.model_name_or_path
)
setattr(config, "device", args.device)
config.output_attentions = False
config.img_feature_dim = input_feat_dim[i]
config.output_feature_dim = output_feat_dim[i]

View File

@@ -2,35 +2,13 @@ from __future__ import division
import torch
import torch.nn.functional as F
import numpy as np
import scipy.sparse
import math
from pathlib import Path
from .util import spmm
data_path = Path(__file__).parent / "data"
sparse_to_dense = lambda x: x
class SparseMM(torch.autograd.Function):
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
The builtin matrix multiplication operation does not support backpropagation in some cases.
"""
@staticmethod
def forward(ctx, sparse, dense):
ctx.req_grad = dense.requires_grad
ctx.save_for_backward(sparse)
return torch.matmul(sparse, dense)
@staticmethod
def backward(ctx, grad_output):
grad_input = None
sparse, = ctx.saved_tensors
if ctx.req_grad:
grad_input = torch.matmul(sparse.t(), grad_output)
return None, grad_input
def spmm(sparse, dense):
return SparseMM.apply(sparse, dense)
def gelu(x):
"""Implementation of the gelu activation function.
@@ -60,7 +38,7 @@ class GraphResBlock(torch.nn.Module):
"""
Graph Residual Block similar to the Bottleneck Residual Block in ResNet
"""
def __init__(self, in_channels, out_channels, mesh_type='body', device="cuda"):
def __init__(self, in_channels, out_channels, mesh_type='body', device=None):
super(GraphResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@@ -87,24 +65,6 @@ class GraphResBlock(torch.nn.Module):
return z
# class GraphResBlock(torch.nn.Module):
# """
# Graph Residual Block similar to the Bottleneck Residual Block in ResNet
# """
# def __init__(self, in_channels, out_channels, mesh_type='body'):
# super(GraphResBlock, self).__init__()
# self.in_channels = in_channels
# self.out_channels = out_channels
# self.conv = GraphConvolution(self.in_channels, self.out_channels, mesh_type)
# print('Use BertLayerNorm and GeLU in GraphResBlock')
# self.norm = BertLayerNorm(self.out_channels)
# def forward(self, x):
# y = self.conv(x)
# y = self.norm(y)
# y = gelu(y)
# z = x+y
# return z
class GraphLinear(torch.nn.Module):
"""
Generalization of 1x1 convolutions on Graphs
@@ -127,7 +87,7 @@ class GraphLinear(torch.nn.Module):
class GraphConvolution(torch.nn.Module):
"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
def __init__(self, in_features, out_features, mesh='body', bias=True, device="cuda"):
def __init__(self, in_features, out_features, mesh='body', bias=True, device=None):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
@@ -170,7 +130,7 @@ class GraphConvolution(torch.nn.Module):
for i in range(x.shape[0]):
support = torch.matmul(x[i], self.weight)
# output.append(torch.matmul(self.adjmat, support))
output.append(spmm(self.adjmat.to(self.device), support.to(self.device)))
output.append(spmm(self.adjmat, support, self.device))
output = torch.stack(output, dim=0)
if self.bias is not None:
output = output + self.bias

View File

@@ -12,11 +12,10 @@ import numpy as np
import torch
import torch.nn as nn
import os.path as osp
import json
import code
from manopth.manolayer import ManoLayer
import scipy.sparse
import mesh_graphormer.modeling.data.config as cfg
from .util import spmm
from pathlib import Path
@@ -64,28 +63,6 @@ class MANO(nn.Module):
return joints
class SparseMM(torch.autograd.Function):
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
The builtin matrix multiplication operation does not support backpropagation in some cases.
"""
@staticmethod
def forward(ctx, sparse, dense):
ctx.req_grad = dense.requires_grad
ctx.save_for_backward(sparse)
return torch.matmul(sparse, dense)
@staticmethod
def backward(ctx, grad_output):
grad_input = None
sparse, = ctx.saved_tensors
if ctx.req_grad:
grad_input = torch.matmul(sparse.t(), grad_output)
return None, grad_input
def spmm(sparse, dense):
return SparseMM.apply(sparse, dense)
def scipy_to_pytorch(A, U, D):
"""Convert scipy sparse matrices to pytorch sparse matrix."""
ptU = []
@@ -141,12 +118,13 @@ def get_graph_params(filename, nsize=1):
class Mesh(object):
"""Mesh object that is used for handling certain graph operations."""
def __init__(self, filename=cfg.MANO_sampling_matrix,
num_downsampling=1, nsize=1, device=torch.device('cuda')):
num_downsampling=1, nsize=1, device=None):
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
# self._A = [a.to(device) for a in self._A]
self._U = [u.to(device) for u in self._U]
self._D = [d.to(device) for d in self._D]
self.num_downsampling = num_downsampling
self.device = device
def downsample(self, x, n1=0, n2=None):
"""Downsample mesh."""
@@ -154,13 +132,13 @@ class Mesh(object):
n2 = self.num_downsampling
if x.ndimension() < 3:
for i in range(n1, n2):
x = spmm(self._D[i], x)
x = spmm(self._D[i], x, self.device)
elif x.ndimension() == 3:
out = []
for i in range(x.shape[0]):
y = x[i]
for j in range(n1, n2):
y = spmm(self._D[j], y)
y = spmm(self._D[j], y, self.device)
out.append(y)
x = torch.stack(out, dim=0)
return x
@@ -169,13 +147,13 @@ class Mesh(object):
"""Upsample mesh."""
if x.ndimension() < 3:
for i in reversed(range(n2, n1)):
x = spmm(self._U[i], x)
x = spmm(self._U[i], x, self.device)
elif x.ndimension() == 3:
out = []
for i in range(x.shape[0]):
y = x[i]
for j in reversed(range(n2, n1)):
y = spmm(self._U[j], y)
y = spmm(self._U[j], y, self.device)
out.append(y)
x = torch.stack(out, dim=0)
return x

View File

@@ -16,6 +16,7 @@ except ImportError:
from mesh_graphormer.utils.geometric_layers import rodrigues
import mesh_graphormer.modeling.data.config as cfg
from .util import spmm
sparse_to_dense = lambda x: x
@@ -140,27 +141,6 @@ class SMPL(nn.Module):
joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct])
return joints
class SparseMM(torch.autograd.Function):
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
The builtin matrix multiplication operation does not support backpropagation in some cases.
"""
@staticmethod
def forward(ctx, sparse, dense):
ctx.req_grad = dense.requires_grad
ctx.save_for_backward(sparse)
return torch.matmul(sparse, dense)
@staticmethod
def backward(ctx, grad_output):
grad_input = None
sparse, = ctx.saved_tensors
if ctx.req_grad:
grad_input = torch.matmul(sparse.t(), grad_output)
return None, grad_input
def spmm(sparse, dense):
return SparseMM.apply(sparse, dense)
def scipy_to_pytorch(A, U, D):
"""Convert scipy sparse matrices to pytorch sparse matrix."""
@@ -217,7 +197,7 @@ def get_graph_params(filename, nsize=1):
class Mesh(object):
"""Mesh object that is used for handling certain graph operations."""
def __init__(self, filename=cfg.SMPL_sampling_matrix,
num_downsampling=1, nsize=1, device=torch.device('cuda')):
num_downsampling=1, nsize=1, device=None):
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
# self._A = [a.to(device) for a in self._A]
self._U = [u.to(device) for u in self._U]
@@ -233,18 +213,14 @@ class Mesh(object):
self._ref_vertices = ref_vertices.to(device)
self.faces = smpl.faces.int().to(device)
# @property
# def adjmat(self):
# """Return the graph adjacency matrix at the specified subsampling level."""
# return self._A[self.num_downsampling].float()
self.device = device
@property
def ref_vertices(self):
"""Return the template vertices at the specified subsampling level."""
ref_vertices = self._ref_vertices
for i in range(self.num_downsampling):
ref_vertices = torch.spmm(self._D[i], ref_vertices)
ref_vertices = torch.spmm(self._D[i], ref_vertices, self.device)
return ref_vertices
def downsample(self, x, n1=0, n2=None):
@@ -253,13 +229,13 @@ class Mesh(object):
n2 = self.num_downsampling
if x.ndimension() < 3:
for i in range(n1, n2):
x = spmm(self._D[i], x)
x = spmm(self._D[i], x, self.device)
elif x.ndimension() == 3:
out = []
for i in range(x.shape[0]):
y = x[i]
for j in range(n1, n2):
y = spmm(self._D[j], y)
y = spmm(self._D[j], y, self.device)
out.append(y)
x = torch.stack(out, dim=0)
return x
@@ -268,13 +244,13 @@ class Mesh(object):
"""Upsample mesh."""
if x.ndimension() < 3:
for i in reversed(range(n2, n1)):
x = spmm(self._U[i], x)
x = spmm(self._U[i], x, self.device)
elif x.ndimension() == 3:
out = []
for i in range(x.shape[0]):
y = x[i]
for j in reversed(range(n2, n1)):
y = spmm(self._U[j], y)
y = spmm(self._U[j], y, self.device)
out.append(y)
x = torch.stack(out, dim=0)
return x

View File

@@ -7,18 +7,18 @@ Licensed under the MIT license.
import torch
import mesh_graphormer.modeling.data.config as cfg
device = "cuda"
class Graphormer_Body_Network(torch.nn.Module):
'''
End-to-end Graphormer network for human pose and mesh reconstruction from a single image.
'''
def __init__(self, args, config, backbone, trans_encoder, mesh_sampler):
def __init__(self, args, config, backbone, trans_encoder, mesh_sampler, device):
super(Graphormer_Body_Network, self).__init__()
self.config = config
self.config.device = device
self.backbone = backbone
self.trans_encoder = trans_encoder
self.device = device
self.upsampling = torch.nn.Linear(431, 1723)
self.upsampling2 = torch.nn.Linear(1723, 6890)
self.cam_param_fc = torch.nn.Linear(3, 1)
@@ -32,8 +32,8 @@ class Graphormer_Body_Network(torch.nn.Module):
# Generate T-pose template mesh
template_pose = torch.zeros((1,72))
template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord
template_pose = template_pose.to(device)
template_betas = torch.zeros((1,10)).to(device)
template_pose = template_pose.to(self.device)
template_betas = torch.zeros((1,10)).to(self.device)
template_vertices = smpl(template_pose, template_betas)
# template mesh simplification

View File

@@ -12,7 +12,7 @@ class Graphormer_Hand_Network(torch.nn.Module):
'''
End-to-end Graphormer network for hand pose and mesh reconstruction from a single image.
'''
def __init__(self, args, config, backbone, trans_encoder, device="cuda"):
def __init__(self, args, config, backbone, trans_encoder, device):
super(Graphormer_Hand_Network, self).__init__()
self.config = config
self.backbone = backbone

View File

@@ -127,7 +127,12 @@ class GraphormerLayer(nn.Module):
self.device = config.device
if self.has_graph_conv == True:
self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type)
self.graph_conv = GraphResBlock(
config.hidden_size,
config.hidden_size,
mesh_type=self.mesh_type,
device=self.device,
)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)

View File

@@ -0,0 +1,26 @@
import torch
class SparseMM(torch.autograd.Function):
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
The builtin matrix multiplication operation does not support backpropagation in some cases.
"""
@staticmethod
def forward(ctx, sparse, dense):
ctx.req_grad = dense.requires_grad
ctx.save_for_backward(sparse)
return torch.matmul(sparse, dense)
@staticmethod
def backward(ctx, grad_output):
grad_input = None
sparse, = ctx.saved_tensors
if ctx.req_grad:
grad_input = torch.matmul(sparse.t(), grad_output)
return None, grad_input
def spmm(sparse, dense, device):
assert device is not None
sparse = sparse.to(device)
dense = dense.to(device)
return SparseMM.apply(sparse, dense)