Add a dino version of DFE

This commit is contained in:
Jaret Burkett
2026-03-04 08:20:37 -07:00
parent 9dee42fc09
commit b04c64e0f8
2 changed files with 167 additions and 2 deletions

View File

@@ -654,7 +654,7 @@ class SDTrainer(BaseSDTrainProcess):
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
elif self.dfe.version in [3, 4, 5]:
elif self.dfe.version in [3, 4, 5, 6]:
dfe_loss = self.dfe(
noise=noise,
noise_pred=noise_pred,

View File

@@ -5,7 +5,7 @@ from torch import nn
from safetensors.torch import load_file
import torch.nn.functional as F
from diffusers import AutoencoderTiny
from transformers import SiglipImageProcessor, SiglipVisionModel
from transformers import AutoImageProcessor, AutoModel, SiglipImageProcessor, SiglipVisionModel
import lpips
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
@@ -644,6 +644,167 @@ class DiffusionFeatureExtractor5(DiffusionFeatureExtractor4):
# return stepped_latents, predicted_images
return predicted_images
class DiffusionFeatureExtractor6(nn.Module):
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None):
super().__init__()
self.version = 6
if vae is None:
raise ValueError("vae must be provided for DFE4")
self.vae = vae
# pretrained_model_name = "facebook/dinov3-vits16-pretrain-lvd1689m"
# pretrained_model_name = "facebook/dinov3-vitl16-pretrain-lvd1689m"
pretrained_model_name = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
# pretrained_model_name = "facebook/dinov3-vit7b16-pretrain-lvd1689m"
self.processor = AutoImageProcessor.from_pretrained(pretrained_model_name)
self.model = AutoModel.from_pretrained(
pretrained_model_name,
device_map=device,
dtype=dtype,
).to(device, dtype=dtype)
self.losses = {}
self.log_every = 100
self.step = 0
def prepare_inputs(self, tensor_0_1: torch.Tensor):
"""
tensor_0_1: (bs, 3, h, w), float, values in [0, 1]
returns: {"pixel_values": (bs, 3, H, W)} ready for the vision transformer
"""
if tensor_0_1.ndim != 4 or tensor_0_1.shape[1] != 3:
raise ValueError(f"Expected (bs, 3, h, w), got {tuple(tensor_0_1.shape)}")
x = tensor_0_1
if not torch.is_floating_point(x):
x = x.float()
# Resize
# if not divisible by 16 or total pixels > max_res*max_res, resize to fit within 16 patches
max_res = 512
p = 16
if (x.shape[-1] % p != 0) or (x.shape[-2] % p != 0) or (x.shape[-1] * x.shape[-2] > max_res * max_res):
target_h = x.shape[-2]
target_w = x.shape[-1]
if x.shape[-1] * target_h > max_res * max_res:
scale_factor = math.sqrt((max_res * max_res) / (target_w * target_h))
target_h = int(target_h * scale_factor)
target_w = int(target_w * scale_factor)
target_h = (target_h // p) * p
target_w = (target_w // p) * p
x = F.interpolate(x, size=(target_h, target_w), mode="bilinear", align_corners=False)
# Rescale (HF processors usually assume uint8 0..255 inputs; your inputs are already 0..1)
if self.processor.do_rescale:
# If it looks like [0..1], skip to avoid double-scaling.
# If user accidentally passed 0..255 floats, this will fix it.
if x.detach().max().item() > 1.0 + 1e-6:
x = x * float(self.processor.rescale_factor or 1.0 / 255.0)
# Normalize
if self.processor.do_normalize:
mean = torch.tensor(self.processor.image_mean, device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
std = torch.tensor(self.processor.image_std, device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
x = (x - mean) / std
return {"pixel_values": x}
def forward(
self,
noise,
noise_pred,
noisy_latents,
timesteps,
batch: DataLoaderBatchDTO,
scheduler: CustomFlowMatchEulerDiscreteScheduler,
model=None
):
dtype = torch.bfloat16
device = self.vae.device
tensors = batch.tensor.to(device, dtype=dtype)
is_video = False
# stack time for video models on the batch dimension
if len(noise_pred.shape) == 5:
# B, C, T, H, W = images.shape
# only take first time
noise = noise[:, :, 0, :, :]
noise_pred = noise_pred[:, :, 0, :, :]
noisy_latents = noisy_latents[:, :, 0, :, :]
is_video = True
if len(tensors.shape) == 5:
# batch is different
# (B, T, C, H, W)
# only take first time
tensors = tensors[:, 0, :, :, :]
with torch.no_grad():
tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0
# expand shape to match noise_pred
while len(tv.shape) < len(noise_pred.shape):
tv = tv.unsqueeze(-1)
# min 0.001
tv = torch.clamp(tv, min=0.001)
# step latent
x0 = noisy_latents - tv * noise_pred
stepped_latents = x0
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0
shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0
latents = (latents / scaling_factor) + shift_factor
if is_video:
# if video, we need to unsqueeze the latents to match the vae input shape
latents = latents.unsqueeze(2)
tensors_n1p1 = self.vae.decode(latents) # -1 to 1
if hasattr(tensors_n1p1, 'sample'):
tensors_n1p1 = tensors_n1p1.sample
if is_video:
# if video, we need to squeeze the tensors to match the output shape
tensors_n1p1 = tensors_n1p1.squeeze(2)
pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1
with torch.no_grad():
target_img = tensors.to(device, dtype=dtype)
# go from -1 to 1 to 0 to 1
target_img = (target_img + 1) / 2
target_dino_input = self.prepare_inputs(target_img)
target_dino_output = self.model(**target_dino_input, output_hidden_states=True)['hidden_states'][-1].detach()
# normalize
target_dino_output = (target_dino_output - target_dino_output.mean()) / (target_dino_output.std() + 1e-6)
pred_dino_input = self.prepare_inputs(pred_images)
pred_dino_output = self.model(**pred_dino_input, output_hidden_states=True)['hidden_states'][-1]
# normalize
pred_dino_output = (pred_dino_output - pred_dino_output.mean()) / (pred_dino_output.std() + 1e-6)
dino_loss = torch.nn.functional.mse_loss(
pred_dino_output.float(), target_dino_output.float()
)
if 'dinov3' not in self.losses:
self.losses['dinov3'] = dino_loss.item()
else:
self.losses['dinov3'] += dino_loss.item()
with torch.no_grad():
if self.step % self.log_every == 0 and self.step > 0:
print(f"DFE losses:")
for key in self.losses:
self.losses[key] /= self.log_every
# print in 2.000e-01 format
print(f" - {key}: {self.losses[key]:.3e}")
self.losses[key] = 0.0
# total_loss += mse_loss
self.step += 1
return dino_loss
def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
if model_path == "v3":
dfe = DiffusionFeatureExtractor3(vae=vae)
@@ -657,6 +818,10 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
dfe = DiffusionFeatureExtractor5(vae=vae)
dfe.eval()
return dfe
if model_path == "v6":
dfe = DiffusionFeatureExtractor6(vae=vae)
dfe.eval()
return dfe
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
# if it ende with safetensors