mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Attention mask tweaks for better long context performance (#825)
* Parallelize mask We see non-negligible PP gains for long contexts. More importantly, the strange drop in performance observed for GPT-OSS for context >= 32k tokens is gone. * Whith FA on, create mask as f16 directly * WIP * Reduce KQ mask padding to 16 Why was it 64 in the first place? I don't observe any issues, while TG performance for long contexts improves by 2-4%. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -2235,7 +2235,7 @@ extern "C" {
|
||||
int min_entries,
|
||||
float thresh);
|
||||
|
||||
#define GGML_KQ_MASK_PAD 64
|
||||
#define GGML_KQ_MASK_PAD 16
|
||||
|
||||
// q: [n_embd, n_batch, n_head, 1]
|
||||
// k: [n_embd, n_kv, n_head_kv, 1]
|
||||
|
||||
@@ -276,6 +276,12 @@ ggml_tensor * llm_build_context::build_inp_out_ids() {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) {
|
||||
if (causal && flash_attn) {
|
||||
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
|
||||
ggml_set_input(lctx.inp_KQ_mask);
|
||||
return lctx.inp_KQ_mask;
|
||||
}
|
||||
lctx.inp_KQ_mask = causal
|
||||
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
|
||||
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
@@ -287,6 +293,12 @@ ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) {
|
||||
|
||||
ggml_tensor * llm_build_context::build_inp_KQ_mask_swa(bool causal) {
|
||||
GGML_ASSERT(hparams.n_swa > 0);
|
||||
if (causal && flash_attn) {
|
||||
lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(lctx.inp_KQ_mask_swa);
|
||||
return lctx.inp_KQ_mask_swa;
|
||||
}
|
||||
|
||||
lctx.inp_KQ_mask_swa = causal
|
||||
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
|
||||
|
||||
288
src/llama.cpp
288
src/llama.cpp
@@ -1924,27 +1924,41 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
// set input data
|
||||
//
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
if (batch.token) {
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("set_inputs(token): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
}
|
||||
|
||||
if (batch.embd) {
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("set_inputs(embd): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
}
|
||||
|
||||
if (batch.pos && lctx.inp_pos) {
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
const int n_pos_per_embd = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
|
||||
if (batch.token && n_pos_per_embd == 4) {
|
||||
@@ -1959,9 +1973,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
} else {
|
||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(lctx.inp_pos));
|
||||
}
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("set_inputs(pos): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
}
|
||||
|
||||
if (lctx.inp_pos && lctx.inp_scale) {
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
int n_tokens = batch.n_tokens;
|
||||
GGML_ASSERT(ggml_nelements(lctx.inp_scale) >= n_tokens);
|
||||
if (int(lctx.scale_data.size()) < n_tokens) lctx.scale_data.resize(n_tokens);
|
||||
@@ -1970,9 +1991,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
lctx.scale_data[i] = std::log(std::floor((batch.pos[i] + 1.0f) / hparams.n_attn_temp_floor_scale) + 1.0f) * hparams.f_attn_temp_scale + 1.0f;
|
||||
}
|
||||
ggml_backend_tensor_set(lctx.inp_scale, lctx.scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(lctx.inp_scale));
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("set_inputs(scale): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
}
|
||||
|
||||
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
@@ -1998,6 +2026,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
} else {
|
||||
GGML_ASSERT(lctx.n_outputs == 0);
|
||||
}
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("set_inputs(outputs): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
}
|
||||
|
||||
GGML_ASSERT(
|
||||
@@ -2008,6 +2040,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
);
|
||||
|
||||
if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||
if (cparams.causal_attn && !lctx.is_encoding) {
|
||||
const int64_t n_kv = kv_self.n;
|
||||
@@ -2016,17 +2051,135 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
|
||||
float * data = nullptr;
|
||||
float * data_swa = nullptr;
|
||||
ggml_half * data_f16 = nullptr;
|
||||
ggml_half * data_swa_f16 = nullptr;
|
||||
|
||||
if (lctx.inp_KQ_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
||||
data = (float *) lctx.inp_KQ_mask->data;
|
||||
if (cparams.flash_attn) {
|
||||
data_f16 = (ggml_half *)lctx.inp_KQ_mask->data;
|
||||
} else {
|
||||
data = (float *) lctx.inp_KQ_mask->data;
|
||||
}
|
||||
}
|
||||
|
||||
if (lctx.inp_KQ_mask_swa) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
|
||||
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
|
||||
if (cparams.flash_attn) {
|
||||
data_swa_f16 = (ggml_half *) lctx.inp_KQ_mask_swa->data;
|
||||
} else {
|
||||
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
|
||||
}
|
||||
}
|
||||
|
||||
auto noalibi_f16 = [&lctx, &hparams, n_kv, data_f16, data_swa_f16] (int j, llama_pos pos, llama_seq_id seq_id, int first, int last) {
|
||||
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||
ggml_half h_zero = ggml_fp32_to_fp16(0.f);
|
||||
for (int i = first; i < last; ++i) {
|
||||
ggml_half h = !lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos ? h_inf : h_zero;
|
||||
if (data_f16) data_f16[j*n_kv + i] = h;
|
||||
if (data_swa_f16) {
|
||||
if (h != h_inf) {
|
||||
if (hparams.n_attn_chunk) {
|
||||
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
||||
if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||
h = h_inf;
|
||||
}
|
||||
} else {
|
||||
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
h = h_inf;
|
||||
}
|
||||
}
|
||||
}
|
||||
data_swa_f16[j*n_kv + i] = h;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (n_kv >= 1024 && n_tokens >= 32) {
|
||||
int n_thread = std::max(1, int(std::thread::hardware_concurrency()/2));
|
||||
int npt = (n_kv + n_thread - 1)/n_thread;
|
||||
auto compute = [&batch, &lctx, &hparams, &cparams, &noalibi_f16, n_tokens, n_kv, npt, data, data_swa, data_f16, data_swa_f16] (int ith) {
|
||||
int first = ith * npt;
|
||||
int last = std::min(int(n_kv), first + npt);
|
||||
if (last <= first) return;
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const llama_pos pos = batch.pos[j];
|
||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||
|
||||
if (!hparams.use_alibi && cparams.flash_attn) {
|
||||
noalibi_f16(j, pos, seq_id, first, last);
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int i = first; i < last; ++i) {
|
||||
float f;
|
||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
||||
f = -INFINITY;
|
||||
} else {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(lctx.kv_self.cells[i].pos - pos);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if (data) {
|
||||
data[j*n_kv + i] = f;
|
||||
}
|
||||
if (data_f16) {
|
||||
data_f16[j*n_kv + i] = ggml_fp32_to_fp16(f);
|
||||
}
|
||||
|
||||
// may need to cut off old tokens for sliding window
|
||||
if (data_swa || data_swa_f16) {
|
||||
if (f > -INFINITY) {
|
||||
if (hparams.n_attn_chunk) {
|
||||
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
||||
if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
} else {
|
||||
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (data_swa) {
|
||||
data_swa[j*n_kv + i] = f;
|
||||
}
|
||||
if (data_swa_f16) {
|
||||
data_swa_f16[j*n_kv + i] = ggml_fp32_to_fp16(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
std::vector<std::thread> workers(n_thread-1);
|
||||
int it = 0;
|
||||
for (auto& w : workers) w = std::thread(compute, it++);
|
||||
compute(it);
|
||||
for (auto& w : workers) w.join();
|
||||
int64_t n_tokens_padded = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD);
|
||||
if (n_tokens_padded > n_tokens) {
|
||||
if (data) {
|
||||
std::fill(data + int64_t(n_tokens)*n_kv, data + n_tokens_padded*n_kv, -INFINITY);
|
||||
}
|
||||
if (data_f16) {
|
||||
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||
std::fill(data_f16 + int64_t(n_tokens)*n_kv, data_f16 + n_tokens_padded*n_kv, h_inf);
|
||||
}
|
||||
if (data_swa) {
|
||||
std::fill(data_swa + int64_t(n_tokens)*n_kv, data_swa + n_tokens_padded*n_kv, -INFINITY);
|
||||
}
|
||||
if (data_swa_f16) {
|
||||
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||
std::fill(data_swa_f16 + int64_t(n_tokens)*n_kv, data_swa_f16 + n_tokens_padded*n_kv, h_inf);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
// For causal attention, use only the previous KV cells
|
||||
// of the correct sequence for each token of the batch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
@@ -2035,6 +2188,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
const llama_pos pos = batch.pos[j];
|
||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||
|
||||
if (!hparams.use_alibi && cparams.flash_attn) {
|
||||
noalibi_f16(j, pos, seq_id, 0, n_kv);
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
float f;
|
||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
||||
@@ -2050,9 +2208,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
if (data) {
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
if (data_f16) {
|
||||
data_f16[h*(n_kv*n_tokens) + j*n_kv + i] = ggml_fp32_to_fp16(f);
|
||||
}
|
||||
|
||||
// may need to cut off old tokens for sliding window
|
||||
if (data_swa) {
|
||||
if (data_swa || data_swa_f16) {
|
||||
if (hparams.n_attn_chunk) {
|
||||
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
||||
if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||
@@ -2063,27 +2224,39 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
}
|
||||
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||
if (data_swa) {
|
||||
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
if (data_swa_f16) {
|
||||
data_swa_f16[h*(n_kv*n_tokens) + j*n_kv + i] = ggml_fp32_to_fp16(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (data) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
int64_t n_tokens_padded = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD);
|
||||
if (n_tokens_padded > n_tokens) {
|
||||
if (data) {
|
||||
std::fill(data + int64_t(n_tokens)*n_kv, data + n_tokens_padded*n_kv, -INFINITY);
|
||||
}
|
||||
}
|
||||
|
||||
if (data_swa) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
if (data_f16) {
|
||||
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||
std::fill(data_f16 + int64_t(n_tokens)*n_kv, data_f16 + n_tokens_padded*n_kv, h_inf);
|
||||
}
|
||||
if (data_swa) {
|
||||
std::fill(data_swa + int64_t(n_tokens)*n_kv, data_swa + n_tokens_padded*n_kv, -INFINITY);
|
||||
}
|
||||
if (data_swa_f16) {
|
||||
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||
std::fill(data_swa_f16 + int64_t(n_tokens)*n_kv, data_swa_f16 + n_tokens_padded*n_kv, h_inf);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("set_inputs(mask1): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
} else {
|
||||
// when using kv cache, the mask needs to match the kv cache size
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
@@ -2118,6 +2291,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
}
|
||||
#if IK_PRINT_TIMING == 2
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("set_inputs(mask2): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2315,10 +2492,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
}
|
||||
#if IK_PRINT_TIMING
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("%s(...): %d us\n", __func__, int(tim2-tim1));
|
||||
#endif
|
||||
}
|
||||
|
||||
// Make sure enough space is available for outputs.
|
||||
@@ -2434,6 +2607,9 @@ static int llama_decode_internal(
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
|
||||
return -1;
|
||||
}
|
||||
#if IK_PRINT_TIMING > 2
|
||||
printf("===== %s: %ld\n", __func__, ggml_time_us());
|
||||
#endif
|
||||
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
@@ -2502,6 +2678,9 @@ static int llama_decode_internal(
|
||||
}
|
||||
|
||||
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
|
||||
#if IK_PRINT_TIMING
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
|
||||
llama_batch u_batch = {
|
||||
/* .n_tokens = */ (int32_t) n_tokens,
|
||||
@@ -2592,13 +2771,31 @@ static int llama_decode_internal(
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
}
|
||||
}
|
||||
#if IK_PRINT_TIMING
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("prelude(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
printf("sched_reset(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
ggml_cgraph * gf = llm_build_context::llama_build_graph(lctx, u_batch, false);
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
printf("build_graph(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
@@ -2624,11 +2821,33 @@ static int llama_decode_internal(
|
||||
}
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
printf("sched_alloc_graph(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
#if IK_PRINT_TIMING == 1
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
llama_set_inputs(lctx, u_batch);
|
||||
#if IK_PRINT_TIMING == 1
|
||||
tim2 = ggml_time_us();
|
||||
printf("set_inputs(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
llama_graph_compute(lctx, gf, n_threads);
|
||||
#if IK_PRINT_TIMING
|
||||
llama_synchronize(&lctx);
|
||||
tim2 = ggml_time_us();
|
||||
printf("graph_compute(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
// update the kv ring buffer
|
||||
{
|
||||
@@ -2647,6 +2866,9 @@ static int llama_decode_internal(
|
||||
|
||||
// extract logits
|
||||
if (res) {
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(lctx.logits != nullptr);
|
||||
@@ -2659,10 +2881,17 @@ static int llama_decode_internal(
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
|
||||
}
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
printf("get_result(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (embd) {
|
||||
#if IK_PRINT_TIMING
|
||||
tim1 = ggml_time_us();
|
||||
#endif
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
@@ -2702,6 +2931,10 @@ static int llama_decode_internal(
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
}
|
||||
#if IK_PRINT_TIMING
|
||||
tim2 = ggml_time_us();
|
||||
printf("get_embedding(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
}
|
||||
n_outputs_prev += lctx.n_outputs;
|
||||
}
|
||||
@@ -2718,7 +2951,7 @@ static int llama_decode_internal(
|
||||
|
||||
// queue defragmentation for next llama_kv_cache_update
|
||||
if (fragmentation > cparams.defrag_thold) {
|
||||
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
|
||||
LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
|
||||
|
||||
llama_kv_cache_defrag(kv_self);
|
||||
}
|
||||
@@ -2726,7 +2959,14 @@ static int llama_decode_internal(
|
||||
|
||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||
// overlap with device computation.
|
||||
#if IK_PRINT_TIMING
|
||||
auto tim1 = ggml_time_us();
|
||||
#endif
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
#if IK_PRINT_TIMING
|
||||
auto tim2 = ggml_time_us();
|
||||
printf("sched_reset(...): %d us\n", int(tim2-tim1));
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user