Work on embedding adapters

This commit is contained in:
Jaret Burkett
2024-03-11 15:18:42 -06:00
parent f415bac7b5
commit d87b49882c
3 changed files with 55 additions and 6 deletions

View File

@@ -767,6 +767,10 @@ class CustomAdapter(torch.nn.Module):
clip_image_embeds = clip_output.hidden_states[-1]
else:
clip_image_embeds = clip_output.image_embeds
# TODO should we always norm image embeds?
# get norm embeddings
l2_norm = torch.norm(clip_image_embeds, p=2)
clip_image_embeds = clip_image_embeds / l2_norm
if not is_training or not self.config.train_image_encoder:
clip_image_embeds = clip_image_embeds.detach()

View File

@@ -21,7 +21,7 @@ from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
AttnProcessor2_0
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
from ipadapter.ip_adapter.resampler import Resampler
from ipadapter.ip_adapter.resampler import PerceiverAttention, FeedForward, Resampler
from toolkit.config_modules import AdapterConfig
from toolkit.prompt_utils import PromptEmbeds
import weakref
@@ -51,6 +51,33 @@ from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
class MLPProjModelClipFace(torch.nn.Module):
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.norm = torch.nn.LayerNorm(id_embeddings_dim)
self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
)
# Initialize the last linear layer weights near zero
torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
torch.nn.init.zeros_(self.proj[2].bias)
# # Custom initialization for LayerNorm to output near zero
# torch.nn.init.constant_(self.norm.weight, 0.1) # Small weights near zero
# torch.nn.init.zeros_(self.norm.bias) # Bias to zero
def forward(self, x):
x = self.norm(x)
x = self.proj(x)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
return x
class CustomIPAttentionProcessor(IPAttnProcessor2_0):
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None):
super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens)
@@ -189,7 +216,7 @@ class IPAdapter(torch.nn.Module):
self.clip_noise_zero = True
self.unconditional: torch.Tensor = None
self.additional_loss = None
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
if self.config.image_encoder_arch.startswith("clip"):
try:
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
@@ -324,10 +351,18 @@ class IPAdapter(torch.nn.Module):
clip_embeddings_dim=self.image_encoder.config.projection_dim,
clip_extra_context_tokens=self.config.num_tokens, # usually 4
)
elif adapter_config.type == 'ip_clip_face':
cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim']
image_proj_model = MLPProjModelClipFace(
cross_attention_dim=cross_attn_dim,
id_embeddings_dim=1024,
num_tokens=self.config.num_tokens, # usually 4
)
elif adapter_config.type == 'ip+':
heads = 12 if not sd.is_xl else 20
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith('convnext') else \
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith(
'convnext') else \
self.image_encoder.config.hidden_sizes[-1]
image_encoder_state_dict = self.image_encoder.state_dict()
@@ -419,7 +454,7 @@ class IPAdapter(torch.nn.Module):
for name in attn_processor_keys:
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \
sd.unet.config['cross_attention_dim']
sd.unet.config['cross_attention_dim']
if name.startswith("mid_block"):
hidden_size = sd.unet.config['block_out_channels'][-1]
elif name.startswith("up_blocks"):
@@ -704,6 +739,10 @@ class IPAdapter(torch.nn.Module):
else:
clip_image_embeds = clip_output.image_embeds
if self.config.adapter_type == "clip_face":
l2_norm = torch.norm(clip_image_embeds, p=2)
clip_image_embeds = clip_image_embeds / l2_norm
if self.config.image_encoder_arch.startswith('convnext'):
# flatten the width height layers to make the token space
clip_image_embeds = clip_image_embeds.view(clip_image_embeds.size(0), clip_image_embeds.size(1), -1)

View File

@@ -216,7 +216,10 @@ class VisionDirectAdapterAttnProcessor(nn.Module):
adapter_hidden_states = torch.cat([
self.unconditional_embeds,
adapter_hidden_states
])
], dim=0)
# if it is image embeds, we need to add a 1 dim at inx 1
if len(adapter_hidden_states.shape) == 2:
adapter_hidden_states = adapter_hidden_states.unsqueeze(1)
# conditional_batch_size = adapter_hidden_states.shape[0]
# conditional_query = query
@@ -268,7 +271,10 @@ class VisionDirectAdapter(torch.nn.Module):
self.sd_ref: weakref.ref = weakref.ref(sd)
self.vision_model_ref: weakref.ref = weakref.ref(vision_model)
self.token_size = vision_model.config.hidden_size
if adapter.config.clip_layer == "image_embeds":
self.token_size = vision_model.config.projection_dim
else:
self.token_size = vision_model.config.hidden_size
# init adapter modules
attn_procs = {}