Bugfixes for slider reference

This commit is contained in:
Jaret Burkett
2023-09-10 18:36:23 -06:00
parent b5ec8e4eb1
commit 083cefa78c
2 changed files with 10 additions and 2 deletions

View File

@@ -159,7 +159,13 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
# if training text encoder enable grads, else do context of no grad
with torch.set_grad_enabled(self.train_config.train_text_encoder):
conditional_embeds = self.sd.encode_prompt(prompts).to(self.device_torch, dtype=dtype)
# fix issue with them being tuples sometimes
prompt_list = []
for prompt in prompts:
if isinstance(prompt, tuple):
prompt = prompt[0]
prompt_list.append(prompt)
conditional_embeds = self.sd.encode_prompt(prompt_list).to(self.device_torch, dtype=dtype)
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
# if self.model_config.is_xl: