From 016687bda1e8600625bcc9a71c73ce9957dc9121 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 17 Mar 2024 10:21:47 -0600 Subject: [PATCH] Adapter work. Bug fixes. Auto adjust LR when resuming optimizer. --- .../advanced_generator/ReferenceGenerator.py | 43 +++++++++++++------ extensions_built_in/sd_trainer/SDTrainer.py | 2 + jobs/process/BaseSDTrainProcess.py | 11 +++++ testing/test_bucket_dataloader.py | 1 + toolkit/config_modules.py | 5 ++- toolkit/custom_adapter.py | 18 +++++++- toolkit/models/ilora.py | 3 ++ toolkit/stable_diffusion_model.py | 16 +++++++ 8 files changed, 84 insertions(+), 15 deletions(-) diff --git a/extensions_built_in/advanced_generator/ReferenceGenerator.py b/extensions_built_in/advanced_generator/ReferenceGenerator.py index 14ae1ff5..19e3b6e5 100644 --- a/extensions_built_in/advanced_generator/ReferenceGenerator.py +++ b/extensions_built_in/advanced_generator/ReferenceGenerator.py @@ -7,7 +7,7 @@ import numpy as np from PIL import Image from diffusers import T2IAdapter from torch.utils.data import DataLoader -from diffusers import StableDiffusionXLAdapterPipeline +from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline from tqdm import tqdm from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig @@ -100,25 +100,43 @@ class ReferenceGenerator(BaseExtensionProcess): if self.generate_config.t2i_adapter_path is not None: self.adapter = T2IAdapter.from_pretrained( - "TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=self.torch_dtype, varient="fp16" + self.generate_config.t2i_adapter_path, + torch_dtype=self.torch_dtype, + varient="fp16" ).to(device) midas_depth = MidasDetector.from_pretrained( "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large" ).to(device) - pipe = StableDiffusionXLAdapterPipeline( - vae=self.sd.vae, - unet=self.sd.unet, - text_encoder=self.sd.text_encoder[0], - text_encoder_2=self.sd.text_encoder[1], - tokenizer=self.sd.tokenizer[0], - tokenizer_2=self.sd.tokenizer[1], - scheduler=get_sampler(self.generate_config.sampler), - adapter=self.adapter, - ).to(device) + if self.model_config.is_xl: + pipe = StableDiffusionXLAdapterPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder[0], + text_encoder_2=self.sd.text_encoder[1], + tokenizer=self.sd.tokenizer[0], + tokenizer_2=self.sd.tokenizer[1], + scheduler=get_sampler(self.generate_config.sampler), + adapter=self.adapter, + ).to(device, dtype=self.torch_dtype) + else: + pipe = StableDiffusionAdapterPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder, + tokenizer=self.sd.tokenizer, + scheduler=get_sampler(self.generate_config.sampler), + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + adapter=self.adapter, + ).to(device, dtype=self.torch_dtype) pipe.set_progress_bar_config(disable=True) + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True) + self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd) num_batches = len(self.data_loader) @@ -176,6 +194,7 @@ class ReferenceGenerator(BaseExtensionProcess): adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale, guidance_scale=self.generate_config.guidance_scale, ).images[0] + os.makedirs(os.path.dirname(output_path), exist_ok=True) gen_images.save(output_path) # save caption diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index ab2831f3..180c4226 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -796,6 +796,7 @@ class SDTrainer(BaseSDTrainProcess): unconditional_embeddings=unconditional_embeds, timestep=timesteps, guidance_scale=self.train_config.cfg_scale, + rescale_cfg=self.train_config.cfg_rescale, **pred_kwargs # adapter residuals in here ) if was_unet_training: @@ -1355,6 +1356,7 @@ class SDTrainer(BaseSDTrainProcess): timestep=timesteps, guidance_scale=self.train_config.cfg_scale, detach_unconditional=False, + rescale_cfg=self.train_config.cfg_rescale, **pred_kwargs ) self.after_unet_predict() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 304d7bac..efcdcfdb 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1352,6 +1352,10 @@ class BaseSDTrainProcess(BaseTrainProcess): # try to load # previous param groups # previous_params = copy.deepcopy(optimizer.param_groups) + previous_lrs = [] + for group in optimizer.param_groups: + previous_lrs.append(group['lr']) + try: print(f"Loading optimizer state from {optimizer_state_file_path}") optimizer_state_dict = torch.load(optimizer_state_file_path) @@ -1360,6 +1364,13 @@ class BaseSDTrainProcess(BaseTrainProcess): print(f"Failed to load optimizer state from {optimizer_state_file_path}") print(e) + # update the optimizer LR from the params + print(f"Updating optimizer LR from params") + if len(previous_lrs) > 0: + for i, group in enumerate(optimizer.param_groups): + group['lr'] = previous_lrs[i] + group['initial_lr'] = previous_lrs[i] + # Update the learning rates if they changed # optimizer.param_groups = previous_params diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py index e8208107..a223307a 100644 --- a/testing/test_bucket_dataloader.py +++ b/testing/test_bucket_dataloader.py @@ -42,6 +42,7 @@ dataset_config = DatasetConfig( resolution=resolution, caption_ext='json', default_caption='default', + clip_image_path='/mnt/Datasets/face_pairs2/control_clean', buckets=True, bucket_tolerance=bucket_tolerance, poi='person', diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 9576a74c..3d007354 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -218,7 +218,7 @@ class TrainConfig: self.xformers = kwargs.get('xformers', False) self.sdp = kwargs.get('sdp', False) self.train_unet = kwargs.get('train_unet', True) - self.train_text_encoder = kwargs.get('train_text_encoder', True) + self.train_text_encoder = kwargs.get('train_text_encoder', False) self.train_refiner = kwargs.get('train_refiner', True) self.train_turbo = kwargs.get('train_turbo', False) self.show_turbo_outputs = kwargs.get('show_turbo_outputs', False) @@ -298,6 +298,9 @@ class TrainConfig: self.do_random_cfg = kwargs.get('do_random_cfg', False) self.cfg_scale = kwargs.get('cfg_scale', 1.0) self.max_cfg_scale = kwargs.get('max_cfg_scale', self.cfg_scale) + self.cfg_rescale = kwargs.get('cfg_rescale', None) + if self.cfg_rescale is None: + self.cfg_rescale = self.cfg_scale # applies the inverse of the prediction mean and std to the target to correct # for norm drift diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 616fc466..c27483c0 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -132,9 +132,16 @@ class CustomAdapter(torch.nn.Module): vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2) if self.config.image_encoder_arch == 'clip': vision_tokens = vision_tokens + 1 + + vision_hidden_size = self.vision_encoder.config.hidden_size + + if self.config.clip_layer == 'image_embeds': + vision_tokens = 1 + vision_hidden_size = self.vision_encoder.config.projection_dim + self.ilora_module = InstantLoRAModule( vision_tokens=vision_tokens, - vision_hidden_size=self.vision_encoder.config.hidden_size, + vision_hidden_size=vision_hidden_size, sd=self.sd_ref() ) elif self.adapter_type == 'text_encoder': @@ -731,7 +738,14 @@ class CustomAdapter(torch.nn.Module): clip_image, output_hidden_states=True ) - img_embeds = id_embeds['last_hidden_state'] + if self.config.clip_layer == 'penultimate_hidden_states': + img_embeds = id_embeds.hidden_states[-2] + elif self.config.clip_layer == 'last_hidden_state': + img_embeds = id_embeds.hidden_states[-1] + elif self.config.clip_layer == 'image_embeds': + img_embeds = id_embeds.image_embeds + else: + raise ValueError(f"unknown clip layer: {self.config.clip_layer}") if self.config.quad_image: # get the outputs of the quat diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index 2698d15a..c7c9fd47 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -106,6 +106,9 @@ class InstantLoRAModule(torch.nn.Module): # this will be used to add the vector to the original forward def forward(self, img_embeds): + # expand token rank if only rank 2 + if len(img_embeds.shape) == 2: + img_embeds = img_embeds.unsqueeze(1) img_embeds = self.resampler(img_embeds) self.img_embeds = img_embeds diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index c8ae4025..b88c8019 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -863,6 +863,7 @@ class StableDiffusion: unconditional_embeddings: Union[PromptEmbeds, None] = None, is_input_scaled=False, detach_unconditional=False, + rescale_cfg=None, **kwargs, ): # get the embeddings @@ -1111,6 +1112,21 @@ class StableDiffusion: noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) + if rescale_cfg is not None and rescale_cfg != guidance_scale: + with torch.no_grad(): + # do cfg at the target rescale so we can match it + target_pred_mean_std = noise_pred_uncond + rescale_cfg * ( + noise_pred_text - noise_pred_uncond + ) + target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach() + target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach() + + pred_mean = noise_pred.mean([1, 2, 3], keepdim=True).detach() + pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach() + + # match the mean and std + noise_pred = (noise_pred - pred_mean) / pred_std + noise_pred = (noise_pred * target_std) + target_mean # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 if guidance_rescale > 0.0: