diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index c605e1dfca..c9802b46c2 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -229,8 +229,8 @@ struct HstuAttentionFwdPipelineQRKSVS using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); using PcompBlockTileType = decltype(cast_tile(SaccBlockTileType{})); - statically_indexed_array sacc_tiles; - statically_indexed_array pcomp_tiles; + SaccBlockTileType sacc_tile; + PcompBlockTileType pcomp_tile; // reduction function for softmax const auto f_silu = [&](CompDataType& x) { @@ -314,9 +314,9 @@ struct HstuAttentionFwdPipelineQRKSVS block_sync_lds(); // execute current unroll of gemm_0 - gemm_0(sacc_tiles[i_k1], q_tile, k_lds_windows[number{}]); + gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); - sacc_tiles[i_k1] = tile_elementwise_in(s_acc_element_func, sacc_tiles[i_k1]); + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); // STAGE 2, scale_s, add bias, mask, siLU if constexpr(kHasBias) @@ -327,15 +327,14 @@ struct HstuAttentionFwdPipelineQRKSVS [&scale_s, &bias_element_func](auto& x, const auto& y) { x = x * scale_s - type_convert(bias_element_func(y)); }, - sacc_tiles[i_k1], + sacc_tile, bias_tile); move_tile_window(bias_dram_window, {0, kK1}); } else { - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, - sacc_tiles[i_k1]); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, sacc_tile); } if constexpr(HstuMask::IsMasking) @@ -346,15 +345,14 @@ struct HstuAttentionFwdPipelineQRKSVS sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( - sacc_tiles[i_k1].get_tile_distribution(), - make_tuple(idx0, idx1)); + sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1)); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); constexpr auto i_j_idx = make_tuple(idx0, idx1); - sacc_tiles[i_k1](i_j_idx) *= static_cast( + sacc_tile(i_j_idx) *= static_cast( mask.IsTokenPairInsideMask(row, col)); }); }); @@ -370,15 +368,14 @@ struct HstuAttentionFwdPipelineQRKSVS sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( - sacc_tiles[i_k1].get_tile_distribution(), - make_tuple(idx0, idx1)); + sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1)); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); constexpr auto i_j_idx = make_tuple(idx0, idx1); - sacc_tiles[i_k1](i_j_idx) *= static_cast( + sacc_tile(i_j_idx) *= static_cast( mask.IsTokenPairInsideMask(row, col)); }); }); @@ -393,22 +390,21 @@ struct HstuAttentionFwdPipelineQRKSVS sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( - sacc_tiles[i_k1].get_tile_distribution(), - make_tuple(idx0, idx1)); + sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1)); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); constexpr auto i_j_idx = make_tuple(idx0, idx1); - sacc_tiles[i_k1](i_j_idx) *= static_cast( + sacc_tile(i_j_idx) *= static_cast( mask.IsTokenPairInsideMask(row, col)); }); }); } } - pcomp_tiles[i_k1] = cast_tile(sacc_tiles[i_k1]); + pcomp_tile = cast_tile(sacc_tile); if constexpr(std::is_same_v) { @@ -438,7 +434,7 @@ struct HstuAttentionFwdPipelineQRKSVS __builtin_amdgcn_sched_barrier(0); - tile_elementwise_inout(f_silu, pcomp_tiles[i_k1]); + tile_elementwise_inout(f_silu, pcomp_tile); if constexpr(kHasDropout) { @@ -446,11 +442,11 @@ struct HstuAttentionFwdPipelineQRKSVS Policy::template GetSmemSizeKV(); dropout.template Run( - randval_lds_ptr, seqlen_k_curr, pcomp_tiles[i_k1], null_randval_window); + randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); } - auto p = cast_tile( - tile_elementwise_in(p_compute_element_func, pcomp_tiles[i_k1])); + auto p = + cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); block_sync_lds();