mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-03-04 04:50:02 +00:00
feat: concurrent download models
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
@@ -24,7 +25,7 @@ class PredictOutput(Generic[T]):
|
||||
preview: Optional[Image.Image] = None
|
||||
|
||||
|
||||
def hf_download(file: str, repo_id: str = REPO_ID) -> str | None:
|
||||
def hf_download(file: str, repo_id: str = REPO_ID) -> str:
|
||||
global _download_failed
|
||||
|
||||
if _download_failed:
|
||||
@@ -52,6 +53,19 @@ def scan_model_dir(path: Path) -> list[Path]:
|
||||
return [p for p in path.rglob("*") if p.is_file() and p.suffix == ".pt"]
|
||||
|
||||
|
||||
def download_models(*names: str) -> dict[str, str]:
|
||||
models = OrderedDict()
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for name in names:
|
||||
if "-world" in name:
|
||||
models[name] = executor.submit(
|
||||
hf_download, name, repo_id="Bingsu/yolo-world-mirror"
|
||||
)
|
||||
else:
|
||||
models[name] = executor.submit(hf_download, name)
|
||||
return {name: future.result() for name, future in models.items()}
|
||||
|
||||
|
||||
def get_models(
|
||||
*dirs: str | os.PathLike[str], huggingface: bool = True
|
||||
) -> OrderedDict[str, str]:
|
||||
@@ -64,18 +78,16 @@ def get_models(
|
||||
|
||||
models = OrderedDict()
|
||||
if huggingface:
|
||||
models.update(
|
||||
{
|
||||
"face_yolov8n.pt": hf_download("face_yolov8n.pt"),
|
||||
"face_yolov8s.pt": hf_download("face_yolov8s.pt"),
|
||||
"hand_yolov8n.pt": hf_download("hand_yolov8n.pt"),
|
||||
"person_yolov8n-seg.pt": hf_download("person_yolov8n-seg.pt"),
|
||||
"person_yolov8s-seg.pt": hf_download("person_yolov8s-seg.pt"),
|
||||
"yolov8x-worldv2.pt": hf_download(
|
||||
"yolov8x-worldv2.pt", repo_id="Bingsu/yolo-world-mirror"
|
||||
),
|
||||
}
|
||||
)
|
||||
to_download = [
|
||||
"face_yolov8n.pt",
|
||||
"face_yolov8s.pt",
|
||||
"hand_yolov8n.pt",
|
||||
"person_yolov8n-seg.pt",
|
||||
"person_yolov8s-seg.pt",
|
||||
"yolov8x-worldv2.pt",
|
||||
]
|
||||
models.update(download_models(*to_download))
|
||||
|
||||
models.update(
|
||||
{
|
||||
"mediapipe_face_full": "mediapipe_face_full",
|
||||
|
||||
Reference in New Issue
Block a user