From ba1274d99e369f9ee6acee9997cd2d6e8e26d853 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 23 Jun 2025 08:38:27 -0600 Subject: [PATCH] Added a guidance burning loss. Modified DFE to work with new model. Bug fixes --- extensions_built_in/sd_trainer/SDTrainer.py | 184 ++++++++++---------- jobs/process/BaseSDTrainProcess.py | 1 - jobs/process/GenerateProcess.py | 2 + toolkit/config_modules.py | 13 +- toolkit/models/base_model.py | 5 +- 5 files changed, 106 insertions(+), 99 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 1b9f35b1..1722f236 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -61,7 +61,6 @@ class SDTrainer(BaseSDTrainProcess): self._clip_image_embeds_unconditional: Union[List[str], None] = None self.negative_prompt_pool: Union[List[str], None] = None self.batch_negative_prompt: Union[List[str], None] = None - self.cfm_cache = None self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" @@ -84,6 +83,7 @@ class SDTrainer(BaseSDTrainProcess): self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None self.dfe: Optional[DiffusionFeatureExtractor] = None + self.unconditional_embeds = None if self.train_config.diff_output_preservation: if self.trigger_word is None: @@ -95,6 +95,15 @@ class SDTrainer(BaseSDTrainProcess): # always do a prior prediction when doing diff output preservation self.do_prior_prediction = True + + # store the loss target for a batch so we can use it in a loss + self._guidance_loss_target_batch: float = 0.0 + if isinstance(self.train_config.guidance_loss_target, (int, float)): + self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target) + elif isinstance(self.train_config.guidance_loss_target, list): + self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target[0]) + else: + raise ValueError(f"Unknown guidance loss target type {type(self.train_config.guidance_loss_target)}") def before_model_load(self): @@ -135,6 +144,16 @@ class SDTrainer(BaseSDTrainProcess): def hook_before_train_loop(self): super().hook_before_train_loop() + # cache unconditional embeds (blank prompt) + with torch.no_grad(): + self.unconditional_embeds = self.sd.encode_prompt( + [''], + long_prompts=self.do_long_prompts + ).to( + self.device_torch, + dtype=self.sd.torch_dtype + ).detach() + if self.train_config.do_prior_divergence: self.do_prior_prediction = True # move vae to device if we did not cache latents @@ -476,6 +495,47 @@ class SDTrainer(BaseSDTrainProcess): additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight else: raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}") + + if self.train_config.do_guidance_loss: + with torch.no_grad(): + # we make cached blank prompt embeds that match the batch size + unconditional_embeds = concat_prompt_embeds( + [self.unconditional_embeds] * noisy_latents.shape[0], + ) + cfm_pred = self.predict_noise( + noisy_latents=noisy_latents, + timesteps=timesteps, + conditional_embeds=unconditional_embeds, + unconditional_embeds=None, + batch=batch, + ) + + # zero cfg + + # ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557 + batch_size = target.shape[0] + positive_flat = target.view(batch_size, -1) + negative_flat = cfm_pred.view(batch_size, -1) + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + alpha = st_star + + is_video = len(target.shape) == 5 + + alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1) + + guidance_scale = self._guidance_loss_target_batch + if isinstance(guidance_scale, list): + guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype) + guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1) + + unconditional_target = cfm_pred * alpha + target = unconditional_target + guidance_scale * (target - unconditional_target) if target is None: @@ -895,6 +955,10 @@ class SDTrainer(BaseSDTrainProcess): if unconditional_embeds is not None: unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + + guidance_embedding_scale = self.train_config.cfg_scale + if self.train_config.do_guidance_loss: + guidance_embedding_scale = self._guidance_loss_target_batch prior_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), @@ -902,6 +966,7 @@ class SDTrainer(BaseSDTrainProcess): unconditional_embeddings=unconditional_embeds, timestep=timesteps, guidance_scale=self.train_config.cfg_scale, + guidance_embedding_scale=guidance_embedding_scale, rescale_cfg=self.train_config.cfg_rescale, batch=batch, **pred_kwargs # adapter residuals in here @@ -945,13 +1010,16 @@ class SDTrainer(BaseSDTrainProcess): **kwargs, ): dtype = get_torch_dtype(self.train_config.dtype) + guidance_embedding_scale = self.train_config.cfg_scale + if self.train_config.do_guidance_loss: + guidance_embedding_scale = self._guidance_loss_target_batch return self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), unconditional_embeddings=unconditional_embeds, timestep=timesteps, guidance_scale=self.train_config.cfg_scale, - guidance_embedding_scale=self.train_config.cfg_scale, + guidance_embedding_scale=guidance_embedding_scale, detach_unconditional=False, rescale_cfg=self.train_config.cfg_rescale, bypass_guidance_embedding=self.train_config.bypass_guidance_embedding, @@ -959,80 +1027,6 @@ class SDTrainer(BaseSDTrainProcess): **kwargs ) - def cfm_augment_tensors( - self, - images: torch.Tensor - ) -> torch.Tensor: - if self.cfm_cache is None: - # flip the current one. Only need this for first time - self.cfm_cache = torch.flip(images, [3]).clone() - augmented_tensor_list = [] - for i in range(images.shape[0]): - # get a random one - idx = random.randint(0, self.cfm_cache.shape[0] - 1) - augmented_tensor_list.append(self.cfm_cache[idx:idx + 1]) - augmented = torch.cat(augmented_tensor_list, dim=0) - # resize to match the input - augmented = torch.nn.functional.interpolate(augmented, size=(images.shape[2], images.shape[3]), mode='bilinear') - self.cfm_cache = images.clone() - return augmented - - def get_cfm_loss( - self, - noisy_latents: torch.Tensor, - noise: torch.Tensor, - noise_pred: torch.Tensor, - conditional_embeds: PromptEmbeds, - timesteps: torch.Tensor, - batch: 'DataLoaderBatchDTO', - alpha: float = 0.1, - ): - dtype = get_torch_dtype(self.train_config.dtype) - if hasattr(self.sd, 'get_loss_target'): - target = self.sd.get_loss_target( - noise=noise, - batch=batch, - timesteps=timesteps, - ).detach() - - elif self.sd.is_flow_matching: - # forward ODE - target = (noise - batch.latents).detach() - else: - raise ValueError("CFM loss only works with flow matching") - fm_loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - with torch.no_grad(): - # we need to compute the contrast - cfm_batch_tensors = self.cfm_augment_tensors(batch.tensor).to(self.device_torch, dtype=dtype) - cfm_latents = self.sd.encode_images(cfm_batch_tensors).to(self.device_torch, dtype=dtype) - cfm_noisy_latents = self.sd.add_noise( - original_samples=cfm_latents, - noise=noise, - timesteps=timesteps, - ) - cfm_pred = self.predict_noise( - noisy_latents=cfm_noisy_latents, - timesteps=timesteps, - conditional_embeds=conditional_embeds, - unconditional_embeds=None, - batch=batch, - ) - - # v_neg = torch.nn.functional.normalize(cfm_pred.float(), dim=1) - # v_pos = torch.nn.functional.normalize(noise_pred.float(), dim=1) # shape: (B, C, H, W) - - # # Compute cosine similarity at each pixel - # sim = (v_pos * v_neg).sum(dim=1) # shape: (B, H, W) - - cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) - # Compute cosine similarity at each pixel - sim = cos(cfm_pred.float(), noise_pred.float()) # shape: (B, H, W) - - # Average over spatial dimensions, then batch - contrastive_loss = -sim.mean() - - loss = fm_loss.mean() + alpha * contrastive_loss - return loss def train_single_accumulation(self, batch: DataLoaderBatchDTO): with torch.no_grad(): @@ -1658,6 +1652,16 @@ class SDTrainer(BaseSDTrainProcess): ) pred_kwargs['down_block_additional_residuals'] = down_block_res_samples pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample + + if self.train_config.do_guidance_loss and isinstance(self.train_config.guidance_loss_target, list): + batch_size = noisy_latents.shape[0] + # update the guidance value, random float between guidance_loss_target[0] and guidance_loss_target[1] + self._guidance_loss_target_batch = [ + random.uniform( + self.train_config.guidance_loss_target[0], + self.train_config.guidance_loss_target[1] + ) for _ in range(batch_size) + ] self.before_unet_predict() @@ -1757,25 +1761,15 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.diff_output_preservation and not do_inverted_masked_prior: prior_to_calculate_loss = None - if self.train_config.loss_type == 'cfm': - loss = self.get_cfm_loss( - noisy_latents=noisy_latents, - noise=noise, - noise_pred=noise_pred, - conditional_embeds=conditional_embeds, - timesteps=timesteps, - batch=batch, - ) - else: - loss = self.calculate_loss( - noise_pred=noise_pred, - noise=noise, - noisy_latents=noisy_latents, - timesteps=timesteps, - batch=batch, - mask_multiplier=mask_multiplier, - prior_pred=prior_to_calculate_loss, - ) + loss = self.calculate_loss( + noise_pred=noise_pred, + noise=noise, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + mask_multiplier=mask_multiplier, + prior_pred=prior_to_calculate_loss, + ) if self.train_config.diff_output_preservation: # send the loss backwards otherwise checkpointing will fail diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 1528e772..814796e6 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1012,7 +1012,6 @@ class BaseSDTrainProcess(BaseTrainProcess): dtype = get_torch_dtype(self.train_config.dtype) imgs = None is_reg = any(batch.get_is_reg_list()) - cfm_batch = None if batch.tensor is not None: imgs = batch.tensor imgs = imgs.to(self.device_torch, dtype=dtype) diff --git a/jobs/process/GenerateProcess.py b/jobs/process/GenerateProcess.py index e0cb32d8..44fc6b28 100644 --- a/jobs/process/GenerateProcess.py +++ b/jobs/process/GenerateProcess.py @@ -113,6 +113,8 @@ class GenerateProcess(BaseProcess): prompt_image_configs = [] for _ in range(self.generate_config.num_repeats): for prompt in self.generate_config.prompts: + # remove -- + prompt = prompt.replace('--', '').strip() width = self.generate_config.width height = self.generate_config.height # prompt = self.clean_prompt(prompt) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index c682b588..b53b6613 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -1,6 +1,6 @@ import os import time -from typing import List, Optional, Literal, Union, TYPE_CHECKING, Dict +from typing import List, Optional, Literal, Tuple, Union, TYPE_CHECKING, Dict import random import torch @@ -413,7 +413,7 @@ class TrainConfig: self.correct_pred_norm = kwargs.get('correct_pred_norm', False) self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0) - self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm, mean_flow + self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, mean_flow # scale the prediction by this. Increase for more detail, decrease for less self.pred_scaler = kwargs.get('pred_scaler', 1.0) @@ -467,6 +467,12 @@ class TrainConfig: # forces same noise for the same image at a given size. self.force_consistent_noise = kwargs.get('force_consistent_noise', False) self.blended_blur_noise = kwargs.get('blended_blur_noise', False) + + # contrastive loss + self.do_guidance_loss = kwargs.get('do_guidance_loss', False) + self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0) + if isinstance(self.guidance_loss_target, tuple): + self.guidance_loss_target = list(self.guidance_loss_target) ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21'] @@ -1145,3 +1151,6 @@ def validate_configs( if model_config.use_flux_cfg: # bypass the embedding train_config.bypass_guidance_embedding = True + if train_config.bypass_guidance_embedding and train_config.do_guidance_loss: + raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. " + "Please set bypass_guidance_embedding to False or do_guidance_loss to False.") diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 17550dde..17dc6b55 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -815,7 +815,10 @@ class BaseModel: # predict the noise residual if self.unet.device != self.device_torch: - self.unet.to(self.device_torch) + try: + self.unet.to(self.device_torch) + except Exception as e: + pass if self.unet.dtype != self.torch_dtype: self.unet = self.unet.to(dtype=self.torch_dtype)