From 9b02dd0405f14d197b7fc6288978f5b91b5e76b4 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 12 Oct 2025 13:15:16 +0300 Subject: [PATCH] Parallelize mask We see non-negligible PP gains for long contexts. More importantly, the strange drop in performance observed for GPT-OSS for context >= 32k tokens is gone. --- src/llama.cpp | 212 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 196 insertions(+), 16 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 7bc823a0..634663fa 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1924,27 +1924,41 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // set input data // -#if IK_PRINT_TIMING - auto tim1 = ggml_time_us(); -#endif const auto & hparams = lctx.model.hparams; const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; if (batch.token) { +#if IK_PRINT_TIMING == 2 + auto tim1 = ggml_time_us(); +#endif const int64_t n_tokens = batch.n_tokens; ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens)); +#if IK_PRINT_TIMING == 2 + auto tim2 = ggml_time_us(); + printf("set_inputs(token): %d us\n", int(tim2-tim1)); +#endif } if (batch.embd) { +#if IK_PRINT_TIMING == 2 + auto tim1 = ggml_time_us(); +#endif const int64_t n_embd = hparams.n_embd; const int64_t n_tokens = batch.n_tokens; ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); +#if IK_PRINT_TIMING == 2 + auto tim2 = ggml_time_us(); + printf("set_inputs(embd): %d us\n", int(tim2-tim1)); +#endif } if (batch.pos && lctx.inp_pos) { +#if IK_PRINT_TIMING == 2 + auto tim1 = ggml_time_us(); +#endif const int64_t n_tokens = batch.n_tokens; const int n_pos_per_embd = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1; if (batch.token && n_pos_per_embd == 4) { @@ -1959,9 +1973,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } else { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(lctx.inp_pos)); } +#if IK_PRINT_TIMING == 2 + auto tim2 = ggml_time_us(); + printf("set_inputs(pos): %d us\n", int(tim2-tim1)); +#endif } if (lctx.inp_pos && lctx.inp_scale) { +#if IK_PRINT_TIMING == 2 + auto tim1 = ggml_time_us(); +#endif int n_tokens = batch.n_tokens; GGML_ASSERT(ggml_nelements(lctx.inp_scale) >= n_tokens); if (int(lctx.scale_data.size()) < n_tokens) lctx.scale_data.resize(n_tokens); @@ -1970,9 +1991,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { lctx.scale_data[i] = std::log(std::floor((batch.pos[i] + 1.0f) / hparams.n_attn_temp_floor_scale) + 1.0f) * hparams.f_attn_temp_scale + 1.0f; } ggml_backend_tensor_set(lctx.inp_scale, lctx.scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(lctx.inp_scale)); +#if IK_PRINT_TIMING == 2 + auto tim2 = ggml_time_us(); + printf("set_inputs(scale): %d us\n", int(tim2-tim1)); +#endif } if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { +#if IK_PRINT_TIMING == 2 + auto tim1 = ggml_time_us(); +#endif GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); const int64_t n_tokens = batch.n_tokens; @@ -1998,6 +2026,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } else { GGML_ASSERT(lctx.n_outputs == 0); } +#if IK_PRINT_TIMING == 2 + auto tim2 = ggml_time_us(); + printf("set_inputs(outputs): %d us\n", int(tim2-tim1)); +#endif } GGML_ASSERT( @@ -2008,6 +2040,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ); if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) { +#if IK_PRINT_TIMING == 2 + auto tim1 = ggml_time_us(); +#endif // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (cparams.causal_attn && !lctx.is_encoding) { const int64_t n_kv = kv_self.n; @@ -2027,11 +2062,84 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { data_swa = (float *) lctx.inp_KQ_mask_swa->data; } + if (n_kv >= 1024 && n_tokens >= 32) { + int n_thread = std::max(1, int(std::thread::hardware_concurrency()/2)); + int npt = (n_kv + n_thread - 1)/n_thread; + auto compute = [&batch, &lctx, &hparams, n_tokens, n_kv, npt, data, data_swa] (int ith) { + int first = ith * npt; + int last = std::min(int(n_kv), first + npt); + if (last <= first) return; + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; + + for (int i = first; i < last; ++i) { + float f; + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + f = -INFINITY; + } else { + if (hparams.use_alibi) { + f = -std::abs(lctx.kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } + } + + if (data) { + data[j*n_kv + i] = f; + } + + // may need to cut off old tokens for sliding window + if (data_swa) { + if (f > -INFINITY) { + if (hparams.n_attn_chunk) { + llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; + if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { + f = -INFINITY; + } + } else { + if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } + } + } + data_swa[j*n_kv + i] = f; + } + } + } + }; + std::vector workers(n_thread-1); + int it = 0; + for (auto& w : workers) w = std::thread(compute, it++); + compute(it); + for (auto& w : workers) w.join(); + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[i*n_kv + j] = -INFINITY; + } + } + } + + if (data_swa) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_swa[i*n_kv + j] = -INFINITY; + } + } + } + } + else { + // For causal attention, use only the previous KV cells // of the correct sequence for each token of the batch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. for (int h = 0; h < 1; ++h) { + auto data_h = data ? data + h*(n_kv*n_tokens) : nullptr; + auto data_swa_h = data_swa ? data_swa + h*(n_kv*n_tokens) : nullptr; for (int j = 0; j < n_tokens; ++j) { + auto data_j = data_h ? data_h + j*n_kv : nullptr; + auto data_swa_j = data_swa_h ? data_swa_h + j*n_kv : nullptr; const llama_pos pos = batch.pos[j]; const llama_seq_id seq_id = batch.seq_id[j][0]; @@ -2047,12 +2155,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (data) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = f; + if (data_j) { + data_j[i] = f; } // may need to cut off old tokens for sliding window - if (data_swa) { + if (data_swa_j) { if (hparams.n_attn_chunk) { llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { @@ -2063,27 +2171,32 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { f = -INFINITY; } } - data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; + data_swa_j[i] = f; } } } - if (data) { + if (data_h) { for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + data_h[i*n_kv + j] = -INFINITY; } } } - if (data_swa) { + if (data_swa_h) { for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { for (int j = 0; j < n_kv; ++j) { - data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + data_swa_h[i*n_kv + j] = -INFINITY; } } } } + } +#if IK_PRINT_TIMING == 2 + auto tim2 = ggml_time_us(); + printf("set_inputs(mask1): %d us\n", int(tim2-tim1)); +#endif } else { // when using kv cache, the mask needs to match the kv cache size const int64_t n_tokens = batch.n_tokens; @@ -2118,6 +2231,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } +#if IK_PRINT_TIMING == 2 + auto tim2 = ggml_time_us(); + printf("set_inputs(mask2): %d us\n", int(tim2-tim1)); +#endif } } @@ -2315,10 +2432,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } -#if IK_PRINT_TIMING - auto tim2 = ggml_time_us(); - printf("%s(...): %d us\n", __func__, int(tim2-tim1)); -#endif } // Make sure enough space is available for outputs. @@ -2434,6 +2547,9 @@ static int llama_decode_internal( LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); return -1; } +#if IK_PRINT_TIMING > 2 + printf("===== %s: %ld\n", __func__, ggml_time_us()); +#endif const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -2502,6 +2618,9 @@ static int llama_decode_internal( } for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { +#if IK_PRINT_TIMING + auto tim1 = ggml_time_us(); +#endif const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); llama_batch u_batch = { /* .n_tokens = */ (int32_t) n_tokens, @@ -2592,13 +2711,31 @@ static int llama_decode_internal( //kv_self.n = llama_kv_cache_cell_max(kv_self); } } +#if IK_PRINT_TIMING + auto tim2 = ggml_time_us(); + printf("prelude(...): %d us\n", int(tim2-tim1)); +#endif //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); +#if IK_PRINT_TIMING + tim1 = ggml_time_us(); +#endif ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); +#if IK_PRINT_TIMING + tim2 = ggml_time_us(); + printf("sched_reset(...): %d us\n", int(tim2-tim1)); +#endif +#if IK_PRINT_TIMING + tim1 = ggml_time_us(); +#endif ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, u_batch, false); +#if IK_PRINT_TIMING + tim2 = ggml_time_us(); + printf("build_graph(...): %d us\n", int(tim2-tim1)); +#endif // the output is always the last tensor in the graph struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; @@ -2624,11 +2761,33 @@ static int llama_decode_internal( } // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); +#if IK_PRINT_TIMING + tim1 = ggml_time_us(); +#endif ggml_backend_sched_alloc_graph(lctx.sched, gf); +#if IK_PRINT_TIMING + tim2 = ggml_time_us(); + printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1)); +#endif +#if IK_PRINT_TIMING == 1 + tim1 = ggml_time_us(); +#endif llama_set_inputs(lctx, u_batch); +#if IK_PRINT_TIMING == 1 + tim2 = ggml_time_us(); + printf("set_inputs(...): %d us\n", int(tim2-tim1)); +#endif +#if IK_PRINT_TIMING + tim1 = ggml_time_us(); +#endif llama_graph_compute(lctx, gf, n_threads); +#if IK_PRINT_TIMING + llama_synchronize(&lctx); + tim2 = ggml_time_us(); + printf("graph_compute(...): %d us\n", int(tim2-tim1)); +#endif // update the kv ring buffer { @@ -2647,6 +2806,9 @@ static int llama_decode_internal( // extract logits if (res) { +#if IK_PRINT_TIMING + tim1 = ggml_time_us(); +#endif ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(lctx.logits != nullptr); @@ -2659,10 +2821,17 @@ static int llama_decode_internal( GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size); ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float)); } +#if IK_PRINT_TIMING + tim2 = ggml_time_us(); + printf("get_result(...): %d us\n", int(tim2-tim1)); +#endif } // extract embeddings if (embd) { +#if IK_PRINT_TIMING + tim1 = ggml_time_us(); +#endif ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); GGML_ASSERT(backend_embd != nullptr); @@ -2702,6 +2871,10 @@ static int llama_decode_internal( GGML_ABORT("unknown pooling type"); } } +#if IK_PRINT_TIMING + tim2 = ggml_time_us(); + printf("get_embedding(...): %d us\n", int(tim2-tim1)); +#endif } n_outputs_prev += lctx.n_outputs; } @@ -2718,7 +2891,7 @@ static int llama_decode_internal( // queue defragmentation for next llama_kv_cache_update if (fragmentation > cparams.defrag_thold) { - //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation); + LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation); llama_kv_cache_defrag(kv_self); } @@ -2726,7 +2899,14 @@ static int llama_decode_internal( // Reset state for the next token before backend sync, to allow the CPU activities in the reset to // overlap with device computation. +#if IK_PRINT_TIMING + auto tim1 = ggml_time_us(); +#endif ggml_backend_sched_reset(lctx.sched); +#if IK_PRINT_TIMING + auto tim2 = ggml_time_us(); + printf("sched_reset(...): %d us\n", int(tim2-tim1)); +#endif return 0; }