diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index cd16c9a0b3..950e5323f5 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -68,6 +68,9 @@ struct HstuAttentionFwdPipelineQRKSVS static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM(); + static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM; + static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::Traits::kBlockPerCu != -1) return Problem::Traits::kBlockPerCu; @@ -173,10 +176,11 @@ struct HstuAttentionFwdPipelineQRKSVS using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); OaccBlockTileType o_acc; - auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - q_dram_block_window_tmp.get_window_lengths(), - q_dram_block_window_tmp.get_window_origin(), - Policy::template MakeQRegTileDistribution()); + auto q_dram_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramSingleRepMTileDistribution()); const auto q_origin = q_dram_window.get_window_origin(); const auto [seqlen_k_start, seqlen_k_end] = @@ -188,15 +192,31 @@ struct HstuAttentionFwdPipelineQRKSVS {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); - auto q_tile = load_tile(q_dram_window); + using q_dram_tile_type = decltype(load_tile(q_dram_window)); + statically_indexed_array q_dram_tiles; - clear_tile(o_acc); + static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { + q_dram_tiles[i_rep] = load_tile(q_dram_window); + move_tile_window(q_dram_window, {kGemmSingleRepM, 0}); + }); auto k_tile = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); __builtin_amdgcn_sched_barrier(0); + // K tile in LDS + QKVDataType* q_lds_ptr = static_cast(smem_ptr); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_write_window = make_tile_window( + q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); + auto q_lds_read_window = + make_tile_window(q_lds, + Policy::template MakeQLdsBlockDescriptor().get_lengths(), + {0, 0}, + Policy::template MakeQRegSingleRepMTileDistribution()); + // K tile in LDS QKVDataType* k_lds_ptr = static_cast(smem_ptr); auto k_lds = make_tensor_view( @@ -284,12 +304,54 @@ struct HstuAttentionFwdPipelineQRKSVS return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); }(); + using q_reg_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeQRegSingleRepMTileDistribution())); + statically_indexed_array q_reg_tiles; + + using q_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeQRegTileDistribution())); + + q_tile_type q_tile; + + { + clear_tile(o_acc); + + constexpr index_t complete_tile_thread_buf_size = q_tile_type::get_thread_buffer_size(); + constexpr index_t splitted_tile_thread_buf_size = + q_reg_tile_type::get_thread_buffer_size(); + + static_assert(complete_tile_thread_buf_size == + kGemmNumRepM * splitted_tile_thread_buf_size, + "Check failed!"); + + static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { + store_tile(q_lds_write_window, q_dram_tiles[i_rep]); + + // no need to call __builtin_amdgcn_s_barrier() since the tile-slice written + // by each wavefront is read by itself + __builtin_amdgcn_s_waitcnt(0xc07f); + + q_reg_tiles[i_rep] = load_tile(q_lds_read_window); + + static_for<0, splitted_tile_thread_buf_size, 1>{}([&](auto i_buf) { + q_tile.get_thread_buffer()[i_rep * splitted_tile_thread_buf_size + i_buf] = + q_reg_tiles[i_rep].get_thread_buffer()[i_buf]; + }); + + // no need to call __builtin_amdgcn_s_barrier() since the tile-slice read + // by each wavefront is over-written by itself + }); + }; + q_tile = tile_elementwise_in(q_element_func, q_tile); auto seqlen_k_curr = seqlen_k_start; index_t i_loop = 0; + // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile + __builtin_amdgcn_s_barrier(); + while(i_loop < num_loops) { static_for<0, k1_loops, 1>{}([&](auto i_k1) { diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 4829bacb54..489b377dd2 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -23,6 +23,16 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return 3; } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSingleRepMTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + constexpr index_t kBlockGemmM = GetQKBlockGemmSingleRepM(); + + return BlockGemm:: + template MakeABlockTileDistribution(); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() { @@ -43,6 +53,30 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return WG::WarpGemmAttribute::kKPerThread; }; + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() + { + if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + return 8; + else + return 4; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + using QDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + + constexpr index_t MaxVectorSize = 16 / sizeof(QDataType); + constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; + + return min(MaxVectorSize, ElemPerThread); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() { @@ -114,6 +148,95 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy GetVSingleSmemElementSpaceSize()); }; + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPack = GetSmemKPackQ(); + constexpr index_t kKVector = GetAlignmentQ(); + + if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + { + static_assert(kKVector == kKPack); + + constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto q_lds_block_desc = transform_tensor_descriptor( + q_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform( + make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return q_lds_block_desc; + } + else + { + static_assert(kKVector % kKPack == 0); + + constexpr auto q_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto q_lds_block_desc = transform_tensor_descriptor( + q_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<2>{}, sequence<0, 1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return q_lds_block_desc; + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramSingleRepMTileDistribution() + { + using QKVDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + + constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType); + + constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + + constexpr index_t KPerThread = kMaxVecLoad; + constexpr index_t KThreads = kKPerBlock / KPerThread; + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + + // for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { @@ -345,6 +468,13 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy } } + template + CK_TILE_HOST_DEVICE static constexpr index_t GetQKBlockGemmSingleRepM() + { + return Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}) * + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + }; + template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { @@ -488,6 +618,13 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return (k1_loops - 1 + 1) % num_kv_lds_buffers == 0; }; + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() + { + return MakeQLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::QKVDataType); + }; + template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { @@ -500,7 +637,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - return GetSmemSizeKV() + GetSmemSizeDropout(0); + return max(GetSmemSizeKV() + GetSmemSizeDropout(0), + GetSmemSizeQ()); } };