mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
@@ -1,13 +1,18 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import scripts
|
from modules import scripts
|
||||||
|
from modules.script_callbacks import on_cfg_denoiser, remove_current_script_callbacks
|
||||||
from backend.patcher.base import set_model_options_patch_replace
|
from backend.patcher.base import set_model_options_patch_replace
|
||||||
from backend.sampling.sampling_function import calc_cond_uncond_batch
|
from backend.sampling.sampling_function import calc_cond_uncond_batch
|
||||||
|
from modules.ui_components import InputAccordion
|
||||||
|
|
||||||
|
|
||||||
class PerturbedAttentionGuidanceForForge(scripts.Script):
|
class PerturbedAttentionGuidanceForForge(scripts.Script):
|
||||||
sorting_priority = 13
|
sorting_priority = 13
|
||||||
|
|
||||||
|
attenuated_scale = 3.0
|
||||||
|
doPAG = True
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
return "PerturbedAttentionGuidance Integrated"
|
return "PerturbedAttentionGuidance Integrated"
|
||||||
|
|
||||||
@@ -15,43 +20,82 @@ class PerturbedAttentionGuidanceForForge(scripts.Script):
|
|||||||
return scripts.AlwaysVisible
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
def ui(self, *args, **kwargs):
|
def ui(self, *args, **kwargs):
|
||||||
with gr.Accordion(open=False, label=self.title()):
|
with InputAccordion(False, label=self.title()) as enabled:
|
||||||
enabled = gr.Checkbox(label='Enabled', value=False)
|
with gr.Row():
|
||||||
scale = gr.Slider(label='Scale', minimum=0.0, maximum=100.0, step=0.1, value=3.0)
|
scale = gr.Slider(label='Scale', minimum=0.0, maximum=100.0, step=0.1, value=3.0)
|
||||||
|
attenuation = gr.Slider(label='Attenuation (linear, % of scale)', minimum=0.0, maximum=100.0, step=0.1, value=0.0)
|
||||||
|
with gr.Row():
|
||||||
|
start_step = gr.Slider(label='Start step', minimum=0.0, maximum=1.0, step=0.01, value=0.0)
|
||||||
|
end_step = gr.Slider(label='End step', minimum=0.0, maximum=1.0, step=0.01, value=1.0)
|
||||||
|
|
||||||
return enabled, scale
|
self.infotext_fields = [
|
||||||
|
(enabled, lambda d: d.get("pagi_enabled", False)),
|
||||||
|
(scale, "pagi_scale"),
|
||||||
|
(attenuation, "pagi_attenuation"),
|
||||||
|
(start_step, "pagi_start_step"),
|
||||||
|
(end_step, "pagi_end_step"),
|
||||||
|
]
|
||||||
|
|
||||||
|
return enabled, scale, attenuation, start_step, end_step
|
||||||
|
|
||||||
|
def denoiser_callback(self, params):
|
||||||
|
thisStep = (params.sampling_step) / (params.total_sampling_steps - 1)
|
||||||
|
|
||||||
|
if thisStep >= PerturbedAttentionGuidanceForForge.PAG_start and thisStep <= PerturbedAttentionGuidanceForForge.PAG_end:
|
||||||
|
PerturbedAttentionGuidanceForForge.doPAG = True
|
||||||
|
else:
|
||||||
|
PerturbedAttentionGuidanceForForge.doPAG = False
|
||||||
|
|
||||||
def process_before_every_sampling(self, p, *script_args, **kwargs):
|
def process_before_every_sampling(self, p, *script_args, **kwargs):
|
||||||
enabled, scale = script_args
|
enabled, scale, attenuation, start_step, end_step = script_args
|
||||||
|
|
||||||
if not enabled:
|
if not enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
PerturbedAttentionGuidanceForForge.scale = scale
|
||||||
|
PerturbedAttentionGuidanceForForge.PAG_start = start_step
|
||||||
|
PerturbedAttentionGuidanceForForge.PAG_end = end_step
|
||||||
|
on_cfg_denoiser(self.denoiser_callback)
|
||||||
|
|
||||||
unet = p.sd_model.forge_objects.unet.clone()
|
unet = p.sd_model.forge_objects.unet.clone()
|
||||||
|
|
||||||
def attn_proc(q, k, v, to):
|
def attn_proc(q, k, v, to):
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
model, cond_denoised, cond, denoised, sigma, x = \
|
denoised = args["denoised"]
|
||||||
args["model"], args["cond_denoised"], args["cond"], args["denoised"], args["sigma"], args["input"]
|
|
||||||
|
|
||||||
new_options = set_model_options_patch_replace(args["model_options"], attn_proc, "attn1", "middle", 0)
|
if PerturbedAttentionGuidanceForForge.scale <= 0.0:
|
||||||
|
|
||||||
if scale == 0:
|
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
|
if not PerturbedAttentionGuidanceForForge.doPAG:
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
model, cond_denoised, cond, sigma, x, options = \
|
||||||
|
args["model"], args["cond_denoised"], args["cond"], args["sigma"], args["input"], args["model_options"].copy()
|
||||||
|
new_options = set_model_options_patch_replace(options, attn_proc, "attn1", "middle", 0)
|
||||||
|
|
||||||
degraded, _ = calc_cond_uncond_batch(model, cond, None, x, sigma, new_options)
|
degraded, _ = calc_cond_uncond_batch(model, cond, None, x, sigma, new_options)
|
||||||
|
|
||||||
return denoised + (cond_denoised - degraded) * scale
|
result = denoised + (cond_denoised - degraded) * PerturbedAttentionGuidanceForForge.scale
|
||||||
|
PerturbedAttentionGuidanceForForge.scale -= scale * attenuation / 100.0
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
unet.set_model_sampler_post_cfg_function(post_cfg_function)
|
unet.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||||
|
|
||||||
p.sd_model.forge_objects.unet = unet
|
p.sd_model.forge_objects.unet = unet
|
||||||
|
|
||||||
p.extra_generation_params.update(dict(
|
p.extra_generation_params.update(dict(
|
||||||
PerturbedAttentionGuidance_enabled=enabled,
|
pagi_enabled = enabled,
|
||||||
PerturbedAttentionGuidance_scale=scale,
|
pagi_scale = scale,
|
||||||
|
pagi_attenuation = attenuation,
|
||||||
|
pagi_start_step = start_step,
|
||||||
|
pagi_end_step = end_step,
|
||||||
))
|
))
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def postprocess(self, params, processed, *args):
|
||||||
|
remove_current_script_callbacks()
|
||||||
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user