Merge pull request #35 from Jarvik7/patch-1

Update sage_attention_patch.py
This commit is contained in:
WildAi
2025-09-07 19:28:25 +03:00
committed by GitHub

View File

@@ -35,10 +35,14 @@ def get_sage_attention_function_and_params():
attn_func = None
pv_accum_dtype = "fp32"
if arch_code >= 90: # Hopper
if arch_code >= 120: # Blackwell
pv_accum_dtype = "fp32+fp32"
attn_func = sageattn_qk_int8_pv_fp8_cuda
logger.info(f"SageAttention: Using SM120 (Blackwell) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.")
elif arch_code >= 90: # Hopper
pv_accum_dtype = "fp32+fp32"
attn_func = sageattn_qk_int8_pv_fp8_cuda_sm90
logger.info(f"SageAttention: Using SM90+ (Hopper) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.")
logger.info(f"SageAttention: Using SM90 (Hopper) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.")
elif arch_code == 89: # Ada Lovelace
pv_accum_dtype = "fp32+fp32"
attn_func = sageattn_qk_int8_pv_fp8_cuda
@@ -141,4 +145,4 @@ def set_sage_attention(model):
for module in model.modules():
if isinstance(module, Qwen2Attention):
module.forward = sage_attention_forward.__get__(module, Qwen2Attention)
module.forward = sage_attention_forward.__get__(module, Qwen2Attention)