feat: rich print, traceback

This commit is contained in:
Bingsu
2023-07-01 09:45:45 +09:00
parent 30330aa4d0
commit eec616afa3
4 changed files with 27 additions and 7 deletions

View File

@@ -7,6 +7,7 @@ from typing import Optional, Union
from huggingface_hub import hf_hub_download
from PIL import Image, ImageDraw
from rich import print
repo_id = "Bingsu/adetailer"
@@ -22,7 +23,7 @@ def hf_download(file: str):
try:
path = hf_hub_download(repo_id, file)
except Exception:
msg = f"[-] ADetailer: Failed to load model {file!r}"
msg = f"[-] ADetailer: Failed to load model {file!r} from huggingface"
print(msg)
path = "INVALID"
return path

View File

@@ -20,7 +20,7 @@ def mediapipe_predict(
if model_type in mapping:
func = mapping[model_type]
return func(image, confidence)
msg = f"[-] ADetailer: Invalid mediapipe model type: {model_type}"
msg = f"[-] ADetailer: Invalid mediapipe model type: {model_type}, Available: {list(mapping.keys())!r}"
raise RuntimeError(msg)

View File

@@ -1,11 +1,12 @@
from __future__ import annotations
import io
import os
import platform
import re
import sys
import traceback
from contextlib import contextmanager, suppress
from contextlib import contextmanager
from copy import copy, deepcopy
from functools import partial
from pathlib import Path
@@ -14,6 +15,8 @@ from typing import Any
import gradio as gr
import torch
from rich import print
from rich.console import Console
import modules
from adetailer import (
@@ -42,10 +45,6 @@ from sd_webui.processing import (
)
from sd_webui.shared import cmd_opts, opts, state
with suppress(ImportError):
from rich import print
no_huggingface = getattr(cmd_opts, "ad_no_huggingface", False)
adetailer_dir = Path(models_path, "adetailer")
model_mapping = get_models(adetailer_dir, huggingface=not no_huggingface)
@@ -84,6 +83,18 @@ def pause_total_tqdm():
opts.data["multiple_tqdm"] = orig
@contextmanager
def rich_traceback():
string = io.StringIO()
console = Console(file=string, force_terminal=True)
try:
yield
except Exception as e:
console.print_exception(show_locals=True)
output = "\n" + string.getvalue()
raise RuntimeError(output) from e
class AfterDetailerScript(scripts.Script):
def __init__(self):
super().__init__()
@@ -519,6 +530,9 @@ class AfterDetailerScript(scripts.Script):
if is_mediapipe:
print(f"mediapipe: {steps} detected.")
_user_pt = p.prompt
_user_ng = p.negative_prompt
p2 = copy(i2i)
for j in range(steps):
p2.image_mask = masks[j]
@@ -541,6 +555,7 @@ class AfterDetailerScript(scripts.Script):
return False
@rich_traceback()
def postprocess_image(self, p, pp, *args_):
if getattr(p, "_disable_adetailer", False):
return

View File

@@ -5,6 +5,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable
def on_app_started(callback: Callable):
pass
def on_ui_settings(callback: Callable):
pass
@@ -17,6 +20,7 @@ if TYPE_CHECKING:
else:
from modules.script_callbacks import (
on_after_component,
on_app_started,
on_before_ui,
on_ui_settings,
)