mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-24 06:35:28 +00:00
server: support slot save/restore/erase for mtmd tokens and checkpoints (#1584)
Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -11,6 +11,8 @@
|
||||
#include "mtmd.h"
|
||||
#include "mtmd-helper.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <regex>
|
||||
|
||||
static void log_text(const gpt_params & params_base, const std::string & text) {
|
||||
@@ -1995,6 +1997,117 @@ void server_context::split_multiprompt_task(int id_multi, server_task& multiprom
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
static size_t save_checkpoints_to_file(const std::string & filename, const std::list<server_prompt_checkpoint> & checkpoints) {
|
||||
if (checkpoints.size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
std::ofstream file(filename, std::ios::binary);
|
||||
uint32_t magic = LLAMA_STATE_SEQ_MAGIC;
|
||||
file.write(reinterpret_cast<const char *>(&magic), sizeof(magic));
|
||||
uint32_t version = LLAMA_STATE_SEQ_VERSION;
|
||||
file.write(reinterpret_cast<const char *>(&version), sizeof(version));
|
||||
size_t count = checkpoints.size();
|
||||
file.write(reinterpret_cast<const char *>(&count), sizeof(count));
|
||||
|
||||
for (const auto & checkpoint : checkpoints) {
|
||||
file.write(reinterpret_cast<const char *>(&checkpoint.pos_min), sizeof(checkpoint.pos_min));
|
||||
file.write(reinterpret_cast<const char *>(&checkpoint.pos_max), sizeof(checkpoint.pos_max));
|
||||
file.write(reinterpret_cast<const char *>(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt));
|
||||
file.write(reinterpret_cast<const char *>(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt));
|
||||
size_t data_len = checkpoint.data.size();
|
||||
file.write(reinterpret_cast<const char *>(&data_len), sizeof(data_len));
|
||||
if (data_len > 0) {
|
||||
file.write(reinterpret_cast<const char *>(checkpoint.data.data()), data_len * sizeof(uint8_t));
|
||||
}
|
||||
}
|
||||
size_t pos = file.tellp();
|
||||
file.close();
|
||||
return pos;
|
||||
}
|
||||
|
||||
static size_t load_checkpoints_from_file(const std::string & filename, std::list<server_prompt_checkpoint> & checkpoints) {
|
||||
std::ifstream file(filename, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
return 0;
|
||||
}
|
||||
checkpoints.clear();
|
||||
// version checks
|
||||
{
|
||||
uint32_t magic;
|
||||
file.read(reinterpret_cast<char *>(&magic), sizeof(magic));
|
||||
uint32_t version;
|
||||
file.read(reinterpret_cast<char *>(&version), sizeof(version));
|
||||
|
||||
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
|
||||
LLAMA_LOG_ERROR("%s: unknown (magic, version) for checkpoint file: %08x, %08x\n", __func__, magic, version);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
// load the checkpoints
|
||||
{
|
||||
size_t count;
|
||||
file.read(reinterpret_cast<char *>(&count), sizeof(count));
|
||||
|
||||
for (int i = 0; i < count; i++) {
|
||||
server_prompt_checkpoint checkpoint;
|
||||
file.read(reinterpret_cast<char *>(&checkpoint.pos_min), sizeof(checkpoint.pos_min));
|
||||
file.read(reinterpret_cast<char *>(&checkpoint.pos_max), sizeof(checkpoint.pos_max));
|
||||
file.read(reinterpret_cast<char *>(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt));
|
||||
file.read(reinterpret_cast<char *>(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt));
|
||||
|
||||
size_t data_len;
|
||||
file.read(reinterpret_cast<char *>(&data_len), sizeof(data_len));
|
||||
if (data_len > 0) {
|
||||
checkpoint.data.resize(data_len);
|
||||
file.read(reinterpret_cast<char *>(checkpoint.data.data()), data_len * sizeof(uint8_t));
|
||||
}
|
||||
checkpoints.push_back(checkpoint);
|
||||
}
|
||||
}
|
||||
size_t pos = file.tellg();
|
||||
file.close();
|
||||
return pos;
|
||||
}
|
||||
|
||||
static size_t save_server_tokens_to_file(const std::string & filename, const server_tokens & tokens) {
|
||||
std::ofstream file(filename, std::ios::binary);
|
||||
json token_json = tokens.to_json();
|
||||
token_json["magic"] = LLAMA_SERVER_MAGIC;
|
||||
token_json["version"] = LLAMA_SERVER_VERSION;
|
||||
size_t pos = 0;
|
||||
if (file.is_open()) {
|
||||
file << token_json;
|
||||
pos = file.tellp();
|
||||
file.close();
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
static size_t load_server_tokens_from_file(const std::string & filename, server_tokens & tokens) {
|
||||
std::ifstream file(filename, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
return 0;
|
||||
}
|
||||
size_t pos = 0;
|
||||
json token_json;
|
||||
if (file.is_open()) {
|
||||
file >> token_json;
|
||||
pos = file.tellg();
|
||||
file.close();
|
||||
}
|
||||
uint32_t magic = token_json.value<uint32_t>("magic", 0);
|
||||
uint32_t version = token_json.value<uint32_t>("version", 0);
|
||||
if (magic != LLAMA_SERVER_MAGIC || version != LLAMA_SERVER_VERSION) {
|
||||
LLAMA_LOG_ERROR("%s: unknown (magic, version) for token file: %08x, %08x\n", __func__, magic, version);
|
||||
return 0;
|
||||
}
|
||||
tokens.from_json(token_json);
|
||||
|
||||
return pos;
|
||||
}
|
||||
|
||||
void server_context::process_single_task(server_task&& task) {
|
||||
switch (task.type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
@@ -2153,14 +2266,14 @@ void server_context::process_single_task(server_task&& task) {
|
||||
queue_tasks.defer(std::move(task));
|
||||
break;
|
||||
}
|
||||
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
|
||||
break;
|
||||
}
|
||||
|
||||
const size_t token_count = slot->cache_tokens.size();
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
std::string filename = task.data.at("filename");
|
||||
std::string filepath = task.data.at("filepath");
|
||||
save_server_tokens_to_file(filepath+".tokens.json", slot->cache_tokens);
|
||||
size_t saved = save_checkpoints_to_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints);
|
||||
|
||||
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
|
||||
|
||||
@@ -2175,7 +2288,7 @@ void server_context::process_single_task(server_task&& task) {
|
||||
{ "id_slot", id_slot },
|
||||
{ "filename", filename },
|
||||
{ "n_saved", token_count }, // tokens saved
|
||||
{ "n_written", nwrite }, // bytes written
|
||||
{ "n_written", nwrite + saved }, // bytes written
|
||||
{ "timings", {
|
||||
{ "save_ms", t_save_ms }
|
||||
} }
|
||||
@@ -2196,9 +2309,6 @@ void server_context::process_single_task(server_task&& task) {
|
||||
queue_tasks.defer(std::move(task));
|
||||
break;
|
||||
}
|
||||
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
|
||||
break;
|
||||
}
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
std::string filename = task.data.at("filename");
|
||||
@@ -2212,10 +2322,9 @@ void server_context::process_single_task(server_task&& task) {
|
||||
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
||||
break;
|
||||
}
|
||||
slot->cache_tokens.resize(token_count);
|
||||
if (mctx) {
|
||||
slot->cache_tokens.has_mtmd = true;
|
||||
}
|
||||
load_server_tokens_from_file(filepath+".tokens.json", slot->cache_tokens);
|
||||
size_t loaded = load_checkpoints_from_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints);
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
||||
|
||||
@@ -2248,14 +2357,13 @@ void server_context::process_single_task(server_task&& task) {
|
||||
queue_tasks.defer(std::move(task));
|
||||
break;
|
||||
}
|
||||
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
|
||||
break;
|
||||
}
|
||||
// Erase token cache
|
||||
const size_t n_erased = slot->cache_tokens.size();
|
||||
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
|
||||
slot->cache_tokens.clear();
|
||||
|
||||
slot->cache_tokens.keep_first(0);
|
||||
//slot->cache_tokens.clear();
|
||||
slot->server_cached_prompt.checkpoints.clear();
|
||||
slot->server_cached_prompt.data.clear();
|
||||
server_task_result result;
|
||||
result.id = task.id;
|
||||
result.stop = true;
|
||||
|
||||
Reference in New Issue
Block a user