mirror of
https://github.com/huchenlei/HandRefinerPortable.git
synced 2026-04-30 03:01:32 +00:00
✨ Initial commit
This commit is contained in:
0
mesh_graphormer/utils/__init__.py
Normal file
0
mesh_graphormer/utils/__init__.py
Normal file
176
mesh_graphormer/utils/comm.py
Normal file
176
mesh_graphormer/utils/comm.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Copyright (c) Microsoft Corporation.
|
||||
Licensed under the MIT license.
|
||||
|
||||
This file contains primitives for multi-gpu communication.
|
||||
This is useful when doing distributed training.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def synchronize():
|
||||
"""
|
||||
Helper function to synchronize (barrier) among all processes when
|
||||
using distributed training
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
world_size = dist.get_world_size()
|
||||
if world_size == 1:
|
||||
return
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def gather_on_master(data):
|
||||
"""Same as all_gather, but gathers data on master process only, using CPU.
|
||||
Thus, this does not work with NCCL backend unless they add CPU support.
|
||||
|
||||
The memory consumption of this function is ~ 3x of data size. While in
|
||||
principal, it should be ~2x, it's not easy to force Python to release
|
||||
memory immediately and thus, peak memory usage could be up to 3x.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
# trying to optimize memory, but in fact, it's not guaranteed to be released
|
||||
del data
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
del buffer
|
||||
tensor = torch.ByteTensor(storage)
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.LongTensor([tensor.numel()])
|
||||
size_list = [torch.LongTensor([0]) for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
if local_size != max_size:
|
||||
padding = torch.ByteTensor(size=(max_size - local_size,))
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
del padding
|
||||
|
||||
if is_main_process():
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.ByteTensor(size=(max_size,)))
|
||||
dist.gather(tensor, gather_list=tensor_list, dst=0)
|
||||
del tensor
|
||||
else:
|
||||
dist.gather(tensor, gather_list=[], dst=0)
|
||||
del tensor
|
||||
return
|
||||
|
||||
data_list = []
|
||||
for tensor in tensor_list:
|
||||
buffer = tensor.cpu().numpy().tobytes()
|
||||
del tensor
|
||||
data_list.append(pickle.loads(buffer))
|
||||
del buffer
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to(device)
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.LongTensor([tensor.numel()]).to(device)
|
||||
size_list = [torch.LongTensor([0]).to(device) for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.ByteTensor(size=(max_size,)).to(device))
|
||||
if local_size != max_size:
|
||||
padding = torch.ByteTensor(size=(max_size - local_size,)).to(device)
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that process with rank
|
||||
0 has the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.reduce(values, dst=0)
|
||||
if dist.get_rank() == 0 and average:
|
||||
# only main process gets accumulated, so only divide by
|
||||
# world_size in this case
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
return reduced_dict
|
||||
66
mesh_graphormer/utils/dataset_utils.py
Normal file
66
mesh_graphormer/utils/dataset_utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Copyright (c) Microsoft Corporation.
|
||||
Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import os.path as op
|
||||
import numpy as np
|
||||
import base64
|
||||
import cv2
|
||||
import yaml
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def img_from_base64(imagestring):
|
||||
try:
|
||||
jpgbytestring = base64.b64decode(imagestring)
|
||||
nparr = np.frombuffer(jpgbytestring, np.uint8)
|
||||
r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
return r
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def load_labelmap(labelmap_file):
|
||||
label_dict = None
|
||||
if labelmap_file is not None and op.isfile(labelmap_file):
|
||||
label_dict = OrderedDict()
|
||||
with open(labelmap_file, 'r') as fp:
|
||||
for line in fp:
|
||||
label = line.strip().split('\t')[0]
|
||||
if label in label_dict:
|
||||
raise ValueError("Duplicate label " + label + " in labelmap.")
|
||||
else:
|
||||
label_dict[label] = len(label_dict)
|
||||
return label_dict
|
||||
|
||||
|
||||
def load_shuffle_file(shuf_file):
|
||||
shuf_list = None
|
||||
if shuf_file is not None:
|
||||
with open(shuf_file, 'r') as fp:
|
||||
shuf_list = []
|
||||
for i in fp:
|
||||
shuf_list.append(int(i.strip()))
|
||||
return shuf_list
|
||||
|
||||
|
||||
def load_box_shuffle_file(shuf_file):
|
||||
if shuf_file is not None:
|
||||
with open(shuf_file, 'r') as fp:
|
||||
img_shuf_list = []
|
||||
box_shuf_list = []
|
||||
for i in fp:
|
||||
idx = [int(_) for _ in i.strip().split('\t')]
|
||||
img_shuf_list.append(idx[0])
|
||||
box_shuf_list.append(idx[1])
|
||||
return [img_shuf_list, box_shuf_list]
|
||||
return None
|
||||
|
||||
|
||||
def load_from_yaml_file(file_name):
|
||||
with open(file_name, 'r') as fp:
|
||||
return yaml.load(fp, Loader=yaml.CLoader)
|
||||
58
mesh_graphormer/utils/geometric_layers.py
Normal file
58
mesh_graphormer/utils/geometric_layers.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Useful geometric operations, e.g. Orthographic projection and a differentiable Rodrigues formula
|
||||
|
||||
Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
|
||||
"""
|
||||
import torch
|
||||
|
||||
def rodrigues(theta):
|
||||
"""Convert axis-angle representation to rotation matrix.
|
||||
Args:
|
||||
theta: size = [B, 3]
|
||||
Returns:
|
||||
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
||||
"""
|
||||
l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
|
||||
angle = torch.unsqueeze(l1norm, -1)
|
||||
normalized = torch.div(theta, angle)
|
||||
angle = angle * 0.5
|
||||
v_cos = torch.cos(angle)
|
||||
v_sin = torch.sin(angle)
|
||||
quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
|
||||
return quat2mat(quat)
|
||||
|
||||
def quat2mat(quat):
|
||||
"""Convert quaternion coefficients to rotation matrix.
|
||||
Args:
|
||||
quat: size = [B, 4] 4 <===>(w, x, y, z)
|
||||
Returns:
|
||||
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
||||
"""
|
||||
norm_quat = quat
|
||||
norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
|
||||
w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
|
||||
|
||||
B = quat.size(0)
|
||||
|
||||
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
||||
wx, wy, wz = w*x, w*y, w*z
|
||||
xy, xz, yz = x*y, x*z, y*z
|
||||
|
||||
rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
|
||||
2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
|
||||
2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
|
||||
return rotMat
|
||||
|
||||
def orthographic_projection(X, camera):
|
||||
"""Perform orthographic projection of 3D points X using the camera parameters
|
||||
Args:
|
||||
X: size = [B, N, 3]
|
||||
camera: size = [B, 3]
|
||||
Returns:
|
||||
Projected 2D points -- size = [B, N, 2]
|
||||
"""
|
||||
camera = camera.view(-1, 1, 3)
|
||||
X_trans = X[:, :, :2] + camera[:, :, 1:]
|
||||
shape = X_trans.shape
|
||||
X_2d = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape)
|
||||
return X_2d
|
||||
208
mesh_graphormer/utils/image_ops.py
Normal file
208
mesh_graphormer/utils/image_ops.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
Image processing tools
|
||||
|
||||
Modified from open source projects:
|
||||
(https://github.com/nkolot/GraphCMR/)
|
||||
(https://github.com/open-mmlab/mmdetection)
|
||||
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import base64
|
||||
import cv2
|
||||
import torch
|
||||
import scipy.misc
|
||||
|
||||
def img_from_base64(imagestring):
|
||||
try:
|
||||
jpgbytestring = base64.b64decode(imagestring)
|
||||
nparr = np.frombuffer(jpgbytestring, np.uint8)
|
||||
r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
return r
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def myimrotate(img, angle, center=None, scale=1.0, border_value=0, auto_bound=False):
|
||||
if center is not None and auto_bound:
|
||||
raise ValueError('`auto_bound` conflicts with `center`')
|
||||
h, w = img.shape[:2]
|
||||
if center is None:
|
||||
center = ((w - 1) * 0.5, (h - 1) * 0.5)
|
||||
assert isinstance(center, tuple)
|
||||
|
||||
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
||||
if auto_bound:
|
||||
cos = np.abs(matrix[0, 0])
|
||||
sin = np.abs(matrix[0, 1])
|
||||
new_w = h * sin + w * cos
|
||||
new_h = h * cos + w * sin
|
||||
matrix[0, 2] += (new_w - w) * 0.5
|
||||
matrix[1, 2] += (new_h - h) * 0.5
|
||||
w = int(np.round(new_w))
|
||||
h = int(np.round(new_h))
|
||||
rotated = cv2.warpAffine(img, matrix, (w, h), borderValue=border_value)
|
||||
return rotated
|
||||
|
||||
def myimresize(img, size, return_scale=False, interpolation='bilinear'):
|
||||
|
||||
h, w = img.shape[:2]
|
||||
resized_img = cv2.resize(
|
||||
img, (size[0],size[1]), interpolation=cv2.INTER_LINEAR)
|
||||
if not return_scale:
|
||||
return resized_img
|
||||
else:
|
||||
w_scale = size[0] / w
|
||||
h_scale = size[1] / h
|
||||
return resized_img, w_scale, h_scale
|
||||
|
||||
|
||||
def get_transform(center, scale, res, rot=0):
|
||||
"""Generate transformation matrix."""
|
||||
h = 200 * scale
|
||||
t = np.zeros((3, 3))
|
||||
t[0, 0] = float(res[1]) / h
|
||||
t[1, 1] = float(res[0]) / h
|
||||
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
|
||||
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
|
||||
t[2, 2] = 1
|
||||
if not rot == 0:
|
||||
rot = -rot # To match direction of rotation from cropping
|
||||
rot_mat = np.zeros((3,3))
|
||||
rot_rad = rot * np.pi / 180
|
||||
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
|
||||
rot_mat[0,:2] = [cs, -sn]
|
||||
rot_mat[1,:2] = [sn, cs]
|
||||
rot_mat[2,2] = 1
|
||||
# Need to rotate around center
|
||||
t_mat = np.eye(3)
|
||||
t_mat[0,2] = -res[1]/2
|
||||
t_mat[1,2] = -res[0]/2
|
||||
t_inv = t_mat.copy()
|
||||
t_inv[:2,2] *= -1
|
||||
t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
|
||||
return t
|
||||
|
||||
def transform(pt, center, scale, res, invert=0, rot=0):
|
||||
"""Transform pixel location to different reference."""
|
||||
t = get_transform(center, scale, res, rot=rot)
|
||||
if invert:
|
||||
# t = np.linalg.inv(t)
|
||||
t_torch = torch.from_numpy(t)
|
||||
t_torch = torch.inverse(t_torch)
|
||||
t = t_torch.numpy()
|
||||
new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T
|
||||
new_pt = np.dot(t, new_pt)
|
||||
return new_pt[:2].astype(int)+1
|
||||
|
||||
def crop(img, center, scale, res, rot=0):
|
||||
"""Crop image according to the supplied bounding box."""
|
||||
# Upper left point
|
||||
ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
|
||||
# Bottom right point
|
||||
br = np.array(transform([res[0]+1,
|
||||
res[1]+1], center, scale, res, invert=1))-1
|
||||
# Padding so that when rotated proper amount of context is included
|
||||
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
|
||||
if not rot == 0:
|
||||
ul -= pad
|
||||
br += pad
|
||||
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
||||
if len(img.shape) > 2:
|
||||
new_shape += [img.shape[2]]
|
||||
new_img = np.zeros(new_shape)
|
||||
|
||||
# Range to fill new array
|
||||
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
|
||||
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
|
||||
# Range to sample from original image
|
||||
old_x = max(0, ul[0]), min(len(img[0]), br[0])
|
||||
old_y = max(0, ul[1]), min(len(img), br[1])
|
||||
|
||||
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
|
||||
old_x[0]:old_x[1]]
|
||||
if not rot == 0:
|
||||
# Remove padding
|
||||
# new_img = scipy.misc.imrotate(new_img, rot)
|
||||
new_img = myimrotate(new_img, rot)
|
||||
new_img = new_img[pad:-pad, pad:-pad]
|
||||
|
||||
# new_img = scipy.misc.imresize(new_img, res)
|
||||
new_img = myimresize(new_img, [res[0], res[1]])
|
||||
return new_img
|
||||
|
||||
def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
|
||||
"""'Undo' the image cropping/resizing.
|
||||
This function is used when evaluating mask/part segmentation.
|
||||
"""
|
||||
res = img.shape[:2]
|
||||
# Upper left point
|
||||
ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
|
||||
# Bottom right point
|
||||
br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1
|
||||
# size of cropped image
|
||||
crop_shape = [br[1] - ul[1], br[0] - ul[0]]
|
||||
|
||||
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
||||
if len(img.shape) > 2:
|
||||
new_shape += [img.shape[2]]
|
||||
new_img = np.zeros(orig_shape, dtype=np.uint8)
|
||||
# Range to fill new array
|
||||
new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
|
||||
new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
|
||||
# Range to sample from original image
|
||||
old_x = max(0, ul[0]), min(orig_shape[1], br[0])
|
||||
old_y = max(0, ul[1]), min(orig_shape[0], br[1])
|
||||
# img = scipy.misc.imresize(img, crop_shape, interp='nearest')
|
||||
img = myimresize(img, [crop_shape[0],crop_shape[1]])
|
||||
new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
|
||||
return new_img
|
||||
|
||||
def rot_aa(aa, rot):
|
||||
"""Rotate axis angle parameters."""
|
||||
# pose parameters
|
||||
R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
||||
[np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
|
||||
[0, 0, 1]])
|
||||
# find the rotation of the body in camera frame
|
||||
per_rdg, _ = cv2.Rodrigues(aa)
|
||||
# apply the global rotation to the global orientation
|
||||
resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
|
||||
aa = (resrot.T)[0]
|
||||
return aa
|
||||
|
||||
def flip_img(img):
|
||||
"""Flip rgb images or masks.
|
||||
channels come last, e.g. (256,256,3).
|
||||
"""
|
||||
img = np.fliplr(img)
|
||||
return img
|
||||
|
||||
def flip_kp(kp):
|
||||
"""Flip keypoints."""
|
||||
flipped_parts = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22]
|
||||
kp = kp[flipped_parts]
|
||||
kp[:,0] = - kp[:,0]
|
||||
return kp
|
||||
|
||||
def flip_pose(pose):
|
||||
"""Flip pose.
|
||||
The flipping is based on SMPL parameters.
|
||||
"""
|
||||
flippedParts = [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13,
|
||||
14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33,
|
||||
34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41,
|
||||
45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55,
|
||||
56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68]
|
||||
pose = pose[flippedParts]
|
||||
# we also negate the second and the third dimension of the axis-angle
|
||||
pose[1::3] = -pose[1::3]
|
||||
pose[2::3] = -pose[2::3]
|
||||
return pose
|
||||
|
||||
def flip_aa(aa):
|
||||
"""Flip axis-angle representation.
|
||||
We negate the second and the third dimension of the axis-angle.
|
||||
"""
|
||||
aa[1] = -aa[1]
|
||||
aa[2] = -aa[2]
|
||||
return aa
|
||||
100
mesh_graphormer/utils/logger.py
Normal file
100
mesh_graphormer/utils/logger.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from logging import StreamHandler, Handler, getLevelName
|
||||
|
||||
|
||||
# this class is a copy of logging.FileHandler except we end self.close()
|
||||
# at the end of each emit. While closing file and reopening file after each
|
||||
# write is not efficient, it allows us to see partial logs when writing to
|
||||
# fused Azure blobs, which is very convenient
|
||||
class FileHandler(StreamHandler):
|
||||
"""
|
||||
A handler class which writes formatted logging records to disk files.
|
||||
"""
|
||||
def __init__(self, filename, mode='a', encoding=None, delay=False):
|
||||
"""
|
||||
Open the specified file and use it as the stream for logging.
|
||||
"""
|
||||
# Issue #27493: add support for Path objects to be passed in
|
||||
filename = os.fspath(filename)
|
||||
#keep the absolute path, otherwise derived classes which use this
|
||||
#may come a cropper when the current directory changes
|
||||
self.baseFilename = os.path.abspath(filename)
|
||||
self.mode = mode
|
||||
self.encoding = encoding
|
||||
self.delay = delay
|
||||
if delay:
|
||||
#We don't open the stream, but we still need to call the
|
||||
#Handler constructor to set level, formatter, lock etc.
|
||||
Handler.__init__(self)
|
||||
self.stream = None
|
||||
else:
|
||||
StreamHandler.__init__(self, self._open())
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Closes the stream.
|
||||
"""
|
||||
self.acquire()
|
||||
try:
|
||||
try:
|
||||
if self.stream:
|
||||
try:
|
||||
self.flush()
|
||||
finally:
|
||||
stream = self.stream
|
||||
self.stream = None
|
||||
if hasattr(stream, "close"):
|
||||
stream.close()
|
||||
finally:
|
||||
# Issue #19523: call unconditionally to
|
||||
# prevent a handler leak when delay is set
|
||||
StreamHandler.close(self)
|
||||
finally:
|
||||
self.release()
|
||||
|
||||
def _open(self):
|
||||
"""
|
||||
Open the current base file with the (original) mode and encoding.
|
||||
Return the resulting stream.
|
||||
"""
|
||||
return open(self.baseFilename, self.mode, encoding=self.encoding)
|
||||
|
||||
def emit(self, record):
|
||||
"""
|
||||
Emit a record.
|
||||
|
||||
If the stream was not opened because 'delay' was specified in the
|
||||
constructor, open it before calling the superclass's emit.
|
||||
"""
|
||||
if self.stream is None:
|
||||
self.stream = self._open()
|
||||
StreamHandler.emit(self, record)
|
||||
self.close()
|
||||
|
||||
def __repr__(self):
|
||||
level = getLevelName(self.level)
|
||||
return '<%s %s (%s)>' % (self.__class__.__name__, self.baseFilename, level)
|
||||
|
||||
|
||||
def setup_logger(name, save_dir, distributed_rank, filename="log.txt"):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
# don't log results for the non-master process
|
||||
if distributed_rank > 0:
|
||||
return logger
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
if save_dir:
|
||||
fh = FileHandler(os.path.join(save_dir, filename))
|
||||
fh.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
return logger
|
||||
45
mesh_graphormer/utils/metric_logger.py
Normal file
45
mesh_graphormer/utils/metric_logger.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Copyright (c) Microsoft Corporation.
|
||||
Licensed under the MIT license.
|
||||
|
||||
Basic logger. It Computes and stores the average and current value
|
||||
"""
|
||||
|
||||
class AverageMeter(object):
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
|
||||
class EvalMetricsLogger(object):
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
# define a upper-bound performance (worst case)
|
||||
# numbers are in unit millimeter
|
||||
self.PAmPJPE = 100.0/1000.0
|
||||
self.mPJPE = 100.0/1000.0
|
||||
self.mPVE = 100.0/1000.0
|
||||
|
||||
self.epoch = 0
|
||||
|
||||
def update(self, mPVE, mPJPE, PAmPJPE, epoch):
|
||||
self.PAmPJPE = PAmPJPE
|
||||
self.mPJPE = mPJPE
|
||||
self.mPVE = mPVE
|
||||
self.epoch = epoch
|
||||
99
mesh_graphormer/utils/metric_pampjpe.py
Normal file
99
mesh_graphormer/utils/metric_pampjpe.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Functions for compuing Procrustes alignment and reconstruction error
|
||||
|
||||
Parts of the code are adapted from https://github.com/akanazawa/hmr
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
def compute_similarity_transform(S1, S2):
|
||||
"""Computes a similarity transform (sR, t) that takes
|
||||
a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
|
||||
where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
|
||||
i.e. solves the orthogonal Procrutes problem.
|
||||
"""
|
||||
transposed = False
|
||||
if S1.shape[0] != 3 and S1.shape[0] != 2:
|
||||
S1 = S1.T
|
||||
S2 = S2.T
|
||||
transposed = True
|
||||
assert(S2.shape[1] == S1.shape[1])
|
||||
|
||||
# 1. Remove mean.
|
||||
mu1 = S1.mean(axis=1, keepdims=True)
|
||||
mu2 = S2.mean(axis=1, keepdims=True)
|
||||
X1 = S1 - mu1
|
||||
X2 = S2 - mu2
|
||||
|
||||
# 2. Compute variance of X1 used for scale.
|
||||
var1 = np.sum(X1**2)
|
||||
|
||||
# 3. The outer product of X1 and X2.
|
||||
K = X1.dot(X2.T)
|
||||
|
||||
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
|
||||
# singular vectors of K.
|
||||
U, s, Vh = np.linalg.svd(K)
|
||||
V = Vh.T
|
||||
# Construct Z that fixes the orientation of R to get det(R)=1.
|
||||
Z = np.eye(U.shape[0])
|
||||
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
|
||||
# Construct R.
|
||||
R = V.dot(Z.dot(U.T))
|
||||
|
||||
# 5. Recover scale.
|
||||
scale = np.trace(R.dot(K)) / var1
|
||||
|
||||
# 6. Recover translation.
|
||||
t = mu2 - scale*(R.dot(mu1))
|
||||
|
||||
# 7. Error:
|
||||
S1_hat = scale*R.dot(S1) + t
|
||||
|
||||
if transposed:
|
||||
S1_hat = S1_hat.T
|
||||
|
||||
return S1_hat
|
||||
|
||||
def compute_similarity_transform_batch(S1, S2):
|
||||
"""Batched version of compute_similarity_transform."""
|
||||
S1_hat = np.zeros_like(S1)
|
||||
for i in range(S1.shape[0]):
|
||||
S1_hat[i] = compute_similarity_transform(S1[i], S2[i])
|
||||
return S1_hat
|
||||
|
||||
def reconstruction_error(S1, S2, reduction='mean'):
|
||||
"""Do Procrustes alignment and compute reconstruction error."""
|
||||
S1_hat = compute_similarity_transform_batch(S1, S2)
|
||||
re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
|
||||
if reduction == 'mean':
|
||||
re = re.mean()
|
||||
elif reduction == 'sum':
|
||||
re = re.sum()
|
||||
return re
|
||||
|
||||
|
||||
def reconstruction_error_v2(S1, S2, J24_TO_J14, reduction='mean'):
|
||||
"""Do Procrustes alignment and compute reconstruction error."""
|
||||
S1_hat = compute_similarity_transform_batch(S1, S2)
|
||||
S1_hat = S1_hat[:,J24_TO_J14,:]
|
||||
S2 = S2[:,J24_TO_J14,:]
|
||||
re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
|
||||
if reduction == 'mean':
|
||||
re = re.mean()
|
||||
elif reduction == 'sum':
|
||||
re = re.sum()
|
||||
return re
|
||||
|
||||
def get_alignMesh(S1, S2, reduction='mean'):
|
||||
"""Do Procrustes alignment and compute reconstruction error."""
|
||||
S1_hat = compute_similarity_transform_batch(S1, S2)
|
||||
re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
|
||||
if reduction == 'mean':
|
||||
re = re.mean()
|
||||
elif reduction == 'sum':
|
||||
re = re.sum()
|
||||
return re, S1_hat, S2
|
||||
171
mesh_graphormer/utils/miscellaneous.py
Normal file
171
mesh_graphormer/utils/miscellaneous.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
import errno
|
||||
import os
|
||||
import os.path as op
|
||||
import re
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
import shutil
|
||||
from .comm import is_main_process
|
||||
import yaml
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
# if it is the current folder, skip.
|
||||
# otherwise the original code will raise FileNotFoundError
|
||||
if path == '':
|
||||
return
|
||||
try:
|
||||
os.makedirs(path)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
|
||||
def save_config(cfg, path):
|
||||
if is_main_process():
|
||||
with open(path, 'w') as f:
|
||||
f.write(cfg.dump())
|
||||
|
||||
|
||||
def config_iteration(output_dir, max_iter):
|
||||
save_file = os.path.join(output_dir, 'last_checkpoint')
|
||||
iteration = -1
|
||||
if os.path.exists(save_file):
|
||||
with open(save_file, 'r') as f:
|
||||
fname = f.read().strip()
|
||||
model_name = os.path.basename(fname)
|
||||
model_path = os.path.dirname(fname)
|
||||
if model_name.startswith('model_') and len(model_name) == 17:
|
||||
iteration = int(model_name[-11:-4])
|
||||
elif model_name == "model_final":
|
||||
iteration = max_iter
|
||||
elif model_path.startswith('checkpoint-') and len(model_path) == 18:
|
||||
iteration = int(model_path.split('-')[-1])
|
||||
return iteration
|
||||
|
||||
|
||||
def get_matching_parameters(model, regexp, none_on_empty=True):
|
||||
"""Returns parameters matching regular expression"""
|
||||
if not regexp:
|
||||
if none_on_empty:
|
||||
return {}
|
||||
else:
|
||||
return dict(model.named_parameters())
|
||||
compiled_pattern = re.compile(regexp)
|
||||
params = {}
|
||||
for weight_name, weight in model.named_parameters():
|
||||
if compiled_pattern.match(weight_name):
|
||||
params[weight_name] = weight
|
||||
return params
|
||||
|
||||
|
||||
def freeze_weights(model, regexp):
|
||||
"""Freeze weights based on regular expression."""
|
||||
logger = logging.getLogger("maskrcnn_benchmark.trainer")
|
||||
for weight_name, weight in get_matching_parameters(model, regexp).items():
|
||||
weight.requires_grad = False
|
||||
logger.info("Disabled training of {}".format(weight_name))
|
||||
|
||||
|
||||
def unfreeze_weights(model, regexp, backbone_freeze_at=-1,
|
||||
is_distributed=False):
|
||||
"""Unfreeze weights based on regular expression.
|
||||
This is helpful during training to unfreeze freezed weights after
|
||||
other unfreezed weights have been trained for some iterations.
|
||||
"""
|
||||
logger = logging.getLogger("maskrcnn_benchmark.trainer")
|
||||
for weight_name, weight in get_matching_parameters(model, regexp).items():
|
||||
weight.requires_grad = True
|
||||
logger.info("Enabled training of {}".format(weight_name))
|
||||
if backbone_freeze_at >= 0:
|
||||
logger.info("Freeze backbone at stage: {}".format(backbone_freeze_at))
|
||||
if is_distributed:
|
||||
model.module.backbone.body._freeze_backbone(backbone_freeze_at)
|
||||
else:
|
||||
model.backbone.body._freeze_backbone(backbone_freeze_at)
|
||||
|
||||
|
||||
def delete_tsv_files(tsvs):
|
||||
for t in tsvs:
|
||||
if op.isfile(t):
|
||||
try_delete(t)
|
||||
line = op.splitext(t)[0] + '.lineidx'
|
||||
if op.isfile(line):
|
||||
try_delete(line)
|
||||
|
||||
|
||||
def concat_files(ins, out):
|
||||
mkdir(op.dirname(out))
|
||||
out_tmp = out + '.tmp'
|
||||
with open(out_tmp, 'wb') as fp_out:
|
||||
for i, f in enumerate(ins):
|
||||
logging.info('concating {}/{} - {}'.format(i, len(ins), f))
|
||||
with open(f, 'rb') as fp_in:
|
||||
shutil.copyfileobj(fp_in, fp_out, 1024*1024*10)
|
||||
os.rename(out_tmp, out)
|
||||
|
||||
|
||||
def concat_tsv_files(tsvs, out_tsv):
|
||||
concat_files(tsvs, out_tsv)
|
||||
sizes = [os.stat(t).st_size for t in tsvs]
|
||||
sizes = np.cumsum(sizes)
|
||||
all_idx = []
|
||||
for i, t in enumerate(tsvs):
|
||||
for idx in load_list_file(op.splitext(t)[0] + '.lineidx'):
|
||||
if i == 0:
|
||||
all_idx.append(idx)
|
||||
else:
|
||||
all_idx.append(str(int(idx) + sizes[i - 1]))
|
||||
with open(op.splitext(out_tsv)[0] + '.lineidx', 'w') as f:
|
||||
f.write('\n'.join(all_idx))
|
||||
|
||||
|
||||
def load_list_file(fname):
|
||||
with open(fname, 'r') as fp:
|
||||
lines = fp.readlines()
|
||||
result = [line.strip() for line in lines]
|
||||
if len(result) > 0 and result[-1] == '':
|
||||
result = result[:-1]
|
||||
return result
|
||||
|
||||
|
||||
def try_once(func):
|
||||
def func_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logging.info('ignore error \n{}'.format(str(e)))
|
||||
return func_wrapper
|
||||
|
||||
|
||||
@try_once
|
||||
def try_delete(f):
|
||||
os.remove(f)
|
||||
|
||||
|
||||
def set_seed(seed, n_gpu):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def print_and_run_cmd(cmd):
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def write_to_yaml_file(context, file_name):
|
||||
with open(file_name, 'w') as fp:
|
||||
yaml.dump(context, fp, encoding='utf-8')
|
||||
|
||||
|
||||
def load_from_yaml_file(yaml_file):
|
||||
with open(yaml_file, 'r') as fp:
|
||||
return yaml.load(fp, Loader=yaml.CLoader)
|
||||
|
||||
|
||||
691
mesh_graphormer/utils/renderer.py
Normal file
691
mesh_graphormer/utils/renderer.py
Normal file
@@ -0,0 +1,691 @@
|
||||
"""
|
||||
Rendering tools for 3D mesh visualization on 2D image.
|
||||
|
||||
Parts of the code are taken from https://github.com/akanazawa/hmr
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import code
|
||||
from opendr.camera import ProjectPoints
|
||||
from opendr.renderer import ColoredRenderer, TexturedRenderer
|
||||
from opendr.lighting import LambertianPointLight
|
||||
import random
|
||||
|
||||
|
||||
# Rotate the points by a specified angle.
|
||||
def rotateY(points, angle):
|
||||
ry = np.array([
|
||||
[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.],
|
||||
[-np.sin(angle), 0., np.cos(angle)]
|
||||
])
|
||||
return np.dot(points, ry)
|
||||
|
||||
def draw_skeleton(input_image, joints, draw_edges=True, vis=None, radius=None):
|
||||
"""
|
||||
joints is 3 x 19. but if not will transpose it.
|
||||
0: Right ankle
|
||||
1: Right knee
|
||||
2: Right hip
|
||||
3: Left hip
|
||||
4: Left knee
|
||||
5: Left ankle
|
||||
6: Right wrist
|
||||
7: Right elbow
|
||||
8: Right shoulder
|
||||
9: Left shoulder
|
||||
10: Left elbow
|
||||
11: Left wrist
|
||||
12: Neck
|
||||
13: Head top
|
||||
14: nose
|
||||
15: left_eye
|
||||
16: right_eye
|
||||
17: left_ear
|
||||
18: right_ear
|
||||
"""
|
||||
|
||||
if radius is None:
|
||||
radius = max(4, (np.mean(input_image.shape[:2]) * 0.01).astype(int))
|
||||
|
||||
colors = {
|
||||
'pink': (197, 27, 125), # L lower leg
|
||||
'light_pink': (233, 163, 201), # L upper leg
|
||||
'light_green': (161, 215, 106), # L lower arm
|
||||
'green': (77, 146, 33), # L upper arm
|
||||
'red': (215, 48, 39), # head
|
||||
'light_red': (252, 146, 114), # head
|
||||
'light_orange': (252, 141, 89), # chest
|
||||
'purple': (118, 42, 131), # R lower leg
|
||||
'light_purple': (175, 141, 195), # R upper
|
||||
'light_blue': (145, 191, 219), # R lower arm
|
||||
'blue': (69, 117, 180), # R upper arm
|
||||
'gray': (130, 130, 130), #
|
||||
'white': (255, 255, 255), #
|
||||
}
|
||||
|
||||
image = input_image.copy()
|
||||
input_is_float = False
|
||||
|
||||
if np.issubdtype(image.dtype, np.float):
|
||||
input_is_float = True
|
||||
max_val = image.max()
|
||||
if max_val <= 2.: # should be 1 but sometimes it's slightly above 1
|
||||
image = (image * 255).astype(np.uint8)
|
||||
else:
|
||||
image = (image).astype(np.uint8)
|
||||
|
||||
if joints.shape[0] != 2:
|
||||
joints = joints.T
|
||||
joints = np.round(joints).astype(int)
|
||||
|
||||
jcolors = [
|
||||
'light_pink', 'light_pink', 'light_pink', 'pink', 'pink', 'pink',
|
||||
'light_blue', 'light_blue', 'light_blue', 'blue', 'blue', 'blue',
|
||||
'purple', 'purple', 'red', 'green', 'green', 'white', 'white',
|
||||
'purple', 'purple', 'red', 'green', 'green', 'white', 'white'
|
||||
]
|
||||
|
||||
if joints.shape[1] == 19:
|
||||
# parent indices -1 means no parents
|
||||
parents = np.array([
|
||||
1, 2, 8, 9, 3, 4, 7, 8, 12, 12, 9, 10, 14, -1, 13, -1, -1, 15, 16
|
||||
])
|
||||
# Left is light and right is dark
|
||||
ecolors = {
|
||||
0: 'light_pink',
|
||||
1: 'light_pink',
|
||||
2: 'light_pink',
|
||||
3: 'pink',
|
||||
4: 'pink',
|
||||
5: 'pink',
|
||||
6: 'light_blue',
|
||||
7: 'light_blue',
|
||||
8: 'light_blue',
|
||||
9: 'blue',
|
||||
10: 'blue',
|
||||
11: 'blue',
|
||||
12: 'purple',
|
||||
17: 'light_green',
|
||||
18: 'light_green',
|
||||
14: 'purple'
|
||||
}
|
||||
elif joints.shape[1] == 14:
|
||||
parents = np.array([
|
||||
1,
|
||||
2,
|
||||
8,
|
||||
9,
|
||||
3,
|
||||
4,
|
||||
7,
|
||||
8,
|
||||
-1,
|
||||
-1,
|
||||
9,
|
||||
10,
|
||||
13,
|
||||
-1,
|
||||
])
|
||||
ecolors = {
|
||||
0: 'light_pink',
|
||||
1: 'light_pink',
|
||||
2: 'light_pink',
|
||||
3: 'pink',
|
||||
4: 'pink',
|
||||
5: 'pink',
|
||||
6: 'light_blue',
|
||||
7: 'light_blue',
|
||||
10: 'light_blue',
|
||||
11: 'blue',
|
||||
12: 'purple'
|
||||
}
|
||||
elif joints.shape[1] == 21: # hand
|
||||
parents = np.array([
|
||||
-1,
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
0,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
0,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
0,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
0,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
])
|
||||
ecolors = {
|
||||
0: 'light_purple',
|
||||
1: 'light_green',
|
||||
2: 'light_green',
|
||||
3: 'light_green',
|
||||
4: 'light_green',
|
||||
5: 'pink',
|
||||
6: 'pink',
|
||||
7: 'pink',
|
||||
8: 'pink',
|
||||
9: 'light_blue',
|
||||
10: 'light_blue',
|
||||
11: 'light_blue',
|
||||
12: 'light_blue',
|
||||
13: 'light_red',
|
||||
14: 'light_red',
|
||||
15: 'light_red',
|
||||
16: 'light_red',
|
||||
17: 'purple',
|
||||
18: 'purple',
|
||||
19: 'purple',
|
||||
20: 'purple',
|
||||
}
|
||||
else:
|
||||
print('Unknown skeleton!!')
|
||||
|
||||
for child in range(len(parents)):
|
||||
point = joints[:, child]
|
||||
# If invisible skip
|
||||
if vis is not None and vis[child] == 0:
|
||||
continue
|
||||
if draw_edges:
|
||||
cv2.circle(image, (point[0], point[1]), radius, colors['white'],
|
||||
-1)
|
||||
cv2.circle(image, (point[0], point[1]), radius - 1,
|
||||
colors[jcolors[child]], -1)
|
||||
else:
|
||||
# cv2.circle(image, (point[0], point[1]), 5, colors['white'], 1)
|
||||
cv2.circle(image, (point[0], point[1]), radius - 1,
|
||||
colors[jcolors[child]], 1)
|
||||
# cv2.circle(image, (point[0], point[1]), 5, colors['gray'], -1)
|
||||
pa_id = parents[child]
|
||||
if draw_edges and pa_id >= 0:
|
||||
if vis is not None and vis[pa_id] == 0:
|
||||
continue
|
||||
point_pa = joints[:, pa_id]
|
||||
cv2.circle(image, (point_pa[0], point_pa[1]), radius - 1,
|
||||
colors[jcolors[pa_id]], -1)
|
||||
if child not in ecolors.keys():
|
||||
print('bad')
|
||||
import ipdb
|
||||
ipdb.set_trace()
|
||||
cv2.line(image, (point[0], point[1]), (point_pa[0], point_pa[1]),
|
||||
colors[ecolors[child]], radius - 2)
|
||||
|
||||
# Convert back in original dtype
|
||||
if input_is_float:
|
||||
if max_val <= 1.:
|
||||
image = image.astype(np.float32) / 255.
|
||||
else:
|
||||
image = image.astype(np.float32)
|
||||
|
||||
return image
|
||||
|
||||
def draw_text(input_image, content):
|
||||
"""
|
||||
content is a dict. draws key: val on image
|
||||
Assumes key is str, val is float
|
||||
"""
|
||||
image = input_image.copy()
|
||||
input_is_float = False
|
||||
if np.issubdtype(image.dtype, np.float):
|
||||
input_is_float = True
|
||||
image = (image * 255).astype(np.uint8)
|
||||
|
||||
black = (255, 255, 0)
|
||||
margin = 15
|
||||
start_x = 5
|
||||
start_y = margin
|
||||
for key in sorted(content.keys()):
|
||||
text = "%s: %.2g" % (key, content[key])
|
||||
cv2.putText(image, text, (start_x, start_y), 0, 0.45, black)
|
||||
start_y += margin
|
||||
|
||||
if input_is_float:
|
||||
image = image.astype(np.float32) / 255.
|
||||
return image
|
||||
|
||||
def visualize_reconstruction(img, img_size, gt_kp, vertices, pred_kp, camera, renderer, color='pink', focal_length=1000):
|
||||
"""Overlays gt_kp and pred_kp on img.
|
||||
Draws vert with text.
|
||||
Renderer is an instance of SMPLRenderer.
|
||||
"""
|
||||
gt_vis = gt_kp[:, 2].astype(bool)
|
||||
loss = np.sum((gt_kp[gt_vis, :2] - pred_kp[gt_vis])**2)
|
||||
debug_text = {"sc": camera[0], "tx": camera[1], "ty": camera[2], "kpl": loss}
|
||||
# Fix a flength so i can render this with persp correct scale
|
||||
res = img.shape[1]
|
||||
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
|
||||
rend_img = renderer.render(vertices, camera_t=camera_t,
|
||||
img=img, use_bg=True,
|
||||
focal_length=focal_length,
|
||||
body_color=color)
|
||||
rend_img = draw_text(rend_img, debug_text)
|
||||
|
||||
# Draw skeleton
|
||||
gt_joint = ((gt_kp[:, :2] + 1) * 0.5) * img_size
|
||||
pred_joint = ((pred_kp + 1) * 0.5) * img_size
|
||||
img_with_gt = draw_skeleton( img, gt_joint, draw_edges=False, vis=gt_vis)
|
||||
skel_img = draw_skeleton(img_with_gt, pred_joint)
|
||||
|
||||
combined = np.hstack([skel_img, rend_img])
|
||||
|
||||
return combined
|
||||
|
||||
def visualize_reconstruction_test(img, img_size, gt_kp, vertices, pred_kp, camera, renderer, score, color='pink', focal_length=1000):
|
||||
"""Overlays gt_kp and pred_kp on img.
|
||||
Draws vert with text.
|
||||
Renderer is an instance of SMPLRenderer.
|
||||
"""
|
||||
gt_vis = gt_kp[:, 2].astype(bool)
|
||||
loss = np.sum((gt_kp[gt_vis, :2] - pred_kp[gt_vis])**2)
|
||||
debug_text = {"sc": camera[0], "tx": camera[1], "ty": camera[2], "kpl": loss, "pa-mpjpe": score*1000}
|
||||
# Fix a flength so i can render this with persp correct scale
|
||||
res = img.shape[1]
|
||||
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
|
||||
rend_img = renderer.render(vertices, camera_t=camera_t,
|
||||
img=img, use_bg=True,
|
||||
focal_length=focal_length,
|
||||
body_color=color)
|
||||
rend_img = draw_text(rend_img, debug_text)
|
||||
|
||||
# Draw skeleton
|
||||
gt_joint = ((gt_kp[:, :2] + 1) * 0.5) * img_size
|
||||
pred_joint = ((pred_kp + 1) * 0.5) * img_size
|
||||
img_with_gt = draw_skeleton( img, gt_joint, draw_edges=False, vis=gt_vis)
|
||||
skel_img = draw_skeleton(img_with_gt, pred_joint)
|
||||
|
||||
combined = np.hstack([skel_img, rend_img])
|
||||
|
||||
return combined
|
||||
|
||||
|
||||
|
||||
def visualize_reconstruction_and_att(img, img_size, vertices_full, vertices, vertices_2d, camera, renderer, ref_points, attention, focal_length=1000):
|
||||
"""Overlays gt_kp and pred_kp on img.
|
||||
Draws vert with text.
|
||||
Renderer is an instance of SMPLRenderer.
|
||||
"""
|
||||
# Fix a flength so i can render this with persp correct scale
|
||||
res = img.shape[1]
|
||||
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
|
||||
rend_img = renderer.render(vertices_full, camera_t=camera_t,
|
||||
img=img, use_bg=True,
|
||||
focal_length=focal_length, body_color='light_blue')
|
||||
|
||||
|
||||
heads_num, vertex_num, _ = attention.shape
|
||||
|
||||
all_head = np.zeros((vertex_num,vertex_num))
|
||||
|
||||
###### find max
|
||||
# for i in range(vertex_num):
|
||||
# for j in range(vertex_num):
|
||||
# all_head[i,j] = np.max(attention[:,i,j])
|
||||
|
||||
##### find avg
|
||||
for h in range(4):
|
||||
att_per_img = attention[h]
|
||||
all_head = all_head + att_per_img
|
||||
all_head = all_head/4
|
||||
|
||||
col_sums = all_head.sum(axis=0)
|
||||
all_head = all_head / col_sums[np.newaxis, :]
|
||||
|
||||
|
||||
# code.interact(local=locals())
|
||||
|
||||
combined = []
|
||||
if vertex_num>400: # body
|
||||
selected_joints = [6,7,4,5,13] # [6,7,4,5,13,12]
|
||||
else: # hand
|
||||
selected_joints = [0, 4, 8, 12, 16, 20]
|
||||
# Draw attention
|
||||
for ii in range(len(selected_joints)):
|
||||
reference_id = selected_joints[ii]
|
||||
ref_point = ref_points[reference_id]
|
||||
attention_to_show = all_head[reference_id][14::]
|
||||
min_v = np.min(attention_to_show)
|
||||
max_v = np.max(attention_to_show)
|
||||
norm_attention_to_show = (attention_to_show - min_v)/(max_v-min_v)
|
||||
|
||||
vertices_norm = ((vertices_2d + 1) * 0.5) * img_size
|
||||
ref_norm = ((ref_point + 1) * 0.5) * img_size
|
||||
image = np.zeros_like(rend_img)
|
||||
|
||||
for jj in range(vertices_norm.shape[0]):
|
||||
x = int(vertices_norm[jj,0])
|
||||
y = int(vertices_norm[jj,1])
|
||||
cv2.circle(image,(x,y), 1, (255,255,255), -1)
|
||||
|
||||
total_to_draw = []
|
||||
for jj in range(vertices_norm.shape[0]):
|
||||
thres = 0.0
|
||||
if norm_attention_to_show[jj]>thres:
|
||||
things = [norm_attention_to_show[jj], ref_norm, vertices_norm[jj]]
|
||||
total_to_draw.append(things)
|
||||
# plot_one_line(ref_norm, vertices_norm[jj], image, reference_id, alpha=0.4*(norm_attention_to_show[jj]-thres)/(1-thres) )
|
||||
total_to_draw.sort()
|
||||
max_att_score = total_to_draw[-1][0]
|
||||
for item in total_to_draw:
|
||||
attention_score = item[0]
|
||||
ref_point = item[1]
|
||||
vertex = item[2]
|
||||
plot_one_line(ref_point, vertex, image, ii, alpha=(attention_score-thres)/(max_att_score-thres) )
|
||||
# code.interact(local=locals())
|
||||
if len(combined)==0:
|
||||
combined = image
|
||||
else:
|
||||
combined = np.hstack([combined, image])
|
||||
|
||||
final = np.hstack([img, combined, rend_img])
|
||||
|
||||
return final
|
||||
|
||||
|
||||
def visualize_reconstruction_and_att_local(img, img_size, vertices_full, vertices, vertices_2d, camera, renderer, ref_points, attention, color='light_blue', focal_length=1000):
|
||||
"""Overlays gt_kp and pred_kp on img.
|
||||
Draws vert with text.
|
||||
Renderer is an instance of SMPLRenderer.
|
||||
"""
|
||||
# Fix a flength so i can render this with persp correct scale
|
||||
res = img.shape[1]
|
||||
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
|
||||
rend_img = renderer.render(vertices_full, camera_t=camera_t,
|
||||
img=img, use_bg=True,
|
||||
focal_length=focal_length, body_color=color)
|
||||
heads_num, vertex_num, _ = attention.shape
|
||||
all_head = np.zeros((vertex_num,vertex_num))
|
||||
|
||||
##### compute avg attention for 4 attention heads
|
||||
for h in range(4):
|
||||
att_per_img = attention[h]
|
||||
all_head = all_head + att_per_img
|
||||
all_head = all_head/4
|
||||
|
||||
col_sums = all_head.sum(axis=0)
|
||||
all_head = all_head / col_sums[np.newaxis, :]
|
||||
|
||||
combined = []
|
||||
if vertex_num>400: # body
|
||||
selected_joints = [7] # [6,7,4,5,13,12]
|
||||
else: # hand
|
||||
selected_joints = [0] # [0, 4, 8, 12, 16, 20]
|
||||
# Draw attention
|
||||
for ii in range(len(selected_joints)):
|
||||
reference_id = selected_joints[ii]
|
||||
ref_point = ref_points[reference_id]
|
||||
attention_to_show = all_head[reference_id][14::]
|
||||
min_v = np.min(attention_to_show)
|
||||
max_v = np.max(attention_to_show)
|
||||
norm_attention_to_show = (attention_to_show - min_v)/(max_v-min_v)
|
||||
vertices_norm = ((vertices_2d + 1) * 0.5) * img_size
|
||||
ref_norm = ((ref_point + 1) * 0.5) * img_size
|
||||
image = rend_img*0.4
|
||||
|
||||
total_to_draw = []
|
||||
for jj in range(vertices_norm.shape[0]):
|
||||
thres = 0.0
|
||||
if norm_attention_to_show[jj]>thres:
|
||||
things = [norm_attention_to_show[jj], ref_norm, vertices_norm[jj]]
|
||||
total_to_draw.append(things)
|
||||
total_to_draw.sort()
|
||||
max_att_score = total_to_draw[-1][0]
|
||||
for item in total_to_draw:
|
||||
attention_score = item[0]
|
||||
ref_point = item[1]
|
||||
vertex = item[2]
|
||||
plot_one_line(ref_point, vertex, image, ii, alpha=(attention_score-thres)/(max_att_score-thres) )
|
||||
|
||||
for jj in range(vertices_norm.shape[0]):
|
||||
x = int(vertices_norm[jj,0])
|
||||
y = int(vertices_norm[jj,1])
|
||||
cv2.circle(image,(x,y), 1, (255,255,255), -1)
|
||||
|
||||
if len(combined)==0:
|
||||
combined = image
|
||||
else:
|
||||
combined = np.hstack([combined, image])
|
||||
|
||||
final = np.hstack([img, combined, rend_img])
|
||||
|
||||
return final
|
||||
|
||||
|
||||
def visualize_reconstruction_no_text(img, img_size, vertices, camera, renderer, color='pink', focal_length=1000):
|
||||
"""Overlays gt_kp and pred_kp on img.
|
||||
Draws vert with text.
|
||||
Renderer is an instance of SMPLRenderer.
|
||||
"""
|
||||
# Fix a flength so i can render this with persp correct scale
|
||||
res = img.shape[1]
|
||||
camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)])
|
||||
rend_img = renderer.render(vertices, camera_t=camera_t,
|
||||
img=img, use_bg=True,
|
||||
focal_length=focal_length,
|
||||
body_color=color)
|
||||
|
||||
|
||||
combined = np.hstack([img, rend_img])
|
||||
|
||||
return combined
|
||||
|
||||
|
||||
def plot_one_line(ref, vertex, img, color_index, alpha=0.0, line_thickness=None):
|
||||
# 13,6,7,8,3,4,5
|
||||
# att_colors = [(255, 221, 104), (255, 255, 0), (255, 215, 227), (210, 240, 119), \
|
||||
# (209, 238, 245), (244, 200, 243), (233, 242, 216)]
|
||||
att_colors = [(255, 255, 0), (244, 200, 243), (210, 243, 119), (209, 238, 255), (200, 208, 255), (250, 238, 215)]
|
||||
|
||||
|
||||
overlay = img.copy()
|
||||
# output = img.copy()
|
||||
# Plots one bounding box on image img
|
||||
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
|
||||
|
||||
color = list(att_colors[color_index])
|
||||
c1, c2 = (int(ref[0]), int(ref[1])), (int(vertex[0]), int(vertex[1]))
|
||||
cv2.line(overlay, c1, c2, (alpha*float(color[0])/255,alpha*float(color[1])/255,alpha*float(color[2])/255) , thickness=tl, lineType=cv2.LINE_AA)
|
||||
cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
|
||||
|
||||
|
||||
|
||||
def cam2pixel(cam_coord, f, c):
|
||||
x = cam_coord[:, 0] / (cam_coord[:, 2]) * f[0] + c[0]
|
||||
y = cam_coord[:, 1] / (cam_coord[:, 2]) * f[1] + c[1]
|
||||
z = cam_coord[:, 2]
|
||||
img_coord = np.concatenate((x[:,None], y[:,None], z[:,None]),1)
|
||||
return img_coord
|
||||
|
||||
|
||||
class Renderer(object):
|
||||
"""
|
||||
Render mesh using OpenDR for visualization.
|
||||
"""
|
||||
|
||||
def __init__(self, width=800, height=600, near=0.5, far=1000, faces=None):
|
||||
self.colors = {'hand': [.9, .9, .9], 'pink': [.9, .7, .7], 'light_blue': [0.65098039, 0.74117647, 0.85882353] }
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.faces = faces
|
||||
self.renderer = ColoredRenderer()
|
||||
|
||||
def render(self, vertices, faces=None, img=None,
|
||||
camera_t=np.zeros([3], dtype=np.float32),
|
||||
camera_rot=np.zeros([3], dtype=np.float32),
|
||||
camera_center=None,
|
||||
use_bg=False,
|
||||
bg_color=(0.0, 0.0, 0.0),
|
||||
body_color=None,
|
||||
focal_length=5000,
|
||||
disp_text=False,
|
||||
gt_keyp=None,
|
||||
pred_keyp=None,
|
||||
**kwargs):
|
||||
if img is not None:
|
||||
height, width = img.shape[:2]
|
||||
else:
|
||||
height, width = self.height, self.width
|
||||
|
||||
if faces is None:
|
||||
faces = self.faces
|
||||
|
||||
if camera_center is None:
|
||||
camera_center = np.array([width * 0.5,
|
||||
height * 0.5])
|
||||
|
||||
self.renderer.camera = ProjectPoints(rt=camera_rot,
|
||||
t=camera_t,
|
||||
f=focal_length * np.ones(2),
|
||||
c=camera_center,
|
||||
k=np.zeros(5))
|
||||
dist = np.abs(self.renderer.camera.t.r[2] -
|
||||
np.mean(vertices, axis=0)[2])
|
||||
far = dist + 20
|
||||
|
||||
self.renderer.frustum = {'near': 1.0, 'far': far,
|
||||
'width': width,
|
||||
'height': height}
|
||||
|
||||
if img is not None:
|
||||
if use_bg:
|
||||
self.renderer.background_image = img
|
||||
else:
|
||||
self.renderer.background_image = np.ones_like(
|
||||
img) * np.array(bg_color)
|
||||
|
||||
if body_color is None:
|
||||
color = self.colors['light_blue']
|
||||
else:
|
||||
color = self.colors[body_color]
|
||||
|
||||
if isinstance(self.renderer, TexturedRenderer):
|
||||
color = [1.,1.,1.]
|
||||
|
||||
self.renderer.set(v=vertices, f=faces,
|
||||
vc=color, bgcolor=np.ones(3))
|
||||
albedo = self.renderer.vc
|
||||
# Construct Back Light (on back right corner)
|
||||
yrot = np.radians(120)
|
||||
|
||||
self.renderer.vc = LambertianPointLight(
|
||||
f=self.renderer.f,
|
||||
v=self.renderer.v,
|
||||
num_verts=self.renderer.v.shape[0],
|
||||
light_pos=rotateY(np.array([-200, -100, -100]), yrot),
|
||||
vc=albedo,
|
||||
light_color=np.array([1, 1, 1]))
|
||||
|
||||
# Construct Left Light
|
||||
self.renderer.vc += LambertianPointLight(
|
||||
f=self.renderer.f,
|
||||
v=self.renderer.v,
|
||||
num_verts=self.renderer.v.shape[0],
|
||||
light_pos=rotateY(np.array([800, 10, 300]), yrot),
|
||||
vc=albedo,
|
||||
light_color=np.array([1, 1, 1]))
|
||||
|
||||
# Construct Right Light
|
||||
self.renderer.vc += LambertianPointLight(
|
||||
f=self.renderer.f,
|
||||
v=self.renderer.v,
|
||||
num_verts=self.renderer.v.shape[0],
|
||||
light_pos=rotateY(np.array([-500, 500, 1000]), yrot),
|
||||
vc=albedo,
|
||||
light_color=np.array([.7, .7, .7]))
|
||||
|
||||
return self.renderer.r
|
||||
|
||||
|
||||
def render_vertex_color(self, vertices, faces=None, img=None,
|
||||
camera_t=np.zeros([3], dtype=np.float32),
|
||||
camera_rot=np.zeros([3], dtype=np.float32),
|
||||
camera_center=None,
|
||||
use_bg=False,
|
||||
bg_color=(0.0, 0.0, 0.0),
|
||||
vertex_color=None,
|
||||
focal_length=5000,
|
||||
disp_text=False,
|
||||
gt_keyp=None,
|
||||
pred_keyp=None,
|
||||
**kwargs):
|
||||
if img is not None:
|
||||
height, width = img.shape[:2]
|
||||
else:
|
||||
height, width = self.height, self.width
|
||||
|
||||
if faces is None:
|
||||
faces = self.faces
|
||||
|
||||
if camera_center is None:
|
||||
camera_center = np.array([width * 0.5,
|
||||
height * 0.5])
|
||||
|
||||
self.renderer.camera = ProjectPoints(rt=camera_rot,
|
||||
t=camera_t,
|
||||
f=focal_length * np.ones(2),
|
||||
c=camera_center,
|
||||
k=np.zeros(5))
|
||||
dist = np.abs(self.renderer.camera.t.r[2] -
|
||||
np.mean(vertices, axis=0)[2])
|
||||
far = dist + 20
|
||||
|
||||
self.renderer.frustum = {'near': 1.0, 'far': far,
|
||||
'width': width,
|
||||
'height': height}
|
||||
|
||||
if img is not None:
|
||||
if use_bg:
|
||||
self.renderer.background_image = img
|
||||
else:
|
||||
self.renderer.background_image = np.ones_like(
|
||||
img) * np.array(bg_color)
|
||||
|
||||
if vertex_color is None:
|
||||
vertex_color = self.colors['light_blue']
|
||||
|
||||
|
||||
self.renderer.set(v=vertices, f=faces,
|
||||
vc=vertex_color, bgcolor=np.ones(3))
|
||||
albedo = self.renderer.vc
|
||||
# Construct Back Light (on back right corner)
|
||||
yrot = np.radians(120)
|
||||
|
||||
self.renderer.vc = LambertianPointLight(
|
||||
f=self.renderer.f,
|
||||
v=self.renderer.v,
|
||||
num_verts=self.renderer.v.shape[0],
|
||||
light_pos=rotateY(np.array([-200, -100, -100]), yrot),
|
||||
vc=albedo,
|
||||
light_color=np.array([1, 1, 1]))
|
||||
|
||||
# Construct Left Light
|
||||
self.renderer.vc += LambertianPointLight(
|
||||
f=self.renderer.f,
|
||||
v=self.renderer.v,
|
||||
num_verts=self.renderer.v.shape[0],
|
||||
light_pos=rotateY(np.array([800, 10, 300]), yrot),
|
||||
vc=albedo,
|
||||
light_color=np.array([1, 1, 1]))
|
||||
|
||||
# Construct Right Light
|
||||
self.renderer.vc += LambertianPointLight(
|
||||
f=self.renderer.f,
|
||||
v=self.renderer.v,
|
||||
num_verts=self.renderer.v.shape[0],
|
||||
light_pos=rotateY(np.array([-500, 500, 1000]), yrot),
|
||||
vc=albedo,
|
||||
light_color=np.array([.7, .7, .7]))
|
||||
|
||||
return self.renderer.r
|
||||
162
mesh_graphormer/utils/tsv_file.py
Normal file
162
mesh_graphormer/utils/tsv_file.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Copyright (c) Microsoft Corporation.
|
||||
Licensed under the MIT license.
|
||||
|
||||
Definition of TSV class
|
||||
"""
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import os.path as op
|
||||
|
||||
|
||||
def generate_lineidx(filein, idxout):
|
||||
idxout_tmp = idxout + '.tmp'
|
||||
with open(filein, 'r') as tsvin, open(idxout_tmp,'w') as tsvout:
|
||||
fsize = os.fstat(tsvin.fileno()).st_size
|
||||
fpos = 0
|
||||
while fpos!=fsize:
|
||||
tsvout.write(str(fpos)+"\n")
|
||||
tsvin.readline()
|
||||
fpos = tsvin.tell()
|
||||
os.rename(idxout_tmp, idxout)
|
||||
|
||||
|
||||
def read_to_character(fp, c):
|
||||
result = []
|
||||
while True:
|
||||
s = fp.read(32)
|
||||
assert s != ''
|
||||
if c in s:
|
||||
result.append(s[: s.index(c)])
|
||||
break
|
||||
else:
|
||||
result.append(s)
|
||||
return ''.join(result)
|
||||
|
||||
|
||||
class TSVFile(object):
|
||||
def __init__(self, tsv_file, generate_lineidx=False):
|
||||
self.tsv_file = tsv_file
|
||||
self.lineidx = op.splitext(tsv_file)[0] + '.lineidx'
|
||||
self._fp = None
|
||||
self._lineidx = None
|
||||
# the process always keeps the process which opens the file.
|
||||
# If the pid is not equal to the currrent pid, we will re-open the file.
|
||||
self.pid = None
|
||||
# generate lineidx if not exist
|
||||
if not op.isfile(self.lineidx) and generate_lineidx:
|
||||
generate_lineidx(self.tsv_file, self.lineidx)
|
||||
|
||||
def __del__(self):
|
||||
if self._fp:
|
||||
self._fp.close()
|
||||
|
||||
def __str__(self):
|
||||
return "TSVFile(tsv_file='{}')".format(self.tsv_file)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def num_rows(self):
|
||||
self._ensure_lineidx_loaded()
|
||||
return len(self._lineidx)
|
||||
|
||||
def seek(self, idx):
|
||||
self._ensure_tsv_opened()
|
||||
self._ensure_lineidx_loaded()
|
||||
try:
|
||||
pos = self._lineidx[idx]
|
||||
except:
|
||||
logging.info('{}-{}'.format(self.tsv_file, idx))
|
||||
raise
|
||||
self._fp.seek(pos)
|
||||
return [s.strip() for s in self._fp.readline().split('\t')]
|
||||
|
||||
def seek_first_column(self, idx):
|
||||
self._ensure_tsv_opened()
|
||||
self._ensure_lineidx_loaded()
|
||||
pos = self._lineidx[idx]
|
||||
self._fp.seek(pos)
|
||||
return read_to_character(self._fp, '\t')
|
||||
|
||||
def get_key(self, idx):
|
||||
return self.seek_first_column(idx)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.seek(index)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_rows()
|
||||
|
||||
def _ensure_lineidx_loaded(self):
|
||||
if self._lineidx is None:
|
||||
logging.info('loading lineidx: {}'.format(self.lineidx))
|
||||
with open(self.lineidx, 'r') as fp:
|
||||
self._lineidx = [int(i.strip()) for i in fp.readlines()]
|
||||
|
||||
def _ensure_tsv_opened(self):
|
||||
if self._fp is None:
|
||||
self._fp = open(self.tsv_file, 'r')
|
||||
self.pid = os.getpid()
|
||||
|
||||
if self.pid != os.getpid():
|
||||
logging.info('re-open {} because the process id changed'.format(self.tsv_file))
|
||||
self._fp = open(self.tsv_file, 'r')
|
||||
self.pid = os.getpid()
|
||||
|
||||
|
||||
class CompositeTSVFile():
|
||||
def __init__(self, file_list, seq_file, root='.'):
|
||||
if isinstance(file_list, str):
|
||||
self.file_list = load_list_file(file_list)
|
||||
else:
|
||||
assert isinstance(file_list, list)
|
||||
self.file_list = file_list
|
||||
|
||||
self.seq_file = seq_file
|
||||
self.root = root
|
||||
self.initialized = False
|
||||
self.initialize()
|
||||
|
||||
def get_key(self, index):
|
||||
idx_source, idx_row = self.seq[index]
|
||||
k = self.tsvs[idx_source].get_key(idx_row)
|
||||
return '_'.join([self.file_list[idx_source], k])
|
||||
|
||||
def num_rows(self):
|
||||
return len(self.seq)
|
||||
|
||||
def __getitem__(self, index):
|
||||
idx_source, idx_row = self.seq[index]
|
||||
return self.tsvs[idx_source].seek(idx_row)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seq)
|
||||
|
||||
def initialize(self):
|
||||
'''
|
||||
this function has to be called in init function if cache_policy is
|
||||
enabled. Thus, let's always call it in init funciton to make it simple.
|
||||
'''
|
||||
if self.initialized:
|
||||
return
|
||||
self.seq = []
|
||||
with open(self.seq_file, 'r') as fp:
|
||||
for line in fp:
|
||||
parts = line.strip().split('\t')
|
||||
self.seq.append([int(parts[0]), int(parts[1])])
|
||||
self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list]
|
||||
self.initialized = True
|
||||
|
||||
|
||||
def load_list_file(fname):
|
||||
with open(fname, 'r') as fp:
|
||||
lines = fp.readlines()
|
||||
result = [line.strip() for line in lines]
|
||||
if len(result) > 0 and result[-1] == '':
|
||||
result = result[:-1]
|
||||
return result
|
||||
|
||||
|
||||
116
mesh_graphormer/utils/tsv_file_ops.py
Normal file
116
mesh_graphormer/utils/tsv_file_ops.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Copyright (c) Microsoft Corporation.
|
||||
Licensed under the MIT license.
|
||||
|
||||
Basic operations for TSV files
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import os.path as op
|
||||
import json
|
||||
import numpy as np
|
||||
import base64
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
import yaml
|
||||
from mesh_graphormer.utils.miscellaneous import mkdir
|
||||
from mesh_graphormer.utils.tsv_file import TSVFile
|
||||
|
||||
|
||||
def img_from_base64(imagestring):
|
||||
try:
|
||||
jpgbytestring = base64.b64decode(imagestring)
|
||||
nparr = np.frombuffer(jpgbytestring, np.uint8)
|
||||
r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
return r
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def load_linelist_file(linelist_file):
|
||||
if linelist_file is not None:
|
||||
line_list = []
|
||||
with open(linelist_file, 'r') as fp:
|
||||
for i in fp:
|
||||
line_list.append(int(i.strip()))
|
||||
return line_list
|
||||
|
||||
def tsv_writer(values, tsv_file, sep='\t'):
|
||||
mkdir(op.dirname(tsv_file))
|
||||
lineidx_file = op.splitext(tsv_file)[0] + '.lineidx'
|
||||
idx = 0
|
||||
tsv_file_tmp = tsv_file + '.tmp'
|
||||
lineidx_file_tmp = lineidx_file + '.tmp'
|
||||
with open(tsv_file_tmp, 'w') as fp, open(lineidx_file_tmp, 'w') as fpidx:
|
||||
assert values is not None
|
||||
for value in values:
|
||||
assert value is not None
|
||||
value = [v if type(v)!=bytes else v.decode('utf-8') for v in value]
|
||||
v = '{0}\n'.format(sep.join(map(str, value)))
|
||||
fp.write(v)
|
||||
fpidx.write(str(idx) + '\n')
|
||||
idx = idx + len(v)
|
||||
os.rename(tsv_file_tmp, tsv_file)
|
||||
os.rename(lineidx_file_tmp, lineidx_file)
|
||||
|
||||
def tsv_reader(tsv_file, sep='\t'):
|
||||
with open(tsv_file, 'r') as fp:
|
||||
for i, line in enumerate(fp):
|
||||
yield [x.strip() for x in line.split(sep)]
|
||||
|
||||
def config_save_file(tsv_file, save_file=None, append_str='.new.tsv'):
|
||||
if save_file is not None:
|
||||
return save_file
|
||||
return op.splitext(tsv_file)[0] + append_str
|
||||
|
||||
def get_line_list(linelist_file=None, num_rows=None):
|
||||
if linelist_file is not None:
|
||||
return load_linelist_file(linelist_file)
|
||||
|
||||
if num_rows is not None:
|
||||
return [i for i in range(num_rows)]
|
||||
|
||||
def generate_hw_file(img_file, save_file=None):
|
||||
rows = tsv_reader(img_file)
|
||||
def gen_rows():
|
||||
for i, row in tqdm(enumerate(rows)):
|
||||
row1 = [row[0]]
|
||||
img = img_from_base64(row[-1])
|
||||
height = img.shape[0]
|
||||
width = img.shape[1]
|
||||
row1.append(json.dumps([{"height":height, "width": width}]))
|
||||
yield row1
|
||||
|
||||
save_file = config_save_file(img_file, save_file, '.hw.tsv')
|
||||
tsv_writer(gen_rows(), save_file)
|
||||
|
||||
def generate_linelist_file(label_file, save_file=None, ignore_attrs=()):
|
||||
# generate a list of image that has labels
|
||||
# images with only ignore labels are not selected.
|
||||
line_list = []
|
||||
rows = tsv_reader(label_file)
|
||||
for i, row in tqdm(enumerate(rows)):
|
||||
labels = json.loads(row[1])
|
||||
if labels:
|
||||
if ignore_attrs and all([any([lab[attr] for attr in ignore_attrs if attr in lab]) \
|
||||
for lab in labels]):
|
||||
continue
|
||||
line_list.append([i])
|
||||
|
||||
save_file = config_save_file(label_file, save_file, '.linelist.tsv')
|
||||
tsv_writer(line_list, save_file)
|
||||
|
||||
def load_from_yaml_file(yaml_file):
|
||||
with open(yaml_file, 'r') as fp:
|
||||
return yaml.load(fp, Loader=yaml.CLoader)
|
||||
|
||||
def find_file_path_in_yaml(fname, root):
|
||||
if fname is not None:
|
||||
if op.isfile(fname):
|
||||
return fname
|
||||
elif op.isfile(op.join(root, fname)):
|
||||
return op.join(root, fname)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname)
|
||||
)
|
||||
Reference in New Issue
Block a user