mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-03-04 04:50:02 +00:00
refactor: generic int | float typing
This commit is contained in:
@@ -4,7 +4,7 @@ import os
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image, ImageDraw
|
||||
@@ -14,10 +14,12 @@ from torchvision.transforms.functional import to_pil_image
|
||||
REPO_ID = "Bingsu/adetailer"
|
||||
_download_failed = False
|
||||
|
||||
T = TypeVar("T", int, float)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictOutput:
|
||||
bboxes: list[list[int | float]] = field(default_factory=list)
|
||||
class PredictOutput(Generic[T]):
|
||||
bboxes: list[list[T]] = field(default_factory=list)
|
||||
masks: list[Image.Image] = field(default_factory=list)
|
||||
preview: Optional[Image.Image] = None
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
from functools import partial, reduce
|
||||
from math import dist
|
||||
from typing import Any
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -26,6 +26,9 @@ class MergeInvert(IntEnum):
|
||||
MERGE_INVERT = 2
|
||||
|
||||
|
||||
T = TypeVar("T", int, float)
|
||||
|
||||
|
||||
def _dilate(arr: np.ndarray, value: int) -> np.ndarray:
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value))
|
||||
return cv2.dilate(arr, kernel, iterations=1)
|
||||
@@ -96,7 +99,7 @@ def has_intersection(im1: Any, im2: Any) -> bool:
|
||||
return not is_all_black(cv2.bitwise_and(arr1, arr2))
|
||||
|
||||
|
||||
def bbox_area(bbox: list[float]) -> float:
|
||||
def bbox_area(bbox: list[T]) -> T:
|
||||
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
||||
|
||||
|
||||
@@ -141,25 +144,25 @@ def mask_preprocess(
|
||||
|
||||
|
||||
# Bbox sorting
|
||||
def _key_left_to_right(bbox: list[float]) -> float:
|
||||
def _key_left_to_right(bbox: list[T]) -> T:
|
||||
"""
|
||||
Left to right
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox: list[float]
|
||||
bbox: list[int] | list[float]
|
||||
list of [x1, y1, x2, y2]
|
||||
"""
|
||||
return bbox[0]
|
||||
|
||||
|
||||
def _key_center_to_edge(bbox: list[float], *, center: tuple[float, float]) -> float:
|
||||
def _key_center_to_edge(bbox: list[T], *, center: tuple[float, float]) -> float:
|
||||
"""
|
||||
Center to edge
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox: list[float]
|
||||
bbox: list[int] | list[float]
|
||||
list of [x1, y1, x2, y2]
|
||||
image: Image.Image
|
||||
the image
|
||||
@@ -168,21 +171,21 @@ def _key_center_to_edge(bbox: list[float], *, center: tuple[float, float]) -> fl
|
||||
return dist(center, bbox_center)
|
||||
|
||||
|
||||
def _key_area(bbox: list[float]) -> float:
|
||||
def _key_area(bbox: list[T]) -> T:
|
||||
"""
|
||||
Large to small
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bbox: list[float]
|
||||
bbox: list[int] | list[float]
|
||||
list of [x1, y1, x2, y2]
|
||||
"""
|
||||
return -bbox_area(bbox)
|
||||
|
||||
|
||||
def sort_bboxes(
|
||||
pred: PredictOutput, order: int | SortBy = SortBy.NONE
|
||||
) -> PredictOutput:
|
||||
pred: PredictOutput[T], order: int | SortBy = SortBy.NONE
|
||||
) -> PredictOutput[T]:
|
||||
if order == SortBy.NONE or len(pred.bboxes) <= 1:
|
||||
return pred
|
||||
|
||||
@@ -205,12 +208,14 @@ def sort_bboxes(
|
||||
|
||||
|
||||
# Filter by ratio
|
||||
def is_in_ratio(bbox: list[float], low: float, high: float, orig_area: int) -> bool:
|
||||
def is_in_ratio(bbox: list[T], low: float, high: float, orig_area: int) -> bool:
|
||||
area = bbox_area(bbox)
|
||||
return low <= area / orig_area <= high
|
||||
|
||||
|
||||
def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutput:
|
||||
def filter_by_ratio(
|
||||
pred: PredictOutput[T], low: float, high: float
|
||||
) -> PredictOutput[T]:
|
||||
if not pred.bboxes:
|
||||
return pred
|
||||
|
||||
@@ -223,7 +228,7 @@ def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutp
|
||||
return pred
|
||||
|
||||
|
||||
def filter_k_largest(pred: PredictOutput, k: int = 0) -> PredictOutput:
|
||||
def filter_k_largest(pred: PredictOutput[T], k: int = 0) -> PredictOutput[T]:
|
||||
if not pred.bboxes or k == 0:
|
||||
return pred
|
||||
areas = [bbox_area(bbox) for bbox in pred.bboxes]
|
||||
|
||||
@@ -28,7 +28,7 @@ def mediapipe_predict(
|
||||
|
||||
def mediapipe_face_detection(
|
||||
model_type: int, image: Image.Image, confidence: float = 0.3
|
||||
) -> PredictOutput:
|
||||
) -> PredictOutput[float]:
|
||||
import mediapipe as mp
|
||||
|
||||
img_width, img_height = image.size
|
||||
@@ -68,7 +68,9 @@ def mediapipe_face_detection(
|
||||
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
|
||||
|
||||
|
||||
def mediapipe_face_mesh(image: Image.Image, confidence: float = 0.3) -> PredictOutput:
|
||||
def mediapipe_face_mesh(
|
||||
image: Image.Image, confidence: float = 0.3
|
||||
) -> PredictOutput[int]:
|
||||
import mediapipe as mp
|
||||
|
||||
mp_face_mesh = mp.solutions.face_mesh
|
||||
@@ -115,7 +117,7 @@ def mediapipe_face_mesh(image: Image.Image, confidence: float = 0.3) -> PredictO
|
||||
|
||||
def mediapipe_face_mesh_eyes_only(
|
||||
image: Image.Image, confidence: float = 0.3
|
||||
) -> PredictOutput:
|
||||
) -> PredictOutput[int]:
|
||||
import mediapipe as mp
|
||||
|
||||
mp_face_mesh = mp.solutions.face_mesh
|
||||
|
||||
@@ -21,7 +21,7 @@ def ultralytics_predict(
|
||||
confidence: float = 0.3,
|
||||
device: str = "",
|
||||
classes: str = "",
|
||||
) -> PredictOutput:
|
||||
) -> PredictOutput[float]:
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO(model_path)
|
||||
|
||||
Reference in New Issue
Block a user