From 754eb669d0ed2045dd98d7e5aaa1ef99fbac6aa3 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Tue, 30 Jan 2024 19:25:27 -0800 Subject: [PATCH] Update forge_reference.py --- .../forge_preprocessor_reference/scripts/forge_reference.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py index ad6b48ab..bb38ceef 100644 --- a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py +++ b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py @@ -25,6 +25,8 @@ class PreprocessorReference(Preprocessor): self.do_not_need_model = True self.is_recording_style = False + self.recorded_attn1 = {} + self.recorded_h = {} def process_before_every_sampling(self, process, cond, *args, **kwargs): unit = kwargs['unit'] @@ -46,6 +48,9 @@ class PreprocessorReference(Preprocessor): sigma_max = unet.model.model_sampling.percent_to_sigma(start_percent) sigma_min = unet.model.model_sampling.percent_to_sigma(end_percent) + self.recorded_attn1 = {} + self.recorded_h = {} + def conditioning_modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed): sigma = timestep[0].item() if not (sigma_min <= sigma <= sigma_max):