fix: resize seg mask

This commit is contained in:
Bingsu
2023-05-08 10:52:56 +09:00
parent ae44df5e30
commit 623d4d20ac

View File

@@ -35,7 +35,7 @@ def ultralytics_predict(
if pred[0].masks is None:
masks = create_mask_from_bbox(image, bboxes)
else:
masks = mask_to_pil(pred[0].masks.data)
masks = mask_to_pil(pred[0].masks.data, image.size)
preview = pred[0].plot()
preview = cv2.cvtColor(preview, cv2.COLOR_BGR2RGB)
preview = Image.fromarray(preview)
@@ -56,14 +56,17 @@ def ultralytics_check():
print(message)
def mask_to_pil(masks) -> list[Image.Image]:
def mask_to_pil(masks, orig_shape: tuple[int, int]) -> list[Image.Image]:
"""
Parameters
----------
masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W).
The device can be CUDA, but `to_pil_image` takes care of that.
The device can be CUDA, but `to_pil_image` takes care of that.
orig_shape: tuple[int, int]
(width, height) of the original image
"""
from torchvision.transforms.functional import to_pil_image
n = masks.shape[0]
return [to_pil_image(masks[i], mode="L") for i in range(n)]
return [to_pil_image(masks[i], mode="L").resize(orig_shape) for i in range(n)]