From 2072e53d1e0d5ec144633f3bf2092945ee02fba7 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 13 Oct 2025 16:01:50 +0000 Subject: [PATCH] Remove K0 from tile setting since it is not used --- .../hstu_attention_fwd_pipeline.hpp | 5 ++--- .../hstu_attention_fwd_setting.hpp | 20 +++++++++---------- .../hstu_attention_tile_setting_define.hpp | 14 ++++++++----- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 8e2a36dabb..7896fac39e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -28,7 +28,6 @@ struct HstuAttentionFwdPipelineQRKSVS static constexpr index_t kM0 = HstuAttentionTileSetting::kM0; static constexpr index_t kN0 = HstuAttentionTileSetting::kN0; - static constexpr index_t kK0 = HstuAttentionTileSetting::kK0; static constexpr index_t kN1 = HstuAttentionTileSetting::kN1; static constexpr index_t kK1 = HstuAttentionTileSetting::kK1; static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim; @@ -522,8 +521,8 @@ struct HstuAttentionFwdPipelineQRKSVS typename BiasDramBlockWindowTmp, typename HstuMask> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile HstuMask mask, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index b3c8dfdc30..db413b14b4 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -18,12 +18,12 @@ template struct HstuAttentionFwdTileSetting; #if !defined(BUILD_HSTU_FOR_GFX95_ONLY) -// Tile-sizes: M N0 K0 N1 K1 MaxK (MaxK % K0 == 0, MaxK % N1 == 0, N0 % K1 == 0) +// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) // template <> struct HstuAttentionFwdBlockTile<32> { - using type = ck_tile::sequence<64, 64, 16, 32, 32, 32>; + using type = ck_tile::sequence<64, 64, 32, 32, 32>; using gemm0_warps = ck_tile::sequence<2, 1, 1>; using gemm1_warps = ck_tile::sequence<2, 1, 1>; }; @@ -31,7 +31,7 @@ struct HstuAttentionFwdBlockTile<32> template <> struct HstuAttentionFwdBlockTile<64> { - using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>; + using type = ck_tile::sequence<128, 64, 64, 32, 64>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -39,7 +39,7 @@ struct HstuAttentionFwdBlockTile<64> template <> struct HstuAttentionFwdBlockTile<128> { - using type = ck_tile::sequence<128, 32, 16, 128, 16, 128>; + using type = ck_tile::sequence<128, 32, 128, 16, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -47,7 +47,7 @@ struct HstuAttentionFwdBlockTile<128> template <> struct HstuAttentionFwdBlockTile<256> { - using type = ck_tile::sequence<128, 32, 16, 256, 16, 256>; + using type = ck_tile::sequence<128, 32, 256, 16, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -98,12 +98,12 @@ struct HstuAttentionFwdTileSetting<256> #endif #if defined(BUILD_HSTU_FOR_GFX95_ONLY) -// Tile-sizes: M N0 K0 N1 K1 MaxK (MaxK % K0 == 0, MaxK % N1 == 0, N0 % K1 == 0) +// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) // template <> struct HstuAttentionFwdBlockTile<32> { - using type = ck_tile::sequence<64, 64, 32, 32, 16, 32>; + using type = ck_tile::sequence<64, 64, 32, 16, 32>; using gemm0_warps = ck_tile::sequence<2, 1, 1>; using gemm1_warps = ck_tile::sequence<2, 1, 1>; }; @@ -111,7 +111,7 @@ struct HstuAttentionFwdBlockTile<32> template <> struct HstuAttentionFwdBlockTile<64> { - using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>; + using type = ck_tile::sequence<128, 64, 64, 32, 64>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -119,7 +119,7 @@ struct HstuAttentionFwdBlockTile<64> template <> struct HstuAttentionFwdBlockTile<128> { - using type = ck_tile::sequence<128, 32, 32, 128, 16, 128>; + using type = ck_tile::sequence<128, 32, 128, 16, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -127,7 +127,7 @@ struct HstuAttentionFwdBlockTile<128> template <> struct HstuAttentionFwdBlockTile<256> { - using type = ck_tile::sequence<128, 32, 32, 256, 16, 256>; + using type = ck_tile::sequence<128, 32, 256, 16, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp index d0b8920997..d19dfc91c2 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp @@ -36,6 +36,12 @@ struct HstuAttentionFwdTileSettingClass using Gemm1BlockWarps = remove_cvref_t; using Gemm1WarpTile = remove_cvref_t; + static_assert(BlockTile::size() == 5, "Check failed!"); + static_assert(Gemm0BlockWarps::size() == 3, "Check failed!"); + static_assert(Gemm0WarpTile::size() == 3, "Check failed!"); + static_assert(Gemm1BlockWarps::size() == 3, "Check failed!"); + static_assert(Gemm1WarpTile::size() == 3, "Check failed!"); + static constexpr index_t NumGemm0Warps = reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{}); static constexpr index_t NumGemm1Warps = @@ -46,13 +52,11 @@ struct HstuAttentionFwdTileSettingClass static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen - static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll - static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim - static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll + static constexpr index_t kN1 = BlockTile::at(number<2>{}); // tile size along v head_dim + static constexpr index_t kK1 = BlockTile::at(number<3>{}); // tile size along kv gemm unroll static constexpr index_t kQKHeaddim = - BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at + BlockTile::at(number<4>{}); // total length of K0, used for pipeline that need load Q at // once (or repeately load Q as a whole tile) - static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0"); static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim); };