From eb649a2f25e40221b342df8978de18e8959469fb Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 23 Jul 2024 07:21:20 +0000 Subject: [PATCH] Move thread locating logics into policy --- .../block_fmha_fwd_appendkv_pipeline.hpp | 41 +++++----------- ...a_fwd_appendkv_pipeline_default_policy.hpp | 48 +++++++++++++++++++ 2 files changed, 61 insertions(+), 28 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index d8afe24eeb..4d5db4363d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -184,17 +184,16 @@ struct BlockFmhaFwdAppendKVPipeline // We assume that each thread owns contiguous elements on head dimention. And we // will use the distribution to enable/disable threads in order to override // knew_tile content + auto [thread_start, thread_end] = + Policy::template GetKnewThreadRangeAlongK(); + ignore = thread_start; + if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) { auto rotary_cos_tile = load_tile(rotary_cos_window); auto rotary_sin_tile = load_tile(rotary_sin_window); - constexpr index_t KPerThread = 16 / sizeof(KDataType); - static_assert(kTileSizeD % KPerThread == 0); - constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; - index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; - - if((start_x + KPerThread) <= rotary_dim) + if(thread_end <= rotary_dim) { constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size(); @@ -217,14 +216,9 @@ struct BlockFmhaFwdAppendKVPipeline } else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED { - constexpr index_t KPerThread = 8 / sizeof(KDataType); - static_assert(kTileSizeD % KPerThread == 0); - constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; - index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; - - if((start_x + KPerThread) <= rotary_dim) + if(thread_end <= rotary_dim) { - const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2); + const bool is_left = (thread_end <= (rotary_dim / 2)); auto knew_other_window = knew_window; move_tile_window(knew_other_window, @@ -291,20 +285,17 @@ struct BlockFmhaFwdAppendKVPipeline // We assume that each thread owns contiguous elements on head dimention. And we // will use the distribution to enable/disable threads in order to override q_tile // content + auto [thread_start, thread_end] = Policy::template GetQThreadRangeAlongK(); + ignore = thread_start; + if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) { auto rotary_cos_tile = load_tile(rotary_cos_window); auto rotary_sin_tile = load_tile(rotary_sin_window); - constexpr index_t KPerThread = 16 / sizeof(QDataType); - static_assert(kTileSizeD % KPerThread == 0); - constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; - index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; - - if((start_x + KPerThread) <= rotary_dim) + if(thread_end <= rotary_dim) { constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size(); - static_assert(thread_buffer_size % KPerThread == 0); static_for<0, thread_buffer_size, 2>{}([&](auto idx) { const auto left = type_convert(q_tile.thread_buf_[idx]); const auto right = type_convert(q_tile.thread_buf_[idx + 1]); @@ -323,14 +314,9 @@ struct BlockFmhaFwdAppendKVPipeline } else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED { - constexpr index_t KPerThread = 8 / sizeof(QDataType); - static_assert(kTileSizeD % KPerThread == 0); - constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; - index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; - - if((start_x + KPerThread) <= rotary_dim) + if(thread_end <= rotary_dim) { - const bool is_left = (start_x + KPerThread) <= (rotary_dim / 2); + const bool is_left = (thread_end <= (rotary_dim / 2)); auto q_other_window = q_window; move_tile_window(q_other_window, @@ -344,7 +330,6 @@ struct BlockFmhaFwdAppendKVPipeline auto rotary_sin_tile = load_tile(rotary_sin_window); constexpr index_t thread_buffer_size = decltype(q_tile.thread_buf_)::size(); - static_assert(thread_buffer_size % KPerThread == 0); static_for<0, thread_buffer_size, 1>{}([&](auto idx) { const auto curr = type_convert(q_tile.thread_buf_[idx]); const auto other = type_convert(q_other_tile.thread_buf_[idx]); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp index a12379aa05..4e8bbc0850 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp @@ -57,6 +57,31 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy return sizeof(KDataType) * Problem::kTileSizeSk * (Problem::kTileSizeD); } + template + CK_TILE_DEVICE static auto GetQThreadRangeAlongK() + { + static_assert(Problem::RotaryEnum != BlockRotaryEmbeddingEnum::NONE); + + if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) + { + constexpr index_t KPerThread = 16 / sizeof(typename Problem::QDataType); + static_assert(Problem::kTileSizeD % KPerThread == 0); + constexpr index_t KThreadPerBlock = Problem::kTileSizeD / KPerThread; + index_t start_x = (get_thread_id() % KThreadPerBlock) * KPerThread; + + return make_tuple(start_x, start_x + KPerThread); + } + else + { + constexpr index_t KPerThread = 8 / sizeof(typename Problem::QDataType); + static_assert(Problem::kTileSizeD % KPerThread == 0); + constexpr index_t KThreadPerBlock = Problem::kTileSizeD / KPerThread; + index_t start_x = (get_thread_id() % KThreadPerBlock) * KPerThread; + + return make_tuple(start_x, start_x + KPerThread); + } + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() { @@ -91,6 +116,29 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy sequence<0, 1>>{}); } + template + CK_TILE_DEVICE static auto GetKnewThreadRangeAlongK() + { + static_assert(Problem::RotaryEnum != BlockRotaryEmbeddingEnum::NONE); + + if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) + { + constexpr index_t KPerThread = 16 / sizeof(typename Problem::KDataType); + constexpr index_t KThreadPerBlock = Problem::kTileSizeD / KPerThread; + index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; + + return make_tuple(start_x, start_x + KPerThread); + } + else + { + constexpr index_t KPerThread = 8 / sizeof(typename Problem::KDataType); + constexpr index_t KThreadPerBlock = Problem::kTileSizeD / KPerThread; + index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; + + return make_tuple(start_x, start_x + KPerThread); + } + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeKnewDramTileDistribution() {