allow preserve memory

This commit is contained in:
lllyasviel
2024-01-31 21:57:40 -08:00
parent df5e8c520d
commit a38313368f
2 changed files with 11 additions and 1 deletions

View File

@@ -86,7 +86,7 @@ def forge_sample(self, denoiser_params, cond_scale, cond_composition):
def sampling_prepare(unet, x):
B, C, H, W = x.shape
unet_inference_memory = unet.memory_required([B * 2, C, H, W])
unet_inference_memory = unet.memory_required([B * 2, C, H, W]) + unet.extra_preserved_memory
additional_inference_memory = 0
additional_model_patchers = []

View File

@@ -7,6 +7,7 @@ class UnetPatcher(ModelPatcher):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.controlnet_linked_list = None
self.extra_preserved_memory = 0
def clone(self):
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
@@ -20,8 +21,17 @@ class UnetPatcher(ModelPatcher):
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
n.controlnet_linked_list = self.controlnet_linked_list
n.extra_preserved_memory = self.extra_preserved_memory
return n
def add_preserved_memory(self, memory_in_bytes):
# Use this to ask Forge to preserve a certain amount of memory during sampling.
# If GPU VRAM is 8 GB, and memory_in_bytes is 2GB, i.e., memory_in_bytes = 2 * 1024 * 1024 * 1024
# Then the sampling will always use less than 6GB memory by dynamically offload modules to CPU RAM.
# You can estimate this using model_management.module_size(any_pytorch_model) to get size of any pytorch models.
self.extra_preserved_memory += memory_in_bytes
return
def add_patched_controlnet(self, cnet):
cnet.set_previous_controlnet(self.controlnet_linked_list)
self.controlnet_linked_list = cnet