From 5276975fb0b73503bce1e3726728d1c4b18c8ee3 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 15 Jan 2024 08:31:09 -0700 Subject: [PATCH] Added additional config options for custom plugins I needed --- extensions_built_in/sd_trainer/SDTrainer.py | 5 ++- jobs/process/BaseSDTrainProcess.py | 14 +++++-- toolkit/config_modules.py | 1 + toolkit/dataloader_mixins.py | 3 -- toolkit/guidance.py | 1 + toolkit/lora_special.py | 1 + toolkit/stable_diffusion_model.py | 43 ++++++++++----------- 7 files changed, 37 insertions(+), 31 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index f18e95c7..64f1da48 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -802,7 +802,7 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.train_text_encoder: grad_on_text_encoder = True - if self.embedding: + if self.embedding is not None: grad_on_text_encoder = True if self.adapter and isinstance(self.adapter, ClipVisionAdapter): @@ -1095,13 +1095,14 @@ class SDTrainer(BaseSDTrainProcess): else: with self.timer('predict_unet'): if unconditional_embeds is not None: - unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype) + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() noise_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), unconditional_embeddings=unconditional_embeds, timestep=timesteps, guidance_scale=self.train_config.cfg_scale, + detach_unconditional=False, **pred_kwargs ) self.after_unet_predict() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 1b816161..fcccb68c 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -333,6 +333,9 @@ class BaseSDTrainProcess(BaseTrainProcess): # remove all but the latest max_step_saves_to_keep # items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep] + # remove duplicates + items_to_remove = list(dict.fromkeys(items_to_remove)) + for item in items_to_remove: self.print(f"Removing old save: {item}") if os.path.isdir(item): @@ -758,7 +761,7 @@ class BaseSDTrainProcess(BaseTrainProcess): do_double = False with self.timer('prepare_noise'): - num_train_timesteps = self.sd.noise_scheduler.config['num_train_timesteps'] + num_train_timesteps = self.train_config.num_train_timesteps if self.train_config.noise_scheduler in ['custom_lcm']: # we store this value on our custom one @@ -791,14 +794,14 @@ class BaseSDTrainProcess(BaseTrainProcess): orig_timesteps = torch.rand((batch_size,), device=latents.device) if content_or_style == 'content': - timestep_indices = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps'] + timestep_indices = orig_timesteps ** 3 * self.train_config.num_train_timesteps elif content_or_style == 'style': - timestep_indices = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps'] + timestep_indices = (1 - orig_timesteps ** 3) * self.train_config.num_train_timesteps timestep_indices = value_map( timestep_indices, 0, - self.sd.noise_scheduler.config['num_train_timesteps'] - 1, + self.train_config.num_train_timesteps - 1, min_noise_steps, max_noise_steps - 1 ) @@ -1234,6 +1237,9 @@ class BaseSDTrainProcess(BaseTrainProcess): # load last saved weights if latest_save_path is not None: self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) + if self.embedding.step > 1: + self.step_num = self.embedding.step + self.start_step = self.step_num # self.step_num = self.embedding.step # self.start_step = self.step_num diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index e2856544..9293ebb4 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -222,6 +222,7 @@ class TrainConfig: self.negative_prompt = kwargs.get('negative_prompt', None) # multiplier applied to loos on regularization images self.reg_weight = kwargs.get('reg_weight', 1.0) + self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000) # dropout that happens before encoding. It functions independently per text encoder self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index fc40b69f..7f2c52c9 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -610,9 +610,6 @@ class ClipImageFileItemDTOMixin: if self.dataset_config.clip_image_shuffle_augmentations: self.build_clip_imag_augmentation_transform() - # save the original tensor - self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img) - open_cv_image = np.array(img) # Convert RGB to BGR open_cv_image = open_cv_image[:, :, ::-1].copy() diff --git a/toolkit/guidance.py b/toolkit/guidance.py index 1fe3a95c..ef005c8f 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -413,6 +413,7 @@ def get_guided_loss_polarity( device = sd.device_torch with torch.no_grad(): dtype = get_torch_dtype(dtype) + noise = noise.to(device, dtype=dtype).detach() conditional_latents = batch.latents.to(device, dtype=dtype).detach() unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 4f566b03..2ad5f3f0 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -51,6 +51,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): use_bias: bool = False, **kwargs ): + self.can_merge_in = True """if alpha == 0 or None, alpha is rank (no scaling).""" ToolkitModuleMixin.__init__(self, network=network) torch.nn.Module.__init__(self) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index f1170355..cb30273a 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -26,7 +26,7 @@ from toolkit import train_tools from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.metadata import get_meta_for_safetensors from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT -from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds +from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds from toolkit.reference_adapter import ReferenceAdapter from toolkit.sampler import get_sampler from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers @@ -701,11 +701,6 @@ class StableDiffusion: # get scheduler class name scheduler_class_name = self.noise_scheduler.__class__.__name__ - index_noise_schedulers = [ - 'DPMSolverMultistepScheduler', - 'EulerDiscreteSchedulerOutput', - ] - # todo handle if timestep is single value original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0) @@ -770,6 +765,7 @@ class StableDiffusion: conditional_embeddings: Union[PromptEmbeds, None] = None, unconditional_embeddings: Union[PromptEmbeds, None] = None, is_input_scaled=False, + detach_unconditional=False, **kwargs, ): with torch.no_grad(): @@ -777,11 +773,10 @@ class StableDiffusion: if text_embeddings is None and conditional_embeddings is None: raise ValueError("Either text_embeddings or conditional_embeddings must be specified") if text_embeddings is None and unconditional_embeddings is not None: - text_embeddings = train_tools.concat_prompt_embeddings( + text_embeddings = concat_prompt_embeds([ unconditional_embeddings, # negative embedding conditional_embeddings, # positive embedding - 1, # batch size - ) + ]) elif text_embeddings is None and conditional_embeddings is not None: # not doing cfg text_embeddings = conditional_embeddings @@ -932,7 +927,7 @@ class StableDiffusion: with torch.no_grad(): if do_classifier_free_guidance: # if we are doing classifier free guidance, need to double up - latent_model_input = torch.cat([latents] * 2) + latent_model_input = torch.cat([latents] * 2, dim=0) timestep = torch.cat([timestep] * 2) else: latent_model_input = latents @@ -946,7 +941,7 @@ class StableDiffusion: if ts_bs == 1: timestep = torch.cat([timestep] * latent_model_input.shape[0]) elif ts_bs * 2 == latent_model_input.shape[0]: - timestep = torch.cat([timestep] * 2) + timestep = torch.cat([timestep] * 2, dim=0) else: raise ValueError( f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") @@ -961,7 +956,9 @@ class StableDiffusion: if do_classifier_free_guidance: # perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0) + if detach_unconditional: + noise_pred_uncond = noise_pred_uncond.detach() noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) @@ -973,13 +970,15 @@ class StableDiffusion: return noise_pred - def step_scheduler(self, model_input, latent_input, timestep_tensor): + def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None): + if noise_scheduler is None: + noise_scheduler = self.noise_scheduler # // sometimes they are on the wrong device, no idea why - if isinstance(self.noise_scheduler, DDPMScheduler) or isinstance(self.noise_scheduler, LCMScheduler): + if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler): try: - self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch) - self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch) - self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch) + noise_scheduler.betas = noise_scheduler.betas.to(self.device_torch) + noise_scheduler.alphas = noise_scheduler.alphas.to(self.device_torch) + noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(self.device_torch) except Exception as e: pass @@ -993,12 +992,12 @@ class StableDiffusion: for idx in range(model_input.shape[0]): # Reset it so it is unique for the - if hasattr(self.noise_scheduler, '_step_index'): - self.noise_scheduler._step_index = None - if hasattr(self.noise_scheduler, 'is_scale_input_called'): - self.noise_scheduler.is_scale_input_called = True + if hasattr(noise_scheduler, '_step_index'): + noise_scheduler._step_index = None + if hasattr(noise_scheduler, 'is_scale_input_called'): + noise_scheduler.is_scale_input_called = True out_chunks.append( - self.noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[ + noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[ 0] ) return torch.cat(out_chunks, dim=0)