From 4f9cdd916a392cf94ceeeeca1ed54bee56bb8394 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 14 Nov 2023 05:26:51 -0700 Subject: [PATCH] added prompt dropout to happen indempendently on each TE --- extensions_built_in/sd_trainer/SDTrainer.py | 16 ++--- jobs/process/BaseSDTrainProcess.py | 23 +++++++ scripts/make_lcm_sdxl_model.py | 67 +++++++++++++++++++++ toolkit/config_modules.py | 3 + toolkit/sampler.py | 5 ++ toolkit/stable_diffusion_model.py | 30 ++++++--- toolkit/train_tools.py | 15 +++++ 7 files changed, 144 insertions(+), 15 deletions(-) create mode 100644 scripts/make_lcm_sdxl_model.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 0dcba154..49f39725 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -62,8 +62,6 @@ class SDTrainer(BaseSDTrainProcess): # offload it. Already cached self.sd.vae.to('cpu') flush() - - self.sd.noise_scheduler.set_timesteps(1000) add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch) # you can expand these in a child class to make customization easier @@ -478,9 +476,10 @@ class SDTrainer(BaseSDTrainProcess): with self.timer('encode_prompt'): if grad_on_text_encoder: with torch.set_grad_enabled(True): - conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, - long_prompts=True).to( - # conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to( + conditional_embeds = self.sd.encode_prompt( + conditioned_prompts, prompt_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=True).to( self.device_torch, dtype=dtype) else: @@ -491,9 +490,10 @@ class SDTrainer(BaseSDTrainProcess): te.eval() else: self.sd.text_encoder.eval() - conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, - long_prompts=True).to( - # conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to( + conditional_embeds = self.sd.encode_prompt( + conditioned_prompts, prompt_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=True).to( self.device_torch, dtype=dtype) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 1c6453be..0f17ce40 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -419,6 +419,16 @@ class BaseSDTrainProcess(BaseTrainProcess): with open(path_to_save, 'w') as f: json.dump(json_data, f, indent=4) + # save optimizer + if self.optimizer is not None: + try: + filename = f'optimizer.pt' + file_path = os.path.join(self.save_root, filename) + torch.save(self.optimizer.state_dict(), file_path) + except Exception as e: + print(e) + print("Could not save optimizer") + self.print(f"Saved to {file_path}") self.clean_up_saves() self.post_save_hook(file_path) @@ -1121,6 +1131,19 @@ class BaseSDTrainProcess(BaseTrainProcess): optimizer_params=self.train_config.optimizer_params) self.optimizer = optimizer + # check if it exists + optimizer_state_filename = f'optimizer.pt' + optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename) + if os.path.exists(optimizer_state_file_path): + # try to load + try: + print(f"Loading optimizer state from {optimizer_state_file_path}") + optimizer_state_dict = torch.load(optimizer_state_file_path) + optimizer.load_state_dict(optimizer_state_dict) + except Exception as e: + print(f"Failed to load optimizer state from {optimizer_state_file_path}") + print(e) + lr_scheduler_params = self.train_config.lr_scheduler_params # make sure it had bare minimum diff --git a/scripts/make_lcm_sdxl_model.py b/scripts/make_lcm_sdxl_model.py new file mode 100644 index 00000000..20e95ce7 --- /dev/null +++ b/scripts/make_lcm_sdxl_model.py @@ -0,0 +1,67 @@ +import argparse +from collections import OrderedDict + +import torch + +from toolkit.config_modules import ModelConfig +from toolkit.stable_diffusion_model import StableDiffusion + + +parser = argparse.ArgumentParser() +parser.add_argument( + 'input_path', + type=str, + help='Path to original sdxl model' +) +parser.add_argument( + 'output_path', + type=str, + help='output path' +) +parser.add_argument('--sdxl', action='store_true', help='is sdxl model') +parser.add_argument('--refiner', action='store_true', help='is refiner model') +parser.add_argument('--ssd', action='store_true', help='is ssd model') +parser.add_argument('--sd2', action='store_true', help='is sd 2 model') + +args = parser.parse_args() +device = torch.device('cpu') +dtype = torch.float32 + +print(f"Loading model from {args.input_path}") + +if args.sdxl: + adapter_id = "latent-consistency/lcm-lora-sdxl" +if args.refiner: + adapter_id = "latent-consistency/lcm-lora-sdxl" +elif args.ssd: + adapter_id = "latent-consistency/lcm-lora-ssd-1b" +else: + adapter_id = "latent-consistency/lcm-lora-sdv1-5" + + +diffusers_model_config = ModelConfig( + name_or_path=args.input_path, + is_xl=args.sdxl, + is_v2=args.sd2, + is_ssd=args.ssd, + dtype=dtype, + ) +diffusers_sd = StableDiffusion( + model_config=diffusers_model_config, + device=device, + dtype=dtype, +) +diffusers_sd.load_model() + + +print(f"Loaded model from {args.input_path}") + +diffusers_sd.pipeline.load_lora_weights(adapter_id) +diffusers_sd.pipeline.fuse_lora() + +meta = OrderedDict() + +diffusers_sd.save(args.output_path, meta=meta) + + +print(f"Saved to {args.output_path}") diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 7f0efafb..fb69c2fc 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -194,6 +194,9 @@ class TrainConfig: self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0) + # dropout that happens before encoding. It functions independently per text encoder + self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0) + # match the norm of the noise before computing loss. This will help the model maintain its # current understandin of the brightness of images. diff --git a/toolkit/sampler.py b/toolkit/sampler.py index 098f4527..5e585279 100644 --- a/toolkit/sampler.py +++ b/toolkit/sampler.py @@ -12,7 +12,9 @@ from diffusers import ( HeunDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, + LCMScheduler ) + from k_diffusion.external import CompVisDenoiser # scheduler: @@ -72,12 +74,15 @@ def get_sampler( scheduler_cls = KDPM2DiscreteScheduler elif sampler == "dpm_2_a": scheduler_cls = KDPM2AncestralDiscreteScheduler + elif sampler == "lcm": + scheduler_cls = LCMScheduler config = copy.deepcopy(sdxl_sampler_config) config.update(sched_init_args) scheduler = scheduler_cls.from_config(config) + return scheduler diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index ca9aceff..9a72d647 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -344,6 +344,11 @@ class StableDiffusion: else: noise_scheduler = get_sampler(sampler) + try: + noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype) + except: + pass + if sampler.startswith("sample_") and self.is_xl: # using kdiffusion Pipe = StableDiffusionKDiffusionXLPipeline @@ -722,7 +727,8 @@ class StableDiffusion: refiner_pred = self.refiner_unet( input_chunks[1], timestep_chunks[1], - encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], # just use the first second text encoder + encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], + # just use the first second text encoder added_cond_kwargs={ "text_embeds": added_cond_kwargs_chunked['text_embeds'][1], # "time_ids": added_cond_kwargs_chunked['time_ids'][1], @@ -740,7 +746,8 @@ class StableDiffusion: # just use the first second text encoder added_cond_kwargs={ "text_embeds": text_embeddings.pooled_embeds, - "time_ids": self.get_time_ids_from_latents(latent_model_input, requires_aesthetic_score=True), + "time_ids": self.get_time_ids_from_latents(latent_model_input, + requires_aesthetic_score=True), }, **kwargs, ).sample @@ -845,7 +852,8 @@ class StableDiffusion: num_images_per_prompt=1, force_all=False, long_prompts=False, - max_length=None + max_length=None, + dropout_prob=0.0, ) -> PromptEmbeds: # sd1.5 embeddings are (bs, 77, 768) prompt = prompt @@ -875,12 +883,18 @@ class StableDiffusion: use_text_encoder_2=use_encoder_2, truncate=not long_prompts, max_length=max_length, + dropout_prob=dropout_prob, ) ) else: return PromptEmbeds( train_tools.encode_prompts( - self.tokenizer, self.text_encoder, prompt, truncate=not long_prompts, max_length=max_length + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=max_length, + dropout_prob=dropout_prob ) ) @@ -1011,8 +1025,9 @@ class StableDiffusion: state_dict[new_key] = v return state_dict - def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> OrderedDict[ - str, Parameter]: + def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \ + OrderedDict[ + str, Parameter]: named_params: OrderedDict[str, Parameter] = OrderedDict() if vae: for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"): @@ -1198,7 +1213,8 @@ class StableDiffusion: print(f"Found {len(params)} trainable parameter in text encoder") if refiner: - named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, state_dict_keys=True) + named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, + state_dict_keys=True) refiner_lr = refiner_lr if refiner_lr is not None else default_lr params = [] for key, diffusers_key in ldm_diffusers_keymap.items(): diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 4b79e024..267b7588 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -537,6 +537,7 @@ def encode_prompts_xl( use_text_encoder_2: bool = True, # sdxl truncate: bool = True, max_length=None, + dropout_prob=0.0, ) -> tuple[torch.FloatTensor, torch.FloatTensor]: # text_encoder and text_encoder_2's penuultimate layer's output text_embeds_list = [] @@ -553,6 +554,12 @@ def encode_prompts_xl( if idx == 1 and not use_text_encoder_2: prompt_list_to_use = ["" for _ in prompts] + if dropout_prob > 0.0: + # randomly drop out prompts + prompt_list_to_use = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompt_list_to_use + ] + text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length) # set the max length for the next one if idx == 0: @@ -598,9 +605,17 @@ def encode_prompts( prompts: list[str], truncate: bool = True, max_length=None, + dropout_prob=0.0, ): if max_length is None: max_length = tokenizer.model_max_length + + if dropout_prob > 0.0: + # randomly drop out prompts + prompts = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts + ] + text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length) text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)