diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 5f2e23ae..577e399d 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -5585,14 +5585,14 @@ void kernel_mul_mv_iq1_bn_f32_impl( constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1}; + const int ib = ix % (QK_IQ1BN / 32); + const int i16 = 2*ib + ir; + for (int ib32 = ix; ib32 < nb32; ib32 += 16) { for (int j = 0; j < 16; ++j) yl[j] = y4[j]; const int ibl = ib32 / (QK_IQ1BN / 32); - const int ib = ib32 % (QK_IQ1BN / 32); - const int i16 = 2*ib + ir; - device const block_iq1_bn * xr = x + ibl; device const uint8_t * ql = xr->ql + 3*i16; device const uint8_t * extra = (device const uint8_t *)&xr->extra; @@ -5602,8 +5602,6 @@ void kernel_mul_mv_iq1_bn_f32_impl( float acc = 0; int i = 0; for (int k = 0; k < 3; ++k) { - //constant int8_t * vs = iq1bn_values + 5*ql[k]; - //for (int j = 0; j < 5; ++j) acc += yl[5*k+j]*vs[j]; uint8_t q = ql[k]; for (int j = 0; j < 5; ++j) { uint8_t v = k_mult[j]*q; @@ -5611,8 +5609,6 @@ void kernel_mul_mv_iq1_bn_f32_impl( acc += yl[i++] * values[v]; } } - //constant int8_t * vs = iq1bn_values + 5*extra[0]; - //acc += yl[15] * vs[i16]; uint8_t v = k_mult[i16]*extra[0]; v = 3*v >> 8; //(v + (v >> 1)) >> 7; acc += yl[15] * values[v]; @@ -5635,7 +5631,6 @@ void kernel_mul_mv_iq1_bn_f32_impl( } } -// TODO: unify with kernel_mul_mv_iq1_bn_f32_impl void kernel_mul_mv_iq1_tn_f32_impl( device const void * src0, device const float * src1, @@ -5661,7 +5656,7 @@ void kernel_mul_mv_iq1_tn_f32_impl( // Why are we not passing in src0->nb[0]? // But because we are not, we need to use this hack - const uint row_size = 2+sizeof(block_iq1_tn)*(ne00/QK_K); + const uint row_size = 2 + sizeof(block_iq1_bn)*(ne00/QK_IQ1BN); const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; @@ -5769,9 +5764,9 @@ void kernel_mul_mv_iq2_bn_f32_impl( device const uint8_t * cx = (device const uint8_t *) src0 + first_row*row_size + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[16]; - float sumf[N_DST]={0.f}; - float scale[N_DST]; + float4 yl[4]; + float sumf[N_DST]={0.f}; + float scale[N_DST]; for (int row = 0; row < N_DST; ++row) { scale[row] = *((device const float *)(cx + row*row_size)); @@ -5780,29 +5775,26 @@ void kernel_mul_mv_iq2_bn_f32_impl( const int ix = tiisg/4; // 0...7 const int ir = tiisg%4; // 0...3 - device const float * y4 = y + QK_IQ1BN * ix + 4 * ir; + device const float4 * y4 = (device const float4 *)(y + QK_IQ1BN * ix + 4 * ir); device const uint8_t * qs0 = cx + sizeof(float) + (QK_IQ1BN/4)*ix + 4*ir; for (int ib = ix; ib < nb; ib += 8) { - float sumy = 0.f; - for (int i = 0; i < 4; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy += yl[i+ 0]; - yl[i+ 4] = y4[i+16]; sumy += yl[i+ 4]; - yl[i+ 8] = y4[i+32]; sumy += yl[i+ 8]; - yl[i+12] = y4[i+48]; sumy += yl[i+12]; - } + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[8]; yl[3] = y4[12]; + float4 tmp = yl[0] + yl[1] + yl[2] + yl[3]; + const float sumy = tmp[0] + tmp[1] + tmp[2] + tmp[3]; device const uint8_t * qs = qs0; for (int row = 0; row < N_DST; row++) { float4 acc = {0.f}; + for (int j = 0; j < 4; ++j) { - acc[0] += yl[j+ 0] * (qs[j] & 0x03); - acc[1] += yl[j+ 4] * (qs[j] & 0x0c); - acc[2] += yl[j+ 8] * (qs[j] & 0x30); - acc[3] += yl[j+12] * (qs[j] & 0xc0); + acc[0] += yl[0][j] * (qs[j] & 0x03); + acc[1] += yl[1][j] * (qs[j] & 0x0c); + acc[2] += yl[2][j] * (qs[j] & 0x30); + acc[3] += yl[3][j] * (qs[j] & 0xc0); } sumf[row] += acc[0] + 0.25f*acc[1] + 0.0625*acc[2] + 0.015625f*acc[3] - sumy; @@ -5810,7 +5802,7 @@ void kernel_mul_mv_iq2_bn_f32_impl( qs += row_size; } - y4 += QK_IQ1BN * 8; + y4 += QK_IQ1BN * 2; qs0 += QK_IQ1BN * 2; }