Handle random resizing for pixtral input on direct vision adapter

This commit is contained in:
Jaret Burkett
2024-09-28 14:53:38 -06:00
parent 69aa92bce5
commit e4c82803e1
3 changed files with 44 additions and 5 deletions

View File

@@ -456,9 +456,12 @@ class PixtralVisionImagePreprocessor:
self.max_image_size = max_image_size
self.image_token = 10
def _image_to_num_tokens(self, img: torch.Tensor) -> Tuple[int, int]:
def _image_to_num_tokens(self, img: torch.Tensor, max_image_size = None) -> Tuple[int, int]:
w: Union[int, float]
h: Union[int, float]
if max_image_size is None:
max_image_size = self.max_image_size
w, h = img.shape[-1], img.shape[-2]
@@ -467,7 +470,7 @@ class PixtralVisionImagePreprocessor:
# ratio = max(h / self.max_image_size, w / self.max_image_size) # original
base_size = int(math.sqrt(w * h))
ratio = base_size / self.max_image_size
ratio = base_size / max_image_size
if ratio > 1:
w = round(w / ratio)
h = round(h / ratio)
@@ -477,7 +480,7 @@ class PixtralVisionImagePreprocessor:
return width_tokens, height_tokens
def __call__(self, image: torch.Tensor) -> torch.Tensor:
def __call__(self, image: torch.Tensor, max_image_size=None) -> torch.Tensor:
"""
Converts ImageChunks to numpy image arrays and image token ids
@@ -495,8 +498,11 @@ class PixtralVisionImagePreprocessor:
if image.min() < 0.0 or image.max() > 1.0:
raise ValueError(
f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
if max_image_size is None:
max_image_size = self.max_image_size
w, h = self._image_to_num_tokens(image)
w, h = self._image_to_num_tokens(image, max_image_size=max_image_size)
assert w > 0
assert h > 0
@@ -526,6 +532,7 @@ class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
'height': max_image_size,
'width': max_image_size
}
self.max_image_size = max_image_size
self.image_mean = DATASET_MEAN
self.image_std = DATASET_STD
@@ -535,13 +542,16 @@ class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
return_tensors="pt",
do_resize=True,
do_rescale=False,
max_image_size=None,
) -> torch.Tensor:
if max_image_size is None:
max_image_size = self.max_image_size
out_stack = []
if len(images.shape) == 3:
images = images.unsqueeze(0)
for i in range(images.shape[0]):
image = images[i]
processed_image = super().__call__(image)
processed_image = super().__call__(image, max_image_size=max_image_size)
out_stack.append(processed_image)
output = torch.stack(out_stack, dim=0)