mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
updated int4->fp8 two methods.
This commit is contained in:
@@ -96,23 +96,26 @@ __device__ inline f8x4_t i4_to_f8x4(int q)
|
||||
return amd_assembly_cvt_f8_to_f32(f32_0, f32_1, f32_2, f32_3);
|
||||
#else
|
||||
// [0, 1, 2, 3] encoded as FP8
|
||||
static constexpr uint32_t POS_E4M3s_REG1 = 0x2C282000;
|
||||
static constexpr uint32_t POS_E4M3s_TABLE1 = 0x2C282000;
|
||||
// [4, 5, 6, 7] encoded as FP8
|
||||
static constexpr uint32_t POS_E4M3s_REG2 = 0x36343230;
|
||||
static constexpr uint32_t POS_E4M3s_TABLE2 = 0x36343230;
|
||||
// [-8, -7, -6, -5] encoded as FP8
|
||||
static constexpr uint32_t NEG_E4M3s_REG1 = 0xB2B4B6B8;
|
||||
static constexpr uint32_t NEG_E4M3s_TABLE1 = 0xB2B4B6B8;
|
||||
// [-4, -3, -2, -1] encoded as FP8
|
||||
static constexpr uint32_t NEG_E4M3s_REG2 = 0xA0A8ACB0;
|
||||
static constexpr uint32_t NEG_E4M3s_TABLE2 = 0xA0A8ACB0;
|
||||
|
||||
uint32_t tmp_pos, tmp_neg, tmp_res;
|
||||
uint32_t tmp_pos, tmp_neg, tmp_res, final_sel;
|
||||
|
||||
uint32_t sign= q & 0x08080808;
|
||||
uint32_t dict_sel = q & 0x07070707;
|
||||
uint32_t final_sel = 0x03020100 | (sign >> 1);
|
||||
uint32_t sign = q >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
|
||||
vector_type<f8_t, 4> res;
|
||||
|
||||
tmp_pos = __builtin_amdgcn_perm(POS_E4M3s_REG2, POS_E4M3s_REG1, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(NEG_E4M3s_REG2, NEG_E4M3s_REG1, dict_sel);
|
||||
tmp_pos = __builtin_amdgcn_perm(POS_E4M3s_TABLE2, POS_E4M3s_TABLE1, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(NEG_E4M3s_TABLE2, NEG_E4M3s_TABLE1, dict_sel);
|
||||
tmp_res = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
|
||||
|
||||
res.template AsType<f8x4_t>()(Number<0>{}) = bit_cast<f8x4_t>(tmp_res);
|
||||
@@ -121,7 +124,58 @@ __device__ inline f8x4_t i4_to_f8x4(int q)
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ inline f8x8_t i4_to_fp8x8(int q) { return amd_assembly_i4_to_fp8x8(q); }
|
||||
__device__ inline f8x8_t i4_to_fp8x8(int q)
|
||||
{
|
||||
#if 1
|
||||
uint32_t i4x8 = static_cast<uint32_t>(q);
|
||||
uint32_t fp8x4_0 = 0;
|
||||
uint32_t fp8x4_1 = 0;
|
||||
float tmp_0, tmp_1;
|
||||
uint32_t i4x8_hi = i4x8 >> 4;
|
||||
|
||||
// // 0, 1, 2, 3
|
||||
// asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(tmp_0) : "v"(i4x8));
|
||||
// asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_2" : "=v"(tmp_1) : "v"(i4x8));
|
||||
// fp8x4_0 = __builtin_amdgcn_cvt_pk_fp8_f32(tmp_0, tmp_1, fp8x4_0, false);
|
||||
|
||||
// asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_1" : "=v"(tmp_0) : "v"(i4x8));
|
||||
// asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_3" : "=v"(tmp_1) : "v"(i4x8));
|
||||
// fp8x4_1 = __builtin_amdgcn_cvt_pk_fp8_f32(tmp_0, tmp_1, fp8x4_1, false);
|
||||
|
||||
// // 4, 5, 6, 7
|
||||
// asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(tmp_0) : "v"(i4x8_hi));
|
||||
// asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_2" : "=v"(tmp_1) : "v"(i4x8_hi));
|
||||
// fp8x4_0 = __builtin_amdgcn_cvt_pk_fp8_f32(tmp_0, tmp_1, fp8x4_0, true);
|
||||
|
||||
// asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_1" : "=v"(tmp_0) : "v"(i4x8_hi));
|
||||
// asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_3" : "=v"(tmp_1) : "v"(i4x8_hi));
|
||||
// fp8x4_1 = __builtin_amdgcn_cvt_pk_fp8_f32(tmp_0, tmp_1, fp8x4_1, true);
|
||||
|
||||
asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(tmp_0) : "v"(i4x8));
|
||||
asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_2" : "=v"(tmp_1): "v"(i4x8));
|
||||
asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2" : "=v"(fp8x4_0) : "v"(tmp_0), "v"(tmp_1));
|
||||
|
||||
asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_1" : "=v"(tmp_0) : "v"(i4x8));
|
||||
asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_3" : "=v"(tmp_1) : "v"(i4x8));
|
||||
asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2" : "=v"(fp8x4_1) : "v"(tmp_0), "v"(tmp_1));
|
||||
|
||||
// 4, 5, 6, 7
|
||||
asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(tmp_0) : "v"(i4x8_hi));
|
||||
asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_2" : "=v"(tmp_1) : "v"(i4x8_hi));
|
||||
asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2, op_sel:[0, 0, 1]" : "+v"(fp8x4_0) : "v"(tmp_0), "v"(tmp_1));
|
||||
|
||||
asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_1" : "=v"(tmp_0) : "v"(i4x8_hi));
|
||||
asm volatile("v_cvt_off_f32_i4 %0, %1, src0_sel:BYTE_3" : "=v"(tmp_1) : "v"(i4x8_hi));
|
||||
asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2, op_sel:[0, 0, 1]" : "+v"(fp8x4_1) : "v"(tmp_0), "v"(tmp_1));
|
||||
|
||||
vector_type<f8_t, 8> out;
|
||||
out.template AsType<f8x4_t>()(Number<0>{}) = bit_cast<f8x4_t>(fp8x4_0);
|
||||
out.template AsType<f8x4_t>()(Number<1>{}) = bit_cast<f8x4_t>(fp8x4_1);
|
||||
return out.template AsType<f8x8_t>()[Number<0>{}];
|
||||
#else
|
||||
return amd_assembly_i4_to_fp8x8(q);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ inline bhalf4_t i4_to_bhalf4(int q)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user