Added an experimental clip fusion model that is showing promise for embedding concepts

This commit is contained in:
Jaret Burkett
2024-01-17 13:13:04 -07:00
parent 655533d4c7
commit 86c70a2a1f
3 changed files with 320 additions and 89 deletions

View File

@@ -5,6 +5,7 @@ from PIL import Image
from torch.nn import Parameter
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.models.clip_fusion import CLIPFusionModule
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.paths import REPOS_ROOT
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
@@ -72,12 +73,14 @@ class CustomAdapter(torch.nn.Module):
# add for dataloader
self.clip_image_processor = self.image_processor
self.clip_fusion_module: CLIPFusionModule = None
self.setup_adapter()
if self.adapter_type == 'photo_maker':
# try to load from our name_or_path
if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'):
self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False)
# add the trigger word to the tokenizer
if isinstance(self.sd_ref().tokenizer, list):
for tokenizer in self.sd_ref().tokenizer:
@@ -90,31 +93,45 @@ class CustomAdapter(torch.nn.Module):
sd = self.sd_ref()
embed_dim = sd.unet.config['cross_attention_dim']
self.fuse_module = FuseModule(embed_dim)
elif self.adapter_type == 'clip_fusion':
sd = self.sd_ref()
embed_dim = sd.unet.config['cross_attention_dim']
vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
if self.config.image_encoder_arch == 'clip':
vision_tokens = vision_tokens + 1
self.clip_fusion_module = CLIPFusionModule(
text_hidden_size=embed_dim,
text_tokens=77,
vision_hidden_size=self.vision_encoder.config.hidden_size,
vision_tokens=vision_tokens
)
else:
raise ValueError(f"unknown adapter type: {self.adapter_type}")
def forward(self, *args, **kwargs):
if self.adapter_type == 'photo_maker':
id_pixel_values = args[0]
prompt_embeds: PromptEmbeds = args[1]
class_tokens_mask = args[2]
grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled()
with torch.set_grad_enabled(grads_on_image_encoder):
id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False)
if not grads_on_image_encoder:
id_embeds = id_embeds.detach()
prompt_embeds = prompt_embeds.detach()
updated_prompt_embeds = self.fuse_module(
prompt_embeds, id_embeds, class_tokens_mask
)
return updated_prompt_embeds
else:
# dont think this is used
# if self.adapter_type == 'photo_maker':
# id_pixel_values = args[0]
# prompt_embeds: PromptEmbeds = args[1]
# class_tokens_mask = args[2]
#
# grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled()
#
# with torch.set_grad_enabled(grads_on_image_encoder):
# id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False)
#
# if not grads_on_image_encoder:
# id_embeds = id_embeds.detach()
#
# prompt_embeds = prompt_embeds.detach()
#
# updated_prompt_embeds = self.fuse_module(
# prompt_embeds, id_embeds, class_tokens_mask
# )
#
# return updated_prompt_embeds
# else:
raise NotImplementedError
def setup_clip(self):
@@ -226,7 +243,9 @@ class CustomAdapter(torch.nn.Module):
# self.sd_ref().pipeline.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker")
# self.sd_ref().pipeline.fuse_lora()
pass
if 'id_encoder' in state_dict and self.adapter_type == 'photo_maker':
if 'clip_fusion' in state_dict:
self.clip_fusion_module.load_state_dict(state_dict['clip_fusion'], strict=strict)
if 'id_encoder' in state_dict and (self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion'):
self.vision_encoder.load_state_dict(state_dict['id_encoder'], strict=strict)
# check to see if the fuse weights are there
fuse_weights = {}
@@ -235,7 +254,31 @@ class CustomAdapter(torch.nn.Module):
k = k.replace('fuse_module.', '')
fuse_weights[k] = v
if len(fuse_weights) > 0:
try:
self.fuse_module.load_state_dict(fuse_weights, strict=strict)
except Exception as e:
print(e)
# force load it
print(f"force loading fuse module as it did not match")
current_state_dict = self.fuse_module.state_dict()
for k, v in fuse_weights.items():
if len(v.shape) == 1:
current_state_dict[k] = v[:current_state_dict[k].shape[0]]
elif len(v.shape) == 2:
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1]]
elif len(v.shape) == 3:
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1],
:current_state_dict[k].shape[2]]
elif len(v.shape) == 4:
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1],
:current_state_dict[k].shape[2], :current_state_dict[k].shape[3]]
else:
raise ValueError(f"unknown shape: {v.shape}")
self.fuse_module.load_state_dict(current_state_dict, strict=strict)
if 'vision_encoder' in state_dict:
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
if 'fuse_module' in state_dict:
self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
@@ -252,6 +295,12 @@ class CustomAdapter(torch.nn.Module):
# todo save LoRA
return state_dict
elif self.adapter_type == 'clip_fusion':
if self.config.train_image_encoder:
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
state_dict["clip_fusion"] = self.clip_fusion_module.state_dict()
return state_dict
else:
raise NotImplementedError
@@ -260,7 +309,9 @@ class CustomAdapter(torch.nn.Module):
prompt: Union[List[str], str],
is_unconditional: bool = False,
):
if self.adapter_type == 'photo_maker':
if self.adapter_type == 'clip_fusion':
return prompt
elif self.adapter_type == 'photo_maker':
if is_unconditional:
return prompt
else:
@@ -310,7 +361,8 @@ class CustomAdapter(torch.nn.Module):
# add the first one to the front of the prompt
tokened_prompt = self.config.class_names[0] + ' ' + self.flag_word + ' ' + prompt
our_class = self.config.class_names[0]
prompt = " ".join([self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt
prompt = " ".join(
[self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt
# add the prompt to the list
new_prompt_list.append(prompt)
@@ -329,8 +381,7 @@ class CustomAdapter(torch.nn.Module):
class_token = tokenized_prompt[flag_idx - 1]
boolean_mask = torch.zeros(flag_idx-1, dtype=torch.bool)
boolean_mask = torch.zeros(flag_idx - 1, dtype=torch.bool)
boolean_mask = torch.cat((boolean_mask, torch.ones(self.num_control_images, dtype=torch.bool)))
boolean_mask = boolean_mask.to(self.device)
# zero pad it to 77
@@ -357,7 +408,7 @@ class CustomAdapter(torch.nn.Module):
has_been_preprocessed=False,
is_unconditional=False
) -> PromptEmbeds:
if self.adapter_type == 'photo_maker':
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
if is_unconditional:
# we dont condition the negative embeds for photo maker
return prompt_embeds
@@ -384,10 +435,9 @@ class CustomAdapter(torch.nn.Module):
clip_image = tensors_0_1
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
if self.adapter_type == 'photo_maker':
# Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image
clip_image = clip_image.unsqueeze(1)
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
@@ -408,11 +458,46 @@ class CustomAdapter(torch.nn.Module):
id_embeds,
self.token_mask
)
return prompt_embeds
elif self.adapter_type == 'clip_fusion':
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
clip_image = clip_image.requires_grad_(True)
id_embeds = self.vision_encoder(
clip_image,
output_hidden_states=True,
)
else:
with torch.no_grad():
self.vision_encoder.eval()
id_embeds = self.vision_encoder(
clip_image, output_hidden_states=True
)
img_embeds = id_embeds['last_hidden_state']
if not is_training or not self.config.train_image_encoder:
img_embeds = img_embeds.detach()
prompt_embeds.text_embeds = self.clip_fusion_module(
prompt_embeds.text_embeds,
img_embeds
)
return prompt_embeds
else:
raise NotImplementedError
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.type == 'photo_maker':
yield from self.fuse_module.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
elif self.config.type == 'clip_fusion':
yield from self.clip_fusion_module.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
else:
raise NotImplementedError

View File

@@ -0,0 +1,143 @@
import torch
import torch.nn as nn
# Conv1d MLP
# MLP that can alternately be used as a conv1d on dim 1
class MLPC(nn.Module):
def __init__(
self,
in_dim,
out_dim,
hidden_dim,
do_conv=False,
use_residual=True
):
super().__init__()
self.do_conv = do_conv
if use_residual:
assert in_dim == out_dim
# dont normalize if using conv
if not do_conv:
self.layernorm = nn.LayerNorm(in_dim)
if do_conv:
self.fc1 = nn.Conv1d(in_dim, hidden_dim, 1)
self.fc2 = nn.Conv1d(hidden_dim, out_dim, 1)
else:
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, out_dim)
self.use_residual = use_residual
self.act_fn = nn.GELU()
def forward(self, x):
residual = x
if not self.do_conv:
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
if self.use_residual:
x = x + residual
return x
class ZipperBlock(nn.Module):
def __init__(
self,
in_size,
in_tokens,
out_size,
out_tokens,
hidden_size,
hidden_tokens,
):
super().__init__()
self.in_size = in_size
self.in_tokens = in_tokens
self.out_size = out_size
self.out_tokens = out_tokens
self.hidden_size = hidden_size
self.hidden_tokens = hidden_tokens
# permute to (batch_size, out_size, in_tokens)
self.zip_token = MLPC(
in_dim=self.in_tokens,
out_dim=self.out_tokens,
hidden_dim=self.hidden_tokens,
do_conv=True, # no need to permute
use_residual=False
)
# permute to (batch_size, out_tokens, out_size)
# in shpae: (batch_size, in_tokens, in_size)
self.zip_size = MLPC(
in_dim=self.in_size,
out_dim=self.out_size,
hidden_dim=self.hidden_size,
use_residual=False
)
def forward(self, x):
x = self.zip_token(x)
x = self.zip_size(x)
return x
# CLIPFusionModule
# Fuses any size of vision and text embeddings into a single embedding.
# remaps tokens and vectors.
class CLIPFusionModule(nn.Module):
def __init__(
self,
text_hidden_size: int = 768,
text_tokens: int = 77,
vision_hidden_size: int = 1024,
vision_tokens: int = 257,
num_blocks: int = 2,
):
super(CLIPFusionModule, self).__init__()
self.text_hidden_size = text_hidden_size
self.text_tokens = text_tokens
self.vision_hidden_size = vision_hidden_size
self.vision_tokens = vision_tokens
self.resampler = ZipperBlock(
in_size=self.vision_hidden_size,
in_tokens=self.vision_tokens,
out_size=self.text_hidden_size,
out_tokens=self.text_tokens,
hidden_size=self.vision_hidden_size * 2,
hidden_tokens=self.vision_tokens * 2
)
self.zipper_blocks = torch.nn.ModuleList([
ZipperBlock(
in_size=self.text_hidden_size * 2,
in_tokens=self.text_tokens,
out_size=self.text_hidden_size,
out_tokens=self.text_tokens,
hidden_size=self.text_hidden_size * 2,
hidden_tokens=self.text_tokens * 2
) for i in range(num_blocks)
])
def forward(self, text_embeds, vision_embeds):
# text_embeds = (batch_size, 77, 768)
# vision_embeds = (batch_size, 257, 1024)
# output = (batch_size, 77, 768)
vision_embeds = self.resampler(vision_embeds)
x = vision_embeds
for i, block in enumerate(self.zipper_blocks):
res = x
x = torch.cat([text_embeds, x], dim=-1)
x = block(x)
x = x + res
x = text_embeds + x
return x

View File

@@ -119,11 +119,12 @@ class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection):
super().__init__(config, *model_args, **model_kwargs)
self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
def forward(self, id_pixel_values, do_projection2=True):
def forward(self, id_pixel_values, do_projection2=True, output_full=False):
b, num_inputs, c, h, w = id_pixel_values.shape
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
shared_id_embeds = self.vision_model(id_pixel_values)[1]
# last_hidden_state, 1, 257, 1024
vision_output = self.vision_model(id_pixel_values, output_hidden_states=True)
shared_id_embeds = vision_output[1]
id_embeds = self.visual_projection(shared_id_embeds)
id_embeds = id_embeds.view(b, num_inputs, 1, -1)
@@ -133,6 +134,8 @@ class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection):
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
if output_full:
return id_embeds, vision_output
return id_embeds