From 1cdb6993eecf140ee9642bd31580da0432173183 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 25 Sep 2024 13:08:55 +0300 Subject: [PATCH] Use fp32 for K*Q in Metal FA implementation (#62) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-metal.metal | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 6553f465..259fa609 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2904,11 +2904,11 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory - half4 mq[D4]; + float4 mq[D4]; for (short ii = 0; ii < D4; ii += NW) { short i = ii + tiisg; - mq[i] = sq4[i]; + mq[i] = (float4)sq4[i]; } // pointer to the mask @@ -2934,11 +2934,11 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short ii = 0; ii < D4; ii += NW) { const short i = ii + tiisg; - half4x4 mk; - mk[0] = pk4[i + 0*(nb11/8)]; - mk[1] = pk4[i + 1*(nb11/8)]; - mk[2] = pk4[i + 2*(nb11/8)]; - mk[3] = pk4[i + 3*(nb11/8)]; + float4x4 mk; + mk[0] = (float4)pk4[i + 0*(nb11/8)]; + mk[1] = (float4)pk4[i + 1*(nb11/8)]; + mk[2] = (float4)pk4[i + 2*(nb11/8)]; + mk[3] = (float4)pk4[i + 3*(nb11/8)]; mqk += (float4) (mq[i] * mk); } @@ -2960,6 +2960,7 @@ kernel void kernel_flash_attn_ext_vec_f16( ss4[cc] = mqk; } + } }