From 083cefa78cd675908aa16dd77ec5a4b9f0b51ba0 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 10 Sep 2023 18:36:23 -0600 Subject: [PATCH] Bugfixes for slider reference --- .../ImageReferenceSliderTrainerProcess.py | 8 +++++++- toolkit/network_mixins.py | 4 +++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py index ea5413fe..11fa8a9e 100644 --- a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py +++ b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py @@ -159,7 +159,13 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): # if training text encoder enable grads, else do context of no grad with torch.set_grad_enabled(self.train_config.train_text_encoder): - conditional_embeds = self.sd.encode_prompt(prompts).to(self.device_torch, dtype=dtype) + # fix issue with them being tuples sometimes + prompt_list = [] + for prompt in prompts: + if isinstance(prompt, tuple): + prompt = prompt[0] + prompt_list.append(prompt) + conditional_embeds = self.sd.encode_prompt(prompt_list).to(self.device_torch, dtype=dtype) conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) # if self.model_config.is_xl: diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 7190f56d..7955642c 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -55,7 +55,7 @@ class ToolkitModuleMixin: self.is_checkpointing = False self.is_normalizing = False self.normalize_scaler = 1.0 - self._multiplier: Union[float, list, torch.Tensor] = 1.0 + self._multiplier: Union[float, list, torch.Tensor] = None # this allows us to set different multipliers on a per item in a batch basis # allowing us to run positive and negative weights in the same batch @@ -123,6 +123,8 @@ class ToolkitModuleMixin: return lx * scale def forward(self: Module, x): + if self._multiplier is None: + self.set_multiplier(0.0) org_forwarded = self.org_forward(x) lora_output = self._call_forward(x)