Faster iq4_k: Metal

This commit is contained in:
Iwan Kawrakow
2024-11-05 09:46:43 +01:00
parent 9d713516cd
commit da9fdf57d8

View File

@@ -6492,73 +6492,59 @@ void kernel_mul_mv_iq4_k_f32_impl(
const int ix = tiisg/16; // 0 or 1
const int it = tiisg%16; // 0...15
const int ib = it/2;
const int il = it%2;
shared_values[tiisg] = kvalues_iq4k_f[tiisg];
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float sumf[2]={0.f}, all_sum;
float2 sumf = 0.f;
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
device const float * yb = y + ix * QK_K + it * 4;
uint32_t aux32[2];
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
float4 qf1, qf2;
float4 qf;
for (int ibl = ix; ibl < nb; ibl += 2) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
//float2 sumy;
//sumy[0] = -4.f*(yl[0][0] + yl[0][1] + yl[0][2] + yl[0][3] + yl[2][0] + yl[2][1] + yl[2][2] + yl[2][3]);
//sumy[1] = -4.f*(yl[1][0] + yl[1][1] + yl[1][2] + yl[1][3] + yl[3][0] + yl[3][1] + yl[3][2] + yl[3][3]);
yl[0] = y4[0]; yl[1] = y4[16]; yl[2] = y4[32]; yl[3] = y4[48];
for (int row = 0; row < 2; ++row) {
device const block_iq4_k & xb = x[row*nb + ibl];
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 4*it);
uint16_t extra = xb.extra >> 2*ib;
threadgroup const float * values1 = shared_values + 16*(extra & 1);
threadgroup const float * values2 = shared_values + 8*(extra & 2);
uint16_t extra = xb.extra >> it;
threadgroup const float * values = shared_values + ((extra & 1) << 4);
float4 acc1 = {0.f}, acc2 = {0.f};
aux32 = q4[0] & 0x0f0f0f0f;
qf = {values[q8[0]], values[q8[1]], values[q8[2]], values[q8[3]]};
float4 acc = yl[0] * qf;
aux32 = (q4[0] >> 4) & 0x0f0f0f0f;
qf = {values[q8[0]], values[q8[1]], values[q8[2]], values[q8[3]]};
acc += yl[1] * qf;
aux32[0] = q4[0] & 0x0f0f0f0f;
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]};
qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]};
acc1 += yl[0] * qf1;
acc2 += yl[1] * qf2;
aux32 = q4[16] & 0x0f0f0f0f;
qf = {values[q8[0]], values[q8[1]], values[q8[2]], values[q8[3]]};
acc += yl[2] * qf;
aux32 = (q4[16] >> 4) & 0x0f0f0f0f;
qf = {values[q8[0]], values[q8[1]], values[q8[2]], values[q8[3]]};
acc += yl[3] * qf;
aux32[0] = q4[1] & 0x0f0f0f0f;
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]};
qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]};
acc1 += yl[2] * qf1;
acc2 += yl[3] * qf2;
const uint8_t h = xb.scales_h[ib/2] >> 4*(ib%2);
const int ls1 = ((xb.scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
const int ls2 = ((xb.scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
sumf[row] += (float)xb.d * (ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3]) + ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3]));
//uint16_t extra = xb.extra >> 2*ib;
//sumf[row] += (float)xb.d * (ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3] + (extra & 1 ? sumy[0] : 0)) +
// ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3] + (extra & 2 ? sumy[1] : 0)));
const uint8_t h = xb.scales_h[it/4] >> 2*(it%4);
const int ls = (((xb.scales_l[it/2] >> 4*(it%2)) & 0xf) | ((h << 4) & 0x30)) - 32;
sumf[row] += (float)xb.d * ls * (acc[0] + acc[1] + acc[2] + acc[3]);
}
yb += 2 * QK_K;
}
for (int row = 0; row < 2; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
}
sumf = simd_sum(sumf);
if (tiisg < 2) {
dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg];
}
}
@@ -7692,22 +7678,26 @@ void dequantize_iq3_k(device const block_iq3_k * xb, short il, thread type4x4 &
template <typename type4x4>
void dequantize_iq4_k(device const block_iq4_k * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
const int l = il%2;
// l = 0 or 1. l = 0 processes the first 16 quants in a block of 32, l = 1 the second 16
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
const int ls = ((xb->scales_l[ib32] >> 4*l) & 0xf) | (((xb->scales_h[il/4] >> 2*(il%4)) & 3) << 4);
const float d = (float)xb->d * (ls - 32);
int i16 = 4*(il%4);
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 16*(il/8) + i16;
const int shift = 4*((il/4)%2);
uint16_t extra = xb->extra >> i16;
const float d = xb->d;
float4 dl;
dl[0] = d * (((xb->scales_l[2*(il%4)+0] & 0xf) | ((xb->scales_h[il%4] << 4) & 0x30)) - 32);
dl[1] = d * (((xb->scales_l[2*(il%4)+0] >> 4) | ((xb->scales_h[il%4] << 2) & 0x30)) - 32);
dl[2] = d * (((xb->scales_l[2*(il%4)+1] & 0xf) | ((xb->scales_h[il%4] >> 0) & 0x30)) - 32);
dl[3] = d * (((xb->scales_l[2*(il%4)+1] >> 4) | ((xb->scales_h[il%4] >> 2) & 0x30)) - 32);
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
constant float * values = kvalues_iq4k_f + 16*((xb->extra >> il) & 1);
for (int i = 0; i < 4; ++i) {
aux32 = (q4[i] >> 4*l) & 0x0f0f0f0f;
reg[i][0] = d * values[q8[0]];
reg[i][1] = d * values[q8[1]];
reg[i][2] = d * values[q8[2]];
reg[i][3] = d * values[q8[3]];
constant float * values = kvalues_iq4k_f + ((extra & 1) << 4);
aux32 = (q4[i] >> shift) & 0x0f0f0f0f;
reg[i][0] = dl[i] * values[q8[0]];
reg[i][1] = dl[i] * values[q8[1]];
reg[i][2] = dl[i] * values[q8[2]];
reg[i][3] = dl[i] * values[q8[3]];
extra >>= 1;
}
}