mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 08:20:35 +00:00
Prep for future breaking changes in newer versions of transformers library
This commit is contained in:
@@ -19,8 +19,7 @@ from safetensors.torch import save_file, load_file
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
import torch.backends.cuda
|
||||
from huggingface_hub import HfApi, Repository, interpreter_login
|
||||
from huggingface_hub.utils import HfFolder
|
||||
from huggingface_hub import HfApi, interpreter_login
|
||||
from toolkit.memory_management import MemoryManager
|
||||
|
||||
from toolkit.basic import value_map
|
||||
|
||||
@@ -36,21 +36,12 @@ if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionModelWithProjection,
|
||||
CLIPVisionModel,
|
||||
AutoImageProcessor,
|
||||
ConvNextModel,
|
||||
ConvNextForImageClassification,
|
||||
ConvNextImageProcessor,
|
||||
UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer, BitsAndBytesConfig
|
||||
)
|
||||
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
||||
|
||||
from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
|
||||
|
||||
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
||||
|
||||
from toolkit.models.llm_adapter import LLMAdapter
|
||||
|
||||
import torch.nn.functional as F
|
||||
@@ -372,13 +363,6 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.vision_encoder = PixtralVisionEncoderCompatible.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'vit':
|
||||
try:
|
||||
self.image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.image_processor = ViTFeatureExtractor()
|
||||
self.vision_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'safe':
|
||||
try:
|
||||
self.image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
@@ -406,20 +390,6 @@ class CustomAdapter(torch.nn.Module):
|
||||
adapter_config.image_encoder_path,
|
||||
use_safetensors=True,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'vit-hybrid':
|
||||
try:
|
||||
self.image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
print(f"could not load image processor from {adapter_config.image_encoder_path}")
|
||||
self.image_processor = ViTHybridImageProcessor(
|
||||
size=320,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
)
|
||||
self.vision_encoder = ViTHybridForImageClassification.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
use_safetensors=True,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
else:
|
||||
raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}")
|
||||
|
||||
|
||||
@@ -37,10 +37,6 @@ from transformers import (
|
||||
)
|
||||
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
||||
|
||||
from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
|
||||
|
||||
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@@ -404,13 +400,6 @@ class IPAdapter(torch.nn.Module):
|
||||
self.image_encoder = SiglipVisionModel.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'vit':
|
||||
try:
|
||||
self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.clip_image_processor = ViTFeatureExtractor()
|
||||
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'safe':
|
||||
try:
|
||||
self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
@@ -452,20 +441,6 @@ class IPAdapter(torch.nn.Module):
|
||||
adapter_config.image_encoder_path,
|
||||
use_safetensors=True,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'vit-hybrid':
|
||||
try:
|
||||
self.clip_image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
print(f"could not load image processor from {adapter_config.image_encoder_path}")
|
||||
self.clip_image_processor = ViTHybridImageProcessor(
|
||||
size=320,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
)
|
||||
self.image_encoder = ViTHybridForImageClassification.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
use_safetensors=True,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
else:
|
||||
raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}")
|
||||
|
||||
|
||||
@@ -30,9 +30,6 @@ from transformers import (
|
||||
)
|
||||
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
||||
|
||||
from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
|
||||
|
||||
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
||||
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
Reference in New Issue
Block a user