mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-05 18:49:50 +00:00
Bugfixes for slider reference
This commit is contained in:
@@ -159,7 +159,13 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
|
||||
# if training text encoder enable grads, else do context of no grad
|
||||
with torch.set_grad_enabled(self.train_config.train_text_encoder):
|
||||
conditional_embeds = self.sd.encode_prompt(prompts).to(self.device_torch, dtype=dtype)
|
||||
# fix issue with them being tuples sometimes
|
||||
prompt_list = []
|
||||
for prompt in prompts:
|
||||
if isinstance(prompt, tuple):
|
||||
prompt = prompt[0]
|
||||
prompt_list.append(prompt)
|
||||
conditional_embeds = self.sd.encode_prompt(prompt_list).to(self.device_torch, dtype=dtype)
|
||||
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
|
||||
# if self.model_config.is_xl:
|
||||
|
||||
Reference in New Issue
Block a user