mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 02:31:16 +00:00
Update forge_reference.py
This commit is contained in:
@@ -7,14 +7,28 @@ import ldm_patched.ldm.modules.attention as attention
|
|||||||
|
|
||||||
|
|
||||||
def sdp(q, k, v, transformer_options):
|
def sdp(q, k, v, transformer_options):
|
||||||
|
if q.shape[0] == 0:
|
||||||
|
return q
|
||||||
|
|
||||||
return attention.optimized_attention(q, k, v, heads=transformer_options["n_heads"], mask=None)
|
return attention.optimized_attention(q, k, v, heads=transformer_options["n_heads"], mask=None)
|
||||||
|
|
||||||
|
|
||||||
def adain(x, target_std, target_mean):
|
def adain(x, target_std, target_mean):
|
||||||
|
if x.shape[0] == 0:
|
||||||
|
return x
|
||||||
|
|
||||||
std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True, correction=0)
|
std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True, correction=0)
|
||||||
return (((x - mean) / std) * target_std) + target_mean
|
return (((x - mean) / std) * target_std) + target_mean
|
||||||
|
|
||||||
|
|
||||||
|
def zero_cat(a, b, dim):
|
||||||
|
if a.shape[0] == 0:
|
||||||
|
return b
|
||||||
|
if b.shape[0] == 0:
|
||||||
|
return a
|
||||||
|
return torch.cat([a, b], dim=dim)
|
||||||
|
|
||||||
|
|
||||||
class PreprocessorReference(Preprocessor):
|
class PreprocessorReference(Preprocessor):
|
||||||
def __init__(self, name, use_attn=True, use_adain=True, priority=0):
|
def __init__(self, name, use_attn=True, use_adain=True, priority=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -84,7 +98,7 @@ class PreprocessorReference(Preprocessor):
|
|||||||
return h
|
return h
|
||||||
|
|
||||||
channel = int(h.shape[1])
|
channel = int(h.shape[1])
|
||||||
minimal_channel = 1280 - 640 * weight
|
minimal_channel = 1500 - 1000 * weight
|
||||||
|
|
||||||
if channel < minimal_channel:
|
if channel < minimal_channel:
|
||||||
return h
|
return h
|
||||||
@@ -128,7 +142,7 @@ class PreprocessorReference(Preprocessor):
|
|||||||
transformer_options['block_index'])
|
transformer_options['block_index'])
|
||||||
|
|
||||||
channel = int(q.shape[2])
|
channel = int(q.shape[2])
|
||||||
minimal_channel = 1280 - 1280 * weight
|
minimal_channel = 1500 - 1280 * weight
|
||||||
|
|
||||||
if channel < minimal_channel:
|
if channel < minimal_channel:
|
||||||
return sdp(q, k, v, transformer_options)
|
return sdp(q, k, v, transformer_options)
|
||||||
@@ -152,9 +166,9 @@ class PreprocessorReference(Preprocessor):
|
|||||||
|
|
||||||
k_r, v_r = self.recorded_attn1[location]
|
k_r, v_r = self.recorded_attn1[location]
|
||||||
|
|
||||||
o_c = sdp(q_c, torch.cat([k_c, k_r], dim=1), torch.cat([v_c, v_r], dim=1), transformer_options)
|
o_c = sdp(q_c, zero_cat(k_c, k_r, dim=1), zero_cat(v_c, v_r, dim=1), transformer_options)
|
||||||
o_uc_strong = sdp(q_uc, k_uc, v_uc, transformer_options)
|
o_uc_strong = sdp(q_uc, k_uc, v_uc, transformer_options)
|
||||||
o_uc_weak = sdp(q_uc, torch.cat([k_uc, k_r], dim=1), torch.cat([v_uc, v_r], dim=1), transformer_options)
|
o_uc_weak = sdp(q_uc, zero_cat(k_uc, k_r, dim=1), zero_cat(v_uc, v_r, dim=1), transformer_options)
|
||||||
o_uc = o_uc_weak + (o_uc_strong - o_uc_weak) * style_fidelity
|
o_uc = o_uc_weak + (o_uc_strong - o_uc_weak) * style_fidelity
|
||||||
|
|
||||||
recon = []
|
recon = []
|
||||||
|
|||||||
Reference in New Issue
Block a user