From 045e7d65b32975dde8fcb7c2cef801dbc8f0bd84 Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Wed, 19 Mar 2025 07:15:57 +0000 Subject: [PATCH] fix int4 moe 16x16x32 type perf issue --- example/65_gemm_multiply_multiply/CMakeLists.txt | 4 ++++ .../65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp | 8 ++++---- include/ck/utility/amd_inline_asm.hpp | 9 ++++++--- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 95fd8bace8..04c068a36a 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -12,6 +12,10 @@ foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp) + set(EXAMPLE_COMPILE_OPTIONS) + list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1) + target_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + target_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) set(target 1) endif() endforeach() diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index c80b01d8c5..8d6da91b60 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -125,10 +125,10 @@ using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t MXDLPerWave = 4; -static constexpr ck::index_t NXDLPerWave = 1; +static constexpr ck::index_t MXDLPerWave = 8; +static constexpr ck::index_t NXDLPerWave = 2; static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t CShuffleNLane = 32; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; @@ -148,7 +148,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic MXDLPerWave, NXDLPerWave, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, - 1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; // clang-format on diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index de59f200f0..1eaea17935 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -62,7 +62,7 @@ inline __device__ f8x8_t amd_assembly_i4_to_fp8x8(int a) uint32_t fp8x4_0; uint32_t fp8x4_1; float tmp_0, tmp_1, tmp_2; - + asm volatile("v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n" "v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_2\n" "v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1]\n" @@ -83,8 +83,11 @@ inline __device__ f8x8_t amd_assembly_i4_to_fp8x8(int a) [v_dst_1] "+v"(fp8x4_1), [v_src] "+v"(i4x8) :); - - return bit_cast(((static_cast(fp8x4_1) << 32) | fp8x4_0)); + + vector_type out; + out.template AsType()(Number<0>{}) = bit_cast(fp8x4_0); + out.template AsType()(Number<1>{}) = bit_cast(fp8x4_1); + return out.template AsType()[Number<0>{}]; } // c0 += inner_product(a, b0)