diff --git a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp index 731810364c..9bb80cf258 100644 --- a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp +++ b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp @@ -57,7 +57,7 @@ struct BlockGemmARegBSmemCRegV2Hack_0 constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - const index_t iNWarp = get_warp_id() % NWarp; + const index_t iNWarp = get_warp_id() % NWarp; static_assert(NWarp == 1, "Check failed!"); diff --git a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp index 0038d725a3..a7a21ef311 100644 --- a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp +++ b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_1.hpp @@ -57,7 +57,7 @@ struct BlockGemmARegBSmemCRegV2Hack_1 constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - const index_t iNWarp = get_warp_id() % NWarp; + const index_t iNWarp = get_warp_id() % NWarp; constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< sequence<>, diff --git a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp index f2ca10f4e8..359e55a0e0 100644 --- a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp +++ b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp @@ -57,7 +57,7 @@ struct BlockGemmARegBSmemTrLoadCRegV2Hack_1 constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - const index_t iNWarp = get_warp_id() % NWarp; + const index_t iNWarp = get_warp_id() % NWarp; constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< sequence<>, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp index 4b65a51c8b..51dddcde5b 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -212,6 +212,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS __builtin_amdgcn_sched_barrier(0); + // provide partition_index for LDS tile window so that warp_id is in vgpr + array partition_index{get_warp_id(), get_lane_id()}; + // Q tile in LDS QKVDataType* q_lds_ptr = static_cast(smem_ptr); auto q_lds = make_tensor_view( @@ -223,7 +226,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}, - Policy::template MakeQRegSingleRepMTileDistribution()); + Policy::template MakeQRegSingleRepMTileDistribution(), + partition_index); // K tile in LDS QKVDataType* k_lds_ptr = static_cast(smem_ptr); @@ -333,7 +337,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS { static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { - store_tile(q_lds_write_window, q_dram_tiles[i_rep]); + store_tile(q_lds_write_window, q_dram_tiles[i_rep], partition_index); // no need to call __builtin_amdgcn_s_barrier() since the tile-slice written // by each wavefront is read by itself @@ -375,7 +379,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS { // STAGE 1, Gemm_0 ( S = Q@K ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[i_k1], k_tiles[i_k1]); + store_tile(k_lds_write_windows[i_k1], k_tiles[i_k1], partition_index); __builtin_amdgcn_sched_barrier(0x00000001); @@ -458,7 +462,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS __builtin_amdgcn_s_barrier(); }; - store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile); + store_tile( + v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index); __builtin_amdgcn_sched_barrier(0x00000001); @@ -501,7 +506,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS shuffle_tile(v_shuffled_tile, v_tiles[number{}]); store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], - v_shuffled_tile); + v_shuffled_tile, + partition_index); __builtin_amdgcn_sched_barrier(0x00000001); }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index b887dd832d..9547ec21de 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -205,6 +205,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad __builtin_amdgcn_sched_barrier(0); + // provide partition_index for LDS tile window so that warp_id is in vgpr + array partition_index{get_warp_id(), get_lane_id()}; + // Q tile in LDS QKVDataType* q_lds_ptr = static_cast(smem_ptr); auto q_lds = make_tensor_view( @@ -216,7 +219,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}, - Policy::template MakeQRegTileDistribution()); + Policy::template MakeQRegTileDistribution(), + partition_index); // K tile in LDS QKVDataType* k_lds_ptr = static_cast(smem_ptr); @@ -315,7 +319,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); }(); - store_tile(q_lds_write_window, q_dram_tile); + store_tile(q_lds_write_window, q_dram_tile, partition_index); clear_tile(o_acc); @@ -344,7 +348,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad { // STAGE 1, Gemm_0 ( S = Q@K ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile(k_lds_write_windows[number{}], k_tiles[i_k1]); + store_tile(k_lds_write_windows[number{}], + k_tiles[i_k1], + partition_index); __builtin_amdgcn_sched_barrier(0x00000001); @@ -439,7 +445,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad // STAGE 3, Gemm_1 ( O = P@V ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], - v_tiles[number{}]); + v_tiles[number{}], + partition_index); __builtin_amdgcn_sched_barrier(0x00000001); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index 20ea006e58..b71cecbece 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -229,6 +229,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS __builtin_amdgcn_sched_barrier(0); + // provide partition_index for LDS tile window so that warp_id is in vgpr + array partition_index{get_warp_id(), get_lane_id()}; + // Q tile in LDS QKVDataType* q_lds_ptr = static_cast(smem_ptr); auto q_lds = make_tensor_view( @@ -240,7 +243,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}, - Policy::template MakeQRegSingleRepMTileDistribution()); + Policy::template MakeQRegSingleRepMTileDistribution(), + partition_index); // K tile in LDS QKVDataType* k_lds_ptr = static_cast(smem_ptr); @@ -347,7 +351,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS { static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) { - store_tile(q_lds_write_window, q_dram_tiles[i_rep]); + store_tile(q_lds_write_window, q_dram_tiles[i_rep], partition_index); // no need to call __builtin_amdgcn_s_barrier() since the tile-slice written // by each wavefront is read by itself @@ -393,7 +397,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS // STAGE 1, Gemm_0 ( S = Q@K ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { store_tile(k_lds_write_windows[number{}], - k_tiles[number{}]); + k_tiles[number{}], + partition_index); __builtin_amdgcn_sched_barrier(0x00000001); @@ -500,7 +505,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS __builtin_amdgcn_s_barrier(); }; - store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile); + store_tile( + v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index); __builtin_amdgcn_sched_barrier(0x00000001); @@ -581,7 +587,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]); - store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile); + store_tile( + v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index); __builtin_amdgcn_sched_barrier(0x00000001); @@ -609,7 +616,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS shuffle_tile(v_shuffled_tile, v_tiles[number{}]); store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}], - v_shuffled_tile); + v_shuffled_tile, + partition_index); __builtin_amdgcn_sched_barrier(0x00000001); }; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 94135d9e84..fce1809ce3 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -222,6 +222,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad __builtin_amdgcn_sched_barrier(0); + // provide partition_index for LDS tile window so that warp_id is in vgpr + array partition_index{get_warp_id(), get_lane_id()}; + // Q tile in LDS QKVDataType* q_lds_ptr = static_cast(smem_ptr); auto q_lds = make_tensor_view( @@ -329,7 +332,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); }(); - store_tile(q_lds_write_window, q_dram_tile); + store_tile(q_lds_write_window, q_dram_tile, partition_index); clear_tile(o_acc); @@ -362,7 +365,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad // STAGE 1, Gemm_0 ( S = Q@K ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { store_tile(k_lds_write_windows[number{}], - k_tiles[number{}]); + k_tiles[number{}], + partition_index); __builtin_amdgcn_sched_barrier(0x00000001); @@ -464,7 +468,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad __builtin_amdgcn_s_barrier(); }; - store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_tiles[number<0>{}]); + store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], + v_tiles[number<0>{}], + partition_index); __builtin_amdgcn_sched_barrier(0x00000001); @@ -550,7 +556,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad __builtin_amdgcn_s_barrier(); }; - store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_tiles[number<1>{}]); + store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], + v_tiles[number<1>{}], + partition_index); __builtin_amdgcn_sched_barrier(0x00000001); @@ -577,7 +585,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad __builtin_amdgcn_sched_barrier(0x00000001); store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}], - v_tiles[number{}]); + v_tiles[number{}], + partition_index); __builtin_amdgcn_sched_barrier(0x00000001); };