updated int4->fp8 two methods.

This commit is contained in:
mtgu0705
2025-03-26 07:09:32 +00:00
parent facd5560e7
commit fc26d568ae

View File

@@ -96,23 +96,26 @@ __device__ inline f8x4_t i4_to_f8x4(int q)
return amd_assembly_cvt_f8_to_f32(f32_0, f32_1, f32_2, f32_3);
#else
// [0, 1, 2, 3] encoded as FP8
static constexpr uint32_t POS_E4M3s_REG1 = 0x2C282000;
static constexpr uint32_t POS_E4M3s_TABLE1 = 0x2C282000;
// [4, 5, 6, 7] encoded as FP8
static constexpr uint32_t POS_E4M3s_REG2 = 0x36343230;
static constexpr uint32_t POS_E4M3s_TABLE2 = 0x36343230;
// [-8, -7, -6, -5] encoded as FP8
static constexpr uint32_t NEG_E4M3s_REG1 = 0xB2B4B6B8;
static constexpr uint32_t NEG_E4M3s_TABLE1 = 0xB2B4B6B8;
// [-4, -3, -2, -1] encoded as FP8
static constexpr uint32_t NEG_E4M3s_REG2 = 0xA0A8ACB0;
static constexpr uint32_t NEG_E4M3s_TABLE2 = 0xA0A8ACB0;
uint32_t tmp_pos, tmp_neg, tmp_res;
uint32_t tmp_pos, tmp_neg, tmp_res, final_sel;
uint32_t sign= q & 0x08080808;
uint32_t dict_sel = q & 0x07070707;
uint32_t final_sel = 0x03020100 | (sign >> 1);
uint32_t sign = q >> 1;
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(final_sel)
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
vector_type<f8_t, 4> res;
tmp_pos = __builtin_amdgcn_perm(POS_E4M3s_REG2, POS_E4M3s_REG1, dict_sel);
tmp_neg = __builtin_amdgcn_perm(NEG_E4M3s_REG2, NEG_E4M3s_REG1, dict_sel);
tmp_pos = __builtin_amdgcn_perm(POS_E4M3s_TABLE2, POS_E4M3s_TABLE1, dict_sel);
tmp_neg = __builtin_amdgcn_perm(NEG_E4M3s_TABLE2, NEG_E4M3s_TABLE1, dict_sel);
tmp_res = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
res.template AsType<f8x4_t>()(Number<0>{}) = bit_cast<f8x4_t>(tmp_res);
@@ -121,7 +124,58 @@ __device__ inline f8x4_t i4_to_f8x4(int q)
#endif
}
__device__ inline f8x8_t i4_to_fp8x8(int q) { return amd_assembly_i4_to_fp8x8(q); }
__device__ inline f8x8_t i4_to_fp8x8(int q)
{
#if 1
uint32_t i4x8 = static_cast<uint32_t>(q);
uint32_t fp8x4_0 = 0;
uint32_t fp8x4_1 = 0;
float tmp_0, tmp_1;
uint32_t i4x8_hi = i4x8 >> 4;
// // 0, 1, 2, 3
// asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(tmp_0) : "v"(i4x8));
// asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_2" : "=v"(tmp_1) : "v"(i4x8));
// fp8x4_0 = __builtin_amdgcn_cvt_pk_fp8_f32(tmp_0, tmp_1, fp8x4_0, false);
// asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_1" : "=v"(tmp_0) : "v"(i4x8));
// asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_3" : "=v"(tmp_1) : "v"(i4x8));
// fp8x4_1 = __builtin_amdgcn_cvt_pk_fp8_f32(tmp_0, tmp_1, fp8x4_1, false);
// // 4, 5, 6, 7
// asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(tmp_0) : "v"(i4x8_hi));
// asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_2" : "=v"(tmp_1) : "v"(i4x8_hi));
// fp8x4_0 = __builtin_amdgcn_cvt_pk_fp8_f32(tmp_0, tmp_1, fp8x4_0, true);
// asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_1" : "=v"(tmp_0) : "v"(i4x8_hi));
// asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_3" : "=v"(tmp_1) : "v"(i4x8_hi));
// fp8x4_1 = __builtin_amdgcn_cvt_pk_fp8_f32(tmp_0, tmp_1, fp8x4_1, true);
asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(tmp_0) : "v"(i4x8));
asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_2" : "=v"(tmp_1): "v"(i4x8));
asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2" : "=v"(fp8x4_0) : "v"(tmp_0), "v"(tmp_1));
asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_1" : "=v"(tmp_0) : "v"(i4x8));
asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_3" : "=v"(tmp_1) : "v"(i4x8));
asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2" : "=v"(fp8x4_1) : "v"(tmp_0), "v"(tmp_1));
// 4, 5, 6, 7
asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(tmp_0) : "v"(i4x8_hi));
asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_2" : "=v"(tmp_1) : "v"(i4x8_hi));
asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2, op_sel:[0, 0, 1]" : "+v"(fp8x4_0) : "v"(tmp_0), "v"(tmp_1));
asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_1" : "=v"(tmp_0) : "v"(i4x8_hi));
asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_3" : "=v"(tmp_1) : "v"(i4x8_hi));
asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2, op_sel:[0, 0, 1]" : "+v"(fp8x4_1) : "v"(tmp_0), "v"(tmp_1));
vector_type<f8_t, 8> out;
out.template AsType<f8x4_t>()(Number<0>{}) = bit_cast<f8x4_t>(fp8x4_0);
out.template AsType<f8x4_t>()(Number<1>{}) = bit_cast<f8x4_t>(fp8x4_1);
return out.template AsType<f8x8_t>()[Number<0>{}];
#else
return amd_assembly_i4_to_fp8x8(q);
#endif
}
__device__ inline bhalf4_t i4_to_bhalf4(int q)
{