diff --git a/modules_forge/initialization.py b/modules_forge/initialization.py index 5c8720f5..4b53fec6 100644 --- a/modules_forge/initialization.py +++ b/modules_forge/initialization.py @@ -2,6 +2,28 @@ import os import sys +MONITOR_MODEL_MOVING = False + + +def monitor_module_moving(): + if not MONITOR_MODEL_MOVING: + return + + import torch + import traceback + + old_to = torch.nn.Module.to + + def new_to(*args, **kwargs): + traceback.print_stack() + print('Model Movement') + + return old_to(*args, **kwargs) + + torch.nn.Module.to = new_to + return + + def initialize_forge(): bad_list = ['--lowvram', '--medvram', '--medvram-sdxl'] @@ -21,6 +43,8 @@ def initialize_forge(): import ldm_patched.modules.model_management as model_management import torch + monitor_module_moving() + device = model_management.get_torch_device() torch.zeros((1, 1)).to(device, torch.float32) model_management.soft_empty_cache()