Bugfixes for slider reference

This commit is contained in:
Jaret Burkett
2023-09-10 18:36:23 -06:00
parent b5ec8e4eb1
commit 083cefa78c
2 changed files with 10 additions and 2 deletions

View File

@@ -159,7 +159,13 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
# if training text encoder enable grads, else do context of no grad
with torch.set_grad_enabled(self.train_config.train_text_encoder):
conditional_embeds = self.sd.encode_prompt(prompts).to(self.device_torch, dtype=dtype)
# fix issue with them being tuples sometimes
prompt_list = []
for prompt in prompts:
if isinstance(prompt, tuple):
prompt = prompt[0]
prompt_list.append(prompt)
conditional_embeds = self.sd.encode_prompt(prompt_list).to(self.device_torch, dtype=dtype)
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
# if self.model_config.is_xl:

View File

@@ -55,7 +55,7 @@ class ToolkitModuleMixin:
self.is_checkpointing = False
self.is_normalizing = False
self.normalize_scaler = 1.0
self._multiplier: Union[float, list, torch.Tensor] = 1.0
self._multiplier: Union[float, list, torch.Tensor] = None
# this allows us to set different multipliers on a per item in a batch basis
# allowing us to run positive and negative weights in the same batch
@@ -123,6 +123,8 @@ class ToolkitModuleMixin:
return lx * scale
def forward(self: Module, x):
if self._multiplier is None:
self.set_multiplier(0.0)
org_forwarded = self.org_forward(x)
lora_output = self._call_forward(x)