mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 10:51:51 +00:00
Adaptive P sampler: update review logic, delete old code comments, put prep stage after logit bias (#1386)
* simpler n_rewind logic, delete old comments
* use more consistent names, add updt_w_cur to json schema
* align comments
* refactor review logic, update struct/variable names
* revert cosmetic changes
* check enable/disable in llama_prep_adaptive_p_impl()
* delete extra whitespaces after statement
* show target in debug prints
* more concise debug print
* delete old comments
* update with loop instead of move()
* comment out all adaptive p debug prints
* more debug prints
* move review() variables: common_sampler struct -> common_sampler_review() args
* match n_unsent type
* fix merge bugs, delete adaptive p references in buffer_and_check_string_ban()
* restore accidental erasure
* Revert "adaptive p: collect probability before logit bias"
This reverts commit 1434878461.
This commit is contained in:
@@ -882,6 +882,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task)
|
||||
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot.sparams.adaptive_target = json_value(data, "adaptive_target", default_sparams.adaptive_target);
|
||||
slot.sparams.adaptive_decay = json_value(data, "adaptive_decay", default_sparams.adaptive_decay);
|
||||
slot.sparams.adaptive_updt_w_cur = json_value(data, "adaptive_updt_w_cur", default_sparams.adaptive_updt_w_cur);
|
||||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||
slot.params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||
@@ -1667,6 +1668,7 @@ json server_context::get_formated_generation(const server_slot& slot) const {
|
||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||
{"adaptive_target", slot.sparams.adaptive_target},
|
||||
{"adaptive_decay", slot.sparams.adaptive_decay},
|
||||
{"adaptive_updt_w_cur", slot.sparams.adaptive_updt_w_cur},
|
||||
{"penalize_nl", slot.sparams.penalize_nl},
|
||||
{"stop", slot.params.antiprompt},
|
||||
{"max_tokens", slot.params.n_predict}, // User configured n_predict
|
||||
@@ -3332,7 +3334,7 @@ void server_context::speculative_decoding_accept() {
|
||||
buffer_and_check_string_ban(slot, result);
|
||||
}
|
||||
|
||||
common_sampler_review(slot.ctx_sampling);
|
||||
common_sampler_review(slot.ctx_sampling, slot.token_buffer.size(), slot.rewind_status);
|
||||
}
|
||||
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int)ids.size() - 1, (int)slot.drafted.size(), slot.n_past);
|
||||
LOG_VERBOSE("speculative decoding result", {
|
||||
@@ -3545,7 +3547,6 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
|
||||
bool buffer_full = slot.token_buffer.size() >= slot.n_buffer;
|
||||
|
||||
int32_t ban_pos = -1;
|
||||
int32_t n_rewind = 0;
|
||||
bool sent_results = false;
|
||||
|
||||
// Always reset logit bias to base before checking bans
|
||||
@@ -3553,12 +3554,6 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
|
||||
|
||||
if (slot.ban_phrases.size() > 0 || slot.ban_regex.size() > 0 || slot.ban_regex_ci.size() > 0) {
|
||||
ban_pos = check_ban_phrase(slot);
|
||||
if (ban_pos >= 0 && slot.sparams.adaptive_target >= 0.0f) {
|
||||
int32_t buffer_start_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1;
|
||||
int32_t n_keep_buffer = ban_pos - buffer_start_pos;
|
||||
if (n_keep_buffer < 0) n_keep_buffer = 0;
|
||||
n_rewind = (int32_t)slot.token_buffer.size() - n_keep_buffer;
|
||||
}
|
||||
}
|
||||
|
||||
bool allow_rewind = true;
|
||||
@@ -3600,17 +3595,11 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_
|
||||
// send 1 token from the front (FIFO)
|
||||
send_token_results(slot.token_buffer, slot, 1);
|
||||
}
|
||||
if (slot.sparams.adaptive_target >= 0.0f) {
|
||||
sent_results = true;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// buffer the result, wait for more tokens to validate string
|
||||
slot.sampled = result.tok;
|
||||
}
|
||||
if (slot.sparams.adaptive_target >= 0.0f) {
|
||||
slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind;
|
||||
}
|
||||
}
|
||||
|
||||
void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
@@ -3761,7 +3750,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) {
|
||||
buffer_and_check_string_ban(slot, result);
|
||||
}
|
||||
|
||||
common_sampler_review(slot.ctx_sampling);
|
||||
common_sampler_review(slot.ctx_sampling, slot.token_buffer.size(), slot.rewind_status);
|
||||
|
||||
slot.i_batch = -1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user