iq1bn(no lookup): Metal

In summary, compared to lookup, the multiplication based approach is
* Much better on AVX2
* Slightly better on CUDA
* Slightly worse on Metal
* Much worse on NEON
This commit is contained in:
Kawrakow
2024-07-16 09:12:15 +02:00
parent d0f9d146b8
commit d84748b71b

View File

@@ -5102,30 +5102,31 @@ void kernel_mul_mv_iq1_bn_f32_impl(
const float values[3] = {-1.f, 0.f, 1.f}; const float values[3] = {-1.f, 0.f, 1.f};
constexpr int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
for (int ib32 = ix; ib32 < nb32; ib32 += 8) { for (int ib32 = ix; ib32 < nb32; ib32 += 8) {
yl[0] = y4[0]; yl[1] = y4[1]; yl[0] = y4[0]; yl[1] = y4[1];
const int ibl = ib32 / (QK_IQ1BN / 32); const int ibl = ib32 / (QK_IQ1BN / 32);
const int ib = ib32 % (QK_IQ1BN / 32); const int ib = ib32 % (QK_IQ1BN / 32);
const int il = 4*ib + ir;
device const block_iq1_bn * xr = x + ibl; device const block_iq1_bn * xr = x + ibl;
device const uint8_t * extra = (device const uint8_t *)&xr->extra; device const uint8_t * extra = (device const uint8_t *)&xr->extra;
device const uint8_t * ql = xr->ql + 4 * ib + ir; device const uint8_t * ql = xr->ql + il;
device const uint8_t * qh = xr->qh + 2 * ib + ir/2; device const uint8_t * qh = xr->qh + il%4;
for (int row = 0; row < N_DST; row++) { for (int row = 0; row < N_DST; row++) {
uint8_t signs = extra[0] >> (4*ib + ir); uint8_t h = extra[0] >> il;
uint32_t v = iq1bn_grid_u16[ql[0] | ((qh[0] << (8 - 4*(ir%2))) & 0x0f00)]; int16_t val = ql[0] | ((qh[0] << (8 - 4*(il/4))) & 0x0f00) | ((extra[0] << (12 - il)) & 4096);
uint32_t v32 = v | (v << 14); float4 acc4 = yl[0] * float4{values[(val*k_mult[0] & 0x1fff)*3 >> 13], values[(val*k_mult[1] & 0x1fff)*3 >> 13],
aux32[0] = v32 & 0x03030303; aux32[1] = (v32 >> 4) & 0x03030303; values[(val*k_mult[2] & 0x1fff)*3 >> 13], values[(val*k_mult[3] & 0x1fff)*3 >> 13]}
float4 acc4 = yl[0] * float4{values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]} + yl[1] * float4{values[(val*k_mult[4] & 0x1fff)*3 >> 13], values[(val*k_mult[5] & 0x1fff)*3 >> 13],
+ yl[1] * float4{values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]}; values[(val*k_mult[6] & 0x1fff)*3 >> 13], values[(val*k_mult[7] & 0x1fff)*3 >> 13]};
float acc = acc4[0] + acc4[1] + acc4[2] + acc4[3]; sumf[row] += acc4[0] + acc4[1] + acc4[2] + acc4[3];
sumf[row] += (signs & 1 ? -acc : acc);
extra += nb*sizeof(block_iq1_bn); extra += nb*sizeof(block_iq1_bn);
ql += nb*sizeof(block_iq1_bn); ql += nb*sizeof(block_iq1_bn);
@@ -5985,32 +5986,22 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
} }
} }
template <typename type4x4> template <typename type4x4>
void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) { void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
// il is in 0...3 // il is in 0...3
uint8_t gs = xb->extra >> 2*il; uint8_t gs = xb->extra >> 2*il;
uint16_t idx1 = xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00); constexpr int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
uint16_t idx2 = xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00);
uint16_t val1 = gs & 1 ? 0xaaaa - iq1bn_grid_u16[idx1] : iq1bn_grid_u16[idx1];
uint16_t val2 = gs & 2 ? 0xaaaa - iq1bn_grid_u16[idx2] : iq1bn_grid_u16[idx2];
uint32_t v = val1 | (val1 << 14); short il1 = 2*il+0, il2 = 2*il+1;
uint32_t aux32; int16_t v1 = xb->ql[il1] | ((xb->qh[il1%4] << (8 - 4*(il1/4))) & 0x0f00) | ((gs << 12) & 4096);
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32; int16_t v2 = xb->ql[il2] | ((xb->qh[il2%4] << (8 - 4*(il2/4))) & 0x0f00) | ((gs << 11) & 4096);
const half values[3] = {-1.h, 0.h, 1.h}; for (int i = 0; i < 8; ++i) {
reg[i/4+0][i%4] = ((v1*k_mult[i] & 0x1fff)*3 >> 13) - 1;
aux32 = v & 0x03030303; reg[i/4+2][i%4] = ((v2*k_mult[i] & 0x1fff)*3 >> 13) - 1;
reg[0][0] = values[aux8[0]]; reg[0][1] = values[aux8[1]]; reg[0][2] = values[aux8[2]]; reg[0][3] = values[aux8[3]]; }
aux32 = (v >> 4) & 0x03030303;
reg[1][0] = values[aux8[0]]; reg[1][1] = values[aux8[1]]; reg[1][2] = values[aux8[2]]; reg[1][3] = values[aux8[3]];
v = val2 | (val2 << 14);
aux32 = v & 0x03030303;
reg[2][0] = values[aux8[0]]; reg[2][1] = values[aux8[1]]; reg[2][2] = values[aux8[2]]; reg[2][3] = values[aux8[3]];
aux32 = (v >> 4) & 0x03030303;
reg[3][0] = values[aux8[0]]; reg[3][1] = values[aux8[1]]; reg[3][2] = values[aux8[2]]; reg[3][3] = values[aux8[3]];
} }
template <typename type4x4> template <typename type4x4>