diff --git a/mesh_graphormer/tools/run_gphmer_bodymesh.py b/mesh_graphormer/tools/run_gphmer_bodymesh.py deleted file mode 100644 index 2f65017..0000000 --- a/mesh_graphormer/tools/run_gphmer_bodymesh.py +++ /dev/null @@ -1,750 +0,0 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT license. - -Training and evaluation codes for -3D human body mesh reconstruction from an image -""" - -from __future__ import absolute_import, division, print_function -import argparse -import os -import os.path as op -import code -import json -import time -import datetime -import torch -import torchvision.models as models -from torchvision.utils import make_grid -import gc -import numpy as np -import cv2 -from mesh_graphormer.modeling.bert import BertConfig, Graphormer -from mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network -from mesh_graphormer.modeling._smpl import SMPL, Mesh -from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat -from mesh_graphormer.modeling.hrnet.config import config as hrnet_config -from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config -import mesh_graphormer.modeling.data.config as cfg -from mesh_graphormer.datasets.build import make_data_loader - -from mesh_graphormer.utils.logger import setup_logger -from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather -from mesh_graphormer.utils.miscellaneous import mkdir, set_seed -from mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger -from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction, visualize_reconstruction_test -from mesh_graphormer.utils.metric_pampjpe import reconstruction_error -from mesh_graphormer.utils.geometric_layers import orthographic_projection - - -device = "cuda" - -from azureml.core.run import Run -aml_run = Run.get_context() - -def save_checkpoint(model, args, epoch, iteration, num_trial=10): - checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format( - epoch, iteration)) - if not is_main_process(): - return checkpoint_dir - mkdir(checkpoint_dir) - model_to_save = model.module if hasattr(model, 'module') else model - for i in range(num_trial): - try: - torch.save(model_to_save, op.join(checkpoint_dir, 'model.bin')) - torch.save(model_to_save.state_dict(), op.join(checkpoint_dir, 'state_dict.bin')) - torch.save(args, op.join(checkpoint_dir, 'training_args.bin')) - logger.info("Save checkpoint to {}".format(checkpoint_dir)) - break - except: - pass - else: - logger.info("Failed to save checkpoint after {} trails.".format(num_trial)) - return checkpoint_dir - -def save_scores(args, split, mpjpe, pampjpe, mpve): - eval_log = [] - res = {} - res['mPJPE'] = mpjpe - res['PAmPJPE'] = pampjpe - res['mPVE'] = mpve - eval_log.append(res) - with open(op.join(args.output_dir, split+'_eval_logs.json'), 'w') as f: - json.dump(eval_log, f) - logger.info("Save eval scores to {}".format(args.output_dir)) - return - -def adjust_learning_rate(optimizer, epoch, args): - """ - Sets the learning rate to the initial LR decayed by x every y epochs - x = 0.1, y = args.num_train_epochs/2.0 = 100 - """ - lr = args.lr * (0.1 ** (epoch // (args.num_train_epochs/2.0) )) - for param_group in optimizer.param_groups: - param_group['lr'] = lr - -def mean_per_joint_position_error(pred, gt, has_3d_joints): - """ - Compute mPJPE - """ - gt = gt[has_3d_joints == 1] - gt = gt[:, :, :-1] - pred = pred[has_3d_joints == 1] - - with torch.no_grad(): - gt_pelvis = (gt[:, 2,:] + gt[:, 3,:]) / 2 - gt = gt - gt_pelvis[:, None, :] - pred_pelvis = (pred[:, 2,:] + pred[:, 3,:]) / 2 - pred = pred - pred_pelvis[:, None, :] - error = torch.sqrt( ((pred - gt) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() - return error - -def mean_per_vertex_error(pred, gt, has_smpl): - """ - Compute mPVE - """ - pred = pred[has_smpl == 1] - gt = gt[has_smpl == 1] - with torch.no_grad(): - error = torch.sqrt( ((pred - gt) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() - return error - -def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, has_pose_2d): - """ - Compute 2D reprojection loss if 2D keypoint annotations are available. - The confidence (conf) is binary and indicates whether the keypoints exist or not. - """ - conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() - loss = (conf * criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean() - return loss - -def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d, device): - """ - Compute 3D keypoint loss if 3D keypoint annotations are available. - """ - conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() - gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone() - gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1] - conf = conf[has_pose_3d == 1] - pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1] - if len(gt_keypoints_3d) > 0: - gt_pelvis = (gt_keypoints_3d[:, 2,:] + gt_keypoints_3d[:, 3,:]) / 2 - gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :] - pred_pelvis = (pred_keypoints_3d[:, 2,:] + pred_keypoints_3d[:, 3,:]) / 2 - pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :] - return (conf * criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean() - else: - return torch.FloatTensor(1).fill_(0.).to(device) - -def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, device): - """ - Compute per-vertex loss if vertex annotations are available. - """ - pred_vertices_with_shape = pred_vertices[has_smpl == 1] - gt_vertices_with_shape = gt_vertices[has_smpl == 1] - if len(gt_vertices_with_shape) > 0: - return criterion_vertices(pred_vertices_with_shape, gt_vertices_with_shape) - else: - return torch.FloatTensor(1).fill_(0.).to(device) - -def rectify_pose(pose): - pose = pose.copy() - R_mod = cv2.Rodrigues(np.array([np.pi, 0, 0]))[0] - R_root = cv2.Rodrigues(pose[:3])[0] - new_root = R_root.dot(R_mod) - pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3) - return pose - -def run(args, train_dataloader, val_dataloader, Graphormer_model, smpl, mesh_sampler, renderer): - smpl.eval() - max_iter = len(train_dataloader) - iters_per_epoch = max_iter // args.num_train_epochs - if iters_per_epoch<1000: - args.logging_steps = 500 - - optimizer = torch.optim.Adam(params=list(Graphormer_model.parameters()), - lr=args.lr, - betas=(0.9, 0.999), - weight_decay=0) - - # define loss function (criterion) and optimizer - criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').to(device) - criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device) - criterion_vertices = torch.nn.L1Loss().to(device) - - if args.distributed: - Graphormer_model = torch.nn.parallel.DistributedDataParallel( - Graphormer_model, device_ids=[args.local_rank], - output_device=args.local_rank, - find_unused_parameters=True, - ) - - logger.info( - ' '.join( - ['Local rank: {o}', 'Max iteration: {a}', 'iters_per_epoch: {b}','num_train_epochs: {c}',] - ).format(o=args.local_rank, a=max_iter, b=iters_per_epoch, c=args.num_train_epochs) - ) - - start_training_time = time.time() - end = time.time() - Graphormer_model.train() - batch_time = AverageMeter() - data_time = AverageMeter() - log_losses = AverageMeter() - log_loss_2djoints = AverageMeter() - log_loss_3djoints = AverageMeter() - log_loss_vertices = AverageMeter() - log_eval_metrics = EvalMetricsLogger() - - for iteration, (img_keys, images, annotations) in enumerate(train_dataloader): - # gc.collect() - # torch.cuda.empty_cache() - Graphormer_model.train() - iteration += 1 - epoch = iteration // iters_per_epoch - batch_size = images.size(0) - adjust_learning_rate(optimizer, epoch, args) - data_time.update(time.time() - end) - - images = images.to(device) - gt_2d_joints = annotations['joints_2d'].to(device) - gt_2d_joints = gt_2d_joints[:,cfg.J24_TO_J14,:] - has_2d_joints = annotations['has_2d_joints'].to(device) - - gt_3d_joints = annotations['joints_3d'].to(device) - gt_3d_pelvis = gt_3d_joints[:,cfg.J24_NAME.index('Pelvis'),:3] - gt_3d_joints = gt_3d_joints[:,cfg.J24_TO_J14,:] - gt_3d_joints[:,:,:3] = gt_3d_joints[:,:,:3] - gt_3d_pelvis[:, None, :] - has_3d_joints = annotations['has_3d_joints'].to(device) - - gt_pose = annotations['pose'].to(device) - gt_betas = annotations['betas'].to(device) - has_smpl = annotations['has_smpl'].to(device) - mjm_mask = annotations['mjm_mask'].to(device) - mvm_mask = annotations['mvm_mask'].to(device) - - # generate simplified mesh - gt_vertices = smpl(gt_pose, gt_betas) - gt_vertices_sub2 = mesh_sampler.downsample(gt_vertices, n1=0, n2=2) - gt_vertices_sub = mesh_sampler.downsample(gt_vertices) - - # normalize gt based on smpl's pelvis - gt_smpl_3d_joints = smpl.get_h36m_joints(gt_vertices) - gt_smpl_3d_pelvis = gt_smpl_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:] - gt_vertices_sub2 = gt_vertices_sub2 - gt_smpl_3d_pelvis[:, None, :] - - # prepare masks for mask vertex/joint modeling - mjm_mask_ = mjm_mask.expand(-1,-1,2051) - mvm_mask_ = mvm_mask.expand(-1,-1,2051) - meta_masks = torch.cat([mjm_mask_, mvm_mask_], dim=1) - - # forward-pass - pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices = Graphormer_model(images, smpl, mesh_sampler, meta_masks=meta_masks, is_train=True) - - # normalize gt based on smpl's pelvis - gt_vertices_sub = gt_vertices_sub - gt_smpl_3d_pelvis[:, None, :] - gt_vertices = gt_vertices - gt_smpl_3d_pelvis[:, None, :] - - # obtain 3d joints, which are regressed from the full mesh - pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices) - pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:] - - # obtain 2d joints, which are projected from 3d joints of smpl mesh - pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera) - pred_2d_joints = orthographic_projection(pred_3d_joints, pred_camera) - - # compute 3d joint loss (where the joints are directly output from transformer) - loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints, has_3d_joints, args.device) - # compute 3d vertex loss - loss_vertices = ( args.vloss_w_sub2 * vertices_loss(criterion_vertices, pred_vertices_sub2, gt_vertices_sub2, has_smpl, args.device) + \ - args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_smpl, args.device) + \ - args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, args.device) ) - # compute 3d joint loss (where the joints are regressed from full mesh) - loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_smpl, gt_3d_joints, has_3d_joints, args.device) - # compute 2d joint loss - loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints) + \ - keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_smpl, gt_2d_joints, has_2d_joints) - - loss_3d_joints = loss_3d_joints + loss_reg_3d_joints - - # we empirically use hyperparameters to balance difference losses - loss = args.joints_loss_weight*loss_3d_joints + \ - args.vertices_loss_weight*loss_vertices + args.vertices_loss_weight*loss_2d_joints - - # update logs - log_loss_2djoints.update(loss_2d_joints.item(), batch_size) - log_loss_3djoints.update(loss_3d_joints.item(), batch_size) - log_loss_vertices.update(loss_vertices.item(), batch_size) - log_losses.update(loss.item(), batch_size) - - # back prop - optimizer.zero_grad() - loss.backward() - optimizer.step() - - batch_time.update(time.time() - end) - end = time.time() - - if iteration % args.logging_steps == 0 or iteration == max_iter: - eta_seconds = batch_time.avg * (max_iter - iteration) - eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) - logger.info( - ' '.join( - ['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',] - ).format(eta=eta_string, ep=epoch, iter=iteration, - memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) - + ' loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format( - log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg, - optimizer.param_groups[0]['lr']) - ) - - aml_run.log(name='Loss', value=float(log_losses.avg)) - aml_run.log(name='3d joint Loss', value=float(log_loss_3djoints.avg)) - aml_run.log(name='2d joint Loss', value=float(log_loss_2djoints.avg)) - aml_run.log(name='vertex Loss', value=float(log_loss_vertices.avg)) - - visual_imgs = visualize_mesh( renderer, - annotations['ori_img'].detach(), - annotations['joints_2d'].detach(), - pred_vertices.detach(), - pred_camera.detach(), - pred_2d_joints_from_smpl.detach()) - visual_imgs = visual_imgs.transpose(0,1) - visual_imgs = visual_imgs.transpose(1,2) - visual_imgs = np.asarray(visual_imgs) - - if is_main_process()==True: - stamp = str(epoch) + '_' + str(iteration) - temp_fname = args.output_dir + 'visual_' + stamp + '.jpg' - cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) - aml_run.log_image(name='visual results', path=temp_fname) - - if iteration % iters_per_epoch == 0: - val_mPVE, val_mPJPE, val_PAmPJPE, val_count = run_validate(args, val_dataloader, - Graphormer_model, - criterion_keypoints, - criterion_vertices, - epoch, - smpl, - mesh_sampler) - aml_run.log(name='mPVE', value=float(1000*val_mPVE)) - aml_run.log(name='mPJPE', value=float(1000*val_mPJPE)) - aml_run.log(name='PAmPJPE', value=float(1000*val_PAmPJPE)) - logger.info( - ' '.join(['Validation', 'epoch: {ep}',]).format(ep=epoch) - + ' mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f}, Data Count: {:6.2f}'.format(1000*val_mPVE, 1000*val_mPJPE, 1000*val_PAmPJPE, val_count) - ) - - if val_PAmPJPE0: - mPVE.update(np.mean(error_vertices), int(torch.sum(has_smpl)) ) - if len(error_joints)>0: - mPJPE.update(np.mean(error_joints), int(torch.sum(has_3d_joints)) ) - if len(error_joints_pa)>0: - PAmPJPE.update(np.mean(error_joints_pa), int(torch.sum(has_3d_joints)) ) - - val_mPVE = all_gather(float(mPVE.avg)) - val_mPVE = sum(val_mPVE)/len(val_mPVE) - val_mPJPE = all_gather(float(mPJPE.avg)) - val_mPJPE = sum(val_mPJPE)/len(val_mPJPE) - - val_PAmPJPE = all_gather(float(PAmPJPE.avg)) - val_PAmPJPE = sum(val_PAmPJPE)/len(val_PAmPJPE) - - val_count = all_gather(float(mPVE.count)) - val_count = sum(val_count) - - return val_mPVE, val_mPJPE, val_PAmPJPE, val_count - - -def visualize_mesh( renderer, - images, - gt_keypoints_2d, - pred_vertices, - pred_camera, - pred_keypoints_2d): - """Tensorboard logging.""" - gt_keypoints_2d = gt_keypoints_2d.cpu().numpy() - to_lsp = list(range(14)) - rend_imgs = [] - batch_size = pred_vertices.shape[0] - # Do visualization for the first 6 images of the batch - for i in range(min(batch_size, 10)): - img = images[i].cpu().numpy().transpose(1,2,0) - # Get LSP keypoints from the full list of keypoints - gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp] - pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp] - # Get predict vertices for the particular example - vertices = pred_vertices[i].cpu().numpy() - cam = pred_camera[i].cpu().numpy() - # Visualize reconstruction and detected pose - rend_img = visualize_reconstruction(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer) - rend_img = rend_img.transpose(2,0,1) - rend_imgs.append(torch.from_numpy(rend_img)) - rend_imgs = make_grid(rend_imgs, nrow=1) - return rend_imgs - -def visualize_mesh_test( renderer, - images, - gt_keypoints_2d, - pred_vertices, - pred_camera, - pred_keypoints_2d, - PAmPJPE_h36m_j14): - """Tensorboard logging.""" - gt_keypoints_2d = gt_keypoints_2d.cpu().numpy() - to_lsp = list(range(14)) - rend_imgs = [] - batch_size = pred_vertices.shape[0] - # Do visualization for the first 6 images of the batch - for i in range(min(batch_size, 10)): - img = images[i].cpu().numpy().transpose(1,2,0) - # Get LSP keypoints from the full list of keypoints - gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp] - pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp] - # Get predict vertices for the particular example - vertices = pred_vertices[i].cpu().numpy() - cam = pred_camera[i].cpu().numpy() - score = PAmPJPE_h36m_j14[i] - # Visualize reconstruction and detected pose - rend_img = visualize_reconstruction_test(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer, score) - rend_img = rend_img.transpose(2,0,1) - rend_imgs.append(torch.from_numpy(rend_img)) - rend_imgs = make_grid(rend_imgs, nrow=1) - return rend_imgs - - -def parse_args(): - parser = argparse.ArgumentParser() - ######################################################### - # Data related arguments - ######################################################### - parser.add_argument("--data_dir", default='datasets', type=str, required=False, - help="Directory with all datasets, each in one subfolder") - parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False, - help="Yaml file with all data for training.") - parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False, - help="Yaml file with all data for validation.") - parser.add_argument("--num_workers", default=4, type=int, - help="Workers in dataloader.") - parser.add_argument("--img_scale_factor", default=1, type=int, - help="adjust image resolution.") - ######################################################### - # Loading/saving checkpoints - ######################################################### - parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, - help="Path to pre-trained transformer model or model type.") - parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, - help="Path to specific checkpoint for resume training.") - parser.add_argument("--output_dir", default='output/', type=str, required=False, - help="The output directory to save checkpoint and test results.") - parser.add_argument("--config_name", default="", type=str, - help="Pretrained config name or path if not the same as model_name.") - ######################################################### - # Training parameters - ######################################################### - parser.add_argument("--per_gpu_train_batch_size", default=30, type=int, - help="Batch size per GPU/CPU for training.") - parser.add_argument("--per_gpu_eval_batch_size", default=30, type=int, - help="Batch size per GPU/CPU for evaluation.") - parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float, - help="The initial lr.") - parser.add_argument("--num_train_epochs", default=200, type=int, - help="Total number of training epochs to perform.") - parser.add_argument("--vertices_loss_weight", default=100.0, type=float) - parser.add_argument("--joints_loss_weight", default=1000.0, type=float) - parser.add_argument("--vloss_w_full", default=0.33, type=float) - parser.add_argument("--vloss_w_sub", default=0.33, type=float) - parser.add_argument("--vloss_w_sub2", default=0.33, type=float) - parser.add_argument("--drop_out", default=0.1, type=float, - help="Drop out ratio in BERT.") - ######################################################### - # Model architectures - ######################################################### - parser.add_argument('-a', '--arch', default='hrnet-w64', - help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') - parser.add_argument("--num_hidden_layers", default=4, type=int, required=False, - help="Update model config if given") - parser.add_argument("--hidden_size", default=-1, type=int, required=False, - help="Update model config if given") - parser.add_argument("--num_attention_heads", default=4, type=int, required=False, - help="Update model config if given. Note that the division of " - "hidden_size / num_attention_heads should be in integer.") - parser.add_argument("--intermediate_size", default=-1, type=int, required=False, - help="Update model config if given.") - parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, - help="The Image Feature Dimension.") - parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, - help="The Image Feature Dimension.") - parser.add_argument("--which_gcn", default='0,0,1', type=str, - help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") - parser.add_argument("--mesh_type", default='body', type=str, help="body or hand") - parser.add_argument("--interm_size_scale", default=2, type=int) - ######################################################### - # Others - ######################################################### - parser.add_argument("--run_eval_only", default=False, action='store_true',) - parser.add_argument('--logging_steps', type=int, default=1000, - help="Log every X steps.") - parser.add_argument("--device", type=str, default='cuda', - help="cuda or cpu") - parser.add_argument('--seed', type=int, default=88, - help="random seed for initialization.") - parser.add_argument("--local_rank", type=int, default=0, - help="For distributed training.") - - - args = parser.parse_args() - return args - - -def main(args): - global logger - # Setup CUDA, GPU & distributed training - args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 - os.environ['OMP_NUM_THREADS'] = str(args.num_workers) - print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) - - args.distributed = args.num_gpus > 1 - args.device = torch.device(args.device) - if args.distributed: - print("Init distributed training on local rank {} ({}), rank {}, world size {}".format(args.local_rank, int(os.environ["LOCAL_RANK"]), int(os.environ["NODE_RANK"]), args.num_gpus)) - torch.cuda.set_device(args.local_rank) - torch.distributed.init_process_group( - backend='nccl', init_method='env://' - ) - local_rank = int(os.environ["LOCAL_RANK"]) - args.device = torch.device("cuda", local_rank) - synchronize() - - mkdir(args.output_dir) - logger = setup_logger("Graphormer", args.output_dir, get_rank()) - set_seed(args.seed, args.num_gpus) - logger.info("Using {} GPUs".format(args.num_gpus)) - - # Mesh and SMPL utils - smpl = SMPL().to(args.device) - mesh_sampler = Mesh() - - # Renderer for visualization - renderer = Renderer(faces=smpl.faces.cpu().numpy()) - - # Load model - trans_encoder = [] - - input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] - hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] - output_feat_dim = input_feat_dim[1:] + [3] - - # which encoder block to have graph convs - which_blk_graph = [int(item) for item in args.which_gcn.split(',')] - - if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: - # if only run eval, load checkpoint - logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) - _model = torch.load(args.resume_checkpoint) - else: - # init three transformer-encoder blocks in a loop - for i in range(len(output_feat_dim)): - config_class, model_class = BertConfig, Graphormer - config = config_class.from_pretrained(args.config_name if args.config_name \ - else args.model_name_or_path) - - config.output_attentions = False - config.hidden_dropout_prob = args.drop_out - config.img_feature_dim = input_feat_dim[i] - config.output_feature_dim = output_feat_dim[i] - args.hidden_size = hidden_feat_dim[i] - args.intermediate_size = int(args.hidden_size*args.interm_size_scale) - - if which_blk_graph[i]==1: - config.graph_conv = True - logger.info("Add Graph Conv") - else: - config.graph_conv = False - - config.mesh_type = args.mesh_type - - # update model structure if specified in arguments - update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] - - for idx, param in enumerate(update_params): - arg_param = getattr(args, param) - config_param = getattr(config, param) - if arg_param > 0 and arg_param != config_param: - logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) - setattr(config, param, arg_param) - - # init a transformer encoder and append it to a list - assert config.hidden_size % config.num_attention_heads == 0 - model = model_class(config=config) - logger.info("Init model from scratch.") - trans_encoder.append(model) - - - # init ImageNet pre-trained backbone model - if args.arch=='hrnet': - hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' - hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' - hrnet_update_config(hrnet_config, hrnet_yaml) - backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) - logger.info('=> loading hrnet-v2-w40 model') - elif args.arch=='hrnet-w64': - hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' - hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' - hrnet_update_config(hrnet_config, hrnet_yaml) - backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) - logger.info('=> loading hrnet-v2-w64 model') - else: - print("=> using pre-trained model '{}'".format(args.arch)) - backbone = models.__dict__[args.arch](pretrained=True) - # remove the last fc layer - backbone = torch.nn.Sequential(*list(backbone.children())[:-2]) - - - trans_encoder = torch.nn.Sequential(*trans_encoder) - total_params = sum(p.numel() for p in trans_encoder.parameters()) - logger.info('Graphormer encoders total parameters: {}'.format(total_params)) - backbone_total_params = sum(p.numel() for p in backbone.parameters()) - logger.info('Backbone total parameters: {}'.format(backbone_total_params)) - - # build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder) - _model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler) - - if args.resume_checkpoint!=None and args.resume_checkpoint!='None': - # for fine-tuning or resume training or inference, load weights from checkpoint - logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) - # workaround approach to load sparse tensor in graph conv. - states = torch.load(args.resume_checkpoint) - # states = checkpoint_loaded.state_dict() - for k, v in states.items(): - states[k] = v.cpu() - # del checkpoint_loaded - _model.load_state_dict(states, strict=False) - del states - gc.collect() - torch.cuda.empty_cache() - - - _model.to(args.device) - logger.info("Training parameters %s", args) - - if args.run_eval_only==True: - val_dataloader = make_data_loader(args, args.val_yaml, - args.distributed, is_train=False, scale_factor=args.img_scale_factor) - run_eval_general(args, val_dataloader, _model, smpl, mesh_sampler) - - else: - train_dataloader = make_data_loader(args, args.train_yaml, - args.distributed, is_train=True, scale_factor=args.img_scale_factor) - val_dataloader = make_data_loader(args, args.val_yaml, - args.distributed, is_train=False, scale_factor=args.img_scale_factor) - run(args, train_dataloader, val_dataloader, _model, smpl, mesh_sampler, renderer) - - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/mesh_graphormer/tools/run_gphmer_bodymesh_inference.py b/mesh_graphormer/tools/run_gphmer_bodymesh_inference.py deleted file mode 100644 index 0ffc213..0000000 --- a/mesh_graphormer/tools/run_gphmer_bodymesh_inference.py +++ /dev/null @@ -1,351 +0,0 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT license. - -End-to-end inference codes for -3D human body mesh reconstruction from an image -""" - -from __future__ import absolute_import, division, print_function -import argparse -import os -import os.path as op -import code -import json -import time -import datetime -import torch -import torchvision.models as models -from torchvision.utils import make_grid -import gc -import numpy as np -import cv2 -from mesh_graphormer.modeling.bert import BertConfig, Graphormer -from mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network -from mesh_graphormer.modeling._smpl import SMPL, Mesh -from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat -from mesh_graphormer.modeling.hrnet.config import config as hrnet_config -from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config -import mesh_graphormer.modeling.data.config as cfg -from mesh_graphormer.datasets.build import make_data_loader - -from mesh_graphormer.utils.logger import setup_logger -from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather -from mesh_graphormer.utils.miscellaneous import mkdir, set_seed -from mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger -from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text -from mesh_graphormer.utils.metric_pampjpe import reconstruction_error -from mesh_graphormer.utils.geometric_layers import orthographic_projection - -from PIL import Image -from torchvision import transforms - - -device = "cuda" - -transform = transforms.Compose([ - transforms.Resize(224), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - -transform_visualize = transforms.Compose([ - transforms.Resize(224), - transforms.CenterCrop(224), - transforms.ToTensor()]) - -def run_inference(args, image_list, Graphormer_model, smpl, renderer, mesh_sampler): - # switch to evaluate mode - Graphormer_model.eval() - smpl.eval() - with torch.no_grad(): - for image_file in image_list: - if 'pred' not in image_file: - att_all = [] - img = Image.open(image_file) - img_tensor = transform(img) - img_visual = transform_visualize(img) - - batch_imgs = torch.unsqueeze(img_tensor, 0).to(device) - batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device) - # forward-pass - pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, smpl, mesh_sampler) - - # obtain 3d joints from full mesh - pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices) - - pred_3d_pelvis = pred_3d_joints_from_smpl[:,cfg.H36M_J17_NAME.index('Pelvis'),:] - pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:] - pred_3d_joints_from_smpl = pred_3d_joints_from_smpl - pred_3d_pelvis[:, None, :] - pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :] - - # save attantion - att_max_value = att[-1] - att_cpu = np.asarray(att_max_value.cpu().detach()) - att_all.append(att_cpu) - - # obtain 3d joints, which are regressed from the full mesh - pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices) - pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:] - # obtain 2d joints, which are projected from 3d joints of smpl mesh - pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera) - pred_2d_431_vertices_from_smpl = orthographic_projection(pred_vertices_sub2, pred_camera) - visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0], - pred_vertices[0].detach(), - pred_camera.detach()) - # visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0], - # pred_vertices[0].detach(), - # pred_vertices_sub2[0].detach(), - # pred_2d_431_vertices_from_smpl[0].detach(), - # pred_2d_joints_from_smpl[0].detach(), - # pred_camera.detach(), - # att[-1][0].detach()) - - visual_imgs = visual_imgs_output.transpose(1,2,0) - visual_imgs = np.asarray(visual_imgs) - - temp_fname = image_file[:-4] + '_graphormer_pred.jpg' - print('save to ', temp_fname) - cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) - - return - -def visualize_mesh( renderer, images, - pred_vertices_full, - pred_camera): - img = images.cpu().numpy().transpose(1,2,0) - # Get predict vertices for the particular example - vertices_full = pred_vertices_full.cpu().numpy() - cam = pred_camera.cpu().numpy() - # Visualize only mesh reconstruction - rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue') - rend_img = rend_img.transpose(2,0,1) - return rend_img - -def visualize_mesh_and_attention( renderer, images, - pred_vertices_full, - pred_vertices, - pred_2d_vertices, - pred_2d_joints, - pred_camera, - attention): - img = images.cpu().numpy().transpose(1,2,0) - # Get predict vertices for the particular example - vertices_full = pred_vertices_full.cpu().numpy() - vertices = pred_vertices.cpu().numpy() - vertices_2d = pred_2d_vertices.cpu().numpy() - joints_2d = pred_2d_joints.cpu().numpy() - cam = pred_camera.cpu().numpy() - att = attention.cpu().numpy() - # Visualize reconstruction and attention - rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue') - rend_img = rend_img.transpose(2,0,1) - return rend_img - - -def parse_args(): - parser = argparse.ArgumentParser() - ######################################################### - # Data related arguments - ######################################################### - parser.add_argument("--num_workers", default=4, type=int, - help="Workers in dataloader.") - parser.add_argument("--img_scale_factor", default=1, type=int, - help="adjust image resolution.") - parser.add_argument("--image_file_or_path", default='./samples/human-body', type=str, - help="test data") - ######################################################### - # Loading/saving checkpoints - ######################################################### - parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, - help="Path to pre-trained transformer model or model type.") - parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, - help="Path to specific checkpoint for resume training.") - parser.add_argument("--output_dir", default='output/', type=str, required=False, - help="The output directory to save checkpoint and test results.") - parser.add_argument("--config_name", default="", type=str, - help="Pretrained config name or path if not the same as model_name.") - ######################################################### - # Model architectures - ######################################################### - parser.add_argument('-a', '--arch', default='hrnet-w64', - help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') - parser.add_argument("--num_hidden_layers", default=4, type=int, required=False, - help="Update model config if given") - parser.add_argument("--hidden_size", default=-1, type=int, required=False, - help="Update model config if given") - parser.add_argument("--num_attention_heads", default=4, type=int, required=False, - help="Update model config if given. Note that the division of " - "hidden_size / num_attention_heads should be in integer.") - parser.add_argument("--intermediate_size", default=-1, type=int, required=False, - help="Update model config if given.") - parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, - help="The Image Feature Dimension.") - parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, - help="The Image Feature Dimension.") - parser.add_argument("--which_gcn", default='0,0,1', type=str, - help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") - parser.add_argument("--mesh_type", default='body', type=str, help="body or hand") - parser.add_argument("--interm_size_scale", default=2, type=int) - ######################################################### - # Others - ######################################################### - parser.add_argument("--run_eval_only", default=True, action='store_true',) - parser.add_argument("--device", type=str, default='cuda', - help="cuda or cpu") - parser.add_argument('--seed', type=int, default=88, - help="random seed for initialization.") - - args = parser.parse_args() - return args - - -def main(args): - global logger - # Setup CUDA, GPU & distributed training - args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 - os.environ['OMP_NUM_THREADS'] = str(args.num_workers) - print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) - - args.distributed = args.num_gpus > 1 - args.device = torch.device(args.device) - - mkdir(args.output_dir) - logger = setup_logger("Graphormer", args.output_dir, get_rank()) - set_seed(args.seed, args.num_gpus) - logger.info("Using {} GPUs".format(args.num_gpus)) - - # Mesh and SMPL utils - smpl = SMPL().to(args.device) - mesh_sampler = Mesh() - - # Renderer for visualization - renderer = Renderer(faces=smpl.faces.cpu().numpy()) - - # Load model - trans_encoder = [] - - input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] - hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] - output_feat_dim = input_feat_dim[1:] + [3] - - # which encoder block to have graph convs - which_blk_graph = [int(item) for item in args.which_gcn.split(',')] - - if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: - # if only run eval, load checkpoint - logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) - _model = torch.load(args.resume_checkpoint) - else: - # init three transformer-encoder blocks in a loop - for i in range(len(output_feat_dim)): - config_class, model_class = BertConfig, Graphormer - config = config_class.from_pretrained(args.config_name if args.config_name \ - else args.model_name_or_path) - - config.output_attentions = False - config.img_feature_dim = input_feat_dim[i] - config.output_feature_dim = output_feat_dim[i] - args.hidden_size = hidden_feat_dim[i] - args.intermediate_size = int(args.hidden_size*args.interm_size_scale) - - if which_blk_graph[i]==1: - config.graph_conv = True - logger.info("Add Graph Conv") - else: - config.graph_conv = False - - config.mesh_type = args.mesh_type - - # update model structure if specified in arguments - update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] - - for idx, param in enumerate(update_params): - arg_param = getattr(args, param) - config_param = getattr(config, param) - if arg_param > 0 and arg_param != config_param: - logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) - setattr(config, param, arg_param) - - # init a transformer encoder and append it to a list - assert config.hidden_size % config.num_attention_heads == 0 - model = model_class(config=config) - logger.info("Init model from scratch.") - trans_encoder.append(model) - - # init ImageNet pre-trained backbone model - if args.arch=='hrnet': - hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' - hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' - hrnet_update_config(hrnet_config, hrnet_yaml) - backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) - logger.info('=> loading hrnet-v2-w40 model') - elif args.arch=='hrnet-w64': - hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' - hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' - hrnet_update_config(hrnet_config, hrnet_yaml) - backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) - logger.info('=> loading hrnet-v2-w64 model') - else: - print("=> using pre-trained model '{}'".format(args.arch)) - backbone = models.__dict__[args.arch](pretrained=True) - # remove the last fc layer - backbone = torch.nn.Sequential(*list(backbone.children())[:-2]) - - - trans_encoder = torch.nn.Sequential(*trans_encoder) - total_params = sum(p.numel() for p in trans_encoder.parameters()) - logger.info('Graphormer encoders total parameters: {}'.format(total_params)) - backbone_total_params = sum(p.numel() for p in backbone.parameters()) - logger.info('Backbone total parameters: {}'.format(backbone_total_params)) - - # build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder) - _model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler) - - if args.resume_checkpoint!=None and args.resume_checkpoint!='None': - # for fine-tuning or resume training or inference, load weights from checkpoint - logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) - # workaround approach to load sparse tensor in graph conv. - states = torch.load(args.resume_checkpoint) - # states = checkpoint_loaded.state_dict() - for k, v in states.items(): - states[k] = v.cpu() - # del checkpoint_loaded - _model.load_state_dict(states, strict=False) - del states - gc.collect() - torch.cuda.empty_cache() - - # update configs to enable attention outputs - setattr(_model.trans_encoder[-1].config,'output_attentions', True) - setattr(_model.trans_encoder[-1].config,'output_hidden_states', True) - _model.trans_encoder[-1].bert.encoder.output_attentions = True - _model.trans_encoder[-1].bert.encoder.output_hidden_states = True - for iter_layer in range(4): - _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True - for inter_block in range(3): - setattr(_model.trans_encoder[-1].config,'device', args.device) - - _model.to(args.device) - logger.info("Run inference") - - image_list = [] - if not args.image_file_or_path: - raise ValueError("image_file_or_path not specified") - if op.isfile(args.image_file_or_path): - image_list = [args.image_file_or_path] - elif op.isdir(args.image_file_or_path): - # should be a path with images only - for filename in os.listdir(args.image_file_or_path): - if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename: - image_list.append(args.image_file_or_path+'/'+filename) - else: - raise ValueError("Cannot find images at {}".format(args.image_file_or_path)) - - run_inference(args, image_list, _model, smpl, renderer, mesh_sampler) - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/mesh_graphormer/tools/run_gphmer_handmesh.py b/mesh_graphormer/tools/run_gphmer_handmesh.py deleted file mode 100644 index 4cfcf55..0000000 --- a/mesh_graphormer/tools/run_gphmer_handmesh.py +++ /dev/null @@ -1,713 +0,0 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT license. - -Training and evaluation codes for -3D hand mesh reconstruction from an image -""" - -from __future__ import absolute_import, division, print_function -import argparse -import os -import os.path as op -import code -import json -import time -import datetime -import torch -import torchvision.models as models -from torchvision.utils import make_grid -import gc -import numpy as np -import cv2 -from mesh_graphormer.modeling.bert import BertConfig, Graphormer -from mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network -from mesh_graphormer.modeling._mano import MANO, Mesh -from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat -from mesh_graphormer.modeling.hrnet.config import config as hrnet_config -from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config -import mesh_graphormer.modeling.data.config as cfg -from mesh_graphormer.datasets.build import make_hand_data_loader - -from mesh_graphormer.utils.logger import setup_logger -from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather -from mesh_graphormer.utils.miscellaneous import mkdir, set_seed -from mesh_graphormer.utils.metric_logger import AverageMeter -from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction, visualize_reconstruction_test, visualize_reconstruction_no_text -from mesh_graphormer.utils.metric_pampjpe import reconstruction_error -from mesh_graphormer.utils.geometric_layers import orthographic_projection - - -device = "cuda" - -from azureml.core.run import Run -aml_run = Run.get_context() - -def save_checkpoint(model, args, epoch, iteration, num_trial=10): - checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format( - epoch, iteration)) - if not is_main_process(): - return checkpoint_dir - mkdir(checkpoint_dir) - model_to_save = model.module if hasattr(model, 'module') else model - for i in range(num_trial): - try: - torch.save(model_to_save, op.join(checkpoint_dir, 'model.bin')) - torch.save(model_to_save.state_dict(), op.join(checkpoint_dir, 'state_dict.bin')) - torch.save(args, op.join(checkpoint_dir, 'training_args.bin')) - logger.info("Save checkpoint to {}".format(checkpoint_dir)) - break - except: - pass - else: - logger.info("Failed to save checkpoint after {} trails.".format(num_trial)) - return checkpoint_dir - -def adjust_learning_rate(optimizer, epoch, args): - """ - Sets the learning rate to the initial LR decayed by x every y epochs - x = 0.1, y = args.num_train_epochs/2.0 = 100 - """ - lr = args.lr * (0.1 ** (epoch // (args.num_train_epochs/2.0) )) - for param_group in optimizer.param_groups: - param_group['lr'] = lr - -def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, has_pose_2d): - """ - Compute 2D reprojection loss if 2D keypoint annotations are available. - The confidence is binary and indicates whether the keypoints exist or not. - """ - conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() - loss = (conf * criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean() - return loss - -def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d): - """ - Compute 3D keypoint loss if 3D keypoint annotations are available. - """ - conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() - gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone() - gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1] - conf = conf[has_pose_3d == 1] - pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1] - if len(gt_keypoints_3d) > 0: - gt_root = gt_keypoints_3d[:, 0,:] - gt_keypoints_3d = gt_keypoints_3d - gt_root[:, None, :] - pred_root = pred_keypoints_3d[:, 0,:] - pred_keypoints_3d = pred_keypoints_3d - pred_root[:, None, :] - return (conf * criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean() - else: - return torch.FloatTensor(1).fill_(0.).to(device) - -def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl): - """ - Compute per-vertex loss if vertex annotations are available. - """ - pred_vertices_with_shape = pred_vertices[has_smpl == 1] - gt_vertices_with_shape = gt_vertices[has_smpl == 1] - if len(gt_vertices_with_shape) > 0: - return criterion_vertices(pred_vertices_with_shape, gt_vertices_with_shape) - else: - return torch.FloatTensor(1).fill_(0.).to(device) - - -def run(args, train_dataloader, Graphormer_model, mano_model, renderer, mesh_sampler): - - max_iter = len(train_dataloader) - iters_per_epoch = max_iter // args.num_train_epochs - - optimizer = torch.optim.Adam(params=list(Graphormer_model.parameters()), - lr=args.lr, - betas=(0.9, 0.999), - weight_decay=0) - - # define loss function (criterion) and optimizer - criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').to(device) - criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device) - criterion_vertices = torch.nn.L1Loss().to(device) - - if args.distributed: - Graphormer_model = torch.nn.parallel.DistributedDataParallel( - Graphormer_model, device_ids=[args.local_rank], - output_device=args.local_rank, - find_unused_parameters=True, - ) - - start_training_time = time.time() - end = time.time() - Graphormer_model.train() - batch_time = AverageMeter() - data_time = AverageMeter() - log_losses = AverageMeter() - log_loss_2djoints = AverageMeter() - log_loss_3djoints = AverageMeter() - log_loss_vertices = AverageMeter() - - for iteration, (img_keys, images, annotations) in enumerate(train_dataloader): - - Graphormer_model.train() - iteration += 1 - epoch = iteration // iters_per_epoch - batch_size = images.size(0) - adjust_learning_rate(optimizer, epoch, args) - data_time.update(time.time() - end) - - images = images.to(device) - gt_2d_joints = annotations['joints_2d'].to(device) - gt_pose = annotations['pose'].to(device) - gt_betas = annotations['betas'].to(device) - has_mesh = annotations['has_smpl'].to(device) - has_3d_joints = has_mesh - has_2d_joints = has_mesh - mjm_mask = annotations['mjm_mask'].to(device) - mvm_mask = annotations['mvm_mask'].to(device) - - # generate mesh - gt_vertices, gt_3d_joints = mano_model.layer(gt_pose, gt_betas) - gt_vertices = gt_vertices/1000.0 - gt_3d_joints = gt_3d_joints/1000.0 - - gt_vertices_sub = mesh_sampler.downsample(gt_vertices) - # normalize gt based on hand's wrist - gt_3d_root = gt_3d_joints[:,cfg.J_NAME.index('Wrist'),:] - gt_vertices = gt_vertices - gt_3d_root[:, None, :] - gt_vertices_sub = gt_vertices_sub - gt_3d_root[:, None, :] - gt_3d_joints = gt_3d_joints - gt_3d_root[:, None, :] - gt_3d_joints_with_tag = torch.ones((batch_size,gt_3d_joints.shape[1],4)).to(device) - gt_3d_joints_with_tag[:,:,:3] = gt_3d_joints - - # prepare masks for mask vertex/joint modeling - mjm_mask_ = mjm_mask.expand(-1,-1,2051) - mvm_mask_ = mvm_mask.expand(-1,-1,2051) - meta_masks = torch.cat([mjm_mask_, mvm_mask_], dim=1) - - # forward-pass - pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler, meta_masks=meta_masks, is_train=True) - - # obtain 3d joints, which are regressed from the full mesh - pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices) - - # obtain 2d joints, which are projected from 3d joints of smpl mesh - pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) - pred_2d_joints = orthographic_projection(pred_3d_joints.contiguous(), pred_camera.contiguous()) - - # compute 3d joint loss (where the joints are directly output from transformer) - loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints_with_tag, has_3d_joints) - - # compute 3d vertex loss - loss_vertices = ( args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_mesh) + \ - args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_mesh) ) - - # compute 3d joint loss (where the joints are regressed from full mesh) - loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_mesh, gt_3d_joints_with_tag, has_3d_joints) - # compute 2d joint loss - loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints) + \ - keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_mesh, gt_2d_joints, has_2d_joints) - - loss_3d_joints = loss_3d_joints + loss_reg_3d_joints - - # we empirically use hyperparameters to balance difference losses - loss = args.joints_loss_weight*loss_3d_joints + \ - args.vertices_loss_weight*loss_vertices + args.vertices_loss_weight*loss_2d_joints - - # update logs - log_loss_2djoints.update(loss_2d_joints.item(), batch_size) - log_loss_3djoints.update(loss_3d_joints.item(), batch_size) - log_loss_vertices.update(loss_vertices.item(), batch_size) - log_losses.update(loss.item(), batch_size) - - # back prop - optimizer.zero_grad() - loss.backward() - optimizer.step() - - batch_time.update(time.time() - end) - end = time.time() - - if iteration % args.logging_steps == 0 or iteration == max_iter: - eta_seconds = batch_time.avg * (max_iter - iteration) - eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) - logger.info( - ' '.join( - ['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',] - ).format(eta=eta_string, ep=epoch, iter=iteration, - memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) - + ' loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format( - log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg, - optimizer.param_groups[0]['lr']) - ) - - aml_run.log(name='Loss', value=float(log_losses.avg)) - aml_run.log(name='3d joint Loss', value=float(log_loss_3djoints.avg)) - aml_run.log(name='2d joint Loss', value=float(log_loss_2djoints.avg)) - aml_run.log(name='vertex Loss', value=float(log_loss_vertices.avg)) - - visual_imgs = visualize_mesh( renderer, - annotations['ori_img'].detach(), - annotations['joints_2d'].detach(), - pred_vertices.detach(), - pred_camera.detach(), - pred_2d_joints_from_mesh.detach()) - visual_imgs = visual_imgs.transpose(0,1) - visual_imgs = visual_imgs.transpose(1,2) - visual_imgs = np.asarray(visual_imgs) - - if is_main_process()==True: - stamp = str(epoch) + '_' + str(iteration) - temp_fname = args.output_dir + 'visual_' + stamp + '.jpg' - cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) - aml_run.log_image(name='visual results', path=temp_fname) - - if iteration % iters_per_epoch == 0: - if epoch%10==0: - checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration) - - total_training_time = time.time() - start_training_time - total_time_str = str(datetime.timedelta(seconds=total_training_time)) - logger.info('Total training time: {} ({:.4f} s / iter)'.format( - total_time_str, total_training_time / max_iter) - ) - checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration) - -def run_eval_and_save(args, split, val_dataloader, Graphormer_model, mano_model, renderer, mesh_sampler): - - criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device) - criterion_vertices = torch.nn.L1Loss().to(device) - - if args.distributed: - Graphormer_model = torch.nn.parallel.DistributedDataParallel( - Graphormer_model, device_ids=[args.local_rank], - output_device=args.local_rank, - find_unused_parameters=True, - ) - Graphormer_model.eval() - - if args.aml_eval==True: - run_aml_inference_hand_mesh(args, val_dataloader, - Graphormer_model, - criterion_keypoints, - criterion_vertices, - 0, - mano_model, mesh_sampler, - renderer, split) - else: - run_inference_hand_mesh(args, val_dataloader, - Graphormer_model, - criterion_keypoints, - criterion_vertices, - 0, - mano_model, mesh_sampler, - renderer, split) - checkpoint_dir = save_checkpoint(Graphormer_model, args, 0, 0) - return - -def run_aml_inference_hand_mesh(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, mano_model, mesh_sampler, renderer, split): - # switch to evaluate mode - Graphormer_model.eval() - fname_output_save = [] - mesh_output_save = [] - joint_output_save = [] - world_size = get_world_size() - with torch.no_grad(): - for i, (img_keys, images, annotations) in enumerate(val_loader): - batch_size = images.size(0) - # compute output - images = images.to(device) - - # forward-pass - pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler) - # obtain 3d joints from full mesh - pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices) - - for j in range(batch_size): - fname_output_save.append(img_keys[j]) - pred_vertices_list = pred_vertices[j].tolist() - mesh_output_save.append(pred_vertices_list) - pred_3d_joints_from_mesh_list = pred_3d_joints_from_mesh[j].tolist() - joint_output_save.append(pred_3d_joints_from_mesh_list) - - if world_size > 1: - torch.distributed.barrier() - print('save results to pred.json') - output_json_file = 'pred.json' - print('save results to ', output_json_file) - with open(output_json_file, 'w') as f: - json.dump([joint_output_save, mesh_output_save], f) - - azure_ckpt_name = '200' # args.resume_checkpoint.split('/')[-2].split('-')[1] - inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot))) - output_zip_file = args.output_dir + 'ckpt' + azure_ckpt_name + '-' + inference_setting +'-pred.zip' - - resolved_submit_cmd = 'zip ' + output_zip_file + ' ' + output_json_file - print(resolved_submit_cmd) - os.system(resolved_submit_cmd) - resolved_submit_cmd = 'rm %s'%(output_json_file) - print(resolved_submit_cmd) - os.system(resolved_submit_cmd) - if world_size > 1: - torch.distributed.barrier() - - return - -def run_inference_hand_mesh(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, mano_model, mesh_sampler, renderer, split): - # switch to evaluate mode - Graphormer_model.eval() - fname_output_save = [] - mesh_output_save = [] - joint_output_save = [] - with torch.no_grad(): - for i, (img_keys, images, annotations) in enumerate(val_loader): - batch_size = images.size(0) - # compute output - images = images.to(device) - - # forward-pass - pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler) - - # obtain 3d joints from full mesh - pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices) - pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:] - pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :] - pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :] - - for j in range(batch_size): - fname_output_save.append(img_keys[j]) - pred_vertices_list = pred_vertices[j].tolist() - mesh_output_save.append(pred_vertices_list) - pred_3d_joints_from_mesh_list = pred_3d_joints_from_mesh[j].tolist() - joint_output_save.append(pred_3d_joints_from_mesh_list) - - if i%20==0: - # obtain 3d joints, which are regressed from the full mesh - pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices) - # obtain 2d joints, which are projected from 3d joints of mesh - pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) - visual_imgs = visualize_mesh( renderer, - annotations['ori_img'].detach(), - annotations['joints_2d'].detach(), - pred_vertices.detach(), - pred_camera.detach(), - pred_2d_joints_from_mesh.detach()) - - visual_imgs = visual_imgs.transpose(0,1) - visual_imgs = visual_imgs.transpose(1,2) - visual_imgs = np.asarray(visual_imgs) - - inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot))) - temp_fname = args.output_dir + args.resume_checkpoint[0:-9] + 'freihand_results_'+inference_setting+'_batch'+str(i)+'.jpg' - cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) - - print('save results to pred.json') - with open('pred.json', 'w') as f: - json.dump([joint_output_save, mesh_output_save], f) - - run_exp_name = args.resume_checkpoint.split('/')[-3] - run_ckpt_name = args.resume_checkpoint.split('/')[-2].split('-')[1] - inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot))) - resolved_submit_cmd = 'zip ' + args.output_dir + run_exp_name + '-ckpt'+ run_ckpt_name + '-' + inference_setting +'-pred.zip ' + 'pred.json' - print(resolved_submit_cmd) - os.system(resolved_submit_cmd) - resolved_submit_cmd = 'rm pred.json' - print(resolved_submit_cmd) - os.system(resolved_submit_cmd) - return - -def visualize_mesh( renderer, - images, - gt_keypoints_2d, - pred_vertices, - pred_camera, - pred_keypoints_2d): - """Tensorboard logging.""" - gt_keypoints_2d = gt_keypoints_2d.cpu().numpy() - to_lsp = list(range(21)) - rend_imgs = [] - batch_size = pred_vertices.shape[0] - # Do visualization for the first 6 images of the batch - for i in range(min(batch_size, 10)): - img = images[i].cpu().numpy().transpose(1,2,0) - # Get LSP keypoints from the full list of keypoints - gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp] - pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp] - # Get predict vertices for the particular example - vertices = pred_vertices[i].cpu().numpy() - cam = pred_camera[i].cpu().numpy() - # Visualize reconstruction and detected pose - rend_img = visualize_reconstruction(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer) - rend_img = rend_img.transpose(2,0,1) - rend_imgs.append(torch.from_numpy(rend_img)) - rend_imgs = make_grid(rend_imgs, nrow=1) - return rend_imgs - -def visualize_mesh_test( renderer, - images, - gt_keypoints_2d, - pred_vertices, - pred_camera, - pred_keypoints_2d, - PAmPJPE): - """Tensorboard logging.""" - gt_keypoints_2d = gt_keypoints_2d.cpu().numpy() - to_lsp = list(range(21)) - rend_imgs = [] - batch_size = pred_vertices.shape[0] - # Do visualization for the first 6 images of the batch - for i in range(min(batch_size, 10)): - img = images[i].cpu().numpy().transpose(1,2,0) - # Get LSP keypoints from the full list of keypoints - gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp] - pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp] - # Get predict vertices for the particular example - vertices = pred_vertices[i].cpu().numpy() - cam = pred_camera[i].cpu().numpy() - score = PAmPJPE[i] - # Visualize reconstruction and detected pose - rend_img = visualize_reconstruction_test(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer, score) - rend_img = rend_img.transpose(2,0,1) - rend_imgs.append(torch.from_numpy(rend_img)) - rend_imgs = make_grid(rend_imgs, nrow=1) - return rend_imgs - -def visualize_mesh_no_text( renderer, - images, - pred_vertices, - pred_camera): - """Tensorboard logging.""" - rend_imgs = [] - batch_size = pred_vertices.shape[0] - # Do visualization for the first 6 images of the batch - for i in range(min(batch_size, 1)): - img = images[i].cpu().numpy().transpose(1,2,0) - # Get predict vertices for the particular example - vertices = pred_vertices[i].cpu().numpy() - cam = pred_camera[i].cpu().numpy() - # Visualize reconstruction only - rend_img = visualize_reconstruction_no_text(img, 224, vertices, cam, renderer, color='hand') - rend_img = rend_img.transpose(2,0,1) - rend_imgs.append(torch.from_numpy(rend_img)) - rend_imgs = make_grid(rend_imgs, nrow=1) - return rend_imgs - -def parse_args(): - parser = argparse.ArgumentParser() - ######################################################### - # Data related arguments - ######################################################### - parser.add_argument("--data_dir", default='datasets', type=str, required=False, - help="Directory with all datasets, each in one subfolder") - parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False, - help="Yaml file with all data for training.") - parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False, - help="Yaml file with all data for validation.") - parser.add_argument("--num_workers", default=4, type=int, - help="Workers in dataloader.") - parser.add_argument("--img_scale_factor", default=1, type=int, - help="adjust image resolution.") - ######################################################### - # Loading/saving checkpoints - ######################################################### - parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, - help="Path to pre-trained transformer model or model type.") - parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, - help="Path to specific checkpoint for resume training.") - parser.add_argument("--output_dir", default='output/', type=str, required=False, - help="The output directory to save checkpoint and test results.") - parser.add_argument("--config_name", default="", type=str, - help="Pretrained config name or path if not the same as model_name.") - parser.add_argument('-a', '--arch', default='hrnet-w64', - help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') - ######################################################### - # Training parameters - ######################################################### - parser.add_argument("--per_gpu_train_batch_size", default=64, type=int, - help="Batch size per GPU/CPU for training.") - parser.add_argument("--per_gpu_eval_batch_size", default=64, type=int, - help="Batch size per GPU/CPU for evaluation.") - parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float, - help="The initial lr.") - parser.add_argument("--num_train_epochs", default=200, type=int, - help="Total number of training epochs to perform.") - parser.add_argument("--vertices_loss_weight", default=1.0, type=float) - parser.add_argument("--joints_loss_weight", default=1.0, type=float) - parser.add_argument("--vloss_w_full", default=0.5, type=float) - parser.add_argument("--vloss_w_sub", default=0.5, type=float) - parser.add_argument("--drop_out", default=0.1, type=float, - help="Drop out ratio in BERT.") - ######################################################### - # Model architectures - ######################################################### - parser.add_argument("--num_hidden_layers", default=-1, type=int, required=False, - help="Update model config if given") - parser.add_argument("--hidden_size", default=-1, type=int, required=False, - help="Update model config if given") - parser.add_argument("--num_attention_heads", default=-1, type=int, required=False, - help="Update model config if given. Note that the division of " - "hidden_size / num_attention_heads should be in integer.") - parser.add_argument("--intermediate_size", default=-1, type=int, required=False, - help="Update model config if given.") - parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, - help="The Image Feature Dimension.") - parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, - help="The Image Feature Dimension.") - parser.add_argument("--which_gcn", default='0,0,1', type=str, - help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") - parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand") - - ######################################################### - # Others - ######################################################### - parser.add_argument("--run_eval_only", default=False, action='store_true',) - parser.add_argument("--multiscale_inference", default=False, action='store_true',) - # if enable "multiscale_inference", dataloader will apply transformations to the test image based on - # the rotation "rot" and scale "sc" parameters below - parser.add_argument("--rot", default=0, type=float) - parser.add_argument("--sc", default=1.0, type=float) - parser.add_argument("--aml_eval", default=False, action='store_true',) - - parser.add_argument('--logging_steps', type=int, default=100, - help="Log every X steps.") - parser.add_argument("--device", type=str, default='cuda', - help="cuda or cpu") - parser.add_argument('--seed', type=int, default=88, - help="random seed for initialization.") - parser.add_argument("--local_rank", type=int, default=0, - help="For distributed training.") - args = parser.parse_args() - return args - -def main(args): - global logger - # Setup CUDA, GPU & distributed training - args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 - os.environ['OMP_NUM_THREADS'] = str(args.num_workers) - print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) - - args.distributed = args.num_gpus > 1 - args.device = torch.device(args.device) - if args.distributed: - print("Init distributed training on local rank {}".format(args.local_rank)) - torch.cuda.set_device(args.local_rank) - torch.distributed.init_process_group( - backend='nccl', init_method='env://' - ) - synchronize() - - mkdir(args.output_dir) - logger = setup_logger("Graphormer", args.output_dir, get_rank()) - set_seed(args.seed, args.num_gpus) - logger.info("Using {} GPUs".format(args.num_gpus)) - - # Mesh and SMPL utils - mano_model = MANO().to(args.device) - mano_model.layer = mano_model.layer.to(device) - mesh_sampler = Mesh() - - # Renderer for visualization - renderer = Renderer(faces=mano_model.face) - - # Load pretrained model - trans_encoder = [] - - input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] - hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] - output_feat_dim = input_feat_dim[1:] + [3] - - # which encoder block to have graph convs - which_blk_graph = [int(item) for item in args.which_gcn.split(',')] - - if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: - # if only run eval, load checkpoint - logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) - _model = torch.load(args.resume_checkpoint) - - else: - # init three transformer-encoder blocks in a loop - for i in range(len(output_feat_dim)): - config_class, model_class = BertConfig, Graphormer - config = config_class.from_pretrained(args.config_name if args.config_name \ - else args.model_name_or_path) - - config.output_attentions = False - config.hidden_dropout_prob = args.drop_out - config.img_feature_dim = input_feat_dim[i] - config.output_feature_dim = output_feat_dim[i] - args.hidden_size = hidden_feat_dim[i] - args.intermediate_size = int(args.hidden_size*2) - - if which_blk_graph[i]==1: - config.graph_conv = True - logger.info("Add Graph Conv") - else: - config.graph_conv = False - - config.mesh_type = args.mesh_type - - # update model structure if specified in arguments - update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] - for idx, param in enumerate(update_params): - arg_param = getattr(args, param) - config_param = getattr(config, param) - if arg_param > 0 and arg_param != config_param: - logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) - setattr(config, param, arg_param) - - # init a transformer encoder and append it to a list - assert config.hidden_size % config.num_attention_heads == 0 - model = model_class(config=config) - logger.info("Init model from scratch.") - trans_encoder.append(model) - - # create backbone model - if args.arch=='hrnet': - hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' - hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' - hrnet_update_config(hrnet_config, hrnet_yaml) - backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) - logger.info('=> loading hrnet-v2-w40 model') - elif args.arch=='hrnet-w64': - hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' - hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' - hrnet_update_config(hrnet_config, hrnet_yaml) - backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) - logger.info('=> loading hrnet-v2-w64 model') - else: - print("=> using pre-trained model '{}'".format(args.arch)) - backbone = models.__dict__[args.arch](pretrained=True) - # remove the last fc layer - backbone = torch.nn.Sequential(*list(backbone.children())[:-1]) - - trans_encoder = torch.nn.Sequential(*trans_encoder) - total_params = sum(p.numel() for p in trans_encoder.parameters()) - logger.info('Graphormer encoders total parameters: {}'.format(total_params)) - backbone_total_params = sum(p.numel() for p in backbone.parameters()) - logger.info('Backbone total parameters: {}'.format(backbone_total_params)) - - # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder) - _model = Graphormer_Network(args, config, backbone, trans_encoder) - - if args.resume_checkpoint!=None and args.resume_checkpoint!='None': - # for fine-tuning or resume training or inference, load weights from checkpoint - logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) - # workaround approach to load sparse tensor in graph conv. - state_dict = torch.load(args.resume_checkpoint) - _model.load_state_dict(state_dict, strict=False) - del state_dict - gc.collect() - torch.cuda.empty_cache() - - _model.to(args.device) - logger.info("Training parameters %s", args) - - if args.run_eval_only==True: - val_dataloader = make_hand_data_loader(args, args.val_yaml, - args.distributed, is_train=False, scale_factor=args.img_scale_factor) - run_eval_and_save(args, 'freihand', val_dataloader, _model, mano_model, renderer, mesh_sampler) - - else: - train_dataloader = make_hand_data_loader(args, args.train_yaml, - args.distributed, is_train=True, scale_factor=args.img_scale_factor) - run(args, train_dataloader, _model, mano_model, renderer, mesh_sampler) - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/mesh_graphormer/tools/run_gphmer_handmesh_inference.py b/mesh_graphormer/tools/run_gphmer_handmesh_inference.py deleted file mode 100644 index 04cd326..0000000 --- a/mesh_graphormer/tools/run_gphmer_handmesh_inference.py +++ /dev/null @@ -1,338 +0,0 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT license. - -End-to-end inference codes for -3D hand mesh reconstruction from an image -""" - -from __future__ import absolute_import, division, print_function -import argparse -import os -import os.path as op -import code -import json -import time -import datetime -import torch -import torchvision.models as models -from torchvision.utils import make_grid -import gc -import numpy as np -import cv2 -from mesh_graphormer.modeling.bert import BertConfig, Graphormer -from mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network -from mesh_graphormer.modeling._mano import MANO, Mesh -from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat -from mesh_graphormer.modeling.hrnet.config import config as hrnet_config -from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config -import mesh_graphormer.modeling.data.config as cfg -from mesh_graphormer.datasets.build import make_hand_data_loader - -from mesh_graphormer.utils.logger import setup_logger -from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather -from mesh_graphormer.utils.miscellaneous import mkdir, set_seed -from mesh_graphormer.utils.metric_logger import AverageMeter -from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text -from mesh_graphormer.utils.metric_pampjpe import reconstruction_error -from mesh_graphormer.utils.geometric_layers import orthographic_projection - -from PIL import Image -from torchvision import transforms - - -device = "cuda" - -transform = transforms.Compose([ - transforms.Resize(224), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225])]) - -transform_visualize = transforms.Compose([ - transforms.Resize(224), - transforms.CenterCrop(224), - transforms.ToTensor()]) - -def run_inference(args, image_list, Graphormer_model, mano, renderer, mesh_sampler): -# switch to evaluate mode - Graphormer_model.eval() - mano.eval() - with torch.no_grad(): - for image_file in image_list: - if 'pred' not in image_file: - att_all = [] - print(image_file) - img = Image.open(image_file) - img_tensor = transform(img) - img_visual = transform_visualize(img) - - batch_imgs = torch.unsqueeze(img_tensor, 0).to(device) - batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device) - # forward-pass - pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler) - # obtain 3d joints from full mesh - pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) - pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:] - pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :] - pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :] - - # save attantion - att_max_value = att[-1] - att_cpu = np.asarray(att_max_value.cpu().detach()) - att_all.append(att_cpu) - - # obtain 3d joints, which are regressed from the full mesh - pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) - # obtain 2d joints, which are projected from 3d joints of mesh - pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) - pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous()) - - - visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0], - pred_vertices[0].detach(), - pred_camera.detach()) - # visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0], - # pred_vertices[0].detach(), - # pred_vertices_sub[0].detach(), - # pred_2d_coarse_vertices_from_mesh[0].detach(), - # pred_2d_joints_from_mesh[0].detach(), - # pred_camera.detach(), - # att[-1][0].detach()) - visual_imgs = visual_imgs_output.transpose(1,2,0) - visual_imgs = np.asarray(visual_imgs) - - temp_fname = image_file[:-4] + '_graphormer_pred.jpg' - print('save to ', temp_fname) - cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) - return - -def visualize_mesh( renderer, images, - pred_vertices_full, - pred_camera): - img = images.cpu().numpy().transpose(1,2,0) - # Get predict vertices for the particular example - vertices_full = pred_vertices_full.cpu().numpy() - cam = pred_camera.cpu().numpy() - # Visualize only mesh reconstruction - rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue') - rend_img = rend_img.transpose(2,0,1) - return rend_img - -def visualize_mesh_and_attention( renderer, images, - pred_vertices_full, - pred_vertices, - pred_2d_vertices, - pred_2d_joints, - pred_camera, - attention): - img = images.cpu().numpy().transpose(1,2,0) - # Get predict vertices for the particular example - vertices_full = pred_vertices_full.cpu().numpy() - vertices = pred_vertices.cpu().numpy() - vertices_2d = pred_2d_vertices.cpu().numpy() - joints_2d = pred_2d_joints.cpu().numpy() - cam = pred_camera.cpu().numpy() - att = attention.cpu().numpy() - # Visualize reconstruction and attention - rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue') - rend_img = rend_img.transpose(2,0,1) - return rend_img - -def parse_args(): - parser = argparse.ArgumentParser() - ######################################################### - # Data related arguments - ######################################################### - parser.add_argument("--num_workers", default=4, type=int, - help="Workers in dataloader.") - parser.add_argument("--img_scale_factor", default=1, type=int, - help="adjust image resolution.") - parser.add_argument("--image_file_or_path", default='./samples/hand', type=str, - help="test data") - ######################################################### - # Loading/saving checkpoints - ######################################################### - parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, - help="Path to pre-trained transformer model or model type.") - parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, - help="Path to specific checkpoint for resume training.") - parser.add_argument("--output_dir", default='output/', type=str, required=False, - help="The output directory to save checkpoint and test results.") - parser.add_argument("--config_name", default="", type=str, - help="Pretrained config name or path if not the same as model_name.") - parser.add_argument('-a', '--arch', default='hrnet-w64', - help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') - ######################################################### - # Model architectures - ######################################################### - parser.add_argument("--num_hidden_layers", default=4, type=int, required=False, - help="Update model config if given") - parser.add_argument("--hidden_size", default=-1, type=int, required=False, - help="Update model config if given") - parser.add_argument("--num_attention_heads", default=4, type=int, required=False, - help="Update model config if given. Note that the division of " - "hidden_size / num_attention_heads should be in integer.") - parser.add_argument("--intermediate_size", default=-1, type=int, required=False, - help="Update model config if given.") - parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, - help="The Image Feature Dimension.") - parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, - help="The Image Feature Dimension.") - parser.add_argument("--which_gcn", default='0,0,1', type=str, - help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") - parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand") - - ######################################################### - # Others - ######################################################### - parser.add_argument("--run_eval_only", default=True, action='store_true',) - parser.add_argument("--device", type=str, default='cuda', - help="cuda or cpu") - parser.add_argument('--seed', type=int, default=88, - help="random seed for initialization.") - args = parser.parse_args() - return args - -def main(args): - global logger - # Setup CUDA, GPU & distributed training - args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 - os.environ['OMP_NUM_THREADS'] = str(args.num_workers) - print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) - - mkdir(args.output_dir) - logger = setup_logger("Graphormer", args.output_dir, get_rank()) - set_seed(args.seed, args.num_gpus) - logger.info("Using {} GPUs".format(args.num_gpus)) - - # Mesh and MANO utils - mano_model = MANO().to(args.device) - mano_model.layer = mano_model.layer.to(device) - mesh_sampler = Mesh() - - # Renderer for visualization - renderer = Renderer(faces=mano_model.face) - - # Load pretrained model - trans_encoder = [] - - input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] - hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] - output_feat_dim = input_feat_dim[1:] + [3] - - # which encoder block to have graph convs - which_blk_graph = [int(item) for item in args.which_gcn.split(',')] - - if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: - # if only run eval, load checkpoint - logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) - _model = torch.load(args.resume_checkpoint) - - else: - # init three transformer-encoder blocks in a loop - for i in range(len(output_feat_dim)): - config_class, model_class = BertConfig, Graphormer - config = config_class.from_pretrained(args.config_name if args.config_name \ - else args.model_name_or_path) - - config.output_attentions = False - config.img_feature_dim = input_feat_dim[i] - config.output_feature_dim = output_feat_dim[i] - args.hidden_size = hidden_feat_dim[i] - args.intermediate_size = int(args.hidden_size*2) - - if which_blk_graph[i]==1: - config.graph_conv = True - logger.info("Add Graph Conv") - else: - config.graph_conv = False - - config.mesh_type = args.mesh_type - - # update model structure if specified in arguments - update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] - for idx, param in enumerate(update_params): - arg_param = getattr(args, param) - config_param = getattr(config, param) - if arg_param > 0 and arg_param != config_param: - logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) - setattr(config, param, arg_param) - - # init a transformer encoder and append it to a list - assert config.hidden_size % config.num_attention_heads == 0 - model = model_class(config=config) - logger.info("Init model from scratch.") - trans_encoder.append(model) - - # create backbone model - if args.arch=='hrnet': - hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' - hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' - hrnet_update_config(hrnet_config, hrnet_yaml) - backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) - logger.info('=> loading hrnet-v2-w40 model') - elif args.arch=='hrnet-w64': - hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' - hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' - hrnet_update_config(hrnet_config, hrnet_yaml) - backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) - logger.info('=> loading hrnet-v2-w64 model') - else: - print("=> using pre-trained model '{}'".format(args.arch)) - backbone = models.__dict__[args.arch](pretrained=True) - # remove the last fc layer - backbone = torch.nn.Sequential(*list(backbone.children())[:-1]) - - trans_encoder = torch.nn.Sequential(*trans_encoder) - total_params = sum(p.numel() for p in trans_encoder.parameters()) - logger.info('Graphormer encoders total parameters: {}'.format(total_params)) - backbone_total_params = sum(p.numel() for p in backbone.parameters()) - logger.info('Backbone total parameters: {}'.format(backbone_total_params)) - - # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder) - _model = Graphormer_Network(args, config, backbone, trans_encoder) - - if args.resume_checkpoint!=None and args.resume_checkpoint!='None': - # for fine-tuning or resume training or inference, load weights from checkpoint - logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) - # workaround approach to load sparse tensor in graph conv. - state_dict = torch.load(args.resume_checkpoint) - _model.load_state_dict(state_dict, strict=False) - del state_dict - gc.collect() - torch.cuda.empty_cache() - - # update configs to enable attention outputs - setattr(_model.trans_encoder[-1].config,'output_attentions', True) - setattr(_model.trans_encoder[-1].config,'output_hidden_states', True) - _model.trans_encoder[-1].bert.encoder.output_attentions = True - _model.trans_encoder[-1].bert.encoder.output_hidden_states = True - for iter_layer in range(4): - _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True - for inter_block in range(3): - setattr(_model.trans_encoder[-1].config,'device', args.device) - - _model.to(args.device) - logger.info("Run inference") - - image_list = [] - if not args.image_file_or_path: - raise ValueError("image_file_or_path not specified") - if op.isfile(args.image_file_or_path): - image_list = [args.image_file_or_path] - elif op.isdir(args.image_file_or_path): - # should be a path with images only - for filename in os.listdir(args.image_file_or_path): - if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename: - image_list.append(args.image_file_or_path+'/'+filename) - else: - raise ValueError("Cannot find images at {}".format(args.image_file_or_path)) - - run_inference(args, image_list, _model, mano_model, renderer, mesh_sampler) - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/mesh_graphormer/tools/run_hand_multiscale.py b/mesh_graphormer/tools/run_hand_multiscale.py deleted file mode 100644 index c00583c..0000000 --- a/mesh_graphormer/tools/run_hand_multiscale.py +++ /dev/null @@ -1,136 +0,0 @@ -from __future__ import absolute_import, division, print_function - -import argparse -import os -import os.path as op -import code -import json -import zipfile -import torch -import numpy as np -from mesh_graphormer.utils.metric_pampjpe import get_alignMesh - - -def load_pred_json(filepath): - archive = zipfile.ZipFile(filepath, 'r') - jsondata = archive.read('pred.json') - reference = json.loads(jsondata.decode("utf-8")) - return reference[0], reference[1] - - -def multiscale_fusion(output_dir): - s = '10' - filepath = output_dir+'ckpt200-sc10_rot0-pred.zip' - ref_joints, ref_vertices = load_pred_json(filepath) - ref_joints_array = np.asarray(ref_joints) - ref_vertices_array = np.asarray(ref_vertices) - - rotations = [0.0] - for i in range(1,10): - rotations.append(i*10) - rotations.append(i*-10) - - scale = [0.7,0.8,0.9,1.0,1.1] - multiscale_joints = [] - multiscale_vertices = [] - - counter = 0 - for s in scale: - for r in rotations: - setting = 'sc%02d_rot%s'%(int(s*10),str(int(r))) - filepath = output_dir+'ckpt200-'+setting+'-pred.zip' - joints, vertices = load_pred_json(filepath) - joints_array = np.asarray(joints) - vertices_array = np.asarray(vertices) - - pa_joint_error, pa_joint_array, _ = get_alignMesh(joints_array, ref_joints_array, reduction=None) - pa_vertices_error, pa_vertices_array, _ = get_alignMesh(vertices_array, ref_vertices_array, reduction=None) - print('--------------------------') - print('scale:', s, 'rotate', r) - print('PAMPJPE:', 1000*np.mean(pa_joint_error)) - print('PAMPVPE:', 1000*np.mean(pa_vertices_error)) - multiscale_joints.append(pa_joint_array) - multiscale_vertices.append(pa_vertices_array) - counter = counter + 1 - - overall_joints_array = ref_joints_array.copy() - overall_vertices_array = ref_vertices_array.copy() - for i in range(counter): - overall_joints_array += multiscale_joints[i] - overall_vertices_array += multiscale_vertices[i] - - overall_joints_array /= (1+counter) - overall_vertices_array /= (1+counter) - pa_joint_error, pa_joint_array, _ = get_alignMesh(overall_joints_array, ref_joints_array, reduction=None) - pa_vertices_error, pa_vertices_array, _ = get_alignMesh(overall_vertices_array, ref_vertices_array, reduction=None) - print('--------------------------') - print('overall:') - print('PAMPJPE:', 1000*np.mean(pa_joint_error)) - print('PAMPVPE:', 1000*np.mean(pa_vertices_error)) - - joint_output_save = overall_joints_array.tolist() - mesh_output_save = overall_vertices_array.tolist() - - print('save results to pred.json') - with open('pred.json', 'w') as f: - json.dump([joint_output_save, mesh_output_save], f) - - - filepath = output_dir+'ckpt200-multisc-pred.zip' - resolved_submit_cmd = 'zip ' + filepath + ' ' + 'pred.json' - print(resolved_submit_cmd) - os.system(resolved_submit_cmd) - resolved_submit_cmd = 'rm pred.json' - print(resolved_submit_cmd) - os.system(resolved_submit_cmd) - - -def run_multiscale_inference(model_path, mode, output_dir): - - if mode==True: - rotations = [0.0] - for i in range(1,10): - rotations.append(i*10) - rotations.append(i*-10) - scale = [0.7,0.8,0.9,1.0,1.1] - else: - rotations = [0.0] - scale = [1.0] - - job_cmd = "python ./src/tools/run_gphmer_handmesh.py " \ - "--val_yaml freihand_v3/test.yaml " \ - "--resume_checkpoint %s " \ - "--per_gpu_eval_batch_size 32 --run_eval_only --num_worker 2 " \ - "--multiscale_inference " \ - "--rot %f " \ - "--sc %s " \ - "--arch hrnet-w64 " \ - "--num_hidden_layers 4 " \ - "--num_attention_heads 4 " \ - "--input_feat_dim 2051,512,128 " \ - "--hidden_feat_dim 1024,256,64 " \ - "--output_dir %s" - - for s in scale: - for r in rotations: - resolved_submit_cmd = job_cmd%(model_path, r, s, output_dir) - print(resolved_submit_cmd) - os.system(resolved_submit_cmd) - -def main(args): - model_path = args.model_path - mode = args.multiscale_inference - output_dir = args.output_dir - run_multiscale_inference(model_path, mode, output_dir) - if mode==True: - multiscale_fusion(output_dir) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Evaluate a checkpoint in the folder") - parser.add_argument("--model_path") - parser.add_argument("--multiscale_inference", default=False, action='store_true',) - parser.add_argument("--output_dir", default='output/', type=str, required=False, - help="The output directory to save checkpoint and test results.") - args = parser.parse_args() - main(args)