From 207e6f10b8e336a7ada5d123ac396e1a2350bd72 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 27 Oct 2025 14:54:36 +0000 Subject: [PATCH] Implementation of hstu attention pipeline using trload for v on mi350 --- ..._gemm_areg_bsmem_trload_creg_v2_hack_1.hpp | 247 +++++++ ...stu_attention_batched_forward_dispatch.hpp | 15 +- .../hstu_attention_fwd_kernel.hpp | 32 +- .../hstu_attention_fwd_pipeline.hpp | 4 + ..._attention_fwd_pipeline_default_policy.hpp | 228 ++++-- .../hstu_attention_fwd_trload_pipeline.hpp | 682 ++++++++++++++++++ ...hstu_attention_jagged_forward_dispatch.hpp | 15 +- .../hstu_attention_pipeline_problem.hpp | 2 + 8 files changed, 1153 insertions(+), 72 deletions(-) create mode 100644 example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp create mode 100644 example/ck_tile/18_hstu_attention/hstu_attention_fwd_trload_pipeline.hpp diff --git a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp new file mode 100644 index 0000000000..07e52f1bc9 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemTrLoadCRegV2Hack_1 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<1>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = make_static_distributed_tensor( + MakeABlockTileDistribution()); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + constexpr auto b_warp_dstr_encode = + typename InputTileDistributionTraits::TransposedDstrEncode{}; + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{0, iNWarp * WG::kN}, + make_static_tile_distribution(b_warp_dstr_encode)); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + constexpr auto I0 = number<0>{}; + + using b_warp_tensor_type = decltype(load_tile_transpose(b_warp_windows(I0)(I0))); + + statically_indexed_array, + NIterPerWarp> + b_warp_tensors; + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(I0)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(I0)(kIter), + {kIter * KPerBlockPerIter, 0 * NPerBlockPerIter}); + b_warp_tensors(I0)(kIter) = load_tile_transpose(b_warp_windows(I0)(kIter)); + }); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + if constexpr(nIter < NIterPerWarp - 1) + { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(number{})(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(number{})(kIter), + {kIter * KPerBlockPerIter, (nIter + 1) * NPerBlockPerIter}); + b_warp_tensors(number{})(kIter) = + load_tile_transpose(b_warp_windows(number{})(kIter)); + }); + }; + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter][kIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + return make_static_tile_distribution(a_block_dstr_encode); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index 17b6aa7350..3d32c40535 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -16,6 +16,7 @@ #include "hstu_attention_pipeline_problem.hpp" #include "hstu_attention_traits.hpp" #include "hstu_attention_fwd_pipeline.hpp" +#include "hstu_attention_fwd_trload_pipeline.hpp" #include "hstu_attention_fwd_kernel.hpp" #include "hstu_attention_epilogue.hpp" @@ -32,6 +33,12 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch HstuAttentionWithSoftmaxFwdTileSetting, HstuAttentionNoSoftmaxFwdTileSetting>::Type; +#ifdef BUILD_HSTU_FOR_GFX95_ONLY + static constexpr bool kUseTrLoad = true; +#else + static constexpr bool kUseTrLoad = false; +#endif + template using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem< InOutDataType, @@ -43,6 +50,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch kHasDropout, kUseCausal, kUseSoftmax, + kUseTrLoad, HstuAttentionTileSetting, HstuTraits>; @@ -80,8 +88,11 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch kPadSeqLenQ, kPadHeadDimV>>; - using HstuPipeline = - ck_tile::HstuAttentionFwdPipelineQRKSVS; + using HstuPipeline = std::conditional_t< + kUseTrLoad, + ck_tile::HstuAttentionFwdPipelineQRKSVSTrLoad, + ck_tile::HstuAttentionFwdPipelineQRKSVS>; + using HstuKernel = ck_tile::HstuAttentionFwdKernel; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index d6d91746a5..0d8282fc5a 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -48,6 +48,8 @@ struct HstuAttentionFwdKernel static constexpr bool kHasDropout = HstuAttentionPipeline::kHasDropout; static constexpr bool kHasCausalMask = HstuAttentionPipeline::kHasCausal; + static constexpr bool kUseTrLoad = HstuAttentionPipeline::kUseTrLoad; + template // to avoid duplicated base class problem, introduce an template // arg struct HstuAttentionFwdEmptyKargs @@ -583,17 +585,27 @@ struct HstuAttentionFwdKernel number{}, number<1>{}); - const auto v_dram_transposed = - transform_tensor_view(v_dram_naive, - make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen)), - make_tuple(sequence<1>{}, sequence<0>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + if constexpr(!kUseTrLoad) + { + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - return pad_tensor_view(v_dram_transposed, - make_tuple(number{}, - number{}), - sequence{}); + return pad_tensor_view(v_dram_transposed, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(v_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + }; }(); auto q_dram_window = diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 5b70d044f2..1245e4e445 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -40,6 +40,10 @@ struct HstuAttentionFwdPipelineQRKSVS static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasCausal = Problem::kHasCausal; + static_assert(Problem::kUseTrLoad == false, "Check failed!"); + + static constexpr bool kUseTrLoad = false; + static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 5694b948b5..0f4bea45f0 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -16,6 +16,7 @@ #include "block_gemm_areg_bsmem_creg_v2_hack_0.hpp" #include "block_gemm_areg_bsmem_creg_v2_hack_1.hpp" +#include "block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp" namespace ck_tile { @@ -71,11 +72,26 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution() { - using BlockGemm = remove_cvref_t())>; + if constexpr(!Problem::kUseTrLoad) + { + using BlockGemm = remove_cvref_t())>; - return BlockGemm::template MakeABlockTileDistribution< - Problem::HstuAttentionTileSetting::kM0, - Problem::HstuAttentionTileSetting::kN0>(); + return BlockGemm::template MakeABlockTileDistribution< + Problem::HstuAttentionTileSetting::kM0, + Problem::HstuAttentionTileSetting::kN0>(); + } + else + { + using BlockGemm = remove_cvref_t())>; + + constexpr auto bias_block_dstr_encode = + BlockGemm::template MakeCBlockDistributionEncode< + Problem::HstuAttentionTileSetting::kM0, + Problem::HstuAttentionTileSetting::kN0>(); + constexpr auto bias_block_dstr = make_static_tile_distribution(bias_block_dstr_encode); + + return bias_block_dstr; + }; } template @@ -148,23 +164,34 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() { + using VDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); + constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); - constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad) - ? kMaxVecLoad - : (ElemPerThread / kMinVecLoad); + constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad) + ? kMaxVecLoad + : (ElemPerThread / kMinVecLoad); - return kVecLoad; + return kVecLoad; + } + else + { + constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + return min(MaxVectorSize, ElemPerThread); + }; } template @@ -195,11 +222,18 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize(); + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize(); - return N0 * (N1 * kKPerBlock + kKPack); + return N0 * (N1 * kKPerBlock + kKPack); + } + else + { + return kNPerBlock * kKPerBlock; + }; }; template @@ -470,43 +504,88 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - // K2 is the vector size for storing shuffled tile to LDS - constexpr index_t K2 = ElemPerThread / N1; + // K2 is the vector size for storing shuffled tile to LDS + constexpr index_t K2 = ElemPerThread / N1; - // GetSmemKPackV() is the vector size for loading from LDS by BlockGemm - constexpr index_t kKPack = GetSmemKPackV(); + // GetSmemKPackV() is the vector size for loading from LDS by BlockGemm + constexpr index_t kKPack = GetSmemKPackV(); - static_assert(kKPack >= K2, "Check failed!"); + static_assert(kKPack >= K2, "Check failed!"); - constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack); + constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack); - static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); + static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); - constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); - constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}, number{}), - make_tuple(number{}, - number{}, - number{}, - number<1>{}), - number<8>{}, - number<1>{}); + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple( + number{}, number{}, number{}, number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); - constexpr auto v_lds_block_desc = transform_tensor_descriptor( - v_lds_block_desc_0, - make_tuple(make_merge_transform( - make_tuple(number{}, number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - return v_lds_block_desc; + return v_lds_block_desc; + } + else + { + constexpr index_t kKPack = GetSmemKPackV(); + + constexpr auto XorGroupSize = + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}); + + constexpr index_t VSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock; + + static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); + + constexpr auto v_lds_block_desc_naive = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor( + v_lds_block_desc_naive, + make_tuple(make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + v_lds_block_desc_permuted, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + }; } template @@ -516,26 +595,51 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(ElemPerThread % N1 == 0); + static_assert(ElemPerThread % N1 == 0); - constexpr index_t K2 = ElemPerThread / N1; - constexpr index_t K1 = get_warp_size() / N0; - constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t K2 = ElemPerThread / N1; + constexpr index_t K1 = get_warp_size() / N0; + constexpr index_t K0 = kBlockSize / get_warp_size(); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<2, 1>>{}); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<2, 1>, + sequence<2, 1>>{}); + } + else + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + static_assert(ElemPerThread % N1 == 0); + + constexpr index_t K2 = ElemPerThread / N1; + constexpr index_t K1 = get_warp_size() / N0; + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + }; } + // used when kUseTrLoad is false template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution() { @@ -717,7 +821,15 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy typename Problem::GemmAccDataType, typename Problem::HstuAttentionTileSetting::Gemm1BlockWarps, WarpGemm>; - return BlockGemmARegBSmemCRegV2Hack_1{}; + + if constexpr(!Problem::kUseTrLoad) + { + return BlockGemmARegBSmemCRegV2Hack_1{}; + } + else + { + return BlockGemmARegBSmemTrLoadCRegV2Hack_1{}; + }; } template diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_trload_pipeline.hpp new file mode 100644 index 0000000000..140ef8f07b --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_trload_pipeline.hpp @@ -0,0 +1,682 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" + +#include "hstu_attention_fwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +template +struct HstuAttentionFwdPipelineQRKSVSTrLoad +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QKVDataType = remove_cvref_t; + using GemmAccDataType = remove_cvref_t; + using CompDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + + using HstuAttentionTileSetting = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = HstuAttentionTileSetting::kM0; + static constexpr index_t kN0 = HstuAttentionTileSetting::kN0; + static constexpr index_t kN1 = HstuAttentionTileSetting::kN1; + static constexpr index_t kK1 = HstuAttentionTileSetting::kK1; + static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsJagged = Problem::kIsJagged; + static constexpr auto kHasBias = Problem::kHasBias; + static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kHasCausal = Problem::kHasCausal; + + static_assert(Problem::kUseTrLoad == true, "Check failed!"); + + static constexpr bool kUseTrLoad = true; + + static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK; + static constexpr bool kPadHeadDimV = + (kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQK ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQK ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM(); + static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM; + + // used by NRepetitions2DEpilogue + static constexpr index_t kGemm1SingleRepN = + Policy::template GetKVBlockGemmSingleRepN(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::Traits::kBlockPerCu != -1) + return Problem::Traits::kBlockPerCu; + else + { + if constexpr(kQKHeaddim == 32) + { + return 2; + } + else if constexpr(kQKHeaddim == 64) + { + return 2; + } + else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128) + { + if constexpr(kHasBias) + return 2; + else + return 2; + } + else if constexpr(kQKHeaddim == 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_hstu"; + + using DropoutType = std::conditional_t; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + HstuMask& mask, + float scale_s, // scaling value exerted on the immediate Q@K result + float scale_p, // scaling value exerted on the SiLu result + void* smem_ptr, + DropoutType& dropout) const + { + ignore = q_element_func; + ignore = k_element_func; + + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr bool kUseSoftmax = Problem::kUseSoftmax; + + constexpr index_t k1_loops = kN0 / kK1; + + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + // SaccBlockTile size is [kM0, kK1] + // PcompBlockTile size is [kM0, kN0] + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); + + SaccBlockTileType sacc_tile; + PcompBlockTileType pcomp_tile; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + using MLBlockTileType = decltype(block_tile_reduce( + PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0})); + + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + OaccBlockTileType o_acc; + + auto q_dram_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramSingleRepMTileDistribution()); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {seqlen_k_start, 0}, + Policy::template MakeKDramTileDistribution()); + + using q_dram_tile_type = decltype(load_tile(q_dram_window)); + statically_indexed_array q_dram_tiles; + + static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { + q_dram_tiles[i_rep] = load_tile(q_dram_window); + move_tile_window(q_dram_window, {kGemmSingleRepM, 0}); + }); + + using k_tile_type = decltype(load_tile(k_dram_window)); + + statically_indexed_array k_tiles; + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + }); + + __builtin_amdgcn_sched_barrier(0); + + // Q tile in LDS + QKVDataType* q_lds_ptr = static_cast(smem_ptr); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_write_window = make_tile_window( + q_lds, Policy::template MakeQLdsBlockDescriptor().get_lengths(), {0, 0}); + // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window + auto q_lds_read_window = + make_tile_window(q_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeQRegSingleRepMTileDistribution()); + + // K tile in LDS + QKVDataType* k_lds_ptr = static_cast(smem_ptr); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_write_window = make_tile_window( + k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); + + // when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window + auto k_lds_read_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + using k_lds_write_window_type = decltype(get_slice_tile( + k_lds_write_window, sequence<0, 0>{}, sequence{})); + + using k_lds_read_window_type = decltype(get_slice_tile( + k_lds_read_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array k_lds_write_windows; + statically_indexed_array k_lds_read_windows; + + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + k_lds_write_windows[i_buf] = + get_slice_tile(k_lds_write_window, + sequence{}, + sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{}); + k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window, + sequence{}, + sequence<(i_buf + 1) * kK1, kQKHeaddim>{}); + }); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + using v_lds_window_type = + decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array v_lds_windows; + + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + v_lds_windows[i_buf] = get_slice_tile( + v_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kN1>{}); + }); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}, + Policy::template MakeVDramTileDistribution()); + + // reduction function for softmax + const auto f_silu = [&](CompDataType& x) { + const auto one = ck_tile::type_convert(1.0f); + + if constexpr(std::is_same_v) + { + x = x * __builtin_amdgcn_rcpf(one + __expf(-x)); + } + else + { + x = x / (one + exp(-x)); + } + }; + + const auto f_exp = [&](CompDataType x) { + if constexpr(std::is_same_v) + { + return __expf(x); + } + else + { + return exp(x); + } + }; + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto null_randval_window = [&]() { + if constexpr(kHasDropout) + { + const auto null_randval_dram = [&]() { + const auto null_dram_naive = make_naive_tensor_view( + static_cast(nullptr), + make_tuple(1, 1), + make_tuple(1, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(null_dram_naive, + make_tuple(number<1>{}, number<1>{}), + sequence{}); + }(); + + return make_tile_window( + null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0}); + } + else + return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); + }(); + + using q_reg_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeQRegSingleRepMTileDistribution())); + statically_indexed_array q_reg_tiles; + + using q_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeQRegTileDistribution())); + + q_tile_type q_tile; + + { + static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { + store_tile(q_lds_write_window, q_dram_tiles[i_rep]); + + // no need to call __builtin_amdgcn_s_barrier() since the tile-slice written + // by each wavefront is read by itself + __builtin_amdgcn_s_waitcnt(0xc07f); + + q_reg_tiles[i_rep] = load_tile(q_lds_read_window); + + __builtin_amdgcn_s_waitcnt(0xc07f); + + // the following codes will not generate actual instructions by the compiler + set_slice_tile(q_tile, + q_reg_tiles[i_rep], + sequence{}, + sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{}); + + // no need to call __builtin_amdgcn_s_barrier() since the tile-slice read + // by each wavefront is over-written by itself + }); + + clear_tile(o_acc); + + if constexpr(kUseSoftmax) + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + }; + }; + + q_tile = tile_elementwise_in(q_element_func, q_tile); + + auto seqlen_k_curr = seqlen_k_start; + + __builtin_amdgcn_sched_barrier(0x00000001); + + // ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile + __builtin_amdgcn_s_barrier(); + + __builtin_amdgcn_sched_barrier(0x00000001); + + using v_tile_type = decltype(load_tile(v_dram_window)); + + statically_indexed_array v_tiles; + + do + { + // STAGE 1, Gemm_0 ( S = Q@K ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile(k_lds_write_windows[i_k1], + tile_elementwise_in(k_element_func, k_tiles[i_k1])); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // load v_tiles used in current iteration + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + + __builtin_amdgcn_sched_barrier(0x00000001); + + block_sync_lds(); + + // execute current unroll of gemm_0 + gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + + auto tmp_tile = cast_tile(sacc_tile); + + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_k1 * kK1>{}, + sequence{}); + }); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // STAGE 2, scale_s, add bias, mask, siLU + if constexpr(kHasBias) + { + const auto bias_tile = load_tile(bias_dram_window); + + tile_elementwise_inout( + [&scale_s, &bias_element_func](auto& x, const auto& y) { + x = x * scale_s + type_convert(bias_element_func(y)); + }, + pcomp_tile, + bias_tile); + + move_tile_window(bias_dram_window, {0, kN0}); + } + else + { + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); + } + + if constexpr(!kUseSoftmax) + { + if(!mask.IsFullTileInsideMask( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) + { + constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if(!mask.IsTokenPairInsideMask(row, col)) + { + pcomp_tile(i_j_idx) = type_convert(0.0f); + }; + }); + }); + } + + tile_elementwise_inout(f_silu, pcomp_tile); + + tile_elementwise_inout( + [&](auto& x) { x = x * type_convert(scale_p); }, pcomp_tile); + } + else + { + if(!mask.IsFullTileInsideMask( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) + { + constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if(!mask.IsTokenPairInsideMask(row, col) || col >= seqlen_k_end) + { + pcomp_tile(i_j_idx) = -numeric::infinity(); + }; + }); + }); + } + else + { + constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if(col >= seqlen_k_end) + { + pcomp_tile(i_j_idx) = -numeric::infinity(); + }; + }); + }); + }; + + auto m_local = block_tile_reduce( + pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; + + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + + constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = type_convert(0.0f); + }); + } + else + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]); + }); + } + }); + + auto rowsum_p = block_tile_reduce( + pcomp_tile, sequence<1>{}, f_sum, CompDataType{0}); + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + + // adjust o_acc[] according to the update between m and m_old + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + l(i_idx) = rowsum_p[i_idx]; + } + else + { + const auto tmp = f_exp(m_old[i_idx] - m[i_idx]); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + } + }); + }; + + seqlen_k_curr += kN0; + + if constexpr(kHasDropout) + { + auto randval_lds_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + + dropout.template Run( + randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); + } + + auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + + // STAGE 3, Gemm_1 ( O = P@V ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_tiles[number{}])); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // load k_tiles used by next iteration + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + + __builtin_amdgcn_sched_barrier(0x00000001); + + block_sync_lds(); + + __builtin_amdgcn_sched_barrier(0x00000001); + + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number{}]); + }); + } while(seqlen_k_curr < seqlen_k_end); + + if constexpr(kUseSoftmax) + { + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if(m[i_idx] == -numeric::infinity()) + o_acc(i_j_idx) = 0.0f; + else + o_acc(i_j_idx) *= 1.0f / l[i_idx]; + }); + }); + }; + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + HstuMask mask, + float scale_s, // scaling value exerted on the immediate Q@K result + float scale_p, // scaling value exerted on the SiLU result + void* smem_ptr, + DropoutType& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + scale_s, + scale_p, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index e15192781b..42d4ae405c 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -16,6 +16,7 @@ #include "hstu_attention_pipeline_problem.hpp" #include "hstu_attention_traits.hpp" #include "hstu_attention_fwd_pipeline.hpp" +#include "hstu_attention_fwd_trload_pipeline.hpp" #include "hstu_attention_fwd_kernel.hpp" #include "hstu_attention_epilogue.hpp" @@ -32,6 +33,12 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch HstuAttentionWithSoftmaxFwdTileSetting, HstuAttentionNoSoftmaxFwdTileSetting>::Type; +#ifdef BUILD_HSTU_FOR_GFX95_ONLY + static constexpr bool kUseTrLoad = true; +#else + static constexpr bool kUseTrLoad = false; +#endif + template using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem< InOutDataType, @@ -43,6 +50,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch kHasDropout, kUseCausal, kUseSoftmax, + kUseTrLoad, HstuAttentionTileSetting, HstuTraits>; @@ -74,8 +82,11 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch kPadSeqLenQ, kPadHeadDimV>>; - using HstuPipeline = ck_tile::HstuAttentionFwdPipelineQRKSVS; - using HstuKernel = ck_tile::HstuAttentionFwdKernel; + using HstuPipeline = std::conditional_t< + kUseTrLoad, + ck_tile::HstuAttentionFwdPipelineQRKSVSTrLoad, + ck_tile::HstuAttentionFwdPipelineQRKSVS>; + using HstuKernel = ck_tile::HstuAttentionFwdKernel; RunWithKernel(param, stream); }); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp index 6bd1c71dd0..092216bd24 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp @@ -22,6 +22,7 @@ template struct HstuAttentionFwdPipelineProblem @@ -44,6 +45,7 @@ struct HstuAttentionFwdPipelineProblem static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kHasCausal = kHasCausal_; static constexpr bool kUseSoftmax = kUseSoftmax_; + static constexpr bool kUseTrLoad = kUseTrLoad_; using HstuAttentionTileSetting = remove_cvref_t;