diff --git a/vibevoice/modular/sage_attention_patch.py b/vibevoice/modular/sage_attention_patch.py index c5bcaf9..4111d62 100644 --- a/vibevoice/modular/sage_attention_patch.py +++ b/vibevoice/modular/sage_attention_patch.py @@ -35,10 +35,14 @@ def get_sage_attention_function_and_params(): attn_func = None pv_accum_dtype = "fp32" - if arch_code >= 90: # Hopper + if arch_code >= 120: # Blackwell + pv_accum_dtype = "fp32+fp32" + attn_func = sageattn_qk_int8_pv_fp8_cuda + logger.info(f"SageAttention: Using SM120 (Blackwell) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.") + elif arch_code >= 90: # Hopper pv_accum_dtype = "fp32+fp32" attn_func = sageattn_qk_int8_pv_fp8_cuda_sm90 - logger.info(f"SageAttention: Using SM90+ (Hopper) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.") + logger.info(f"SageAttention: Using SM90 (Hopper) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.") elif arch_code == 89: # Ada Lovelace pv_accum_dtype = "fp32+fp32" attn_func = sageattn_qk_int8_pv_fp8_cuda @@ -141,4 +145,4 @@ def set_sage_attention(model): for module in model.modules(): if isinstance(module, Qwen2Attention): - module.forward = sage_attention_forward.__get__(module, Qwen2Attention) \ No newline at end of file + module.forward = sage_attention_forward.__get__(module, Qwen2Attention)