diff --git a/modules_forge/forge_sampler.py b/modules_forge/forge_sampler.py index bfc8f549..c3e447ed 100644 --- a/modules_forge/forge_sampler.py +++ b/modules_forge/forge_sampler.py @@ -86,9 +86,9 @@ def forge_sample(self, denoiser_params, cond_scale, cond_composition): def sampling_prepare(unet, x): B, C, H, W = x.shape - unet_inference_memory = unet.memory_required([B * 2, C, H, W]) + unet.extra_preserved_memory - additional_inference_memory = 0 - additional_model_patchers = [] + unet_inference_memory = unet.memory_required([B * 2, C, H, W]) + additional_inference_memory = unet.extra_preserved_memory_during_sampling + additional_model_patchers = unet.extra_model_patchers_during_sampling if unet.controlnet_linked_list is not None: additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype()) diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index cff5c742..66c48c06 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -7,7 +7,8 @@ class UnetPatcher(ModelPatcher): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.controlnet_linked_list = None - self.extra_preserved_memory = 0 + self.extra_preserved_memory_during_sampling = 0 + self.extra_model_patchers_during_sampling = [] def clone(self): n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, @@ -21,15 +22,22 @@ class UnetPatcher(ModelPatcher): n.model_options = copy.deepcopy(self.model_options) n.model_keys = self.model_keys n.controlnet_linked_list = self.controlnet_linked_list - n.extra_preserved_memory = self.extra_preserved_memory + n.extra_preserved_memory_during_sampling = self.extra_preserved_memory_during_sampling + n.extra_model_patchers_during_sampling = self.extra_model_patchers_during_sampling.copy() return n - def add_preserved_memory(self, memory_in_bytes): + def add_extra_preserved_memory_during_sampling(self, memory_in_bytes): # Use this to ask Forge to preserve a certain amount of memory during sampling. # If GPU VRAM is 8 GB, and memory_in_bytes is 2GB, i.e., memory_in_bytes = 2 * 1024 * 1024 * 1024 # Then the sampling will always use less than 6GB memory by dynamically offload modules to CPU RAM. # You can estimate this using model_management.module_size(any_pytorch_model) to get size of any pytorch models. - self.extra_preserved_memory += memory_in_bytes + self.extra_preserved_memory_during_sampling += memory_in_bytes + return + + def add_extra_model_patcher_during_sampling(self, model_patcher): + # Use this to ask Forge to move extra model patchers to GPU during sampling. + # This method will manage GPU memory perfectly. + self.extra_model_patchers_during_sampling.append(model_patcher) return def add_patched_controlnet(self, cnet):