Fix device issue

This commit is contained in:
huchenlei
2024-01-03 19:06:35 -05:00
parent 0d668aff1c
commit 5926b7e47b
9 changed files with 60 additions and 116 deletions

View File

@@ -8,7 +8,7 @@ class MeshGraphormerDetector:
self.pipeline = pipeline self.pipeline = pipeline
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device="cuda"): def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device=None):
filename = filename or "graphormer_hand_state_dict.bin" filename = filename or "graphormer_hand_state_dict.bin"
hrnet_filename = hrnet_filename or "hrnetv2_w64_imagenet_pretrained.pth" hrnet_filename = hrnet_filename or "hrnetv2_w64_imagenet_pretrained.pth"
args.resume_checkpoint = custom_hf_download(pretrained_model_or_path, filename, cache_dir) args.resume_checkpoint = custom_hf_download(pretrained_model_or_path, filename, cache_dir)

View File

@@ -98,10 +98,9 @@ class MeshGraphormerMediapipe(Preprocessor):
for i in range(len(output_feat_dim)): for i in range(len(output_feat_dim)):
config_class, model_class = BertConfig, Graphormer config_class, model_class = BertConfig, Graphormer
config = config_class.from_pretrained( config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path, args.config_name if args.config_name else args.model_name_or_path
device=args.device,
) )
setattr(config, "device", args.device)
config.output_attentions = False config.output_attentions = False
config.img_feature_dim = input_feat_dim[i] config.img_feature_dim = input_feat_dim[i]
config.output_feature_dim = output_feat_dim[i] config.output_feature_dim = output_feat_dim[i]

View File

@@ -2,35 +2,13 @@ from __future__ import division
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
import scipy.sparse
import math import math
from pathlib import Path from pathlib import Path
from .util import spmm
data_path = Path(__file__).parent / "data" data_path = Path(__file__).parent / "data"
sparse_to_dense = lambda x: x sparse_to_dense = lambda x: x
class SparseMM(torch.autograd.Function):
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
The builtin matrix multiplication operation does not support backpropagation in some cases.
"""
@staticmethod
def forward(ctx, sparse, dense):
ctx.req_grad = dense.requires_grad
ctx.save_for_backward(sparse)
return torch.matmul(sparse, dense)
@staticmethod
def backward(ctx, grad_output):
grad_input = None
sparse, = ctx.saved_tensors
if ctx.req_grad:
grad_input = torch.matmul(sparse.t(), grad_output)
return None, grad_input
def spmm(sparse, dense):
return SparseMM.apply(sparse, dense)
def gelu(x): def gelu(x):
"""Implementation of the gelu activation function. """Implementation of the gelu activation function.
@@ -60,7 +38,7 @@ class GraphResBlock(torch.nn.Module):
""" """
Graph Residual Block similar to the Bottleneck Residual Block in ResNet Graph Residual Block similar to the Bottleneck Residual Block in ResNet
""" """
def __init__(self, in_channels, out_channels, mesh_type='body', device="cuda"): def __init__(self, in_channels, out_channels, mesh_type='body', device=None):
super(GraphResBlock, self).__init__() super(GraphResBlock, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@@ -87,24 +65,6 @@ class GraphResBlock(torch.nn.Module):
return z return z
# class GraphResBlock(torch.nn.Module):
# """
# Graph Residual Block similar to the Bottleneck Residual Block in ResNet
# """
# def __init__(self, in_channels, out_channels, mesh_type='body'):
# super(GraphResBlock, self).__init__()
# self.in_channels = in_channels
# self.out_channels = out_channels
# self.conv = GraphConvolution(self.in_channels, self.out_channels, mesh_type)
# print('Use BertLayerNorm and GeLU in GraphResBlock')
# self.norm = BertLayerNorm(self.out_channels)
# def forward(self, x):
# y = self.conv(x)
# y = self.norm(y)
# y = gelu(y)
# z = x+y
# return z
class GraphLinear(torch.nn.Module): class GraphLinear(torch.nn.Module):
""" """
Generalization of 1x1 convolutions on Graphs Generalization of 1x1 convolutions on Graphs
@@ -127,7 +87,7 @@ class GraphLinear(torch.nn.Module):
class GraphConvolution(torch.nn.Module): class GraphConvolution(torch.nn.Module):
"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907.""" """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
def __init__(self, in_features, out_features, mesh='body', bias=True, device="cuda"): def __init__(self, in_features, out_features, mesh='body', bias=True, device=None):
super(GraphConvolution, self).__init__() super(GraphConvolution, self).__init__()
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
@@ -170,7 +130,7 @@ class GraphConvolution(torch.nn.Module):
for i in range(x.shape[0]): for i in range(x.shape[0]):
support = torch.matmul(x[i], self.weight) support = torch.matmul(x[i], self.weight)
# output.append(torch.matmul(self.adjmat, support)) # output.append(torch.matmul(self.adjmat, support))
output.append(spmm(self.adjmat.to(self.device), support.to(self.device))) output.append(spmm(self.adjmat, support, self.device))
output = torch.stack(output, dim=0) output = torch.stack(output, dim=0)
if self.bias is not None: if self.bias is not None:
output = output + self.bias output = output + self.bias

View File

@@ -12,11 +12,10 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import os.path as osp import os.path as osp
import json
import code
from manopth.manolayer import ManoLayer from manopth.manolayer import ManoLayer
import scipy.sparse import scipy.sparse
import mesh_graphormer.modeling.data.config as cfg import mesh_graphormer.modeling.data.config as cfg
from .util import spmm
from pathlib import Path from pathlib import Path
@@ -64,28 +63,6 @@ class MANO(nn.Module):
return joints return joints
class SparseMM(torch.autograd.Function):
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
The builtin matrix multiplication operation does not support backpropagation in some cases.
"""
@staticmethod
def forward(ctx, sparse, dense):
ctx.req_grad = dense.requires_grad
ctx.save_for_backward(sparse)
return torch.matmul(sparse, dense)
@staticmethod
def backward(ctx, grad_output):
grad_input = None
sparse, = ctx.saved_tensors
if ctx.req_grad:
grad_input = torch.matmul(sparse.t(), grad_output)
return None, grad_input
def spmm(sparse, dense):
return SparseMM.apply(sparse, dense)
def scipy_to_pytorch(A, U, D): def scipy_to_pytorch(A, U, D):
"""Convert scipy sparse matrices to pytorch sparse matrix.""" """Convert scipy sparse matrices to pytorch sparse matrix."""
ptU = [] ptU = []
@@ -141,12 +118,13 @@ def get_graph_params(filename, nsize=1):
class Mesh(object): class Mesh(object):
"""Mesh object that is used for handling certain graph operations.""" """Mesh object that is used for handling certain graph operations."""
def __init__(self, filename=cfg.MANO_sampling_matrix, def __init__(self, filename=cfg.MANO_sampling_matrix,
num_downsampling=1, nsize=1, device=torch.device('cuda')): num_downsampling=1, nsize=1, device=None):
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
# self._A = [a.to(device) for a in self._A] # self._A = [a.to(device) for a in self._A]
self._U = [u.to(device) for u in self._U] self._U = [u.to(device) for u in self._U]
self._D = [d.to(device) for d in self._D] self._D = [d.to(device) for d in self._D]
self.num_downsampling = num_downsampling self.num_downsampling = num_downsampling
self.device = device
def downsample(self, x, n1=0, n2=None): def downsample(self, x, n1=0, n2=None):
"""Downsample mesh.""" """Downsample mesh."""
@@ -154,13 +132,13 @@ class Mesh(object):
n2 = self.num_downsampling n2 = self.num_downsampling
if x.ndimension() < 3: if x.ndimension() < 3:
for i in range(n1, n2): for i in range(n1, n2):
x = spmm(self._D[i], x) x = spmm(self._D[i], x, self.device)
elif x.ndimension() == 3: elif x.ndimension() == 3:
out = [] out = []
for i in range(x.shape[0]): for i in range(x.shape[0]):
y = x[i] y = x[i]
for j in range(n1, n2): for j in range(n1, n2):
y = spmm(self._D[j], y) y = spmm(self._D[j], y, self.device)
out.append(y) out.append(y)
x = torch.stack(out, dim=0) x = torch.stack(out, dim=0)
return x return x
@@ -169,13 +147,13 @@ class Mesh(object):
"""Upsample mesh.""" """Upsample mesh."""
if x.ndimension() < 3: if x.ndimension() < 3:
for i in reversed(range(n2, n1)): for i in reversed(range(n2, n1)):
x = spmm(self._U[i], x) x = spmm(self._U[i], x, self.device)
elif x.ndimension() == 3: elif x.ndimension() == 3:
out = [] out = []
for i in range(x.shape[0]): for i in range(x.shape[0]):
y = x[i] y = x[i]
for j in reversed(range(n2, n1)): for j in reversed(range(n2, n1)):
y = spmm(self._U[j], y) y = spmm(self._U[j], y, self.device)
out.append(y) out.append(y)
x = torch.stack(out, dim=0) x = torch.stack(out, dim=0)
return x return x

View File

@@ -16,6 +16,7 @@ except ImportError:
from mesh_graphormer.utils.geometric_layers import rodrigues from mesh_graphormer.utils.geometric_layers import rodrigues
import mesh_graphormer.modeling.data.config as cfg import mesh_graphormer.modeling.data.config as cfg
from .util import spmm
sparse_to_dense = lambda x: x sparse_to_dense = lambda x: x
@@ -140,27 +141,6 @@ class SMPL(nn.Module):
joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct]) joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct])
return joints return joints
class SparseMM(torch.autograd.Function):
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
The builtin matrix multiplication operation does not support backpropagation in some cases.
"""
@staticmethod
def forward(ctx, sparse, dense):
ctx.req_grad = dense.requires_grad
ctx.save_for_backward(sparse)
return torch.matmul(sparse, dense)
@staticmethod
def backward(ctx, grad_output):
grad_input = None
sparse, = ctx.saved_tensors
if ctx.req_grad:
grad_input = torch.matmul(sparse.t(), grad_output)
return None, grad_input
def spmm(sparse, dense):
return SparseMM.apply(sparse, dense)
def scipy_to_pytorch(A, U, D): def scipy_to_pytorch(A, U, D):
"""Convert scipy sparse matrices to pytorch sparse matrix.""" """Convert scipy sparse matrices to pytorch sparse matrix."""
@@ -217,7 +197,7 @@ def get_graph_params(filename, nsize=1):
class Mesh(object): class Mesh(object):
"""Mesh object that is used for handling certain graph operations.""" """Mesh object that is used for handling certain graph operations."""
def __init__(self, filename=cfg.SMPL_sampling_matrix, def __init__(self, filename=cfg.SMPL_sampling_matrix,
num_downsampling=1, nsize=1, device=torch.device('cuda')): num_downsampling=1, nsize=1, device=None):
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
# self._A = [a.to(device) for a in self._A] # self._A = [a.to(device) for a in self._A]
self._U = [u.to(device) for u in self._U] self._U = [u.to(device) for u in self._U]
@@ -233,18 +213,14 @@ class Mesh(object):
self._ref_vertices = ref_vertices.to(device) self._ref_vertices = ref_vertices.to(device)
self.faces = smpl.faces.int().to(device) self.faces = smpl.faces.int().to(device)
self.device = device
# @property
# def adjmat(self):
# """Return the graph adjacency matrix at the specified subsampling level."""
# return self._A[self.num_downsampling].float()
@property @property
def ref_vertices(self): def ref_vertices(self):
"""Return the template vertices at the specified subsampling level.""" """Return the template vertices at the specified subsampling level."""
ref_vertices = self._ref_vertices ref_vertices = self._ref_vertices
for i in range(self.num_downsampling): for i in range(self.num_downsampling):
ref_vertices = torch.spmm(self._D[i], ref_vertices) ref_vertices = torch.spmm(self._D[i], ref_vertices, self.device)
return ref_vertices return ref_vertices
def downsample(self, x, n1=0, n2=None): def downsample(self, x, n1=0, n2=None):
@@ -253,13 +229,13 @@ class Mesh(object):
n2 = self.num_downsampling n2 = self.num_downsampling
if x.ndimension() < 3: if x.ndimension() < 3:
for i in range(n1, n2): for i in range(n1, n2):
x = spmm(self._D[i], x) x = spmm(self._D[i], x, self.device)
elif x.ndimension() == 3: elif x.ndimension() == 3:
out = [] out = []
for i in range(x.shape[0]): for i in range(x.shape[0]):
y = x[i] y = x[i]
for j in range(n1, n2): for j in range(n1, n2):
y = spmm(self._D[j], y) y = spmm(self._D[j], y, self.device)
out.append(y) out.append(y)
x = torch.stack(out, dim=0) x = torch.stack(out, dim=0)
return x return x
@@ -268,13 +244,13 @@ class Mesh(object):
"""Upsample mesh.""" """Upsample mesh."""
if x.ndimension() < 3: if x.ndimension() < 3:
for i in reversed(range(n2, n1)): for i in reversed(range(n2, n1)):
x = spmm(self._U[i], x) x = spmm(self._U[i], x, self.device)
elif x.ndimension() == 3: elif x.ndimension() == 3:
out = [] out = []
for i in range(x.shape[0]): for i in range(x.shape[0]):
y = x[i] y = x[i]
for j in reversed(range(n2, n1)): for j in reversed(range(n2, n1)):
y = spmm(self._U[j], y) y = spmm(self._U[j], y, self.device)
out.append(y) out.append(y)
x = torch.stack(out, dim=0) x = torch.stack(out, dim=0)
return x return x

View File

@@ -7,18 +7,18 @@ Licensed under the MIT license.
import torch import torch
import mesh_graphormer.modeling.data.config as cfg import mesh_graphormer.modeling.data.config as cfg
device = "cuda"
class Graphormer_Body_Network(torch.nn.Module): class Graphormer_Body_Network(torch.nn.Module):
''' '''
End-to-end Graphormer network for human pose and mesh reconstruction from a single image. End-to-end Graphormer network for human pose and mesh reconstruction from a single image.
''' '''
def __init__(self, args, config, backbone, trans_encoder, mesh_sampler): def __init__(self, args, config, backbone, trans_encoder, mesh_sampler, device):
super(Graphormer_Body_Network, self).__init__() super(Graphormer_Body_Network, self).__init__()
self.config = config self.config = config
self.config.device = device self.config.device = device
self.backbone = backbone self.backbone = backbone
self.trans_encoder = trans_encoder self.trans_encoder = trans_encoder
self.device = device
self.upsampling = torch.nn.Linear(431, 1723) self.upsampling = torch.nn.Linear(431, 1723)
self.upsampling2 = torch.nn.Linear(1723, 6890) self.upsampling2 = torch.nn.Linear(1723, 6890)
self.cam_param_fc = torch.nn.Linear(3, 1) self.cam_param_fc = torch.nn.Linear(3, 1)
@@ -32,8 +32,8 @@ class Graphormer_Body_Network(torch.nn.Module):
# Generate T-pose template mesh # Generate T-pose template mesh
template_pose = torch.zeros((1,72)) template_pose = torch.zeros((1,72))
template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord
template_pose = template_pose.to(device) template_pose = template_pose.to(self.device)
template_betas = torch.zeros((1,10)).to(device) template_betas = torch.zeros((1,10)).to(self.device)
template_vertices = smpl(template_pose, template_betas) template_vertices = smpl(template_pose, template_betas)
# template mesh simplification # template mesh simplification

View File

@@ -12,7 +12,7 @@ class Graphormer_Hand_Network(torch.nn.Module):
''' '''
End-to-end Graphormer network for hand pose and mesh reconstruction from a single image. End-to-end Graphormer network for hand pose and mesh reconstruction from a single image.
''' '''
def __init__(self, args, config, backbone, trans_encoder, device="cuda"): def __init__(self, args, config, backbone, trans_encoder, device):
super(Graphormer_Hand_Network, self).__init__() super(Graphormer_Hand_Network, self).__init__()
self.config = config self.config = config
self.backbone = backbone self.backbone = backbone

View File

@@ -127,7 +127,12 @@ class GraphormerLayer(nn.Module):
self.device = config.device self.device = config.device
if self.has_graph_conv == True: if self.has_graph_conv == True:
self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type) self.graph_conv = GraphResBlock(
config.hidden_size,
config.hidden_size,
mesh_type=self.mesh_type,
device=self.device,
)
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)

View File

@@ -0,0 +1,26 @@
import torch
class SparseMM(torch.autograd.Function):
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
The builtin matrix multiplication operation does not support backpropagation in some cases.
"""
@staticmethod
def forward(ctx, sparse, dense):
ctx.req_grad = dense.requires_grad
ctx.save_for_backward(sparse)
return torch.matmul(sparse, dense)
@staticmethod
def backward(ctx, grad_output):
grad_input = None
sparse, = ctx.saved_tensors
if ctx.req_grad:
grad_input = torch.matmul(sparse.t(), grad_output)
return None, grad_input
def spmm(sparse, dense, device):
assert device is not None
sparse = sparse.to(device)
dense = dense.to(device)
return SparseMM.apply(sparse, dense)