import torch import numpy as np from scipy.ndimage import gaussian_filter class HeatmapHead(torch.nn.Module): def __init__( self, in_channels=640, out_channels=133, input_size=(768, 1024), heatmap_scale=4, deconv_out_channels=(640,), deconv_kernel_sizes=(4,), conv_out_channels=(640,), conv_kernel_sizes=(1,), final_layer_kernel_size=1, device=None, dtype=None, operations=None ): super().__init__() self.heatmap_size = (input_size[0] // heatmap_scale, input_size[1] // heatmap_scale) self.scale_factor = ((np.array(input_size) - 1) / (np.array(self.heatmap_size) - 1)).astype(np.float32) # Deconv layers if deconv_out_channels: deconv_layers = [] for out_ch, kernel_size in zip(deconv_out_channels, deconv_kernel_sizes): if kernel_size == 4: padding, output_padding = 1, 0 elif kernel_size == 3: padding, output_padding = 1, 1 elif kernel_size == 2: padding, output_padding = 0, 0 else: raise ValueError(f'Unsupported kernel size {kernel_size}') deconv_layers.extend([ operations.ConvTranspose2d(in_channels, out_ch, kernel_size, stride=2, padding=padding, output_padding=output_padding, bias=False, device=device, dtype=dtype), torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype), torch.nn.SiLU(inplace=True) ]) in_channels = out_ch self.deconv_layers = torch.nn.Sequential(*deconv_layers) else: self.deconv_layers = torch.nn.Identity() # Conv layers if conv_out_channels: conv_layers = [] for out_ch, kernel_size in zip(conv_out_channels, conv_kernel_sizes): padding = (kernel_size - 1) // 2 conv_layers.extend([ operations.Conv2d(in_channels, out_ch, kernel_size, stride=1, padding=padding, device=device, dtype=dtype), torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype), torch.nn.SiLU(inplace=True) ]) in_channels = out_ch self.conv_layers = torch.nn.Sequential(*conv_layers) else: self.conv_layers = torch.nn.Identity() self.final_layer = operations.Conv2d(in_channels, out_channels, kernel_size=final_layer_kernel_size, padding=final_layer_kernel_size // 2, device=device, dtype=dtype) def forward(self, x): # Decode heatmaps to keypoints heatmaps = self.final_layer(self.conv_layers(self.deconv_layers(x))) heatmaps_np = heatmaps.float().cpu().numpy() # (B, K, H, W) B, K, H, W = heatmaps_np.shape batch_keypoints = [] batch_scores = [] for b in range(B): hm = heatmaps_np[b].copy() # (K, H, W) # --- vectorised argmax --- flat = hm.reshape(K, -1) idx = np.argmax(flat, axis=1) scores = flat[np.arange(K), idx].copy() y_locs, x_locs = np.unravel_index(idx, (H, W)) keypoints = np.stack([x_locs, y_locs], axis=-1).astype(np.float32) # (K, 2) in heatmap space invalid = scores <= 0. keypoints[invalid] = -1 # --- DARK sub-pixel refinement (UDP) --- # 1. Gaussian blur with max-preserving normalisation border = 5 # (kernel-1)//2 for kernel=11 for k in range(K): origin_max = np.max(hm[k]) dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32) dr[border:-border, border:-border] = hm[k].copy() dr = gaussian_filter(dr, sigma=2.0) hm[k] = dr[border:-border, border:-border].copy() cur_max = np.max(hm[k]) if cur_max > 0: hm[k] *= origin_max / cur_max # 2. Log-space for Taylor expansion np.clip(hm, 1e-3, 50., hm) np.log(hm, hm) # 3. Hessian-based Newton step hm_pad = np.pad(hm, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten() index = keypoints[:, 0] + 1 + (keypoints[:, 1] + 1) * (W + 2) index += (W + 2) * (H + 2) * np.arange(0, K) index = index.astype(int).reshape(-1, 1) i_ = hm_pad[index] ix1 = hm_pad[index + 1] iy1 = hm_pad[index + W + 2] ix1y1 = hm_pad[index + W + 3] ix1_y1_ = hm_pad[index - W - 3] ix1_ = hm_pad[index - 1] iy1_ = hm_pad[index - 2 - W] dx = 0.5 * (ix1 - ix1_) dy = 0.5 * (iy1 - iy1_) derivative = np.concatenate([dx, dy], axis=1).reshape(K, 2, 1) dxx = ix1 - 2 * i_ + ix1_ dyy = iy1 - 2 * i_ + iy1_ dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1).reshape(K, 2, 2) hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) keypoints -= np.einsum('imn,ink->imk', hessian, derivative).squeeze(axis=-1) # --- restore to input image space --- keypoints = keypoints * self.scale_factor keypoints[invalid] = -1 batch_keypoints.append(keypoints) batch_scores.append(scores) return batch_keypoints, batch_scores