Fix device issue

This commit is contained in:
huchenlei
2024-01-03 19:06:35 -05:00
parent 0d668aff1c
commit 5926b7e47b
9 changed files with 60 additions and 116 deletions

View File

@@ -7,18 +7,18 @@ Licensed under the MIT license.
import torch
import mesh_graphormer.modeling.data.config as cfg
device = "cuda"
class Graphormer_Body_Network(torch.nn.Module):
'''
End-to-end Graphormer network for human pose and mesh reconstruction from a single image.
'''
def __init__(self, args, config, backbone, trans_encoder, mesh_sampler):
def __init__(self, args, config, backbone, trans_encoder, mesh_sampler, device):
super(Graphormer_Body_Network, self).__init__()
self.config = config
self.config.device = device
self.backbone = backbone
self.trans_encoder = trans_encoder
self.device = device
self.upsampling = torch.nn.Linear(431, 1723)
self.upsampling2 = torch.nn.Linear(1723, 6890)
self.cam_param_fc = torch.nn.Linear(3, 1)
@@ -32,8 +32,8 @@ class Graphormer_Body_Network(torch.nn.Module):
# Generate T-pose template mesh
template_pose = torch.zeros((1,72))
template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord
template_pose = template_pose.to(device)
template_betas = torch.zeros((1,10)).to(device)
template_pose = template_pose.to(self.device)
template_betas = torch.zeros((1,10)).to(self.device)
template_vertices = smpl(template_pose, template_betas)
# template mesh simplification

View File

@@ -12,7 +12,7 @@ class Graphormer_Hand_Network(torch.nn.Module):
'''
End-to-end Graphormer network for hand pose and mesh reconstruction from a single image.
'''
def __init__(self, args, config, backbone, trans_encoder, device="cuda"):
def __init__(self, args, config, backbone, trans_encoder, device):
super(Graphormer_Hand_Network, self).__init__()
self.config = config
self.backbone = backbone

View File

@@ -127,7 +127,12 @@ class GraphormerLayer(nn.Module):
self.device = config.device
if self.has_graph_conv == True:
self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type)
self.graph_conv = GraphResBlock(
config.hidden_size,
config.hidden_size,
mesh_type=self.mesh_type,
device=self.device,
)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)