mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
set lane_group_sz=1 for small token decode
This commit is contained in:
@@ -829,23 +829,24 @@ struct MoeSortingKernel
|
||||
// smem_cumdup(0) = 0;
|
||||
}
|
||||
|
||||
constexpr int lane_group_sz = 1;
|
||||
constexpr int lane_group_nm = block_size / lane_group_sz;
|
||||
|
||||
{
|
||||
constexpr int lane_group_sz = 8;
|
||||
int lane_group_id = tid / lane_group_sz;
|
||||
int lane_group_os = tid % lane_group_sz;
|
||||
constexpr int lane_group_nm = block_size / lane_group_sz;
|
||||
|
||||
for(int i_e = lane_group_id; i_e < num_experts; i_e += lane_group_nm)
|
||||
{
|
||||
index_t local_c[Problem::SubTokenTile];
|
||||
index_t cnt = 0;
|
||||
|
||||
for(int i = 0; i < sub_tokens; i += 8 * Problem::SubTokenTile)
|
||||
for(int i = 0; i < sub_tokens; i += lane_group_sz * Problem::SubTokenTile)
|
||||
{
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(int j = 0; j < Problem::SubTokenTile; j++)
|
||||
{
|
||||
local_c[j] = smem_tokens(i + j * 8 + lane_group_os, i_e);
|
||||
local_c[j] = smem_tokens(i + j * lane_group_sz + lane_group_os, i_e);
|
||||
if constexpr(Problem::SubTokenOneShot)
|
||||
{
|
||||
local_c[j] = local_c[j] != 0 ? 1 : 0;
|
||||
@@ -855,7 +856,7 @@ struct MoeSortingKernel
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(int j = 0; j < Problem::SubTokenTile; j++)
|
||||
{
|
||||
cnt += wave_reduce(local_c[j], f_sum, number<8>{});
|
||||
cnt += wave_reduce(local_c[j], f_sum, number<lane_group_sz>{});
|
||||
}
|
||||
}
|
||||
if(lane_group_os == 0)
|
||||
@@ -1022,10 +1023,8 @@ struct MoeSortingKernel
|
||||
}
|
||||
|
||||
{
|
||||
constexpr int lane_group_sz = 8;
|
||||
int lane_group_id = tid / lane_group_sz;
|
||||
int lane_group_os = tid % lane_group_sz;
|
||||
constexpr int lane_group_nm = block_size / lane_group_sz;
|
||||
for(int eid = lane_group_id; eid < num_experts; eid += lane_group_nm)
|
||||
{
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
|
||||
Reference in New Issue
Block a user