From 64f2b085b7b931f55dda668351f7dda03ce4736e Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 23 Jun 2024 14:47:40 -0600 Subject: [PATCH] Minor fixes --- testing/merge_in_text_encoder_adapter.py | 12 +++++-- toolkit/config_modules.py | 2 +- toolkit/guidance.py | 40 ++++++------------------ 3 files changed, 19 insertions(+), 35 deletions(-) diff --git a/testing/merge_in_text_encoder_adapter.py b/testing/merge_in_text_encoder_adapter.py index 9158384a..6903a6ad 100644 --- a/testing/merge_in_text_encoder_adapter.py +++ b/testing/merge_in_text_encoder_adapter.py @@ -7,16 +7,21 @@ from safetensors.torch import load_file, save_file from collections import OrderedDict import json -model_path = "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/objective_reality_v2.safetensors" +model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000" te_path = "google/flan-t5-xl" te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors" -output_path = "/home/jaret/Dev/models/hf/t5xl_sd15_v1" +output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw" print("Loading te adapter") te_aug_sd = load_file(te_aug_path) print("Loading model") -sd = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=torch.float16) +is_diffusers = (not os.path.exists(model_path)) or os.path.isdir(model_path) + +if is_diffusers: + sd = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) +else: + sd = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=torch.float16) print("Loading Text Encoder") # Load the text encoder @@ -74,6 +79,7 @@ for name in sd.unet.attn_processors.keys(): print("Saving unmodified model") +sd = sd.to("cpu", torch.float16) sd.save_pretrained( output_path, safe_serialization=True, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index c6a3245b..73b09764 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -339,7 +339,7 @@ class ModelConfig: self.is_v2: bool = kwargs.get('is_v2', False) self.is_xl: bool = kwargs.get('is_xl', False) self.is_pixart: bool = kwargs.get('is_pixart', False) - self.is_pixart_sigma: bool = kwargs.get('is_pixart', False) + self.is_pixart_sigma: bool = kwargs.get('is_pixart_sigma', False) self.is_v3: bool = kwargs.get('is_v3', False) if self.is_pixart_sigma: self.is_pixart = True diff --git a/toolkit/guidance.py b/toolkit/guidance.py index ba69c3e6..9f6bc1e1 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -240,7 +240,7 @@ def get_direct_guidance_loss( noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0) - guidance_scale = 1.25 + guidance_scale = 1.1 guidance_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) @@ -552,39 +552,17 @@ def get_guided_tnt( reduction="none" ) - with torch.no_grad(): - tnt_loss = this_loss - that_loss - - # create a mask by scaling loss from 0 to mean to 1 to 0 - # this will act to regularize unchanged areas to prior prediction - loss_min = tnt_loss.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0].min(dim=3, keepdim=True)[0] - loss_mean = tnt_loss.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) - mask = value_map( - torch.abs(tnt_loss), - loss_min, - loss_mean, - 0.0, - 1.0 - ).clamp(0.0, 1.0).detach() - - prior_mask = 1.0 - mask - - - this_loss = this_loss * mask - that_loss = that_loss * prior_mask - this_loss = this_loss.mean([1, 2, 3]) - that_loss = that_loss.mean([1, 2, 3]) + # negative loss on that + that_loss = -that_loss.mean([1, 2, 3]) - prior_loss = torch.nn.functional.mse_loss( - this_prediction.float(), - prior_pred.detach().float(), - reduction="none" - ) - prior_loss = prior_loss * prior_mask - prior_loss = prior_loss.mean([1, 2, 3]) + with torch.no_grad(): + # match that loss with this loss so it is not a negative value and same scale + that_loss_scaler = torch.abs(this_loss) / torch.abs(that_loss) - loss = prior_loss + this_loss - that_loss + that_loss = that_loss * that_loss_scaler * 0.01 + + loss = this_loss + that_loss loss = loss.mean()