mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
When soing guidance loss, make CFG zero an optional target instead of a forced one.
This commit is contained in:
@@ -673,20 +673,21 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
unconditional_embeds = concat_prompt_embeds(
|
unconditional_embeds = concat_prompt_embeds(
|
||||||
[self.unconditional_embeds] * noisy_latents.shape[0],
|
[self.unconditional_embeds] * noisy_latents.shape[0],
|
||||||
)
|
)
|
||||||
cfm_pred = self.predict_noise(
|
unconditional_target = self.predict_noise(
|
||||||
noisy_latents=noisy_latents,
|
noisy_latents=noisy_latents,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
conditional_embeds=unconditional_embeds,
|
conditional_embeds=unconditional_embeds,
|
||||||
unconditional_embeds=None,
|
unconditional_embeds=None,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
is_video = len(target.shape) == 5
|
||||||
|
|
||||||
|
if self.train_config.do_guidance_loss_cfg_zero:
|
||||||
# zero cfg
|
# zero cfg
|
||||||
|
|
||||||
# ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557
|
# ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557
|
||||||
batch_size = target.shape[0]
|
batch_size = target.shape[0]
|
||||||
positive_flat = target.view(batch_size, -1)
|
positive_flat = target.view(batch_size, -1)
|
||||||
negative_flat = cfm_pred.view(batch_size, -1)
|
negative_flat = unconditional_target.view(batch_size, -1)
|
||||||
# Calculate dot production
|
# Calculate dot production
|
||||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||||
# Squared norm of uncondition
|
# Squared norm of uncondition
|
||||||
@@ -696,16 +697,16 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
alpha = st_star
|
alpha = st_star
|
||||||
|
|
||||||
is_video = len(target.shape) == 5
|
|
||||||
|
|
||||||
alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1)
|
alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1)
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
guidance_scale = self._guidance_loss_target_batch
|
guidance_scale = self._guidance_loss_target_batch
|
||||||
if isinstance(guidance_scale, list):
|
if isinstance(guidance_scale, list):
|
||||||
guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype)
|
guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype)
|
||||||
guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1)
|
guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1)
|
||||||
|
|
||||||
unconditional_target = cfm_pred * alpha
|
unconditional_target = unconditional_target * alpha
|
||||||
target = unconditional_target + guidance_scale * (target - unconditional_target)
|
target = unconditional_target + guidance_scale * (target - unconditional_target)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -541,6 +541,7 @@ class TrainConfig:
|
|||||||
# contrastive loss
|
# contrastive loss
|
||||||
self.do_guidance_loss = kwargs.get('do_guidance_loss', False)
|
self.do_guidance_loss = kwargs.get('do_guidance_loss', False)
|
||||||
self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0)
|
self.guidance_loss_target: Union[int, List[int, int]] = kwargs.get('guidance_loss_target', 3.0)
|
||||||
|
self.do_guidance_loss_cfg_zero: bool = kwargs.get('do_guidance_loss_cfg_zero', False)
|
||||||
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
|
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
|
||||||
if isinstance(self.guidance_loss_target, tuple):
|
if isinstance(self.guidance_loss_target, tuple):
|
||||||
self.guidance_loss_target = list(self.guidance_loss_target)
|
self.guidance_loss_target = list(self.guidance_loss_target)
|
||||||
|
|||||||
Reference in New Issue
Block a user