mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
Fix and improve the gemm quant pipeline infrastructure (#3245)
This commit is contained in:
@@ -214,22 +214,27 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
|
||||
|
||||
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
|
||||
|
||||
// ---- Lower 4 int4 values (even positions) ----
|
||||
// Extract dictionary indices: low 3 bits of each byte (values 0..7).
|
||||
uint32_t dict_sel = a & 0x07070707;
|
||||
uint32_t sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
// sign bit is bit[2] of each nibble after bias; shift to isolate per-byte sign.
|
||||
uint32_t sign = a >> 1;
|
||||
// Build final selector:
|
||||
// - bit 2 of each byte (0x04) selects negative vs positive table
|
||||
// - 0x03020100 selects byte lanes [0,1,2,3] in order
|
||||
final_sel = (sign & 0x04040404) | 0x03020100;
|
||||
// Lookup positive and negative fp8 codes from the small register tables.
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
// Select per-lane between tmp_pos and tmp_neg using the sign-derived selector.
|
||||
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
|
||||
|
||||
// ---- Upper 4 int4 values (odd positions) ----
|
||||
// Shift to bring the high-nibble int4s into place and repeat the process.
|
||||
a >>= 4;
|
||||
dict_sel = a & 0x07070707;
|
||||
sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
dict_sel = a & 0x07070707;
|
||||
sign = a >> 1;
|
||||
final_sel = (sign & 0x04040404) | 0x03020100;
|
||||
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
@@ -306,22 +311,29 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)
|
||||
|
||||
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
|
||||
|
||||
// ---- Lower 4 int4 values (even positions) ----
|
||||
// Extract dictionary indices: low 3 bits of each byte (values 0..7).
|
||||
uint32_t dict_sel = a & 0x07070707;
|
||||
uint32_t sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
// sign bit is bit[2] of each nibble after bias; shift to isolate per-byte sign.
|
||||
uint32_t sign = a >> 1;
|
||||
// Build final selector:
|
||||
// - bit 2 of each byte (0x04) selects negative vs positive table
|
||||
// - 0x03020100 selects byte lanes [0,1,2,3] in order
|
||||
final_sel = (sign & 0x04040404) | 0x03020100;
|
||||
|
||||
// Lookup positive and negative fp8 codes from the small register tables.
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
// Select per-lane between tmp_pos and tmp_neg using the sign-derived selector.
|
||||
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
|
||||
|
||||
// ---- Upper 4 int4 values (odd positions) ----
|
||||
// Shift to bring the high-nibble int4s into place and repeat the process.
|
||||
a >>= 4;
|
||||
dict_sel = a & 0x07070707;
|
||||
sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
dict_sel = a & 0x07070707;
|
||||
sign = a >> 1;
|
||||
final_sel = (sign & 0x04040404) | 0x03020100;
|
||||
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
|
||||
Reference in New Issue
Block a user