mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added a guidance burning loss. Modified DFE to work with new model. Bug fixes
This commit is contained in:
@@ -61,7 +61,6 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self._clip_image_embeds_unconditional: Union[List[str], None] = None
|
self._clip_image_embeds_unconditional: Union[List[str], None] = None
|
||||||
self.negative_prompt_pool: Union[List[str], None] = None
|
self.negative_prompt_pool: Union[List[str], None] = None
|
||||||
self.batch_negative_prompt: Union[List[str], None] = None
|
self.batch_negative_prompt: Union[List[str], None] = None
|
||||||
self.cfm_cache = None
|
|
||||||
|
|
||||||
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
||||||
|
|
||||||
@@ -84,6 +83,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None
|
self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None
|
||||||
|
|
||||||
self.dfe: Optional[DiffusionFeatureExtractor] = None
|
self.dfe: Optional[DiffusionFeatureExtractor] = None
|
||||||
|
self.unconditional_embeds = None
|
||||||
|
|
||||||
if self.train_config.diff_output_preservation:
|
if self.train_config.diff_output_preservation:
|
||||||
if self.trigger_word is None:
|
if self.trigger_word is None:
|
||||||
@@ -95,6 +95,15 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
# always do a prior prediction when doing diff output preservation
|
# always do a prior prediction when doing diff output preservation
|
||||||
self.do_prior_prediction = True
|
self.do_prior_prediction = True
|
||||||
|
|
||||||
|
# store the loss target for a batch so we can use it in a loss
|
||||||
|
self._guidance_loss_target_batch: float = 0.0
|
||||||
|
if isinstance(self.train_config.guidance_loss_target, (int, float)):
|
||||||
|
self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target)
|
||||||
|
elif isinstance(self.train_config.guidance_loss_target, list):
|
||||||
|
self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target[0])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown guidance loss target type {type(self.train_config.guidance_loss_target)}")
|
||||||
|
|
||||||
|
|
||||||
def before_model_load(self):
|
def before_model_load(self):
|
||||||
@@ -135,6 +144,16 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
def hook_before_train_loop(self):
|
def hook_before_train_loop(self):
|
||||||
super().hook_before_train_loop()
|
super().hook_before_train_loop()
|
||||||
|
|
||||||
|
# cache unconditional embeds (blank prompt)
|
||||||
|
with torch.no_grad():
|
||||||
|
self.unconditional_embeds = self.sd.encode_prompt(
|
||||||
|
[''],
|
||||||
|
long_prompts=self.do_long_prompts
|
||||||
|
).to(
|
||||||
|
self.device_torch,
|
||||||
|
dtype=self.sd.torch_dtype
|
||||||
|
).detach()
|
||||||
|
|
||||||
if self.train_config.do_prior_divergence:
|
if self.train_config.do_prior_divergence:
|
||||||
self.do_prior_prediction = True
|
self.do_prior_prediction = True
|
||||||
# move vae to device if we did not cache latents
|
# move vae to device if we did not cache latents
|
||||||
@@ -476,6 +495,47 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
|
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}")
|
raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}")
|
||||||
|
|
||||||
|
if self.train_config.do_guidance_loss:
|
||||||
|
with torch.no_grad():
|
||||||
|
# we make cached blank prompt embeds that match the batch size
|
||||||
|
unconditional_embeds = concat_prompt_embeds(
|
||||||
|
[self.unconditional_embeds] * noisy_latents.shape[0],
|
||||||
|
)
|
||||||
|
cfm_pred = self.predict_noise(
|
||||||
|
noisy_latents=noisy_latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
conditional_embeds=unconditional_embeds,
|
||||||
|
unconditional_embeds=None,
|
||||||
|
batch=batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# zero cfg
|
||||||
|
|
||||||
|
# ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557
|
||||||
|
batch_size = target.shape[0]
|
||||||
|
positive_flat = target.view(batch_size, -1)
|
||||||
|
negative_flat = cfm_pred.view(batch_size, -1)
|
||||||
|
# Calculate dot production
|
||||||
|
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||||
|
# Squared norm of uncondition
|
||||||
|
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||||
|
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||||
|
st_star = dot_product / squared_norm
|
||||||
|
|
||||||
|
alpha = st_star
|
||||||
|
|
||||||
|
is_video = len(target.shape) == 5
|
||||||
|
|
||||||
|
alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1)
|
||||||
|
|
||||||
|
guidance_scale = self._guidance_loss_target_batch
|
||||||
|
if isinstance(guidance_scale, list):
|
||||||
|
guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype)
|
||||||
|
guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1)
|
||||||
|
|
||||||
|
unconditional_target = cfm_pred * alpha
|
||||||
|
target = unconditional_target + guidance_scale * (target - unconditional_target)
|
||||||
|
|
||||||
|
|
||||||
if target is None:
|
if target is None:
|
||||||
@@ -895,6 +955,10 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
if unconditional_embeds is not None:
|
if unconditional_embeds is not None:
|
||||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||||
|
|
||||||
|
guidance_embedding_scale = self.train_config.cfg_scale
|
||||||
|
if self.train_config.do_guidance_loss:
|
||||||
|
guidance_embedding_scale = self._guidance_loss_target_batch
|
||||||
|
|
||||||
prior_pred = self.sd.predict_noise(
|
prior_pred = self.sd.predict_noise(
|
||||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||||
@@ -902,6 +966,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
unconditional_embeddings=unconditional_embeds,
|
unconditional_embeddings=unconditional_embeds,
|
||||||
timestep=timesteps,
|
timestep=timesteps,
|
||||||
guidance_scale=self.train_config.cfg_scale,
|
guidance_scale=self.train_config.cfg_scale,
|
||||||
|
guidance_embedding_scale=guidance_embedding_scale,
|
||||||
rescale_cfg=self.train_config.cfg_rescale,
|
rescale_cfg=self.train_config.cfg_rescale,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
**pred_kwargs # adapter residuals in here
|
**pred_kwargs # adapter residuals in here
|
||||||
@@ -945,13 +1010,16 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
|
guidance_embedding_scale = self.train_config.cfg_scale
|
||||||
|
if self.train_config.do_guidance_loss:
|
||||||
|
guidance_embedding_scale = self._guidance_loss_target_batch
|
||||||
return self.sd.predict_noise(
|
return self.sd.predict_noise(
|
||||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||||
unconditional_embeddings=unconditional_embeds,
|
unconditional_embeddings=unconditional_embeds,
|
||||||
timestep=timesteps,
|
timestep=timesteps,
|
||||||
guidance_scale=self.train_config.cfg_scale,
|
guidance_scale=self.train_config.cfg_scale,
|
||||||
guidance_embedding_scale=self.train_config.cfg_scale,
|
guidance_embedding_scale=guidance_embedding_scale,
|
||||||
detach_unconditional=False,
|
detach_unconditional=False,
|
||||||
rescale_cfg=self.train_config.cfg_rescale,
|
rescale_cfg=self.train_config.cfg_rescale,
|
||||||
bypass_guidance_embedding=self.train_config.bypass_guidance_embedding,
|
bypass_guidance_embedding=self.train_config.bypass_guidance_embedding,
|
||||||
@@ -959,80 +1027,6 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def cfm_augment_tensors(
|
|
||||||
self,
|
|
||||||
images: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if self.cfm_cache is None:
|
|
||||||
# flip the current one. Only need this for first time
|
|
||||||
self.cfm_cache = torch.flip(images, [3]).clone()
|
|
||||||
augmented_tensor_list = []
|
|
||||||
for i in range(images.shape[0]):
|
|
||||||
# get a random one
|
|
||||||
idx = random.randint(0, self.cfm_cache.shape[0] - 1)
|
|
||||||
augmented_tensor_list.append(self.cfm_cache[idx:idx + 1])
|
|
||||||
augmented = torch.cat(augmented_tensor_list, dim=0)
|
|
||||||
# resize to match the input
|
|
||||||
augmented = torch.nn.functional.interpolate(augmented, size=(images.shape[2], images.shape[3]), mode='bilinear')
|
|
||||||
self.cfm_cache = images.clone()
|
|
||||||
return augmented
|
|
||||||
|
|
||||||
def get_cfm_loss(
|
|
||||||
self,
|
|
||||||
noisy_latents: torch.Tensor,
|
|
||||||
noise: torch.Tensor,
|
|
||||||
noise_pred: torch.Tensor,
|
|
||||||
conditional_embeds: PromptEmbeds,
|
|
||||||
timesteps: torch.Tensor,
|
|
||||||
batch: 'DataLoaderBatchDTO',
|
|
||||||
alpha: float = 0.1,
|
|
||||||
):
|
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
|
||||||
if hasattr(self.sd, 'get_loss_target'):
|
|
||||||
target = self.sd.get_loss_target(
|
|
||||||
noise=noise,
|
|
||||||
batch=batch,
|
|
||||||
timesteps=timesteps,
|
|
||||||
).detach()
|
|
||||||
|
|
||||||
elif self.sd.is_flow_matching:
|
|
||||||
# forward ODE
|
|
||||||
target = (noise - batch.latents).detach()
|
|
||||||
else:
|
|
||||||
raise ValueError("CFM loss only works with flow matching")
|
|
||||||
fm_loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
|
||||||
with torch.no_grad():
|
|
||||||
# we need to compute the contrast
|
|
||||||
cfm_batch_tensors = self.cfm_augment_tensors(batch.tensor).to(self.device_torch, dtype=dtype)
|
|
||||||
cfm_latents = self.sd.encode_images(cfm_batch_tensors).to(self.device_torch, dtype=dtype)
|
|
||||||
cfm_noisy_latents = self.sd.add_noise(
|
|
||||||
original_samples=cfm_latents,
|
|
||||||
noise=noise,
|
|
||||||
timesteps=timesteps,
|
|
||||||
)
|
|
||||||
cfm_pred = self.predict_noise(
|
|
||||||
noisy_latents=cfm_noisy_latents,
|
|
||||||
timesteps=timesteps,
|
|
||||||
conditional_embeds=conditional_embeds,
|
|
||||||
unconditional_embeds=None,
|
|
||||||
batch=batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
# v_neg = torch.nn.functional.normalize(cfm_pred.float(), dim=1)
|
|
||||||
# v_pos = torch.nn.functional.normalize(noise_pred.float(), dim=1) # shape: (B, C, H, W)
|
|
||||||
|
|
||||||
# # Compute cosine similarity at each pixel
|
|
||||||
# sim = (v_pos * v_neg).sum(dim=1) # shape: (B, H, W)
|
|
||||||
|
|
||||||
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
|
|
||||||
# Compute cosine similarity at each pixel
|
|
||||||
sim = cos(cfm_pred.float(), noise_pred.float()) # shape: (B, H, W)
|
|
||||||
|
|
||||||
# Average over spatial dimensions, then batch
|
|
||||||
contrastive_loss = -sim.mean()
|
|
||||||
|
|
||||||
loss = fm_loss.mean() + alpha * contrastive_loss
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -1658,6 +1652,16 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
)
|
)
|
||||||
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
|
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
|
||||||
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
|
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
|
||||||
|
|
||||||
|
if self.train_config.do_guidance_loss and isinstance(self.train_config.guidance_loss_target, list):
|
||||||
|
batch_size = noisy_latents.shape[0]
|
||||||
|
# update the guidance value, random float between guidance_loss_target[0] and guidance_loss_target[1]
|
||||||
|
self._guidance_loss_target_batch = [
|
||||||
|
random.uniform(
|
||||||
|
self.train_config.guidance_loss_target[0],
|
||||||
|
self.train_config.guidance_loss_target[1]
|
||||||
|
) for _ in range(batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
self.before_unet_predict()
|
self.before_unet_predict()
|
||||||
|
|
||||||
@@ -1757,25 +1761,15 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
|
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
|
||||||
prior_to_calculate_loss = None
|
prior_to_calculate_loss = None
|
||||||
|
|
||||||
if self.train_config.loss_type == 'cfm':
|
loss = self.calculate_loss(
|
||||||
loss = self.get_cfm_loss(
|
noise_pred=noise_pred,
|
||||||
noisy_latents=noisy_latents,
|
noise=noise,
|
||||||
noise=noise,
|
noisy_latents=noisy_latents,
|
||||||
noise_pred=noise_pred,
|
timesteps=timesteps,
|
||||||
conditional_embeds=conditional_embeds,
|
batch=batch,
|
||||||
timesteps=timesteps,
|
mask_multiplier=mask_multiplier,
|
||||||
batch=batch,
|
prior_pred=prior_to_calculate_loss,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
loss = self.calculate_loss(
|
|
||||||
noise_pred=noise_pred,
|
|
||||||
noise=noise,
|
|
||||||
noisy_latents=noisy_latents,
|
|
||||||
timesteps=timesteps,
|
|
||||||
batch=batch,
|
|
||||||
mask_multiplier=mask_multiplier,
|
|
||||||
prior_pred=prior_to_calculate_loss,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.train_config.diff_output_preservation:
|
if self.train_config.diff_output_preservation:
|
||||||
# send the loss backwards otherwise checkpointing will fail
|
# send the loss backwards otherwise checkpointing will fail
|
||||||
|
|||||||
@@ -1012,7 +1012,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
imgs = None
|
imgs = None
|
||||||
is_reg = any(batch.get_is_reg_list())
|
is_reg = any(batch.get_is_reg_list())
|
||||||
cfm_batch = None
|
|
||||||
if batch.tensor is not None:
|
if batch.tensor is not None:
|
||||||
imgs = batch.tensor
|
imgs = batch.tensor
|
||||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||||
|
|||||||
@@ -113,6 +113,8 @@ class GenerateProcess(BaseProcess):
|
|||||||
prompt_image_configs = []
|
prompt_image_configs = []
|
||||||
for _ in range(self.generate_config.num_repeats):
|
for _ in range(self.generate_config.num_repeats):
|
||||||
for prompt in self.generate_config.prompts:
|
for prompt in self.generate_config.prompts:
|
||||||
|
# remove --
|
||||||
|
prompt = prompt.replace('--', '').strip()
|
||||||
width = self.generate_config.width
|
width = self.generate_config.width
|
||||||
height = self.generate_config.height
|
height = self.generate_config.height
|
||||||
# prompt = self.clean_prompt(prompt)
|
# prompt = self.clean_prompt(prompt)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Literal, Union, TYPE_CHECKING, Dict
|
from typing import List, Optional, Literal, Tuple, Union, TYPE_CHECKING, Dict
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -413,7 +413,7 @@ class TrainConfig:
|
|||||||
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
||||||
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
||||||
|
|
||||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm, mean_flow
|
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, mean_flow
|
||||||
|
|
||||||
# scale the prediction by this. Increase for more detail, decrease for less
|
# scale the prediction by this. Increase for more detail, decrease for less
|
||||||
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
||||||
@@ -467,6 +467,12 @@ class TrainConfig:
|
|||||||
# forces same noise for the same image at a given size.
|
# forces same noise for the same image at a given size.
|
||||||
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
|
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
|
||||||
self.blended_blur_noise = kwargs.get('blended_blur_noise', False)
|
self.blended_blur_noise = kwargs.get('blended_blur_noise', False)
|
||||||
|
|
||||||
|
# contrastive loss
|
||||||
|
self.do_guidance_loss = kwargs.get('do_guidance_loss', False)
|
||||||
|
self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0)
|
||||||
|
if isinstance(self.guidance_loss_target, tuple):
|
||||||
|
self.guidance_loss_target = list(self.guidance_loss_target)
|
||||||
|
|
||||||
|
|
||||||
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
|
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
|
||||||
@@ -1145,3 +1151,6 @@ def validate_configs(
|
|||||||
if model_config.use_flux_cfg:
|
if model_config.use_flux_cfg:
|
||||||
# bypass the embedding
|
# bypass the embedding
|
||||||
train_config.bypass_guidance_embedding = True
|
train_config.bypass_guidance_embedding = True
|
||||||
|
if train_config.bypass_guidance_embedding and train_config.do_guidance_loss:
|
||||||
|
raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. "
|
||||||
|
"Please set bypass_guidance_embedding to False or do_guidance_loss to False.")
|
||||||
|
|||||||
@@ -815,7 +815,10 @@ class BaseModel:
|
|||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
if self.unet.device != self.device_torch:
|
if self.unet.device != self.device_torch:
|
||||||
self.unet.to(self.device_torch)
|
try:
|
||||||
|
self.unet.to(self.device_torch)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
if self.unet.dtype != self.torch_dtype:
|
if self.unet.dtype != self.torch_dtype:
|
||||||
self.unet = self.unet.to(dtype=self.torch_dtype)
|
self.unet = self.unet.to(dtype=self.torch_dtype)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user