mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
iq4_ks: faster dot product on Metal
TG-128(LLaMA-3.1-8B) goes to 52.5 t/s up from 48.4 t/s.
This commit is contained in:
@@ -6079,6 +6079,7 @@ void kernel_mul_mv_iq4_ks_f32_impl(
|
||||
|
||||
float4 yl[4];
|
||||
float2 sumf = 0.f;
|
||||
float d[2];
|
||||
|
||||
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
||||
|
||||
@@ -6087,22 +6088,25 @@ void kernel_mul_mv_iq4_ks_f32_impl(
|
||||
|
||||
float4 qf1, qf2;
|
||||
|
||||
device const float * dptr = (device const float *)cx;
|
||||
d[0] = *dptr;
|
||||
device const block_iq4_ks * x = (device const block_iq4_ks *)(dptr + 1) + ix;
|
||||
dptr += row_size/4;
|
||||
d[1] = *dptr;
|
||||
|
||||
for (int ibl = ix; ibl < nb; ibl += 2) {
|
||||
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
||||
|
||||
device const float * dptr = (device const float *)cx;
|
||||
device const uint8_t * scales = x->scales;
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
|
||||
//device const float * dptr = (device const float *)(cx + row*row_size);
|
||||
const float d = *dptr;
|
||||
device const block_iq4_ks * x = (device const block_iq4_ks *)(dptr + 1);
|
||||
device const block_iq4_ks & xb = x[ibl];
|
||||
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
||||
threadgroup const float * block_values = shared_values + ((scales[ib] & 1) << 4);
|
||||
const float ls = ((scales[ib] & 254) - 127);
|
||||
|
||||
threadgroup const float * block_values = shared_values + ((xb.scales[ib] & 1) << 4);
|
||||
device const uint32_t * q4 = (device const uint32_t *)scales + QK_K/128 + 4*ib + 2*il;
|
||||
|
||||
float4 acc1 = {0.f}, acc2 = {0.f};
|
||||
|
||||
@@ -6122,14 +6126,14 @@ void kernel_mul_mv_iq4_ks_f32_impl(
|
||||
|
||||
acc1 += acc2;
|
||||
|
||||
const int ls = (xb.scales[ib] & 254) - 127;
|
||||
sumf[row] += d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
||||
sumf[row] += d[row] * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
||||
|
||||
dptr += row_size/4;
|
||||
scales += row_size;
|
||||
|
||||
}
|
||||
|
||||
yb += 2 * QK_K;
|
||||
x += 2;
|
||||
}
|
||||
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
Reference in New Issue
Block a user