From 5e8bb724ce53e25ead40e42e384938673ef7599c Mon Sep 17 00:00:00 2001 From: firecoperana <18252262+firecoperana@users.noreply.github.com> Date: Sun, 5 Apr 2026 01:41:04 -0500 Subject: [PATCH] server: support slot save/restore/erase for mtmd tokens and checkpoints (#1584) Co-authored-by: firecoperana --- examples/mtmd/mtmd.cpp | 142 +++++++++++++++++++++++++++++ examples/mtmd/mtmd.h | 5 + examples/server/server-common.cpp | 41 +++++++++ examples/server/server-common.h | 4 + examples/server/server-context.cpp | 140 ++++++++++++++++++++++++---- examples/server/server-task.h | 32 +++++++ include/llama.h | 3 + 7 files changed, 351 insertions(+), 16 deletions(-) diff --git a/examples/mtmd/mtmd.cpp b/examples/mtmd/mtmd.cpp index 77a0bbab..530b334d 100644 --- a/examples/mtmd/mtmd.cpp +++ b/examples/mtmd/mtmd.cpp @@ -1054,6 +1054,18 @@ llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) { return image_tokens->n_tokens(); } +mtmd_input_chunk * mtmd_create_input_chunk() { + auto * chunk = new mtmd_input_chunk{ + MTMD_INPUT_CHUNK_TYPE_TEXT, + std::vector{}, + nullptr, + nullptr + }; + return chunk; +} + + + // test function mtmd_input_chunks * mtmd_test_create_input_chunks() { @@ -1088,3 +1100,133 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() { return chunks; } + +static json mtmd_clip_image_f32_to_json(const clip_image_f32 & clip) { + json j; + j["nx"] = clip.nx; + j["ny"] = clip.ny; + j["buf"] = clip.buf; + return j; +} + +static clip_image_f32 * mtmd_clip_image_f32_from_json(const json & j) { + clip_image_f32 * clip = new clip_image_f32; + clip->nx = j["nx"]; + clip->ny = j["ny"]; + clip->buf = j["buf"].get>(); + return clip; +} + +static json mtmd_clip_image_f32_batch_to_json(const clip_image_f32_batch & batch, bool full = false) { + json j; + j["is_audio"] = batch.is_audio; + j["grid_x"] = batch.grid_x; + j["grid_y"] = batch.grid_y; + + if (full) { + std::vector entries; + for (auto & entry : batch.entries) { + entries.push_back(mtmd_clip_image_f32_to_json(*entry)); + } + j["entries"] = entries; + } + + return j; +} + +static clip_image_f32_batch mtmd_clip_image_f32_batch_from_json(const json & j, bool full = false) { + clip_image_f32_batch batch; + if (j.contains("is_audio")) { + batch.is_audio = j["is_audio"]; + batch.grid_x = j["grid_x"]; + batch.grid_y = j["grid_y"]; + if (full) { + auto entries = j["entries"]; + if (entries.is_array()) { + for (auto & entry : entries) { + clip_image_f32 * clip = mtmd_clip_image_f32_from_json(entry); + batch.entries.push_back(clip_image_f32_ptr(clip)); + } + } + } + + } + return batch; +} + +static mtmd_audio_tokens mtmd_audio_tokens_from_json(json & j) { + return mtmd_audio_tokens{ + j.value("n_tokens", 0), + mtmd_clip_image_f32_batch_from_json(j.value("batch_f32", json{})), + j.value("id","") + }; +} + +static mtmd_image_tokens mtmd_image_tokens_from_json(json & j) { + return mtmd_image_tokens{ + j.value("nx", 0), + j.value("ny", 0), + j.value("use_mrope_pos",false), + mtmd_clip_image_f32_batch_from_json(j.value("batch_f32", json{})), + j.value("id","") + }; +} + +static json mtmd_audio_tokens_to_json(mtmd_audio_tokens * chunk) { + json j; + if (chunk) { + j["n_tokens"] = chunk->n_tokens; + j["id"] = chunk->id; + j["batch_f32"] = mtmd_clip_image_f32_batch_to_json(chunk->batch_f32); + } + return j; +} + +static json mtmd_image_tokens_to_json(mtmd_image_tokens * chunk) { + json j; + if (chunk) { + j["nx"] = chunk->nx; + j["ny"] = chunk->ny; + j["use_mrope_pos"] = chunk->use_mrope_pos; + j["batch_f32"] = mtmd_clip_image_f32_batch_to_json(chunk->batch_f32); + j["id"] = chunk->id; + } + return j; +} + +mtmd_input_chunk * mtmd_input_chunk_from_json(json & j) { + mtmd_input_chunk * chunk = mtmd_create_input_chunk(); + chunk->type = j.value("type", MTMD_INPUT_CHUNK_TYPE_TEXT); + chunk->tokens_text = j.value("tokens_text", chunk->tokens_text); + chunk->tokens_image = nullptr; + chunk->tokens_audio = nullptr; + if (j.contains("tokens_image")) { + chunk->tokens_image = mtmd_image_tokens_ptr(new mtmd_image_tokens()); + auto image_json = j.value("tokens_image", json::array()); + *chunk->tokens_image = mtmd_image_tokens_from_json(image_json); + } + if (j.contains("tokens_audio")) { + chunk->tokens_audio = mtmd_audio_tokens_ptr(new mtmd_audio_tokens()); + *chunk->tokens_audio = mtmd_audio_tokens_from_json(j.at("tokens_audio")); + } + return chunk; +} + +void mtmd_input_chunk_to_json(mtmd_input_chunk * chunk, json & j) { + j.clear(); + if (chunk) { + j["type"] = chunk->type; + j["tokens_text"] = chunk->tokens_text; + if (chunk->tokens_image) { + j["tokens_image"] = mtmd_image_tokens_to_json(chunk->tokens_image.get()); + } + if (chunk->tokens_audio) { + j["tokens_audio"] = mtmd_audio_tokens_to_json(chunk->tokens_audio.get()); + } + } +} + + + + + diff --git a/examples/mtmd/mtmd.h b/examples/mtmd/mtmd.h index 85084abc..3285f24b 100644 --- a/examples/mtmd/mtmd.h +++ b/examples/mtmd/mtmd.h @@ -13,6 +13,8 @@ #include #include #include +#include +using json = nlohmann::ordered_json; #endif /** @@ -215,6 +217,9 @@ MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx, // the reading size (in bytes) is equal to: // llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float) MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx); +MTMD_API mtmd_input_chunk * mtmd_create_input_chunk(void); +MTMD_API mtmd_input_chunk * mtmd_input_chunk_from_json(json & j); +MTMD_API void mtmd_input_chunk_to_json(mtmd_input_chunk * chunk, json & j); ///////////////////////////////////////// diff --git a/examples/server/server-common.cpp b/examples/server/server-common.cpp index b1549449..8c1e363a 100644 --- a/examples/server/server-common.cpp +++ b/examples/server/server-common.cpp @@ -2164,6 +2164,47 @@ server_tokens server_tokens::clone() const { return res; } +json server_tokens::to_json() const +{ + json j; + std::vector media_array; + for (auto & [idx, chunk_ptr] : map_idx_to_media) { // or direct access if friend + if (chunk_ptr) { + nlohmann::json obj; + obj["index"] = idx; + json j; + mtmd_input_chunk_to_json(chunk_ptr.get(), j); + obj["chunk"] = j; + media_array.push_back(std::move(obj)); + } + } + j = nlohmann::json{ + {"has_mtmd", has_mtmd}, + {"map_idx_to_media", media_array}, + {"tokens", tokens} + }; + return j; +} + +void server_tokens::from_json(const json & j) { + clear(); + map_idx_to_media.clear(); + has_mtmd = j.value("has_mtmd", has_mtmd); + tokens = j.value("tokens", tokens); + map_idx_to_media.clear(); + json media_array = j.at("map_idx_to_media"); + if (media_array.is_array()) { + for (const auto & entry : media_array) { + size_t idx = entry.at("index"); + json chunk_json = entry.at("chunk"); + mtmd_input_chunk * chunk = mtmd_input_chunk_from_json(chunk_json); + map_idx_to_media[idx] = mtmd::input_chunk_ptr(chunk); + } + } + +} + + // Keep the first n_keep and remove n_discard tokens from tokens void server_tokens::discard_n_tokens(int32_t n_keep, int32_t n_discard) { diff --git a/examples/server/server-common.h b/examples/server/server-common.h index 4d451d8c..03dfbdbf 100644 --- a/examples/server/server-common.h +++ b/examples/server/server-common.h @@ -349,6 +349,10 @@ public: server_tokens(const llama_tokens& tokens, bool has_mtmd); + json to_json() const; + + void from_json(const json & j); + // the next position after n_tokens. if n_tokens < 0, return the next position after all tokens. llama_pos pos_next(int64_t n_tokens = -1) const; diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 3daf86af..d1c233c6 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -11,6 +11,8 @@ #include "mtmd.h" #include "mtmd-helper.h" +#include +#include #include static void log_text(const gpt_params & params_base, const std::string & text) { @@ -1995,6 +1997,117 @@ void server_context::split_multiprompt_task(int id_multi, server_task& multiprom } } + + +static size_t save_checkpoints_to_file(const std::string & filename, const std::list & checkpoints) { + if (checkpoints.size() == 0) { + return 0; + } + std::ofstream file(filename, std::ios::binary); + uint32_t magic = LLAMA_STATE_SEQ_MAGIC; + file.write(reinterpret_cast(&magic), sizeof(magic)); + uint32_t version = LLAMA_STATE_SEQ_VERSION; + file.write(reinterpret_cast(&version), sizeof(version)); + size_t count = checkpoints.size(); + file.write(reinterpret_cast(&count), sizeof(count)); + + for (const auto & checkpoint : checkpoints) { + file.write(reinterpret_cast(&checkpoint.pos_min), sizeof(checkpoint.pos_min)); + file.write(reinterpret_cast(&checkpoint.pos_max), sizeof(checkpoint.pos_max)); + file.write(reinterpret_cast(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt)); + file.write(reinterpret_cast(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt)); + size_t data_len = checkpoint.data.size(); + file.write(reinterpret_cast(&data_len), sizeof(data_len)); + if (data_len > 0) { + file.write(reinterpret_cast(checkpoint.data.data()), data_len * sizeof(uint8_t)); + } + } + size_t pos = file.tellp(); + file.close(); + return pos; +} + +static size_t load_checkpoints_from_file(const std::string & filename, std::list & checkpoints) { + std::ifstream file(filename, std::ios::binary); + if (!file.is_open()) { + return 0; + } + checkpoints.clear(); + // version checks + { + uint32_t magic; + file.read(reinterpret_cast(&magic), sizeof(magic)); + uint32_t version; + file.read(reinterpret_cast(&version), sizeof(version)); + + if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { + LLAMA_LOG_ERROR("%s: unknown (magic, version) for checkpoint file: %08x, %08x\n", __func__, magic, version); + return 0; + } + } + // load the checkpoints + { + size_t count; + file.read(reinterpret_cast(&count), sizeof(count)); + + for (int i = 0; i < count; i++) { + server_prompt_checkpoint checkpoint; + file.read(reinterpret_cast(&checkpoint.pos_min), sizeof(checkpoint.pos_min)); + file.read(reinterpret_cast(&checkpoint.pos_max), sizeof(checkpoint.pos_max)); + file.read(reinterpret_cast(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt)); + file.read(reinterpret_cast(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt)); + + size_t data_len; + file.read(reinterpret_cast(&data_len), sizeof(data_len)); + if (data_len > 0) { + checkpoint.data.resize(data_len); + file.read(reinterpret_cast(checkpoint.data.data()), data_len * sizeof(uint8_t)); + } + checkpoints.push_back(checkpoint); + } + } + size_t pos = file.tellg(); + file.close(); + return pos; +} + +static size_t save_server_tokens_to_file(const std::string & filename, const server_tokens & tokens) { + std::ofstream file(filename, std::ios::binary); + json token_json = tokens.to_json(); + token_json["magic"] = LLAMA_SERVER_MAGIC; + token_json["version"] = LLAMA_SERVER_VERSION; + size_t pos = 0; + if (file.is_open()) { + file << token_json; + pos = file.tellp(); + file.close(); + } + return pos; +} + +static size_t load_server_tokens_from_file(const std::string & filename, server_tokens & tokens) { + std::ifstream file(filename, std::ios::binary); + if (!file.is_open()) { + return 0; + } + size_t pos = 0; + json token_json; + if (file.is_open()) { + file >> token_json; + pos = file.tellg(); + file.close(); + } + uint32_t magic = token_json.value("magic", 0); + uint32_t version = token_json.value("version", 0); + if (magic != LLAMA_SERVER_MAGIC || version != LLAMA_SERVER_VERSION) { + LLAMA_LOG_ERROR("%s: unknown (magic, version) for token file: %08x, %08x\n", __func__, magic, version); + return 0; + } + tokens.from_json(token_json); + + return pos; +} + void server_context::process_single_task(server_task&& task) { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: @@ -2153,14 +2266,14 @@ void server_context::process_single_task(server_task&& task) { queue_tasks.defer(std::move(task)); break; } - if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) { - break; - } + const size_t token_count = slot->cache_tokens.size(); const int64_t t_start = ggml_time_us(); std::string filename = task.data.at("filename"); std::string filepath = task.data.at("filepath"); + save_server_tokens_to_file(filepath+".tokens.json", slot->cache_tokens); + size_t saved = save_checkpoints_to_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints); const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); @@ -2175,7 +2288,7 @@ void server_context::process_single_task(server_task&& task) { { "id_slot", id_slot }, { "filename", filename }, { "n_saved", token_count }, // tokens saved - { "n_written", nwrite }, // bytes written + { "n_written", nwrite + saved }, // bytes written { "timings", { { "save_ms", t_save_ms } } } @@ -2196,9 +2309,6 @@ void server_context::process_single_task(server_task&& task) { queue_tasks.defer(std::move(task)); break; } - if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) { - break; - } const int64_t t_start = ggml_time_us(); std::string filename = task.data.at("filename"); @@ -2212,10 +2322,9 @@ void server_context::process_single_task(server_task&& task) { send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } - slot->cache_tokens.resize(token_count); - if (mctx) { - slot->cache_tokens.has_mtmd = true; - } + load_server_tokens_from_file(filepath+".tokens.json", slot->cache_tokens); + size_t loaded = load_checkpoints_from_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints); + const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; @@ -2248,14 +2357,13 @@ void server_context::process_single_task(server_task&& task) { queue_tasks.defer(std::move(task)); break; } - if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) { - break; - } // Erase token cache const size_t n_erased = slot->cache_tokens.size(); llama_kv_cache_seq_rm(ctx, slot->id, -1, -1); - slot->cache_tokens.clear(); - + slot->cache_tokens.keep_first(0); + //slot->cache_tokens.clear(); + slot->server_cached_prompt.checkpoints.clear(); + slot->server_cached_prompt.data.clear(); server_task_result result; result.id = task.id; result.stop = true; diff --git a/examples/server/server-task.h b/examples/server/server-task.h index 89232b3d..9529261f 100644 --- a/examples/server/server-task.h +++ b/examples/server/server-task.h @@ -355,6 +355,22 @@ struct server_prompt_checkpoint { size_t size() const { return data.size(); } + + json to_json() { + json j; + j["pos_min"] = pos_min; + j["pos_max"] = pos_max; + j["pos_min_prompt"] = pos_min_prompt; + j["pos_max_prompt"] = pos_max_prompt; + return j; + } + + void from_json(const json & j) { + pos_min = j.value("pos_min", 0); + pos_max = j.value("pos_max", 0); + pos_min_prompt = j.value("pos_min_prompt", 0); + pos_max_prompt = j.value("pos_max_prompt", 0); + } }; @@ -384,6 +400,22 @@ struct server_prompt { checkpoints }; } + + json to_json() + { + json j; + j["tokens"] = tokens.to_json(); + j["n_kept_prompt"] = n_kept_prompt; + j["n_discarded_prompt"] = n_discarded_prompt; + return j; + } + + void from_json(const json & j) { + tokens.from_json(j.at("tokens")); + n_kept_prompt = j.value("n_kept_prompt", 0); + n_discarded_prompt = j.value("n_discarded_prompt", 0); + n_kept_prompt = j.value("n_kept_prompt", 0); + } }; struct server_prompt_cache { diff --git a/include/llama.h b/include/llama.h index 44817d9f..6b8a207a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -52,6 +52,9 @@ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 3 +#define LLAMA_SERVER_MAGIC 0x6c6d7376u // 'lmsv' +#define LLAMA_SERVER_VERSION 1 + #ifdef __cplusplus extern "C" { #endif