mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 18:51:31 +00:00
All Forge Spaces Now Pass 4GB VRAM
and they all 100% reproduce author results
This commit is contained in:
@@ -500,3 +500,52 @@ def automatic_memory_management():
|
||||
|
||||
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.')
|
||||
return
|
||||
|
||||
|
||||
class DynamicSwapInstaller:
|
||||
@staticmethod
|
||||
def _install_module(module: torch.nn.Module, target_device: torch.device):
|
||||
original_class = module.__class__
|
||||
module.__dict__['forge_backup_original_class'] = original_class
|
||||
|
||||
def hacked_get_attr(self, name: str):
|
||||
if '_parameters' in self.__dict__:
|
||||
_parameters = self.__dict__['_parameters']
|
||||
if name in _parameters:
|
||||
p = _parameters[name]
|
||||
if p is None:
|
||||
return None
|
||||
if p.__class__ == torch.nn.Parameter:
|
||||
return torch.nn.Parameter(p.to(target_device), requires_grad=p.requires_grad)
|
||||
else:
|
||||
return p.to(target_device)
|
||||
if '_buffers' in self.__dict__:
|
||||
_buffers = self.__dict__['_buffers']
|
||||
if name in _buffers:
|
||||
return _buffers[name].to(target_device)
|
||||
return super(original_class, self).__getattr__(name)
|
||||
|
||||
module.__class__ = type('DynamicSwapInstance', (original_class,), {
|
||||
'__getattr__': hacked_get_attr,
|
||||
})
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _uninstall_module(module: torch.nn.Module):
|
||||
if 'forge_backup_original_class' in module.__dict__:
|
||||
module.__class__ = module.__dict__.pop('forge_backup_original_class')
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def install_model(model: torch.nn.Module, target_device: torch.device):
|
||||
for m in model.modules():
|
||||
DynamicSwapInstaller._install_module(m, target_device)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def uninstall_model(model: torch.nn.Module):
|
||||
for m in model.modules():
|
||||
DynamicSwapInstaller._uninstall_module(m)
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user