fix: misc

This commit is contained in:
Bingsu
2023-05-09 16:53:59 +09:00
parent 6095dbc41b
commit 4b3e5f7e96
3 changed files with 9 additions and 9 deletions

View File

@@ -51,16 +51,16 @@ def get_models(model_dir: Union[str, Path]) -> OrderedDict[str, Optional[str]]:
def create_mask_from_bbox(
image: Image.Image, bboxes: list[list[float]]
bboxes: list[list[float]], shape: tuple[int, int]
) -> list[Image.Image]:
"""
Parameters
----------
image: Image.Image
The image to create the mask from
bboxes: list[list[float]]
list of [x1, y1, x2, y2]
bounding boxes
shape: tuple[int, int]
shape of the image (width, height)
Returns
-------
@@ -70,7 +70,7 @@ def create_mask_from_bbox(
"""
masks = []
for bbox in bboxes:
mask = Image.new("L", image.size, 0)
mask = Image.new("L", shape, 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.rectangle(bbox, fill=255)
masks.append(mask)

View File

@@ -45,7 +45,7 @@ def mediapipe_predict(
bboxes.append([x1, y1, x2, y2])
masks = create_mask_from_bbox(image, bboxes)
masks = create_mask_from_bbox(bboxes, image.size)
preview = Image.fromarray(preview_array)
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)

View File

@@ -33,7 +33,7 @@ def ultralytics_predict(
bboxes = bboxes.tolist()
if pred[0].masks is None:
masks = create_mask_from_bbox(image, bboxes)
masks = create_mask_from_bbox(bboxes, image.size)
else:
masks = mask_to_pil(pred[0].masks.data, image.size)
preview = pred[0].plot()
@@ -56,17 +56,17 @@ def ultralytics_check():
print(message)
def mask_to_pil(masks, orig_shape: tuple[int, int]) -> list[Image.Image]:
def mask_to_pil(masks, shape: tuple[int, int]) -> list[Image.Image]:
"""
Parameters
----------
masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W).
The device can be CUDA, but `to_pil_image` takes care of that.
orig_shape: tuple[int, int]
shape: tuple[int, int]
(width, height) of the original image
"""
from torchvision.transforms.functional import to_pil_image
n = masks.shape[0]
return [to_pil_image(masks[i], mode="L").resize(orig_shape) for i in range(n)]
return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]