Widen OOM_EXCEPTION to AcceleratorError form (#12835)

Pytorch only filters for OOMs in its own allocators however there are
paths that can OOM on allocators made outside the pytorch allocators.
These manifest as an AllocatorError as pytorch does not have universal
error translation to its OOM type on exception. Handle it. A log I have
for this also shows a double report of the error async, so call the
async discarder to cleanup and make these OOMs look like OOMs.
This commit is contained in:
rattus
2026-03-09 21:41:02 -07:00
committed by GitHub
parent a912809c25
commit 535c16ce6e
7 changed files with 27 additions and 8 deletions

View File

@@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except model_management.OOM_EXCEPTION as e:
except Exception as e:
model_management.raise_non_oom(e)
if first_op_done == False:
model_management.soft_empty_cache(True)
if cleared_cache == False:

View File

@@ -258,7 +258,8 @@ def slice_attention(q, k, v):
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except model_management.OOM_EXCEPTION as e:
except Exception as e:
model_management.raise_non_oom(e)
model_management.soft_empty_cache(True)
steps *= 2
if steps > 128:
@@ -314,7 +315,8 @@ def pytorch_attention(q, k, v):
try:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
oom_fallback = True
if oom_fallback:

View File

@@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking(
try:
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
except model_management.OOM_EXCEPTION:
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores)