mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-08 07:20:12 +00:00
iq4_kss: very slightly faster Metal dot product
48.7 t/s -> 49.3 t/s
This commit is contained in:
@@ -6189,14 +6189,14 @@ void kernel_mul_mv_iq4_kss_f32_impl(
|
||||
|
||||
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
||||
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
||||
uint32_t aux32;
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
||||
|
||||
float4 qf1, qf2;
|
||||
|
||||
device const float * dptr = (device const float *)cx;
|
||||
d[0] = *dptr;
|
||||
device const uint32_t * qptr = (device const uint32_t *)(dptr + 1);
|
||||
device const uint32_t * qptr = (device const uint32_t *)(dptr + 1) + ix*(QK_K/8) + 4*ib;
|
||||
dptr += row_size/4;
|
||||
d[1] = *dptr;
|
||||
|
||||
@@ -6205,7 +6205,7 @@ void kernel_mul_mv_iq4_kss_f32_impl(
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
||||
|
||||
device const uint32_t * q4 = qptr + ibl*(QK_K/8) + 4*ib;
|
||||
device const uint32_t * q4 = qptr;
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
|
||||
@@ -6213,26 +6213,26 @@ void kernel_mul_mv_iq4_kss_f32_impl(
|
||||
int16_t ls = (s32 | (s32 >> 15)) & 0xff;
|
||||
|
||||
threadgroup const float * block_values = shared_values + ((ls & 1) << 4);
|
||||
const int scale = (ls & 254) - 127;
|
||||
const float scale = ((ls & 254) - 127);
|
||||
|
||||
float4 acc1 = {0.f}, acc2 = {0.f};
|
||||
|
||||
aux32[0] = q4[2*il+0] & 0xfffefffe;
|
||||
aux32[0] ^= (aux32[0] >> 1);
|
||||
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
||||
aux32[0] &= 0x0f0f0f0f;
|
||||
uint32_t v32 = q4[2*il+0] & 0xfffefffe;
|
||||
v32 ^= (v32 >> 1);
|
||||
aux32 = v32 & 0x0f0f0f0f;
|
||||
qf1 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]};
|
||||
qf2 = {block_values[q8[4]], block_values[q8[5]], block_values[q8[6]], block_values[q8[7]]};
|
||||
acc1 += yl[0] * qf1;
|
||||
aux32 = (v32 >> 4) & 0x0f0f0f0f;
|
||||
qf2 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]};
|
||||
acc2 += yl[1] * qf2;
|
||||
|
||||
aux32[0] = q4[2*il+1] & 0xfffefffe;
|
||||
aux32[0] ^= (aux32[0] >> 1);
|
||||
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
||||
aux32[0] &= 0x0f0f0f0f;
|
||||
v32 = q4[2*il+1] & 0xfffefffe;
|
||||
v32 ^= (v32 >> 1);
|
||||
aux32 = v32 & 0x0f0f0f0f;
|
||||
qf1 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]};
|
||||
qf2 = {block_values[q8[4]], block_values[q8[5]], block_values[q8[6]], block_values[q8[7]]};
|
||||
acc1 += yl[2] * qf1;
|
||||
aux32 = (v32 >> 4) & 0x0f0f0f0f;
|
||||
qf2 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]};
|
||||
acc2 += yl[3] * qf2;
|
||||
|
||||
acc1 += acc2;
|
||||
@@ -6244,6 +6244,7 @@ void kernel_mul_mv_iq4_kss_f32_impl(
|
||||
}
|
||||
|
||||
yb += 2 * QK_K;
|
||||
qptr += 2 * (QK_K/8);
|
||||
}
|
||||
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
Reference in New Issue
Block a user