From dce37a234c4697a6cfa6639eaedaeaf8859dd69c Mon Sep 17 00:00:00 2001 From: Jarvik7 Date: Fri, 5 Sep 2025 10:19:13 -0300 Subject: [PATCH] Update sage_attention_patch.py Fix inappropriate assert on Blackwell (SM120) that broke sage attention. Tested with Torch2.9 nightly. Saves avg 2s compared to sdpa. --- vibevoice/modular/sage_attention_patch.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vibevoice/modular/sage_attention_patch.py b/vibevoice/modular/sage_attention_patch.py index c5bcaf9..4111d62 100644 --- a/vibevoice/modular/sage_attention_patch.py +++ b/vibevoice/modular/sage_attention_patch.py @@ -35,10 +35,14 @@ def get_sage_attention_function_and_params(): attn_func = None pv_accum_dtype = "fp32" - if arch_code >= 90: # Hopper + if arch_code >= 120: # Blackwell + pv_accum_dtype = "fp32+fp32" + attn_func = sageattn_qk_int8_pv_fp8_cuda + logger.info(f"SageAttention: Using SM120 (Blackwell) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.") + elif arch_code >= 90: # Hopper pv_accum_dtype = "fp32+fp32" attn_func = sageattn_qk_int8_pv_fp8_cuda_sm90 - logger.info(f"SageAttention: Using SM90+ (Hopper) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.") + logger.info(f"SageAttention: Using SM90 (Hopper) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.") elif arch_code == 89: # Ada Lovelace pv_accum_dtype = "fp32+fp32" attn_func = sageattn_qk_int8_pv_fp8_cuda @@ -141,4 +145,4 @@ def set_sage_attention(model): for module in model.modules(): if isinstance(module, Qwen2Attention): - module.forward = sage_attention_forward.__get__(module, Qwen2Attention) \ No newline at end of file + module.forward = sage_attention_forward.__get__(module, Qwen2Attention)