feat: mediapipe face mesh

This commit is contained in:
Bingsu
2023-05-29 14:56:00 +09:00
parent 45afcc0a8b
commit 841f511216
2 changed files with 91 additions and 15 deletions

View File

@@ -13,7 +13,7 @@ repo_id = "Bingsu/adetailer"
@dataclass
class PredictOutput:
bboxes: list[list[float]] = field(default_factory=list)
bboxes: list[list[int | float]] = field(default_factory=list)
masks: list[Image.Image] = field(default_factory=list)
preview: Optional[Image.Image] = None
@@ -54,6 +54,7 @@ def get_models(
{
"mediapipe_face_full": None,
"mediapipe_face_short": None,
"mediapipe_face_mesh": None,
}
)
@@ -94,3 +95,29 @@ def create_mask_from_bbox(
mask_draw.rectangle(bbox, fill=255)
masks.append(mask)
return masks
def create_bbox_from_mask(
masks: list[Image.Image], shape: tuple[int, int]
) -> list[list[int]]:
"""
Parameters
----------
masks: list[Image.Image]
A list of masks
shape: tuple[int, int]
shape of the image (width, height)
Returns
-------
bboxes: list[list[float]]
A list of bounding boxes
"""
bboxes = []
for mask in masks:
mask = mask.resize(shape)
bbox = mask.getbbox()
if bbox is not None:
bboxes.append(list(bbox))
return bboxes

View File

@@ -1,19 +1,33 @@
from __future__ import annotations
from functools import partial
import numpy as np
from PIL import Image
from PIL import Image, ImageDraw
from adetailer import PredictOutput
from adetailer.common import create_mask_from_bbox
from adetailer.common import create_bbox_from_mask, create_mask_from_bbox
def mediapipe_predict(
model_type: int | str, image: Image.Image, confidence: float = 0.3
model_type: str, image: Image.Image, confidence: float = 0.3
) -> PredictOutput:
mapping = {
"mediapipe_face_short": partial(mediapipe_face_detection, model_type=0),
"mediapipe_face_full": partial(mediapipe_face_detection, model_type=1),
"mediapipe_face_mesh": mediapipe_face_mesh,
}
if model_type in mapping:
func = mapping[model_type]
return func(image, confidence)
raise RuntimeError(f"[-] ADetailer: Invalid mediapipe model type: {model_type}")
def mediapipe_face_detection(
model_type: int, image: Image.Image, confidence: float = 0.3
) -> PredictOutput:
import mediapipe as mp
if isinstance(model_type, str):
model_type = mediapipe_model_name_to_type(model_type)
img_width, img_height = image.size
mp_face_detection = mp.solutions.face_detection
@@ -51,12 +65,47 @@ def mediapipe_predict(
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
def mediapipe_model_name_to_type(name: str) -> int:
name = name.lower()
mapping = {
"mediapipe_face_short": 0,
"mediapipe_face_full": 1,
}
if name not in mapping:
raise ValueError(f"[-] ADetailer: Invalid model name: {name}")
return mapping[name]
def mediapipe_face_mesh(image: Image.Image, confidence: float = 0.3) -> PredictOutput:
import mediapipe as mp
from scipy.spatial import ConvexHull
mp_face_mesh = mp.solutions.face_mesh
draw_util = mp.solutions.drawing_utils
drawing_styles = mp.solutions.drawing_styles
w, h = image.size
with mp_face_mesh.FaceMesh(
static_image_mode=True, max_num_faces=20, min_detection_confidence=confidence
) as face_mesh:
arr = np.array(image)
pred = face_mesh.process(arr)
if pred.multi_face_landmarks is None:
return PredictOutput()
preview = arr.copy()
masks = []
for landmarks in pred.multi_face_landmarks:
draw_util.draw_landmarks(
image=preview,
landmark_list=landmarks,
connections=mp_face_mesh.FACEMESH_TESSELATION,
landmark_drawing_spec=None,
connection_drawing_spec=drawing_styles.get_default_face_mesh_tesselation_style(),
)
points = np.array([(land.x * w, land.y * h) for land in landmarks.landmark])
hull = ConvexHull(points)
vertices = hull.vertices
outline = list(zip(points[vertices, 0], points[vertices, 1]))
mask = Image.new("L", image.size, "black")
draw = ImageDraw.Draw(mask)
draw.polygon(outline, fill="white")
masks.append(mask)
bboxes = create_bbox_from_mask(masks, image.size)
preview = Image.fromarray(preview)
return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)