mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-05-01 03:31:21 +00:00
feat: check hf mirror, use local cache
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Generic, Optional, TypeVar
|
from typing import Any, Generic, Optional, TypeVar
|
||||||
@@ -24,14 +25,22 @@ class PredictOutput(Generic[T]):
|
|||||||
preview: Optional[Image.Image] = None
|
preview: Optional[Image.Image] = None
|
||||||
|
|
||||||
|
|
||||||
def hf_download(file: str, repo_id: str = REPO_ID) -> str:
|
def hf_download(file: str, repo_id: str = REPO_ID, check_remote: bool = True) -> str:
|
||||||
try:
|
if check_remote:
|
||||||
path = hf_hub_download(repo_id, file)
|
with suppress(Exception):
|
||||||
except Exception:
|
return hf_hub_download(repo_id, file, etag_timeout=1)
|
||||||
msg = f"[-] ADetailer: Failed to load model {file!r} from huggingface"
|
|
||||||
print(msg)
|
with suppress(Exception):
|
||||||
path = "INVALID"
|
return hf_hub_download(
|
||||||
return path
|
repo_id, file, etag_timeout=1, endpoint="https://hf-mirror.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
with suppress(Exception):
|
||||||
|
return hf_hub_download(repo_id, file, local_files_only=True)
|
||||||
|
|
||||||
|
msg = f"[-] ADetailer: Failed to load model {file!r} from huggingface"
|
||||||
|
print(msg)
|
||||||
|
return "INVALID"
|
||||||
|
|
||||||
|
|
||||||
def safe_mkdir(path: str | os.PathLike[str]) -> None:
|
def safe_mkdir(path: str | os.PathLike[str]) -> None:
|
||||||
@@ -46,16 +55,23 @@ def scan_model_dir(path: Path) -> list[Path]:
|
|||||||
return [p for p in path.rglob("*") if p.is_file() and p.suffix == ".pt"]
|
return [p for p in path.rglob("*") if p.is_file() and p.suffix == ".pt"]
|
||||||
|
|
||||||
|
|
||||||
def download_models(*names: str) -> dict[str, str]:
|
def download_models(*names: str, check_remote: bool = True) -> dict[str, str]:
|
||||||
models = OrderedDict()
|
models = OrderedDict()
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor() as executor:
|
||||||
for name in names:
|
for name in names:
|
||||||
if "-world" in name:
|
if "-world" in name:
|
||||||
models[name] = executor.submit(
|
models[name] = executor.submit(
|
||||||
hf_download, name, repo_id="Bingsu/yolo-world-mirror"
|
hf_download,
|
||||||
|
name,
|
||||||
|
repo_id="Bingsu/yolo-world-mirror",
|
||||||
|
check_remote=check_remote,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
models[name] = executor.submit(hf_download, name)
|
models[name] = executor.submit(
|
||||||
|
hf_download,
|
||||||
|
name,
|
||||||
|
check_remote=check_remote,
|
||||||
|
)
|
||||||
return {name: future.result() for name, future in models.items()}
|
return {name: future.result() for name, future in models.items()}
|
||||||
|
|
||||||
|
|
||||||
@@ -70,16 +86,15 @@ def get_models(
|
|||||||
model_paths.extend(scan_model_dir(Path(dir_)))
|
model_paths.extend(scan_model_dir(Path(dir_)))
|
||||||
|
|
||||||
models = OrderedDict()
|
models = OrderedDict()
|
||||||
if huggingface:
|
to_download = [
|
||||||
to_download = [
|
"face_yolov8n.pt",
|
||||||
"face_yolov8n.pt",
|
"face_yolov8s.pt",
|
||||||
"face_yolov8s.pt",
|
"hand_yolov8n.pt",
|
||||||
"hand_yolov8n.pt",
|
"person_yolov8n-seg.pt",
|
||||||
"person_yolov8n-seg.pt",
|
"person_yolov8s-seg.pt",
|
||||||
"person_yolov8s-seg.pt",
|
"yolov8x-worldv2.pt",
|
||||||
"yolov8x-worldv2.pt",
|
]
|
||||||
]
|
models.update(download_models(*to_download, check_remote=huggingface))
|
||||||
models.update(download_models(*to_download))
|
|
||||||
|
|
||||||
models.update(
|
models.update(
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user