diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 156466b5..6553f465 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1701,7 +1701,7 @@ void kernel_mul_mv_bf16_impl( device const uint16_t * x = (device const uint16_t *) (src0 + offset0); - typedef union { uint32_t u; float f; } aux_t; + typedef union { uint32_t u[4]; float f[4]; } aux_t; aux_t aux; for (int row = 0; row < N_MV_T_T; ++row) { @@ -1713,9 +1713,12 @@ void kernel_mul_mv_bf16_impl( device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - aux.u = x[i] << 16; - sumf += aux.f * (float) y[i]; + for (int i = tiisg; i < ne00/4; i += 32) { + aux.u[0] = x[4*i+0] << 16; + aux.u[1] = x[4*i+1] << 16; + aux.u[2] = x[4*i+2] << 16; + aux.u[3] = x[4*i+3] << 16; + sumf += aux.f[0] * (float)y[4*i+0] + aux.f[1] * (float)y[4*i+1] + aux.f[2] * (float)y[4*i+2] + aux.f[3] * (float)y[4*i+3]; } float all_sum = simd_sum(sumf);