Plumbing on device. Allow device != cuda

This commit is contained in:
huchenlei
2024-01-03 15:28:22 -05:00
parent 373886cde5
commit 0d668aff1c
7 changed files with 19 additions and 74 deletions

View File

@@ -97,8 +97,10 @@ class MeshGraphormerMediapipe(Preprocessor):
# init three transformer-encoder blocks in a loop
for i in range(len(output_feat_dim)):
config_class, model_class = BertConfig, Graphormer
config = config_class.from_pretrained(args.config_name if args.config_name \
else args.model_name_or_path)
config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
device=args.device,
)
config.output_attentions = False
config.img_feature_dim = input_feat_dim[i]
@@ -155,7 +157,7 @@ class MeshGraphormerMediapipe(Preprocessor):
#logger.info('Backbone total parameters: {}'.format(backbone_total_params))
# build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
_model = Graphormer_Network(args, config, backbone, trans_encoder)
_model = Graphormer_Network(args, config, backbone, trans_encoder, device=args.device)
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
# for fine-tuning or resume training or inference, load weights from checkpoint

View File

@@ -9,7 +9,6 @@ data_path = Path(__file__).parent / "data"
sparse_to_dense = lambda x: x
device = "cuda"
class SparseMM(torch.autograd.Function):
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
@@ -30,8 +29,6 @@ class SparseMM(torch.autograd.Function):
return None, grad_input
def spmm(sparse, dense):
sparse = sparse.to(device)
dense = dense.to(device)
return SparseMM.apply(sparse, dense)
@@ -63,12 +60,12 @@ class GraphResBlock(torch.nn.Module):
"""
Graph Residual Block similar to the Bottleneck Residual Block in ResNet
"""
def __init__(self, in_channels, out_channels, mesh_type='body'):
def __init__(self, in_channels, out_channels, mesh_type='body', device="cuda"):
super(GraphResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.lin1 = GraphLinear(in_channels, out_channels // 2)
self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type)
self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type, device=device)
self.lin2 = GraphLinear(out_channels // 2, out_channels)
self.skip_conv = GraphLinear(in_channels, out_channels)
# print('Use BertLayerNorm in GraphResBlock')
@@ -130,10 +127,11 @@ class GraphLinear(torch.nn.Module):
class GraphConvolution(torch.nn.Module):
"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
def __init__(self, in_features, out_features, mesh='body', bias=True):
def __init__(self, in_features, out_features, mesh='body', bias=True, device="cuda"):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.device = device
if mesh=='body':
adj_indices = torch.load(data_path / 'smpl_431_adjmat_indices.pt')
@@ -144,7 +142,7 @@ class GraphConvolution(torch.nn.Module):
adj_mat_value = torch.load(data_path / 'mano_195_adjmat_values.pt')
adj_mat_size = torch.load(data_path / 'mano_195_adjmat_size.pt')
self.adjmat = sparse_to_dense(torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size)).to(device)
self.adjmat = sparse_to_dense(torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size)).to(self.device)
self.weight = torch.nn.Parameter(torch.FloatTensor(in_features, out_features))
if bias:
@@ -172,7 +170,7 @@ class GraphConvolution(torch.nn.Module):
for i in range(x.shape[0]):
support = torch.matmul(x[i], self.weight)
# output.append(torch.matmul(self.adjmat, support))
output.append(spmm(self.adjmat, support))
output.append(spmm(self.adjmat.to(self.device), support.to(self.device)))
output = torch.stack(output, dim=0)
if self.bias is not None:
output = output + self.bias

View File

@@ -21,7 +21,6 @@ from pathlib import Path
sparse_to_dense = lambda x: x
device = "cuda"
class MANO(nn.Module):
def __init__(self):
@@ -84,8 +83,6 @@ class SparseMM(torch.autograd.Function):
return None, grad_input
def spmm(sparse, dense):
sparse = sparse.to(device)
dense = dense.to(device)
return SparseMM.apply(sparse, dense)

View File

@@ -19,7 +19,6 @@ import mesh_graphormer.modeling.data.config as cfg
sparse_to_dense = lambda x: x
device = "cuda"
class SMPL(nn.Module):
@@ -160,8 +159,6 @@ class SparseMM(torch.autograd.Function):
return None, grad_input
def spmm(sparse, dense):
sparse = sparse.to(device)
dense = dense.to(device)
return SparseMM.apply(sparse, dense)

View File

@@ -7,17 +7,17 @@ Licensed under the MIT license.
import torch
import mesh_graphormer.modeling.data.config as cfg
device = "cuda"
class Graphormer_Hand_Network(torch.nn.Module):
'''
End-to-end Graphormer network for hand pose and mesh reconstruction from a single image.
'''
def __init__(self, args, config, backbone, trans_encoder):
def __init__(self, args, config, backbone, trans_encoder, device="cuda"):
super(Graphormer_Hand_Network, self).__init__()
self.config = config
self.backbone = backbone
self.trans_encoder = trans_encoder
self.device = device
self.upsampling = torch.nn.Linear(195, 778)
self.cam_param_fc = torch.nn.Linear(3, 1)
self.cam_param_fc2 = torch.nn.Linear(195+21, 150)
@@ -28,8 +28,8 @@ class Graphormer_Hand_Network(torch.nn.Module):
batch_size = images.size(0)
# Generate T-pose template mesh
template_pose = torch.zeros((1,48))
template_pose = template_pose.to(device)
template_betas = torch.zeros((1,10)).to(device)
template_pose = template_pose.to(self.device)
template_betas = torch.zeros((1,10)).to(self.device)
template_vertices, template_3d_joints = mesh_model.layer(template_pose, template_betas)
template_vertices = template_vertices/1000.0
template_3d_joints = template_3d_joints/1000.0
@@ -64,7 +64,7 @@ class Graphormer_Hand_Network(torch.nn.Module):
# apply mask vertex/joint modeling
# meta_masks is a tensor of all the masks, randomly generated in dataloader
# we pre-define a [MASK] token, which is a floating-value vector with 0.01s
special_token = torch.ones_like(features[:,:-49,:]).to(device)*0.01
special_token = torch.ones_like(features[:,:-49,:]).to(self.device)*0.01
features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks)
# forward pass

View File

@@ -14,13 +14,11 @@ import torch
from torch import nn
from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput
import mesh_graphormer.modeling.data.config as cfg
from mesh_graphormer.modeling._gcnn import GraphConvolution, GraphResBlock
from mesh_graphormer.modeling._gcnn import GraphResBlock
from .modeling_utils import prune_linear_layer
LayerNormClass = torch.nn.LayerNorm
BertLayerNorm = torch.nn.LayerNorm
device = "cuda"
class BertSelfAttention(nn.Module):
def __init__(self, config):
@@ -126,6 +124,7 @@ class GraphormerLayer(nn.Module):
self.attention = BertAttention(config)
self.has_graph_conv = config.graph_conv
self.mesh_type = config.mesh_type
self.device = config.device
if self.has_graph_conv == True:
self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type)
@@ -235,7 +234,7 @@ class EncoderBlock(BertPreTrainedModel):
batch_size = len(img_feats)
seq_length = len(img_feats[0])
input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).to(device)
input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).to(self.device)
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)

View File

@@ -7,15 +7,10 @@ This is useful when doing distributed training.
"""
import pickle
import time
import torch
import torch.distributed as dist
device = "cuda"
def get_world_size():
if not dist.is_available():
return 1
@@ -104,49 +99,6 @@ def gather_on_master(data):
return data_list
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to(device)
# obtain Tensor size of each rank
local_size = torch.LongTensor([tensor.numel()]).to(device)
size_list = [torch.LongTensor([0]).to(device) for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to(device))
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).to(device)
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args: