All Forge Spaces Now Pass 4GB VRAM

and they all 100% reproduce author results
This commit is contained in:
layerdiffusion
2024-08-20 08:01:10 -07:00
parent f136f86fee
commit 5452bc6ac3
4 changed files with 75 additions and 5 deletions

View File

@@ -500,3 +500,52 @@ def automatic_memory_management():
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.')
return
class DynamicSwapInstaller:
@staticmethod
def _install_module(module: torch.nn.Module, target_device: torch.device):
original_class = module.__class__
module.__dict__['forge_backup_original_class'] = original_class
def hacked_get_attr(self, name: str):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
p = _parameters[name]
if p is None:
return None
if p.__class__ == torch.nn.Parameter:
return torch.nn.Parameter(p.to(target_device), requires_grad=p.requires_grad)
else:
return p.to(target_device)
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name].to(target_device)
return super(original_class, self).__getattr__(name)
module.__class__ = type('DynamicSwapInstance', (original_class,), {
'__getattr__': hacked_get_attr,
})
return
@staticmethod
def _uninstall_module(module: torch.nn.Module):
if 'forge_backup_original_class' in module.__dict__:
module.__class__ = module.__dict__.pop('forge_backup_original_class')
return
@staticmethod
def install_model(model: torch.nn.Module, target_device: torch.device):
for m in model.modules():
DynamicSwapInstaller._install_module(m, target_device)
return
@staticmethod
def uninstall_model(model: torch.nn.Module):
for m in model.modules():
DynamicSwapInstaller._uninstall_module(m)
return