From 1ca2b3d76ceb9f0aca2d51fc8e3604113323a396 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Tue, 8 Oct 2024 10:44:34 +0800 Subject: [PATCH] [CK_TILE] Simplify the codes in splitkv_combine pipeline (#1549) * Simplify the codes in splitkv_combine pipeline * Always set kPadSeqLenK=true for fmha splitkv kernels * Change in Oacc Alignment and TileDistribution to be more adaptable to tile sizes --------- Co-authored-by: Po Yen Chen [ROCm/composable_kernel commit: 74d68e3b991dbfff7f14881a572bc77f4954c4fc] --- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 4 +- ...lock_fmha_fwd_splitkv_combine_pipeline.hpp | 90 ++++++++++--------- ...plitkv_combine_pipeline_default_policy.hpp | 23 +++-- 3 files changed, 67 insertions(+), 50 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index ba826c8fb3..82cf3a5ab2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -600,8 +600,8 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> # TODO: use async pipeline when compiler is more stable if hdim == 256 or hdim in [32, 64, 128]: # if True: - pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 842090afbe..1afe0feab3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -172,22 +172,27 @@ struct BlockFmhaFwdSplitKVCombinePipeline lse_accum, sequence<1>{}, f_max, -numeric::infinity()); block_tile_reduce_sync(lse_max, f_max, bool_constant{}); - static const auto get_validated_m = [](LSEDataType raw_m) { - return raw_m == -numeric::infinity() ? type_convert(0.f) - : raw_m; - }; - decltype(lse_accum) lse_exp; { constexpr auto spans = decltype(lse_exp)::get_distributed_spans(); sweep_tile_span(spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - sweep_tile_span(spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); + if(lse_max[i_idx] == -numeric::infinity()) + { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); - lse_exp(i_j_idx) = - ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx))); - }); + lse_exp(i_j_idx) = ck_tile::type_convert(0.0f); + }); + } + else + { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + lse_exp(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - lse_max(i_idx)); + }); + } }); } @@ -201,15 +206,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline sweep_tile_span(spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx)) - { - lse_logsum(i_idx) = numeric::infinity(); - } + if(lse_sum[i_idx] == ck_tile::type_convert(0.0f)) + lse_logsum(i_idx) = -numeric::infinity(); else - { - lse_logsum(i_idx) = - ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx)); - } + lse_logsum(i_idx) = ck_tile::log(lse_sum(i_idx)) + lse_max(i_idx); }); } @@ -218,37 +218,47 @@ struct BlockFmhaFwdSplitKVCombinePipeline constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); sweep_tile_span(spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - sweep_tile_span(spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); + if(lse_logsum(i_idx) == -numeric::infinity()) + { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); - const auto x_indices = get_x_indices_from_distributed_indices( - lse_accum.get_tile_distribution(), i_j_idx); + const auto x_indices = get_x_indices_from_distributed_indices( + lse_accum.get_tile_distribution(), i_j_idx); - const auto col = x_indices.at(number<1>{}); - if(col < num_splits) - { - const auto row = x_indices.at(number<0>{}); + const auto col = x_indices.at(number<1>{}); + if(col < num_splits) + { + const auto row = x_indices.at(number<0>{}); - lse_acc_lds(row, col) = - ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx)); - } - }); + lse_acc_lds(row, col) = ck_tile::type_convert(0.0f); + } + }); + } + else + { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + const auto x_indices = get_x_indices_from_distributed_indices( + lse_accum.get_tile_distribution(), i_j_idx); + + const auto col = x_indices.at(number<1>{}); + if(col < num_splits) + { + const auto row = x_indices.at(number<0>{}); + + lse_acc_lds(row, col) = + ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx)); + } + }); + } }); } block_sync_lds(); if constexpr(kStoreLSE) { - constexpr auto spans = decltype(lse_logsum)::get_distributed_spans(); - sweep_tile_span(spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - - if(lse_logsum(i_idx) == numeric::infinity()) - { - lse_logsum(i_idx) = -numeric::infinity(); - } - }); - store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp index 2eb092f055..3327d4af87 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -21,14 +21,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() { using OaccDataType = remove_cvref_t; - return 16 / sizeof(OaccDataType); + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kNPerBlock = Problem::kN1; + + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size()); + constexpr index_t N0 = get_warp_size() / M2; + constexpr index_t N1 = kNPerBlock / N0; + + return min(N1, static_cast(16 / sizeof(OaccDataType))); } template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() { - using ODataType = remove_cvref_t; - return 16 / sizeof(ODataType); + return GetAlignmentOacc(); } template @@ -150,16 +159,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution() { - using OaccDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kNPerBlock = Problem::kN1; - constexpr index_t N1 = 16 / sizeof(OaccDataType); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t M2 = get_warp_size() / N0; constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size()); + constexpr index_t N0 = get_warp_size() / M2; + constexpr index_t N1 = kNPerBlock / N0; constexpr index_t M0 = kMPerBlock / (M2 * M1); return make_static_tile_distribution(