mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-22 21:33:59 +00:00
Added a guidance burning loss. Modified DFE to work with new model. Bug fixes
This commit is contained in:
@@ -61,7 +61,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self._clip_image_embeds_unconditional: Union[List[str], None] = None
|
||||
self.negative_prompt_pool: Union[List[str], None] = None
|
||||
self.batch_negative_prompt: Union[List[str], None] = None
|
||||
self.cfm_cache = None
|
||||
|
||||
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
||||
|
||||
@@ -84,6 +83,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None
|
||||
|
||||
self.dfe: Optional[DiffusionFeatureExtractor] = None
|
||||
self.unconditional_embeds = None
|
||||
|
||||
if self.train_config.diff_output_preservation:
|
||||
if self.trigger_word is None:
|
||||
@@ -95,6 +95,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# always do a prior prediction when doing diff output preservation
|
||||
self.do_prior_prediction = True
|
||||
|
||||
# store the loss target for a batch so we can use it in a loss
|
||||
self._guidance_loss_target_batch: float = 0.0
|
||||
if isinstance(self.train_config.guidance_loss_target, (int, float)):
|
||||
self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target)
|
||||
elif isinstance(self.train_config.guidance_loss_target, list):
|
||||
self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target[0])
|
||||
else:
|
||||
raise ValueError(f"Unknown guidance loss target type {type(self.train_config.guidance_loss_target)}")
|
||||
|
||||
|
||||
def before_model_load(self):
|
||||
@@ -135,6 +144,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
|
||||
# cache unconditional embeds (blank prompt)
|
||||
with torch.no_grad():
|
||||
self.unconditional_embeds = self.sd.encode_prompt(
|
||||
[''],
|
||||
long_prompts=self.do_long_prompts
|
||||
).to(
|
||||
self.device_torch,
|
||||
dtype=self.sd.torch_dtype
|
||||
).detach()
|
||||
|
||||
if self.train_config.do_prior_divergence:
|
||||
self.do_prior_prediction = True
|
||||
# move vae to device if we did not cache latents
|
||||
@@ -476,6 +495,47 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
|
||||
else:
|
||||
raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}")
|
||||
|
||||
if self.train_config.do_guidance_loss:
|
||||
with torch.no_grad():
|
||||
# we make cached blank prompt embeds that match the batch size
|
||||
unconditional_embeds = concat_prompt_embeds(
|
||||
[self.unconditional_embeds] * noisy_latents.shape[0],
|
||||
)
|
||||
cfm_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=unconditional_embeds,
|
||||
unconditional_embeds=None,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
# zero cfg
|
||||
|
||||
# ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557
|
||||
batch_size = target.shape[0]
|
||||
positive_flat = target.view(batch_size, -1)
|
||||
negative_flat = cfm_pred.view(batch_size, -1)
|
||||
# Calculate dot production
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
# Squared norm of uncondition
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
st_star = dot_product / squared_norm
|
||||
|
||||
alpha = st_star
|
||||
|
||||
is_video = len(target.shape) == 5
|
||||
|
||||
alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1)
|
||||
|
||||
guidance_scale = self._guidance_loss_target_batch
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype)
|
||||
guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1)
|
||||
|
||||
unconditional_target = cfm_pred * alpha
|
||||
target = unconditional_target + guidance_scale * (target - unconditional_target)
|
||||
|
||||
|
||||
if target is None:
|
||||
@@ -895,6 +955,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
guidance_embedding_scale = self.train_config.cfg_scale
|
||||
if self.train_config.do_guidance_loss:
|
||||
guidance_embedding_scale = self._guidance_loss_target_batch
|
||||
|
||||
prior_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
@@ -902,6 +966,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
guidance_embedding_scale=guidance_embedding_scale,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
batch=batch,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
@@ -945,13 +1010,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**kwargs,
|
||||
):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
guidance_embedding_scale = self.train_config.cfg_scale
|
||||
if self.train_config.do_guidance_loss:
|
||||
guidance_embedding_scale = self._guidance_loss_target_batch
|
||||
return self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
guidance_embedding_scale=self.train_config.cfg_scale,
|
||||
guidance_embedding_scale=guidance_embedding_scale,
|
||||
detach_unconditional=False,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
bypass_guidance_embedding=self.train_config.bypass_guidance_embedding,
|
||||
@@ -959,80 +1027,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def cfm_augment_tensors(
|
||||
self,
|
||||
images: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if self.cfm_cache is None:
|
||||
# flip the current one. Only need this for first time
|
||||
self.cfm_cache = torch.flip(images, [3]).clone()
|
||||
augmented_tensor_list = []
|
||||
for i in range(images.shape[0]):
|
||||
# get a random one
|
||||
idx = random.randint(0, self.cfm_cache.shape[0] - 1)
|
||||
augmented_tensor_list.append(self.cfm_cache[idx:idx + 1])
|
||||
augmented = torch.cat(augmented_tensor_list, dim=0)
|
||||
# resize to match the input
|
||||
augmented = torch.nn.functional.interpolate(augmented, size=(images.shape[2], images.shape[3]), mode='bilinear')
|
||||
self.cfm_cache = images.clone()
|
||||
return augmented
|
||||
|
||||
def get_cfm_loss(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
noise_pred: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
timesteps: torch.Tensor,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
alpha: float = 0.1,
|
||||
):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
if hasattr(self.sd, 'get_loss_target'):
|
||||
target = self.sd.get_loss_target(
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
timesteps=timesteps,
|
||||
).detach()
|
||||
|
||||
elif self.sd.is_flow_matching:
|
||||
# forward ODE
|
||||
target = (noise - batch.latents).detach()
|
||||
else:
|
||||
raise ValueError("CFM loss only works with flow matching")
|
||||
fm_loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
with torch.no_grad():
|
||||
# we need to compute the contrast
|
||||
cfm_batch_tensors = self.cfm_augment_tensors(batch.tensor).to(self.device_torch, dtype=dtype)
|
||||
cfm_latents = self.sd.encode_images(cfm_batch_tensors).to(self.device_torch, dtype=dtype)
|
||||
cfm_noisy_latents = self.sd.add_noise(
|
||||
original_samples=cfm_latents,
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
cfm_pred = self.predict_noise(
|
||||
noisy_latents=cfm_noisy_latents,
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=None,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
# v_neg = torch.nn.functional.normalize(cfm_pred.float(), dim=1)
|
||||
# v_pos = torch.nn.functional.normalize(noise_pred.float(), dim=1) # shape: (B, C, H, W)
|
||||
|
||||
# # Compute cosine similarity at each pixel
|
||||
# sim = (v_pos * v_neg).sum(dim=1) # shape: (B, H, W)
|
||||
|
||||
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
# Compute cosine similarity at each pixel
|
||||
sim = cos(cfm_pred.float(), noise_pred.float()) # shape: (B, H, W)
|
||||
|
||||
# Average over spatial dimensions, then batch
|
||||
contrastive_loss = -sim.mean()
|
||||
|
||||
loss = fm_loss.mean() + alpha * contrastive_loss
|
||||
return loss
|
||||
|
||||
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
||||
with torch.no_grad():
|
||||
@@ -1658,6 +1652,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
)
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
|
||||
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
|
||||
|
||||
if self.train_config.do_guidance_loss and isinstance(self.train_config.guidance_loss_target, list):
|
||||
batch_size = noisy_latents.shape[0]
|
||||
# update the guidance value, random float between guidance_loss_target[0] and guidance_loss_target[1]
|
||||
self._guidance_loss_target_batch = [
|
||||
random.uniform(
|
||||
self.train_config.guidance_loss_target[0],
|
||||
self.train_config.guidance_loss_target[1]
|
||||
) for _ in range(batch_size)
|
||||
]
|
||||
|
||||
self.before_unet_predict()
|
||||
|
||||
@@ -1757,25 +1761,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
|
||||
prior_to_calculate_loss = None
|
||||
|
||||
if self.train_config.loss_type == 'cfm':
|
||||
loss = self.get_cfm_loss(
|
||||
noisy_latents=noisy_latents,
|
||||
noise=noise,
|
||||
noise_pred=noise_pred,
|
||||
conditional_embeds=conditional_embeds,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
)
|
||||
else:
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_to_calculate_loss,
|
||||
)
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_to_calculate_loss,
|
||||
)
|
||||
|
||||
if self.train_config.diff_output_preservation:
|
||||
# send the loss backwards otherwise checkpointing will fail
|
||||
|
||||
@@ -1012,7 +1012,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = None
|
||||
is_reg = any(batch.get_is_reg_list())
|
||||
cfm_batch = None
|
||||
if batch.tensor is not None:
|
||||
imgs = batch.tensor
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
|
||||
@@ -113,6 +113,8 @@ class GenerateProcess(BaseProcess):
|
||||
prompt_image_configs = []
|
||||
for _ in range(self.generate_config.num_repeats):
|
||||
for prompt in self.generate_config.prompts:
|
||||
# remove --
|
||||
prompt = prompt.replace('--', '').strip()
|
||||
width = self.generate_config.width
|
||||
height = self.generate_config.height
|
||||
# prompt = self.clean_prompt(prompt)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Literal, Union, TYPE_CHECKING, Dict
|
||||
from typing import List, Optional, Literal, Tuple, Union, TYPE_CHECKING, Dict
|
||||
import random
|
||||
|
||||
import torch
|
||||
@@ -413,7 +413,7 @@ class TrainConfig:
|
||||
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
||||
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
||||
|
||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm, mean_flow
|
||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, mean_flow
|
||||
|
||||
# scale the prediction by this. Increase for more detail, decrease for less
|
||||
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
||||
@@ -467,6 +467,12 @@ class TrainConfig:
|
||||
# forces same noise for the same image at a given size.
|
||||
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
|
||||
self.blended_blur_noise = kwargs.get('blended_blur_noise', False)
|
||||
|
||||
# contrastive loss
|
||||
self.do_guidance_loss = kwargs.get('do_guidance_loss', False)
|
||||
self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0)
|
||||
if isinstance(self.guidance_loss_target, tuple):
|
||||
self.guidance_loss_target = list(self.guidance_loss_target)
|
||||
|
||||
|
||||
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
|
||||
@@ -1145,3 +1151,6 @@ def validate_configs(
|
||||
if model_config.use_flux_cfg:
|
||||
# bypass the embedding
|
||||
train_config.bypass_guidance_embedding = True
|
||||
if train_config.bypass_guidance_embedding and train_config.do_guidance_loss:
|
||||
raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. "
|
||||
"Please set bypass_guidance_embedding to False or do_guidance_loss to False.")
|
||||
|
||||
@@ -815,7 +815,10 @@ class BaseModel:
|
||||
|
||||
# predict the noise residual
|
||||
if self.unet.device != self.device_torch:
|
||||
self.unet.to(self.device_torch)
|
||||
try:
|
||||
self.unet.to(self.device_torch)
|
||||
except Exception as e:
|
||||
pass
|
||||
if self.unet.dtype != self.torch_dtype:
|
||||
self.unet = self.unet.to(dtype=self.torch_dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user