When soing guidance loss, make CFG zero an optional target instead of a forced one.

This commit is contained in:
Jaret Burkett
2025-11-04 09:16:15 -07:00
parent c984369294
commit 6f308fc46e
2 changed files with 21 additions and 19 deletions

View File

@@ -673,20 +673,21 @@ class SDTrainer(BaseSDTrainProcess):
unconditional_embeds = concat_prompt_embeds( unconditional_embeds = concat_prompt_embeds(
[self.unconditional_embeds] * noisy_latents.shape[0], [self.unconditional_embeds] * noisy_latents.shape[0],
) )
cfm_pred = self.predict_noise( unconditional_target = self.predict_noise(
noisy_latents=noisy_latents, noisy_latents=noisy_latents,
timesteps=timesteps, timesteps=timesteps,
conditional_embeds=unconditional_embeds, conditional_embeds=unconditional_embeds,
unconditional_embeds=None, unconditional_embeds=None,
batch=batch, batch=batch,
) )
is_video = len(target.shape) == 5
if self.train_config.do_guidance_loss_cfg_zero:
# zero cfg # zero cfg
# ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557 # ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557
batch_size = target.shape[0] batch_size = target.shape[0]
positive_flat = target.view(batch_size, -1) positive_flat = target.view(batch_size, -1)
negative_flat = cfm_pred.view(batch_size, -1) negative_flat = unconditional_target.view(batch_size, -1)
# Calculate dot production # Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition # Squared norm of uncondition
@@ -696,16 +697,16 @@ class SDTrainer(BaseSDTrainProcess):
alpha = st_star alpha = st_star
is_video = len(target.shape) == 5
alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1) alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1)
else:
alpha = 1.0
guidance_scale = self._guidance_loss_target_batch guidance_scale = self._guidance_loss_target_batch
if isinstance(guidance_scale, list): if isinstance(guidance_scale, list):
guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype) guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype)
guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1) guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1)
unconditional_target = cfm_pred * alpha unconditional_target = unconditional_target * alpha
target = unconditional_target + guidance_scale * (target - unconditional_target) target = unconditional_target + guidance_scale * (target - unconditional_target)

View File

@@ -541,6 +541,7 @@ class TrainConfig:
# contrastive loss # contrastive loss
self.do_guidance_loss = kwargs.get('do_guidance_loss', False) self.do_guidance_loss = kwargs.get('do_guidance_loss', False)
self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0) self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0)
self.do_guidance_loss_cfg_zero: bool = kwargs.get('do_guidance_loss_cfg_zero', False)
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '') self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
if isinstance(self.guidance_loss_target, tuple): if isinstance(self.guidance_loss_target, tuple):
self.guidance_loss_target = list(self.guidance_loss_target) self.guidance_loss_target = list(self.guidance_loss_target)