Targeted guidance work

This commit is contained in:
Jaret Burkett
2023-12-09 19:06:18 -07:00
parent eaa0fb6253
commit e5177833b2

View File

@@ -206,63 +206,31 @@ def get_targeted_guidance_loss(
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
# apply random offset to both latents
offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
offset = offset * 0.1
conditional_latents = conditional_latents + offset
unconditional_latents = unconditional_latents + offset
# get random scale 0f 0.8 to 1.2
scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
scale = scale * 0.4
scale = scale + 0.8
conditional_latents = conditional_latents * scale
unconditional_latents = unconditional_latents * scale
diff_mask = get_differential_mask(
conditional_latents,
unconditional_latents,
threshold=0.2,
gradient=True
)
# standardize inpute to std of 1
# combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True)
# # apply random offset to both latents
# offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
# offset = offset * 0.1
# conditional_latents = conditional_latents + offset
# unconditional_latents = unconditional_latents + offset
#
# # scale the latents to std of 1
# conditional_latents = conditional_latents / combo_std
# unconditional_latents = unconditional_latents / combo_std
# # get random scale 0f 0.8 to 1.2
# scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
# scale = scale * 0.4
# scale = scale + 0.8
# conditional_latents = conditional_latents * scale
# unconditional_latents = unconditional_latents * scale
unconditional_diff = (unconditional_latents - conditional_latents)
# get a -0.5 to 0.5 multiplier for the diff noise
# noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype)
# noise_multiplier = noise_multiplier - 0.5
noise_multiplier = 1.0
# unconditional_diff_noise = unconditional_diff * noise_multiplier
unconditional_diff_noise = unconditional_diff * noise_multiplier
# scale it to the timestep
unconditional_diff_noise = sd.add_noise(
torch.zeros_like(unconditional_latents),
unconditional_diff_noise,
unconditional_diff,
timesteps
)
unconditional_diff_noise = unconditional_diff_noise * 0.2
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
baseline_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
target_noise = noise + unconditional_diff_noise
noisy_latents = sd.add_noise(
conditional_latents,
target_noise,
@@ -276,7 +244,7 @@ def get_targeted_guidance_loss(
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
# This acts as our control to preserve the unaltered parts of the image.
baseline_prediction = sd.predict_noise(
latents=baseline_noisy_latents.to(device, dtype=dtype).detach(),
latents=noisy_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
@@ -286,17 +254,16 @@ def get_targeted_guidance_loss(
# determine the error for the baseline prediction
baseline_prediction_error = baseline_prediction - noise
prediction_target = baseline_prediction_error + unconditional_diff_noise
prediction_target = prediction_target.detach().requires_grad_(False)
# turn the LoRA network back on.
sd.unet.train()
sd.network.is_active = True
sd.network.multiplier = network_weight_list
# unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
# masked_noise = noise * diff_mask
# pred_target = unmasked_noise + unconditional_diff_noise
# do our prediction with LoRA active on the scaled guidance latents
prediction = sd.predict_noise(
latents=noisy_latents.to(device, dtype=dtype).detach(),
@@ -306,32 +273,20 @@ def get_targeted_guidance_loss(
**pred_kwargs # adapter residuals in here
)
baselined_prediction = prediction - baseline_prediction
prediction_error = prediction - noise
guidance_loss = torch.nn.functional.mse_loss(
baselined_prediction.float(),
prediction_error.float(),
# unconditional_diff_noise.float(),
unconditional_diff_noise.float(),
prediction_target.float(),
reduction="none"
)
guidance_loss = guidance_loss.mean([1, 2, 3])
guidance_loss = guidance_loss.mean()
# do the masked noise prediction
masked_noise_loss = torch.nn.functional.mse_loss(
prediction.float(),
target_noise.float(),
reduction="none"
) * diff_mask
masked_noise_loss = masked_noise_loss.mean([1, 2, 3])
masked_noise_loss = masked_noise_loss.mean()
loss = guidance_loss + masked_noise_loss
# loss = guidance_loss + masked_noise_loss
loss = guidance_loss
loss.backward()
@@ -342,158 +297,6 @@ def get_targeted_guidance_loss(
return loss
def get_targeted_guidance_loss_WIP(
noisy_latents: torch.Tensor,
conditional_embeds: 'PromptEmbeds',
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
**kwargs
):
with torch.no_grad():
# Perform targeted guidance (working title)
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
# apply random offset to both latents
offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
offset = offset * 0.1
conditional_latents = conditional_latents + offset
unconditional_latents = unconditional_latents + offset
# get random scale 0f 0.8 to 1.2
scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
scale = scale * 0.4
scale = scale + 0.8
conditional_latents = conditional_latents * scale
unconditional_latents = unconditional_latents * scale
diff_mask = get_differential_mask(
conditional_latents,
unconditional_latents,
threshold=0.2,
gradient=True
)
# standardize inpute to std of 1
# combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True)
#
# # scale the latents to std of 1
# conditional_latents = conditional_latents / combo_std
# unconditional_latents = unconditional_latents / combo_std
unconditional_diff = (unconditional_latents - conditional_latents)
# get a -0.5 to 0.5 multiplier for the diff noise
# noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype)
# noise_multiplier = noise_multiplier - 0.5
noise_multiplier = 1.0
# unconditional_diff_noise = unconditional_diff * noise_multiplier
unconditional_diff_noise = unconditional_diff * noise_multiplier
# scale it to the timestep
unconditional_diff_noise = sd.add_noise(
torch.zeros_like(unconditional_latents),
unconditional_diff_noise,
timesteps
)
unconditional_diff_noise = unconditional_diff_noise * 0.2
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
baseline_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
target_noise = noise + unconditional_diff_noise
noisy_latents = sd.add_noise(
conditional_latents,
target_noise,
# noise,
timesteps
).detach()
# Disable the LoRA network so we can predict parent network knowledge without it
sd.network.is_active = False
sd.unet.eval()
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
# This acts as our control to preserve the unaltered parts of the image.
baseline_prediction = sd.predict_noise(
latents=baseline_noisy_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach().requires_grad_(False)
# turn the LoRA network back on.
sd.unet.train()
sd.network.is_active = True
sd.network.multiplier = network_weight_list
# unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
# masked_noise = noise * diff_mask
# pred_target = unmasked_noise + unconditional_diff_noise
# do our prediction with LoRA active on the scaled guidance latents
prediction = sd.predict_noise(
latents=noisy_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
baselined_prediction = prediction - baseline_prediction
guidance_loss = torch.nn.functional.mse_loss(
baselined_prediction.float(),
# unconditional_diff_noise.float(),
unconditional_diff_noise.float(),
reduction="none"
)
guidance_loss = guidance_loss.mean([1, 2, 3])
guidance_loss = guidance_loss.mean()
# do the masked noise prediction
masked_noise_loss = torch.nn.functional.mse_loss(
prediction.float(),
target_noise.float(),
reduction="none"
) * diff_mask
masked_noise_loss = masked_noise_loss.mean([1, 2, 3])
masked_noise_loss = masked_noise_loss.mean()
loss = guidance_loss + masked_noise_loss
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
def get_guided_loss_polarity(
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,