mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
fix int4 moe 16x16x32 type perf issue
This commit is contained in:
@@ -12,6 +12,10 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp)
|
||||
add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp)
|
||||
set(EXAMPLE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1)
|
||||
target_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
target_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -125,10 +125,10 @@ using CDEElementOp = MulABScaleExpertWeight;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t MXDLPerWave = 4;
|
||||
static constexpr ck::index_t NXDLPerWave = 1;
|
||||
static constexpr ck::index_t MXDLPerWave = 8;
|
||||
static constexpr ck::index_t NXDLPerWave = 2;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t CShuffleNLane = 32;
|
||||
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
|
||||
@@ -148,7 +148,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
MXDLPerWave, NXDLPerWave,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ inline __device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
|
||||
uint32_t fp8x4_0;
|
||||
uint32_t fp8x4_1;
|
||||
float tmp_0, tmp_1, tmp_2;
|
||||
|
||||
|
||||
asm volatile("v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1]\n"
|
||||
@@ -83,8 +83,11 @@ inline __device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
|
||||
[v_dst_1] "+v"(fp8x4_1),
|
||||
[v_src] "+v"(i4x8)
|
||||
:);
|
||||
|
||||
return bit_cast<f8x8_t>(((static_cast<uint64_t>(fp8x4_1) << 32) | fp8x4_0));
|
||||
|
||||
vector_type<f8_t, 8> out;
|
||||
out.template AsType<f8x4_t>()(Number<0>{}) = bit_cast<f8x4_t>(fp8x4_0);
|
||||
out.template AsType<f8x4_t>()(Number<1>{}) = bit_cast<f8x4_t>(fp8x4_1);
|
||||
return out.template AsType<f8x8_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
|
||||
Reference in New Issue
Block a user