added prompt dropout to happen indempendently on each TE

This commit is contained in:
Jaret Burkett
2023-11-14 05:26:51 -07:00
parent 7782caa468
commit 4f9cdd916a
7 changed files with 144 additions and 15 deletions

View File

@@ -62,8 +62,6 @@ class SDTrainer(BaseSDTrainProcess):
# offload it. Already cached
self.sd.vae.to('cpu')
flush()
self.sd.noise_scheduler.set_timesteps(1000)
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
# you can expand these in a child class to make customization easier
@@ -478,9 +476,10 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('encode_prompt'):
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
long_prompts=True).to(
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=True).to(
self.device_torch,
dtype=dtype)
else:
@@ -491,9 +490,10 @@ class SDTrainer(BaseSDTrainProcess):
te.eval()
else:
self.sd.text_encoder.eval()
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
long_prompts=True).to(
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=True).to(
self.device_torch,
dtype=dtype)