attention: use flag based OOM fallback (#11038)

Exception ref all local variables for the lifetime of exception
context. Just set a flag and then if to dump the exception before
falling back.
This commit is contained in:
rattus
2025-12-03 08:24:19 +10:00
committed by GitHub
parent daaceac769
commit 277237ccc1
2 changed files with 6 additions and 0 deletions

View File

@@ -279,6 +279,7 @@ def pytorch_attention(q, k, v):
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
oom_fallback = False
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
@@ -289,6 +290,8 @@ def pytorch_attention(q, k, v):
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
oom_fallback = True
if oom_fallback:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out