From 8e80d15930dd17e73e42d01fa53c02b3fbced0d8 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 16 Sep 2024 17:32:48 +0200 Subject: [PATCH] Faster BF16 Metal dot product --- ggml/src/ggml-metal.metal | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 156466b5..6553f465 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1701,7 +1701,7 @@ void kernel_mul_mv_bf16_impl( device const uint16_t * x = (device const uint16_t *) (src0 + offset0); - typedef union { uint32_t u; float f; } aux_t; + typedef union { uint32_t u[4]; float f[4]; } aux_t; aux_t aux; for (int row = 0; row < N_MV_T_T; ++row) { @@ -1713,9 +1713,12 @@ void kernel_mul_mv_bf16_impl( device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - aux.u = x[i] << 16; - sumf += aux.f * (float) y[i]; + for (int i = tiisg; i < ne00/4; i += 32) { + aux.u[0] = x[4*i+0] << 16; + aux.u[1] = x[4*i+1] << 16; + aux.u[2] = x[4*i+2] << 16; + aux.u[3] = x[4*i+3] << 16; + sumf += aux.f[0] * (float)y[4*i+0] + aux.f[1] * (float)y[4*i+1] + aux.f[2] * (float)y[4*i+2] + aux.f[3] * (float)y[4*i+3]; } float all_sum = simd_sum(sumf);