Update sage_attention_patch.py

Fix inappropriate assert on Blackwell (SM120) that broke sage attention.
Tested with Torch2.9 nightly. Saves avg 2s compared to sdpa.
This commit is contained in:
Jarvik7
2025-09-05 10:19:13 -03:00
committed by GitHub
parent 4f7d167d59
commit dce37a234c

View File

@@ -35,10 +35,14 @@ def get_sage_attention_function_and_params():
attn_func = None
pv_accum_dtype = "fp32"
if arch_code >= 90: # Hopper
if arch_code >= 120: # Blackwell
pv_accum_dtype = "fp32+fp32"
attn_func = sageattn_qk_int8_pv_fp8_cuda
logger.info(f"SageAttention: Using SM120 (Blackwell) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.")
elif arch_code >= 90: # Hopper
pv_accum_dtype = "fp32+fp32"
attn_func = sageattn_qk_int8_pv_fp8_cuda_sm90
logger.info(f"SageAttention: Using SM90+ (Hopper) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.")
logger.info(f"SageAttention: Using SM90 (Hopper) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.")
elif arch_code == 89: # Ada Lovelace
pv_accum_dtype = "fp32+fp32"
attn_func = sageattn_qk_int8_pv_fp8_cuda
@@ -141,4 +145,4 @@ def set_sage_attention(model):
for module in model.modules():
if isinstance(module, Qwen2Attention):
module.forward = sage_attention_forward.__get__(module, Qwen2Attention)
module.forward = sage_attention_forward.__get__(module, Qwen2Attention)