mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-20 20:33:57 +00:00
Work on embedding adapters
This commit is contained in:
@@ -767,6 +767,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
clip_image_embeds = clip_output.hidden_states[-1]
|
||||
else:
|
||||
clip_image_embeds = clip_output.image_embeds
|
||||
# TODO should we always norm image embeds?
|
||||
# get norm embeddings
|
||||
l2_norm = torch.norm(clip_image_embeds, p=2)
|
||||
clip_image_embeds = clip_image_embeds / l2_norm
|
||||
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
clip_image_embeds = clip_image_embeds.detach()
|
||||
|
||||
@@ -21,7 +21,7 @@ from collections import OrderedDict
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
||||
AttnProcessor2_0
|
||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
from ipadapter.ip_adapter.resampler import PerceiverAttention, FeedForward, Resampler
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
import weakref
|
||||
@@ -51,6 +51,33 @@ from torch.utils.checkpoint import checkpoint
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MLPProjModelClipFace(torch.nn.Module):
|
||||
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
||||
super().__init__()
|
||||
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.num_tokens = num_tokens
|
||||
self.norm = torch.nn.LayerNorm(id_embeddings_dim)
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
|
||||
)
|
||||
# Initialize the last linear layer weights near zero
|
||||
torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
|
||||
torch.nn.init.zeros_(self.proj[2].bias)
|
||||
# # Custom initialization for LayerNorm to output near zero
|
||||
# torch.nn.init.constant_(self.norm.weight, 0.1) # Small weights near zero
|
||||
# torch.nn.init.zeros_(self.norm.bias) # Bias to zero
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
x = self.proj(x)
|
||||
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
||||
return x
|
||||
|
||||
|
||||
class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
||||
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None):
|
||||
super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens)
|
||||
@@ -189,7 +216,7 @@ class IPAdapter(torch.nn.Module):
|
||||
self.clip_noise_zero = True
|
||||
self.unconditional: torch.Tensor = None
|
||||
self.additional_loss = None
|
||||
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
|
||||
if self.config.image_encoder_arch.startswith("clip"):
|
||||
try:
|
||||
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
@@ -324,10 +351,18 @@ class IPAdapter(torch.nn.Module):
|
||||
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
||||
clip_extra_context_tokens=self.config.num_tokens, # usually 4
|
||||
)
|
||||
elif adapter_config.type == 'ip_clip_face':
|
||||
cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim']
|
||||
image_proj_model = MLPProjModelClipFace(
|
||||
cross_attention_dim=cross_attn_dim,
|
||||
id_embeddings_dim=1024,
|
||||
num_tokens=self.config.num_tokens, # usually 4
|
||||
)
|
||||
elif adapter_config.type == 'ip+':
|
||||
heads = 12 if not sd.is_xl else 20
|
||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith('convnext') else \
|
||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith(
|
||||
'convnext') else \
|
||||
self.image_encoder.config.hidden_sizes[-1]
|
||||
|
||||
image_encoder_state_dict = self.image_encoder.state_dict()
|
||||
@@ -419,7 +454,7 @@ class IPAdapter(torch.nn.Module):
|
||||
|
||||
for name in attn_processor_keys:
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \
|
||||
sd.unet.config['cross_attention_dim']
|
||||
sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
@@ -704,6 +739,10 @@ class IPAdapter(torch.nn.Module):
|
||||
else:
|
||||
clip_image_embeds = clip_output.image_embeds
|
||||
|
||||
if self.config.adapter_type == "clip_face":
|
||||
l2_norm = torch.norm(clip_image_embeds, p=2)
|
||||
clip_image_embeds = clip_image_embeds / l2_norm
|
||||
|
||||
if self.config.image_encoder_arch.startswith('convnext'):
|
||||
# flatten the width height layers to make the token space
|
||||
clip_image_embeds = clip_image_embeds.view(clip_image_embeds.size(0), clip_image_embeds.size(1), -1)
|
||||
|
||||
@@ -216,7 +216,10 @@ class VisionDirectAdapterAttnProcessor(nn.Module):
|
||||
adapter_hidden_states = torch.cat([
|
||||
self.unconditional_embeds,
|
||||
adapter_hidden_states
|
||||
])
|
||||
], dim=0)
|
||||
# if it is image embeds, we need to add a 1 dim at inx 1
|
||||
if len(adapter_hidden_states.shape) == 2:
|
||||
adapter_hidden_states = adapter_hidden_states.unsqueeze(1)
|
||||
# conditional_batch_size = adapter_hidden_states.shape[0]
|
||||
# conditional_query = query
|
||||
|
||||
@@ -268,7 +271,10 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
self.vision_model_ref: weakref.ref = weakref.ref(vision_model)
|
||||
|
||||
self.token_size = vision_model.config.hidden_size
|
||||
if adapter.config.clip_layer == "image_embeds":
|
||||
self.token_size = vision_model.config.projection_dim
|
||||
else:
|
||||
self.token_size = vision_model.config.hidden_size
|
||||
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
|
||||
Reference in New Issue
Block a user