diff --git a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py index fdd6a2bd..2fa6b26c 100644 --- a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py +++ b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py @@ -83,6 +83,8 @@ class PreprocessorReference(Preprocessor): if not (sigma_min <= sigma <= sigma_max): return h + C = int(h.shape[1]) + if self.is_recording_style: self.recorded_h[location] = torch.std_mean(h, dim=(2, 3), keepdim=True, correction=0) return h @@ -121,6 +123,8 @@ class PreprocessorReference(Preprocessor): location = (transformer_options['block'][0], transformer_options['block'][1], transformer_options['block_index']) + C = int(q.shape[2]) + if self.is_recording_style: self.recorded_attn1[location] = (k, v) return sdp(q, k, v, transformer_options)