mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Work on embedding adapters
This commit is contained in:
@@ -767,6 +767,10 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
clip_image_embeds = clip_output.hidden_states[-1]
|
clip_image_embeds = clip_output.hidden_states[-1]
|
||||||
else:
|
else:
|
||||||
clip_image_embeds = clip_output.image_embeds
|
clip_image_embeds = clip_output.image_embeds
|
||||||
|
# TODO should we always norm image embeds?
|
||||||
|
# get norm embeddings
|
||||||
|
l2_norm = torch.norm(clip_image_embeds, p=2)
|
||||||
|
clip_image_embeds = clip_image_embeds / l2_norm
|
||||||
|
|
||||||
if not is_training or not self.config.train_image_encoder:
|
if not is_training or not self.config.train_image_encoder:
|
||||||
clip_image_embeds = clip_image_embeds.detach()
|
clip_image_embeds = clip_image_embeds.detach()
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from collections import OrderedDict
|
|||||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
||||||
AttnProcessor2_0
|
AttnProcessor2_0
|
||||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
||||||
from ipadapter.ip_adapter.resampler import Resampler
|
from ipadapter.ip_adapter.resampler import PerceiverAttention, FeedForward, Resampler
|
||||||
from toolkit.config_modules import AdapterConfig
|
from toolkit.config_modules import AdapterConfig
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
import weakref
|
import weakref
|
||||||
@@ -51,6 +51,33 @@ from torch.utils.checkpoint import checkpoint
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class MLPProjModelClipFace(torch.nn.Module):
|
||||||
|
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.num_tokens = num_tokens
|
||||||
|
self.norm = torch.nn.LayerNorm(id_embeddings_dim)
|
||||||
|
|
||||||
|
self.proj = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
|
||||||
|
torch.nn.GELU(),
|
||||||
|
torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
|
||||||
|
)
|
||||||
|
# Initialize the last linear layer weights near zero
|
||||||
|
torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
|
||||||
|
torch.nn.init.zeros_(self.proj[2].bias)
|
||||||
|
# # Custom initialization for LayerNorm to output near zero
|
||||||
|
# torch.nn.init.constant_(self.norm.weight, 0.1) # Small weights near zero
|
||||||
|
# torch.nn.init.zeros_(self.norm.bias) # Bias to zero
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
||||||
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None):
|
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None):
|
||||||
super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens)
|
super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens)
|
||||||
@@ -189,7 +216,7 @@ class IPAdapter(torch.nn.Module):
|
|||||||
self.clip_noise_zero = True
|
self.clip_noise_zero = True
|
||||||
self.unconditional: torch.Tensor = None
|
self.unconditional: torch.Tensor = None
|
||||||
self.additional_loss = None
|
self.additional_loss = None
|
||||||
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
|
if self.config.image_encoder_arch.startswith("clip"):
|
||||||
try:
|
try:
|
||||||
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
@@ -324,10 +351,18 @@ class IPAdapter(torch.nn.Module):
|
|||||||
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
||||||
clip_extra_context_tokens=self.config.num_tokens, # usually 4
|
clip_extra_context_tokens=self.config.num_tokens, # usually 4
|
||||||
)
|
)
|
||||||
|
elif adapter_config.type == 'ip_clip_face':
|
||||||
|
cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim']
|
||||||
|
image_proj_model = MLPProjModelClipFace(
|
||||||
|
cross_attention_dim=cross_attn_dim,
|
||||||
|
id_embeddings_dim=1024,
|
||||||
|
num_tokens=self.config.num_tokens, # usually 4
|
||||||
|
)
|
||||||
elif adapter_config.type == 'ip+':
|
elif adapter_config.type == 'ip+':
|
||||||
heads = 12 if not sd.is_xl else 20
|
heads = 12 if not sd.is_xl else 20
|
||||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith('convnext') else \
|
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith(
|
||||||
|
'convnext') else \
|
||||||
self.image_encoder.config.hidden_sizes[-1]
|
self.image_encoder.config.hidden_sizes[-1]
|
||||||
|
|
||||||
image_encoder_state_dict = self.image_encoder.state_dict()
|
image_encoder_state_dict = self.image_encoder.state_dict()
|
||||||
@@ -419,7 +454,7 @@ class IPAdapter(torch.nn.Module):
|
|||||||
|
|
||||||
for name in attn_processor_keys:
|
for name in attn_processor_keys:
|
||||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \
|
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \
|
||||||
sd.unet.config['cross_attention_dim']
|
sd.unet.config['cross_attention_dim']
|
||||||
if name.startswith("mid_block"):
|
if name.startswith("mid_block"):
|
||||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||||
elif name.startswith("up_blocks"):
|
elif name.startswith("up_blocks"):
|
||||||
@@ -704,6 +739,10 @@ class IPAdapter(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
clip_image_embeds = clip_output.image_embeds
|
clip_image_embeds = clip_output.image_embeds
|
||||||
|
|
||||||
|
if self.config.adapter_type == "clip_face":
|
||||||
|
l2_norm = torch.norm(clip_image_embeds, p=2)
|
||||||
|
clip_image_embeds = clip_image_embeds / l2_norm
|
||||||
|
|
||||||
if self.config.image_encoder_arch.startswith('convnext'):
|
if self.config.image_encoder_arch.startswith('convnext'):
|
||||||
# flatten the width height layers to make the token space
|
# flatten the width height layers to make the token space
|
||||||
clip_image_embeds = clip_image_embeds.view(clip_image_embeds.size(0), clip_image_embeds.size(1), -1)
|
clip_image_embeds = clip_image_embeds.view(clip_image_embeds.size(0), clip_image_embeds.size(1), -1)
|
||||||
|
|||||||
@@ -216,7 +216,10 @@ class VisionDirectAdapterAttnProcessor(nn.Module):
|
|||||||
adapter_hidden_states = torch.cat([
|
adapter_hidden_states = torch.cat([
|
||||||
self.unconditional_embeds,
|
self.unconditional_embeds,
|
||||||
adapter_hidden_states
|
adapter_hidden_states
|
||||||
])
|
], dim=0)
|
||||||
|
# if it is image embeds, we need to add a 1 dim at inx 1
|
||||||
|
if len(adapter_hidden_states.shape) == 2:
|
||||||
|
adapter_hidden_states = adapter_hidden_states.unsqueeze(1)
|
||||||
# conditional_batch_size = adapter_hidden_states.shape[0]
|
# conditional_batch_size = adapter_hidden_states.shape[0]
|
||||||
# conditional_query = query
|
# conditional_query = query
|
||||||
|
|
||||||
@@ -268,7 +271,10 @@ class VisionDirectAdapter(torch.nn.Module):
|
|||||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||||
self.vision_model_ref: weakref.ref = weakref.ref(vision_model)
|
self.vision_model_ref: weakref.ref = weakref.ref(vision_model)
|
||||||
|
|
||||||
self.token_size = vision_model.config.hidden_size
|
if adapter.config.clip_layer == "image_embeds":
|
||||||
|
self.token_size = vision_model.config.projection_dim
|
||||||
|
else:
|
||||||
|
self.token_size = vision_model.config.hidden_size
|
||||||
|
|
||||||
# init adapter modules
|
# init adapter modules
|
||||||
attn_procs = {}
|
attn_procs = {}
|
||||||
|
|||||||
Reference in New Issue
Block a user