From 1dbed18555dd15804d8d9302db079263fbe37f56 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 23 Jul 2024 03:11:31 +0000 Subject: [PATCH] Remove constness from q_ptr --- example/ck_tile/01_fmha/fmha_fwd.hpp | 2 +- .../fmha/kernel/fmha_fwd_appendkv_kernel.hpp | 12 +++---- .../block_fmha_fwd_appendkv_pipeline.hpp | 26 ++++++++++++-- ...a_fwd_appendkv_pipeline_default_policy.hpp | 34 +++++++++++++++++++ 4 files changed, 65 insertions(+), 9 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 8b8f08fc20..089ea923f6 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -156,7 +156,7 @@ struct fmha_fwd_args struct fmha_fwd_appendkv_args { - const void* q_ptr; + void* q_ptr; void* k_ptr; const void* knew_ptr; void* v_ptr; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index b64aeb48d7..4d65bfed62 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -79,7 +79,7 @@ struct FmhaFwdAppendKVKernel // user need to use MakeKargs() function to create kargs. struct CommonKargs { - const void* q_ptr; + void* q_ptr; void* k_ptr; const void* knew_ptr; void* v_ptr; @@ -139,7 +139,7 @@ struct FmhaFwdAppendKVKernel template __host__ static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, + MakeKargs(void* q_ptr, void* k_ptr, const void* knew_ptr, void* v_ptr, @@ -211,7 +211,7 @@ struct FmhaFwdAppendKVKernel template __host__ static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, + MakeKargs(void* q_ptr, void* k_ptr, const void* knew_ptr, void* v_ptr, @@ -384,9 +384,9 @@ struct FmhaFwdAppendKVKernel } // for simplicity, batch stride we just modify the pointer - const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_q + - batch_offset_q; + QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index 77a960dab6..d200d457b3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -87,7 +87,7 @@ struct BlockFmhaFwdAppendKVPipeline typename RotaryCosDramBlockWindow, typename RotarySinDramBlockWindow> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindow& q_dram_block_window, // M0*K0 tile + operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile const QElementFunction& q_element_func, KDramBlockWindow& k_dram_block_window, // N0*K0 tile const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile @@ -261,6 +261,28 @@ struct BlockFmhaFwdAppendKVPipeline return tile_elementwise_in(vnew_element_func, vnew); }(); store_tile(v_dram_block_window, vnew_tile); + + if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE) + { + auto q_window = make_tile_window(q_dram_block_window.get_bottom_tensor_view(), + q_dram_block_window.get_window_lengths(), + q_dram_block_window.get_window_origin(), + Policy::template MakeQDramTileDistribution()); + + auto q_tile = [&]() { + auto q = load_tile(q_window); + return tile_elementwise_in(q_element_func, q); + }(); + + // We assume that each thread owns contiguous elements on head dimention. And we will + // use the distribution to enable/disable threads in order to override knew_tile content + if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) {} + else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED + { + } + + store_tile(q_dram_block_window, q_tile); + } } template CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindow& q_dram_block_window, + operator()(QDramBlockWindow& q_dram_block_window, KDramBlockWindow& k_dram_block_window, const KnewDramBlockWindow& knew_dram_block_window, VDramBlockWindow& v_dram_block_window, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp index 344f3077ac..a12379aa05 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp @@ -57,6 +57,40 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy return sizeof(KDataType) * Problem::kTileSizeSk * (Problem::kTileSizeD); } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + { + using QDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kTileSizeS; + constexpr index_t kKPerBlock = Problem::kTileSizeD; + + constexpr index_t KPerThread = [&]() { + if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED) + { + return 8 / sizeof(QDataType); + } + else + { + return 16 / sizeof(QDataType); + } + }(); + constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread; + constexpr index_t MThreadPerWarp = get_warp_size() / KThreadPerBlock; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (NumWarps * MThreadPerWarp); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeKnewDramTileDistribution() {