Update forge_reference.py

This commit is contained in:
lllyasviel
2024-01-30 20:40:47 -08:00
parent c2e2d0ea9b
commit c6a344aed3

View File

@@ -10,6 +10,11 @@ def sdp(q, k, v, transformer_options):
return attention.optimized_attention(q, k, v, heads=transformer_options["n_heads"], mask=None)
def adain(x, target_std, target_mean):
std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True, correction=0)
return (((x - mean) / std) * target_std) + target_mean
class PreprocessorReference(Preprocessor):
def __init__(self, name, use_attn=True, use_adain=True, priority=0):
super().__init__()
@@ -79,13 +84,31 @@ class PreprocessorReference(Preprocessor):
return h
if self.is_recording_style:
self.recorded_h[location] = h
self.recorded_h[location] = torch.std_mean(h, dim=(2, 3), keepdim=True, correction=0)
return h
else:
cond_indices = transformer_options['cond_indices']
uncond_indices = transformer_options['uncond_indices']
rh = self.recorded_h[location]
return h
cond_or_uncond = transformer_options['cond_or_uncond']
r_std, r_mean = self.recorded_h[location]
h_c = h[cond_indices]
h_uc = h[uncond_indices]
o_c = adain(h_c, r_std, r_mean)
o_uc_strong = h_uc
o_uc_weak = adain(h_uc, r_std, r_mean)
o_uc = o_uc_weak + (o_uc_strong - o_uc_weak) * style_fidelity
recon = []
for cx in cond_or_uncond:
if cx == 0:
recon.append(o_c)
else:
recon.append(o_uc)
o = torch.cat(recon, dim=0)
return o
def attn1_proc(q, k, v, transformer_options):
if not self.use_attn: