Initial commit

This commit is contained in:
huchenlei
2024-01-03 00:39:16 -05:00
commit a1f793c0a7
75 changed files with 10090 additions and 0 deletions

161
.gitignore vendored Normal file
View File

@@ -0,0 +1,161 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
.vscode/

42
hand_refiner/__init__.py Normal file
View File

@@ -0,0 +1,42 @@
import numpy as np
from PIL import Image
from .util import resize_image_with_pad, common_input_validate, HWC3, custom_hf_download
from hand_refiner.pipeline import MeshGraphormerMediapipe, args
class MeshGraphormerDetector:
def __init__(self, pipeline):
self.pipeline = pipeline
@classmethod
def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device="cuda"):
filename = filename or "graphormer_hand_state_dict.bin"
hrnet_filename = hrnet_filename or "hrnetv2_w64_imagenet_pretrained.pth"
args.resume_checkpoint = custom_hf_download(pretrained_model_or_path, filename, cache_dir)
args.hrnet_checkpoint = custom_hf_download(pretrained_model_or_path, hrnet_filename, cache_dir)
args.device = device
pipeline = MeshGraphormerMediapipe(args)
return cls(pipeline)
def to(self, device):
self.pipeline._model.to(device)
self.pipeline.mano_model.to(device)
self.pipeline.mano_model.layer.to(device)
return self
def __call__(self, input_image=None, mask_bbox_padding=30, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
depth_map, mask, info = self.pipeline.get_depth(input_image, mask_bbox_padding)
if depth_map is None:
depth_map = np.zeros_like(input_image)
mask = np.zeros_like(input_image)
#The hand is small
depth_map, mask = HWC3(depth_map), HWC3(mask)
depth_map, remove_pad = resize_image_with_pad(depth_map, detect_resolution, upscale_method)
depth_map = remove_pad(depth_map)
if output_type == "pil":
depth_map = Image.fromarray(depth_map)
mask = Image.fromarray(mask)
return depth_map, mask, info

View File

@@ -0,0 +1,92 @@
GPUS: (0,1,2,3)
LOG_DIR: 'log/'
DATA_DIR: ''
OUTPUT_DIR: 'output/'
WORKERS: 4
PRINT_FREQ: 1000
MODEL:
NAME: cls_hrnet
IMAGE_SIZE:
- 224
- 224
EXTRA:
STAGE1:
NUM_MODULES: 1
NUM_RANCHES: 1
BLOCK: BOTTLENECK
NUM_BLOCKS:
- 4
NUM_CHANNELS:
- 64
FUSE_METHOD: SUM
STAGE2:
NUM_MODULES: 1
NUM_BRANCHES: 2
BLOCK: BASIC
NUM_BLOCKS:
- 4
- 4
NUM_CHANNELS:
- 64
- 128
FUSE_METHOD: SUM
STAGE3:
NUM_MODULES: 4
NUM_BRANCHES: 3
BLOCK: BASIC
NUM_BLOCKS:
- 4
- 4
- 4
NUM_CHANNELS:
- 64
- 128
- 256
FUSE_METHOD: SUM
STAGE4:
NUM_MODULES: 3
NUM_BRANCHES: 4
BLOCK: BASIC
NUM_BLOCKS:
- 4
- 4
- 4
- 4
NUM_CHANNELS:
- 64
- 128
- 256
- 512
FUSE_METHOD: SUM
CUDNN:
BENCHMARK: true
DETERMINISTIC: false
ENABLED: true
DATASET:
DATASET: 'imagenet'
DATA_FORMAT: 'jpg'
ROOT: 'data/imagenet/'
TEST_SET: 'val'
TRAIN_SET: 'train'
TEST:
BATCH_SIZE_PER_GPU: 32
MODEL_FILE: ''
TRAIN:
BATCH_SIZE_PER_GPU: 32
BEGIN_EPOCH: 0
END_EPOCH: 100
RESUME: true
LR_FACTOR: 0.1
LR_STEP:
- 30
- 60
- 90
OPTIMIZER: sgd
LR: 0.05
WD: 0.0001
MOMENTUM: 0.9
NESTEROV: true
SHUFFLE: true
DEBUG:
DEBUG: false

View File

@@ -0,0 +1,6 @@
class Preprocessor:
def __init__(self) -> None:
pass
def get_depth(self, input_dir, file_name):
return

Binary file not shown.

468
hand_refiner/pipeline.py Normal file
View File

@@ -0,0 +1,468 @@
import os
import torch
import gc
import numpy as np
from hand_refiner.depth_preprocessor import Preprocessor
import torchvision.models as models
from mesh_graphormer.modeling.bert import BertConfig, Graphormer
from mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network
from mesh_graphormer.modeling._mano import MANO, Mesh
from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
from mesh_graphormer.modeling.hrnet.config import config as hrnet_config
from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
from mesh_graphormer.utils.miscellaneous import set_seed
from argparse import Namespace
from pathlib import Path
import cv2
from torchvision import transforms
import numpy as np
import cv2
from trimesh import Trimesh
from trimesh.ray.ray_triangle import RayMeshIntersector
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from torchvision import transforms
from pathlib import Path
import mesh_graphormer
from packaging import version
args = Namespace(
num_workers=4,
img_scale_factor=1,
image_file_or_path=os.path.join('', 'MeshGraphormer', 'samples', 'hand'),
model_name_or_path=str(Path(mesh_graphormer.__file__).parent / "modeling/bert/bert-base-uncased"),
resume_checkpoint=None,
output_dir='output/',
config_name='',
a='hrnet-w64',
arch='hrnet-w64',
num_hidden_layers=4,
hidden_size=-1,
num_attention_heads=4,
intermediate_size=-1,
input_feat_dim='2051,512,128',
hidden_feat_dim='1024,256,64',
which_gcn='0,0,1',
mesh_type='hand',
run_eval_only=True,
device="cpu",
seed=88,
hrnet_checkpoint=None,
)
#Since mediapipe v0.10.5, the hand category has been correct
if version.parse(mp.__version__) >= version.parse('0.10.5'):
true_hand_category = {"Right": "right", "Left": "left"}
else:
true_hand_category = {"Right": "left", "Left": "right"}
class MeshGraphormerMediapipe(Preprocessor):
def __init__(self, args=args) -> None:
# Setup CUDA, GPU & distributed training
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
#mkdir(args.output_dir)
#logger = setup_logger("Graphormer", args.output_dir, get_rank())
set_seed(args.seed, args.num_gpus)
#logger.info("Using {} GPUs".format(args.num_gpus))
# Mesh and MANO utils
mano_model = MANO().to(args.device)
mano_model.layer = mano_model.layer.to(args.device)
mesh_sampler = Mesh(device=args.device)
# Renderer for visualization
# renderer = Renderer(faces=mano_model.face)
# Load pretrained model
trans_encoder = []
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
output_feat_dim = input_feat_dim[1:] + [3]
# which encoder block to have graph convs
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
# if only run eval, load checkpoint
#logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
_model = torch.load(args.resume_checkpoint)
else:
# init three transformer-encoder blocks in a loop
for i in range(len(output_feat_dim)):
config_class, model_class = BertConfig, Graphormer
config = config_class.from_pretrained(args.config_name if args.config_name \
else args.model_name_or_path)
config.output_attentions = False
config.img_feature_dim = input_feat_dim[i]
config.output_feature_dim = output_feat_dim[i]
args.hidden_size = hidden_feat_dim[i]
args.intermediate_size = int(args.hidden_size*2)
if which_blk_graph[i]==1:
config.graph_conv = True
#logger.info("Add Graph Conv")
else:
config.graph_conv = False
config.mesh_type = args.mesh_type
# update model structure if specified in arguments
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
for idx, param in enumerate(update_params):
arg_param = getattr(args, param)
config_param = getattr(config, param)
if arg_param > 0 and arg_param != config_param:
#logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
setattr(config, param, arg_param)
# init a transformer encoder and append it to a list
assert config.hidden_size % config.num_attention_heads == 0
model = model_class(config=config)
#logger.info("Init model from scratch.")
trans_encoder.append(model)
# create backbone model
if args.arch=='hrnet':
hrnet_yaml = Path(__file__).parent / 'cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
hrnet_checkpoint = args.hrnet_checkpoint
hrnet_update_config(hrnet_config, hrnet_yaml)
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
#logger.info('=> loading hrnet-v2-w40 model')
elif args.arch=='hrnet-w64':
hrnet_yaml = Path(__file__).parent / 'cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
hrnet_checkpoint = args.hrnet_checkpoint
hrnet_update_config(hrnet_config, hrnet_yaml)
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
#logger.info('=> loading hrnet-v2-w64 model')
else:
print("=> using pre-trained model '{}'".format(args.arch))
backbone = models.__dict__[args.arch](pretrained=True)
# remove the last fc layer
backbone = torch.nn.Sequential(*list(backbone.children())[:-1])
trans_encoder = torch.nn.Sequential(*trans_encoder)
total_params = sum(p.numel() for p in trans_encoder.parameters())
#logger.info('Graphormer encoders total parameters: {}'.format(total_params))
backbone_total_params = sum(p.numel() for p in backbone.parameters())
#logger.info('Backbone total parameters: {}'.format(backbone_total_params))
# build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
_model = Graphormer_Network(args, config, backbone, trans_encoder)
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
# for fine-tuning or resume training or inference, load weights from checkpoint
#logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
# workaround approach to load sparse tensor in graph conv.
state_dict = torch.load(args.resume_checkpoint)
_model.load_state_dict(state_dict, strict=False)
del state_dict
gc.collect()
# update configs to enable attention outputs
setattr(_model.trans_encoder[-1].config,'output_attentions', True)
setattr(_model.trans_encoder[-1].config,'output_hidden_states', True)
_model.trans_encoder[-1].bert.encoder.output_attentions = True
_model.trans_encoder[-1].bert.encoder.output_hidden_states = True
for iter_layer in range(4):
_model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True
for inter_block in range(3):
setattr(_model.trans_encoder[-1].config,'device', args.device)
_model.to(args.device)
self._model = _model
self.mano_model = mano_model
self.mesh_sampler = mesh_sampler
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
base_options = python.BaseOptions(model_asset_path=str( Path(__file__).parent / "hand_landmarker.task" ))
options = vision.HandLandmarkerOptions(base_options=base_options,
min_hand_detection_confidence=0.6,
min_hand_presence_confidence=0.6,
min_tracking_confidence=0.6,
num_hands=2)
self.detector = vision.HandLandmarker.create_from_options(options)
def get_rays(self, W, H, fx, fy, cx, cy, c2w_t, center_pixels): # rot = I
j, i = np.meshgrid(np.arange(H, dtype=np.float32), np.arange(W, dtype=np.float32))
if center_pixels:
i = i.copy() + 0.5
j = j.copy() + 0.5
directions = np.stack([(i - cx) / fx, (j - cy) / fy, np.ones_like(i)], -1)
directions /= np.linalg.norm(directions, axis=-1, keepdims=True)
rays_o = np.expand_dims(c2w_t,0).repeat(H*W, 0)
rays_d = directions # (H, W, 3)
rays_d = (rays_d / np.linalg.norm(rays_d, axis=-1, keepdims=True)).reshape(-1,3)
return rays_o, rays_d
def get_mask_bounding_box(self, extrema, H, W, padding=30, dynamic_resize=0.15):
x_min, x_max, y_min, y_max = extrema
bb_xpad = max(int((x_max - x_min + 1) * dynamic_resize), padding)
bb_ypad = max(int((y_max - y_min + 1) * dynamic_resize), padding)
bbx_min = np.max((x_min - bb_xpad, 0))
bbx_max = np.min((x_max + bb_xpad, W-1))
bby_min = np.max((y_min - bb_ypad, 0))
bby_max = np.min((y_max + bb_ypad, H-1))
return bbx_min, bbx_max, bby_min, bby_max
def run_inference(self, img, Graphormer_model, mano, mesh_sampler, scale, crop_len):
global args
H, W = int(crop_len), int(crop_len)
Graphormer_model.eval()
mano.eval()
device = next(Graphormer_model.parameters()).device
with torch.no_grad():
img_tensor = self.transform(img)
batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
# forward-pass
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler)
# obtain 3d joints, which are regressed from the full mesh
pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
# obtain 2d joints, which are projected from 3d joints of mesh
#pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
#pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous())
pred_camera = pred_camera.cpu()
pred_vertices = pred_vertices.cpu()
mesh = Trimesh(vertices=pred_vertices[0], faces=mano.face)
res = crop_len
focal_length = 1000 * scale
camera_t = np.array([-pred_camera[1], -pred_camera[2], -2*focal_length/(res * pred_camera[0] +1e-9)])
pred_3d_joints_camera = pred_3d_joints_from_mesh.cpu()[0] - camera_t
z_3d_dist = pred_3d_joints_camera[:,2].clone()
pred_2d_joints_img_space = ((pred_3d_joints_camera/z_3d_dist[:,None]) * np.array((focal_length, focal_length, 1)))[:,:2] + np.array((W/2, H/2))
rays_o, rays_d = self.get_rays(W, H, focal_length, focal_length, W/2, H/2, camera_t, True)
coords = np.array(list(np.ndindex(H,W))).reshape(H,W,-1).transpose(1,0,2).reshape(-1,2)
intersector = RayMeshIntersector(mesh)
points, index_ray, _ = intersector.intersects_location(rays_o, rays_d, multiple_hits=False)
tri_index = intersector.intersects_first(rays_o, rays_d)
tri_index = tri_index[index_ray]
assert len(index_ray) == len(tri_index)
discriminator = (np.sum(mesh.face_normals[tri_index]* rays_d[index_ray], axis=-1)<= 0)
points = points[discriminator] # ray intesects in interior faces, discard them
if len(points) == 0:
return None, None
depth = (points + camera_t)[:,-1]
index_ray = index_ray[discriminator]
pixel_ray = coords[index_ray]
minval = np.min(depth)
maxval = np.max(depth)
depthmap = np.zeros([H,W])
depthmap[pixel_ray[:, 0], pixel_ray[:, 1]] = 1.0 - (0.8 * (depth - minval) / (maxval - minval))
depthmap *= 255
return depthmap, pred_2d_joints_img_space
def get_depth(self, np_image, padding):
info = {}
# STEP 3: Load the input image.
#https://stackoverflow.com/a/76407270
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np_image.copy())
# STEP 4: Detect hand landmarks from the input image.
detection_result = self.detector.detect(image)
handedness_list = detection_result.handedness
hand_landmarks_list = detection_result.hand_landmarks
raw_image = image.numpy_view()
H, W, C = raw_image.shape
# HANDLANDMARKS CAN BE EMPTY, HANDLE THIS!
if len(hand_landmarks_list) == 0:
return None, None, None
raw_image = raw_image[:, :, :3]
padded_image = np.zeros((H*2, W*2, 3))
padded_image[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)] = raw_image
hand_landmarks_list, handedness_list = zip(
*sorted(
zip(hand_landmarks_list, handedness_list), key=lambda x: x[0][9].z, reverse=True
)
)
padded_depthmap = np.zeros((H*2, W*2))
mask = np.zeros((H, W))
crop_boxes = []
#bboxes = []
groundtruth_2d_keypoints = []
hands = []
depth_failure = False
crop_lens = []
for idx in range(len(hand_landmarks_list)):
hand = true_hand_category[handedness_list[idx][0].category_name]
hands.append(hand)
hand_landmarks = hand_landmarks_list[idx]
handedness = handedness_list[idx]
height, width, _ = raw_image.shape
x_coordinates = [landmark.x for landmark in hand_landmarks]
y_coordinates = [landmark.y for landmark in hand_landmarks]
# x_min, x_max, y_min, y_max: extrema from mediapipe keypoint detection
x_min = int(min(x_coordinates) * width)
x_max = int(max(x_coordinates) * width)
x_c = (x_min + x_max)//2
y_min = int(min(y_coordinates) * height)
y_max = int(max(y_coordinates) * height)
y_c = (y_min + y_max)//2
#if x_max - x_min < 60 or y_max - y_min < 60:
# continue
crop_len = (max(x_max - x_min, y_max - y_min) * 1.6) //2 * 2
# crop_x_min, crop_x_max, crop_y_min, crop_y_max: bounding box for mesh reconstruction
crop_x_min = int(x_c - (crop_len/2 - 1) + W/2)
crop_x_max = int(x_c + crop_len/2 + W/2)
crop_y_min = int(y_c - (crop_len/2 - 1) + H/2)
crop_y_max = int(y_c + crop_len/2 + H/2)
cropped = padded_image[crop_y_min:crop_y_max+1, crop_x_min:crop_x_max+1]
crop_boxes.append([crop_y_min, crop_y_max, crop_x_min, crop_x_max])
crop_lens.append(crop_len)
if hand == "left":
cropped = cv2.flip(cropped, 1)
if crop_len < 224:
graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_CUBIC)
else:
graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_AREA)
scale = crop_len/224
cropped_depthmap, pred_2d_keypoints = self.run_inference(graphormer_input.astype(np.uint8), self._model, self.mano_model, self.mesh_sampler, scale, int(crop_len))
if cropped_depthmap is None:
depth_failure = True
break
#keypoints_image_space = pred_2d_keypoints * (crop_y_max - crop_y_min + 1)/224
groundtruth_2d_keypoints.append(pred_2d_keypoints)
if hand == "left":
cropped_depthmap = cv2.flip(cropped_depthmap, 1)
resized_cropped_depthmap = cv2.resize(cropped_depthmap, (int(crop_len), int(crop_len)), interpolation=cv2.INTER_LINEAR)
nonzero_y, nonzero_x = (resized_cropped_depthmap != 0).nonzero()
if len(nonzero_y) == 0 or len(nonzero_x) == 0:
depth_failure = True
break
padded_depthmap[crop_y_min+nonzero_y, crop_x_min+nonzero_x] = resized_cropped_depthmap[nonzero_y, nonzero_x]
# nonzero stands for nonzero value on the depth map
# coordinates of nonzero depth pixels in original image space
original_nonzero_x = crop_x_min+nonzero_x - int(W/2)
original_nonzero_y = crop_y_min+nonzero_y - int(H/2)
nonzerox_min = min(np.min(original_nonzero_x), x_min)
nonzerox_max = max(np.max(original_nonzero_x), x_max)
nonzeroy_min = min(np.min(original_nonzero_y), y_min)
nonzeroy_max = max(np.max(original_nonzero_y), y_max)
bbx_min, bbx_max, bby_min, bby_max = self.get_mask_bounding_box((nonzerox_min, nonzerox_max, nonzeroy_min, nonzeroy_max), H, W, padding)
mask[bby_min:bby_max+1, bbx_min:bbx_max+1] = 1.0
#bboxes.append([int(bbx_min), int(bbx_max), int(bby_min), int(bby_max)])
if depth_failure:
#print("cannot detect normal hands")
return None, None, None
depthmap = padded_depthmap[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)].astype(np.uint8)
mask = (255.0 * mask).astype(np.uint8)
info["groundtruth_2d_keypoints"] = groundtruth_2d_keypoints
info["hands"] = hands
info["crop_boxes"] = crop_boxes
info["crop_lens"] = crop_lens
return depthmap, mask, info
def get_keypoints(self, img, Graphormer_model, mano, mesh_sampler, scale, crop_len):
global args
H, W = int(crop_len), int(crop_len)
Graphormer_model.eval()
mano.eval()
device = next(Graphormer_model.parameters()).device
with torch.no_grad():
img_tensor = self.transform(img)
#print(img_tensor)
batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
# forward-pass
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler)
# obtain 3d joints, which are regressed from the full mesh
pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
# obtain 2d joints, which are projected from 3d joints of mesh
#pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
#pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous())
pred_camera = pred_camera.cpu()
pred_vertices = pred_vertices.cpu()
#
res = crop_len
focal_length = 1000 * scale
camera_t = np.array([-pred_camera[1], -pred_camera[2], -2*focal_length/(res * pred_camera[0] +1e-9)])
pred_3d_joints_camera = pred_3d_joints_from_mesh.cpu()[0] - camera_t
z_3d_dist = pred_3d_joints_camera[:,2].clone()
pred_2d_joints_img_space = ((pred_3d_joints_camera/z_3d_dist[:,None]) * np.array((focal_length, focal_length, 1)))[:,:2] + np.array((W/2, H/2))
return pred_2d_joints_img_space
def eval_mpjpe(self, sample, info):
H, W, C = sample.shape
padded_image = np.zeros((H*2, W*2, 3))
padded_image[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)] = sample
crop_boxes = info["crop_boxes"]
hands = info["hands"]
groundtruth_2d_keypoints = info["groundtruth_2d_keypoints"]
crop_lens = info["crop_lens"]
pjpe = 0
for i in range(len(crop_boxes)):#box in crop_boxes:
crop_y_min, crop_y_max, crop_x_min, crop_x_max = crop_boxes[i]
cropped = padded_image[crop_y_min:crop_y_max+1, crop_x_min:crop_x_max+1]
hand = hands[i]
if hand == "left":
cropped = cv2.flip(cropped, 1)
crop_len = crop_lens[i]
scale = crop_len/224
if crop_len < 224:
graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_CUBIC)
else:
graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_AREA)
generated_keypoint = self.get_keypoints(graphormer_input.astype(np.uint8), self._model, self.mano_model, self.mesh_sampler, scale, crop_len)
#generated_keypoint = generated_keypoint * ((crop_y_max - crop_y_min + 1)/224)
pjpe += np.sum(np.sqrt(np.sum(((generated_keypoint - groundtruth_2d_keypoints[i]) ** 2).numpy(), axis=1)))
pass
mpjpe = pjpe/(len(crop_boxes) * 21)
return mpjpe

193
hand_refiner/util.py Normal file
View File

@@ -0,0 +1,193 @@
import os
import random
import cv2
import numpy as np
from pathlib import Path
import warnings
from huggingface_hub import hf_hub_download
here = Path(__file__).parent.resolve()
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def make_noise_disk(H, W, C, F):
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
noise = noise[F: F + H, F: F + W]
noise -= np.min(noise)
noise /= np.max(noise)
if C == 1:
noise = noise[:, :, None]
return noise
def nms(x, t, s):
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
z = np.zeros_like(y, dtype=np.uint8)
z[y > t] = 255
return z
def min_max_norm(x):
x -= np.min(x)
x /= np.maximum(np.max(x), 1e-5)
return x
def safe_step(x, step=2):
y = x.astype(np.float32) * float(step + 1)
y = y.astype(np.int32).astype(np.float32) / float(step)
return y
def img2mask(img, H, W, low=10, high=90):
assert img.ndim == 3 or img.ndim == 2
assert img.dtype == np.uint8
if img.ndim == 3:
y = img[:, :, random.randrange(0, img.shape[2])]
else:
y = img
y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
if random.uniform(0, 1) < 0.5:
y = 255 - y
return y < np.percentile(y, random.randrange(low, high))
def safer_memory(x):
# Fix many MAC/AMD problems
return np.ascontiguousarray(x.copy()).copy()
UPSCALE_METHODS = ["INTER_NEAREST", "INTER_LINEAR", "INTER_AREA", "INTER_CUBIC", "INTER_LANCZOS4"]
def get_upscale_method(method_str):
assert method_str in UPSCALE_METHODS, f"Method {method_str} not found in {UPSCALE_METHODS}"
return getattr(cv2, method_str)
def pad64(x):
return int(np.ceil(float(x) / 64.0) * 64 - x)
#https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/processor.py#L17
#Added upscale_method param
def resize_image_with_pad(input_image, resolution, upscale_method = "", skip_hwc3=False):
if skip_hwc3:
img = input_image
else:
img = HWC3(input_image)
H_raw, W_raw, _ = img.shape
k = float(resolution) / float(min(H_raw, W_raw))
H_target = int(np.round(float(H_raw) * k))
W_target = int(np.round(float(W_raw) * k))
img = cv2.resize(img, (W_target, H_target), interpolation=get_upscale_method(upscale_method) if k > 1 else cv2.INTER_AREA)
H_pad, W_pad = pad64(H_target), pad64(W_target)
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode='edge')
def remove_pad(x):
return safer_memory(x[:H_target, :W_target, ...])
return safer_memory(img_padded), remove_pad
def common_input_validate(input_image, output_type, **kwargs):
if "img" in kwargs:
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
input_image = kwargs.pop("img")
if "return_pil" in kwargs:
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
output_type = "pil" if kwargs["return_pil"] else "np"
if type(output_type) is bool:
warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
if output_type:
output_type = "pil"
if input_image is None:
raise ValueError("input_image must be defined.")
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)
output_type = output_type or "pil"
else:
output_type = output_type or "np"
return (input_image, output_type)
def custom_hf_download(pretrained_model_or_path, filename, cache_dir, subfolder='', use_symlinks=False):
local_dir = os.path.join(cache_dir, pretrained_model_or_path)
model_path = os.path.join(local_dir, *subfolder.split('/'), filename)
if not os.path.exists(model_path):
print(f"Failed to find {model_path}.\n Downloading from huggingface.co")
if use_symlinks:
cache_dir_d = os.getenv("HUGGINGFACE_HUB_CACHE")
if cache_dir_d is None:
import platform
if platform.system() == "Windows":
cache_dir_d = os.path.join(os.getenv("USERPROFILE"), ".cache", "huggingface", "hub")
else:
cache_dir_d = os.path.join(os.getenv("HOME"), ".cache", "huggingface", "hub")
try:
# test_link
if not os.path.exists(cache_dir_d):
os.makedirs(cache_dir_d)
open(os.path.join(cache_dir_d, f"linktest_{filename}.txt"), "w")
os.link(os.path.join(cache_dir_d, f"linktest_{filename}.txt"), os.path.join(cache_dir, f"linktest_{filename}.txt"))
os.remove(os.path.join(cache_dir, f"linktest_{filename}.txt"))
os.remove(os.path.join(cache_dir_d, f"linktest_{filename}.txt"))
print("Using symlinks to download models. \n",\
"Make sure you have enough space on your cache folder. \n",\
"And do not purge the cache folder after downloading.\n",\
"Otherwise, you will have to re-download the models every time you run the script.\n",\
"You can use USE_SYMLINKS: False in config.yaml to avoid this behavior.")
except:
print("Maybe not able to create symlink. Disable using symlinks.")
use_symlinks = False
cache_dir_d = os.path.join(cache_dir, pretrained_model_or_path, "cache")
else:
cache_dir_d = os.path.join(cache_dir, pretrained_model_or_path, "cache")
model_path = hf_hub_download(repo_id=pretrained_model_or_path,
cache_dir=cache_dir_d,
local_dir=local_dir,
subfolder=subfolder,
filename=filename,
local_dir_use_symlinks=use_symlinks,
resume_download=True,
etag_timeout=100
)
if not use_symlinks:
try:
import shutil
shutil.rmtree(cache_dir_d)
except Exception as e :
print(e)
return model_path

1
manopth/CHANGES.md Normal file
View File

@@ -0,0 +1 @@
* Chumpy is removed

674
manopth/LICENSE Normal file
View File

@@ -0,0 +1,674 @@
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

1
manopth/__init__.py Normal file
View File

@@ -0,0 +1 @@
name = 'manopth'

51
manopth/argutils.py Normal file
View File

@@ -0,0 +1,51 @@
import datetime
import os
import pickle
import subprocess
import sys
def print_args(args):
opts = vars(args)
print('======= Options ========')
for k, v in sorted(opts.items()):
print('{}: {}'.format(k, v))
print('========================')
def save_args(args, save_folder, opt_prefix='opt', verbose=True):
opts = vars(args)
# Create checkpoint folder
if not os.path.exists(save_folder):
os.makedirs(save_folder, exist_ok=True)
# Save options
opt_filename = '{}.txt'.format(opt_prefix)
opt_path = os.path.join(save_folder, opt_filename)
with open(opt_path, 'a') as opt_file:
opt_file.write('====== Options ======\n')
for k, v in sorted(opts.items()):
opt_file.write(
'{option}: {value}\n'.format(option=str(k), value=str(v)))
opt_file.write('=====================\n')
opt_file.write('launched {} at {}\n'.format(
str(sys.argv[0]), str(datetime.datetime.now())))
# Add git info
label = subprocess.check_output(["git", "describe",
"--always"]).strip()
if subprocess.call(
["git", "branch"],
stderr=subprocess.STDOUT,
stdout=open(os.devnull, 'w')) == 0:
opt_file.write('=== Git info ====\n')
opt_file.write('{}\n'.format(label))
commit = subprocess.check_output(['git', 'rev-parse', 'HEAD'])
opt_file.write('commit : {}\n'.format(commit.strip()))
opt_picklename = '{}.pkl'.format(opt_prefix)
opt_picklepath = os.path.join(save_folder, opt_picklename)
with open(opt_picklepath, 'wb') as opt_file:
pickle.dump(opts, opt_file)
if verbose:
print('Saved options to {}'.format(opt_path))

59
manopth/demo.py Normal file
View File

@@ -0,0 +1,59 @@
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import numpy as np
import torch
from manopth.manolayer import ManoLayer
def generate_random_hand(batch_size=1, ncomps=6, mano_root='mano/models'):
nfull_comps = ncomps + 3 # Add global orientation dims to PCA
random_pcapose = torch.rand(batch_size, nfull_comps)
mano_layer = ManoLayer(mano_root=mano_root)
verts, joints = mano_layer(random_pcapose)
return {'verts': verts, 'joints': joints, 'faces': mano_layer.th_faces}
def display_hand(hand_info, mano_faces=None, ax=None, alpha=0.2, batch_idx=0, show=True):
"""
Displays hand batch_idx in batch of hand_info, hand_info as returned by
generate_random_hand
"""
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
verts, joints = hand_info['verts'][batch_idx], hand_info['joints'][
batch_idx]
if mano_faces is None:
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.1)
else:
mesh = Poly3DCollection(verts[mano_faces], alpha=alpha)
face_color = (141 / 255, 184 / 255, 226 / 255)
edge_color = (50 / 255, 50 / 255, 50 / 255)
mesh.set_edgecolor(edge_color)
mesh.set_facecolor(face_color)
ax.add_collection3d(mesh)
ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color='r')
cam_equal_aspect_3d(ax, verts.numpy())
if show:
plt.show()
def cam_equal_aspect_3d(ax, verts, flip_x=False):
"""
Centers view on cuboid containing hand and flips y and z axis
and fixes azimuth
"""
extents = np.stack([verts.min(0), verts.max(0)], axis=1)
sz = extents[:, 1] - extents[:, 0]
centers = np.mean(extents, axis=1)
maxsize = max(abs(sz))
r = maxsize / 2
if flip_x:
ax.set_xlim(centers[0] + r, centers[0] - r)
else:
ax.set_xlim(centers[0] - r, centers[0] + r)
# Invert y and z axis
ax.set_ylim(centers[1] + r, centers[1] - r)
ax.set_zlim(centers[2] + r, centers[2] - r)

274
manopth/manolayer.py Normal file
View File

@@ -0,0 +1,274 @@
import os
import numpy as np
import torch
from torch.nn import Module
from manopth.smpl_handpca_wrapper_HAND_only import ready_arguments
from manopth import rodrigues_layer, rotproj, rot6d
from manopth.tensutils import (th_posemap_axisang, th_with_zeros, th_pack,
subtract_flat_id, make_list)
class ManoLayer(Module):
__constants__ = [
'use_pca', 'rot', 'ncomps', 'ncomps', 'kintree_parents', 'check',
'side', 'center_idx', 'joint_rot_mode'
]
def __init__(self,
center_idx=None,
flat_hand_mean=True,
ncomps=6,
side='right',
mano_root='mano/models',
use_pca=True,
root_rot_mode='axisang',
joint_rot_mode='axisang',
robust_rot=False):
"""
Args:
center_idx: index of center joint in our computations,
if -1 centers on estimate of palm as middle of base
of middle finger and wrist
flat_hand_mean: if True, (0, 0, 0, ...) pose coefficients match
flat hand, else match average hand pose
mano_root: path to MANO pkl files for left and right hand
ncomps: number of PCA components form pose space (<45)
side: 'right' or 'left'
use_pca: Use PCA decomposition for pose space.
joint_rot_mode: 'axisang' or 'rotmat', ignored if use_pca
"""
super().__init__()
self.center_idx = center_idx
self.robust_rot = robust_rot
if root_rot_mode == 'axisang':
self.rot = 3
else:
self.rot = 6
self.flat_hand_mean = flat_hand_mean
self.side = side
self.use_pca = use_pca
self.joint_rot_mode = joint_rot_mode
self.root_rot_mode = root_rot_mode
if use_pca:
self.ncomps = ncomps
else:
self.ncomps = 45
if side == 'right':
self.mano_path = os.path.join(mano_root, 'MANO_RIGHT.pkl')
elif side == 'left':
self.mano_path = os.path.join(mano_root, 'MANO_LEFT.pkl')
smpl_data = ready_arguments(self.mano_path)
hands_components = smpl_data['hands_components']
self.smpl_data = smpl_data
self.register_buffer('th_betas',
torch.Tensor(smpl_data['betas']).unsqueeze(0))
self.register_buffer('th_shapedirs',
torch.Tensor(smpl_data['shapedirs']))
self.register_buffer('th_posedirs',
torch.Tensor(smpl_data['posedirs']))
self.register_buffer(
'th_v_template',
torch.Tensor(smpl_data['v_template']).unsqueeze(0))
self.register_buffer(
'th_J_regressor',
torch.Tensor(np.array(smpl_data['J_regressor'].toarray())))
self.register_buffer('th_weights',
torch.Tensor(smpl_data['weights']))
self.register_buffer('th_faces',
torch.Tensor(smpl_data['f'].astype(np.int32)).long())
# Get hand mean
hands_mean = np.zeros(hands_components.shape[1]
) if flat_hand_mean else smpl_data['hands_mean']
hands_mean = hands_mean.copy()
th_hands_mean = torch.Tensor(hands_mean).unsqueeze(0)
if self.use_pca or self.joint_rot_mode == 'axisang':
# Save as axis-angle
self.register_buffer('th_hands_mean', th_hands_mean)
selected_components = hands_components[:ncomps]
self.register_buffer('th_comps', torch.Tensor(hands_components))
self.register_buffer('th_selected_comps',
torch.Tensor(selected_components))
else:
th_hands_mean_rotmat = rodrigues_layer.batch_rodrigues(
th_hands_mean.view(15, 3)).reshape(15, 3, 3)
self.register_buffer('th_hands_mean_rotmat', th_hands_mean_rotmat)
# Kinematic chain params
self.kintree_table = smpl_data['kintree_table']
parents = list(self.kintree_table[0].tolist())
self.kintree_parents = parents
def forward(self,
th_pose_coeffs,
th_betas=torch.zeros(1),
th_trans=torch.zeros(1),
root_palm=torch.Tensor([0]),
share_betas=torch.Tensor([0]),
):
"""
Args:
th_trans (Tensor (batch_size x ncomps)): if provided, applies trans to joints and vertices
th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters for hand shape
else centers on root joint (9th joint)
root_palm: return palm as hand root instead of wrist
"""
# if len(th_pose_coeffs) == 0:
# return th_pose_coeffs.new_empty(0), th_pose_coeffs.new_empty(0)
batch_size = th_pose_coeffs.shape[0]
# Get axis angle from PCA components and coefficients
if self.use_pca or self.joint_rot_mode == 'axisang':
# Remove global rot coeffs
th_hand_pose_coeffs = th_pose_coeffs[:, self.rot:self.rot +
self.ncomps]
if self.use_pca:
# PCA components --> axis angles
th_full_hand_pose = th_hand_pose_coeffs.mm(self.th_selected_comps)
else:
th_full_hand_pose = th_hand_pose_coeffs
# Concatenate back global rot
th_full_pose = torch.cat([
th_pose_coeffs[:, :self.rot],
self.th_hands_mean + th_full_hand_pose
], 1)
if self.root_rot_mode == 'axisang':
# compute rotation matrixes from axis-angle while skipping global rotation
th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose)
root_rot = th_rot_map[:, :9].view(batch_size, 3, 3)
th_rot_map = th_rot_map[:, 9:]
th_pose_map = th_pose_map[:, 9:]
else:
# th_posemap offsets by 3, so add offset or 3 to get to self.rot=6
th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose[:, 6:])
if self.robust_rot:
root_rot = rot6d.robust_compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6])
else:
root_rot = rot6d.compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6])
else:
assert th_pose_coeffs.dim() == 4, (
'When not self.use_pca, '
'th_pose_coeffs should have 4 dims, got {}'.format(
th_pose_coeffs.dim()))
assert th_pose_coeffs.shape[2:4] == (3, 3), (
'When not self.use_pca, th_pose_coeffs have 3x3 matrix for two'
'last dims, got {}'.format(th_pose_coeffs.shape[2:4]))
th_pose_rots = rotproj.batch_rotprojs(th_pose_coeffs)
th_rot_map = th_pose_rots[:, 1:].view(batch_size, -1)
th_pose_map = subtract_flat_id(th_rot_map)
root_rot = th_pose_rots[:, 0]
# Full axis angle representation with root joint
if th_betas is None or th_betas.numel() == 1:
th_v_shaped = torch.matmul(self.th_shapedirs,
self.th_betas.transpose(1, 0)).permute(
2, 0, 1) + self.th_v_template
th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat(
batch_size, 1, 1)
else:
if share_betas:
th_betas = th_betas.mean(0, keepdim=True).expand(th_betas.shape[0], 10)
th_v_shaped = torch.matmul(self.th_shapedirs,
th_betas.transpose(1, 0)).permute(
2, 0, 1) + self.th_v_template
th_j = torch.matmul(self.th_J_regressor, th_v_shaped)
# th_pose_map should have shape 20x135
th_v_posed = th_v_shaped + torch.matmul(
self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1)
# Final T pose with transformation done !
# Global rigid transformation
root_j = th_j[:, 0, :].contiguous().view(batch_size, 3, 1)
root_trans = th_with_zeros(torch.cat([root_rot, root_j], 2))
all_rots = th_rot_map.view(th_rot_map.shape[0], 15, 3, 3)
lev1_idxs = [1, 4, 7, 10, 13]
lev2_idxs = [2, 5, 8, 11, 14]
lev3_idxs = [3, 6, 9, 12, 15]
lev1_rots = all_rots[:, [idx - 1 for idx in lev1_idxs]]
lev2_rots = all_rots[:, [idx - 1 for idx in lev2_idxs]]
lev3_rots = all_rots[:, [idx - 1 for idx in lev3_idxs]]
lev1_j = th_j[:, lev1_idxs]
lev2_j = th_j[:, lev2_idxs]
lev3_j = th_j[:, lev3_idxs]
# From base to tips
# Get lev1 results
all_transforms = [root_trans.unsqueeze(1)]
lev1_j_rel = lev1_j - root_j.transpose(1, 2)
lev1_rel_transform_flt = th_with_zeros(torch.cat([lev1_rots, lev1_j_rel.unsqueeze(3)], 3).view(-1, 3, 4))
root_trans_flt = root_trans.unsqueeze(1).repeat(1, 5, 1, 1).view(root_trans.shape[0] * 5, 4, 4)
lev1_flt = torch.matmul(root_trans_flt, lev1_rel_transform_flt)
all_transforms.append(lev1_flt.view(all_rots.shape[0], 5, 4, 4))
# Get lev2 results
lev2_j_rel = lev2_j - lev1_j
lev2_rel_transform_flt = th_with_zeros(torch.cat([lev2_rots, lev2_j_rel.unsqueeze(3)], 3).view(-1, 3, 4))
lev2_flt = torch.matmul(lev1_flt, lev2_rel_transform_flt)
all_transforms.append(lev2_flt.view(all_rots.shape[0], 5, 4, 4))
# Get lev3 results
lev3_j_rel = lev3_j - lev2_j
lev3_rel_transform_flt = th_with_zeros(torch.cat([lev3_rots, lev3_j_rel.unsqueeze(3)], 3).view(-1, 3, 4))
lev3_flt = torch.matmul(lev2_flt, lev3_rel_transform_flt)
all_transforms.append(lev3_flt.view(all_rots.shape[0], 5, 4, 4))
reorder_idxs = [0, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14, 5, 10, 15]
th_results = torch.cat(all_transforms, 1)[:, reorder_idxs]
th_results_global = th_results
joint_js = torch.cat([th_j, th_j.new_zeros(th_j.shape[0], 16, 1)], 2)
tmp2 = torch.matmul(th_results, joint_js.unsqueeze(3))
th_results2 = (th_results - torch.cat([tmp2.new_zeros(*tmp2.shape[:2], 4, 3), tmp2], 3)).permute(0, 2, 3, 1)
th_T = torch.matmul(th_results2, self.th_weights.transpose(0, 1))
th_rest_shape_h = torch.cat([
th_v_posed.transpose(2, 1),
torch.ones((batch_size, 1, th_v_posed.shape[1]),
dtype=th_T.dtype,
device=th_T.device),
], 1)
th_verts = (th_T * th_rest_shape_h.unsqueeze(1)).sum(2).transpose(2, 1)
th_verts = th_verts[:, :, :3]
th_jtr = th_results_global[:, :, :3, 3]
# In addition to MANO reference joints we sample vertices on each finger
# to serve as finger tips
if self.side == 'right':
tips = th_verts[:, [745, 317, 444, 556, 673]]
else:
tips = th_verts[:, [745, 317, 445, 556, 673]]
if bool(root_palm):
palm = (th_verts[:, 95] + th_verts[:, 22]).unsqueeze(1) / 2
th_jtr = torch.cat([palm, th_jtr[:, 1:]], 1)
th_jtr = torch.cat([th_jtr, tips], 1)
# Reorder joints to match visualization utilities
th_jtr = th_jtr[:, [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]]
if th_trans is None or bool(torch.norm(th_trans) == 0):
if self.center_idx is not None:
center_joint = th_jtr[:, self.center_idx].unsqueeze(1)
th_jtr = th_jtr - center_joint
th_verts = th_verts - center_joint
else:
th_jtr = th_jtr + th_trans.unsqueeze(1)
th_verts = th_verts + th_trans.unsqueeze(1)
# Scale to milimeters
th_verts = th_verts * 1000
th_jtr = th_jtr * 1000
return th_verts, th_jtr

37
manopth/posemapper.py Normal file
View File

@@ -0,0 +1,37 @@
'''
Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
This software is provided for research purposes only.
By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
For comments or questions, please email us at: mano@tue.mpg.de
About this file:
================
This file defines a wrapper for the loading functions of the MANO model.
Modules included:
- load_model:
loads the MANO model from a given file location (i.e. a .pkl file location),
or a dictionary object.
'''
import numpy as np
import cv2
def lrotmin(p):
if isinstance(p, np.ndarray):
p = p.ravel()[3:]
return np.concatenate(
[(cv2.Rodrigues(np.array(pp))[0] - np.eye(3)).ravel()
for pp in p.reshape((-1, 3))]).ravel()
def posemap(s):
if s == 'lrotmin':
return lrotmin
else:
raise Exception('Unknown posemapping: %s' % (str(s), ))

View File

@@ -0,0 +1,89 @@
"""
This part reuses code from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py
which is part of a PyTorch port of SMPL.
Thanks to Zhang Xiong (MandyMo) for making this great code available on github !
"""
import argparse
from torch.autograd import gradcheck
import torch
from torch.autograd import Variable
from manopth import argutils
def quat2mat(quat):
"""Convert quaternion coefficients to rotation matrix.
Args:
quat: size = [batch_size, 4] 4 <===>(w, x, y, z)
Returns:
Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3]
"""
norm_quat = quat
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:,
2], norm_quat[:,
3]
batch_size = quat.size(0)
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
wx, wy, wz = w * x, w * y, w * z
xy, xz, yz = x * y, x * z, y * z
rotMat = torch.stack([
w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy,
w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz,
w2 - x2 - y2 + z2
],
dim=1).view(batch_size, 3, 3)
return rotMat
def batch_rodrigues(axisang):
#axisang N x 3
axisang_norm = torch.norm(axisang + 1e-8, p=2, dim=1)
angle = torch.unsqueeze(axisang_norm, -1)
axisang_normalized = torch.div(axisang, angle)
angle = angle * 0.5
v_cos = torch.cos(angle)
v_sin = torch.sin(angle)
quat = torch.cat([v_cos, v_sin * axisang_normalized], dim=1)
rot_mat = quat2mat(quat)
rot_mat = rot_mat.view(rot_mat.shape[0], 9)
return rot_mat
def th_get_axis_angle(vector):
angle = torch.norm(vector, 2, 1)
axes = vector / angle.unsqueeze(1)
return axes, angle
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--cuda', action='store_true')
args = parser.parse_args()
argutils.print_args(args)
n_components = 6
rot = 3
inputs = torch.rand(args.batch_size, rot)
inputs_var = Variable(inputs.double(), requires_grad=True)
if args.cuda:
inputs = inputs.cuda()
# outputs = batch_rodrigues(inputs)
test_function = gradcheck(batch_rodrigues, (inputs_var, ))
print('batch test passed !')
inputs = torch.rand(rot)
inputs_var = Variable(inputs.double(), requires_grad=True)
test_function = gradcheck(th_cv2_rod_sub_id.apply, (inputs_var, ))
print('th_cv2_rod test passed')
inputs = torch.rand(rot)
inputs_var = Variable(inputs.double(), requires_grad=True)
test_th = gradcheck(th_cv2_rod.apply, (inputs_var, ))
print('th_cv2_rod_id test passed !')

71
manopth/rot6d.py Normal file
View File

@@ -0,0 +1,71 @@
import torch
def compute_rotation_matrix_from_ortho6d(poses):
"""
Code from
https://github.com/papagina/RotationContinuity
On the Continuity of Rotation Representations in Neural Networks
Zhou et al. CVPR19
https://zhouyisjtu.github.io/project_rotation/rotation.html
"""
x_raw = poses[:, 0:3] # batch*3
y_raw = poses[:, 3:6] # batch*3
x = normalize_vector(x_raw) # batch*3
z = cross_product(x, y_raw) # batch*3
z = normalize_vector(z) # batch*3
y = cross_product(z, x) # batch*3
x = x.view(-1, 3, 1)
y = y.view(-1, 3, 1)
z = z.view(-1, 3, 1)
matrix = torch.cat((x, y, z), 2) # batch*3*3
return matrix
def robust_compute_rotation_matrix_from_ortho6d(poses):
"""
Instead of making 2nd vector orthogonal to first
create a base that takes into account the two predicted
directions equally
"""
x_raw = poses[:, 0:3] # batch*3
y_raw = poses[:, 3:6] # batch*3
x = normalize_vector(x_raw) # batch*3
y = normalize_vector(y_raw) # batch*3
middle = normalize_vector(x + y)
orthmid = normalize_vector(x - y)
x = normalize_vector(middle + orthmid)
y = normalize_vector(middle - orthmid)
# Their scalar product should be small !
# assert torch.einsum("ij,ij->i", [x, y]).abs().max() < 0.00001
z = normalize_vector(cross_product(x, y))
x = x.view(-1, 3, 1)
y = y.view(-1, 3, 1)
z = z.view(-1, 3, 1)
matrix = torch.cat((x, y, z), 2) # batch*3*3
# Check for reflection in matrix ! If found, flip last vector TODO
assert (torch.stack([torch.det(mat) for mat in matrix ])< 0).sum() == 0
return matrix
def normalize_vector(v):
batch = v.shape[0]
v_mag = torch.sqrt(v.pow(2).sum(1)) # batch
v_mag = torch.max(v_mag, v.new([1e-8]))
v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1])
v = v/v_mag
return v
def cross_product(u, v):
batch = u.shape[0]
i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1]
j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2]
k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0]
out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1)
return out

21
manopth/rotproj.py Normal file
View File

@@ -0,0 +1,21 @@
import torch
def batch_rotprojs(batches_rotmats):
proj_rotmats = []
for batch_idx, batch_rotmats in enumerate(batches_rotmats):
proj_batch_rotmats = []
for rot_idx, rotmat in enumerate(batch_rotmats):
# GPU implementation of svd is VERY slow
# ~ 2 10^-3 per hit vs 5 10^-5 on cpu
U, S, V = rotmat.cpu().svd()
rotmat = torch.matmul(U, V.transpose(0, 1))
orth_det = rotmat.det()
# Remove reflection
if orth_det < 0:
rotmat[:, 2] = -1 * rotmat[:, 2]
rotmat = rotmat.cuda()
proj_batch_rotmats.append(rotmat)
proj_rotmats.append(torch.stack(proj_batch_rotmats))
return torch.stack(proj_rotmats)

View File

@@ -0,0 +1,155 @@
'''
Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
This software is provided for research purposes only.
By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
For comments or questions, please email us at: mano@tue.mpg.de
About this file:
================
This file defines a wrapper for the loading functions of the MANO model.
Modules included:
- load_model:
loads the MANO model from a given file location (i.e. a .pkl file location),
or a dictionary object.
'''
def col(A):
return A.reshape((-1, 1))
def MatVecMult(mtx, vec):
result = mtx.dot(col(vec.ravel())).ravel()
if len(vec.shape) > 1 and vec.shape[1] > 1:
result = result.reshape((-1, vec.shape[1]))
return result
def ready_arguments(fname_or_dict, posekey4vposed='pose'):
import numpy as np
import pickle
from manopth.posemapper import posemap
if not isinstance(fname_or_dict, dict):
dd = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1')
# dd = pickle.load(open(fname_or_dict, 'rb'))
else:
dd = fname_or_dict
want_shapemodel = 'shapedirs' in dd
nposeparms = dd['kintree_table'].shape[1] * 3
if 'trans' not in dd:
dd['trans'] = np.zeros(3)
if 'pose' not in dd:
dd['pose'] = np.zeros(nposeparms)
if 'shapedirs' in dd and 'betas' not in dd:
dd['betas'] = np.zeros(dd['shapedirs'].shape[-1])
for s in [
'v_template', 'weights', 'posedirs', 'pose', 'trans', 'shapedirs',
'betas', 'J'
]:
if (s in dd) and not hasattr(dd[s], 'dterms'):
dd[s] = np.array(dd[s])
assert (posekey4vposed in dd)
if want_shapemodel:
dd['v_shaped'] = dd['shapedirs'].dot(dd['betas']) + dd['v_template']
v_shaped = dd['v_shaped']
J_tmpx = MatVecMult(dd['J_regressor'], v_shaped[:, 0])
J_tmpy = MatVecMult(dd['J_regressor'], v_shaped[:, 1])
J_tmpz = MatVecMult(dd['J_regressor'], v_shaped[:, 2])
dd['J'] = np.vstack((J_tmpx, J_tmpy, J_tmpz)).T
pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed])
dd['v_posed'] = v_shaped + dd['posedirs'].dot(pose_map_res)
else:
pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed])
dd_add = dd['posedirs'].dot(pose_map_res)
dd['v_posed'] = dd['v_template'] + dd_add
return dd
def load_model(fname_or_dict, ncomps=6, flat_hand_mean=False, v_template=None):
''' This model loads the fully articulable HAND SMPL model,
and replaces the pose DOFS by ncomps from PCA'''
from manopth.verts import verts_core
import numpy as np
import pickle
import scipy.sparse as sp
np.random.seed(1)
if not isinstance(fname_or_dict, dict):
smpl_data = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1')
# smpl_data = pickle.load(open(fname_or_dict, 'rb'))
else:
smpl_data = fname_or_dict
rot = 3 # for global orientation!!!
hands_components = smpl_data['hands_components']
hands_mean = np.zeros(hands_components.shape[
1]) if flat_hand_mean else smpl_data['hands_mean']
hands_coeffs = smpl_data['hands_coeffs'][:, :ncomps]
selected_components = np.vstack((hands_components[:ncomps]))
hands_mean = hands_mean.copy()
pose_coeffs = np.zeros(rot + selected_components.shape[0])
full_hand_pose = pose_coeffs[rot:(rot + ncomps)].dot(selected_components)
smpl_data['fullpose'] = np.concatenate((pose_coeffs[:rot],
hands_mean + full_hand_pose))
smpl_data['pose'] = pose_coeffs
Jreg = smpl_data['J_regressor']
if not sp.issparse(Jreg):
smpl_data['J_regressor'] = (sp.csc_matrix(
(Jreg.data, (Jreg.row, Jreg.col)), shape=Jreg.shape))
# slightly modify ready_arguments to make sure that it uses the fullpose
# (which will NOT be pose) for the computation of posedirs
dd = ready_arguments(smpl_data, posekey4vposed='fullpose')
# create the smpl formula with the fullpose,
# but expose the PCA coefficients as smpl.pose for compatibility
args = {
'pose': dd['fullpose'],
'v': dd['v_posed'],
'J': dd['J'],
'weights': dd['weights'],
'kintree_table': dd['kintree_table'],
'xp': np,
'want_Jtr': True,
'bs_style': dd['bs_style'],
}
result_previous, meta = verts_core(**args)
result = result_previous + dd['trans'].reshape((1, 3))
result.no_translation = result_previous
if meta is not None:
for field in ['Jtr', 'A', 'A_global', 'A_weighted']:
if (hasattr(meta, field)):
setattr(result, field, getattr(meta, field))
setattr(result, 'Jtr', meta)
if hasattr(result, 'Jtr'):
result.J_transformed = result.Jtr + dd['trans'].reshape((1, 3))
for k, v in dd.items():
setattr(result, k, v)
if v_template is not None:
result.v_template[:] = v_template
return result
if __name__ == '__main__':
load_model()

47
manopth/tensutils.py Normal file
View File

@@ -0,0 +1,47 @@
import torch
from manopth import rodrigues_layer
def th_posemap_axisang(pose_vectors):
rot_nb = int(pose_vectors.shape[1] / 3)
pose_vec_reshaped = pose_vectors.contiguous().view(-1, 3)
rot_mats = rodrigues_layer.batch_rodrigues(pose_vec_reshaped)
rot_mats = rot_mats.view(pose_vectors.shape[0], rot_nb * 9)
pose_maps = subtract_flat_id(rot_mats)
return pose_maps, rot_mats
def th_with_zeros(tensor):
batch_size = tensor.shape[0]
padding = tensor.new([0.0, 0.0, 0.0, 1.0])
padding.requires_grad = False
concat_list = [tensor, padding.view(1, 1, 4).repeat(batch_size, 1, 1)]
cat_res = torch.cat(concat_list, 1)
return cat_res
def th_pack(tensor):
batch_size = tensor.shape[0]
padding = tensor.new_zeros((batch_size, 4, 3))
padding.requires_grad = False
pack_list = [padding, tensor]
pack_res = torch.cat(pack_list, 2)
return pack_res
def subtract_flat_id(rot_mats):
# Subtracts identity as a flattened tensor
rot_nb = int(rot_mats.shape[1] / 9)
id_flat = torch.eye(
3, dtype=rot_mats.dtype, device=rot_mats.device).view(1, 9).repeat(
rot_mats.shape[0], rot_nb)
# id_flat.requires_grad = False
results = rot_mats - id_flat
return results
def make_list(tensor):
# type: (List[int]) -> List[int]
return tensor

117
manopth/verts.py Normal file
View File

@@ -0,0 +1,117 @@
'''
Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
This software is provided for research purposes only.
By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
For comments or questions, please email us at: mano@tue.mpg.de
About this file:
================
This file defines a wrapper for the loading functions of the MANO model.
Modules included:
- load_model:
loads the MANO model from a given file location (i.e. a .pkl file location),
or a dictionary object.
'''
import numpy as np
import mano.webuser.lbs as lbs
from mano.webuser.posemapper import posemap
import scipy.sparse as sp
def ischumpy(x):
return hasattr(x, 'dterms')
def verts_decorated(trans,
pose,
v_template,
J_regressor,
weights,
kintree_table,
bs_style,
f,
bs_type=None,
posedirs=None,
betas=None,
shapedirs=None,
want_Jtr=False):
for which in [
trans, pose, v_template, weights, posedirs, betas, shapedirs
]:
if which is not None:
assert ischumpy(which)
v = v_template
if shapedirs is not None:
if betas is None:
betas = np.zeros(shapedirs.shape[-1])
v_shaped = v + shapedirs.dot(betas)
else:
v_shaped = v
if posedirs is not None:
v_posed = v_shaped + posedirs.dot(posemap(bs_type)(pose))
else:
v_posed = v_shaped
v = v_posed
if sp.issparse(J_regressor):
J_tmpx = np.matmul(J_regressor, v_shaped[:, 0])
J_tmpy = np.matmul(J_regressor, v_shaped[:, 1])
J_tmpz = np.matmul(J_regressor, v_shaped[:, 2])
J = np.vstack((J_tmpx, J_tmpy, J_tmpz)).T
else:
assert (ischumpy(J))
assert (bs_style == 'lbs')
result, Jtr = lbs.verts_core(
pose, v, J, weights, kintree_table, want_Jtr=True, xp=np)
tr = trans.reshape((1, 3))
result = result + tr
Jtr = Jtr + tr
result.trans = trans
result.f = f
result.pose = pose
result.v_template = v_template
result.J = J
result.J_regressor = J_regressor
result.weights = weights
result.kintree_table = kintree_table
result.bs_style = bs_style
result.bs_type = bs_type
if posedirs is not None:
result.posedirs = posedirs
result.v_posed = v_posed
if shapedirs is not None:
result.shapedirs = shapedirs
result.betas = betas
result.v_shaped = v_shaped
if want_Jtr:
result.J_transformed = Jtr
return result
def verts_core(pose,
v,
J,
weights,
kintree_table,
bs_style,
want_Jtr=False,
xp=np):
assert (bs_style == 'lbs')
result = lbs.verts_core(pose, v, J, weights, kintree_table, want_Jtr, xp)
return result

View File

@@ -0,0 +1 @@
__version__ = "0.1.0"

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,147 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
"""
import os.path as op
import torch
import logging
import code
from mesh_graphormer.utils.comm import get_world_size
from mesh_graphormer.datasets.human_mesh_tsv import (MeshTSVDataset, MeshTSVYamlDataset)
from mesh_graphormer.datasets.hand_mesh_tsv import (HandMeshTSVDataset, HandMeshTSVYamlDataset)
def build_dataset(yaml_file, args, is_train=True, scale_factor=1):
print(yaml_file)
if not op.isfile(yaml_file):
yaml_file = op.join(args.data_dir, yaml_file)
# code.interact(local=locals())
assert op.isfile(yaml_file)
return MeshTSVYamlDataset(yaml_file, is_train, False, scale_factor)
class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler):
"""
Wraps a BatchSampler, resampling from it until
a specified number of iterations have been sampled
"""
def __init__(self, batch_sampler, num_iterations, start_iter=0):
self.batch_sampler = batch_sampler
self.num_iterations = num_iterations
self.start_iter = start_iter
def __iter__(self):
iteration = self.start_iter
while iteration <= self.num_iterations:
# if the underlying sampler has a set_epoch method, like
# DistributedSampler, used for making each process see
# a different split of the dataset, then set it
if hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(iteration)
for batch in self.batch_sampler:
iteration += 1
if iteration > self.num_iterations:
break
yield batch
def __len__(self):
return self.num_iterations
def make_batch_data_sampler(sampler, images_per_gpu, num_iters=None, start_iter=0):
batch_sampler = torch.utils.data.sampler.BatchSampler(
sampler, images_per_gpu, drop_last=False
)
if num_iters is not None and num_iters >= 0:
batch_sampler = IterationBasedBatchSampler(
batch_sampler, num_iters, start_iter
)
return batch_sampler
def make_data_sampler(dataset, shuffle, distributed):
if distributed:
return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
sampler = torch.utils.data.sampler.RandomSampler(dataset)
else:
sampler = torch.utils.data.sampler.SequentialSampler(dataset)
return sampler
def make_data_loader(args, yaml_file, is_distributed=True,
is_train=True, start_iter=0, scale_factor=1):
dataset = build_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor)
logger = logging.getLogger(__name__)
if is_train==True:
shuffle = True
images_per_gpu = args.per_gpu_train_batch_size
images_per_batch = images_per_gpu * get_world_size()
iters_per_batch = len(dataset) // images_per_batch
num_iters = iters_per_batch * args.num_train_epochs
logger.info("Train with {} images per GPU.".format(images_per_gpu))
logger.info("Total batch size {}".format(images_per_batch))
logger.info("Total training steps {}".format(num_iters))
else:
shuffle = False
images_per_gpu = args.per_gpu_eval_batch_size
num_iters = None
start_iter = 0
sampler = make_data_sampler(dataset, shuffle, is_distributed)
batch_sampler = make_batch_data_sampler(
sampler, images_per_gpu, num_iters, start_iter
)
data_loader = torch.utils.data.DataLoader(
dataset, num_workers=args.num_workers, batch_sampler=batch_sampler,
pin_memory=True,
)
return data_loader
#==============================================================================================
def build_hand_dataset(yaml_file, args, is_train=True, scale_factor=1):
print(yaml_file)
if not op.isfile(yaml_file):
yaml_file = op.join(args.data_dir, yaml_file)
# code.interact(local=locals())
assert op.isfile(yaml_file)
return HandMeshTSVYamlDataset(args, yaml_file, is_train, False, scale_factor)
def make_hand_data_loader(args, yaml_file, is_distributed=True,
is_train=True, start_iter=0, scale_factor=1):
dataset = build_hand_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor)
logger = logging.getLogger(__name__)
if is_train==True:
shuffle = True
images_per_gpu = args.per_gpu_train_batch_size
images_per_batch = images_per_gpu * get_world_size()
iters_per_batch = len(dataset) // images_per_batch
num_iters = iters_per_batch * args.num_train_epochs
logger.info("Train with {} images per GPU.".format(images_per_gpu))
logger.info("Total batch size {}".format(images_per_batch))
logger.info("Total training steps {}".format(num_iters))
else:
shuffle = False
images_per_gpu = args.per_gpu_eval_batch_size
num_iters = None
start_iter = 0
sampler = make_data_sampler(dataset, shuffle, is_distributed)
batch_sampler = make_batch_data_sampler(
sampler, images_per_gpu, num_iters, start_iter
)
data_loader = torch.utils.data.DataLoader(
dataset, num_workers=args.num_workers, batch_sampler=batch_sampler,
pin_memory=True,
)
return data_loader

View File

@@ -0,0 +1,334 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
"""
import cv2
import math
import json
from PIL import Image
import os.path as op
import numpy as np
import code
from mesh_graphormer.utils.tsv_file import TSVFile, CompositeTSVFile
from mesh_graphormer.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml
from mesh_graphormer.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa
import torch
import torchvision.transforms as transforms
class HandMeshTSVDataset(object):
def __init__(self, args, img_file, label_file=None, hw_file=None,
linelist_file=None, is_train=True, cv2_output=False, scale_factor=1):
self.args = args
self.img_file = img_file
self.label_file = label_file
self.hw_file = hw_file
self.linelist_file = linelist_file
self.img_tsv = self.get_tsv_file(img_file)
self.label_tsv = None if label_file is None else self.get_tsv_file(label_file)
self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file)
if self.is_composite:
assert op.isfile(self.linelist_file)
self.line_list = [i for i in range(self.hw_tsv.num_rows())]
else:
self.line_list = load_linelist_file(linelist_file)
self.cv2_output = cv2_output
self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.is_train = is_train
self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor]
self.noise_factor = 0.4
self.rot_factor = 90 # Random rotation in the range [-rot_factor, rot_factor]
self.img_res = 224
self.image_keys = self.prepare_image_keys()
self.joints_definition = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1',
'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
self.root_index = self.joints_definition.index('Wrist')
def get_tsv_file(self, tsv_file):
if tsv_file:
if self.is_composite:
return CompositeTSVFile(tsv_file, self.linelist_file,
root=self.root)
tsv_path = find_file_path_in_yaml(tsv_file, self.root)
return TSVFile(tsv_path)
def get_valid_tsv(self):
# sorted by file size
if self.hw_tsv:
return self.hw_tsv
if self.label_tsv:
return self.label_tsv
def prepare_image_keys(self):
tsv = self.get_valid_tsv()
return [tsv.get_key(i) for i in range(tsv.num_rows())]
def prepare_image_key_to_index(self):
tsv = self.get_valid_tsv()
return {tsv.get_key(i) : i for i in range(tsv.num_rows())}
def augm_params(self):
"""Get augmentation parameters."""
flip = 0 # flipping
pn = np.ones(3) # per channel pixel-noise
if self.args.multiscale_inference == False:
rot = 0 # rotation
sc = 1.0 # scaling
elif self.args.multiscale_inference == True:
rot = self.args.rot
sc = self.args.sc
if self.is_train:
sc = 1.0
# Each channel is multiplied with a number
# in the area [1-opt.noiseFactor,1+opt.noiseFactor]
pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3)
# The rotation is a number in the area [-2*rotFactor, 2*rotFactor]
rot = min(2*self.rot_factor,
max(-2*self.rot_factor, np.random.randn()*self.rot_factor))
# The scale is multiplied with a number
# in the area [1-scaleFactor,1+scaleFactor]
sc = min(1+self.scale_factor,
max(1-self.scale_factor, np.random.randn()*self.scale_factor+1))
# but it is zero with probability 3/5
if np.random.uniform() <= 0.6:
rot = 0
return flip, pn, rot, sc
def rgb_processing(self, rgb_img, center, scale, rot, flip, pn):
"""Process rgb image and do augmentation."""
rgb_img = crop(rgb_img, center, scale,
[self.img_res, self.img_res], rot=rot)
# flip the image
if flip:
rgb_img = flip_img(rgb_img)
# in the rgb image we add pixel noise in a channel-wise manner
rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0]))
rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1]))
rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2]))
# (3,224,224),float,[0,1]
rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0
return rgb_img
def j2d_processing(self, kp, center, scale, r, f):
"""Process gt 2D keypoints and apply all augmentation transforms."""
nparts = kp.shape[0]
for i in range(nparts):
kp[i,0:2] = transform(kp[i,0:2]+1, center, scale,
[self.img_res, self.img_res], rot=r)
# convert to normalized coordinates
kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1.
# flip the x coordinates
if f:
kp = flip_kp(kp)
kp = kp.astype('float32')
return kp
def j3d_processing(self, S, r, f):
"""Process gt 3D keypoints and apply all augmentation transforms."""
# in-plane rotation
rot_mat = np.eye(3)
if not r == 0:
rot_rad = -r * np.pi / 180
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0,:2] = [cs, -sn]
rot_mat[1,:2] = [sn, cs]
S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1])
# flip the x coordinates
if f:
S = flip_kp(S)
S = S.astype('float32')
return S
def pose_processing(self, pose, r, f):
"""Process SMPL theta parameters and apply all augmentation transforms."""
# rotation or the pose parameters
pose = pose.astype('float32')
pose[:3] = rot_aa(pose[:3], r)
# flip the pose parameters
if f:
pose = flip_pose(pose)
# (72),float
pose = pose.astype('float32')
return pose
def get_line_no(self, idx):
return idx if self.line_list is None else self.line_list[idx]
def get_image(self, idx):
line_no = self.get_line_no(idx)
row = self.img_tsv[line_no]
# use -1 to support old format with multiple columns.
cv2_im = img_from_base64(row[-1])
if self.cv2_output:
return cv2_im.astype(np.float32, copy=True)
cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB)
return cv2_im
def get_annotations(self, idx):
line_no = self.get_line_no(idx)
if self.label_tsv is not None:
row = self.label_tsv[line_no]
annotations = json.loads(row[1])
return annotations
else:
return []
def get_target_from_annotations(self, annotations, img_size, idx):
# This function will be overwritten by each dataset to
# decode the labels to specific formats for each task.
return annotations
def get_img_info(self, idx):
if self.hw_tsv is not None:
line_no = self.get_line_no(idx)
row = self.hw_tsv[line_no]
try:
# json string format with "height" and "width" being the keys
return json.loads(row[1])[0]
except ValueError:
# list of strings representing height and width in order
hw_str = row[1].split(' ')
hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])}
return hw_dict
def get_img_key(self, idx):
line_no = self.get_line_no(idx)
# based on the overhead of reading each row.
if self.hw_tsv:
return self.hw_tsv[line_no][0]
elif self.label_tsv:
return self.label_tsv[line_no][0]
else:
return self.img_tsv[line_no][0]
def __len__(self):
if self.line_list is None:
return self.img_tsv.num_rows()
else:
return len(self.line_list)
def __getitem__(self, idx):
img = self.get_image(idx)
img_key = self.get_img_key(idx)
annotations = self.get_annotations(idx)
annotations = annotations[0]
center = annotations['center']
scale = annotations['scale']
has_2d_joints = annotations['has_2d_joints']
has_3d_joints = annotations['has_3d_joints']
joints_2d = np.asarray(annotations['2d_joints'])
joints_3d = np.asarray(annotations['3d_joints'])
if joints_2d.ndim==3:
joints_2d = joints_2d[0]
if joints_3d.ndim==3:
joints_3d = joints_3d[0]
# Get SMPL parameters, if available
has_smpl = np.asarray(annotations['has_smpl'])
pose = np.asarray(annotations['pose'])
betas = np.asarray(annotations['betas'])
# Get augmentation parameters
flip,pn,rot,sc = self.augm_params()
# Process image
img = self.rgb_processing(img, center, sc*scale, rot, flip, pn)
img = torch.from_numpy(img).float()
# Store image before normalization to use it in visualization
transfromed_img = self.normalize_img(img)
# normalize 3d pose by aligning the wrist as the root (at origin)
root_coord = joints_3d[self.root_index,:-1]
joints_3d[:,:-1] = joints_3d[:,:-1] - root_coord[None,:]
# 3d pose augmentation (random flip + rotation, consistent to image and SMPL)
joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip)
# 2d pose augmentation
joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip)
###################################
# Masking percantage
# We observe that 0% or 5% works better for 3D hand mesh
# We think this is probably becasue 3D vertices are quite sparse in the down-sampled hand mesh
mvm_percent = 0.0 # or 0.05
###################################
mjm_mask = np.ones((21,1))
if self.is_train:
num_joints = 21
pb = np.random.random_sample()
masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked
indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num)
mjm_mask[indices,:] = 0.0
mjm_mask = torch.from_numpy(mjm_mask).float()
mvm_mask = np.ones((195,1))
if self.is_train:
num_vertices = 195
pb = np.random.random_sample()
masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked
indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num)
mvm_mask[indices,:] = 0.0
mvm_mask = torch.from_numpy(mvm_mask).float()
meta_data = {}
meta_data['ori_img'] = img
meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float()
meta_data['betas'] = torch.from_numpy(betas).float()
meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float()
meta_data['has_3d_joints'] = has_3d_joints
meta_data['has_smpl'] = has_smpl
meta_data['mjm_mask'] = mjm_mask
meta_data['mvm_mask'] = mvm_mask
# Get 2D keypoints and apply augmentation transforms
meta_data['has_2d_joints'] = has_2d_joints
meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float()
meta_data['scale'] = float(sc * scale)
meta_data['center'] = np.asarray(center).astype(np.float32)
return img_key, transfromed_img, meta_data
class HandMeshTSVYamlDataset(HandMeshTSVDataset):
""" TSVDataset taking a Yaml file for easy function call
"""
def __init__(self, args, yaml_file, is_train=True, cv2_output=False, scale_factor=1):
self.cfg = load_from_yaml_file(yaml_file)
self.is_composite = self.cfg.get('composite', False)
self.root = op.dirname(yaml_file)
if self.is_composite==False:
img_file = find_file_path_in_yaml(self.cfg['img'], self.root)
label_file = find_file_path_in_yaml(self.cfg.get('label', None),
self.root)
hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root)
linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
self.root)
else:
img_file = self.cfg['img']
hw_file = self.cfg['hw']
label_file = self.cfg.get('label', None)
linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
self.root)
super(HandMeshTSVYamlDataset, self).__init__(
args, img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor)

View File

@@ -0,0 +1,337 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
"""
import cv2
import math
import json
from PIL import Image
import os.path as op
import numpy as np
import code
from mesh_graphormer.utils.tsv_file import TSVFile, CompositeTSVFile
from mesh_graphormer.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml
from mesh_graphormer.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa
import torch
import torchvision.transforms as transforms
class MeshTSVDataset(object):
def __init__(self, img_file, label_file=None, hw_file=None,
linelist_file=None, is_train=True, cv2_output=False, scale_factor=1):
self.img_file = img_file
self.label_file = label_file
self.hw_file = hw_file
self.linelist_file = linelist_file
self.img_tsv = self.get_tsv_file(img_file)
self.label_tsv = None if label_file is None else self.get_tsv_file(label_file)
self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file)
if self.is_composite:
assert op.isfile(self.linelist_file)
self.line_list = [i for i in range(self.hw_tsv.num_rows())]
else:
self.line_list = load_linelist_file(linelist_file)
self.cv2_output = cv2_output
self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.is_train = is_train
self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor]
self.noise_factor = 0.4
self.rot_factor = 30 # Random rotation in the range [-rot_factor, rot_factor]
self.img_res = 224
self.image_keys = self.prepare_image_keys()
self.joints_definition = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder',
'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear')
self.pelvis_index = self.joints_definition.index('Pelvis')
def get_tsv_file(self, tsv_file):
if tsv_file:
if self.is_composite:
return CompositeTSVFile(tsv_file, self.linelist_file,
root=self.root)
tsv_path = find_file_path_in_yaml(tsv_file, self.root)
return TSVFile(tsv_path)
def get_valid_tsv(self):
# sorted by file size
if self.hw_tsv:
return self.hw_tsv
if self.label_tsv:
return self.label_tsv
def prepare_image_keys(self):
tsv = self.get_valid_tsv()
return [tsv.get_key(i) for i in range(tsv.num_rows())]
def prepare_image_key_to_index(self):
tsv = self.get_valid_tsv()
return {tsv.get_key(i) : i for i in range(tsv.num_rows())}
def augm_params(self):
"""Get augmentation parameters."""
flip = 0 # flipping
pn = np.ones(3) # per channel pixel-noise
rot = 0 # rotation
sc = 1 # scaling
if self.is_train:
# We flip with probability 1/2
if np.random.uniform() <= 0.5:
flip = 1
# Each channel is multiplied with a number
# in the area [1-opt.noiseFactor,1+opt.noiseFactor]
pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3)
# The rotation is a number in the area [-2*rotFactor, 2*rotFactor]
rot = min(2*self.rot_factor,
max(-2*self.rot_factor, np.random.randn()*self.rot_factor))
# The scale is multiplied with a number
# in the area [1-scaleFactor,1+scaleFactor]
sc = min(1+self.scale_factor,
max(1-self.scale_factor, np.random.randn()*self.scale_factor+1))
# but it is zero with probability 3/5
if np.random.uniform() <= 0.6:
rot = 0
return flip, pn, rot, sc
def rgb_processing(self, rgb_img, center, scale, rot, flip, pn):
"""Process rgb image and do augmentation."""
rgb_img = crop(rgb_img, center, scale,
[self.img_res, self.img_res], rot=rot)
# flip the image
if flip:
rgb_img = flip_img(rgb_img)
# in the rgb image we add pixel noise in a channel-wise manner
rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0]))
rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1]))
rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2]))
# (3,224,224),float,[0,1]
rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0
return rgb_img
def j2d_processing(self, kp, center, scale, r, f):
"""Process gt 2D keypoints and apply all augmentation transforms."""
nparts = kp.shape[0]
for i in range(nparts):
kp[i,0:2] = transform(kp[i,0:2]+1, center, scale,
[self.img_res, self.img_res], rot=r)
# convert to normalized coordinates
kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1.
# flip the x coordinates
if f:
kp = flip_kp(kp)
kp = kp.astype('float32')
return kp
def j3d_processing(self, S, r, f):
"""Process gt 3D keypoints and apply all augmentation transforms."""
# in-plane rotation
rot_mat = np.eye(3)
if not r == 0:
rot_rad = -r * np.pi / 180
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0,:2] = [cs, -sn]
rot_mat[1,:2] = [sn, cs]
S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1])
# flip the x coordinates
if f:
S = flip_kp(S)
S = S.astype('float32')
return S
def pose_processing(self, pose, r, f):
"""Process SMPL theta parameters and apply all augmentation transforms."""
# rotation or the pose parameters
pose = pose.astype('float32')
pose[:3] = rot_aa(pose[:3], r)
# flip the pose parameters
if f:
pose = flip_pose(pose)
# (72),float
pose = pose.astype('float32')
return pose
def get_line_no(self, idx):
return idx if self.line_list is None else self.line_list[idx]
def get_image(self, idx):
line_no = self.get_line_no(idx)
row = self.img_tsv[line_no]
# use -1 to support old format with multiple columns.
cv2_im = img_from_base64(row[-1])
if self.cv2_output:
return cv2_im.astype(np.float32, copy=True)
cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB)
return cv2_im
def get_annotations(self, idx):
line_no = self.get_line_no(idx)
if self.label_tsv is not None:
row = self.label_tsv[line_no]
annotations = json.loads(row[1])
return annotations
else:
return []
def get_target_from_annotations(self, annotations, img_size, idx):
# This function will be overwritten by each dataset to
# decode the labels to specific formats for each task.
return annotations
def get_img_info(self, idx):
if self.hw_tsv is not None:
line_no = self.get_line_no(idx)
row = self.hw_tsv[line_no]
try:
# json string format with "height" and "width" being the keys
return json.loads(row[1])[0]
except ValueError:
# list of strings representing height and width in order
hw_str = row[1].split(' ')
hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])}
return hw_dict
def get_img_key(self, idx):
line_no = self.get_line_no(idx)
# based on the overhead of reading each row.
if self.hw_tsv:
return self.hw_tsv[line_no][0]
elif self.label_tsv:
return self.label_tsv[line_no][0]
else:
return self.img_tsv[line_no][0]
def __len__(self):
if self.line_list is None:
return self.img_tsv.num_rows()
else:
return len(self.line_list)
def __getitem__(self, idx):
img = self.get_image(idx)
img_key = self.get_img_key(idx)
annotations = self.get_annotations(idx)
annotations = annotations[0]
center = annotations['center']
scale = annotations['scale']
has_2d_joints = annotations['has_2d_joints']
has_3d_joints = annotations['has_3d_joints']
joints_2d = np.asarray(annotations['2d_joints'])
joints_3d = np.asarray(annotations['3d_joints'])
if joints_2d.ndim==3:
joints_2d = joints_2d[0]
if joints_3d.ndim==3:
joints_3d = joints_3d[0]
# Get SMPL parameters, if available
has_smpl = np.asarray(annotations['has_smpl'])
pose = np.asarray(annotations['pose'])
betas = np.asarray(annotations['betas'])
try:
gender = annotations['gender']
except KeyError:
gender = 'none'
# Get augmentation parameters
flip,pn,rot,sc = self.augm_params()
# Process image
img = self.rgb_processing(img, center, sc*scale, rot, flip, pn)
img = torch.from_numpy(img).float()
# Store image before normalization to use it in visualization
transfromed_img = self.normalize_img(img)
# normalize 3d pose by aligning the pelvis as the root (at origin)
root_pelvis = joints_3d[self.pelvis_index,:-1]
joints_3d[:,:-1] = joints_3d[:,:-1] - root_pelvis[None,:]
# 3d pose augmentation (random flip + rotation, consistent to image and SMPL)
joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip)
# 2d pose augmentation
joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip)
###################################
# Masking percantage
# We observe that 30% works better for human body mesh. Further details are reported in the paper.
mvm_percent = 0.3
###################################
mjm_mask = np.ones((14,1))
if self.is_train:
num_joints = 14
pb = np.random.random_sample()
masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked
indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num)
mjm_mask[indices,:] = 0.0
mjm_mask = torch.from_numpy(mjm_mask).float()
mvm_mask = np.ones((431,1))
if self.is_train:
num_vertices = 431
pb = np.random.random_sample()
masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked
indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num)
mvm_mask[indices,:] = 0.0
mvm_mask = torch.from_numpy(mvm_mask).float()
meta_data = {}
meta_data['ori_img'] = img
meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float()
meta_data['betas'] = torch.from_numpy(betas).float()
meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float()
meta_data['has_3d_joints'] = has_3d_joints
meta_data['has_smpl'] = has_smpl
meta_data['mjm_mask'] = mjm_mask
meta_data['mvm_mask'] = mvm_mask
# Get 2D keypoints and apply augmentation transforms
meta_data['has_2d_joints'] = has_2d_joints
meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float()
meta_data['scale'] = float(sc * scale)
meta_data['center'] = np.asarray(center).astype(np.float32)
meta_data['gender'] = gender
return img_key, transfromed_img, meta_data
class MeshTSVYamlDataset(MeshTSVDataset):
""" TSVDataset taking a Yaml file for easy function call
"""
def __init__(self, yaml_file, is_train=True, cv2_output=False, scale_factor=1):
self.cfg = load_from_yaml_file(yaml_file)
self.is_composite = self.cfg.get('composite', False)
self.root = op.dirname(yaml_file)
if self.is_composite==False:
img_file = find_file_path_in_yaml(self.cfg['img'], self.root)
label_file = find_file_path_in_yaml(self.cfg.get('label', None),
self.root)
hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root)
linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
self.root)
else:
img_file = self.cfg['img']
hw_file = self.cfg['hw']
label_file = self.cfg.get('label', None)
linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
self.root)
super(MeshTSVYamlDataset, self).__init__(
img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor)

View File

View File

@@ -0,0 +1,184 @@
from __future__ import division
import torch
import torch.nn.functional as F
import numpy as np
import scipy.sparse
import math
from pathlib import Path
data_path = Path(__file__).parent / "data"
sparse_to_dense = lambda x: x
device = "cuda"
class SparseMM(torch.autograd.Function):
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
The builtin matrix multiplication operation does not support backpropagation in some cases.
"""
@staticmethod
def forward(ctx, sparse, dense):
ctx.req_grad = dense.requires_grad
ctx.save_for_backward(sparse)
return torch.matmul(sparse, dense)
@staticmethod
def backward(ctx, grad_output):
grad_input = None
sparse, = ctx.saved_tensors
if ctx.req_grad:
grad_input = torch.matmul(sparse.t(), grad_output)
return None, grad_input
def spmm(sparse, dense):
sparse = sparse.to(device)
dense = dense.to(device)
return SparseMM.apply(sparse, dense)
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class BertLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.bias = torch.nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class GraphResBlock(torch.nn.Module):
"""
Graph Residual Block similar to the Bottleneck Residual Block in ResNet
"""
def __init__(self, in_channels, out_channels, mesh_type='body'):
super(GraphResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.lin1 = GraphLinear(in_channels, out_channels // 2)
self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type)
self.lin2 = GraphLinear(out_channels // 2, out_channels)
self.skip_conv = GraphLinear(in_channels, out_channels)
# print('Use BertLayerNorm in GraphResBlock')
self.pre_norm = BertLayerNorm(in_channels)
self.norm1 = BertLayerNorm(out_channels // 2)
self.norm2 = BertLayerNorm(out_channels // 2)
def forward(self, x):
trans_y = F.relu(self.pre_norm(x)).transpose(1,2)
y = self.lin1(trans_y).transpose(1,2)
y = F.relu(self.norm1(y))
y = self.conv(y)
trans_y = F.relu(self.norm2(y)).transpose(1,2)
y = self.lin2(trans_y).transpose(1,2)
z = x+y
return z
# class GraphResBlock(torch.nn.Module):
# """
# Graph Residual Block similar to the Bottleneck Residual Block in ResNet
# """
# def __init__(self, in_channels, out_channels, mesh_type='body'):
# super(GraphResBlock, self).__init__()
# self.in_channels = in_channels
# self.out_channels = out_channels
# self.conv = GraphConvolution(self.in_channels, self.out_channels, mesh_type)
# print('Use BertLayerNorm and GeLU in GraphResBlock')
# self.norm = BertLayerNorm(self.out_channels)
# def forward(self, x):
# y = self.conv(x)
# y = self.norm(y)
# y = gelu(y)
# z = x+y
# return z
class GraphLinear(torch.nn.Module):
"""
Generalization of 1x1 convolutions on Graphs
"""
def __init__(self, in_channels, out_channels):
super(GraphLinear, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.W = torch.nn.Parameter(torch.FloatTensor(out_channels, in_channels))
self.b = torch.nn.Parameter(torch.FloatTensor(out_channels))
self.reset_parameters()
def reset_parameters(self):
w_stdv = 1 / (self.in_channels * self.out_channels)
self.W.data.uniform_(-w_stdv, w_stdv)
self.b.data.uniform_(-w_stdv, w_stdv)
def forward(self, x):
return torch.matmul(self.W[None, :], x) + self.b[None, :, None]
class GraphConvolution(torch.nn.Module):
"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
def __init__(self, in_features, out_features, mesh='body', bias=True):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
if mesh=='body':
adj_indices = torch.load(data_path / 'smpl_431_adjmat_indices.pt')
adj_mat_value = torch.load(data_path / 'smpl_431_adjmat_values.pt')
adj_mat_size = torch.load(data_path / 'smpl_431_adjmat_size.pt')
elif mesh=='hand':
adj_indices = torch.load(data_path / 'mano_195_adjmat_indices.pt')
adj_mat_value = torch.load(data_path / 'mano_195_adjmat_values.pt')
adj_mat_size = torch.load(data_path / 'mano_195_adjmat_size.pt')
self.adjmat = sparse_to_dense(torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size)).to(device)
self.weight = torch.nn.Parameter(torch.FloatTensor(in_features, out_features))
if bias:
self.bias = torch.nn.Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
# stdv = 1. / math.sqrt(self.weight.size(1))
stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, x):
if x.ndimension() == 2:
support = torch.matmul(x, self.weight)
output = torch.matmul(self.adjmat, support)
if self.bias is not None:
output = output + self.bias
return output
else:
output = []
for i in range(x.shape[0]):
support = torch.matmul(x[i], self.weight)
# output.append(torch.matmul(self.adjmat, support))
output.append(spmm(self.adjmat, support))
output = torch.stack(output, dim=0)
if self.bias is not None:
output = output + self.bias
return output
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'

View File

@@ -0,0 +1,184 @@
"""
This file contains the MANO defination and mesh sampling operations for MANO mesh
Adapted from opensource projects
MANOPTH (https://github.com/hassony2/manopth)
Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
GraphCMR (https://github.com/nkolot/GraphCMR/)
"""
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
import os.path as osp
import json
import code
from manopth.manolayer import ManoLayer
import scipy.sparse
import mesh_graphormer.modeling.data.config as cfg
from pathlib import Path
sparse_to_dense = lambda x: x
device = "cuda"
class MANO(nn.Module):
def __init__(self):
super(MANO, self).__init__()
self.mano_dir = str(Path(__file__).parent / "data")
self.layer = self.get_layer()
self.vertex_num = 778
self.face = self.layer.th_faces.numpy()
self.joint_regressor = self.layer.th_J_regressor.numpy()
self.joint_num = 21
self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) )
self.root_joint_idx = self.joints_name.index('Wrist')
# add fingertips to joint_regressor
self.fingertip_vertex_idx = [745, 317, 444, 556, 673] # mesh vertex idx (right hand)
thumbtip_onehot = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
indextip_onehot = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
middletip_onehot = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
ringtip_onehot = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
pinkytip_onehot = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot))
self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:]
joint_regressor_torch = torch.from_numpy(self.joint_regressor).float()
self.register_buffer('joint_regressor_torch', joint_regressor_torch)
def get_layer(self):
return ManoLayer(mano_root=osp.join(self.mano_dir), flat_hand_mean=False, use_pca=False) # load right hand MANO model
def get_3d_joints(self, vertices):
"""
This method is used to get the joint locations from the SMPL mesh
Input:
vertices: size = (B, 778, 3)
Output:
3D joints: size = (B, 21, 3)
"""
joints = torch.einsum('bik,ji->bjk', [vertices, self.joint_regressor_torch])
return joints
class SparseMM(torch.autograd.Function):
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
The builtin matrix multiplication operation does not support backpropagation in some cases.
"""
@staticmethod
def forward(ctx, sparse, dense):
ctx.req_grad = dense.requires_grad
ctx.save_for_backward(sparse)
return torch.matmul(sparse, dense)
@staticmethod
def backward(ctx, grad_output):
grad_input = None
sparse, = ctx.saved_tensors
if ctx.req_grad:
grad_input = torch.matmul(sparse.t(), grad_output)
return None, grad_input
def spmm(sparse, dense):
sparse = sparse.to(device)
dense = dense.to(device)
return SparseMM.apply(sparse, dense)
def scipy_to_pytorch(A, U, D):
"""Convert scipy sparse matrices to pytorch sparse matrix."""
ptU = []
ptD = []
for i in range(len(U)):
u = scipy.sparse.coo_matrix(U[i])
i = torch.LongTensor(np.array([u.row, u.col]))
v = torch.FloatTensor(u.data)
ptU.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, u.shape)))
for i in range(len(D)):
d = scipy.sparse.coo_matrix(D[i])
i = torch.LongTensor(np.array([d.row, d.col]))
v = torch.FloatTensor(d.data)
ptD.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, d.shape)))
return ptU, ptD
def adjmat_sparse(adjmat, nsize=1):
"""Create row-normalized sparse graph adjacency matrix."""
adjmat = scipy.sparse.csr_matrix(adjmat)
if nsize > 1:
orig_adjmat = adjmat.copy()
for _ in range(1, nsize):
adjmat = adjmat * orig_adjmat
adjmat.data = np.ones_like(adjmat.data)
for i in range(adjmat.shape[0]):
adjmat[i,i] = 1
num_neighbors = np.array(1 / adjmat.sum(axis=-1))
adjmat = adjmat.multiply(num_neighbors)
adjmat = scipy.sparse.coo_matrix(adjmat)
row = adjmat.row
col = adjmat.col
data = adjmat.data
i = torch.LongTensor(np.array([row, col]))
v = torch.from_numpy(data).float()
adjmat = sparse_to_dense(torch.sparse_coo_tensor(i, v, adjmat.shape))
return adjmat
def get_graph_params(filename, nsize=1):
"""Load and process graph adjacency matrix and upsampling/downsampling matrices."""
data = np.load(filename, encoding='latin1', allow_pickle=True)
A = data['A']
U = data['U']
D = data['D']
U, D = scipy_to_pytorch(A, U, D)
A = [adjmat_sparse(a, nsize=nsize) for a in A]
return A, U, D
class Mesh(object):
"""Mesh object that is used for handling certain graph operations."""
def __init__(self, filename=cfg.MANO_sampling_matrix,
num_downsampling=1, nsize=1, device=torch.device('cuda')):
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
# self._A = [a.to(device) for a in self._A]
self._U = [u.to(device) for u in self._U]
self._D = [d.to(device) for d in self._D]
self.num_downsampling = num_downsampling
def downsample(self, x, n1=0, n2=None):
"""Downsample mesh."""
if n2 is None:
n2 = self.num_downsampling
if x.ndimension() < 3:
for i in range(n1, n2):
x = spmm(self._D[i], x)
elif x.ndimension() == 3:
out = []
for i in range(x.shape[0]):
y = x[i]
for j in range(n1, n2):
y = spmm(self._D[j], y)
out.append(y)
x = torch.stack(out, dim=0)
return x
def upsample(self, x, n1=1, n2=0):
"""Upsample mesh."""
if x.ndimension() < 3:
for i in reversed(range(n2, n1)):
x = spmm(self._U[i], x)
elif x.ndimension() == 3:
out = []
for i in range(x.shape[0]):
y = x[i]
for j in reversed(range(n2, n1)):
y = spmm(self._U[j], y)
out.append(y)
x = torch.stack(out, dim=0)
return x

View File

@@ -0,0 +1,283 @@
"""
This file contains the definition of the SMPL model
It is adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/)
"""
from __future__ import division
import torch
import torch.nn as nn
import numpy as np
import scipy.sparse
try:
import cPickle as pickle
except ImportError:
import pickle
from mesh_graphormer.utils.geometric_layers import rodrigues
import mesh_graphormer.modeling.data.config as cfg
sparse_to_dense = lambda x: x
device = "cuda"
class SMPL(nn.Module):
def __init__(self, gender='neutral'):
super(SMPL, self).__init__()
if gender=='m':
model_file=cfg.SMPL_Male
elif gender=='f':
model_file=cfg.SMPL_Female
else:
model_file=cfg.SMPL_FILE
smpl_model = pickle.load(open(model_file, 'rb'), encoding='latin1')
J_regressor = smpl_model['J_regressor'].tocoo()
row = J_regressor.row
col = J_regressor.col
data = J_regressor.data
i = torch.LongTensor([row, col])
v = torch.FloatTensor(data)
J_regressor_shape = [24, 6890]
self.register_buffer('J_regressor', torch.sparse_coo_tensor(i, v, J_regressor_shape).to_dense())
self.register_buffer('weights', torch.FloatTensor(smpl_model['weights']))
self.register_buffer('posedirs', torch.FloatTensor(smpl_model['posedirs']))
self.register_buffer('v_template', torch.FloatTensor(smpl_model['v_template']))
self.register_buffer('shapedirs', torch.FloatTensor(np.array(smpl_model['shapedirs'])))
self.register_buffer('faces', torch.from_numpy(smpl_model['f'].astype(np.int64)))
self.register_buffer('kintree_table', torch.from_numpy(smpl_model['kintree_table'].astype(np.int64)))
id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])}
self.register_buffer('parent', torch.LongTensor([id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])]))
self.pose_shape = [24, 3]
self.beta_shape = [10]
self.translation_shape = [3]
self.pose = torch.zeros(self.pose_shape)
self.beta = torch.zeros(self.beta_shape)
self.translation = torch.zeros(self.translation_shape)
self.verts = None
self.J = None
self.R = None
J_regressor_extra = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_TRAIN_EXTRA)).float()
self.register_buffer('J_regressor_extra', J_regressor_extra)
self.joints_idx = cfg.JOINTS_IDX
J_regressor_h36m_correct = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M_correct)).float()
self.register_buffer('J_regressor_h36m_correct', J_regressor_h36m_correct)
def forward(self, pose, beta):
device = pose.device
batch_size = pose.shape[0]
v_template = self.v_template[None, :]
shapedirs = self.shapedirs.view(-1,10)[None, :].expand(batch_size, -1, -1)
beta = beta[:, :, None]
v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template
# batched sparse matmul not supported in pytorch
J = []
for i in range(batch_size):
J.append(torch.matmul(self.J_regressor, v_shaped[i]))
J = torch.stack(J, dim=0)
# input it rotmat: (bs,24,3,3)
if pose.ndimension() == 4:
R = pose
# input it rotmat: (bs,72)
elif pose.ndimension() == 2:
pose_cube = pose.view(-1, 3) # (batch_size * 24, 1, 3)
R = rodrigues(pose_cube).view(batch_size, 24, 3, 3)
R = R.view(batch_size, 24, 3, 3)
I_cube = torch.eye(3)[None, None, :].to(device)
# I_cube = torch.eye(3)[None, None, :].expand(theta.shape[0], R.shape[1]-1, -1, -1)
lrotmin = (R[:,1:,:] - I_cube).view(batch_size, -1)
posedirs = self.posedirs.view(-1,207)[None, :].expand(batch_size, -1, -1)
v_posed = v_shaped + torch.matmul(posedirs, lrotmin[:, :, None]).view(-1, 6890, 3)
J_ = J.clone()
J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :]
G_ = torch.cat([R, J_[:, :, :, None]], dim=-1)
pad_row = torch.FloatTensor([0,0,0,1]).to(device).view(1,1,1,4).expand(batch_size, 24, -1, -1)
G_ = torch.cat([G_, pad_row], dim=2)
G = [G_[:, 0].clone()]
for i in range(1, 24):
G.append(torch.matmul(G[self.parent[i-1]], G_[:, i, :, :]))
G = torch.stack(G, dim=1)
rest = torch.cat([J, torch.zeros(batch_size, 24, 1).to(device)], dim=2).view(batch_size, 24, 4, 1)
zeros = torch.zeros(batch_size, 24, 4, 3).to(device)
rest = torch.cat([zeros, rest], dim=-1)
rest = torch.matmul(G, rest)
G = G - rest
T = torch.matmul(self.weights, G.permute(1,0,2,3).contiguous().view(24,-1)).view(6890, batch_size, 4, 4).transpose(0,1)
rest_shape_h = torch.cat([v_posed, torch.ones_like(v_posed)[:, :, [0]]], dim=-1)
v = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0]
return v
def get_joints(self, vertices):
"""
This method is used to get the joint locations from the SMPL mesh
Input:
vertices: size = (B, 6890, 3)
Output:
3D joints: size = (B, 38, 3)
"""
joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor])
joints_extra = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_extra])
joints = torch.cat((joints, joints_extra), dim=1)
joints = joints[:, cfg.JOINTS_IDX]
return joints
def get_h36m_joints(self, vertices):
"""
This method is used to get the joint locations from the SMPL mesh
Input:
vertices: size = (B, 6890, 3)
Output:
3D joints: size = (B, 24, 3)
"""
joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct])
return joints
class SparseMM(torch.autograd.Function):
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
The builtin matrix multiplication operation does not support backpropagation in some cases.
"""
@staticmethod
def forward(ctx, sparse, dense):
ctx.req_grad = dense.requires_grad
ctx.save_for_backward(sparse)
return torch.matmul(sparse, dense)
@staticmethod
def backward(ctx, grad_output):
grad_input = None
sparse, = ctx.saved_tensors
if ctx.req_grad:
grad_input = torch.matmul(sparse.t(), grad_output)
return None, grad_input
def spmm(sparse, dense):
sparse = sparse.to(device)
dense = dense.to(device)
return SparseMM.apply(sparse, dense)
def scipy_to_pytorch(A, U, D):
"""Convert scipy sparse matrices to pytorch sparse matrix."""
ptU = []
ptD = []
for i in range(len(U)):
u = scipy.sparse.coo_matrix(U[i])
i = torch.LongTensor(np.array([u.row, u.col]))
v = torch.FloatTensor(u.data)
ptU.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, u.shape)))
for i in range(len(D)):
d = scipy.sparse.coo_matrix(D[i])
i = torch.LongTensor(np.array([d.row, d.col]))
v = torch.FloatTensor(d.data)
ptD.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, d.shape)))
return ptU, ptD
def adjmat_sparse(adjmat, nsize=1):
"""Create row-normalized sparse graph adjacency matrix."""
adjmat = scipy.sparse.csr_matrix(adjmat)
if nsize > 1:
orig_adjmat = adjmat.copy()
for _ in range(1, nsize):
adjmat = adjmat * orig_adjmat
adjmat.data = np.ones_like(adjmat.data)
for i in range(adjmat.shape[0]):
adjmat[i,i] = 1
num_neighbors = np.array(1 / adjmat.sum(axis=-1))
adjmat = adjmat.multiply(num_neighbors)
adjmat = scipy.sparse.coo_matrix(adjmat)
row = adjmat.row
col = adjmat.col
data = adjmat.data
i = torch.LongTensor(np.array([row, col]))
v = torch.from_numpy(data).float()
adjmat = sparse_to_dense(torch.sparse_coo_tensor(i, v, adjmat.shape))
return adjmat
def get_graph_params(filename, nsize=1):
"""Load and process graph adjacency matrix and upsampling/downsampling matrices."""
data = np.load(filename, encoding='latin1', allow_pickle=True)
A = data['A']
U = data['U']
D = data['D']
U, D = scipy_to_pytorch(A, U, D)
A = [adjmat_sparse(a, nsize=nsize) for a in A]
return A, U, D
class Mesh(object):
"""Mesh object that is used for handling certain graph operations."""
def __init__(self, filename=cfg.SMPL_sampling_matrix,
num_downsampling=1, nsize=1, device=torch.device('cuda')):
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
# self._A = [a.to(device) for a in self._A]
self._U = [u.to(device) for u in self._U]
self._D = [d.to(device) for d in self._D]
self.num_downsampling = num_downsampling
# load template vertices from SMPL and normalize them
smpl = SMPL()
ref_vertices = smpl.v_template
center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None]
ref_vertices -= center
ref_vertices /= ref_vertices.abs().max().item()
self._ref_vertices = ref_vertices.to(device)
self.faces = smpl.faces.int().to(device)
# @property
# def adjmat(self):
# """Return the graph adjacency matrix at the specified subsampling level."""
# return self._A[self.num_downsampling].float()
@property
def ref_vertices(self):
"""Return the template vertices at the specified subsampling level."""
ref_vertices = self._ref_vertices
for i in range(self.num_downsampling):
ref_vertices = torch.spmm(self._D[i], ref_vertices)
return ref_vertices
def downsample(self, x, n1=0, n2=None):
"""Downsample mesh."""
if n2 is None:
n2 = self.num_downsampling
if x.ndimension() < 3:
for i in range(n1, n2):
x = spmm(self._D[i], x)
elif x.ndimension() == 3:
out = []
for i in range(x.shape[0]):
y = x[i]
for j in range(n1, n2):
y = spmm(self._D[j], y)
out.append(y)
x = torch.stack(out, dim=0)
return x
def upsample(self, x, n1=1, n2=0):
"""Upsample mesh."""
if x.ndimension() < 3:
for i in reversed(range(n2, n1)):
x = spmm(self._U[i], x)
elif x.ndimension() == 3:
out = []
for i in range(x.shape[0]):
y = x[i]
for j in reversed(range(n2, n1)):
y = spmm(self._U[j], y)
out.append(y)
x = torch.stack(out, dim=0)
return x

View File

@@ -0,0 +1,17 @@
__version__ = "1.0.0"
from .modeling_bert import (BertConfig, BertModel,
load_tf_weights_in_bert)
from .modeling_graphormer import Graphormer
from .e2e_body_network import Graphormer_Body_Network
from .e2e_hand_network import Graphormer_Hand_Network
CONFIG_NAME = "config.json"
from .modeling_utils import (WEIGHTS_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE)

View File

@@ -0,0 +1,16 @@
{
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"type_vocab_size": 2,
"vocab_size": 30522
}

View File

@@ -0,0 +1,103 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
"""
import torch
import mesh_graphormer.modeling.data.config as cfg
device = "cuda"
class Graphormer_Body_Network(torch.nn.Module):
'''
End-to-end Graphormer network for human pose and mesh reconstruction from a single image.
'''
def __init__(self, args, config, backbone, trans_encoder, mesh_sampler):
super(Graphormer_Body_Network, self).__init__()
self.config = config
self.config.device = device
self.backbone = backbone
self.trans_encoder = trans_encoder
self.upsampling = torch.nn.Linear(431, 1723)
self.upsampling2 = torch.nn.Linear(1723, 6890)
self.cam_param_fc = torch.nn.Linear(3, 1)
self.cam_param_fc2 = torch.nn.Linear(431, 250)
self.cam_param_fc3 = torch.nn.Linear(250, 3)
self.grid_feat_dim = torch.nn.Linear(1024, 2051)
def forward(self, images, smpl, mesh_sampler, meta_masks=None, is_train=False):
batch_size = images.size(0)
# Generate T-pose template mesh
template_pose = torch.zeros((1,72))
template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord
template_pose = template_pose.to(device)
template_betas = torch.zeros((1,10)).to(device)
template_vertices = smpl(template_pose, template_betas)
# template mesh simplification
template_vertices_sub = mesh_sampler.downsample(template_vertices)
template_vertices_sub2 = mesh_sampler.downsample(template_vertices_sub, n1=1, n2=2)
# template mesh-to-joint regression
template_3d_joints = smpl.get_h36m_joints(template_vertices)
template_pelvis = template_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
template_3d_joints = template_3d_joints[:,cfg.H36M_J17_TO_J14,:]
num_joints = template_3d_joints.shape[1]
# normalize
template_3d_joints = template_3d_joints - template_pelvis[:, None, :]
template_vertices_sub2 = template_vertices_sub2 - template_pelvis[:, None, :]
# concatinate template joints and template vertices, and then duplicate to batch size
ref_vertices = torch.cat([template_3d_joints, template_vertices_sub2],dim=1)
ref_vertices = ref_vertices.expand(batch_size, -1, -1)
# extract grid features and global image features using a CNN backbone
image_feat, grid_feat = self.backbone(images)
# concatinate image feat and 3d mesh template
image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1)
# process grid features
grid_feat = torch.flatten(grid_feat, start_dim=2)
grid_feat = grid_feat.transpose(1,2)
grid_feat = self.grid_feat_dim(grid_feat)
# concatinate image feat and template mesh to form the joint/vertex queries
features = torch.cat([ref_vertices, image_feat], dim=2)
# prepare input tokens including joint/vertex queries and grid features
features = torch.cat([features, grid_feat],dim=1)
if is_train==True:
# apply mask vertex/joint modeling
# meta_masks is a tensor of all the masks, randomly generated in dataloader
# we pre-define a [MASK] token, which is a floating-value vector with 0.01s
special_token = torch.ones_like(features[:,:-49,:]).to(device)*0.01
features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks)
# forward pass
if self.config.output_attentions==True:
features, hidden_states, att = self.trans_encoder(features)
else:
features = self.trans_encoder(features)
pred_3d_joints = features[:,:num_joints,:]
pred_vertices_sub2 = features[:,num_joints:-49,:]
# learn camera parameters
x = self.cam_param_fc(pred_vertices_sub2)
x = x.transpose(1,2)
x = self.cam_param_fc2(x)
x = self.cam_param_fc3(x)
cam_param = x.transpose(1,2)
cam_param = cam_param.squeeze()
temp_transpose = pred_vertices_sub2.transpose(1,2)
pred_vertices_sub = self.upsampling(temp_transpose)
pred_vertices_full = self.upsampling2(pred_vertices_sub)
pred_vertices_sub = pred_vertices_sub.transpose(1,2)
pred_vertices_full = pred_vertices_full.transpose(1,2)
if self.config.output_attentions==True:
return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full, hidden_states, att
else:
return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full

View File

@@ -0,0 +1,94 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
"""
import torch
import mesh_graphormer.modeling.data.config as cfg
device = "cuda"
class Graphormer_Hand_Network(torch.nn.Module):
'''
End-to-end Graphormer network for hand pose and mesh reconstruction from a single image.
'''
def __init__(self, args, config, backbone, trans_encoder):
super(Graphormer_Hand_Network, self).__init__()
self.config = config
self.backbone = backbone
self.trans_encoder = trans_encoder
self.upsampling = torch.nn.Linear(195, 778)
self.cam_param_fc = torch.nn.Linear(3, 1)
self.cam_param_fc2 = torch.nn.Linear(195+21, 150)
self.cam_param_fc3 = torch.nn.Linear(150, 3)
self.grid_feat_dim = torch.nn.Linear(1024, 2051)
def forward(self, images, mesh_model, mesh_sampler, meta_masks=None, is_train=False):
batch_size = images.size(0)
# Generate T-pose template mesh
template_pose = torch.zeros((1,48))
template_pose = template_pose.to(device)
template_betas = torch.zeros((1,10)).to(device)
template_vertices, template_3d_joints = mesh_model.layer(template_pose, template_betas)
template_vertices = template_vertices/1000.0
template_3d_joints = template_3d_joints/1000.0
template_vertices_sub = mesh_sampler.downsample(template_vertices)
# normalize
template_root = template_3d_joints[:,cfg.J_NAME.index('Wrist'),:]
template_3d_joints = template_3d_joints - template_root[:, None, :]
template_vertices = template_vertices - template_root[:, None, :]
template_vertices_sub = template_vertices_sub - template_root[:, None, :]
num_joints = template_3d_joints.shape[1]
# concatinate template joints and template vertices, and then duplicate to batch size
ref_vertices = torch.cat([template_3d_joints, template_vertices_sub],dim=1)
ref_vertices = ref_vertices.expand(batch_size, -1, -1)
# extract grid features and global image features using a CNN backbone
image_feat, grid_feat = self.backbone(images)
# concatinate image feat and mesh template
image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1)
# process grid features
grid_feat = torch.flatten(grid_feat, start_dim=2)
grid_feat = grid_feat.transpose(1,2)
grid_feat = self.grid_feat_dim(grid_feat)
# concatinate image feat and template mesh to form the joint/vertex queries
features = torch.cat([ref_vertices, image_feat], dim=2)
# prepare input tokens including joint/vertex queries and grid features
features = torch.cat([features, grid_feat],dim=1)
if is_train==True:
# apply mask vertex/joint modeling
# meta_masks is a tensor of all the masks, randomly generated in dataloader
# we pre-define a [MASK] token, which is a floating-value vector with 0.01s
special_token = torch.ones_like(features[:,:-49,:]).to(device)*0.01
features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks)
# forward pass
if self.config.output_attentions==True:
features, hidden_states, att = self.trans_encoder(features)
else:
features = self.trans_encoder(features)
pred_3d_joints = features[:,:num_joints,:]
pred_vertices_sub = features[:,num_joints:-49,:]
# learn camera parameters
x = self.cam_param_fc(features[:,:-49,:])
x = x.transpose(1,2)
x = self.cam_param_fc2(x)
x = self.cam_param_fc3(x)
cam_param = x.transpose(1,2)
cam_param = cam_param.squeeze()
temp_transpose = pred_vertices_sub.transpose(1,2)
pred_vertices = self.upsampling(temp_transpose)
pred_vertices = pred_vertices.transpose(1,2)
if self.config.output_attentions==True:
return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att
else:
return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices

View File

@@ -0,0 +1 @@
from transformers.file_utils import *

View File

@@ -0,0 +1 @@
from transformers.models.bert.modeling_bert import *

View File

@@ -0,0 +1,328 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import math
import os
import code
import torch
from torch import nn
from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput
import mesh_graphormer.modeling.data.config as cfg
from mesh_graphormer.modeling._gcnn import GraphConvolution, GraphResBlock
from .modeling_utils import prune_linear_layer
LayerNormClass = torch.nn.LayerNorm
BertLayerNorm = torch.nn.LayerNorm
device = "cuda"
class BertSelfAttention(nn.Module):
def __init__(self, config):
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask, head_mask=None,
history_state=None):
if history_state is not None:
x_states = torch.cat([history_state, hidden_states], dim=1)
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(x_states)
mixed_value_layer = self.value(x_states)
else:
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
return outputs
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
for head in heads:
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
def forward(self, input_tensor, attention_mask, head_mask=None,
history_state=None):
self_outputs = self.self(input_tensor, attention_mask, head_mask,
history_state)
attention_output = self.output(self_outputs[0], input_tensor)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class GraphormerLayer(nn.Module):
def __init__(self, config):
super(GraphormerLayer, self).__init__()
self.attention = BertAttention(config)
self.has_graph_conv = config.graph_conv
self.mesh_type = config.mesh_type
if self.has_graph_conv == True:
self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def MHA_GCN(self, hidden_states, attention_mask, head_mask=None,
history_state=None):
attention_outputs = self.attention(hidden_states, attention_mask,
head_mask, history_state)
attention_output = attention_outputs[0]
if self.has_graph_conv==True:
if self.mesh_type == 'body':
joints = attention_output[:,0:14,:]
vertices = attention_output[:,14:-49,:]
img_tokens = attention_output[:,-49:,:]
elif self.mesh_type == 'hand':
joints = attention_output[:,0:21,:]
vertices = attention_output[:,21:-49,:]
img_tokens = attention_output[:,-49:,:]
vertices = self.graph_conv(vertices)
joints_vertices = torch.cat([joints,vertices,img_tokens],dim=1)
else:
joints_vertices = attention_output
intermediate_output = self.intermediate(joints_vertices)
layer_output = self.output(intermediate_output, joints_vertices)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs
def forward(self, hidden_states, attention_mask, head_mask=None,
history_state=None):
return self.MHA_GCN(hidden_states, attention_mask, head_mask,history_state)
class GraphormerEncoder(nn.Module):
def __init__(self, config):
super(GraphormerEncoder, self).__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([GraphormerLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, head_mask=None,
encoder_history_states=None):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
history_state = None if encoder_history_states is None else encoder_history_states[i]
layer_outputs = layer_module(
hidden_states, attention_mask, head_mask[i],
history_state)
hidden_states = layer_outputs[0]
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions)
class EncoderBlock(BertPreTrainedModel):
def __init__(self, config):
super(EncoderBlock, self).__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = GraphormerEncoder(config)
self.pooler = BertPooler(config)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.img_dim = config.img_feature_dim
try:
self.use_img_layernorm = config.use_img_layernorm
except:
self.use_img_layernorm = None
self.img_embedding = nn.Linear(self.img_dim, self.config.hidden_size, bias=True)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if self.use_img_layernorm:
self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.img_layer_norm_eps)
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None,
position_ids=None, head_mask=None):
batch_size = len(img_feats)
seq_length = len(img_feats[0])
input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).to(device)
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
position_embeddings = self.position_embeddings(position_ids)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
if attention_mask.dim() == 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
elif attention_mask.dim() == 3:
extended_attention_mask = attention_mask.unsqueeze(1)
else:
raise NotImplementedError
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
# Project input token features to have spcified hidden size
img_embedding_output = self.img_embedding(img_feats)
# We empirically observe that adding an additional learnable position embedding leads to more stable training
embeddings = position_embeddings + img_embedding_output
if self.use_img_layernorm:
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
encoder_outputs = self.encoder(embeddings,
extended_attention_mask, head_mask=head_mask)
sequence_output = encoder_outputs[0]
outputs = (sequence_output,)
if self.config.output_hidden_states:
all_hidden_states = encoder_outputs[1]
outputs = outputs + (all_hidden_states,)
if self.config.output_attentions:
all_attentions = encoder_outputs[-1]
outputs = outputs + (all_attentions,)
return outputs
class Graphormer(BertPreTrainedModel):
'''
The archtecture of a transformer encoder block we used in Graphormer
'''
def __init__(self, config):
super(Graphormer, self).__init__(config)
self.config = config
self.bert = EncoderBlock(config)
self.cls_head = nn.Linear(config.hidden_size, self.config.output_feature_dim)
self.residual = nn.Linear(config.img_feature_dim, self.config.output_feature_dim)
def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
next_sentence_label=None, position_ids=None, head_mask=None):
'''
# self.bert has three outputs
# predictions[0]: output tokens
# predictions[1]: all_hidden_states, if enable "self.config.output_hidden_states"
# predictions[2]: attentions, if enable "self.config.output_attentions"
'''
predictions = self.bert(img_feats=img_feats, input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
# We use "self.cls_head" to perform dimensionality reduction. We don't use it for classification.
pred_score = self.cls_head(predictions[0])
res_img_feats = self.residual(img_feats)
pred_score = pred_score + res_img_feats
if self.config.output_attentions and self.config.output_hidden_states:
return pred_score, predictions[1], predictions[-1]
else:
return pred_score

View File

@@ -0,0 +1 @@
from transformers.modeling_utils import *

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,30 @@
# Extra data
Adapted from open source project [GraphCMR](https://github.com/nkolot/GraphCMR/) and [Pose2Mesh](https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
Our code requires additional data to run smoothly.
### J_regressor_extra.npy
Joints regressor for joints or landmarks that are not included in the standard set of SMPL joints.
### J_regressor_h36m_correct.npy
Joints regressor reflecting the Human3.6M joints.
### mesh_downsampling.npz
Extra file with precomputed downsampling for the SMPL body mesh.
### mano_downsampling.npz
Extra file with precomputed downsampling for the MANO hand mesh.
### basicModel_neutral_lbs_10_207_0_v1.0.0.pkl
SMPL neutral model. Please visit the official website to download the file [http://smplify.is.tue.mpg.de/](http://smplify.is.tue.mpg.de/)
### basicModel_m_lbs_10_207_0_v1.0.0.pkl
SMPL male model. Please visit the official website to download the file [https://smpl.is.tue.mpg.de/](https://smpl.is.tue.mpg.de/)
### basicModel_f_lbs_10_207_0_v1.0.0.pkl
SMPL female model. Please visit the official website to download the file [https://smpl.is.tue.mpg.de/](https://smpl.is.tue.mpg.de/)
### MANO_RIGHT.pkl
MANO hand model. Please visit the official website to download the file [https://mano.is.tue.mpg.de/](https://mano.is.tue.mpg.de/)

View File

@@ -0,0 +1,47 @@
"""
This file contains definitions of useful data stuctures and the paths
for the datasets and data files necessary to run the code.
Adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) and Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
"""
from pathlib import Path
folder_path = Path(__file__).parent.parent
JOINT_REGRESSOR_TRAIN_EXTRA = folder_path / 'data/J_regressor_extra.npy'
JOINT_REGRESSOR_H36M_correct = folder_path / 'data/J_regressor_h36m_correct.npy'
SMPL_FILE = folder_path / 'data/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl'
SMPL_Male = folder_path / 'data/basicModel_m_lbs_10_207_0_v1.0.0.pkl'
SMPL_Female = folder_path / 'data/basicModel_f_lbs_10_207_0_v1.0.0.pkl'
SMPL_sampling_matrix = folder_path / 'data/mesh_downsampling.npz'
MANO_FILE = folder_path / 'data/MANO_RIGHT.pkl'
MANO_sampling_matrix = folder_path / 'data/mano_downsampling.npz'
JOINTS_IDX = [8, 5, 29, 30, 4, 7, 21, 19, 17, 16, 18, 20, 31, 32, 33, 34, 35, 36, 37, 24, 26, 25, 28, 27]
"""
We follow the body joint definition, loss functions, and evaluation metrics from
open source project GraphCMR (https://github.com/nkolot/GraphCMR/)
Each dataset uses different sets of joints.
We use a superset of 24 joints such that we include all joints from every dataset.
If a dataset doesn't provide annotations for a specific joint, we simply ignore it.
The joints used here are:
"""
J24_NAME = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder',
'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear')
H36M_J17_NAME = ( 'Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head',
'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist')
J24_TO_J14 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
H36M_J17_TO_J14 = [3, 2, 1, 4, 5, 6, 16, 15, 14, 11, 12, 13, 8, 10]
"""
We follow the hand joint definition and mesh topology from
open source project Manopth (https://github.com/hassony2/manopth)
The hand joints used here are:
"""
J_NAME = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1',
'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
ROOT_INDEX = 0

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,9 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# ------------------------------------------------------------------------------
from .default import _C as config
from .default import update_config
from .models import MODEL_EXTRAS

View File

@@ -0,0 +1,138 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# Modified by Ke Sun (sunk@mail.ustc.edu.cn)
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from yacs.config import CfgNode as CN
_C = CN()
_C.OUTPUT_DIR = ''
_C.LOG_DIR = ''
_C.DATA_DIR = ''
_C.GPUS = (0,)
_C.WORKERS = 4
_C.PRINT_FREQ = 20
_C.AUTO_RESUME = False
_C.PIN_MEMORY = True
_C.RANK = 0
# Cudnn related params
_C.CUDNN = CN()
_C.CUDNN.BENCHMARK = True
_C.CUDNN.DETERMINISTIC = False
_C.CUDNN.ENABLED = True
# common params for NETWORK
_C.MODEL = CN()
_C.MODEL.NAME = 'cls_hrnet'
_C.MODEL.INIT_WEIGHTS = True
_C.MODEL.PRETRAINED = ''
_C.MODEL.NUM_JOINTS = 17
_C.MODEL.NUM_CLASSES = 1000
_C.MODEL.TAG_PER_JOINT = True
_C.MODEL.TARGET_TYPE = 'gaussian'
_C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256
_C.MODEL.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32
_C.MODEL.SIGMA = 2
_C.MODEL.EXTRA = CN(new_allowed=True)
_C.LOSS = CN()
_C.LOSS.USE_OHKM = False
_C.LOSS.TOPK = 8
_C.LOSS.USE_TARGET_WEIGHT = True
_C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False
# DATASET related params
_C.DATASET = CN()
_C.DATASET.ROOT = ''
_C.DATASET.DATASET = 'mpii'
_C.DATASET.TRAIN_SET = 'train'
_C.DATASET.TEST_SET = 'valid'
_C.DATASET.DATA_FORMAT = 'jpg'
_C.DATASET.HYBRID_JOINTS_TYPE = ''
_C.DATASET.SELECT_DATA = False
# training data augmentation
_C.DATASET.FLIP = True
_C.DATASET.SCALE_FACTOR = 0.25
_C.DATASET.ROT_FACTOR = 30
_C.DATASET.PROB_HALF_BODY = 0.0
_C.DATASET.NUM_JOINTS_HALF_BODY = 8
_C.DATASET.COLOR_RGB = False
# train
_C.TRAIN = CN()
_C.TRAIN.LR_FACTOR = 0.1
_C.TRAIN.LR_STEP = [90, 110]
_C.TRAIN.LR = 0.001
_C.TRAIN.OPTIMIZER = 'adam'
_C.TRAIN.MOMENTUM = 0.9
_C.TRAIN.WD = 0.0001
_C.TRAIN.NESTEROV = False
_C.TRAIN.GAMMA1 = 0.99
_C.TRAIN.GAMMA2 = 0.0
_C.TRAIN.BEGIN_EPOCH = 0
_C.TRAIN.END_EPOCH = 140
_C.TRAIN.RESUME = False
_C.TRAIN.CHECKPOINT = ''
_C.TRAIN.BATCH_SIZE_PER_GPU = 32
_C.TRAIN.SHUFFLE = True
# testing
_C.TEST = CN()
# size of images for each device
_C.TEST.BATCH_SIZE_PER_GPU = 32
# Test Model Epoch
_C.TEST.FLIP_TEST = False
_C.TEST.POST_PROCESS = False
_C.TEST.SHIFT_HEATMAP = False
_C.TEST.USE_GT_BBOX = False
# nms
_C.TEST.IMAGE_THRE = 0.1
_C.TEST.NMS_THRE = 0.6
_C.TEST.SOFT_NMS = False
_C.TEST.OKS_THRE = 0.5
_C.TEST.IN_VIS_THRE = 0.0
_C.TEST.COCO_BBOX_FILE = ''
_C.TEST.BBOX_THRE = 1.0
_C.TEST.MODEL_FILE = ''
# debug
_C.DEBUG = CN()
_C.DEBUG.DEBUG = False
_C.DEBUG.SAVE_BATCH_IMAGES_GT = False
_C.DEBUG.SAVE_BATCH_IMAGES_PRED = False
_C.DEBUG.SAVE_HEATMAPS_GT = False
_C.DEBUG.SAVE_HEATMAPS_PRED = False
def update_config(cfg, config_file):
cfg.defrost()
cfg.merge_from_file(config_file)
cfg.freeze()
if __name__ == '__main__':
import sys
with open(sys.argv[1], 'w') as f:
print(_C, file=f)

View File

@@ -0,0 +1,47 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Create by Bin Xiao (Bin.Xiao@microsoft.com)
# Modified by Ke Sun (sunk@mail.ustc.edu.cn)
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from yacs.config import CfgNode as CN
# high_resoluton_net related params for classification
POSE_HIGH_RESOLUTION_NET = CN()
POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*']
POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = 64
POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1
POSE_HIGH_RESOLUTION_NET.WITH_HEAD = True
POSE_HIGH_RESOLUTION_NET.STAGE2 = CN()
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4]
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64]
POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM'
POSE_HIGH_RESOLUTION_NET.STAGE3 = CN()
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4]
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128]
POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM'
POSE_HIGH_RESOLUTION_NET.STAGE4 = CN()
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM'
MODEL_EXTRAS = {
'cls_hrnet': POSE_HIGH_RESOLUTION_NET,
}

View File

@@ -0,0 +1,523 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# Modified by Ke Sun (sunk@mail.ustc.edu.cn)
# Modified by Kevin Lin (keli@microsoft.com)
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import logging
import functools
import numpy as np
import torch
import torch.nn as nn
import torch._utils
import torch.nn.functional as F
import code
BN_MOMENTUM = 0.1
logger = logging.getLogger(__name__)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class HighResolutionModule(nn.Module):
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
num_channels, fuse_method, multi_scale_output=True):
super(HighResolutionModule, self).__init__()
self._check_branches(
num_branches, blocks, num_blocks, num_inchannels, num_channels)
self.num_inchannels = num_inchannels
self.fuse_method = fuse_method
self.num_branches = num_branches
self.multi_scale_output = multi_scale_output
self.branches = self._make_branches(
num_branches, blocks, num_blocks, num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(False)
def _check_branches(self, num_branches, blocks, num_blocks,
num_inchannels, num_channels):
if num_branches != len(num_blocks):
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
num_branches, len(num_blocks))
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
num_branches, len(num_channels))
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_inchannels):
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
num_branches, len(num_inchannels))
logger.error(error_msg)
raise ValueError(error_msg)
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
stride=1):
downsample = None
if stride != 1 or \
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.num_inchannels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(num_channels[branch_index] * block.expansion,
momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(self.num_inchannels[branch_index],
num_channels[branch_index], stride, downsample))
self.num_inchannels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(block(self.num_inchannels[branch_index],
num_channels[branch_index]))
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
num_branches = self.num_branches
num_inchannels = self.num_inchannels
fuse_layers = []
for i in range(num_branches if self.multi_scale_output else 1):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_inchannels[i],
1,
1,
0,
bias=False),
nn.BatchNorm2d(num_inchannels[i],
momentum=BN_MOMENTUM),
nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
elif j == i:
fuse_layer.append(None)
else:
conv3x3s = []
for k in range(i-j):
if k == i - j - 1:
num_outchannels_conv3x3 = num_inchannels[i]
conv3x3s.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_outchannels_conv3x3,
3, 2, 1, bias=False),
nn.BatchNorm2d(num_outchannels_conv3x3,
momentum=BN_MOMENTUM)))
else:
num_outchannels_conv3x3 = num_inchannels[j]
conv3x3s.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_outchannels_conv3x3,
3, 2, 1, bias=False),
nn.BatchNorm2d(num_outchannels_conv3x3,
momentum=BN_MOMENTUM),
nn.ReLU(False)))
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def get_num_inchannels(self):
return self.num_inchannels
def forward(self, x):
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
for j in range(1, self.num_branches):
if i == j:
y = y + x[j]
else:
y = y + self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
blocks_dict = {
'BASIC': BasicBlock,
'BOTTLENECK': Bottleneck
}
class HighResolutionNet(nn.Module):
def __init__(self, cfg, **kwargs):
super(HighResolutionNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1']
num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
block = blocks_dict[self.stage1_cfg['BLOCK']]
num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
stage1_out_channel = block.expansion*num_channels
self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
num_channels = self.stage2_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage2_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition1 = self._make_transition_layer(
[stage1_out_channel], num_channels)
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels)
self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']
num_channels = self.stage3_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage3_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition2 = self._make_transition_layer(
pre_stage_channels, num_channels)
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg, num_channels)
self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']
num_channels = self.stage4_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage4_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition3 = self._make_transition_layer(
pre_stage_channels, num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels, multi_scale_output=True)
# Classification Head
self.incre_modules, self.downsamp_modules, \
self.final_layer = self._make_head(pre_stage_channels)
self.classifier = nn.Linear(2048, 1000)
def _make_head(self, pre_stage_channels):
head_block = Bottleneck
head_channels = [32, 64, 128, 256]
# Increasing the #channels on each resolution
# from C, 2C, 4C, 8C to 128, 256, 512, 1024
incre_modules = []
for i, channels in enumerate(pre_stage_channels):
incre_module = self._make_layer(head_block,
channels,
head_channels[i],
1,
stride=1)
incre_modules.append(incre_module)
incre_modules = nn.ModuleList(incre_modules)
# downsampling modules
downsamp_modules = []
for i in range(len(pre_stage_channels)-1):
in_channels = head_channels[i] * head_block.expansion
out_channels = head_channels[i+1] * head_block.expansion
downsamp_module = nn.Sequential(
nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1),
nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
downsamp_modules.append(downsamp_module)
downsamp_modules = nn.ModuleList(downsamp_modules)
final_layer = nn.Sequential(
nn.Conv2d(
in_channels=head_channels[3] * head_block.expansion,
out_channels=2048,
kernel_size=1,
stride=1,
padding=0
),
nn.BatchNorm2d(2048, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
return incre_modules, downsamp_modules, final_layer
def _make_transition_layer(
self, num_channels_pre_layer, num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(nn.Sequential(
nn.Conv2d(num_channels_pre_layer[i],
num_channels_cur_layer[i],
3,
1,
1,
bias=False),
nn.BatchNorm2d(
num_channels_cur_layer[i], momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)))
else:
transition_layers.append(None)
else:
conv3x3s = []
for j in range(i+1-num_branches_pre):
inchannels = num_channels_pre_layer[-1]
outchannels = num_channels_cur_layer[i] \
if j == i-num_branches_pre else inchannels
conv3x3s.append(nn.Sequential(
nn.Conv2d(
inchannels, outchannels, 3, 2, 1, bias=False),
nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)))
transition_layers.append(nn.Sequential(*conv3x3s))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(inplanes, planes, stride, downsample))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(inplanes, planes))
return nn.Sequential(*layers)
def _make_stage(self, layer_config, num_inchannels,
multi_scale_output=True):
num_modules = layer_config['NUM_MODULES']
num_branches = layer_config['NUM_BRANCHES']
num_blocks = layer_config['NUM_BLOCKS']
num_channels = layer_config['NUM_CHANNELS']
block = blocks_dict[layer_config['BLOCK']]
fuse_method = layer_config['FUSE_METHOD']
modules = []
for i in range(num_modules):
# multi_scale_output is only used last module
if not multi_scale_output and i == num_modules - 1:
reset_multi_scale_output = False
else:
reset_multi_scale_output = True
modules.append(
HighResolutionModule(num_branches,
block,
num_blocks,
num_inchannels,
num_channels,
fuse_method,
reset_multi_scale_output)
)
num_inchannels = modules[-1].get_num_inchannels()
return nn.Sequential(*modules), num_inchannels
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg['NUM_BRANCHES']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['NUM_BRANCHES']):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['NUM_BRANCHES']):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage4(x_list)
# Classification Head
y = self.incre_modules[0](y_list[0])
for i in range(len(self.downsamp_modules)):
y = self.incre_modules[i+1](y_list[i+1]) + \
self.downsamp_modules[i](y)
y = self.final_layer(y)
if torch._C._get_tracing_state():
y = y.flatten(start_dim=2).mean(dim=2)
else:
y = F.avg_pool2d(y, kernel_size=y.size()
[2:]).view(y.size(0), -1)
# y = self.classifier(y)
return y
def init_weights(self, pretrained='',):
logger.info('=> init weights from normal distribution')
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if os.path.isfile(pretrained):
pretrained_dict = torch.load(pretrained)
logger.info('=> loading pretrained model {}'.format(pretrained))
print('=> loading pretrained model {}'.format(pretrained))
model_dict = self.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items()
if k in model_dict.keys()}
# for k, _ in pretrained_dict.items():
# logger.info(
# '=> loading {} pretrained model {}'.format(k, pretrained))
# print('=> loading {} pretrained model {}'.format(k, pretrained))
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict)
# code.interact(local=locals())
def get_cls_net(config, pretrained, **kwargs):
model = HighResolutionNet(config, **kwargs)
model.init_weights(pretrained=pretrained)
return model

View File

@@ -0,0 +1,524 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# Modified by Ke Sun (sunk@mail.ustc.edu.cn)
# Modified by Kevin Lin (keli@microsoft.com)
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import logging
import functools
import numpy as np
import torch
import torch.nn as nn
import torch._utils
import torch.nn.functional as F
import code
BN_MOMENTUM = 0.1
logger = logging.getLogger(__name__)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class HighResolutionModule(nn.Module):
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
num_channels, fuse_method, multi_scale_output=True):
super(HighResolutionModule, self).__init__()
self._check_branches(
num_branches, blocks, num_blocks, num_inchannels, num_channels)
self.num_inchannels = num_inchannels
self.fuse_method = fuse_method
self.num_branches = num_branches
self.multi_scale_output = multi_scale_output
self.branches = self._make_branches(
num_branches, blocks, num_blocks, num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(False)
def _check_branches(self, num_branches, blocks, num_blocks,
num_inchannels, num_channels):
if num_branches != len(num_blocks):
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
num_branches, len(num_blocks))
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
num_branches, len(num_channels))
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_inchannels):
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
num_branches, len(num_inchannels))
logger.error(error_msg)
raise ValueError(error_msg)
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
stride=1):
downsample = None
if stride != 1 or \
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.num_inchannels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(num_channels[branch_index] * block.expansion,
momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(self.num_inchannels[branch_index],
num_channels[branch_index], stride, downsample))
self.num_inchannels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(block(self.num_inchannels[branch_index],
num_channels[branch_index]))
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
num_branches = self.num_branches
num_inchannels = self.num_inchannels
fuse_layers = []
for i in range(num_branches if self.multi_scale_output else 1):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_inchannels[i],
1,
1,
0,
bias=False),
nn.BatchNorm2d(num_inchannels[i],
momentum=BN_MOMENTUM),
nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
elif j == i:
fuse_layer.append(None)
else:
conv3x3s = []
for k in range(i-j):
if k == i - j - 1:
num_outchannels_conv3x3 = num_inchannels[i]
conv3x3s.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_outchannels_conv3x3,
3, 2, 1, bias=False),
nn.BatchNorm2d(num_outchannels_conv3x3,
momentum=BN_MOMENTUM)))
else:
num_outchannels_conv3x3 = num_inchannels[j]
conv3x3s.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_outchannels_conv3x3,
3, 2, 1, bias=False),
nn.BatchNorm2d(num_outchannels_conv3x3,
momentum=BN_MOMENTUM),
nn.ReLU(False)))
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def get_num_inchannels(self):
return self.num_inchannels
def forward(self, x):
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
for j in range(1, self.num_branches):
if i == j:
y = y + x[j]
else:
y = y + self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
blocks_dict = {
'BASIC': BasicBlock,
'BOTTLENECK': Bottleneck
}
class HighResolutionNet(nn.Module):
def __init__(self, cfg, **kwargs):
super(HighResolutionNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1']
num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
block = blocks_dict[self.stage1_cfg['BLOCK']]
num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
stage1_out_channel = block.expansion*num_channels
self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
num_channels = self.stage2_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage2_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition1 = self._make_transition_layer(
[stage1_out_channel], num_channels)
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels)
self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']
num_channels = self.stage3_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage3_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition2 = self._make_transition_layer(
pre_stage_channels, num_channels)
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg, num_channels)
self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']
num_channels = self.stage4_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage4_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition3 = self._make_transition_layer(
pre_stage_channels, num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels, multi_scale_output=True)
# Classification Head
self.incre_modules, self.downsamp_modules, \
self.final_layer = self._make_head(pre_stage_channels)
self.classifier = nn.Linear(2048, 1000)
def _make_head(self, pre_stage_channels):
head_block = Bottleneck
head_channels = [32, 64, 128, 256]
# Increasing the #channels on each resolution
# from C, 2C, 4C, 8C to 128, 256, 512, 1024
incre_modules = []
for i, channels in enumerate(pre_stage_channels):
incre_module = self._make_layer(head_block,
channels,
head_channels[i],
1,
stride=1)
incre_modules.append(incre_module)
incre_modules = nn.ModuleList(incre_modules)
# downsampling modules
downsamp_modules = []
for i in range(len(pre_stage_channels)-1):
in_channels = head_channels[i] * head_block.expansion
out_channels = head_channels[i+1] * head_block.expansion
downsamp_module = nn.Sequential(
nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1),
nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
downsamp_modules.append(downsamp_module)
downsamp_modules = nn.ModuleList(downsamp_modules)
final_layer = nn.Sequential(
nn.Conv2d(
in_channels=head_channels[3] * head_block.expansion,
out_channels=2048,
kernel_size=1,
stride=1,
padding=0
),
nn.BatchNorm2d(2048, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
return incre_modules, downsamp_modules, final_layer
def _make_transition_layer(
self, num_channels_pre_layer, num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(nn.Sequential(
nn.Conv2d(num_channels_pre_layer[i],
num_channels_cur_layer[i],
3,
1,
1,
bias=False),
nn.BatchNorm2d(
num_channels_cur_layer[i], momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)))
else:
transition_layers.append(None)
else:
conv3x3s = []
for j in range(i+1-num_branches_pre):
inchannels = num_channels_pre_layer[-1]
outchannels = num_channels_cur_layer[i] \
if j == i-num_branches_pre else inchannels
conv3x3s.append(nn.Sequential(
nn.Conv2d(
inchannels, outchannels, 3, 2, 1, bias=False),
nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)))
transition_layers.append(nn.Sequential(*conv3x3s))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(inplanes, planes, stride, downsample))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(inplanes, planes))
return nn.Sequential(*layers)
def _make_stage(self, layer_config, num_inchannels,
multi_scale_output=True):
num_modules = layer_config['NUM_MODULES']
num_branches = layer_config['NUM_BRANCHES']
num_blocks = layer_config['NUM_BLOCKS']
num_channels = layer_config['NUM_CHANNELS']
block = blocks_dict[layer_config['BLOCK']]
fuse_method = layer_config['FUSE_METHOD']
modules = []
for i in range(num_modules):
# multi_scale_output is only used last module
if not multi_scale_output and i == num_modules - 1:
reset_multi_scale_output = False
else:
reset_multi_scale_output = True
modules.append(
HighResolutionModule(num_branches,
block,
num_blocks,
num_inchannels,
num_channels,
fuse_method,
reset_multi_scale_output)
)
num_inchannels = modules[-1].get_num_inchannels()
return nn.Sequential(*modules), num_inchannels
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg['NUM_BRANCHES']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['NUM_BRANCHES']):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['NUM_BRANCHES']):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage4(x_list)
# Classification Head
y = self.incre_modules[0](y_list[0])
for i in range(len(self.downsamp_modules)):
y = self.incre_modules[i+1](y_list[i+1]) + \
self.downsamp_modules[i](y)
yy = self.final_layer(y)
if torch._C._get_tracing_state():
yy = yy.flatten(start_dim=2).mean(dim=2)
else:
yy = F.avg_pool2d(yy, kernel_size=yy.size()
[2:]).view(yy.size(0), -1)
# y = self.classifier(y)
return yy, y
def init_weights(self, pretrained='',):
logger.info('=> init weights from normal distribution')
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if os.path.isfile(pretrained):
pretrained_dict = torch.load(pretrained)
logger.info('=> loading pretrained model {}'.format(pretrained))
print('=> loading pretrained model {}'.format(pretrained))
model_dict = self.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items()
if k in model_dict.keys()}
# for k, _ in pretrained_dict.items():
# logger.info(
# '=> loading {} pretrained model {}'.format(k, pretrained))
# print('=> loading {} pretrained model {}'.format(k, pretrained))
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict)
# code.interact(local=locals())
def get_cls_net_gridfeat(config, pretrained, **kwargs):
model = HighResolutionNet(config, **kwargs)
model.init_weights(pretrained=pretrained)
return model

View File

@@ -0,0 +1,750 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
Training and evaluation codes for
3D human body mesh reconstruction from an image
"""
from __future__ import absolute_import, division, print_function
import argparse
import os
import os.path as op
import code
import json
import time
import datetime
import torch
import torchvision.models as models
from torchvision.utils import make_grid
import gc
import numpy as np
import cv2
from mesh_graphormer.modeling.bert import BertConfig, Graphormer
from mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network
from mesh_graphormer.modeling._smpl import SMPL, Mesh
from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
from mesh_graphormer.modeling.hrnet.config import config as hrnet_config
from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
import mesh_graphormer.modeling.data.config as cfg
from mesh_graphormer.datasets.build import make_data_loader
from mesh_graphormer.utils.logger import setup_logger
from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
from mesh_graphormer.utils.miscellaneous import mkdir, set_seed
from mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger
from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction, visualize_reconstruction_test
from mesh_graphormer.utils.metric_pampjpe import reconstruction_error
from mesh_graphormer.utils.geometric_layers import orthographic_projection
device = "cuda"
from azureml.core.run import Run
aml_run = Run.get_context()
def save_checkpoint(model, args, epoch, iteration, num_trial=10):
checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format(
epoch, iteration))
if not is_main_process():
return checkpoint_dir
mkdir(checkpoint_dir)
model_to_save = model.module if hasattr(model, 'module') else model
for i in range(num_trial):
try:
torch.save(model_to_save, op.join(checkpoint_dir, 'model.bin'))
torch.save(model_to_save.state_dict(), op.join(checkpoint_dir, 'state_dict.bin'))
torch.save(args, op.join(checkpoint_dir, 'training_args.bin'))
logger.info("Save checkpoint to {}".format(checkpoint_dir))
break
except:
pass
else:
logger.info("Failed to save checkpoint after {} trails.".format(num_trial))
return checkpoint_dir
def save_scores(args, split, mpjpe, pampjpe, mpve):
eval_log = []
res = {}
res['mPJPE'] = mpjpe
res['PAmPJPE'] = pampjpe
res['mPVE'] = mpve
eval_log.append(res)
with open(op.join(args.output_dir, split+'_eval_logs.json'), 'w') as f:
json.dump(eval_log, f)
logger.info("Save eval scores to {}".format(args.output_dir))
return
def adjust_learning_rate(optimizer, epoch, args):
"""
Sets the learning rate to the initial LR decayed by x every y epochs
x = 0.1, y = args.num_train_epochs/2.0 = 100
"""
lr = args.lr * (0.1 ** (epoch // (args.num_train_epochs/2.0) ))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def mean_per_joint_position_error(pred, gt, has_3d_joints):
"""
Compute mPJPE
"""
gt = gt[has_3d_joints == 1]
gt = gt[:, :, :-1]
pred = pred[has_3d_joints == 1]
with torch.no_grad():
gt_pelvis = (gt[:, 2,:] + gt[:, 3,:]) / 2
gt = gt - gt_pelvis[:, None, :]
pred_pelvis = (pred[:, 2,:] + pred[:, 3,:]) / 2
pred = pred - pred_pelvis[:, None, :]
error = torch.sqrt( ((pred - gt) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
return error
def mean_per_vertex_error(pred, gt, has_smpl):
"""
Compute mPVE
"""
pred = pred[has_smpl == 1]
gt = gt[has_smpl == 1]
with torch.no_grad():
error = torch.sqrt( ((pred - gt) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
return error
def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, has_pose_2d):
"""
Compute 2D reprojection loss if 2D keypoint annotations are available.
The confidence (conf) is binary and indicates whether the keypoints exist or not.
"""
conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
loss = (conf * criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean()
return loss
def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d, device):
"""
Compute 3D keypoint loss if 3D keypoint annotations are available.
"""
conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
conf = conf[has_pose_3d == 1]
pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
if len(gt_keypoints_3d) > 0:
gt_pelvis = (gt_keypoints_3d[:, 2,:] + gt_keypoints_3d[:, 3,:]) / 2
gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :]
pred_pelvis = (pred_keypoints_3d[:, 2,:] + pred_keypoints_3d[:, 3,:]) / 2
pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :]
return (conf * criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean()
else:
return torch.FloatTensor(1).fill_(0.).to(device)
def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, device):
"""
Compute per-vertex loss if vertex annotations are available.
"""
pred_vertices_with_shape = pred_vertices[has_smpl == 1]
gt_vertices_with_shape = gt_vertices[has_smpl == 1]
if len(gt_vertices_with_shape) > 0:
return criterion_vertices(pred_vertices_with_shape, gt_vertices_with_shape)
else:
return torch.FloatTensor(1).fill_(0.).to(device)
def rectify_pose(pose):
pose = pose.copy()
R_mod = cv2.Rodrigues(np.array([np.pi, 0, 0]))[0]
R_root = cv2.Rodrigues(pose[:3])[0]
new_root = R_root.dot(R_mod)
pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3)
return pose
def run(args, train_dataloader, val_dataloader, Graphormer_model, smpl, mesh_sampler, renderer):
smpl.eval()
max_iter = len(train_dataloader)
iters_per_epoch = max_iter // args.num_train_epochs
if iters_per_epoch<1000:
args.logging_steps = 500
optimizer = torch.optim.Adam(params=list(Graphormer_model.parameters()),
lr=args.lr,
betas=(0.9, 0.999),
weight_decay=0)
# define loss function (criterion) and optimizer
criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').to(device)
criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
criterion_vertices = torch.nn.L1Loss().to(device)
if args.distributed:
Graphormer_model = torch.nn.parallel.DistributedDataParallel(
Graphormer_model, device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True,
)
logger.info(
' '.join(
['Local rank: {o}', 'Max iteration: {a}', 'iters_per_epoch: {b}','num_train_epochs: {c}',]
).format(o=args.local_rank, a=max_iter, b=iters_per_epoch, c=args.num_train_epochs)
)
start_training_time = time.time()
end = time.time()
Graphormer_model.train()
batch_time = AverageMeter()
data_time = AverageMeter()
log_losses = AverageMeter()
log_loss_2djoints = AverageMeter()
log_loss_3djoints = AverageMeter()
log_loss_vertices = AverageMeter()
log_eval_metrics = EvalMetricsLogger()
for iteration, (img_keys, images, annotations) in enumerate(train_dataloader):
# gc.collect()
# torch.cuda.empty_cache()
Graphormer_model.train()
iteration += 1
epoch = iteration // iters_per_epoch
batch_size = images.size(0)
adjust_learning_rate(optimizer, epoch, args)
data_time.update(time.time() - end)
images = images.to(device)
gt_2d_joints = annotations['joints_2d'].to(device)
gt_2d_joints = gt_2d_joints[:,cfg.J24_TO_J14,:]
has_2d_joints = annotations['has_2d_joints'].to(device)
gt_3d_joints = annotations['joints_3d'].to(device)
gt_3d_pelvis = gt_3d_joints[:,cfg.J24_NAME.index('Pelvis'),:3]
gt_3d_joints = gt_3d_joints[:,cfg.J24_TO_J14,:]
gt_3d_joints[:,:,:3] = gt_3d_joints[:,:,:3] - gt_3d_pelvis[:, None, :]
has_3d_joints = annotations['has_3d_joints'].to(device)
gt_pose = annotations['pose'].to(device)
gt_betas = annotations['betas'].to(device)
has_smpl = annotations['has_smpl'].to(device)
mjm_mask = annotations['mjm_mask'].to(device)
mvm_mask = annotations['mvm_mask'].to(device)
# generate simplified mesh
gt_vertices = smpl(gt_pose, gt_betas)
gt_vertices_sub2 = mesh_sampler.downsample(gt_vertices, n1=0, n2=2)
gt_vertices_sub = mesh_sampler.downsample(gt_vertices)
# normalize gt based on smpl's pelvis
gt_smpl_3d_joints = smpl.get_h36m_joints(gt_vertices)
gt_smpl_3d_pelvis = gt_smpl_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
gt_vertices_sub2 = gt_vertices_sub2 - gt_smpl_3d_pelvis[:, None, :]
# prepare masks for mask vertex/joint modeling
mjm_mask_ = mjm_mask.expand(-1,-1,2051)
mvm_mask_ = mvm_mask.expand(-1,-1,2051)
meta_masks = torch.cat([mjm_mask_, mvm_mask_], dim=1)
# forward-pass
pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices = Graphormer_model(images, smpl, mesh_sampler, meta_masks=meta_masks, is_train=True)
# normalize gt based on smpl's pelvis
gt_vertices_sub = gt_vertices_sub - gt_smpl_3d_pelvis[:, None, :]
gt_vertices = gt_vertices - gt_smpl_3d_pelvis[:, None, :]
# obtain 3d joints, which are regressed from the full mesh
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
# obtain 2d joints, which are projected from 3d joints of smpl mesh
pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera)
pred_2d_joints = orthographic_projection(pred_3d_joints, pred_camera)
# compute 3d joint loss (where the joints are directly output from transformer)
loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints, has_3d_joints, args.device)
# compute 3d vertex loss
loss_vertices = ( args.vloss_w_sub2 * vertices_loss(criterion_vertices, pred_vertices_sub2, gt_vertices_sub2, has_smpl, args.device) + \
args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_smpl, args.device) + \
args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, args.device) )
# compute 3d joint loss (where the joints are regressed from full mesh)
loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_smpl, gt_3d_joints, has_3d_joints, args.device)
# compute 2d joint loss
loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints) + \
keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_smpl, gt_2d_joints, has_2d_joints)
loss_3d_joints = loss_3d_joints + loss_reg_3d_joints
# we empirically use hyperparameters to balance difference losses
loss = args.joints_loss_weight*loss_3d_joints + \
args.vertices_loss_weight*loss_vertices + args.vertices_loss_weight*loss_2d_joints
# update logs
log_loss_2djoints.update(loss_2d_joints.item(), batch_size)
log_loss_3djoints.update(loss_3d_joints.item(), batch_size)
log_loss_vertices.update(loss_vertices.item(), batch_size)
log_losses.update(loss.item(), batch_size)
# back prop
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if iteration % args.logging_steps == 0 or iteration == max_iter:
eta_seconds = batch_time.avg * (max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
logger.info(
' '.join(
['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',]
).format(eta=eta_string, ep=epoch, iter=iteration,
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
+ ' loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format(
log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg,
optimizer.param_groups[0]['lr'])
)
aml_run.log(name='Loss', value=float(log_losses.avg))
aml_run.log(name='3d joint Loss', value=float(log_loss_3djoints.avg))
aml_run.log(name='2d joint Loss', value=float(log_loss_2djoints.avg))
aml_run.log(name='vertex Loss', value=float(log_loss_vertices.avg))
visual_imgs = visualize_mesh( renderer,
annotations['ori_img'].detach(),
annotations['joints_2d'].detach(),
pred_vertices.detach(),
pred_camera.detach(),
pred_2d_joints_from_smpl.detach())
visual_imgs = visual_imgs.transpose(0,1)
visual_imgs = visual_imgs.transpose(1,2)
visual_imgs = np.asarray(visual_imgs)
if is_main_process()==True:
stamp = str(epoch) + '_' + str(iteration)
temp_fname = args.output_dir + 'visual_' + stamp + '.jpg'
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
aml_run.log_image(name='visual results', path=temp_fname)
if iteration % iters_per_epoch == 0:
val_mPVE, val_mPJPE, val_PAmPJPE, val_count = run_validate(args, val_dataloader,
Graphormer_model,
criterion_keypoints,
criterion_vertices,
epoch,
smpl,
mesh_sampler)
aml_run.log(name='mPVE', value=float(1000*val_mPVE))
aml_run.log(name='mPJPE', value=float(1000*val_mPJPE))
aml_run.log(name='PAmPJPE', value=float(1000*val_PAmPJPE))
logger.info(
' '.join(['Validation', 'epoch: {ep}',]).format(ep=epoch)
+ ' mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f}, Data Count: {:6.2f}'.format(1000*val_mPVE, 1000*val_mPJPE, 1000*val_PAmPJPE, val_count)
)
if val_PAmPJPE<log_eval_metrics.PAmPJPE:
checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration)
log_eval_metrics.update(val_mPVE, val_mPJPE, val_PAmPJPE, epoch)
total_training_time = time.time() - start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time))
logger.info('Total training time: {} ({:.4f} s / iter)'.format(
total_time_str, total_training_time / max_iter)
)
checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration)
logger.info(
' Best Results:'
+ ' mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f}, at epoch {:6.2f}'.format(1000*log_eval_metrics.mPVE, 1000*log_eval_metrics.mPJPE, 1000*log_eval_metrics.PAmPJPE, log_eval_metrics.epoch)
)
def run_eval_general(args, val_dataloader, Graphormer_model, smpl, mesh_sampler):
smpl.eval()
criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
criterion_vertices = torch.nn.L1Loss().to(device)
epoch = 0
if args.distributed:
Graphormer_model = torch.nn.parallel.DistributedDataParallel(
Graphormer_model, device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True,
)
Graphormer_model.eval()
val_mPVE, val_mPJPE, val_PAmPJPE, val_count = run_validate(args, val_dataloader,
Graphormer_model,
criterion_keypoints,
criterion_vertices,
epoch,
smpl,
mesh_sampler)
aml_run.log(name='mPVE', value=float(1000*val_mPVE))
aml_run.log(name='mPJPE', value=float(1000*val_mPJPE))
aml_run.log(name='PAmPJPE', value=float(1000*val_PAmPJPE))
logger.info(
' '.join(['Validation', 'epoch: {ep}',]).format(ep=epoch)
+ ' mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f} '.format(1000*val_mPVE, 1000*val_mPJPE, 1000*val_PAmPJPE)
)
# checkpoint_dir = save_checkpoint(Graphormer_model, args, 0, 0)
return
def run_validate(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, smpl, mesh_sampler):
batch_time = AverageMeter()
mPVE = AverageMeter()
mPJPE = AverageMeter()
PAmPJPE = AverageMeter()
# switch to evaluate mode
Graphormer_model.eval()
smpl.eval()
with torch.no_grad():
# end = time.time()
for i, (img_keys, images, annotations) in enumerate(val_loader):
batch_size = images.size(0)
# compute output
images = images.to(device)
gt_3d_joints = annotations['joints_3d'].to(device)
gt_3d_pelvis = gt_3d_joints[:,cfg.J24_NAME.index('Pelvis'),:3]
gt_3d_joints = gt_3d_joints[:,cfg.J24_TO_J14,:]
gt_3d_joints[:,:,:3] = gt_3d_joints[:,:,:3] - gt_3d_pelvis[:, None, :]
has_3d_joints = annotations['has_3d_joints'].to(device)
gt_pose = annotations['pose'].to(device)
gt_betas = annotations['betas'].to(device)
has_smpl = annotations['has_smpl'].to(device)
# generate simplified mesh
gt_vertices = smpl(gt_pose, gt_betas)
gt_vertices_sub = mesh_sampler.downsample(gt_vertices)
gt_vertices_sub2 = mesh_sampler.downsample(gt_vertices_sub, n1=1, n2=2)
# normalize gt based on smpl pelvis
gt_smpl_3d_joints = smpl.get_h36m_joints(gt_vertices)
gt_smpl_3d_pelvis = gt_smpl_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
gt_vertices_sub2 = gt_vertices_sub2 - gt_smpl_3d_pelvis[:, None, :]
gt_vertices = gt_vertices - gt_smpl_3d_pelvis[:, None, :]
# forward-pass
pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices = Graphormer_model(images, smpl, mesh_sampler)
# obtain 3d joints from full mesh
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
pred_3d_pelvis = pred_3d_joints_from_smpl[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl - pred_3d_pelvis[:, None, :]
pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
# measure errors
error_vertices = mean_per_vertex_error(pred_vertices, gt_vertices, has_smpl)
error_joints = mean_per_joint_position_error(pred_3d_joints_from_smpl, gt_3d_joints, has_3d_joints)
error_joints_pa = reconstruction_error(pred_3d_joints_from_smpl.cpu().numpy(), gt_3d_joints[:,:,:3].cpu().numpy(), reduction=None)
if len(error_vertices)>0:
mPVE.update(np.mean(error_vertices), int(torch.sum(has_smpl)) )
if len(error_joints)>0:
mPJPE.update(np.mean(error_joints), int(torch.sum(has_3d_joints)) )
if len(error_joints_pa)>0:
PAmPJPE.update(np.mean(error_joints_pa), int(torch.sum(has_3d_joints)) )
val_mPVE = all_gather(float(mPVE.avg))
val_mPVE = sum(val_mPVE)/len(val_mPVE)
val_mPJPE = all_gather(float(mPJPE.avg))
val_mPJPE = sum(val_mPJPE)/len(val_mPJPE)
val_PAmPJPE = all_gather(float(PAmPJPE.avg))
val_PAmPJPE = sum(val_PAmPJPE)/len(val_PAmPJPE)
val_count = all_gather(float(mPVE.count))
val_count = sum(val_count)
return val_mPVE, val_mPJPE, val_PAmPJPE, val_count
def visualize_mesh( renderer,
images,
gt_keypoints_2d,
pred_vertices,
pred_camera,
pred_keypoints_2d):
"""Tensorboard logging."""
gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
to_lsp = list(range(14))
rend_imgs = []
batch_size = pred_vertices.shape[0]
# Do visualization for the first 6 images of the batch
for i in range(min(batch_size, 10)):
img = images[i].cpu().numpy().transpose(1,2,0)
# Get LSP keypoints from the full list of keypoints
gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
# Get predict vertices for the particular example
vertices = pred_vertices[i].cpu().numpy()
cam = pred_camera[i].cpu().numpy()
# Visualize reconstruction and detected pose
rend_img = visualize_reconstruction(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer)
rend_img = rend_img.transpose(2,0,1)
rend_imgs.append(torch.from_numpy(rend_img))
rend_imgs = make_grid(rend_imgs, nrow=1)
return rend_imgs
def visualize_mesh_test( renderer,
images,
gt_keypoints_2d,
pred_vertices,
pred_camera,
pred_keypoints_2d,
PAmPJPE_h36m_j14):
"""Tensorboard logging."""
gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
to_lsp = list(range(14))
rend_imgs = []
batch_size = pred_vertices.shape[0]
# Do visualization for the first 6 images of the batch
for i in range(min(batch_size, 10)):
img = images[i].cpu().numpy().transpose(1,2,0)
# Get LSP keypoints from the full list of keypoints
gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
# Get predict vertices for the particular example
vertices = pred_vertices[i].cpu().numpy()
cam = pred_camera[i].cpu().numpy()
score = PAmPJPE_h36m_j14[i]
# Visualize reconstruction and detected pose
rend_img = visualize_reconstruction_test(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer, score)
rend_img = rend_img.transpose(2,0,1)
rend_imgs.append(torch.from_numpy(rend_img))
rend_imgs = make_grid(rend_imgs, nrow=1)
return rend_imgs
def parse_args():
parser = argparse.ArgumentParser()
#########################################################
# Data related arguments
#########################################################
parser.add_argument("--data_dir", default='datasets', type=str, required=False,
help="Directory with all datasets, each in one subfolder")
parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False,
help="Yaml file with all data for training.")
parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False,
help="Yaml file with all data for validation.")
parser.add_argument("--num_workers", default=4, type=int,
help="Workers in dataloader.")
parser.add_argument("--img_scale_factor", default=1, type=int,
help="adjust image resolution.")
#########################################################
# Loading/saving checkpoints
#########################################################
parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
help="Path to pre-trained transformer model or model type.")
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
help="Path to specific checkpoint for resume training.")
parser.add_argument("--output_dir", default='output/', type=str, required=False,
help="The output directory to save checkpoint and test results.")
parser.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name.")
#########################################################
# Training parameters
#########################################################
parser.add_argument("--per_gpu_train_batch_size", default=30, type=int,
help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=30, type=int,
help="Batch size per GPU/CPU for evaluation.")
parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float,
help="The initial lr.")
parser.add_argument("--num_train_epochs", default=200, type=int,
help="Total number of training epochs to perform.")
parser.add_argument("--vertices_loss_weight", default=100.0, type=float)
parser.add_argument("--joints_loss_weight", default=1000.0, type=float)
parser.add_argument("--vloss_w_full", default=0.33, type=float)
parser.add_argument("--vloss_w_sub", default=0.33, type=float)
parser.add_argument("--vloss_w_sub2", default=0.33, type=float)
parser.add_argument("--drop_out", default=0.1, type=float,
help="Drop out ratio in BERT.")
#########################################################
# Model architectures
#########################################################
parser.add_argument('-a', '--arch', default='hrnet-w64',
help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
parser.add_argument("--num_hidden_layers", default=4, type=int, required=False,
help="Update model config if given")
parser.add_argument("--hidden_size", default=-1, type=int, required=False,
help="Update model config if given")
parser.add_argument("--num_attention_heads", default=4, type=int, required=False,
help="Update model config if given. Note that the division of "
"hidden_size / num_attention_heads should be in integer.")
parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
help="Update model config if given.")
parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
help="The Image Feature Dimension.")
parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
help="The Image Feature Dimension.")
parser.add_argument("--which_gcn", default='0,0,1', type=str,
help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
parser.add_argument("--mesh_type", default='body', type=str, help="body or hand")
parser.add_argument("--interm_size_scale", default=2, type=int)
#########################################################
# Others
#########################################################
parser.add_argument("--run_eval_only", default=False, action='store_true',)
parser.add_argument('--logging_steps', type=int, default=1000,
help="Log every X steps.")
parser.add_argument("--device", type=str, default='cuda',
help="cuda or cpu")
parser.add_argument('--seed', type=int, default=88,
help="random seed for initialization.")
parser.add_argument("--local_rank", type=int, default=0,
help="For distributed training.")
args = parser.parse_args()
return args
def main(args):
global logger
# Setup CUDA, GPU & distributed training
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
args.distributed = args.num_gpus > 1
args.device = torch.device(args.device)
if args.distributed:
print("Init distributed training on local rank {} ({}), rank {}, world size {}".format(args.local_rank, int(os.environ["LOCAL_RANK"]), int(os.environ["NODE_RANK"]), args.num_gpus))
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
backend='nccl', init_method='env://'
)
local_rank = int(os.environ["LOCAL_RANK"])
args.device = torch.device("cuda", local_rank)
synchronize()
mkdir(args.output_dir)
logger = setup_logger("Graphormer", args.output_dir, get_rank())
set_seed(args.seed, args.num_gpus)
logger.info("Using {} GPUs".format(args.num_gpus))
# Mesh and SMPL utils
smpl = SMPL().to(args.device)
mesh_sampler = Mesh()
# Renderer for visualization
renderer = Renderer(faces=smpl.faces.cpu().numpy())
# Load model
trans_encoder = []
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
output_feat_dim = input_feat_dim[1:] + [3]
# which encoder block to have graph convs
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
# if only run eval, load checkpoint
logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
_model = torch.load(args.resume_checkpoint)
else:
# init three transformer-encoder blocks in a loop
for i in range(len(output_feat_dim)):
config_class, model_class = BertConfig, Graphormer
config = config_class.from_pretrained(args.config_name if args.config_name \
else args.model_name_or_path)
config.output_attentions = False
config.hidden_dropout_prob = args.drop_out
config.img_feature_dim = input_feat_dim[i]
config.output_feature_dim = output_feat_dim[i]
args.hidden_size = hidden_feat_dim[i]
args.intermediate_size = int(args.hidden_size*args.interm_size_scale)
if which_blk_graph[i]==1:
config.graph_conv = True
logger.info("Add Graph Conv")
else:
config.graph_conv = False
config.mesh_type = args.mesh_type
# update model structure if specified in arguments
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
for idx, param in enumerate(update_params):
arg_param = getattr(args, param)
config_param = getattr(config, param)
if arg_param > 0 and arg_param != config_param:
logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
setattr(config, param, arg_param)
# init a transformer encoder and append it to a list
assert config.hidden_size % config.num_attention_heads == 0
model = model_class(config=config)
logger.info("Init model from scratch.")
trans_encoder.append(model)
# init ImageNet pre-trained backbone model
if args.arch=='hrnet':
hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
hrnet_update_config(hrnet_config, hrnet_yaml)
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
logger.info('=> loading hrnet-v2-w40 model')
elif args.arch=='hrnet-w64':
hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
hrnet_update_config(hrnet_config, hrnet_yaml)
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
logger.info('=> loading hrnet-v2-w64 model')
else:
print("=> using pre-trained model '{}'".format(args.arch))
backbone = models.__dict__[args.arch](pretrained=True)
# remove the last fc layer
backbone = torch.nn.Sequential(*list(backbone.children())[:-2])
trans_encoder = torch.nn.Sequential(*trans_encoder)
total_params = sum(p.numel() for p in trans_encoder.parameters())
logger.info('Graphormer encoders total parameters: {}'.format(total_params))
backbone_total_params = sum(p.numel() for p in backbone.parameters())
logger.info('Backbone total parameters: {}'.format(backbone_total_params))
# build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder)
_model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler)
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
# for fine-tuning or resume training or inference, load weights from checkpoint
logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
# workaround approach to load sparse tensor in graph conv.
states = torch.load(args.resume_checkpoint)
# states = checkpoint_loaded.state_dict()
for k, v in states.items():
states[k] = v.cpu()
# del checkpoint_loaded
_model.load_state_dict(states, strict=False)
del states
gc.collect()
torch.cuda.empty_cache()
_model.to(args.device)
logger.info("Training parameters %s", args)
if args.run_eval_only==True:
val_dataloader = make_data_loader(args, args.val_yaml,
args.distributed, is_train=False, scale_factor=args.img_scale_factor)
run_eval_general(args, val_dataloader, _model, smpl, mesh_sampler)
else:
train_dataloader = make_data_loader(args, args.train_yaml,
args.distributed, is_train=True, scale_factor=args.img_scale_factor)
val_dataloader = make_data_loader(args, args.val_yaml,
args.distributed, is_train=False, scale_factor=args.img_scale_factor)
run(args, train_dataloader, val_dataloader, _model, smpl, mesh_sampler, renderer)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,351 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
End-to-end inference codes for
3D human body mesh reconstruction from an image
"""
from __future__ import absolute_import, division, print_function
import argparse
import os
import os.path as op
import code
import json
import time
import datetime
import torch
import torchvision.models as models
from torchvision.utils import make_grid
import gc
import numpy as np
import cv2
from mesh_graphormer.modeling.bert import BertConfig, Graphormer
from mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network
from mesh_graphormer.modeling._smpl import SMPL, Mesh
from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
from mesh_graphormer.modeling.hrnet.config import config as hrnet_config
from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
import mesh_graphormer.modeling.data.config as cfg
from mesh_graphormer.datasets.build import make_data_loader
from mesh_graphormer.utils.logger import setup_logger
from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
from mesh_graphormer.utils.miscellaneous import mkdir, set_seed
from mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger
from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text
from mesh_graphormer.utils.metric_pampjpe import reconstruction_error
from mesh_graphormer.utils.geometric_layers import orthographic_projection
from PIL import Image
from torchvision import transforms
device = "cuda"
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
transform_visualize = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor()])
def run_inference(args, image_list, Graphormer_model, smpl, renderer, mesh_sampler):
# switch to evaluate mode
Graphormer_model.eval()
smpl.eval()
with torch.no_grad():
for image_file in image_list:
if 'pred' not in image_file:
att_all = []
img = Image.open(image_file)
img_tensor = transform(img)
img_visual = transform_visualize(img)
batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device)
# forward-pass
pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, smpl, mesh_sampler)
# obtain 3d joints from full mesh
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
pred_3d_pelvis = pred_3d_joints_from_smpl[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl - pred_3d_pelvis[:, None, :]
pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
# save attantion
att_max_value = att[-1]
att_cpu = np.asarray(att_max_value.cpu().detach())
att_all.append(att_cpu)
# obtain 3d joints, which are regressed from the full mesh
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
# obtain 2d joints, which are projected from 3d joints of smpl mesh
pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera)
pred_2d_431_vertices_from_smpl = orthographic_projection(pred_vertices_sub2, pred_camera)
visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0],
pred_vertices[0].detach(),
pred_camera.detach())
# visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0],
# pred_vertices[0].detach(),
# pred_vertices_sub2[0].detach(),
# pred_2d_431_vertices_from_smpl[0].detach(),
# pred_2d_joints_from_smpl[0].detach(),
# pred_camera.detach(),
# att[-1][0].detach())
visual_imgs = visual_imgs_output.transpose(1,2,0)
visual_imgs = np.asarray(visual_imgs)
temp_fname = image_file[:-4] + '_graphormer_pred.jpg'
print('save to ', temp_fname)
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
return
def visualize_mesh( renderer, images,
pred_vertices_full,
pred_camera):
img = images.cpu().numpy().transpose(1,2,0)
# Get predict vertices for the particular example
vertices_full = pred_vertices_full.cpu().numpy()
cam = pred_camera.cpu().numpy()
# Visualize only mesh reconstruction
rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue')
rend_img = rend_img.transpose(2,0,1)
return rend_img
def visualize_mesh_and_attention( renderer, images,
pred_vertices_full,
pred_vertices,
pred_2d_vertices,
pred_2d_joints,
pred_camera,
attention):
img = images.cpu().numpy().transpose(1,2,0)
# Get predict vertices for the particular example
vertices_full = pred_vertices_full.cpu().numpy()
vertices = pred_vertices.cpu().numpy()
vertices_2d = pred_2d_vertices.cpu().numpy()
joints_2d = pred_2d_joints.cpu().numpy()
cam = pred_camera.cpu().numpy()
att = attention.cpu().numpy()
# Visualize reconstruction and attention
rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue')
rend_img = rend_img.transpose(2,0,1)
return rend_img
def parse_args():
parser = argparse.ArgumentParser()
#########################################################
# Data related arguments
#########################################################
parser.add_argument("--num_workers", default=4, type=int,
help="Workers in dataloader.")
parser.add_argument("--img_scale_factor", default=1, type=int,
help="adjust image resolution.")
parser.add_argument("--image_file_or_path", default='./samples/human-body', type=str,
help="test data")
#########################################################
# Loading/saving checkpoints
#########################################################
parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
help="Path to pre-trained transformer model or model type.")
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
help="Path to specific checkpoint for resume training.")
parser.add_argument("--output_dir", default='output/', type=str, required=False,
help="The output directory to save checkpoint and test results.")
parser.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name.")
#########################################################
# Model architectures
#########################################################
parser.add_argument('-a', '--arch', default='hrnet-w64',
help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
parser.add_argument("--num_hidden_layers", default=4, type=int, required=False,
help="Update model config if given")
parser.add_argument("--hidden_size", default=-1, type=int, required=False,
help="Update model config if given")
parser.add_argument("--num_attention_heads", default=4, type=int, required=False,
help="Update model config if given. Note that the division of "
"hidden_size / num_attention_heads should be in integer.")
parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
help="Update model config if given.")
parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
help="The Image Feature Dimension.")
parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
help="The Image Feature Dimension.")
parser.add_argument("--which_gcn", default='0,0,1', type=str,
help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
parser.add_argument("--mesh_type", default='body', type=str, help="body or hand")
parser.add_argument("--interm_size_scale", default=2, type=int)
#########################################################
# Others
#########################################################
parser.add_argument("--run_eval_only", default=True, action='store_true',)
parser.add_argument("--device", type=str, default='cuda',
help="cuda or cpu")
parser.add_argument('--seed', type=int, default=88,
help="random seed for initialization.")
args = parser.parse_args()
return args
def main(args):
global logger
# Setup CUDA, GPU & distributed training
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
args.distributed = args.num_gpus > 1
args.device = torch.device(args.device)
mkdir(args.output_dir)
logger = setup_logger("Graphormer", args.output_dir, get_rank())
set_seed(args.seed, args.num_gpus)
logger.info("Using {} GPUs".format(args.num_gpus))
# Mesh and SMPL utils
smpl = SMPL().to(args.device)
mesh_sampler = Mesh()
# Renderer for visualization
renderer = Renderer(faces=smpl.faces.cpu().numpy())
# Load model
trans_encoder = []
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
output_feat_dim = input_feat_dim[1:] + [3]
# which encoder block to have graph convs
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
# if only run eval, load checkpoint
logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
_model = torch.load(args.resume_checkpoint)
else:
# init three transformer-encoder blocks in a loop
for i in range(len(output_feat_dim)):
config_class, model_class = BertConfig, Graphormer
config = config_class.from_pretrained(args.config_name if args.config_name \
else args.model_name_or_path)
config.output_attentions = False
config.img_feature_dim = input_feat_dim[i]
config.output_feature_dim = output_feat_dim[i]
args.hidden_size = hidden_feat_dim[i]
args.intermediate_size = int(args.hidden_size*args.interm_size_scale)
if which_blk_graph[i]==1:
config.graph_conv = True
logger.info("Add Graph Conv")
else:
config.graph_conv = False
config.mesh_type = args.mesh_type
# update model structure if specified in arguments
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
for idx, param in enumerate(update_params):
arg_param = getattr(args, param)
config_param = getattr(config, param)
if arg_param > 0 and arg_param != config_param:
logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
setattr(config, param, arg_param)
# init a transformer encoder and append it to a list
assert config.hidden_size % config.num_attention_heads == 0
model = model_class(config=config)
logger.info("Init model from scratch.")
trans_encoder.append(model)
# init ImageNet pre-trained backbone model
if args.arch=='hrnet':
hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
hrnet_update_config(hrnet_config, hrnet_yaml)
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
logger.info('=> loading hrnet-v2-w40 model')
elif args.arch=='hrnet-w64':
hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
hrnet_update_config(hrnet_config, hrnet_yaml)
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
logger.info('=> loading hrnet-v2-w64 model')
else:
print("=> using pre-trained model '{}'".format(args.arch))
backbone = models.__dict__[args.arch](pretrained=True)
# remove the last fc layer
backbone = torch.nn.Sequential(*list(backbone.children())[:-2])
trans_encoder = torch.nn.Sequential(*trans_encoder)
total_params = sum(p.numel() for p in trans_encoder.parameters())
logger.info('Graphormer encoders total parameters: {}'.format(total_params))
backbone_total_params = sum(p.numel() for p in backbone.parameters())
logger.info('Backbone total parameters: {}'.format(backbone_total_params))
# build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder)
_model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler)
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
# for fine-tuning or resume training or inference, load weights from checkpoint
logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
# workaround approach to load sparse tensor in graph conv.
states = torch.load(args.resume_checkpoint)
# states = checkpoint_loaded.state_dict()
for k, v in states.items():
states[k] = v.cpu()
# del checkpoint_loaded
_model.load_state_dict(states, strict=False)
del states
gc.collect()
torch.cuda.empty_cache()
# update configs to enable attention outputs
setattr(_model.trans_encoder[-1].config,'output_attentions', True)
setattr(_model.trans_encoder[-1].config,'output_hidden_states', True)
_model.trans_encoder[-1].bert.encoder.output_attentions = True
_model.trans_encoder[-1].bert.encoder.output_hidden_states = True
for iter_layer in range(4):
_model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True
for inter_block in range(3):
setattr(_model.trans_encoder[-1].config,'device', args.device)
_model.to(args.device)
logger.info("Run inference")
image_list = []
if not args.image_file_or_path:
raise ValueError("image_file_or_path not specified")
if op.isfile(args.image_file_or_path):
image_list = [args.image_file_or_path]
elif op.isdir(args.image_file_or_path):
# should be a path with images only
for filename in os.listdir(args.image_file_or_path):
if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename:
image_list.append(args.image_file_or_path+'/'+filename)
else:
raise ValueError("Cannot find images at {}".format(args.image_file_or_path))
run_inference(args, image_list, _model, smpl, renderer, mesh_sampler)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,713 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
Training and evaluation codes for
3D hand mesh reconstruction from an image
"""
from __future__ import absolute_import, division, print_function
import argparse
import os
import os.path as op
import code
import json
import time
import datetime
import torch
import torchvision.models as models
from torchvision.utils import make_grid
import gc
import numpy as np
import cv2
from mesh_graphormer.modeling.bert import BertConfig, Graphormer
from mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network
from mesh_graphormer.modeling._mano import MANO, Mesh
from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
from mesh_graphormer.modeling.hrnet.config import config as hrnet_config
from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
import mesh_graphormer.modeling.data.config as cfg
from mesh_graphormer.datasets.build import make_hand_data_loader
from mesh_graphormer.utils.logger import setup_logger
from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
from mesh_graphormer.utils.miscellaneous import mkdir, set_seed
from mesh_graphormer.utils.metric_logger import AverageMeter
from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction, visualize_reconstruction_test, visualize_reconstruction_no_text
from mesh_graphormer.utils.metric_pampjpe import reconstruction_error
from mesh_graphormer.utils.geometric_layers import orthographic_projection
device = "cuda"
from azureml.core.run import Run
aml_run = Run.get_context()
def save_checkpoint(model, args, epoch, iteration, num_trial=10):
checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format(
epoch, iteration))
if not is_main_process():
return checkpoint_dir
mkdir(checkpoint_dir)
model_to_save = model.module if hasattr(model, 'module') else model
for i in range(num_trial):
try:
torch.save(model_to_save, op.join(checkpoint_dir, 'model.bin'))
torch.save(model_to_save.state_dict(), op.join(checkpoint_dir, 'state_dict.bin'))
torch.save(args, op.join(checkpoint_dir, 'training_args.bin'))
logger.info("Save checkpoint to {}".format(checkpoint_dir))
break
except:
pass
else:
logger.info("Failed to save checkpoint after {} trails.".format(num_trial))
return checkpoint_dir
def adjust_learning_rate(optimizer, epoch, args):
"""
Sets the learning rate to the initial LR decayed by x every y epochs
x = 0.1, y = args.num_train_epochs/2.0 = 100
"""
lr = args.lr * (0.1 ** (epoch // (args.num_train_epochs/2.0) ))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, has_pose_2d):
"""
Compute 2D reprojection loss if 2D keypoint annotations are available.
The confidence is binary and indicates whether the keypoints exist or not.
"""
conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
loss = (conf * criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean()
return loss
def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d):
"""
Compute 3D keypoint loss if 3D keypoint annotations are available.
"""
conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
conf = conf[has_pose_3d == 1]
pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
if len(gt_keypoints_3d) > 0:
gt_root = gt_keypoints_3d[:, 0,:]
gt_keypoints_3d = gt_keypoints_3d - gt_root[:, None, :]
pred_root = pred_keypoints_3d[:, 0,:]
pred_keypoints_3d = pred_keypoints_3d - pred_root[:, None, :]
return (conf * criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean()
else:
return torch.FloatTensor(1).fill_(0.).to(device)
def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl):
"""
Compute per-vertex loss if vertex annotations are available.
"""
pred_vertices_with_shape = pred_vertices[has_smpl == 1]
gt_vertices_with_shape = gt_vertices[has_smpl == 1]
if len(gt_vertices_with_shape) > 0:
return criterion_vertices(pred_vertices_with_shape, gt_vertices_with_shape)
else:
return torch.FloatTensor(1).fill_(0.).to(device)
def run(args, train_dataloader, Graphormer_model, mano_model, renderer, mesh_sampler):
max_iter = len(train_dataloader)
iters_per_epoch = max_iter // args.num_train_epochs
optimizer = torch.optim.Adam(params=list(Graphormer_model.parameters()),
lr=args.lr,
betas=(0.9, 0.999),
weight_decay=0)
# define loss function (criterion) and optimizer
criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').to(device)
criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
criterion_vertices = torch.nn.L1Loss().to(device)
if args.distributed:
Graphormer_model = torch.nn.parallel.DistributedDataParallel(
Graphormer_model, device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True,
)
start_training_time = time.time()
end = time.time()
Graphormer_model.train()
batch_time = AverageMeter()
data_time = AverageMeter()
log_losses = AverageMeter()
log_loss_2djoints = AverageMeter()
log_loss_3djoints = AverageMeter()
log_loss_vertices = AverageMeter()
for iteration, (img_keys, images, annotations) in enumerate(train_dataloader):
Graphormer_model.train()
iteration += 1
epoch = iteration // iters_per_epoch
batch_size = images.size(0)
adjust_learning_rate(optimizer, epoch, args)
data_time.update(time.time() - end)
images = images.to(device)
gt_2d_joints = annotations['joints_2d'].to(device)
gt_pose = annotations['pose'].to(device)
gt_betas = annotations['betas'].to(device)
has_mesh = annotations['has_smpl'].to(device)
has_3d_joints = has_mesh
has_2d_joints = has_mesh
mjm_mask = annotations['mjm_mask'].to(device)
mvm_mask = annotations['mvm_mask'].to(device)
# generate mesh
gt_vertices, gt_3d_joints = mano_model.layer(gt_pose, gt_betas)
gt_vertices = gt_vertices/1000.0
gt_3d_joints = gt_3d_joints/1000.0
gt_vertices_sub = mesh_sampler.downsample(gt_vertices)
# normalize gt based on hand's wrist
gt_3d_root = gt_3d_joints[:,cfg.J_NAME.index('Wrist'),:]
gt_vertices = gt_vertices - gt_3d_root[:, None, :]
gt_vertices_sub = gt_vertices_sub - gt_3d_root[:, None, :]
gt_3d_joints = gt_3d_joints - gt_3d_root[:, None, :]
gt_3d_joints_with_tag = torch.ones((batch_size,gt_3d_joints.shape[1],4)).to(device)
gt_3d_joints_with_tag[:,:,:3] = gt_3d_joints
# prepare masks for mask vertex/joint modeling
mjm_mask_ = mjm_mask.expand(-1,-1,2051)
mvm_mask_ = mvm_mask.expand(-1,-1,2051)
meta_masks = torch.cat([mjm_mask_, mvm_mask_], dim=1)
# forward-pass
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler, meta_masks=meta_masks, is_train=True)
# obtain 3d joints, which are regressed from the full mesh
pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
# obtain 2d joints, which are projected from 3d joints of smpl mesh
pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
pred_2d_joints = orthographic_projection(pred_3d_joints.contiguous(), pred_camera.contiguous())
# compute 3d joint loss (where the joints are directly output from transformer)
loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints_with_tag, has_3d_joints)
# compute 3d vertex loss
loss_vertices = ( args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_mesh) + \
args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_mesh) )
# compute 3d joint loss (where the joints are regressed from full mesh)
loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_mesh, gt_3d_joints_with_tag, has_3d_joints)
# compute 2d joint loss
loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints) + \
keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_mesh, gt_2d_joints, has_2d_joints)
loss_3d_joints = loss_3d_joints + loss_reg_3d_joints
# we empirically use hyperparameters to balance difference losses
loss = args.joints_loss_weight*loss_3d_joints + \
args.vertices_loss_weight*loss_vertices + args.vertices_loss_weight*loss_2d_joints
# update logs
log_loss_2djoints.update(loss_2d_joints.item(), batch_size)
log_loss_3djoints.update(loss_3d_joints.item(), batch_size)
log_loss_vertices.update(loss_vertices.item(), batch_size)
log_losses.update(loss.item(), batch_size)
# back prop
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if iteration % args.logging_steps == 0 or iteration == max_iter:
eta_seconds = batch_time.avg * (max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
logger.info(
' '.join(
['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',]
).format(eta=eta_string, ep=epoch, iter=iteration,
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
+ ' loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format(
log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg,
optimizer.param_groups[0]['lr'])
)
aml_run.log(name='Loss', value=float(log_losses.avg))
aml_run.log(name='3d joint Loss', value=float(log_loss_3djoints.avg))
aml_run.log(name='2d joint Loss', value=float(log_loss_2djoints.avg))
aml_run.log(name='vertex Loss', value=float(log_loss_vertices.avg))
visual_imgs = visualize_mesh( renderer,
annotations['ori_img'].detach(),
annotations['joints_2d'].detach(),
pred_vertices.detach(),
pred_camera.detach(),
pred_2d_joints_from_mesh.detach())
visual_imgs = visual_imgs.transpose(0,1)
visual_imgs = visual_imgs.transpose(1,2)
visual_imgs = np.asarray(visual_imgs)
if is_main_process()==True:
stamp = str(epoch) + '_' + str(iteration)
temp_fname = args.output_dir + 'visual_' + stamp + '.jpg'
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
aml_run.log_image(name='visual results', path=temp_fname)
if iteration % iters_per_epoch == 0:
if epoch%10==0:
checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration)
total_training_time = time.time() - start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time))
logger.info('Total training time: {} ({:.4f} s / iter)'.format(
total_time_str, total_training_time / max_iter)
)
checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration)
def run_eval_and_save(args, split, val_dataloader, Graphormer_model, mano_model, renderer, mesh_sampler):
criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device)
criterion_vertices = torch.nn.L1Loss().to(device)
if args.distributed:
Graphormer_model = torch.nn.parallel.DistributedDataParallel(
Graphormer_model, device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True,
)
Graphormer_model.eval()
if args.aml_eval==True:
run_aml_inference_hand_mesh(args, val_dataloader,
Graphormer_model,
criterion_keypoints,
criterion_vertices,
0,
mano_model, mesh_sampler,
renderer, split)
else:
run_inference_hand_mesh(args, val_dataloader,
Graphormer_model,
criterion_keypoints,
criterion_vertices,
0,
mano_model, mesh_sampler,
renderer, split)
checkpoint_dir = save_checkpoint(Graphormer_model, args, 0, 0)
return
def run_aml_inference_hand_mesh(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, mano_model, mesh_sampler, renderer, split):
# switch to evaluate mode
Graphormer_model.eval()
fname_output_save = []
mesh_output_save = []
joint_output_save = []
world_size = get_world_size()
with torch.no_grad():
for i, (img_keys, images, annotations) in enumerate(val_loader):
batch_size = images.size(0)
# compute output
images = images.to(device)
# forward-pass
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler)
# obtain 3d joints from full mesh
pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
for j in range(batch_size):
fname_output_save.append(img_keys[j])
pred_vertices_list = pred_vertices[j].tolist()
mesh_output_save.append(pred_vertices_list)
pred_3d_joints_from_mesh_list = pred_3d_joints_from_mesh[j].tolist()
joint_output_save.append(pred_3d_joints_from_mesh_list)
if world_size > 1:
torch.distributed.barrier()
print('save results to pred.json')
output_json_file = 'pred.json'
print('save results to ', output_json_file)
with open(output_json_file, 'w') as f:
json.dump([joint_output_save, mesh_output_save], f)
azure_ckpt_name = '200' # args.resume_checkpoint.split('/')[-2].split('-')[1]
inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot)))
output_zip_file = args.output_dir + 'ckpt' + azure_ckpt_name + '-' + inference_setting +'-pred.zip'
resolved_submit_cmd = 'zip ' + output_zip_file + ' ' + output_json_file
print(resolved_submit_cmd)
os.system(resolved_submit_cmd)
resolved_submit_cmd = 'rm %s'%(output_json_file)
print(resolved_submit_cmd)
os.system(resolved_submit_cmd)
if world_size > 1:
torch.distributed.barrier()
return
def run_inference_hand_mesh(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, mano_model, mesh_sampler, renderer, split):
# switch to evaluate mode
Graphormer_model.eval()
fname_output_save = []
mesh_output_save = []
joint_output_save = []
with torch.no_grad():
for i, (img_keys, images, annotations) in enumerate(val_loader):
batch_size = images.size(0)
# compute output
images = images.to(device)
# forward-pass
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler)
# obtain 3d joints from full mesh
pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:]
pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :]
pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
for j in range(batch_size):
fname_output_save.append(img_keys[j])
pred_vertices_list = pred_vertices[j].tolist()
mesh_output_save.append(pred_vertices_list)
pred_3d_joints_from_mesh_list = pred_3d_joints_from_mesh[j].tolist()
joint_output_save.append(pred_3d_joints_from_mesh_list)
if i%20==0:
# obtain 3d joints, which are regressed from the full mesh
pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices)
# obtain 2d joints, which are projected from 3d joints of mesh
pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
visual_imgs = visualize_mesh( renderer,
annotations['ori_img'].detach(),
annotations['joints_2d'].detach(),
pred_vertices.detach(),
pred_camera.detach(),
pred_2d_joints_from_mesh.detach())
visual_imgs = visual_imgs.transpose(0,1)
visual_imgs = visual_imgs.transpose(1,2)
visual_imgs = np.asarray(visual_imgs)
inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot)))
temp_fname = args.output_dir + args.resume_checkpoint[0:-9] + 'freihand_results_'+inference_setting+'_batch'+str(i)+'.jpg'
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
print('save results to pred.json')
with open('pred.json', 'w') as f:
json.dump([joint_output_save, mesh_output_save], f)
run_exp_name = args.resume_checkpoint.split('/')[-3]
run_ckpt_name = args.resume_checkpoint.split('/')[-2].split('-')[1]
inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot)))
resolved_submit_cmd = 'zip ' + args.output_dir + run_exp_name + '-ckpt'+ run_ckpt_name + '-' + inference_setting +'-pred.zip ' + 'pred.json'
print(resolved_submit_cmd)
os.system(resolved_submit_cmd)
resolved_submit_cmd = 'rm pred.json'
print(resolved_submit_cmd)
os.system(resolved_submit_cmd)
return
def visualize_mesh( renderer,
images,
gt_keypoints_2d,
pred_vertices,
pred_camera,
pred_keypoints_2d):
"""Tensorboard logging."""
gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
to_lsp = list(range(21))
rend_imgs = []
batch_size = pred_vertices.shape[0]
# Do visualization for the first 6 images of the batch
for i in range(min(batch_size, 10)):
img = images[i].cpu().numpy().transpose(1,2,0)
# Get LSP keypoints from the full list of keypoints
gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
# Get predict vertices for the particular example
vertices = pred_vertices[i].cpu().numpy()
cam = pred_camera[i].cpu().numpy()
# Visualize reconstruction and detected pose
rend_img = visualize_reconstruction(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer)
rend_img = rend_img.transpose(2,0,1)
rend_imgs.append(torch.from_numpy(rend_img))
rend_imgs = make_grid(rend_imgs, nrow=1)
return rend_imgs
def visualize_mesh_test( renderer,
images,
gt_keypoints_2d,
pred_vertices,
pred_camera,
pred_keypoints_2d,
PAmPJPE):
"""Tensorboard logging."""
gt_keypoints_2d = gt_keypoints_2d.cpu().numpy()
to_lsp = list(range(21))
rend_imgs = []
batch_size = pred_vertices.shape[0]
# Do visualization for the first 6 images of the batch
for i in range(min(batch_size, 10)):
img = images[i].cpu().numpy().transpose(1,2,0)
# Get LSP keypoints from the full list of keypoints
gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp]
pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp]
# Get predict vertices for the particular example
vertices = pred_vertices[i].cpu().numpy()
cam = pred_camera[i].cpu().numpy()
score = PAmPJPE[i]
# Visualize reconstruction and detected pose
rend_img = visualize_reconstruction_test(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer, score)
rend_img = rend_img.transpose(2,0,1)
rend_imgs.append(torch.from_numpy(rend_img))
rend_imgs = make_grid(rend_imgs, nrow=1)
return rend_imgs
def visualize_mesh_no_text( renderer,
images,
pred_vertices,
pred_camera):
"""Tensorboard logging."""
rend_imgs = []
batch_size = pred_vertices.shape[0]
# Do visualization for the first 6 images of the batch
for i in range(min(batch_size, 1)):
img = images[i].cpu().numpy().transpose(1,2,0)
# Get predict vertices for the particular example
vertices = pred_vertices[i].cpu().numpy()
cam = pred_camera[i].cpu().numpy()
# Visualize reconstruction only
rend_img = visualize_reconstruction_no_text(img, 224, vertices, cam, renderer, color='hand')
rend_img = rend_img.transpose(2,0,1)
rend_imgs.append(torch.from_numpy(rend_img))
rend_imgs = make_grid(rend_imgs, nrow=1)
return rend_imgs
def parse_args():
parser = argparse.ArgumentParser()
#########################################################
# Data related arguments
#########################################################
parser.add_argument("--data_dir", default='datasets', type=str, required=False,
help="Directory with all datasets, each in one subfolder")
parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False,
help="Yaml file with all data for training.")
parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False,
help="Yaml file with all data for validation.")
parser.add_argument("--num_workers", default=4, type=int,
help="Workers in dataloader.")
parser.add_argument("--img_scale_factor", default=1, type=int,
help="adjust image resolution.")
#########################################################
# Loading/saving checkpoints
#########################################################
parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
help="Path to pre-trained transformer model or model type.")
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
help="Path to specific checkpoint for resume training.")
parser.add_argument("--output_dir", default='output/', type=str, required=False,
help="The output directory to save checkpoint and test results.")
parser.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name.")
parser.add_argument('-a', '--arch', default='hrnet-w64',
help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
#########################################################
# Training parameters
#########################################################
parser.add_argument("--per_gpu_train_batch_size", default=64, type=int,
help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=64, type=int,
help="Batch size per GPU/CPU for evaluation.")
parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float,
help="The initial lr.")
parser.add_argument("--num_train_epochs", default=200, type=int,
help="Total number of training epochs to perform.")
parser.add_argument("--vertices_loss_weight", default=1.0, type=float)
parser.add_argument("--joints_loss_weight", default=1.0, type=float)
parser.add_argument("--vloss_w_full", default=0.5, type=float)
parser.add_argument("--vloss_w_sub", default=0.5, type=float)
parser.add_argument("--drop_out", default=0.1, type=float,
help="Drop out ratio in BERT.")
#########################################################
# Model architectures
#########################################################
parser.add_argument("--num_hidden_layers", default=-1, type=int, required=False,
help="Update model config if given")
parser.add_argument("--hidden_size", default=-1, type=int, required=False,
help="Update model config if given")
parser.add_argument("--num_attention_heads", default=-1, type=int, required=False,
help="Update model config if given. Note that the division of "
"hidden_size / num_attention_heads should be in integer.")
parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
help="Update model config if given.")
parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
help="The Image Feature Dimension.")
parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
help="The Image Feature Dimension.")
parser.add_argument("--which_gcn", default='0,0,1', type=str,
help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand")
#########################################################
# Others
#########################################################
parser.add_argument("--run_eval_only", default=False, action='store_true',)
parser.add_argument("--multiscale_inference", default=False, action='store_true',)
# if enable "multiscale_inference", dataloader will apply transformations to the test image based on
# the rotation "rot" and scale "sc" parameters below
parser.add_argument("--rot", default=0, type=float)
parser.add_argument("--sc", default=1.0, type=float)
parser.add_argument("--aml_eval", default=False, action='store_true',)
parser.add_argument('--logging_steps', type=int, default=100,
help="Log every X steps.")
parser.add_argument("--device", type=str, default='cuda',
help="cuda or cpu")
parser.add_argument('--seed', type=int, default=88,
help="random seed for initialization.")
parser.add_argument("--local_rank", type=int, default=0,
help="For distributed training.")
args = parser.parse_args()
return args
def main(args):
global logger
# Setup CUDA, GPU & distributed training
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
args.distributed = args.num_gpus > 1
args.device = torch.device(args.device)
if args.distributed:
print("Init distributed training on local rank {}".format(args.local_rank))
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
backend='nccl', init_method='env://'
)
synchronize()
mkdir(args.output_dir)
logger = setup_logger("Graphormer", args.output_dir, get_rank())
set_seed(args.seed, args.num_gpus)
logger.info("Using {} GPUs".format(args.num_gpus))
# Mesh and SMPL utils
mano_model = MANO().to(args.device)
mano_model.layer = mano_model.layer.to(device)
mesh_sampler = Mesh()
# Renderer for visualization
renderer = Renderer(faces=mano_model.face)
# Load pretrained model
trans_encoder = []
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
output_feat_dim = input_feat_dim[1:] + [3]
# which encoder block to have graph convs
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
# if only run eval, load checkpoint
logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
_model = torch.load(args.resume_checkpoint)
else:
# init three transformer-encoder blocks in a loop
for i in range(len(output_feat_dim)):
config_class, model_class = BertConfig, Graphormer
config = config_class.from_pretrained(args.config_name if args.config_name \
else args.model_name_or_path)
config.output_attentions = False
config.hidden_dropout_prob = args.drop_out
config.img_feature_dim = input_feat_dim[i]
config.output_feature_dim = output_feat_dim[i]
args.hidden_size = hidden_feat_dim[i]
args.intermediate_size = int(args.hidden_size*2)
if which_blk_graph[i]==1:
config.graph_conv = True
logger.info("Add Graph Conv")
else:
config.graph_conv = False
config.mesh_type = args.mesh_type
# update model structure if specified in arguments
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
for idx, param in enumerate(update_params):
arg_param = getattr(args, param)
config_param = getattr(config, param)
if arg_param > 0 and arg_param != config_param:
logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
setattr(config, param, arg_param)
# init a transformer encoder and append it to a list
assert config.hidden_size % config.num_attention_heads == 0
model = model_class(config=config)
logger.info("Init model from scratch.")
trans_encoder.append(model)
# create backbone model
if args.arch=='hrnet':
hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
hrnet_update_config(hrnet_config, hrnet_yaml)
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
logger.info('=> loading hrnet-v2-w40 model')
elif args.arch=='hrnet-w64':
hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
hrnet_update_config(hrnet_config, hrnet_yaml)
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
logger.info('=> loading hrnet-v2-w64 model')
else:
print("=> using pre-trained model '{}'".format(args.arch))
backbone = models.__dict__[args.arch](pretrained=True)
# remove the last fc layer
backbone = torch.nn.Sequential(*list(backbone.children())[:-1])
trans_encoder = torch.nn.Sequential(*trans_encoder)
total_params = sum(p.numel() for p in trans_encoder.parameters())
logger.info('Graphormer encoders total parameters: {}'.format(total_params))
backbone_total_params = sum(p.numel() for p in backbone.parameters())
logger.info('Backbone total parameters: {}'.format(backbone_total_params))
# build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
_model = Graphormer_Network(args, config, backbone, trans_encoder)
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
# for fine-tuning or resume training or inference, load weights from checkpoint
logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
# workaround approach to load sparse tensor in graph conv.
state_dict = torch.load(args.resume_checkpoint)
_model.load_state_dict(state_dict, strict=False)
del state_dict
gc.collect()
torch.cuda.empty_cache()
_model.to(args.device)
logger.info("Training parameters %s", args)
if args.run_eval_only==True:
val_dataloader = make_hand_data_loader(args, args.val_yaml,
args.distributed, is_train=False, scale_factor=args.img_scale_factor)
run_eval_and_save(args, 'freihand', val_dataloader, _model, mano_model, renderer, mesh_sampler)
else:
train_dataloader = make_hand_data_loader(args, args.train_yaml,
args.distributed, is_train=True, scale_factor=args.img_scale_factor)
run(args, train_dataloader, _model, mano_model, renderer, mesh_sampler)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,338 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
End-to-end inference codes for
3D hand mesh reconstruction from an image
"""
from __future__ import absolute_import, division, print_function
import argparse
import os
import os.path as op
import code
import json
import time
import datetime
import torch
import torchvision.models as models
from torchvision.utils import make_grid
import gc
import numpy as np
import cv2
from mesh_graphormer.modeling.bert import BertConfig, Graphormer
from mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network
from mesh_graphormer.modeling._mano import MANO, Mesh
from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
from mesh_graphormer.modeling.hrnet.config import config as hrnet_config
from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
import mesh_graphormer.modeling.data.config as cfg
from mesh_graphormer.datasets.build import make_hand_data_loader
from mesh_graphormer.utils.logger import setup_logger
from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
from mesh_graphormer.utils.miscellaneous import mkdir, set_seed
from mesh_graphormer.utils.metric_logger import AverageMeter
from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text
from mesh_graphormer.utils.metric_pampjpe import reconstruction_error
from mesh_graphormer.utils.geometric_layers import orthographic_projection
from PIL import Image
from torchvision import transforms
device = "cuda"
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
transform_visualize = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor()])
def run_inference(args, image_list, Graphormer_model, mano, renderer, mesh_sampler):
# switch to evaluate mode
Graphormer_model.eval()
mano.eval()
with torch.no_grad():
for image_file in image_list:
if 'pred' not in image_file:
att_all = []
print(image_file)
img = Image.open(image_file)
img_tensor = transform(img)
img_visual = transform_visualize(img)
batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device)
# forward-pass
pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler)
# obtain 3d joints from full mesh
pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:]
pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :]
pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
# save attantion
att_max_value = att[-1]
att_cpu = np.asarray(att_max_value.cpu().detach())
att_all.append(att_cpu)
# obtain 3d joints, which are regressed from the full mesh
pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
# obtain 2d joints, which are projected from 3d joints of mesh
pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous())
visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0],
pred_vertices[0].detach(),
pred_camera.detach())
# visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0],
# pred_vertices[0].detach(),
# pred_vertices_sub[0].detach(),
# pred_2d_coarse_vertices_from_mesh[0].detach(),
# pred_2d_joints_from_mesh[0].detach(),
# pred_camera.detach(),
# att[-1][0].detach())
visual_imgs = visual_imgs_output.transpose(1,2,0)
visual_imgs = np.asarray(visual_imgs)
temp_fname = image_file[:-4] + '_graphormer_pred.jpg'
print('save to ', temp_fname)
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
return
def visualize_mesh( renderer, images,
pred_vertices_full,
pred_camera):
img = images.cpu().numpy().transpose(1,2,0)
# Get predict vertices for the particular example
vertices_full = pred_vertices_full.cpu().numpy()
cam = pred_camera.cpu().numpy()
# Visualize only mesh reconstruction
rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue')
rend_img = rend_img.transpose(2,0,1)
return rend_img
def visualize_mesh_and_attention( renderer, images,
pred_vertices_full,
pred_vertices,
pred_2d_vertices,
pred_2d_joints,
pred_camera,
attention):
img = images.cpu().numpy().transpose(1,2,0)
# Get predict vertices for the particular example
vertices_full = pred_vertices_full.cpu().numpy()
vertices = pred_vertices.cpu().numpy()
vertices_2d = pred_2d_vertices.cpu().numpy()
joints_2d = pred_2d_joints.cpu().numpy()
cam = pred_camera.cpu().numpy()
att = attention.cpu().numpy()
# Visualize reconstruction and attention
rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue')
rend_img = rend_img.transpose(2,0,1)
return rend_img
def parse_args():
parser = argparse.ArgumentParser()
#########################################################
# Data related arguments
#########################################################
parser.add_argument("--num_workers", default=4, type=int,
help="Workers in dataloader.")
parser.add_argument("--img_scale_factor", default=1, type=int,
help="adjust image resolution.")
parser.add_argument("--image_file_or_path", default='./samples/hand', type=str,
help="test data")
#########################################################
# Loading/saving checkpoints
#########################################################
parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
help="Path to pre-trained transformer model or model type.")
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
help="Path to specific checkpoint for resume training.")
parser.add_argument("--output_dir", default='output/', type=str, required=False,
help="The output directory to save checkpoint and test results.")
parser.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name.")
parser.add_argument('-a', '--arch', default='hrnet-w64',
help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
#########################################################
# Model architectures
#########################################################
parser.add_argument("--num_hidden_layers", default=4, type=int, required=False,
help="Update model config if given")
parser.add_argument("--hidden_size", default=-1, type=int, required=False,
help="Update model config if given")
parser.add_argument("--num_attention_heads", default=4, type=int, required=False,
help="Update model config if given. Note that the division of "
"hidden_size / num_attention_heads should be in integer.")
parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
help="Update model config if given.")
parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
help="The Image Feature Dimension.")
parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
help="The Image Feature Dimension.")
parser.add_argument("--which_gcn", default='0,0,1', type=str,
help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand")
#########################################################
# Others
#########################################################
parser.add_argument("--run_eval_only", default=True, action='store_true',)
parser.add_argument("--device", type=str, default='cuda',
help="cuda or cpu")
parser.add_argument('--seed', type=int, default=88,
help="random seed for initialization.")
args = parser.parse_args()
return args
def main(args):
global logger
# Setup CUDA, GPU & distributed training
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
mkdir(args.output_dir)
logger = setup_logger("Graphormer", args.output_dir, get_rank())
set_seed(args.seed, args.num_gpus)
logger.info("Using {} GPUs".format(args.num_gpus))
# Mesh and MANO utils
mano_model = MANO().to(args.device)
mano_model.layer = mano_model.layer.to(device)
mesh_sampler = Mesh()
# Renderer for visualization
renderer = Renderer(faces=mano_model.face)
# Load pretrained model
trans_encoder = []
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
output_feat_dim = input_feat_dim[1:] + [3]
# which encoder block to have graph convs
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
# if only run eval, load checkpoint
logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
_model = torch.load(args.resume_checkpoint)
else:
# init three transformer-encoder blocks in a loop
for i in range(len(output_feat_dim)):
config_class, model_class = BertConfig, Graphormer
config = config_class.from_pretrained(args.config_name if args.config_name \
else args.model_name_or_path)
config.output_attentions = False
config.img_feature_dim = input_feat_dim[i]
config.output_feature_dim = output_feat_dim[i]
args.hidden_size = hidden_feat_dim[i]
args.intermediate_size = int(args.hidden_size*2)
if which_blk_graph[i]==1:
config.graph_conv = True
logger.info("Add Graph Conv")
else:
config.graph_conv = False
config.mesh_type = args.mesh_type
# update model structure if specified in arguments
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
for idx, param in enumerate(update_params):
arg_param = getattr(args, param)
config_param = getattr(config, param)
if arg_param > 0 and arg_param != config_param:
logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
setattr(config, param, arg_param)
# init a transformer encoder and append it to a list
assert config.hidden_size % config.num_attention_heads == 0
model = model_class(config=config)
logger.info("Init model from scratch.")
trans_encoder.append(model)
# create backbone model
if args.arch=='hrnet':
hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
hrnet_update_config(hrnet_config, hrnet_yaml)
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
logger.info('=> loading hrnet-v2-w40 model')
elif args.arch=='hrnet-w64':
hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
hrnet_update_config(hrnet_config, hrnet_yaml)
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
logger.info('=> loading hrnet-v2-w64 model')
else:
print("=> using pre-trained model '{}'".format(args.arch))
backbone = models.__dict__[args.arch](pretrained=True)
# remove the last fc layer
backbone = torch.nn.Sequential(*list(backbone.children())[:-1])
trans_encoder = torch.nn.Sequential(*trans_encoder)
total_params = sum(p.numel() for p in trans_encoder.parameters())
logger.info('Graphormer encoders total parameters: {}'.format(total_params))
backbone_total_params = sum(p.numel() for p in backbone.parameters())
logger.info('Backbone total parameters: {}'.format(backbone_total_params))
# build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
_model = Graphormer_Network(args, config, backbone, trans_encoder)
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
# for fine-tuning or resume training or inference, load weights from checkpoint
logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
# workaround approach to load sparse tensor in graph conv.
state_dict = torch.load(args.resume_checkpoint)
_model.load_state_dict(state_dict, strict=False)
del state_dict
gc.collect()
torch.cuda.empty_cache()
# update configs to enable attention outputs
setattr(_model.trans_encoder[-1].config,'output_attentions', True)
setattr(_model.trans_encoder[-1].config,'output_hidden_states', True)
_model.trans_encoder[-1].bert.encoder.output_attentions = True
_model.trans_encoder[-1].bert.encoder.output_hidden_states = True
for iter_layer in range(4):
_model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True
for inter_block in range(3):
setattr(_model.trans_encoder[-1].config,'device', args.device)
_model.to(args.device)
logger.info("Run inference")
image_list = []
if not args.image_file_or_path:
raise ValueError("image_file_or_path not specified")
if op.isfile(args.image_file_or_path):
image_list = [args.image_file_or_path]
elif op.isdir(args.image_file_or_path):
# should be a path with images only
for filename in os.listdir(args.image_file_or_path):
if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename:
image_list.append(args.image_file_or_path+'/'+filename)
else:
raise ValueError("Cannot find images at {}".format(args.image_file_or_path))
run_inference(args, image_list, _model, mano_model, renderer, mesh_sampler)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,136 @@
from __future__ import absolute_import, division, print_function
import argparse
import os
import os.path as op
import code
import json
import zipfile
import torch
import numpy as np
from mesh_graphormer.utils.metric_pampjpe import get_alignMesh
def load_pred_json(filepath):
archive = zipfile.ZipFile(filepath, 'r')
jsondata = archive.read('pred.json')
reference = json.loads(jsondata.decode("utf-8"))
return reference[0], reference[1]
def multiscale_fusion(output_dir):
s = '10'
filepath = output_dir+'ckpt200-sc10_rot0-pred.zip'
ref_joints, ref_vertices = load_pred_json(filepath)
ref_joints_array = np.asarray(ref_joints)
ref_vertices_array = np.asarray(ref_vertices)
rotations = [0.0]
for i in range(1,10):
rotations.append(i*10)
rotations.append(i*-10)
scale = [0.7,0.8,0.9,1.0,1.1]
multiscale_joints = []
multiscale_vertices = []
counter = 0
for s in scale:
for r in rotations:
setting = 'sc%02d_rot%s'%(int(s*10),str(int(r)))
filepath = output_dir+'ckpt200-'+setting+'-pred.zip'
joints, vertices = load_pred_json(filepath)
joints_array = np.asarray(joints)
vertices_array = np.asarray(vertices)
pa_joint_error, pa_joint_array, _ = get_alignMesh(joints_array, ref_joints_array, reduction=None)
pa_vertices_error, pa_vertices_array, _ = get_alignMesh(vertices_array, ref_vertices_array, reduction=None)
print('--------------------------')
print('scale:', s, 'rotate', r)
print('PAMPJPE:', 1000*np.mean(pa_joint_error))
print('PAMPVPE:', 1000*np.mean(pa_vertices_error))
multiscale_joints.append(pa_joint_array)
multiscale_vertices.append(pa_vertices_array)
counter = counter + 1
overall_joints_array = ref_joints_array.copy()
overall_vertices_array = ref_vertices_array.copy()
for i in range(counter):
overall_joints_array += multiscale_joints[i]
overall_vertices_array += multiscale_vertices[i]
overall_joints_array /= (1+counter)
overall_vertices_array /= (1+counter)
pa_joint_error, pa_joint_array, _ = get_alignMesh(overall_joints_array, ref_joints_array, reduction=None)
pa_vertices_error, pa_vertices_array, _ = get_alignMesh(overall_vertices_array, ref_vertices_array, reduction=None)
print('--------------------------')
print('overall:')
print('PAMPJPE:', 1000*np.mean(pa_joint_error))
print('PAMPVPE:', 1000*np.mean(pa_vertices_error))
joint_output_save = overall_joints_array.tolist()
mesh_output_save = overall_vertices_array.tolist()
print('save results to pred.json')
with open('pred.json', 'w') as f:
json.dump([joint_output_save, mesh_output_save], f)
filepath = output_dir+'ckpt200-multisc-pred.zip'
resolved_submit_cmd = 'zip ' + filepath + ' ' + 'pred.json'
print(resolved_submit_cmd)
os.system(resolved_submit_cmd)
resolved_submit_cmd = 'rm pred.json'
print(resolved_submit_cmd)
os.system(resolved_submit_cmd)
def run_multiscale_inference(model_path, mode, output_dir):
if mode==True:
rotations = [0.0]
for i in range(1,10):
rotations.append(i*10)
rotations.append(i*-10)
scale = [0.7,0.8,0.9,1.0,1.1]
else:
rotations = [0.0]
scale = [1.0]
job_cmd = "python ./src/tools/run_gphmer_handmesh.py " \
"--val_yaml freihand_v3/test.yaml " \
"--resume_checkpoint %s " \
"--per_gpu_eval_batch_size 32 --run_eval_only --num_worker 2 " \
"--multiscale_inference " \
"--rot %f " \
"--sc %s " \
"--arch hrnet-w64 " \
"--num_hidden_layers 4 " \
"--num_attention_heads 4 " \
"--input_feat_dim 2051,512,128 " \
"--hidden_feat_dim 1024,256,64 " \
"--output_dir %s"
for s in scale:
for r in rotations:
resolved_submit_cmd = job_cmd%(model_path, r, s, output_dir)
print(resolved_submit_cmd)
os.system(resolved_submit_cmd)
def main(args):
model_path = args.model_path
mode = args.multiscale_inference
output_dir = args.output_dir
run_multiscale_inference(model_path, mode, output_dir)
if mode==True:
multiscale_fusion(output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate a checkpoint in the folder")
parser.add_argument("--model_path")
parser.add_argument("--multiscale_inference", default=False, action='store_true',)
parser.add_argument("--output_dir", default='output/', type=str, required=False,
help="The output directory to save checkpoint and test results.")
args = parser.parse_args()
main(args)

View File

View File

@@ -0,0 +1,176 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
This file contains primitives for multi-gpu communication.
This is useful when doing distributed training.
"""
import pickle
import time
import torch
import torch.distributed as dist
device = "cuda"
def get_world_size():
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def gather_on_master(data):
"""Same as all_gather, but gathers data on master process only, using CPU.
Thus, this does not work with NCCL backend unless they add CPU support.
The memory consumption of this function is ~ 3x of data size. While in
principal, it should be ~2x, it's not easy to force Python to release
memory immediately and thus, peak memory usage could be up to 3x.
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
# trying to optimize memory, but in fact, it's not guaranteed to be released
del data
storage = torch.ByteStorage.from_buffer(buffer)
del buffer
tensor = torch.ByteTensor(storage)
# obtain Tensor size of each rank
local_size = torch.LongTensor([tensor.numel()])
size_list = [torch.LongTensor([0]) for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,))
tensor = torch.cat((tensor, padding), dim=0)
del padding
if is_main_process():
tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)))
dist.gather(tensor, gather_list=tensor_list, dst=0)
del tensor
else:
dist.gather(tensor, gather_list=[], dst=0)
del tensor
return
data_list = []
for tensor in tensor_list:
buffer = tensor.cpu().numpy().tobytes()
del tensor
data_list.append(pickle.loads(buffer))
del buffer
return data_list
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to(device)
# obtain Tensor size of each rank
local_size = torch.LongTensor([tensor.numel()]).to(device)
size_list = [torch.LongTensor([0]).to(device) for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to(device))
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).to(device)
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that process with rank
0 has the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
# only main process gets accumulated, so only divide by
# world_size in this case
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict

View File

@@ -0,0 +1,66 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
"""
import os
import os.path as op
import numpy as np
import base64
import cv2
import yaml
from collections import OrderedDict
def img_from_base64(imagestring):
try:
jpgbytestring = base64.b64decode(imagestring)
nparr = np.frombuffer(jpgbytestring, np.uint8)
r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
return r
except:
return None
def load_labelmap(labelmap_file):
label_dict = None
if labelmap_file is not None and op.isfile(labelmap_file):
label_dict = OrderedDict()
with open(labelmap_file, 'r') as fp:
for line in fp:
label = line.strip().split('\t')[0]
if label in label_dict:
raise ValueError("Duplicate label " + label + " in labelmap.")
else:
label_dict[label] = len(label_dict)
return label_dict
def load_shuffle_file(shuf_file):
shuf_list = None
if shuf_file is not None:
with open(shuf_file, 'r') as fp:
shuf_list = []
for i in fp:
shuf_list.append(int(i.strip()))
return shuf_list
def load_box_shuffle_file(shuf_file):
if shuf_file is not None:
with open(shuf_file, 'r') as fp:
img_shuf_list = []
box_shuf_list = []
for i in fp:
idx = [int(_) for _ in i.strip().split('\t')]
img_shuf_list.append(idx[0])
box_shuf_list.append(idx[1])
return [img_shuf_list, box_shuf_list]
return None
def load_from_yaml_file(file_name):
with open(file_name, 'r') as fp:
return yaml.load(fp, Loader=yaml.CLoader)

View File

@@ -0,0 +1,58 @@
"""
Useful geometric operations, e.g. Orthographic projection and a differentiable Rodrigues formula
Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
"""
import torch
def rodrigues(theta):
"""Convert axis-angle representation to rotation matrix.
Args:
theta: size = [B, 3]
Returns:
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
"""
l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
angle = torch.unsqueeze(l1norm, -1)
normalized = torch.div(theta, angle)
angle = angle * 0.5
v_cos = torch.cos(angle)
v_sin = torch.sin(angle)
quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
return quat2mat(quat)
def quat2mat(quat):
"""Convert quaternion coefficients to rotation matrix.
Args:
quat: size = [B, 4] 4 <===>(w, x, y, z)
Returns:
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
"""
norm_quat = quat
norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
B = quat.size(0)
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
wx, wy, wz = w*x, w*y, w*z
xy, xz, yz = x*y, x*z, y*z
rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
return rotMat
def orthographic_projection(X, camera):
"""Perform orthographic projection of 3D points X using the camera parameters
Args:
X: size = [B, N, 3]
camera: size = [B, 3]
Returns:
Projected 2D points -- size = [B, N, 2]
"""
camera = camera.view(-1, 1, 3)
X_trans = X[:, :, :2] + camera[:, :, 1:]
shape = X_trans.shape
X_2d = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape)
return X_2d

View File

@@ -0,0 +1,208 @@
"""
Image processing tools
Modified from open source projects:
(https://github.com/nkolot/GraphCMR/)
(https://github.com/open-mmlab/mmdetection)
"""
import numpy as np
import base64
import cv2
import torch
import scipy.misc
def img_from_base64(imagestring):
try:
jpgbytestring = base64.b64decode(imagestring)
nparr = np.frombuffer(jpgbytestring, np.uint8)
r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
return r
except ValueError:
return None
def myimrotate(img, angle, center=None, scale=1.0, border_value=0, auto_bound=False):
if center is not None and auto_bound:
raise ValueError('`auto_bound` conflicts with `center`')
h, w = img.shape[:2]
if center is None:
center = ((w - 1) * 0.5, (h - 1) * 0.5)
assert isinstance(center, tuple)
matrix = cv2.getRotationMatrix2D(center, angle, scale)
if auto_bound:
cos = np.abs(matrix[0, 0])
sin = np.abs(matrix[0, 1])
new_w = h * sin + w * cos
new_h = h * cos + w * sin
matrix[0, 2] += (new_w - w) * 0.5
matrix[1, 2] += (new_h - h) * 0.5
w = int(np.round(new_w))
h = int(np.round(new_h))
rotated = cv2.warpAffine(img, matrix, (w, h), borderValue=border_value)
return rotated
def myimresize(img, size, return_scale=False, interpolation='bilinear'):
h, w = img.shape[:2]
resized_img = cv2.resize(
img, (size[0],size[1]), interpolation=cv2.INTER_LINEAR)
if not return_scale:
return resized_img
else:
w_scale = size[0] / w
h_scale = size[1] / h
return resized_img, w_scale, h_scale
def get_transform(center, scale, res, rot=0):
"""Generate transformation matrix."""
h = 200 * scale
t = np.zeros((3, 3))
t[0, 0] = float(res[1]) / h
t[1, 1] = float(res[0]) / h
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
t[2, 2] = 1
if not rot == 0:
rot = -rot # To match direction of rotation from cropping
rot_mat = np.zeros((3,3))
rot_rad = rot * np.pi / 180
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0,:2] = [cs, -sn]
rot_mat[1,:2] = [sn, cs]
rot_mat[2,2] = 1
# Need to rotate around center
t_mat = np.eye(3)
t_mat[0,2] = -res[1]/2
t_mat[1,2] = -res[0]/2
t_inv = t_mat.copy()
t_inv[:2,2] *= -1
t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
return t
def transform(pt, center, scale, res, invert=0, rot=0):
"""Transform pixel location to different reference."""
t = get_transform(center, scale, res, rot=rot)
if invert:
# t = np.linalg.inv(t)
t_torch = torch.from_numpy(t)
t_torch = torch.inverse(t_torch)
t = t_torch.numpy()
new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2].astype(int)+1
def crop(img, center, scale, res, rot=0):
"""Crop image according to the supplied bounding box."""
# Upper left point
ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
# Bottom right point
br = np.array(transform([res[0]+1,
res[1]+1], center, scale, res, invert=1))-1
# Padding so that when rotated proper amount of context is included
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
if not rot == 0:
ul -= pad
br += pad
new_shape = [br[1] - ul[1], br[0] - ul[0]]
if len(img.shape) > 2:
new_shape += [img.shape[2]]
new_img = np.zeros(new_shape)
# Range to fill new array
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
# Range to sample from original image
old_x = max(0, ul[0]), min(len(img[0]), br[0])
old_y = max(0, ul[1]), min(len(img), br[1])
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
old_x[0]:old_x[1]]
if not rot == 0:
# Remove padding
# new_img = scipy.misc.imrotate(new_img, rot)
new_img = myimrotate(new_img, rot)
new_img = new_img[pad:-pad, pad:-pad]
# new_img = scipy.misc.imresize(new_img, res)
new_img = myimresize(new_img, [res[0], res[1]])
return new_img
def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
"""'Undo' the image cropping/resizing.
This function is used when evaluating mask/part segmentation.
"""
res = img.shape[:2]
# Upper left point
ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
# Bottom right point
br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1
# size of cropped image
crop_shape = [br[1] - ul[1], br[0] - ul[0]]
new_shape = [br[1] - ul[1], br[0] - ul[0]]
if len(img.shape) > 2:
new_shape += [img.shape[2]]
new_img = np.zeros(orig_shape, dtype=np.uint8)
# Range to fill new array
new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
# Range to sample from original image
old_x = max(0, ul[0]), min(orig_shape[1], br[0])
old_y = max(0, ul[1]), min(orig_shape[0], br[1])
# img = scipy.misc.imresize(img, crop_shape, interp='nearest')
img = myimresize(img, [crop_shape[0],crop_shape[1]])
new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
return new_img
def rot_aa(aa, rot):
"""Rotate axis angle parameters."""
# pose parameters
R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
[np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
[0, 0, 1]])
# find the rotation of the body in camera frame
per_rdg, _ = cv2.Rodrigues(aa)
# apply the global rotation to the global orientation
resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
aa = (resrot.T)[0]
return aa
def flip_img(img):
"""Flip rgb images or masks.
channels come last, e.g. (256,256,3).
"""
img = np.fliplr(img)
return img
def flip_kp(kp):
"""Flip keypoints."""
flipped_parts = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22]
kp = kp[flipped_parts]
kp[:,0] = - kp[:,0]
return kp
def flip_pose(pose):
"""Flip pose.
The flipping is based on SMPL parameters.
"""
flippedParts = [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13,
14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33,
34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41,
45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55,
56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68]
pose = pose[flippedParts]
# we also negate the second and the third dimension of the axis-angle
pose[1::3] = -pose[1::3]
pose[2::3] = -pose[2::3]
return pose
def flip_aa(aa):
"""Flip axis-angle representation.
We negate the second and the third dimension of the axis-angle.
"""
aa[1] = -aa[1]
aa[2] = -aa[2]
return aa

View File

@@ -0,0 +1,100 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import logging
import os
import sys
from logging import StreamHandler, Handler, getLevelName
# this class is a copy of logging.FileHandler except we end self.close()
# at the end of each emit. While closing file and reopening file after each
# write is not efficient, it allows us to see partial logs when writing to
# fused Azure blobs, which is very convenient
class FileHandler(StreamHandler):
"""
A handler class which writes formatted logging records to disk files.
"""
def __init__(self, filename, mode='a', encoding=None, delay=False):
"""
Open the specified file and use it as the stream for logging.
"""
# Issue #27493: add support for Path objects to be passed in
filename = os.fspath(filename)
#keep the absolute path, otherwise derived classes which use this
#may come a cropper when the current directory changes
self.baseFilename = os.path.abspath(filename)
self.mode = mode
self.encoding = encoding
self.delay = delay
if delay:
#We don't open the stream, but we still need to call the
#Handler constructor to set level, formatter, lock etc.
Handler.__init__(self)
self.stream = None
else:
StreamHandler.__init__(self, self._open())
def close(self):
"""
Closes the stream.
"""
self.acquire()
try:
try:
if self.stream:
try:
self.flush()
finally:
stream = self.stream
self.stream = None
if hasattr(stream, "close"):
stream.close()
finally:
# Issue #19523: call unconditionally to
# prevent a handler leak when delay is set
StreamHandler.close(self)
finally:
self.release()
def _open(self):
"""
Open the current base file with the (original) mode and encoding.
Return the resulting stream.
"""
return open(self.baseFilename, self.mode, encoding=self.encoding)
def emit(self, record):
"""
Emit a record.
If the stream was not opened because 'delay' was specified in the
constructor, open it before calling the superclass's emit.
"""
if self.stream is None:
self.stream = self._open()
StreamHandler.emit(self, record)
self.close()
def __repr__(self):
level = getLevelName(self.level)
return '<%s %s (%s)>' % (self.__class__.__name__, self.baseFilename, level)
def setup_logger(name, save_dir, distributed_rank, filename="log.txt"):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
# don't log results for the non-master process
if distributed_rank > 0:
return logger
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
ch.setFormatter(formatter)
logger.addHandler(ch)
if save_dir:
fh = FileHandler(os.path.join(save_dir, filename))
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger

View File

@@ -0,0 +1,45 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
Basic logger. It Computes and stores the average and current value
"""
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class EvalMetricsLogger(object):
def __init__(self):
self.reset()
def reset(self):
# define a upper-bound performance (worst case)
# numbers are in unit millimeter
self.PAmPJPE = 100.0/1000.0
self.mPJPE = 100.0/1000.0
self.mPVE = 100.0/1000.0
self.epoch = 0
def update(self, mPVE, mPJPE, PAmPJPE, epoch):
self.PAmPJPE = PAmPJPE
self.mPJPE = mPJPE
self.mPVE = mPVE
self.epoch = epoch

View File

@@ -0,0 +1,99 @@
"""
Functions for compuing Procrustes alignment and reconstruction error
Parts of the code are adapted from https://github.com/akanazawa/hmr
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def compute_similarity_transform(S1, S2):
"""Computes a similarity transform (sR, t) that takes
a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
i.e. solves the orthogonal Procrutes problem.
"""
transposed = False
if S1.shape[0] != 3 and S1.shape[0] != 2:
S1 = S1.T
S2 = S2.T
transposed = True
assert(S2.shape[1] == S1.shape[1])
# 1. Remove mean.
mu1 = S1.mean(axis=1, keepdims=True)
mu2 = S2.mean(axis=1, keepdims=True)
X1 = S1 - mu1
X2 = S2 - mu2
# 2. Compute variance of X1 used for scale.
var1 = np.sum(X1**2)
# 3. The outer product of X1 and X2.
K = X1.dot(X2.T)
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
# singular vectors of K.
U, s, Vh = np.linalg.svd(K)
V = Vh.T
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = np.eye(U.shape[0])
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
# Construct R.
R = V.dot(Z.dot(U.T))
# 5. Recover scale.
scale = np.trace(R.dot(K)) / var1
# 6. Recover translation.
t = mu2 - scale*(R.dot(mu1))
# 7. Error:
S1_hat = scale*R.dot(S1) + t
if transposed:
S1_hat = S1_hat.T
return S1_hat
def compute_similarity_transform_batch(S1, S2):
"""Batched version of compute_similarity_transform."""
S1_hat = np.zeros_like(S1)
for i in range(S1.shape[0]):
S1_hat[i] = compute_similarity_transform(S1[i], S2[i])
return S1_hat
def reconstruction_error(S1, S2, reduction='mean'):
"""Do Procrustes alignment and compute reconstruction error."""
S1_hat = compute_similarity_transform_batch(S1, S2)
re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
if reduction == 'mean':
re = re.mean()
elif reduction == 'sum':
re = re.sum()
return re
def reconstruction_error_v2(S1, S2, J24_TO_J14, reduction='mean'):
"""Do Procrustes alignment and compute reconstruction error."""
S1_hat = compute_similarity_transform_batch(S1, S2)
S1_hat = S1_hat[:,J24_TO_J14,:]
S2 = S2[:,J24_TO_J14,:]
re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
if reduction == 'mean':
re = re.mean()
elif reduction == 'sum':
re = re.sum()
return re
def get_alignMesh(S1, S2, reduction='mean'):
"""Do Procrustes alignment and compute reconstruction error."""
S1_hat = compute_similarity_transform_batch(S1, S2)
re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
if reduction == 'mean':
re = re.mean()
elif reduction == 'sum':
re = re.sum()
return re, S1_hat, S2

View File

@@ -0,0 +1,171 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import errno
import os
import os.path as op
import re
import logging
import numpy as np
import torch
import random
import shutil
from .comm import is_main_process
import yaml
def mkdir(path):
# if it is the current folder, skip.
# otherwise the original code will raise FileNotFoundError
if path == '':
return
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def save_config(cfg, path):
if is_main_process():
with open(path, 'w') as f:
f.write(cfg.dump())
def config_iteration(output_dir, max_iter):
save_file = os.path.join(output_dir, 'last_checkpoint')
iteration = -1
if os.path.exists(save_file):
with open(save_file, 'r') as f:
fname = f.read().strip()
model_name = os.path.basename(fname)
model_path = os.path.dirname(fname)
if model_name.startswith('model_') and len(model_name) == 17:
iteration = int(model_name[-11:-4])
elif model_name == "model_final":
iteration = max_iter
elif model_path.startswith('checkpoint-') and len(model_path) == 18:
iteration = int(model_path.split('-')[-1])
return iteration
def get_matching_parameters(model, regexp, none_on_empty=True):
"""Returns parameters matching regular expression"""
if not regexp:
if none_on_empty:
return {}
else:
return dict(model.named_parameters())
compiled_pattern = re.compile(regexp)
params = {}
for weight_name, weight in model.named_parameters():
if compiled_pattern.match(weight_name):
params[weight_name] = weight
return params
def freeze_weights(model, regexp):
"""Freeze weights based on regular expression."""
logger = logging.getLogger("maskrcnn_benchmark.trainer")
for weight_name, weight in get_matching_parameters(model, regexp).items():
weight.requires_grad = False
logger.info("Disabled training of {}".format(weight_name))
def unfreeze_weights(model, regexp, backbone_freeze_at=-1,
is_distributed=False):
"""Unfreeze weights based on regular expression.
This is helpful during training to unfreeze freezed weights after
other unfreezed weights have been trained for some iterations.
"""
logger = logging.getLogger("maskrcnn_benchmark.trainer")
for weight_name, weight in get_matching_parameters(model, regexp).items():
weight.requires_grad = True
logger.info("Enabled training of {}".format(weight_name))
if backbone_freeze_at >= 0:
logger.info("Freeze backbone at stage: {}".format(backbone_freeze_at))
if is_distributed:
model.module.backbone.body._freeze_backbone(backbone_freeze_at)
else:
model.backbone.body._freeze_backbone(backbone_freeze_at)
def delete_tsv_files(tsvs):
for t in tsvs:
if op.isfile(t):
try_delete(t)
line = op.splitext(t)[0] + '.lineidx'
if op.isfile(line):
try_delete(line)
def concat_files(ins, out):
mkdir(op.dirname(out))
out_tmp = out + '.tmp'
with open(out_tmp, 'wb') as fp_out:
for i, f in enumerate(ins):
logging.info('concating {}/{} - {}'.format(i, len(ins), f))
with open(f, 'rb') as fp_in:
shutil.copyfileobj(fp_in, fp_out, 1024*1024*10)
os.rename(out_tmp, out)
def concat_tsv_files(tsvs, out_tsv):
concat_files(tsvs, out_tsv)
sizes = [os.stat(t).st_size for t in tsvs]
sizes = np.cumsum(sizes)
all_idx = []
for i, t in enumerate(tsvs):
for idx in load_list_file(op.splitext(t)[0] + '.lineidx'):
if i == 0:
all_idx.append(idx)
else:
all_idx.append(str(int(idx) + sizes[i - 1]))
with open(op.splitext(out_tsv)[0] + '.lineidx', 'w') as f:
f.write('\n'.join(all_idx))
def load_list_file(fname):
with open(fname, 'r') as fp:
lines = fp.readlines()
result = [line.strip() for line in lines]
if len(result) > 0 and result[-1] == '':
result = result[:-1]
return result
def try_once(func):
def func_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
logging.info('ignore error \n{}'.format(str(e)))
return func_wrapper
@try_once
def try_delete(f):
os.remove(f)
def set_seed(seed, n_gpu):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(seed)
def print_and_run_cmd(cmd):
print(cmd)
os.system(cmd)
def write_to_yaml_file(context, file_name):
with open(file_name, 'w') as fp:
yaml.dump(context, fp, encoding='utf-8')
def load_from_yaml_file(yaml_file):
with open(yaml_file, 'r') as fp:
return yaml.load(fp, Loader=yaml.CLoader)

View File

@@ -0,0 +1,691 @@
"""
Rendering tools for 3D mesh visualization on 2D image.
Parts of the code are taken from https://github.com/akanazawa/hmr
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import cv2
import code
from opendr.camera import ProjectPoints
from opendr.renderer import ColoredRenderer, TexturedRenderer
from opendr.lighting import LambertianPointLight
import random
# Rotate the points by a specified angle.
def rotateY(points, angle):
ry = np.array([
[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.],
[-np.sin(angle), 0., np.cos(angle)]
])
return np.dot(points, ry)
def draw_skeleton(input_image, joints, draw_edges=True, vis=None, radius=None):
"""
joints is 3 x 19. but if not will transpose it.
0: Right ankle
1: Right knee
2: Right hip
3: Left hip
4: Left knee
5: Left ankle
6: Right wrist
7: Right elbow
8: Right shoulder
9: Left shoulder
10: Left elbow
11: Left wrist
12: Neck
13: Head top
14: nose
15: left_eye
16: right_eye
17: left_ear
18: right_ear
"""
if radius is None:
radius = max(4, (np.mean(input_image.shape[:2]) * 0.01).astype(int))
colors = {
'pink': (197, 27, 125), # L lower leg
'light_pink': (233, 163, 201), # L upper leg
'light_green': (161, 215, 106), # L lower arm
'green': (77, 146, 33), # L upper arm
'red': (215, 48, 39), # head
'light_red': (252, 146, 114), # head
'light_orange': (252, 141, 89), # chest
'purple': (118, 42, 131), # R lower leg
'light_purple': (175, 141, 195), # R upper
'light_blue': (145, 191, 219), # R lower arm
'blue': (69, 117, 180), # R upper arm
'gray': (130, 130, 130), #
'white': (255, 255, 255), #
}
image = input_image.copy()
input_is_float = False
if np.issubdtype(image.dtype, np.float):
input_is_float = True
max_val = image.max()
if max_val <= 2.: # should be 1 but sometimes it's slightly above 1
image = (image * 255).astype(np.uint8)
else:
image = (image).astype(np.uint8)
if joints.shape[0] != 2:
joints = joints.T
joints = np.round(joints).astype(int)
jcolors = [
'light_pink', 'light_pink', 'light_pink', 'pink', 'pink', 'pink',
'light_blue', 'light_blue', 'light_blue', 'blue', 'blue', 'blue',
'purple', 'purple', 'red', 'green', 'green', 'white', 'white',
'purple', 'purple', 'red', 'green', 'green', 'white', 'white'
]
if joints.shape[1] == 19:
# parent indices -1 means no parents
parents = np.array([
1, 2, 8, 9, 3, 4, 7, 8, 12, 12, 9, 10, 14, -1, 13, -1, -1, 15, 16
])
# Left is light and right is dark
ecolors = {
0: 'light_pink',
1: 'light_pink',
2: 'light_pink',
3: 'pink',
4: 'pink',
5: 'pink',
6: 'light_blue',
7: 'light_blue',
8: 'light_blue',
9: 'blue',
10: 'blue',
11: 'blue',
12: 'purple',
17: 'light_green',
18: 'light_green',
14: 'purple'
}
elif joints.shape[1] == 14:
parents = np.array([
1,
2,
8,
9,
3,
4,
7,
8,
-1,
-1,
9,
10,
13,
-1,
])
ecolors = {
0: 'light_pink',
1: 'light_pink',
2: 'light_pink',
3: 'pink',
4: 'pink',
5: 'pink',
6: 'light_blue',
7: 'light_blue',
10: 'light_blue',
11: 'blue',
12: 'purple'
}
elif joints.shape[1] == 21: # hand
parents = np.array([
-1,
0,
1,
2,
3,
0,
5,
6,
7,
0,
9,
10,
11,
0,
13,
14,
15,
0,
17,
18,
19,
])
ecolors = {
0: 'light_purple',
1: 'light_green',
2: 'light_green',
3: 'light_green',
4: 'light_green',
5: 'pink',
6: 'pink',
7: 'pink',
8: 'pink',
9: 'light_blue',
10: 'light_blue',
11: 'light_blue',
12: 'light_blue',
13: 'light_red',
14: 'light_red',
15: 'light_red',
16: 'light_red',
17: 'purple',
18: 'purple',
19: 'purple',
20: 'purple',
}
else:
print('Unknown skeleton!!')
for child in range(len(parents)):
point = joints[:, child]
# If invisible skip
if vis is not None and vis[child] == 0:
continue
if draw_edges:
cv2.circle(image, (point[0], point[1]), radius, colors['white'],
-1)
cv2.circle(image, (point[0], point[1]), radius - 1,
colors[jcolors[child]], -1)
else:
# cv2.circle(image, (point[0], point[1]), 5, colors['white'], 1)
cv2.circle(image, (point[0], point[1]), radius - 1,
colors[jcolors[child]], 1)
# cv2.circle(image, (point[0], point[1]), 5, colors['gray'], -1)
pa_id = parents[child]
if draw_edges and pa_id >= 0:
if vis is not None and vis[pa_id] == 0:
continue
point_pa = joints[:, pa_id]
cv2.circle(image, (point_pa[0], point_pa[1]), radius - 1,
colors[jcolors[pa_id]], -1)
if child not in ecolors.keys():
print('bad')
import ipdb
ipdb.set_trace()
cv2.line(image, (point[0], point[1]), (point_pa[0], point_pa[1]),
colors[ecolors[child]], radius - 2)
# Convert back in original dtype
if input_is_float:
if max_val <= 1.:
image = image.astype(np.float32) / 255.
else:
image = image.astype(np.float32)
return image
def draw_text(input_image, content):
"""
content is a dict. draws key: val on image
Assumes key is str, val is float
"""
image = input_image.copy()
input_is_float = False
if np.issubdtype(image.dtype, np.float):
input_is_float = True
image = (image * 255).astype(np.uint8)
black = (255, 255, 0)
margin = 15
start_x = 5
start_y = margin
for key in sorted(content.keys()):
text = "%s: %.2g" % (key, content[key])
cv2.putText(image, text, (start_x, start_y), 0, 0.45, black)
start_y += margin
if input_is_float:
image = image.astype(np.float32) / 255.
return image
def visualize_reconstruction(img, img_size, gt_kp, vertices, pred_kp, camera, renderer, color='pink', focal_length=1000):
"""Overlays gt_kp and pred_kp on img.
Draws vert with text.
Renderer is an instance of SMPLRenderer.
"""
gt_vis = gt_kp[:, 2].astype(bool)
loss = np.sum((gt_kp[gt_vis, :2] - pred_kp[gt_vis])**2)
debug_text = {"sc": camera[0], "tx": camera[1], "ty": camera[2], "kpl": loss}
# Fix a flength so i can render this with persp correct scale
res = img.shape[1]
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
rend_img = renderer.render(vertices, camera_t=camera_t,
img=img, use_bg=True,
focal_length=focal_length,
body_color=color)
rend_img = draw_text(rend_img, debug_text)
# Draw skeleton
gt_joint = ((gt_kp[:, :2] + 1) * 0.5) * img_size
pred_joint = ((pred_kp + 1) * 0.5) * img_size
img_with_gt = draw_skeleton( img, gt_joint, draw_edges=False, vis=gt_vis)
skel_img = draw_skeleton(img_with_gt, pred_joint)
combined = np.hstack([skel_img, rend_img])
return combined
def visualize_reconstruction_test(img, img_size, gt_kp, vertices, pred_kp, camera, renderer, score, color='pink', focal_length=1000):
"""Overlays gt_kp and pred_kp on img.
Draws vert with text.
Renderer is an instance of SMPLRenderer.
"""
gt_vis = gt_kp[:, 2].astype(bool)
loss = np.sum((gt_kp[gt_vis, :2] - pred_kp[gt_vis])**2)
debug_text = {"sc": camera[0], "tx": camera[1], "ty": camera[2], "kpl": loss, "pa-mpjpe": score*1000}
# Fix a flength so i can render this with persp correct scale
res = img.shape[1]
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
rend_img = renderer.render(vertices, camera_t=camera_t,
img=img, use_bg=True,
focal_length=focal_length,
body_color=color)
rend_img = draw_text(rend_img, debug_text)
# Draw skeleton
gt_joint = ((gt_kp[:, :2] + 1) * 0.5) * img_size
pred_joint = ((pred_kp + 1) * 0.5) * img_size
img_with_gt = draw_skeleton( img, gt_joint, draw_edges=False, vis=gt_vis)
skel_img = draw_skeleton(img_with_gt, pred_joint)
combined = np.hstack([skel_img, rend_img])
return combined
def visualize_reconstruction_and_att(img, img_size, vertices_full, vertices, vertices_2d, camera, renderer, ref_points, attention, focal_length=1000):
"""Overlays gt_kp and pred_kp on img.
Draws vert with text.
Renderer is an instance of SMPLRenderer.
"""
# Fix a flength so i can render this with persp correct scale
res = img.shape[1]
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
rend_img = renderer.render(vertices_full, camera_t=camera_t,
img=img, use_bg=True,
focal_length=focal_length, body_color='light_blue')
heads_num, vertex_num, _ = attention.shape
all_head = np.zeros((vertex_num,vertex_num))
###### find max
# for i in range(vertex_num):
# for j in range(vertex_num):
# all_head[i,j] = np.max(attention[:,i,j])
##### find avg
for h in range(4):
att_per_img = attention[h]
all_head = all_head + att_per_img
all_head = all_head/4
col_sums = all_head.sum(axis=0)
all_head = all_head / col_sums[np.newaxis, :]
# code.interact(local=locals())
combined = []
if vertex_num>400: # body
selected_joints = [6,7,4,5,13] # [6,7,4,5,13,12]
else: # hand
selected_joints = [0, 4, 8, 12, 16, 20]
# Draw attention
for ii in range(len(selected_joints)):
reference_id = selected_joints[ii]
ref_point = ref_points[reference_id]
attention_to_show = all_head[reference_id][14::]
min_v = np.min(attention_to_show)
max_v = np.max(attention_to_show)
norm_attention_to_show = (attention_to_show - min_v)/(max_v-min_v)
vertices_norm = ((vertices_2d + 1) * 0.5) * img_size
ref_norm = ((ref_point + 1) * 0.5) * img_size
image = np.zeros_like(rend_img)
for jj in range(vertices_norm.shape[0]):
x = int(vertices_norm[jj,0])
y = int(vertices_norm[jj,1])
cv2.circle(image,(x,y), 1, (255,255,255), -1)
total_to_draw = []
for jj in range(vertices_norm.shape[0]):
thres = 0.0
if norm_attention_to_show[jj]>thres:
things = [norm_attention_to_show[jj], ref_norm, vertices_norm[jj]]
total_to_draw.append(things)
# plot_one_line(ref_norm, vertices_norm[jj], image, reference_id, alpha=0.4*(norm_attention_to_show[jj]-thres)/(1-thres) )
total_to_draw.sort()
max_att_score = total_to_draw[-1][0]
for item in total_to_draw:
attention_score = item[0]
ref_point = item[1]
vertex = item[2]
plot_one_line(ref_point, vertex, image, ii, alpha=(attention_score-thres)/(max_att_score-thres) )
# code.interact(local=locals())
if len(combined)==0:
combined = image
else:
combined = np.hstack([combined, image])
final = np.hstack([img, combined, rend_img])
return final
def visualize_reconstruction_and_att_local(img, img_size, vertices_full, vertices, vertices_2d, camera, renderer, ref_points, attention, color='light_blue', focal_length=1000):
"""Overlays gt_kp and pred_kp on img.
Draws vert with text.
Renderer is an instance of SMPLRenderer.
"""
# Fix a flength so i can render this with persp correct scale
res = img.shape[1]
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
rend_img = renderer.render(vertices_full, camera_t=camera_t,
img=img, use_bg=True,
focal_length=focal_length, body_color=color)
heads_num, vertex_num, _ = attention.shape
all_head = np.zeros((vertex_num,vertex_num))
##### compute avg attention for 4 attention heads
for h in range(4):
att_per_img = attention[h]
all_head = all_head + att_per_img
all_head = all_head/4
col_sums = all_head.sum(axis=0)
all_head = all_head / col_sums[np.newaxis, :]
combined = []
if vertex_num>400: # body
selected_joints = [7] # [6,7,4,5,13,12]
else: # hand
selected_joints = [0] # [0, 4, 8, 12, 16, 20]
# Draw attention
for ii in range(len(selected_joints)):
reference_id = selected_joints[ii]
ref_point = ref_points[reference_id]
attention_to_show = all_head[reference_id][14::]
min_v = np.min(attention_to_show)
max_v = np.max(attention_to_show)
norm_attention_to_show = (attention_to_show - min_v)/(max_v-min_v)
vertices_norm = ((vertices_2d + 1) * 0.5) * img_size
ref_norm = ((ref_point + 1) * 0.5) * img_size
image = rend_img*0.4
total_to_draw = []
for jj in range(vertices_norm.shape[0]):
thres = 0.0
if norm_attention_to_show[jj]>thres:
things = [norm_attention_to_show[jj], ref_norm, vertices_norm[jj]]
total_to_draw.append(things)
total_to_draw.sort()
max_att_score = total_to_draw[-1][0]
for item in total_to_draw:
attention_score = item[0]
ref_point = item[1]
vertex = item[2]
plot_one_line(ref_point, vertex, image, ii, alpha=(attention_score-thres)/(max_att_score-thres) )
for jj in range(vertices_norm.shape[0]):
x = int(vertices_norm[jj,0])
y = int(vertices_norm[jj,1])
cv2.circle(image,(x,y), 1, (255,255,255), -1)
if len(combined)==0:
combined = image
else:
combined = np.hstack([combined, image])
final = np.hstack([img, combined, rend_img])
return final
def visualize_reconstruction_no_text(img, img_size, vertices, camera, renderer, color='pink', focal_length=1000):
"""Overlays gt_kp and pred_kp on img.
Draws vert with text.
Renderer is an instance of SMPLRenderer.
"""
# Fix a flength so i can render this with persp correct scale
res = img.shape[1]
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
rend_img = renderer.render(vertices, camera_t=camera_t,
img=img, use_bg=True,
focal_length=focal_length,
body_color=color)
combined = np.hstack([img, rend_img])
return combined
def plot_one_line(ref, vertex, img, color_index, alpha=0.0, line_thickness=None):
# 13,6,7,8,3,4,5
# att_colors = [(255, 221, 104), (255, 255, 0), (255, 215, 227), (210, 240, 119), \
# (209, 238, 245), (244, 200, 243), (233, 242, 216)]
att_colors = [(255, 255, 0), (244, 200, 243), (210, 243, 119), (209, 238, 255), (200, 208, 255), (250, 238, 215)]
overlay = img.copy()
# output = img.copy()
# Plots one bounding box on image img
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
color = list(att_colors[color_index])
c1, c2 = (int(ref[0]), int(ref[1])), (int(vertex[0]), int(vertex[1]))
cv2.line(overlay, c1, c2, (alpha*float(color[0])/255,alpha*float(color[1])/255,alpha*float(color[2])/255) , thickness=tl, lineType=cv2.LINE_AA)
cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
def cam2pixel(cam_coord, f, c):
x = cam_coord[:, 0] / (cam_coord[:, 2]) * f[0] + c[0]
y = cam_coord[:, 1] / (cam_coord[:, 2]) * f[1] + c[1]
z = cam_coord[:, 2]
img_coord = np.concatenate((x[:,None], y[:,None], z[:,None]),1)
return img_coord
class Renderer(object):
"""
Render mesh using OpenDR for visualization.
"""
def __init__(self, width=800, height=600, near=0.5, far=1000, faces=None):
self.colors = {'hand': [.9, .9, .9], 'pink': [.9, .7, .7], 'light_blue': [0.65098039, 0.74117647, 0.85882353] }
self.width = width
self.height = height
self.faces = faces
self.renderer = ColoredRenderer()
def render(self, vertices, faces=None, img=None,
camera_t=np.zeros([3], dtype=np.float32),
camera_rot=np.zeros([3], dtype=np.float32),
camera_center=None,
use_bg=False,
bg_color=(0.0, 0.0, 0.0),
body_color=None,
focal_length=5000,
disp_text=False,
gt_keyp=None,
pred_keyp=None,
**kwargs):
if img is not None:
height, width = img.shape[:2]
else:
height, width = self.height, self.width
if faces is None:
faces = self.faces
if camera_center is None:
camera_center = np.array([width * 0.5,
height * 0.5])
self.renderer.camera = ProjectPoints(rt=camera_rot,
t=camera_t,
f=focal_length * np.ones(2),
c=camera_center,
k=np.zeros(5))
dist = np.abs(self.renderer.camera.t.r[2] -
np.mean(vertices, axis=0)[2])
far = dist + 20
self.renderer.frustum = {'near': 1.0, 'far': far,
'width': width,
'height': height}
if img is not None:
if use_bg:
self.renderer.background_image = img
else:
self.renderer.background_image = np.ones_like(
img) * np.array(bg_color)
if body_color is None:
color = self.colors['light_blue']
else:
color = self.colors[body_color]
if isinstance(self.renderer, TexturedRenderer):
color = [1.,1.,1.]
self.renderer.set(v=vertices, f=faces,
vc=color, bgcolor=np.ones(3))
albedo = self.renderer.vc
# Construct Back Light (on back right corner)
yrot = np.radians(120)
self.renderer.vc = LambertianPointLight(
f=self.renderer.f,
v=self.renderer.v,
num_verts=self.renderer.v.shape[0],
light_pos=rotateY(np.array([-200, -100, -100]), yrot),
vc=albedo,
light_color=np.array([1, 1, 1]))
# Construct Left Light
self.renderer.vc += LambertianPointLight(
f=self.renderer.f,
v=self.renderer.v,
num_verts=self.renderer.v.shape[0],
light_pos=rotateY(np.array([800, 10, 300]), yrot),
vc=albedo,
light_color=np.array([1, 1, 1]))
# Construct Right Light
self.renderer.vc += LambertianPointLight(
f=self.renderer.f,
v=self.renderer.v,
num_verts=self.renderer.v.shape[0],
light_pos=rotateY(np.array([-500, 500, 1000]), yrot),
vc=albedo,
light_color=np.array([.7, .7, .7]))
return self.renderer.r
def render_vertex_color(self, vertices, faces=None, img=None,
camera_t=np.zeros([3], dtype=np.float32),
camera_rot=np.zeros([3], dtype=np.float32),
camera_center=None,
use_bg=False,
bg_color=(0.0, 0.0, 0.0),
vertex_color=None,
focal_length=5000,
disp_text=False,
gt_keyp=None,
pred_keyp=None,
**kwargs):
if img is not None:
height, width = img.shape[:2]
else:
height, width = self.height, self.width
if faces is None:
faces = self.faces
if camera_center is None:
camera_center = np.array([width * 0.5,
height * 0.5])
self.renderer.camera = ProjectPoints(rt=camera_rot,
t=camera_t,
f=focal_length * np.ones(2),
c=camera_center,
k=np.zeros(5))
dist = np.abs(self.renderer.camera.t.r[2] -
np.mean(vertices, axis=0)[2])
far = dist + 20
self.renderer.frustum = {'near': 1.0, 'far': far,
'width': width,
'height': height}
if img is not None:
if use_bg:
self.renderer.background_image = img
else:
self.renderer.background_image = np.ones_like(
img) * np.array(bg_color)
if vertex_color is None:
vertex_color = self.colors['light_blue']
self.renderer.set(v=vertices, f=faces,
vc=vertex_color, bgcolor=np.ones(3))
albedo = self.renderer.vc
# Construct Back Light (on back right corner)
yrot = np.radians(120)
self.renderer.vc = LambertianPointLight(
f=self.renderer.f,
v=self.renderer.v,
num_verts=self.renderer.v.shape[0],
light_pos=rotateY(np.array([-200, -100, -100]), yrot),
vc=albedo,
light_color=np.array([1, 1, 1]))
# Construct Left Light
self.renderer.vc += LambertianPointLight(
f=self.renderer.f,
v=self.renderer.v,
num_verts=self.renderer.v.shape[0],
light_pos=rotateY(np.array([800, 10, 300]), yrot),
vc=albedo,
light_color=np.array([1, 1, 1]))
# Construct Right Light
self.renderer.vc += LambertianPointLight(
f=self.renderer.f,
v=self.renderer.v,
num_verts=self.renderer.v.shape[0],
light_pos=rotateY(np.array([-500, 500, 1000]), yrot),
vc=albedo,
light_color=np.array([.7, .7, .7]))
return self.renderer.r

View File

@@ -0,0 +1,162 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
Definition of TSV class
"""
import logging
import os
import os.path as op
def generate_lineidx(filein, idxout):
idxout_tmp = idxout + '.tmp'
with open(filein, 'r') as tsvin, open(idxout_tmp,'w') as tsvout:
fsize = os.fstat(tsvin.fileno()).st_size
fpos = 0
while fpos!=fsize:
tsvout.write(str(fpos)+"\n")
tsvin.readline()
fpos = tsvin.tell()
os.rename(idxout_tmp, idxout)
def read_to_character(fp, c):
result = []
while True:
s = fp.read(32)
assert s != ''
if c in s:
result.append(s[: s.index(c)])
break
else:
result.append(s)
return ''.join(result)
class TSVFile(object):
def __init__(self, tsv_file, generate_lineidx=False):
self.tsv_file = tsv_file
self.lineidx = op.splitext(tsv_file)[0] + '.lineidx'
self._fp = None
self._lineidx = None
# the process always keeps the process which opens the file.
# If the pid is not equal to the currrent pid, we will re-open the file.
self.pid = None
# generate lineidx if not exist
if not op.isfile(self.lineidx) and generate_lineidx:
generate_lineidx(self.tsv_file, self.lineidx)
def __del__(self):
if self._fp:
self._fp.close()
def __str__(self):
return "TSVFile(tsv_file='{}')".format(self.tsv_file)
def __repr__(self):
return str(self)
def num_rows(self):
self._ensure_lineidx_loaded()
return len(self._lineidx)
def seek(self, idx):
self._ensure_tsv_opened()
self._ensure_lineidx_loaded()
try:
pos = self._lineidx[idx]
except:
logging.info('{}-{}'.format(self.tsv_file, idx))
raise
self._fp.seek(pos)
return [s.strip() for s in self._fp.readline().split('\t')]
def seek_first_column(self, idx):
self._ensure_tsv_opened()
self._ensure_lineidx_loaded()
pos = self._lineidx[idx]
self._fp.seek(pos)
return read_to_character(self._fp, '\t')
def get_key(self, idx):
return self.seek_first_column(idx)
def __getitem__(self, index):
return self.seek(index)
def __len__(self):
return self.num_rows()
def _ensure_lineidx_loaded(self):
if self._lineidx is None:
logging.info('loading lineidx: {}'.format(self.lineidx))
with open(self.lineidx, 'r') as fp:
self._lineidx = [int(i.strip()) for i in fp.readlines()]
def _ensure_tsv_opened(self):
if self._fp is None:
self._fp = open(self.tsv_file, 'r')
self.pid = os.getpid()
if self.pid != os.getpid():
logging.info('re-open {} because the process id changed'.format(self.tsv_file))
self._fp = open(self.tsv_file, 'r')
self.pid = os.getpid()
class CompositeTSVFile():
def __init__(self, file_list, seq_file, root='.'):
if isinstance(file_list, str):
self.file_list = load_list_file(file_list)
else:
assert isinstance(file_list, list)
self.file_list = file_list
self.seq_file = seq_file
self.root = root
self.initialized = False
self.initialize()
def get_key(self, index):
idx_source, idx_row = self.seq[index]
k = self.tsvs[idx_source].get_key(idx_row)
return '_'.join([self.file_list[idx_source], k])
def num_rows(self):
return len(self.seq)
def __getitem__(self, index):
idx_source, idx_row = self.seq[index]
return self.tsvs[idx_source].seek(idx_row)
def __len__(self):
return len(self.seq)
def initialize(self):
'''
this function has to be called in init function if cache_policy is
enabled. Thus, let's always call it in init funciton to make it simple.
'''
if self.initialized:
return
self.seq = []
with open(self.seq_file, 'r') as fp:
for line in fp:
parts = line.strip().split('\t')
self.seq.append([int(parts[0]), int(parts[1])])
self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list]
self.initialized = True
def load_list_file(fname):
with open(fname, 'r') as fp:
lines = fp.readlines()
result = [line.strip() for line in lines]
if len(result) > 0 and result[-1] == '':
result = result[:-1]
return result

View File

@@ -0,0 +1,116 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
Basic operations for TSV files
"""
import os
import os.path as op
import json
import numpy as np
import base64
import cv2
from tqdm import tqdm
import yaml
from mesh_graphormer.utils.miscellaneous import mkdir
from mesh_graphormer.utils.tsv_file import TSVFile
def img_from_base64(imagestring):
try:
jpgbytestring = base64.b64decode(imagestring)
nparr = np.frombuffer(jpgbytestring, np.uint8)
r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
return r
except ValueError:
return None
def load_linelist_file(linelist_file):
if linelist_file is not None:
line_list = []
with open(linelist_file, 'r') as fp:
for i in fp:
line_list.append(int(i.strip()))
return line_list
def tsv_writer(values, tsv_file, sep='\t'):
mkdir(op.dirname(tsv_file))
lineidx_file = op.splitext(tsv_file)[0] + '.lineidx'
idx = 0
tsv_file_tmp = tsv_file + '.tmp'
lineidx_file_tmp = lineidx_file + '.tmp'
with open(tsv_file_tmp, 'w') as fp, open(lineidx_file_tmp, 'w') as fpidx:
assert values is not None
for value in values:
assert value is not None
value = [v if type(v)!=bytes else v.decode('utf-8') for v in value]
v = '{0}\n'.format(sep.join(map(str, value)))
fp.write(v)
fpidx.write(str(idx) + '\n')
idx = idx + len(v)
os.rename(tsv_file_tmp, tsv_file)
os.rename(lineidx_file_tmp, lineidx_file)
def tsv_reader(tsv_file, sep='\t'):
with open(tsv_file, 'r') as fp:
for i, line in enumerate(fp):
yield [x.strip() for x in line.split(sep)]
def config_save_file(tsv_file, save_file=None, append_str='.new.tsv'):
if save_file is not None:
return save_file
return op.splitext(tsv_file)[0] + append_str
def get_line_list(linelist_file=None, num_rows=None):
if linelist_file is not None:
return load_linelist_file(linelist_file)
if num_rows is not None:
return [i for i in range(num_rows)]
def generate_hw_file(img_file, save_file=None):
rows = tsv_reader(img_file)
def gen_rows():
for i, row in tqdm(enumerate(rows)):
row1 = [row[0]]
img = img_from_base64(row[-1])
height = img.shape[0]
width = img.shape[1]
row1.append(json.dumps([{"height":height, "width": width}]))
yield row1
save_file = config_save_file(img_file, save_file, '.hw.tsv')
tsv_writer(gen_rows(), save_file)
def generate_linelist_file(label_file, save_file=None, ignore_attrs=()):
# generate a list of image that has labels
# images with only ignore labels are not selected.
line_list = []
rows = tsv_reader(label_file)
for i, row in tqdm(enumerate(rows)):
labels = json.loads(row[1])
if labels:
if ignore_attrs and all([any([lab[attr] for attr in ignore_attrs if attr in lab]) \
for lab in labels]):
continue
line_list.append([i])
save_file = config_save_file(label_file, save_file, '.linelist.tsv')
tsv_writer(line_list, save_file)
def load_from_yaml_file(yaml_file):
with open(yaml_file, 'r') as fp:
return yaml.load(fp, Loader=yaml.CLoader)
def find_file_path_in_yaml(fname, root):
if fname is not None:
if op.isfile(fname):
return fname
elif op.isfile(op.join(root, fname)):
return op.join(root, fname)
else:
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname)
)

1
requirement.txt Normal file
View File

@@ -0,0 +1 @@
rtree