From 83fd31dac233bddd86efecff05212d290d407c91 Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Wed, 26 Mar 2025 16:26:12 +0800 Subject: [PATCH] Revert "remove unused codes." This reverts commit 50763eefab263f80d40a34297401e3ee41010480. --- example/01_gemm/CMakeLists.txt | 1 - .../gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp | 2 +- .../65_gemm_multiply_multiply/CMakeLists.txt | 2 -- .../gpu/device/impl/device_moe_gemm.hpp | 2 +- .../element/unary_element_wise_operation.hpp | 19 +++++++++++++++++++ 5 files changed, 21 insertions(+), 5 deletions(-) diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index fab970dc23..65f4b7e923 100755 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -45,7 +45,6 @@ foreach(gpu IN LISTS GPU_TARGETS) set(EXAMPLE_COMPILE_OPTIONS) list(APPEND EXAMPLE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) target_compile_options(example_gemm_xdl_fp8_pk_i4_v3 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) - target_compile_options(example_gemm_xdl_fp8_pk_i4_bpreshuffle_v3 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) set(target 1) endif() endforeach() diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp index 6302bcaa6c..50071dfe4e 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp @@ -5,7 +5,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp" -using F8 = ck::f8_t; +using F8 = ck::f8_t; using I4 = ck::pk_i4_t; using F16 = ck::half_t; using F32 = float; diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index a0010426aa..4d4957b6ba 100755 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -5,8 +5,6 @@ add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) -# add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) -# add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp) list(APPEND gpu_list gfx942) set(target 0) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index 51cf001986..950fe0236d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -472,7 +472,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle BlkGemmPipelineVersionToString{ - {BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v2, "v2"}, {BlockGemmPipelineVersion::v3, "v3"}}; + {BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v2, "v2"}}; // clang-format off str << "DeviceMoeGEmm" diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index e5c1e7beb7..517be925d4 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -81,9 +81,27 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale) __device__ inline f8x4_t i4_to_f8x4(int q) { +#if 0 + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + + int lo = amd_assembly_and_b32(q, LO); + int hi = amd_assembly_and_b32(q, HI); + + float f32_0 = amd_assemble_cvt_f32_i4(lo); + float f32_1 = amd_assemble_cvt_f32_i4(lo >> 16); + float f32_2 = amd_assemble_cvt_f32_i4(hi); + float f32_3 = amd_assemble_cvt_f32_i4(hi >> 16); + + return amd_assembly_cvt_f8_to_f32(f32_0, f32_1, f32_2, f32_3); +#else + // [0, 1, 2, 3] encoded as FP8 static constexpr uint32_t POS_E4M3s_TABLE1 = 0x2C282000; + // [4, 5, 6, 7] encoded as FP8 static constexpr uint32_t POS_E4M3s_TABLE2 = 0x36343230; + // [-8, -7, -6, -5] encoded as FP8 static constexpr uint32_t NEG_E4M3s_TABLE1 = 0xB2B4B6B8; + // [-4, -3, -2, -1] encoded as FP8 static constexpr uint32_t NEG_E4M3s_TABLE2 = 0xA0A8ACB0; uint32_t tmp_pos, tmp_neg, tmp_res, final_sel; @@ -103,6 +121,7 @@ __device__ inline f8x4_t i4_to_f8x4(int q) res.template AsType()(Number<0>{}) = bit_cast(tmp_res); return res.template AsType()[Number<0>{}]; +#endif } __device__ inline f8x8_t i4_to_fp8x8(int q)