diff --git a/CHANGELOG.md b/CHANGELOG.md index 44daffc..f92f104 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## 2023-07-02 + +- v23.7.0 +- `NansException`이 발생하면 로그에 표시하고 원본 이미지를 반환하게 설정 +- `rich`를 사용한 에러 트레이싱 + - install.py에 `rich` 추가 +- 생성 중에 컴포넌트의 값을 변경하면 args의 값도 함께 변경되는 문제 수정 (issue #180) + ## 2023-06-28 - v23.6.4 diff --git a/adetailer/__version__.py b/adetailer/__version__.py index 5da8775..f4744b8 100644 --- a/adetailer/__version__.py +++ b/adetailer/__version__.py @@ -1 +1 @@ -__version__ = "23.6.4" +__version__ = "23.7.0" diff --git a/adetailer/common.py b/adetailer/common.py index 08477ed..481adbb 100644 --- a/adetailer/common.py +++ b/adetailer/common.py @@ -7,6 +7,7 @@ from typing import Optional, Union from huggingface_hub import hf_hub_download from PIL import Image, ImageDraw +from rich import print repo_id = "Bingsu/adetailer" @@ -22,7 +23,7 @@ def hf_download(file: str): try: path = hf_hub_download(repo_id, file) except Exception: - msg = f"[-] ADetailer: Failed to load model {file!r}" + msg = f"[-] ADetailer: Failed to load model {file!r} from huggingface" print(msg) path = "INVALID" return path diff --git a/adetailer/mediapipe.py b/adetailer/mediapipe.py index 7f949d3..0ec6b16 100644 --- a/adetailer/mediapipe.py +++ b/adetailer/mediapipe.py @@ -20,7 +20,7 @@ def mediapipe_predict( if model_type in mapping: func = mapping[model_type] return func(image, confidence) - msg = f"[-] ADetailer: Invalid mediapipe model type: {model_type}" + msg = f"[-] ADetailer: Invalid mediapipe model type: {model_type}, Available: {list(mapping.keys())!r}" raise RuntimeError(msg) diff --git a/adetailer/traceback.py b/adetailer/traceback.py new file mode 100644 index 0000000..46dc85f --- /dev/null +++ b/adetailer/traceback.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import io +import platform +import sys +from typing import Any, Callable + +from rich.console import Console, Group +from rich.panel import Panel +from rich.table import Table +from rich.traceback import Traceback + +from adetailer.__version__ import __version__ + + +def processing(*args: Any) -> dict[str, Any]: + try: + from modules.processing import ( + StableDiffusionProcessingImg2Img, + StableDiffusionProcessingTxt2Img, + ) + except ImportError: + return {} + + p = None + for arg in args: + if isinstance( + arg, (StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img) + ): + p = arg + break + + if p is None: + return {} + + info = { + "prompt": p.prompt, + "negative_prompt": p.negative_prompt, + "n_iter": p.n_iter, + "batch_size": p.batch_size, + "width": p.width, + "height": p.height, + "sampler_name": p.sampler_name, + "enable_hr": getattr(p, "enable_hr", False), + "hr_upscaler": getattr(p, "hr_upscaler", ""), + } + + info.update(sd_models()) + return info + + +def sd_models() -> dict[str, str]: + try: + from modules import shared + + opts = shared.opts + except Exception: + return {} + + return { + "checkpoint": getattr(opts, "sd_model_checkpoint", "------"), + "vae": getattr(opts, "sd_vae", "------"), + "unet": getattr(opts, "sd_unet", "------"), + } + + +def ad_args(*args: Any) -> dict[str, str]: + ad_args = [ + arg + for arg in args + if isinstance(arg, dict) and arg.get("ad_model", "None") != "None" + ] + if not ad_args: + return {} + + arg0 = ad_args[0] + return { + "version": __version__, + "ad_model": arg0["ad_model"], + "ad_prompt": arg0.get("ad_prompt", ""), + "ad_negative_prompt": arg0.get("ad_negative_prompt", ""), + "ad_controlnet_model": arg0.get("ad_controlnet_model", "None"), + } + + +def sys_info() -> dict[str, Any]: + try: + import launch + + version = launch.git_tag() + commit = launch.commit_hash() + except Exception: + version = commit = "------" + + return { + "Platform": platform.platform(), + "Python": sys.version, + "Version": version, + "Commit": commit, + "Commandline": sys.argv, + } + + +def get_table(title: str, data: dict[str, Any]) -> Table: + table = Table(title=title, highlight=True) + table.add_column(" ", justify="right", style="dim") + table.add_column("Value") + for key, value in data.items(): + if not isinstance(value, str): + value = repr(value) + table.add_row(key, value) + + return table + + +def force_terminal_value(): + try: + from modules.shared import cmd_opts + + return True if hasattr(cmd_opts, "skip_torch_cuda_test") else None + except Exception: + return None + + +def rich_traceback(func: Callable) -> Callable: + force_terminal = force_terminal_value() + + def wrapper(*args, **kwargs): + string = io.StringIO() + width = Console().width + width = width - 4 if width > 4 else None + console = Console(file=string, force_terminal=force_terminal, width=width) + try: + return func(*args, **kwargs) + except Exception as e: + tables = [ + get_table(title, data) + for title, data in [ + ("System info", sys_info()), + ("Inputs", processing(*args)), + ("ADetailer", ad_args(*args)), + ] + if data + ] + tables.append(Traceback()) + + console.print(Panel(Group(*tables))) + output = "\n" + string.getvalue() + + try: + error = e.__class__(output) + except Exception: + error = RuntimeError(output) + raise error from None + + return wrapper diff --git a/adetailer/ui.py b/adetailer/ui.py index 5547578..865ecd1 100644 --- a/adetailer/ui.py +++ b/adetailer/ui.py @@ -175,11 +175,6 @@ def one_ui_group( with gr.Group(): controlnet(w, n, is_img2img) - for attr in ALL_ARGS.attrs: - widget = getattr(w, attr) - on_change = partial(on_widget_change, attr=attr) - widget.change(fn=on_change, inputs=[state, widget], outputs=state, queue=False) - all_inputs = [state, *w.tolist()] target_button = i2i_button if is_img2img else t2i_button target_button.click( diff --git a/install.py b/install.py index 7193feb..25bfba2 100644 --- a/install.py +++ b/install.py @@ -48,6 +48,7 @@ def install(): ("mediapipe", "0.10.0", None), ("huggingface_hub", None, None), ("pydantic", "1.10.8", None), + ("rich", "13.4.2", None), # mediapipe ("protobuf", "3.20.0", "3.20.9999"), ] diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index 753f632..461b9cf 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -5,7 +5,7 @@ import platform import re import sys import traceback -from contextlib import contextmanager, suppress +from contextlib import contextmanager from copy import copy, deepcopy from functools import partial from pathlib import Path @@ -14,6 +14,7 @@ from typing import Any import gradio as gr import torch +from rich import print import modules from adetailer import ( @@ -26,6 +27,7 @@ from adetailer import ( from adetailer.args import ALL_ARGS, BBOX_SORTBY, ADetailerArgs, EnableChecker from adetailer.common import PredictOutput from adetailer.mask import filter_by_ratio, mask_preprocess, sort_bboxes +from adetailer.traceback import rich_traceback from adetailer.ui import adui, ordinal, suffix from controlnet_ext import ControlNetExt, controlnet_exists, get_cn_models from controlnet_ext.restore import ( @@ -34,6 +36,7 @@ from controlnet_ext.restore import ( cn_restore_unet_hook, ) from sd_webui import images, safe, script_callbacks, scripts, shared +from sd_webui.devices import NansException from sd_webui.paths import data_path, models_path from sd_webui.processing import ( StableDiffusionProcessingImg2Img, @@ -42,10 +45,6 @@ from sd_webui.processing import ( ) from sd_webui.shared import cmd_opts, opts, state -with suppress(ImportError): - from rich import print - - no_huggingface = getattr(cmd_opts, "ad_no_huggingface", False) adetailer_dir = Path(models_path, "adetailer") model_mapping = get_models(adetailer_dir, huggingface=not no_huggingface) @@ -93,6 +92,9 @@ class AfterDetailerScript(scripts.Script): self.cn_script = None self.cn_latest_network = None + def __repr__(self): + return f"{self.__class__.__name__}(version={__version__})" + def title(self): return AFTER_DETAILER @@ -449,12 +451,25 @@ class AfterDetailerScript(scripts.Script): i2i.prompt = prompt i2i.negative_prompt = negative_prompt + @staticmethod + def compare_prompt(p, processed, n: int = 0): + if p.prompt != processed.all_prompts[0]: + print( + f"[-] ADetailer: applied {ordinal(n + 1)} ad_prompt: {processed.all_prompts[0]!r}" + ) + + if p.negative_prompt != processed.all_negative_prompts[0]: + print( + f"[-] ADetailer: applied {ordinal(n + 1)} ad_negative_prompt: {processed.all_negative_prompts[0]!r}" + ) + def is_need_call_process(self, p) -> bool: i = p._idx n_iter = p.iteration bs = p.batch_size return (i == (n_iter + 1) * bs - 1) and (i != len(p.all_prompts) - 1) + @rich_traceback def process(self, p, *args_): if getattr(p, "_disable_adetailer", False): return @@ -519,6 +534,9 @@ class AfterDetailerScript(scripts.Script): if is_mediapipe: print(f"mediapipe: {steps} detected.") + _user_pt = p.prompt + _user_ng = p.negative_prompt + p2 = copy(i2i) for j in range(steps): p2.image_mask = masks[j] @@ -527,8 +545,15 @@ class AfterDetailerScript(scripts.Script): if not re.match(r"^\s*\[SKIP\]\s*$", p2.prompt): if args.ad_controlnet_model == "None": cn_restore_unet_hook(p2, self.cn_latest_network) - processed = process_images(p2) + try: + processed = process_images(p2) + except NansException as e: + msg = f"[-] ADetailer: 'NansException' occurred with {ordinal(n + 1)} settings.\n{e}" + print(msg, file=sys.stderr) + return False + + self.compare_prompt(p2, processed, n=n) p2 = copy(i2i) p2.init_images = [processed.images[0]] @@ -541,6 +566,7 @@ class AfterDetailerScript(scripts.Script): return False + @rich_traceback def postprocess_image(self, p, pp, *args_): if getattr(p, "_disable_adetailer", False): return diff --git a/sd_webui/devices.py b/sd_webui/devices.py new file mode 100644 index 0000000..51d0569 --- /dev/null +++ b/sd_webui/devices.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + + class NansException(Exception): # noqa: N818 + pass + +else: + from modules.devices import NansException diff --git a/sd_webui/script_callbacks.py b/sd_webui/script_callbacks.py index 06f99fa..ebb3ac0 100644 --- a/sd_webui/script_callbacks.py +++ b/sd_webui/script_callbacks.py @@ -5,6 +5,9 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Callable + def on_app_started(callback: Callable): + pass + def on_ui_settings(callback: Callable): pass @@ -17,6 +20,7 @@ if TYPE_CHECKING: else: from modules.script_callbacks import ( on_after_component, + on_app_started, on_before_ui, on_ui_settings, )