mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-11 02:19:59 +00:00
Perturbed Attention Guidance Integrated
This commit is contained in:
@@ -0,0 +1,57 @@
|
||||
import gradio as gr
|
||||
import ldm_patched.modules.samplers
|
||||
|
||||
from modules import scripts
|
||||
from modules_forge.unet_patcher import copy_and_update_model_options
|
||||
|
||||
|
||||
class PerturbedAttentionGuidanceForForge(scripts.Script):
|
||||
sorting_priority = 13
|
||||
|
||||
def title(self):
|
||||
return "Perturbed Attention Guidance Integrated"
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, *args, **kwargs):
|
||||
with gr.Accordion(open=False, label=self.title()):
|
||||
enabled = gr.Checkbox(label='Enabled', value=False)
|
||||
scale = gr.Slider(label='Scale', minimum=0.0, maximum=100.0, step=0.1, value=3.0)
|
||||
|
||||
return enabled, scale
|
||||
|
||||
def process_before_every_sampling(self, p, *script_args, **kwargs):
|
||||
enabled, scale = script_args
|
||||
|
||||
if not enabled:
|
||||
return
|
||||
|
||||
unet = p.sd_model.forge_objects.unet.clone()
|
||||
|
||||
def attn_proc(q, k, v, to):
|
||||
return v
|
||||
|
||||
def post_cfg_function(args):
|
||||
model, cond_denoised, cond, denoised, sigma, x = \
|
||||
args["model"], args["cond_denoised"], args["cond"], args["denoised"], args["sigma"], args["input"]
|
||||
|
||||
new_options = copy_and_update_model_options(args["model_options"], attn_proc, "attn1", "middle", 0)
|
||||
|
||||
if scale == 0:
|
||||
return denoised
|
||||
|
||||
degraded, _ = ldm_patched.modules.samplers.calc_cond_uncond_batch(model, cond, None, x, sigma, new_options)
|
||||
|
||||
return denoised + (cond_denoised - degraded) * scale
|
||||
|
||||
unet.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
p.sd_model.forge_objects.unet = unet
|
||||
|
||||
p.extra_generation_params.update(dict(
|
||||
PerturbedAttentionGuidance_enabled=enabled,
|
||||
PerturbedAttentionGuidance_scale=scale,
|
||||
))
|
||||
|
||||
return
|
||||
@@ -196,3 +196,16 @@ class UnetPatcher(ModelPatcher):
|
||||
|
||||
self.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
|
||||
return
|
||||
|
||||
|
||||
def copy_and_update_model_options(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
model_options = model_options.copy()
|
||||
transformer_options = model_options.get("transformer_options", {}).copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {}).copy()
|
||||
name_patches = patches_replace.get(name, {}).copy()
|
||||
block = (block_name, number, transformer_index) if transformer_index is not None else (block_name, number)
|
||||
name_patches[block] = patch
|
||||
patches_replace[name] = name_patches
|
||||
transformer_options["patches_replace"] = patches_replace
|
||||
model_options["transformer_options"] = transformer_options
|
||||
return model_options
|
||||
|
||||
Reference in New Issue
Block a user