From fdbe5b7aa7192059e1f1253e25009a8ee01cc662 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sun, 4 Feb 2024 19:36:58 -0800 Subject: [PATCH] memory_peak_estimation_modifier --- modules_forge/forge_sampler.py | 4 +++- modules_forge/unet_patcher.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/modules_forge/forge_sampler.py b/modules_forge/forge_sampler.py index c3e447ed..bac7d156 100644 --- a/modules_forge/forge_sampler.py +++ b/modules_forge/forge_sampler.py @@ -86,7 +86,9 @@ def forge_sample(self, denoiser_params, cond_scale, cond_composition): def sampling_prepare(unet, x): B, C, H, W = x.shape - unet_inference_memory = unet.memory_required([B * 2, C, H, W]) + memory_estimation_function = unet.model_options.get('memory_peak_estimation_modifier', unet.memory_required) + + unet_inference_memory = memory_estimation_function([B * 2, C, H, W]) additional_inference_memory = unet.extra_preserved_memory_during_sampling additional_model_patchers = unet.extra_model_patchers_during_sampling diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index d94d9a4e..a20ccb70 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -102,6 +102,10 @@ class UnetPatcher(ModelPatcher): self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness) return + def add_memory_peak_estimation_modifier(self, modifier): + self.model_options['memory_peak_estimation_modifier'] = modifier + return + def add_alphas_cumprod_modifier(self, modifier, ensure_uniqueness=False): """