diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index b71cecbece..e8fa211222 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -161,6 +161,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS constexpr index_t k1_loops = kN0 / kK1; + static_assert(k1_loops >= 2, + "k1_loops >= 2 required due to pre-storing two v_tiles to Lds"); + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); // Block GEMM @@ -585,6 +588,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + // k1_loops >= 2 required shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]); store_tile( diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index adb6032317..9af1a1b610 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -160,6 +160,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad constexpr index_t k1_loops = kN0 / kK1; + static_assert(k1_loops >= 2, + "k1_loops >= 2 required due to pre-storing two v_tiles to Lds"); + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); // Block GEMM