diff --git a/toolkit/guidance.py b/toolkit/guidance.py index cb22807a..6efdfcca 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -206,63 +206,31 @@ def get_targeted_guidance_loss( conditional_latents = batch.latents.to(device, dtype=dtype).detach() unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() - # apply random offset to both latents - offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype) - offset = offset * 0.1 - conditional_latents = conditional_latents + offset - unconditional_latents = unconditional_latents + offset - - # get random scale 0f 0.8 to 1.2 - scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype) - scale = scale * 0.4 - scale = scale + 0.8 - conditional_latents = conditional_latents * scale - unconditional_latents = unconditional_latents * scale - - diff_mask = get_differential_mask( - conditional_latents, - unconditional_latents, - threshold=0.2, - gradient=True - ) - - # standardize inpute to std of 1 - # combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True) + # # apply random offset to both latents + # offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype) + # offset = offset * 0.1 + # conditional_latents = conditional_latents + offset + # unconditional_latents = unconditional_latents + offset # - # # scale the latents to std of 1 - # conditional_latents = conditional_latents / combo_std - # unconditional_latents = unconditional_latents / combo_std + # # get random scale 0f 0.8 to 1.2 + # scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype) + # scale = scale * 0.4 + # scale = scale + 0.8 + # conditional_latents = conditional_latents * scale + # unconditional_latents = unconditional_latents * scale unconditional_diff = (unconditional_latents - conditional_latents) - - - # get a -0.5 to 0.5 multiplier for the diff noise - # noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype) - # noise_multiplier = noise_multiplier - 0.5 - noise_multiplier = 1.0 - - # unconditional_diff_noise = unconditional_diff * noise_multiplier - unconditional_diff_noise = unconditional_diff * noise_multiplier - # scale it to the timestep unconditional_diff_noise = sd.add_noise( torch.zeros_like(unconditional_latents), - unconditional_diff_noise, + unconditional_diff, timesteps ) - - unconditional_diff_noise = unconditional_diff_noise * 0.2 - unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False) - baseline_noisy_latents = sd.add_noise( - unconditional_latents, - noise, - timesteps - ).detach() - target_noise = noise + unconditional_diff_noise + noisy_latents = sd.add_noise( conditional_latents, target_noise, @@ -276,7 +244,7 @@ def get_targeted_guidance_loss( # Predict noise to get a baseline of what the parent network wants to do with the latents + noise. # This acts as our control to preserve the unaltered parts of the image. baseline_prediction = sd.predict_noise( - latents=baseline_noisy_latents.to(device, dtype=dtype).detach(), + latents=noisy_latents.to(device, dtype=dtype).detach(), conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), timestep=timesteps, guidance_scale=1.0, @@ -286,17 +254,16 @@ def get_targeted_guidance_loss( # determine the error for the baseline prediction baseline_prediction_error = baseline_prediction - noise + prediction_target = baseline_prediction_error + unconditional_diff_noise + + prediction_target = prediction_target.detach().requires_grad_(False) + # turn the LoRA network back on. sd.unet.train() sd.network.is_active = True sd.network.multiplier = network_weight_list - - # unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask) - # masked_noise = noise * diff_mask - # pred_target = unmasked_noise + unconditional_diff_noise - # do our prediction with LoRA active on the scaled guidance latents prediction = sd.predict_noise( latents=noisy_latents.to(device, dtype=dtype).detach(), @@ -306,32 +273,20 @@ def get_targeted_guidance_loss( **pred_kwargs # adapter residuals in here ) - - - baselined_prediction = prediction - baseline_prediction + prediction_error = prediction - noise guidance_loss = torch.nn.functional.mse_loss( - baselined_prediction.float(), + prediction_error.float(), # unconditional_diff_noise.float(), - unconditional_diff_noise.float(), + prediction_target.float(), reduction="none" ) guidance_loss = guidance_loss.mean([1, 2, 3]) guidance_loss = guidance_loss.mean() - - # do the masked noise prediction - masked_noise_loss = torch.nn.functional.mse_loss( - prediction.float(), - target_noise.float(), - reduction="none" - ) * diff_mask - masked_noise_loss = masked_noise_loss.mean([1, 2, 3]) - masked_noise_loss = masked_noise_loss.mean() - - - loss = guidance_loss + masked_noise_loss + # loss = guidance_loss + masked_noise_loss + loss = guidance_loss loss.backward() @@ -342,158 +297,6 @@ def get_targeted_guidance_loss( return loss -def get_targeted_guidance_loss_WIP( - noisy_latents: torch.Tensor, - conditional_embeds: 'PromptEmbeds', - match_adapter_assist: bool, - network_weight_list: list, - timesteps: torch.Tensor, - pred_kwargs: dict, - batch: 'DataLoaderBatchDTO', - noise: torch.Tensor, - sd: 'StableDiffusion', - **kwargs -): - with torch.no_grad(): - # Perform targeted guidance (working title) - dtype = get_torch_dtype(sd.torch_dtype) - device = sd.device_torch - - - conditional_latents = batch.latents.to(device, dtype=dtype).detach() - unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() - - # apply random offset to both latents - offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype) - offset = offset * 0.1 - conditional_latents = conditional_latents + offset - unconditional_latents = unconditional_latents + offset - - # get random scale 0f 0.8 to 1.2 - scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype) - scale = scale * 0.4 - scale = scale + 0.8 - conditional_latents = conditional_latents * scale - unconditional_latents = unconditional_latents * scale - - diff_mask = get_differential_mask( - conditional_latents, - unconditional_latents, - threshold=0.2, - gradient=True - ) - - # standardize inpute to std of 1 - # combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True) - # - # # scale the latents to std of 1 - # conditional_latents = conditional_latents / combo_std - # unconditional_latents = unconditional_latents / combo_std - - unconditional_diff = (unconditional_latents - conditional_latents) - - - - # get a -0.5 to 0.5 multiplier for the diff noise - # noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype) - # noise_multiplier = noise_multiplier - 0.5 - noise_multiplier = 1.0 - - # unconditional_diff_noise = unconditional_diff * noise_multiplier - unconditional_diff_noise = unconditional_diff * noise_multiplier - - # scale it to the timestep - unconditional_diff_noise = sd.add_noise( - torch.zeros_like(unconditional_latents), - unconditional_diff_noise, - timesteps - ) - - unconditional_diff_noise = unconditional_diff_noise * 0.2 - - unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False) - - baseline_noisy_latents = sd.add_noise( - unconditional_latents, - noise, - timesteps - ).detach() - - target_noise = noise + unconditional_diff_noise - noisy_latents = sd.add_noise( - conditional_latents, - target_noise, - # noise, - timesteps - ).detach() - # Disable the LoRA network so we can predict parent network knowledge without it - sd.network.is_active = False - sd.unet.eval() - - # Predict noise to get a baseline of what the parent network wants to do with the latents + noise. - # This acts as our control to preserve the unaltered parts of the image. - baseline_prediction = sd.predict_noise( - latents=baseline_noisy_latents.to(device, dtype=dtype).detach(), - conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), - timestep=timesteps, - guidance_scale=1.0, - **pred_kwargs # adapter residuals in here - ).detach().requires_grad_(False) - - # turn the LoRA network back on. - sd.unet.train() - sd.network.is_active = True - - sd.network.multiplier = network_weight_list - - # unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask) - # masked_noise = noise * diff_mask - # pred_target = unmasked_noise + unconditional_diff_noise - - # do our prediction with LoRA active on the scaled guidance latents - prediction = sd.predict_noise( - latents=noisy_latents.to(device, dtype=dtype).detach(), - conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), - timestep=timesteps, - guidance_scale=1.0, - **pred_kwargs # adapter residuals in here - ) - - - - baselined_prediction = prediction - baseline_prediction - - guidance_loss = torch.nn.functional.mse_loss( - baselined_prediction.float(), - # unconditional_diff_noise.float(), - unconditional_diff_noise.float(), - reduction="none" - ) - guidance_loss = guidance_loss.mean([1, 2, 3]) - - guidance_loss = guidance_loss.mean() - - - # do the masked noise prediction - masked_noise_loss = torch.nn.functional.mse_loss( - prediction.float(), - target_noise.float(), - reduction="none" - ) * diff_mask - masked_noise_loss = masked_noise_loss.mean([1, 2, 3]) - masked_noise_loss = masked_noise_loss.mean() - - - loss = guidance_loss + masked_noise_loss - - loss.backward() - - # detach it so parent class can run backward on no grads without throwing error - loss = loss.detach() - loss.requires_grad_(True) - - return loss - def get_guided_loss_polarity( noisy_latents: torch.Tensor, conditional_embeds: PromptEmbeds,