Support --device and --device-draft parameter (#866)

* add --device and --device-draft parameter

* don't print debug message in release mode

* fix

* bug fix to throw exception when no device specified

* add const

---------

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2025-10-27 16:13:28 +00:00
committed by GitHub
parent eb8116b097
commit 904e994bfb
12 changed files with 283 additions and 40 deletions

View File

@@ -91,10 +91,10 @@ bool llama_speculative_are_compatible(
const struct llama_vocab * vocab_dft = llama_get_model_vocab(model_dft);
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
LLAMA_LOG_INFO("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
LLAMA_LOG_DEBUG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
const bool vocab_type_dft = llama_vocab_type(model_dft);
LLAMA_LOG_INFO("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
LLAMA_LOG_DEBUG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
LLAMA_LOG_INFO("%s: draft model vocab type must match target model to use speculation but ", __func__);
@@ -203,13 +203,13 @@ std::vector<llama_token> llama_speculative_gen_draft(
std::string text;
text = llama_detokenize(ctx_tgt, prompt_tgt_main_model, true);
text = replace_to_dft(spec, text);
LLAMA_LOG_INFO("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
LLAMA_LOG_DEBUG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
prompt_tgt_draft_model = llama_tokenize(ctx_dft, text, false, true);
// convert id_last to draft vocab
std::vector<llama_token> id_last_vec(1, id_last);
text = llama_detokenize(ctx_tgt, id_last_vec);
LLAMA_LOG_INFO("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
LLAMA_LOG_DEBUG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
id_last = llama_tokenize(ctx_dft, text, false, true)[0];
}
// prompt_tgt's tokens will always be compatible with ctx_dft
@@ -233,8 +233,7 @@ std::vector<llama_token> llama_speculative_gen_draft(
reuse_n = cur;
}
}
LLAMA_LOG_INFO("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
LLAMA_LOG_DEBUG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
std::vector<llama_token> result;
result.reserve(params.n_draft);
@@ -344,7 +343,7 @@ std::vector<llama_token> llama_speculative_gen_draft(
if (!spec->vocab_dft_compatible) {
std::string detokenized = llama_detokenize(ctx_dft, result, true);
detokenized = replace_to_tgt(spec, detokenized);
LLAMA_LOG_INFO("draft->main detokenized string: '%s'\n", detokenized.c_str());
LLAMA_LOG_DEBUG("draft->main detokenized string: '%s'\n", detokenized.c_str());
result = llama_tokenize(ctx_tgt, detokenized, false, true);
if (result.size() > (size_t)params.n_draft) {
result.resize(params.n_draft);