diff --git a/modules_forge/ops.py b/modules_forge/ops.py index ee0e7756..f05a1376 100644 --- a/modules_forge/ops.py +++ b/modules_forge/ops.py @@ -17,3 +17,32 @@ def use_patched_ops(operations): for op_name in op_names: setattr(torch.nn, op_name, backups[op_name]) return + + +@contextlib.contextmanager +def capture_model(): + module_list = [] + backup_init = torch.nn.Module.__init__ + + def patched_init(self, *args, **kwargs): + module_list.append(self) + return backup_init(self, *args, **kwargs) + + try: + torch.nn.Module.__init__ = patched_init + yield + finally: + torch.nn.Module.__init__ = backup_init + + results = [] + for item in module_list: + item_params = getattr(item, '_parameters', []) + if len(item_params) > 0: + results.append(item) + + if len(results) == 0: + return None + + captured_model = torch.nn.ModuleList(results) + + return captured_model