From 6f95239229c807608c168445bbd58285ee65f470 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 24 Jul 2024 03:40:29 +0000 Subject: [PATCH] Use different rotary_cos/rotary_sin distr for Q/Knew --- .../block_fmha_fwd_appendkv_pipeline.hpp | 16 ++++++-- ...a_fwd_appendkv_pipeline_default_policy.hpp | 38 ++++++++++++------- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index efd13ed503..737c1b5feb 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -172,11 +172,15 @@ struct BlockFmhaFwdAppendKVPipeline { auto rotary_cos_window = make_tile_window(knew_rotary_cos_dram_block_window, - Policy::template MakeRotaryCosSinTileDistribution()); + Policy::template MakeRotaryCosSinTileDistribution< + Problem, + /*IsRotaryCosSinForQ=*/false>()); auto rotary_sin_window = make_tile_window(knew_rotary_sin_dram_block_window, - Policy::template MakeRotaryCosSinTileDistribution()); + Policy::template MakeRotaryCosSinTileDistribution< + Problem, + /*IsRotaryCosSinForQ=*/false>()); // We assume that each thread owns contiguous elements on head dimention. And we // will use the distribution to enable/disable threads in order to override partial @@ -220,11 +224,15 @@ struct BlockFmhaFwdAppendKVPipeline auto rotary_cos_window = make_tile_window(q_rotary_cos_dram_block_window, - Policy::template MakeRotaryCosSinTileDistribution()); + Policy::template MakeRotaryCosSinTileDistribution< + Problem, + /*IsRotaryCosSinForQ=*/true>()); auto rotary_sin_window = make_tile_window(q_rotary_sin_dram_block_window, - Policy::template MakeRotaryCosSinTileDistribution()); + Policy::template MakeRotaryCosSinTileDistribution< + Problem, + /*IsRotaryCosSinForQ=*/true>()); // We assume that each thread owns contiguous elements on head dimention. And we // will use the distribution to enable/disable threads in order to override partial diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp index c2c41dbe60..76066befeb 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp @@ -236,25 +236,35 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy } } - template + template + CK_TILE_DEVICE static constexpr auto GetRotaryCosSinTileSize() + { + constexpr index_t height = (IsRotaryCosSinForQ ? Problem::kM0 : Problem::kN0); + + if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) + { + return make_tuple(number{}, number{}); + } + else // Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED + { + return make_tuple(number{}, number{}); + } + } + + template CK_TILE_DEVICE static constexpr auto MakeRotaryCosSinTileDistribution() { - using KDataType = remove_cvref_t; + using DataType = std::conditional_t; + + constexpr auto TileSize = GetRotaryCosSinTileSize(); constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::kN0; - constexpr index_t kKPerBlock = [&]() { - if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) - { - return Problem::kK0; - } - else - { - return Problem::kK0 / 2; - } - }(); + constexpr index_t kNPerBlock = TileSize[number<0>{}]; + constexpr index_t kKPerBlock = TileSize[number<1>{}]; - constexpr index_t KPerThread = 8 / sizeof(KDataType); + constexpr index_t KPerThread = 8 / sizeof(DataType); constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread; constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock; constexpr index_t NumWarps = kBlockSize / get_warp_size();