From ffd2768000328ecf29c4ce7a65b85bf0ac994994 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 6 Jun 2024 04:36:49 +0000 Subject: [PATCH] Format codes --- .../ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp | 6 ++++-- .../pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp | 5 ++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index 838b2b99bf..1957d24a22 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -259,12 +259,14 @@ struct FmhaFwdSplitKVCombineKernel // get starting offset for each batch const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - batch_offset_lse_acc = static_cast(i_batch) * (kargs.nhead * kargs.max_seqlen_q); + batch_offset_lse_acc = + static_cast(i_batch) * (kargs.nhead * kargs.max_seqlen_q); batch_offset_o_acc = static_cast(i_batch) * (kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v); if constexpr(kStoreLSE) { - batch_offset_lse = static_cast(i_batch) * (kargs.nhead * kargs.max_seqlen_q); + batch_offset_lse = + static_cast(i_batch) * (kargs.nhead * kargs.max_seqlen_q); } batch_offset_o = query_start * kargs.row_stride_o; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 30368c3f44..2aa4d061ad 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -389,11 +389,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices); const auto row = x_indices.at(number<0>{}); - LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits]; o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices); - #if 0 +#if 0 DEBUG_STMTS { const auto col = x_indices.at(number<1>{}); @@ -405,7 +404,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline lse_scale, o_tile(distributed_indices)); } - #endif +#endif }); }); }