mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added this not that guidance. Added ability to replace prompts.
This commit is contained in:
@@ -482,6 +482,87 @@ def get_guided_loss_polarity(
|
||||
return loss
|
||||
|
||||
|
||||
def get_guided_tnt(
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
sd: 'StableDiffusion',
|
||||
**kwargs
|
||||
):
|
||||
dtype = get_torch_dtype(sd.torch_dtype)
|
||||
device = sd.device_torch
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(dtype)
|
||||
noise = noise.to(device, dtype=dtype).detach()
|
||||
|
||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||
|
||||
conditional_noisy_latents = sd.add_noise(
|
||||
conditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
unconditional_noisy_latents = sd.add_noise(
|
||||
unconditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
# double up everything to run it through all at once
|
||||
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
|
||||
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
|
||||
|
||||
# turn the LoRA network back on.
|
||||
sd.unet.train()
|
||||
if sd.network is not None:
|
||||
cat_network_weight_list = [weight for weight in network_weight_list * 2]
|
||||
sd.network.multiplier = cat_network_weight_list
|
||||
sd.network.is_active = True
|
||||
|
||||
prediction = sd.predict_noise(
|
||||
latents=cat_latents.to(device, dtype=dtype).detach(),
|
||||
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
|
||||
timestep=cat_timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
this_prediction, that_prediction = torch.chunk(prediction, 2, dim=0)
|
||||
|
||||
this_loss = torch.nn.functional.mse_loss(
|
||||
this_prediction.float(),
|
||||
noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
|
||||
that_loss = torch.nn.functional.mse_loss(
|
||||
that_prediction.float(),
|
||||
noise.float(),
|
||||
reduction="none"
|
||||
) * -1.0
|
||||
|
||||
loss = this_loss + that_loss
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss.backward()
|
||||
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
loss = loss.detach()
|
||||
loss.requires_grad_(True)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
# this processes all guidance losses based on the batch information
|
||||
def get_guidance_loss(
|
||||
noisy_latents: torch.Tensor,
|
||||
@@ -529,6 +610,20 @@ def get_guidance_loss(
|
||||
sd,
|
||||
**kwargs
|
||||
)
|
||||
elif guidance_type == "tnt":
|
||||
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
|
||||
return get_guided_loss_polarity(
|
||||
noisy_latents,
|
||||
conditional_embeds,
|
||||
match_adapter_assist,
|
||||
network_weight_list,
|
||||
timesteps,
|
||||
pred_kwargs,
|
||||
batch,
|
||||
noise,
|
||||
sd,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
elif guidance_type == "targeted_polarity":
|
||||
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance"
|
||||
|
||||
Reference in New Issue
Block a user