mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Remove K0 from tile setting since it is not used
This commit is contained in:
@@ -28,7 +28,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
static constexpr index_t kM0 = HstuAttentionTileSetting::kM0;
|
||||
static constexpr index_t kN0 = HstuAttentionTileSetting::kN0;
|
||||
static constexpr index_t kK0 = HstuAttentionTileSetting::kK0;
|
||||
static constexpr index_t kN1 = HstuAttentionTileSetting::kN1;
|
||||
static constexpr index_t kK1 = HstuAttentionTileSetting::kK1;
|
||||
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
|
||||
@@ -522,8 +521,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
HstuMask mask,
|
||||
|
||||
@@ -18,12 +18,12 @@ template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionFwdTileSetting;
|
||||
|
||||
#if !defined(BUILD_HSTU_FOR_GFX95_ONLY)
|
||||
// Tile-sizes: M N0 K0 N1 K1 MaxK (MaxK % K0 == 0, MaxK % N1 == 0, N0 % K1 == 0)
|
||||
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
|
||||
//
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<32>
|
||||
{
|
||||
using type = ck_tile::sequence<64, 64, 16, 32, 32, 32>;
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 32>;
|
||||
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
|
||||
};
|
||||
@@ -31,7 +31,7 @@ struct HstuAttentionFwdBlockTile<32>
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<64>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
|
||||
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -39,7 +39,7 @@ struct HstuAttentionFwdBlockTile<64>
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 16, 128, 16, 128>;
|
||||
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -47,7 +47,7 @@ struct HstuAttentionFwdBlockTile<128>
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<256>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 16, 256, 16, 256>;
|
||||
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -98,12 +98,12 @@ struct HstuAttentionFwdTileSetting<256>
|
||||
#endif
|
||||
|
||||
#if defined(BUILD_HSTU_FOR_GFX95_ONLY)
|
||||
// Tile-sizes: M N0 K0 N1 K1 MaxK (MaxK % K0 == 0, MaxK % N1 == 0, N0 % K1 == 0)
|
||||
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
|
||||
//
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<32>
|
||||
{
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 16, 32>;
|
||||
using type = ck_tile::sequence<64, 64, 32, 16, 32>;
|
||||
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
|
||||
};
|
||||
@@ -111,7 +111,7 @@ struct HstuAttentionFwdBlockTile<32>
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<64>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
|
||||
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -119,7 +119,7 @@ struct HstuAttentionFwdBlockTile<64>
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 32, 128, 16, 128>;
|
||||
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
@@ -127,7 +127,7 @@ struct HstuAttentionFwdBlockTile<128>
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<256>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 32, 256, 16, 256>;
|
||||
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
@@ -36,6 +36,12 @@ struct HstuAttentionFwdTileSettingClass
|
||||
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
|
||||
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
|
||||
|
||||
static_assert(BlockTile::size() == 5, "Check failed!");
|
||||
static_assert(Gemm0BlockWarps::size() == 3, "Check failed!");
|
||||
static_assert(Gemm0WarpTile::size() == 3, "Check failed!");
|
||||
static_assert(Gemm1BlockWarps::size() == 3, "Check failed!");
|
||||
static_assert(Gemm1WarpTile::size() == 3, "Check failed!");
|
||||
|
||||
static constexpr index_t NumGemm0Warps =
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
|
||||
static constexpr index_t NumGemm1Warps =
|
||||
@@ -46,13 +52,11 @@ struct HstuAttentionFwdTileSettingClass
|
||||
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
|
||||
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
|
||||
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
|
||||
static constexpr index_t kN1 = BlockTile::at(number<2>{}); // tile size along v head_dim
|
||||
static constexpr index_t kK1 = BlockTile::at(number<3>{}); // tile size along kv gemm unroll
|
||||
static constexpr index_t kQKHeaddim =
|
||||
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
|
||||
BlockTile::at(number<4>{}); // total length of K0, used for pipeline that need load Q at
|
||||
// once (or repeately load Q as a whole tile)
|
||||
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
|
||||
|
||||
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user