Revert "remove unused codes."

This reverts commit 50763eefab.
This commit is contained in:
mtgu0705
2025-03-26 16:26:12 +08:00
parent c1938b9611
commit 83fd31dac2
5 changed files with 21 additions and 5 deletions

View File

@@ -45,7 +45,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(EXAMPLE_COMPILE_OPTIONS)
list(APPEND EXAMPLE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(example_gemm_xdl_fp8_pk_i4_v3 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
target_compile_options(example_gemm_xdl_fp8_pk_i4_bpreshuffle_v3 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
set(target 1)
endif()
endforeach()

View File

@@ -5,7 +5,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp"
using F8 = ck::f8_t;
using F8 = ck::f8_t;
using I4 = ck::pk_i4_t;
using F16 = ck::half_t;
using F32 = float;

View File

@@ -5,8 +5,6 @@ add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp)
add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp)
# add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp)
# add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp)
list(APPEND gpu_list gfx942)
set(target 0)

View File

@@ -472,7 +472,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v2, "v2"}, {BlockGemmPipelineVersion::v3, "v3"}};
{BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v2, "v2"}};
// clang-format off
str << "DeviceMoeGEmm"

View File

@@ -81,9 +81,27 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
__device__ inline f8x4_t i4_to_f8x4(int q)
{
#if 0
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
int lo = amd_assembly_and_b32(q, LO);
int hi = amd_assembly_and_b32(q, HI);
float f32_0 = amd_assemble_cvt_f32_i4(lo);
float f32_1 = amd_assemble_cvt_f32_i4(lo >> 16);
float f32_2 = amd_assemble_cvt_f32_i4(hi);
float f32_3 = amd_assemble_cvt_f32_i4(hi >> 16);
return amd_assembly_cvt_f8_to_f32(f32_0, f32_1, f32_2, f32_3);
#else
// [0, 1, 2, 3] encoded as FP8
static constexpr uint32_t POS_E4M3s_TABLE1 = 0x2C282000;
// [4, 5, 6, 7] encoded as FP8
static constexpr uint32_t POS_E4M3s_TABLE2 = 0x36343230;
// [-8, -7, -6, -5] encoded as FP8
static constexpr uint32_t NEG_E4M3s_TABLE1 = 0xB2B4B6B8;
// [-4, -3, -2, -1] encoded as FP8
static constexpr uint32_t NEG_E4M3s_TABLE2 = 0xA0A8ACB0;
uint32_t tmp_pos, tmp_neg, tmp_res, final_sel;
@@ -103,6 +121,7 @@ __device__ inline f8x4_t i4_to_f8x4(int q)
res.template AsType<f8x4_t>()(Number<0>{}) = bit_cast<f8x4_t>(tmp_res);
return res.template AsType<f8x4_t>()[Number<0>{}];
#endif
}
__device__ inline f8x8_t i4_to_fp8x8(int q)