mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-04-28 10:11:24 +00:00
fix: resize seg mask
This commit is contained in:
@@ -35,7 +35,7 @@ def ultralytics_predict(
|
|||||||
if pred[0].masks is None:
|
if pred[0].masks is None:
|
||||||
masks = create_mask_from_bbox(image, bboxes)
|
masks = create_mask_from_bbox(image, bboxes)
|
||||||
else:
|
else:
|
||||||
masks = mask_to_pil(pred[0].masks.data)
|
masks = mask_to_pil(pred[0].masks.data, image.size)
|
||||||
preview = pred[0].plot()
|
preview = pred[0].plot()
|
||||||
preview = cv2.cvtColor(preview, cv2.COLOR_BGR2RGB)
|
preview = cv2.cvtColor(preview, cv2.COLOR_BGR2RGB)
|
||||||
preview = Image.fromarray(preview)
|
preview = Image.fromarray(preview)
|
||||||
@@ -56,14 +56,17 @@ def ultralytics_check():
|
|||||||
print(message)
|
print(message)
|
||||||
|
|
||||||
|
|
||||||
def mask_to_pil(masks) -> list[Image.Image]:
|
def mask_to_pil(masks, orig_shape: tuple[int, int]) -> list[Image.Image]:
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W).
|
masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W).
|
||||||
The device can be CUDA, but `to_pil_image` takes care of that.
|
The device can be CUDA, but `to_pil_image` takes care of that.
|
||||||
|
|
||||||
|
orig_shape: tuple[int, int]
|
||||||
|
(width, height) of the original image
|
||||||
"""
|
"""
|
||||||
from torchvision.transforms.functional import to_pil_image
|
from torchvision.transforms.functional import to_pil_image
|
||||||
|
|
||||||
n = masks.shape[0]
|
n = masks.shape[0]
|
||||||
return [to_pil_image(masks[i], mode="L") for i in range(n)]
|
return [to_pil_image(masks[i], mode="L").resize(orig_shape) for i in range(n)]
|
||||||
|
|||||||
Reference in New Issue
Block a user