From 6f308fc46e52f5470213182fa47c2fea73b84cf0 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 4 Nov 2025 09:16:15 -0700 Subject: [PATCH] When soing guidance loss, make CFG zero an optional target instead of a forced one. --- extensions_built_in/sd_trainer/SDTrainer.py | 39 +++++++++++---------- toolkit/config_modules.py | 1 + 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 59d75988..2233e073 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -673,39 +673,40 @@ class SDTrainer(BaseSDTrainProcess): unconditional_embeds = concat_prompt_embeds( [self.unconditional_embeds] * noisy_latents.shape[0], ) - cfm_pred = self.predict_noise( + unconditional_target = self.predict_noise( noisy_latents=noisy_latents, timesteps=timesteps, conditional_embeds=unconditional_embeds, unconditional_embeds=None, batch=batch, ) - - # zero cfg - - # ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557 - batch_size = target.shape[0] - positive_flat = target.view(batch_size, -1) - negative_flat = cfm_pred.view(batch_size, -1) - # Calculate dot production - dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) - # Squared norm of uncondition - squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 - # st_star = v_cond^T * v_uncond / ||v_uncond||^2 - st_star = dot_product / squared_norm - - alpha = st_star - is_video = len(target.shape) == 5 - alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1) + if self.train_config.do_guidance_loss_cfg_zero: + # zero cfg + # ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557 + batch_size = target.shape[0] + positive_flat = target.view(batch_size, -1) + negative_flat = unconditional_target.view(batch_size, -1) + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + alpha = st_star + + alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1) + else: + alpha = 1.0 guidance_scale = self._guidance_loss_target_batch if isinstance(guidance_scale, list): guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype) guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1) - unconditional_target = cfm_pred * alpha + unconditional_target = unconditional_target * alpha target = unconditional_target + guidance_scale * (target - unconditional_target) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 6aeb9466..22a78189 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -541,6 +541,7 @@ class TrainConfig: # contrastive loss self.do_guidance_loss = kwargs.get('do_guidance_loss', False) self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0) + self.do_guidance_loss_cfg_zero: bool = kwargs.get('do_guidance_loss_cfg_zero', False) self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '') if isinstance(self.guidance_loss_target, tuple): self.guidance_loss_target = list(self.guidance_loss_target)