diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index e3f7fd0ab7..8511a89abf 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -333,6 +333,7 @@ struct FmhaFwdSplitKVCombineKernel lse_acc_ptr, make_tuple(kargs.num_splits, kargs.seqlen_q), make_tuple(kargs.split_stride_lse_acc, 1), + -numeric::infinity(), number<8>{}, number<1>{}); @@ -421,7 +422,6 @@ struct FmhaFwdSplitKVCombineKernel identity{}, // lse_element_func composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func kargs.num_splits, - kargs.seqlen_q, kargs.max_seqlen_q, smem_ptr); } @@ -431,7 +431,6 @@ struct FmhaFwdSplitKVCombineKernel o_acc_dram_window, lse_dram_window, kargs.num_splits, - kargs.seqlen_q, kargs.max_seqlen_q, smem_ptr); }