mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-04 23:19:57 +00:00
Update forge_reference.py
This commit is contained in:
@@ -10,6 +10,11 @@ def sdp(q, k, v, transformer_options):
|
||||
return attention.optimized_attention(q, k, v, heads=transformer_options["n_heads"], mask=None)
|
||||
|
||||
|
||||
def adain(x, target_std, target_mean):
|
||||
std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True, correction=0)
|
||||
return (((x - mean) / std) * target_std) + target_mean
|
||||
|
||||
|
||||
class PreprocessorReference(Preprocessor):
|
||||
def __init__(self, name, use_attn=True, use_adain=True, priority=0):
|
||||
super().__init__()
|
||||
@@ -79,13 +84,31 @@ class PreprocessorReference(Preprocessor):
|
||||
return h
|
||||
|
||||
if self.is_recording_style:
|
||||
self.recorded_h[location] = h
|
||||
self.recorded_h[location] = torch.std_mean(h, dim=(2, 3), keepdim=True, correction=0)
|
||||
return h
|
||||
else:
|
||||
cond_indices = transformer_options['cond_indices']
|
||||
uncond_indices = transformer_options['uncond_indices']
|
||||
rh = self.recorded_h[location]
|
||||
return h
|
||||
cond_or_uncond = transformer_options['cond_or_uncond']
|
||||
r_std, r_mean = self.recorded_h[location]
|
||||
|
||||
h_c = h[cond_indices]
|
||||
h_uc = h[uncond_indices]
|
||||
|
||||
o_c = adain(h_c, r_std, r_mean)
|
||||
o_uc_strong = h_uc
|
||||
o_uc_weak = adain(h_uc, r_std, r_mean)
|
||||
o_uc = o_uc_weak + (o_uc_strong - o_uc_weak) * style_fidelity
|
||||
|
||||
recon = []
|
||||
for cx in cond_or_uncond:
|
||||
if cx == 0:
|
||||
recon.append(o_c)
|
||||
else:
|
||||
recon.append(o_uc)
|
||||
|
||||
o = torch.cat(recon, dim=0)
|
||||
return o
|
||||
|
||||
def attn1_proc(q, k, v, transformer_options):
|
||||
if not self.use_attn:
|
||||
|
||||
Reference in New Issue
Block a user