diff --git a/CHANGELOG.md b/CHANGELOG.md index a468a82..d1e485a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # Changelog +## 2023-05-26 + +- v23.5.19 +- 1번째 탭에도 `None` 옵션을 추가함 +- api로 ad controlnet model에 inpaint가 아닌 다른 컨트롤넷 모델을 사용하지 못하도록 막음 +- adetailer 진행중에 total tqdm 진행바 업데이트를 멈춤 +- state.inturrupted 상태에서 adetailer 과정을 중지함 +- 컨트롤넷 process를 각 batch가 끝난 순간에만 호출하도록 변경 + ### 2023-05-25 - v23.5.18 diff --git a/adetailer/__version__.py b/adetailer/__version__.py index 220c1d2..fce398f 100644 --- a/adetailer/__version__.py +++ b/adetailer/__version__.py @@ -1 +1 @@ -__version__ = "23.5.18" +__version__ = "23.5.19" diff --git a/adetailer/args.py b/adetailer/args.py index 394731e..fabbef1 100644 --- a/adetailer/args.py +++ b/adetailer/args.py @@ -12,6 +12,7 @@ from pydantic import ( NonNegativeInt, PositiveInt, confloat, + constr, validator, ) @@ -54,7 +55,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid): ad_use_cfg_scale: bool = False ad_cfg_scale: NonNegativeFloat = 7.0 ad_restore_face: bool = False - ad_controlnet_model: str = "None" + ad_controlnet_model: constr(regex=r".*inpaint.*|^None$") = "None" ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0 @validator("ad_conf", pre=True) diff --git a/adetailer/ui.py b/adetailer/ui.py index 359523e..8fccaf3 100644 --- a/adetailer/ui.py +++ b/adetailer/ui.py @@ -105,7 +105,7 @@ def one_ui_group( eid = partial(elem_id, n=n, is_img2img=is_img2img) with gr.Row(): - model_choices = model_list if n == 0 else ["None"] + model_list + model_choices = model_list + ["None"] if n == 0 else ["None"] + model_list w.ad_model = gr.Dropdown( label="ADetailer model" + suffix(n), diff --git a/controlnet_ext/restore.py b/controlnet_ext/restore.py index c218e07..5b9bfa6 100644 --- a/controlnet_ext/restore.py +++ b/controlnet_ext/restore.py @@ -1,6 +1,8 @@ from __future__ import annotations -from modules import img2img, processing +from contextlib import contextmanager + +from modules import img2img, processing, shared def cn_restore_unet_hook(p, cn_latest_network): @@ -31,3 +33,17 @@ class CNHijackRestore: processing.process_images_inner = self.orig_process if self.img2img: img2img.process_batch = self.orig_img2img + + +@contextmanager +def cn_allow_script_control(): + orig = False + if "control_net_allow_script_control" in shared.opts.data: + try: + orig = shared.opts.data["control_net_allow_script_control"] + shared.opts.data["control_net_allow_script_control"] = True + yield + finally: + shared.opts.data["control_net_allow_script_control"] = orig + else: + yield diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index d5dda67..754a644 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -5,6 +5,7 @@ import platform import re import sys import traceback +from contextlib import contextmanager from copy import copy, deepcopy from pathlib import Path from textwrap import dedent @@ -26,7 +27,11 @@ from adetailer.common import PredictOutput from adetailer.mask import filter_by_ratio, mask_preprocess, sort_bboxes from adetailer.ui import adui, ordinal, suffix from controlnet_ext import ControlNetExt, controlnet_exists -from controlnet_ext.restore import CNHijackRestore, cn_restore_unet_hook +from controlnet_ext.restore import ( + CNHijackRestore, + cn_allow_script_control, + cn_restore_unet_hook, +) from sd_webui import images, safe, script_callbacks, scripts, shared from sd_webui.paths import data_path, models_path from sd_webui.processing import ( @@ -34,7 +39,7 @@ from sd_webui.processing import ( create_infotext, process_images, ) -from sd_webui.shared import cmd_opts, opts +from sd_webui.shared import cmd_opts, opts, state try: from rich import print @@ -62,13 +67,24 @@ print( ) -class ChangeTorchLoad: - def __enter__(self): - self.orig = torch.load +@contextmanager +def change_torch_load(): + orig = torch.load + try: torch.load = safe.unsafe_torch_load + yield + finally: + torch.load = orig - def __exit__(self, *args, **kwargs): - torch.load = self.orig + +@contextmanager +def pause_total_tqdm(): + orig = opts.data.get("multiple_tqdm", True) + try: + opts.data["multiple_tqdm"] = False + yield + finally: + opts.data["multiple_tqdm"] = orig class AfterDetailerScript(scripts.Script): @@ -311,6 +327,9 @@ class AfterDetailerScript(scripts.Script): if hasattr(obj, "input_mode"): obj.input_mode = getattr(obj.input_mode, "SIMPLE", "simple") + elif isinstance(obj, dict) and "module" in obj: + obj["enabled"] = False + def get_i2i_p(self, p, args: ADetailerArgs, image): seed, subseed = self.get_seed(p) width, height = self.get_width_height(p, args) @@ -361,6 +380,9 @@ class AfterDetailerScript(scripts.Script): if args.ad_controlnet_model != "None": self.update_controlnet_args(i2i, args) + else: + i2i.control_net_enabled = False + return i2i def save_image(self, p, image, *, condition: str, suffix: str) -> None: @@ -415,6 +437,12 @@ class AfterDetailerScript(scripts.Script): i2i.prompt = prompt i2i.negative_prompt = negative_prompt + def is_need_call_process(self, p): + i = p._idx + n_iter = p.iteration + bs = p.batch_size + return (i == (n_iter + 1) * bs - 1) and (i != len(p.all_prompts) - 1) + def process(self, p, *args_): if getattr(p, "_disable_adetailer", False): return @@ -424,6 +452,8 @@ class AfterDetailerScript(scripts.Script): extra_params = self.extra_params(arg_list) p.extra_generation_params.update(extra_params) + p._idx = -1 + def _postprocess_image(self, p, pp, args: ADetailerArgs, *, n: int = 0) -> bool: """ Returns @@ -432,6 +462,9 @@ class AfterDetailerScript(scripts.Script): `True` if image was processed, `False` otherwise. """ + if state.interrupted: + return False + i = p._idx i2i = self.get_i2i_p(p, args, pp.image) @@ -449,7 +482,7 @@ class AfterDetailerScript(scripts.Script): ad_model = self.get_ad_model(args.ad_model) kwargs["device"] = self.ultralytics_device - with ChangeTorchLoad(): + with change_torch_load(): pred = predictor(ad_model, pp.image, args.ad_conf, **kwargs) masks = self.pred_preprocessing(pred, args) @@ -469,6 +502,7 @@ class AfterDetailerScript(scripts.Script): steps = len(masks) processed = None + state.job_count += steps if is_mediapipe: print(f"mediapipe: {steps} detected.") @@ -507,10 +541,10 @@ class AfterDetailerScript(scripts.Script): arg_list = self.get_args(*args_) is_processed = False - for n, args in enumerate(arg_list): - if args.ad_model == "None": - continue - with CNHijackRestore(): + with CNHijackRestore(), pause_total_tqdm(), cn_allow_script_control(): + for n, args in enumerate(arg_list): + if args.ad_model == "None": + continue is_processed |= self._postprocess_image(p, pp, args, n=n) if is_processed: @@ -518,8 +552,8 @@ class AfterDetailerScript(scripts.Script): p, init_image, condition="ad_save_images_before", suffix="-ad-before" ) - if self.cn_script is not None: - self.cn_script.process(p) + if self.cn_script is not None and self.is_need_call_process(p): + self.cn_script.process(p) try: if p._idx == len(p.all_prompts) - 1: diff --git a/sd_webui/shared.py b/sd_webui/shared.py index e71260a..18b0cd0 100644 --- a/sd_webui/shared.py +++ b/sd_webui/shared.py @@ -7,6 +7,29 @@ if TYPE_CHECKING: from dataclasses import dataclass from typing import Any, Callable + import torch + from PIL import Image + + @dataclass + class State: + skipped: bool = False + interrupted: bool = False + job: str = "" + job_no: int = 0 + job_count: int = 0 + processing_has_refined_job_count: bool = False + job_timestamp: str = "0" + sampling_step: int = 0 + sampling_steps: int = 0 + current_latent: torch.Tensor | None = None + current_image: Image.Image | None = None + current_image_sampling_step: int = 0 + id_live_preview: int = 0 + textinfo: str | None = None + time_start: float | None = None + need_restart: bool = False + server_start: float | None = None + @dataclass class OptionInfo: default: Any = None @@ -37,6 +60,7 @@ if TYPE_CHECKING: opts = Option() cmd_opts = argparse.Namespace() + state = State() else: - from modules.shared import OptionInfo, cmd_opts, opts + from modules.shared import OptionInfo, cmd_opts, opts, state