iq5_ks: Metal dot product

This commit is contained in:
Iwan Kawrakow
2025-05-15 15:48:29 +03:00
parent cf93e69f0f
commit a7ceba3dc6

View File

@@ -6305,24 +6305,25 @@ void kernel_mul_mv_iq5_ks_f32_impl(
const uint i12 = im%ne12;
const uint i13 = im/ne12;
const uint row_size = 4 + nb*sizeof(block_iq4_ks);
const uint row_size = 4 + nb*sizeof(block_iq5_ks);
const uint offset0 = (i12/r2)*ne01 + (i13/r3)*(ne01*ne02);
device const char * cx = (device const char *)src0 + (first_row + offset0)*row_size;
device const float * y = (device const float *)src1 + r1*ne10 + im*ne00*ne1;
const int ix = tiisg/16; // 0 or 1
const int it = tiisg%16; // 0...15
const int ib = it/2;
const int il = it%2;
const int ib64 = it/4; // 0...3
const int il64 = it%4; // 0...3
shared_values[tiisg] = kvalues_iq4k_f[tiisg];
shared_values[2*tiisg+0] = kvalues_iq5k_f[2*tiisg+0];
shared_values[2*tiisg+1] = kvalues_iq5k_f[2*tiisg+1];
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float2 sumf = 0.f;
float d[2];
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
device const float * yb = y + ix * QK_K + ib64 * 64 + il64 * 8;
uint32_t aux32[2];
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
@@ -6331,43 +6332,46 @@ void kernel_mul_mv_iq5_ks_f32_impl(
device const float * dptr = (device const float *)cx;
d[0] = *dptr;
device const block_iq4_ks * x = (device const block_iq4_ks *)(dptr + 1) + ix;
device const block_iq5_ks * x = (device const block_iq5_ks *)(dptr + 1) + ix;
dptr += row_size/4;
d[1] = *dptr;
for (int ibl = ix; ibl < nb; ibl += 2) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
yl[0] = y4[0]; yl[1] = y4[8]; yl[2] = y4[1]; yl[3] = y4[9];
device const uint8_t * scales = x->scales;
for (int row = 0; row < 2; ++row) {
threadgroup const float * block_values = shared_values + ((scales[ib] & 1) << 4);
const float ls = ((scales[ib] & 254) - 127);
threadgroup const float * values1 = shared_values + ((scales[2*ib64+0] & 1) << 5);
threadgroup const float * values2 = shared_values + ((scales[2*ib64+1] & 1) << 5);
const float ls1 = ((scales[2*ib64+0] & 254) - 127);
const float ls2 = ((scales[2*ib64+1] & 254) - 127);
device const uint32_t * q4 = (device const uint32_t *)scales + QK_K/128 + 4*ib + 2*il;
device const uint32_t * q4 = (device const uint32_t *)scales + QK_K/128 + 8*ib64 + 2*il64;
device const uint32_t * qh = (device const uint32_t *)scales + QK_K/128 + QK_K/8 + 2*il64;
float4 acc1 = {0.f}, acc2 = {0.f};
aux32[0] = q4[0] & 0x0f0f0f0f;
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
qf1 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]};
qf2 = {block_values[q8[4]], block_values[q8[5]], block_values[q8[6]], block_values[q8[7]]};
uint32_t h = qh[0] >> 2*ib64;
aux32[0] = ((q4[0] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x10101010);
aux32[1] = ((q4[0] >> 4) & 0x0f0f0f0f) | ((h << 3) & 0x10101010);
qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]};
qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]};
acc1 += yl[0] * qf1;
acc2 += yl[1] * qf2;
aux32[0] = q4[1] & 0x0f0f0f0f;
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
qf1 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]};
qf2 = {block_values[q8[4]], block_values[q8[5]], block_values[q8[6]], block_values[q8[7]]};
h = qh[1] >> 2*ib64;
aux32[0] = ((q4[1] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x10101010);
aux32[1] = ((q4[1] >> 4) & 0x0f0f0f0f) | ((h << 3) & 0x10101010);
qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]};
qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]};
acc1 += yl[2] * qf1;
acc2 += yl[3] * qf2;
acc1 += acc2;
sumf[row] += d[row] * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
sumf[row] += ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3]) + ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3]);
scales += row_size;
@@ -6379,7 +6383,7 @@ void kernel_mul_mv_iq5_ks_f32_impl(
sumf = simd_sum(sumf);
if (tiisg < 2) {
dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg];
dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg] * d[tiisg];
}
}