From 5926b7e47bbfb117cf97030778cac38c22a057f5 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Wed, 3 Jan 2024 19:06:35 -0500 Subject: [PATCH] Fix device issue --- hand_refiner/__init__.py | 2 +- hand_refiner/pipeline.py | 5 +- mesh_graphormer/modeling/_gcnn.py | 50 ++----------------- mesh_graphormer/modeling/_mano.py | 36 +++---------- mesh_graphormer/modeling/_smpl.py | 40 +++------------ .../modeling/bert/e2e_body_network.py | 8 +-- .../modeling/bert/e2e_hand_network.py | 2 +- .../modeling/bert/modeling_graphormer.py | 7 ++- mesh_graphormer/modeling/util.py | 26 ++++++++++ 9 files changed, 60 insertions(+), 116 deletions(-) create mode 100644 mesh_graphormer/modeling/util.py diff --git a/hand_refiner/__init__.py b/hand_refiner/__init__.py index 3a1d72e..25866c6 100644 --- a/hand_refiner/__init__.py +++ b/hand_refiner/__init__.py @@ -8,7 +8,7 @@ class MeshGraphormerDetector: self.pipeline = pipeline @classmethod - def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device="cuda"): + def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device=None): filename = filename or "graphormer_hand_state_dict.bin" hrnet_filename = hrnet_filename or "hrnetv2_w64_imagenet_pretrained.pth" args.resume_checkpoint = custom_hf_download(pretrained_model_or_path, filename, cache_dir) diff --git a/hand_refiner/pipeline.py b/hand_refiner/pipeline.py index 970f112..7c7d87e 100644 --- a/hand_refiner/pipeline.py +++ b/hand_refiner/pipeline.py @@ -98,10 +98,9 @@ class MeshGraphormerMediapipe(Preprocessor): for i in range(len(output_feat_dim)): config_class, model_class = BertConfig, Graphormer config = config_class.from_pretrained( - args.config_name if args.config_name else args.model_name_or_path, - device=args.device, + args.config_name if args.config_name else args.model_name_or_path ) - + setattr(config, "device", args.device) config.output_attentions = False config.img_feature_dim = input_feat_dim[i] config.output_feature_dim = output_feat_dim[i] diff --git a/mesh_graphormer/modeling/_gcnn.py b/mesh_graphormer/modeling/_gcnn.py index 8961843..d507d5a 100644 --- a/mesh_graphormer/modeling/_gcnn.py +++ b/mesh_graphormer/modeling/_gcnn.py @@ -2,35 +2,13 @@ from __future__ import division import torch import torch.nn.functional as F import numpy as np -import scipy.sparse import math from pathlib import Path +from .util import spmm + data_path = Path(__file__).parent / "data" - - sparse_to_dense = lambda x: x -class SparseMM(torch.autograd.Function): - """Redefine sparse @ dense matrix multiplication to enable backpropagation. - The builtin matrix multiplication operation does not support backpropagation in some cases. - """ - @staticmethod - def forward(ctx, sparse, dense): - ctx.req_grad = dense.requires_grad - ctx.save_for_backward(sparse) - return torch.matmul(sparse, dense) - - @staticmethod - def backward(ctx, grad_output): - grad_input = None - sparse, = ctx.saved_tensors - if ctx.req_grad: - grad_input = torch.matmul(sparse.t(), grad_output) - return None, grad_input - -def spmm(sparse, dense): - return SparseMM.apply(sparse, dense) - def gelu(x): """Implementation of the gelu activation function. @@ -60,7 +38,7 @@ class GraphResBlock(torch.nn.Module): """ Graph Residual Block similar to the Bottleneck Residual Block in ResNet """ - def __init__(self, in_channels, out_channels, mesh_type='body', device="cuda"): + def __init__(self, in_channels, out_channels, mesh_type='body', device=None): super(GraphResBlock, self).__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -87,24 +65,6 @@ class GraphResBlock(torch.nn.Module): return z -# class GraphResBlock(torch.nn.Module): -# """ -# Graph Residual Block similar to the Bottleneck Residual Block in ResNet -# """ -# def __init__(self, in_channels, out_channels, mesh_type='body'): -# super(GraphResBlock, self).__init__() -# self.in_channels = in_channels -# self.out_channels = out_channels -# self.conv = GraphConvolution(self.in_channels, self.out_channels, mesh_type) -# print('Use BertLayerNorm and GeLU in GraphResBlock') -# self.norm = BertLayerNorm(self.out_channels) -# def forward(self, x): -# y = self.conv(x) -# y = self.norm(y) -# y = gelu(y) -# z = x+y -# return z - class GraphLinear(torch.nn.Module): """ Generalization of 1x1 convolutions on Graphs @@ -127,7 +87,7 @@ class GraphLinear(torch.nn.Module): class GraphConvolution(torch.nn.Module): """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907.""" - def __init__(self, in_features, out_features, mesh='body', bias=True, device="cuda"): + def __init__(self, in_features, out_features, mesh='body', bias=True, device=None): super(GraphConvolution, self).__init__() self.in_features = in_features self.out_features = out_features @@ -170,7 +130,7 @@ class GraphConvolution(torch.nn.Module): for i in range(x.shape[0]): support = torch.matmul(x[i], self.weight) # output.append(torch.matmul(self.adjmat, support)) - output.append(spmm(self.adjmat.to(self.device), support.to(self.device))) + output.append(spmm(self.adjmat, support, self.device)) output = torch.stack(output, dim=0) if self.bias is not None: output = output + self.bias diff --git a/mesh_graphormer/modeling/_mano.py b/mesh_graphormer/modeling/_mano.py index 0997017..b132bf8 100644 --- a/mesh_graphormer/modeling/_mano.py +++ b/mesh_graphormer/modeling/_mano.py @@ -12,11 +12,10 @@ import numpy as np import torch import torch.nn as nn import os.path as osp -import json -import code from manopth.manolayer import ManoLayer import scipy.sparse import mesh_graphormer.modeling.data.config as cfg +from .util import spmm from pathlib import Path @@ -64,28 +63,6 @@ class MANO(nn.Module): return joints -class SparseMM(torch.autograd.Function): - """Redefine sparse @ dense matrix multiplication to enable backpropagation. - The builtin matrix multiplication operation does not support backpropagation in some cases. - """ - @staticmethod - def forward(ctx, sparse, dense): - ctx.req_grad = dense.requires_grad - ctx.save_for_backward(sparse) - return torch.matmul(sparse, dense) - - @staticmethod - def backward(ctx, grad_output): - grad_input = None - sparse, = ctx.saved_tensors - if ctx.req_grad: - grad_input = torch.matmul(sparse.t(), grad_output) - return None, grad_input - -def spmm(sparse, dense): - return SparseMM.apply(sparse, dense) - - def scipy_to_pytorch(A, U, D): """Convert scipy sparse matrices to pytorch sparse matrix.""" ptU = [] @@ -141,12 +118,13 @@ def get_graph_params(filename, nsize=1): class Mesh(object): """Mesh object that is used for handling certain graph operations.""" def __init__(self, filename=cfg.MANO_sampling_matrix, - num_downsampling=1, nsize=1, device=torch.device('cuda')): + num_downsampling=1, nsize=1, device=None): self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) # self._A = [a.to(device) for a in self._A] self._U = [u.to(device) for u in self._U] self._D = [d.to(device) for d in self._D] self.num_downsampling = num_downsampling + self.device = device def downsample(self, x, n1=0, n2=None): """Downsample mesh.""" @@ -154,13 +132,13 @@ class Mesh(object): n2 = self.num_downsampling if x.ndimension() < 3: for i in range(n1, n2): - x = spmm(self._D[i], x) + x = spmm(self._D[i], x, self.device) elif x.ndimension() == 3: out = [] for i in range(x.shape[0]): y = x[i] for j in range(n1, n2): - y = spmm(self._D[j], y) + y = spmm(self._D[j], y, self.device) out.append(y) x = torch.stack(out, dim=0) return x @@ -169,13 +147,13 @@ class Mesh(object): """Upsample mesh.""" if x.ndimension() < 3: for i in reversed(range(n2, n1)): - x = spmm(self._U[i], x) + x = spmm(self._U[i], x, self.device) elif x.ndimension() == 3: out = [] for i in range(x.shape[0]): y = x[i] for j in reversed(range(n2, n1)): - y = spmm(self._U[j], y) + y = spmm(self._U[j], y, self.device) out.append(y) x = torch.stack(out, dim=0) return x diff --git a/mesh_graphormer/modeling/_smpl.py b/mesh_graphormer/modeling/_smpl.py index ec68148..e355fce 100644 --- a/mesh_graphormer/modeling/_smpl.py +++ b/mesh_graphormer/modeling/_smpl.py @@ -16,6 +16,7 @@ except ImportError: from mesh_graphormer.utils.geometric_layers import rodrigues import mesh_graphormer.modeling.data.config as cfg +from .util import spmm sparse_to_dense = lambda x: x @@ -140,27 +141,6 @@ class SMPL(nn.Module): joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct]) return joints -class SparseMM(torch.autograd.Function): - """Redefine sparse @ dense matrix multiplication to enable backpropagation. - The builtin matrix multiplication operation does not support backpropagation in some cases. - """ - @staticmethod - def forward(ctx, sparse, dense): - ctx.req_grad = dense.requires_grad - ctx.save_for_backward(sparse) - return torch.matmul(sparse, dense) - - @staticmethod - def backward(ctx, grad_output): - grad_input = None - sparse, = ctx.saved_tensors - if ctx.req_grad: - grad_input = torch.matmul(sparse.t(), grad_output) - return None, grad_input - -def spmm(sparse, dense): - return SparseMM.apply(sparse, dense) - def scipy_to_pytorch(A, U, D): """Convert scipy sparse matrices to pytorch sparse matrix.""" @@ -217,7 +197,7 @@ def get_graph_params(filename, nsize=1): class Mesh(object): """Mesh object that is used for handling certain graph operations.""" def __init__(self, filename=cfg.SMPL_sampling_matrix, - num_downsampling=1, nsize=1, device=torch.device('cuda')): + num_downsampling=1, nsize=1, device=None): self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) # self._A = [a.to(device) for a in self._A] self._U = [u.to(device) for u in self._U] @@ -233,18 +213,14 @@ class Mesh(object): self._ref_vertices = ref_vertices.to(device) self.faces = smpl.faces.int().to(device) - - # @property - # def adjmat(self): - # """Return the graph adjacency matrix at the specified subsampling level.""" - # return self._A[self.num_downsampling].float() + self.device = device @property def ref_vertices(self): """Return the template vertices at the specified subsampling level.""" ref_vertices = self._ref_vertices for i in range(self.num_downsampling): - ref_vertices = torch.spmm(self._D[i], ref_vertices) + ref_vertices = torch.spmm(self._D[i], ref_vertices, self.device) return ref_vertices def downsample(self, x, n1=0, n2=None): @@ -253,13 +229,13 @@ class Mesh(object): n2 = self.num_downsampling if x.ndimension() < 3: for i in range(n1, n2): - x = spmm(self._D[i], x) + x = spmm(self._D[i], x, self.device) elif x.ndimension() == 3: out = [] for i in range(x.shape[0]): y = x[i] for j in range(n1, n2): - y = spmm(self._D[j], y) + y = spmm(self._D[j], y, self.device) out.append(y) x = torch.stack(out, dim=0) return x @@ -268,13 +244,13 @@ class Mesh(object): """Upsample mesh.""" if x.ndimension() < 3: for i in reversed(range(n2, n1)): - x = spmm(self._U[i], x) + x = spmm(self._U[i], x, self.device) elif x.ndimension() == 3: out = [] for i in range(x.shape[0]): y = x[i] for j in reversed(range(n2, n1)): - y = spmm(self._U[j], y) + y = spmm(self._U[j], y, self.device) out.append(y) x = torch.stack(out, dim=0) return x diff --git a/mesh_graphormer/modeling/bert/e2e_body_network.py b/mesh_graphormer/modeling/bert/e2e_body_network.py index c958047..60b815e 100644 --- a/mesh_graphormer/modeling/bert/e2e_body_network.py +++ b/mesh_graphormer/modeling/bert/e2e_body_network.py @@ -7,18 +7,18 @@ Licensed under the MIT license. import torch import mesh_graphormer.modeling.data.config as cfg -device = "cuda" class Graphormer_Body_Network(torch.nn.Module): ''' End-to-end Graphormer network for human pose and mesh reconstruction from a single image. ''' - def __init__(self, args, config, backbone, trans_encoder, mesh_sampler): + def __init__(self, args, config, backbone, trans_encoder, mesh_sampler, device): super(Graphormer_Body_Network, self).__init__() self.config = config self.config.device = device self.backbone = backbone self.trans_encoder = trans_encoder + self.device = device self.upsampling = torch.nn.Linear(431, 1723) self.upsampling2 = torch.nn.Linear(1723, 6890) self.cam_param_fc = torch.nn.Linear(3, 1) @@ -32,8 +32,8 @@ class Graphormer_Body_Network(torch.nn.Module): # Generate T-pose template mesh template_pose = torch.zeros((1,72)) template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord - template_pose = template_pose.to(device) - template_betas = torch.zeros((1,10)).to(device) + template_pose = template_pose.to(self.device) + template_betas = torch.zeros((1,10)).to(self.device) template_vertices = smpl(template_pose, template_betas) # template mesh simplification diff --git a/mesh_graphormer/modeling/bert/e2e_hand_network.py b/mesh_graphormer/modeling/bert/e2e_hand_network.py index 488d399..8900008 100644 --- a/mesh_graphormer/modeling/bert/e2e_hand_network.py +++ b/mesh_graphormer/modeling/bert/e2e_hand_network.py @@ -12,7 +12,7 @@ class Graphormer_Hand_Network(torch.nn.Module): ''' End-to-end Graphormer network for hand pose and mesh reconstruction from a single image. ''' - def __init__(self, args, config, backbone, trans_encoder, device="cuda"): + def __init__(self, args, config, backbone, trans_encoder, device): super(Graphormer_Hand_Network, self).__init__() self.config = config self.backbone = backbone diff --git a/mesh_graphormer/modeling/bert/modeling_graphormer.py b/mesh_graphormer/modeling/bert/modeling_graphormer.py index c0833a7..57ce5e1 100644 --- a/mesh_graphormer/modeling/bert/modeling_graphormer.py +++ b/mesh_graphormer/modeling/bert/modeling_graphormer.py @@ -127,7 +127,12 @@ class GraphormerLayer(nn.Module): self.device = config.device if self.has_graph_conv == True: - self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type) + self.graph_conv = GraphResBlock( + config.hidden_size, + config.hidden_size, + mesh_type=self.mesh_type, + device=self.device, + ) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) diff --git a/mesh_graphormer/modeling/util.py b/mesh_graphormer/modeling/util.py new file mode 100644 index 0000000..e50dded --- /dev/null +++ b/mesh_graphormer/modeling/util.py @@ -0,0 +1,26 @@ +import torch + +class SparseMM(torch.autograd.Function): + """Redefine sparse @ dense matrix multiplication to enable backpropagation. + The builtin matrix multiplication operation does not support backpropagation in some cases. + """ + @staticmethod + def forward(ctx, sparse, dense): + ctx.req_grad = dense.requires_grad + ctx.save_for_backward(sparse) + return torch.matmul(sparse, dense) + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + sparse, = ctx.saved_tensors + if ctx.req_grad: + grad_input = torch.matmul(sparse.t(), grad_output) + return None, grad_input + + +def spmm(sparse, dense, device): + assert device is not None + sparse = sparse.to(device) + dense = dense.to(device) + return SparseMM.apply(sparse, dense) \ No newline at end of file