mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 02:01:29 +00:00
Added multiplier jitter, min_snr, ability to choose sdxl encoders to use, shuffle generator, and other fun
This commit is contained in:
@@ -16,10 +16,6 @@ from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
from leco import train_util
|
||||
import torch
|
||||
from library import model_util
|
||||
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
|
||||
@@ -124,6 +120,9 @@ class StableDiffusion:
|
||||
self.is_xl = model_config.is_xl
|
||||
self.is_v2 = model_config.is_v2
|
||||
|
||||
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
||||
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
return
|
||||
@@ -309,6 +308,7 @@ class StableDiffusion:
|
||||
torch.manual_seed(gen_config.seed)
|
||||
torch.cuda.manual_seed(gen_config.seed)
|
||||
|
||||
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
|
||||
if self.is_xl:
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
@@ -393,7 +393,7 @@ class StableDiffusion:
|
||||
dtype = latents.dtype
|
||||
|
||||
if self.is_xl:
|
||||
prompt_ids = train_util.get_add_time_ids(
|
||||
prompt_ids = train_tools.get_add_time_ids(
|
||||
height,
|
||||
width,
|
||||
dynamic_crops=False, # look into this
|
||||
@@ -444,7 +444,7 @@ class StableDiffusion:
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# todo check this with larget batches
|
||||
add_time_ids = train_util.concat_embeddings(
|
||||
add_time_ids = train_tools.concat_embeddings(
|
||||
add_time_ids, add_time_ids, int(latents.shape[0])
|
||||
)
|
||||
else:
|
||||
@@ -459,6 +459,7 @@ class StableDiffusion:
|
||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
added_cond_kwargs = {
|
||||
# todo can we zero here the second text encoder? or match a blank string?
|
||||
"text_embeds": text_embeddings.pooled_embeds,
|
||||
"time_ids": add_time_ids,
|
||||
}
|
||||
@@ -541,16 +542,18 @@ class StableDiffusion:
|
||||
prompt = [prompt]
|
||||
if self.is_xl:
|
||||
return PromptEmbeds(
|
||||
train_util.encode_prompts_xl(
|
||||
train_tools.encode_prompts_xl(
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
use_text_encoder_1=self.use_text_encoder_1,
|
||||
use_text_encoder_2=self.use_text_encoder_2,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PromptEmbeds(
|
||||
train_util.encode_prompts(
|
||||
train_tools.encode_prompts(
|
||||
self.tokenizer, self.text_encoder, prompt
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user