From be2694eb68b99ca07ce579d069b1ad8ec9f227fe Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 2 Sep 2025 14:19:22 +0300 Subject: [PATCH] Add n_swa to FA parameters --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 23 +++++++++++++++++++---- ggml/src/ggml.c | 2 ++ src/llama.cpp | 24 ++++++++++++++++-------- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index aa85716e..8be70176 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1209,7 +1209,7 @@ static __device__ __forceinline__ int warp_reduce_all(int x) { } } -template +template __launch_bounds__(FATTN_KQ_STRIDE/2, 1) static __global__ void flash_attn_mask_to_KV_min_max( const half2 * __restrict__ mask, int2 * __restrict__ KV_min_max, const int ne30, const int s31, const int s33) { @@ -1250,6 +1250,13 @@ static __global__ void flash_attn_mask_to_KV_min_max( } } + if constexpr (!is_swa) { + if (threadIdx.x == 0) { + KV_min_max[sequence*ne31 + jt] = {0, KV_max_sj + FATTN_KQ_STRIDE}; + } + return; + } + if (threadIdx.x == 0) { KV_min_max[sequence*ne31 + jt].y = KV_max_sj + FATTN_KQ_STRIDE; } @@ -1316,6 +1323,9 @@ void launch_fattn_mma( GGML_ASSERT(Q->ne[3] == 1); + int n_swa; + memcpy(&n_swa, (const int *) KQV->op_params + 4, sizeof(int)); + ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); @@ -1371,7 +1381,7 @@ void launch_fattn_mma( const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; - if (mask && (Q->ne[1] >= 1024 || K->ne[1] >= 1024)) { + if (mask && (Q->ne[1] >= 1024 || (n_swa > 0 && K->ne[1] >= FATTN_KQ_STRIDE + n_swa))) { const int s31 = mask->nb[1] / sizeof(half2); const int s33 = mask->nb[3] / sizeof(half2); const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1); @@ -1379,8 +1389,13 @@ void launch_fattn_mma( const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y; const int iter_k = K->ne[1] / FATTN_KQ_STRIDE; KV_min_max.alloc(ne_KV_max); - flash_attn_mask_to_KV_min_max<<>> - ((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33); + if (n_swa > 0) { + flash_attn_mask_to_KV_min_max<<>> + ((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33); + } else { + flash_attn_mask_to_KV_min_max<<>> + ((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33); + } CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c6912301..c4227b35 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -9008,6 +9008,8 @@ struct ggml_tensor * ggml_flash_attn_ext( float params[] = { scale, max_bias, softcap }; ggml_set_op_params(result, params, sizeof(params)); + ggml_set_op_params_i32(result, 4, 0); + result->op = GGML_OP_FLASH_ATTN_EXT; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = q; diff --git a/src/llama.cpp b/src/llama.cpp index d793d006..49edbb8d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7985,7 +7985,7 @@ static struct ggml_tensor * llm_build_kqv( float kq_scale, const llm_build_cb & cb, int il, - ggml_tensor * sinks = nullptr) { + ggml_tensor * sinks = nullptr, int n_swa = 0) { const llama_model & model = lctx.model; const llama_hparams & hparams = lctx.model.hparams; const llama_cparams & cparams = lctx.cparams; @@ -8033,6 +8033,9 @@ static struct ggml_tensor * llm_build_kqv( cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); ggml_flash_attn_ext_add_sinks(cur, sinks); + if (n_swa > 0) { + ((int32_t *)cur->op_params)[4] = n_swa; + } // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG. @@ -8190,7 +8193,7 @@ static struct ggml_tensor * llm_build_kv( float kq_scale, const llm_build_cb & cb, int il, - ggml_tensor * sinks = nullptr) { + ggml_tensor * sinks = nullptr, int n_swa = 0) { const llama_hparams & hparams = lctx.model.hparams; const llama_cparams & cparams = lctx.cparams; @@ -8205,7 +8208,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * cur; cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, - q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks); + q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks, n_swa); cb(cur, "kqv_out", il); return cur; @@ -8766,7 +8769,8 @@ struct llm_build_context { cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr, + this_KQ_mask == KQ_mask_swa ? hparams.n_swa : 0); } if (il == n_layer - 1) { @@ -12198,7 +12202,8 @@ struct llm_build_context { cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il, nullptr, + KQ_mask_l == KQ_mask_swa ? hparams.n_swa : 0); } cur = llm_build_norm(ctx0, cur, hparams, @@ -12335,7 +12340,8 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il); + Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il, nullptr, + KQ_mask_l == KQ_mask_swa ? hparams.n_swa : 0); } cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il); @@ -14400,7 +14406,8 @@ struct llm_build_context { } cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, - KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il); + KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il, nullptr, + is_sliding ? hparams.n_swa : 0); } if (il == n_layer - 1) { @@ -15490,7 +15497,8 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks); + Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks, + is_sliding ? hparams.n_swa : 0); cb(cur, "attn_out", il); }