mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
Fixes to esrgan trainer. Moved logic for sd prompt embeddings out of diffusers pipeline so I can manipulate it
This commit is contained in:
@@ -368,6 +368,19 @@ class StableDiffusion:
|
||||
torch.manual_seed(gen_config.seed)
|
||||
torch.cuda.manual_seed(gen_config.seed)
|
||||
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
)
|
||||
|
||||
# allow any manipulations to take place to embeddings
|
||||
gen_config.post_process_embeddings(
|
||||
conditional_embeds,
|
||||
unconditional_embeds,
|
||||
)
|
||||
|
||||
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
|
||||
if self.is_xl:
|
||||
# fix guidance rescale for sdxl
|
||||
@@ -382,10 +395,14 @@ class StableDiffusion:
|
||||
extra['use_karras_sigmas'] = True
|
||||
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
prompt_2=gen_config.prompt_2,
|
||||
negative_prompt=gen_config.negative_prompt,
|
||||
negative_prompt_2=gen_config.negative_prompt_2,
|
||||
# prompt=gen_config.prompt,
|
||||
# prompt_2=gen_config.prompt_2,
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
||||
# negative_prompt=gen_config.negative_prompt,
|
||||
# negative_prompt_2=gen_config.negative_prompt_2,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
@@ -395,8 +412,10 @@ class StableDiffusion:
|
||||
).images[0]
|
||||
else:
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
negative_prompt=gen_config.negative_prompt,
|
||||
# prompt=gen_config.prompt,
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
# negative_prompt=gen_config.negative_prompt,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
@@ -625,21 +644,25 @@ class StableDiffusion:
|
||||
# return latents_steps
|
||||
return latents
|
||||
|
||||
def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds:
|
||||
def encode_prompt(self, prompt, prompt2=None, num_images_per_prompt=1, force_all=False) -> PromptEmbeds:
|
||||
# sd1.5 embeddings are (bs, 77, 768)
|
||||
prompt = prompt
|
||||
# if it is not a list, make it one
|
||||
if not isinstance(prompt, list):
|
||||
prompt = [prompt]
|
||||
|
||||
if prompt2 is not None and not isinstance(prompt2, list):
|
||||
prompt2 = [prompt2]
|
||||
if self.is_xl:
|
||||
return PromptEmbeds(
|
||||
train_tools.encode_prompts_xl(
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt,
|
||||
prompt2,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
use_text_encoder_1=self.use_text_encoder_1,
|
||||
use_text_encoder_2=self.use_text_encoder_2,
|
||||
use_text_encoder_1=self.use_text_encoder_1 or force_all,
|
||||
use_text_encoder_2=self.use_text_encoder_2 or force_all,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user