mirror of
https://github.com/wildminder/ComfyUI-VibeVoice.git
synced 2026-03-13 20:39:49 +00:00
Merge pull request #35 from Jarvik7/patch-1
Update sage_attention_patch.py
This commit is contained in:
@@ -35,10 +35,14 @@ def get_sage_attention_function_and_params():
|
||||
attn_func = None
|
||||
pv_accum_dtype = "fp32"
|
||||
|
||||
if arch_code >= 90: # Hopper
|
||||
if arch_code >= 120: # Blackwell
|
||||
pv_accum_dtype = "fp32+fp32"
|
||||
attn_func = sageattn_qk_int8_pv_fp8_cuda
|
||||
logger.info(f"SageAttention: Using SM120 (Blackwell) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.")
|
||||
elif arch_code >= 90: # Hopper
|
||||
pv_accum_dtype = "fp32+fp32"
|
||||
attn_func = sageattn_qk_int8_pv_fp8_cuda_sm90
|
||||
logger.info(f"SageAttention: Using SM90+ (Hopper) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.")
|
||||
logger.info(f"SageAttention: Using SM90 (Hopper) FP8 kernel with pv_accum_dtype='{pv_accum_dtype}'.")
|
||||
elif arch_code == 89: # Ada Lovelace
|
||||
pv_accum_dtype = "fp32+fp32"
|
||||
attn_func = sageattn_qk_int8_pv_fp8_cuda
|
||||
@@ -141,4 +145,4 @@ def set_sage_attention(model):
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, Qwen2Attention):
|
||||
module.forward = sage_attention_forward.__get__(module, Qwen2Attention)
|
||||
module.forward = sage_attention_forward.__get__(module, Qwen2Attention)
|
||||
|
||||
Reference in New Issue
Block a user