diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 66fcab474a..c1e5594ea8 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -736,21 +736,10 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy #ifdef __gfx950__ static_assert(WarpGemmM == 16 && WarpGemmK == 32, "Not supported WarpGemm sizes!"); - - return WarpGemmDispatcher< - typename Problem::QKVDataType, - typename Problem::QKVDataType, - typename Problem::GemmAccDataType, - Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{}), - Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<1>{}), - Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{}), - true, - false, - false, - WGAttrNumAccessEnum::Single>{}; #else - static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)), + static_assert(WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32), "Not supported WarpGemm sizes!"); +#endif return WarpGemmDispatcher< typename Problem::QKVDataType, @@ -763,7 +752,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy false, false, WGAttrNumAccessEnum::Single>{}; -#endif } else { @@ -815,8 +803,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t WarpGemmK = Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}); - static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)), +#ifdef __gfx950__ + static_assert(WarpGemmM == 16 && WarpGemmK == 32, "Not supported WarpGemm sizes!"); +#else + static_assert(WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32), "Not supported WarpGemm sizes!"); +#endif if constexpr(WarpGemmK == 32) return WarpGemmDispatcher< diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index ea46255d06..c861bfd61e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -190,7 +190,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile; template <> struct HstuAttentionNoSoftmaxFwdBlockTile<32> { - using type = ck_tile::sequence<64, 64, 32, 16, 32>; + using type = ck_tile::sequence<64, 64, 32, 32, 32>; using gemm0_warps = ck_tile::sequence<2, 1, 1>; using gemm1_warps = ck_tile::sequence<2, 1, 1>; }; @@ -214,7 +214,7 @@ struct HstuAttentionNoSoftmaxFwdBlockTile<128> template <> struct HstuAttentionNoSoftmaxFwdBlockTile<256> { - using type = ck_tile::sequence<128, 32, 256, 16, 256>; + using type = ck_tile::sequence<128, 32, 256, 32, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -227,7 +227,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile; template <> struct HstuAttentionWithSoftmaxFwdBlockTile<32> { - using type = ck_tile::sequence<64, 64, 32, 16, 32>; + using type = ck_tile::sequence<64, 64, 32, 32, 32>; using gemm0_warps = ck_tile::sequence<2, 1, 1>; using gemm1_warps = ck_tile::sequence<2, 1, 1>; }; @@ -251,7 +251,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<128> template <> struct HstuAttentionWithSoftmaxFwdBlockTile<256> { - using type = ck_tile::sequence<128, 32, 256, 16, 256>; + using type = ck_tile::sequence<128, 64, 256, 32, 256>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -267,7 +267,7 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<32> typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm0_warps, HstuAttentionFwdWarpTile2, typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm1_warps, - HstuAttentionFwdWarpTile1>; + HstuAttentionFwdWarpTile2>; }; template <> @@ -278,7 +278,7 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<64> typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm0_warps, HstuAttentionFwdWarpTile2, typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm1_warps, - HstuAttentionFwdWarpTile1>; + HstuAttentionFwdWarpTile2>; }; template <> @@ -300,7 +300,7 @@ struct HstuAttentionNoSoftmaxFwdTileSetting<256> typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm0_warps, HstuAttentionFwdWarpTile2, typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm1_warps, - HstuAttentionFwdWarpTile1>; + HstuAttentionFwdWarpTile2>; }; template @@ -314,7 +314,7 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<32> typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm0_warps, HstuAttentionFwdWarpTile2, typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm1_warps, - HstuAttentionFwdWarpTile1>; + HstuAttentionFwdWarpTile2>; }; template <> @@ -325,7 +325,7 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<64> typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm0_warps, HstuAttentionFwdWarpTile2, typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm1_warps, - HstuAttentionFwdWarpTile1>; + HstuAttentionFwdWarpTile2>; }; template <> @@ -347,6 +347,6 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<256> typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm0_warps, HstuAttentionFwdWarpTile2, typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm1_warps, - HstuAttentionFwdWarpTile1>; + HstuAttentionFwdWarpTile2>; }; #endif diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index d396e22b87..adb6032317 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -212,6 +212,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad static_assert(k1_loops >= NumPrefetchK, "Check failed!"); + static_assert(k1_loops >= 2, + "k1_loops >= 2 required due to pre-storing two v_tiles to Lds"); + // only prefetch two k tiles to save vgprs consumption statically_indexed_array k_tiles; @@ -556,6 +559,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad __builtin_amdgcn_s_barrier(); }; + // k1_loops >= 2 required store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_tiles[number<1>{}], partition_index);