revise space

This commit is contained in:
layerdiffusion
2024-08-30 16:48:11 -07:00
parent ec7917bd16
commit 8e9b124456

View File

@@ -114,24 +114,17 @@ class GPUObject:
def __enter__(self):
self.original_init = torch.nn.Module.__init__
self.original_to = torch.nn.Module.to
def patched_init(module, *args, **kwargs):
self.module_list.append(module)
return self.original_init(module, *args, **kwargs)
def patched_to(module, *args, **kwargs):
self.module_list.append(module)
return self.original_to(module, *args, **kwargs)
torch.nn.Module.__init__ = patched_init
torch.nn.Module.to = patched_to
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch.nn.Module.__init__ = self.original_init
torch.nn.Module.to = self.original_to
self.module_list = set(self.module_list)
self.module_list = list(set(self.module_list))
self.to(device=torch.device('cpu'))
memory_management.soft_empty_cache()
return