mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-01 22:09:46 +00:00
Update forge_reference.py
This commit is contained in:
@@ -7,14 +7,28 @@ import ldm_patched.ldm.modules.attention as attention
|
||||
|
||||
|
||||
def sdp(q, k, v, transformer_options):
|
||||
if q.shape[0] == 0:
|
||||
return q
|
||||
|
||||
return attention.optimized_attention(q, k, v, heads=transformer_options["n_heads"], mask=None)
|
||||
|
||||
|
||||
def adain(x, target_std, target_mean):
|
||||
if x.shape[0] == 0:
|
||||
return x
|
||||
|
||||
std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True, correction=0)
|
||||
return (((x - mean) / std) * target_std) + target_mean
|
||||
|
||||
|
||||
def zero_cat(a, b, dim):
|
||||
if a.shape[0] == 0:
|
||||
return b
|
||||
if b.shape[0] == 0:
|
||||
return a
|
||||
return torch.cat([a, b], dim=dim)
|
||||
|
||||
|
||||
class PreprocessorReference(Preprocessor):
|
||||
def __init__(self, name, use_attn=True, use_adain=True, priority=0):
|
||||
super().__init__()
|
||||
@@ -84,7 +98,7 @@ class PreprocessorReference(Preprocessor):
|
||||
return h
|
||||
|
||||
channel = int(h.shape[1])
|
||||
minimal_channel = 1280 - 640 * weight
|
||||
minimal_channel = 1500 - 1000 * weight
|
||||
|
||||
if channel < minimal_channel:
|
||||
return h
|
||||
@@ -128,7 +142,7 @@ class PreprocessorReference(Preprocessor):
|
||||
transformer_options['block_index'])
|
||||
|
||||
channel = int(q.shape[2])
|
||||
minimal_channel = 1280 - 1280 * weight
|
||||
minimal_channel = 1500 - 1280 * weight
|
||||
|
||||
if channel < minimal_channel:
|
||||
return sdp(q, k, v, transformer_options)
|
||||
@@ -152,9 +166,9 @@ class PreprocessorReference(Preprocessor):
|
||||
|
||||
k_r, v_r = self.recorded_attn1[location]
|
||||
|
||||
o_c = sdp(q_c, torch.cat([k_c, k_r], dim=1), torch.cat([v_c, v_r], dim=1), transformer_options)
|
||||
o_c = sdp(q_c, zero_cat(k_c, k_r, dim=1), zero_cat(v_c, v_r, dim=1), transformer_options)
|
||||
o_uc_strong = sdp(q_uc, k_uc, v_uc, transformer_options)
|
||||
o_uc_weak = sdp(q_uc, torch.cat([k_uc, k_r], dim=1), torch.cat([v_uc, v_r], dim=1), transformer_options)
|
||||
o_uc_weak = sdp(q_uc, zero_cat(k_uc, k_r, dim=1), zero_cat(v_uc, v_r, dim=1), transformer_options)
|
||||
o_uc = o_uc_weak + (o_uc_strong - o_uc_weak) * style_fidelity
|
||||
|
||||
recon = []
|
||||
|
||||
Reference in New Issue
Block a user