feat(scripts): add scheduler option

This commit is contained in:
Dowon
2024-04-14 01:21:03 +09:00
parent a40c4c77db
commit 38e369305e
3 changed files with 51 additions and 3 deletions

View File

@@ -82,6 +82,7 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
ad_vae: Optional[str] = None ad_vae: Optional[str] = None
ad_use_sampler: bool = False ad_use_sampler: bool = False
ad_sampler: str = "DPM++ 2M Karras" ad_sampler: str = "DPM++ 2M Karras"
ad_scheduler: str = "Use same scheduler"
ad_use_noise_multiplier: bool = False ad_use_noise_multiplier: bool = False
ad_noise_multiplier: confloat(ge=0.5, le=1.5) = 1.0 ad_noise_multiplier: confloat(ge=0.5, le=1.5) = 1.0
ad_use_clip_skip: bool = False ad_use_clip_skip: bool = False
@@ -160,8 +161,13 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
) )
ppop( ppop(
"ADetailer use separate sampler", "ADetailer use separate sampler",
["ADetailer use separate sampler", "ADetailer sampler"], [
"ADetailer use separate sampler",
"ADetailer sampler",
"ADetailer scheduler",
],
) )
ppop("ADetailer scheduler", cond="Use same scheduler")
ppop( ppop(
"ADetailer use separate noise multiplier", "ADetailer use separate noise multiplier",
["ADetailer use separate noise multiplier", "ADetailer noise multiplier"], ["ADetailer use separate noise multiplier", "ADetailer noise multiplier"],
@@ -225,6 +231,7 @@ _all_args = [
("ad_vae", "ADetailer VAE"), ("ad_vae", "ADetailer VAE"),
("ad_use_sampler", "ADetailer use separate sampler"), ("ad_use_sampler", "ADetailer use separate sampler"),
("ad_sampler", "ADetailer sampler"), ("ad_sampler", "ADetailer sampler"),
("ad_scheduler", "ADetailer scheduler"),
("ad_use_noise_multiplier", "ADetailer use separate noise multiplier"), ("ad_use_noise_multiplier", "ADetailer use separate noise multiplier"),
("ad_noise_multiplier", "ADetailer noise multiplier"), ("ad_noise_multiplier", "ADetailer noise multiplier"),
("ad_use_clip_skip", "ADetailer use separate CLIP skip"), ("ad_use_clip_skip", "ADetailer use separate CLIP skip"),

View File

@@ -51,6 +51,7 @@ class Widgets(SimpleNamespace):
class WebuiInfo: class WebuiInfo:
ad_model_list: list[str] ad_model_list: list[str]
sampler_names: list[str] sampler_names: list[str]
scheduler_names: list[str]
t2i_button: gr.Button t2i_button: gr.Button
i2i_button: gr.Button i2i_button: gr.Button
checkpoints_list: list[str] checkpoints_list: list[str]
@@ -545,10 +546,23 @@ def inpainting(w: Widgets, n: int, is_img2img: bool, webui_info: WebuiInfo):
elem_id=eid("ad_sampler"), elem_id=eid("ad_sampler"),
) )
scheduler_names = [
"Use same scheduler",
"Automatic",
*webui_info.scheduler_names,
]
w.ad_scheduler = gr.Dropdown(
label="ADetailer scheduler" + suffix(n),
choices=webui_info.scheduler_names,
value=webui_info.scheduler_names[0],
visible=len(scheduler_names) > 2,
elem_id=eid("ad_scheduler"),
)
w.ad_use_sampler.change( w.ad_use_sampler.change(
gr_interactive, lambda value: (gr_interactive(value), gr_interactive(value)),
inputs=w.ad_use_sampler, inputs=w.ad_use_sampler,
outputs=w.ad_sampler, outputs=[w.ad_sampler, w.ad_scheduler],
queue=False, queue=False,
) )

View File

@@ -75,6 +75,15 @@ except ImportError:
return image.convert("L") return image.convert("L")
try:
from modules.sd_schedulers import sd_schedulers
scheduler_available = True
except ImportError:
sd_schedulers = []
scheduler_available = False
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi import FastAPI from fastapi import FastAPI
@@ -118,6 +127,7 @@ class AfterDetailerScript(scripts.Script):
num_models = opts.data.get("ad_max_models", 2) num_models = opts.data.get("ad_max_models", 2)
ad_model_list = list(model_mapping.keys()) ad_model_list = list(model_mapping.keys())
sampler_names = [sampler.name for sampler in all_samplers] sampler_names = [sampler.name for sampler in all_samplers]
scheduler_names = [x.label for x in sd_schedulers]
try: try:
checkpoint_list = modules.sd_models.checkpoint_tiles(use_shorts=True) checkpoint_list = modules.sd_models.checkpoint_tiles(use_shorts=True)
@@ -128,6 +138,7 @@ class AfterDetailerScript(scripts.Script):
webui_info = WebuiInfo( webui_info = WebuiInfo(
ad_model_list=ad_model_list, ad_model_list=ad_model_list,
sampler_names=sampler_names, sampler_names=sampler_names,
scheduler_names=scheduler_names,
t2i_button=txt2img_submit_button, t2i_button=txt2img_submit_button,
i2i_button=img2img_submit_button, i2i_button=img2img_submit_button,
checkpoints_list=checkpoint_list, checkpoints_list=checkpoint_list,
@@ -373,6 +384,17 @@ class AfterDetailerScript(scripts.Script):
return p._ad_orig.sampler_name return p._ad_orig.sampler_name
return p.sampler_name return p.sampler_name
def get_scheduler(self, p, args: ADetailerArgs) -> dict[str, str]:
"webui >= 1.9.0"
if not args.ad_use_sampler:
return {}
if args.ad_scheduler == "Use same scheduler":
value = getattr(p, "scheduler", "Automatic")
else:
value = args.ad_scheduler
return {"scheduler": value}
def get_override_settings(self, p, args: ADetailerArgs) -> dict[str, Any]: def get_override_settings(self, p, args: ADetailerArgs) -> dict[str, Any]:
d = {} d = {}
@@ -470,6 +492,10 @@ class AfterDetailerScript(scripts.Script):
sampler_name = self.get_sampler(p, args) sampler_name = self.get_sampler(p, args)
override_settings = self.get_override_settings(p, args) override_settings = self.get_override_settings(p, args)
version_args = {}
if scheduler_available:
version_args.update(self.get_scheduler(p, args))
i2i = StableDiffusionProcessingImg2Img( i2i = StableDiffusionProcessingImg2Img(
init_images=[image], init_images=[image],
resize_mode=0, resize_mode=0,
@@ -505,6 +531,7 @@ class AfterDetailerScript(scripts.Script):
do_not_save_samples=True, do_not_save_samples=True,
do_not_save_grid=True, do_not_save_grid=True,
override_settings=override_settings, override_settings=override_settings,
**version_args,
) )
i2i.cached_c = [None, None] i2i.cached_c = [None, None]