diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index 43207799eb..ea46255d06 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -243,7 +243,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<64> template <> struct HstuAttentionWithSoftmaxFwdBlockTile<128> { - using type = ck_tile::sequence<128, 64, 128, 16, 128>; + using type = ck_tile::sequence<128, 64, 128, 32, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; @@ -336,7 +336,7 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<128> typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm0_warps, HstuAttentionFwdWarpTile2, typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm1_warps, - HstuAttentionFwdWarpTile1>; + HstuAttentionFwdWarpTile2>; }; template <> diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index fce1809ce3..d396e22b87 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -208,7 +208,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad using k_tile_type = decltype(load_tile(k_dram_window)); - constexpr index_t NumPrefetchK = 2; + constexpr index_t NumPrefetchK = (k1_loops <= 3) ? 1 : 2; static_assert(k1_loops >= NumPrefetchK, "Check failed!");