mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-03-13 17:30:01 +00:00
feat: mediapipe face mesh
This commit is contained in:
@@ -13,7 +13,7 @@ repo_id = "Bingsu/adetailer"
|
||||
|
||||
@dataclass
|
||||
class PredictOutput:
|
||||
bboxes: list[list[float]] = field(default_factory=list)
|
||||
bboxes: list[list[int | float]] = field(default_factory=list)
|
||||
masks: list[Image.Image] = field(default_factory=list)
|
||||
preview: Optional[Image.Image] = None
|
||||
|
||||
@@ -54,6 +54,7 @@ def get_models(
|
||||
{
|
||||
"mediapipe_face_full": None,
|
||||
"mediapipe_face_short": None,
|
||||
"mediapipe_face_mesh": None,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -94,3 +95,29 @@ def create_mask_from_bbox(
|
||||
mask_draw.rectangle(bbox, fill=255)
|
||||
masks.append(mask)
|
||||
return masks
|
||||
|
||||
|
||||
def create_bbox_from_mask(
|
||||
masks: list[Image.Image], shape: tuple[int, int]
|
||||
) -> list[list[int]]:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
masks: list[Image.Image]
|
||||
A list of masks
|
||||
shape: tuple[int, int]
|
||||
shape of the image (width, height)
|
||||
|
||||
Returns
|
||||
-------
|
||||
bboxes: list[list[float]]
|
||||
A list of bounding boxes
|
||||
|
||||
"""
|
||||
bboxes = []
|
||||
for mask in masks:
|
||||
mask = mask.resize(shape)
|
||||
bbox = mask.getbbox()
|
||||
if bbox is not None:
|
||||
bboxes.append(list(bbox))
|
||||
return bboxes
|
||||
|
||||
@@ -1,19 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from adetailer import PredictOutput
|
||||
from adetailer.common import create_mask_from_bbox
|
||||
from adetailer.common import create_bbox_from_mask, create_mask_from_bbox
|
||||
|
||||
|
||||
def mediapipe_predict(
|
||||
model_type: int | str, image: Image.Image, confidence: float = 0.3
|
||||
model_type: str, image: Image.Image, confidence: float = 0.3
|
||||
) -> PredictOutput:
|
||||
mapping = {
|
||||
"mediapipe_face_short": partial(mediapipe_face_detection, model_type=0),
|
||||
"mediapipe_face_full": partial(mediapipe_face_detection, model_type=1),
|
||||
"mediapipe_face_mesh": mediapipe_face_mesh,
|
||||
}
|
||||
if model_type in mapping:
|
||||
func = mapping[model_type]
|
||||
return func(image, confidence)
|
||||
raise RuntimeError(f"[-] ADetailer: Invalid mediapipe model type: {model_type}")
|
||||
|
||||
|
||||
def mediapipe_face_detection(
|
||||
model_type: int, image: Image.Image, confidence: float = 0.3
|
||||
) -> PredictOutput:
|
||||
import mediapipe as mp
|
||||
|
||||
if isinstance(model_type, str):
|
||||
model_type = mediapipe_model_name_to_type(model_type)
|
||||
img_width, img_height = image.size
|
||||
|
||||
mp_face_detection = mp.solutions.face_detection
|
||||
@@ -51,12 +65,47 @@ def mediapipe_predict(
|
||||
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
|
||||
|
||||
|
||||
def mediapipe_model_name_to_type(name: str) -> int:
|
||||
name = name.lower()
|
||||
mapping = {
|
||||
"mediapipe_face_short": 0,
|
||||
"mediapipe_face_full": 1,
|
||||
}
|
||||
if name not in mapping:
|
||||
raise ValueError(f"[-] ADetailer: Invalid model name: {name}")
|
||||
return mapping[name]
|
||||
def mediapipe_face_mesh(image: Image.Image, confidence: float = 0.3) -> PredictOutput:
|
||||
import mediapipe as mp
|
||||
from scipy.spatial import ConvexHull
|
||||
|
||||
mp_face_mesh = mp.solutions.face_mesh
|
||||
draw_util = mp.solutions.drawing_utils
|
||||
drawing_styles = mp.solutions.drawing_styles
|
||||
|
||||
w, h = image.size
|
||||
|
||||
with mp_face_mesh.FaceMesh(
|
||||
static_image_mode=True, max_num_faces=20, min_detection_confidence=confidence
|
||||
) as face_mesh:
|
||||
arr = np.array(image)
|
||||
pred = face_mesh.process(arr)
|
||||
|
||||
if pred.multi_face_landmarks is None:
|
||||
return PredictOutput()
|
||||
|
||||
preview = arr.copy()
|
||||
masks = []
|
||||
|
||||
for landmarks in pred.multi_face_landmarks:
|
||||
draw_util.draw_landmarks(
|
||||
image=preview,
|
||||
landmark_list=landmarks,
|
||||
connections=mp_face_mesh.FACEMESH_TESSELATION,
|
||||
landmark_drawing_spec=None,
|
||||
connection_drawing_spec=drawing_styles.get_default_face_mesh_tesselation_style(),
|
||||
)
|
||||
|
||||
points = np.array([(land.x * w, land.y * h) for land in landmarks.landmark])
|
||||
hull = ConvexHull(points)
|
||||
vertices = hull.vertices
|
||||
outline = list(zip(points[vertices, 0], points[vertices, 1]))
|
||||
|
||||
mask = Image.new("L", image.size, "black")
|
||||
draw = ImageDraw.Draw(mask)
|
||||
draw.polygon(outline, fill="white")
|
||||
masks.append(mask)
|
||||
|
||||
bboxes = create_bbox_from_mask(masks, image.size)
|
||||
preview = Image.fromarray(preview)
|
||||
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
|
||||
|
||||
Reference in New Issue
Block a user