modified the int4->fp8 conversion using original permute.

This commit is contained in:
mtgu0705
2025-04-09 11:32:00 +08:00
parent ce06c53bce
commit 4efdfc79f8
8 changed files with 61 additions and 60 deletions

View File

@@ -187,7 +187,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
}
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
#if CK_USE_PK4_LAYOUT_SHUFFLE
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
@@ -204,7 +204,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
// permute 01234567->20643175
{
int hi = input[4];
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
@@ -212,16 +212,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
{
int hi = input[5];
int lo = input[1];
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_k_n_preshuffled(j + 2, i) = i4x2;
}
{
int hi = input[6];
int lo = input[2];
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_k_n_preshuffled(j + 4, i) = i4x2;
@@ -229,7 +229,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
int hi = input[7];
int lo = input[3];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_k_n_preshuffled(j + 6, i) = i4x2;
@@ -287,7 +287,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
#if CK_USE_PK4_LAYOUT_SHUFFLE
float v_b = i4_to_f32_gfx9(i4) * 16;
#else
float v_b = i4 - 8;

View File

@@ -164,7 +164,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
}
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
#if CK_USE_PK4_LAYOUT_SHUFFLE
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
@@ -179,9 +179,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->04152637
// permute 01234567->02461357
{
int hi = input[4];
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
@@ -189,16 +189,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
{
int hi = input[5];
int lo = input[1];
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2;
}
{
int hi = input[6];
int lo = input[2];
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2;
@@ -206,7 +206,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
int hi = input[7];
int lo = input[3];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2;
@@ -266,7 +266,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
else
i4 = (i4x2.data >> 4) & 0xf;
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
#if CK_USE_PK4_LAYOUT_SHUFFLE
float v_b = i4_to_f32_gfx9(i4) * 16;
#else
float v_b = i4 - 8;

View File

@@ -353,7 +353,7 @@ int main(int argc, char* argv[])
}
#endif
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
#if CK_USE_PK4_LAYOUT_SHUFFLE
// vector pk_i4x4 permute
for(int e = 0; e < experts; e++)
{
@@ -372,7 +372,7 @@ int main(int argc, char* argv[])
// permute 01234567->04152637
{
int hi = input[4];
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
@@ -380,16 +380,16 @@ int main(int argc, char* argv[])
}
{
int hi = input[5];
int lo = input[1];
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b0_preshuffled(e, j + 2, i) = i4x2;
}
{
int hi = input[6];
int lo = input[2];
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b0_preshuffled(e, j + 4, i) = i4x2;
@@ -397,7 +397,7 @@ int main(int argc, char* argv[])
{
int hi = input[7];
int lo = input[3];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b0_preshuffled(e, j + 6, i) = i4x2;
@@ -505,7 +505,6 @@ int main(int argc, char* argv[])
c_t_k_n(t, topk_id, n),
d0_t_n(t, n),
d1_e_n(e, n));
e_t_n_host_result(t, topk_id, n) *= 16; // the result need to multiply by 16
}
}

View File

@@ -336,7 +336,7 @@ int main(int argc, char* argv[])
K,
device_op.GetPreShuffleParameters());
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
#if CK_USE_PK4_LAYOUT_SHUFFLE
// vector pk_i4x4 permute
for(int e = 0; e < experts; e++)
{
@@ -355,7 +355,7 @@ int main(int argc, char* argv[])
// permute 01234567->20643175
{
int hi = input[4];
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
@@ -363,16 +363,16 @@ int main(int argc, char* argv[])
}
{
int hi = input[5];
int lo = input[1];
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b0_preshuffled(e, j + 2, i) = i4x2;
}
{
int hi = input[6];
int lo = input[2];
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b0_preshuffled(e, j + 4, i) = i4x2;
@@ -380,7 +380,7 @@ int main(int argc, char* argv[])
{
int hi = input[7];
int lo = input[3];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b0_preshuffled(e, j + 6, i) = i4x2;

View File

@@ -173,10 +173,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// operations
#define CK_USE_PK4_LAYOUT_SHUFFLE 1
// hip solution for shuffle pk_i4 values during conversion to optimize number of binary
// operations
#define CK_USE_PK4_LAYOUT_SHUFFLE_V2 1
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1

View File

@@ -79,7 +79,7 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
return res.template AsType<half4_t>()[Number<0>{}];
}
__device__ inline f8x4_t i4_to_f8x4(int q)
__device__ inline uint32_t i4_to_f8x4(int q)
{
// register values [0, 1, 2, 3]
static constexpr uint32_t reg0 = 0x4C484000;
@@ -103,20 +103,35 @@ __device__ inline f8x4_t i4_to_f8x4(int q)
final_sel = (sign & 0x04040404) | 0x03020100;
#endif
vector_type<f8_t, 4> res;
// vector_type<f8_t, 4> res;
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
tmp_res = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
res.template AsType<f8x4_t>()(Number<0>{}) = bit_cast<f8x4_t>(tmp_res);
// res.template AsType<f8x4_t>()(Number<0>{}) = bit_cast<f8x4_t>(tmp_res);
return res.template AsType<f8x4_t>()[Number<0>{}];
// return res.template AsType<f8x4_t>()[Number<0>{}];
return tmp_res;
}
__device__ inline f8x8_t i4_to_fp8x8(int q)
{
return amd_assembly_i4_to_fp8x8(q);
{
#if 1
vector_type<f8_t, 8> result;
uint32_t res_lo = i4_to_f8x4(bit_cast<int>(q));
uint32_t res_hi = i4_to_f8x4(bit_cast<int>(q) >> 4);
result.template AsType<f8x4_t>()(Number<0>{}) =
bit_cast<f8x4_t>(__builtin_amdgcn_perm(res_hi, res_lo, 0x06040200));
result.template AsType<f8x4_t>()(Number<1>{}) =
bit_cast<f8x4_t>(__builtin_amdgcn_perm(res_hi, res_lo, 0x07050301));
return result.template AsType<f8x8_t>()[Number<0>{}];
#else
return amd_assembly_i4_to_fp8x8(q);
#endif
}
__device__ inline bhalf4_t i4_to_bhalf4(int q)
@@ -184,14 +199,8 @@ struct PassThroughPack8
__host__ __device__ constexpr void operator()(ck::f8x8_t& y, const ck::pk_i4x4_t& x) const
{
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
vector_type<f8_t, 8> result;
result.template AsType<f8x4_t>()(Number<0>{}) = i4_to_f8x4(bit_cast<int>(x));
result.template AsType<f8x4_t>()(Number<1>{}) = i4_to_f8x4(bit_cast<int>(x) >> 4);
y = result.template AsType<f8x8_t>()[Number<0>{}];
#if CK_USE_PK4_LAYOUT_SHUFFLE
y= i4_to_fp8x8(bit_cast<int>(x));
#else
// Added pk_i4_t to f8x2_fnuz_t conversion
vector_type<f8_t, 8> dst;

View File

@@ -93,8 +93,8 @@ struct ReferenceMoeGemm : public device::BaseOperator
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
v_a = i4_to_f32_gfx9(i4);
#if CK_USE_PK4_LAYOUT_SHUFFLE
v_a = i4_to_f32_gfx9(i4) * 16;
#else
v_a = i4 - 8;
#endif
@@ -112,8 +112,8 @@ struct ReferenceMoeGemm : public device::BaseOperator
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
v_b = i4_to_f32_gfx9(i4);
#if CK_USE_PK4_LAYOUT_SHUFFLE
v_b = i4_to_f32_gfx9(i4) * 16;
#else
v_b = i4 - 8;
#endif

View File

@@ -108,8 +108,8 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
v_a = i4_to_f32_gfx9(i4);
#if CK_USE_PK4_LAYOUT_SHUFFLE
v_a = i4_to_f32_gfx9(i4) * 16;
#else
v_a = i4 - 8;
#endif
@@ -126,7 +126,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
#if CK_USE_PK4_LAYOUT_SHUFFLE
v_b = i4_to_f32_gfx9(i4) * 16;
#else
v_b = i4 - 8;
@@ -145,9 +145,6 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
D0DataType v_d1 = arg.d1_(e, n); // b
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w);
arg.c_t_n_(t, n) += v_c;
// #if CK_USE_PK4_LAYOUT_SHUFFLE_V2
// arg.c_t_n_(t, n) *= 16;
// #endif
}
};