diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 540b7a69..ebc1b1fe 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -5,6 +5,7 @@ from PIL import Image from torch.nn import Parameter from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from toolkit.models.clip_fusion import CLIPFusionModule from toolkit.models.clip_pre_processor import CLIPImagePreProcessor from toolkit.paths import REPOS_ROOT from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder @@ -72,50 +73,66 @@ class CustomAdapter(torch.nn.Module): # add for dataloader self.clip_image_processor = self.image_processor + self.clip_fusion_module: CLIPFusionModule = None + self.setup_adapter() - # try to load from our name_or_path - if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'): - self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False) - - # add the trigger word to the tokenizer - if isinstance(self.sd_ref().tokenizer, list): - for tokenizer in self.sd_ref().tokenizer: - tokenizer.add_tokens([self.flag_word], special_tokens=True) - else: - self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True) + if self.adapter_type == 'photo_maker': + # try to load from our name_or_path + if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'): + self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False) + # add the trigger word to the tokenizer + if isinstance(self.sd_ref().tokenizer, list): + for tokenizer in self.sd_ref().tokenizer: + tokenizer.add_tokens([self.flag_word], special_tokens=True) + else: + self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True) def setup_adapter(self): if self.adapter_type == 'photo_maker': sd = self.sd_ref() embed_dim = sd.unet.config['cross_attention_dim'] self.fuse_module = FuseModule(embed_dim) + elif self.adapter_type == 'clip_fusion': + sd = self.sd_ref() + embed_dim = sd.unet.config['cross_attention_dim'] + + vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2) + if self.config.image_encoder_arch == 'clip': + vision_tokens = vision_tokens + 1 + self.clip_fusion_module = CLIPFusionModule( + text_hidden_size=embed_dim, + text_tokens=77, + vision_hidden_size=self.vision_encoder.config.hidden_size, + vision_tokens=vision_tokens + ) else: raise ValueError(f"unknown adapter type: {self.adapter_type}") def forward(self, *args, **kwargs): - if self.adapter_type == 'photo_maker': - id_pixel_values = args[0] - prompt_embeds: PromptEmbeds = args[1] - class_tokens_mask = args[2] - - grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled() - - with torch.set_grad_enabled(grads_on_image_encoder): - id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False) - - if not grads_on_image_encoder: - id_embeds = id_embeds.detach() - - prompt_embeds = prompt_embeds.detach() - - updated_prompt_embeds = self.fuse_module( - prompt_embeds, id_embeds, class_tokens_mask - ) - - return updated_prompt_embeds - else: - raise NotImplementedError + # dont think this is used + # if self.adapter_type == 'photo_maker': + # id_pixel_values = args[0] + # prompt_embeds: PromptEmbeds = args[1] + # class_tokens_mask = args[2] + # + # grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled() + # + # with torch.set_grad_enabled(grads_on_image_encoder): + # id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False) + # + # if not grads_on_image_encoder: + # id_embeds = id_embeds.detach() + # + # prompt_embeds = prompt_embeds.detach() + # + # updated_prompt_embeds = self.fuse_module( + # prompt_embeds, id_embeds, class_tokens_mask + # ) + # + # return updated_prompt_embeds + # else: + raise NotImplementedError def setup_clip(self): adapter_config = self.config @@ -226,7 +243,9 @@ class CustomAdapter(torch.nn.Module): # self.sd_ref().pipeline.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker") # self.sd_ref().pipeline.fuse_lora() pass - if 'id_encoder' in state_dict and self.adapter_type == 'photo_maker': + if 'clip_fusion' in state_dict: + self.clip_fusion_module.load_state_dict(state_dict['clip_fusion'], strict=strict) + if 'id_encoder' in state_dict and (self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion'): self.vision_encoder.load_state_dict(state_dict['id_encoder'], strict=strict) # check to see if the fuse weights are there fuse_weights = {} @@ -235,7 +254,31 @@ class CustomAdapter(torch.nn.Module): k = k.replace('fuse_module.', '') fuse_weights[k] = v if len(fuse_weights) > 0: - self.fuse_module.load_state_dict(fuse_weights, strict=strict) + try: + self.fuse_module.load_state_dict(fuse_weights, strict=strict) + except Exception as e: + + print(e) + # force load it + print(f"force loading fuse module as it did not match") + current_state_dict = self.fuse_module.state_dict() + for k, v in fuse_weights.items(): + if len(v.shape) == 1: + current_state_dict[k] = v[:current_state_dict[k].shape[0]] + elif len(v.shape) == 2: + current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1]] + elif len(v.shape) == 3: + current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1], + :current_state_dict[k].shape[2]] + elif len(v.shape) == 4: + current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1], + :current_state_dict[k].shape[2], :current_state_dict[k].shape[3]] + else: + raise ValueError(f"unknown shape: {v.shape}") + self.fuse_module.load_state_dict(current_state_dict, strict=strict) + + if 'vision_encoder' in state_dict: + self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict) if 'fuse_module' in state_dict: self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict) @@ -252,6 +295,12 @@ class CustomAdapter(torch.nn.Module): # todo save LoRA return state_dict + + elif self.adapter_type == 'clip_fusion': + if self.config.train_image_encoder: + state_dict["vision_encoder"] = self.vision_encoder.state_dict() + state_dict["clip_fusion"] = self.clip_fusion_module.state_dict() + return state_dict else: raise NotImplementedError @@ -260,7 +309,9 @@ class CustomAdapter(torch.nn.Module): prompt: Union[List[str], str], is_unconditional: bool = False, ): - if self.adapter_type == 'photo_maker': + if self.adapter_type == 'clip_fusion': + return prompt + elif self.adapter_type == 'photo_maker': if is_unconditional: return prompt else: @@ -310,7 +361,8 @@ class CustomAdapter(torch.nn.Module): # add the first one to the front of the prompt tokened_prompt = self.config.class_names[0] + ' ' + self.flag_word + ' ' + prompt our_class = self.config.class_names[0] - prompt = " ".join([self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt + prompt = " ".join( + [self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt # add the prompt to the list new_prompt_list.append(prompt) @@ -329,8 +381,7 @@ class CustomAdapter(torch.nn.Module): class_token = tokenized_prompt[flag_idx - 1] - - boolean_mask = torch.zeros(flag_idx-1, dtype=torch.bool) + boolean_mask = torch.zeros(flag_idx - 1, dtype=torch.bool) boolean_mask = torch.cat((boolean_mask, torch.ones(self.num_control_images, dtype=torch.bool))) boolean_mask = boolean_mask.to(self.device) # zero pad it to 77 @@ -357,62 +408,96 @@ class CustomAdapter(torch.nn.Module): has_been_preprocessed=False, is_unconditional=False ) -> PromptEmbeds: - if self.adapter_type == 'photo_maker': + if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion': if is_unconditional: # we dont condition the negative embeds for photo maker return prompt_embeds - with torch.no_grad(): - # on training the clip image is created in the dataloader - if not has_been_preprocessed: - # tensors should be 0-1 - if tensors_0_1.ndim == 3: - tensors_0_1 = tensors_0_1.unsqueeze(0) - # training tensors are 0 - 1 - tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) - # if images are out of this range throw error - if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: - raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( - tensors_0_1.min(), tensors_0_1.max() - )) - clip_image = self.image_processor( - images=tensors_0_1, - return_tensors="pt", - do_resize=True, - do_rescale=False, - ).pixel_values - else: - clip_image = tensors_0_1 - clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + with torch.no_grad(): + # on training the clip image is created in the dataloader + if not has_been_preprocessed: + # tensors should be 0-1 + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + clip_image = self.image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + else: + clip_image = tensors_0_1 + clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + + if self.adapter_type == 'photo_maker': + # Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image + clip_image = clip_image.unsqueeze(1) + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + id_embeds = self.vision_encoder( + clip_image, + do_projection2=isinstance(self.sd_ref().text_encoder, list), + ) + else: + with torch.no_grad(): + self.vision_encoder.eval() + id_embeds = self.vision_encoder( + clip_image, do_projection2=isinstance(self.sd_ref().text_encoder, list) + ).detach() + + prompt_embeds.text_embeds = self.fuse_module( + prompt_embeds.text_embeds, + id_embeds, + self.token_mask + ) + return prompt_embeds + elif self.adapter_type == 'clip_fusion': + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + id_embeds = self.vision_encoder( + clip_image, + output_hidden_states=True, + ) + else: + with torch.no_grad(): + self.vision_encoder.eval() + id_embeds = self.vision_encoder( + clip_image, output_hidden_states=True + ) + + img_embeds = id_embeds['last_hidden_state'] + + if not is_training or not self.config.train_image_encoder: + img_embeds = img_embeds.detach() + + prompt_embeds.text_embeds = self.clip_fusion_module( + prompt_embeds.text_embeds, + img_embeds + ) + return prompt_embeds - # Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image - clip_image = clip_image.unsqueeze(1) - - with torch.set_grad_enabled(is_training): - if is_training and self.config.train_image_encoder: - self.vision_encoder.train() - clip_image = clip_image.requires_grad_(True) - id_embeds = self.vision_encoder( - clip_image, - do_projection2=isinstance(self.sd_ref().text_encoder, list), - ) - else: - with torch.no_grad(): - self.vision_encoder.eval() - id_embeds = self.vision_encoder( - clip_image, do_projection2=isinstance(self.sd_ref().text_encoder, list) - ).detach() - - prompt_embeds.text_embeds = self.fuse_module( - prompt_embeds.text_embeds, - id_embeds, - self.token_mask - ) - - return prompt_embeds + else: + raise NotImplementedError def parameters(self, recurse: bool = True) -> Iterator[Parameter]: if self.config.type == 'photo_maker': yield from self.fuse_module.parameters(recurse) if self.config.train_image_encoder: yield from self.vision_encoder.parameters(recurse) + elif self.config.type == 'clip_fusion': + yield from self.clip_fusion_module.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) + else: + raise NotImplementedError diff --git a/toolkit/models/clip_fusion.py b/toolkit/models/clip_fusion.py new file mode 100644 index 00000000..3a9448b0 --- /dev/null +++ b/toolkit/models/clip_fusion.py @@ -0,0 +1,143 @@ +import torch +import torch.nn as nn + + +# Conv1d MLP +# MLP that can alternately be used as a conv1d on dim 1 +class MLPC(nn.Module): + def __init__( + self, + in_dim, + out_dim, + hidden_dim, + do_conv=False, + use_residual=True + ): + super().__init__() + self.do_conv = do_conv + if use_residual: + assert in_dim == out_dim + # dont normalize if using conv + if not do_conv: + self.layernorm = nn.LayerNorm(in_dim) + + if do_conv: + self.fc1 = nn.Conv1d(in_dim, hidden_dim, 1) + self.fc2 = nn.Conv1d(hidden_dim, out_dim, 1) + else: + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + if not self.do_conv: + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class ZipperBlock(nn.Module): + def __init__( + self, + in_size, + in_tokens, + out_size, + out_tokens, + hidden_size, + hidden_tokens, + ): + super().__init__() + self.in_size = in_size + self.in_tokens = in_tokens + self.out_size = out_size + self.out_tokens = out_tokens + self.hidden_size = hidden_size + self.hidden_tokens = hidden_tokens + # permute to (batch_size, out_size, in_tokens) + + self.zip_token = MLPC( + in_dim=self.in_tokens, + out_dim=self.out_tokens, + hidden_dim=self.hidden_tokens, + do_conv=True, # no need to permute + use_residual=False + ) + + # permute to (batch_size, out_tokens, out_size) + + # in shpae: (batch_size, in_tokens, in_size) + self.zip_size = MLPC( + in_dim=self.in_size, + out_dim=self.out_size, + hidden_dim=self.hidden_size, + use_residual=False + ) + + def forward(self, x): + x = self.zip_token(x) + x = self.zip_size(x) + return x + + +# CLIPFusionModule +# Fuses any size of vision and text embeddings into a single embedding. +# remaps tokens and vectors. +class CLIPFusionModule(nn.Module): + def __init__( + self, + text_hidden_size: int = 768, + text_tokens: int = 77, + vision_hidden_size: int = 1024, + vision_tokens: int = 257, + num_blocks: int = 2, + ): + super(CLIPFusionModule, self).__init__() + + self.text_hidden_size = text_hidden_size + self.text_tokens = text_tokens + self.vision_hidden_size = vision_hidden_size + self.vision_tokens = vision_tokens + + self.resampler = ZipperBlock( + in_size=self.vision_hidden_size, + in_tokens=self.vision_tokens, + out_size=self.text_hidden_size, + out_tokens=self.text_tokens, + hidden_size=self.vision_hidden_size * 2, + hidden_tokens=self.vision_tokens * 2 + ) + + self.zipper_blocks = torch.nn.ModuleList([ + ZipperBlock( + in_size=self.text_hidden_size * 2, + in_tokens=self.text_tokens, + out_size=self.text_hidden_size, + out_tokens=self.text_tokens, + hidden_size=self.text_hidden_size * 2, + hidden_tokens=self.text_tokens * 2 + ) for i in range(num_blocks) + ]) + + def forward(self, text_embeds, vision_embeds): + # text_embeds = (batch_size, 77, 768) + # vision_embeds = (batch_size, 257, 1024) + # output = (batch_size, 77, 768) + + vision_embeds = self.resampler(vision_embeds) + x = vision_embeds + for i, block in enumerate(self.zipper_blocks): + res = x + x = torch.cat([text_embeds, x], dim=-1) + x = block(x) + x = x + res + + x = text_embeds + x + + return x diff --git a/toolkit/photomaker.py b/toolkit/photomaker.py index 9f8d69ef..80379695 100644 --- a/toolkit/photomaker.py +++ b/toolkit/photomaker.py @@ -119,11 +119,12 @@ class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection): super().__init__(config, *model_args, **model_kwargs) self.visual_projection_2 = nn.Linear(1024, 1280, bias=False) - def forward(self, id_pixel_values, do_projection2=True): + def forward(self, id_pixel_values, do_projection2=True, output_full=False): b, num_inputs, c, h, w = id_pixel_values.shape id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) - - shared_id_embeds = self.vision_model(id_pixel_values)[1] + # last_hidden_state, 1, 257, 1024 + vision_output = self.vision_model(id_pixel_values, output_hidden_states=True) + shared_id_embeds = vision_output[1] id_embeds = self.visual_projection(shared_id_embeds) id_embeds = id_embeds.view(b, num_inputs, 1, -1) @@ -133,6 +134,8 @@ class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection): id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) + if output_full: + return id_embeds, vision_output return id_embeds