Adjustments to guidance

This commit is contained in:
Jaret Burkett
2024-04-19 15:00:35 -06:00
parent 2d0a1be59d
commit 377b81ee3e
2 changed files with 48 additions and 6 deletions

View File

@@ -482,6 +482,7 @@ def get_guided_loss_polarity(
return loss
def get_guided_tnt(
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
@@ -492,6 +493,7 @@ def get_guided_tnt(
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
prior_pred: torch.Tensor = None,
**kwargs
):
dtype = get_torch_dtype(sd.torch_dtype)
@@ -528,6 +530,7 @@ def get_guided_tnt(
sd.network.multiplier = cat_network_weight_list
sd.network.is_active = True
prediction = sd.predict_noise(
latents=cat_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
@@ -547,11 +550,41 @@ def get_guided_tnt(
that_prediction.float(),
noise.float(),
reduction="none"
) * -1.0
)
loss = this_loss + that_loss
with torch.no_grad():
tnt_loss = this_loss - that_loss
loss = loss.mean([1, 2, 3])
# create a mask by scaling loss from 0 to mean to 1 to 0
# this will act to regularize unchanged areas to prior prediction
loss_min = tnt_loss.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0].min(dim=3, keepdim=True)[0]
loss_mean = tnt_loss.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
mask = value_map(
torch.abs(tnt_loss),
loss_min,
loss_mean,
0.0,
1.0
).clamp(0.0, 1.0).detach()
prior_mask = 1.0 - mask
this_loss = this_loss * mask
that_loss = that_loss * prior_mask
this_loss = this_loss.mean([1, 2, 3])
that_loss = that_loss.mean([1, 2, 3])
prior_loss = torch.nn.functional.mse_loss(
this_prediction.float(),
prior_pred.detach().float(),
reduction="none"
)
prior_loss = prior_loss * prior_mask
prior_loss = prior_loss.mean([1, 2, 3])
loss = prior_loss + this_loss - that_loss
loss.backward()
@@ -612,7 +645,7 @@ def get_guidance_loss(
)
elif guidance_type == "tnt":
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
return get_guided_loss_polarity(
return get_guided_tnt(
noisy_latents,
conditional_embeds,
match_adapter_assist,
@@ -622,6 +655,7 @@ def get_guidance_loss(
batch,
noise,
sd,
prior_pred=prior_pred,
**kwargs
)