diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index c64bcb1a93..17b6aa7350 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -27,7 +27,10 @@ template struct batched_forward_causal_softmax_bias_dropout_dispatch { - using HstuAttentionTileSetting = typename HstuAttentionFwdTileSetting::Type; + using HstuAttentionTileSetting = + typename std::conditional_t, + HstuAttentionNoSoftmaxFwdTileSetting>::Type; template using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem< diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index db413b14b4..66e28db661 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -8,20 +8,17 @@ #include "hstu_attention_fwd_type_config.hpp" #include "hstu_attention_tile_setting_define.hpp" -template -struct HstuAttentionFwdBlockTile; - using HstuAttentionFwdWarpTile1 = ck_tile::sequence<16, 16, 16>; using HstuAttentionFwdWarpTile2 = ck_tile::sequence<16, 16, 32>; -template -struct HstuAttentionFwdTileSetting; - #if !defined(BUILD_HSTU_FOR_GFX95_ONLY) +template +struct HstuAttentionNoSoftmaxFwdBlockTile; + // Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) // template <> -struct HstuAttentionFwdBlockTile<32> +struct HstuAttentionNoSoftmaxFwdBlockTile<32> { using type = ck_tile::sequence<64, 64, 32, 32, 32>; using gemm0_warps = ck_tile::sequence<2, 1, 1>; @@ -29,7 +26,7 @@ struct HstuAttentionFwdBlockTile<32> }; template <> -struct HstuAttentionFwdBlockTile<64> +struct HstuAttentionNoSoftmaxFwdBlockTile<64> { using type = ck_tile::sequence<128, 64, 64, 32, 64>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; @@ -37,7 +34,7 @@ struct HstuAttentionFwdBlockTile<64> }; template <> -struct HstuAttentionFwdBlockTile<128> +struct HstuAttentionNoSoftmaxFwdBlockTile<128> { using type = ck_tile::sequence<128, 32, 128, 16, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; @@ -45,63 +42,153 @@ struct HstuAttentionFwdBlockTile<128> }; template <> -struct HstuAttentionFwdBlockTile<256> +struct HstuAttentionNoSoftmaxFwdBlockTile<256> { using type = ck_tile::sequence<128, 32, 256, 16, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; +template +struct HstuAttentionWithSoftmaxFwdBlockTile; + +// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) +// template <> -struct HstuAttentionFwdTileSetting<32> +struct HstuAttentionWithSoftmaxFwdBlockTile<32> +{ + using type = ck_tile::sequence<64, 64, 32, 32, 32>; + using gemm0_warps = ck_tile::sequence<2, 1, 1>; + using gemm1_warps = ck_tile::sequence<2, 1, 1>; +}; + +template <> +struct HstuAttentionWithSoftmaxFwdBlockTile<64> +{ + using type = ck_tile::sequence<128, 64, 64, 32, 64>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template <> +struct HstuAttentionWithSoftmaxFwdBlockTile<128> +{ + using type = ck_tile::sequence<128, 32, 128, 16, 128>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template <> +struct HstuAttentionWithSoftmaxFwdBlockTile<256> +{ + using type = ck_tile::sequence<128, 32, 256, 16, 256>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template +struct HstuAttentionNoSoftmaxFwdTileSetting; + +template <> +struct HstuAttentionNoSoftmaxFwdTileSetting<32> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< - typename HstuAttentionFwdBlockTile<32>::type, - typename HstuAttentionFwdBlockTile<32>::gemm0_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<32>::type, + typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm0_warps, HstuAttentionFwdWarpTile1, - typename HstuAttentionFwdBlockTile<32>::gemm1_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm1_warps, HstuAttentionFwdWarpTile1>; }; template <> -struct HstuAttentionFwdTileSetting<64> +struct HstuAttentionNoSoftmaxFwdTileSetting<64> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< - typename HstuAttentionFwdBlockTile<64>::type, - typename HstuAttentionFwdBlockTile<64>::gemm0_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<64>::type, + typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm0_warps, HstuAttentionFwdWarpTile1, - typename HstuAttentionFwdBlockTile<64>::gemm1_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm1_warps, HstuAttentionFwdWarpTile1>; }; template <> -struct HstuAttentionFwdTileSetting<128> +struct HstuAttentionNoSoftmaxFwdTileSetting<128> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< - typename HstuAttentionFwdBlockTile<128>::type, - typename HstuAttentionFwdBlockTile<128>::gemm0_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<128>::type, + typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm0_warps, HstuAttentionFwdWarpTile1, - typename HstuAttentionFwdBlockTile<128>::gemm1_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm1_warps, HstuAttentionFwdWarpTile1>; }; template <> -struct HstuAttentionFwdTileSetting<256> +struct HstuAttentionNoSoftmaxFwdTileSetting<256> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< - typename HstuAttentionFwdBlockTile<256>::type, - typename HstuAttentionFwdBlockTile<256>::gemm0_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<256>::type, + typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm0_warps, HstuAttentionFwdWarpTile1, - typename HstuAttentionFwdBlockTile<256>::gemm1_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm1_warps, + HstuAttentionFwdWarpTile1>; +}; + +template +struct HstuAttentionWithSoftmaxFwdTileSetting; + +template <> +struct HstuAttentionWithSoftmaxFwdTileSetting<32> +{ + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionWithSoftmaxFwdBlockTile<32>::type, + typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm0_warps, + HstuAttentionFwdWarpTile1, + typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm1_warps, + HstuAttentionFwdWarpTile1>; +}; + +template <> +struct HstuAttentionWithSoftmaxFwdTileSetting<64> +{ + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionWithSoftmaxFwdBlockTile<64>::type, + typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm0_warps, + HstuAttentionFwdWarpTile1, + typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm1_warps, + HstuAttentionFwdWarpTile1>; +}; + +template <> +struct HstuAttentionWithSoftmaxFwdTileSetting<128> +{ + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionWithSoftmaxFwdBlockTile<128>::type, + typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm0_warps, + HstuAttentionFwdWarpTile1, + typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm1_warps, + HstuAttentionFwdWarpTile1>; +}; + +template <> +struct HstuAttentionWithSoftmaxFwdTileSetting<256> +{ + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionWithSoftmaxFwdBlockTile<256>::type, + typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm0_warps, + HstuAttentionFwdWarpTile1, + typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm1_warps, HstuAttentionFwdWarpTile1>; }; #endif #if defined(BUILD_HSTU_FOR_GFX95_ONLY) +template +struct HstuAttentionNoSoftmaxFwdBlockTile; + // Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) // template <> -struct HstuAttentionFwdBlockTile<32> +struct HstuAttentionNoSoftmaxFwdBlockTile<32> { using type = ck_tile::sequence<64, 64, 32, 16, 32>; using gemm0_warps = ck_tile::sequence<2, 1, 1>; @@ -109,7 +196,7 @@ struct HstuAttentionFwdBlockTile<32> }; template <> -struct HstuAttentionFwdBlockTile<64> +struct HstuAttentionNoSoftmaxFwdBlockTile<64> { using type = ck_tile::sequence<128, 64, 64, 32, 64>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; @@ -117,7 +204,7 @@ struct HstuAttentionFwdBlockTile<64> }; template <> -struct HstuAttentionFwdBlockTile<128> +struct HstuAttentionNoSoftmaxFwdBlockTile<128> { using type = ck_tile::sequence<128, 32, 128, 16, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; @@ -125,54 +212,141 @@ struct HstuAttentionFwdBlockTile<128> }; template <> -struct HstuAttentionFwdBlockTile<256> +struct HstuAttentionNoSoftmaxFwdBlockTile<256> { using type = ck_tile::sequence<128, 32, 256, 16, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; +template +struct HstuAttentionWithSoftmaxFwdBlockTile; + +// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0) +// template <> -struct HstuAttentionFwdTileSetting<32> +struct HstuAttentionWithSoftmaxFwdBlockTile<32> +{ + using type = ck_tile::sequence<64, 64, 32, 16, 32>; + using gemm0_warps = ck_tile::sequence<2, 1, 1>; + using gemm1_warps = ck_tile::sequence<2, 1, 1>; +}; + +template <> +struct HstuAttentionWithSoftmaxFwdBlockTile<64> +{ + using type = ck_tile::sequence<128, 64, 64, 32, 64>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template <> +struct HstuAttentionWithSoftmaxFwdBlockTile<128> +{ + using type = ck_tile::sequence<128, 32, 128, 16, 128>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template <> +struct HstuAttentionWithSoftmaxFwdBlockTile<256> +{ + using type = ck_tile::sequence<128, 32, 256, 16, 256>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template +struct HstuAttentionNoSoftmaxFwdTileSetting; + +template <> +struct HstuAttentionNoSoftmaxFwdTileSetting<32> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< - typename HstuAttentionFwdBlockTile<32>::type, - typename HstuAttentionFwdBlockTile<32>::gemm0_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<32>::type, + typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm0_warps, HstuAttentionFwdWarpTile2, - typename HstuAttentionFwdBlockTile<32>::gemm1_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm1_warps, HstuAttentionFwdWarpTile1>; }; template <> -struct HstuAttentionFwdTileSetting<64> +struct HstuAttentionNoSoftmaxFwdTileSetting<64> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< - typename HstuAttentionFwdBlockTile<64>::type, - typename HstuAttentionFwdBlockTile<64>::gemm0_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<64>::type, + typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm0_warps, HstuAttentionFwdWarpTile2, - typename HstuAttentionFwdBlockTile<64>::gemm1_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm1_warps, HstuAttentionFwdWarpTile1>; }; template <> -struct HstuAttentionFwdTileSetting<128> +struct HstuAttentionNoSoftmaxFwdTileSetting<128> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< - typename HstuAttentionFwdBlockTile<128>::type, - typename HstuAttentionFwdBlockTile<128>::gemm0_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<128>::type, + typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm0_warps, HstuAttentionFwdWarpTile2, - typename HstuAttentionFwdBlockTile<128>::gemm1_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm1_warps, HstuAttentionFwdWarpTile1>; }; template <> -struct HstuAttentionFwdTileSetting<256> +struct HstuAttentionNoSoftmaxFwdTileSetting<256> { using Type = ck_tile::HstuAttentionFwdTileSettingClass< - typename HstuAttentionFwdBlockTile<256>::type, - typename HstuAttentionFwdBlockTile<256>::gemm0_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<256>::type, + typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm0_warps, HstuAttentionFwdWarpTile2, - typename HstuAttentionFwdBlockTile<256>::gemm1_warps, + typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm1_warps, + HstuAttentionFwdWarpTile1>; +}; + +template +struct HstuAttentionWithSoftmaxFwdTileSetting; + +template <> +struct HstuAttentionWithSoftmaxFwdTileSetting<32> +{ + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionWithSoftmaxFwdBlockTile<32>::type, + typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm0_warps, + HstuAttentionFwdWarpTile2, + typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm1_warps, + HstuAttentionFwdWarpTile1>; +}; + +template <> +struct HstuAttentionWithSoftmaxFwdTileSetting<64> +{ + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionWithSoftmaxFwdBlockTile<64>::type, + typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm0_warps, + HstuAttentionFwdWarpTile2, + typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm1_warps, + HstuAttentionFwdWarpTile1>; +}; + +template <> +struct HstuAttentionWithSoftmaxFwdTileSetting<128> +{ + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionWithSoftmaxFwdBlockTile<128>::type, + typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm0_warps, + HstuAttentionFwdWarpTile2, + typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm1_warps, + HstuAttentionFwdWarpTile1>; +}; + +template <> +struct HstuAttentionWithSoftmaxFwdTileSetting<256> +{ + using Type = ck_tile::HstuAttentionFwdTileSettingClass< + typename HstuAttentionWithSoftmaxFwdBlockTile<256>::type, + typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm0_warps, + HstuAttentionFwdWarpTile2, + typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm1_warps, HstuAttentionFwdWarpTile1>; }; #endif diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index 2165cd79ec..e15192781b 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -27,7 +27,10 @@ template struct jagged_forward_causal_softmax_bias_dropout_dispatch { - using HstuAttentionTileSetting = typename HstuAttentionFwdTileSetting::Type; + using HstuAttentionTileSetting = + typename std::conditional_t, + HstuAttentionNoSoftmaxFwdTileSetting>::Type; template using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<