diff --git a/adetailer/__init__.py b/adetailer/__init__.py new file mode 100644 index 0000000..aee037e --- /dev/null +++ b/adetailer/__init__.py @@ -0,0 +1,3 @@ +from .common import PredictOutput, get_models + +__all__ = ["PredictOutput", "get_models"] diff --git a/adetailer/common.py b/adetailer/common.py new file mode 100644 index 0000000..44bf5d6 --- /dev/null +++ b/adetailer/common.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from collections import OrderedDict +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from huggingface_hub import hf_hub_download +from PIL import Image, ImageDraw + + +@dataclass +class PredictOutput: + bboxes: list[list[int]] | None = None + masks: list[Image.Image] | None = None + example: Image.Image | None = None + + +def get_models(model_dir: str | Path) -> OrderedDict[str, str | None]: + model_dir = Path(model_dir) + model_paths = [ + p for p in model_dir.rglob("*") if p.is_file() and p.suffix in (".pt", ".pth") + ] + + models = OrderedDict( + { + "face_yolo8n.pt": hf_hub_download("Bingsu/adetailer", "face_yolov8n.pt"), + "face_yolo8s.pt": hf_hub_download("Bingsu/adetailer", "face_yolov8s.pt"), + "mediapipe_face_full": None, + "mediapipe_face_short": None, + } + ) + + for path in model_paths: + if path.name in models: + continue + models[path.name] = str(path) + + return models + + +def create_mask_from_bbox( + image: Image.Image, bboxes: list[list[float]] +) -> list[Image.Image]: + """ + Parameters + ---------- + image: Image.Image + The image to create the mask from + bboxes: list[list[float]] + list of [x1, y1, x2, y2] + bounding boxes + + Returns + ------- + masks: list[Image.Image] + A list of masks + + """ + masks = [] + for bbox in bboxes: + mask = Image.new("L", image.size, 0) + mask_draw = ImageDraw.Draw(mask) + mask_draw.rectangle(bbox, fill=255) + masks.append(mask) + return masks diff --git a/adetailer/ultralytics.py b/adetailer/ultralytics.py new file mode 100644 index 0000000..c55bfb9 --- /dev/null +++ b/adetailer/ultralytics.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from pathlib import Path + +import cv2 +from PIL import Image +from ultralytics import YOLO + +from adetailer import PredictOutput +from adetailer.common import create_mask_from_bbox + + +def ultralytics_predict( + model_path: str | Path, image: Image.Image, confidence: float = 0.25 +) -> PredictOutput: + model_path = str(model_path) + + model = YOLO(model_path) + pred = model(image, conf=confidence, hide_labels=True) + + bboxes = pred[0].xyxy.cpu().numpy() + masks = create_mask_from_bbox(image, bboxes) + example = pred[0].plot() + example = cv2.cvtColor(example, cv2.COLOR_BGR2RGB) + example = Image.fromarray(example) + + return PredictOutput(bboxes=bboxes, masks=masks, example=example) diff --git a/pyproject.toml b/pyproject.toml index 4148d27..508dcdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,4 +22,4 @@ ignore = ["B008", "B905", "E501"] unfixable = ["F401"] [tool.ruff.isort] -known-first-party = ["modules", "launch"] \ No newline at end of file +known-first-party = ["modules", "launch"]