From bedaaf27477a3052ec0484f0bb5bf5e90b04024f Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Fri, 14 Mar 2025 13:37:15 +0800 Subject: [PATCH] updated the 2 f8xi4 files output to multiply of 16 --- .../gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp | 33 ++++++++++++++++--- example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp | 29 ++++++++++++++-- 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp index f5c7013698..43512ff4b3 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp @@ -22,7 +22,26 @@ using CLayout = Row; using AElementOp = PassThrough; using BElementOp = PassThrough; -using CElementOp = PassThrough; + +// C = C * 16 Using the assembly instruction for int4->fp8 conversion, need to mulitply 16 +struct MultiplyConst +{ + template + __host__ __device__ constexpr void operator()(C& c, const C1& c1) const; + + template <> + __host__ __device__ constexpr void operator()(CDataType& c, + const CDataType& c1) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + c = ck::type_convert(c1 * 16); +#else + c = ck::type_convert(c1); +#endif + } +}; + +using CElementOp = MultiplyConst; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; @@ -30,7 +49,7 @@ static constexpr bool PermuteA = false; static constexpr bool PermuteB = false; // clang-format off -#if 0 +#if 1 using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle< ALayout, BLayout, CLayout, @@ -38,14 +57,14 @@ using DeviceGemmV2Instance = AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, - 256, 16, 32, + 128, 16, 32, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 1, 1, S<1, 32, 1, 8>, 4, + 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F8, F8, PermuteA, PermuteB>; #else @@ -284,8 +303,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) i4 = (i4x2.data >> 0) & 0xf; else i4 = (i4x2.data >> 4) & 0xf; +#if CK_USE_PK4_LAYOUT_SHUFFLE + float v_b = i4_to_f32_gfx9(i4) * 16; +#else + float v_b = i4 - 8; +#endif - float v_b = i4_to_f32_gfx9(i4); b_k_n_f32(k, n) = v_b; } } diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp index a8101587e8..b52f9d0114 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp @@ -22,7 +22,26 @@ using CLayout = Row; using AElementOp = PassThrough; using BElementOp = PassThrough; -using CElementOp = PassThrough; + +// C = C * 16 Using the assembly instruction for int4->fp8 conversion, need to mulitply 16 +struct MultiplyConst +{ + template + __host__ __device__ constexpr void operator()(C& c, const C1& c1) const; + + template <> + __host__ __device__ constexpr void operator()(CDataType& c, + const CDataType& c1) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + c = ck::type_convert(c1 * 16); +#else + c = ck::type_convert(c1); +#endif + } +}; + +using CElementOp = MultiplyConst; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; @@ -40,7 +59,7 @@ using DeviceGemmV2Instance = 128, 128, KPerBlock, 16, 32, 32, 32, - 2, 2, + 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, @@ -264,7 +283,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) else i4 = (i4x2.data >> 4) & 0xf; - float v_b = i4_to_f32_gfx9(i4); +#if CK_USE_PK4_LAYOUT_SHUFFLE + float v_b = i4_to_f32_gfx9(i4) * 16; +#else + float v_b = i4 - 8; +#endif b_k_n_f32(k, n) = v_b; } }