mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-27 01:29:51 +00:00
* Add new webui from llama.cpp * Add new webui * feat: Improve mobile UI for Settings Dialog (#16084) * feat: Improve mobile UI for Settings Dialog * chore: update webui build output * fix: Linting errors * chore: update webui build output # Conflicts: # examples/server/webui_llamacpp/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte # examples/server/webui_llamacpp/src/lib/components/app/chat/ChatSettings/ChatSettingsSection.svelte # tools/server/public/index.html.gz * webui : fix handling incomplete chunks (#16107) * Always show message actions for mobile UI + improvements for user message sizing (#16076) # Conflicts: # .gitignore # examples/server/webui_llamacpp/package.json # examples/server/webui_llamacpp/scripts/dev.sh # tools/server/webui/scripts/post-build.sh * webui: switch to hash-based routing (alternative of #16079) (#16157) * Switched web UI to hash-based routing * Added hash to missed goto function call * Removed outdated SPA handling code * Fixed broken sidebar home link # Conflicts: # examples/server/webui_llamacpp/src/routes/+layout.ts # tools/server/server.cpp * Allow viewing conversations even when llama server is down (#16255) * webui: allow viewing conversations and sending messages even if llama-server is down - Cached llama.cpp server properties in browser localStorage on startup, persisting successful fetches and reloading them when refresh attempts fail so the chat UI continues to render while the backend is unavailable. - Cleared the stored server properties when resetting the store to prevent stale capability data after cache-backed operation. - Kept the original error-splash behavior when no cached props exist so fresh installs still surface a clear failure state instead of rendering stale data. * feat: Add UI for `props` endpoint unavailable + cleanup logic * webui: extend cached props fallback to offline errors Treat connection failures (refused, DNS, timeout, fetch) the same way as server 5xx so the warning banner shows up when cache is available, instead of falling back to a full error screen. * webui: Left the chat form enabled when a server warning is present so operators can keep sending messages e.g., to restart the backend over llama-swap, even while cached /props data is in use * chore: update webui build output --------- Co-authored-by: Pascal <admin@serveurperso.com> # Conflicts: # examples/server/webui_llamacpp/src/lib/components/app/chat/ChatScreen/ChatScreenWarning.svelte # examples/server/webui_llamacpp/src/lib/constants/localstorage-keys.ts * Enhance text file detection logic for file attachments (#16199) * feat: Enhances text file detection logic * chore: Build static `webui` output * chore: update webui build output # Conflicts: # examples/server/webui_llamacpp/src/lib/constants/binary-detection.ts * Show message actions by default (#16289) * fix: preserved zero values in chat settings inputs and textareas by switching to nullish coalescing for field values and default placeholders (#16312) * Improve Mobile UI for dialogs and action dropdowns (#16222) * fix: Always show conversation item actions * feat: Improve Alert Dialog and Dialog mobile UI * feat: Add settings reset to default confirmation * fix: Close Edit dialog on save * chore: update webui build output * webui: implement proper z-index system and scroll management - Add CSS variable for centralized z-index control - Fix dropdown positioning with Settings dialog conflicts - Prevent external scroll interference with proper event handling - Clean up hardcoded z-index values for maintainable architecture * webui: ensured the settings dialog enforces dynamic viewport height on mobile while retaining existing desktop sizing overrides * feat: Use `dvh` instead of computed px height for dialogs max height on mobile * chore: update webui build output * feat: Improve Settings fields UI * chore: update webui build output * chore: update webui build output --------- Co-authored-by: Pascal <admin@serveurperso.com> * Fix thinking blocks with quotes + add handling `[THINK]...[/THINK]` blocks (#16326) * fix: prevent reasoning blocks with quotes from being truncated * chore: update webui build output * feat: Improve thinking content parsing * test: Adds ChatMessage component stories for different thinking blocks * chore: update webui build output * fix: ChatMessage story fix --------- Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * Chatapi ignore empty sampling (#16330) * fix: skip empty sampling fields instead of coercing to 0 in chat API options * chore: update webui build output * webui: Remove running `llama-server` within WebUI `dev.sh` script (#16363) * Add optional setting for showing "Model used:" information (#16337) * feat: Add a setting to include model name used to generate the message * feat: UI improvements * feat: Save model info along with the database message entry creation * chore: Build webui static output * Improve code block color theming (#16325) * feat: Improve code block theming * chore: update webui build output * chore: Update webui static build * Conversation action dialogs as singletons from Chat Sidebar + apply conditional rendering for Actions Dropdown for Chat Conversation Items (#16369) * fix: Render Conversation action dialogs as singletons from Chat Sidebar level * chore: update webui build output * fix: Render Actions Dropdown conditionally only when user hovers conversation item + remove unused markup * chore: Update webui static build * fix: Always truncate conversation names * chore: Update webui static build * fix: track viewportHeight via window.innerHeight to avoid unwanted scrolling (#16356) Use <svelte:window bind:innerHeight> instead of manual resize listener Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * webui : Fix messages payload sent to chat completions (#16402) * fix: Include just the currently active message branches instead of all in chat completions request * chore: Build webui static output * chore: Formatting * chore: update webui build output * Capture model name only after first token (streaming) or completed request (#16405) * feat: Capture model name only after first token (streaming) or completed request (non-streaming) * chore: update webui build output * chore: update webui build output * Fix missing messages on sibling navigation (#16408) * fix: resolve message disappearing issue when navigating between regenerated siblings by using current leaf nodes instead of cached sibling IDs * chore: update webui build output * chore: update webui build output * webui : added download action (#13552) (#16282) * webui : added download action (#13552) * webui : import and export (for all conversations) * webui : fixed download-format, import of one conversation * webui : add ExportedConversations type for chat import/export * feat: Update naming & order * chore: Linting * webui : Updated static build output --------- Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * refactor: centralize CoT parsing in backend for streaming mode (#16394) * refactor: unify reasoning handling via backend reasoning_content, drop frontend tag parsing - Updated the chat message component to surface backend-supplied reasoning via message.thinking while showing the raw assistant content without inline tag scrubbing - Simplified chat streaming to append content chunks directly, stream reasoning into the message model, and persist any partial reasoning when generation stops - Refactored the chat service SSE handler to rely on server-provided reasoning_content, removing legacy <think> parsing logic - Refreshed Storybook data and streaming flows to populate the thinking field explicitly for static and streaming assistant messages * refactor: implement streaming-aware universal reasoning parser Remove the streaming mode limitation from --reasoning-format by refactoring try_parse_reasoning() to handle incremental parsing of <think> tags across all formats. - Rework try_parse_reasoning() to track whitespace, partial tags, and multiple reasoning segments, allowing proper separation of reasoning_content and content in streaming mode - Parse reasoning tags before tool call handling in content-only and Llama 3.x formats to ensure inline <think> blocks are captured correctly - Change default reasoning_format from 'auto' to 'deepseek' for consistent behavior - Add 'deepseek-legacy' option to preserve old inline behavior when needed - Update CLI help and documentation to reflect streaming support - Add parser tests for inline <think>...</think> segments The parser now continues processing content after </think> closes instead of stopping, enabling proper message.reasoning_content and message.content separation in both streaming and non-streaming modes. Fixes the issue where streaming responses would dump everything (including post-thinking content) into reasoning_content while leaving content empty. * refactor: address review feedback from allozaur - Passed the assistant message content directly to ChatMessageAssistant to drop the redundant derived state in the chat message component - Simplified chat streaming updates by removing unused partial-thinking handling and persisting partial responses straight from currentResponse - Refreshed the ChatMessage stories to cover standard and reasoning scenarios without the old THINK-tag parsing examples Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * refactor: restore forced reasoning prefix to pass test-chat ([chat] All tests passed) - store the exact sequence seen on input when 'thinking_forced_open' enforces a reasoning block - inject this prefix before the first accumulated segment in 'reasoning_content', then clear it to avoid duplication - repeat the capture on every new 'start_think' detection to properly handle partial/streaming flows * refactor: address review feedback from ngxson * debug: say goodbye to curl -N, hello one-click raw stream - adds a new checkbox in the WebUI to display raw LLM output without backend parsing or frontend Markdown rendering * Update tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * webui: add Storybook example for raw LLM output and scope reasoning format toggle per story - Added a Storybook example that showcases the chat message component in raw LLM output mode with the provided trace sample - Updated every ChatMessage story to toggle the disableReasoningFormat setting so the raw-output rendering remains scoped to its own example * npm run format * chat-parser: address review feedback from ngxson Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> --------- Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> # Conflicts: # common/arg.cpp # examples/server/webui_llamacpp/src/lib/utils/thinking.ts # tools/server/README.md * No markdown in cot (#16483) * fix: let the model think in plaintext * chore: npm run format + npm run build * webui: updated the chat service to only include max_tokens in the req… (#16489) * webui: updated the chat service to only include max_tokens in the request payload when the setting is explicitly provided, while still mapping explicit zero or null values to the infinite-token sentinel * chore: update webui build output * feat: render user content as markdown option (#16358) * feat: render user content as markdown option - Add a persisted 'renderUserContentAsMarkdown' preference to the settings defaults and info metadata so the choice survives reloads like other options - Surface the new 'Render user content as Markdown' checkbox in the General section of the chat settings dialog, beneath the PDF toggle - Render user chat messages with 'MarkdownContent' when the new setting is enabled, matching assistant formatting while preserving the existing card styling otherwise - chore: update webui build output * chore: update webui build output * webui: remove client-side context pre-check and rely on backend for limits (#16506) * fix: make SSE client robust to premature [DONE] in agentic proxy chains * webui: remove client-side context pre-check and rely on backend for limits Removed the client-side context window pre-check and now simply sends messages while keeping the dialog imports limited to core components, eliminating the maximum context alert path Simplified streaming and non-streaming chat error handling to surface a generic 'No response received from server' error whenever the backend returns no content Removed the obsolete maxContextError plumbing from the chat store so state management now focuses on the core message flow without special context-limit cases * webui: cosmetic rename of error messages * Update tools/server/webui/src/lib/stores/chat.svelte.ts Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * Update tools/server/webui/src/lib/stores/chat.svelte.ts Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * Update tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * Update tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * chore: update webui build output --------- Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> # Conflicts: # examples/server/webui_llamacpp/src/lib/components/app/dialogs/ChatErrorDialog.svelte # examples/server/webui_llamacpp/src/lib/components/app/dialogs/MaximumContextAlertDialog.svelte # examples/server/webui_llamacpp/src/lib/services/context.ts * fix: add remark plugin to render raw HTML as literal text (#16505) * fix: add remark plugin to render raw HTML as literal text Implemented a missing MDAST stage to neutralize raw HTML like major LLM WebUIs do ensuring consistent and safe Markdown rendering Introduced 'remarkLiteralHtml', a plugin that converts raw HTML nodes in the Markdown AST into plain-text equivalents while preserving indentation and line breaks. This ensures consistent rendering and prevents unintended HTML execution, without altering valid Markdown structure Kept 'remarkRehype' in the pipeline since it performs the required conversion from MDAST to HAST for KaTeX, syntax highlighting, and HTML serialization Refined the link-enhancement logic to skip unnecessary DOM rewrites, fixing a subtle bug where extra paragraphs were injected after the first line due to full innerHTML reconstruction, and ensuring links open in new tabs only when required Final pipeline: remarkGfm -> remarkMath -> remarkBreaks -> remarkLiteralHtml -> remarkRehype -> rehypeKatex -> rehypeHighlight -> rehypeStringify * fix: address review feedback from allozaur * chore: update webui build output # Conflicts: # examples/server/webui_llamacpp/src/lib/constants/literal-html.ts * Add server-driven parameter defaults and syncing (#16515) # Conflicts: # examples/server/webui_llamacpp/src/lib/components/app/chat/ChatSettings/ParameterSourceIndicator.svelte # examples/server/webui_llamacpp/src/lib/constants/precision.ts # examples/server/webui_llamacpp/src/lib/services/parameter-sync.spec.ts # examples/server/webui_llamacpp/src/lib/services/parameter-sync.ts # examples/server/webui_llamacpp/src/lib/utils/config-helpers.ts # examples/server/webui_llamacpp/src/lib/utils/precision.ts * fix: added a normalization step for MathJax-style \[\] and \(\) delimiters (#16599) * fix: added a normalization step for MathJax-style \[\] and \(\) delimiters So inline and block equations are converted before KaTeX rendering, enabling proper display of model-generated LaTeX in the WebUI * chore: update webui build output * webui: reorganize settings layout (#16607) * webui: reorganize settings layout * chore: update webui build output * fix: remove unused variable * chore: update webui build output * Enable per-conversation loading states to allow having parallel conversations (#16327) * feat: Per-conversation loading states and tracking streaming stats * chore: update webui build output * refactor: Chat state management Consolidates loading state management by using a global `isLoading` store synchronized with individual conversation states. This change ensures proper reactivity and avoids potential race conditions when updating the UI based on the loading status of different conversations. It also improves the accuracy of statistics displayed. Additionally, slots service methods are updated to use conversation IDs for per-conversation state management, avoiding global state pollution. * feat: Adds loading indicator to conversation items * chore: update webui build output * fix: Fix aborting chat streaming Improves the chat stream abortion process by ensuring that partial responses are saved before the abort signal is sent. This avoids a race condition where the onError callback could clear the streaming state before the partial response is saved. Additionally, the stream reading loop and callbacks are now checked for abort signals to prevent further processing after abortion. * refactor: Remove redundant comments * chore: build webui static output * refactor: Cleanup * chore: update webui build output * chore: update webui build output * fix: Conversation loading indicator for regenerating messages * chore: update webui static build * feat: Improve configuration * feat: Install `http-server` as dev dependency to not need to rely on `npx` in CI * Import/Export UX improvements (#16619) * webui : added download action (#13552) * webui : import and export (for all conversations) * webui : fixed download-format, import of one conversation * webui : add ExportedConversations type for chat import/export * feat: Update naming & order * chore: Linting * feat: Import/Export UX improvements * chore: update webui build output * feat: Update UI placement of Import/Export tab in Chat Settings Dialog * refactor: Cleanup chore: update webui build output * feat: Enable shift-click multiple conversation items selection * chore: update webui static build * chore: update webui static build --------- Co-authored-by: Sascha Rogmann <github@rogmann.org> # Conflicts: # examples/server/webui_llamacpp/src/lib/components/app/chat/ChatSettings/ConversationSelectionDialog.svelte # examples/server/webui_llamacpp/src/lib/components/app/chat/ChatSettings/ImportExportTab.svelte # examples/server/webui_llamacpp/src/lib/utils/conversation-utils.ts * Prevent premature submission on IME input (#16673) * fix: Prevent premature submission on IME input * chore: update webui static build * refactor: Put IME completion checker in a helper function and add checking for `KeyboardEvent.eventKey === 229` * chore: update webui static build * chore: update webui static build * chore: update webui static build # Conflicts: # examples/server/webui_llamacpp/src/lib/utils/is-ime-composing.ts * Handle legacy 'context' attachments (#16687) * webui: introduce OpenAI-compatible model selector in JSON payload (#16562) * webui: introduce OpenAI-compatible model selector in JSON payload * webui: restore OpenAI-Compatible model source of truth and unify metadata capture This change re-establishes a single, reliable source of truth for the active model: fully aligned with the OpenAI-Compat API behavior It introduces a unified metadata flow that captures the model field from both streaming and non-streaming responses, wiring a new onModel callback through ChatService The model name is now resolved directly from the API payload rather than relying on server /props or UI assumptions ChatStore records and persists the resolved model for each assistant message during streaming, ensuring consistency across the UI and database Type definitions for API and settings were also extended to include model metadata and the onModel callback, completing the alignment with OpenAI-Compat semantics * webui: address review feedback from allozaur * webui: move model selector into ChatForm (idea by @allozaur) * webui: make model selector more subtle and integrated into ChatForm * webui: replaced the Flowbite selector with a native Svelte dropdown * webui: add developer setting to toggle the chat model selector * webui: address review feedback from allozaur Normalized streamed model names during chat updates by trimming input and removing directory components before saving or persisting them, so the conversation UI shows only the filename Forced model names within the chat form selector dropdown to render as a single-line, truncated entry with a tooltip revealing the full name * webui: toggle displayed model source for legacy vs OpenAI-Compat modes When the selector is disabled, it falls back to the active server model name from /props When the model selector is enabled, the displayed model comes from the message metadata (the one explicitly selected and sent in the request) * Update tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions.svelte Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * Update tools/server/webui/src/lib/constants/localstorage-keys.ts Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * Update tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormModelSelector.svelte Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * Update tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * Update tools/server/webui/src/lib/services/chat.ts Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * Update tools/server/webui/src/lib/services/chat.ts Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * webui: refactor model selector and persistence helpers - Replace inline portal and event listeners with proper Svelte bindings - Introduce 'persisted' store helper for localStorage sync without runes - Extract 'normalizeModelName' utils + Vitest coverage - Simplify ChatFormModelSelector structure and cleanup logic Replaced the persisted store helper's use of '$state/$effect' runes with a plain TS implementation to prevent orphaned effect runtime errors outside component context Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * webui: document normalizeModelName usage with inline examples * Update tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormModelSelector.svelte Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * Update tools/server/webui/src/lib/stores/models.svelte.ts Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * Update tools/server/webui/src/lib/stores/models.svelte.ts Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * webui: extract ModelOption type into dedicated models.d.ts Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * webui: refine ChatMessageAssistant displayedModel source logic * webui: stabilize dropdown, simplify model extraction, and init assistant model field * chore: update webui static build * Update tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * chore: npm format, update webui static build * webui: align sidebar trigger position, remove z-index glitch * chore: update webui build output --------- Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> # Conflicts: # examples/server/webui_llamacpp/src/lib/components/app/chat/ChatForm/ChatFormModelSelector.svelte # examples/server/webui_llamacpp/src/lib/services/models.ts # examples/server/webui_llamacpp/src/lib/stores/models.svelte.ts # examples/server/webui_llamacpp/src/lib/stores/persisted.svelte.ts # examples/server/webui_llamacpp/src/lib/types/models.d.ts # examples/server/webui_llamacpp/src/lib/utils/model-names.test.ts # examples/server/webui_llamacpp/src/lib/utils/model-names.ts # examples/server/webui_llamacpp/src/lib/utils/portal-to-body.ts * webui: support q URL parameter (#16728) * webui: support q URL parameter Fixes #16722 I’ve checked that it works with Firefox’s AI tools * webui: apply suggestions from code review Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * chore: update webui static build --------- Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> * build fix --------- Co-authored-by: firecoperana <firecoperana> Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com> Co-authored-by: Quentin Bramas <quentin.bramas@gmail.com> Co-authored-by: Isaac McFadyen <isaac@imcf.me> Co-authored-by: Pascal <admin@serveurperso.com> Co-authored-by: Sascha Rogmann <59577610+srogmann@users.noreply.github.com> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> Co-authored-by: Sascha Rogmann <github@rogmann.org> Co-authored-by: Florian Badie <florianbadie@odrling.xyz>
5173 lines
207 KiB
C++
5173 lines
207 KiB
C++
#pragma warning(disable : 4996)
|
|
#include "chat.h"
|
|
#include "utils.hpp"
|
|
|
|
#include "common.h"
|
|
#include "speculative.h"
|
|
#include "sampling.h"
|
|
#include "json-schema-to-grammar.h"
|
|
#include "llama.h"
|
|
#include "grammar-parser.h"
|
|
#include "llama-vocab.h"
|
|
|
|
#ifndef NDEBUG
|
|
// crash the server in debug mode, otherwise send an http 500 error
|
|
#define CPPHTTPLIB_NO_EXCEPTIONS 1
|
|
#endif
|
|
// increase max payload length to allow use of larger context size
|
|
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
|
// disable Nagle's algorithm
|
|
#define CPPHTTPLIB_TCP_NODELAY true
|
|
#include "httplib.h"
|
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
|
#define JSON_ASSERT GGML_ASSERT
|
|
#include <nlohmann/json.hpp>
|
|
#include "index.html.gz.hpp"
|
|
#include "index_llamacpp.html.gz.hpp"
|
|
#include "loading.html.hpp"
|
|
|
|
#include <atomic>
|
|
#include <chrono>
|
|
#include <condition_variable>
|
|
#include <cstddef>
|
|
#include <set>
|
|
#include <mutex>
|
|
#include <thread>
|
|
#include <signal.h>
|
|
#include <memory>
|
|
#include <random>
|
|
#include <algorithm>
|
|
#include <src/llama-impl.h>
|
|
#ifdef SQLITE3_MODERN_CPP_SUPPORT
|
|
#include <sqlite_modern_cpp.h>
|
|
|
|
struct DatabaseHandle {
|
|
sqlite::database db;
|
|
|
|
DatabaseHandle(const std::string& path) : db(path) {
|
|
db << "CREATE TABLE IF NOT EXISTS sessions (key TEXT PRIMARY KEY, data TEXT)";
|
|
db << "CREATE TABLE IF NOT EXISTS templates (key TEXT PRIMARY KEY, data TEXT)";
|
|
db << "CREATE TABLE IF NOT EXISTS names (key TEXT PRIMARY KEY, data TEXT)";
|
|
}
|
|
};
|
|
#endif
|
|
|
|
using json = nlohmann::ordered_json;
|
|
|
|
bool server_verbose = false;
|
|
bool server_log_json = true;
|
|
|
|
|
|
|
|
enum stop_type {
|
|
STOP_TYPE_NONE,
|
|
STOP_TYPE_EOS,
|
|
STOP_TYPE_WORD,
|
|
STOP_TYPE_LIMIT,
|
|
};
|
|
enum slot_state {
|
|
SLOT_STATE_IDLE,
|
|
SLOT_STATE_PROCESSING,
|
|
};
|
|
|
|
enum slot_command {
|
|
SLOT_COMMAND_NONE,
|
|
SLOT_COMMAND_LOAD_PROMPT,
|
|
SLOT_COMMAND_RELEASE,
|
|
};
|
|
|
|
enum server_state {
|
|
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
|
|
SERVER_STATE_READY, // Server is ready and model is loaded
|
|
SERVER_STATE_ERROR // An error occurred, load_model failed
|
|
};
|
|
|
|
enum server_task_type {
|
|
SERVER_TASK_TYPE_COMPLETION,
|
|
SERVER_TASK_TYPE_CANCEL,
|
|
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
|
SERVER_TASK_TYPE_METRICS,
|
|
SERVER_TASK_TYPE_SLOT_SAVE,
|
|
SERVER_TASK_TYPE_SLOT_RESTORE,
|
|
SERVER_TASK_TYPE_SLOT_ERASE,
|
|
SERVER_TASK_TYPE_SET_LORA,
|
|
};
|
|
|
|
enum oaicompat_type {
|
|
OAICOMPAT_TYPE_NONE,
|
|
OAICOMPAT_TYPE_CHAT,
|
|
OAICOMPAT_TYPE_COMPLETION,
|
|
OAICOMPAT_TYPE_EMBEDDING,
|
|
};
|
|
|
|
struct result_timings {
|
|
int32_t prompt_n = -1;
|
|
double prompt_ms;
|
|
double prompt_per_token_ms;
|
|
double prompt_per_second;
|
|
|
|
int32_t predicted_n = -1;
|
|
double predicted_ms;
|
|
double predicted_per_token_ms;
|
|
double predicted_per_second;
|
|
|
|
// Optional speculative metrics - only included when > 0
|
|
int32_t draft_n = 0;
|
|
int32_t draft_n_accepted = 0;
|
|
|
|
json to_json() const {
|
|
json base = {
|
|
{"prompt_n", prompt_n},
|
|
{"prompt_ms", prompt_ms},
|
|
{"prompt_per_token_ms", prompt_per_token_ms},
|
|
{"prompt_per_second", prompt_per_second},
|
|
|
|
{"predicted_n", predicted_n},
|
|
{"predicted_ms", predicted_ms},
|
|
{"predicted_per_token_ms", predicted_per_token_ms},
|
|
{"predicted_per_second", predicted_per_second},
|
|
};
|
|
|
|
if (draft_n > 0) {
|
|
base["draft_n"] = draft_n;
|
|
base["draft_n_accepted"] = draft_n_accepted;
|
|
}
|
|
|
|
return base;
|
|
}
|
|
};
|
|
|
|
struct server_task {
|
|
int id = -1; // to be filled by server_queue
|
|
int id_multi = -1;
|
|
int id_target = -1;
|
|
|
|
server_task_type type;
|
|
json data;
|
|
|
|
bool infill = false;
|
|
bool embedding = false;
|
|
};
|
|
|
|
struct server_task_result {
|
|
int id = -1;
|
|
int id_multi = -1;
|
|
|
|
json data;
|
|
|
|
bool stop;
|
|
bool error;
|
|
bool final_result = false;
|
|
result_timings timings;
|
|
// OAI-compat fields
|
|
//bool verbose = false;
|
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
|
std::string oaicompat_model;
|
|
std::string oaicompat_cmpl_id;
|
|
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
common_chat_msg oaicompat_msg;
|
|
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
|
|
|
int index = 0;
|
|
|
|
std::string content;
|
|
std::vector<llama_token> tokens;
|
|
|
|
bool stream;
|
|
bool include_usage;
|
|
std::string prompt;
|
|
//slot_params generation_params;
|
|
|
|
bool truncated;
|
|
int32_t n_decoded;
|
|
int32_t n_prompt_tokens;
|
|
int32_t n_tokens_cached;
|
|
bool has_new_line;
|
|
std::string stopping_word;
|
|
|
|
bool post_sampling_probs = false;
|
|
std::vector<completion_token_output> probs_output;
|
|
std::vector<std::string> response_fields;
|
|
|
|
//slot_params generation_params;
|
|
|
|
bool verbose = false;
|
|
|
|
|
|
int get_index() {
|
|
return index;
|
|
}
|
|
|
|
bool is_stop() {
|
|
return true; // in stream mode, final responses are considered stop
|
|
}
|
|
|
|
json to_json_final() {
|
|
switch (oaicompat) {
|
|
case OAICOMPAT_TYPE_NONE:
|
|
return to_json_non_oaicompat_final();
|
|
case OAICOMPAT_TYPE_COMPLETION:
|
|
return to_json_oaicompat_final();
|
|
case OAICOMPAT_TYPE_CHAT:
|
|
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat_final();
|
|
default:
|
|
GGML_ASSERT(false && "Invalid oaicompat_type");
|
|
}
|
|
}
|
|
|
|
json to_json_partial() {
|
|
switch (oaicompat) {
|
|
case OAICOMPAT_TYPE_NONE:
|
|
return to_json_non_oaicompat_partial();
|
|
case OAICOMPAT_TYPE_COMPLETION:
|
|
return to_json_oaicompat_partial();
|
|
case OAICOMPAT_TYPE_CHAT:
|
|
return to_json_oaicompat_chat_partial();
|
|
default:
|
|
GGML_ASSERT(false && "Invalid oaicompat_type");
|
|
}
|
|
}
|
|
|
|
json to_json_non_oaicompat_partial() {
|
|
// non-OAI-compat JSON
|
|
json res = json{
|
|
{"index", index},
|
|
{"content", content},
|
|
{"tokens", tokens},
|
|
{"stop", false},
|
|
{"id_slot", id_multi},
|
|
{"tokens_predicted", n_decoded},
|
|
{"tokens_evaluated", n_prompt_tokens},
|
|
};
|
|
// populate the timings object when needed (usually for the last response or with timings_per_token enabled)
|
|
if (timings.prompt_n > 0) {
|
|
res.push_back({ "timings", timings.to_json() });
|
|
}
|
|
if (!probs_output.empty()) {
|
|
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
json to_json_non_oaicompat_final() {
|
|
json res = json{
|
|
{"index", index},
|
|
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
|
{"tokens", stream ? std::vector<llama_token> {} : tokens},
|
|
{"id_slot", id_multi},
|
|
{"stop", true},
|
|
{"model", oaicompat_model},
|
|
{"tokens_predicted", n_decoded},
|
|
{"tokens_evaluated", n_prompt_tokens},
|
|
//{"generation_settings", default_generation_settings_for_props.to_json()},
|
|
{"prompt", prompt},
|
|
{"has_new_line", has_new_line},
|
|
{"truncated", truncated},
|
|
//{"stop_type", stop_type_to_str(STOP_TYPE_EOS)},
|
|
{"stopping_word", stopping_word},
|
|
{"tokens_cached", n_tokens_cached},
|
|
{"timings", timings.to_json()},
|
|
};
|
|
if (!stream && !probs_output.empty()) {
|
|
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
|
}
|
|
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
|
|
}
|
|
|
|
json to_json_oaicompat_partial() {
|
|
std::time_t t = std::time(0);
|
|
json logprobs = json(nullptr); // OAI default to null
|
|
if (probs_output.size() > 0) {
|
|
logprobs = json{
|
|
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
|
};
|
|
}
|
|
json res = json{
|
|
{"choices", json::array({
|
|
json{
|
|
{"text", content},
|
|
{"index", index},
|
|
{"logprobs", logprobs},
|
|
{"finish_reason", nullptr},
|
|
}
|
|
})},
|
|
{"created", t},
|
|
{"model", oaicompat_model},
|
|
{"object", "text_completion"},
|
|
{"usage", json {
|
|
{"completion_tokens", n_decoded},
|
|
{"prompt_tokens", n_prompt_tokens},
|
|
{"total_tokens", n_decoded + n_prompt_tokens}
|
|
}},
|
|
{"id", oaicompat_cmpl_id}
|
|
};
|
|
|
|
// extra fields for debugging purposes
|
|
if (verbose) {
|
|
res["__verbose"] = to_json_non_oaicompat_partial();
|
|
}
|
|
if (timings.prompt_n >= 0) {
|
|
res.push_back({ "timings", timings.to_json() });
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
json to_json_oaicompat_final() {
|
|
std::time_t t = std::time(0);
|
|
json logprobs = json(nullptr); // OAI default to null
|
|
if (!stream && probs_output.size() > 0) {
|
|
logprobs = json{
|
|
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
|
};
|
|
}
|
|
json finish_reason = "length";
|
|
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
|
finish_reason = "stop";
|
|
}
|
|
json res = json{
|
|
{"choices", json::array({
|
|
json{
|
|
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
|
{"index", index},
|
|
{"logprobs", logprobs},
|
|
{"finish_reason", finish_reason},
|
|
}
|
|
})},
|
|
{"created", t},
|
|
{"model", oaicompat_model},
|
|
{"object", "text_completion"},
|
|
{"usage", json {
|
|
{"completion_tokens", n_decoded},
|
|
{"prompt_tokens", n_prompt_tokens},
|
|
{"total_tokens", n_decoded + n_prompt_tokens}
|
|
}},
|
|
{"id", oaicompat_cmpl_id}
|
|
};
|
|
|
|
// extra fields for debugging purposes
|
|
if (verbose) {
|
|
res["__verbose"] = to_json_non_oaicompat_final();
|
|
}
|
|
if (timings.prompt_n >= 0) {
|
|
res.push_back({ "timings", timings.to_json() });
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
json to_json_oaicompat_chat_partial() {
|
|
bool first = n_decoded == 1;
|
|
std::time_t t = std::time(0);
|
|
json choices;
|
|
|
|
std::vector<json> deltas;
|
|
auto add_delta = [&](const json& delta) {
|
|
deltas.push_back({
|
|
{"choices", json::array({
|
|
json {
|
|
{"finish_reason", nullptr},
|
|
{"index", 0},
|
|
{"delta", delta},
|
|
},
|
|
})},
|
|
{"created", t},
|
|
{"id", oaicompat_cmpl_id},
|
|
{"model", oaicompat_model},
|
|
{"object", "chat.completion.chunk"},
|
|
{"usage", json {
|
|
{"completion_tokens", n_decoded},
|
|
{"prompt_tokens", n_prompt_tokens},
|
|
{"total_tokens", n_decoded + n_prompt_tokens},
|
|
}},
|
|
});
|
|
};
|
|
// We have to send an initial update to conform to openai behavior
|
|
if (first) {
|
|
add_delta({
|
|
{"role", "assistant"},
|
|
{"content", nullptr},
|
|
});
|
|
}
|
|
|
|
for (const auto& diff : oaicompat_msg_diffs) {
|
|
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
|
|
}
|
|
|
|
if (!deltas.empty()) {
|
|
GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
|
|
|
|
if (probs_output.size() > 0) {
|
|
deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json{
|
|
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
|
};
|
|
}
|
|
|
|
if (timings.prompt_n >= 0) {
|
|
deltas[deltas.size() - 1].push_back({ "timings", timings.to_json() });
|
|
}
|
|
}
|
|
|
|
return deltas;
|
|
}
|
|
|
|
json to_json_oaicompat_chat_final() {
|
|
std::string finish_reason = "length";
|
|
common_chat_msg msg;
|
|
if (!oaicompat_msg.empty()) {
|
|
msg = oaicompat_msg;
|
|
}
|
|
else {
|
|
msg.role = "assistant";
|
|
msg.content = content;
|
|
}
|
|
if (stop) {
|
|
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
|
|
}
|
|
|
|
|
|
json choice{
|
|
{"finish_reason", finish_reason},
|
|
{"index", 0},
|
|
{"message", msg.to_json_oaicompat<json>()},
|
|
};
|
|
|
|
if (!stream && probs_output.size() > 0) {
|
|
choice["logprobs"] = json{
|
|
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
|
};
|
|
}
|
|
|
|
std::time_t t = std::time(0);
|
|
|
|
json res = json{
|
|
{"choices", json::array({choice})},
|
|
{"created", t},
|
|
{"model", oaicompat_model},
|
|
{"object", "chat.completion"},
|
|
{"usage", json {
|
|
{"completion_tokens", n_decoded},
|
|
{"prompt_tokens", n_prompt_tokens},
|
|
{"total_tokens", n_decoded + n_prompt_tokens}
|
|
}},
|
|
{"id", oaicompat_cmpl_id}
|
|
};
|
|
|
|
// extra fields for debugging purposes
|
|
if (verbose) {
|
|
res["__verbose"] = to_json_non_oaicompat_final();
|
|
}
|
|
if (timings.prompt_n >= 0) {
|
|
res.push_back({ "timings", timings.to_json() });
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
json to_json_oaicompat_chat_stream() {
|
|
std::time_t t = std::time(0);
|
|
std::string finish_reason = "length";
|
|
if (stop) {
|
|
//if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
|
finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
|
|
}
|
|
|
|
json deltas = json::array();
|
|
for (const auto& diff : oaicompat_msg_diffs) {
|
|
deltas.push_back({
|
|
{"choices", json::array({
|
|
json {
|
|
{"finish_reason", nullptr},
|
|
{"index", 0},
|
|
{"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
|
|
},
|
|
})},
|
|
{"created", t},
|
|
{"id", oaicompat_cmpl_id},
|
|
{"model", oaicompat_model},
|
|
{"object", "chat.completion.chunk"},
|
|
});
|
|
}
|
|
|
|
deltas.push_back({
|
|
{"choices", json::array({
|
|
json {
|
|
{"finish_reason", finish_reason},
|
|
{"index", 0},
|
|
{"delta", json::object()},
|
|
},
|
|
})},
|
|
{"created", t},
|
|
{"id", oaicompat_cmpl_id},
|
|
{"model", oaicompat_model},
|
|
{"object", "chat.completion.chunk"},
|
|
});
|
|
if (include_usage) {
|
|
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
|
|
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
|
|
deltas.push_back({
|
|
{"choices", json::array()},
|
|
{"created", t},
|
|
{"id", oaicompat_cmpl_id},
|
|
{"model", oaicompat_model},
|
|
{"object", "chat.completion.chunk"},
|
|
{"usage", json {
|
|
{"completion_tokens", n_decoded},
|
|
{"prompt_tokens", n_prompt_tokens},
|
|
{"total_tokens", n_decoded + n_prompt_tokens},
|
|
}},
|
|
});
|
|
}
|
|
if (timings.prompt_n >= 0) {
|
|
deltas.back().push_back({ "timings", timings.to_json() });
|
|
}
|
|
// extra fields for debugging purposes
|
|
if (verbose && !deltas.empty()) {
|
|
deltas.front()["__verbose"] = to_json_non_oaicompat_final();
|
|
}
|
|
|
|
return deltas;
|
|
}
|
|
};
|
|
|
|
inline std::string stop_type_to_str(stop_type type) {
|
|
switch (type) {
|
|
case STOP_TYPE_EOS: return "eos";
|
|
case STOP_TYPE_WORD: return "word";
|
|
case STOP_TYPE_LIMIT: return "limit";
|
|
default: return "none";
|
|
}
|
|
}
|
|
|
|
|
|
struct server_task_multi {
|
|
int id = -1;
|
|
|
|
std::set<int> subtasks_remaining;
|
|
std::vector<server_task_result> results;
|
|
};
|
|
|
|
struct slot_params {
|
|
bool stream = true;
|
|
bool include_usage = false;
|
|
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
|
|
|
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
|
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
|
int32_t n_predict = -1; // new tokens to predict
|
|
|
|
std::vector<std::string> antiprompt;
|
|
|
|
bool timings_per_token = false;
|
|
bool post_sampling_probs = false;
|
|
json input_prefix;
|
|
json input_suffix;
|
|
|
|
// speculative decoding parameters
|
|
struct {
|
|
int n_max = 16; // max drafted tokens
|
|
int n_min = 0; // min drafted tokens to accept
|
|
float p_min = 0.75f; // min probability required to accept a token in the draft
|
|
} speculative;
|
|
|
|
// OAI-compat fields
|
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
|
std::string oaicompat_model;
|
|
std::string oaicompat_cmpl_id;
|
|
common_chat_syntax oaicompat_chat_syntax;
|
|
|
|
};
|
|
|
|
struct server_slot {
|
|
int id;
|
|
int id_task = -1;
|
|
int id_multi = -1;
|
|
|
|
struct slot_params params;
|
|
|
|
slot_state state = SLOT_STATE_IDLE;
|
|
slot_command command = SLOT_COMMAND_NONE;
|
|
|
|
// used to determine the slot that has been used the longest
|
|
int64_t t_last_used = -1;
|
|
|
|
// generation props
|
|
int32_t n_ctx = 0; // context size per slot
|
|
int32_t n_past = 0;
|
|
int32_t n_decoded = 0;
|
|
int32_t n_remaining = -1;
|
|
int32_t i_batch = -1;
|
|
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
|
|
|
int32_t n_prompt_tokens = 0;
|
|
int32_t n_prompt_tokens_processed = 0;
|
|
|
|
json prompt; // can be either a string, array of strings or array of token ids
|
|
|
|
// when a task is submitted, we first tokenize the prompt and store it here
|
|
std::vector<llama_token> prompt_tokens;
|
|
|
|
std::string generated_text;
|
|
std::vector<llama_token> cache_tokens;
|
|
std::vector<completion_token_output> generated_token_probs;
|
|
common_chat_msg chat_msg;
|
|
|
|
bool infill = false;
|
|
bool embedding = false;
|
|
bool has_next_token = true;
|
|
bool truncated = false;
|
|
bool stopped_eos = false;
|
|
bool stopped_word = false;
|
|
bool stopped_limit = false;
|
|
|
|
bool oaicompat = false;
|
|
|
|
std::string oaicompat_model;
|
|
std::string stopping_word;
|
|
stop_type stop;
|
|
// sampling
|
|
llama_token sampled;
|
|
struct llama_sampling_params sparams;
|
|
llama_sampling_context * ctx_sampling = nullptr;
|
|
json json_schema;
|
|
|
|
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
std::vector<std::string> generated_tool_call_ids;
|
|
|
|
int32_t ga_i = 0; // group-attention state
|
|
int32_t ga_n = 1; // group-attention factor
|
|
int32_t ga_w = 512; // group-attention width
|
|
|
|
// speculative decoding
|
|
struct llama_speculative * spec = nullptr;
|
|
llama_context * ctx_dft = nullptr;
|
|
llama_batch batch_spec = {};
|
|
|
|
// speculative decoding stats
|
|
int32_t n_draft_total = 0; // Total draft tokens generated
|
|
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
|
|
|
|
int32_t n_past_se = 0; // self-extend
|
|
|
|
// stats
|
|
size_t n_sent_text = 0; // number of sent text character
|
|
size_t n_sent_token_probs = 0;
|
|
|
|
int64_t t_start_process_prompt;
|
|
int64_t t_start_generation;
|
|
|
|
double t_prompt_processing; // ms
|
|
double t_token_generation; // ms
|
|
|
|
void reset() {
|
|
n_prompt_tokens = 0;
|
|
generated_text = "";
|
|
truncated = false;
|
|
stopped_eos = false;
|
|
stopped_word = false;
|
|
stopped_limit = false;
|
|
stopping_word = "";
|
|
n_past = 0;
|
|
n_sent_text = 0;
|
|
n_sent_token_probs = 0;
|
|
infill = false;
|
|
ga_i = 0;
|
|
n_past_se = 0;
|
|
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
|
|
generated_token_probs.clear();
|
|
|
|
|
|
// Reset speculative decoding stats
|
|
n_draft_total = 0;
|
|
n_draft_accepted = 0;
|
|
chat_msg = {};
|
|
json_schema = json();
|
|
generated_tool_call_ids.clear();
|
|
}
|
|
|
|
bool has_budget(gpt_params &global_params) {
|
|
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
|
return true; // limitless
|
|
}
|
|
|
|
n_remaining = -1;
|
|
|
|
if (params.n_predict != -1) {
|
|
n_remaining = params.n_predict - n_decoded;
|
|
} else if (global_params.n_predict != -1) {
|
|
n_remaining = global_params.n_predict - n_decoded;
|
|
}
|
|
|
|
return n_remaining > 0; // no budget
|
|
}
|
|
|
|
bool available() const {
|
|
return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE;
|
|
}
|
|
|
|
bool is_processing() const {
|
|
return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING;
|
|
}
|
|
|
|
void add_token_string(const completion_token_output & token) {
|
|
if (command == SLOT_COMMAND_RELEASE) {
|
|
return;
|
|
}
|
|
generated_token_probs.push_back(token);
|
|
}
|
|
|
|
void release() {
|
|
if (state == SLOT_STATE_PROCESSING) {
|
|
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
|
command = SLOT_COMMAND_RELEASE;
|
|
}
|
|
}
|
|
|
|
json get_formated_timings() const {
|
|
return json {
|
|
{"prompt_n", n_prompt_tokens_processed},
|
|
{"prompt_ms", t_prompt_processing},
|
|
{"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed},
|
|
{"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed},
|
|
|
|
{"predicted_n", n_decoded},
|
|
{"predicted_ms", t_token_generation},
|
|
{"predicted_per_token_ms", t_token_generation / n_decoded},
|
|
{"predicted_per_second", 1e3 / t_token_generation * n_decoded},
|
|
};
|
|
}
|
|
|
|
result_timings get_timings() const {
|
|
result_timings timings;
|
|
timings.prompt_n = n_prompt_tokens_processed;
|
|
timings.prompt_ms = t_prompt_processing;
|
|
timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
|
|
timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
|
|
|
timings.predicted_n = n_decoded;
|
|
timings.predicted_ms = t_token_generation;
|
|
timings.predicted_per_token_ms = t_token_generation / n_decoded;
|
|
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
|
|
|
|
// Add speculative metrics
|
|
if (n_draft_total > 0) {
|
|
timings.draft_n = n_draft_total;
|
|
timings.draft_n_accepted = n_draft_accepted;
|
|
}
|
|
|
|
return timings;
|
|
}
|
|
|
|
const common_chat_msg& update_chat_msg(std::vector<common_chat_msg_diff>& diffs) {
|
|
auto previous_msg = chat_msg;
|
|
auto new_msg = common_chat_parse(
|
|
generated_text,
|
|
/* is_partial= */ stop != STOP_TYPE_EOS,
|
|
params.oaicompat_chat_syntax);
|
|
if (!new_msg.empty()) {
|
|
new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
|
|
chat_msg = new_msg;
|
|
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
|
|
}
|
|
//LLAMA_LOG_DEBUG("Parsing chat message: %s\n", generated_text.c_str());
|
|
//LLAMA_LOG_DEBUG("Parsing chat message: %s\n", chat_msg.reasoning_content.c_str());
|
|
//LLAMA_LOG_DEBUG("Parsing chat message: %s\n", chat_msg.content.c_str());
|
|
return chat_msg;
|
|
}
|
|
|
|
|
|
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
|
size_t stop_pos = std::string::npos;
|
|
|
|
for (const std::string & word : params.antiprompt) {
|
|
size_t pos;
|
|
|
|
if (is_full_stop) {
|
|
const size_t tmp = word.size() + last_token_size;
|
|
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
|
|
|
pos = text.find(word, from_pos);
|
|
} else {
|
|
pos = string_find_partial_stop(text, word);
|
|
}
|
|
|
|
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
|
|
if (is_full_stop) {
|
|
stopped_word = true;
|
|
stopping_word = word;
|
|
has_next_token = false;
|
|
}
|
|
stop_pos = pos;
|
|
}
|
|
}
|
|
|
|
return stop_pos;
|
|
}
|
|
|
|
void print_timings() const {
|
|
char buffer[512];
|
|
|
|
double t_token = t_prompt_processing / n_prompt_tokens_processed;
|
|
double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
|
|
|
snprintf(buffer, 512, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
|
|
t_prompt_processing, n_prompt_tokens_processed,
|
|
t_token, n_tokens_second);
|
|
|
|
LOG_INFO(buffer, {
|
|
{"id_slot", id},
|
|
{"id_task", id_task},
|
|
{"t_prompt_processing", t_prompt_processing},
|
|
{"n_prompt_tokens_processed", n_prompt_tokens_processed},
|
|
{"t_token", t_token},
|
|
{"n_tokens_second", n_tokens_second},
|
|
});
|
|
|
|
t_token = t_token_generation / n_decoded;
|
|
n_tokens_second = 1e3 / t_token_generation * n_decoded;
|
|
|
|
snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
|
|
t_token_generation, n_decoded,
|
|
t_token, n_tokens_second);
|
|
|
|
LOG_INFO(buffer, {
|
|
{"id_slot", id},
|
|
{"id_task", id_task},
|
|
{"t_token_generation", t_token_generation},
|
|
{"n_decoded", n_decoded},
|
|
{"t_token", t_token},
|
|
{"n_tokens_second", n_tokens_second},
|
|
});
|
|
|
|
snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
|
|
|
|
LOG_INFO(buffer, {
|
|
{"id_slot", id},
|
|
{"id_task", id_task},
|
|
{"t_prompt_processing", t_prompt_processing},
|
|
{"t_token_generation", t_token_generation},
|
|
{"t_total", t_prompt_processing + t_token_generation},
|
|
});
|
|
}
|
|
};
|
|
|
|
struct server_metrics {
|
|
int64_t t_start = 0;
|
|
|
|
uint64_t n_prompt_tokens_processed_total = 0;
|
|
uint64_t t_prompt_processing_total = 0;
|
|
uint64_t n_tokens_predicted_total = 0;
|
|
uint64_t t_tokens_generation_total = 0;
|
|
|
|
uint64_t n_prompt_tokens_processed = 0;
|
|
uint64_t t_prompt_processing = 0;
|
|
|
|
uint64_t n_tokens_predicted = 0;
|
|
uint64_t t_tokens_generation = 0;
|
|
|
|
void init() {
|
|
t_start = ggml_time_us();
|
|
}
|
|
|
|
void on_prompt_eval(const server_slot & slot) {
|
|
n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
|
|
n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
|
|
t_prompt_processing += slot.t_prompt_processing;
|
|
t_prompt_processing_total += slot.t_prompt_processing;
|
|
}
|
|
|
|
void on_prediction(const server_slot & slot) {
|
|
n_tokens_predicted_total += slot.n_decoded;
|
|
n_tokens_predicted += slot.n_decoded;
|
|
t_tokens_generation += slot.t_token_generation;
|
|
t_tokens_generation_total += slot.t_token_generation;
|
|
}
|
|
|
|
void reset_bucket() {
|
|
n_prompt_tokens_processed = 0;
|
|
t_prompt_processing = 0;
|
|
n_tokens_predicted = 0;
|
|
t_tokens_generation = 0;
|
|
}
|
|
};
|
|
|
|
struct server_queue {
|
|
int id = 0;
|
|
bool running;
|
|
|
|
// queues
|
|
std::vector<server_task> queue_tasks;
|
|
std::vector<server_task> queue_tasks_deferred;
|
|
|
|
std::vector<server_task_multi> queue_multitasks;
|
|
|
|
std::mutex mutex_tasks;
|
|
std::condition_variable condition_tasks;
|
|
|
|
// callback functions
|
|
std::function<void(server_task &)> callback_new_task;
|
|
std::function<void(server_task_multi &)> callback_finish_multitask;
|
|
std::function<void(void)> callback_update_slots;
|
|
|
|
// Add a new task to the end of the queue
|
|
int post(server_task task) {
|
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
if (task.id == -1) {
|
|
task.id = id++;
|
|
LOG_VERBOSE("new task id", {{"new_id", task.id}});
|
|
}
|
|
queue_tasks.push_back(std::move(task));
|
|
condition_tasks.notify_one();
|
|
return task.id;
|
|
}
|
|
|
|
// Add a new task, but defer until one slot is available
|
|
void defer(server_task task) {
|
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
queue_tasks_deferred.push_back(std::move(task));
|
|
}
|
|
|
|
// Get the next id for creating anew task
|
|
int get_new_id() {
|
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
int new_id = id++;
|
|
LOG_VERBOSE("new task id", {{"new_id", new_id}});
|
|
return new_id;
|
|
}
|
|
|
|
// Register function to process a new task
|
|
void on_new_task(std::function<void(server_task &)> callback) {
|
|
callback_new_task = std::move(callback);
|
|
}
|
|
|
|
// Register function to process a multitask when it is finished
|
|
void on_finish_multitask(std::function<void(server_task_multi&)> callback) {
|
|
callback_finish_multitask = std::move(callback);
|
|
}
|
|
|
|
// Register the function to be called when all slots data is ready to be processed
|
|
void on_update_slots(std::function<void(void)> callback) {
|
|
callback_update_slots = std::move(callback);
|
|
}
|
|
|
|
// Call when the state of one slot is changed
|
|
void notify_slot_changed() {
|
|
// move deferred tasks back to main loop
|
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
for (auto & task : queue_tasks_deferred) {
|
|
queue_tasks.push_back(std::move(task));
|
|
}
|
|
queue_tasks_deferred.clear();
|
|
}
|
|
|
|
// end the start_loop routine
|
|
void terminate() {
|
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
running = false;
|
|
condition_tasks.notify_all();
|
|
}
|
|
|
|
/**
|
|
* Main loop consists of these steps:
|
|
* - Wait until a new task arrives
|
|
* - Process the task (i.e. maybe copy data into slot)
|
|
* - Check if multitask is finished
|
|
* - Update all slots
|
|
*/
|
|
void start_loop() {
|
|
running = true;
|
|
|
|
while (true) {
|
|
LOG_VERBOSE("new task may arrive", {});
|
|
|
|
while (true) {
|
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
if (queue_tasks.empty()) {
|
|
lock.unlock();
|
|
break;
|
|
}
|
|
server_task task = queue_tasks.front();
|
|
queue_tasks.erase(queue_tasks.begin());
|
|
lock.unlock();
|
|
LOG_VERBOSE("callback_new_task", {{"id_task", task.id}});
|
|
callback_new_task(task);
|
|
}
|
|
|
|
LOG_VERBOSE("update_multitasks", {});
|
|
|
|
// check if we have any finished multitasks
|
|
auto queue_iterator = queue_multitasks.begin();
|
|
while (queue_iterator != queue_multitasks.end()) {
|
|
if (queue_iterator->subtasks_remaining.empty()) {
|
|
// all subtasks done == multitask is done
|
|
server_task_multi current_multitask = *queue_iterator;
|
|
callback_finish_multitask(current_multitask);
|
|
// remove this multitask
|
|
queue_iterator = queue_multitasks.erase(queue_iterator);
|
|
} else {
|
|
++queue_iterator;
|
|
}
|
|
}
|
|
|
|
// all tasks in the current loop is processed, slots data is now ready
|
|
LOG_VERBOSE("callback_update_slots", {});
|
|
|
|
callback_update_slots();
|
|
|
|
LOG_VERBOSE("wait for new task", {});
|
|
{
|
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
if (queue_tasks.empty()) {
|
|
if (!running) {
|
|
LOG_VERBOSE("ending start_loop", {});
|
|
return;
|
|
}
|
|
condition_tasks.wait(lock, [&]{
|
|
return (!queue_tasks.empty() || !running);
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
//
|
|
// functions to manage multitasks
|
|
//
|
|
|
|
// add a multitask by specifying the id of all subtask (subtask is a server_task)
|
|
void add_multitask(int id_multi, std::vector<int> & sub_ids) {
|
|
std::lock_guard<std::mutex> lock(mutex_tasks);
|
|
server_task_multi multi;
|
|
multi.id = id_multi;
|
|
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
|
|
queue_multitasks.push_back(multi);
|
|
}
|
|
|
|
// updatethe remaining subtasks, while appending results to multitask
|
|
void update_multitask(int id_multi, int id_sub, server_task_result & result) {
|
|
std::lock_guard<std::mutex> lock(mutex_tasks);
|
|
for (auto & multitask : queue_multitasks) {
|
|
if (multitask.id == id_multi) {
|
|
multitask.subtasks_remaining.erase(id_sub);
|
|
multitask.results.push_back(result);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
struct server_response {
|
|
typedef std::function<void(int, int, server_task_result &)> callback_multitask_t;
|
|
callback_multitask_t callback_update_multitask;
|
|
|
|
// for keeping track of all tasks waiting for the result
|
|
std::set<int> waiting_task_ids;
|
|
|
|
// the main result queue
|
|
std::vector<server_task_result> queue_results;
|
|
|
|
std::mutex mutex_results;
|
|
std::condition_variable condition_results;
|
|
|
|
// add the id_task to the list of tasks waiting for response
|
|
void add_waiting_task_id(int id_task) {
|
|
LOG_VERBOSE("waiting for task id", {{"id_task", id_task}});
|
|
|
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
waiting_task_ids.insert(id_task);
|
|
}
|
|
|
|
// when the request is finished, we can remove task associated with it
|
|
void remove_waiting_task_id(int id_task) {
|
|
LOG_VERBOSE("remove waiting for task id", {{"id_task", id_task}});
|
|
|
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
waiting_task_ids.erase(id_task);
|
|
}
|
|
|
|
// This function blocks the thread until there is a response for this id_task
|
|
server_task_result recv(int id_task) {
|
|
while (true) {
|
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
condition_results.wait(lock, [&]{
|
|
return !queue_results.empty();
|
|
});
|
|
|
|
for (int i = 0; i < (int) queue_results.size(); i++) {
|
|
if (queue_results[i].id == id_task) {
|
|
assert(queue_results[i].id_multi == -1);
|
|
server_task_result res = queue_results[i];
|
|
queue_results.erase(queue_results.begin() + i);
|
|
return res;
|
|
}
|
|
}
|
|
}
|
|
|
|
// should never reach here
|
|
}
|
|
|
|
// Register the function to update multitask
|
|
void on_multitask_update(callback_multitask_t callback) {
|
|
callback_update_multitask = std::move(callback);
|
|
}
|
|
|
|
// Send a new result to a waiting id_task
|
|
void send(server_task_result result) {
|
|
LOG_VERBOSE("send new result", {{"id_task", result.id}});
|
|
|
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
for (const auto & id_task : waiting_task_ids) {
|
|
// LOG_TEE("waiting task id %i \n", id_task);
|
|
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
|
|
if (result.id_multi == id_task) {
|
|
LOG_VERBOSE("callback_update_multitask", {{"id_task", id_task}});
|
|
callback_update_multitask(id_task, result.id, result);
|
|
continue;
|
|
}
|
|
|
|
if (result.id == id_task) {
|
|
LOG_VERBOSE("queue_results.push_back", {{"id_task", id_task}});
|
|
queue_results.push_back(result);
|
|
condition_results.notify_all();
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
struct server_context {
|
|
llama_model * model = nullptr;
|
|
llama_context * ctx = nullptr;
|
|
std::vector<llama_lora_adapter_container> lora_adapters;
|
|
|
|
gpt_params params;
|
|
|
|
llama_batch batch;
|
|
|
|
bool clean_kv_cache = true;
|
|
bool add_bos_token = true;
|
|
|
|
// For speculative decoding
|
|
llama_model * model_draft = nullptr;
|
|
llama_context * ctx_draft = nullptr;
|
|
llama_context_params cparams_dft;
|
|
|
|
int32_t n_ctx; // total context for all clients / slots
|
|
|
|
// system prompt
|
|
bool system_need_update = false;
|
|
|
|
std::string system_prompt;
|
|
std::vector<llama_token> system_tokens;
|
|
|
|
// slots / clients
|
|
std::vector<server_slot> slots;
|
|
json default_generation_settings_for_props;
|
|
|
|
server_queue queue_tasks;
|
|
server_response queue_results;
|
|
|
|
server_metrics metrics;
|
|
|
|
common_chat_templates_ptr chat_templates;
|
|
oaicompat_parser_options oai_parser_opt;
|
|
// Necessary similarity of prompt for slot selection
|
|
float slot_prompt_similarity = 0.0f;
|
|
|
|
~server_context() {
|
|
if (ctx) {
|
|
llama_free(ctx);
|
|
ctx = nullptr;
|
|
}
|
|
|
|
if (model) {
|
|
llama_free_model(model);
|
|
model = nullptr;
|
|
}
|
|
|
|
// Free draft model and context if they exist
|
|
if (ctx_draft) {
|
|
llama_free(ctx_draft);
|
|
ctx_draft = nullptr;
|
|
}
|
|
if (model_draft) {
|
|
llama_free_model(model_draft);
|
|
model_draft = nullptr;
|
|
}
|
|
|
|
// Clear any sampling context
|
|
for (server_slot & slot : slots) {
|
|
if (slot.ctx_sampling != nullptr) {
|
|
llama_sampling_free(slot.ctx_sampling);
|
|
}
|
|
if (slot.ctx_dft) {
|
|
llama_free(slot.ctx_dft);
|
|
}
|
|
if (slot.spec) {
|
|
llama_speculative_free(slot.spec);
|
|
}
|
|
llama_batch_free(slot.batch_spec);
|
|
}
|
|
|
|
llama_batch_free(batch);
|
|
}
|
|
|
|
bool load_model(const gpt_params & params_) {
|
|
params = params_;
|
|
|
|
// dedicate one sequence to the system prompt
|
|
params.n_parallel += 1;
|
|
|
|
llama_init_result llama_init = llama_init_from_gpt_params(params);
|
|
|
|
model = llama_init.model;
|
|
ctx = llama_init.context;
|
|
lora_adapters = llama_init.lora_adapters;
|
|
params.n_parallel -= 1; // but be sneaky about it
|
|
if (model == nullptr) {
|
|
LOG_ERROR("unable to load model", {{"model", params.model}});
|
|
return false;
|
|
}
|
|
|
|
n_ctx = llama_n_ctx(ctx);
|
|
|
|
add_bos_token = llama_should_add_bos_token(model);
|
|
GGML_ASSERT(llama_add_eos_token(model) != 1);
|
|
|
|
chat_templates = common_chat_templates_init(model, params.chat_template);
|
|
try {
|
|
common_chat_format_example(chat_templates.get(), params.use_jinja, {});
|
|
}
|
|
catch (const std::exception& e) {
|
|
LOG_WARNING("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
|
chat_templates = common_chat_templates_init(model, "chatml");
|
|
}
|
|
|
|
|
|
// Load draft model for speculative decoding if specified
|
|
if (!params.model_draft.empty()) {
|
|
LOG_INFO("loading draft model", {{"model", params.model_draft}});
|
|
|
|
gpt_params params_dft;
|
|
params_dft.model = params.model_draft;
|
|
params_dft.n_ctx = params.n_ctx_draft == 0 ? params.n_ctx / params.n_parallel : params.n_ctx_draft;
|
|
params_dft.n_gpu_layers = params.n_gpu_layers_draft;
|
|
params_dft.n_parallel = 1;
|
|
params_dft.cache_type_k = params.cache_type_k_draft.empty() ? params.cache_type_k : params.cache_type_k_draft;
|
|
params_dft.cache_type_v = params.cache_type_v_draft.empty() ? params.cache_type_v : params.cache_type_v_draft;
|
|
params_dft.flash_attn = params.flash_attn;
|
|
|
|
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft);
|
|
|
|
llama_model * model_dft = llama_init_dft.model;
|
|
if (model_dft == nullptr) {
|
|
LOG_ERROR("failed to load draft model", {{"model", params.model_draft}});
|
|
return false;
|
|
}
|
|
|
|
if (!llama_speculative_are_compatible(ctx, llama_init_dft.context)) {
|
|
LOG_INFO("the draft model is not compatible with the target model. tokens will be translated between the draft and target models.", {{}});
|
|
}
|
|
|
|
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context);
|
|
|
|
cparams_dft = llama_context_params_from_gpt_params(params_dft);
|
|
cparams_dft.n_batch = n_ctx_dft;
|
|
|
|
model_draft = llama_init_dft.model;
|
|
ctx_draft = llama_init_dft.context;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
|
|
void init() {
|
|
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
|
|
|
|
LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}});
|
|
|
|
for (int i = 0; i < params.n_parallel; i++) {
|
|
server_slot slot;
|
|
|
|
slot.id = i;
|
|
slot.n_ctx = n_ctx_slot;
|
|
slot.n_predict = params.n_predict;
|
|
|
|
LOG_INFO("new slot", {
|
|
{"id_slot", slot.id},
|
|
{"n_ctx_slot", slot.n_ctx}
|
|
});
|
|
|
|
const int ga_n = params.grp_attn_n;
|
|
const int ga_w = params.grp_attn_w;
|
|
|
|
if (ga_n != 1) {
|
|
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
|
|
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
|
|
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
|
|
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
|
|
|
|
LOG_INFO("slot self-extend", {
|
|
{"id_slot", slot.id},
|
|
{"ga_n", ga_n},
|
|
{"ga_w", ga_w}
|
|
});
|
|
}
|
|
|
|
slot.ga_i = 0;
|
|
slot.ga_n = ga_n;
|
|
slot.ga_w = ga_w;
|
|
|
|
slot.sparams = params.sparams;
|
|
|
|
// Initialize speculative decoding if a draft model is loaded
|
|
if (ctx_draft) {
|
|
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
|
|
|
|
slot.ctx_dft = llama_new_context_with_model(model_draft, cparams_dft);
|
|
if (slot.ctx_dft == nullptr) {
|
|
LOG_ERROR("failed to create draft context", {});
|
|
return;
|
|
}
|
|
|
|
slot.spec = llama_speculative_init(ctx, slot.ctx_dft);
|
|
if (slot.spec == nullptr) {
|
|
LOG_ERROR("failed to create speculator", {});
|
|
return;
|
|
}
|
|
for (auto & pair : params.replacements_draft) {
|
|
llama_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
|
|
}
|
|
|
|
}
|
|
|
|
slot.reset();
|
|
|
|
slots.push_back(slot);
|
|
}
|
|
|
|
default_generation_settings_for_props = get_formated_generation(slots.front());
|
|
default_generation_settings_for_props["seed"] = -1;
|
|
|
|
// the update_slots() logic will always submit a maximum of n_batch tokens
|
|
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
|
{
|
|
const int32_t n_batch = llama_n_batch(ctx);
|
|
|
|
// only a single seq_id per token is needed
|
|
batch = llama_batch_init(n_batch, 0, 1);
|
|
}
|
|
|
|
metrics.init();
|
|
|
|
// thinking is enabled if:
|
|
// 1. It's not explicitly disabled (reasoning_budget == 0)
|
|
// 2. The chat template supports it
|
|
const bool enable_thinking = params.use_jinja && params.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
|
|
//LLAMA_LOG_INFO("Enable thinking? %d\n", enable_thinking);
|
|
|
|
oai_parser_opt = {
|
|
/* use_jinja */ params.use_jinja,
|
|
/* prefill_assistant */ params.prefill_assistant,
|
|
/* reasoning_format */ params.reasoning_format,
|
|
/* chat_template_kwargs */ params.default_template_kwargs,
|
|
/* common_chat_templates */ chat_templates.get(),
|
|
/* allow_image */ false,
|
|
/* allow_audio */ false,
|
|
/* enable_thinking */ enable_thinking,
|
|
};
|
|
}
|
|
|
|
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special) const {
|
|
// TODO: currently, we tokenize using special tokens by default
|
|
// this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
|
|
// but it's better compared to completely ignoring ChatML and other chat templates
|
|
const bool TMP_FORCE_SPECIAL = true;
|
|
|
|
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
|
// or the first element of the json_prompt array is a string.
|
|
std::vector<llama_token> prompt_tokens;
|
|
|
|
if (json_prompt.is_array()) {
|
|
bool first = true;
|
|
for (const auto & p : json_prompt) {
|
|
if (p.is_string()) {
|
|
auto s = p.template get<std::string>();
|
|
|
|
std::vector<llama_token> p;
|
|
if (first) {
|
|
p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
|
|
first = false;
|
|
} else {
|
|
p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
|
|
}
|
|
|
|
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
|
} else {
|
|
if (first) {
|
|
first = false;
|
|
}
|
|
|
|
prompt_tokens.push_back(p.template get<llama_token>());
|
|
}
|
|
}
|
|
} else {
|
|
auto s = json_prompt.template get<std::string>();
|
|
prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
|
|
}
|
|
|
|
return prompt_tokens;
|
|
}
|
|
|
|
server_slot * get_slot_by_id(int id) {
|
|
for (server_slot & slot : slots) {
|
|
if (slot.id == id) {
|
|
return &slot;
|
|
}
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
server_slot * get_available_slot(const std::string & prompt) {
|
|
server_slot * ret = nullptr;
|
|
|
|
// find the slot that has at least n% prompt similarity
|
|
if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
|
|
int max_lcp_len = 0;
|
|
float similarity = 0;
|
|
|
|
for (server_slot & slot : slots) {
|
|
// skip the slot if it is not available
|
|
if (!slot.available()) {
|
|
continue;
|
|
}
|
|
|
|
// skip the slot if it does not contains prompt
|
|
if (!slot.prompt.is_string()) {
|
|
continue;
|
|
}
|
|
|
|
// current slot's prompt
|
|
std::string slot_prompt = slot.prompt.get<std::string>();
|
|
|
|
// length of the current slot's prompt
|
|
int slot_prompt_len = slot_prompt.size();
|
|
|
|
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
|
int lcp_len = common_part(slot_prompt, prompt);
|
|
|
|
// fraction of the common substring length compared to the current slot's prompt length
|
|
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
|
|
|
|
// select the current slot if the criteria match
|
|
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
|
|
max_lcp_len = lcp_len;
|
|
ret = &slot;
|
|
}
|
|
}
|
|
|
|
if (ret != nullptr) {
|
|
LOG_VERBOSE("selected slot by lcp similarity", {
|
|
{"id_slot", ret->id},
|
|
{"max_lcp_len", max_lcp_len},
|
|
{"similarity", similarity},
|
|
});
|
|
}
|
|
}
|
|
|
|
// find the slot that has been least recently used
|
|
if (ret == nullptr) {
|
|
int64_t t_last = ggml_time_us();
|
|
for (server_slot & slot : slots) {
|
|
// skip the slot if it is not available
|
|
if (!slot.available()) {
|
|
continue;
|
|
}
|
|
|
|
// select the current slot if the criteria match
|
|
if (slot.t_last_used < t_last) {
|
|
t_last = slot.t_last_used;
|
|
ret = &slot;
|
|
}
|
|
}
|
|
|
|
if (ret != nullptr) {
|
|
LOG_VERBOSE("selected slot by lru", {
|
|
{"id_slot", ret->id},
|
|
{"t_last", t_last},
|
|
});
|
|
}
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
|
slot_params default_params;
|
|
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
|
llama_sampling_params default_sparams = params.sparams;
|
|
auto & data = task.data;
|
|
|
|
if (data.count("__oaicompat") != 0) {
|
|
slot.oaicompat = true;
|
|
slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
|
} else {
|
|
slot.oaicompat = false;
|
|
slot.oaicompat_model = "";
|
|
}
|
|
slot.params.timings_per_token = json_value(data, "timings_per_token", false);
|
|
slot.params.stream = json_value(data, "stream", false);
|
|
auto stream_opt = json_value(data, "stream_options", json::object());
|
|
slot.params.include_usage = json_value(stream_opt, "include_usage", false);
|
|
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
|
|
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
|
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
|
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
|
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
|
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
|
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
|
|
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
|
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
|
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
|
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
|
|
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
|
|
slot.sparams.top_n_sigma = json_value(data, "top_n_sigma", default_sparams.top_n_sigma);
|
|
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
|
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
|
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
|
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
|
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
|
|
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
|
|
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
|
|
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
|
|
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
|
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
|
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
|
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
|
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
|
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
|
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
|
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
|
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
|
|
|
slot.params.post_sampling_probs = json_value(data, "post_sampling_probs", default_params.post_sampling_probs);
|
|
|
|
// speculative decoding parameters
|
|
slot.params.speculative.n_max = json_value(data, "speculative.n_max", params.n_draft);
|
|
slot.params.speculative.n_min = json_value(data, "speculative.n_min", params.n_draft_min);
|
|
slot.params.speculative.p_min = json_value(data, "speculative.p_min", params.p_draft_min);
|
|
|
|
// Clamp speculative parameters
|
|
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
|
|
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 0);
|
|
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
|
|
|
|
if (slot.sparams.penalty_last_n < -1) {
|
|
throw std::runtime_error("Error: repeat_last_n must be >= -1");
|
|
}
|
|
|
|
if (slot.sparams.dry_penalty_last_n < -1) {
|
|
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
|
|
}
|
|
|
|
if (slot.sparams.penalty_last_n == -1) {
|
|
// note: should be the slot's context and not the full context, but it's ok
|
|
slot.sparams.penalty_last_n = llama_n_ctx(ctx);
|
|
}
|
|
|
|
if (slot.sparams.dry_penalty_last_n == -1) {
|
|
slot.sparams.dry_penalty_last_n = llama_n_ctx(ctx);
|
|
|
|
}
|
|
if (slot.sparams.dry_base < 1.0f)
|
|
{
|
|
slot.sparams.dry_base = default_sparams.dry_base;
|
|
}
|
|
|
|
// sequence breakers for DRY
|
|
{
|
|
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
|
|
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
|
|
|
if (data.contains("dry_sequence_breakers")) {
|
|
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
|
if (slot.sparams.dry_sequence_breakers.empty()) {
|
|
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
// process "json_schema" and "grammar"
|
|
if (data.contains("json_schema") && !data.contains("grammar")) {
|
|
try {
|
|
auto schema = json_value(data, "json_schema", json::object());
|
|
LLAMA_LOG_DEBUG("JSON schema: %s\n", schema.dump(2).c_str());
|
|
slot.sparams.grammar = json_schema_to_grammar(schema);
|
|
LLAMA_LOG_DEBUG("Converted grammar: %s\n", slot.sparams.grammar.c_str());
|
|
}
|
|
catch (const std::exception& e) {
|
|
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
|
}
|
|
}
|
|
else {
|
|
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
|
LLAMA_LOG_DEBUG("Grammar: %s\n", slot.sparams.grammar.c_str());
|
|
slot.sparams.grammar_lazy = json_value(data, "grammar_lazy", default_sparams.grammar_lazy);
|
|
LLAMA_LOG_DEBUG("Grammar lazy: %s\n", slot.sparams.grammar_lazy ? "true" : "false");
|
|
}
|
|
|
|
if (slot.params.cache_prompt && slot.ga_n != 1) {
|
|
LOG_WARNING("cache_prompt is not supported with group-attention", {});
|
|
slot.params.cache_prompt = false;
|
|
}
|
|
|
|
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
|
// Might be better to reject the request with a 400 ?
|
|
LOG_WARNING("Max tokens to predict exceeds server configuration", {
|
|
{"params.n_predict", slot.params.n_predict},
|
|
{"slot.n_predict", slot.n_predict},
|
|
});
|
|
slot.params.n_predict = slot.n_predict;
|
|
}
|
|
|
|
// infill
|
|
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
|
|
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
|
|
|
|
// get prompt
|
|
if (!task.infill) {
|
|
const auto & prompt = data.find("prompt");
|
|
if (prompt == data.end()) {
|
|
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
|
|
if ((prompt->is_string()) ||
|
|
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
|
|
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
|
slot.prompt = *prompt;
|
|
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
|
slot.prompt = prompt->at(0);
|
|
} else {
|
|
send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// penalize user-provided tokens
|
|
{
|
|
slot.sparams.penalty_prompt_tokens.clear();
|
|
slot.sparams.use_penalty_prompt_tokens = false;
|
|
|
|
const auto & penalty_prompt = data.find("penalty_prompt");
|
|
|
|
if (penalty_prompt != data.end()) {
|
|
if (penalty_prompt->is_string()) {
|
|
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
|
|
slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
|
|
|
|
if (slot.params.n_predict > 0) {
|
|
slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
|
|
}
|
|
slot.sparams.use_penalty_prompt_tokens = true;
|
|
|
|
LOG_VERBOSE("penalty_prompt_tokens", {
|
|
{"id_slot", slot.id},
|
|
{"tokens", slot.sparams.penalty_prompt_tokens},
|
|
});
|
|
}
|
|
else if (penalty_prompt->is_array()) {
|
|
const auto n_tokens = penalty_prompt->size();
|
|
slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
|
|
|
|
const int n_vocab = llama_n_vocab(model);
|
|
for (const auto & penalty_token : *penalty_prompt) {
|
|
if (penalty_token.is_number_integer()) {
|
|
const auto tok = penalty_token.get<llama_token>();
|
|
if (tok >= 0 && tok < n_vocab) {
|
|
slot.sparams.penalty_prompt_tokens.push_back(tok);
|
|
}
|
|
}
|
|
}
|
|
slot.sparams.use_penalty_prompt_tokens = true;
|
|
|
|
LOG_VERBOSE("penalty_prompt_tokens", {
|
|
{"id_slot", slot.id},
|
|
{"tokens", slot.sparams.penalty_prompt_tokens},
|
|
});
|
|
}
|
|
}
|
|
}
|
|
{
|
|
auto it = data.find("chat_format");
|
|
if (it != data.end()) {
|
|
slot.params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
|
|
LLAMA_LOG_DEBUG("Chat format: %s\n", common_chat_format_name(slot.params.oaicompat_chat_syntax.format));
|
|
}
|
|
else {
|
|
slot.params.oaicompat_chat_syntax.format = default_params.oaicompat_chat_syntax.format;
|
|
}
|
|
common_reasoning_format reasoning_format = params.reasoning_format;
|
|
if (data.contains("reasoning_format")) {
|
|
reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
|
|
}
|
|
slot.params.oaicompat_chat_syntax.reasoning_format = reasoning_format;
|
|
slot.params.oaicompat_chat_syntax.reasoning_in_content = slot.params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
|
|
slot.params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
|
|
|
|
slot.params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
|
|
}
|
|
{
|
|
|
|
const auto preserved_tokens = data.find("preserved_tokens");
|
|
if (preserved_tokens != data.end()) {
|
|
for (const auto& t : *preserved_tokens) {
|
|
auto ids = llama_tokenize(model, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
|
|
if (ids.size() == 1) {
|
|
LOG("Preserved token: %d\n", ids[0]);
|
|
slot.sparams.preserved_tokens.insert(ids[0]);
|
|
}
|
|
else {
|
|
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
|
|
LOG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
|
|
}
|
|
}
|
|
}
|
|
const auto grammar_triggers = data.find("grammar_triggers");
|
|
if (grammar_triggers != data.end()) {
|
|
for (const auto& t : *grammar_triggers) {
|
|
server_grammar_trigger ct(t);
|
|
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
|
const auto& word = ct.value.value;
|
|
auto ids = llama_tokenize(model, word, /* add_special= */ false, /* parse_special= */ true);
|
|
if (ids.size() == 1) {
|
|
auto token = ids[0];
|
|
if (std::find(slot.sparams.preserved_tokens.begin(), slot.sparams.preserved_tokens.end(), (llama_token)token) == slot.sparams.preserved_tokens.end()) {
|
|
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
|
|
}
|
|
LOG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
|
|
common_grammar_trigger trigger;
|
|
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
|
|
trigger.value = word;
|
|
trigger.token = token;
|
|
slot.sparams.grammar_triggers.push_back(std::move(trigger));
|
|
}
|
|
else {
|
|
LOG("Grammar trigger word: `%s`\n", word.c_str());
|
|
slot.sparams.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word });
|
|
}
|
|
}
|
|
else {
|
|
//slot.sparams.grammar_triggers.push_back(ct);
|
|
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
|
|
LLAMA_LOG_DEBUG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
|
|
}
|
|
else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
|
|
LLAMA_LOG_DEBUG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
|
|
}
|
|
else {
|
|
throw std::runtime_error("Unknown grammar trigger type");
|
|
}
|
|
slot.sparams.grammar_triggers.emplace_back(std::move(ct.value));
|
|
}
|
|
}
|
|
}
|
|
|
|
if (slot.sparams.grammar_lazy && slot.sparams.grammar_triggers.empty()) {
|
|
throw std::runtime_error("Error: no triggers set for lazy grammar!");
|
|
}
|
|
}
|
|
|
|
{
|
|
slot.sparams.logit_bias.clear();
|
|
|
|
if (json_value(data, "ignore_eos", false)) {
|
|
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
|
}
|
|
|
|
const auto & logit_bias = data.find("logit_bias");
|
|
if (logit_bias != data.end() && logit_bias->is_array()) {
|
|
const int n_vocab = llama_n_vocab(model);
|
|
for (const auto & el : *logit_bias) {
|
|
// TODO: we may want to throw errors here, in case "el" is incorrect
|
|
if (el.is_array() && el.size() == 2) {
|
|
float bias;
|
|
if (el[1].is_number()) {
|
|
bias = el[1].get<float>();
|
|
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
|
|
bias = -INFINITY;
|
|
} else {
|
|
continue;
|
|
}
|
|
|
|
if (el[0].is_number_integer()) {
|
|
llama_token tok = el[0].get<llama_token>();
|
|
if (tok >= 0 && tok < n_vocab) {
|
|
slot.sparams.logit_bias[tok] = bias;
|
|
}
|
|
} else if (el[0].is_string()) {
|
|
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
|
|
for (auto tok : toks) {
|
|
slot.sparams.logit_bias[tok] = bias;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
{
|
|
slot.params.antiprompt.clear();
|
|
|
|
const auto & stop = data.find("stop");
|
|
if (stop != data.end() && stop->is_array()) {
|
|
for (const auto & word : *stop) {
|
|
if (!word.empty()) {
|
|
slot.params.antiprompt.push_back(word);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
{
|
|
const auto samplers = data.find("samplers");
|
|
if (samplers != data.end()) {
|
|
if (samplers->is_array()) {
|
|
slot.sparams.samplers_sequence = llama_sampling_types_from_names(*samplers, false);
|
|
}
|
|
else if (samplers->is_string()) {
|
|
slot.sparams.samplers_sequence = llama_sampling_types_from_chars(samplers->get<std::string>());
|
|
}
|
|
else {
|
|
slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
|
|
}
|
|
}
|
|
}
|
|
|
|
{
|
|
if (slot.ctx_sampling != nullptr) {
|
|
llama_sampling_free(slot.ctx_sampling);
|
|
}
|
|
slot.ctx_sampling = llama_sampling_init(llama_get_model_vocab(model),slot.sparams);
|
|
if (slot.ctx_sampling == nullptr) {
|
|
// for now, the only error that may happen here is invalid grammar
|
|
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
slot.command = SLOT_COMMAND_LOAD_PROMPT;
|
|
slot.prompt_tokens.clear();
|
|
|
|
LOG_INFO("slot is processing task", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
});
|
|
|
|
return true;
|
|
}
|
|
|
|
void kv_cache_clear() {
|
|
LOG_VERBOSE("clearing KV cache", {});
|
|
|
|
// clear the entire KV cache
|
|
llama_kv_cache_clear(ctx);
|
|
clean_kv_cache = false;
|
|
}
|
|
|
|
void system_prompt_update() {
|
|
LOG_VERBOSE("system prompt update", {
|
|
{"system_prompt", system_prompt},
|
|
});
|
|
|
|
kv_cache_clear();
|
|
system_tokens.clear();
|
|
|
|
if (!system_prompt.empty()) {
|
|
system_tokens = ::llama_tokenize(ctx, system_prompt, true);
|
|
|
|
llama_batch_clear(batch);
|
|
|
|
for (int i = 0; i < (int)system_tokens.size(); ++i) {
|
|
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
|
|
}
|
|
|
|
const int32_t n_batch = llama_n_batch(ctx);
|
|
|
|
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
|
const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
|
|
llama_batch batch_view = {
|
|
n_tokens,
|
|
batch.token + i,
|
|
nullptr,
|
|
batch.pos + i,
|
|
batch.n_seq_id + i,
|
|
batch.seq_id + i,
|
|
batch.logits + i,
|
|
0, 0, 0, // unused
|
|
};
|
|
|
|
if (llama_decode(ctx, batch_view) != 0) {
|
|
LOG_ERROR("llama_decode() failed", {});
|
|
return;
|
|
}
|
|
}
|
|
|
|
// assign the system KV cache to all parallel sequences
|
|
for (int32_t i = 1; i <= params.n_parallel; ++i) {
|
|
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
|
}
|
|
}
|
|
|
|
system_need_update = false;
|
|
}
|
|
|
|
bool system_prompt_set(const std::string & sys_prompt) {
|
|
system_prompt = sys_prompt;
|
|
|
|
LOG_VERBOSE("system prompt process", {
|
|
{"system_prompt", system_prompt},
|
|
});
|
|
|
|
// release all slots
|
|
for (server_slot & slot : slots) {
|
|
slot.release();
|
|
}
|
|
|
|
system_need_update = true;
|
|
return true;
|
|
}
|
|
|
|
bool process_token(completion_token_output & result, server_slot & slot) {
|
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
|
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
|
|
slot.sampled = result.tok;
|
|
|
|
// search stop word and delete it
|
|
slot.generated_text += token_str;
|
|
slot.has_next_token = true;
|
|
|
|
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
|
|
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
|
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
|
}
|
|
|
|
// check if there is incomplete UTF-8 character at the end
|
|
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
|
|
|
|
if (!incomplete) {
|
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
|
|
const std::string str_test = slot.generated_text.substr(pos);
|
|
bool send_text = true;
|
|
|
|
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
|
|
if (stop_pos != std::string::npos) {
|
|
slot.generated_text.erase(
|
|
slot.generated_text.begin() + pos + stop_pos,
|
|
slot.generated_text.end());
|
|
pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
}
|
|
else if (slot.has_next_token && !llama_token_is_eog(model, result.tok)) {
|
|
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
|
|
send_text = stop_pos == std::string::npos;
|
|
}
|
|
|
|
// check if there is any token to predict
|
|
if (send_text) {
|
|
// no send the stop word in the response
|
|
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
|
slot.n_sent_text += result.text_to_send.size();
|
|
// add the token to slot queue and cache
|
|
} else {
|
|
result.text_to_send = "";
|
|
}
|
|
|
|
slot.add_token_string(result);
|
|
if (slot.params.stream) {
|
|
send_partial_response(slot, result);
|
|
}
|
|
}
|
|
|
|
if (incomplete) {
|
|
slot.has_next_token = true;
|
|
}
|
|
|
|
// check the limits
|
|
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) {
|
|
slot.stopped_limit = true;
|
|
slot.has_next_token = false;
|
|
|
|
LOG_VERBOSE("stopped by limit", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_decoded", slot.n_decoded},
|
|
{"n_predict", slot.params.n_predict},
|
|
});
|
|
}
|
|
|
|
if (llama_token_is_eog(model, result.tok)) {
|
|
slot.stopped_eos = true;
|
|
slot.has_next_token = false;
|
|
|
|
LOG_VERBOSE("eos token found", {});
|
|
}
|
|
|
|
auto n_ctx_train = llama_n_ctx_train(model);
|
|
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
|
|
&& slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
|
|
LOG_WARNING("n_predict is not set and self-context extend is disabled."
|
|
" Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", {
|
|
{ "id_slot", slot.id },
|
|
{ "params.n_predict", slot.params.n_predict },
|
|
{ "slot.n_prompt_tokens", slot.n_prompt_tokens },
|
|
{ "slot.n_decoded", slot.n_decoded },
|
|
{ "slot.n_predict", slot.n_predict },
|
|
{ "n_slots", params.n_parallel },
|
|
{ "slot.n_ctx", slot.n_ctx },
|
|
{ "n_ctx", n_ctx },
|
|
{ "n_ctx_train", n_ctx_train },
|
|
{ "ga_n", slot.ga_n },
|
|
});
|
|
slot.truncated = true;
|
|
slot.stopped_limit = true;
|
|
slot.has_next_token = false; // stop prediction
|
|
}
|
|
|
|
LOG_VERBOSE("next token", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"token", result.tok},
|
|
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
|
|
{"has_next_token", slot.has_next_token},
|
|
{"n_remain", slot.n_remaining},
|
|
{"n_decoded", slot.n_decoded},
|
|
{"stopped_eos", slot.stopped_eos},
|
|
{"stopped_word", slot.stopped_word},
|
|
{"stopped_limit", slot.stopped_limit},
|
|
{"stopping_word", slot.stopping_word},
|
|
});
|
|
|
|
return slot.has_next_token; // continue
|
|
}
|
|
|
|
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
|
|
size_t n_probs = slot.sparams.n_probs;
|
|
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
|
|
if (post_sampling) {
|
|
const auto * cur_p = llama_sampling_get_candidates(slot.ctx_sampling);
|
|
const size_t max_probs = cur_p->size;
|
|
|
|
// set probability for sampled token
|
|
for (size_t i = 0; i < max_probs; i++) {
|
|
if (cur_p->data[i].id == result.tok) {
|
|
result.prob = cur_p->data[i].p;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// set probability for top n_probs tokens
|
|
result.probs.reserve(max_probs);
|
|
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
|
result.probs.push_back({
|
|
cur_p->data[i].id,
|
|
llama_detokenize(ctx, {cur_p->data[i].id}, special),
|
|
cur_p->data[i].p
|
|
});
|
|
}
|
|
} else {
|
|
auto&&[sampled_token_p, cur] = get_token_probabilities(ctx, idx, result.tok, n_probs);
|
|
|
|
// set probability for sampled token
|
|
result.prob = sampled_token_p;
|
|
|
|
// set probability for top n_probs tokens
|
|
result.probs.reserve(n_probs);
|
|
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
|
result.probs.push_back({
|
|
cur[i].id,
|
|
llama_detokenize(ctx, {cur[i].id}, special),
|
|
cur[i].p
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
json get_formated_generation(const server_slot & slot) const {
|
|
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
|
|
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
|
|
|
std::vector<std::string> samplers_sequence;
|
|
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
|
|
for (const auto & sampler_type : slot.sparams.samplers_sequence) {
|
|
samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
|
|
}
|
|
|
|
auto grammar_triggers = json::array();
|
|
for (const auto& trigger : slot.sparams.grammar_triggers) {
|
|
grammar_triggers.push_back(trigger.to_json<json>());
|
|
}
|
|
|
|
return json {
|
|
{"n_ctx", slot.n_ctx},
|
|
{"n_predict", slot.n_predict},
|
|
{"model", params.model_alias},
|
|
{"seed", slot.sparams.seed},
|
|
{"temperature", slot.sparams.temp},
|
|
{"dynatemp_range", slot.sparams.dynatemp_range},
|
|
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
|
|
{"top_k", slot.sparams.top_k},
|
|
{"top_p", slot.sparams.top_p},
|
|
{"min_p", slot.sparams.min_p},
|
|
{"tfs_z", slot.sparams.tfs_z},
|
|
{"typical_p", slot.sparams.typical_p},
|
|
{"repeat_last_n", slot.sparams.penalty_last_n},
|
|
{"repeat_penalty", slot.sparams.penalty_repeat},
|
|
{"presence_penalty", slot.sparams.penalty_present},
|
|
{"frequency_penalty", slot.sparams.penalty_freq},
|
|
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
|
|
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
|
|
{"dry_multiplier", slot.sparams.dry_multiplier},
|
|
{"dry_base", slot.sparams.dry_base},
|
|
{"dry_allowed_length", slot.sparams.dry_allowed_length},
|
|
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
|
|
{"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
|
|
{"mirostat", slot.sparams.mirostat},
|
|
{"mirostat_tau", slot.sparams.mirostat_tau},
|
|
{"mirostat_eta", slot.sparams.mirostat_eta},
|
|
{"penalize_nl", slot.sparams.penalize_nl},
|
|
{"stop", slot.params.antiprompt},
|
|
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
|
|
{"n_keep", slot.params.n_keep},
|
|
{"n_discard", slot.params.n_discard},
|
|
{"ignore_eos", ignore_eos},
|
|
{"stream", slot.params.stream},
|
|
{"logit_bias", slot.sparams.logit_bias},
|
|
{"n_probs", slot.sparams.n_probs},
|
|
{"min_keep", slot.sparams.min_keep},
|
|
{"grammar", slot.sparams.grammar},
|
|
{"grammar_triggers", grammar_triggers},
|
|
{"preserved_tokens", slot.sparams.preserved_tokens},
|
|
{"chat_format", common_chat_format_name(slot.params.oaicompat_chat_syntax.format)},
|
|
{"reasoning_format", common_reasoning_format_name(slot.params.oaicompat_chat_syntax.reasoning_format)},
|
|
{"reasoning_in_content", slot.params.oaicompat_chat_syntax.reasoning_in_content},
|
|
{"thinking_forced_open", slot.params.oaicompat_chat_syntax.thinking_forced_open},
|
|
{"samplers", samplers_sequence}
|
|
};
|
|
}
|
|
|
|
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
send_error(task.id, task.id_multi, error, type);
|
|
}
|
|
|
|
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
send_error(slot.id_task, slot.id_multi, error, type);
|
|
}
|
|
|
|
void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
LOG_ERROR("task error", {
|
|
{"id_multi", id_multi},
|
|
{"id_task", id_task},
|
|
{"error", error},
|
|
});
|
|
|
|
server_task_result res;
|
|
res.id = id_task;
|
|
res.id_multi = id_multi;
|
|
res.stop = false;
|
|
res.error = true;
|
|
res.data = format_error_response(error, type);
|
|
|
|
queue_results.send(res);
|
|
}
|
|
|
|
void send_partial_response(server_slot & slot, completion_token_output tkn) {
|
|
server_task_result res;
|
|
res.final_result = false;
|
|
res.id = slot.id_task;
|
|
res.id_multi = slot.id_multi;
|
|
res.error = false;
|
|
res.stop = false;
|
|
res.stream = slot.params.stream;
|
|
res.content = tkn.text_to_send;
|
|
res.post_sampling_probs = slot.params.post_sampling_probs;
|
|
res.oaicompat = slot.params.oaicompat;
|
|
res.oaicompat_model = slot.params.oaicompat_model;
|
|
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
|
res.n_decoded = slot.n_decoded;
|
|
res.n_prompt_tokens = slot.n_prompt_tokens;
|
|
res.data = json {
|
|
{"content", tkn.text_to_send},
|
|
{"stop", false},
|
|
{"id_slot", slot.id},
|
|
{"multimodal", false}
|
|
};
|
|
slot.update_chat_msg(res.oaicompat_msg_diffs);
|
|
|
|
// populate res.probs_output
|
|
if (slot.sparams.n_probs > 0) {
|
|
res.probs_output = {tkn}; // copy the token probs
|
|
res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output);
|
|
}
|
|
|
|
if (slot.oaicompat) {
|
|
res.data["oaicompat_token_ctr"] = slot.n_decoded;
|
|
res.data["model"] = slot.oaicompat_model;
|
|
}
|
|
|
|
// populate timings if this is final response or timings_per_token is enabled
|
|
if (slot.params.timings_per_token) {
|
|
res.timings = slot.get_timings();
|
|
}
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
void send_final_response(server_slot& slot) {
|
|
server_task_result res;
|
|
res.final_result = true;
|
|
res.id = slot.id_task;
|
|
res.id_multi = slot.id_multi;
|
|
res.error = false;
|
|
res.stop = true; // to do: set value
|
|
res.stream = slot.params.stream;
|
|
res.include_usage = slot.params.include_usage;
|
|
res.content = slot.generated_text;
|
|
res.timings = slot.get_timings();
|
|
res.post_sampling_probs = slot.params.post_sampling_probs;
|
|
res.oaicompat = slot.params.oaicompat;
|
|
res.oaicompat_model = slot.params.oaicompat_model;
|
|
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
|
res.oaicompat_msg = slot.update_chat_msg(res.oaicompat_msg_diffs);
|
|
res.n_decoded = slot.n_decoded;
|
|
res.n_prompt_tokens = slot.n_prompt_tokens;
|
|
res.oaicompat_model = slot.oaicompat_model;
|
|
res.data = json {
|
|
{"content", !slot.params.stream ? slot.generated_text : ""},
|
|
{"generated_text", slot.generated_text}, // Always include full text for finish_reason logic
|
|
{"id_slot", slot.id},
|
|
{"stop", true},
|
|
{"model", params.model_alias},
|
|
{"tokens_predicted", slot.n_decoded},
|
|
{"tokens_evaluated", slot.n_prompt_tokens},
|
|
{"generation_settings", get_formated_generation(slot)},
|
|
{"prompt", slot.prompt},
|
|
{"truncated", slot.truncated},
|
|
{"stopped_eos", slot.stopped_eos},
|
|
{"stopped_word", slot.stopped_word},
|
|
{"stopped_limit", slot.stopped_limit},
|
|
{"stopping_word", slot.stopping_word},
|
|
{"tokens_cached", slot.n_past},
|
|
{"timings", slot.get_formated_timings()},
|
|
//{"oaicompat_chat_format", slot.params.oaicompat_chat_format},
|
|
};
|
|
|
|
// populate res.probs_output
|
|
if (slot.sparams.n_probs > 0) {
|
|
if (!slot.params.stream && slot.stopped_word) {
|
|
const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
|
|
|
|
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
|
res.probs_output = std::vector<completion_token_output>(
|
|
slot.generated_token_probs.begin(),
|
|
slot.generated_token_probs.end() - safe_offset);
|
|
} else {
|
|
res.probs_output = std::vector<completion_token_output>(
|
|
slot.generated_token_probs.begin(),
|
|
slot.generated_token_probs.end());
|
|
}
|
|
res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output);
|
|
}
|
|
|
|
if (slot.oaicompat) {
|
|
res.data["oaicompat_token_ctr"] = slot.n_decoded;
|
|
res.data["model"] = slot.oaicompat_model;
|
|
}
|
|
|
|
queue_results.send(std::move(res));
|
|
}
|
|
|
|
void send_embedding(const server_slot & slot, const llama_batch & batch) {
|
|
server_task_result res;
|
|
res.id = slot.id_task;
|
|
res.id_multi = slot.id_multi;
|
|
res.error = false;
|
|
res.stop = true;
|
|
|
|
const int n_embd = llama_n_embd(model);
|
|
|
|
std::vector<float> embd_res(n_embd, 0.0f);
|
|
|
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
|
continue;
|
|
}
|
|
|
|
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
|
if (embd == NULL) {
|
|
embd = llama_get_embeddings_ith(ctx, i);
|
|
}
|
|
|
|
if (embd == NULL) {
|
|
LOG_ERROR("failed to get embeddings", {
|
|
{"token", batch.token [i]},
|
|
{"seq_id", batch.seq_id[i][0]}
|
|
});
|
|
|
|
res.data = json {
|
|
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
|
};
|
|
|
|
continue;
|
|
}
|
|
|
|
llama_embd_normalize(embd, embd_res.data(), n_embd);
|
|
|
|
res.data = json {
|
|
{"embedding", embd_res},
|
|
};
|
|
}
|
|
|
|
queue_results.send(res);
|
|
}
|
|
|
|
void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding) {
|
|
server_task task;
|
|
task.id = id_task;
|
|
task.id_multi = id_multi;
|
|
task.id_target = 0;
|
|
task.data = std::move(data);
|
|
task.infill = infill;
|
|
task.embedding = embedding;
|
|
task.type = SERVER_TASK_TYPE_COMPLETION;
|
|
|
|
// when a completion task's prompt array is not a singleton, we split it into multiple requests
|
|
// otherwise, it's a single-prompt task, we actually queue it
|
|
// if there's numbers in the prompt array it will be treated as an array of tokens
|
|
if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
|
|
bool numbers = false;
|
|
for (const auto & e : task.data.at("prompt")) {
|
|
if (e.is_number()) {
|
|
numbers = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
|
|
// it will completely stall the server. I don't know where the bug for this is.
|
|
//
|
|
// if there are numbers, it needs to be treated like a single prompt,
|
|
// queue_tasks handles a mix of strings and numbers just fine.
|
|
if (numbers) {
|
|
queue_tasks.post(task);
|
|
} else {
|
|
split_multiprompt_task(id_task, task);
|
|
}
|
|
} else {
|
|
queue_tasks.post(task);
|
|
}
|
|
}
|
|
|
|
void request_cancel(int id_task) {
|
|
server_task task;
|
|
task.type = SERVER_TASK_TYPE_CANCEL;
|
|
task.id_target = id_task;
|
|
|
|
queue_tasks.post(task);
|
|
}
|
|
|
|
void split_multiprompt_task(int id_multi, const server_task & multiprompt_task) {
|
|
const int prompt_count = multiprompt_task.data.at("prompt").size();
|
|
if (prompt_count <= 1) {
|
|
send_error(multiprompt_task, "error while handling multiple prompts");
|
|
return;
|
|
}
|
|
|
|
// generate all the ID for subtask
|
|
std::vector<int> subtask_ids(prompt_count);
|
|
for (int i = 0; i < prompt_count; i++) {
|
|
subtask_ids[i] = queue_tasks.get_new_id();
|
|
}
|
|
|
|
// queue up the multitask so we can track its subtask progression
|
|
queue_tasks.add_multitask(id_multi, subtask_ids);
|
|
|
|
// add subtasks
|
|
for (int i = 0; i < prompt_count; i++) {
|
|
json subtask_data = multiprompt_task.data;
|
|
subtask_data["prompt"] = subtask_data.at("prompt")[i];
|
|
|
|
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
|
request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding);
|
|
}
|
|
}
|
|
|
|
void process_single_task(const server_task & task) {
|
|
switch (task.type) {
|
|
case SERVER_TASK_TYPE_COMPLETION:
|
|
{
|
|
const int id_slot = json_value(task.data, "id_slot", -1);
|
|
|
|
server_slot * slot;
|
|
|
|
if (id_slot != -1) {
|
|
slot = get_slot_by_id(id_slot);
|
|
} else {
|
|
std::string prompt;
|
|
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
|
|
prompt = json_value(task.data, "prompt", std::string());
|
|
}
|
|
|
|
slot = get_available_slot(prompt);
|
|
}
|
|
|
|
if (slot == nullptr) {
|
|
// if no slot is available, we defer this task for processing later
|
|
LOG_VERBOSE("no slot is available", {{"id_task", task.id}});
|
|
queue_tasks.defer(task);
|
|
break;
|
|
}
|
|
if (!slot->available()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
|
|
queue_tasks.defer(task);
|
|
break;
|
|
}
|
|
|
|
if (task.data.contains("system_prompt")) {
|
|
std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
|
|
system_prompt_set(sys_prompt);
|
|
|
|
for (server_slot & slot : slots) {
|
|
slot.n_past = 0;
|
|
slot.n_past_se = 0;
|
|
}
|
|
}
|
|
|
|
slot->reset();
|
|
|
|
slot->id_task = task.id;
|
|
slot->id_multi = task.id_multi;
|
|
slot->infill = task.infill;
|
|
slot->embedding = task.embedding;
|
|
|
|
if (!launch_slot_with_task(*slot, task)) {
|
|
LOG_ERROR("error while launching slot", task.data);
|
|
break;
|
|
}
|
|
} break;
|
|
case SERVER_TASK_TYPE_CANCEL:
|
|
{
|
|
// release slot linked with the task id
|
|
for (auto & slot : slots) {
|
|
if (slot.id_task == task.id_target) {
|
|
slot.release();
|
|
break;
|
|
}
|
|
}
|
|
} break;
|
|
case SERVER_TASK_TYPE_NEXT_RESPONSE:
|
|
{
|
|
// do nothing
|
|
} break;
|
|
case SERVER_TASK_TYPE_METRICS:
|
|
{
|
|
json slots_data = json::array();
|
|
|
|
int n_idle_slots = 0;
|
|
int n_processing_slots = 0;
|
|
|
|
for (server_slot & slot : slots) {
|
|
json slot_data = get_formated_generation(slot);
|
|
slot_data["id"] = slot.id;
|
|
slot_data["id_task"] = slot.id_task;
|
|
slot_data["state"] = slot.state;
|
|
slot_data["prompt"] = slot.prompt;
|
|
slot_data["next_token"] = {
|
|
{"has_next_token", slot.has_next_token},
|
|
{"n_remain", slot.n_remaining},
|
|
{"n_decoded", slot.n_decoded},
|
|
{"stopped_eos", slot.stopped_eos},
|
|
{"stopped_word", slot.stopped_word},
|
|
{"stopped_limit", slot.stopped_limit},
|
|
{"stopping_word", slot.stopping_word},
|
|
};
|
|
|
|
if (slot_data["state"] == SLOT_STATE_IDLE) {
|
|
n_idle_slots++;
|
|
} else {
|
|
n_processing_slots++;
|
|
}
|
|
|
|
slots_data.push_back(slot_data);
|
|
}
|
|
LOG_INFO("slot data", {
|
|
{"id_task", task.id},
|
|
{"n_idle_slots", n_idle_slots},
|
|
{"n_processing_slots", n_processing_slots}
|
|
});
|
|
|
|
LOG_VERBOSE("slot data", {
|
|
{"id_task", task.id},
|
|
{"n_idle_slots", n_idle_slots},
|
|
{"n_processing_slots", n_processing_slots},
|
|
{"slots", slots_data}
|
|
});
|
|
|
|
server_task_result res;
|
|
res.id = task.id;
|
|
res.id_multi = task.id_multi;
|
|
res.stop = true;
|
|
res.error = false;
|
|
res.data = {
|
|
{ "idle", n_idle_slots },
|
|
{ "processing", n_processing_slots },
|
|
{ "deferred", queue_tasks.queue_tasks_deferred.size() },
|
|
{ "t_start", metrics.t_start},
|
|
|
|
{ "n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total},
|
|
{ "t_tokens_generation_total", metrics.t_tokens_generation_total},
|
|
{ "n_tokens_predicted_total", metrics.n_tokens_predicted_total},
|
|
{ "t_prompt_processing_total", metrics.t_prompt_processing_total},
|
|
|
|
{ "n_prompt_tokens_processed", metrics.n_prompt_tokens_processed},
|
|
{ "t_prompt_processing", metrics.t_prompt_processing},
|
|
{ "n_tokens_predicted", metrics.n_tokens_predicted},
|
|
{ "t_tokens_generation", metrics.t_tokens_generation},
|
|
|
|
{ "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)},
|
|
{ "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)},
|
|
|
|
{ "slots", slots_data },
|
|
};
|
|
|
|
if (json_value(task.data, "reset_bucket", false)) {
|
|
metrics.reset_bucket();
|
|
}
|
|
queue_results.send(res);
|
|
} break;
|
|
case SERVER_TASK_TYPE_SLOT_SAVE:
|
|
{
|
|
int id_slot = task.data.at("id_slot");
|
|
server_slot * slot = get_slot_by_id(id_slot);
|
|
if (slot == nullptr) {
|
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
if (!slot->available()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
|
|
queue_tasks.defer(task);
|
|
break;
|
|
}
|
|
|
|
const size_t token_count = slot->cache_tokens.size();
|
|
const int64_t t_start = ggml_time_us();
|
|
|
|
std::string filename = task.data.at("filename");
|
|
std::string filepath = task.data.at("filepath");
|
|
|
|
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
|
|
|
|
const int64_t t_end = ggml_time_us();
|
|
const double t_save_ms = (t_end - t_start) / 1000.0;
|
|
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.stop = true;
|
|
result.error = false;
|
|
result.data = json {
|
|
{ "id_slot", id_slot },
|
|
{ "filename", filename },
|
|
{ "n_saved", token_count }, // tokens saved
|
|
{ "n_written", nwrite }, // bytes written
|
|
{ "timings", {
|
|
{ "save_ms", t_save_ms }
|
|
} }
|
|
};
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_SLOT_RESTORE:
|
|
{
|
|
int id_slot = task.data.at("id_slot");
|
|
server_slot * slot = get_slot_by_id(id_slot);
|
|
if (slot == nullptr) {
|
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
if (!slot->available()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
|
|
queue_tasks.defer(task);
|
|
break;
|
|
}
|
|
|
|
const int64_t t_start = ggml_time_us();
|
|
|
|
std::string filename = task.data.at("filename");
|
|
std::string filepath = task.data.at("filepath");
|
|
|
|
slot->cache_tokens.resize(slot->n_ctx);
|
|
size_t token_count = 0;
|
|
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
|
|
if (nread == 0) {
|
|
slot->cache_tokens.resize(0);
|
|
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
slot->cache_tokens.resize(token_count);
|
|
|
|
const int64_t t_end = ggml_time_us();
|
|
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
|
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.stop = true;
|
|
result.error = false;
|
|
result.data = json {
|
|
{ "id_slot", id_slot },
|
|
{ "filename", filename },
|
|
{ "n_restored", token_count }, // tokens restored
|
|
{ "n_read", nread }, // bytes read
|
|
{ "timings", {
|
|
{ "restore_ms", t_restore_ms }
|
|
} }
|
|
};
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_SLOT_ERASE:
|
|
{
|
|
int id_slot = task.data.at("id_slot");
|
|
server_slot * slot = get_slot_by_id(id_slot);
|
|
if (slot == nullptr) {
|
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
break;
|
|
}
|
|
if (!slot->available()) {
|
|
// if requested slot is unavailable, we defer this task for processing later
|
|
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
|
|
queue_tasks.defer(task);
|
|
break;
|
|
}
|
|
|
|
// Erase token cache
|
|
const size_t n_erased = slot->cache_tokens.size();
|
|
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
|
|
slot->cache_tokens.clear();
|
|
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.stop = true;
|
|
result.error = false;
|
|
result.data = json {
|
|
{ "id_slot", id_slot },
|
|
{ "n_erased", n_erased }
|
|
};
|
|
queue_results.send(result);
|
|
} break;
|
|
case SERVER_TASK_TYPE_SET_LORA:
|
|
{
|
|
llama_lora_adapters_apply(ctx, lora_adapters);
|
|
server_task_result result;
|
|
result.id = task.id;
|
|
result.data = json{{ "success", true }};
|
|
queue_results.send(result);
|
|
} break;
|
|
}
|
|
}
|
|
|
|
void on_finish_multitask(const server_task_multi & multitask) {
|
|
// all subtasks done == multitask is done
|
|
server_task_result result;
|
|
result.id = multitask.id;
|
|
result.stop = true;
|
|
result.error = false;
|
|
|
|
// collect json results into one json result
|
|
std::vector<json> result_jsons;
|
|
for (const auto & subres : multitask.results) {
|
|
result_jsons.push_back(subres.data);
|
|
result.error = result.error && subres.error;
|
|
}
|
|
result.data = json {
|
|
{ "results", result_jsons }
|
|
};
|
|
|
|
queue_results.send(result);
|
|
}
|
|
|
|
void update_slots() {
|
|
if (system_need_update) {
|
|
system_prompt_update();
|
|
}
|
|
|
|
// release slots
|
|
for (auto & slot : slots) {
|
|
if (slot.command == SLOT_COMMAND_RELEASE) {
|
|
slot.state = SLOT_STATE_IDLE;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
slot.t_last_used = ggml_time_us();
|
|
|
|
LOG_INFO("slot released", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_ctx", n_ctx},
|
|
{"n_past", slot.n_past},
|
|
{"n_system_tokens", system_tokens.size()},
|
|
{"n_cache_tokens", slot.cache_tokens.size()},
|
|
{"truncated", slot.truncated}
|
|
});
|
|
|
|
queue_tasks.notify_slot_changed();
|
|
}
|
|
}
|
|
|
|
// check if all slots are idle
|
|
{
|
|
bool all_idle = true;
|
|
|
|
for (auto & slot : slots) {
|
|
if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) {
|
|
all_idle = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (all_idle) {
|
|
LOG_INFO("all slots are idle", {});
|
|
if (system_prompt.empty() && clean_kv_cache) {
|
|
kv_cache_clear();
|
|
}
|
|
|
|
return;
|
|
}
|
|
}
|
|
|
|
{
|
|
LOG_VERBOSE("posting NEXT_RESPONSE", {});
|
|
|
|
server_task task;
|
|
task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
|
|
task.id_target = -1;
|
|
|
|
queue_tasks.post(task);
|
|
}
|
|
|
|
// apply context-shift if needed
|
|
// TODO: simplify and improve
|
|
for (server_slot & slot : slots) {
|
|
if (slot.ga_n == 1) {
|
|
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
|
|
// Shift context
|
|
const int n_keep = slot.params.n_keep + add_bos_token;
|
|
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
|
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
|
|
|
LOG_INFO("slot context shift", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_keep", n_keep},
|
|
{"n_left", n_left},
|
|
{"n_discard", n_discard},
|
|
{"n_ctx", n_ctx},
|
|
{"n_past", slot.n_past},
|
|
{"n_system_tokens", system_tokens.size()},
|
|
{"n_cache_tokens", slot.cache_tokens.size()}
|
|
});
|
|
|
|
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
|
|
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
|
|
|
if (slot.params.cache_prompt) {
|
|
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
|
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
|
|
}
|
|
|
|
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
|
|
}
|
|
|
|
slot.n_past -= n_discard;
|
|
|
|
slot.truncated = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
// start populating the batch for this iteration
|
|
llama_batch_clear(batch);
|
|
|
|
auto accept_special_token = [&](server_slot& slot, llama_token token) {
|
|
return params.special || slot.sparams.preserved_tokens.find(token) != slot.sparams.preserved_tokens.end();
|
|
};
|
|
|
|
// frist, add sampled tokens from any ongoing sequences
|
|
for (auto & slot : slots) {
|
|
if (slot.state == SLOT_STATE_IDLE) {
|
|
continue;
|
|
}
|
|
|
|
slot.i_batch = batch.n_tokens;
|
|
|
|
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
|
|
|
// TODO: we always have to take into account the "system_tokens"
|
|
// this is not great and needs to be improved somehow
|
|
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
|
|
|
|
slot.n_past += 1;
|
|
|
|
if (slot.params.cache_prompt) {
|
|
slot.cache_tokens.push_back(slot.sampled);
|
|
}
|
|
|
|
LOG_VERBOSE("slot decode token", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_ctx", n_ctx},
|
|
{"n_past", slot.n_past},
|
|
{"n_system_tokens", system_tokens.size()},
|
|
{"n_cache_tokens", slot.cache_tokens.size()},
|
|
{"truncated", slot.truncated}
|
|
});
|
|
}
|
|
|
|
// process in chunks of params.n_batch
|
|
int32_t n_batch = llama_n_batch(ctx);
|
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
|
|
|
// track if this is an embedding or non-embedding batch
|
|
// if we've added sampled tokens above, we are in non-embedding mode
|
|
// -1: none, 0: non-embedding, 1: embedding
|
|
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
|
|
|
// next, batch any pending prompts without exceeding n_batch
|
|
if (params.cont_batching || batch.n_tokens == 0) {
|
|
for (auto & slot : slots) {
|
|
// this slot still has a prompt to be processed
|
|
if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
|
|
auto & prompt_tokens = slot.prompt_tokens;
|
|
|
|
// we haven't tokenized the prompt yet - do it now:
|
|
if (prompt_tokens.empty()) {
|
|
LOG_VERBOSE("tokenizing prompt", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task}
|
|
});
|
|
|
|
slot.t_start_process_prompt = ggml_time_us();
|
|
slot.t_start_generation = 0;
|
|
|
|
if (slot.infill) {
|
|
const bool add_bos = llama_should_add_bos_token(model);
|
|
bool suff_rm_leading_spc = true;
|
|
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
|
|
params.input_suffix.erase(0, 1);
|
|
suff_rm_leading_spc = false;
|
|
}
|
|
|
|
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
|
|
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
|
|
|
|
const int space_token = 29871; // TODO: this should not be hardcoded
|
|
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
|
|
suffix_tokens.erase(suffix_tokens.begin());
|
|
}
|
|
|
|
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
|
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
|
|
|
|
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
|
|
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
|
|
if (add_bos) {
|
|
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
|
}
|
|
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
|
|
|
const llama_token middle_token = llama_token_middle(model);
|
|
if (middle_token >= 0) {
|
|
embd_inp.push_back(middle_token);
|
|
}
|
|
|
|
prompt_tokens = embd_inp;
|
|
} else {
|
|
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
|
}
|
|
|
|
slot.n_past = 0;
|
|
slot.n_prompt_tokens = prompt_tokens.size();
|
|
|
|
LOG_VERBOSE("prompt tokenized", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_ctx", slot.n_ctx},
|
|
{"n_keep", slot.params.n_keep},
|
|
{"n_prompt_tokens", slot.n_prompt_tokens},
|
|
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
|
});
|
|
|
|
// empty prompt passed -> release the slot and send empty response
|
|
if (prompt_tokens.empty()) {
|
|
LOG_INFO("empty prompt - releasing slot", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task}
|
|
});
|
|
|
|
slot.state = SLOT_STATE_PROCESSING;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
slot.release();
|
|
slot.print_timings();
|
|
send_final_response(slot);
|
|
continue;
|
|
}
|
|
|
|
if (slot.embedding) {
|
|
// this prompt is too large to process - discard it
|
|
if (slot.n_prompt_tokens > n_ubatch) {
|
|
slot.state = SLOT_STATE_PROCESSING;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
slot.release();
|
|
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
|
continue;
|
|
}
|
|
} else {
|
|
if (slot.params.n_keep < 0) {
|
|
slot.params.n_keep = slot.n_prompt_tokens;
|
|
}
|
|
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
|
|
|
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
|
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
|
const int n_left = slot.n_ctx - slot.params.n_keep;
|
|
|
|
const int n_block_size = n_left / 2;
|
|
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
|
|
|
std::vector<llama_token> new_tokens(
|
|
prompt_tokens.begin(),
|
|
prompt_tokens.begin() + slot.params.n_keep);
|
|
|
|
new_tokens.insert(
|
|
new_tokens.end(),
|
|
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
|
prompt_tokens.end());
|
|
|
|
prompt_tokens = std::move(new_tokens);
|
|
|
|
slot.truncated = true;
|
|
slot.n_prompt_tokens = prompt_tokens.size();
|
|
|
|
LOG_VERBOSE("input truncated", {
|
|
{"id_slot", slot.id},
|
|
{"id_task", slot.id_task},
|
|
{"n_ctx", slot.n_ctx},
|
|
{"n_keep", slot.params.n_keep},
|
|
{"n_left", n_left},
|
|
{"n_prompt_tokens", slot.n_prompt_tokens},
|
|
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
|
});
|
|
|
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
|
}
|
|
|
|
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
|
|
|
|
if (!slot.params.cache_prompt) {
|
|
slot.n_past_se = 0;
|
|
slot.ga_i = 0;
|
|
} else {
|
|
GGML_ASSERT(slot.ga_n == 1);
|
|
|
|
// reuse any previously computed tokens that are common with the new prompt
|
|
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
|
|
|
// push the prompt into the sampling context (do not apply grammar)
|
|
for (int i = 0; i < slot.n_past; ++i) {
|
|
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
|
|
// we have to evaluate at least 1 token to generate logits.
|
|
LOG_INFO("we have to evaluate at least 1 token to generate logits", {
|
|
{ "id_slot", slot.id },
|
|
{ "id_task", slot.id_task }
|
|
});
|
|
|
|
slot.n_past--;
|
|
if (slot.ga_i > 0) {
|
|
slot.n_past_se--;
|
|
}
|
|
}
|
|
|
|
slot.n_prompt_tokens_processed = 0;
|
|
}
|
|
|
|
if (slot.embedding) {
|
|
// cannot fit the prompt in the current batch - will try next iter
|
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// check that we are in the right batch_type, if not defer the slot
|
|
bool slot_type = slot.embedding ? 1 : 0;
|
|
if (batch_type == -1) {
|
|
batch_type = slot_type;
|
|
} else if (batch_type != slot_type) {
|
|
continue;
|
|
}
|
|
|
|
// keep only the common part
|
|
int p0 = (int) system_tokens.size() + slot.n_past;
|
|
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
|
// could not partially delete (likely using a non-Transformer model)
|
|
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
|
|
|
p0 = (int) system_tokens.size();
|
|
if (p0 != 0) {
|
|
// copy over the system prompt when there is one
|
|
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
|
|
}
|
|
|
|
// there is no common part left (except for the system prompt)
|
|
slot.n_past = 0;
|
|
slot.n_past_se = 0;
|
|
slot.ga_i = 0;
|
|
// TODO: is the system prompt ever in the sampling context?
|
|
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
|
|
}
|
|
|
|
// remove the non-common part from the cache
|
|
slot.cache_tokens.resize(slot.n_past);
|
|
|
|
LOG_INFO("kv cache rm [p0, end)", {
|
|
{ "id_slot", slot.id },
|
|
{ "id_task", slot.id_task },
|
|
{ "p0", p0 }
|
|
});
|
|
|
|
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
|
|
|
int32_t ga_i = slot.ga_i;
|
|
int32_t ga_n = slot.ga_n;
|
|
int32_t ga_w = slot.ga_w;
|
|
|
|
// add prompt tokens for processing in the current batch
|
|
// TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
|
|
for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) {
|
|
if (slot.ga_n != 1) {
|
|
while (slot_npast >= ga_i + ga_w) {
|
|
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
|
slot_npast -= bd;
|
|
ga_i += ga_w/ga_n;
|
|
}
|
|
}
|
|
|
|
llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
|
|
|
|
if (slot.params.cache_prompt) {
|
|
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
|
}
|
|
|
|
slot.n_prompt_tokens_processed++;
|
|
slot_npast++;
|
|
}
|
|
|
|
LOG_VERBOSE("prompt processing progress", {
|
|
{"id_slot", slot.id},
|
|
{"n_past", slot.n_past},
|
|
{"n_ctx", n_ctx},
|
|
{"n_tokens", batch.n_tokens},
|
|
{"progress", (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens},
|
|
});
|
|
|
|
// entire prompt has been processed - start decoding new tokens
|
|
if (slot.n_past == slot.n_prompt_tokens) {
|
|
slot.state = SLOT_STATE_PROCESSING;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
|
|
GGML_ASSERT(batch.n_tokens > 0);
|
|
llama_sampling_reset(llama_get_model_vocab(model), slot.ctx_sampling);
|
|
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
|
llama_token id = slot.prompt_tokens[i];
|
|
if (id != LLAMA_TOKEN_NULL) {
|
|
llama_sampling_accept(slot.ctx_sampling, ctx, id, false);
|
|
}
|
|
}
|
|
|
|
// extract the logits only for the last token
|
|
batch.logits[batch.n_tokens - 1] = true;
|
|
|
|
slot.n_decoded = 0;
|
|
slot.i_batch = batch.n_tokens - 1;
|
|
|
|
LOG_VERBOSE("prompt done", {
|
|
{"id_slot", slot.id},
|
|
{"n_past", slot.n_past},
|
|
{"n_ctx", n_ctx},
|
|
{"n_tokens", batch.n_tokens},
|
|
});
|
|
}
|
|
}
|
|
|
|
if (batch.n_tokens >= n_batch) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (batch.n_tokens == 0) {
|
|
LOG_VERBOSE("no tokens to decode", {});
|
|
return;
|
|
}
|
|
|
|
LOG_VERBOSE("decoding batch", {
|
|
{"n_tokens", batch.n_tokens},
|
|
});
|
|
|
|
// make sure we're in the right embedding mode
|
|
llama_set_embeddings(ctx, batch_type == 1);
|
|
|
|
// process the created batch of tokens
|
|
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
|
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
|
|
|
for (auto & slot : slots) {
|
|
if (slot.ga_n != 1) {
|
|
// context extension via Self-Extend
|
|
// TODO: simplify and/or abstract this
|
|
while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
|
|
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
|
|
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
|
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
|
|
|
LOG_TEE("\n");
|
|
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
|
|
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
|
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
|
|
|
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
|
|
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
|
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
|
|
|
slot.n_past_se -= bd;
|
|
|
|
slot.ga_i += slot.ga_w / slot.ga_n;
|
|
|
|
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
|
|
}
|
|
|
|
slot.n_past_se += n_tokens;
|
|
}
|
|
}
|
|
|
|
llama_batch batch_view = {
|
|
n_tokens,
|
|
batch.token + i,
|
|
nullptr,
|
|
batch.pos + i,
|
|
batch.n_seq_id + i,
|
|
batch.seq_id + i,
|
|
batch.logits + i,
|
|
0, 0, 0, // unused
|
|
};
|
|
|
|
const int ret = llama_decode(ctx, batch_view);
|
|
|
|
if (ret != 0) {
|
|
if (n_batch == 1 || ret < 0) {
|
|
// if you get here, it means the KV cache is full - try increasing it via the context size
|
|
LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
|
|
{"i", i},
|
|
{"n_batch", ret},
|
|
{"ret", ret},
|
|
});
|
|
for (auto & slot : slots) {
|
|
slot.state = SLOT_STATE_PROCESSING;
|
|
slot.command = SLOT_COMMAND_NONE;
|
|
slot.release();
|
|
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
|
}
|
|
break; // break loop of n_batch
|
|
}
|
|
|
|
// retry with half the batch size to try to find a free slot in the KV cache
|
|
n_batch /= 2;
|
|
i -= n_batch;
|
|
|
|
LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", {
|
|
{"i", i},
|
|
{"n_batch", n_batch},
|
|
{"ret", ret},
|
|
});
|
|
|
|
continue; // continue loop of n_batch
|
|
}
|
|
|
|
for (auto & slot : slots) {
|
|
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
|
|
continue; // continue loop of slots
|
|
}
|
|
|
|
// prompt evaluated for embedding
|
|
if (slot.embedding) {
|
|
send_embedding(slot, batch_view);
|
|
slot.release();
|
|
slot.i_batch = -1;
|
|
continue; // continue loop of slots
|
|
}
|
|
|
|
completion_token_output result;
|
|
const int tok_idx = slot.i_batch - i;
|
|
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, tok_idx);
|
|
|
|
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
|
|
|
|
slot.n_decoded += 1;
|
|
|
|
const int64_t t_current = ggml_time_us();
|
|
|
|
if (slot.n_decoded == 1) {
|
|
slot.t_start_generation = ggml_time_us();
|
|
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
|
metrics.on_prompt_eval(slot);
|
|
}
|
|
|
|
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
|
|
|
|
result.tok = id;
|
|
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
|
result.text_to_send = llama_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
|
|
|
if (slot.sparams.n_probs > 0) {
|
|
populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, tok_idx);
|
|
}
|
|
|
|
if (!process_token(result, slot)) {
|
|
slot.release();
|
|
slot.print_timings();
|
|
send_final_response(slot);
|
|
metrics.on_prediction(slot);
|
|
}
|
|
|
|
slot.i_batch = -1;
|
|
}
|
|
|
|
// Do speculative decoding
|
|
for (auto & slot : slots) {
|
|
if (!slot.is_processing() || !slot.spec) {
|
|
continue;
|
|
}
|
|
|
|
if (slot.state != SLOT_STATE_PROCESSING) {
|
|
continue;
|
|
}
|
|
|
|
// determine the max draft that fits the current slot state
|
|
int n_draft_max = slot.params.speculative.n_max;
|
|
|
|
// note: n_past is not yet increased for the `id` token sampled above
|
|
// also, need to leave space for 1 extra token to allow context shifts
|
|
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
|
|
|
|
if (slot.n_predict > 0) {
|
|
n_draft_max = std::min(n_draft_max, slot.n_predict - slot.n_decoded - 1);
|
|
}
|
|
|
|
LOG_VERBOSE("max possible draft", {
|
|
{"id_slot", slot.id},
|
|
{"n_draft_max", n_draft_max}
|
|
});
|
|
|
|
if (n_draft_max < slot.params.speculative.n_min) {
|
|
LOG_VERBOSE("the max possible draft is too small", {
|
|
{"id_slot", slot.id},
|
|
{"n_draft_max", n_draft_max},
|
|
{"n_min", slot.params.speculative.n_min}
|
|
});
|
|
continue;
|
|
}
|
|
|
|
llama_token id = slot.sampled;
|
|
|
|
struct llama_speculative_params params_spec;
|
|
params_spec.n_draft = n_draft_max;
|
|
params_spec.n_reuse = cparams_dft.n_ctx - slot.params.speculative.n_max;
|
|
params_spec.p_min = slot.params.speculative.p_min;
|
|
|
|
const std::vector<llama_token> & cached_text_tokens = slot.cache_tokens;
|
|
std::vector<llama_token> draft = llama_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
|
|
|
|
// ignore small drafts
|
|
if (slot.params.speculative.n_min > (int) draft.size()) {
|
|
LOG_VERBOSE("ignoring small draft", {
|
|
{"id_slot", slot.id},
|
|
{"draft_size", (int) draft.size()},
|
|
{"n_min", slot.params.speculative.n_min}
|
|
});
|
|
continue;
|
|
}
|
|
|
|
// keep track of total number of drafted tokens tested
|
|
slot.n_draft_total += draft.size();
|
|
|
|
// construct the speculation batch
|
|
llama_batch_clear(slot.batch_spec);
|
|
llama_batch_add(slot.batch_spec, id, slot.n_past, { slot.id + 1 }, true);
|
|
|
|
for (size_t i = 0; i < draft.size(); ++i) {
|
|
llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id + 1 }, true);
|
|
}
|
|
|
|
LOG_VERBOSE("decoding speculative batch", {
|
|
{"id_slot", slot.id},
|
|
{"size", slot.batch_spec.n_tokens}
|
|
});
|
|
|
|
llama_decode(ctx, slot.batch_spec);
|
|
|
|
// the accepted tokens from the speculation
|
|
std::vector<llama_token> ids = llama_sampling_sample_and_accept_n(slot.ctx_sampling, ctx, draft);
|
|
|
|
slot.n_past += ids.size();
|
|
slot.n_decoded += ids.size();
|
|
|
|
// update how many tokens out of those tested were accepted
|
|
slot.n_draft_accepted += ids.size() - 1;
|
|
|
|
slot.cache_tokens.push_back(id);
|
|
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
|
|
|
|
llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1);
|
|
|
|
for (size_t i = 0; i < ids.size(); ++i) {
|
|
completion_token_output result;
|
|
|
|
result.tok = ids[i];
|
|
result.text_to_send = llama_token_to_piece(ctx, result.tok, params.special);
|
|
result.prob = 1.0f; // set later
|
|
|
|
if (slot.sparams.n_probs > 0) {
|
|
populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, i);
|
|
}
|
|
|
|
if (!process_token(result, slot)) {
|
|
// release slot because of stop condition
|
|
slot.release();
|
|
slot.print_timings();
|
|
send_final_response(slot);
|
|
metrics.on_prediction(slot);
|
|
break;
|
|
}
|
|
}
|
|
|
|
LOG_VERBOSE("speculative decoding result", {
|
|
{"id_slot", slot.id},
|
|
{"accepted", (int) ids.size() - 1},
|
|
{"total", (int) draft.size()},
|
|
{"new_n_past", slot.n_past}
|
|
});
|
|
}
|
|
}
|
|
|
|
LOG_VERBOSE("run slots completed", {});
|
|
}
|
|
|
|
json model_meta() const {
|
|
return json {
|
|
{"vocab_type", llama_vocab_type (model)},
|
|
{"n_vocab", llama_n_vocab (model)},
|
|
{"n_ctx_train", llama_n_ctx_train (model)},
|
|
{"n_embd", llama_n_embd (model)},
|
|
{"n_params", llama_model_n_params(model)},
|
|
{"size", llama_model_size (model)},
|
|
};
|
|
}
|
|
};
|
|
|
|
static json format_final_response_oaicompat(const json& request, json result, const std::string& completion_id, bool streaming = false) {
|
|
bool stopped_word = result.count("stopped_word") != 0;
|
|
bool stopped_eos = json_value(result, "stopped_eos", false);
|
|
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
|
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
|
std::string content = json_value(result, "content", std::string(""));
|
|
|
|
std::string finish_reason = "length";
|
|
if (stopped_word || stopped_eos) {
|
|
finish_reason = "stop";
|
|
}
|
|
|
|
json choices =
|
|
streaming ? json::array({ json{{"finish_reason", finish_reason},
|
|
{"index", 0},
|
|
{"delta", json::object()}} })
|
|
: json::array({ json{{"finish_reason", finish_reason},
|
|
{"index", 0},
|
|
{"message", json{{"content", content},
|
|
{"role", "assistant"}}}} });
|
|
|
|
std::time_t t = std::time(0);
|
|
|
|
json res = json{
|
|
{"choices", choices},
|
|
{"created", t},
|
|
{"model",
|
|
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
|
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
|
|
{"usage", json {
|
|
{"completion_tokens", num_tokens_predicted},
|
|
{"prompt_tokens", num_prompt_tokens},
|
|
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
|
|
}},
|
|
{"id", completion_id}
|
|
};
|
|
|
|
if (server_verbose) {
|
|
res["__verbose"] = result;
|
|
}
|
|
|
|
if (result.contains("completion_probabilities")) {
|
|
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
// return value is vector as there is one case where we might need to generate two responses
|
|
static std::vector<json> format_partial_response_oaicompat(server_task_result task_result, const std::string& completion_id) {
|
|
json result = task_result.data;
|
|
std::cout << result.dump(4) << std::endl;
|
|
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
|
|
return std::vector<json>({ result });
|
|
}
|
|
|
|
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
|
|
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
|
|
|
bool stopped_word = json_value(result, "stopped_word", false);
|
|
bool stopped_eos = json_value(result, "stopped_eos", false);
|
|
bool stopped_limit = json_value(result, "stopped_limit", false);
|
|
std::string content = json_value(result, "content", std::string(""));
|
|
|
|
std::string finish_reason;
|
|
if (stopped_word || stopped_eos) {
|
|
finish_reason = "stop";
|
|
}
|
|
if (stopped_limit) {
|
|
finish_reason = "length";
|
|
}
|
|
|
|
std::time_t t = std::time(0);
|
|
|
|
json choices;
|
|
|
|
if (!finish_reason.empty()) {
|
|
choices = json::array({ json{{"finish_reason", finish_reason},
|
|
{"index", 0},
|
|
{"delta", json::object()}} });
|
|
}
|
|
else {
|
|
if (first) {
|
|
if (content.empty()) {
|
|
choices = json::array({ json{{"finish_reason", nullptr},
|
|
{"index", 0},
|
|
{"delta", json{{"role", "assistant"}}}} });
|
|
}
|
|
else {
|
|
// We have to send this as two updates to conform to openai behavior
|
|
json initial_ret = json{ {"choices", json::array({json{
|
|
{"finish_reason", nullptr},
|
|
{"index", 0},
|
|
{"delta", json{
|
|
{"role", "assistant"}
|
|
}}}})},
|
|
{"created", t},
|
|
{"id", completion_id},
|
|
{"model", modelname},
|
|
{"object", "chat.completion.chunk"} };
|
|
|
|
json second_ret = json{
|
|
{"choices", json::array({json{{"finish_reason", nullptr},
|
|
{"index", 0},
|
|
{"delta", json{
|
|
{"content", content}}}
|
|
}})},
|
|
{"created", t},
|
|
{"id", completion_id},
|
|
{"model", modelname},
|
|
{"object", "chat.completion.chunk"} };
|
|
|
|
return std::vector<json>({ initial_ret, second_ret });
|
|
}
|
|
}
|
|
else {
|
|
// Some idiosyncrasy in task processing logic makes several trailing calls
|
|
// with empty content, we ignore these at the calee site.
|
|
if (content.empty()) {
|
|
return std::vector<json>({ json::object() });
|
|
}
|
|
|
|
choices = json::array({ json{
|
|
{"finish_reason", nullptr},
|
|
{"index", 0},
|
|
{"delta",
|
|
json{
|
|
{"content", content},
|
|
}},
|
|
} });
|
|
}
|
|
}
|
|
|
|
json ret = json{
|
|
{"choices", choices},
|
|
{"created", t},
|
|
{"id", completion_id},
|
|
{"model", modelname},
|
|
{"object", "chat.completion.chunk"}
|
|
};
|
|
|
|
if (task_result.timings.prompt_n != -1) {
|
|
ret.push_back({ "timings", task_result.timings.to_json() });
|
|
}
|
|
|
|
//
|
|
if (!finish_reason.empty()) {
|
|
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
|
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
|
ret.push_back({ "usage", json {
|
|
{"completion_tokens", num_tokens_predicted},
|
|
{"prompt_tokens", num_prompt_tokens},
|
|
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
|
|
} });
|
|
}
|
|
|
|
return std::vector<json>({ ret });
|
|
}
|
|
|
|
|
|
static json format_embeddings_response_oaicompat(const json& request, const json& embeddings) {
|
|
json data = json::array();
|
|
int i = 0;
|
|
for (auto& elem : embeddings) {
|
|
data.push_back(json{
|
|
{"embedding", json_value(elem, "embedding", json::array())},
|
|
{"index", i++},
|
|
{"object", "embedding"}
|
|
});
|
|
}
|
|
|
|
json res = json{
|
|
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
|
{"object", "list"},
|
|
{"usage", json {
|
|
{"prompt_tokens", 0},
|
|
{"total_tokens", 0}
|
|
}},
|
|
{"data", data}
|
|
};
|
|
|
|
return res;
|
|
}
|
|
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
|
|
// skip GH copilot requests when using default port
|
|
if (req.path == "/v1/health" || req.path == "/v1/completions") {
|
|
return;
|
|
}
|
|
|
|
LOG_INFO("request", {
|
|
{"remote_addr", req.remote_addr},
|
|
{"remote_port", req.remote_port},
|
|
{"status", res.status},
|
|
{"method", req.method},
|
|
{"path", req.path},
|
|
{"params", req.params},
|
|
});
|
|
|
|
LOG_VERBOSE("request", {
|
|
{"request", req.body},
|
|
{"response", res.body},
|
|
});
|
|
}
|
|
|
|
std::function<void(int)> shutdown_handler;
|
|
std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
|
|
|
|
inline void signal_handler(int signal) {
|
|
if (is_terminating.test_and_set()) {
|
|
// in case it hangs, we can force terminate the server by hitting Ctrl+C twice
|
|
// this is for better developer experience, we can remove when the server is stable enough
|
|
fprintf(stderr, "Received second interrupt, terminating immediately.\n");
|
|
exit(1);
|
|
}
|
|
|
|
shutdown_handler(signal);
|
|
}
|
|
|
|
int main(int argc, char ** argv) {
|
|
#if SERVER_VERBOSE != 1
|
|
log_disable();
|
|
#endif
|
|
// own arguments required by this example
|
|
gpt_params params;
|
|
|
|
if (!gpt_params_parse(argc, argv, params)) {
|
|
gpt_params_print_usage(argc, argv, params);
|
|
return 1;
|
|
}
|
|
|
|
// TODO: not great to use extern vars
|
|
server_log_json = params.log_json;
|
|
server_verbose = params.verbosity > 0;
|
|
|
|
|
|
// struct that contains llama context and inference
|
|
server_context ctx_server;
|
|
|
|
if (!params.system_prompt.empty()) {
|
|
ctx_server.system_prompt_set(params.system_prompt);
|
|
}
|
|
|
|
if (params.model_alias == "unknown") {
|
|
params.model_alias = params.model;
|
|
}
|
|
|
|
llama_backend_init();
|
|
llama_numa_init(params.numa);
|
|
|
|
LOG_INFO("build info", {
|
|
{"build", LLAMA_BUILD_NUMBER},
|
|
{"commit", LLAMA_COMMIT}
|
|
});
|
|
|
|
LOG_INFO("system info", {
|
|
{"n_threads", params.n_threads},
|
|
{"n_threads_batch", params.n_threads_batch},
|
|
{"total_threads", std::thread::hardware_concurrency()},
|
|
{"system_info", llama_print_system_info()},
|
|
});
|
|
|
|
std::unique_ptr<httplib::Server> svr;
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
|
LOG_INFO("Running with SSL", {{"key", params.ssl_file_key}, {"cert", params.ssl_file_cert}});
|
|
svr.reset(
|
|
new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
|
|
);
|
|
} else {
|
|
LOG_INFO("Running without SSL", {});
|
|
svr.reset(new httplib::Server());
|
|
}
|
|
#else
|
|
svr.reset(new httplib::Server());
|
|
#endif
|
|
|
|
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
|
|
|
|
svr->set_default_headers({{"Server", "llama.cpp"}});
|
|
|
|
// CORS preflight
|
|
svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
res.set_header("Access-Control-Allow-Credentials", "true");
|
|
res.set_header("Access-Control-Allow-Methods", "POST");
|
|
res.set_header("Access-Control-Allow-Headers", "*");
|
|
return res.set_content("", "application/json; charset=utf-8");
|
|
});
|
|
|
|
svr->set_logger(log_server_request);
|
|
|
|
auto res_error = [](httplib::Response & res, json error_data) {
|
|
json final_response {{"error", error_data}};
|
|
res.set_content(final_response.dump(), "application/json; charset=utf-8");
|
|
res.status = json_value(error_data, "code", 500);
|
|
};
|
|
|
|
auto res_ok = [](httplib::Response& res, const json& data) {
|
|
res.set_content(data.dump(), "application/json; charset=utf-8");
|
|
res.status = 200;
|
|
};
|
|
|
|
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
|
std::string message;
|
|
try {
|
|
std::rethrow_exception(std::move(ep));
|
|
} catch (std::exception & e) {
|
|
message = e.what();
|
|
} catch (...) {
|
|
message = "Unknown Exception";
|
|
}
|
|
|
|
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
|
|
LOG_VERBOSE("Got exception", formatted_error);
|
|
res_error(res, formatted_error);
|
|
});
|
|
|
|
svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
|
|
if (res.status == 404) {
|
|
res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
|
|
}
|
|
// for other error codes, we skip processing here because it's already done by res_error()
|
|
});
|
|
|
|
// set timeouts and change hostname and port
|
|
svr->set_read_timeout (params.timeout_read);
|
|
svr->set_write_timeout(params.timeout_write);
|
|
|
|
if (!svr->bind_to_port(params.hostname, params.port)) {
|
|
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port);
|
|
return 1;
|
|
}
|
|
|
|
std::unordered_map<std::string, std::string> log_data;
|
|
|
|
log_data["hostname"] = params.hostname;
|
|
log_data["port"] = std::to_string(params.port);
|
|
|
|
if (params.api_keys.size() == 1) {
|
|
auto key = params.api_keys[0];
|
|
log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0));
|
|
} else if (params.api_keys.size() > 1) {
|
|
log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded";
|
|
}
|
|
|
|
// Necessary similarity of prompt for slot selection
|
|
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
|
|
#ifdef SQLITE3_MODERN_CPP_SUPPORT
|
|
auto db_handle = std::make_shared<DatabaseHandle>(params.sql_save_file);
|
|
bool sqlite_extension_loaded = false;
|
|
if (!params.sqlite_zstd_ext_file.empty()) {
|
|
auto* conn = db_handle->db.connection().get();
|
|
sqlite3_enable_load_extension(conn, 1);
|
|
char* errmsg = nullptr;
|
|
const int rc = sqlite3_load_extension(
|
|
conn,
|
|
params.sqlite_zstd_ext_file.c_str(),
|
|
nullptr,
|
|
&errmsg
|
|
);
|
|
if(rc != SQLITE_OK) {
|
|
const std::string err = errmsg ? errmsg : "Unknown extension error";
|
|
sqlite3_free(errmsg);
|
|
LOG_WARNING("Failed to load extension", {{"err", err}});
|
|
}
|
|
else {
|
|
sqlite_extension_loaded = true;
|
|
}
|
|
sqlite3_enable_load_extension(conn, 0);
|
|
}
|
|
#else
|
|
auto db_handle = false;
|
|
#endif
|
|
// load the model
|
|
if (!ctx_server.load_model(params)) {
|
|
state.store(SERVER_STATE_ERROR);
|
|
return 1;
|
|
} else {
|
|
ctx_server.init();
|
|
state.store(SERVER_STATE_READY);
|
|
}
|
|
|
|
LOG_INFO("model loaded", {});
|
|
|
|
const auto model_meta = ctx_server.model_meta();
|
|
|
|
// print sample chat example to make it clear which template is used
|
|
|
|
LOG_INFO("chat template", {
|
|
{"chat_template", common_chat_templates_source(ctx_server.chat_templates.get())},
|
|
});
|
|
|
|
LOG_INFO("chat template", {
|
|
{"chat_example", common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params.use_jinja, {}).c_str()
|
|
},
|
|
{"built_in", params.chat_template.empty()},
|
|
});
|
|
//
|
|
// Middlewares
|
|
//
|
|
|
|
auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
|
// TODO: should we apply API key to all endpoints, including "/health" and "/models"?
|
|
static const std::set<std::string> protected_endpoints = {
|
|
"/props",
|
|
"/completion",
|
|
"/completions",
|
|
"/v1/completions",
|
|
"/chat/completions",
|
|
"/v1/chat/completions",
|
|
"/infill",
|
|
"/tokenize",
|
|
"/detokenize",
|
|
"/embedding",
|
|
"/embeddings",
|
|
"/v1/embeddings",
|
|
};
|
|
|
|
// If API key is not set, skip validation
|
|
if (params.api_keys.empty()) {
|
|
return true;
|
|
}
|
|
|
|
// If path is not in protected_endpoints list, skip validation
|
|
if (protected_endpoints.find(req.path) == protected_endpoints.end()) {
|
|
return true;
|
|
}
|
|
|
|
// Check for API key in the header
|
|
auto auth_header = req.get_header_value("Authorization");
|
|
|
|
std::string prefix = "Bearer ";
|
|
if (auth_header.substr(0, prefix.size()) == prefix) {
|
|
std::string received_api_key = auth_header.substr(prefix.size());
|
|
if (std::find(params.api_keys.begin(), params.api_keys.end(), received_api_key) != params.api_keys.end()) {
|
|
return true; // API key is valid
|
|
}
|
|
}
|
|
|
|
// API key is invalid or not provided
|
|
// TODO: make another middleware for CORS related logic
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
|
|
|
|
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
|
|
|
return false;
|
|
};
|
|
|
|
// register server middlewares
|
|
svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
|
if (!middleware_validate_api_key(req, res)) {
|
|
return httplib::Server::HandlerResponse::Handled;
|
|
}
|
|
return httplib::Server::HandlerResponse::Unhandled;
|
|
});
|
|
|
|
//
|
|
// Route handlers (or controllers)
|
|
//
|
|
|
|
const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
|
|
server_state current_state = state.load();
|
|
switch (current_state) {
|
|
case SERVER_STATE_READY:
|
|
{
|
|
// request slots data using task queue
|
|
server_task task;
|
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
task.type = SERVER_TASK_TYPE_METRICS;
|
|
task.id_target = -1;
|
|
|
|
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
ctx_server.queue_tasks.post(task);
|
|
|
|
// get the result
|
|
server_task_result result = ctx_server.queue_results.recv(task.id);
|
|
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
|
|
const int n_idle_slots = result.data.at("idle");
|
|
const int n_processing_slots = result.data.at("processing");
|
|
|
|
json health = {
|
|
{"status", "ok"},
|
|
{"slots_idle", n_idle_slots},
|
|
{"slots_processing", n_processing_slots}
|
|
};
|
|
|
|
res.status = 200; // HTTP OK
|
|
if (params.endpoint_slots && req.has_param("include_slots")) {
|
|
health["slots"] = result.data.at("slots");
|
|
}
|
|
|
|
if (n_idle_slots == 0) {
|
|
health["status"] = "no slot available";
|
|
if (req.has_param("fail_on_no_slot")) {
|
|
res.status = 503; // HTTP Service Unavailable
|
|
}
|
|
}
|
|
|
|
res.set_content(health.dump(), "application/json");
|
|
break;
|
|
}
|
|
case SERVER_STATE_LOADING_MODEL:
|
|
{
|
|
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
|
} break;
|
|
case SERVER_STATE_ERROR:
|
|
{
|
|
res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
|
|
} break;
|
|
}
|
|
};
|
|
|
|
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
|
|
if (!params.endpoint_slots) {
|
|
res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
|
return;
|
|
}
|
|
|
|
// request slots data using task queue
|
|
server_task task;
|
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
task.id_multi = -1;
|
|
task.id_target = -1;
|
|
task.type = SERVER_TASK_TYPE_METRICS;
|
|
|
|
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
ctx_server.queue_tasks.post(task);
|
|
|
|
// get the result
|
|
server_task_result result = ctx_server.queue_results.recv(task.id);
|
|
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
|
|
res.set_content(result.data.at("slots").dump(), "application/json");
|
|
res.status = 200; // HTTP OK
|
|
};
|
|
|
|
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
|
if (!params.endpoint_metrics) {
|
|
res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
|
return;
|
|
}
|
|
|
|
// request slots data using task queue
|
|
server_task task;
|
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
task.id_multi = -1;
|
|
task.id_target = -1;
|
|
task.type = SERVER_TASK_TYPE_METRICS;
|
|
task.data.push_back({{"reset_bucket", true}});
|
|
|
|
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
ctx_server.queue_tasks.post(task);
|
|
|
|
// get the result
|
|
server_task_result result = ctx_server.queue_results.recv(task.id);
|
|
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
|
|
json data = result.data;
|
|
|
|
const uint64_t n_prompt_tokens_processed = data.at("n_prompt_tokens_processed");
|
|
const uint64_t t_prompt_processing = data.at("t_prompt_processing");
|
|
|
|
const uint64_t n_tokens_predicted = data.at("n_tokens_predicted");
|
|
const uint64_t t_tokens_generation = data.at("t_tokens_generation");
|
|
|
|
const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells");
|
|
|
|
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
|
|
json all_metrics_def = json {
|
|
{"counter", {{
|
|
{"name", "prompt_tokens_total"},
|
|
{"help", "Number of prompt tokens processed."},
|
|
{"value", (uint64_t) data.at("n_prompt_tokens_processed_total")}
|
|
}, {
|
|
{"name", "prompt_seconds_total"},
|
|
{"help", "Prompt process time"},
|
|
{"value", (uint64_t) data.at("t_prompt_processing_total") / 1.e3}
|
|
}, {
|
|
{"name", "tokens_predicted_total"},
|
|
{"help", "Number of generation tokens processed."},
|
|
{"value", (uint64_t) data.at("n_tokens_predicted_total")}
|
|
}, {
|
|
{"name", "tokens_predicted_seconds_total"},
|
|
{"help", "Predict process time"},
|
|
{"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3}
|
|
}}},
|
|
{"gauge", {{
|
|
{"name", "prompt_tokens_seconds"},
|
|
{"help", "Average prompt throughput in tokens/s."},
|
|
{"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.}
|
|
},{
|
|
{"name", "predicted_tokens_seconds"},
|
|
{"help", "Average generation throughput in tokens/s."},
|
|
{"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}
|
|
},{
|
|
{"name", "kv_cache_usage_ratio"},
|
|
{"help", "KV-cache usage. 1 means 100 percent usage."},
|
|
{"value", 1. * kv_cache_used_cells / params.n_ctx}
|
|
},{
|
|
{"name", "kv_cache_tokens"},
|
|
{"help", "KV-cache tokens."},
|
|
{"value", (uint64_t) data.at("kv_cache_tokens_count")}
|
|
},{
|
|
{"name", "requests_processing"},
|
|
{"help", "Number of request processing."},
|
|
{"value", (uint64_t) data.at("processing")}
|
|
},{
|
|
{"name", "requests_deferred"},
|
|
{"help", "Number of request deferred."},
|
|
{"value", (uint64_t) data.at("deferred")}
|
|
}}}
|
|
};
|
|
|
|
std::stringstream prometheus;
|
|
|
|
for (const auto & el : all_metrics_def.items()) {
|
|
const auto & type = el.key();
|
|
const auto & metrics_def = el.value();
|
|
|
|
for (const auto & metric_def : metrics_def) {
|
|
const std::string name = metric_def.at("name");
|
|
const std::string help = metric_def.at("help");
|
|
|
|
auto value = json_value(metric_def, "value", 0.);
|
|
prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
|
|
<< "# TYPE llamacpp:" << name << " " << type << "\n"
|
|
<< "llamacpp:" << name << " " << value << "\n";
|
|
}
|
|
}
|
|
|
|
const int64_t t_start = data.at("t_start");
|
|
res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
|
|
|
|
res.set_content(prometheus.str(), "text/plain; version=0.0.4");
|
|
res.status = 200; // HTTP OK
|
|
};
|
|
|
|
const auto handle_slots_save = [&ctx_server, &res_error, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
|
json request_data = json::parse(req.body);
|
|
std::string filename = request_data.at("filename");
|
|
if (!fs_validate_filename(filename)) {
|
|
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
|
return;
|
|
}
|
|
std::string filepath = params.slot_save_path + filename;
|
|
|
|
server_task task;
|
|
task.type = SERVER_TASK_TYPE_SLOT_SAVE;
|
|
task.data = {
|
|
{ "id_slot", id_slot },
|
|
{ "filename", filename },
|
|
{ "filepath", filepath }
|
|
};
|
|
|
|
const int id_task = ctx_server.queue_tasks.post(task);
|
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
|
|
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
|
|
if (result.error) {
|
|
res_error(res, result.data);
|
|
} else {
|
|
res.set_content(result.data.dump(), "application/json");
|
|
}
|
|
};
|
|
|
|
const auto handle_slots_restore = [&ctx_server, &res_error, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
|
json request_data = json::parse(req.body);
|
|
std::string filename = request_data.at("filename");
|
|
if (!fs_validate_filename(filename)) {
|
|
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
|
return;
|
|
}
|
|
std::string filepath = params.slot_save_path + filename;
|
|
|
|
server_task task;
|
|
task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
|
|
task.data = {
|
|
{ "id_slot", id_slot },
|
|
{ "filename", filename },
|
|
{ "filepath", filepath }
|
|
};
|
|
|
|
const int id_task = ctx_server.queue_tasks.post(task);
|
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
|
|
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
|
|
if (result.error) {
|
|
res_error(res, result.data);
|
|
} else {
|
|
res.set_content(result.data.dump(), "application/json");
|
|
}
|
|
};
|
|
|
|
const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
|
server_task task;
|
|
task.type = SERVER_TASK_TYPE_SLOT_ERASE;
|
|
task.data = {
|
|
{ "id_slot", id_slot },
|
|
};
|
|
|
|
const int id_task = ctx_server.queue_tasks.post(task);
|
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
|
|
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
|
|
if (result.error) {
|
|
res_error(res, result.data);
|
|
} else {
|
|
res.set_content(result.data.dump(), "application/json");
|
|
}
|
|
};
|
|
|
|
const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
|
|
std::string id_slot_str = req.path_params.at("id_slot");
|
|
int id_slot;
|
|
|
|
try {
|
|
id_slot = std::stoi(id_slot_str);
|
|
} catch (const std::exception &) {
|
|
res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
|
|
return;
|
|
}
|
|
|
|
std::string action = req.get_param_value("action");
|
|
|
|
if (action == "save") {
|
|
handle_slots_save(req, res, id_slot);
|
|
} else if (action == "restore") {
|
|
handle_slots_restore(req, res, id_slot);
|
|
} else if (action == "erase") {
|
|
handle_slots_erase(req, res, id_slot);
|
|
} else {
|
|
res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
|
|
}
|
|
};
|
|
|
|
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
|
std::string template_key = "tokenizer.chat_template", curr_tmpl;
|
|
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
|
|
if (tlen > 0) {
|
|
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
|
if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
|
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
|
|
}
|
|
}
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
json data = {
|
|
{ "system_prompt", ctx_server.system_prompt.c_str() },
|
|
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
|
{ "total_slots", ctx_server.params.n_parallel },
|
|
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
|
|
{ "bos_token", llama_token_to_piece(ctx_server.ctx, llama_token_bos(ctx_server.model), /* special= */ true)},
|
|
{ "eos_token", llama_token_to_piece(ctx_server.ctx, llama_token_eos(ctx_server.model), /* special= */ true)},
|
|
{ "model_path", ctx_server.params.model },
|
|
{ "n_ctx", ctx_server.n_ctx }
|
|
|
|
};
|
|
|
|
if (ctx_server.params.use_jinja) {
|
|
if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
|
|
data["chat_template_tool_use"] = tool_use_src;
|
|
}
|
|
}
|
|
res.set_content(data.dump(), "application/json; charset=utf-8");
|
|
};
|
|
|
|
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
|
if (ctx_server.params.embedding) {
|
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
return;
|
|
}
|
|
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
auto data = json::parse(req.body);
|
|
const int id_task = ctx_server.queue_tasks.get_new_id();
|
|
|
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
|
ctx_server.request_completion(id_task, -1, data, false, false);
|
|
|
|
if (!json_value(data, "stream", false)) {
|
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
if (!result.error && result.stop) {
|
|
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
|
} else {
|
|
res_error(res, result.data);
|
|
}
|
|
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
} else {
|
|
const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
|
|
while (true) {
|
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
if (!result.error) {
|
|
const std::string str =
|
|
"data: " +
|
|
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
"\n\n";
|
|
|
|
LOG_VERBOSE("data stream", {
|
|
{ "to_send", str }
|
|
});
|
|
|
|
if (!sink.write(str.c_str(), str.size())) {
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
return false;
|
|
}
|
|
|
|
if (result.stop) {
|
|
break;
|
|
}
|
|
} else {
|
|
const std::string str =
|
|
"error: " +
|
|
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
"\n\n";
|
|
|
|
LOG_VERBOSE("data stream", {
|
|
{ "to_send", str }
|
|
});
|
|
|
|
if (!sink.write(str.c_str(), str.size())) {
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
return false;
|
|
}
|
|
|
|
break;
|
|
}
|
|
}
|
|
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
sink.done();
|
|
|
|
return true;
|
|
};
|
|
|
|
auto on_complete = [id_task, &ctx_server] (bool) {
|
|
// cancel
|
|
ctx_server.request_cancel(id_task);
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
};
|
|
|
|
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
|
}
|
|
};
|
|
|
|
const auto handle_completions_oai = [&ctx_server, &res_error](const httplib::Request& req, httplib::Response& res) {
|
|
if (ctx_server.params.embedding) {
|
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
return;
|
|
}
|
|
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
auto body = json::parse(req.body);
|
|
json data = oaicompat_chat_params_parse(body);
|
|
const int id_task = ctx_server.queue_tasks.get_new_id();
|
|
const auto completion_id = gen_chatcmplid();
|
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
|
ctx_server.request_completion(id_task, -1, data, false, false);
|
|
|
|
if (!json_value(data, "stream", false)) {
|
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
if (!result.error && result.stop) {
|
|
result.oaicompat_cmpl_id = completion_id;
|
|
result.oaicompat = OAICOMPAT_TYPE_COMPLETION;
|
|
json result_oai = result.to_json_final();
|
|
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
|
}
|
|
else {
|
|
res_error(res, result.data);
|
|
}
|
|
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
}
|
|
else {
|
|
const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink& sink) {
|
|
while (true) {
|
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
result.oaicompat = OAICOMPAT_TYPE_COMPLETION;
|
|
json result_oai;
|
|
if (result.final_result) {
|
|
result_oai = result.to_json_final();
|
|
}
|
|
else {
|
|
result_oai = result.to_json_partial(); // format_final_response_oaicompat(data, result.data, completion_id);
|
|
}
|
|
if (!result.error) {
|
|
const std::string str =
|
|
"data: " +
|
|
result_oai.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
"\n\n";
|
|
|
|
LOG_VERBOSE("data stream", {
|
|
{ "to_send", str }
|
|
});
|
|
|
|
if (!sink.write(str.c_str(), str.size())) {
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
return false;
|
|
}
|
|
|
|
if (result.stop) {
|
|
break;
|
|
}
|
|
}
|
|
else {
|
|
const std::string str =
|
|
"error: " +
|
|
result_oai.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
"\n\n";
|
|
|
|
LOG_VERBOSE("data stream", {
|
|
{ "to_send", str }
|
|
});
|
|
|
|
if (!sink.write(str.c_str(), str.size())) {
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
return false;
|
|
}
|
|
|
|
break;
|
|
}
|
|
}
|
|
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
sink.done();
|
|
|
|
return true;
|
|
};
|
|
|
|
auto on_complete = [id_task, &ctx_server](bool) {
|
|
// cancel
|
|
ctx_server.request_cancel(id_task);
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
};
|
|
|
|
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
|
}
|
|
};
|
|
|
|
const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
|
|
json models = {
|
|
{"object", "list"},
|
|
{"data", {
|
|
{
|
|
{"id", params.model_alias},
|
|
{"object", "model"},
|
|
{"created", std::time(0)},
|
|
{"owned_by", "llamacpp"},
|
|
{"meta", model_meta}
|
|
},
|
|
}}
|
|
};
|
|
|
|
res.set_content(models.dump(), "application/json; charset=utf-8");
|
|
};
|
|
|
|
|
|
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
|
if (ctx_server.params.embedding) {
|
|
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
return;
|
|
}
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
|
|
auto body = json::parse(req.body);
|
|
json data = oaicompat_chat_params_parse(ctx_server.model, body, ctx_server.oai_parser_opt);
|
|
const int id_task = ctx_server.queue_tasks.get_new_id();
|
|
|
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
|
ctx_server.request_completion(id_task, -1, data, false, false);
|
|
const auto completion_id = gen_chatcmplid();
|
|
if (!json_value(data, "stream", false)) {
|
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
result.oaicompat = OAICOMPAT_TYPE_CHAT;
|
|
result.oaicompat_cmpl_id = completion_id;
|
|
json result_oai;
|
|
if (result.final_result) {
|
|
result_oai = result.to_json_final();
|
|
}
|
|
else {
|
|
result_oai = result.to_json_partial();
|
|
}
|
|
if (!result.error && result.stop) {
|
|
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
|
} else {
|
|
res_error(res, result_oai);
|
|
}
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
} else {
|
|
const auto chunked_content_provider = [id_task, &ctx_server, completion_id, send_done = params.send_done](size_t, httplib::DataSink & sink) {
|
|
bool successful_completion = false;
|
|
while (true) {
|
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
if (!result.error) {
|
|
result.oaicompat = OAICOMPAT_TYPE_CHAT;
|
|
result.oaicompat_cmpl_id = completion_id;
|
|
json result_array;
|
|
if (result.final_result) {
|
|
result_array = result.to_json_final();
|
|
}
|
|
else {
|
|
result_array = result.to_json_partial();
|
|
}
|
|
if (result_array.is_array()) {
|
|
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
|
|
if (!it->empty()) {
|
|
const std::string str =
|
|
"data: " +
|
|
it->dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
"\n\n";
|
|
LOG_VERBOSE("data stream", {{"to_send", str}});
|
|
if (!sink.write(str.c_str(), str.size())) {
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
if (result.stop) {
|
|
successful_completion = true;
|
|
break;
|
|
}
|
|
}
|
|
} else {
|
|
const std::string str =
|
|
"error: " +
|
|
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
"\n\n";
|
|
LOG_VERBOSE("data stream", {{"to_send", str}});
|
|
if (!sink.write(str.c_str(), str.size())) {
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
return false;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
bool ok = true;
|
|
if (successful_completion) {
|
|
static const std::string done_message = "data: [DONE]\n\n";
|
|
LOG_VERBOSE("data stream", {{"to_send", done_message}});
|
|
if (!sink.write(done_message.c_str(), done_message.size())) {
|
|
// If writing [DONE] fails, the stream is likely already problematic.
|
|
ok = false;
|
|
}
|
|
}
|
|
sink.done();
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
return ok;
|
|
};
|
|
|
|
auto on_complete = [id_task, &ctx_server](bool) {
|
|
// cancel request
|
|
ctx_server.request_cancel(id_task);
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
};
|
|
|
|
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
|
}
|
|
};
|
|
|
|
// same with handle_chat_completions, but without inference part
|
|
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request& req, httplib::Response& res) {
|
|
auto body = json::parse(req.body);
|
|
json data = oaicompat_chat_params_parse(ctx_server.model, body,ctx_server.oai_parser_opt);
|
|
res_ok(res, { { "prompt", std::move(data.at("prompt")) } });
|
|
};
|
|
|
|
const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
|
if (ctx_server.params.embedding) {
|
|
res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
return;
|
|
}
|
|
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
|
|
json data = json::parse(req.body);
|
|
|
|
const int id_task = ctx_server.queue_tasks.get_new_id();
|
|
|
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
|
ctx_server.request_completion(id_task, -1, data, true, false);
|
|
|
|
if (!json_value(data, "stream", false)) {
|
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
if (!result.error && result.stop) {
|
|
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
|
} else {
|
|
res_error(res, result.data);
|
|
}
|
|
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
} else {
|
|
const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
|
|
while (true) {
|
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
if (!result.error) {
|
|
const std::string str =
|
|
"data: " +
|
|
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
"\n\n";
|
|
|
|
LOG_VERBOSE("data stream", {
|
|
{ "to_send", str }
|
|
});
|
|
|
|
if (!sink.write(str.c_str(), str.size())) {
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
return false;
|
|
}
|
|
|
|
if (result.stop) {
|
|
break;
|
|
}
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
sink.done();
|
|
|
|
return true;
|
|
};
|
|
|
|
auto on_complete = [id_task, &ctx_server] (bool) {
|
|
ctx_server.request_cancel(id_task);
|
|
};
|
|
|
|
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
|
}
|
|
};
|
|
|
|
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
const json body = json::parse(req.body);
|
|
|
|
std::vector<llama_token> tokens;
|
|
if (body.count("content") != 0) {
|
|
const bool add_special = json_value(body, "add_special", false);
|
|
tokens = ctx_server.tokenize(body.at("content"), add_special);
|
|
}
|
|
const json data = format_tokenizer_response(tokens);
|
|
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
|
};
|
|
|
|
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
const json body = json::parse(req.body);
|
|
|
|
std::string content;
|
|
if (body.count("tokens") != 0) {
|
|
const std::vector<llama_token> tokens = body.at("tokens");
|
|
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
|
|
}
|
|
|
|
const json data = format_detokenized_response(content);
|
|
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
|
};
|
|
|
|
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
|
|
const json body = json::parse(req.body);
|
|
bool is_openai = false;
|
|
|
|
// an input prompt can be a string or a list of tokens (integer)
|
|
json prompt;
|
|
if (body.count("input") != 0) {
|
|
is_openai = true;
|
|
prompt = body.at("input");
|
|
} else if (body.count("content") != 0) {
|
|
// with "content", we only support single prompt
|
|
prompt = std::vector<std::string>{body.at("content")};
|
|
} else {
|
|
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
|
return;
|
|
}
|
|
|
|
// create and queue the task
|
|
json responses;
|
|
{
|
|
const int id_task = ctx_server.queue_tasks.get_new_id();
|
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
|
ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true);
|
|
|
|
// get the result
|
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
if (!result.error) {
|
|
if (result.data.count("results")) {
|
|
// result for multi-task
|
|
responses = result.data.at("results");
|
|
} else {
|
|
// result for single task
|
|
responses = std::vector<json>{result.data};
|
|
}
|
|
} else {
|
|
// error received, ignore everything else
|
|
res_error(res, result.data);
|
|
return;
|
|
}
|
|
}
|
|
|
|
// write JSON response
|
|
json root = is_openai
|
|
? format_embeddings_response_oaicompat(body, responses)
|
|
: responses[0];
|
|
return res.set_content(root.dump(), "application/json; charset=utf-8");
|
|
};
|
|
|
|
const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) {
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
json result = json::array();
|
|
for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
|
|
auto & la = ctx_server.lora_adapters[i];
|
|
result.push_back({
|
|
{"id", i},
|
|
{"path", la.path},
|
|
{"scale", la.scale},
|
|
});
|
|
}
|
|
res.set_content(result.dump(), "application/json");
|
|
res.status = 200; // HTTP OK
|
|
};
|
|
|
|
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
|
|
const std::vector<json> body = json::parse(req.body);
|
|
int max_idx = ctx_server.lora_adapters.size();
|
|
|
|
// clear existing value
|
|
for (auto & la : ctx_server.lora_adapters) {
|
|
la.scale = 0.0f;
|
|
}
|
|
|
|
// set value
|
|
for (auto entry : body) {
|
|
int id = entry.at("id");
|
|
float scale = entry.at("scale");
|
|
if (0 <= id && id < max_idx) {
|
|
ctx_server.lora_adapters[id].scale = scale;
|
|
} else {
|
|
throw std::runtime_error("invalid adapter id");
|
|
}
|
|
}
|
|
|
|
server_task task;
|
|
task.type = SERVER_TASK_TYPE_SET_LORA;
|
|
const int id_task = ctx_server.queue_tasks.post(task);
|
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
|
|
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
|
|
res.set_content(result.data.dump(), "application/json");
|
|
res.status = 200; // HTTP OK
|
|
};
|
|
|
|
const auto list_saved_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) {
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
json response = json::array();
|
|
namespace fs = std::filesystem;
|
|
|
|
try {
|
|
for (const auto& entry : fs::directory_iterator(params.slot_save_path)) {
|
|
if (!entry.is_regular_file() || entry.file_size() < 12) {
|
|
continue;
|
|
}
|
|
|
|
std::ifstream file(entry.path(), std::ios::binary);
|
|
if (!file) continue;
|
|
|
|
uint32_t magic, version, n_token_count;
|
|
file.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
|
file.read(reinterpret_cast<char*>(&version), sizeof(version));
|
|
file.read(reinterpret_cast<char*>(&n_token_count), sizeof(n_token_count));
|
|
|
|
if (magic != LLAMA_STATE_SEQ_MAGIC ||
|
|
version != LLAMA_STATE_SEQ_VERSION ||
|
|
entry.file_size() < (12 + (n_token_count * sizeof(llama_token)))) {
|
|
continue;
|
|
}
|
|
|
|
std::vector<llama_token> tokens(n_token_count);
|
|
file.read(reinterpret_cast<char*>(tokens.data()), tokens.size() * sizeof(llama_token));
|
|
|
|
//C++17 is not modern enough to have a nice and portable way to get the mtime of a file
|
|
//so the following seems to be needed
|
|
auto ftime = fs::last_write_time(entry.path());
|
|
auto system_time = std::chrono::time_point_cast<std::chrono::system_clock::duration>(
|
|
ftime - fs::file_time_type::clock::now() + std::chrono::system_clock::now()
|
|
);
|
|
std::time_t c_time = std::chrono::system_clock::to_time_t(system_time);
|
|
std::tm tm_struct;
|
|
#if defined(_WIN32)
|
|
localtime_s(&tm_struct, &c_time);
|
|
#else
|
|
localtime_r(&c_time, &tm_struct);
|
|
#endif
|
|
std::ostringstream oss;
|
|
oss << std::put_time(&tm_struct, "%Y-%m-%d %H:%M:%S");
|
|
auto str_time = oss.str();
|
|
|
|
|
|
response.push_back({
|
|
{"filename", entry.path().filename().string()},
|
|
{"filesize", entry.file_size()},
|
|
{"mtime", str_time},
|
|
{"token_count", n_token_count},
|
|
{"prompt", tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend())}
|
|
});
|
|
}
|
|
} catch (const std::exception& e) {
|
|
res.status = 500;
|
|
response = {{"error", e.what()}};
|
|
}
|
|
res.set_content(response.dump(), "application/json; charset=utf-8");
|
|
};
|
|
|
|
const auto list_slot_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) {
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
json response = json::array();
|
|
for (server_slot & slot : ctx_server.slots) {
|
|
response.push_back({
|
|
{"slot_id", slot.id},
|
|
{"token_count", slot.cache_tokens.size()},
|
|
{"prompt", tokens_to_str(ctx_server.ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cend())}
|
|
});
|
|
}
|
|
res.set_content(response.dump(), "application/json; charset=utf-8");
|
|
};
|
|
|
|
|
|
const auto delete_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void {
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
json response;
|
|
namespace fs = std::filesystem;
|
|
|
|
try {
|
|
const json body = json::parse(req.body);
|
|
const std::string filename_str = body.at("filename");
|
|
|
|
// prevent directory traversal attacks
|
|
if (filename_str.find("..") != std::string::npos || filename_str.find('/') != std::string::npos || filename_str.find('\\') != std::string::npos) {
|
|
res.status = 400;
|
|
response = {{"error", "Invalid filename format."}};
|
|
res.set_content(response.dump(), "application/json; charset=utf-8");
|
|
return;
|
|
}
|
|
|
|
const fs::path file_to_delete = fs::path(params.slot_save_path) / fs::path(filename_str);
|
|
|
|
if (!fs::exists(file_to_delete) || !fs::is_regular_file(file_to_delete)) {
|
|
res.status = 404;
|
|
response = {{"error", "File not found."}};
|
|
res.set_content(response.dump(), "application/json; charset=utf-8");
|
|
return;
|
|
}
|
|
|
|
if (fs::remove(file_to_delete)) {
|
|
response = {
|
|
{"status", "deleted"},
|
|
{"filename", filename_str}
|
|
};
|
|
} else {
|
|
res.status = 500;
|
|
response = {{"error", "Failed to delete the file."}};
|
|
}
|
|
} catch (const json::parse_error& e) {
|
|
res.status = 400;
|
|
response = {{"error", "Invalid JSON request body."}};
|
|
} catch (const json::out_of_range& e) {
|
|
res.status = 400;
|
|
response = {{"error", "Missing 'filename' key in request body."}};
|
|
} catch (const std::exception& e) {
|
|
res.status = 500;
|
|
response = {{"error", e.what()}};
|
|
}
|
|
res.set_content(response.dump(), "application/json; charset=utf-8");
|
|
};
|
|
|
|
const auto rename_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void {
|
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
json response;
|
|
namespace fs = std::filesystem;
|
|
|
|
try {
|
|
const json body = json::parse(req.body);
|
|
const std::string old_filename_str = body.at("old_filename");
|
|
const std::string new_filename_str = body.at("new_filename");
|
|
|
|
if (old_filename_str.find("..") != std::string::npos || old_filename_str.find_first_of("/\\") != std::string::npos ||
|
|
new_filename_str.find("..") != std::string::npos || new_filename_str.find_first_of("/\\") != std::string::npos) {
|
|
res.status = 400;
|
|
response = {{"error", "Invalid filename format."}};
|
|
res.set_content(response.dump(), "application/json; charset=utf-8");
|
|
return;
|
|
}
|
|
|
|
const fs::path old_path = fs::path(params.slot_save_path) / old_filename_str;
|
|
const fs::path new_path = fs::path(params.slot_save_path) / new_filename_str;
|
|
|
|
if (!fs::exists(old_path) || !fs::is_regular_file(old_path)) {
|
|
res.status = 404;
|
|
response = {{"error", "Source file not found."}};
|
|
res.set_content(response.dump(), "application/json; charset=utf-8");
|
|
return;
|
|
}
|
|
|
|
if (fs::exists(new_path)) {
|
|
res.status = 409;
|
|
response = {{"error", "Destination filename already exists."}};
|
|
res.set_content(response.dump(), "application/json; charset=utf-8");
|
|
return;
|
|
}
|
|
|
|
std::error_code ec;
|
|
fs::rename(old_path, new_path, ec);
|
|
|
|
if (ec) {
|
|
res.status = 500;
|
|
response = {{"error", "Failed to rename file: " + ec.message()}};
|
|
} else {
|
|
response = {
|
|
{"status", "renamed"},
|
|
{"old_filename", old_filename_str},
|
|
{"new_filename", new_filename_str}
|
|
};
|
|
}
|
|
|
|
} catch (const json::parse_error& e) {
|
|
res.status = 400;
|
|
response = {{"error", "Invalid JSON request body."}};
|
|
} catch (const json::out_of_range& e) {
|
|
res.status = 400;
|
|
response = {{"error", "Missing 'old_filename' or 'new_filename' in request body."}};
|
|
} catch (const std::exception& e) {
|
|
res.status = 500;
|
|
response = {{"error", e.what()}};
|
|
}
|
|
|
|
res.set_content(response.dump(), "application/json; charset=utf-8");
|
|
};
|
|
|
|
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
|
|
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
|
|
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
|
|
return false;
|
|
};
|
|
};
|
|
#ifdef SQLITE3_MODERN_CPP_SUPPORT
|
|
const auto handle_version = [¶ms, sqlite_extension_loaded](const httplib::Request&, httplib::Response& res) {
|
|
res.set_content(
|
|
json{{"version", 4},
|
|
{"features", {{"sql", !params.sql_save_file.empty()}, {"zstd_compression", sqlite_extension_loaded}}}}.dump(),
|
|
"application/json"
|
|
);
|
|
};
|
|
#else
|
|
const auto handle_version = [](const httplib::Request&, httplib::Response& res)-> void {
|
|
res.set_content(
|
|
json{{"version", 4},
|
|
{"features", {{"sql", false}, {"zstd_compression", false}}}}.dump(),
|
|
"application/json"
|
|
);
|
|
};
|
|
#endif
|
|
|
|
#ifdef SQLITE3_MODERN_CPP_SUPPORT
|
|
auto db_handler = [db_handle](auto func) {
|
|
return [func, db_handle](const httplib::Request& req, httplib::Response& res) {
|
|
res.set_header("Access-Control-Allow-Origin", "*");
|
|
try {
|
|
const json body = !req.body.empty() ? json::parse(req.body) : json::object();
|
|
func(*db_handle, body, req, res);
|
|
} catch(const std::exception& e) {
|
|
res.status = 500;
|
|
res.set_content(
|
|
json{{"ok", false}, {"message", e.what()}}.dump(),
|
|
"application/json"
|
|
);
|
|
}
|
|
};
|
|
};
|
|
#else
|
|
auto db_handler = [db_handle](auto func) {
|
|
return [func, db_handle](const httplib::Request& req, httplib::Response& res) {
|
|
res.set_header("Access-Control-Allow-Origin", "*");
|
|
res.status = 500;
|
|
res.set_content(
|
|
json{{"ok", false}, {"message", "Sqlite3 support was not enabled. Recompile with '-DLLAMA_SERVER_SQLITE3=ON'"}}.dump(),
|
|
"application/json"
|
|
);
|
|
};
|
|
};
|
|
#endif
|
|
|
|
const auto normalize_store_name = [](const std::string& storeName) {
|
|
if(storeName.empty()) return std::string("sessions");
|
|
|
|
std::string normalized;
|
|
normalized.reserve(storeName.size());
|
|
|
|
for(char c : storeName) {
|
|
if(std::isalpha(static_cast<unsigned char>(c))) {
|
|
normalized.push_back(std::tolower(static_cast<unsigned char>(c)));
|
|
}
|
|
}
|
|
|
|
return normalized.empty() ? "sessions" : normalized;
|
|
};
|
|
|
|
const auto get_key_string = [](const json& j) {
|
|
return j.is_string() ? j.get<std::string>() : j.dump();
|
|
};
|
|
|
|
|
|
const auto handle_load = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) {
|
|
std::string data;
|
|
const std::string store = normalize_store_name(body["storeName"]);
|
|
db.db << "SELECT data FROM " + store + " WHERE key = ?" << get_key_string(body["key"]) >> data;
|
|
if(data.empty()) {
|
|
res.status = 404;
|
|
res.set_content(json{{"ok", false}, {"message", "Key not found"}}.dump(), "application/json");
|
|
} else {
|
|
json response{{"ok", true}};
|
|
response["result"] = (store == "names") ? json(data) : json::parse(data);
|
|
res.set_content(response.dump(), "application/json");
|
|
}
|
|
});
|
|
|
|
const auto handle_save = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) {
|
|
const std::string store = normalize_store_name(body["storeName"]);
|
|
const std::string data = (store == "names") ? body["data"].get<std::string>() : body["data"].dump();
|
|
db.db << "INSERT OR REPLACE INTO " + store + " (key, data) VALUES (?, ?)" << get_key_string(body["key"]) << data;
|
|
res.set_content(json{{"ok", true}, {"result", "Data saved successfully"}}.dump(), "application/json");
|
|
});
|
|
|
|
const auto handle_rename = db_handler([get_key_string](auto& db, const json& body, auto&, auto& res) {
|
|
db.db << "UPDATE names SET data = ? WHERE key = ?"
|
|
<< body["newName"].get<std::string>()
|
|
<< get_key_string(body["key"]);
|
|
res.set_content(json{{"ok", true}, {"result", "Session renamed successfully"}}.dump(), "application/json");
|
|
});
|
|
|
|
const auto handle_all = db_handler([normalize_store_name](auto& db, const json& body, auto&, auto& res) {
|
|
json result = json::object();
|
|
db.db << "SELECT key, data FROM " + normalize_store_name(body["storeName"]) >>
|
|
[&](const std::string& key, const std::string& data) {
|
|
result[key] = json::parse(data);
|
|
};
|
|
res.set_content(json{{"ok", true}, {"result", result}}.dump(), "application/json");
|
|
});
|
|
|
|
const auto handle_sessions = db_handler([](auto& db, const json& body, auto&, auto& res) {
|
|
json result = json::object();
|
|
db.db << "SELECT key, data FROM names" >> [&](const std::string& key, const std::string& data) {
|
|
result[key] = data;
|
|
};
|
|
res.set_content(json{{"ok", true}, {"result", result}}.dump(), "application/json");
|
|
});
|
|
|
|
const auto handle_delete = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) {
|
|
db.db << "DELETE FROM " + normalize_store_name(body["storeName"]) + " WHERE key = ?"
|
|
<< get_key_string(body["key"]);
|
|
res.set_content(json{{"ok", true}, {"result", "Session deleted successfully"}}.dump(), "application/json");
|
|
});
|
|
|
|
const auto handle_vacuum = db_handler([](auto& db, const json& body, auto&, auto& res) {
|
|
json result = json::object();
|
|
db.db << "VACUUM";
|
|
res.set_content(json{"ok", true}.dump(), "application/json");
|
|
});
|
|
|
|
const auto handle_zstd_get_configs = db_handler([](auto& db, const json& body, auto&, auto& res) {
|
|
json result = json::object();
|
|
db.db << "SELECT id, config FROM _zstd_configs" >> [&](const std::string id, const std::string& config) {
|
|
result[id] = config;
|
|
};
|
|
res.set_content(json{{"ok", true}, {"configs", result}}.dump(), "application/json");
|
|
});
|
|
|
|
const auto handle_zstd_maintenance = db_handler([](auto& db, const json& body, auto&, auto& res) {
|
|
std::string data;
|
|
if (body["duration"].is_null()) {
|
|
db.db << "select zstd_incremental_maintenance(?, ?)" << nullptr << body["db_load"].get<double>() >> data;
|
|
}
|
|
else {
|
|
db.db << "select zstd_incremental_maintenance(?, ?)" << body["duration"].get<double>() << body["db_load"].get<double>() >> data;
|
|
}
|
|
json response{{"ok", true}};
|
|
response["result"] = json::parse(data);
|
|
res.set_content(response.dump(), "application/json");
|
|
});
|
|
|
|
const auto handle_zstd_enable = db_handler([](auto& db, const json& body, auto&, auto& res) {
|
|
db.db << "select zstd_enable_transparent('{\"table\": \"" + body["table"].get<std::string>() + "\",\"column\": \"" + body["column"].get<std::string>() + "\", \"compression_level\": " + std::to_string(body["compression_level"].get<int>()) + ", \"dict_chooser\": \"''a''\", \"train_dict_samples_ratio\": " + std::to_string(body["train_dict_samples_ratio"].get<int>()) + "}')";
|
|
res.set_content(json{"ok", true}.dump(), "application/json");
|
|
});
|
|
|
|
const auto handle_zstd_config_update = db_handler([](auto& db, const json& body, auto&, auto& res) {
|
|
std::string patch_json = "{\"compression_level\": " + std::to_string(body["compression_level"].get<int>()) + ", \"train_dict_samples_ratio\": " + std::to_string(body["train_dict_samples_ratio"].get<int>()) + "}";
|
|
db.db << "update _zstd_configs set config = json_patch(config, '" + patch_json + "')";
|
|
res.set_content(json{{"ok", true}}.dump(), "application/json");
|
|
});
|
|
|
|
//
|
|
// Router
|
|
//
|
|
if (params.webui == COMMON_WEBUI_NONE) {
|
|
LLAMA_LOG_INFO("Web UI is disabled\n");
|
|
}
|
|
else {
|
|
// register static assets routes
|
|
if (!params.public_path.empty()) {
|
|
// Set the base directory for serving static files
|
|
svr->set_base_dir(params.public_path);
|
|
}
|
|
|
|
{
|
|
// register static assets routes
|
|
if (!params.public_path.empty()) {
|
|
// Set the base directory for serving static files
|
|
bool is_found = svr->set_mount_point("/", params.public_path);
|
|
if (!is_found) {
|
|
GGML_ABORT("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
|
|
return 1;
|
|
}
|
|
}
|
|
else {
|
|
|
|
// using embedded static index.html
|
|
svr->Get("/", [params](const httplib::Request& req, httplib::Response& res) {
|
|
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
|
|
res.set_content("Error: gzip is not supported by this browser", "text/plain");
|
|
}
|
|
else {
|
|
res.set_header("Content-Encoding", "gzip");
|
|
// COEP and COOP headers, required by pyodide (python interpreter)
|
|
res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
|
|
res.set_header("Cross-Origin-Opener-Policy", "same-origin");
|
|
if (params.webui == COMMON_WEBUI_AUTO) {
|
|
res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
|
|
}
|
|
else if (params.webui == COMMON_WEBUI_LLAMACPP) {
|
|
res.set_content(reinterpret_cast<const char*>(index_llamacpp_html_gz), index_llamacpp_html_gz_len, "text/html; charset=utf-8");
|
|
}
|
|
else {
|
|
res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
|
|
}
|
|
}
|
|
return false;
|
|
});
|
|
}
|
|
}
|
|
}
|
|
// register API routes
|
|
svr->Get ("/health", handle_health);
|
|
svr->Get ("/metrics", handle_metrics);
|
|
svr->Get ("/props", handle_props);
|
|
svr->Get ("/v1/models", handle_models);
|
|
svr->Post("/completion", handle_completions); // legacy
|
|
svr->Post("/completions", handle_completions); // legacy
|
|
svr->Post("/v1/completions", handle_completions_oai);
|
|
svr->Post("/chat/completions", handle_chat_completions);
|
|
svr->Post("/v1/chat/completions", handle_chat_completions);
|
|
svr->Post("/infill", handle_infill);
|
|
svr->Post("/embedding", handle_embeddings); // legacy
|
|
svr->Post("/embeddings", handle_embeddings);
|
|
svr->Post("/v1/embeddings", handle_embeddings);
|
|
svr->Post("/tokenize", handle_tokenize);
|
|
svr->Post("/detokenize", handle_detokenize);
|
|
// LoRA adapters hotswap
|
|
svr->Get ("/lora-adapters", handle_lora_adapters_list);
|
|
svr->Post("/lora-adapters", handle_lora_adapters_apply);
|
|
// Save & load slots
|
|
svr->Get ("/slots", handle_slots);
|
|
svr->Get ("/slots/list", list_slot_prompts);
|
|
if (!params.slot_save_path.empty()) {
|
|
// these endpoints rely on slot_save_path existing
|
|
svr->Post("/slots/:id_slot", handle_slots_action);
|
|
svr->Get ("/list", list_saved_prompts);
|
|
svr->Post("/delete_prompt", delete_saved_prompt);
|
|
svr->Post("/rename_prompt", rename_saved_prompt);
|
|
|
|
}
|
|
|
|
svr->Get ("/version", handle_version);
|
|
if (!params.sql_save_file.empty()) {
|
|
// these endpoints rely on sql_save_file existing
|
|
svr->Post("/load", handle_load);
|
|
svr->Post("/save", handle_save);
|
|
svr->Post("/rename", handle_rename);
|
|
svr->Post("/all", handle_all);
|
|
svr->Post("/sessions", handle_sessions);
|
|
svr->Get ("/sessions", handle_sessions);
|
|
svr->Post("/delete", handle_delete);
|
|
//VACUUM is there for the extension but does not require the extension
|
|
svr->Get ("/vacuum", handle_vacuum);
|
|
#ifdef SQLITE3_MODERN_CPP_SUPPORT
|
|
if (sqlite_extension_loaded) {
|
|
svr->Get ("/zstd_get_configs", handle_zstd_get_configs);
|
|
svr->Post("/zstd_incremental_maintenance", handle_zstd_maintenance);
|
|
svr->Post("/zstd_enable_transparent", handle_zstd_enable);
|
|
svr->Post("/zstd_update_transparent", handle_zstd_config_update);
|
|
}
|
|
#endif
|
|
}
|
|
//
|
|
// Start the server
|
|
//
|
|
if (params.n_threads_http < 1) {
|
|
// +2 threads for monitoring endpoints
|
|
params.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
|
|
}
|
|
log_data["n_threads_http"] = std::to_string(params.n_threads_http);
|
|
svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); };
|
|
|
|
LOG_INFO("HTTP server listening", log_data);
|
|
|
|
// run the HTTP server in a thread - see comment below
|
|
std::thread t([&]() {
|
|
if (!svr->listen_after_bind()) {
|
|
state.store(SERVER_STATE_ERROR);
|
|
return 1;
|
|
}
|
|
|
|
return 0;
|
|
});
|
|
|
|
ctx_server.queue_tasks.on_new_task(std::bind(
|
|
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
|
ctx_server.queue_tasks.on_finish_multitask(std::bind(
|
|
&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
|
|
ctx_server.queue_tasks.on_update_slots(std::bind(
|
|
&server_context::update_slots, &ctx_server));
|
|
ctx_server.queue_results.on_multitask_update(std::bind(
|
|
&server_queue::update_multitask,
|
|
&ctx_server.queue_tasks,
|
|
std::placeholders::_1,
|
|
std::placeholders::_2,
|
|
std::placeholders::_3
|
|
));
|
|
|
|
shutdown_handler = [&](int) {
|
|
ctx_server.queue_tasks.terminate();
|
|
};
|
|
|
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
|
struct sigaction sigint_action;
|
|
sigint_action.sa_handler = signal_handler;
|
|
sigemptyset (&sigint_action.sa_mask);
|
|
sigint_action.sa_flags = 0;
|
|
sigaction(SIGINT, &sigint_action, NULL);
|
|
sigaction(SIGTERM, &sigint_action, NULL);
|
|
#elif defined (_WIN32)
|
|
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
|
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
|
|
};
|
|
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
|
#endif
|
|
|
|
ctx_server.queue_tasks.start_loop();
|
|
|
|
svr->stop();
|
|
t.join();
|
|
|
|
llama_backend_free();
|
|
|
|
return 0;
|
|
}
|