mirror of
https://github.com/huchenlei/HandRefinerPortable.git
synced 2026-01-26 15:49:45 +00:00
✨ Initial commit
This commit is contained in:
161
.gitignore
vendored
Normal file
161
.gitignore
vendored
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
42
hand_refiner/__init__.py
Normal file
42
hand_refiner/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from .util import resize_image_with_pad, common_input_validate, HWC3, custom_hf_download
|
||||||
|
from hand_refiner.pipeline import MeshGraphormerMediapipe, args
|
||||||
|
|
||||||
|
class MeshGraphormerDetector:
|
||||||
|
def __init__(self, pipeline):
|
||||||
|
self.pipeline = pipeline
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device="cuda"):
|
||||||
|
filename = filename or "graphormer_hand_state_dict.bin"
|
||||||
|
hrnet_filename = hrnet_filename or "hrnetv2_w64_imagenet_pretrained.pth"
|
||||||
|
args.resume_checkpoint = custom_hf_download(pretrained_model_or_path, filename, cache_dir)
|
||||||
|
args.hrnet_checkpoint = custom_hf_download(pretrained_model_or_path, hrnet_filename, cache_dir)
|
||||||
|
args.device = device
|
||||||
|
pipeline = MeshGraphormerMediapipe(args)
|
||||||
|
return cls(pipeline)
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
self.pipeline._model.to(device)
|
||||||
|
self.pipeline.mano_model.to(device)
|
||||||
|
self.pipeline.mano_model.layer.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __call__(self, input_image=None, mask_bbox_padding=30, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
|
||||||
|
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
||||||
|
|
||||||
|
depth_map, mask, info = self.pipeline.get_depth(input_image, mask_bbox_padding)
|
||||||
|
if depth_map is None:
|
||||||
|
depth_map = np.zeros_like(input_image)
|
||||||
|
mask = np.zeros_like(input_image)
|
||||||
|
|
||||||
|
#The hand is small
|
||||||
|
depth_map, mask = HWC3(depth_map), HWC3(mask)
|
||||||
|
depth_map, remove_pad = resize_image_with_pad(depth_map, detect_resolution, upscale_method)
|
||||||
|
depth_map = remove_pad(depth_map)
|
||||||
|
if output_type == "pil":
|
||||||
|
depth_map = Image.fromarray(depth_map)
|
||||||
|
mask = Image.fromarray(mask)
|
||||||
|
|
||||||
|
return depth_map, mask, info
|
||||||
92
hand_refiner/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml
Normal file
92
hand_refiner/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
GPUS: (0,1,2,3)
|
||||||
|
LOG_DIR: 'log/'
|
||||||
|
DATA_DIR: ''
|
||||||
|
OUTPUT_DIR: 'output/'
|
||||||
|
WORKERS: 4
|
||||||
|
PRINT_FREQ: 1000
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
NAME: cls_hrnet
|
||||||
|
IMAGE_SIZE:
|
||||||
|
- 224
|
||||||
|
- 224
|
||||||
|
EXTRA:
|
||||||
|
STAGE1:
|
||||||
|
NUM_MODULES: 1
|
||||||
|
NUM_RANCHES: 1
|
||||||
|
BLOCK: BOTTLENECK
|
||||||
|
NUM_BLOCKS:
|
||||||
|
- 4
|
||||||
|
NUM_CHANNELS:
|
||||||
|
- 64
|
||||||
|
FUSE_METHOD: SUM
|
||||||
|
STAGE2:
|
||||||
|
NUM_MODULES: 1
|
||||||
|
NUM_BRANCHES: 2
|
||||||
|
BLOCK: BASIC
|
||||||
|
NUM_BLOCKS:
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
NUM_CHANNELS:
|
||||||
|
- 64
|
||||||
|
- 128
|
||||||
|
FUSE_METHOD: SUM
|
||||||
|
STAGE3:
|
||||||
|
NUM_MODULES: 4
|
||||||
|
NUM_BRANCHES: 3
|
||||||
|
BLOCK: BASIC
|
||||||
|
NUM_BLOCKS:
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
NUM_CHANNELS:
|
||||||
|
- 64
|
||||||
|
- 128
|
||||||
|
- 256
|
||||||
|
FUSE_METHOD: SUM
|
||||||
|
STAGE4:
|
||||||
|
NUM_MODULES: 3
|
||||||
|
NUM_BRANCHES: 4
|
||||||
|
BLOCK: BASIC
|
||||||
|
NUM_BLOCKS:
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
NUM_CHANNELS:
|
||||||
|
- 64
|
||||||
|
- 128
|
||||||
|
- 256
|
||||||
|
- 512
|
||||||
|
FUSE_METHOD: SUM
|
||||||
|
CUDNN:
|
||||||
|
BENCHMARK: true
|
||||||
|
DETERMINISTIC: false
|
||||||
|
ENABLED: true
|
||||||
|
DATASET:
|
||||||
|
DATASET: 'imagenet'
|
||||||
|
DATA_FORMAT: 'jpg'
|
||||||
|
ROOT: 'data/imagenet/'
|
||||||
|
TEST_SET: 'val'
|
||||||
|
TRAIN_SET: 'train'
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE_PER_GPU: 32
|
||||||
|
MODEL_FILE: ''
|
||||||
|
TRAIN:
|
||||||
|
BATCH_SIZE_PER_GPU: 32
|
||||||
|
BEGIN_EPOCH: 0
|
||||||
|
END_EPOCH: 100
|
||||||
|
RESUME: true
|
||||||
|
LR_FACTOR: 0.1
|
||||||
|
LR_STEP:
|
||||||
|
- 30
|
||||||
|
- 60
|
||||||
|
- 90
|
||||||
|
OPTIMIZER: sgd
|
||||||
|
LR: 0.05
|
||||||
|
WD: 0.0001
|
||||||
|
MOMENTUM: 0.9
|
||||||
|
NESTEROV: true
|
||||||
|
SHUFFLE: true
|
||||||
|
DEBUG:
|
||||||
|
DEBUG: false
|
||||||
6
hand_refiner/depth_preprocessor.py
Normal file
6
hand_refiner/depth_preprocessor.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
class Preprocessor:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_depth(self, input_dir, file_name):
|
||||||
|
return
|
||||||
BIN
hand_refiner/hand_landmarker.task
Normal file
BIN
hand_refiner/hand_landmarker.task
Normal file
Binary file not shown.
468
hand_refiner/pipeline.py
Normal file
468
hand_refiner/pipeline.py
Normal file
@@ -0,0 +1,468 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import gc
|
||||||
|
import numpy as np
|
||||||
|
from hand_refiner.depth_preprocessor import Preprocessor
|
||||||
|
|
||||||
|
import torchvision.models as models
|
||||||
|
from mesh_graphormer.modeling.bert import BertConfig, Graphormer
|
||||||
|
from mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network
|
||||||
|
from mesh_graphormer.modeling._mano import MANO, Mesh
|
||||||
|
from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
|
||||||
|
from mesh_graphormer.modeling.hrnet.config import config as hrnet_config
|
||||||
|
from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
|
||||||
|
from mesh_graphormer.utils.miscellaneous import set_seed
|
||||||
|
from argparse import Namespace
|
||||||
|
from pathlib import Path
|
||||||
|
import cv2
|
||||||
|
from torchvision import transforms
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from trimesh import Trimesh
|
||||||
|
from trimesh.ray.ray_triangle import RayMeshIntersector
|
||||||
|
import mediapipe as mp
|
||||||
|
from mediapipe.tasks import python
|
||||||
|
from mediapipe.tasks.python import vision
|
||||||
|
from torchvision import transforms
|
||||||
|
from pathlib import Path
|
||||||
|
import mesh_graphormer
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
args = Namespace(
|
||||||
|
num_workers=4,
|
||||||
|
img_scale_factor=1,
|
||||||
|
image_file_or_path=os.path.join('', 'MeshGraphormer', 'samples', 'hand'),
|
||||||
|
model_name_or_path=str(Path(mesh_graphormer.__file__).parent / "modeling/bert/bert-base-uncased"),
|
||||||
|
resume_checkpoint=None,
|
||||||
|
output_dir='output/',
|
||||||
|
config_name='',
|
||||||
|
a='hrnet-w64',
|
||||||
|
arch='hrnet-w64',
|
||||||
|
num_hidden_layers=4,
|
||||||
|
hidden_size=-1,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=-1,
|
||||||
|
input_feat_dim='2051,512,128',
|
||||||
|
hidden_feat_dim='1024,256,64',
|
||||||
|
which_gcn='0,0,1',
|
||||||
|
mesh_type='hand',
|
||||||
|
run_eval_only=True,
|
||||||
|
device="cpu",
|
||||||
|
seed=88,
|
||||||
|
hrnet_checkpoint=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
#Since mediapipe v0.10.5, the hand category has been correct
|
||||||
|
if version.parse(mp.__version__) >= version.parse('0.10.5'):
|
||||||
|
true_hand_category = {"Right": "right", "Left": "left"}
|
||||||
|
else:
|
||||||
|
true_hand_category = {"Right": "left", "Left": "right"}
|
||||||
|
|
||||||
|
class MeshGraphormerMediapipe(Preprocessor):
|
||||||
|
def __init__(self, args=args) -> None:
|
||||||
|
# Setup CUDA, GPU & distributed training
|
||||||
|
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
||||||
|
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
|
||||||
|
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
|
||||||
|
|
||||||
|
#mkdir(args.output_dir)
|
||||||
|
#logger = setup_logger("Graphormer", args.output_dir, get_rank())
|
||||||
|
set_seed(args.seed, args.num_gpus)
|
||||||
|
#logger.info("Using {} GPUs".format(args.num_gpus))
|
||||||
|
|
||||||
|
# Mesh and MANO utils
|
||||||
|
mano_model = MANO().to(args.device)
|
||||||
|
mano_model.layer = mano_model.layer.to(args.device)
|
||||||
|
mesh_sampler = Mesh(device=args.device)
|
||||||
|
|
||||||
|
# Renderer for visualization
|
||||||
|
# renderer = Renderer(faces=mano_model.face)
|
||||||
|
|
||||||
|
# Load pretrained model
|
||||||
|
trans_encoder = []
|
||||||
|
|
||||||
|
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
|
||||||
|
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
|
||||||
|
output_feat_dim = input_feat_dim[1:] + [3]
|
||||||
|
|
||||||
|
# which encoder block to have graph convs
|
||||||
|
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
|
||||||
|
|
||||||
|
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
|
||||||
|
# if only run eval, load checkpoint
|
||||||
|
#logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
|
||||||
|
_model = torch.load(args.resume_checkpoint)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# init three transformer-encoder blocks in a loop
|
||||||
|
for i in range(len(output_feat_dim)):
|
||||||
|
config_class, model_class = BertConfig, Graphormer
|
||||||
|
config = config_class.from_pretrained(args.config_name if args.config_name \
|
||||||
|
else args.model_name_or_path)
|
||||||
|
|
||||||
|
config.output_attentions = False
|
||||||
|
config.img_feature_dim = input_feat_dim[i]
|
||||||
|
config.output_feature_dim = output_feat_dim[i]
|
||||||
|
args.hidden_size = hidden_feat_dim[i]
|
||||||
|
args.intermediate_size = int(args.hidden_size*2)
|
||||||
|
|
||||||
|
if which_blk_graph[i]==1:
|
||||||
|
config.graph_conv = True
|
||||||
|
#logger.info("Add Graph Conv")
|
||||||
|
else:
|
||||||
|
config.graph_conv = False
|
||||||
|
|
||||||
|
config.mesh_type = args.mesh_type
|
||||||
|
|
||||||
|
# update model structure if specified in arguments
|
||||||
|
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
|
||||||
|
for idx, param in enumerate(update_params):
|
||||||
|
arg_param = getattr(args, param)
|
||||||
|
config_param = getattr(config, param)
|
||||||
|
if arg_param > 0 and arg_param != config_param:
|
||||||
|
#logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
|
||||||
|
setattr(config, param, arg_param)
|
||||||
|
|
||||||
|
# init a transformer encoder and append it to a list
|
||||||
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
|
model = model_class(config=config)
|
||||||
|
#logger.info("Init model from scratch.")
|
||||||
|
trans_encoder.append(model)
|
||||||
|
|
||||||
|
# create backbone model
|
||||||
|
if args.arch=='hrnet':
|
||||||
|
hrnet_yaml = Path(__file__).parent / 'cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||||
|
hrnet_checkpoint = args.hrnet_checkpoint
|
||||||
|
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||||
|
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||||
|
#logger.info('=> loading hrnet-v2-w40 model')
|
||||||
|
elif args.arch=='hrnet-w64':
|
||||||
|
hrnet_yaml = Path(__file__).parent / 'cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||||
|
hrnet_checkpoint = args.hrnet_checkpoint
|
||||||
|
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||||
|
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||||
|
#logger.info('=> loading hrnet-v2-w64 model')
|
||||||
|
else:
|
||||||
|
print("=> using pre-trained model '{}'".format(args.arch))
|
||||||
|
backbone = models.__dict__[args.arch](pretrained=True)
|
||||||
|
# remove the last fc layer
|
||||||
|
backbone = torch.nn.Sequential(*list(backbone.children())[:-1])
|
||||||
|
|
||||||
|
trans_encoder = torch.nn.Sequential(*trans_encoder)
|
||||||
|
total_params = sum(p.numel() for p in trans_encoder.parameters())
|
||||||
|
#logger.info('Graphormer encoders total parameters: {}'.format(total_params))
|
||||||
|
backbone_total_params = sum(p.numel() for p in backbone.parameters())
|
||||||
|
#logger.info('Backbone total parameters: {}'.format(backbone_total_params))
|
||||||
|
|
||||||
|
# build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
|
||||||
|
_model = Graphormer_Network(args, config, backbone, trans_encoder)
|
||||||
|
|
||||||
|
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
|
||||||
|
# for fine-tuning or resume training or inference, load weights from checkpoint
|
||||||
|
#logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
|
||||||
|
# workaround approach to load sparse tensor in graph conv.
|
||||||
|
state_dict = torch.load(args.resume_checkpoint)
|
||||||
|
_model.load_state_dict(state_dict, strict=False)
|
||||||
|
del state_dict
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# update configs to enable attention outputs
|
||||||
|
setattr(_model.trans_encoder[-1].config,'output_attentions', True)
|
||||||
|
setattr(_model.trans_encoder[-1].config,'output_hidden_states', True)
|
||||||
|
_model.trans_encoder[-1].bert.encoder.output_attentions = True
|
||||||
|
_model.trans_encoder[-1].bert.encoder.output_hidden_states = True
|
||||||
|
for iter_layer in range(4):
|
||||||
|
_model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True
|
||||||
|
for inter_block in range(3):
|
||||||
|
setattr(_model.trans_encoder[-1].config,'device', args.device)
|
||||||
|
|
||||||
|
_model.to(args.device)
|
||||||
|
self._model = _model
|
||||||
|
self.mano_model = mano_model
|
||||||
|
self.mesh_sampler = mesh_sampler
|
||||||
|
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225])])
|
||||||
|
|
||||||
|
base_options = python.BaseOptions(model_asset_path=str( Path(__file__).parent / "hand_landmarker.task" ))
|
||||||
|
options = vision.HandLandmarkerOptions(base_options=base_options,
|
||||||
|
min_hand_detection_confidence=0.6,
|
||||||
|
min_hand_presence_confidence=0.6,
|
||||||
|
min_tracking_confidence=0.6,
|
||||||
|
num_hands=2)
|
||||||
|
|
||||||
|
self.detector = vision.HandLandmarker.create_from_options(options)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rays(self, W, H, fx, fy, cx, cy, c2w_t, center_pixels): # rot = I
|
||||||
|
|
||||||
|
j, i = np.meshgrid(np.arange(H, dtype=np.float32), np.arange(W, dtype=np.float32))
|
||||||
|
if center_pixels:
|
||||||
|
i = i.copy() + 0.5
|
||||||
|
j = j.copy() + 0.5
|
||||||
|
|
||||||
|
directions = np.stack([(i - cx) / fx, (j - cy) / fy, np.ones_like(i)], -1)
|
||||||
|
directions /= np.linalg.norm(directions, axis=-1, keepdims=True)
|
||||||
|
|
||||||
|
rays_o = np.expand_dims(c2w_t,0).repeat(H*W, 0)
|
||||||
|
|
||||||
|
rays_d = directions # (H, W, 3)
|
||||||
|
rays_d = (rays_d / np.linalg.norm(rays_d, axis=-1, keepdims=True)).reshape(-1,3)
|
||||||
|
|
||||||
|
return rays_o, rays_d
|
||||||
|
|
||||||
|
def get_mask_bounding_box(self, extrema, H, W, padding=30, dynamic_resize=0.15):
|
||||||
|
x_min, x_max, y_min, y_max = extrema
|
||||||
|
bb_xpad = max(int((x_max - x_min + 1) * dynamic_resize), padding)
|
||||||
|
bb_ypad = max(int((y_max - y_min + 1) * dynamic_resize), padding)
|
||||||
|
bbx_min = np.max((x_min - bb_xpad, 0))
|
||||||
|
bbx_max = np.min((x_max + bb_xpad, W-1))
|
||||||
|
bby_min = np.max((y_min - bb_ypad, 0))
|
||||||
|
bby_max = np.min((y_max + bb_ypad, H-1))
|
||||||
|
return bbx_min, bbx_max, bby_min, bby_max
|
||||||
|
|
||||||
|
def run_inference(self, img, Graphormer_model, mano, mesh_sampler, scale, crop_len):
|
||||||
|
global args
|
||||||
|
H, W = int(crop_len), int(crop_len)
|
||||||
|
Graphormer_model.eval()
|
||||||
|
mano.eval()
|
||||||
|
device = next(Graphormer_model.parameters()).device
|
||||||
|
with torch.no_grad():
|
||||||
|
img_tensor = self.transform(img)
|
||||||
|
batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
|
||||||
|
|
||||||
|
# forward-pass
|
||||||
|
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler)
|
||||||
|
|
||||||
|
# obtain 3d joints, which are regressed from the full mesh
|
||||||
|
pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
|
||||||
|
# obtain 2d joints, which are projected from 3d joints of mesh
|
||||||
|
#pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
|
||||||
|
#pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous())
|
||||||
|
pred_camera = pred_camera.cpu()
|
||||||
|
pred_vertices = pred_vertices.cpu()
|
||||||
|
mesh = Trimesh(vertices=pred_vertices[0], faces=mano.face)
|
||||||
|
res = crop_len
|
||||||
|
focal_length = 1000 * scale
|
||||||
|
camera_t = np.array([-pred_camera[1], -pred_camera[2], -2*focal_length/(res * pred_camera[0] +1e-9)])
|
||||||
|
pred_3d_joints_camera = pred_3d_joints_from_mesh.cpu()[0] - camera_t
|
||||||
|
z_3d_dist = pred_3d_joints_camera[:,2].clone()
|
||||||
|
|
||||||
|
pred_2d_joints_img_space = ((pred_3d_joints_camera/z_3d_dist[:,None]) * np.array((focal_length, focal_length, 1)))[:,:2] + np.array((W/2, H/2))
|
||||||
|
|
||||||
|
rays_o, rays_d = self.get_rays(W, H, focal_length, focal_length, W/2, H/2, camera_t, True)
|
||||||
|
coords = np.array(list(np.ndindex(H,W))).reshape(H,W,-1).transpose(1,0,2).reshape(-1,2)
|
||||||
|
intersector = RayMeshIntersector(mesh)
|
||||||
|
points, index_ray, _ = intersector.intersects_location(rays_o, rays_d, multiple_hits=False)
|
||||||
|
|
||||||
|
tri_index = intersector.intersects_first(rays_o, rays_d)
|
||||||
|
|
||||||
|
tri_index = tri_index[index_ray]
|
||||||
|
|
||||||
|
assert len(index_ray) == len(tri_index)
|
||||||
|
|
||||||
|
discriminator = (np.sum(mesh.face_normals[tri_index]* rays_d[index_ray], axis=-1)<= 0)
|
||||||
|
points = points[discriminator] # ray intesects in interior faces, discard them
|
||||||
|
|
||||||
|
if len(points) == 0:
|
||||||
|
return None, None
|
||||||
|
depth = (points + camera_t)[:,-1]
|
||||||
|
index_ray = index_ray[discriminator]
|
||||||
|
pixel_ray = coords[index_ray]
|
||||||
|
|
||||||
|
minval = np.min(depth)
|
||||||
|
maxval = np.max(depth)
|
||||||
|
depthmap = np.zeros([H,W])
|
||||||
|
|
||||||
|
depthmap[pixel_ray[:, 0], pixel_ray[:, 1]] = 1.0 - (0.8 * (depth - minval) / (maxval - minval))
|
||||||
|
depthmap *= 255
|
||||||
|
return depthmap, pred_2d_joints_img_space
|
||||||
|
|
||||||
|
|
||||||
|
def get_depth(self, np_image, padding):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
# STEP 3: Load the input image.
|
||||||
|
#https://stackoverflow.com/a/76407270
|
||||||
|
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np_image.copy())
|
||||||
|
|
||||||
|
# STEP 4: Detect hand landmarks from the input image.
|
||||||
|
detection_result = self.detector.detect(image)
|
||||||
|
|
||||||
|
handedness_list = detection_result.handedness
|
||||||
|
hand_landmarks_list = detection_result.hand_landmarks
|
||||||
|
|
||||||
|
raw_image = image.numpy_view()
|
||||||
|
H, W, C = raw_image.shape
|
||||||
|
|
||||||
|
|
||||||
|
# HANDLANDMARKS CAN BE EMPTY, HANDLE THIS!
|
||||||
|
if len(hand_landmarks_list) == 0:
|
||||||
|
return None, None, None
|
||||||
|
raw_image = raw_image[:, :, :3]
|
||||||
|
|
||||||
|
padded_image = np.zeros((H*2, W*2, 3))
|
||||||
|
padded_image[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)] = raw_image
|
||||||
|
|
||||||
|
hand_landmarks_list, handedness_list = zip(
|
||||||
|
*sorted(
|
||||||
|
zip(hand_landmarks_list, handedness_list), key=lambda x: x[0][9].z, reverse=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
padded_depthmap = np.zeros((H*2, W*2))
|
||||||
|
mask = np.zeros((H, W))
|
||||||
|
crop_boxes = []
|
||||||
|
#bboxes = []
|
||||||
|
groundtruth_2d_keypoints = []
|
||||||
|
hands = []
|
||||||
|
depth_failure = False
|
||||||
|
crop_lens = []
|
||||||
|
|
||||||
|
for idx in range(len(hand_landmarks_list)):
|
||||||
|
hand = true_hand_category[handedness_list[idx][0].category_name]
|
||||||
|
hands.append(hand)
|
||||||
|
hand_landmarks = hand_landmarks_list[idx]
|
||||||
|
handedness = handedness_list[idx]
|
||||||
|
height, width, _ = raw_image.shape
|
||||||
|
x_coordinates = [landmark.x for landmark in hand_landmarks]
|
||||||
|
y_coordinates = [landmark.y for landmark in hand_landmarks]
|
||||||
|
|
||||||
|
# x_min, x_max, y_min, y_max: extrema from mediapipe keypoint detection
|
||||||
|
x_min = int(min(x_coordinates) * width)
|
||||||
|
x_max = int(max(x_coordinates) * width)
|
||||||
|
x_c = (x_min + x_max)//2
|
||||||
|
y_min = int(min(y_coordinates) * height)
|
||||||
|
y_max = int(max(y_coordinates) * height)
|
||||||
|
y_c = (y_min + y_max)//2
|
||||||
|
|
||||||
|
#if x_max - x_min < 60 or y_max - y_min < 60:
|
||||||
|
# continue
|
||||||
|
|
||||||
|
crop_len = (max(x_max - x_min, y_max - y_min) * 1.6) //2 * 2
|
||||||
|
|
||||||
|
# crop_x_min, crop_x_max, crop_y_min, crop_y_max: bounding box for mesh reconstruction
|
||||||
|
crop_x_min = int(x_c - (crop_len/2 - 1) + W/2)
|
||||||
|
crop_x_max = int(x_c + crop_len/2 + W/2)
|
||||||
|
crop_y_min = int(y_c - (crop_len/2 - 1) + H/2)
|
||||||
|
crop_y_max = int(y_c + crop_len/2 + H/2)
|
||||||
|
|
||||||
|
cropped = padded_image[crop_y_min:crop_y_max+1, crop_x_min:crop_x_max+1]
|
||||||
|
crop_boxes.append([crop_y_min, crop_y_max, crop_x_min, crop_x_max])
|
||||||
|
crop_lens.append(crop_len)
|
||||||
|
if hand == "left":
|
||||||
|
cropped = cv2.flip(cropped, 1)
|
||||||
|
|
||||||
|
if crop_len < 224:
|
||||||
|
graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_CUBIC)
|
||||||
|
else:
|
||||||
|
graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_AREA)
|
||||||
|
scale = crop_len/224
|
||||||
|
cropped_depthmap, pred_2d_keypoints = self.run_inference(graphormer_input.astype(np.uint8), self._model, self.mano_model, self.mesh_sampler, scale, int(crop_len))
|
||||||
|
|
||||||
|
if cropped_depthmap is None:
|
||||||
|
depth_failure = True
|
||||||
|
break
|
||||||
|
#keypoints_image_space = pred_2d_keypoints * (crop_y_max - crop_y_min + 1)/224
|
||||||
|
groundtruth_2d_keypoints.append(pred_2d_keypoints)
|
||||||
|
|
||||||
|
if hand == "left":
|
||||||
|
cropped_depthmap = cv2.flip(cropped_depthmap, 1)
|
||||||
|
resized_cropped_depthmap = cv2.resize(cropped_depthmap, (int(crop_len), int(crop_len)), interpolation=cv2.INTER_LINEAR)
|
||||||
|
nonzero_y, nonzero_x = (resized_cropped_depthmap != 0).nonzero()
|
||||||
|
if len(nonzero_y) == 0 or len(nonzero_x) == 0:
|
||||||
|
depth_failure = True
|
||||||
|
break
|
||||||
|
padded_depthmap[crop_y_min+nonzero_y, crop_x_min+nonzero_x] = resized_cropped_depthmap[nonzero_y, nonzero_x]
|
||||||
|
|
||||||
|
# nonzero stands for nonzero value on the depth map
|
||||||
|
# coordinates of nonzero depth pixels in original image space
|
||||||
|
original_nonzero_x = crop_x_min+nonzero_x - int(W/2)
|
||||||
|
original_nonzero_y = crop_y_min+nonzero_y - int(H/2)
|
||||||
|
|
||||||
|
nonzerox_min = min(np.min(original_nonzero_x), x_min)
|
||||||
|
nonzerox_max = max(np.max(original_nonzero_x), x_max)
|
||||||
|
nonzeroy_min = min(np.min(original_nonzero_y), y_min)
|
||||||
|
nonzeroy_max = max(np.max(original_nonzero_y), y_max)
|
||||||
|
|
||||||
|
bbx_min, bbx_max, bby_min, bby_max = self.get_mask_bounding_box((nonzerox_min, nonzerox_max, nonzeroy_min, nonzeroy_max), H, W, padding)
|
||||||
|
mask[bby_min:bby_max+1, bbx_min:bbx_max+1] = 1.0
|
||||||
|
#bboxes.append([int(bbx_min), int(bbx_max), int(bby_min), int(bby_max)])
|
||||||
|
if depth_failure:
|
||||||
|
#print("cannot detect normal hands")
|
||||||
|
return None, None, None
|
||||||
|
depthmap = padded_depthmap[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)].astype(np.uint8)
|
||||||
|
mask = (255.0 * mask).astype(np.uint8)
|
||||||
|
info["groundtruth_2d_keypoints"] = groundtruth_2d_keypoints
|
||||||
|
info["hands"] = hands
|
||||||
|
info["crop_boxes"] = crop_boxes
|
||||||
|
info["crop_lens"] = crop_lens
|
||||||
|
return depthmap, mask, info
|
||||||
|
|
||||||
|
def get_keypoints(self, img, Graphormer_model, mano, mesh_sampler, scale, crop_len):
|
||||||
|
global args
|
||||||
|
H, W = int(crop_len), int(crop_len)
|
||||||
|
Graphormer_model.eval()
|
||||||
|
mano.eval()
|
||||||
|
device = next(Graphormer_model.parameters()).device
|
||||||
|
with torch.no_grad():
|
||||||
|
img_tensor = self.transform(img)
|
||||||
|
#print(img_tensor)
|
||||||
|
batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
|
||||||
|
|
||||||
|
# forward-pass
|
||||||
|
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler)
|
||||||
|
|
||||||
|
# obtain 3d joints, which are regressed from the full mesh
|
||||||
|
pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
|
||||||
|
# obtain 2d joints, which are projected from 3d joints of mesh
|
||||||
|
#pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
|
||||||
|
#pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous())
|
||||||
|
pred_camera = pred_camera.cpu()
|
||||||
|
pred_vertices = pred_vertices.cpu()
|
||||||
|
#
|
||||||
|
res = crop_len
|
||||||
|
focal_length = 1000 * scale
|
||||||
|
camera_t = np.array([-pred_camera[1], -pred_camera[2], -2*focal_length/(res * pred_camera[0] +1e-9)])
|
||||||
|
pred_3d_joints_camera = pred_3d_joints_from_mesh.cpu()[0] - camera_t
|
||||||
|
z_3d_dist = pred_3d_joints_camera[:,2].clone()
|
||||||
|
pred_2d_joints_img_space = ((pred_3d_joints_camera/z_3d_dist[:,None]) * np.array((focal_length, focal_length, 1)))[:,:2] + np.array((W/2, H/2))
|
||||||
|
|
||||||
|
return pred_2d_joints_img_space
|
||||||
|
|
||||||
|
|
||||||
|
def eval_mpjpe(self, sample, info):
|
||||||
|
H, W, C = sample.shape
|
||||||
|
padded_image = np.zeros((H*2, W*2, 3))
|
||||||
|
padded_image[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)] = sample
|
||||||
|
crop_boxes = info["crop_boxes"]
|
||||||
|
hands = info["hands"]
|
||||||
|
groundtruth_2d_keypoints = info["groundtruth_2d_keypoints"]
|
||||||
|
crop_lens = info["crop_lens"]
|
||||||
|
pjpe = 0
|
||||||
|
for i in range(len(crop_boxes)):#box in crop_boxes:
|
||||||
|
crop_y_min, crop_y_max, crop_x_min, crop_x_max = crop_boxes[i]
|
||||||
|
cropped = padded_image[crop_y_min:crop_y_max+1, crop_x_min:crop_x_max+1]
|
||||||
|
hand = hands[i]
|
||||||
|
if hand == "left":
|
||||||
|
cropped = cv2.flip(cropped, 1)
|
||||||
|
crop_len = crop_lens[i]
|
||||||
|
scale = crop_len/224
|
||||||
|
if crop_len < 224:
|
||||||
|
graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_CUBIC)
|
||||||
|
else:
|
||||||
|
graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_AREA)
|
||||||
|
generated_keypoint = self.get_keypoints(graphormer_input.astype(np.uint8), self._model, self.mano_model, self.mesh_sampler, scale, crop_len)
|
||||||
|
#generated_keypoint = generated_keypoint * ((crop_y_max - crop_y_min + 1)/224)
|
||||||
|
pjpe += np.sum(np.sqrt(np.sum(((generated_keypoint - groundtruth_2d_keypoints[i]) ** 2).numpy(), axis=1)))
|
||||||
|
pass
|
||||||
|
mpjpe = pjpe/(len(crop_boxes) * 21)
|
||||||
|
return mpjpe
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
193
hand_refiner/util.py
Normal file
193
hand_refiner/util.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
import warnings
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
here = Path(__file__).parent.resolve()
|
||||||
|
|
||||||
|
def HWC3(x):
|
||||||
|
assert x.dtype == np.uint8
|
||||||
|
if x.ndim == 2:
|
||||||
|
x = x[:, :, None]
|
||||||
|
assert x.ndim == 3
|
||||||
|
H, W, C = x.shape
|
||||||
|
assert C == 1 or C == 3 or C == 4
|
||||||
|
if C == 3:
|
||||||
|
return x
|
||||||
|
if C == 1:
|
||||||
|
return np.concatenate([x, x, x], axis=2)
|
||||||
|
if C == 4:
|
||||||
|
color = x[:, :, 0:3].astype(np.float32)
|
||||||
|
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
||||||
|
y = color * alpha + 255.0 * (1.0 - alpha)
|
||||||
|
y = y.clip(0, 255).astype(np.uint8)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def make_noise_disk(H, W, C, F):
|
||||||
|
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
|
||||||
|
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
|
||||||
|
noise = noise[F: F + H, F: F + W]
|
||||||
|
noise -= np.min(noise)
|
||||||
|
noise /= np.max(noise)
|
||||||
|
if C == 1:
|
||||||
|
noise = noise[:, :, None]
|
||||||
|
return noise
|
||||||
|
|
||||||
|
|
||||||
|
def nms(x, t, s):
|
||||||
|
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
||||||
|
|
||||||
|
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
||||||
|
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
||||||
|
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
||||||
|
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
||||||
|
|
||||||
|
y = np.zeros_like(x)
|
||||||
|
|
||||||
|
for f in [f1, f2, f3, f4]:
|
||||||
|
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
||||||
|
|
||||||
|
z = np.zeros_like(y, dtype=np.uint8)
|
||||||
|
z[y > t] = 255
|
||||||
|
return z
|
||||||
|
|
||||||
|
def min_max_norm(x):
|
||||||
|
x -= np.min(x)
|
||||||
|
x /= np.maximum(np.max(x), 1e-5)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def safe_step(x, step=2):
|
||||||
|
y = x.astype(np.float32) * float(step + 1)
|
||||||
|
y = y.astype(np.int32).astype(np.float32) / float(step)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def img2mask(img, H, W, low=10, high=90):
|
||||||
|
assert img.ndim == 3 or img.ndim == 2
|
||||||
|
assert img.dtype == np.uint8
|
||||||
|
|
||||||
|
if img.ndim == 3:
|
||||||
|
y = img[:, :, random.randrange(0, img.shape[2])]
|
||||||
|
else:
|
||||||
|
y = img
|
||||||
|
|
||||||
|
y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
|
||||||
|
|
||||||
|
if random.uniform(0, 1) < 0.5:
|
||||||
|
y = 255 - y
|
||||||
|
|
||||||
|
return y < np.percentile(y, random.randrange(low, high))
|
||||||
|
|
||||||
|
def safer_memory(x):
|
||||||
|
# Fix many MAC/AMD problems
|
||||||
|
return np.ascontiguousarray(x.copy()).copy()
|
||||||
|
|
||||||
|
UPSCALE_METHODS = ["INTER_NEAREST", "INTER_LINEAR", "INTER_AREA", "INTER_CUBIC", "INTER_LANCZOS4"]
|
||||||
|
def get_upscale_method(method_str):
|
||||||
|
assert method_str in UPSCALE_METHODS, f"Method {method_str} not found in {UPSCALE_METHODS}"
|
||||||
|
return getattr(cv2, method_str)
|
||||||
|
|
||||||
|
def pad64(x):
|
||||||
|
return int(np.ceil(float(x) / 64.0) * 64 - x)
|
||||||
|
|
||||||
|
#https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/processor.py#L17
|
||||||
|
#Added upscale_method param
|
||||||
|
def resize_image_with_pad(input_image, resolution, upscale_method = "", skip_hwc3=False):
|
||||||
|
if skip_hwc3:
|
||||||
|
img = input_image
|
||||||
|
else:
|
||||||
|
img = HWC3(input_image)
|
||||||
|
H_raw, W_raw, _ = img.shape
|
||||||
|
k = float(resolution) / float(min(H_raw, W_raw))
|
||||||
|
H_target = int(np.round(float(H_raw) * k))
|
||||||
|
W_target = int(np.round(float(W_raw) * k))
|
||||||
|
img = cv2.resize(img, (W_target, H_target), interpolation=get_upscale_method(upscale_method) if k > 1 else cv2.INTER_AREA)
|
||||||
|
H_pad, W_pad = pad64(H_target), pad64(W_target)
|
||||||
|
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode='edge')
|
||||||
|
|
||||||
|
def remove_pad(x):
|
||||||
|
return safer_memory(x[:H_target, :W_target, ...])
|
||||||
|
|
||||||
|
return safer_memory(img_padded), remove_pad
|
||||||
|
|
||||||
|
def common_input_validate(input_image, output_type, **kwargs):
|
||||||
|
if "img" in kwargs:
|
||||||
|
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
|
||||||
|
input_image = kwargs.pop("img")
|
||||||
|
|
||||||
|
if "return_pil" in kwargs:
|
||||||
|
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
|
||||||
|
output_type = "pil" if kwargs["return_pil"] else "np"
|
||||||
|
|
||||||
|
if type(output_type) is bool:
|
||||||
|
warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
|
||||||
|
if output_type:
|
||||||
|
output_type = "pil"
|
||||||
|
|
||||||
|
if input_image is None:
|
||||||
|
raise ValueError("input_image must be defined.")
|
||||||
|
|
||||||
|
if not isinstance(input_image, np.ndarray):
|
||||||
|
input_image = np.array(input_image, dtype=np.uint8)
|
||||||
|
output_type = output_type or "pil"
|
||||||
|
else:
|
||||||
|
output_type = output_type or "np"
|
||||||
|
|
||||||
|
return (input_image, output_type)
|
||||||
|
|
||||||
|
def custom_hf_download(pretrained_model_or_path, filename, cache_dir, subfolder='', use_symlinks=False):
|
||||||
|
local_dir = os.path.join(cache_dir, pretrained_model_or_path)
|
||||||
|
model_path = os.path.join(local_dir, *subfolder.split('/'), filename)
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print(f"Failed to find {model_path}.\n Downloading from huggingface.co")
|
||||||
|
if use_symlinks:
|
||||||
|
cache_dir_d = os.getenv("HUGGINGFACE_HUB_CACHE")
|
||||||
|
if cache_dir_d is None:
|
||||||
|
import platform
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
cache_dir_d = os.path.join(os.getenv("USERPROFILE"), ".cache", "huggingface", "hub")
|
||||||
|
else:
|
||||||
|
cache_dir_d = os.path.join(os.getenv("HOME"), ".cache", "huggingface", "hub")
|
||||||
|
try:
|
||||||
|
# test_link
|
||||||
|
if not os.path.exists(cache_dir_d):
|
||||||
|
os.makedirs(cache_dir_d)
|
||||||
|
open(os.path.join(cache_dir_d, f"linktest_{filename}.txt"), "w")
|
||||||
|
os.link(os.path.join(cache_dir_d, f"linktest_{filename}.txt"), os.path.join(cache_dir, f"linktest_{filename}.txt"))
|
||||||
|
os.remove(os.path.join(cache_dir, f"linktest_{filename}.txt"))
|
||||||
|
os.remove(os.path.join(cache_dir_d, f"linktest_{filename}.txt"))
|
||||||
|
print("Using symlinks to download models. \n",\
|
||||||
|
"Make sure you have enough space on your cache folder. \n",\
|
||||||
|
"And do not purge the cache folder after downloading.\n",\
|
||||||
|
"Otherwise, you will have to re-download the models every time you run the script.\n",\
|
||||||
|
"You can use USE_SYMLINKS: False in config.yaml to avoid this behavior.")
|
||||||
|
except:
|
||||||
|
print("Maybe not able to create symlink. Disable using symlinks.")
|
||||||
|
use_symlinks = False
|
||||||
|
cache_dir_d = os.path.join(cache_dir, pretrained_model_or_path, "cache")
|
||||||
|
else:
|
||||||
|
cache_dir_d = os.path.join(cache_dir, pretrained_model_or_path, "cache")
|
||||||
|
|
||||||
|
model_path = hf_hub_download(repo_id=pretrained_model_or_path,
|
||||||
|
cache_dir=cache_dir_d,
|
||||||
|
local_dir=local_dir,
|
||||||
|
subfolder=subfolder,
|
||||||
|
filename=filename,
|
||||||
|
local_dir_use_symlinks=use_symlinks,
|
||||||
|
resume_download=True,
|
||||||
|
etag_timeout=100
|
||||||
|
)
|
||||||
|
if not use_symlinks:
|
||||||
|
try:
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(cache_dir_d)
|
||||||
|
except Exception as e :
|
||||||
|
print(e)
|
||||||
|
return model_path
|
||||||
1
manopth/CHANGES.md
Normal file
1
manopth/CHANGES.md
Normal file
@@ -0,0 +1 @@
|
|||||||
|
* Chumpy is removed
|
||||||
674
manopth/LICENSE
Normal file
674
manopth/LICENSE
Normal file
@@ -0,0 +1,674 @@
|
|||||||
|
GNU GENERAL PUBLIC LICENSE
|
||||||
|
Version 3, 29 June 2007
|
||||||
|
|
||||||
|
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||||
|
Everyone is permitted to copy and distribute verbatim copies
|
||||||
|
of this license document, but changing it is not allowed.
|
||||||
|
|
||||||
|
Preamble
|
||||||
|
|
||||||
|
The GNU General Public License is a free, copyleft license for
|
||||||
|
software and other kinds of works.
|
||||||
|
|
||||||
|
The licenses for most software and other practical works are designed
|
||||||
|
to take away your freedom to share and change the works. By contrast,
|
||||||
|
the GNU General Public License is intended to guarantee your freedom to
|
||||||
|
share and change all versions of a program--to make sure it remains free
|
||||||
|
software for all its users. We, the Free Software Foundation, use the
|
||||||
|
GNU General Public License for most of our software; it applies also to
|
||||||
|
any other work released this way by its authors. You can apply it to
|
||||||
|
your programs, too.
|
||||||
|
|
||||||
|
When we speak of free software, we are referring to freedom, not
|
||||||
|
price. Our General Public Licenses are designed to make sure that you
|
||||||
|
have the freedom to distribute copies of free software (and charge for
|
||||||
|
them if you wish), that you receive source code or can get it if you
|
||||||
|
want it, that you can change the software or use pieces of it in new
|
||||||
|
free programs, and that you know you can do these things.
|
||||||
|
|
||||||
|
To protect your rights, we need to prevent others from denying you
|
||||||
|
these rights or asking you to surrender the rights. Therefore, you have
|
||||||
|
certain responsibilities if you distribute copies of the software, or if
|
||||||
|
you modify it: responsibilities to respect the freedom of others.
|
||||||
|
|
||||||
|
For example, if you distribute copies of such a program, whether
|
||||||
|
gratis or for a fee, you must pass on to the recipients the same
|
||||||
|
freedoms that you received. You must make sure that they, too, receive
|
||||||
|
or can get the source code. And you must show them these terms so they
|
||||||
|
know their rights.
|
||||||
|
|
||||||
|
Developers that use the GNU GPL protect your rights with two steps:
|
||||||
|
(1) assert copyright on the software, and (2) offer you this License
|
||||||
|
giving you legal permission to copy, distribute and/or modify it.
|
||||||
|
|
||||||
|
For the developers' and authors' protection, the GPL clearly explains
|
||||||
|
that there is no warranty for this free software. For both users' and
|
||||||
|
authors' sake, the GPL requires that modified versions be marked as
|
||||||
|
changed, so that their problems will not be attributed erroneously to
|
||||||
|
authors of previous versions.
|
||||||
|
|
||||||
|
Some devices are designed to deny users access to install or run
|
||||||
|
modified versions of the software inside them, although the manufacturer
|
||||||
|
can do so. This is fundamentally incompatible with the aim of
|
||||||
|
protecting users' freedom to change the software. The systematic
|
||||||
|
pattern of such abuse occurs in the area of products for individuals to
|
||||||
|
use, which is precisely where it is most unacceptable. Therefore, we
|
||||||
|
have designed this version of the GPL to prohibit the practice for those
|
||||||
|
products. If such problems arise substantially in other domains, we
|
||||||
|
stand ready to extend this provision to those domains in future versions
|
||||||
|
of the GPL, as needed to protect the freedom of users.
|
||||||
|
|
||||||
|
Finally, every program is threatened constantly by software patents.
|
||||||
|
States should not allow patents to restrict development and use of
|
||||||
|
software on general-purpose computers, but in those that do, we wish to
|
||||||
|
avoid the special danger that patents applied to a free program could
|
||||||
|
make it effectively proprietary. To prevent this, the GPL assures that
|
||||||
|
patents cannot be used to render the program non-free.
|
||||||
|
|
||||||
|
The precise terms and conditions for copying, distribution and
|
||||||
|
modification follow.
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
0. Definitions.
|
||||||
|
|
||||||
|
"This License" refers to version 3 of the GNU General Public License.
|
||||||
|
|
||||||
|
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||||
|
works, such as semiconductor masks.
|
||||||
|
|
||||||
|
"The Program" refers to any copyrightable work licensed under this
|
||||||
|
License. Each licensee is addressed as "you". "Licensees" and
|
||||||
|
"recipients" may be individuals or organizations.
|
||||||
|
|
||||||
|
To "modify" a work means to copy from or adapt all or part of the work
|
||||||
|
in a fashion requiring copyright permission, other than the making of an
|
||||||
|
exact copy. The resulting work is called a "modified version" of the
|
||||||
|
earlier work or a work "based on" the earlier work.
|
||||||
|
|
||||||
|
A "covered work" means either the unmodified Program or a work based
|
||||||
|
on the Program.
|
||||||
|
|
||||||
|
To "propagate" a work means to do anything with it that, without
|
||||||
|
permission, would make you directly or secondarily liable for
|
||||||
|
infringement under applicable copyright law, except executing it on a
|
||||||
|
computer or modifying a private copy. Propagation includes copying,
|
||||||
|
distribution (with or without modification), making available to the
|
||||||
|
public, and in some countries other activities as well.
|
||||||
|
|
||||||
|
To "convey" a work means any kind of propagation that enables other
|
||||||
|
parties to make or receive copies. Mere interaction with a user through
|
||||||
|
a computer network, with no transfer of a copy, is not conveying.
|
||||||
|
|
||||||
|
An interactive user interface displays "Appropriate Legal Notices"
|
||||||
|
to the extent that it includes a convenient and prominently visible
|
||||||
|
feature that (1) displays an appropriate copyright notice, and (2)
|
||||||
|
tells the user that there is no warranty for the work (except to the
|
||||||
|
extent that warranties are provided), that licensees may convey the
|
||||||
|
work under this License, and how to view a copy of this License. If
|
||||||
|
the interface presents a list of user commands or options, such as a
|
||||||
|
menu, a prominent item in the list meets this criterion.
|
||||||
|
|
||||||
|
1. Source Code.
|
||||||
|
|
||||||
|
The "source code" for a work means the preferred form of the work
|
||||||
|
for making modifications to it. "Object code" means any non-source
|
||||||
|
form of a work.
|
||||||
|
|
||||||
|
A "Standard Interface" means an interface that either is an official
|
||||||
|
standard defined by a recognized standards body, or, in the case of
|
||||||
|
interfaces specified for a particular programming language, one that
|
||||||
|
is widely used among developers working in that language.
|
||||||
|
|
||||||
|
The "System Libraries" of an executable work include anything, other
|
||||||
|
than the work as a whole, that (a) is included in the normal form of
|
||||||
|
packaging a Major Component, but which is not part of that Major
|
||||||
|
Component, and (b) serves only to enable use of the work with that
|
||||||
|
Major Component, or to implement a Standard Interface for which an
|
||||||
|
implementation is available to the public in source code form. A
|
||||||
|
"Major Component", in this context, means a major essential component
|
||||||
|
(kernel, window system, and so on) of the specific operating system
|
||||||
|
(if any) on which the executable work runs, or a compiler used to
|
||||||
|
produce the work, or an object code interpreter used to run it.
|
||||||
|
|
||||||
|
The "Corresponding Source" for a work in object code form means all
|
||||||
|
the source code needed to generate, install, and (for an executable
|
||||||
|
work) run the object code and to modify the work, including scripts to
|
||||||
|
control those activities. However, it does not include the work's
|
||||||
|
System Libraries, or general-purpose tools or generally available free
|
||||||
|
programs which are used unmodified in performing those activities but
|
||||||
|
which are not part of the work. For example, Corresponding Source
|
||||||
|
includes interface definition files associated with source files for
|
||||||
|
the work, and the source code for shared libraries and dynamically
|
||||||
|
linked subprograms that the work is specifically designed to require,
|
||||||
|
such as by intimate data communication or control flow between those
|
||||||
|
subprograms and other parts of the work.
|
||||||
|
|
||||||
|
The Corresponding Source need not include anything that users
|
||||||
|
can regenerate automatically from other parts of the Corresponding
|
||||||
|
Source.
|
||||||
|
|
||||||
|
The Corresponding Source for a work in source code form is that
|
||||||
|
same work.
|
||||||
|
|
||||||
|
2. Basic Permissions.
|
||||||
|
|
||||||
|
All rights granted under this License are granted for the term of
|
||||||
|
copyright on the Program, and are irrevocable provided the stated
|
||||||
|
conditions are met. This License explicitly affirms your unlimited
|
||||||
|
permission to run the unmodified Program. The output from running a
|
||||||
|
covered work is covered by this License only if the output, given its
|
||||||
|
content, constitutes a covered work. This License acknowledges your
|
||||||
|
rights of fair use or other equivalent, as provided by copyright law.
|
||||||
|
|
||||||
|
You may make, run and propagate covered works that you do not
|
||||||
|
convey, without conditions so long as your license otherwise remains
|
||||||
|
in force. You may convey covered works to others for the sole purpose
|
||||||
|
of having them make modifications exclusively for you, or provide you
|
||||||
|
with facilities for running those works, provided that you comply with
|
||||||
|
the terms of this License in conveying all material for which you do
|
||||||
|
not control copyright. Those thus making or running the covered works
|
||||||
|
for you must do so exclusively on your behalf, under your direction
|
||||||
|
and control, on terms that prohibit them from making any copies of
|
||||||
|
your copyrighted material outside their relationship with you.
|
||||||
|
|
||||||
|
Conveying under any other circumstances is permitted solely under
|
||||||
|
the conditions stated below. Sublicensing is not allowed; section 10
|
||||||
|
makes it unnecessary.
|
||||||
|
|
||||||
|
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||||
|
|
||||||
|
No covered work shall be deemed part of an effective technological
|
||||||
|
measure under any applicable law fulfilling obligations under article
|
||||||
|
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||||
|
similar laws prohibiting or restricting circumvention of such
|
||||||
|
measures.
|
||||||
|
|
||||||
|
When you convey a covered work, you waive any legal power to forbid
|
||||||
|
circumvention of technological measures to the extent such circumvention
|
||||||
|
is effected by exercising rights under this License with respect to
|
||||||
|
the covered work, and you disclaim any intention to limit operation or
|
||||||
|
modification of the work as a means of enforcing, against the work's
|
||||||
|
users, your or third parties' legal rights to forbid circumvention of
|
||||||
|
technological measures.
|
||||||
|
|
||||||
|
4. Conveying Verbatim Copies.
|
||||||
|
|
||||||
|
You may convey verbatim copies of the Program's source code as you
|
||||||
|
receive it, in any medium, provided that you conspicuously and
|
||||||
|
appropriately publish on each copy an appropriate copyright notice;
|
||||||
|
keep intact all notices stating that this License and any
|
||||||
|
non-permissive terms added in accord with section 7 apply to the code;
|
||||||
|
keep intact all notices of the absence of any warranty; and give all
|
||||||
|
recipients a copy of this License along with the Program.
|
||||||
|
|
||||||
|
You may charge any price or no price for each copy that you convey,
|
||||||
|
and you may offer support or warranty protection for a fee.
|
||||||
|
|
||||||
|
5. Conveying Modified Source Versions.
|
||||||
|
|
||||||
|
You may convey a work based on the Program, or the modifications to
|
||||||
|
produce it from the Program, in the form of source code under the
|
||||||
|
terms of section 4, provided that you also meet all of these conditions:
|
||||||
|
|
||||||
|
a) The work must carry prominent notices stating that you modified
|
||||||
|
it, and giving a relevant date.
|
||||||
|
|
||||||
|
b) The work must carry prominent notices stating that it is
|
||||||
|
released under this License and any conditions added under section
|
||||||
|
7. This requirement modifies the requirement in section 4 to
|
||||||
|
"keep intact all notices".
|
||||||
|
|
||||||
|
c) You must license the entire work, as a whole, under this
|
||||||
|
License to anyone who comes into possession of a copy. This
|
||||||
|
License will therefore apply, along with any applicable section 7
|
||||||
|
additional terms, to the whole of the work, and all its parts,
|
||||||
|
regardless of how they are packaged. This License gives no
|
||||||
|
permission to license the work in any other way, but it does not
|
||||||
|
invalidate such permission if you have separately received it.
|
||||||
|
|
||||||
|
d) If the work has interactive user interfaces, each must display
|
||||||
|
Appropriate Legal Notices; however, if the Program has interactive
|
||||||
|
interfaces that do not display Appropriate Legal Notices, your
|
||||||
|
work need not make them do so.
|
||||||
|
|
||||||
|
A compilation of a covered work with other separate and independent
|
||||||
|
works, which are not by their nature extensions of the covered work,
|
||||||
|
and which are not combined with it such as to form a larger program,
|
||||||
|
in or on a volume of a storage or distribution medium, is called an
|
||||||
|
"aggregate" if the compilation and its resulting copyright are not
|
||||||
|
used to limit the access or legal rights of the compilation's users
|
||||||
|
beyond what the individual works permit. Inclusion of a covered work
|
||||||
|
in an aggregate does not cause this License to apply to the other
|
||||||
|
parts of the aggregate.
|
||||||
|
|
||||||
|
6. Conveying Non-Source Forms.
|
||||||
|
|
||||||
|
You may convey a covered work in object code form under the terms
|
||||||
|
of sections 4 and 5, provided that you also convey the
|
||||||
|
machine-readable Corresponding Source under the terms of this License,
|
||||||
|
in one of these ways:
|
||||||
|
|
||||||
|
a) Convey the object code in, or embodied in, a physical product
|
||||||
|
(including a physical distribution medium), accompanied by the
|
||||||
|
Corresponding Source fixed on a durable physical medium
|
||||||
|
customarily used for software interchange.
|
||||||
|
|
||||||
|
b) Convey the object code in, or embodied in, a physical product
|
||||||
|
(including a physical distribution medium), accompanied by a
|
||||||
|
written offer, valid for at least three years and valid for as
|
||||||
|
long as you offer spare parts or customer support for that product
|
||||||
|
model, to give anyone who possesses the object code either (1) a
|
||||||
|
copy of the Corresponding Source for all the software in the
|
||||||
|
product that is covered by this License, on a durable physical
|
||||||
|
medium customarily used for software interchange, for a price no
|
||||||
|
more than your reasonable cost of physically performing this
|
||||||
|
conveying of source, or (2) access to copy the
|
||||||
|
Corresponding Source from a network server at no charge.
|
||||||
|
|
||||||
|
c) Convey individual copies of the object code with a copy of the
|
||||||
|
written offer to provide the Corresponding Source. This
|
||||||
|
alternative is allowed only occasionally and noncommercially, and
|
||||||
|
only if you received the object code with such an offer, in accord
|
||||||
|
with subsection 6b.
|
||||||
|
|
||||||
|
d) Convey the object code by offering access from a designated
|
||||||
|
place (gratis or for a charge), and offer equivalent access to the
|
||||||
|
Corresponding Source in the same way through the same place at no
|
||||||
|
further charge. You need not require recipients to copy the
|
||||||
|
Corresponding Source along with the object code. If the place to
|
||||||
|
copy the object code is a network server, the Corresponding Source
|
||||||
|
may be on a different server (operated by you or a third party)
|
||||||
|
that supports equivalent copying facilities, provided you maintain
|
||||||
|
clear directions next to the object code saying where to find the
|
||||||
|
Corresponding Source. Regardless of what server hosts the
|
||||||
|
Corresponding Source, you remain obligated to ensure that it is
|
||||||
|
available for as long as needed to satisfy these requirements.
|
||||||
|
|
||||||
|
e) Convey the object code using peer-to-peer transmission, provided
|
||||||
|
you inform other peers where the object code and Corresponding
|
||||||
|
Source of the work are being offered to the general public at no
|
||||||
|
charge under subsection 6d.
|
||||||
|
|
||||||
|
A separable portion of the object code, whose source code is excluded
|
||||||
|
from the Corresponding Source as a System Library, need not be
|
||||||
|
included in conveying the object code work.
|
||||||
|
|
||||||
|
A "User Product" is either (1) a "consumer product", which means any
|
||||||
|
tangible personal property which is normally used for personal, family,
|
||||||
|
or household purposes, or (2) anything designed or sold for incorporation
|
||||||
|
into a dwelling. In determining whether a product is a consumer product,
|
||||||
|
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||||
|
product received by a particular user, "normally used" refers to a
|
||||||
|
typical or common use of that class of product, regardless of the status
|
||||||
|
of the particular user or of the way in which the particular user
|
||||||
|
actually uses, or expects or is expected to use, the product. A product
|
||||||
|
is a consumer product regardless of whether the product has substantial
|
||||||
|
commercial, industrial or non-consumer uses, unless such uses represent
|
||||||
|
the only significant mode of use of the product.
|
||||||
|
|
||||||
|
"Installation Information" for a User Product means any methods,
|
||||||
|
procedures, authorization keys, or other information required to install
|
||||||
|
and execute modified versions of a covered work in that User Product from
|
||||||
|
a modified version of its Corresponding Source. The information must
|
||||||
|
suffice to ensure that the continued functioning of the modified object
|
||||||
|
code is in no case prevented or interfered with solely because
|
||||||
|
modification has been made.
|
||||||
|
|
||||||
|
If you convey an object code work under this section in, or with, or
|
||||||
|
specifically for use in, a User Product, and the conveying occurs as
|
||||||
|
part of a transaction in which the right of possession and use of the
|
||||||
|
User Product is transferred to the recipient in perpetuity or for a
|
||||||
|
fixed term (regardless of how the transaction is characterized), the
|
||||||
|
Corresponding Source conveyed under this section must be accompanied
|
||||||
|
by the Installation Information. But this requirement does not apply
|
||||||
|
if neither you nor any third party retains the ability to install
|
||||||
|
modified object code on the User Product (for example, the work has
|
||||||
|
been installed in ROM).
|
||||||
|
|
||||||
|
The requirement to provide Installation Information does not include a
|
||||||
|
requirement to continue to provide support service, warranty, or updates
|
||||||
|
for a work that has been modified or installed by the recipient, or for
|
||||||
|
the User Product in which it has been modified or installed. Access to a
|
||||||
|
network may be denied when the modification itself materially and
|
||||||
|
adversely affects the operation of the network or violates the rules and
|
||||||
|
protocols for communication across the network.
|
||||||
|
|
||||||
|
Corresponding Source conveyed, and Installation Information provided,
|
||||||
|
in accord with this section must be in a format that is publicly
|
||||||
|
documented (and with an implementation available to the public in
|
||||||
|
source code form), and must require no special password or key for
|
||||||
|
unpacking, reading or copying.
|
||||||
|
|
||||||
|
7. Additional Terms.
|
||||||
|
|
||||||
|
"Additional permissions" are terms that supplement the terms of this
|
||||||
|
License by making exceptions from one or more of its conditions.
|
||||||
|
Additional permissions that are applicable to the entire Program shall
|
||||||
|
be treated as though they were included in this License, to the extent
|
||||||
|
that they are valid under applicable law. If additional permissions
|
||||||
|
apply only to part of the Program, that part may be used separately
|
||||||
|
under those permissions, but the entire Program remains governed by
|
||||||
|
this License without regard to the additional permissions.
|
||||||
|
|
||||||
|
When you convey a copy of a covered work, you may at your option
|
||||||
|
remove any additional permissions from that copy, or from any part of
|
||||||
|
it. (Additional permissions may be written to require their own
|
||||||
|
removal in certain cases when you modify the work.) You may place
|
||||||
|
additional permissions on material, added by you to a covered work,
|
||||||
|
for which you have or can give appropriate copyright permission.
|
||||||
|
|
||||||
|
Notwithstanding any other provision of this License, for material you
|
||||||
|
add to a covered work, you may (if authorized by the copyright holders of
|
||||||
|
that material) supplement the terms of this License with terms:
|
||||||
|
|
||||||
|
a) Disclaiming warranty or limiting liability differently from the
|
||||||
|
terms of sections 15 and 16 of this License; or
|
||||||
|
|
||||||
|
b) Requiring preservation of specified reasonable legal notices or
|
||||||
|
author attributions in that material or in the Appropriate Legal
|
||||||
|
Notices displayed by works containing it; or
|
||||||
|
|
||||||
|
c) Prohibiting misrepresentation of the origin of that material, or
|
||||||
|
requiring that modified versions of such material be marked in
|
||||||
|
reasonable ways as different from the original version; or
|
||||||
|
|
||||||
|
d) Limiting the use for publicity purposes of names of licensors or
|
||||||
|
authors of the material; or
|
||||||
|
|
||||||
|
e) Declining to grant rights under trademark law for use of some
|
||||||
|
trade names, trademarks, or service marks; or
|
||||||
|
|
||||||
|
f) Requiring indemnification of licensors and authors of that
|
||||||
|
material by anyone who conveys the material (or modified versions of
|
||||||
|
it) with contractual assumptions of liability to the recipient, for
|
||||||
|
any liability that these contractual assumptions directly impose on
|
||||||
|
those licensors and authors.
|
||||||
|
|
||||||
|
All other non-permissive additional terms are considered "further
|
||||||
|
restrictions" within the meaning of section 10. If the Program as you
|
||||||
|
received it, or any part of it, contains a notice stating that it is
|
||||||
|
governed by this License along with a term that is a further
|
||||||
|
restriction, you may remove that term. If a license document contains
|
||||||
|
a further restriction but permits relicensing or conveying under this
|
||||||
|
License, you may add to a covered work material governed by the terms
|
||||||
|
of that license document, provided that the further restriction does
|
||||||
|
not survive such relicensing or conveying.
|
||||||
|
|
||||||
|
If you add terms to a covered work in accord with this section, you
|
||||||
|
must place, in the relevant source files, a statement of the
|
||||||
|
additional terms that apply to those files, or a notice indicating
|
||||||
|
where to find the applicable terms.
|
||||||
|
|
||||||
|
Additional terms, permissive or non-permissive, may be stated in the
|
||||||
|
form of a separately written license, or stated as exceptions;
|
||||||
|
the above requirements apply either way.
|
||||||
|
|
||||||
|
8. Termination.
|
||||||
|
|
||||||
|
You may not propagate or modify a covered work except as expressly
|
||||||
|
provided under this License. Any attempt otherwise to propagate or
|
||||||
|
modify it is void, and will automatically terminate your rights under
|
||||||
|
this License (including any patent licenses granted under the third
|
||||||
|
paragraph of section 11).
|
||||||
|
|
||||||
|
However, if you cease all violation of this License, then your
|
||||||
|
license from a particular copyright holder is reinstated (a)
|
||||||
|
provisionally, unless and until the copyright holder explicitly and
|
||||||
|
finally terminates your license, and (b) permanently, if the copyright
|
||||||
|
holder fails to notify you of the violation by some reasonable means
|
||||||
|
prior to 60 days after the cessation.
|
||||||
|
|
||||||
|
Moreover, your license from a particular copyright holder is
|
||||||
|
reinstated permanently if the copyright holder notifies you of the
|
||||||
|
violation by some reasonable means, this is the first time you have
|
||||||
|
received notice of violation of this License (for any work) from that
|
||||||
|
copyright holder, and you cure the violation prior to 30 days after
|
||||||
|
your receipt of the notice.
|
||||||
|
|
||||||
|
Termination of your rights under this section does not terminate the
|
||||||
|
licenses of parties who have received copies or rights from you under
|
||||||
|
this License. If your rights have been terminated and not permanently
|
||||||
|
reinstated, you do not qualify to receive new licenses for the same
|
||||||
|
material under section 10.
|
||||||
|
|
||||||
|
9. Acceptance Not Required for Having Copies.
|
||||||
|
|
||||||
|
You are not required to accept this License in order to receive or
|
||||||
|
run a copy of the Program. Ancillary propagation of a covered work
|
||||||
|
occurring solely as a consequence of using peer-to-peer transmission
|
||||||
|
to receive a copy likewise does not require acceptance. However,
|
||||||
|
nothing other than this License grants you permission to propagate or
|
||||||
|
modify any covered work. These actions infringe copyright if you do
|
||||||
|
not accept this License. Therefore, by modifying or propagating a
|
||||||
|
covered work, you indicate your acceptance of this License to do so.
|
||||||
|
|
||||||
|
10. Automatic Licensing of Downstream Recipients.
|
||||||
|
|
||||||
|
Each time you convey a covered work, the recipient automatically
|
||||||
|
receives a license from the original licensors, to run, modify and
|
||||||
|
propagate that work, subject to this License. You are not responsible
|
||||||
|
for enforcing compliance by third parties with this License.
|
||||||
|
|
||||||
|
An "entity transaction" is a transaction transferring control of an
|
||||||
|
organization, or substantially all assets of one, or subdividing an
|
||||||
|
organization, or merging organizations. If propagation of a covered
|
||||||
|
work results from an entity transaction, each party to that
|
||||||
|
transaction who receives a copy of the work also receives whatever
|
||||||
|
licenses to the work the party's predecessor in interest had or could
|
||||||
|
give under the previous paragraph, plus a right to possession of the
|
||||||
|
Corresponding Source of the work from the predecessor in interest, if
|
||||||
|
the predecessor has it or can get it with reasonable efforts.
|
||||||
|
|
||||||
|
You may not impose any further restrictions on the exercise of the
|
||||||
|
rights granted or affirmed under this License. For example, you may
|
||||||
|
not impose a license fee, royalty, or other charge for exercise of
|
||||||
|
rights granted under this License, and you may not initiate litigation
|
||||||
|
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||||
|
any patent claim is infringed by making, using, selling, offering for
|
||||||
|
sale, or importing the Program or any portion of it.
|
||||||
|
|
||||||
|
11. Patents.
|
||||||
|
|
||||||
|
A "contributor" is a copyright holder who authorizes use under this
|
||||||
|
License of the Program or a work on which the Program is based. The
|
||||||
|
work thus licensed is called the contributor's "contributor version".
|
||||||
|
|
||||||
|
A contributor's "essential patent claims" are all patent claims
|
||||||
|
owned or controlled by the contributor, whether already acquired or
|
||||||
|
hereafter acquired, that would be infringed by some manner, permitted
|
||||||
|
by this License, of making, using, or selling its contributor version,
|
||||||
|
but do not include claims that would be infringed only as a
|
||||||
|
consequence of further modification of the contributor version. For
|
||||||
|
purposes of this definition, "control" includes the right to grant
|
||||||
|
patent sublicenses in a manner consistent with the requirements of
|
||||||
|
this License.
|
||||||
|
|
||||||
|
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||||
|
patent license under the contributor's essential patent claims, to
|
||||||
|
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||||
|
propagate the contents of its contributor version.
|
||||||
|
|
||||||
|
In the following three paragraphs, a "patent license" is any express
|
||||||
|
agreement or commitment, however denominated, not to enforce a patent
|
||||||
|
(such as an express permission to practice a patent or covenant not to
|
||||||
|
sue for patent infringement). To "grant" such a patent license to a
|
||||||
|
party means to make such an agreement or commitment not to enforce a
|
||||||
|
patent against the party.
|
||||||
|
|
||||||
|
If you convey a covered work, knowingly relying on a patent license,
|
||||||
|
and the Corresponding Source of the work is not available for anyone
|
||||||
|
to copy, free of charge and under the terms of this License, through a
|
||||||
|
publicly available network server or other readily accessible means,
|
||||||
|
then you must either (1) cause the Corresponding Source to be so
|
||||||
|
available, or (2) arrange to deprive yourself of the benefit of the
|
||||||
|
patent license for this particular work, or (3) arrange, in a manner
|
||||||
|
consistent with the requirements of this License, to extend the patent
|
||||||
|
license to downstream recipients. "Knowingly relying" means you have
|
||||||
|
actual knowledge that, but for the patent license, your conveying the
|
||||||
|
covered work in a country, or your recipient's use of the covered work
|
||||||
|
in a country, would infringe one or more identifiable patents in that
|
||||||
|
country that you have reason to believe are valid.
|
||||||
|
|
||||||
|
If, pursuant to or in connection with a single transaction or
|
||||||
|
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||||
|
covered work, and grant a patent license to some of the parties
|
||||||
|
receiving the covered work authorizing them to use, propagate, modify
|
||||||
|
or convey a specific copy of the covered work, then the patent license
|
||||||
|
you grant is automatically extended to all recipients of the covered
|
||||||
|
work and works based on it.
|
||||||
|
|
||||||
|
A patent license is "discriminatory" if it does not include within
|
||||||
|
the scope of its coverage, prohibits the exercise of, or is
|
||||||
|
conditioned on the non-exercise of one or more of the rights that are
|
||||||
|
specifically granted under this License. You may not convey a covered
|
||||||
|
work if you are a party to an arrangement with a third party that is
|
||||||
|
in the business of distributing software, under which you make payment
|
||||||
|
to the third party based on the extent of your activity of conveying
|
||||||
|
the work, and under which the third party grants, to any of the
|
||||||
|
parties who would receive the covered work from you, a discriminatory
|
||||||
|
patent license (a) in connection with copies of the covered work
|
||||||
|
conveyed by you (or copies made from those copies), or (b) primarily
|
||||||
|
for and in connection with specific products or compilations that
|
||||||
|
contain the covered work, unless you entered into that arrangement,
|
||||||
|
or that patent license was granted, prior to 28 March 2007.
|
||||||
|
|
||||||
|
Nothing in this License shall be construed as excluding or limiting
|
||||||
|
any implied license or other defenses to infringement that may
|
||||||
|
otherwise be available to you under applicable patent law.
|
||||||
|
|
||||||
|
12. No Surrender of Others' Freedom.
|
||||||
|
|
||||||
|
If conditions are imposed on you (whether by court order, agreement or
|
||||||
|
otherwise) that contradict the conditions of this License, they do not
|
||||||
|
excuse you from the conditions of this License. If you cannot convey a
|
||||||
|
covered work so as to satisfy simultaneously your obligations under this
|
||||||
|
License and any other pertinent obligations, then as a consequence you may
|
||||||
|
not convey it at all. For example, if you agree to terms that obligate you
|
||||||
|
to collect a royalty for further conveying from those to whom you convey
|
||||||
|
the Program, the only way you could satisfy both those terms and this
|
||||||
|
License would be to refrain entirely from conveying the Program.
|
||||||
|
|
||||||
|
13. Use with the GNU Affero General Public License.
|
||||||
|
|
||||||
|
Notwithstanding any other provision of this License, you have
|
||||||
|
permission to link or combine any covered work with a work licensed
|
||||||
|
under version 3 of the GNU Affero General Public License into a single
|
||||||
|
combined work, and to convey the resulting work. The terms of this
|
||||||
|
License will continue to apply to the part which is the covered work,
|
||||||
|
but the special requirements of the GNU Affero General Public License,
|
||||||
|
section 13, concerning interaction through a network will apply to the
|
||||||
|
combination as such.
|
||||||
|
|
||||||
|
14. Revised Versions of this License.
|
||||||
|
|
||||||
|
The Free Software Foundation may publish revised and/or new versions of
|
||||||
|
the GNU General Public License from time to time. Such new versions will
|
||||||
|
be similar in spirit to the present version, but may differ in detail to
|
||||||
|
address new problems or concerns.
|
||||||
|
|
||||||
|
Each version is given a distinguishing version number. If the
|
||||||
|
Program specifies that a certain numbered version of the GNU General
|
||||||
|
Public License "or any later version" applies to it, you have the
|
||||||
|
option of following the terms and conditions either of that numbered
|
||||||
|
version or of any later version published by the Free Software
|
||||||
|
Foundation. If the Program does not specify a version number of the
|
||||||
|
GNU General Public License, you may choose any version ever published
|
||||||
|
by the Free Software Foundation.
|
||||||
|
|
||||||
|
If the Program specifies that a proxy can decide which future
|
||||||
|
versions of the GNU General Public License can be used, that proxy's
|
||||||
|
public statement of acceptance of a version permanently authorizes you
|
||||||
|
to choose that version for the Program.
|
||||||
|
|
||||||
|
Later license versions may give you additional or different
|
||||||
|
permissions. However, no additional obligations are imposed on any
|
||||||
|
author or copyright holder as a result of your choosing to follow a
|
||||||
|
later version.
|
||||||
|
|
||||||
|
15. Disclaimer of Warranty.
|
||||||
|
|
||||||
|
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||||
|
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||||
|
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||||
|
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||||
|
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||||
|
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||||
|
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||||
|
|
||||||
|
16. Limitation of Liability.
|
||||||
|
|
||||||
|
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||||
|
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||||
|
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||||
|
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||||
|
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||||
|
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||||
|
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||||
|
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||||
|
SUCH DAMAGES.
|
||||||
|
|
||||||
|
17. Interpretation of Sections 15 and 16.
|
||||||
|
|
||||||
|
If the disclaimer of warranty and limitation of liability provided
|
||||||
|
above cannot be given local legal effect according to their terms,
|
||||||
|
reviewing courts shall apply local law that most closely approximates
|
||||||
|
an absolute waiver of all civil liability in connection with the
|
||||||
|
Program, unless a warranty or assumption of liability accompanies a
|
||||||
|
copy of the Program in return for a fee.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
How to Apply These Terms to Your New Programs
|
||||||
|
|
||||||
|
If you develop a new program, and you want it to be of the greatest
|
||||||
|
possible use to the public, the best way to achieve this is to make it
|
||||||
|
free software which everyone can redistribute and change under these terms.
|
||||||
|
|
||||||
|
To do so, attach the following notices to the program. It is safest
|
||||||
|
to attach them to the start of each source file to most effectively
|
||||||
|
state the exclusion of warranty; and each file should have at least
|
||||||
|
the "copyright" line and a pointer to where the full notice is found.
|
||||||
|
|
||||||
|
<one line to give the program's name and a brief idea of what it does.>
|
||||||
|
Copyright (C) <year> <name of author>
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
Also add information on how to contact you by electronic and paper mail.
|
||||||
|
|
||||||
|
If the program does terminal interaction, make it output a short
|
||||||
|
notice like this when it starts in an interactive mode:
|
||||||
|
|
||||||
|
<program> Copyright (C) <year> <name of author>
|
||||||
|
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||||
|
This is free software, and you are welcome to redistribute it
|
||||||
|
under certain conditions; type `show c' for details.
|
||||||
|
|
||||||
|
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||||
|
parts of the General Public License. Of course, your program's commands
|
||||||
|
might be different; for a GUI interface, you would use an "about box".
|
||||||
|
|
||||||
|
You should also get your employer (if you work as a programmer) or school,
|
||||||
|
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||||
|
For more information on this, and how to apply and follow the GNU GPL, see
|
||||||
|
<https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
The GNU General Public License does not permit incorporating your program
|
||||||
|
into proprietary programs. If your program is a subroutine library, you
|
||||||
|
may consider it more useful to permit linking proprietary applications with
|
||||||
|
the library. If this is what you want to do, use the GNU Lesser General
|
||||||
|
Public License instead of this License. But first, please read
|
||||||
|
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||||
1
manopth/__init__.py
Normal file
1
manopth/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
name = 'manopth'
|
||||||
51
manopth/argutils.py
Normal file
51
manopth/argutils.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def print_args(args):
|
||||||
|
opts = vars(args)
|
||||||
|
print('======= Options ========')
|
||||||
|
for k, v in sorted(opts.items()):
|
||||||
|
print('{}: {}'.format(k, v))
|
||||||
|
print('========================')
|
||||||
|
|
||||||
|
|
||||||
|
def save_args(args, save_folder, opt_prefix='opt', verbose=True):
|
||||||
|
opts = vars(args)
|
||||||
|
# Create checkpoint folder
|
||||||
|
if not os.path.exists(save_folder):
|
||||||
|
os.makedirs(save_folder, exist_ok=True)
|
||||||
|
|
||||||
|
# Save options
|
||||||
|
opt_filename = '{}.txt'.format(opt_prefix)
|
||||||
|
opt_path = os.path.join(save_folder, opt_filename)
|
||||||
|
with open(opt_path, 'a') as opt_file:
|
||||||
|
opt_file.write('====== Options ======\n')
|
||||||
|
for k, v in sorted(opts.items()):
|
||||||
|
opt_file.write(
|
||||||
|
'{option}: {value}\n'.format(option=str(k), value=str(v)))
|
||||||
|
opt_file.write('=====================\n')
|
||||||
|
opt_file.write('launched {} at {}\n'.format(
|
||||||
|
str(sys.argv[0]), str(datetime.datetime.now())))
|
||||||
|
|
||||||
|
# Add git info
|
||||||
|
label = subprocess.check_output(["git", "describe",
|
||||||
|
"--always"]).strip()
|
||||||
|
if subprocess.call(
|
||||||
|
["git", "branch"],
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
stdout=open(os.devnull, 'w')) == 0:
|
||||||
|
opt_file.write('=== Git info ====\n')
|
||||||
|
opt_file.write('{}\n'.format(label))
|
||||||
|
commit = subprocess.check_output(['git', 'rev-parse', 'HEAD'])
|
||||||
|
opt_file.write('commit : {}\n'.format(commit.strip()))
|
||||||
|
|
||||||
|
opt_picklename = '{}.pkl'.format(opt_prefix)
|
||||||
|
opt_picklepath = os.path.join(save_folder, opt_picklename)
|
||||||
|
with open(opt_picklepath, 'wb') as opt_file:
|
||||||
|
pickle.dump(opts, opt_file)
|
||||||
|
if verbose:
|
||||||
|
print('Saved options to {}'.format(opt_path))
|
||||||
59
manopth/demo.py
Normal file
59
manopth/demo.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from mpl_toolkits.mplot3d import Axes3D
|
||||||
|
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from manopth.manolayer import ManoLayer
|
||||||
|
|
||||||
|
|
||||||
|
def generate_random_hand(batch_size=1, ncomps=6, mano_root='mano/models'):
|
||||||
|
nfull_comps = ncomps + 3 # Add global orientation dims to PCA
|
||||||
|
random_pcapose = torch.rand(batch_size, nfull_comps)
|
||||||
|
mano_layer = ManoLayer(mano_root=mano_root)
|
||||||
|
verts, joints = mano_layer(random_pcapose)
|
||||||
|
return {'verts': verts, 'joints': joints, 'faces': mano_layer.th_faces}
|
||||||
|
|
||||||
|
|
||||||
|
def display_hand(hand_info, mano_faces=None, ax=None, alpha=0.2, batch_idx=0, show=True):
|
||||||
|
"""
|
||||||
|
Displays hand batch_idx in batch of hand_info, hand_info as returned by
|
||||||
|
generate_random_hand
|
||||||
|
"""
|
||||||
|
if ax is None:
|
||||||
|
fig = plt.figure()
|
||||||
|
ax = fig.add_subplot(111, projection='3d')
|
||||||
|
verts, joints = hand_info['verts'][batch_idx], hand_info['joints'][
|
||||||
|
batch_idx]
|
||||||
|
if mano_faces is None:
|
||||||
|
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.1)
|
||||||
|
else:
|
||||||
|
mesh = Poly3DCollection(verts[mano_faces], alpha=alpha)
|
||||||
|
face_color = (141 / 255, 184 / 255, 226 / 255)
|
||||||
|
edge_color = (50 / 255, 50 / 255, 50 / 255)
|
||||||
|
mesh.set_edgecolor(edge_color)
|
||||||
|
mesh.set_facecolor(face_color)
|
||||||
|
ax.add_collection3d(mesh)
|
||||||
|
ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color='r')
|
||||||
|
cam_equal_aspect_3d(ax, verts.numpy())
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def cam_equal_aspect_3d(ax, verts, flip_x=False):
|
||||||
|
"""
|
||||||
|
Centers view on cuboid containing hand and flips y and z axis
|
||||||
|
and fixes azimuth
|
||||||
|
"""
|
||||||
|
extents = np.stack([verts.min(0), verts.max(0)], axis=1)
|
||||||
|
sz = extents[:, 1] - extents[:, 0]
|
||||||
|
centers = np.mean(extents, axis=1)
|
||||||
|
maxsize = max(abs(sz))
|
||||||
|
r = maxsize / 2
|
||||||
|
if flip_x:
|
||||||
|
ax.set_xlim(centers[0] + r, centers[0] - r)
|
||||||
|
else:
|
||||||
|
ax.set_xlim(centers[0] - r, centers[0] + r)
|
||||||
|
# Invert y and z axis
|
||||||
|
ax.set_ylim(centers[1] + r, centers[1] - r)
|
||||||
|
ax.set_zlim(centers[2] + r, centers[2] - r)
|
||||||
274
manopth/manolayer.py
Normal file
274
manopth/manolayer.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from manopth.smpl_handpca_wrapper_HAND_only import ready_arguments
|
||||||
|
from manopth import rodrigues_layer, rotproj, rot6d
|
||||||
|
from manopth.tensutils import (th_posemap_axisang, th_with_zeros, th_pack,
|
||||||
|
subtract_flat_id, make_list)
|
||||||
|
|
||||||
|
|
||||||
|
class ManoLayer(Module):
|
||||||
|
__constants__ = [
|
||||||
|
'use_pca', 'rot', 'ncomps', 'ncomps', 'kintree_parents', 'check',
|
||||||
|
'side', 'center_idx', 'joint_rot_mode'
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
center_idx=None,
|
||||||
|
flat_hand_mean=True,
|
||||||
|
ncomps=6,
|
||||||
|
side='right',
|
||||||
|
mano_root='mano/models',
|
||||||
|
use_pca=True,
|
||||||
|
root_rot_mode='axisang',
|
||||||
|
joint_rot_mode='axisang',
|
||||||
|
robust_rot=False):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
center_idx: index of center joint in our computations,
|
||||||
|
if -1 centers on estimate of palm as middle of base
|
||||||
|
of middle finger and wrist
|
||||||
|
flat_hand_mean: if True, (0, 0, 0, ...) pose coefficients match
|
||||||
|
flat hand, else match average hand pose
|
||||||
|
mano_root: path to MANO pkl files for left and right hand
|
||||||
|
ncomps: number of PCA components form pose space (<45)
|
||||||
|
side: 'right' or 'left'
|
||||||
|
use_pca: Use PCA decomposition for pose space.
|
||||||
|
joint_rot_mode: 'axisang' or 'rotmat', ignored if use_pca
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.center_idx = center_idx
|
||||||
|
self.robust_rot = robust_rot
|
||||||
|
if root_rot_mode == 'axisang':
|
||||||
|
self.rot = 3
|
||||||
|
else:
|
||||||
|
self.rot = 6
|
||||||
|
self.flat_hand_mean = flat_hand_mean
|
||||||
|
self.side = side
|
||||||
|
self.use_pca = use_pca
|
||||||
|
self.joint_rot_mode = joint_rot_mode
|
||||||
|
self.root_rot_mode = root_rot_mode
|
||||||
|
if use_pca:
|
||||||
|
self.ncomps = ncomps
|
||||||
|
else:
|
||||||
|
self.ncomps = 45
|
||||||
|
|
||||||
|
if side == 'right':
|
||||||
|
self.mano_path = os.path.join(mano_root, 'MANO_RIGHT.pkl')
|
||||||
|
elif side == 'left':
|
||||||
|
self.mano_path = os.path.join(mano_root, 'MANO_LEFT.pkl')
|
||||||
|
|
||||||
|
smpl_data = ready_arguments(self.mano_path)
|
||||||
|
|
||||||
|
hands_components = smpl_data['hands_components']
|
||||||
|
|
||||||
|
self.smpl_data = smpl_data
|
||||||
|
|
||||||
|
self.register_buffer('th_betas',
|
||||||
|
torch.Tensor(smpl_data['betas']).unsqueeze(0))
|
||||||
|
self.register_buffer('th_shapedirs',
|
||||||
|
torch.Tensor(smpl_data['shapedirs']))
|
||||||
|
self.register_buffer('th_posedirs',
|
||||||
|
torch.Tensor(smpl_data['posedirs']))
|
||||||
|
self.register_buffer(
|
||||||
|
'th_v_template',
|
||||||
|
torch.Tensor(smpl_data['v_template']).unsqueeze(0))
|
||||||
|
self.register_buffer(
|
||||||
|
'th_J_regressor',
|
||||||
|
torch.Tensor(np.array(smpl_data['J_regressor'].toarray())))
|
||||||
|
self.register_buffer('th_weights',
|
||||||
|
torch.Tensor(smpl_data['weights']))
|
||||||
|
self.register_buffer('th_faces',
|
||||||
|
torch.Tensor(smpl_data['f'].astype(np.int32)).long())
|
||||||
|
|
||||||
|
# Get hand mean
|
||||||
|
hands_mean = np.zeros(hands_components.shape[1]
|
||||||
|
) if flat_hand_mean else smpl_data['hands_mean']
|
||||||
|
hands_mean = hands_mean.copy()
|
||||||
|
th_hands_mean = torch.Tensor(hands_mean).unsqueeze(0)
|
||||||
|
if self.use_pca or self.joint_rot_mode == 'axisang':
|
||||||
|
# Save as axis-angle
|
||||||
|
self.register_buffer('th_hands_mean', th_hands_mean)
|
||||||
|
selected_components = hands_components[:ncomps]
|
||||||
|
self.register_buffer('th_comps', torch.Tensor(hands_components))
|
||||||
|
self.register_buffer('th_selected_comps',
|
||||||
|
torch.Tensor(selected_components))
|
||||||
|
else:
|
||||||
|
th_hands_mean_rotmat = rodrigues_layer.batch_rodrigues(
|
||||||
|
th_hands_mean.view(15, 3)).reshape(15, 3, 3)
|
||||||
|
self.register_buffer('th_hands_mean_rotmat', th_hands_mean_rotmat)
|
||||||
|
|
||||||
|
# Kinematic chain params
|
||||||
|
self.kintree_table = smpl_data['kintree_table']
|
||||||
|
parents = list(self.kintree_table[0].tolist())
|
||||||
|
self.kintree_parents = parents
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
th_pose_coeffs,
|
||||||
|
th_betas=torch.zeros(1),
|
||||||
|
th_trans=torch.zeros(1),
|
||||||
|
root_palm=torch.Tensor([0]),
|
||||||
|
share_betas=torch.Tensor([0]),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
th_trans (Tensor (batch_size x ncomps)): if provided, applies trans to joints and vertices
|
||||||
|
th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters for hand shape
|
||||||
|
else centers on root joint (9th joint)
|
||||||
|
root_palm: return palm as hand root instead of wrist
|
||||||
|
"""
|
||||||
|
# if len(th_pose_coeffs) == 0:
|
||||||
|
# return th_pose_coeffs.new_empty(0), th_pose_coeffs.new_empty(0)
|
||||||
|
|
||||||
|
batch_size = th_pose_coeffs.shape[0]
|
||||||
|
# Get axis angle from PCA components and coefficients
|
||||||
|
if self.use_pca or self.joint_rot_mode == 'axisang':
|
||||||
|
# Remove global rot coeffs
|
||||||
|
th_hand_pose_coeffs = th_pose_coeffs[:, self.rot:self.rot +
|
||||||
|
self.ncomps]
|
||||||
|
if self.use_pca:
|
||||||
|
# PCA components --> axis angles
|
||||||
|
th_full_hand_pose = th_hand_pose_coeffs.mm(self.th_selected_comps)
|
||||||
|
else:
|
||||||
|
th_full_hand_pose = th_hand_pose_coeffs
|
||||||
|
|
||||||
|
# Concatenate back global rot
|
||||||
|
th_full_pose = torch.cat([
|
||||||
|
th_pose_coeffs[:, :self.rot],
|
||||||
|
self.th_hands_mean + th_full_hand_pose
|
||||||
|
], 1)
|
||||||
|
if self.root_rot_mode == 'axisang':
|
||||||
|
# compute rotation matrixes from axis-angle while skipping global rotation
|
||||||
|
th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose)
|
||||||
|
root_rot = th_rot_map[:, :9].view(batch_size, 3, 3)
|
||||||
|
th_rot_map = th_rot_map[:, 9:]
|
||||||
|
th_pose_map = th_pose_map[:, 9:]
|
||||||
|
else:
|
||||||
|
# th_posemap offsets by 3, so add offset or 3 to get to self.rot=6
|
||||||
|
th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose[:, 6:])
|
||||||
|
if self.robust_rot:
|
||||||
|
root_rot = rot6d.robust_compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6])
|
||||||
|
else:
|
||||||
|
root_rot = rot6d.compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6])
|
||||||
|
else:
|
||||||
|
assert th_pose_coeffs.dim() == 4, (
|
||||||
|
'When not self.use_pca, '
|
||||||
|
'th_pose_coeffs should have 4 dims, got {}'.format(
|
||||||
|
th_pose_coeffs.dim()))
|
||||||
|
assert th_pose_coeffs.shape[2:4] == (3, 3), (
|
||||||
|
'When not self.use_pca, th_pose_coeffs have 3x3 matrix for two'
|
||||||
|
'last dims, got {}'.format(th_pose_coeffs.shape[2:4]))
|
||||||
|
th_pose_rots = rotproj.batch_rotprojs(th_pose_coeffs)
|
||||||
|
th_rot_map = th_pose_rots[:, 1:].view(batch_size, -1)
|
||||||
|
th_pose_map = subtract_flat_id(th_rot_map)
|
||||||
|
root_rot = th_pose_rots[:, 0]
|
||||||
|
|
||||||
|
# Full axis angle representation with root joint
|
||||||
|
if th_betas is None or th_betas.numel() == 1:
|
||||||
|
th_v_shaped = torch.matmul(self.th_shapedirs,
|
||||||
|
self.th_betas.transpose(1, 0)).permute(
|
||||||
|
2, 0, 1) + self.th_v_template
|
||||||
|
th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat(
|
||||||
|
batch_size, 1, 1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if share_betas:
|
||||||
|
th_betas = th_betas.mean(0, keepdim=True).expand(th_betas.shape[0], 10)
|
||||||
|
th_v_shaped = torch.matmul(self.th_shapedirs,
|
||||||
|
th_betas.transpose(1, 0)).permute(
|
||||||
|
2, 0, 1) + self.th_v_template
|
||||||
|
th_j = torch.matmul(self.th_J_regressor, th_v_shaped)
|
||||||
|
# th_pose_map should have shape 20x135
|
||||||
|
|
||||||
|
th_v_posed = th_v_shaped + torch.matmul(
|
||||||
|
self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1)
|
||||||
|
# Final T pose with transformation done !
|
||||||
|
|
||||||
|
# Global rigid transformation
|
||||||
|
|
||||||
|
root_j = th_j[:, 0, :].contiguous().view(batch_size, 3, 1)
|
||||||
|
root_trans = th_with_zeros(torch.cat([root_rot, root_j], 2))
|
||||||
|
|
||||||
|
all_rots = th_rot_map.view(th_rot_map.shape[0], 15, 3, 3)
|
||||||
|
lev1_idxs = [1, 4, 7, 10, 13]
|
||||||
|
lev2_idxs = [2, 5, 8, 11, 14]
|
||||||
|
lev3_idxs = [3, 6, 9, 12, 15]
|
||||||
|
lev1_rots = all_rots[:, [idx - 1 for idx in lev1_idxs]]
|
||||||
|
lev2_rots = all_rots[:, [idx - 1 for idx in lev2_idxs]]
|
||||||
|
lev3_rots = all_rots[:, [idx - 1 for idx in lev3_idxs]]
|
||||||
|
lev1_j = th_j[:, lev1_idxs]
|
||||||
|
lev2_j = th_j[:, lev2_idxs]
|
||||||
|
lev3_j = th_j[:, lev3_idxs]
|
||||||
|
|
||||||
|
# From base to tips
|
||||||
|
# Get lev1 results
|
||||||
|
all_transforms = [root_trans.unsqueeze(1)]
|
||||||
|
lev1_j_rel = lev1_j - root_j.transpose(1, 2)
|
||||||
|
lev1_rel_transform_flt = th_with_zeros(torch.cat([lev1_rots, lev1_j_rel.unsqueeze(3)], 3).view(-1, 3, 4))
|
||||||
|
root_trans_flt = root_trans.unsqueeze(1).repeat(1, 5, 1, 1).view(root_trans.shape[0] * 5, 4, 4)
|
||||||
|
lev1_flt = torch.matmul(root_trans_flt, lev1_rel_transform_flt)
|
||||||
|
all_transforms.append(lev1_flt.view(all_rots.shape[0], 5, 4, 4))
|
||||||
|
|
||||||
|
# Get lev2 results
|
||||||
|
lev2_j_rel = lev2_j - lev1_j
|
||||||
|
lev2_rel_transform_flt = th_with_zeros(torch.cat([lev2_rots, lev2_j_rel.unsqueeze(3)], 3).view(-1, 3, 4))
|
||||||
|
lev2_flt = torch.matmul(lev1_flt, lev2_rel_transform_flt)
|
||||||
|
all_transforms.append(lev2_flt.view(all_rots.shape[0], 5, 4, 4))
|
||||||
|
|
||||||
|
# Get lev3 results
|
||||||
|
lev3_j_rel = lev3_j - lev2_j
|
||||||
|
lev3_rel_transform_flt = th_with_zeros(torch.cat([lev3_rots, lev3_j_rel.unsqueeze(3)], 3).view(-1, 3, 4))
|
||||||
|
lev3_flt = torch.matmul(lev2_flt, lev3_rel_transform_flt)
|
||||||
|
all_transforms.append(lev3_flt.view(all_rots.shape[0], 5, 4, 4))
|
||||||
|
|
||||||
|
reorder_idxs = [0, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14, 5, 10, 15]
|
||||||
|
th_results = torch.cat(all_transforms, 1)[:, reorder_idxs]
|
||||||
|
th_results_global = th_results
|
||||||
|
|
||||||
|
joint_js = torch.cat([th_j, th_j.new_zeros(th_j.shape[0], 16, 1)], 2)
|
||||||
|
tmp2 = torch.matmul(th_results, joint_js.unsqueeze(3))
|
||||||
|
th_results2 = (th_results - torch.cat([tmp2.new_zeros(*tmp2.shape[:2], 4, 3), tmp2], 3)).permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
th_T = torch.matmul(th_results2, self.th_weights.transpose(0, 1))
|
||||||
|
|
||||||
|
th_rest_shape_h = torch.cat([
|
||||||
|
th_v_posed.transpose(2, 1),
|
||||||
|
torch.ones((batch_size, 1, th_v_posed.shape[1]),
|
||||||
|
dtype=th_T.dtype,
|
||||||
|
device=th_T.device),
|
||||||
|
], 1)
|
||||||
|
|
||||||
|
th_verts = (th_T * th_rest_shape_h.unsqueeze(1)).sum(2).transpose(2, 1)
|
||||||
|
th_verts = th_verts[:, :, :3]
|
||||||
|
th_jtr = th_results_global[:, :, :3, 3]
|
||||||
|
# In addition to MANO reference joints we sample vertices on each finger
|
||||||
|
# to serve as finger tips
|
||||||
|
if self.side == 'right':
|
||||||
|
tips = th_verts[:, [745, 317, 444, 556, 673]]
|
||||||
|
else:
|
||||||
|
tips = th_verts[:, [745, 317, 445, 556, 673]]
|
||||||
|
if bool(root_palm):
|
||||||
|
palm = (th_verts[:, 95] + th_verts[:, 22]).unsqueeze(1) / 2
|
||||||
|
th_jtr = torch.cat([palm, th_jtr[:, 1:]], 1)
|
||||||
|
th_jtr = torch.cat([th_jtr, tips], 1)
|
||||||
|
|
||||||
|
# Reorder joints to match visualization utilities
|
||||||
|
th_jtr = th_jtr[:, [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]]
|
||||||
|
|
||||||
|
if th_trans is None or bool(torch.norm(th_trans) == 0):
|
||||||
|
if self.center_idx is not None:
|
||||||
|
center_joint = th_jtr[:, self.center_idx].unsqueeze(1)
|
||||||
|
th_jtr = th_jtr - center_joint
|
||||||
|
th_verts = th_verts - center_joint
|
||||||
|
else:
|
||||||
|
th_jtr = th_jtr + th_trans.unsqueeze(1)
|
||||||
|
th_verts = th_verts + th_trans.unsqueeze(1)
|
||||||
|
|
||||||
|
# Scale to milimeters
|
||||||
|
th_verts = th_verts * 1000
|
||||||
|
th_jtr = th_jtr * 1000
|
||||||
|
return th_verts, th_jtr
|
||||||
37
manopth/posemapper.py
Normal file
37
manopth/posemapper.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
'''
|
||||||
|
Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
|
||||||
|
This software is provided for research purposes only.
|
||||||
|
By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
|
||||||
|
|
||||||
|
More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
|
||||||
|
For comments or questions, please email us at: mano@tue.mpg.de
|
||||||
|
|
||||||
|
|
||||||
|
About this file:
|
||||||
|
================
|
||||||
|
This file defines a wrapper for the loading functions of the MANO model.
|
||||||
|
|
||||||
|
Modules included:
|
||||||
|
- load_model:
|
||||||
|
loads the MANO model from a given file location (i.e. a .pkl file location),
|
||||||
|
or a dictionary object.
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
def lrotmin(p):
|
||||||
|
if isinstance(p, np.ndarray):
|
||||||
|
p = p.ravel()[3:]
|
||||||
|
return np.concatenate(
|
||||||
|
[(cv2.Rodrigues(np.array(pp))[0] - np.eye(3)).ravel()
|
||||||
|
for pp in p.reshape((-1, 3))]).ravel()
|
||||||
|
|
||||||
|
|
||||||
|
def posemap(s):
|
||||||
|
if s == 'lrotmin':
|
||||||
|
return lrotmin
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown posemapping: %s' % (str(s), ))
|
||||||
89
manopth/rodrigues_layer.py
Normal file
89
manopth/rodrigues_layer.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""
|
||||||
|
This part reuses code from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py
|
||||||
|
which is part of a PyTorch port of SMPL.
|
||||||
|
Thanks to Zhang Xiong (MandyMo) for making this great code available on github !
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from torch.autograd import gradcheck
|
||||||
|
import torch
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
from manopth import argutils
|
||||||
|
|
||||||
|
|
||||||
|
def quat2mat(quat):
|
||||||
|
"""Convert quaternion coefficients to rotation matrix.
|
||||||
|
Args:
|
||||||
|
quat: size = [batch_size, 4] 4 <===>(w, x, y, z)
|
||||||
|
Returns:
|
||||||
|
Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3]
|
||||||
|
"""
|
||||||
|
norm_quat = quat
|
||||||
|
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
|
||||||
|
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:,
|
||||||
|
2], norm_quat[:,
|
||||||
|
3]
|
||||||
|
|
||||||
|
batch_size = quat.size(0)
|
||||||
|
|
||||||
|
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
||||||
|
wx, wy, wz = w * x, w * y, w * z
|
||||||
|
xy, xz, yz = x * y, x * z, y * z
|
||||||
|
|
||||||
|
rotMat = torch.stack([
|
||||||
|
w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy,
|
||||||
|
w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz,
|
||||||
|
w2 - x2 - y2 + z2
|
||||||
|
],
|
||||||
|
dim=1).view(batch_size, 3, 3)
|
||||||
|
return rotMat
|
||||||
|
|
||||||
|
|
||||||
|
def batch_rodrigues(axisang):
|
||||||
|
#axisang N x 3
|
||||||
|
axisang_norm = torch.norm(axisang + 1e-8, p=2, dim=1)
|
||||||
|
angle = torch.unsqueeze(axisang_norm, -1)
|
||||||
|
axisang_normalized = torch.div(axisang, angle)
|
||||||
|
angle = angle * 0.5
|
||||||
|
v_cos = torch.cos(angle)
|
||||||
|
v_sin = torch.sin(angle)
|
||||||
|
quat = torch.cat([v_cos, v_sin * axisang_normalized], dim=1)
|
||||||
|
rot_mat = quat2mat(quat)
|
||||||
|
rot_mat = rot_mat.view(rot_mat.shape[0], 9)
|
||||||
|
return rot_mat
|
||||||
|
|
||||||
|
|
||||||
|
def th_get_axis_angle(vector):
|
||||||
|
angle = torch.norm(vector, 2, 1)
|
||||||
|
axes = vector / angle.unsqueeze(1)
|
||||||
|
return axes, angle
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--batch_size', default=1, type=int)
|
||||||
|
parser.add_argument('--cuda', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
argutils.print_args(args)
|
||||||
|
|
||||||
|
n_components = 6
|
||||||
|
rot = 3
|
||||||
|
inputs = torch.rand(args.batch_size, rot)
|
||||||
|
inputs_var = Variable(inputs.double(), requires_grad=True)
|
||||||
|
if args.cuda:
|
||||||
|
inputs = inputs.cuda()
|
||||||
|
# outputs = batch_rodrigues(inputs)
|
||||||
|
test_function = gradcheck(batch_rodrigues, (inputs_var, ))
|
||||||
|
print('batch test passed !')
|
||||||
|
|
||||||
|
inputs = torch.rand(rot)
|
||||||
|
inputs_var = Variable(inputs.double(), requires_grad=True)
|
||||||
|
test_function = gradcheck(th_cv2_rod_sub_id.apply, (inputs_var, ))
|
||||||
|
print('th_cv2_rod test passed')
|
||||||
|
|
||||||
|
inputs = torch.rand(rot)
|
||||||
|
inputs_var = Variable(inputs.double(), requires_grad=True)
|
||||||
|
test_th = gradcheck(th_cv2_rod.apply, (inputs_var, ))
|
||||||
|
print('th_cv2_rod_id test passed !')
|
||||||
71
manopth/rot6d.py
Normal file
71
manopth/rot6d.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def compute_rotation_matrix_from_ortho6d(poses):
|
||||||
|
"""
|
||||||
|
Code from
|
||||||
|
https://github.com/papagina/RotationContinuity
|
||||||
|
On the Continuity of Rotation Representations in Neural Networks
|
||||||
|
Zhou et al. CVPR19
|
||||||
|
https://zhouyisjtu.github.io/project_rotation/rotation.html
|
||||||
|
"""
|
||||||
|
x_raw = poses[:, 0:3] # batch*3
|
||||||
|
y_raw = poses[:, 3:6] # batch*3
|
||||||
|
|
||||||
|
x = normalize_vector(x_raw) # batch*3
|
||||||
|
z = cross_product(x, y_raw) # batch*3
|
||||||
|
z = normalize_vector(z) # batch*3
|
||||||
|
y = cross_product(z, x) # batch*3
|
||||||
|
|
||||||
|
x = x.view(-1, 3, 1)
|
||||||
|
y = y.view(-1, 3, 1)
|
||||||
|
z = z.view(-1, 3, 1)
|
||||||
|
matrix = torch.cat((x, y, z), 2) # batch*3*3
|
||||||
|
return matrix
|
||||||
|
|
||||||
|
def robust_compute_rotation_matrix_from_ortho6d(poses):
|
||||||
|
"""
|
||||||
|
Instead of making 2nd vector orthogonal to first
|
||||||
|
create a base that takes into account the two predicted
|
||||||
|
directions equally
|
||||||
|
"""
|
||||||
|
x_raw = poses[:, 0:3] # batch*3
|
||||||
|
y_raw = poses[:, 3:6] # batch*3
|
||||||
|
|
||||||
|
x = normalize_vector(x_raw) # batch*3
|
||||||
|
y = normalize_vector(y_raw) # batch*3
|
||||||
|
middle = normalize_vector(x + y)
|
||||||
|
orthmid = normalize_vector(x - y)
|
||||||
|
x = normalize_vector(middle + orthmid)
|
||||||
|
y = normalize_vector(middle - orthmid)
|
||||||
|
# Their scalar product should be small !
|
||||||
|
# assert torch.einsum("ij,ij->i", [x, y]).abs().max() < 0.00001
|
||||||
|
z = normalize_vector(cross_product(x, y))
|
||||||
|
|
||||||
|
x = x.view(-1, 3, 1)
|
||||||
|
y = y.view(-1, 3, 1)
|
||||||
|
z = z.view(-1, 3, 1)
|
||||||
|
matrix = torch.cat((x, y, z), 2) # batch*3*3
|
||||||
|
# Check for reflection in matrix ! If found, flip last vector TODO
|
||||||
|
assert (torch.stack([torch.det(mat) for mat in matrix ])< 0).sum() == 0
|
||||||
|
return matrix
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_vector(v):
|
||||||
|
batch = v.shape[0]
|
||||||
|
v_mag = torch.sqrt(v.pow(2).sum(1)) # batch
|
||||||
|
v_mag = torch.max(v_mag, v.new([1e-8]))
|
||||||
|
v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1])
|
||||||
|
v = v/v_mag
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def cross_product(u, v):
|
||||||
|
batch = u.shape[0]
|
||||||
|
i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1]
|
||||||
|
j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2]
|
||||||
|
k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0]
|
||||||
|
|
||||||
|
out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1)
|
||||||
|
|
||||||
|
return out
|
||||||
21
manopth/rotproj.py
Normal file
21
manopth/rotproj.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def batch_rotprojs(batches_rotmats):
|
||||||
|
proj_rotmats = []
|
||||||
|
for batch_idx, batch_rotmats in enumerate(batches_rotmats):
|
||||||
|
proj_batch_rotmats = []
|
||||||
|
for rot_idx, rotmat in enumerate(batch_rotmats):
|
||||||
|
# GPU implementation of svd is VERY slow
|
||||||
|
# ~ 2 10^-3 per hit vs 5 10^-5 on cpu
|
||||||
|
U, S, V = rotmat.cpu().svd()
|
||||||
|
rotmat = torch.matmul(U, V.transpose(0, 1))
|
||||||
|
orth_det = rotmat.det()
|
||||||
|
# Remove reflection
|
||||||
|
if orth_det < 0:
|
||||||
|
rotmat[:, 2] = -1 * rotmat[:, 2]
|
||||||
|
|
||||||
|
rotmat = rotmat.cuda()
|
||||||
|
proj_batch_rotmats.append(rotmat)
|
||||||
|
proj_rotmats.append(torch.stack(proj_batch_rotmats))
|
||||||
|
return torch.stack(proj_rotmats)
|
||||||
155
manopth/smpl_handpca_wrapper_HAND_only.py
Normal file
155
manopth/smpl_handpca_wrapper_HAND_only.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
'''
|
||||||
|
Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
|
||||||
|
This software is provided for research purposes only.
|
||||||
|
By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
|
||||||
|
|
||||||
|
More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
|
||||||
|
For comments or questions, please email us at: mano@tue.mpg.de
|
||||||
|
|
||||||
|
|
||||||
|
About this file:
|
||||||
|
================
|
||||||
|
This file defines a wrapper for the loading functions of the MANO model.
|
||||||
|
|
||||||
|
Modules included:
|
||||||
|
- load_model:
|
||||||
|
loads the MANO model from a given file location (i.e. a .pkl file location),
|
||||||
|
or a dictionary object.
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
def col(A):
|
||||||
|
return A.reshape((-1, 1))
|
||||||
|
|
||||||
|
def MatVecMult(mtx, vec):
|
||||||
|
result = mtx.dot(col(vec.ravel())).ravel()
|
||||||
|
if len(vec.shape) > 1 and vec.shape[1] > 1:
|
||||||
|
result = result.reshape((-1, vec.shape[1]))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def ready_arguments(fname_or_dict, posekey4vposed='pose'):
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
from manopth.posemapper import posemap
|
||||||
|
|
||||||
|
if not isinstance(fname_or_dict, dict):
|
||||||
|
dd = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1')
|
||||||
|
# dd = pickle.load(open(fname_or_dict, 'rb'))
|
||||||
|
else:
|
||||||
|
dd = fname_or_dict
|
||||||
|
|
||||||
|
want_shapemodel = 'shapedirs' in dd
|
||||||
|
nposeparms = dd['kintree_table'].shape[1] * 3
|
||||||
|
|
||||||
|
if 'trans' not in dd:
|
||||||
|
dd['trans'] = np.zeros(3)
|
||||||
|
if 'pose' not in dd:
|
||||||
|
dd['pose'] = np.zeros(nposeparms)
|
||||||
|
if 'shapedirs' in dd and 'betas' not in dd:
|
||||||
|
dd['betas'] = np.zeros(dd['shapedirs'].shape[-1])
|
||||||
|
|
||||||
|
for s in [
|
||||||
|
'v_template', 'weights', 'posedirs', 'pose', 'trans', 'shapedirs',
|
||||||
|
'betas', 'J'
|
||||||
|
]:
|
||||||
|
if (s in dd) and not hasattr(dd[s], 'dterms'):
|
||||||
|
dd[s] = np.array(dd[s])
|
||||||
|
|
||||||
|
assert (posekey4vposed in dd)
|
||||||
|
if want_shapemodel:
|
||||||
|
dd['v_shaped'] = dd['shapedirs'].dot(dd['betas']) + dd['v_template']
|
||||||
|
v_shaped = dd['v_shaped']
|
||||||
|
J_tmpx = MatVecMult(dd['J_regressor'], v_shaped[:, 0])
|
||||||
|
J_tmpy = MatVecMult(dd['J_regressor'], v_shaped[:, 1])
|
||||||
|
J_tmpz = MatVecMult(dd['J_regressor'], v_shaped[:, 2])
|
||||||
|
dd['J'] = np.vstack((J_tmpx, J_tmpy, J_tmpz)).T
|
||||||
|
pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed])
|
||||||
|
dd['v_posed'] = v_shaped + dd['posedirs'].dot(pose_map_res)
|
||||||
|
else:
|
||||||
|
pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed])
|
||||||
|
dd_add = dd['posedirs'].dot(pose_map_res)
|
||||||
|
dd['v_posed'] = dd['v_template'] + dd_add
|
||||||
|
|
||||||
|
return dd
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(fname_or_dict, ncomps=6, flat_hand_mean=False, v_template=None):
|
||||||
|
''' This model loads the fully articulable HAND SMPL model,
|
||||||
|
and replaces the pose DOFS by ncomps from PCA'''
|
||||||
|
|
||||||
|
from manopth.verts import verts_core
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
import scipy.sparse as sp
|
||||||
|
np.random.seed(1)
|
||||||
|
|
||||||
|
if not isinstance(fname_or_dict, dict):
|
||||||
|
smpl_data = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1')
|
||||||
|
# smpl_data = pickle.load(open(fname_or_dict, 'rb'))
|
||||||
|
else:
|
||||||
|
smpl_data = fname_or_dict
|
||||||
|
|
||||||
|
rot = 3 # for global orientation!!!
|
||||||
|
|
||||||
|
hands_components = smpl_data['hands_components']
|
||||||
|
hands_mean = np.zeros(hands_components.shape[
|
||||||
|
1]) if flat_hand_mean else smpl_data['hands_mean']
|
||||||
|
hands_coeffs = smpl_data['hands_coeffs'][:, :ncomps]
|
||||||
|
|
||||||
|
selected_components = np.vstack((hands_components[:ncomps]))
|
||||||
|
hands_mean = hands_mean.copy()
|
||||||
|
|
||||||
|
pose_coeffs = np.zeros(rot + selected_components.shape[0])
|
||||||
|
full_hand_pose = pose_coeffs[rot:(rot + ncomps)].dot(selected_components)
|
||||||
|
|
||||||
|
smpl_data['fullpose'] = np.concatenate((pose_coeffs[:rot],
|
||||||
|
hands_mean + full_hand_pose))
|
||||||
|
smpl_data['pose'] = pose_coeffs
|
||||||
|
|
||||||
|
Jreg = smpl_data['J_regressor']
|
||||||
|
if not sp.issparse(Jreg):
|
||||||
|
smpl_data['J_regressor'] = (sp.csc_matrix(
|
||||||
|
(Jreg.data, (Jreg.row, Jreg.col)), shape=Jreg.shape))
|
||||||
|
|
||||||
|
# slightly modify ready_arguments to make sure that it uses the fullpose
|
||||||
|
# (which will NOT be pose) for the computation of posedirs
|
||||||
|
dd = ready_arguments(smpl_data, posekey4vposed='fullpose')
|
||||||
|
|
||||||
|
# create the smpl formula with the fullpose,
|
||||||
|
# but expose the PCA coefficients as smpl.pose for compatibility
|
||||||
|
args = {
|
||||||
|
'pose': dd['fullpose'],
|
||||||
|
'v': dd['v_posed'],
|
||||||
|
'J': dd['J'],
|
||||||
|
'weights': dd['weights'],
|
||||||
|
'kintree_table': dd['kintree_table'],
|
||||||
|
'xp': np,
|
||||||
|
'want_Jtr': True,
|
||||||
|
'bs_style': dd['bs_style'],
|
||||||
|
}
|
||||||
|
|
||||||
|
result_previous, meta = verts_core(**args)
|
||||||
|
|
||||||
|
result = result_previous + dd['trans'].reshape((1, 3))
|
||||||
|
result.no_translation = result_previous
|
||||||
|
|
||||||
|
if meta is not None:
|
||||||
|
for field in ['Jtr', 'A', 'A_global', 'A_weighted']:
|
||||||
|
if (hasattr(meta, field)):
|
||||||
|
setattr(result, field, getattr(meta, field))
|
||||||
|
|
||||||
|
setattr(result, 'Jtr', meta)
|
||||||
|
if hasattr(result, 'Jtr'):
|
||||||
|
result.J_transformed = result.Jtr + dd['trans'].reshape((1, 3))
|
||||||
|
|
||||||
|
for k, v in dd.items():
|
||||||
|
setattr(result, k, v)
|
||||||
|
|
||||||
|
if v_template is not None:
|
||||||
|
result.v_template[:] = v_template
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
load_model()
|
||||||
47
manopth/tensutils.py
Normal file
47
manopth/tensutils.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from manopth import rodrigues_layer
|
||||||
|
|
||||||
|
|
||||||
|
def th_posemap_axisang(pose_vectors):
|
||||||
|
rot_nb = int(pose_vectors.shape[1] / 3)
|
||||||
|
pose_vec_reshaped = pose_vectors.contiguous().view(-1, 3)
|
||||||
|
rot_mats = rodrigues_layer.batch_rodrigues(pose_vec_reshaped)
|
||||||
|
rot_mats = rot_mats.view(pose_vectors.shape[0], rot_nb * 9)
|
||||||
|
pose_maps = subtract_flat_id(rot_mats)
|
||||||
|
return pose_maps, rot_mats
|
||||||
|
|
||||||
|
|
||||||
|
def th_with_zeros(tensor):
|
||||||
|
batch_size = tensor.shape[0]
|
||||||
|
padding = tensor.new([0.0, 0.0, 0.0, 1.0])
|
||||||
|
padding.requires_grad = False
|
||||||
|
|
||||||
|
concat_list = [tensor, padding.view(1, 1, 4).repeat(batch_size, 1, 1)]
|
||||||
|
cat_res = torch.cat(concat_list, 1)
|
||||||
|
return cat_res
|
||||||
|
|
||||||
|
|
||||||
|
def th_pack(tensor):
|
||||||
|
batch_size = tensor.shape[0]
|
||||||
|
padding = tensor.new_zeros((batch_size, 4, 3))
|
||||||
|
padding.requires_grad = False
|
||||||
|
pack_list = [padding, tensor]
|
||||||
|
pack_res = torch.cat(pack_list, 2)
|
||||||
|
return pack_res
|
||||||
|
|
||||||
|
|
||||||
|
def subtract_flat_id(rot_mats):
|
||||||
|
# Subtracts identity as a flattened tensor
|
||||||
|
rot_nb = int(rot_mats.shape[1] / 9)
|
||||||
|
id_flat = torch.eye(
|
||||||
|
3, dtype=rot_mats.dtype, device=rot_mats.device).view(1, 9).repeat(
|
||||||
|
rot_mats.shape[0], rot_nb)
|
||||||
|
# id_flat.requires_grad = False
|
||||||
|
results = rot_mats - id_flat
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def make_list(tensor):
|
||||||
|
# type: (List[int]) -> List[int]
|
||||||
|
return tensor
|
||||||
117
manopth/verts.py
Normal file
117
manopth/verts.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
'''
|
||||||
|
Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
|
||||||
|
This software is provided for research purposes only.
|
||||||
|
By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
|
||||||
|
|
||||||
|
More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
|
||||||
|
For comments or questions, please email us at: mano@tue.mpg.de
|
||||||
|
|
||||||
|
|
||||||
|
About this file:
|
||||||
|
================
|
||||||
|
This file defines a wrapper for the loading functions of the MANO model.
|
||||||
|
|
||||||
|
Modules included:
|
||||||
|
- load_model:
|
||||||
|
loads the MANO model from a given file location (i.e. a .pkl file location),
|
||||||
|
or a dictionary object.
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import mano.webuser.lbs as lbs
|
||||||
|
from mano.webuser.posemapper import posemap
|
||||||
|
import scipy.sparse as sp
|
||||||
|
|
||||||
|
|
||||||
|
def ischumpy(x):
|
||||||
|
return hasattr(x, 'dterms')
|
||||||
|
|
||||||
|
|
||||||
|
def verts_decorated(trans,
|
||||||
|
pose,
|
||||||
|
v_template,
|
||||||
|
J_regressor,
|
||||||
|
weights,
|
||||||
|
kintree_table,
|
||||||
|
bs_style,
|
||||||
|
f,
|
||||||
|
bs_type=None,
|
||||||
|
posedirs=None,
|
||||||
|
betas=None,
|
||||||
|
shapedirs=None,
|
||||||
|
want_Jtr=False):
|
||||||
|
|
||||||
|
for which in [
|
||||||
|
trans, pose, v_template, weights, posedirs, betas, shapedirs
|
||||||
|
]:
|
||||||
|
if which is not None:
|
||||||
|
assert ischumpy(which)
|
||||||
|
|
||||||
|
v = v_template
|
||||||
|
|
||||||
|
if shapedirs is not None:
|
||||||
|
if betas is None:
|
||||||
|
betas = np.zeros(shapedirs.shape[-1])
|
||||||
|
v_shaped = v + shapedirs.dot(betas)
|
||||||
|
else:
|
||||||
|
v_shaped = v
|
||||||
|
|
||||||
|
if posedirs is not None:
|
||||||
|
v_posed = v_shaped + posedirs.dot(posemap(bs_type)(pose))
|
||||||
|
else:
|
||||||
|
v_posed = v_shaped
|
||||||
|
|
||||||
|
v = v_posed
|
||||||
|
|
||||||
|
if sp.issparse(J_regressor):
|
||||||
|
J_tmpx = np.matmul(J_regressor, v_shaped[:, 0])
|
||||||
|
J_tmpy = np.matmul(J_regressor, v_shaped[:, 1])
|
||||||
|
J_tmpz = np.matmul(J_regressor, v_shaped[:, 2])
|
||||||
|
J = np.vstack((J_tmpx, J_tmpy, J_tmpz)).T
|
||||||
|
else:
|
||||||
|
assert (ischumpy(J))
|
||||||
|
|
||||||
|
assert (bs_style == 'lbs')
|
||||||
|
result, Jtr = lbs.verts_core(
|
||||||
|
pose, v, J, weights, kintree_table, want_Jtr=True, xp=np)
|
||||||
|
|
||||||
|
tr = trans.reshape((1, 3))
|
||||||
|
result = result + tr
|
||||||
|
Jtr = Jtr + tr
|
||||||
|
|
||||||
|
result.trans = trans
|
||||||
|
result.f = f
|
||||||
|
result.pose = pose
|
||||||
|
result.v_template = v_template
|
||||||
|
result.J = J
|
||||||
|
result.J_regressor = J_regressor
|
||||||
|
result.weights = weights
|
||||||
|
result.kintree_table = kintree_table
|
||||||
|
result.bs_style = bs_style
|
||||||
|
result.bs_type = bs_type
|
||||||
|
if posedirs is not None:
|
||||||
|
result.posedirs = posedirs
|
||||||
|
result.v_posed = v_posed
|
||||||
|
if shapedirs is not None:
|
||||||
|
result.shapedirs = shapedirs
|
||||||
|
result.betas = betas
|
||||||
|
result.v_shaped = v_shaped
|
||||||
|
if want_Jtr:
|
||||||
|
result.J_transformed = Jtr
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def verts_core(pose,
|
||||||
|
v,
|
||||||
|
J,
|
||||||
|
weights,
|
||||||
|
kintree_table,
|
||||||
|
bs_style,
|
||||||
|
want_Jtr=False,
|
||||||
|
xp=np):
|
||||||
|
|
||||||
|
assert (bs_style == 'lbs')
|
||||||
|
result = lbs.verts_core(pose, v, J, weights, kintree_table, want_Jtr, xp)
|
||||||
|
return result
|
||||||
1
mesh_graphormer/__init__.py
Normal file
1
mesh_graphormer/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__version__ = "0.1.0"
|
||||||
1
mesh_graphormer/datasets/__init__.py
Normal file
1
mesh_graphormer/datasets/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
147
mesh_graphormer/datasets/build.py
Normal file
147
mesh_graphormer/datasets/build.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import os.path as op
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import code
|
||||||
|
from mesh_graphormer.utils.comm import get_world_size
|
||||||
|
from mesh_graphormer.datasets.human_mesh_tsv import (MeshTSVDataset, MeshTSVYamlDataset)
|
||||||
|
from mesh_graphormer.datasets.hand_mesh_tsv import (HandMeshTSVDataset, HandMeshTSVYamlDataset)
|
||||||
|
|
||||||
|
|
||||||
|
def build_dataset(yaml_file, args, is_train=True, scale_factor=1):
|
||||||
|
print(yaml_file)
|
||||||
|
if not op.isfile(yaml_file):
|
||||||
|
yaml_file = op.join(args.data_dir, yaml_file)
|
||||||
|
# code.interact(local=locals())
|
||||||
|
assert op.isfile(yaml_file)
|
||||||
|
return MeshTSVYamlDataset(yaml_file, is_train, False, scale_factor)
|
||||||
|
|
||||||
|
|
||||||
|
class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler):
|
||||||
|
"""
|
||||||
|
Wraps a BatchSampler, resampling from it until
|
||||||
|
a specified number of iterations have been sampled
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, batch_sampler, num_iterations, start_iter=0):
|
||||||
|
self.batch_sampler = batch_sampler
|
||||||
|
self.num_iterations = num_iterations
|
||||||
|
self.start_iter = start_iter
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
iteration = self.start_iter
|
||||||
|
while iteration <= self.num_iterations:
|
||||||
|
# if the underlying sampler has a set_epoch method, like
|
||||||
|
# DistributedSampler, used for making each process see
|
||||||
|
# a different split of the dataset, then set it
|
||||||
|
if hasattr(self.batch_sampler.sampler, "set_epoch"):
|
||||||
|
self.batch_sampler.sampler.set_epoch(iteration)
|
||||||
|
for batch in self.batch_sampler:
|
||||||
|
iteration += 1
|
||||||
|
if iteration > self.num_iterations:
|
||||||
|
break
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_iterations
|
||||||
|
|
||||||
|
|
||||||
|
def make_batch_data_sampler(sampler, images_per_gpu, num_iters=None, start_iter=0):
|
||||||
|
batch_sampler = torch.utils.data.sampler.BatchSampler(
|
||||||
|
sampler, images_per_gpu, drop_last=False
|
||||||
|
)
|
||||||
|
if num_iters is not None and num_iters >= 0:
|
||||||
|
batch_sampler = IterationBasedBatchSampler(
|
||||||
|
batch_sampler, num_iters, start_iter
|
||||||
|
)
|
||||||
|
return batch_sampler
|
||||||
|
|
||||||
|
|
||||||
|
def make_data_sampler(dataset, shuffle, distributed):
|
||||||
|
if distributed:
|
||||||
|
return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||||
|
if shuffle:
|
||||||
|
sampler = torch.utils.data.sampler.RandomSampler(dataset)
|
||||||
|
else:
|
||||||
|
sampler = torch.utils.data.sampler.SequentialSampler(dataset)
|
||||||
|
return sampler
|
||||||
|
|
||||||
|
|
||||||
|
def make_data_loader(args, yaml_file, is_distributed=True,
|
||||||
|
is_train=True, start_iter=0, scale_factor=1):
|
||||||
|
|
||||||
|
dataset = build_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
if is_train==True:
|
||||||
|
shuffle = True
|
||||||
|
images_per_gpu = args.per_gpu_train_batch_size
|
||||||
|
images_per_batch = images_per_gpu * get_world_size()
|
||||||
|
iters_per_batch = len(dataset) // images_per_batch
|
||||||
|
num_iters = iters_per_batch * args.num_train_epochs
|
||||||
|
logger.info("Train with {} images per GPU.".format(images_per_gpu))
|
||||||
|
logger.info("Total batch size {}".format(images_per_batch))
|
||||||
|
logger.info("Total training steps {}".format(num_iters))
|
||||||
|
else:
|
||||||
|
shuffle = False
|
||||||
|
images_per_gpu = args.per_gpu_eval_batch_size
|
||||||
|
num_iters = None
|
||||||
|
start_iter = 0
|
||||||
|
|
||||||
|
sampler = make_data_sampler(dataset, shuffle, is_distributed)
|
||||||
|
batch_sampler = make_batch_data_sampler(
|
||||||
|
sampler, images_per_gpu, num_iters, start_iter
|
||||||
|
)
|
||||||
|
data_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset, num_workers=args.num_workers, batch_sampler=batch_sampler,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
return data_loader
|
||||||
|
|
||||||
|
|
||||||
|
#==============================================================================================
|
||||||
|
|
||||||
|
def build_hand_dataset(yaml_file, args, is_train=True, scale_factor=1):
|
||||||
|
print(yaml_file)
|
||||||
|
if not op.isfile(yaml_file):
|
||||||
|
yaml_file = op.join(args.data_dir, yaml_file)
|
||||||
|
# code.interact(local=locals())
|
||||||
|
assert op.isfile(yaml_file)
|
||||||
|
return HandMeshTSVYamlDataset(args, yaml_file, is_train, False, scale_factor)
|
||||||
|
|
||||||
|
|
||||||
|
def make_hand_data_loader(args, yaml_file, is_distributed=True,
|
||||||
|
is_train=True, start_iter=0, scale_factor=1):
|
||||||
|
|
||||||
|
dataset = build_hand_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
if is_train==True:
|
||||||
|
shuffle = True
|
||||||
|
images_per_gpu = args.per_gpu_train_batch_size
|
||||||
|
images_per_batch = images_per_gpu * get_world_size()
|
||||||
|
iters_per_batch = len(dataset) // images_per_batch
|
||||||
|
num_iters = iters_per_batch * args.num_train_epochs
|
||||||
|
logger.info("Train with {} images per GPU.".format(images_per_gpu))
|
||||||
|
logger.info("Total batch size {}".format(images_per_batch))
|
||||||
|
logger.info("Total training steps {}".format(num_iters))
|
||||||
|
else:
|
||||||
|
shuffle = False
|
||||||
|
images_per_gpu = args.per_gpu_eval_batch_size
|
||||||
|
num_iters = None
|
||||||
|
start_iter = 0
|
||||||
|
|
||||||
|
sampler = make_data_sampler(dataset, shuffle, is_distributed)
|
||||||
|
batch_sampler = make_batch_data_sampler(
|
||||||
|
sampler, images_per_gpu, num_iters, start_iter
|
||||||
|
)
|
||||||
|
data_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset, num_workers=args.num_workers, batch_sampler=batch_sampler,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
return data_loader
|
||||||
|
|
||||||
334
mesh_graphormer/datasets/hand_mesh_tsv.py
Normal file
334
mesh_graphormer/datasets/hand_mesh_tsv.py
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
import json
|
||||||
|
from PIL import Image
|
||||||
|
import os.path as op
|
||||||
|
import numpy as np
|
||||||
|
import code
|
||||||
|
|
||||||
|
from mesh_graphormer.utils.tsv_file import TSVFile, CompositeTSVFile
|
||||||
|
from mesh_graphormer.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml
|
||||||
|
from mesh_graphormer.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
|
||||||
|
class HandMeshTSVDataset(object):
|
||||||
|
def __init__(self, args, img_file, label_file=None, hw_file=None,
|
||||||
|
linelist_file=None, is_train=True, cv2_output=False, scale_factor=1):
|
||||||
|
|
||||||
|
self.args = args
|
||||||
|
self.img_file = img_file
|
||||||
|
self.label_file = label_file
|
||||||
|
self.hw_file = hw_file
|
||||||
|
self.linelist_file = linelist_file
|
||||||
|
self.img_tsv = self.get_tsv_file(img_file)
|
||||||
|
self.label_tsv = None if label_file is None else self.get_tsv_file(label_file)
|
||||||
|
self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file)
|
||||||
|
|
||||||
|
if self.is_composite:
|
||||||
|
assert op.isfile(self.linelist_file)
|
||||||
|
self.line_list = [i for i in range(self.hw_tsv.num_rows())]
|
||||||
|
else:
|
||||||
|
self.line_list = load_linelist_file(linelist_file)
|
||||||
|
|
||||||
|
self.cv2_output = cv2_output
|
||||||
|
self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225])
|
||||||
|
self.is_train = is_train
|
||||||
|
self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor]
|
||||||
|
self.noise_factor = 0.4
|
||||||
|
self.rot_factor = 90 # Random rotation in the range [-rot_factor, rot_factor]
|
||||||
|
self.img_res = 224
|
||||||
|
self.image_keys = self.prepare_image_keys()
|
||||||
|
self.joints_definition = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1',
|
||||||
|
'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
|
||||||
|
self.root_index = self.joints_definition.index('Wrist')
|
||||||
|
|
||||||
|
def get_tsv_file(self, tsv_file):
|
||||||
|
if tsv_file:
|
||||||
|
if self.is_composite:
|
||||||
|
return CompositeTSVFile(tsv_file, self.linelist_file,
|
||||||
|
root=self.root)
|
||||||
|
tsv_path = find_file_path_in_yaml(tsv_file, self.root)
|
||||||
|
return TSVFile(tsv_path)
|
||||||
|
|
||||||
|
def get_valid_tsv(self):
|
||||||
|
# sorted by file size
|
||||||
|
if self.hw_tsv:
|
||||||
|
return self.hw_tsv
|
||||||
|
if self.label_tsv:
|
||||||
|
return self.label_tsv
|
||||||
|
|
||||||
|
def prepare_image_keys(self):
|
||||||
|
tsv = self.get_valid_tsv()
|
||||||
|
return [tsv.get_key(i) for i in range(tsv.num_rows())]
|
||||||
|
|
||||||
|
def prepare_image_key_to_index(self):
|
||||||
|
tsv = self.get_valid_tsv()
|
||||||
|
return {tsv.get_key(i) : i for i in range(tsv.num_rows())}
|
||||||
|
|
||||||
|
|
||||||
|
def augm_params(self):
|
||||||
|
"""Get augmentation parameters."""
|
||||||
|
flip = 0 # flipping
|
||||||
|
pn = np.ones(3) # per channel pixel-noise
|
||||||
|
|
||||||
|
if self.args.multiscale_inference == False:
|
||||||
|
rot = 0 # rotation
|
||||||
|
sc = 1.0 # scaling
|
||||||
|
elif self.args.multiscale_inference == True:
|
||||||
|
rot = self.args.rot
|
||||||
|
sc = self.args.sc
|
||||||
|
|
||||||
|
if self.is_train:
|
||||||
|
sc = 1.0
|
||||||
|
# Each channel is multiplied with a number
|
||||||
|
# in the area [1-opt.noiseFactor,1+opt.noiseFactor]
|
||||||
|
pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3)
|
||||||
|
|
||||||
|
# The rotation is a number in the area [-2*rotFactor, 2*rotFactor]
|
||||||
|
rot = min(2*self.rot_factor,
|
||||||
|
max(-2*self.rot_factor, np.random.randn()*self.rot_factor))
|
||||||
|
|
||||||
|
# The scale is multiplied with a number
|
||||||
|
# in the area [1-scaleFactor,1+scaleFactor]
|
||||||
|
sc = min(1+self.scale_factor,
|
||||||
|
max(1-self.scale_factor, np.random.randn()*self.scale_factor+1))
|
||||||
|
# but it is zero with probability 3/5
|
||||||
|
if np.random.uniform() <= 0.6:
|
||||||
|
rot = 0
|
||||||
|
|
||||||
|
return flip, pn, rot, sc
|
||||||
|
|
||||||
|
def rgb_processing(self, rgb_img, center, scale, rot, flip, pn):
|
||||||
|
"""Process rgb image and do augmentation."""
|
||||||
|
rgb_img = crop(rgb_img, center, scale,
|
||||||
|
[self.img_res, self.img_res], rot=rot)
|
||||||
|
# flip the image
|
||||||
|
if flip:
|
||||||
|
rgb_img = flip_img(rgb_img)
|
||||||
|
# in the rgb image we add pixel noise in a channel-wise manner
|
||||||
|
rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0]))
|
||||||
|
rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1]))
|
||||||
|
rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2]))
|
||||||
|
# (3,224,224),float,[0,1]
|
||||||
|
rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0
|
||||||
|
return rgb_img
|
||||||
|
|
||||||
|
def j2d_processing(self, kp, center, scale, r, f):
|
||||||
|
"""Process gt 2D keypoints and apply all augmentation transforms."""
|
||||||
|
nparts = kp.shape[0]
|
||||||
|
for i in range(nparts):
|
||||||
|
kp[i,0:2] = transform(kp[i,0:2]+1, center, scale,
|
||||||
|
[self.img_res, self.img_res], rot=r)
|
||||||
|
# convert to normalized coordinates
|
||||||
|
kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1.
|
||||||
|
# flip the x coordinates
|
||||||
|
if f:
|
||||||
|
kp = flip_kp(kp)
|
||||||
|
kp = kp.astype('float32')
|
||||||
|
return kp
|
||||||
|
|
||||||
|
|
||||||
|
def j3d_processing(self, S, r, f):
|
||||||
|
"""Process gt 3D keypoints and apply all augmentation transforms."""
|
||||||
|
# in-plane rotation
|
||||||
|
rot_mat = np.eye(3)
|
||||||
|
if not r == 0:
|
||||||
|
rot_rad = -r * np.pi / 180
|
||||||
|
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
|
||||||
|
rot_mat[0,:2] = [cs, -sn]
|
||||||
|
rot_mat[1,:2] = [sn, cs]
|
||||||
|
S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1])
|
||||||
|
# flip the x coordinates
|
||||||
|
if f:
|
||||||
|
S = flip_kp(S)
|
||||||
|
S = S.astype('float32')
|
||||||
|
return S
|
||||||
|
|
||||||
|
def pose_processing(self, pose, r, f):
|
||||||
|
"""Process SMPL theta parameters and apply all augmentation transforms."""
|
||||||
|
# rotation or the pose parameters
|
||||||
|
pose = pose.astype('float32')
|
||||||
|
pose[:3] = rot_aa(pose[:3], r)
|
||||||
|
# flip the pose parameters
|
||||||
|
if f:
|
||||||
|
pose = flip_pose(pose)
|
||||||
|
# (72),float
|
||||||
|
pose = pose.astype('float32')
|
||||||
|
return pose
|
||||||
|
|
||||||
|
def get_line_no(self, idx):
|
||||||
|
return idx if self.line_list is None else self.line_list[idx]
|
||||||
|
|
||||||
|
def get_image(self, idx):
|
||||||
|
line_no = self.get_line_no(idx)
|
||||||
|
row = self.img_tsv[line_no]
|
||||||
|
# use -1 to support old format with multiple columns.
|
||||||
|
cv2_im = img_from_base64(row[-1])
|
||||||
|
if self.cv2_output:
|
||||||
|
return cv2_im.astype(np.float32, copy=True)
|
||||||
|
cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB)
|
||||||
|
return cv2_im
|
||||||
|
|
||||||
|
def get_annotations(self, idx):
|
||||||
|
line_no = self.get_line_no(idx)
|
||||||
|
if self.label_tsv is not None:
|
||||||
|
row = self.label_tsv[line_no]
|
||||||
|
annotations = json.loads(row[1])
|
||||||
|
return annotations
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_target_from_annotations(self, annotations, img_size, idx):
|
||||||
|
# This function will be overwritten by each dataset to
|
||||||
|
# decode the labels to specific formats for each task.
|
||||||
|
return annotations
|
||||||
|
|
||||||
|
def get_img_info(self, idx):
|
||||||
|
if self.hw_tsv is not None:
|
||||||
|
line_no = self.get_line_no(idx)
|
||||||
|
row = self.hw_tsv[line_no]
|
||||||
|
try:
|
||||||
|
# json string format with "height" and "width" being the keys
|
||||||
|
return json.loads(row[1])[0]
|
||||||
|
except ValueError:
|
||||||
|
# list of strings representing height and width in order
|
||||||
|
hw_str = row[1].split(' ')
|
||||||
|
hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])}
|
||||||
|
return hw_dict
|
||||||
|
|
||||||
|
def get_img_key(self, idx):
|
||||||
|
line_no = self.get_line_no(idx)
|
||||||
|
# based on the overhead of reading each row.
|
||||||
|
if self.hw_tsv:
|
||||||
|
return self.hw_tsv[line_no][0]
|
||||||
|
elif self.label_tsv:
|
||||||
|
return self.label_tsv[line_no][0]
|
||||||
|
else:
|
||||||
|
return self.img_tsv[line_no][0]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.line_list is None:
|
||||||
|
return self.img_tsv.num_rows()
|
||||||
|
else:
|
||||||
|
return len(self.line_list)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
|
||||||
|
img = self.get_image(idx)
|
||||||
|
img_key = self.get_img_key(idx)
|
||||||
|
annotations = self.get_annotations(idx)
|
||||||
|
|
||||||
|
annotations = annotations[0]
|
||||||
|
center = annotations['center']
|
||||||
|
scale = annotations['scale']
|
||||||
|
has_2d_joints = annotations['has_2d_joints']
|
||||||
|
has_3d_joints = annotations['has_3d_joints']
|
||||||
|
joints_2d = np.asarray(annotations['2d_joints'])
|
||||||
|
joints_3d = np.asarray(annotations['3d_joints'])
|
||||||
|
|
||||||
|
if joints_2d.ndim==3:
|
||||||
|
joints_2d = joints_2d[0]
|
||||||
|
if joints_3d.ndim==3:
|
||||||
|
joints_3d = joints_3d[0]
|
||||||
|
|
||||||
|
# Get SMPL parameters, if available
|
||||||
|
has_smpl = np.asarray(annotations['has_smpl'])
|
||||||
|
pose = np.asarray(annotations['pose'])
|
||||||
|
betas = np.asarray(annotations['betas'])
|
||||||
|
|
||||||
|
# Get augmentation parameters
|
||||||
|
flip,pn,rot,sc = self.augm_params()
|
||||||
|
|
||||||
|
# Process image
|
||||||
|
img = self.rgb_processing(img, center, sc*scale, rot, flip, pn)
|
||||||
|
img = torch.from_numpy(img).float()
|
||||||
|
# Store image before normalization to use it in visualization
|
||||||
|
transfromed_img = self.normalize_img(img)
|
||||||
|
|
||||||
|
# normalize 3d pose by aligning the wrist as the root (at origin)
|
||||||
|
root_coord = joints_3d[self.root_index,:-1]
|
||||||
|
joints_3d[:,:-1] = joints_3d[:,:-1] - root_coord[None,:]
|
||||||
|
# 3d pose augmentation (random flip + rotation, consistent to image and SMPL)
|
||||||
|
joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip)
|
||||||
|
# 2d pose augmentation
|
||||||
|
joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip)
|
||||||
|
|
||||||
|
###################################
|
||||||
|
# Masking percantage
|
||||||
|
# We observe that 0% or 5% works better for 3D hand mesh
|
||||||
|
# We think this is probably becasue 3D vertices are quite sparse in the down-sampled hand mesh
|
||||||
|
mvm_percent = 0.0 # or 0.05
|
||||||
|
###################################
|
||||||
|
|
||||||
|
mjm_mask = np.ones((21,1))
|
||||||
|
if self.is_train:
|
||||||
|
num_joints = 21
|
||||||
|
pb = np.random.random_sample()
|
||||||
|
masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked
|
||||||
|
indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num)
|
||||||
|
mjm_mask[indices,:] = 0.0
|
||||||
|
mjm_mask = torch.from_numpy(mjm_mask).float()
|
||||||
|
|
||||||
|
mvm_mask = np.ones((195,1))
|
||||||
|
if self.is_train:
|
||||||
|
num_vertices = 195
|
||||||
|
pb = np.random.random_sample()
|
||||||
|
masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked
|
||||||
|
indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num)
|
||||||
|
mvm_mask[indices,:] = 0.0
|
||||||
|
mvm_mask = torch.from_numpy(mvm_mask).float()
|
||||||
|
|
||||||
|
meta_data = {}
|
||||||
|
meta_data['ori_img'] = img
|
||||||
|
meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float()
|
||||||
|
meta_data['betas'] = torch.from_numpy(betas).float()
|
||||||
|
meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float()
|
||||||
|
meta_data['has_3d_joints'] = has_3d_joints
|
||||||
|
meta_data['has_smpl'] = has_smpl
|
||||||
|
meta_data['mjm_mask'] = mjm_mask
|
||||||
|
meta_data['mvm_mask'] = mvm_mask
|
||||||
|
|
||||||
|
# Get 2D keypoints and apply augmentation transforms
|
||||||
|
meta_data['has_2d_joints'] = has_2d_joints
|
||||||
|
meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float()
|
||||||
|
|
||||||
|
meta_data['scale'] = float(sc * scale)
|
||||||
|
meta_data['center'] = np.asarray(center).astype(np.float32)
|
||||||
|
|
||||||
|
return img_key, transfromed_img, meta_data
|
||||||
|
|
||||||
|
|
||||||
|
class HandMeshTSVYamlDataset(HandMeshTSVDataset):
|
||||||
|
""" TSVDataset taking a Yaml file for easy function call
|
||||||
|
"""
|
||||||
|
def __init__(self, args, yaml_file, is_train=True, cv2_output=False, scale_factor=1):
|
||||||
|
self.cfg = load_from_yaml_file(yaml_file)
|
||||||
|
self.is_composite = self.cfg.get('composite', False)
|
||||||
|
self.root = op.dirname(yaml_file)
|
||||||
|
|
||||||
|
if self.is_composite==False:
|
||||||
|
img_file = find_file_path_in_yaml(self.cfg['img'], self.root)
|
||||||
|
label_file = find_file_path_in_yaml(self.cfg.get('label', None),
|
||||||
|
self.root)
|
||||||
|
hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root)
|
||||||
|
linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
|
||||||
|
self.root)
|
||||||
|
else:
|
||||||
|
img_file = self.cfg['img']
|
||||||
|
hw_file = self.cfg['hw']
|
||||||
|
label_file = self.cfg.get('label', None)
|
||||||
|
linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
|
||||||
|
self.root)
|
||||||
|
|
||||||
|
super(HandMeshTSVYamlDataset, self).__init__(
|
||||||
|
args, img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor)
|
||||||
337
mesh_graphormer/datasets/human_mesh_tsv.py
Normal file
337
mesh_graphormer/datasets/human_mesh_tsv.py
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
import json
|
||||||
|
from PIL import Image
|
||||||
|
import os.path as op
|
||||||
|
import numpy as np
|
||||||
|
import code
|
||||||
|
|
||||||
|
from mesh_graphormer.utils.tsv_file import TSVFile, CompositeTSVFile
|
||||||
|
from mesh_graphormer.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml
|
||||||
|
from mesh_graphormer.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
|
||||||
|
class MeshTSVDataset(object):
|
||||||
|
def __init__(self, img_file, label_file=None, hw_file=None,
|
||||||
|
linelist_file=None, is_train=True, cv2_output=False, scale_factor=1):
|
||||||
|
|
||||||
|
self.img_file = img_file
|
||||||
|
self.label_file = label_file
|
||||||
|
self.hw_file = hw_file
|
||||||
|
self.linelist_file = linelist_file
|
||||||
|
self.img_tsv = self.get_tsv_file(img_file)
|
||||||
|
self.label_tsv = None if label_file is None else self.get_tsv_file(label_file)
|
||||||
|
self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file)
|
||||||
|
|
||||||
|
if self.is_composite:
|
||||||
|
assert op.isfile(self.linelist_file)
|
||||||
|
self.line_list = [i for i in range(self.hw_tsv.num_rows())]
|
||||||
|
else:
|
||||||
|
self.line_list = load_linelist_file(linelist_file)
|
||||||
|
|
||||||
|
self.cv2_output = cv2_output
|
||||||
|
self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225])
|
||||||
|
self.is_train = is_train
|
||||||
|
self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor]
|
||||||
|
self.noise_factor = 0.4
|
||||||
|
self.rot_factor = 30 # Random rotation in the range [-rot_factor, rot_factor]
|
||||||
|
self.img_res = 224
|
||||||
|
|
||||||
|
self.image_keys = self.prepare_image_keys()
|
||||||
|
|
||||||
|
self.joints_definition = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder',
|
||||||
|
'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear')
|
||||||
|
self.pelvis_index = self.joints_definition.index('Pelvis')
|
||||||
|
|
||||||
|
def get_tsv_file(self, tsv_file):
|
||||||
|
if tsv_file:
|
||||||
|
if self.is_composite:
|
||||||
|
return CompositeTSVFile(tsv_file, self.linelist_file,
|
||||||
|
root=self.root)
|
||||||
|
tsv_path = find_file_path_in_yaml(tsv_file, self.root)
|
||||||
|
return TSVFile(tsv_path)
|
||||||
|
|
||||||
|
def get_valid_tsv(self):
|
||||||
|
# sorted by file size
|
||||||
|
if self.hw_tsv:
|
||||||
|
return self.hw_tsv
|
||||||
|
if self.label_tsv:
|
||||||
|
return self.label_tsv
|
||||||
|
|
||||||
|
def prepare_image_keys(self):
|
||||||
|
tsv = self.get_valid_tsv()
|
||||||
|
return [tsv.get_key(i) for i in range(tsv.num_rows())]
|
||||||
|
|
||||||
|
def prepare_image_key_to_index(self):
|
||||||
|
tsv = self.get_valid_tsv()
|
||||||
|
return {tsv.get_key(i) : i for i in range(tsv.num_rows())}
|
||||||
|
|
||||||
|
|
||||||
|
def augm_params(self):
|
||||||
|
"""Get augmentation parameters."""
|
||||||
|
flip = 0 # flipping
|
||||||
|
pn = np.ones(3) # per channel pixel-noise
|
||||||
|
rot = 0 # rotation
|
||||||
|
sc = 1 # scaling
|
||||||
|
if self.is_train:
|
||||||
|
# We flip with probability 1/2
|
||||||
|
if np.random.uniform() <= 0.5:
|
||||||
|
flip = 1
|
||||||
|
|
||||||
|
# Each channel is multiplied with a number
|
||||||
|
# in the area [1-opt.noiseFactor,1+opt.noiseFactor]
|
||||||
|
pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3)
|
||||||
|
|
||||||
|
# The rotation is a number in the area [-2*rotFactor, 2*rotFactor]
|
||||||
|
rot = min(2*self.rot_factor,
|
||||||
|
max(-2*self.rot_factor, np.random.randn()*self.rot_factor))
|
||||||
|
|
||||||
|
# The scale is multiplied with a number
|
||||||
|
# in the area [1-scaleFactor,1+scaleFactor]
|
||||||
|
sc = min(1+self.scale_factor,
|
||||||
|
max(1-self.scale_factor, np.random.randn()*self.scale_factor+1))
|
||||||
|
# but it is zero with probability 3/5
|
||||||
|
if np.random.uniform() <= 0.6:
|
||||||
|
rot = 0
|
||||||
|
|
||||||
|
return flip, pn, rot, sc
|
||||||
|
|
||||||
|
def rgb_processing(self, rgb_img, center, scale, rot, flip, pn):
|
||||||
|
"""Process rgb image and do augmentation."""
|
||||||
|
rgb_img = crop(rgb_img, center, scale,
|
||||||
|
[self.img_res, self.img_res], rot=rot)
|
||||||
|
# flip the image
|
||||||
|
if flip:
|
||||||
|
rgb_img = flip_img(rgb_img)
|
||||||
|
# in the rgb image we add pixel noise in a channel-wise manner
|
||||||
|
rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0]))
|
||||||
|
rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1]))
|
||||||
|
rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2]))
|
||||||
|
# (3,224,224),float,[0,1]
|
||||||
|
rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0
|
||||||
|
return rgb_img
|
||||||
|
|
||||||
|
def j2d_processing(self, kp, center, scale, r, f):
|
||||||
|
"""Process gt 2D keypoints and apply all augmentation transforms."""
|
||||||
|
nparts = kp.shape[0]
|
||||||
|
for i in range(nparts):
|
||||||
|
kp[i,0:2] = transform(kp[i,0:2]+1, center, scale,
|
||||||
|
[self.img_res, self.img_res], rot=r)
|
||||||
|
# convert to normalized coordinates
|
||||||
|
kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1.
|
||||||
|
# flip the x coordinates
|
||||||
|
if f:
|
||||||
|
kp = flip_kp(kp)
|
||||||
|
kp = kp.astype('float32')
|
||||||
|
return kp
|
||||||
|
|
||||||
|
def j3d_processing(self, S, r, f):
|
||||||
|
"""Process gt 3D keypoints and apply all augmentation transforms."""
|
||||||
|
# in-plane rotation
|
||||||
|
rot_mat = np.eye(3)
|
||||||
|
if not r == 0:
|
||||||
|
rot_rad = -r * np.pi / 180
|
||||||
|
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
|
||||||
|
rot_mat[0,:2] = [cs, -sn]
|
||||||
|
rot_mat[1,:2] = [sn, cs]
|
||||||
|
S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1])
|
||||||
|
# flip the x coordinates
|
||||||
|
if f:
|
||||||
|
S = flip_kp(S)
|
||||||
|
S = S.astype('float32')
|
||||||
|
return S
|
||||||
|
|
||||||
|
def pose_processing(self, pose, r, f):
|
||||||
|
"""Process SMPL theta parameters and apply all augmentation transforms."""
|
||||||
|
# rotation or the pose parameters
|
||||||
|
pose = pose.astype('float32')
|
||||||
|
pose[:3] = rot_aa(pose[:3], r)
|
||||||
|
# flip the pose parameters
|
||||||
|
if f:
|
||||||
|
pose = flip_pose(pose)
|
||||||
|
# (72),float
|
||||||
|
pose = pose.astype('float32')
|
||||||
|
return pose
|
||||||
|
|
||||||
|
def get_line_no(self, idx):
|
||||||
|
return idx if self.line_list is None else self.line_list[idx]
|
||||||
|
|
||||||
|
def get_image(self, idx):
|
||||||
|
line_no = self.get_line_no(idx)
|
||||||
|
row = self.img_tsv[line_no]
|
||||||
|
# use -1 to support old format with multiple columns.
|
||||||
|
cv2_im = img_from_base64(row[-1])
|
||||||
|
if self.cv2_output:
|
||||||
|
return cv2_im.astype(np.float32, copy=True)
|
||||||
|
cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
return cv2_im
|
||||||
|
|
||||||
|
def get_annotations(self, idx):
|
||||||
|
line_no = self.get_line_no(idx)
|
||||||
|
if self.label_tsv is not None:
|
||||||
|
row = self.label_tsv[line_no]
|
||||||
|
annotations = json.loads(row[1])
|
||||||
|
return annotations
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_target_from_annotations(self, annotations, img_size, idx):
|
||||||
|
# This function will be overwritten by each dataset to
|
||||||
|
# decode the labels to specific formats for each task.
|
||||||
|
return annotations
|
||||||
|
|
||||||
|
|
||||||
|
def get_img_info(self, idx):
|
||||||
|
if self.hw_tsv is not None:
|
||||||
|
line_no = self.get_line_no(idx)
|
||||||
|
row = self.hw_tsv[line_no]
|
||||||
|
try:
|
||||||
|
# json string format with "height" and "width" being the keys
|
||||||
|
return json.loads(row[1])[0]
|
||||||
|
except ValueError:
|
||||||
|
# list of strings representing height and width in order
|
||||||
|
hw_str = row[1].split(' ')
|
||||||
|
hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])}
|
||||||
|
return hw_dict
|
||||||
|
|
||||||
|
def get_img_key(self, idx):
|
||||||
|
line_no = self.get_line_no(idx)
|
||||||
|
# based on the overhead of reading each row.
|
||||||
|
if self.hw_tsv:
|
||||||
|
return self.hw_tsv[line_no][0]
|
||||||
|
elif self.label_tsv:
|
||||||
|
return self.label_tsv[line_no][0]
|
||||||
|
else:
|
||||||
|
return self.img_tsv[line_no][0]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.line_list is None:
|
||||||
|
return self.img_tsv.num_rows()
|
||||||
|
else:
|
||||||
|
return len(self.line_list)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
|
||||||
|
img = self.get_image(idx)
|
||||||
|
img_key = self.get_img_key(idx)
|
||||||
|
annotations = self.get_annotations(idx)
|
||||||
|
|
||||||
|
annotations = annotations[0]
|
||||||
|
center = annotations['center']
|
||||||
|
scale = annotations['scale']
|
||||||
|
has_2d_joints = annotations['has_2d_joints']
|
||||||
|
has_3d_joints = annotations['has_3d_joints']
|
||||||
|
joints_2d = np.asarray(annotations['2d_joints'])
|
||||||
|
joints_3d = np.asarray(annotations['3d_joints'])
|
||||||
|
|
||||||
|
if joints_2d.ndim==3:
|
||||||
|
joints_2d = joints_2d[0]
|
||||||
|
if joints_3d.ndim==3:
|
||||||
|
joints_3d = joints_3d[0]
|
||||||
|
|
||||||
|
# Get SMPL parameters, if available
|
||||||
|
has_smpl = np.asarray(annotations['has_smpl'])
|
||||||
|
pose = np.asarray(annotations['pose'])
|
||||||
|
betas = np.asarray(annotations['betas'])
|
||||||
|
|
||||||
|
try:
|
||||||
|
gender = annotations['gender']
|
||||||
|
except KeyError:
|
||||||
|
gender = 'none'
|
||||||
|
|
||||||
|
# Get augmentation parameters
|
||||||
|
flip,pn,rot,sc = self.augm_params()
|
||||||
|
|
||||||
|
# Process image
|
||||||
|
img = self.rgb_processing(img, center, sc*scale, rot, flip, pn)
|
||||||
|
img = torch.from_numpy(img).float()
|
||||||
|
# Store image before normalization to use it in visualization
|
||||||
|
transfromed_img = self.normalize_img(img)
|
||||||
|
|
||||||
|
# normalize 3d pose by aligning the pelvis as the root (at origin)
|
||||||
|
root_pelvis = joints_3d[self.pelvis_index,:-1]
|
||||||
|
joints_3d[:,:-1] = joints_3d[:,:-1] - root_pelvis[None,:]
|
||||||
|
# 3d pose augmentation (random flip + rotation, consistent to image and SMPL)
|
||||||
|
joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip)
|
||||||
|
# 2d pose augmentation
|
||||||
|
joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip)
|
||||||
|
|
||||||
|
###################################
|
||||||
|
# Masking percantage
|
||||||
|
# We observe that 30% works better for human body mesh. Further details are reported in the paper.
|
||||||
|
mvm_percent = 0.3
|
||||||
|
###################################
|
||||||
|
|
||||||
|
mjm_mask = np.ones((14,1))
|
||||||
|
if self.is_train:
|
||||||
|
num_joints = 14
|
||||||
|
pb = np.random.random_sample()
|
||||||
|
masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked
|
||||||
|
indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num)
|
||||||
|
mjm_mask[indices,:] = 0.0
|
||||||
|
mjm_mask = torch.from_numpy(mjm_mask).float()
|
||||||
|
|
||||||
|
mvm_mask = np.ones((431,1))
|
||||||
|
if self.is_train:
|
||||||
|
num_vertices = 431
|
||||||
|
pb = np.random.random_sample()
|
||||||
|
masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked
|
||||||
|
indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num)
|
||||||
|
mvm_mask[indices,:] = 0.0
|
||||||
|
mvm_mask = torch.from_numpy(mvm_mask).float()
|
||||||
|
|
||||||
|
meta_data = {}
|
||||||
|
meta_data['ori_img'] = img
|
||||||
|
meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float()
|
||||||
|
meta_data['betas'] = torch.from_numpy(betas).float()
|
||||||
|
meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float()
|
||||||
|
meta_data['has_3d_joints'] = has_3d_joints
|
||||||
|
meta_data['has_smpl'] = has_smpl
|
||||||
|
|
||||||
|
meta_data['mjm_mask'] = mjm_mask
|
||||||
|
meta_data['mvm_mask'] = mvm_mask
|
||||||
|
|
||||||
|
# Get 2D keypoints and apply augmentation transforms
|
||||||
|
meta_data['has_2d_joints'] = has_2d_joints
|
||||||
|
meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float()
|
||||||
|
meta_data['scale'] = float(sc * scale)
|
||||||
|
meta_data['center'] = np.asarray(center).astype(np.float32)
|
||||||
|
meta_data['gender'] = gender
|
||||||
|
return img_key, transfromed_img, meta_data
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MeshTSVYamlDataset(MeshTSVDataset):
|
||||||
|
""" TSVDataset taking a Yaml file for easy function call
|
||||||
|
"""
|
||||||
|
def __init__(self, yaml_file, is_train=True, cv2_output=False, scale_factor=1):
|
||||||
|
self.cfg = load_from_yaml_file(yaml_file)
|
||||||
|
self.is_composite = self.cfg.get('composite', False)
|
||||||
|
self.root = op.dirname(yaml_file)
|
||||||
|
|
||||||
|
if self.is_composite==False:
|
||||||
|
img_file = find_file_path_in_yaml(self.cfg['img'], self.root)
|
||||||
|
label_file = find_file_path_in_yaml(self.cfg.get('label', None),
|
||||||
|
self.root)
|
||||||
|
hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root)
|
||||||
|
linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
|
||||||
|
self.root)
|
||||||
|
else:
|
||||||
|
img_file = self.cfg['img']
|
||||||
|
hw_file = self.cfg['hw']
|
||||||
|
label_file = self.cfg.get('label', None)
|
||||||
|
linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
|
||||||
|
self.root)
|
||||||
|
|
||||||
|
super(MeshTSVYamlDataset, self).__init__(
|
||||||
|
img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor)
|
||||||
0
mesh_graphormer/modeling/__init__.py
Normal file
0
mesh_graphormer/modeling/__init__.py
Normal file
184
mesh_graphormer/modeling/_gcnn.py
Normal file
184
mesh_graphormer/modeling/_gcnn.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
from __future__ import division
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
import scipy.sparse
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
data_path = Path(__file__).parent / "data"
|
||||||
|
|
||||||
|
|
||||||
|
sparse_to_dense = lambda x: x
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
class SparseMM(torch.autograd.Function):
|
||||||
|
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
|
||||||
|
The builtin matrix multiplication operation does not support backpropagation in some cases.
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, sparse, dense):
|
||||||
|
ctx.req_grad = dense.requires_grad
|
||||||
|
ctx.save_for_backward(sparse)
|
||||||
|
return torch.matmul(sparse, dense)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
grad_input = None
|
||||||
|
sparse, = ctx.saved_tensors
|
||||||
|
if ctx.req_grad:
|
||||||
|
grad_input = torch.matmul(sparse.t(), grad_output)
|
||||||
|
return None, grad_input
|
||||||
|
|
||||||
|
def spmm(sparse, dense):
|
||||||
|
sparse = sparse.to(device)
|
||||||
|
dense = dense.to(device)
|
||||||
|
return SparseMM.apply(sparse, dense)
|
||||||
|
|
||||||
|
|
||||||
|
def gelu(x):
|
||||||
|
"""Implementation of the gelu activation function.
|
||||||
|
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||||
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
|
Also see https://arxiv.org/abs/1606.08415
|
||||||
|
"""
|
||||||
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||||
|
|
||||||
|
class BertLayerNorm(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-12):
|
||||||
|
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||||
|
"""
|
||||||
|
super(BertLayerNorm, self).__init__()
|
||||||
|
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.bias = torch.nn.Parameter(torch.zeros(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
u = x.mean(-1, keepdim=True)
|
||||||
|
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||||
|
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||||
|
return self.weight * x + self.bias
|
||||||
|
|
||||||
|
|
||||||
|
class GraphResBlock(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Graph Residual Block similar to the Bottleneck Residual Block in ResNet
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, mesh_type='body'):
|
||||||
|
super(GraphResBlock, self).__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.lin1 = GraphLinear(in_channels, out_channels // 2)
|
||||||
|
self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type)
|
||||||
|
self.lin2 = GraphLinear(out_channels // 2, out_channels)
|
||||||
|
self.skip_conv = GraphLinear(in_channels, out_channels)
|
||||||
|
# print('Use BertLayerNorm in GraphResBlock')
|
||||||
|
self.pre_norm = BertLayerNorm(in_channels)
|
||||||
|
self.norm1 = BertLayerNorm(out_channels // 2)
|
||||||
|
self.norm2 = BertLayerNorm(out_channels // 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
trans_y = F.relu(self.pre_norm(x)).transpose(1,2)
|
||||||
|
y = self.lin1(trans_y).transpose(1,2)
|
||||||
|
|
||||||
|
y = F.relu(self.norm1(y))
|
||||||
|
y = self.conv(y)
|
||||||
|
|
||||||
|
trans_y = F.relu(self.norm2(y)).transpose(1,2)
|
||||||
|
y = self.lin2(trans_y).transpose(1,2)
|
||||||
|
|
||||||
|
z = x+y
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
# class GraphResBlock(torch.nn.Module):
|
||||||
|
# """
|
||||||
|
# Graph Residual Block similar to the Bottleneck Residual Block in ResNet
|
||||||
|
# """
|
||||||
|
# def __init__(self, in_channels, out_channels, mesh_type='body'):
|
||||||
|
# super(GraphResBlock, self).__init__()
|
||||||
|
# self.in_channels = in_channels
|
||||||
|
# self.out_channels = out_channels
|
||||||
|
# self.conv = GraphConvolution(self.in_channels, self.out_channels, mesh_type)
|
||||||
|
# print('Use BertLayerNorm and GeLU in GraphResBlock')
|
||||||
|
# self.norm = BertLayerNorm(self.out_channels)
|
||||||
|
# def forward(self, x):
|
||||||
|
# y = self.conv(x)
|
||||||
|
# y = self.norm(y)
|
||||||
|
# y = gelu(y)
|
||||||
|
# z = x+y
|
||||||
|
# return z
|
||||||
|
|
||||||
|
class GraphLinear(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Generalization of 1x1 convolutions on Graphs
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super(GraphLinear, self).__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.W = torch.nn.Parameter(torch.FloatTensor(out_channels, in_channels))
|
||||||
|
self.b = torch.nn.Parameter(torch.FloatTensor(out_channels))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
w_stdv = 1 / (self.in_channels * self.out_channels)
|
||||||
|
self.W.data.uniform_(-w_stdv, w_stdv)
|
||||||
|
self.b.data.uniform_(-w_stdv, w_stdv)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.matmul(self.W[None, :], x) + self.b[None, :, None]
|
||||||
|
|
||||||
|
class GraphConvolution(torch.nn.Module):
|
||||||
|
"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
|
||||||
|
def __init__(self, in_features, out_features, mesh='body', bias=True):
|
||||||
|
super(GraphConvolution, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
|
||||||
|
if mesh=='body':
|
||||||
|
adj_indices = torch.load(data_path / 'smpl_431_adjmat_indices.pt')
|
||||||
|
adj_mat_value = torch.load(data_path / 'smpl_431_adjmat_values.pt')
|
||||||
|
adj_mat_size = torch.load(data_path / 'smpl_431_adjmat_size.pt')
|
||||||
|
elif mesh=='hand':
|
||||||
|
adj_indices = torch.load(data_path / 'mano_195_adjmat_indices.pt')
|
||||||
|
adj_mat_value = torch.load(data_path / 'mano_195_adjmat_values.pt')
|
||||||
|
adj_mat_size = torch.load(data_path / 'mano_195_adjmat_size.pt')
|
||||||
|
|
||||||
|
self.adjmat = sparse_to_dense(torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size)).to(device)
|
||||||
|
|
||||||
|
self.weight = torch.nn.Parameter(torch.FloatTensor(in_features, out_features))
|
||||||
|
if bias:
|
||||||
|
self.bias = torch.nn.Parameter(torch.FloatTensor(out_features))
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
# stdv = 1. / math.sqrt(self.weight.size(1))
|
||||||
|
stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1))
|
||||||
|
self.weight.data.uniform_(-stdv, stdv)
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias.data.uniform_(-stdv, stdv)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if x.ndimension() == 2:
|
||||||
|
support = torch.matmul(x, self.weight)
|
||||||
|
output = torch.matmul(self.adjmat, support)
|
||||||
|
if self.bias is not None:
|
||||||
|
output = output + self.bias
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
output = []
|
||||||
|
for i in range(x.shape[0]):
|
||||||
|
support = torch.matmul(x[i], self.weight)
|
||||||
|
# output.append(torch.matmul(self.adjmat, support))
|
||||||
|
output.append(spmm(self.adjmat, support))
|
||||||
|
output = torch.stack(output, dim=0)
|
||||||
|
if self.bias is not None:
|
||||||
|
output = output + self.bias
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__class__.__name__ + ' (' \
|
||||||
|
+ str(self.in_features) + ' -> ' \
|
||||||
|
+ str(self.out_features) + ')'
|
||||||
184
mesh_graphormer/modeling/_mano.py
Normal file
184
mesh_graphormer/modeling/_mano.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
"""
|
||||||
|
This file contains the MANO defination and mesh sampling operations for MANO mesh
|
||||||
|
|
||||||
|
Adapted from opensource projects
|
||||||
|
MANOPTH (https://github.com/hassony2/manopth)
|
||||||
|
Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
|
||||||
|
GraphCMR (https://github.com/nkolot/GraphCMR/)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import division
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import os.path as osp
|
||||||
|
import json
|
||||||
|
import code
|
||||||
|
from manopth.manolayer import ManoLayer
|
||||||
|
import scipy.sparse
|
||||||
|
import mesh_graphormer.modeling.data.config as cfg
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
sparse_to_dense = lambda x: x
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
class MANO(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MANO, self).__init__()
|
||||||
|
|
||||||
|
self.mano_dir = str(Path(__file__).parent / "data")
|
||||||
|
self.layer = self.get_layer()
|
||||||
|
self.vertex_num = 778
|
||||||
|
self.face = self.layer.th_faces.numpy()
|
||||||
|
self.joint_regressor = self.layer.th_J_regressor.numpy()
|
||||||
|
|
||||||
|
self.joint_num = 21
|
||||||
|
self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
|
||||||
|
self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) )
|
||||||
|
self.root_joint_idx = self.joints_name.index('Wrist')
|
||||||
|
|
||||||
|
# add fingertips to joint_regressor
|
||||||
|
self.fingertip_vertex_idx = [745, 317, 444, 556, 673] # mesh vertex idx (right hand)
|
||||||
|
thumbtip_onehot = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
|
||||||
|
indextip_onehot = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
|
||||||
|
middletip_onehot = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
|
||||||
|
ringtip_onehot = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
|
||||||
|
pinkytip_onehot = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
|
||||||
|
self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot))
|
||||||
|
self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:]
|
||||||
|
joint_regressor_torch = torch.from_numpy(self.joint_regressor).float()
|
||||||
|
self.register_buffer('joint_regressor_torch', joint_regressor_torch)
|
||||||
|
|
||||||
|
def get_layer(self):
|
||||||
|
return ManoLayer(mano_root=osp.join(self.mano_dir), flat_hand_mean=False, use_pca=False) # load right hand MANO model
|
||||||
|
|
||||||
|
def get_3d_joints(self, vertices):
|
||||||
|
"""
|
||||||
|
This method is used to get the joint locations from the SMPL mesh
|
||||||
|
Input:
|
||||||
|
vertices: size = (B, 778, 3)
|
||||||
|
Output:
|
||||||
|
3D joints: size = (B, 21, 3)
|
||||||
|
"""
|
||||||
|
joints = torch.einsum('bik,ji->bjk', [vertices, self.joint_regressor_torch])
|
||||||
|
return joints
|
||||||
|
|
||||||
|
|
||||||
|
class SparseMM(torch.autograd.Function):
|
||||||
|
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
|
||||||
|
The builtin matrix multiplication operation does not support backpropagation in some cases.
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, sparse, dense):
|
||||||
|
ctx.req_grad = dense.requires_grad
|
||||||
|
ctx.save_for_backward(sparse)
|
||||||
|
return torch.matmul(sparse, dense)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
grad_input = None
|
||||||
|
sparse, = ctx.saved_tensors
|
||||||
|
if ctx.req_grad:
|
||||||
|
grad_input = torch.matmul(sparse.t(), grad_output)
|
||||||
|
return None, grad_input
|
||||||
|
|
||||||
|
def spmm(sparse, dense):
|
||||||
|
sparse = sparse.to(device)
|
||||||
|
dense = dense.to(device)
|
||||||
|
return SparseMM.apply(sparse, dense)
|
||||||
|
|
||||||
|
|
||||||
|
def scipy_to_pytorch(A, U, D):
|
||||||
|
"""Convert scipy sparse matrices to pytorch sparse matrix."""
|
||||||
|
ptU = []
|
||||||
|
ptD = []
|
||||||
|
|
||||||
|
for i in range(len(U)):
|
||||||
|
u = scipy.sparse.coo_matrix(U[i])
|
||||||
|
i = torch.LongTensor(np.array([u.row, u.col]))
|
||||||
|
v = torch.FloatTensor(u.data)
|
||||||
|
ptU.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, u.shape)))
|
||||||
|
|
||||||
|
for i in range(len(D)):
|
||||||
|
d = scipy.sparse.coo_matrix(D[i])
|
||||||
|
i = torch.LongTensor(np.array([d.row, d.col]))
|
||||||
|
v = torch.FloatTensor(d.data)
|
||||||
|
ptD.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, d.shape)))
|
||||||
|
|
||||||
|
return ptU, ptD
|
||||||
|
|
||||||
|
|
||||||
|
def adjmat_sparse(adjmat, nsize=1):
|
||||||
|
"""Create row-normalized sparse graph adjacency matrix."""
|
||||||
|
adjmat = scipy.sparse.csr_matrix(adjmat)
|
||||||
|
if nsize > 1:
|
||||||
|
orig_adjmat = adjmat.copy()
|
||||||
|
for _ in range(1, nsize):
|
||||||
|
adjmat = adjmat * orig_adjmat
|
||||||
|
adjmat.data = np.ones_like(adjmat.data)
|
||||||
|
for i in range(adjmat.shape[0]):
|
||||||
|
adjmat[i,i] = 1
|
||||||
|
num_neighbors = np.array(1 / adjmat.sum(axis=-1))
|
||||||
|
adjmat = adjmat.multiply(num_neighbors)
|
||||||
|
adjmat = scipy.sparse.coo_matrix(adjmat)
|
||||||
|
row = adjmat.row
|
||||||
|
col = adjmat.col
|
||||||
|
data = adjmat.data
|
||||||
|
i = torch.LongTensor(np.array([row, col]))
|
||||||
|
v = torch.from_numpy(data).float()
|
||||||
|
adjmat = sparse_to_dense(torch.sparse_coo_tensor(i, v, adjmat.shape))
|
||||||
|
return adjmat
|
||||||
|
|
||||||
|
def get_graph_params(filename, nsize=1):
|
||||||
|
"""Load and process graph adjacency matrix and upsampling/downsampling matrices."""
|
||||||
|
data = np.load(filename, encoding='latin1', allow_pickle=True)
|
||||||
|
A = data['A']
|
||||||
|
U = data['U']
|
||||||
|
D = data['D']
|
||||||
|
U, D = scipy_to_pytorch(A, U, D)
|
||||||
|
A = [adjmat_sparse(a, nsize=nsize) for a in A]
|
||||||
|
return A, U, D
|
||||||
|
|
||||||
|
|
||||||
|
class Mesh(object):
|
||||||
|
"""Mesh object that is used for handling certain graph operations."""
|
||||||
|
def __init__(self, filename=cfg.MANO_sampling_matrix,
|
||||||
|
num_downsampling=1, nsize=1, device=torch.device('cuda')):
|
||||||
|
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
|
||||||
|
# self._A = [a.to(device) for a in self._A]
|
||||||
|
self._U = [u.to(device) for u in self._U]
|
||||||
|
self._D = [d.to(device) for d in self._D]
|
||||||
|
self.num_downsampling = num_downsampling
|
||||||
|
|
||||||
|
def downsample(self, x, n1=0, n2=None):
|
||||||
|
"""Downsample mesh."""
|
||||||
|
if n2 is None:
|
||||||
|
n2 = self.num_downsampling
|
||||||
|
if x.ndimension() < 3:
|
||||||
|
for i in range(n1, n2):
|
||||||
|
x = spmm(self._D[i], x)
|
||||||
|
elif x.ndimension() == 3:
|
||||||
|
out = []
|
||||||
|
for i in range(x.shape[0]):
|
||||||
|
y = x[i]
|
||||||
|
for j in range(n1, n2):
|
||||||
|
y = spmm(self._D[j], y)
|
||||||
|
out.append(y)
|
||||||
|
x = torch.stack(out, dim=0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def upsample(self, x, n1=1, n2=0):
|
||||||
|
"""Upsample mesh."""
|
||||||
|
if x.ndimension() < 3:
|
||||||
|
for i in reversed(range(n2, n1)):
|
||||||
|
x = spmm(self._U[i], x)
|
||||||
|
elif x.ndimension() == 3:
|
||||||
|
out = []
|
||||||
|
for i in range(x.shape[0]):
|
||||||
|
y = x[i]
|
||||||
|
for j in reversed(range(n2, n1)):
|
||||||
|
y = spmm(self._U[j], y)
|
||||||
|
out.append(y)
|
||||||
|
x = torch.stack(out, dim=0)
|
||||||
|
return x
|
||||||
283
mesh_graphormer/modeling/_smpl.py
Normal file
283
mesh_graphormer/modeling/_smpl.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
"""
|
||||||
|
This file contains the definition of the SMPL model
|
||||||
|
|
||||||
|
It is adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/)
|
||||||
|
"""
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import scipy.sparse
|
||||||
|
try:
|
||||||
|
import cPickle as pickle
|
||||||
|
except ImportError:
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from mesh_graphormer.utils.geometric_layers import rodrigues
|
||||||
|
import mesh_graphormer.modeling.data.config as cfg
|
||||||
|
|
||||||
|
|
||||||
|
sparse_to_dense = lambda x: x
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
class SMPL(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, gender='neutral'):
|
||||||
|
super(SMPL, self).__init__()
|
||||||
|
|
||||||
|
if gender=='m':
|
||||||
|
model_file=cfg.SMPL_Male
|
||||||
|
elif gender=='f':
|
||||||
|
model_file=cfg.SMPL_Female
|
||||||
|
else:
|
||||||
|
model_file=cfg.SMPL_FILE
|
||||||
|
|
||||||
|
smpl_model = pickle.load(open(model_file, 'rb'), encoding='latin1')
|
||||||
|
J_regressor = smpl_model['J_regressor'].tocoo()
|
||||||
|
row = J_regressor.row
|
||||||
|
col = J_regressor.col
|
||||||
|
data = J_regressor.data
|
||||||
|
i = torch.LongTensor([row, col])
|
||||||
|
v = torch.FloatTensor(data)
|
||||||
|
J_regressor_shape = [24, 6890]
|
||||||
|
self.register_buffer('J_regressor', torch.sparse_coo_tensor(i, v, J_regressor_shape).to_dense())
|
||||||
|
self.register_buffer('weights', torch.FloatTensor(smpl_model['weights']))
|
||||||
|
self.register_buffer('posedirs', torch.FloatTensor(smpl_model['posedirs']))
|
||||||
|
self.register_buffer('v_template', torch.FloatTensor(smpl_model['v_template']))
|
||||||
|
self.register_buffer('shapedirs', torch.FloatTensor(np.array(smpl_model['shapedirs'])))
|
||||||
|
self.register_buffer('faces', torch.from_numpy(smpl_model['f'].astype(np.int64)))
|
||||||
|
self.register_buffer('kintree_table', torch.from_numpy(smpl_model['kintree_table'].astype(np.int64)))
|
||||||
|
id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])}
|
||||||
|
self.register_buffer('parent', torch.LongTensor([id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])]))
|
||||||
|
|
||||||
|
self.pose_shape = [24, 3]
|
||||||
|
self.beta_shape = [10]
|
||||||
|
self.translation_shape = [3]
|
||||||
|
|
||||||
|
self.pose = torch.zeros(self.pose_shape)
|
||||||
|
self.beta = torch.zeros(self.beta_shape)
|
||||||
|
self.translation = torch.zeros(self.translation_shape)
|
||||||
|
|
||||||
|
self.verts = None
|
||||||
|
self.J = None
|
||||||
|
self.R = None
|
||||||
|
|
||||||
|
J_regressor_extra = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_TRAIN_EXTRA)).float()
|
||||||
|
self.register_buffer('J_regressor_extra', J_regressor_extra)
|
||||||
|
self.joints_idx = cfg.JOINTS_IDX
|
||||||
|
|
||||||
|
J_regressor_h36m_correct = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M_correct)).float()
|
||||||
|
self.register_buffer('J_regressor_h36m_correct', J_regressor_h36m_correct)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, pose, beta):
|
||||||
|
device = pose.device
|
||||||
|
batch_size = pose.shape[0]
|
||||||
|
v_template = self.v_template[None, :]
|
||||||
|
shapedirs = self.shapedirs.view(-1,10)[None, :].expand(batch_size, -1, -1)
|
||||||
|
beta = beta[:, :, None]
|
||||||
|
v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template
|
||||||
|
# batched sparse matmul not supported in pytorch
|
||||||
|
J = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
J.append(torch.matmul(self.J_regressor, v_shaped[i]))
|
||||||
|
J = torch.stack(J, dim=0)
|
||||||
|
# input it rotmat: (bs,24,3,3)
|
||||||
|
if pose.ndimension() == 4:
|
||||||
|
R = pose
|
||||||
|
# input it rotmat: (bs,72)
|
||||||
|
elif pose.ndimension() == 2:
|
||||||
|
pose_cube = pose.view(-1, 3) # (batch_size * 24, 1, 3)
|
||||||
|
R = rodrigues(pose_cube).view(batch_size, 24, 3, 3)
|
||||||
|
R = R.view(batch_size, 24, 3, 3)
|
||||||
|
I_cube = torch.eye(3)[None, None, :].to(device)
|
||||||
|
# I_cube = torch.eye(3)[None, None, :].expand(theta.shape[0], R.shape[1]-1, -1, -1)
|
||||||
|
lrotmin = (R[:,1:,:] - I_cube).view(batch_size, -1)
|
||||||
|
posedirs = self.posedirs.view(-1,207)[None, :].expand(batch_size, -1, -1)
|
||||||
|
v_posed = v_shaped + torch.matmul(posedirs, lrotmin[:, :, None]).view(-1, 6890, 3)
|
||||||
|
J_ = J.clone()
|
||||||
|
J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :]
|
||||||
|
G_ = torch.cat([R, J_[:, :, :, None]], dim=-1)
|
||||||
|
pad_row = torch.FloatTensor([0,0,0,1]).to(device).view(1,1,1,4).expand(batch_size, 24, -1, -1)
|
||||||
|
G_ = torch.cat([G_, pad_row], dim=2)
|
||||||
|
G = [G_[:, 0].clone()]
|
||||||
|
for i in range(1, 24):
|
||||||
|
G.append(torch.matmul(G[self.parent[i-1]], G_[:, i, :, :]))
|
||||||
|
G = torch.stack(G, dim=1)
|
||||||
|
|
||||||
|
rest = torch.cat([J, torch.zeros(batch_size, 24, 1).to(device)], dim=2).view(batch_size, 24, 4, 1)
|
||||||
|
zeros = torch.zeros(batch_size, 24, 4, 3).to(device)
|
||||||
|
rest = torch.cat([zeros, rest], dim=-1)
|
||||||
|
rest = torch.matmul(G, rest)
|
||||||
|
G = G - rest
|
||||||
|
T = torch.matmul(self.weights, G.permute(1,0,2,3).contiguous().view(24,-1)).view(6890, batch_size, 4, 4).transpose(0,1)
|
||||||
|
rest_shape_h = torch.cat([v_posed, torch.ones_like(v_posed)[:, :, [0]]], dim=-1)
|
||||||
|
v = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0]
|
||||||
|
return v
|
||||||
|
|
||||||
|
def get_joints(self, vertices):
|
||||||
|
"""
|
||||||
|
This method is used to get the joint locations from the SMPL mesh
|
||||||
|
Input:
|
||||||
|
vertices: size = (B, 6890, 3)
|
||||||
|
Output:
|
||||||
|
3D joints: size = (B, 38, 3)
|
||||||
|
"""
|
||||||
|
joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor])
|
||||||
|
joints_extra = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_extra])
|
||||||
|
joints = torch.cat((joints, joints_extra), dim=1)
|
||||||
|
joints = joints[:, cfg.JOINTS_IDX]
|
||||||
|
return joints
|
||||||
|
|
||||||
|
def get_h36m_joints(self, vertices):
|
||||||
|
"""
|
||||||
|
This method is used to get the joint locations from the SMPL mesh
|
||||||
|
Input:
|
||||||
|
vertices: size = (B, 6890, 3)
|
||||||
|
Output:
|
||||||
|
3D joints: size = (B, 24, 3)
|
||||||
|
"""
|
||||||
|
joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct])
|
||||||
|
return joints
|
||||||
|
|
||||||
|
class SparseMM(torch.autograd.Function):
|
||||||
|
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
|
||||||
|
The builtin matrix multiplication operation does not support backpropagation in some cases.
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, sparse, dense):
|
||||||
|
ctx.req_grad = dense.requires_grad
|
||||||
|
ctx.save_for_backward(sparse)
|
||||||
|
return torch.matmul(sparse, dense)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
grad_input = None
|
||||||
|
sparse, = ctx.saved_tensors
|
||||||
|
if ctx.req_grad:
|
||||||
|
grad_input = torch.matmul(sparse.t(), grad_output)
|
||||||
|
return None, grad_input
|
||||||
|
|
||||||
|
def spmm(sparse, dense):
|
||||||
|
sparse = sparse.to(device)
|
||||||
|
dense = dense.to(device)
|
||||||
|
return SparseMM.apply(sparse, dense)
|
||||||
|
|
||||||
|
|
||||||
|
def scipy_to_pytorch(A, U, D):
|
||||||
|
"""Convert scipy sparse matrices to pytorch sparse matrix."""
|
||||||
|
ptU = []
|
||||||
|
ptD = []
|
||||||
|
|
||||||
|
for i in range(len(U)):
|
||||||
|
u = scipy.sparse.coo_matrix(U[i])
|
||||||
|
i = torch.LongTensor(np.array([u.row, u.col]))
|
||||||
|
v = torch.FloatTensor(u.data)
|
||||||
|
ptU.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, u.shape)))
|
||||||
|
|
||||||
|
for i in range(len(D)):
|
||||||
|
d = scipy.sparse.coo_matrix(D[i])
|
||||||
|
i = torch.LongTensor(np.array([d.row, d.col]))
|
||||||
|
v = torch.FloatTensor(d.data)
|
||||||
|
ptD.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, d.shape)))
|
||||||
|
|
||||||
|
return ptU, ptD
|
||||||
|
|
||||||
|
|
||||||
|
def adjmat_sparse(adjmat, nsize=1):
|
||||||
|
"""Create row-normalized sparse graph adjacency matrix."""
|
||||||
|
adjmat = scipy.sparse.csr_matrix(adjmat)
|
||||||
|
if nsize > 1:
|
||||||
|
orig_adjmat = adjmat.copy()
|
||||||
|
for _ in range(1, nsize):
|
||||||
|
adjmat = adjmat * orig_adjmat
|
||||||
|
adjmat.data = np.ones_like(adjmat.data)
|
||||||
|
for i in range(adjmat.shape[0]):
|
||||||
|
adjmat[i,i] = 1
|
||||||
|
num_neighbors = np.array(1 / adjmat.sum(axis=-1))
|
||||||
|
adjmat = adjmat.multiply(num_neighbors)
|
||||||
|
adjmat = scipy.sparse.coo_matrix(adjmat)
|
||||||
|
row = adjmat.row
|
||||||
|
col = adjmat.col
|
||||||
|
data = adjmat.data
|
||||||
|
i = torch.LongTensor(np.array([row, col]))
|
||||||
|
v = torch.from_numpy(data).float()
|
||||||
|
adjmat = sparse_to_dense(torch.sparse_coo_tensor(i, v, adjmat.shape))
|
||||||
|
return adjmat
|
||||||
|
|
||||||
|
def get_graph_params(filename, nsize=1):
|
||||||
|
"""Load and process graph adjacency matrix and upsampling/downsampling matrices."""
|
||||||
|
data = np.load(filename, encoding='latin1', allow_pickle=True)
|
||||||
|
A = data['A']
|
||||||
|
U = data['U']
|
||||||
|
D = data['D']
|
||||||
|
U, D = scipy_to_pytorch(A, U, D)
|
||||||
|
A = [adjmat_sparse(a, nsize=nsize) for a in A]
|
||||||
|
return A, U, D
|
||||||
|
|
||||||
|
|
||||||
|
class Mesh(object):
|
||||||
|
"""Mesh object that is used for handling certain graph operations."""
|
||||||
|
def __init__(self, filename=cfg.SMPL_sampling_matrix,
|
||||||
|
num_downsampling=1, nsize=1, device=torch.device('cuda')):
|
||||||
|
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
|
||||||
|
# self._A = [a.to(device) for a in self._A]
|
||||||
|
self._U = [u.to(device) for u in self._U]
|
||||||
|
self._D = [d.to(device) for d in self._D]
|
||||||
|
self.num_downsampling = num_downsampling
|
||||||
|
|
||||||
|
# load template vertices from SMPL and normalize them
|
||||||
|
smpl = SMPL()
|
||||||
|
ref_vertices = smpl.v_template
|
||||||
|
center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None]
|
||||||
|
ref_vertices -= center
|
||||||
|
ref_vertices /= ref_vertices.abs().max().item()
|
||||||
|
|
||||||
|
self._ref_vertices = ref_vertices.to(device)
|
||||||
|
self.faces = smpl.faces.int().to(device)
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def adjmat(self):
|
||||||
|
# """Return the graph adjacency matrix at the specified subsampling level."""
|
||||||
|
# return self._A[self.num_downsampling].float()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ref_vertices(self):
|
||||||
|
"""Return the template vertices at the specified subsampling level."""
|
||||||
|
ref_vertices = self._ref_vertices
|
||||||
|
for i in range(self.num_downsampling):
|
||||||
|
ref_vertices = torch.spmm(self._D[i], ref_vertices)
|
||||||
|
return ref_vertices
|
||||||
|
|
||||||
|
def downsample(self, x, n1=0, n2=None):
|
||||||
|
"""Downsample mesh."""
|
||||||
|
if n2 is None:
|
||||||
|
n2 = self.num_downsampling
|
||||||
|
if x.ndimension() < 3:
|
||||||
|
for i in range(n1, n2):
|
||||||
|
x = spmm(self._D[i], x)
|
||||||
|
elif x.ndimension() == 3:
|
||||||
|
out = []
|
||||||
|
for i in range(x.shape[0]):
|
||||||
|
y = x[i]
|
||||||
|
for j in range(n1, n2):
|
||||||
|
y = spmm(self._D[j], y)
|
||||||
|
out.append(y)
|
||||||
|
x = torch.stack(out, dim=0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def upsample(self, x, n1=1, n2=0):
|
||||||
|
"""Upsample mesh."""
|
||||||
|
if x.ndimension() < 3:
|
||||||
|
for i in reversed(range(n2, n1)):
|
||||||
|
x = spmm(self._U[i], x)
|
||||||
|
elif x.ndimension() == 3:
|
||||||
|
out = []
|
||||||
|
for i in range(x.shape[0]):
|
||||||
|
y = x[i]
|
||||||
|
for j in reversed(range(n2, n1)):
|
||||||
|
y = spmm(self._U[j], y)
|
||||||
|
out.append(y)
|
||||||
|
x = torch.stack(out, dim=0)
|
||||||
|
return x
|
||||||
17
mesh_graphormer/modeling/bert/__init__.py
Normal file
17
mesh_graphormer/modeling/bert/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
__version__ = "1.0.0"
|
||||||
|
|
||||||
|
from .modeling_bert import (BertConfig, BertModel,
|
||||||
|
load_tf_weights_in_bert)
|
||||||
|
|
||||||
|
from .modeling_graphormer import Graphormer
|
||||||
|
|
||||||
|
from .e2e_body_network import Graphormer_Body_Network
|
||||||
|
|
||||||
|
from .e2e_hand_network import Graphormer_Hand_Network
|
||||||
|
|
||||||
|
CONFIG_NAME = "config.json"
|
||||||
|
|
||||||
|
from .modeling_utils import (WEIGHTS_NAME, TF_WEIGHTS_NAME,
|
||||||
|
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
|
||||||
|
|
||||||
|
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE)
|
||||||
16
mesh_graphormer/modeling/bert/bert-base-uncased/config.json
Normal file
16
mesh_graphormer/modeling/bert/bert-base-uncased/config.json
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"BertForMaskedLM"
|
||||||
|
],
|
||||||
|
"attention_probs_dropout_prob": 0.1,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout_prob": 0.1,
|
||||||
|
"hidden_size": 768,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"max_position_embeddings": 512,
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"type_vocab_size": 2,
|
||||||
|
"vocab_size": 30522
|
||||||
|
}
|
||||||
103
mesh_graphormer/modeling/bert/e2e_body_network.py
Normal file
103
mesh_graphormer/modeling/bert/e2e_body_network.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import mesh_graphormer.modeling.data.config as cfg
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
class Graphormer_Body_Network(torch.nn.Module):
|
||||||
|
'''
|
||||||
|
End-to-end Graphormer network for human pose and mesh reconstruction from a single image.
|
||||||
|
'''
|
||||||
|
def __init__(self, args, config, backbone, trans_encoder, mesh_sampler):
|
||||||
|
super(Graphormer_Body_Network, self).__init__()
|
||||||
|
self.config = config
|
||||||
|
self.config.device = device
|
||||||
|
self.backbone = backbone
|
||||||
|
self.trans_encoder = trans_encoder
|
||||||
|
self.upsampling = torch.nn.Linear(431, 1723)
|
||||||
|
self.upsampling2 = torch.nn.Linear(1723, 6890)
|
||||||
|
self.cam_param_fc = torch.nn.Linear(3, 1)
|
||||||
|
self.cam_param_fc2 = torch.nn.Linear(431, 250)
|
||||||
|
self.cam_param_fc3 = torch.nn.Linear(250, 3)
|
||||||
|
self.grid_feat_dim = torch.nn.Linear(1024, 2051)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, images, smpl, mesh_sampler, meta_masks=None, is_train=False):
|
||||||
|
batch_size = images.size(0)
|
||||||
|
# Generate T-pose template mesh
|
||||||
|
template_pose = torch.zeros((1,72))
|
||||||
|
template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord
|
||||||
|
template_pose = template_pose.to(device)
|
||||||
|
template_betas = torch.zeros((1,10)).to(device)
|
||||||
|
template_vertices = smpl(template_pose, template_betas)
|
||||||
|
|
||||||
|
# template mesh simplification
|
||||||
|
template_vertices_sub = mesh_sampler.downsample(template_vertices)
|
||||||
|
template_vertices_sub2 = mesh_sampler.downsample(template_vertices_sub, n1=1, n2=2)
|
||||||
|
|
||||||
|
# template mesh-to-joint regression
|
||||||
|
template_3d_joints = smpl.get_h36m_joints(template_vertices)
|
||||||
|
template_pelvis = template_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
|
||||||
|
template_3d_joints = template_3d_joints[:,cfg.H36M_J17_TO_J14,:]
|
||||||
|
num_joints = template_3d_joints.shape[1]
|
||||||
|
|
||||||
|
# normalize
|
||||||
|
template_3d_joints = template_3d_joints - template_pelvis[:, None, :]
|
||||||
|
template_vertices_sub2 = template_vertices_sub2 - template_pelvis[:, None, :]
|
||||||
|
|
||||||
|
# concatinate template joints and template vertices, and then duplicate to batch size
|
||||||
|
ref_vertices = torch.cat([template_3d_joints, template_vertices_sub2],dim=1)
|
||||||
|
ref_vertices = ref_vertices.expand(batch_size, -1, -1)
|
||||||
|
|
||||||
|
# extract grid features and global image features using a CNN backbone
|
||||||
|
image_feat, grid_feat = self.backbone(images)
|
||||||
|
# concatinate image feat and 3d mesh template
|
||||||
|
image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1)
|
||||||
|
# process grid features
|
||||||
|
grid_feat = torch.flatten(grid_feat, start_dim=2)
|
||||||
|
grid_feat = grid_feat.transpose(1,2)
|
||||||
|
grid_feat = self.grid_feat_dim(grid_feat)
|
||||||
|
# concatinate image feat and template mesh to form the joint/vertex queries
|
||||||
|
features = torch.cat([ref_vertices, image_feat], dim=2)
|
||||||
|
# prepare input tokens including joint/vertex queries and grid features
|
||||||
|
features = torch.cat([features, grid_feat],dim=1)
|
||||||
|
|
||||||
|
if is_train==True:
|
||||||
|
# apply mask vertex/joint modeling
|
||||||
|
# meta_masks is a tensor of all the masks, randomly generated in dataloader
|
||||||
|
# we pre-define a [MASK] token, which is a floating-value vector with 0.01s
|
||||||
|
special_token = torch.ones_like(features[:,:-49,:]).to(device)*0.01
|
||||||
|
features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
if self.config.output_attentions==True:
|
||||||
|
features, hidden_states, att = self.trans_encoder(features)
|
||||||
|
else:
|
||||||
|
features = self.trans_encoder(features)
|
||||||
|
|
||||||
|
pred_3d_joints = features[:,:num_joints,:]
|
||||||
|
pred_vertices_sub2 = features[:,num_joints:-49,:]
|
||||||
|
|
||||||
|
# learn camera parameters
|
||||||
|
x = self.cam_param_fc(pred_vertices_sub2)
|
||||||
|
x = x.transpose(1,2)
|
||||||
|
x = self.cam_param_fc2(x)
|
||||||
|
x = self.cam_param_fc3(x)
|
||||||
|
cam_param = x.transpose(1,2)
|
||||||
|
cam_param = cam_param.squeeze()
|
||||||
|
|
||||||
|
temp_transpose = pred_vertices_sub2.transpose(1,2)
|
||||||
|
pred_vertices_sub = self.upsampling(temp_transpose)
|
||||||
|
pred_vertices_full = self.upsampling2(pred_vertices_sub)
|
||||||
|
pred_vertices_sub = pred_vertices_sub.transpose(1,2)
|
||||||
|
pred_vertices_full = pred_vertices_full.transpose(1,2)
|
||||||
|
|
||||||
|
if self.config.output_attentions==True:
|
||||||
|
return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full, hidden_states, att
|
||||||
|
else:
|
||||||
|
return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full
|
||||||
94
mesh_graphormer/modeling/bert/e2e_hand_network.py
Normal file
94
mesh_graphormer/modeling/bert/e2e_hand_network.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import mesh_graphormer.modeling.data.config as cfg
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
class Graphormer_Hand_Network(torch.nn.Module):
|
||||||
|
'''
|
||||||
|
End-to-end Graphormer network for hand pose and mesh reconstruction from a single image.
|
||||||
|
'''
|
||||||
|
def __init__(self, args, config, backbone, trans_encoder):
|
||||||
|
super(Graphormer_Hand_Network, self).__init__()
|
||||||
|
self.config = config
|
||||||
|
self.backbone = backbone
|
||||||
|
self.trans_encoder = trans_encoder
|
||||||
|
self.upsampling = torch.nn.Linear(195, 778)
|
||||||
|
self.cam_param_fc = torch.nn.Linear(3, 1)
|
||||||
|
self.cam_param_fc2 = torch.nn.Linear(195+21, 150)
|
||||||
|
self.cam_param_fc3 = torch.nn.Linear(150, 3)
|
||||||
|
self.grid_feat_dim = torch.nn.Linear(1024, 2051)
|
||||||
|
|
||||||
|
def forward(self, images, mesh_model, mesh_sampler, meta_masks=None, is_train=False):
|
||||||
|
batch_size = images.size(0)
|
||||||
|
# Generate T-pose template mesh
|
||||||
|
template_pose = torch.zeros((1,48))
|
||||||
|
template_pose = template_pose.to(device)
|
||||||
|
template_betas = torch.zeros((1,10)).to(device)
|
||||||
|
template_vertices, template_3d_joints = mesh_model.layer(template_pose, template_betas)
|
||||||
|
template_vertices = template_vertices/1000.0
|
||||||
|
template_3d_joints = template_3d_joints/1000.0
|
||||||
|
|
||||||
|
template_vertices_sub = mesh_sampler.downsample(template_vertices)
|
||||||
|
|
||||||
|
# normalize
|
||||||
|
template_root = template_3d_joints[:,cfg.J_NAME.index('Wrist'),:]
|
||||||
|
template_3d_joints = template_3d_joints - template_root[:, None, :]
|
||||||
|
template_vertices = template_vertices - template_root[:, None, :]
|
||||||
|
template_vertices_sub = template_vertices_sub - template_root[:, None, :]
|
||||||
|
num_joints = template_3d_joints.shape[1]
|
||||||
|
|
||||||
|
# concatinate template joints and template vertices, and then duplicate to batch size
|
||||||
|
ref_vertices = torch.cat([template_3d_joints, template_vertices_sub],dim=1)
|
||||||
|
ref_vertices = ref_vertices.expand(batch_size, -1, -1)
|
||||||
|
|
||||||
|
# extract grid features and global image features using a CNN backbone
|
||||||
|
image_feat, grid_feat = self.backbone(images)
|
||||||
|
# concatinate image feat and mesh template
|
||||||
|
image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1)
|
||||||
|
# process grid features
|
||||||
|
grid_feat = torch.flatten(grid_feat, start_dim=2)
|
||||||
|
grid_feat = grid_feat.transpose(1,2)
|
||||||
|
grid_feat = self.grid_feat_dim(grid_feat)
|
||||||
|
# concatinate image feat and template mesh to form the joint/vertex queries
|
||||||
|
features = torch.cat([ref_vertices, image_feat], dim=2)
|
||||||
|
# prepare input tokens including joint/vertex queries and grid features
|
||||||
|
features = torch.cat([features, grid_feat],dim=1)
|
||||||
|
|
||||||
|
if is_train==True:
|
||||||
|
# apply mask vertex/joint modeling
|
||||||
|
# meta_masks is a tensor of all the masks, randomly generated in dataloader
|
||||||
|
# we pre-define a [MASK] token, which is a floating-value vector with 0.01s
|
||||||
|
special_token = torch.ones_like(features[:,:-49,:]).to(device)*0.01
|
||||||
|
features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
if self.config.output_attentions==True:
|
||||||
|
features, hidden_states, att = self.trans_encoder(features)
|
||||||
|
else:
|
||||||
|
features = self.trans_encoder(features)
|
||||||
|
|
||||||
|
pred_3d_joints = features[:,:num_joints,:]
|
||||||
|
pred_vertices_sub = features[:,num_joints:-49,:]
|
||||||
|
|
||||||
|
# learn camera parameters
|
||||||
|
x = self.cam_param_fc(features[:,:-49,:])
|
||||||
|
x = x.transpose(1,2)
|
||||||
|
x = self.cam_param_fc2(x)
|
||||||
|
x = self.cam_param_fc3(x)
|
||||||
|
cam_param = x.transpose(1,2)
|
||||||
|
cam_param = cam_param.squeeze()
|
||||||
|
|
||||||
|
temp_transpose = pred_vertices_sub.transpose(1,2)
|
||||||
|
pred_vertices = self.upsampling(temp_transpose)
|
||||||
|
pred_vertices = pred_vertices.transpose(1,2)
|
||||||
|
|
||||||
|
if self.config.output_attentions==True:
|
||||||
|
return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att
|
||||||
|
else:
|
||||||
|
return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices
|
||||||
1
mesh_graphormer/modeling/bert/file_utils.py
Normal file
1
mesh_graphormer/modeling/bert/file_utils.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from transformers.file_utils import *
|
||||||
1
mesh_graphormer/modeling/bert/modeling_bert.py
Normal file
1
mesh_graphormer/modeling/bert/modeling_bert.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from transformers.models.bert.modeling_bert import *
|
||||||
328
mesh_graphormer/modeling/bert/modeling_graphormer.py
Normal file
328
mesh_graphormer/modeling/bert/modeling_graphormer.py
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import code
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput
|
||||||
|
import mesh_graphormer.modeling.data.config as cfg
|
||||||
|
from mesh_graphormer.modeling._gcnn import GraphConvolution, GraphResBlock
|
||||||
|
from .modeling_utils import prune_linear_layer
|
||||||
|
LayerNormClass = torch.nn.LayerNorm
|
||||||
|
BertLayerNorm = torch.nn.LayerNorm
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
class BertSelfAttention(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(BertSelfAttention, self).__init__()
|
||||||
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||||||
|
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
|
||||||
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
|
|
||||||
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||||
|
|
||||||
|
def transpose_for_scores(self, x):
|
||||||
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
|
x = x.view(*new_x_shape)
|
||||||
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attention_mask, head_mask=None,
|
||||||
|
history_state=None):
|
||||||
|
if history_state is not None:
|
||||||
|
x_states = torch.cat([history_state, hidden_states], dim=1)
|
||||||
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
mixed_key_layer = self.key(x_states)
|
||||||
|
mixed_value_layer = self.value(x_states)
|
||||||
|
else:
|
||||||
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
mixed_key_layer = self.key(hidden_states)
|
||||||
|
mixed_value_layer = self.value(hidden_states)
|
||||||
|
|
||||||
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
|
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||||
|
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||||
|
|
||||||
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||||
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||||
|
attention_scores = attention_scores + attention_mask
|
||||||
|
|
||||||
|
# Normalize the attention scores to probabilities.
|
||||||
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||||
|
|
||||||
|
# This is actually dropping out entire tokens to attend to, which might
|
||||||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||||
|
attention_probs = self.dropout(attention_probs)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if head_mask is not None:
|
||||||
|
attention_probs = attention_probs * head_mask
|
||||||
|
|
||||||
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||||||
|
|
||||||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
|
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
class BertAttention(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(BertAttention, self).__init__()
|
||||||
|
self.self = BertSelfAttention(config)
|
||||||
|
self.output = BertSelfOutput(config)
|
||||||
|
|
||||||
|
def prune_heads(self, heads):
|
||||||
|
if len(heads) == 0:
|
||||||
|
return
|
||||||
|
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
||||||
|
for head in heads:
|
||||||
|
mask[head] = 0
|
||||||
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
|
index = torch.arange(len(mask))[mask].long()
|
||||||
|
# Prune linear layers
|
||||||
|
self.self.query = prune_linear_layer(self.self.query, index)
|
||||||
|
self.self.key = prune_linear_layer(self.self.key, index)
|
||||||
|
self.self.value = prune_linear_layer(self.self.value, index)
|
||||||
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||||
|
# Update hyper params
|
||||||
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||||
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
|
|
||||||
|
def forward(self, input_tensor, attention_mask, head_mask=None,
|
||||||
|
history_state=None):
|
||||||
|
self_outputs = self.self(input_tensor, attention_mask, head_mask,
|
||||||
|
history_state)
|
||||||
|
attention_output = self.output(self_outputs[0], input_tensor)
|
||||||
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class GraphormerLayer(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(GraphormerLayer, self).__init__()
|
||||||
|
self.attention = BertAttention(config)
|
||||||
|
self.has_graph_conv = config.graph_conv
|
||||||
|
self.mesh_type = config.mesh_type
|
||||||
|
|
||||||
|
if self.has_graph_conv == True:
|
||||||
|
self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type)
|
||||||
|
|
||||||
|
self.intermediate = BertIntermediate(config)
|
||||||
|
self.output = BertOutput(config)
|
||||||
|
|
||||||
|
def MHA_GCN(self, hidden_states, attention_mask, head_mask=None,
|
||||||
|
history_state=None):
|
||||||
|
attention_outputs = self.attention(hidden_states, attention_mask,
|
||||||
|
head_mask, history_state)
|
||||||
|
attention_output = attention_outputs[0]
|
||||||
|
|
||||||
|
if self.has_graph_conv==True:
|
||||||
|
if self.mesh_type == 'body':
|
||||||
|
joints = attention_output[:,0:14,:]
|
||||||
|
vertices = attention_output[:,14:-49,:]
|
||||||
|
img_tokens = attention_output[:,-49:,:]
|
||||||
|
|
||||||
|
elif self.mesh_type == 'hand':
|
||||||
|
joints = attention_output[:,0:21,:]
|
||||||
|
vertices = attention_output[:,21:-49,:]
|
||||||
|
img_tokens = attention_output[:,-49:,:]
|
||||||
|
|
||||||
|
vertices = self.graph_conv(vertices)
|
||||||
|
joints_vertices = torch.cat([joints,vertices,img_tokens],dim=1)
|
||||||
|
else:
|
||||||
|
joints_vertices = attention_output
|
||||||
|
|
||||||
|
intermediate_output = self.intermediate(joints_vertices)
|
||||||
|
layer_output = self.output(intermediate_output, joints_vertices)
|
||||||
|
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attention_mask, head_mask=None,
|
||||||
|
history_state=None):
|
||||||
|
return self.MHA_GCN(hidden_states, attention_mask, head_mask,history_state)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphormerEncoder(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(GraphormerEncoder, self).__init__()
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
self.layer = nn.ModuleList([GraphormerLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attention_mask, head_mask=None,
|
||||||
|
encoder_history_states=None):
|
||||||
|
all_hidden_states = ()
|
||||||
|
all_attentions = ()
|
||||||
|
for i, layer_module in enumerate(self.layer):
|
||||||
|
if self.output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
history_state = None if encoder_history_states is None else encoder_history_states[i]
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states, attention_mask, head_mask[i],
|
||||||
|
history_state)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if self.output_attentions:
|
||||||
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
# Add last layer
|
||||||
|
if self.output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
if self.output_hidden_states:
|
||||||
|
outputs = outputs + (all_hidden_states,)
|
||||||
|
if self.output_attentions:
|
||||||
|
outputs = outputs + (all_attentions,)
|
||||||
|
|
||||||
|
return outputs # outputs, (hidden states), (attentions)
|
||||||
|
|
||||||
|
class EncoderBlock(BertPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(EncoderBlock, self).__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.embeddings = BertEmbeddings(config)
|
||||||
|
self.encoder = GraphormerEncoder(config)
|
||||||
|
self.pooler = BertPooler(config)
|
||||||
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||||
|
self.img_dim = config.img_feature_dim
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.use_img_layernorm = config.use_img_layernorm
|
||||||
|
except:
|
||||||
|
self.use_img_layernorm = None
|
||||||
|
|
||||||
|
self.img_embedding = nn.Linear(self.img_dim, self.config.hidden_size, bias=True)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
if self.use_img_layernorm:
|
||||||
|
self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.img_layer_norm_eps)
|
||||||
|
|
||||||
|
|
||||||
|
def _prune_heads(self, heads_to_prune):
|
||||||
|
""" Prunes heads of the model.
|
||||||
|
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||||
|
See base class PreTrainedModel
|
||||||
|
"""
|
||||||
|
for layer, heads in heads_to_prune.items():
|
||||||
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
|
def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None,
|
||||||
|
position_ids=None, head_mask=None):
|
||||||
|
|
||||||
|
batch_size = len(img_feats)
|
||||||
|
seq_length = len(img_feats[0])
|
||||||
|
input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).to(device)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||||
|
|
||||||
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = torch.zeros_like(input_ids)
|
||||||
|
|
||||||
|
if attention_mask.dim() == 2:
|
||||||
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
elif attention_mask.dim() == 3:
|
||||||
|
extended_attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||||
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||||
|
|
||||||
|
if head_mask is not None:
|
||||||
|
if head_mask.dim() == 1:
|
||||||
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||||
|
elif head_mask.dim() == 2:
|
||||||
|
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||||
|
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||||
|
else:
|
||||||
|
head_mask = [None] * self.config.num_hidden_layers
|
||||||
|
|
||||||
|
# Project input token features to have spcified hidden size
|
||||||
|
img_embedding_output = self.img_embedding(img_feats)
|
||||||
|
|
||||||
|
# We empirically observe that adding an additional learnable position embedding leads to more stable training
|
||||||
|
embeddings = position_embeddings + img_embedding_output
|
||||||
|
|
||||||
|
if self.use_img_layernorm:
|
||||||
|
embeddings = self.LayerNorm(embeddings)
|
||||||
|
embeddings = self.dropout(embeddings)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(embeddings,
|
||||||
|
extended_attention_mask, head_mask=head_mask)
|
||||||
|
sequence_output = encoder_outputs[0]
|
||||||
|
|
||||||
|
outputs = (sequence_output,)
|
||||||
|
if self.config.output_hidden_states:
|
||||||
|
all_hidden_states = encoder_outputs[1]
|
||||||
|
outputs = outputs + (all_hidden_states,)
|
||||||
|
if self.config.output_attentions:
|
||||||
|
all_attentions = encoder_outputs[-1]
|
||||||
|
outputs = outputs + (all_attentions,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
class Graphormer(BertPreTrainedModel):
|
||||||
|
'''
|
||||||
|
The archtecture of a transformer encoder block we used in Graphormer
|
||||||
|
'''
|
||||||
|
def __init__(self, config):
|
||||||
|
super(Graphormer, self).__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.bert = EncoderBlock(config)
|
||||||
|
self.cls_head = nn.Linear(config.hidden_size, self.config.output_feature_dim)
|
||||||
|
self.residual = nn.Linear(config.img_feature_dim, self.config.output_feature_dim)
|
||||||
|
|
||||||
|
def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
||||||
|
next_sentence_label=None, position_ids=None, head_mask=None):
|
||||||
|
'''
|
||||||
|
# self.bert has three outputs
|
||||||
|
# predictions[0]: output tokens
|
||||||
|
# predictions[1]: all_hidden_states, if enable "self.config.output_hidden_states"
|
||||||
|
# predictions[2]: attentions, if enable "self.config.output_attentions"
|
||||||
|
'''
|
||||||
|
predictions = self.bert(img_feats=img_feats, input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||||
|
attention_mask=attention_mask, head_mask=head_mask)
|
||||||
|
|
||||||
|
# We use "self.cls_head" to perform dimensionality reduction. We don't use it for classification.
|
||||||
|
pred_score = self.cls_head(predictions[0])
|
||||||
|
res_img_feats = self.residual(img_feats)
|
||||||
|
pred_score = pred_score + res_img_feats
|
||||||
|
|
||||||
|
if self.config.output_attentions and self.config.output_hidden_states:
|
||||||
|
return pred_score, predictions[1], predictions[-1]
|
||||||
|
else:
|
||||||
|
return pred_score
|
||||||
|
|
||||||
|
|
||||||
1
mesh_graphormer/modeling/bert/modeling_utils.py
Normal file
1
mesh_graphormer/modeling/bert/modeling_utils.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from transformers.modeling_utils import *
|
||||||
BIN
mesh_graphormer/modeling/data/J_regressor_extra.npy
Normal file
BIN
mesh_graphormer/modeling/data/J_regressor_extra.npy
Normal file
Binary file not shown.
BIN
mesh_graphormer/modeling/data/J_regressor_h36m_correct.npy
Normal file
BIN
mesh_graphormer/modeling/data/J_regressor_h36m_correct.npy
Normal file
Binary file not shown.
BIN
mesh_graphormer/modeling/data/MANO_LEFT.pkl
Normal file
BIN
mesh_graphormer/modeling/data/MANO_LEFT.pkl
Normal file
Binary file not shown.
BIN
mesh_graphormer/modeling/data/MANO_RIGHT.pkl
Normal file
BIN
mesh_graphormer/modeling/data/MANO_RIGHT.pkl
Normal file
Binary file not shown.
30
mesh_graphormer/modeling/data/README.md
Normal file
30
mesh_graphormer/modeling/data/README.md
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
|
||||||
|
# Extra data
|
||||||
|
Adapted from open source project [GraphCMR](https://github.com/nkolot/GraphCMR/) and [Pose2Mesh](https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
|
||||||
|
|
||||||
|
Our code requires additional data to run smoothly.
|
||||||
|
|
||||||
|
### J_regressor_extra.npy
|
||||||
|
Joints regressor for joints or landmarks that are not included in the standard set of SMPL joints.
|
||||||
|
|
||||||
|
### J_regressor_h36m_correct.npy
|
||||||
|
Joints regressor reflecting the Human3.6M joints.
|
||||||
|
|
||||||
|
### mesh_downsampling.npz
|
||||||
|
Extra file with precomputed downsampling for the SMPL body mesh.
|
||||||
|
|
||||||
|
### mano_downsampling.npz
|
||||||
|
Extra file with precomputed downsampling for the MANO hand mesh.
|
||||||
|
|
||||||
|
### basicModel_neutral_lbs_10_207_0_v1.0.0.pkl
|
||||||
|
SMPL neutral model. Please visit the official website to download the file [http://smplify.is.tue.mpg.de/](http://smplify.is.tue.mpg.de/)
|
||||||
|
|
||||||
|
### basicModel_m_lbs_10_207_0_v1.0.0.pkl
|
||||||
|
SMPL male model. Please visit the official website to download the file [https://smpl.is.tue.mpg.de/](https://smpl.is.tue.mpg.de/)
|
||||||
|
|
||||||
|
### basicModel_f_lbs_10_207_0_v1.0.0.pkl
|
||||||
|
SMPL female model. Please visit the official website to download the file [https://smpl.is.tue.mpg.de/](https://smpl.is.tue.mpg.de/)
|
||||||
|
|
||||||
|
### MANO_RIGHT.pkl
|
||||||
|
MANO hand model. Please visit the official website to download the file [https://mano.is.tue.mpg.de/](https://mano.is.tue.mpg.de/)
|
||||||
|
|
||||||
47
mesh_graphormer/modeling/data/config.py
Normal file
47
mesh_graphormer/modeling/data/config.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""
|
||||||
|
This file contains definitions of useful data stuctures and the paths
|
||||||
|
for the datasets and data files necessary to run the code.
|
||||||
|
|
||||||
|
Adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) and Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
folder_path = Path(__file__).parent.parent
|
||||||
|
JOINT_REGRESSOR_TRAIN_EXTRA = folder_path / 'data/J_regressor_extra.npy'
|
||||||
|
JOINT_REGRESSOR_H36M_correct = folder_path / 'data/J_regressor_h36m_correct.npy'
|
||||||
|
SMPL_FILE = folder_path / 'data/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl'
|
||||||
|
SMPL_Male = folder_path / 'data/basicModel_m_lbs_10_207_0_v1.0.0.pkl'
|
||||||
|
SMPL_Female = folder_path / 'data/basicModel_f_lbs_10_207_0_v1.0.0.pkl'
|
||||||
|
SMPL_sampling_matrix = folder_path / 'data/mesh_downsampling.npz'
|
||||||
|
MANO_FILE = folder_path / 'data/MANO_RIGHT.pkl'
|
||||||
|
MANO_sampling_matrix = folder_path / 'data/mano_downsampling.npz'
|
||||||
|
|
||||||
|
JOINTS_IDX = [8, 5, 29, 30, 4, 7, 21, 19, 17, 16, 18, 20, 31, 32, 33, 34, 35, 36, 37, 24, 26, 25, 28, 27]
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
We follow the body joint definition, loss functions, and evaluation metrics from
|
||||||
|
open source project GraphCMR (https://github.com/nkolot/GraphCMR/)
|
||||||
|
|
||||||
|
Each dataset uses different sets of joints.
|
||||||
|
We use a superset of 24 joints such that we include all joints from every dataset.
|
||||||
|
If a dataset doesn't provide annotations for a specific joint, we simply ignore it.
|
||||||
|
The joints used here are:
|
||||||
|
"""
|
||||||
|
J24_NAME = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder',
|
||||||
|
'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear')
|
||||||
|
H36M_J17_NAME = ( 'Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head',
|
||||||
|
'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist')
|
||||||
|
J24_TO_J14 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
|
||||||
|
H36M_J17_TO_J14 = [3, 2, 1, 4, 5, 6, 16, 15, 14, 11, 12, 13, 8, 10]
|
||||||
|
|
||||||
|
"""
|
||||||
|
We follow the hand joint definition and mesh topology from
|
||||||
|
open source project Manopth (https://github.com/hassony2/manopth)
|
||||||
|
|
||||||
|
The hand joints used here are:
|
||||||
|
"""
|
||||||
|
J_NAME = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1',
|
||||||
|
'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
|
||||||
|
ROOT_INDEX = 0
|
||||||
BIN
mesh_graphormer/modeling/data/mano_195_adjmat_indices.pt
Normal file
BIN
mesh_graphormer/modeling/data/mano_195_adjmat_indices.pt
Normal file
Binary file not shown.
BIN
mesh_graphormer/modeling/data/mano_195_adjmat_size.pt
Normal file
BIN
mesh_graphormer/modeling/data/mano_195_adjmat_size.pt
Normal file
Binary file not shown.
BIN
mesh_graphormer/modeling/data/mano_195_adjmat_values.pt
Normal file
BIN
mesh_graphormer/modeling/data/mano_195_adjmat_values.pt
Normal file
Binary file not shown.
BIN
mesh_graphormer/modeling/data/mano_downsampling.npz
Normal file
BIN
mesh_graphormer/modeling/data/mano_downsampling.npz
Normal file
Binary file not shown.
BIN
mesh_graphormer/modeling/data/mesh_downsampling.npz
Normal file
BIN
mesh_graphormer/modeling/data/mesh_downsampling.npz
Normal file
Binary file not shown.
BIN
mesh_graphormer/modeling/data/smpl_431_adjmat_indices.pt
Normal file
BIN
mesh_graphormer/modeling/data/smpl_431_adjmat_indices.pt
Normal file
Binary file not shown.
BIN
mesh_graphormer/modeling/data/smpl_431_adjmat_size.pt
Normal file
BIN
mesh_graphormer/modeling/data/smpl_431_adjmat_size.pt
Normal file
Binary file not shown.
BIN
mesh_graphormer/modeling/data/smpl_431_adjmat_values.pt
Normal file
BIN
mesh_graphormer/modeling/data/smpl_431_adjmat_values.pt
Normal file
Binary file not shown.
BIN
mesh_graphormer/modeling/data/smpl_431_faces.npy
Normal file
BIN
mesh_graphormer/modeling/data/smpl_431_faces.npy
Normal file
Binary file not shown.
9
mesh_graphormer/modeling/hrnet/config/__init__.py
Normal file
9
mesh_graphormer/modeling/hrnet/config/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Microsoft
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from .default import _C as config
|
||||||
|
from .default import update_config
|
||||||
|
from .models import MODEL_EXTRAS
|
||||||
138
mesh_graphormer/modeling/hrnet/config/default.py
Normal file
138
mesh_graphormer/modeling/hrnet/config/default.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Microsoft
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
|
||||||
|
# Modified by Ke Sun (sunk@mail.ustc.edu.cn)
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from yacs.config import CfgNode as CN
|
||||||
|
|
||||||
|
|
||||||
|
_C = CN()
|
||||||
|
|
||||||
|
_C.OUTPUT_DIR = ''
|
||||||
|
_C.LOG_DIR = ''
|
||||||
|
_C.DATA_DIR = ''
|
||||||
|
_C.GPUS = (0,)
|
||||||
|
_C.WORKERS = 4
|
||||||
|
_C.PRINT_FREQ = 20
|
||||||
|
_C.AUTO_RESUME = False
|
||||||
|
_C.PIN_MEMORY = True
|
||||||
|
_C.RANK = 0
|
||||||
|
|
||||||
|
# Cudnn related params
|
||||||
|
_C.CUDNN = CN()
|
||||||
|
_C.CUDNN.BENCHMARK = True
|
||||||
|
_C.CUDNN.DETERMINISTIC = False
|
||||||
|
_C.CUDNN.ENABLED = True
|
||||||
|
|
||||||
|
# common params for NETWORK
|
||||||
|
_C.MODEL = CN()
|
||||||
|
_C.MODEL.NAME = 'cls_hrnet'
|
||||||
|
_C.MODEL.INIT_WEIGHTS = True
|
||||||
|
_C.MODEL.PRETRAINED = ''
|
||||||
|
_C.MODEL.NUM_JOINTS = 17
|
||||||
|
_C.MODEL.NUM_CLASSES = 1000
|
||||||
|
_C.MODEL.TAG_PER_JOINT = True
|
||||||
|
_C.MODEL.TARGET_TYPE = 'gaussian'
|
||||||
|
_C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256
|
||||||
|
_C.MODEL.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32
|
||||||
|
_C.MODEL.SIGMA = 2
|
||||||
|
_C.MODEL.EXTRA = CN(new_allowed=True)
|
||||||
|
|
||||||
|
_C.LOSS = CN()
|
||||||
|
_C.LOSS.USE_OHKM = False
|
||||||
|
_C.LOSS.TOPK = 8
|
||||||
|
_C.LOSS.USE_TARGET_WEIGHT = True
|
||||||
|
_C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False
|
||||||
|
|
||||||
|
# DATASET related params
|
||||||
|
_C.DATASET = CN()
|
||||||
|
_C.DATASET.ROOT = ''
|
||||||
|
_C.DATASET.DATASET = 'mpii'
|
||||||
|
_C.DATASET.TRAIN_SET = 'train'
|
||||||
|
_C.DATASET.TEST_SET = 'valid'
|
||||||
|
_C.DATASET.DATA_FORMAT = 'jpg'
|
||||||
|
_C.DATASET.HYBRID_JOINTS_TYPE = ''
|
||||||
|
_C.DATASET.SELECT_DATA = False
|
||||||
|
|
||||||
|
# training data augmentation
|
||||||
|
_C.DATASET.FLIP = True
|
||||||
|
_C.DATASET.SCALE_FACTOR = 0.25
|
||||||
|
_C.DATASET.ROT_FACTOR = 30
|
||||||
|
_C.DATASET.PROB_HALF_BODY = 0.0
|
||||||
|
_C.DATASET.NUM_JOINTS_HALF_BODY = 8
|
||||||
|
_C.DATASET.COLOR_RGB = False
|
||||||
|
|
||||||
|
# train
|
||||||
|
_C.TRAIN = CN()
|
||||||
|
|
||||||
|
_C.TRAIN.LR_FACTOR = 0.1
|
||||||
|
_C.TRAIN.LR_STEP = [90, 110]
|
||||||
|
_C.TRAIN.LR = 0.001
|
||||||
|
|
||||||
|
_C.TRAIN.OPTIMIZER = 'adam'
|
||||||
|
_C.TRAIN.MOMENTUM = 0.9
|
||||||
|
_C.TRAIN.WD = 0.0001
|
||||||
|
_C.TRAIN.NESTEROV = False
|
||||||
|
_C.TRAIN.GAMMA1 = 0.99
|
||||||
|
_C.TRAIN.GAMMA2 = 0.0
|
||||||
|
|
||||||
|
_C.TRAIN.BEGIN_EPOCH = 0
|
||||||
|
_C.TRAIN.END_EPOCH = 140
|
||||||
|
|
||||||
|
_C.TRAIN.RESUME = False
|
||||||
|
_C.TRAIN.CHECKPOINT = ''
|
||||||
|
|
||||||
|
_C.TRAIN.BATCH_SIZE_PER_GPU = 32
|
||||||
|
_C.TRAIN.SHUFFLE = True
|
||||||
|
|
||||||
|
# testing
|
||||||
|
_C.TEST = CN()
|
||||||
|
|
||||||
|
# size of images for each device
|
||||||
|
_C.TEST.BATCH_SIZE_PER_GPU = 32
|
||||||
|
# Test Model Epoch
|
||||||
|
_C.TEST.FLIP_TEST = False
|
||||||
|
_C.TEST.POST_PROCESS = False
|
||||||
|
_C.TEST.SHIFT_HEATMAP = False
|
||||||
|
|
||||||
|
_C.TEST.USE_GT_BBOX = False
|
||||||
|
|
||||||
|
# nms
|
||||||
|
_C.TEST.IMAGE_THRE = 0.1
|
||||||
|
_C.TEST.NMS_THRE = 0.6
|
||||||
|
_C.TEST.SOFT_NMS = False
|
||||||
|
_C.TEST.OKS_THRE = 0.5
|
||||||
|
_C.TEST.IN_VIS_THRE = 0.0
|
||||||
|
_C.TEST.COCO_BBOX_FILE = ''
|
||||||
|
_C.TEST.BBOX_THRE = 1.0
|
||||||
|
_C.TEST.MODEL_FILE = ''
|
||||||
|
|
||||||
|
# debug
|
||||||
|
_C.DEBUG = CN()
|
||||||
|
_C.DEBUG.DEBUG = False
|
||||||
|
_C.DEBUG.SAVE_BATCH_IMAGES_GT = False
|
||||||
|
_C.DEBUG.SAVE_BATCH_IMAGES_PRED = False
|
||||||
|
_C.DEBUG.SAVE_HEATMAPS_GT = False
|
||||||
|
_C.DEBUG.SAVE_HEATMAPS_PRED = False
|
||||||
|
|
||||||
|
|
||||||
|
def update_config(cfg, config_file):
|
||||||
|
cfg.defrost()
|
||||||
|
cfg.merge_from_file(config_file)
|
||||||
|
cfg.freeze()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
with open(sys.argv[1], 'w') as f:
|
||||||
|
print(_C, file=f)
|
||||||
|
|
||||||
47
mesh_graphormer/modeling/hrnet/config/models.py
Normal file
47
mesh_graphormer/modeling/hrnet/config/models.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Microsoft
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
# Create by Bin Xiao (Bin.Xiao@microsoft.com)
|
||||||
|
# Modified by Ke Sun (sunk@mail.ustc.edu.cn)
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from yacs.config import CfgNode as CN
|
||||||
|
|
||||||
|
# high_resoluton_net related params for classification
|
||||||
|
POSE_HIGH_RESOLUTION_NET = CN()
|
||||||
|
POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*']
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = 64
|
||||||
|
POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1
|
||||||
|
POSE_HIGH_RESOLUTION_NET.WITH_HEAD = True
|
||||||
|
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE2 = CN()
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4]
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64]
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC'
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM'
|
||||||
|
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE3 = CN()
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4]
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128]
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC'
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM'
|
||||||
|
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE4 = CN()
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC'
|
||||||
|
POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM'
|
||||||
|
|
||||||
|
MODEL_EXTRAS = {
|
||||||
|
'cls_hrnet': POSE_HIGH_RESOLUTION_NET,
|
||||||
|
}
|
||||||
523
mesh_graphormer/modeling/hrnet/hrnet_cls_net.py
Normal file
523
mesh_graphormer/modeling/hrnet/hrnet_cls_net.py
Normal file
@@ -0,0 +1,523 @@
|
|||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Microsoft
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
|
||||||
|
# Modified by Ke Sun (sunk@mail.ustc.edu.cn)
|
||||||
|
# Modified by Kevin Lin (keli@microsoft.com)
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch._utils
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import code
|
||||||
|
BN_MOMENTUM = 0.1
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, stride=1):
|
||||||
|
"""3x3 convolution with padding"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
|
padding=1, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||||
|
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.conv2 = conv3x3(planes, planes)
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(x)
|
||||||
|
|
||||||
|
out += residual
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||||
|
padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||||
|
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
|
||||||
|
bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
|
||||||
|
momentum=BN_MOMENTUM)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(x)
|
||||||
|
|
||||||
|
out += residual
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class HighResolutionModule(nn.Module):
|
||||||
|
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
|
||||||
|
num_channels, fuse_method, multi_scale_output=True):
|
||||||
|
super(HighResolutionModule, self).__init__()
|
||||||
|
self._check_branches(
|
||||||
|
num_branches, blocks, num_blocks, num_inchannels, num_channels)
|
||||||
|
|
||||||
|
self.num_inchannels = num_inchannels
|
||||||
|
self.fuse_method = fuse_method
|
||||||
|
self.num_branches = num_branches
|
||||||
|
|
||||||
|
self.multi_scale_output = multi_scale_output
|
||||||
|
|
||||||
|
self.branches = self._make_branches(
|
||||||
|
num_branches, blocks, num_blocks, num_channels)
|
||||||
|
self.fuse_layers = self._make_fuse_layers()
|
||||||
|
self.relu = nn.ReLU(False)
|
||||||
|
|
||||||
|
def _check_branches(self, num_branches, blocks, num_blocks,
|
||||||
|
num_inchannels, num_channels):
|
||||||
|
if num_branches != len(num_blocks):
|
||||||
|
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
||||||
|
num_branches, len(num_blocks))
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
if num_branches != len(num_channels):
|
||||||
|
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
||||||
|
num_branches, len(num_channels))
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
if num_branches != len(num_inchannels):
|
||||||
|
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
||||||
|
num_branches, len(num_inchannels))
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
||||||
|
stride=1):
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or \
|
||||||
|
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(self.num_inchannels[branch_index],
|
||||||
|
num_channels[branch_index] * block.expansion,
|
||||||
|
kernel_size=1, stride=stride, bias=False),
|
||||||
|
nn.BatchNorm2d(num_channels[branch_index] * block.expansion,
|
||||||
|
momentum=BN_MOMENTUM),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.num_inchannels[branch_index],
|
||||||
|
num_channels[branch_index], stride, downsample))
|
||||||
|
self.num_inchannels[branch_index] = \
|
||||||
|
num_channels[branch_index] * block.expansion
|
||||||
|
for i in range(1, num_blocks[branch_index]):
|
||||||
|
layers.append(block(self.num_inchannels[branch_index],
|
||||||
|
num_channels[branch_index]))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
||||||
|
branches = []
|
||||||
|
|
||||||
|
for i in range(num_branches):
|
||||||
|
branches.append(
|
||||||
|
self._make_one_branch(i, block, num_blocks, num_channels))
|
||||||
|
|
||||||
|
return nn.ModuleList(branches)
|
||||||
|
|
||||||
|
def _make_fuse_layers(self):
|
||||||
|
if self.num_branches == 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
num_branches = self.num_branches
|
||||||
|
num_inchannels = self.num_inchannels
|
||||||
|
fuse_layers = []
|
||||||
|
for i in range(num_branches if self.multi_scale_output else 1):
|
||||||
|
fuse_layer = []
|
||||||
|
for j in range(num_branches):
|
||||||
|
if j > i:
|
||||||
|
fuse_layer.append(nn.Sequential(
|
||||||
|
nn.Conv2d(num_inchannels[j],
|
||||||
|
num_inchannels[i],
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
bias=False),
|
||||||
|
nn.BatchNorm2d(num_inchannels[i],
|
||||||
|
momentum=BN_MOMENTUM),
|
||||||
|
nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
|
||||||
|
elif j == i:
|
||||||
|
fuse_layer.append(None)
|
||||||
|
else:
|
||||||
|
conv3x3s = []
|
||||||
|
for k in range(i-j):
|
||||||
|
if k == i - j - 1:
|
||||||
|
num_outchannels_conv3x3 = num_inchannels[i]
|
||||||
|
conv3x3s.append(nn.Sequential(
|
||||||
|
nn.Conv2d(num_inchannels[j],
|
||||||
|
num_outchannels_conv3x3,
|
||||||
|
3, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(num_outchannels_conv3x3,
|
||||||
|
momentum=BN_MOMENTUM)))
|
||||||
|
else:
|
||||||
|
num_outchannels_conv3x3 = num_inchannels[j]
|
||||||
|
conv3x3s.append(nn.Sequential(
|
||||||
|
nn.Conv2d(num_inchannels[j],
|
||||||
|
num_outchannels_conv3x3,
|
||||||
|
3, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(num_outchannels_conv3x3,
|
||||||
|
momentum=BN_MOMENTUM),
|
||||||
|
nn.ReLU(False)))
|
||||||
|
fuse_layer.append(nn.Sequential(*conv3x3s))
|
||||||
|
fuse_layers.append(nn.ModuleList(fuse_layer))
|
||||||
|
|
||||||
|
return nn.ModuleList(fuse_layers)
|
||||||
|
|
||||||
|
def get_num_inchannels(self):
|
||||||
|
return self.num_inchannels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.num_branches == 1:
|
||||||
|
return [self.branches[0](x[0])]
|
||||||
|
|
||||||
|
for i in range(self.num_branches):
|
||||||
|
x[i] = self.branches[i](x[i])
|
||||||
|
|
||||||
|
x_fuse = []
|
||||||
|
for i in range(len(self.fuse_layers)):
|
||||||
|
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
||||||
|
for j in range(1, self.num_branches):
|
||||||
|
if i == j:
|
||||||
|
y = y + x[j]
|
||||||
|
else:
|
||||||
|
y = y + self.fuse_layers[i][j](x[j])
|
||||||
|
x_fuse.append(self.relu(y))
|
||||||
|
|
||||||
|
return x_fuse
|
||||||
|
|
||||||
|
|
||||||
|
blocks_dict = {
|
||||||
|
'BASIC': BasicBlock,
|
||||||
|
'BOTTLENECK': Bottleneck
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class HighResolutionNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, cfg, **kwargs):
|
||||||
|
super(HighResolutionNet, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
||||||
|
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
|
||||||
|
bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1']
|
||||||
|
num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
|
||||||
|
block = blocks_dict[self.stage1_cfg['BLOCK']]
|
||||||
|
num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
|
||||||
|
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
|
||||||
|
stage1_out_channel = block.expansion*num_channels
|
||||||
|
|
||||||
|
self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
|
||||||
|
num_channels = self.stage2_cfg['NUM_CHANNELS']
|
||||||
|
block = blocks_dict[self.stage2_cfg['BLOCK']]
|
||||||
|
num_channels = [
|
||||||
|
num_channels[i] * block.expansion for i in range(len(num_channels))]
|
||||||
|
self.transition1 = self._make_transition_layer(
|
||||||
|
[stage1_out_channel], num_channels)
|
||||||
|
self.stage2, pre_stage_channels = self._make_stage(
|
||||||
|
self.stage2_cfg, num_channels)
|
||||||
|
|
||||||
|
self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']
|
||||||
|
num_channels = self.stage3_cfg['NUM_CHANNELS']
|
||||||
|
block = blocks_dict[self.stage3_cfg['BLOCK']]
|
||||||
|
num_channels = [
|
||||||
|
num_channels[i] * block.expansion for i in range(len(num_channels))]
|
||||||
|
self.transition2 = self._make_transition_layer(
|
||||||
|
pre_stage_channels, num_channels)
|
||||||
|
self.stage3, pre_stage_channels = self._make_stage(
|
||||||
|
self.stage3_cfg, num_channels)
|
||||||
|
|
||||||
|
self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']
|
||||||
|
num_channels = self.stage4_cfg['NUM_CHANNELS']
|
||||||
|
block = blocks_dict[self.stage4_cfg['BLOCK']]
|
||||||
|
num_channels = [
|
||||||
|
num_channels[i] * block.expansion for i in range(len(num_channels))]
|
||||||
|
self.transition3 = self._make_transition_layer(
|
||||||
|
pre_stage_channels, num_channels)
|
||||||
|
self.stage4, pre_stage_channels = self._make_stage(
|
||||||
|
self.stage4_cfg, num_channels, multi_scale_output=True)
|
||||||
|
|
||||||
|
# Classification Head
|
||||||
|
self.incre_modules, self.downsamp_modules, \
|
||||||
|
self.final_layer = self._make_head(pre_stage_channels)
|
||||||
|
|
||||||
|
self.classifier = nn.Linear(2048, 1000)
|
||||||
|
|
||||||
|
def _make_head(self, pre_stage_channels):
|
||||||
|
head_block = Bottleneck
|
||||||
|
head_channels = [32, 64, 128, 256]
|
||||||
|
|
||||||
|
# Increasing the #channels on each resolution
|
||||||
|
# from C, 2C, 4C, 8C to 128, 256, 512, 1024
|
||||||
|
incre_modules = []
|
||||||
|
for i, channels in enumerate(pre_stage_channels):
|
||||||
|
incre_module = self._make_layer(head_block,
|
||||||
|
channels,
|
||||||
|
head_channels[i],
|
||||||
|
1,
|
||||||
|
stride=1)
|
||||||
|
incre_modules.append(incre_module)
|
||||||
|
incre_modules = nn.ModuleList(incre_modules)
|
||||||
|
|
||||||
|
# downsampling modules
|
||||||
|
downsamp_modules = []
|
||||||
|
for i in range(len(pre_stage_channels)-1):
|
||||||
|
in_channels = head_channels[i] * head_block.expansion
|
||||||
|
out_channels = head_channels[i+1] * head_block.expansion
|
||||||
|
|
||||||
|
downsamp_module = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1),
|
||||||
|
nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
downsamp_modules.append(downsamp_module)
|
||||||
|
downsamp_modules = nn.ModuleList(downsamp_modules)
|
||||||
|
|
||||||
|
final_layer = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=head_channels[3] * head_block.expansion,
|
||||||
|
out_channels=2048,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(2048, momentum=BN_MOMENTUM),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
return incre_modules, downsamp_modules, final_layer
|
||||||
|
|
||||||
|
def _make_transition_layer(
|
||||||
|
self, num_channels_pre_layer, num_channels_cur_layer):
|
||||||
|
num_branches_cur = len(num_channels_cur_layer)
|
||||||
|
num_branches_pre = len(num_channels_pre_layer)
|
||||||
|
|
||||||
|
transition_layers = []
|
||||||
|
for i in range(num_branches_cur):
|
||||||
|
if i < num_branches_pre:
|
||||||
|
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
||||||
|
transition_layers.append(nn.Sequential(
|
||||||
|
nn.Conv2d(num_channels_pre_layer[i],
|
||||||
|
num_channels_cur_layer[i],
|
||||||
|
3,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
bias=False),
|
||||||
|
nn.BatchNorm2d(
|
||||||
|
num_channels_cur_layer[i], momentum=BN_MOMENTUM),
|
||||||
|
nn.ReLU(inplace=True)))
|
||||||
|
else:
|
||||||
|
transition_layers.append(None)
|
||||||
|
else:
|
||||||
|
conv3x3s = []
|
||||||
|
for j in range(i+1-num_branches_pre):
|
||||||
|
inchannels = num_channels_pre_layer[-1]
|
||||||
|
outchannels = num_channels_cur_layer[i] \
|
||||||
|
if j == i-num_branches_pre else inchannels
|
||||||
|
conv3x3s.append(nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
inchannels, outchannels, 3, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
|
||||||
|
nn.ReLU(inplace=True)))
|
||||||
|
transition_layers.append(nn.Sequential(*conv3x3s))
|
||||||
|
|
||||||
|
return nn.ModuleList(transition_layers)
|
||||||
|
|
||||||
|
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(inplanes, planes * block.expansion,
|
||||||
|
kernel_size=1, stride=stride, bias=False),
|
||||||
|
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(inplanes, planes, stride, downsample))
|
||||||
|
inplanes = planes * block.expansion
|
||||||
|
for i in range(1, blocks):
|
||||||
|
layers.append(block(inplanes, planes))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _make_stage(self, layer_config, num_inchannels,
|
||||||
|
multi_scale_output=True):
|
||||||
|
num_modules = layer_config['NUM_MODULES']
|
||||||
|
num_branches = layer_config['NUM_BRANCHES']
|
||||||
|
num_blocks = layer_config['NUM_BLOCKS']
|
||||||
|
num_channels = layer_config['NUM_CHANNELS']
|
||||||
|
block = blocks_dict[layer_config['BLOCK']]
|
||||||
|
fuse_method = layer_config['FUSE_METHOD']
|
||||||
|
|
||||||
|
modules = []
|
||||||
|
for i in range(num_modules):
|
||||||
|
# multi_scale_output is only used last module
|
||||||
|
if not multi_scale_output and i == num_modules - 1:
|
||||||
|
reset_multi_scale_output = False
|
||||||
|
else:
|
||||||
|
reset_multi_scale_output = True
|
||||||
|
|
||||||
|
modules.append(
|
||||||
|
HighResolutionModule(num_branches,
|
||||||
|
block,
|
||||||
|
num_blocks,
|
||||||
|
num_inchannels,
|
||||||
|
num_channels,
|
||||||
|
fuse_method,
|
||||||
|
reset_multi_scale_output)
|
||||||
|
)
|
||||||
|
num_inchannels = modules[-1].get_num_inchannels()
|
||||||
|
|
||||||
|
return nn.Sequential(*modules), num_inchannels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.layer1(x)
|
||||||
|
|
||||||
|
x_list = []
|
||||||
|
for i in range(self.stage2_cfg['NUM_BRANCHES']):
|
||||||
|
if self.transition1[i] is not None:
|
||||||
|
x_list.append(self.transition1[i](x))
|
||||||
|
else:
|
||||||
|
x_list.append(x)
|
||||||
|
y_list = self.stage2(x_list)
|
||||||
|
|
||||||
|
x_list = []
|
||||||
|
for i in range(self.stage3_cfg['NUM_BRANCHES']):
|
||||||
|
if self.transition2[i] is not None:
|
||||||
|
x_list.append(self.transition2[i](y_list[-1]))
|
||||||
|
else:
|
||||||
|
x_list.append(y_list[i])
|
||||||
|
y_list = self.stage3(x_list)
|
||||||
|
|
||||||
|
x_list = []
|
||||||
|
for i in range(self.stage4_cfg['NUM_BRANCHES']):
|
||||||
|
if self.transition3[i] is not None:
|
||||||
|
x_list.append(self.transition3[i](y_list[-1]))
|
||||||
|
else:
|
||||||
|
x_list.append(y_list[i])
|
||||||
|
y_list = self.stage4(x_list)
|
||||||
|
|
||||||
|
# Classification Head
|
||||||
|
y = self.incre_modules[0](y_list[0])
|
||||||
|
for i in range(len(self.downsamp_modules)):
|
||||||
|
y = self.incre_modules[i+1](y_list[i+1]) + \
|
||||||
|
self.downsamp_modules[i](y)
|
||||||
|
|
||||||
|
y = self.final_layer(y)
|
||||||
|
|
||||||
|
if torch._C._get_tracing_state():
|
||||||
|
y = y.flatten(start_dim=2).mean(dim=2)
|
||||||
|
else:
|
||||||
|
y = F.avg_pool2d(y, kernel_size=y.size()
|
||||||
|
[2:]).view(y.size(0), -1)
|
||||||
|
|
||||||
|
# y = self.classifier(y)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
def init_weights(self, pretrained='',):
|
||||||
|
logger.info('=> init weights from normal distribution')
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(
|
||||||
|
m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
if os.path.isfile(pretrained):
|
||||||
|
pretrained_dict = torch.load(pretrained)
|
||||||
|
logger.info('=> loading pretrained model {}'.format(pretrained))
|
||||||
|
print('=> loading pretrained model {}'.format(pretrained))
|
||||||
|
model_dict = self.state_dict()
|
||||||
|
pretrained_dict = {k: v for k, v in pretrained_dict.items()
|
||||||
|
if k in model_dict.keys()}
|
||||||
|
# for k, _ in pretrained_dict.items():
|
||||||
|
# logger.info(
|
||||||
|
# '=> loading {} pretrained model {}'.format(k, pretrained))
|
||||||
|
# print('=> loading {} pretrained model {}'.format(k, pretrained))
|
||||||
|
model_dict.update(pretrained_dict)
|
||||||
|
self.load_state_dict(model_dict)
|
||||||
|
# code.interact(local=locals())
|
||||||
|
|
||||||
|
def get_cls_net(config, pretrained, **kwargs):
|
||||||
|
model = HighResolutionNet(config, **kwargs)
|
||||||
|
model.init_weights(pretrained=pretrained)
|
||||||
|
return model
|
||||||
524
mesh_graphormer/modeling/hrnet/hrnet_cls_net_gridfeat.py
Normal file
524
mesh_graphormer/modeling/hrnet/hrnet_cls_net_gridfeat.py
Normal file
@@ -0,0 +1,524 @@
|
|||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Microsoft
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
|
||||||
|
# Modified by Ke Sun (sunk@mail.ustc.edu.cn)
|
||||||
|
# Modified by Kevin Lin (keli@microsoft.com)
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch._utils
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import code
|
||||||
|
BN_MOMENTUM = 0.1
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, stride=1):
|
||||||
|
"""3x3 convolution with padding"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
|
padding=1, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||||
|
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.conv2 = conv3x3(planes, planes)
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(x)
|
||||||
|
|
||||||
|
out += residual
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||||
|
super(Bottleneck, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||||
|
padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||||
|
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
|
||||||
|
bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
|
||||||
|
momentum=BN_MOMENTUM)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv3(out)
|
||||||
|
out = self.bn3(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(x)
|
||||||
|
|
||||||
|
out += residual
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class HighResolutionModule(nn.Module):
|
||||||
|
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
|
||||||
|
num_channels, fuse_method, multi_scale_output=True):
|
||||||
|
super(HighResolutionModule, self).__init__()
|
||||||
|
self._check_branches(
|
||||||
|
num_branches, blocks, num_blocks, num_inchannels, num_channels)
|
||||||
|
|
||||||
|
self.num_inchannels = num_inchannels
|
||||||
|
self.fuse_method = fuse_method
|
||||||
|
self.num_branches = num_branches
|
||||||
|
|
||||||
|
self.multi_scale_output = multi_scale_output
|
||||||
|
|
||||||
|
self.branches = self._make_branches(
|
||||||
|
num_branches, blocks, num_blocks, num_channels)
|
||||||
|
self.fuse_layers = self._make_fuse_layers()
|
||||||
|
self.relu = nn.ReLU(False)
|
||||||
|
|
||||||
|
def _check_branches(self, num_branches, blocks, num_blocks,
|
||||||
|
num_inchannels, num_channels):
|
||||||
|
if num_branches != len(num_blocks):
|
||||||
|
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
||||||
|
num_branches, len(num_blocks))
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
if num_branches != len(num_channels):
|
||||||
|
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
||||||
|
num_branches, len(num_channels))
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
if num_branches != len(num_inchannels):
|
||||||
|
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
||||||
|
num_branches, len(num_inchannels))
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
||||||
|
stride=1):
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or \
|
||||||
|
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(self.num_inchannels[branch_index],
|
||||||
|
num_channels[branch_index] * block.expansion,
|
||||||
|
kernel_size=1, stride=stride, bias=False),
|
||||||
|
nn.BatchNorm2d(num_channels[branch_index] * block.expansion,
|
||||||
|
momentum=BN_MOMENTUM),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.num_inchannels[branch_index],
|
||||||
|
num_channels[branch_index], stride, downsample))
|
||||||
|
self.num_inchannels[branch_index] = \
|
||||||
|
num_channels[branch_index] * block.expansion
|
||||||
|
for i in range(1, num_blocks[branch_index]):
|
||||||
|
layers.append(block(self.num_inchannels[branch_index],
|
||||||
|
num_channels[branch_index]))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
||||||
|
branches = []
|
||||||
|
|
||||||
|
for i in range(num_branches):
|
||||||
|
branches.append(
|
||||||
|
self._make_one_branch(i, block, num_blocks, num_channels))
|
||||||
|
|
||||||
|
return nn.ModuleList(branches)
|
||||||
|
|
||||||
|
def _make_fuse_layers(self):
|
||||||
|
if self.num_branches == 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
num_branches = self.num_branches
|
||||||
|
num_inchannels = self.num_inchannels
|
||||||
|
fuse_layers = []
|
||||||
|
for i in range(num_branches if self.multi_scale_output else 1):
|
||||||
|
fuse_layer = []
|
||||||
|
for j in range(num_branches):
|
||||||
|
if j > i:
|
||||||
|
fuse_layer.append(nn.Sequential(
|
||||||
|
nn.Conv2d(num_inchannels[j],
|
||||||
|
num_inchannels[i],
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
bias=False),
|
||||||
|
nn.BatchNorm2d(num_inchannels[i],
|
||||||
|
momentum=BN_MOMENTUM),
|
||||||
|
nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
|
||||||
|
elif j == i:
|
||||||
|
fuse_layer.append(None)
|
||||||
|
else:
|
||||||
|
conv3x3s = []
|
||||||
|
for k in range(i-j):
|
||||||
|
if k == i - j - 1:
|
||||||
|
num_outchannels_conv3x3 = num_inchannels[i]
|
||||||
|
conv3x3s.append(nn.Sequential(
|
||||||
|
nn.Conv2d(num_inchannels[j],
|
||||||
|
num_outchannels_conv3x3,
|
||||||
|
3, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(num_outchannels_conv3x3,
|
||||||
|
momentum=BN_MOMENTUM)))
|
||||||
|
else:
|
||||||
|
num_outchannels_conv3x3 = num_inchannels[j]
|
||||||
|
conv3x3s.append(nn.Sequential(
|
||||||
|
nn.Conv2d(num_inchannels[j],
|
||||||
|
num_outchannels_conv3x3,
|
||||||
|
3, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(num_outchannels_conv3x3,
|
||||||
|
momentum=BN_MOMENTUM),
|
||||||
|
nn.ReLU(False)))
|
||||||
|
fuse_layer.append(nn.Sequential(*conv3x3s))
|
||||||
|
fuse_layers.append(nn.ModuleList(fuse_layer))
|
||||||
|
|
||||||
|
return nn.ModuleList(fuse_layers)
|
||||||
|
|
||||||
|
def get_num_inchannels(self):
|
||||||
|
return self.num_inchannels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.num_branches == 1:
|
||||||
|
return [self.branches[0](x[0])]
|
||||||
|
|
||||||
|
for i in range(self.num_branches):
|
||||||
|
x[i] = self.branches[i](x[i])
|
||||||
|
|
||||||
|
x_fuse = []
|
||||||
|
for i in range(len(self.fuse_layers)):
|
||||||
|
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
||||||
|
for j in range(1, self.num_branches):
|
||||||
|
if i == j:
|
||||||
|
y = y + x[j]
|
||||||
|
else:
|
||||||
|
y = y + self.fuse_layers[i][j](x[j])
|
||||||
|
x_fuse.append(self.relu(y))
|
||||||
|
|
||||||
|
return x_fuse
|
||||||
|
|
||||||
|
|
||||||
|
blocks_dict = {
|
||||||
|
'BASIC': BasicBlock,
|
||||||
|
'BOTTLENECK': Bottleneck
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class HighResolutionNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, cfg, **kwargs):
|
||||||
|
super(HighResolutionNet, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
|
||||||
|
bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
||||||
|
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
|
||||||
|
bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1']
|
||||||
|
num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
|
||||||
|
block = blocks_dict[self.stage1_cfg['BLOCK']]
|
||||||
|
num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
|
||||||
|
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
|
||||||
|
stage1_out_channel = block.expansion*num_channels
|
||||||
|
|
||||||
|
self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
|
||||||
|
num_channels = self.stage2_cfg['NUM_CHANNELS']
|
||||||
|
block = blocks_dict[self.stage2_cfg['BLOCK']]
|
||||||
|
num_channels = [
|
||||||
|
num_channels[i] * block.expansion for i in range(len(num_channels))]
|
||||||
|
self.transition1 = self._make_transition_layer(
|
||||||
|
[stage1_out_channel], num_channels)
|
||||||
|
self.stage2, pre_stage_channels = self._make_stage(
|
||||||
|
self.stage2_cfg, num_channels)
|
||||||
|
|
||||||
|
self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']
|
||||||
|
num_channels = self.stage3_cfg['NUM_CHANNELS']
|
||||||
|
block = blocks_dict[self.stage3_cfg['BLOCK']]
|
||||||
|
num_channels = [
|
||||||
|
num_channels[i] * block.expansion for i in range(len(num_channels))]
|
||||||
|
self.transition2 = self._make_transition_layer(
|
||||||
|
pre_stage_channels, num_channels)
|
||||||
|
self.stage3, pre_stage_channels = self._make_stage(
|
||||||
|
self.stage3_cfg, num_channels)
|
||||||
|
|
||||||
|
self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']
|
||||||
|
num_channels = self.stage4_cfg['NUM_CHANNELS']
|
||||||
|
block = blocks_dict[self.stage4_cfg['BLOCK']]
|
||||||
|
num_channels = [
|
||||||
|
num_channels[i] * block.expansion for i in range(len(num_channels))]
|
||||||
|
self.transition3 = self._make_transition_layer(
|
||||||
|
pre_stage_channels, num_channels)
|
||||||
|
self.stage4, pre_stage_channels = self._make_stage(
|
||||||
|
self.stage4_cfg, num_channels, multi_scale_output=True)
|
||||||
|
|
||||||
|
# Classification Head
|
||||||
|
self.incre_modules, self.downsamp_modules, \
|
||||||
|
self.final_layer = self._make_head(pre_stage_channels)
|
||||||
|
|
||||||
|
self.classifier = nn.Linear(2048, 1000)
|
||||||
|
|
||||||
|
def _make_head(self, pre_stage_channels):
|
||||||
|
head_block = Bottleneck
|
||||||
|
head_channels = [32, 64, 128, 256]
|
||||||
|
|
||||||
|
# Increasing the #channels on each resolution
|
||||||
|
# from C, 2C, 4C, 8C to 128, 256, 512, 1024
|
||||||
|
incre_modules = []
|
||||||
|
for i, channels in enumerate(pre_stage_channels):
|
||||||
|
incre_module = self._make_layer(head_block,
|
||||||
|
channels,
|
||||||
|
head_channels[i],
|
||||||
|
1,
|
||||||
|
stride=1)
|
||||||
|
incre_modules.append(incre_module)
|
||||||
|
incre_modules = nn.ModuleList(incre_modules)
|
||||||
|
|
||||||
|
# downsampling modules
|
||||||
|
downsamp_modules = []
|
||||||
|
for i in range(len(pre_stage_channels)-1):
|
||||||
|
in_channels = head_channels[i] * head_block.expansion
|
||||||
|
out_channels = head_channels[i+1] * head_block.expansion
|
||||||
|
|
||||||
|
downsamp_module = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1),
|
||||||
|
nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
downsamp_modules.append(downsamp_module)
|
||||||
|
downsamp_modules = nn.ModuleList(downsamp_modules)
|
||||||
|
|
||||||
|
final_layer = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=head_channels[3] * head_block.expansion,
|
||||||
|
out_channels=2048,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(2048, momentum=BN_MOMENTUM),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
return incre_modules, downsamp_modules, final_layer
|
||||||
|
|
||||||
|
def _make_transition_layer(
|
||||||
|
self, num_channels_pre_layer, num_channels_cur_layer):
|
||||||
|
num_branches_cur = len(num_channels_cur_layer)
|
||||||
|
num_branches_pre = len(num_channels_pre_layer)
|
||||||
|
|
||||||
|
transition_layers = []
|
||||||
|
for i in range(num_branches_cur):
|
||||||
|
if i < num_branches_pre:
|
||||||
|
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
||||||
|
transition_layers.append(nn.Sequential(
|
||||||
|
nn.Conv2d(num_channels_pre_layer[i],
|
||||||
|
num_channels_cur_layer[i],
|
||||||
|
3,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
bias=False),
|
||||||
|
nn.BatchNorm2d(
|
||||||
|
num_channels_cur_layer[i], momentum=BN_MOMENTUM),
|
||||||
|
nn.ReLU(inplace=True)))
|
||||||
|
else:
|
||||||
|
transition_layers.append(None)
|
||||||
|
else:
|
||||||
|
conv3x3s = []
|
||||||
|
for j in range(i+1-num_branches_pre):
|
||||||
|
inchannels = num_channels_pre_layer[-1]
|
||||||
|
outchannels = num_channels_cur_layer[i] \
|
||||||
|
if j == i-num_branches_pre else inchannels
|
||||||
|
conv3x3s.append(nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
inchannels, outchannels, 3, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
|
||||||
|
nn.ReLU(inplace=True)))
|
||||||
|
transition_layers.append(nn.Sequential(*conv3x3s))
|
||||||
|
|
||||||
|
return nn.ModuleList(transition_layers)
|
||||||
|
|
||||||
|
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(inplanes, planes * block.expansion,
|
||||||
|
kernel_size=1, stride=stride, bias=False),
|
||||||
|
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(inplanes, planes, stride, downsample))
|
||||||
|
inplanes = planes * block.expansion
|
||||||
|
for i in range(1, blocks):
|
||||||
|
layers.append(block(inplanes, planes))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _make_stage(self, layer_config, num_inchannels,
|
||||||
|
multi_scale_output=True):
|
||||||
|
num_modules = layer_config['NUM_MODULES']
|
||||||
|
num_branches = layer_config['NUM_BRANCHES']
|
||||||
|
num_blocks = layer_config['NUM_BLOCKS']
|
||||||
|
num_channels = layer_config['NUM_CHANNELS']
|
||||||
|
block = blocks_dict[layer_config['BLOCK']]
|
||||||
|
fuse_method = layer_config['FUSE_METHOD']
|
||||||
|
|
||||||
|
modules = []
|
||||||
|
for i in range(num_modules):
|
||||||
|
# multi_scale_output is only used last module
|
||||||
|
if not multi_scale_output and i == num_modules - 1:
|
||||||
|
reset_multi_scale_output = False
|
||||||
|
else:
|
||||||
|
reset_multi_scale_output = True
|
||||||
|
|
||||||
|
modules.append(
|
||||||
|
HighResolutionModule(num_branches,
|
||||||
|
block,
|
||||||
|
num_blocks,
|
||||||
|
num_inchannels,
|
||||||
|
num_channels,
|
||||||
|
fuse_method,
|
||||||
|
reset_multi_scale_output)
|
||||||
|
)
|
||||||
|
num_inchannels = modules[-1].get_num_inchannels()
|
||||||
|
|
||||||
|
return nn.Sequential(*modules), num_inchannels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.layer1(x)
|
||||||
|
|
||||||
|
x_list = []
|
||||||
|
for i in range(self.stage2_cfg['NUM_BRANCHES']):
|
||||||
|
if self.transition1[i] is not None:
|
||||||
|
x_list.append(self.transition1[i](x))
|
||||||
|
else:
|
||||||
|
x_list.append(x)
|
||||||
|
y_list = self.stage2(x_list)
|
||||||
|
|
||||||
|
x_list = []
|
||||||
|
for i in range(self.stage3_cfg['NUM_BRANCHES']):
|
||||||
|
if self.transition2[i] is not None:
|
||||||
|
x_list.append(self.transition2[i](y_list[-1]))
|
||||||
|
else:
|
||||||
|
x_list.append(y_list[i])
|
||||||
|
y_list = self.stage3(x_list)
|
||||||
|
|
||||||
|
x_list = []
|
||||||
|
for i in range(self.stage4_cfg['NUM_BRANCHES']):
|
||||||
|
if self.transition3[i] is not None:
|
||||||
|
x_list.append(self.transition3[i](y_list[-1]))
|
||||||
|
else:
|
||||||
|
x_list.append(y_list[i])
|
||||||
|
y_list = self.stage4(x_list)
|
||||||
|
|
||||||
|
# Classification Head
|
||||||
|
y = self.incre_modules[0](y_list[0])
|
||||||
|
for i in range(len(self.downsamp_modules)):
|
||||||
|
y = self.incre_modules[i+1](y_list[i+1]) + \
|
||||||
|
self.downsamp_modules[i](y)
|
||||||
|
|
||||||
|
yy = self.final_layer(y)
|
||||||
|
|
||||||
|
if torch._C._get_tracing_state():
|
||||||
|
yy = yy.flatten(start_dim=2).mean(dim=2)
|
||||||
|
else:
|
||||||
|
yy = F.avg_pool2d(yy, kernel_size=yy.size()
|
||||||
|
[2:]).view(yy.size(0), -1)
|
||||||
|
|
||||||
|
# y = self.classifier(y)
|
||||||
|
return yy, y
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(self, pretrained='',):
|
||||||
|
logger.info('=> init weights from normal distribution')
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(
|
||||||
|
m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
if os.path.isfile(pretrained):
|
||||||
|
pretrained_dict = torch.load(pretrained)
|
||||||
|
logger.info('=> loading pretrained model {}'.format(pretrained))
|
||||||
|
print('=> loading pretrained model {}'.format(pretrained))
|
||||||
|
model_dict = self.state_dict()
|
||||||
|
pretrained_dict = {k: v for k, v in pretrained_dict.items()
|
||||||
|
if k in model_dict.keys()}
|
||||||
|
# for k, _ in pretrained_dict.items():
|
||||||
|
# logger.info(
|
||||||
|
# '=> loading {} pretrained model {}'.format(k, pretrained))
|
||||||
|
# print('=> loading {} pretrained model {}'.format(k, pretrained))
|
||||||
|
model_dict.update(pretrained_dict)
|
||||||
|
self.load_state_dict(model_dict)
|
||||||
|
# code.interact(local=locals())
|
||||||
|
|
||||||
|
def get_cls_net_gridfeat(config, pretrained, **kwargs):
|
||||||
|
model = HighResolutionNet(config, **kwargs)
|
||||||
|
model.init_weights(pretrained=pretrained)
|
||||||
|
return model
|
||||||
750
mesh_graphormer/tools/run_gphmer_bodymesh.py
Normal file
750
mesh_graphormer/tools/run_gphmer_bodymesh.py
Normal file
@@ -0,0 +1,750 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
Training and evaluation codes for
|
||||||
|
3D human body mesh reconstruction from an image
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import os.path as op
|
||||||
|
import code
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
import torch
|
||||||
|
import torchvision.models as models
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
import gc
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from mesh_graphormer.modeling.bert import BertConfig, Graphormer
|
||||||
|
from mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network
|
||||||
|
from mesh_graphormer.modeling._smpl import SMPL, Mesh
|
||||||
|
from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
|
||||||
|
from mesh_graphormer.modeling.hrnet.config import config as hrnet_config
|
||||||
|
from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
|
||||||
|
import mesh_graphormer.modeling.data.config as cfg
|
||||||
|
from mesh_graphormer.datasets.build import make_data_loader
|
||||||
|
|
||||||
|
from mesh_graphormer.utils.logger import setup_logger
|
||||||
|
from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
|
||||||
|
from mesh_graphormer.utils.miscellaneous import mkdir, set_seed
|
||||||
|
from mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger
|
||||||
|
from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction, visualize_reconstruction_test
|
||||||
|
from mesh_graphormer.utils.metric_pampjpe import reconstruction_error
|
||||||
|
from mesh_graphormer.utils.geometric_layers import orthographic_projection
|
||||||
|
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
from azureml.core.run import Run
|
||||||
|
aml_run = Run.get_context()
|
||||||
|
|
||||||
|
def save_checkpoint(model, args, epoch, iteration, num_trial=10):
|
||||||
|
checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format(
|
||||||
|
epoch, iteration))
|
||||||
|
if not is_main_process():
|
||||||
|
return checkpoint_dir
|
||||||
|
mkdir(checkpoint_dir)
|
||||||
|
model_to_save = model.module if hasattr(model, 'module') else model
|
||||||
|
for i in range(num_trial):
|
||||||
|
try:
|
||||||
|
torch.save(model_to_save, op.join(checkpoint_dir, 'model.bin'))
|
||||||
|
torch.save(model_to_save.state_dict(), op.join(checkpoint_dir, 'state_dict.bin'))
|
||||||
|
torch.save(args, op.join(checkpoint_dir, 'training_args.bin'))
|
||||||
|
logger.info("Save checkpoint to {}".format(checkpoint_dir))
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.info("Failed to save checkpoint after {} trails.".format(num_trial))
|
||||||
|
return checkpoint_dir
|
||||||
|
|
||||||
|
def save_scores(args, split, mpjpe, pampjpe, mpve):
|
||||||
|
eval_log = []
|
||||||
|
res = {}
|
||||||
|
res['mPJPE'] = mpjpe
|
||||||
|
res['PAmPJPE'] = pampjpe
|
||||||
|
res['mPVE'] = mpve
|
||||||
|
eval_log.append(res)
|
||||||
|
with open(op.join(args.output_dir, split+'_eval_logs.json'), 'w') as f:
|
||||||
|
json.dump(eval_log, f)
|
||||||
|
logger.info("Save eval scores to {}".format(args.output_dir))
|
||||||
|
return
|
||||||
|
|
||||||
|
def adjust_learning_rate(optimizer, epoch, args):
|
||||||
|
"""
|
||||||
|
Sets the learning rate to the initial LR decayed by x every y epochs
|
||||||
|
x = 0.1, y = args.num_train_epochs/2.0 = 100
|
||||||
|
"""
|
||||||
|
lr = args.lr * (0.1 ** (epoch // (args.num_train_epochs/2.0) ))
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
def mean_per_joint_position_error(pred, gt, has_3d_joints):
|
||||||
|
"""
|
||||||
|
Compute mPJPE
|
||||||
|
"""
|
||||||
|
gt = gt[has_3d_joints == 1]
|
||||||
|
gt = gt[:, :, :-1]
|
||||||
|
pred = pred[has_3d_joints == 1]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
gt_pelvis = (gt[:, 2,:] + gt[:, 3,:]) / 2
|
||||||
|
gt = gt - gt_pelvis[:, None, :]
|
||||||
|
pred_pelvis = (pred[:, 2,:] + pred[:, 3,:]) / 2
|
||||||
|
pred = pred - pred_pelvis[:, None, :]
|
||||||
|
error = torch.sqrt( ((pred - gt) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
|
||||||
|
return error
|
||||||
|
|
||||||
|
def mean_per_vertex_error(pred, gt, has_smpl):
|
||||||
|
"""
|
||||||
|
Compute mPVE
|
||||||
|
"""
|
||||||
|
pred = pred[has_smpl == 1]
|
||||||
|
gt = gt[has_smpl == 1]
|
||||||
|
with torch.no_grad():
|
||||||
|
error = torch.sqrt( ((pred - gt) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
|
||||||
|
return error
|
||||||
|
|
||||||
|
def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, has_pose_2d):
|
||||||
|
"""
|
||||||
|
Compute 2D reprojection loss if 2D keypoint annotations are available.
|
||||||
|
The confidence (conf) is binary and indicates whether the keypoints exist or not.
|
||||||
|
"""
|
||||||
|
conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
|
||||||
|
loss = (conf * criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d, device):
|
||||||
|
"""
|
||||||
|
Compute 3D keypoint loss if 3D keypoint annotations are available.
|
||||||
|
"""
|
||||||
|
conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
|
||||||
|
gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
|
||||||
|
gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
|
||||||
|
conf = conf[has_pose_3d == 1]
|
||||||
|
pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
|
||||||
|
if len(gt_keypoints_3d) > 0:
|
||||||
|
gt_pelvis = (gt_keypoints_3d[:, 2,:] + gt_keypoints_3d[:, 3,:]) / 2
|
||||||
|
gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :]
|
||||||
|
pred_pelvis = (pred_keypoints_3d[:, 2,:] + pred_keypoints_3d[:, 3,:]) / 2
|
||||||
|
pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :]
|
||||||
|
return (conf * criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean()
|
||||||
|
else:
|
||||||
|
return torch.FloatTensor(1).fill_(0.).to(device)
|
||||||
|
|
||||||
|
def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, device):
|
||||||
|
"""
|
||||||
|
Compute per-vertex loss if vertex annotations are available.
|
||||||
|
"""
|
||||||
|
pred_vertices_with_shape = pred_vertices[has_smpl == 1]
|
||||||
|
gt_vertices_with_shape = gt_vertices[has_smpl == 1]
|
||||||
|
if len(gt_vertices_with_shape) > 0:
|
||||||
|
return criterion_vertices(pred_vertices_with_shape, gt_vertices_with_shape)
|
||||||
|
else:
|
||||||
|
return torch.FloatTensor(1).fill_(0.).to(device)
|
||||||
|
|
||||||
|
def rectify_pose(pose):
|
||||||
|
pose = pose.copy()
|
||||||
|
R_mod = cv2.Rodrigues(np.array([np.pi, 0, 0]))[0]
|
||||||
|
R_root = cv2.Rodrigues(pose[:3])[0]
|
||||||
|
new_root = R_root.dot(R_mod)
|
||||||
|
pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3)
|
||||||
|
return pose
|
||||||
|
|
||||||
|
def run(args, train_dataloader, val_dataloader, Graphormer_model, smpl, mesh_sampler, renderer):
|
||||||
|
smpl.eval()
|
||||||
|
max_iter = len(train_dataloader)
|
||||||
|
iters_per_epoch = max_iter // args.num_train_epochs
|
||||||
|
if iters_per_epoch<1000:
|
||||||
|
args.logging_steps = 500
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(params=list(Graphormer_model.parameters()),
|
||||||
|
lr=args.lr,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
weight_decay=0)
|
||||||
|
|
||||||
|
# define loss function (criterion) and optimizer
|
||||||
|
criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').to(device)
|
||||||
|
criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
|
||||||
|
criterion_vertices = torch.nn.L1Loss().to(device)
|
||||||
|
|
||||||
|
if args.distributed:
|
||||||
|
Graphormer_model = torch.nn.parallel.DistributedDataParallel(
|
||||||
|
Graphormer_model, device_ids=[args.local_rank],
|
||||||
|
output_device=args.local_rank,
|
||||||
|
find_unused_parameters=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
' '.join(
|
||||||
|
['Local rank: {o}', 'Max iteration: {a}', 'iters_per_epoch: {b}','num_train_epochs: {c}',]
|
||||||
|
).format(o=args.local_rank, a=max_iter, b=iters_per_epoch, c=args.num_train_epochs)
|
||||||
|
)
|
||||||
|
|
||||||
|
start_training_time = time.time()
|
||||||
|
end = time.time()
|
||||||
|
Graphormer_model.train()
|
||||||
|
batch_time = AverageMeter()
|
||||||
|
data_time = AverageMeter()
|
||||||
|
log_losses = AverageMeter()
|
||||||
|
log_loss_2djoints = AverageMeter()
|
||||||
|
log_loss_3djoints = AverageMeter()
|
||||||
|
log_loss_vertices = AverageMeter()
|
||||||
|
log_eval_metrics = EvalMetricsLogger()
|
||||||
|
|
||||||
|
for iteration, (img_keys, images, annotations) in enumerate(train_dataloader):
|
||||||
|
# gc.collect()
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
Graphormer_model.train()
|
||||||
|
iteration += 1
|
||||||
|
epoch = iteration // iters_per_epoch
|
||||||
|
batch_size = images.size(0)
|
||||||
|
adjust_learning_rate(optimizer, epoch, args)
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
|
||||||
|
images = images.to(device)
|
||||||
|
gt_2d_joints = annotations['joints_2d'].to(device)
|
||||||
|
gt_2d_joints = gt_2d_joints[:,cfg.J24_TO_J14,:]
|
||||||
|
has_2d_joints = annotations['has_2d_joints'].to(device)
|
||||||
|
|
||||||
|
gt_3d_joints = annotations['joints_3d'].to(device)
|
||||||
|
gt_3d_pelvis = gt_3d_joints[:,cfg.J24_NAME.index('Pelvis'),:3]
|
||||||
|
gt_3d_joints = gt_3d_joints[:,cfg.J24_TO_J14,:]
|
||||||
|
gt_3d_joints[:,:,:3] = gt_3d_joints[:,:,:3] - gt_3d_pelvis[:, None, :]
|
||||||
|
has_3d_joints = annotations['has_3d_joints'].to(device)
|
||||||
|
|
||||||
|
gt_pose = annotations['pose'].to(device)
|
||||||
|
gt_betas = annotations['betas'].to(device)
|
||||||
|
has_smpl = annotations['has_smpl'].to(device)
|
||||||
|
mjm_mask = annotations['mjm_mask'].to(device)
|
||||||
|
mvm_mask = annotations['mvm_mask'].to(device)
|
||||||
|
|
||||||
|
# generate simplified mesh
|
||||||
|
gt_vertices = smpl(gt_pose, gt_betas)
|
||||||
|
gt_vertices_sub2 = mesh_sampler.downsample(gt_vertices, n1=0, n2=2)
|
||||||
|
gt_vertices_sub = mesh_sampler.downsample(gt_vertices)
|
||||||
|
|
||||||
|
# normalize gt based on smpl's pelvis
|
||||||
|
gt_smpl_3d_joints = smpl.get_h36m_joints(gt_vertices)
|
||||||
|
gt_smpl_3d_pelvis = gt_smpl_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
|
||||||
|
gt_vertices_sub2 = gt_vertices_sub2 - gt_smpl_3d_pelvis[:, None, :]
|
||||||
|
|
||||||
|
# prepare masks for mask vertex/joint modeling
|
||||||
|
mjm_mask_ = mjm_mask.expand(-1,-1,2051)
|
||||||
|
mvm_mask_ = mvm_mask.expand(-1,-1,2051)
|
||||||
|
meta_masks = torch.cat([mjm_mask_, mvm_mask_], dim=1)
|
||||||
|
|
||||||
|
# forward-pass
|
||||||
|
pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices = Graphormer_model(images, smpl, mesh_sampler, meta_masks=meta_masks, is_train=True)
|
||||||
|
|
||||||
|
# normalize gt based on smpl's pelvis
|
||||||
|
gt_vertices_sub = gt_vertices_sub - gt_smpl_3d_pelvis[:, None, :]
|
||||||
|
gt_vertices = gt_vertices - gt_smpl_3d_pelvis[:, None, :]
|
||||||
|
|
||||||
|
# obtain 3d joints, which are regressed from the full mesh
|
||||||
|
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
|
||||||
|
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
|
||||||
|
|
||||||
|
# obtain 2d joints, which are projected from 3d joints of smpl mesh
|
||||||
|
pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera)
|
||||||
|
pred_2d_joints = orthographic_projection(pred_3d_joints, pred_camera)
|
||||||
|
|
||||||
|
# compute 3d joint loss (where the joints are directly output from transformer)
|
||||||
|
loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints, has_3d_joints, args.device)
|
||||||
|
# compute 3d vertex loss
|
||||||
|
loss_vertices = ( args.vloss_w_sub2 * vertices_loss(criterion_vertices, pred_vertices_sub2, gt_vertices_sub2, has_smpl, args.device) + \
|
||||||
|
args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_smpl, args.device) + \
|
||||||
|
args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, args.device) )
|
||||||
|
# compute 3d joint loss (where the joints are regressed from full mesh)
|
||||||
|
loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_smpl, gt_3d_joints, has_3d_joints, args.device)
|
||||||
|
# compute 2d joint loss
|
||||||
|
loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints) + \
|
||||||
|
keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_smpl, gt_2d_joints, has_2d_joints)
|
||||||
|
|
||||||
|
loss_3d_joints = loss_3d_joints + loss_reg_3d_joints
|
||||||
|
|
||||||
|
# we empirically use hyperparameters to balance difference losses
|
||||||
|
loss = args.joints_loss_weight*loss_3d_joints + \
|
||||||
|
args.vertices_loss_weight*loss_vertices + args.vertices_loss_weight*loss_2d_joints
|
||||||
|
|
||||||
|
# update logs
|
||||||
|
log_loss_2djoints.update(loss_2d_joints.item(), batch_size)
|
||||||
|
log_loss_3djoints.update(loss_3d_joints.item(), batch_size)
|
||||||
|
log_loss_vertices.update(loss_vertices.item(), batch_size)
|
||||||
|
log_losses.update(loss.item(), batch_size)
|
||||||
|
|
||||||
|
# back prop
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
if iteration % args.logging_steps == 0 or iteration == max_iter:
|
||||||
|
eta_seconds = batch_time.avg * (max_iter - iteration)
|
||||||
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||||
|
logger.info(
|
||||||
|
' '.join(
|
||||||
|
['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',]
|
||||||
|
).format(eta=eta_string, ep=epoch, iter=iteration,
|
||||||
|
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
|
||||||
|
+ ' loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format(
|
||||||
|
log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg,
|
||||||
|
optimizer.param_groups[0]['lr'])
|
||||||
|
)
|
||||||
|
|
||||||
|
aml_run.log(name='Loss', value=float(log_losses.avg))
|
||||||
|
aml_run.log(name='3d joint Loss', value=float(log_loss_3djoints.avg))
|
||||||
|
aml_run.log(name='2d joint Loss', value=float(log_loss_2djoints.avg))
|
||||||
|
aml_run.log(name='vertex Loss', value=float(log_loss_vertices.avg))
|
||||||
|
|
||||||
|
visual_imgs = visualize_mesh( renderer,
|
||||||
|
annotations['ori_img'].detach(),
|
||||||
|
annotations['joints_2d'].detach(),
|
||||||
|
pred_vertices.detach(),
|
||||||
|
pred_camera.detach(),
|
||||||
|
pred_2d_joints_from_smpl.detach())
|
||||||
|
visual_imgs = visual_imgs.transpose(0,1)
|
||||||
|
visual_imgs = visual_imgs.transpose(1,2)
|
||||||
|
visual_imgs = np.asarray(visual_imgs)
|
||||||
|
|
||||||
|
if is_main_process()==True:
|
||||||
|
stamp = str(epoch) + '_' + str(iteration)
|
||||||
|
temp_fname = args.output_dir + 'visual_' + stamp + '.jpg'
|
||||||
|
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
|
||||||
|
aml_run.log_image(name='visual results', path=temp_fname)
|
||||||
|
|
||||||
|
if iteration % iters_per_epoch == 0:
|
||||||
|
val_mPVE, val_mPJPE, val_PAmPJPE, val_count = run_validate(args, val_dataloader,
|
||||||
|
Graphormer_model,
|
||||||
|
criterion_keypoints,
|
||||||
|
criterion_vertices,
|
||||||
|
epoch,
|
||||||
|
smpl,
|
||||||
|
mesh_sampler)
|
||||||
|
aml_run.log(name='mPVE', value=float(1000*val_mPVE))
|
||||||
|
aml_run.log(name='mPJPE', value=float(1000*val_mPJPE))
|
||||||
|
aml_run.log(name='PAmPJPE', value=float(1000*val_PAmPJPE))
|
||||||
|
logger.info(
|
||||||
|
' '.join(['Validation', 'epoch: {ep}',]).format(ep=epoch)
|
||||||
|
+ ' mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f}, Data Count: {:6.2f}'.format(1000*val_mPVE, 1000*val_mPJPE, 1000*val_PAmPJPE, val_count)
|
||||||
|
)
|
||||||
|
|
||||||
|
if val_PAmPJPE<log_eval_metrics.PAmPJPE:
|
||||||
|
checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration)
|
||||||
|
log_eval_metrics.update(val_mPVE, val_mPJPE, val_PAmPJPE, epoch)
|
||||||
|
|
||||||
|
|
||||||
|
total_training_time = time.time() - start_training_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=total_training_time))
|
||||||
|
logger.info('Total training time: {} ({:.4f} s / iter)'.format(
|
||||||
|
total_time_str, total_training_time / max_iter)
|
||||||
|
)
|
||||||
|
checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
' Best Results:'
|
||||||
|
+ ' mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f}, at epoch {:6.2f}'.format(1000*log_eval_metrics.mPVE, 1000*log_eval_metrics.mPJPE, 1000*log_eval_metrics.PAmPJPE, log_eval_metrics.epoch)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_eval_general(args, val_dataloader, Graphormer_model, smpl, mesh_sampler):
|
||||||
|
smpl.eval()
|
||||||
|
criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
|
||||||
|
criterion_vertices = torch.nn.L1Loss().to(device)
|
||||||
|
|
||||||
|
epoch = 0
|
||||||
|
if args.distributed:
|
||||||
|
Graphormer_model = torch.nn.parallel.DistributedDataParallel(
|
||||||
|
Graphormer_model, device_ids=[args.local_rank],
|
||||||
|
output_device=args.local_rank,
|
||||||
|
find_unused_parameters=True,
|
||||||
|
)
|
||||||
|
Graphormer_model.eval()
|
||||||
|
|
||||||
|
val_mPVE, val_mPJPE, val_PAmPJPE, val_count = run_validate(args, val_dataloader,
|
||||||
|
Graphormer_model,
|
||||||
|
criterion_keypoints,
|
||||||
|
criterion_vertices,
|
||||||
|
epoch,
|
||||||
|
smpl,
|
||||||
|
mesh_sampler)
|
||||||
|
|
||||||
|
aml_run.log(name='mPVE', value=float(1000*val_mPVE))
|
||||||
|
aml_run.log(name='mPJPE', value=float(1000*val_mPJPE))
|
||||||
|
aml_run.log(name='PAmPJPE', value=float(1000*val_PAmPJPE))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
' '.join(['Validation', 'epoch: {ep}',]).format(ep=epoch)
|
||||||
|
+ ' mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f} '.format(1000*val_mPVE, 1000*val_mPJPE, 1000*val_PAmPJPE)
|
||||||
|
)
|
||||||
|
# checkpoint_dir = save_checkpoint(Graphormer_model, args, 0, 0)
|
||||||
|
return
|
||||||
|
|
||||||
|
def run_validate(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, smpl, mesh_sampler):
|
||||||
|
batch_time = AverageMeter()
|
||||||
|
mPVE = AverageMeter()
|
||||||
|
mPJPE = AverageMeter()
|
||||||
|
PAmPJPE = AverageMeter()
|
||||||
|
# switch to evaluate mode
|
||||||
|
Graphormer_model.eval()
|
||||||
|
smpl.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
# end = time.time()
|
||||||
|
for i, (img_keys, images, annotations) in enumerate(val_loader):
|
||||||
|
batch_size = images.size(0)
|
||||||
|
# compute output
|
||||||
|
images = images.to(device)
|
||||||
|
gt_3d_joints = annotations['joints_3d'].to(device)
|
||||||
|
gt_3d_pelvis = gt_3d_joints[:,cfg.J24_NAME.index('Pelvis'),:3]
|
||||||
|
gt_3d_joints = gt_3d_joints[:,cfg.J24_TO_J14,:]
|
||||||
|
gt_3d_joints[:,:,:3] = gt_3d_joints[:,:,:3] - gt_3d_pelvis[:, None, :]
|
||||||
|
has_3d_joints = annotations['has_3d_joints'].to(device)
|
||||||
|
|
||||||
|
gt_pose = annotations['pose'].to(device)
|
||||||
|
gt_betas = annotations['betas'].to(device)
|
||||||
|
has_smpl = annotations['has_smpl'].to(device)
|
||||||
|
|
||||||
|
# generate simplified mesh
|
||||||
|
gt_vertices = smpl(gt_pose, gt_betas)
|
||||||
|
gt_vertices_sub = mesh_sampler.downsample(gt_vertices)
|
||||||
|
gt_vertices_sub2 = mesh_sampler.downsample(gt_vertices_sub, n1=1, n2=2)
|
||||||
|
|
||||||
|
# normalize gt based on smpl pelvis
|
||||||
|
gt_smpl_3d_joints = smpl.get_h36m_joints(gt_vertices)
|
||||||
|
gt_smpl_3d_pelvis = gt_smpl_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
|
||||||
|
gt_vertices_sub2 = gt_vertices_sub2 - gt_smpl_3d_pelvis[:, None, :]
|
||||||
|
gt_vertices = gt_vertices - gt_smpl_3d_pelvis[:, None, :]
|
||||||
|
|
||||||
|
# forward-pass
|
||||||
|
pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices = Graphormer_model(images, smpl, mesh_sampler)
|
||||||
|
|
||||||
|
# obtain 3d joints from full mesh
|
||||||
|
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
|
||||||
|
|
||||||
|
pred_3d_pelvis = pred_3d_joints_from_smpl[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
|
||||||
|
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
|
||||||
|
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl - pred_3d_pelvis[:, None, :]
|
||||||
|
pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
|
||||||
|
|
||||||
|
# measure errors
|
||||||
|
error_vertices = mean_per_vertex_error(pred_vertices, gt_vertices, has_smpl)
|
||||||
|
error_joints = mean_per_joint_position_error(pred_3d_joints_from_smpl, gt_3d_joints, has_3d_joints)
|
||||||
|
error_joints_pa = reconstruction_error(pred_3d_joints_from_smpl.cpu().numpy(), gt_3d_joints[:,:,:3].cpu().numpy(), reduction=None)
|
||||||
|
|
||||||
|
if len(error_vertices)>0:
|
||||||
|
mPVE.update(np.mean(error_vertices), int(torch.sum(has_smpl)) )
|
||||||
|
if len(error_joints)>0:
|
||||||
|
mPJPE.update(np.mean(error_joints), int(torch.sum(has_3d_joints)) )
|
||||||
|
if len(error_joints_pa)>0:
|
||||||
|
PAmPJPE.update(np.mean(error_joints_pa), int(torch.sum(has_3d_joints)) )
|
||||||
|
|
||||||
|
val_mPVE = all_gather(float(mPVE.avg))
|
||||||
|
val_mPVE = sum(val_mPVE)/len(val_mPVE)
|
||||||
|
val_mPJPE = all_gather(float(mPJPE.avg))
|
||||||
|
val_mPJPE = sum(val_mPJPE)/len(val_mPJPE)
|
||||||
|
|
||||||
|
val_PAmPJPE = all_gather(float(PAmPJPE.avg))
|
||||||
|
val_PAmPJPE = sum(val_PAmPJPE)/len(val_PAmPJPE)
|
||||||
|
|
||||||
|
val_count = all_gather(float(mPVE.count))
|
||||||
|
val_count = sum(val_count)
|
||||||
|
|
||||||
|
return val_mPVE, val_mPJPE, val_PAmPJPE, val_count
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_mesh( renderer,
|
||||||
|
images,
|
||||||
|
gt_keypoints_2d,
|
||||||
|
pred_vertices,
|
||||||
|
pred_camera,
|
||||||
|
pred_keypoints_2d):
|
||||||
|
"""Tensorboard logging."""
|
||||||
|
gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
|
||||||
|
to_lsp = list(range(14))
|
||||||
|
rend_imgs = []
|
||||||
|
batch_size = pred_vertices.shape[0]
|
||||||
|
# Do visualization for the first 6 images of the batch
|
||||||
|
for i in range(min(batch_size, 10)):
|
||||||
|
img = images[i].cpu().numpy().transpose(1,2,0)
|
||||||
|
# Get LSP keypoints from the full list of keypoints
|
||||||
|
gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
|
||||||
|
pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
|
||||||
|
# Get predict vertices for the particular example
|
||||||
|
vertices = pred_vertices[i].cpu().numpy()
|
||||||
|
cam = pred_camera[i].cpu().numpy()
|
||||||
|
# Visualize reconstruction and detected pose
|
||||||
|
rend_img = visualize_reconstruction(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer)
|
||||||
|
rend_img = rend_img.transpose(2,0,1)
|
||||||
|
rend_imgs.append(torch.from_numpy(rend_img))
|
||||||
|
rend_imgs = make_grid(rend_imgs, nrow=1)
|
||||||
|
return rend_imgs
|
||||||
|
|
||||||
|
def visualize_mesh_test( renderer,
|
||||||
|
images,
|
||||||
|
gt_keypoints_2d,
|
||||||
|
pred_vertices,
|
||||||
|
pred_camera,
|
||||||
|
pred_keypoints_2d,
|
||||||
|
PAmPJPE_h36m_j14):
|
||||||
|
"""Tensorboard logging."""
|
||||||
|
gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
|
||||||
|
to_lsp = list(range(14))
|
||||||
|
rend_imgs = []
|
||||||
|
batch_size = pred_vertices.shape[0]
|
||||||
|
# Do visualization for the first 6 images of the batch
|
||||||
|
for i in range(min(batch_size, 10)):
|
||||||
|
img = images[i].cpu().numpy().transpose(1,2,0)
|
||||||
|
# Get LSP keypoints from the full list of keypoints
|
||||||
|
gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
|
||||||
|
pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
|
||||||
|
# Get predict vertices for the particular example
|
||||||
|
vertices = pred_vertices[i].cpu().numpy()
|
||||||
|
cam = pred_camera[i].cpu().numpy()
|
||||||
|
score = PAmPJPE_h36m_j14[i]
|
||||||
|
# Visualize reconstruction and detected pose
|
||||||
|
rend_img = visualize_reconstruction_test(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer, score)
|
||||||
|
rend_img = rend_img.transpose(2,0,1)
|
||||||
|
rend_imgs.append(torch.from_numpy(rend_img))
|
||||||
|
rend_imgs = make_grid(rend_imgs, nrow=1)
|
||||||
|
return rend_imgs
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
#########################################################
|
||||||
|
# Data related arguments
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--data_dir", default='datasets', type=str, required=False,
|
||||||
|
help="Directory with all datasets, each in one subfolder")
|
||||||
|
parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False,
|
||||||
|
help="Yaml file with all data for training.")
|
||||||
|
parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False,
|
||||||
|
help="Yaml file with all data for validation.")
|
||||||
|
parser.add_argument("--num_workers", default=4, type=int,
|
||||||
|
help="Workers in dataloader.")
|
||||||
|
parser.add_argument("--img_scale_factor", default=1, type=int,
|
||||||
|
help="adjust image resolution.")
|
||||||
|
#########################################################
|
||||||
|
# Loading/saving checkpoints
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
|
||||||
|
help="Path to pre-trained transformer model or model type.")
|
||||||
|
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
|
||||||
|
help="Path to specific checkpoint for resume training.")
|
||||||
|
parser.add_argument("--output_dir", default='output/', type=str, required=False,
|
||||||
|
help="The output directory to save checkpoint and test results.")
|
||||||
|
parser.add_argument("--config_name", default="", type=str,
|
||||||
|
help="Pretrained config name or path if not the same as model_name.")
|
||||||
|
#########################################################
|
||||||
|
# Training parameters
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--per_gpu_train_batch_size", default=30, type=int,
|
||||||
|
help="Batch size per GPU/CPU for training.")
|
||||||
|
parser.add_argument("--per_gpu_eval_batch_size", default=30, type=int,
|
||||||
|
help="Batch size per GPU/CPU for evaluation.")
|
||||||
|
parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float,
|
||||||
|
help="The initial lr.")
|
||||||
|
parser.add_argument("--num_train_epochs", default=200, type=int,
|
||||||
|
help="Total number of training epochs to perform.")
|
||||||
|
parser.add_argument("--vertices_loss_weight", default=100.0, type=float)
|
||||||
|
parser.add_argument("--joints_loss_weight", default=1000.0, type=float)
|
||||||
|
parser.add_argument("--vloss_w_full", default=0.33, type=float)
|
||||||
|
parser.add_argument("--vloss_w_sub", default=0.33, type=float)
|
||||||
|
parser.add_argument("--vloss_w_sub2", default=0.33, type=float)
|
||||||
|
parser.add_argument("--drop_out", default=0.1, type=float,
|
||||||
|
help="Drop out ratio in BERT.")
|
||||||
|
#########################################################
|
||||||
|
# Model architectures
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument('-a', '--arch', default='hrnet-w64',
|
||||||
|
help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
|
||||||
|
parser.add_argument("--num_hidden_layers", default=4, type=int, required=False,
|
||||||
|
help="Update model config if given")
|
||||||
|
parser.add_argument("--hidden_size", default=-1, type=int, required=False,
|
||||||
|
help="Update model config if given")
|
||||||
|
parser.add_argument("--num_attention_heads", default=4, type=int, required=False,
|
||||||
|
help="Update model config if given. Note that the division of "
|
||||||
|
"hidden_size / num_attention_heads should be in integer.")
|
||||||
|
parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
|
||||||
|
help="Update model config if given.")
|
||||||
|
parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
|
||||||
|
help="The Image Feature Dimension.")
|
||||||
|
parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
|
||||||
|
help="The Image Feature Dimension.")
|
||||||
|
parser.add_argument("--which_gcn", default='0,0,1', type=str,
|
||||||
|
help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
|
||||||
|
parser.add_argument("--mesh_type", default='body', type=str, help="body or hand")
|
||||||
|
parser.add_argument("--interm_size_scale", default=2, type=int)
|
||||||
|
#########################################################
|
||||||
|
# Others
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--run_eval_only", default=False, action='store_true',)
|
||||||
|
parser.add_argument('--logging_steps', type=int, default=1000,
|
||||||
|
help="Log every X steps.")
|
||||||
|
parser.add_argument("--device", type=str, default='cuda',
|
||||||
|
help="cuda or cpu")
|
||||||
|
parser.add_argument('--seed', type=int, default=88,
|
||||||
|
help="random seed for initialization.")
|
||||||
|
parser.add_argument("--local_rank", type=int, default=0,
|
||||||
|
help="For distributed training.")
|
||||||
|
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
global logger
|
||||||
|
# Setup CUDA, GPU & distributed training
|
||||||
|
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
||||||
|
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
|
||||||
|
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
|
||||||
|
|
||||||
|
args.distributed = args.num_gpus > 1
|
||||||
|
args.device = torch.device(args.device)
|
||||||
|
if args.distributed:
|
||||||
|
print("Init distributed training on local rank {} ({}), rank {}, world size {}".format(args.local_rank, int(os.environ["LOCAL_RANK"]), int(os.environ["NODE_RANK"]), args.num_gpus))
|
||||||
|
torch.cuda.set_device(args.local_rank)
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend='nccl', init_method='env://'
|
||||||
|
)
|
||||||
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
|
args.device = torch.device("cuda", local_rank)
|
||||||
|
synchronize()
|
||||||
|
|
||||||
|
mkdir(args.output_dir)
|
||||||
|
logger = setup_logger("Graphormer", args.output_dir, get_rank())
|
||||||
|
set_seed(args.seed, args.num_gpus)
|
||||||
|
logger.info("Using {} GPUs".format(args.num_gpus))
|
||||||
|
|
||||||
|
# Mesh and SMPL utils
|
||||||
|
smpl = SMPL().to(args.device)
|
||||||
|
mesh_sampler = Mesh()
|
||||||
|
|
||||||
|
# Renderer for visualization
|
||||||
|
renderer = Renderer(faces=smpl.faces.cpu().numpy())
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
trans_encoder = []
|
||||||
|
|
||||||
|
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
|
||||||
|
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
|
||||||
|
output_feat_dim = input_feat_dim[1:] + [3]
|
||||||
|
|
||||||
|
# which encoder block to have graph convs
|
||||||
|
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
|
||||||
|
|
||||||
|
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
|
||||||
|
# if only run eval, load checkpoint
|
||||||
|
logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
|
||||||
|
_model = torch.load(args.resume_checkpoint)
|
||||||
|
else:
|
||||||
|
# init three transformer-encoder blocks in a loop
|
||||||
|
for i in range(len(output_feat_dim)):
|
||||||
|
config_class, model_class = BertConfig, Graphormer
|
||||||
|
config = config_class.from_pretrained(args.config_name if args.config_name \
|
||||||
|
else args.model_name_or_path)
|
||||||
|
|
||||||
|
config.output_attentions = False
|
||||||
|
config.hidden_dropout_prob = args.drop_out
|
||||||
|
config.img_feature_dim = input_feat_dim[i]
|
||||||
|
config.output_feature_dim = output_feat_dim[i]
|
||||||
|
args.hidden_size = hidden_feat_dim[i]
|
||||||
|
args.intermediate_size = int(args.hidden_size*args.interm_size_scale)
|
||||||
|
|
||||||
|
if which_blk_graph[i]==1:
|
||||||
|
config.graph_conv = True
|
||||||
|
logger.info("Add Graph Conv")
|
||||||
|
else:
|
||||||
|
config.graph_conv = False
|
||||||
|
|
||||||
|
config.mesh_type = args.mesh_type
|
||||||
|
|
||||||
|
# update model structure if specified in arguments
|
||||||
|
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
|
||||||
|
|
||||||
|
for idx, param in enumerate(update_params):
|
||||||
|
arg_param = getattr(args, param)
|
||||||
|
config_param = getattr(config, param)
|
||||||
|
if arg_param > 0 and arg_param != config_param:
|
||||||
|
logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
|
||||||
|
setattr(config, param, arg_param)
|
||||||
|
|
||||||
|
# init a transformer encoder and append it to a list
|
||||||
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
|
model = model_class(config=config)
|
||||||
|
logger.info("Init model from scratch.")
|
||||||
|
trans_encoder.append(model)
|
||||||
|
|
||||||
|
|
||||||
|
# init ImageNet pre-trained backbone model
|
||||||
|
if args.arch=='hrnet':
|
||||||
|
hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||||
|
hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
|
||||||
|
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||||
|
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||||
|
logger.info('=> loading hrnet-v2-w40 model')
|
||||||
|
elif args.arch=='hrnet-w64':
|
||||||
|
hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||||
|
hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
|
||||||
|
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||||
|
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||||
|
logger.info('=> loading hrnet-v2-w64 model')
|
||||||
|
else:
|
||||||
|
print("=> using pre-trained model '{}'".format(args.arch))
|
||||||
|
backbone = models.__dict__[args.arch](pretrained=True)
|
||||||
|
# remove the last fc layer
|
||||||
|
backbone = torch.nn.Sequential(*list(backbone.children())[:-2])
|
||||||
|
|
||||||
|
|
||||||
|
trans_encoder = torch.nn.Sequential(*trans_encoder)
|
||||||
|
total_params = sum(p.numel() for p in trans_encoder.parameters())
|
||||||
|
logger.info('Graphormer encoders total parameters: {}'.format(total_params))
|
||||||
|
backbone_total_params = sum(p.numel() for p in backbone.parameters())
|
||||||
|
logger.info('Backbone total parameters: {}'.format(backbone_total_params))
|
||||||
|
|
||||||
|
# build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder)
|
||||||
|
_model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler)
|
||||||
|
|
||||||
|
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
|
||||||
|
# for fine-tuning or resume training or inference, load weights from checkpoint
|
||||||
|
logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
|
||||||
|
# workaround approach to load sparse tensor in graph conv.
|
||||||
|
states = torch.load(args.resume_checkpoint)
|
||||||
|
# states = checkpoint_loaded.state_dict()
|
||||||
|
for k, v in states.items():
|
||||||
|
states[k] = v.cpu()
|
||||||
|
# del checkpoint_loaded
|
||||||
|
_model.load_state_dict(states, strict=False)
|
||||||
|
del states
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
_model.to(args.device)
|
||||||
|
logger.info("Training parameters %s", args)
|
||||||
|
|
||||||
|
if args.run_eval_only==True:
|
||||||
|
val_dataloader = make_data_loader(args, args.val_yaml,
|
||||||
|
args.distributed, is_train=False, scale_factor=args.img_scale_factor)
|
||||||
|
run_eval_general(args, val_dataloader, _model, smpl, mesh_sampler)
|
||||||
|
|
||||||
|
else:
|
||||||
|
train_dataloader = make_data_loader(args, args.train_yaml,
|
||||||
|
args.distributed, is_train=True, scale_factor=args.img_scale_factor)
|
||||||
|
val_dataloader = make_data_loader(args, args.val_yaml,
|
||||||
|
args.distributed, is_train=False, scale_factor=args.img_scale_factor)
|
||||||
|
run(args, train_dataloader, val_dataloader, _model, smpl, mesh_sampler, renderer)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
||||||
351
mesh_graphormer/tools/run_gphmer_bodymesh_inference.py
Normal file
351
mesh_graphormer/tools/run_gphmer_bodymesh_inference.py
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
End-to-end inference codes for
|
||||||
|
3D human body mesh reconstruction from an image
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import os.path as op
|
||||||
|
import code
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
import torch
|
||||||
|
import torchvision.models as models
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
import gc
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from mesh_graphormer.modeling.bert import BertConfig, Graphormer
|
||||||
|
from mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network
|
||||||
|
from mesh_graphormer.modeling._smpl import SMPL, Mesh
|
||||||
|
from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
|
||||||
|
from mesh_graphormer.modeling.hrnet.config import config as hrnet_config
|
||||||
|
from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
|
||||||
|
import mesh_graphormer.modeling.data.config as cfg
|
||||||
|
from mesh_graphormer.datasets.build import make_data_loader
|
||||||
|
|
||||||
|
from mesh_graphormer.utils.logger import setup_logger
|
||||||
|
from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
|
||||||
|
from mesh_graphormer.utils.miscellaneous import mkdir, set_seed
|
||||||
|
from mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger
|
||||||
|
from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text
|
||||||
|
from mesh_graphormer.utils.metric_pampjpe import reconstruction_error
|
||||||
|
from mesh_graphormer.utils.geometric_layers import orthographic_projection
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize(224),
|
||||||
|
transforms.CenterCrop(224),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225])])
|
||||||
|
|
||||||
|
transform_visualize = transforms.Compose([
|
||||||
|
transforms.Resize(224),
|
||||||
|
transforms.CenterCrop(224),
|
||||||
|
transforms.ToTensor()])
|
||||||
|
|
||||||
|
def run_inference(args, image_list, Graphormer_model, smpl, renderer, mesh_sampler):
|
||||||
|
# switch to evaluate mode
|
||||||
|
Graphormer_model.eval()
|
||||||
|
smpl.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for image_file in image_list:
|
||||||
|
if 'pred' not in image_file:
|
||||||
|
att_all = []
|
||||||
|
img = Image.open(image_file)
|
||||||
|
img_tensor = transform(img)
|
||||||
|
img_visual = transform_visualize(img)
|
||||||
|
|
||||||
|
batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
|
||||||
|
batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device)
|
||||||
|
# forward-pass
|
||||||
|
pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, smpl, mesh_sampler)
|
||||||
|
|
||||||
|
# obtain 3d joints from full mesh
|
||||||
|
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
|
||||||
|
|
||||||
|
pred_3d_pelvis = pred_3d_joints_from_smpl[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
|
||||||
|
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
|
||||||
|
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl - pred_3d_pelvis[:, None, :]
|
||||||
|
pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
|
||||||
|
|
||||||
|
# save attantion
|
||||||
|
att_max_value = att[-1]
|
||||||
|
att_cpu = np.asarray(att_max_value.cpu().detach())
|
||||||
|
att_all.append(att_cpu)
|
||||||
|
|
||||||
|
# obtain 3d joints, which are regressed from the full mesh
|
||||||
|
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
|
||||||
|
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
|
||||||
|
# obtain 2d joints, which are projected from 3d joints of smpl mesh
|
||||||
|
pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera)
|
||||||
|
pred_2d_431_vertices_from_smpl = orthographic_projection(pred_vertices_sub2, pred_camera)
|
||||||
|
visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0],
|
||||||
|
pred_vertices[0].detach(),
|
||||||
|
pred_camera.detach())
|
||||||
|
# visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0],
|
||||||
|
# pred_vertices[0].detach(),
|
||||||
|
# pred_vertices_sub2[0].detach(),
|
||||||
|
# pred_2d_431_vertices_from_smpl[0].detach(),
|
||||||
|
# pred_2d_joints_from_smpl[0].detach(),
|
||||||
|
# pred_camera.detach(),
|
||||||
|
# att[-1][0].detach())
|
||||||
|
|
||||||
|
visual_imgs = visual_imgs_output.transpose(1,2,0)
|
||||||
|
visual_imgs = np.asarray(visual_imgs)
|
||||||
|
|
||||||
|
temp_fname = image_file[:-4] + '_graphormer_pred.jpg'
|
||||||
|
print('save to ', temp_fname)
|
||||||
|
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def visualize_mesh( renderer, images,
|
||||||
|
pred_vertices_full,
|
||||||
|
pred_camera):
|
||||||
|
img = images.cpu().numpy().transpose(1,2,0)
|
||||||
|
# Get predict vertices for the particular example
|
||||||
|
vertices_full = pred_vertices_full.cpu().numpy()
|
||||||
|
cam = pred_camera.cpu().numpy()
|
||||||
|
# Visualize only mesh reconstruction
|
||||||
|
rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue')
|
||||||
|
rend_img = rend_img.transpose(2,0,1)
|
||||||
|
return rend_img
|
||||||
|
|
||||||
|
def visualize_mesh_and_attention( renderer, images,
|
||||||
|
pred_vertices_full,
|
||||||
|
pred_vertices,
|
||||||
|
pred_2d_vertices,
|
||||||
|
pred_2d_joints,
|
||||||
|
pred_camera,
|
||||||
|
attention):
|
||||||
|
img = images.cpu().numpy().transpose(1,2,0)
|
||||||
|
# Get predict vertices for the particular example
|
||||||
|
vertices_full = pred_vertices_full.cpu().numpy()
|
||||||
|
vertices = pred_vertices.cpu().numpy()
|
||||||
|
vertices_2d = pred_2d_vertices.cpu().numpy()
|
||||||
|
joints_2d = pred_2d_joints.cpu().numpy()
|
||||||
|
cam = pred_camera.cpu().numpy()
|
||||||
|
att = attention.cpu().numpy()
|
||||||
|
# Visualize reconstruction and attention
|
||||||
|
rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue')
|
||||||
|
rend_img = rend_img.transpose(2,0,1)
|
||||||
|
return rend_img
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
#########################################################
|
||||||
|
# Data related arguments
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--num_workers", default=4, type=int,
|
||||||
|
help="Workers in dataloader.")
|
||||||
|
parser.add_argument("--img_scale_factor", default=1, type=int,
|
||||||
|
help="adjust image resolution.")
|
||||||
|
parser.add_argument("--image_file_or_path", default='./samples/human-body', type=str,
|
||||||
|
help="test data")
|
||||||
|
#########################################################
|
||||||
|
# Loading/saving checkpoints
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
|
||||||
|
help="Path to pre-trained transformer model or model type.")
|
||||||
|
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
|
||||||
|
help="Path to specific checkpoint for resume training.")
|
||||||
|
parser.add_argument("--output_dir", default='output/', type=str, required=False,
|
||||||
|
help="The output directory to save checkpoint and test results.")
|
||||||
|
parser.add_argument("--config_name", default="", type=str,
|
||||||
|
help="Pretrained config name or path if not the same as model_name.")
|
||||||
|
#########################################################
|
||||||
|
# Model architectures
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument('-a', '--arch', default='hrnet-w64',
|
||||||
|
help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
|
||||||
|
parser.add_argument("--num_hidden_layers", default=4, type=int, required=False,
|
||||||
|
help="Update model config if given")
|
||||||
|
parser.add_argument("--hidden_size", default=-1, type=int, required=False,
|
||||||
|
help="Update model config if given")
|
||||||
|
parser.add_argument("--num_attention_heads", default=4, type=int, required=False,
|
||||||
|
help="Update model config if given. Note that the division of "
|
||||||
|
"hidden_size / num_attention_heads should be in integer.")
|
||||||
|
parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
|
||||||
|
help="Update model config if given.")
|
||||||
|
parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
|
||||||
|
help="The Image Feature Dimension.")
|
||||||
|
parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
|
||||||
|
help="The Image Feature Dimension.")
|
||||||
|
parser.add_argument("--which_gcn", default='0,0,1', type=str,
|
||||||
|
help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
|
||||||
|
parser.add_argument("--mesh_type", default='body', type=str, help="body or hand")
|
||||||
|
parser.add_argument("--interm_size_scale", default=2, type=int)
|
||||||
|
#########################################################
|
||||||
|
# Others
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--run_eval_only", default=True, action='store_true',)
|
||||||
|
parser.add_argument("--device", type=str, default='cuda',
|
||||||
|
help="cuda or cpu")
|
||||||
|
parser.add_argument('--seed', type=int, default=88,
|
||||||
|
help="random seed for initialization.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
global logger
|
||||||
|
# Setup CUDA, GPU & distributed training
|
||||||
|
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
||||||
|
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
|
||||||
|
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
|
||||||
|
|
||||||
|
args.distributed = args.num_gpus > 1
|
||||||
|
args.device = torch.device(args.device)
|
||||||
|
|
||||||
|
mkdir(args.output_dir)
|
||||||
|
logger = setup_logger("Graphormer", args.output_dir, get_rank())
|
||||||
|
set_seed(args.seed, args.num_gpus)
|
||||||
|
logger.info("Using {} GPUs".format(args.num_gpus))
|
||||||
|
|
||||||
|
# Mesh and SMPL utils
|
||||||
|
smpl = SMPL().to(args.device)
|
||||||
|
mesh_sampler = Mesh()
|
||||||
|
|
||||||
|
# Renderer for visualization
|
||||||
|
renderer = Renderer(faces=smpl.faces.cpu().numpy())
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
trans_encoder = []
|
||||||
|
|
||||||
|
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
|
||||||
|
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
|
||||||
|
output_feat_dim = input_feat_dim[1:] + [3]
|
||||||
|
|
||||||
|
# which encoder block to have graph convs
|
||||||
|
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
|
||||||
|
|
||||||
|
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
|
||||||
|
# if only run eval, load checkpoint
|
||||||
|
logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
|
||||||
|
_model = torch.load(args.resume_checkpoint)
|
||||||
|
else:
|
||||||
|
# init three transformer-encoder blocks in a loop
|
||||||
|
for i in range(len(output_feat_dim)):
|
||||||
|
config_class, model_class = BertConfig, Graphormer
|
||||||
|
config = config_class.from_pretrained(args.config_name if args.config_name \
|
||||||
|
else args.model_name_or_path)
|
||||||
|
|
||||||
|
config.output_attentions = False
|
||||||
|
config.img_feature_dim = input_feat_dim[i]
|
||||||
|
config.output_feature_dim = output_feat_dim[i]
|
||||||
|
args.hidden_size = hidden_feat_dim[i]
|
||||||
|
args.intermediate_size = int(args.hidden_size*args.interm_size_scale)
|
||||||
|
|
||||||
|
if which_blk_graph[i]==1:
|
||||||
|
config.graph_conv = True
|
||||||
|
logger.info("Add Graph Conv")
|
||||||
|
else:
|
||||||
|
config.graph_conv = False
|
||||||
|
|
||||||
|
config.mesh_type = args.mesh_type
|
||||||
|
|
||||||
|
# update model structure if specified in arguments
|
||||||
|
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
|
||||||
|
|
||||||
|
for idx, param in enumerate(update_params):
|
||||||
|
arg_param = getattr(args, param)
|
||||||
|
config_param = getattr(config, param)
|
||||||
|
if arg_param > 0 and arg_param != config_param:
|
||||||
|
logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
|
||||||
|
setattr(config, param, arg_param)
|
||||||
|
|
||||||
|
# init a transformer encoder and append it to a list
|
||||||
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
|
model = model_class(config=config)
|
||||||
|
logger.info("Init model from scratch.")
|
||||||
|
trans_encoder.append(model)
|
||||||
|
|
||||||
|
# init ImageNet pre-trained backbone model
|
||||||
|
if args.arch=='hrnet':
|
||||||
|
hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||||
|
hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
|
||||||
|
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||||
|
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||||
|
logger.info('=> loading hrnet-v2-w40 model')
|
||||||
|
elif args.arch=='hrnet-w64':
|
||||||
|
hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||||
|
hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
|
||||||
|
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||||
|
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||||
|
logger.info('=> loading hrnet-v2-w64 model')
|
||||||
|
else:
|
||||||
|
print("=> using pre-trained model '{}'".format(args.arch))
|
||||||
|
backbone = models.__dict__[args.arch](pretrained=True)
|
||||||
|
# remove the last fc layer
|
||||||
|
backbone = torch.nn.Sequential(*list(backbone.children())[:-2])
|
||||||
|
|
||||||
|
|
||||||
|
trans_encoder = torch.nn.Sequential(*trans_encoder)
|
||||||
|
total_params = sum(p.numel() for p in trans_encoder.parameters())
|
||||||
|
logger.info('Graphormer encoders total parameters: {}'.format(total_params))
|
||||||
|
backbone_total_params = sum(p.numel() for p in backbone.parameters())
|
||||||
|
logger.info('Backbone total parameters: {}'.format(backbone_total_params))
|
||||||
|
|
||||||
|
# build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder)
|
||||||
|
_model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler)
|
||||||
|
|
||||||
|
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
|
||||||
|
# for fine-tuning or resume training or inference, load weights from checkpoint
|
||||||
|
logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
|
||||||
|
# workaround approach to load sparse tensor in graph conv.
|
||||||
|
states = torch.load(args.resume_checkpoint)
|
||||||
|
# states = checkpoint_loaded.state_dict()
|
||||||
|
for k, v in states.items():
|
||||||
|
states[k] = v.cpu()
|
||||||
|
# del checkpoint_loaded
|
||||||
|
_model.load_state_dict(states, strict=False)
|
||||||
|
del states
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# update configs to enable attention outputs
|
||||||
|
setattr(_model.trans_encoder[-1].config,'output_attentions', True)
|
||||||
|
setattr(_model.trans_encoder[-1].config,'output_hidden_states', True)
|
||||||
|
_model.trans_encoder[-1].bert.encoder.output_attentions = True
|
||||||
|
_model.trans_encoder[-1].bert.encoder.output_hidden_states = True
|
||||||
|
for iter_layer in range(4):
|
||||||
|
_model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True
|
||||||
|
for inter_block in range(3):
|
||||||
|
setattr(_model.trans_encoder[-1].config,'device', args.device)
|
||||||
|
|
||||||
|
_model.to(args.device)
|
||||||
|
logger.info("Run inference")
|
||||||
|
|
||||||
|
image_list = []
|
||||||
|
if not args.image_file_or_path:
|
||||||
|
raise ValueError("image_file_or_path not specified")
|
||||||
|
if op.isfile(args.image_file_or_path):
|
||||||
|
image_list = [args.image_file_or_path]
|
||||||
|
elif op.isdir(args.image_file_or_path):
|
||||||
|
# should be a path with images only
|
||||||
|
for filename in os.listdir(args.image_file_or_path):
|
||||||
|
if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename:
|
||||||
|
image_list.append(args.image_file_or_path+'/'+filename)
|
||||||
|
else:
|
||||||
|
raise ValueError("Cannot find images at {}".format(args.image_file_or_path))
|
||||||
|
|
||||||
|
run_inference(args, image_list, _model, smpl, renderer, mesh_sampler)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
||||||
713
mesh_graphormer/tools/run_gphmer_handmesh.py
Normal file
713
mesh_graphormer/tools/run_gphmer_handmesh.py
Normal file
@@ -0,0 +1,713 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
Training and evaluation codes for
|
||||||
|
3D hand mesh reconstruction from an image
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import os.path as op
|
||||||
|
import code
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
import torch
|
||||||
|
import torchvision.models as models
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
import gc
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from mesh_graphormer.modeling.bert import BertConfig, Graphormer
|
||||||
|
from mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network
|
||||||
|
from mesh_graphormer.modeling._mano import MANO, Mesh
|
||||||
|
from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
|
||||||
|
from mesh_graphormer.modeling.hrnet.config import config as hrnet_config
|
||||||
|
from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
|
||||||
|
import mesh_graphormer.modeling.data.config as cfg
|
||||||
|
from mesh_graphormer.datasets.build import make_hand_data_loader
|
||||||
|
|
||||||
|
from mesh_graphormer.utils.logger import setup_logger
|
||||||
|
from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
|
||||||
|
from mesh_graphormer.utils.miscellaneous import mkdir, set_seed
|
||||||
|
from mesh_graphormer.utils.metric_logger import AverageMeter
|
||||||
|
from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction, visualize_reconstruction_test, visualize_reconstruction_no_text
|
||||||
|
from mesh_graphormer.utils.metric_pampjpe import reconstruction_error
|
||||||
|
from mesh_graphormer.utils.geometric_layers import orthographic_projection
|
||||||
|
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
from azureml.core.run import Run
|
||||||
|
aml_run = Run.get_context()
|
||||||
|
|
||||||
|
def save_checkpoint(model, args, epoch, iteration, num_trial=10):
|
||||||
|
checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format(
|
||||||
|
epoch, iteration))
|
||||||
|
if not is_main_process():
|
||||||
|
return checkpoint_dir
|
||||||
|
mkdir(checkpoint_dir)
|
||||||
|
model_to_save = model.module if hasattr(model, 'module') else model
|
||||||
|
for i in range(num_trial):
|
||||||
|
try:
|
||||||
|
torch.save(model_to_save, op.join(checkpoint_dir, 'model.bin'))
|
||||||
|
torch.save(model_to_save.state_dict(), op.join(checkpoint_dir, 'state_dict.bin'))
|
||||||
|
torch.save(args, op.join(checkpoint_dir, 'training_args.bin'))
|
||||||
|
logger.info("Save checkpoint to {}".format(checkpoint_dir))
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.info("Failed to save checkpoint after {} trails.".format(num_trial))
|
||||||
|
return checkpoint_dir
|
||||||
|
|
||||||
|
def adjust_learning_rate(optimizer, epoch, args):
|
||||||
|
"""
|
||||||
|
Sets the learning rate to the initial LR decayed by x every y epochs
|
||||||
|
x = 0.1, y = args.num_train_epochs/2.0 = 100
|
||||||
|
"""
|
||||||
|
lr = args.lr * (0.1 ** (epoch // (args.num_train_epochs/2.0) ))
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, has_pose_2d):
|
||||||
|
"""
|
||||||
|
Compute 2D reprojection loss if 2D keypoint annotations are available.
|
||||||
|
The confidence is binary and indicates whether the keypoints exist or not.
|
||||||
|
"""
|
||||||
|
conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
|
||||||
|
loss = (conf * criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d):
|
||||||
|
"""
|
||||||
|
Compute 3D keypoint loss if 3D keypoint annotations are available.
|
||||||
|
"""
|
||||||
|
conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
|
||||||
|
gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
|
||||||
|
gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
|
||||||
|
conf = conf[has_pose_3d == 1]
|
||||||
|
pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
|
||||||
|
if len(gt_keypoints_3d) > 0:
|
||||||
|
gt_root = gt_keypoints_3d[:, 0,:]
|
||||||
|
gt_keypoints_3d = gt_keypoints_3d - gt_root[:, None, :]
|
||||||
|
pred_root = pred_keypoints_3d[:, 0,:]
|
||||||
|
pred_keypoints_3d = pred_keypoints_3d - pred_root[:, None, :]
|
||||||
|
return (conf * criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean()
|
||||||
|
else:
|
||||||
|
return torch.FloatTensor(1).fill_(0.).to(device)
|
||||||
|
|
||||||
|
def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl):
|
||||||
|
"""
|
||||||
|
Compute per-vertex loss if vertex annotations are available.
|
||||||
|
"""
|
||||||
|
pred_vertices_with_shape = pred_vertices[has_smpl == 1]
|
||||||
|
gt_vertices_with_shape = gt_vertices[has_smpl == 1]
|
||||||
|
if len(gt_vertices_with_shape) > 0:
|
||||||
|
return criterion_vertices(pred_vertices_with_shape, gt_vertices_with_shape)
|
||||||
|
else:
|
||||||
|
return torch.FloatTensor(1).fill_(0.).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def run(args, train_dataloader, Graphormer_model, mano_model, renderer, mesh_sampler):
|
||||||
|
|
||||||
|
max_iter = len(train_dataloader)
|
||||||
|
iters_per_epoch = max_iter // args.num_train_epochs
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(params=list(Graphormer_model.parameters()),
|
||||||
|
lr=args.lr,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
weight_decay=0)
|
||||||
|
|
||||||
|
# define loss function (criterion) and optimizer
|
||||||
|
criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').to(device)
|
||||||
|
criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
|
||||||
|
criterion_vertices = torch.nn.L1Loss().to(device)
|
||||||
|
|
||||||
|
if args.distributed:
|
||||||
|
Graphormer_model = torch.nn.parallel.DistributedDataParallel(
|
||||||
|
Graphormer_model, device_ids=[args.local_rank],
|
||||||
|
output_device=args.local_rank,
|
||||||
|
find_unused_parameters=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
start_training_time = time.time()
|
||||||
|
end = time.time()
|
||||||
|
Graphormer_model.train()
|
||||||
|
batch_time = AverageMeter()
|
||||||
|
data_time = AverageMeter()
|
||||||
|
log_losses = AverageMeter()
|
||||||
|
log_loss_2djoints = AverageMeter()
|
||||||
|
log_loss_3djoints = AverageMeter()
|
||||||
|
log_loss_vertices = AverageMeter()
|
||||||
|
|
||||||
|
for iteration, (img_keys, images, annotations) in enumerate(train_dataloader):
|
||||||
|
|
||||||
|
Graphormer_model.train()
|
||||||
|
iteration += 1
|
||||||
|
epoch = iteration // iters_per_epoch
|
||||||
|
batch_size = images.size(0)
|
||||||
|
adjust_learning_rate(optimizer, epoch, args)
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
|
||||||
|
images = images.to(device)
|
||||||
|
gt_2d_joints = annotations['joints_2d'].to(device)
|
||||||
|
gt_pose = annotations['pose'].to(device)
|
||||||
|
gt_betas = annotations['betas'].to(device)
|
||||||
|
has_mesh = annotations['has_smpl'].to(device)
|
||||||
|
has_3d_joints = has_mesh
|
||||||
|
has_2d_joints = has_mesh
|
||||||
|
mjm_mask = annotations['mjm_mask'].to(device)
|
||||||
|
mvm_mask = annotations['mvm_mask'].to(device)
|
||||||
|
|
||||||
|
# generate mesh
|
||||||
|
gt_vertices, gt_3d_joints = mano_model.layer(gt_pose, gt_betas)
|
||||||
|
gt_vertices = gt_vertices/1000.0
|
||||||
|
gt_3d_joints = gt_3d_joints/1000.0
|
||||||
|
|
||||||
|
gt_vertices_sub = mesh_sampler.downsample(gt_vertices)
|
||||||
|
# normalize gt based on hand's wrist
|
||||||
|
gt_3d_root = gt_3d_joints[:,cfg.J_NAME.index('Wrist'),:]
|
||||||
|
gt_vertices = gt_vertices - gt_3d_root[:, None, :]
|
||||||
|
gt_vertices_sub = gt_vertices_sub - gt_3d_root[:, None, :]
|
||||||
|
gt_3d_joints = gt_3d_joints - gt_3d_root[:, None, :]
|
||||||
|
gt_3d_joints_with_tag = torch.ones((batch_size,gt_3d_joints.shape[1],4)).to(device)
|
||||||
|
gt_3d_joints_with_tag[:,:,:3] = gt_3d_joints
|
||||||
|
|
||||||
|
# prepare masks for mask vertex/joint modeling
|
||||||
|
mjm_mask_ = mjm_mask.expand(-1,-1,2051)
|
||||||
|
mvm_mask_ = mvm_mask.expand(-1,-1,2051)
|
||||||
|
meta_masks = torch.cat([mjm_mask_, mvm_mask_], dim=1)
|
||||||
|
|
||||||
|
# forward-pass
|
||||||
|
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler, meta_masks=meta_masks, is_train=True)
|
||||||
|
|
||||||
|
# obtain 3d joints, which are regressed from the full mesh
|
||||||
|
pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
|
||||||
|
|
||||||
|
# obtain 2d joints, which are projected from 3d joints of smpl mesh
|
||||||
|
pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
|
||||||
|
pred_2d_joints = orthographic_projection(pred_3d_joints.contiguous(), pred_camera.contiguous())
|
||||||
|
|
||||||
|
# compute 3d joint loss (where the joints are directly output from transformer)
|
||||||
|
loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints_with_tag, has_3d_joints)
|
||||||
|
|
||||||
|
# compute 3d vertex loss
|
||||||
|
loss_vertices = ( args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_mesh) + \
|
||||||
|
args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_mesh) )
|
||||||
|
|
||||||
|
# compute 3d joint loss (where the joints are regressed from full mesh)
|
||||||
|
loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_mesh, gt_3d_joints_with_tag, has_3d_joints)
|
||||||
|
# compute 2d joint loss
|
||||||
|
loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints) + \
|
||||||
|
keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_mesh, gt_2d_joints, has_2d_joints)
|
||||||
|
|
||||||
|
loss_3d_joints = loss_3d_joints + loss_reg_3d_joints
|
||||||
|
|
||||||
|
# we empirically use hyperparameters to balance difference losses
|
||||||
|
loss = args.joints_loss_weight*loss_3d_joints + \
|
||||||
|
args.vertices_loss_weight*loss_vertices + args.vertices_loss_weight*loss_2d_joints
|
||||||
|
|
||||||
|
# update logs
|
||||||
|
log_loss_2djoints.update(loss_2d_joints.item(), batch_size)
|
||||||
|
log_loss_3djoints.update(loss_3d_joints.item(), batch_size)
|
||||||
|
log_loss_vertices.update(loss_vertices.item(), batch_size)
|
||||||
|
log_losses.update(loss.item(), batch_size)
|
||||||
|
|
||||||
|
# back prop
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
if iteration % args.logging_steps == 0 or iteration == max_iter:
|
||||||
|
eta_seconds = batch_time.avg * (max_iter - iteration)
|
||||||
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||||
|
logger.info(
|
||||||
|
' '.join(
|
||||||
|
['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',]
|
||||||
|
).format(eta=eta_string, ep=epoch, iter=iteration,
|
||||||
|
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
|
||||||
|
+ ' loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format(
|
||||||
|
log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg,
|
||||||
|
optimizer.param_groups[0]['lr'])
|
||||||
|
)
|
||||||
|
|
||||||
|
aml_run.log(name='Loss', value=float(log_losses.avg))
|
||||||
|
aml_run.log(name='3d joint Loss', value=float(log_loss_3djoints.avg))
|
||||||
|
aml_run.log(name='2d joint Loss', value=float(log_loss_2djoints.avg))
|
||||||
|
aml_run.log(name='vertex Loss', value=float(log_loss_vertices.avg))
|
||||||
|
|
||||||
|
visual_imgs = visualize_mesh( renderer,
|
||||||
|
annotations['ori_img'].detach(),
|
||||||
|
annotations['joints_2d'].detach(),
|
||||||
|
pred_vertices.detach(),
|
||||||
|
pred_camera.detach(),
|
||||||
|
pred_2d_joints_from_mesh.detach())
|
||||||
|
visual_imgs = visual_imgs.transpose(0,1)
|
||||||
|
visual_imgs = visual_imgs.transpose(1,2)
|
||||||
|
visual_imgs = np.asarray(visual_imgs)
|
||||||
|
|
||||||
|
if is_main_process()==True:
|
||||||
|
stamp = str(epoch) + '_' + str(iteration)
|
||||||
|
temp_fname = args.output_dir + 'visual_' + stamp + '.jpg'
|
||||||
|
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
|
||||||
|
aml_run.log_image(name='visual results', path=temp_fname)
|
||||||
|
|
||||||
|
if iteration % iters_per_epoch == 0:
|
||||||
|
if epoch%10==0:
|
||||||
|
checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration)
|
||||||
|
|
||||||
|
total_training_time = time.time() - start_training_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=total_training_time))
|
||||||
|
logger.info('Total training time: {} ({:.4f} s / iter)'.format(
|
||||||
|
total_time_str, total_training_time / max_iter)
|
||||||
|
)
|
||||||
|
checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration)
|
||||||
|
|
||||||
|
def run_eval_and_save(args, split, val_dataloader, Graphormer_model, mano_model, renderer, mesh_sampler):
|
||||||
|
|
||||||
|
criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
|
||||||
|
criterion_vertices = torch.nn.L1Loss().to(device)
|
||||||
|
|
||||||
|
if args.distributed:
|
||||||
|
Graphormer_model = torch.nn.parallel.DistributedDataParallel(
|
||||||
|
Graphormer_model, device_ids=[args.local_rank],
|
||||||
|
output_device=args.local_rank,
|
||||||
|
find_unused_parameters=True,
|
||||||
|
)
|
||||||
|
Graphormer_model.eval()
|
||||||
|
|
||||||
|
if args.aml_eval==True:
|
||||||
|
run_aml_inference_hand_mesh(args, val_dataloader,
|
||||||
|
Graphormer_model,
|
||||||
|
criterion_keypoints,
|
||||||
|
criterion_vertices,
|
||||||
|
0,
|
||||||
|
mano_model, mesh_sampler,
|
||||||
|
renderer, split)
|
||||||
|
else:
|
||||||
|
run_inference_hand_mesh(args, val_dataloader,
|
||||||
|
Graphormer_model,
|
||||||
|
criterion_keypoints,
|
||||||
|
criterion_vertices,
|
||||||
|
0,
|
||||||
|
mano_model, mesh_sampler,
|
||||||
|
renderer, split)
|
||||||
|
checkpoint_dir = save_checkpoint(Graphormer_model, args, 0, 0)
|
||||||
|
return
|
||||||
|
|
||||||
|
def run_aml_inference_hand_mesh(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, mano_model, mesh_sampler, renderer, split):
|
||||||
|
# switch to evaluate mode
|
||||||
|
Graphormer_model.eval()
|
||||||
|
fname_output_save = []
|
||||||
|
mesh_output_save = []
|
||||||
|
joint_output_save = []
|
||||||
|
world_size = get_world_size()
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, (img_keys, images, annotations) in enumerate(val_loader):
|
||||||
|
batch_size = images.size(0)
|
||||||
|
# compute output
|
||||||
|
images = images.to(device)
|
||||||
|
|
||||||
|
# forward-pass
|
||||||
|
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler)
|
||||||
|
# obtain 3d joints from full mesh
|
||||||
|
pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
|
||||||
|
|
||||||
|
for j in range(batch_size):
|
||||||
|
fname_output_save.append(img_keys[j])
|
||||||
|
pred_vertices_list = pred_vertices[j].tolist()
|
||||||
|
mesh_output_save.append(pred_vertices_list)
|
||||||
|
pred_3d_joints_from_mesh_list = pred_3d_joints_from_mesh[j].tolist()
|
||||||
|
joint_output_save.append(pred_3d_joints_from_mesh_list)
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
print('save results to pred.json')
|
||||||
|
output_json_file = 'pred.json'
|
||||||
|
print('save results to ', output_json_file)
|
||||||
|
with open(output_json_file, 'w') as f:
|
||||||
|
json.dump([joint_output_save, mesh_output_save], f)
|
||||||
|
|
||||||
|
azure_ckpt_name = '200' # args.resume_checkpoint.split('/')[-2].split('-')[1]
|
||||||
|
inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot)))
|
||||||
|
output_zip_file = args.output_dir + 'ckpt' + azure_ckpt_name + '-' + inference_setting +'-pred.zip'
|
||||||
|
|
||||||
|
resolved_submit_cmd = 'zip ' + output_zip_file + ' ' + output_json_file
|
||||||
|
print(resolved_submit_cmd)
|
||||||
|
os.system(resolved_submit_cmd)
|
||||||
|
resolved_submit_cmd = 'rm %s'%(output_json_file)
|
||||||
|
print(resolved_submit_cmd)
|
||||||
|
os.system(resolved_submit_cmd)
|
||||||
|
if world_size > 1:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def run_inference_hand_mesh(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, mano_model, mesh_sampler, renderer, split):
|
||||||
|
# switch to evaluate mode
|
||||||
|
Graphormer_model.eval()
|
||||||
|
fname_output_save = []
|
||||||
|
mesh_output_save = []
|
||||||
|
joint_output_save = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, (img_keys, images, annotations) in enumerate(val_loader):
|
||||||
|
batch_size = images.size(0)
|
||||||
|
# compute output
|
||||||
|
images = images.to(device)
|
||||||
|
|
||||||
|
# forward-pass
|
||||||
|
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler)
|
||||||
|
|
||||||
|
# obtain 3d joints from full mesh
|
||||||
|
pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
|
||||||
|
pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:]
|
||||||
|
pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :]
|
||||||
|
pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
|
||||||
|
|
||||||
|
for j in range(batch_size):
|
||||||
|
fname_output_save.append(img_keys[j])
|
||||||
|
pred_vertices_list = pred_vertices[j].tolist()
|
||||||
|
mesh_output_save.append(pred_vertices_list)
|
||||||
|
pred_3d_joints_from_mesh_list = pred_3d_joints_from_mesh[j].tolist()
|
||||||
|
joint_output_save.append(pred_3d_joints_from_mesh_list)
|
||||||
|
|
||||||
|
if i%20==0:
|
||||||
|
# obtain 3d joints, which are regressed from the full mesh
|
||||||
|
pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
|
||||||
|
# obtain 2d joints, which are projected from 3d joints of mesh
|
||||||
|
pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
|
||||||
|
visual_imgs = visualize_mesh( renderer,
|
||||||
|
annotations['ori_img'].detach(),
|
||||||
|
annotations['joints_2d'].detach(),
|
||||||
|
pred_vertices.detach(),
|
||||||
|
pred_camera.detach(),
|
||||||
|
pred_2d_joints_from_mesh.detach())
|
||||||
|
|
||||||
|
visual_imgs = visual_imgs.transpose(0,1)
|
||||||
|
visual_imgs = visual_imgs.transpose(1,2)
|
||||||
|
visual_imgs = np.asarray(visual_imgs)
|
||||||
|
|
||||||
|
inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot)))
|
||||||
|
temp_fname = args.output_dir + args.resume_checkpoint[0:-9] + 'freihand_results_'+inference_setting+'_batch'+str(i)+'.jpg'
|
||||||
|
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
|
||||||
|
|
||||||
|
print('save results to pred.json')
|
||||||
|
with open('pred.json', 'w') as f:
|
||||||
|
json.dump([joint_output_save, mesh_output_save], f)
|
||||||
|
|
||||||
|
run_exp_name = args.resume_checkpoint.split('/')[-3]
|
||||||
|
run_ckpt_name = args.resume_checkpoint.split('/')[-2].split('-')[1]
|
||||||
|
inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot)))
|
||||||
|
resolved_submit_cmd = 'zip ' + args.output_dir + run_exp_name + '-ckpt'+ run_ckpt_name + '-' + inference_setting +'-pred.zip ' + 'pred.json'
|
||||||
|
print(resolved_submit_cmd)
|
||||||
|
os.system(resolved_submit_cmd)
|
||||||
|
resolved_submit_cmd = 'rm pred.json'
|
||||||
|
print(resolved_submit_cmd)
|
||||||
|
os.system(resolved_submit_cmd)
|
||||||
|
return
|
||||||
|
|
||||||
|
def visualize_mesh( renderer,
|
||||||
|
images,
|
||||||
|
gt_keypoints_2d,
|
||||||
|
pred_vertices,
|
||||||
|
pred_camera,
|
||||||
|
pred_keypoints_2d):
|
||||||
|
"""Tensorboard logging."""
|
||||||
|
gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
|
||||||
|
to_lsp = list(range(21))
|
||||||
|
rend_imgs = []
|
||||||
|
batch_size = pred_vertices.shape[0]
|
||||||
|
# Do visualization for the first 6 images of the batch
|
||||||
|
for i in range(min(batch_size, 10)):
|
||||||
|
img = images[i].cpu().numpy().transpose(1,2,0)
|
||||||
|
# Get LSP keypoints from the full list of keypoints
|
||||||
|
gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
|
||||||
|
pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
|
||||||
|
# Get predict vertices for the particular example
|
||||||
|
vertices = pred_vertices[i].cpu().numpy()
|
||||||
|
cam = pred_camera[i].cpu().numpy()
|
||||||
|
# Visualize reconstruction and detected pose
|
||||||
|
rend_img = visualize_reconstruction(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer)
|
||||||
|
rend_img = rend_img.transpose(2,0,1)
|
||||||
|
rend_imgs.append(torch.from_numpy(rend_img))
|
||||||
|
rend_imgs = make_grid(rend_imgs, nrow=1)
|
||||||
|
return rend_imgs
|
||||||
|
|
||||||
|
def visualize_mesh_test( renderer,
|
||||||
|
images,
|
||||||
|
gt_keypoints_2d,
|
||||||
|
pred_vertices,
|
||||||
|
pred_camera,
|
||||||
|
pred_keypoints_2d,
|
||||||
|
PAmPJPE):
|
||||||
|
"""Tensorboard logging."""
|
||||||
|
gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
|
||||||
|
to_lsp = list(range(21))
|
||||||
|
rend_imgs = []
|
||||||
|
batch_size = pred_vertices.shape[0]
|
||||||
|
# Do visualization for the first 6 images of the batch
|
||||||
|
for i in range(min(batch_size, 10)):
|
||||||
|
img = images[i].cpu().numpy().transpose(1,2,0)
|
||||||
|
# Get LSP keypoints from the full list of keypoints
|
||||||
|
gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
|
||||||
|
pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
|
||||||
|
# Get predict vertices for the particular example
|
||||||
|
vertices = pred_vertices[i].cpu().numpy()
|
||||||
|
cam = pred_camera[i].cpu().numpy()
|
||||||
|
score = PAmPJPE[i]
|
||||||
|
# Visualize reconstruction and detected pose
|
||||||
|
rend_img = visualize_reconstruction_test(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer, score)
|
||||||
|
rend_img = rend_img.transpose(2,0,1)
|
||||||
|
rend_imgs.append(torch.from_numpy(rend_img))
|
||||||
|
rend_imgs = make_grid(rend_imgs, nrow=1)
|
||||||
|
return rend_imgs
|
||||||
|
|
||||||
|
def visualize_mesh_no_text( renderer,
|
||||||
|
images,
|
||||||
|
pred_vertices,
|
||||||
|
pred_camera):
|
||||||
|
"""Tensorboard logging."""
|
||||||
|
rend_imgs = []
|
||||||
|
batch_size = pred_vertices.shape[0]
|
||||||
|
# Do visualization for the first 6 images of the batch
|
||||||
|
for i in range(min(batch_size, 1)):
|
||||||
|
img = images[i].cpu().numpy().transpose(1,2,0)
|
||||||
|
# Get predict vertices for the particular example
|
||||||
|
vertices = pred_vertices[i].cpu().numpy()
|
||||||
|
cam = pred_camera[i].cpu().numpy()
|
||||||
|
# Visualize reconstruction only
|
||||||
|
rend_img = visualize_reconstruction_no_text(img, 224, vertices, cam, renderer, color='hand')
|
||||||
|
rend_img = rend_img.transpose(2,0,1)
|
||||||
|
rend_imgs.append(torch.from_numpy(rend_img))
|
||||||
|
rend_imgs = make_grid(rend_imgs, nrow=1)
|
||||||
|
return rend_imgs
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
#########################################################
|
||||||
|
# Data related arguments
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--data_dir", default='datasets', type=str, required=False,
|
||||||
|
help="Directory with all datasets, each in one subfolder")
|
||||||
|
parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False,
|
||||||
|
help="Yaml file with all data for training.")
|
||||||
|
parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False,
|
||||||
|
help="Yaml file with all data for validation.")
|
||||||
|
parser.add_argument("--num_workers", default=4, type=int,
|
||||||
|
help="Workers in dataloader.")
|
||||||
|
parser.add_argument("--img_scale_factor", default=1, type=int,
|
||||||
|
help="adjust image resolution.")
|
||||||
|
#########################################################
|
||||||
|
# Loading/saving checkpoints
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
|
||||||
|
help="Path to pre-trained transformer model or model type.")
|
||||||
|
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
|
||||||
|
help="Path to specific checkpoint for resume training.")
|
||||||
|
parser.add_argument("--output_dir", default='output/', type=str, required=False,
|
||||||
|
help="The output directory to save checkpoint and test results.")
|
||||||
|
parser.add_argument("--config_name", default="", type=str,
|
||||||
|
help="Pretrained config name or path if not the same as model_name.")
|
||||||
|
parser.add_argument('-a', '--arch', default='hrnet-w64',
|
||||||
|
help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
|
||||||
|
#########################################################
|
||||||
|
# Training parameters
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--per_gpu_train_batch_size", default=64, type=int,
|
||||||
|
help="Batch size per GPU/CPU for training.")
|
||||||
|
parser.add_argument("--per_gpu_eval_batch_size", default=64, type=int,
|
||||||
|
help="Batch size per GPU/CPU for evaluation.")
|
||||||
|
parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float,
|
||||||
|
help="The initial lr.")
|
||||||
|
parser.add_argument("--num_train_epochs", default=200, type=int,
|
||||||
|
help="Total number of training epochs to perform.")
|
||||||
|
parser.add_argument("--vertices_loss_weight", default=1.0, type=float)
|
||||||
|
parser.add_argument("--joints_loss_weight", default=1.0, type=float)
|
||||||
|
parser.add_argument("--vloss_w_full", default=0.5, type=float)
|
||||||
|
parser.add_argument("--vloss_w_sub", default=0.5, type=float)
|
||||||
|
parser.add_argument("--drop_out", default=0.1, type=float,
|
||||||
|
help="Drop out ratio in BERT.")
|
||||||
|
#########################################################
|
||||||
|
# Model architectures
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--num_hidden_layers", default=-1, type=int, required=False,
|
||||||
|
help="Update model config if given")
|
||||||
|
parser.add_argument("--hidden_size", default=-1, type=int, required=False,
|
||||||
|
help="Update model config if given")
|
||||||
|
parser.add_argument("--num_attention_heads", default=-1, type=int, required=False,
|
||||||
|
help="Update model config if given. Note that the division of "
|
||||||
|
"hidden_size / num_attention_heads should be in integer.")
|
||||||
|
parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
|
||||||
|
help="Update model config if given.")
|
||||||
|
parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
|
||||||
|
help="The Image Feature Dimension.")
|
||||||
|
parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
|
||||||
|
help="The Image Feature Dimension.")
|
||||||
|
parser.add_argument("--which_gcn", default='0,0,1', type=str,
|
||||||
|
help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
|
||||||
|
parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand")
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Others
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--run_eval_only", default=False, action='store_true',)
|
||||||
|
parser.add_argument("--multiscale_inference", default=False, action='store_true',)
|
||||||
|
# if enable "multiscale_inference", dataloader will apply transformations to the test image based on
|
||||||
|
# the rotation "rot" and scale "sc" parameters below
|
||||||
|
parser.add_argument("--rot", default=0, type=float)
|
||||||
|
parser.add_argument("--sc", default=1.0, type=float)
|
||||||
|
parser.add_argument("--aml_eval", default=False, action='store_true',)
|
||||||
|
|
||||||
|
parser.add_argument('--logging_steps', type=int, default=100,
|
||||||
|
help="Log every X steps.")
|
||||||
|
parser.add_argument("--device", type=str, default='cuda',
|
||||||
|
help="cuda or cpu")
|
||||||
|
parser.add_argument('--seed', type=int, default=88,
|
||||||
|
help="random seed for initialization.")
|
||||||
|
parser.add_argument("--local_rank", type=int, default=0,
|
||||||
|
help="For distributed training.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
global logger
|
||||||
|
# Setup CUDA, GPU & distributed training
|
||||||
|
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
||||||
|
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
|
||||||
|
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
|
||||||
|
|
||||||
|
args.distributed = args.num_gpus > 1
|
||||||
|
args.device = torch.device(args.device)
|
||||||
|
if args.distributed:
|
||||||
|
print("Init distributed training on local rank {}".format(args.local_rank))
|
||||||
|
torch.cuda.set_device(args.local_rank)
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend='nccl', init_method='env://'
|
||||||
|
)
|
||||||
|
synchronize()
|
||||||
|
|
||||||
|
mkdir(args.output_dir)
|
||||||
|
logger = setup_logger("Graphormer", args.output_dir, get_rank())
|
||||||
|
set_seed(args.seed, args.num_gpus)
|
||||||
|
logger.info("Using {} GPUs".format(args.num_gpus))
|
||||||
|
|
||||||
|
# Mesh and SMPL utils
|
||||||
|
mano_model = MANO().to(args.device)
|
||||||
|
mano_model.layer = mano_model.layer.to(device)
|
||||||
|
mesh_sampler = Mesh()
|
||||||
|
|
||||||
|
# Renderer for visualization
|
||||||
|
renderer = Renderer(faces=mano_model.face)
|
||||||
|
|
||||||
|
# Load pretrained model
|
||||||
|
trans_encoder = []
|
||||||
|
|
||||||
|
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
|
||||||
|
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
|
||||||
|
output_feat_dim = input_feat_dim[1:] + [3]
|
||||||
|
|
||||||
|
# which encoder block to have graph convs
|
||||||
|
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
|
||||||
|
|
||||||
|
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
|
||||||
|
# if only run eval, load checkpoint
|
||||||
|
logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
|
||||||
|
_model = torch.load(args.resume_checkpoint)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# init three transformer-encoder blocks in a loop
|
||||||
|
for i in range(len(output_feat_dim)):
|
||||||
|
config_class, model_class = BertConfig, Graphormer
|
||||||
|
config = config_class.from_pretrained(args.config_name if args.config_name \
|
||||||
|
else args.model_name_or_path)
|
||||||
|
|
||||||
|
config.output_attentions = False
|
||||||
|
config.hidden_dropout_prob = args.drop_out
|
||||||
|
config.img_feature_dim = input_feat_dim[i]
|
||||||
|
config.output_feature_dim = output_feat_dim[i]
|
||||||
|
args.hidden_size = hidden_feat_dim[i]
|
||||||
|
args.intermediate_size = int(args.hidden_size*2)
|
||||||
|
|
||||||
|
if which_blk_graph[i]==1:
|
||||||
|
config.graph_conv = True
|
||||||
|
logger.info("Add Graph Conv")
|
||||||
|
else:
|
||||||
|
config.graph_conv = False
|
||||||
|
|
||||||
|
config.mesh_type = args.mesh_type
|
||||||
|
|
||||||
|
# update model structure if specified in arguments
|
||||||
|
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
|
||||||
|
for idx, param in enumerate(update_params):
|
||||||
|
arg_param = getattr(args, param)
|
||||||
|
config_param = getattr(config, param)
|
||||||
|
if arg_param > 0 and arg_param != config_param:
|
||||||
|
logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
|
||||||
|
setattr(config, param, arg_param)
|
||||||
|
|
||||||
|
# init a transformer encoder and append it to a list
|
||||||
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
|
model = model_class(config=config)
|
||||||
|
logger.info("Init model from scratch.")
|
||||||
|
trans_encoder.append(model)
|
||||||
|
|
||||||
|
# create backbone model
|
||||||
|
if args.arch=='hrnet':
|
||||||
|
hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||||
|
hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
|
||||||
|
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||||
|
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||||
|
logger.info('=> loading hrnet-v2-w40 model')
|
||||||
|
elif args.arch=='hrnet-w64':
|
||||||
|
hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||||
|
hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
|
||||||
|
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||||
|
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||||
|
logger.info('=> loading hrnet-v2-w64 model')
|
||||||
|
else:
|
||||||
|
print("=> using pre-trained model '{}'".format(args.arch))
|
||||||
|
backbone = models.__dict__[args.arch](pretrained=True)
|
||||||
|
# remove the last fc layer
|
||||||
|
backbone = torch.nn.Sequential(*list(backbone.children())[:-1])
|
||||||
|
|
||||||
|
trans_encoder = torch.nn.Sequential(*trans_encoder)
|
||||||
|
total_params = sum(p.numel() for p in trans_encoder.parameters())
|
||||||
|
logger.info('Graphormer encoders total parameters: {}'.format(total_params))
|
||||||
|
backbone_total_params = sum(p.numel() for p in backbone.parameters())
|
||||||
|
logger.info('Backbone total parameters: {}'.format(backbone_total_params))
|
||||||
|
|
||||||
|
# build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
|
||||||
|
_model = Graphormer_Network(args, config, backbone, trans_encoder)
|
||||||
|
|
||||||
|
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
|
||||||
|
# for fine-tuning or resume training or inference, load weights from checkpoint
|
||||||
|
logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
|
||||||
|
# workaround approach to load sparse tensor in graph conv.
|
||||||
|
state_dict = torch.load(args.resume_checkpoint)
|
||||||
|
_model.load_state_dict(state_dict, strict=False)
|
||||||
|
del state_dict
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
_model.to(args.device)
|
||||||
|
logger.info("Training parameters %s", args)
|
||||||
|
|
||||||
|
if args.run_eval_only==True:
|
||||||
|
val_dataloader = make_hand_data_loader(args, args.val_yaml,
|
||||||
|
args.distributed, is_train=False, scale_factor=args.img_scale_factor)
|
||||||
|
run_eval_and_save(args, 'freihand', val_dataloader, _model, mano_model, renderer, mesh_sampler)
|
||||||
|
|
||||||
|
else:
|
||||||
|
train_dataloader = make_hand_data_loader(args, args.train_yaml,
|
||||||
|
args.distributed, is_train=True, scale_factor=args.img_scale_factor)
|
||||||
|
run(args, train_dataloader, _model, mano_model, renderer, mesh_sampler)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
||||||
338
mesh_graphormer/tools/run_gphmer_handmesh_inference.py
Normal file
338
mesh_graphormer/tools/run_gphmer_handmesh_inference.py
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
End-to-end inference codes for
|
||||||
|
3D hand mesh reconstruction from an image
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import os.path as op
|
||||||
|
import code
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
import torch
|
||||||
|
import torchvision.models as models
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
import gc
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from mesh_graphormer.modeling.bert import BertConfig, Graphormer
|
||||||
|
from mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network
|
||||||
|
from mesh_graphormer.modeling._mano import MANO, Mesh
|
||||||
|
from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
|
||||||
|
from mesh_graphormer.modeling.hrnet.config import config as hrnet_config
|
||||||
|
from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
|
||||||
|
import mesh_graphormer.modeling.data.config as cfg
|
||||||
|
from mesh_graphormer.datasets.build import make_hand_data_loader
|
||||||
|
|
||||||
|
from mesh_graphormer.utils.logger import setup_logger
|
||||||
|
from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
|
||||||
|
from mesh_graphormer.utils.miscellaneous import mkdir, set_seed
|
||||||
|
from mesh_graphormer.utils.metric_logger import AverageMeter
|
||||||
|
from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text
|
||||||
|
from mesh_graphormer.utils.metric_pampjpe import reconstruction_error
|
||||||
|
from mesh_graphormer.utils.geometric_layers import orthographic_projection
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize(224),
|
||||||
|
transforms.CenterCrop(224),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225])])
|
||||||
|
|
||||||
|
transform_visualize = transforms.Compose([
|
||||||
|
transforms.Resize(224),
|
||||||
|
transforms.CenterCrop(224),
|
||||||
|
transforms.ToTensor()])
|
||||||
|
|
||||||
|
def run_inference(args, image_list, Graphormer_model, mano, renderer, mesh_sampler):
|
||||||
|
# switch to evaluate mode
|
||||||
|
Graphormer_model.eval()
|
||||||
|
mano.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for image_file in image_list:
|
||||||
|
if 'pred' not in image_file:
|
||||||
|
att_all = []
|
||||||
|
print(image_file)
|
||||||
|
img = Image.open(image_file)
|
||||||
|
img_tensor = transform(img)
|
||||||
|
img_visual = transform_visualize(img)
|
||||||
|
|
||||||
|
batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
|
||||||
|
batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device)
|
||||||
|
# forward-pass
|
||||||
|
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler)
|
||||||
|
# obtain 3d joints from full mesh
|
||||||
|
pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
|
||||||
|
pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:]
|
||||||
|
pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :]
|
||||||
|
pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
|
||||||
|
|
||||||
|
# save attantion
|
||||||
|
att_max_value = att[-1]
|
||||||
|
att_cpu = np.asarray(att_max_value.cpu().detach())
|
||||||
|
att_all.append(att_cpu)
|
||||||
|
|
||||||
|
# obtain 3d joints, which are regressed from the full mesh
|
||||||
|
pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
|
||||||
|
# obtain 2d joints, which are projected from 3d joints of mesh
|
||||||
|
pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
|
||||||
|
pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous())
|
||||||
|
|
||||||
|
|
||||||
|
visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0],
|
||||||
|
pred_vertices[0].detach(),
|
||||||
|
pred_camera.detach())
|
||||||
|
# visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0],
|
||||||
|
# pred_vertices[0].detach(),
|
||||||
|
# pred_vertices_sub[0].detach(),
|
||||||
|
# pred_2d_coarse_vertices_from_mesh[0].detach(),
|
||||||
|
# pred_2d_joints_from_mesh[0].detach(),
|
||||||
|
# pred_camera.detach(),
|
||||||
|
# att[-1][0].detach())
|
||||||
|
visual_imgs = visual_imgs_output.transpose(1,2,0)
|
||||||
|
visual_imgs = np.asarray(visual_imgs)
|
||||||
|
|
||||||
|
temp_fname = image_file[:-4] + '_graphormer_pred.jpg'
|
||||||
|
print('save to ', temp_fname)
|
||||||
|
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
|
||||||
|
return
|
||||||
|
|
||||||
|
def visualize_mesh( renderer, images,
|
||||||
|
pred_vertices_full,
|
||||||
|
pred_camera):
|
||||||
|
img = images.cpu().numpy().transpose(1,2,0)
|
||||||
|
# Get predict vertices for the particular example
|
||||||
|
vertices_full = pred_vertices_full.cpu().numpy()
|
||||||
|
cam = pred_camera.cpu().numpy()
|
||||||
|
# Visualize only mesh reconstruction
|
||||||
|
rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue')
|
||||||
|
rend_img = rend_img.transpose(2,0,1)
|
||||||
|
return rend_img
|
||||||
|
|
||||||
|
def visualize_mesh_and_attention( renderer, images,
|
||||||
|
pred_vertices_full,
|
||||||
|
pred_vertices,
|
||||||
|
pred_2d_vertices,
|
||||||
|
pred_2d_joints,
|
||||||
|
pred_camera,
|
||||||
|
attention):
|
||||||
|
img = images.cpu().numpy().transpose(1,2,0)
|
||||||
|
# Get predict vertices for the particular example
|
||||||
|
vertices_full = pred_vertices_full.cpu().numpy()
|
||||||
|
vertices = pred_vertices.cpu().numpy()
|
||||||
|
vertices_2d = pred_2d_vertices.cpu().numpy()
|
||||||
|
joints_2d = pred_2d_joints.cpu().numpy()
|
||||||
|
cam = pred_camera.cpu().numpy()
|
||||||
|
att = attention.cpu().numpy()
|
||||||
|
# Visualize reconstruction and attention
|
||||||
|
rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue')
|
||||||
|
rend_img = rend_img.transpose(2,0,1)
|
||||||
|
return rend_img
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
#########################################################
|
||||||
|
# Data related arguments
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--num_workers", default=4, type=int,
|
||||||
|
help="Workers in dataloader.")
|
||||||
|
parser.add_argument("--img_scale_factor", default=1, type=int,
|
||||||
|
help="adjust image resolution.")
|
||||||
|
parser.add_argument("--image_file_or_path", default='./samples/hand', type=str,
|
||||||
|
help="test data")
|
||||||
|
#########################################################
|
||||||
|
# Loading/saving checkpoints
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
|
||||||
|
help="Path to pre-trained transformer model or model type.")
|
||||||
|
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
|
||||||
|
help="Path to specific checkpoint for resume training.")
|
||||||
|
parser.add_argument("--output_dir", default='output/', type=str, required=False,
|
||||||
|
help="The output directory to save checkpoint and test results.")
|
||||||
|
parser.add_argument("--config_name", default="", type=str,
|
||||||
|
help="Pretrained config name or path if not the same as model_name.")
|
||||||
|
parser.add_argument('-a', '--arch', default='hrnet-w64',
|
||||||
|
help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
|
||||||
|
#########################################################
|
||||||
|
# Model architectures
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--num_hidden_layers", default=4, type=int, required=False,
|
||||||
|
help="Update model config if given")
|
||||||
|
parser.add_argument("--hidden_size", default=-1, type=int, required=False,
|
||||||
|
help="Update model config if given")
|
||||||
|
parser.add_argument("--num_attention_heads", default=4, type=int, required=False,
|
||||||
|
help="Update model config if given. Note that the division of "
|
||||||
|
"hidden_size / num_attention_heads should be in integer.")
|
||||||
|
parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
|
||||||
|
help="Update model config if given.")
|
||||||
|
parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
|
||||||
|
help="The Image Feature Dimension.")
|
||||||
|
parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
|
||||||
|
help="The Image Feature Dimension.")
|
||||||
|
parser.add_argument("--which_gcn", default='0,0,1', type=str,
|
||||||
|
help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
|
||||||
|
parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand")
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Others
|
||||||
|
#########################################################
|
||||||
|
parser.add_argument("--run_eval_only", default=True, action='store_true',)
|
||||||
|
parser.add_argument("--device", type=str, default='cuda',
|
||||||
|
help="cuda or cpu")
|
||||||
|
parser.add_argument('--seed', type=int, default=88,
|
||||||
|
help="random seed for initialization.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
global logger
|
||||||
|
# Setup CUDA, GPU & distributed training
|
||||||
|
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
||||||
|
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
|
||||||
|
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
|
||||||
|
|
||||||
|
mkdir(args.output_dir)
|
||||||
|
logger = setup_logger("Graphormer", args.output_dir, get_rank())
|
||||||
|
set_seed(args.seed, args.num_gpus)
|
||||||
|
logger.info("Using {} GPUs".format(args.num_gpus))
|
||||||
|
|
||||||
|
# Mesh and MANO utils
|
||||||
|
mano_model = MANO().to(args.device)
|
||||||
|
mano_model.layer = mano_model.layer.to(device)
|
||||||
|
mesh_sampler = Mesh()
|
||||||
|
|
||||||
|
# Renderer for visualization
|
||||||
|
renderer = Renderer(faces=mano_model.face)
|
||||||
|
|
||||||
|
# Load pretrained model
|
||||||
|
trans_encoder = []
|
||||||
|
|
||||||
|
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
|
||||||
|
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
|
||||||
|
output_feat_dim = input_feat_dim[1:] + [3]
|
||||||
|
|
||||||
|
# which encoder block to have graph convs
|
||||||
|
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
|
||||||
|
|
||||||
|
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
|
||||||
|
# if only run eval, load checkpoint
|
||||||
|
logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
|
||||||
|
_model = torch.load(args.resume_checkpoint)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# init three transformer-encoder blocks in a loop
|
||||||
|
for i in range(len(output_feat_dim)):
|
||||||
|
config_class, model_class = BertConfig, Graphormer
|
||||||
|
config = config_class.from_pretrained(args.config_name if args.config_name \
|
||||||
|
else args.model_name_or_path)
|
||||||
|
|
||||||
|
config.output_attentions = False
|
||||||
|
config.img_feature_dim = input_feat_dim[i]
|
||||||
|
config.output_feature_dim = output_feat_dim[i]
|
||||||
|
args.hidden_size = hidden_feat_dim[i]
|
||||||
|
args.intermediate_size = int(args.hidden_size*2)
|
||||||
|
|
||||||
|
if which_blk_graph[i]==1:
|
||||||
|
config.graph_conv = True
|
||||||
|
logger.info("Add Graph Conv")
|
||||||
|
else:
|
||||||
|
config.graph_conv = False
|
||||||
|
|
||||||
|
config.mesh_type = args.mesh_type
|
||||||
|
|
||||||
|
# update model structure if specified in arguments
|
||||||
|
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
|
||||||
|
for idx, param in enumerate(update_params):
|
||||||
|
arg_param = getattr(args, param)
|
||||||
|
config_param = getattr(config, param)
|
||||||
|
if arg_param > 0 and arg_param != config_param:
|
||||||
|
logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
|
||||||
|
setattr(config, param, arg_param)
|
||||||
|
|
||||||
|
# init a transformer encoder and append it to a list
|
||||||
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
|
model = model_class(config=config)
|
||||||
|
logger.info("Init model from scratch.")
|
||||||
|
trans_encoder.append(model)
|
||||||
|
|
||||||
|
# create backbone model
|
||||||
|
if args.arch=='hrnet':
|
||||||
|
hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||||
|
hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
|
||||||
|
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||||
|
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||||
|
logger.info('=> loading hrnet-v2-w40 model')
|
||||||
|
elif args.arch=='hrnet-w64':
|
||||||
|
hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
|
||||||
|
hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
|
||||||
|
hrnet_update_config(hrnet_config, hrnet_yaml)
|
||||||
|
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
|
||||||
|
logger.info('=> loading hrnet-v2-w64 model')
|
||||||
|
else:
|
||||||
|
print("=> using pre-trained model '{}'".format(args.arch))
|
||||||
|
backbone = models.__dict__[args.arch](pretrained=True)
|
||||||
|
# remove the last fc layer
|
||||||
|
backbone = torch.nn.Sequential(*list(backbone.children())[:-1])
|
||||||
|
|
||||||
|
trans_encoder = torch.nn.Sequential(*trans_encoder)
|
||||||
|
total_params = sum(p.numel() for p in trans_encoder.parameters())
|
||||||
|
logger.info('Graphormer encoders total parameters: {}'.format(total_params))
|
||||||
|
backbone_total_params = sum(p.numel() for p in backbone.parameters())
|
||||||
|
logger.info('Backbone total parameters: {}'.format(backbone_total_params))
|
||||||
|
|
||||||
|
# build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
|
||||||
|
_model = Graphormer_Network(args, config, backbone, trans_encoder)
|
||||||
|
|
||||||
|
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
|
||||||
|
# for fine-tuning or resume training or inference, load weights from checkpoint
|
||||||
|
logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
|
||||||
|
# workaround approach to load sparse tensor in graph conv.
|
||||||
|
state_dict = torch.load(args.resume_checkpoint)
|
||||||
|
_model.load_state_dict(state_dict, strict=False)
|
||||||
|
del state_dict
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# update configs to enable attention outputs
|
||||||
|
setattr(_model.trans_encoder[-1].config,'output_attentions', True)
|
||||||
|
setattr(_model.trans_encoder[-1].config,'output_hidden_states', True)
|
||||||
|
_model.trans_encoder[-1].bert.encoder.output_attentions = True
|
||||||
|
_model.trans_encoder[-1].bert.encoder.output_hidden_states = True
|
||||||
|
for iter_layer in range(4):
|
||||||
|
_model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True
|
||||||
|
for inter_block in range(3):
|
||||||
|
setattr(_model.trans_encoder[-1].config,'device', args.device)
|
||||||
|
|
||||||
|
_model.to(args.device)
|
||||||
|
logger.info("Run inference")
|
||||||
|
|
||||||
|
image_list = []
|
||||||
|
if not args.image_file_or_path:
|
||||||
|
raise ValueError("image_file_or_path not specified")
|
||||||
|
if op.isfile(args.image_file_or_path):
|
||||||
|
image_list = [args.image_file_or_path]
|
||||||
|
elif op.isdir(args.image_file_or_path):
|
||||||
|
# should be a path with images only
|
||||||
|
for filename in os.listdir(args.image_file_or_path):
|
||||||
|
if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename:
|
||||||
|
image_list.append(args.image_file_or_path+'/'+filename)
|
||||||
|
else:
|
||||||
|
raise ValueError("Cannot find images at {}".format(args.image_file_or_path))
|
||||||
|
|
||||||
|
run_inference(args, image_list, _model, mano_model, renderer, mesh_sampler)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
||||||
136
mesh_graphormer/tools/run_hand_multiscale.py
Normal file
136
mesh_graphormer/tools/run_hand_multiscale.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import os.path as op
|
||||||
|
import code
|
||||||
|
import json
|
||||||
|
import zipfile
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from mesh_graphormer.utils.metric_pampjpe import get_alignMesh
|
||||||
|
|
||||||
|
|
||||||
|
def load_pred_json(filepath):
|
||||||
|
archive = zipfile.ZipFile(filepath, 'r')
|
||||||
|
jsondata = archive.read('pred.json')
|
||||||
|
reference = json.loads(jsondata.decode("utf-8"))
|
||||||
|
return reference[0], reference[1]
|
||||||
|
|
||||||
|
|
||||||
|
def multiscale_fusion(output_dir):
|
||||||
|
s = '10'
|
||||||
|
filepath = output_dir+'ckpt200-sc10_rot0-pred.zip'
|
||||||
|
ref_joints, ref_vertices = load_pred_json(filepath)
|
||||||
|
ref_joints_array = np.asarray(ref_joints)
|
||||||
|
ref_vertices_array = np.asarray(ref_vertices)
|
||||||
|
|
||||||
|
rotations = [0.0]
|
||||||
|
for i in range(1,10):
|
||||||
|
rotations.append(i*10)
|
||||||
|
rotations.append(i*-10)
|
||||||
|
|
||||||
|
scale = [0.7,0.8,0.9,1.0,1.1]
|
||||||
|
multiscale_joints = []
|
||||||
|
multiscale_vertices = []
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
for s in scale:
|
||||||
|
for r in rotations:
|
||||||
|
setting = 'sc%02d_rot%s'%(int(s*10),str(int(r)))
|
||||||
|
filepath = output_dir+'ckpt200-'+setting+'-pred.zip'
|
||||||
|
joints, vertices = load_pred_json(filepath)
|
||||||
|
joints_array = np.asarray(joints)
|
||||||
|
vertices_array = np.asarray(vertices)
|
||||||
|
|
||||||
|
pa_joint_error, pa_joint_array, _ = get_alignMesh(joints_array, ref_joints_array, reduction=None)
|
||||||
|
pa_vertices_error, pa_vertices_array, _ = get_alignMesh(vertices_array, ref_vertices_array, reduction=None)
|
||||||
|
print('--------------------------')
|
||||||
|
print('scale:', s, 'rotate', r)
|
||||||
|
print('PAMPJPE:', 1000*np.mean(pa_joint_error))
|
||||||
|
print('PAMPVPE:', 1000*np.mean(pa_vertices_error))
|
||||||
|
multiscale_joints.append(pa_joint_array)
|
||||||
|
multiscale_vertices.append(pa_vertices_array)
|
||||||
|
counter = counter + 1
|
||||||
|
|
||||||
|
overall_joints_array = ref_joints_array.copy()
|
||||||
|
overall_vertices_array = ref_vertices_array.copy()
|
||||||
|
for i in range(counter):
|
||||||
|
overall_joints_array += multiscale_joints[i]
|
||||||
|
overall_vertices_array += multiscale_vertices[i]
|
||||||
|
|
||||||
|
overall_joints_array /= (1+counter)
|
||||||
|
overall_vertices_array /= (1+counter)
|
||||||
|
pa_joint_error, pa_joint_array, _ = get_alignMesh(overall_joints_array, ref_joints_array, reduction=None)
|
||||||
|
pa_vertices_error, pa_vertices_array, _ = get_alignMesh(overall_vertices_array, ref_vertices_array, reduction=None)
|
||||||
|
print('--------------------------')
|
||||||
|
print('overall:')
|
||||||
|
print('PAMPJPE:', 1000*np.mean(pa_joint_error))
|
||||||
|
print('PAMPVPE:', 1000*np.mean(pa_vertices_error))
|
||||||
|
|
||||||
|
joint_output_save = overall_joints_array.tolist()
|
||||||
|
mesh_output_save = overall_vertices_array.tolist()
|
||||||
|
|
||||||
|
print('save results to pred.json')
|
||||||
|
with open('pred.json', 'w') as f:
|
||||||
|
json.dump([joint_output_save, mesh_output_save], f)
|
||||||
|
|
||||||
|
|
||||||
|
filepath = output_dir+'ckpt200-multisc-pred.zip'
|
||||||
|
resolved_submit_cmd = 'zip ' + filepath + ' ' + 'pred.json'
|
||||||
|
print(resolved_submit_cmd)
|
||||||
|
os.system(resolved_submit_cmd)
|
||||||
|
resolved_submit_cmd = 'rm pred.json'
|
||||||
|
print(resolved_submit_cmd)
|
||||||
|
os.system(resolved_submit_cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def run_multiscale_inference(model_path, mode, output_dir):
|
||||||
|
|
||||||
|
if mode==True:
|
||||||
|
rotations = [0.0]
|
||||||
|
for i in range(1,10):
|
||||||
|
rotations.append(i*10)
|
||||||
|
rotations.append(i*-10)
|
||||||
|
scale = [0.7,0.8,0.9,1.0,1.1]
|
||||||
|
else:
|
||||||
|
rotations = [0.0]
|
||||||
|
scale = [1.0]
|
||||||
|
|
||||||
|
job_cmd = "python ./src/tools/run_gphmer_handmesh.py " \
|
||||||
|
"--val_yaml freihand_v3/test.yaml " \
|
||||||
|
"--resume_checkpoint %s " \
|
||||||
|
"--per_gpu_eval_batch_size 32 --run_eval_only --num_worker 2 " \
|
||||||
|
"--multiscale_inference " \
|
||||||
|
"--rot %f " \
|
||||||
|
"--sc %s " \
|
||||||
|
"--arch hrnet-w64 " \
|
||||||
|
"--num_hidden_layers 4 " \
|
||||||
|
"--num_attention_heads 4 " \
|
||||||
|
"--input_feat_dim 2051,512,128 " \
|
||||||
|
"--hidden_feat_dim 1024,256,64 " \
|
||||||
|
"--output_dir %s"
|
||||||
|
|
||||||
|
for s in scale:
|
||||||
|
for r in rotations:
|
||||||
|
resolved_submit_cmd = job_cmd%(model_path, r, s, output_dir)
|
||||||
|
print(resolved_submit_cmd)
|
||||||
|
os.system(resolved_submit_cmd)
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
model_path = args.model_path
|
||||||
|
mode = args.multiscale_inference
|
||||||
|
output_dir = args.output_dir
|
||||||
|
run_multiscale_inference(model_path, mode, output_dir)
|
||||||
|
if mode==True:
|
||||||
|
multiscale_fusion(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Evaluate a checkpoint in the folder")
|
||||||
|
parser.add_argument("--model_path")
|
||||||
|
parser.add_argument("--multiscale_inference", default=False, action='store_true',)
|
||||||
|
parser.add_argument("--output_dir", default='output/', type=str, required=False,
|
||||||
|
help="The output directory to save checkpoint and test results.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
0
mesh_graphormer/utils/__init__.py
Normal file
0
mesh_graphormer/utils/__init__.py
Normal file
176
mesh_graphormer/utils/comm.py
Normal file
176
mesh_graphormer/utils/comm.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
This file contains primitives for multi-gpu communication.
|
||||||
|
This is useful when doing distributed training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
if not dist.is_available():
|
||||||
|
return 1
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return 1
|
||||||
|
return dist.get_world_size()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if not dist.is_available():
|
||||||
|
return 0
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return 0
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
return get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def synchronize():
|
||||||
|
"""
|
||||||
|
Helper function to synchronize (barrier) among all processes when
|
||||||
|
using distributed training
|
||||||
|
"""
|
||||||
|
if not dist.is_available():
|
||||||
|
return
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
if world_size == 1:
|
||||||
|
return
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
def gather_on_master(data):
|
||||||
|
"""Same as all_gather, but gathers data on master process only, using CPU.
|
||||||
|
Thus, this does not work with NCCL backend unless they add CPU support.
|
||||||
|
|
||||||
|
The memory consumption of this function is ~ 3x of data size. While in
|
||||||
|
principal, it should be ~2x, it's not easy to force Python to release
|
||||||
|
memory immediately and thus, peak memory usage could be up to 3x.
|
||||||
|
"""
|
||||||
|
world_size = get_world_size()
|
||||||
|
if world_size == 1:
|
||||||
|
return [data]
|
||||||
|
|
||||||
|
# serialized to a Tensor
|
||||||
|
buffer = pickle.dumps(data)
|
||||||
|
# trying to optimize memory, but in fact, it's not guaranteed to be released
|
||||||
|
del data
|
||||||
|
storage = torch.ByteStorage.from_buffer(buffer)
|
||||||
|
del buffer
|
||||||
|
tensor = torch.ByteTensor(storage)
|
||||||
|
|
||||||
|
# obtain Tensor size of each rank
|
||||||
|
local_size = torch.LongTensor([tensor.numel()])
|
||||||
|
size_list = [torch.LongTensor([0]) for _ in range(world_size)]
|
||||||
|
dist.all_gather(size_list, local_size)
|
||||||
|
size_list = [int(size.item()) for size in size_list]
|
||||||
|
max_size = max(size_list)
|
||||||
|
|
||||||
|
if local_size != max_size:
|
||||||
|
padding = torch.ByteTensor(size=(max_size - local_size,))
|
||||||
|
tensor = torch.cat((tensor, padding), dim=0)
|
||||||
|
del padding
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
tensor_list = []
|
||||||
|
for _ in size_list:
|
||||||
|
tensor_list.append(torch.ByteTensor(size=(max_size,)))
|
||||||
|
dist.gather(tensor, gather_list=tensor_list, dst=0)
|
||||||
|
del tensor
|
||||||
|
else:
|
||||||
|
dist.gather(tensor, gather_list=[], dst=0)
|
||||||
|
del tensor
|
||||||
|
return
|
||||||
|
|
||||||
|
data_list = []
|
||||||
|
for tensor in tensor_list:
|
||||||
|
buffer = tensor.cpu().numpy().tobytes()
|
||||||
|
del tensor
|
||||||
|
data_list.append(pickle.loads(buffer))
|
||||||
|
del buffer
|
||||||
|
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather(data):
|
||||||
|
"""
|
||||||
|
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||||
|
Args:
|
||||||
|
data: any picklable object
|
||||||
|
Returns:
|
||||||
|
list[data]: list of data gathered from each rank
|
||||||
|
"""
|
||||||
|
world_size = get_world_size()
|
||||||
|
if world_size == 1:
|
||||||
|
return [data]
|
||||||
|
|
||||||
|
# serialized to a Tensor
|
||||||
|
buffer = pickle.dumps(data)
|
||||||
|
storage = torch.ByteStorage.from_buffer(buffer)
|
||||||
|
tensor = torch.ByteTensor(storage).to(device)
|
||||||
|
|
||||||
|
# obtain Tensor size of each rank
|
||||||
|
local_size = torch.LongTensor([tensor.numel()]).to(device)
|
||||||
|
size_list = [torch.LongTensor([0]).to(device) for _ in range(world_size)]
|
||||||
|
dist.all_gather(size_list, local_size)
|
||||||
|
size_list = [int(size.item()) for size in size_list]
|
||||||
|
max_size = max(size_list)
|
||||||
|
|
||||||
|
# receiving Tensor from all ranks
|
||||||
|
# we pad the tensor because torch all_gather does not support
|
||||||
|
# gathering tensors of different shapes
|
||||||
|
tensor_list = []
|
||||||
|
for _ in size_list:
|
||||||
|
tensor_list.append(torch.ByteTensor(size=(max_size,)).to(device))
|
||||||
|
if local_size != max_size:
|
||||||
|
padding = torch.ByteTensor(size=(max_size - local_size,)).to(device)
|
||||||
|
tensor = torch.cat((tensor, padding), dim=0)
|
||||||
|
dist.all_gather(tensor_list, tensor)
|
||||||
|
|
||||||
|
data_list = []
|
||||||
|
for size, tensor in zip(size_list, tensor_list):
|
||||||
|
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||||
|
data_list.append(pickle.loads(buffer))
|
||||||
|
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_dict(input_dict, average=True):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_dict (dict): all the values will be reduced
|
||||||
|
average (bool): whether to do average or sum
|
||||||
|
Reduce the values in the dictionary from all processes so that process with rank
|
||||||
|
0 has the averaged results. Returns a dict with the same fields as
|
||||||
|
input_dict, after reduction.
|
||||||
|
"""
|
||||||
|
world_size = get_world_size()
|
||||||
|
if world_size < 2:
|
||||||
|
return input_dict
|
||||||
|
with torch.no_grad():
|
||||||
|
names = []
|
||||||
|
values = []
|
||||||
|
# sort the keys so that they are consistent across processes
|
||||||
|
for k in sorted(input_dict.keys()):
|
||||||
|
names.append(k)
|
||||||
|
values.append(input_dict[k])
|
||||||
|
values = torch.stack(values, dim=0)
|
||||||
|
dist.reduce(values, dst=0)
|
||||||
|
if dist.get_rank() == 0 and average:
|
||||||
|
# only main process gets accumulated, so only divide by
|
||||||
|
# world_size in this case
|
||||||
|
values /= world_size
|
||||||
|
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||||
|
return reduced_dict
|
||||||
66
mesh_graphormer/utils/dataset_utils.py
Normal file
66
mesh_graphormer/utils/dataset_utils.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import os.path as op
|
||||||
|
import numpy as np
|
||||||
|
import base64
|
||||||
|
import cv2
|
||||||
|
import yaml
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
|
def img_from_base64(imagestring):
|
||||||
|
try:
|
||||||
|
jpgbytestring = base64.b64decode(imagestring)
|
||||||
|
nparr = np.frombuffer(jpgbytestring, np.uint8)
|
||||||
|
r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||||
|
return r
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_labelmap(labelmap_file):
|
||||||
|
label_dict = None
|
||||||
|
if labelmap_file is not None and op.isfile(labelmap_file):
|
||||||
|
label_dict = OrderedDict()
|
||||||
|
with open(labelmap_file, 'r') as fp:
|
||||||
|
for line in fp:
|
||||||
|
label = line.strip().split('\t')[0]
|
||||||
|
if label in label_dict:
|
||||||
|
raise ValueError("Duplicate label " + label + " in labelmap.")
|
||||||
|
else:
|
||||||
|
label_dict[label] = len(label_dict)
|
||||||
|
return label_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_shuffle_file(shuf_file):
|
||||||
|
shuf_list = None
|
||||||
|
if shuf_file is not None:
|
||||||
|
with open(shuf_file, 'r') as fp:
|
||||||
|
shuf_list = []
|
||||||
|
for i in fp:
|
||||||
|
shuf_list.append(int(i.strip()))
|
||||||
|
return shuf_list
|
||||||
|
|
||||||
|
|
||||||
|
def load_box_shuffle_file(shuf_file):
|
||||||
|
if shuf_file is not None:
|
||||||
|
with open(shuf_file, 'r') as fp:
|
||||||
|
img_shuf_list = []
|
||||||
|
box_shuf_list = []
|
||||||
|
for i in fp:
|
||||||
|
idx = [int(_) for _ in i.strip().split('\t')]
|
||||||
|
img_shuf_list.append(idx[0])
|
||||||
|
box_shuf_list.append(idx[1])
|
||||||
|
return [img_shuf_list, box_shuf_list]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_yaml_file(file_name):
|
||||||
|
with open(file_name, 'r') as fp:
|
||||||
|
return yaml.load(fp, Loader=yaml.CLoader)
|
||||||
58
mesh_graphormer/utils/geometric_layers.py
Normal file
58
mesh_graphormer/utils/geometric_layers.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""
|
||||||
|
Useful geometric operations, e.g. Orthographic projection and a differentiable Rodrigues formula
|
||||||
|
|
||||||
|
Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def rodrigues(theta):
|
||||||
|
"""Convert axis-angle representation to rotation matrix.
|
||||||
|
Args:
|
||||||
|
theta: size = [B, 3]
|
||||||
|
Returns:
|
||||||
|
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
||||||
|
"""
|
||||||
|
l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
|
||||||
|
angle = torch.unsqueeze(l1norm, -1)
|
||||||
|
normalized = torch.div(theta, angle)
|
||||||
|
angle = angle * 0.5
|
||||||
|
v_cos = torch.cos(angle)
|
||||||
|
v_sin = torch.sin(angle)
|
||||||
|
quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
|
||||||
|
return quat2mat(quat)
|
||||||
|
|
||||||
|
def quat2mat(quat):
|
||||||
|
"""Convert quaternion coefficients to rotation matrix.
|
||||||
|
Args:
|
||||||
|
quat: size = [B, 4] 4 <===>(w, x, y, z)
|
||||||
|
Returns:
|
||||||
|
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
||||||
|
"""
|
||||||
|
norm_quat = quat
|
||||||
|
norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
|
||||||
|
w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
|
||||||
|
|
||||||
|
B = quat.size(0)
|
||||||
|
|
||||||
|
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
||||||
|
wx, wy, wz = w*x, w*y, w*z
|
||||||
|
xy, xz, yz = x*y, x*z, y*z
|
||||||
|
|
||||||
|
rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
|
||||||
|
2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
|
||||||
|
2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
|
||||||
|
return rotMat
|
||||||
|
|
||||||
|
def orthographic_projection(X, camera):
|
||||||
|
"""Perform orthographic projection of 3D points X using the camera parameters
|
||||||
|
Args:
|
||||||
|
X: size = [B, N, 3]
|
||||||
|
camera: size = [B, 3]
|
||||||
|
Returns:
|
||||||
|
Projected 2D points -- size = [B, N, 2]
|
||||||
|
"""
|
||||||
|
camera = camera.view(-1, 1, 3)
|
||||||
|
X_trans = X[:, :, :2] + camera[:, :, 1:]
|
||||||
|
shape = X_trans.shape
|
||||||
|
X_2d = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape)
|
||||||
|
return X_2d
|
||||||
208
mesh_graphormer/utils/image_ops.py
Normal file
208
mesh_graphormer/utils/image_ops.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
"""
|
||||||
|
Image processing tools
|
||||||
|
|
||||||
|
Modified from open source projects:
|
||||||
|
(https://github.com/nkolot/GraphCMR/)
|
||||||
|
(https://github.com/open-mmlab/mmdetection)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import base64
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import scipy.misc
|
||||||
|
|
||||||
|
def img_from_base64(imagestring):
|
||||||
|
try:
|
||||||
|
jpgbytestring = base64.b64decode(imagestring)
|
||||||
|
nparr = np.frombuffer(jpgbytestring, np.uint8)
|
||||||
|
r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||||
|
return r
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def myimrotate(img, angle, center=None, scale=1.0, border_value=0, auto_bound=False):
|
||||||
|
if center is not None and auto_bound:
|
||||||
|
raise ValueError('`auto_bound` conflicts with `center`')
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
if center is None:
|
||||||
|
center = ((w - 1) * 0.5, (h - 1) * 0.5)
|
||||||
|
assert isinstance(center, tuple)
|
||||||
|
|
||||||
|
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
||||||
|
if auto_bound:
|
||||||
|
cos = np.abs(matrix[0, 0])
|
||||||
|
sin = np.abs(matrix[0, 1])
|
||||||
|
new_w = h * sin + w * cos
|
||||||
|
new_h = h * cos + w * sin
|
||||||
|
matrix[0, 2] += (new_w - w) * 0.5
|
||||||
|
matrix[1, 2] += (new_h - h) * 0.5
|
||||||
|
w = int(np.round(new_w))
|
||||||
|
h = int(np.round(new_h))
|
||||||
|
rotated = cv2.warpAffine(img, matrix, (w, h), borderValue=border_value)
|
||||||
|
return rotated
|
||||||
|
|
||||||
|
def myimresize(img, size, return_scale=False, interpolation='bilinear'):
|
||||||
|
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
resized_img = cv2.resize(
|
||||||
|
img, (size[0],size[1]), interpolation=cv2.INTER_LINEAR)
|
||||||
|
if not return_scale:
|
||||||
|
return resized_img
|
||||||
|
else:
|
||||||
|
w_scale = size[0] / w
|
||||||
|
h_scale = size[1] / h
|
||||||
|
return resized_img, w_scale, h_scale
|
||||||
|
|
||||||
|
|
||||||
|
def get_transform(center, scale, res, rot=0):
|
||||||
|
"""Generate transformation matrix."""
|
||||||
|
h = 200 * scale
|
||||||
|
t = np.zeros((3, 3))
|
||||||
|
t[0, 0] = float(res[1]) / h
|
||||||
|
t[1, 1] = float(res[0]) / h
|
||||||
|
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
|
||||||
|
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
|
||||||
|
t[2, 2] = 1
|
||||||
|
if not rot == 0:
|
||||||
|
rot = -rot # To match direction of rotation from cropping
|
||||||
|
rot_mat = np.zeros((3,3))
|
||||||
|
rot_rad = rot * np.pi / 180
|
||||||
|
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
|
||||||
|
rot_mat[0,:2] = [cs, -sn]
|
||||||
|
rot_mat[1,:2] = [sn, cs]
|
||||||
|
rot_mat[2,2] = 1
|
||||||
|
# Need to rotate around center
|
||||||
|
t_mat = np.eye(3)
|
||||||
|
t_mat[0,2] = -res[1]/2
|
||||||
|
t_mat[1,2] = -res[0]/2
|
||||||
|
t_inv = t_mat.copy()
|
||||||
|
t_inv[:2,2] *= -1
|
||||||
|
t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
|
||||||
|
return t
|
||||||
|
|
||||||
|
def transform(pt, center, scale, res, invert=0, rot=0):
|
||||||
|
"""Transform pixel location to different reference."""
|
||||||
|
t = get_transform(center, scale, res, rot=rot)
|
||||||
|
if invert:
|
||||||
|
# t = np.linalg.inv(t)
|
||||||
|
t_torch = torch.from_numpy(t)
|
||||||
|
t_torch = torch.inverse(t_torch)
|
||||||
|
t = t_torch.numpy()
|
||||||
|
new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T
|
||||||
|
new_pt = np.dot(t, new_pt)
|
||||||
|
return new_pt[:2].astype(int)+1
|
||||||
|
|
||||||
|
def crop(img, center, scale, res, rot=0):
|
||||||
|
"""Crop image according to the supplied bounding box."""
|
||||||
|
# Upper left point
|
||||||
|
ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
|
||||||
|
# Bottom right point
|
||||||
|
br = np.array(transform([res[0]+1,
|
||||||
|
res[1]+1], center, scale, res, invert=1))-1
|
||||||
|
# Padding so that when rotated proper amount of context is included
|
||||||
|
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
|
||||||
|
if not rot == 0:
|
||||||
|
ul -= pad
|
||||||
|
br += pad
|
||||||
|
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
||||||
|
if len(img.shape) > 2:
|
||||||
|
new_shape += [img.shape[2]]
|
||||||
|
new_img = np.zeros(new_shape)
|
||||||
|
|
||||||
|
# Range to fill new array
|
||||||
|
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
|
||||||
|
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
|
||||||
|
# Range to sample from original image
|
||||||
|
old_x = max(0, ul[0]), min(len(img[0]), br[0])
|
||||||
|
old_y = max(0, ul[1]), min(len(img), br[1])
|
||||||
|
|
||||||
|
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
|
||||||
|
old_x[0]:old_x[1]]
|
||||||
|
if not rot == 0:
|
||||||
|
# Remove padding
|
||||||
|
# new_img = scipy.misc.imrotate(new_img, rot)
|
||||||
|
new_img = myimrotate(new_img, rot)
|
||||||
|
new_img = new_img[pad:-pad, pad:-pad]
|
||||||
|
|
||||||
|
# new_img = scipy.misc.imresize(new_img, res)
|
||||||
|
new_img = myimresize(new_img, [res[0], res[1]])
|
||||||
|
return new_img
|
||||||
|
|
||||||
|
def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
|
||||||
|
"""'Undo' the image cropping/resizing.
|
||||||
|
This function is used when evaluating mask/part segmentation.
|
||||||
|
"""
|
||||||
|
res = img.shape[:2]
|
||||||
|
# Upper left point
|
||||||
|
ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
|
||||||
|
# Bottom right point
|
||||||
|
br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1
|
||||||
|
# size of cropped image
|
||||||
|
crop_shape = [br[1] - ul[1], br[0] - ul[0]]
|
||||||
|
|
||||||
|
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
||||||
|
if len(img.shape) > 2:
|
||||||
|
new_shape += [img.shape[2]]
|
||||||
|
new_img = np.zeros(orig_shape, dtype=np.uint8)
|
||||||
|
# Range to fill new array
|
||||||
|
new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
|
||||||
|
new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
|
||||||
|
# Range to sample from original image
|
||||||
|
old_x = max(0, ul[0]), min(orig_shape[1], br[0])
|
||||||
|
old_y = max(0, ul[1]), min(orig_shape[0], br[1])
|
||||||
|
# img = scipy.misc.imresize(img, crop_shape, interp='nearest')
|
||||||
|
img = myimresize(img, [crop_shape[0],crop_shape[1]])
|
||||||
|
new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
|
||||||
|
return new_img
|
||||||
|
|
||||||
|
def rot_aa(aa, rot):
|
||||||
|
"""Rotate axis angle parameters."""
|
||||||
|
# pose parameters
|
||||||
|
R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
||||||
|
[np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
|
||||||
|
[0, 0, 1]])
|
||||||
|
# find the rotation of the body in camera frame
|
||||||
|
per_rdg, _ = cv2.Rodrigues(aa)
|
||||||
|
# apply the global rotation to the global orientation
|
||||||
|
resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
|
||||||
|
aa = (resrot.T)[0]
|
||||||
|
return aa
|
||||||
|
|
||||||
|
def flip_img(img):
|
||||||
|
"""Flip rgb images or masks.
|
||||||
|
channels come last, e.g. (256,256,3).
|
||||||
|
"""
|
||||||
|
img = np.fliplr(img)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def flip_kp(kp):
|
||||||
|
"""Flip keypoints."""
|
||||||
|
flipped_parts = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22]
|
||||||
|
kp = kp[flipped_parts]
|
||||||
|
kp[:,0] = - kp[:,0]
|
||||||
|
return kp
|
||||||
|
|
||||||
|
def flip_pose(pose):
|
||||||
|
"""Flip pose.
|
||||||
|
The flipping is based on SMPL parameters.
|
||||||
|
"""
|
||||||
|
flippedParts = [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13,
|
||||||
|
14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33,
|
||||||
|
34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41,
|
||||||
|
45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55,
|
||||||
|
56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68]
|
||||||
|
pose = pose[flippedParts]
|
||||||
|
# we also negate the second and the third dimension of the axis-angle
|
||||||
|
pose[1::3] = -pose[1::3]
|
||||||
|
pose[2::3] = -pose[2::3]
|
||||||
|
return pose
|
||||||
|
|
||||||
|
def flip_aa(aa):
|
||||||
|
"""Flip axis-angle representation.
|
||||||
|
We negate the second and the third dimension of the axis-angle.
|
||||||
|
"""
|
||||||
|
aa[1] = -aa[1]
|
||||||
|
aa[2] = -aa[2]
|
||||||
|
return aa
|
||||||
100
mesh_graphormer/utils/logger.py
Normal file
100
mesh_graphormer/utils/logger.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from logging import StreamHandler, Handler, getLevelName
|
||||||
|
|
||||||
|
|
||||||
|
# this class is a copy of logging.FileHandler except we end self.close()
|
||||||
|
# at the end of each emit. While closing file and reopening file after each
|
||||||
|
# write is not efficient, it allows us to see partial logs when writing to
|
||||||
|
# fused Azure blobs, which is very convenient
|
||||||
|
class FileHandler(StreamHandler):
|
||||||
|
"""
|
||||||
|
A handler class which writes formatted logging records to disk files.
|
||||||
|
"""
|
||||||
|
def __init__(self, filename, mode='a', encoding=None, delay=False):
|
||||||
|
"""
|
||||||
|
Open the specified file and use it as the stream for logging.
|
||||||
|
"""
|
||||||
|
# Issue #27493: add support for Path objects to be passed in
|
||||||
|
filename = os.fspath(filename)
|
||||||
|
#keep the absolute path, otherwise derived classes which use this
|
||||||
|
#may come a cropper when the current directory changes
|
||||||
|
self.baseFilename = os.path.abspath(filename)
|
||||||
|
self.mode = mode
|
||||||
|
self.encoding = encoding
|
||||||
|
self.delay = delay
|
||||||
|
if delay:
|
||||||
|
#We don't open the stream, but we still need to call the
|
||||||
|
#Handler constructor to set level, formatter, lock etc.
|
||||||
|
Handler.__init__(self)
|
||||||
|
self.stream = None
|
||||||
|
else:
|
||||||
|
StreamHandler.__init__(self, self._open())
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Closes the stream.
|
||||||
|
"""
|
||||||
|
self.acquire()
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
if self.stream:
|
||||||
|
try:
|
||||||
|
self.flush()
|
||||||
|
finally:
|
||||||
|
stream = self.stream
|
||||||
|
self.stream = None
|
||||||
|
if hasattr(stream, "close"):
|
||||||
|
stream.close()
|
||||||
|
finally:
|
||||||
|
# Issue #19523: call unconditionally to
|
||||||
|
# prevent a handler leak when delay is set
|
||||||
|
StreamHandler.close(self)
|
||||||
|
finally:
|
||||||
|
self.release()
|
||||||
|
|
||||||
|
def _open(self):
|
||||||
|
"""
|
||||||
|
Open the current base file with the (original) mode and encoding.
|
||||||
|
Return the resulting stream.
|
||||||
|
"""
|
||||||
|
return open(self.baseFilename, self.mode, encoding=self.encoding)
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
"""
|
||||||
|
Emit a record.
|
||||||
|
|
||||||
|
If the stream was not opened because 'delay' was specified in the
|
||||||
|
constructor, open it before calling the superclass's emit.
|
||||||
|
"""
|
||||||
|
if self.stream is None:
|
||||||
|
self.stream = self._open()
|
||||||
|
StreamHandler.emit(self, record)
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
level = getLevelName(self.level)
|
||||||
|
return '<%s %s (%s)>' % (self.__class__.__name__, self.baseFilename, level)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(name, save_dir, distributed_rank, filename="log.txt"):
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
# don't log results for the non-master process
|
||||||
|
if distributed_rank > 0:
|
||||||
|
return logger
|
||||||
|
ch = logging.StreamHandler(stream=sys.stdout)
|
||||||
|
ch.setLevel(logging.DEBUG)
|
||||||
|
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
|
||||||
|
ch.setFormatter(formatter)
|
||||||
|
logger.addHandler(ch)
|
||||||
|
|
||||||
|
if save_dir:
|
||||||
|
fh = FileHandler(os.path.join(save_dir, filename))
|
||||||
|
fh.setLevel(logging.DEBUG)
|
||||||
|
fh.setFormatter(formatter)
|
||||||
|
logger.addHandler(fh)
|
||||||
|
|
||||||
|
return logger
|
||||||
45
mesh_graphormer/utils/metric_logger.py
Normal file
45
mesh_graphormer/utils/metric_logger.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
Basic logger. It Computes and stores the average and current value
|
||||||
|
"""
|
||||||
|
|
||||||
|
class AverageMeter(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.val = val
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class EvalMetricsLogger(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
# define a upper-bound performance (worst case)
|
||||||
|
# numbers are in unit millimeter
|
||||||
|
self.PAmPJPE = 100.0/1000.0
|
||||||
|
self.mPJPE = 100.0/1000.0
|
||||||
|
self.mPVE = 100.0/1000.0
|
||||||
|
|
||||||
|
self.epoch = 0
|
||||||
|
|
||||||
|
def update(self, mPVE, mPJPE, PAmPJPE, epoch):
|
||||||
|
self.PAmPJPE = PAmPJPE
|
||||||
|
self.mPJPE = mPJPE
|
||||||
|
self.mPVE = mPVE
|
||||||
|
self.epoch = epoch
|
||||||
99
mesh_graphormer/utils/metric_pampjpe.py
Normal file
99
mesh_graphormer/utils/metric_pampjpe.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""
|
||||||
|
Functions for compuing Procrustes alignment and reconstruction error
|
||||||
|
|
||||||
|
Parts of the code are adapted from https://github.com/akanazawa/hmr
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def compute_similarity_transform(S1, S2):
|
||||||
|
"""Computes a similarity transform (sR, t) that takes
|
||||||
|
a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
|
||||||
|
where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
|
||||||
|
i.e. solves the orthogonal Procrutes problem.
|
||||||
|
"""
|
||||||
|
transposed = False
|
||||||
|
if S1.shape[0] != 3 and S1.shape[0] != 2:
|
||||||
|
S1 = S1.T
|
||||||
|
S2 = S2.T
|
||||||
|
transposed = True
|
||||||
|
assert(S2.shape[1] == S1.shape[1])
|
||||||
|
|
||||||
|
# 1. Remove mean.
|
||||||
|
mu1 = S1.mean(axis=1, keepdims=True)
|
||||||
|
mu2 = S2.mean(axis=1, keepdims=True)
|
||||||
|
X1 = S1 - mu1
|
||||||
|
X2 = S2 - mu2
|
||||||
|
|
||||||
|
# 2. Compute variance of X1 used for scale.
|
||||||
|
var1 = np.sum(X1**2)
|
||||||
|
|
||||||
|
# 3. The outer product of X1 and X2.
|
||||||
|
K = X1.dot(X2.T)
|
||||||
|
|
||||||
|
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
|
||||||
|
# singular vectors of K.
|
||||||
|
U, s, Vh = np.linalg.svd(K)
|
||||||
|
V = Vh.T
|
||||||
|
# Construct Z that fixes the orientation of R to get det(R)=1.
|
||||||
|
Z = np.eye(U.shape[0])
|
||||||
|
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
|
||||||
|
# Construct R.
|
||||||
|
R = V.dot(Z.dot(U.T))
|
||||||
|
|
||||||
|
# 5. Recover scale.
|
||||||
|
scale = np.trace(R.dot(K)) / var1
|
||||||
|
|
||||||
|
# 6. Recover translation.
|
||||||
|
t = mu2 - scale*(R.dot(mu1))
|
||||||
|
|
||||||
|
# 7. Error:
|
||||||
|
S1_hat = scale*R.dot(S1) + t
|
||||||
|
|
||||||
|
if transposed:
|
||||||
|
S1_hat = S1_hat.T
|
||||||
|
|
||||||
|
return S1_hat
|
||||||
|
|
||||||
|
def compute_similarity_transform_batch(S1, S2):
|
||||||
|
"""Batched version of compute_similarity_transform."""
|
||||||
|
S1_hat = np.zeros_like(S1)
|
||||||
|
for i in range(S1.shape[0]):
|
||||||
|
S1_hat[i] = compute_similarity_transform(S1[i], S2[i])
|
||||||
|
return S1_hat
|
||||||
|
|
||||||
|
def reconstruction_error(S1, S2, reduction='mean'):
|
||||||
|
"""Do Procrustes alignment and compute reconstruction error."""
|
||||||
|
S1_hat = compute_similarity_transform_batch(S1, S2)
|
||||||
|
re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
|
||||||
|
if reduction == 'mean':
|
||||||
|
re = re.mean()
|
||||||
|
elif reduction == 'sum':
|
||||||
|
re = re.sum()
|
||||||
|
return re
|
||||||
|
|
||||||
|
|
||||||
|
def reconstruction_error_v2(S1, S2, J24_TO_J14, reduction='mean'):
|
||||||
|
"""Do Procrustes alignment and compute reconstruction error."""
|
||||||
|
S1_hat = compute_similarity_transform_batch(S1, S2)
|
||||||
|
S1_hat = S1_hat[:,J24_TO_J14,:]
|
||||||
|
S2 = S2[:,J24_TO_J14,:]
|
||||||
|
re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
|
||||||
|
if reduction == 'mean':
|
||||||
|
re = re.mean()
|
||||||
|
elif reduction == 'sum':
|
||||||
|
re = re.sum()
|
||||||
|
return re
|
||||||
|
|
||||||
|
def get_alignMesh(S1, S2, reduction='mean'):
|
||||||
|
"""Do Procrustes alignment and compute reconstruction error."""
|
||||||
|
S1_hat = compute_similarity_transform_batch(S1, S2)
|
||||||
|
re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
|
||||||
|
if reduction == 'mean':
|
||||||
|
re = re.mean()
|
||||||
|
elif reduction == 'sum':
|
||||||
|
re = re.sum()
|
||||||
|
return re, S1_hat, S2
|
||||||
171
mesh_graphormer/utils/miscellaneous.py
Normal file
171
mesh_graphormer/utils/miscellaneous.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||||
|
import errno
|
||||||
|
import os
|
||||||
|
import os.path as op
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
from .comm import is_main_process
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir(path):
|
||||||
|
# if it is the current folder, skip.
|
||||||
|
# otherwise the original code will raise FileNotFoundError
|
||||||
|
if path == '':
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
os.makedirs(path)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno != errno.EEXIST:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def save_config(cfg, path):
|
||||||
|
if is_main_process():
|
||||||
|
with open(path, 'w') as f:
|
||||||
|
f.write(cfg.dump())
|
||||||
|
|
||||||
|
|
||||||
|
def config_iteration(output_dir, max_iter):
|
||||||
|
save_file = os.path.join(output_dir, 'last_checkpoint')
|
||||||
|
iteration = -1
|
||||||
|
if os.path.exists(save_file):
|
||||||
|
with open(save_file, 'r') as f:
|
||||||
|
fname = f.read().strip()
|
||||||
|
model_name = os.path.basename(fname)
|
||||||
|
model_path = os.path.dirname(fname)
|
||||||
|
if model_name.startswith('model_') and len(model_name) == 17:
|
||||||
|
iteration = int(model_name[-11:-4])
|
||||||
|
elif model_name == "model_final":
|
||||||
|
iteration = max_iter
|
||||||
|
elif model_path.startswith('checkpoint-') and len(model_path) == 18:
|
||||||
|
iteration = int(model_path.split('-')[-1])
|
||||||
|
return iteration
|
||||||
|
|
||||||
|
|
||||||
|
def get_matching_parameters(model, regexp, none_on_empty=True):
|
||||||
|
"""Returns parameters matching regular expression"""
|
||||||
|
if not regexp:
|
||||||
|
if none_on_empty:
|
||||||
|
return {}
|
||||||
|
else:
|
||||||
|
return dict(model.named_parameters())
|
||||||
|
compiled_pattern = re.compile(regexp)
|
||||||
|
params = {}
|
||||||
|
for weight_name, weight in model.named_parameters():
|
||||||
|
if compiled_pattern.match(weight_name):
|
||||||
|
params[weight_name] = weight
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_weights(model, regexp):
|
||||||
|
"""Freeze weights based on regular expression."""
|
||||||
|
logger = logging.getLogger("maskrcnn_benchmark.trainer")
|
||||||
|
for weight_name, weight in get_matching_parameters(model, regexp).items():
|
||||||
|
weight.requires_grad = False
|
||||||
|
logger.info("Disabled training of {}".format(weight_name))
|
||||||
|
|
||||||
|
|
||||||
|
def unfreeze_weights(model, regexp, backbone_freeze_at=-1,
|
||||||
|
is_distributed=False):
|
||||||
|
"""Unfreeze weights based on regular expression.
|
||||||
|
This is helpful during training to unfreeze freezed weights after
|
||||||
|
other unfreezed weights have been trained for some iterations.
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger("maskrcnn_benchmark.trainer")
|
||||||
|
for weight_name, weight in get_matching_parameters(model, regexp).items():
|
||||||
|
weight.requires_grad = True
|
||||||
|
logger.info("Enabled training of {}".format(weight_name))
|
||||||
|
if backbone_freeze_at >= 0:
|
||||||
|
logger.info("Freeze backbone at stage: {}".format(backbone_freeze_at))
|
||||||
|
if is_distributed:
|
||||||
|
model.module.backbone.body._freeze_backbone(backbone_freeze_at)
|
||||||
|
else:
|
||||||
|
model.backbone.body._freeze_backbone(backbone_freeze_at)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_tsv_files(tsvs):
|
||||||
|
for t in tsvs:
|
||||||
|
if op.isfile(t):
|
||||||
|
try_delete(t)
|
||||||
|
line = op.splitext(t)[0] + '.lineidx'
|
||||||
|
if op.isfile(line):
|
||||||
|
try_delete(line)
|
||||||
|
|
||||||
|
|
||||||
|
def concat_files(ins, out):
|
||||||
|
mkdir(op.dirname(out))
|
||||||
|
out_tmp = out + '.tmp'
|
||||||
|
with open(out_tmp, 'wb') as fp_out:
|
||||||
|
for i, f in enumerate(ins):
|
||||||
|
logging.info('concating {}/{} - {}'.format(i, len(ins), f))
|
||||||
|
with open(f, 'rb') as fp_in:
|
||||||
|
shutil.copyfileobj(fp_in, fp_out, 1024*1024*10)
|
||||||
|
os.rename(out_tmp, out)
|
||||||
|
|
||||||
|
|
||||||
|
def concat_tsv_files(tsvs, out_tsv):
|
||||||
|
concat_files(tsvs, out_tsv)
|
||||||
|
sizes = [os.stat(t).st_size for t in tsvs]
|
||||||
|
sizes = np.cumsum(sizes)
|
||||||
|
all_idx = []
|
||||||
|
for i, t in enumerate(tsvs):
|
||||||
|
for idx in load_list_file(op.splitext(t)[0] + '.lineidx'):
|
||||||
|
if i == 0:
|
||||||
|
all_idx.append(idx)
|
||||||
|
else:
|
||||||
|
all_idx.append(str(int(idx) + sizes[i - 1]))
|
||||||
|
with open(op.splitext(out_tsv)[0] + '.lineidx', 'w') as f:
|
||||||
|
f.write('\n'.join(all_idx))
|
||||||
|
|
||||||
|
|
||||||
|
def load_list_file(fname):
|
||||||
|
with open(fname, 'r') as fp:
|
||||||
|
lines = fp.readlines()
|
||||||
|
result = [line.strip() for line in lines]
|
||||||
|
if len(result) > 0 and result[-1] == '':
|
||||||
|
result = result[:-1]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def try_once(func):
|
||||||
|
def func_wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logging.info('ignore error \n{}'.format(str(e)))
|
||||||
|
return func_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@try_once
|
||||||
|
def try_delete(f):
|
||||||
|
os.remove(f)
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed, n_gpu):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if n_gpu > 0:
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def print_and_run_cmd(cmd):
|
||||||
|
print(cmd)
|
||||||
|
os.system(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def write_to_yaml_file(context, file_name):
|
||||||
|
with open(file_name, 'w') as fp:
|
||||||
|
yaml.dump(context, fp, encoding='utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_yaml_file(yaml_file):
|
||||||
|
with open(yaml_file, 'r') as fp:
|
||||||
|
return yaml.load(fp, Loader=yaml.CLoader)
|
||||||
|
|
||||||
|
|
||||||
691
mesh_graphormer/utils/renderer.py
Normal file
691
mesh_graphormer/utils/renderer.py
Normal file
@@ -0,0 +1,691 @@
|
|||||||
|
"""
|
||||||
|
Rendering tools for 3D mesh visualization on 2D image.
|
||||||
|
|
||||||
|
Parts of the code are taken from https://github.com/akanazawa/hmr
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import code
|
||||||
|
from opendr.camera import ProjectPoints
|
||||||
|
from opendr.renderer import ColoredRenderer, TexturedRenderer
|
||||||
|
from opendr.lighting import LambertianPointLight
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
# Rotate the points by a specified angle.
|
||||||
|
def rotateY(points, angle):
|
||||||
|
ry = np.array([
|
||||||
|
[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.],
|
||||||
|
[-np.sin(angle), 0., np.cos(angle)]
|
||||||
|
])
|
||||||
|
return np.dot(points, ry)
|
||||||
|
|
||||||
|
def draw_skeleton(input_image, joints, draw_edges=True, vis=None, radius=None):
|
||||||
|
"""
|
||||||
|
joints is 3 x 19. but if not will transpose it.
|
||||||
|
0: Right ankle
|
||||||
|
1: Right knee
|
||||||
|
2: Right hip
|
||||||
|
3: Left hip
|
||||||
|
4: Left knee
|
||||||
|
5: Left ankle
|
||||||
|
6: Right wrist
|
||||||
|
7: Right elbow
|
||||||
|
8: Right shoulder
|
||||||
|
9: Left shoulder
|
||||||
|
10: Left elbow
|
||||||
|
11: Left wrist
|
||||||
|
12: Neck
|
||||||
|
13: Head top
|
||||||
|
14: nose
|
||||||
|
15: left_eye
|
||||||
|
16: right_eye
|
||||||
|
17: left_ear
|
||||||
|
18: right_ear
|
||||||
|
"""
|
||||||
|
|
||||||
|
if radius is None:
|
||||||
|
radius = max(4, (np.mean(input_image.shape[:2]) * 0.01).astype(int))
|
||||||
|
|
||||||
|
colors = {
|
||||||
|
'pink': (197, 27, 125), # L lower leg
|
||||||
|
'light_pink': (233, 163, 201), # L upper leg
|
||||||
|
'light_green': (161, 215, 106), # L lower arm
|
||||||
|
'green': (77, 146, 33), # L upper arm
|
||||||
|
'red': (215, 48, 39), # head
|
||||||
|
'light_red': (252, 146, 114), # head
|
||||||
|
'light_orange': (252, 141, 89), # chest
|
||||||
|
'purple': (118, 42, 131), # R lower leg
|
||||||
|
'light_purple': (175, 141, 195), # R upper
|
||||||
|
'light_blue': (145, 191, 219), # R lower arm
|
||||||
|
'blue': (69, 117, 180), # R upper arm
|
||||||
|
'gray': (130, 130, 130), #
|
||||||
|
'white': (255, 255, 255), #
|
||||||
|
}
|
||||||
|
|
||||||
|
image = input_image.copy()
|
||||||
|
input_is_float = False
|
||||||
|
|
||||||
|
if np.issubdtype(image.dtype, np.float):
|
||||||
|
input_is_float = True
|
||||||
|
max_val = image.max()
|
||||||
|
if max_val <= 2.: # should be 1 but sometimes it's slightly above 1
|
||||||
|
image = (image * 255).astype(np.uint8)
|
||||||
|
else:
|
||||||
|
image = (image).astype(np.uint8)
|
||||||
|
|
||||||
|
if joints.shape[0] != 2:
|
||||||
|
joints = joints.T
|
||||||
|
joints = np.round(joints).astype(int)
|
||||||
|
|
||||||
|
jcolors = [
|
||||||
|
'light_pink', 'light_pink', 'light_pink', 'pink', 'pink', 'pink',
|
||||||
|
'light_blue', 'light_blue', 'light_blue', 'blue', 'blue', 'blue',
|
||||||
|
'purple', 'purple', 'red', 'green', 'green', 'white', 'white',
|
||||||
|
'purple', 'purple', 'red', 'green', 'green', 'white', 'white'
|
||||||
|
]
|
||||||
|
|
||||||
|
if joints.shape[1] == 19:
|
||||||
|
# parent indices -1 means no parents
|
||||||
|
parents = np.array([
|
||||||
|
1, 2, 8, 9, 3, 4, 7, 8, 12, 12, 9, 10, 14, -1, 13, -1, -1, 15, 16
|
||||||
|
])
|
||||||
|
# Left is light and right is dark
|
||||||
|
ecolors = {
|
||||||
|
0: 'light_pink',
|
||||||
|
1: 'light_pink',
|
||||||
|
2: 'light_pink',
|
||||||
|
3: 'pink',
|
||||||
|
4: 'pink',
|
||||||
|
5: 'pink',
|
||||||
|
6: 'light_blue',
|
||||||
|
7: 'light_blue',
|
||||||
|
8: 'light_blue',
|
||||||
|
9: 'blue',
|
||||||
|
10: 'blue',
|
||||||
|
11: 'blue',
|
||||||
|
12: 'purple',
|
||||||
|
17: 'light_green',
|
||||||
|
18: 'light_green',
|
||||||
|
14: 'purple'
|
||||||
|
}
|
||||||
|
elif joints.shape[1] == 14:
|
||||||
|
parents = np.array([
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
-1,
|
||||||
|
-1,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
13,
|
||||||
|
-1,
|
||||||
|
])
|
||||||
|
ecolors = {
|
||||||
|
0: 'light_pink',
|
||||||
|
1: 'light_pink',
|
||||||
|
2: 'light_pink',
|
||||||
|
3: 'pink',
|
||||||
|
4: 'pink',
|
||||||
|
5: 'pink',
|
||||||
|
6: 'light_blue',
|
||||||
|
7: 'light_blue',
|
||||||
|
10: 'light_blue',
|
||||||
|
11: 'blue',
|
||||||
|
12: 'purple'
|
||||||
|
}
|
||||||
|
elif joints.shape[1] == 21: # hand
|
||||||
|
parents = np.array([
|
||||||
|
-1,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
0,
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
0,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
11,
|
||||||
|
0,
|
||||||
|
13,
|
||||||
|
14,
|
||||||
|
15,
|
||||||
|
0,
|
||||||
|
17,
|
||||||
|
18,
|
||||||
|
19,
|
||||||
|
])
|
||||||
|
ecolors = {
|
||||||
|
0: 'light_purple',
|
||||||
|
1: 'light_green',
|
||||||
|
2: 'light_green',
|
||||||
|
3: 'light_green',
|
||||||
|
4: 'light_green',
|
||||||
|
5: 'pink',
|
||||||
|
6: 'pink',
|
||||||
|
7: 'pink',
|
||||||
|
8: 'pink',
|
||||||
|
9: 'light_blue',
|
||||||
|
10: 'light_blue',
|
||||||
|
11: 'light_blue',
|
||||||
|
12: 'light_blue',
|
||||||
|
13: 'light_red',
|
||||||
|
14: 'light_red',
|
||||||
|
15: 'light_red',
|
||||||
|
16: 'light_red',
|
||||||
|
17: 'purple',
|
||||||
|
18: 'purple',
|
||||||
|
19: 'purple',
|
||||||
|
20: 'purple',
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
print('Unknown skeleton!!')
|
||||||
|
|
||||||
|
for child in range(len(parents)):
|
||||||
|
point = joints[:, child]
|
||||||
|
# If invisible skip
|
||||||
|
if vis is not None and vis[child] == 0:
|
||||||
|
continue
|
||||||
|
if draw_edges:
|
||||||
|
cv2.circle(image, (point[0], point[1]), radius, colors['white'],
|
||||||
|
-1)
|
||||||
|
cv2.circle(image, (point[0], point[1]), radius - 1,
|
||||||
|
colors[jcolors[child]], -1)
|
||||||
|
else:
|
||||||
|
# cv2.circle(image, (point[0], point[1]), 5, colors['white'], 1)
|
||||||
|
cv2.circle(image, (point[0], point[1]), radius - 1,
|
||||||
|
colors[jcolors[child]], 1)
|
||||||
|
# cv2.circle(image, (point[0], point[1]), 5, colors['gray'], -1)
|
||||||
|
pa_id = parents[child]
|
||||||
|
if draw_edges and pa_id >= 0:
|
||||||
|
if vis is not None and vis[pa_id] == 0:
|
||||||
|
continue
|
||||||
|
point_pa = joints[:, pa_id]
|
||||||
|
cv2.circle(image, (point_pa[0], point_pa[1]), radius - 1,
|
||||||
|
colors[jcolors[pa_id]], -1)
|
||||||
|
if child not in ecolors.keys():
|
||||||
|
print('bad')
|
||||||
|
import ipdb
|
||||||
|
ipdb.set_trace()
|
||||||
|
cv2.line(image, (point[0], point[1]), (point_pa[0], point_pa[1]),
|
||||||
|
colors[ecolors[child]], radius - 2)
|
||||||
|
|
||||||
|
# Convert back in original dtype
|
||||||
|
if input_is_float:
|
||||||
|
if max_val <= 1.:
|
||||||
|
image = image.astype(np.float32) / 255.
|
||||||
|
else:
|
||||||
|
image = image.astype(np.float32)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def draw_text(input_image, content):
|
||||||
|
"""
|
||||||
|
content is a dict. draws key: val on image
|
||||||
|
Assumes key is str, val is float
|
||||||
|
"""
|
||||||
|
image = input_image.copy()
|
||||||
|
input_is_float = False
|
||||||
|
if np.issubdtype(image.dtype, np.float):
|
||||||
|
input_is_float = True
|
||||||
|
image = (image * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
black = (255, 255, 0)
|
||||||
|
margin = 15
|
||||||
|
start_x = 5
|
||||||
|
start_y = margin
|
||||||
|
for key in sorted(content.keys()):
|
||||||
|
text = "%s: %.2g" % (key, content[key])
|
||||||
|
cv2.putText(image, text, (start_x, start_y), 0, 0.45, black)
|
||||||
|
start_y += margin
|
||||||
|
|
||||||
|
if input_is_float:
|
||||||
|
image = image.astype(np.float32) / 255.
|
||||||
|
return image
|
||||||
|
|
||||||
|
def visualize_reconstruction(img, img_size, gt_kp, vertices, pred_kp, camera, renderer, color='pink', focal_length=1000):
|
||||||
|
"""Overlays gt_kp and pred_kp on img.
|
||||||
|
Draws vert with text.
|
||||||
|
Renderer is an instance of SMPLRenderer.
|
||||||
|
"""
|
||||||
|
gt_vis = gt_kp[:, 2].astype(bool)
|
||||||
|
loss = np.sum((gt_kp[gt_vis, :2] - pred_kp[gt_vis])**2)
|
||||||
|
debug_text = {"sc": camera[0], "tx": camera[1], "ty": camera[2], "kpl": loss}
|
||||||
|
# Fix a flength so i can render this with persp correct scale
|
||||||
|
res = img.shape[1]
|
||||||
|
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
|
||||||
|
rend_img = renderer.render(vertices, camera_t=camera_t,
|
||||||
|
img=img, use_bg=True,
|
||||||
|
focal_length=focal_length,
|
||||||
|
body_color=color)
|
||||||
|
rend_img = draw_text(rend_img, debug_text)
|
||||||
|
|
||||||
|
# Draw skeleton
|
||||||
|
gt_joint = ((gt_kp[:, :2] + 1) * 0.5) * img_size
|
||||||
|
pred_joint = ((pred_kp + 1) * 0.5) * img_size
|
||||||
|
img_with_gt = draw_skeleton( img, gt_joint, draw_edges=False, vis=gt_vis)
|
||||||
|
skel_img = draw_skeleton(img_with_gt, pred_joint)
|
||||||
|
|
||||||
|
combined = np.hstack([skel_img, rend_img])
|
||||||
|
|
||||||
|
return combined
|
||||||
|
|
||||||
|
def visualize_reconstruction_test(img, img_size, gt_kp, vertices, pred_kp, camera, renderer, score, color='pink', focal_length=1000):
|
||||||
|
"""Overlays gt_kp and pred_kp on img.
|
||||||
|
Draws vert with text.
|
||||||
|
Renderer is an instance of SMPLRenderer.
|
||||||
|
"""
|
||||||
|
gt_vis = gt_kp[:, 2].astype(bool)
|
||||||
|
loss = np.sum((gt_kp[gt_vis, :2] - pred_kp[gt_vis])**2)
|
||||||
|
debug_text = {"sc": camera[0], "tx": camera[1], "ty": camera[2], "kpl": loss, "pa-mpjpe": score*1000}
|
||||||
|
# Fix a flength so i can render this with persp correct scale
|
||||||
|
res = img.shape[1]
|
||||||
|
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
|
||||||
|
rend_img = renderer.render(vertices, camera_t=camera_t,
|
||||||
|
img=img, use_bg=True,
|
||||||
|
focal_length=focal_length,
|
||||||
|
body_color=color)
|
||||||
|
rend_img = draw_text(rend_img, debug_text)
|
||||||
|
|
||||||
|
# Draw skeleton
|
||||||
|
gt_joint = ((gt_kp[:, :2] + 1) * 0.5) * img_size
|
||||||
|
pred_joint = ((pred_kp + 1) * 0.5) * img_size
|
||||||
|
img_with_gt = draw_skeleton( img, gt_joint, draw_edges=False, vis=gt_vis)
|
||||||
|
skel_img = draw_skeleton(img_with_gt, pred_joint)
|
||||||
|
|
||||||
|
combined = np.hstack([skel_img, rend_img])
|
||||||
|
|
||||||
|
return combined
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_reconstruction_and_att(img, img_size, vertices_full, vertices, vertices_2d, camera, renderer, ref_points, attention, focal_length=1000):
|
||||||
|
"""Overlays gt_kp and pred_kp on img.
|
||||||
|
Draws vert with text.
|
||||||
|
Renderer is an instance of SMPLRenderer.
|
||||||
|
"""
|
||||||
|
# Fix a flength so i can render this with persp correct scale
|
||||||
|
res = img.shape[1]
|
||||||
|
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
|
||||||
|
rend_img = renderer.render(vertices_full, camera_t=camera_t,
|
||||||
|
img=img, use_bg=True,
|
||||||
|
focal_length=focal_length, body_color='light_blue')
|
||||||
|
|
||||||
|
|
||||||
|
heads_num, vertex_num, _ = attention.shape
|
||||||
|
|
||||||
|
all_head = np.zeros((vertex_num,vertex_num))
|
||||||
|
|
||||||
|
###### find max
|
||||||
|
# for i in range(vertex_num):
|
||||||
|
# for j in range(vertex_num):
|
||||||
|
# all_head[i,j] = np.max(attention[:,i,j])
|
||||||
|
|
||||||
|
##### find avg
|
||||||
|
for h in range(4):
|
||||||
|
att_per_img = attention[h]
|
||||||
|
all_head = all_head + att_per_img
|
||||||
|
all_head = all_head/4
|
||||||
|
|
||||||
|
col_sums = all_head.sum(axis=0)
|
||||||
|
all_head = all_head / col_sums[np.newaxis, :]
|
||||||
|
|
||||||
|
|
||||||
|
# code.interact(local=locals())
|
||||||
|
|
||||||
|
combined = []
|
||||||
|
if vertex_num>400: # body
|
||||||
|
selected_joints = [6,7,4,5,13] # [6,7,4,5,13,12]
|
||||||
|
else: # hand
|
||||||
|
selected_joints = [0, 4, 8, 12, 16, 20]
|
||||||
|
# Draw attention
|
||||||
|
for ii in range(len(selected_joints)):
|
||||||
|
reference_id = selected_joints[ii]
|
||||||
|
ref_point = ref_points[reference_id]
|
||||||
|
attention_to_show = all_head[reference_id][14::]
|
||||||
|
min_v = np.min(attention_to_show)
|
||||||
|
max_v = np.max(attention_to_show)
|
||||||
|
norm_attention_to_show = (attention_to_show - min_v)/(max_v-min_v)
|
||||||
|
|
||||||
|
vertices_norm = ((vertices_2d + 1) * 0.5) * img_size
|
||||||
|
ref_norm = ((ref_point + 1) * 0.5) * img_size
|
||||||
|
image = np.zeros_like(rend_img)
|
||||||
|
|
||||||
|
for jj in range(vertices_norm.shape[0]):
|
||||||
|
x = int(vertices_norm[jj,0])
|
||||||
|
y = int(vertices_norm[jj,1])
|
||||||
|
cv2.circle(image,(x,y), 1, (255,255,255), -1)
|
||||||
|
|
||||||
|
total_to_draw = []
|
||||||
|
for jj in range(vertices_norm.shape[0]):
|
||||||
|
thres = 0.0
|
||||||
|
if norm_attention_to_show[jj]>thres:
|
||||||
|
things = [norm_attention_to_show[jj], ref_norm, vertices_norm[jj]]
|
||||||
|
total_to_draw.append(things)
|
||||||
|
# plot_one_line(ref_norm, vertices_norm[jj], image, reference_id, alpha=0.4*(norm_attention_to_show[jj]-thres)/(1-thres) )
|
||||||
|
total_to_draw.sort()
|
||||||
|
max_att_score = total_to_draw[-1][0]
|
||||||
|
for item in total_to_draw:
|
||||||
|
attention_score = item[0]
|
||||||
|
ref_point = item[1]
|
||||||
|
vertex = item[2]
|
||||||
|
plot_one_line(ref_point, vertex, image, ii, alpha=(attention_score-thres)/(max_att_score-thres) )
|
||||||
|
# code.interact(local=locals())
|
||||||
|
if len(combined)==0:
|
||||||
|
combined = image
|
||||||
|
else:
|
||||||
|
combined = np.hstack([combined, image])
|
||||||
|
|
||||||
|
final = np.hstack([img, combined, rend_img])
|
||||||
|
|
||||||
|
return final
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_reconstruction_and_att_local(img, img_size, vertices_full, vertices, vertices_2d, camera, renderer, ref_points, attention, color='light_blue', focal_length=1000):
|
||||||
|
"""Overlays gt_kp and pred_kp on img.
|
||||||
|
Draws vert with text.
|
||||||
|
Renderer is an instance of SMPLRenderer.
|
||||||
|
"""
|
||||||
|
# Fix a flength so i can render this with persp correct scale
|
||||||
|
res = img.shape[1]
|
||||||
|
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
|
||||||
|
rend_img = renderer.render(vertices_full, camera_t=camera_t,
|
||||||
|
img=img, use_bg=True,
|
||||||
|
focal_length=focal_length, body_color=color)
|
||||||
|
heads_num, vertex_num, _ = attention.shape
|
||||||
|
all_head = np.zeros((vertex_num,vertex_num))
|
||||||
|
|
||||||
|
##### compute avg attention for 4 attention heads
|
||||||
|
for h in range(4):
|
||||||
|
att_per_img = attention[h]
|
||||||
|
all_head = all_head + att_per_img
|
||||||
|
all_head = all_head/4
|
||||||
|
|
||||||
|
col_sums = all_head.sum(axis=0)
|
||||||
|
all_head = all_head / col_sums[np.newaxis, :]
|
||||||
|
|
||||||
|
combined = []
|
||||||
|
if vertex_num>400: # body
|
||||||
|
selected_joints = [7] # [6,7,4,5,13,12]
|
||||||
|
else: # hand
|
||||||
|
selected_joints = [0] # [0, 4, 8, 12, 16, 20]
|
||||||
|
# Draw attention
|
||||||
|
for ii in range(len(selected_joints)):
|
||||||
|
reference_id = selected_joints[ii]
|
||||||
|
ref_point = ref_points[reference_id]
|
||||||
|
attention_to_show = all_head[reference_id][14::]
|
||||||
|
min_v = np.min(attention_to_show)
|
||||||
|
max_v = np.max(attention_to_show)
|
||||||
|
norm_attention_to_show = (attention_to_show - min_v)/(max_v-min_v)
|
||||||
|
vertices_norm = ((vertices_2d + 1) * 0.5) * img_size
|
||||||
|
ref_norm = ((ref_point + 1) * 0.5) * img_size
|
||||||
|
image = rend_img*0.4
|
||||||
|
|
||||||
|
total_to_draw = []
|
||||||
|
for jj in range(vertices_norm.shape[0]):
|
||||||
|
thres = 0.0
|
||||||
|
if norm_attention_to_show[jj]>thres:
|
||||||
|
things = [norm_attention_to_show[jj], ref_norm, vertices_norm[jj]]
|
||||||
|
total_to_draw.append(things)
|
||||||
|
total_to_draw.sort()
|
||||||
|
max_att_score = total_to_draw[-1][0]
|
||||||
|
for item in total_to_draw:
|
||||||
|
attention_score = item[0]
|
||||||
|
ref_point = item[1]
|
||||||
|
vertex = item[2]
|
||||||
|
plot_one_line(ref_point, vertex, image, ii, alpha=(attention_score-thres)/(max_att_score-thres) )
|
||||||
|
|
||||||
|
for jj in range(vertices_norm.shape[0]):
|
||||||
|
x = int(vertices_norm[jj,0])
|
||||||
|
y = int(vertices_norm[jj,1])
|
||||||
|
cv2.circle(image,(x,y), 1, (255,255,255), -1)
|
||||||
|
|
||||||
|
if len(combined)==0:
|
||||||
|
combined = image
|
||||||
|
else:
|
||||||
|
combined = np.hstack([combined, image])
|
||||||
|
|
||||||
|
final = np.hstack([img, combined, rend_img])
|
||||||
|
|
||||||
|
return final
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_reconstruction_no_text(img, img_size, vertices, camera, renderer, color='pink', focal_length=1000):
|
||||||
|
"""Overlays gt_kp and pred_kp on img.
|
||||||
|
Draws vert with text.
|
||||||
|
Renderer is an instance of SMPLRenderer.
|
||||||
|
"""
|
||||||
|
# Fix a flength so i can render this with persp correct scale
|
||||||
|
res = img.shape[1]
|
||||||
|
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
|
||||||
|
rend_img = renderer.render(vertices, camera_t=camera_t,
|
||||||
|
img=img, use_bg=True,
|
||||||
|
focal_length=focal_length,
|
||||||
|
body_color=color)
|
||||||
|
|
||||||
|
|
||||||
|
combined = np.hstack([img, rend_img])
|
||||||
|
|
||||||
|
return combined
|
||||||
|
|
||||||
|
|
||||||
|
def plot_one_line(ref, vertex, img, color_index, alpha=0.0, line_thickness=None):
|
||||||
|
# 13,6,7,8,3,4,5
|
||||||
|
# att_colors = [(255, 221, 104), (255, 255, 0), (255, 215, 227), (210, 240, 119), \
|
||||||
|
# (209, 238, 245), (244, 200, 243), (233, 242, 216)]
|
||||||
|
att_colors = [(255, 255, 0), (244, 200, 243), (210, 243, 119), (209, 238, 255), (200, 208, 255), (250, 238, 215)]
|
||||||
|
|
||||||
|
|
||||||
|
overlay = img.copy()
|
||||||
|
# output = img.copy()
|
||||||
|
# Plots one bounding box on image img
|
||||||
|
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
|
||||||
|
|
||||||
|
color = list(att_colors[color_index])
|
||||||
|
c1, c2 = (int(ref[0]), int(ref[1])), (int(vertex[0]), int(vertex[1]))
|
||||||
|
cv2.line(overlay, c1, c2, (alpha*float(color[0])/255,alpha*float(color[1])/255,alpha*float(color[2])/255) , thickness=tl, lineType=cv2.LINE_AA)
|
||||||
|
cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def cam2pixel(cam_coord, f, c):
|
||||||
|
x = cam_coord[:, 0] / (cam_coord[:, 2]) * f[0] + c[0]
|
||||||
|
y = cam_coord[:, 1] / (cam_coord[:, 2]) * f[1] + c[1]
|
||||||
|
z = cam_coord[:, 2]
|
||||||
|
img_coord = np.concatenate((x[:,None], y[:,None], z[:,None]),1)
|
||||||
|
return img_coord
|
||||||
|
|
||||||
|
|
||||||
|
class Renderer(object):
|
||||||
|
"""
|
||||||
|
Render mesh using OpenDR for visualization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, width=800, height=600, near=0.5, far=1000, faces=None):
|
||||||
|
self.colors = {'hand': [.9, .9, .9], 'pink': [.9, .7, .7], 'light_blue': [0.65098039, 0.74117647, 0.85882353] }
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.faces = faces
|
||||||
|
self.renderer = ColoredRenderer()
|
||||||
|
|
||||||
|
def render(self, vertices, faces=None, img=None,
|
||||||
|
camera_t=np.zeros([3], dtype=np.float32),
|
||||||
|
camera_rot=np.zeros([3], dtype=np.float32),
|
||||||
|
camera_center=None,
|
||||||
|
use_bg=False,
|
||||||
|
bg_color=(0.0, 0.0, 0.0),
|
||||||
|
body_color=None,
|
||||||
|
focal_length=5000,
|
||||||
|
disp_text=False,
|
||||||
|
gt_keyp=None,
|
||||||
|
pred_keyp=None,
|
||||||
|
**kwargs):
|
||||||
|
if img is not None:
|
||||||
|
height, width = img.shape[:2]
|
||||||
|
else:
|
||||||
|
height, width = self.height, self.width
|
||||||
|
|
||||||
|
if faces is None:
|
||||||
|
faces = self.faces
|
||||||
|
|
||||||
|
if camera_center is None:
|
||||||
|
camera_center = np.array([width * 0.5,
|
||||||
|
height * 0.5])
|
||||||
|
|
||||||
|
self.renderer.camera = ProjectPoints(rt=camera_rot,
|
||||||
|
t=camera_t,
|
||||||
|
f=focal_length * np.ones(2),
|
||||||
|
c=camera_center,
|
||||||
|
k=np.zeros(5))
|
||||||
|
dist = np.abs(self.renderer.camera.t.r[2] -
|
||||||
|
np.mean(vertices, axis=0)[2])
|
||||||
|
far = dist + 20
|
||||||
|
|
||||||
|
self.renderer.frustum = {'near': 1.0, 'far': far,
|
||||||
|
'width': width,
|
||||||
|
'height': height}
|
||||||
|
|
||||||
|
if img is not None:
|
||||||
|
if use_bg:
|
||||||
|
self.renderer.background_image = img
|
||||||
|
else:
|
||||||
|
self.renderer.background_image = np.ones_like(
|
||||||
|
img) * np.array(bg_color)
|
||||||
|
|
||||||
|
if body_color is None:
|
||||||
|
color = self.colors['light_blue']
|
||||||
|
else:
|
||||||
|
color = self.colors[body_color]
|
||||||
|
|
||||||
|
if isinstance(self.renderer, TexturedRenderer):
|
||||||
|
color = [1.,1.,1.]
|
||||||
|
|
||||||
|
self.renderer.set(v=vertices, f=faces,
|
||||||
|
vc=color, bgcolor=np.ones(3))
|
||||||
|
albedo = self.renderer.vc
|
||||||
|
# Construct Back Light (on back right corner)
|
||||||
|
yrot = np.radians(120)
|
||||||
|
|
||||||
|
self.renderer.vc = LambertianPointLight(
|
||||||
|
f=self.renderer.f,
|
||||||
|
v=self.renderer.v,
|
||||||
|
num_verts=self.renderer.v.shape[0],
|
||||||
|
light_pos=rotateY(np.array([-200, -100, -100]), yrot),
|
||||||
|
vc=albedo,
|
||||||
|
light_color=np.array([1, 1, 1]))
|
||||||
|
|
||||||
|
# Construct Left Light
|
||||||
|
self.renderer.vc += LambertianPointLight(
|
||||||
|
f=self.renderer.f,
|
||||||
|
v=self.renderer.v,
|
||||||
|
num_verts=self.renderer.v.shape[0],
|
||||||
|
light_pos=rotateY(np.array([800, 10, 300]), yrot),
|
||||||
|
vc=albedo,
|
||||||
|
light_color=np.array([1, 1, 1]))
|
||||||
|
|
||||||
|
# Construct Right Light
|
||||||
|
self.renderer.vc += LambertianPointLight(
|
||||||
|
f=self.renderer.f,
|
||||||
|
v=self.renderer.v,
|
||||||
|
num_verts=self.renderer.v.shape[0],
|
||||||
|
light_pos=rotateY(np.array([-500, 500, 1000]), yrot),
|
||||||
|
vc=albedo,
|
||||||
|
light_color=np.array([.7, .7, .7]))
|
||||||
|
|
||||||
|
return self.renderer.r
|
||||||
|
|
||||||
|
|
||||||
|
def render_vertex_color(self, vertices, faces=None, img=None,
|
||||||
|
camera_t=np.zeros([3], dtype=np.float32),
|
||||||
|
camera_rot=np.zeros([3], dtype=np.float32),
|
||||||
|
camera_center=None,
|
||||||
|
use_bg=False,
|
||||||
|
bg_color=(0.0, 0.0, 0.0),
|
||||||
|
vertex_color=None,
|
||||||
|
focal_length=5000,
|
||||||
|
disp_text=False,
|
||||||
|
gt_keyp=None,
|
||||||
|
pred_keyp=None,
|
||||||
|
**kwargs):
|
||||||
|
if img is not None:
|
||||||
|
height, width = img.shape[:2]
|
||||||
|
else:
|
||||||
|
height, width = self.height, self.width
|
||||||
|
|
||||||
|
if faces is None:
|
||||||
|
faces = self.faces
|
||||||
|
|
||||||
|
if camera_center is None:
|
||||||
|
camera_center = np.array([width * 0.5,
|
||||||
|
height * 0.5])
|
||||||
|
|
||||||
|
self.renderer.camera = ProjectPoints(rt=camera_rot,
|
||||||
|
t=camera_t,
|
||||||
|
f=focal_length * np.ones(2),
|
||||||
|
c=camera_center,
|
||||||
|
k=np.zeros(5))
|
||||||
|
dist = np.abs(self.renderer.camera.t.r[2] -
|
||||||
|
np.mean(vertices, axis=0)[2])
|
||||||
|
far = dist + 20
|
||||||
|
|
||||||
|
self.renderer.frustum = {'near': 1.0, 'far': far,
|
||||||
|
'width': width,
|
||||||
|
'height': height}
|
||||||
|
|
||||||
|
if img is not None:
|
||||||
|
if use_bg:
|
||||||
|
self.renderer.background_image = img
|
||||||
|
else:
|
||||||
|
self.renderer.background_image = np.ones_like(
|
||||||
|
img) * np.array(bg_color)
|
||||||
|
|
||||||
|
if vertex_color is None:
|
||||||
|
vertex_color = self.colors['light_blue']
|
||||||
|
|
||||||
|
|
||||||
|
self.renderer.set(v=vertices, f=faces,
|
||||||
|
vc=vertex_color, bgcolor=np.ones(3))
|
||||||
|
albedo = self.renderer.vc
|
||||||
|
# Construct Back Light (on back right corner)
|
||||||
|
yrot = np.radians(120)
|
||||||
|
|
||||||
|
self.renderer.vc = LambertianPointLight(
|
||||||
|
f=self.renderer.f,
|
||||||
|
v=self.renderer.v,
|
||||||
|
num_verts=self.renderer.v.shape[0],
|
||||||
|
light_pos=rotateY(np.array([-200, -100, -100]), yrot),
|
||||||
|
vc=albedo,
|
||||||
|
light_color=np.array([1, 1, 1]))
|
||||||
|
|
||||||
|
# Construct Left Light
|
||||||
|
self.renderer.vc += LambertianPointLight(
|
||||||
|
f=self.renderer.f,
|
||||||
|
v=self.renderer.v,
|
||||||
|
num_verts=self.renderer.v.shape[0],
|
||||||
|
light_pos=rotateY(np.array([800, 10, 300]), yrot),
|
||||||
|
vc=albedo,
|
||||||
|
light_color=np.array([1, 1, 1]))
|
||||||
|
|
||||||
|
# Construct Right Light
|
||||||
|
self.renderer.vc += LambertianPointLight(
|
||||||
|
f=self.renderer.f,
|
||||||
|
v=self.renderer.v,
|
||||||
|
num_verts=self.renderer.v.shape[0],
|
||||||
|
light_pos=rotateY(np.array([-500, 500, 1000]), yrot),
|
||||||
|
vc=albedo,
|
||||||
|
light_color=np.array([.7, .7, .7]))
|
||||||
|
|
||||||
|
return self.renderer.r
|
||||||
162
mesh_graphormer/utils/tsv_file.py
Normal file
162
mesh_graphormer/utils/tsv_file.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
Definition of TSV class
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import os.path as op
|
||||||
|
|
||||||
|
|
||||||
|
def generate_lineidx(filein, idxout):
|
||||||
|
idxout_tmp = idxout + '.tmp'
|
||||||
|
with open(filein, 'r') as tsvin, open(idxout_tmp,'w') as tsvout:
|
||||||
|
fsize = os.fstat(tsvin.fileno()).st_size
|
||||||
|
fpos = 0
|
||||||
|
while fpos!=fsize:
|
||||||
|
tsvout.write(str(fpos)+"\n")
|
||||||
|
tsvin.readline()
|
||||||
|
fpos = tsvin.tell()
|
||||||
|
os.rename(idxout_tmp, idxout)
|
||||||
|
|
||||||
|
|
||||||
|
def read_to_character(fp, c):
|
||||||
|
result = []
|
||||||
|
while True:
|
||||||
|
s = fp.read(32)
|
||||||
|
assert s != ''
|
||||||
|
if c in s:
|
||||||
|
result.append(s[: s.index(c)])
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
result.append(s)
|
||||||
|
return ''.join(result)
|
||||||
|
|
||||||
|
|
||||||
|
class TSVFile(object):
|
||||||
|
def __init__(self, tsv_file, generate_lineidx=False):
|
||||||
|
self.tsv_file = tsv_file
|
||||||
|
self.lineidx = op.splitext(tsv_file)[0] + '.lineidx'
|
||||||
|
self._fp = None
|
||||||
|
self._lineidx = None
|
||||||
|
# the process always keeps the process which opens the file.
|
||||||
|
# If the pid is not equal to the currrent pid, we will re-open the file.
|
||||||
|
self.pid = None
|
||||||
|
# generate lineidx if not exist
|
||||||
|
if not op.isfile(self.lineidx) and generate_lineidx:
|
||||||
|
generate_lineidx(self.tsv_file, self.lineidx)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self._fp:
|
||||||
|
self._fp.close()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "TSVFile(tsv_file='{}')".format(self.tsv_file)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
def num_rows(self):
|
||||||
|
self._ensure_lineidx_loaded()
|
||||||
|
return len(self._lineidx)
|
||||||
|
|
||||||
|
def seek(self, idx):
|
||||||
|
self._ensure_tsv_opened()
|
||||||
|
self._ensure_lineidx_loaded()
|
||||||
|
try:
|
||||||
|
pos = self._lineidx[idx]
|
||||||
|
except:
|
||||||
|
logging.info('{}-{}'.format(self.tsv_file, idx))
|
||||||
|
raise
|
||||||
|
self._fp.seek(pos)
|
||||||
|
return [s.strip() for s in self._fp.readline().split('\t')]
|
||||||
|
|
||||||
|
def seek_first_column(self, idx):
|
||||||
|
self._ensure_tsv_opened()
|
||||||
|
self._ensure_lineidx_loaded()
|
||||||
|
pos = self._lineidx[idx]
|
||||||
|
self._fp.seek(pos)
|
||||||
|
return read_to_character(self._fp, '\t')
|
||||||
|
|
||||||
|
def get_key(self, idx):
|
||||||
|
return self.seek_first_column(idx)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.seek(index)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_rows()
|
||||||
|
|
||||||
|
def _ensure_lineidx_loaded(self):
|
||||||
|
if self._lineidx is None:
|
||||||
|
logging.info('loading lineidx: {}'.format(self.lineidx))
|
||||||
|
with open(self.lineidx, 'r') as fp:
|
||||||
|
self._lineidx = [int(i.strip()) for i in fp.readlines()]
|
||||||
|
|
||||||
|
def _ensure_tsv_opened(self):
|
||||||
|
if self._fp is None:
|
||||||
|
self._fp = open(self.tsv_file, 'r')
|
||||||
|
self.pid = os.getpid()
|
||||||
|
|
||||||
|
if self.pid != os.getpid():
|
||||||
|
logging.info('re-open {} because the process id changed'.format(self.tsv_file))
|
||||||
|
self._fp = open(self.tsv_file, 'r')
|
||||||
|
self.pid = os.getpid()
|
||||||
|
|
||||||
|
|
||||||
|
class CompositeTSVFile():
|
||||||
|
def __init__(self, file_list, seq_file, root='.'):
|
||||||
|
if isinstance(file_list, str):
|
||||||
|
self.file_list = load_list_file(file_list)
|
||||||
|
else:
|
||||||
|
assert isinstance(file_list, list)
|
||||||
|
self.file_list = file_list
|
||||||
|
|
||||||
|
self.seq_file = seq_file
|
||||||
|
self.root = root
|
||||||
|
self.initialized = False
|
||||||
|
self.initialize()
|
||||||
|
|
||||||
|
def get_key(self, index):
|
||||||
|
idx_source, idx_row = self.seq[index]
|
||||||
|
k = self.tsvs[idx_source].get_key(idx_row)
|
||||||
|
return '_'.join([self.file_list[idx_source], k])
|
||||||
|
|
||||||
|
def num_rows(self):
|
||||||
|
return len(self.seq)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
idx_source, idx_row = self.seq[index]
|
||||||
|
return self.tsvs[idx_source].seek(idx_row)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.seq)
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
'''
|
||||||
|
this function has to be called in init function if cache_policy is
|
||||||
|
enabled. Thus, let's always call it in init funciton to make it simple.
|
||||||
|
'''
|
||||||
|
if self.initialized:
|
||||||
|
return
|
||||||
|
self.seq = []
|
||||||
|
with open(self.seq_file, 'r') as fp:
|
||||||
|
for line in fp:
|
||||||
|
parts = line.strip().split('\t')
|
||||||
|
self.seq.append([int(parts[0]), int(parts[1])])
|
||||||
|
self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list]
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
|
||||||
|
def load_list_file(fname):
|
||||||
|
with open(fname, 'r') as fp:
|
||||||
|
lines = fp.readlines()
|
||||||
|
result = [line.strip() for line in lines]
|
||||||
|
if len(result) > 0 and result[-1] == '':
|
||||||
|
result = result[:-1]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
116
mesh_graphormer/utils/tsv_file_ops.py
Normal file
116
mesh_graphormer/utils/tsv_file_ops.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
Licensed under the MIT license.
|
||||||
|
|
||||||
|
Basic operations for TSV files
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import os.path as op
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import base64
|
||||||
|
import cv2
|
||||||
|
from tqdm import tqdm
|
||||||
|
import yaml
|
||||||
|
from mesh_graphormer.utils.miscellaneous import mkdir
|
||||||
|
from mesh_graphormer.utils.tsv_file import TSVFile
|
||||||
|
|
||||||
|
|
||||||
|
def img_from_base64(imagestring):
|
||||||
|
try:
|
||||||
|
jpgbytestring = base64.b64decode(imagestring)
|
||||||
|
nparr = np.frombuffer(jpgbytestring, np.uint8)
|
||||||
|
r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||||
|
return r
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_linelist_file(linelist_file):
|
||||||
|
if linelist_file is not None:
|
||||||
|
line_list = []
|
||||||
|
with open(linelist_file, 'r') as fp:
|
||||||
|
for i in fp:
|
||||||
|
line_list.append(int(i.strip()))
|
||||||
|
return line_list
|
||||||
|
|
||||||
|
def tsv_writer(values, tsv_file, sep='\t'):
|
||||||
|
mkdir(op.dirname(tsv_file))
|
||||||
|
lineidx_file = op.splitext(tsv_file)[0] + '.lineidx'
|
||||||
|
idx = 0
|
||||||
|
tsv_file_tmp = tsv_file + '.tmp'
|
||||||
|
lineidx_file_tmp = lineidx_file + '.tmp'
|
||||||
|
with open(tsv_file_tmp, 'w') as fp, open(lineidx_file_tmp, 'w') as fpidx:
|
||||||
|
assert values is not None
|
||||||
|
for value in values:
|
||||||
|
assert value is not None
|
||||||
|
value = [v if type(v)!=bytes else v.decode('utf-8') for v in value]
|
||||||
|
v = '{0}\n'.format(sep.join(map(str, value)))
|
||||||
|
fp.write(v)
|
||||||
|
fpidx.write(str(idx) + '\n')
|
||||||
|
idx = idx + len(v)
|
||||||
|
os.rename(tsv_file_tmp, tsv_file)
|
||||||
|
os.rename(lineidx_file_tmp, lineidx_file)
|
||||||
|
|
||||||
|
def tsv_reader(tsv_file, sep='\t'):
|
||||||
|
with open(tsv_file, 'r') as fp:
|
||||||
|
for i, line in enumerate(fp):
|
||||||
|
yield [x.strip() for x in line.split(sep)]
|
||||||
|
|
||||||
|
def config_save_file(tsv_file, save_file=None, append_str='.new.tsv'):
|
||||||
|
if save_file is not None:
|
||||||
|
return save_file
|
||||||
|
return op.splitext(tsv_file)[0] + append_str
|
||||||
|
|
||||||
|
def get_line_list(linelist_file=None, num_rows=None):
|
||||||
|
if linelist_file is not None:
|
||||||
|
return load_linelist_file(linelist_file)
|
||||||
|
|
||||||
|
if num_rows is not None:
|
||||||
|
return [i for i in range(num_rows)]
|
||||||
|
|
||||||
|
def generate_hw_file(img_file, save_file=None):
|
||||||
|
rows = tsv_reader(img_file)
|
||||||
|
def gen_rows():
|
||||||
|
for i, row in tqdm(enumerate(rows)):
|
||||||
|
row1 = [row[0]]
|
||||||
|
img = img_from_base64(row[-1])
|
||||||
|
height = img.shape[0]
|
||||||
|
width = img.shape[1]
|
||||||
|
row1.append(json.dumps([{"height":height, "width": width}]))
|
||||||
|
yield row1
|
||||||
|
|
||||||
|
save_file = config_save_file(img_file, save_file, '.hw.tsv')
|
||||||
|
tsv_writer(gen_rows(), save_file)
|
||||||
|
|
||||||
|
def generate_linelist_file(label_file, save_file=None, ignore_attrs=()):
|
||||||
|
# generate a list of image that has labels
|
||||||
|
# images with only ignore labels are not selected.
|
||||||
|
line_list = []
|
||||||
|
rows = tsv_reader(label_file)
|
||||||
|
for i, row in tqdm(enumerate(rows)):
|
||||||
|
labels = json.loads(row[1])
|
||||||
|
if labels:
|
||||||
|
if ignore_attrs and all([any([lab[attr] for attr in ignore_attrs if attr in lab]) \
|
||||||
|
for lab in labels]):
|
||||||
|
continue
|
||||||
|
line_list.append([i])
|
||||||
|
|
||||||
|
save_file = config_save_file(label_file, save_file, '.linelist.tsv')
|
||||||
|
tsv_writer(line_list, save_file)
|
||||||
|
|
||||||
|
def load_from_yaml_file(yaml_file):
|
||||||
|
with open(yaml_file, 'r') as fp:
|
||||||
|
return yaml.load(fp, Loader=yaml.CLoader)
|
||||||
|
|
||||||
|
def find_file_path_in_yaml(fname, root):
|
||||||
|
if fname is not None:
|
||||||
|
if op.isfile(fname):
|
||||||
|
return fname
|
||||||
|
elif op.isfile(op.join(root, fname)):
|
||||||
|
return op.join(root, fname)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname)
|
||||||
|
)
|
||||||
1
requirement.txt
Normal file
1
requirement.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
rtree
|
||||||
Reference in New Issue
Block a user