Added multiplier jitter, min_snr, ability to choose sdxl encoders to use, shuffle generator, and other fun

This commit is contained in:
Jaret Burkett
2023-08-19 05:54:22 -06:00
parent 80e2f4a2a4
commit 90eedb78bf
7 changed files with 239 additions and 35 deletions

View File

@@ -16,10 +16,6 @@ from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.paths import REPOS_ROOT
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from leco import train_util
import torch
from library import model_util
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
@@ -124,6 +120,9 @@ class StableDiffusion:
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
self.use_text_encoder_1 = model_config.use_text_encoder_1
self.use_text_encoder_2 = model_config.use_text_encoder_2
def load_model(self):
if self.is_loaded:
return
@@ -309,6 +308,7 @@ class StableDiffusion:
torch.manual_seed(gen_config.seed)
torch.cuda.manual_seed(gen_config.seed)
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
if self.is_xl:
img = pipeline(
prompt=gen_config.prompt,
@@ -393,7 +393,7 @@ class StableDiffusion:
dtype = latents.dtype
if self.is_xl:
prompt_ids = train_util.get_add_time_ids(
prompt_ids = train_tools.get_add_time_ids(
height,
width,
dynamic_crops=False, # look into this
@@ -444,7 +444,7 @@ class StableDiffusion:
if do_classifier_free_guidance:
# todo check this with larget batches
add_time_ids = train_util.concat_embeddings(
add_time_ids = train_tools.concat_embeddings(
add_time_ids, add_time_ids, int(latents.shape[0])
)
else:
@@ -459,6 +459,7 @@ class StableDiffusion:
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
added_cond_kwargs = {
# todo can we zero here the second text encoder? or match a blank string?
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": add_time_ids,
}
@@ -541,16 +542,18 @@ class StableDiffusion:
prompt = [prompt]
if self.is_xl:
return PromptEmbeds(
train_util.encode_prompts_xl(
train_tools.encode_prompts_xl(
self.tokenizer,
self.text_encoder,
prompt,
num_images_per_prompt=num_images_per_prompt,
use_text_encoder_1=self.use_text_encoder_1,
use_text_encoder_2=self.use_text_encoder_2,
)
)
else:
return PromptEmbeds(
train_util.encode_prompts(
train_tools.encode_prompts(
self.tokenizer, self.text_encoder, prompt
)
)