Added this not that guidance. Added ability to replace prompts.

This commit is contained in:
Jaret Burkett
2024-02-28 20:10:14 -07:00
parent 561914d8e6
commit 337945de9a
7 changed files with 114 additions and 5 deletions

View File

@@ -482,6 +482,87 @@ def get_guided_loss_polarity(
return loss
def get_guided_tnt(
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
**kwargs
):
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
with torch.no_grad():
dtype = get_torch_dtype(dtype)
noise = noise.to(device, dtype=dtype).detach()
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
conditional_noisy_latents = sd.add_noise(
conditional_latents,
noise,
timesteps
).detach()
unconditional_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
# double up everything to run it through all at once
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
# turn the LoRA network back on.
sd.unet.train()
if sd.network is not None:
cat_network_weight_list = [weight for weight in network_weight_list * 2]
sd.network.multiplier = cat_network_weight_list
sd.network.is_active = True
prediction = sd.predict_noise(
latents=cat_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
timestep=cat_timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
this_prediction, that_prediction = torch.chunk(prediction, 2, dim=0)
this_loss = torch.nn.functional.mse_loss(
this_prediction.float(),
noise.float(),
reduction="none"
)
that_loss = torch.nn.functional.mse_loss(
that_prediction.float(),
noise.float(),
reduction="none"
) * -1.0
loss = this_loss + that_loss
loss = loss.mean([1, 2, 3])
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
# this processes all guidance losses based on the batch information
def get_guidance_loss(
noisy_latents: torch.Tensor,
@@ -529,6 +610,20 @@ def get_guidance_loss(
sd,
**kwargs
)
elif guidance_type == "tnt":
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
return get_guided_loss_polarity(
noisy_latents,
conditional_embeds,
match_adapter_assist,
network_weight_list,
timesteps,
pred_kwargs,
batch,
noise,
sd,
**kwargs
)
elif guidance_type == "targeted_polarity":
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance"