server: support slot save/restore/erase for mtmd tokens and checkpoints (#1584)

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2026-04-05 01:41:04 -05:00
committed by GitHub
parent 0147cf4837
commit 5e8bb724ce
7 changed files with 351 additions and 16 deletions

View File

@@ -11,6 +11,8 @@
#include "mtmd.h"
#include "mtmd-helper.h"
#include <fstream>
#include <iostream>
#include <regex>
static void log_text(const gpt_params & params_base, const std::string & text) {
@@ -1995,6 +1997,117 @@ void server_context::split_multiprompt_task(int id_multi, server_task& multiprom
}
}
static size_t save_checkpoints_to_file(const std::string & filename, const std::list<server_prompt_checkpoint> & checkpoints) {
if (checkpoints.size() == 0) {
return 0;
}
std::ofstream file(filename, std::ios::binary);
uint32_t magic = LLAMA_STATE_SEQ_MAGIC;
file.write(reinterpret_cast<const char *>(&magic), sizeof(magic));
uint32_t version = LLAMA_STATE_SEQ_VERSION;
file.write(reinterpret_cast<const char *>(&version), sizeof(version));
size_t count = checkpoints.size();
file.write(reinterpret_cast<const char *>(&count), sizeof(count));
for (const auto & checkpoint : checkpoints) {
file.write(reinterpret_cast<const char *>(&checkpoint.pos_min), sizeof(checkpoint.pos_min));
file.write(reinterpret_cast<const char *>(&checkpoint.pos_max), sizeof(checkpoint.pos_max));
file.write(reinterpret_cast<const char *>(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt));
file.write(reinterpret_cast<const char *>(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt));
size_t data_len = checkpoint.data.size();
file.write(reinterpret_cast<const char *>(&data_len), sizeof(data_len));
if (data_len > 0) {
file.write(reinterpret_cast<const char *>(checkpoint.data.data()), data_len * sizeof(uint8_t));
}
}
size_t pos = file.tellp();
file.close();
return pos;
}
static size_t load_checkpoints_from_file(const std::string & filename, std::list<server_prompt_checkpoint> & checkpoints) {
std::ifstream file(filename, std::ios::binary);
if (!file.is_open()) {
return 0;
}
checkpoints.clear();
// version checks
{
uint32_t magic;
file.read(reinterpret_cast<char *>(&magic), sizeof(magic));
uint32_t version;
file.read(reinterpret_cast<char *>(&version), sizeof(version));
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
LLAMA_LOG_ERROR("%s: unknown (magic, version) for checkpoint file: %08x, %08x\n", __func__, magic, version);
return 0;
}
}
// load the checkpoints
{
size_t count;
file.read(reinterpret_cast<char *>(&count), sizeof(count));
for (int i = 0; i < count; i++) {
server_prompt_checkpoint checkpoint;
file.read(reinterpret_cast<char *>(&checkpoint.pos_min), sizeof(checkpoint.pos_min));
file.read(reinterpret_cast<char *>(&checkpoint.pos_max), sizeof(checkpoint.pos_max));
file.read(reinterpret_cast<char *>(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt));
file.read(reinterpret_cast<char *>(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt));
size_t data_len;
file.read(reinterpret_cast<char *>(&data_len), sizeof(data_len));
if (data_len > 0) {
checkpoint.data.resize(data_len);
file.read(reinterpret_cast<char *>(checkpoint.data.data()), data_len * sizeof(uint8_t));
}
checkpoints.push_back(checkpoint);
}
}
size_t pos = file.tellg();
file.close();
return pos;
}
static size_t save_server_tokens_to_file(const std::string & filename, const server_tokens & tokens) {
std::ofstream file(filename, std::ios::binary);
json token_json = tokens.to_json();
token_json["magic"] = LLAMA_SERVER_MAGIC;
token_json["version"] = LLAMA_SERVER_VERSION;
size_t pos = 0;
if (file.is_open()) {
file << token_json;
pos = file.tellp();
file.close();
}
return pos;
}
static size_t load_server_tokens_from_file(const std::string & filename, server_tokens & tokens) {
std::ifstream file(filename, std::ios::binary);
if (!file.is_open()) {
return 0;
}
size_t pos = 0;
json token_json;
if (file.is_open()) {
file >> token_json;
pos = file.tellg();
file.close();
}
uint32_t magic = token_json.value<uint32_t>("magic", 0);
uint32_t version = token_json.value<uint32_t>("version", 0);
if (magic != LLAMA_SERVER_MAGIC || version != LLAMA_SERVER_VERSION) {
LLAMA_LOG_ERROR("%s: unknown (magic, version) for token file: %08x, %08x\n", __func__, magic, version);
return 0;
}
tokens.from_json(token_json);
return pos;
}
void server_context::process_single_task(server_task&& task) {
switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION:
@@ -2153,14 +2266,14 @@ void server_context::process_single_task(server_task&& task) {
queue_tasks.defer(std::move(task));
break;
}
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
break;
}
const size_t token_count = slot->cache_tokens.size();
const int64_t t_start = ggml_time_us();
std::string filename = task.data.at("filename");
std::string filepath = task.data.at("filepath");
save_server_tokens_to_file(filepath+".tokens.json", slot->cache_tokens);
size_t saved = save_checkpoints_to_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints);
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
@@ -2175,7 +2288,7 @@ void server_context::process_single_task(server_task&& task) {
{ "id_slot", id_slot },
{ "filename", filename },
{ "n_saved", token_count }, // tokens saved
{ "n_written", nwrite }, // bytes written
{ "n_written", nwrite + saved }, // bytes written
{ "timings", {
{ "save_ms", t_save_ms }
} }
@@ -2196,9 +2309,6 @@ void server_context::process_single_task(server_task&& task) {
queue_tasks.defer(std::move(task));
break;
}
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
break;
}
const int64_t t_start = ggml_time_us();
std::string filename = task.data.at("filename");
@@ -2212,10 +2322,9 @@ void server_context::process_single_task(server_task&& task) {
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
break;
}
slot->cache_tokens.resize(token_count);
if (mctx) {
slot->cache_tokens.has_mtmd = true;
}
load_server_tokens_from_file(filepath+".tokens.json", slot->cache_tokens);
size_t loaded = load_checkpoints_from_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints);
const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0;
@@ -2248,14 +2357,13 @@ void server_context::process_single_task(server_task&& task) {
queue_tasks.defer(std::move(task));
break;
}
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
break;
}
// Erase token cache
const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
slot->cache_tokens.clear();
slot->cache_tokens.keep_first(0);
//slot->cache_tokens.clear();
slot->server_cached_prompt.checkpoints.clear();
slot->server_cached_prompt.data.clear();
server_task_result result;
result.id = task.id;
result.stop = true;