From b2a54c8f3609b46c53dd909cdbb2c79f02e2ef17 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 9 Jan 2024 20:52:21 -0700 Subject: [PATCH] Added siglip support --- toolkit/dataloader_mixins.py | 25 +++++++++++++++++-------- toolkit/ip_adapter.py | 25 ++++++++++++++++++++++--- toolkit/saving.py | 1 - 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 3fcdc85..0197ed3 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -410,10 +410,10 @@ class ImageProcessingDTOMixin: if self.flip_x: # do a flip - img.transpose(Image.FLIP_LEFT_RIGHT) + img = img.transpose(Image.FLIP_LEFT_RIGHT) if self.flip_y: # do a flip - img.transpose(Image.FLIP_TOP_BOTTOM) + img = img.transpose(Image.FLIP_TOP_BOTTOM) if self.dataset_config.buckets: # scale and crop based on file item @@ -527,10 +527,10 @@ class ControlFileItemDTOMixin: if self.flip_x: # do a flip - img.transpose(Image.FLIP_LEFT_RIGHT) + img = img.transpose(Image.FLIP_LEFT_RIGHT) if self.flip_y: # do a flip - img.transpose(Image.FLIP_TOP_BOTTOM) + img = img.transpose(Image.FLIP_TOP_BOTTOM) if self.dataset_config.buckets: # scale and crop based on file item @@ -638,6 +638,15 @@ class ClipImageFileItemDTOMixin: print(f"Error: {e}") print(f"Error loading image: {self.clip_image_path}") + img = img.convert('RGB') + + if self.flip_x: + # do a flip + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img = img.transpose(Image.FLIP_TOP_BOTTOM) + if self.has_clip_augmentations: self.clip_image_tensor = self.augment_clip_image(img, transform=None) else: @@ -822,10 +831,10 @@ class MaskFileItemDTOMixin: if self.flip_x: # do a flip - img.transpose(Image.FLIP_LEFT_RIGHT) + img = img.transpose(Image.FLIP_LEFT_RIGHT) if self.flip_y: # do a flip - img.transpose(Image.FLIP_TOP_BOTTOM) + img = img.transpose(Image.FLIP_TOP_BOTTOM) # randomly apply a blur up to 0.5% of the size of the min (width, height) min_size = min(img.width, img.height) @@ -906,10 +915,10 @@ class UnconditionalFileItemDTOMixin: if self.flip_x: # do a flip - img.transpose(Image.FLIP_LEFT_RIGHT) + img = img.transpose(Image.FLIP_LEFT_RIGHT) if self.flip_y: # do a flip - img.transpose(Image.FLIP_TOP_BOTTOM) + img = img.transpose(Image.FLIP_TOP_BOTTOM) if self.dataset_config.buckets: # scale and crop based on file item diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index abae9d7..9ebecfb 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -174,6 +174,15 @@ class IPAdapter(torch.nn.Module): self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( adapter_config.image_encoder_path, ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'siglip': + from transformers import SiglipImageProcessor, SiglipVisionModel + try: + self.clip_image_processor = SiglipImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = SiglipImageProcessor() + self.image_encoder = SiglipVisionModel.from_pretrained( + adapter_config.image_encoder_path, + ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) elif self.config.image_encoder_arch == 'vit': try: self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path) @@ -241,8 +250,10 @@ class IPAdapter(torch.nn.Module): input_size=preprocessor_input_size, clip_input_size=self.image_encoder.config.image_size, ) - - self.input_size = self.clip_image_processor.size['shortest_edge'] + if 'height' in self.clip_image_processor.size: + self.input_size = self.clip_image_processor.size['height'] + else: + self.input_size = self.clip_image_processor.crop_size['height'] self.current_scale = 1.0 self.is_active = True if adapter_config.type == 'ip': @@ -258,14 +269,22 @@ class IPAdapter(torch.nn.Module): embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \ self.image_encoder.config.hidden_sizes[-1] + image_encoder_state_dict = self.image_encoder.state_dict() + # max_seq_len = CLIP tokens + CLS token + max_seq_len = 257 + if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # clip + max_seq_len = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + # ip-adapter-plus image_proj_model = Resampler( dim=dim, depth=4, dim_head=64, heads=heads, - num_queries=self.config.num_tokens, # usually 16 + num_queries=self.config.num_tokens if self.config.num_tokens > 0 else max_seq_len, embedding_dim=embedding_dim, + max_seq_len=max_seq_len, output_dim=sd.unet.config['cross_attention_dim'], ff_mult=4 ) diff --git a/toolkit/saving.py b/toolkit/saving.py index 3a5789f..e3c63b2 100644 --- a/toolkit/saving.py +++ b/toolkit/saving.py @@ -256,7 +256,6 @@ def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDi if key.startswith('conditioner.embedders.1'): has_dual_text_encoders = True break - # map through the keys and values for key, value in model_keymap.items(): # ignore bias weights