mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
Bug fixes and improvements to token injection
This commit is contained in:
@@ -71,6 +71,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
flush()
|
||||
# 9.18 gb
|
||||
noise = noise.to(self.device_torch, dtype=dtype)
|
||||
|
||||
|
||||
@@ -684,6 +684,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# zero any gradients
|
||||
optimizer.zero_grad()
|
||||
flush()
|
||||
|
||||
self.lr_scheduler.step(self.step_num)
|
||||
|
||||
@@ -721,6 +722,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
### HOOK ###
|
||||
loss_dict = self.hook_train_loop(batch)
|
||||
flush()
|
||||
# setup the networks to gradient checkpointing and everything works
|
||||
if self.embedding is not None or self.train_config.train_text_encoder:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.train()
|
||||
else:
|
||||
self.sd.text_encoder.train()
|
||||
|
||||
self.sd.unet.train()
|
||||
|
||||
with torch.no_grad():
|
||||
if self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
|
||||
@@ -173,7 +173,7 @@ class CaptionProcessingDTOMixin:
|
||||
self: 'FileItemDTO',
|
||||
trigger=None,
|
||||
to_replace_list=None,
|
||||
add_if_not_present=True
|
||||
add_if_not_present=False
|
||||
):
|
||||
raw_caption = self.raw_caption
|
||||
if raw_caption is None:
|
||||
|
||||
@@ -142,7 +142,7 @@ class Embedding:
|
||||
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
|
||||
output_prompt = prompt
|
||||
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
|
||||
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", embedding_tokens]
|
||||
default_replacements = ["[name]", "[trigger]"]
|
||||
|
||||
replace_with = embedding_tokens if expand_token else self.trigger
|
||||
if to_replace_list is None:
|
||||
@@ -167,7 +167,7 @@ class Embedding:
|
||||
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
|
||||
|
||||
@@ -24,8 +24,6 @@ from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_ldm_model_from_diffusers
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import torch
|
||||
from library import model_util
|
||||
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
|
||||
from diffusers.schedulers import DDPMScheduler
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
StableDiffusionKDiffusionXLPipeline
|
||||
@@ -233,6 +231,9 @@ class StableDiffusion:
|
||||
# scheduler doesn't get set sometimes, so we set it here
|
||||
pipe.scheduler = self.noise_scheduler
|
||||
|
||||
# add hacks to unet to help training
|
||||
# pipe.unet = prepare_unet_for_training(pipe.unet)
|
||||
|
||||
if self.model_config.vae_path is not None:
|
||||
external_vae = load_vae(self.model_config.vae_path, dtype)
|
||||
pipe.vae = external_vae
|
||||
@@ -476,6 +477,7 @@ class StableDiffusion:
|
||||
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
with torch.no_grad():
|
||||
# get the embeddings
|
||||
if text_embeddings is None and conditional_embeddings is None:
|
||||
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
||||
@@ -500,17 +502,14 @@ class StableDiffusion:
|
||||
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
|
||||
|
||||
if self.is_xl:
|
||||
with torch.no_grad():
|
||||
# 16, 6 for bs of 4
|
||||
if add_time_ids is None:
|
||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# todo check this with larget batches
|
||||
add_time_ids = train_tools.concat_embeddings(
|
||||
add_time_ids, add_time_ids, int(latents.shape[0])
|
||||
)
|
||||
else:
|
||||
# concat to fit batch size
|
||||
add_time_ids = torch.cat([add_time_ids] * latents.shape[0])
|
||||
add_time_ids = torch.cat([add_time_ids] * 2)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
@@ -546,6 +545,7 @@ class StableDiffusion:
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if do_classifier_free_guidance:
|
||||
# if we are doing classifier free guidance, need to double up
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
|
||||
Reference in New Issue
Block a user