mirror of
https://github.com/huchenlei/HandRefinerPortable.git
synced 2026-04-29 02:31:21 +00:00
100 lines
2.9 KiB
Python
100 lines
2.9 KiB
Python
"""
|
|
Functions for compuing Procrustes alignment and reconstruction error
|
|
|
|
Parts of the code are adapted from https://github.com/akanazawa/hmr
|
|
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
import numpy as np
|
|
|
|
def compute_similarity_transform(S1, S2):
|
|
"""Computes a similarity transform (sR, t) that takes
|
|
a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
|
|
where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
|
|
i.e. solves the orthogonal Procrutes problem.
|
|
"""
|
|
transposed = False
|
|
if S1.shape[0] != 3 and S1.shape[0] != 2:
|
|
S1 = S1.T
|
|
S2 = S2.T
|
|
transposed = True
|
|
assert(S2.shape[1] == S1.shape[1])
|
|
|
|
# 1. Remove mean.
|
|
mu1 = S1.mean(axis=1, keepdims=True)
|
|
mu2 = S2.mean(axis=1, keepdims=True)
|
|
X1 = S1 - mu1
|
|
X2 = S2 - mu2
|
|
|
|
# 2. Compute variance of X1 used for scale.
|
|
var1 = np.sum(X1**2)
|
|
|
|
# 3. The outer product of X1 and X2.
|
|
K = X1.dot(X2.T)
|
|
|
|
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
|
|
# singular vectors of K.
|
|
U, s, Vh = np.linalg.svd(K)
|
|
V = Vh.T
|
|
# Construct Z that fixes the orientation of R to get det(R)=1.
|
|
Z = np.eye(U.shape[0])
|
|
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
|
|
# Construct R.
|
|
R = V.dot(Z.dot(U.T))
|
|
|
|
# 5. Recover scale.
|
|
scale = np.trace(R.dot(K)) / var1
|
|
|
|
# 6. Recover translation.
|
|
t = mu2 - scale*(R.dot(mu1))
|
|
|
|
# 7. Error:
|
|
S1_hat = scale*R.dot(S1) + t
|
|
|
|
if transposed:
|
|
S1_hat = S1_hat.T
|
|
|
|
return S1_hat
|
|
|
|
def compute_similarity_transform_batch(S1, S2):
|
|
"""Batched version of compute_similarity_transform."""
|
|
S1_hat = np.zeros_like(S1)
|
|
for i in range(S1.shape[0]):
|
|
S1_hat[i] = compute_similarity_transform(S1[i], S2[i])
|
|
return S1_hat
|
|
|
|
def reconstruction_error(S1, S2, reduction='mean'):
|
|
"""Do Procrustes alignment and compute reconstruction error."""
|
|
S1_hat = compute_similarity_transform_batch(S1, S2)
|
|
re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
|
|
if reduction == 'mean':
|
|
re = re.mean()
|
|
elif reduction == 'sum':
|
|
re = re.sum()
|
|
return re
|
|
|
|
|
|
def reconstruction_error_v2(S1, S2, J24_TO_J14, reduction='mean'):
|
|
"""Do Procrustes alignment and compute reconstruction error."""
|
|
S1_hat = compute_similarity_transform_batch(S1, S2)
|
|
S1_hat = S1_hat[:,J24_TO_J14,:]
|
|
S2 = S2[:,J24_TO_J14,:]
|
|
re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
|
|
if reduction == 'mean':
|
|
re = re.mean()
|
|
elif reduction == 'sum':
|
|
re = re.sum()
|
|
return re
|
|
|
|
def get_alignMesh(S1, S2, reduction='mean'):
|
|
"""Do Procrustes alignment and compute reconstruction error."""
|
|
S1_hat = compute_similarity_transform_batch(S1, S2)
|
|
re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
|
|
if reduction == 'mean':
|
|
re = re.mean()
|
|
elif reduction == 'sum':
|
|
re = re.sum()
|
|
return re, S1_hat, S2
|