memory_peak_estimation_modifier

This commit is contained in:
lllyasviel
2024-02-04 19:36:58 -08:00
parent bbdf53c79f
commit fdbe5b7aa7
2 changed files with 7 additions and 1 deletions

View File

@@ -86,7 +86,9 @@ def forge_sample(self, denoiser_params, cond_scale, cond_composition):
def sampling_prepare(unet, x):
B, C, H, W = x.shape
unet_inference_memory = unet.memory_required([B * 2, C, H, W])
memory_estimation_function = unet.model_options.get('memory_peak_estimation_modifier', unet.memory_required)
unet_inference_memory = memory_estimation_function([B * 2, C, H, W])
additional_inference_memory = unet.extra_preserved_memory_during_sampling
additional_model_patchers = unet.extra_model_patchers_during_sampling

View File

@@ -102,6 +102,10 @@ class UnetPatcher(ModelPatcher):
self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness)
return
def add_memory_peak_estimation_modifier(self, modifier):
self.model_options['memory_peak_estimation_modifier'] = modifier
return
def add_alphas_cumprod_modifier(self, modifier, ensure_uniqueness=False):
"""