mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Targeted guidance work
This commit is contained in:
@@ -206,63 +206,31 @@ def get_targeted_guidance_loss(
|
||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||
|
||||
# apply random offset to both latents
|
||||
offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||
offset = offset * 0.1
|
||||
conditional_latents = conditional_latents + offset
|
||||
unconditional_latents = unconditional_latents + offset
|
||||
|
||||
# get random scale 0f 0.8 to 1.2
|
||||
scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||
scale = scale * 0.4
|
||||
scale = scale + 0.8
|
||||
conditional_latents = conditional_latents * scale
|
||||
unconditional_latents = unconditional_latents * scale
|
||||
|
||||
diff_mask = get_differential_mask(
|
||||
conditional_latents,
|
||||
unconditional_latents,
|
||||
threshold=0.2,
|
||||
gradient=True
|
||||
)
|
||||
|
||||
# standardize inpute to std of 1
|
||||
# combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True)
|
||||
# # apply random offset to both latents
|
||||
# offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||
# offset = offset * 0.1
|
||||
# conditional_latents = conditional_latents + offset
|
||||
# unconditional_latents = unconditional_latents + offset
|
||||
#
|
||||
# # scale the latents to std of 1
|
||||
# conditional_latents = conditional_latents / combo_std
|
||||
# unconditional_latents = unconditional_latents / combo_std
|
||||
# # get random scale 0f 0.8 to 1.2
|
||||
# scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||
# scale = scale * 0.4
|
||||
# scale = scale + 0.8
|
||||
# conditional_latents = conditional_latents * scale
|
||||
# unconditional_latents = unconditional_latents * scale
|
||||
|
||||
unconditional_diff = (unconditional_latents - conditional_latents)
|
||||
|
||||
|
||||
|
||||
# get a -0.5 to 0.5 multiplier for the diff noise
|
||||
# noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||
# noise_multiplier = noise_multiplier - 0.5
|
||||
noise_multiplier = 1.0
|
||||
|
||||
# unconditional_diff_noise = unconditional_diff * noise_multiplier
|
||||
unconditional_diff_noise = unconditional_diff * noise_multiplier
|
||||
|
||||
# scale it to the timestep
|
||||
unconditional_diff_noise = sd.add_noise(
|
||||
torch.zeros_like(unconditional_latents),
|
||||
unconditional_diff_noise,
|
||||
unconditional_diff,
|
||||
timesteps
|
||||
)
|
||||
|
||||
unconditional_diff_noise = unconditional_diff_noise * 0.2
|
||||
|
||||
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
|
||||
|
||||
baseline_noisy_latents = sd.add_noise(
|
||||
unconditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
target_noise = noise + unconditional_diff_noise
|
||||
|
||||
noisy_latents = sd.add_noise(
|
||||
conditional_latents,
|
||||
target_noise,
|
||||
@@ -276,7 +244,7 @@ def get_targeted_guidance_loss(
|
||||
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||
# This acts as our control to preserve the unaltered parts of the image.
|
||||
baseline_prediction = sd.predict_noise(
|
||||
latents=baseline_noisy_latents.to(device, dtype=dtype).detach(),
|
||||
latents=noisy_latents.to(device, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
@@ -286,17 +254,16 @@ def get_targeted_guidance_loss(
|
||||
# determine the error for the baseline prediction
|
||||
baseline_prediction_error = baseline_prediction - noise
|
||||
|
||||
prediction_target = baseline_prediction_error + unconditional_diff_noise
|
||||
|
||||
prediction_target = prediction_target.detach().requires_grad_(False)
|
||||
|
||||
|
||||
# turn the LoRA network back on.
|
||||
sd.unet.train()
|
||||
sd.network.is_active = True
|
||||
|
||||
sd.network.multiplier = network_weight_list
|
||||
|
||||
# unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
|
||||
# masked_noise = noise * diff_mask
|
||||
# pred_target = unmasked_noise + unconditional_diff_noise
|
||||
|
||||
# do our prediction with LoRA active on the scaled guidance latents
|
||||
prediction = sd.predict_noise(
|
||||
latents=noisy_latents.to(device, dtype=dtype).detach(),
|
||||
@@ -306,32 +273,20 @@ def get_targeted_guidance_loss(
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
|
||||
|
||||
baselined_prediction = prediction - baseline_prediction
|
||||
prediction_error = prediction - noise
|
||||
|
||||
guidance_loss = torch.nn.functional.mse_loss(
|
||||
baselined_prediction.float(),
|
||||
prediction_error.float(),
|
||||
# unconditional_diff_noise.float(),
|
||||
unconditional_diff_noise.float(),
|
||||
prediction_target.float(),
|
||||
reduction="none"
|
||||
)
|
||||
guidance_loss = guidance_loss.mean([1, 2, 3])
|
||||
|
||||
guidance_loss = guidance_loss.mean()
|
||||
|
||||
|
||||
# do the masked noise prediction
|
||||
masked_noise_loss = torch.nn.functional.mse_loss(
|
||||
prediction.float(),
|
||||
target_noise.float(),
|
||||
reduction="none"
|
||||
) * diff_mask
|
||||
masked_noise_loss = masked_noise_loss.mean([1, 2, 3])
|
||||
masked_noise_loss = masked_noise_loss.mean()
|
||||
|
||||
|
||||
loss = guidance_loss + masked_noise_loss
|
||||
# loss = guidance_loss + masked_noise_loss
|
||||
loss = guidance_loss
|
||||
|
||||
loss.backward()
|
||||
|
||||
@@ -342,158 +297,6 @@ def get_targeted_guidance_loss(
|
||||
return loss
|
||||
|
||||
|
||||
def get_targeted_guidance_loss_WIP(
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: 'PromptEmbeds',
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
sd: 'StableDiffusion',
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
# Perform targeted guidance (working title)
|
||||
dtype = get_torch_dtype(sd.torch_dtype)
|
||||
device = sd.device_torch
|
||||
|
||||
|
||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||
|
||||
# apply random offset to both latents
|
||||
offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||
offset = offset * 0.1
|
||||
conditional_latents = conditional_latents + offset
|
||||
unconditional_latents = unconditional_latents + offset
|
||||
|
||||
# get random scale 0f 0.8 to 1.2
|
||||
scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||
scale = scale * 0.4
|
||||
scale = scale + 0.8
|
||||
conditional_latents = conditional_latents * scale
|
||||
unconditional_latents = unconditional_latents * scale
|
||||
|
||||
diff_mask = get_differential_mask(
|
||||
conditional_latents,
|
||||
unconditional_latents,
|
||||
threshold=0.2,
|
||||
gradient=True
|
||||
)
|
||||
|
||||
# standardize inpute to std of 1
|
||||
# combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True)
|
||||
#
|
||||
# # scale the latents to std of 1
|
||||
# conditional_latents = conditional_latents / combo_std
|
||||
# unconditional_latents = unconditional_latents / combo_std
|
||||
|
||||
unconditional_diff = (unconditional_latents - conditional_latents)
|
||||
|
||||
|
||||
|
||||
# get a -0.5 to 0.5 multiplier for the diff noise
|
||||
# noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||
# noise_multiplier = noise_multiplier - 0.5
|
||||
noise_multiplier = 1.0
|
||||
|
||||
# unconditional_diff_noise = unconditional_diff * noise_multiplier
|
||||
unconditional_diff_noise = unconditional_diff * noise_multiplier
|
||||
|
||||
# scale it to the timestep
|
||||
unconditional_diff_noise = sd.add_noise(
|
||||
torch.zeros_like(unconditional_latents),
|
||||
unconditional_diff_noise,
|
||||
timesteps
|
||||
)
|
||||
|
||||
unconditional_diff_noise = unconditional_diff_noise * 0.2
|
||||
|
||||
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
|
||||
|
||||
baseline_noisy_latents = sd.add_noise(
|
||||
unconditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
target_noise = noise + unconditional_diff_noise
|
||||
noisy_latents = sd.add_noise(
|
||||
conditional_latents,
|
||||
target_noise,
|
||||
# noise,
|
||||
timesteps
|
||||
).detach()
|
||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||
sd.network.is_active = False
|
||||
sd.unet.eval()
|
||||
|
||||
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||
# This acts as our control to preserve the unaltered parts of the image.
|
||||
baseline_prediction = sd.predict_noise(
|
||||
latents=baseline_noisy_latents.to(device, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach().requires_grad_(False)
|
||||
|
||||
# turn the LoRA network back on.
|
||||
sd.unet.train()
|
||||
sd.network.is_active = True
|
||||
|
||||
sd.network.multiplier = network_weight_list
|
||||
|
||||
# unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
|
||||
# masked_noise = noise * diff_mask
|
||||
# pred_target = unmasked_noise + unconditional_diff_noise
|
||||
|
||||
# do our prediction with LoRA active on the scaled guidance latents
|
||||
prediction = sd.predict_noise(
|
||||
latents=noisy_latents.to(device, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
|
||||
|
||||
baselined_prediction = prediction - baseline_prediction
|
||||
|
||||
guidance_loss = torch.nn.functional.mse_loss(
|
||||
baselined_prediction.float(),
|
||||
# unconditional_diff_noise.float(),
|
||||
unconditional_diff_noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
guidance_loss = guidance_loss.mean([1, 2, 3])
|
||||
|
||||
guidance_loss = guidance_loss.mean()
|
||||
|
||||
|
||||
# do the masked noise prediction
|
||||
masked_noise_loss = torch.nn.functional.mse_loss(
|
||||
prediction.float(),
|
||||
target_noise.float(),
|
||||
reduction="none"
|
||||
) * diff_mask
|
||||
masked_noise_loss = masked_noise_loss.mean([1, 2, 3])
|
||||
masked_noise_loss = masked_noise_loss.mean()
|
||||
|
||||
|
||||
loss = guidance_loss + masked_noise_loss
|
||||
|
||||
loss.backward()
|
||||
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
loss = loss.detach()
|
||||
loss.requires_grad_(True)
|
||||
|
||||
return loss
|
||||
|
||||
def get_guided_loss_polarity(
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
|
||||
Reference in New Issue
Block a user