diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index ef98c32d..b792844d 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6305,24 +6305,25 @@ void kernel_mul_mv_iq5_ks_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint row_size = 4 + nb*sizeof(block_iq4_ks); + const uint row_size = 4 + nb*sizeof(block_iq5_ks); const uint offset0 = (i12/r2)*ne01 + (i13/r3)*(ne01*ne02); device const char * cx = (device const char *)src0 + (first_row + offset0)*row_size; device const float * y = (device const float *)src1 + r1*ne10 + im*ne00*ne1; const int ix = tiisg/16; // 0 or 1 const int it = tiisg%16; // 0...15 - const int ib = it/2; - const int il = it%2; + const int ib64 = it/4; // 0...3 + const int il64 = it%4; // 0...3 - shared_values[tiisg] = kvalues_iq4k_f[tiisg]; + shared_values[2*tiisg+0] = kvalues_iq5k_f[2*tiisg+0]; + shared_values[2*tiisg+1] = kvalues_iq5k_f[2*tiisg+1]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; float2 sumf = 0.f; float d[2]; - device const float * yb = y + ix * QK_K + ib * 32 + il * 8; + device const float * yb = y + ix * QK_K + ib64 * 64 + il64 * 8; uint32_t aux32[2]; thread const uint8_t * q8 = (thread const uint8_t *)aux32; @@ -6331,43 +6332,46 @@ void kernel_mul_mv_iq5_ks_f32_impl( device const float * dptr = (device const float *)cx; d[0] = *dptr; - device const block_iq4_ks * x = (device const block_iq4_ks *)(dptr + 1) + ix; + device const block_iq5_ks * x = (device const block_iq5_ks *)(dptr + 1) + ix; dptr += row_size/4; d[1] = *dptr; for (int ibl = ix; ibl < nb; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + yl[0] = y4[0]; yl[1] = y4[8]; yl[2] = y4[1]; yl[3] = y4[9]; device const uint8_t * scales = x->scales; for (int row = 0; row < 2; ++row) { - threadgroup const float * block_values = shared_values + ((scales[ib] & 1) << 4); - const float ls = ((scales[ib] & 254) - 127); + threadgroup const float * values1 = shared_values + ((scales[2*ib64+0] & 1) << 5); + threadgroup const float * values2 = shared_values + ((scales[2*ib64+1] & 1) << 5); + const float ls1 = ((scales[2*ib64+0] & 254) - 127); + const float ls2 = ((scales[2*ib64+1] & 254) - 127); - device const uint32_t * q4 = (device const uint32_t *)scales + QK_K/128 + 4*ib + 2*il; + device const uint32_t * q4 = (device const uint32_t *)scales + QK_K/128 + 8*ib64 + 2*il64; + device const uint32_t * qh = (device const uint32_t *)scales + QK_K/128 + QK_K/8 + 2*il64; float4 acc1 = {0.f}, acc2 = {0.f}; - aux32[0] = q4[0] & 0x0f0f0f0f; - aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; - qf1 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]}; - qf2 = {block_values[q8[4]], block_values[q8[5]], block_values[q8[6]], block_values[q8[7]]}; + uint32_t h = qh[0] >> 2*ib64; + aux32[0] = ((q4[0] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x10101010); + aux32[1] = ((q4[0] >> 4) & 0x0f0f0f0f) | ((h << 3) & 0x10101010); + qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]}; + qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; - aux32[0] = q4[1] & 0x0f0f0f0f; - aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; - qf1 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]}; - qf2 = {block_values[q8[4]], block_values[q8[5]], block_values[q8[6]], block_values[q8[7]]}; + h = qh[1] >> 2*ib64; + aux32[0] = ((q4[1] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x10101010); + aux32[1] = ((q4[1] >> 4) & 0x0f0f0f0f) | ((h << 3) & 0x10101010); + qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]}; + qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]}; acc1 += yl[2] * qf1; acc2 += yl[3] * qf2; - acc1 += acc2; - - sumf[row] += d[row] * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + sumf[row] += ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3]) + ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3]); scales += row_size; @@ -6379,7 +6383,7 @@ void kernel_mul_mv_iq5_ks_f32_impl( sumf = simd_sum(sumf); if (tiisg < 2) { - dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg]; + dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg] * d[tiisg]; } }