Added a guidance burning loss. Modified DFE to work with new model. Bug fixes

This commit is contained in:
Jaret Burkett
2025-06-23 08:38:27 -06:00
parent 8602470952
commit ba1274d99e
5 changed files with 106 additions and 99 deletions

View File

@@ -61,7 +61,6 @@ class SDTrainer(BaseSDTrainProcess):
self._clip_image_embeds_unconditional: Union[List[str], None] = None
self.negative_prompt_pool: Union[List[str], None] = None
self.batch_negative_prompt: Union[List[str], None] = None
self.cfm_cache = None
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
@@ -84,6 +83,7 @@ class SDTrainer(BaseSDTrainProcess):
self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None
self.dfe: Optional[DiffusionFeatureExtractor] = None
self.unconditional_embeds = None
if self.train_config.diff_output_preservation:
if self.trigger_word is None:
@@ -95,6 +95,15 @@ class SDTrainer(BaseSDTrainProcess):
# always do a prior prediction when doing diff output preservation
self.do_prior_prediction = True
# store the loss target for a batch so we can use it in a loss
self._guidance_loss_target_batch: float = 0.0
if isinstance(self.train_config.guidance_loss_target, (int, float)):
self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target)
elif isinstance(self.train_config.guidance_loss_target, list):
self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target[0])
else:
raise ValueError(f"Unknown guidance loss target type {type(self.train_config.guidance_loss_target)}")
def before_model_load(self):
@@ -135,6 +144,16 @@ class SDTrainer(BaseSDTrainProcess):
def hook_before_train_loop(self):
super().hook_before_train_loop()
# cache unconditional embeds (blank prompt)
with torch.no_grad():
self.unconditional_embeds = self.sd.encode_prompt(
[''],
long_prompts=self.do_long_prompts
).to(
self.device_torch,
dtype=self.sd.torch_dtype
).detach()
if self.train_config.do_prior_divergence:
self.do_prior_prediction = True
# move vae to device if we did not cache latents
@@ -476,6 +495,47 @@ class SDTrainer(BaseSDTrainProcess):
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
else:
raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}")
if self.train_config.do_guidance_loss:
with torch.no_grad():
# we make cached blank prompt embeds that match the batch size
unconditional_embeds = concat_prompt_embeds(
[self.unconditional_embeds] * noisy_latents.shape[0],
)
cfm_pred = self.predict_noise(
noisy_latents=noisy_latents,
timesteps=timesteps,
conditional_embeds=unconditional_embeds,
unconditional_embeds=None,
batch=batch,
)
# zero cfg
# ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557
batch_size = target.shape[0]
positive_flat = target.view(batch_size, -1)
negative_flat = cfm_pred.view(batch_size, -1)
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
st_star = dot_product / squared_norm
alpha = st_star
is_video = len(target.shape) == 5
alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1)
guidance_scale = self._guidance_loss_target_batch
if isinstance(guidance_scale, list):
guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype)
guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1)
unconditional_target = cfm_pred * alpha
target = unconditional_target + guidance_scale * (target - unconditional_target)
if target is None:
@@ -895,6 +955,10 @@ class SDTrainer(BaseSDTrainProcess):
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
guidance_embedding_scale = self.train_config.cfg_scale
if self.train_config.do_guidance_loss:
guidance_embedding_scale = self._guidance_loss_target_batch
prior_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
@@ -902,6 +966,7 @@ class SDTrainer(BaseSDTrainProcess):
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
guidance_embedding_scale=guidance_embedding_scale,
rescale_cfg=self.train_config.cfg_rescale,
batch=batch,
**pred_kwargs # adapter residuals in here
@@ -945,13 +1010,16 @@ class SDTrainer(BaseSDTrainProcess):
**kwargs,
):
dtype = get_torch_dtype(self.train_config.dtype)
guidance_embedding_scale = self.train_config.cfg_scale
if self.train_config.do_guidance_loss:
guidance_embedding_scale = self._guidance_loss_target_batch
return self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
guidance_embedding_scale=self.train_config.cfg_scale,
guidance_embedding_scale=guidance_embedding_scale,
detach_unconditional=False,
rescale_cfg=self.train_config.cfg_rescale,
bypass_guidance_embedding=self.train_config.bypass_guidance_embedding,
@@ -959,80 +1027,6 @@ class SDTrainer(BaseSDTrainProcess):
**kwargs
)
def cfm_augment_tensors(
self,
images: torch.Tensor
) -> torch.Tensor:
if self.cfm_cache is None:
# flip the current one. Only need this for first time
self.cfm_cache = torch.flip(images, [3]).clone()
augmented_tensor_list = []
for i in range(images.shape[0]):
# get a random one
idx = random.randint(0, self.cfm_cache.shape[0] - 1)
augmented_tensor_list.append(self.cfm_cache[idx:idx + 1])
augmented = torch.cat(augmented_tensor_list, dim=0)
# resize to match the input
augmented = torch.nn.functional.interpolate(augmented, size=(images.shape[2], images.shape[3]), mode='bilinear')
self.cfm_cache = images.clone()
return augmented
def get_cfm_loss(
self,
noisy_latents: torch.Tensor,
noise: torch.Tensor,
noise_pred: torch.Tensor,
conditional_embeds: PromptEmbeds,
timesteps: torch.Tensor,
batch: 'DataLoaderBatchDTO',
alpha: float = 0.1,
):
dtype = get_torch_dtype(self.train_config.dtype)
if hasattr(self.sd, 'get_loss_target'):
target = self.sd.get_loss_target(
noise=noise,
batch=batch,
timesteps=timesteps,
).detach()
elif self.sd.is_flow_matching:
# forward ODE
target = (noise - batch.latents).detach()
else:
raise ValueError("CFM loss only works with flow matching")
fm_loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
with torch.no_grad():
# we need to compute the contrast
cfm_batch_tensors = self.cfm_augment_tensors(batch.tensor).to(self.device_torch, dtype=dtype)
cfm_latents = self.sd.encode_images(cfm_batch_tensors).to(self.device_torch, dtype=dtype)
cfm_noisy_latents = self.sd.add_noise(
original_samples=cfm_latents,
noise=noise,
timesteps=timesteps,
)
cfm_pred = self.predict_noise(
noisy_latents=cfm_noisy_latents,
timesteps=timesteps,
conditional_embeds=conditional_embeds,
unconditional_embeds=None,
batch=batch,
)
# v_neg = torch.nn.functional.normalize(cfm_pred.float(), dim=1)
# v_pos = torch.nn.functional.normalize(noise_pred.float(), dim=1) # shape: (B, C, H, W)
# # Compute cosine similarity at each pixel
# sim = (v_pos * v_neg).sum(dim=1) # shape: (B, H, W)
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
# Compute cosine similarity at each pixel
sim = cos(cfm_pred.float(), noise_pred.float()) # shape: (B, H, W)
# Average over spatial dimensions, then batch
contrastive_loss = -sim.mean()
loss = fm_loss.mean() + alpha * contrastive_loss
return loss
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
with torch.no_grad():
@@ -1658,6 +1652,16 @@ class SDTrainer(BaseSDTrainProcess):
)
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
if self.train_config.do_guidance_loss and isinstance(self.train_config.guidance_loss_target, list):
batch_size = noisy_latents.shape[0]
# update the guidance value, random float between guidance_loss_target[0] and guidance_loss_target[1]
self._guidance_loss_target_batch = [
random.uniform(
self.train_config.guidance_loss_target[0],
self.train_config.guidance_loss_target[1]
) for _ in range(batch_size)
]
self.before_unet_predict()
@@ -1757,25 +1761,15 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
prior_to_calculate_loss = None
if self.train_config.loss_type == 'cfm':
loss = self.get_cfm_loss(
noisy_latents=noisy_latents,
noise=noise,
noise_pred=noise_pred,
conditional_embeds=conditional_embeds,
timesteps=timesteps,
batch=batch,
)
else:
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_to_calculate_loss,
)
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_to_calculate_loss,
)
if self.train_config.diff_output_preservation:
# send the loss backwards otherwise checkpointing will fail

View File

@@ -1012,7 +1012,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
imgs = None
is_reg = any(batch.get_is_reg_list())
cfm_batch = None
if batch.tensor is not None:
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)

View File

@@ -113,6 +113,8 @@ class GenerateProcess(BaseProcess):
prompt_image_configs = []
for _ in range(self.generate_config.num_repeats):
for prompt in self.generate_config.prompts:
# remove --
prompt = prompt.replace('--', '').strip()
width = self.generate_config.width
height = self.generate_config.height
# prompt = self.clean_prompt(prompt)

View File

@@ -1,6 +1,6 @@
import os
import time
from typing import List, Optional, Literal, Union, TYPE_CHECKING, Dict
from typing import List, Optional, Literal, Tuple, Union, TYPE_CHECKING, Dict
import random
import torch
@@ -413,7 +413,7 @@ class TrainConfig:
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm, mean_flow
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, mean_flow
# scale the prediction by this. Increase for more detail, decrease for less
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
@@ -467,6 +467,12 @@ class TrainConfig:
# forces same noise for the same image at a given size.
self.force_consistent_noise = kwargs.get('force_consistent_noise', False)
self.blended_blur_noise = kwargs.get('blended_blur_noise', False)
# contrastive loss
self.do_guidance_loss = kwargs.get('do_guidance_loss', False)
self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0)
if isinstance(self.guidance_loss_target, tuple):
self.guidance_loss_target = list(self.guidance_loss_target)
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
@@ -1145,3 +1151,6 @@ def validate_configs(
if model_config.use_flux_cfg:
# bypass the embedding
train_config.bypass_guidance_embedding = True
if train_config.bypass_guidance_embedding and train_config.do_guidance_loss:
raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. "
"Please set bypass_guidance_embedding to False or do_guidance_loss to False.")

View File

@@ -815,7 +815,10 @@ class BaseModel:
# predict the noise residual
if self.unet.device != self.device_torch:
self.unet.to(self.device_torch)
try:
self.unet.to(self.device_torch)
except Exception as e:
pass
if self.unet.dtype != self.torch_dtype:
self.unet = self.unet.to(dtype=self.torch_dtype)