diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 52b2b86574..089cbcafb1 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -829,23 +829,24 @@ struct MoeSortingKernel // smem_cumdup(0) = 0; } + constexpr int lane_group_sz = 1; + constexpr int lane_group_nm = block_size / lane_group_sz; + { - constexpr int lane_group_sz = 8; int lane_group_id = tid / lane_group_sz; int lane_group_os = tid % lane_group_sz; - constexpr int lane_group_nm = block_size / lane_group_sz; for(int i_e = lane_group_id; i_e < num_experts; i_e += lane_group_nm) { index_t local_c[Problem::SubTokenTile]; index_t cnt = 0; - for(int i = 0; i < sub_tokens; i += 8 * Problem::SubTokenTile) + for(int i = 0; i < sub_tokens; i += lane_group_sz * Problem::SubTokenTile) { #pragma unroll Problem::SubTokenTile for(int j = 0; j < Problem::SubTokenTile; j++) { - local_c[j] = smem_tokens(i + j * 8 + lane_group_os, i_e); + local_c[j] = smem_tokens(i + j * lane_group_sz + lane_group_os, i_e); if constexpr(Problem::SubTokenOneShot) { local_c[j] = local_c[j] != 0 ? 1 : 0; @@ -855,7 +856,7 @@ struct MoeSortingKernel #pragma unroll Problem::SubTokenTile for(int j = 0; j < Problem::SubTokenTile; j++) { - cnt += wave_reduce(local_c[j], f_sum, number<8>{}); + cnt += wave_reduce(local_c[j], f_sum, number{}); } } if(lane_group_os == 0) @@ -1022,10 +1023,8 @@ struct MoeSortingKernel } { - constexpr int lane_group_sz = 8; int lane_group_id = tid / lane_group_sz; int lane_group_os = tid % lane_group_sz; - constexpr int lane_group_nm = block_size / lane_group_sz; for(int eid = lane_group_id; eid < num_experts; eid += lane_group_nm) { if constexpr(Problem::LocalExpertMasking)