From f7ab23b1cb5d2e17d33ecc8febe9313bf824d897 Mon Sep 17 00:00:00 2001 From: twarrendewit Date: Thu, 22 Aug 2024 14:52:41 -0500 Subject: [PATCH] Variable-Strength StyleAlign (#1387) * adding efficiency * adding variable strength * Revert "adding efficiency" This reverts commit 6d0ad98c06fd7d68b8cf4cfe005886c347f1bfba. * updating with 0 and 1 cases --------- Co-authored-by: T. Warren de Wit --- .../scripts/forge_stylealign.py | 29 +++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/extensions-builtin/sd_forge_stylealign/scripts/forge_stylealign.py b/extensions-builtin/sd_forge_stylealign/scripts/forge_stylealign.py index 901cc0e7..010df4a6 100644 --- a/extensions-builtin/sd_forge_stylealign/scripts/forge_stylealign.py +++ b/extensions-builtin/sd_forge_stylealign/scripts/forge_stylealign.py @@ -22,14 +22,15 @@ class StyleAlignForForge(scripts.Script): def ui(self, *args, **kwargs): with gr.Accordion(open=False, label=self.title()): shared_attention = gr.Checkbox(label='Share attention in batch', value=False) + strength = gr.Slider(label='Strength', minimum=0.0, maximum=1.0, value=1.0) - return [shared_attention] + return [shared_attention, strength] def process_before_every_sampling(self, p, *script_args, **kwargs): # This will be called before every sampling. # If you use highres fix, this will be called twice. - shared_attention = script_args[0] + shared_attention, strength = script_args if not shared_attention: return @@ -60,9 +61,26 @@ class StyleAlignForForge(scripts.Script): indices = uncond_indices if len(indices) > 0: + bq, bk, bv = q[indices], k[indices], v[indices] - bo = aligned_attention(bq, bk, bv, transformer_options) - results.append(bo) + + if strength < 0.01: + # At strength = 0, use original. + original_attention = sdp(bq, bk, bv, transformer_options) + results.append(original_attention) + + elif strength > 0.99: + # At strength 1, use aligned. + aligned_attention_result = aligned_attention(bq, bk, bv, transformer_options) + results.append(aligned_attention_result) + + else: + # In between, blend original and aligned attention based on strength. + original_attention = sdp(bq, bk, bv, transformer_options) + aligned_attention_result = aligned_attention(bq, bk, bv, transformer_options) + blended_attention = (1.0 - strength) * original_attention + strength * aligned_attention_result + results.append(blended_attention) + results = torch.cat(results, dim=0) return results @@ -75,6 +93,7 @@ class StyleAlignForForge(scripts.Script): # The extra_generation_params does not influence results. p.extra_generation_params.update(dict( stylealign_enabled=shared_attention, + stylealign_strength=strength, )) - return + return \ No newline at end of file