diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..fc3e22d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,29 @@ +from functools import cache + +import pytest +import requests +from PIL import Image + + +@cache +def _sample_image(): + url = "https://i.imgur.com/E5OVXvn.png" + resp = requests.get(url, stream=True) + return Image.open(resp.raw) + + +@cache +def _sample_image2(): + url = "https://i.imgur.com/px5UT7T.png" + resp = requests.get(url, stream=True) + return Image.open(resp.raw) + + +@pytest.fixture() +def sample_image(): + return _sample_image() + + +@pytest.fixture() +def sample_image2(): + return _sample_image2() diff --git a/tests/test_mediapipe.py b/tests/test_mediapipe.py new file mode 100644 index 0000000..7ddcdfe --- /dev/null +++ b/tests/test_mediapipe.py @@ -0,0 +1,18 @@ +import pytest +from PIL import Image + +from adetailer.mediapipe import mediapipe_predict + + +@pytest.mark.parametrize( + "model_name", + [ + "mediapipe_face_short", + "mediapipe_face_full", + "mediapipe_face_mesh", + "mediapipe_face_mesh_eyes_only", + ], +) +def test_mediapipe(sample_image2: Image.Image, model_name: str): + result = mediapipe_predict(model_name, sample_image2) + assert result.preview is not None diff --git a/tests/test_ultralytics.py b/tests/test_ultralytics.py new file mode 100644 index 0000000..ad855ad --- /dev/null +++ b/tests/test_ultralytics.py @@ -0,0 +1,48 @@ +import pytest +from huggingface_hub import hf_hub_download +from PIL import Image + +from adetailer.ultralytics import ultralytics_predict + +repo_id = "Bingsu/adetailer" + + +@pytest.mark.parametrize( + "model_name", + [ + "face_yolov8n.pt", + "face_yolov8n_v2.pt", + "face_yolov8s.pt", + "hand_yolov8n.pt", + "hand_yolov8s.pt", + "person_yolov8n-seg.pt", + "person_yolov8s-seg.pt", + "person_yolov8m-seg.pt", + "deepfashion2_yolov8s-seg.pt", + ], +) +def test_ultralytics_hf_models(sample_image: Image.Image, model_name: str): + model_path = hf_hub_download(repo_id, model_name) + result = ultralytics_predict(model_path, sample_image) + assert result.preview is not None + + +def test_yolo_world_default(sample_image: Image.Image): + result = ultralytics_predict("yolov8l-world.pt", sample_image) + assert result.preview is not None + + +@pytest.mark.parametrize( + "klass", + [ + "person", + "bird", + "yellow bird", + "person,glasses,headphone", + "person,bird", + "glasses,yellow bird", + ], +) +def test_yolo_world(sample_image2: Image.Image, klass: str): + result = ultralytics_predict("yolov8l-world.pt", sample_image2, classes=klass) + assert result.preview is not None