added prompt dropout to happen indempendently on each TE

This commit is contained in:
Jaret Burkett
2023-11-14 05:26:51 -07:00
parent 7782caa468
commit 4f9cdd916a
7 changed files with 144 additions and 15 deletions

View File

@@ -344,6 +344,11 @@ class StableDiffusion:
else:
noise_scheduler = get_sampler(sampler)
try:
noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype)
except:
pass
if sampler.startswith("sample_") and self.is_xl:
# using kdiffusion
Pipe = StableDiffusionKDiffusionXLPipeline
@@ -722,7 +727,8 @@ class StableDiffusion:
refiner_pred = self.refiner_unet(
input_chunks[1],
timestep_chunks[1],
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], # just use the first second text encoder
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:],
# just use the first second text encoder
added_cond_kwargs={
"text_embeds": added_cond_kwargs_chunked['text_embeds'][1],
# "time_ids": added_cond_kwargs_chunked['time_ids'][1],
@@ -740,7 +746,8 @@ class StableDiffusion:
# just use the first second text encoder
added_cond_kwargs={
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": self.get_time_ids_from_latents(latent_model_input, requires_aesthetic_score=True),
"time_ids": self.get_time_ids_from_latents(latent_model_input,
requires_aesthetic_score=True),
},
**kwargs,
).sample
@@ -845,7 +852,8 @@ class StableDiffusion:
num_images_per_prompt=1,
force_all=False,
long_prompts=False,
max_length=None
max_length=None,
dropout_prob=0.0,
) -> PromptEmbeds:
# sd1.5 embeddings are (bs, 77, 768)
prompt = prompt
@@ -875,12 +883,18 @@ class StableDiffusion:
use_text_encoder_2=use_encoder_2,
truncate=not long_prompts,
max_length=max_length,
dropout_prob=dropout_prob,
)
)
else:
return PromptEmbeds(
train_tools.encode_prompts(
self.tokenizer, self.text_encoder, prompt, truncate=not long_prompts, max_length=max_length
self.tokenizer,
self.text_encoder,
prompt,
truncate=not long_prompts,
max_length=max_length,
dropout_prob=dropout_prob
)
)
@@ -1011,8 +1025,9 @@ class StableDiffusion:
state_dict[new_key] = v
return state_dict
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> OrderedDict[
str, Parameter]:
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \
OrderedDict[
str, Parameter]:
named_params: OrderedDict[str, Parameter] = OrderedDict()
if vae:
for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"):
@@ -1198,7 +1213,8 @@ class StableDiffusion:
print(f"Found {len(params)} trainable parameter in text encoder")
if refiner:
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, state_dict_keys=True)
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True,
state_dict_keys=True)
refiner_lr = refiner_lr if refiner_lr is not None else default_lr
params = []
for key, diffusers_key in ldm_diffusers_keymap.items():