diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index 66e28db661..9c9ac286b4 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -73,7 +73,7 @@ struct HstuAttentionWithSoftmaxFwdBlockTile<64> template <> struct HstuAttentionWithSoftmaxFwdBlockTile<128> { - using type = ck_tile::sequence<128, 32, 128, 16, 128>; + using type = ck_tile::sequence<128, 64, 128, 16, 128>; using gemm0_warps = ck_tile::sequence<4, 1, 1>; using gemm1_warps = ck_tile::sequence<4, 1, 1>; }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp index 498bcece4b..593ce90a31 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -448,6 +448,27 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS }); } + __builtin_amdgcn_sched_barrier(0x00000001); + + using v_shuffled_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeShuffledVRegTileDistribution())); + + v_shuffled_tile_type v_shuffled_tile; + + shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]); + + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + + store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_shuffled_tile)); + + __builtin_amdgcn_sched_barrier(0x00000001); + tile_elementwise_inout(f_silu, pcomp_tile); tile_elementwise_inout([&](auto& x) { x = x * type_convert(scale_p); }, @@ -466,28 +487,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); - using v_shuffled_tile_type = decltype(make_static_distributed_tensor( - Policy::template MakeShuffledVRegTileDistribution())); - - statically_indexed_array v_shuffled_tiles; - - static_for<0, k1_loops, 1>{}( - [&](auto i_k1) { shuffle_tile(v_shuffled_tiles[i_k1], v_tiles[i_k1]); }); - - // check whether first V-LdsBufer overlap with last K-LdsBuffer, - // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 - if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) - { - __builtin_amdgcn_s_barrier(); - }; - // STAGE 3, Gemm_1 ( O = P@V ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffled_tiles[i_k1])); - - __builtin_amdgcn_sched_barrier(0x00000001); - // load k_tiles used by next iteration k_tiles[i_k1] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); @@ -500,6 +501,17 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS o_acc, get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); + + if constexpr(i_k1 < k1_loops - 1) + { + __builtin_amdgcn_sched_barrier(0x00000001); + + shuffle_tile(v_shuffled_tile, v_tiles[number{}]); + store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_shuffled_tile)); + + __builtin_amdgcn_sched_barrier(0x00000001); + }; }); // check whether last V-LdsBuffer overlap with first K-LdsBuffer, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index 5cd5ec58e9..7bc8445bea 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -219,9 +219,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS using k_tile_type = decltype(load_tile(k_dram_window)); - statically_indexed_array k_tiles; + constexpr index_t NumPrefetchK = 2; - static_for<0, k1_loops, 1>{}([&](auto i_k1) { + static_assert(k1_loops >= NumPrefetchK, "Check failed!"); + + // only prefetch two k tiles to save vgprs consumption + statically_indexed_array k_tiles; + + static_for<0, NumPrefetchK, 1>{}([&](auto i_k1) { k_tiles[i_k1] = load_tile(k_dram_window); move_tile_window(k_dram_window, {kK1, 0}); }); @@ -391,14 +396,23 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS { // STAGE 1, Gemm_0 ( S = Q@K ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[i_k1], - tile_elementwise_in(k_element_func, k_tiles[i_k1])); + store_tile( + k_lds_write_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}])); __builtin_amdgcn_sched_barrier(0x00000001); - // load v_tiles used in current iteration - v_tiles[i_k1] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); + if constexpr(i_k1 < k1_loops - NumPrefetchK) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + } + else + { + // load v_tiles used in current iteration + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; __builtin_amdgcn_sched_barrier(0x00000001); @@ -417,8 +431,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS sequence{}); }); - __builtin_amdgcn_sched_barrier(0x00000001); - // STAGE 2, scale_s, add bias, mask, siLU if constexpr(kHasBias) { @@ -477,6 +489,35 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS }); }; + __builtin_amdgcn_sched_barrier(0x00000001); + + using v_shuffled_tile_type = decltype(make_static_distributed_tensor( + Policy::template MakeShuffledVRegTileDistribution())); + + v_shuffled_tile_type v_shuffled_tile; + + shuffle_tile(v_shuffled_tile, v_tiles[number<0>{}]); + + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + + store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_shuffled_tile)); + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for{}([&](auto i_k1) { + // load v_tiles used in current iteration + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }); + + __builtin_amdgcn_sched_barrier(0x00000001); + auto m_local = block_tile_reduce( pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); block_tile_reduce_sync(m_local, f_max, bool_constant{}); @@ -544,31 +585,21 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); - using v_shuffled_tile_type = decltype(make_static_distributed_tensor( - Policy::template MakeShuffledVRegTileDistribution())); + shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]); - statically_indexed_array v_shuffled_tiles; + store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_shuffled_tile)); - static_for<0, k1_loops, 1>{}( - [&](auto i_k1) { shuffle_tile(v_shuffled_tiles[i_k1], v_tiles[i_k1]); }); - - // check whether first V-LdsBufer overlap with last K-LdsBuffer, - // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 - if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) - { - __builtin_amdgcn_s_barrier(); - }; + __builtin_amdgcn_sched_barrier(0x00000001); // STAGE 3, Gemm_1 ( O = P@V ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffled_tiles[i_k1])); - - __builtin_amdgcn_sched_barrier(0x00000001); - - // load k_tiles used by next iteration - k_tiles[i_k1] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {kK1, 0}); + if constexpr(i_k1 < NumPrefetchK) + { + // load k_tiles used by next iteration + k_tiles[i_k1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kK1, 0}); + }; __builtin_amdgcn_sched_barrier(0x00000001); @@ -578,6 +609,17 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS o_acc, get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); + + if constexpr(i_k1 < k1_loops - 2) + { + __builtin_amdgcn_sched_barrier(0x00000001); + + shuffle_tile(v_shuffled_tile, v_tiles[number{}]); + store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_shuffled_tile)); + + __builtin_amdgcn_sched_barrier(0x00000001); + }; }); // check whether last V-LdsBuffer overlap with first K-LdsBuffer,