From b04c64e0f8ecbb88b906a502844b5cefcfe8efd9 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 4 Mar 2026 08:20:37 -0700 Subject: [PATCH] Add a dino version of DFE --- extensions_built_in/sd_trainer/SDTrainer.py | 2 +- .../models/diffusion_feature_extraction.py | 167 +++++++++++++++++- 2 files changed, 167 insertions(+), 2 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 2ca7f5fc..63c3874b 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -654,7 +654,7 @@ class SDTrainer(BaseSDTrainProcess): dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean") additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0 - elif self.dfe.version in [3, 4, 5]: + elif self.dfe.version in [3, 4, 5, 6]: dfe_loss = self.dfe( noise=noise, noise_pred=noise_pred, diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 30a19d3d..c1effd80 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -5,7 +5,7 @@ from torch import nn from safetensors.torch import load_file import torch.nn.functional as F from diffusers import AutoencoderTiny -from transformers import SiglipImageProcessor, SiglipVisionModel +from transformers import AutoImageProcessor, AutoModel, SiglipImageProcessor, SiglipVisionModel import lpips from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO @@ -644,6 +644,167 @@ class DiffusionFeatureExtractor5(DiffusionFeatureExtractor4): # return stepped_latents, predicted_images return predicted_images + +class DiffusionFeatureExtractor6(nn.Module): + def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None): + super().__init__() + self.version = 6 + if vae is None: + raise ValueError("vae must be provided for DFE4") + self.vae = vae + # pretrained_model_name = "facebook/dinov3-vits16-pretrain-lvd1689m" + # pretrained_model_name = "facebook/dinov3-vitl16-pretrain-lvd1689m" + pretrained_model_name = "facebook/dinov3-vith16plus-pretrain-lvd1689m" + # pretrained_model_name = "facebook/dinov3-vit7b16-pretrain-lvd1689m" + self.processor = AutoImageProcessor.from_pretrained(pretrained_model_name) + self.model = AutoModel.from_pretrained( + pretrained_model_name, + device_map=device, + dtype=dtype, + ).to(device, dtype=dtype) + + self.losses = {} + self.log_every = 100 + self.step = 0 + + def prepare_inputs(self, tensor_0_1: torch.Tensor): + """ + tensor_0_1: (bs, 3, h, w), float, values in [0, 1] + returns: {"pixel_values": (bs, 3, H, W)} ready for the vision transformer + """ + + if tensor_0_1.ndim != 4 or tensor_0_1.shape[1] != 3: + raise ValueError(f"Expected (bs, 3, h, w), got {tuple(tensor_0_1.shape)}") + + x = tensor_0_1 + if not torch.is_floating_point(x): + x = x.float() + + # Resize + # if not divisible by 16 or total pixels > max_res*max_res, resize to fit within 16 patches + max_res = 512 + p = 16 + if (x.shape[-1] % p != 0) or (x.shape[-2] % p != 0) or (x.shape[-1] * x.shape[-2] > max_res * max_res): + target_h = x.shape[-2] + target_w = x.shape[-1] + if x.shape[-1] * target_h > max_res * max_res: + scale_factor = math.sqrt((max_res * max_res) / (target_w * target_h)) + target_h = int(target_h * scale_factor) + target_w = int(target_w * scale_factor) + target_h = (target_h // p) * p + target_w = (target_w // p) * p + x = F.interpolate(x, size=(target_h, target_w), mode="bilinear", align_corners=False) + + # Rescale (HF processors usually assume uint8 0..255 inputs; your inputs are already 0..1) + if self.processor.do_rescale: + # If it looks like [0..1], skip to avoid double-scaling. + # If user accidentally passed 0..255 floats, this will fix it. + if x.detach().max().item() > 1.0 + 1e-6: + x = x * float(self.processor.rescale_factor or 1.0 / 255.0) + + # Normalize + if self.processor.do_normalize: + mean = torch.tensor(self.processor.image_mean, device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + std = torch.tensor(self.processor.image_std, device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + x = (x - mean) / std + + return {"pixel_values": x} + + def forward( + self, + noise, + noise_pred, + noisy_latents, + timesteps, + batch: DataLoaderBatchDTO, + scheduler: CustomFlowMatchEulerDiscreteScheduler, + model=None + ): + dtype = torch.bfloat16 + device = self.vae.device + tensors = batch.tensor.to(device, dtype=dtype) + is_video = False + # stack time for video models on the batch dimension + if len(noise_pred.shape) == 5: + # B, C, T, H, W = images.shape + # only take first time + noise = noise[:, :, 0, :, :] + noise_pred = noise_pred[:, :, 0, :, :] + noisy_latents = noisy_latents[:, :, 0, :, :] + is_video = True + + if len(tensors.shape) == 5: + # batch is different + # (B, T, C, H, W) + # only take first time + tensors = tensors[:, 0, :, :, :] + + with torch.no_grad(): + tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0 + # expand shape to match noise_pred + while len(tv.shape) < len(noise_pred.shape): + tv = tv.unsqueeze(-1) + # min 0.001 + tv = torch.clamp(tv, min=0.001) + + # step latent + x0 = noisy_latents - tv * noise_pred + + stepped_latents = x0 + + latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) + + scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0 + shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0 + latents = (latents / scaling_factor) + shift_factor + if is_video: + # if video, we need to unsqueeze the latents to match the vae input shape + latents = latents.unsqueeze(2) + tensors_n1p1 = self.vae.decode(latents) # -1 to 1 + if hasattr(tensors_n1p1, 'sample'): + tensors_n1p1 = tensors_n1p1.sample + + if is_video: + # if video, we need to squeeze the tensors to match the output shape + tensors_n1p1 = tensors_n1p1.squeeze(2) + + pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1 + + with torch.no_grad(): + target_img = tensors.to(device, dtype=dtype) + # go from -1 to 1 to 0 to 1 + target_img = (target_img + 1) / 2 + target_dino_input = self.prepare_inputs(target_img) + target_dino_output = self.model(**target_dino_input, output_hidden_states=True)['hidden_states'][-1].detach() + # normalize + target_dino_output = (target_dino_output - target_dino_output.mean()) / (target_dino_output.std() + 1e-6) + pred_dino_input = self.prepare_inputs(pred_images) + pred_dino_output = self.model(**pred_dino_input, output_hidden_states=True)['hidden_states'][-1] + # normalize + pred_dino_output = (pred_dino_output - pred_dino_output.mean()) / (pred_dino_output.std() + 1e-6) + dino_loss = torch.nn.functional.mse_loss( + pred_dino_output.float(), target_dino_output.float() + ) + + if 'dinov3' not in self.losses: + self.losses['dinov3'] = dino_loss.item() + else: + self.losses['dinov3'] += dino_loss.item() + + with torch.no_grad(): + if self.step % self.log_every == 0 and self.step > 0: + print(f"DFE losses:") + for key in self.losses: + self.losses[key] /= self.log_every + # print in 2.000e-01 format + print(f" - {key}: {self.losses[key]:.3e}") + self.losses[key] = 0.0 + + # total_loss += mse_loss + self.step += 1 + + return dino_loss + def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor: if model_path == "v3": dfe = DiffusionFeatureExtractor3(vae=vae) @@ -657,6 +818,10 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor: dfe = DiffusionFeatureExtractor5(vae=vae) dfe.eval() return dfe + if model_path == "v6": + dfe = DiffusionFeatureExtractor6(vae=vae) + dfe.eval() + return dfe if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") # if it ende with safetensors