mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
iq2_kl: Metal GEMV - pretty slow
This commit is contained in:
@@ -7262,81 +7262,80 @@ void kernel_mul_mv_iq2_kl_f32_impl(
|
||||
|
||||
const uint offset0 = (i12/r2)*(ne01) + (i13/r3)*(ne01*ne02);
|
||||
|
||||
device const char * cx = (device const char *) src0 + (first_row + offset0)*row_size;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
device const char * cx0 = (device const char *) src0 + (first_row + offset0)*row_size;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f};
|
||||
float drow[N_DST];
|
||||
|
||||
const int ix = tiisg/8; // 0...3
|
||||
const int it = tiisg%8; // 0...7
|
||||
const int iq = it/4; // 0 or 1
|
||||
const int ir = it%4; // 0...3
|
||||
const int iq = it/2; // 0...3
|
||||
const int ir = it%2; // 0 or 1
|
||||
|
||||
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
||||
device const float * y4 = y + ix * QK_K + 64 * iq + 16 * ir;
|
||||
|
||||
threadgroup float * all_values = (threadgroup float *)shared_values + 32*sgitg;
|
||||
{
|
||||
//int row = tiisg%N_DST;
|
||||
//device const half * dptr = (device const half *)(cx + row*row_size);
|
||||
//const float d = *dptr;
|
||||
//all_values[8*row + tiisg/N_DST] = d*iq2nl_values[tiisg/N_DST];
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
int row = tiisg/8;
|
||||
int pos = tiisg%8;
|
||||
device const half * dptr = (device const half *)(cx + row*row_size);
|
||||
const float d = *dptr;
|
||||
all_values[8*row + pos] = d*kvalues_iq2k_f[pos];
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
uint16_t aux16[2];
|
||||
thread const uint8_t * aux8 = (thread const uint8_t *)&aux16;
|
||||
|
||||
device const char * cx = cx0;
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
device const half * dptr = (device const half *)cx;
|
||||
drow[row] = dptr[0];
|
||||
cx += row_size;
|
||||
}
|
||||
|
||||
cx += sizeof(half);
|
||||
|
||||
uint32_t q32[2];
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
|
||||
cx0 += sizeof(half);
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 4) {
|
||||
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
yl[i+ 0] = y4[i+ 0];
|
||||
yl[i+ 8] = y4[i+32];
|
||||
yl[i+16] = y4[i+64];
|
||||
yl[i+24] = y4[i+96];
|
||||
yl[i+16] = y4[i+32];
|
||||
}
|
||||
|
||||
device const block_iq2_ks * x = (device const block_iq2_ks *)cx + ib;
|
||||
device const uint16_t * q16 = (device const uint16_t *)x->qs + 16*iq + 4*ir;
|
||||
device const uint16_t * sc = (device const uint16_t *)x->scales;
|
||||
device const uint16_t * ex = (device const uint16_t *)&x->extra;
|
||||
//device const block_iq2_kl * x = (device const block_iq2_kl *)cx + ib;
|
||||
//device const uint16_t * ql = (device const uint16_t *)x->qs + 8*iq + 4*ir;
|
||||
//device const uint16_t * qh = (device const uint16_t *)x->qh + 4*ir;
|
||||
//device const uint8_t * sl = x->scales_l;
|
||||
//device const uint16_t * sh = &x->scales_h;
|
||||
|
||||
device const char * cx = cx0;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
threadgroup const float * row_values = all_values + 8*row;
|
||||
device const block_iq2_kl * x = (device const block_iq2_kl *)cx + ib;
|
||||
|
||||
uint32_t sc32 = (sc[iq] | (sc[iq] << 12)) & 0x0f0f0f0f;
|
||||
thread const int8_t * s8 = (thread const int8_t *)&sc32;
|
||||
//threadgroup const float * row_values = all_values + 8*row;
|
||||
|
||||
q32[0] = q16[0] | (q16[1] << 16);
|
||||
q32[1] = q16[2] | (q16[3] << 16);
|
||||
int8_t ls1 = int8_t(((x->scales_l[(2*iq+0)%4] >> 4*((2*iq+0)/4)) & 0xf) | (((x->scales_h >> (4*iq+0)) & 0x03) << 4)) - 32;
|
||||
int8_t ls2 = int8_t(((x->scales_l[(2*iq+1)%4] >> 4*((2*iq+1)/4)) & 0xf) | (((x->scales_h >> (4*iq+2)) & 0x03) << 4)) - 32;
|
||||
|
||||
uint8_t extra = ex[0] << 4*(1-iq);
|
||||
device const uint16_t * ql = (device const uint16_t *)x->qs + 8*iq + 4*ir;
|
||||
device const uint16_t * qh = (device const uint16_t *)x->qh + 4*ir;
|
||||
|
||||
float4 acc = {0.f};
|
||||
float2 acc = {0.f};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
threadgroup const float * values = row_values + ((extra >> (2 + l)) & 4);
|
||||
aux32[0] = (q32[0] >> 2*l) & 0x03030303;
|
||||
aux32[1] = (q32[1] >> 2*l) & 0x03030303;
|
||||
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
|
||||
}
|
||||
extra = ex[0] >> (8 + 4*iq);
|
||||
sumf[row] += acc[0] * (s8[0] - (extra & 1 ? 0 : 16)) + acc[1] * (s8[2] - (extra & 2 ? 0 : 16))
|
||||
+ acc[2] * (s8[1] - (extra & 4 ? 0 : 16)) + acc[3] * (s8[3] - (extra & 8 ? 0 : 16));
|
||||
uint16_t h = qh[l] >> 2*iq;
|
||||
aux16[0] = ((ql[l] >> 0) & 0x0f0f) | ((h & 0x0101) << 4);
|
||||
aux16[1] = ((ql[l] >> 4) & 0x0f0f) | ((h & 0x0202) << 3);
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
constant const int8_t * val1 = (constant const int8_t *)(iq2kl_values + aux8[j+0]);
|
||||
constant const int8_t * val2 = (constant const int8_t *)(iq2kl_values + aux8[j+2]);
|
||||
acc[0] += yl[4*l+2*j+ 0] * val1[0] + yl[4*l+2*j+ 1] * val1[1];
|
||||
acc[1] += yl[4*l+2*j+16] * val2[0] + yl[4*l+2*j+17] * val2[1];
|
||||
}
|
||||
|
||||
q16 += row_size/2;
|
||||
sc += row_size/2;
|
||||
ex += row_size/2;
|
||||
}
|
||||
sumf[row] += drow[row] * (acc[0] * ls1 + acc[1] * ls2);
|
||||
|
||||
cx += row_size;
|
||||
|
||||
//ql += row_size/2;
|
||||
//qh += row_size/2;
|
||||
//sl += row_size;
|
||||
//sh += row_size/2;
|
||||
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user