diff --git a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py index 68be9fb9..e4469975 100644 --- a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py +++ b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py @@ -7,14 +7,28 @@ import ldm_patched.ldm.modules.attention as attention def sdp(q, k, v, transformer_options): + if q.shape[0] == 0: + return q + return attention.optimized_attention(q, k, v, heads=transformer_options["n_heads"], mask=None) def adain(x, target_std, target_mean): + if x.shape[0] == 0: + return x + std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True, correction=0) return (((x - mean) / std) * target_std) + target_mean +def zero_cat(a, b, dim): + if a.shape[0] == 0: + return b + if b.shape[0] == 0: + return a + return torch.cat([a, b], dim=dim) + + class PreprocessorReference(Preprocessor): def __init__(self, name, use_attn=True, use_adain=True, priority=0): super().__init__() @@ -84,7 +98,7 @@ class PreprocessorReference(Preprocessor): return h channel = int(h.shape[1]) - minimal_channel = 1280 - 640 * weight + minimal_channel = 1500 - 1000 * weight if channel < minimal_channel: return h @@ -128,7 +142,7 @@ class PreprocessorReference(Preprocessor): transformer_options['block_index']) channel = int(q.shape[2]) - minimal_channel = 1280 - 1280 * weight + minimal_channel = 1500 - 1280 * weight if channel < minimal_channel: return sdp(q, k, v, transformer_options) @@ -152,9 +166,9 @@ class PreprocessorReference(Preprocessor): k_r, v_r = self.recorded_attn1[location] - o_c = sdp(q_c, torch.cat([k_c, k_r], dim=1), torch.cat([v_c, v_r], dim=1), transformer_options) + o_c = sdp(q_c, zero_cat(k_c, k_r, dim=1), zero_cat(v_c, v_r, dim=1), transformer_options) o_uc_strong = sdp(q_uc, k_uc, v_uc, transformer_options) - o_uc_weak = sdp(q_uc, torch.cat([k_uc, k_r], dim=1), torch.cat([v_uc, v_r], dim=1), transformer_options) + o_uc_weak = sdp(q_uc, zero_cat(k_uc, k_r, dim=1), zero_cat(v_uc, v_r, dim=1), transformer_options) o_uc = o_uc_weak + (o_uc_strong - o_uc_weak) * style_fidelity recon = []