From 599c3cc7fcfea9c0260e2cccb07af91644744c5b Mon Sep 17 00:00:00 2001 From: Dowon Date: Sun, 19 May 2024 23:54:38 +0900 Subject: [PATCH] refactor: generic int | float typing --- adetailer/common.py | 8 +++++--- adetailer/mask.py | 31 ++++++++++++++++++------------- adetailer/mediapipe.py | 8 +++++--- adetailer/ultralytics.py | 2 +- 4 files changed, 29 insertions(+), 20 deletions(-) diff --git a/adetailer/common.py b/adetailer/common.py index dfd7952..12fd77b 100644 --- a/adetailer/common.py +++ b/adetailer/common.py @@ -4,7 +4,7 @@ import os from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Optional +from typing import Any, Generic, Optional, TypeVar from huggingface_hub import hf_hub_download from PIL import Image, ImageDraw @@ -14,10 +14,12 @@ from torchvision.transforms.functional import to_pil_image REPO_ID = "Bingsu/adetailer" _download_failed = False +T = TypeVar("T", int, float) + @dataclass -class PredictOutput: - bboxes: list[list[int | float]] = field(default_factory=list) +class PredictOutput(Generic[T]): + bboxes: list[list[T]] = field(default_factory=list) masks: list[Image.Image] = field(default_factory=list) preview: Optional[Image.Image] = None diff --git a/adetailer/mask.py b/adetailer/mask.py index 3cd2edd..9496aa4 100644 --- a/adetailer/mask.py +++ b/adetailer/mask.py @@ -3,7 +3,7 @@ from __future__ import annotations from enum import IntEnum from functools import partial, reduce from math import dist -from typing import Any +from typing import Any, TypeVar import cv2 import numpy as np @@ -26,6 +26,9 @@ class MergeInvert(IntEnum): MERGE_INVERT = 2 +T = TypeVar("T", int, float) + + def _dilate(arr: np.ndarray, value: int) -> np.ndarray: kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value)) return cv2.dilate(arr, kernel, iterations=1) @@ -96,7 +99,7 @@ def has_intersection(im1: Any, im2: Any) -> bool: return not is_all_black(cv2.bitwise_and(arr1, arr2)) -def bbox_area(bbox: list[float]) -> float: +def bbox_area(bbox: list[T]) -> T: return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) @@ -141,25 +144,25 @@ def mask_preprocess( # Bbox sorting -def _key_left_to_right(bbox: list[float]) -> float: +def _key_left_to_right(bbox: list[T]) -> T: """ Left to right Parameters ---------- - bbox: list[float] + bbox: list[int] | list[float] list of [x1, y1, x2, y2] """ return bbox[0] -def _key_center_to_edge(bbox: list[float], *, center: tuple[float, float]) -> float: +def _key_center_to_edge(bbox: list[T], *, center: tuple[float, float]) -> float: """ Center to edge Parameters ---------- - bbox: list[float] + bbox: list[int] | list[float] list of [x1, y1, x2, y2] image: Image.Image the image @@ -168,21 +171,21 @@ def _key_center_to_edge(bbox: list[float], *, center: tuple[float, float]) -> fl return dist(center, bbox_center) -def _key_area(bbox: list[float]) -> float: +def _key_area(bbox: list[T]) -> T: """ Large to small Parameters ---------- - bbox: list[float] + bbox: list[int] | list[float] list of [x1, y1, x2, y2] """ return -bbox_area(bbox) def sort_bboxes( - pred: PredictOutput, order: int | SortBy = SortBy.NONE -) -> PredictOutput: + pred: PredictOutput[T], order: int | SortBy = SortBy.NONE +) -> PredictOutput[T]: if order == SortBy.NONE or len(pred.bboxes) <= 1: return pred @@ -205,12 +208,14 @@ def sort_bboxes( # Filter by ratio -def is_in_ratio(bbox: list[float], low: float, high: float, orig_area: int) -> bool: +def is_in_ratio(bbox: list[T], low: float, high: float, orig_area: int) -> bool: area = bbox_area(bbox) return low <= area / orig_area <= high -def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutput: +def filter_by_ratio( + pred: PredictOutput[T], low: float, high: float +) -> PredictOutput[T]: if not pred.bboxes: return pred @@ -223,7 +228,7 @@ def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutp return pred -def filter_k_largest(pred: PredictOutput, k: int = 0) -> PredictOutput: +def filter_k_largest(pred: PredictOutput[T], k: int = 0) -> PredictOutput[T]: if not pred.bboxes or k == 0: return pred areas = [bbox_area(bbox) for bbox in pred.bboxes] diff --git a/adetailer/mediapipe.py b/adetailer/mediapipe.py index 3d530a5..25c6900 100644 --- a/adetailer/mediapipe.py +++ b/adetailer/mediapipe.py @@ -28,7 +28,7 @@ def mediapipe_predict( def mediapipe_face_detection( model_type: int, image: Image.Image, confidence: float = 0.3 -) -> PredictOutput: +) -> PredictOutput[float]: import mediapipe as mp img_width, img_height = image.size @@ -68,7 +68,9 @@ def mediapipe_face_detection( return PredictOutput(bboxes=bboxes, masks=masks, preview=preview) -def mediapipe_face_mesh(image: Image.Image, confidence: float = 0.3) -> PredictOutput: +def mediapipe_face_mesh( + image: Image.Image, confidence: float = 0.3 +) -> PredictOutput[int]: import mediapipe as mp mp_face_mesh = mp.solutions.face_mesh @@ -115,7 +117,7 @@ def mediapipe_face_mesh(image: Image.Image, confidence: float = 0.3) -> PredictO def mediapipe_face_mesh_eyes_only( image: Image.Image, confidence: float = 0.3 -) -> PredictOutput: +) -> PredictOutput[int]: import mediapipe as mp mp_face_mesh = mp.solutions.face_mesh diff --git a/adetailer/ultralytics.py b/adetailer/ultralytics.py index 0c9fb86..dc93482 100644 --- a/adetailer/ultralytics.py +++ b/adetailer/ultralytics.py @@ -21,7 +21,7 @@ def ultralytics_predict( confidence: float = 0.3, device: str = "", classes: str = "", -) -> PredictOutput: +) -> PredictOutput[float]: from ultralytics import YOLO model = YOLO(model_path)