mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Added initial direct vision pixtral support
This commit is contained in:
@@ -17,6 +17,7 @@ from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
||||
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||
@@ -257,6 +258,13 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.vision_encoder = SiglipVisionModel.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'pixtral':
|
||||
self.image_processor = PixtralVisionImagePreprocessorCompatible(
|
||||
max_image_size=self.config.pixtral_max_image_size,
|
||||
)
|
||||
self.vision_encoder = PixtralVisionEncoderCompatible.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'vit':
|
||||
try:
|
||||
self.image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
|
||||
@@ -700,9 +708,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
else:
|
||||
return prompt_embeds
|
||||
|
||||
def get_empty_clip_image(self, batch_size: int) -> torch.Tensor:
|
||||
def get_empty_clip_image(self, batch_size: int, shape=None) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
tensors_0_1 = torch.rand([batch_size, 3, self.input_size, self.input_size], device=self.device)
|
||||
if shape is None:
|
||||
shape = [batch_size, 3, self.input_size, self.input_size]
|
||||
tensors_0_1 = torch.rand(shape, device=self.device)
|
||||
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
|
||||
dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
tensors_0_1 = tensors_0_1 * noise_scale
|
||||
@@ -761,7 +771,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
batch_size = clip_image.shape[0]
|
||||
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
# add an unconditional so we can save it
|
||||
unconditional = self.get_empty_clip_image(batch_size).to(
|
||||
unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to(
|
||||
clip_image.device, dtype=clip_image.dtype
|
||||
)
|
||||
clip_image = torch.cat([unconditional, clip_image], dim=0)
|
||||
|
||||
Reference in New Issue
Block a user