mirror of
https://github.com/huchenlei/HandRefinerPortable.git
synced 2026-03-11 13:29:47 +00:00
Plumbing on device. Allow device != cuda
This commit is contained in:
@@ -97,8 +97,10 @@ class MeshGraphormerMediapipe(Preprocessor):
|
||||
# init three transformer-encoder blocks in a loop
|
||||
for i in range(len(output_feat_dim)):
|
||||
config_class, model_class = BertConfig, Graphormer
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name \
|
||||
else args.model_name_or_path)
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
config.output_attentions = False
|
||||
config.img_feature_dim = input_feat_dim[i]
|
||||
@@ -155,7 +157,7 @@ class MeshGraphormerMediapipe(Preprocessor):
|
||||
#logger.info('Backbone total parameters: {}'.format(backbone_total_params))
|
||||
|
||||
# build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
|
||||
_model = Graphormer_Network(args, config, backbone, trans_encoder)
|
||||
_model = Graphormer_Network(args, config, backbone, trans_encoder, device=args.device)
|
||||
|
||||
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
|
||||
# for fine-tuning or resume training or inference, load weights from checkpoint
|
||||
|
||||
@@ -9,7 +9,6 @@ data_path = Path(__file__).parent / "data"
|
||||
|
||||
|
||||
sparse_to_dense = lambda x: x
|
||||
device = "cuda"
|
||||
|
||||
class SparseMM(torch.autograd.Function):
|
||||
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
|
||||
@@ -30,8 +29,6 @@ class SparseMM(torch.autograd.Function):
|
||||
return None, grad_input
|
||||
|
||||
def spmm(sparse, dense):
|
||||
sparse = sparse.to(device)
|
||||
dense = dense.to(device)
|
||||
return SparseMM.apply(sparse, dense)
|
||||
|
||||
|
||||
@@ -63,12 +60,12 @@ class GraphResBlock(torch.nn.Module):
|
||||
"""
|
||||
Graph Residual Block similar to the Bottleneck Residual Block in ResNet
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, mesh_type='body'):
|
||||
def __init__(self, in_channels, out_channels, mesh_type='body', device="cuda"):
|
||||
super(GraphResBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.lin1 = GraphLinear(in_channels, out_channels // 2)
|
||||
self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type)
|
||||
self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type, device=device)
|
||||
self.lin2 = GraphLinear(out_channels // 2, out_channels)
|
||||
self.skip_conv = GraphLinear(in_channels, out_channels)
|
||||
# print('Use BertLayerNorm in GraphResBlock')
|
||||
@@ -130,10 +127,11 @@ class GraphLinear(torch.nn.Module):
|
||||
|
||||
class GraphConvolution(torch.nn.Module):
|
||||
"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
|
||||
def __init__(self, in_features, out_features, mesh='body', bias=True):
|
||||
def __init__(self, in_features, out_features, mesh='body', bias=True, device="cuda"):
|
||||
super(GraphConvolution, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.device = device
|
||||
|
||||
if mesh=='body':
|
||||
adj_indices = torch.load(data_path / 'smpl_431_adjmat_indices.pt')
|
||||
@@ -144,7 +142,7 @@ class GraphConvolution(torch.nn.Module):
|
||||
adj_mat_value = torch.load(data_path / 'mano_195_adjmat_values.pt')
|
||||
adj_mat_size = torch.load(data_path / 'mano_195_adjmat_size.pt')
|
||||
|
||||
self.adjmat = sparse_to_dense(torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size)).to(device)
|
||||
self.adjmat = sparse_to_dense(torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size)).to(self.device)
|
||||
|
||||
self.weight = torch.nn.Parameter(torch.FloatTensor(in_features, out_features))
|
||||
if bias:
|
||||
@@ -172,7 +170,7 @@ class GraphConvolution(torch.nn.Module):
|
||||
for i in range(x.shape[0]):
|
||||
support = torch.matmul(x[i], self.weight)
|
||||
# output.append(torch.matmul(self.adjmat, support))
|
||||
output.append(spmm(self.adjmat, support))
|
||||
output.append(spmm(self.adjmat.to(self.device), support.to(self.device)))
|
||||
output = torch.stack(output, dim=0)
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
|
||||
@@ -21,7 +21,6 @@ from pathlib import Path
|
||||
|
||||
|
||||
sparse_to_dense = lambda x: x
|
||||
device = "cuda"
|
||||
|
||||
class MANO(nn.Module):
|
||||
def __init__(self):
|
||||
@@ -84,8 +83,6 @@ class SparseMM(torch.autograd.Function):
|
||||
return None, grad_input
|
||||
|
||||
def spmm(sparse, dense):
|
||||
sparse = sparse.to(device)
|
||||
dense = dense.to(device)
|
||||
return SparseMM.apply(sparse, dense)
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import mesh_graphormer.modeling.data.config as cfg
|
||||
|
||||
|
||||
sparse_to_dense = lambda x: x
|
||||
device = "cuda"
|
||||
|
||||
class SMPL(nn.Module):
|
||||
|
||||
@@ -160,8 +159,6 @@ class SparseMM(torch.autograd.Function):
|
||||
return None, grad_input
|
||||
|
||||
def spmm(sparse, dense):
|
||||
sparse = sparse.to(device)
|
||||
dense = dense.to(device)
|
||||
return SparseMM.apply(sparse, dense)
|
||||
|
||||
|
||||
|
||||
@@ -7,17 +7,17 @@ Licensed under the MIT license.
|
||||
import torch
|
||||
import mesh_graphormer.modeling.data.config as cfg
|
||||
|
||||
device = "cuda"
|
||||
|
||||
class Graphormer_Hand_Network(torch.nn.Module):
|
||||
'''
|
||||
End-to-end Graphormer network for hand pose and mesh reconstruction from a single image.
|
||||
'''
|
||||
def __init__(self, args, config, backbone, trans_encoder):
|
||||
def __init__(self, args, config, backbone, trans_encoder, device="cuda"):
|
||||
super(Graphormer_Hand_Network, self).__init__()
|
||||
self.config = config
|
||||
self.backbone = backbone
|
||||
self.trans_encoder = trans_encoder
|
||||
self.device = device
|
||||
self.upsampling = torch.nn.Linear(195, 778)
|
||||
self.cam_param_fc = torch.nn.Linear(3, 1)
|
||||
self.cam_param_fc2 = torch.nn.Linear(195+21, 150)
|
||||
@@ -28,8 +28,8 @@ class Graphormer_Hand_Network(torch.nn.Module):
|
||||
batch_size = images.size(0)
|
||||
# Generate T-pose template mesh
|
||||
template_pose = torch.zeros((1,48))
|
||||
template_pose = template_pose.to(device)
|
||||
template_betas = torch.zeros((1,10)).to(device)
|
||||
template_pose = template_pose.to(self.device)
|
||||
template_betas = torch.zeros((1,10)).to(self.device)
|
||||
template_vertices, template_3d_joints = mesh_model.layer(template_pose, template_betas)
|
||||
template_vertices = template_vertices/1000.0
|
||||
template_3d_joints = template_3d_joints/1000.0
|
||||
@@ -64,7 +64,7 @@ class Graphormer_Hand_Network(torch.nn.Module):
|
||||
# apply mask vertex/joint modeling
|
||||
# meta_masks is a tensor of all the masks, randomly generated in dataloader
|
||||
# we pre-define a [MASK] token, which is a floating-value vector with 0.01s
|
||||
special_token = torch.ones_like(features[:,:-49,:]).to(device)*0.01
|
||||
special_token = torch.ones_like(features[:,:-49,:]).to(self.device)*0.01
|
||||
features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks)
|
||||
|
||||
# forward pass
|
||||
|
||||
@@ -14,13 +14,11 @@ import torch
|
||||
from torch import nn
|
||||
from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput
|
||||
import mesh_graphormer.modeling.data.config as cfg
|
||||
from mesh_graphormer.modeling._gcnn import GraphConvolution, GraphResBlock
|
||||
from mesh_graphormer.modeling._gcnn import GraphResBlock
|
||||
from .modeling_utils import prune_linear_layer
|
||||
LayerNormClass = torch.nn.LayerNorm
|
||||
BertLayerNorm = torch.nn.LayerNorm
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
@@ -126,6 +124,7 @@ class GraphormerLayer(nn.Module):
|
||||
self.attention = BertAttention(config)
|
||||
self.has_graph_conv = config.graph_conv
|
||||
self.mesh_type = config.mesh_type
|
||||
self.device = config.device
|
||||
|
||||
if self.has_graph_conv == True:
|
||||
self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type)
|
||||
@@ -235,7 +234,7 @@ class EncoderBlock(BertPreTrainedModel):
|
||||
|
||||
batch_size = len(img_feats)
|
||||
seq_length = len(img_feats[0])
|
||||
input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).to(device)
|
||||
input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).to(self.device)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||
|
||||
@@ -7,15 +7,10 @@ This is useful when doing distributed training.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
@@ -104,49 +99,6 @@ def gather_on_master(data):
|
||||
return data_list
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to(device)
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.LongTensor([tensor.numel()]).to(device)
|
||||
size_list = [torch.LongTensor([0]).to(device) for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.ByteTensor(size=(max_size,)).to(device))
|
||||
if local_size != max_size:
|
||||
padding = torch.ByteTensor(size=(max_size - local_size,)).to(device)
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user