mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Imporvements to ip weight adaptation. Bug fixes. Added masking to direct guidance loss. Allow importing a file for random triggers. Handle bas meta images with improper sizing.
This commit is contained in:
@@ -194,6 +194,8 @@ def get_direct_guidance_loss(
|
||||
noise: torch.Tensor,
|
||||
sd: 'StableDiffusion',
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
mask_multiplier=None,
|
||||
prior_pred=None,
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
@@ -248,6 +250,8 @@ def get_direct_guidance_loss(
|
||||
noise.detach().float(),
|
||||
reduction="none"
|
||||
)
|
||||
if mask_multiplier is not None:
|
||||
guidance_loss = guidance_loss * mask_multiplier
|
||||
|
||||
guidance_loss = guidance_loss.mean([1, 2, 3])
|
||||
|
||||
@@ -489,6 +493,8 @@ def get_guidance_loss(
|
||||
noise: torch.Tensor,
|
||||
sd: 'StableDiffusion',
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
mask_multiplier=None,
|
||||
prior_pred=None,
|
||||
**kwargs
|
||||
):
|
||||
# TODO add others and process individual batch items separately
|
||||
@@ -549,6 +555,8 @@ def get_guidance_loss(
|
||||
noise,
|
||||
sd,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_pred,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user