mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added an experimental clip fusion model that is showing promise for embedding concepts
This commit is contained in:
@@ -5,6 +5,7 @@ from PIL import Image
|
||||
from torch.nn import Parameter
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.models.clip_fusion import CLIPFusionModule
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
||||
@@ -72,12 +73,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
# add for dataloader
|
||||
self.clip_image_processor = self.image_processor
|
||||
|
||||
self.clip_fusion_module: CLIPFusionModule = None
|
||||
|
||||
self.setup_adapter()
|
||||
|
||||
if self.adapter_type == 'photo_maker':
|
||||
# try to load from our name_or_path
|
||||
if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'):
|
||||
self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False)
|
||||
|
||||
# add the trigger word to the tokenizer
|
||||
if isinstance(self.sd_ref().tokenizer, list):
|
||||
for tokenizer in self.sd_ref().tokenizer:
|
||||
@@ -90,31 +93,45 @@ class CustomAdapter(torch.nn.Module):
|
||||
sd = self.sd_ref()
|
||||
embed_dim = sd.unet.config['cross_attention_dim']
|
||||
self.fuse_module = FuseModule(embed_dim)
|
||||
elif self.adapter_type == 'clip_fusion':
|
||||
sd = self.sd_ref()
|
||||
embed_dim = sd.unet.config['cross_attention_dim']
|
||||
|
||||
vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
|
||||
if self.config.image_encoder_arch == 'clip':
|
||||
vision_tokens = vision_tokens + 1
|
||||
self.clip_fusion_module = CLIPFusionModule(
|
||||
text_hidden_size=embed_dim,
|
||||
text_tokens=77,
|
||||
vision_hidden_size=self.vision_encoder.config.hidden_size,
|
||||
vision_tokens=vision_tokens
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown adapter type: {self.adapter_type}")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.adapter_type == 'photo_maker':
|
||||
id_pixel_values = args[0]
|
||||
prompt_embeds: PromptEmbeds = args[1]
|
||||
class_tokens_mask = args[2]
|
||||
|
||||
grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled()
|
||||
|
||||
with torch.set_grad_enabled(grads_on_image_encoder):
|
||||
id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False)
|
||||
|
||||
if not grads_on_image_encoder:
|
||||
id_embeds = id_embeds.detach()
|
||||
|
||||
prompt_embeds = prompt_embeds.detach()
|
||||
|
||||
updated_prompt_embeds = self.fuse_module(
|
||||
prompt_embeds, id_embeds, class_tokens_mask
|
||||
)
|
||||
|
||||
return updated_prompt_embeds
|
||||
else:
|
||||
# dont think this is used
|
||||
# if self.adapter_type == 'photo_maker':
|
||||
# id_pixel_values = args[0]
|
||||
# prompt_embeds: PromptEmbeds = args[1]
|
||||
# class_tokens_mask = args[2]
|
||||
#
|
||||
# grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled()
|
||||
#
|
||||
# with torch.set_grad_enabled(grads_on_image_encoder):
|
||||
# id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False)
|
||||
#
|
||||
# if not grads_on_image_encoder:
|
||||
# id_embeds = id_embeds.detach()
|
||||
#
|
||||
# prompt_embeds = prompt_embeds.detach()
|
||||
#
|
||||
# updated_prompt_embeds = self.fuse_module(
|
||||
# prompt_embeds, id_embeds, class_tokens_mask
|
||||
# )
|
||||
#
|
||||
# return updated_prompt_embeds
|
||||
# else:
|
||||
raise NotImplementedError
|
||||
|
||||
def setup_clip(self):
|
||||
@@ -226,7 +243,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
# self.sd_ref().pipeline.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker")
|
||||
# self.sd_ref().pipeline.fuse_lora()
|
||||
pass
|
||||
if 'id_encoder' in state_dict and self.adapter_type == 'photo_maker':
|
||||
if 'clip_fusion' in state_dict:
|
||||
self.clip_fusion_module.load_state_dict(state_dict['clip_fusion'], strict=strict)
|
||||
if 'id_encoder' in state_dict and (self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion'):
|
||||
self.vision_encoder.load_state_dict(state_dict['id_encoder'], strict=strict)
|
||||
# check to see if the fuse weights are there
|
||||
fuse_weights = {}
|
||||
@@ -235,7 +254,31 @@ class CustomAdapter(torch.nn.Module):
|
||||
k = k.replace('fuse_module.', '')
|
||||
fuse_weights[k] = v
|
||||
if len(fuse_weights) > 0:
|
||||
try:
|
||||
self.fuse_module.load_state_dict(fuse_weights, strict=strict)
|
||||
except Exception as e:
|
||||
|
||||
print(e)
|
||||
# force load it
|
||||
print(f"force loading fuse module as it did not match")
|
||||
current_state_dict = self.fuse_module.state_dict()
|
||||
for k, v in fuse_weights.items():
|
||||
if len(v.shape) == 1:
|
||||
current_state_dict[k] = v[:current_state_dict[k].shape[0]]
|
||||
elif len(v.shape) == 2:
|
||||
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1]]
|
||||
elif len(v.shape) == 3:
|
||||
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1],
|
||||
:current_state_dict[k].shape[2]]
|
||||
elif len(v.shape) == 4:
|
||||
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1],
|
||||
:current_state_dict[k].shape[2], :current_state_dict[k].shape[3]]
|
||||
else:
|
||||
raise ValueError(f"unknown shape: {v.shape}")
|
||||
self.fuse_module.load_state_dict(current_state_dict, strict=strict)
|
||||
|
||||
if 'vision_encoder' in state_dict:
|
||||
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
|
||||
|
||||
if 'fuse_module' in state_dict:
|
||||
self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
|
||||
@@ -252,6 +295,12 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
# todo save LoRA
|
||||
return state_dict
|
||||
|
||||
elif self.adapter_type == 'clip_fusion':
|
||||
if self.config.train_image_encoder:
|
||||
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
|
||||
state_dict["clip_fusion"] = self.clip_fusion_module.state_dict()
|
||||
return state_dict
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -260,7 +309,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
prompt: Union[List[str], str],
|
||||
is_unconditional: bool = False,
|
||||
):
|
||||
if self.adapter_type == 'photo_maker':
|
||||
if self.adapter_type == 'clip_fusion':
|
||||
return prompt
|
||||
elif self.adapter_type == 'photo_maker':
|
||||
if is_unconditional:
|
||||
return prompt
|
||||
else:
|
||||
@@ -310,7 +361,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
# add the first one to the front of the prompt
|
||||
tokened_prompt = self.config.class_names[0] + ' ' + self.flag_word + ' ' + prompt
|
||||
our_class = self.config.class_names[0]
|
||||
prompt = " ".join([self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt
|
||||
prompt = " ".join(
|
||||
[self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt
|
||||
|
||||
# add the prompt to the list
|
||||
new_prompt_list.append(prompt)
|
||||
@@ -329,7 +381,6 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
class_token = tokenized_prompt[flag_idx - 1]
|
||||
|
||||
|
||||
boolean_mask = torch.zeros(flag_idx - 1, dtype=torch.bool)
|
||||
boolean_mask = torch.cat((boolean_mask, torch.ones(self.num_control_images, dtype=torch.bool)))
|
||||
boolean_mask = boolean_mask.to(self.device)
|
||||
@@ -357,7 +408,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
has_been_preprocessed=False,
|
||||
is_unconditional=False
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'photo_maker':
|
||||
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
|
||||
if is_unconditional:
|
||||
# we dont condition the negative embeds for photo maker
|
||||
return prompt_embeds
|
||||
@@ -384,10 +435,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
clip_image = tensors_0_1
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
|
||||
|
||||
if self.adapter_type == 'photo_maker':
|
||||
# Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image
|
||||
clip_image = clip_image.unsqueeze(1)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
@@ -408,11 +458,46 @@ class CustomAdapter(torch.nn.Module):
|
||||
id_embeds,
|
||||
self.token_mask
|
||||
)
|
||||
|
||||
return prompt_embeds
|
||||
elif self.adapter_type == 'clip_fusion':
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
clip_image = clip_image.requires_grad_(True)
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
self.vision_encoder.eval()
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image, output_hidden_states=True
|
||||
)
|
||||
|
||||
img_embeds = id_embeds['last_hidden_state']
|
||||
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
img_embeds = img_embeds.detach()
|
||||
|
||||
prompt_embeds.text_embeds = self.clip_fusion_module(
|
||||
prompt_embeds.text_embeds,
|
||||
img_embeds
|
||||
)
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.type == 'photo_maker':
|
||||
yield from self.fuse_module.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
elif self.config.type == 'clip_fusion':
|
||||
yield from self.clip_fusion_module.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
143
toolkit/models/clip_fusion.py
Normal file
143
toolkit/models/clip_fusion.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# Conv1d MLP
|
||||
# MLP that can alternately be used as a conv1d on dim 1
|
||||
class MLPC(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
hidden_dim,
|
||||
do_conv=False,
|
||||
use_residual=True
|
||||
):
|
||||
super().__init__()
|
||||
self.do_conv = do_conv
|
||||
if use_residual:
|
||||
assert in_dim == out_dim
|
||||
# dont normalize if using conv
|
||||
if not do_conv:
|
||||
self.layernorm = nn.LayerNorm(in_dim)
|
||||
|
||||
if do_conv:
|
||||
self.fc1 = nn.Conv1d(in_dim, hidden_dim, 1)
|
||||
self.fc2 = nn.Conv1d(hidden_dim, out_dim, 1)
|
||||
else:
|
||||
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
||||
|
||||
self.use_residual = use_residual
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
if not self.do_conv:
|
||||
x = self.layernorm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
if self.use_residual:
|
||||
x = x + residual
|
||||
return x
|
||||
|
||||
|
||||
class ZipperBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
in_tokens,
|
||||
out_size,
|
||||
out_tokens,
|
||||
hidden_size,
|
||||
hidden_tokens,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_size = in_size
|
||||
self.in_tokens = in_tokens
|
||||
self.out_size = out_size
|
||||
self.out_tokens = out_tokens
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_tokens = hidden_tokens
|
||||
# permute to (batch_size, out_size, in_tokens)
|
||||
|
||||
self.zip_token = MLPC(
|
||||
in_dim=self.in_tokens,
|
||||
out_dim=self.out_tokens,
|
||||
hidden_dim=self.hidden_tokens,
|
||||
do_conv=True, # no need to permute
|
||||
use_residual=False
|
||||
)
|
||||
|
||||
# permute to (batch_size, out_tokens, out_size)
|
||||
|
||||
# in shpae: (batch_size, in_tokens, in_size)
|
||||
self.zip_size = MLPC(
|
||||
in_dim=self.in_size,
|
||||
out_dim=self.out_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
use_residual=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.zip_token(x)
|
||||
x = self.zip_size(x)
|
||||
return x
|
||||
|
||||
|
||||
# CLIPFusionModule
|
||||
# Fuses any size of vision and text embeddings into a single embedding.
|
||||
# remaps tokens and vectors.
|
||||
class CLIPFusionModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
text_hidden_size: int = 768,
|
||||
text_tokens: int = 77,
|
||||
vision_hidden_size: int = 1024,
|
||||
vision_tokens: int = 257,
|
||||
num_blocks: int = 2,
|
||||
):
|
||||
super(CLIPFusionModule, self).__init__()
|
||||
|
||||
self.text_hidden_size = text_hidden_size
|
||||
self.text_tokens = text_tokens
|
||||
self.vision_hidden_size = vision_hidden_size
|
||||
self.vision_tokens = vision_tokens
|
||||
|
||||
self.resampler = ZipperBlock(
|
||||
in_size=self.vision_hidden_size,
|
||||
in_tokens=self.vision_tokens,
|
||||
out_size=self.text_hidden_size,
|
||||
out_tokens=self.text_tokens,
|
||||
hidden_size=self.vision_hidden_size * 2,
|
||||
hidden_tokens=self.vision_tokens * 2
|
||||
)
|
||||
|
||||
self.zipper_blocks = torch.nn.ModuleList([
|
||||
ZipperBlock(
|
||||
in_size=self.text_hidden_size * 2,
|
||||
in_tokens=self.text_tokens,
|
||||
out_size=self.text_hidden_size,
|
||||
out_tokens=self.text_tokens,
|
||||
hidden_size=self.text_hidden_size * 2,
|
||||
hidden_tokens=self.text_tokens * 2
|
||||
) for i in range(num_blocks)
|
||||
])
|
||||
|
||||
def forward(self, text_embeds, vision_embeds):
|
||||
# text_embeds = (batch_size, 77, 768)
|
||||
# vision_embeds = (batch_size, 257, 1024)
|
||||
# output = (batch_size, 77, 768)
|
||||
|
||||
vision_embeds = self.resampler(vision_embeds)
|
||||
x = vision_embeds
|
||||
for i, block in enumerate(self.zipper_blocks):
|
||||
res = x
|
||||
x = torch.cat([text_embeds, x], dim=-1)
|
||||
x = block(x)
|
||||
x = x + res
|
||||
|
||||
x = text_embeds + x
|
||||
|
||||
return x
|
||||
@@ -119,11 +119,12 @@ class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection):
|
||||
super().__init__(config, *model_args, **model_kwargs)
|
||||
self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
|
||||
|
||||
def forward(self, id_pixel_values, do_projection2=True):
|
||||
def forward(self, id_pixel_values, do_projection2=True, output_full=False):
|
||||
b, num_inputs, c, h, w = id_pixel_values.shape
|
||||
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
|
||||
|
||||
shared_id_embeds = self.vision_model(id_pixel_values)[1]
|
||||
# last_hidden_state, 1, 257, 1024
|
||||
vision_output = self.vision_model(id_pixel_values, output_hidden_states=True)
|
||||
shared_id_embeds = vision_output[1]
|
||||
id_embeds = self.visual_projection(shared_id_embeds)
|
||||
|
||||
id_embeds = id_embeds.view(b, num_inputs, 1, -1)
|
||||
@@ -133,6 +134,8 @@ class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection):
|
||||
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
|
||||
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
|
||||
|
||||
if output_full:
|
||||
return id_embeds, vision_output
|
||||
return id_embeds
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user