From 1673a5ac2d993e53a05164a25362dcc48082dfbe Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Fri, 26 Jul 2024 15:00:31 -0700 Subject: [PATCH] Perturbed Attention Guidance Integrated --- .../scripts/forge_perturbed_attention.py | 57 +++++++++++++++++++ modules_forge/unet_patcher.py | 13 +++++ 2 files changed, 70 insertions(+) create mode 100644 extensions-builtin/sd_forge_perturbed_attention/scripts/forge_perturbed_attention.py diff --git a/extensions-builtin/sd_forge_perturbed_attention/scripts/forge_perturbed_attention.py b/extensions-builtin/sd_forge_perturbed_attention/scripts/forge_perturbed_attention.py new file mode 100644 index 00000000..1fe5be68 --- /dev/null +++ b/extensions-builtin/sd_forge_perturbed_attention/scripts/forge_perturbed_attention.py @@ -0,0 +1,57 @@ +import gradio as gr +import ldm_patched.modules.samplers + +from modules import scripts +from modules_forge.unet_patcher import copy_and_update_model_options + + +class PerturbedAttentionGuidanceForForge(scripts.Script): + sorting_priority = 13 + + def title(self): + return "Perturbed Attention Guidance Integrated" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, *args, **kwargs): + with gr.Accordion(open=False, label=self.title()): + enabled = gr.Checkbox(label='Enabled', value=False) + scale = gr.Slider(label='Scale', minimum=0.0, maximum=100.0, step=0.1, value=3.0) + + return enabled, scale + + def process_before_every_sampling(self, p, *script_args, **kwargs): + enabled, scale = script_args + + if not enabled: + return + + unet = p.sd_model.forge_objects.unet.clone() + + def attn_proc(q, k, v, to): + return v + + def post_cfg_function(args): + model, cond_denoised, cond, denoised, sigma, x = \ + args["model"], args["cond_denoised"], args["cond"], args["denoised"], args["sigma"], args["input"] + + new_options = copy_and_update_model_options(args["model_options"], attn_proc, "attn1", "middle", 0) + + if scale == 0: + return denoised + + degraded, _ = ldm_patched.modules.samplers.calc_cond_uncond_batch(model, cond, None, x, sigma, new_options) + + return denoised + (cond_denoised - degraded) * scale + + unet.set_model_sampler_post_cfg_function(post_cfg_function) + + p.sd_model.forge_objects.unet = unet + + p.extra_generation_params.update(dict( + PerturbedAttentionGuidance_enabled=enabled, + PerturbedAttentionGuidance_scale=scale, + )) + + return diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index 275e0e96..01164c09 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -196,3 +196,16 @@ class UnetPatcher(ModelPatcher): self.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0) return + + +def copy_and_update_model_options(model_options, patch, name, block_name, number, transformer_index=None): + model_options = model_options.copy() + transformer_options = model_options.get("transformer_options", {}).copy() + patches_replace = transformer_options.get("patches_replace", {}).copy() + name_patches = patches_replace.get(name, {}).copy() + block = (block_name, number, transformer_index) if transformer_index is not None else (block_name, number) + name_patches[block] = patch + patches_replace[name] = name_patches + transformer_options["patches_replace"] = patches_replace + model_options["transformer_options"] = transformer_options + return model_options