From 95419ed393371ce4179c00e58caaafd7f7530734 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 12 Oct 2025 14:23:05 +0300 Subject: [PATCH] Whith FA on, create mask as f16 directly --- src/llama-build-context.cpp | 12 ++++ src/llama.cpp | 122 +++++++++++++++++++++++++++++++----- 2 files changed, 117 insertions(+), 17 deletions(-) diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index d8533f4f..ef2a6621 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -276,6 +276,12 @@ ggml_tensor * llm_build_context::build_inp_out_ids() { } ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) { + if (causal && flash_attn) { + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + cb(lctx.inp_KQ_mask, "KQ_mask", -1); + ggml_set_input(lctx.inp_KQ_mask); + return lctx.inp_KQ_mask; + } lctx.inp_KQ_mask = causal ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); @@ -287,6 +293,12 @@ ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) { ggml_tensor * llm_build_context::build_inp_KQ_mask_swa(bool causal) { GGML_ASSERT(hparams.n_swa > 0); + if (causal && flash_attn) { + lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); + ggml_set_input(lctx.inp_KQ_mask_swa); + return lctx.inp_KQ_mask_swa; + } lctx.inp_KQ_mask_swa = causal ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) diff --git a/src/llama.cpp b/src/llama.cpp index 634663fa..5d39c9af 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2051,21 +2051,55 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = nullptr; float * data_swa = nullptr; + ggml_half * data_f16 = nullptr; + ggml_half * data_swa_f16 = nullptr; if (lctx.inp_KQ_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); - data = (float *) lctx.inp_KQ_mask->data; + if (cparams.flash_attn) { + data_f16 = (ggml_half *)lctx.inp_KQ_mask->data; + } else { + data = (float *) lctx.inp_KQ_mask->data; + } } if (lctx.inp_KQ_mask_swa) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer)); - data_swa = (float *) lctx.inp_KQ_mask_swa->data; + if (cparams.flash_attn) { + data_swa_f16 = (ggml_half *) lctx.inp_KQ_mask_swa->data; + } else { + data_swa = (float *) lctx.inp_KQ_mask_swa->data; + } } + auto noalibi_f16 = [&lctx, &hparams, n_kv, data_f16, data_swa_f16] (int j, llama_pos pos, llama_seq_id seq_id, int first, int last) { + ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY); + ggml_half h_zero = ggml_fp32_to_fp16(0.f); + for (int i = first; i < last; ++i) { + ggml_half h = !lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos ? h_inf : h_zero; + if (data_f16) data_f16[j*n_kv + i] = h; + if (data_swa_f16) { + if (h != h_inf) { + if (hparams.n_attn_chunk) { + llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; + if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { + h = h_inf; + } + } else { + if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + h = h_inf; + } + } + } + data_swa_f16[j*n_kv + i] = h; + } + } + }; + if (n_kv >= 1024 && n_tokens >= 32) { int n_thread = std::max(1, int(std::thread::hardware_concurrency()/2)); int npt = (n_kv + n_thread - 1)/n_thread; - auto compute = [&batch, &lctx, &hparams, n_tokens, n_kv, npt, data, data_swa] (int ith) { + auto compute = [&batch, &lctx, &hparams, &cparams, &noalibi_f16, n_tokens, n_kv, npt, data, data_swa, data_f16, data_swa_f16] (int ith) { int first = ith * npt; int last = std::min(int(n_kv), first + npt); if (last <= first) return; @@ -2073,6 +2107,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const llama_pos pos = batch.pos[j]; const llama_seq_id seq_id = batch.seq_id[j][0]; + if (!hparams.use_alibi && cparams.flash_attn) { + noalibi_f16(j, pos, seq_id, first, last); + continue; + } + for (int i = first; i < last; ++i) { float f; if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { @@ -2088,9 +2127,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { if (data) { data[j*n_kv + i] = f; } + if (data_f16) { + data_f16[j*n_kv + i] = ggml_fp32_to_fp16(f); + } // may need to cut off old tokens for sliding window - if (data_swa) { + if (data_swa || data_swa_f16) { if (f > -INFINITY) { if (hparams.n_attn_chunk) { llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; @@ -2103,7 +2145,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } - data_swa[j*n_kv + i] = f; + if (data_swa) { + data_swa[j*n_kv + i] = f; + } + if (data_swa_f16) { + data_swa_f16[j*n_kv + i] = ggml_fp32_to_fp16(f); + } } } } @@ -2120,6 +2167,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } + if (data_f16) { + ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY); + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_f16[i*n_kv + j] = h_inf; + } + } + } if (data_swa) { for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { @@ -2128,6 +2183,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } + if (data_swa_f16) { + ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY); + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_swa_f16[i*n_kv + j] = h_inf; + } + } + } } else { @@ -2135,14 +2198,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // of the correct sequence for each token of the batch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. for (int h = 0; h < 1; ++h) { - auto data_h = data ? data + h*(n_kv*n_tokens) : nullptr; - auto data_swa_h = data_swa ? data_swa + h*(n_kv*n_tokens) : nullptr; for (int j = 0; j < n_tokens; ++j) { - auto data_j = data_h ? data_h + j*n_kv : nullptr; - auto data_swa_j = data_swa_h ? data_swa_h + j*n_kv : nullptr; const llama_pos pos = batch.pos[j]; const llama_seq_id seq_id = batch.seq_id[j][0]; + if (!hparams.use_alibi && cparams.flash_attn) { + noalibi_f16(j, pos, seq_id, 0, n_kv); + continue; + } + for (int i = 0; i < n_kv; ++i) { float f; if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { @@ -2155,12 +2219,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (data_j) { - data_j[i] = f; + if (data) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = f; + } + if (data_f16) { + data_f16[h*(n_kv*n_tokens) + j*n_kv + i] = ggml_fp32_to_fp16(f); } // may need to cut off old tokens for sliding window - if (data_swa_j) { + if (data_swa || data_swa_f16) { if (hparams.n_attn_chunk) { llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { @@ -2171,23 +2238,44 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { f = -INFINITY; } } - data_swa_j[i] = f; + if (data_swa) { + data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; + } + if (data_swa_f16) { + data_swa_f16[h*(n_kv*n_tokens) + j*n_kv + i] = ggml_fp32_to_fp16(f); + } } } } - if (data_h) { + if (data) { for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { for (int j = 0; j < n_kv; ++j) { - data_h[i*n_kv + j] = -INFINITY; + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + if (data_f16) { + ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY); + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_f16[h*(n_kv*n_tokens) + i*n_kv + j] = h_inf; } } } - if (data_swa_h) { + if (data_swa) { for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { for (int j = 0; j < n_kv; ++j) { - data_swa_h[i*n_kv + j] = -INFINITY; + data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + if (data_swa_f16) { + ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY); + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_swa_f16[h*(n_kv*n_tokens) + i*n_kv + j] = h_inf; } } }