mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-07 23:10:10 +00:00
Add an endpoint that lists all the saved prompt caches to server (#502)
This commit is contained in:
@@ -3390,6 +3390,48 @@ int main(int argc, char ** argv) {
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
const auto list_saved_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
json response = json::array();
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
try {
|
||||
for (const auto& entry : fs::directory_iterator(params.slot_save_path)) {
|
||||
if (!entry.is_regular_file() || entry.file_size() < 12) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::ifstream file(entry.path(), std::ios::binary);
|
||||
if (!file) continue;
|
||||
|
||||
uint32_t magic, version, n_token_count;
|
||||
file.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
file.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||
file.read(reinterpret_cast<char*>(&n_token_count), sizeof(n_token_count));
|
||||
|
||||
if (magic != LLAMA_STATE_SEQ_MAGIC ||
|
||||
version != LLAMA_STATE_SEQ_VERSION ||
|
||||
entry.file_size() < (12 + (n_token_count * sizeof(llama_token)))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens(n_token_count);
|
||||
file.read(reinterpret_cast<char*>(tokens.data()), tokens.size() * sizeof(llama_token));
|
||||
|
||||
response.push_back({
|
||||
{"filename", entry.path().filename().string()},
|
||||
{"filesize", entry.file_size()},
|
||||
{"token_count", n_token_count},
|
||||
{"prompt", tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend())}
|
||||
});
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
res.status = 500;
|
||||
response = {{"error", e.what()}};
|
||||
}
|
||||
res.set_content(response.dump(), "application/json; charset=utf-8");
|
||||
};
|
||||
|
||||
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
|
||||
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
|
||||
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
|
||||
@@ -3448,8 +3490,9 @@ int main(int argc, char ** argv) {
|
||||
// Save & load slots
|
||||
svr->Get ("/slots", handle_slots);
|
||||
if (!params.slot_save_path.empty()) {
|
||||
// only enable slot endpoints if slot_save_path is set
|
||||
// these endpoints rely on slot_save_path existing
|
||||
svr->Post("/slots/:id_slot", handle_slots_action);
|
||||
svr->Get ("/list", list_saved_prompts);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
Reference in New Issue
Block a user