mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-26 17:29:27 +00:00
Various code to support experiments.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
import torch
|
||||
import os
|
||||
from torch import nn
|
||||
@@ -351,12 +352,251 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
|
||||
return total_loss
|
||||
|
||||
class DiffusionFeatureExtractor4(nn.Module):
|
||||
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None):
|
||||
super().__init__()
|
||||
self.version = 4
|
||||
if vae is None:
|
||||
raise ValueError("vae must be provided for DFE4")
|
||||
self.vae = vae
|
||||
# image_encoder_path = "google/siglip-so400m-patch14-384"
|
||||
image_encoder_path = "google/siglip2-so400m-patch16-naflex"
|
||||
from transformers import Siglip2ImageProcessor, Siglip2VisionModel
|
||||
try:
|
||||
self.image_processor = Siglip2ImageProcessor.from_pretrained(
|
||||
image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.image_processor = Siglip2ImageProcessor()
|
||||
|
||||
self.image_processor.max_num_patches = 1024
|
||||
|
||||
self.vision_encoder = Siglip2VisionModel.from_pretrained(
|
||||
image_encoder_path,
|
||||
ignore_mismatched_sizes=True
|
||||
).to(device, dtype=dtype)
|
||||
|
||||
self.losses = {}
|
||||
self.log_every = 100
|
||||
self.step = 0
|
||||
|
||||
def _target_hw(self, h, w, patch, max_patches, eps: float = 1e-5):
|
||||
def _snap(x, s):
|
||||
x = math.ceil((x * s) / patch) * patch
|
||||
return max(patch, int(x))
|
||||
|
||||
lo, hi = eps / 10, 1.0
|
||||
while hi - lo >= eps:
|
||||
mid = (lo + hi) / 2
|
||||
th, tw = _snap(h, mid), _snap(w, mid)
|
||||
if (th // patch) * (tw // patch) <= max_patches:
|
||||
lo = mid
|
||||
else:
|
||||
hi = mid
|
||||
return _snap(h, lo), _snap(w, lo)
|
||||
|
||||
|
||||
def tensors_to_siglip_like_features(self, batch: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
batch: (bs, 3, H, W) tensor already in the desired value range
|
||||
(e.g. [-1, 1] or [0, 1]); no extra rescale / normalize here.
|
||||
|
||||
Returns:
|
||||
dict(
|
||||
pixel_values – (bs, L, P) where L = n_h*n_w, P = 3*patch*patch
|
||||
pixel_attention_mask– (L,) all-ones
|
||||
spatial_shapes – (n_h, n_w)
|
||||
)
|
||||
"""
|
||||
if batch.ndim != 4:
|
||||
raise ValueError("Expected (bs, 3, H, W) tensor")
|
||||
|
||||
bs, c, H, W = batch.shape
|
||||
proc = self.image_processor
|
||||
patch = proc.patch_size
|
||||
max_patches = proc.max_num_patches
|
||||
|
||||
# One shared resize for the whole batch
|
||||
tgt_h, tgt_w = self._target_hw(H, W, patch, max_patches)
|
||||
batch = torch.nn.functional.interpolate(
|
||||
batch, size=(tgt_h, tgt_w), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
n_h, n_w = tgt_h // patch, tgt_w // patch
|
||||
# flat_dim = c * patch * patch
|
||||
num_p = n_h * n_w
|
||||
|
||||
# unfold → (bs, flat_dim, num_p) → (bs, num_p, flat_dim)
|
||||
patches = (
|
||||
torch.nn.functional.unfold(batch, kernel_size=patch, stride=patch)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
attn_mask = torch.ones(num_p, dtype=torch.long, device=batch.device)
|
||||
spatial = torch.tensor((n_h, n_w), device=batch.device, dtype=torch.int32)
|
||||
|
||||
# repeat attn_mask for each batch element
|
||||
attn_mask = attn_mask.unsqueeze(0).repeat(bs, 1)
|
||||
spatial = spatial.unsqueeze(0).repeat(bs, 1)
|
||||
|
||||
return {
|
||||
"pixel_values": patches, # (bs, num_patches, patch_dim)
|
||||
"pixel_attention_mask": attn_mask, # (num_patches,)
|
||||
"spatial_shapes": spatial
|
||||
}
|
||||
|
||||
def get_siglip_features(self, tensors_0_1):
|
||||
dtype = torch.bfloat16
|
||||
device = self.vae.device
|
||||
|
||||
tensors_0_1 = torch.clamp(tensors_0_1, 0.0, 1.0)
|
||||
|
||||
mean = torch.tensor(self.image_processor.image_mean).to(
|
||||
device, dtype=dtype
|
||||
).detach()
|
||||
std = torch.tensor(self.image_processor.image_std).to(
|
||||
device, dtype=dtype
|
||||
).detach()
|
||||
# tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
|
||||
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
|
||||
|
||||
encoder_kwargs = self.tensors_to_siglip_like_features(clip_image)
|
||||
id_embeds = self.vision_encoder(
|
||||
pixel_values=encoder_kwargs['pixel_values'],
|
||||
pixel_attention_mask=encoder_kwargs['pixel_attention_mask'],
|
||||
spatial_shapes=encoder_kwargs['spatial_shapes'],
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# embeds = id_embeds['hidden_states'][-2] # penultimate layer
|
||||
embeds = id_embeds['pooler_output']
|
||||
return embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
noise,
|
||||
noise_pred,
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
batch: DataLoaderBatchDTO,
|
||||
scheduler: CustomFlowMatchEulerDiscreteScheduler,
|
||||
clip_weight=1.0,
|
||||
mse_weight=0.0,
|
||||
model=None
|
||||
):
|
||||
dtype = torch.bfloat16
|
||||
device = self.vae.device
|
||||
tensors = batch.tensor.to(device, dtype=dtype)
|
||||
is_video = False
|
||||
# stack time for video models on the batch dimension
|
||||
if len(noise_pred.shape) == 5:
|
||||
# B, C, T, H, W = images.shape
|
||||
# only take first time
|
||||
noise = noise[:, :, 0, :, :]
|
||||
noise_pred = noise_pred[:, :, 0, :, :]
|
||||
noisy_latents = noisy_latents[:, :, 0, :, :]
|
||||
is_video = True
|
||||
|
||||
if len(tensors.shape) == 5:
|
||||
# batch is different
|
||||
# (B, T, C, H, W)
|
||||
# only take first time
|
||||
tensors = tensors[:, 0, :, :, :]
|
||||
|
||||
if model is not None and hasattr(model, 'get_stepped_pred'):
|
||||
stepped_latents = model.get_stepped_pred(noise_pred, noise)
|
||||
else:
|
||||
# stepped_latents = noise - noise_pred
|
||||
# first we step the scheduler from current timestep to the very end for a full denoise
|
||||
bs = noise_pred.shape[0]
|
||||
noise_pred_chunks = torch.chunk(noise_pred, bs)
|
||||
timestep_chunks = torch.chunk(timesteps, bs)
|
||||
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
||||
stepped_chunks = []
|
||||
for idx in range(bs):
|
||||
model_output = noise_pred_chunks[idx]
|
||||
timestep = timestep_chunks[idx]
|
||||
scheduler._step_index = None
|
||||
scheduler._init_step_index(timestep)
|
||||
sample = noisy_latent_chunks[idx].to(torch.float32)
|
||||
|
||||
sigma = scheduler.sigmas[scheduler.step_index]
|
||||
sigma_next = scheduler.sigmas[-1] # use last sigma for final step
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
stepped_chunks.append(prev_sample)
|
||||
|
||||
stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
|
||||
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
||||
|
||||
scaling_factor = self.vae.config['scaling_factor'] if 'scaling_factor' in self.vae.config else 1.0
|
||||
shift_factor = self.vae.config['shift_factor'] if 'shift_factor' in self.vae.config else 0.0
|
||||
latents = (latents / scaling_factor) + shift_factor
|
||||
if is_video:
|
||||
# if video, we need to unsqueeze the latents to match the vae input shape
|
||||
latents = latents.unsqueeze(2)
|
||||
tensors_n1p1 = self.vae.decode(latents).sample # -1 to 1
|
||||
|
||||
if is_video:
|
||||
# if video, we need to squeeze the tensors to match the output shape
|
||||
tensors_n1p1 = tensors_n1p1.squeeze(2)
|
||||
|
||||
pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1
|
||||
|
||||
total_loss = 0
|
||||
|
||||
with torch.no_grad():
|
||||
target_img = tensors.to(device, dtype=dtype)
|
||||
# go from -1 to 1 to 0 to 1
|
||||
target_img = (target_img + 1) / 2
|
||||
if clip_weight > 0:
|
||||
target_clip_output = self.get_siglip_features(target_img).detach()
|
||||
if clip_weight > 0:
|
||||
pred_clip_output = self.get_siglip_features(pred_images)
|
||||
clip_loss = torch.nn.functional.mse_loss(
|
||||
pred_clip_output.float(), target_clip_output.float()
|
||||
) * clip_weight
|
||||
|
||||
if 'clip_loss' not in self.losses:
|
||||
self.losses['clip_loss'] = clip_loss.item()
|
||||
else:
|
||||
self.losses['clip_loss'] += clip_loss.item()
|
||||
|
||||
total_loss += clip_loss
|
||||
if mse_weight > 0:
|
||||
mse_loss = torch.nn.functional.mse_loss(
|
||||
pred_images.float(), target_img.float()
|
||||
) * mse_weight
|
||||
|
||||
if 'mse_loss' not in self.losses:
|
||||
self.losses['mse_loss'] = mse_loss.item()
|
||||
else:
|
||||
self.losses['mse_loss'] += mse_loss.item()
|
||||
|
||||
total_loss += mse_loss
|
||||
|
||||
if self.step % self.log_every == 0 and self.step > 0:
|
||||
print(f"DFE losses:")
|
||||
for key in self.losses:
|
||||
self.losses[key] /= self.log_every
|
||||
# print in 2.000e-01 format
|
||||
print(f" - {key}: {self.losses[key]:.3e}")
|
||||
self.losses[key] = 0.0
|
||||
|
||||
# total_loss += mse_loss
|
||||
self.step += 1
|
||||
|
||||
return total_loss
|
||||
|
||||
def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
|
||||
if model_path == "v3":
|
||||
dfe = DiffusionFeatureExtractor3(vae=vae)
|
||||
dfe.eval()
|
||||
return dfe
|
||||
if model_path == "v4":
|
||||
dfe = DiffusionFeatureExtractor4(vae=vae)
|
||||
dfe.eval()
|
||||
return dfe
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
# if it ende with safetensors
|
||||
|
||||
Reference in New Issue
Block a user