mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
When soing guidance loss, make CFG zero an optional target instead of a forced one.
This commit is contained in:
@@ -673,39 +673,40 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
unconditional_embeds = concat_prompt_embeds(
|
||||
[self.unconditional_embeds] * noisy_latents.shape[0],
|
||||
)
|
||||
cfm_pred = self.predict_noise(
|
||||
unconditional_target = self.predict_noise(
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=unconditional_embeds,
|
||||
unconditional_embeds=None,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
# zero cfg
|
||||
|
||||
# ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557
|
||||
batch_size = target.shape[0]
|
||||
positive_flat = target.view(batch_size, -1)
|
||||
negative_flat = cfm_pred.view(batch_size, -1)
|
||||
# Calculate dot production
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
# Squared norm of uncondition
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
st_star = dot_product / squared_norm
|
||||
|
||||
alpha = st_star
|
||||
|
||||
is_video = len(target.shape) == 5
|
||||
|
||||
alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1)
|
||||
if self.train_config.do_guidance_loss_cfg_zero:
|
||||
# zero cfg
|
||||
# ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557
|
||||
batch_size = target.shape[0]
|
||||
positive_flat = target.view(batch_size, -1)
|
||||
negative_flat = unconditional_target.view(batch_size, -1)
|
||||
# Calculate dot production
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
# Squared norm of uncondition
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
st_star = dot_product / squared_norm
|
||||
|
||||
alpha = st_star
|
||||
|
||||
alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1)
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
guidance_scale = self._guidance_loss_target_batch
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype)
|
||||
guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1)
|
||||
|
||||
unconditional_target = cfm_pred * alpha
|
||||
unconditional_target = unconditional_target * alpha
|
||||
target = unconditional_target + guidance_scale * (target - unconditional_target)
|
||||
|
||||
|
||||
|
||||
@@ -541,6 +541,7 @@ class TrainConfig:
|
||||
# contrastive loss
|
||||
self.do_guidance_loss = kwargs.get('do_guidance_loss', False)
|
||||
self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0)
|
||||
self.do_guidance_loss_cfg_zero: bool = kwargs.get('do_guidance_loss_cfg_zero', False)
|
||||
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
|
||||
if isinstance(self.guidance_loss_target, tuple):
|
||||
self.guidance_loss_target = list(self.guidance_loss_target)
|
||||
|
||||
Reference in New Issue
Block a user