Work on additional image embedding methods. Finalized zipper resampler. It works amazing

This commit is contained in:
Jaret Burkett
2024-02-10 09:00:05 -07:00
parent a8481c1670
commit e074058faa
7 changed files with 261 additions and 47 deletions

View File

@@ -9,6 +9,7 @@ from torch.nn.modules.module import T
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.models.zipper_resampler import ZipperResampler
from toolkit.paths import REPOS_ROOT
from toolkit.saving import load_ip_adapter_model
from toolkit.train_tools import get_torch_dtype
@@ -33,6 +34,7 @@ from transformers import (
CLIPVisionModel,
AutoImageProcessor,
ConvNextModel,
ConvNextV2ForImageClassification,
ConvNextForImageClassification,
ConvNextImageProcessor
)
@@ -226,6 +228,20 @@ class IPAdapter(torch.nn.Module):
adapter_config.image_encoder_path,
use_safetensors=True,
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'convnextv2':
try:
self.clip_image_processor = AutoImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
print(f"could not load image processor from {adapter_config.image_encoder_path}")
self.clip_image_processor = ConvNextImageProcessor(
size=512,
image_mean=[0.485,0.456,0.406],
image_std=[0.229, 0.224, 0.225],
)
self.image_encoder = ConvNextV2ForImageClassification.from_pretrained(
adapter_config.image_encoder_path,
use_safetensors=True,
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'vit-hybrid':
try:
self.clip_image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path)
@@ -275,8 +291,12 @@ class IPAdapter(torch.nn.Module):
)
if 'height' in self.clip_image_processor.size:
self.input_size = self.clip_image_processor.size['height']
else:
elif hasattr(self.clip_image_processor, 'crop_size'):
self.input_size = self.clip_image_processor.crop_size['height']
elif 'shortest_edge' in self.clip_image_processor.size.keys():
self.input_size = self.clip_image_processor.size['shortest_edge']
else:
raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}")
self.current_scale = 1.0
self.is_active = True
if adapter_config.type == 'ip':
@@ -311,6 +331,39 @@ class IPAdapter(torch.nn.Module):
output_dim=sd.unet.config['cross_attention_dim'],
ff_mult=4
)
elif adapter_config.type == 'ipz':
dim = sd.unet.config['cross_attention_dim']
if hasattr(self.image_encoder.config, 'hidden_sizes'):
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
else:
embedding_dim = self.image_encoder.config.hidden_size
image_encoder_state_dict = self.image_encoder.state_dict()
# max_seq_len = CLIP tokens + CLS token
in_tokens = 257
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
# clip
in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
if self.config.image_encoder_arch.startswith('convnext'):
in_tokens = 16 * 16
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
is_conv_next = self.config.image_encoder_arch.startswith('convnext')
out_tokens = self.config.num_tokens if self.config.num_tokens > 0 else in_tokens
# ip-adapter-plus
image_proj_model = ZipperResampler(
in_size=embedding_dim,
in_tokens=in_tokens,
out_size=dim,
out_tokens=out_tokens,
hidden_size=embedding_dim,
hidden_tokens=in_tokens,
# num_blocks=1 if not is_conv_next else 2,
num_blocks=1 if not is_conv_next else 2,
is_conv_input=is_conv_next
)
elif adapter_config.type == 'ilora':
# we apply the clip encodings to the LoRA
image_proj_model = None