Added long prompts to general training

This commit is contained in:
Jaret Burkett
2023-10-23 08:12:58 -06:00
parent 9905a1e205
commit dc36bbb3c8

View File

@@ -285,7 +285,7 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('encode_prompt'):
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
self.device_torch,
dtype=dtype)
else:
@@ -296,7 +296,7 @@ class SDTrainer(BaseSDTrainProcess):
te.eval()
else:
self.sd.text_encoder.eval()
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
self.device_torch,
dtype=dtype)