feat: check hf mirror, use local cache

This commit is contained in:
Dowon
2024-05-22 00:34:22 +09:00
parent 0fd53ce2f3
commit ba039cbbae

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import os
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Generic, Optional, TypeVar
@@ -24,14 +25,22 @@ class PredictOutput(Generic[T]):
preview: Optional[Image.Image] = None
def hf_download(file: str, repo_id: str = REPO_ID) -> str:
try:
path = hf_hub_download(repo_id, file)
except Exception:
msg = f"[-] ADetailer: Failed to load model {file!r} from huggingface"
print(msg)
path = "INVALID"
return path
def hf_download(file: str, repo_id: str = REPO_ID, check_remote: bool = True) -> str:
if check_remote:
with suppress(Exception):
return hf_hub_download(repo_id, file, etag_timeout=1)
with suppress(Exception):
return hf_hub_download(
repo_id, file, etag_timeout=1, endpoint="https://hf-mirror.com"
)
with suppress(Exception):
return hf_hub_download(repo_id, file, local_files_only=True)
msg = f"[-] ADetailer: Failed to load model {file!r} from huggingface"
print(msg)
return "INVALID"
def safe_mkdir(path: str | os.PathLike[str]) -> None:
@@ -46,16 +55,23 @@ def scan_model_dir(path: Path) -> list[Path]:
return [p for p in path.rglob("*") if p.is_file() and p.suffix == ".pt"]
def download_models(*names: str) -> dict[str, str]:
def download_models(*names: str, check_remote: bool = True) -> dict[str, str]:
models = OrderedDict()
with ThreadPoolExecutor() as executor:
for name in names:
if "-world" in name:
models[name] = executor.submit(
hf_download, name, repo_id="Bingsu/yolo-world-mirror"
hf_download,
name,
repo_id="Bingsu/yolo-world-mirror",
check_remote=check_remote,
)
else:
models[name] = executor.submit(hf_download, name)
models[name] = executor.submit(
hf_download,
name,
check_remote=check_remote,
)
return {name: future.result() for name, future in models.items()}
@@ -70,16 +86,15 @@ def get_models(
model_paths.extend(scan_model_dir(Path(dir_)))
models = OrderedDict()
if huggingface:
to_download = [
"face_yolov8n.pt",
"face_yolov8s.pt",
"hand_yolov8n.pt",
"person_yolov8n-seg.pt",
"person_yolov8s-seg.pt",
"yolov8x-worldv2.pt",
]
models.update(download_models(*to_download))
to_download = [
"face_yolov8n.pt",
"face_yolov8s.pt",
"hand_yolov8n.pt",
"person_yolov8n-seg.pt",
"person_yolov8s-seg.pt",
"yolov8x-worldv2.pt",
]
models.update(download_models(*to_download, check_remote=huggingface))
models.update(
{