fix: guard torch.AcceleratorError for compatibility with torch < 2.8.0 (#12874)

* fix: guard torch.AcceleratorError for compatibility with torch < 2.8.0

torch.AcceleratorError was introduced in PyTorch 2.8.0. Accessing it
directly raises AttributeError on older versions. Use a try/except
fallback at module load time, consistent with the existing pattern used
for OOM_EXCEPTION.


* fix: address review feedback for AcceleratorError compat

- Fall back to RuntimeError instead of type(None) for ACCELERATOR_ERROR,
  consistent with OOM_EXCEPTION fallback pattern and valid for except clauses
- Add "out of memory" message introspection for RuntimeError fallback case
- Use RuntimeError directly in discard_cuda_async_error except clause
---------
This commit is contained in:
Adi Borochov
2026-03-11 19:04:13 +02:00
committed by GitHub
parent 3365008dfe
commit 4f4f8659c2

View File

@@ -270,10 +270,15 @@ try:
except:
OOM_EXCEPTION = Exception
try:
ACCELERATOR_ERROR = torch.AcceleratorError
except AttributeError:
ACCELERATOR_ERROR = RuntimeError
def is_oom(e):
if isinstance(e, OOM_EXCEPTION):
return True
if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2:
if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
discard_cuda_async_error()
return True
return False
@@ -1275,7 +1280,7 @@ def discard_cuda_async_error():
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
synchronize()
except torch.AcceleratorError:
except RuntimeError:
#Dump it! We already know about it from the synchronous return
pass