set lane_group_sz=1 for small token decode

This commit is contained in:
Robin Elbers
2026-04-20 11:13:01 -04:00
parent 574c1c121a
commit b05adebf38

View File

@@ -829,23 +829,24 @@ struct MoeSortingKernel
// smem_cumdup(0) = 0;
}
constexpr int lane_group_sz = 1;
constexpr int lane_group_nm = block_size / lane_group_sz;
{
constexpr int lane_group_sz = 8;
int lane_group_id = tid / lane_group_sz;
int lane_group_os = tid % lane_group_sz;
constexpr int lane_group_nm = block_size / lane_group_sz;
for(int i_e = lane_group_id; i_e < num_experts; i_e += lane_group_nm)
{
index_t local_c[Problem::SubTokenTile];
index_t cnt = 0;
for(int i = 0; i < sub_tokens; i += 8 * Problem::SubTokenTile)
for(int i = 0; i < sub_tokens; i += lane_group_sz * Problem::SubTokenTile)
{
#pragma unroll Problem::SubTokenTile
for(int j = 0; j < Problem::SubTokenTile; j++)
{
local_c[j] = smem_tokens(i + j * 8 + lane_group_os, i_e);
local_c[j] = smem_tokens(i + j * lane_group_sz + lane_group_os, i_e);
if constexpr(Problem::SubTokenOneShot)
{
local_c[j] = local_c[j] != 0 ? 1 : 0;
@@ -855,7 +856,7 @@ struct MoeSortingKernel
#pragma unroll Problem::SubTokenTile
for(int j = 0; j < Problem::SubTokenTile; j++)
{
cnt += wave_reduce(local_c[j], f_sum, number<8>{});
cnt += wave_reduce(local_c[j], f_sum, number<lane_group_sz>{});
}
}
if(lane_group_os == 0)
@@ -1022,10 +1023,8 @@ struct MoeSortingKernel
}
{
constexpr int lane_group_sz = 8;
int lane_group_id = tid / lane_group_sz;
int lane_group_os = tid % lane_group_sz;
constexpr int lane_group_nm = block_size / lane_group_sz;
for(int eid = lane_group_id; eid < num_experts; eid += lane_group_nm)
{
if constexpr(Problem::LocalExpertMasking)