feat: ultraytics predict

This commit is contained in:
Bingsu
2023-04-26 15:04:45 +09:00
parent ee65200527
commit 50e2a88b5b
4 changed files with 97 additions and 1 deletions

3
adetailer/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .common import PredictOutput, get_models
__all__ = ["PredictOutput", "get_models"]

66
adetailer/common.py Normal file
View File

@@ -0,0 +1,66 @@
from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from huggingface_hub import hf_hub_download
from PIL import Image, ImageDraw
@dataclass
class PredictOutput:
bboxes: list[list[int]] | None = None
masks: list[Image.Image] | None = None
example: Image.Image | None = None
def get_models(model_dir: str | Path) -> OrderedDict[str, str | None]:
model_dir = Path(model_dir)
model_paths = [
p for p in model_dir.rglob("*") if p.is_file() and p.suffix in (".pt", ".pth")
]
models = OrderedDict(
{
"face_yolo8n.pt": hf_hub_download("Bingsu/adetailer", "face_yolov8n.pt"),
"face_yolo8s.pt": hf_hub_download("Bingsu/adetailer", "face_yolov8s.pt"),
"mediapipe_face_full": None,
"mediapipe_face_short": None,
}
)
for path in model_paths:
if path.name in models:
continue
models[path.name] = str(path)
return models
def create_mask_from_bbox(
image: Image.Image, bboxes: list[list[float]]
) -> list[Image.Image]:
"""
Parameters
----------
image: Image.Image
The image to create the mask from
bboxes: list[list[float]]
list of [x1, y1, x2, y2]
bounding boxes
Returns
-------
masks: list[Image.Image]
A list of masks
"""
masks = []
for bbox in bboxes:
mask = Image.new("L", image.size, 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.rectangle(bbox, fill=255)
masks.append(mask)
return masks

27
adetailer/ultralytics.py Normal file
View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from pathlib import Path
import cv2
from PIL import Image
from ultralytics import YOLO
from adetailer import PredictOutput
from adetailer.common import create_mask_from_bbox
def ultralytics_predict(
model_path: str | Path, image: Image.Image, confidence: float = 0.25
) -> PredictOutput:
model_path = str(model_path)
model = YOLO(model_path)
pred = model(image, conf=confidence, hide_labels=True)
bboxes = pred[0].xyxy.cpu().numpy()
masks = create_mask_from_bbox(image, bboxes)
example = pred[0].plot()
example = cv2.cvtColor(example, cv2.COLOR_BGR2RGB)
example = Image.fromarray(example)
return PredictOutput(bboxes=bboxes, masks=masks, example=example)

View File

@@ -22,4 +22,4 @@ ignore = ["B008", "B905", "E501"]
unfixable = ["F401"]
[tool.ruff.isort]
known-first-party = ["modules", "launch"]
known-first-party = ["modules", "launch"]