Variable-Strength StyleAlign (#1387)

* adding efficiency

* adding variable strength

* Revert "adding efficiency"

This reverts commit 6d0ad98c06.

* updating with 0 and 1 cases

---------

Co-authored-by: T. Warren de Wit <tww0007@uah.edu>
This commit is contained in:
twarrendewit
2024-08-22 14:52:41 -05:00
committed by GitHub
parent d169cd5881
commit f7ab23b1cb

View File

@@ -22,14 +22,15 @@ class StyleAlignForForge(scripts.Script):
def ui(self, *args, **kwargs):
with gr.Accordion(open=False, label=self.title()):
shared_attention = gr.Checkbox(label='Share attention in batch', value=False)
strength = gr.Slider(label='Strength', minimum=0.0, maximum=1.0, value=1.0)
return [shared_attention]
return [shared_attention, strength]
def process_before_every_sampling(self, p, *script_args, **kwargs):
# This will be called before every sampling.
# If you use highres fix, this will be called twice.
shared_attention = script_args[0]
shared_attention, strength = script_args
if not shared_attention:
return
@@ -60,9 +61,26 @@ class StyleAlignForForge(scripts.Script):
indices = uncond_indices
if len(indices) > 0:
bq, bk, bv = q[indices], k[indices], v[indices]
bo = aligned_attention(bq, bk, bv, transformer_options)
results.append(bo)
if strength < 0.01:
# At strength = 0, use original.
original_attention = sdp(bq, bk, bv, transformer_options)
results.append(original_attention)
elif strength > 0.99:
# At strength 1, use aligned.
aligned_attention_result = aligned_attention(bq, bk, bv, transformer_options)
results.append(aligned_attention_result)
else:
# In between, blend original and aligned attention based on strength.
original_attention = sdp(bq, bk, bv, transformer_options)
aligned_attention_result = aligned_attention(bq, bk, bv, transformer_options)
blended_attention = (1.0 - strength) * original_attention + strength * aligned_attention_result
results.append(blended_attention)
results = torch.cat(results, dim=0)
return results
@@ -75,6 +93,7 @@ class StyleAlignForForge(scripts.Script):
# The extra_generation_params does not influence results.
p.extra_generation_params.update(dict(
stylealign_enabled=shared_attention,
stylealign_strength=strength,
))
return
return