diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index 838b2b99bf..1957d24a22 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -259,12 +259,14 @@ struct FmhaFwdSplitKVCombineKernel // get starting offset for each batch const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - batch_offset_lse_acc = static_cast(i_batch) * (kargs.nhead * kargs.max_seqlen_q); + batch_offset_lse_acc = + static_cast(i_batch) * (kargs.nhead * kargs.max_seqlen_q); batch_offset_o_acc = static_cast(i_batch) * (kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v); if constexpr(kStoreLSE) { - batch_offset_lse = static_cast(i_batch) * (kargs.nhead * kargs.max_seqlen_q); + batch_offset_lse = + static_cast(i_batch) * (kargs.nhead * kargs.max_seqlen_q); } batch_offset_o = query_start * kargs.row_stride_o; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 30368c3f44..2aa4d061ad 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -389,11 +389,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices); const auto row = x_indices.at(number<0>{}); - LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits]; o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices); - #if 0 +#if 0 DEBUG_STMTS { const auto col = x_indices.at(number<1>{}); @@ -405,7 +404,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline lse_scale, o_tile(distributed_indices)); } - #endif +#endif }); }); }