commit a1f793c0a71e63e1cf8f0922251c0a8525315db6 Author: huchenlei Date: Wed Jan 3 00:39:16 2024 -0500 :sparkles: Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..deb74ca --- /dev/null +++ b/.gitignore @@ -0,0 +1,161 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ +.vscode/ \ No newline at end of file diff --git a/hand_refiner/__init__.py b/hand_refiner/__init__.py new file mode 100644 index 0000000..3a1d72e --- /dev/null +++ b/hand_refiner/__init__.py @@ -0,0 +1,42 @@ +import numpy as np +from PIL import Image +from .util import resize_image_with_pad, common_input_validate, HWC3, custom_hf_download +from hand_refiner.pipeline import MeshGraphormerMediapipe, args + +class MeshGraphormerDetector: + def __init__(self, pipeline): + self.pipeline = pipeline + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device="cuda"): + filename = filename or "graphormer_hand_state_dict.bin" + hrnet_filename = hrnet_filename or "hrnetv2_w64_imagenet_pretrained.pth" + args.resume_checkpoint = custom_hf_download(pretrained_model_or_path, filename, cache_dir) + args.hrnet_checkpoint = custom_hf_download(pretrained_model_or_path, hrnet_filename, cache_dir) + args.device = device + pipeline = MeshGraphormerMediapipe(args) + return cls(pipeline) + + def to(self, device): + self.pipeline._model.to(device) + self.pipeline.mano_model.to(device) + self.pipeline.mano_model.layer.to(device) + return self + + def __call__(self, input_image=None, mask_bbox_padding=30, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + + depth_map, mask, info = self.pipeline.get_depth(input_image, mask_bbox_padding) + if depth_map is None: + depth_map = np.zeros_like(input_image) + mask = np.zeros_like(input_image) + + #The hand is small + depth_map, mask = HWC3(depth_map), HWC3(mask) + depth_map, remove_pad = resize_image_with_pad(depth_map, detect_resolution, upscale_method) + depth_map = remove_pad(depth_map) + if output_type == "pil": + depth_map = Image.fromarray(depth_map) + mask = Image.fromarray(mask) + + return depth_map, mask, info diff --git a/hand_refiner/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml b/hand_refiner/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml new file mode 100644 index 0000000..65d7103 --- /dev/null +++ b/hand_refiner/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml @@ -0,0 +1,92 @@ +GPUS: (0,1,2,3) +LOG_DIR: 'log/' +DATA_DIR: '' +OUTPUT_DIR: 'output/' +WORKERS: 4 +PRINT_FREQ: 1000 + +MODEL: + NAME: cls_hrnet + IMAGE_SIZE: + - 224 + - 224 + EXTRA: + STAGE1: + NUM_MODULES: 1 + NUM_RANCHES: 1 + BLOCK: BOTTLENECK + NUM_BLOCKS: + - 4 + NUM_CHANNELS: + - 64 + FUSE_METHOD: SUM + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 64 + - 128 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 64 + - 128 + - 256 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 64 + - 128 + - 256 + - 512 + FUSE_METHOD: SUM +CUDNN: + BENCHMARK: true + DETERMINISTIC: false + ENABLED: true +DATASET: + DATASET: 'imagenet' + DATA_FORMAT: 'jpg' + ROOT: 'data/imagenet/' + TEST_SET: 'val' + TRAIN_SET: 'train' +TEST: + BATCH_SIZE_PER_GPU: 32 + MODEL_FILE: '' +TRAIN: + BATCH_SIZE_PER_GPU: 32 + BEGIN_EPOCH: 0 + END_EPOCH: 100 + RESUME: true + LR_FACTOR: 0.1 + LR_STEP: + - 30 + - 60 + - 90 + OPTIMIZER: sgd + LR: 0.05 + WD: 0.0001 + MOMENTUM: 0.9 + NESTEROV: true + SHUFFLE: true +DEBUG: + DEBUG: false diff --git a/hand_refiner/depth_preprocessor.py b/hand_refiner/depth_preprocessor.py new file mode 100644 index 0000000..496313a --- /dev/null +++ b/hand_refiner/depth_preprocessor.py @@ -0,0 +1,6 @@ +class Preprocessor: + def __init__(self) -> None: + pass + + def get_depth(self, input_dir, file_name): + return \ No newline at end of file diff --git a/hand_refiner/hand_landmarker.task b/hand_refiner/hand_landmarker.task new file mode 100644 index 0000000..0d53faf Binary files /dev/null and b/hand_refiner/hand_landmarker.task differ diff --git a/hand_refiner/pipeline.py b/hand_refiner/pipeline.py new file mode 100644 index 0000000..c05c207 --- /dev/null +++ b/hand_refiner/pipeline.py @@ -0,0 +1,468 @@ +import os +import torch +import gc +import numpy as np +from hand_refiner.depth_preprocessor import Preprocessor + +import torchvision.models as models +from mesh_graphormer.modeling.bert import BertConfig, Graphormer +from mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network +from mesh_graphormer.modeling._mano import MANO, Mesh +from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat +from mesh_graphormer.modeling.hrnet.config import config as hrnet_config +from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config +from mesh_graphormer.utils.miscellaneous import set_seed +from argparse import Namespace +from pathlib import Path +import cv2 +from torchvision import transforms +import numpy as np +import cv2 +from trimesh import Trimesh +from trimesh.ray.ray_triangle import RayMeshIntersector +import mediapipe as mp +from mediapipe.tasks import python +from mediapipe.tasks.python import vision +from torchvision import transforms +from pathlib import Path +import mesh_graphormer +from packaging import version + +args = Namespace( + num_workers=4, + img_scale_factor=1, + image_file_or_path=os.path.join('', 'MeshGraphormer', 'samples', 'hand'), + model_name_or_path=str(Path(mesh_graphormer.__file__).parent / "modeling/bert/bert-base-uncased"), + resume_checkpoint=None, + output_dir='output/', + config_name='', + a='hrnet-w64', + arch='hrnet-w64', + num_hidden_layers=4, + hidden_size=-1, + num_attention_heads=4, + intermediate_size=-1, + input_feat_dim='2051,512,128', + hidden_feat_dim='1024,256,64', + which_gcn='0,0,1', + mesh_type='hand', + run_eval_only=True, + device="cpu", + seed=88, + hrnet_checkpoint=None, +) + +#Since mediapipe v0.10.5, the hand category has been correct +if version.parse(mp.__version__) >= version.parse('0.10.5'): + true_hand_category = {"Right": "right", "Left": "left"} +else: + true_hand_category = {"Right": "left", "Left": "right"} + +class MeshGraphormerMediapipe(Preprocessor): + def __init__(self, args=args) -> None: + # Setup CUDA, GPU & distributed training + args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 + os.environ['OMP_NUM_THREADS'] = str(args.num_workers) + print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) + + #mkdir(args.output_dir) + #logger = setup_logger("Graphormer", args.output_dir, get_rank()) + set_seed(args.seed, args.num_gpus) + #logger.info("Using {} GPUs".format(args.num_gpus)) + + # Mesh and MANO utils + mano_model = MANO().to(args.device) + mano_model.layer = mano_model.layer.to(args.device) + mesh_sampler = Mesh(device=args.device) + + # Renderer for visualization + # renderer = Renderer(faces=mano_model.face) + + # Load pretrained model + trans_encoder = [] + + input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] + hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] + output_feat_dim = input_feat_dim[1:] + [3] + + # which encoder block to have graph convs + which_blk_graph = [int(item) for item in args.which_gcn.split(',')] + + if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: + # if only run eval, load checkpoint + #logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) + _model = torch.load(args.resume_checkpoint) + + else: + # init three transformer-encoder blocks in a loop + for i in range(len(output_feat_dim)): + config_class, model_class = BertConfig, Graphormer + config = config_class.from_pretrained(args.config_name if args.config_name \ + else args.model_name_or_path) + + config.output_attentions = False + config.img_feature_dim = input_feat_dim[i] + config.output_feature_dim = output_feat_dim[i] + args.hidden_size = hidden_feat_dim[i] + args.intermediate_size = int(args.hidden_size*2) + + if which_blk_graph[i]==1: + config.graph_conv = True + #logger.info("Add Graph Conv") + else: + config.graph_conv = False + + config.mesh_type = args.mesh_type + + # update model structure if specified in arguments + update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] + for idx, param in enumerate(update_params): + arg_param = getattr(args, param) + config_param = getattr(config, param) + if arg_param > 0 and arg_param != config_param: + #logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) + setattr(config, param, arg_param) + + # init a transformer encoder and append it to a list + assert config.hidden_size % config.num_attention_heads == 0 + model = model_class(config=config) + #logger.info("Init model from scratch.") + trans_encoder.append(model) + + # create backbone model + if args.arch=='hrnet': + hrnet_yaml = Path(__file__).parent / 'cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = args.hrnet_checkpoint + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + #logger.info('=> loading hrnet-v2-w40 model') + elif args.arch=='hrnet-w64': + hrnet_yaml = Path(__file__).parent / 'cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = args.hrnet_checkpoint + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + #logger.info('=> loading hrnet-v2-w64 model') + else: + print("=> using pre-trained model '{}'".format(args.arch)) + backbone = models.__dict__[args.arch](pretrained=True) + # remove the last fc layer + backbone = torch.nn.Sequential(*list(backbone.children())[:-1]) + + trans_encoder = torch.nn.Sequential(*trans_encoder) + total_params = sum(p.numel() for p in trans_encoder.parameters()) + #logger.info('Graphormer encoders total parameters: {}'.format(total_params)) + backbone_total_params = sum(p.numel() for p in backbone.parameters()) + #logger.info('Backbone total parameters: {}'.format(backbone_total_params)) + + # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder) + _model = Graphormer_Network(args, config, backbone, trans_encoder) + + if args.resume_checkpoint!=None and args.resume_checkpoint!='None': + # for fine-tuning or resume training or inference, load weights from checkpoint + #logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) + # workaround approach to load sparse tensor in graph conv. + state_dict = torch.load(args.resume_checkpoint) + _model.load_state_dict(state_dict, strict=False) + del state_dict + gc.collect() + + # update configs to enable attention outputs + setattr(_model.trans_encoder[-1].config,'output_attentions', True) + setattr(_model.trans_encoder[-1].config,'output_hidden_states', True) + _model.trans_encoder[-1].bert.encoder.output_attentions = True + _model.trans_encoder[-1].bert.encoder.output_hidden_states = True + for iter_layer in range(4): + _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True + for inter_block in range(3): + setattr(_model.trans_encoder[-1].config,'device', args.device) + + _model.to(args.device) + self._model = _model + self.mano_model = mano_model + self.mesh_sampler = mesh_sampler + + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])]) + + base_options = python.BaseOptions(model_asset_path=str( Path(__file__).parent / "hand_landmarker.task" )) + options = vision.HandLandmarkerOptions(base_options=base_options, + min_hand_detection_confidence=0.6, + min_hand_presence_confidence=0.6, + min_tracking_confidence=0.6, + num_hands=2) + + self.detector = vision.HandLandmarker.create_from_options(options) + + + def get_rays(self, W, H, fx, fy, cx, cy, c2w_t, center_pixels): # rot = I + + j, i = np.meshgrid(np.arange(H, dtype=np.float32), np.arange(W, dtype=np.float32)) + if center_pixels: + i = i.copy() + 0.5 + j = j.copy() + 0.5 + + directions = np.stack([(i - cx) / fx, (j - cy) / fy, np.ones_like(i)], -1) + directions /= np.linalg.norm(directions, axis=-1, keepdims=True) + + rays_o = np.expand_dims(c2w_t,0).repeat(H*W, 0) + + rays_d = directions # (H, W, 3) + rays_d = (rays_d / np.linalg.norm(rays_d, axis=-1, keepdims=True)).reshape(-1,3) + + return rays_o, rays_d + + def get_mask_bounding_box(self, extrema, H, W, padding=30, dynamic_resize=0.15): + x_min, x_max, y_min, y_max = extrema + bb_xpad = max(int((x_max - x_min + 1) * dynamic_resize), padding) + bb_ypad = max(int((y_max - y_min + 1) * dynamic_resize), padding) + bbx_min = np.max((x_min - bb_xpad, 0)) + bbx_max = np.min((x_max + bb_xpad, W-1)) + bby_min = np.max((y_min - bb_ypad, 0)) + bby_max = np.min((y_max + bb_ypad, H-1)) + return bbx_min, bbx_max, bby_min, bby_max + + def run_inference(self, img, Graphormer_model, mano, mesh_sampler, scale, crop_len): + global args + H, W = int(crop_len), int(crop_len) + Graphormer_model.eval() + mano.eval() + device = next(Graphormer_model.parameters()).device + with torch.no_grad(): + img_tensor = self.transform(img) + batch_imgs = torch.unsqueeze(img_tensor, 0).to(device) + + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler) + + # obtain 3d joints, which are regressed from the full mesh + pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) + # obtain 2d joints, which are projected from 3d joints of mesh + #pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) + #pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous()) + pred_camera = pred_camera.cpu() + pred_vertices = pred_vertices.cpu() + mesh = Trimesh(vertices=pred_vertices[0], faces=mano.face) + res = crop_len + focal_length = 1000 * scale + camera_t = np.array([-pred_camera[1], -pred_camera[2], -2*focal_length/(res * pred_camera[0] +1e-9)]) + pred_3d_joints_camera = pred_3d_joints_from_mesh.cpu()[0] - camera_t + z_3d_dist = pred_3d_joints_camera[:,2].clone() + + pred_2d_joints_img_space = ((pred_3d_joints_camera/z_3d_dist[:,None]) * np.array((focal_length, focal_length, 1)))[:,:2] + np.array((W/2, H/2)) + + rays_o, rays_d = self.get_rays(W, H, focal_length, focal_length, W/2, H/2, camera_t, True) + coords = np.array(list(np.ndindex(H,W))).reshape(H,W,-1).transpose(1,0,2).reshape(-1,2) + intersector = RayMeshIntersector(mesh) + points, index_ray, _ = intersector.intersects_location(rays_o, rays_d, multiple_hits=False) + + tri_index = intersector.intersects_first(rays_o, rays_d) + + tri_index = tri_index[index_ray] + + assert len(index_ray) == len(tri_index) + + discriminator = (np.sum(mesh.face_normals[tri_index]* rays_d[index_ray], axis=-1)<= 0) + points = points[discriminator] # ray intesects in interior faces, discard them + + if len(points) == 0: + return None, None + depth = (points + camera_t)[:,-1] + index_ray = index_ray[discriminator] + pixel_ray = coords[index_ray] + + minval = np.min(depth) + maxval = np.max(depth) + depthmap = np.zeros([H,W]) + + depthmap[pixel_ray[:, 0], pixel_ray[:, 1]] = 1.0 - (0.8 * (depth - minval) / (maxval - minval)) + depthmap *= 255 + return depthmap, pred_2d_joints_img_space + + + def get_depth(self, np_image, padding): + info = {} + + # STEP 3: Load the input image. + #https://stackoverflow.com/a/76407270 + image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np_image.copy()) + + # STEP 4: Detect hand landmarks from the input image. + detection_result = self.detector.detect(image) + + handedness_list = detection_result.handedness + hand_landmarks_list = detection_result.hand_landmarks + + raw_image = image.numpy_view() + H, W, C = raw_image.shape + + + # HANDLANDMARKS CAN BE EMPTY, HANDLE THIS! + if len(hand_landmarks_list) == 0: + return None, None, None + raw_image = raw_image[:, :, :3] + + padded_image = np.zeros((H*2, W*2, 3)) + padded_image[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)] = raw_image + + hand_landmarks_list, handedness_list = zip( + *sorted( + zip(hand_landmarks_list, handedness_list), key=lambda x: x[0][9].z, reverse=True + ) + ) + + padded_depthmap = np.zeros((H*2, W*2)) + mask = np.zeros((H, W)) + crop_boxes = [] + #bboxes = [] + groundtruth_2d_keypoints = [] + hands = [] + depth_failure = False + crop_lens = [] + + for idx in range(len(hand_landmarks_list)): + hand = true_hand_category[handedness_list[idx][0].category_name] + hands.append(hand) + hand_landmarks = hand_landmarks_list[idx] + handedness = handedness_list[idx] + height, width, _ = raw_image.shape + x_coordinates = [landmark.x for landmark in hand_landmarks] + y_coordinates = [landmark.y for landmark in hand_landmarks] + + # x_min, x_max, y_min, y_max: extrema from mediapipe keypoint detection + x_min = int(min(x_coordinates) * width) + x_max = int(max(x_coordinates) * width) + x_c = (x_min + x_max)//2 + y_min = int(min(y_coordinates) * height) + y_max = int(max(y_coordinates) * height) + y_c = (y_min + y_max)//2 + + #if x_max - x_min < 60 or y_max - y_min < 60: + # continue + + crop_len = (max(x_max - x_min, y_max - y_min) * 1.6) //2 * 2 + + # crop_x_min, crop_x_max, crop_y_min, crop_y_max: bounding box for mesh reconstruction + crop_x_min = int(x_c - (crop_len/2 - 1) + W/2) + crop_x_max = int(x_c + crop_len/2 + W/2) + crop_y_min = int(y_c - (crop_len/2 - 1) + H/2) + crop_y_max = int(y_c + crop_len/2 + H/2) + + cropped = padded_image[crop_y_min:crop_y_max+1, crop_x_min:crop_x_max+1] + crop_boxes.append([crop_y_min, crop_y_max, crop_x_min, crop_x_max]) + crop_lens.append(crop_len) + if hand == "left": + cropped = cv2.flip(cropped, 1) + + if crop_len < 224: + graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_CUBIC) + else: + graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_AREA) + scale = crop_len/224 + cropped_depthmap, pred_2d_keypoints = self.run_inference(graphormer_input.astype(np.uint8), self._model, self.mano_model, self.mesh_sampler, scale, int(crop_len)) + + if cropped_depthmap is None: + depth_failure = True + break + #keypoints_image_space = pred_2d_keypoints * (crop_y_max - crop_y_min + 1)/224 + groundtruth_2d_keypoints.append(pred_2d_keypoints) + + if hand == "left": + cropped_depthmap = cv2.flip(cropped_depthmap, 1) + resized_cropped_depthmap = cv2.resize(cropped_depthmap, (int(crop_len), int(crop_len)), interpolation=cv2.INTER_LINEAR) + nonzero_y, nonzero_x = (resized_cropped_depthmap != 0).nonzero() + if len(nonzero_y) == 0 or len(nonzero_x) == 0: + depth_failure = True + break + padded_depthmap[crop_y_min+nonzero_y, crop_x_min+nonzero_x] = resized_cropped_depthmap[nonzero_y, nonzero_x] + + # nonzero stands for nonzero value on the depth map + # coordinates of nonzero depth pixels in original image space + original_nonzero_x = crop_x_min+nonzero_x - int(W/2) + original_nonzero_y = crop_y_min+nonzero_y - int(H/2) + + nonzerox_min = min(np.min(original_nonzero_x), x_min) + nonzerox_max = max(np.max(original_nonzero_x), x_max) + nonzeroy_min = min(np.min(original_nonzero_y), y_min) + nonzeroy_max = max(np.max(original_nonzero_y), y_max) + + bbx_min, bbx_max, bby_min, bby_max = self.get_mask_bounding_box((nonzerox_min, nonzerox_max, nonzeroy_min, nonzeroy_max), H, W, padding) + mask[bby_min:bby_max+1, bbx_min:bbx_max+1] = 1.0 + #bboxes.append([int(bbx_min), int(bbx_max), int(bby_min), int(bby_max)]) + if depth_failure: + #print("cannot detect normal hands") + return None, None, None + depthmap = padded_depthmap[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)].astype(np.uint8) + mask = (255.0 * mask).astype(np.uint8) + info["groundtruth_2d_keypoints"] = groundtruth_2d_keypoints + info["hands"] = hands + info["crop_boxes"] = crop_boxes + info["crop_lens"] = crop_lens + return depthmap, mask, info + + def get_keypoints(self, img, Graphormer_model, mano, mesh_sampler, scale, crop_len): + global args + H, W = int(crop_len), int(crop_len) + Graphormer_model.eval() + mano.eval() + device = next(Graphormer_model.parameters()).device + with torch.no_grad(): + img_tensor = self.transform(img) + #print(img_tensor) + batch_imgs = torch.unsqueeze(img_tensor, 0).to(device) + + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler) + + # obtain 3d joints, which are regressed from the full mesh + pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) + # obtain 2d joints, which are projected from 3d joints of mesh + #pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) + #pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous()) + pred_camera = pred_camera.cpu() + pred_vertices = pred_vertices.cpu() + # + res = crop_len + focal_length = 1000 * scale + camera_t = np.array([-pred_camera[1], -pred_camera[2], -2*focal_length/(res * pred_camera[0] +1e-9)]) + pred_3d_joints_camera = pred_3d_joints_from_mesh.cpu()[0] - camera_t + z_3d_dist = pred_3d_joints_camera[:,2].clone() + pred_2d_joints_img_space = ((pred_3d_joints_camera/z_3d_dist[:,None]) * np.array((focal_length, focal_length, 1)))[:,:2] + np.array((W/2, H/2)) + + return pred_2d_joints_img_space + + + def eval_mpjpe(self, sample, info): + H, W, C = sample.shape + padded_image = np.zeros((H*2, W*2, 3)) + padded_image[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)] = sample + crop_boxes = info["crop_boxes"] + hands = info["hands"] + groundtruth_2d_keypoints = info["groundtruth_2d_keypoints"] + crop_lens = info["crop_lens"] + pjpe = 0 + for i in range(len(crop_boxes)):#box in crop_boxes: + crop_y_min, crop_y_max, crop_x_min, crop_x_max = crop_boxes[i] + cropped = padded_image[crop_y_min:crop_y_max+1, crop_x_min:crop_x_max+1] + hand = hands[i] + if hand == "left": + cropped = cv2.flip(cropped, 1) + crop_len = crop_lens[i] + scale = crop_len/224 + if crop_len < 224: + graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_CUBIC) + else: + graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_AREA) + generated_keypoint = self.get_keypoints(graphormer_input.astype(np.uint8), self._model, self.mano_model, self.mesh_sampler, scale, crop_len) + #generated_keypoint = generated_keypoint * ((crop_y_max - crop_y_min + 1)/224) + pjpe += np.sum(np.sqrt(np.sum(((generated_keypoint - groundtruth_2d_keypoints[i]) ** 2).numpy(), axis=1))) + pass + mpjpe = pjpe/(len(crop_boxes) * 21) + return mpjpe + + + + + diff --git a/hand_refiner/util.py b/hand_refiner/util.py new file mode 100644 index 0000000..a781031 --- /dev/null +++ b/hand_refiner/util.py @@ -0,0 +1,193 @@ +import os +import random + +import cv2 +import numpy as np +from pathlib import Path +import warnings +from huggingface_hub import hf_hub_download + +here = Path(__file__).parent.resolve() + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def make_noise_disk(H, W, C, F): + noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) + noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) + noise = noise[F: F + H, F: F + W] + noise -= np.min(noise) + noise /= np.max(noise) + if C == 1: + noise = noise[:, :, None] + return noise + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z + +def min_max_norm(x): + x -= np.min(x) + x /= np.maximum(np.max(x), 1e-5) + return x + + +def safe_step(x, step=2): + y = x.astype(np.float32) * float(step + 1) + y = y.astype(np.int32).astype(np.float32) / float(step) + return y + + +def img2mask(img, H, W, low=10, high=90): + assert img.ndim == 3 or img.ndim == 2 + assert img.dtype == np.uint8 + + if img.ndim == 3: + y = img[:, :, random.randrange(0, img.shape[2])] + else: + y = img + + y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) + + if random.uniform(0, 1) < 0.5: + y = 255 - y + + return y < np.percentile(y, random.randrange(low, high)) + +def safer_memory(x): + # Fix many MAC/AMD problems + return np.ascontiguousarray(x.copy()).copy() + +UPSCALE_METHODS = ["INTER_NEAREST", "INTER_LINEAR", "INTER_AREA", "INTER_CUBIC", "INTER_LANCZOS4"] +def get_upscale_method(method_str): + assert method_str in UPSCALE_METHODS, f"Method {method_str} not found in {UPSCALE_METHODS}" + return getattr(cv2, method_str) + +def pad64(x): + return int(np.ceil(float(x) / 64.0) * 64 - x) + +#https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/processor.py#L17 +#Added upscale_method param +def resize_image_with_pad(input_image, resolution, upscale_method = "", skip_hwc3=False): + if skip_hwc3: + img = input_image + else: + img = HWC3(input_image) + H_raw, W_raw, _ = img.shape + k = float(resolution) / float(min(H_raw, W_raw)) + H_target = int(np.round(float(H_raw) * k)) + W_target = int(np.round(float(W_raw) * k)) + img = cv2.resize(img, (W_target, H_target), interpolation=get_upscale_method(upscale_method) if k > 1 else cv2.INTER_AREA) + H_pad, W_pad = pad64(H_target), pad64(W_target) + img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode='edge') + + def remove_pad(x): + return safer_memory(x[:H_target, :W_target, ...]) + + return safer_memory(img_padded), remove_pad + +def common_input_validate(input_image, output_type, **kwargs): + if "img" in kwargs: + warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning) + input_image = kwargs.pop("img") + + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + if input_image is None: + raise ValueError("input_image must be defined.") + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + output_type = output_type or "pil" + else: + output_type = output_type or "np" + + return (input_image, output_type) + +def custom_hf_download(pretrained_model_or_path, filename, cache_dir, subfolder='', use_symlinks=False): + local_dir = os.path.join(cache_dir, pretrained_model_or_path) + model_path = os.path.join(local_dir, *subfolder.split('/'), filename) + + if not os.path.exists(model_path): + print(f"Failed to find {model_path}.\n Downloading from huggingface.co") + if use_symlinks: + cache_dir_d = os.getenv("HUGGINGFACE_HUB_CACHE") + if cache_dir_d is None: + import platform + if platform.system() == "Windows": + cache_dir_d = os.path.join(os.getenv("USERPROFILE"), ".cache", "huggingface", "hub") + else: + cache_dir_d = os.path.join(os.getenv("HOME"), ".cache", "huggingface", "hub") + try: + # test_link + if not os.path.exists(cache_dir_d): + os.makedirs(cache_dir_d) + open(os.path.join(cache_dir_d, f"linktest_{filename}.txt"), "w") + os.link(os.path.join(cache_dir_d, f"linktest_{filename}.txt"), os.path.join(cache_dir, f"linktest_{filename}.txt")) + os.remove(os.path.join(cache_dir, f"linktest_{filename}.txt")) + os.remove(os.path.join(cache_dir_d, f"linktest_{filename}.txt")) + print("Using symlinks to download models. \n",\ + "Make sure you have enough space on your cache folder. \n",\ + "And do not purge the cache folder after downloading.\n",\ + "Otherwise, you will have to re-download the models every time you run the script.\n",\ + "You can use USE_SYMLINKS: False in config.yaml to avoid this behavior.") + except: + print("Maybe not able to create symlink. Disable using symlinks.") + use_symlinks = False + cache_dir_d = os.path.join(cache_dir, pretrained_model_or_path, "cache") + else: + cache_dir_d = os.path.join(cache_dir, pretrained_model_or_path, "cache") + + model_path = hf_hub_download(repo_id=pretrained_model_or_path, + cache_dir=cache_dir_d, + local_dir=local_dir, + subfolder=subfolder, + filename=filename, + local_dir_use_symlinks=use_symlinks, + resume_download=True, + etag_timeout=100 + ) + if not use_symlinks: + try: + import shutil + shutil.rmtree(cache_dir_d) + except Exception as e : + print(e) + return model_path \ No newline at end of file diff --git a/manopth/CHANGES.md b/manopth/CHANGES.md new file mode 100644 index 0000000..27e7d74 --- /dev/null +++ b/manopth/CHANGES.md @@ -0,0 +1 @@ +* Chumpy is removed \ No newline at end of file diff --git a/manopth/LICENSE b/manopth/LICENSE new file mode 100644 index 0000000..e72bfdd --- /dev/null +++ b/manopth/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. \ No newline at end of file diff --git a/manopth/__init__.py b/manopth/__init__.py new file mode 100644 index 0000000..e27cf86 --- /dev/null +++ b/manopth/__init__.py @@ -0,0 +1 @@ +name = 'manopth' diff --git a/manopth/argutils.py b/manopth/argutils.py new file mode 100644 index 0000000..7e86eb0 --- /dev/null +++ b/manopth/argutils.py @@ -0,0 +1,51 @@ +import datetime +import os +import pickle +import subprocess +import sys + + +def print_args(args): + opts = vars(args) + print('======= Options ========') + for k, v in sorted(opts.items()): + print('{}: {}'.format(k, v)) + print('========================') + + +def save_args(args, save_folder, opt_prefix='opt', verbose=True): + opts = vars(args) + # Create checkpoint folder + if not os.path.exists(save_folder): + os.makedirs(save_folder, exist_ok=True) + + # Save options + opt_filename = '{}.txt'.format(opt_prefix) + opt_path = os.path.join(save_folder, opt_filename) + with open(opt_path, 'a') as opt_file: + opt_file.write('====== Options ======\n') + for k, v in sorted(opts.items()): + opt_file.write( + '{option}: {value}\n'.format(option=str(k), value=str(v))) + opt_file.write('=====================\n') + opt_file.write('launched {} at {}\n'.format( + str(sys.argv[0]), str(datetime.datetime.now()))) + + # Add git info + label = subprocess.check_output(["git", "describe", + "--always"]).strip() + if subprocess.call( + ["git", "branch"], + stderr=subprocess.STDOUT, + stdout=open(os.devnull, 'w')) == 0: + opt_file.write('=== Git info ====\n') + opt_file.write('{}\n'.format(label)) + commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']) + opt_file.write('commit : {}\n'.format(commit.strip())) + + opt_picklename = '{}.pkl'.format(opt_prefix) + opt_picklepath = os.path.join(save_folder, opt_picklename) + with open(opt_picklepath, 'wb') as opt_file: + pickle.dump(opts, opt_file) + if verbose: + print('Saved options to {}'.format(opt_path)) diff --git a/manopth/demo.py b/manopth/demo.py new file mode 100644 index 0000000..0bca468 --- /dev/null +++ b/manopth/demo.py @@ -0,0 +1,59 @@ +from matplotlib import pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +import numpy as np +import torch + +from manopth.manolayer import ManoLayer + + +def generate_random_hand(batch_size=1, ncomps=6, mano_root='mano/models'): + nfull_comps = ncomps + 3 # Add global orientation dims to PCA + random_pcapose = torch.rand(batch_size, nfull_comps) + mano_layer = ManoLayer(mano_root=mano_root) + verts, joints = mano_layer(random_pcapose) + return {'verts': verts, 'joints': joints, 'faces': mano_layer.th_faces} + + +def display_hand(hand_info, mano_faces=None, ax=None, alpha=0.2, batch_idx=0, show=True): + """ + Displays hand batch_idx in batch of hand_info, hand_info as returned by + generate_random_hand + """ + if ax is None: + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + verts, joints = hand_info['verts'][batch_idx], hand_info['joints'][ + batch_idx] + if mano_faces is None: + ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.1) + else: + mesh = Poly3DCollection(verts[mano_faces], alpha=alpha) + face_color = (141 / 255, 184 / 255, 226 / 255) + edge_color = (50 / 255, 50 / 255, 50 / 255) + mesh.set_edgecolor(edge_color) + mesh.set_facecolor(face_color) + ax.add_collection3d(mesh) + ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color='r') + cam_equal_aspect_3d(ax, verts.numpy()) + if show: + plt.show() + + +def cam_equal_aspect_3d(ax, verts, flip_x=False): + """ + Centers view on cuboid containing hand and flips y and z axis + and fixes azimuth + """ + extents = np.stack([verts.min(0), verts.max(0)], axis=1) + sz = extents[:, 1] - extents[:, 0] + centers = np.mean(extents, axis=1) + maxsize = max(abs(sz)) + r = maxsize / 2 + if flip_x: + ax.set_xlim(centers[0] + r, centers[0] - r) + else: + ax.set_xlim(centers[0] - r, centers[0] + r) + # Invert y and z axis + ax.set_ylim(centers[1] + r, centers[1] - r) + ax.set_zlim(centers[2] + r, centers[2] - r) diff --git a/manopth/manolayer.py b/manopth/manolayer.py new file mode 100644 index 0000000..5965a04 --- /dev/null +++ b/manopth/manolayer.py @@ -0,0 +1,274 @@ +import os + +import numpy as np +import torch +from torch.nn import Module + +from manopth.smpl_handpca_wrapper_HAND_only import ready_arguments +from manopth import rodrigues_layer, rotproj, rot6d +from manopth.tensutils import (th_posemap_axisang, th_with_zeros, th_pack, + subtract_flat_id, make_list) + + +class ManoLayer(Module): + __constants__ = [ + 'use_pca', 'rot', 'ncomps', 'ncomps', 'kintree_parents', 'check', + 'side', 'center_idx', 'joint_rot_mode' + ] + + def __init__(self, + center_idx=None, + flat_hand_mean=True, + ncomps=6, + side='right', + mano_root='mano/models', + use_pca=True, + root_rot_mode='axisang', + joint_rot_mode='axisang', + robust_rot=False): + """ + Args: + center_idx: index of center joint in our computations, + if -1 centers on estimate of palm as middle of base + of middle finger and wrist + flat_hand_mean: if True, (0, 0, 0, ...) pose coefficients match + flat hand, else match average hand pose + mano_root: path to MANO pkl files for left and right hand + ncomps: number of PCA components form pose space (<45) + side: 'right' or 'left' + use_pca: Use PCA decomposition for pose space. + joint_rot_mode: 'axisang' or 'rotmat', ignored if use_pca + """ + super().__init__() + + self.center_idx = center_idx + self.robust_rot = robust_rot + if root_rot_mode == 'axisang': + self.rot = 3 + else: + self.rot = 6 + self.flat_hand_mean = flat_hand_mean + self.side = side + self.use_pca = use_pca + self.joint_rot_mode = joint_rot_mode + self.root_rot_mode = root_rot_mode + if use_pca: + self.ncomps = ncomps + else: + self.ncomps = 45 + + if side == 'right': + self.mano_path = os.path.join(mano_root, 'MANO_RIGHT.pkl') + elif side == 'left': + self.mano_path = os.path.join(mano_root, 'MANO_LEFT.pkl') + + smpl_data = ready_arguments(self.mano_path) + + hands_components = smpl_data['hands_components'] + + self.smpl_data = smpl_data + + self.register_buffer('th_betas', + torch.Tensor(smpl_data['betas']).unsqueeze(0)) + self.register_buffer('th_shapedirs', + torch.Tensor(smpl_data['shapedirs'])) + self.register_buffer('th_posedirs', + torch.Tensor(smpl_data['posedirs'])) + self.register_buffer( + 'th_v_template', + torch.Tensor(smpl_data['v_template']).unsqueeze(0)) + self.register_buffer( + 'th_J_regressor', + torch.Tensor(np.array(smpl_data['J_regressor'].toarray()))) + self.register_buffer('th_weights', + torch.Tensor(smpl_data['weights'])) + self.register_buffer('th_faces', + torch.Tensor(smpl_data['f'].astype(np.int32)).long()) + + # Get hand mean + hands_mean = np.zeros(hands_components.shape[1] + ) if flat_hand_mean else smpl_data['hands_mean'] + hands_mean = hands_mean.copy() + th_hands_mean = torch.Tensor(hands_mean).unsqueeze(0) + if self.use_pca or self.joint_rot_mode == 'axisang': + # Save as axis-angle + self.register_buffer('th_hands_mean', th_hands_mean) + selected_components = hands_components[:ncomps] + self.register_buffer('th_comps', torch.Tensor(hands_components)) + self.register_buffer('th_selected_comps', + torch.Tensor(selected_components)) + else: + th_hands_mean_rotmat = rodrigues_layer.batch_rodrigues( + th_hands_mean.view(15, 3)).reshape(15, 3, 3) + self.register_buffer('th_hands_mean_rotmat', th_hands_mean_rotmat) + + # Kinematic chain params + self.kintree_table = smpl_data['kintree_table'] + parents = list(self.kintree_table[0].tolist()) + self.kintree_parents = parents + + def forward(self, + th_pose_coeffs, + th_betas=torch.zeros(1), + th_trans=torch.zeros(1), + root_palm=torch.Tensor([0]), + share_betas=torch.Tensor([0]), + ): + """ + Args: + th_trans (Tensor (batch_size x ncomps)): if provided, applies trans to joints and vertices + th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters for hand shape + else centers on root joint (9th joint) + root_palm: return palm as hand root instead of wrist + """ + # if len(th_pose_coeffs) == 0: + # return th_pose_coeffs.new_empty(0), th_pose_coeffs.new_empty(0) + + batch_size = th_pose_coeffs.shape[0] + # Get axis angle from PCA components and coefficients + if self.use_pca or self.joint_rot_mode == 'axisang': + # Remove global rot coeffs + th_hand_pose_coeffs = th_pose_coeffs[:, self.rot:self.rot + + self.ncomps] + if self.use_pca: + # PCA components --> axis angles + th_full_hand_pose = th_hand_pose_coeffs.mm(self.th_selected_comps) + else: + th_full_hand_pose = th_hand_pose_coeffs + + # Concatenate back global rot + th_full_pose = torch.cat([ + th_pose_coeffs[:, :self.rot], + self.th_hands_mean + th_full_hand_pose + ], 1) + if self.root_rot_mode == 'axisang': + # compute rotation matrixes from axis-angle while skipping global rotation + th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose) + root_rot = th_rot_map[:, :9].view(batch_size, 3, 3) + th_rot_map = th_rot_map[:, 9:] + th_pose_map = th_pose_map[:, 9:] + else: + # th_posemap offsets by 3, so add offset or 3 to get to self.rot=6 + th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose[:, 6:]) + if self.robust_rot: + root_rot = rot6d.robust_compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6]) + else: + root_rot = rot6d.compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6]) + else: + assert th_pose_coeffs.dim() == 4, ( + 'When not self.use_pca, ' + 'th_pose_coeffs should have 4 dims, got {}'.format( + th_pose_coeffs.dim())) + assert th_pose_coeffs.shape[2:4] == (3, 3), ( + 'When not self.use_pca, th_pose_coeffs have 3x3 matrix for two' + 'last dims, got {}'.format(th_pose_coeffs.shape[2:4])) + th_pose_rots = rotproj.batch_rotprojs(th_pose_coeffs) + th_rot_map = th_pose_rots[:, 1:].view(batch_size, -1) + th_pose_map = subtract_flat_id(th_rot_map) + root_rot = th_pose_rots[:, 0] + + # Full axis angle representation with root joint + if th_betas is None or th_betas.numel() == 1: + th_v_shaped = torch.matmul(self.th_shapedirs, + self.th_betas.transpose(1, 0)).permute( + 2, 0, 1) + self.th_v_template + th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat( + batch_size, 1, 1) + + else: + if share_betas: + th_betas = th_betas.mean(0, keepdim=True).expand(th_betas.shape[0], 10) + th_v_shaped = torch.matmul(self.th_shapedirs, + th_betas.transpose(1, 0)).permute( + 2, 0, 1) + self.th_v_template + th_j = torch.matmul(self.th_J_regressor, th_v_shaped) + # th_pose_map should have shape 20x135 + + th_v_posed = th_v_shaped + torch.matmul( + self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1) + # Final T pose with transformation done ! + + # Global rigid transformation + + root_j = th_j[:, 0, :].contiguous().view(batch_size, 3, 1) + root_trans = th_with_zeros(torch.cat([root_rot, root_j], 2)) + + all_rots = th_rot_map.view(th_rot_map.shape[0], 15, 3, 3) + lev1_idxs = [1, 4, 7, 10, 13] + lev2_idxs = [2, 5, 8, 11, 14] + lev3_idxs = [3, 6, 9, 12, 15] + lev1_rots = all_rots[:, [idx - 1 for idx in lev1_idxs]] + lev2_rots = all_rots[:, [idx - 1 for idx in lev2_idxs]] + lev3_rots = all_rots[:, [idx - 1 for idx in lev3_idxs]] + lev1_j = th_j[:, lev1_idxs] + lev2_j = th_j[:, lev2_idxs] + lev3_j = th_j[:, lev3_idxs] + + # From base to tips + # Get lev1 results + all_transforms = [root_trans.unsqueeze(1)] + lev1_j_rel = lev1_j - root_j.transpose(1, 2) + lev1_rel_transform_flt = th_with_zeros(torch.cat([lev1_rots, lev1_j_rel.unsqueeze(3)], 3).view(-1, 3, 4)) + root_trans_flt = root_trans.unsqueeze(1).repeat(1, 5, 1, 1).view(root_trans.shape[0] * 5, 4, 4) + lev1_flt = torch.matmul(root_trans_flt, lev1_rel_transform_flt) + all_transforms.append(lev1_flt.view(all_rots.shape[0], 5, 4, 4)) + + # Get lev2 results + lev2_j_rel = lev2_j - lev1_j + lev2_rel_transform_flt = th_with_zeros(torch.cat([lev2_rots, lev2_j_rel.unsqueeze(3)], 3).view(-1, 3, 4)) + lev2_flt = torch.matmul(lev1_flt, lev2_rel_transform_flt) + all_transforms.append(lev2_flt.view(all_rots.shape[0], 5, 4, 4)) + + # Get lev3 results + lev3_j_rel = lev3_j - lev2_j + lev3_rel_transform_flt = th_with_zeros(torch.cat([lev3_rots, lev3_j_rel.unsqueeze(3)], 3).view(-1, 3, 4)) + lev3_flt = torch.matmul(lev2_flt, lev3_rel_transform_flt) + all_transforms.append(lev3_flt.view(all_rots.shape[0], 5, 4, 4)) + + reorder_idxs = [0, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14, 5, 10, 15] + th_results = torch.cat(all_transforms, 1)[:, reorder_idxs] + th_results_global = th_results + + joint_js = torch.cat([th_j, th_j.new_zeros(th_j.shape[0], 16, 1)], 2) + tmp2 = torch.matmul(th_results, joint_js.unsqueeze(3)) + th_results2 = (th_results - torch.cat([tmp2.new_zeros(*tmp2.shape[:2], 4, 3), tmp2], 3)).permute(0, 2, 3, 1) + + th_T = torch.matmul(th_results2, self.th_weights.transpose(0, 1)) + + th_rest_shape_h = torch.cat([ + th_v_posed.transpose(2, 1), + torch.ones((batch_size, 1, th_v_posed.shape[1]), + dtype=th_T.dtype, + device=th_T.device), + ], 1) + + th_verts = (th_T * th_rest_shape_h.unsqueeze(1)).sum(2).transpose(2, 1) + th_verts = th_verts[:, :, :3] + th_jtr = th_results_global[:, :, :3, 3] + # In addition to MANO reference joints we sample vertices on each finger + # to serve as finger tips + if self.side == 'right': + tips = th_verts[:, [745, 317, 444, 556, 673]] + else: + tips = th_verts[:, [745, 317, 445, 556, 673]] + if bool(root_palm): + palm = (th_verts[:, 95] + th_verts[:, 22]).unsqueeze(1) / 2 + th_jtr = torch.cat([palm, th_jtr[:, 1:]], 1) + th_jtr = torch.cat([th_jtr, tips], 1) + + # Reorder joints to match visualization utilities + th_jtr = th_jtr[:, [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]] + + if th_trans is None or bool(torch.norm(th_trans) == 0): + if self.center_idx is not None: + center_joint = th_jtr[:, self.center_idx].unsqueeze(1) + th_jtr = th_jtr - center_joint + th_verts = th_verts - center_joint + else: + th_jtr = th_jtr + th_trans.unsqueeze(1) + th_verts = th_verts + th_trans.unsqueeze(1) + + # Scale to milimeters + th_verts = th_verts * 1000 + th_jtr = th_jtr * 1000 + return th_verts, th_jtr diff --git a/manopth/posemapper.py b/manopth/posemapper.py new file mode 100644 index 0000000..9b86ea0 --- /dev/null +++ b/manopth/posemapper.py @@ -0,0 +1,37 @@ +''' +Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved. +This software is provided for research purposes only. +By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license + +More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de. +For comments or questions, please email us at: mano@tue.mpg.de + + +About this file: +================ +This file defines a wrapper for the loading functions of the MANO model. + +Modules included: +- load_model: + loads the MANO model from a given file location (i.e. a .pkl file location), + or a dictionary object. + +''' + + +import numpy as np +import cv2 + +def lrotmin(p): + if isinstance(p, np.ndarray): + p = p.ravel()[3:] + return np.concatenate( + [(cv2.Rodrigues(np.array(pp))[0] - np.eye(3)).ravel() + for pp in p.reshape((-1, 3))]).ravel() + + +def posemap(s): + if s == 'lrotmin': + return lrotmin + else: + raise Exception('Unknown posemapping: %s' % (str(s), )) \ No newline at end of file diff --git a/manopth/rodrigues_layer.py b/manopth/rodrigues_layer.py new file mode 100644 index 0000000..bb5ac1e --- /dev/null +++ b/manopth/rodrigues_layer.py @@ -0,0 +1,89 @@ +""" +This part reuses code from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py +which is part of a PyTorch port of SMPL. +Thanks to Zhang Xiong (MandyMo) for making this great code available on github ! +""" + +import argparse +from torch.autograd import gradcheck +import torch +from torch.autograd import Variable + +from manopth import argutils + + +def quat2mat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [batch_size, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, + 2], norm_quat[:, + 3] + + batch_size = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack([ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, + w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, + w2 - x2 - y2 + z2 + ], + dim=1).view(batch_size, 3, 3) + return rotMat + + +def batch_rodrigues(axisang): + #axisang N x 3 + axisang_norm = torch.norm(axisang + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(axisang_norm, -1) + axisang_normalized = torch.div(axisang, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * axisang_normalized], dim=1) + rot_mat = quat2mat(quat) + rot_mat = rot_mat.view(rot_mat.shape[0], 9) + return rot_mat + + +def th_get_axis_angle(vector): + angle = torch.norm(vector, 2, 1) + axes = vector / angle.unsqueeze(1) + return axes, angle + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', default=1, type=int) + parser.add_argument('--cuda', action='store_true') + args = parser.parse_args() + + argutils.print_args(args) + + n_components = 6 + rot = 3 + inputs = torch.rand(args.batch_size, rot) + inputs_var = Variable(inputs.double(), requires_grad=True) + if args.cuda: + inputs = inputs.cuda() + # outputs = batch_rodrigues(inputs) + test_function = gradcheck(batch_rodrigues, (inputs_var, )) + print('batch test passed !') + + inputs = torch.rand(rot) + inputs_var = Variable(inputs.double(), requires_grad=True) + test_function = gradcheck(th_cv2_rod_sub_id.apply, (inputs_var, )) + print('th_cv2_rod test passed') + + inputs = torch.rand(rot) + inputs_var = Variable(inputs.double(), requires_grad=True) + test_th = gradcheck(th_cv2_rod.apply, (inputs_var, )) + print('th_cv2_rod_id test passed !') diff --git a/manopth/rot6d.py b/manopth/rot6d.py new file mode 100644 index 0000000..c1d60ef --- /dev/null +++ b/manopth/rot6d.py @@ -0,0 +1,71 @@ +import torch + + +def compute_rotation_matrix_from_ortho6d(poses): + """ + Code from + https://github.com/papagina/RotationContinuity + On the Continuity of Rotation Representations in Neural Networks + Zhou et al. CVPR19 + https://zhouyisjtu.github.io/project_rotation/rotation.html + """ + x_raw = poses[:, 0:3] # batch*3 + y_raw = poses[:, 3:6] # batch*3 + + x = normalize_vector(x_raw) # batch*3 + z = cross_product(x, y_raw) # batch*3 + z = normalize_vector(z) # batch*3 + y = cross_product(z, x) # batch*3 + + x = x.view(-1, 3, 1) + y = y.view(-1, 3, 1) + z = z.view(-1, 3, 1) + matrix = torch.cat((x, y, z), 2) # batch*3*3 + return matrix + +def robust_compute_rotation_matrix_from_ortho6d(poses): + """ + Instead of making 2nd vector orthogonal to first + create a base that takes into account the two predicted + directions equally + """ + x_raw = poses[:, 0:3] # batch*3 + y_raw = poses[:, 3:6] # batch*3 + + x = normalize_vector(x_raw) # batch*3 + y = normalize_vector(y_raw) # batch*3 + middle = normalize_vector(x + y) + orthmid = normalize_vector(x - y) + x = normalize_vector(middle + orthmid) + y = normalize_vector(middle - orthmid) + # Their scalar product should be small ! + # assert torch.einsum("ij,ij->i", [x, y]).abs().max() < 0.00001 + z = normalize_vector(cross_product(x, y)) + + x = x.view(-1, 3, 1) + y = y.view(-1, 3, 1) + z = z.view(-1, 3, 1) + matrix = torch.cat((x, y, z), 2) # batch*3*3 + # Check for reflection in matrix ! If found, flip last vector TODO + assert (torch.stack([torch.det(mat) for mat in matrix ])< 0).sum() == 0 + return matrix + + +def normalize_vector(v): + batch = v.shape[0] + v_mag = torch.sqrt(v.pow(2).sum(1)) # batch + v_mag = torch.max(v_mag, v.new([1e-8])) + v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1]) + v = v/v_mag + return v + + +def cross_product(u, v): + batch = u.shape[0] + i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1] + j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2] + k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0] + + out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1) + + return out diff --git a/manopth/rotproj.py b/manopth/rotproj.py new file mode 100644 index 0000000..91a601d --- /dev/null +++ b/manopth/rotproj.py @@ -0,0 +1,21 @@ +import torch + + +def batch_rotprojs(batches_rotmats): + proj_rotmats = [] + for batch_idx, batch_rotmats in enumerate(batches_rotmats): + proj_batch_rotmats = [] + for rot_idx, rotmat in enumerate(batch_rotmats): + # GPU implementation of svd is VERY slow + # ~ 2 10^-3 per hit vs 5 10^-5 on cpu + U, S, V = rotmat.cpu().svd() + rotmat = torch.matmul(U, V.transpose(0, 1)) + orth_det = rotmat.det() + # Remove reflection + if orth_det < 0: + rotmat[:, 2] = -1 * rotmat[:, 2] + + rotmat = rotmat.cuda() + proj_batch_rotmats.append(rotmat) + proj_rotmats.append(torch.stack(proj_batch_rotmats)) + return torch.stack(proj_rotmats) diff --git a/manopth/smpl_handpca_wrapper_HAND_only.py b/manopth/smpl_handpca_wrapper_HAND_only.py new file mode 100644 index 0000000..bd54652 --- /dev/null +++ b/manopth/smpl_handpca_wrapper_HAND_only.py @@ -0,0 +1,155 @@ +''' +Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved. +This software is provided for research purposes only. +By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license + +More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de. +For comments or questions, please email us at: mano@tue.mpg.de + + +About this file: +================ +This file defines a wrapper for the loading functions of the MANO model. + +Modules included: +- load_model: + loads the MANO model from a given file location (i.e. a .pkl file location), + or a dictionary object. + +''' + +def col(A): + return A.reshape((-1, 1)) + +def MatVecMult(mtx, vec): + result = mtx.dot(col(vec.ravel())).ravel() + if len(vec.shape) > 1 and vec.shape[1] > 1: + result = result.reshape((-1, vec.shape[1])) + return result + +def ready_arguments(fname_or_dict, posekey4vposed='pose'): + import numpy as np + import pickle + from manopth.posemapper import posemap + + if not isinstance(fname_or_dict, dict): + dd = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1') + # dd = pickle.load(open(fname_or_dict, 'rb')) + else: + dd = fname_or_dict + + want_shapemodel = 'shapedirs' in dd + nposeparms = dd['kintree_table'].shape[1] * 3 + + if 'trans' not in dd: + dd['trans'] = np.zeros(3) + if 'pose' not in dd: + dd['pose'] = np.zeros(nposeparms) + if 'shapedirs' in dd and 'betas' not in dd: + dd['betas'] = np.zeros(dd['shapedirs'].shape[-1]) + + for s in [ + 'v_template', 'weights', 'posedirs', 'pose', 'trans', 'shapedirs', + 'betas', 'J' + ]: + if (s in dd) and not hasattr(dd[s], 'dterms'): + dd[s] = np.array(dd[s]) + + assert (posekey4vposed in dd) + if want_shapemodel: + dd['v_shaped'] = dd['shapedirs'].dot(dd['betas']) + dd['v_template'] + v_shaped = dd['v_shaped'] + J_tmpx = MatVecMult(dd['J_regressor'], v_shaped[:, 0]) + J_tmpy = MatVecMult(dd['J_regressor'], v_shaped[:, 1]) + J_tmpz = MatVecMult(dd['J_regressor'], v_shaped[:, 2]) + dd['J'] = np.vstack((J_tmpx, J_tmpy, J_tmpz)).T + pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed]) + dd['v_posed'] = v_shaped + dd['posedirs'].dot(pose_map_res) + else: + pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed]) + dd_add = dd['posedirs'].dot(pose_map_res) + dd['v_posed'] = dd['v_template'] + dd_add + + return dd + + +def load_model(fname_or_dict, ncomps=6, flat_hand_mean=False, v_template=None): + ''' This model loads the fully articulable HAND SMPL model, + and replaces the pose DOFS by ncomps from PCA''' + + from manopth.verts import verts_core + import numpy as np + import pickle + import scipy.sparse as sp + np.random.seed(1) + + if not isinstance(fname_or_dict, dict): + smpl_data = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1') + # smpl_data = pickle.load(open(fname_or_dict, 'rb')) + else: + smpl_data = fname_or_dict + + rot = 3 # for global orientation!!! + + hands_components = smpl_data['hands_components'] + hands_mean = np.zeros(hands_components.shape[ + 1]) if flat_hand_mean else smpl_data['hands_mean'] + hands_coeffs = smpl_data['hands_coeffs'][:, :ncomps] + + selected_components = np.vstack((hands_components[:ncomps])) + hands_mean = hands_mean.copy() + + pose_coeffs = np.zeros(rot + selected_components.shape[0]) + full_hand_pose = pose_coeffs[rot:(rot + ncomps)].dot(selected_components) + + smpl_data['fullpose'] = np.concatenate((pose_coeffs[:rot], + hands_mean + full_hand_pose)) + smpl_data['pose'] = pose_coeffs + + Jreg = smpl_data['J_regressor'] + if not sp.issparse(Jreg): + smpl_data['J_regressor'] = (sp.csc_matrix( + (Jreg.data, (Jreg.row, Jreg.col)), shape=Jreg.shape)) + + # slightly modify ready_arguments to make sure that it uses the fullpose + # (which will NOT be pose) for the computation of posedirs + dd = ready_arguments(smpl_data, posekey4vposed='fullpose') + + # create the smpl formula with the fullpose, + # but expose the PCA coefficients as smpl.pose for compatibility + args = { + 'pose': dd['fullpose'], + 'v': dd['v_posed'], + 'J': dd['J'], + 'weights': dd['weights'], + 'kintree_table': dd['kintree_table'], + 'xp': np, + 'want_Jtr': True, + 'bs_style': dd['bs_style'], + } + + result_previous, meta = verts_core(**args) + + result = result_previous + dd['trans'].reshape((1, 3)) + result.no_translation = result_previous + + if meta is not None: + for field in ['Jtr', 'A', 'A_global', 'A_weighted']: + if (hasattr(meta, field)): + setattr(result, field, getattr(meta, field)) + + setattr(result, 'Jtr', meta) + if hasattr(result, 'Jtr'): + result.J_transformed = result.Jtr + dd['trans'].reshape((1, 3)) + + for k, v in dd.items(): + setattr(result, k, v) + + if v_template is not None: + result.v_template[:] = v_template + + return result + + +if __name__ == '__main__': + load_model() \ No newline at end of file diff --git a/manopth/tensutils.py b/manopth/tensutils.py new file mode 100644 index 0000000..0c64c78 --- /dev/null +++ b/manopth/tensutils.py @@ -0,0 +1,47 @@ +import torch + +from manopth import rodrigues_layer + + +def th_posemap_axisang(pose_vectors): + rot_nb = int(pose_vectors.shape[1] / 3) + pose_vec_reshaped = pose_vectors.contiguous().view(-1, 3) + rot_mats = rodrigues_layer.batch_rodrigues(pose_vec_reshaped) + rot_mats = rot_mats.view(pose_vectors.shape[0], rot_nb * 9) + pose_maps = subtract_flat_id(rot_mats) + return pose_maps, rot_mats + + +def th_with_zeros(tensor): + batch_size = tensor.shape[0] + padding = tensor.new([0.0, 0.0, 0.0, 1.0]) + padding.requires_grad = False + + concat_list = [tensor, padding.view(1, 1, 4).repeat(batch_size, 1, 1)] + cat_res = torch.cat(concat_list, 1) + return cat_res + + +def th_pack(tensor): + batch_size = tensor.shape[0] + padding = tensor.new_zeros((batch_size, 4, 3)) + padding.requires_grad = False + pack_list = [padding, tensor] + pack_res = torch.cat(pack_list, 2) + return pack_res + + +def subtract_flat_id(rot_mats): + # Subtracts identity as a flattened tensor + rot_nb = int(rot_mats.shape[1] / 9) + id_flat = torch.eye( + 3, dtype=rot_mats.dtype, device=rot_mats.device).view(1, 9).repeat( + rot_mats.shape[0], rot_nb) + # id_flat.requires_grad = False + results = rot_mats - id_flat + return results + + +def make_list(tensor): + # type: (List[int]) -> List[int] + return tensor diff --git a/manopth/verts.py b/manopth/verts.py new file mode 100644 index 0000000..7a9e5c3 --- /dev/null +++ b/manopth/verts.py @@ -0,0 +1,117 @@ +''' +Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved. +This software is provided for research purposes only. +By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license + +More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de. +For comments or questions, please email us at: mano@tue.mpg.de + + +About this file: +================ +This file defines a wrapper for the loading functions of the MANO model. + +Modules included: +- load_model: + loads the MANO model from a given file location (i.e. a .pkl file location), + or a dictionary object. + +''' + + +import numpy as np +import mano.webuser.lbs as lbs +from mano.webuser.posemapper import posemap +import scipy.sparse as sp + + +def ischumpy(x): + return hasattr(x, 'dterms') + + +def verts_decorated(trans, + pose, + v_template, + J_regressor, + weights, + kintree_table, + bs_style, + f, + bs_type=None, + posedirs=None, + betas=None, + shapedirs=None, + want_Jtr=False): + + for which in [ + trans, pose, v_template, weights, posedirs, betas, shapedirs + ]: + if which is not None: + assert ischumpy(which) + + v = v_template + + if shapedirs is not None: + if betas is None: + betas = np.zeros(shapedirs.shape[-1]) + v_shaped = v + shapedirs.dot(betas) + else: + v_shaped = v + + if posedirs is not None: + v_posed = v_shaped + posedirs.dot(posemap(bs_type)(pose)) + else: + v_posed = v_shaped + + v = v_posed + + if sp.issparse(J_regressor): + J_tmpx = np.matmul(J_regressor, v_shaped[:, 0]) + J_tmpy = np.matmul(J_regressor, v_shaped[:, 1]) + J_tmpz = np.matmul(J_regressor, v_shaped[:, 2]) + J = np.vstack((J_tmpx, J_tmpy, J_tmpz)).T + else: + assert (ischumpy(J)) + + assert (bs_style == 'lbs') + result, Jtr = lbs.verts_core( + pose, v, J, weights, kintree_table, want_Jtr=True, xp=np) + + tr = trans.reshape((1, 3)) + result = result + tr + Jtr = Jtr + tr + + result.trans = trans + result.f = f + result.pose = pose + result.v_template = v_template + result.J = J + result.J_regressor = J_regressor + result.weights = weights + result.kintree_table = kintree_table + result.bs_style = bs_style + result.bs_type = bs_type + if posedirs is not None: + result.posedirs = posedirs + result.v_posed = v_posed + if shapedirs is not None: + result.shapedirs = shapedirs + result.betas = betas + result.v_shaped = v_shaped + if want_Jtr: + result.J_transformed = Jtr + return result + + +def verts_core(pose, + v, + J, + weights, + kintree_table, + bs_style, + want_Jtr=False, + xp=np): + + assert (bs_style == 'lbs') + result = lbs.verts_core(pose, v, J, weights, kintree_table, want_Jtr, xp) + return result \ No newline at end of file diff --git a/mesh_graphormer/__init__.py b/mesh_graphormer/__init__.py new file mode 100644 index 0000000..3dc1f76 --- /dev/null +++ b/mesh_graphormer/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/mesh_graphormer/datasets/__init__.py b/mesh_graphormer/datasets/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/mesh_graphormer/datasets/__init__.py @@ -0,0 +1 @@ + diff --git a/mesh_graphormer/datasets/build.py b/mesh_graphormer/datasets/build.py new file mode 100644 index 0000000..16477ea --- /dev/null +++ b/mesh_graphormer/datasets/build.py @@ -0,0 +1,147 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + + +import os.path as op +import torch +import logging +import code +from mesh_graphormer.utils.comm import get_world_size +from mesh_graphormer.datasets.human_mesh_tsv import (MeshTSVDataset, MeshTSVYamlDataset) +from mesh_graphormer.datasets.hand_mesh_tsv import (HandMeshTSVDataset, HandMeshTSVYamlDataset) + + +def build_dataset(yaml_file, args, is_train=True, scale_factor=1): + print(yaml_file) + if not op.isfile(yaml_file): + yaml_file = op.join(args.data_dir, yaml_file) + # code.interact(local=locals()) + assert op.isfile(yaml_file) + return MeshTSVYamlDataset(yaml_file, is_train, False, scale_factor) + + +class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler): + """ + Wraps a BatchSampler, resampling from it until + a specified number of iterations have been sampled + """ + + def __init__(self, batch_sampler, num_iterations, start_iter=0): + self.batch_sampler = batch_sampler + self.num_iterations = num_iterations + self.start_iter = start_iter + + def __iter__(self): + iteration = self.start_iter + while iteration <= self.num_iterations: + # if the underlying sampler has a set_epoch method, like + # DistributedSampler, used for making each process see + # a different split of the dataset, then set it + if hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(iteration) + for batch in self.batch_sampler: + iteration += 1 + if iteration > self.num_iterations: + break + yield batch + + def __len__(self): + return self.num_iterations + + +def make_batch_data_sampler(sampler, images_per_gpu, num_iters=None, start_iter=0): + batch_sampler = torch.utils.data.sampler.BatchSampler( + sampler, images_per_gpu, drop_last=False + ) + if num_iters is not None and num_iters >= 0: + batch_sampler = IterationBasedBatchSampler( + batch_sampler, num_iters, start_iter + ) + return batch_sampler + + +def make_data_sampler(dataset, shuffle, distributed): + if distributed: + return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) + if shuffle: + sampler = torch.utils.data.sampler.RandomSampler(dataset) + else: + sampler = torch.utils.data.sampler.SequentialSampler(dataset) + return sampler + + +def make_data_loader(args, yaml_file, is_distributed=True, + is_train=True, start_iter=0, scale_factor=1): + + dataset = build_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor) + logger = logging.getLogger(__name__) + if is_train==True: + shuffle = True + images_per_gpu = args.per_gpu_train_batch_size + images_per_batch = images_per_gpu * get_world_size() + iters_per_batch = len(dataset) // images_per_batch + num_iters = iters_per_batch * args.num_train_epochs + logger.info("Train with {} images per GPU.".format(images_per_gpu)) + logger.info("Total batch size {}".format(images_per_batch)) + logger.info("Total training steps {}".format(num_iters)) + else: + shuffle = False + images_per_gpu = args.per_gpu_eval_batch_size + num_iters = None + start_iter = 0 + + sampler = make_data_sampler(dataset, shuffle, is_distributed) + batch_sampler = make_batch_data_sampler( + sampler, images_per_gpu, num_iters, start_iter + ) + data_loader = torch.utils.data.DataLoader( + dataset, num_workers=args.num_workers, batch_sampler=batch_sampler, + pin_memory=True, + ) + return data_loader + + +#============================================================================================== + +def build_hand_dataset(yaml_file, args, is_train=True, scale_factor=1): + print(yaml_file) + if not op.isfile(yaml_file): + yaml_file = op.join(args.data_dir, yaml_file) + # code.interact(local=locals()) + assert op.isfile(yaml_file) + return HandMeshTSVYamlDataset(args, yaml_file, is_train, False, scale_factor) + + +def make_hand_data_loader(args, yaml_file, is_distributed=True, + is_train=True, start_iter=0, scale_factor=1): + + dataset = build_hand_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor) + logger = logging.getLogger(__name__) + if is_train==True: + shuffle = True + images_per_gpu = args.per_gpu_train_batch_size + images_per_batch = images_per_gpu * get_world_size() + iters_per_batch = len(dataset) // images_per_batch + num_iters = iters_per_batch * args.num_train_epochs + logger.info("Train with {} images per GPU.".format(images_per_gpu)) + logger.info("Total batch size {}".format(images_per_batch)) + logger.info("Total training steps {}".format(num_iters)) + else: + shuffle = False + images_per_gpu = args.per_gpu_eval_batch_size + num_iters = None + start_iter = 0 + + sampler = make_data_sampler(dataset, shuffle, is_distributed) + batch_sampler = make_batch_data_sampler( + sampler, images_per_gpu, num_iters, start_iter + ) + data_loader = torch.utils.data.DataLoader( + dataset, num_workers=args.num_workers, batch_sampler=batch_sampler, + pin_memory=True, + ) + return data_loader + diff --git a/mesh_graphormer/datasets/hand_mesh_tsv.py b/mesh_graphormer/datasets/hand_mesh_tsv.py new file mode 100644 index 0000000..4f4a46f --- /dev/null +++ b/mesh_graphormer/datasets/hand_mesh_tsv.py @@ -0,0 +1,334 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + + +import cv2 +import math +import json +from PIL import Image +import os.path as op +import numpy as np +import code + +from mesh_graphormer.utils.tsv_file import TSVFile, CompositeTSVFile +from mesh_graphormer.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml +from mesh_graphormer.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa +import torch +import torchvision.transforms as transforms + + +class HandMeshTSVDataset(object): + def __init__(self, args, img_file, label_file=None, hw_file=None, + linelist_file=None, is_train=True, cv2_output=False, scale_factor=1): + + self.args = args + self.img_file = img_file + self.label_file = label_file + self.hw_file = hw_file + self.linelist_file = linelist_file + self.img_tsv = self.get_tsv_file(img_file) + self.label_tsv = None if label_file is None else self.get_tsv_file(label_file) + self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file) + + if self.is_composite: + assert op.isfile(self.linelist_file) + self.line_list = [i for i in range(self.hw_tsv.num_rows())] + else: + self.line_list = load_linelist_file(linelist_file) + + self.cv2_output = cv2_output + self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + self.is_train = is_train + self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor] + self.noise_factor = 0.4 + self.rot_factor = 90 # Random rotation in the range [-rot_factor, rot_factor] + self.img_res = 224 + self.image_keys = self.prepare_image_keys() + self.joints_definition = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', + 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') + self.root_index = self.joints_definition.index('Wrist') + + def get_tsv_file(self, tsv_file): + if tsv_file: + if self.is_composite: + return CompositeTSVFile(tsv_file, self.linelist_file, + root=self.root) + tsv_path = find_file_path_in_yaml(tsv_file, self.root) + return TSVFile(tsv_path) + + def get_valid_tsv(self): + # sorted by file size + if self.hw_tsv: + return self.hw_tsv + if self.label_tsv: + return self.label_tsv + + def prepare_image_keys(self): + tsv = self.get_valid_tsv() + return [tsv.get_key(i) for i in range(tsv.num_rows())] + + def prepare_image_key_to_index(self): + tsv = self.get_valid_tsv() + return {tsv.get_key(i) : i for i in range(tsv.num_rows())} + + + def augm_params(self): + """Get augmentation parameters.""" + flip = 0 # flipping + pn = np.ones(3) # per channel pixel-noise + + if self.args.multiscale_inference == False: + rot = 0 # rotation + sc = 1.0 # scaling + elif self.args.multiscale_inference == True: + rot = self.args.rot + sc = self.args.sc + + if self.is_train: + sc = 1.0 + # Each channel is multiplied with a number + # in the area [1-opt.noiseFactor,1+opt.noiseFactor] + pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3) + + # The rotation is a number in the area [-2*rotFactor, 2*rotFactor] + rot = min(2*self.rot_factor, + max(-2*self.rot_factor, np.random.randn()*self.rot_factor)) + + # The scale is multiplied with a number + # in the area [1-scaleFactor,1+scaleFactor] + sc = min(1+self.scale_factor, + max(1-self.scale_factor, np.random.randn()*self.scale_factor+1)) + # but it is zero with probability 3/5 + if np.random.uniform() <= 0.6: + rot = 0 + + return flip, pn, rot, sc + + def rgb_processing(self, rgb_img, center, scale, rot, flip, pn): + """Process rgb image and do augmentation.""" + rgb_img = crop(rgb_img, center, scale, + [self.img_res, self.img_res], rot=rot) + # flip the image + if flip: + rgb_img = flip_img(rgb_img) + # in the rgb image we add pixel noise in a channel-wise manner + rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0])) + rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1])) + rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2])) + # (3,224,224),float,[0,1] + rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0 + return rgb_img + + def j2d_processing(self, kp, center, scale, r, f): + """Process gt 2D keypoints and apply all augmentation transforms.""" + nparts = kp.shape[0] + for i in range(nparts): + kp[i,0:2] = transform(kp[i,0:2]+1, center, scale, + [self.img_res, self.img_res], rot=r) + # convert to normalized coordinates + kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1. + # flip the x coordinates + if f: + kp = flip_kp(kp) + kp = kp.astype('float32') + return kp + + + def j3d_processing(self, S, r, f): + """Process gt 3D keypoints and apply all augmentation transforms.""" + # in-plane rotation + rot_mat = np.eye(3) + if not r == 0: + rot_rad = -r * np.pi / 180 + sn,cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0,:2] = [cs, -sn] + rot_mat[1,:2] = [sn, cs] + S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1]) + # flip the x coordinates + if f: + S = flip_kp(S) + S = S.astype('float32') + return S + + def pose_processing(self, pose, r, f): + """Process SMPL theta parameters and apply all augmentation transforms.""" + # rotation or the pose parameters + pose = pose.astype('float32') + pose[:3] = rot_aa(pose[:3], r) + # flip the pose parameters + if f: + pose = flip_pose(pose) + # (72),float + pose = pose.astype('float32') + return pose + + def get_line_no(self, idx): + return idx if self.line_list is None else self.line_list[idx] + + def get_image(self, idx): + line_no = self.get_line_no(idx) + row = self.img_tsv[line_no] + # use -1 to support old format with multiple columns. + cv2_im = img_from_base64(row[-1]) + if self.cv2_output: + return cv2_im.astype(np.float32, copy=True) + cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB) + return cv2_im + + def get_annotations(self, idx): + line_no = self.get_line_no(idx) + if self.label_tsv is not None: + row = self.label_tsv[line_no] + annotations = json.loads(row[1]) + return annotations + else: + return [] + + def get_target_from_annotations(self, annotations, img_size, idx): + # This function will be overwritten by each dataset to + # decode the labels to specific formats for each task. + return annotations + + def get_img_info(self, idx): + if self.hw_tsv is not None: + line_no = self.get_line_no(idx) + row = self.hw_tsv[line_no] + try: + # json string format with "height" and "width" being the keys + return json.loads(row[1])[0] + except ValueError: + # list of strings representing height and width in order + hw_str = row[1].split(' ') + hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])} + return hw_dict + + def get_img_key(self, idx): + line_no = self.get_line_no(idx) + # based on the overhead of reading each row. + if self.hw_tsv: + return self.hw_tsv[line_no][0] + elif self.label_tsv: + return self.label_tsv[line_no][0] + else: + return self.img_tsv[line_no][0] + + def __len__(self): + if self.line_list is None: + return self.img_tsv.num_rows() + else: + return len(self.line_list) + + def __getitem__(self, idx): + + img = self.get_image(idx) + img_key = self.get_img_key(idx) + annotations = self.get_annotations(idx) + + annotations = annotations[0] + center = annotations['center'] + scale = annotations['scale'] + has_2d_joints = annotations['has_2d_joints'] + has_3d_joints = annotations['has_3d_joints'] + joints_2d = np.asarray(annotations['2d_joints']) + joints_3d = np.asarray(annotations['3d_joints']) + + if joints_2d.ndim==3: + joints_2d = joints_2d[0] + if joints_3d.ndim==3: + joints_3d = joints_3d[0] + + # Get SMPL parameters, if available + has_smpl = np.asarray(annotations['has_smpl']) + pose = np.asarray(annotations['pose']) + betas = np.asarray(annotations['betas']) + + # Get augmentation parameters + flip,pn,rot,sc = self.augm_params() + + # Process image + img = self.rgb_processing(img, center, sc*scale, rot, flip, pn) + img = torch.from_numpy(img).float() + # Store image before normalization to use it in visualization + transfromed_img = self.normalize_img(img) + + # normalize 3d pose by aligning the wrist as the root (at origin) + root_coord = joints_3d[self.root_index,:-1] + joints_3d[:,:-1] = joints_3d[:,:-1] - root_coord[None,:] + # 3d pose augmentation (random flip + rotation, consistent to image and SMPL) + joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip) + # 2d pose augmentation + joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip) + + ################################### + # Masking percantage + # We observe that 0% or 5% works better for 3D hand mesh + # We think this is probably becasue 3D vertices are quite sparse in the down-sampled hand mesh + mvm_percent = 0.0 # or 0.05 + ################################### + + mjm_mask = np.ones((21,1)) + if self.is_train: + num_joints = 21 + pb = np.random.random_sample() + masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked + indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num) + mjm_mask[indices,:] = 0.0 + mjm_mask = torch.from_numpy(mjm_mask).float() + + mvm_mask = np.ones((195,1)) + if self.is_train: + num_vertices = 195 + pb = np.random.random_sample() + masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked + indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num) + mvm_mask[indices,:] = 0.0 + mvm_mask = torch.from_numpy(mvm_mask).float() + + meta_data = {} + meta_data['ori_img'] = img + meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float() + meta_data['betas'] = torch.from_numpy(betas).float() + meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float() + meta_data['has_3d_joints'] = has_3d_joints + meta_data['has_smpl'] = has_smpl + meta_data['mjm_mask'] = mjm_mask + meta_data['mvm_mask'] = mvm_mask + + # Get 2D keypoints and apply augmentation transforms + meta_data['has_2d_joints'] = has_2d_joints + meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float() + + meta_data['scale'] = float(sc * scale) + meta_data['center'] = np.asarray(center).astype(np.float32) + + return img_key, transfromed_img, meta_data + + +class HandMeshTSVYamlDataset(HandMeshTSVDataset): + """ TSVDataset taking a Yaml file for easy function call + """ + def __init__(self, args, yaml_file, is_train=True, cv2_output=False, scale_factor=1): + self.cfg = load_from_yaml_file(yaml_file) + self.is_composite = self.cfg.get('composite', False) + self.root = op.dirname(yaml_file) + + if self.is_composite==False: + img_file = find_file_path_in_yaml(self.cfg['img'], self.root) + label_file = find_file_path_in_yaml(self.cfg.get('label', None), + self.root) + hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root) + linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), + self.root) + else: + img_file = self.cfg['img'] + hw_file = self.cfg['hw'] + label_file = self.cfg.get('label', None) + linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), + self.root) + + super(HandMeshTSVYamlDataset, self).__init__( + args, img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor) diff --git a/mesh_graphormer/datasets/human_mesh_tsv.py b/mesh_graphormer/datasets/human_mesh_tsv.py new file mode 100644 index 0000000..ceebd8f --- /dev/null +++ b/mesh_graphormer/datasets/human_mesh_tsv.py @@ -0,0 +1,337 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + +import cv2 +import math +import json +from PIL import Image +import os.path as op +import numpy as np +import code + +from mesh_graphormer.utils.tsv_file import TSVFile, CompositeTSVFile +from mesh_graphormer.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml +from mesh_graphormer.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa +import torch +import torchvision.transforms as transforms + + +class MeshTSVDataset(object): + def __init__(self, img_file, label_file=None, hw_file=None, + linelist_file=None, is_train=True, cv2_output=False, scale_factor=1): + + self.img_file = img_file + self.label_file = label_file + self.hw_file = hw_file + self.linelist_file = linelist_file + self.img_tsv = self.get_tsv_file(img_file) + self.label_tsv = None if label_file is None else self.get_tsv_file(label_file) + self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file) + + if self.is_composite: + assert op.isfile(self.linelist_file) + self.line_list = [i for i in range(self.hw_tsv.num_rows())] + else: + self.line_list = load_linelist_file(linelist_file) + + self.cv2_output = cv2_output + self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + self.is_train = is_train + self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor] + self.noise_factor = 0.4 + self.rot_factor = 30 # Random rotation in the range [-rot_factor, rot_factor] + self.img_res = 224 + + self.image_keys = self.prepare_image_keys() + + self.joints_definition = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder', + 'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear') + self.pelvis_index = self.joints_definition.index('Pelvis') + + def get_tsv_file(self, tsv_file): + if tsv_file: + if self.is_composite: + return CompositeTSVFile(tsv_file, self.linelist_file, + root=self.root) + tsv_path = find_file_path_in_yaml(tsv_file, self.root) + return TSVFile(tsv_path) + + def get_valid_tsv(self): + # sorted by file size + if self.hw_tsv: + return self.hw_tsv + if self.label_tsv: + return self.label_tsv + + def prepare_image_keys(self): + tsv = self.get_valid_tsv() + return [tsv.get_key(i) for i in range(tsv.num_rows())] + + def prepare_image_key_to_index(self): + tsv = self.get_valid_tsv() + return {tsv.get_key(i) : i for i in range(tsv.num_rows())} + + + def augm_params(self): + """Get augmentation parameters.""" + flip = 0 # flipping + pn = np.ones(3) # per channel pixel-noise + rot = 0 # rotation + sc = 1 # scaling + if self.is_train: + # We flip with probability 1/2 + if np.random.uniform() <= 0.5: + flip = 1 + + # Each channel is multiplied with a number + # in the area [1-opt.noiseFactor,1+opt.noiseFactor] + pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3) + + # The rotation is a number in the area [-2*rotFactor, 2*rotFactor] + rot = min(2*self.rot_factor, + max(-2*self.rot_factor, np.random.randn()*self.rot_factor)) + + # The scale is multiplied with a number + # in the area [1-scaleFactor,1+scaleFactor] + sc = min(1+self.scale_factor, + max(1-self.scale_factor, np.random.randn()*self.scale_factor+1)) + # but it is zero with probability 3/5 + if np.random.uniform() <= 0.6: + rot = 0 + + return flip, pn, rot, sc + + def rgb_processing(self, rgb_img, center, scale, rot, flip, pn): + """Process rgb image and do augmentation.""" + rgb_img = crop(rgb_img, center, scale, + [self.img_res, self.img_res], rot=rot) + # flip the image + if flip: + rgb_img = flip_img(rgb_img) + # in the rgb image we add pixel noise in a channel-wise manner + rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0])) + rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1])) + rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2])) + # (3,224,224),float,[0,1] + rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0 + return rgb_img + + def j2d_processing(self, kp, center, scale, r, f): + """Process gt 2D keypoints and apply all augmentation transforms.""" + nparts = kp.shape[0] + for i in range(nparts): + kp[i,0:2] = transform(kp[i,0:2]+1, center, scale, + [self.img_res, self.img_res], rot=r) + # convert to normalized coordinates + kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1. + # flip the x coordinates + if f: + kp = flip_kp(kp) + kp = kp.astype('float32') + return kp + + def j3d_processing(self, S, r, f): + """Process gt 3D keypoints and apply all augmentation transforms.""" + # in-plane rotation + rot_mat = np.eye(3) + if not r == 0: + rot_rad = -r * np.pi / 180 + sn,cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0,:2] = [cs, -sn] + rot_mat[1,:2] = [sn, cs] + S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1]) + # flip the x coordinates + if f: + S = flip_kp(S) + S = S.astype('float32') + return S + + def pose_processing(self, pose, r, f): + """Process SMPL theta parameters and apply all augmentation transforms.""" + # rotation or the pose parameters + pose = pose.astype('float32') + pose[:3] = rot_aa(pose[:3], r) + # flip the pose parameters + if f: + pose = flip_pose(pose) + # (72),float + pose = pose.astype('float32') + return pose + + def get_line_no(self, idx): + return idx if self.line_list is None else self.line_list[idx] + + def get_image(self, idx): + line_no = self.get_line_no(idx) + row = self.img_tsv[line_no] + # use -1 to support old format with multiple columns. + cv2_im = img_from_base64(row[-1]) + if self.cv2_output: + return cv2_im.astype(np.float32, copy=True) + cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB) + + return cv2_im + + def get_annotations(self, idx): + line_no = self.get_line_no(idx) + if self.label_tsv is not None: + row = self.label_tsv[line_no] + annotations = json.loads(row[1]) + return annotations + else: + return [] + + def get_target_from_annotations(self, annotations, img_size, idx): + # This function will be overwritten by each dataset to + # decode the labels to specific formats for each task. + return annotations + + + def get_img_info(self, idx): + if self.hw_tsv is not None: + line_no = self.get_line_no(idx) + row = self.hw_tsv[line_no] + try: + # json string format with "height" and "width" being the keys + return json.loads(row[1])[0] + except ValueError: + # list of strings representing height and width in order + hw_str = row[1].split(' ') + hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])} + return hw_dict + + def get_img_key(self, idx): + line_no = self.get_line_no(idx) + # based on the overhead of reading each row. + if self.hw_tsv: + return self.hw_tsv[line_no][0] + elif self.label_tsv: + return self.label_tsv[line_no][0] + else: + return self.img_tsv[line_no][0] + + def __len__(self): + if self.line_list is None: + return self.img_tsv.num_rows() + else: + return len(self.line_list) + + def __getitem__(self, idx): + + img = self.get_image(idx) + img_key = self.get_img_key(idx) + annotations = self.get_annotations(idx) + + annotations = annotations[0] + center = annotations['center'] + scale = annotations['scale'] + has_2d_joints = annotations['has_2d_joints'] + has_3d_joints = annotations['has_3d_joints'] + joints_2d = np.asarray(annotations['2d_joints']) + joints_3d = np.asarray(annotations['3d_joints']) + + if joints_2d.ndim==3: + joints_2d = joints_2d[0] + if joints_3d.ndim==3: + joints_3d = joints_3d[0] + + # Get SMPL parameters, if available + has_smpl = np.asarray(annotations['has_smpl']) + pose = np.asarray(annotations['pose']) + betas = np.asarray(annotations['betas']) + + try: + gender = annotations['gender'] + except KeyError: + gender = 'none' + + # Get augmentation parameters + flip,pn,rot,sc = self.augm_params() + + # Process image + img = self.rgb_processing(img, center, sc*scale, rot, flip, pn) + img = torch.from_numpy(img).float() + # Store image before normalization to use it in visualization + transfromed_img = self.normalize_img(img) + + # normalize 3d pose by aligning the pelvis as the root (at origin) + root_pelvis = joints_3d[self.pelvis_index,:-1] + joints_3d[:,:-1] = joints_3d[:,:-1] - root_pelvis[None,:] + # 3d pose augmentation (random flip + rotation, consistent to image and SMPL) + joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip) + # 2d pose augmentation + joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip) + + ################################### + # Masking percantage + # We observe that 30% works better for human body mesh. Further details are reported in the paper. + mvm_percent = 0.3 + ################################### + + mjm_mask = np.ones((14,1)) + if self.is_train: + num_joints = 14 + pb = np.random.random_sample() + masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked + indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num) + mjm_mask[indices,:] = 0.0 + mjm_mask = torch.from_numpy(mjm_mask).float() + + mvm_mask = np.ones((431,1)) + if self.is_train: + num_vertices = 431 + pb = np.random.random_sample() + masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked + indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num) + mvm_mask[indices,:] = 0.0 + mvm_mask = torch.from_numpy(mvm_mask).float() + + meta_data = {} + meta_data['ori_img'] = img + meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float() + meta_data['betas'] = torch.from_numpy(betas).float() + meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float() + meta_data['has_3d_joints'] = has_3d_joints + meta_data['has_smpl'] = has_smpl + + meta_data['mjm_mask'] = mjm_mask + meta_data['mvm_mask'] = mvm_mask + + # Get 2D keypoints and apply augmentation transforms + meta_data['has_2d_joints'] = has_2d_joints + meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float() + meta_data['scale'] = float(sc * scale) + meta_data['center'] = np.asarray(center).astype(np.float32) + meta_data['gender'] = gender + return img_key, transfromed_img, meta_data + + + +class MeshTSVYamlDataset(MeshTSVDataset): + """ TSVDataset taking a Yaml file for easy function call + """ + def __init__(self, yaml_file, is_train=True, cv2_output=False, scale_factor=1): + self.cfg = load_from_yaml_file(yaml_file) + self.is_composite = self.cfg.get('composite', False) + self.root = op.dirname(yaml_file) + + if self.is_composite==False: + img_file = find_file_path_in_yaml(self.cfg['img'], self.root) + label_file = find_file_path_in_yaml(self.cfg.get('label', None), + self.root) + hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root) + linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), + self.root) + else: + img_file = self.cfg['img'] + hw_file = self.cfg['hw'] + label_file = self.cfg.get('label', None) + linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), + self.root) + + super(MeshTSVYamlDataset, self).__init__( + img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor) diff --git a/mesh_graphormer/modeling/__init__.py b/mesh_graphormer/modeling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mesh_graphormer/modeling/_gcnn.py b/mesh_graphormer/modeling/_gcnn.py new file mode 100644 index 0000000..43bfe63 --- /dev/null +++ b/mesh_graphormer/modeling/_gcnn.py @@ -0,0 +1,184 @@ +from __future__ import division +import torch +import torch.nn.functional as F +import numpy as np +import scipy.sparse +import math +from pathlib import Path +data_path = Path(__file__).parent / "data" + + +sparse_to_dense = lambda x: x +device = "cuda" + +class SparseMM(torch.autograd.Function): + """Redefine sparse @ dense matrix multiplication to enable backpropagation. + The builtin matrix multiplication operation does not support backpropagation in some cases. + """ + @staticmethod + def forward(ctx, sparse, dense): + ctx.req_grad = dense.requires_grad + ctx.save_for_backward(sparse) + return torch.matmul(sparse, dense) + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + sparse, = ctx.saved_tensors + if ctx.req_grad: + grad_input = torch.matmul(sparse.t(), grad_output) + return None, grad_input + +def spmm(sparse, dense): + sparse = sparse.to(device) + dense = dense.to(device) + return SparseMM.apply(sparse, dense) + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + +class BertLayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.bias = torch.nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class GraphResBlock(torch.nn.Module): + """ + Graph Residual Block similar to the Bottleneck Residual Block in ResNet + """ + def __init__(self, in_channels, out_channels, mesh_type='body'): + super(GraphResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.lin1 = GraphLinear(in_channels, out_channels // 2) + self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type) + self.lin2 = GraphLinear(out_channels // 2, out_channels) + self.skip_conv = GraphLinear(in_channels, out_channels) + # print('Use BertLayerNorm in GraphResBlock') + self.pre_norm = BertLayerNorm(in_channels) + self.norm1 = BertLayerNorm(out_channels // 2) + self.norm2 = BertLayerNorm(out_channels // 2) + + def forward(self, x): + trans_y = F.relu(self.pre_norm(x)).transpose(1,2) + y = self.lin1(trans_y).transpose(1,2) + + y = F.relu(self.norm1(y)) + y = self.conv(y) + + trans_y = F.relu(self.norm2(y)).transpose(1,2) + y = self.lin2(trans_y).transpose(1,2) + + z = x+y + + return z + +# class GraphResBlock(torch.nn.Module): +# """ +# Graph Residual Block similar to the Bottleneck Residual Block in ResNet +# """ +# def __init__(self, in_channels, out_channels, mesh_type='body'): +# super(GraphResBlock, self).__init__() +# self.in_channels = in_channels +# self.out_channels = out_channels +# self.conv = GraphConvolution(self.in_channels, self.out_channels, mesh_type) +# print('Use BertLayerNorm and GeLU in GraphResBlock') +# self.norm = BertLayerNorm(self.out_channels) +# def forward(self, x): +# y = self.conv(x) +# y = self.norm(y) +# y = gelu(y) +# z = x+y +# return z + +class GraphLinear(torch.nn.Module): + """ + Generalization of 1x1 convolutions on Graphs + """ + def __init__(self, in_channels, out_channels): + super(GraphLinear, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.W = torch.nn.Parameter(torch.FloatTensor(out_channels, in_channels)) + self.b = torch.nn.Parameter(torch.FloatTensor(out_channels)) + self.reset_parameters() + + def reset_parameters(self): + w_stdv = 1 / (self.in_channels * self.out_channels) + self.W.data.uniform_(-w_stdv, w_stdv) + self.b.data.uniform_(-w_stdv, w_stdv) + + def forward(self, x): + return torch.matmul(self.W[None, :], x) + self.b[None, :, None] + +class GraphConvolution(torch.nn.Module): + """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907.""" + def __init__(self, in_features, out_features, mesh='body', bias=True): + super(GraphConvolution, self).__init__() + self.in_features = in_features + self.out_features = out_features + + if mesh=='body': + adj_indices = torch.load(data_path / 'smpl_431_adjmat_indices.pt') + adj_mat_value = torch.load(data_path / 'smpl_431_adjmat_values.pt') + adj_mat_size = torch.load(data_path / 'smpl_431_adjmat_size.pt') + elif mesh=='hand': + adj_indices = torch.load(data_path / 'mano_195_adjmat_indices.pt') + adj_mat_value = torch.load(data_path / 'mano_195_adjmat_values.pt') + adj_mat_size = torch.load(data_path / 'mano_195_adjmat_size.pt') + + self.adjmat = sparse_to_dense(torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size)).to(device) + + self.weight = torch.nn.Parameter(torch.FloatTensor(in_features, out_features)) + if bias: + self.bias = torch.nn.Parameter(torch.FloatTensor(out_features)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + # stdv = 1. / math.sqrt(self.weight.size(1)) + stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.uniform_(-stdv, stdv) + + def forward(self, x): + if x.ndimension() == 2: + support = torch.matmul(x, self.weight) + output = torch.matmul(self.adjmat, support) + if self.bias is not None: + output = output + self.bias + return output + else: + output = [] + for i in range(x.shape[0]): + support = torch.matmul(x[i], self.weight) + # output.append(torch.matmul(self.adjmat, support)) + output.append(spmm(self.adjmat, support)) + output = torch.stack(output, dim=0) + if self.bias is not None: + output = output + self.bias + return output + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + str(self.in_features) + ' -> ' \ + + str(self.out_features) + ')' \ No newline at end of file diff --git a/mesh_graphormer/modeling/_mano.py b/mesh_graphormer/modeling/_mano.py new file mode 100644 index 0000000..01d2d67 --- /dev/null +++ b/mesh_graphormer/modeling/_mano.py @@ -0,0 +1,184 @@ +""" +This file contains the MANO defination and mesh sampling operations for MANO mesh + +Adapted from opensource projects +MANOPTH (https://github.com/hassony2/manopth) +Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE) +GraphCMR (https://github.com/nkolot/GraphCMR/) +""" + +from __future__ import division +import numpy as np +import torch +import torch.nn as nn +import os.path as osp +import json +import code +from manopth.manolayer import ManoLayer +import scipy.sparse +import mesh_graphormer.modeling.data.config as cfg +from pathlib import Path + + +sparse_to_dense = lambda x: x +device = "cuda" + +class MANO(nn.Module): + def __init__(self): + super(MANO, self).__init__() + + self.mano_dir = str(Path(__file__).parent / "data") + self.layer = self.get_layer() + self.vertex_num = 778 + self.face = self.layer.th_faces.numpy() + self.joint_regressor = self.layer.th_J_regressor.numpy() + + self.joint_num = 21 + self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') + self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) ) + self.root_joint_idx = self.joints_name.index('Wrist') + + # add fingertips to joint_regressor + self.fingertip_vertex_idx = [745, 317, 444, 556, 673] # mesh vertex idx (right hand) + thumbtip_onehot = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + indextip_onehot = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + middletip_onehot = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + ringtip_onehot = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + pinkytip_onehot = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot)) + self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:] + joint_regressor_torch = torch.from_numpy(self.joint_regressor).float() + self.register_buffer('joint_regressor_torch', joint_regressor_torch) + + def get_layer(self): + return ManoLayer(mano_root=osp.join(self.mano_dir), flat_hand_mean=False, use_pca=False) # load right hand MANO model + + def get_3d_joints(self, vertices): + """ + This method is used to get the joint locations from the SMPL mesh + Input: + vertices: size = (B, 778, 3) + Output: + 3D joints: size = (B, 21, 3) + """ + joints = torch.einsum('bik,ji->bjk', [vertices, self.joint_regressor_torch]) + return joints + + +class SparseMM(torch.autograd.Function): + """Redefine sparse @ dense matrix multiplication to enable backpropagation. + The builtin matrix multiplication operation does not support backpropagation in some cases. + """ + @staticmethod + def forward(ctx, sparse, dense): + ctx.req_grad = dense.requires_grad + ctx.save_for_backward(sparse) + return torch.matmul(sparse, dense) + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + sparse, = ctx.saved_tensors + if ctx.req_grad: + grad_input = torch.matmul(sparse.t(), grad_output) + return None, grad_input + +def spmm(sparse, dense): + sparse = sparse.to(device) + dense = dense.to(device) + return SparseMM.apply(sparse, dense) + + +def scipy_to_pytorch(A, U, D): + """Convert scipy sparse matrices to pytorch sparse matrix.""" + ptU = [] + ptD = [] + + for i in range(len(U)): + u = scipy.sparse.coo_matrix(U[i]) + i = torch.LongTensor(np.array([u.row, u.col])) + v = torch.FloatTensor(u.data) + ptU.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, u.shape))) + + for i in range(len(D)): + d = scipy.sparse.coo_matrix(D[i]) + i = torch.LongTensor(np.array([d.row, d.col])) + v = torch.FloatTensor(d.data) + ptD.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, d.shape))) + + return ptU, ptD + + +def adjmat_sparse(adjmat, nsize=1): + """Create row-normalized sparse graph adjacency matrix.""" + adjmat = scipy.sparse.csr_matrix(adjmat) + if nsize > 1: + orig_adjmat = adjmat.copy() + for _ in range(1, nsize): + adjmat = adjmat * orig_adjmat + adjmat.data = np.ones_like(adjmat.data) + for i in range(adjmat.shape[0]): + adjmat[i,i] = 1 + num_neighbors = np.array(1 / adjmat.sum(axis=-1)) + adjmat = adjmat.multiply(num_neighbors) + adjmat = scipy.sparse.coo_matrix(adjmat) + row = adjmat.row + col = adjmat.col + data = adjmat.data + i = torch.LongTensor(np.array([row, col])) + v = torch.from_numpy(data).float() + adjmat = sparse_to_dense(torch.sparse_coo_tensor(i, v, adjmat.shape)) + return adjmat + +def get_graph_params(filename, nsize=1): + """Load and process graph adjacency matrix and upsampling/downsampling matrices.""" + data = np.load(filename, encoding='latin1', allow_pickle=True) + A = data['A'] + U = data['U'] + D = data['D'] + U, D = scipy_to_pytorch(A, U, D) + A = [adjmat_sparse(a, nsize=nsize) for a in A] + return A, U, D + + +class Mesh(object): + """Mesh object that is used for handling certain graph operations.""" + def __init__(self, filename=cfg.MANO_sampling_matrix, + num_downsampling=1, nsize=1, device=torch.device('cuda')): + self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) + # self._A = [a.to(device) for a in self._A] + self._U = [u.to(device) for u in self._U] + self._D = [d.to(device) for d in self._D] + self.num_downsampling = num_downsampling + + def downsample(self, x, n1=0, n2=None): + """Downsample mesh.""" + if n2 is None: + n2 = self.num_downsampling + if x.ndimension() < 3: + for i in range(n1, n2): + x = spmm(self._D[i], x) + elif x.ndimension() == 3: + out = [] + for i in range(x.shape[0]): + y = x[i] + for j in range(n1, n2): + y = spmm(self._D[j], y) + out.append(y) + x = torch.stack(out, dim=0) + return x + + def upsample(self, x, n1=1, n2=0): + """Upsample mesh.""" + if x.ndimension() < 3: + for i in reversed(range(n2, n1)): + x = spmm(self._U[i], x) + elif x.ndimension() == 3: + out = [] + for i in range(x.shape[0]): + y = x[i] + for j in reversed(range(n2, n1)): + y = spmm(self._U[j], y) + out.append(y) + x = torch.stack(out, dim=0) + return x diff --git a/mesh_graphormer/modeling/_smpl.py b/mesh_graphormer/modeling/_smpl.py new file mode 100644 index 0000000..15b0f1a --- /dev/null +++ b/mesh_graphormer/modeling/_smpl.py @@ -0,0 +1,283 @@ +""" +This file contains the definition of the SMPL model + +It is adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) +""" +from __future__ import division + +import torch +import torch.nn as nn +import numpy as np +import scipy.sparse +try: + import cPickle as pickle +except ImportError: + import pickle + +from mesh_graphormer.utils.geometric_layers import rodrigues +import mesh_graphormer.modeling.data.config as cfg + + +sparse_to_dense = lambda x: x +device = "cuda" + +class SMPL(nn.Module): + + def __init__(self, gender='neutral'): + super(SMPL, self).__init__() + + if gender=='m': + model_file=cfg.SMPL_Male + elif gender=='f': + model_file=cfg.SMPL_Female + else: + model_file=cfg.SMPL_FILE + + smpl_model = pickle.load(open(model_file, 'rb'), encoding='latin1') + J_regressor = smpl_model['J_regressor'].tocoo() + row = J_regressor.row + col = J_regressor.col + data = J_regressor.data + i = torch.LongTensor([row, col]) + v = torch.FloatTensor(data) + J_regressor_shape = [24, 6890] + self.register_buffer('J_regressor', torch.sparse_coo_tensor(i, v, J_regressor_shape).to_dense()) + self.register_buffer('weights', torch.FloatTensor(smpl_model['weights'])) + self.register_buffer('posedirs', torch.FloatTensor(smpl_model['posedirs'])) + self.register_buffer('v_template', torch.FloatTensor(smpl_model['v_template'])) + self.register_buffer('shapedirs', torch.FloatTensor(np.array(smpl_model['shapedirs']))) + self.register_buffer('faces', torch.from_numpy(smpl_model['f'].astype(np.int64))) + self.register_buffer('kintree_table', torch.from_numpy(smpl_model['kintree_table'].astype(np.int64))) + id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])} + self.register_buffer('parent', torch.LongTensor([id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])])) + + self.pose_shape = [24, 3] + self.beta_shape = [10] + self.translation_shape = [3] + + self.pose = torch.zeros(self.pose_shape) + self.beta = torch.zeros(self.beta_shape) + self.translation = torch.zeros(self.translation_shape) + + self.verts = None + self.J = None + self.R = None + + J_regressor_extra = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_TRAIN_EXTRA)).float() + self.register_buffer('J_regressor_extra', J_regressor_extra) + self.joints_idx = cfg.JOINTS_IDX + + J_regressor_h36m_correct = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M_correct)).float() + self.register_buffer('J_regressor_h36m_correct', J_regressor_h36m_correct) + + + def forward(self, pose, beta): + device = pose.device + batch_size = pose.shape[0] + v_template = self.v_template[None, :] + shapedirs = self.shapedirs.view(-1,10)[None, :].expand(batch_size, -1, -1) + beta = beta[:, :, None] + v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template + # batched sparse matmul not supported in pytorch + J = [] + for i in range(batch_size): + J.append(torch.matmul(self.J_regressor, v_shaped[i])) + J = torch.stack(J, dim=0) + # input it rotmat: (bs,24,3,3) + if pose.ndimension() == 4: + R = pose + # input it rotmat: (bs,72) + elif pose.ndimension() == 2: + pose_cube = pose.view(-1, 3) # (batch_size * 24, 1, 3) + R = rodrigues(pose_cube).view(batch_size, 24, 3, 3) + R = R.view(batch_size, 24, 3, 3) + I_cube = torch.eye(3)[None, None, :].to(device) + # I_cube = torch.eye(3)[None, None, :].expand(theta.shape[0], R.shape[1]-1, -1, -1) + lrotmin = (R[:,1:,:] - I_cube).view(batch_size, -1) + posedirs = self.posedirs.view(-1,207)[None, :].expand(batch_size, -1, -1) + v_posed = v_shaped + torch.matmul(posedirs, lrotmin[:, :, None]).view(-1, 6890, 3) + J_ = J.clone() + J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :] + G_ = torch.cat([R, J_[:, :, :, None]], dim=-1) + pad_row = torch.FloatTensor([0,0,0,1]).to(device).view(1,1,1,4).expand(batch_size, 24, -1, -1) + G_ = torch.cat([G_, pad_row], dim=2) + G = [G_[:, 0].clone()] + for i in range(1, 24): + G.append(torch.matmul(G[self.parent[i-1]], G_[:, i, :, :])) + G = torch.stack(G, dim=1) + + rest = torch.cat([J, torch.zeros(batch_size, 24, 1).to(device)], dim=2).view(batch_size, 24, 4, 1) + zeros = torch.zeros(batch_size, 24, 4, 3).to(device) + rest = torch.cat([zeros, rest], dim=-1) + rest = torch.matmul(G, rest) + G = G - rest + T = torch.matmul(self.weights, G.permute(1,0,2,3).contiguous().view(24,-1)).view(6890, batch_size, 4, 4).transpose(0,1) + rest_shape_h = torch.cat([v_posed, torch.ones_like(v_posed)[:, :, [0]]], dim=-1) + v = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0] + return v + + def get_joints(self, vertices): + """ + This method is used to get the joint locations from the SMPL mesh + Input: + vertices: size = (B, 6890, 3) + Output: + 3D joints: size = (B, 38, 3) + """ + joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor]) + joints_extra = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_extra]) + joints = torch.cat((joints, joints_extra), dim=1) + joints = joints[:, cfg.JOINTS_IDX] + return joints + + def get_h36m_joints(self, vertices): + """ + This method is used to get the joint locations from the SMPL mesh + Input: + vertices: size = (B, 6890, 3) + Output: + 3D joints: size = (B, 24, 3) + """ + joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct]) + return joints + +class SparseMM(torch.autograd.Function): + """Redefine sparse @ dense matrix multiplication to enable backpropagation. + The builtin matrix multiplication operation does not support backpropagation in some cases. + """ + @staticmethod + def forward(ctx, sparse, dense): + ctx.req_grad = dense.requires_grad + ctx.save_for_backward(sparse) + return torch.matmul(sparse, dense) + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + sparse, = ctx.saved_tensors + if ctx.req_grad: + grad_input = torch.matmul(sparse.t(), grad_output) + return None, grad_input + +def spmm(sparse, dense): + sparse = sparse.to(device) + dense = dense.to(device) + return SparseMM.apply(sparse, dense) + + +def scipy_to_pytorch(A, U, D): + """Convert scipy sparse matrices to pytorch sparse matrix.""" + ptU = [] + ptD = [] + + for i in range(len(U)): + u = scipy.sparse.coo_matrix(U[i]) + i = torch.LongTensor(np.array([u.row, u.col])) + v = torch.FloatTensor(u.data) + ptU.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, u.shape))) + + for i in range(len(D)): + d = scipy.sparse.coo_matrix(D[i]) + i = torch.LongTensor(np.array([d.row, d.col])) + v = torch.FloatTensor(d.data) + ptD.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, d.shape))) + + return ptU, ptD + + +def adjmat_sparse(adjmat, nsize=1): + """Create row-normalized sparse graph adjacency matrix.""" + adjmat = scipy.sparse.csr_matrix(adjmat) + if nsize > 1: + orig_adjmat = adjmat.copy() + for _ in range(1, nsize): + adjmat = adjmat * orig_adjmat + adjmat.data = np.ones_like(adjmat.data) + for i in range(adjmat.shape[0]): + adjmat[i,i] = 1 + num_neighbors = np.array(1 / adjmat.sum(axis=-1)) + adjmat = adjmat.multiply(num_neighbors) + adjmat = scipy.sparse.coo_matrix(adjmat) + row = adjmat.row + col = adjmat.col + data = adjmat.data + i = torch.LongTensor(np.array([row, col])) + v = torch.from_numpy(data).float() + adjmat = sparse_to_dense(torch.sparse_coo_tensor(i, v, adjmat.shape)) + return adjmat + +def get_graph_params(filename, nsize=1): + """Load and process graph adjacency matrix and upsampling/downsampling matrices.""" + data = np.load(filename, encoding='latin1', allow_pickle=True) + A = data['A'] + U = data['U'] + D = data['D'] + U, D = scipy_to_pytorch(A, U, D) + A = [adjmat_sparse(a, nsize=nsize) for a in A] + return A, U, D + + +class Mesh(object): + """Mesh object that is used for handling certain graph operations.""" + def __init__(self, filename=cfg.SMPL_sampling_matrix, + num_downsampling=1, nsize=1, device=torch.device('cuda')): + self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) + # self._A = [a.to(device) for a in self._A] + self._U = [u.to(device) for u in self._U] + self._D = [d.to(device) for d in self._D] + self.num_downsampling = num_downsampling + + # load template vertices from SMPL and normalize them + smpl = SMPL() + ref_vertices = smpl.v_template + center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None] + ref_vertices -= center + ref_vertices /= ref_vertices.abs().max().item() + + self._ref_vertices = ref_vertices.to(device) + self.faces = smpl.faces.int().to(device) + + # @property + # def adjmat(self): + # """Return the graph adjacency matrix at the specified subsampling level.""" + # return self._A[self.num_downsampling].float() + + @property + def ref_vertices(self): + """Return the template vertices at the specified subsampling level.""" + ref_vertices = self._ref_vertices + for i in range(self.num_downsampling): + ref_vertices = torch.spmm(self._D[i], ref_vertices) + return ref_vertices + + def downsample(self, x, n1=0, n2=None): + """Downsample mesh.""" + if n2 is None: + n2 = self.num_downsampling + if x.ndimension() < 3: + for i in range(n1, n2): + x = spmm(self._D[i], x) + elif x.ndimension() == 3: + out = [] + for i in range(x.shape[0]): + y = x[i] + for j in range(n1, n2): + y = spmm(self._D[j], y) + out.append(y) + x = torch.stack(out, dim=0) + return x + + def upsample(self, x, n1=1, n2=0): + """Upsample mesh.""" + if x.ndimension() < 3: + for i in reversed(range(n2, n1)): + x = spmm(self._U[i], x) + elif x.ndimension() == 3: + out = [] + for i in range(x.shape[0]): + y = x[i] + for j in reversed(range(n2, n1)): + y = spmm(self._U[j], y) + out.append(y) + x = torch.stack(out, dim=0) + return x diff --git a/mesh_graphormer/modeling/bert/__init__.py b/mesh_graphormer/modeling/bert/__init__.py new file mode 100644 index 0000000..197c5b9 --- /dev/null +++ b/mesh_graphormer/modeling/bert/__init__.py @@ -0,0 +1,17 @@ +__version__ = "1.0.0" + +from .modeling_bert import (BertConfig, BertModel, + load_tf_weights_in_bert) + +from .modeling_graphormer import Graphormer + +from .e2e_body_network import Graphormer_Body_Network + +from .e2e_hand_network import Graphormer_Hand_Network + +CONFIG_NAME = "config.json" + +from .modeling_utils import (WEIGHTS_NAME, TF_WEIGHTS_NAME, + PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) + +from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE) diff --git a/mesh_graphormer/modeling/bert/bert-base-uncased/config.json b/mesh_graphormer/modeling/bert/bert-base-uncased/config.json new file mode 100644 index 0000000..7927667 --- /dev/null +++ b/mesh_graphormer/modeling/bert/bert-base-uncased/config.json @@ -0,0 +1,16 @@ +{ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "type_vocab_size": 2, + "vocab_size": 30522 +} diff --git a/mesh_graphormer/modeling/bert/e2e_body_network.py b/mesh_graphormer/modeling/bert/e2e_body_network.py new file mode 100644 index 0000000..c958047 --- /dev/null +++ b/mesh_graphormer/modeling/bert/e2e_body_network.py @@ -0,0 +1,103 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + +import torch +import mesh_graphormer.modeling.data.config as cfg + +device = "cuda" + +class Graphormer_Body_Network(torch.nn.Module): + ''' + End-to-end Graphormer network for human pose and mesh reconstruction from a single image. + ''' + def __init__(self, args, config, backbone, trans_encoder, mesh_sampler): + super(Graphormer_Body_Network, self).__init__() + self.config = config + self.config.device = device + self.backbone = backbone + self.trans_encoder = trans_encoder + self.upsampling = torch.nn.Linear(431, 1723) + self.upsampling2 = torch.nn.Linear(1723, 6890) + self.cam_param_fc = torch.nn.Linear(3, 1) + self.cam_param_fc2 = torch.nn.Linear(431, 250) + self.cam_param_fc3 = torch.nn.Linear(250, 3) + self.grid_feat_dim = torch.nn.Linear(1024, 2051) + + + def forward(self, images, smpl, mesh_sampler, meta_masks=None, is_train=False): + batch_size = images.size(0) + # Generate T-pose template mesh + template_pose = torch.zeros((1,72)) + template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord + template_pose = template_pose.to(device) + template_betas = torch.zeros((1,10)).to(device) + template_vertices = smpl(template_pose, template_betas) + + # template mesh simplification + template_vertices_sub = mesh_sampler.downsample(template_vertices) + template_vertices_sub2 = mesh_sampler.downsample(template_vertices_sub, n1=1, n2=2) + + # template mesh-to-joint regression + template_3d_joints = smpl.get_h36m_joints(template_vertices) + template_pelvis = template_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:] + template_3d_joints = template_3d_joints[:,cfg.H36M_J17_TO_J14,:] + num_joints = template_3d_joints.shape[1] + + # normalize + template_3d_joints = template_3d_joints - template_pelvis[:, None, :] + template_vertices_sub2 = template_vertices_sub2 - template_pelvis[:, None, :] + + # concatinate template joints and template vertices, and then duplicate to batch size + ref_vertices = torch.cat([template_3d_joints, template_vertices_sub2],dim=1) + ref_vertices = ref_vertices.expand(batch_size, -1, -1) + + # extract grid features and global image features using a CNN backbone + image_feat, grid_feat = self.backbone(images) + # concatinate image feat and 3d mesh template + image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1) + # process grid features + grid_feat = torch.flatten(grid_feat, start_dim=2) + grid_feat = grid_feat.transpose(1,2) + grid_feat = self.grid_feat_dim(grid_feat) + # concatinate image feat and template mesh to form the joint/vertex queries + features = torch.cat([ref_vertices, image_feat], dim=2) + # prepare input tokens including joint/vertex queries and grid features + features = torch.cat([features, grid_feat],dim=1) + + if is_train==True: + # apply mask vertex/joint modeling + # meta_masks is a tensor of all the masks, randomly generated in dataloader + # we pre-define a [MASK] token, which is a floating-value vector with 0.01s + special_token = torch.ones_like(features[:,:-49,:]).to(device)*0.01 + features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks) + + # forward pass + if self.config.output_attentions==True: + features, hidden_states, att = self.trans_encoder(features) + else: + features = self.trans_encoder(features) + + pred_3d_joints = features[:,:num_joints,:] + pred_vertices_sub2 = features[:,num_joints:-49,:] + + # learn camera parameters + x = self.cam_param_fc(pred_vertices_sub2) + x = x.transpose(1,2) + x = self.cam_param_fc2(x) + x = self.cam_param_fc3(x) + cam_param = x.transpose(1,2) + cam_param = cam_param.squeeze() + + temp_transpose = pred_vertices_sub2.transpose(1,2) + pred_vertices_sub = self.upsampling(temp_transpose) + pred_vertices_full = self.upsampling2(pred_vertices_sub) + pred_vertices_sub = pred_vertices_sub.transpose(1,2) + pred_vertices_full = pred_vertices_full.transpose(1,2) + + if self.config.output_attentions==True: + return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full, hidden_states, att + else: + return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full \ No newline at end of file diff --git a/mesh_graphormer/modeling/bert/e2e_hand_network.py b/mesh_graphormer/modeling/bert/e2e_hand_network.py new file mode 100644 index 0000000..4dc7385 --- /dev/null +++ b/mesh_graphormer/modeling/bert/e2e_hand_network.py @@ -0,0 +1,94 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + +import torch +import mesh_graphormer.modeling.data.config as cfg + +device = "cuda" + +class Graphormer_Hand_Network(torch.nn.Module): + ''' + End-to-end Graphormer network for hand pose and mesh reconstruction from a single image. + ''' + def __init__(self, args, config, backbone, trans_encoder): + super(Graphormer_Hand_Network, self).__init__() + self.config = config + self.backbone = backbone + self.trans_encoder = trans_encoder + self.upsampling = torch.nn.Linear(195, 778) + self.cam_param_fc = torch.nn.Linear(3, 1) + self.cam_param_fc2 = torch.nn.Linear(195+21, 150) + self.cam_param_fc3 = torch.nn.Linear(150, 3) + self.grid_feat_dim = torch.nn.Linear(1024, 2051) + + def forward(self, images, mesh_model, mesh_sampler, meta_masks=None, is_train=False): + batch_size = images.size(0) + # Generate T-pose template mesh + template_pose = torch.zeros((1,48)) + template_pose = template_pose.to(device) + template_betas = torch.zeros((1,10)).to(device) + template_vertices, template_3d_joints = mesh_model.layer(template_pose, template_betas) + template_vertices = template_vertices/1000.0 + template_3d_joints = template_3d_joints/1000.0 + + template_vertices_sub = mesh_sampler.downsample(template_vertices) + + # normalize + template_root = template_3d_joints[:,cfg.J_NAME.index('Wrist'),:] + template_3d_joints = template_3d_joints - template_root[:, None, :] + template_vertices = template_vertices - template_root[:, None, :] + template_vertices_sub = template_vertices_sub - template_root[:, None, :] + num_joints = template_3d_joints.shape[1] + + # concatinate template joints and template vertices, and then duplicate to batch size + ref_vertices = torch.cat([template_3d_joints, template_vertices_sub],dim=1) + ref_vertices = ref_vertices.expand(batch_size, -1, -1) + + # extract grid features and global image features using a CNN backbone + image_feat, grid_feat = self.backbone(images) + # concatinate image feat and mesh template + image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1) + # process grid features + grid_feat = torch.flatten(grid_feat, start_dim=2) + grid_feat = grid_feat.transpose(1,2) + grid_feat = self.grid_feat_dim(grid_feat) + # concatinate image feat and template mesh to form the joint/vertex queries + features = torch.cat([ref_vertices, image_feat], dim=2) + # prepare input tokens including joint/vertex queries and grid features + features = torch.cat([features, grid_feat],dim=1) + + if is_train==True: + # apply mask vertex/joint modeling + # meta_masks is a tensor of all the masks, randomly generated in dataloader + # we pre-define a [MASK] token, which is a floating-value vector with 0.01s + special_token = torch.ones_like(features[:,:-49,:]).to(device)*0.01 + features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks) + + # forward pass + if self.config.output_attentions==True: + features, hidden_states, att = self.trans_encoder(features) + else: + features = self.trans_encoder(features) + + pred_3d_joints = features[:,:num_joints,:] + pred_vertices_sub = features[:,num_joints:-49,:] + + # learn camera parameters + x = self.cam_param_fc(features[:,:-49,:]) + x = x.transpose(1,2) + x = self.cam_param_fc2(x) + x = self.cam_param_fc3(x) + cam_param = x.transpose(1,2) + cam_param = cam_param.squeeze() + + temp_transpose = pred_vertices_sub.transpose(1,2) + pred_vertices = self.upsampling(temp_transpose) + pred_vertices = pred_vertices.transpose(1,2) + + if self.config.output_attentions==True: + return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att + else: + return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices \ No newline at end of file diff --git a/mesh_graphormer/modeling/bert/file_utils.py b/mesh_graphormer/modeling/bert/file_utils.py new file mode 100644 index 0000000..0b26c2b --- /dev/null +++ b/mesh_graphormer/modeling/bert/file_utils.py @@ -0,0 +1 @@ +from transformers.file_utils import * \ No newline at end of file diff --git a/mesh_graphormer/modeling/bert/modeling_bert.py b/mesh_graphormer/modeling/bert/modeling_bert.py new file mode 100644 index 0000000..820b69a --- /dev/null +++ b/mesh_graphormer/modeling/bert/modeling_bert.py @@ -0,0 +1 @@ +from transformers.models.bert.modeling_bert import * \ No newline at end of file diff --git a/mesh_graphormer/modeling/bert/modeling_graphormer.py b/mesh_graphormer/modeling/bert/modeling_graphormer.py new file mode 100644 index 0000000..6d167c6 --- /dev/null +++ b/mesh_graphormer/modeling/bert/modeling_graphormer.py @@ -0,0 +1,328 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging +import math +import os +import code +import torch +from torch import nn +from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput +import mesh_graphormer.modeling.data.config as cfg +from mesh_graphormer.modeling._gcnn import GraphConvolution, GraphResBlock +from .modeling_utils import prune_linear_layer +LayerNormClass = torch.nn.LayerNorm +BertLayerNorm = torch.nn.LayerNorm + +device = "cuda" + + +class BertSelfAttention(nn.Module): + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + self.output_attentions = config.output_attentions + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, head_mask=None, + history_state=None): + if history_state is not None: + x_states = torch.cat([history_state, hidden_states], dim=1) + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(x_states) + mixed_value_layer = self.value(x_states) + else: + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) + return outputs + +class BertAttention(nn.Module): + def __init__(self, config): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def prune_heads(self, heads): + if len(heads) == 0: + return + mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) + for head in heads: + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + # Update hyper params + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + + def forward(self, input_tensor, attention_mask, head_mask=None, + history_state=None): + self_outputs = self.self(input_tensor, attention_mask, head_mask, + history_state) + attention_output = self.output(self_outputs[0], input_tensor) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class GraphormerLayer(nn.Module): + def __init__(self, config): + super(GraphormerLayer, self).__init__() + self.attention = BertAttention(config) + self.has_graph_conv = config.graph_conv + self.mesh_type = config.mesh_type + + if self.has_graph_conv == True: + self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type) + + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def MHA_GCN(self, hidden_states, attention_mask, head_mask=None, + history_state=None): + attention_outputs = self.attention(hidden_states, attention_mask, + head_mask, history_state) + attention_output = attention_outputs[0] + + if self.has_graph_conv==True: + if self.mesh_type == 'body': + joints = attention_output[:,0:14,:] + vertices = attention_output[:,14:-49,:] + img_tokens = attention_output[:,-49:,:] + + elif self.mesh_type == 'hand': + joints = attention_output[:,0:21,:] + vertices = attention_output[:,21:-49,:] + img_tokens = attention_output[:,-49:,:] + + vertices = self.graph_conv(vertices) + joints_vertices = torch.cat([joints,vertices,img_tokens],dim=1) + else: + joints_vertices = attention_output + + intermediate_output = self.intermediate(joints_vertices) + layer_output = self.output(intermediate_output, joints_vertices) + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + return outputs + + def forward(self, hidden_states, attention_mask, head_mask=None, + history_state=None): + return self.MHA_GCN(hidden_states, attention_mask, head_mask,history_state) + + +class GraphormerEncoder(nn.Module): + def __init__(self, config): + super(GraphormerEncoder, self).__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.layer = nn.ModuleList([GraphormerLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask, head_mask=None, + encoder_history_states=None): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + history_state = None if encoder_history_states is None else encoder_history_states[i] + layer_outputs = layer_module( + hidden_states, attention_mask, head_mask[i], + history_state) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states,) + if self.output_attentions: + outputs = outputs + (all_attentions,) + + return outputs # outputs, (hidden states), (attentions) + +class EncoderBlock(BertPreTrainedModel): + def __init__(self, config): + super(EncoderBlock, self).__init__(config) + self.config = config + self.embeddings = BertEmbeddings(config) + self.encoder = GraphormerEncoder(config) + self.pooler = BertPooler(config) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.img_dim = config.img_feature_dim + + try: + self.use_img_layernorm = config.use_img_layernorm + except: + self.use_img_layernorm = None + + self.img_embedding = nn.Linear(self.img_dim, self.config.hidden_size, bias=True) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if self.use_img_layernorm: + self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.img_layer_norm_eps) + + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, + position_ids=None, head_mask=None): + + batch_size = len(img_feats) + seq_length = len(img_feats[0]) + input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).to(device) + + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + position_embeddings = self.position_embeddings(position_ids) + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + if attention_mask.dim() == 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + elif attention_mask.dim() == 3: + extended_attention_mask = attention_mask.unsqueeze(1) + else: + raise NotImplementedError + + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer + head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + # Project input token features to have spcified hidden size + img_embedding_output = self.img_embedding(img_feats) + + # We empirically observe that adding an additional learnable position embedding leads to more stable training + embeddings = position_embeddings + img_embedding_output + + if self.use_img_layernorm: + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + + encoder_outputs = self.encoder(embeddings, + extended_attention_mask, head_mask=head_mask) + sequence_output = encoder_outputs[0] + + outputs = (sequence_output,) + if self.config.output_hidden_states: + all_hidden_states = encoder_outputs[1] + outputs = outputs + (all_hidden_states,) + if self.config.output_attentions: + all_attentions = encoder_outputs[-1] + outputs = outputs + (all_attentions,) + + return outputs + +class Graphormer(BertPreTrainedModel): + ''' + The archtecture of a transformer encoder block we used in Graphormer + ''' + def __init__(self, config): + super(Graphormer, self).__init__(config) + self.config = config + self.bert = EncoderBlock(config) + self.cls_head = nn.Linear(config.hidden_size, self.config.output_feature_dim) + self.residual = nn.Linear(config.img_feature_dim, self.config.output_feature_dim) + + def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None, + next_sentence_label=None, position_ids=None, head_mask=None): + ''' + # self.bert has three outputs + # predictions[0]: output tokens + # predictions[1]: all_hidden_states, if enable "self.config.output_hidden_states" + # predictions[2]: attentions, if enable "self.config.output_attentions" + ''' + predictions = self.bert(img_feats=img_feats, input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask, head_mask=head_mask) + + # We use "self.cls_head" to perform dimensionality reduction. We don't use it for classification. + pred_score = self.cls_head(predictions[0]) + res_img_feats = self.residual(img_feats) + pred_score = pred_score + res_img_feats + + if self.config.output_attentions and self.config.output_hidden_states: + return pred_score, predictions[1], predictions[-1] + else: + return pred_score + + \ No newline at end of file diff --git a/mesh_graphormer/modeling/bert/modeling_utils.py b/mesh_graphormer/modeling/bert/modeling_utils.py new file mode 100644 index 0000000..fcfab43 --- /dev/null +++ b/mesh_graphormer/modeling/bert/modeling_utils.py @@ -0,0 +1 @@ +from transformers.modeling_utils import * \ No newline at end of file diff --git a/mesh_graphormer/modeling/data/J_regressor_extra.npy b/mesh_graphormer/modeling/data/J_regressor_extra.npy new file mode 100644 index 0000000..d045e08 Binary files /dev/null and b/mesh_graphormer/modeling/data/J_regressor_extra.npy differ diff --git a/mesh_graphormer/modeling/data/J_regressor_h36m_correct.npy b/mesh_graphormer/modeling/data/J_regressor_h36m_correct.npy new file mode 100644 index 0000000..8fe4518 Binary files /dev/null and b/mesh_graphormer/modeling/data/J_regressor_h36m_correct.npy differ diff --git a/mesh_graphormer/modeling/data/MANO_LEFT.pkl b/mesh_graphormer/modeling/data/MANO_LEFT.pkl new file mode 100644 index 0000000..745210a Binary files /dev/null and b/mesh_graphormer/modeling/data/MANO_LEFT.pkl differ diff --git a/mesh_graphormer/modeling/data/MANO_RIGHT.pkl b/mesh_graphormer/modeling/data/MANO_RIGHT.pkl new file mode 100644 index 0000000..06b42f3 Binary files /dev/null and b/mesh_graphormer/modeling/data/MANO_RIGHT.pkl differ diff --git a/mesh_graphormer/modeling/data/README.md b/mesh_graphormer/modeling/data/README.md new file mode 100644 index 0000000..e7cfc08 --- /dev/null +++ b/mesh_graphormer/modeling/data/README.md @@ -0,0 +1,30 @@ + +# Extra data +Adapted from open source project [GraphCMR](https://github.com/nkolot/GraphCMR/) and [Pose2Mesh](https://github.com/hongsukchoi/Pose2Mesh_RELEASE) + +Our code requires additional data to run smoothly. + +### J_regressor_extra.npy +Joints regressor for joints or landmarks that are not included in the standard set of SMPL joints. + +### J_regressor_h36m_correct.npy +Joints regressor reflecting the Human3.6M joints. + +### mesh_downsampling.npz +Extra file with precomputed downsampling for the SMPL body mesh. + +### mano_downsampling.npz +Extra file with precomputed downsampling for the MANO hand mesh. + +### basicModel_neutral_lbs_10_207_0_v1.0.0.pkl +SMPL neutral model. Please visit the official website to download the file [http://smplify.is.tue.mpg.de/](http://smplify.is.tue.mpg.de/) + +### basicModel_m_lbs_10_207_0_v1.0.0.pkl +SMPL male model. Please visit the official website to download the file [https://smpl.is.tue.mpg.de/](https://smpl.is.tue.mpg.de/) + +### basicModel_f_lbs_10_207_0_v1.0.0.pkl +SMPL female model. Please visit the official website to download the file [https://smpl.is.tue.mpg.de/](https://smpl.is.tue.mpg.de/) + +### MANO_RIGHT.pkl +MANO hand model. Please visit the official website to download the file [https://mano.is.tue.mpg.de/](https://mano.is.tue.mpg.de/) + diff --git a/mesh_graphormer/modeling/data/config.py b/mesh_graphormer/modeling/data/config.py new file mode 100644 index 0000000..b8eb40f --- /dev/null +++ b/mesh_graphormer/modeling/data/config.py @@ -0,0 +1,47 @@ +""" +This file contains definitions of useful data stuctures and the paths +for the datasets and data files necessary to run the code. + +Adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) and Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE) + +""" + +from pathlib import Path +folder_path = Path(__file__).parent.parent +JOINT_REGRESSOR_TRAIN_EXTRA = folder_path / 'data/J_regressor_extra.npy' +JOINT_REGRESSOR_H36M_correct = folder_path / 'data/J_regressor_h36m_correct.npy' +SMPL_FILE = folder_path / 'data/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl' +SMPL_Male = folder_path / 'data/basicModel_m_lbs_10_207_0_v1.0.0.pkl' +SMPL_Female = folder_path / 'data/basicModel_f_lbs_10_207_0_v1.0.0.pkl' +SMPL_sampling_matrix = folder_path / 'data/mesh_downsampling.npz' +MANO_FILE = folder_path / 'data/MANO_RIGHT.pkl' +MANO_sampling_matrix = folder_path / 'data/mano_downsampling.npz' + +JOINTS_IDX = [8, 5, 29, 30, 4, 7, 21, 19, 17, 16, 18, 20, 31, 32, 33, 34, 35, 36, 37, 24, 26, 25, 28, 27] + + +""" +We follow the body joint definition, loss functions, and evaluation metrics from +open source project GraphCMR (https://github.com/nkolot/GraphCMR/) + +Each dataset uses different sets of joints. +We use a superset of 24 joints such that we include all joints from every dataset. +If a dataset doesn't provide annotations for a specific joint, we simply ignore it. +The joints used here are: +""" +J24_NAME = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder', +'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear') +H36M_J17_NAME = ( 'Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head', + 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist') +J24_TO_J14 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18] +H36M_J17_TO_J14 = [3, 2, 1, 4, 5, 6, 16, 15, 14, 11, 12, 13, 8, 10] + +""" +We follow the hand joint definition and mesh topology from +open source project Manopth (https://github.com/hassony2/manopth) + +The hand joints used here are: +""" +J_NAME = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', +'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') +ROOT_INDEX = 0 \ No newline at end of file diff --git a/mesh_graphormer/modeling/data/mano_195_adjmat_indices.pt b/mesh_graphormer/modeling/data/mano_195_adjmat_indices.pt new file mode 100644 index 0000000..b67b782 Binary files /dev/null and b/mesh_graphormer/modeling/data/mano_195_adjmat_indices.pt differ diff --git a/mesh_graphormer/modeling/data/mano_195_adjmat_size.pt b/mesh_graphormer/modeling/data/mano_195_adjmat_size.pt new file mode 100644 index 0000000..a9a25c3 Binary files /dev/null and b/mesh_graphormer/modeling/data/mano_195_adjmat_size.pt differ diff --git a/mesh_graphormer/modeling/data/mano_195_adjmat_values.pt b/mesh_graphormer/modeling/data/mano_195_adjmat_values.pt new file mode 100644 index 0000000..b642f60 Binary files /dev/null and b/mesh_graphormer/modeling/data/mano_195_adjmat_values.pt differ diff --git a/mesh_graphormer/modeling/data/mano_downsampling.npz b/mesh_graphormer/modeling/data/mano_downsampling.npz new file mode 100644 index 0000000..dba3918 Binary files /dev/null and b/mesh_graphormer/modeling/data/mano_downsampling.npz differ diff --git a/mesh_graphormer/modeling/data/mesh_downsampling.npz b/mesh_graphormer/modeling/data/mesh_downsampling.npz new file mode 100644 index 0000000..d117358 Binary files /dev/null and b/mesh_graphormer/modeling/data/mesh_downsampling.npz differ diff --git a/mesh_graphormer/modeling/data/smpl_431_adjmat_indices.pt b/mesh_graphormer/modeling/data/smpl_431_adjmat_indices.pt new file mode 100644 index 0000000..07699e2 Binary files /dev/null and b/mesh_graphormer/modeling/data/smpl_431_adjmat_indices.pt differ diff --git a/mesh_graphormer/modeling/data/smpl_431_adjmat_size.pt b/mesh_graphormer/modeling/data/smpl_431_adjmat_size.pt new file mode 100644 index 0000000..c93200d Binary files /dev/null and b/mesh_graphormer/modeling/data/smpl_431_adjmat_size.pt differ diff --git a/mesh_graphormer/modeling/data/smpl_431_adjmat_values.pt b/mesh_graphormer/modeling/data/smpl_431_adjmat_values.pt new file mode 100644 index 0000000..0489eba Binary files /dev/null and b/mesh_graphormer/modeling/data/smpl_431_adjmat_values.pt differ diff --git a/mesh_graphormer/modeling/data/smpl_431_faces.npy b/mesh_graphormer/modeling/data/smpl_431_faces.npy new file mode 100644 index 0000000..f80b756 Binary files /dev/null and b/mesh_graphormer/modeling/data/smpl_431_faces.npy differ diff --git a/mesh_graphormer/modeling/hrnet/config/__init__.py b/mesh_graphormer/modeling/hrnet/config/__init__.py new file mode 100644 index 0000000..59be749 --- /dev/null +++ b/mesh_graphormer/modeling/hrnet/config/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# ------------------------------------------------------------------------------ + +from .default import _C as config +from .default import update_config +from .models import MODEL_EXTRAS diff --git a/mesh_graphormer/modeling/hrnet/config/default.py b/mesh_graphormer/modeling/hrnet/config/default.py new file mode 100644 index 0000000..59b9843 --- /dev/null +++ b/mesh_graphormer/modeling/hrnet/config/default.py @@ -0,0 +1,138 @@ + +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# Modified by Ke Sun (sunk@mail.ustc.edu.cn) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from yacs.config import CfgNode as CN + + +_C = CN() + +_C.OUTPUT_DIR = '' +_C.LOG_DIR = '' +_C.DATA_DIR = '' +_C.GPUS = (0,) +_C.WORKERS = 4 +_C.PRINT_FREQ = 20 +_C.AUTO_RESUME = False +_C.PIN_MEMORY = True +_C.RANK = 0 + +# Cudnn related params +_C.CUDNN = CN() +_C.CUDNN.BENCHMARK = True +_C.CUDNN.DETERMINISTIC = False +_C.CUDNN.ENABLED = True + +# common params for NETWORK +_C.MODEL = CN() +_C.MODEL.NAME = 'cls_hrnet' +_C.MODEL.INIT_WEIGHTS = True +_C.MODEL.PRETRAINED = '' +_C.MODEL.NUM_JOINTS = 17 +_C.MODEL.NUM_CLASSES = 1000 +_C.MODEL.TAG_PER_JOINT = True +_C.MODEL.TARGET_TYPE = 'gaussian' +_C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256 +_C.MODEL.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32 +_C.MODEL.SIGMA = 2 +_C.MODEL.EXTRA = CN(new_allowed=True) + +_C.LOSS = CN() +_C.LOSS.USE_OHKM = False +_C.LOSS.TOPK = 8 +_C.LOSS.USE_TARGET_WEIGHT = True +_C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False + +# DATASET related params +_C.DATASET = CN() +_C.DATASET.ROOT = '' +_C.DATASET.DATASET = 'mpii' +_C.DATASET.TRAIN_SET = 'train' +_C.DATASET.TEST_SET = 'valid' +_C.DATASET.DATA_FORMAT = 'jpg' +_C.DATASET.HYBRID_JOINTS_TYPE = '' +_C.DATASET.SELECT_DATA = False + +# training data augmentation +_C.DATASET.FLIP = True +_C.DATASET.SCALE_FACTOR = 0.25 +_C.DATASET.ROT_FACTOR = 30 +_C.DATASET.PROB_HALF_BODY = 0.0 +_C.DATASET.NUM_JOINTS_HALF_BODY = 8 +_C.DATASET.COLOR_RGB = False + +# train +_C.TRAIN = CN() + +_C.TRAIN.LR_FACTOR = 0.1 +_C.TRAIN.LR_STEP = [90, 110] +_C.TRAIN.LR = 0.001 + +_C.TRAIN.OPTIMIZER = 'adam' +_C.TRAIN.MOMENTUM = 0.9 +_C.TRAIN.WD = 0.0001 +_C.TRAIN.NESTEROV = False +_C.TRAIN.GAMMA1 = 0.99 +_C.TRAIN.GAMMA2 = 0.0 + +_C.TRAIN.BEGIN_EPOCH = 0 +_C.TRAIN.END_EPOCH = 140 + +_C.TRAIN.RESUME = False +_C.TRAIN.CHECKPOINT = '' + +_C.TRAIN.BATCH_SIZE_PER_GPU = 32 +_C.TRAIN.SHUFFLE = True + +# testing +_C.TEST = CN() + +# size of images for each device +_C.TEST.BATCH_SIZE_PER_GPU = 32 +# Test Model Epoch +_C.TEST.FLIP_TEST = False +_C.TEST.POST_PROCESS = False +_C.TEST.SHIFT_HEATMAP = False + +_C.TEST.USE_GT_BBOX = False + +# nms +_C.TEST.IMAGE_THRE = 0.1 +_C.TEST.NMS_THRE = 0.6 +_C.TEST.SOFT_NMS = False +_C.TEST.OKS_THRE = 0.5 +_C.TEST.IN_VIS_THRE = 0.0 +_C.TEST.COCO_BBOX_FILE = '' +_C.TEST.BBOX_THRE = 1.0 +_C.TEST.MODEL_FILE = '' + +# debug +_C.DEBUG = CN() +_C.DEBUG.DEBUG = False +_C.DEBUG.SAVE_BATCH_IMAGES_GT = False +_C.DEBUG.SAVE_BATCH_IMAGES_PRED = False +_C.DEBUG.SAVE_HEATMAPS_GT = False +_C.DEBUG.SAVE_HEATMAPS_PRED = False + + +def update_config(cfg, config_file): + cfg.defrost() + cfg.merge_from_file(config_file) + cfg.freeze() + + +if __name__ == '__main__': + import sys + with open(sys.argv[1], 'w') as f: + print(_C, file=f) + diff --git a/mesh_graphormer/modeling/hrnet/config/models.py b/mesh_graphormer/modeling/hrnet/config/models.py new file mode 100644 index 0000000..7a73bc9 --- /dev/null +++ b/mesh_graphormer/modeling/hrnet/config/models.py @@ -0,0 +1,47 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Create by Bin Xiao (Bin.Xiao@microsoft.com) +# Modified by Ke Sun (sunk@mail.ustc.edu.cn) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from yacs.config import CfgNode as CN + +# high_resoluton_net related params for classification +POSE_HIGH_RESOLUTION_NET = CN() +POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*'] +POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = 64 +POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1 +POSE_HIGH_RESOLUTION_NET.WITH_HEAD = True + +POSE_HIGH_RESOLUTION_NET.STAGE2 = CN() +POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1 +POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2 +POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4] +POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64] +POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC' +POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM' + +POSE_HIGH_RESOLUTION_NET.STAGE3 = CN() +POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1 +POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3 +POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4] +POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128] +POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC' +POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM' + +POSE_HIGH_RESOLUTION_NET.STAGE4 = CN() +POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1 +POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4 +POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] +POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] +POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC' +POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM' + +MODEL_EXTRAS = { + 'cls_hrnet': POSE_HIGH_RESOLUTION_NET, +} diff --git a/mesh_graphormer/modeling/hrnet/hrnet_cls_net.py b/mesh_graphormer/modeling/hrnet/hrnet_cls_net.py new file mode 100644 index 0000000..3388c23 --- /dev/null +++ b/mesh_graphormer/modeling/hrnet/hrnet_cls_net.py @@ -0,0 +1,523 @@ + + +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# Modified by Ke Sun (sunk@mail.ustc.edu.cn) +# Modified by Kevin Lin (keli@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging +import functools + +import numpy as np + +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F +import code +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(False) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + nn.BatchNorm2d(num_inchannels[i], + momentum=BN_MOMENTUM), + nn.Upsample(scale_factor=2**(j-i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i-j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM), + nn.ReLU(False))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HighResolutionNet(nn.Module): + + def __init__(self, cfg, **kwargs): + super(HighResolutionNet, self).__init__() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1'] + num_channels = self.stage1_cfg['NUM_CHANNELS'][0] + block = blocks_dict[self.stage1_cfg['BLOCK']] + num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion*num_channels + + self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + + # Classification Head + self.incre_modules, self.downsamp_modules, \ + self.final_layer = self._make_head(pre_stage_channels) + + self.classifier = nn.Linear(2048, 1000) + + def _make_head(self, pre_stage_channels): + head_block = Bottleneck + head_channels = [32, 64, 128, 256] + + # Increasing the #channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + incre_modules = [] + for i, channels in enumerate(pre_stage_channels): + incre_module = self._make_layer(head_block, + channels, + head_channels[i], + 1, + stride=1) + incre_modules.append(incre_module) + incre_modules = nn.ModuleList(incre_modules) + + # downsampling modules + downsamp_modules = [] + for i in range(len(pre_stage_channels)-1): + in_channels = head_channels[i] * head_block.expansion + out_channels = head_channels[i+1] * head_block.expansion + + downsamp_module = nn.Sequential( + nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1), + nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + downsamp_modules.append(downsamp_module) + downsamp_modules = nn.ModuleList(downsamp_modules) + + final_layer = nn.Sequential( + nn.Conv2d( + in_channels=head_channels[3] * head_block.expansion, + out_channels=2048, + kernel_size=1, + stride=1, + padding=0 + ), + nn.BatchNorm2d(2048, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + return incre_modules, downsamp_modules, final_layer + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + nn.BatchNorm2d( + num_channels_cur_layer[i], momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i+1-num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + # Classification Head + y = self.incre_modules[0](y_list[0]) + for i in range(len(self.downsamp_modules)): + y = self.incre_modules[i+1](y_list[i+1]) + \ + self.downsamp_modules[i](y) + + y = self.final_layer(y) + + if torch._C._get_tracing_state(): + y = y.flatten(start_dim=2).mean(dim=2) + else: + y = F.avg_pool2d(y, kernel_size=y.size() + [2:]).view(y.size(0), -1) + + # y = self.classifier(y) + + return y + + def init_weights(self, pretrained='',): + logger.info('=> init weights from normal distribution') + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained) + logger.info('=> loading pretrained model {}'.format(pretrained)) + print('=> loading pretrained model {}'.format(pretrained)) + model_dict = self.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys()} + # for k, _ in pretrained_dict.items(): + # logger.info( + # '=> loading {} pretrained model {}'.format(k, pretrained)) + # print('=> loading {} pretrained model {}'.format(k, pretrained)) + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) + # code.interact(local=locals()) + +def get_cls_net(config, pretrained, **kwargs): + model = HighResolutionNet(config, **kwargs) + model.init_weights(pretrained=pretrained) + return model diff --git a/mesh_graphormer/modeling/hrnet/hrnet_cls_net_gridfeat.py b/mesh_graphormer/modeling/hrnet/hrnet_cls_net_gridfeat.py new file mode 100644 index 0000000..6c44c94 --- /dev/null +++ b/mesh_graphormer/modeling/hrnet/hrnet_cls_net_gridfeat.py @@ -0,0 +1,524 @@ + + +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# Modified by Ke Sun (sunk@mail.ustc.edu.cn) +# Modified by Kevin Lin (keli@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging +import functools + +import numpy as np + +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F +import code +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(False) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + nn.BatchNorm2d(num_inchannels[i], + momentum=BN_MOMENTUM), + nn.Upsample(scale_factor=2**(j-i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i-j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM), + nn.ReLU(False))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HighResolutionNet(nn.Module): + + def __init__(self, cfg, **kwargs): + super(HighResolutionNet, self).__init__() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1'] + num_channels = self.stage1_cfg['NUM_CHANNELS'][0] + block = blocks_dict[self.stage1_cfg['BLOCK']] + num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion*num_channels + + self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + + # Classification Head + self.incre_modules, self.downsamp_modules, \ + self.final_layer = self._make_head(pre_stage_channels) + + self.classifier = nn.Linear(2048, 1000) + + def _make_head(self, pre_stage_channels): + head_block = Bottleneck + head_channels = [32, 64, 128, 256] + + # Increasing the #channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + incre_modules = [] + for i, channels in enumerate(pre_stage_channels): + incre_module = self._make_layer(head_block, + channels, + head_channels[i], + 1, + stride=1) + incre_modules.append(incre_module) + incre_modules = nn.ModuleList(incre_modules) + + # downsampling modules + downsamp_modules = [] + for i in range(len(pre_stage_channels)-1): + in_channels = head_channels[i] * head_block.expansion + out_channels = head_channels[i+1] * head_block.expansion + + downsamp_module = nn.Sequential( + nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1), + nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + downsamp_modules.append(downsamp_module) + downsamp_modules = nn.ModuleList(downsamp_modules) + + final_layer = nn.Sequential( + nn.Conv2d( + in_channels=head_channels[3] * head_block.expansion, + out_channels=2048, + kernel_size=1, + stride=1, + padding=0 + ), + nn.BatchNorm2d(2048, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + return incre_modules, downsamp_modules, final_layer + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + nn.BatchNorm2d( + num_channels_cur_layer[i], momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i+1-num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + # Classification Head + y = self.incre_modules[0](y_list[0]) + for i in range(len(self.downsamp_modules)): + y = self.incre_modules[i+1](y_list[i+1]) + \ + self.downsamp_modules[i](y) + + yy = self.final_layer(y) + + if torch._C._get_tracing_state(): + yy = yy.flatten(start_dim=2).mean(dim=2) + else: + yy = F.avg_pool2d(yy, kernel_size=yy.size() + [2:]).view(yy.size(0), -1) + + # y = self.classifier(y) + return yy, y + + + + def init_weights(self, pretrained='',): + logger.info('=> init weights from normal distribution') + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained) + logger.info('=> loading pretrained model {}'.format(pretrained)) + print('=> loading pretrained model {}'.format(pretrained)) + model_dict = self.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys()} + # for k, _ in pretrained_dict.items(): + # logger.info( + # '=> loading {} pretrained model {}'.format(k, pretrained)) + # print('=> loading {} pretrained model {}'.format(k, pretrained)) + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) + # code.interact(local=locals()) + +def get_cls_net_gridfeat(config, pretrained, **kwargs): + model = HighResolutionNet(config, **kwargs) + model.init_weights(pretrained=pretrained) + return model diff --git a/mesh_graphormer/tools/run_gphmer_bodymesh.py b/mesh_graphormer/tools/run_gphmer_bodymesh.py new file mode 100644 index 0000000..2f65017 --- /dev/null +++ b/mesh_graphormer/tools/run_gphmer_bodymesh.py @@ -0,0 +1,750 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Training and evaluation codes for +3D human body mesh reconstruction from an image +""" + +from __future__ import absolute_import, division, print_function +import argparse +import os +import os.path as op +import code +import json +import time +import datetime +import torch +import torchvision.models as models +from torchvision.utils import make_grid +import gc +import numpy as np +import cv2 +from mesh_graphormer.modeling.bert import BertConfig, Graphormer +from mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network +from mesh_graphormer.modeling._smpl import SMPL, Mesh +from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat +from mesh_graphormer.modeling.hrnet.config import config as hrnet_config +from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config +import mesh_graphormer.modeling.data.config as cfg +from mesh_graphormer.datasets.build import make_data_loader + +from mesh_graphormer.utils.logger import setup_logger +from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather +from mesh_graphormer.utils.miscellaneous import mkdir, set_seed +from mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger +from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction, visualize_reconstruction_test +from mesh_graphormer.utils.metric_pampjpe import reconstruction_error +from mesh_graphormer.utils.geometric_layers import orthographic_projection + + +device = "cuda" + +from azureml.core.run import Run +aml_run = Run.get_context() + +def save_checkpoint(model, args, epoch, iteration, num_trial=10): + checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format( + epoch, iteration)) + if not is_main_process(): + return checkpoint_dir + mkdir(checkpoint_dir) + model_to_save = model.module if hasattr(model, 'module') else model + for i in range(num_trial): + try: + torch.save(model_to_save, op.join(checkpoint_dir, 'model.bin')) + torch.save(model_to_save.state_dict(), op.join(checkpoint_dir, 'state_dict.bin')) + torch.save(args, op.join(checkpoint_dir, 'training_args.bin')) + logger.info("Save checkpoint to {}".format(checkpoint_dir)) + break + except: + pass + else: + logger.info("Failed to save checkpoint after {} trails.".format(num_trial)) + return checkpoint_dir + +def save_scores(args, split, mpjpe, pampjpe, mpve): + eval_log = [] + res = {} + res['mPJPE'] = mpjpe + res['PAmPJPE'] = pampjpe + res['mPVE'] = mpve + eval_log.append(res) + with open(op.join(args.output_dir, split+'_eval_logs.json'), 'w') as f: + json.dump(eval_log, f) + logger.info("Save eval scores to {}".format(args.output_dir)) + return + +def adjust_learning_rate(optimizer, epoch, args): + """ + Sets the learning rate to the initial LR decayed by x every y epochs + x = 0.1, y = args.num_train_epochs/2.0 = 100 + """ + lr = args.lr * (0.1 ** (epoch // (args.num_train_epochs/2.0) )) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def mean_per_joint_position_error(pred, gt, has_3d_joints): + """ + Compute mPJPE + """ + gt = gt[has_3d_joints == 1] + gt = gt[:, :, :-1] + pred = pred[has_3d_joints == 1] + + with torch.no_grad(): + gt_pelvis = (gt[:, 2,:] + gt[:, 3,:]) / 2 + gt = gt - gt_pelvis[:, None, :] + pred_pelvis = (pred[:, 2,:] + pred[:, 3,:]) / 2 + pred = pred - pred_pelvis[:, None, :] + error = torch.sqrt( ((pred - gt) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() + return error + +def mean_per_vertex_error(pred, gt, has_smpl): + """ + Compute mPVE + """ + pred = pred[has_smpl == 1] + gt = gt[has_smpl == 1] + with torch.no_grad(): + error = torch.sqrt( ((pred - gt) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() + return error + +def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, has_pose_2d): + """ + Compute 2D reprojection loss if 2D keypoint annotations are available. + The confidence (conf) is binary and indicates whether the keypoints exist or not. + """ + conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() + loss = (conf * criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean() + return loss + +def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d, device): + """ + Compute 3D keypoint loss if 3D keypoint annotations are available. + """ + conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() + gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone() + gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1] + conf = conf[has_pose_3d == 1] + pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1] + if len(gt_keypoints_3d) > 0: + gt_pelvis = (gt_keypoints_3d[:, 2,:] + gt_keypoints_3d[:, 3,:]) / 2 + gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :] + pred_pelvis = (pred_keypoints_3d[:, 2,:] + pred_keypoints_3d[:, 3,:]) / 2 + pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :] + return (conf * criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean() + else: + return torch.FloatTensor(1).fill_(0.).to(device) + +def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, device): + """ + Compute per-vertex loss if vertex annotations are available. + """ + pred_vertices_with_shape = pred_vertices[has_smpl == 1] + gt_vertices_with_shape = gt_vertices[has_smpl == 1] + if len(gt_vertices_with_shape) > 0: + return criterion_vertices(pred_vertices_with_shape, gt_vertices_with_shape) + else: + return torch.FloatTensor(1).fill_(0.).to(device) + +def rectify_pose(pose): + pose = pose.copy() + R_mod = cv2.Rodrigues(np.array([np.pi, 0, 0]))[0] + R_root = cv2.Rodrigues(pose[:3])[0] + new_root = R_root.dot(R_mod) + pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3) + return pose + +def run(args, train_dataloader, val_dataloader, Graphormer_model, smpl, mesh_sampler, renderer): + smpl.eval() + max_iter = len(train_dataloader) + iters_per_epoch = max_iter // args.num_train_epochs + if iters_per_epoch<1000: + args.logging_steps = 500 + + optimizer = torch.optim.Adam(params=list(Graphormer_model.parameters()), + lr=args.lr, + betas=(0.9, 0.999), + weight_decay=0) + + # define loss function (criterion) and optimizer + criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').to(device) + criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device) + criterion_vertices = torch.nn.L1Loss().to(device) + + if args.distributed: + Graphormer_model = torch.nn.parallel.DistributedDataParallel( + Graphormer_model, device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True, + ) + + logger.info( + ' '.join( + ['Local rank: {o}', 'Max iteration: {a}', 'iters_per_epoch: {b}','num_train_epochs: {c}',] + ).format(o=args.local_rank, a=max_iter, b=iters_per_epoch, c=args.num_train_epochs) + ) + + start_training_time = time.time() + end = time.time() + Graphormer_model.train() + batch_time = AverageMeter() + data_time = AverageMeter() + log_losses = AverageMeter() + log_loss_2djoints = AverageMeter() + log_loss_3djoints = AverageMeter() + log_loss_vertices = AverageMeter() + log_eval_metrics = EvalMetricsLogger() + + for iteration, (img_keys, images, annotations) in enumerate(train_dataloader): + # gc.collect() + # torch.cuda.empty_cache() + Graphormer_model.train() + iteration += 1 + epoch = iteration // iters_per_epoch + batch_size = images.size(0) + adjust_learning_rate(optimizer, epoch, args) + data_time.update(time.time() - end) + + images = images.to(device) + gt_2d_joints = annotations['joints_2d'].to(device) + gt_2d_joints = gt_2d_joints[:,cfg.J24_TO_J14,:] + has_2d_joints = annotations['has_2d_joints'].to(device) + + gt_3d_joints = annotations['joints_3d'].to(device) + gt_3d_pelvis = gt_3d_joints[:,cfg.J24_NAME.index('Pelvis'),:3] + gt_3d_joints = gt_3d_joints[:,cfg.J24_TO_J14,:] + gt_3d_joints[:,:,:3] = gt_3d_joints[:,:,:3] - gt_3d_pelvis[:, None, :] + has_3d_joints = annotations['has_3d_joints'].to(device) + + gt_pose = annotations['pose'].to(device) + gt_betas = annotations['betas'].to(device) + has_smpl = annotations['has_smpl'].to(device) + mjm_mask = annotations['mjm_mask'].to(device) + mvm_mask = annotations['mvm_mask'].to(device) + + # generate simplified mesh + gt_vertices = smpl(gt_pose, gt_betas) + gt_vertices_sub2 = mesh_sampler.downsample(gt_vertices, n1=0, n2=2) + gt_vertices_sub = mesh_sampler.downsample(gt_vertices) + + # normalize gt based on smpl's pelvis + gt_smpl_3d_joints = smpl.get_h36m_joints(gt_vertices) + gt_smpl_3d_pelvis = gt_smpl_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:] + gt_vertices_sub2 = gt_vertices_sub2 - gt_smpl_3d_pelvis[:, None, :] + + # prepare masks for mask vertex/joint modeling + mjm_mask_ = mjm_mask.expand(-1,-1,2051) + mvm_mask_ = mvm_mask.expand(-1,-1,2051) + meta_masks = torch.cat([mjm_mask_, mvm_mask_], dim=1) + + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices = Graphormer_model(images, smpl, mesh_sampler, meta_masks=meta_masks, is_train=True) + + # normalize gt based on smpl's pelvis + gt_vertices_sub = gt_vertices_sub - gt_smpl_3d_pelvis[:, None, :] + gt_vertices = gt_vertices - gt_smpl_3d_pelvis[:, None, :] + + # obtain 3d joints, which are regressed from the full mesh + pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices) + pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:] + + # obtain 2d joints, which are projected from 3d joints of smpl mesh + pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera) + pred_2d_joints = orthographic_projection(pred_3d_joints, pred_camera) + + # compute 3d joint loss (where the joints are directly output from transformer) + loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints, has_3d_joints, args.device) + # compute 3d vertex loss + loss_vertices = ( args.vloss_w_sub2 * vertices_loss(criterion_vertices, pred_vertices_sub2, gt_vertices_sub2, has_smpl, args.device) + \ + args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_smpl, args.device) + \ + args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, args.device) ) + # compute 3d joint loss (where the joints are regressed from full mesh) + loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_smpl, gt_3d_joints, has_3d_joints, args.device) + # compute 2d joint loss + loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints) + \ + keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_smpl, gt_2d_joints, has_2d_joints) + + loss_3d_joints = loss_3d_joints + loss_reg_3d_joints + + # we empirically use hyperparameters to balance difference losses + loss = args.joints_loss_weight*loss_3d_joints + \ + args.vertices_loss_weight*loss_vertices + args.vertices_loss_weight*loss_2d_joints + + # update logs + log_loss_2djoints.update(loss_2d_joints.item(), batch_size) + log_loss_3djoints.update(loss_3d_joints.item(), batch_size) + log_loss_vertices.update(loss_vertices.item(), batch_size) + log_losses.update(loss.item(), batch_size) + + # back prop + optimizer.zero_grad() + loss.backward() + optimizer.step() + + batch_time.update(time.time() - end) + end = time.time() + + if iteration % args.logging_steps == 0 or iteration == max_iter: + eta_seconds = batch_time.avg * (max_iter - iteration) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + logger.info( + ' '.join( + ['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',] + ).format(eta=eta_string, ep=epoch, iter=iteration, + memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) + + ' loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format( + log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg, + optimizer.param_groups[0]['lr']) + ) + + aml_run.log(name='Loss', value=float(log_losses.avg)) + aml_run.log(name='3d joint Loss', value=float(log_loss_3djoints.avg)) + aml_run.log(name='2d joint Loss', value=float(log_loss_2djoints.avg)) + aml_run.log(name='vertex Loss', value=float(log_loss_vertices.avg)) + + visual_imgs = visualize_mesh( renderer, + annotations['ori_img'].detach(), + annotations['joints_2d'].detach(), + pred_vertices.detach(), + pred_camera.detach(), + pred_2d_joints_from_smpl.detach()) + visual_imgs = visual_imgs.transpose(0,1) + visual_imgs = visual_imgs.transpose(1,2) + visual_imgs = np.asarray(visual_imgs) + + if is_main_process()==True: + stamp = str(epoch) + '_' + str(iteration) + temp_fname = args.output_dir + 'visual_' + stamp + '.jpg' + cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) + aml_run.log_image(name='visual results', path=temp_fname) + + if iteration % iters_per_epoch == 0: + val_mPVE, val_mPJPE, val_PAmPJPE, val_count = run_validate(args, val_dataloader, + Graphormer_model, + criterion_keypoints, + criterion_vertices, + epoch, + smpl, + mesh_sampler) + aml_run.log(name='mPVE', value=float(1000*val_mPVE)) + aml_run.log(name='mPJPE', value=float(1000*val_mPJPE)) + aml_run.log(name='PAmPJPE', value=float(1000*val_PAmPJPE)) + logger.info( + ' '.join(['Validation', 'epoch: {ep}',]).format(ep=epoch) + + ' mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f}, Data Count: {:6.2f}'.format(1000*val_mPVE, 1000*val_mPJPE, 1000*val_PAmPJPE, val_count) + ) + + if val_PAmPJPE0: + mPVE.update(np.mean(error_vertices), int(torch.sum(has_smpl)) ) + if len(error_joints)>0: + mPJPE.update(np.mean(error_joints), int(torch.sum(has_3d_joints)) ) + if len(error_joints_pa)>0: + PAmPJPE.update(np.mean(error_joints_pa), int(torch.sum(has_3d_joints)) ) + + val_mPVE = all_gather(float(mPVE.avg)) + val_mPVE = sum(val_mPVE)/len(val_mPVE) + val_mPJPE = all_gather(float(mPJPE.avg)) + val_mPJPE = sum(val_mPJPE)/len(val_mPJPE) + + val_PAmPJPE = all_gather(float(PAmPJPE.avg)) + val_PAmPJPE = sum(val_PAmPJPE)/len(val_PAmPJPE) + + val_count = all_gather(float(mPVE.count)) + val_count = sum(val_count) + + return val_mPVE, val_mPJPE, val_PAmPJPE, val_count + + +def visualize_mesh( renderer, + images, + gt_keypoints_2d, + pred_vertices, + pred_camera, + pred_keypoints_2d): + """Tensorboard logging.""" + gt_keypoints_2d = gt_keypoints_2d.cpu().numpy() + to_lsp = list(range(14)) + rend_imgs = [] + batch_size = pred_vertices.shape[0] + # Do visualization for the first 6 images of the batch + for i in range(min(batch_size, 10)): + img = images[i].cpu().numpy().transpose(1,2,0) + # Get LSP keypoints from the full list of keypoints + gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp] + pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp] + # Get predict vertices for the particular example + vertices = pred_vertices[i].cpu().numpy() + cam = pred_camera[i].cpu().numpy() + # Visualize reconstruction and detected pose + rend_img = visualize_reconstruction(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer) + rend_img = rend_img.transpose(2,0,1) + rend_imgs.append(torch.from_numpy(rend_img)) + rend_imgs = make_grid(rend_imgs, nrow=1) + return rend_imgs + +def visualize_mesh_test( renderer, + images, + gt_keypoints_2d, + pred_vertices, + pred_camera, + pred_keypoints_2d, + PAmPJPE_h36m_j14): + """Tensorboard logging.""" + gt_keypoints_2d = gt_keypoints_2d.cpu().numpy() + to_lsp = list(range(14)) + rend_imgs = [] + batch_size = pred_vertices.shape[0] + # Do visualization for the first 6 images of the batch + for i in range(min(batch_size, 10)): + img = images[i].cpu().numpy().transpose(1,2,0) + # Get LSP keypoints from the full list of keypoints + gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp] + pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp] + # Get predict vertices for the particular example + vertices = pred_vertices[i].cpu().numpy() + cam = pred_camera[i].cpu().numpy() + score = PAmPJPE_h36m_j14[i] + # Visualize reconstruction and detected pose + rend_img = visualize_reconstruction_test(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer, score) + rend_img = rend_img.transpose(2,0,1) + rend_imgs.append(torch.from_numpy(rend_img)) + rend_imgs = make_grid(rend_imgs, nrow=1) + return rend_imgs + + +def parse_args(): + parser = argparse.ArgumentParser() + ######################################################### + # Data related arguments + ######################################################### + parser.add_argument("--data_dir", default='datasets', type=str, required=False, + help="Directory with all datasets, each in one subfolder") + parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False, + help="Yaml file with all data for training.") + parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False, + help="Yaml file with all data for validation.") + parser.add_argument("--num_workers", default=4, type=int, + help="Workers in dataloader.") + parser.add_argument("--img_scale_factor", default=1, type=int, + help="adjust image resolution.") + ######################################################### + # Loading/saving checkpoints + ######################################################### + parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, + help="Path to pre-trained transformer model or model type.") + parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, + help="Path to specific checkpoint for resume training.") + parser.add_argument("--output_dir", default='output/', type=str, required=False, + help="The output directory to save checkpoint and test results.") + parser.add_argument("--config_name", default="", type=str, + help="Pretrained config name or path if not the same as model_name.") + ######################################################### + # Training parameters + ######################################################### + parser.add_argument("--per_gpu_train_batch_size", default=30, type=int, + help="Batch size per GPU/CPU for training.") + parser.add_argument("--per_gpu_eval_batch_size", default=30, type=int, + help="Batch size per GPU/CPU for evaluation.") + parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float, + help="The initial lr.") + parser.add_argument("--num_train_epochs", default=200, type=int, + help="Total number of training epochs to perform.") + parser.add_argument("--vertices_loss_weight", default=100.0, type=float) + parser.add_argument("--joints_loss_weight", default=1000.0, type=float) + parser.add_argument("--vloss_w_full", default=0.33, type=float) + parser.add_argument("--vloss_w_sub", default=0.33, type=float) + parser.add_argument("--vloss_w_sub2", default=0.33, type=float) + parser.add_argument("--drop_out", default=0.1, type=float, + help="Drop out ratio in BERT.") + ######################################################### + # Model architectures + ######################################################### + parser.add_argument('-a', '--arch', default='hrnet-w64', + help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') + parser.add_argument("--num_hidden_layers", default=4, type=int, required=False, + help="Update model config if given") + parser.add_argument("--hidden_size", default=-1, type=int, required=False, + help="Update model config if given") + parser.add_argument("--num_attention_heads", default=4, type=int, required=False, + help="Update model config if given. Note that the division of " + "hidden_size / num_attention_heads should be in integer.") + parser.add_argument("--intermediate_size", default=-1, type=int, required=False, + help="Update model config if given.") + parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--which_gcn", default='0,0,1', type=str, + help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") + parser.add_argument("--mesh_type", default='body', type=str, help="body or hand") + parser.add_argument("--interm_size_scale", default=2, type=int) + ######################################################### + # Others + ######################################################### + parser.add_argument("--run_eval_only", default=False, action='store_true',) + parser.add_argument('--logging_steps', type=int, default=1000, + help="Log every X steps.") + parser.add_argument("--device", type=str, default='cuda', + help="cuda or cpu") + parser.add_argument('--seed', type=int, default=88, + help="random seed for initialization.") + parser.add_argument("--local_rank", type=int, default=0, + help="For distributed training.") + + + args = parser.parse_args() + return args + + +def main(args): + global logger + # Setup CUDA, GPU & distributed training + args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 + os.environ['OMP_NUM_THREADS'] = str(args.num_workers) + print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) + + args.distributed = args.num_gpus > 1 + args.device = torch.device(args.device) + if args.distributed: + print("Init distributed training on local rank {} ({}), rank {}, world size {}".format(args.local_rank, int(os.environ["LOCAL_RANK"]), int(os.environ["NODE_RANK"]), args.num_gpus)) + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group( + backend='nccl', init_method='env://' + ) + local_rank = int(os.environ["LOCAL_RANK"]) + args.device = torch.device("cuda", local_rank) + synchronize() + + mkdir(args.output_dir) + logger = setup_logger("Graphormer", args.output_dir, get_rank()) + set_seed(args.seed, args.num_gpus) + logger.info("Using {} GPUs".format(args.num_gpus)) + + # Mesh and SMPL utils + smpl = SMPL().to(args.device) + mesh_sampler = Mesh() + + # Renderer for visualization + renderer = Renderer(faces=smpl.faces.cpu().numpy()) + + # Load model + trans_encoder = [] + + input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] + hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] + output_feat_dim = input_feat_dim[1:] + [3] + + # which encoder block to have graph convs + which_blk_graph = [int(item) for item in args.which_gcn.split(',')] + + if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: + # if only run eval, load checkpoint + logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) + _model = torch.load(args.resume_checkpoint) + else: + # init three transformer-encoder blocks in a loop + for i in range(len(output_feat_dim)): + config_class, model_class = BertConfig, Graphormer + config = config_class.from_pretrained(args.config_name if args.config_name \ + else args.model_name_or_path) + + config.output_attentions = False + config.hidden_dropout_prob = args.drop_out + config.img_feature_dim = input_feat_dim[i] + config.output_feature_dim = output_feat_dim[i] + args.hidden_size = hidden_feat_dim[i] + args.intermediate_size = int(args.hidden_size*args.interm_size_scale) + + if which_blk_graph[i]==1: + config.graph_conv = True + logger.info("Add Graph Conv") + else: + config.graph_conv = False + + config.mesh_type = args.mesh_type + + # update model structure if specified in arguments + update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] + + for idx, param in enumerate(update_params): + arg_param = getattr(args, param) + config_param = getattr(config, param) + if arg_param > 0 and arg_param != config_param: + logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) + setattr(config, param, arg_param) + + # init a transformer encoder and append it to a list + assert config.hidden_size % config.num_attention_heads == 0 + model = model_class(config=config) + logger.info("Init model from scratch.") + trans_encoder.append(model) + + + # init ImageNet pre-trained backbone model + if args.arch=='hrnet': + hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w40 model') + elif args.arch=='hrnet-w64': + hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w64 model') + else: + print("=> using pre-trained model '{}'".format(args.arch)) + backbone = models.__dict__[args.arch](pretrained=True) + # remove the last fc layer + backbone = torch.nn.Sequential(*list(backbone.children())[:-2]) + + + trans_encoder = torch.nn.Sequential(*trans_encoder) + total_params = sum(p.numel() for p in trans_encoder.parameters()) + logger.info('Graphormer encoders total parameters: {}'.format(total_params)) + backbone_total_params = sum(p.numel() for p in backbone.parameters()) + logger.info('Backbone total parameters: {}'.format(backbone_total_params)) + + # build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder) + _model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler) + + if args.resume_checkpoint!=None and args.resume_checkpoint!='None': + # for fine-tuning or resume training or inference, load weights from checkpoint + logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) + # workaround approach to load sparse tensor in graph conv. + states = torch.load(args.resume_checkpoint) + # states = checkpoint_loaded.state_dict() + for k, v in states.items(): + states[k] = v.cpu() + # del checkpoint_loaded + _model.load_state_dict(states, strict=False) + del states + gc.collect() + torch.cuda.empty_cache() + + + _model.to(args.device) + logger.info("Training parameters %s", args) + + if args.run_eval_only==True: + val_dataloader = make_data_loader(args, args.val_yaml, + args.distributed, is_train=False, scale_factor=args.img_scale_factor) + run_eval_general(args, val_dataloader, _model, smpl, mesh_sampler) + + else: + train_dataloader = make_data_loader(args, args.train_yaml, + args.distributed, is_train=True, scale_factor=args.img_scale_factor) + val_dataloader = make_data_loader(args, args.val_yaml, + args.distributed, is_train=False, scale_factor=args.img_scale_factor) + run(args, train_dataloader, val_dataloader, _model, smpl, mesh_sampler, renderer) + + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/mesh_graphormer/tools/run_gphmer_bodymesh_inference.py b/mesh_graphormer/tools/run_gphmer_bodymesh_inference.py new file mode 100644 index 0000000..0ffc213 --- /dev/null +++ b/mesh_graphormer/tools/run_gphmer_bodymesh_inference.py @@ -0,0 +1,351 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +End-to-end inference codes for +3D human body mesh reconstruction from an image +""" + +from __future__ import absolute_import, division, print_function +import argparse +import os +import os.path as op +import code +import json +import time +import datetime +import torch +import torchvision.models as models +from torchvision.utils import make_grid +import gc +import numpy as np +import cv2 +from mesh_graphormer.modeling.bert import BertConfig, Graphormer +from mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network +from mesh_graphormer.modeling._smpl import SMPL, Mesh +from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat +from mesh_graphormer.modeling.hrnet.config import config as hrnet_config +from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config +import mesh_graphormer.modeling.data.config as cfg +from mesh_graphormer.datasets.build import make_data_loader + +from mesh_graphormer.utils.logger import setup_logger +from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather +from mesh_graphormer.utils.miscellaneous import mkdir, set_seed +from mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger +from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text +from mesh_graphormer.utils.metric_pampjpe import reconstruction_error +from mesh_graphormer.utils.geometric_layers import orthographic_projection + +from PIL import Image +from torchvision import transforms + + +device = "cuda" + +transform = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])]) + +transform_visualize = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor()]) + +def run_inference(args, image_list, Graphormer_model, smpl, renderer, mesh_sampler): + # switch to evaluate mode + Graphormer_model.eval() + smpl.eval() + with torch.no_grad(): + for image_file in image_list: + if 'pred' not in image_file: + att_all = [] + img = Image.open(image_file) + img_tensor = transform(img) + img_visual = transform_visualize(img) + + batch_imgs = torch.unsqueeze(img_tensor, 0).to(device) + batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device) + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, smpl, mesh_sampler) + + # obtain 3d joints from full mesh + pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices) + + pred_3d_pelvis = pred_3d_joints_from_smpl[:,cfg.H36M_J17_NAME.index('Pelvis'),:] + pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:] + pred_3d_joints_from_smpl = pred_3d_joints_from_smpl - pred_3d_pelvis[:, None, :] + pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :] + + # save attantion + att_max_value = att[-1] + att_cpu = np.asarray(att_max_value.cpu().detach()) + att_all.append(att_cpu) + + # obtain 3d joints, which are regressed from the full mesh + pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices) + pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:] + # obtain 2d joints, which are projected from 3d joints of smpl mesh + pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera) + pred_2d_431_vertices_from_smpl = orthographic_projection(pred_vertices_sub2, pred_camera) + visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0], + pred_vertices[0].detach(), + pred_camera.detach()) + # visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0], + # pred_vertices[0].detach(), + # pred_vertices_sub2[0].detach(), + # pred_2d_431_vertices_from_smpl[0].detach(), + # pred_2d_joints_from_smpl[0].detach(), + # pred_camera.detach(), + # att[-1][0].detach()) + + visual_imgs = visual_imgs_output.transpose(1,2,0) + visual_imgs = np.asarray(visual_imgs) + + temp_fname = image_file[:-4] + '_graphormer_pred.jpg' + print('save to ', temp_fname) + cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) + + return + +def visualize_mesh( renderer, images, + pred_vertices_full, + pred_camera): + img = images.cpu().numpy().transpose(1,2,0) + # Get predict vertices for the particular example + vertices_full = pred_vertices_full.cpu().numpy() + cam = pred_camera.cpu().numpy() + # Visualize only mesh reconstruction + rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue') + rend_img = rend_img.transpose(2,0,1) + return rend_img + +def visualize_mesh_and_attention( renderer, images, + pred_vertices_full, + pred_vertices, + pred_2d_vertices, + pred_2d_joints, + pred_camera, + attention): + img = images.cpu().numpy().transpose(1,2,0) + # Get predict vertices for the particular example + vertices_full = pred_vertices_full.cpu().numpy() + vertices = pred_vertices.cpu().numpy() + vertices_2d = pred_2d_vertices.cpu().numpy() + joints_2d = pred_2d_joints.cpu().numpy() + cam = pred_camera.cpu().numpy() + att = attention.cpu().numpy() + # Visualize reconstruction and attention + rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue') + rend_img = rend_img.transpose(2,0,1) + return rend_img + + +def parse_args(): + parser = argparse.ArgumentParser() + ######################################################### + # Data related arguments + ######################################################### + parser.add_argument("--num_workers", default=4, type=int, + help="Workers in dataloader.") + parser.add_argument("--img_scale_factor", default=1, type=int, + help="adjust image resolution.") + parser.add_argument("--image_file_or_path", default='./samples/human-body', type=str, + help="test data") + ######################################################### + # Loading/saving checkpoints + ######################################################### + parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, + help="Path to pre-trained transformer model or model type.") + parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, + help="Path to specific checkpoint for resume training.") + parser.add_argument("--output_dir", default='output/', type=str, required=False, + help="The output directory to save checkpoint and test results.") + parser.add_argument("--config_name", default="", type=str, + help="Pretrained config name or path if not the same as model_name.") + ######################################################### + # Model architectures + ######################################################### + parser.add_argument('-a', '--arch', default='hrnet-w64', + help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') + parser.add_argument("--num_hidden_layers", default=4, type=int, required=False, + help="Update model config if given") + parser.add_argument("--hidden_size", default=-1, type=int, required=False, + help="Update model config if given") + parser.add_argument("--num_attention_heads", default=4, type=int, required=False, + help="Update model config if given. Note that the division of " + "hidden_size / num_attention_heads should be in integer.") + parser.add_argument("--intermediate_size", default=-1, type=int, required=False, + help="Update model config if given.") + parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--which_gcn", default='0,0,1', type=str, + help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") + parser.add_argument("--mesh_type", default='body', type=str, help="body or hand") + parser.add_argument("--interm_size_scale", default=2, type=int) + ######################################################### + # Others + ######################################################### + parser.add_argument("--run_eval_only", default=True, action='store_true',) + parser.add_argument("--device", type=str, default='cuda', + help="cuda or cpu") + parser.add_argument('--seed', type=int, default=88, + help="random seed for initialization.") + + args = parser.parse_args() + return args + + +def main(args): + global logger + # Setup CUDA, GPU & distributed training + args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 + os.environ['OMP_NUM_THREADS'] = str(args.num_workers) + print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) + + args.distributed = args.num_gpus > 1 + args.device = torch.device(args.device) + + mkdir(args.output_dir) + logger = setup_logger("Graphormer", args.output_dir, get_rank()) + set_seed(args.seed, args.num_gpus) + logger.info("Using {} GPUs".format(args.num_gpus)) + + # Mesh and SMPL utils + smpl = SMPL().to(args.device) + mesh_sampler = Mesh() + + # Renderer for visualization + renderer = Renderer(faces=smpl.faces.cpu().numpy()) + + # Load model + trans_encoder = [] + + input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] + hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] + output_feat_dim = input_feat_dim[1:] + [3] + + # which encoder block to have graph convs + which_blk_graph = [int(item) for item in args.which_gcn.split(',')] + + if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: + # if only run eval, load checkpoint + logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) + _model = torch.load(args.resume_checkpoint) + else: + # init three transformer-encoder blocks in a loop + for i in range(len(output_feat_dim)): + config_class, model_class = BertConfig, Graphormer + config = config_class.from_pretrained(args.config_name if args.config_name \ + else args.model_name_or_path) + + config.output_attentions = False + config.img_feature_dim = input_feat_dim[i] + config.output_feature_dim = output_feat_dim[i] + args.hidden_size = hidden_feat_dim[i] + args.intermediate_size = int(args.hidden_size*args.interm_size_scale) + + if which_blk_graph[i]==1: + config.graph_conv = True + logger.info("Add Graph Conv") + else: + config.graph_conv = False + + config.mesh_type = args.mesh_type + + # update model structure if specified in arguments + update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] + + for idx, param in enumerate(update_params): + arg_param = getattr(args, param) + config_param = getattr(config, param) + if arg_param > 0 and arg_param != config_param: + logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) + setattr(config, param, arg_param) + + # init a transformer encoder and append it to a list + assert config.hidden_size % config.num_attention_heads == 0 + model = model_class(config=config) + logger.info("Init model from scratch.") + trans_encoder.append(model) + + # init ImageNet pre-trained backbone model + if args.arch=='hrnet': + hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w40 model') + elif args.arch=='hrnet-w64': + hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w64 model') + else: + print("=> using pre-trained model '{}'".format(args.arch)) + backbone = models.__dict__[args.arch](pretrained=True) + # remove the last fc layer + backbone = torch.nn.Sequential(*list(backbone.children())[:-2]) + + + trans_encoder = torch.nn.Sequential(*trans_encoder) + total_params = sum(p.numel() for p in trans_encoder.parameters()) + logger.info('Graphormer encoders total parameters: {}'.format(total_params)) + backbone_total_params = sum(p.numel() for p in backbone.parameters()) + logger.info('Backbone total parameters: {}'.format(backbone_total_params)) + + # build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder) + _model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler) + + if args.resume_checkpoint!=None and args.resume_checkpoint!='None': + # for fine-tuning or resume training or inference, load weights from checkpoint + logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) + # workaround approach to load sparse tensor in graph conv. + states = torch.load(args.resume_checkpoint) + # states = checkpoint_loaded.state_dict() + for k, v in states.items(): + states[k] = v.cpu() + # del checkpoint_loaded + _model.load_state_dict(states, strict=False) + del states + gc.collect() + torch.cuda.empty_cache() + + # update configs to enable attention outputs + setattr(_model.trans_encoder[-1].config,'output_attentions', True) + setattr(_model.trans_encoder[-1].config,'output_hidden_states', True) + _model.trans_encoder[-1].bert.encoder.output_attentions = True + _model.trans_encoder[-1].bert.encoder.output_hidden_states = True + for iter_layer in range(4): + _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True + for inter_block in range(3): + setattr(_model.trans_encoder[-1].config,'device', args.device) + + _model.to(args.device) + logger.info("Run inference") + + image_list = [] + if not args.image_file_or_path: + raise ValueError("image_file_or_path not specified") + if op.isfile(args.image_file_or_path): + image_list = [args.image_file_or_path] + elif op.isdir(args.image_file_or_path): + # should be a path with images only + for filename in os.listdir(args.image_file_or_path): + if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename: + image_list.append(args.image_file_or_path+'/'+filename) + else: + raise ValueError("Cannot find images at {}".format(args.image_file_or_path)) + + run_inference(args, image_list, _model, smpl, renderer, mesh_sampler) + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/mesh_graphormer/tools/run_gphmer_handmesh.py b/mesh_graphormer/tools/run_gphmer_handmesh.py new file mode 100644 index 0000000..4cfcf55 --- /dev/null +++ b/mesh_graphormer/tools/run_gphmer_handmesh.py @@ -0,0 +1,713 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Training and evaluation codes for +3D hand mesh reconstruction from an image +""" + +from __future__ import absolute_import, division, print_function +import argparse +import os +import os.path as op +import code +import json +import time +import datetime +import torch +import torchvision.models as models +from torchvision.utils import make_grid +import gc +import numpy as np +import cv2 +from mesh_graphormer.modeling.bert import BertConfig, Graphormer +from mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network +from mesh_graphormer.modeling._mano import MANO, Mesh +from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat +from mesh_graphormer.modeling.hrnet.config import config as hrnet_config +from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config +import mesh_graphormer.modeling.data.config as cfg +from mesh_graphormer.datasets.build import make_hand_data_loader + +from mesh_graphormer.utils.logger import setup_logger +from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather +from mesh_graphormer.utils.miscellaneous import mkdir, set_seed +from mesh_graphormer.utils.metric_logger import AverageMeter +from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction, visualize_reconstruction_test, visualize_reconstruction_no_text +from mesh_graphormer.utils.metric_pampjpe import reconstruction_error +from mesh_graphormer.utils.geometric_layers import orthographic_projection + + +device = "cuda" + +from azureml.core.run import Run +aml_run = Run.get_context() + +def save_checkpoint(model, args, epoch, iteration, num_trial=10): + checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format( + epoch, iteration)) + if not is_main_process(): + return checkpoint_dir + mkdir(checkpoint_dir) + model_to_save = model.module if hasattr(model, 'module') else model + for i in range(num_trial): + try: + torch.save(model_to_save, op.join(checkpoint_dir, 'model.bin')) + torch.save(model_to_save.state_dict(), op.join(checkpoint_dir, 'state_dict.bin')) + torch.save(args, op.join(checkpoint_dir, 'training_args.bin')) + logger.info("Save checkpoint to {}".format(checkpoint_dir)) + break + except: + pass + else: + logger.info("Failed to save checkpoint after {} trails.".format(num_trial)) + return checkpoint_dir + +def adjust_learning_rate(optimizer, epoch, args): + """ + Sets the learning rate to the initial LR decayed by x every y epochs + x = 0.1, y = args.num_train_epochs/2.0 = 100 + """ + lr = args.lr * (0.1 ** (epoch // (args.num_train_epochs/2.0) )) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, has_pose_2d): + """ + Compute 2D reprojection loss if 2D keypoint annotations are available. + The confidence is binary and indicates whether the keypoints exist or not. + """ + conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() + loss = (conf * criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean() + return loss + +def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d): + """ + Compute 3D keypoint loss if 3D keypoint annotations are available. + """ + conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() + gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone() + gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1] + conf = conf[has_pose_3d == 1] + pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1] + if len(gt_keypoints_3d) > 0: + gt_root = gt_keypoints_3d[:, 0,:] + gt_keypoints_3d = gt_keypoints_3d - gt_root[:, None, :] + pred_root = pred_keypoints_3d[:, 0,:] + pred_keypoints_3d = pred_keypoints_3d - pred_root[:, None, :] + return (conf * criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean() + else: + return torch.FloatTensor(1).fill_(0.).to(device) + +def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl): + """ + Compute per-vertex loss if vertex annotations are available. + """ + pred_vertices_with_shape = pred_vertices[has_smpl == 1] + gt_vertices_with_shape = gt_vertices[has_smpl == 1] + if len(gt_vertices_with_shape) > 0: + return criterion_vertices(pred_vertices_with_shape, gt_vertices_with_shape) + else: + return torch.FloatTensor(1).fill_(0.).to(device) + + +def run(args, train_dataloader, Graphormer_model, mano_model, renderer, mesh_sampler): + + max_iter = len(train_dataloader) + iters_per_epoch = max_iter // args.num_train_epochs + + optimizer = torch.optim.Adam(params=list(Graphormer_model.parameters()), + lr=args.lr, + betas=(0.9, 0.999), + weight_decay=0) + + # define loss function (criterion) and optimizer + criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').to(device) + criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device) + criterion_vertices = torch.nn.L1Loss().to(device) + + if args.distributed: + Graphormer_model = torch.nn.parallel.DistributedDataParallel( + Graphormer_model, device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True, + ) + + start_training_time = time.time() + end = time.time() + Graphormer_model.train() + batch_time = AverageMeter() + data_time = AverageMeter() + log_losses = AverageMeter() + log_loss_2djoints = AverageMeter() + log_loss_3djoints = AverageMeter() + log_loss_vertices = AverageMeter() + + for iteration, (img_keys, images, annotations) in enumerate(train_dataloader): + + Graphormer_model.train() + iteration += 1 + epoch = iteration // iters_per_epoch + batch_size = images.size(0) + adjust_learning_rate(optimizer, epoch, args) + data_time.update(time.time() - end) + + images = images.to(device) + gt_2d_joints = annotations['joints_2d'].to(device) + gt_pose = annotations['pose'].to(device) + gt_betas = annotations['betas'].to(device) + has_mesh = annotations['has_smpl'].to(device) + has_3d_joints = has_mesh + has_2d_joints = has_mesh + mjm_mask = annotations['mjm_mask'].to(device) + mvm_mask = annotations['mvm_mask'].to(device) + + # generate mesh + gt_vertices, gt_3d_joints = mano_model.layer(gt_pose, gt_betas) + gt_vertices = gt_vertices/1000.0 + gt_3d_joints = gt_3d_joints/1000.0 + + gt_vertices_sub = mesh_sampler.downsample(gt_vertices) + # normalize gt based on hand's wrist + gt_3d_root = gt_3d_joints[:,cfg.J_NAME.index('Wrist'),:] + gt_vertices = gt_vertices - gt_3d_root[:, None, :] + gt_vertices_sub = gt_vertices_sub - gt_3d_root[:, None, :] + gt_3d_joints = gt_3d_joints - gt_3d_root[:, None, :] + gt_3d_joints_with_tag = torch.ones((batch_size,gt_3d_joints.shape[1],4)).to(device) + gt_3d_joints_with_tag[:,:,:3] = gt_3d_joints + + # prepare masks for mask vertex/joint modeling + mjm_mask_ = mjm_mask.expand(-1,-1,2051) + mvm_mask_ = mvm_mask.expand(-1,-1,2051) + meta_masks = torch.cat([mjm_mask_, mvm_mask_], dim=1) + + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler, meta_masks=meta_masks, is_train=True) + + # obtain 3d joints, which are regressed from the full mesh + pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices) + + # obtain 2d joints, which are projected from 3d joints of smpl mesh + pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) + pred_2d_joints = orthographic_projection(pred_3d_joints.contiguous(), pred_camera.contiguous()) + + # compute 3d joint loss (where the joints are directly output from transformer) + loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints_with_tag, has_3d_joints) + + # compute 3d vertex loss + loss_vertices = ( args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_mesh) + \ + args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_mesh) ) + + # compute 3d joint loss (where the joints are regressed from full mesh) + loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_mesh, gt_3d_joints_with_tag, has_3d_joints) + # compute 2d joint loss + loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints) + \ + keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_mesh, gt_2d_joints, has_2d_joints) + + loss_3d_joints = loss_3d_joints + loss_reg_3d_joints + + # we empirically use hyperparameters to balance difference losses + loss = args.joints_loss_weight*loss_3d_joints + \ + args.vertices_loss_weight*loss_vertices + args.vertices_loss_weight*loss_2d_joints + + # update logs + log_loss_2djoints.update(loss_2d_joints.item(), batch_size) + log_loss_3djoints.update(loss_3d_joints.item(), batch_size) + log_loss_vertices.update(loss_vertices.item(), batch_size) + log_losses.update(loss.item(), batch_size) + + # back prop + optimizer.zero_grad() + loss.backward() + optimizer.step() + + batch_time.update(time.time() - end) + end = time.time() + + if iteration % args.logging_steps == 0 or iteration == max_iter: + eta_seconds = batch_time.avg * (max_iter - iteration) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + logger.info( + ' '.join( + ['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',] + ).format(eta=eta_string, ep=epoch, iter=iteration, + memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) + + ' loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format( + log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg, + optimizer.param_groups[0]['lr']) + ) + + aml_run.log(name='Loss', value=float(log_losses.avg)) + aml_run.log(name='3d joint Loss', value=float(log_loss_3djoints.avg)) + aml_run.log(name='2d joint Loss', value=float(log_loss_2djoints.avg)) + aml_run.log(name='vertex Loss', value=float(log_loss_vertices.avg)) + + visual_imgs = visualize_mesh( renderer, + annotations['ori_img'].detach(), + annotations['joints_2d'].detach(), + pred_vertices.detach(), + pred_camera.detach(), + pred_2d_joints_from_mesh.detach()) + visual_imgs = visual_imgs.transpose(0,1) + visual_imgs = visual_imgs.transpose(1,2) + visual_imgs = np.asarray(visual_imgs) + + if is_main_process()==True: + stamp = str(epoch) + '_' + str(iteration) + temp_fname = args.output_dir + 'visual_' + stamp + '.jpg' + cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) + aml_run.log_image(name='visual results', path=temp_fname) + + if iteration % iters_per_epoch == 0: + if epoch%10==0: + checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration) + + total_training_time = time.time() - start_training_time + total_time_str = str(datetime.timedelta(seconds=total_training_time)) + logger.info('Total training time: {} ({:.4f} s / iter)'.format( + total_time_str, total_training_time / max_iter) + ) + checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration) + +def run_eval_and_save(args, split, val_dataloader, Graphormer_model, mano_model, renderer, mesh_sampler): + + criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device) + criterion_vertices = torch.nn.L1Loss().to(device) + + if args.distributed: + Graphormer_model = torch.nn.parallel.DistributedDataParallel( + Graphormer_model, device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True, + ) + Graphormer_model.eval() + + if args.aml_eval==True: + run_aml_inference_hand_mesh(args, val_dataloader, + Graphormer_model, + criterion_keypoints, + criterion_vertices, + 0, + mano_model, mesh_sampler, + renderer, split) + else: + run_inference_hand_mesh(args, val_dataloader, + Graphormer_model, + criterion_keypoints, + criterion_vertices, + 0, + mano_model, mesh_sampler, + renderer, split) + checkpoint_dir = save_checkpoint(Graphormer_model, args, 0, 0) + return + +def run_aml_inference_hand_mesh(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, mano_model, mesh_sampler, renderer, split): + # switch to evaluate mode + Graphormer_model.eval() + fname_output_save = [] + mesh_output_save = [] + joint_output_save = [] + world_size = get_world_size() + with torch.no_grad(): + for i, (img_keys, images, annotations) in enumerate(val_loader): + batch_size = images.size(0) + # compute output + images = images.to(device) + + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler) + # obtain 3d joints from full mesh + pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices) + + for j in range(batch_size): + fname_output_save.append(img_keys[j]) + pred_vertices_list = pred_vertices[j].tolist() + mesh_output_save.append(pred_vertices_list) + pred_3d_joints_from_mesh_list = pred_3d_joints_from_mesh[j].tolist() + joint_output_save.append(pred_3d_joints_from_mesh_list) + + if world_size > 1: + torch.distributed.barrier() + print('save results to pred.json') + output_json_file = 'pred.json' + print('save results to ', output_json_file) + with open(output_json_file, 'w') as f: + json.dump([joint_output_save, mesh_output_save], f) + + azure_ckpt_name = '200' # args.resume_checkpoint.split('/')[-2].split('-')[1] + inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot))) + output_zip_file = args.output_dir + 'ckpt' + azure_ckpt_name + '-' + inference_setting +'-pred.zip' + + resolved_submit_cmd = 'zip ' + output_zip_file + ' ' + output_json_file + print(resolved_submit_cmd) + os.system(resolved_submit_cmd) + resolved_submit_cmd = 'rm %s'%(output_json_file) + print(resolved_submit_cmd) + os.system(resolved_submit_cmd) + if world_size > 1: + torch.distributed.barrier() + + return + +def run_inference_hand_mesh(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, mano_model, mesh_sampler, renderer, split): + # switch to evaluate mode + Graphormer_model.eval() + fname_output_save = [] + mesh_output_save = [] + joint_output_save = [] + with torch.no_grad(): + for i, (img_keys, images, annotations) in enumerate(val_loader): + batch_size = images.size(0) + # compute output + images = images.to(device) + + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler) + + # obtain 3d joints from full mesh + pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices) + pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:] + pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :] + pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :] + + for j in range(batch_size): + fname_output_save.append(img_keys[j]) + pred_vertices_list = pred_vertices[j].tolist() + mesh_output_save.append(pred_vertices_list) + pred_3d_joints_from_mesh_list = pred_3d_joints_from_mesh[j].tolist() + joint_output_save.append(pred_3d_joints_from_mesh_list) + + if i%20==0: + # obtain 3d joints, which are regressed from the full mesh + pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices) + # obtain 2d joints, which are projected from 3d joints of mesh + pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) + visual_imgs = visualize_mesh( renderer, + annotations['ori_img'].detach(), + annotations['joints_2d'].detach(), + pred_vertices.detach(), + pred_camera.detach(), + pred_2d_joints_from_mesh.detach()) + + visual_imgs = visual_imgs.transpose(0,1) + visual_imgs = visual_imgs.transpose(1,2) + visual_imgs = np.asarray(visual_imgs) + + inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot))) + temp_fname = args.output_dir + args.resume_checkpoint[0:-9] + 'freihand_results_'+inference_setting+'_batch'+str(i)+'.jpg' + cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) + + print('save results to pred.json') + with open('pred.json', 'w') as f: + json.dump([joint_output_save, mesh_output_save], f) + + run_exp_name = args.resume_checkpoint.split('/')[-3] + run_ckpt_name = args.resume_checkpoint.split('/')[-2].split('-')[1] + inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot))) + resolved_submit_cmd = 'zip ' + args.output_dir + run_exp_name + '-ckpt'+ run_ckpt_name + '-' + inference_setting +'-pred.zip ' + 'pred.json' + print(resolved_submit_cmd) + os.system(resolved_submit_cmd) + resolved_submit_cmd = 'rm pred.json' + print(resolved_submit_cmd) + os.system(resolved_submit_cmd) + return + +def visualize_mesh( renderer, + images, + gt_keypoints_2d, + pred_vertices, + pred_camera, + pred_keypoints_2d): + """Tensorboard logging.""" + gt_keypoints_2d = gt_keypoints_2d.cpu().numpy() + to_lsp = list(range(21)) + rend_imgs = [] + batch_size = pred_vertices.shape[0] + # Do visualization for the first 6 images of the batch + for i in range(min(batch_size, 10)): + img = images[i].cpu().numpy().transpose(1,2,0) + # Get LSP keypoints from the full list of keypoints + gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp] + pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp] + # Get predict vertices for the particular example + vertices = pred_vertices[i].cpu().numpy() + cam = pred_camera[i].cpu().numpy() + # Visualize reconstruction and detected pose + rend_img = visualize_reconstruction(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer) + rend_img = rend_img.transpose(2,0,1) + rend_imgs.append(torch.from_numpy(rend_img)) + rend_imgs = make_grid(rend_imgs, nrow=1) + return rend_imgs + +def visualize_mesh_test( renderer, + images, + gt_keypoints_2d, + pred_vertices, + pred_camera, + pred_keypoints_2d, + PAmPJPE): + """Tensorboard logging.""" + gt_keypoints_2d = gt_keypoints_2d.cpu().numpy() + to_lsp = list(range(21)) + rend_imgs = [] + batch_size = pred_vertices.shape[0] + # Do visualization for the first 6 images of the batch + for i in range(min(batch_size, 10)): + img = images[i].cpu().numpy().transpose(1,2,0) + # Get LSP keypoints from the full list of keypoints + gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp] + pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp] + # Get predict vertices for the particular example + vertices = pred_vertices[i].cpu().numpy() + cam = pred_camera[i].cpu().numpy() + score = PAmPJPE[i] + # Visualize reconstruction and detected pose + rend_img = visualize_reconstruction_test(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer, score) + rend_img = rend_img.transpose(2,0,1) + rend_imgs.append(torch.from_numpy(rend_img)) + rend_imgs = make_grid(rend_imgs, nrow=1) + return rend_imgs + +def visualize_mesh_no_text( renderer, + images, + pred_vertices, + pred_camera): + """Tensorboard logging.""" + rend_imgs = [] + batch_size = pred_vertices.shape[0] + # Do visualization for the first 6 images of the batch + for i in range(min(batch_size, 1)): + img = images[i].cpu().numpy().transpose(1,2,0) + # Get predict vertices for the particular example + vertices = pred_vertices[i].cpu().numpy() + cam = pred_camera[i].cpu().numpy() + # Visualize reconstruction only + rend_img = visualize_reconstruction_no_text(img, 224, vertices, cam, renderer, color='hand') + rend_img = rend_img.transpose(2,0,1) + rend_imgs.append(torch.from_numpy(rend_img)) + rend_imgs = make_grid(rend_imgs, nrow=1) + return rend_imgs + +def parse_args(): + parser = argparse.ArgumentParser() + ######################################################### + # Data related arguments + ######################################################### + parser.add_argument("--data_dir", default='datasets', type=str, required=False, + help="Directory with all datasets, each in one subfolder") + parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False, + help="Yaml file with all data for training.") + parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False, + help="Yaml file with all data for validation.") + parser.add_argument("--num_workers", default=4, type=int, + help="Workers in dataloader.") + parser.add_argument("--img_scale_factor", default=1, type=int, + help="adjust image resolution.") + ######################################################### + # Loading/saving checkpoints + ######################################################### + parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, + help="Path to pre-trained transformer model or model type.") + parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, + help="Path to specific checkpoint for resume training.") + parser.add_argument("--output_dir", default='output/', type=str, required=False, + help="The output directory to save checkpoint and test results.") + parser.add_argument("--config_name", default="", type=str, + help="Pretrained config name or path if not the same as model_name.") + parser.add_argument('-a', '--arch', default='hrnet-w64', + help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') + ######################################################### + # Training parameters + ######################################################### + parser.add_argument("--per_gpu_train_batch_size", default=64, type=int, + help="Batch size per GPU/CPU for training.") + parser.add_argument("--per_gpu_eval_batch_size", default=64, type=int, + help="Batch size per GPU/CPU for evaluation.") + parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float, + help="The initial lr.") + parser.add_argument("--num_train_epochs", default=200, type=int, + help="Total number of training epochs to perform.") + parser.add_argument("--vertices_loss_weight", default=1.0, type=float) + parser.add_argument("--joints_loss_weight", default=1.0, type=float) + parser.add_argument("--vloss_w_full", default=0.5, type=float) + parser.add_argument("--vloss_w_sub", default=0.5, type=float) + parser.add_argument("--drop_out", default=0.1, type=float, + help="Drop out ratio in BERT.") + ######################################################### + # Model architectures + ######################################################### + parser.add_argument("--num_hidden_layers", default=-1, type=int, required=False, + help="Update model config if given") + parser.add_argument("--hidden_size", default=-1, type=int, required=False, + help="Update model config if given") + parser.add_argument("--num_attention_heads", default=-1, type=int, required=False, + help="Update model config if given. Note that the division of " + "hidden_size / num_attention_heads should be in integer.") + parser.add_argument("--intermediate_size", default=-1, type=int, required=False, + help="Update model config if given.") + parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--which_gcn", default='0,0,1', type=str, + help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") + parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand") + + ######################################################### + # Others + ######################################################### + parser.add_argument("--run_eval_only", default=False, action='store_true',) + parser.add_argument("--multiscale_inference", default=False, action='store_true',) + # if enable "multiscale_inference", dataloader will apply transformations to the test image based on + # the rotation "rot" and scale "sc" parameters below + parser.add_argument("--rot", default=0, type=float) + parser.add_argument("--sc", default=1.0, type=float) + parser.add_argument("--aml_eval", default=False, action='store_true',) + + parser.add_argument('--logging_steps', type=int, default=100, + help="Log every X steps.") + parser.add_argument("--device", type=str, default='cuda', + help="cuda or cpu") + parser.add_argument('--seed', type=int, default=88, + help="random seed for initialization.") + parser.add_argument("--local_rank", type=int, default=0, + help="For distributed training.") + args = parser.parse_args() + return args + +def main(args): + global logger + # Setup CUDA, GPU & distributed training + args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 + os.environ['OMP_NUM_THREADS'] = str(args.num_workers) + print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) + + args.distributed = args.num_gpus > 1 + args.device = torch.device(args.device) + if args.distributed: + print("Init distributed training on local rank {}".format(args.local_rank)) + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group( + backend='nccl', init_method='env://' + ) + synchronize() + + mkdir(args.output_dir) + logger = setup_logger("Graphormer", args.output_dir, get_rank()) + set_seed(args.seed, args.num_gpus) + logger.info("Using {} GPUs".format(args.num_gpus)) + + # Mesh and SMPL utils + mano_model = MANO().to(args.device) + mano_model.layer = mano_model.layer.to(device) + mesh_sampler = Mesh() + + # Renderer for visualization + renderer = Renderer(faces=mano_model.face) + + # Load pretrained model + trans_encoder = [] + + input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] + hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] + output_feat_dim = input_feat_dim[1:] + [3] + + # which encoder block to have graph convs + which_blk_graph = [int(item) for item in args.which_gcn.split(',')] + + if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: + # if only run eval, load checkpoint + logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) + _model = torch.load(args.resume_checkpoint) + + else: + # init three transformer-encoder blocks in a loop + for i in range(len(output_feat_dim)): + config_class, model_class = BertConfig, Graphormer + config = config_class.from_pretrained(args.config_name if args.config_name \ + else args.model_name_or_path) + + config.output_attentions = False + config.hidden_dropout_prob = args.drop_out + config.img_feature_dim = input_feat_dim[i] + config.output_feature_dim = output_feat_dim[i] + args.hidden_size = hidden_feat_dim[i] + args.intermediate_size = int(args.hidden_size*2) + + if which_blk_graph[i]==1: + config.graph_conv = True + logger.info("Add Graph Conv") + else: + config.graph_conv = False + + config.mesh_type = args.mesh_type + + # update model structure if specified in arguments + update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] + for idx, param in enumerate(update_params): + arg_param = getattr(args, param) + config_param = getattr(config, param) + if arg_param > 0 and arg_param != config_param: + logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) + setattr(config, param, arg_param) + + # init a transformer encoder and append it to a list + assert config.hidden_size % config.num_attention_heads == 0 + model = model_class(config=config) + logger.info("Init model from scratch.") + trans_encoder.append(model) + + # create backbone model + if args.arch=='hrnet': + hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w40 model') + elif args.arch=='hrnet-w64': + hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w64 model') + else: + print("=> using pre-trained model '{}'".format(args.arch)) + backbone = models.__dict__[args.arch](pretrained=True) + # remove the last fc layer + backbone = torch.nn.Sequential(*list(backbone.children())[:-1]) + + trans_encoder = torch.nn.Sequential(*trans_encoder) + total_params = sum(p.numel() for p in trans_encoder.parameters()) + logger.info('Graphormer encoders total parameters: {}'.format(total_params)) + backbone_total_params = sum(p.numel() for p in backbone.parameters()) + logger.info('Backbone total parameters: {}'.format(backbone_total_params)) + + # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder) + _model = Graphormer_Network(args, config, backbone, trans_encoder) + + if args.resume_checkpoint!=None and args.resume_checkpoint!='None': + # for fine-tuning or resume training or inference, load weights from checkpoint + logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) + # workaround approach to load sparse tensor in graph conv. + state_dict = torch.load(args.resume_checkpoint) + _model.load_state_dict(state_dict, strict=False) + del state_dict + gc.collect() + torch.cuda.empty_cache() + + _model.to(args.device) + logger.info("Training parameters %s", args) + + if args.run_eval_only==True: + val_dataloader = make_hand_data_loader(args, args.val_yaml, + args.distributed, is_train=False, scale_factor=args.img_scale_factor) + run_eval_and_save(args, 'freihand', val_dataloader, _model, mano_model, renderer, mesh_sampler) + + else: + train_dataloader = make_hand_data_loader(args, args.train_yaml, + args.distributed, is_train=True, scale_factor=args.img_scale_factor) + run(args, train_dataloader, _model, mano_model, renderer, mesh_sampler) + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/mesh_graphormer/tools/run_gphmer_handmesh_inference.py b/mesh_graphormer/tools/run_gphmer_handmesh_inference.py new file mode 100644 index 0000000..04cd326 --- /dev/null +++ b/mesh_graphormer/tools/run_gphmer_handmesh_inference.py @@ -0,0 +1,338 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +End-to-end inference codes for +3D hand mesh reconstruction from an image +""" + +from __future__ import absolute_import, division, print_function +import argparse +import os +import os.path as op +import code +import json +import time +import datetime +import torch +import torchvision.models as models +from torchvision.utils import make_grid +import gc +import numpy as np +import cv2 +from mesh_graphormer.modeling.bert import BertConfig, Graphormer +from mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network +from mesh_graphormer.modeling._mano import MANO, Mesh +from mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat +from mesh_graphormer.modeling.hrnet.config import config as hrnet_config +from mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config +import mesh_graphormer.modeling.data.config as cfg +from mesh_graphormer.datasets.build import make_hand_data_loader + +from mesh_graphormer.utils.logger import setup_logger +from mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather +from mesh_graphormer.utils.miscellaneous import mkdir, set_seed +from mesh_graphormer.utils.metric_logger import AverageMeter +from mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text +from mesh_graphormer.utils.metric_pampjpe import reconstruction_error +from mesh_graphormer.utils.geometric_layers import orthographic_projection + +from PIL import Image +from torchvision import transforms + + +device = "cuda" + +transform = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])]) + +transform_visualize = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor()]) + +def run_inference(args, image_list, Graphormer_model, mano, renderer, mesh_sampler): +# switch to evaluate mode + Graphormer_model.eval() + mano.eval() + with torch.no_grad(): + for image_file in image_list: + if 'pred' not in image_file: + att_all = [] + print(image_file) + img = Image.open(image_file) + img_tensor = transform(img) + img_visual = transform_visualize(img) + + batch_imgs = torch.unsqueeze(img_tensor, 0).to(device) + batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device) + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler) + # obtain 3d joints from full mesh + pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) + pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:] + pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :] + pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :] + + # save attantion + att_max_value = att[-1] + att_cpu = np.asarray(att_max_value.cpu().detach()) + att_all.append(att_cpu) + + # obtain 3d joints, which are regressed from the full mesh + pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) + # obtain 2d joints, which are projected from 3d joints of mesh + pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) + pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous()) + + + visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0], + pred_vertices[0].detach(), + pred_camera.detach()) + # visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0], + # pred_vertices[0].detach(), + # pred_vertices_sub[0].detach(), + # pred_2d_coarse_vertices_from_mesh[0].detach(), + # pred_2d_joints_from_mesh[0].detach(), + # pred_camera.detach(), + # att[-1][0].detach()) + visual_imgs = visual_imgs_output.transpose(1,2,0) + visual_imgs = np.asarray(visual_imgs) + + temp_fname = image_file[:-4] + '_graphormer_pred.jpg' + print('save to ', temp_fname) + cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) + return + +def visualize_mesh( renderer, images, + pred_vertices_full, + pred_camera): + img = images.cpu().numpy().transpose(1,2,0) + # Get predict vertices for the particular example + vertices_full = pred_vertices_full.cpu().numpy() + cam = pred_camera.cpu().numpy() + # Visualize only mesh reconstruction + rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue') + rend_img = rend_img.transpose(2,0,1) + return rend_img + +def visualize_mesh_and_attention( renderer, images, + pred_vertices_full, + pred_vertices, + pred_2d_vertices, + pred_2d_joints, + pred_camera, + attention): + img = images.cpu().numpy().transpose(1,2,0) + # Get predict vertices for the particular example + vertices_full = pred_vertices_full.cpu().numpy() + vertices = pred_vertices.cpu().numpy() + vertices_2d = pred_2d_vertices.cpu().numpy() + joints_2d = pred_2d_joints.cpu().numpy() + cam = pred_camera.cpu().numpy() + att = attention.cpu().numpy() + # Visualize reconstruction and attention + rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue') + rend_img = rend_img.transpose(2,0,1) + return rend_img + +def parse_args(): + parser = argparse.ArgumentParser() + ######################################################### + # Data related arguments + ######################################################### + parser.add_argument("--num_workers", default=4, type=int, + help="Workers in dataloader.") + parser.add_argument("--img_scale_factor", default=1, type=int, + help="adjust image resolution.") + parser.add_argument("--image_file_or_path", default='./samples/hand', type=str, + help="test data") + ######################################################### + # Loading/saving checkpoints + ######################################################### + parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, + help="Path to pre-trained transformer model or model type.") + parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, + help="Path to specific checkpoint for resume training.") + parser.add_argument("--output_dir", default='output/', type=str, required=False, + help="The output directory to save checkpoint and test results.") + parser.add_argument("--config_name", default="", type=str, + help="Pretrained config name or path if not the same as model_name.") + parser.add_argument('-a', '--arch', default='hrnet-w64', + help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') + ######################################################### + # Model architectures + ######################################################### + parser.add_argument("--num_hidden_layers", default=4, type=int, required=False, + help="Update model config if given") + parser.add_argument("--hidden_size", default=-1, type=int, required=False, + help="Update model config if given") + parser.add_argument("--num_attention_heads", default=4, type=int, required=False, + help="Update model config if given. Note that the division of " + "hidden_size / num_attention_heads should be in integer.") + parser.add_argument("--intermediate_size", default=-1, type=int, required=False, + help="Update model config if given.") + parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--which_gcn", default='0,0,1', type=str, + help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") + parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand") + + ######################################################### + # Others + ######################################################### + parser.add_argument("--run_eval_only", default=True, action='store_true',) + parser.add_argument("--device", type=str, default='cuda', + help="cuda or cpu") + parser.add_argument('--seed', type=int, default=88, + help="random seed for initialization.") + args = parser.parse_args() + return args + +def main(args): + global logger + # Setup CUDA, GPU & distributed training + args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 + os.environ['OMP_NUM_THREADS'] = str(args.num_workers) + print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) + + mkdir(args.output_dir) + logger = setup_logger("Graphormer", args.output_dir, get_rank()) + set_seed(args.seed, args.num_gpus) + logger.info("Using {} GPUs".format(args.num_gpus)) + + # Mesh and MANO utils + mano_model = MANO().to(args.device) + mano_model.layer = mano_model.layer.to(device) + mesh_sampler = Mesh() + + # Renderer for visualization + renderer = Renderer(faces=mano_model.face) + + # Load pretrained model + trans_encoder = [] + + input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] + hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] + output_feat_dim = input_feat_dim[1:] + [3] + + # which encoder block to have graph convs + which_blk_graph = [int(item) for item in args.which_gcn.split(',')] + + if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: + # if only run eval, load checkpoint + logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) + _model = torch.load(args.resume_checkpoint) + + else: + # init three transformer-encoder blocks in a loop + for i in range(len(output_feat_dim)): + config_class, model_class = BertConfig, Graphormer + config = config_class.from_pretrained(args.config_name if args.config_name \ + else args.model_name_or_path) + + config.output_attentions = False + config.img_feature_dim = input_feat_dim[i] + config.output_feature_dim = output_feat_dim[i] + args.hidden_size = hidden_feat_dim[i] + args.intermediate_size = int(args.hidden_size*2) + + if which_blk_graph[i]==1: + config.graph_conv = True + logger.info("Add Graph Conv") + else: + config.graph_conv = False + + config.mesh_type = args.mesh_type + + # update model structure if specified in arguments + update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] + for idx, param in enumerate(update_params): + arg_param = getattr(args, param) + config_param = getattr(config, param) + if arg_param > 0 and arg_param != config_param: + logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) + setattr(config, param, arg_param) + + # init a transformer encoder and append it to a list + assert config.hidden_size % config.num_attention_heads == 0 + model = model_class(config=config) + logger.info("Init model from scratch.") + trans_encoder.append(model) + + # create backbone model + if args.arch=='hrnet': + hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w40 model') + elif args.arch=='hrnet-w64': + hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w64 model') + else: + print("=> using pre-trained model '{}'".format(args.arch)) + backbone = models.__dict__[args.arch](pretrained=True) + # remove the last fc layer + backbone = torch.nn.Sequential(*list(backbone.children())[:-1]) + + trans_encoder = torch.nn.Sequential(*trans_encoder) + total_params = sum(p.numel() for p in trans_encoder.parameters()) + logger.info('Graphormer encoders total parameters: {}'.format(total_params)) + backbone_total_params = sum(p.numel() for p in backbone.parameters()) + logger.info('Backbone total parameters: {}'.format(backbone_total_params)) + + # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder) + _model = Graphormer_Network(args, config, backbone, trans_encoder) + + if args.resume_checkpoint!=None and args.resume_checkpoint!='None': + # for fine-tuning or resume training or inference, load weights from checkpoint + logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) + # workaround approach to load sparse tensor in graph conv. + state_dict = torch.load(args.resume_checkpoint) + _model.load_state_dict(state_dict, strict=False) + del state_dict + gc.collect() + torch.cuda.empty_cache() + + # update configs to enable attention outputs + setattr(_model.trans_encoder[-1].config,'output_attentions', True) + setattr(_model.trans_encoder[-1].config,'output_hidden_states', True) + _model.trans_encoder[-1].bert.encoder.output_attentions = True + _model.trans_encoder[-1].bert.encoder.output_hidden_states = True + for iter_layer in range(4): + _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True + for inter_block in range(3): + setattr(_model.trans_encoder[-1].config,'device', args.device) + + _model.to(args.device) + logger.info("Run inference") + + image_list = [] + if not args.image_file_or_path: + raise ValueError("image_file_or_path not specified") + if op.isfile(args.image_file_or_path): + image_list = [args.image_file_or_path] + elif op.isdir(args.image_file_or_path): + # should be a path with images only + for filename in os.listdir(args.image_file_or_path): + if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename: + image_list.append(args.image_file_or_path+'/'+filename) + else: + raise ValueError("Cannot find images at {}".format(args.image_file_or_path)) + + run_inference(args, image_list, _model, mano_model, renderer, mesh_sampler) + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/mesh_graphormer/tools/run_hand_multiscale.py b/mesh_graphormer/tools/run_hand_multiscale.py new file mode 100644 index 0000000..c00583c --- /dev/null +++ b/mesh_graphormer/tools/run_hand_multiscale.py @@ -0,0 +1,136 @@ +from __future__ import absolute_import, division, print_function + +import argparse +import os +import os.path as op +import code +import json +import zipfile +import torch +import numpy as np +from mesh_graphormer.utils.metric_pampjpe import get_alignMesh + + +def load_pred_json(filepath): + archive = zipfile.ZipFile(filepath, 'r') + jsondata = archive.read('pred.json') + reference = json.loads(jsondata.decode("utf-8")) + return reference[0], reference[1] + + +def multiscale_fusion(output_dir): + s = '10' + filepath = output_dir+'ckpt200-sc10_rot0-pred.zip' + ref_joints, ref_vertices = load_pred_json(filepath) + ref_joints_array = np.asarray(ref_joints) + ref_vertices_array = np.asarray(ref_vertices) + + rotations = [0.0] + for i in range(1,10): + rotations.append(i*10) + rotations.append(i*-10) + + scale = [0.7,0.8,0.9,1.0,1.1] + multiscale_joints = [] + multiscale_vertices = [] + + counter = 0 + for s in scale: + for r in rotations: + setting = 'sc%02d_rot%s'%(int(s*10),str(int(r))) + filepath = output_dir+'ckpt200-'+setting+'-pred.zip' + joints, vertices = load_pred_json(filepath) + joints_array = np.asarray(joints) + vertices_array = np.asarray(vertices) + + pa_joint_error, pa_joint_array, _ = get_alignMesh(joints_array, ref_joints_array, reduction=None) + pa_vertices_error, pa_vertices_array, _ = get_alignMesh(vertices_array, ref_vertices_array, reduction=None) + print('--------------------------') + print('scale:', s, 'rotate', r) + print('PAMPJPE:', 1000*np.mean(pa_joint_error)) + print('PAMPVPE:', 1000*np.mean(pa_vertices_error)) + multiscale_joints.append(pa_joint_array) + multiscale_vertices.append(pa_vertices_array) + counter = counter + 1 + + overall_joints_array = ref_joints_array.copy() + overall_vertices_array = ref_vertices_array.copy() + for i in range(counter): + overall_joints_array += multiscale_joints[i] + overall_vertices_array += multiscale_vertices[i] + + overall_joints_array /= (1+counter) + overall_vertices_array /= (1+counter) + pa_joint_error, pa_joint_array, _ = get_alignMesh(overall_joints_array, ref_joints_array, reduction=None) + pa_vertices_error, pa_vertices_array, _ = get_alignMesh(overall_vertices_array, ref_vertices_array, reduction=None) + print('--------------------------') + print('overall:') + print('PAMPJPE:', 1000*np.mean(pa_joint_error)) + print('PAMPVPE:', 1000*np.mean(pa_vertices_error)) + + joint_output_save = overall_joints_array.tolist() + mesh_output_save = overall_vertices_array.tolist() + + print('save results to pred.json') + with open('pred.json', 'w') as f: + json.dump([joint_output_save, mesh_output_save], f) + + + filepath = output_dir+'ckpt200-multisc-pred.zip' + resolved_submit_cmd = 'zip ' + filepath + ' ' + 'pred.json' + print(resolved_submit_cmd) + os.system(resolved_submit_cmd) + resolved_submit_cmd = 'rm pred.json' + print(resolved_submit_cmd) + os.system(resolved_submit_cmd) + + +def run_multiscale_inference(model_path, mode, output_dir): + + if mode==True: + rotations = [0.0] + for i in range(1,10): + rotations.append(i*10) + rotations.append(i*-10) + scale = [0.7,0.8,0.9,1.0,1.1] + else: + rotations = [0.0] + scale = [1.0] + + job_cmd = "python ./src/tools/run_gphmer_handmesh.py " \ + "--val_yaml freihand_v3/test.yaml " \ + "--resume_checkpoint %s " \ + "--per_gpu_eval_batch_size 32 --run_eval_only --num_worker 2 " \ + "--multiscale_inference " \ + "--rot %f " \ + "--sc %s " \ + "--arch hrnet-w64 " \ + "--num_hidden_layers 4 " \ + "--num_attention_heads 4 " \ + "--input_feat_dim 2051,512,128 " \ + "--hidden_feat_dim 1024,256,64 " \ + "--output_dir %s" + + for s in scale: + for r in rotations: + resolved_submit_cmd = job_cmd%(model_path, r, s, output_dir) + print(resolved_submit_cmd) + os.system(resolved_submit_cmd) + +def main(args): + model_path = args.model_path + mode = args.multiscale_inference + output_dir = args.output_dir + run_multiscale_inference(model_path, mode, output_dir) + if mode==True: + multiscale_fusion(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate a checkpoint in the folder") + parser.add_argument("--model_path") + parser.add_argument("--multiscale_inference", default=False, action='store_true',) + parser.add_argument("--output_dir", default='output/', type=str, required=False, + help="The output directory to save checkpoint and test results.") + args = parser.parse_args() + main(args) diff --git a/mesh_graphormer/utils/__init__.py b/mesh_graphormer/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mesh_graphormer/utils/comm.py b/mesh_graphormer/utils/comm.py new file mode 100644 index 0000000..b82e7f8 --- /dev/null +++ b/mesh_graphormer/utils/comm.py @@ -0,0 +1,176 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import pickle +import time + +import torch +import torch.distributed as dist + + +device = "cuda" + + +def get_world_size(): + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +def gather_on_master(data): + """Same as all_gather, but gathers data on master process only, using CPU. + Thus, this does not work with NCCL backend unless they add CPU support. + + The memory consumption of this function is ~ 3x of data size. While in + principal, it should be ~2x, it's not easy to force Python to release + memory immediately and thus, peak memory usage could be up to 3x. + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + # trying to optimize memory, but in fact, it's not guaranteed to be released + del data + storage = torch.ByteStorage.from_buffer(buffer) + del buffer + tensor = torch.ByteTensor(storage) + + # obtain Tensor size of each rank + local_size = torch.LongTensor([tensor.numel()]) + size_list = [torch.LongTensor([0]) for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)) + tensor = torch.cat((tensor, padding), dim=0) + del padding + + if is_main_process(): + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size,))) + dist.gather(tensor, gather_list=tensor_list, dst=0) + del tensor + else: + dist.gather(tensor, gather_list=[], dst=0) + del tensor + return + + data_list = [] + for tensor in tensor_list: + buffer = tensor.cpu().numpy().tobytes() + del tensor + data_list.append(pickle.loads(buffer)) + del buffer + + return data_list + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device) + + # obtain Tensor size of each rank + local_size = torch.LongTensor([tensor.numel()]).to(device) + size_list = [torch.LongTensor([0]).to(device) for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size,)).to(device)) + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)).to(device) + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that process with rank + 0 has the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/mesh_graphormer/utils/dataset_utils.py b/mesh_graphormer/utils/dataset_utils.py new file mode 100644 index 0000000..cb66451 --- /dev/null +++ b/mesh_graphormer/utils/dataset_utils.py @@ -0,0 +1,66 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + + +import os +import os.path as op +import numpy as np +import base64 +import cv2 +import yaml +from collections import OrderedDict + + +def img_from_base64(imagestring): + try: + jpgbytestring = base64.b64decode(imagestring) + nparr = np.frombuffer(jpgbytestring, np.uint8) + r = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + return r + except: + return None + + +def load_labelmap(labelmap_file): + label_dict = None + if labelmap_file is not None and op.isfile(labelmap_file): + label_dict = OrderedDict() + with open(labelmap_file, 'r') as fp: + for line in fp: + label = line.strip().split('\t')[0] + if label in label_dict: + raise ValueError("Duplicate label " + label + " in labelmap.") + else: + label_dict[label] = len(label_dict) + return label_dict + + +def load_shuffle_file(shuf_file): + shuf_list = None + if shuf_file is not None: + with open(shuf_file, 'r') as fp: + shuf_list = [] + for i in fp: + shuf_list.append(int(i.strip())) + return shuf_list + + +def load_box_shuffle_file(shuf_file): + if shuf_file is not None: + with open(shuf_file, 'r') as fp: + img_shuf_list = [] + box_shuf_list = [] + for i in fp: + idx = [int(_) for _ in i.strip().split('\t')] + img_shuf_list.append(idx[0]) + box_shuf_list.append(idx[1]) + return [img_shuf_list, box_shuf_list] + return None + + +def load_from_yaml_file(file_name): + with open(file_name, 'r') as fp: + return yaml.load(fp, Loader=yaml.CLoader) diff --git a/mesh_graphormer/utils/geometric_layers.py b/mesh_graphormer/utils/geometric_layers.py new file mode 100644 index 0000000..ad4bf9d --- /dev/null +++ b/mesh_graphormer/utils/geometric_layers.py @@ -0,0 +1,58 @@ +""" +Useful geometric operations, e.g. Orthographic projection and a differentiable Rodrigues formula + +Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR +""" +import torch + +def rodrigues(theta): + """Convert axis-angle representation to rotation matrix. + Args: + theta: size = [B, 3] + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1) + angle = torch.unsqueeze(l1norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim = 1) + return quat2mat(quat) + +def quat2mat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + +def orthographic_projection(X, camera): + """Perform orthographic projection of 3D points X using the camera parameters + Args: + X: size = [B, N, 3] + camera: size = [B, 3] + Returns: + Projected 2D points -- size = [B, N, 2] + """ + camera = camera.view(-1, 1, 3) + X_trans = X[:, :, :2] + camera[:, :, 1:] + shape = X_trans.shape + X_2d = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape) + return X_2d diff --git a/mesh_graphormer/utils/image_ops.py b/mesh_graphormer/utils/image_ops.py new file mode 100644 index 0000000..eb62d8f --- /dev/null +++ b/mesh_graphormer/utils/image_ops.py @@ -0,0 +1,208 @@ +""" +Image processing tools + +Modified from open source projects: +(https://github.com/nkolot/GraphCMR/) +(https://github.com/open-mmlab/mmdetection) + +""" + +import numpy as np +import base64 +import cv2 +import torch +import scipy.misc + +def img_from_base64(imagestring): + try: + jpgbytestring = base64.b64decode(imagestring) + nparr = np.frombuffer(jpgbytestring, np.uint8) + r = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + return r + except ValueError: + return None + +def myimrotate(img, angle, center=None, scale=1.0, border_value=0, auto_bound=False): + if center is not None and auto_bound: + raise ValueError('`auto_bound` conflicts with `center`') + h, w = img.shape[:2] + if center is None: + center = ((w - 1) * 0.5, (h - 1) * 0.5) + assert isinstance(center, tuple) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + if auto_bound: + cos = np.abs(matrix[0, 0]) + sin = np.abs(matrix[0, 1]) + new_w = h * sin + w * cos + new_h = h * cos + w * sin + matrix[0, 2] += (new_w - w) * 0.5 + matrix[1, 2] += (new_h - h) * 0.5 + w = int(np.round(new_w)) + h = int(np.round(new_h)) + rotated = cv2.warpAffine(img, matrix, (w, h), borderValue=border_value) + return rotated + +def myimresize(img, size, return_scale=False, interpolation='bilinear'): + + h, w = img.shape[:2] + resized_img = cv2.resize( + img, (size[0],size[1]), interpolation=cv2.INTER_LINEAR) + if not return_scale: + return resized_img + else: + w_scale = size[0] / w + h_scale = size[1] / h + return resized_img, w_scale, h_scale + + +def get_transform(center, scale, res, rot=0): + """Generate transformation matrix.""" + h = 200 * scale + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / h + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / h + .5) + t[1, 2] = res[0] * (-float(center[1]) / h + .5) + t[2, 2] = 1 + if not rot == 0: + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3,3)) + rot_rad = rot * np.pi / 180 + sn,cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0,:2] = [cs, -sn] + rot_mat[1,:2] = [sn, cs] + rot_mat[2,2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0,2] = -res[1]/2 + t_mat[1,2] = -res[0]/2 + t_inv = t_mat.copy() + t_inv[:2,2] *= -1 + t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t))) + return t + +def transform(pt, center, scale, res, invert=0, rot=0): + """Transform pixel location to different reference.""" + t = get_transform(center, scale, res, rot=rot) + if invert: + # t = np.linalg.inv(t) + t_torch = torch.from_numpy(t) + t_torch = torch.inverse(t_torch) + t = t_torch.numpy() + new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T + new_pt = np.dot(t, new_pt) + return new_pt[:2].astype(int)+1 + +def crop(img, center, scale, res, rot=0): + """Crop image according to the supplied bounding box.""" + # Upper left point + ul = np.array(transform([1, 1], center, scale, res, invert=1))-1 + # Bottom right point + br = np.array(transform([res[0]+1, + res[1]+1], center, scale, res, invert=1))-1 + # Padding so that when rotated proper amount of context is included + pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + if not rot == 0: + ul -= pad + br += pad + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(new_shape) + + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] + new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(len(img[0]), br[0]) + old_y = max(0, ul[1]), min(len(img), br[1]) + + new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], + old_x[0]:old_x[1]] + if not rot == 0: + # Remove padding + # new_img = scipy.misc.imrotate(new_img, rot) + new_img = myimrotate(new_img, rot) + new_img = new_img[pad:-pad, pad:-pad] + + # new_img = scipy.misc.imresize(new_img, res) + new_img = myimresize(new_img, [res[0], res[1]]) + return new_img + +def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True): + """'Undo' the image cropping/resizing. + This function is used when evaluating mask/part segmentation. + """ + res = img.shape[:2] + # Upper left point + ul = np.array(transform([1, 1], center, scale, res, invert=1))-1 + # Bottom right point + br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1 + # size of cropped image + crop_shape = [br[1] - ul[1], br[0] - ul[0]] + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(orig_shape, dtype=np.uint8) + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0] + new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(orig_shape[1], br[0]) + old_y = max(0, ul[1]), min(orig_shape[0], br[1]) + # img = scipy.misc.imresize(img, crop_shape, interp='nearest') + img = myimresize(img, [crop_shape[0],crop_shape[1]]) + new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]] + return new_img + +def rot_aa(aa, rot): + """Rotate axis angle parameters.""" + # pose parameters + R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], + [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], + [0, 0, 1]]) + # find the rotation of the body in camera frame + per_rdg, _ = cv2.Rodrigues(aa) + # apply the global rotation to the global orientation + resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg)) + aa = (resrot.T)[0] + return aa + +def flip_img(img): + """Flip rgb images or masks. + channels come last, e.g. (256,256,3). + """ + img = np.fliplr(img) + return img + +def flip_kp(kp): + """Flip keypoints.""" + flipped_parts = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22] + kp = kp[flipped_parts] + kp[:,0] = - kp[:,0] + return kp + +def flip_pose(pose): + """Flip pose. + The flipping is based on SMPL parameters. + """ + flippedParts = [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13, + 14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33, + 34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41, + 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55, + 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68] + pose = pose[flippedParts] + # we also negate the second and the third dimension of the axis-angle + pose[1::3] = -pose[1::3] + pose[2::3] = -pose[2::3] + return pose + +def flip_aa(aa): + """Flip axis-angle representation. + We negate the second and the third dimension of the axis-angle. + """ + aa[1] = -aa[1] + aa[2] = -aa[2] + return aa \ No newline at end of file diff --git a/mesh_graphormer/utils/logger.py b/mesh_graphormer/utils/logger.py new file mode 100644 index 0000000..0131396 --- /dev/null +++ b/mesh_graphormer/utils/logger.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import logging +import os +import sys +from logging import StreamHandler, Handler, getLevelName + + +# this class is a copy of logging.FileHandler except we end self.close() +# at the end of each emit. While closing file and reopening file after each +# write is not efficient, it allows us to see partial logs when writing to +# fused Azure blobs, which is very convenient +class FileHandler(StreamHandler): + """ + A handler class which writes formatted logging records to disk files. + """ + def __init__(self, filename, mode='a', encoding=None, delay=False): + """ + Open the specified file and use it as the stream for logging. + """ + # Issue #27493: add support for Path objects to be passed in + filename = os.fspath(filename) + #keep the absolute path, otherwise derived classes which use this + #may come a cropper when the current directory changes + self.baseFilename = os.path.abspath(filename) + self.mode = mode + self.encoding = encoding + self.delay = delay + if delay: + #We don't open the stream, but we still need to call the + #Handler constructor to set level, formatter, lock etc. + Handler.__init__(self) + self.stream = None + else: + StreamHandler.__init__(self, self._open()) + + def close(self): + """ + Closes the stream. + """ + self.acquire() + try: + try: + if self.stream: + try: + self.flush() + finally: + stream = self.stream + self.stream = None + if hasattr(stream, "close"): + stream.close() + finally: + # Issue #19523: call unconditionally to + # prevent a handler leak when delay is set + StreamHandler.close(self) + finally: + self.release() + + def _open(self): + """ + Open the current base file with the (original) mode and encoding. + Return the resulting stream. + """ + return open(self.baseFilename, self.mode, encoding=self.encoding) + + def emit(self, record): + """ + Emit a record. + + If the stream was not opened because 'delay' was specified in the + constructor, open it before calling the superclass's emit. + """ + if self.stream is None: + self.stream = self._open() + StreamHandler.emit(self, record) + self.close() + + def __repr__(self): + level = getLevelName(self.level) + return '<%s %s (%s)>' % (self.__class__.__name__, self.baseFilename, level) + + +def setup_logger(name, save_dir, distributed_rank, filename="log.txt"): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + # don't log results for the non-master process + if distributed_rank > 0: + return logger + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") + ch.setFormatter(formatter) + logger.addHandler(ch) + + if save_dir: + fh = FileHandler(os.path.join(save_dir, filename)) + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logger diff --git a/mesh_graphormer/utils/metric_logger.py b/mesh_graphormer/utils/metric_logger.py new file mode 100644 index 0000000..ddaa0ab --- /dev/null +++ b/mesh_graphormer/utils/metric_logger.py @@ -0,0 +1,45 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Basic logger. It Computes and stores the average and current value +""" + +class AverageMeter(object): + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + + +class EvalMetricsLogger(object): + + def __init__(self): + self.reset() + + def reset(self): + # define a upper-bound performance (worst case) + # numbers are in unit millimeter + self.PAmPJPE = 100.0/1000.0 + self.mPJPE = 100.0/1000.0 + self.mPVE = 100.0/1000.0 + + self.epoch = 0 + + def update(self, mPVE, mPJPE, PAmPJPE, epoch): + self.PAmPJPE = PAmPJPE + self.mPJPE = mPJPE + self.mPVE = mPVE + self.epoch = epoch diff --git a/mesh_graphormer/utils/metric_pampjpe.py b/mesh_graphormer/utils/metric_pampjpe.py new file mode 100644 index 0000000..89fe55b --- /dev/null +++ b/mesh_graphormer/utils/metric_pampjpe.py @@ -0,0 +1,99 @@ +""" +Functions for compuing Procrustes alignment and reconstruction error + +Parts of the code are adapted from https://github.com/akanazawa/hmr + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np + +def compute_similarity_transform(S1, S2): + """Computes a similarity transform (sR, t) that takes + a set of 3D points S1 (3 x N) closest to a set of 3D points S2, + where R is an 3x3 rotation matrix, t 3x1 translation, s scale. + i.e. solves the orthogonal Procrutes problem. + """ + transposed = False + if S1.shape[0] != 3 and S1.shape[0] != 2: + S1 = S1.T + S2 = S2.T + transposed = True + assert(S2.shape[1] == S1.shape[1]) + + # 1. Remove mean. + mu1 = S1.mean(axis=1, keepdims=True) + mu2 = S2.mean(axis=1, keepdims=True) + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = np.sum(X1**2) + + # 3. The outer product of X1 and X2. + K = X1.dot(X2.T) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are + # singular vectors of K. + U, s, Vh = np.linalg.svd(K) + V = Vh.T + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = np.eye(U.shape[0]) + Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) + # Construct R. + R = V.dot(Z.dot(U.T)) + + # 5. Recover scale. + scale = np.trace(R.dot(K)) / var1 + + # 6. Recover translation. + t = mu2 - scale*(R.dot(mu1)) + + # 7. Error: + S1_hat = scale*R.dot(S1) + t + + if transposed: + S1_hat = S1_hat.T + + return S1_hat + +def compute_similarity_transform_batch(S1, S2): + """Batched version of compute_similarity_transform.""" + S1_hat = np.zeros_like(S1) + for i in range(S1.shape[0]): + S1_hat[i] = compute_similarity_transform(S1[i], S2[i]) + return S1_hat + +def reconstruction_error(S1, S2, reduction='mean'): + """Do Procrustes alignment and compute reconstruction error.""" + S1_hat = compute_similarity_transform_batch(S1, S2) + re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1) + if reduction == 'mean': + re = re.mean() + elif reduction == 'sum': + re = re.sum() + return re + + +def reconstruction_error_v2(S1, S2, J24_TO_J14, reduction='mean'): + """Do Procrustes alignment and compute reconstruction error.""" + S1_hat = compute_similarity_transform_batch(S1, S2) + S1_hat = S1_hat[:,J24_TO_J14,:] + S2 = S2[:,J24_TO_J14,:] + re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1) + if reduction == 'mean': + re = re.mean() + elif reduction == 'sum': + re = re.sum() + return re + +def get_alignMesh(S1, S2, reduction='mean'): + """Do Procrustes alignment and compute reconstruction error.""" + S1_hat = compute_similarity_transform_batch(S1, S2) + re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1) + if reduction == 'mean': + re = re.mean() + elif reduction == 'sum': + re = re.sum() + return re, S1_hat, S2 diff --git a/mesh_graphormer/utils/miscellaneous.py b/mesh_graphormer/utils/miscellaneous.py new file mode 100644 index 0000000..3de72c6 --- /dev/null +++ b/mesh_graphormer/utils/miscellaneous.py @@ -0,0 +1,171 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import errno +import os +import os.path as op +import re +import logging +import numpy as np +import torch +import random +import shutil +from .comm import is_main_process +import yaml + + +def mkdir(path): + # if it is the current folder, skip. + # otherwise the original code will raise FileNotFoundError + if path == '': + return + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def save_config(cfg, path): + if is_main_process(): + with open(path, 'w') as f: + f.write(cfg.dump()) + + +def config_iteration(output_dir, max_iter): + save_file = os.path.join(output_dir, 'last_checkpoint') + iteration = -1 + if os.path.exists(save_file): + with open(save_file, 'r') as f: + fname = f.read().strip() + model_name = os.path.basename(fname) + model_path = os.path.dirname(fname) + if model_name.startswith('model_') and len(model_name) == 17: + iteration = int(model_name[-11:-4]) + elif model_name == "model_final": + iteration = max_iter + elif model_path.startswith('checkpoint-') and len(model_path) == 18: + iteration = int(model_path.split('-')[-1]) + return iteration + + +def get_matching_parameters(model, regexp, none_on_empty=True): + """Returns parameters matching regular expression""" + if not regexp: + if none_on_empty: + return {} + else: + return dict(model.named_parameters()) + compiled_pattern = re.compile(regexp) + params = {} + for weight_name, weight in model.named_parameters(): + if compiled_pattern.match(weight_name): + params[weight_name] = weight + return params + + +def freeze_weights(model, regexp): + """Freeze weights based on regular expression.""" + logger = logging.getLogger("maskrcnn_benchmark.trainer") + for weight_name, weight in get_matching_parameters(model, regexp).items(): + weight.requires_grad = False + logger.info("Disabled training of {}".format(weight_name)) + + +def unfreeze_weights(model, regexp, backbone_freeze_at=-1, + is_distributed=False): + """Unfreeze weights based on regular expression. + This is helpful during training to unfreeze freezed weights after + other unfreezed weights have been trained for some iterations. + """ + logger = logging.getLogger("maskrcnn_benchmark.trainer") + for weight_name, weight in get_matching_parameters(model, regexp).items(): + weight.requires_grad = True + logger.info("Enabled training of {}".format(weight_name)) + if backbone_freeze_at >= 0: + logger.info("Freeze backbone at stage: {}".format(backbone_freeze_at)) + if is_distributed: + model.module.backbone.body._freeze_backbone(backbone_freeze_at) + else: + model.backbone.body._freeze_backbone(backbone_freeze_at) + + +def delete_tsv_files(tsvs): + for t in tsvs: + if op.isfile(t): + try_delete(t) + line = op.splitext(t)[0] + '.lineidx' + if op.isfile(line): + try_delete(line) + + +def concat_files(ins, out): + mkdir(op.dirname(out)) + out_tmp = out + '.tmp' + with open(out_tmp, 'wb') as fp_out: + for i, f in enumerate(ins): + logging.info('concating {}/{} - {}'.format(i, len(ins), f)) + with open(f, 'rb') as fp_in: + shutil.copyfileobj(fp_in, fp_out, 1024*1024*10) + os.rename(out_tmp, out) + + +def concat_tsv_files(tsvs, out_tsv): + concat_files(tsvs, out_tsv) + sizes = [os.stat(t).st_size for t in tsvs] + sizes = np.cumsum(sizes) + all_idx = [] + for i, t in enumerate(tsvs): + for idx in load_list_file(op.splitext(t)[0] + '.lineidx'): + if i == 0: + all_idx.append(idx) + else: + all_idx.append(str(int(idx) + sizes[i - 1])) + with open(op.splitext(out_tsv)[0] + '.lineidx', 'w') as f: + f.write('\n'.join(all_idx)) + + +def load_list_file(fname): + with open(fname, 'r') as fp: + lines = fp.readlines() + result = [line.strip() for line in lines] + if len(result) > 0 and result[-1] == '': + result = result[:-1] + return result + + +def try_once(func): + def func_wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logging.info('ignore error \n{}'.format(str(e))) + return func_wrapper + + +@try_once +def try_delete(f): + os.remove(f) + + +def set_seed(seed, n_gpu): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if n_gpu > 0: + torch.cuda.manual_seed_all(seed) + + +def print_and_run_cmd(cmd): + print(cmd) + os.system(cmd) + + +def write_to_yaml_file(context, file_name): + with open(file_name, 'w') as fp: + yaml.dump(context, fp, encoding='utf-8') + + +def load_from_yaml_file(yaml_file): + with open(yaml_file, 'r') as fp: + return yaml.load(fp, Loader=yaml.CLoader) + + diff --git a/mesh_graphormer/utils/renderer.py b/mesh_graphormer/utils/renderer.py new file mode 100644 index 0000000..3b2f8a9 --- /dev/null +++ b/mesh_graphormer/utils/renderer.py @@ -0,0 +1,691 @@ +""" +Rendering tools for 3D mesh visualization on 2D image. + +Parts of the code are taken from https://github.com/akanazawa/hmr +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import cv2 +import code +from opendr.camera import ProjectPoints +from opendr.renderer import ColoredRenderer, TexturedRenderer +from opendr.lighting import LambertianPointLight +import random + + +# Rotate the points by a specified angle. +def rotateY(points, angle): + ry = np.array([ + [np.cos(angle), 0., np.sin(angle)], [0., 1., 0.], + [-np.sin(angle), 0., np.cos(angle)] + ]) + return np.dot(points, ry) + +def draw_skeleton(input_image, joints, draw_edges=True, vis=None, radius=None): + """ + joints is 3 x 19. but if not will transpose it. + 0: Right ankle + 1: Right knee + 2: Right hip + 3: Left hip + 4: Left knee + 5: Left ankle + 6: Right wrist + 7: Right elbow + 8: Right shoulder + 9: Left shoulder + 10: Left elbow + 11: Left wrist + 12: Neck + 13: Head top + 14: nose + 15: left_eye + 16: right_eye + 17: left_ear + 18: right_ear + """ + + if radius is None: + radius = max(4, (np.mean(input_image.shape[:2]) * 0.01).astype(int)) + + colors = { + 'pink': (197, 27, 125), # L lower leg + 'light_pink': (233, 163, 201), # L upper leg + 'light_green': (161, 215, 106), # L lower arm + 'green': (77, 146, 33), # L upper arm + 'red': (215, 48, 39), # head + 'light_red': (252, 146, 114), # head + 'light_orange': (252, 141, 89), # chest + 'purple': (118, 42, 131), # R lower leg + 'light_purple': (175, 141, 195), # R upper + 'light_blue': (145, 191, 219), # R lower arm + 'blue': (69, 117, 180), # R upper arm + 'gray': (130, 130, 130), # + 'white': (255, 255, 255), # + } + + image = input_image.copy() + input_is_float = False + + if np.issubdtype(image.dtype, np.float): + input_is_float = True + max_val = image.max() + if max_val <= 2.: # should be 1 but sometimes it's slightly above 1 + image = (image * 255).astype(np.uint8) + else: + image = (image).astype(np.uint8) + + if joints.shape[0] != 2: + joints = joints.T + joints = np.round(joints).astype(int) + + jcolors = [ + 'light_pink', 'light_pink', 'light_pink', 'pink', 'pink', 'pink', + 'light_blue', 'light_blue', 'light_blue', 'blue', 'blue', 'blue', + 'purple', 'purple', 'red', 'green', 'green', 'white', 'white', + 'purple', 'purple', 'red', 'green', 'green', 'white', 'white' + ] + + if joints.shape[1] == 19: + # parent indices -1 means no parents + parents = np.array([ + 1, 2, 8, 9, 3, 4, 7, 8, 12, 12, 9, 10, 14, -1, 13, -1, -1, 15, 16 + ]) + # Left is light and right is dark + ecolors = { + 0: 'light_pink', + 1: 'light_pink', + 2: 'light_pink', + 3: 'pink', + 4: 'pink', + 5: 'pink', + 6: 'light_blue', + 7: 'light_blue', + 8: 'light_blue', + 9: 'blue', + 10: 'blue', + 11: 'blue', + 12: 'purple', + 17: 'light_green', + 18: 'light_green', + 14: 'purple' + } + elif joints.shape[1] == 14: + parents = np.array([ + 1, + 2, + 8, + 9, + 3, + 4, + 7, + 8, + -1, + -1, + 9, + 10, + 13, + -1, + ]) + ecolors = { + 0: 'light_pink', + 1: 'light_pink', + 2: 'light_pink', + 3: 'pink', + 4: 'pink', + 5: 'pink', + 6: 'light_blue', + 7: 'light_blue', + 10: 'light_blue', + 11: 'blue', + 12: 'purple' + } + elif joints.shape[1] == 21: # hand + parents = np.array([ + -1, + 0, + 1, + 2, + 3, + 0, + 5, + 6, + 7, + 0, + 9, + 10, + 11, + 0, + 13, + 14, + 15, + 0, + 17, + 18, + 19, + ]) + ecolors = { + 0: 'light_purple', + 1: 'light_green', + 2: 'light_green', + 3: 'light_green', + 4: 'light_green', + 5: 'pink', + 6: 'pink', + 7: 'pink', + 8: 'pink', + 9: 'light_blue', + 10: 'light_blue', + 11: 'light_blue', + 12: 'light_blue', + 13: 'light_red', + 14: 'light_red', + 15: 'light_red', + 16: 'light_red', + 17: 'purple', + 18: 'purple', + 19: 'purple', + 20: 'purple', + } + else: + print('Unknown skeleton!!') + + for child in range(len(parents)): + point = joints[:, child] + # If invisible skip + if vis is not None and vis[child] == 0: + continue + if draw_edges: + cv2.circle(image, (point[0], point[1]), radius, colors['white'], + -1) + cv2.circle(image, (point[0], point[1]), radius - 1, + colors[jcolors[child]], -1) + else: + # cv2.circle(image, (point[0], point[1]), 5, colors['white'], 1) + cv2.circle(image, (point[0], point[1]), radius - 1, + colors[jcolors[child]], 1) + # cv2.circle(image, (point[0], point[1]), 5, colors['gray'], -1) + pa_id = parents[child] + if draw_edges and pa_id >= 0: + if vis is not None and vis[pa_id] == 0: + continue + point_pa = joints[:, pa_id] + cv2.circle(image, (point_pa[0], point_pa[1]), radius - 1, + colors[jcolors[pa_id]], -1) + if child not in ecolors.keys(): + print('bad') + import ipdb + ipdb.set_trace() + cv2.line(image, (point[0], point[1]), (point_pa[0], point_pa[1]), + colors[ecolors[child]], radius - 2) + + # Convert back in original dtype + if input_is_float: + if max_val <= 1.: + image = image.astype(np.float32) / 255. + else: + image = image.astype(np.float32) + + return image + +def draw_text(input_image, content): + """ + content is a dict. draws key: val on image + Assumes key is str, val is float + """ + image = input_image.copy() + input_is_float = False + if np.issubdtype(image.dtype, np.float): + input_is_float = True + image = (image * 255).astype(np.uint8) + + black = (255, 255, 0) + margin = 15 + start_x = 5 + start_y = margin + for key in sorted(content.keys()): + text = "%s: %.2g" % (key, content[key]) + cv2.putText(image, text, (start_x, start_y), 0, 0.45, black) + start_y += margin + + if input_is_float: + image = image.astype(np.float32) / 255. + return image + +def visualize_reconstruction(img, img_size, gt_kp, vertices, pred_kp, camera, renderer, color='pink', focal_length=1000): + """Overlays gt_kp and pred_kp on img. + Draws vert with text. + Renderer is an instance of SMPLRenderer. + """ + gt_vis = gt_kp[:, 2].astype(bool) + loss = np.sum((gt_kp[gt_vis, :2] - pred_kp[gt_vis])**2) + debug_text = {"sc": camera[0], "tx": camera[1], "ty": camera[2], "kpl": loss} + # Fix a flength so i can render this with persp correct scale + res = img.shape[1] + camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)]) + rend_img = renderer.render(vertices, camera_t=camera_t, + img=img, use_bg=True, + focal_length=focal_length, + body_color=color) + rend_img = draw_text(rend_img, debug_text) + + # Draw skeleton + gt_joint = ((gt_kp[:, :2] + 1) * 0.5) * img_size + pred_joint = ((pred_kp + 1) * 0.5) * img_size + img_with_gt = draw_skeleton( img, gt_joint, draw_edges=False, vis=gt_vis) + skel_img = draw_skeleton(img_with_gt, pred_joint) + + combined = np.hstack([skel_img, rend_img]) + + return combined + +def visualize_reconstruction_test(img, img_size, gt_kp, vertices, pred_kp, camera, renderer, score, color='pink', focal_length=1000): + """Overlays gt_kp and pred_kp on img. + Draws vert with text. + Renderer is an instance of SMPLRenderer. + """ + gt_vis = gt_kp[:, 2].astype(bool) + loss = np.sum((gt_kp[gt_vis, :2] - pred_kp[gt_vis])**2) + debug_text = {"sc": camera[0], "tx": camera[1], "ty": camera[2], "kpl": loss, "pa-mpjpe": score*1000} + # Fix a flength so i can render this with persp correct scale + res = img.shape[1] + camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)]) + rend_img = renderer.render(vertices, camera_t=camera_t, + img=img, use_bg=True, + focal_length=focal_length, + body_color=color) + rend_img = draw_text(rend_img, debug_text) + + # Draw skeleton + gt_joint = ((gt_kp[:, :2] + 1) * 0.5) * img_size + pred_joint = ((pred_kp + 1) * 0.5) * img_size + img_with_gt = draw_skeleton( img, gt_joint, draw_edges=False, vis=gt_vis) + skel_img = draw_skeleton(img_with_gt, pred_joint) + + combined = np.hstack([skel_img, rend_img]) + + return combined + + + +def visualize_reconstruction_and_att(img, img_size, vertices_full, vertices, vertices_2d, camera, renderer, ref_points, attention, focal_length=1000): + """Overlays gt_kp and pred_kp on img. + Draws vert with text. + Renderer is an instance of SMPLRenderer. + """ + # Fix a flength so i can render this with persp correct scale + res = img.shape[1] + camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)]) + rend_img = renderer.render(vertices_full, camera_t=camera_t, + img=img, use_bg=True, + focal_length=focal_length, body_color='light_blue') + + + heads_num, vertex_num, _ = attention.shape + + all_head = np.zeros((vertex_num,vertex_num)) + + ###### find max + # for i in range(vertex_num): + # for j in range(vertex_num): + # all_head[i,j] = np.max(attention[:,i,j]) + + ##### find avg + for h in range(4): + att_per_img = attention[h] + all_head = all_head + att_per_img + all_head = all_head/4 + + col_sums = all_head.sum(axis=0) + all_head = all_head / col_sums[np.newaxis, :] + + + # code.interact(local=locals()) + + combined = [] + if vertex_num>400: # body + selected_joints = [6,7,4,5,13] # [6,7,4,5,13,12] + else: # hand + selected_joints = [0, 4, 8, 12, 16, 20] + # Draw attention + for ii in range(len(selected_joints)): + reference_id = selected_joints[ii] + ref_point = ref_points[reference_id] + attention_to_show = all_head[reference_id][14::] + min_v = np.min(attention_to_show) + max_v = np.max(attention_to_show) + norm_attention_to_show = (attention_to_show - min_v)/(max_v-min_v) + + vertices_norm = ((vertices_2d + 1) * 0.5) * img_size + ref_norm = ((ref_point + 1) * 0.5) * img_size + image = np.zeros_like(rend_img) + + for jj in range(vertices_norm.shape[0]): + x = int(vertices_norm[jj,0]) + y = int(vertices_norm[jj,1]) + cv2.circle(image,(x,y), 1, (255,255,255), -1) + + total_to_draw = [] + for jj in range(vertices_norm.shape[0]): + thres = 0.0 + if norm_attention_to_show[jj]>thres: + things = [norm_attention_to_show[jj], ref_norm, vertices_norm[jj]] + total_to_draw.append(things) + # plot_one_line(ref_norm, vertices_norm[jj], image, reference_id, alpha=0.4*(norm_attention_to_show[jj]-thres)/(1-thres) ) + total_to_draw.sort() + max_att_score = total_to_draw[-1][0] + for item in total_to_draw: + attention_score = item[0] + ref_point = item[1] + vertex = item[2] + plot_one_line(ref_point, vertex, image, ii, alpha=(attention_score-thres)/(max_att_score-thres) ) + # code.interact(local=locals()) + if len(combined)==0: + combined = image + else: + combined = np.hstack([combined, image]) + + final = np.hstack([img, combined, rend_img]) + + return final + + +def visualize_reconstruction_and_att_local(img, img_size, vertices_full, vertices, vertices_2d, camera, renderer, ref_points, attention, color='light_blue', focal_length=1000): + """Overlays gt_kp and pred_kp on img. + Draws vert with text. + Renderer is an instance of SMPLRenderer. + """ + # Fix a flength so i can render this with persp correct scale + res = img.shape[1] + camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)]) + rend_img = renderer.render(vertices_full, camera_t=camera_t, + img=img, use_bg=True, + focal_length=focal_length, body_color=color) + heads_num, vertex_num, _ = attention.shape + all_head = np.zeros((vertex_num,vertex_num)) + + ##### compute avg attention for 4 attention heads + for h in range(4): + att_per_img = attention[h] + all_head = all_head + att_per_img + all_head = all_head/4 + + col_sums = all_head.sum(axis=0) + all_head = all_head / col_sums[np.newaxis, :] + + combined = [] + if vertex_num>400: # body + selected_joints = [7] # [6,7,4,5,13,12] + else: # hand + selected_joints = [0] # [0, 4, 8, 12, 16, 20] + # Draw attention + for ii in range(len(selected_joints)): + reference_id = selected_joints[ii] + ref_point = ref_points[reference_id] + attention_to_show = all_head[reference_id][14::] + min_v = np.min(attention_to_show) + max_v = np.max(attention_to_show) + norm_attention_to_show = (attention_to_show - min_v)/(max_v-min_v) + vertices_norm = ((vertices_2d + 1) * 0.5) * img_size + ref_norm = ((ref_point + 1) * 0.5) * img_size + image = rend_img*0.4 + + total_to_draw = [] + for jj in range(vertices_norm.shape[0]): + thres = 0.0 + if norm_attention_to_show[jj]>thres: + things = [norm_attention_to_show[jj], ref_norm, vertices_norm[jj]] + total_to_draw.append(things) + total_to_draw.sort() + max_att_score = total_to_draw[-1][0] + for item in total_to_draw: + attention_score = item[0] + ref_point = item[1] + vertex = item[2] + plot_one_line(ref_point, vertex, image, ii, alpha=(attention_score-thres)/(max_att_score-thres) ) + + for jj in range(vertices_norm.shape[0]): + x = int(vertices_norm[jj,0]) + y = int(vertices_norm[jj,1]) + cv2.circle(image,(x,y), 1, (255,255,255), -1) + + if len(combined)==0: + combined = image + else: + combined = np.hstack([combined, image]) + + final = np.hstack([img, combined, rend_img]) + + return final + + +def visualize_reconstruction_no_text(img, img_size, vertices, camera, renderer, color='pink', focal_length=1000): + """Overlays gt_kp and pred_kp on img. + Draws vert with text. + Renderer is an instance of SMPLRenderer. + """ + # Fix a flength so i can render this with persp correct scale + res = img.shape[1] + camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)]) + rend_img = renderer.render(vertices, camera_t=camera_t, + img=img, use_bg=True, + focal_length=focal_length, + body_color=color) + + + combined = np.hstack([img, rend_img]) + + return combined + + +def plot_one_line(ref, vertex, img, color_index, alpha=0.0, line_thickness=None): + # 13,6,7,8,3,4,5 + # att_colors = [(255, 221, 104), (255, 255, 0), (255, 215, 227), (210, 240, 119), \ + # (209, 238, 245), (244, 200, 243), (233, 242, 216)] + att_colors = [(255, 255, 0), (244, 200, 243), (210, 243, 119), (209, 238, 255), (200, 208, 255), (250, 238, 215)] + + + overlay = img.copy() + # output = img.copy() + # Plots one bounding box on image img + tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness + + color = list(att_colors[color_index]) + c1, c2 = (int(ref[0]), int(ref[1])), (int(vertex[0]), int(vertex[1])) + cv2.line(overlay, c1, c2, (alpha*float(color[0])/255,alpha*float(color[1])/255,alpha*float(color[2])/255) , thickness=tl, lineType=cv2.LINE_AA) + cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img) + + + +def cam2pixel(cam_coord, f, c): + x = cam_coord[:, 0] / (cam_coord[:, 2]) * f[0] + c[0] + y = cam_coord[:, 1] / (cam_coord[:, 2]) * f[1] + c[1] + z = cam_coord[:, 2] + img_coord = np.concatenate((x[:,None], y[:,None], z[:,None]),1) + return img_coord + + +class Renderer(object): + """ + Render mesh using OpenDR for visualization. + """ + + def __init__(self, width=800, height=600, near=0.5, far=1000, faces=None): + self.colors = {'hand': [.9, .9, .9], 'pink': [.9, .7, .7], 'light_blue': [0.65098039, 0.74117647, 0.85882353] } + self.width = width + self.height = height + self.faces = faces + self.renderer = ColoredRenderer() + + def render(self, vertices, faces=None, img=None, + camera_t=np.zeros([3], dtype=np.float32), + camera_rot=np.zeros([3], dtype=np.float32), + camera_center=None, + use_bg=False, + bg_color=(0.0, 0.0, 0.0), + body_color=None, + focal_length=5000, + disp_text=False, + gt_keyp=None, + pred_keyp=None, + **kwargs): + if img is not None: + height, width = img.shape[:2] + else: + height, width = self.height, self.width + + if faces is None: + faces = self.faces + + if camera_center is None: + camera_center = np.array([width * 0.5, + height * 0.5]) + + self.renderer.camera = ProjectPoints(rt=camera_rot, + t=camera_t, + f=focal_length * np.ones(2), + c=camera_center, + k=np.zeros(5)) + dist = np.abs(self.renderer.camera.t.r[2] - + np.mean(vertices, axis=0)[2]) + far = dist + 20 + + self.renderer.frustum = {'near': 1.0, 'far': far, + 'width': width, + 'height': height} + + if img is not None: + if use_bg: + self.renderer.background_image = img + else: + self.renderer.background_image = np.ones_like( + img) * np.array(bg_color) + + if body_color is None: + color = self.colors['light_blue'] + else: + color = self.colors[body_color] + + if isinstance(self.renderer, TexturedRenderer): + color = [1.,1.,1.] + + self.renderer.set(v=vertices, f=faces, + vc=color, bgcolor=np.ones(3)) + albedo = self.renderer.vc + # Construct Back Light (on back right corner) + yrot = np.radians(120) + + self.renderer.vc = LambertianPointLight( + f=self.renderer.f, + v=self.renderer.v, + num_verts=self.renderer.v.shape[0], + light_pos=rotateY(np.array([-200, -100, -100]), yrot), + vc=albedo, + light_color=np.array([1, 1, 1])) + + # Construct Left Light + self.renderer.vc += LambertianPointLight( + f=self.renderer.f, + v=self.renderer.v, + num_verts=self.renderer.v.shape[0], + light_pos=rotateY(np.array([800, 10, 300]), yrot), + vc=albedo, + light_color=np.array([1, 1, 1])) + + # Construct Right Light + self.renderer.vc += LambertianPointLight( + f=self.renderer.f, + v=self.renderer.v, + num_verts=self.renderer.v.shape[0], + light_pos=rotateY(np.array([-500, 500, 1000]), yrot), + vc=albedo, + light_color=np.array([.7, .7, .7])) + + return self.renderer.r + + + def render_vertex_color(self, vertices, faces=None, img=None, + camera_t=np.zeros([3], dtype=np.float32), + camera_rot=np.zeros([3], dtype=np.float32), + camera_center=None, + use_bg=False, + bg_color=(0.0, 0.0, 0.0), + vertex_color=None, + focal_length=5000, + disp_text=False, + gt_keyp=None, + pred_keyp=None, + **kwargs): + if img is not None: + height, width = img.shape[:2] + else: + height, width = self.height, self.width + + if faces is None: + faces = self.faces + + if camera_center is None: + camera_center = np.array([width * 0.5, + height * 0.5]) + + self.renderer.camera = ProjectPoints(rt=camera_rot, + t=camera_t, + f=focal_length * np.ones(2), + c=camera_center, + k=np.zeros(5)) + dist = np.abs(self.renderer.camera.t.r[2] - + np.mean(vertices, axis=0)[2]) + far = dist + 20 + + self.renderer.frustum = {'near': 1.0, 'far': far, + 'width': width, + 'height': height} + + if img is not None: + if use_bg: + self.renderer.background_image = img + else: + self.renderer.background_image = np.ones_like( + img) * np.array(bg_color) + + if vertex_color is None: + vertex_color = self.colors['light_blue'] + + + self.renderer.set(v=vertices, f=faces, + vc=vertex_color, bgcolor=np.ones(3)) + albedo = self.renderer.vc + # Construct Back Light (on back right corner) + yrot = np.radians(120) + + self.renderer.vc = LambertianPointLight( + f=self.renderer.f, + v=self.renderer.v, + num_verts=self.renderer.v.shape[0], + light_pos=rotateY(np.array([-200, -100, -100]), yrot), + vc=albedo, + light_color=np.array([1, 1, 1])) + + # Construct Left Light + self.renderer.vc += LambertianPointLight( + f=self.renderer.f, + v=self.renderer.v, + num_verts=self.renderer.v.shape[0], + light_pos=rotateY(np.array([800, 10, 300]), yrot), + vc=albedo, + light_color=np.array([1, 1, 1])) + + # Construct Right Light + self.renderer.vc += LambertianPointLight( + f=self.renderer.f, + v=self.renderer.v, + num_verts=self.renderer.v.shape[0], + light_pos=rotateY(np.array([-500, 500, 1000]), yrot), + vc=albedo, + light_color=np.array([.7, .7, .7])) + + return self.renderer.r \ No newline at end of file diff --git a/mesh_graphormer/utils/tsv_file.py b/mesh_graphormer/utils/tsv_file.py new file mode 100644 index 0000000..8e79a9e --- /dev/null +++ b/mesh_graphormer/utils/tsv_file.py @@ -0,0 +1,162 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Definition of TSV class +""" + + +import logging +import os +import os.path as op + + +def generate_lineidx(filein, idxout): + idxout_tmp = idxout + '.tmp' + with open(filein, 'r') as tsvin, open(idxout_tmp,'w') as tsvout: + fsize = os.fstat(tsvin.fileno()).st_size + fpos = 0 + while fpos!=fsize: + tsvout.write(str(fpos)+"\n") + tsvin.readline() + fpos = tsvin.tell() + os.rename(idxout_tmp, idxout) + + +def read_to_character(fp, c): + result = [] + while True: + s = fp.read(32) + assert s != '' + if c in s: + result.append(s[: s.index(c)]) + break + else: + result.append(s) + return ''.join(result) + + +class TSVFile(object): + def __init__(self, tsv_file, generate_lineidx=False): + self.tsv_file = tsv_file + self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' + self._fp = None + self._lineidx = None + # the process always keeps the process which opens the file. + # If the pid is not equal to the currrent pid, we will re-open the file. + self.pid = None + # generate lineidx if not exist + if not op.isfile(self.lineidx) and generate_lineidx: + generate_lineidx(self.tsv_file, self.lineidx) + + def __del__(self): + if self._fp: + self._fp.close() + + def __str__(self): + return "TSVFile(tsv_file='{}')".format(self.tsv_file) + + def __repr__(self): + return str(self) + + def num_rows(self): + self._ensure_lineidx_loaded() + return len(self._lineidx) + + def seek(self, idx): + self._ensure_tsv_opened() + self._ensure_lineidx_loaded() + try: + pos = self._lineidx[idx] + except: + logging.info('{}-{}'.format(self.tsv_file, idx)) + raise + self._fp.seek(pos) + return [s.strip() for s in self._fp.readline().split('\t')] + + def seek_first_column(self, idx): + self._ensure_tsv_opened() + self._ensure_lineidx_loaded() + pos = self._lineidx[idx] + self._fp.seek(pos) + return read_to_character(self._fp, '\t') + + def get_key(self, idx): + return self.seek_first_column(idx) + + def __getitem__(self, index): + return self.seek(index) + + def __len__(self): + return self.num_rows() + + def _ensure_lineidx_loaded(self): + if self._lineidx is None: + logging.info('loading lineidx: {}'.format(self.lineidx)) + with open(self.lineidx, 'r') as fp: + self._lineidx = [int(i.strip()) for i in fp.readlines()] + + def _ensure_tsv_opened(self): + if self._fp is None: + self._fp = open(self.tsv_file, 'r') + self.pid = os.getpid() + + if self.pid != os.getpid(): + logging.info('re-open {} because the process id changed'.format(self.tsv_file)) + self._fp = open(self.tsv_file, 'r') + self.pid = os.getpid() + + +class CompositeTSVFile(): + def __init__(self, file_list, seq_file, root='.'): + if isinstance(file_list, str): + self.file_list = load_list_file(file_list) + else: + assert isinstance(file_list, list) + self.file_list = file_list + + self.seq_file = seq_file + self.root = root + self.initialized = False + self.initialize() + + def get_key(self, index): + idx_source, idx_row = self.seq[index] + k = self.tsvs[idx_source].get_key(idx_row) + return '_'.join([self.file_list[idx_source], k]) + + def num_rows(self): + return len(self.seq) + + def __getitem__(self, index): + idx_source, idx_row = self.seq[index] + return self.tsvs[idx_source].seek(idx_row) + + def __len__(self): + return len(self.seq) + + def initialize(self): + ''' + this function has to be called in init function if cache_policy is + enabled. Thus, let's always call it in init funciton to make it simple. + ''' + if self.initialized: + return + self.seq = [] + with open(self.seq_file, 'r') as fp: + for line in fp: + parts = line.strip().split('\t') + self.seq.append([int(parts[0]), int(parts[1])]) + self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list] + self.initialized = True + + +def load_list_file(fname): + with open(fname, 'r') as fp: + lines = fp.readlines() + result = [line.strip() for line in lines] + if len(result) > 0 and result[-1] == '': + result = result[:-1] + return result + + diff --git a/mesh_graphormer/utils/tsv_file_ops.py b/mesh_graphormer/utils/tsv_file_ops.py new file mode 100644 index 0000000..4d10f12 --- /dev/null +++ b/mesh_graphormer/utils/tsv_file_ops.py @@ -0,0 +1,116 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Basic operations for TSV files +""" + + +import os +import os.path as op +import json +import numpy as np +import base64 +import cv2 +from tqdm import tqdm +import yaml +from mesh_graphormer.utils.miscellaneous import mkdir +from mesh_graphormer.utils.tsv_file import TSVFile + + +def img_from_base64(imagestring): + try: + jpgbytestring = base64.b64decode(imagestring) + nparr = np.frombuffer(jpgbytestring, np.uint8) + r = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + return r + except ValueError: + return None + +def load_linelist_file(linelist_file): + if linelist_file is not None: + line_list = [] + with open(linelist_file, 'r') as fp: + for i in fp: + line_list.append(int(i.strip())) + return line_list + +def tsv_writer(values, tsv_file, sep='\t'): + mkdir(op.dirname(tsv_file)) + lineidx_file = op.splitext(tsv_file)[0] + '.lineidx' + idx = 0 + tsv_file_tmp = tsv_file + '.tmp' + lineidx_file_tmp = lineidx_file + '.tmp' + with open(tsv_file_tmp, 'w') as fp, open(lineidx_file_tmp, 'w') as fpidx: + assert values is not None + for value in values: + assert value is not None + value = [v if type(v)!=bytes else v.decode('utf-8') for v in value] + v = '{0}\n'.format(sep.join(map(str, value))) + fp.write(v) + fpidx.write(str(idx) + '\n') + idx = idx + len(v) + os.rename(tsv_file_tmp, tsv_file) + os.rename(lineidx_file_tmp, lineidx_file) + +def tsv_reader(tsv_file, sep='\t'): + with open(tsv_file, 'r') as fp: + for i, line in enumerate(fp): + yield [x.strip() for x in line.split(sep)] + +def config_save_file(tsv_file, save_file=None, append_str='.new.tsv'): + if save_file is not None: + return save_file + return op.splitext(tsv_file)[0] + append_str + +def get_line_list(linelist_file=None, num_rows=None): + if linelist_file is not None: + return load_linelist_file(linelist_file) + + if num_rows is not None: + return [i for i in range(num_rows)] + +def generate_hw_file(img_file, save_file=None): + rows = tsv_reader(img_file) + def gen_rows(): + for i, row in tqdm(enumerate(rows)): + row1 = [row[0]] + img = img_from_base64(row[-1]) + height = img.shape[0] + width = img.shape[1] + row1.append(json.dumps([{"height":height, "width": width}])) + yield row1 + + save_file = config_save_file(img_file, save_file, '.hw.tsv') + tsv_writer(gen_rows(), save_file) + +def generate_linelist_file(label_file, save_file=None, ignore_attrs=()): + # generate a list of image that has labels + # images with only ignore labels are not selected. + line_list = [] + rows = tsv_reader(label_file) + for i, row in tqdm(enumerate(rows)): + labels = json.loads(row[1]) + if labels: + if ignore_attrs and all([any([lab[attr] for attr in ignore_attrs if attr in lab]) \ + for lab in labels]): + continue + line_list.append([i]) + + save_file = config_save_file(label_file, save_file, '.linelist.tsv') + tsv_writer(line_list, save_file) + +def load_from_yaml_file(yaml_file): + with open(yaml_file, 'r') as fp: + return yaml.load(fp, Loader=yaml.CLoader) + +def find_file_path_in_yaml(fname, root): + if fname is not None: + if op.isfile(fname): + return fname + elif op.isfile(op.join(root, fname)): + return op.join(root, fname) + else: + raise FileNotFoundError( + errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname) + ) diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000..7c2c262 --- /dev/null +++ b/requirement.txt @@ -0,0 +1 @@ +rtree \ No newline at end of file