From 035ad4836a816bc2006bb9d33946fc9b658f1e74 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Tue, 30 Jan 2024 20:45:18 -0800 Subject: [PATCH] Update forge_reference.py --- .../scripts/forge_reference.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py index 2fa6b26c..68be9fb9 100644 --- a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py +++ b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py @@ -83,7 +83,11 @@ class PreprocessorReference(Preprocessor): if not (sigma_min <= sigma <= sigma_max): return h - C = int(h.shape[1]) + channel = int(h.shape[1]) + minimal_channel = 1280 - 640 * weight + + if channel < minimal_channel: + return h if self.is_recording_style: self.recorded_h[location] = torch.std_mean(h, dim=(2, 3), keepdim=True, correction=0) @@ -123,7 +127,11 @@ class PreprocessorReference(Preprocessor): location = (transformer_options['block'][0], transformer_options['block'][1], transformer_options['block_index']) - C = int(q.shape[2]) + channel = int(q.shape[2]) + minimal_channel = 1280 - 1280 * weight + + if channel < minimal_channel: + return sdp(q, k, v, transformer_options) if self.is_recording_style: self.recorded_attn1[location] = (k, v)