From f55cac672dcee709b9e580e91c62df9e0d658fe8 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 5 Mar 2025 01:10:40 +0000 Subject: [PATCH] rename moe example --- example/65_gemm_multiply_multiply/CMakeLists.txt | 4 ++-- .../{moe_gemm1.cpp => moe_gemm1_xdl_fp8.cpp} | 16 +++++++++------- .../{moe_gemm2.cpp => moe_gemm2_xdl_fp8.cpp} | 14 ++++++++------ .../gpu/grid/gridwise_moe_gemm.hpp | 2 +- 4 files changed, 20 insertions(+), 16 deletions(-) rename example/65_gemm_multiply_multiply/{moe_gemm1.cpp => moe_gemm1_xdl_fp8.cpp} (97%) rename example/65_gemm_multiply_multiply/{moe_gemm2.cpp => moe_gemm2_xdl_fp8.cpp} (98%) diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index a9e886d6db..62a8112a1a 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -3,5 +3,5 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_mult add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) -add_example_executable(example_moe_gemm1 moe_gemm1.cpp) -add_example_executable(example_moe_gemm2 moe_gemm2.cpp) \ No newline at end of file +add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) +add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) \ No newline at end of file diff --git a/example/65_gemm_multiply_multiply/moe_gemm1.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp similarity index 97% rename from example/65_gemm_multiply_multiply/moe_gemm1.cpp rename to example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 8263d68a71..6569d849ac 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -169,12 +169,12 @@ int main(int argc, char* argv[]) bool time_kernel = true; // GEMM shape - ck::index_t N = 4096 * 2; - ck::index_t K = 6144; + ck::index_t N = 4096; + ck::index_t K = 4096; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 16; - ck::index_t valid_tile_num = 13; - ck::index_t tokens = 544; + ck::index_t sorted_tile_num = 8; + ck::index_t valid_tile_num = 8; + ck::index_t tokens = 128; ck::index_t topk = 2; // ck::index_t tokens = batch * topk; @@ -235,8 +235,10 @@ int main(int argc, char* argv[]) // max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0}; // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; // int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; + // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + int eids[] = {0,1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} for(int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; diff --git a/example/65_gemm_multiply_multiply/moe_gemm2.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp similarity index 98% rename from example/65_gemm_multiply_multiply/moe_gemm2.cpp rename to example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index bdc62a5b87..309a95c50c 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -181,13 +181,13 @@ int main(int argc, char* argv[]) // per expert: // GEMM shape ck::index_t N = 4096; - ck::index_t K = 14336; + ck::index_t K = 4096; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 16; - ck::index_t valid_tile_num = 13; + ck::index_t sorted_tile_num = 6; + ck::index_t valid_tile_num = 6; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; - ck::index_t tokens = 512; + ck::index_t tokens = 128; ck::index_t topk = 2; if(argc == 1) @@ -232,8 +232,10 @@ int main(int argc, char* argv[]) Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1})); // max_token_id.mData[0] = valid_size; - max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; + // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; + // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; + max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + int eids[] = {0,1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} for(int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 96b2e8e075..d0e06a6c53 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -16,7 +16,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" -#define DEBUG_LOG 1 +#define DEBUG_LOG 0 namespace ck {