feat(scripts): add scheduler option

This commit is contained in:
Dowon
2024-04-14 01:21:03 +09:00
parent a40c4c77db
commit 38e369305e
3 changed files with 51 additions and 3 deletions

View File

@@ -82,6 +82,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
ad_vae: Optional[str] = None
ad_use_sampler: bool = False
ad_sampler: str = "DPM++ 2M Karras"
ad_scheduler: str = "Use same scheduler"
ad_use_noise_multiplier: bool = False
ad_noise_multiplier: confloat(ge=0.5, le=1.5) = 1.0
ad_use_clip_skip: bool = False
@@ -160,8 +161,13 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
)
ppop(
"ADetailer use separate sampler",
["ADetailer use separate sampler", "ADetailer sampler"],
[
"ADetailer use separate sampler",
"ADetailer sampler",
"ADetailer scheduler",
],
)
ppop("ADetailer scheduler", cond="Use same scheduler")
ppop(
"ADetailer use separate noise multiplier",
["ADetailer use separate noise multiplier", "ADetailer noise multiplier"],
@@ -225,6 +231,7 @@ _all_args = [
("ad_vae", "ADetailer VAE"),
("ad_use_sampler", "ADetailer use separate sampler"),
("ad_sampler", "ADetailer sampler"),
("ad_scheduler", "ADetailer scheduler"),
("ad_use_noise_multiplier", "ADetailer use separate noise multiplier"),
("ad_noise_multiplier", "ADetailer noise multiplier"),
("ad_use_clip_skip", "ADetailer use separate CLIP skip"),

View File

@@ -51,6 +51,7 @@ class Widgets(SimpleNamespace):
class WebuiInfo:
ad_model_list: list[str]
sampler_names: list[str]
scheduler_names: list[str]
t2i_button: gr.Button
i2i_button: gr.Button
checkpoints_list: list[str]
@@ -545,10 +546,23 @@ def inpainting(w: Widgets, n: int, is_img2img: bool, webui_info: WebuiInfo):
elem_id=eid("ad_sampler"),
)
scheduler_names = [
"Use same scheduler",
"Automatic",
*webui_info.scheduler_names,
]
w.ad_scheduler = gr.Dropdown(
label="ADetailer scheduler" + suffix(n),
choices=webui_info.scheduler_names,
value=webui_info.scheduler_names[0],
visible=len(scheduler_names) > 2,
elem_id=eid("ad_scheduler"),
)
w.ad_use_sampler.change(
gr_interactive,
lambda value: (gr_interactive(value), gr_interactive(value)),
inputs=w.ad_use_sampler,
outputs=w.ad_sampler,
outputs=[w.ad_sampler, w.ad_scheduler],
queue=False,
)

View File

@@ -75,6 +75,15 @@ except ImportError:
return image.convert("L")
try:
from modules.sd_schedulers import sd_schedulers
scheduler_available = True
except ImportError:
sd_schedulers = []
scheduler_available = False
if TYPE_CHECKING:
from fastapi import FastAPI
@@ -118,6 +127,7 @@ class AfterDetailerScript(scripts.Script):
num_models = opts.data.get("ad_max_models", 2)
ad_model_list = list(model_mapping.keys())
sampler_names = [sampler.name for sampler in all_samplers]
scheduler_names = [x.label for x in sd_schedulers]
try:
checkpoint_list = modules.sd_models.checkpoint_tiles(use_shorts=True)
@@ -128,6 +138,7 @@ class AfterDetailerScript(scripts.Script):
webui_info = WebuiInfo(
ad_model_list=ad_model_list,
sampler_names=sampler_names,
scheduler_names=scheduler_names,
t2i_button=txt2img_submit_button,
i2i_button=img2img_submit_button,
checkpoints_list=checkpoint_list,
@@ -373,6 +384,17 @@ class AfterDetailerScript(scripts.Script):
return p._ad_orig.sampler_name
return p.sampler_name
def get_scheduler(self, p, args: ADetailerArgs) -> dict[str, str]:
"webui >= 1.9.0"
if not args.ad_use_sampler:
return {}
if args.ad_scheduler == "Use same scheduler":
value = getattr(p, "scheduler", "Automatic")
else:
value = args.ad_scheduler
return {"scheduler": value}
def get_override_settings(self, p, args: ADetailerArgs) -> dict[str, Any]:
d = {}
@@ -470,6 +492,10 @@ class AfterDetailerScript(scripts.Script):
sampler_name = self.get_sampler(p, args)
override_settings = self.get_override_settings(p, args)
version_args = {}
if scheduler_available:
version_args.update(self.get_scheduler(p, args))
i2i = StableDiffusionProcessingImg2Img(
init_images=[image],
resize_mode=0,
@@ -505,6 +531,7 @@ class AfterDetailerScript(scripts.Script):
do_not_save_samples=True,
do_not_save_grid=True,
override_settings=override_settings,
**version_args,
)
i2i.cached_c = [None, None]