Add model hooks to polarity loss

This commit is contained in:
Jaret Burkett
2025-04-17 09:00:10 -06:00
parent 5961ef6c9f
commit c90615f8bb

View File

@@ -453,12 +453,14 @@ def get_guided_loss_polarity(
noise,
timesteps
).detach()
conditional_noisy_latents = sd.condition_noisy_latents(conditional_noisy_latents, batch)
unconditional_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
unconditional_noisy_latents = sd.condition_noisy_latents(unconditional_noisy_latents, batch)
# double up everything to run it through all at once
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])