mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
Bug fixes for completions and prompt caching in server (#906)
* Bug fixes for completions and prompt caching in server * Fix compiler warning about redefinition --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -15,13 +15,7 @@
|
|||||||
// crash the server in debug mode, otherwise send an http 500 error
|
// crash the server in debug mode, otherwise send an http 500 error
|
||||||
#define CPPHTTPLIB_NO_EXCEPTIONS 1
|
#define CPPHTTPLIB_NO_EXCEPTIONS 1
|
||||||
#endif
|
#endif
|
||||||
// increase max payload length to allow use of larger context size
|
|
||||||
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
|
||||||
// disable Nagle's algorithm
|
|
||||||
#define CPPHTTPLIB_TCP_NODELAY true
|
|
||||||
#include "httplib.h"
|
|
||||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
#include "index.html.gz.hpp"
|
#include "index.html.gz.hpp"
|
||||||
#include "index_llamacpp.html.gz.hpp"
|
#include "index_llamacpp.html.gz.hpp"
|
||||||
@@ -3050,7 +3044,7 @@ struct server_context {
|
|||||||
GGML_ASSERT(slot.ga_n == 1);
|
GGML_ASSERT(slot.ga_n == 1);
|
||||||
|
|
||||||
// reuse any previously computed tokens that are common with the new prompt
|
// reuse any previously computed tokens that are common with the new prompt
|
||||||
slot.n_past = common_part(slot.cache_tokens.tokens_data(), prompt_tokens.tokens_data());
|
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
|
||||||
|
|
||||||
// push the prompt into the sampling context (do not apply grammar)
|
// push the prompt into the sampling context (do not apply grammar)
|
||||||
for (int i = 0; i < slot.n_past; ++i) {
|
for (int i = 0; i < slot.n_past; ++i) {
|
||||||
@@ -3137,7 +3131,6 @@ struct server_context {
|
|||||||
{
|
{
|
||||||
const auto& chunk = slot.prompt_tokens.find_chunk(slot.n_past);
|
const auto& chunk = slot.prompt_tokens.find_chunk(slot.n_past);
|
||||||
slot.cache_tokens.push_back(chunk.get()); // copy
|
slot.cache_tokens.push_back(chunk.get()); // copy
|
||||||
fprintf(stdout, slot.cache_tokens.detokenize(ctx, true).c_str());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.n_past += n_pos;
|
slot.n_past += n_pos;
|
||||||
@@ -4293,14 +4286,15 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto& prompt = data.at("prompt");
|
const auto& prompt = data.at("prompt");
|
||||||
fprintf(stdout, prompt.get<std::string>().c_str());
|
|
||||||
|
|
||||||
// process prompt
|
// process prompt
|
||||||
std::vector<server_tokens> inputs;
|
std::vector<server_tokens> inputs;
|
||||||
|
|
||||||
if (oaicompat && ctx_server.mctx != nullptr) {
|
if (oaicompat && ctx_server.mctx != nullptr) {
|
||||||
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
|
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
|
||||||
printFilesInfo(files);
|
#ifndef NDEBUG
|
||||||
|
print_files_info(files);
|
||||||
|
#endif // !NDEBUG
|
||||||
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
|
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
@@ -4346,31 +4340,26 @@ int main(int argc, char ** argv) {
|
|||||||
if (!result.error) {
|
if (!result.error) {
|
||||||
result.oaicompat = oaicompat;
|
result.oaicompat = oaicompat;
|
||||||
result.oaicompat_cmpl_id = completion_id;
|
result.oaicompat_cmpl_id = completion_id;
|
||||||
json result_array;
|
json res_json;
|
||||||
if (oaicompat) {
|
if (oaicompat) {
|
||||||
if (result.final_result) {
|
if (result.final_result) {
|
||||||
result_array = result.to_json_final();
|
res_json = result.to_json_final();
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
result_array = result.to_json_partial();
|
res_json = result.to_json_partial();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// legacy completions
|
// legacy completions
|
||||||
result_array = result.data;
|
res_json = result.data;
|
||||||
}
|
}
|
||||||
if (result_array.is_array()) {
|
if (res_json.is_array()) {
|
||||||
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
|
// chat completions and oai completions
|
||||||
if (!it->empty()) {
|
for (const auto& res : res_json) {
|
||||||
const std::string str =
|
if (!server_sent_event(sink, res)) {
|
||||||
"data: " +
|
// sending failed (HTTP connection closed), cancel the generation
|
||||||
it->dump(-1, ' ', false, json::error_handler_t::replace) +
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
"\n\n";
|
return false;
|
||||||
LOG_VERBOSE("data stream", { {"to_send", str} });
|
|
||||||
if (!sink.write(str.c_str(), str.size())) {
|
|
||||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (result.stop) {
|
if (result.stop) {
|
||||||
@@ -4378,14 +4367,19 @@ int main(int argc, char ** argv) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
// legacy completions
|
||||||
|
if (!server_sent_event(sink, res_json)) {
|
||||||
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (result.stop) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
const std::string str =
|
if (!server_sent_event(sink, result.data)) {
|
||||||
"error: " +
|
|
||||||
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
||||||
"\n\n";
|
|
||||||
LOG_VERBOSE("data stream", { {"to_send", str} });
|
|
||||||
if (!sink.write(str.c_str(), str.size())) {
|
|
||||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -4436,7 +4430,7 @@ int main(int argc, char ** argv) {
|
|||||||
data,
|
data,
|
||||||
files,
|
files,
|
||||||
res,
|
res,
|
||||||
OAICOMPAT_TYPE_CHAT);
|
OAICOMPAT_TYPE_COMPLETION);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
|||||||
@@ -15,6 +15,16 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
|
// increase max payload length to allow use of larger context size
|
||||||
|
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
||||||
|
// increase backlog size to avoid connection resets for >> 1 slots
|
||||||
|
#define CPPHTTPLIB_LISTEN_BACKLOG 512
|
||||||
|
// increase max URI length to handle longer prompts in query string
|
||||||
|
#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768
|
||||||
|
// disable Nagle's algorithm
|
||||||
|
#define CPPHTTPLIB_TCP_NODELAY true
|
||||||
|
#include "httplib.h"
|
||||||
|
|
||||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
@@ -411,6 +421,17 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool server_sent_event(httplib::DataSink& sink, const json& data) {
|
||||||
|
const std::string str =
|
||||||
|
"data: " +
|
||||||
|
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||||
|
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
|
||||||
|
|
||||||
|
LOG_VERBOSE("data stream, to_send: %s", str.c_str());
|
||||||
|
|
||||||
|
return sink.write(str.c_str(), str.size());
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// OAI utils
|
// OAI utils
|
||||||
//
|
//
|
||||||
@@ -1065,7 +1086,6 @@ public:
|
|||||||
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||||
GGML_ASSERT(has_mtmd);
|
GGML_ASSERT(has_mtmd);
|
||||||
const int n_pos = mtmd_input_chunk_get_n_pos(chunk);
|
const int n_pos = mtmd_input_chunk_get_n_pos(chunk);
|
||||||
fprintf(stdout, "n_pos: %d\n", n_pos);
|
|
||||||
llama_pos start_pos = tokens.size();
|
llama_pos start_pos = tokens.size();
|
||||||
for (int i = 0; i < n_pos; ++i) {
|
for (int i = 0; i < n_pos; ++i) {
|
||||||
tokens.emplace_back(LLAMA_TOKEN_NULL);
|
tokens.emplace_back(LLAMA_TOKEN_NULL);
|
||||||
@@ -1209,39 +1229,54 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t get_common_prefix(const server_tokens& b) const {
|
size_t get_common_prefix(const server_tokens& b) const {
|
||||||
size_t max_idx = std::min(tokens.size(), b.tokens.size());
|
const size_t max_idx = std::min(tokens.size(), b.tokens.size());
|
||||||
for (size_t i = 0; i < max_idx; ++i) {
|
|
||||||
auto& ai = tokens[i];
|
|
||||||
auto& bi = b.tokens[i];
|
|
||||||
|
|
||||||
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
|
if (!has_mtmd) {
|
||||||
GGML_ASSERT(has_mtmd);
|
for (size_t i = 0; i < max_idx; ++i) {
|
||||||
const auto& a_chunk = find_chunk(i);
|
if (tokens[i] == b.tokens[i]) {
|
||||||
const auto& b_chunk = b.find_chunk(i);
|
|
||||||
GGML_ASSERT(a_chunk && b_chunk);
|
|
||||||
std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get());
|
|
||||||
std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get());
|
|
||||||
size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get());
|
|
||||||
size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get());
|
|
||||||
if (ai_id == bi_id && a_pos == b_pos) {
|
|
||||||
GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen
|
|
||||||
i += a_pos - 1; // will be +1 by the for loop
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (ai == bi) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
return max_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < max_idx; ++i) {
|
||||||
|
const llama_token ai = tokens[i];
|
||||||
|
const llama_token bi = b.tokens[i];
|
||||||
|
|
||||||
|
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
|
||||||
|
const auto& a_chunk = find_chunk(i);
|
||||||
|
const auto& b_chunk = b.find_chunk(i);
|
||||||
|
|
||||||
|
GGML_ASSERT(a_chunk && b_chunk);
|
||||||
|
|
||||||
|
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
|
||||||
|
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
|
||||||
|
|
||||||
|
const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get());
|
||||||
|
const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get());
|
||||||
|
|
||||||
|
if (id_ai == id_bi && pos_a == pos_b) {
|
||||||
|
GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen
|
||||||
|
i += pos_a - 1; // will be +1 by the for loop
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ai == bi) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
return max_idx; // all tokens are equal
|
return max_idx; // all tokens are equal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// make sure all text tokens are within the vocab range
|
// make sure all text tokens are within the vocab range
|
||||||
bool validate(const struct llama_context* ctx) const {
|
bool validate(const struct llama_context* ctx) const {
|
||||||
const llama_model* model = llama_get_model(ctx);
|
const llama_model* model = llama_get_model(ctx);
|
||||||
@@ -1274,10 +1309,12 @@ public:
|
|||||||
llama_pos n_past,
|
llama_pos n_past,
|
||||||
int32_t seq_id,
|
int32_t seq_id,
|
||||||
llama_pos& n_pos_out) {
|
llama_pos& n_pos_out) {
|
||||||
|
char buffer[512];
|
||||||
auto& chunk = find_chunk(n_past);
|
auto& chunk = find_chunk(n_past);
|
||||||
const char* name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
|
const char* name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
|
||||||
? "image" : "audio";
|
? "image" : "audio";
|
||||||
LOG_INFO("processing %s...\n", name);
|
snprintf(buffer, 512, "processing : %s",name);
|
||||||
|
LOG_INFO(buffer, {});
|
||||||
int32_t n_batch = llama_n_batch(ctx);
|
int32_t n_batch = llama_n_batch(ctx);
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
llama_pos new_n_past = n_past;
|
llama_pos new_n_past = n_past;
|
||||||
@@ -1288,9 +1325,11 @@ public:
|
|||||||
n_batch,
|
n_batch,
|
||||||
true, // logits last
|
true, // logits last
|
||||||
&new_n_past);
|
&new_n_past);
|
||||||
LOG_INFO("processed in %" PRId64 " ms\n", ggml_time_ms() - t0);
|
snprintf(buffer, 512, "processed in %d ms", ggml_time_ms() - t0);
|
||||||
|
LOG_INFO(buffer, {});
|
||||||
if (result != 0) {
|
if (result != 0) {
|
||||||
LOG_ERROR("mtmd_helper_eval failed with status %d", result);
|
snprintf(buffer, 512, "mtmd_helper_eval failed with status %d", result);
|
||||||
|
LOG_ERROR(buffer, {});
|
||||||
n_pos_out = n_past;
|
n_pos_out = n_past;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@@ -1422,7 +1461,7 @@ static std::vector<server_tokens> tokenize_input_prompts(const llama_vocab* voca
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
// Assuming raw_buffer has .data() and .size() members
|
// Assuming raw_buffer has .data() and .size() members
|
||||||
inline void printFilesInfo(const std::vector<raw_buffer>& files) {
|
inline void print_files_info(const std::vector<raw_buffer>& files) {
|
||||||
for (size_t i = 0; i < files.size(); ++i) {
|
for (size_t i = 0; i < files.size(); ++i) {
|
||||||
const auto& file = files[i];
|
const auto& file = files[i];
|
||||||
std::cout << "File " << i << ": Size = " << file.size() << " bytes\n";
|
std::cout << "File " << i << ": Size = " << file.size() << " bytes\n";
|
||||||
|
|||||||
Reference in New Issue
Block a user