mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-01-26 11:19:53 +00:00
feat: ultraytics predict
This commit is contained in:
3
adetailer/__init__.py
Normal file
3
adetailer/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .common import PredictOutput, get_models
|
||||
|
||||
__all__ = ["PredictOutput", "get_models"]
|
||||
66
adetailer/common.py
Normal file
66
adetailer/common.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictOutput:
|
||||
bboxes: list[list[int]] | None = None
|
||||
masks: list[Image.Image] | None = None
|
||||
example: Image.Image | None = None
|
||||
|
||||
|
||||
def get_models(model_dir: str | Path) -> OrderedDict[str, str | None]:
|
||||
model_dir = Path(model_dir)
|
||||
model_paths = [
|
||||
p for p in model_dir.rglob("*") if p.is_file() and p.suffix in (".pt", ".pth")
|
||||
]
|
||||
|
||||
models = OrderedDict(
|
||||
{
|
||||
"face_yolo8n.pt": hf_hub_download("Bingsu/adetailer", "face_yolov8n.pt"),
|
||||
"face_yolo8s.pt": hf_hub_download("Bingsu/adetailer", "face_yolov8s.pt"),
|
||||
"mediapipe_face_full": None,
|
||||
"mediapipe_face_short": None,
|
||||
}
|
||||
)
|
||||
|
||||
for path in model_paths:
|
||||
if path.name in models:
|
||||
continue
|
||||
models[path.name] = str(path)
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def create_mask_from_bbox(
|
||||
image: Image.Image, bboxes: list[list[float]]
|
||||
) -> list[Image.Image]:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
image: Image.Image
|
||||
The image to create the mask from
|
||||
bboxes: list[list[float]]
|
||||
list of [x1, y1, x2, y2]
|
||||
bounding boxes
|
||||
|
||||
Returns
|
||||
-------
|
||||
masks: list[Image.Image]
|
||||
A list of masks
|
||||
|
||||
"""
|
||||
masks = []
|
||||
for bbox in bboxes:
|
||||
mask = Image.new("L", image.size, 0)
|
||||
mask_draw = ImageDraw.Draw(mask)
|
||||
mask_draw.rectangle(bbox, fill=255)
|
||||
masks.append(mask)
|
||||
return masks
|
||||
27
adetailer/ultralytics.py
Normal file
27
adetailer/ultralytics.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from ultralytics import YOLO
|
||||
|
||||
from adetailer import PredictOutput
|
||||
from adetailer.common import create_mask_from_bbox
|
||||
|
||||
|
||||
def ultralytics_predict(
|
||||
model_path: str | Path, image: Image.Image, confidence: float = 0.25
|
||||
) -> PredictOutput:
|
||||
model_path = str(model_path)
|
||||
|
||||
model = YOLO(model_path)
|
||||
pred = model(image, conf=confidence, hide_labels=True)
|
||||
|
||||
bboxes = pred[0].xyxy.cpu().numpy()
|
||||
masks = create_mask_from_bbox(image, bboxes)
|
||||
example = pred[0].plot()
|
||||
example = cv2.cvtColor(example, cv2.COLOR_BGR2RGB)
|
||||
example = Image.fromarray(example)
|
||||
|
||||
return PredictOutput(bboxes=bboxes, masks=masks, example=example)
|
||||
@@ -22,4 +22,4 @@ ignore = ["B008", "B905", "E501"]
|
||||
unfixable = ["F401"]
|
||||
|
||||
[tool.ruff.isort]
|
||||
known-first-party = ["modules", "launch"]
|
||||
known-first-party = ["modules", "launch"]
|
||||
|
||||
Reference in New Issue
Block a user