From ce4f9fe02a2a863ba36d9edcaa85bf46cdb37ee5 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 8 Sep 2023 06:10:59 -0600 Subject: [PATCH] Bug fixes and improvements to token injection --- extensions_built_in/sd_trainer/SDTrainer.py | 3 +- jobs/process/BaseSDTrainProcess.py | 10 ++ toolkit/dataloader_mixins.py | 2 +- toolkit/embedding.py | 6 +- toolkit/stable_diffusion_model.py | 116 ++++++++++---------- 5 files changed, 74 insertions(+), 63 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 1ae27c9a..c9b7178f 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -63,7 +63,7 @@ class SDTrainer(BaseSDTrainProcess): # detach the embeddings conditional_embeds = conditional_embeds.detach() self.optimizer.zero_grad() - flush() + flush() noise_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype), @@ -71,6 +71,7 @@ class SDTrainer(BaseSDTrainProcess): timestep=timesteps, guidance_scale=1.0, ) + flush() # 9.18 gb noise = noise.to(self.device_torch, dtype=dtype) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index c6b516de..d2062912 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -684,6 +684,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # zero any gradients optimizer.zero_grad() + flush() self.lr_scheduler.step(self.step_num) @@ -721,6 +722,15 @@ class BaseSDTrainProcess(BaseTrainProcess): ### HOOK ### loss_dict = self.hook_train_loop(batch) flush() + # setup the networks to gradient checkpointing and everything works + if self.embedding is not None or self.train_config.train_text_encoder: + if isinstance(self.sd.text_encoder, list): + for te in self.sd.text_encoder: + te.train() + else: + self.sd.text_encoder.train() + + self.sd.unet.train() with torch.no_grad(): if self.train_config.optimizer.lower().startswith('dadaptation') or \ diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 7ed869b6..787657ec 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -173,7 +173,7 @@ class CaptionProcessingDTOMixin: self: 'FileItemDTO', trigger=None, to_replace_list=None, - add_if_not_present=True + add_if_not_present=False ): raw_caption = self.raw_caption if raw_caption is None: diff --git a/toolkit/embedding.py b/toolkit/embedding.py index b326dde3..142b3913 100644 --- a/toolkit/embedding.py +++ b/toolkit/embedding.py @@ -141,8 +141,8 @@ class Embedding: # however, on training we don't use that pipeline, so we have to do it ourselves def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True): output_prompt = prompt - embedding_tokens = self.embedding_tokens[0] # shoudl be the same - default_replacements = [self.name, self.trigger, "[name]", "[trigger]", embedding_tokens] + embedding_tokens = self.embedding_tokens[0] # shoudl be the same + default_replacements = ["[name]", "[trigger]"] replace_with = embedding_tokens if expand_token else self.trigger if to_replace_list is None: @@ -167,7 +167,7 @@ class Embedding: if num_instances > 1: print( - f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") return output_prompt diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 5e773acc..176b77a3 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -24,8 +24,6 @@ from toolkit.sampler import get_sampler from toolkit.saving import save_ldm_model_from_diffusers from toolkit.train_tools import get_torch_dtype, apply_noise_offset import torch -from library import model_util -from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl from diffusers.schedulers import DDPMScheduler from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ StableDiffusionKDiffusionXLPipeline @@ -233,6 +231,9 @@ class StableDiffusion: # scheduler doesn't get set sometimes, so we set it here pipe.scheduler = self.noise_scheduler + # add hacks to unet to help training + # pipe.unet = prepare_unet_for_training(pipe.unet) + if self.model_config.vae_path is not None: external_vae = load_vae(self.model_config.vae_path, dtype) pipe.vae = external_vae @@ -476,54 +477,52 @@ class StableDiffusion: unconditional_embeddings: Union[PromptEmbeds, None] = None, **kwargs, ): - # get the embeddings - if text_embeddings is None and conditional_embeddings is None: - raise ValueError("Either text_embeddings or conditional_embeddings must be specified") - if text_embeddings is None and unconditional_embeddings is not None: - text_embeddings = train_tools.concat_prompt_embeddings( - unconditional_embeddings, # negative embedding - conditional_embeddings, # positive embedding - 1, # batch size - ) - elif text_embeddings is None and conditional_embeddings is not None: - # not doing cfg - text_embeddings = conditional_embeddings + with torch.no_grad(): + # get the embeddings + if text_embeddings is None and conditional_embeddings is None: + raise ValueError("Either text_embeddings or conditional_embeddings must be specified") + if text_embeddings is None and unconditional_embeddings is not None: + text_embeddings = train_tools.concat_prompt_embeddings( + unconditional_embeddings, # negative embedding + conditional_embeddings, # positive embedding + 1, # batch size + ) + elif text_embeddings is None and conditional_embeddings is not None: + # not doing cfg + text_embeddings = conditional_embeddings - # CFG is comparing neg and positive, if we have concatenated embeddings - # then we are doing it, otherwise we are not and takes half the time. - do_classifier_free_guidance = True + # CFG is comparing neg and positive, if we have concatenated embeddings + # then we are doing it, otherwise we are not and takes half the time. + do_classifier_free_guidance = True - # check if batch size of embeddings matches batch size of latents - if latents.shape[0] == text_embeddings.text_embeds.shape[0]: - do_classifier_free_guidance = False - elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: - raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings") + # check if batch size of embeddings matches batch size of latents + if latents.shape[0] == text_embeddings.text_embeds.shape[0]: + do_classifier_free_guidance = False + elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: + raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings") if self.is_xl: - if add_time_ids is None: - add_time_ids = self.get_time_ids_from_latents(latents) + with torch.no_grad(): + # 16, 6 for bs of 4 + if add_time_ids is None: + add_time_ids = self.get_time_ids_from_latents(latents) + + if do_classifier_free_guidance: + # todo check this with larget batches + add_time_ids = torch.cat([add_time_ids] * 2) if do_classifier_free_guidance: - # todo check this with larget batches - add_time_ids = train_tools.concat_embeddings( - add_time_ids, add_time_ids, int(latents.shape[0]) - ) + latent_model_input = torch.cat([latents] * 2) else: - # concat to fit batch size - add_time_ids = torch.cat([add_time_ids] * latents.shape[0]) + latent_model_input = latents - if do_classifier_free_guidance: - latent_model_input = torch.cat([latents] * 2) - else: - latent_model_input = latents + latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) - latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) - - added_cond_kwargs = { - # todo can we zero here the second text encoder? or match a blank string? - "text_embeds": text_embeddings.pooled_embeds, - "time_ids": add_time_ids, - } + added_cond_kwargs = { + # todo can we zero here the second text encoder? or match a blank string? + "text_embeds": text_embeddings.pooled_embeds, + "time_ids": add_time_ids, + } # predict the noise residual noise_pred = self.unet( @@ -546,25 +545,26 @@ class StableDiffusion: noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) else: - if do_classifier_free_guidance: - # if we are doing classifier free guidance, need to double up - latent_model_input = torch.cat([latents] * 2) - else: - latent_model_input = latents + with torch.no_grad(): + if do_classifier_free_guidance: + # if we are doing classifier free guidance, need to double up + latent_model_input = torch.cat([latents] * 2) + else: + latent_model_input = latents - latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) + latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) - # check if we need to concat timesteps - if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1: - ts_bs = timestep.shape[0] - if ts_bs != latent_model_input.shape[0]: - if ts_bs == 1: - timestep = torch.cat([timestep] * latent_model_input.shape[0]) - elif ts_bs * 2 == latent_model_input.shape[0]: - timestep = torch.cat([timestep] * 2) - else: - raise ValueError( - f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") + # check if we need to concat timesteps + if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1: + ts_bs = timestep.shape[0] + if ts_bs != latent_model_input.shape[0]: + if ts_bs == 1: + timestep = torch.cat([timestep] * latent_model_input.shape[0]) + elif ts_bs * 2 == latent_model_input.shape[0]: + timestep = torch.cat([timestep] * 2) + else: + raise ValueError( + f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") # predict the noise residual noise_pred = self.unet(