From eea6cc443349610fc3869651499e91bdc8262252 Mon Sep 17 00:00:00 2001 From: firecoperana Date: Mon, 10 Nov 2025 07:51:07 +0000 Subject: [PATCH] Server: Add --draft-params to set draft model parameter via command line args (#932) * Add command line argument for draft model * Remove second context of draft model * Format print * print usage if parsing -draft fails --------- Co-authored-by: firecoperana --- common/common.cpp | 54 +++++++++++++++++++++++++++++++++++++- common/common.h | 3 +++ examples/server/server.cpp | 30 +++++++++++++++------ 3 files changed, 78 insertions(+), 9 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index c626397b..5d549d00 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -282,6 +282,53 @@ std::pair> common_remote_get_content(const std::string& // CLI argument parsing // +std::pair parse_command_line(const std::string& commandLine) { + std::vector tokens; + std::string current; + bool inQuotes = false; + + for (size_t i = 0; i < commandLine.length(); i++) { + char c = commandLine[i]; + + if (c == '\"') { + inQuotes = !inQuotes; + } + else if (c == ' ' && !inQuotes) { + if (!current.empty()) { + tokens.push_back(current); + current.clear(); + } + } + else { + current += c; + } + } + + if (!current.empty()) { + tokens.push_back(current); + } + + int argc = static_cast(tokens.size()); + char** argv = new char* [static_cast(argc) + 1]; + + for (int i = 0; i < argc; i++) { + argv[i] = new char[tokens[i].length() + 1]; + std::strcpy(argv[i], tokens[i].c_str()); + } + argv[argc] = nullptr; + return { argc, argv }; +} + +void free_command_line(int argc, char** argv) { + if (argv == nullptr) return; + + for (int i = 0; i < argc; i++) { + delete[] argv[i]; + } + delete[] argv; +} + + void gpt_params_handle_model_default(gpt_params & params) { if (!params.hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model @@ -1254,6 +1301,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cuda_params = argv[i]; return true; } + if (arg == "-draft" || arg == "--draft-params") { + CHECK_ARG + params.draft_params = argv[i]; + return true; + } if (arg == "--cpu-moe" || arg == "-cmoe") { params.tensor_buft_overrides.push_back({strdup("\\.ffn_(up|down|gate)_exps\\.weight"), ggml_backend_cpu_buffer_type()}); return true; @@ -2081,7 +2133,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "backend" }); options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" }); options.push_back({ "*", "-cuda, --cuda-params", "comma separate list of cuda parameters" }); - + options.push_back({ "*", "-draft, --draft-params", "comma separate list of draft model parameters" }); if (llama_supports_mlock()) { options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" }); } diff --git a/common/common.h b/common/common.h index 4ad5908d..776fa255 100644 --- a/common/common.h +++ b/common/common.h @@ -130,6 +130,7 @@ struct model_paths { struct gpt_params { std::string devices; std::string devices_draft; + std::string draft_params; uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed @@ -375,6 +376,8 @@ struct gpt_params { }; +std::pair parse_command_line(const std::string& commandLine); +void free_command_line(int argc, char** argv); void gpt_params_handle_hf_token(gpt_params & params); void gpt_params_parse_from_env(gpt_params & params); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9279a045..e4f03470 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1250,6 +1250,7 @@ struct server_context { chat_templates = common_chat_templates_init(model, "chatml"); } + bool has_draft_model = !params.model_draft.empty() || !params.draft_params.empty(); std::string & mmproj_path = params.mmproj.path; if (!mmproj_path.empty()) { mtmd_context_params mparams = mtmd_context_params_default(); @@ -1274,24 +1275,37 @@ struct server_context { // SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); //} - if (!params.model_draft.empty()) { + if (has_draft_model) { LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal"); return false; } } // Load draft model for speculative decoding if specified - if (!params.model_draft.empty()) { - LOG_INFO("loading draft model", {{"model", params.model_draft}}); + if (has_draft_model) { + LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n"); gpt_params params_dft; params_dft.devices = params.devices_draft; params_dft.model = params.model_draft; - params_dft.n_ctx = params.n_ctx_draft == 0 ? params.n_ctx / params.n_parallel : params.n_ctx_draft; params_dft.n_gpu_layers = params.n_gpu_layers_draft; - params_dft.n_parallel = 1; params_dft.cache_type_k = params.cache_type_k_draft.empty() ? params.cache_type_k : params.cache_type_k_draft; params_dft.cache_type_v = params.cache_type_v_draft.empty() ? params.cache_type_v : params.cache_type_v_draft; params_dft.flash_attn = params.flash_attn; + if (!params.draft_params.empty()) { + auto [argc, argv] = parse_command_line("llama-server "+params.draft_params); + if (!gpt_params_parse(argc, argv, params_dft)) { + gpt_params_print_usage(argc, argv, params_dft); + free_command_line(argc, argv); + return false; + }; + free_command_line(argc, argv); + } + LOG_INFO("", { {"model", params_dft.model} }); + if (params_dft.n_ctx == 0) { + params_dft.n_ctx = params.n_ctx_draft; + } + params_dft.n_ctx = params_dft.n_ctx == 0 ? params.n_ctx / params.n_parallel : params_dft.n_ctx; + params_dft.n_parallel = 1; llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft); @@ -1361,8 +1375,8 @@ struct server_context { // Initialize speculative decoding if a draft model is loaded if (ctx_draft) { slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); - - slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft); + // slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft); // initialized twice + slot.ctx_dft = ctx_draft; if (slot.ctx_dft == nullptr) { LOG_ERROR("failed to create draft context", {}); return; @@ -3010,7 +3024,7 @@ struct server_context { for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { new_tokens[i - n_discard] = new_tokens[i]; } - new_tokens.resize((int) prompt_tokens.size() - n_discard); + new_tokens.resize(prompt_tokens.size() - n_discard); prompt_tokens.clear(); prompt_tokens.insert(new_tokens); slot.truncated = true;