diff --git a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py index 288352ad..853afcee 100644 --- a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py +++ b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py @@ -80,12 +80,11 @@ class PreprocessorReference(Preprocessor): if self.is_recording_style: self.recorded_h[location] = h + return h else: cond_mark = transformer_options['cond_mark'][:, None, None, None] # cond is 0 rh = self.recorded_h[location] - b = 0 - - return h + return h def attn1_proc(q, k, v, transformer_options): if not self.use_attn: @@ -100,12 +99,11 @@ class PreprocessorReference(Preprocessor): if self.is_recording_style: self.recorded_attn1[location] = (k, v) + return sdp(q, k, v, transformer_options) else: cond_mark = transformer_options['cond_mark'][:, None, None, None] # cond is 0 rk, rv = self.recorded_attn1[location] - b = 0 - - return sdp(q, k, v, transformer_options) + return sdp(q, k, v, transformer_options) unet.add_block_modifier(block_proc) unet.add_conditioning_modifier(conditioning_modifier)