mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Using separate settings for gfx942 and gfx950
This commit is contained in:
@@ -16,6 +16,10 @@ if (DEFINED ENV{ASSUME_HIGHLY_VARIED_SEQLEN})
|
||||
list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -DHSTU_SCHED_BATCH_AS_FIRST_GRID_DIM=0)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx90")
|
||||
list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -DBUILD_HSTU_FOR_GFX95_ONLY)
|
||||
endif()
|
||||
|
||||
target_compile_options(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
|
||||
@@ -14,6 +14,13 @@
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionFwdBlockTile;
|
||||
|
||||
using HstuAttentionFwdWarpTile1 = ck_tile::sequence<16, 16, 16>;
|
||||
using HstuAttentionFwdWarpTile2 = ck_tile::sequence<16, 16, 32>;
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionFwdTileSetting;
|
||||
|
||||
#if !defined(BUILD_HSTU_FOR_GFX95_ONLY)
|
||||
// Tile-sizes: M N0 K0 N1 K1 MaxK (MaxK % K0 == 0, MaxK % N1 == 0, N0 % K1 == 0)
|
||||
//
|
||||
template <>
|
||||
@@ -48,11 +55,6 @@ struct HstuAttentionFwdBlockTile<256>
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
using HstuAttentionFwdWarpTile1 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionFwdTileSetting;
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdTileSetting<32>
|
||||
{
|
||||
@@ -100,3 +102,88 @@ struct HstuAttentionFwdTileSetting<256>
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined(BUILD_HSTU_FOR_GFX95_ONLY)
|
||||
// Tile-sizes: M N0 K0 N1 K1 MaxK (MaxK % K0 == 0, MaxK % N1 == 0, N0 % K1 == 0)
|
||||
//
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<32>
|
||||
{
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 16, 32>;
|
||||
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<64>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 32, 128, 16, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<256>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 32, 256, 16, 256>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdTileSetting<32>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<32>::type,
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdTileSetting<64>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<64>::type,
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdTileSetting<128>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<128>::type,
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdTileSetting<256>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<256>::type,
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
};
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user