mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 06:29:48 +00:00
Add a dino version of DFE
This commit is contained in:
@@ -5,7 +5,7 @@ from torch import nn
|
||||
from safetensors.torch import load_file
|
||||
import torch.nn.functional as F
|
||||
from diffusers import AutoencoderTiny
|
||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
||||
from transformers import AutoImageProcessor, AutoModel, SiglipImageProcessor, SiglipVisionModel
|
||||
import lpips
|
||||
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
@@ -644,6 +644,167 @@ class DiffusionFeatureExtractor5(DiffusionFeatureExtractor4):
|
||||
# return stepped_latents, predicted_images
|
||||
return predicted_images
|
||||
|
||||
|
||||
class DiffusionFeatureExtractor6(nn.Module):
|
||||
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None):
|
||||
super().__init__()
|
||||
self.version = 6
|
||||
if vae is None:
|
||||
raise ValueError("vae must be provided for DFE4")
|
||||
self.vae = vae
|
||||
# pretrained_model_name = "facebook/dinov3-vits16-pretrain-lvd1689m"
|
||||
# pretrained_model_name = "facebook/dinov3-vitl16-pretrain-lvd1689m"
|
||||
pretrained_model_name = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
|
||||
# pretrained_model_name = "facebook/dinov3-vit7b16-pretrain-lvd1689m"
|
||||
self.processor = AutoImageProcessor.from_pretrained(pretrained_model_name)
|
||||
self.model = AutoModel.from_pretrained(
|
||||
pretrained_model_name,
|
||||
device_map=device,
|
||||
dtype=dtype,
|
||||
).to(device, dtype=dtype)
|
||||
|
||||
self.losses = {}
|
||||
self.log_every = 100
|
||||
self.step = 0
|
||||
|
||||
def prepare_inputs(self, tensor_0_1: torch.Tensor):
|
||||
"""
|
||||
tensor_0_1: (bs, 3, h, w), float, values in [0, 1]
|
||||
returns: {"pixel_values": (bs, 3, H, W)} ready for the vision transformer
|
||||
"""
|
||||
|
||||
if tensor_0_1.ndim != 4 or tensor_0_1.shape[1] != 3:
|
||||
raise ValueError(f"Expected (bs, 3, h, w), got {tuple(tensor_0_1.shape)}")
|
||||
|
||||
x = tensor_0_1
|
||||
if not torch.is_floating_point(x):
|
||||
x = x.float()
|
||||
|
||||
# Resize
|
||||
# if not divisible by 16 or total pixels > max_res*max_res, resize to fit within 16 patches
|
||||
max_res = 512
|
||||
p = 16
|
||||
if (x.shape[-1] % p != 0) or (x.shape[-2] % p != 0) or (x.shape[-1] * x.shape[-2] > max_res * max_res):
|
||||
target_h = x.shape[-2]
|
||||
target_w = x.shape[-1]
|
||||
if x.shape[-1] * target_h > max_res * max_res:
|
||||
scale_factor = math.sqrt((max_res * max_res) / (target_w * target_h))
|
||||
target_h = int(target_h * scale_factor)
|
||||
target_w = int(target_w * scale_factor)
|
||||
target_h = (target_h // p) * p
|
||||
target_w = (target_w // p) * p
|
||||
x = F.interpolate(x, size=(target_h, target_w), mode="bilinear", align_corners=False)
|
||||
|
||||
# Rescale (HF processors usually assume uint8 0..255 inputs; your inputs are already 0..1)
|
||||
if self.processor.do_rescale:
|
||||
# If it looks like [0..1], skip to avoid double-scaling.
|
||||
# If user accidentally passed 0..255 floats, this will fix it.
|
||||
if x.detach().max().item() > 1.0 + 1e-6:
|
||||
x = x * float(self.processor.rescale_factor or 1.0 / 255.0)
|
||||
|
||||
# Normalize
|
||||
if self.processor.do_normalize:
|
||||
mean = torch.tensor(self.processor.image_mean, device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||
std = torch.tensor(self.processor.image_std, device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||
x = (x - mean) / std
|
||||
|
||||
return {"pixel_values": x}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
noise,
|
||||
noise_pred,
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
batch: DataLoaderBatchDTO,
|
||||
scheduler: CustomFlowMatchEulerDiscreteScheduler,
|
||||
model=None
|
||||
):
|
||||
dtype = torch.bfloat16
|
||||
device = self.vae.device
|
||||
tensors = batch.tensor.to(device, dtype=dtype)
|
||||
is_video = False
|
||||
# stack time for video models on the batch dimension
|
||||
if len(noise_pred.shape) == 5:
|
||||
# B, C, T, H, W = images.shape
|
||||
# only take first time
|
||||
noise = noise[:, :, 0, :, :]
|
||||
noise_pred = noise_pred[:, :, 0, :, :]
|
||||
noisy_latents = noisy_latents[:, :, 0, :, :]
|
||||
is_video = True
|
||||
|
||||
if len(tensors.shape) == 5:
|
||||
# batch is different
|
||||
# (B, T, C, H, W)
|
||||
# only take first time
|
||||
tensors = tensors[:, 0, :, :, :]
|
||||
|
||||
with torch.no_grad():
|
||||
tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0
|
||||
# expand shape to match noise_pred
|
||||
while len(tv.shape) < len(noise_pred.shape):
|
||||
tv = tv.unsqueeze(-1)
|
||||
# min 0.001
|
||||
tv = torch.clamp(tv, min=0.001)
|
||||
|
||||
# step latent
|
||||
x0 = noisy_latents - tv * noise_pred
|
||||
|
||||
stepped_latents = x0
|
||||
|
||||
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
||||
|
||||
scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0
|
||||
shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0
|
||||
latents = (latents / scaling_factor) + shift_factor
|
||||
if is_video:
|
||||
# if video, we need to unsqueeze the latents to match the vae input shape
|
||||
latents = latents.unsqueeze(2)
|
||||
tensors_n1p1 = self.vae.decode(latents) # -1 to 1
|
||||
if hasattr(tensors_n1p1, 'sample'):
|
||||
tensors_n1p1 = tensors_n1p1.sample
|
||||
|
||||
if is_video:
|
||||
# if video, we need to squeeze the tensors to match the output shape
|
||||
tensors_n1p1 = tensors_n1p1.squeeze(2)
|
||||
|
||||
pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1
|
||||
|
||||
with torch.no_grad():
|
||||
target_img = tensors.to(device, dtype=dtype)
|
||||
# go from -1 to 1 to 0 to 1
|
||||
target_img = (target_img + 1) / 2
|
||||
target_dino_input = self.prepare_inputs(target_img)
|
||||
target_dino_output = self.model(**target_dino_input, output_hidden_states=True)['hidden_states'][-1].detach()
|
||||
# normalize
|
||||
target_dino_output = (target_dino_output - target_dino_output.mean()) / (target_dino_output.std() + 1e-6)
|
||||
pred_dino_input = self.prepare_inputs(pred_images)
|
||||
pred_dino_output = self.model(**pred_dino_input, output_hidden_states=True)['hidden_states'][-1]
|
||||
# normalize
|
||||
pred_dino_output = (pred_dino_output - pred_dino_output.mean()) / (pred_dino_output.std() + 1e-6)
|
||||
dino_loss = torch.nn.functional.mse_loss(
|
||||
pred_dino_output.float(), target_dino_output.float()
|
||||
)
|
||||
|
||||
if 'dinov3' not in self.losses:
|
||||
self.losses['dinov3'] = dino_loss.item()
|
||||
else:
|
||||
self.losses['dinov3'] += dino_loss.item()
|
||||
|
||||
with torch.no_grad():
|
||||
if self.step % self.log_every == 0 and self.step > 0:
|
||||
print(f"DFE losses:")
|
||||
for key in self.losses:
|
||||
self.losses[key] /= self.log_every
|
||||
# print in 2.000e-01 format
|
||||
print(f" - {key}: {self.losses[key]:.3e}")
|
||||
self.losses[key] = 0.0
|
||||
|
||||
# total_loss += mse_loss
|
||||
self.step += 1
|
||||
|
||||
return dino_loss
|
||||
|
||||
def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
|
||||
if model_path == "v3":
|
||||
dfe = DiffusionFeatureExtractor3(vae=vae)
|
||||
@@ -657,6 +818,10 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
|
||||
dfe = DiffusionFeatureExtractor5(vae=vae)
|
||||
dfe.eval()
|
||||
return dfe
|
||||
if model_path == "v6":
|
||||
dfe = DiffusionFeatureExtractor6(vae=vae)
|
||||
dfe.eval()
|
||||
return dfe
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
# if it ende with safetensors
|
||||
|
||||
Reference in New Issue
Block a user