Created a size agnostic feature encoder (SAFE) model to be trained in replace of CLIP for ip adapters. It is mostly conv layers so will hopefully be able to handle facial features better than clip can. Also bug fixes

This commit is contained in:
Jaret Burkett
2023-12-28 12:20:27 -07:00
parent d11ed7f66c
commit eeee4a1620
5 changed files with 286 additions and 6 deletions

View File

@@ -32,6 +32,8 @@ from transformers import (
ConvNextForImageClassification,
ConvNextImageProcessor
)
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
from transformers import ViTFeatureExtractor, ViTForImageClassification
@@ -175,6 +177,19 @@ class IPAdapter(torch.nn.Module):
except EnvironmentError:
self.clip_image_processor = ViTFeatureExtractor()
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'safe':
try:
self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
self.clip_image_processor = SAFEImageProcessor()
self.image_encoder = SAFEVisionModel(
in_channels=3,
num_tokens=self.config.num_tokens if self.config.adapter_type == 'ip+' else 1,
num_vectors=sd.unet.config['cross_attention_dim'] if self.config.adapter_type == 'ip+' else self.config.safe_channels,
reducer_channels=self.config.safe_reducer_channels,
channels=self.config.safe_channels,
downscale_factor=8
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'convnext':
try:
self.clip_image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path)