Update forge_reference.py

This commit is contained in:
lllyasviel
2024-01-30 18:51:10 -08:00
parent c22c1e6726
commit df673fa6dc

View File

@@ -56,6 +56,9 @@ class PreprocessorReference(Preprocessor):
return model, x, timestep, uncond, cond, cond_scale, model_options, seed
def block_proc(h, flag, transformer_options):
if not self.use_adain:
return h
sigma = transformer_options["sigmas"][0].item()
if not (sigma_min <= sigma <= sigma_max):
return h
@@ -68,6 +71,9 @@ class PreprocessorReference(Preprocessor):
return h
def attn1_proc(q, k, v, transformer_options):
if not self.use_attn:
return q, k, v
sigma = transformer_options["sigmas"][0].item()
if not (sigma_min <= sigma <= sigma_max):
return q, k, v
@@ -80,6 +86,9 @@ class PreprocessorReference(Preprocessor):
return q, k, v
def attn1_output_proc(h, transformer_options):
if not self.use_attn:
return h
sigma = transformer_options["sigmas"][0].item()
if not (sigma_min <= sigma <= sigma_max):
return h