diff --git a/backend/attention.py b/backend/attention.py index 6cf13ff2..58fee278 100644 --- a/backend/attention.py +++ b/backend/attention.py @@ -284,18 +284,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh b, _, dim_head = q.shape dim_head //= heads - disabled_xformers = False - - if BROKEN_XFORMERS: - if b * heads > 65535: - disabled_xformers = True - - if not disabled_xformers: - if torch.jit.is_tracing() or torch.jit.is_scripting(): - disabled_xformers = True - - if disabled_xformers: - return attention_pytorch(q, k, v, heads, mask) + if BROKEN_XFORMERS and b * heads > 65535: + return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) if skip_reshape: q, k, v = map( diff --git a/backend/nn/t5.py b/backend/nn/t5.py index 8a9cc9b5..d867f0de 100644 --- a/backend/nn/t5.py +++ b/backend/nn/t5.py @@ -1,7 +1,7 @@ import torch import math -from backend.attention import attention_function +from backend.attention import attention_pytorch as attention_function activations = {