Added siglip support

This commit is contained in:
Jaret Burkett
2024-01-09 20:52:21 -07:00
parent b767d29b3c
commit b2a54c8f36
3 changed files with 39 additions and 12 deletions

View File

@@ -410,10 +410,10 @@ class ImageProcessingDTOMixin:
if self.flip_x:
# do a flip
img.transpose(Image.FLIP_LEFT_RIGHT)
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img.transpose(Image.FLIP_TOP_BOTTOM)
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item
@@ -527,10 +527,10 @@ class ControlFileItemDTOMixin:
if self.flip_x:
# do a flip
img.transpose(Image.FLIP_LEFT_RIGHT)
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img.transpose(Image.FLIP_TOP_BOTTOM)
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item
@@ -638,6 +638,15 @@ class ClipImageFileItemDTOMixin:
print(f"Error: {e}")
print(f"Error loading image: {self.clip_image_path}")
img = img.convert('RGB')
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if self.has_clip_augmentations:
self.clip_image_tensor = self.augment_clip_image(img, transform=None)
else:
@@ -822,10 +831,10 @@ class MaskFileItemDTOMixin:
if self.flip_x:
# do a flip
img.transpose(Image.FLIP_LEFT_RIGHT)
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img.transpose(Image.FLIP_TOP_BOTTOM)
img = img.transpose(Image.FLIP_TOP_BOTTOM)
# randomly apply a blur up to 0.5% of the size of the min (width, height)
min_size = min(img.width, img.height)
@@ -906,10 +915,10 @@ class UnconditionalFileItemDTOMixin:
if self.flip_x:
# do a flip
img.transpose(Image.FLIP_LEFT_RIGHT)
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img.transpose(Image.FLIP_TOP_BOTTOM)
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item

View File

@@ -174,6 +174,15 @@ class IPAdapter(torch.nn.Module):
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
adapter_config.image_encoder_path,
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'siglip':
from transformers import SiglipImageProcessor, SiglipVisionModel
try:
self.clip_image_processor = SiglipImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
self.clip_image_processor = SiglipImageProcessor()
self.image_encoder = SiglipVisionModel.from_pretrained(
adapter_config.image_encoder_path,
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'vit':
try:
self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
@@ -241,8 +250,10 @@ class IPAdapter(torch.nn.Module):
input_size=preprocessor_input_size,
clip_input_size=self.image_encoder.config.image_size,
)
self.input_size = self.clip_image_processor.size['shortest_edge']
if 'height' in self.clip_image_processor.size:
self.input_size = self.clip_image_processor.size['height']
else:
self.input_size = self.clip_image_processor.crop_size['height']
self.current_scale = 1.0
self.is_active = True
if adapter_config.type == 'ip':
@@ -258,14 +269,22 @@ class IPAdapter(torch.nn.Module):
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \
self.image_encoder.config.hidden_sizes[-1]
image_encoder_state_dict = self.image_encoder.state_dict()
# max_seq_len = CLIP tokens + CLS token
max_seq_len = 257
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
# clip
max_seq_len = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
# ip-adapter-plus
image_proj_model = Resampler(
dim=dim,
depth=4,
dim_head=64,
heads=heads,
num_queries=self.config.num_tokens, # usually 16
num_queries=self.config.num_tokens if self.config.num_tokens > 0 else max_seq_len,
embedding_dim=embedding_dim,
max_seq_len=max_seq_len,
output_dim=sd.unet.config['cross_attention_dim'],
ff_mult=4
)

View File

@@ -256,7 +256,6 @@ def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDi
if key.startswith('conditioner.embedders.1'):
has_dual_text_encoders = True
break
# map through the keys and values
for key, value in model_keymap.items():
# ignore bias weights