mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
added prompt dropout to happen indempendently on each TE
This commit is contained in:
@@ -62,8 +62,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# offload it. Already cached
|
||||
self.sd.vae.to('cpu')
|
||||
flush()
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(1000)
|
||||
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
|
||||
|
||||
# you can expand these in a child class to make customization easier
|
||||
@@ -478,9 +476,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
with self.timer('encode_prompt'):
|
||||
if grad_on_text_encoder:
|
||||
with torch.set_grad_enabled(True):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
|
||||
long_prompts=True).to(
|
||||
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
||||
conditional_embeds = self.sd.encode_prompt(
|
||||
conditioned_prompts, prompt_2,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
long_prompts=True).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
else:
|
||||
@@ -491,9 +490,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
|
||||
long_prompts=True).to(
|
||||
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
||||
conditional_embeds = self.sd.encode_prompt(
|
||||
conditioned_prompts, prompt_2,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
long_prompts=True).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user