feat: concurrent download models

This commit is contained in:
Dowon
2024-05-20 00:11:01 +09:00
parent 599c3cc7fc
commit 2f9b9ab0f6

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import os
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Generic, Optional, TypeVar
@@ -24,7 +25,7 @@ class PredictOutput(Generic[T]):
preview: Optional[Image.Image] = None
def hf_download(file: str, repo_id: str = REPO_ID) -> str | None:
def hf_download(file: str, repo_id: str = REPO_ID) -> str:
global _download_failed
if _download_failed:
@@ -52,6 +53,19 @@ def scan_model_dir(path: Path) -> list[Path]:
return [p for p in path.rglob("*") if p.is_file() and p.suffix == ".pt"]
def download_models(*names: str) -> dict[str, str]:
models = OrderedDict()
with ThreadPoolExecutor() as executor:
for name in names:
if "-world" in name:
models[name] = executor.submit(
hf_download, name, repo_id="Bingsu/yolo-world-mirror"
)
else:
models[name] = executor.submit(hf_download, name)
return {name: future.result() for name, future in models.items()}
def get_models(
*dirs: str | os.PathLike[str], huggingface: bool = True
) -> OrderedDict[str, str]:
@@ -64,18 +78,16 @@ def get_models(
models = OrderedDict()
if huggingface:
models.update(
{
"face_yolov8n.pt": hf_download("face_yolov8n.pt"),
"face_yolov8s.pt": hf_download("face_yolov8s.pt"),
"hand_yolov8n.pt": hf_download("hand_yolov8n.pt"),
"person_yolov8n-seg.pt": hf_download("person_yolov8n-seg.pt"),
"person_yolov8s-seg.pt": hf_download("person_yolov8s-seg.pt"),
"yolov8x-worldv2.pt": hf_download(
"yolov8x-worldv2.pt", repo_id="Bingsu/yolo-world-mirror"
),
}
)
to_download = [
"face_yolov8n.pt",
"face_yolov8s.pt",
"hand_yolov8n.pt",
"person_yolov8n-seg.pt",
"person_yolov8s-seg.pt",
"yolov8x-worldv2.pt",
]
models.update(download_models(*to_download))
models.update(
{
"mediapipe_face_full": "mediapipe_face_full",