mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-06 05:29:57 +00:00
Added siglip support
This commit is contained in:
@@ -410,10 +410,10 @@ class ImageProcessingDTOMixin:
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if self.flip_y:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# scale and crop based on file item
|
||||
@@ -527,10 +527,10 @@ class ControlFileItemDTOMixin:
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if self.flip_y:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# scale and crop based on file item
|
||||
@@ -638,6 +638,15 @@ class ClipImageFileItemDTOMixin:
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {self.clip_image_path}")
|
||||
|
||||
img = img.convert('RGB')
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if self.flip_y:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
if self.has_clip_augmentations:
|
||||
self.clip_image_tensor = self.augment_clip_image(img, transform=None)
|
||||
else:
|
||||
@@ -822,10 +831,10 @@ class MaskFileItemDTOMixin:
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if self.flip_y:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
# randomly apply a blur up to 0.5% of the size of the min (width, height)
|
||||
min_size = min(img.width, img.height)
|
||||
@@ -906,10 +915,10 @@ class UnconditionalFileItemDTOMixin:
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if self.flip_y:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# scale and crop based on file item
|
||||
|
||||
@@ -174,6 +174,15 @@ class IPAdapter(torch.nn.Module):
|
||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'siglip':
|
||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
||||
try:
|
||||
self.clip_image_processor = SiglipImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.clip_image_processor = SiglipImageProcessor()
|
||||
self.image_encoder = SiglipVisionModel.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'vit':
|
||||
try:
|
||||
self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
|
||||
@@ -241,8 +250,10 @@ class IPAdapter(torch.nn.Module):
|
||||
input_size=preprocessor_input_size,
|
||||
clip_input_size=self.image_encoder.config.image_size,
|
||||
)
|
||||
|
||||
self.input_size = self.clip_image_processor.size['shortest_edge']
|
||||
if 'height' in self.clip_image_processor.size:
|
||||
self.input_size = self.clip_image_processor.size['height']
|
||||
else:
|
||||
self.input_size = self.clip_image_processor.crop_size['height']
|
||||
self.current_scale = 1.0
|
||||
self.is_active = True
|
||||
if adapter_config.type == 'ip':
|
||||
@@ -258,14 +269,22 @@ class IPAdapter(torch.nn.Module):
|
||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \
|
||||
self.image_encoder.config.hidden_sizes[-1]
|
||||
|
||||
image_encoder_state_dict = self.image_encoder.state_dict()
|
||||
# max_seq_len = CLIP tokens + CLS token
|
||||
max_seq_len = 257
|
||||
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
||||
# clip
|
||||
max_seq_len = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
|
||||
# ip-adapter-plus
|
||||
image_proj_model = Resampler(
|
||||
dim=dim,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=heads,
|
||||
num_queries=self.config.num_tokens, # usually 16
|
||||
num_queries=self.config.num_tokens if self.config.num_tokens > 0 else max_seq_len,
|
||||
embedding_dim=embedding_dim,
|
||||
max_seq_len=max_seq_len,
|
||||
output_dim=sd.unet.config['cross_attention_dim'],
|
||||
ff_mult=4
|
||||
)
|
||||
|
||||
@@ -256,7 +256,6 @@ def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDi
|
||||
if key.startswith('conditioner.embedders.1'):
|
||||
has_dual_text_encoders = True
|
||||
break
|
||||
|
||||
# map through the keys and values
|
||||
for key, value in model_keymap.items():
|
||||
# ignore bias weights
|
||||
|
||||
Reference in New Issue
Block a user