mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-21 06:59:21 +00:00
add dry sampler (#513)
* add dry sampler * use vocab instead of model in dry_init function * fix compile error for build test --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -349,7 +349,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), sparams);
|
||||
|
||||
while (n_remain != 0 || params.interactive) {
|
||||
// predict
|
||||
|
||||
@@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
||||
|
||||
LOG_TEE("\n");
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(ctx_llava->model),params->sparams);
|
||||
if (!ctx_sampling) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
|
||||
@@ -218,7 +218,7 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla
|
||||
|
||||
LOG_TEE("\n");
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(ctx_llava->model),params->sparams);
|
||||
return ctx_sampling;
|
||||
}
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
|
||||
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
||||
|
||||
// target model sampling context
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), params.sparams);
|
||||
|
||||
// verification n-grams
|
||||
std::vector<ngram_data> ngrams_cur(G);
|
||||
|
||||
@@ -106,7 +106,7 @@ int main(int argc, char ** argv){
|
||||
|
||||
bool has_eos = false;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), params.sparams);
|
||||
|
||||
std::vector<llama_token> draft;
|
||||
|
||||
|
||||
@@ -531,7 +531,7 @@ int main(int argc, char ** argv) {
|
||||
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
|
||||
}
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), sparams);
|
||||
if (!ctx_sampling) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
|
||||
@@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
|
||||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.ctx_sampling = llama_sampling_init(params.sparams);
|
||||
client.ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), params.sparams);
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
set(TARGET rpc-server)
|
||||
add_executable(${TARGET} rpc-server.cpp)
|
||||
target_link_libraries(${TARGET} PRIVATE ggml)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
if (MSVC)
|
||||
target_link_options(${TARGET} PRIVATE
|
||||
$<$<CONFIG:DEBUG>:/STACK:20971520,1048576 >
|
||||
$<$<CONFIG:RELEASE>:/STACK:20971520,1048576>
|
||||
)
|
||||
endif()
|
||||
@@ -37,7 +37,13 @@ install(TARGETS ${TARGET} RUNTIME)
|
||||
target_compile_definitions(${TARGET} PRIVATE
|
||||
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
|
||||
)
|
||||
|
||||
if (MSVC)
|
||||
target_link_options(${TARGET} PRIVATE
|
||||
$<$<CONFIG:DEBUG>:/STACK:20971520,1048576 >
|
||||
$<$<CONFIG:RELEASE>:/STACK:20971520,1048576>
|
||||
)
|
||||
endif()
|
||||
# target_link_libraries(${TARGET} PRIVATE "/STACK:104857600")
|
||||
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
|
||||
@@ -977,6 +977,10 @@ struct server_context {
|
||||
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
||||
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
||||
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
||||
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
|
||||
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
|
||||
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
|
||||
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
|
||||
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
@@ -987,6 +991,42 @@ struct server_context {
|
||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
|
||||
if (slot.sparams.penalty_last_n < -1) {
|
||||
throw std::runtime_error("Error: repeat_last_n must be >= -1");
|
||||
}
|
||||
|
||||
if (slot.sparams.dry_penalty_last_n < -1) {
|
||||
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
|
||||
}
|
||||
|
||||
if (slot.sparams.penalty_last_n == -1) {
|
||||
// note: should be the slot's context and not the full context, but it's ok
|
||||
slot.sparams.penalty_last_n = llama_n_ctx(ctx);
|
||||
}
|
||||
|
||||
if (slot.sparams.dry_penalty_last_n == -1) {
|
||||
slot.sparams.dry_penalty_last_n = llama_n_ctx(ctx);
|
||||
|
||||
}
|
||||
if (slot.sparams.dry_base < 1.0f)
|
||||
{
|
||||
slot.sparams.dry_base = default_sparams.dry_base;
|
||||
}
|
||||
|
||||
// sequence breakers for DRY
|
||||
{
|
||||
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
|
||||
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
||||
|
||||
if (data.contains("dry_sequence_breakers")) {
|
||||
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
||||
if (slot.sparams.dry_sequence_breakers.empty()) {
|
||||
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process "json_schema" and "grammar"
|
||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
||||
@@ -1156,7 +1196,7 @@ struct server_context {
|
||||
if (slot.ctx_sampling != nullptr) {
|
||||
llama_sampling_free(slot.ctx_sampling);
|
||||
}
|
||||
slot.ctx_sampling = llama_sampling_init(slot.sparams);
|
||||
slot.ctx_sampling = llama_sampling_init(llama_get_model_vocab(model),slot.sparams);
|
||||
if (slot.ctx_sampling == nullptr) {
|
||||
// for now, the only error that may happen here is invalid grammar
|
||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||
@@ -1405,6 +1445,11 @@ struct server_context {
|
||||
{"frequency_penalty", slot.sparams.penalty_freq},
|
||||
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
|
||||
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
|
||||
{"dry_multiplier", slot.sparams.dry_multiplier},
|
||||
{"dry_base", slot.sparams.dry_base},
|
||||
{"dry_allowed_length", slot.sparams.dry_allowed_length},
|
||||
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
|
||||
{"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
|
||||
{"mirostat", slot.sparams.mirostat},
|
||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||
@@ -2337,6 +2382,13 @@ struct server_context {
|
||||
slot.command = SLOT_COMMAND_NONE;
|
||||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
llama_sampling_reset(slot.ctx_sampling);
|
||||
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
||||
llama_token id = slot.prompt_tokens[i];
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, id, false);
|
||||
}
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
@@ -179,7 +179,7 @@ int main(int argc, char ** argv) {
|
||||
bool has_eos = false;
|
||||
|
||||
// target model sampling context
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model_tgt), params.sparams);
|
||||
|
||||
// draft sequence data
|
||||
std::vector<seq_draft> drafts(n_seq_dft);
|
||||
@@ -190,7 +190,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
||||
drafts[s].ctx_sampling = llama_sampling_init(llama_get_model_vocab(model_dft), params.sparams);
|
||||
}
|
||||
|
||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
||||
|
||||
Reference in New Issue
Block a user