From e4c82803e1fb9e2e5753de1d251b97e3f342b152 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 28 Sep 2024 14:53:38 -0600 Subject: [PATCH] Handle random resizing for pixtral input on direct vision adapter --- toolkit/config_modules.py | 1 + toolkit/custom_adapter.py | 28 ++++++++++++++++++++++++++++ toolkit/models/pixtral_vision.py | 20 +++++++++++++++----- 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 0306bd70..ad8baa4a 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -208,6 +208,7 @@ class AdapterConfig: self.ilora_up: bool = kwargs.get('ilora_up', True) self.pixtral_max_image_size: int = kwargs.get('pixtral_max_image_size', 512) + self.pixtral_random_image_size: int = kwargs.get('pixtral_random_image_size', False) self.flux_only_double: bool = kwargs.get('flux_only_double', False) diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 2f5a79d4..07a8e180 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -1,3 +1,4 @@ +import math import torch import sys @@ -18,6 +19,7 @@ from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEn from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model from toolkit.train_tools import get_torch_dtype from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible +import random sys.path.append(REPOS_ROOT) from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict @@ -767,6 +769,32 @@ class CustomAdapter(torch.nn.Module): ).pixel_values else: clip_image = tensors_0_1 + + # if is pixtral + if self.config.image_encoder_arch == 'pixtral' and self.config.pixtral_random_image_size: + # get the random size + random_size = random.randint(256, self.config.pixtral_max_image_size) + # images are already sized for max size, we have to fit them to the pixtral patch size to reduce / enlarge it farther. + h, w = clip_image.shape[2], clip_image.shape[3] + current_base_size = int(math.sqrt(w * h)) + ratio = current_base_size / random_size + if ratio > 1: + w = round(w / ratio) + h = round(h / ratio) + + width_tokens = (w - 1) // self.image_processor.image_patch_size + 1 + height_tokens = (h - 1) // self.image_processor.image_patch_size + 1 + assert width_tokens > 0 + assert height_tokens > 0 + + new_image_size = ( + width_tokens * self.image_processor.image_patch_size, + height_tokens * self.image_processor.image_patch_size, + ) + + # resize the image + clip_image = F.interpolate(clip_image, size=new_image_size, mode='bicubic', align_corners=False) + batch_size = clip_image.shape[0] if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': diff --git a/toolkit/models/pixtral_vision.py b/toolkit/models/pixtral_vision.py index e6c972bd..51da56b1 100644 --- a/toolkit/models/pixtral_vision.py +++ b/toolkit/models/pixtral_vision.py @@ -456,9 +456,12 @@ class PixtralVisionImagePreprocessor: self.max_image_size = max_image_size self.image_token = 10 - def _image_to_num_tokens(self, img: torch.Tensor) -> Tuple[int, int]: + def _image_to_num_tokens(self, img: torch.Tensor, max_image_size = None) -> Tuple[int, int]: w: Union[int, float] h: Union[int, float] + + if max_image_size is None: + max_image_size = self.max_image_size w, h = img.shape[-1], img.shape[-2] @@ -467,7 +470,7 @@ class PixtralVisionImagePreprocessor: # ratio = max(h / self.max_image_size, w / self.max_image_size) # original base_size = int(math.sqrt(w * h)) - ratio = base_size / self.max_image_size + ratio = base_size / max_image_size if ratio > 1: w = round(w / ratio) h = round(h / ratio) @@ -477,7 +480,7 @@ class PixtralVisionImagePreprocessor: return width_tokens, height_tokens - def __call__(self, image: torch.Tensor) -> torch.Tensor: + def __call__(self, image: torch.Tensor, max_image_size=None) -> torch.Tensor: """ Converts ImageChunks to numpy image arrays and image token ids @@ -495,8 +498,11 @@ class PixtralVisionImagePreprocessor: if image.min() < 0.0 or image.max() > 1.0: raise ValueError( f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}") + + if max_image_size is None: + max_image_size = self.max_image_size - w, h = self._image_to_num_tokens(image) + w, h = self._image_to_num_tokens(image, max_image_size=max_image_size) assert w > 0 assert h > 0 @@ -526,6 +532,7 @@ class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor): 'height': max_image_size, 'width': max_image_size } + self.max_image_size = max_image_size self.image_mean = DATASET_MEAN self.image_std = DATASET_STD @@ -535,13 +542,16 @@ class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor): return_tensors="pt", do_resize=True, do_rescale=False, + max_image_size=None, ) -> torch.Tensor: + if max_image_size is None: + max_image_size = self.max_image_size out_stack = [] if len(images.shape) == 3: images = images.unsqueeze(0) for i in range(images.shape[0]): image = images[i] - processed_image = super().__call__(image) + processed_image = super().__call__(image, max_image_size=max_image_size) out_stack.append(processed_image) output = torch.stack(out_stack, dim=0)