From 0b6bbe45d64577120e2a43913503558086a7a522 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Tue, 21 Apr 2026 15:35:03 +0000 Subject: [PATCH] Remove exposing kUseTrLoad as template parameter of pipeline problem --- ...stu_attention_batched_forward_dispatch.hpp | 7 +- ...ntion_batched_forward_splitkv_dispatch.hpp | 7 +- .../hstu_attention_fwd_kernel.hpp | 4 +- .../hstu_attention_fwd_pipeline_policy.hpp | 86 ++++++++----------- .../hstu_attention_fwd_splitkv_kernel.hpp | 4 +- .../hstu_attention_group_forward_dispatch.hpp | 7 +- ...tention_group_forward_splitkv_dispatch.hpp | 7 +- ...hstu_attention_jagged_forward_dispatch.hpp | 7 +- ...ention_jagged_forward_splitkv_dispatch.hpp | 7 +- ...hstu_attention_no_softmax_fwd_pipeline.hpp | 2 - ...tention_no_softmax_fwd_trload_pipeline.hpp | 37 ++++---- .../hstu_attention_pipeline_problem.hpp | 4 - .../18_hstu_attention/hstu_attention_util.hpp | 26 ++++++ ...tu_attention_with_softmax_fwd_pipeline.hpp | 2 - ...ntion_with_softmax_fwd_trload_pipeline.hpp | 38 ++++---- 15 files changed, 129 insertions(+), 116 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index 25679561a1..814c3af0d1 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -39,9 +39,9 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch HstuAttentionNoSoftmaxFwdTileSetting>::Type; #ifdef BUILD_HSTU_FOR_GFX95_ONLY - static constexpr bool kUseTrLoad = true; + static constexpr bool use_trload_pipeline = true; #else - static constexpr bool kUseTrLoad = false; + static constexpr bool use_trload_pipeline = false; #endif template @@ -57,7 +57,6 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch kHasDropout, kUseCausal, kUseSoftmax, - kUseTrLoad, HstuAttentionTileSetting>; static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream) @@ -96,7 +95,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] { using HstuPipelineProblem = HstuPipelineProblemTemp; - if constexpr(!kUseTrLoad) + if constexpr(!use_trload_pipeline) { using HstuPipeline = std::conditional_t< kUseSoftmax, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp index 8af4deee3c..fb7b005aaf 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp @@ -44,9 +44,9 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch using HstuAttentionCombineTileSetting = HstuAttentionFwdSplitKVCombineTileSetting::Type; #ifdef BUILD_HSTU_FOR_GFX95_ONLY - static constexpr bool kUseTrLoad = true; + static constexpr bool use_trload_pipeline = true; #else - static constexpr bool kUseTrLoad = false; + static constexpr bool use_trload_pipeline = false; #endif template @@ -62,7 +62,6 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch kHasDropout, kUseCausal, kUseSoftmax, - kUseTrLoad, HstuAttentionFwdTileSetting>; using OaccDataType = HstuAttentionFwdTypeConfig::OaccDataType; @@ -107,7 +106,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] { using HstuPipelineProblem = HstuFwdPipelineProblemTemp; - if constexpr(!kUseTrLoad) + if constexpr(!use_trload_pipeline) { using HstuPipeline = std::conditional_t< kUseSoftmax, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 1ab7178c0a..716741fe0f 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -13,6 +13,7 @@ #include #include "hstu_block_masking.hpp" +#include "hstu_attention_util.hpp" #ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM #define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1 @@ -45,13 +46,14 @@ struct HstuAttentionFwdKernel static constexpr auto kHasBias = HstuAttentionPipeline::Problem::kHasBias; static constexpr bool kHasDropout = HstuAttentionPipeline::Problem::kHasDropout; static constexpr bool kHasCausalMask = HstuAttentionPipeline::Problem::kHasCausal; - static constexpr bool kUseTrLoad = HstuAttentionPipeline::Problem::kUseTrLoad; static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ; static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQK = HstuAttentionPipeline::kPadHeadDimQK; static constexpr bool kPadHeadDimV = HstuAttentionPipeline::kPadHeadDimV; + static constexpr bool kUseTrLoad = detail::is_using_trload_v; + template // to avoid duplicated base class problem, introduce an template // arg struct HstuAttentionFwdEmptyKargs diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp index 4fb2feb6e1..49dc92e6bf 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp @@ -45,10 +45,10 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy return WG::WarpGemmAttribute::kKPerThread; }; - template + template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetKVWarpGemmKPerThreadSize() { - using BlockGemm = remove_cvref_t())>; + using BlockGemm = remove_cvref_t())>; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -58,26 +58,14 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution() { - if constexpr(!Problem::kUseTrLoad) - { - using BlockGemm = remove_cvref_t())>; + using BlockGemm = remove_cvref_t())>; - return BlockGemm::template MakeABlockTileDistribution< - Problem::HstuAttentionTileSetting::kM0, - Problem::HstuAttentionTileSetting::kN0>(); - } - else - { - using BlockGemm = remove_cvref_t())>; + constexpr auto bias_block_dstr_encode = BlockGemm::template MakeCBlockDistributionEncode< + Problem::HstuAttentionTileSetting::kM0, + Problem::HstuAttentionTileSetting::kN0>(); + constexpr auto bias_block_dstr = make_static_tile_distribution(bias_block_dstr_encode); - constexpr auto bias_block_dstr_encode = - BlockGemm::template MakeCBlockDistributionEncode< - Problem::HstuAttentionTileSetting::kM0, - Problem::HstuAttentionTileSetting::kN0>(); - constexpr auto bias_block_dstr = make_static_tile_distribution(bias_block_dstr_encode); - - return bias_block_dstr; - }; + return bias_block_dstr; } template @@ -117,20 +105,20 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy return Problem::GetKDramTileAccessMaxVectorSize(); } - template + template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() { - if constexpr(GetKVWarpGemmKPerThreadSize() >= 8) + if constexpr(GetKVWarpGemmKPerThreadSize() >= 8) return 8; else return 4; } - template + template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() { // special consideration when shuffling is required before storing V to LDS - if constexpr(!Problem::kUseTrLoad) + if constexpr(!kUseTrLoad) { using VDataType = remove_cvref_t; @@ -183,13 +171,13 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy }; }; - template + template CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize() { constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; - if constexpr(!Problem::kUseTrLoad) + if constexpr(!kUseTrLoad) { constexpr index_t N1 = GetAlignmentV(); constexpr index_t N0 = kNPerBlock / N1; @@ -203,14 +191,14 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy }; }; - template + template CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() { return max(GetKSingleSmemElementSpaceSize(), - GetVSingleSmemElementSpaceSize()); + GetVSingleSmemElementSpaceSize()); }; - template + template CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers(); @@ -219,6 +207,9 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); + constexpr index_t SingleSmemElementSpaceSize = + GetSingleSmemElementSpaceSize(); + // for hdim96 and hdim160, use simplest layout if constexpr(kKPerBlock < Problem::HstuAttentionTileSetting::kSubQKHeaddim) { @@ -226,8 +217,6 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); - constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); - constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number{}, number{}, number<1>{}), @@ -252,8 +241,6 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); - constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); - using KDataType = remove_cvref_t; constexpr index_t DataTypeSize = sizeof(KDataType); @@ -322,8 +309,6 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); - constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); - constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(make_tuple(number{}, number{}, @@ -405,7 +390,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy }; } - template + template CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() { constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers(); @@ -413,7 +398,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; - if constexpr(!Problem::kUseTrLoad) + if constexpr(!kUseTrLoad) { constexpr index_t N1 = GetAlignmentV(); constexpr index_t N0 = kNPerBlock / N1; @@ -456,14 +441,15 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy } else { - constexpr index_t kKPack = GetSmemKPackV(); + constexpr index_t kKPack = GetSmemKPackV(); constexpr auto XorGroupSize = Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}); constexpr index_t VSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock; - static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); + static_assert(VSingleSmemElementSpaceSize == + GetVSingleSmemElementSpaceSize()); constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(make_tuple(number{}, @@ -497,14 +483,14 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy }; } - template + template CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; - if constexpr(!Problem::kUseTrLoad) + if constexpr(!kUseTrLoad) { constexpr index_t NPerThread = GetAlignmentV(); constexpr index_t NThreads = kNPerBlock / NPerThread; @@ -526,7 +512,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy } else { - constexpr index_t NPerThread = GetAlignmentV(); + constexpr index_t NPerThread = GetAlignmentV(); constexpr index_t NThreads = kNPerBlock / NPerThread; constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; @@ -652,7 +638,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy Problem::HstuAttentionTileSetting::Gemm1BlockWarps::at(number<1>{}); }; - template + template CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() { using GemmProblem = BlockGemmProblem< @@ -727,7 +713,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy typename Problem::HstuAttentionTileSetting::Gemm1BlockWarps, WarpGemm>; - if constexpr(!Problem::kUseTrLoad) + if constexpr(!kUseTrLoad) { return BlockGemmARegBSmemCRegV2Hack_1{}; } @@ -737,22 +723,22 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy }; } - template + template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() { - using BlockGemm = remove_cvref_t())>; + using BlockGemm = remove_cvref_t())>; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; return WG::WarpGemmAttribute::Impl::kCM1PerLane; } - template + template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { constexpr index_t num_kv_lds_buffers = GetNumKVLdsBuffers(); - return num_kv_lds_buffers * GetSingleSmemElementSpaceSize() * + return num_kv_lds_buffers * GetSingleSmemElementSpaceSize() * sizeof(typename Problem::QKVDataType); }; @@ -762,10 +748,10 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy return 0; }; - template + template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - return GetSmemSizeKV() + GetSmemSizeDropout(); + return GetSmemSizeKV() + GetSmemSizeDropout(); } }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp index 96c14ebb04..d9ae25a9c1 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_kernel.hpp @@ -13,6 +13,7 @@ #include #include "hstu_block_masking.hpp" +#include "hstu_attention_util.hpp" #ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM #define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1 @@ -48,13 +49,14 @@ struct HstuAttentionFwdSplitKVKernel static constexpr auto kHasBias = HstuAttentionPipeline::Problem::kHasBias; static constexpr bool kHasDropout = HstuAttentionPipeline::Problem::kHasDropout; static constexpr bool kHasCausalMask = HstuAttentionPipeline::Problem::kHasCausal; - static constexpr bool kUseTrLoad = HstuAttentionPipeline::Problem::kUseTrLoad; static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ; static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQK = HstuAttentionPipeline::kPadHeadDimQK; static constexpr bool kPadHeadDimV = HstuAttentionPipeline::kPadHeadDimV; + static constexpr bool kUseTrLoad = detail::is_using_trload_v; + template // to avoid duplicated base class problem, introduce an template // arg struct HstuAttentionFwdEmptyKargs diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp index fa69ab7cb8..9e8df40392 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp @@ -39,9 +39,9 @@ struct group_forward_causal_softmax_bias_dropout_dispatch HstuAttentionNoSoftmaxFwdTileSetting>::Type; #ifdef BUILD_HSTU_FOR_GFX95_ONLY - static constexpr bool kUseTrLoad = true; + static constexpr bool use_trload_pipeline = true; #else - static constexpr bool kUseTrLoad = false; + static constexpr bool use_trload_pipeline = false; #endif template @@ -57,7 +57,6 @@ struct group_forward_causal_softmax_bias_dropout_dispatch kHasDropout, kUseCausal, kUseSoftmax, - kUseTrLoad, HstuAttentionTileSetting>; static void Run(HstuAttentionGroupFwdParams& param, hipStream_t stream) @@ -89,7 +88,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] { using HstuPipelineProblem = HstuPipelineProblemTemp; - if constexpr(!kUseTrLoad) + if constexpr(!use_trload_pipeline) { using HstuPipeline = std::conditional_t< kUseSoftmax, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp index f7a373d416..cd0591cbfd 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp @@ -45,9 +45,9 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch typename HstuAttentionFwdSplitKVCombineTileSetting::Type; #ifdef BUILD_HSTU_FOR_GFX95_ONLY - static constexpr bool kUseTrLoad = true; + static constexpr bool use_trload_pipeline = true; #else - static constexpr bool kUseTrLoad = false; + static constexpr bool use_trload_pipeline = false; #endif template @@ -63,7 +63,6 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch kHasDropout, kUseCausal, kUseSoftmax, - kUseTrLoad, HstuAttentionFwdTileSetting>; using OaccDataType = HstuAttentionFwdTypeConfig::OaccDataType; @@ -101,7 +100,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] { using HstuPipelineProblem = HstuFwdPipelineProblemTemp; - if constexpr(!kUseTrLoad) + if constexpr(!use_trload_pipeline) { using HstuPipeline = std::conditional_t< kUseSoftmax, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index 0e8dadee2f..052eb8e19e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -39,9 +39,9 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch HstuAttentionNoSoftmaxFwdTileSetting>::Type; #ifdef BUILD_HSTU_FOR_GFX95_ONLY - static constexpr bool kUseTrLoad = true; + static constexpr bool use_trload_pipeline = true; #else - static constexpr bool kUseTrLoad = false; + static constexpr bool use_trload_pipeline = false; #endif template @@ -57,7 +57,6 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch kHasDropout, kUseCausal, kUseSoftmax, - kUseTrLoad, HstuAttentionTileSetting>; static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream) @@ -89,7 +88,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] { using HstuPipelineProblem = HstuPipelineProblemTemp; - if constexpr(!kUseTrLoad) + if constexpr(!use_trload_pipeline) { using HstuPipeline = std::conditional_t< kUseSoftmax, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp index 3021785920..5f6444d31a 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp @@ -44,9 +44,9 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch using HstuAttentionCombineTileSetting = HstuAttentionFwdSplitKVCombineTileSetting::Type; #ifdef BUILD_HSTU_FOR_GFX95_ONLY - static constexpr bool kUseTrLoad = true; + static constexpr bool use_trload_pipeline = true; #else - static constexpr bool kUseTrLoad = false; + static constexpr bool use_trload_pipeline = false; #endif template @@ -62,7 +62,6 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch kHasDropout, kUseCausal, kUseSoftmax, - kUseTrLoad, HstuAttentionFwdTileSetting>; using OaccDataType = HstuAttentionFwdTypeConfig::OaccDataType; @@ -100,7 +99,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] { using HstuPipelineProblem = HstuFwdPipelineProblemTemp; - if constexpr(!kUseTrLoad) + if constexpr(!use_trload_pipeline) { using HstuPipeline = std::conditional_t< kUseSoftmax, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp index c73f0ecffd..77d32295a1 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -46,8 +46,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasCausal = Problem::kHasCausal; - static_assert(Problem::kUseTrLoad == false, "Check failed!"); - static constexpr bool kUseTrLoad = false; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index d2fc5806c8..81f1192eed 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -46,8 +46,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasCausal = Problem::kHasCausal; - static_assert(Problem::kUseTrLoad == true, "Check failed!"); - static constexpr bool kUseTrLoad = true; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; @@ -62,10 +60,10 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad static constexpr index_t kAlignmentK = kPadHeadDimQK ? 1 : Policy::template GetAlignmentK(); static constexpr index_t kAlignmentV = - Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); static constexpr index_t kAlignmentO = - kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); @@ -110,7 +108,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - return Policy::template GetSmemSize(); + return Policy::template GetSmemSize(); } template (); - constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); // SaccBlockTile size is [kM0, kN0Sub] // PcompBlockTile size is [kM0, kN0] @@ -220,9 +218,13 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad // K tile in LDS QKVDataType* k_lds_ptr = static_cast(smem_ptr); auto k_lds = make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + k_lds_ptr, + Policy::template MakeKLdsBlockDescriptor()); auto k_lds_window = make_tile_window( - k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); + k_lds, + Policy::template MakeKLdsBlockDescriptor() + .get_lengths(), + {0, 0}); using k_lds_window_type = decltype(get_slice_tile( k_lds_window, sequence<0, 0>{}, sequence{})); @@ -238,9 +240,11 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad // V tile in LDS auto v_lds = make_tensor_view( reinterpret_cast(smem_ptr), - Policy::template MakeVLdsBlockDescriptor()); + Policy::template MakeVLdsBlockDescriptor()); auto v_lds_window = make_tile_window( - v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + v_lds, + Policy::template MakeVLdsBlockDescriptor().get_lengths(), + {0, 0}); using v_lds_window_type = decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence{})); @@ -252,11 +256,11 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad v_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kN1>{}); }); - auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), - v_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}, - Policy::template MakeVDramTileDistribution()); + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}, + Policy::template MakeVDramTileDistribution()); // reduction function for softmax const auto f_silu = [&](CompDataType& x) { @@ -397,7 +401,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad if constexpr(kHasDropout) { auto randval_lds_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + reinterpret_cast(smem_ptr) + + Policy::template GetSmemSizeKV(); dropout.template Run( randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp index 3dabf8acd3..6a42ba448a 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp @@ -57,8 +57,6 @@ CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize() // but it also contains other information needed by the pipeline, which includes // TileShape -- which determines how block-layer calculation is done in tiles and // how warps are allocated on dimensions -// Traits -- other information required for running the kernel and pipeline - template struct HstuAttentionFwdPipelineProblem { @@ -94,7 +91,6 @@ struct HstuAttentionFwdPipelineProblem static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kHasCausal = kHasCausal_; static constexpr bool kUseSoftmax = kUseSoftmax_; - static constexpr bool kUseTrLoad = kUseTrLoad_; static_assert(!kUseGroup || (kUseGroup && kIsJagged), "Group HSTU is only used with jagged mode!"); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp index 0bfc2c565c..578aefbc7c 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp @@ -30,3 +30,29 @@ static inline int get_number_of_cu() return props.multiProcessorCount; } + +namespace ck_tile { + +namespace detail { + +// A helper struct for detecting kUseTrLoad +// T is the pipeline class used by the kernel instance +template +struct has_use_trload_flag : std::false_type +{ +}; + +template +struct has_use_trload_flag< + T, + std::enable_if_t && T::kUseTrLoad>> + : std::true_type +{ +}; + +template +static inline constexpr bool is_using_trload_v = has_use_trload_flag::value; + +} // namespace detail + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index e598a7439c..fb4cee3648 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -46,8 +46,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasCausal = Problem::kHasCausal; - static_assert(Problem::kUseTrLoad == false, "Check failed!"); - static constexpr bool kUseTrLoad = false; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index b5fdf03d47..ee0be7b7b4 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -46,8 +46,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasCausal = Problem::kHasCausal; - static_assert(Problem::kUseTrLoad == true, "Check failed!"); - static constexpr bool kUseTrLoad = true; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; @@ -62,10 +60,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad static constexpr index_t kAlignmentK = kPadHeadDimQK ? 1 : Policy::template GetAlignmentK(); static constexpr index_t kAlignmentV = - Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); static constexpr index_t kAlignmentO = - kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); @@ -110,7 +108,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - return Policy::template GetSmemSize(); + return Policy::template GetSmemSize(); } template (); - constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + constexpr auto gemm_1 = + Policy::template GetKVBlockGemm(); // SaccBlockTile size is [kM0, kN0Sub] // PcompBlockTile size is [kM0, kN0] @@ -230,9 +229,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad // K tile in LDS QKVDataType* k_lds_ptr = static_cast(smem_ptr); auto k_lds = make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + k_lds_ptr, + Policy::template MakeKLdsBlockDescriptor()); auto k_lds_window = make_tile_window( - k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); + k_lds, + Policy::template MakeKLdsBlockDescriptor() + .get_lengths(), + {0, 0}); using k_lds_window_type = decltype(get_slice_tile( k_lds_window, sequence<0, 0>{}, sequence{})); @@ -248,9 +251,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad // V tile in LDS auto v_lds = make_tensor_view( reinterpret_cast(smem_ptr), - Policy::template MakeVLdsBlockDescriptor()); + Policy::template MakeVLdsBlockDescriptor()); auto v_lds_window = make_tile_window( - v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + v_lds, + Policy::template MakeVLdsBlockDescriptor().get_lengths(), + {0, 0}); using v_lds_window_type = decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence{})); @@ -262,11 +267,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad v_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kN1>{}); }); - auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), - v_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}, - Policy::template MakeVDramTileDistribution()); + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}, + Policy::template MakeVDramTileDistribution()); const auto f_exp = [&](CompDataType x) { if constexpr(std::is_same_v) @@ -508,7 +513,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad if constexpr(kHasDropout) { auto randval_lds_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + reinterpret_cast(smem_ptr) + + Policy::template GetSmemSizeKV(); dropout.template Run( randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);