From f38de2a2fedfafa4bf298806d1efcabb4a357cbc Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 10 May 2026 14:47:44 -0600 Subject: [PATCH] Add tipsv2 locally and fix gradient checkpointing for it --- .../models/diffusion_feature_extraction.py | 11 +- toolkit/models/tipsv2.py | 867 ++++++++++++++++++ 2 files changed, 872 insertions(+), 6 deletions(-) create mode 100644 toolkit/models/tipsv2.py diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 7332a123..32a1d0fa 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -829,13 +829,12 @@ class DiffusionFeatureExtractor7(nn.Module): self.version = 7 self.sd_ref = weakref.ref(sd) if sd is not None else None + from toolkit.models.tipsv2 import TIPSv2DPTModel pretrained_model_name = "google/tipsv2-b14-dpt" - self.model = AutoModel.from_pretrained( - pretrained_model_name, - device_map=device, - dtype=torch.float32, - trust_remote_code=True - ).to(device) + self.model = TIPSv2DPTModel.from_pretrained( + pretrained_model_name, + dtype=dtype, + ).to(device, dtype=dtype) self.losses = {} self.log_every = 100 diff --git a/toolkit/models/tipsv2.py b/toolkit/models/tipsv2.py new file mode 100644 index 00000000..ff751924 --- /dev/null +++ b/toolkit/models/tipsv2.py @@ -0,0 +1,867 @@ +"""Local implementation of google/tipsv2-b14-dpt. + +Self-contained port of the remote `trust_remote_code=True` model into ai-toolkit. +Includes vision encoder + DPT depth/normals/segmentation heads, with optional +gradient checkpointing on the vision transformer blocks. The text encoder is +intentionally not included — only the dense-prediction stack is used here. + +Original remote code: https://huggingface.co/google/tipsv2-b14-dpt + https://huggingface.co/google/tipsv2-b14 +""" + +import functools +import math +from dataclasses import dataclass +from typing import Callable, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + + +# ───────────────────────── Vision Transformer ────────────────────────────── + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def _make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + image_hw = _make_2tuple(img_size) + patch_hw = _make_2tuple(patch_size) + self.img_size = image_hw + self.patch_size = patch_hw + self.patches_resolution = ( + image_hw[0] // patch_hw[0], + image_hw[1] // patch_hw[1], + ) + self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + self.flatten_embedding = flatten_embedding + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_hw, stride=patch_hw + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, _, h, w = x.shape + ph, pw = self.patch_size + assert h % ph == 0, f"Input height {h} not divisible by patch {ph}" + assert w % pw == 0, f"Input width {w} not divisible by patch {pw}" + x = self.proj(x) + h, w = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, h, w, self.embed_dim) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, n, c = x.shape + qkv = ( + self.qkv(x) + .reshape(b, n, 3, self.num_heads, c // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + # Use SDPA — drops the manual attention matmul + softmax and supports flash on cuda. + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0 + ) + x = x.transpose(1, 2).reshape(b, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, dim: int, init_values: Union[float, torch.Tensor] = 1e-5 + ) -> None: + super().__init__() + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.gamma + + +class _DropPath(nn.Module): + def __init__(self, drop_prob: float = 0.0): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.drop_prob == 0.0 or not self.training: + return x + keep = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + mask = x.new_empty(shape).bernoulli_(keep) + if keep > 0.0: + mask.div_(keep) + return x * mask + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = _DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = _DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class VisionTransformer(nn.Module): + """DINOv2-style ViT used as the TIPSv2 vision backbone.""" + + def __init__( + self, + img_size: int = 224, + patch_size: int = 14, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + ffn_bias: bool = True, + proj_bias: bool = True, + drop_path_rate: float = 0.0, + init_values: Optional[float] = 1.0, + ffn_layer: str = "mlp", + num_register_tokens: int = 1, + interpolate_antialias: bool = True, + interpolate_offset: float = 0.0, + ): + super().__init__() + norm_layer = functools.partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + self.gradient_checkpointing = False + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim) + ) + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) + if num_register_tokens + else None + ) + + if ffn_layer != "mlp": + raise NotImplementedError( + f"ffn_layer={ffn_layer!r} not supported in local port" + ) + + dpr = [drop_path_rate * i / max(depth - 1, 1) for i in range(depth)] + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + ffn_layer=Mlp, + init_values=init_values, + ) + for i in range(depth) + ] + ) + # Maintain weight-key compat with the upstream non-chunked branch. + self.chunked_blocks = False + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + # ---- gradient checkpointing toggles ------------------------------------ + + def gradient_checkpointing_enable(self, **_kwargs) -> None: + self.gradient_checkpointing = True + + def gradient_checkpointing_disable(self) -> None: + self.gradient_checkpointing = False + + enable_gradient_checkpointing = gradient_checkpointing_enable + disable_gradient_checkpointing = gradient_checkpointing_disable + + # ---- positional embedding / token prep --------------------------------- + + def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor: + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + num_patches = self.pos_embed.shape[1] - 1 + if npatch == num_patches and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + side = int(math.sqrt(num_patches)) + assert num_patches == side * side + kwargs = {} + if self.interpolate_offset: + kwargs["scale_factor"] = ( + float(w0 + self.interpolate_offset) / side, + float(h0 + self.interpolate_offset) / side, + ) + else: + kwargs["size"] = (w0, h0) + patch_pos_embed = F.interpolate( + patch_pos_embed.reshape(1, side, side, dim).permute(0, 3, 1, 2), + mode="bilinear", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( + previous_dtype + ) + + def prepare_tokens_with_masks( + self, x: torch.Tensor, masks: Optional[torch.Tensor] = None + ) -> torch.Tensor: + _, _, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where( + masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x + ) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + if self.register_tokens is not None: + x = torch.cat( + (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), + dim=1, + ) + return x + + # ---- block runner with optional checkpointing -------------------------- + + def _run_blocks( + self, x: torch.Tensor, collect_indices: Optional[Sequence[int]] = None + ): + collected = [] if collect_indices is not None else None + use_ckpt = self.gradient_checkpointing and self.training + for i, blk in enumerate(self.blocks): + if use_ckpt: + x = torch.utils.checkpoint.checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + if collected is not None and i in collect_indices: + collected.append(x) + return (x, collected) if collected is not None else x + + # ---- public forwards --------------------------------------------------- + + def forward_features( + self, x: torch.Tensor, masks: Optional[torch.Tensor] = None + ) -> dict: + x = self.prepare_tokens_with_masks(x, masks) + x = self._run_blocks(x) + x_norm = self.norm(x) + return { + "x_norm_1st_clstoken": x_norm[:, :1], + "x_norm_2nd_clstoken": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence[int]] = 1, + reshape: bool = False, + return_class_token: bool = False, + norm: bool = True, + ): + x_in = x + x = self.prepare_tokens_with_masks(x) + total = len(self.blocks) + indices = list(range(total - n, total)) if isinstance(n, int) else list(n) + _, outputs = self._run_blocks(x, collect_indices=indices) + # Preserve the requested ordering. + order = {idx: pos for pos, idx in enumerate(sorted(indices))} + outputs = [outputs[order[idx]] for idx in indices] + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + b, _, w, h = x_in.shape + outputs = [ + out.reshape(b, w // self.patch_size, h // self.patch_size, -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, x: torch.Tensor, is_training: bool = False): + ret = self.forward_features(x) + if is_training: + return ret + return ( + self.head(ret["x_norm_1st_clstoken"]), + self.head(ret["x_norm_2nd_clstoken"]), + ret["x_norm_patchtokens"], + ) + + +def _vit_base(patch_size: int = 14, **kwargs) -> VisionTransformer: + return VisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + num_register_tokens=1, + **kwargs, + ) + + +# ───────────────────────────── DPT heads ─────────────────────────────────── + + +class PreActResidualConvUnit(nn.Module): + def __init__(self, features: int): + super().__init__() + self.conv1 = nn.Conv2d(features, features, 3, padding=1, bias=False) + self.conv2 = nn.Conv2d(features, features, 3, padding=1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = F.relu(x) + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + return x + residual + + +class FeatureFusionBlock(nn.Module): + def __init__(self, features: int, has_residual: bool = False, expand: bool = False): + super().__init__() + self.has_residual = has_residual + if has_residual: + self.residual_unit = PreActResidualConvUnit(features) + self.main_unit = PreActResidualConvUnit(features) + out_features = features // 2 if expand else features + self.out_conv = nn.Conv2d(features, out_features, 1, bias=True) + + def forward( + self, x: torch.Tensor, residual: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if self.has_residual and residual is not None: + if residual.shape != x.shape: + residual = F.interpolate( + residual, size=x.shape[2:], mode="bilinear", align_corners=False + ) + residual = self.residual_unit(residual) + x = x + residual + x = self.main_unit(x) + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + x = self.out_conv(x) + return x + + +class ReassembleBlocks(nn.Module): + def __init__( + self, + input_embed_dim: int = 1024, + out_channels: Tuple[int, ...] = (128, 256, 512, 1024), + readout_type: str = "project", + ): + super().__init__() + self.readout_type = readout_type + self.out_projections = nn.ModuleList( + [nn.Conv2d(input_embed_dim, ch, 1) for ch in out_channels] + ) + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d(out_channels[3], out_channels[3], 3, stride=2, padding=1), + ] + ) + if readout_type == "project": + self.readout_projects = nn.ModuleList( + [nn.Linear(2 * input_embed_dim, input_embed_dim) for _ in out_channels] + ) + + def forward(self, features): + out = [] + for i, (cls_token, x) in enumerate(features): + B, D, H, W = x.shape + if self.readout_type == "project": + x_flat = x.flatten(2).transpose(1, 2) + readout = cls_token.unsqueeze(1).expand(-1, x_flat.shape[1], -1) + x_cat = torch.cat([x_flat, readout], dim=-1) + x_proj = F.gelu(self.readout_projects[i](x_cat)) + x = x_proj.transpose(1, 2).reshape(B, D, H, W) + x = self.out_projections[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +def _build_fusion_stack(channels: int) -> nn.ModuleList: + return nn.ModuleList( + [ + FeatureFusionBlock(channels, has_residual=False), + FeatureFusionBlock(channels, has_residual=True), + FeatureFusionBlock(channels, has_residual=True), + FeatureFusionBlock(channels, has_residual=True), + ] + ) + + +class _DPTHeadBase(nn.Module): + """Shared reassemble + fuse + project trunk used by all three task heads.""" + + def __init__( + self, + input_embed_dim: int, + channels: int, + post_process_channels: Tuple[int, ...], + readout_type: str, + ): + super().__init__() + self.reassemble = ReassembleBlocks( + input_embed_dim=input_embed_dim, + out_channels=post_process_channels, + readout_type=readout_type, + ) + self.convs = nn.ModuleList( + [ + nn.Conv2d(ch, channels, 3, padding=1, bias=False) + for ch in post_process_channels + ] + ) + self.fusion_blocks = _build_fusion_stack(channels) + self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True) + + def _trunk(self, intermediate_features) -> torch.Tensor: + x = self.reassemble(intermediate_features) + x = [self.convs[i](feat) for i, feat in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + for i in range(1, 4): + out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) + return self.project(out) + + +class DPTDepthHead(_DPTHeadBase): + def __init__( + self, + input_embed_dim: int = 1024, + channels: int = 256, + post_process_channels: Tuple[int, ...] = (128, 256, 512, 1024), + readout_type: str = "project", + num_depth_bins: int = 256, + min_depth: float = 1e-3, + max_depth: float = 10.0, + ): + super().__init__(input_embed_dim, channels, post_process_channels, readout_type) + self.num_depth_bins = num_depth_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.depth_head = nn.Linear(channels, num_depth_bins) + + def forward(self, intermediate_features, image_size=None) -> torch.Tensor: + out = F.relu(self._trunk(intermediate_features)) + out = out.permute(0, 2, 3, 1) + out = self.depth_head(out) + bin_centers = torch.linspace( + self.min_depth, + self.max_depth, + self.num_depth_bins, + device=out.device, + dtype=out.dtype, + ) + out = F.relu(out) + self.min_depth + out_norm = out / out.sum(dim=-1, keepdim=True) + depth = torch.einsum("bhwn,n->bhw", out_norm, bin_centers).unsqueeze(1) + if image_size is not None: + depth = F.interpolate( + depth, size=image_size, mode="bilinear", align_corners=False + ) + return depth + + +class DPTNormalsHead(_DPTHeadBase): + def __init__( + self, + input_embed_dim: int = 1024, + channels: int = 256, + post_process_channels: Tuple[int, ...] = (128, 256, 512, 1024), + readout_type: str = "project", + ): + super().__init__(input_embed_dim, channels, post_process_channels, readout_type) + self.normals_head = nn.Linear(channels, 3) + + def forward(self, intermediate_features, image_size=None) -> torch.Tensor: + out = self._trunk(intermediate_features) + out = out.permute(0, 2, 3, 1) + out = self.normals_head(out) + out = F.normalize(out, p=2, dim=-1) + out = out.permute(0, 3, 1, 2) + if image_size is not None: + out = F.interpolate( + out, size=image_size, mode="bilinear", align_corners=False + ) + return out + + +class DPTSegmentationHead(_DPTHeadBase): + def __init__( + self, + input_embed_dim: int = 1024, + channels: int = 256, + post_process_channels: Tuple[int, ...] = (128, 256, 512, 1024), + readout_type: str = "project", + num_classes: int = 150, + ): + super().__init__(input_embed_dim, channels, post_process_channels, readout_type) + self.segmentation_head = nn.Linear(channels, num_classes) + + def forward(self, intermediate_features, image_size=None) -> torch.Tensor: + out = self._trunk(intermediate_features) + out = out.permute(0, 2, 3, 1) + out = self.segmentation_head(out) + out = out.permute(0, 3, 1, 2) + if image_size is not None: + out = F.interpolate( + out, size=image_size, mode="bilinear", align_corners=False + ) + return out + + +# ───────────────────────────── Top-level model ───────────────────────────── + + +@dataclass +class TIPSv2DPTOutput: + depth: Optional[torch.Tensor] = None + normals: Optional[torch.Tensor] = None + segmentation: Optional[torch.Tensor] = None + + +# Hard-coded config for the b14-dpt variant — matches config.json on the hub. +_B14_DPT_CONFIG = dict( + backbone_repo="google/tipsv2-b14", + embed_dim=768, + channels=256, + post_process_channels=(96, 192, 384, 768), + block_indices=(2, 5, 8, 11), + readout_type="project", + num_depth_bins=256, + min_depth=1e-3, + max_depth=10.0, + num_seg_classes=150, + # Vision encoder + vision_fn="vit_base", + patch_size=14, + img_size=448, + init_values=1.0, + num_register_tokens=1, + ffn_layer="mlp", +) + + +class TIPSv2DPTModel(nn.Module): + """TIPSv2 DPT dense-prediction model (depth, normals, segmentation). + + Use :meth:`from_pretrained` to load weights for `google/tipsv2-b14-dpt`. + """ + + def __init__(self, config: Optional[dict] = None): + super().__init__() + cfg = dict(_B14_DPT_CONFIG) + if config: + cfg.update(config) + self.config = cfg + + builders = {"vit_base": _vit_base} + if cfg["vision_fn"] not in builders: + raise NotImplementedError(f"vision_fn={cfg['vision_fn']!r} not supported") + + self.vision_encoder = builders[cfg["vision_fn"]]( + img_size=cfg["img_size"], + patch_size=cfg["patch_size"], + ffn_layer=cfg["ffn_layer"], + init_values=cfg["init_values"], + interpolate_antialias=True, + interpolate_offset=0.0, + ) + + ppc = tuple(cfg["post_process_channels"]) + self.depth_head = DPTDepthHead( + input_embed_dim=cfg["embed_dim"], + channels=cfg["channels"], + post_process_channels=ppc, + readout_type=cfg["readout_type"], + num_depth_bins=cfg["num_depth_bins"], + min_depth=cfg["min_depth"], + max_depth=cfg["max_depth"], + ) + self.normals_head = DPTNormalsHead( + input_embed_dim=cfg["embed_dim"], + channels=cfg["channels"], + post_process_channels=ppc, + readout_type=cfg["readout_type"], + ) + self.segmentation_head = DPTSegmentationHead( + input_embed_dim=cfg["embed_dim"], + channels=cfg["channels"], + post_process_channels=ppc, + readout_type=cfg["readout_type"], + num_classes=cfg["num_seg_classes"], + ) + + # ---- properties + checkpointing --------------------------------------- + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def gradient_checkpointing_enable(self, **kwargs) -> None: + """Enable gradient checkpointing on the vision transformer blocks.""" + self.vision_encoder.gradient_checkpointing_enable(**kwargs) + + def gradient_checkpointing_disable(self) -> None: + self.vision_encoder.gradient_checkpointing_disable() + + enable_gradient_checkpointing = gradient_checkpointing_enable + disable_gradient_checkpointing = gradient_checkpointing_disable + + # ---- core inference path ---------------------------------------------- + + def _extract_intermediate(self, pixel_values: torch.Tensor): + intermediate = self.vision_encoder.get_intermediate_layers( + pixel_values, + n=tuple(self.config["block_indices"]), + reshape=True, + return_class_token=True, + norm=True, + ) + # Returned as (cls_token, patch_feats) tuples to match the remote API. + return [(cls_tok, patch_feat) for patch_feat, cls_tok in intermediate] + + def predict_depth(self, pixel_values: torch.Tensor) -> torch.Tensor: + h, w = pixel_values.shape[2:] + return self.depth_head( + self._extract_intermediate(pixel_values), image_size=(h, w) + ) + + def predict_normals(self, pixel_values: torch.Tensor) -> torch.Tensor: + h, w = pixel_values.shape[2:] + return self.normals_head( + self._extract_intermediate(pixel_values), image_size=(h, w) + ) + + def predict_segmentation(self, pixel_values: torch.Tensor) -> torch.Tensor: + h, w = pixel_values.shape[2:] + return self.segmentation_head( + self._extract_intermediate(pixel_values), image_size=(h, w) + ) + + def forward(self, pixel_values: torch.Tensor) -> TIPSv2DPTOutput: + h, w = pixel_values.shape[2:] + feats = self._extract_intermediate(pixel_values) + return TIPSv2DPTOutput( + depth=self.depth_head(feats, image_size=(h, w)), + normals=self.normals_head(feats, image_size=(h, w)), + segmentation=self.segmentation_head(feats, image_size=(h, w)), + ) + + # ---- loader ----------------------------------------------------------- + + @classmethod + def from_pretrained( + cls, + model_id: str = "google/tipsv2-b14-dpt", + device: Union[str, torch.device] = "cpu", + dtype: torch.dtype = torch.float32, + cache_dir: Optional[str] = None, + ) -> "TIPSv2DPTModel": + """Build the model and load weights from the hub. + + Pulls the DPT head weights from ``model_id`` (default + ``google/tipsv2-b14-dpt``) and the vision-encoder weights from the + backbone repo specified in the DPT config. + """ + from huggingface_hub import hf_hub_download + from safetensors.torch import load_file + + if model_id != "google/tipsv2-b14-dpt": + raise NotImplementedError( + f"Local TIPSv2DPTModel only supports 'google/tipsv2-b14-dpt'; got {model_id!r}" + ) + + model = cls() + + dpt_ckpt = hf_hub_download(model_id, "model.safetensors", cache_dir=cache_dir) + dpt_state = load_file(dpt_ckpt) + + backbone_ckpt = hf_hub_download( + model.config["backbone_repo"], + "model.safetensors", + cache_dir=cache_dir, + ) + backbone_state = load_file(backbone_ckpt) + # Backbone repo stores both vision and text encoders — keep only vision_encoder.*. + backbone_state = { + k: v for k, v in backbone_state.items() if k.startswith("vision_encoder.") + } + + merged = {**dpt_state, **backbone_state} + missing, unexpected = model.load_state_dict(merged, strict=False) + if missing: + print( + f"[tipsv2] Missing keys ({len(missing)}): {missing[:8]}{'...' if len(missing) > 8 else ''}" + ) + if unexpected: + print( + f"[tipsv2] Unexpected keys ({len(unexpected)}): {unexpected[:8]}{'...' if len(unexpected) > 8 else ''}" + ) + + model.to(device=device, dtype=dtype) + return model