mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Work on additional image embedding methods. Finalized zipper resampler. It works amazing
This commit is contained in:
@@ -9,6 +9,7 @@ from torch.nn.modules.module import T
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.models.zipper_resampler import ZipperResampler
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.saving import load_ip_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
@@ -33,6 +34,7 @@ from transformers import (
|
||||
CLIPVisionModel,
|
||||
AutoImageProcessor,
|
||||
ConvNextModel,
|
||||
ConvNextV2ForImageClassification,
|
||||
ConvNextForImageClassification,
|
||||
ConvNextImageProcessor
|
||||
)
|
||||
@@ -226,6 +228,20 @@ class IPAdapter(torch.nn.Module):
|
||||
adapter_config.image_encoder_path,
|
||||
use_safetensors=True,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'convnextv2':
|
||||
try:
|
||||
self.clip_image_processor = AutoImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
print(f"could not load image processor from {adapter_config.image_encoder_path}")
|
||||
self.clip_image_processor = ConvNextImageProcessor(
|
||||
size=512,
|
||||
image_mean=[0.485,0.456,0.406],
|
||||
image_std=[0.229, 0.224, 0.225],
|
||||
)
|
||||
self.image_encoder = ConvNextV2ForImageClassification.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
use_safetensors=True,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'vit-hybrid':
|
||||
try:
|
||||
self.clip_image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
@@ -275,8 +291,12 @@ class IPAdapter(torch.nn.Module):
|
||||
)
|
||||
if 'height' in self.clip_image_processor.size:
|
||||
self.input_size = self.clip_image_processor.size['height']
|
||||
else:
|
||||
elif hasattr(self.clip_image_processor, 'crop_size'):
|
||||
self.input_size = self.clip_image_processor.crop_size['height']
|
||||
elif 'shortest_edge' in self.clip_image_processor.size.keys():
|
||||
self.input_size = self.clip_image_processor.size['shortest_edge']
|
||||
else:
|
||||
raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}")
|
||||
self.current_scale = 1.0
|
||||
self.is_active = True
|
||||
if adapter_config.type == 'ip':
|
||||
@@ -311,6 +331,39 @@ class IPAdapter(torch.nn.Module):
|
||||
output_dim=sd.unet.config['cross_attention_dim'],
|
||||
ff_mult=4
|
||||
)
|
||||
elif adapter_config.type == 'ipz':
|
||||
dim = sd.unet.config['cross_attention_dim']
|
||||
if hasattr(self.image_encoder.config, 'hidden_sizes'):
|
||||
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
|
||||
else:
|
||||
embedding_dim = self.image_encoder.config.hidden_size
|
||||
|
||||
image_encoder_state_dict = self.image_encoder.state_dict()
|
||||
# max_seq_len = CLIP tokens + CLS token
|
||||
in_tokens = 257
|
||||
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
||||
# clip
|
||||
in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
|
||||
if self.config.image_encoder_arch.startswith('convnext'):
|
||||
in_tokens = 16 * 16
|
||||
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
|
||||
|
||||
is_conv_next = self.config.image_encoder_arch.startswith('convnext')
|
||||
|
||||
out_tokens = self.config.num_tokens if self.config.num_tokens > 0 else in_tokens
|
||||
# ip-adapter-plus
|
||||
image_proj_model = ZipperResampler(
|
||||
in_size=embedding_dim,
|
||||
in_tokens=in_tokens,
|
||||
out_size=dim,
|
||||
out_tokens=out_tokens,
|
||||
hidden_size=embedding_dim,
|
||||
hidden_tokens=in_tokens,
|
||||
# num_blocks=1 if not is_conv_next else 2,
|
||||
num_blocks=1 if not is_conv_next else 2,
|
||||
is_conv_input=is_conv_next
|
||||
)
|
||||
elif adapter_config.type == 'ilora':
|
||||
# we apply the clip encodings to the LoRA
|
||||
image_proj_model = None
|
||||
|
||||
Reference in New Issue
Block a user