diff --git a/adetailer/ultralytics.py b/adetailer/ultralytics.py index 80e03d8..90bc68f 100644 --- a/adetailer/ultralytics.py +++ b/adetailer/ultralytics.py @@ -35,7 +35,7 @@ def ultralytics_predict( if pred[0].masks is None: masks = create_mask_from_bbox(image, bboxes) else: - masks = mask_to_pil(pred[0].masks.data) + masks = mask_to_pil(pred[0].masks.data, image.size) preview = pred[0].plot() preview = cv2.cvtColor(preview, cv2.COLOR_BGR2RGB) preview = Image.fromarray(preview) @@ -56,14 +56,17 @@ def ultralytics_check(): print(message) -def mask_to_pil(masks) -> list[Image.Image]: +def mask_to_pil(masks, orig_shape: tuple[int, int]) -> list[Image.Image]: """ Parameters ---------- masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W). - The device can be CUDA, but `to_pil_image` takes care of that. + The device can be CUDA, but `to_pil_image` takes care of that. + + orig_shape: tuple[int, int] + (width, height) of the original image """ from torchvision.transforms.functional import to_pil_image n = masks.shape[0] - return [to_pil_image(masks[i], mode="L") for i in range(n)] + return [to_pil_image(masks[i], mode="L").resize(orig_shape) for i in range(n)]