diff --git a/hand_refiner/pipeline.py b/hand_refiner/pipeline.py index c05c207..970f112 100644 --- a/hand_refiner/pipeline.py +++ b/hand_refiner/pipeline.py @@ -97,8 +97,10 @@ class MeshGraphormerMediapipe(Preprocessor): # init three transformer-encoder blocks in a loop for i in range(len(output_feat_dim)): config_class, model_class = BertConfig, Graphormer - config = config_class.from_pretrained(args.config_name if args.config_name \ - else args.model_name_or_path) + config = config_class.from_pretrained( + args.config_name if args.config_name else args.model_name_or_path, + device=args.device, + ) config.output_attentions = False config.img_feature_dim = input_feat_dim[i] @@ -155,7 +157,7 @@ class MeshGraphormerMediapipe(Preprocessor): #logger.info('Backbone total parameters: {}'.format(backbone_total_params)) # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder) - _model = Graphormer_Network(args, config, backbone, trans_encoder) + _model = Graphormer_Network(args, config, backbone, trans_encoder, device=args.device) if args.resume_checkpoint!=None and args.resume_checkpoint!='None': # for fine-tuning or resume training or inference, load weights from checkpoint diff --git a/mesh_graphormer/modeling/_gcnn.py b/mesh_graphormer/modeling/_gcnn.py index 43bfe63..8961843 100644 --- a/mesh_graphormer/modeling/_gcnn.py +++ b/mesh_graphormer/modeling/_gcnn.py @@ -9,7 +9,6 @@ data_path = Path(__file__).parent / "data" sparse_to_dense = lambda x: x -device = "cuda" class SparseMM(torch.autograd.Function): """Redefine sparse @ dense matrix multiplication to enable backpropagation. @@ -30,8 +29,6 @@ class SparseMM(torch.autograd.Function): return None, grad_input def spmm(sparse, dense): - sparse = sparse.to(device) - dense = dense.to(device) return SparseMM.apply(sparse, dense) @@ -63,12 +60,12 @@ class GraphResBlock(torch.nn.Module): """ Graph Residual Block similar to the Bottleneck Residual Block in ResNet """ - def __init__(self, in_channels, out_channels, mesh_type='body'): + def __init__(self, in_channels, out_channels, mesh_type='body', device="cuda"): super(GraphResBlock, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.lin1 = GraphLinear(in_channels, out_channels // 2) - self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type) + self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type, device=device) self.lin2 = GraphLinear(out_channels // 2, out_channels) self.skip_conv = GraphLinear(in_channels, out_channels) # print('Use BertLayerNorm in GraphResBlock') @@ -130,10 +127,11 @@ class GraphLinear(torch.nn.Module): class GraphConvolution(torch.nn.Module): """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907.""" - def __init__(self, in_features, out_features, mesh='body', bias=True): + def __init__(self, in_features, out_features, mesh='body', bias=True, device="cuda"): super(GraphConvolution, self).__init__() self.in_features = in_features self.out_features = out_features + self.device = device if mesh=='body': adj_indices = torch.load(data_path / 'smpl_431_adjmat_indices.pt') @@ -144,7 +142,7 @@ class GraphConvolution(torch.nn.Module): adj_mat_value = torch.load(data_path / 'mano_195_adjmat_values.pt') adj_mat_size = torch.load(data_path / 'mano_195_adjmat_size.pt') - self.adjmat = sparse_to_dense(torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size)).to(device) + self.adjmat = sparse_to_dense(torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size)).to(self.device) self.weight = torch.nn.Parameter(torch.FloatTensor(in_features, out_features)) if bias: @@ -172,7 +170,7 @@ class GraphConvolution(torch.nn.Module): for i in range(x.shape[0]): support = torch.matmul(x[i], self.weight) # output.append(torch.matmul(self.adjmat, support)) - output.append(spmm(self.adjmat, support)) + output.append(spmm(self.adjmat.to(self.device), support.to(self.device))) output = torch.stack(output, dim=0) if self.bias is not None: output = output + self.bias diff --git a/mesh_graphormer/modeling/_mano.py b/mesh_graphormer/modeling/_mano.py index 01d2d67..0997017 100644 --- a/mesh_graphormer/modeling/_mano.py +++ b/mesh_graphormer/modeling/_mano.py @@ -21,7 +21,6 @@ from pathlib import Path sparse_to_dense = lambda x: x -device = "cuda" class MANO(nn.Module): def __init__(self): @@ -84,8 +83,6 @@ class SparseMM(torch.autograd.Function): return None, grad_input def spmm(sparse, dense): - sparse = sparse.to(device) - dense = dense.to(device) return SparseMM.apply(sparse, dense) diff --git a/mesh_graphormer/modeling/_smpl.py b/mesh_graphormer/modeling/_smpl.py index 15b0f1a..ec68148 100644 --- a/mesh_graphormer/modeling/_smpl.py +++ b/mesh_graphormer/modeling/_smpl.py @@ -19,7 +19,6 @@ import mesh_graphormer.modeling.data.config as cfg sparse_to_dense = lambda x: x -device = "cuda" class SMPL(nn.Module): @@ -160,8 +159,6 @@ class SparseMM(torch.autograd.Function): return None, grad_input def spmm(sparse, dense): - sparse = sparse.to(device) - dense = dense.to(device) return SparseMM.apply(sparse, dense) diff --git a/mesh_graphormer/modeling/bert/e2e_hand_network.py b/mesh_graphormer/modeling/bert/e2e_hand_network.py index 4dc7385..488d399 100644 --- a/mesh_graphormer/modeling/bert/e2e_hand_network.py +++ b/mesh_graphormer/modeling/bert/e2e_hand_network.py @@ -7,17 +7,17 @@ Licensed under the MIT license. import torch import mesh_graphormer.modeling.data.config as cfg -device = "cuda" class Graphormer_Hand_Network(torch.nn.Module): ''' End-to-end Graphormer network for hand pose and mesh reconstruction from a single image. ''' - def __init__(self, args, config, backbone, trans_encoder): + def __init__(self, args, config, backbone, trans_encoder, device="cuda"): super(Graphormer_Hand_Network, self).__init__() self.config = config self.backbone = backbone self.trans_encoder = trans_encoder + self.device = device self.upsampling = torch.nn.Linear(195, 778) self.cam_param_fc = torch.nn.Linear(3, 1) self.cam_param_fc2 = torch.nn.Linear(195+21, 150) @@ -28,8 +28,8 @@ class Graphormer_Hand_Network(torch.nn.Module): batch_size = images.size(0) # Generate T-pose template mesh template_pose = torch.zeros((1,48)) - template_pose = template_pose.to(device) - template_betas = torch.zeros((1,10)).to(device) + template_pose = template_pose.to(self.device) + template_betas = torch.zeros((1,10)).to(self.device) template_vertices, template_3d_joints = mesh_model.layer(template_pose, template_betas) template_vertices = template_vertices/1000.0 template_3d_joints = template_3d_joints/1000.0 @@ -64,7 +64,7 @@ class Graphormer_Hand_Network(torch.nn.Module): # apply mask vertex/joint modeling # meta_masks is a tensor of all the masks, randomly generated in dataloader # we pre-define a [MASK] token, which is a floating-value vector with 0.01s - special_token = torch.ones_like(features[:,:-49,:]).to(device)*0.01 + special_token = torch.ones_like(features[:,:-49,:]).to(self.device)*0.01 features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks) # forward pass diff --git a/mesh_graphormer/modeling/bert/modeling_graphormer.py b/mesh_graphormer/modeling/bert/modeling_graphormer.py index 6d167c6..c0833a7 100644 --- a/mesh_graphormer/modeling/bert/modeling_graphormer.py +++ b/mesh_graphormer/modeling/bert/modeling_graphormer.py @@ -14,13 +14,11 @@ import torch from torch import nn from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput import mesh_graphormer.modeling.data.config as cfg -from mesh_graphormer.modeling._gcnn import GraphConvolution, GraphResBlock +from mesh_graphormer.modeling._gcnn import GraphResBlock from .modeling_utils import prune_linear_layer LayerNormClass = torch.nn.LayerNorm BertLayerNorm = torch.nn.LayerNorm -device = "cuda" - class BertSelfAttention(nn.Module): def __init__(self, config): @@ -126,6 +124,7 @@ class GraphormerLayer(nn.Module): self.attention = BertAttention(config) self.has_graph_conv = config.graph_conv self.mesh_type = config.mesh_type + self.device = config.device if self.has_graph_conv == True: self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type) @@ -235,7 +234,7 @@ class EncoderBlock(BertPreTrainedModel): batch_size = len(img_feats) seq_length = len(img_feats[0]) - input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).to(device) + input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).to(self.device) if position_ids is None: position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) diff --git a/mesh_graphormer/utils/comm.py b/mesh_graphormer/utils/comm.py index b82e7f8..a79c078 100644 --- a/mesh_graphormer/utils/comm.py +++ b/mesh_graphormer/utils/comm.py @@ -7,15 +7,10 @@ This is useful when doing distributed training. """ import pickle -import time - import torch import torch.distributed as dist -device = "cuda" - - def get_world_size(): if not dist.is_available(): return 1 @@ -104,49 +99,6 @@ def gather_on_master(data): return data_list -def all_gather(data): - """ - Run all_gather on arbitrary picklable data (not necessarily tensors) - Args: - data: any picklable object - Returns: - list[data]: list of data gathered from each rank - """ - world_size = get_world_size() - if world_size == 1: - return [data] - - # serialized to a Tensor - buffer = pickle.dumps(data) - storage = torch.ByteStorage.from_buffer(buffer) - tensor = torch.ByteTensor(storage).to(device) - - # obtain Tensor size of each rank - local_size = torch.LongTensor([tensor.numel()]).to(device) - size_list = [torch.LongTensor([0]).to(device) for _ in range(world_size)] - dist.all_gather(size_list, local_size) - size_list = [int(size.item()) for size in size_list] - max_size = max(size_list) - - # receiving Tensor from all ranks - # we pad the tensor because torch all_gather does not support - # gathering tensors of different shapes - tensor_list = [] - for _ in size_list: - tensor_list.append(torch.ByteTensor(size=(max_size,)).to(device)) - if local_size != max_size: - padding = torch.ByteTensor(size=(max_size - local_size,)).to(device) - tensor = torch.cat((tensor, padding), dim=0) - dist.all_gather(tensor_list, tensor) - - data_list = [] - for size, tensor in zip(size_list, tensor_list): - buffer = tensor.cpu().numpy().tobytes()[:size] - data_list.append(pickle.loads(buffer)) - - return data_list - - def reduce_dict(input_dict, average=True): """ Args: