iq4_xxs: slightly faster TG on Metal

This commit is contained in:
Iwan Kawrakow
2024-10-09 11:06:45 +03:00
parent bb522fb314
commit 0e12b2919c

View File

@@ -6077,7 +6077,7 @@ void kernel_mul_mv_iq4_xxs_f32_impl(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4]; float4 yl[4];
float sumf[2]={0.f}, all_sum; float2 sumf = 0.f;
device const float * yb = y + ix * QK_K + ib * 32 + il * 8; device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
@@ -6091,9 +6091,11 @@ void kernel_mul_mv_iq4_xxs_f32_impl(
device const float4 * y4 = (device const float4 *)yb; device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
device const float * dptr = (device const float *)cx;
for (int row = 0; row < 2; ++row) { for (int row = 0; row < 2; ++row) {
device const float * dptr = (device const float *)(cx + row*row_size); //device const float * dptr = (device const float *)(cx + row*row_size);
const float d = *dptr; const float d = *dptr;
device const block_iq4_xxs * x = (device const block_iq4_xxs *)(dptr + 1); device const block_iq4_xxs * x = (device const block_iq4_xxs *)(dptr + 1);
device const block_iq4_xxs & xb = x[ibl]; device const block_iq4_xxs & xb = x[ibl];
@@ -6122,16 +6124,16 @@ void kernel_mul_mv_iq4_xxs_f32_impl(
const int ls = (xb.scales[ib] & 254) - 127; const int ls = (xb.scales[ib] & 254) - 127;
sumf[row] += d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); sumf[row] += d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
dptr += row_size/4;
} }
yb += 2 * QK_K; yb += 2 * QK_K;
} }
for (int row = 0; row < 2; ++row) { sumf = simd_sum(sumf);
all_sum = simd_sum(sumf[row]); if (tiisg < 2) {
if (tiisg == 0) { dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg];
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
}
} }
} }