From fa6d91ba7610edaa6116118da9d0f1bdd7242731 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 9 Nov 2023 06:09:36 -0700 Subject: [PATCH] Diffirential guidance working, but I may have a better way --- extensions_built_in/sd_trainer/SDTrainer.py | 51 +++++++++++++++++---- toolkit/basic.py | 22 +++++++++ 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index a64c289f..18025c29 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -3,7 +3,7 @@ from typing import Union from diffusers import T2IAdapter from toolkit import train_tools -from toolkit.basic import value_map +from toolkit.basic import value_map, adain from toolkit.config_modules import GuidanceConfig from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.ip_adapter import IPAdapter @@ -202,20 +202,39 @@ class SDTrainer(BaseSDTrainProcess): # calculate the differential between our conditional (target image) and out unconditional ("bad" image) target_differential = unconditional_noisy_latents - conditional_noisy_latents + + # scale the target differential by the scheduler + # todo, scale it the right way + # target_differential = self.sd.noise_scheduler.add_noise( + # torch.zeros_like(target_differential), + # target_differential, + # timesteps + # ) + target_differential = target_differential.detach() # add the target differential to the target latents as if it were noise with the scheduler scaled to # the current timestep. Scaling the noise here is IMPORTANT and will lead to a blurry targeted area if not done # properly - guidance_latents = self.sd.noise_scheduler.add_noise( - conditional_noisy_latents, - target_differential, - timesteps - ) + # guidance_latents = self.sd.noise_scheduler.add_noise( + # conditional_noisy_latents, + # target_differential, + # timesteps + # ) + + # guidance_latents = conditional_noisy_latents + target_differential + # target_noise = conditional_noisy_latents + target_differential # With LoRA network bypassed, predict noise to get a baseline of what the network # wants to do with the latents + noise. Pass our target latents here for the input. target_unconditional = self.sd.predict_noise( + latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ).detach() + target_conditional = self.sd.predict_noise( latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), timestep=timesteps, @@ -223,6 +242,12 @@ class SDTrainer(BaseSDTrainProcess): **pred_kwargs # adapter residuals in here ).detach() + # we calculate the networks current knowledge so we do not overlearn what we know + current_knowledge = target_unconditional - target_conditional + + # we now have the differential noise prediction needed to create our convergence target + target_unknown_knowledge = target_differential - current_knowledge + # turn the LoRA network back on. self.sd.unet.train() self.network.is_active = True @@ -231,7 +256,7 @@ class SDTrainer(BaseSDTrainProcess): # with LoRA active, predict the noise with the scaled differential latents added. This will allow us # the opportunity to predict the differential + noise that was added to the latents. prediction_unconditional = self.sd.predict_noise( - latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(), + latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), timestep=timesteps, guidance_scale=1.0, @@ -240,18 +265,26 @@ class SDTrainer(BaseSDTrainProcess): # remove the baseline conditional prediction. This will leave only the divergence from the baseline and # the prediction of the added differential noise - prediction_positive = prediction_unconditional - target_unconditional + # prediction_positive = prediction_unconditional - target_unconditional + prediction_positive = target_unconditional - prediction_unconditional # for loss, we target ONLY the unscaled differential between our conditional and unconditional latents # this is the diffusion training process. # This will guide the network to make identical predictions it previously did for everything EXCEPT our # differential between the conditional and unconditional images + positive_loss = torch.nn.functional.mse_loss( prediction_positive.float(), - target_differential.float(), + target_unknown_knowledge.float(), reduction="none" ) + + # add adain loss + positive_loss = positive_loss + positive_loss = positive_loss.mean([1, 2, 3]) + + # positive_loss = positive_loss + adain_loss.mean([1, 2, 3]) # send it backwards BEFORE switching network polarity positive_loss = self.apply_snr(positive_loss, timesteps) positive_loss = positive_loss.mean() diff --git a/toolkit/basic.py b/toolkit/basic.py index f0464d69..9a64ca11 100644 --- a/toolkit/basic.py +++ b/toolkit/basic.py @@ -11,3 +11,25 @@ def flush(garbage_collect=True): torch.cuda.empty_cache() if garbage_collect: gc.collect() + + +def adain(content_features, style_features): + # Assumes that the content and style features are of shape (batch_size, channels, width, height) + + # Step 1: Calculate mean and variance of content features + content_mean, content_var = torch.mean(content_features, dim=[2, 3], keepdim=True), torch.var(content_features, + dim=[2, 3], + keepdim=True) + # Step 2: Calculate mean and variance of style features + style_mean, style_var = torch.mean(style_features, dim=[2, 3], keepdim=True), torch.var(style_features, dim=[2, 3], + keepdim=True) + + # Step 3: Normalize content features + content_std = torch.sqrt(content_var + 1e-5) + normalized_content = (content_features - content_mean) / content_std + + # Step 4: Scale and shift normalized content with style's statistics + style_std = torch.sqrt(style_var + 1e-5) + stylized_content = normalized_content * style_std + style_mean + + return stylized_content