diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 1f76900000..80b637b051 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -315,8 +315,7 @@ struct HstuAttentionFwdPipelineQRKSVS if constexpr(i_k0 == 0) clear_tile(s_acc); - if constexpr(i_k0 < k0_loops - 1) - k_tile = load_tile(k_dram_window); + k_tile = load_tile(k_dram_window); if constexpr(i_k0 < k0_loops - 2) move_tile_window(k_dram_window, {0, kK0}); @@ -389,6 +388,18 @@ struct HstuAttentionFwdPipelineQRKSVS const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); return !mask.IsTokenPairInsideMask(row, col); }); + } + else + { + if(q_origin.at(number<0>{}) + kM0 > mask.max_uih_len) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + set_tile_if(s_acc, type_convert(0), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !mask.IsTokenPairInsideMask(row, col); + }); + }; }; auto s = cast_tile(s_acc); // S{j} @@ -404,6 +415,8 @@ struct HstuAttentionFwdPipelineQRKSVS randval_lds_ptr, seqlen_k_start + i_loop * kN0, s, null_randval_window); } + // ensure gemm_0 has finished access of k-Lds for all warps + __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0x7f); if constexpr(std::is_same_v) @@ -426,6 +439,12 @@ struct HstuAttentionFwdPipelineQRKSVS const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, s)); + move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0}); + k_tile = load_tile(k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + + __builtin_amdgcn_sched_barrier(0); + // STAGE 3, KV gemm if constexpr(k1_loops > 1) { diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index d2ededb305..ad165e7c00 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -305,7 +305,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return BlockGemmARegBSmemCRegOneWarpV1{}; } - // leave some exclusive space so that the second v_lds buffer will nenver overlap with the first + // leave some exclusive space so that the second v_lds buffer will never overlap with the first // k_lds bufffer template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes()