diff --git a/spaces.py b/spaces.py index 16709dc6..d592722d 100644 --- a/spaces.py +++ b/spaces.py @@ -114,24 +114,17 @@ class GPUObject: def __enter__(self): self.original_init = torch.nn.Module.__init__ - self.original_to = torch.nn.Module.to def patched_init(module, *args, **kwargs): self.module_list.append(module) return self.original_init(module, *args, **kwargs) - def patched_to(module, *args, **kwargs): - self.module_list.append(module) - return self.original_to(module, *args, **kwargs) - torch.nn.Module.__init__ = patched_init - torch.nn.Module.to = patched_to return self def __exit__(self, exc_type, exc_val, exc_tb): torch.nn.Module.__init__ = self.original_init - torch.nn.Module.to = self.original_to - self.module_list = set(self.module_list) + self.module_list = list(set(self.module_list)) self.to(device=torch.device('cpu')) memory_management.soft_empty_cache() return