diff --git a/extensions-builtin/sd_forge_stylealign/scripts/forge_stylealign.py b/extensions-builtin/sd_forge_stylealign/scripts/forge_stylealign.py new file mode 100644 index 00000000..ab44be01 --- /dev/null +++ b/extensions-builtin/sd_forge_stylealign/scripts/forge_stylealign.py @@ -0,0 +1,57 @@ +import gradio as gr + +from modules import scripts +import ldm_patched.ldm.modules.attention as attention + + +def sdp(q, k, v, transformer_options): + return attention.optimized_attention(q, k, v, heads=transformer_options["n_heads"], mask=None) + + +class StyleAlignForForge(scripts.Script): + def title(self): + return "StyleAlign Integrated" + + def show(self, is_img2img): + # make this extension visible in both txt2img and img2img tab. + return scripts.AlwaysVisible + + def ui(self, *args, **kwargs): + with gr.Accordion(open=False, label=self.title()): + shared_attention = gr.Checkbox(label='Share attention in batch', value=False) + + return [shared_attention] + + def process_before_every_sampling(self, p, *script_args, **kwargs): + # This will be called before every sampling. + # If you use highres fix, this will be called twice. + + shared_attention = script_args[0] + + if not shared_attention: + return + + unet = p.sd_model.forge_objects.unet.clone() + + def join(x): + b, f, c = x.shape + return x.reshape(1, b * f, c) + + def attn1_proc(q, k, v, transformer_options): + b, f, c = q.shape + o = sdp(join(q), join(k), join(v), transformer_options) + b2, f2, c2 = o.shape + o = o.reshape(b, b2 * f2 // b, c2) + return o + + unet.set_model_replace_all(attn1_proc, 'attn1') + + p.sd_model.forge_objects.unet = unet + + # Below codes will add some logs to the texts below the image outputs on UI. + # The extra_generation_params does not influence results. + p.extra_generation_params.update(dict( + stylealign_enabled=shared_attention, + )) + + return