mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-30 12:59:47 +00:00
49 lines
1.1 KiB
Python
49 lines
1.1 KiB
Python
import torch
|
|
import contextlib
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def use_patched_ops(operations):
|
|
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
|
|
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
|
|
|
|
try:
|
|
for op_name in op_names:
|
|
setattr(torch.nn, op_name, getattr(operations, op_name))
|
|
|
|
yield
|
|
|
|
finally:
|
|
for op_name in op_names:
|
|
setattr(torch.nn, op_name, backups[op_name])
|
|
return
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def capture_model():
|
|
module_list = []
|
|
backup_init = torch.nn.Module.__init__
|
|
|
|
def patched_init(self, *args, **kwargs):
|
|
module_list.append(self)
|
|
return backup_init(self, *args, **kwargs)
|
|
|
|
try:
|
|
torch.nn.Module.__init__ = patched_init
|
|
yield
|
|
finally:
|
|
torch.nn.Module.__init__ = backup_init
|
|
|
|
results = []
|
|
for item in module_list:
|
|
item_params = getattr(item, '_parameters', [])
|
|
if len(item_params) > 0:
|
|
results.append(item)
|
|
|
|
if len(results) == 0:
|
|
return None
|
|
|
|
captured_model = torch.nn.ModuleList(results)
|
|
|
|
return captured_model
|