mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 08:20:35 +00:00
Add sapiens2 as a diffusion feature extractor
This commit is contained in:
@@ -365,6 +365,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# must be set to train for gradient checkpointing to work
|
||||
self.dfe.vision_encoder.train()
|
||||
self.dfe.vision_encoder.gradient_checkpointing = True
|
||||
elif hasattr(self.dfe, 'model') and self.train_config.gradient_checkpointing:
|
||||
if hasattr(self.dfe.model, 'enable_gradient_checkpointing'):
|
||||
self.dfe.model.train()
|
||||
self.dfe.model.enable_gradient_checkpointing()
|
||||
elif hasattr(self.dfe.model, 'gradient_checkpointing'):
|
||||
self.dfe.model.train()
|
||||
self.dfe.model.gradient_checkpointing = True
|
||||
else:
|
||||
print_acc("Warning: Could not enable gradient checkpointing on diffusion feature extractor model.")
|
||||
else:
|
||||
self.dfe.eval()
|
||||
|
||||
@@ -664,7 +673,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
|
||||
|
||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
|
||||
elif self.dfe.version in [3, 4, 5, 6, 7, 8]:
|
||||
elif self.dfe.version in [3, 4, 5, 6, 7, 8, 9, 10]:
|
||||
dfe_loss = self.dfe(
|
||||
noise=noise,
|
||||
noise_pred=noise_pred,
|
||||
|
||||
@@ -11,6 +11,8 @@ import weakref
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.models.sapiens2 import Sapiens2
|
||||
import huggingface_hub
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
@@ -1020,6 +1022,156 @@ class DiffusionFeatureExtractor8(DiffusionFeatureExtractor7):
|
||||
super().__init__(device=device, dtype=dtype, vae=vae, sd=sd, partial_step=True)
|
||||
self.version = 8
|
||||
|
||||
class DiffusionFeatureExtractor9(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
device=torch.device("cuda"),
|
||||
dtype=torch.bfloat16,
|
||||
vae=None,
|
||||
sd=None,
|
||||
partial_step: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.version = 9
|
||||
self.sd_ref = weakref.ref(sd) if sd is not None else None
|
||||
ckpt_path = huggingface_hub.hf_hub_download(repo_id="facebook/sapiens2-pretrain-1b", filename="sapiens2_1b_pretrain.safetensors")
|
||||
self.model = Sapiens2(arch="sapiens2_1b", img_size=(1024, 768), patch_size=16).eval().cuda() # img_size is (H, W)
|
||||
self.model.load_state_dict(load_file(ckpt_path))
|
||||
self.model.to(device, dtype=dtype)
|
||||
|
||||
self.losses = {}
|
||||
self.log_every = 100
|
||||
self.step = 0
|
||||
self.do_partial_step = partial_step
|
||||
|
||||
def get_pred(self, tensor_0_1: torch.Tensor):
|
||||
if tensor_0_1.ndim != 4 or tensor_0_1.shape[1] != 3:
|
||||
raise ValueError(f"Expected (bs, 3, h, w), got {tuple(tensor_0_1.shape)}")
|
||||
|
||||
x = tensor_0_1.to(self.model.device, dtype=self.model.dtype)
|
||||
"""Apply ImageNet normalization to a (B, C, H, W) RGB tensor in [0, 1]."""
|
||||
mean = torch.as_tensor((0.485, 0.456, 0.406), dtype=x.dtype, device=x.device).view(1, 3, 1, 1)
|
||||
std = torch.as_tensor((0.229, 0.224, 0.225), dtype=x.dtype, device=x.device).view(1, 3, 1, 1)
|
||||
x = (x - mean) / std
|
||||
|
||||
# Resize
|
||||
# if not divisible by 16 or total pixels > max_res*max_res, resize to fit within 16 patches
|
||||
max_res = 1024
|
||||
p = 16
|
||||
if (x.shape[-1] % p != 0) or (x.shape[-2] % p != 0) or (x.shape[-1] * x.shape[-2] > max_res * max_res):
|
||||
target_h = x.shape[-2]
|
||||
target_w = x.shape[-1]
|
||||
if x.shape[-1] * target_h > max_res * max_res:
|
||||
scale_factor = math.sqrt((max_res * max_res) / (target_w * target_h))
|
||||
target_h = int(target_h * scale_factor)
|
||||
target_w = int(target_w * scale_factor)
|
||||
target_h = (target_h // p) * p
|
||||
target_w = (target_w // p) * p
|
||||
x = F.interpolate(x, size=(target_h, target_w), mode="bilinear", align_corners=False)
|
||||
x = x.to(self.model.device, dtype=self.model.dtype)
|
||||
features = self.model(x)[0]
|
||||
return features
|
||||
|
||||
def forward(
|
||||
self,
|
||||
noise,
|
||||
noise_pred,
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
batch: DataLoaderBatchDTO,
|
||||
scheduler: CustomFlowMatchEulerDiscreteScheduler,
|
||||
model=None
|
||||
):
|
||||
dtype = torch.bfloat16
|
||||
device = self.sd_ref().vae.device
|
||||
tensors = batch.tensor.to(device, dtype=dtype)
|
||||
is_video = False
|
||||
# stack time for video models on the batch dimension
|
||||
if len(noise_pred.shape) == 5:
|
||||
# B, C, T, H, W = images.shape
|
||||
# only take first time
|
||||
noise = noise[:, :, 0, :, :]
|
||||
noise_pred = noise_pred[:, :, 0, :, :]
|
||||
noisy_latents = noisy_latents[:, :, 0, :, :]
|
||||
is_video = True
|
||||
|
||||
if len(tensors.shape) == 5:
|
||||
# batch is different
|
||||
# (B, T, C, H, W)
|
||||
# only take first time
|
||||
tensors = tensors[:, 0, :, :, :]
|
||||
|
||||
with torch.no_grad():
|
||||
tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0
|
||||
# expand shape to match noise_pred
|
||||
while len(tv.shape) < len(noise_pred.shape):
|
||||
tv = tv.unsqueeze(-1)
|
||||
|
||||
with torch.no_grad():
|
||||
target_0_1 = (tensors + 1) / 2 # 0 to 1
|
||||
|
||||
if not self.do_partial_step:
|
||||
# step latent
|
||||
x0 = noisy_latents - tv * noise_pred
|
||||
stepped_latents = x0
|
||||
# min 0.001
|
||||
tv = torch.clamp(tv, min=0.001)
|
||||
else:
|
||||
# step is random 0.05 to 0.02
|
||||
step = torch.rand_like(tv) * 0.03 + 0.02
|
||||
next_step = tv - step
|
||||
next_step = torch.clamp(next_step, min=0.0)
|
||||
stepped_latents = noisy_latents + (next_step - tv) * noise_pred
|
||||
|
||||
with torch.no_grad():
|
||||
# make a noisy target at next timestep
|
||||
target_latents = batch.latents.to(self.sd_ref().vae.device, dtype=self.sd_ref().vae.dtype)
|
||||
# add noise
|
||||
target_latents = (1.0 - next_step) * target_latents + next_step * noise
|
||||
target_n1p1 = self.sd_ref().decode_latents(target_latents)
|
||||
target_0_1 = (target_n1p1 + 1) / 2 # 0 to 1
|
||||
|
||||
latents = stepped_latents.to(self.sd_ref().vae.device, dtype=self.sd_ref().vae.dtype)
|
||||
|
||||
tensors_n1p1 = self.sd_ref().decode_latents(latents)
|
||||
|
||||
pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1
|
||||
|
||||
device = self.model.device
|
||||
dtype = self.model.dtype
|
||||
|
||||
with torch.no_grad():
|
||||
target = self.get_pred(target_0_1)
|
||||
|
||||
pred_images = pred_images.to(device, dtype=dtype)
|
||||
pred = self.get_pred(pred_images)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
pred.float(), target.float()
|
||||
) * 100.0
|
||||
|
||||
if self.do_partial_step:
|
||||
loss = loss * 10.0
|
||||
|
||||
if 'loss' not in self.losses:
|
||||
self.losses['loss'] = loss.item()
|
||||
else:
|
||||
self.losses['loss'] += loss.item()
|
||||
with torch.no_grad():
|
||||
if self.step % self.log_every == 0 and self.step > 0:
|
||||
print(f"DFE losses:")
|
||||
for key in self.losses:
|
||||
self.losses[key] /= self.log_every
|
||||
# print in 2.000e-01 format
|
||||
print(f" - {key}: {self.losses[key]:.3e}")
|
||||
self.losses[key] = 0.0
|
||||
|
||||
# total_loss += mse_loss
|
||||
self.step += 1
|
||||
|
||||
return loss
|
||||
|
||||
def load_dfe(model_path, vae=None, sd: 'BaseModel' = None) -> DiffusionFeatureExtractor:
|
||||
if model_path == "v3":
|
||||
dfe = DiffusionFeatureExtractor3(vae=vae)
|
||||
@@ -1045,6 +1197,10 @@ def load_dfe(model_path, vae=None, sd: 'BaseModel' = None) -> DiffusionFeatureEx
|
||||
dfe = DiffusionFeatureExtractor8(vae=vae, sd=sd)
|
||||
dfe.eval()
|
||||
return dfe
|
||||
if model_path == "v9":
|
||||
dfe = DiffusionFeatureExtractor9(vae=vae, sd=sd)
|
||||
dfe.eval()
|
||||
return dfe
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
# if it ende with safetensors
|
||||
|
||||
944
toolkit/models/sapiens2.py
Normal file
944
toolkit/models/sapiens2.py
Normal file
@@ -0,0 +1,944 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# Modified for AI Toolkit by Ostris
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# https://raw.githubusercontent.com/facebookresearch/sapiens2/refs/heads/main/sapiens/backbones/standalone/sapiens2.py
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn.init import trunc_normal_
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
def to_2tuple(x):
|
||||
if isinstance(x, (str, bytes)):
|
||||
return (x, x)
|
||||
if isinstance(x, Sequence):
|
||||
x = tuple(x)
|
||||
if len(x) == 2:
|
||||
return x
|
||||
raise ValueError("Expected scalar or length-2 iterable")
|
||||
return (x, x)
|
||||
|
||||
|
||||
class RopePositionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
*,
|
||||
num_heads: int,
|
||||
base: float | None = 100.0,
|
||||
min_period: float | None = None,
|
||||
max_period: float | None = None,
|
||||
normalize_coords: Literal["min", "max", "separate"] = "separate",
|
||||
shift_coords: float | None = None,
|
||||
jitter_coords: float | None = None,
|
||||
rescale_coords: float | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert embed_dim % (4 * num_heads) == 0
|
||||
both_periods = min_period is not None and max_period is not None
|
||||
if (base is None and not both_periods) or (base is not None and both_periods):
|
||||
raise ValueError(
|
||||
"Either `base` or `min_period`+`max_period` must be provided."
|
||||
)
|
||||
|
||||
D_head = embed_dim // num_heads
|
||||
self.base = base
|
||||
self.min_period = min_period
|
||||
self.max_period = max_period
|
||||
self.D_head = D_head
|
||||
self.normalize_coords = normalize_coords
|
||||
self.shift_coords = shift_coords
|
||||
self.jitter_coords = jitter_coords
|
||||
self.rescale_coords = rescale_coords
|
||||
|
||||
# Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher
|
||||
self.dtype = dtype or torch.float32 # Don't rely on self.periods.dtype
|
||||
self.register_buffer(
|
||||
"periods",
|
||||
torch.empty(D_head // 4, device=device, dtype=self.dtype),
|
||||
persistent=True,
|
||||
)
|
||||
self._init_weights()
|
||||
|
||||
def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]:
|
||||
device = self.periods.device
|
||||
dtype = self.dtype
|
||||
dd = {"device": device, "dtype": dtype}
|
||||
# Prepare coords in range [-1, +1]
|
||||
if self.normalize_coords == "max":
|
||||
max_HW = max(H, W)
|
||||
coords_h = torch.arange(0.5, H, **dd) / max_HW # [H]
|
||||
coords_w = torch.arange(0.5, W, **dd) / max_HW # [W]
|
||||
elif self.normalize_coords == "min":
|
||||
min_HW = min(H, W)
|
||||
coords_h = torch.arange(0.5, H, **dd) / min_HW # [H]
|
||||
coords_w = torch.arange(0.5, W, **dd) / min_HW # [W]
|
||||
elif self.normalize_coords == "separate":
|
||||
coords_h = torch.arange(0.5, H, **dd) / H # [H]
|
||||
coords_w = torch.arange(0.5, W, **dd) / W # [W]
|
||||
else:
|
||||
raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
|
||||
coords = torch.stack(
|
||||
torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1
|
||||
) # [H, W, 2]
|
||||
coords = coords.flatten(0, 1) # [HW, 2]
|
||||
coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1]
|
||||
|
||||
# Shift coords by adding a uniform value in [-shift, shift]
|
||||
if self.training and self.shift_coords is not None:
|
||||
shift_hw = torch.empty(2, **dd).uniform_(
|
||||
-self.shift_coords, self.shift_coords
|
||||
)
|
||||
coords += shift_hw[None, :]
|
||||
|
||||
# Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter]
|
||||
if self.training and self.jitter_coords is not None:
|
||||
jitter_max = np.log(self.jitter_coords)
|
||||
jitter_min = -jitter_max
|
||||
jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp()
|
||||
coords *= jitter_hw[None, :]
|
||||
|
||||
# Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale]
|
||||
if self.training and self.rescale_coords is not None:
|
||||
rescale_max = np.log(self.rescale_coords)
|
||||
rescale_min = -rescale_max
|
||||
rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
|
||||
coords *= rescale_hw
|
||||
|
||||
# Prepare angles and sin/cos
|
||||
angles = (
|
||||
2 * math.pi * coords[:, :, None] / self.periods[None, None, :]
|
||||
) # [HW, 2, D//4]
|
||||
angles = angles.flatten(1, 2) # [HW, D//2]
|
||||
angles = angles.tile(2) # [HW, D]
|
||||
cos = torch.cos(angles) # [HW, D]
|
||||
sin = torch.sin(angles) # [HW, D]
|
||||
|
||||
return (sin, cos) # 2 * [HW, D]
|
||||
|
||||
def _init_weights(self):
|
||||
device = self.periods.device
|
||||
dtype = self.dtype
|
||||
if self.base is not None:
|
||||
periods = self.base ** (
|
||||
2
|
||||
* torch.arange(self.D_head // 4, device=device, dtype=dtype)
|
||||
/ (self.D_head // 2)
|
||||
) # [D//4]
|
||||
else:
|
||||
base = self.max_period / self.min_period
|
||||
exponents = torch.linspace(
|
||||
0, 1, self.D_head // 4, device=device, dtype=dtype
|
||||
) # [D//4] range [0, 1]
|
||||
periods = base**exponents # range [1, max_period / min_period]
|
||||
periods = periods / base # range [min_period / max_period, 1]
|
||||
periods = periods * self.max_period # range [min_period, max_period]
|
||||
self.periods.data = periods
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
class Tokenizer(nn.Module):
|
||||
"""Stacked window self‑attention that emits one token per window
|
||||
by re‑using TransformerEncoderLayer blocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int,
|
||||
window_size: int = 4,
|
||||
num_heads: int = 4,
|
||||
num_tokenizer_layers: int = 1,
|
||||
qkv_bias: bool = True,
|
||||
use_qk_norm: bool = False,
|
||||
chunk_size: int = 1024, # max windows per chunk
|
||||
):
|
||||
super().__init__()
|
||||
self.ws = window_size
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
# local absolute positional embeddings for [CLS] + patch tokens
|
||||
self.local_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, 1 + window_size * window_size, embed_dims)
|
||||
)
|
||||
trunc_normal_(self.local_pos_embed, std=0.02)
|
||||
|
||||
# build N identical TransformerEncoderLayer blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
TransformerEncoderLayer2(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=embed_dims * 4, # standard FFN size
|
||||
qkv_bias=qkv_bias,
|
||||
use_qk_norm=use_qk_norm,
|
||||
)
|
||||
for _ in range(num_tokenizer_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# shared CLS token for pooling
|
||||
self.w_cls = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
trunc_normal_(self.w_cls, std=0.02)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
hw: Tuple[int, int],
|
||||
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
"""Args:
|
||||
x : B, N, C (N = H*W)
|
||||
hw : (H, W) before reduction
|
||||
Returns:
|
||||
x_ : B, (H/ws)*(W/ws), C
|
||||
hw_: (H/ws, W/ws)
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
H, W = hw
|
||||
ws = self.ws
|
||||
assert H % ws == 0 and W % ws == 0, (
|
||||
f"Image size {H}×{W} must be divisible by window {ws}."
|
||||
)
|
||||
|
||||
# reshape tokens → non‑overlapping windows
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
ph, pw = H // ws, W // ws ## ints in eager mode
|
||||
ph, pw = int(ph), int(pw) ## ints in scripting mode
|
||||
x = x.view(B, ph, ws, pw, ws, C) # B, H/ws, ws, W/ws, ws, C
|
||||
x = x.permute(0, 1, 3, 2, 4, 5) # B, H/ws, W/ws, ws, ws, C
|
||||
x = x.contiguous().view(B * ph * pw, ws * ws, C) # (B*H/ws*W/ws), ws², C))
|
||||
|
||||
total_windows = x.size(0)
|
||||
chunk_size = int(min(self.chunk_size, total_windows))
|
||||
token_out = x.new_empty(total_windows, C)
|
||||
|
||||
use_ckpt = torch.is_grad_enabled() and self.gradient_checkpointing
|
||||
|
||||
def _run_blocks(t: torch.Tensor) -> torch.Tensor:
|
||||
for blk in self.blocks:
|
||||
t = blk(t)
|
||||
return t
|
||||
|
||||
for i in range(0, total_windows, chunk_size):
|
||||
chunk = x[i : i + chunk_size] # (m, ws², C)
|
||||
m = chunk.size(0)
|
||||
cls = self.w_cls.expand(m, -1, -1) # (m, 1, C)
|
||||
chunk = torch.cat([cls, chunk], dim=1) # (m, 1+ws², C)
|
||||
chunk = chunk + self.local_pos_embed # add local PE
|
||||
|
||||
if use_ckpt:
|
||||
chunk = checkpoint(_run_blocks, chunk, use_reentrant=False)
|
||||
else:
|
||||
chunk = _run_blocks(chunk)
|
||||
|
||||
token_out[i : i + m] = chunk[:, 0] # take CLS out
|
||||
|
||||
token = token_out.view(B, ph * pw, C) # (B, (H/ws)*(W
|
||||
return token, (ph, pw)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
class GroupedQueryAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
num_kv_heads=None,
|
||||
input_dims=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
proj_bias=True,
|
||||
use_qk_norm=True,
|
||||
v_shortcut=False,
|
||||
layer_scale_init_value=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
# Core dims
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads or num_heads
|
||||
assert self.num_heads % self.num_kv_heads == 0, (
|
||||
"num_kv_heads must divide num_heads"
|
||||
)
|
||||
self.head_dim = embed_dims // num_heads
|
||||
self.input_dims = input_dims or embed_dims
|
||||
# Features
|
||||
self.attn_drop = attn_drop
|
||||
self.v_shortcut = v_shortcut
|
||||
self.use_qk_norm = use_qk_norm
|
||||
|
||||
# Attention operation selection
|
||||
if qk_scale is not None:
|
||||
scale = qk_scale
|
||||
else:
|
||||
scale = self.head_dim**-0.5
|
||||
|
||||
assert qk_scale is None, "qk_scale is not supported"
|
||||
self.attn_op = F.scaled_dot_product_attention
|
||||
|
||||
# Q/K/V projections
|
||||
self.wq = nn.Linear(self.input_dims, embed_dims, bias=qkv_bias)
|
||||
self.wk = nn.Linear(
|
||||
self.input_dims, self.num_kv_heads * self.head_dim, bias=qkv_bias
|
||||
)
|
||||
self.wv = nn.Linear(
|
||||
self.input_dims, self.num_kv_heads * self.head_dim, bias=qkv_bias
|
||||
)
|
||||
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = nn.RMSNorm(self.head_dim, eps=1e-6)
|
||||
self.k_norm = nn.RMSNorm(self.head_dim, eps=1e-6)
|
||||
|
||||
# Output projection + dropout
|
||||
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
# Optional LayerScale
|
||||
if layer_scale_init_value > 0:
|
||||
self.gamma = LayerScale(embed_dims, scale=layer_scale_init_value)
|
||||
else:
|
||||
self.gamma = nn.Identity()
|
||||
|
||||
def apply_rope(
|
||||
self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor]
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
# All operations will use the dtype of rope, the output is cast back to the dtype of q and k
|
||||
q_dtype = q.dtype
|
||||
k_dtype = k.dtype
|
||||
sin, cos = rope
|
||||
rope_dtype = sin.dtype
|
||||
q = q.to(dtype=rope_dtype)
|
||||
k = k.to(dtype=rope_dtype)
|
||||
N = q.shape[-2]
|
||||
prefix = N - sin.shape[-2] ## extra tokens
|
||||
assert prefix >= 0
|
||||
q_prefix = q[:, :, :prefix, :]
|
||||
q = self._rope_apply(q[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head]
|
||||
q = torch.cat((q_prefix, q), dim=-2) # [B, head, N, D//head]
|
||||
k_prefix = k[:, :, :prefix, :]
|
||||
k = self._rope_apply(k[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head]
|
||||
k = torch.cat((k_prefix, k), dim=-2) # [B, head, N, D//head]
|
||||
q = q.to(dtype=q_dtype)
|
||||
k = k.to(dtype=k_dtype)
|
||||
return q, k
|
||||
|
||||
def _rope_rotate_half(self, x: Tensor) -> Tensor:
|
||||
# x: [ x0 x1 x2 x3 x4 x5]
|
||||
# out: [-x3 -x4 -x5 x0 x1 x2]
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat([-x2, x1], dim=-1)
|
||||
|
||||
def _rope_apply(self, x: Tensor, sin: Tensor, cos: Tensor) -> Tensor:
|
||||
# x: [..., D], eg [x0, x1, x2, x3, x4, x5]
|
||||
# sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
|
||||
# cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2]
|
||||
return (x * cos) + (self._rope_rotate_half(x) * sin)
|
||||
|
||||
def forward(self, x, rope=None):
|
||||
B, N, _ = x.shape
|
||||
# Q: (B, N, num_heads, head_dim)
|
||||
q = self.wq(x).view(B, N, self.num_heads, self.head_dim)
|
||||
# K/V: (B, N, num_kv_heads, head_dim)
|
||||
k = self.wk(x).view(B, N, self.num_kv_heads, self.head_dim)
|
||||
v = self.wv(x).view(B, N, self.num_kv_heads, self.head_dim)
|
||||
|
||||
# (B, heads, N, head_dim)
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
|
||||
if self.use_qk_norm:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Repeat KV heads if group ratio >1
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
factor = self.num_heads // self.num_kv_heads
|
||||
k = k.repeat_interleave(factor, dim=1)
|
||||
v = v.repeat_interleave(factor, dim=1)
|
||||
|
||||
if rope is not None:
|
||||
q, k = self.apply_rope(q, k, rope)
|
||||
|
||||
# Scaled dot-product attention
|
||||
attn_out = self.attn_op(
|
||||
q, k, v, dropout_p=self.attn_drop if self.training else 0.0
|
||||
) # (B, num_heads, N, head_dim)
|
||||
|
||||
# Merge heads -> (B, N, embed_dims)
|
||||
out = attn_out.permute(0, 2, 1, 3).reshape(B, N, self.embed_dims)
|
||||
|
||||
# Output projection + drop + layer scale
|
||||
out = self.proj(out)
|
||||
out = self.gamma(self.proj_drop(out))
|
||||
|
||||
# Optional V-shortcut (only when MQA)
|
||||
if self.v_shortcut and self.num_kv_heads == 1:
|
||||
raise NotImplementedError
|
||||
return out
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
class TransformerEncoderLayer2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
num_kv_heads=None,
|
||||
feedforward_channels=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
layer_scale_init_value=0.0,
|
||||
use_qk_norm=True,
|
||||
qkv_bias=True,
|
||||
):
|
||||
super(TransformerEncoderLayer2, self).__init__()
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.ln1 = nn.RMSNorm(self.embed_dims, eps=1e-6)
|
||||
self.attn = GroupedQueryAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
qkv_bias=qkv_bias,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
use_qk_norm=use_qk_norm,
|
||||
)
|
||||
|
||||
self.ln2 = nn.RMSNorm(self.embed_dims, eps=1e-6)
|
||||
self.ffn = SwiGLUFFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return self.ln1
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return self.ln2
|
||||
|
||||
def forward(self, x, rope=None):
|
||||
x = x + self.attn(self.ln1(x), rope=rope)
|
||||
x = self.ffn(self.ln2(x), identity=x)
|
||||
return x
|
||||
|
||||
|
||||
##-----------------------------------
|
||||
class Sapiens2(nn.Module):
|
||||
arch_zoo = {
|
||||
**dict.fromkeys(
|
||||
["sapiens2_0.1b"],
|
||||
{
|
||||
"embed_dims": 768,
|
||||
"num_layers": 12,
|
||||
"num_heads": 12,
|
||||
"feedforward_channels": 768 * 4,
|
||||
"num_tokenizer_layers": 2,
|
||||
},
|
||||
),
|
||||
**dict.fromkeys(
|
||||
["sapiens2_0.4b"],
|
||||
{
|
||||
"embed_dims": 1024,
|
||||
"num_layers": 24,
|
||||
"num_heads": 16,
|
||||
"feedforward_channels": 1024 * 4,
|
||||
"num_tokenizer_layers": 2,
|
||||
},
|
||||
),
|
||||
**dict.fromkeys(
|
||||
["sapiens2_0.8b"],
|
||||
{
|
||||
"embed_dims": 1280,
|
||||
"num_layers": 32,
|
||||
"num_heads": 16,
|
||||
"feedforward_channels": 1280 * 4,
|
||||
"num_tokenizer_layers": 3,
|
||||
},
|
||||
),
|
||||
**dict.fromkeys(
|
||||
["sapiens2_1b"],
|
||||
{
|
||||
"embed_dims": 1536,
|
||||
"num_layers": 40,
|
||||
"num_heads": 24,
|
||||
"feedforward_channels": 1536 * 4,
|
||||
"num_tokenizer_layers": 4,
|
||||
},
|
||||
),
|
||||
**dict.fromkeys(
|
||||
["sapiens2_5b"],
|
||||
{
|
||||
"embed_dims": 2432,
|
||||
"num_layers": 56,
|
||||
"num_heads": 32,
|
||||
"feedforward_channels": 2432 * 4,
|
||||
"num_tokenizer_layers": 6,
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
num_extra_tokens = 1 # class token
|
||||
OUT_TYPES = {"raw", "cls_token", "featmap"}
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arch="sapiens2_1b",
|
||||
img_size=(1024, 768),
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
out_indices=-1,
|
||||
drop_rate=0.0,
|
||||
window_size=4,
|
||||
use_tokenizer=False, ## 4k resolution
|
||||
use_qk_norm=True,
|
||||
qkv_bias=True,
|
||||
final_norm=True,
|
||||
out_type="raw",
|
||||
with_cls_token=True,
|
||||
layer_scale_init_value=1e-4, ## non zero init to activate layerscale
|
||||
frozen_stages=-1,
|
||||
patch_cfg=dict(),
|
||||
layer_cfgs=dict(),
|
||||
pos_embed_rope_base: float = 100.0,
|
||||
pos_embed_rope_min_period: float | None = None,
|
||||
pos_embed_rope_max_period: float | None = None,
|
||||
pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate",
|
||||
pos_embed_rope_shift_coords: float | None = None,
|
||||
pos_embed_rope_jitter_coords: float | None = None,
|
||||
pos_embed_rope_rescale_coords: float | None = None,
|
||||
pos_embed_rope_dtype: str = "bf16",
|
||||
n_storage_tokens: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
arch = arch.lower()
|
||||
assert arch in set(self.arch_zoo), (
|
||||
f"Arch {arch} is not in default archs {set(self.arch_zoo)}"
|
||||
)
|
||||
self.arch_settings = self.arch_zoo[arch]
|
||||
|
||||
self.embed_dims = self.arch_settings["embed_dims"]
|
||||
self.num_layers = self.arch_settings["num_layers"]
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.window_size = window_size
|
||||
img_size = to_2tuple(img_size)
|
||||
encoder_img_size = (
|
||||
(img_size[0] // window_size, img_size[1] // window_size)
|
||||
if use_tokenizer
|
||||
else img_size
|
||||
)
|
||||
self.img_size = to_2tuple(encoder_img_size)
|
||||
|
||||
# Set patch embedding
|
||||
_patch_cfg = dict(
|
||||
in_channels=in_channels,
|
||||
input_size=self.img_size,
|
||||
embed_dims=self.embed_dims,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=True,
|
||||
)
|
||||
_patch_cfg.update(patch_cfg)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
self.rope_embed = RopePositionEmbedding(
|
||||
embed_dim=self.embed_dims,
|
||||
num_heads=self.arch_settings["num_heads"],
|
||||
base=pos_embed_rope_base,
|
||||
min_period=pos_embed_rope_min_period,
|
||||
max_period=pos_embed_rope_max_period,
|
||||
normalize_coords=pos_embed_rope_normalize_coords,
|
||||
shift_coords=pos_embed_rope_shift_coords,
|
||||
jitter_coords=pos_embed_rope_jitter_coords,
|
||||
rescale_coords=pos_embed_rope_rescale_coords,
|
||||
dtype=torch.bfloat16 if pos_embed_rope_dtype == "bf16" else torch.float32,
|
||||
)
|
||||
|
||||
# Set out type
|
||||
if out_type not in self.OUT_TYPES:
|
||||
raise ValueError(
|
||||
f"Unsupported `out_type` {out_type}, please "
|
||||
f"choose from {self.OUT_TYPES}"
|
||||
)
|
||||
self.out_type = out_type
|
||||
|
||||
if use_tokenizer == True:
|
||||
self.tokenizer = Tokenizer(
|
||||
embed_dims=self.embed_dims,
|
||||
window_size=self.window_size,
|
||||
num_heads=self.arch_settings["num_heads"],
|
||||
num_tokenizer_layers=self.arch_settings["num_tokenizer_layers"],
|
||||
qkv_bias=True,
|
||||
use_qk_norm=False,
|
||||
)
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
# Set cls + storage tokens
|
||||
self.with_cls_token = with_cls_token
|
||||
if with_cls_token:
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
elif out_type != "cls_token":
|
||||
self.cls_token = None
|
||||
self.num_extra_tokens = 0
|
||||
else:
|
||||
raise ValueError('with_cls_token must be True when `out_type="cls_token"`.')
|
||||
|
||||
## registers
|
||||
self.n_storage_tokens = int(n_storage_tokens)
|
||||
self.storage_tokens = (
|
||||
nn.Parameter(torch.zeros(1, self.n_storage_tokens, self.embed_dims))
|
||||
if self.n_storage_tokens > 0
|
||||
else None
|
||||
)
|
||||
# how many non-patch tokens are at the front
|
||||
self.num_extra_tokens = (
|
||||
1 if self.cls_token is not None else 0
|
||||
) + self.n_storage_tokens
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
assert isinstance(out_indices, Sequence), (
|
||||
f'"out_indices" must by a sequence or int, get {type(out_indices)} instead.'
|
||||
)
|
||||
for i, index in enumerate(out_indices):
|
||||
if index < 0:
|
||||
out_indices[i] = self.num_layers + index
|
||||
assert 0 <= out_indices[i] <= self.num_layers, (
|
||||
f"Invalid out_indices {index}"
|
||||
)
|
||||
self.out_indices = out_indices
|
||||
|
||||
self.blocks = nn.Sequential()
|
||||
if isinstance(layer_cfgs, dict):
|
||||
layer_cfgs = [layer_cfgs] * self.num_layers
|
||||
|
||||
mhsa_early, mhsa_late = 8, 8
|
||||
for i in range(self.num_layers):
|
||||
if i < mhsa_early or i >= self.num_layers - mhsa_late:
|
||||
num_kv_heads = None ## use MHSA
|
||||
else:
|
||||
num_kv_heads = self.arch_settings["num_heads"] // 2 # Use GQA
|
||||
|
||||
_layer_cfg = dict(
|
||||
embed_dims=self.embed_dims,
|
||||
num_heads=self.arch_settings["num_heads"],
|
||||
num_kv_heads=num_kv_heads,
|
||||
feedforward_channels=self.arch_settings["feedforward_channels"],
|
||||
use_qk_norm=use_qk_norm,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
drop_rate=drop_rate,
|
||||
qkv_bias=qkv_bias,
|
||||
)
|
||||
_layer_cfg.update(layer_cfgs[i])
|
||||
self.blocks.append(TransformerEncoderLayer2(**_layer_cfg))
|
||||
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.ln1 = nn.RMSNorm(self.embed_dims, eps=1e-6)
|
||||
|
||||
# freeze stages only when self.frozen_stages > 0
|
||||
if self.frozen_stages > 0:
|
||||
self._freeze_stages()
|
||||
|
||||
## load init weights
|
||||
self.init_weights()
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def enable_gradient_checkpointing(self, enable=True):
|
||||
self.gradient_checkpointing = enable
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.gradient_checkpointing = enable
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def init_weights(self):
|
||||
# Initialize class token and storagr token embeddings
|
||||
if self.with_cls_token:
|
||||
trunc_normal_(self.cls_token, std=0.02)
|
||||
|
||||
if self.storage_tokens is not None:
|
||||
trunc_normal_(self.storage_tokens, std=0.02)
|
||||
|
||||
# Apply custom initialization to all submodules
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
# Use a truncated normal distribution for linear layer weights
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, (nn.LayerNorm, nn.RMSNorm)):
|
||||
# Initialize normalization layers to act as an identity function
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
if hasattr(m, "weight") and m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
# Initialize conv layer weights like linear layers
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _freeze_stages(self):
|
||||
## freeze tokenizer
|
||||
if self.frozen_stages >= 1 and self.tokenizer is not None:
|
||||
self.tokenizer.eval()
|
||||
for param in self.tokenizer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# freeze patch embedding
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
# freeze cls_token
|
||||
if self.cls_token is not None:
|
||||
self.cls_token.requires_grad = False
|
||||
if self.storage_tokens is not None:
|
||||
self.storage_tokens.requires_grad = False
|
||||
# freeze layers
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
m = self.blocks[i - 1]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# freeze the last layer norm
|
||||
if self.frozen_stages == len(self.blocks):
|
||||
if self.final_norm:
|
||||
self.ln1.eval()
|
||||
for param in self.ln1.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
|
||||
x, patch_resolution = self.patch_embed(x) # (B, 256*256, C)
|
||||
if self.tokenizer is not None:
|
||||
x, patch_resolution = self.tokenizer(x, patch_resolution)
|
||||
|
||||
# prepend [CLS] and storage tokens
|
||||
prepend = []
|
||||
if self.cls_token is not None:
|
||||
prepend.append(self.cls_token.expand(B, -1, -1))
|
||||
if self.storage_tokens is not None:
|
||||
prepend.append(self.storage_tokens.expand(B, -1, -1))
|
||||
if len(prepend) > 0:
|
||||
x = torch.cat(prepend + [x], dim=1)
|
||||
|
||||
rope_sincos = self.rope_embed(H=patch_resolution[0], W=patch_resolution[1])
|
||||
outs = []
|
||||
for i, layer in enumerate(self.blocks):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
x = checkpoint(layer, x, rope_sincos, use_reentrant=False)
|
||||
else:
|
||||
x = layer(x, rope=rope_sincos)
|
||||
|
||||
if i == len(self.blocks) - 1 and self.final_norm:
|
||||
x = self.ln1(x)
|
||||
|
||||
if i in self.out_indices:
|
||||
outs.append(self._format_output(x, patch_resolution))
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def _format_output(self, x, hw):
|
||||
if self.out_type == "raw":
|
||||
return x
|
||||
if self.out_type == "cls_token":
|
||||
return x[:, 0]
|
||||
|
||||
patch_token = x[:, self.num_extra_tokens :]
|
||||
if self.out_type == "featmap":
|
||||
B = x.size(0)
|
||||
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
|
||||
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return self.ln1
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
inplace: bool = False,
|
||||
data_format: str = "channels_last",
|
||||
scale: float = 1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
assert data_format in (
|
||||
"channels_last",
|
||||
"channels_first",
|
||||
), "'data_format' could only be channels_last or channels_first."
|
||||
self.inplace = inplace
|
||||
self.data_format = data_format
|
||||
self.weight = nn.Parameter(torch.ones(dim) * scale)
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
if self.data_format == "channels_first":
|
||||
shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2))))
|
||||
else:
|
||||
shape = tuple((*(1 for _ in range(x.dim() - 1)), -1))
|
||||
if self.inplace:
|
||||
return x.mul_(self.weight.view(*shape))
|
||||
else:
|
||||
return x * self.weight.view(*shape)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
kernel_size=16,
|
||||
stride=16,
|
||||
padding="corner",
|
||||
dilation=1,
|
||||
bias=True,
|
||||
input_size=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
if stride is None:
|
||||
stride = kernel_size
|
||||
|
||||
kernel_size = to_2tuple(kernel_size)
|
||||
stride = to_2tuple(stride)
|
||||
dilation = to_2tuple(dilation)
|
||||
padding = 0
|
||||
padding = to_2tuple(padding)
|
||||
|
||||
self.projection = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dims,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if input_size:
|
||||
input_size = to_2tuple(input_size)
|
||||
self.init_input_size = input_size
|
||||
h_out = (
|
||||
input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
|
||||
) // stride[0] + 1
|
||||
w_out = (
|
||||
input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
|
||||
) // stride[1] + 1
|
||||
self.init_out_size = (h_out, w_out)
|
||||
else:
|
||||
self.init_input_size = None
|
||||
self.init_out_size = None
|
||||
|
||||
def forward(self, x):
|
||||
x = self.projection(x)
|
||||
out_size = (x.shape[2], x.shape[3])
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x, out_size
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
class SwiGLUFFN(nn.Module):
|
||||
"""SwiGLU FFN layer.
|
||||
https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py
|
||||
""" # noqa
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int,
|
||||
feedforward_channels: Optional[int] = None,
|
||||
out_dims: Optional[int] = None,
|
||||
layer_scale_init_value: float = 0.0,
|
||||
bias: bool = True,
|
||||
add_identity: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dims = embed_dims
|
||||
self.out_dims = out_dims or embed_dims
|
||||
hidden_dims = feedforward_channels or embed_dims
|
||||
|
||||
self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias)
|
||||
self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias)
|
||||
|
||||
if layer_scale_init_value > 0:
|
||||
self.gamma2 = LayerScale(dim=embed_dims, scale=layer_scale_init_value)
|
||||
else:
|
||||
self.gamma2 = nn.Identity()
|
||||
|
||||
self.add_identity = add_identity
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, identity: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
x12 = self.w12(x)
|
||||
x1, x2 = x12.chunk(2, dim=-1)
|
||||
hidden = F.silu(x1) * x2
|
||||
out = self.w3(hidden)
|
||||
out = self.gamma2(out)
|
||||
|
||||
if self.out_dims != self.embed_dims or not self.add_identity:
|
||||
# due to the dimension inconsistence or user setting
|
||||
# not to apply residual operation
|
||||
return out
|
||||
|
||||
if identity is None:
|
||||
identity = x
|
||||
return identity + out
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
_IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
_IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
|
||||
def imagenet_normalize(tensors_0_1: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply ImageNet normalization to a (B, C, H, W) RGB tensor in [0, 1]."""
|
||||
mean = torch.as_tensor(
|
||||
_IMAGENET_MEAN, dtype=tensors_0_1.dtype, device=tensors_0_1.device
|
||||
).view(1, 3, 1, 1)
|
||||
std = torch.as_tensor(
|
||||
_IMAGENET_STD, dtype=tensors_0_1.dtype, device=tensors_0_1.device
|
||||
).view(1, 3, 1, 1)
|
||||
return (tensors_0_1 - mean) / std
|
||||
Reference in New Issue
Block a user