From 92cb5ae09601ba2cc24b14e2f4a19cc33d66974f Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 1 Dec 2023 06:30:52 -0700 Subject: [PATCH] Reworked targeted guidance algo --- extensions_built_in/sd_trainer/SDTrainer.py | 33 ++++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 93bee42d..f3a8a8ae 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -189,15 +189,31 @@ class SDTrainer(BaseSDTrainProcess): ): with torch.no_grad(): # Perform targeted guidance (working title) - conditional_noisy_latents = noisy_latents.detach() # target images dtype = get_torch_dtype(self.train_config.dtype) - if batch.unconditional_latents is not None: - # unconditional latents are the "neutral" images. Add noise here identical to - # the noise added to the conditional latents, at the same timesteps + conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach() - unconditional_noisy_latents = self.sd.add_noise(batch.unconditional_latents, noise, timesteps).detach() + unconditional_diff = unconditional_latents - conditional_latents + conditional_diff = conditional_latents - unconditional_latents + # we need to determine the amount of signal and noise that would be present at the current timestep + conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps) + unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps) + + target_noise = noise + unconditional_signal + + conditional_noisy_latents = self.sd.add_noise( + unconditional_latents + conditional_signal, + target_noise, + timesteps + ).detach() + + unconditional_noisy_latents = self.sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() # Disable the LoRA network so we can predict parent network knowledge without it self.network.is_active = False @@ -227,9 +243,12 @@ class SDTrainer(BaseSDTrainProcess): **pred_kwargs # adapter residuals in here ) + # remove baseline from our prediction to extract our differential prediction + prediction = prediction - baseline_prediction + loss = torch.nn.functional.mse_loss( prediction.float(), - baseline_prediction.float(), + unconditional_signal.float(), reduction="none" ) loss = loss.mean([1, 2, 3]) @@ -303,7 +322,6 @@ class SDTrainer(BaseSDTrainProcess): match_adapter_assist = False - # check if we are matching the adapter assistant if self.assistant_adapter: if self.train_config.match_adapter_chance == 1.0: @@ -321,7 +339,6 @@ class SDTrainer(BaseSDTrainProcess): if file_item.is_reg: loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight - adapter_images = None sigmas = None if has_adapter_img and (self.adapter or self.assistant_adapter):