From fc26d568ae5921ab1dae29cfd784db2dfbc0ad8f Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Wed, 26 Mar 2025 07:09:32 +0000 Subject: [PATCH] updated int4->fp8 two methods. --- .../element/unary_element_wise_operation.hpp | 74 ++++++++++++++++--- 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index d53d54bff1..517be925d4 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -96,23 +96,26 @@ __device__ inline f8x4_t i4_to_f8x4(int q) return amd_assembly_cvt_f8_to_f32(f32_0, f32_1, f32_2, f32_3); #else // [0, 1, 2, 3] encoded as FP8 - static constexpr uint32_t POS_E4M3s_REG1 = 0x2C282000; + static constexpr uint32_t POS_E4M3s_TABLE1 = 0x2C282000; // [4, 5, 6, 7] encoded as FP8 - static constexpr uint32_t POS_E4M3s_REG2 = 0x36343230; + static constexpr uint32_t POS_E4M3s_TABLE2 = 0x36343230; // [-8, -7, -6, -5] encoded as FP8 - static constexpr uint32_t NEG_E4M3s_REG1 = 0xB2B4B6B8; + static constexpr uint32_t NEG_E4M3s_TABLE1 = 0xB2B4B6B8; // [-4, -3, -2, -1] encoded as FP8 - static constexpr uint32_t NEG_E4M3s_REG2 = 0xA0A8ACB0; + static constexpr uint32_t NEG_E4M3s_TABLE2 = 0xA0A8ACB0; - uint32_t tmp_pos, tmp_neg, tmp_res; + uint32_t tmp_pos, tmp_neg, tmp_res, final_sel; - uint32_t sign= q & 0x08080808; uint32_t dict_sel = q & 0x07070707; - uint32_t final_sel = 0x03020100 | (sign >> 1); + uint32_t sign = q >> 1; + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(final_sel) + : "v"(sign), "v"(0x04040404), "v"(0x03020100)); + vector_type res; - tmp_pos = __builtin_amdgcn_perm(POS_E4M3s_REG2, POS_E4M3s_REG1, dict_sel); - tmp_neg = __builtin_amdgcn_perm(NEG_E4M3s_REG2, NEG_E4M3s_REG1, dict_sel); + tmp_pos = __builtin_amdgcn_perm(POS_E4M3s_TABLE2, POS_E4M3s_TABLE1, dict_sel); + tmp_neg = __builtin_amdgcn_perm(NEG_E4M3s_TABLE2, NEG_E4M3s_TABLE1, dict_sel); tmp_res = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel); res.template AsType()(Number<0>{}) = bit_cast(tmp_res); @@ -121,7 +124,58 @@ __device__ inline f8x4_t i4_to_f8x4(int q) #endif } -__device__ inline f8x8_t i4_to_fp8x8(int q) { return amd_assembly_i4_to_fp8x8(q); } +__device__ inline f8x8_t i4_to_fp8x8(int q) +{ +#if 1 + uint32_t i4x8 = static_cast(q); + uint32_t fp8x4_0 = 0; + uint32_t fp8x4_1 = 0; + float tmp_0, tmp_1; + uint32_t i4x8_hi = i4x8 >> 4; + + // // 0, 1, 2, 3 + // asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(tmp_0) : "v"(i4x8)); + // asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_2" : "=v"(tmp_1) : "v"(i4x8)); + // fp8x4_0 = __builtin_amdgcn_cvt_pk_fp8_f32(tmp_0, tmp_1, fp8x4_0, false); + + // asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_1" : "=v"(tmp_0) : "v"(i4x8)); + // asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_3" : "=v"(tmp_1) : "v"(i4x8)); + // fp8x4_1 = __builtin_amdgcn_cvt_pk_fp8_f32(tmp_0, tmp_1, fp8x4_1, false); + + // // 4, 5, 6, 7 + // asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(tmp_0) : "v"(i4x8_hi)); + // asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_2" : "=v"(tmp_1) : "v"(i4x8_hi)); + // fp8x4_0 = __builtin_amdgcn_cvt_pk_fp8_f32(tmp_0, tmp_1, fp8x4_0, true); + + // asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_1" : "=v"(tmp_0) : "v"(i4x8_hi)); + // asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_3" : "=v"(tmp_1) : "v"(i4x8_hi)); + // fp8x4_1 = __builtin_amdgcn_cvt_pk_fp8_f32(tmp_0, tmp_1, fp8x4_1, true); + + asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(tmp_0) : "v"(i4x8)); + asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_2" : "=v"(tmp_1): "v"(i4x8)); + asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2" : "=v"(fp8x4_0) : "v"(tmp_0), "v"(tmp_1)); + + asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_1" : "=v"(tmp_0) : "v"(i4x8)); + asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_3" : "=v"(tmp_1) : "v"(i4x8)); + asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2" : "=v"(fp8x4_1) : "v"(tmp_0), "v"(tmp_1)); + + // 4, 5, 6, 7 + asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(tmp_0) : "v"(i4x8_hi)); + asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_2" : "=v"(tmp_1) : "v"(i4x8_hi)); + asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2, op_sel:[0, 0, 1]" : "+v"(fp8x4_0) : "v"(tmp_0), "v"(tmp_1)); + + asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_1" : "=v"(tmp_0) : "v"(i4x8_hi)); + asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_3" : "=v"(tmp_1) : "v"(i4x8_hi)); + asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2, op_sel:[0, 0, 1]" : "+v"(fp8x4_1) : "v"(tmp_0), "v"(tmp_1)); + + vector_type out; + out.template AsType()(Number<0>{}) = bit_cast(fp8x4_0); + out.template AsType()(Number<1>{}) = bit_cast(fp8x4_1); + return out.template AsType()[Number<0>{}]; +#else + return amd_assembly_i4_to_fp8x8(q); +#endif +} __device__ inline bhalf4_t i4_to_bhalf4(int q) {