From d87b49882cae7065ba68453dd00cf587fd19b98a Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 11 Mar 2024 15:18:42 -0600 Subject: [PATCH] Work on embedding adapters --- toolkit/custom_adapter.py | 4 +++ toolkit/ip_adapter.py | 47 +++++++++++++++++++++++++++++++++--- toolkit/models/vd_adapter.py | 10 ++++++-- 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 18780d87..616fc466 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -767,6 +767,10 @@ class CustomAdapter(torch.nn.Module): clip_image_embeds = clip_output.hidden_states[-1] else: clip_image_embeds = clip_output.image_embeds + # TODO should we always norm image embeds? + # get norm embeddings + l2_norm = torch.norm(clip_image_embeds, p=2) + clip_image_embeds = clip_image_embeds / l2_norm if not is_training or not self.config.train_image_encoder: clip_image_embeds = clip_image_embeds.detach() diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index e30754de..72cab509 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -21,7 +21,7 @@ from collections import OrderedDict from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \ AttnProcessor2_0 from ipadapter.ip_adapter.ip_adapter import ImageProjModel -from ipadapter.ip_adapter.resampler import Resampler +from ipadapter.ip_adapter.resampler import PerceiverAttention, FeedForward, Resampler from toolkit.config_modules import AdapterConfig from toolkit.prompt_utils import PromptEmbeds import weakref @@ -51,6 +51,33 @@ from torch.utils.checkpoint import checkpoint import torch.nn.functional as F +class MLPProjModelClipFace(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + self.norm = torch.nn.LayerNorm(id_embeddings_dim) + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens), + ) + # Initialize the last linear layer weights near zero + torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01) + torch.nn.init.zeros_(self.proj[2].bias) + # # Custom initialization for LayerNorm to output near zero + # torch.nn.init.constant_(self.norm.weight, 0.1) # Small weights near zero + # torch.nn.init.zeros_(self.norm.bias) # Bias to zero + + def forward(self, x): + x = self.norm(x) + x = self.proj(x) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return x + + class CustomIPAttentionProcessor(IPAttnProcessor2_0): def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None): super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens) @@ -189,7 +216,7 @@ class IPAdapter(torch.nn.Module): self.clip_noise_zero = True self.unconditional: torch.Tensor = None self.additional_loss = None - if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+': + if self.config.image_encoder_arch.startswith("clip"): try: self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) except EnvironmentError: @@ -324,10 +351,18 @@ class IPAdapter(torch.nn.Module): clip_embeddings_dim=self.image_encoder.config.projection_dim, clip_extra_context_tokens=self.config.num_tokens, # usually 4 ) + elif adapter_config.type == 'ip_clip_face': + cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim'] + image_proj_model = MLPProjModelClipFace( + cross_attention_dim=cross_attn_dim, + id_embeddings_dim=1024, + num_tokens=self.config.num_tokens, # usually 4 + ) elif adapter_config.type == 'ip+': heads = 12 if not sd.is_xl else 20 dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 - embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith('convnext') else \ + embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith( + 'convnext') else \ self.image_encoder.config.hidden_sizes[-1] image_encoder_state_dict = self.image_encoder.state_dict() @@ -419,7 +454,7 @@ class IPAdapter(torch.nn.Module): for name in attn_processor_keys: cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \ - sd.unet.config['cross_attention_dim'] + sd.unet.config['cross_attention_dim'] if name.startswith("mid_block"): hidden_size = sd.unet.config['block_out_channels'][-1] elif name.startswith("up_blocks"): @@ -704,6 +739,10 @@ class IPAdapter(torch.nn.Module): else: clip_image_embeds = clip_output.image_embeds + if self.config.adapter_type == "clip_face": + l2_norm = torch.norm(clip_image_embeds, p=2) + clip_image_embeds = clip_image_embeds / l2_norm + if self.config.image_encoder_arch.startswith('convnext'): # flatten the width height layers to make the token space clip_image_embeds = clip_image_embeds.view(clip_image_embeds.size(0), clip_image_embeds.size(1), -1) diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index a83e154b..7e3c8b68 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -216,7 +216,10 @@ class VisionDirectAdapterAttnProcessor(nn.Module): adapter_hidden_states = torch.cat([ self.unconditional_embeds, adapter_hidden_states - ]) + ], dim=0) + # if it is image embeds, we need to add a 1 dim at inx 1 + if len(adapter_hidden_states.shape) == 2: + adapter_hidden_states = adapter_hidden_states.unsqueeze(1) # conditional_batch_size = adapter_hidden_states.shape[0] # conditional_query = query @@ -268,7 +271,10 @@ class VisionDirectAdapter(torch.nn.Module): self.sd_ref: weakref.ref = weakref.ref(sd) self.vision_model_ref: weakref.ref = weakref.ref(vision_model) - self.token_size = vision_model.config.hidden_size + if adapter.config.clip_layer == "image_embeds": + self.token_size = vision_model.config.projection_dim + else: + self.token_size = vision_model.config.hidden_size # init adapter modules attn_procs = {}