mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
ckmoe: change cmake; use smaller shape for i4 (#2027)
* change cmake; use smaller shape for i4
* fix pki4 run
* fix typo
* fix runtime arch logic for moe_gemm2 example
---------
Co-authored-by: coderfeli <coderfeli@163.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
[ROCm/composable_kernel commit: 36d50de50e]
This commit is contained in:
@@ -3,14 +3,14 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_mult
|
||||
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp)
|
||||
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
|
||||
# add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp)
|
||||
add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp)
|
||||
add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp)
|
||||
|
||||
list(APPEND gpu_list gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
# add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp)
|
||||
add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp)
|
||||
add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
|
||||
@@ -191,14 +191,14 @@ int main(int argc, char* argv[])
|
||||
// experts = 8
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
ck::index_t N = 14336 * 2;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t N = 4096 * 2;
|
||||
ck::index_t K = 6144;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
ck::index_t tokens = 64;
|
||||
ck::index_t tokens = 644;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
@@ -440,8 +440,8 @@ int main(int argc, char* argv[])
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" ||
|
||||
ck::get_device_name() != "gfx950")
|
||||
if(!device_op.IsSupportedArgument(argument) ||
|
||||
!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
|
||||
@@ -407,8 +407,8 @@ int main(int argc, char* argv[])
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" ||
|
||||
ck::get_device_name() != "gfx950")
|
||||
if(!device_op.IsSupportedArgument(argument) ||
|
||||
!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
|
||||
Reference in New Issue
Block a user