From 02cae85af5fd0b603a60e85b8dd20ae733d1de71 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 20 Dec 2025 13:35:45 +0000 Subject: [PATCH] Load Q directly from global memory to registers for BlockGemm --- ..._attention_fwd_pipeline_default_policy.hpp | 259 +----------------- ...hstu_attention_no_softmax_fwd_pipeline.hpp | 74 +---- ...tention_no_softmax_fwd_trload_pipeline.hpp | 43 +-- .../hstu_attention_pipeline_problem.hpp | 3 - ...tu_attention_with_softmax_fwd_pipeline.hpp | 82 +----- ...ntion_with_softmax_fwd_trload_pipeline.hpp | 47 +--- 6 files changed, 37 insertions(+), 471 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 711ebe86e9..13da5774b7 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -25,17 +25,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return 4; } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSingleRepMTileDistribution() - { - using BlockGemm = remove_cvref_t())>; - constexpr index_t kBlockGemmM = GetQKBlockGemmSingleRepM(); - - return BlockGemm::template MakeABlockTileDistribution< - kBlockGemmM, - Problem::HstuAttentionTileSetting::kQKHeaddim>(); - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() { @@ -101,33 +90,16 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return WG::WarpGemmAttribute::Impl::kCM1PerLane; } - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() - { - if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) - return 8; - else - return 4; - } - template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() { - if constexpr(Problem::kLoadWholeQTileOnceThroughLds) - { - return Problem::GetQDramTileAccessMaxVectorSize(); - } - else - { - using QDataType = remove_cvref_t; + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QKVDataType); - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; - return detail:: - GetDramTileAccessMaxVectorSize(); - }; + return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); } template @@ -238,217 +210,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy GetVSingleSmemElementSpaceSize()); }; - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() - { - constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds - ? Problem::HstuAttentionTileSetting::kM0 - : GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; - constexpr index_t kKPack = GetSmemKPackQ(); - constexpr index_t kKVector = GetAlignmentQ(); - - // for hdim96 and hdim160, use simplest layout - if constexpr(kKPerBlock < Problem::HstuAttentionTileSetting::kSubQKHeaddim) - { - return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{}), - number{}, - number<1>{}); - } - else if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) - { - static_assert(kKVector == kKPack); - - using QDataType = remove_cvref_t; - - constexpr index_t DataTypeSize = sizeof(QDataType); - - // 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes - constexpr auto MLdsLayer = - (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); - - constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor( - q_lds_block_desc_0, - make_tuple( - make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); - - constexpr auto q_lds_block_desc_k0_mldslayer_m_k1 = transform_tensor_descriptor( - q_lds_block_desc_permuted, - make_tuple(make_pass_through_transform(number{}), - make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - - constexpr auto q_lds_block_desc = transform_tensor_descriptor( - q_lds_block_desc_k0_mldslayer_m_k1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return q_lds_block_desc; - } - else - { - static_assert(kKVector % kKPack == 0); - - constexpr auto q_lds_block_desc_0 = - make_naive_tensor_descriptor(make_tuple(number{}, - number{}, - number{}, - number{}), - make_tuple(number{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); - - constexpr auto q_lds_block_desc = transform_tensor_descriptor( - q_lds_block_desc_0, - make_tuple(make_pass_through_transform(number{}), - make_merge_transform(make_tuple(number{}, - number{}, - number{}))), - make_tuple(sequence<2>{}, sequence<0, 1, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return q_lds_block_desc; - }; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQDramSingleRepMTileDistribution() - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; - - constexpr index_t kKVector = GetAlignmentQ(); - constexpr index_t OtherK = kKPerBlock / kKVector; - - if constexpr(kKPerBlock == Problem::HstuAttentionTileSetting::kSubQKHeaddim) - // for kKPerBlock=32,64,128,256 - { - static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!"); - - constexpr index_t KPerThread = kKVector; - constexpr index_t KThreads = OtherK; - - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); - - // for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E] - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); - } - else // for kKPerBlock=96,160 - { - static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!"); - - // ToDo: need more considieration for hdim72 - constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5; - constexpr index_t KThreads = OtherK / KRepPerThread; - - static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!"); - - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 1>>, - sequence<1, 2, 2>, - sequence<1, 0, 2>>{}); - }; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::HstuAttentionTileSetting::kM0; - constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; - - constexpr index_t kKVector = GetAlignmentQ(); - constexpr index_t OtherK = kKPerBlock / kKVector; - - if constexpr(kKPerBlock == Problem::HstuAttentionTileSetting::kSubQKHeaddim) - // for kKPerBlock=32,64,128,256 - { - static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!"); - - constexpr index_t KPerThread = kKVector; - constexpr index_t KThreads = OtherK; - - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); - - // for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E] - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); - } - else // for kKPerBlock=96,160 - { - static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!"); - - // ToDo: need more considieration for hdim72 - constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5; - constexpr index_t KThreads = OtherK / KRepPerThread; - - static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!"); - - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 1>>, - sequence<1, 2, 2>, - sequence<1, 0, 2>>{}); - }; - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { @@ -986,13 +747,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return WG::WarpGemmAttribute::Impl::kCM1PerLane; } - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() - { - return MakeQLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::QKVDataType); - }; - template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { @@ -1011,8 +765,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - return max(GetSmemSizeKV() + GetSmemSizeDropout(), - GetSmemSizeQ()); + return GetSmemSizeKV() + GetSmemSizeDropout(); } }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp index 1d3035a43c..8fe2f99561 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -69,9 +69,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); - static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM(); - static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM; - // used by NRepetitions2DEpilogue static constexpr index_t kGemm1SingleRepN = Policy::template GetKVBlockGemmSingleRepN(); @@ -181,11 +178,12 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); OaccBlockTileType o_acc; - auto q_dram_window = - make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - q_dram_block_window_tmp.get_window_origin(), - Policy::template MakeQDramSingleRepMTileDistribution()); + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + + auto q_tile = load_tile(q_dram_window); const auto q_origin = q_dram_window.get_window_origin(); const auto [seqlen_k_start, seqlen_k_end] = @@ -197,14 +195,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); - using q_dram_tile_type = decltype(load_tile(q_dram_window)); - statically_indexed_array q_dram_tiles; - - static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { - q_dram_tiles[i_rep] = load_tile(q_dram_window); - move_tile_window(q_dram_window, {kGemmSingleRepM, 0}); - }); - using k_tile_type = decltype(load_tile(k_dram_window)); statically_indexed_array k_tiles; @@ -219,19 +209,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS // provide partition_index for LDS tile window so that warp_id is in vgpr array partition_index{get_warp_id(), get_lane_id()}; - // Q tile in LDS - QKVDataType* q_lds_ptr = static_cast(smem_ptr); - auto q_lds = make_tensor_view( - q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); - auto q_lds_write_window = make_tile_window( - q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); - auto q_lds_read_window = - make_tile_window(q_lds, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeQRegSingleRepMTileDistribution(), - partition_index); - // K tile in LDS QKVDataType* k_lds_ptr = static_cast(smem_ptr); auto k_lds = make_tensor_view( @@ -317,49 +294,12 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); }(); - using q_reg_tile_type = decltype(make_static_distributed_tensor( - Policy::template MakeQRegSingleRepMTileDistribution())); - statically_indexed_array q_reg_tiles; - - using q_tile_type = decltype(make_static_distributed_tensor( - Policy::template MakeQRegTileDistribution())); - - q_tile_type q_tile; - - { - static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { - store_tile(q_lds_write_window, q_dram_tiles[i_rep], partition_index); - - // no need to call __builtin_amdgcn_s_barrier() since the tile-slice written - // by each wavefront is read by itself - __builtin_amdgcn_s_waitcnt(0xc07f); - - q_reg_tiles[i_rep] = load_tile(q_lds_read_window); - - // the following codes will not generate actual instructions by the compiler - set_slice_tile(q_tile, - q_reg_tiles[i_rep], - sequence{}, - sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{}); - - // no need to call __builtin_amdgcn_s_barrier() since the tile-slice read - // by each wavefront is over-written by itself - }); - - clear_tile(o_acc); - }; + clear_tile(o_acc); q_tile = tile_elementwise_in(q_element_func, q_tile); auto seqlen_k_curr = seqlen_k_start; - __builtin_amdgcn_sched_barrier(0x00000001); - - // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile - __builtin_amdgcn_s_barrier(); - - __builtin_amdgcn_sched_barrier(0x00000001); - using v_tile_type = decltype(load_tile(v_dram_window)); statically_indexed_array v_tiles; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index ab760cf860..ecb8c8a9bc 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -50,8 +50,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad static constexpr bool kUseTrLoad = true; - static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!"); - static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK; @@ -180,11 +178,12 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); OaccBlockTileType o_acc; - auto q_dram_window = - make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - q_dram_block_window_tmp.get_window_origin(), - Policy::template MakeQDramTileDistribution()); + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + + auto q_tile = load_tile(q_dram_window); const auto q_origin = q_dram_window.get_window_origin(); const auto [seqlen_k_start, seqlen_k_end] = @@ -196,8 +195,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); - auto q_dram_tile = load_tile(q_dram_window); - using k_tile_type = decltype(load_tile(k_dram_window)); statically_indexed_array k_tiles; @@ -212,19 +209,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad // provide partition_index for LDS tile window so that warp_id is in vgpr array partition_index{get_warp_id(), get_lane_id()}; - // Q tile in LDS - QKVDataType* q_lds_ptr = static_cast(smem_ptr); - auto q_lds = make_tensor_view( - q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); - auto q_lds_write_window = make_tile_window( - q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); - auto q_lds_read_window = - make_tile_window(q_lds, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeQRegTileDistribution(), - partition_index); - // K tile in LDS QKVDataType* k_lds_ptr = static_cast(smem_ptr); auto k_lds = make_tensor_view( @@ -310,27 +294,12 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); }(); - store_tile(q_lds_write_window, q_dram_tile, partition_index); - clear_tile(o_acc); - __builtin_amdgcn_sched_barrier(0x00000001); - - block_sync_lds(); - - auto q_tile = load_tile(q_lds_read_window); - q_tile = tile_elementwise_in(q_element_func, q_tile); auto seqlen_k_curr = seqlen_k_start; - __builtin_amdgcn_sched_barrier(0x00000001); - - // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile - __builtin_amdgcn_s_barrier(); - - __builtin_amdgcn_sched_barrier(0x00000001); - using v_tile_type = decltype(load_tile(v_dram_window)); statically_indexed_array v_tiles; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp index 4a669ee6ec..9ce27427b3 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp @@ -92,9 +92,6 @@ struct HstuAttentionFwdPipelineProblem static constexpr bool kUseSoftmax = kUseSoftmax_; static constexpr bool kUseTrLoad = kUseTrLoad_; - // ToDo: should we define kUseTrLoad and kLoadWholeQTileOnceThrough Lds here ? - static constexpr bool kLoadWholeQTileOnceThroughLds = kUseTrLoad ? true : false; - using HstuAttentionTileSetting = remove_cvref_t; static constexpr index_t kNumGemm0Warps = AttentionTileSetting_::NumGemm0Warps; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index 23c8ed4945..f23bffe772 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -69,9 +69,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); - static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM(); - static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM; - // used by NRepetitions2DEpilogue static constexpr index_t kGemm1SingleRepN = Policy::template GetKVBlockGemmSingleRepN(); @@ -195,11 +192,12 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); OaccBlockTileType o_acc; - auto q_dram_window = - make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - q_dram_block_window_tmp.get_window_origin(), - Policy::template MakeQDramSingleRepMTileDistribution()); + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + + auto q_tile = load_tile(q_dram_window); const auto q_origin = q_dram_window.get_window_origin(); const auto [seqlen_k_start, seqlen_k_end] = @@ -211,14 +209,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); - using q_dram_tile_type = decltype(load_tile(q_dram_window)); - statically_indexed_array q_dram_tiles; - - static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { - q_dram_tiles[i_rep] = load_tile(q_dram_window); - move_tile_window(q_dram_window, {kGemmSingleRepM, 0}); - }); - using k_tile_type = decltype(load_tile(k_dram_window)); constexpr index_t NumPrefetchK = 2; @@ -238,20 +228,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS // provide partition_index for LDS tile window so that warp_id is in vgpr array partition_index{get_warp_id(), get_lane_id()}; - // Q tile in LDS - QKVDataType* q_lds_ptr = static_cast(smem_ptr); - auto q_lds = make_tensor_view( - q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); - auto q_lds_write_window = make_tile_window( - q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); - // when kQKHeaddim > kQKHeaddim, read window is actually smaller than write window - auto q_lds_read_window = - make_tile_window(q_lds, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeQRegSingleRepMTileDistribution(), - partition_index); - // K tile in LDS QKVDataType* k_lds_ptr = static_cast(smem_ptr); auto k_lds = make_tensor_view( @@ -290,7 +266,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_k_start}, // TODO: hdim split? + {0, seqlen_k_start}, Policy::template MakeVDramTileDistribution()); const auto f_exp = [&](CompDataType x) { @@ -334,52 +310,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); }(); - using q_reg_tile_type = decltype(make_static_distributed_tensor( - Policy::template MakeQRegSingleRepMTileDistribution())); - statically_indexed_array q_reg_tiles; - - using q_tile_type = decltype(make_static_distributed_tensor( - Policy::template MakeQRegTileDistribution())); - - q_tile_type q_tile; - - { - static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { - store_tile(q_lds_write_window, q_dram_tiles[i_rep], partition_index); - - // no need to call __builtin_amdgcn_s_barrier() since the tile-slice written - // by each wavefront is read by itself - __builtin_amdgcn_s_waitcnt(0xc07f); - - q_reg_tiles[i_rep] = load_tile(q_lds_read_window); - - // the following codes will not generate actual instructions by the compiler - set_slice_tile(q_tile, - q_reg_tiles[i_rep], - sequence{}, - sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{}); - - // no need to call __builtin_amdgcn_s_barrier() since the tile-slice read - // by each wavefront is over-written by itself - }); - - clear_tile(o_acc); - - set_tile(m, -numeric::infinity()); - clear_tile(l); - }; + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); q_tile = tile_elementwise_in(q_element_func, q_tile); auto seqlen_k_curr = seqlen_k_start; - __builtin_amdgcn_sched_barrier(0x00000001); - - // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile - __builtin_amdgcn_s_barrier(); - - __builtin_amdgcn_sched_barrier(0x00000001); - using v_tile_type = decltype(load_tile(v_dram_window)); statically_indexed_array v_tiles; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index da05271e04..69161209ec 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -50,8 +50,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad static constexpr bool kUseTrLoad = true; - static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!"); - static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK; @@ -194,11 +192,12 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); OaccBlockTileType o_acc; - auto q_dram_window = - make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - q_dram_block_window_tmp.get_window_origin(), - Policy::template MakeQDramTileDistribution()); + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + + auto q_tile = load_tile(q_dram_window); const auto q_origin = q_dram_window.get_window_origin(); const auto [seqlen_k_start, seqlen_k_end] = @@ -210,8 +209,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); - auto q_dram_tile = load_tile(q_dram_window); - using k_tile_type = decltype(load_tile(k_dram_window)); constexpr index_t NumPrefetchK = (n0_loops <= 3) ? 1 : 2; @@ -234,18 +231,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad // provide partition_index for LDS tile window so that warp_id is in vgpr array partition_index{get_warp_id(), get_lane_id()}; - // Q tile in LDS - QKVDataType* q_lds_ptr = static_cast(smem_ptr); - auto q_lds = make_tensor_view( - q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); - auto q_lds_write_window = make_tile_window( - q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); - auto q_lds_read_window = - make_tile_window(q_lds, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeQRegTileDistribution()); - // K tile in LDS QKVDataType* k_lds_ptr = static_cast(smem_ptr); auto k_lds = make_tensor_view( @@ -328,30 +313,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); }(); - store_tile(q_lds_write_window, q_dram_tile, partition_index); - clear_tile(o_acc); - - __builtin_amdgcn_sched_barrier(0x00000001); - - block_sync_lds(); - - auto q_tile = load_tile(q_lds_read_window); - - q_tile = tile_elementwise_in(q_element_func, q_tile); - set_tile(m, -numeric::infinity()); clear_tile(l); + q_tile = tile_elementwise_in(q_element_func, q_tile); + auto seqlen_k_curr = seqlen_k_start; - __builtin_amdgcn_sched_barrier(0x00000001); - - // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile - __builtin_amdgcn_s_barrier(); - - __builtin_amdgcn_sched_barrier(0x00000001); - using v_tile_type = decltype(load_tile(v_dram_window)); statically_indexed_array v_tiles;