diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 4226502a0f..01e2eac7b3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -534,7 +534,7 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, - bool has_padded_seqlen_k=false) + bool has_padded_seqlen_k = false) { ck_tile::index_t nhead_ = kMergeNumHeadGroupsSeqLenQ ? nhead_kv : nhead_q; ck_tile::index_t max_seqlen_q_ = @@ -545,20 +545,18 @@ struct FmhaFwdSplitKVKernel { // TODO: this may need tuning return dim3(nhead_, - batch_size, - ck_tile::integer_divide_ceil(max_seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits); + batch_size, + ck_tile::integer_divide_ceil(max_seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits); } else { // TODO: this may need tuning return dim3(ck_tile::integer_divide_ceil(max_seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits, + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits, nhead_, batch_size); } - - } CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)