diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index 513bb53f75..7df548daeb 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -27,13 +27,14 @@ template + ck_tile::index_t MaxK, + ck_tile::index_t MTile> struct batched_forward_causal_softmax_bias_dropout_dispatch { using HstuAttentionTileSetting = typename std::conditional_t, - HstuAttentionNoSoftmaxFwdTileSetting>::Type; + HstuAttentionWithSoftmaxFwdTileSetting, + HstuAttentionNoSoftmaxFwdTileSetting>::Type; #ifdef BUILD_HSTU_FOR_GFX95_ONLY static constexpr bool kUseTrLoad = true; @@ -189,10 +190,20 @@ template ::Run(param, stream); + if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen) == 128) + batched_forward_causal_softmax_bias_dropout_dispatch::Run(param, stream); + else + batched_forward_causal_softmax_bias_dropout_dispatch::Run(param, stream); }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index 7a0641f7e2..c1e6c3c941 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -7,27 +7,28 @@ #include "hstu_attention_fwd_type_config.hpp" #include "hstu_attention_tile_setting_define.hpp" +#include "hstu_attention_util.hpp" using WarpTile_16x16x16 = ck_tile::sequence<16, 16, 16>; using WarpTile_16x16x32 = ck_tile::sequence<16, 16, 32>; using WarpTile_32x32x16 = ck_tile::sequence<32, 32, 16>; #if !defined(BUILD_HSTU_FOR_GFX95_ONLY) -template +template struct HstuAttentionNoSoftmaxFwdBlockTile; // Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) // -template <> -struct HstuAttentionNoSoftmaxFwdBlockTile<64> +template +struct HstuAttentionNoSoftmaxFwdBlockTile<64, MTile> { using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -template <> -struct HstuAttentionNoSoftmaxFwdBlockTile<96> +template +struct HstuAttentionNoSoftmaxFwdBlockTile<96, MTile> { using type = ck_tile::sequence<128, 64, 32, 128, 32, 96>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; @@ -35,36 +36,44 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<96> }; template <> -struct HstuAttentionNoSoftmaxFwdBlockTile<128> +struct HstuAttentionNoSoftmaxFwdBlockTile<128, 64> +{ + using type = ck_tile::sequence<64, 32, 16, 128, 16, 128>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template <> +struct HstuAttentionNoSoftmaxFwdBlockTile<128, 128> { using type = ck_tile::sequence<128, 32, 16, 128, 16, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -template <> -struct HstuAttentionNoSoftmaxFwdBlockTile<256> +template +struct HstuAttentionNoSoftmaxFwdBlockTile<256, MTile> { using type = ck_tile::sequence<128, 32, 16, 256, 16, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -template +template struct HstuAttentionWithSoftmaxFwdBlockTile; // Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) // -template <> -struct HstuAttentionWithSoftmaxFwdBlockTile<64> +template +struct HstuAttentionWithSoftmaxFwdBlockTile<64, MTile> { using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -template <> -struct HstuAttentionWithSoftmaxFwdBlockTile<96> +template +struct HstuAttentionWithSoftmaxFwdBlockTile<96, MTile> { using type = ck_tile::sequence<128, 64, 32, 128, 32, 96>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; @@ -72,26 +81,34 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<96> }; template <> -struct HstuAttentionWithSoftmaxFwdBlockTile<128> +struct HstuAttentionWithSoftmaxFwdBlockTile<128, 64> +{ + using type = ck_tile::sequence<64, 64, 16, 128, 16, 128>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template <> +struct HstuAttentionWithSoftmaxFwdBlockTile<128, 128> { using type = ck_tile::sequence<128, 64, 16, 128, 16, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -template <> -struct HstuAttentionWithSoftmaxFwdBlockTile<256> +template +struct HstuAttentionWithSoftmaxFwdBlockTile<256, MTile> { using type = ck_tile::sequence<128, 32, 16, 256, 16, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -template +template struct HstuAttentionNoSoftmaxFwdTileSetting; -template <> -struct HstuAttentionNoSoftmaxFwdTileSetting<64> +template +struct HstuAttentionNoSoftmaxFwdTileSetting<64, MTile> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< typename HstuAttentionNoSoftmaxFwdBlockTile<64>::type, @@ -101,8 +118,11 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<64> WarpTile_16x16x16>; }; -template <> -struct HstuAttentionNoSoftmaxFwdTileSetting<96> +template struct HstuAttentionNoSoftmaxFwdTileSetting<64, 64>; +template struct HstuAttentionNoSoftmaxFwdTileSetting<64, 128>; + +template +struct HstuAttentionNoSoftmaxFwdTileSetting<96, MTile> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< typename HstuAttentionNoSoftmaxFwdBlockTile<96>::type, @@ -112,19 +132,33 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<96> WarpTile_16x16x16>; }; +template struct HstuAttentionNoSoftmaxFwdTileSetting<96, 64>; +template struct HstuAttentionNoSoftmaxFwdTileSetting<96, 128>; + template <> -struct HstuAttentionNoSoftmaxFwdTileSetting<128> +struct HstuAttentionNoSoftmaxFwdTileSetting<128, 64> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< - typename HstuAttentionNoSoftmaxFwdBlockTile<128>::type, - typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm0_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<128, 64>::type, + typename HstuAttentionNoSoftmaxFwdBlockTile<128, 64>::gemm0_warps, WarpTile_16x16x16, - typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm1_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<128, 64>::gemm1_warps, WarpTile_16x16x16>; }; template <> -struct HstuAttentionNoSoftmaxFwdTileSetting<256> +struct HstuAttentionNoSoftmaxFwdTileSetting<128, 128> +{ + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionNoSoftmaxFwdBlockTile<128, 128>::type, + typename HstuAttentionNoSoftmaxFwdBlockTile<128, 128>::gemm0_warps, + WarpTile_16x16x16, + typename HstuAttentionNoSoftmaxFwdBlockTile<128, 128>::gemm1_warps, + WarpTile_16x16x16>; +}; + +template +struct HstuAttentionNoSoftmaxFwdTileSetting<256, MTile> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< typename HstuAttentionNoSoftmaxFwdBlockTile<256>::type, @@ -134,11 +168,14 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<256> WarpTile_16x16x16>; }; -template +template struct HstuAttentionNoSoftmaxFwdTileSetting<256, 64>; +template struct HstuAttentionNoSoftmaxFwdTileSetting<256, 128>; + +template struct HstuAttentionWithSoftmaxFwdTileSetting; -template <> -struct HstuAttentionWithSoftmaxFwdTileSetting<64> +template +struct HstuAttentionWithSoftmaxFwdTileSetting<64, MTile> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< typename HstuAttentionWithSoftmaxFwdBlockTile<64>::type, @@ -148,8 +185,11 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<64> WarpTile_32x32x16>; }; -template <> -struct HstuAttentionWithSoftmaxFwdTileSetting<96> +template struct HstuAttentionWithSoftmaxFwdTileSetting<64, 64>; +template struct HstuAttentionWithSoftmaxFwdTileSetting<64, 128>; + +template +struct HstuAttentionWithSoftmaxFwdTileSetting<96, MTile> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< typename HstuAttentionWithSoftmaxFwdBlockTile<96>::type, @@ -159,19 +199,33 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<96> WarpTile_32x32x16>; }; +template struct HstuAttentionWithSoftmaxFwdTileSetting<96, 64>; +template struct HstuAttentionWithSoftmaxFwdTileSetting<96, 128>; + template <> -struct HstuAttentionWithSoftmaxFwdTileSetting<128> +struct HstuAttentionWithSoftmaxFwdTileSetting<128, 64> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< - typename HstuAttentionWithSoftmaxFwdBlockTile<128>::type, - typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm0_warps, + typename HstuAttentionWithSoftmaxFwdBlockTile<128, 64>::type, + typename HstuAttentionWithSoftmaxFwdBlockTile<128, 64>::gemm0_warps, WarpTile_16x16x16, - typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm1_warps, + typename HstuAttentionWithSoftmaxFwdBlockTile<128, 64>::gemm1_warps, WarpTile_16x16x16>; }; template <> -struct HstuAttentionWithSoftmaxFwdTileSetting<256> +struct HstuAttentionWithSoftmaxFwdTileSetting<128, 128> +{ + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionWithSoftmaxFwdBlockTile<128, 128>::type, + typename HstuAttentionWithSoftmaxFwdBlockTile<128, 128>::gemm0_warps, + WarpTile_16x16x16, + typename HstuAttentionWithSoftmaxFwdBlockTile<128, 128>::gemm1_warps, + WarpTile_16x16x16>; +}; + +template +struct HstuAttentionWithSoftmaxFwdTileSetting<256, MTile> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< typename HstuAttentionWithSoftmaxFwdBlockTile<256>::type, @@ -180,32 +234,27 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<256> typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm1_warps, WarpTile_16x16x16>; }; + +template struct HstuAttentionWithSoftmaxFwdTileSetting<256, 64>; +template struct HstuAttentionWithSoftmaxFwdTileSetting<256, 128>; #endif #if defined(BUILD_HSTU_FOR_GFX95_ONLY) -template +template struct HstuAttentionNoSoftmaxFwdBlockTile; // Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) // -template <> -struct HstuAttentionNoSoftmaxFwdBlockTile<32> -{ - using type = ck_tile::sequence<64, 64, 32, 32, 32, 32>; - using gemm0_warps = ck_tile::sequence<2, 1, 1>; - using gemm1_warps = ck_tile::sequence<2, 1, 1>; -}; - -template <> -struct HstuAttentionNoSoftmaxFwdBlockTile<64> +template +struct HstuAttentionNoSoftmaxFwdBlockTile<64, MTile> { using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -template <> -struct HstuAttentionNoSoftmaxFwdBlockTile<96> +template +struct HstuAttentionNoSoftmaxFwdBlockTile<96, MTile> { using type = ck_tile::sequence<128, 64, 32, 128, 32, 96>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; @@ -213,44 +262,44 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<96> }; template <> -struct HstuAttentionNoSoftmaxFwdBlockTile<128> +struct HstuAttentionNoSoftmaxFwdBlockTile<128, 64> +{ + using type = ck_tile::sequence<64, 32, 32, 128, 32, 128>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template <> +struct HstuAttentionNoSoftmaxFwdBlockTile<128, 128> { using type = ck_tile::sequence<128, 32, 32, 128, 32, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -template <> -struct HstuAttentionNoSoftmaxFwdBlockTile<256> +template +struct HstuAttentionNoSoftmaxFwdBlockTile<256, MTile> { using type = ck_tile::sequence<128, 32, 32, 256, 32, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -template +template struct HstuAttentionWithSoftmaxFwdBlockTile; // Tile-sizes: M N0 N0Sub N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) // -template <> -struct HstuAttentionWithSoftmaxFwdBlockTile<32> -{ - using type = ck_tile::sequence<64, 64, 32, 32, 32, 32>; - using gemm0_warps = ck_tile::sequence<2, 1, 1>; - using gemm1_warps = ck_tile::sequence<2, 1, 1>; -}; - -template <> -struct HstuAttentionWithSoftmaxFwdBlockTile<64> +template +struct HstuAttentionWithSoftmaxFwdBlockTile<64, MTile> { using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -template <> -struct HstuAttentionWithSoftmaxFwdBlockTile<96> +template +struct HstuAttentionWithSoftmaxFwdBlockTile<96, MTile> { using type = ck_tile::sequence<128, 64, 32, 128, 32, 96>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; @@ -258,37 +307,34 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<96> }; template <> -struct HstuAttentionWithSoftmaxFwdBlockTile<128> +struct HstuAttentionWithSoftmaxFwdBlockTile<128, 64> +{ + using type = ck_tile::sequence<64, 64, 32, 128, 32, 128>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template <> +struct HstuAttentionWithSoftmaxFwdBlockTile<128, 128> { using type = ck_tile::sequence<128, 64, 32, 128, 16, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -template <> -struct HstuAttentionWithSoftmaxFwdBlockTile<256> +template +struct HstuAttentionWithSoftmaxFwdBlockTile<256, MTile> { using type = ck_tile::sequence<128, 64, 16, 256, 32, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; -template +template struct HstuAttentionNoSoftmaxFwdTileSetting; -template <> -struct HstuAttentionNoSoftmaxFwdTileSetting<32> -{ - using Type = ck_tile::HstuAttentionFwdTileSettingClass< - typename HstuAttentionNoSoftmaxFwdBlockTile<32>::type, - typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm0_warps, - WarpTile_16x16x32, - typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm1_warps, - WarpTile_16x16x32>; -}; - -template <> -struct HstuAttentionNoSoftmaxFwdTileSetting<64> +template +struct HstuAttentionNoSoftmaxFwdTileSetting<64, MTile> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< typename HstuAttentionNoSoftmaxFwdBlockTile<64>::type, @@ -298,8 +344,11 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<64> WarpTile_16x16x32>; }; -template <> -struct HstuAttentionNoSoftmaxFwdTileSetting<96> +template struct HstuAttentionNoSoftmaxFwdTileSetting<64, 64>; +template struct HstuAttentionNoSoftmaxFwdTileSetting<64, 128>; + +template +struct HstuAttentionNoSoftmaxFwdTileSetting<96, MTile> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< typename HstuAttentionNoSoftmaxFwdBlockTile<96>::type, @@ -309,19 +358,33 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<96> WarpTile_16x16x32>; }; +template struct HstuAttentionNoSoftmaxFwdTileSetting<96, 64>; +template struct HstuAttentionNoSoftmaxFwdTileSetting<96, 128>; + template <> -struct HstuAttentionNoSoftmaxFwdTileSetting<128> +struct HstuAttentionNoSoftmaxFwdTileSetting<128, 64> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< - typename HstuAttentionNoSoftmaxFwdBlockTile<128>::type, - typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm0_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<128, 64>::type, + typename HstuAttentionNoSoftmaxFwdBlockTile<128, 64>::gemm0_warps, WarpTile_16x16x32, - typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm1_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<128, 64>::gemm1_warps, WarpTile_16x16x32>; }; template <> -struct HstuAttentionNoSoftmaxFwdTileSetting<256> +struct HstuAttentionNoSoftmaxFwdTileSetting<128, 128> +{ + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionNoSoftmaxFwdBlockTile<128, 128>::type, + typename HstuAttentionNoSoftmaxFwdBlockTile<128, 128>::gemm0_warps, + WarpTile_16x16x32, + typename HstuAttentionNoSoftmaxFwdBlockTile<128, 128>::gemm1_warps, + WarpTile_16x16x32>; +}; + +template +struct HstuAttentionNoSoftmaxFwdTileSetting<256, MTile> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< typename HstuAttentionNoSoftmaxFwdBlockTile<256>::type, @@ -331,22 +394,14 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<256> WarpTile_16x16x32>; }; -template +template struct HstuAttentionNoSoftmaxFwdTileSetting<256, 64>; +template struct HstuAttentionNoSoftmaxFwdTileSetting<256, 128>; + +template struct HstuAttentionWithSoftmaxFwdTileSetting; -template <> -struct HstuAttentionWithSoftmaxFwdTileSetting<32> -{ - using Type = ck_tile::HstuAttentionFwdTileSettingClass< - typename HstuAttentionWithSoftmaxFwdBlockTile<32>::type, - typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm0_warps, - WarpTile_16x16x32, - typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm1_warps, - WarpTile_16x16x32>; -}; - -template <> -struct HstuAttentionWithSoftmaxFwdTileSetting<64> +template +struct HstuAttentionWithSoftmaxFwdTileSetting<64, MTile> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< typename HstuAttentionWithSoftmaxFwdBlockTile<64>::type, @@ -356,8 +411,11 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<64> WarpTile_32x32x16>; }; -template <> -struct HstuAttentionWithSoftmaxFwdTileSetting<96> +template struct HstuAttentionWithSoftmaxFwdTileSetting<64, 64>; +template struct HstuAttentionWithSoftmaxFwdTileSetting<64, 128>; + +template +struct HstuAttentionWithSoftmaxFwdTileSetting<96, MTile> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< typename HstuAttentionWithSoftmaxFwdBlockTile<96>::type, @@ -367,19 +425,33 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<96> WarpTile_32x32x16>; }; +template struct HstuAttentionWithSoftmaxFwdTileSetting<96, 64>; +template struct HstuAttentionWithSoftmaxFwdTileSetting<96, 128>; + template <> -struct HstuAttentionWithSoftmaxFwdTileSetting<128> +struct HstuAttentionWithSoftmaxFwdTileSetting<128, 64> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< - typename HstuAttentionWithSoftmaxFwdBlockTile<128>::type, - typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm0_warps, - WarpTile_32x32x16, - typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm1_warps, - WarpTile_32x32x16>; + typename HstuAttentionWithSoftmaxFwdBlockTile<128, 64>::type, + typename HstuAttentionWithSoftmaxFwdBlockTile<128, 64>::gemm0_warps, + WarpTile_16x16x32, + typename HstuAttentionWithSoftmaxFwdBlockTile<128, 64>::gemm1_warps, + WarpTile_16x16x32>; }; template <> -struct HstuAttentionWithSoftmaxFwdTileSetting<256> +struct HstuAttentionWithSoftmaxFwdTileSetting<128, 128> +{ + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionWithSoftmaxFwdBlockTile<128, 128>::type, + typename HstuAttentionWithSoftmaxFwdBlockTile<128, 128>::gemm0_warps, + WarpTile_32x32x16, + typename HstuAttentionWithSoftmaxFwdBlockTile<128, 128>::gemm1_warps, + WarpTile_32x32x16>; +}; + +template +struct HstuAttentionWithSoftmaxFwdTileSetting<256, MTile> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< typename HstuAttentionWithSoftmaxFwdBlockTile<256>::type, @@ -388,4 +460,24 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<256> typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm1_warps, WarpTile_16x16x32>; }; + +template struct HstuAttentionWithSoftmaxFwdTileSetting<256, 64>; +template struct HstuAttentionWithSoftmaxFwdTileSetting<256, 128>; + #endif + +static int get_hstu_attention_fwd_mtile(int num_batches, int num_heads, int max_seqlen_q) +{ + int num_CUs = get_number_of_cu(); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + + int nbatch_nhead_mblocks = num_batches * num_heads * ceildiv(max_seqlen_q, 128); + + // assuming each CU is assigned two work-groups + if(nbatch_nhead_mblocks >= static_cast(0.85f * num_CUs * 2.0f)) + return 128; + + // currently, only hdim-128 actually uses mtile-64, for other hdim, the settings for + // mtile-64 can be added through tuning/verification + return 64; +}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp index c3ea10c74b..58886d05b2 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp @@ -27,13 +27,14 @@ template + ck_tile::index_t MaxK, + ck_tile::index_t MTile> struct group_forward_causal_softmax_bias_dropout_dispatch { using HstuAttentionTileSetting = typename std::conditional_t, - HstuAttentionNoSoftmaxFwdTileSetting>::Type; + HstuAttentionWithSoftmaxFwdTileSetting, + HstuAttentionNoSoftmaxFwdTileSetting>::Type; #ifdef BUILD_HSTU_FOR_GFX95_ONLY static constexpr bool kUseTrLoad = true; @@ -175,10 +176,20 @@ template ::Run(param, stream); + if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen) == 128) + group_forward_causal_softmax_bias_dropout_dispatch::Run(param, stream); + else + group_forward_causal_softmax_bias_dropout_dispatch::Run(param, stream); }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index e37a839606..bf7dda05e0 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -27,13 +27,14 @@ template + ck_tile::index_t MaxK, + ck_tile::index_t MTile> struct jagged_forward_causal_softmax_bias_dropout_dispatch { using HstuAttentionTileSetting = typename std::conditional_t, - HstuAttentionNoSoftmaxFwdTileSetting>::Type; + HstuAttentionWithSoftmaxFwdTileSetting, + HstuAttentionNoSoftmaxFwdTileSetting>::Type; #ifdef BUILD_HSTU_FOR_GFX95_ONLY static constexpr bool kUseTrLoad = true; @@ -178,10 +179,20 @@ template ::Run(param, stream); + if(get_hstu_attention_fwd_mtile(param.num_batch, param.num_head, param.max_seqlen) == 128) + jagged_forward_causal_softmax_bias_dropout_dispatch::Run(param, stream); + else + jagged_forward_causal_softmax_bias_dropout_dispatch::Run(param, stream); }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp index 5df4f70242..0bfc2c565c 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp @@ -8,6 +8,8 @@ #include #include +#include "ck_tile/host/hip_check_error.hpp" + #define HSTU_CHECK(COND, ERR) \ if(!(COND)) \ { \ @@ -15,3 +17,16 @@ ostr << "'" #COND "' failed: " << ERR; \ throw std::runtime_error(ostr.str()); \ } + +static inline int get_number_of_cu() +{ + int device; + + HIP_CHECK_ERROR(hipGetDevice(&device)); + + hipDeviceProp_t props; + + HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device)); + + return props.multiProcessorCount; +}