Update forge_reference.py

This commit is contained in:
lllyasviel
2024-01-30 21:04:44 -08:00
parent 035ad4836a
commit feddafef45

View File

@@ -7,14 +7,28 @@ import ldm_patched.ldm.modules.attention as attention
def sdp(q, k, v, transformer_options):
if q.shape[0] == 0:
return q
return attention.optimized_attention(q, k, v, heads=transformer_options["n_heads"], mask=None)
def adain(x, target_std, target_mean):
if x.shape[0] == 0:
return x
std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True, correction=0)
return (((x - mean) / std) * target_std) + target_mean
def zero_cat(a, b, dim):
if a.shape[0] == 0:
return b
if b.shape[0] == 0:
return a
return torch.cat([a, b], dim=dim)
class PreprocessorReference(Preprocessor):
def __init__(self, name, use_attn=True, use_adain=True, priority=0):
super().__init__()
@@ -84,7 +98,7 @@ class PreprocessorReference(Preprocessor):
return h
channel = int(h.shape[1])
minimal_channel = 1280 - 640 * weight
minimal_channel = 1500 - 1000 * weight
if channel < minimal_channel:
return h
@@ -128,7 +142,7 @@ class PreprocessorReference(Preprocessor):
transformer_options['block_index'])
channel = int(q.shape[2])
minimal_channel = 1280 - 1280 * weight
minimal_channel = 1500 - 1280 * weight
if channel < minimal_channel:
return sdp(q, k, v, transformer_options)
@@ -152,9 +166,9 @@ class PreprocessorReference(Preprocessor):
k_r, v_r = self.recorded_attn1[location]
o_c = sdp(q_c, torch.cat([k_c, k_r], dim=1), torch.cat([v_c, v_r], dim=1), transformer_options)
o_c = sdp(q_c, zero_cat(k_c, k_r, dim=1), zero_cat(v_c, v_r, dim=1), transformer_options)
o_uc_strong = sdp(q_uc, k_uc, v_uc, transformer_options)
o_uc_weak = sdp(q_uc, torch.cat([k_uc, k_r], dim=1), torch.cat([v_uc, v_r], dim=1), transformer_options)
o_uc_weak = sdp(q_uc, zero_cat(k_uc, k_r, dim=1), zero_cat(v_uc, v_r, dim=1), transformer_options)
o_uc = o_uc_weak + (o_uc_strong - o_uc_weak) * style_fidelity
recon = []