From 05ae95ca8987457adbf487a7866713601c7e7f0a Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 24 Dec 2023 13:26:04 -0700 Subject: [PATCH] Added a clip vision adapter trainer. Only works for sd15 for now --- extensions_built_in/sd_trainer/SDTrainer.py | 46 ++- jobs/process/BaseSDTrainProcess.py | 32 +- toolkit/clip_vision_adapter.py | 331 ++++++++++++++++++++ toolkit/config_modules.py | 21 +- toolkit/resampler.py | 160 ++++++++++ toolkit/stable_diffusion_model.py | 16 +- 6 files changed, 586 insertions(+), 20 deletions(-) create mode 100644 toolkit/clip_vision_adapter.py create mode 100644 toolkit/resampler.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index dfd51e72..e0824b17 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -5,6 +5,7 @@ from diffusers import T2IAdapter import torch.functional as F from toolkit import train_tools from toolkit.basic import value_map, adain, get_mean_std +from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.config_modules import GuidanceConfig from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss @@ -504,6 +505,7 @@ class SDTrainer(BaseSDTrainProcess): noise: torch.Tensor, **kwargs ): + # todo for embeddings, we need to run without trigger words was_unet_training = self.sd.unet.training was_network_active = False if self.network is not None: @@ -519,13 +521,28 @@ class SDTrainer(BaseSDTrainProcess): # do a prediction here so we can match its output with network multiplier set to 0.0 with torch.no_grad(): dtype = get_torch_dtype(self.train_config.dtype) + + embeds_to_use = conditional_embeds.clone().detach() + # handle clip vision adapter by removing triggers from prompt and replacing with the class name + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter): + prompt_list = batch.get_caption_list() + for idx, prompt in enumerate(prompt_list): + prompt = self.adapter.inject_trigger_class_name_into_prompt(prompt) + prompt_list[idx] = prompt + + embeds_to_use = self.sd.encode_prompt( + prompt, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype).detach() + # dont use network on this # self.network.multiplier = 0.0 self.sd.unet.eval() prior_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), - conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(), timestep=timesteps, guidance_scale=1.0, **pred_kwargs # adapter residuals in here @@ -666,6 +683,9 @@ class SDTrainer(BaseSDTrainProcess): if self.embedding: grad_on_text_encoder = True + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + grad_on_text_encoder = True + # have a blank network so we can wrap it in a context and set multipliers without checking every time if self.network is not None: network = self.network @@ -745,6 +765,26 @@ class SDTrainer(BaseSDTrainProcess): prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))] with network: + # encode clip adapter here so embeds are active for tokenizer + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + with self.timer('encode_clip_vision_embeds'): + if has_clip_image: + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True + ) + else: + # just do a blank one + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + torch.zeros( + (noisy_latents.shape[0], 3, 512, 512), + device=self.device_torch, dtype=dtype + ), + is_training=True + ) + # it will be injected into the tokenizer when called + self.adapter(conditional_clip_embeds) + with self.timer('encode_prompt'): if grad_on_text_encoder: with torch.set_grad_enabled(True): @@ -912,6 +952,10 @@ class SDTrainer(BaseSDTrainProcess): with self.timer('restore_embeddings'): # Let's make sure we don't update any embedding weights besides the newly added token self.embedding.restore_embeddings() + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter): + with self.timer('restore_adapter'): + # Let's make sure we don't update any embedding weights besides the newly added token + self.adapter.restore_embeddings() loss_dict = OrderedDict( {'loss': loss.item()} diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 25fbf12a..5fe985fc 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -5,7 +5,7 @@ import json import shutil from collections import OrderedDict import os -from typing import Union, List +from typing import Union, List, Optional import numpy as np import yaml @@ -17,6 +17,7 @@ import torch import torch.backends.cuda from toolkit.basic import value_map +from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.embedding import Embedding @@ -138,7 +139,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # to hold network if there is one self.network: Union[Network, None] = None - self.adapter: Union[T2IAdapter, IPAdapter, None] = None + self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, None] = None self.embedding: Union[Embedding, None] = None is_training_adapter = self.adapter_config is not None and self.adapter_config.train @@ -202,7 +203,11 @@ class BaseSDTrainProcess(BaseTrainProcess): # ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here if self.embedding is not None: prompt = self.embedding.inject_embedding_to_prompt( - prompt, add_if_not_present=False + prompt, expand_token=True, add_if_not_present=False + ) + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter): + prompt = self.adapter.inject_trigger_into_prompt( + prompt, expand_token=True, add_if_not_present=False ) if self.trigger_word is not None: prompt = self.sd.inject_trigger_into_prompt( @@ -400,6 +405,8 @@ class BaseSDTrainProcess(BaseTrainProcess): # add _lora to name if self.adapter_config.type == 't2i': adapter_name += '_t2i' + elif self.adapter_config.type == 'clip': + adapter_name += '_clip' else: adapter_name += '_ip' @@ -647,6 +654,13 @@ class BaseSDTrainProcess(BaseTrainProcess): add_if_not_present=not is_reg, ) + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + prompt = self.adapter.inject_trigger_into_prompt( + prompt, + expand_token=True, + add_if_not_present=not is_reg, + ) + # make sure trigger is in the prompts if not a regularization run if self.trigger_word is not None: prompt = self.sd.inject_trigger_into_prompt( @@ -840,7 +854,12 @@ class BaseSDTrainProcess(BaseTrainProcess): def setup_adapter(self): # t2i adapter is_t2i = self.adapter_config.type == 't2i' - suffix = 't2i' if is_t2i else 'ip' + if self.adapter_config.type == 't2i': + suffix = 't2i' + elif self.adapter_config.type == 'clip': + suffix = 'clip' + else: + suffix = 'ip' adapter_name = self.name if self.network_config is not None: adapter_name = f"{adapter_name}_{suffix}" @@ -865,6 +884,11 @@ class BaseSDTrainProcess(BaseTrainProcess): downscale_factor=self.adapter_config.downscale_factor, adapter_type=self.adapter_config.adapter_type, ) + elif self.adapter_config.type == 'clip': + self.adapter = ClipVisionAdapter( + sd=self.sd, + adapter_config=self.adapter_config, + ) else: self.adapter = IPAdapter( sd=self.sd, diff --git a/toolkit/clip_vision_adapter.py b/toolkit/clip_vision_adapter.py new file mode 100644 index 00000000..93ef4be2 --- /dev/null +++ b/toolkit/clip_vision_adapter.py @@ -0,0 +1,331 @@ +from typing import TYPE_CHECKING, Mapping, Any + +import torch +import weakref + +from toolkit.config_modules import AdapterConfig +from toolkit.prompt_utils import PromptEmbeds +from toolkit.train_tools import get_torch_dtype + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + CLIPVisionModel +) + +from toolkit.resampler import Resampler + +import torch.nn as nn + + +class Embedder(nn.Module): + def __init__( + self, + num_input_tokens: int = 50, + input_dim: int = 1024, + num_output_tokens: int = 8, + output_dim: int = 768, + mid_dim: int = 128, + ): + super(Embedder, self).__init__() + self.num_output_tokens = num_output_tokens + self.num_input_tokens = num_input_tokens + self.input_dim = input_dim + self.output_dim = output_dim + + # Convolutional layer to reduce channel dimension + self.conv = nn.Conv1d(in_channels=input_dim, out_channels=mid_dim, kernel_size=1) + + # GELU Activation + self.gelu = nn.GELU() + + # Layer Normalization + self.layer_norm = nn.LayerNorm(mid_dim) + + # Adaptive pooling to change sequence length + self.adaptive_pool = nn.AdaptiveAvgPool1d(num_output_tokens) + + # Linear layer for final transformation + self.final_linear = nn.Linear(mid_dim, output_dim) + + def forward(self, x): + x = x.permute(0, 2, 1) # Adjust for Conv1d + x = self.conv(x) + x = self.gelu(x) + x = self.layer_norm(x.permute(0, 2, 1)).permute(0, 2, 1) # Apply LayerNorm + x = self.adaptive_pool(x) + x = x.permute(0, 2, 1) # Adjust back + x = self.final_linear(x) + return x + + +class ClipVisionAdapter(torch.nn.Module): + def __init__(self, sd: 'StableDiffusion', adapter_config: AdapterConfig): + super().__init__() + self.config = adapter_config + self.trigger = adapter_config.trigger + self.trigger_class_name = adapter_config.trigger_class_name + self.sd_ref: weakref.ref = weakref.ref(sd) + # embedding stuff + self.text_encoder_list = sd.text_encoder if isinstance(sd.text_encoder, list) else [sd.text_encoder] + self.tokenizer_list = sd.tokenizer if isinstance(sd.tokenizer, list) else [sd.tokenizer] + placeholder_tokens = [self.trigger] + + # add dummy tokens for multi-vector + additional_tokens = [] + for i in range(1, self.config.num_tokens): + additional_tokens.append(f"{self.trigger}_{i}") + placeholder_tokens += additional_tokens + + # handle dual tokenizer + self.tokenizer_list = self.sd_ref().tokenizer if isinstance(self.sd_ref().tokenizer, list) else [ + self.sd_ref().tokenizer] + self.text_encoder_list = self.sd_ref().text_encoder if isinstance(self.sd_ref().text_encoder, list) else [ + self.sd_ref().text_encoder] + + self.placeholder_token_ids = [] + self.embedding_tokens = [] + + print(f"Adding {placeholder_tokens} tokens to tokenizer") + print(f"Adding {self.config.num_tokens} tokens to tokenizer") + + for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list): + num_added_tokens = tokenizer.add_tokens(placeholder_tokens) + if num_added_tokens != self.config.num_tokens: + raise ValueError( + f"The tokenizer already contains the token {self.trigger}. Please pass a different" + f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}" + ) + + # Convert the initializer_token, placeholder_token to ids + init_token_ids = tokenizer.encode(self.config.trigger_class_name, add_special_tokens=False) + # if length of token ids is more than number of orm embedding tokens fill with * + if len(init_token_ids) > self.config.num_tokens: + init_token_ids = init_token_ids[:self.config.num_tokens] + elif len(init_token_ids) < self.config.num_tokens: + pad_token_id = tokenizer.encode(["*"], add_special_tokens=False) + init_token_ids += pad_token_id * (self.config.num_tokens - len(init_token_ids)) + + placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False) + self.placeholder_token_ids.append(placeholder_token_ids) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + with torch.no_grad(): + for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids): + token_embeds[token_id] = token_embeds[initializer_token_id].clone() + + # replace "[name] with this. on training. This is automatically generated in pipeline on inference + self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))) + + # backup text encoder embeddings + self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list] + + try: + self.clip_image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = CLIPImageProcessor() + self.device = self.sd_ref().unet.device + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( + self.config.image_encoder_path, + ignore_mismatched_sizes=True + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + if self.config.train_image_encoder: + self.image_encoder.train() + else: + self.image_encoder.eval() + # self.embedder = Embedder( + # num_output_tokens=self.config.num_tokens, + # num_input_tokens=self.image_encoder.config.top_k, # max_position_embeddings ? + # input_dim=self.image_encoder.config.hidden_size, + # output_dim=sd.unet.config['cross_attention_dim'], + # ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + heads = 12 if not sd.is_xl else 20 + # dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 + dim = sd.unet.config['cross_attention_dim'] + self.embedder = Resampler( + dim=dim, + depth=4, + dim_head=64, + heads=heads, + num_queries=self.config.num_tokens, # usually 16 + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=sd.unet.config['cross_attention_dim'], + ff_mult=4 + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + + self.embedder.train() + + def state_dict(self, *args, destination=None, prefix='', keep_vars=False): + state_dict = { + 'embedder': self.embedder.state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) + } + if self.config.train_image_encoder: + state_dict['image_encoder'] = self.image_encoder.state_dict( + *args, destination=destination, prefix=prefix, + keep_vars=keep_vars) + + return state_dict + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + self.embedder.load_state_dict(state_dict["embedder"], strict=strict) + if self.config.train_image_encoder and 'image_encoder' in state_dict: + self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict) + + def parameters(self, *args, **kwargs): + yield from self.embedder.parameters(*args, **kwargs) + + def named_parameters(self, *args, **kwargs): + yield from self.embedder.named_parameters(*args, **kwargs) + + def get_clip_image_embeds_from_tensors( + self, tensors_0_1: torch.Tensor, drop=False, + is_training=False + ) -> torch.Tensor: + with torch.no_grad(): + # tensors should be 0-1 + # todo: add support for sdxl + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + + clip_image = self.clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16).detach() + if drop: + clip_image = clip_image * 0 + with torch.set_grad_enabled(is_training): + if is_training: + self.image_encoder.train() + else: + self.image_encoder.eval() + clip_output = self.image_encoder(clip_image, output_hidden_states=True) + clip_image_embeds = clip_output.hidden_states[-2] + return clip_image_embeds + + import torch + + def set_vec(self, new_vector, text_encoder_idx=0): + # Get the embedding layer + embedding_layer = self.text_encoder_list[text_encoder_idx].get_input_embeddings() + + # Indices to replace in the embeddings + indices_to_replace = self.placeholder_token_ids[text_encoder_idx] + + # Replace the specified embeddings with new_vector + for idx in indices_to_replace: + vector_idx = idx - indices_to_replace[0] + embedding_layer.weight[idx] = new_vector[vector_idx] + + # adds it to the tokenizer + def forward(self, clip_image_embeds: torch.Tensor) -> PromptEmbeds: + clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + image_prompt_embeds = self.embedder(clip_image_embeds) + # todo add support for multiple batch sizes + if image_prompt_embeds.shape[0] != 1: + raise ValueError("Batch size must be 1 for embedder for now") + + # output on sd1.5 is bs, num_tokens, 768 + if len(self.text_encoder_list) == 1: + # add it to the text encoder + self.set_vec(image_prompt_embeds[0], text_encoder_idx=0) + else: + raise ValueError("Multiple text encoders not supported yet") + # just a place to put a breakpoint + pass + + def restore_embeddings(self): + # Let's make sure we don't update any embedding weights besides the newly added token + for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip( + self.text_encoder_list, + self.tokenizer_list, + self.orig_embeds_params, + self.placeholder_token_ids + ): + index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) + index_no_updates[ + min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False + with torch.no_grad(): + text_encoder.get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds[index_no_updates] + # detach it all + text_encoder.get_input_embeddings().weight.detach_() + + def enable_gradient_checkpointing(self): + self.image_encoder.gradient_checkpointing = True + + def inject_trigger_into_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True): + output_prompt = prompt + embedding_tokens = self.embedding_tokens[0] # shoudl be the same + default_replacements = ["[name]", "[trigger]"] + + replace_with = embedding_tokens if expand_token else self.trigger + if to_replace_list is None: + to_replace_list = default_replacements + else: + to_replace_list += default_replacements + + # remove duplicates + to_replace_list = list(set(to_replace_list)) + + # replace them all + for to_replace in to_replace_list: + # replace it + output_prompt = output_prompt.replace(to_replace, replace_with) + + # see how many times replace_with is in the prompt + num_instances = output_prompt.count(replace_with) + + if num_instances == 0 and add_if_not_present: + # add it to the beginning of the prompt + output_prompt = replace_with + " " + output_prompt + + if num_instances > 1: + print( + f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + + return output_prompt + + # reverses injection with class name. useful for normalizations + def inject_trigger_class_name_into_prompt(self, prompt): + output_prompt = prompt + embedding_tokens = self.embedding_tokens[0] # shoudl be the same + + default_replacements = ["[name]", "[trigger]", embedding_tokens, self.trigger] + + replace_with = self.config.trigger_class_name + to_replace_list = default_replacements + + # remove duplicates + to_replace_list = list(set(to_replace_list)) + + # replace them all + for to_replace in to_replace_list: + # replace it + output_prompt = output_prompt.replace(to_replace, replace_with) + + # see how many times replace_with is in the prompt + num_instances = output_prompt.count(replace_with) + + if num_instances > 1: + print( + f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + + return output_prompt diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 96074c0d..d1e3a05e 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -14,6 +14,7 @@ SaveFormat = Literal['safetensors', 'diffusers'] if TYPE_CHECKING: from toolkit.guidance import GuidanceType + class SaveConfig: def __init__(self, **kwargs): self.save_every: int = kwargs.get('save_every', 1000) @@ -47,7 +48,8 @@ class SampleConfig: self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) self.ext: ImgExt = kwargs.get('format', 'jpg') self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0) - self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) # step to start using refiner on sample if it exists + self.refiner_start_at = kwargs.get('refiner_start_at', + 0.5) # step to start using refiner on sample if it exists class LormModuleSettingsConfig: @@ -130,7 +132,7 @@ AdapterTypes = Literal['t2i', 'ip', 'ip+'] class AdapterConfig: def __init__(self, **kwargs): - self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip + self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip self.in_channels: int = kwargs.get('in_channels', 3) self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280]) self.num_res_blocks: int = kwargs.get('num_res_blocks', 2) @@ -153,15 +155,9 @@ class AdapterConfig: self.train_image_encoder: bool = kwargs.get('train_image_encoder', False) self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid - - -class ClipTokenMakerConfig: - def __init__(self, **kwargs): - self.image_encoder_path: str = kwargs.get('image_encoder_path', None) - self.num_tokens: int = kwargs.get('num_tokens', 8) - - - + # clip vision + self.trigger = kwargs.get('trigger', 'tri993r') + self.trigger_class_name = kwargs.get('trigger_class_name', 'person') class EmbeddingConfig: @@ -401,7 +397,8 @@ class DatasetConfig: self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black) - self.unconditional_path: str = kwargs.get('unconditional_path', None) # path where matching unconditional images are located + self.unconditional_path: str = kwargs.get('unconditional_path', + None) # path where matching unconditional images are located self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask self.mask_min_value: float = kwargs.get('mask_min_value', 0.0) # min value for . 0 - 1 self.poi: Union[str, None] = kwargs.get('poi', diff --git a/toolkit/resampler.py b/toolkit/resampler.py new file mode 100644 index 00000000..9ace5a3a --- /dev/null +++ b/toolkit/resampler.py @@ -0,0 +1,160 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py +# and https://github.com/tencent-ailab/IP-Adapter/blob/9fc189e3fb389cc2b60a7d0c0850e083a716ea6e/ip_adapter/resampler.py + +import math + +import torch +import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head ** -0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + max_seq_len: int = 257, # CLIP tokens + CLS token + apply_pos_emb: bool = False, + num_latents_mean_pooled: int = 0, + # number of latents derived from mean pooled representation of the sequence + ): + super().__init__() + self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.to_latents_from_mean_pooled_seq = ( + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * num_latents_mean_pooled), + Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), + ) + if num_latents_mean_pooled > 0 + else None + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + if self.pos_emb is not None: + n, device = x.shape[1], x.device + pos_emb = self.pos_emb(torch.arange(n, device=device)) + x = x + pos_emb + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + if self.to_latents_from_mean_pooled_seq: + meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) + meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) + latents = torch.cat((meanpooled_latents, latents), dim=-2) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +def masked_mean(t, *, dim, mask=None): + if mask is None: + return t.mean(dim=dim) + + denom = mask.sum(dim=dim, keepdim=True) + mask = rearrange(mask, "b n -> b n 1") + masked_t = t.masked_fill(~mask, 0.0) + + return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 18d78a52..6bac0aee 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -18,6 +18,7 @@ from torch.utils.checkpoint import checkpoint from tqdm import tqdm from torchvision.transforms import Resize, transforms +from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.ip_adapter import IPAdapter from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ convert_vae_state_dict, load_vae @@ -472,7 +473,7 @@ class StableDiffusion: validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2)) extra['image'] = validation_image extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale - if isinstance(self.adapter, IPAdapter): + if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): transform = transforms.Compose([ transforms.ToTensor(), ]) @@ -483,6 +484,12 @@ class StableDiffusion: torch.manual_seed(gen_config.seed) torch.cuda.manual_seed(gen_config.seed) + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ + and gen_config.adapter_image_path is not None: + # run through the adapter to saturate the embeds + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) + self.adapter(conditional_clip_embeds) + # encode the prompt ourselves so we can do fun stuff with embeddings conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True) @@ -496,8 +503,8 @@ class StableDiffusion: unconditional_embeds, ) - if self.adapter is not None and isinstance(self.adapter, - IPAdapter) and gen_config.adapter_image_path is not None: + if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ + and gen_config.adapter_image_path is not None: # apply the image projection conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) @@ -1445,6 +1452,9 @@ class StableDiffusion: elif isinstance(self.adapter, T2IAdapter): requires_grad = self.adapter.adapter.conv_in.weight.requires_grad adapter_device = self.adapter.device + elif isinstance(self.adapter, ClipVisionAdapter): + requires_grad = self.adapter.embedder.training + adapter_device = self.adapter.device else: raise ValueError(f"Unknown adapter type: {type(self.adapter)}") self.device_state['adapter'] = {