From ba039cbbae879a50a1bb1756c22b4503d3133eed Mon Sep 17 00:00:00 2001 From: Dowon Date: Wed, 22 May 2024 00:34:22 +0900 Subject: [PATCH] feat: check hf mirror, use local cache --- adetailer/common.py | 57 ++++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/adetailer/common.py b/adetailer/common.py index 470bb2f..f9e42fc 100644 --- a/adetailer/common.py +++ b/adetailer/common.py @@ -3,6 +3,7 @@ from __future__ import annotations import os from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor +from contextlib import suppress from dataclasses import dataclass, field from pathlib import Path from typing import Any, Generic, Optional, TypeVar @@ -24,14 +25,22 @@ class PredictOutput(Generic[T]): preview: Optional[Image.Image] = None -def hf_download(file: str, repo_id: str = REPO_ID) -> str: - try: - path = hf_hub_download(repo_id, file) - except Exception: - msg = f"[-] ADetailer: Failed to load model {file!r} from huggingface" - print(msg) - path = "INVALID" - return path +def hf_download(file: str, repo_id: str = REPO_ID, check_remote: bool = True) -> str: + if check_remote: + with suppress(Exception): + return hf_hub_download(repo_id, file, etag_timeout=1) + + with suppress(Exception): + return hf_hub_download( + repo_id, file, etag_timeout=1, endpoint="https://hf-mirror.com" + ) + + with suppress(Exception): + return hf_hub_download(repo_id, file, local_files_only=True) + + msg = f"[-] ADetailer: Failed to load model {file!r} from huggingface" + print(msg) + return "INVALID" def safe_mkdir(path: str | os.PathLike[str]) -> None: @@ -46,16 +55,23 @@ def scan_model_dir(path: Path) -> list[Path]: return [p for p in path.rglob("*") if p.is_file() and p.suffix == ".pt"] -def download_models(*names: str) -> dict[str, str]: +def download_models(*names: str, check_remote: bool = True) -> dict[str, str]: models = OrderedDict() with ThreadPoolExecutor() as executor: for name in names: if "-world" in name: models[name] = executor.submit( - hf_download, name, repo_id="Bingsu/yolo-world-mirror" + hf_download, + name, + repo_id="Bingsu/yolo-world-mirror", + check_remote=check_remote, ) else: - models[name] = executor.submit(hf_download, name) + models[name] = executor.submit( + hf_download, + name, + check_remote=check_remote, + ) return {name: future.result() for name, future in models.items()} @@ -70,16 +86,15 @@ def get_models( model_paths.extend(scan_model_dir(Path(dir_))) models = OrderedDict() - if huggingface: - to_download = [ - "face_yolov8n.pt", - "face_yolov8s.pt", - "hand_yolov8n.pt", - "person_yolov8n-seg.pt", - "person_yolov8s-seg.pt", - "yolov8x-worldv2.pt", - ] - models.update(download_models(*to_download)) + to_download = [ + "face_yolov8n.pt", + "face_yolov8s.pt", + "hand_yolov8n.pt", + "person_yolov8n-seg.pt", + "person_yolov8s-seg.pt", + "yolov8x-worldv2.pt", + ] + models.update(download_models(*to_download, check_remote=huggingface)) models.update( {