diff --git a/adetailer/common.py b/adetailer/common.py index 12fd77b..ffc3323 100644 --- a/adetailer/common.py +++ b/adetailer/common.py @@ -2,6 +2,7 @@ from __future__ import annotations import os from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from pathlib import Path from typing import Any, Generic, Optional, TypeVar @@ -24,7 +25,7 @@ class PredictOutput(Generic[T]): preview: Optional[Image.Image] = None -def hf_download(file: str, repo_id: str = REPO_ID) -> str | None: +def hf_download(file: str, repo_id: str = REPO_ID) -> str: global _download_failed if _download_failed: @@ -52,6 +53,19 @@ def scan_model_dir(path: Path) -> list[Path]: return [p for p in path.rglob("*") if p.is_file() and p.suffix == ".pt"] +def download_models(*names: str) -> dict[str, str]: + models = OrderedDict() + with ThreadPoolExecutor() as executor: + for name in names: + if "-world" in name: + models[name] = executor.submit( + hf_download, name, repo_id="Bingsu/yolo-world-mirror" + ) + else: + models[name] = executor.submit(hf_download, name) + return {name: future.result() for name, future in models.items()} + + def get_models( *dirs: str | os.PathLike[str], huggingface: bool = True ) -> OrderedDict[str, str]: @@ -64,18 +78,16 @@ def get_models( models = OrderedDict() if huggingface: - models.update( - { - "face_yolov8n.pt": hf_download("face_yolov8n.pt"), - "face_yolov8s.pt": hf_download("face_yolov8s.pt"), - "hand_yolov8n.pt": hf_download("hand_yolov8n.pt"), - "person_yolov8n-seg.pt": hf_download("person_yolov8n-seg.pt"), - "person_yolov8s-seg.pt": hf_download("person_yolov8s-seg.pt"), - "yolov8x-worldv2.pt": hf_download( - "yolov8x-worldv2.pt", repo_id="Bingsu/yolo-world-mirror" - ), - } - ) + to_download = [ + "face_yolov8n.pt", + "face_yolov8s.pt", + "hand_yolov8n.pt", + "person_yolov8n-seg.pt", + "person_yolov8s-seg.pt", + "yolov8x-worldv2.pt", + ] + models.update(download_models(*to_download)) + models.update( { "mediapipe_face_full": "mediapipe_face_full",