From 47a74f282d359bc408af39c20731a8a40baa0196 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 24 Jul 2024 03:23:18 +0000 Subject: [PATCH] Extract Q/Knew vector size to helper methods --- ...a_fwd_appendkv_pipeline_default_policy.hpp | 80 ++++++++++--------- 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp index f716e1eed1..c2c41dbe60 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp @@ -57,6 +57,21 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy return sizeof(KDataType) * Problem::kN0 * (Problem::kK0); } + template + CK_TILE_DEVICE static constexpr auto GetQNumElemsPerRead() + { + using DataType = typename Problem::QDataType; + + if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) + { + return 8 / sizeof(DataType); + } + else + { + return 16 / sizeof(DataType); + } + } + template CK_TILE_DEVICE static auto GetQThreadRangeAlongK() { @@ -64,43 +79,32 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) { - constexpr index_t KPerThread = 16 / sizeof(typename Problem::QDataType); + constexpr index_t KPerThread = GetQNumElemsPerRead(); static_assert(Problem::kK0 % KPerThread == 0); constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread; - index_t start_x = (get_thread_id() % KThreadPerBlock) * KPerThread; + index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread; - return make_tuple(start_x, start_x + KPerThread); + return make_tuple(start_pos, start_pos + KPerThread); } else { - constexpr index_t KPerThread = 8 / sizeof(typename Problem::QDataType); + constexpr index_t KPerThread = GetQNumElemsPerRead(); static_assert(Problem::kK0 % KPerThread == 0); constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread; - index_t start_x = (get_thread_id() % KThreadPerBlock) * KPerThread; + index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread; - return make_tuple(start_x, start_x + KPerThread); + return make_tuple(start_pos, start_pos + KPerThread); } } template CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() { - using QDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kKPerBlock = Problem::kK0; - constexpr index_t KPerThread = [&]() { - if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) - { - return 8 / sizeof(QDataType); - } - else - { - return 16 / sizeof(QDataType); - } - }(); + constexpr index_t KPerThread = GetQNumElemsPerRead(); constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread; constexpr index_t MThreadPerWarp = get_warp_size() / KThreadPerBlock; constexpr index_t NumWarps = kBlockSize / get_warp_size(); @@ -116,6 +120,21 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy sequence<0, 1>>{}); } + template + CK_TILE_DEVICE static constexpr auto GetKnewNumElemsPerRead() + { + using DataType = typename Problem::KDataType; + + if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) + { + return 8 / sizeof(DataType); + } + else + { + return 16 / sizeof(DataType); + } + } + template CK_TILE_DEVICE static auto GetKnewThreadRangeAlongK() { @@ -123,41 +142,30 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) { - constexpr index_t KPerThread = 16 / sizeof(typename Problem::KDataType); + constexpr index_t KPerThread = GetKnewNumElemsPerRead(); constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread; - index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; + index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread; - return make_tuple(start_x, start_x + KPerThread); + return make_tuple(start_pos, start_pos + KPerThread); } else { - constexpr index_t KPerThread = 8 / sizeof(typename Problem::KDataType); + constexpr index_t KPerThread = GetKnewNumElemsPerRead(); constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread; - index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; + index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread; - return make_tuple(start_x, start_x + KPerThread); + return make_tuple(start_pos, start_pos + KPerThread); } } template CK_TILE_HOST_DEVICE static constexpr auto MakeKnewDramTileDistribution() { - using KDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::kN0; constexpr index_t kKPerBlock = Problem::kK0; - constexpr index_t KPerThread = [&]() { - if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) - { - return 8 / sizeof(KDataType); - } - else - { - return 16 / sizeof(KDataType); - } - }(); + constexpr index_t KPerThread = GetKnewNumElemsPerRead(); constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread; constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock; constexpr index_t NumWarps = kBlockSize / get_warp_size();