mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Handle random resizing for pixtral input on direct vision adapter
This commit is contained in:
@@ -456,9 +456,12 @@ class PixtralVisionImagePreprocessor:
|
||||
self.max_image_size = max_image_size
|
||||
self.image_token = 10
|
||||
|
||||
def _image_to_num_tokens(self, img: torch.Tensor) -> Tuple[int, int]:
|
||||
def _image_to_num_tokens(self, img: torch.Tensor, max_image_size = None) -> Tuple[int, int]:
|
||||
w: Union[int, float]
|
||||
h: Union[int, float]
|
||||
|
||||
if max_image_size is None:
|
||||
max_image_size = self.max_image_size
|
||||
|
||||
w, h = img.shape[-1], img.shape[-2]
|
||||
|
||||
@@ -467,7 +470,7 @@ class PixtralVisionImagePreprocessor:
|
||||
# ratio = max(h / self.max_image_size, w / self.max_image_size) # original
|
||||
|
||||
base_size = int(math.sqrt(w * h))
|
||||
ratio = base_size / self.max_image_size
|
||||
ratio = base_size / max_image_size
|
||||
if ratio > 1:
|
||||
w = round(w / ratio)
|
||||
h = round(h / ratio)
|
||||
@@ -477,7 +480,7 @@ class PixtralVisionImagePreprocessor:
|
||||
|
||||
return width_tokens, height_tokens
|
||||
|
||||
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
||||
def __call__(self, image: torch.Tensor, max_image_size=None) -> torch.Tensor:
|
||||
"""
|
||||
Converts ImageChunks to numpy image arrays and image token ids
|
||||
|
||||
@@ -495,8 +498,11 @@ class PixtralVisionImagePreprocessor:
|
||||
if image.min() < 0.0 or image.max() > 1.0:
|
||||
raise ValueError(
|
||||
f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
|
||||
|
||||
if max_image_size is None:
|
||||
max_image_size = self.max_image_size
|
||||
|
||||
w, h = self._image_to_num_tokens(image)
|
||||
w, h = self._image_to_num_tokens(image, max_image_size=max_image_size)
|
||||
assert w > 0
|
||||
assert h > 0
|
||||
|
||||
@@ -526,6 +532,7 @@ class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
|
||||
'height': max_image_size,
|
||||
'width': max_image_size
|
||||
}
|
||||
self.max_image_size = max_image_size
|
||||
self.image_mean = DATASET_MEAN
|
||||
self.image_std = DATASET_STD
|
||||
|
||||
@@ -535,13 +542,16 @@ class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
max_image_size=None,
|
||||
) -> torch.Tensor:
|
||||
if max_image_size is None:
|
||||
max_image_size = self.max_image_size
|
||||
out_stack = []
|
||||
if len(images.shape) == 3:
|
||||
images = images.unsqueeze(0)
|
||||
for i in range(images.shape[0]):
|
||||
image = images[i]
|
||||
processed_image = super().__call__(image)
|
||||
processed_image = super().__call__(image, max_image_size=max_image_size)
|
||||
out_stack.append(processed_image)
|
||||
|
||||
output = torch.stack(out_stack, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user