diff --git a/extensions-builtin/sd_forge_stylealign/scripts/forge_stylealign.py b/extensions-builtin/sd_forge_stylealign/scripts/forge_stylealign.py index ab44be01..6322fae0 100644 --- a/extensions-builtin/sd_forge_stylealign/scripts/forge_stylealign.py +++ b/extensions-builtin/sd_forge_stylealign/scripts/forge_stylealign.py @@ -1,3 +1,4 @@ +import torch import gradio as gr from modules import scripts @@ -37,13 +38,33 @@ class StyleAlignForForge(scripts.Script): b, f, c = x.shape return x.reshape(1, b * f, c) - def attn1_proc(q, k, v, transformer_options): + def aligned_attention(q, k, v, transformer_options): b, f, c = q.shape o = sdp(join(q), join(k), join(v), transformer_options) b2, f2, c2 = o.shape o = o.reshape(b, b2 * f2 // b, c2) return o + def attn1_proc(q, k, v, transformer_options): + cond_indices = transformer_options['cond_indices'] + uncond_indices = transformer_options['uncond_indices'] + cond_or_uncond = transformer_options['cond_or_uncond'] + results = [] + + for cx in cond_or_uncond: + if cx == 0: + indices = cond_indices + else: + indices = uncond_indices + + if len(indices) > 0: + bq, bk, bv = q[indices], k[indices], v[indices] + bo = aligned_attention(bq, bk, bv, transformer_options) + results.append(bo) + + results = torch.cat(results, dim=0) + return results + unet.set_model_replace_all(attn1_proc, 'attn1') p.sd_model.forge_objects.unet = unet