mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
iq5_ks: Metal dot product
This commit is contained in:
@@ -6305,24 +6305,25 @@ void kernel_mul_mv_iq5_ks_f32_impl(
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
|
||||
const uint row_size = 4 + nb*sizeof(block_iq4_ks);
|
||||
const uint row_size = 4 + nb*sizeof(block_iq5_ks);
|
||||
const uint offset0 = (i12/r2)*ne01 + (i13/r3)*(ne01*ne02);
|
||||
device const char * cx = (device const char *)src0 + (first_row + offset0)*row_size;
|
||||
device const float * y = (device const float *)src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
const int ix = tiisg/16; // 0 or 1
|
||||
const int it = tiisg%16; // 0...15
|
||||
const int ib = it/2;
|
||||
const int il = it%2;
|
||||
const int ib64 = it/4; // 0...3
|
||||
const int il64 = it%4; // 0...3
|
||||
|
||||
shared_values[tiisg] = kvalues_iq4k_f[tiisg];
|
||||
shared_values[2*tiisg+0] = kvalues_iq5k_f[2*tiisg+0];
|
||||
shared_values[2*tiisg+1] = kvalues_iq5k_f[2*tiisg+1];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float4 yl[4];
|
||||
float2 sumf = 0.f;
|
||||
float d[2];
|
||||
|
||||
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
||||
device const float * yb = y + ix * QK_K + ib64 * 64 + il64 * 8;
|
||||
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
||||
@@ -6331,43 +6332,46 @@ void kernel_mul_mv_iq5_ks_f32_impl(
|
||||
|
||||
device const float * dptr = (device const float *)cx;
|
||||
d[0] = *dptr;
|
||||
device const block_iq4_ks * x = (device const block_iq4_ks *)(dptr + 1) + ix;
|
||||
device const block_iq5_ks * x = (device const block_iq5_ks *)(dptr + 1) + ix;
|
||||
dptr += row_size/4;
|
||||
d[1] = *dptr;
|
||||
|
||||
for (int ibl = ix; ibl < nb; ibl += 2) {
|
||||
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
||||
yl[0] = y4[0]; yl[1] = y4[8]; yl[2] = y4[1]; yl[3] = y4[9];
|
||||
|
||||
device const uint8_t * scales = x->scales;
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
|
||||
threadgroup const float * block_values = shared_values + ((scales[ib] & 1) << 4);
|
||||
const float ls = ((scales[ib] & 254) - 127);
|
||||
threadgroup const float * values1 = shared_values + ((scales[2*ib64+0] & 1) << 5);
|
||||
threadgroup const float * values2 = shared_values + ((scales[2*ib64+1] & 1) << 5);
|
||||
const float ls1 = ((scales[2*ib64+0] & 254) - 127);
|
||||
const float ls2 = ((scales[2*ib64+1] & 254) - 127);
|
||||
|
||||
device const uint32_t * q4 = (device const uint32_t *)scales + QK_K/128 + 4*ib + 2*il;
|
||||
device const uint32_t * q4 = (device const uint32_t *)scales + QK_K/128 + 8*ib64 + 2*il64;
|
||||
device const uint32_t * qh = (device const uint32_t *)scales + QK_K/128 + QK_K/8 + 2*il64;
|
||||
|
||||
float4 acc1 = {0.f}, acc2 = {0.f};
|
||||
|
||||
aux32[0] = q4[0] & 0x0f0f0f0f;
|
||||
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
|
||||
qf1 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]};
|
||||
qf2 = {block_values[q8[4]], block_values[q8[5]], block_values[q8[6]], block_values[q8[7]]};
|
||||
uint32_t h = qh[0] >> 2*ib64;
|
||||
aux32[0] = ((q4[0] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x10101010);
|
||||
aux32[1] = ((q4[0] >> 4) & 0x0f0f0f0f) | ((h << 3) & 0x10101010);
|
||||
qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]};
|
||||
qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]};
|
||||
acc1 += yl[0] * qf1;
|
||||
acc2 += yl[1] * qf2;
|
||||
|
||||
aux32[0] = q4[1] & 0x0f0f0f0f;
|
||||
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
|
||||
qf1 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]};
|
||||
qf2 = {block_values[q8[4]], block_values[q8[5]], block_values[q8[6]], block_values[q8[7]]};
|
||||
h = qh[1] >> 2*ib64;
|
||||
aux32[0] = ((q4[1] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x10101010);
|
||||
aux32[1] = ((q4[1] >> 4) & 0x0f0f0f0f) | ((h << 3) & 0x10101010);
|
||||
qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]};
|
||||
qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]};
|
||||
acc1 += yl[2] * qf1;
|
||||
acc2 += yl[3] * qf2;
|
||||
|
||||
acc1 += acc2;
|
||||
|
||||
sumf[row] += d[row] * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
||||
sumf[row] += ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3]) + ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3]);
|
||||
|
||||
scales += row_size;
|
||||
|
||||
@@ -6379,7 +6383,7 @@ void kernel_mul_mv_iq5_ks_f32_impl(
|
||||
|
||||
sumf = simd_sum(sumf);
|
||||
if (tiisg < 2) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg];
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg] * d[tiisg];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user