mirror of
https://github.com/huchenlei/HandRefinerPortable.git
synced 2026-01-26 15:49:45 +00:00
335 lines
13 KiB
Python
335 lines
13 KiB
Python
"""
|
|
Copyright (c) Microsoft Corporation.
|
|
Licensed under the MIT license.
|
|
|
|
"""
|
|
|
|
|
|
import cv2
|
|
import math
|
|
import json
|
|
from PIL import Image
|
|
import os.path as op
|
|
import numpy as np
|
|
import code
|
|
|
|
from mesh_graphormer.utils.tsv_file import TSVFile, CompositeTSVFile
|
|
from mesh_graphormer.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml
|
|
from mesh_graphormer.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa
|
|
import torch
|
|
import torchvision.transforms as transforms
|
|
|
|
|
|
class HandMeshTSVDataset(object):
|
|
def __init__(self, args, img_file, label_file=None, hw_file=None,
|
|
linelist_file=None, is_train=True, cv2_output=False, scale_factor=1):
|
|
|
|
self.args = args
|
|
self.img_file = img_file
|
|
self.label_file = label_file
|
|
self.hw_file = hw_file
|
|
self.linelist_file = linelist_file
|
|
self.img_tsv = self.get_tsv_file(img_file)
|
|
self.label_tsv = None if label_file is None else self.get_tsv_file(label_file)
|
|
self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file)
|
|
|
|
if self.is_composite:
|
|
assert op.isfile(self.linelist_file)
|
|
self.line_list = [i for i in range(self.hw_tsv.num_rows())]
|
|
else:
|
|
self.line_list = load_linelist_file(linelist_file)
|
|
|
|
self.cv2_output = cv2_output
|
|
self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
std=[0.229, 0.224, 0.225])
|
|
self.is_train = is_train
|
|
self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor]
|
|
self.noise_factor = 0.4
|
|
self.rot_factor = 90 # Random rotation in the range [-rot_factor, rot_factor]
|
|
self.img_res = 224
|
|
self.image_keys = self.prepare_image_keys()
|
|
self.joints_definition = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1',
|
|
'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
|
|
self.root_index = self.joints_definition.index('Wrist')
|
|
|
|
def get_tsv_file(self, tsv_file):
|
|
if tsv_file:
|
|
if self.is_composite:
|
|
return CompositeTSVFile(tsv_file, self.linelist_file,
|
|
root=self.root)
|
|
tsv_path = find_file_path_in_yaml(tsv_file, self.root)
|
|
return TSVFile(tsv_path)
|
|
|
|
def get_valid_tsv(self):
|
|
# sorted by file size
|
|
if self.hw_tsv:
|
|
return self.hw_tsv
|
|
if self.label_tsv:
|
|
return self.label_tsv
|
|
|
|
def prepare_image_keys(self):
|
|
tsv = self.get_valid_tsv()
|
|
return [tsv.get_key(i) for i in range(tsv.num_rows())]
|
|
|
|
def prepare_image_key_to_index(self):
|
|
tsv = self.get_valid_tsv()
|
|
return {tsv.get_key(i) : i for i in range(tsv.num_rows())}
|
|
|
|
|
|
def augm_params(self):
|
|
"""Get augmentation parameters."""
|
|
flip = 0 # flipping
|
|
pn = np.ones(3) # per channel pixel-noise
|
|
|
|
if self.args.multiscale_inference == False:
|
|
rot = 0 # rotation
|
|
sc = 1.0 # scaling
|
|
elif self.args.multiscale_inference == True:
|
|
rot = self.args.rot
|
|
sc = self.args.sc
|
|
|
|
if self.is_train:
|
|
sc = 1.0
|
|
# Each channel is multiplied with a number
|
|
# in the area [1-opt.noiseFactor,1+opt.noiseFactor]
|
|
pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3)
|
|
|
|
# The rotation is a number in the area [-2*rotFactor, 2*rotFactor]
|
|
rot = min(2*self.rot_factor,
|
|
max(-2*self.rot_factor, np.random.randn()*self.rot_factor))
|
|
|
|
# The scale is multiplied with a number
|
|
# in the area [1-scaleFactor,1+scaleFactor]
|
|
sc = min(1+self.scale_factor,
|
|
max(1-self.scale_factor, np.random.randn()*self.scale_factor+1))
|
|
# but it is zero with probability 3/5
|
|
if np.random.uniform() <= 0.6:
|
|
rot = 0
|
|
|
|
return flip, pn, rot, sc
|
|
|
|
def rgb_processing(self, rgb_img, center, scale, rot, flip, pn):
|
|
"""Process rgb image and do augmentation."""
|
|
rgb_img = crop(rgb_img, center, scale,
|
|
[self.img_res, self.img_res], rot=rot)
|
|
# flip the image
|
|
if flip:
|
|
rgb_img = flip_img(rgb_img)
|
|
# in the rgb image we add pixel noise in a channel-wise manner
|
|
rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0]))
|
|
rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1]))
|
|
rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2]))
|
|
# (3,224,224),float,[0,1]
|
|
rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0
|
|
return rgb_img
|
|
|
|
def j2d_processing(self, kp, center, scale, r, f):
|
|
"""Process gt 2D keypoints and apply all augmentation transforms."""
|
|
nparts = kp.shape[0]
|
|
for i in range(nparts):
|
|
kp[i,0:2] = transform(kp[i,0:2]+1, center, scale,
|
|
[self.img_res, self.img_res], rot=r)
|
|
# convert to normalized coordinates
|
|
kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1.
|
|
# flip the x coordinates
|
|
if f:
|
|
kp = flip_kp(kp)
|
|
kp = kp.astype('float32')
|
|
return kp
|
|
|
|
|
|
def j3d_processing(self, S, r, f):
|
|
"""Process gt 3D keypoints and apply all augmentation transforms."""
|
|
# in-plane rotation
|
|
rot_mat = np.eye(3)
|
|
if not r == 0:
|
|
rot_rad = -r * np.pi / 180
|
|
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
|
|
rot_mat[0,:2] = [cs, -sn]
|
|
rot_mat[1,:2] = [sn, cs]
|
|
S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1])
|
|
# flip the x coordinates
|
|
if f:
|
|
S = flip_kp(S)
|
|
S = S.astype('float32')
|
|
return S
|
|
|
|
def pose_processing(self, pose, r, f):
|
|
"""Process SMPL theta parameters and apply all augmentation transforms."""
|
|
# rotation or the pose parameters
|
|
pose = pose.astype('float32')
|
|
pose[:3] = rot_aa(pose[:3], r)
|
|
# flip the pose parameters
|
|
if f:
|
|
pose = flip_pose(pose)
|
|
# (72),float
|
|
pose = pose.astype('float32')
|
|
return pose
|
|
|
|
def get_line_no(self, idx):
|
|
return idx if self.line_list is None else self.line_list[idx]
|
|
|
|
def get_image(self, idx):
|
|
line_no = self.get_line_no(idx)
|
|
row = self.img_tsv[line_no]
|
|
# use -1 to support old format with multiple columns.
|
|
cv2_im = img_from_base64(row[-1])
|
|
if self.cv2_output:
|
|
return cv2_im.astype(np.float32, copy=True)
|
|
cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB)
|
|
return cv2_im
|
|
|
|
def get_annotations(self, idx):
|
|
line_no = self.get_line_no(idx)
|
|
if self.label_tsv is not None:
|
|
row = self.label_tsv[line_no]
|
|
annotations = json.loads(row[1])
|
|
return annotations
|
|
else:
|
|
return []
|
|
|
|
def get_target_from_annotations(self, annotations, img_size, idx):
|
|
# This function will be overwritten by each dataset to
|
|
# decode the labels to specific formats for each task.
|
|
return annotations
|
|
|
|
def get_img_info(self, idx):
|
|
if self.hw_tsv is not None:
|
|
line_no = self.get_line_no(idx)
|
|
row = self.hw_tsv[line_no]
|
|
try:
|
|
# json string format with "height" and "width" being the keys
|
|
return json.loads(row[1])[0]
|
|
except ValueError:
|
|
# list of strings representing height and width in order
|
|
hw_str = row[1].split(' ')
|
|
hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])}
|
|
return hw_dict
|
|
|
|
def get_img_key(self, idx):
|
|
line_no = self.get_line_no(idx)
|
|
# based on the overhead of reading each row.
|
|
if self.hw_tsv:
|
|
return self.hw_tsv[line_no][0]
|
|
elif self.label_tsv:
|
|
return self.label_tsv[line_no][0]
|
|
else:
|
|
return self.img_tsv[line_no][0]
|
|
|
|
def __len__(self):
|
|
if self.line_list is None:
|
|
return self.img_tsv.num_rows()
|
|
else:
|
|
return len(self.line_list)
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
img = self.get_image(idx)
|
|
img_key = self.get_img_key(idx)
|
|
annotations = self.get_annotations(idx)
|
|
|
|
annotations = annotations[0]
|
|
center = annotations['center']
|
|
scale = annotations['scale']
|
|
has_2d_joints = annotations['has_2d_joints']
|
|
has_3d_joints = annotations['has_3d_joints']
|
|
joints_2d = np.asarray(annotations['2d_joints'])
|
|
joints_3d = np.asarray(annotations['3d_joints'])
|
|
|
|
if joints_2d.ndim==3:
|
|
joints_2d = joints_2d[0]
|
|
if joints_3d.ndim==3:
|
|
joints_3d = joints_3d[0]
|
|
|
|
# Get SMPL parameters, if available
|
|
has_smpl = np.asarray(annotations['has_smpl'])
|
|
pose = np.asarray(annotations['pose'])
|
|
betas = np.asarray(annotations['betas'])
|
|
|
|
# Get augmentation parameters
|
|
flip,pn,rot,sc = self.augm_params()
|
|
|
|
# Process image
|
|
img = self.rgb_processing(img, center, sc*scale, rot, flip, pn)
|
|
img = torch.from_numpy(img).float()
|
|
# Store image before normalization to use it in visualization
|
|
transfromed_img = self.normalize_img(img)
|
|
|
|
# normalize 3d pose by aligning the wrist as the root (at origin)
|
|
root_coord = joints_3d[self.root_index,:-1]
|
|
joints_3d[:,:-1] = joints_3d[:,:-1] - root_coord[None,:]
|
|
# 3d pose augmentation (random flip + rotation, consistent to image and SMPL)
|
|
joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip)
|
|
# 2d pose augmentation
|
|
joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip)
|
|
|
|
###################################
|
|
# Masking percantage
|
|
# We observe that 0% or 5% works better for 3D hand mesh
|
|
# We think this is probably becasue 3D vertices are quite sparse in the down-sampled hand mesh
|
|
mvm_percent = 0.0 # or 0.05
|
|
###################################
|
|
|
|
mjm_mask = np.ones((21,1))
|
|
if self.is_train:
|
|
num_joints = 21
|
|
pb = np.random.random_sample()
|
|
masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked
|
|
indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num)
|
|
mjm_mask[indices,:] = 0.0
|
|
mjm_mask = torch.from_numpy(mjm_mask).float()
|
|
|
|
mvm_mask = np.ones((195,1))
|
|
if self.is_train:
|
|
num_vertices = 195
|
|
pb = np.random.random_sample()
|
|
masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked
|
|
indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num)
|
|
mvm_mask[indices,:] = 0.0
|
|
mvm_mask = torch.from_numpy(mvm_mask).float()
|
|
|
|
meta_data = {}
|
|
meta_data['ori_img'] = img
|
|
meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float()
|
|
meta_data['betas'] = torch.from_numpy(betas).float()
|
|
meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float()
|
|
meta_data['has_3d_joints'] = has_3d_joints
|
|
meta_data['has_smpl'] = has_smpl
|
|
meta_data['mjm_mask'] = mjm_mask
|
|
meta_data['mvm_mask'] = mvm_mask
|
|
|
|
# Get 2D keypoints and apply augmentation transforms
|
|
meta_data['has_2d_joints'] = has_2d_joints
|
|
meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float()
|
|
|
|
meta_data['scale'] = float(sc * scale)
|
|
meta_data['center'] = np.asarray(center).astype(np.float32)
|
|
|
|
return img_key, transfromed_img, meta_data
|
|
|
|
|
|
class HandMeshTSVYamlDataset(HandMeshTSVDataset):
|
|
""" TSVDataset taking a Yaml file for easy function call
|
|
"""
|
|
def __init__(self, args, yaml_file, is_train=True, cv2_output=False, scale_factor=1):
|
|
self.cfg = load_from_yaml_file(yaml_file)
|
|
self.is_composite = self.cfg.get('composite', False)
|
|
self.root = op.dirname(yaml_file)
|
|
|
|
if self.is_composite==False:
|
|
img_file = find_file_path_in_yaml(self.cfg['img'], self.root)
|
|
label_file = find_file_path_in_yaml(self.cfg.get('label', None),
|
|
self.root)
|
|
hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root)
|
|
linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
|
|
self.root)
|
|
else:
|
|
img_file = self.cfg['img']
|
|
hw_file = self.cfg['hw']
|
|
label_file = self.cfg.get('label', None)
|
|
linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
|
|
self.root)
|
|
|
|
super(HandMeshTSVYamlDataset, self).__init__(
|
|
args, img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor)
|