feat: check hf mirror, use local cache

This commit is contained in:
Dowon
2024-05-22 00:34:22 +09:00
parent 0fd53ce2f3
commit ba039cbbae

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import os import os
from collections import OrderedDict from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import suppress
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Generic, Optional, TypeVar from typing import Any, Generic, Optional, TypeVar
@@ -24,14 +25,22 @@ class PredictOutput(Generic[T]):
preview: Optional[Image.Image] = None preview: Optional[Image.Image] = None
def hf_download(file: str, repo_id: str = REPO_ID) -> str: def hf_download(file: str, repo_id: str = REPO_ID, check_remote: bool = True) -> str:
try: if check_remote:
path = hf_hub_download(repo_id, file) with suppress(Exception):
except Exception: return hf_hub_download(repo_id, file, etag_timeout=1)
msg = f"[-] ADetailer: Failed to load model {file!r} from huggingface"
print(msg) with suppress(Exception):
path = "INVALID" return hf_hub_download(
return path repo_id, file, etag_timeout=1, endpoint="https://hf-mirror.com"
)
with suppress(Exception):
return hf_hub_download(repo_id, file, local_files_only=True)
msg = f"[-] ADetailer: Failed to load model {file!r} from huggingface"
print(msg)
return "INVALID"
def safe_mkdir(path: str | os.PathLike[str]) -> None: def safe_mkdir(path: str | os.PathLike[str]) -> None:
@@ -46,16 +55,23 @@ def scan_model_dir(path: Path) -> list[Path]:
return [p for p in path.rglob("*") if p.is_file() and p.suffix == ".pt"] return [p for p in path.rglob("*") if p.is_file() and p.suffix == ".pt"]
def download_models(*names: str) -> dict[str, str]: def download_models(*names: str, check_remote: bool = True) -> dict[str, str]:
models = OrderedDict() models = OrderedDict()
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
for name in names: for name in names:
if "-world" in name: if "-world" in name:
models[name] = executor.submit( models[name] = executor.submit(
hf_download, name, repo_id="Bingsu/yolo-world-mirror" hf_download,
name,
repo_id="Bingsu/yolo-world-mirror",
check_remote=check_remote,
) )
else: else:
models[name] = executor.submit(hf_download, name) models[name] = executor.submit(
hf_download,
name,
check_remote=check_remote,
)
return {name: future.result() for name, future in models.items()} return {name: future.result() for name, future in models.items()}
@@ -70,16 +86,15 @@ def get_models(
model_paths.extend(scan_model_dir(Path(dir_))) model_paths.extend(scan_model_dir(Path(dir_)))
models = OrderedDict() models = OrderedDict()
if huggingface: to_download = [
to_download = [ "face_yolov8n.pt",
"face_yolov8n.pt", "face_yolov8s.pt",
"face_yolov8s.pt", "hand_yolov8n.pt",
"hand_yolov8n.pt", "person_yolov8n-seg.pt",
"person_yolov8n-seg.pt", "person_yolov8s-seg.pt",
"person_yolov8s-seg.pt", "yolov8x-worldv2.pt",
"yolov8x-worldv2.pt", ]
] models.update(download_models(*to_download, check_remote=huggingface))
models.update(download_models(*to_download))
models.update( models.update(
{ {