updated the 2 f8xi4 files output to multiply of 16

This commit is contained in:
mtgu0705
2025-03-14 13:37:15 +08:00
parent 8dfb6e82e4
commit bedaaf2747
2 changed files with 54 additions and 8 deletions

View File

@@ -22,7 +22,26 @@ using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
// C = C * 16 Using the assembly instruction for int4->fp8 conversion, need to mulitply 16
struct MultiplyConst
{
template <typename C, typename C1>
__host__ __device__ constexpr void operator()(C& c, const C1& c1) const;
template <>
__host__ __device__ constexpr void operator()<CDataType>(CDataType& c,
const CDataType& c1) const
{
#if CK_USE_PK4_LAYOUT_SHUFFLE
c = ck::type_convert<CDataType>(c1 * 16);
#else
c = ck::type_convert<CDataType>(c1);
#endif
}
};
using CElementOp = MultiplyConst;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
@@ -30,7 +49,7 @@ static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
// clang-format off
#if 0
#if 1
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle<
ALayout, BLayout, CLayout,
@@ -38,14 +57,14 @@ using DeviceGemmV2Instance =
AElementOp, BElementOp, CElementOp, GemmDefault,
256,
128, 128,
256, 16, 32,
128, 16, 32,
32, 32,
4, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 32, 1, 8>, 4,
1, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F8, F8, PermuteA, PermuteB>;
#else
@@ -284,8 +303,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
#if CK_USE_PK4_LAYOUT_SHUFFLE
float v_b = i4_to_f32_gfx9(i4) * 16;
#else
float v_b = i4 - 8;
#endif
float v_b = i4_to_f32_gfx9(i4);
b_k_n_f32(k, n) = v_b;
}
}

View File

@@ -22,7 +22,26 @@ using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
// C = C * 16 Using the assembly instruction for int4->fp8 conversion, need to mulitply 16
struct MultiplyConst
{
template <typename C, typename C1>
__host__ __device__ constexpr void operator()(C& c, const C1& c1) const;
template <>
__host__ __device__ constexpr void operator()<CDataType>(CDataType& c,
const CDataType& c1) const
{
#if CK_USE_PK4_LAYOUT_SHUFFLE
c = ck::type_convert<CDataType>(c1 * 16);
#else
c = ck::type_convert<CDataType>(c1);
#endif
}
};
using CElementOp = MultiplyConst;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
@@ -40,7 +59,7 @@ using DeviceGemmV2Instance =
128, 128,
KPerBlock, 16, 32,
32, 32,
2, 2,
4, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
@@ -264,7 +283,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
else
i4 = (i4x2.data >> 4) & 0xf;
float v_b = i4_to_f32_gfx9(i4);
#if CK_USE_PK4_LAYOUT_SHUFFLE
float v_b = i4_to_f32_gfx9(i4) * 16;
#else
float v_b = i4 - 8;
#endif
b_k_n_f32(k, n) = v_b;
}
}