diff --git a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py index 8d5d4250..644345a1 100644 --- a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py +++ b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py @@ -56,6 +56,9 @@ class PreprocessorReference(Preprocessor): return model, x, timestep, uncond, cond, cond_scale, model_options, seed def block_proc(h, flag, transformer_options): + if not self.use_adain: + return h + sigma = transformer_options["sigmas"][0].item() if not (sigma_min <= sigma <= sigma_max): return h @@ -68,6 +71,9 @@ class PreprocessorReference(Preprocessor): return h def attn1_proc(q, k, v, transformer_options): + if not self.use_attn: + return q, k, v + sigma = transformer_options["sigmas"][0].item() if not (sigma_min <= sigma <= sigma_max): return q, k, v @@ -80,6 +86,9 @@ class PreprocessorReference(Preprocessor): return q, k, v def attn1_output_proc(h, transformer_options): + if not self.use_attn: + return h + sigma = transformer_options["sigmas"][0].item() if not (sigma_min <= sigma <= sigma_max): return h