rename moe example

This commit is contained in:
coderfeli
2025-03-05 01:10:40 +00:00
parent 7cca528aa7
commit f55cac672d
4 changed files with 20 additions and 16 deletions

View File

@@ -3,5 +3,5 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_mult
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
add_example_executable(example_moe_gemm1 moe_gemm1.cpp)
add_example_executable(example_moe_gemm2 moe_gemm2.cpp)
add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp)
add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp)

View File

@@ -169,12 +169,12 @@ int main(int argc, char* argv[])
bool time_kernel = true;
// GEMM shape
ck::index_t N = 4096 * 2;
ck::index_t K = 6144;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 13;
ck::index_t tokens = 544;
ck::index_t sorted_tile_num = 8;
ck::index_t valid_tile_num = 8;
ck::index_t tokens = 128;
ck::index_t topk = 2;
// ck::index_t tokens = batch * topk;
@@ -235,8 +235,10 @@ int main(int argc, char* argv[])
// max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0};
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
// int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
int eids[] = {0,1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];

View File

@@ -181,13 +181,13 @@ int main(int argc, char* argv[])
// per expert:
// GEMM shape
ck::index_t N = 4096;
ck::index_t K = 14336;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 13;
ck::index_t sorted_tile_num = 6;
ck::index_t valid_tile_num = 6;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t tokens = 512;
ck::index_t tokens = 128;
ck::index_t topk = 2;
if(argc == 1)
@@ -232,8 +232,10 @@ int main(int argc, char* argv[])
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
// max_token_id.mData[0] = valid_size;
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
int eids[] = {0,1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];

View File

@@ -16,7 +16,7 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp"
#define DEBUG_LOG 1
#define DEBUG_LOG 0
namespace ck {