From c6a344aed3f8698bc65b18996cff68352109bb15 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Tue, 30 Jan 2024 20:40:47 -0800 Subject: [PATCH] Update forge_reference.py --- .../scripts/forge_reference.py | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py index 6f23cefd..fdd6a2bd 100644 --- a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py +++ b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py @@ -10,6 +10,11 @@ def sdp(q, k, v, transformer_options): return attention.optimized_attention(q, k, v, heads=transformer_options["n_heads"], mask=None) +def adain(x, target_std, target_mean): + std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True, correction=0) + return (((x - mean) / std) * target_std) + target_mean + + class PreprocessorReference(Preprocessor): def __init__(self, name, use_attn=True, use_adain=True, priority=0): super().__init__() @@ -79,13 +84,31 @@ class PreprocessorReference(Preprocessor): return h if self.is_recording_style: - self.recorded_h[location] = h + self.recorded_h[location] = torch.std_mean(h, dim=(2, 3), keepdim=True, correction=0) return h else: cond_indices = transformer_options['cond_indices'] uncond_indices = transformer_options['uncond_indices'] - rh = self.recorded_h[location] - return h + cond_or_uncond = transformer_options['cond_or_uncond'] + r_std, r_mean = self.recorded_h[location] + + h_c = h[cond_indices] + h_uc = h[uncond_indices] + + o_c = adain(h_c, r_std, r_mean) + o_uc_strong = h_uc + o_uc_weak = adain(h_uc, r_std, r_mean) + o_uc = o_uc_weak + (o_uc_strong - o_uc_weak) * style_fidelity + + recon = [] + for cx in cond_or_uncond: + if cx == 0: + recon.append(o_c) + else: + recon.append(o_uc) + + o = torch.cat(recon, dim=0) + return o def attn1_proc(q, k, v, transformer_options): if not self.use_attn: