From 3f6d26e9a7484318995bdd41629ea97eef7565ba Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 20 Dec 2025 12:50:18 +0000 Subject: [PATCH] Load Q directly from global memory to registers for BlockGemm --- .../pipeline/block_fmha_pipeline_problem.hpp | 3 - ...mha_pipeline_qr_ks_vs_whole_k_prefetch.hpp | 72 +---- ..._ks_vs_whole_k_prefetch_default_policy.hpp | 265 +----------------- ...eline_qr_ks_vs_whole_k_prefetch_trload.hpp | 44 +-- 4 files changed, 17 insertions(+), 367 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 90200d9e83..75179270e9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -96,9 +96,6 @@ struct BlockFmhaPipelineProblem static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kUseTrLoad = kUseTrLoad_; - // ToDo: should we define kUseTrLoad and kLoadWholeQTileOnceThrough Lds here ? - static constexpr bool kLoadWholeQTileOnceThroughLds = kUseTrLoad ? true : false; - // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index f3577543da..0720802d86 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -75,9 +75,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); - static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM(); - static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM; - static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::kBlockPerCu != -1) return Problem::kBlockPerCu; @@ -223,11 +220,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); OaccBlockTileType o_acc; - auto q_dram_window = - make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - q_dram_block_window_tmp.get_window_origin(), - Policy::template MakeQDramSingleRepMTileDistribution()); + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); const auto q_origin = q_dram_window.get_window_origin(); const auto [seqlen_k_start, seqlen_k_end] = @@ -239,14 +235,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); - using q_dram_tile_type = decltype(load_tile(q_dram_window)); - statically_indexed_array q_dram_tiles; - - static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { - q_dram_tiles[i_rep] = load_tile(q_dram_window); - move_tile_window(q_dram_window, {kGemmSingleRepM, 0}); - }); - using k_tile_type = decltype(load_tile(k_dram_window)); // only prefetch two k tiles to save vgprs consumption @@ -260,24 +248,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch k_tiles[I0] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kN0Sub, 0}); + auto q_tile = load_tile(q_dram_window); + __builtin_amdgcn_sched_barrier(0x00000001); // provide partition_index for LDS tile window with so that warp_id is in vgpr array partition_index{get_warp_id(), get_lane_id()}; - // Q tile in LDS - QDataType* q_lds_ptr = static_cast(smem_ptr); - auto q_lds = make_tensor_view( - q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); - auto q_lds_write_window = make_tile_window( - q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); - auto q_lds_read_window = - make_tile_window(q_lds, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeQRegSingleRepMTileDistribution(), - partition_index); - // K tile in LDS KDataType* k_lds_ptr = static_cast(smem_ptr); auto k_lds = make_tensor_view( @@ -368,47 +345,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); }(); - using q_reg_tile_type = decltype(make_static_distributed_tensor( - Policy::template MakeQRegSingleRepMTileDistribution())); - statically_indexed_array q_reg_tiles; - - using q_tile_type = decltype(make_static_distributed_tensor( - Policy::template MakeQRegTileDistribution())); - - q_tile_type q_tile; - - static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { - store_tile(q_lds_write_window, q_dram_tiles[i_rep], partition_index); - - // no need to call __builtin_amdgcn_s_barrier() since the tile-slice written - // by each wavefront is read by itself - __builtin_amdgcn_s_waitcnt(0xc07f); - - q_reg_tiles[i_rep] = load_tile(q_lds_read_window); - - __builtin_amdgcn_s_waitcnt(0xc07f); - - // the following codes will not generate actual instructions by the compiler - set_slice_tile(q_tile, - q_reg_tiles[i_rep], - sequence{}, - sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{}); - - // no need to call __builtin_amdgcn_s_barrier() since the tile-slice read - // by each wavefront is over-written by itself - }); - q_tile = tile_elementwise_in(q_element_func, q_tile); auto seqlen_k_curr = seqlen_k_start; - __builtin_amdgcn_sched_barrier(0x00000001); - - // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile - __builtin_amdgcn_s_barrier(); - - __builtin_amdgcn_sched_barrier(0x00000001); - using v_tile_type = decltype(load_tile(v_dram_window)); statically_indexed_array v_tiles; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index d805075980..cc342c936d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -43,16 +43,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy return 4; } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSingleRepMTileDistribution() - { - using BlockGemm = remove_cvref_t())>; - constexpr index_t kBlockGemmM = GetQKBlockGemmSingleRepM(); - - return BlockGemm:: - template MakeABlockTileDistribution(); - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() { @@ -103,33 +93,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy return WG::WarpGemmAttribute::Impl::kCM1PerLane; } - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() - { - if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) - return 8; - else - return 4; - } - template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() { - if constexpr(Problem::kLoadWholeQTileOnceThroughLds) - { - return Problem::GetQDramTileAccessMaxVectorSize(); - } - else - { - using QDataType = remove_cvref_t; + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; - return detail:: - GetDramTileAccessMaxVectorSize(); - }; + return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); } template @@ -257,217 +230,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy GetVSingleSmemElementSpaceSize()); }; - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() - { - constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds - ? Problem::BlockFmhaShape::kM0 - : GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t kKPack = GetSmemKPackQ(); - constexpr index_t kKVector = GetAlignmentQ(); - - // for hdim96 and hdim160, use simplest layout - if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim) - { - return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{}), - number{}, - number<1>{}); - } - else if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) - { - static_assert(kKVector == kKPack); - - using QDataType = remove_cvref_t; - - constexpr index_t DataTypeSize = sizeof(QDataType); - - // 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes - constexpr auto MLdsLayer = - (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); - - constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor( - q_lds_block_desc_0, - make_tuple( - make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - constexpr auto q_lds_block_desc_k0_mldslayer_m_k1 = transform_tensor_descriptor( - q_lds_block_desc_permuted, - make_tuple(make_pass_through_transform(number{}), - make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - constexpr auto q_lds_block_desc = transform_tensor_descriptor( - q_lds_block_desc_k0_mldslayer_m_k1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return q_lds_block_desc; - } - else - { - static_assert(kKVector % kKPack == 0); - - constexpr auto q_lds_block_desc_0 = - make_naive_tensor_descriptor(make_tuple(number{}, - number{}, - number{}, - number{}), - make_tuple(number{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); - - constexpr auto q_lds_block_desc = transform_tensor_descriptor( - q_lds_block_desc_0, - make_tuple(make_pass_through_transform(number{}), - make_merge_transform(make_tuple(number{}, - number{}, - number{}))), - make_tuple(sequence<2>{}, sequence<0, 1, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return q_lds_block_desc; - }; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQDramSingleRepMTileDistribution() - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - - constexpr index_t kKVector = GetAlignmentQ(); - constexpr index_t OtherK = kKPerBlock / kKVector; - - if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim) - // for kKPerBlock=32,64,128,256 - { - static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!"); - - constexpr index_t KPerThread = kKVector; - constexpr index_t KThreads = OtherK; - - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); - - // for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E] - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); - } - else // for kKPerBlock=96,160 - { - static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!"); - - // ToDo: need more considieration for hdim72 - constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5; - constexpr index_t KThreads = OtherK / KRepPerThread; - - static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!"); - - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 1>>, - sequence<1, 2, 2>, - sequence<1, 0, 2>>{}); - }; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - - constexpr index_t kKVector = GetAlignmentQ(); - constexpr index_t OtherK = kKPerBlock / kKVector; - - if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim) - // for kKPerBlock=32,64,128,256 - { - static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!"); - - constexpr index_t KPerThread = kKVector; - constexpr index_t KThreads = OtherK; - - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); - - // for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E] - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); - } - else // for kKPerBlock=96,160 - { - static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!"); - - // ToDo: need more considieration for hdim72 - constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5; - constexpr index_t KThreads = OtherK / KRepPerThread; - - static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!"); - - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 1>>, - sequence<1, 2, 2>, - sequence<1, 0, 2>>{}); - }; - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { @@ -823,13 +585,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy sequence<1, 2>>{}); } - template - CK_TILE_HOST_DEVICE static constexpr index_t GetQKBlockGemmSingleRepM() - { - return Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}) * - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - }; - template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { @@ -976,13 +731,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy return BlockGemmARegBSmemCRegOneWarpV1{}; } - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() - { - return MakeQLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::QDataType); - }; - template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { @@ -1001,8 +749,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - return max(GetSmemSizeKV() + GetSmemSizeDropout(), - GetSmemSizeQ()); + return GetSmemSizeKV() + GetSmemSizeDropout(); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp index 4ebfe19bd7..3ca4734ac9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp @@ -65,8 +65,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad static constexpr bool kUseTrLoad = true; - static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!"); - // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kAlignmentQ = @@ -226,11 +224,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); OaccBlockTileType o_acc; - auto q_dram_window = - make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - q_dram_block_window_tmp.get_window_origin(), - Policy::template MakeQDramSingleRepMTileDistribution()); + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); const auto q_origin = q_dram_window.get_window_origin(); const auto [seqlen_k_start, seqlen_k_end] = @@ -242,7 +239,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); - auto q_dram_tile = load_tile(q_dram_window); + auto q_tile = load_tile(q_dram_window); using k_tile_type = decltype(load_tile(k_dram_window)); @@ -262,19 +259,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad // provide partition_index for LDS tile window with so that warp_id is in vgpr array partition_index{get_warp_id(), get_lane_id()}; - // Q tile in LDS - QDataType* q_lds_ptr = static_cast(smem_ptr); - auto q_lds = make_tensor_view( - q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); - auto q_lds_write_window = make_tile_window( - q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); - auto q_lds_read_window = - make_tile_window(q_lds, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeQRegTileDistribution(), - partition_index); - // K tile in LDS KDataType* k_lds_ptr = static_cast(smem_ptr); auto k_lds = make_tensor_view( @@ -361,18 +345,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); }(); - store_tile(q_lds_write_window, q_dram_tile, partition_index); - clear_tile(o_acc); - - __builtin_amdgcn_sched_barrier(0x00000001); - - block_sync_lds(); - - auto q_tile = load_tile(q_lds_read_window); - - q_tile = tile_elementwise_in(q_element_func, q_tile); - set_tile(m, -numeric::infinity()); clear_tile(l); @@ -380,13 +353,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad auto seqlen_k_curr = seqlen_k_start; - __builtin_amdgcn_sched_barrier(0x00000001); - - // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile - __builtin_amdgcn_s_barrier(); - - __builtin_amdgcn_sched_barrier(0x00000001); - using v_tile_type = decltype(load_tile(v_dram_window)); statically_indexed_array v_tiles;