fix(scripts): fix paths

This commit is contained in:
Dowon
2024-04-13 15:49:03 +09:00
parent e390875198
commit 89ee330271
2 changed files with 28 additions and 18 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import os
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
@@ -37,16 +38,27 @@ def hf_download(file: str, repo_id: str = REPO_ID) -> str | None:
return path
def scan_model_dir(path_: str | Path) -> list[Path]:
if not path_ or not (path := Path(path_)).is_dir():
def safe_mkdir(path: str | os.PathLike[str]) -> None:
path = Path(path)
if not path.exists() and path.parent.exists() and os.access(path.parent, os.W_OK):
path.mkdir()
def scan_model_dir(path: Path) -> list[Path]:
if not path.is_dir():
return []
return [p for p in path.rglob("*") if p.is_file() and p.suffix in (".pt", ".pth")]
return [p for p in path.rglob("*") if p.is_file() and p.suffix == ".pt"]
def get_models(
model_dir: str | Path, extra_dir: str | Path = "", huggingface: bool = True
*dirs: str | os.PathLike[str], huggingface: bool = True
) -> OrderedDict[str, str]:
model_paths = [*scan_model_dir(model_dir), *scan_model_dir(extra_dir)]
model_paths = []
for dir_ in dirs:
if not dir_:
continue
model_paths.extend(scan_model_dir(Path(dir_)))
models = OrderedDict()
if huggingface:

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import os
import platform
import re
import sys
@@ -10,7 +9,7 @@ from copy import copy
from functools import partial
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING, Any, NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple, cast
import gradio as gr
import torch
@@ -25,8 +24,8 @@ from adetailer import (
mediapipe_predict,
ultralytics_predict,
)
from adetailer.args import BBOX_SORTBY, ADetailerArgs, SkipImg2ImgOrig
from adetailer.common import PredictOutput, ensure_pil_image
from adetailer.args import BBOX_SORTBY, SCRIPT_DEFAULT, ADetailerArgs, SkipImg2ImgOrig
from adetailer.common import PredictOutput, ensure_pil_image, safe_mkdir
from adetailer.mask import (
filter_by_ratio,
filter_k_largest,
@@ -69,19 +68,18 @@ if TYPE_CHECKING:
no_huggingface = getattr(cmd_opts, "ad_no_huggingface", False)
adetailer_dir = Path(paths.models_path, "adetailer")
safe_mkdir(adetailer_dir)
extra_models_dir = shared.opts.data.get("ad_extra_models_dir", "")
model_mapping = get_models(
adetailer_dir, extra_dir=extra_models_dir, huggingface=not no_huggingface
adetailer_dir,
extra_models_dir,
huggingface=not no_huggingface,
)
txt2img_submit_button = img2img_submit_button = None
SCRIPT_DEFAULT = "dynamic_prompting,dynamic_thresholding,wildcard_recursive,wildcards,lora_block_weight,negpip,soft_inpainting"
if (
not adetailer_dir.exists()
and adetailer_dir.parent.exists()
and os.access(adetailer_dir.parent, os.W_OK)
):
adetailer_dir.mkdir()
txt2img_submit_button = img2img_submit_button = None
txt2img_submit_button = cast(gr.Button, txt2img_submit_button)
img2img_submit_button = cast(gr.Button, img2img_submit_button)
print(
f"[-] ADetailer initialized. version: {__version__}, num models: {len(model_mapping)}"