Flash attention with softcap: Metal

This commit is contained in:
Iwan Kawrakow
2024-08-26 18:34:43 +02:00
parent 1ad3b25132
commit e4f200098b
2 changed files with 26 additions and 11 deletions

View File

@@ -910,9 +910,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
if (op->src[0]->ne[0] == 256) {
return false;
}
float softcap;
memcpy(&softcap, ((const float *) op->op_params) + 2, sizeof(softcap));
if (softcap != 0.0f) return false;
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
@@ -3004,9 +3001,14 @@ static enum ggml_status ggml_metal_graph_compute(
float scale;
float max_bias;
float softcap;
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
memcpy(&softcap, ((int32_t *) dst->op_params) + 2, sizeof(softcap));
if (softcap != 0.0f) {
scale /= softcap;
}
const uint32_t n_head = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
@@ -3080,7 +3082,8 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
[encoder setBytes:&softcap length:sizeof(softcap) atIndex:27];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:28];
if (!use_vec_kernel) {
// half8x8 kernel

View File

@@ -2279,6 +2279,7 @@ typedef void (flash_attn_ext_f16_t)(
constant float & max_bias,
constant float & m0,
constant float & m1,
constant float & softcap,
constant uint32_t & n_head_log2,
threadgroup half * shared,
uint3 tgpig[[threadgroup_position_in_grid]],
@@ -2317,6 +2318,7 @@ kernel void kernel_flash_attn_ext_f16(
constant float & max_bias,
constant float & m0,
constant float & m1,
constant float & softcap,
constant uint32_t & n_head_log2,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
@@ -2446,14 +2448,19 @@ kernel void kernel_flash_attn_ext_f16(
const short tx = tiisg%4;
const short ty = tiisg/4;
// mqk = mqk*scale
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
if (softcap != 0.0f) {
ss[8*cc + ty*TF + 2*tx + 0] = softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
ss[8*cc + ty*TF + 2*tx + 1] = softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
}
if (mask != q) {
// mqk = mqk*scale + mask*slope
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
} else {
// mqk = mqk*scale
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
}
}
}
@@ -2648,6 +2655,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
constant float & max_bias,
constant float & m0,
constant float & m1,
constant float & softcap,
constant uint32_t & n_head_log2,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
@@ -2783,7 +2791,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
// mqk = mqk*scale + mask*slope
if (tiisg == 0) {
mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
mqk *= scale;
if (softcap != 0.0f) {
mqk = softcap*precise::tanh(mqk);
}
mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
ss4[cc] = mqk;
}