diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 4ff04a69af..8e2a36dabb 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -165,8 +165,12 @@ struct HstuAttentionFwdPipelineQRKSVS constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); - using PcompBlockTileType = decltype(cast_tile(SaccBlockTileType{})); + // SaccBlockTile size is [kM0, kK1] + // PcompBlockTile size is [kM0, kN0] + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + // using PcompBlockTileType = decltype(cast_tile(SaccBlockTileType{})); + using PcompBlockTileType = decltype(make_static_distributed_tensor( + Policy::template MakePRegTileDistribution())); SaccBlockTileType sacc_tile; PcompBlockTileType pcomp_tile; @@ -198,8 +202,14 @@ struct HstuAttentionFwdPipelineQRKSVS move_tile_window(q_dram_window, {kGemmSingleRepM, 0}); }); - auto k_tile = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + using k_tile_type = decltype(load_tile(k_dram_window)); + + statically_indexed_array k_tiles; + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + }); __builtin_amdgcn_sched_barrier(0); @@ -323,14 +333,6 @@ struct HstuAttentionFwdPipelineQRKSVS q_tile_type q_tile; { - constexpr index_t complete_tile_thread_buf_size = q_tile_type::get_thread_buffer_size(); - constexpr index_t splitted_tile_thread_buf_size = - q_reg_tile_type::get_thread_buffer_size(); - - static_assert(complete_tile_thread_buf_size == - kGemmNumRepM * splitted_tile_thread_buf_size, - "Check failed!"); - static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { store_tile(q_lds_write_window, q_dram_tiles[i_rep]); @@ -368,128 +370,142 @@ struct HstuAttentionFwdPipelineQRKSVS using v_tile_type = decltype(load_tile(v_dram_window)); - v_tile_type v_tile; - - store_tile(k_lds_write_windows[number<0>{}], tile_elementwise_in(k_element_func, k_tile)); + statically_indexed_array v_tiles; do { + // STAGE 1, Gemm_0 ( S = Q@K ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { - // load v_tile for current unroll - v_tile = load_tile(v_dram_window); + store_tile(k_lds_write_windows[i_k1], + tile_elementwise_in(k_element_func, k_tiles[i_k1])); + __builtin_amdgcn_sched_barrier(0x00000001); + + // load v_tiles used in current iteration + v_tiles[i_k1] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); - // for i_k1 = k1_loop-1, the loading is for next iteration - k_tile = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); - __builtin_amdgcn_sched_barrier(0x00000001); block_sync_lds(); + // execute current unroll of gemm_0 gemm_0(sacc_tile, q_tile, k_lds_read_windows[number{}]); sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); - // STAGE 2, scale_s, add bias, mask, siLU - if constexpr(kHasBias) - { - const auto bias_tile = load_tile(bias_dram_window); // load bias tile + auto sacc_tile_tmp = cast_tile(sacc_tile); - tile_elementwise_inout( - [&scale_s, &bias_element_func](auto& x, const auto& y) { - x = x * scale_s - type_convert(bias_element_func(y)); - }, - sacc_tile, - bias_tile); + using pcomp_tile_tmp_type = + decltype(get_slice_tile(pcomp_tile, sequence<0, 0>{}, sequence{})); - move_tile_window(bias_dram_window, {0, kK1}); - } - else - { - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, sacc_tile); - } + pcomp_tile_tmp_type pcomp_tile_tmp; - if(!mask.IsFullTileInsideMask( - q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) - { - constexpr auto s_spans = SaccBlockTileType::get_distributed_spans(); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + pcomp_tile_tmp.get_thread_buffer() = sacc_tile_tmp.get_thread_buffer(); - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); + set_slice_tile(pcomp_tile, + pcomp_tile_tmp, + sequence<0, i_k1 * kK1>{}, + sequence{}); + }); - sacc_tile(i_j_idx) *= - static_cast(mask.IsTokenPairInsideMask(row, col)); - }); + __builtin_amdgcn_sched_barrier(0x00000001); + + // STAGE 2, scale_s, add bias, mask, siLU + if constexpr(kHasBias) + { + const auto bias_tile = load_tile(bias_dram_window); + + tile_elementwise_inout( + [&scale_s, &bias_element_func](auto& x, const auto& y) { + x = x * scale_s - type_convert(bias_element_func(y)); + }, + pcomp_tile, + bias_tile); + + move_tile_window(bias_dram_window, {0, kK1}); + } + else + { + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); + } + + if(!mask.IsFullTileInsideMask( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) + { + constexpr auto p_spans = PcompBlockTileType::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + pcomp_tile(i_j_idx) *= + static_cast(mask.IsTokenPairInsideMask(row, col)); }); - } + }); + } - pcomp_tile = cast_tile(sacc_tile); + tile_elementwise_inout(f_silu, pcomp_tile); - tile_elementwise_inout(f_silu, pcomp_tile); + seqlen_k_curr += kN0; - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegTileDistribution()); - shuffle_tile(v_shuffle_tmp, v_tile); + if constexpr(kHasDropout) + { + auto randval_lds_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - // if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer - // i+2, No overlap occurs between V and K in the same unroll, and V in current - // unroll and K in next unroll or first unroll in next iteration - store_tile( - v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + dropout.template Run( + randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); + } - if constexpr(kHasDropout) - { - auto randval_lds_ptr = reinterpret_cast(smem_ptr) + - Policy::template GetSmemSizeKV(); + auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); - dropout.template Run( - randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); - } + using v_shuffled_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeShuffledVRegTileDistribution())); - auto p = - cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + statically_indexed_array v_shuffled_tiles; + + static_for<0, k1_loops, 1>{}( + [&](auto i_k1) { shuffle_tile(v_shuffled_tiles[i_k1], v_tiles[i_k1]); }); + + // check whether first V-LdsBufer overlap with next K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + + // STAGE 3, Gemm_1 ( O = P@V ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_shuffled_tiles[i_k1])); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // load k_tiles used by next iteration + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + + __builtin_amdgcn_sched_barrier(0x00000001); block_sync_lds(); - gemm_1(o_acc, p, v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); - - seqlen_k_curr += kK1; - - if constexpr(i_k1 < k1_loops - 1) - { - // check whether current V-LdsBufer overlap with next K-LdsBuffer, - // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 - if constexpr((i_k1 + 2) % NumKVLdsBuffers == (i_k1 + 1) % NumKVLdsBuffers) - { - __builtin_amdgcn_s_barrier(); - }; - - store_tile(k_lds_write_windows[number<(i_k1 + 1) % NumKVLdsBuffers>{}], - tile_elementwise_in(k_element_func, k_tile)); - - __builtin_amdgcn_sched_barrier(0x00000001); - } - else - { - // check whether last V-LdsBuffer overlap with first K-LdsBuffer, - // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 - if constexpr((i_k1 + 2) % NumKVLdsBuffers == 0) - { - __builtin_amdgcn_s_barrier(); - }; - - store_tile(k_lds_write_windows[number<0>{}], - tile_elementwise_in(k_element_func, k_tile)); - } + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); }); + + // check whether last V-LdsBuffer overlap with first K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0) + { + __builtin_amdgcn_s_barrier(); + }; } while(seqlen_k_curr < seqlen_k_end); tile_elementwise_inout([&](auto& x) { x = x * type_convert(scale_p); }, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 0f30bdc167..653c3a1da3 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -68,12 +68,20 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy return WG::WarpGemmAttribute::kKPerThread; }; + template + CK_TILE_HOST_DEVICE static constexpr auto MakePRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + return BlockGemm::template MakeABlockTileDistribution< + Problem::HstuAttentionTileSetting::kM0, + Problem::HstuAttentionTileSetting::kN0>(); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution() { - using BlockGemm = remove_cvref_t())>; - - return BlockGemm::MakeCBlockTile().get_tile_distribution(); + return MakePRegTileDistribution(); } template