Model: Fix flash-attn checks

If flash attention is already turned off by exllamaV2 itself, don't
try creating a paged generator. Also condense all the redundant
logic into one if statement.

Also check arch_compat_overrides to see if flash attention should
be disabled for a model arch (ex. Gemma 2)

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-07-06 20:57:10 -04:00
parent 27d2d5f3d2
commit 773639ea89
2 changed files with 30 additions and 8 deletions

View File

@@ -94,3 +94,18 @@ def supports_paged_attn():
return False
else:
return True
def exllama_disabled_flash_attn(no_flash_attn: bool):
unsupported_message = (
"ExllamaV2 has disabled Flash Attention. \n"
"Please see the above logs for warnings/errors. \n"
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
)
if no_flash_attn:
logger.warning(unsupported_message)
return no_flash_attn