diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp index d84d7ae82e..ab1713e507 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp @@ -187,7 +187,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) } } -#if CK_USE_PK4_LAYOUT_SHUFFLE_V2 +#if CK_USE_PK4_LAYOUT_SHUFFLE // vector pk_i4x4 permute for(int i = 0; i < N; i++) { @@ -204,7 +204,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) // permute 01234567->20643175 { - int hi = input[4]; + int hi = input[2]; int lo = input[0]; int i4x2 = (hi << 4) | lo; @@ -212,16 +212,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) } { - int hi = input[5]; - int lo = input[1]; + int hi = input[6]; + int lo = input[4]; int i4x2 = (hi << 4) | lo; b_k_n_preshuffled(j + 2, i) = i4x2; } { - int hi = input[6]; - int lo = input[2]; + int hi = input[3]; + int lo = input[1]; int i4x2 = (hi << 4) | lo; b_k_n_preshuffled(j + 4, i) = i4x2; @@ -229,7 +229,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { int hi = input[7]; - int lo = input[3]; + int lo = input[5]; int i4x2 = (hi << 4) | lo; b_k_n_preshuffled(j + 6, i) = i4x2; @@ -287,7 +287,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) i4 = (i4x2.data >> 0) & 0xf; else i4 = (i4x2.data >> 4) & 0xf; -#if CK_USE_PK4_LAYOUT_SHUFFLE_V2 +#if CK_USE_PK4_LAYOUT_SHUFFLE float v_b = i4_to_f32_gfx9(i4) * 16; #else float v_b = i4 - 8; diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp index 950bbd6c34..dc4ad64b72 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp @@ -164,7 +164,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) } } -#if CK_USE_PK4_LAYOUT_SHUFFLE_V2 +#if CK_USE_PK4_LAYOUT_SHUFFLE // vector pk_i4x4 permute for(int i = 0; i < N; i++) { @@ -179,9 +179,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) input[k * 2 + 1] = (i4x2 >> 0) & 0xf; } - // permute 01234567->04152637 + // permute 01234567->02461357 { - int hi = input[4]; + int hi = input[2]; int lo = input[0]; int i4x2 = (hi << 4) | lo; @@ -189,16 +189,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) } { - int hi = input[5]; - int lo = input[1]; + int hi = input[6]; + int lo = input[4]; int i4x2 = (hi << 4) | lo; b_k_n_permute(j + 2, i) = i4x2; } { - int hi = input[6]; - int lo = input[2]; + int hi = input[3]; + int lo = input[1]; int i4x2 = (hi << 4) | lo; b_k_n_permute(j + 4, i) = i4x2; @@ -206,7 +206,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { int hi = input[7]; - int lo = input[3]; + int lo = input[5]; int i4x2 = (hi << 4) | lo; b_k_n_permute(j + 6, i) = i4x2; @@ -266,7 +266,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) else i4 = (i4x2.data >> 4) & 0xf; -#if CK_USE_PK4_LAYOUT_SHUFFLE_V2 +#if CK_USE_PK4_LAYOUT_SHUFFLE float v_b = i4_to_f32_gfx9(i4) * 16; #else float v_b = i4 - 8; diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 0f68485321..74d38f3a4f 100755 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -353,7 +353,7 @@ int main(int argc, char* argv[]) } #endif -#if CK_USE_PK4_LAYOUT_SHUFFLE_V2 +#if CK_USE_PK4_LAYOUT_SHUFFLE // vector pk_i4x4 permute for(int e = 0; e < experts; e++) { @@ -372,7 +372,7 @@ int main(int argc, char* argv[]) // permute 01234567->04152637 { - int hi = input[4]; + int hi = input[2]; int lo = input[0]; int i4x2 = (hi << 4) | lo; @@ -380,16 +380,16 @@ int main(int argc, char* argv[]) } { - int hi = input[5]; - int lo = input[1]; + int hi = input[6]; + int lo = input[4]; int i4x2 = (hi << 4) | lo; b0_preshuffled(e, j + 2, i) = i4x2; } { - int hi = input[6]; - int lo = input[2]; + int hi = input[3]; + int lo = input[1]; int i4x2 = (hi << 4) | lo; b0_preshuffled(e, j + 4, i) = i4x2; @@ -397,7 +397,7 @@ int main(int argc, char* argv[]) { int hi = input[7]; - int lo = input[3]; + int lo = input[5]; int i4x2 = (hi << 4) | lo; b0_preshuffled(e, j + 6, i) = i4x2; @@ -505,7 +505,6 @@ int main(int argc, char* argv[]) c_t_k_n(t, topk_id, n), d0_t_n(t, n), d1_e_n(e, n)); - e_t_n_host_result(t, topk_id, n) *= 16; // the result need to multiply by 16 } } diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index 6fe1b46ef7..ce8088064d 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -336,7 +336,7 @@ int main(int argc, char* argv[]) K, device_op.GetPreShuffleParameters()); -#if CK_USE_PK4_LAYOUT_SHUFFLE_V2 +#if CK_USE_PK4_LAYOUT_SHUFFLE // vector pk_i4x4 permute for(int e = 0; e < experts; e++) { @@ -355,7 +355,7 @@ int main(int argc, char* argv[]) // permute 01234567->20643175 { - int hi = input[4]; + int hi = input[2]; int lo = input[0]; int i4x2 = (hi << 4) | lo; @@ -363,16 +363,16 @@ int main(int argc, char* argv[]) } { - int hi = input[5]; - int lo = input[1]; + int hi = input[6]; + int lo = input[4]; int i4x2 = (hi << 4) | lo; b0_preshuffled(e, j + 2, i) = i4x2; } { - int hi = input[6]; - int lo = input[2]; + int hi = input[3]; + int lo = input[1]; int i4x2 = (hi << 4) | lo; b0_preshuffled(e, j + 4, i) = i4x2; @@ -380,7 +380,7 @@ int main(int argc, char* argv[]) { int hi = input[7]; - int lo = input[3]; + int lo = input[5]; int i4x2 = (hi << 4) | lo; b0_preshuffled(e, j + 6, i) = i4x2; diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 971946ca32..6f510e735e 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -173,10 +173,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // operations #define CK_USE_PK4_LAYOUT_SHUFFLE 1 -// hip solution for shuffle pk_i4 values during conversion to optimize number of binary -// operations -#define CK_USE_PK4_LAYOUT_SHUFFLE_V2 1 - // block synchronization only s_wait lgkmcnt(0), not vmcnt(0) #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index b26c11dc59..1f55eb19f5 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -79,7 +79,7 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale) return res.template AsType()[Number<0>{}]; } -__device__ inline f8x4_t i4_to_f8x4(int q) +__device__ inline uint32_t i4_to_f8x4(int q) { // register values [0, 1, 2, 3] static constexpr uint32_t reg0 = 0x4C484000; @@ -103,20 +103,35 @@ __device__ inline f8x4_t i4_to_f8x4(int q) final_sel = (sign & 0x04040404) | 0x03020100; #endif - vector_type res; + // vector_type res; tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); tmp_res = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel); - res.template AsType()(Number<0>{}) = bit_cast(tmp_res); + // res.template AsType()(Number<0>{}) = bit_cast(tmp_res); - return res.template AsType()[Number<0>{}]; + // return res.template AsType()[Number<0>{}]; + return tmp_res; } __device__ inline f8x8_t i4_to_fp8x8(int q) -{ - return amd_assembly_i4_to_fp8x8(q); +{ +#if 1 + vector_type result; + + uint32_t res_lo = i4_to_f8x4(bit_cast(q)); + uint32_t res_hi = i4_to_f8x4(bit_cast(q) >> 4); + + result.template AsType()(Number<0>{}) = + bit_cast(__builtin_amdgcn_perm(res_hi, res_lo, 0x06040200)); + result.template AsType()(Number<1>{}) = + bit_cast(__builtin_amdgcn_perm(res_hi, res_lo, 0x07050301)); + + return result.template AsType()[Number<0>{}]; +#else + return amd_assembly_i4_to_fp8x8(q); +#endif } __device__ inline bhalf4_t i4_to_bhalf4(int q) @@ -184,14 +199,8 @@ struct PassThroughPack8 __host__ __device__ constexpr void operator()(ck::f8x8_t& y, const ck::pk_i4x4_t& x) const { -#if CK_USE_PK4_LAYOUT_SHUFFLE_V2 - vector_type result; - - result.template AsType()(Number<0>{}) = i4_to_f8x4(bit_cast(x)); - result.template AsType()(Number<1>{}) = i4_to_f8x4(bit_cast(x) >> 4); - - y = result.template AsType()[Number<0>{}]; - +#if CK_USE_PK4_LAYOUT_SHUFFLE + y= i4_to_fp8x8(bit_cast(x)); #else // Added pk_i4_t to f8x2_fnuz_t conversion vector_type dst; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index bb3b98771c..bde19379bd 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -93,8 +93,8 @@ struct ReferenceMoeGemm : public device::BaseOperator i4 = (i4x2 >> 0) & 0xf; else i4 = (i4x2 >> 4) & 0xf; -#if CK_USE_PK4_LAYOUT_SHUFFLE_V2 - v_a = i4_to_f32_gfx9(i4); +#if CK_USE_PK4_LAYOUT_SHUFFLE + v_a = i4_to_f32_gfx9(i4) * 16; #else v_a = i4 - 8; #endif @@ -112,8 +112,8 @@ struct ReferenceMoeGemm : public device::BaseOperator i4 = (i4x2 >> 0) & 0xf; else i4 = (i4x2 >> 4) & 0xf; -#if CK_USE_PK4_LAYOUT_SHUFFLE_V2 - v_b = i4_to_f32_gfx9(i4); +#if CK_USE_PK4_LAYOUT_SHUFFLE + v_b = i4_to_f32_gfx9(i4) * 16; #else v_b = i4 - 8; #endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp index 04cfde88c2..3b928010bb 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -108,8 +108,8 @@ struct ReferenceMoeGemm2 : public device::BaseOperator i4 = (i4x2 >> 0) & 0xf; else i4 = (i4x2 >> 4) & 0xf; -#if CK_USE_PK4_LAYOUT_SHUFFLE_V2 - v_a = i4_to_f32_gfx9(i4); +#if CK_USE_PK4_LAYOUT_SHUFFLE + v_a = i4_to_f32_gfx9(i4) * 16; #else v_a = i4 - 8; #endif @@ -126,7 +126,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator i4 = (i4x2 >> 0) & 0xf; else i4 = (i4x2 >> 4) & 0xf; -#if CK_USE_PK4_LAYOUT_SHUFFLE_V2 +#if CK_USE_PK4_LAYOUT_SHUFFLE v_b = i4_to_f32_gfx9(i4) * 16; #else v_b = i4 - 8; @@ -145,9 +145,6 @@ struct ReferenceMoeGemm2 : public device::BaseOperator D0DataType v_d1 = arg.d1_(e, n); // b arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w); arg.c_t_n_(t, n) += v_c; -// #if CK_USE_PK4_LAYOUT_SHUFFLE_V2 -// arg.c_t_n_(t, n) *= 16; -// #endif } };