diff --git a/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py b/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py index 359d772d..30f4bd87 100644 --- a/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py +++ b/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py @@ -78,17 +78,6 @@ class ControlNetExampleForge(scripts.Script): input_image = cv2.resize(input_image, (width, height)) canny_image = cv2.cvtColor(cv2.Canny(input_image, 100, 200), cv2.COLOR_GRAY2RGB) - from modules_forge.ops import capture_model - from modules_forge.shared import shared_preprocessors - - with capture_model() as captured_model: - canny_image = shared_preprocessors['normalbae'](input_image, 512) - - captured_model.cpu() - from ldm_patched.modules import model_management - model_management.soft_empty_cache() - a = 0 - # # Or you can get a list of preprocessors in this way # from modules_forge.shared import shared_preprocessors # canny_preprocessor = shared_preprocessors['canny'] diff --git a/modules_forge/ops.py b/modules_forge/ops.py index f05a1376..9ffbb965 100644 --- a/modules_forge/ops.py +++ b/modules_forge/ops.py @@ -1,5 +1,6 @@ import torch import contextlib +from ldm_patched.modules import model_management @contextlib.contextmanager @@ -20,29 +21,41 @@ def use_patched_ops(operations): @contextlib.contextmanager -def capture_model(): +def automatic_memory_management(): + model_management.free_memory( + memory_required=3 * 1024 * 1024 * 1024, + device=model_management.get_torch_device() + ) + module_list = [] - backup_init = torch.nn.Module.__init__ + + original_init = torch.nn.Module.__init__ + original_to = torch.nn.Module.to def patched_init(self, *args, **kwargs): module_list.append(self) - return backup_init(self, *args, **kwargs) + return original_init(self, *args, **kwargs) + + def patched_to(self, *args, **kwargs): + module_list.append(self) + return original_to(self, *args, **kwargs) try: torch.nn.Module.__init__ = patched_init + torch.nn.Module.to = patched_to yield finally: - torch.nn.Module.__init__ = backup_init + torch.nn.Module.__init__ = original_init + torch.nn.Module.to = original_to - results = [] - for item in module_list: - item_params = getattr(item, '_parameters', []) - if len(item_params) > 0: - results.append(item) + count = 0 + for module in set(module_list): + module_params = getattr(module, '_parameters', []) + if len(module_params) > 0: + module.cpu() + count += 1 - if len(results) == 0: - return None + print(f'Automatic Memory Management: {count} Modules.') + model_management.soft_empty_cache() - captured_model = torch.nn.ModuleList(results) - - return captured_model + return