mirror of
https://github.com/huchenlei/HandRefinerPortable.git
synced 2026-04-30 03:01:32 +00:00
Fix device issue
This commit is contained in:
@@ -7,18 +7,18 @@ Licensed under the MIT license.
|
||||
import torch
|
||||
import mesh_graphormer.modeling.data.config as cfg
|
||||
|
||||
device = "cuda"
|
||||
|
||||
class Graphormer_Body_Network(torch.nn.Module):
|
||||
'''
|
||||
End-to-end Graphormer network for human pose and mesh reconstruction from a single image.
|
||||
'''
|
||||
def __init__(self, args, config, backbone, trans_encoder, mesh_sampler):
|
||||
def __init__(self, args, config, backbone, trans_encoder, mesh_sampler, device):
|
||||
super(Graphormer_Body_Network, self).__init__()
|
||||
self.config = config
|
||||
self.config.device = device
|
||||
self.backbone = backbone
|
||||
self.trans_encoder = trans_encoder
|
||||
self.device = device
|
||||
self.upsampling = torch.nn.Linear(431, 1723)
|
||||
self.upsampling2 = torch.nn.Linear(1723, 6890)
|
||||
self.cam_param_fc = torch.nn.Linear(3, 1)
|
||||
@@ -32,8 +32,8 @@ class Graphormer_Body_Network(torch.nn.Module):
|
||||
# Generate T-pose template mesh
|
||||
template_pose = torch.zeros((1,72))
|
||||
template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord
|
||||
template_pose = template_pose.to(device)
|
||||
template_betas = torch.zeros((1,10)).to(device)
|
||||
template_pose = template_pose.to(self.device)
|
||||
template_betas = torch.zeros((1,10)).to(self.device)
|
||||
template_vertices = smpl(template_pose, template_betas)
|
||||
|
||||
# template mesh simplification
|
||||
|
||||
@@ -12,7 +12,7 @@ class Graphormer_Hand_Network(torch.nn.Module):
|
||||
'''
|
||||
End-to-end Graphormer network for hand pose and mesh reconstruction from a single image.
|
||||
'''
|
||||
def __init__(self, args, config, backbone, trans_encoder, device="cuda"):
|
||||
def __init__(self, args, config, backbone, trans_encoder, device):
|
||||
super(Graphormer_Hand_Network, self).__init__()
|
||||
self.config = config
|
||||
self.backbone = backbone
|
||||
|
||||
@@ -127,7 +127,12 @@ class GraphormerLayer(nn.Module):
|
||||
self.device = config.device
|
||||
|
||||
if self.has_graph_conv == True:
|
||||
self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type)
|
||||
self.graph_conv = GraphResBlock(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
mesh_type=self.mesh_type,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
Reference in New Issue
Block a user