mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Targeted guidance work
This commit is contained in:
@@ -206,63 +206,31 @@ def get_targeted_guidance_loss(
|
|||||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||||
|
|
||||||
# apply random offset to both latents
|
# # apply random offset to both latents
|
||||||
offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
# offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||||
offset = offset * 0.1
|
# offset = offset * 0.1
|
||||||
conditional_latents = conditional_latents + offset
|
# conditional_latents = conditional_latents + offset
|
||||||
unconditional_latents = unconditional_latents + offset
|
# unconditional_latents = unconditional_latents + offset
|
||||||
|
|
||||||
# get random scale 0f 0.8 to 1.2
|
|
||||||
scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
|
||||||
scale = scale * 0.4
|
|
||||||
scale = scale + 0.8
|
|
||||||
conditional_latents = conditional_latents * scale
|
|
||||||
unconditional_latents = unconditional_latents * scale
|
|
||||||
|
|
||||||
diff_mask = get_differential_mask(
|
|
||||||
conditional_latents,
|
|
||||||
unconditional_latents,
|
|
||||||
threshold=0.2,
|
|
||||||
gradient=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# standardize inpute to std of 1
|
|
||||||
# combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True)
|
|
||||||
#
|
#
|
||||||
# # scale the latents to std of 1
|
# # get random scale 0f 0.8 to 1.2
|
||||||
# conditional_latents = conditional_latents / combo_std
|
# scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||||
# unconditional_latents = unconditional_latents / combo_std
|
# scale = scale * 0.4
|
||||||
|
# scale = scale + 0.8
|
||||||
|
# conditional_latents = conditional_latents * scale
|
||||||
|
# unconditional_latents = unconditional_latents * scale
|
||||||
|
|
||||||
unconditional_diff = (unconditional_latents - conditional_latents)
|
unconditional_diff = (unconditional_latents - conditional_latents)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# get a -0.5 to 0.5 multiplier for the diff noise
|
|
||||||
# noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
|
||||||
# noise_multiplier = noise_multiplier - 0.5
|
|
||||||
noise_multiplier = 1.0
|
|
||||||
|
|
||||||
# unconditional_diff_noise = unconditional_diff * noise_multiplier
|
|
||||||
unconditional_diff_noise = unconditional_diff * noise_multiplier
|
|
||||||
|
|
||||||
# scale it to the timestep
|
# scale it to the timestep
|
||||||
unconditional_diff_noise = sd.add_noise(
|
unconditional_diff_noise = sd.add_noise(
|
||||||
torch.zeros_like(unconditional_latents),
|
torch.zeros_like(unconditional_latents),
|
||||||
unconditional_diff_noise,
|
unconditional_diff,
|
||||||
timesteps
|
timesteps
|
||||||
)
|
)
|
||||||
|
|
||||||
unconditional_diff_noise = unconditional_diff_noise * 0.2
|
|
||||||
|
|
||||||
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
|
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
|
||||||
|
|
||||||
baseline_noisy_latents = sd.add_noise(
|
|
||||||
unconditional_latents,
|
|
||||||
noise,
|
|
||||||
timesteps
|
|
||||||
).detach()
|
|
||||||
|
|
||||||
target_noise = noise + unconditional_diff_noise
|
target_noise = noise + unconditional_diff_noise
|
||||||
|
|
||||||
noisy_latents = sd.add_noise(
|
noisy_latents = sd.add_noise(
|
||||||
conditional_latents,
|
conditional_latents,
|
||||||
target_noise,
|
target_noise,
|
||||||
@@ -276,7 +244,7 @@ def get_targeted_guidance_loss(
|
|||||||
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||||
# This acts as our control to preserve the unaltered parts of the image.
|
# This acts as our control to preserve the unaltered parts of the image.
|
||||||
baseline_prediction = sd.predict_noise(
|
baseline_prediction = sd.predict_noise(
|
||||||
latents=baseline_noisy_latents.to(device, dtype=dtype).detach(),
|
latents=noisy_latents.to(device, dtype=dtype).detach(),
|
||||||
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||||
timestep=timesteps,
|
timestep=timesteps,
|
||||||
guidance_scale=1.0,
|
guidance_scale=1.0,
|
||||||
@@ -286,17 +254,16 @@ def get_targeted_guidance_loss(
|
|||||||
# determine the error for the baseline prediction
|
# determine the error for the baseline prediction
|
||||||
baseline_prediction_error = baseline_prediction - noise
|
baseline_prediction_error = baseline_prediction - noise
|
||||||
|
|
||||||
|
prediction_target = baseline_prediction_error + unconditional_diff_noise
|
||||||
|
|
||||||
|
prediction_target = prediction_target.detach().requires_grad_(False)
|
||||||
|
|
||||||
|
|
||||||
# turn the LoRA network back on.
|
# turn the LoRA network back on.
|
||||||
sd.unet.train()
|
sd.unet.train()
|
||||||
sd.network.is_active = True
|
sd.network.is_active = True
|
||||||
|
|
||||||
sd.network.multiplier = network_weight_list
|
sd.network.multiplier = network_weight_list
|
||||||
|
|
||||||
# unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
|
|
||||||
# masked_noise = noise * diff_mask
|
|
||||||
# pred_target = unmasked_noise + unconditional_diff_noise
|
|
||||||
|
|
||||||
# do our prediction with LoRA active on the scaled guidance latents
|
# do our prediction with LoRA active on the scaled guidance latents
|
||||||
prediction = sd.predict_noise(
|
prediction = sd.predict_noise(
|
||||||
latents=noisy_latents.to(device, dtype=dtype).detach(),
|
latents=noisy_latents.to(device, dtype=dtype).detach(),
|
||||||
@@ -306,32 +273,20 @@ def get_targeted_guidance_loss(
|
|||||||
**pred_kwargs # adapter residuals in here
|
**pred_kwargs # adapter residuals in here
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prediction_error = prediction - noise
|
||||||
|
|
||||||
baselined_prediction = prediction - baseline_prediction
|
|
||||||
|
|
||||||
guidance_loss = torch.nn.functional.mse_loss(
|
guidance_loss = torch.nn.functional.mse_loss(
|
||||||
baselined_prediction.float(),
|
prediction_error.float(),
|
||||||
# unconditional_diff_noise.float(),
|
# unconditional_diff_noise.float(),
|
||||||
unconditional_diff_noise.float(),
|
prediction_target.float(),
|
||||||
reduction="none"
|
reduction="none"
|
||||||
)
|
)
|
||||||
guidance_loss = guidance_loss.mean([1, 2, 3])
|
guidance_loss = guidance_loss.mean([1, 2, 3])
|
||||||
|
|
||||||
guidance_loss = guidance_loss.mean()
|
guidance_loss = guidance_loss.mean()
|
||||||
|
|
||||||
|
# loss = guidance_loss + masked_noise_loss
|
||||||
# do the masked noise prediction
|
loss = guidance_loss
|
||||||
masked_noise_loss = torch.nn.functional.mse_loss(
|
|
||||||
prediction.float(),
|
|
||||||
target_noise.float(),
|
|
||||||
reduction="none"
|
|
||||||
) * diff_mask
|
|
||||||
masked_noise_loss = masked_noise_loss.mean([1, 2, 3])
|
|
||||||
masked_noise_loss = masked_noise_loss.mean()
|
|
||||||
|
|
||||||
|
|
||||||
loss = guidance_loss + masked_noise_loss
|
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
@@ -342,158 +297,6 @@ def get_targeted_guidance_loss(
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def get_targeted_guidance_loss_WIP(
|
|
||||||
noisy_latents: torch.Tensor,
|
|
||||||
conditional_embeds: 'PromptEmbeds',
|
|
||||||
match_adapter_assist: bool,
|
|
||||||
network_weight_list: list,
|
|
||||||
timesteps: torch.Tensor,
|
|
||||||
pred_kwargs: dict,
|
|
||||||
batch: 'DataLoaderBatchDTO',
|
|
||||||
noise: torch.Tensor,
|
|
||||||
sd: 'StableDiffusion',
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
with torch.no_grad():
|
|
||||||
# Perform targeted guidance (working title)
|
|
||||||
dtype = get_torch_dtype(sd.torch_dtype)
|
|
||||||
device = sd.device_torch
|
|
||||||
|
|
||||||
|
|
||||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
|
||||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
|
||||||
|
|
||||||
# apply random offset to both latents
|
|
||||||
offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
|
||||||
offset = offset * 0.1
|
|
||||||
conditional_latents = conditional_latents + offset
|
|
||||||
unconditional_latents = unconditional_latents + offset
|
|
||||||
|
|
||||||
# get random scale 0f 0.8 to 1.2
|
|
||||||
scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
|
||||||
scale = scale * 0.4
|
|
||||||
scale = scale + 0.8
|
|
||||||
conditional_latents = conditional_latents * scale
|
|
||||||
unconditional_latents = unconditional_latents * scale
|
|
||||||
|
|
||||||
diff_mask = get_differential_mask(
|
|
||||||
conditional_latents,
|
|
||||||
unconditional_latents,
|
|
||||||
threshold=0.2,
|
|
||||||
gradient=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# standardize inpute to std of 1
|
|
||||||
# combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True)
|
|
||||||
#
|
|
||||||
# # scale the latents to std of 1
|
|
||||||
# conditional_latents = conditional_latents / combo_std
|
|
||||||
# unconditional_latents = unconditional_latents / combo_std
|
|
||||||
|
|
||||||
unconditional_diff = (unconditional_latents - conditional_latents)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# get a -0.5 to 0.5 multiplier for the diff noise
|
|
||||||
# noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
|
||||||
# noise_multiplier = noise_multiplier - 0.5
|
|
||||||
noise_multiplier = 1.0
|
|
||||||
|
|
||||||
# unconditional_diff_noise = unconditional_diff * noise_multiplier
|
|
||||||
unconditional_diff_noise = unconditional_diff * noise_multiplier
|
|
||||||
|
|
||||||
# scale it to the timestep
|
|
||||||
unconditional_diff_noise = sd.add_noise(
|
|
||||||
torch.zeros_like(unconditional_latents),
|
|
||||||
unconditional_diff_noise,
|
|
||||||
timesteps
|
|
||||||
)
|
|
||||||
|
|
||||||
unconditional_diff_noise = unconditional_diff_noise * 0.2
|
|
||||||
|
|
||||||
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
|
|
||||||
|
|
||||||
baseline_noisy_latents = sd.add_noise(
|
|
||||||
unconditional_latents,
|
|
||||||
noise,
|
|
||||||
timesteps
|
|
||||||
).detach()
|
|
||||||
|
|
||||||
target_noise = noise + unconditional_diff_noise
|
|
||||||
noisy_latents = sd.add_noise(
|
|
||||||
conditional_latents,
|
|
||||||
target_noise,
|
|
||||||
# noise,
|
|
||||||
timesteps
|
|
||||||
).detach()
|
|
||||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
|
||||||
sd.network.is_active = False
|
|
||||||
sd.unet.eval()
|
|
||||||
|
|
||||||
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
|
||||||
# This acts as our control to preserve the unaltered parts of the image.
|
|
||||||
baseline_prediction = sd.predict_noise(
|
|
||||||
latents=baseline_noisy_latents.to(device, dtype=dtype).detach(),
|
|
||||||
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
|
||||||
timestep=timesteps,
|
|
||||||
guidance_scale=1.0,
|
|
||||||
**pred_kwargs # adapter residuals in here
|
|
||||||
).detach().requires_grad_(False)
|
|
||||||
|
|
||||||
# turn the LoRA network back on.
|
|
||||||
sd.unet.train()
|
|
||||||
sd.network.is_active = True
|
|
||||||
|
|
||||||
sd.network.multiplier = network_weight_list
|
|
||||||
|
|
||||||
# unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
|
|
||||||
# masked_noise = noise * diff_mask
|
|
||||||
# pred_target = unmasked_noise + unconditional_diff_noise
|
|
||||||
|
|
||||||
# do our prediction with LoRA active on the scaled guidance latents
|
|
||||||
prediction = sd.predict_noise(
|
|
||||||
latents=noisy_latents.to(device, dtype=dtype).detach(),
|
|
||||||
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
|
||||||
timestep=timesteps,
|
|
||||||
guidance_scale=1.0,
|
|
||||||
**pred_kwargs # adapter residuals in here
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
baselined_prediction = prediction - baseline_prediction
|
|
||||||
|
|
||||||
guidance_loss = torch.nn.functional.mse_loss(
|
|
||||||
baselined_prediction.float(),
|
|
||||||
# unconditional_diff_noise.float(),
|
|
||||||
unconditional_diff_noise.float(),
|
|
||||||
reduction="none"
|
|
||||||
)
|
|
||||||
guidance_loss = guidance_loss.mean([1, 2, 3])
|
|
||||||
|
|
||||||
guidance_loss = guidance_loss.mean()
|
|
||||||
|
|
||||||
|
|
||||||
# do the masked noise prediction
|
|
||||||
masked_noise_loss = torch.nn.functional.mse_loss(
|
|
||||||
prediction.float(),
|
|
||||||
target_noise.float(),
|
|
||||||
reduction="none"
|
|
||||||
) * diff_mask
|
|
||||||
masked_noise_loss = masked_noise_loss.mean([1, 2, 3])
|
|
||||||
masked_noise_loss = masked_noise_loss.mean()
|
|
||||||
|
|
||||||
|
|
||||||
loss = guidance_loss + masked_noise_loss
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# detach it so parent class can run backward on no grads without throwing error
|
|
||||||
loss = loss.detach()
|
|
||||||
loss.requires_grad_(True)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def get_guided_loss_polarity(
|
def get_guided_loss_polarity(
|
||||||
noisy_latents: torch.Tensor,
|
noisy_latents: torch.Tensor,
|
||||||
conditional_embeds: PromptEmbeds,
|
conditional_embeds: PromptEmbeds,
|
||||||
|
|||||||
Reference in New Issue
Block a user