diff --git a/adetailer/mediapipe.py b/adetailer/mediapipe.py index 067aa53..95084cf 100644 --- a/adetailer/mediapipe.py +++ b/adetailer/mediapipe.py @@ -97,6 +97,7 @@ def mediapipe_face_mesh( preview = arr.copy() masks = [] + confidences = [] for landmarks in pred.multi_face_landmarks: draw_util.draw_landmarks( @@ -116,10 +117,13 @@ def mediapipe_face_mesh( draw = ImageDraw.Draw(mask) draw.polygon(outline, fill="white") masks.append(mask) + confidences.append(1.0) # Confidence is unknown bboxes = create_bbox_from_mask(masks, image.size) preview = Image.fromarray(preview) - return PredictOutput(bboxes=bboxes, masks=masks, preview=preview) + return PredictOutput( + bboxes=bboxes, masks=masks, confidences=confidences, preview=preview + ) def mediapipe_face_mesh_eyes_only( @@ -145,6 +149,8 @@ def mediapipe_face_mesh_eyes_only( preview = image.copy() masks = [] + confidences = [] + for landmarks in pred.multi_face_landmarks: points = np.array( [[land.x * w, land.y * h] for land in landmarks.landmark], dtype=int @@ -159,10 +165,13 @@ def mediapipe_face_mesh_eyes_only( for outline in (left_outline, right_outline): draw.polygon(outline, fill="white") masks.append(mask) + confidences.append(1.0) # Confidence is unknown bboxes = create_bbox_from_mask(masks, image.size) preview = draw_preview(preview, bboxes, masks) - return PredictOutput(bboxes=bboxes, masks=masks, preview=preview) + return PredictOutput( + bboxes=bboxes, masks=masks, confidences=confidences, preview=preview + ) def draw_preview( diff --git a/tests/test_mediapipe.py b/tests/test_mediapipe.py index 7ddcdfe..900d056 100644 --- a/tests/test_mediapipe.py +++ b/tests/test_mediapipe.py @@ -16,3 +16,6 @@ from adetailer.mediapipe import mediapipe_predict def test_mediapipe(sample_image2: Image.Image, model_name: str): result = mediapipe_predict(model_name, sample_image2) assert result.preview is not None + assert len(result.bboxes) > 0 + assert len(result.masks) > 0 + assert len(result.confidences) > 0 diff --git a/tests/test_ultralytics.py b/tests/test_ultralytics.py index c772607..7ae53a6 100644 --- a/tests/test_ultralytics.py +++ b/tests/test_ultralytics.py @@ -48,3 +48,6 @@ def test_yolo_world(sample_image2: Image.Image, klass: str): model_path = hf_hub_download("Bingsu/yolo-world-mirror", "yolov8x-worldv2.pt") result = ultralytics_predict(model_path, sample_image2, classes=klass) assert result.preview is not None + assert len(result.bboxes) > 0 + assert len(result.masks) > 0 + assert len(result.confidences) > 0