diff --git a/toolkit/models/pixtral_vision.py b/toolkit/models/pixtral_vision.py index ebebce40..e6c972bd 100644 --- a/toolkit/models/pixtral_vision.py +++ b/toolkit/models/pixtral_vision.py @@ -1,3 +1,4 @@ +import math from typing import List, Optional, Tuple, Any, Union import os import torch @@ -461,7 +462,12 @@ class PixtralVisionImagePreprocessor: w, h = img.shape[-1], img.shape[-2] - ratio = max(h / self.max_image_size, w / self.max_image_size) + # originally, pixtral used the largest of the 2 dimensions, but we + # will use the base size of the image based on number of pixels. + # ratio = max(h / self.max_image_size, w / self.max_image_size) # original + + base_size = int(math.sqrt(w * h)) + ratio = base_size / self.max_image_size if ratio > 1: w = round(w / ratio) h = round(h / ratio)