diff --git a/extensions-builtin/sd_forge_sag/scripts/forge_sag.py b/extensions-builtin/sd_forge_sag/scripts/forge_sag.py index 8ee4ad51..ab477aa6 100644 --- a/extensions-builtin/sd_forge_sag/scripts/forge_sag.py +++ b/extensions-builtin/sd_forge_sag/scripts/forge_sag.py @@ -6,7 +6,8 @@ from backend.sampling.sampling_function import calc_cond_uncond_batch from backend import attention, memory_management from torch import einsum from einops import rearrange, repeat -from modules import scripts +from modules import scripts, shared +from modules.ui_components import InputAccordion attn_precision = memory_management.force_upcast_attention_dtype() @@ -99,7 +100,7 @@ def gaussian_blur_2d(img, kernel_size, sigma): class SelfAttentionGuidance: - def patch(self, model, scale, blur_sigma): + def patch(self, model, scale, blur_sigma, threshold): m = model.clone() attn_scores = None @@ -129,7 +130,7 @@ class SelfAttentionGuidance: sag_scale = scale sag_sigma = blur_sigma - sag_threshold = 1.0 + sag_threshold = threshold model = args["model"] uncond_pred = args["uncond_denoised"] uncond = args["uncond"] @@ -163,35 +164,52 @@ class SAGForForge(scripts.Script): sorting_priority = 12.5 def title(self): - return "SelfAttentionGuidance Integrated" + return "SelfAttentionGuidance Integrated (SD 1.x, SD 2.x, SDXL)" def show(self, is_img2img): return scripts.AlwaysVisible def ui(self, *args, **kwargs): - with gr.Accordion(open=False, label=self.title()): - enabled = gr.Checkbox(label='Enabled', value=False) + with InputAccordion(False, label=self.title()) as enabled: scale = gr.Slider(label='Scale', minimum=-2.0, maximum=5.0, step=0.01, value=0.5) blur_sigma = gr.Slider(label='Blur Sigma', minimum=0.0, maximum=10.0, step=0.01, value=2.0) + threshold = gr.Slider(label='Blur mask threshold', minimum=0.0, maximum=4.0, step=0.01, value=1.0) - return enabled, scale, blur_sigma + self.infotext_fields = [ + (enabled, lambda d: d.get("sag_enabled", False)), + (scale, "sag_scale"), + (blur_sigma, "sag_blur_sigma"), + (threshold, "sag_threshold"), + ] + + return enabled, scale, blur_sigma, threshold def process_before_every_sampling(self, p, *script_args, **kwargs): - enabled, scale, blur_sigma = script_args + enabled, scale, blur_sigma, threshold = script_args if not enabled: return + # not for FLux + if not shared.sd_model.is_webui_legacy_model(): # ideally would be is_flux + gr.Info ("Self Attention Guidance is not compatible with Flux") + return + # Self Attention Guidance errors if CFG is 1 + if p.cfg_scale == 1: + gr.Info ("Self Attention Guidance requires CFG > 1") + return + unet = p.sd_model.forge_objects.unet - unet = opSelfAttentionGuidance.patch(unet, scale, blur_sigma)[0] + unet = opSelfAttentionGuidance.patch(unet, scale, blur_sigma, threshold)[0] p.sd_model.forge_objects.unet = unet p.extra_generation_params.update(dict( - sag_enabled=enabled, - sag_scale=scale, - sag_blur_sigma=blur_sigma + sag_enabled = enabled, + sag_scale = scale, + sag_blur_sigma = blur_sigma, + sag_threshold = threshold, )) return