From f968a7e4422787a25834bf0e24b1b705bcd548c1 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 11 Jun 2024 18:36:23 +0000 Subject: [PATCH] Remove more debug code in combine pipeline --- ...lock_fmha_fwd_splitkv_combine_pipeline.hpp | 112 +----------------- 1 file changed, 2 insertions(+), 110 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index d456717ff7..e1b8ae12cb 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -99,15 +99,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline auto lse_acc = load_tile(lse_acc_dram_window); // [kMaxSplits, kM0] -#if !defined(TID) -#define TID 0 -#endif - -#if defined(ENABLE_DEBUG_STMTS) -#define DEBUG_STMTS if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID) -#else -#define DEBUG_STMTS if(false) -#endif // copy lse_acc to LDS { using DataType = LSEDataType; @@ -169,22 +160,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline }); } -#if defined(PRINT_LSE_ACCUM) - DEBUG_STMTS - { - printf("(%2d, %2d)\n", num_splits, kM0); - for(index_t row = 0; row < num_splits; ++row) - { - printf("[POYENC][DEVICE] lse_acc[%2d] = ", row); - for(index_t col = 0; col < kM0; ++col) - { - printf("%11.7f", lse_acc_lds_ptr[row + col * kMaxSplits]); - } - printf("\n"); - } - } -#endif - // calculate row_max of lse_accum const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); }; const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; @@ -193,25 +168,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline lse_accum, sequence<1>{}, f_max, -numeric::infinity()); block_tile_reduce_sync(lse_max, f_max, bool_constant{}); -#if defined(PRINT_LSE_MAX) - DEBUG_STMTS - { - constexpr auto out_spans = - static_distributed_tensor:: - get_distributed_spans(); - sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { - constexpr auto distributed_indices = make_tuple(idx0); - const auto x_indices = get_x_indices_from_distributed_indices( - lse_max.get_tile_distribution(), distributed_indices); - - const auto row = x_indices.at(number<0>{}); - - printf( - "[POYENC][DEVICE] lse_max[%2d]: %11.7f\n", row, lse_max(distributed_indices)); - }); - } -#endif - static const auto get_validated_m = [](LSEDataType raw_m) { /// NOTICE: bias might be materialized mask including -inf values, need /// consideration @@ -234,16 +190,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline const auto row = x_indices.at(number<0>{}); const auto col = x_indices.at(number<1>{}); -#if 0 - // from dist tensor - p_compute(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx))); -#else - if (col < num_splits) { + if(col < num_splits) + { // from shared memory p_compute(i_j_idx) = ck_tile::exp(lse_acc_lds_ptr[col + row * kMaxSplits] - get_validated_m(lse_max(i_idx))); } -#endif }); }); } @@ -253,25 +205,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline p_compute, sequence<1>{}, f_sum, type_convert(0)); block_tile_reduce_sync(lse_sum, f_sum, bool_constant{}); -#if defined(PRINT_LSE_SUM) - DEBUG_STMTS - { - constexpr auto out_spans = - static_distributed_tensor:: - get_distributed_spans(); - sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { - constexpr auto distributed_indices = make_tuple(idx0); - const auto x_indices = get_x_indices_from_distributed_indices( - lse_sum.get_tile_distribution(), distributed_indices); - - const auto row = x_indices.at(number<0>{}); - - printf( - "[POYENC][DEVICE] lse_sum[%2d]: %11.7f\n", row, lse_sum(distributed_indices)); - }); - } -#endif - decltype(lse_max) lse_logsum; { constexpr auto out_spans = static_distributed_tensor< @@ -290,19 +223,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline lse_logsum(distributed_indices) = ck_tile::log(lse_sum(distributed_indices)) + get_validated_m(lse_max(distributed_indices)); } - -#if defined(PRINT_LSE_LOGSUM) - DEBUG_STMTS - { - const auto x_indices = get_x_indices_from_distributed_indices( - lse_logsum.get_tile_distribution(), distributed_indices); - - const auto row = x_indices.at(number<0>{}); - printf("[POYENC][DEVICE] lse_logsum[%d]: %11.7f\n", - row, - lse_logsum(distributed_indices)); - } -#endif }); } @@ -327,21 +247,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline } block_sync_lds(); -#if defined(PRINT_LSE_SCALE) - DEBUG_STMTS - { - for(index_t row = 0; row < 32; ++row) - { - printf("[POYENC][DEVICE] lse_scale[%2d] = ", row); - for(index_t col = 0; col < num_splits; ++col) - { - printf("%11.7f", lse_acc_lds_ptr[col + row * kMaxSplits]); - } - printf("\n"); - } - } -#endif - if constexpr(kStoreLSE) { constexpr auto out_spans = static_distributed_tensor< @@ -388,19 +293,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits]; o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices); -#if 0 - DEBUG_STMTS - { - const auto col = x_indices.at(number<1>{}); - printf("[POYENC][DEVICE] [%3d,%3d], o_acc(%11.7f) = lse_scale(%11.7f) " - "* o_tile(%11.7f)\n", - row, - col, - o_acc(distributed_indices), - lse_scale, - o_tile(distributed_indices)); - } -#endif }); }); }