fix int4 moe 16x16x32 type perf issue

This commit is contained in:
mtgu0705
2025-03-19 07:15:57 +00:00
parent 1228d112ef
commit 045e7d65b3
3 changed files with 14 additions and 7 deletions

View File

@@ -12,6 +12,10 @@ foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp)
add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp)
set(EXAMPLE_COMPILE_OPTIONS)
list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1)
target_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
target_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
set(target 1)
endif()
endforeach()

View File

@@ -125,10 +125,10 @@ using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t MXDLPerWave = 4;
static constexpr ck::index_t NXDLPerWave = 1;
static constexpr ck::index_t MXDLPerWave = 8;
static constexpr ck::index_t NXDLPerWave = 2;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t CShuffleNLane = 32;
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
@@ -148,7 +148,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
MXDLPerWave, NXDLPerWave,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
// clang-format on

View File

@@ -62,7 +62,7 @@ inline __device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
uint32_t fp8x4_0;
uint32_t fp8x4_1;
float tmp_0, tmp_1, tmp_2;
asm volatile("v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_2\n"
"v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1]\n"
@@ -83,8 +83,11 @@ inline __device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
[v_dst_1] "+v"(fp8x4_1),
[v_src] "+v"(i4x8)
:);
return bit_cast<f8x8_t>(((static_cast<uint64_t>(fp8x4_1) << 32) | fp8x4_0));
vector_type<f8_t, 8> out;
out.template AsType<f8x4_t>()(Number<0>{}) = bit_cast<f8x4_t>(fp8x4_0);
out.template AsType<f8x4_t>()(Number<1>{}) = bit_cast<f8x4_t>(fp8x4_1);
return out.template AsType<f8x8_t>()[Number<0>{}];
}
// c0 += inner_product(a, b0)