mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Add a dino version of DFE
This commit is contained in:
@@ -654,7 +654,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
|
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
|
||||||
|
|
||||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
|
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
|
||||||
elif self.dfe.version in [3, 4, 5]:
|
elif self.dfe.version in [3, 4, 5, 6]:
|
||||||
dfe_loss = self.dfe(
|
dfe_loss = self.dfe(
|
||||||
noise=noise,
|
noise=noise,
|
||||||
noise_pred=noise_pred,
|
noise_pred=noise_pred,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from torch import nn
|
|||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from diffusers import AutoencoderTiny
|
from diffusers import AutoencoderTiny
|
||||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
from transformers import AutoImageProcessor, AutoModel, SiglipImageProcessor, SiglipVisionModel
|
||||||
import lpips
|
import lpips
|
||||||
|
|
||||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||||
@@ -644,6 +644,167 @@ class DiffusionFeatureExtractor5(DiffusionFeatureExtractor4):
|
|||||||
# return stepped_latents, predicted_images
|
# return stepped_latents, predicted_images
|
||||||
return predicted_images
|
return predicted_images
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionFeatureExtractor6(nn.Module):
|
||||||
|
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None):
|
||||||
|
super().__init__()
|
||||||
|
self.version = 6
|
||||||
|
if vae is None:
|
||||||
|
raise ValueError("vae must be provided for DFE4")
|
||||||
|
self.vae = vae
|
||||||
|
# pretrained_model_name = "facebook/dinov3-vits16-pretrain-lvd1689m"
|
||||||
|
# pretrained_model_name = "facebook/dinov3-vitl16-pretrain-lvd1689m"
|
||||||
|
pretrained_model_name = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
|
||||||
|
# pretrained_model_name = "facebook/dinov3-vit7b16-pretrain-lvd1689m"
|
||||||
|
self.processor = AutoImageProcessor.from_pretrained(pretrained_model_name)
|
||||||
|
self.model = AutoModel.from_pretrained(
|
||||||
|
pretrained_model_name,
|
||||||
|
device_map=device,
|
||||||
|
dtype=dtype,
|
||||||
|
).to(device, dtype=dtype)
|
||||||
|
|
||||||
|
self.losses = {}
|
||||||
|
self.log_every = 100
|
||||||
|
self.step = 0
|
||||||
|
|
||||||
|
def prepare_inputs(self, tensor_0_1: torch.Tensor):
|
||||||
|
"""
|
||||||
|
tensor_0_1: (bs, 3, h, w), float, values in [0, 1]
|
||||||
|
returns: {"pixel_values": (bs, 3, H, W)} ready for the vision transformer
|
||||||
|
"""
|
||||||
|
|
||||||
|
if tensor_0_1.ndim != 4 or tensor_0_1.shape[1] != 3:
|
||||||
|
raise ValueError(f"Expected (bs, 3, h, w), got {tuple(tensor_0_1.shape)}")
|
||||||
|
|
||||||
|
x = tensor_0_1
|
||||||
|
if not torch.is_floating_point(x):
|
||||||
|
x = x.float()
|
||||||
|
|
||||||
|
# Resize
|
||||||
|
# if not divisible by 16 or total pixels > max_res*max_res, resize to fit within 16 patches
|
||||||
|
max_res = 512
|
||||||
|
p = 16
|
||||||
|
if (x.shape[-1] % p != 0) or (x.shape[-2] % p != 0) or (x.shape[-1] * x.shape[-2] > max_res * max_res):
|
||||||
|
target_h = x.shape[-2]
|
||||||
|
target_w = x.shape[-1]
|
||||||
|
if x.shape[-1] * target_h > max_res * max_res:
|
||||||
|
scale_factor = math.sqrt((max_res * max_res) / (target_w * target_h))
|
||||||
|
target_h = int(target_h * scale_factor)
|
||||||
|
target_w = int(target_w * scale_factor)
|
||||||
|
target_h = (target_h // p) * p
|
||||||
|
target_w = (target_w // p) * p
|
||||||
|
x = F.interpolate(x, size=(target_h, target_w), mode="bilinear", align_corners=False)
|
||||||
|
|
||||||
|
# Rescale (HF processors usually assume uint8 0..255 inputs; your inputs are already 0..1)
|
||||||
|
if self.processor.do_rescale:
|
||||||
|
# If it looks like [0..1], skip to avoid double-scaling.
|
||||||
|
# If user accidentally passed 0..255 floats, this will fix it.
|
||||||
|
if x.detach().max().item() > 1.0 + 1e-6:
|
||||||
|
x = x * float(self.processor.rescale_factor or 1.0 / 255.0)
|
||||||
|
|
||||||
|
# Normalize
|
||||||
|
if self.processor.do_normalize:
|
||||||
|
mean = torch.tensor(self.processor.image_mean, device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||||
|
std = torch.tensor(self.processor.image_std, device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||||
|
x = (x - mean) / std
|
||||||
|
|
||||||
|
return {"pixel_values": x}
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
noise,
|
||||||
|
noise_pred,
|
||||||
|
noisy_latents,
|
||||||
|
timesteps,
|
||||||
|
batch: DataLoaderBatchDTO,
|
||||||
|
scheduler: CustomFlowMatchEulerDiscreteScheduler,
|
||||||
|
model=None
|
||||||
|
):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = self.vae.device
|
||||||
|
tensors = batch.tensor.to(device, dtype=dtype)
|
||||||
|
is_video = False
|
||||||
|
# stack time for video models on the batch dimension
|
||||||
|
if len(noise_pred.shape) == 5:
|
||||||
|
# B, C, T, H, W = images.shape
|
||||||
|
# only take first time
|
||||||
|
noise = noise[:, :, 0, :, :]
|
||||||
|
noise_pred = noise_pred[:, :, 0, :, :]
|
||||||
|
noisy_latents = noisy_latents[:, :, 0, :, :]
|
||||||
|
is_video = True
|
||||||
|
|
||||||
|
if len(tensors.shape) == 5:
|
||||||
|
# batch is different
|
||||||
|
# (B, T, C, H, W)
|
||||||
|
# only take first time
|
||||||
|
tensors = tensors[:, 0, :, :, :]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0
|
||||||
|
# expand shape to match noise_pred
|
||||||
|
while len(tv.shape) < len(noise_pred.shape):
|
||||||
|
tv = tv.unsqueeze(-1)
|
||||||
|
# min 0.001
|
||||||
|
tv = torch.clamp(tv, min=0.001)
|
||||||
|
|
||||||
|
# step latent
|
||||||
|
x0 = noisy_latents - tv * noise_pred
|
||||||
|
|
||||||
|
stepped_latents = x0
|
||||||
|
|
||||||
|
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
||||||
|
|
||||||
|
scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0
|
||||||
|
shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0
|
||||||
|
latents = (latents / scaling_factor) + shift_factor
|
||||||
|
if is_video:
|
||||||
|
# if video, we need to unsqueeze the latents to match the vae input shape
|
||||||
|
latents = latents.unsqueeze(2)
|
||||||
|
tensors_n1p1 = self.vae.decode(latents) # -1 to 1
|
||||||
|
if hasattr(tensors_n1p1, 'sample'):
|
||||||
|
tensors_n1p1 = tensors_n1p1.sample
|
||||||
|
|
||||||
|
if is_video:
|
||||||
|
# if video, we need to squeeze the tensors to match the output shape
|
||||||
|
tensors_n1p1 = tensors_n1p1.squeeze(2)
|
||||||
|
|
||||||
|
pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
target_img = tensors.to(device, dtype=dtype)
|
||||||
|
# go from -1 to 1 to 0 to 1
|
||||||
|
target_img = (target_img + 1) / 2
|
||||||
|
target_dino_input = self.prepare_inputs(target_img)
|
||||||
|
target_dino_output = self.model(**target_dino_input, output_hidden_states=True)['hidden_states'][-1].detach()
|
||||||
|
# normalize
|
||||||
|
target_dino_output = (target_dino_output - target_dino_output.mean()) / (target_dino_output.std() + 1e-6)
|
||||||
|
pred_dino_input = self.prepare_inputs(pred_images)
|
||||||
|
pred_dino_output = self.model(**pred_dino_input, output_hidden_states=True)['hidden_states'][-1]
|
||||||
|
# normalize
|
||||||
|
pred_dino_output = (pred_dino_output - pred_dino_output.mean()) / (pred_dino_output.std() + 1e-6)
|
||||||
|
dino_loss = torch.nn.functional.mse_loss(
|
||||||
|
pred_dino_output.float(), target_dino_output.float()
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'dinov3' not in self.losses:
|
||||||
|
self.losses['dinov3'] = dino_loss.item()
|
||||||
|
else:
|
||||||
|
self.losses['dinov3'] += dino_loss.item()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.step % self.log_every == 0 and self.step > 0:
|
||||||
|
print(f"DFE losses:")
|
||||||
|
for key in self.losses:
|
||||||
|
self.losses[key] /= self.log_every
|
||||||
|
# print in 2.000e-01 format
|
||||||
|
print(f" - {key}: {self.losses[key]:.3e}")
|
||||||
|
self.losses[key] = 0.0
|
||||||
|
|
||||||
|
# total_loss += mse_loss
|
||||||
|
self.step += 1
|
||||||
|
|
||||||
|
return dino_loss
|
||||||
|
|
||||||
def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
|
def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
|
||||||
if model_path == "v3":
|
if model_path == "v3":
|
||||||
dfe = DiffusionFeatureExtractor3(vae=vae)
|
dfe = DiffusionFeatureExtractor3(vae=vae)
|
||||||
@@ -657,6 +818,10 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
|
|||||||
dfe = DiffusionFeatureExtractor5(vae=vae)
|
dfe = DiffusionFeatureExtractor5(vae=vae)
|
||||||
dfe.eval()
|
dfe.eval()
|
||||||
return dfe
|
return dfe
|
||||||
|
if model_path == "v6":
|
||||||
|
dfe = DiffusionFeatureExtractor6(vae=vae)
|
||||||
|
dfe.eval()
|
||||||
|
return dfe
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||||
# if it ende with safetensors
|
# if it ende with safetensors
|
||||||
|
|||||||
Reference in New Issue
Block a user