diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 293b191eaa..e68f3ad85e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -356,7 +356,7 @@ struct HstuAttentionFwdPipelineQRKSVS set_tile_if( sacc_tiles[i_k1], type_convert(0), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + const auto col = seqlen_k_curr + i_k1 * kK1 + tile_idx.at(number<1>{}); return !mask.IsTokenPairInsideMask(row, col); }); } @@ -368,26 +368,14 @@ struct HstuAttentionFwdPipelineQRKSVS sacc_tiles[i_k1], type_convert(0), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + const auto col = + seqlen_k_curr + i_k1 * kK1 + tile_idx.at(number<1>{}); return !mask.IsTokenPairInsideMask(row, col); }); } } pcomp_tiles[i_k1] = cast_tile(sacc_tiles[i_k1]); - - tile_elementwise_inout(f_silu, pcomp_tiles[i_k1]); - - if constexpr(kHasDropout) - { - auto randval_lds_ptr = reinterpret_cast(smem_ptr) + - Policy::template GetSmemSizeKV(); - - dropout.template Run( - randval_lds_ptr, seqlen_k_curr, pcomp_tiles[i_k1], null_randval_window); - } - - seqlen_k_curr += kK1; }); // load one k_tile for next iteration @@ -419,16 +407,29 @@ struct HstuAttentionFwdPipelineQRKSVS tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch }; - static_for<0, k1_loops, 1>{}([&](auto i_k1) { - const auto p = [&]() { - if constexpr(std::is_same_v) - return impl::cast_tile_pk_fp16_fp32( - tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1])); - else - return cast_tile( - tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1])); - }(); + tile_elementwise_inout(f_silu, pcomp_tiles[I0]); + if constexpr(kHasDropout) + { + auto randval_lds_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + + dropout.template Run( + randval_lds_ptr, seqlen_k_curr, pcomp_tiles[I0], null_randval_window); + } + + seqlen_k_curr += kK1; + + auto p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pk_fp16_fp32( + tile_elementwise_in(p_compute_element_func, pcomp_tiles[I0])); + else + return cast_tile( + tile_elementwise_in(p_compute_element_func, pcomp_tiles[I0])); + }(); + + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { if constexpr(i_k1 < k1_loops - NumPrefetchV) { v_tiles[number{}] = load_tile(v_dram_window); @@ -436,31 +437,56 @@ struct HstuAttentionFwdPipelineQRKSVS }; block_sync_lds(); + gemm_1(o_acc, p, v_lds_windows[number{}]); + tile_elementwise_inout(f_silu, pcomp_tiles[number{}]); - if constexpr(i_k1 < k1_loops - 1) + if constexpr(std::is_same_v) { - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]); + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]); - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, - v_shuffle_tmp)); // store the prefetch - } + store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], + tile_elementwise_in( + v_element_func, + v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); // store the + // prefetch + } + + if constexpr(kHasDropout) + { + auto randval_lds_ptr = reinterpret_cast(smem_ptr) + + Policy::template GetSmemSizeKV(); + + dropout.template Run( + randval_lds_ptr, + seqlen_k_curr, + pcomp_tiles[number{}], + null_randval_window); + } + + seqlen_k_curr += kK1; + + p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pk_fp16_fp32(tile_elementwise_in( + p_compute_element_func, pcomp_tiles[number{}])); else - { - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in( - v_element_func, - v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); // store the - // prefetch - } - }; + return cast_tile(tile_elementwise_in( + p_compute_element_func, pcomp_tiles[number{}])); + }(); }); + block_sync_lds(); + gemm_1(o_acc, p, v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]); + // the over-lap only occurs when k1_loops is 3/5/7, NumVLdsBuffers is 2 if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer()) __builtin_amdgcn_s_barrier();