From 841f511216ba2a3d583f7ec72008960c992b6f3f Mon Sep 17 00:00:00 2001 From: Bingsu Date: Mon, 29 May 2023 14:56:00 +0900 Subject: [PATCH] feat: mediapipe face mesh --- adetailer/common.py | 29 +++++++++++++++- adetailer/mediapipe.py | 77 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 91 insertions(+), 15 deletions(-) diff --git a/adetailer/common.py b/adetailer/common.py index 73d0ba8..064153f 100644 --- a/adetailer/common.py +++ b/adetailer/common.py @@ -13,7 +13,7 @@ repo_id = "Bingsu/adetailer" @dataclass class PredictOutput: - bboxes: list[list[float]] = field(default_factory=list) + bboxes: list[list[int | float]] = field(default_factory=list) masks: list[Image.Image] = field(default_factory=list) preview: Optional[Image.Image] = None @@ -54,6 +54,7 @@ def get_models( { "mediapipe_face_full": None, "mediapipe_face_short": None, + "mediapipe_face_mesh": None, } ) @@ -94,3 +95,29 @@ def create_mask_from_bbox( mask_draw.rectangle(bbox, fill=255) masks.append(mask) return masks + + +def create_bbox_from_mask( + masks: list[Image.Image], shape: tuple[int, int] +) -> list[list[int]]: + """ + Parameters + ---------- + masks: list[Image.Image] + A list of masks + shape: tuple[int, int] + shape of the image (width, height) + + Returns + ------- + bboxes: list[list[float]] + A list of bounding boxes + + """ + bboxes = [] + for mask in masks: + mask = mask.resize(shape) + bbox = mask.getbbox() + if bbox is not None: + bboxes.append(list(bbox)) + return bboxes diff --git a/adetailer/mediapipe.py b/adetailer/mediapipe.py index 066eabc..dfaca74 100644 --- a/adetailer/mediapipe.py +++ b/adetailer/mediapipe.py @@ -1,19 +1,33 @@ from __future__ import annotations +from functools import partial + import numpy as np -from PIL import Image +from PIL import Image, ImageDraw from adetailer import PredictOutput -from adetailer.common import create_mask_from_bbox +from adetailer.common import create_bbox_from_mask, create_mask_from_bbox def mediapipe_predict( - model_type: int | str, image: Image.Image, confidence: float = 0.3 + model_type: str, image: Image.Image, confidence: float = 0.3 +) -> PredictOutput: + mapping = { + "mediapipe_face_short": partial(mediapipe_face_detection, model_type=0), + "mediapipe_face_full": partial(mediapipe_face_detection, model_type=1), + "mediapipe_face_mesh": mediapipe_face_mesh, + } + if model_type in mapping: + func = mapping[model_type] + return func(image, confidence) + raise RuntimeError(f"[-] ADetailer: Invalid mediapipe model type: {model_type}") + + +def mediapipe_face_detection( + model_type: int, image: Image.Image, confidence: float = 0.3 ) -> PredictOutput: import mediapipe as mp - if isinstance(model_type, str): - model_type = mediapipe_model_name_to_type(model_type) img_width, img_height = image.size mp_face_detection = mp.solutions.face_detection @@ -51,12 +65,47 @@ def mediapipe_predict( return PredictOutput(bboxes=bboxes, masks=masks, preview=preview) -def mediapipe_model_name_to_type(name: str) -> int: - name = name.lower() - mapping = { - "mediapipe_face_short": 0, - "mediapipe_face_full": 1, - } - if name not in mapping: - raise ValueError(f"[-] ADetailer: Invalid model name: {name}") - return mapping[name] +def mediapipe_face_mesh(image: Image.Image, confidence: float = 0.3) -> PredictOutput: + import mediapipe as mp + from scipy.spatial import ConvexHull + + mp_face_mesh = mp.solutions.face_mesh + draw_util = mp.solutions.drawing_utils + drawing_styles = mp.solutions.drawing_styles + + w, h = image.size + + with mp_face_mesh.FaceMesh( + static_image_mode=True, max_num_faces=20, min_detection_confidence=confidence + ) as face_mesh: + arr = np.array(image) + pred = face_mesh.process(arr) + + if pred.multi_face_landmarks is None: + return PredictOutput() + + preview = arr.copy() + masks = [] + + for landmarks in pred.multi_face_landmarks: + draw_util.draw_landmarks( + image=preview, + landmark_list=landmarks, + connections=mp_face_mesh.FACEMESH_TESSELATION, + landmark_drawing_spec=None, + connection_drawing_spec=drawing_styles.get_default_face_mesh_tesselation_style(), + ) + + points = np.array([(land.x * w, land.y * h) for land in landmarks.landmark]) + hull = ConvexHull(points) + vertices = hull.vertices + outline = list(zip(points[vertices, 0], points[vertices, 1])) + + mask = Image.new("L", image.size, "black") + draw = ImageDraw.Draw(mask) + draw.polygon(outline, fill="white") + masks.append(mask) + + bboxes = create_bbox_from_mask(masks, image.size) + preview = Image.fromarray(preview) + return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)