From 427847ac4cf5e009d75b376145c9d18bd53315df Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 26 Mar 2024 11:35:26 -0600 Subject: [PATCH] Small tweaks and fixes for specialized ip adapter training --- extensions_built_in/sd_trainer/SDTrainer.py | 14 ++--- toolkit/config_modules.py | 3 + toolkit/dataloader_mixins.py | 6 ++ toolkit/ip_adapter.py | 68 ++++++++++++++++++++- toolkit/stable_diffusion_model.py | 12 +++- toolkit/util/inverse_cfg.py | 25 ++++++++ 6 files changed, 117 insertions(+), 11 deletions(-) create mode 100644 toolkit/util/inverse_cfg.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 180c4226..ba9ad115 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1211,7 +1211,7 @@ class SDTrainer(BaseSDTrainProcess): clip_images, drop=True, is_training=True, - has_been_preprocessed=True, + has_been_preprocessed=False, quad_count=quad_count ) if self.train_config.do_cfg: @@ -1222,7 +1222,7 @@ class SDTrainer(BaseSDTrainProcess): ).detach(), is_training=True, drop=True, - has_been_preprocessed=True, + has_been_preprocessed=False, quad_count=quad_count ) elif has_clip_image: @@ -1230,14 +1230,14 @@ class SDTrainer(BaseSDTrainProcess): clip_images.detach().to(self.device_torch, dtype=dtype), is_training=True, has_been_preprocessed=True, - quad_count=quad_count + quad_count=quad_count, + # do cfg on clip embeds to normalize the embeddings for when doing cfg + # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None + # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None ) if self.train_config.do_cfg: unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( - torch.zeros( - (noisy_latents.shape[0], 3, image_size, image_size), - device=self.device_torch, dtype=dtype - ).detach(), + clip_images.detach().to(self.device_torch, dtype=dtype), is_training=True, drop=True, has_been_preprocessed=True, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 5fa79376..3caecbca 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -352,6 +352,9 @@ class ModelConfig: if self.is_vega: self.is_xl = True + # for text encoder quant. Only works with pixart currently + self.text_encoder_bits = kwargs.get('text_encoder_bits', 8) # 16, 8, 4 + class ReferenceDatasetConfig: def __init__(self, **kwargs): diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 20ffc3a7..c164a3e0 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -728,6 +728,12 @@ class ClipImageFileItemDTOMixin: # do a flip img = img.transpose(Image.FLIP_TOP_BOTTOM) + # image must be square. If it is not, we will resize/squish it so it is, that way we don't crop out data + if img.width != img.height: + # resize to the smallest dimension + min_size = min(img.width, img.height) + img = img.resize((min_size, min_size), Image.BICUBIC) + if self.has_clip_augmentations: self.clip_image_tensor = self.augment_clip_image(img, transform=None) else: diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 72cab509..a59d5789 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -14,6 +14,7 @@ from toolkit.models.zipper_resampler import ZipperResampler from toolkit.paths import REPOS_ROOT from toolkit.saving import load_ip_adapter_model from toolkit.train_tools import get_torch_dtype +from toolkit.util.inverse_cfg import inverse_classifier_guidance sys.path.append(REPOS_ROOT) from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional @@ -376,8 +377,10 @@ class IPAdapter(torch.nn.Module): output_dim = sd.unet.config['cross_attention_dim'] if is_pixart: - heads = 20 - dim = 4096 + # heads = 20 + heads = 12 + # dim = 4096 + dim = 1280 output_dim = 4096 if self.config.image_encoder_arch.startswith('convnext'): @@ -628,6 +631,23 @@ class IPAdapter(torch.nn.Module): clip_image_embeds = clip_image_embeds.to(device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() return clip_image_embeds + def get_empty_clip_image(self, batch_size: int) -> torch.Tensor: + with torch.no_grad(): + tensors_0_1 = torch.rand([batch_size, 3, self.input_size, self.input_size], device=self.device) + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + # tensors_0_1 = tensors_0_1 * 0 + mean = torch.tensor(self.clip_image_processor.image_mean).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + std = torch.tensor(self.clip_image_processor.image_std).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + return clip_image.detach() + def get_clip_image_embeds_from_tensors( self, tensors_0_1: torch.Tensor, @@ -635,6 +655,7 @@ class IPAdapter(torch.nn.Module): is_training=False, has_been_preprocessed=False, quad_count=4, + cfg_embed_strength=None, # perform CFG on embeds with unconditional as negative ) -> torch.Tensor: if self.sd_ref().unet.device != self.device: self.to(self.sd_ref().unet.device) @@ -642,6 +663,7 @@ class IPAdapter(torch.nn.Module): self.to(self.sd_ref().unet.device) if not self.config.train: is_training = False + uncond_clip = None with torch.no_grad(): # on training the clip image is created in the dataloader if not has_been_preprocessed: @@ -749,6 +771,48 @@ class IPAdapter(torch.nn.Module): # rearrange to (batch, tokens, size) clip_image_embeds = clip_image_embeds.permute(0, 2, 1) + # apply unconditional if doing cfg on embeds + with torch.no_grad(): + if cfg_embed_strength is not None: + uncond_clip = self.get_empty_clip_image(tensors_0_1.shape[0]).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + if self.config.quad_image: + # split the 4x4 grid and stack on batch + ci1, ci2 = uncond_clip.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + to_cat = [] + for i, ci in enumerate([ci1, ci2, ci3, ci4]): + if i < quad_count: + to_cat.append(ci) + else: + break + + uncond_clip = torch.cat(to_cat, dim=0).detach() + uncond_clip_output = self.image_encoder( + uncond_clip, output_hidden_states=True + ) + + if self.config.clip_layer == 'penultimate_hidden_states': + uncond_clip_output_embeds = uncond_clip_output.hidden_states[-2] + elif self.config.clip_layer == 'last_hidden_state': + uncond_clip_output_embeds = uncond_clip_output.hidden_states[-1] + else: + uncond_clip_output_embeds = uncond_clip_output.image_embeds + if self.config.adapter_type == "clip_face": + l2_norm = torch.norm(uncond_clip_output_embeds, p=2) + uncond_clip_output_embeds = uncond_clip_output_embeds / l2_norm + + uncond_clip_output_embeds = uncond_clip_output_embeds.detach() + + + # apply inverse cfg + clip_image_embeds = inverse_classifier_guidance( + clip_image_embeds, + uncond_clip_output_embeds, + cfg_embed_strength + ) + + if self.config.quad_image: # get the outputs of the quat chunks = clip_image_embeds.chunk(quad_count, dim=0) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index b88c8019..214ab44e 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -48,6 +48,7 @@ from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler from transformers import T5EncoderModel from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT +from toolkit.util.inverse_cfg import inverse_classifier_guidance # tell it to shut up diffusers.logging.set_verbosity(diffusers.logging.ERROR) @@ -234,13 +235,20 @@ class StableDiffusion: flush() print("Injecting alt weights") elif self.model_config.is_pixart: + te_kwargs = {} + # handle quantization of TE + if self.model_config.text_encoder_bits == 8: + te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + elif self.model_config.text_encoder_bits == 4: + te_kwargs['load_in_4bit'] = True + te_kwargs['device_map'] = "auto" # load the TE in 8bit mode text_encoder = T5EncoderModel.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder", - load_in_8bit=True, - device_map="auto", torch_dtype=self.torch_dtype, + **te_kwargs ) # load the transformer diff --git a/toolkit/util/inverse_cfg.py b/toolkit/util/inverse_cfg.py new file mode 100644 index 00000000..0c85544a --- /dev/null +++ b/toolkit/util/inverse_cfg.py @@ -0,0 +1,25 @@ +import torch + + +def inverse_classifier_guidance( + noise_pred_cond: torch.Tensor, + noise_pred_uncond: torch.Tensor, + guidance_scale: torch.Tensor +): + """ + Adjust the noise_pred_cond for the classifier free guidance algorithm + to ensure that the final noise prediction equals the original noise_pred_cond. + """ + # To make noise_pred equal noise_pred_cond_orig, we adjust noise_pred_cond + # based on the formula used in the algorithm. + # We derive the formula to find the correct adjustment for noise_pred_cond: + # noise_pred_cond = (noise_pred_cond_orig - noise_pred_uncond * guidance_scale) / (guidance_scale - 1) + # It's important to check if guidance_scale is not 1 to avoid division by zero. + if guidance_scale == 1: + # If guidance_scale is 1, adjusting is not needed or possible in the same way, + # since it would lead to division by zero. This also means the algorithm inherently + # doesn't alter the noise_pred_cond in relation to noise_pred_uncond. + # Thus, we return the original values, though this situation might need special handling. + return noise_pred_cond + adjusted_noise_pred_cond = (noise_pred_cond - noise_pred_uncond) / guidance_scale + return adjusted_noise_pred_cond