mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Server: Add --draft-params to set draft model parameter via command line args (#932)
* Add command line argument for draft model * Remove second context of draft model * Format print * print usage if parsing -draft fails --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -1250,6 +1250,7 @@ struct server_context {
|
||||
chat_templates = common_chat_templates_init(model, "chatml");
|
||||
}
|
||||
|
||||
bool has_draft_model = !params.model_draft.empty() || !params.draft_params.empty();
|
||||
std::string & mmproj_path = params.mmproj.path;
|
||||
if (!mmproj_path.empty()) {
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
@@ -1274,24 +1275,37 @@ struct server_context {
|
||||
// SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
|
||||
//}
|
||||
|
||||
if (!params.model_draft.empty()) {
|
||||
if (has_draft_model) {
|
||||
LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Load draft model for speculative decoding if specified
|
||||
if (!params.model_draft.empty()) {
|
||||
LOG_INFO("loading draft model", {{"model", params.model_draft}});
|
||||
if (has_draft_model) {
|
||||
LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n");
|
||||
|
||||
gpt_params params_dft;
|
||||
params_dft.devices = params.devices_draft;
|
||||
params_dft.model = params.model_draft;
|
||||
params_dft.n_ctx = params.n_ctx_draft == 0 ? params.n_ctx / params.n_parallel : params.n_ctx_draft;
|
||||
params_dft.n_gpu_layers = params.n_gpu_layers_draft;
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.cache_type_k = params.cache_type_k_draft.empty() ? params.cache_type_k : params.cache_type_k_draft;
|
||||
params_dft.cache_type_v = params.cache_type_v_draft.empty() ? params.cache_type_v : params.cache_type_v_draft;
|
||||
params_dft.flash_attn = params.flash_attn;
|
||||
if (!params.draft_params.empty()) {
|
||||
auto [argc, argv] = parse_command_line("llama-server "+params.draft_params);
|
||||
if (!gpt_params_parse(argc, argv, params_dft)) {
|
||||
gpt_params_print_usage(argc, argv, params_dft);
|
||||
free_command_line(argc, argv);
|
||||
return false;
|
||||
};
|
||||
free_command_line(argc, argv);
|
||||
}
|
||||
LOG_INFO("", { {"model", params_dft.model} });
|
||||
if (params_dft.n_ctx == 0) {
|
||||
params_dft.n_ctx = params.n_ctx_draft;
|
||||
}
|
||||
params_dft.n_ctx = params_dft.n_ctx == 0 ? params.n_ctx / params.n_parallel : params_dft.n_ctx;
|
||||
params_dft.n_parallel = 1;
|
||||
|
||||
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft);
|
||||
|
||||
@@ -1361,8 +1375,8 @@ struct server_context {
|
||||
// Initialize speculative decoding if a draft model is loaded
|
||||
if (ctx_draft) {
|
||||
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
|
||||
|
||||
slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft);
|
||||
// slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft); // initialized twice
|
||||
slot.ctx_dft = ctx_draft;
|
||||
if (slot.ctx_dft == nullptr) {
|
||||
LOG_ERROR("failed to create draft context", {});
|
||||
return;
|
||||
@@ -3010,7 +3024,7 @@ struct server_context {
|
||||
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
||||
new_tokens[i - n_discard] = new_tokens[i];
|
||||
}
|
||||
new_tokens.resize((int) prompt_tokens.size() - n_discard);
|
||||
new_tokens.resize(prompt_tokens.size() - n_discard);
|
||||
prompt_tokens.clear();
|
||||
prompt_tokens.insert(new_tokens);
|
||||
slot.truncated = true;
|
||||
|
||||
Reference in New Issue
Block a user