revise attention alignment in stylealign #100

A mistake in 0day release is that the attention layers of cond and uncond items in a batch are aligned when they should not.
after align batch in cond and uncond separately they now works and give same results to legacy sd-webui-cnet
This commit is contained in:
lllyasviel
2024-02-07 19:11:53 -08:00
parent f63917a323
commit 42dd258c8d

View File

@@ -1,3 +1,4 @@
import torch
import gradio as gr
from modules import scripts
@@ -37,13 +38,33 @@ class StyleAlignForForge(scripts.Script):
b, f, c = x.shape
return x.reshape(1, b * f, c)
def attn1_proc(q, k, v, transformer_options):
def aligned_attention(q, k, v, transformer_options):
b, f, c = q.shape
o = sdp(join(q), join(k), join(v), transformer_options)
b2, f2, c2 = o.shape
o = o.reshape(b, b2 * f2 // b, c2)
return o
def attn1_proc(q, k, v, transformer_options):
cond_indices = transformer_options['cond_indices']
uncond_indices = transformer_options['uncond_indices']
cond_or_uncond = transformer_options['cond_or_uncond']
results = []
for cx in cond_or_uncond:
if cx == 0:
indices = cond_indices
else:
indices = uncond_indices
if len(indices) > 0:
bq, bk, bv = q[indices], k[indices], v[indices]
bo = aligned_attention(bq, bk, bv, transformer_options)
results.append(bo)
results = torch.cat(results, dim=0)
return results
unet.set_model_replace_all(attn1_proc, 'attn1')
p.sd_model.forge_objects.unet = unet