mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bug fixes and improvements to token injection
This commit is contained in:
@@ -63,7 +63,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
self.optimizer.zero_grad()
|
||||
flush()
|
||||
flush()
|
||||
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
@@ -71,6 +71,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
flush()
|
||||
# 9.18 gb
|
||||
noise = noise.to(self.device_torch, dtype=dtype)
|
||||
|
||||
|
||||
@@ -684,6 +684,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# zero any gradients
|
||||
optimizer.zero_grad()
|
||||
flush()
|
||||
|
||||
self.lr_scheduler.step(self.step_num)
|
||||
|
||||
@@ -721,6 +722,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
### HOOK ###
|
||||
loss_dict = self.hook_train_loop(batch)
|
||||
flush()
|
||||
# setup the networks to gradient checkpointing and everything works
|
||||
if self.embedding is not None or self.train_config.train_text_encoder:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.train()
|
||||
else:
|
||||
self.sd.text_encoder.train()
|
||||
|
||||
self.sd.unet.train()
|
||||
|
||||
with torch.no_grad():
|
||||
if self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
|
||||
@@ -173,7 +173,7 @@ class CaptionProcessingDTOMixin:
|
||||
self: 'FileItemDTO',
|
||||
trigger=None,
|
||||
to_replace_list=None,
|
||||
add_if_not_present=True
|
||||
add_if_not_present=False
|
||||
):
|
||||
raw_caption = self.raw_caption
|
||||
if raw_caption is None:
|
||||
|
||||
@@ -141,8 +141,8 @@ class Embedding:
|
||||
# however, on training we don't use that pipeline, so we have to do it ourselves
|
||||
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
|
||||
output_prompt = prompt
|
||||
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
|
||||
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", embedding_tokens]
|
||||
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
|
||||
default_replacements = ["[name]", "[trigger]"]
|
||||
|
||||
replace_with = embedding_tokens if expand_token else self.trigger
|
||||
if to_replace_list is None:
|
||||
@@ -167,7 +167,7 @@ class Embedding:
|
||||
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
|
||||
|
||||
@@ -24,8 +24,6 @@ from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_ldm_model_from_diffusers
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import torch
|
||||
from library import model_util
|
||||
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
|
||||
from diffusers.schedulers import DDPMScheduler
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
StableDiffusionKDiffusionXLPipeline
|
||||
@@ -233,6 +231,9 @@ class StableDiffusion:
|
||||
# scheduler doesn't get set sometimes, so we set it here
|
||||
pipe.scheduler = self.noise_scheduler
|
||||
|
||||
# add hacks to unet to help training
|
||||
# pipe.unet = prepare_unet_for_training(pipe.unet)
|
||||
|
||||
if self.model_config.vae_path is not None:
|
||||
external_vae = load_vae(self.model_config.vae_path, dtype)
|
||||
pipe.vae = external_vae
|
||||
@@ -476,54 +477,52 @@ class StableDiffusion:
|
||||
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# get the embeddings
|
||||
if text_embeddings is None and conditional_embeddings is None:
|
||||
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
||||
if text_embeddings is None and unconditional_embeddings is not None:
|
||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||
unconditional_embeddings, # negative embedding
|
||||
conditional_embeddings, # positive embedding
|
||||
1, # batch size
|
||||
)
|
||||
elif text_embeddings is None and conditional_embeddings is not None:
|
||||
# not doing cfg
|
||||
text_embeddings = conditional_embeddings
|
||||
with torch.no_grad():
|
||||
# get the embeddings
|
||||
if text_embeddings is None and conditional_embeddings is None:
|
||||
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
||||
if text_embeddings is None and unconditional_embeddings is not None:
|
||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||
unconditional_embeddings, # negative embedding
|
||||
conditional_embeddings, # positive embedding
|
||||
1, # batch size
|
||||
)
|
||||
elif text_embeddings is None and conditional_embeddings is not None:
|
||||
# not doing cfg
|
||||
text_embeddings = conditional_embeddings
|
||||
|
||||
# CFG is comparing neg and positive, if we have concatenated embeddings
|
||||
# then we are doing it, otherwise we are not and takes half the time.
|
||||
do_classifier_free_guidance = True
|
||||
# CFG is comparing neg and positive, if we have concatenated embeddings
|
||||
# then we are doing it, otherwise we are not and takes half the time.
|
||||
do_classifier_free_guidance = True
|
||||
|
||||
# check if batch size of embeddings matches batch size of latents
|
||||
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
|
||||
do_classifier_free_guidance = False
|
||||
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
|
||||
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
|
||||
# check if batch size of embeddings matches batch size of latents
|
||||
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
|
||||
do_classifier_free_guidance = False
|
||||
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
|
||||
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
|
||||
|
||||
if self.is_xl:
|
||||
if add_time_ids is None:
|
||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||
with torch.no_grad():
|
||||
# 16, 6 for bs of 4
|
||||
if add_time_ids is None:
|
||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# todo check this with larget batches
|
||||
add_time_ids = torch.cat([add_time_ids] * 2)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# todo check this with larget batches
|
||||
add_time_ids = train_tools.concat_embeddings(
|
||||
add_time_ids, add_time_ids, int(latents.shape[0])
|
||||
)
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
else:
|
||||
# concat to fit batch size
|
||||
add_time_ids = torch.cat([add_time_ids] * latents.shape[0])
|
||||
latent_model_input = latents
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
added_cond_kwargs = {
|
||||
# todo can we zero here the second text encoder? or match a blank string?
|
||||
"text_embeds": text_embeddings.pooled_embeds,
|
||||
"time_ids": add_time_ids,
|
||||
}
|
||||
added_cond_kwargs = {
|
||||
# todo can we zero here the second text encoder? or match a blank string?
|
||||
"text_embeds": text_embeddings.pooled_embeds,
|
||||
"time_ids": add_time_ids,
|
||||
}
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
@@ -546,25 +545,26 @@ class StableDiffusion:
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
else:
|
||||
if do_classifier_free_guidance:
|
||||
# if we are doing classifier free guidance, need to double up
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
with torch.no_grad():
|
||||
if do_classifier_free_guidance:
|
||||
# if we are doing classifier free guidance, need to double up
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
# check if we need to concat timesteps
|
||||
if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1:
|
||||
ts_bs = timestep.shape[0]
|
||||
if ts_bs != latent_model_input.shape[0]:
|
||||
if ts_bs == 1:
|
||||
timestep = torch.cat([timestep] * latent_model_input.shape[0])
|
||||
elif ts_bs * 2 == latent_model_input.shape[0]:
|
||||
timestep = torch.cat([timestep] * 2)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
||||
# check if we need to concat timesteps
|
||||
if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1:
|
||||
ts_bs = timestep.shape[0]
|
||||
if ts_bs != latent_model_input.shape[0]:
|
||||
if ts_bs == 1:
|
||||
timestep = torch.cat([timestep] * latent_model_input.shape[0])
|
||||
elif ts_bs * 2 == latent_model_input.shape[0]:
|
||||
timestep = torch.cat([timestep] * 2)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
|
||||
Reference in New Issue
Block a user