mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
updated the 2 f8xi4 files output to multiply of 16
This commit is contained in:
@@ -22,7 +22,26 @@ using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
// C = C * 16 Using the assembly instruction for int4->fp8 conversion, need to mulitply 16
|
||||
struct MultiplyConst
|
||||
{
|
||||
template <typename C, typename C1>
|
||||
__host__ __device__ constexpr void operator()(C& c, const C1& c1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<CDataType>(CDataType& c,
|
||||
const CDataType& c1) const
|
||||
{
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
c = ck::type_convert<CDataType>(c1 * 16);
|
||||
#else
|
||||
c = ck::type_convert<CDataType>(c1);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
using CElementOp = MultiplyConst;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
@@ -30,7 +49,7 @@ static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
// clang-format off
|
||||
#if 0
|
||||
#if 1
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle<
|
||||
ALayout, BLayout, CLayout,
|
||||
@@ -38,14 +57,14 @@ using DeviceGemmV2Instance =
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
256,
|
||||
128, 128,
|
||||
256, 16, 32,
|
||||
128, 16, 32,
|
||||
32, 32,
|
||||
4, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 32, 32, 0,
|
||||
1, 1, S<1, 32, 1, 8>, 4,
|
||||
1, 1, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F8, F8, PermuteA, PermuteB>;
|
||||
|
||||
#else
|
||||
@@ -284,8 +303,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
i4 = (i4x2.data >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2.data >> 4) & 0xf;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
float v_b = i4_to_f32_gfx9(i4) * 16;
|
||||
#else
|
||||
float v_b = i4 - 8;
|
||||
#endif
|
||||
|
||||
float v_b = i4_to_f32_gfx9(i4);
|
||||
b_k_n_f32(k, n) = v_b;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,26 @@ using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
// C = C * 16 Using the assembly instruction for int4->fp8 conversion, need to mulitply 16
|
||||
struct MultiplyConst
|
||||
{
|
||||
template <typename C, typename C1>
|
||||
__host__ __device__ constexpr void operator()(C& c, const C1& c1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<CDataType>(CDataType& c,
|
||||
const CDataType& c1) const
|
||||
{
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
c = ck::type_convert<CDataType>(c1 * 16);
|
||||
#else
|
||||
c = ck::type_convert<CDataType>(c1);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
using CElementOp = MultiplyConst;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
@@ -40,7 +59,7 @@ using DeviceGemmV2Instance =
|
||||
128, 128,
|
||||
KPerBlock, 16, 32,
|
||||
32, 32,
|
||||
2, 2,
|
||||
4, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
@@ -264,7 +283,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
else
|
||||
i4 = (i4x2.data >> 4) & 0xf;
|
||||
|
||||
float v_b = i4_to_f32_gfx9(i4);
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
float v_b = i4_to_f32_gfx9(i4) * 16;
|
||||
#else
|
||||
float v_b = i4 - 8;
|
||||
#endif
|
||||
b_k_n_f32(k, n) = v_b;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user