mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-02 18:10:02 +00:00
iq2_kl: Metal GEMV - slightly better (44.5 t/s -> 46.5 t/s)
This commit is contained in:
@@ -7286,11 +7286,11 @@ void kernel_mul_mv_iq2_kl_f32_impl(
|
||||
cx += row_size;
|
||||
}
|
||||
|
||||
threadgroup float * all_values = (threadgroup float *)shared_values + 64*sgitg;
|
||||
threadgroup float2 * all_values = (threadgroup float2 *)shared_values + 32*sgitg;
|
||||
{
|
||||
constant const int8_t * val = (constant const int8_t *)iq2kl_values;
|
||||
all_values[2*tiisg + 0] = val[2*tiisg + 0];
|
||||
all_values[2*tiisg + 1] = val[2*tiisg + 1];
|
||||
all_values[tiisg][0] = val[2*tiisg + 0];
|
||||
all_values[tiisg][1] = val[2*tiisg + 1];
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
@@ -7303,20 +7303,15 @@ void kernel_mul_mv_iq2_kl_f32_impl(
|
||||
yl[i+16] = y4[i+32];
|
||||
}
|
||||
|
||||
//device const block_iq2_kl * x = (device const block_iq2_kl *)cx + ib;
|
||||
//device const uint16_t * ql = (device const uint16_t *)x->qs + 8*iq + 4*ir;
|
||||
//device const uint16_t * qh = (device const uint16_t *)x->qh + 4*ir;
|
||||
//device const uint8_t * sl = x->scales_l;
|
||||
//device const uint16_t * sh = &x->scales_h;
|
||||
|
||||
device const char * cx = cx0;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
device const block_iq2_kl * x = (device const block_iq2_kl *)cx + ib;
|
||||
|
||||
int8_t ls1 = int8_t(((x->scales_l[(2*iq+0)%4] >> 4*((2*iq+0)/4)) & 0xf) | (((x->scales_h >> (4*iq+0)) & 0x03) << 4)) - 32;
|
||||
int8_t ls2 = int8_t(((x->scales_l[(2*iq+1)%4] >> 4*((2*iq+1)/4)) & 0xf) | (((x->scales_h >> (4*iq+2)) & 0x03) << 4)) - 32;
|
||||
uint16_t h = x->scales_h >> 4*iq;
|
||||
int8_t ls1 = int8_t(((x->scales_l[(2*iq+0)%4] >> 4*((2*iq+0)/4)) & 0xf) | ((h & 0x03) << 4)) - 32;
|
||||
int8_t ls2 = int8_t(((x->scales_l[(2*iq+1)%4] >> 4*((2*iq+1)/4)) & 0xf) | ((h & 0x0c) << 2)) - 32;
|
||||
|
||||
device const uint16_t * ql = (device const uint16_t *)x->qs + 8*iq + 4*ir;
|
||||
device const uint16_t * qh = (device const uint16_t *)x->qh + 4*ir;
|
||||
@@ -7327,8 +7322,8 @@ void kernel_mul_mv_iq2_kl_f32_impl(
|
||||
aux16[0] = ((ql[l] >> 0) & 0x0f0f) | ((h & 0x0101) << 4);
|
||||
aux16[1] = ((ql[l] >> 4) & 0x0f0f) | ((h & 0x0202) << 3);
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
threadgroup const float * val1 = all_values + 2*aux8[j+0];
|
||||
threadgroup const float * val2 = all_values + 2*aux8[j+2];
|
||||
threadgroup const float2 & val1 = all_values[aux8[j+0]];
|
||||
threadgroup const float2 & val2 = all_values[aux8[j+2]];
|
||||
acc[0] += yl[4*l+2*j+ 0] * val1[0] + yl[4*l+2*j+ 1] * val1[1];
|
||||
acc[1] += yl[4*l+2*j+16] * val2[0] + yl[4*l+2*j+17] * val2[1];
|
||||
}
|
||||
@@ -7338,11 +7333,6 @@ void kernel_mul_mv_iq2_kl_f32_impl(
|
||||
|
||||
cx += row_size;
|
||||
|
||||
//ql += row_size/2;
|
||||
//qh += row_size/2;
|
||||
//sl += row_size;
|
||||
//sh += row_size/2;
|
||||
|
||||
}
|
||||
|
||||
y4 += 4 * QK_K;
|
||||
|
||||
Reference in New Issue
Block a user