Remove K0 from tile setting since it is not used

This commit is contained in:
Qianfeng Zhang
2025-10-13 16:01:50 +00:00
parent 22a7b31865
commit 2072e53d1e
3 changed files with 21 additions and 18 deletions

View File

@@ -28,7 +28,6 @@ struct HstuAttentionFwdPipelineQRKSVS
static constexpr index_t kM0 = HstuAttentionTileSetting::kM0;
static constexpr index_t kN0 = HstuAttentionTileSetting::kN0;
static constexpr index_t kK0 = HstuAttentionTileSetting::kK0;
static constexpr index_t kN1 = HstuAttentionTileSetting::kN1;
static constexpr index_t kK1 = HstuAttentionTileSetting::kK1;
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
@@ -522,8 +521,8 @@ struct HstuAttentionFwdPipelineQRKSVS
typename BiasDramBlockWindowTmp,
typename HstuMask>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
HstuMask mask,

View File

@@ -18,12 +18,12 @@ template <ck_tile::index_t MaxK>
struct HstuAttentionFwdTileSetting;
#if !defined(BUILD_HSTU_FOR_GFX95_ONLY)
// Tile-sizes: M N0 K0 N1 K1 MaxK (MaxK % K0 == 0, MaxK % N1 == 0, N0 % K1 == 0)
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
//
template <>
struct HstuAttentionFwdBlockTile<32>
{
using type = ck_tile::sequence<64, 64, 16, 32, 32, 32>;
using type = ck_tile::sequence<64, 64, 32, 32, 32>;
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
};
@@ -31,7 +31,7 @@ struct HstuAttentionFwdBlockTile<32>
template <>
struct HstuAttentionFwdBlockTile<64>
{
using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
@@ -39,7 +39,7 @@ struct HstuAttentionFwdBlockTile<64>
template <>
struct HstuAttentionFwdBlockTile<128>
{
using type = ck_tile::sequence<128, 32, 16, 128, 16, 128>;
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
@@ -47,7 +47,7 @@ struct HstuAttentionFwdBlockTile<128>
template <>
struct HstuAttentionFwdBlockTile<256>
{
using type = ck_tile::sequence<128, 32, 16, 256, 16, 256>;
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
@@ -98,12 +98,12 @@ struct HstuAttentionFwdTileSetting<256>
#endif
#if defined(BUILD_HSTU_FOR_GFX95_ONLY)
// Tile-sizes: M N0 K0 N1 K1 MaxK (MaxK % K0 == 0, MaxK % N1 == 0, N0 % K1 == 0)
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
//
template <>
struct HstuAttentionFwdBlockTile<32>
{
using type = ck_tile::sequence<64, 64, 32, 32, 16, 32>;
using type = ck_tile::sequence<64, 64, 32, 16, 32>;
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
};
@@ -111,7 +111,7 @@ struct HstuAttentionFwdBlockTile<32>
template <>
struct HstuAttentionFwdBlockTile<64>
{
using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
@@ -119,7 +119,7 @@ struct HstuAttentionFwdBlockTile<64>
template <>
struct HstuAttentionFwdBlockTile<128>
{
using type = ck_tile::sequence<128, 32, 32, 128, 16, 128>;
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
@@ -127,7 +127,7 @@ struct HstuAttentionFwdBlockTile<128>
template <>
struct HstuAttentionFwdBlockTile<256>
{
using type = ck_tile::sequence<128, 32, 32, 256, 16, 256>;
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};

View File

@@ -36,6 +36,12 @@ struct HstuAttentionFwdTileSettingClass
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
static_assert(BlockTile::size() == 5, "Check failed!");
static_assert(Gemm0BlockWarps::size() == 3, "Check failed!");
static_assert(Gemm0WarpTile::size() == 3, "Check failed!");
static_assert(Gemm1BlockWarps::size() == 3, "Check failed!");
static_assert(Gemm1WarpTile::size() == 3, "Check failed!");
static constexpr index_t NumGemm0Warps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t NumGemm1Warps =
@@ -46,13 +52,11 @@ struct HstuAttentionFwdTileSettingClass
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
static constexpr index_t kN1 = BlockTile::at(number<2>{}); // tile size along v head_dim
static constexpr index_t kK1 = BlockTile::at(number<3>{}); // tile size along kv gemm unroll
static constexpr index_t kQKHeaddim =
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
BlockTile::at(number<4>{}); // total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
};