Fix and improve the gemm quant pipeline infrastructure (#3245)

This commit is contained in:
Thomas Ning
2025-11-26 18:04:27 -08:00
committed by GitHub
parent 79aae7c7f7
commit a38aeceb21
11 changed files with 96 additions and 272 deletions

View File

@@ -214,22 +214,27 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
// ---- Lower 4 int4 values (even positions) ----
// Extract dictionary indices: low 3 bits of each byte (values 0..7).
uint32_t dict_sel = a & 0x07070707;
uint32_t sign = a >> 1;
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(final_sel)
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
// sign bit is bit[2] of each nibble after bias; shift to isolate per-byte sign.
uint32_t sign = a >> 1;
// Build final selector:
// - bit 2 of each byte (0x04) selects negative vs positive table
// - 0x03020100 selects byte lanes [0,1,2,3] in order
final_sel = (sign & 0x04040404) | 0x03020100;
// Lookup positive and negative fp8 codes from the small register tables.
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
// Select per-lane between tmp_pos and tmp_neg using the sign-derived selector.
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
// ---- Upper 4 int4 values (odd positions) ----
// Shift to bring the high-nibble int4s into place and repeat the process.
a >>= 4;
dict_sel = a & 0x07070707;
sign = a >> 1;
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(final_sel)
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
dict_sel = a & 0x07070707;
sign = a >> 1;
final_sel = (sign & 0x04040404) | 0x03020100;
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
@@ -306,22 +311,29 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
// ---- Lower 4 int4 values (even positions) ----
// Extract dictionary indices: low 3 bits of each byte (values 0..7).
uint32_t dict_sel = a & 0x07070707;
uint32_t sign = a >> 1;
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(final_sel)
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
// sign bit is bit[2] of each nibble after bias; shift to isolate per-byte sign.
uint32_t sign = a >> 1;
// Build final selector:
// - bit 2 of each byte (0x04) selects negative vs positive table
// - 0x03020100 selects byte lanes [0,1,2,3] in order
final_sel = (sign & 0x04040404) | 0x03020100;
// Lookup positive and negative fp8 codes from the small register tables.
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
// Select per-lane between tmp_pos and tmp_neg using the sign-derived selector.
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
// ---- Upper 4 int4 values (odd positions) ----
// Shift to bring the high-nibble int4s into place and repeat the process.
a >>= 4;
dict_sel = a & 0x07070707;
sign = a >> 1;
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(final_sel)
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
dict_sel = a & 0x07070707;
sign = a >> 1;
final_sel = (sign & 0x04040404) | 0x03020100;
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);