diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 023de7e655..08e70ada3f 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -487,18 +487,18 @@ struct HstuAttentionFwdPipelineQRKSVS shuffle_tile(v_shuffle_tmp, v_tile); // if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer - // i+1, No overlap occurs between V and K in the same unroll, and V in current - // unroll and K in next unroll or first unrool in next iteration + // i+2, No overlap occurs between V and K in the same unroll, and V in current + // unroll and K in next unroll or first unroll in next iteration store_tile( - v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}], + v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch } else { // if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer - // i+1, No overlap occurs between V and K in the same unroll, and V in current - // unroll and K in next unroll or first unrool in next iteration - store_tile(v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}], + // i+2, No overlap occurs between V and K in the same unroll, and V in current + // unroll and K in next unroll or first unroll in next iteration + store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], tile_elementwise_in(v_element_func, v_tile)); // store the prefetch }; @@ -520,14 +520,29 @@ struct HstuAttentionFwdPipelineQRKSVS block_sync_lds(); - gemm_1(o_acc, p, v_lds_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}]); + gemm_1(o_acc, p, v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); seqlen_k_curr += kK1; - }); - // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 3 - if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer()) - __builtin_amdgcn_s_barrier(); + if constexpr(i_k1 < k1_loops - 1) + { + // check whether current V-LdsBufer overlap with next K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((i_k1 + 2) % NumKVLdsBuffers == (i_k1 + 1) % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + } + else + { + // check whether last V-LdsBuffer overlap with first K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((i_k1 + 2) % NumKVLdsBuffers == 0) + { + __builtin_amdgcn_s_barrier(); + }; + } + }); } while(i_loop++ < num_loops); tile_elementwise_inout( diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index feb458673f..94a015998c 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -20,7 +20,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_DEVICE static constexpr auto GetNumKVLdsBuffers() { - return 3; + return 4; } template @@ -787,17 +787,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return WG::WarpGemmAttribute::Impl::kCM1PerLane; } - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer() - { - using BlockFmhaShape = remove_cvref_t; - - constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1; - constexpr index_t num_kv_lds_buffers = GetNumKVLdsBuffers(); - - return (k1_loops - 1 + 1) % num_kv_lds_buffers == 0; - }; - template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() {