mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 22:54:10 +00:00
Add n_swa to FA parameters
This commit is contained in:
@@ -1209,7 +1209,7 @@ static __device__ __forceinline__ int warp_reduce_all(int x) {
|
||||
}
|
||||
}
|
||||
|
||||
template <int ncols1>
|
||||
template <int ncols1, bool is_swa>
|
||||
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
|
||||
static __global__ void flash_attn_mask_to_KV_min_max(
|
||||
const half2 * __restrict__ mask, int2 * __restrict__ KV_min_max, const int ne30, const int s31, const int s33) {
|
||||
@@ -1250,6 +1250,13 @@ static __global__ void flash_attn_mask_to_KV_min_max(
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (!is_swa) {
|
||||
if (threadIdx.x == 0) {
|
||||
KV_min_max[sequence*ne31 + jt] = {0, KV_max_sj + FATTN_KQ_STRIDE};
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
KV_min_max[sequence*ne31 + jt].y = KV_max_sj + FATTN_KQ_STRIDE;
|
||||
}
|
||||
@@ -1316,6 +1323,9 @@ void launch_fattn_mma(
|
||||
|
||||
GGML_ASSERT(Q->ne[3] == 1);
|
||||
|
||||
int n_swa;
|
||||
memcpy(&n_swa, (const int *) KQV->op_params + 4, sizeof(int));
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int id = ggml_cuda_get_device();
|
||||
@@ -1371,7 +1381,7 @@ void launch_fattn_mma(
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||
|
||||
if (mask && (Q->ne[1] >= 1024 || K->ne[1] >= 1024)) {
|
||||
if (mask && (Q->ne[1] >= 1024 || (n_swa > 0 && K->ne[1] >= FATTN_KQ_STRIDE + n_swa))) {
|
||||
const int s31 = mask->nb[1] / sizeof(half2);
|
||||
const int s33 = mask->nb[3] / sizeof(half2);
|
||||
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
|
||||
@@ -1379,8 +1389,13 @@ void launch_fattn_mma(
|
||||
const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
|
||||
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
|
||||
KV_min_max.alloc(ne_KV_max);
|
||||
flash_attn_mask_to_KV_min_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
|
||||
((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33);
|
||||
if (n_swa > 0) {
|
||||
flash_attn_mask_to_KV_min_max<ncols1, true><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
|
||||
((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33);
|
||||
} else {
|
||||
flash_attn_mask_to_KV_min_max<ncols1, false><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
|
||||
((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
@@ -9008,6 +9008,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||
float params[] = { scale, max_bias, softcap };
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
ggml_set_op_params_i32(result, 4, 0);
|
||||
|
||||
result->op = GGML_OP_FLASH_ATTN_EXT;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = q;
|
||||
|
||||
@@ -7985,7 +7985,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
float kq_scale,
|
||||
const llm_build_cb & cb,
|
||||
int il,
|
||||
ggml_tensor * sinks = nullptr) {
|
||||
ggml_tensor * sinks = nullptr, int n_swa = 0) {
|
||||
const llama_model & model = lctx.model;
|
||||
const llama_hparams & hparams = lctx.model.hparams;
|
||||
const llama_cparams & cparams = lctx.cparams;
|
||||
@@ -8033,6 +8033,9 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
||||
if (n_swa > 0) {
|
||||
((int32_t *)cur->op_params)[4] = n_swa;
|
||||
}
|
||||
|
||||
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
|
||||
// For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG.
|
||||
@@ -8190,7 +8193,7 @@ static struct ggml_tensor * llm_build_kv(
|
||||
float kq_scale,
|
||||
const llm_build_cb & cb,
|
||||
int il,
|
||||
ggml_tensor * sinks = nullptr) {
|
||||
ggml_tensor * sinks = nullptr, int n_swa = 0) {
|
||||
const llama_hparams & hparams = lctx.model.hparams;
|
||||
const llama_cparams & cparams = lctx.cparams;
|
||||
|
||||
@@ -8205,7 +8208,7 @@ static struct ggml_tensor * llm_build_kv(
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b,
|
||||
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks);
|
||||
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks, n_swa);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
return cur;
|
||||
@@ -8766,7 +8769,8 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
||||
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr,
|
||||
this_KQ_mask == KQ_mask_swa ? hparams.n_swa : 0);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
@@ -12198,7 +12202,8 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il, nullptr,
|
||||
KQ_mask_l == KQ_mask_swa ? hparams.n_swa : 0);
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
@@ -12335,7 +12340,8 @@ struct llm_build_context {
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il, nullptr,
|
||||
KQ_mask_l == KQ_mask_swa ? hparams.n_swa : 0);
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
@@ -14400,7 +14406,8 @@ struct llm_build_context {
|
||||
}
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
|
||||
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
|
||||
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il, nullptr,
|
||||
is_sliding ? hparams.n_swa : 0);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
@@ -15490,7 +15497,8 @@ struct llm_build_context {
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks);
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks,
|
||||
is_sliding ? hparams.n_swa : 0);
|
||||
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user