Added an experimental clip fusion model that is showing promise for embedding concepts

This commit is contained in:
Jaret Burkett
2024-01-17 13:13:04 -07:00
parent 655533d4c7
commit 86c70a2a1f
3 changed files with 320 additions and 89 deletions

View File

@@ -5,6 +5,7 @@ from PIL import Image
from torch.nn import Parameter from torch.nn import Parameter
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.models.clip_fusion import CLIPFusionModule
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.paths import REPOS_ROOT from toolkit.paths import REPOS_ROOT
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
@@ -72,50 +73,66 @@ class CustomAdapter(torch.nn.Module):
# add for dataloader # add for dataloader
self.clip_image_processor = self.image_processor self.clip_image_processor = self.image_processor
self.clip_fusion_module: CLIPFusionModule = None
self.setup_adapter() self.setup_adapter()
# try to load from our name_or_path if self.adapter_type == 'photo_maker':
if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'): # try to load from our name_or_path
self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False) if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'):
self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False)
# add the trigger word to the tokenizer # add the trigger word to the tokenizer
if isinstance(self.sd_ref().tokenizer, list): if isinstance(self.sd_ref().tokenizer, list):
for tokenizer in self.sd_ref().tokenizer: for tokenizer in self.sd_ref().tokenizer:
tokenizer.add_tokens([self.flag_word], special_tokens=True) tokenizer.add_tokens([self.flag_word], special_tokens=True)
else: else:
self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True) self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True)
def setup_adapter(self): def setup_adapter(self):
if self.adapter_type == 'photo_maker': if self.adapter_type == 'photo_maker':
sd = self.sd_ref() sd = self.sd_ref()
embed_dim = sd.unet.config['cross_attention_dim'] embed_dim = sd.unet.config['cross_attention_dim']
self.fuse_module = FuseModule(embed_dim) self.fuse_module = FuseModule(embed_dim)
elif self.adapter_type == 'clip_fusion':
sd = self.sd_ref()
embed_dim = sd.unet.config['cross_attention_dim']
vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
if self.config.image_encoder_arch == 'clip':
vision_tokens = vision_tokens + 1
self.clip_fusion_module = CLIPFusionModule(
text_hidden_size=embed_dim,
text_tokens=77,
vision_hidden_size=self.vision_encoder.config.hidden_size,
vision_tokens=vision_tokens
)
else: else:
raise ValueError(f"unknown adapter type: {self.adapter_type}") raise ValueError(f"unknown adapter type: {self.adapter_type}")
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
if self.adapter_type == 'photo_maker': # dont think this is used
id_pixel_values = args[0] # if self.adapter_type == 'photo_maker':
prompt_embeds: PromptEmbeds = args[1] # id_pixel_values = args[0]
class_tokens_mask = args[2] # prompt_embeds: PromptEmbeds = args[1]
# class_tokens_mask = args[2]
grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled() #
# grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled()
with torch.set_grad_enabled(grads_on_image_encoder): #
id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False) # with torch.set_grad_enabled(grads_on_image_encoder):
# id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False)
if not grads_on_image_encoder: #
id_embeds = id_embeds.detach() # if not grads_on_image_encoder:
# id_embeds = id_embeds.detach()
prompt_embeds = prompt_embeds.detach() #
# prompt_embeds = prompt_embeds.detach()
updated_prompt_embeds = self.fuse_module( #
prompt_embeds, id_embeds, class_tokens_mask # updated_prompt_embeds = self.fuse_module(
) # prompt_embeds, id_embeds, class_tokens_mask
# )
return updated_prompt_embeds #
else: # return updated_prompt_embeds
raise NotImplementedError # else:
raise NotImplementedError
def setup_clip(self): def setup_clip(self):
adapter_config = self.config adapter_config = self.config
@@ -226,7 +243,9 @@ class CustomAdapter(torch.nn.Module):
# self.sd_ref().pipeline.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker") # self.sd_ref().pipeline.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker")
# self.sd_ref().pipeline.fuse_lora() # self.sd_ref().pipeline.fuse_lora()
pass pass
if 'id_encoder' in state_dict and self.adapter_type == 'photo_maker': if 'clip_fusion' in state_dict:
self.clip_fusion_module.load_state_dict(state_dict['clip_fusion'], strict=strict)
if 'id_encoder' in state_dict and (self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion'):
self.vision_encoder.load_state_dict(state_dict['id_encoder'], strict=strict) self.vision_encoder.load_state_dict(state_dict['id_encoder'], strict=strict)
# check to see if the fuse weights are there # check to see if the fuse weights are there
fuse_weights = {} fuse_weights = {}
@@ -235,7 +254,31 @@ class CustomAdapter(torch.nn.Module):
k = k.replace('fuse_module.', '') k = k.replace('fuse_module.', '')
fuse_weights[k] = v fuse_weights[k] = v
if len(fuse_weights) > 0: if len(fuse_weights) > 0:
self.fuse_module.load_state_dict(fuse_weights, strict=strict) try:
self.fuse_module.load_state_dict(fuse_weights, strict=strict)
except Exception as e:
print(e)
# force load it
print(f"force loading fuse module as it did not match")
current_state_dict = self.fuse_module.state_dict()
for k, v in fuse_weights.items():
if len(v.shape) == 1:
current_state_dict[k] = v[:current_state_dict[k].shape[0]]
elif len(v.shape) == 2:
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1]]
elif len(v.shape) == 3:
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1],
:current_state_dict[k].shape[2]]
elif len(v.shape) == 4:
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1],
:current_state_dict[k].shape[2], :current_state_dict[k].shape[3]]
else:
raise ValueError(f"unknown shape: {v.shape}")
self.fuse_module.load_state_dict(current_state_dict, strict=strict)
if 'vision_encoder' in state_dict:
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
if 'fuse_module' in state_dict: if 'fuse_module' in state_dict:
self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict) self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
@@ -252,6 +295,12 @@ class CustomAdapter(torch.nn.Module):
# todo save LoRA # todo save LoRA
return state_dict return state_dict
elif self.adapter_type == 'clip_fusion':
if self.config.train_image_encoder:
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
state_dict["clip_fusion"] = self.clip_fusion_module.state_dict()
return state_dict
else: else:
raise NotImplementedError raise NotImplementedError
@@ -260,7 +309,9 @@ class CustomAdapter(torch.nn.Module):
prompt: Union[List[str], str], prompt: Union[List[str], str],
is_unconditional: bool = False, is_unconditional: bool = False,
): ):
if self.adapter_type == 'photo_maker': if self.adapter_type == 'clip_fusion':
return prompt
elif self.adapter_type == 'photo_maker':
if is_unconditional: if is_unconditional:
return prompt return prompt
else: else:
@@ -310,7 +361,8 @@ class CustomAdapter(torch.nn.Module):
# add the first one to the front of the prompt # add the first one to the front of the prompt
tokened_prompt = self.config.class_names[0] + ' ' + self.flag_word + ' ' + prompt tokened_prompt = self.config.class_names[0] + ' ' + self.flag_word + ' ' + prompt
our_class = self.config.class_names[0] our_class = self.config.class_names[0]
prompt = " ".join([self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt prompt = " ".join(
[self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt
# add the prompt to the list # add the prompt to the list
new_prompt_list.append(prompt) new_prompt_list.append(prompt)
@@ -329,8 +381,7 @@ class CustomAdapter(torch.nn.Module):
class_token = tokenized_prompt[flag_idx - 1] class_token = tokenized_prompt[flag_idx - 1]
boolean_mask = torch.zeros(flag_idx - 1, dtype=torch.bool)
boolean_mask = torch.zeros(flag_idx-1, dtype=torch.bool)
boolean_mask = torch.cat((boolean_mask, torch.ones(self.num_control_images, dtype=torch.bool))) boolean_mask = torch.cat((boolean_mask, torch.ones(self.num_control_images, dtype=torch.bool)))
boolean_mask = boolean_mask.to(self.device) boolean_mask = boolean_mask.to(self.device)
# zero pad it to 77 # zero pad it to 77
@@ -357,62 +408,96 @@ class CustomAdapter(torch.nn.Module):
has_been_preprocessed=False, has_been_preprocessed=False,
is_unconditional=False is_unconditional=False
) -> PromptEmbeds: ) -> PromptEmbeds:
if self.adapter_type == 'photo_maker': if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
if is_unconditional: if is_unconditional:
# we dont condition the negative embeds for photo maker # we dont condition the negative embeds for photo maker
return prompt_embeds return prompt_embeds
with torch.no_grad(): with torch.no_grad():
# on training the clip image is created in the dataloader # on training the clip image is created in the dataloader
if not has_been_preprocessed: if not has_been_preprocessed:
# tensors should be 0-1 # tensors should be 0-1
if tensors_0_1.ndim == 3: if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0) tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1 # training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error # if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max() tensors_0_1.min(), tensors_0_1.max()
)) ))
clip_image = self.image_processor( clip_image = self.image_processor(
images=tensors_0_1, images=tensors_0_1,
return_tensors="pt", return_tensors="pt",
do_resize=True, do_resize=True,
do_rescale=False, do_rescale=False,
).pixel_values ).pixel_values
else: else:
clip_image = tensors_0_1 clip_image = tensors_0_1
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
if self.adapter_type == 'photo_maker':
# Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image
clip_image = clip_image.unsqueeze(1)
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
clip_image = clip_image.requires_grad_(True)
id_embeds = self.vision_encoder(
clip_image,
do_projection2=isinstance(self.sd_ref().text_encoder, list),
)
else:
with torch.no_grad():
self.vision_encoder.eval()
id_embeds = self.vision_encoder(
clip_image, do_projection2=isinstance(self.sd_ref().text_encoder, list)
).detach()
prompt_embeds.text_embeds = self.fuse_module(
prompt_embeds.text_embeds,
id_embeds,
self.token_mask
)
return prompt_embeds
elif self.adapter_type == 'clip_fusion':
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
clip_image = clip_image.requires_grad_(True)
id_embeds = self.vision_encoder(
clip_image,
output_hidden_states=True,
)
else:
with torch.no_grad():
self.vision_encoder.eval()
id_embeds = self.vision_encoder(
clip_image, output_hidden_states=True
)
img_embeds = id_embeds['last_hidden_state']
if not is_training or not self.config.train_image_encoder:
img_embeds = img_embeds.detach()
prompt_embeds.text_embeds = self.clip_fusion_module(
prompt_embeds.text_embeds,
img_embeds
)
return prompt_embeds
# Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image else:
clip_image = clip_image.unsqueeze(1) raise NotImplementedError
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
clip_image = clip_image.requires_grad_(True)
id_embeds = self.vision_encoder(
clip_image,
do_projection2=isinstance(self.sd_ref().text_encoder, list),
)
else:
with torch.no_grad():
self.vision_encoder.eval()
id_embeds = self.vision_encoder(
clip_image, do_projection2=isinstance(self.sd_ref().text_encoder, list)
).detach()
prompt_embeds.text_embeds = self.fuse_module(
prompt_embeds.text_embeds,
id_embeds,
self.token_mask
)
return prompt_embeds
def parameters(self, recurse: bool = True) -> Iterator[Parameter]: def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.type == 'photo_maker': if self.config.type == 'photo_maker':
yield from self.fuse_module.parameters(recurse) yield from self.fuse_module.parameters(recurse)
if self.config.train_image_encoder: if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse) yield from self.vision_encoder.parameters(recurse)
elif self.config.type == 'clip_fusion':
yield from self.clip_fusion_module.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
else:
raise NotImplementedError

View File

@@ -0,0 +1,143 @@
import torch
import torch.nn as nn
# Conv1d MLP
# MLP that can alternately be used as a conv1d on dim 1
class MLPC(nn.Module):
def __init__(
self,
in_dim,
out_dim,
hidden_dim,
do_conv=False,
use_residual=True
):
super().__init__()
self.do_conv = do_conv
if use_residual:
assert in_dim == out_dim
# dont normalize if using conv
if not do_conv:
self.layernorm = nn.LayerNorm(in_dim)
if do_conv:
self.fc1 = nn.Conv1d(in_dim, hidden_dim, 1)
self.fc2 = nn.Conv1d(hidden_dim, out_dim, 1)
else:
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, out_dim)
self.use_residual = use_residual
self.act_fn = nn.GELU()
def forward(self, x):
residual = x
if not self.do_conv:
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
if self.use_residual:
x = x + residual
return x
class ZipperBlock(nn.Module):
def __init__(
self,
in_size,
in_tokens,
out_size,
out_tokens,
hidden_size,
hidden_tokens,
):
super().__init__()
self.in_size = in_size
self.in_tokens = in_tokens
self.out_size = out_size
self.out_tokens = out_tokens
self.hidden_size = hidden_size
self.hidden_tokens = hidden_tokens
# permute to (batch_size, out_size, in_tokens)
self.zip_token = MLPC(
in_dim=self.in_tokens,
out_dim=self.out_tokens,
hidden_dim=self.hidden_tokens,
do_conv=True, # no need to permute
use_residual=False
)
# permute to (batch_size, out_tokens, out_size)
# in shpae: (batch_size, in_tokens, in_size)
self.zip_size = MLPC(
in_dim=self.in_size,
out_dim=self.out_size,
hidden_dim=self.hidden_size,
use_residual=False
)
def forward(self, x):
x = self.zip_token(x)
x = self.zip_size(x)
return x
# CLIPFusionModule
# Fuses any size of vision and text embeddings into a single embedding.
# remaps tokens and vectors.
class CLIPFusionModule(nn.Module):
def __init__(
self,
text_hidden_size: int = 768,
text_tokens: int = 77,
vision_hidden_size: int = 1024,
vision_tokens: int = 257,
num_blocks: int = 2,
):
super(CLIPFusionModule, self).__init__()
self.text_hidden_size = text_hidden_size
self.text_tokens = text_tokens
self.vision_hidden_size = vision_hidden_size
self.vision_tokens = vision_tokens
self.resampler = ZipperBlock(
in_size=self.vision_hidden_size,
in_tokens=self.vision_tokens,
out_size=self.text_hidden_size,
out_tokens=self.text_tokens,
hidden_size=self.vision_hidden_size * 2,
hidden_tokens=self.vision_tokens * 2
)
self.zipper_blocks = torch.nn.ModuleList([
ZipperBlock(
in_size=self.text_hidden_size * 2,
in_tokens=self.text_tokens,
out_size=self.text_hidden_size,
out_tokens=self.text_tokens,
hidden_size=self.text_hidden_size * 2,
hidden_tokens=self.text_tokens * 2
) for i in range(num_blocks)
])
def forward(self, text_embeds, vision_embeds):
# text_embeds = (batch_size, 77, 768)
# vision_embeds = (batch_size, 257, 1024)
# output = (batch_size, 77, 768)
vision_embeds = self.resampler(vision_embeds)
x = vision_embeds
for i, block in enumerate(self.zipper_blocks):
res = x
x = torch.cat([text_embeds, x], dim=-1)
x = block(x)
x = x + res
x = text_embeds + x
return x

View File

@@ -119,11 +119,12 @@ class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection):
super().__init__(config, *model_args, **model_kwargs) super().__init__(config, *model_args, **model_kwargs)
self.visual_projection_2 = nn.Linear(1024, 1280, bias=False) self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
def forward(self, id_pixel_values, do_projection2=True): def forward(self, id_pixel_values, do_projection2=True, output_full=False):
b, num_inputs, c, h, w = id_pixel_values.shape b, num_inputs, c, h, w = id_pixel_values.shape
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
# last_hidden_state, 1, 257, 1024
shared_id_embeds = self.vision_model(id_pixel_values)[1] vision_output = self.vision_model(id_pixel_values, output_hidden_states=True)
shared_id_embeds = vision_output[1]
id_embeds = self.visual_projection(shared_id_embeds) id_embeds = self.visual_projection(shared_id_embeds)
id_embeds = id_embeds.view(b, num_inputs, 1, -1) id_embeds = id_embeds.view(b, num_inputs, 1, -1)
@@ -133,6 +134,8 @@ class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection):
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
if output_full:
return id_embeds, vision_output
return id_embeds return id_embeds