From 463a19859a36840950ebb2026d2cb04f00e3c6be Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sun, 22 Jun 2025 11:29:03 +0000 Subject: [PATCH] Completely remove the dependency to include/ck_tile/ops/fmha/ops headers --- ...stu_attention_batched_forward_dispatch.hpp | 13 +- .../hstu_attention_fwd_pipeline.hpp | 22 +- ..._attention_fwd_pipeline_default_policy.hpp | 192 ++++++++++-------- .../hstu_attention_fwd_setting.hpp | 64 +++--- ...hstu_attention_jagged_forward_dispatch.hpp | 11 +- .../hstu_attention_pipeline_problem.hpp | 18 +- .../hstu_attention_tile_setting_define.hpp | 67 ++++++ 7 files changed, 240 insertions(+), 147 deletions(-) create mode 100644 example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index 5e59736e8d..52bd5222e8 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -10,7 +10,6 @@ #include #include #include -#include #include "hstu_attention_bool_switch.hpp" #include "hstu_attention_fwd_type_config.hpp" @@ -32,8 +31,8 @@ template struct batched_forward_causal_local_bias_dropout_dispatch { - using HstuAttentionShape = typename HstuAttentionFwdShape::Type; - using HstuMask = typename ck_tile::HstuBlockMasking::Type; + using HstuAttentionTileSetting = typename HstuAttentionFwdTileSetting::Type; + using HstuMask = typename ck_tile::HstuBlockMasking::Type; template using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem< @@ -45,16 +44,16 @@ struct batched_forward_causal_local_bias_dropout_dispatch kHasBias, kHasDropout, HstuMask, - HstuAttentionShape, + HstuAttentionTileSetting, HstuTraits>; static void Run(HstuAttentionFwdParams& param, hipStream_t stream) { constexpr ck_tile::index_t occupancy = -1; - const bool pad_seqlen_k = !(param.seqlen % HstuAttentionShape::kN0 == 0); - const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionShape::kSubQKHeaddim == 0); - const bool pad_headdim_v = !(param.hdim_v % HstuAttentionShape::kN1 == 0); + const bool pad_seqlen_k = !(param.seqlen % HstuAttentionTileSetting::kN0 == 0); + const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0); + const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0); // no need to check seqlen_q since it is not used as fastest dim, // buffer_load_dwordxx/buffer_store_dwordxx can handle oob access diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 08e70ada3f..05288b0d5d 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -23,20 +23,18 @@ struct HstuAttentionFwdPipelineQRKSVS using ODataType = remove_cvref_t; using HstuMask = remove_cvref_t; - using HstuAttentionTileShape = remove_cvref_t; - using VLayout = remove_cvref_t; - static constexpr bool kQLoadOnce = true; - static_assert(kQLoadOnce == Policy::QLoadOnce); + using HstuAttentionTileSetting = remove_cvref_t; + using VLayout = remove_cvref_t; static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = HstuAttentionTileShape::kM0; - static constexpr index_t kN0 = HstuAttentionTileShape::kN0; - static constexpr index_t kK0 = HstuAttentionTileShape::kK0; - static constexpr index_t kN1 = HstuAttentionTileShape::kN1; - static constexpr index_t kK1 = HstuAttentionTileShape::kK1; - static constexpr index_t kQKHeaddim = HstuAttentionTileShape::kQKHeaddim; - static constexpr index_t kSubQKHeaddim = HstuAttentionTileShape::kSubQKHeaddim; + static constexpr index_t kM0 = HstuAttentionTileSetting::kM0; + static constexpr index_t kN0 = HstuAttentionTileSetting::kN0; + static constexpr index_t kK0 = HstuAttentionTileSetting::kK0; + static constexpr index_t kN1 = HstuAttentionTileSetting::kN1; + static constexpr index_t kK1 = HstuAttentionTileSetting::kK1; + static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); @@ -283,7 +281,7 @@ struct HstuAttentionFwdPipelineQRKSVS make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N - Policy::template MakeBiasDramTileDistribution()); + Policy::template MakeBiasDramTileDistribution()); auto null_randval_window = [&]() { if constexpr(kHasDropout) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 94a015998c..b4349a4032 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -4,7 +4,16 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp" #include "block_gemm_areg_bsmem_creg_v2_hack_0.hpp" #include "block_gemm_areg_bsmem_creg_v2_hack_1.hpp" @@ -12,10 +21,6 @@ namespace ck_tile { struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy - : BlockFmhaPipelineQXKSVSCustomPolicy { template CK_TILE_DEVICE static constexpr auto GetNumKVLdsBuffers() @@ -29,8 +34,9 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy using BlockGemm = remove_cvref_t())>; constexpr index_t kBlockGemmM = GetQKBlockGemmSingleRepM(); - return BlockGemm:: - template MakeABlockTileDistribution(); + return BlockGemm::template MakeABlockTileDistribution< + kBlockGemmM, + Problem::HstuAttentionTileSetting::kQKHeaddim>(); } template @@ -39,8 +45,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy using BlockGemm = remove_cvref_t())>; return BlockGemm::template MakeABlockTileDistribution< - Problem::BlockFmhaShape::kM0, - Problem::BlockFmhaShape::kQKHeaddim>(); + Problem::HstuAttentionTileSetting::kM0, + Problem::HstuAttentionTileSetting::kQKHeaddim>(); } template @@ -63,6 +69,24 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return WG::WarpGemmAttribute::kKPerThread; }; + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + return BlockGemm::MakeCBlockTile().get_tile_distribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::Impl::kCM1PerLane; + } + template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() { @@ -75,11 +99,11 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() { - using QDataType = remove_cvref_t; + using QDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(QDataType); constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; @@ -99,11 +123,11 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() { - using KDataType = remove_cvref_t; + using KDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; @@ -123,12 +147,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() { - using VLayout = remove_cvref_t; - using VDataType = remove_cvref_t; + using VLayout = remove_cvref_t; + using VDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; @@ -155,8 +179,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize() { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); @@ -177,10 +201,10 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize() { - using VLayout = remove_cvref_t; + using VLayout = remove_cvref_t; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; // Need special consideration for RowMajor since shuffling is needed to write LDS in dwords if constexpr(std::is_same_v) @@ -222,7 +246,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() { constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackQ(); constexpr index_t kKVector = GetAlignmentQ(); @@ -282,7 +306,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM(); - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType); @@ -311,8 +335,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers(); - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); @@ -394,8 +418,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy using QKVDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim; constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType); constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; @@ -421,12 +445,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() { - using VLayout = remove_cvref_t; + using VLayout = remove_cvref_t; constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers(); constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; // Need special consideration for RowMajor since shuffling is needed to write LDS in dwords if constexpr(std::is_same_v) @@ -556,11 +580,11 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() { - using VLayout = remove_cvref_t; + using VLayout = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; // Need special consideration for RowMajor since shuffling is needed to write LDS in dwords if constexpr(std::is_same_v) @@ -614,12 +638,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution() { // This tile-distribuiton only used when V layout is RowMajor - using VLayout = remove_cvref_t; + using VLayout = remove_cvref_t; static_assert(std::is_same_v); constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; + constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; constexpr index_t N1 = GetAlignmentV(); constexpr index_t N0 = kNPerBlock / N1; @@ -644,27 +668,29 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr index_t GetQKBlockGemmSingleRepM() { - return Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}) * - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + return Problem::HstuAttentionTileSetting::Gemm0BlockWarps::at(number<0>{}) * + Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{}); }; template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { - using GemmProblem = - BlockGemmProblem, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + using GemmProblem = BlockGemmProblem< + typename Problem::QKVDataType, + typename Problem::QKVDataType, + typename Problem::GemmAccDataType, + Problem::kNumGemm0Warps * get_warp_size(), + TileGemmShape, + typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps, + typename Problem::HstuAttentionTileSetting::Gemm0WarpTile>>; constexpr auto warp_gemm = []() { - constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); - constexpr index_t WarpGemmK = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}); + constexpr index_t WarpGemmM = + Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{}); + constexpr index_t WarpGemmK = + Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{}); static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); if constexpr(std::is_same_v && @@ -709,12 +735,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy } // TODO - bf8_t }(); - using BlockGemmPolicy = - BlockGemmARegBSmemCRegV2CustomPolicy; + using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy< + typename Problem::QKVDataType, + typename Problem::QKVDataType, + typename Problem::GemmAccDataType, + typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps, + decltype(warp_gemm)>; if constexpr(1 < Problem::kNumGemm0Warps) return BlockGemmARegBSmemCRegV2Hack_0{}; @@ -725,23 +751,23 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemmSingleRepN() { - return Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) * - Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + return Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}) * + Problem::HstuAttentionTileSetting::Gemm1BlockWarps::at(number<1>{}); }; template CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() { - using GemmProblem = - BlockGemmProblem, - typename Problem::BlockFmhaShape::Gemm1BlockWarps, - typename Problem::BlockFmhaShape::Gemm1WarpTile>>; + using GemmProblem = BlockGemmProblem< + typename Problem::QKVDataType, + typename Problem::QKVDataType, + typename Problem::GemmAccDataType, + Problem::kNumGemm1Warps * get_warp_size(), + TileGemmShape, + typename Problem::HstuAttentionTileSetting::Gemm1BlockWarps, + typename Problem::HstuAttentionTileSetting::Gemm1WarpTile>>; auto warp_gemm = [&]() { if constexpr(std::is_same_v && @@ -759,21 +785,21 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy typename Problem::QKVDataType, typename Problem::QKVDataType, typename Problem::GemmAccDataType, - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}), + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}), + Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}), true>{}; } }(); using WarpGemm = remove_cvref_t; - using BlockGemmPolicy = - BlockGemmARegBSmemCRegV2CustomPolicy; + using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy< + typename Problem::QKVDataType, + typename Problem::QKVDataType, + typename Problem::GemmAccDataType, + typename Problem::HstuAttentionTileSetting::Gemm1BlockWarps, + WarpGemm>; return BlockGemmARegBSmemCRegV2Hack_1{}; } @@ -803,10 +829,16 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy sizeof(typename Problem::QKVDataType); }; + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout() + { + return 0; + }; + template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - return max(GetSmemSizeKV() + GetSmemSizeDropout(0), + return max(GetSmemSizeKV() + GetSmemSizeDropout(), GetSmemSizeQ()); } }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index 65eeafbd0f..d50e2ffde8 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -7,9 +7,9 @@ #pragma once #include -#include #include "hstu_attention_fwd_type_config.hpp" +#include "hstu_attention_tile_setting_define.hpp" template struct HstuAttentionFwdBlockTile; @@ -51,48 +51,52 @@ struct HstuAttentionFwdBlockTile<256> using HstuAttentionFwdWarpTile1 = ck_tile::sequence<16, 16, 16>; template -struct HstuAttentionFwdShape; +struct HstuAttentionFwdTileSetting; template <> -struct HstuAttentionFwdShape<32> +struct HstuAttentionFwdTileSetting<32> { - using Type = ck_tile::TileFmhaShape::type, - typename HstuAttentionFwdBlockTile<32>::gemm0_warps, - HstuAttentionFwdWarpTile1, - typename HstuAttentionFwdBlockTile<32>::gemm1_warps, - HstuAttentionFwdWarpTile1, - IsVLayoutRowMajor>; + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionFwdBlockTile<32>::type, + typename HstuAttentionFwdBlockTile<32>::gemm0_warps, + HstuAttentionFwdWarpTile1, + typename HstuAttentionFwdBlockTile<32>::gemm1_warps, + HstuAttentionFwdWarpTile1, + IsVLayoutRowMajor>; }; template <> -struct HstuAttentionFwdShape<64> +struct HstuAttentionFwdTileSetting<64> { - using Type = ck_tile::TileFmhaShape::type, - typename HstuAttentionFwdBlockTile<64>::gemm0_warps, - HstuAttentionFwdWarpTile1, - typename HstuAttentionFwdBlockTile<64>::gemm1_warps, - HstuAttentionFwdWarpTile1, - IsVLayoutRowMajor>; + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionFwdBlockTile<64>::type, + typename HstuAttentionFwdBlockTile<64>::gemm0_warps, + HstuAttentionFwdWarpTile1, + typename HstuAttentionFwdBlockTile<64>::gemm1_warps, + HstuAttentionFwdWarpTile1, + IsVLayoutRowMajor>; }; template <> -struct HstuAttentionFwdShape<128> +struct HstuAttentionFwdTileSetting<128> { - using Type = ck_tile::TileFmhaShape::type, - typename HstuAttentionFwdBlockTile<128>::gemm0_warps, - HstuAttentionFwdWarpTile1, - typename HstuAttentionFwdBlockTile<128>::gemm1_warps, - HstuAttentionFwdWarpTile1, - IsVLayoutRowMajor>; + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionFwdBlockTile<128>::type, + typename HstuAttentionFwdBlockTile<128>::gemm0_warps, + HstuAttentionFwdWarpTile1, + typename HstuAttentionFwdBlockTile<128>::gemm1_warps, + HstuAttentionFwdWarpTile1, + IsVLayoutRowMajor>; }; template <> -struct HstuAttentionFwdShape<256> +struct HstuAttentionFwdTileSetting<256> { - using Type = ck_tile::TileFmhaShape::type, - typename HstuAttentionFwdBlockTile<256>::gemm0_warps, - HstuAttentionFwdWarpTile1, - typename HstuAttentionFwdBlockTile<256>::gemm1_warps, - HstuAttentionFwdWarpTile1, - IsVLayoutRowMajor>; + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionFwdBlockTile<256>::type, + typename HstuAttentionFwdBlockTile<256>::gemm0_warps, + HstuAttentionFwdWarpTile1, + typename HstuAttentionFwdBlockTile<256>::gemm1_warps, + HstuAttentionFwdWarpTile1, + IsVLayoutRowMajor>; }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index 9a093b7663..a4d27b7eff 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -10,7 +10,6 @@ #include #include #include -#include #include "hstu_attention_bool_switch.hpp" #include "hstu_attention_fwd_type_config.hpp" @@ -32,8 +31,8 @@ template struct jagged_forward_causal_local_bias_dropout_dispatch { - using HstuAttentionShape = typename HstuAttentionFwdShape::Type; - using HstuMask = typename ck_tile::HstuBlockMasking::Type; + using HstuAttentionTileSetting = typename HstuAttentionFwdTileSetting::Type; + using HstuMask = typename ck_tile::HstuBlockMasking::Type; template using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem< @@ -45,15 +44,15 @@ struct jagged_forward_causal_local_bias_dropout_dispatch kHasBias, kHasDropout, HstuMask, - HstuAttentionShape, + HstuAttentionTileSetting, HstuTraits>; static void Run(HstuAttentionFwdParams& param, hipStream_t stream) { constexpr ck_tile::index_t occupancy = -1; - const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionShape::kSubQKHeaddim == 0); - const bool pad_headdim_v = !(param.hdim_v % HstuAttentionShape::kN1 == 0); + const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0); + const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0); // no need to check seqlen_q since it is not used as fastest dim, // buffer_load_dwordxx/buffer_store_dwordxx can handle oob access diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp index bec43a83a2..fb6d9d2a24 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp @@ -21,7 +21,7 @@ template struct HstuAttentionFwdPipelineProblem { @@ -35,10 +35,6 @@ struct HstuAttentionFwdPipelineProblem using BiasDataType = remove_cvref_t; // to be compatible with ck_tile existing policy codes - using QDataType = QKVDataType; - using KDataType = QKVDataType; - using VDataType = QKVDataType; - using SaccDataType = GemmAccDataType; using OaccDataType = GemmAccDataType; using PDataType = QKVDataType; @@ -48,15 +44,13 @@ struct HstuAttentionFwdPipelineProblem using HstuMask = remove_cvref_t; - using HstuAttentionTileShape = remove_cvref_t; + using HstuAttentionTileSetting = remove_cvref_t; - // Keep the name compatible with ck_tile existing policy codes, to be changed - using BlockFmhaShape = HstuAttentionTileShape; - using Traits = remove_cvref_t; + using Traits = remove_cvref_t; - static constexpr index_t kNumGemm0Warps = AttentionTileShape_::NumGemm0Warps; - static constexpr index_t kNumGemm1Warps = AttentionTileShape_::NumGemm1Warps; - static constexpr index_t kBlockSize = AttentionTileShape_::NumWarps * get_warp_size(); + static constexpr index_t kNumGemm0Warps = AttentionTileSetting_::NumGemm0Warps; + static constexpr index_t kNumGemm1Warps = AttentionTileSetting_::NumGemm1Warps; + static constexpr index_t kBlockSize = AttentionTileSetting_::NumWarps * get_warp_size(); }; } // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp new file mode 100644 index 0000000000..a44c33fddd --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len) +{ + if(len == 96) + return 128; + if(len == 160) + return 256; + if(len == 192) + return 192; + + // only length of 96, 160, 192 and power-of-two is supported + if(!(len & (len - 1))) + return len; + + return 0; +}; + +template +struct HstuAttentionFwdTileSettingClass +{ + using BlockTile = remove_cvref_t; + using Gemm0BlockWarps = remove_cvref_t; + using Gemm0WarpTile = remove_cvref_t; + using Gemm1BlockWarps = remove_cvref_t; + using Gemm1WarpTile = remove_cvref_t; + + static constexpr index_t NumGemm0Warps = + reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{}); + static constexpr index_t NumGemm1Warps = + reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}); + static_assert(NumGemm1Warps % NumGemm0Warps == 0); + + static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); + + static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen + static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen + static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll + static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim + static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll + static constexpr index_t kQKHeaddim = + BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at + // once (or repeately load Q as a whole tile) + static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0"); + + static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim); + + // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen + static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; + using VLayout = std::conditional_t; +}; + +} // namespace ck_tile