Bug fixes and improvements to token injection

This commit is contained in:
Jaret Burkett
2023-09-08 06:10:59 -06:00
parent 92a086d5a5
commit ce4f9fe02a
5 changed files with 74 additions and 63 deletions

View File

@@ -63,7 +63,7 @@ class SDTrainer(BaseSDTrainProcess):
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
self.optimizer.zero_grad()
flush()
flush()
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
@@ -71,6 +71,7 @@ class SDTrainer(BaseSDTrainProcess):
timestep=timesteps,
guidance_scale=1.0,
)
flush()
# 9.18 gb
noise = noise.to(self.device_torch, dtype=dtype)

View File

@@ -684,6 +684,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# zero any gradients
optimizer.zero_grad()
flush()
self.lr_scheduler.step(self.step_num)
@@ -721,6 +722,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
### HOOK ###
loss_dict = self.hook_train_loop(batch)
flush()
# setup the networks to gradient checkpointing and everything works
if self.embedding is not None or self.train_config.train_text_encoder:
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.train()
else:
self.sd.text_encoder.train()
self.sd.unet.train()
with torch.no_grad():
if self.train_config.optimizer.lower().startswith('dadaptation') or \

View File

@@ -173,7 +173,7 @@ class CaptionProcessingDTOMixin:
self: 'FileItemDTO',
trigger=None,
to_replace_list=None,
add_if_not_present=True
add_if_not_present=False
):
raw_caption = self.raw_caption
if raw_caption is None:

View File

@@ -141,8 +141,8 @@ class Embedding:
# however, on training we don't use that pipeline, so we have to do it ourselves
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
output_prompt = prompt
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", embedding_tokens]
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
default_replacements = ["[name]", "[trigger]"]
replace_with = embedding_tokens if expand_token else self.trigger
if to_replace_list is None:
@@ -167,7 +167,7 @@ class Embedding:
if num_instances > 1:
print(
f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt

View File

@@ -24,8 +24,6 @@ from toolkit.sampler import get_sampler
from toolkit.saving import save_ldm_model_from_diffusers
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import torch
from library import model_util
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
from diffusers.schedulers import DDPMScheduler
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
StableDiffusionKDiffusionXLPipeline
@@ -233,6 +231,9 @@ class StableDiffusion:
# scheduler doesn't get set sometimes, so we set it here
pipe.scheduler = self.noise_scheduler
# add hacks to unet to help training
# pipe.unet = prepare_unet_for_training(pipe.unet)
if self.model_config.vae_path is not None:
external_vae = load_vae(self.model_config.vae_path, dtype)
pipe.vae = external_vae
@@ -476,54 +477,52 @@ class StableDiffusion:
unconditional_embeddings: Union[PromptEmbeds, None] = None,
**kwargs,
):
# get the embeddings
if text_embeddings is None and conditional_embeddings is None:
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
if text_embeddings is None and unconditional_embeddings is not None:
text_embeddings = train_tools.concat_prompt_embeddings(
unconditional_embeddings, # negative embedding
conditional_embeddings, # positive embedding
1, # batch size
)
elif text_embeddings is None and conditional_embeddings is not None:
# not doing cfg
text_embeddings = conditional_embeddings
with torch.no_grad():
# get the embeddings
if text_embeddings is None and conditional_embeddings is None:
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
if text_embeddings is None and unconditional_embeddings is not None:
text_embeddings = train_tools.concat_prompt_embeddings(
unconditional_embeddings, # negative embedding
conditional_embeddings, # positive embedding
1, # batch size
)
elif text_embeddings is None and conditional_embeddings is not None:
# not doing cfg
text_embeddings = conditional_embeddings
# CFG is comparing neg and positive, if we have concatenated embeddings
# then we are doing it, otherwise we are not and takes half the time.
do_classifier_free_guidance = True
# CFG is comparing neg and positive, if we have concatenated embeddings
# then we are doing it, otherwise we are not and takes half the time.
do_classifier_free_guidance = True
# check if batch size of embeddings matches batch size of latents
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
do_classifier_free_guidance = False
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
# check if batch size of embeddings matches batch size of latents
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
do_classifier_free_guidance = False
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
if self.is_xl:
if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents)
with torch.no_grad():
# 16, 6 for bs of 4
if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents)
if do_classifier_free_guidance:
# todo check this with larget batches
add_time_ids = torch.cat([add_time_ids] * 2)
if do_classifier_free_guidance:
# todo check this with larget batches
add_time_ids = train_tools.concat_embeddings(
add_time_ids, add_time_ids, int(latents.shape[0])
)
latent_model_input = torch.cat([latents] * 2)
else:
# concat to fit batch size
add_time_ids = torch.cat([add_time_ids] * latents.shape[0])
latent_model_input = latents
if do_classifier_free_guidance:
latent_model_input = torch.cat([latents] * 2)
else:
latent_model_input = latents
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
added_cond_kwargs = {
# todo can we zero here the second text encoder? or match a blank string?
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": add_time_ids,
}
added_cond_kwargs = {
# todo can we zero here the second text encoder? or match a blank string?
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": add_time_ids,
}
# predict the noise residual
noise_pred = self.unet(
@@ -546,25 +545,26 @@ class StableDiffusion:
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
else:
if do_classifier_free_guidance:
# if we are doing classifier free guidance, need to double up
latent_model_input = torch.cat([latents] * 2)
else:
latent_model_input = latents
with torch.no_grad():
if do_classifier_free_guidance:
# if we are doing classifier free guidance, need to double up
latent_model_input = torch.cat([latents] * 2)
else:
latent_model_input = latents
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
# check if we need to concat timesteps
if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1:
ts_bs = timestep.shape[0]
if ts_bs != latent_model_input.shape[0]:
if ts_bs == 1:
timestep = torch.cat([timestep] * latent_model_input.shape[0])
elif ts_bs * 2 == latent_model_input.shape[0]:
timestep = torch.cat([timestep] * 2)
else:
raise ValueError(
f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
# check if we need to concat timesteps
if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1:
ts_bs = timestep.shape[0]
if ts_bs != latent_model_input.shape[0]:
if ts_bs == 1:
timestep = torch.cat([timestep] * latent_model_input.shape[0])
elif ts_bs * 2 == latent_model_input.shape[0]:
timestep = torch.cat([timestep] * 2)
else:
raise ValueError(
f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
# predict the noise residual
noise_pred = self.unet(