diff --git a/comfy/model_management.py b/comfy/model_management.py index 81550c790..81c89b180 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -270,10 +270,15 @@ try: except: OOM_EXCEPTION = Exception +try: + ACCELERATOR_ERROR = torch.AcceleratorError +except AttributeError: + ACCELERATOR_ERROR = RuntimeError + def is_oom(e): if isinstance(e, OOM_EXCEPTION): return True - if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2: + if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()): discard_cuda_async_error() return True return False @@ -1275,7 +1280,7 @@ def discard_cuda_async_error(): b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) _ = a + b synchronize() - except torch.AcceleratorError: + except RuntimeError: #Dump it! We already know about it from the synchronous return pass