diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 33a4c70a..92890537 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -7265,7 +7265,7 @@ void kernel_mul_mv_iq2_kl_f32_impl( device const char * cx0 = (device const char *) src0 + (first_row + offset0)*row_size; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[32]; + float2 yl[16]; float sumf[N_DST]={0.f}; float drow[N_DST]; @@ -7298,9 +7298,10 @@ void kernel_mul_mv_iq2_kl_f32_impl( for (int ib = ix; ib < nb; ib += 4) { - for (int i = 0; i < 16; ++i) { - yl[i+ 0] = y4[i+ 0]; - yl[i+16] = y4[i+32]; + device const float2 * y2 = (device const float2 *)y4; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y2[i+ 0]; + yl[i+8] = y2[i+16]; } device const char * cx = cx0; @@ -7316,7 +7317,7 @@ void kernel_mul_mv_iq2_kl_f32_impl( device const uint16_t * ql = (device const uint16_t *)x->qs + 8*iq + 4*ir; device const uint16_t * qh = (device const uint16_t *)x->qh + 4*ir; - float2 acc = {0.f}; + float2 acc[2] = {0.f}; for (int l = 0; l < 4; ++l) { uint16_t h = qh[l] >> 2*iq; aux16[0] = ((ql[l] >> 0) & 0x0f0f) | ((h & 0x0101) << 4); @@ -7324,12 +7325,12 @@ void kernel_mul_mv_iq2_kl_f32_impl( for (int j = 0; j < 2; ++j) { threadgroup const float2 & val1 = all_values[aux8[j+0]]; threadgroup const float2 & val2 = all_values[aux8[j+2]]; - acc[0] += yl[4*l+2*j+ 0] * val1[0] + yl[4*l+2*j+ 1] * val1[1]; - acc[1] += yl[4*l+2*j+16] * val2[0] + yl[4*l+2*j+17] * val2[1]; + acc[0] += yl[2*l+j+0] * val1; + acc[1] += yl[2*l+j+8] * val2; } } - sumf[row] += drow[row] * (acc[0] * ls1 + acc[1] * ls2); + sumf[row] += drow[row] * ((acc[0][0] + acc[0][1]) * ls1 + (acc[1][0] + acc[1][1]) * ls2); cx += row_size;