mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 10:41:25 +00:00
SelfAttentionGuidance updates (#1725)
* SAG updates * restore settings from infotext * bypass if model is Flux (would error during generation) * bypass if CFG == 1 (would error during generation) * add control of blur mask threshold (previously hardcoded) * change to InputAccordion * forge_sag.py: update title
This commit is contained in:
@@ -6,7 +6,8 @@ from backend.sampling.sampling_function import calc_cond_uncond_batch
|
|||||||
from backend import attention, memory_management
|
from backend import attention, memory_management
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from modules import scripts
|
from modules import scripts, shared
|
||||||
|
from modules.ui_components import InputAccordion
|
||||||
|
|
||||||
|
|
||||||
attn_precision = memory_management.force_upcast_attention_dtype()
|
attn_precision = memory_management.force_upcast_attention_dtype()
|
||||||
@@ -99,7 +100,7 @@ def gaussian_blur_2d(img, kernel_size, sigma):
|
|||||||
|
|
||||||
|
|
||||||
class SelfAttentionGuidance:
|
class SelfAttentionGuidance:
|
||||||
def patch(self, model, scale, blur_sigma):
|
def patch(self, model, scale, blur_sigma, threshold):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|
||||||
attn_scores = None
|
attn_scores = None
|
||||||
@@ -129,7 +130,7 @@ class SelfAttentionGuidance:
|
|||||||
|
|
||||||
sag_scale = scale
|
sag_scale = scale
|
||||||
sag_sigma = blur_sigma
|
sag_sigma = blur_sigma
|
||||||
sag_threshold = 1.0
|
sag_threshold = threshold
|
||||||
model = args["model"]
|
model = args["model"]
|
||||||
uncond_pred = args["uncond_denoised"]
|
uncond_pred = args["uncond_denoised"]
|
||||||
uncond = args["uncond"]
|
uncond = args["uncond"]
|
||||||
@@ -163,35 +164,52 @@ class SAGForForge(scripts.Script):
|
|||||||
sorting_priority = 12.5
|
sorting_priority = 12.5
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
return "SelfAttentionGuidance Integrated"
|
return "SelfAttentionGuidance Integrated (SD 1.x, SD 2.x, SDXL)"
|
||||||
|
|
||||||
def show(self, is_img2img):
|
def show(self, is_img2img):
|
||||||
return scripts.AlwaysVisible
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
def ui(self, *args, **kwargs):
|
def ui(self, *args, **kwargs):
|
||||||
with gr.Accordion(open=False, label=self.title()):
|
with InputAccordion(False, label=self.title()) as enabled:
|
||||||
enabled = gr.Checkbox(label='Enabled', value=False)
|
|
||||||
scale = gr.Slider(label='Scale', minimum=-2.0, maximum=5.0, step=0.01, value=0.5)
|
scale = gr.Slider(label='Scale', minimum=-2.0, maximum=5.0, step=0.01, value=0.5)
|
||||||
blur_sigma = gr.Slider(label='Blur Sigma', minimum=0.0, maximum=10.0, step=0.01, value=2.0)
|
blur_sigma = gr.Slider(label='Blur Sigma', minimum=0.0, maximum=10.0, step=0.01, value=2.0)
|
||||||
|
threshold = gr.Slider(label='Blur mask threshold', minimum=0.0, maximum=4.0, step=0.01, value=1.0)
|
||||||
|
|
||||||
return enabled, scale, blur_sigma
|
self.infotext_fields = [
|
||||||
|
(enabled, lambda d: d.get("sag_enabled", False)),
|
||||||
|
(scale, "sag_scale"),
|
||||||
|
(blur_sigma, "sag_blur_sigma"),
|
||||||
|
(threshold, "sag_threshold"),
|
||||||
|
]
|
||||||
|
|
||||||
|
return enabled, scale, blur_sigma, threshold
|
||||||
|
|
||||||
def process_before_every_sampling(self, p, *script_args, **kwargs):
|
def process_before_every_sampling(self, p, *script_args, **kwargs):
|
||||||
enabled, scale, blur_sigma = script_args
|
enabled, scale, blur_sigma, threshold = script_args
|
||||||
|
|
||||||
if not enabled:
|
if not enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# not for FLux
|
||||||
|
if not shared.sd_model.is_webui_legacy_model(): # ideally would be is_flux
|
||||||
|
gr.Info ("Self Attention Guidance is not compatible with Flux")
|
||||||
|
return
|
||||||
|
# Self Attention Guidance errors if CFG is 1
|
||||||
|
if p.cfg_scale == 1:
|
||||||
|
gr.Info ("Self Attention Guidance requires CFG > 1")
|
||||||
|
return
|
||||||
|
|
||||||
unet = p.sd_model.forge_objects.unet
|
unet = p.sd_model.forge_objects.unet
|
||||||
|
|
||||||
unet = opSelfAttentionGuidance.patch(unet, scale, blur_sigma)[0]
|
unet = opSelfAttentionGuidance.patch(unet, scale, blur_sigma, threshold)[0]
|
||||||
|
|
||||||
p.sd_model.forge_objects.unet = unet
|
p.sd_model.forge_objects.unet = unet
|
||||||
|
|
||||||
p.extra_generation_params.update(dict(
|
p.extra_generation_params.update(dict(
|
||||||
sag_enabled=enabled,
|
sag_enabled = enabled,
|
||||||
sag_scale=scale,
|
sag_scale = scale,
|
||||||
sag_blur_sigma=blur_sigma
|
sag_blur_sigma = blur_sigma,
|
||||||
|
sag_threshold = threshold,
|
||||||
))
|
))
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user