refactor: generic int | float typing

This commit is contained in:
Dowon
2024-05-19 23:54:38 +09:00
parent 9d1b6bf64a
commit 599c3cc7fc
4 changed files with 29 additions and 20 deletions

View File

@@ -4,7 +4,7 @@ import os
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional
from typing import Any, Generic, Optional, TypeVar
from huggingface_hub import hf_hub_download
from PIL import Image, ImageDraw
@@ -14,10 +14,12 @@ from torchvision.transforms.functional import to_pil_image
REPO_ID = "Bingsu/adetailer"
_download_failed = False
T = TypeVar("T", int, float)
@dataclass
class PredictOutput:
bboxes: list[list[int | float]] = field(default_factory=list)
class PredictOutput(Generic[T]):
bboxes: list[list[T]] = field(default_factory=list)
masks: list[Image.Image] = field(default_factory=list)
preview: Optional[Image.Image] = None

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from enum import IntEnum
from functools import partial, reduce
from math import dist
from typing import Any
from typing import Any, TypeVar
import cv2
import numpy as np
@@ -26,6 +26,9 @@ class MergeInvert(IntEnum):
MERGE_INVERT = 2
T = TypeVar("T", int, float)
def _dilate(arr: np.ndarray, value: int) -> np.ndarray:
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value))
return cv2.dilate(arr, kernel, iterations=1)
@@ -96,7 +99,7 @@ def has_intersection(im1: Any, im2: Any) -> bool:
return not is_all_black(cv2.bitwise_and(arr1, arr2))
def bbox_area(bbox: list[float]) -> float:
def bbox_area(bbox: list[T]) -> T:
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
@@ -141,25 +144,25 @@ def mask_preprocess(
# Bbox sorting
def _key_left_to_right(bbox: list[float]) -> float:
def _key_left_to_right(bbox: list[T]) -> T:
"""
Left to right
Parameters
----------
bbox: list[float]
bbox: list[int] | list[float]
list of [x1, y1, x2, y2]
"""
return bbox[0]
def _key_center_to_edge(bbox: list[float], *, center: tuple[float, float]) -> float:
def _key_center_to_edge(bbox: list[T], *, center: tuple[float, float]) -> float:
"""
Center to edge
Parameters
----------
bbox: list[float]
bbox: list[int] | list[float]
list of [x1, y1, x2, y2]
image: Image.Image
the image
@@ -168,21 +171,21 @@ def _key_center_to_edge(bbox: list[float], *, center: tuple[float, float]) -> fl
return dist(center, bbox_center)
def _key_area(bbox: list[float]) -> float:
def _key_area(bbox: list[T]) -> T:
"""
Large to small
Parameters
----------
bbox: list[float]
bbox: list[int] | list[float]
list of [x1, y1, x2, y2]
"""
return -bbox_area(bbox)
def sort_bboxes(
pred: PredictOutput, order: int | SortBy = SortBy.NONE
) -> PredictOutput:
pred: PredictOutput[T], order: int | SortBy = SortBy.NONE
) -> PredictOutput[T]:
if order == SortBy.NONE or len(pred.bboxes) <= 1:
return pred
@@ -205,12 +208,14 @@ def sort_bboxes(
# Filter by ratio
def is_in_ratio(bbox: list[float], low: float, high: float, orig_area: int) -> bool:
def is_in_ratio(bbox: list[T], low: float, high: float, orig_area: int) -> bool:
area = bbox_area(bbox)
return low <= area / orig_area <= high
def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutput:
def filter_by_ratio(
pred: PredictOutput[T], low: float, high: float
) -> PredictOutput[T]:
if not pred.bboxes:
return pred
@@ -223,7 +228,7 @@ def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutp
return pred
def filter_k_largest(pred: PredictOutput, k: int = 0) -> PredictOutput:
def filter_k_largest(pred: PredictOutput[T], k: int = 0) -> PredictOutput[T]:
if not pred.bboxes or k == 0:
return pred
areas = [bbox_area(bbox) for bbox in pred.bboxes]

View File

@@ -28,7 +28,7 @@ def mediapipe_predict(
def mediapipe_face_detection(
model_type: int, image: Image.Image, confidence: float = 0.3
) -> PredictOutput:
) -> PredictOutput[float]:
import mediapipe as mp
img_width, img_height = image.size
@@ -68,7 +68,9 @@ def mediapipe_face_detection(
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
def mediapipe_face_mesh(image: Image.Image, confidence: float = 0.3) -> PredictOutput:
def mediapipe_face_mesh(
image: Image.Image, confidence: float = 0.3
) -> PredictOutput[int]:
import mediapipe as mp
mp_face_mesh = mp.solutions.face_mesh
@@ -115,7 +117,7 @@ def mediapipe_face_mesh(image: Image.Image, confidence: float = 0.3) -> PredictO
def mediapipe_face_mesh_eyes_only(
image: Image.Image, confidence: float = 0.3
) -> PredictOutput:
) -> PredictOutput[int]:
import mediapipe as mp
mp_face_mesh = mp.solutions.face_mesh

View File

@@ -21,7 +21,7 @@ def ultralytics_predict(
confidence: float = 0.3,
device: str = "",
classes: str = "",
) -> PredictOutput:
) -> PredictOutput[float]:
from ultralytics import YOLO
model = YOLO(model_path)