mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-28 02:44:04 +00:00
131 lines
5.6 KiB
Python
131 lines
5.6 KiB
Python
import torch
|
|
import numpy as np
|
|
from scipy.ndimage import gaussian_filter
|
|
|
|
class HeatmapHead(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels=640,
|
|
out_channels=133,
|
|
input_size=(768, 1024),
|
|
heatmap_scale=4,
|
|
deconv_out_channels=(640,),
|
|
deconv_kernel_sizes=(4,),
|
|
conv_out_channels=(640,),
|
|
conv_kernel_sizes=(1,),
|
|
final_layer_kernel_size=1,
|
|
device=None, dtype=None, operations=None
|
|
):
|
|
super().__init__()
|
|
|
|
self.heatmap_size = (input_size[0] // heatmap_scale, input_size[1] // heatmap_scale)
|
|
self.scale_factor = ((np.array(input_size) - 1) / (np.array(self.heatmap_size) - 1)).astype(np.float32)
|
|
|
|
# Deconv layers
|
|
if deconv_out_channels:
|
|
deconv_layers = []
|
|
for out_ch, kernel_size in zip(deconv_out_channels, deconv_kernel_sizes):
|
|
if kernel_size == 4:
|
|
padding, output_padding = 1, 0
|
|
elif kernel_size == 3:
|
|
padding, output_padding = 1, 1
|
|
elif kernel_size == 2:
|
|
padding, output_padding = 0, 0
|
|
else:
|
|
raise ValueError(f'Unsupported kernel size {kernel_size}')
|
|
|
|
deconv_layers.extend([
|
|
operations.ConvTranspose2d(in_channels, out_ch, kernel_size,
|
|
stride=2, padding=padding, output_padding=output_padding, bias=False, device=device, dtype=dtype),
|
|
torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype),
|
|
torch.nn.SiLU(inplace=True)
|
|
])
|
|
in_channels = out_ch
|
|
self.deconv_layers = torch.nn.Sequential(*deconv_layers)
|
|
else:
|
|
self.deconv_layers = torch.nn.Identity()
|
|
|
|
# Conv layers
|
|
if conv_out_channels:
|
|
conv_layers = []
|
|
for out_ch, kernel_size in zip(conv_out_channels, conv_kernel_sizes):
|
|
padding = (kernel_size - 1) // 2
|
|
conv_layers.extend([
|
|
operations.Conv2d(in_channels, out_ch, kernel_size,
|
|
stride=1, padding=padding, device=device, dtype=dtype),
|
|
torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype),
|
|
torch.nn.SiLU(inplace=True)
|
|
])
|
|
in_channels = out_ch
|
|
self.conv_layers = torch.nn.Sequential(*conv_layers)
|
|
else:
|
|
self.conv_layers = torch.nn.Identity()
|
|
|
|
self.final_layer = operations.Conv2d(in_channels, out_channels, kernel_size=final_layer_kernel_size, padding=final_layer_kernel_size // 2, device=device, dtype=dtype)
|
|
|
|
def forward(self, x): # Decode heatmaps to keypoints
|
|
heatmaps = self.final_layer(self.conv_layers(self.deconv_layers(x)))
|
|
heatmaps_np = heatmaps.float().cpu().numpy() # (B, K, H, W)
|
|
B, K, H, W = heatmaps_np.shape
|
|
|
|
batch_keypoints = []
|
|
batch_scores = []
|
|
|
|
for b in range(B):
|
|
hm = heatmaps_np[b].copy() # (K, H, W)
|
|
|
|
# --- vectorised argmax ---
|
|
flat = hm.reshape(K, -1)
|
|
idx = np.argmax(flat, axis=1)
|
|
scores = flat[np.arange(K), idx].copy()
|
|
y_locs, x_locs = np.unravel_index(idx, (H, W))
|
|
keypoints = np.stack([x_locs, y_locs], axis=-1).astype(np.float32) # (K, 2) in heatmap space
|
|
invalid = scores <= 0.
|
|
keypoints[invalid] = -1
|
|
|
|
# --- DARK sub-pixel refinement (UDP) ---
|
|
# 1. Gaussian blur with max-preserving normalisation
|
|
border = 5 # (kernel-1)//2 for kernel=11
|
|
for k in range(K):
|
|
origin_max = np.max(hm[k])
|
|
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
|
|
dr[border:-border, border:-border] = hm[k].copy()
|
|
dr = gaussian_filter(dr, sigma=2.0)
|
|
hm[k] = dr[border:-border, border:-border].copy()
|
|
cur_max = np.max(hm[k])
|
|
if cur_max > 0:
|
|
hm[k] *= origin_max / cur_max
|
|
# 2. Log-space for Taylor expansion
|
|
np.clip(hm, 1e-3, 50., hm)
|
|
np.log(hm, hm)
|
|
# 3. Hessian-based Newton step
|
|
hm_pad = np.pad(hm, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten()
|
|
index = keypoints[:, 0] + 1 + (keypoints[:, 1] + 1) * (W + 2)
|
|
index += (W + 2) * (H + 2) * np.arange(0, K)
|
|
index = index.astype(int).reshape(-1, 1)
|
|
i_ = hm_pad[index]
|
|
ix1 = hm_pad[index + 1]
|
|
iy1 = hm_pad[index + W + 2]
|
|
ix1y1 = hm_pad[index + W + 3]
|
|
ix1_y1_ = hm_pad[index - W - 3]
|
|
ix1_ = hm_pad[index - 1]
|
|
iy1_ = hm_pad[index - 2 - W]
|
|
dx = 0.5 * (ix1 - ix1_)
|
|
dy = 0.5 * (iy1 - iy1_)
|
|
derivative = np.concatenate([dx, dy], axis=1).reshape(K, 2, 1)
|
|
dxx = ix1 - 2 * i_ + ix1_
|
|
dyy = iy1 - 2 * i_ + iy1_
|
|
dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
|
|
hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1).reshape(K, 2, 2)
|
|
hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
|
|
keypoints -= np.einsum('imn,ink->imk', hessian, derivative).squeeze(axis=-1)
|
|
|
|
# --- restore to input image space ---
|
|
keypoints = keypoints * self.scale_factor
|
|
keypoints[invalid] = -1
|
|
|
|
batch_keypoints.append(keypoints)
|
|
batch_scores.append(scores)
|
|
|
|
return batch_keypoints, batch_scores
|