Handle random resizing for pixtral input on direct vision adapter

This commit is contained in:
Jaret Burkett
2024-09-28 14:53:38 -06:00
parent 69aa92bce5
commit e4c82803e1
3 changed files with 44 additions and 5 deletions

View File

@@ -208,6 +208,7 @@ class AdapterConfig:
self.ilora_up: bool = kwargs.get('ilora_up', True)
self.pixtral_max_image_size: int = kwargs.get('pixtral_max_image_size', 512)
self.pixtral_random_image_size: int = kwargs.get('pixtral_random_image_size', False)
self.flux_only_double: bool = kwargs.get('flux_only_double', False)

View File

@@ -1,3 +1,4 @@
import math
import torch
import sys
@@ -18,6 +19,7 @@ from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEn
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
from toolkit.train_tools import get_torch_dtype
from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible
import random
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
@@ -767,6 +769,32 @@ class CustomAdapter(torch.nn.Module):
).pixel_values
else:
clip_image = tensors_0_1
# if is pixtral
if self.config.image_encoder_arch == 'pixtral' and self.config.pixtral_random_image_size:
# get the random size
random_size = random.randint(256, self.config.pixtral_max_image_size)
# images are already sized for max size, we have to fit them to the pixtral patch size to reduce / enlarge it farther.
h, w = clip_image.shape[2], clip_image.shape[3]
current_base_size = int(math.sqrt(w * h))
ratio = current_base_size / random_size
if ratio > 1:
w = round(w / ratio)
h = round(h / ratio)
width_tokens = (w - 1) // self.image_processor.image_patch_size + 1
height_tokens = (h - 1) // self.image_processor.image_patch_size + 1
assert width_tokens > 0
assert height_tokens > 0
new_image_size = (
width_tokens * self.image_processor.image_patch_size,
height_tokens * self.image_processor.image_patch_size,
)
# resize the image
clip_image = F.interpolate(clip_image, size=new_image_size, mode='bicubic', align_corners=False)
batch_size = clip_image.shape[0]
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':

View File

@@ -456,9 +456,12 @@ class PixtralVisionImagePreprocessor:
self.max_image_size = max_image_size
self.image_token = 10
def _image_to_num_tokens(self, img: torch.Tensor) -> Tuple[int, int]:
def _image_to_num_tokens(self, img: torch.Tensor, max_image_size = None) -> Tuple[int, int]:
w: Union[int, float]
h: Union[int, float]
if max_image_size is None:
max_image_size = self.max_image_size
w, h = img.shape[-1], img.shape[-2]
@@ -467,7 +470,7 @@ class PixtralVisionImagePreprocessor:
# ratio = max(h / self.max_image_size, w / self.max_image_size) # original
base_size = int(math.sqrt(w * h))
ratio = base_size / self.max_image_size
ratio = base_size / max_image_size
if ratio > 1:
w = round(w / ratio)
h = round(h / ratio)
@@ -477,7 +480,7 @@ class PixtralVisionImagePreprocessor:
return width_tokens, height_tokens
def __call__(self, image: torch.Tensor) -> torch.Tensor:
def __call__(self, image: torch.Tensor, max_image_size=None) -> torch.Tensor:
"""
Converts ImageChunks to numpy image arrays and image token ids
@@ -495,8 +498,11 @@ class PixtralVisionImagePreprocessor:
if image.min() < 0.0 or image.max() > 1.0:
raise ValueError(
f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
if max_image_size is None:
max_image_size = self.max_image_size
w, h = self._image_to_num_tokens(image)
w, h = self._image_to_num_tokens(image, max_image_size=max_image_size)
assert w > 0
assert h > 0
@@ -526,6 +532,7 @@ class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
'height': max_image_size,
'width': max_image_size
}
self.max_image_size = max_image_size
self.image_mean = DATASET_MEAN
self.image_std = DATASET_STD
@@ -535,13 +542,16 @@ class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
return_tensors="pt",
do_resize=True,
do_rescale=False,
max_image_size=None,
) -> torch.Tensor:
if max_image_size is None:
max_image_size = self.max_image_size
out_stack = []
if len(images.shape) == 3:
images = images.unsqueeze(0)
for i in range(images.shape[0]):
image = images[i]
processed_image = super().__call__(image)
processed_image = super().__call__(image, max_image_size=max_image_size)
out_stack.append(processed_image)
output = torch.stack(out_stack, dim=0)