Imporvements to ip weight adaptation. Bug fixes. Added masking to direct guidance loss. Allow importing a file for random triggers. Handle bas meta images with improper sizing.

This commit is contained in:
Jaret Burkett
2024-01-11 12:22:16 -07:00
parent b2a54c8f36
commit 290393f7ae
5 changed files with 101 additions and 34 deletions

View File

@@ -194,6 +194,8 @@ def get_direct_guidance_loss(
noise: torch.Tensor,
sd: 'StableDiffusion',
unconditional_embeds: Optional[PromptEmbeds] = None,
mask_multiplier=None,
prior_pred=None,
**kwargs
):
with torch.no_grad():
@@ -248,6 +250,8 @@ def get_direct_guidance_loss(
noise.detach().float(),
reduction="none"
)
if mask_multiplier is not None:
guidance_loss = guidance_loss * mask_multiplier
guidance_loss = guidance_loss.mean([1, 2, 3])
@@ -489,6 +493,8 @@ def get_guidance_loss(
noise: torch.Tensor,
sd: 'StableDiffusion',
unconditional_embeds: Optional[PromptEmbeds] = None,
mask_multiplier=None,
prior_pred=None,
**kwargs
):
# TODO add others and process individual batch items separately
@@ -549,6 +555,8 @@ def get_guidance_loss(
noise,
sd,
unconditional_embeds=unconditional_embeds,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
**kwargs
)
else: