mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-03-02 03:49:59 +00:00
Merge branch 'dev' into main
This commit is contained in:
@@ -1,5 +1,14 @@
|
||||
# Changelog
|
||||
|
||||
## 2023-05-26
|
||||
|
||||
- v23.5.19
|
||||
- 1번째 탭에도 `None` 옵션을 추가함
|
||||
- api로 ad controlnet model에 inpaint가 아닌 다른 컨트롤넷 모델을 사용하지 못하도록 막음
|
||||
- adetailer 진행중에 total tqdm 진행바 업데이트를 멈춤
|
||||
- state.inturrupted 상태에서 adetailer 과정을 중지함
|
||||
- 컨트롤넷 process를 각 batch가 끝난 순간에만 호출하도록 변경
|
||||
|
||||
### 2023-05-25
|
||||
|
||||
- v23.5.18
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "23.5.18"
|
||||
__version__ = "23.5.19"
|
||||
|
||||
@@ -12,6 +12,7 @@ from pydantic import (
|
||||
NonNegativeInt,
|
||||
PositiveInt,
|
||||
confloat,
|
||||
constr,
|
||||
validator,
|
||||
)
|
||||
|
||||
@@ -54,7 +55,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
|
||||
ad_use_cfg_scale: bool = False
|
||||
ad_cfg_scale: NonNegativeFloat = 7.0
|
||||
ad_restore_face: bool = False
|
||||
ad_controlnet_model: str = "None"
|
||||
ad_controlnet_model: constr(regex=r".*inpaint.*|^None$") = "None"
|
||||
ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0
|
||||
|
||||
@validator("ad_conf", pre=True)
|
||||
|
||||
@@ -105,7 +105,7 @@ def one_ui_group(
|
||||
eid = partial(elem_id, n=n, is_img2img=is_img2img)
|
||||
|
||||
with gr.Row():
|
||||
model_choices = model_list if n == 0 else ["None"] + model_list
|
||||
model_choices = model_list + ["None"] if n == 0 else ["None"] + model_list
|
||||
|
||||
w.ad_model = gr.Dropdown(
|
||||
label="ADetailer model" + suffix(n),
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from modules import img2img, processing
|
||||
from contextlib import contextmanager
|
||||
|
||||
from modules import img2img, processing, shared
|
||||
|
||||
|
||||
def cn_restore_unet_hook(p, cn_latest_network):
|
||||
@@ -31,3 +33,17 @@ class CNHijackRestore:
|
||||
processing.process_images_inner = self.orig_process
|
||||
if self.img2img:
|
||||
img2img.process_batch = self.orig_img2img
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cn_allow_script_control():
|
||||
orig = False
|
||||
if "control_net_allow_script_control" in shared.opts.data:
|
||||
try:
|
||||
orig = shared.opts.data["control_net_allow_script_control"]
|
||||
shared.opts.data["control_net_allow_script_control"] = True
|
||||
yield
|
||||
finally:
|
||||
shared.opts.data["control_net_allow_script_control"] = orig
|
||||
else:
|
||||
yield
|
||||
|
||||
@@ -5,6 +5,7 @@ import platform
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from copy import copy, deepcopy
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
@@ -26,7 +27,11 @@ from adetailer.common import PredictOutput
|
||||
from adetailer.mask import filter_by_ratio, mask_preprocess, sort_bboxes
|
||||
from adetailer.ui import adui, ordinal, suffix
|
||||
from controlnet_ext import ControlNetExt, controlnet_exists
|
||||
from controlnet_ext.restore import CNHijackRestore, cn_restore_unet_hook
|
||||
from controlnet_ext.restore import (
|
||||
CNHijackRestore,
|
||||
cn_allow_script_control,
|
||||
cn_restore_unet_hook,
|
||||
)
|
||||
from sd_webui import images, safe, script_callbacks, scripts, shared
|
||||
from sd_webui.paths import data_path, models_path
|
||||
from sd_webui.processing import (
|
||||
@@ -34,7 +39,7 @@ from sd_webui.processing import (
|
||||
create_infotext,
|
||||
process_images,
|
||||
)
|
||||
from sd_webui.shared import cmd_opts, opts
|
||||
from sd_webui.shared import cmd_opts, opts, state
|
||||
|
||||
try:
|
||||
from rich import print
|
||||
@@ -62,13 +67,24 @@ print(
|
||||
)
|
||||
|
||||
|
||||
class ChangeTorchLoad:
|
||||
def __enter__(self):
|
||||
self.orig = torch.load
|
||||
@contextmanager
|
||||
def change_torch_load():
|
||||
orig = torch.load
|
||||
try:
|
||||
torch.load = safe.unsafe_torch_load
|
||||
yield
|
||||
finally:
|
||||
torch.load = orig
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
torch.load = self.orig
|
||||
|
||||
@contextmanager
|
||||
def pause_total_tqdm():
|
||||
orig = opts.data.get("multiple_tqdm", True)
|
||||
try:
|
||||
opts.data["multiple_tqdm"] = False
|
||||
yield
|
||||
finally:
|
||||
opts.data["multiple_tqdm"] = orig
|
||||
|
||||
|
||||
class AfterDetailerScript(scripts.Script):
|
||||
@@ -311,6 +327,9 @@ class AfterDetailerScript(scripts.Script):
|
||||
if hasattr(obj, "input_mode"):
|
||||
obj.input_mode = getattr(obj.input_mode, "SIMPLE", "simple")
|
||||
|
||||
elif isinstance(obj, dict) and "module" in obj:
|
||||
obj["enabled"] = False
|
||||
|
||||
def get_i2i_p(self, p, args: ADetailerArgs, image):
|
||||
seed, subseed = self.get_seed(p)
|
||||
width, height = self.get_width_height(p, args)
|
||||
@@ -361,6 +380,9 @@ class AfterDetailerScript(scripts.Script):
|
||||
|
||||
if args.ad_controlnet_model != "None":
|
||||
self.update_controlnet_args(i2i, args)
|
||||
else:
|
||||
i2i.control_net_enabled = False
|
||||
|
||||
return i2i
|
||||
|
||||
def save_image(self, p, image, *, condition: str, suffix: str) -> None:
|
||||
@@ -415,6 +437,12 @@ class AfterDetailerScript(scripts.Script):
|
||||
i2i.prompt = prompt
|
||||
i2i.negative_prompt = negative_prompt
|
||||
|
||||
def is_need_call_process(self, p):
|
||||
i = p._idx
|
||||
n_iter = p.iteration
|
||||
bs = p.batch_size
|
||||
return (i == (n_iter + 1) * bs - 1) and (i != len(p.all_prompts) - 1)
|
||||
|
||||
def process(self, p, *args_):
|
||||
if getattr(p, "_disable_adetailer", False):
|
||||
return
|
||||
@@ -424,6 +452,8 @@ class AfterDetailerScript(scripts.Script):
|
||||
extra_params = self.extra_params(arg_list)
|
||||
p.extra_generation_params.update(extra_params)
|
||||
|
||||
p._idx = -1
|
||||
|
||||
def _postprocess_image(self, p, pp, args: ADetailerArgs, *, n: int = 0) -> bool:
|
||||
"""
|
||||
Returns
|
||||
@@ -432,6 +462,9 @@ class AfterDetailerScript(scripts.Script):
|
||||
|
||||
`True` if image was processed, `False` otherwise.
|
||||
"""
|
||||
if state.interrupted:
|
||||
return False
|
||||
|
||||
i = p._idx
|
||||
|
||||
i2i = self.get_i2i_p(p, args, pp.image)
|
||||
@@ -449,7 +482,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
ad_model = self.get_ad_model(args.ad_model)
|
||||
kwargs["device"] = self.ultralytics_device
|
||||
|
||||
with ChangeTorchLoad():
|
||||
with change_torch_load():
|
||||
pred = predictor(ad_model, pp.image, args.ad_conf, **kwargs)
|
||||
|
||||
masks = self.pred_preprocessing(pred, args)
|
||||
@@ -469,6 +502,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
|
||||
steps = len(masks)
|
||||
processed = None
|
||||
state.job_count += steps
|
||||
|
||||
if is_mediapipe:
|
||||
print(f"mediapipe: {steps} detected.")
|
||||
@@ -507,10 +541,10 @@ class AfterDetailerScript(scripts.Script):
|
||||
arg_list = self.get_args(*args_)
|
||||
|
||||
is_processed = False
|
||||
for n, args in enumerate(arg_list):
|
||||
if args.ad_model == "None":
|
||||
continue
|
||||
with CNHijackRestore():
|
||||
with CNHijackRestore(), pause_total_tqdm(), cn_allow_script_control():
|
||||
for n, args in enumerate(arg_list):
|
||||
if args.ad_model == "None":
|
||||
continue
|
||||
is_processed |= self._postprocess_image(p, pp, args, n=n)
|
||||
|
||||
if is_processed:
|
||||
@@ -518,8 +552,8 @@ class AfterDetailerScript(scripts.Script):
|
||||
p, init_image, condition="ad_save_images_before", suffix="-ad-before"
|
||||
)
|
||||
|
||||
if self.cn_script is not None:
|
||||
self.cn_script.process(p)
|
||||
if self.cn_script is not None and self.is_need_call_process(p):
|
||||
self.cn_script.process(p)
|
||||
|
||||
try:
|
||||
if p._idx == len(p.all_prompts) - 1:
|
||||
|
||||
@@ -7,6 +7,29 @@ if TYPE_CHECKING:
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
skipped: bool = False
|
||||
interrupted: bool = False
|
||||
job: str = ""
|
||||
job_no: int = 0
|
||||
job_count: int = 0
|
||||
processing_has_refined_job_count: bool = False
|
||||
job_timestamp: str = "0"
|
||||
sampling_step: int = 0
|
||||
sampling_steps: int = 0
|
||||
current_latent: torch.Tensor | None = None
|
||||
current_image: Image.Image | None = None
|
||||
current_image_sampling_step: int = 0
|
||||
id_live_preview: int = 0
|
||||
textinfo: str | None = None
|
||||
time_start: float | None = None
|
||||
need_restart: bool = False
|
||||
server_start: float | None = None
|
||||
|
||||
@dataclass
|
||||
class OptionInfo:
|
||||
default: Any = None
|
||||
@@ -37,6 +60,7 @@ if TYPE_CHECKING:
|
||||
|
||||
opts = Option()
|
||||
cmd_opts = argparse.Namespace()
|
||||
state = State()
|
||||
|
||||
else:
|
||||
from modules.shared import OptionInfo, cmd_opts, opts
|
||||
from modules.shared import OptionInfo, cmd_opts, opts, state
|
||||
|
||||
Reference in New Issue
Block a user