diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp index f3c276c1a4..b1c1cc71c6 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -460,7 +460,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS // check whether first V-LdsBufer overlap with last K-LdsBuffer, // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 - if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) { __builtin_amdgcn_s_barrier(); }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index 538ec053cf..2cceeb2591 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -224,13 +224,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS constexpr index_t NumPrefetchK = 2; - static_assert(k1_loops >= NumPrefetchK, "Check failed!"); + static_assert(n0_loops >= NumPrefetchK, "Check failed!"); // only prefetch two k tiles to save vgprs consumption statically_indexed_array k_tiles; - static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) { - k_tiles[i_k1] = load_tile(k_dram_window); + static_for<0, NumPrefetchK, 1>{}([&](auto i_n0) { + k_tiles[i_n0] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kN0Sub, 0}); }); @@ -509,7 +509,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS // check whether first V-LdsBufer overlap with last K-LdsBuffer, // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 - if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) { __builtin_amdgcn_s_barrier(); }; @@ -519,7 +519,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS __builtin_amdgcn_sched_barrier(0x00000001); - static_for{}([&](auto i_k1) { + static_for{}([&](auto i_k1) { // load v_tiles used in current iteration v_tiles[i_k1] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); @@ -597,6 +597,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS // k1_loops >= 2 required shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]); + // check whether second V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((n0_loops - 1) % NumKVLdsBuffers == 3 % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + store_tile( v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index d8e96289bd..b02944815b 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -214,9 +214,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad using k_tile_type = decltype(load_tile(k_dram_window)); - constexpr index_t NumPrefetchK = (k1_loops <= 3) ? 1 : 2; + constexpr index_t NumPrefetchK = (n0_loops <= 3) ? 1 : 2; - static_assert(k1_loops >= NumPrefetchK, "Check failed!"); + static_assert(n0_loops >= NumPrefetchK, "Check failed!"); static_assert(k1_loops >= 2, "k1_loops >= 2 required due to pre-storing two v_tiles to Lds"); @@ -224,8 +224,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad // only prefetch two k tiles to save vgprs consumption statically_indexed_array k_tiles; - static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) { - k_tiles[i_k1] = load_tile(k_dram_window); + static_for<0, NumPrefetchK, 1>{}([&](auto i_n0) { + k_tiles[i_n0] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kN0Sub, 0}); }); @@ -485,7 +485,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad __builtin_amdgcn_sched_barrier(0x00000001); - static_for{}([&](auto i_k1) { + static_for{}([&](auto i_k1) { // load v_tiles used in current iteration v_tiles[i_k1] = load_tile(v_dram_window); move_tile_window(v_dram_window, {kK1, 0});