From 377b81ee3e83f23dd9db97f18cbd21f60864424f Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 19 Apr 2024 15:00:35 -0600 Subject: [PATCH] Adjustments to guidance --- extensions_built_in/sd_trainer/SDTrainer.py | 12 +++++- toolkit/guidance.py | 42 +++++++++++++++++++-- 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 1daf3773..5742edea 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -16,7 +16,7 @@ from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.config_modules import GuidanceConfig from toolkit.data_loader import get_dataloader_datasets from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO -from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss +from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType from toolkit.image_utils import show_tensors, show_latents from toolkit.ip_adapter import IPAdapter from toolkit.custom_adapter import CustomAdapter @@ -1293,8 +1293,16 @@ class SDTrainer(BaseSDTrainProcess): do_correct_pred_norm_prior = self.train_config.correct_pred_norm + do_guidance_prior = False + + if batch.unconditional_latents is not None: + # for this not that, we need a prior pred to normalize + guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type + if guidance_type == 'tnt': + do_guidance_prior = True + if (( - has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm): + has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm): with self.timer('prior predict'): prior_pred = self.get_prior_prediction( noisy_latents=noisy_latents, diff --git a/toolkit/guidance.py b/toolkit/guidance.py index d8e6ad5a..13ab65ea 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -482,6 +482,7 @@ def get_guided_loss_polarity( return loss + def get_guided_tnt( noisy_latents: torch.Tensor, conditional_embeds: PromptEmbeds, @@ -492,6 +493,7 @@ def get_guided_tnt( batch: 'DataLoaderBatchDTO', noise: torch.Tensor, sd: 'StableDiffusion', + prior_pred: torch.Tensor = None, **kwargs ): dtype = get_torch_dtype(sd.torch_dtype) @@ -528,6 +530,7 @@ def get_guided_tnt( sd.network.multiplier = cat_network_weight_list sd.network.is_active = True + prediction = sd.predict_noise( latents=cat_latents.to(device, dtype=dtype).detach(), conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), @@ -547,11 +550,41 @@ def get_guided_tnt( that_prediction.float(), noise.float(), reduction="none" - ) * -1.0 + ) - loss = this_loss + that_loss + with torch.no_grad(): + tnt_loss = this_loss - that_loss - loss = loss.mean([1, 2, 3]) + # create a mask by scaling loss from 0 to mean to 1 to 0 + # this will act to regularize unchanged areas to prior prediction + loss_min = tnt_loss.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0].min(dim=3, keepdim=True)[0] + loss_mean = tnt_loss.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) + mask = value_map( + torch.abs(tnt_loss), + loss_min, + loss_mean, + 0.0, + 1.0 + ).clamp(0.0, 1.0).detach() + + prior_mask = 1.0 - mask + + + this_loss = this_loss * mask + that_loss = that_loss * prior_mask + + this_loss = this_loss.mean([1, 2, 3]) + that_loss = that_loss.mean([1, 2, 3]) + + prior_loss = torch.nn.functional.mse_loss( + this_prediction.float(), + prior_pred.detach().float(), + reduction="none" + ) + prior_loss = prior_loss * prior_mask + prior_loss = prior_loss.mean([1, 2, 3]) + + loss = prior_loss + this_loss - that_loss loss.backward() @@ -612,7 +645,7 @@ def get_guidance_loss( ) elif guidance_type == "tnt": assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance" - return get_guided_loss_polarity( + return get_guided_tnt( noisy_latents, conditional_embeds, match_adapter_assist, @@ -622,6 +655,7 @@ def get_guidance_loss( batch, noise, sd, + prior_pred=prior_pred, **kwargs )