diff --git a/extensions-builtin/sd_forge_sag/scripts/forge_sag.py b/extensions-builtin/sd_forge_sag/scripts/forge_sag.py new file mode 100644 index 00000000..90e246a4 --- /dev/null +++ b/extensions-builtin/sd_forge_sag/scripts/forge_sag.py @@ -0,0 +1,45 @@ +import gradio as gr + +from modules import scripts +from ldm_patched.contrib.external_sag import SelfAttentionGuidance + + +opSelfAttentionGuidance = SelfAttentionGuidance() + + +class SAGForForge(scripts.Script): + def title(self): + return "SelfAttentionGuidance Integrated" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, *args, **kwargs): + with gr.Accordion(open=False, label=self.title()): + enabled = gr.Checkbox(label='Enabled', value=False) + scale = gr.Slider(label='Scale', minimum=-2.0, maximum=5.0, step=0.01, value=0.5) + blur_sigma = gr.Slider(label='Blur Sigma', minimum=0.0, maximum=10.0, step=0.01, value=2.0) + + return enabled, scale, blur_sigma + + def process_batch(self, p, *script_args, **kwargs): + enabled, scale, blur_sigma = script_args + + if not enabled: + return + + unet = p.sd_model.forge_objects.unet + + unet = opSelfAttentionGuidance.patch(unet, scale, blur_sigma)[0] + + p.sd_model.forge_objects.unet = unet + + # Below codes will add some logs to the texts below the image outputs on UI. + # The extra_generation_params does not influence results. + p.extra_generation_params.update(dict( + sag_enabled=enabled, + sag_scale=scale, + sag_blur_sigma=blur_sigma + )) + + return