diff --git a/adetailer/common.py b/adetailer/common.py index f12b682..fe91b3f 100644 --- a/adetailer/common.py +++ b/adetailer/common.py @@ -51,16 +51,16 @@ def get_models(model_dir: Union[str, Path]) -> OrderedDict[str, Optional[str]]: def create_mask_from_bbox( - image: Image.Image, bboxes: list[list[float]] + bboxes: list[list[float]], shape: tuple[int, int] ) -> list[Image.Image]: """ Parameters ---------- - image: Image.Image - The image to create the mask from bboxes: list[list[float]] list of [x1, y1, x2, y2] bounding boxes + shape: tuple[int, int] + shape of the image (width, height) Returns ------- @@ -70,7 +70,7 @@ def create_mask_from_bbox( """ masks = [] for bbox in bboxes: - mask = Image.new("L", image.size, 0) + mask = Image.new("L", shape, 0) mask_draw = ImageDraw.Draw(mask) mask_draw.rectangle(bbox, fill=255) masks.append(mask) diff --git a/adetailer/mediapipe.py b/adetailer/mediapipe.py index 2e1ea43..7a3e775 100644 --- a/adetailer/mediapipe.py +++ b/adetailer/mediapipe.py @@ -45,7 +45,7 @@ def mediapipe_predict( bboxes.append([x1, y1, x2, y2]) - masks = create_mask_from_bbox(image, bboxes) + masks = create_mask_from_bbox(bboxes, image.size) preview = Image.fromarray(preview_array) return PredictOutput(bboxes=bboxes, masks=masks, preview=preview) diff --git a/adetailer/ultralytics.py b/adetailer/ultralytics.py index 90bc68f..8d378f1 100644 --- a/adetailer/ultralytics.py +++ b/adetailer/ultralytics.py @@ -33,7 +33,7 @@ def ultralytics_predict( bboxes = bboxes.tolist() if pred[0].masks is None: - masks = create_mask_from_bbox(image, bboxes) + masks = create_mask_from_bbox(bboxes, image.size) else: masks = mask_to_pil(pred[0].masks.data, image.size) preview = pred[0].plot() @@ -56,17 +56,17 @@ def ultralytics_check(): print(message) -def mask_to_pil(masks, orig_shape: tuple[int, int]) -> list[Image.Image]: +def mask_to_pil(masks, shape: tuple[int, int]) -> list[Image.Image]: """ Parameters ---------- masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W). The device can be CUDA, but `to_pil_image` takes care of that. - orig_shape: tuple[int, int] + shape: tuple[int, int] (width, height) of the original image """ from torchvision.transforms.functional import to_pil_image n = masks.shape[0] - return [to_pil_image(masks[i], mode="L").resize(orig_shape) for i in range(n)] + return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]