mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
modified the int4->fp8 conversion using original permute.
This commit is contained in:
@@ -187,7 +187,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
}
|
||||
}
|
||||
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
@@ -204,7 +204,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int hi = input[4];
|
||||
int hi = input[2];
|
||||
int lo = input[0];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
@@ -212,16 +212,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[5];
|
||||
int lo = input[1];
|
||||
int hi = input[6];
|
||||
int lo = input[4];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_preshuffled(j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[6];
|
||||
int lo = input[2];
|
||||
int hi = input[3];
|
||||
int lo = input[1];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_preshuffled(j + 4, i) = i4x2;
|
||||
@@ -229,7 +229,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
{
|
||||
int hi = input[7];
|
||||
int lo = input[3];
|
||||
int lo = input[5];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_preshuffled(j + 6, i) = i4x2;
|
||||
@@ -287,7 +287,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
i4 = (i4x2.data >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2.data >> 4) & 0xf;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
float v_b = i4_to_f32_gfx9(i4) * 16;
|
||||
#else
|
||||
float v_b = i4 - 8;
|
||||
|
||||
@@ -164,7 +164,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
}
|
||||
}
|
||||
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
@@ -179,9 +179,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->04152637
|
||||
// permute 01234567->02461357
|
||||
{
|
||||
int hi = input[4];
|
||||
int hi = input[2];
|
||||
int lo = input[0];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
@@ -189,16 +189,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[5];
|
||||
int lo = input[1];
|
||||
int hi = input[6];
|
||||
int lo = input[4];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[6];
|
||||
int lo = input[2];
|
||||
int hi = input[3];
|
||||
int lo = input[1];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 4, i) = i4x2;
|
||||
@@ -206,7 +206,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
{
|
||||
int hi = input[7];
|
||||
int lo = input[3];
|
||||
int lo = input[5];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 6, i) = i4x2;
|
||||
@@ -266,7 +266,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
else
|
||||
i4 = (i4x2.data >> 4) & 0xf;
|
||||
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
float v_b = i4_to_f32_gfx9(i4) * 16;
|
||||
#else
|
||||
float v_b = i4 - 8;
|
||||
|
||||
@@ -353,7 +353,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
#endif
|
||||
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
// vector pk_i4x4 permute
|
||||
for(int e = 0; e < experts; e++)
|
||||
{
|
||||
@@ -372,7 +372,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
// permute 01234567->04152637
|
||||
{
|
||||
int hi = input[4];
|
||||
int hi = input[2];
|
||||
int lo = input[0];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
@@ -380,16 +380,16 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[5];
|
||||
int lo = input[1];
|
||||
int hi = input[6];
|
||||
int lo = input[4];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b0_preshuffled(e, j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[6];
|
||||
int lo = input[2];
|
||||
int hi = input[3];
|
||||
int lo = input[1];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b0_preshuffled(e, j + 4, i) = i4x2;
|
||||
@@ -397,7 +397,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
{
|
||||
int hi = input[7];
|
||||
int lo = input[3];
|
||||
int lo = input[5];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b0_preshuffled(e, j + 6, i) = i4x2;
|
||||
@@ -505,7 +505,6 @@ int main(int argc, char* argv[])
|
||||
c_t_k_n(t, topk_id, n),
|
||||
d0_t_n(t, n),
|
||||
d1_e_n(e, n));
|
||||
e_t_n_host_result(t, topk_id, n) *= 16; // the result need to multiply by 16
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -336,7 +336,7 @@ int main(int argc, char* argv[])
|
||||
K,
|
||||
device_op.GetPreShuffleParameters());
|
||||
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
// vector pk_i4x4 permute
|
||||
for(int e = 0; e < experts; e++)
|
||||
{
|
||||
@@ -355,7 +355,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int hi = input[4];
|
||||
int hi = input[2];
|
||||
int lo = input[0];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
@@ -363,16 +363,16 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[5];
|
||||
int lo = input[1];
|
||||
int hi = input[6];
|
||||
int lo = input[4];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b0_preshuffled(e, j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[6];
|
||||
int lo = input[2];
|
||||
int hi = input[3];
|
||||
int lo = input[1];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b0_preshuffled(e, j + 4, i) = i4x2;
|
||||
@@ -380,7 +380,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
{
|
||||
int hi = input[7];
|
||||
int lo = input[3];
|
||||
int lo = input[5];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b0_preshuffled(e, j + 6, i) = i4x2;
|
||||
|
||||
@@ -173,10 +173,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
// operations
|
||||
#define CK_USE_PK4_LAYOUT_SHUFFLE 1
|
||||
|
||||
// hip solution for shuffle pk_i4 values during conversion to optimize number of binary
|
||||
// operations
|
||||
#define CK_USE_PK4_LAYOUT_SHUFFLE_V2 1
|
||||
|
||||
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
|
||||
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
|
||||
return res.template AsType<half4_t>()[Number<0>{}];
|
||||
}
|
||||
|
||||
__device__ inline f8x4_t i4_to_f8x4(int q)
|
||||
__device__ inline uint32_t i4_to_f8x4(int q)
|
||||
{
|
||||
// register values [0, 1, 2, 3]
|
||||
static constexpr uint32_t reg0 = 0x4C484000;
|
||||
@@ -103,20 +103,35 @@ __device__ inline f8x4_t i4_to_f8x4(int q)
|
||||
final_sel = (sign & 0x04040404) | 0x03020100;
|
||||
#endif
|
||||
|
||||
vector_type<f8_t, 4> res;
|
||||
// vector_type<f8_t, 4> res;
|
||||
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
tmp_res = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
|
||||
|
||||
res.template AsType<f8x4_t>()(Number<0>{}) = bit_cast<f8x4_t>(tmp_res);
|
||||
// res.template AsType<f8x4_t>()(Number<0>{}) = bit_cast<f8x4_t>(tmp_res);
|
||||
|
||||
return res.template AsType<f8x4_t>()[Number<0>{}];
|
||||
// return res.template AsType<f8x4_t>()[Number<0>{}];
|
||||
return tmp_res;
|
||||
}
|
||||
|
||||
__device__ inline f8x8_t i4_to_fp8x8(int q)
|
||||
{
|
||||
return amd_assembly_i4_to_fp8x8(q);
|
||||
{
|
||||
#if 1
|
||||
vector_type<f8_t, 8> result;
|
||||
|
||||
uint32_t res_lo = i4_to_f8x4(bit_cast<int>(q));
|
||||
uint32_t res_hi = i4_to_f8x4(bit_cast<int>(q) >> 4);
|
||||
|
||||
result.template AsType<f8x4_t>()(Number<0>{}) =
|
||||
bit_cast<f8x4_t>(__builtin_amdgcn_perm(res_hi, res_lo, 0x06040200));
|
||||
result.template AsType<f8x4_t>()(Number<1>{}) =
|
||||
bit_cast<f8x4_t>(__builtin_amdgcn_perm(res_hi, res_lo, 0x07050301));
|
||||
|
||||
return result.template AsType<f8x8_t>()[Number<0>{}];
|
||||
#else
|
||||
return amd_assembly_i4_to_fp8x8(q);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ inline bhalf4_t i4_to_bhalf4(int q)
|
||||
@@ -184,14 +199,8 @@ struct PassThroughPack8
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::f8x8_t& y, const ck::pk_i4x4_t& x) const
|
||||
{
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
|
||||
vector_type<f8_t, 8> result;
|
||||
|
||||
result.template AsType<f8x4_t>()(Number<0>{}) = i4_to_f8x4(bit_cast<int>(x));
|
||||
result.template AsType<f8x4_t>()(Number<1>{}) = i4_to_f8x4(bit_cast<int>(x) >> 4);
|
||||
|
||||
y = result.template AsType<f8x8_t>()[Number<0>{}];
|
||||
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
y= i4_to_fp8x8(bit_cast<int>(x));
|
||||
#else
|
||||
// Added pk_i4_t to f8x2_fnuz_t conversion
|
||||
vector_type<f8_t, 8> dst;
|
||||
|
||||
@@ -93,8 +93,8 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
|
||||
v_a = i4_to_f32_gfx9(i4);
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
v_a = i4_to_f32_gfx9(i4) * 16;
|
||||
#else
|
||||
v_a = i4 - 8;
|
||||
#endif
|
||||
@@ -112,8 +112,8 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
|
||||
v_b = i4_to_f32_gfx9(i4);
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
v_b = i4_to_f32_gfx9(i4) * 16;
|
||||
#else
|
||||
v_b = i4 - 8;
|
||||
#endif
|
||||
|
||||
@@ -108,8 +108,8 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
|
||||
v_a = i4_to_f32_gfx9(i4);
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
v_a = i4_to_f32_gfx9(i4) * 16;
|
||||
#else
|
||||
v_a = i4 - 8;
|
||||
#endif
|
||||
@@ -126,7 +126,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE_V2
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
v_b = i4_to_f32_gfx9(i4) * 16;
|
||||
#else
|
||||
v_b = i4 - 8;
|
||||
@@ -145,9 +145,6 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
D0DataType v_d1 = arg.d1_(e, n); // b
|
||||
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w);
|
||||
arg.c_t_n_(t, n) += v_c;
|
||||
// #if CK_USE_PK4_LAYOUT_SHUFFLE_V2
|
||||
// arg.c_t_n_(t, n) *= 16;
|
||||
// #endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user