Merge branch 'dev' into main

This commit is contained in:
Bingsu
2023-05-26 14:57:05 +09:00
7 changed files with 103 additions and 19 deletions

View File

@@ -1,5 +1,14 @@
# Changelog
## 2023-05-26
- v23.5.19
- 1번째 탭에도 `None` 옵션을 추가함
- api로 ad controlnet model에 inpaint가 아닌 다른 컨트롤넷 모델을 사용하지 못하도록 막음
- adetailer 진행중에 total tqdm 진행바 업데이트를 멈춤
- state.inturrupted 상태에서 adetailer 과정을 중지함
- 컨트롤넷 process를 각 batch가 끝난 순간에만 호출하도록 변경
### 2023-05-25
- v23.5.18

View File

@@ -1 +1 @@
__version__ = "23.5.18"
__version__ = "23.5.19"

View File

@@ -12,6 +12,7 @@ from pydantic import (
NonNegativeInt,
PositiveInt,
confloat,
constr,
validator,
)
@@ -54,7 +55,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
ad_use_cfg_scale: bool = False
ad_cfg_scale: NonNegativeFloat = 7.0
ad_restore_face: bool = False
ad_controlnet_model: str = "None"
ad_controlnet_model: constr(regex=r".*inpaint.*|^None$") = "None"
ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0
@validator("ad_conf", pre=True)

View File

@@ -105,7 +105,7 @@ def one_ui_group(
eid = partial(elem_id, n=n, is_img2img=is_img2img)
with gr.Row():
model_choices = model_list if n == 0 else ["None"] + model_list
model_choices = model_list + ["None"] if n == 0 else ["None"] + model_list
w.ad_model = gr.Dropdown(
label="ADetailer model" + suffix(n),

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
from modules import img2img, processing
from contextlib import contextmanager
from modules import img2img, processing, shared
def cn_restore_unet_hook(p, cn_latest_network):
@@ -31,3 +33,17 @@ class CNHijackRestore:
processing.process_images_inner = self.orig_process
if self.img2img:
img2img.process_batch = self.orig_img2img
@contextmanager
def cn_allow_script_control():
orig = False
if "control_net_allow_script_control" in shared.opts.data:
try:
orig = shared.opts.data["control_net_allow_script_control"]
shared.opts.data["control_net_allow_script_control"] = True
yield
finally:
shared.opts.data["control_net_allow_script_control"] = orig
else:
yield

View File

@@ -5,6 +5,7 @@ import platform
import re
import sys
import traceback
from contextlib import contextmanager
from copy import copy, deepcopy
from pathlib import Path
from textwrap import dedent
@@ -26,7 +27,11 @@ from adetailer.common import PredictOutput
from adetailer.mask import filter_by_ratio, mask_preprocess, sort_bboxes
from adetailer.ui import adui, ordinal, suffix
from controlnet_ext import ControlNetExt, controlnet_exists
from controlnet_ext.restore import CNHijackRestore, cn_restore_unet_hook
from controlnet_ext.restore import (
CNHijackRestore,
cn_allow_script_control,
cn_restore_unet_hook,
)
from sd_webui import images, safe, script_callbacks, scripts, shared
from sd_webui.paths import data_path, models_path
from sd_webui.processing import (
@@ -34,7 +39,7 @@ from sd_webui.processing import (
create_infotext,
process_images,
)
from sd_webui.shared import cmd_opts, opts
from sd_webui.shared import cmd_opts, opts, state
try:
from rich import print
@@ -62,13 +67,24 @@ print(
)
class ChangeTorchLoad:
def __enter__(self):
self.orig = torch.load
@contextmanager
def change_torch_load():
orig = torch.load
try:
torch.load = safe.unsafe_torch_load
yield
finally:
torch.load = orig
def __exit__(self, *args, **kwargs):
torch.load = self.orig
@contextmanager
def pause_total_tqdm():
orig = opts.data.get("multiple_tqdm", True)
try:
opts.data["multiple_tqdm"] = False
yield
finally:
opts.data["multiple_tqdm"] = orig
class AfterDetailerScript(scripts.Script):
@@ -311,6 +327,9 @@ class AfterDetailerScript(scripts.Script):
if hasattr(obj, "input_mode"):
obj.input_mode = getattr(obj.input_mode, "SIMPLE", "simple")
elif isinstance(obj, dict) and "module" in obj:
obj["enabled"] = False
def get_i2i_p(self, p, args: ADetailerArgs, image):
seed, subseed = self.get_seed(p)
width, height = self.get_width_height(p, args)
@@ -361,6 +380,9 @@ class AfterDetailerScript(scripts.Script):
if args.ad_controlnet_model != "None":
self.update_controlnet_args(i2i, args)
else:
i2i.control_net_enabled = False
return i2i
def save_image(self, p, image, *, condition: str, suffix: str) -> None:
@@ -415,6 +437,12 @@ class AfterDetailerScript(scripts.Script):
i2i.prompt = prompt
i2i.negative_prompt = negative_prompt
def is_need_call_process(self, p):
i = p._idx
n_iter = p.iteration
bs = p.batch_size
return (i == (n_iter + 1) * bs - 1) and (i != len(p.all_prompts) - 1)
def process(self, p, *args_):
if getattr(p, "_disable_adetailer", False):
return
@@ -424,6 +452,8 @@ class AfterDetailerScript(scripts.Script):
extra_params = self.extra_params(arg_list)
p.extra_generation_params.update(extra_params)
p._idx = -1
def _postprocess_image(self, p, pp, args: ADetailerArgs, *, n: int = 0) -> bool:
"""
Returns
@@ -432,6 +462,9 @@ class AfterDetailerScript(scripts.Script):
`True` if image was processed, `False` otherwise.
"""
if state.interrupted:
return False
i = p._idx
i2i = self.get_i2i_p(p, args, pp.image)
@@ -449,7 +482,7 @@ class AfterDetailerScript(scripts.Script):
ad_model = self.get_ad_model(args.ad_model)
kwargs["device"] = self.ultralytics_device
with ChangeTorchLoad():
with change_torch_load():
pred = predictor(ad_model, pp.image, args.ad_conf, **kwargs)
masks = self.pred_preprocessing(pred, args)
@@ -469,6 +502,7 @@ class AfterDetailerScript(scripts.Script):
steps = len(masks)
processed = None
state.job_count += steps
if is_mediapipe:
print(f"mediapipe: {steps} detected.")
@@ -507,10 +541,10 @@ class AfterDetailerScript(scripts.Script):
arg_list = self.get_args(*args_)
is_processed = False
for n, args in enumerate(arg_list):
if args.ad_model == "None":
continue
with CNHijackRestore():
with CNHijackRestore(), pause_total_tqdm(), cn_allow_script_control():
for n, args in enumerate(arg_list):
if args.ad_model == "None":
continue
is_processed |= self._postprocess_image(p, pp, args, n=n)
if is_processed:
@@ -518,8 +552,8 @@ class AfterDetailerScript(scripts.Script):
p, init_image, condition="ad_save_images_before", suffix="-ad-before"
)
if self.cn_script is not None:
self.cn_script.process(p)
if self.cn_script is not None and self.is_need_call_process(p):
self.cn_script.process(p)
try:
if p._idx == len(p.all_prompts) - 1:

View File

@@ -7,6 +7,29 @@ if TYPE_CHECKING:
from dataclasses import dataclass
from typing import Any, Callable
import torch
from PIL import Image
@dataclass
class State:
skipped: bool = False
interrupted: bool = False
job: str = ""
job_no: int = 0
job_count: int = 0
processing_has_refined_job_count: bool = False
job_timestamp: str = "0"
sampling_step: int = 0
sampling_steps: int = 0
current_latent: torch.Tensor | None = None
current_image: Image.Image | None = None
current_image_sampling_step: int = 0
id_live_preview: int = 0
textinfo: str | None = None
time_start: float | None = None
need_restart: bool = False
server_start: float | None = None
@dataclass
class OptionInfo:
default: Any = None
@@ -37,6 +60,7 @@ if TYPE_CHECKING:
opts = Option()
cmd_opts = argparse.Namespace()
state = State()
else:
from modules.shared import OptionInfo, cmd_opts, opts
from modules.shared import OptionInfo, cmd_opts, opts, state