diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index e782c1ad..562e6f58 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2235,7 +2235,7 @@ extern "C" { int min_entries, float thresh); -#define GGML_KQ_MASK_PAD 64 +#define GGML_KQ_MASK_PAD 16 // q: [n_embd, n_batch, n_head, 1] // k: [n_embd, n_kv, n_head_kv, 1] diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index d8533f4f..ef2a6621 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -276,6 +276,12 @@ ggml_tensor * llm_build_context::build_inp_out_ids() { } ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) { + if (causal && flash_attn) { + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + cb(lctx.inp_KQ_mask, "KQ_mask", -1); + ggml_set_input(lctx.inp_KQ_mask); + return lctx.inp_KQ_mask; + } lctx.inp_KQ_mask = causal ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); @@ -287,6 +293,12 @@ ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) { ggml_tensor * llm_build_context::build_inp_KQ_mask_swa(bool causal) { GGML_ASSERT(hparams.n_swa > 0); + if (causal && flash_attn) { + lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); + ggml_set_input(lctx.inp_KQ_mask_swa); + return lctx.inp_KQ_mask_swa; + } lctx.inp_KQ_mask_swa = causal ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) diff --git a/src/llama.cpp b/src/llama.cpp index 7bc823a0..53ebb075 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1924,27 +1924,41 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // set input data // -#if IK_PRINT_TIMING - auto tim1 = ggml_time_us(); -#endif const auto & hparams = lctx.model.hparams; const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; if (batch.token) { +#if IK_PRINT_TIMING == 2 + auto tim1 = ggml_time_us(); +#endif const int64_t n_tokens = batch.n_tokens; ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens)); +#if IK_PRINT_TIMING == 2 + auto tim2 = ggml_time_us(); + printf("set_inputs(token): %d us\n", int(tim2-tim1)); +#endif } if (batch.embd) { +#if IK_PRINT_TIMING == 2 + auto tim1 = ggml_time_us(); +#endif const int64_t n_embd = hparams.n_embd; const int64_t n_tokens = batch.n_tokens; ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); +#if IK_PRINT_TIMING == 2 + auto tim2 = ggml_time_us(); + printf("set_inputs(embd): %d us\n", int(tim2-tim1)); +#endif } if (batch.pos && lctx.inp_pos) { +#if IK_PRINT_TIMING == 2 + auto tim1 = ggml_time_us(); +#endif const int64_t n_tokens = batch.n_tokens; const int n_pos_per_embd = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1; if (batch.token && n_pos_per_embd == 4) { @@ -1959,9 +1973,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } else { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(lctx.inp_pos)); } +#if IK_PRINT_TIMING == 2 + auto tim2 = ggml_time_us(); + printf("set_inputs(pos): %d us\n", int(tim2-tim1)); +#endif } if (lctx.inp_pos && lctx.inp_scale) { +#if IK_PRINT_TIMING == 2 + auto tim1 = ggml_time_us(); +#endif int n_tokens = batch.n_tokens; GGML_ASSERT(ggml_nelements(lctx.inp_scale) >= n_tokens); if (int(lctx.scale_data.size()) < n_tokens) lctx.scale_data.resize(n_tokens); @@ -1970,9 +1991,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { lctx.scale_data[i] = std::log(std::floor((batch.pos[i] + 1.0f) / hparams.n_attn_temp_floor_scale) + 1.0f) * hparams.f_attn_temp_scale + 1.0f; } ggml_backend_tensor_set(lctx.inp_scale, lctx.scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(lctx.inp_scale)); +#if IK_PRINT_TIMING == 2 + auto tim2 = ggml_time_us(); + printf("set_inputs(scale): %d us\n", int(tim2-tim1)); +#endif } if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { +#if IK_PRINT_TIMING == 2 + auto tim1 = ggml_time_us(); +#endif GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); const int64_t n_tokens = batch.n_tokens; @@ -1998,6 +2026,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } else { GGML_ASSERT(lctx.n_outputs == 0); } +#if IK_PRINT_TIMING == 2 + auto tim2 = ggml_time_us(); + printf("set_inputs(outputs): %d us\n", int(tim2-tim1)); +#endif } GGML_ASSERT( @@ -2008,6 +2040,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ); if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) { +#if IK_PRINT_TIMING == 2 + auto tim1 = ggml_time_us(); +#endif // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (cparams.causal_attn && !lctx.is_encoding) { const int64_t n_kv = kv_self.n; @@ -2016,17 +2051,135 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = nullptr; float * data_swa = nullptr; + ggml_half * data_f16 = nullptr; + ggml_half * data_swa_f16 = nullptr; if (lctx.inp_KQ_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); - data = (float *) lctx.inp_KQ_mask->data; + if (cparams.flash_attn) { + data_f16 = (ggml_half *)lctx.inp_KQ_mask->data; + } else { + data = (float *) lctx.inp_KQ_mask->data; + } } if (lctx.inp_KQ_mask_swa) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer)); - data_swa = (float *) lctx.inp_KQ_mask_swa->data; + if (cparams.flash_attn) { + data_swa_f16 = (ggml_half *) lctx.inp_KQ_mask_swa->data; + } else { + data_swa = (float *) lctx.inp_KQ_mask_swa->data; + } } + auto noalibi_f16 = [&lctx, &hparams, n_kv, data_f16, data_swa_f16] (int j, llama_pos pos, llama_seq_id seq_id, int first, int last) { + ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY); + ggml_half h_zero = ggml_fp32_to_fp16(0.f); + for (int i = first; i < last; ++i) { + ggml_half h = !lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos ? h_inf : h_zero; + if (data_f16) data_f16[j*n_kv + i] = h; + if (data_swa_f16) { + if (h != h_inf) { + if (hparams.n_attn_chunk) { + llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; + if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { + h = h_inf; + } + } else { + if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + h = h_inf; + } + } + } + data_swa_f16[j*n_kv + i] = h; + } + } + }; + + if (n_kv >= 1024 && n_tokens >= 32) { + int n_thread = std::max(1, int(std::thread::hardware_concurrency()/2)); + int npt = (n_kv + n_thread - 1)/n_thread; + auto compute = [&batch, &lctx, &hparams, &cparams, &noalibi_f16, n_tokens, n_kv, npt, data, data_swa, data_f16, data_swa_f16] (int ith) { + int first = ith * npt; + int last = std::min(int(n_kv), first + npt); + if (last <= first) return; + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; + + if (!hparams.use_alibi && cparams.flash_attn) { + noalibi_f16(j, pos, seq_id, first, last); + continue; + } + + for (int i = first; i < last; ++i) { + float f; + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + f = -INFINITY; + } else { + if (hparams.use_alibi) { + f = -std::abs(lctx.kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } + } + + if (data) { + data[j*n_kv + i] = f; + } + if (data_f16) { + data_f16[j*n_kv + i] = ggml_fp32_to_fp16(f); + } + + // may need to cut off old tokens for sliding window + if (data_swa || data_swa_f16) { + if (f > -INFINITY) { + if (hparams.n_attn_chunk) { + llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; + if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { + f = -INFINITY; + } + } else { + if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } + } + } + if (data_swa) { + data_swa[j*n_kv + i] = f; + } + if (data_swa_f16) { + data_swa_f16[j*n_kv + i] = ggml_fp32_to_fp16(f); + } + } + } + } + }; + std::vector workers(n_thread-1); + int it = 0; + for (auto& w : workers) w = std::thread(compute, it++); + compute(it); + for (auto& w : workers) w.join(); + int64_t n_tokens_padded = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); + if (n_tokens_padded > n_tokens) { + if (data) { + std::fill(data + int64_t(n_tokens)*n_kv, data + n_tokens_padded*n_kv, -INFINITY); + } + if (data_f16) { + ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY); + std::fill(data_f16 + int64_t(n_tokens)*n_kv, data_f16 + n_tokens_padded*n_kv, h_inf); + } + if (data_swa) { + std::fill(data_swa + int64_t(n_tokens)*n_kv, data_swa + n_tokens_padded*n_kv, -INFINITY); + } + if (data_swa_f16) { + ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY); + std::fill(data_swa_f16 + int64_t(n_tokens)*n_kv, data_swa_f16 + n_tokens_padded*n_kv, h_inf); + } + } + } + else { + // For causal attention, use only the previous KV cells // of the correct sequence for each token of the batch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. @@ -2035,6 +2188,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const llama_pos pos = batch.pos[j]; const llama_seq_id seq_id = batch.seq_id[j][0]; + if (!hparams.use_alibi && cparams.flash_attn) { + noalibi_f16(j, pos, seq_id, 0, n_kv); + continue; + } + for (int i = 0; i < n_kv; ++i) { float f; if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { @@ -2050,9 +2208,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { if (data) { data[h*(n_kv*n_tokens) + j*n_kv + i] = f; } + if (data_f16) { + data_f16[h*(n_kv*n_tokens) + j*n_kv + i] = ggml_fp32_to_fp16(f); + } // may need to cut off old tokens for sliding window - if (data_swa) { + if (data_swa || data_swa_f16) { if (hparams.n_attn_chunk) { llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { @@ -2063,27 +2224,39 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { f = -INFINITY; } } - data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; + if (data_swa) { + data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; + } + if (data_swa_f16) { + data_swa_f16[h*(n_kv*n_tokens) + j*n_kv + i] = ggml_fp32_to_fp16(f); + } } } } - if (data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } + int64_t n_tokens_padded = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); + if (n_tokens_padded > n_tokens) { + if (data) { + std::fill(data + int64_t(n_tokens)*n_kv, data + n_tokens_padded*n_kv, -INFINITY); } - } - - if (data_swa) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } + if (data_f16) { + ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY); + std::fill(data_f16 + int64_t(n_tokens)*n_kv, data_f16 + n_tokens_padded*n_kv, h_inf); + } + if (data_swa) { + std::fill(data_swa + int64_t(n_tokens)*n_kv, data_swa + n_tokens_padded*n_kv, -INFINITY); + } + if (data_swa_f16) { + ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY); + std::fill(data_swa_f16 + int64_t(n_tokens)*n_kv, data_swa_f16 + n_tokens_padded*n_kv, h_inf); } } } + } +#if IK_PRINT_TIMING == 2 + auto tim2 = ggml_time_us(); + printf("set_inputs(mask1): %d us\n", int(tim2-tim1)); +#endif } else { // when using kv cache, the mask needs to match the kv cache size const int64_t n_tokens = batch.n_tokens; @@ -2118,6 +2291,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } +#if IK_PRINT_TIMING == 2 + auto tim2 = ggml_time_us(); + printf("set_inputs(mask2): %d us\n", int(tim2-tim1)); +#endif } } @@ -2315,10 +2492,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } -#if IK_PRINT_TIMING - auto tim2 = ggml_time_us(); - printf("%s(...): %d us\n", __func__, int(tim2-tim1)); -#endif } // Make sure enough space is available for outputs. @@ -2434,6 +2607,9 @@ static int llama_decode_internal( LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); return -1; } +#if IK_PRINT_TIMING > 2 + printf("===== %s: %ld\n", __func__, ggml_time_us()); +#endif const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -2502,6 +2678,9 @@ static int llama_decode_internal( } for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { +#if IK_PRINT_TIMING + auto tim1 = ggml_time_us(); +#endif const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); llama_batch u_batch = { /* .n_tokens = */ (int32_t) n_tokens, @@ -2592,13 +2771,31 @@ static int llama_decode_internal( //kv_self.n = llama_kv_cache_cell_max(kv_self); } } +#if IK_PRINT_TIMING + auto tim2 = ggml_time_us(); + printf("prelude(...): %d us\n", int(tim2-tim1)); +#endif //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); +#if IK_PRINT_TIMING + tim1 = ggml_time_us(); +#endif ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); +#if IK_PRINT_TIMING + tim2 = ggml_time_us(); + printf("sched_reset(...): %d us\n", int(tim2-tim1)); +#endif +#if IK_PRINT_TIMING + tim1 = ggml_time_us(); +#endif ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, u_batch, false); +#if IK_PRINT_TIMING + tim2 = ggml_time_us(); + printf("build_graph(...): %d us\n", int(tim2-tim1)); +#endif // the output is always the last tensor in the graph struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; @@ -2624,11 +2821,33 @@ static int llama_decode_internal( } // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); +#if IK_PRINT_TIMING + tim1 = ggml_time_us(); +#endif ggml_backend_sched_alloc_graph(lctx.sched, gf); +#if IK_PRINT_TIMING + tim2 = ggml_time_us(); + printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1)); +#endif +#if IK_PRINT_TIMING == 1 + tim1 = ggml_time_us(); +#endif llama_set_inputs(lctx, u_batch); +#if IK_PRINT_TIMING == 1 + tim2 = ggml_time_us(); + printf("set_inputs(...): %d us\n", int(tim2-tim1)); +#endif +#if IK_PRINT_TIMING + tim1 = ggml_time_us(); +#endif llama_graph_compute(lctx, gf, n_threads); +#if IK_PRINT_TIMING + llama_synchronize(&lctx); + tim2 = ggml_time_us(); + printf("graph_compute(...): %d us\n", int(tim2-tim1)); +#endif // update the kv ring buffer { @@ -2647,6 +2866,9 @@ static int llama_decode_internal( // extract logits if (res) { +#if IK_PRINT_TIMING + tim1 = ggml_time_us(); +#endif ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(lctx.logits != nullptr); @@ -2659,10 +2881,17 @@ static int llama_decode_internal( GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size); ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float)); } +#if IK_PRINT_TIMING + tim2 = ggml_time_us(); + printf("get_result(...): %d us\n", int(tim2-tim1)); +#endif } // extract embeddings if (embd) { +#if IK_PRINT_TIMING + tim1 = ggml_time_us(); +#endif ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); GGML_ASSERT(backend_embd != nullptr); @@ -2702,6 +2931,10 @@ static int llama_decode_internal( GGML_ABORT("unknown pooling type"); } } +#if IK_PRINT_TIMING + tim2 = ggml_time_us(); + printf("get_embedding(...): %d us\n", int(tim2-tim1)); +#endif } n_outputs_prev += lctx.n_outputs; } @@ -2718,7 +2951,7 @@ static int llama_decode_internal( // queue defragmentation for next llama_kv_cache_update if (fragmentation > cparams.defrag_thold) { - //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation); + LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation); llama_kv_cache_defrag(kv_self); } @@ -2726,7 +2959,14 @@ static int llama_decode_internal( // Reset state for the next token before backend sync, to allow the CPU activities in the reset to // overlap with device computation. +#if IK_PRINT_TIMING + auto tim1 = ggml_time_us(); +#endif ggml_backend_sched_reset(lctx.sched); +#if IK_PRINT_TIMING + auto tim2 = ggml_time_us(); + printf("sched_reset(...): %d us\n", int(tim2-tim1)); +#endif return 0; }