diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 61112db7..baaac407 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -7396,15 +7396,15 @@ void kernel_mul_mv_iq3_ks_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq3_k * x = (device const block_iq3_k *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const uint row_size = sizeof(half) + nb*sizeof(block_iq3_ks); + device const char * cx = (device const char *)src0 + (first_row + offset0)*row_size; + device const float * y = (device const float *)src1 + r1*ne10 + im*ne00*ne1; threadgroup float * all_values = (threadgroup float *)shared_values + 16*sgitg; { @@ -7414,12 +7414,19 @@ void kernel_mul_mv_iq3_ks_f32_impl( float yl[32]; float sumf[N_DST]={0.f}; + float d[N_DST]; const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 const int iq = it/4; // 0 or 1 const int ir = it%4; // 0...3 - const int is = (8*ir)/16;// 0 or 1 + + device const half * dptr = (device const half *)cx; + d[0] = (float)dptr[0]; + for (int i = 1; i < N_DST; ++i) { + dptr += row_size/2; + d[i] = (float)dptr[0]; + } device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; @@ -7438,16 +7445,22 @@ void kernel_mul_mv_iq3_ks_f32_impl( for (int row = 0; row < N_DST; row++) { - device const block_iq3_k & xb = x[row*nb + ib]; + device const block_iq3_ks * x = (device const block_iq3_ks *)(cx + row_size*row + sizeof(half)); + device const block_iq3_ks & xb = x[ib]; device const uint16_t * ql16 = (device const uint16_t *)xb.qs + 16*iq + 4*ir; device const uint16_t * qh16 = (device const uint16_t *)xb.qh + 4*ir; - device const uint16_t * sc16 = (device const uint16_t *)xb.scales_l; + device const uint16_t * sc16 = (device const uint16_t *)xb.scales; - uint32_t scales32 = sc16[2*iq+0] | (sc16[2*iq+1] << 16); - scales32 = ((scales32 >> 4*is) & 0x0f0f0f0f) << 1; - thread const int8_t * s8 = (thread const int8_t *)&scales32; - uint16_t extra = (xb.extra >> (8*iq + is)) << 3; - uint16_t signs = xb.scales_h >> (8*iq + is); + uint8_t extra_s = (xb.extra & 0xff) >> 4*iq; + uint8_t extra_v = xb.extra >> (8 + 4*iq); + + uint32_t scales32 = sc16[0] | (sc16[1] << 16); + scales32 = (scales32 >> 4*iq) & 0x0f0f0f0f; + thread int8_t * s8 = (thread int8_t *)&scales32; + s8[0] += ((extra_s << 4) & 0x10) - 16; + s8[1] += ((extra_s << 3) & 0x10) - 16; + s8[2] += ((extra_s << 2) & 0x10) - 16; + s8[3] += ((extra_s << 1) & 0x10) - 16; vl[0] = ql16[0] | ql16[1] << 16; vl[1] = ql16[2] | ql16[3] << 16; @@ -7456,18 +7469,16 @@ void kernel_mul_mv_iq3_ks_f32_impl( float4 acc = {0.f}; for (int l = 0; l < 4; ++l) { - threadgroup const float * values = all_values + (extra & 8); - //constant float * values = kvalues_iq3k_f + (extra & 8); + threadgroup const float * values = all_values + ((extra_v & 1) << 3); aux32[0] = (vl[0] & 0x03030303) | (vh[0] & 0x04040404); aux32[1] = (vl[1] & 0x03030303) | (vh[1] & 0x04040404); for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]]; vl[0] >>= 2; vl[1] >>= 2; vh[0] >>= 1; vh[1] >>= 1; - extra >>= 2; + extra_v >>= 1; } - sumf[row] += (float)xb.d * (acc[0] * (signs & 0x01 ? -s8[0] : s8[0]) + acc[1] * (signs & 0x04 ? -s8[1] : s8[1]) + - acc[2] * (signs & 0x10 ? -s8[2] : s8[2]) + acc[3] * (signs & 0x40 ? -s8[3] : s8[3])); + sumf[row] += d[row] * (acc[0] * s8[0] + acc[1] * s8[1] + acc[2] * s8[2] + acc[3] * s8[3]); }