iq4_xxs: slightly faster TG on Metal

This commit is contained in:
Iwan Kawrakow
2024-10-09 11:06:45 +03:00
parent bb522fb314
commit 0e12b2919c

View File

@@ -6077,7 +6077,7 @@ void kernel_mul_mv_iq4_xxs_f32_impl(
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float sumf[2]={0.f}, all_sum;
float2 sumf = 0.f;
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
@@ -6091,9 +6091,11 @@ void kernel_mul_mv_iq4_xxs_f32_impl(
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
device const float * dptr = (device const float *)cx;
for (int row = 0; row < 2; ++row) {
device const float * dptr = (device const float *)(cx + row*row_size);
//device const float * dptr = (device const float *)(cx + row*row_size);
const float d = *dptr;
device const block_iq4_xxs * x = (device const block_iq4_xxs *)(dptr + 1);
device const block_iq4_xxs & xb = x[ibl];
@@ -6122,16 +6124,16 @@ void kernel_mul_mv_iq4_xxs_f32_impl(
const int ls = (xb.scales[ib] & 254) - 127;
sumf[row] += d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
dptr += row_size/4;
}
yb += 2 * QK_K;
}
for (int row = 0; row < 2; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
}
sumf = simd_sum(sumf);
if (tiisg < 2) {
dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg];
}
}