diff --git a/modules_forge/forge_sampler.py b/modules_forge/forge_sampler.py index c3e447ed..bac7d156 100644 --- a/modules_forge/forge_sampler.py +++ b/modules_forge/forge_sampler.py @@ -86,7 +86,9 @@ def forge_sample(self, denoiser_params, cond_scale, cond_composition): def sampling_prepare(unet, x): B, C, H, W = x.shape - unet_inference_memory = unet.memory_required([B * 2, C, H, W]) + memory_estimation_function = unet.model_options.get('memory_peak_estimation_modifier', unet.memory_required) + + unet_inference_memory = memory_estimation_function([B * 2, C, H, W]) additional_inference_memory = unet.extra_preserved_memory_during_sampling additional_model_patchers = unet.extra_model_patchers_during_sampling diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index d94d9a4e..a20ccb70 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -102,6 +102,10 @@ class UnetPatcher(ModelPatcher): self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness) return + def add_memory_peak_estimation_modifier(self, modifier): + self.model_options['memory_peak_estimation_modifier'] = modifier + return + def add_alphas_cumprod_modifier(self, modifier, ensure_uniqueness=False): """