From ce16d34d034b0e961136384c633880a099c638db Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Thu, 15 Aug 2024 00:55:49 -0700 Subject: [PATCH] disable xformers for t5 --- backend/attention.py | 14 ++------------ backend/nn/t5.py | 2 +- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/backend/attention.py b/backend/attention.py index 6cf13ff2..58fee278 100644 --- a/backend/attention.py +++ b/backend/attention.py @@ -284,18 +284,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh b, _, dim_head = q.shape dim_head //= heads - disabled_xformers = False - - if BROKEN_XFORMERS: - if b * heads > 65535: - disabled_xformers = True - - if not disabled_xformers: - if torch.jit.is_tracing() or torch.jit.is_scripting(): - disabled_xformers = True - - if disabled_xformers: - return attention_pytorch(q, k, v, heads, mask) + if BROKEN_XFORMERS and b * heads > 65535: + return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) if skip_reshape: q, k, v = map( diff --git a/backend/nn/t5.py b/backend/nn/t5.py index 8a9cc9b5..d867f0de 100644 --- a/backend/nn/t5.py +++ b/backend/nn/t5.py @@ -1,7 +1,7 @@ import torch import math -from backend.attention import attention_function +from backend.attention import attention_pytorch as attention_function activations = {