Add n_swa to FA parameters

This commit is contained in:
Iwan Kawrakow
2025-09-02 14:19:22 +03:00
parent c2500dbb04
commit be2694eb68
3 changed files with 37 additions and 12 deletions

View File

@@ -1209,7 +1209,7 @@ static __device__ __forceinline__ int warp_reduce_all(int x) {
}
}
template <int ncols1>
template <int ncols1, bool is_swa>
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
static __global__ void flash_attn_mask_to_KV_min_max(
const half2 * __restrict__ mask, int2 * __restrict__ KV_min_max, const int ne30, const int s31, const int s33) {
@@ -1250,6 +1250,13 @@ static __global__ void flash_attn_mask_to_KV_min_max(
}
}
if constexpr (!is_swa) {
if (threadIdx.x == 0) {
KV_min_max[sequence*ne31 + jt] = {0, KV_max_sj + FATTN_KQ_STRIDE};
}
return;
}
if (threadIdx.x == 0) {
KV_min_max[sequence*ne31 + jt].y = KV_max_sj + FATTN_KQ_STRIDE;
}
@@ -1316,6 +1323,9 @@ void launch_fattn_mma(
GGML_ASSERT(Q->ne[3] == 1);
int n_swa;
memcpy(&n_swa, (const int *) KQV->op_params + 4, sizeof(int));
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int id = ggml_cuda_get_device();
@@ -1371,7 +1381,7 @@ void launch_fattn_mma(
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
if (mask && (Q->ne[1] >= 1024 || K->ne[1] >= 1024)) {
if (mask && (Q->ne[1] >= 1024 || (n_swa > 0 && K->ne[1] >= FATTN_KQ_STRIDE + n_swa))) {
const int s31 = mask->nb[1] / sizeof(half2);
const int s33 = mask->nb[3] / sizeof(half2);
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
@@ -1379,8 +1389,13 @@ void launch_fattn_mma(
const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
KV_min_max.alloc(ne_KV_max);
flash_attn_mask_to_KV_min_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33);
if (n_swa > 0) {
flash_attn_mask_to_KV_min_max<ncols1, true><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33);
} else {
flash_attn_mask_to_KV_min_max<ncols1, false><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33);
}
CUDA_CHECK(cudaGetLastError());
}

View File

@@ -9008,6 +9008,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
float params[] = { scale, max_bias, softcap };
ggml_set_op_params(result, params, sizeof(params));
ggml_set_op_params_i32(result, 4, 0);
result->op = GGML_OP_FLASH_ATTN_EXT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = q;

View File

@@ -7985,7 +7985,7 @@ static struct ggml_tensor * llm_build_kqv(
float kq_scale,
const llm_build_cb & cb,
int il,
ggml_tensor * sinks = nullptr) {
ggml_tensor * sinks = nullptr, int n_swa = 0) {
const llama_model & model = lctx.model;
const llama_hparams & hparams = lctx.model.hparams;
const llama_cparams & cparams = lctx.cparams;
@@ -8033,6 +8033,9 @@ static struct ggml_tensor * llm_build_kqv(
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
ggml_flash_attn_ext_add_sinks(cur, sinks);
if (n_swa > 0) {
((int32_t *)cur->op_params)[4] = n_swa;
}
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
// For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG.
@@ -8190,7 +8193,7 @@ static struct ggml_tensor * llm_build_kv(
float kq_scale,
const llm_build_cb & cb,
int il,
ggml_tensor * sinks = nullptr) {
ggml_tensor * sinks = nullptr, int n_swa = 0) {
const llama_hparams & hparams = lctx.model.hparams;
const llama_cparams & cparams = lctx.cparams;
@@ -8205,7 +8208,7 @@ static struct ggml_tensor * llm_build_kv(
struct ggml_tensor * cur;
cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b,
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks);
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks, n_swa);
cb(cur, "kqv_out", il);
return cur;
@@ -8766,7 +8769,8 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr,
this_KQ_mask == KQ_mask_swa ? hparams.n_swa : 0);
}
if (il == n_layer - 1) {
@@ -12198,7 +12202,8 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il, nullptr,
KQ_mask_l == KQ_mask_swa ? hparams.n_swa : 0);
}
cur = llm_build_norm(ctx0, cur, hparams,
@@ -12335,7 +12340,8 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il);
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il, nullptr,
KQ_mask_l == KQ_mask_swa ? hparams.n_swa : 0);
}
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il);
@@ -14400,7 +14406,8 @@ struct llm_build_context {
}
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il, nullptr,
is_sliding ? hparams.n_swa : 0);
}
if (il == n_layer - 1) {
@@ -15490,7 +15497,8 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks);
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks,
is_sliding ? hparams.n_swa : 0);
cb(cur, "attn_out", il);
}