feat: save image before ad, args length check

This commit is contained in:
Bingsu
2023-05-05 08:30:32 +09:00
parent ccdb62d11c
commit 40d1373de9

View File

@@ -4,6 +4,7 @@ import platform
import sys import sys
from copy import copy from copy import copy
from pathlib import Path from pathlib import Path
from textwrap import dedent
import gradio as gr import gradio as gr
import torch import torch
@@ -409,6 +410,21 @@ class AfterDetailerScript(scripts.Script):
self.update_controlnet_args(i2i, args) self.update_controlnet_args(i2i, args)
return i2i return i2i
def save_image(self, p, image, seed, *, condition: str, suffix: str):
i = p._idx
if opts.data.get(condition, False):
images.save_image(
image=image,
path=p.outpath_samples,
basename="",
seed=seed,
prompt=p.all_prompts[i] if i < len(p.all_prompts) else p.prompt,
extension=opts.samples_format,
info=self.infotext(p),
p=p,
suffix=suffix,
)
def get_ad_model(self, name: str): def get_ad_model(self, name: str):
if name not in model_mapping: if name not in model_mapping:
msg = f"[-] ADetailer: Model {name!r} not found. Available models: {list(model_mapping.keys())}" msg = f"[-] ADetailer: Model {name!r} not found. Available models: {list(model_mapping.keys())}"
@@ -426,6 +442,9 @@ class AfterDetailerScript(scripts.Script):
) )
def process(self, p, *args_): def process(self, p, *args_):
if getattr(p, "_disable_adetailer", False):
return
args = get_args(*args_) args = get_args(*args_)
if self.is_ad_enabled(args): if self.is_ad_enabled(args):
extra_params = self.extra_params(args) extra_params = self.extra_params(args)
@@ -435,6 +454,15 @@ class AfterDetailerScript(scripts.Script):
if getattr(p, "_disable_adetailer", False): if getattr(p, "_disable_adetailer", False):
return return
if len(args_) != len(ALL_ARGS):
message = f"""
[-] ADetailer: len(args)({len(args_)}) != len(ALL_ARGS)({len(ALL_ARGS)})
Something went wrong. Please reload this extension.
"""
print(dedent(message), file=sys.stderr)
p._disable_adetailer = True
return
args = get_args(*args_) args = get_args(*args_)
if not self.is_ad_enabled(args): if not self.is_ad_enabled(args):
@@ -446,6 +474,10 @@ class AfterDetailerScript(scripts.Script):
i2i = self.get_i2i_p(p, args, pp.image) i2i = self.get_i2i_p(p, args, pp.image)
seed, subseed = self.get_seed(p) seed, subseed = self.get_seed(p)
self.save_image(
p, pp.image, seed, condition="ad_save_images_before", suffix="-ad-before"
)
is_mediapipe = args.ad_model.lower().startswith("mediapipe") is_mediapipe = args.ad_model.lower().startswith("mediapipe")
kwargs = {} kwargs = {}
@@ -466,18 +498,9 @@ class AfterDetailerScript(scripts.Script):
) )
return return
if opts.data.get("ad_save_previews", False): self.save_image(
images.save_image( p, pred.preview, seed, condition="ad_save_previews", suffix="-ad-preview"
image=pred.preview, )
path=p.outpath_samples,
basename="",
seed=seed,
prompt=p.all_prompts[i],
extension=opts.samples_format,
info=self.infotext(p),
p=p,
suffix="-ad-preview",
)
masks = pred.masks masks = pred.masks
steps = len(masks) steps = len(masks)
@@ -519,5 +542,10 @@ def on_ui_settings():
shared.OptionInfo(False, "Save mask previews", section=section), shared.OptionInfo(False, "Save mask previews", section=section),
) )
shared.opts.add_option(
"ad_save_images_before",
shared.OptionInfo(False, "Save images before ADetailer", section=section),
)
script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_ui_settings(on_ui_settings)