feat: add segment model

This commit is contained in:
Bingsu
2023-05-08 10:27:55 +09:00
parent d055f24895
commit 256d24055c
2 changed files with 18 additions and 1 deletions

View File

@@ -38,6 +38,7 @@ def get_models(model_dir: Union[str, Path]) -> OrderedDict[str, Optional[str]]:
"mediapipe_face_full": None,
"mediapipe_face_short": None,
"hand_yolov8n.pt": hf_hub_download(repo_id, "hand_yolov8n.pt"),
"person_yolov8n-seg.pt": hf_hub_download(repo_id, "person_yolov8n-seg.pt"),
}
)

View File

@@ -32,7 +32,10 @@ def ultralytics_predict(
return PredictOutput()
bboxes = bboxes.tolist()
masks = create_mask_from_bbox(image, bboxes)
if pred[0].masks is None:
masks = create_mask_from_bbox(image, bboxes)
else:
masks = mask_to_pil(pred[0].masks.data)
preview = pred[0].plot()
preview = cv2.cvtColor(preview, cv2.COLOR_BGR2RGB)
preview = Image.fromarray(preview)
@@ -51,3 +54,16 @@ def ultralytics_check():
if p == "C:\\":
message = "[-] ADetailer: if you get stuck here, try moving the stable-diffusion-webui to a different directory, or try running as administrator."
print(message)
def mask_to_pil(masks) -> list[Image.Image]:
"""
Parameters
----------
masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W).
The device can be CUDA, but `to_pil_image` takes care of that.
"""
from torchvision.transforms.functional import to_pil_image
n = masks.shape[0]
return [to_pil_image(masks[i], mode="L") for i in range(n)]