From 72de68d8aa6d328f428efff414769b4f1739a8ca Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 13 Mar 2024 07:24:08 -0600 Subject: [PATCH] WIP on clip vision encoder --- extensions_built_in/sd_trainer/SDTrainer.py | 35 ++++- toolkit/clip_vision_adapter.py | 164 +++++++++++++------- toolkit/config_modules.py | 3 +- toolkit/stable_diffusion_model.py | 35 ++++- 4 files changed, 164 insertions(+), 73 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 65424bc9..ab2831f3 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -740,14 +740,36 @@ class SDTrainer(BaseSDTrainProcess): embeds_to_use = conditional_embeds.clone().detach() # handle clip vision adapter by removing triggers from prompt and replacing with the class name - if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter): + if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None: prompt_list = batch.get_caption_list() + class_name = '' + + triggers = ['[trigger]', '[name]'] + remove_tokens = [] + + if self.embed_config is not None: + triggers.append(self.embed_config.trigger) + for i in range(1, self.embed_config.tokens): + remove_tokens.append(f"{self.embed_config.trigger}_{i}") + if self.embed_config.trigger_class_name is not None: + class_name = self.embed_config.trigger_class_name + + if self.adapter is not None: + triggers.append(self.adapter_config.trigger) + for i in range(1, self.adapter_config.num_tokens): + remove_tokens.append(f"{self.adapter_config.trigger}_{i}") + if self.adapter_config.trigger_class_name is not None: + class_name = self.adapter_config.trigger_class_name + for idx, prompt in enumerate(prompt_list): - prompt = self.adapter.inject_trigger_class_name_into_prompt(prompt) + for remove_token in remove_tokens: + prompt = prompt.replace(remove_token, '') + for trigger in triggers: + prompt = prompt.replace(trigger, class_name) prompt_list[idx] = prompt embeds_to_use = self.sd.encode_prompt( - prompt, + prompt_list, long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype).detach() @@ -1030,7 +1052,8 @@ class SDTrainer(BaseSDTrainProcess): if has_clip_image: conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( clip_images.detach().to(self.device_torch, dtype=dtype), - is_training=True + is_training=True, + has_been_preprocessed=True ) else: # just do a blank one @@ -1039,7 +1062,9 @@ class SDTrainer(BaseSDTrainProcess): (noisy_latents.shape[0], 3, 512, 512), device=self.device_torch, dtype=dtype ), - is_training=True + is_training=True, + has_been_preprocessed=True, + drop=True ) # it will be injected into the tokenizer when called self.adapter(conditional_clip_embeds) diff --git a/toolkit/clip_vision_adapter.py b/toolkit/clip_vision_adapter.py index 4691bb0d..83580fa9 100644 --- a/toolkit/clip_vision_adapter.py +++ b/toolkit/clip_vision_adapter.py @@ -4,6 +4,8 @@ import torch import weakref from toolkit.config_modules import AdapterConfig +from toolkit.models.clip_fusion import ZipperBlock +from toolkit.models.zipper_resampler import ZipperModule from toolkit.prompt_utils import PromptEmbeds from toolkit.train_tools import get_torch_dtype @@ -24,11 +26,11 @@ import torch.nn as nn class Embedder(nn.Module): def __init__( self, - num_input_tokens: int = 50, + num_input_tokens: int = 1, input_dim: int = 1024, num_output_tokens: int = 8, output_dim: int = 768, - mid_dim: int = 128, + mid_dim: int = 1024 ): super(Embedder, self).__init__() self.num_output_tokens = num_output_tokens @@ -36,29 +38,24 @@ class Embedder(nn.Module): self.input_dim = input_dim self.output_dim = output_dim - # Convolutional layer to reduce channel dimension - self.conv = nn.Conv1d(in_channels=input_dim, out_channels=mid_dim, kernel_size=1) - - # GELU Activation + self.layer_norm = nn.LayerNorm(input_dim) + self.fc1 = nn.Linear(input_dim, mid_dim) self.gelu = nn.GELU() + self.fc2 = nn.Linear(mid_dim, output_dim * num_output_tokens) - # Layer Normalization - self.layer_norm = nn.LayerNorm(mid_dim) - - # Adaptive pooling to change sequence length - self.adaptive_pool = nn.AdaptiveAvgPool1d(num_output_tokens) - - # Linear layer for final transformation - self.final_linear = nn.Linear(mid_dim, output_dim) + self.static_tokens = nn.Parameter(torch.randn(num_output_tokens, output_dim)) def forward(self, x): - x = x.permute(0, 2, 1) # Adjust for Conv1d - x = self.conv(x) + x = self.layer_norm(x) + x = self.fc1(x) x = self.gelu(x) - x = self.layer_norm(x.permute(0, 2, 1)).permute(0, 2, 1) # Apply LayerNorm - x = self.adaptive_pool(x) - x = x.permute(0, 2, 1) # Adjust back - x = self.final_linear(x) + x = self.fc2(x) + x = x.view(-1, self.num_output_tokens, self.output_dim) + + # repeat the static tokens for each batch + static_tokens = torch.stack([self.static_tokens] * x.shape[0]) + x = static_tokens + x + return x @@ -140,24 +137,29 @@ class ClipVisionAdapter(torch.nn.Module): self.image_encoder.train() else: self.image_encoder.eval() - # self.embedder = Embedder( - # num_output_tokens=self.config.num_tokens, - # num_input_tokens=self.image_encoder.config.top_k, # max_position_embeddings ? - # input_dim=self.image_encoder.config.hidden_size, - # output_dim=sd.unet.config['cross_attention_dim'], - # ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) - heads = 12 if not sd.is_xl else 20 - # dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 - dim = sd.unet.config['cross_attention_dim'] - self.embedder = Resampler( - dim=dim, - depth=4, - dim_head=64, - heads=heads, - num_queries=self.config.num_tokens, # usually 16 - embedding_dim=self.image_encoder.config.hidden_size, - output_dim=sd.unet.config['cross_attention_dim'], - ff_mult=4 + + # max_seq_len = CLIP tokens + CLS token + image_encoder_state_dict = self.image_encoder.state_dict() + in_tokens = 257 + if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # clip + in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + + if hasattr(self.image_encoder.config, 'hidden_sizes'): + embedding_dim = self.image_encoder.config.hidden_sizes[-1] + else: + embedding_dim = self.image_encoder.config.hidden_size + + if self.config.clip_layer == 'image_embeds': + in_tokens = 1 + embedding_dim = self.image_encoder.config.projection_dim + + self.embedder = Embedder( + num_output_tokens=self.config.num_tokens, + num_input_tokens=in_tokens, + input_dim=embedding_dim, + output_dim=self.sd_ref().unet.config['cross_attention_dim'], + mid_dim=embedding_dim * self.config.num_tokens, ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) self.embedder.train() @@ -186,37 +188,76 @@ class ClipVisionAdapter(torch.nn.Module): def get_clip_image_embeds_from_tensors( self, tensors_0_1: torch.Tensor, drop=False, - is_training=False + is_training=False, + has_been_preprocessed=False ) -> torch.Tensor: with torch.no_grad(): - # tensors should be 0-1 - # todo: add support for sdxl - if tensors_0_1.ndim == 3: - tensors_0_1 = tensors_0_1.unsqueeze(0) - # training tensors are 0 - 1 - tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) - # if images are out of this range throw error - if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: - raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( - tensors_0_1.min(), tensors_0_1.max() - )) + if not has_been_preprocessed: + # tensors should be 0-1 + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) - clip_image = self.clip_image_processor( - images=tensors_0_1, - return_tensors="pt", - do_resize=True, - do_rescale=False, - ).pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16).detach() - if drop: - clip_image = clip_image * 0 + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + # unconditional + if drop: + if self.clip_noise_zero: + tensors_0_1 = torch.rand_like(tensors_0_1).detach() + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + else: + tensors_0_1 = torch.zeros_like(tensors_0_1).detach() + # tensors_0_1 = tensors_0_1 * 0 + clip_image = self.clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + else: + if drop: + # scale the noise down + if self.clip_noise_zero: + tensors_0_1 = torch.rand_like(tensors_0_1).detach() + noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, + dtype=get_torch_dtype(self.sd_ref().dtype)) + tensors_0_1 = tensors_0_1 * noise_scale + else: + tensors_0_1 = torch.zeros_like(tensors_0_1).detach() + # tensors_0_1 = tensors_0_1 * 0 + mean = torch.tensor(self.clip_image_processor.image_mean).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + std = torch.tensor(self.clip_image_processor.image_std).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + + else: + clip_image = tensors_0_1 + clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() with torch.set_grad_enabled(is_training): if is_training: self.image_encoder.train() else: self.image_encoder.eval() clip_output = self.image_encoder(clip_image, output_hidden_states=True) - clip_image_embeds = clip_output.hidden_states[-2] + + if self.config.clip_layer == 'penultimate_hidden_states': + # they skip last layer for ip+ + # https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26 + clip_image_embeds = clip_output.hidden_states[-2] + elif self.config.clip_layer == 'last_hidden_state': + clip_image_embeds = clip_output.hidden_states[-1] + else: + clip_image_embeds = clip_output.image_embeds return clip_image_embeds import torch @@ -236,6 +277,9 @@ class ClipVisionAdapter(torch.nn.Module): # adds it to the tokenizer def forward(self, clip_image_embeds: torch.Tensor) -> PromptEmbeds: clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + if clip_image_embeds.ndim == 2: + # expand the token dimension + clip_image_embeds = clip_image_embeds.unsqueeze(1) image_prompt_embeds = self.embedder(clip_image_embeds) # todo add support for multiple batch sizes if image_prompt_embeds.shape[0] != 1: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 1bf89dcb..9576a74c 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -166,7 +166,7 @@ class AdapterConfig: # clip vision self.trigger = kwargs.get('trigger', 'tri993r') - self.trigger_class_name = kwargs.get('trigger_class_name', 'person') + self.trigger_class_name = kwargs.get('trigger_class_name', None) self.class_names = kwargs.get('class_names', []) @@ -188,6 +188,7 @@ class EmbeddingConfig: self.tokens = kwargs.get('tokens', 4) self.init_words = kwargs.get('init_words', '*') self.save_format = kwargs.get('save_format', 'safetensors') + self.trigger_class_name = kwargs.get('trigger_class_name', None) # used for inverted masked prior ContentOrStyleType = Literal['balanced', 'style', 'content'] diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 5bd72254..c8ae4025 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -39,7 +39,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \ - StableDiffusionXLImg2ImgPipeline, LCMScheduler + StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel import diffusers from diffusers import \ AutoencoderKL, \ @@ -242,10 +242,21 @@ class StableDiffusion: device_map="auto", torch_dtype=self.torch_dtype, ) + + # load the transformer + subfolder = "transformer" + # check if it is just the unet + if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)): + subfolder = None + # load the transformer only from the save + transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype, subfolder=subfolder) + + # replace the to function with a no-op since it throws an error instead of a warning text_encoder.to = lambda *args, **kwargs: None pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained( - model_path, + "PixArt-alpha/PixArt-XL-2-1024-MS", + transformer=transformer, text_encoder=text_encoder, dtype=dtype, device=self.device_torch, @@ -1081,10 +1092,14 @@ class StableDiffusion: else: noise_pred = noise_pred else: + if self.unet.device != self.device_torch: + self.unet.to(self.device_torch) + if self.unet.dtype != self.torch_dtype: + self.unet = self.unet.to(dtype=self.torch_dtype) noise_pred = self.unet( latent_model_input.to(self.device_torch, self.torch_dtype), timestep, - encoder_hidden_states=text_embeddings.text_embeds, + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), **kwargs, ).sample @@ -1485,10 +1500,16 @@ class StableDiffusion: # saving in diffusers format if not output_file.endswith('.safetensors'): # diffusers - self.pipeline.save_pretrained( - save_directory=output_file, - safe_serialization=True, - ) + if self.is_pixart: + self.unet.save_pretrained( + save_directory=output_file, + safe_serialization=True, + ) + else: + self.pipeline.save_pretrained( + save_directory=output_file, + safe_serialization=True, + ) # save out meta config meta_path = os.path.join(output_file, 'aitk_meta.yaml') with open(meta_path, 'w') as f: