From 900acdc2dbcaebefe6707013466f7050866d0688 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 28 Mar 2025 00:04:31 +0800 Subject: [PATCH] ckmoe: change cmake; use smaller shape for i4 (#2027) * change cmake; use smaller shape for i4 * fix pki4 run * fix typo * fix runtime arch logic for moe_gemm2 example --------- Co-authored-by: coderfeli Co-authored-by: illsilin [ROCm/composable_kernel commit: 36d50de50e30e92950070c3449b99d78143fb221] --- example/65_gemm_multiply_multiply/CMakeLists.txt | 4 ++-- .../65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp | 10 +++++----- .../65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 38b42fefc4..95fd8bace8 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -3,14 +3,14 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_mult add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) -# add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) +add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) list(APPEND gpu_list gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - # add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) + add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp) set(target 1) endif() diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 17f4cd8a3f..1102ce1054 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -191,14 +191,14 @@ int main(int argc, char* argv[]) // experts = 8 // per expert: // GEMM shape - ck::index_t N = 14336 * 2; - ck::index_t K = 4096; + ck::index_t N = 4096 * 2; + ck::index_t K = 6144; ck::index_t experts = 8; ck::index_t sorted_tile_num = 16; ck::index_t valid_tile_num = 13; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; - ck::index_t tokens = 64; + ck::index_t tokens = 644; ck::index_t topk = 2; if(argc == 1) @@ -440,8 +440,8 @@ int main(int argc, char* argv[]) b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!device_op.IsSupportedArgument(argument) || + !(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index 8441862004..528503a2c4 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -407,8 +407,8 @@ int main(int argc, char* argv[]) b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!device_op.IsSupportedArgument(argument) || + !(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does "