diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 514ac935..83bd76f9 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -910,9 +910,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx if (op->src[0]->ne[0] == 256) { return false; } - float softcap; - memcpy(&softcap, ((const float *) op->op_params) + 2, sizeof(softcap)); - if (softcap != 0.0f) return false; return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -3004,9 +3001,14 @@ static enum ggml_status ggml_metal_graph_compute( float scale; float max_bias; + float softcap; memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&softcap, ((int32_t *) dst->op_params) + 2, sizeof(softcap)); + if (softcap != 0.0f) { + scale /= softcap; + } const uint32_t n_head = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); @@ -3080,7 +3082,8 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; + [encoder setBytes:&softcap length:sizeof(softcap) atIndex:27]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:28]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index b7c03356..f9c88a37 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2279,6 +2279,7 @@ typedef void (flash_attn_ext_f16_t)( constant float & max_bias, constant float & m0, constant float & m1, + constant float & softcap, constant uint32_t & n_head_log2, threadgroup half * shared, uint3 tgpig[[threadgroup_position_in_grid]], @@ -2317,6 +2318,7 @@ kernel void kernel_flash_attn_ext_f16( constant float & max_bias, constant float & m0, constant float & m1, + constant float & softcap, constant uint32_t & n_head_log2, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -2446,14 +2448,19 @@ kernel void kernel_flash_attn_ext_f16( const short tx = tiisg%4; const short ty = tiisg/4; + // mqk = mqk*scale + ss[8*cc + ty*TF + 2*tx + 0] *= scale; + ss[8*cc + ty*TF + 2*tx + 1] *= scale; + + if (softcap != 0.0f) { + ss[8*cc + ty*TF + 2*tx + 0] = softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]); + ss[8*cc + ty*TF + 2*tx + 1] = softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]); + } + if (mask != q) { // mqk = mqk*scale + mask*slope - ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; - ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; - } else { - // mqk = mqk*scale - ss[8*cc + ty*TF + 2*tx + 0] *= scale; - ss[8*cc + ty*TF + 2*tx + 1] *= scale; + ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; + ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; } } } @@ -2648,6 +2655,7 @@ kernel void kernel_flash_attn_ext_vec_f16( constant float & max_bias, constant float & m0, constant float & m1, + constant float & softcap, constant uint32_t & n_head_log2, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -2783,7 +2791,11 @@ kernel void kernel_flash_attn_ext_vec_f16( // mqk = mqk*scale + mask*slope if (tiisg == 0) { - mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f); + mqk *= scale; + if (softcap != 0.0f) { + mqk = softcap*precise::tanh(mqk); + } + mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f; ss4[cc] = mqk; }