diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index 44d51d13db..f3ba54a35c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -74,6 +74,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM(); + static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM; + static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::kBlockPerCu != -1) return Problem::kBlockPerCu; @@ -220,10 +223,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); OaccBlockTileType o_acc; - auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - q_dram_block_window_tmp.get_window_origin(), - Policy::template MakeQRegTileDistribution()); + auto q_dram_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramSingleRepMTileDistribution()); const auto q_origin = q_dram_window.get_window_origin(); const auto [seqlen_k_start, seqlen_k_end] = @@ -235,6 +239,14 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); + using q_dram_tile_type = decltype(load_tile(q_dram_window)); + statically_indexed_array q_dram_tiles; + + static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { + q_dram_tiles[i_rep] = load_tile(q_dram_window); + move_tile_window(q_dram_window, {kGemmSingleRepM, 0}); + }); + using k_tile_type = decltype(load_tile(k_dram_window)); // only prefetch two k tiles to save vgprs consumption @@ -250,9 +262,22 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch __builtin_amdgcn_sched_barrier(0x00000001); - auto q_tile = load_tile(q_dram_window); + // provide partition_index for LDS tile window with so that warp_id is in vgpr + array partition_index{get_warp_id(), get_lane_id()}; - __builtin_amdgcn_sched_barrier(0x00000001); + // Q tile in LDS + QDataType* q_lds_ptr = static_cast(smem_ptr); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_write_window = make_tile_window( + q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); + // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window + auto q_lds_read_window = + make_tile_window(q_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeQRegSingleRepMTileDistribution(), + partition_index); // K tile in LDS KDataType* k_lds_ptr = static_cast(smem_ptr); @@ -354,19 +379,58 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); }(); + using q_reg_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeQRegSingleRepMTileDistribution())); + statically_indexed_array q_reg_tiles; + + using q_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeQRegTileDistribution())); + + q_tile_type q_tile; + + { + static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { + store_tile(q_lds_write_window, q_dram_tiles[i_rep], partition_index); + + // no need to call __builtin_amdgcn_s_barrier() since the tile-slice written + // by each wavefront is read by itself + __builtin_amdgcn_s_waitcnt(0xc07f); + + q_reg_tiles[i_rep] = load_tile(q_lds_read_window); + + __builtin_amdgcn_s_waitcnt(0xc07f); + + // the following codes will not generate actual instructions by the compiler + set_slice_tile(q_tile, + q_reg_tiles[i_rep], + sequence{}, + sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{}); + + // no need to call __builtin_amdgcn_s_barrier() since the tile-slice read + // by each wavefront is over-written by itself + }); + + clear_tile(o_acc); + + set_tile(m, -numeric::infinity()); + clear_tile(l); + }; + q_tile = tile_elementwise_in(q_element_func, q_tile); auto seqlen_k_curr = seqlen_k_start; __builtin_amdgcn_sched_barrier(0x00000001); + // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile + __builtin_amdgcn_s_barrier(); + + __builtin_amdgcn_sched_barrier(0x00000001); + using v_tile_type = decltype(load_tile(v_dram_window)); statically_indexed_array v_tiles; - // provide partition_index for LDS tile window with so that warp_id is in vgpr - array partition_index{get_warp_id(), get_lane_id()}; - do { // STAGE 1, Gemm_0 ( S = Q@K ) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index a1cc228812..046d07909a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -40,6 +40,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy return 4; } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSingleRepMTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + constexpr index_t kBlockGemmM = GetQKBlockGemmSingleRepM(); + + return BlockGemm:: + template MakeABlockTileDistribution(); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() { @@ -90,16 +100,28 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy return WG::WarpGemmAttribute::Impl::kCM1PerLane; } + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() + { + if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + return 8; + else + return 4; + } + template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() { - constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + using QDataType = remove_cvref_t; - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); + constexpr index_t MaxVectorSize = 16 / sizeof(QDataType); + constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; + + return min(MaxVectorSize, ElemPerThread); } template @@ -209,6 +231,118 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy GetVSingleSmemElementSpaceSize()); }; + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + constexpr index_t kKPack = GetSmemKPackQ(); + constexpr index_t kKVector = GetAlignmentQ(); + + if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + { + static_assert(kKVector == kKPack); + + using QDataType = remove_cvref_t; + + constexpr index_t DataTypeSize = sizeof(QDataType); + + // 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes + constexpr auto MLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor( + q_lds_block_desc_0, + make_tuple( + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + constexpr auto q_lds_block_desc_k0_mldslayer_m_k1 = transform_tensor_descriptor( + q_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + constexpr auto q_lds_block_desc = transform_tensor_descriptor( + q_lds_block_desc_k0_mldslayer_m_k1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return q_lds_block_desc; + } + else + { + static_assert(kKVector % kKPack == 0); + + constexpr auto q_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto q_lds_block_desc = transform_tensor_descriptor( + q_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<2>{}, sequence<0, 1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return q_lds_block_desc; + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramSingleRepMTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; + + constexpr index_t kKVector = GetAlignmentQ(); + + constexpr index_t KPerThread = kKVector; + constexpr index_t KThreads = kKPerBlock / KPerThread; + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + // for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { @@ -317,19 +451,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() { - using KDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim; - constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); + constexpr index_t kKVector = GetAlignmentK(); - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - - constexpr index_t KPerThread = kMaxVecLoad; + constexpr index_t KPerThread = kKVector; constexpr index_t KThreads = kKPerBlock / KPerThread; constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; constexpr index_t NumWarps = kBlockSize / get_warp_size(); @@ -446,6 +574,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy sequence<1, 2>>{}); } + template + CK_TILE_HOST_DEVICE static constexpr index_t GetQKBlockGemmSingleRepM() + { + return Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}) * + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + }; + template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { @@ -587,6 +722,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy return BlockGemmARegBSmemCRegOneWarpV1{}; } + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() + { + return MakeQLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::QDataType); + }; + template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { @@ -605,7 +747,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - return GetSmemSizeKV() + GetSmemSizeDropout(); + return max(GetSmemSizeKV() + GetSmemSizeDropout(), + GetSmemSizeQ()); } };