mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Created a size agnostic feature encoder (SAFE) model to be trained in replace of CLIP for ip adapters. It is mostly conv layers so will hopefully be able to handle facial features better than clip can. Also bug fixes
This commit is contained in:
@@ -32,6 +32,8 @@ from transformers import (
|
||||
ConvNextForImageClassification,
|
||||
ConvNextImageProcessor
|
||||
)
|
||||
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
||||
|
||||
from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
|
||||
|
||||
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
||||
@@ -175,6 +177,19 @@ class IPAdapter(torch.nn.Module):
|
||||
except EnvironmentError:
|
||||
self.clip_image_processor = ViTFeatureExtractor()
|
||||
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'safe':
|
||||
try:
|
||||
self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.clip_image_processor = SAFEImageProcessor()
|
||||
self.image_encoder = SAFEVisionModel(
|
||||
in_channels=3,
|
||||
num_tokens=self.config.num_tokens if self.config.adapter_type == 'ip+' else 1,
|
||||
num_vectors=sd.unet.config['cross_attention_dim'] if self.config.adapter_type == 'ip+' else self.config.safe_channels,
|
||||
reducer_channels=self.config.safe_reducer_channels,
|
||||
channels=self.config.safe_channels,
|
||||
downscale_factor=8
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'convnext':
|
||||
try:
|
||||
self.clip_image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
|
||||
Reference in New Issue
Block a user