From 9d46fcd714fc84dd8489e1b396e7b6b4ba73bdda Mon Sep 17 00:00:00 2001 From: Dowon Date: Fri, 1 Mar 2024 11:02:42 +0900 Subject: [PATCH] fix: download yolo world from hf --- adetailer/common.py | 8 +++++--- tests/test_ultralytics.py | 10 +++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/adetailer/common.py b/adetailer/common.py index 5a8b1ab..317d9c6 100644 --- a/adetailer/common.py +++ b/adetailer/common.py @@ -9,7 +9,7 @@ from huggingface_hub import hf_hub_download from PIL import Image, ImageDraw from rich import print -repo_id = "Bingsu/adetailer" +REPO_ID = "Bingsu/adetailer" _download_failed = False @@ -20,7 +20,7 @@ class PredictOutput: preview: Optional[Image.Image] = None -def hf_download(file: str): +def hf_download(file: str, repo_id: str = REPO_ID) -> str | None: global _download_failed if _download_failed: @@ -56,11 +56,13 @@ def get_models( "hand_yolov8n.pt": hf_download("hand_yolov8n.pt"), "person_yolov8n-seg.pt": hf_download("person_yolov8n-seg.pt"), "person_yolov8s-seg.pt": hf_download("person_yolov8s-seg.pt"), + "yolov8x-world.pt": hf_download( + "yolov8x-world.pt", repo_id="Bingsu/yolo-world-mirror" + ), } ) models.update( { - "yolov8x-world.pt": "yolov8x-world.pt", "mediapipe_face_full": None, "mediapipe_face_short": None, "mediapipe_face_mesh": None, diff --git a/tests/test_ultralytics.py b/tests/test_ultralytics.py index f3885de..fe40097 100644 --- a/tests/test_ultralytics.py +++ b/tests/test_ultralytics.py @@ -4,8 +4,6 @@ from PIL import Image from adetailer.ultralytics import ultralytics_predict -repo_id = "Bingsu/adetailer" - @pytest.mark.parametrize( "model_name", @@ -22,13 +20,14 @@ repo_id = "Bingsu/adetailer" ], ) def test_ultralytics_hf_models(sample_image: Image.Image, model_name: str): - model_path = hf_hub_download(repo_id, model_name) + model_path = hf_hub_download("Bingsu/adetailer", model_name) result = ultralytics_predict(model_path, sample_image) assert result.preview is not None def test_yolo_world_default(sample_image: Image.Image): - result = ultralytics_predict("yolov8x-world.pt", sample_image) + model_path = hf_hub_download("Bingsu/yolo-world-mirror", "yolov8x-world.pt") + result = ultralytics_predict(model_path, sample_image) assert result.preview is not None @@ -44,5 +43,6 @@ def test_yolo_world_default(sample_image: Image.Image): ], ) def test_yolo_world(sample_image2: Image.Image, klass: str): - result = ultralytics_predict("yolov8x-world.pt", sample_image2, classes=klass) + model_path = hf_hub_download("Bingsu/yolo-world-mirror", "yolov8x-world.pt") + result = ultralytics_predict(model_path, sample_image2, classes=klass) assert result.preview is not None