fix: download yolo world from hf

This commit is contained in:
Dowon
2024-03-01 11:02:42 +09:00
parent 9e9dcd5bca
commit 9d46fcd714
2 changed files with 10 additions and 8 deletions

View File

@@ -9,7 +9,7 @@ from huggingface_hub import hf_hub_download
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from rich import print from rich import print
repo_id = "Bingsu/adetailer" REPO_ID = "Bingsu/adetailer"
_download_failed = False _download_failed = False
@@ -20,7 +20,7 @@ class PredictOutput:
preview: Optional[Image.Image] = None preview: Optional[Image.Image] = None
def hf_download(file: str): def hf_download(file: str, repo_id: str = REPO_ID) -> str | None:
global _download_failed global _download_failed
if _download_failed: if _download_failed:
@@ -56,11 +56,13 @@ def get_models(
"hand_yolov8n.pt": hf_download("hand_yolov8n.pt"), "hand_yolov8n.pt": hf_download("hand_yolov8n.pt"),
"person_yolov8n-seg.pt": hf_download("person_yolov8n-seg.pt"), "person_yolov8n-seg.pt": hf_download("person_yolov8n-seg.pt"),
"person_yolov8s-seg.pt": hf_download("person_yolov8s-seg.pt"), "person_yolov8s-seg.pt": hf_download("person_yolov8s-seg.pt"),
"yolov8x-world.pt": hf_download(
"yolov8x-world.pt", repo_id="Bingsu/yolo-world-mirror"
),
} }
) )
models.update( models.update(
{ {
"yolov8x-world.pt": "yolov8x-world.pt",
"mediapipe_face_full": None, "mediapipe_face_full": None,
"mediapipe_face_short": None, "mediapipe_face_short": None,
"mediapipe_face_mesh": None, "mediapipe_face_mesh": None,

View File

@@ -4,8 +4,6 @@ from PIL import Image
from adetailer.ultralytics import ultralytics_predict from adetailer.ultralytics import ultralytics_predict
repo_id = "Bingsu/adetailer"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
@@ -22,13 +20,14 @@ repo_id = "Bingsu/adetailer"
], ],
) )
def test_ultralytics_hf_models(sample_image: Image.Image, model_name: str): def test_ultralytics_hf_models(sample_image: Image.Image, model_name: str):
model_path = hf_hub_download(repo_id, model_name) model_path = hf_hub_download("Bingsu/adetailer", model_name)
result = ultralytics_predict(model_path, sample_image) result = ultralytics_predict(model_path, sample_image)
assert result.preview is not None assert result.preview is not None
def test_yolo_world_default(sample_image: Image.Image): def test_yolo_world_default(sample_image: Image.Image):
result = ultralytics_predict("yolov8x-world.pt", sample_image) model_path = hf_hub_download("Bingsu/yolo-world-mirror", "yolov8x-world.pt")
result = ultralytics_predict(model_path, sample_image)
assert result.preview is not None assert result.preview is not None
@@ -44,5 +43,6 @@ def test_yolo_world_default(sample_image: Image.Image):
], ],
) )
def test_yolo_world(sample_image2: Image.Image, klass: str): def test_yolo_world(sample_image2: Image.Image, klass: str):
result = ultralytics_predict("yolov8x-world.pt", sample_image2, classes=klass) model_path = hf_hub_download("Bingsu/yolo-world-mirror", "yolov8x-world.pt")
result = ultralytics_predict(model_path, sample_image2, classes=klass)
assert result.preview is not None assert result.preview is not None