mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
iq4_xxs: slightly faster TG on Metal
This commit is contained in:
@@ -6077,7 +6077,7 @@ void kernel_mul_mv_iq4_xxs_f32_impl(
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
float4 yl[4];
|
float4 yl[4];
|
||||||
float sumf[2]={0.f}, all_sum;
|
float2 sumf = 0.f;
|
||||||
|
|
||||||
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
||||||
|
|
||||||
@@ -6091,9 +6091,11 @@ void kernel_mul_mv_iq4_xxs_f32_impl(
|
|||||||
device const float4 * y4 = (device const float4 *)yb;
|
device const float4 * y4 = (device const float4 *)yb;
|
||||||
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
||||||
|
|
||||||
|
device const float * dptr = (device const float *)cx;
|
||||||
|
|
||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2; ++row) {
|
||||||
|
|
||||||
device const float * dptr = (device const float *)(cx + row*row_size);
|
//device const float * dptr = (device const float *)(cx + row*row_size);
|
||||||
const float d = *dptr;
|
const float d = *dptr;
|
||||||
device const block_iq4_xxs * x = (device const block_iq4_xxs *)(dptr + 1);
|
device const block_iq4_xxs * x = (device const block_iq4_xxs *)(dptr + 1);
|
||||||
device const block_iq4_xxs & xb = x[ibl];
|
device const block_iq4_xxs & xb = x[ibl];
|
||||||
@@ -6122,16 +6124,16 @@ void kernel_mul_mv_iq4_xxs_f32_impl(
|
|||||||
const int ls = (xb.scales[ib] & 254) - 127;
|
const int ls = (xb.scales[ib] & 254) - 127;
|
||||||
sumf[row] += d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
sumf[row] += d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
||||||
|
|
||||||
|
dptr += row_size/4;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
yb += 2 * QK_K;
|
yb += 2 * QK_K;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < 2; ++row) {
|
sumf = simd_sum(sumf);
|
||||||
all_sum = simd_sum(sumf[row]);
|
if (tiisg < 2) {
|
||||||
if (tiisg == 0) {
|
dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg];
|
||||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user