Enable run-time selection of MTile sizes according to the predicted CU utilization ratio

This commit is contained in:
Qianfeng Zhang
2026-03-19 10:39:37 +00:00
parent 302537c5a8
commit 76da618c85
5 changed files with 279 additions and 139 deletions

View File

@@ -27,13 +27,14 @@ template <typename InOutDataType,
bool kUseSoftmax,
bool kHasBias,
bool kHasDropout,
ck_tile::index_t MaxK>
ck_tile::index_t MaxK,
ck_tile::index_t MTile>
struct batched_forward_causal_softmax_bias_dropout_dispatch
{
using HstuAttentionTileSetting =
typename std::conditional_t<kUseSoftmax,
HstuAttentionWithSoftmaxFwdTileSetting<MaxK>,
HstuAttentionNoSoftmaxFwdTileSetting<MaxK>>::Type;
HstuAttentionWithSoftmaxFwdTileSetting<MaxK, MTile>,
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
@@ -189,10 +190,20 @@ template <typename InOutDataType,
void run_batched_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGroupFwdParams& param,
hipStream_t stream)
{
batched_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen) == 128)
batched_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
128>::Run(param, stream);
else
batched_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
64>::Run(param, stream);
};

View File

@@ -7,27 +7,28 @@
#include "hstu_attention_fwd_type_config.hpp"
#include "hstu_attention_tile_setting_define.hpp"
#include "hstu_attention_util.hpp"
using WarpTile_16x16x16 = ck_tile::sequence<16, 16, 16>;
using WarpTile_16x16x32 = ck_tile::sequence<16, 16, 32>;
using WarpTile_32x32x16 = ck_tile::sequence<32, 32, 16>;
#if !defined(BUILD_HSTU_FOR_GFX95_ONLY)
template <ck_tile::index_t MaxK>
template <ck_tile::index_t MaxK, ck_tile::index_t MTile = 0>
struct HstuAttentionNoSoftmaxFwdBlockTile;
// Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
//
template <>
struct HstuAttentionNoSoftmaxFwdBlockTile<64>
template <ck_tile::index_t MTile>
struct HstuAttentionNoSoftmaxFwdBlockTile<64, MTile>
{
using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionNoSoftmaxFwdBlockTile<96>
template <ck_tile::index_t MTile>
struct HstuAttentionNoSoftmaxFwdBlockTile<96, MTile>
{
using type = ck_tile::sequence<128, 64, 32, 128, 32, 96>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
@@ -35,36 +36,44 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<96>
};
template <>
struct HstuAttentionNoSoftmaxFwdBlockTile<128>
struct HstuAttentionNoSoftmaxFwdBlockTile<128, 64>
{
using type = ck_tile::sequence<64, 32, 16, 128, 16, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionNoSoftmaxFwdBlockTile<128, 128>
{
using type = ck_tile::sequence<128, 32, 16, 128, 16, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionNoSoftmaxFwdBlockTile<256>
template <ck_tile::index_t MTile>
struct HstuAttentionNoSoftmaxFwdBlockTile<256, MTile>
{
using type = ck_tile::sequence<128, 32, 16, 256, 16, 256>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <ck_tile::index_t MaxK>
template <ck_tile::index_t MaxK, ck_tile::index_t MTile = 0>
struct HstuAttentionWithSoftmaxFwdBlockTile;
// Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
//
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<64>
template <ck_tile::index_t MTile>
struct HstuAttentionWithSoftmaxFwdBlockTile<64, MTile>
{
using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<96>
template <ck_tile::index_t MTile>
struct HstuAttentionWithSoftmaxFwdBlockTile<96, MTile>
{
using type = ck_tile::sequence<128, 64, 32, 128, 32, 96>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
@@ -72,26 +81,34 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<96>
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<128>
struct HstuAttentionWithSoftmaxFwdBlockTile<128, 64>
{
using type = ck_tile::sequence<64, 64, 16, 128, 16, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<128, 128>
{
using type = ck_tile::sequence<128, 64, 16, 128, 16, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<256>
template <ck_tile::index_t MTile>
struct HstuAttentionWithSoftmaxFwdBlockTile<256, MTile>
{
using type = ck_tile::sequence<128, 32, 16, 256, 16, 256>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <ck_tile::index_t MaxK>
template <ck_tile::index_t MaxK, ck_tile::index_t MTile = 0>
struct HstuAttentionNoSoftmaxFwdTileSetting;
template <>
struct HstuAttentionNoSoftmaxFwdTileSetting<64>
template <ck_tile::index_t MTile>
struct HstuAttentionNoSoftmaxFwdTileSetting<64, MTile>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::type,
@@ -101,8 +118,11 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<64>
WarpTile_16x16x16>;
};
template <>
struct HstuAttentionNoSoftmaxFwdTileSetting<96>
template struct HstuAttentionNoSoftmaxFwdTileSetting<64, 64>;
template struct HstuAttentionNoSoftmaxFwdTileSetting<64, 128>;
template <ck_tile::index_t MTile>
struct HstuAttentionNoSoftmaxFwdTileSetting<96, MTile>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionNoSoftmaxFwdBlockTile<96>::type,
@@ -112,19 +132,33 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<96>
WarpTile_16x16x16>;
};
template struct HstuAttentionNoSoftmaxFwdTileSetting<96, 64>;
template struct HstuAttentionNoSoftmaxFwdTileSetting<96, 128>;
template <>
struct HstuAttentionNoSoftmaxFwdTileSetting<128>
struct HstuAttentionNoSoftmaxFwdTileSetting<128, 64>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm0_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<128, 64>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<128, 64>::gemm0_warps,
WarpTile_16x16x16,
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm1_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<128, 64>::gemm1_warps,
WarpTile_16x16x16>;
};
template <>
struct HstuAttentionNoSoftmaxFwdTileSetting<256>
struct HstuAttentionNoSoftmaxFwdTileSetting<128, 128>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionNoSoftmaxFwdBlockTile<128, 128>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<128, 128>::gemm0_warps,
WarpTile_16x16x16,
typename HstuAttentionNoSoftmaxFwdBlockTile<128, 128>::gemm1_warps,
WarpTile_16x16x16>;
};
template <ck_tile::index_t MTile>
struct HstuAttentionNoSoftmaxFwdTileSetting<256, MTile>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::type,
@@ -134,11 +168,14 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<256>
WarpTile_16x16x16>;
};
template <ck_tile::index_t MaxK>
template struct HstuAttentionNoSoftmaxFwdTileSetting<256, 64>;
template struct HstuAttentionNoSoftmaxFwdTileSetting<256, 128>;
template <ck_tile::index_t MaxK, ck_tile::index_t MTile = 0>
struct HstuAttentionWithSoftmaxFwdTileSetting;
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<64>
template <ck_tile::index_t MTile>
struct HstuAttentionWithSoftmaxFwdTileSetting<64, MTile>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::type,
@@ -148,8 +185,11 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<64>
WarpTile_32x32x16>;
};
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<96>
template struct HstuAttentionWithSoftmaxFwdTileSetting<64, 64>;
template struct HstuAttentionWithSoftmaxFwdTileSetting<64, 128>;
template <ck_tile::index_t MTile>
struct HstuAttentionWithSoftmaxFwdTileSetting<96, MTile>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<96>::type,
@@ -159,19 +199,33 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<96>
WarpTile_32x32x16>;
};
template struct HstuAttentionWithSoftmaxFwdTileSetting<96, 64>;
template struct HstuAttentionWithSoftmaxFwdTileSetting<96, 128>;
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<128>
struct HstuAttentionWithSoftmaxFwdTileSetting<128, 64>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm0_warps,
typename HstuAttentionWithSoftmaxFwdBlockTile<128, 64>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<128, 64>::gemm0_warps,
WarpTile_16x16x16,
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm1_warps,
typename HstuAttentionWithSoftmaxFwdBlockTile<128, 64>::gemm1_warps,
WarpTile_16x16x16>;
};
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<256>
struct HstuAttentionWithSoftmaxFwdTileSetting<128, 128>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<128, 128>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<128, 128>::gemm0_warps,
WarpTile_16x16x16,
typename HstuAttentionWithSoftmaxFwdBlockTile<128, 128>::gemm1_warps,
WarpTile_16x16x16>;
};
template <ck_tile::index_t MTile>
struct HstuAttentionWithSoftmaxFwdTileSetting<256, MTile>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::type,
@@ -180,32 +234,27 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<256>
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm1_warps,
WarpTile_16x16x16>;
};
template struct HstuAttentionWithSoftmaxFwdTileSetting<256, 64>;
template struct HstuAttentionWithSoftmaxFwdTileSetting<256, 128>;
#endif
#if defined(BUILD_HSTU_FOR_GFX95_ONLY)
template <ck_tile::index_t MaxK>
template <ck_tile::index_t MaxK, ck_tile::index_t MTile = 0>
struct HstuAttentionNoSoftmaxFwdBlockTile;
// Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
//
template <>
struct HstuAttentionNoSoftmaxFwdBlockTile<32>
{
using type = ck_tile::sequence<64, 64, 32, 32, 32, 32>;
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
};
template <>
struct HstuAttentionNoSoftmaxFwdBlockTile<64>
template <ck_tile::index_t MTile>
struct HstuAttentionNoSoftmaxFwdBlockTile<64, MTile>
{
using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionNoSoftmaxFwdBlockTile<96>
template <ck_tile::index_t MTile>
struct HstuAttentionNoSoftmaxFwdBlockTile<96, MTile>
{
using type = ck_tile::sequence<128, 64, 32, 128, 32, 96>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
@@ -213,44 +262,44 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<96>
};
template <>
struct HstuAttentionNoSoftmaxFwdBlockTile<128>
struct HstuAttentionNoSoftmaxFwdBlockTile<128, 64>
{
using type = ck_tile::sequence<64, 32, 32, 128, 32, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionNoSoftmaxFwdBlockTile<128, 128>
{
using type = ck_tile::sequence<128, 32, 32, 128, 32, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionNoSoftmaxFwdBlockTile<256>
template <ck_tile::index_t MTile>
struct HstuAttentionNoSoftmaxFwdBlockTile<256, MTile>
{
using type = ck_tile::sequence<128, 32, 32, 256, 32, 256>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <ck_tile::index_t MaxK>
template <ck_tile::index_t MaxK, ck_tile::index_t = 0>
struct HstuAttentionWithSoftmaxFwdBlockTile;
// Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
//
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<32>
{
using type = ck_tile::sequence<64, 64, 32, 32, 32, 32>;
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<64>
template <ck_tile::index_t MTile>
struct HstuAttentionWithSoftmaxFwdBlockTile<64, MTile>
{
using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<96>
template <ck_tile::index_t MTile>
struct HstuAttentionWithSoftmaxFwdBlockTile<96, MTile>
{
using type = ck_tile::sequence<128, 64, 32, 128, 32, 96>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
@@ -258,37 +307,34 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<96>
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<128>
struct HstuAttentionWithSoftmaxFwdBlockTile<128, 64>
{
using type = ck_tile::sequence<64, 64, 32, 128, 32, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<128, 128>
{
using type = ck_tile::sequence<128, 64, 32, 128, 16, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<256>
template <ck_tile::index_t MTile>
struct HstuAttentionWithSoftmaxFwdBlockTile<256, MTile>
{
using type = ck_tile::sequence<128, 64, 16, 256, 32, 256>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <ck_tile::index_t MaxK>
template <ck_tile::index_t MaxK, ck_tile::index_t MTile = 0>
struct HstuAttentionNoSoftmaxFwdTileSetting;
template <>
struct HstuAttentionNoSoftmaxFwdTileSetting<32>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm0_warps,
WarpTile_16x16x32,
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm1_warps,
WarpTile_16x16x32>;
};
template <>
struct HstuAttentionNoSoftmaxFwdTileSetting<64>
template <ck_tile::index_t MTile>
struct HstuAttentionNoSoftmaxFwdTileSetting<64, MTile>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::type,
@@ -298,8 +344,11 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<64>
WarpTile_16x16x32>;
};
template <>
struct HstuAttentionNoSoftmaxFwdTileSetting<96>
template struct HstuAttentionNoSoftmaxFwdTileSetting<64, 64>;
template struct HstuAttentionNoSoftmaxFwdTileSetting<64, 128>;
template <ck_tile::index_t MTile>
struct HstuAttentionNoSoftmaxFwdTileSetting<96, MTile>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionNoSoftmaxFwdBlockTile<96>::type,
@@ -309,19 +358,33 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<96>
WarpTile_16x16x32>;
};
template struct HstuAttentionNoSoftmaxFwdTileSetting<96, 64>;
template struct HstuAttentionNoSoftmaxFwdTileSetting<96, 128>;
template <>
struct HstuAttentionNoSoftmaxFwdTileSetting<128>
struct HstuAttentionNoSoftmaxFwdTileSetting<128, 64>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm0_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<128, 64>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<128, 64>::gemm0_warps,
WarpTile_16x16x32,
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm1_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<128, 64>::gemm1_warps,
WarpTile_16x16x32>;
};
template <>
struct HstuAttentionNoSoftmaxFwdTileSetting<256>
struct HstuAttentionNoSoftmaxFwdTileSetting<128, 128>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionNoSoftmaxFwdBlockTile<128, 128>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<128, 128>::gemm0_warps,
WarpTile_16x16x32,
typename HstuAttentionNoSoftmaxFwdBlockTile<128, 128>::gemm1_warps,
WarpTile_16x16x32>;
};
template <ck_tile::index_t MTile>
struct HstuAttentionNoSoftmaxFwdTileSetting<256, MTile>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::type,
@@ -331,22 +394,14 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<256>
WarpTile_16x16x32>;
};
template <ck_tile::index_t MaxK>
template struct HstuAttentionNoSoftmaxFwdTileSetting<256, 64>;
template struct HstuAttentionNoSoftmaxFwdTileSetting<256, 128>;
template <ck_tile::index_t MaxK, ck_tile::index_t MTile = 0>
struct HstuAttentionWithSoftmaxFwdTileSetting;
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<32>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm0_warps,
WarpTile_16x16x32,
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm1_warps,
WarpTile_16x16x32>;
};
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<64>
template <ck_tile::index_t MTile>
struct HstuAttentionWithSoftmaxFwdTileSetting<64, MTile>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::type,
@@ -356,8 +411,11 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<64>
WarpTile_32x32x16>;
};
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<96>
template struct HstuAttentionWithSoftmaxFwdTileSetting<64, 64>;
template struct HstuAttentionWithSoftmaxFwdTileSetting<64, 128>;
template <ck_tile::index_t MTile>
struct HstuAttentionWithSoftmaxFwdTileSetting<96, MTile>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<96>::type,
@@ -367,19 +425,33 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<96>
WarpTile_32x32x16>;
};
template struct HstuAttentionWithSoftmaxFwdTileSetting<96, 64>;
template struct HstuAttentionWithSoftmaxFwdTileSetting<96, 128>;
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<128>
struct HstuAttentionWithSoftmaxFwdTileSetting<128, 64>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm0_warps,
WarpTile_32x32x16,
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm1_warps,
WarpTile_32x32x16>;
typename HstuAttentionWithSoftmaxFwdBlockTile<128, 64>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<128, 64>::gemm0_warps,
WarpTile_16x16x32,
typename HstuAttentionWithSoftmaxFwdBlockTile<128, 64>::gemm1_warps,
WarpTile_16x16x32>;
};
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<256>
struct HstuAttentionWithSoftmaxFwdTileSetting<128, 128>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<128, 128>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<128, 128>::gemm0_warps,
WarpTile_32x32x16,
typename HstuAttentionWithSoftmaxFwdBlockTile<128, 128>::gemm1_warps,
WarpTile_32x32x16>;
};
template <ck_tile::index_t MTile>
struct HstuAttentionWithSoftmaxFwdTileSetting<256, MTile>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::type,
@@ -388,4 +460,24 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<256>
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm1_warps,
WarpTile_16x16x32>;
};
template struct HstuAttentionWithSoftmaxFwdTileSetting<256, 64>;
template struct HstuAttentionWithSoftmaxFwdTileSetting<256, 128>;
#endif
static int get_hstu_attention_fwd_mtile(int num_batches, int num_heads, int max_seqlen_q)
{
int num_CUs = get_number_of_cu();
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
int nbatch_nhead_mblocks = num_batches * num_heads * ceildiv(max_seqlen_q, 128);
// assuming each CU is assigned two work-groups
if(nbatch_nhead_mblocks >= static_cast<int>(0.85f * num_CUs * 2.0f))
return 128;
// currently, only hdim-128 actually uses mtile-64, for other hdim, the settings for
// mtile-64 can be added through tuning/verification
return 64;
};

View File

@@ -27,13 +27,14 @@ template <typename InOutDataType,
bool kUseSoftmax,
bool kHasBias,
bool kHasDropout,
ck_tile::index_t MaxK>
ck_tile::index_t MaxK,
ck_tile::index_t MTile>
struct group_forward_causal_softmax_bias_dropout_dispatch
{
using HstuAttentionTileSetting =
typename std::conditional_t<kUseSoftmax,
HstuAttentionWithSoftmaxFwdTileSetting<MaxK>,
HstuAttentionNoSoftmaxFwdTileSetting<MaxK>>::Type;
HstuAttentionWithSoftmaxFwdTileSetting<MaxK, MTile>,
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
@@ -175,10 +176,20 @@ template <typename InOutDataType,
void run_group_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionGroupFwdParams& param,
hipStream_t stream)
{
group_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen) == 128)
group_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
128>::Run(param, stream);
else
group_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
64>::Run(param, stream);
};

View File

@@ -27,13 +27,14 @@ template <typename InOutDataType,
bool kUseSoftmax,
bool kHasBias,
bool kHasDropout,
ck_tile::index_t MaxK>
ck_tile::index_t MaxK,
ck_tile::index_t MTile>
struct jagged_forward_causal_softmax_bias_dropout_dispatch
{
using HstuAttentionTileSetting =
typename std::conditional_t<kUseSoftmax,
HstuAttentionWithSoftmaxFwdTileSetting<MaxK>,
HstuAttentionNoSoftmaxFwdTileSetting<MaxK>>::Type;
HstuAttentionWithSoftmaxFwdTileSetting<MaxK, MTile>,
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
@@ -178,10 +179,20 @@ template <typename InOutDataType,
void run_jagged_forward_causal_softmax_bias_dropout_dispatch(HstuAttentionNoGroupFwdParams& param,
hipStream_t stream)
{
jagged_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen) == 128)
jagged_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
128>::Run(param, stream);
else
jagged_forward_causal_softmax_bias_dropout_dispatch<InOutDataType,
kUseCausal,
kUseSoftmax,
kHasBias,
kHasDropout,
MaxK,
64>::Run(param, stream);
};

View File

@@ -8,6 +8,8 @@
#include <sstream>
#include <stdexcept>
#include "ck_tile/host/hip_check_error.hpp"
#define HSTU_CHECK(COND, ERR) \
if(!(COND)) \
{ \
@@ -15,3 +17,16 @@
ostr << "'" #COND "' failed: " << ERR; \
throw std::runtime_error(ostr.str()); \
}
static inline int get_number_of_cu()
{
int device;
HIP_CHECK_ERROR(hipGetDevice(&device));
hipDeviceProp_t props;
HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device));
return props.multiProcessorCount;
}