update case construction

This commit is contained in:
Feng Shijie
2025-08-11 07:56:14 +00:00
parent 8b85fa6cf2
commit 87aed564dc

View File

@@ -101,7 +101,7 @@ int run_moe_gemm_example_with_layouts(int argc,
// TODO: replace the magic declaration
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
ck_tile::index_t sorted_tile_num = 8;
ck_tile::index_t sorted_tile_num = num_tokens * topk / MPerBlock;
ck_tile::index_t valid_tile_num = sorted_tile_num;
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;
@@ -191,16 +191,15 @@ int run_moe_gemm_example_with_layouts(int argc,
per_channel_scale.get_element_space_size_in_bytes());
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
// int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
for(int i = 0; i < sorted_tile_num; i++)
{
eids[i] = min(eids[i], experts - 1);
expert_ids.mData[i] = eids[i];
expert_ids.mData[i] = i / (valid_tile_num / experts);
}
// int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
int token_per_tile = num_tokens * topk / valid_tile_num;
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
// int token_per_tile = num_tokens * topk / valid_tile_num;
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)