This commit is contained in:
lllyasviel
2024-01-29 07:47:56 -08:00
parent bd334d3aff
commit 336eac060f
2 changed files with 27 additions and 25 deletions

View File

@@ -78,17 +78,6 @@ class ControlNetExampleForge(scripts.Script):
input_image = cv2.resize(input_image, (width, height))
canny_image = cv2.cvtColor(cv2.Canny(input_image, 100, 200), cv2.COLOR_GRAY2RGB)
from modules_forge.ops import capture_model
from modules_forge.shared import shared_preprocessors
with capture_model() as captured_model:
canny_image = shared_preprocessors['normalbae'](input_image, 512)
captured_model.cpu()
from ldm_patched.modules import model_management
model_management.soft_empty_cache()
a = 0
# # Or you can get a list of preprocessors in this way
# from modules_forge.shared import shared_preprocessors
# canny_preprocessor = shared_preprocessors['canny']

View File

@@ -1,5 +1,6 @@
import torch
import contextlib
from ldm_patched.modules import model_management
@contextlib.contextmanager
@@ -20,29 +21,41 @@ def use_patched_ops(operations):
@contextlib.contextmanager
def capture_model():
def automatic_memory_management():
model_management.free_memory(
memory_required=3 * 1024 * 1024 * 1024,
device=model_management.get_torch_device()
)
module_list = []
backup_init = torch.nn.Module.__init__
original_init = torch.nn.Module.__init__
original_to = torch.nn.Module.to
def patched_init(self, *args, **kwargs):
module_list.append(self)
return backup_init(self, *args, **kwargs)
return original_init(self, *args, **kwargs)
def patched_to(self, *args, **kwargs):
module_list.append(self)
return original_to(self, *args, **kwargs)
try:
torch.nn.Module.__init__ = patched_init
torch.nn.Module.to = patched_to
yield
finally:
torch.nn.Module.__init__ = backup_init
torch.nn.Module.__init__ = original_init
torch.nn.Module.to = original_to
results = []
for item in module_list:
item_params = getattr(item, '_parameters', [])
if len(item_params) > 0:
results.append(item)
count = 0
for module in set(module_list):
module_params = getattr(module, '_parameters', [])
if len(module_params) > 0:
module.cpu()
count += 1
if len(results) == 0:
return None
print(f'Automatic Memory Management: {count} Modules.')
model_management.soft_empty_cache()
captured_model = torch.nn.ModuleList(results)
return captured_model
return