From e00ff9d24665ad891fd79fe740baa7bb87fef61a Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 12 Jun 2024 09:17:04 +0000 Subject: [PATCH] Simplify pipeline source code --- ...lock_fmha_fwd_splitkv_combine_pipeline.hpp | 87 ++++++++----------- 1 file changed, 36 insertions(+), 51 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 386accfa5a..e4732cc251 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -118,21 +118,17 @@ struct BlockFmhaFwdSplitKVCombinePipeline auto lse_accum = load_tile(lse_acc_lds_read_window, Policy::template MakeLSEaccTDramTileDistribution()); #else - auto lse_accum_dist = Policy::template MakeLSEaccTDramTileDistribution(); - auto lse_accum = make_static_distributed_tensor(lse_accum_dist); + auto lse_accum = make_static_distributed_tensor( + Policy::template MakeLSEaccTDramTileDistribution()); // copy LDS to lse_accum (transpose) { - using DataType = LSEDataType; - using StaticTileDistribution = decltype(lse_accum_dist); - constexpr auto out_spans = - static_distributed_tensor::get_distributed_spans(); + constexpr auto out_spans = decltype(lse_accum)::get_distributed_spans(); sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) { - constexpr auto distributed_indices = make_tuple(idx0, idx1); - const auto x_indices = get_x_indices_from_distributed_indices( - StaticTileDistribution{}, distributed_indices); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + lse_accum.get_tile_distribution(), i_j_idx); const auto row = x_indices.at(number<0>{}); const auto col = x_indices.at(number<1>{}); @@ -140,11 +136,11 @@ struct BlockFmhaFwdSplitKVCombinePipeline auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col)); if(col < num_splits) { - lse_accum(distributed_indices) = lse_acc_lds_ptr[offset]; + lse_accum(i_j_idx) = lse_acc_lds_ptr[offset]; } else { - lse_accum(distributed_indices) = -numeric::infinity(); + lse_accum(i_j_idx) = -numeric::infinity(); } }); }); @@ -164,17 +160,16 @@ struct BlockFmhaFwdSplitKVCombinePipeline : raw_m; }; - auto p_compute = make_static_distributed_tensor( - lse_accum.get_tile_distribution()); // Pcompute{j} - clear_tile(p_compute); + decltype(lse_accum) lse_exp; + clear_tile(lse_exp); { - constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + constexpr auto p_spans = decltype(lse_exp)::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); const auto x_indices = get_x_indices_from_distributed_indices( - p_compute.get_tile_distribution(), i_j_idx); + lse_exp.get_tile_distribution(), i_j_idx); const auto row = x_indices.at(number<0>{}); const auto col = x_indices.at(number<1>{}); @@ -183,7 +178,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline if(col < num_splits) { // from shared memory - p_compute(i_j_idx) = + lse_exp(i_j_idx) = ck_tile::exp(lse_acc_lds_ptr[offset] - get_validated_m(lse_max(i_idx))); } }); @@ -192,39 +187,34 @@ struct BlockFmhaFwdSplitKVCombinePipeline __syncthreads(); auto lse_sum = block_tile_reduce( - p_compute, sequence<1>{}, f_sum, type_convert(0)); + lse_exp, sequence<1>{}, f_sum, type_convert(0)); block_tile_reduce_sync(lse_sum, f_sum, bool_constant{}); decltype(lse_max) lse_logsum; { - constexpr auto out_spans = static_distributed_tensor< - LSEDataType, - decltype(lse_logsum.get_tile_distribution())>::get_distributed_spans(); + constexpr auto out_spans = decltype(lse_logsum)::get_distributed_spans(); sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { - constexpr auto distributed_indices = make_tuple(idx0); + constexpr auto i_idx = make_tuple(idx0); - if(lse_sum(distributed_indices) == 0.f || - lse_sum(distributed_indices) != lse_sum(distributed_indices)) + if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx)) { - lse_logsum(distributed_indices) = numeric::infinity(); + lse_logsum(i_idx) = numeric::infinity(); } else { - lse_logsum(distributed_indices) = ck_tile::log(lse_sum(distributed_indices)) + - get_validated_m(lse_max(distributed_indices)); + lse_logsum(i_idx) = + ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx)); } }); } // write lse scales into LDS { - constexpr auto out_spans = - static_distributed_tensor:: - get_distributed_spans(); + constexpr auto out_spans = decltype(lse_sum)::get_distributed_spans(); sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { - constexpr auto distributed_indices = make_tuple(idx0); - const auto x_indices = get_x_indices_from_distributed_indices( - lse_sum.get_tile_distribution(), distributed_indices); + constexpr auto i_idx = make_tuple(idx0); + const auto x_indices = + get_x_indices_from_distributed_indices(lse_sum.get_tile_distribution(), i_idx); const auto row = x_indices.at(number<0>{}); @@ -232,7 +222,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline { auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col)); lse_acc_lds_ptr[offset] = - ck_tile::exp(lse_acc_lds_ptr[offset] - lse_logsum(distributed_indices)); + ck_tile::exp(lse_acc_lds_ptr[offset] - lse_logsum(i_idx)); } }); } @@ -240,15 +230,13 @@ struct BlockFmhaFwdSplitKVCombinePipeline if constexpr(kStoreLSE) { - constexpr auto out_spans = static_distributed_tensor< - LSEDataType, - decltype(lse_logsum.get_tile_distribution())>::get_distributed_spans(); + constexpr auto out_spans = decltype(lse_logsum)::get_distributed_spans(); sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { - constexpr auto distributed_indices = make_tuple(idx0); + constexpr auto i_idx = make_tuple(idx0); - if(lse_logsum(distributed_indices) == numeric::infinity()) + if(lse_logsum(i_idx) == numeric::infinity()) { - lse_logsum(distributed_indices) = -numeric::infinity(); + lse_logsum(i_idx) = -numeric::infinity(); } }); @@ -264,21 +252,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline auto o_acc = make_static_distributed_tensor(o_acc_dist); // Pcompute{j} clear_tile(o_acc); - const index_t new_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0; + const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0; for(index_t i_split = 0; i_split < num_splits; ++i_split) { auto o_tile = load_tile(o_acc_dram_window); { - using DataType = OaccDataType; - constexpr auto out_spans = - static_distributed_tensor::get_distributed_spans(); + constexpr auto out_spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) { - constexpr auto distributed_indices = make_tuple(idx0, idx1); - const auto x_indices = - get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + o_acc.get_tile_distribution(), i_j_idx); const auto row = x_indices.at(number<0>{}); @@ -286,12 +271,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, i_split)); LSEDataType lse_scale = lse_acc_lds_ptr[offset]; - o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices); + o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx); }); }); } - move_tile_window(o_acc_dram_window, {new_max_seqlen_q, 0}); + move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0}); } o_acc = tile_elementwise_in(o_acc_element_func, o_acc);