Fixes to esrgan trainer. Moved logic for sd prompt embeddings out of diffusers pipeline so I can manipulate it

This commit is contained in:
Jaret Burkett
2023-09-16 17:41:07 -06:00
parent 27f343fc08
commit c698837241
11 changed files with 214 additions and 78 deletions

View File

@@ -368,6 +368,19 @@ class StableDiffusion:
torch.manual_seed(gen_config.seed)
torch.cuda.manual_seed(gen_config.seed)
# encode the prompt ourselves so we can do fun stuff with embeddings
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
)
# allow any manipulations to take place to embeddings
gen_config.post_process_embeddings(
conditional_embeds,
unconditional_embeds,
)
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
if self.is_xl:
# fix guidance rescale for sdxl
@@ -382,10 +395,14 @@ class StableDiffusion:
extra['use_karras_sigmas'] = True
img = pipeline(
prompt=gen_config.prompt,
prompt_2=gen_config.prompt_2,
negative_prompt=gen_config.negative_prompt,
negative_prompt_2=gen_config.negative_prompt_2,
# prompt=gen_config.prompt,
# prompt_2=gen_config.prompt_2,
prompt_embeds=conditional_embeds.text_embeds,
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
negative_prompt_embeds=unconditional_embeds.text_embeds,
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
# negative_prompt=gen_config.negative_prompt,
# negative_prompt_2=gen_config.negative_prompt_2,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
@@ -395,8 +412,10 @@ class StableDiffusion:
).images[0]
else:
img = pipeline(
prompt=gen_config.prompt,
negative_prompt=gen_config.negative_prompt,
# prompt=gen_config.prompt,
prompt_embeds=conditional_embeds.text_embeds,
negative_prompt_embeds=unconditional_embeds.text_embeds,
# negative_prompt=gen_config.negative_prompt,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
@@ -625,21 +644,25 @@ class StableDiffusion:
# return latents_steps
return latents
def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds:
def encode_prompt(self, prompt, prompt2=None, num_images_per_prompt=1, force_all=False) -> PromptEmbeds:
# sd1.5 embeddings are (bs, 77, 768)
prompt = prompt
# if it is not a list, make it one
if not isinstance(prompt, list):
prompt = [prompt]
if prompt2 is not None and not isinstance(prompt2, list):
prompt2 = [prompt2]
if self.is_xl:
return PromptEmbeds(
train_tools.encode_prompts_xl(
self.tokenizer,
self.text_encoder,
prompt,
prompt2,
num_images_per_prompt=num_images_per_prompt,
use_text_encoder_1=self.use_text_encoder_1,
use_text_encoder_2=self.use_text_encoder_2,
use_text_encoder_1=self.use_text_encoder_1 or force_all,
use_text_encoder_2=self.use_text_encoder_2 or force_all,
)
)
else: