mirror of
https://github.com/huchenlei/HandRefinerPortable.git
synced 2026-01-26 15:49:45 +00:00
🔥 Remove unused tools files
This commit is contained in:
@@ -1,750 +0,0 @@
|
||||
"""
|
||||
Copyright (c) Microsoft Corporation.
|
||||
Licensed under the MIT license.
|
||||
|
||||
Training and evaluation codes for
|
||||
3D human body mesh reconstruction from an image
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import argparse
|
||||
import os
|
||||
import os.path as op
|
||||
import code
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from torchvision.utils import make_grid
|
||||
import gc
|
||||
import numpy as np
|
||||
import cv2
|
||||
from mesh_graphormer.modeling.bert import BertConfig, Graphormer
|
||||
from mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network
|
||||
from mesh_graphormer.modeling._smpl import SMPL, Mesh
|
||||
from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
|
||||
from mesh_graphormer.modeling.hrnet.config import config as hrnet_config
|
||||
from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
|
||||
import mesh_graphormer.modeling.data.config as cfg
|
||||
from mesh_graphormer.datasets.build import make_data_loader
|
||||
|
||||
from mesh_graphormer.utils.logger import setup_logger
|
||||
from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
|
||||
from mesh_graphormer.utils.miscellaneous import mkdir, set_seed
|
||||
from mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger
|
||||
from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction, visualize_reconstruction_test
|
||||
from mesh_graphormer.utils.metric_pampjpe import reconstruction_error
|
||||
from mesh_graphormer.utils.geometric_layers import orthographic_projection
|
||||
|
||||
|
||||
device = "cuda"
|
||||
|
||||
from azureml.core.run import Run
|
||||
aml_run = Run.get_context()
|
||||
|
||||
def save_checkpoint(model, args, epoch, iteration, num_trial=10):
|
||||
checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format(
|
||||
epoch, iteration))
|
||||
if not is_main_process():
|
||||
return checkpoint_dir
|
||||
mkdir(checkpoint_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
for i in range(num_trial):
|
||||
try:
|
||||
torch.save(model_to_save, op.join(checkpoint_dir, 'model.bin'))
|
||||
torch.save(model_to_save.state_dict(), op.join(checkpoint_dir, 'state_dict.bin'))
|
||||
torch.save(args, op.join(checkpoint_dir, 'training_args.bin'))
|
||||
logger.info("Save checkpoint to {}".format(checkpoint_dir))
|
||||
break
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
logger.info("Failed to save checkpoint after {} trails.".format(num_trial))
|
||||
return checkpoint_dir
|
||||
|
||||
def save_scores(args, split, mpjpe, pampjpe, mpve):
|
||||
eval_log = []
|
||||
res = {}
|
||||
res['mPJPE'] = mpjpe
|
||||
res['PAmPJPE'] = pampjpe
|
||||
res['mPVE'] = mpve
|
||||
eval_log.append(res)
|
||||
with open(op.join(args.output_dir, split+'_eval_logs.json'), 'w') as f:
|
||||
json.dump(eval_log, f)
|
||||
logger.info("Save eval scores to {}".format(args.output_dir))
|
||||
return
|
||||
|
||||
def adjust_learning_rate(optimizer, epoch, args):
|
||||
"""
|
||||
Sets the learning rate to the initial LR decayed by x every y epochs
|
||||
x = 0.1, y = args.num_train_epochs/2.0 = 100
|
||||
"""
|
||||
lr = args.lr * (0.1 ** (epoch // (args.num_train_epochs/2.0) ))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
def mean_per_joint_position_error(pred, gt, has_3d_joints):
|
||||
"""
|
||||
Compute mPJPE
|
||||
"""
|
||||
gt = gt[has_3d_joints == 1]
|
||||
gt = gt[:, :, :-1]
|
||||
pred = pred[has_3d_joints == 1]
|
||||
|
||||
with torch.no_grad():
|
||||
gt_pelvis = (gt[:, 2,:] + gt[:, 3,:]) / 2
|
||||
gt = gt - gt_pelvis[:, None, :]
|
||||
pred_pelvis = (pred[:, 2,:] + pred[:, 3,:]) / 2
|
||||
pred = pred - pred_pelvis[:, None, :]
|
||||
error = torch.sqrt( ((pred - gt) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
|
||||
return error
|
||||
|
||||
def mean_per_vertex_error(pred, gt, has_smpl):
|
||||
"""
|
||||
Compute mPVE
|
||||
"""
|
||||
pred = pred[has_smpl == 1]
|
||||
gt = gt[has_smpl == 1]
|
||||
with torch.no_grad():
|
||||
error = torch.sqrt( ((pred - gt) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
|
||||
return error
|
||||
|
||||
def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, has_pose_2d):
|
||||
"""
|
||||
Compute 2D reprojection loss if 2D keypoint annotations are available.
|
||||
The confidence (conf) is binary and indicates whether the keypoints exist or not.
|
||||
"""
|
||||
conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
|
||||
loss = (conf * criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean()
|
||||
return loss
|
||||
|
||||
def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d, device):
|
||||
"""
|
||||
Compute 3D keypoint loss if 3D keypoint annotations are available.
|
||||
"""
|
||||
conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
|
||||
gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
|
||||
gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
|
||||
conf = conf[has_pose_3d == 1]
|
||||
pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
|
||||
if len(gt_keypoints_3d) > 0:
|
||||
gt_pelvis = (gt_keypoints_3d[:, 2,:] + gt_keypoints_3d[:, 3,:]) / 2
|
||||
gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :]
|
||||
pred_pelvis = (pred_keypoints_3d[:, 2,:] + pred_keypoints_3d[:, 3,:]) / 2
|
||||
pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :]
|
||||
return (conf * criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean()
|
||||
else:
|
||||
return torch.FloatTensor(1).fill_(0.).to(device)
|
||||
|
||||
def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, device):
|
||||
"""
|
||||
Compute per-vertex loss if vertex annotations are available.
|
||||
"""
|
||||
pred_vertices_with_shape = pred_vertices[has_smpl == 1]
|
||||
gt_vertices_with_shape = gt_vertices[has_smpl == 1]
|
||||
if len(gt_vertices_with_shape) > 0:
|
||||
return criterion_vertices(pred_vertices_with_shape, gt_vertices_with_shape)
|
||||
else:
|
||||
return torch.FloatTensor(1).fill_(0.).to(device)
|
||||
|
||||
def rectify_pose(pose):
|
||||
pose = pose.copy()
|
||||
R_mod = cv2.Rodrigues(np.array([np.pi, 0, 0]))[0]
|
||||
R_root = cv2.Rodrigues(pose[:3])[0]
|
||||
new_root = R_root.dot(R_mod)
|
||||
pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3)
|
||||
return pose
|
||||
|
||||
def run(args, train_dataloader, val_dataloader, Graphormer_model, smpl, mesh_sampler, renderer):
|
||||
smpl.eval()
|
||||
max_iter = len(train_dataloader)
|
||||
iters_per_epoch = max_iter // args.num_train_epochs
|
||||
if iters_per_epoch<1000:
|
||||
args.logging_steps = 500
|
||||
|
||||
optimizer = torch.optim.Adam(params=list(Graphormer_model.parameters()),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0)
|
||||
|
||||
# define loss function (criterion) and optimizer
|
||||
criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').to(device)
|
||||
criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
|
||||
criterion_vertices = torch.nn.L1Loss().to(device)
|
||||
|
||||
if args.distributed:
|
||||
Graphormer_model = torch.nn.parallel.DistributedDataParallel(
|
||||
Graphormer_model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
' '.join(
|
||||
['Local rank: {o}', 'Max iteration: {a}', 'iters_per_epoch: {b}','num_train_epochs: {c}',]
|
||||
).format(o=args.local_rank, a=max_iter, b=iters_per_epoch, c=args.num_train_epochs)
|
||||
)
|
||||
|
||||
start_training_time = time.time()
|
||||
end = time.time()
|
||||
Graphormer_model.train()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
log_losses = AverageMeter()
|
||||
log_loss_2djoints = AverageMeter()
|
||||
log_loss_3djoints = AverageMeter()
|
||||
log_loss_vertices = AverageMeter()
|
||||
log_eval_metrics = EvalMetricsLogger()
|
||||
|
||||
for iteration, (img_keys, images, annotations) in enumerate(train_dataloader):
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
Graphormer_model.train()
|
||||
iteration += 1
|
||||
epoch = iteration // iters_per_epoch
|
||||
batch_size = images.size(0)
|
||||
adjust_learning_rate(optimizer, epoch, args)
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
images = images.to(device)
|
||||
gt_2d_joints = annotations['joints_2d'].to(device)
|
||||
gt_2d_joints = gt_2d_joints[:,cfg.J24_TO_J14,:]
|
||||
has_2d_joints = annotations['has_2d_joints'].to(device)
|
||||
|
||||
gt_3d_joints = annotations['joints_3d'].to(device)
|
||||
gt_3d_pelvis = gt_3d_joints[:,cfg.J24_NAME.index('Pelvis'),:3]
|
||||
gt_3d_joints = gt_3d_joints[:,cfg.J24_TO_J14,:]
|
||||
gt_3d_joints[:,:,:3] = gt_3d_joints[:,:,:3] - gt_3d_pelvis[:, None, :]
|
||||
has_3d_joints = annotations['has_3d_joints'].to(device)
|
||||
|
||||
gt_pose = annotations['pose'].to(device)
|
||||
gt_betas = annotations['betas'].to(device)
|
||||
has_smpl = annotations['has_smpl'].to(device)
|
||||
mjm_mask = annotations['mjm_mask'].to(device)
|
||||
mvm_mask = annotations['mvm_mask'].to(device)
|
||||
|
||||
# generate simplified mesh
|
||||
gt_vertices = smpl(gt_pose, gt_betas)
|
||||
gt_vertices_sub2 = mesh_sampler.downsample(gt_vertices, n1=0, n2=2)
|
||||
gt_vertices_sub = mesh_sampler.downsample(gt_vertices)
|
||||
|
||||
# normalize gt based on smpl's pelvis
|
||||
gt_smpl_3d_joints = smpl.get_h36m_joints(gt_vertices)
|
||||
gt_smpl_3d_pelvis = gt_smpl_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
|
||||
gt_vertices_sub2 = gt_vertices_sub2 - gt_smpl_3d_pelvis[:, None, :]
|
||||
|
||||
# prepare masks for mask vertex/joint modeling
|
||||
mjm_mask_ = mjm_mask.expand(-1,-1,2051)
|
||||
mvm_mask_ = mvm_mask.expand(-1,-1,2051)
|
||||
meta_masks = torch.cat([mjm_mask_, mvm_mask_], dim=1)
|
||||
|
||||
# forward-pass
|
||||
pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices = Graphormer_model(images, smpl, mesh_sampler, meta_masks=meta_masks, is_train=True)
|
||||
|
||||
# normalize gt based on smpl's pelvis
|
||||
gt_vertices_sub = gt_vertices_sub - gt_smpl_3d_pelvis[:, None, :]
|
||||
gt_vertices = gt_vertices - gt_smpl_3d_pelvis[:, None, :]
|
||||
|
||||
# obtain 3d joints, which are regressed from the full mesh
|
||||
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
|
||||
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
|
||||
|
||||
# obtain 2d joints, which are projected from 3d joints of smpl mesh
|
||||
pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera)
|
||||
pred_2d_joints = orthographic_projection(pred_3d_joints, pred_camera)
|
||||
|
||||
# compute 3d joint loss (where the joints are directly output from transformer)
|
||||
loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints, has_3d_joints, args.device)
|
||||
# compute 3d vertex loss
|
||||
loss_vertices = ( args.vloss_w_sub2 * vertices_loss(criterion_vertices, pred_vertices_sub2, gt_vertices_sub2, has_smpl, args.device) + \
|
||||
args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_smpl, args.device) + \
|
||||
args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, args.device) )
|
||||
# compute 3d joint loss (where the joints are regressed from full mesh)
|
||||
loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_smpl, gt_3d_joints, has_3d_joints, args.device)
|
||||
# compute 2d joint loss
|
||||
loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints) + \
|
||||
keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_smpl, gt_2d_joints, has_2d_joints)
|
||||
|
||||
loss_3d_joints = loss_3d_joints + loss_reg_3d_joints
|
||||
|
||||
# we empirically use hyperparameters to balance difference losses
|
||||
loss = args.joints_loss_weight*loss_3d_joints + \
|
||||
args.vertices_loss_weight*loss_vertices + args.vertices_loss_weight*loss_2d_joints
|
||||
|
||||
# update logs
|
||||
log_loss_2djoints.update(loss_2d_joints.item(), batch_size)
|
||||
log_loss_3djoints.update(loss_3d_joints.item(), batch_size)
|
||||
log_loss_vertices.update(loss_vertices.item(), batch_size)
|
||||
log_losses.update(loss.item(), batch_size)
|
||||
|
||||
# back prop
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if iteration % args.logging_steps == 0 or iteration == max_iter:
|
||||
eta_seconds = batch_time.avg * (max_iter - iteration)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
logger.info(
|
||||
' '.join(
|
||||
['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',]
|
||||
).format(eta=eta_string, ep=epoch, iter=iteration,
|
||||
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
|
||||
+ ' loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format(
|
||||
log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg,
|
||||
optimizer.param_groups[0]['lr'])
|
||||
)
|
||||
|
||||
aml_run.log(name='Loss', value=float(log_losses.avg))
|
||||
aml_run.log(name='3d joint Loss', value=float(log_loss_3djoints.avg))
|
||||
aml_run.log(name='2d joint Loss', value=float(log_loss_2djoints.avg))
|
||||
aml_run.log(name='vertex Loss', value=float(log_loss_vertices.avg))
|
||||
|
||||
visual_imgs = visualize_mesh( renderer,
|
||||
annotations['ori_img'].detach(),
|
||||
annotations['joints_2d'].detach(),
|
||||
pred_vertices.detach(),
|
||||
pred_camera.detach(),
|
||||
pred_2d_joints_from_smpl.detach())
|
||||
visual_imgs = visual_imgs.transpose(0,1)
|
||||
visual_imgs = visual_imgs.transpose(1,2)
|
||||
visual_imgs = np.asarray(visual_imgs)
|
||||
|
||||
if is_main_process()==True:
|
||||
stamp = str(epoch) + '_' + str(iteration)
|
||||
temp_fname = args.output_dir + 'visual_' + stamp + '.jpg'
|
||||
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
|
||||
aml_run.log_image(name='visual results', path=temp_fname)
|
||||
|
||||
if iteration % iters_per_epoch == 0:
|
||||
val_mPVE, val_mPJPE, val_PAmPJPE, val_count = run_validate(args, val_dataloader,
|
||||
Graphormer_model,
|
||||
criterion_keypoints,
|
||||
criterion_vertices,
|
||||
epoch,
|
||||
smpl,
|
||||
mesh_sampler)
|
||||
aml_run.log(name='mPVE', value=float(1000*val_mPVE))
|
||||
aml_run.log(name='mPJPE', value=float(1000*val_mPJPE))
|
||||
aml_run.log(name='PAmPJPE', value=float(1000*val_PAmPJPE))
|
||||
logger.info(
|
||||
' '.join(['Validation', 'epoch: {ep}',]).format(ep=epoch)
|
||||
+ ' mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f}, Data Count: {:6.2f}'.format(1000*val_mPVE, 1000*val_mPJPE, 1000*val_PAmPJPE, val_count)
|
||||
)
|
||||
|
||||
if val_PAmPJPE<log_eval_metrics.PAmPJPE:
|
||||
checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration)
|
||||
log_eval_metrics.update(val_mPVE, val_mPJPE, val_PAmPJPE, epoch)
|
||||
|
||||
|
||||
total_training_time = time.time() - start_training_time
|
||||
total_time_str = str(datetime.timedelta(seconds=total_training_time))
|
||||
logger.info('Total training time: {} ({:.4f} s / iter)'.format(
|
||||
total_time_str, total_training_time / max_iter)
|
||||
)
|
||||
checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration)
|
||||
|
||||
logger.info(
|
||||
' Best Results:'
|
||||
+ ' mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f}, at epoch {:6.2f}'.format(1000*log_eval_metrics.mPVE, 1000*log_eval_metrics.mPJPE, 1000*log_eval_metrics.PAmPJPE, log_eval_metrics.epoch)
|
||||
)
|
||||
|
||||
|
||||
def run_eval_general(args, val_dataloader, Graphormer_model, smpl, mesh_sampler):
|
||||
smpl.eval()
|
||||
criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
|
||||
criterion_vertices = torch.nn.L1Loss().to(device)
|
||||
|
||||
epoch = 0
|
||||
if args.distributed:
|
||||
Graphormer_model = torch.nn.parallel.DistributedDataParallel(
|
||||
Graphormer_model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
Graphormer_model.eval()
|
||||
|
||||
val_mPVE, val_mPJPE, val_PAmPJPE, val_count = run_validate(args, val_dataloader,
|
||||
Graphormer_model,
|
||||
criterion_keypoints,
|
||||
criterion_vertices,
|
||||
epoch,
|
||||
smpl,
|
||||
mesh_sampler)
|
||||
|
||||
aml_run.log(name='mPVE', value=float(1000*val_mPVE))
|
||||
aml_run.log(name='mPJPE', value=float(1000*val_mPJPE))
|
||||
aml_run.log(name='PAmPJPE', value=float(1000*val_PAmPJPE))
|
||||
|
||||
logger.info(
|
||||
' '.join(['Validation', 'epoch: {ep}',]).format(ep=epoch)
|
||||
+ ' mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f} '.format(1000*val_mPVE, 1000*val_mPJPE, 1000*val_PAmPJPE)
|
||||
)
|
||||
# checkpoint_dir = save_checkpoint(Graphormer_model, args, 0, 0)
|
||||
return
|
||||
|
||||
def run_validate(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, smpl, mesh_sampler):
|
||||
batch_time = AverageMeter()
|
||||
mPVE = AverageMeter()
|
||||
mPJPE = AverageMeter()
|
||||
PAmPJPE = AverageMeter()
|
||||
# switch to evaluate mode
|
||||
Graphormer_model.eval()
|
||||
smpl.eval()
|
||||
with torch.no_grad():
|
||||
# end = time.time()
|
||||
for i, (img_keys, images, annotations) in enumerate(val_loader):
|
||||
batch_size = images.size(0)
|
||||
# compute output
|
||||
images = images.to(device)
|
||||
gt_3d_joints = annotations['joints_3d'].to(device)
|
||||
gt_3d_pelvis = gt_3d_joints[:,cfg.J24_NAME.index('Pelvis'),:3]
|
||||
gt_3d_joints = gt_3d_joints[:,cfg.J24_TO_J14,:]
|
||||
gt_3d_joints[:,:,:3] = gt_3d_joints[:,:,:3] - gt_3d_pelvis[:, None, :]
|
||||
has_3d_joints = annotations['has_3d_joints'].to(device)
|
||||
|
||||
gt_pose = annotations['pose'].to(device)
|
||||
gt_betas = annotations['betas'].to(device)
|
||||
has_smpl = annotations['has_smpl'].to(device)
|
||||
|
||||
# generate simplified mesh
|
||||
gt_vertices = smpl(gt_pose, gt_betas)
|
||||
gt_vertices_sub = mesh_sampler.downsample(gt_vertices)
|
||||
gt_vertices_sub2 = mesh_sampler.downsample(gt_vertices_sub, n1=1, n2=2)
|
||||
|
||||
# normalize gt based on smpl pelvis
|
||||
gt_smpl_3d_joints = smpl.get_h36m_joints(gt_vertices)
|
||||
gt_smpl_3d_pelvis = gt_smpl_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
|
||||
gt_vertices_sub2 = gt_vertices_sub2 - gt_smpl_3d_pelvis[:, None, :]
|
||||
gt_vertices = gt_vertices - gt_smpl_3d_pelvis[:, None, :]
|
||||
|
||||
# forward-pass
|
||||
pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices = Graphormer_model(images, smpl, mesh_sampler)
|
||||
|
||||
# obtain 3d joints from full mesh
|
||||
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
|
||||
|
||||
pred_3d_pelvis = pred_3d_joints_from_smpl[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
|
||||
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
|
||||
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl - pred_3d_pelvis[:, None, :]
|
||||
pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
|
||||
|
||||
# measure errors
|
||||
error_vertices = mean_per_vertex_error(pred_vertices, gt_vertices, has_smpl)
|
||||
error_joints = mean_per_joint_position_error(pred_3d_joints_from_smpl, gt_3d_joints, has_3d_joints)
|
||||
error_joints_pa = reconstruction_error(pred_3d_joints_from_smpl.cpu().numpy(), gt_3d_joints[:,:,:3].cpu().numpy(), reduction=None)
|
||||
|
||||
if len(error_vertices)>0:
|
||||
mPVE.update(np.mean(error_vertices), int(torch.sum(has_smpl)) )
|
||||
if len(error_joints)>0:
|
||||
mPJPE.update(np.mean(error_joints), int(torch.sum(has_3d_joints)) )
|
||||
if len(error_joints_pa)>0:
|
||||
PAmPJPE.update(np.mean(error_joints_pa), int(torch.sum(has_3d_joints)) )
|
||||
|
||||
val_mPVE = all_gather(float(mPVE.avg))
|
||||
val_mPVE = sum(val_mPVE)/len(val_mPVE)
|
||||
val_mPJPE = all_gather(float(mPJPE.avg))
|
||||
val_mPJPE = sum(val_mPJPE)/len(val_mPJPE)
|
||||
|
||||
val_PAmPJPE = all_gather(float(PAmPJPE.avg))
|
||||
val_PAmPJPE = sum(val_PAmPJPE)/len(val_PAmPJPE)
|
||||
|
||||
val_count = all_gather(float(mPVE.count))
|
||||
val_count = sum(val_count)
|
||||
|
||||
return val_mPVE, val_mPJPE, val_PAmPJPE, val_count
|
||||
|
||||
|
||||
def visualize_mesh( renderer,
|
||||
images,
|
||||
gt_keypoints_2d,
|
||||
pred_vertices,
|
||||
pred_camera,
|
||||
pred_keypoints_2d):
|
||||
"""Tensorboard logging."""
|
||||
gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
|
||||
to_lsp = list(range(14))
|
||||
rend_imgs = []
|
||||
batch_size = pred_vertices.shape[0]
|
||||
# Do visualization for the first 6 images of the batch
|
||||
for i in range(min(batch_size, 10)):
|
||||
img = images[i].cpu().numpy().transpose(1,2,0)
|
||||
# Get LSP keypoints from the full list of keypoints
|
||||
gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
|
||||
pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
|
||||
# Get predict vertices for the particular example
|
||||
vertices = pred_vertices[i].cpu().numpy()
|
||||
cam = pred_camera[i].cpu().numpy()
|
||||
# Visualize reconstruction and detected pose
|
||||
rend_img = visualize_reconstruction(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer)
|
||||
rend_img = rend_img.transpose(2,0,1)
|
||||
rend_imgs.append(torch.from_numpy(rend_img))
|
||||
rend_imgs = make_grid(rend_imgs, nrow=1)
|
||||
return rend_imgs
|
||||
|
||||
def visualize_mesh_test( renderer,
|
||||
images,
|
||||
gt_keypoints_2d,
|
||||
pred_vertices,
|
||||
pred_camera,
|
||||
pred_keypoints_2d,
|
||||
PAmPJPE_h36m_j14):
|
||||
"""Tensorboard logging."""
|
||||
gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
|
||||
to_lsp = list(range(14))
|
||||
rend_imgs = []
|
||||
batch_size = pred_vertices.shape[0]
|
||||
# Do visualization for the first 6 images of the batch
|
||||
for i in range(min(batch_size, 10)):
|
||||
img = images[i].cpu().numpy().transpose(1,2,0)
|
||||
# Get LSP keypoints from the full list of keypoints
|
||||
gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
|
||||
pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
|
||||
# Get predict vertices for the particular example
|
||||
vertices = pred_vertices[i].cpu().numpy()
|
||||
cam = pred_camera[i].cpu().numpy()
|
||||
score = PAmPJPE_h36m_j14[i]
|
||||
# Visualize reconstruction and detected pose
|
||||
rend_img = visualize_reconstruction_test(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer, score)
|
||||
rend_img = rend_img.transpose(2,0,1)
|
||||
rend_imgs.append(torch.from_numpy(rend_img))
|
||||
rend_imgs = make_grid(rend_imgs, nrow=1)
|
||||
return rend_imgs
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
#########################################################
|
||||
# Data related arguments
|
||||
#########################################################
|
||||
parser.add_argument("--data_dir", default='datasets', type=str, required=False,
|
||||
help="Directory with all datasets, each in one subfolder")
|
||||
parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False,
|
||||
help="Yaml file with all data for training.")
|
||||
parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False,
|
||||
help="Yaml file with all data for validation.")
|
||||
parser.add_argument("--num_workers", default=4, type=int,
|
||||
help="Workers in dataloader.")
|
||||
parser.add_argument("--img_scale_factor", default=1, type=int,
|
||||
help="adjust image resolution.")
|
||||
#########################################################
|
||||
# Loading/saving checkpoints
|
||||
#########################################################
|
||||
parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
|
||||
help="Path to pre-trained transformer model or model type.")
|
||||
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
|
||||
help="Path to specific checkpoint for resume training.")
|
||||
parser.add_argument("--output_dir", default='output/', type=str, required=False,
|
||||
help="The output directory to save checkpoint and test results.")
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name.")
|
||||
#########################################################
|
||||
# Training parameters
|
||||
#########################################################
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=30, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=30, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float,
|
||||
help="The initial lr.")
|
||||
parser.add_argument("--num_train_epochs", default=200, type=int,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--vertices_loss_weight", default=100.0, type=float)
|
||||
parser.add_argument("--joints_loss_weight", default=1000.0, type=float)
|
||||
parser.add_argument("--vloss_w_full", default=0.33, type=float)
|
||||
parser.add_argument("--vloss_w_sub", default=0.33, type=float)
|
||||
parser.add_argument("--vloss_w_sub2", default=0.33, type=float)
|
||||
parser.add_argument("--drop_out", default=0.1, type=float,
|
||||
help="Drop out ratio in BERT.")
|
||||
#########################################################
|
||||
# Model architectures
|
||||
#########################################################
|
||||
parser.add_argument('-a', '--arch', default='hrnet-w64',
|
||||
help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
|
||||
parser.add_argument("--num_hidden_layers", default=4, type=int, required=False,
|
||||
help="Update model config if given")
|
||||
parser.add_argument("--hidden_size", default=-1, type=int, required=False,
|
||||
help="Update model config if given")
|
||||
parser.add_argument("--num_attention_heads", default=4, type=int, required=False,
|
||||
help="Update model config if given. Note that the division of "
|
||||
"hidden_size / num_attention_heads should be in integer.")
|
||||
parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
|
||||
help="Update model config if given.")
|
||||
parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
|
||||
help="The Image Feature Dimension.")
|
||||
parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
|
||||
help="The Image Feature Dimension.")
|
||||
parser.add_argument("--which_gcn", default='0,0,1', type=str,
|
||||
help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
|
||||
parser.add_argument("--mesh_type", default='body', type=str, help="body or hand")
|
||||
parser.add_argument("--interm_size_scale", default=2, type=int)
|
||||
#########################################################
|
||||
# Others
|
||||
#########################################################
|
||||
parser.add_argument("--run_eval_only", default=False, action='store_true',)
|
||||
parser.add_argument('--logging_steps', type=int, default=1000,
|
||||
help="Log every X steps.")
|
||||
parser.add_argument("--device", type=str, default='cuda',
|
||||
help="cuda or cpu")
|
||||
parser.add_argument('--seed', type=int, default=88,
|
||||
help="random seed for initialization.")
|
||||
parser.add_argument("--local_rank", type=int, default=0,
|
||||
help="For distributed training.")
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
global logger
|
||||
# Setup CUDA, GPU & distributed training
|
||||
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
||||
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
|
||||
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
|
||||
|
||||
args.distributed = args.num_gpus > 1
|
||||
args.device = torch.device(args.device)
|
||||
if args.distributed:
|
||||
print("Init distributed training on local rank {} ({}), rank {}, world size {}".format(args.local_rank, int(os.environ["LOCAL_RANK"]), int(os.environ["NODE_RANK"]), args.num_gpus))
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend='nccl', init_method='env://'
|
||||
)
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
args.device = torch.device("cuda", local_rank)
|
||||
synchronize()
|
||||
|
||||
mkdir(args.output_dir)
|
||||
logger = setup_logger("Graphormer", args.output_dir, get_rank())
|
||||
set_seed(args.seed, args.num_gpus)
|
||||
logger.info("Using {} GPUs".format(args.num_gpus))
|
||||
|
||||
# Mesh and SMPL utils
|
||||
smpl = SMPL().to(args.device)
|
||||
mesh_sampler = Mesh()
|
||||
|
||||
# Renderer for visualization
|
||||
renderer = Renderer(faces=smpl.faces.cpu().numpy())
|
||||
|
||||
# Load model
|
||||
trans_encoder = []
|
||||
|
||||
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
|
||||
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
|
||||
output_feat_dim = input_feat_dim[1:] + [3]
|
||||
|
||||
# which encoder block to have graph convs
|
||||
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
|
||||
|
||||
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
|
||||
# if only run eval, load checkpoint
|
||||
logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
|
||||
_model = torch.load(args.resume_checkpoint)
|
||||
else:
|
||||
# init three transformer-encoder blocks in a loop
|
||||
for i in range(len(output_feat_dim)):
|
||||
config_class, model_class = BertConfig, Graphormer
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name \
|
||||
else args.model_name_or_path)
|
||||
|
||||
config.output_attentions = False
|
||||
config.hidden_dropout_prob = args.drop_out
|
||||
config.img_feature_dim = input_feat_dim[i]
|
||||
config.output_feature_dim = output_feat_dim[i]
|
||||
args.hidden_size = hidden_feat_dim[i]
|
||||
args.intermediate_size = int(args.hidden_size*args.interm_size_scale)
|
||||
|
||||
if which_blk_graph[i]==1:
|
||||
config.graph_conv = True
|
||||
logger.info("Add Graph Conv")
|
||||
else:
|
||||
config.graph_conv = False
|
||||
|
||||
config.mesh_type = args.mesh_type
|
||||
|
||||
# update model structure if specified in arguments
|
||||
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
|
||||
|
||||
for idx, param in enumerate(update_params):
|
||||
arg_param = getattr(args, param)
|
||||
config_param = getattr(config, param)
|
||||
if arg_param > 0 and arg_param != config_param:
|
||||
logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
|
||||
setattr(config, param, arg_param)
|
||||
|
||||
# init a transformer encoder and append it to a list
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
model = model_class(config=config)
|
||||
logger.info("Init model from scratch.")
|
||||
trans_encoder.append(model)
|
||||
|
||||
|
||||
# init ImageNet pre-trained backbone model
|
||||
if args.arch=='hrnet':
|
||||
hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||
hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
|
||||
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||
logger.info('=> loading hrnet-v2-w40 model')
|
||||
elif args.arch=='hrnet-w64':
|
||||
hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||
hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
|
||||
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||
logger.info('=> loading hrnet-v2-w64 model')
|
||||
else:
|
||||
print("=> using pre-trained model '{}'".format(args.arch))
|
||||
backbone = models.__dict__[args.arch](pretrained=True)
|
||||
# remove the last fc layer
|
||||
backbone = torch.nn.Sequential(*list(backbone.children())[:-2])
|
||||
|
||||
|
||||
trans_encoder = torch.nn.Sequential(*trans_encoder)
|
||||
total_params = sum(p.numel() for p in trans_encoder.parameters())
|
||||
logger.info('Graphormer encoders total parameters: {}'.format(total_params))
|
||||
backbone_total_params = sum(p.numel() for p in backbone.parameters())
|
||||
logger.info('Backbone total parameters: {}'.format(backbone_total_params))
|
||||
|
||||
# build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder)
|
||||
_model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler)
|
||||
|
||||
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
|
||||
# for fine-tuning or resume training or inference, load weights from checkpoint
|
||||
logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
|
||||
# workaround approach to load sparse tensor in graph conv.
|
||||
states = torch.load(args.resume_checkpoint)
|
||||
# states = checkpoint_loaded.state_dict()
|
||||
for k, v in states.items():
|
||||
states[k] = v.cpu()
|
||||
# del checkpoint_loaded
|
||||
_model.load_state_dict(states, strict=False)
|
||||
del states
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
_model.to(args.device)
|
||||
logger.info("Training parameters %s", args)
|
||||
|
||||
if args.run_eval_only==True:
|
||||
val_dataloader = make_data_loader(args, args.val_yaml,
|
||||
args.distributed, is_train=False, scale_factor=args.img_scale_factor)
|
||||
run_eval_general(args, val_dataloader, _model, smpl, mesh_sampler)
|
||||
|
||||
else:
|
||||
train_dataloader = make_data_loader(args, args.train_yaml,
|
||||
args.distributed, is_train=True, scale_factor=args.img_scale_factor)
|
||||
val_dataloader = make_data_loader(args, args.val_yaml,
|
||||
args.distributed, is_train=False, scale_factor=args.img_scale_factor)
|
||||
run(args, train_dataloader, val_dataloader, _model, smpl, mesh_sampler, renderer)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -1,351 +0,0 @@
|
||||
"""
|
||||
Copyright (c) Microsoft Corporation.
|
||||
Licensed under the MIT license.
|
||||
|
||||
End-to-end inference codes for
|
||||
3D human body mesh reconstruction from an image
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import argparse
|
||||
import os
|
||||
import os.path as op
|
||||
import code
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from torchvision.utils import make_grid
|
||||
import gc
|
||||
import numpy as np
|
||||
import cv2
|
||||
from mesh_graphormer.modeling.bert import BertConfig, Graphormer
|
||||
from mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network
|
||||
from mesh_graphormer.modeling._smpl import SMPL, Mesh
|
||||
from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
|
||||
from mesh_graphormer.modeling.hrnet.config import config as hrnet_config
|
||||
from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
|
||||
import mesh_graphormer.modeling.data.config as cfg
|
||||
from mesh_graphormer.datasets.build import make_data_loader
|
||||
|
||||
from mesh_graphormer.utils.logger import setup_logger
|
||||
from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
|
||||
from mesh_graphormer.utils.miscellaneous import mkdir, set_seed
|
||||
from mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger
|
||||
from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text
|
||||
from mesh_graphormer.utils.metric_pampjpe import reconstruction_error
|
||||
from mesh_graphormer.utils.geometric_layers import orthographic_projection
|
||||
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
device = "cuda"
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(224),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])])
|
||||
|
||||
transform_visualize = transforms.Compose([
|
||||
transforms.Resize(224),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor()])
|
||||
|
||||
def run_inference(args, image_list, Graphormer_model, smpl, renderer, mesh_sampler):
|
||||
# switch to evaluate mode
|
||||
Graphormer_model.eval()
|
||||
smpl.eval()
|
||||
with torch.no_grad():
|
||||
for image_file in image_list:
|
||||
if 'pred' not in image_file:
|
||||
att_all = []
|
||||
img = Image.open(image_file)
|
||||
img_tensor = transform(img)
|
||||
img_visual = transform_visualize(img)
|
||||
|
||||
batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
|
||||
batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device)
|
||||
# forward-pass
|
||||
pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, smpl, mesh_sampler)
|
||||
|
||||
# obtain 3d joints from full mesh
|
||||
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
|
||||
|
||||
pred_3d_pelvis = pred_3d_joints_from_smpl[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
|
||||
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
|
||||
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl - pred_3d_pelvis[:, None, :]
|
||||
pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
|
||||
|
||||
# save attantion
|
||||
att_max_value = att[-1]
|
||||
att_cpu = np.asarray(att_max_value.cpu().detach())
|
||||
att_all.append(att_cpu)
|
||||
|
||||
# obtain 3d joints, which are regressed from the full mesh
|
||||
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
|
||||
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
|
||||
# obtain 2d joints, which are projected from 3d joints of smpl mesh
|
||||
pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera)
|
||||
pred_2d_431_vertices_from_smpl = orthographic_projection(pred_vertices_sub2, pred_camera)
|
||||
visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0],
|
||||
pred_vertices[0].detach(),
|
||||
pred_camera.detach())
|
||||
# visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0],
|
||||
# pred_vertices[0].detach(),
|
||||
# pred_vertices_sub2[0].detach(),
|
||||
# pred_2d_431_vertices_from_smpl[0].detach(),
|
||||
# pred_2d_joints_from_smpl[0].detach(),
|
||||
# pred_camera.detach(),
|
||||
# att[-1][0].detach())
|
||||
|
||||
visual_imgs = visual_imgs_output.transpose(1,2,0)
|
||||
visual_imgs = np.asarray(visual_imgs)
|
||||
|
||||
temp_fname = image_file[:-4] + '_graphormer_pred.jpg'
|
||||
print('save to ', temp_fname)
|
||||
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
|
||||
|
||||
return
|
||||
|
||||
def visualize_mesh( renderer, images,
|
||||
pred_vertices_full,
|
||||
pred_camera):
|
||||
img = images.cpu().numpy().transpose(1,2,0)
|
||||
# Get predict vertices for the particular example
|
||||
vertices_full = pred_vertices_full.cpu().numpy()
|
||||
cam = pred_camera.cpu().numpy()
|
||||
# Visualize only mesh reconstruction
|
||||
rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue')
|
||||
rend_img = rend_img.transpose(2,0,1)
|
||||
return rend_img
|
||||
|
||||
def visualize_mesh_and_attention( renderer, images,
|
||||
pred_vertices_full,
|
||||
pred_vertices,
|
||||
pred_2d_vertices,
|
||||
pred_2d_joints,
|
||||
pred_camera,
|
||||
attention):
|
||||
img = images.cpu().numpy().transpose(1,2,0)
|
||||
# Get predict vertices for the particular example
|
||||
vertices_full = pred_vertices_full.cpu().numpy()
|
||||
vertices = pred_vertices.cpu().numpy()
|
||||
vertices_2d = pred_2d_vertices.cpu().numpy()
|
||||
joints_2d = pred_2d_joints.cpu().numpy()
|
||||
cam = pred_camera.cpu().numpy()
|
||||
att = attention.cpu().numpy()
|
||||
# Visualize reconstruction and attention
|
||||
rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue')
|
||||
rend_img = rend_img.transpose(2,0,1)
|
||||
return rend_img
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
#########################################################
|
||||
# Data related arguments
|
||||
#########################################################
|
||||
parser.add_argument("--num_workers", default=4, type=int,
|
||||
help="Workers in dataloader.")
|
||||
parser.add_argument("--img_scale_factor", default=1, type=int,
|
||||
help="adjust image resolution.")
|
||||
parser.add_argument("--image_file_or_path", default='./samples/human-body', type=str,
|
||||
help="test data")
|
||||
#########################################################
|
||||
# Loading/saving checkpoints
|
||||
#########################################################
|
||||
parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
|
||||
help="Path to pre-trained transformer model or model type.")
|
||||
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
|
||||
help="Path to specific checkpoint for resume training.")
|
||||
parser.add_argument("--output_dir", default='output/', type=str, required=False,
|
||||
help="The output directory to save checkpoint and test results.")
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name.")
|
||||
#########################################################
|
||||
# Model architectures
|
||||
#########################################################
|
||||
parser.add_argument('-a', '--arch', default='hrnet-w64',
|
||||
help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
|
||||
parser.add_argument("--num_hidden_layers", default=4, type=int, required=False,
|
||||
help="Update model config if given")
|
||||
parser.add_argument("--hidden_size", default=-1, type=int, required=False,
|
||||
help="Update model config if given")
|
||||
parser.add_argument("--num_attention_heads", default=4, type=int, required=False,
|
||||
help="Update model config if given. Note that the division of "
|
||||
"hidden_size / num_attention_heads should be in integer.")
|
||||
parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
|
||||
help="Update model config if given.")
|
||||
parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
|
||||
help="The Image Feature Dimension.")
|
||||
parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
|
||||
help="The Image Feature Dimension.")
|
||||
parser.add_argument("--which_gcn", default='0,0,1', type=str,
|
||||
help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
|
||||
parser.add_argument("--mesh_type", default='body', type=str, help="body or hand")
|
||||
parser.add_argument("--interm_size_scale", default=2, type=int)
|
||||
#########################################################
|
||||
# Others
|
||||
#########################################################
|
||||
parser.add_argument("--run_eval_only", default=True, action='store_true',)
|
||||
parser.add_argument("--device", type=str, default='cuda',
|
||||
help="cuda or cpu")
|
||||
parser.add_argument('--seed', type=int, default=88,
|
||||
help="random seed for initialization.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
global logger
|
||||
# Setup CUDA, GPU & distributed training
|
||||
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
||||
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
|
||||
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
|
||||
|
||||
args.distributed = args.num_gpus > 1
|
||||
args.device = torch.device(args.device)
|
||||
|
||||
mkdir(args.output_dir)
|
||||
logger = setup_logger("Graphormer", args.output_dir, get_rank())
|
||||
set_seed(args.seed, args.num_gpus)
|
||||
logger.info("Using {} GPUs".format(args.num_gpus))
|
||||
|
||||
# Mesh and SMPL utils
|
||||
smpl = SMPL().to(args.device)
|
||||
mesh_sampler = Mesh()
|
||||
|
||||
# Renderer for visualization
|
||||
renderer = Renderer(faces=smpl.faces.cpu().numpy())
|
||||
|
||||
# Load model
|
||||
trans_encoder = []
|
||||
|
||||
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
|
||||
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
|
||||
output_feat_dim = input_feat_dim[1:] + [3]
|
||||
|
||||
# which encoder block to have graph convs
|
||||
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
|
||||
|
||||
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
|
||||
# if only run eval, load checkpoint
|
||||
logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
|
||||
_model = torch.load(args.resume_checkpoint)
|
||||
else:
|
||||
# init three transformer-encoder blocks in a loop
|
||||
for i in range(len(output_feat_dim)):
|
||||
config_class, model_class = BertConfig, Graphormer
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name \
|
||||
else args.model_name_or_path)
|
||||
|
||||
config.output_attentions = False
|
||||
config.img_feature_dim = input_feat_dim[i]
|
||||
config.output_feature_dim = output_feat_dim[i]
|
||||
args.hidden_size = hidden_feat_dim[i]
|
||||
args.intermediate_size = int(args.hidden_size*args.interm_size_scale)
|
||||
|
||||
if which_blk_graph[i]==1:
|
||||
config.graph_conv = True
|
||||
logger.info("Add Graph Conv")
|
||||
else:
|
||||
config.graph_conv = False
|
||||
|
||||
config.mesh_type = args.mesh_type
|
||||
|
||||
# update model structure if specified in arguments
|
||||
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
|
||||
|
||||
for idx, param in enumerate(update_params):
|
||||
arg_param = getattr(args, param)
|
||||
config_param = getattr(config, param)
|
||||
if arg_param > 0 and arg_param != config_param:
|
||||
logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
|
||||
setattr(config, param, arg_param)
|
||||
|
||||
# init a transformer encoder and append it to a list
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
model = model_class(config=config)
|
||||
logger.info("Init model from scratch.")
|
||||
trans_encoder.append(model)
|
||||
|
||||
# init ImageNet pre-trained backbone model
|
||||
if args.arch=='hrnet':
|
||||
hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||
hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
|
||||
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||
logger.info('=> loading hrnet-v2-w40 model')
|
||||
elif args.arch=='hrnet-w64':
|
||||
hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||
hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
|
||||
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||
logger.info('=> loading hrnet-v2-w64 model')
|
||||
else:
|
||||
print("=> using pre-trained model '{}'".format(args.arch))
|
||||
backbone = models.__dict__[args.arch](pretrained=True)
|
||||
# remove the last fc layer
|
||||
backbone = torch.nn.Sequential(*list(backbone.children())[:-2])
|
||||
|
||||
|
||||
trans_encoder = torch.nn.Sequential(*trans_encoder)
|
||||
total_params = sum(p.numel() for p in trans_encoder.parameters())
|
||||
logger.info('Graphormer encoders total parameters: {}'.format(total_params))
|
||||
backbone_total_params = sum(p.numel() for p in backbone.parameters())
|
||||
logger.info('Backbone total parameters: {}'.format(backbone_total_params))
|
||||
|
||||
# build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder)
|
||||
_model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler)
|
||||
|
||||
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
|
||||
# for fine-tuning or resume training or inference, load weights from checkpoint
|
||||
logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
|
||||
# workaround approach to load sparse tensor in graph conv.
|
||||
states = torch.load(args.resume_checkpoint)
|
||||
# states = checkpoint_loaded.state_dict()
|
||||
for k, v in states.items():
|
||||
states[k] = v.cpu()
|
||||
# del checkpoint_loaded
|
||||
_model.load_state_dict(states, strict=False)
|
||||
del states
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# update configs to enable attention outputs
|
||||
setattr(_model.trans_encoder[-1].config,'output_attentions', True)
|
||||
setattr(_model.trans_encoder[-1].config,'output_hidden_states', True)
|
||||
_model.trans_encoder[-1].bert.encoder.output_attentions = True
|
||||
_model.trans_encoder[-1].bert.encoder.output_hidden_states = True
|
||||
for iter_layer in range(4):
|
||||
_model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True
|
||||
for inter_block in range(3):
|
||||
setattr(_model.trans_encoder[-1].config,'device', args.device)
|
||||
|
||||
_model.to(args.device)
|
||||
logger.info("Run inference")
|
||||
|
||||
image_list = []
|
||||
if not args.image_file_or_path:
|
||||
raise ValueError("image_file_or_path not specified")
|
||||
if op.isfile(args.image_file_or_path):
|
||||
image_list = [args.image_file_or_path]
|
||||
elif op.isdir(args.image_file_or_path):
|
||||
# should be a path with images only
|
||||
for filename in os.listdir(args.image_file_or_path):
|
||||
if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename:
|
||||
image_list.append(args.image_file_or_path+'/'+filename)
|
||||
else:
|
||||
raise ValueError("Cannot find images at {}".format(args.image_file_or_path))
|
||||
|
||||
run_inference(args, image_list, _model, smpl, renderer, mesh_sampler)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -1,713 +0,0 @@
|
||||
"""
|
||||
Copyright (c) Microsoft Corporation.
|
||||
Licensed under the MIT license.
|
||||
|
||||
Training and evaluation codes for
|
||||
3D hand mesh reconstruction from an image
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import argparse
|
||||
import os
|
||||
import os.path as op
|
||||
import code
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from torchvision.utils import make_grid
|
||||
import gc
|
||||
import numpy as np
|
||||
import cv2
|
||||
from mesh_graphormer.modeling.bert import BertConfig, Graphormer
|
||||
from mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network
|
||||
from mesh_graphormer.modeling._mano import MANO, Mesh
|
||||
from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
|
||||
from mesh_graphormer.modeling.hrnet.config import config as hrnet_config
|
||||
from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
|
||||
import mesh_graphormer.modeling.data.config as cfg
|
||||
from mesh_graphormer.datasets.build import make_hand_data_loader
|
||||
|
||||
from mesh_graphormer.utils.logger import setup_logger
|
||||
from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
|
||||
from mesh_graphormer.utils.miscellaneous import mkdir, set_seed
|
||||
from mesh_graphormer.utils.metric_logger import AverageMeter
|
||||
from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction, visualize_reconstruction_test, visualize_reconstruction_no_text
|
||||
from mesh_graphormer.utils.metric_pampjpe import reconstruction_error
|
||||
from mesh_graphormer.utils.geometric_layers import orthographic_projection
|
||||
|
||||
|
||||
device = "cuda"
|
||||
|
||||
from azureml.core.run import Run
|
||||
aml_run = Run.get_context()
|
||||
|
||||
def save_checkpoint(model, args, epoch, iteration, num_trial=10):
|
||||
checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format(
|
||||
epoch, iteration))
|
||||
if not is_main_process():
|
||||
return checkpoint_dir
|
||||
mkdir(checkpoint_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
for i in range(num_trial):
|
||||
try:
|
||||
torch.save(model_to_save, op.join(checkpoint_dir, 'model.bin'))
|
||||
torch.save(model_to_save.state_dict(), op.join(checkpoint_dir, 'state_dict.bin'))
|
||||
torch.save(args, op.join(checkpoint_dir, 'training_args.bin'))
|
||||
logger.info("Save checkpoint to {}".format(checkpoint_dir))
|
||||
break
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
logger.info("Failed to save checkpoint after {} trails.".format(num_trial))
|
||||
return checkpoint_dir
|
||||
|
||||
def adjust_learning_rate(optimizer, epoch, args):
|
||||
"""
|
||||
Sets the learning rate to the initial LR decayed by x every y epochs
|
||||
x = 0.1, y = args.num_train_epochs/2.0 = 100
|
||||
"""
|
||||
lr = args.lr * (0.1 ** (epoch // (args.num_train_epochs/2.0) ))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, has_pose_2d):
|
||||
"""
|
||||
Compute 2D reprojection loss if 2D keypoint annotations are available.
|
||||
The confidence is binary and indicates whether the keypoints exist or not.
|
||||
"""
|
||||
conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
|
||||
loss = (conf * criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean()
|
||||
return loss
|
||||
|
||||
def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d):
|
||||
"""
|
||||
Compute 3D keypoint loss if 3D keypoint annotations are available.
|
||||
"""
|
||||
conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
|
||||
gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
|
||||
gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
|
||||
conf = conf[has_pose_3d == 1]
|
||||
pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
|
||||
if len(gt_keypoints_3d) > 0:
|
||||
gt_root = gt_keypoints_3d[:, 0,:]
|
||||
gt_keypoints_3d = gt_keypoints_3d - gt_root[:, None, :]
|
||||
pred_root = pred_keypoints_3d[:, 0,:]
|
||||
pred_keypoints_3d = pred_keypoints_3d - pred_root[:, None, :]
|
||||
return (conf * criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean()
|
||||
else:
|
||||
return torch.FloatTensor(1).fill_(0.).to(device)
|
||||
|
||||
def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl):
|
||||
"""
|
||||
Compute per-vertex loss if vertex annotations are available.
|
||||
"""
|
||||
pred_vertices_with_shape = pred_vertices[has_smpl == 1]
|
||||
gt_vertices_with_shape = gt_vertices[has_smpl == 1]
|
||||
if len(gt_vertices_with_shape) > 0:
|
||||
return criterion_vertices(pred_vertices_with_shape, gt_vertices_with_shape)
|
||||
else:
|
||||
return torch.FloatTensor(1).fill_(0.).to(device)
|
||||
|
||||
|
||||
def run(args, train_dataloader, Graphormer_model, mano_model, renderer, mesh_sampler):
|
||||
|
||||
max_iter = len(train_dataloader)
|
||||
iters_per_epoch = max_iter // args.num_train_epochs
|
||||
|
||||
optimizer = torch.optim.Adam(params=list(Graphormer_model.parameters()),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0)
|
||||
|
||||
# define loss function (criterion) and optimizer
|
||||
criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').to(device)
|
||||
criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
|
||||
criterion_vertices = torch.nn.L1Loss().to(device)
|
||||
|
||||
if args.distributed:
|
||||
Graphormer_model = torch.nn.parallel.DistributedDataParallel(
|
||||
Graphormer_model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
|
||||
start_training_time = time.time()
|
||||
end = time.time()
|
||||
Graphormer_model.train()
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
log_losses = AverageMeter()
|
||||
log_loss_2djoints = AverageMeter()
|
||||
log_loss_3djoints = AverageMeter()
|
||||
log_loss_vertices = AverageMeter()
|
||||
|
||||
for iteration, (img_keys, images, annotations) in enumerate(train_dataloader):
|
||||
|
||||
Graphormer_model.train()
|
||||
iteration += 1
|
||||
epoch = iteration // iters_per_epoch
|
||||
batch_size = images.size(0)
|
||||
adjust_learning_rate(optimizer, epoch, args)
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
images = images.to(device)
|
||||
gt_2d_joints = annotations['joints_2d'].to(device)
|
||||
gt_pose = annotations['pose'].to(device)
|
||||
gt_betas = annotations['betas'].to(device)
|
||||
has_mesh = annotations['has_smpl'].to(device)
|
||||
has_3d_joints = has_mesh
|
||||
has_2d_joints = has_mesh
|
||||
mjm_mask = annotations['mjm_mask'].to(device)
|
||||
mvm_mask = annotations['mvm_mask'].to(device)
|
||||
|
||||
# generate mesh
|
||||
gt_vertices, gt_3d_joints = mano_model.layer(gt_pose, gt_betas)
|
||||
gt_vertices = gt_vertices/1000.0
|
||||
gt_3d_joints = gt_3d_joints/1000.0
|
||||
|
||||
gt_vertices_sub = mesh_sampler.downsample(gt_vertices)
|
||||
# normalize gt based on hand's wrist
|
||||
gt_3d_root = gt_3d_joints[:,cfg.J_NAME.index('Wrist'),:]
|
||||
gt_vertices = gt_vertices - gt_3d_root[:, None, :]
|
||||
gt_vertices_sub = gt_vertices_sub - gt_3d_root[:, None, :]
|
||||
gt_3d_joints = gt_3d_joints - gt_3d_root[:, None, :]
|
||||
gt_3d_joints_with_tag = torch.ones((batch_size,gt_3d_joints.shape[1],4)).to(device)
|
||||
gt_3d_joints_with_tag[:,:,:3] = gt_3d_joints
|
||||
|
||||
# prepare masks for mask vertex/joint modeling
|
||||
mjm_mask_ = mjm_mask.expand(-1,-1,2051)
|
||||
mvm_mask_ = mvm_mask.expand(-1,-1,2051)
|
||||
meta_masks = torch.cat([mjm_mask_, mvm_mask_], dim=1)
|
||||
|
||||
# forward-pass
|
||||
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler, meta_masks=meta_masks, is_train=True)
|
||||
|
||||
# obtain 3d joints, which are regressed from the full mesh
|
||||
pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
|
||||
|
||||
# obtain 2d joints, which are projected from 3d joints of smpl mesh
|
||||
pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
|
||||
pred_2d_joints = orthographic_projection(pred_3d_joints.contiguous(), pred_camera.contiguous())
|
||||
|
||||
# compute 3d joint loss (where the joints are directly output from transformer)
|
||||
loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints_with_tag, has_3d_joints)
|
||||
|
||||
# compute 3d vertex loss
|
||||
loss_vertices = ( args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_mesh) + \
|
||||
args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_mesh) )
|
||||
|
||||
# compute 3d joint loss (where the joints are regressed from full mesh)
|
||||
loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_mesh, gt_3d_joints_with_tag, has_3d_joints)
|
||||
# compute 2d joint loss
|
||||
loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints) + \
|
||||
keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_mesh, gt_2d_joints, has_2d_joints)
|
||||
|
||||
loss_3d_joints = loss_3d_joints + loss_reg_3d_joints
|
||||
|
||||
# we empirically use hyperparameters to balance difference losses
|
||||
loss = args.joints_loss_weight*loss_3d_joints + \
|
||||
args.vertices_loss_weight*loss_vertices + args.vertices_loss_weight*loss_2d_joints
|
||||
|
||||
# update logs
|
||||
log_loss_2djoints.update(loss_2d_joints.item(), batch_size)
|
||||
log_loss_3djoints.update(loss_3d_joints.item(), batch_size)
|
||||
log_loss_vertices.update(loss_vertices.item(), batch_size)
|
||||
log_losses.update(loss.item(), batch_size)
|
||||
|
||||
# back prop
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if iteration % args.logging_steps == 0 or iteration == max_iter:
|
||||
eta_seconds = batch_time.avg * (max_iter - iteration)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
logger.info(
|
||||
' '.join(
|
||||
['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',]
|
||||
).format(eta=eta_string, ep=epoch, iter=iteration,
|
||||
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
|
||||
+ ' loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format(
|
||||
log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg,
|
||||
optimizer.param_groups[0]['lr'])
|
||||
)
|
||||
|
||||
aml_run.log(name='Loss', value=float(log_losses.avg))
|
||||
aml_run.log(name='3d joint Loss', value=float(log_loss_3djoints.avg))
|
||||
aml_run.log(name='2d joint Loss', value=float(log_loss_2djoints.avg))
|
||||
aml_run.log(name='vertex Loss', value=float(log_loss_vertices.avg))
|
||||
|
||||
visual_imgs = visualize_mesh( renderer,
|
||||
annotations['ori_img'].detach(),
|
||||
annotations['joints_2d'].detach(),
|
||||
pred_vertices.detach(),
|
||||
pred_camera.detach(),
|
||||
pred_2d_joints_from_mesh.detach())
|
||||
visual_imgs = visual_imgs.transpose(0,1)
|
||||
visual_imgs = visual_imgs.transpose(1,2)
|
||||
visual_imgs = np.asarray(visual_imgs)
|
||||
|
||||
if is_main_process()==True:
|
||||
stamp = str(epoch) + '_' + str(iteration)
|
||||
temp_fname = args.output_dir + 'visual_' + stamp + '.jpg'
|
||||
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
|
||||
aml_run.log_image(name='visual results', path=temp_fname)
|
||||
|
||||
if iteration % iters_per_epoch == 0:
|
||||
if epoch%10==0:
|
||||
checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration)
|
||||
|
||||
total_training_time = time.time() - start_training_time
|
||||
total_time_str = str(datetime.timedelta(seconds=total_training_time))
|
||||
logger.info('Total training time: {} ({:.4f} s / iter)'.format(
|
||||
total_time_str, total_training_time / max_iter)
|
||||
)
|
||||
checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration)
|
||||
|
||||
def run_eval_and_save(args, split, val_dataloader, Graphormer_model, mano_model, renderer, mesh_sampler):
|
||||
|
||||
criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
|
||||
criterion_vertices = torch.nn.L1Loss().to(device)
|
||||
|
||||
if args.distributed:
|
||||
Graphormer_model = torch.nn.parallel.DistributedDataParallel(
|
||||
Graphormer_model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
Graphormer_model.eval()
|
||||
|
||||
if args.aml_eval==True:
|
||||
run_aml_inference_hand_mesh(args, val_dataloader,
|
||||
Graphormer_model,
|
||||
criterion_keypoints,
|
||||
criterion_vertices,
|
||||
0,
|
||||
mano_model, mesh_sampler,
|
||||
renderer, split)
|
||||
else:
|
||||
run_inference_hand_mesh(args, val_dataloader,
|
||||
Graphormer_model,
|
||||
criterion_keypoints,
|
||||
criterion_vertices,
|
||||
0,
|
||||
mano_model, mesh_sampler,
|
||||
renderer, split)
|
||||
checkpoint_dir = save_checkpoint(Graphormer_model, args, 0, 0)
|
||||
return
|
||||
|
||||
def run_aml_inference_hand_mesh(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, mano_model, mesh_sampler, renderer, split):
|
||||
# switch to evaluate mode
|
||||
Graphormer_model.eval()
|
||||
fname_output_save = []
|
||||
mesh_output_save = []
|
||||
joint_output_save = []
|
||||
world_size = get_world_size()
|
||||
with torch.no_grad():
|
||||
for i, (img_keys, images, annotations) in enumerate(val_loader):
|
||||
batch_size = images.size(0)
|
||||
# compute output
|
||||
images = images.to(device)
|
||||
|
||||
# forward-pass
|
||||
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler)
|
||||
# obtain 3d joints from full mesh
|
||||
pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
|
||||
|
||||
for j in range(batch_size):
|
||||
fname_output_save.append(img_keys[j])
|
||||
pred_vertices_list = pred_vertices[j].tolist()
|
||||
mesh_output_save.append(pred_vertices_list)
|
||||
pred_3d_joints_from_mesh_list = pred_3d_joints_from_mesh[j].tolist()
|
||||
joint_output_save.append(pred_3d_joints_from_mesh_list)
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
print('save results to pred.json')
|
||||
output_json_file = 'pred.json'
|
||||
print('save results to ', output_json_file)
|
||||
with open(output_json_file, 'w') as f:
|
||||
json.dump([joint_output_save, mesh_output_save], f)
|
||||
|
||||
azure_ckpt_name = '200' # args.resume_checkpoint.split('/')[-2].split('-')[1]
|
||||
inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot)))
|
||||
output_zip_file = args.output_dir + 'ckpt' + azure_ckpt_name + '-' + inference_setting +'-pred.zip'
|
||||
|
||||
resolved_submit_cmd = 'zip ' + output_zip_file + ' ' + output_json_file
|
||||
print(resolved_submit_cmd)
|
||||
os.system(resolved_submit_cmd)
|
||||
resolved_submit_cmd = 'rm %s'%(output_json_file)
|
||||
print(resolved_submit_cmd)
|
||||
os.system(resolved_submit_cmd)
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
|
||||
return
|
||||
|
||||
def run_inference_hand_mesh(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, mano_model, mesh_sampler, renderer, split):
|
||||
# switch to evaluate mode
|
||||
Graphormer_model.eval()
|
||||
fname_output_save = []
|
||||
mesh_output_save = []
|
||||
joint_output_save = []
|
||||
with torch.no_grad():
|
||||
for i, (img_keys, images, annotations) in enumerate(val_loader):
|
||||
batch_size = images.size(0)
|
||||
# compute output
|
||||
images = images.to(device)
|
||||
|
||||
# forward-pass
|
||||
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler)
|
||||
|
||||
# obtain 3d joints from full mesh
|
||||
pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
|
||||
pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:]
|
||||
pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :]
|
||||
pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
|
||||
|
||||
for j in range(batch_size):
|
||||
fname_output_save.append(img_keys[j])
|
||||
pred_vertices_list = pred_vertices[j].tolist()
|
||||
mesh_output_save.append(pred_vertices_list)
|
||||
pred_3d_joints_from_mesh_list = pred_3d_joints_from_mesh[j].tolist()
|
||||
joint_output_save.append(pred_3d_joints_from_mesh_list)
|
||||
|
||||
if i%20==0:
|
||||
# obtain 3d joints, which are regressed from the full mesh
|
||||
pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
|
||||
# obtain 2d joints, which are projected from 3d joints of mesh
|
||||
pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
|
||||
visual_imgs = visualize_mesh( renderer,
|
||||
annotations['ori_img'].detach(),
|
||||
annotations['joints_2d'].detach(),
|
||||
pred_vertices.detach(),
|
||||
pred_camera.detach(),
|
||||
pred_2d_joints_from_mesh.detach())
|
||||
|
||||
visual_imgs = visual_imgs.transpose(0,1)
|
||||
visual_imgs = visual_imgs.transpose(1,2)
|
||||
visual_imgs = np.asarray(visual_imgs)
|
||||
|
||||
inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot)))
|
||||
temp_fname = args.output_dir + args.resume_checkpoint[0:-9] + 'freihand_results_'+inference_setting+'_batch'+str(i)+'.jpg'
|
||||
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
|
||||
|
||||
print('save results to pred.json')
|
||||
with open('pred.json', 'w') as f:
|
||||
json.dump([joint_output_save, mesh_output_save], f)
|
||||
|
||||
run_exp_name = args.resume_checkpoint.split('/')[-3]
|
||||
run_ckpt_name = args.resume_checkpoint.split('/')[-2].split('-')[1]
|
||||
inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot)))
|
||||
resolved_submit_cmd = 'zip ' + args.output_dir + run_exp_name + '-ckpt'+ run_ckpt_name + '-' + inference_setting +'-pred.zip ' + 'pred.json'
|
||||
print(resolved_submit_cmd)
|
||||
os.system(resolved_submit_cmd)
|
||||
resolved_submit_cmd = 'rm pred.json'
|
||||
print(resolved_submit_cmd)
|
||||
os.system(resolved_submit_cmd)
|
||||
return
|
||||
|
||||
def visualize_mesh( renderer,
|
||||
images,
|
||||
gt_keypoints_2d,
|
||||
pred_vertices,
|
||||
pred_camera,
|
||||
pred_keypoints_2d):
|
||||
"""Tensorboard logging."""
|
||||
gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
|
||||
to_lsp = list(range(21))
|
||||
rend_imgs = []
|
||||
batch_size = pred_vertices.shape[0]
|
||||
# Do visualization for the first 6 images of the batch
|
||||
for i in range(min(batch_size, 10)):
|
||||
img = images[i].cpu().numpy().transpose(1,2,0)
|
||||
# Get LSP keypoints from the full list of keypoints
|
||||
gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
|
||||
pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
|
||||
# Get predict vertices for the particular example
|
||||
vertices = pred_vertices[i].cpu().numpy()
|
||||
cam = pred_camera[i].cpu().numpy()
|
||||
# Visualize reconstruction and detected pose
|
||||
rend_img = visualize_reconstruction(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer)
|
||||
rend_img = rend_img.transpose(2,0,1)
|
||||
rend_imgs.append(torch.from_numpy(rend_img))
|
||||
rend_imgs = make_grid(rend_imgs, nrow=1)
|
||||
return rend_imgs
|
||||
|
||||
def visualize_mesh_test( renderer,
|
||||
images,
|
||||
gt_keypoints_2d,
|
||||
pred_vertices,
|
||||
pred_camera,
|
||||
pred_keypoints_2d,
|
||||
PAmPJPE):
|
||||
"""Tensorboard logging."""
|
||||
gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
|
||||
to_lsp = list(range(21))
|
||||
rend_imgs = []
|
||||
batch_size = pred_vertices.shape[0]
|
||||
# Do visualization for the first 6 images of the batch
|
||||
for i in range(min(batch_size, 10)):
|
||||
img = images[i].cpu().numpy().transpose(1,2,0)
|
||||
# Get LSP keypoints from the full list of keypoints
|
||||
gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
|
||||
pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
|
||||
# Get predict vertices for the particular example
|
||||
vertices = pred_vertices[i].cpu().numpy()
|
||||
cam = pred_camera[i].cpu().numpy()
|
||||
score = PAmPJPE[i]
|
||||
# Visualize reconstruction and detected pose
|
||||
rend_img = visualize_reconstruction_test(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer, score)
|
||||
rend_img = rend_img.transpose(2,0,1)
|
||||
rend_imgs.append(torch.from_numpy(rend_img))
|
||||
rend_imgs = make_grid(rend_imgs, nrow=1)
|
||||
return rend_imgs
|
||||
|
||||
def visualize_mesh_no_text( renderer,
|
||||
images,
|
||||
pred_vertices,
|
||||
pred_camera):
|
||||
"""Tensorboard logging."""
|
||||
rend_imgs = []
|
||||
batch_size = pred_vertices.shape[0]
|
||||
# Do visualization for the first 6 images of the batch
|
||||
for i in range(min(batch_size, 1)):
|
||||
img = images[i].cpu().numpy().transpose(1,2,0)
|
||||
# Get predict vertices for the particular example
|
||||
vertices = pred_vertices[i].cpu().numpy()
|
||||
cam = pred_camera[i].cpu().numpy()
|
||||
# Visualize reconstruction only
|
||||
rend_img = visualize_reconstruction_no_text(img, 224, vertices, cam, renderer, color='hand')
|
||||
rend_img = rend_img.transpose(2,0,1)
|
||||
rend_imgs.append(torch.from_numpy(rend_img))
|
||||
rend_imgs = make_grid(rend_imgs, nrow=1)
|
||||
return rend_imgs
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
#########################################################
|
||||
# Data related arguments
|
||||
#########################################################
|
||||
parser.add_argument("--data_dir", default='datasets', type=str, required=False,
|
||||
help="Directory with all datasets, each in one subfolder")
|
||||
parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False,
|
||||
help="Yaml file with all data for training.")
|
||||
parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False,
|
||||
help="Yaml file with all data for validation.")
|
||||
parser.add_argument("--num_workers", default=4, type=int,
|
||||
help="Workers in dataloader.")
|
||||
parser.add_argument("--img_scale_factor", default=1, type=int,
|
||||
help="adjust image resolution.")
|
||||
#########################################################
|
||||
# Loading/saving checkpoints
|
||||
#########################################################
|
||||
parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
|
||||
help="Path to pre-trained transformer model or model type.")
|
||||
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
|
||||
help="Path to specific checkpoint for resume training.")
|
||||
parser.add_argument("--output_dir", default='output/', type=str, required=False,
|
||||
help="The output directory to save checkpoint and test results.")
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name.")
|
||||
parser.add_argument('-a', '--arch', default='hrnet-w64',
|
||||
help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
|
||||
#########################################################
|
||||
# Training parameters
|
||||
#########################################################
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=64, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=64, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float,
|
||||
help="The initial lr.")
|
||||
parser.add_argument("--num_train_epochs", default=200, type=int,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--vertices_loss_weight", default=1.0, type=float)
|
||||
parser.add_argument("--joints_loss_weight", default=1.0, type=float)
|
||||
parser.add_argument("--vloss_w_full", default=0.5, type=float)
|
||||
parser.add_argument("--vloss_w_sub", default=0.5, type=float)
|
||||
parser.add_argument("--drop_out", default=0.1, type=float,
|
||||
help="Drop out ratio in BERT.")
|
||||
#########################################################
|
||||
# Model architectures
|
||||
#########################################################
|
||||
parser.add_argument("--num_hidden_layers", default=-1, type=int, required=False,
|
||||
help="Update model config if given")
|
||||
parser.add_argument("--hidden_size", default=-1, type=int, required=False,
|
||||
help="Update model config if given")
|
||||
parser.add_argument("--num_attention_heads", default=-1, type=int, required=False,
|
||||
help="Update model config if given. Note that the division of "
|
||||
"hidden_size / num_attention_heads should be in integer.")
|
||||
parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
|
||||
help="Update model config if given.")
|
||||
parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
|
||||
help="The Image Feature Dimension.")
|
||||
parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
|
||||
help="The Image Feature Dimension.")
|
||||
parser.add_argument("--which_gcn", default='0,0,1', type=str,
|
||||
help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
|
||||
parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand")
|
||||
|
||||
#########################################################
|
||||
# Others
|
||||
#########################################################
|
||||
parser.add_argument("--run_eval_only", default=False, action='store_true',)
|
||||
parser.add_argument("--multiscale_inference", default=False, action='store_true',)
|
||||
# if enable "multiscale_inference", dataloader will apply transformations to the test image based on
|
||||
# the rotation "rot" and scale "sc" parameters below
|
||||
parser.add_argument("--rot", default=0, type=float)
|
||||
parser.add_argument("--sc", default=1.0, type=float)
|
||||
parser.add_argument("--aml_eval", default=False, action='store_true',)
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=100,
|
||||
help="Log every X steps.")
|
||||
parser.add_argument("--device", type=str, default='cuda',
|
||||
help="cuda or cpu")
|
||||
parser.add_argument('--seed', type=int, default=88,
|
||||
help="random seed for initialization.")
|
||||
parser.add_argument("--local_rank", type=int, default=0,
|
||||
help="For distributed training.")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def main(args):
|
||||
global logger
|
||||
# Setup CUDA, GPU & distributed training
|
||||
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
||||
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
|
||||
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
|
||||
|
||||
args.distributed = args.num_gpus > 1
|
||||
args.device = torch.device(args.device)
|
||||
if args.distributed:
|
||||
print("Init distributed training on local rank {}".format(args.local_rank))
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend='nccl', init_method='env://'
|
||||
)
|
||||
synchronize()
|
||||
|
||||
mkdir(args.output_dir)
|
||||
logger = setup_logger("Graphormer", args.output_dir, get_rank())
|
||||
set_seed(args.seed, args.num_gpus)
|
||||
logger.info("Using {} GPUs".format(args.num_gpus))
|
||||
|
||||
# Mesh and SMPL utils
|
||||
mano_model = MANO().to(args.device)
|
||||
mano_model.layer = mano_model.layer.to(device)
|
||||
mesh_sampler = Mesh()
|
||||
|
||||
# Renderer for visualization
|
||||
renderer = Renderer(faces=mano_model.face)
|
||||
|
||||
# Load pretrained model
|
||||
trans_encoder = []
|
||||
|
||||
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
|
||||
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
|
||||
output_feat_dim = input_feat_dim[1:] + [3]
|
||||
|
||||
# which encoder block to have graph convs
|
||||
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
|
||||
|
||||
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
|
||||
# if only run eval, load checkpoint
|
||||
logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
|
||||
_model = torch.load(args.resume_checkpoint)
|
||||
|
||||
else:
|
||||
# init three transformer-encoder blocks in a loop
|
||||
for i in range(len(output_feat_dim)):
|
||||
config_class, model_class = BertConfig, Graphormer
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name \
|
||||
else args.model_name_or_path)
|
||||
|
||||
config.output_attentions = False
|
||||
config.hidden_dropout_prob = args.drop_out
|
||||
config.img_feature_dim = input_feat_dim[i]
|
||||
config.output_feature_dim = output_feat_dim[i]
|
||||
args.hidden_size = hidden_feat_dim[i]
|
||||
args.intermediate_size = int(args.hidden_size*2)
|
||||
|
||||
if which_blk_graph[i]==1:
|
||||
config.graph_conv = True
|
||||
logger.info("Add Graph Conv")
|
||||
else:
|
||||
config.graph_conv = False
|
||||
|
||||
config.mesh_type = args.mesh_type
|
||||
|
||||
# update model structure if specified in arguments
|
||||
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
|
||||
for idx, param in enumerate(update_params):
|
||||
arg_param = getattr(args, param)
|
||||
config_param = getattr(config, param)
|
||||
if arg_param > 0 and arg_param != config_param:
|
||||
logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
|
||||
setattr(config, param, arg_param)
|
||||
|
||||
# init a transformer encoder and append it to a list
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
model = model_class(config=config)
|
||||
logger.info("Init model from scratch.")
|
||||
trans_encoder.append(model)
|
||||
|
||||
# create backbone model
|
||||
if args.arch=='hrnet':
|
||||
hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||
hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
|
||||
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||
logger.info('=> loading hrnet-v2-w40 model')
|
||||
elif args.arch=='hrnet-w64':
|
||||
hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||
hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
|
||||
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||
logger.info('=> loading hrnet-v2-w64 model')
|
||||
else:
|
||||
print("=> using pre-trained model '{}'".format(args.arch))
|
||||
backbone = models.__dict__[args.arch](pretrained=True)
|
||||
# remove the last fc layer
|
||||
backbone = torch.nn.Sequential(*list(backbone.children())[:-1])
|
||||
|
||||
trans_encoder = torch.nn.Sequential(*trans_encoder)
|
||||
total_params = sum(p.numel() for p in trans_encoder.parameters())
|
||||
logger.info('Graphormer encoders total parameters: {}'.format(total_params))
|
||||
backbone_total_params = sum(p.numel() for p in backbone.parameters())
|
||||
logger.info('Backbone total parameters: {}'.format(backbone_total_params))
|
||||
|
||||
# build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
|
||||
_model = Graphormer_Network(args, config, backbone, trans_encoder)
|
||||
|
||||
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
|
||||
# for fine-tuning or resume training or inference, load weights from checkpoint
|
||||
logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
|
||||
# workaround approach to load sparse tensor in graph conv.
|
||||
state_dict = torch.load(args.resume_checkpoint)
|
||||
_model.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
_model.to(args.device)
|
||||
logger.info("Training parameters %s", args)
|
||||
|
||||
if args.run_eval_only==True:
|
||||
val_dataloader = make_hand_data_loader(args, args.val_yaml,
|
||||
args.distributed, is_train=False, scale_factor=args.img_scale_factor)
|
||||
run_eval_and_save(args, 'freihand', val_dataloader, _model, mano_model, renderer, mesh_sampler)
|
||||
|
||||
else:
|
||||
train_dataloader = make_hand_data_loader(args, args.train_yaml,
|
||||
args.distributed, is_train=True, scale_factor=args.img_scale_factor)
|
||||
run(args, train_dataloader, _model, mano_model, renderer, mesh_sampler)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -1,338 +0,0 @@
|
||||
"""
|
||||
Copyright (c) Microsoft Corporation.
|
||||
Licensed under the MIT license.
|
||||
|
||||
End-to-end inference codes for
|
||||
3D hand mesh reconstruction from an image
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import argparse
|
||||
import os
|
||||
import os.path as op
|
||||
import code
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from torchvision.utils import make_grid
|
||||
import gc
|
||||
import numpy as np
|
||||
import cv2
|
||||
from mesh_graphormer.modeling.bert import BertConfig, Graphormer
|
||||
from mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network
|
||||
from mesh_graphormer.modeling._mano import MANO, Mesh
|
||||
from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
|
||||
from mesh_graphormer.modeling.hrnet.config import config as hrnet_config
|
||||
from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
|
||||
import mesh_graphormer.modeling.data.config as cfg
|
||||
from mesh_graphormer.datasets.build import make_hand_data_loader
|
||||
|
||||
from mesh_graphormer.utils.logger import setup_logger
|
||||
from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
|
||||
from mesh_graphormer.utils.miscellaneous import mkdir, set_seed
|
||||
from mesh_graphormer.utils.metric_logger import AverageMeter
|
||||
from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text
|
||||
from mesh_graphormer.utils.metric_pampjpe import reconstruction_error
|
||||
from mesh_graphormer.utils.geometric_layers import orthographic_projection
|
||||
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
device = "cuda"
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(224),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])])
|
||||
|
||||
transform_visualize = transforms.Compose([
|
||||
transforms.Resize(224),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor()])
|
||||
|
||||
def run_inference(args, image_list, Graphormer_model, mano, renderer, mesh_sampler):
|
||||
# switch to evaluate mode
|
||||
Graphormer_model.eval()
|
||||
mano.eval()
|
||||
with torch.no_grad():
|
||||
for image_file in image_list:
|
||||
if 'pred' not in image_file:
|
||||
att_all = []
|
||||
print(image_file)
|
||||
img = Image.open(image_file)
|
||||
img_tensor = transform(img)
|
||||
img_visual = transform_visualize(img)
|
||||
|
||||
batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
|
||||
batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device)
|
||||
# forward-pass
|
||||
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler)
|
||||
# obtain 3d joints from full mesh
|
||||
pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
|
||||
pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:]
|
||||
pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :]
|
||||
pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
|
||||
|
||||
# save attantion
|
||||
att_max_value = att[-1]
|
||||
att_cpu = np.asarray(att_max_value.cpu().detach())
|
||||
att_all.append(att_cpu)
|
||||
|
||||
# obtain 3d joints, which are regressed from the full mesh
|
||||
pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
|
||||
# obtain 2d joints, which are projected from 3d joints of mesh
|
||||
pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
|
||||
pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous())
|
||||
|
||||
|
||||
visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0],
|
||||
pred_vertices[0].detach(),
|
||||
pred_camera.detach())
|
||||
# visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0],
|
||||
# pred_vertices[0].detach(),
|
||||
# pred_vertices_sub[0].detach(),
|
||||
# pred_2d_coarse_vertices_from_mesh[0].detach(),
|
||||
# pred_2d_joints_from_mesh[0].detach(),
|
||||
# pred_camera.detach(),
|
||||
# att[-1][0].detach())
|
||||
visual_imgs = visual_imgs_output.transpose(1,2,0)
|
||||
visual_imgs = np.asarray(visual_imgs)
|
||||
|
||||
temp_fname = image_file[:-4] + '_graphormer_pred.jpg'
|
||||
print('save to ', temp_fname)
|
||||
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
|
||||
return
|
||||
|
||||
def visualize_mesh( renderer, images,
|
||||
pred_vertices_full,
|
||||
pred_camera):
|
||||
img = images.cpu().numpy().transpose(1,2,0)
|
||||
# Get predict vertices for the particular example
|
||||
vertices_full = pred_vertices_full.cpu().numpy()
|
||||
cam = pred_camera.cpu().numpy()
|
||||
# Visualize only mesh reconstruction
|
||||
rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue')
|
||||
rend_img = rend_img.transpose(2,0,1)
|
||||
return rend_img
|
||||
|
||||
def visualize_mesh_and_attention( renderer, images,
|
||||
pred_vertices_full,
|
||||
pred_vertices,
|
||||
pred_2d_vertices,
|
||||
pred_2d_joints,
|
||||
pred_camera,
|
||||
attention):
|
||||
img = images.cpu().numpy().transpose(1,2,0)
|
||||
# Get predict vertices for the particular example
|
||||
vertices_full = pred_vertices_full.cpu().numpy()
|
||||
vertices = pred_vertices.cpu().numpy()
|
||||
vertices_2d = pred_2d_vertices.cpu().numpy()
|
||||
joints_2d = pred_2d_joints.cpu().numpy()
|
||||
cam = pred_camera.cpu().numpy()
|
||||
att = attention.cpu().numpy()
|
||||
# Visualize reconstruction and attention
|
||||
rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue')
|
||||
rend_img = rend_img.transpose(2,0,1)
|
||||
return rend_img
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
#########################################################
|
||||
# Data related arguments
|
||||
#########################################################
|
||||
parser.add_argument("--num_workers", default=4, type=int,
|
||||
help="Workers in dataloader.")
|
||||
parser.add_argument("--img_scale_factor", default=1, type=int,
|
||||
help="adjust image resolution.")
|
||||
parser.add_argument("--image_file_or_path", default='./samples/hand', type=str,
|
||||
help="test data")
|
||||
#########################################################
|
||||
# Loading/saving checkpoints
|
||||
#########################################################
|
||||
parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
|
||||
help="Path to pre-trained transformer model or model type.")
|
||||
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
|
||||
help="Path to specific checkpoint for resume training.")
|
||||
parser.add_argument("--output_dir", default='output/', type=str, required=False,
|
||||
help="The output directory to save checkpoint and test results.")
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name.")
|
||||
parser.add_argument('-a', '--arch', default='hrnet-w64',
|
||||
help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
|
||||
#########################################################
|
||||
# Model architectures
|
||||
#########################################################
|
||||
parser.add_argument("--num_hidden_layers", default=4, type=int, required=False,
|
||||
help="Update model config if given")
|
||||
parser.add_argument("--hidden_size", default=-1, type=int, required=False,
|
||||
help="Update model config if given")
|
||||
parser.add_argument("--num_attention_heads", default=4, type=int, required=False,
|
||||
help="Update model config if given. Note that the division of "
|
||||
"hidden_size / num_attention_heads should be in integer.")
|
||||
parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
|
||||
help="Update model config if given.")
|
||||
parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
|
||||
help="The Image Feature Dimension.")
|
||||
parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
|
||||
help="The Image Feature Dimension.")
|
||||
parser.add_argument("--which_gcn", default='0,0,1', type=str,
|
||||
help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
|
||||
parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand")
|
||||
|
||||
#########################################################
|
||||
# Others
|
||||
#########################################################
|
||||
parser.add_argument("--run_eval_only", default=True, action='store_true',)
|
||||
parser.add_argument("--device", type=str, default='cuda',
|
||||
help="cuda or cpu")
|
||||
parser.add_argument('--seed', type=int, default=88,
|
||||
help="random seed for initialization.")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def main(args):
|
||||
global logger
|
||||
# Setup CUDA, GPU & distributed training
|
||||
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
||||
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
|
||||
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
|
||||
|
||||
mkdir(args.output_dir)
|
||||
logger = setup_logger("Graphormer", args.output_dir, get_rank())
|
||||
set_seed(args.seed, args.num_gpus)
|
||||
logger.info("Using {} GPUs".format(args.num_gpus))
|
||||
|
||||
# Mesh and MANO utils
|
||||
mano_model = MANO().to(args.device)
|
||||
mano_model.layer = mano_model.layer.to(device)
|
||||
mesh_sampler = Mesh()
|
||||
|
||||
# Renderer for visualization
|
||||
renderer = Renderer(faces=mano_model.face)
|
||||
|
||||
# Load pretrained model
|
||||
trans_encoder = []
|
||||
|
||||
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
|
||||
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
|
||||
output_feat_dim = input_feat_dim[1:] + [3]
|
||||
|
||||
# which encoder block to have graph convs
|
||||
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
|
||||
|
||||
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
|
||||
# if only run eval, load checkpoint
|
||||
logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
|
||||
_model = torch.load(args.resume_checkpoint)
|
||||
|
||||
else:
|
||||
# init three transformer-encoder blocks in a loop
|
||||
for i in range(len(output_feat_dim)):
|
||||
config_class, model_class = BertConfig, Graphormer
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name \
|
||||
else args.model_name_or_path)
|
||||
|
||||
config.output_attentions = False
|
||||
config.img_feature_dim = input_feat_dim[i]
|
||||
config.output_feature_dim = output_feat_dim[i]
|
||||
args.hidden_size = hidden_feat_dim[i]
|
||||
args.intermediate_size = int(args.hidden_size*2)
|
||||
|
||||
if which_blk_graph[i]==1:
|
||||
config.graph_conv = True
|
||||
logger.info("Add Graph Conv")
|
||||
else:
|
||||
config.graph_conv = False
|
||||
|
||||
config.mesh_type = args.mesh_type
|
||||
|
||||
# update model structure if specified in arguments
|
||||
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
|
||||
for idx, param in enumerate(update_params):
|
||||
arg_param = getattr(args, param)
|
||||
config_param = getattr(config, param)
|
||||
if arg_param > 0 and arg_param != config_param:
|
||||
logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
|
||||
setattr(config, param, arg_param)
|
||||
|
||||
# init a transformer encoder and append it to a list
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
model = model_class(config=config)
|
||||
logger.info("Init model from scratch.")
|
||||
trans_encoder.append(model)
|
||||
|
||||
# create backbone model
|
||||
if args.arch=='hrnet':
|
||||
hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||
hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
|
||||
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||
logger.info('=> loading hrnet-v2-w40 model')
|
||||
elif args.arch=='hrnet-w64':
|
||||
hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||
hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
|
||||
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||
logger.info('=> loading hrnet-v2-w64 model')
|
||||
else:
|
||||
print("=> using pre-trained model '{}'".format(args.arch))
|
||||
backbone = models.__dict__[args.arch](pretrained=True)
|
||||
# remove the last fc layer
|
||||
backbone = torch.nn.Sequential(*list(backbone.children())[:-1])
|
||||
|
||||
trans_encoder = torch.nn.Sequential(*trans_encoder)
|
||||
total_params = sum(p.numel() for p in trans_encoder.parameters())
|
||||
logger.info('Graphormer encoders total parameters: {}'.format(total_params))
|
||||
backbone_total_params = sum(p.numel() for p in backbone.parameters())
|
||||
logger.info('Backbone total parameters: {}'.format(backbone_total_params))
|
||||
|
||||
# build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
|
||||
_model = Graphormer_Network(args, config, backbone, trans_encoder)
|
||||
|
||||
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
|
||||
# for fine-tuning or resume training or inference, load weights from checkpoint
|
||||
logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
|
||||
# workaround approach to load sparse tensor in graph conv.
|
||||
state_dict = torch.load(args.resume_checkpoint)
|
||||
_model.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# update configs to enable attention outputs
|
||||
setattr(_model.trans_encoder[-1].config,'output_attentions', True)
|
||||
setattr(_model.trans_encoder[-1].config,'output_hidden_states', True)
|
||||
_model.trans_encoder[-1].bert.encoder.output_attentions = True
|
||||
_model.trans_encoder[-1].bert.encoder.output_hidden_states = True
|
||||
for iter_layer in range(4):
|
||||
_model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True
|
||||
for inter_block in range(3):
|
||||
setattr(_model.trans_encoder[-1].config,'device', args.device)
|
||||
|
||||
_model.to(args.device)
|
||||
logger.info("Run inference")
|
||||
|
||||
image_list = []
|
||||
if not args.image_file_or_path:
|
||||
raise ValueError("image_file_or_path not specified")
|
||||
if op.isfile(args.image_file_or_path):
|
||||
image_list = [args.image_file_or_path]
|
||||
elif op.isdir(args.image_file_or_path):
|
||||
# should be a path with images only
|
||||
for filename in os.listdir(args.image_file_or_path):
|
||||
if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename:
|
||||
image_list.append(args.image_file_or_path+'/'+filename)
|
||||
else:
|
||||
raise ValueError("Cannot find images at {}".format(args.image_file_or_path))
|
||||
|
||||
run_inference(args, image_list, _model, mano_model, renderer, mesh_sampler)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -1,136 +0,0 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import os.path as op
|
||||
import code
|
||||
import json
|
||||
import zipfile
|
||||
import torch
|
||||
import numpy as np
|
||||
from mesh_graphormer.utils.metric_pampjpe import get_alignMesh
|
||||
|
||||
|
||||
def load_pred_json(filepath):
|
||||
archive = zipfile.ZipFile(filepath, 'r')
|
||||
jsondata = archive.read('pred.json')
|
||||
reference = json.loads(jsondata.decode("utf-8"))
|
||||
return reference[0], reference[1]
|
||||
|
||||
|
||||
def multiscale_fusion(output_dir):
|
||||
s = '10'
|
||||
filepath = output_dir+'ckpt200-sc10_rot0-pred.zip'
|
||||
ref_joints, ref_vertices = load_pred_json(filepath)
|
||||
ref_joints_array = np.asarray(ref_joints)
|
||||
ref_vertices_array = np.asarray(ref_vertices)
|
||||
|
||||
rotations = [0.0]
|
||||
for i in range(1,10):
|
||||
rotations.append(i*10)
|
||||
rotations.append(i*-10)
|
||||
|
||||
scale = [0.7,0.8,0.9,1.0,1.1]
|
||||
multiscale_joints = []
|
||||
multiscale_vertices = []
|
||||
|
||||
counter = 0
|
||||
for s in scale:
|
||||
for r in rotations:
|
||||
setting = 'sc%02d_rot%s'%(int(s*10),str(int(r)))
|
||||
filepath = output_dir+'ckpt200-'+setting+'-pred.zip'
|
||||
joints, vertices = load_pred_json(filepath)
|
||||
joints_array = np.asarray(joints)
|
||||
vertices_array = np.asarray(vertices)
|
||||
|
||||
pa_joint_error, pa_joint_array, _ = get_alignMesh(joints_array, ref_joints_array, reduction=None)
|
||||
pa_vertices_error, pa_vertices_array, _ = get_alignMesh(vertices_array, ref_vertices_array, reduction=None)
|
||||
print('--------------------------')
|
||||
print('scale:', s, 'rotate', r)
|
||||
print('PAMPJPE:', 1000*np.mean(pa_joint_error))
|
||||
print('PAMPVPE:', 1000*np.mean(pa_vertices_error))
|
||||
multiscale_joints.append(pa_joint_array)
|
||||
multiscale_vertices.append(pa_vertices_array)
|
||||
counter = counter + 1
|
||||
|
||||
overall_joints_array = ref_joints_array.copy()
|
||||
overall_vertices_array = ref_vertices_array.copy()
|
||||
for i in range(counter):
|
||||
overall_joints_array += multiscale_joints[i]
|
||||
overall_vertices_array += multiscale_vertices[i]
|
||||
|
||||
overall_joints_array /= (1+counter)
|
||||
overall_vertices_array /= (1+counter)
|
||||
pa_joint_error, pa_joint_array, _ = get_alignMesh(overall_joints_array, ref_joints_array, reduction=None)
|
||||
pa_vertices_error, pa_vertices_array, _ = get_alignMesh(overall_vertices_array, ref_vertices_array, reduction=None)
|
||||
print('--------------------------')
|
||||
print('overall:')
|
||||
print('PAMPJPE:', 1000*np.mean(pa_joint_error))
|
||||
print('PAMPVPE:', 1000*np.mean(pa_vertices_error))
|
||||
|
||||
joint_output_save = overall_joints_array.tolist()
|
||||
mesh_output_save = overall_vertices_array.tolist()
|
||||
|
||||
print('save results to pred.json')
|
||||
with open('pred.json', 'w') as f:
|
||||
json.dump([joint_output_save, mesh_output_save], f)
|
||||
|
||||
|
||||
filepath = output_dir+'ckpt200-multisc-pred.zip'
|
||||
resolved_submit_cmd = 'zip ' + filepath + ' ' + 'pred.json'
|
||||
print(resolved_submit_cmd)
|
||||
os.system(resolved_submit_cmd)
|
||||
resolved_submit_cmd = 'rm pred.json'
|
||||
print(resolved_submit_cmd)
|
||||
os.system(resolved_submit_cmd)
|
||||
|
||||
|
||||
def run_multiscale_inference(model_path, mode, output_dir):
|
||||
|
||||
if mode==True:
|
||||
rotations = [0.0]
|
||||
for i in range(1,10):
|
||||
rotations.append(i*10)
|
||||
rotations.append(i*-10)
|
||||
scale = [0.7,0.8,0.9,1.0,1.1]
|
||||
else:
|
||||
rotations = [0.0]
|
||||
scale = [1.0]
|
||||
|
||||
job_cmd = "python ./src/tools/run_gphmer_handmesh.py " \
|
||||
"--val_yaml freihand_v3/test.yaml " \
|
||||
"--resume_checkpoint %s " \
|
||||
"--per_gpu_eval_batch_size 32 --run_eval_only --num_worker 2 " \
|
||||
"--multiscale_inference " \
|
||||
"--rot %f " \
|
||||
"--sc %s " \
|
||||
"--arch hrnet-w64 " \
|
||||
"--num_hidden_layers 4 " \
|
||||
"--num_attention_heads 4 " \
|
||||
"--input_feat_dim 2051,512,128 " \
|
||||
"--hidden_feat_dim 1024,256,64 " \
|
||||
"--output_dir %s"
|
||||
|
||||
for s in scale:
|
||||
for r in rotations:
|
||||
resolved_submit_cmd = job_cmd%(model_path, r, s, output_dir)
|
||||
print(resolved_submit_cmd)
|
||||
os.system(resolved_submit_cmd)
|
||||
|
||||
def main(args):
|
||||
model_path = args.model_path
|
||||
mode = args.multiscale_inference
|
||||
output_dir = args.output_dir
|
||||
run_multiscale_inference(model_path, mode, output_dir)
|
||||
if mode==True:
|
||||
multiscale_fusion(output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Evaluate a checkpoint in the folder")
|
||||
parser.add_argument("--model_path")
|
||||
parser.add_argument("--multiscale_inference", default=False, action='store_true',)
|
||||
parser.add_argument("--output_dir", default='output/', type=str, required=False,
|
||||
help="The output directory to save checkpoint and test results.")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user