Added ability to do cfg during training. Various bug fixes

This commit is contained in:
Jaret Burkett
2024-01-02 11:29:57 -07:00
parent afc231efc1
commit 65c08b09c3
4 changed files with 71 additions and 4 deletions

View File

@@ -1,5 +1,5 @@
import torch
from typing import Literal
from typing import Literal, Optional
from toolkit.basic import value_map
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
@@ -193,6 +193,7 @@ def get_direct_guidance_loss(
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
unconditional_embeds: Optional[PromptEmbeds] = None,
**kwargs
):
with torch.no_grad():
@@ -222,9 +223,14 @@ def get_direct_guidance_loss(
# sd.network.multiplier = network_weight_list
# do our prediction with LoRA active on the scaled guidance latents
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(device, dtype=dtype).detach()
unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds])
prediction = sd.predict_noise(
latents=torch.cat([unconditional_noisy_latents, conditional_noisy_latents]).to(device, dtype=dtype).detach(),
conditional_embeddings=concat_prompt_embeds([conditional_embeds,conditional_embeds]).to(device, dtype=dtype).detach(),
unconditional_embeddings=unconditional_embeds,
timestep=torch.cat([timesteps, timesteps]),
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
@@ -482,12 +488,14 @@ def get_guidance_loss(
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
unconditional_embeds: Optional[PromptEmbeds] = None,
**kwargs
):
# TODO add others and process individual batch items separately
guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type
if guidance_type == "targeted":
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted guidance"
return get_targeted_guidance_loss(
noisy_latents,
conditional_embeds,
@@ -501,6 +509,7 @@ def get_guidance_loss(
**kwargs
)
elif guidance_type == "polarity":
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
return get_guided_loss_polarity(
noisy_latents,
conditional_embeds,
@@ -515,6 +524,7 @@ def get_guidance_loss(
)
elif guidance_type == "targeted_polarity":
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance"
return get_targeted_polarity_loss(
noisy_latents,
conditional_embeds,
@@ -538,6 +548,7 @@ def get_guidance_loss(
batch,
noise,
sd,
unconditional_embeds=unconditional_embeds,
**kwargs
)
else: