mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-10 21:19:49 +00:00
added prompt dropout to happen indempendently on each TE
This commit is contained in:
@@ -344,6 +344,11 @@ class StableDiffusion:
|
||||
else:
|
||||
noise_scheduler = get_sampler(sampler)
|
||||
|
||||
try:
|
||||
noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype)
|
||||
except:
|
||||
pass
|
||||
|
||||
if sampler.startswith("sample_") and self.is_xl:
|
||||
# using kdiffusion
|
||||
Pipe = StableDiffusionKDiffusionXLPipeline
|
||||
@@ -722,7 +727,8 @@ class StableDiffusion:
|
||||
refiner_pred = self.refiner_unet(
|
||||
input_chunks[1],
|
||||
timestep_chunks[1],
|
||||
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], # just use the first second text encoder
|
||||
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:],
|
||||
# just use the first second text encoder
|
||||
added_cond_kwargs={
|
||||
"text_embeds": added_cond_kwargs_chunked['text_embeds'][1],
|
||||
# "time_ids": added_cond_kwargs_chunked['time_ids'][1],
|
||||
@@ -740,7 +746,8 @@ class StableDiffusion:
|
||||
# just use the first second text encoder
|
||||
added_cond_kwargs={
|
||||
"text_embeds": text_embeddings.pooled_embeds,
|
||||
"time_ids": self.get_time_ids_from_latents(latent_model_input, requires_aesthetic_score=True),
|
||||
"time_ids": self.get_time_ids_from_latents(latent_model_input,
|
||||
requires_aesthetic_score=True),
|
||||
},
|
||||
**kwargs,
|
||||
).sample
|
||||
@@ -845,7 +852,8 @@ class StableDiffusion:
|
||||
num_images_per_prompt=1,
|
||||
force_all=False,
|
||||
long_prompts=False,
|
||||
max_length=None
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
) -> PromptEmbeds:
|
||||
# sd1.5 embeddings are (bs, 77, 768)
|
||||
prompt = prompt
|
||||
@@ -875,12 +883,18 @@ class StableDiffusion:
|
||||
use_text_encoder_2=use_encoder_2,
|
||||
truncate=not long_prompts,
|
||||
max_length=max_length,
|
||||
dropout_prob=dropout_prob,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PromptEmbeds(
|
||||
train_tools.encode_prompts(
|
||||
self.tokenizer, self.text_encoder, prompt, truncate=not long_prompts, max_length=max_length
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt,
|
||||
truncate=not long_prompts,
|
||||
max_length=max_length,
|
||||
dropout_prob=dropout_prob
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1011,8 +1025,9 @@ class StableDiffusion:
|
||||
state_dict[new_key] = v
|
||||
return state_dict
|
||||
|
||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> OrderedDict[
|
||||
str, Parameter]:
|
||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \
|
||||
OrderedDict[
|
||||
str, Parameter]:
|
||||
named_params: OrderedDict[str, Parameter] = OrderedDict()
|
||||
if vae:
|
||||
for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"):
|
||||
@@ -1198,7 +1213,8 @@ class StableDiffusion:
|
||||
print(f"Found {len(params)} trainable parameter in text encoder")
|
||||
|
||||
if refiner:
|
||||
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, state_dict_keys=True)
|
||||
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True,
|
||||
state_dict_keys=True)
|
||||
refiner_lr = refiner_lr if refiner_lr is not None else default_lr
|
||||
params = []
|
||||
for key, diffusers_key in ldm_diffusers_keymap.items():
|
||||
|
||||
Reference in New Issue
Block a user