mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Handle random resizing for pixtral input on direct vision adapter
This commit is contained in:
@@ -208,6 +208,7 @@ class AdapterConfig:
|
||||
self.ilora_up: bool = kwargs.get('ilora_up', True)
|
||||
|
||||
self.pixtral_max_image_size: int = kwargs.get('pixtral_max_image_size', 512)
|
||||
self.pixtral_random_image_size: int = kwargs.get('pixtral_random_image_size', False)
|
||||
|
||||
self.flux_only_double: bool = kwargs.get('flux_only_double', False)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
import torch
|
||||
import sys
|
||||
|
||||
@@ -18,6 +19,7 @@ from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEn
|
||||
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible
|
||||
import random
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||
@@ -767,6 +769,32 @@ class CustomAdapter(torch.nn.Module):
|
||||
).pixel_values
|
||||
else:
|
||||
clip_image = tensors_0_1
|
||||
|
||||
# if is pixtral
|
||||
if self.config.image_encoder_arch == 'pixtral' and self.config.pixtral_random_image_size:
|
||||
# get the random size
|
||||
random_size = random.randint(256, self.config.pixtral_max_image_size)
|
||||
# images are already sized for max size, we have to fit them to the pixtral patch size to reduce / enlarge it farther.
|
||||
h, w = clip_image.shape[2], clip_image.shape[3]
|
||||
current_base_size = int(math.sqrt(w * h))
|
||||
ratio = current_base_size / random_size
|
||||
if ratio > 1:
|
||||
w = round(w / ratio)
|
||||
h = round(h / ratio)
|
||||
|
||||
width_tokens = (w - 1) // self.image_processor.image_patch_size + 1
|
||||
height_tokens = (h - 1) // self.image_processor.image_patch_size + 1
|
||||
assert width_tokens > 0
|
||||
assert height_tokens > 0
|
||||
|
||||
new_image_size = (
|
||||
width_tokens * self.image_processor.image_patch_size,
|
||||
height_tokens * self.image_processor.image_patch_size,
|
||||
)
|
||||
|
||||
# resize the image
|
||||
clip_image = F.interpolate(clip_image, size=new_image_size, mode='bicubic', align_corners=False)
|
||||
|
||||
|
||||
batch_size = clip_image.shape[0]
|
||||
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
|
||||
@@ -456,9 +456,12 @@ class PixtralVisionImagePreprocessor:
|
||||
self.max_image_size = max_image_size
|
||||
self.image_token = 10
|
||||
|
||||
def _image_to_num_tokens(self, img: torch.Tensor) -> Tuple[int, int]:
|
||||
def _image_to_num_tokens(self, img: torch.Tensor, max_image_size = None) -> Tuple[int, int]:
|
||||
w: Union[int, float]
|
||||
h: Union[int, float]
|
||||
|
||||
if max_image_size is None:
|
||||
max_image_size = self.max_image_size
|
||||
|
||||
w, h = img.shape[-1], img.shape[-2]
|
||||
|
||||
@@ -467,7 +470,7 @@ class PixtralVisionImagePreprocessor:
|
||||
# ratio = max(h / self.max_image_size, w / self.max_image_size) # original
|
||||
|
||||
base_size = int(math.sqrt(w * h))
|
||||
ratio = base_size / self.max_image_size
|
||||
ratio = base_size / max_image_size
|
||||
if ratio > 1:
|
||||
w = round(w / ratio)
|
||||
h = round(h / ratio)
|
||||
@@ -477,7 +480,7 @@ class PixtralVisionImagePreprocessor:
|
||||
|
||||
return width_tokens, height_tokens
|
||||
|
||||
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
||||
def __call__(self, image: torch.Tensor, max_image_size=None) -> torch.Tensor:
|
||||
"""
|
||||
Converts ImageChunks to numpy image arrays and image token ids
|
||||
|
||||
@@ -495,8 +498,11 @@ class PixtralVisionImagePreprocessor:
|
||||
if image.min() < 0.0 or image.max() > 1.0:
|
||||
raise ValueError(
|
||||
f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
|
||||
|
||||
if max_image_size is None:
|
||||
max_image_size = self.max_image_size
|
||||
|
||||
w, h = self._image_to_num_tokens(image)
|
||||
w, h = self._image_to_num_tokens(image, max_image_size=max_image_size)
|
||||
assert w > 0
|
||||
assert h > 0
|
||||
|
||||
@@ -526,6 +532,7 @@ class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
|
||||
'height': max_image_size,
|
||||
'width': max_image_size
|
||||
}
|
||||
self.max_image_size = max_image_size
|
||||
self.image_mean = DATASET_MEAN
|
||||
self.image_std = DATASET_STD
|
||||
|
||||
@@ -535,13 +542,16 @@ class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
max_image_size=None,
|
||||
) -> torch.Tensor:
|
||||
if max_image_size is None:
|
||||
max_image_size = self.max_image_size
|
||||
out_stack = []
|
||||
if len(images.shape) == 3:
|
||||
images = images.unsqueeze(0)
|
||||
for i in range(images.shape[0]):
|
||||
image = images[i]
|
||||
processed_image = super().__call__(image)
|
||||
processed_image = super().__call__(image, max_image_size=max_image_size)
|
||||
out_stack.append(processed_image)
|
||||
|
||||
output = torch.stack(out_stack, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user