From 2e4bfed5832c132f3bed7ea5f2fdade21c3f32d8 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 20 Nov 2025 13:30:43 +0000 Subject: [PATCH] WIP: try syncing - not working yet --- include/llama.h | 382 +++--- src/llama-context.h | 3 +- src/llama-grammar.cpp | 666 +++++---- src/llama-grammar.h | 169 ++- src/llama-impl.h | 120 +- src/llama-sampling.cpp | 2925 ++++++++++++++++++++++++++++++---------- src/llama-sampling.h | 104 +- src/llama.cpp | 226 +--- 8 files changed, 2887 insertions(+), 1708 deletions(-) diff --git a/include/llama.h b/include/llama.h index 7682951e..f6609543 100644 --- a/include/llama.h +++ b/include/llama.h @@ -280,9 +280,12 @@ extern "C" { } llama_token_data; typedef struct llama_token_data_array { + // TODO: consider SoA + // NOTE: this pointer can be modified by the samplers llama_token_data * data; size_t size; - bool sorted; + int64_t selected; // this is the index in the data array (i.e. not the token id) + bool sorted; // note: do not assume the data is sorted - always check this flag } llama_token_data_array; typedef bool (*llama_progress_callback)(float progress, void * user_data); @@ -471,43 +474,6 @@ extern "C" { void * repack_pattern; // pointer to a vector containing regexes to be used for matching tensor names. Can be null } llama_model_quantize_params; - // grammar types - struct llama_grammar; - - // grammar element type - enum llama_gretype { - // end of rule definition - LLAMA_GRETYPE_END = 0, - - // start of alternate definition for rule - LLAMA_GRETYPE_ALT = 1, - - // non-terminal element: reference to rule - LLAMA_GRETYPE_RULE_REF = 2, - - // terminal element: character (code point) - LLAMA_GRETYPE_CHAR = 3, - - // inverse char(s) ([^a], [^a-b] [^abc]) - LLAMA_GRETYPE_CHAR_NOT = 4, - - // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to - // be an inclusive range ([a-z]) - LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, - - // modifies a preceding LLAMA_GRETYPE_CHAR or - // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) - LLAMA_GRETYPE_CHAR_ALT = 6, - - // any character (.) - LLAMA_GRETYPE_CHAR_ANY = 7, - }; - - typedef struct llama_grammar_element { - enum llama_gretype type; - uint32_t value; // Unicode code point or rule ID - } llama_grammar_element; - // performance timing information struct llama_timings { double t_start_ms; @@ -531,10 +497,15 @@ extern "C" { // lora adapter struct llama_lora_adapter; + typedef struct llama_sampler_chain_params { + bool no_perf; // whether to measure performance timings + } llama_sampler_chain_params; + // Helpers for getting default parameters LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); + LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void); // Initialize the llama + ggml backend // If numa is true, use NUMA optimizations @@ -1151,65 +1122,196 @@ extern "C" { typedef void* llama_sampler_context_t; - struct llama_sampler; - // user code can implement the interface below in order to create custom llama_sampler struct llama_sampler_i { - const char* (*name) (const struct llama_sampler*); // can be NULL - void (*accept)(struct llama_sampler*, llama_token); // can be NULL - void (*apply) (struct llama_sampler*, llama_token_data_array*); // required - void (*reset) (struct llama_sampler*); // can be NULL - struct llama_sampler* (*clone) (const struct llama_sampler*); // can be NULL if ctx is NULL - void (*free) (struct llama_sampler* smpl); // can be NULL if ctx is NULL + const char * (*name) (const struct llama_sampler * smpl); // can be NULL + void (*accept)( struct llama_sampler * smpl, llama_token token); // can be NULL + void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p); // required + void (*reset) ( struct llama_sampler * smpl); // can be NULL + struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL + void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL + + // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph + //void (*apply_ggml) (struct llama_sampler * smpl, ...); }; struct llama_sampler { - struct llama_sampler_i* iface; + const struct llama_sampler_i* iface; llama_sampler_context_t ctx; }; + typedef struct llama_logit_bias { + llama_token token; + float bias; + } llama_logit_bias; + + // mirror of llama_sampler_i: + LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); + LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); + LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); + LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); + LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); + // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) + LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); + // llama_sampler_chain + // a type of llama_sampler that can chain multiple samplers one after another + + LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params); + + // important: takes ownership of the sampler object and will free it when llama_sampler_free is called + LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i); + LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); + + // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed + LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i); + + // available samplers: + + LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); + LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); + + /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + /// Setting k <= 0 makes this a noop + LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); + + /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep); + + /// @details Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841 + LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep); + + /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. + LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep); + + /// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf + LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); + + /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. + LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent); + + /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 + LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); + + /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 + LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n); + + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + LLAMA_API struct llama_sampler * llama_sampler_init_mirostat( + int32_t n_vocab, + uint32_t seed, + float tau, + float eta, + int32_t m); + /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2( + uint32_t seed, + float tau, + float eta); + + /// @details Intializes a GBNF grammar, see grammars/README.md for details. + /// @param vocab The vocabulary that this grammar will be used with. + /// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails. + /// @param grammar_root The name of the start symbol for the grammar. + LLAMA_API struct llama_sampler * llama_sampler_init_grammar( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root); + + DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens), + "use llama_sampler_init_grammar_lazy_patterns instead"); + + + /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 + /// @param trigger_patterns A list of patterns that will trigger the grammar sampler. Pattern will be matched from the start of the generation output, and grammar sampler will be fed content starting from its first match group. + /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included. + LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy_patterns( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_patterns, + size_t num_trigger_patterns, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); + + + /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. + LLAMA_API struct llama_sampler * llama_sampler_init_penalties( + int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat, // 1.0 = disabled + float penalty_freq, // 0.0 = disabled + float penalty_present); // 0.0 = disabled + + /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 + LLAMA_API struct llama_sampler * llama_sampler_init_dry( + const struct llama_vocab * vocab, + int32_t n_ctx_train, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const char ** seq_breakers, + size_t num_breakers); + + LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( + int32_t n_vocab, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias); + + // this sampler is meant to be used for fill-in-the-middle infilling + // it's supposed to be used after top_k + top_p sampling // - // Grammar + // 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG + // 2. combine probs of tokens that have the same prefix // - - /// Initialize a llama_grammar. - /// - /// @param rules The rule elements of the grammar to initialize. - /// @param n_rules The number of rules. - /// @param start_rule_index The index of the root rule (the starting point of the grammar). - /// @return The initialized llama_grammar or nullptr if initialization failed. - LLAMA_API struct llama_grammar * llama_grammar_init( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index); - - struct llama_sampler_grammar; - LLAMA_API void llama_grammar_init_lazy(struct llama_sampler_grammar * grammar); - - LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); - - LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar); - - /// @details Apply constraints from grammar - LLAMA_API void llama_grammar_sample( - const struct llama_grammar * grammar, - const struct llama_context * ctx, - llama_token_data_array * candidates); - LLAMA_API DEPRECATED(void llama_sample_grammar( - struct llama_context * ctx, - llama_token_data_array * candidates, - const struct llama_grammar * grammar), - "use llama_grammar_sample instead"); - - /// @details Accepts the sampled token into the grammar - LLAMA_API void llama_grammar_accept_token( - struct llama_grammar * grammar, - struct llama_context * ctx, - llama_token token); - + // example: // - // Sampling functions + // - before: + // "hel": 0.5 + // "hell": 0.2 + // "hello": 0.1 + // "dummy": 0.1 // + // - after: + // "hel": 0.8 + // "dummy": 0.1 + // + // 3. discard non-EOG tokens with low prob + // 4. if no tokens are left -> pick EOT + // + LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab); + + // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise + LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); + + /// @details Sample and accept a token from the idx-th output of the last evaluation + // + // Shorthand for: + // const auto * logits = llama_get_logits_ith(ctx, idx); + // llama_token_data_array cur_p = { ... init from logits ... }; + // llama_sampler_apply(smpl, &cur_p); + // auto token = cur_p.data[cur_p.selected].id; + // llama_sampler_accept(smpl, token); + // return token; + // Returns the sampled token + LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); // Sets the current rng seed. LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); @@ -1305,47 +1407,6 @@ extern "C" { LLAMA_API void llama_sampler_reset(struct llama_sampler* smpl); -LLAMA_API struct llama_grammar* llama_sampler_init_grammar( - const struct llama_vocab* vocab, - const char* grammar_str, - - const char* grammar_root); - /// @details Lazy grammar sampler, introduced in https://github.com/ggerganov/llama.cpp/pull/9639 -/// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future. -/// @param trigger_tokens A list of tokens that will trigger the grammar sampler. -DEPRECATED(LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy( - const struct llama_vocab* vocab, - const char* grammar_str, - const char* grammar_root, - const char** trigger_words, - size_t num_trigger_words, - const llama_token* trigger_tokens, - size_t num_trigger_tokens), - "use llama_sampler_init_grammar_lazy_patterns instead"); - - -/// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 -/// @param trigger_patterns A list of patterns that will trigger the grammar sampler. Pattern will be matched from the start of the generation output, and grammar sampler will be fed content starting from its first match group. -/// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included. -LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns( - const struct llama_vocab* vocab, - const char* grammar_str, - const char* grammar_root, - const char** trigger_patterns, - size_t num_trigger_patterns, - const llama_token* trigger_tokens, - size_t num_trigger_tokens); - - /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 - LLAMA_API struct llama_sampler_dry * llama_sampler_init_dry( - const struct llama_vocab* model, - float dry_multiplier, - float dry_base, - int32_t dry_allowed_length, - int32_t dry_penalty_last_n, - const char** seq_breakers, - size_t num_breakers); - //LLAMA_API void llama_sample_dry(struct llama_context* ctx, llama_token_data_array* candidates_p, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers); void llama_sample_dry(struct llama_context* ctx, struct llama_sampler_dry* smpl, llama_token_data_array* candidates_p); @@ -1427,6 +1488,39 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns( LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx); + // + // Performance utils + // + // NOTE: Used by llama.cpp examples/tools, avoid using in third-party apps. Instead, do your own performance measurements. + // + + struct llama_perf_context_data { + // ms == milliseconds + double t_start_ms; // absolute start time + double t_load_ms; // time needed for loading the model + double t_p_eval_ms; // time needed for processing the prompt + double t_eval_ms; // time needed for generating tokens + + int32_t n_p_eval; // number of prompt tokens + int32_t n_eval; // number of generated tokens + int32_t n_reused; // number of times a ggml compute graph had been reused + }; + + struct llama_perf_sampler_data { + double t_sample_ms; // time needed for sampling in ms + + int32_t n_sample; // number of sampled tokens + }; + + LLAMA_API struct llama_perf_context_data llama_perf_context (const struct llama_context * ctx); + LLAMA_API void llama_perf_context_print(const struct llama_context * ctx); + LLAMA_API void llama_perf_context_reset( struct llama_context * ctx); + + // NOTE: the following work only with samplers constructed via llama_sampler_chain_init + LLAMA_API struct llama_perf_sampler_data llama_perf_sampler (const struct llama_sampler * chain); + LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); + LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); + #ifdef __cplusplus } #endif @@ -1444,42 +1538,6 @@ const std::vector> & llama_internal struct llama_context * ctx ); -struct llama_partial_utf8 { - uint32_t value; // bit value so far (unshifted) - int n_remain; // num bytes remaining; -1 indicates invalid sequence -}; - -struct llama_grammar_candidate { - size_t index; - const uint32_t * code_points; - llama_partial_utf8 partial_utf8; -}; - -using llama_grammar_rule = std::vector< llama_grammar_element>; -using llama_grammar_stack = std::vector; - -using llama_grammar_rules = std::vector; -using llama_grammar_stacks = std::vector; -using llama_grammar_candidates = std::vector; - -const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); - llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); - -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & new_stacks); - -std::vector llama_grammar_reject_candidates_for_stack( - const llama_grammar_rules & rules, - const llama_grammar_stack & stack, - const llama_grammar_candidates & candidates); - -std::pair, llama_partial_utf8> decode_utf8( - const std::string & src, - llama_partial_utf8 partial_start); - // Randomly selects a token from the candidates based on their probabilities using given std::mt19937. // This is a temporary workaround in order to fix race conditions when sampling with multiple sequences. llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng); diff --git a/src/llama-context.h b/src/llama-context.h index bb21d880..f0147892 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -10,6 +10,7 @@ struct llama_model; #include #include #include +#include struct llama_kv_cell { llama_pos pos = -1; @@ -120,7 +121,7 @@ struct llama_context { const struct llama_model & model; struct llama_cparams cparams; - struct llama_sampling sampling; + //struct llama_sampling sampling; struct llama_kv_cache kv_self; struct llama_control_vector cvec; diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 12f20dae..bed706bb 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1,5 +1,6 @@ #include "llama-grammar.h" -#include "llama-vocab.h" + +#include "llama-impl.h" #include "llama-vocab.h" #include "llama-sampling.h" @@ -7,25 +8,27 @@ #include #include +// +// helpers +// + // NOTE: assumes valid utf8 (but checks for overrun) -static std::pair decode_utf8(const char* src) { +static std::pair decode_utf8(const char * src) { static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; uint8_t first_byte = static_cast(*src); - uint8_t highbits = first_byte >> 4; - int len = lookup[highbits]; - uint8_t mask = (1 << (8 - len)) - 1; - uint32_t value = first_byte & mask; - const char* end = src + len; // may overrun! - const char* pos = src + 1; - for (; pos < end && *pos; pos++) { + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; + for ( ; pos < end && *pos; pos++) { value = (value << 6) + (static_cast(*pos) & 0x3F); } return std::make_pair(value, pos); } -// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as -// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. -std::pair, llama_partial_utf8> decode_utf8( +static std::pair, llama_partial_utf8> decode_utf8( const std::string & src, llama_partial_utf8 partial_start) { static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; @@ -58,7 +61,7 @@ std::pair, llama_partial_utf8> decode_utf8( while (*pos != 0) { uint8_t first_byte = static_cast(*pos); uint8_t highbits = first_byte >> 4; - n_remain = lookup[highbits] - 1; + n_remain = lookup[highbits] - 1; if (n_remain < 0) { // invalid sequence, abort @@ -68,7 +71,7 @@ std::pair, llama_partial_utf8> decode_utf8( } uint8_t mask = (1 << (7 - n_remain)) - 1; - value = first_byte & mask; + value = first_byte & mask; ++pos; while (*pos != 0 && n_remain > 0) { @@ -85,7 +88,6 @@ std::pair, llama_partial_utf8> decode_utf8( return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); } - static bool is_digit_char(char c) { return '0' <= c && c <= '9'; } @@ -94,23 +96,20 @@ static bool is_word_char(char c) { return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); } -static std::pair parse_hex(const char* src, int size) { - const char* pos = src; - const char* end = src + size; +static std::pair parse_hex(const char * src, int size) { + const char * pos = src; + const char * end = src + size; uint32_t value = 0; - for (; pos < end && *pos; pos++) { + for ( ; pos < end && *pos; pos++) { value <<= 4; char c = *pos; if ('a' <= c && c <= 'f') { value += c - 'a' + 10; - } - else if ('A' <= c && c <= 'F') { + } else if ('A' <= c && c <= 'F') { value += c - 'A' + 10; - } - else if ('0' <= c && c <= '9') { + } else if ('0' <= c && c <= '9') { value += c - '0'; - } - else { + } else { break; } } @@ -120,24 +119,23 @@ static std::pair parse_hex(const char* src, int size) { return std::make_pair(value, pos); } -static const char* parse_space(const char* src, bool newline_ok) { - const char* pos = src; +static const char * parse_space(const char * src, bool newline_ok) { + const char * pos = src; while (*pos == ' ' || *pos == '\t' || *pos == '#' || - (newline_ok && (*pos == '\r' || *pos == '\n'))) { + (newline_ok && (*pos == '\r' || *pos == '\n'))) { if (*pos == '#') { while (*pos && *pos != '\r' && *pos != '\n') { pos++; } - } - else { + } else { pos++; } } return pos; } -static const char* parse_name(const char* src) { - const char* pos = src; +static const char * parse_name(const char * src) { + const char * pos = src; while (is_word_char(*pos)) { pos++; } @@ -147,8 +145,8 @@ static const char* parse_name(const char* src) { return pos; } -static const char* parse_int(const char* src) { - const char* pos = src; +static const char * parse_int(const char * src) { + const char * pos = src; while (is_digit_char(*pos)) { pos++; } @@ -158,35 +156,33 @@ static const char* parse_int(const char* src) { return pos; } -static std::pair parse_char(const char* src) { +static std::pair parse_char(const char * src) { if (*src == '\\') { switch (src[1]) { - case 'x': return parse_hex(src + 2, 2); - case 'u': return parse_hex(src + 2, 4); - case 'U': return parse_hex(src + 2, 8); - case 't': return std::make_pair('\t', src + 2); - case 'r': return std::make_pair('\r', src + 2); - case 'n': return std::make_pair('\n', src + 2); - case '\\': - case '"': - case '[': - case ']': - return std::make_pair(src[1], src + 2); - default: - throw std::runtime_error(std::string("unknown escape at ") + src); + case 'x': return parse_hex(src + 2, 2); + case 'u': return parse_hex(src + 2, 4); + case 'U': return parse_hex(src + 2, 8); + case 't': return std::make_pair('\t', src + 2); + case 'r': return std::make_pair('\r', src + 2); + case 'n': return std::make_pair('\n', src + 2); + case '\\': + case '"': + case '[': + case ']': + return std::make_pair(src[1], src + 2); + default: + throw std::runtime_error(std::string("unknown escape at ") + src); } - } - else if (*src) { + } else if (*src) { return decode_utf8(src); } throw std::runtime_error("unexpected end of input"); } -static void print_grammar_char(FILE* file, uint32_t c) { +static void print_grammar_char(FILE * file, uint32_t c) { if (0x20 <= c && c <= 0x7f) { fprintf(file, "%c", static_cast(c)); - } - else { + } else { // cop out of encoding UTF-8 fprintf(file, "", c); } @@ -194,52 +190,52 @@ static void print_grammar_char(FILE* file, uint32_t c) { static bool is_char_element(llama_grammar_element elem) { switch (elem.type) { - case LLAMA_GRETYPE_CHAR: return true; - case LLAMA_GRETYPE_CHAR_NOT: return true; - case LLAMA_GRETYPE_CHAR_ALT: return true; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; - case LLAMA_GRETYPE_CHAR_ANY: return true; - default: return false; + case LLAMA_GRETYPE_CHAR: return true; + case LLAMA_GRETYPE_CHAR_NOT: return true; + case LLAMA_GRETYPE_CHAR_ALT: return true; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + case LLAMA_GRETYPE_CHAR_ANY: return true; + default: return false; } } -static void print_rule_binary(FILE* file, const llama_grammar_rule& rule) { +static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) { for (auto elem : rule) { switch (elem.type) { - case LLAMA_GRETYPE_END: fprintf(file, "END"); break; - case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; - case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; - case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; - case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; - case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; - case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; + case LLAMA_GRETYPE_END: fprintf(file, "END"); break; + case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; + case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; + case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; + case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; + case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; } switch (elem.type) { - case LLAMA_GRETYPE_END: - case LLAMA_GRETYPE_ALT: - case LLAMA_GRETYPE_RULE_REF: - fprintf(file, "(%u) ", elem.value); - break; - case LLAMA_GRETYPE_CHAR: - case LLAMA_GRETYPE_CHAR_NOT: - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - case LLAMA_GRETYPE_CHAR_ALT: - case LLAMA_GRETYPE_CHAR_ANY: - fprintf(file, "(\""); - print_grammar_char(file, elem.value); - fprintf(file, "\") "); - break; + case LLAMA_GRETYPE_END: + case LLAMA_GRETYPE_ALT: + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "(%u) ", elem.value); + break; + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "(\""); + print_grammar_char(file, elem.value); + fprintf(file, "\") "); + break; } } fprintf(file, "\n"); } static void print_rule( - FILE* file, - uint32_t rule_id, - const llama_grammar_rule& rule, - const std::map& symbol_id_names) { + FILE * file, + uint32_t rule_id, + const llama_grammar_rule & rule, + const std::map & symbol_id_names) { if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { throw std::runtime_error( "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); @@ -248,53 +244,53 @@ static void print_rule( for (size_t i = 0, end = rule.size() - 1; i < end; i++) { llama_grammar_element elem = rule[i]; switch (elem.type) { - case LLAMA_GRETYPE_END: - throw std::runtime_error( - "unexpected end of rule: " + std::to_string(rule_id) + "," + - std::to_string(i)); - case LLAMA_GRETYPE_ALT: - fprintf(file, "| "); - break; - case LLAMA_GRETYPE_RULE_REF: - fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); - break; - case LLAMA_GRETYPE_CHAR: - fprintf(file, "["); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_NOT: - fprintf(file, "[^"); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - if (i == 0 || !is_char_element(rule[i - 1])) { + case LLAMA_GRETYPE_END: throw std::runtime_error( - "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + - std::to_string(rule_id) + "," + std::to_string(i)); - } - fprintf(file, "-"); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_ALT: - if (i == 0 || !is_char_element(rule[i - 1])) { - throw std::runtime_error( - "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + - std::to_string(rule_id) + "," + std::to_string(i)); - } - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_ANY: - fprintf(file, "."); - break; + "unexpected end of rule: " + std::to_string(rule_id) + "," + + std::to_string(i)); + case LLAMA_GRETYPE_ALT: + fprintf(file, "| "); + break; + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + break; + case LLAMA_GRETYPE_CHAR: + fprintf(file, "["); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_NOT: + fprintf(file, "[^"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + fprintf(file, "-"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ALT: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "."); + break; } if (is_char_element(elem)) { switch (rule[i + 1].type) { - case LLAMA_GRETYPE_CHAR_ALT: - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - case LLAMA_GRETYPE_CHAR_ANY: - break; - default: - fprintf(file, "] "); + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ANY: + break; + default: + fprintf(file, "] "); } } } @@ -305,49 +301,49 @@ static void print_rule( // implementation // -uint32_t llama_grammar_parser::get_symbol_id(const char* src, size_t len) { +uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) { uint32_t next_id = static_cast(symbol_ids.size()); auto result = symbol_ids.emplace(std::string(src, len), next_id); return result.first->second; } -uint32_t llama_grammar_parser::generate_symbol_id(const std::string& base_name) { +uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) { uint32_t next_id = static_cast(symbol_ids.size()); symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; return next_id; } -void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule& rule) { +void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) { if (rules.size() <= rule_id) { rules.resize(rule_id + 1); } rules[rule_id] = rule; } -const char* llama_grammar_parser::parse_alternates( - const char* src, - const std::string& rule_name, - uint32_t rule_id, - bool is_nested) { +const char * llama_grammar_parser::parse_alternates( + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested) { llama_grammar_rule rule; - const char* pos = parse_sequence(src, rule_name, rule, is_nested); + const char * pos = parse_sequence(src, rule_name, rule, is_nested); while (*pos == '|') { - rule.push_back({ LLAMA_GRETYPE_ALT, 0 }); + rule.push_back({LLAMA_GRETYPE_ALT, 0}); pos = parse_space(pos + 1, true); pos = parse_sequence(pos, rule_name, rule, is_nested); } - rule.push_back({ LLAMA_GRETYPE_END, 0 }); + rule.push_back({LLAMA_GRETYPE_END, 0}); add_rule(rule_id, rule); return pos; } -const char* llama_grammar_parser::parse_sequence( - const char* src, - const std::string& rule_name, - llama_grammar_rule& rule, - bool is_nested) { +const char * llama_grammar_parser::parse_sequence( + const char * src, + const std::string & rule_name, + llama_grammar_rule & rule, + bool is_nested) { size_t last_sym_start = rule.size(); - const char* pos = src; + const char * pos = src; auto handle_repetitions = [&](int min_times, int max_times) { @@ -375,8 +371,7 @@ const char* llama_grammar_parser::parse_sequence( llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); if (min_times == 0) { rule.resize(last_sym_start); - } - else { + } else { // Repeat the previous elements (min_times - 1) times for (int i = 1; i < min_times; i++) { rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); @@ -389,17 +384,17 @@ const char* llama_grammar_parser::parse_sequence( llama_grammar_rule rec_rule(prev_rule); for (int i = 0; i < n_opt; i++) { rec_rule.resize(prev_rule.size()); - uint32_t rec_rule_id = generate_symbol_id(rule_name); + uint32_t rec_rule_id = generate_symbol_id( rule_name); if (i > 0 || max_times < 0) { - rec_rule.push_back({ LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id }); + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); } - rec_rule.push_back({ LLAMA_GRETYPE_ALT, 0 }); - rec_rule.push_back({ LLAMA_GRETYPE_END, 0 }); - add_rule(rec_rule_id, rec_rule); + rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + rec_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule( rec_rule_id, rec_rule); last_rec_rule_id = rec_rule_id; } if (n_opt > 0) { - rule.push_back({ LLAMA_GRETYPE_RULE_REF, last_rec_rule_id }); + rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); } }; @@ -412,12 +407,11 @@ const char* llama_grammar_parser::parse_sequence( throw std::runtime_error("unexpected end of input"); } auto char_pair = parse_char(pos); - pos = char_pair.second; - rule.push_back({ LLAMA_GRETYPE_CHAR, char_pair.first }); + pos = char_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); } pos = parse_space(pos + 1, is_nested); - } - else if (*pos == '[') { // char range(s) + } else if (*pos == '[') { // char range(s) pos++; enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; if (*pos == '^') { @@ -430,67 +424,60 @@ const char* llama_grammar_parser::parse_sequence( throw std::runtime_error("unexpected end of input"); } auto char_pair = parse_char(pos); - pos = char_pair.second; + pos = char_pair.second; enum llama_gretype type = last_sym_start < rule.size() ? LLAMA_GRETYPE_CHAR_ALT : start_type; - rule.push_back({ type, char_pair.first }); + rule.push_back({type, char_pair.first}); if (pos[0] == '-' && pos[1] != ']') { if (!pos[1]) { throw std::runtime_error("unexpected end of input"); } auto endchar_pair = parse_char(pos + 1); - pos = endchar_pair.second; - rule.push_back({ LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first }); + pos = endchar_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); } } pos = parse_space(pos + 1, is_nested); - } - else if (is_word_char(*pos)) { // rule reference - const char* name_end = parse_name(pos); + } else if (is_word_char(*pos)) { // rule reference + const char * name_end = parse_name(pos); uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); pos = parse_space(name_end, is_nested); last_sym_start = rule.size(); - rule.push_back({ LLAMA_GRETYPE_RULE_REF, ref_rule_id }); - } - else if (*pos == '(') { // grouping - // parse nested alternates into synthesized rule + rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule pos = parse_space(pos + 1, true); uint32_t sub_rule_id = generate_symbol_id(rule_name); pos = parse_alternates(pos, rule_name, sub_rule_id, true); last_sym_start = rule.size(); // output reference to synthesized rule - rule.push_back({ LLAMA_GRETYPE_RULE_REF, sub_rule_id }); + rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); if (*pos != ')') { throw std::runtime_error(std::string("expecting ')' at ") + pos); } pos = parse_space(pos + 1, is_nested); - } - else if (*pos == '.') { // any char + } else if (*pos == '.') { // any char last_sym_start = rule.size(); - rule.push_back({ LLAMA_GRETYPE_CHAR_ANY, 0 }); + rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); pos = parse_space(pos + 1, is_nested); - } - else if (*pos == '*') { + } else if (*pos == '*') { pos = parse_space(pos + 1, is_nested); handle_repetitions(0, -1); - } - else if (*pos == '+') { + } else if (*pos == '+') { pos = parse_space(pos + 1, is_nested); handle_repetitions(1, -1); - } - else if (*pos == '?') { + } else if (*pos == '?') { pos = parse_space(pos + 1, is_nested); handle_repetitions(0, 1); - } - else if (*pos == '{') { + } else if (*pos == '{') { pos = parse_space(pos + 1, is_nested); if (!is_digit_char(*pos)) { throw std::runtime_error(std::string("expecting an int at ") + pos); } - const char* int_end = parse_int(pos); + const char * int_end = parse_int(pos); int min_times = std::stoul(std::string(pos, int_end - pos)); pos = parse_space(int_end, is_nested); @@ -499,12 +486,11 @@ const char* llama_grammar_parser::parse_sequence( if (*pos == '}') { max_times = min_times; pos = parse_space(pos + 1, is_nested); - } - else if (*pos == ',') { + } else if (*pos == ',') { pos = parse_space(pos + 1, is_nested); if (is_digit_char(*pos)) { - const char* int_end = parse_int(pos); + const char * int_end = parse_int(pos); max_times = std::stoul(std::string(pos, int_end - pos)); pos = parse_space(int_end, is_nested); } @@ -513,24 +499,22 @@ const char* llama_grammar_parser::parse_sequence( throw std::runtime_error(std::string("expecting '}' at ") + pos); } pos = parse_space(pos + 1, is_nested); - } - else { + } else { throw std::runtime_error(std::string("expecting ',' at ") + pos); } handle_repetitions(min_times, max_times); - } - else { + } else { break; } } return pos; } -const char* llama_grammar_parser::parse_rule(const char* src) { - const char* name_end = parse_name(src); - const char* pos = parse_space(name_end, false); +const char * llama_grammar_parser::parse_rule(const char * src) { + const char * name_end = parse_name(src); + const char * pos = parse_space(name_end, false); size_t name_len = name_end - src; - uint32_t rule_id = get_symbol_id(src, name_len); + uint32_t rule_id = get_symbol_id(src, name_len); const std::string name(src, name_len); if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { @@ -542,33 +526,31 @@ const char* llama_grammar_parser::parse_rule(const char* src) { if (*pos == '\r') { pos += pos[1] == '\n' ? 2 : 1; - } - else if (*pos == '\n') { + } else if (*pos == '\n') { pos++; - } - else if (*pos) { + } else if (*pos) { throw std::runtime_error(std::string("expecting newline or end at ") + pos); } return parse_space(pos, true); } -bool llama_grammar_parser::parse(const char* src) { +bool llama_grammar_parser::parse(const char * src) { try { - const char* pos = parse_space(src, true); + const char * pos = parse_space(src, true); while (*pos) { pos = parse_rule(pos); } // Validate the state to ensure that all rules are defined - for (const auto& rule : rules) { + for (const auto & rule : rules) { if (rule.empty()) { throw std::runtime_error("Undefined rule"); } - for (const auto& elem : rule) { + for (const auto & elem : rule) { if (elem.type == LLAMA_GRETYPE_RULE_REF) { // Ensure that the rule at that location exists if (elem.value >= rules.size() || rules[elem.value].empty()) { // Get the name of the rule that is missing - for (const auto& kv : symbol_ids) { + for (const auto & kv : symbol_ids) { if (kv.second == elem.value) { throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); } @@ -577,8 +559,7 @@ bool llama_grammar_parser::parse(const char* src) { } } } - } - catch (const std::exception& err) { + } catch (const std::exception & err) { fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src); rules.clear(); return false; @@ -587,10 +568,10 @@ bool llama_grammar_parser::parse(const char* src) { return true; } -void llama_grammar_parser::print(FILE* file) { +void llama_grammar_parser::print(FILE * file) { try { std::map symbol_id_names; - for (const auto& kv : symbol_ids) { + for (const auto & kv : symbol_ids) { symbol_id_names[kv.second] = kv.first; } for (size_t i = 0, end = rules.size(); i < end; i++) { @@ -599,8 +580,7 @@ void llama_grammar_parser::print(FILE* file) { print_rule(file, uint32_t(i), rules[i], symbol_id_names); // fprintf(file, "\n"); } - } - catch (const std::exception& err) { + } catch (const std::exception & err) { fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); } } @@ -608,7 +588,7 @@ void llama_grammar_parser::print(FILE* file) { llama_grammar_stack llama_grammar_parser::c_rules() const { llama_grammar_stack ret; ret.reserve(rules.size()); - for (const auto& rule : rules) { + for (const auto & rule : rules) { ret.push_back(rule.data()); } return ret; @@ -638,13 +618,11 @@ static std::pair llama_grammar_match_char( // inclusive range, e.g. [a-z] found = found || (pos->value <= chr && chr <= pos[1].value); pos += 2; - } - else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { // Any character matches "." found = true; pos += 1; - } - else { + } else { // exact char match, e.g. [a] or "a" found = found || pos->value == chr; pos += 1; @@ -678,8 +656,7 @@ static bool llama_grammar_match_partial_char( if (low == 0) { if (n_remain == 2) { low = 1 << 11; - } - else if (n_remain == 3) { + } else if (n_remain == 3) { low = 1 << 16; } } @@ -691,12 +668,10 @@ static bool llama_grammar_match_partial_char( return is_positive_char; } pos += 2; - } - else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { // Any character matches "." return true; - } - else { + } else { // exact char match, e.g. [a] or "a" if (low <= pos->value && pos->value <= high) { return is_positive_char; @@ -746,8 +721,7 @@ static void llama_grammar_advance_stack( if (subpos->type == LLAMA_GRETYPE_ALT) { // there's another alternate def of this rule to process subpos++; - } - else { + } else { break; } } while (true); @@ -770,9 +744,9 @@ static void llama_grammar_advance_stack( } static llama_grammar_candidates llama_grammar_reject_candidates( - const llama_grammar_rules& rules, - const llama_grammar_stacks& stacks, - const llama_grammar_candidates& candidates) { + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + const llama_grammar_candidates & candidates) { GGML_ASSERT(!stacks.empty()); // REVIEW if (candidates.empty()) { @@ -789,18 +763,18 @@ static llama_grammar_candidates llama_grammar_reject_candidates( } static bool llama_grammar_detect_left_recursion( - const llama_grammar_rules& rules, - size_t rule_index, - std::vector* rules_visited, - std::vector* rules_in_progress, - std::vector* rules_may_be_empty) { + const llama_grammar_rules & rules, + size_t rule_index, + std::vector * rules_visited, + std::vector * rules_in_progress, + std::vector * rules_may_be_empty) { if ((*rules_in_progress)[rule_index]) { return true; } (*rules_in_progress)[rule_index] = true; - const llama_grammar_rule& rule = rules[rule_index]; + const llama_grammar_rule & rule = rules[rule_index]; // First check if the rule might produce the empty string. This could be done combined with the second // step but it's more readable as two steps. @@ -812,8 +786,7 @@ static bool llama_grammar_detect_left_recursion( break; } at_rule_start = true; - } - else { + } else { at_rule_start = false; } } @@ -829,11 +802,9 @@ static bool llama_grammar_detect_left_recursion( if (!((*rules_may_be_empty)[(size_t)rule[i].value])) { recurse_into_nonterminal = false; } - } - else if (llama_grammar_is_end_of_sequence(&rule[i])) { + } else if (llama_grammar_is_end_of_sequence(&rule[i])) { recurse_into_nonterminal = true; - } - else { + } else { recurse_into_nonterminal = false; } } @@ -852,19 +823,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } +void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { + llama_grammar_stacks stacks_new; + stacks_new.reserve(grammar->stacks.size()); -// takes a set of possible pushdown stacks on a grammar, which are required to -// be positioned at a character range (see `llama_grammar_advance_stack`), and -// produces the N possible stacks if the given char is accepted at those -// positions -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & new_stacks) { - new_stacks.clear(); - - for (const auto & stack : stacks) { + for (const auto & stack : grammar->stacks) { if (stack.empty()) { continue; } @@ -878,11 +841,12 @@ void llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack(rules, new_stack, new_stacks); + llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new); } } -} + grammar->stacks = std::move(stacks_new); +} llama_grammar_candidates llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, @@ -939,29 +903,23 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( return rejects; } -// -// implementation -// +//////////////////// -// -// grammar - external -// - - -struct llama_grammar* llama_grammar_init_impl( - const llama_grammar_element** rules, - size_t n_rules, - size_t start_rule_index) { - const llama_grammar_element* pos; +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { + const llama_grammar_element * pos; // copy rule definitions into vectors llama_grammar_rules vec_rules(n_rules); for (size_t i = 0; i < n_rules; i++) { for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { vec_rules[i].push_back(*pos); - } - vec_rules[i].push_back({ LLAMA_GRETYPE_END, 0 }); } + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); + } // Check for left recursion std::vector rules_visited(n_rules); @@ -970,7 +928,7 @@ struct llama_grammar* llama_grammar_init_impl( for (size_t i = 0; i < n_rules; i++) { if (rules_visited[i]) { continue; - } + } if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i); return nullptr; @@ -985,26 +943,25 @@ struct llama_grammar* llama_grammar_init_impl( if (!llama_grammar_is_end_of_sequence(pos)) { // if alternate is nonempty, add to stack stack.push_back(pos); - } + } llama_grammar_advance_stack(vec_rules, stack, stacks); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++; - } + } if (pos->type == LLAMA_GRETYPE_ALT) { // there's another alternate def of this rule to process pos++; - } - else { + } else { break; - } + } } while (true); // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar{ - NULL, + return new llama_grammar { + vocab, std::move(vec_rules), std::move(stacks), /* .partial_utf8 = */ {}, @@ -1016,15 +973,15 @@ struct llama_grammar* llama_grammar_init_impl( }; } -struct llama_grammar* llama_grammar_init_impl( - const struct llama_vocab* vocab, - const char* grammar_str, - const char* grammar_root, - bool lazy, - const char** trigger_patterns, - size_t num_trigger_patterns, - const llama_token* trigger_tokens, - size_t num_trigger_tokens) { +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_patterns, + size_t num_trigger_patterns, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { llama_grammar_parser parser; // if there is a grammar, parse it @@ -1040,7 +997,7 @@ struct llama_grammar* llama_grammar_init_impl( return nullptr; } - std::vector grammar_rules(parser.c_rules()); + std::vector grammar_rules(parser.c_rules()); const size_t n_rules = grammar_rules.size(); const size_t start_rule_index = parser.symbol_ids.at(grammar_root); @@ -1087,8 +1044,7 @@ struct llama_grammar* llama_grammar_init_impl( if (pos->type == LLAMA_GRETYPE_ALT) { // there's another alternate def of this rule to process pos++; - } - else { + } else { break; } } while (true); @@ -1101,16 +1057,15 @@ struct llama_grammar* llama_grammar_init_impl( } for (size_t i = 0; i < num_trigger_patterns; i++) { GGML_ASSERT(trigger_patterns != nullptr); - auto& trigger = vec_trigger_patterns.emplace_back(); + auto & trigger = vec_trigger_patterns.emplace_back(); trigger.pattern = trigger_patterns[i]; trigger.regex = std::regex(trigger.pattern); } - // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar{ + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), @@ -1124,28 +1079,33 @@ struct llama_grammar* llama_grammar_init_impl( } void llama_grammar_free_impl(struct llama_grammar * grammar) { + if (grammar == nullptr) { + return; + } + delete grammar; } -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) { - auto* result = new llama_grammar{ - grammar->vocab, - grammar->rules, - grammar->stacks, - grammar->partial_utf8, - grammar->lazy, - grammar->awaiting_trigger, - grammar->trigger_buffer, - grammar->trigger_tokens, - grammar->trigger_patterns, +struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { + auto * result = new llama_grammar { + grammar.vocab, + grammar.rules, + grammar.stacks, + grammar.partial_utf8, + grammar.lazy, + grammar.awaiting_trigger, + grammar.trigger_buffer, + grammar.trigger_tokens, + grammar.trigger_patterns, }; + // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { - for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { - for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { - if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { - result->stacks[is][ie] = &result->rules[ir0][ir1]; + for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { + for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { + if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { + result->stacks[is][ie] = &result->rules[ir0][ir1]; } } } @@ -1155,16 +1115,15 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram return result; } -void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) { - GGML_ASSERT(grammar); - GGML_ASSERT(vocab); - if (grammar->awaiting_trigger) { +void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { + GGML_ASSERT(grammar.vocab != nullptr); + + if (grammar.awaiting_trigger) { return; } - int64_t t_start_sample_us = ggml_time_us(); bool allow_eog = false; - for (const auto & stack : grammar->stacks) { + for (const auto & stack : grammar.stacks) { if (stack.empty()) { allow_eog = true; break; @@ -1172,55 +1131,52 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc } std::vector, llama_partial_utf8>> candidates_decoded; - candidates_decoded.reserve(candidates->size); + candidates_decoded.reserve(cur_p->size); llama_grammar_candidates candidates_grammar; - candidates_grammar.reserve(candidates->size); + candidates_grammar.reserve(cur_p->size); - for (size_t i = 0; i < candidates->size; ++i) { - const llama_token id = candidates->data[i].id; - const std::string & piece = vocab->token_to_piece(id); + for (size_t i = 0; i < cur_p->size; ++i) { + const llama_token id = cur_p->data[i].id; + const std::string & piece = grammar.vocab->token_to_piece(id); - if (vocab->is_eog(id)) { + if (grammar.vocab->is_eog(id)) { if (!allow_eog) { - candidates->data[i].logit = -INFINITY; + cur_p->data[i].logit = -INFINITY; } } else if (piece.empty() || piece[0] == 0) { - candidates->data[i].logit = -INFINITY; + cur_p->data[i].logit = -INFINITY; } else { - candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); + candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } } - const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); for (const auto & reject : rejects) { - candidates->data[reject.index].logit = -INFINITY; + cur_p->data[reject.index].logit = -INFINITY; } - if (!smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; -} } -void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { - const int64_t t_start_sample_us = ggml_time_us(); - GGML_ASSERT(grammar->vocab != nullptr); - const auto& piece = grammar->vocab->token_to_piece(token); +void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { + GGML_ASSERT(grammar.vocab != nullptr); - if (grammar->awaiting_trigger) { - if (std::find(grammar->trigger_tokens.begin(), grammar->trigger_tokens.end(), token) != grammar->trigger_tokens.end()) { - grammar->awaiting_trigger = false; - grammar->trigger_buffer.clear(); + const auto & piece = grammar.vocab->token_to_piece(token); + + if (grammar.awaiting_trigger) { + if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { + grammar.awaiting_trigger = false; + grammar.trigger_buffer.clear(); llama_grammar_accept_str(grammar, piece); LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); return; - } else { - grammar->trigger_buffer += piece; + } else { + grammar.trigger_buffer += piece; std::smatch match; - for (const auto& trigger_pattern : grammar->trigger_patterns) { - if (std::regex_match(grammar->trigger_buffer, match, trigger_pattern.regex)) { - grammar->awaiting_trigger = false; + for (const auto & trigger_pattern : grammar.trigger_patterns) { + if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) { + grammar.awaiting_trigger = false; // get from the first matched capturing group to the end of the string size_t start = std::string::npos; for (auto i = 1u; i < match.size(); i++) { @@ -1232,21 +1188,21 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc if (start == std::string::npos) { start = match.position(0); } - auto constrained_str = grammar->trigger_buffer.substr(start); + auto constrained_str = grammar.trigger_buffer.substr(start); // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); - grammar->trigger_buffer.clear(); + grammar.trigger_buffer.clear(); llama_grammar_accept_str(grammar, constrained_str); - //LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str()); + LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str()); return; } } - //LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str()); + LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str()); return; } } - if (llama_token_is_eog(vocab, token)) { - for (const auto & stack : grammar->stacks) { + if (grammar.vocab->is_eog(token)) { + for (const auto & stack : grammar.stacks) { if (stack.empty()) { return; } @@ -1255,33 +1211,19 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc } llama_grammar_accept_str(grammar, piece); - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } -void llama_grammar_accept_str(struct llama_grammar* grammar, const std::string& piece) { - +void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) { // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece, grammar->partial_utf8); + const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; - llama_grammar_stacks tmp_new_stacks; - for (auto it = code_points.begin(), end = code_points.end()-1; it != end; ++it) { - llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks); - // avoid empty grammar stack at the end of the code_points - // mainline has this bug too, reason unknown - if (end == code_points.end() - 1) { - if (tmp_new_stacks.size()) { - grammar->stacks = tmp_new_stacks; - } - } - else { - grammar->stacks = tmp_new_stacks; - } + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + llama_grammar_accept(&grammar, *it); } - grammar->partial_utf8 = decoded.second; - if (grammar->stacks.empty()) { + grammar.partial_utf8 = decoded.second; + if (grammar.stacks.empty()) { throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); } - } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 532f9d22..f8c291de 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -1,13 +1,80 @@ #pragma once -#include "llama-impl.h" +#include "llama.h" + #include #include #include #include struct llama_vocab; -struct llama_sampling; + +// grammar element type +enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 6, + + // any character (.) + LLAMA_GRETYPE_CHAR_ANY = 7, +}; + +typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point or rule ID +} llama_grammar_element; + +struct llama_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; + llama_partial_utf8 partial_utf8; +}; + +using llama_grammar_rule = std::vector< llama_grammar_element>; +using llama_grammar_stack = std::vector; + +using llama_grammar_rules = std::vector; +using llama_grammar_stacks = std::vector; +using llama_grammar_candidates = std::vector; + +// TODO: remove, needed for tests atm +const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); + llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `llama_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr); + +std::vector llama_grammar_reject_candidates_for_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + const llama_grammar_candidates & candidates); struct llama_grammar_parser { std::map symbol_ids; @@ -16,27 +83,27 @@ struct llama_grammar_parser { llama_grammar_stack c_rules() const; - uint32_t get_symbol_id(const char* src, size_t len); - uint32_t generate_symbol_id(const std::string& base_name); + uint32_t get_symbol_id(const char * src, size_t len); + uint32_t generate_symbol_id(const std::string & base_name); - void add_rule(uint32_t rule_id, const llama_grammar_rule& rule); + void add_rule(uint32_t rule_id, const llama_grammar_rule & rule); - const char* parse_alternates( - const char* src, - const std::string& rule_name, - uint32_t rule_id, - bool is_nested); + const char * parse_alternates( + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested); - const char* parse_sequence( - const char* src, - const std::string& rule_name, - llama_grammar_rule& rule, - bool is_nested); + const char * parse_sequence( + const char * src, + const std::string & rule_name, + llama_grammar_rule & rule, + bool is_nested); - const char* parse_rule(const char* src); + const char * parse_rule(const char * src); - bool parse(const char* src); - void print(FILE* file); + bool parse(const char * src); + void print(FILE * file); }; struct llama_grammar_trigger_pattern { @@ -46,63 +113,61 @@ struct llama_grammar_trigger_pattern { struct llama_grammar { // note: allow null vocab for testing (not great) - const llama_vocab* vocab; + const llama_vocab * vocab; const llama_grammar_rules rules; // TODO: shared ptr - llama_grammar_stacks stacks; + llama_grammar_stacks stacks; // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; // lazy grammars wait for trigger words or tokens before constraining the sampling. - // we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens. + // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens. // (useful e.g. for tool_choice=required) - bool lazy = false; + bool lazy = false; bool awaiting_trigger = false; // Initialized to true for lazy grammars only std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). - std::vector trigger_patterns; - // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated - // string, and the grammar will be given the string from the first match group onwards. + std::vector + trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated + // string, and the grammar will be given the string from the first match group onwards. }; // // internal API // -// note: needed for tests (not great) -struct llama_grammar* llama_grammar_init_impl( - const llama_grammar_element** rules, - size_t n_rules, - size_t start_rule_index); -struct llama_grammar* llama_grammar_init_impl( - const struct llama_vocab* vocab, - const char* grammar_str, - const char* grammar_root, - bool lazy, - const char** trigger_patterns, - size_t num_trigger_patterns, - const llama_token* trigger_tokens, - size_t num_trigger_tokens); +// note: needed for tests (not great) +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_patterns, + size_t num_trigger_patterns, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); void llama_grammar_free_impl(struct llama_grammar * grammar); -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar); +struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar); -void llama_grammar_sample_impl( - const struct llama_grammar * grammar, - const struct llama_vocab * vocab, - const struct llama_sampling * smpl, - llama_token_data_array * candidates); +// TODO: move the API below as member functions of llama_grammar +void llama_grammar_apply_impl( + const struct llama_grammar & grammar, + llama_token_data_array * cur_p); -void llama_grammar_accept_token_impl( - struct llama_grammar * grammar, - const struct llama_vocab * vocab, - const struct llama_sampling * smpl, +void llama_grammar_accept_impl( + struct llama_grammar & grammar, llama_token token); - void llama_grammar_accept_str( - struct llama_grammar* grammar, - const std::string& piece); + struct llama_grammar & grammar, + const std::string & piece); diff --git a/src/llama-impl.h b/src/llama-impl.h index 38844506..fc5bc30b 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -67,118 +67,6 @@ static void replace_all(std::string & s, const std::string & search, const std:: s = std::move(builder); } - -// the ring buffer works similarly to std::deque, but with a fixed capacity -template -struct ring_buffer { - ring_buffer(size_t cap) : capacity(cap), data(cap) {} - - T& front() { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - return data[first]; - } - - const T& front() const { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - return data[first]; - } - - T& back() { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - return data[pos]; - } - - const T& back() const { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - return data[pos]; - } - - void push_back(const T& value) { - if (capacity == 0) { - throw std::runtime_error("ring buffer: capacity is zero"); - } - - if (sz == capacity) { - // advance the start when buffer is full - first = (first + 1) % capacity; - } - else { - sz++; - } - data[pos] = value; - pos = (pos + 1) % capacity; - } - - T pop_front() { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - T value = data[first]; - first = (first + 1) % capacity; - sz--; - return value; - } - - //T & operator[](size_t i) { - // if (i >= sz) { - // throw std::runtime_error("ring buffer: index out of bounds"); - // } - // return data[(first + i) % capacity]; - //} - - //const T & at(size_t i) const { - // if (i >= sz) { - // throw std::runtime_error("ring buffer: index out of bounds"); - // } - // return data[(first + i) % capacity]; - //} - - const T& rat(size_t i) const { - if (i >= sz) { - throw std::runtime_error("ring buffer: index out of bounds"); - } - return data[(first + sz - i - 1) % capacity]; - } - - std::vector to_vector() const { - std::vector result; - result.reserve(sz); - for (size_t i = 0; i < sz; i++) { - result.push_back(data[(first + i) % capacity]); - } - return result; - } - - void clear() { - // here only reset the status of the buffer - sz = 0; - first = 0; - pos = 0; - } - - bool empty() const { - return sz == 0; - } - - size_t size() const { - return sz; - } - - size_t capacity = 0; - size_t sz = 0; - size_t first = 0; - size_t pos = 0; - std::vector data; -}; - LLAMA_ATTRIBUTE_FORMAT(1, 2) static std::string format(const char * fmt, ...) { va_list ap; @@ -219,6 +107,14 @@ struct no_init { no_init() { /* do nothing */ } }; +struct time_meas { + time_meas(int64_t & t_acc, bool disable = false); + ~time_meas(); + + const int64_t t_start_us; + + int64_t & t_acc; +}; struct gguf_context; std::string gguf_kv_to_str(const gguf_context * ctx_gguf, int i); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 0d23e146..adb3f881 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1,14 +1,250 @@ #include "llama-sampling.h" + +#include "llama-impl.h" #include "llama-vocab.h" #include "llama-grammar.h" +#include #include +#include +#include +#include +#include +#include #include #include -#include #include +#include #include +#include +// the ring buffer works similarly to std::deque, but with a fixed capacity +template +struct ring_buffer { + ring_buffer(size_t cap) : capacity(cap), data(cap) {} + + T & front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + const T & front() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + T & back() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + const T & back() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + void push_back(const T & value) { + if (capacity == 0) { + throw std::runtime_error("ring buffer: capacity is zero"); + } + + if (sz == capacity) { + // advance the start when buffer is full + first = (first + 1) % capacity; + } else { + sz++; + } + data[pos] = value; + pos = (pos + 1) % capacity; + } + + T pop_front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + T value = data[first]; + first = (first + 1) % capacity; + sz--; + return value; + } + + //T & operator[](size_t i) { + // if (i >= sz) { + // throw std::runtime_error("ring buffer: index out of bounds"); + // } + // return data[(first + i) % capacity]; + //} + + //const T & at(size_t i) const { + // if (i >= sz) { + // throw std::runtime_error("ring buffer: index out of bounds"); + // } + // return data[(first + i) % capacity]; + //} + + const T & rat(size_t i) const { + if (i >= sz) { + throw std::runtime_error("ring buffer: index out of bounds"); + } + return data[(first + sz - i - 1) % capacity]; + } + + std::vector to_vector() const { + std::vector result; + result.reserve(sz); + for (size_t i = 0; i < sz; i++) { + result.push_back(data[(first + i) % capacity]); + } + return result; + } + + void clear() { + // here only reset the status of the buffer + sz = 0; + first = 0; + pos = 0; + } + + bool empty() const { + return sz == 0; + } + + size_t size() const { + return sz; + } + + size_t capacity = 0; + size_t sz = 0; + size_t first = 0; + size_t pos = 0; + + std::vector data; +}; + +// writes result in res, does not mutate cur +static void llama_token_data_array_partial_sort(const llama_token_data_array & cur, int npartial, std::vector & res) { + static const auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + + constexpr int nbuckets = 128; + constexpr float bucket_low = -10.0f; + constexpr float bucket_high = 10.0f; + constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); + constexpr float bucket_inter = -bucket_low * bucket_scale; + + std::vector bucket_idx; + std::vector histo(nbuckets, 0); + + std::vector bucket_ptrs; + + bucket_idx.reserve(cur.size); + + for (int i = 0; i < (int)cur.size; ++i) { + const float val = cur.data[i].logit; + int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); + ib = std::max(0, std::min(nbuckets - 1, ib)); + bucket_idx.push_back(ib); + ++histo[ib]; + } + int nhave = 0; + int ib = nbuckets - 1; + for ( ; ib >= 0; --ib) { + nhave += histo[ib]; + if (nhave >= npartial) { + break; + } + } + res.resize(nhave); + auto * ptr = res.data(); + bucket_ptrs.reserve(nbuckets - ib); + for (int j = nbuckets - 1; j >= ib; --j) { + bucket_ptrs.push_back(ptr); + ptr += histo[j]; + } + for (int i = 0; i < (int)cur.size; ++i) { + int j = bucket_idx[i]; + if (j >= ib) { + *bucket_ptrs[nbuckets - 1 - j]++ = cur.data[i]; + } + } + + ptr = res.data(); + int ndone = 0; + for (int j = nbuckets - 1; j > ib; --j) { + std::sort(ptr, ptr + histo[j], comp); + ptr += histo[j]; + ndone += histo[j]; + } + std::partial_sort(ptr, ptr + npartial - ndone, ptr + histo[ib], comp); +} + +// reduces the size of cur_p to npartial, keeping only the top npartial elements +static void llama_token_data_array_partial_sort_inplace(llama_token_data_array * cur_p, int npartial) { + static const auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + + if (npartial <= 128) { + std::partial_sort(cur_p->data, cur_p->data + npartial, cur_p->data + cur_p->size, comp); + + cur_p->size = npartial; + cur_p->sorted = true; + + return; + } + + std::vector tmp; + + llama_token_data_array_partial_sort(*cur_p, npartial, tmp); + + std::copy(tmp.data(), tmp.data() + npartial, cur_p->data); + + cur_p->size = npartial; + cur_p->sorted = true; +} + +static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) { + // iterator for the probabilities +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#endif + + struct probs_iterator { + typedef std::input_iterator_tag iterator_category; + typedef float value_type; + typedef float * pointer; + typedef float & reference; + typedef ptrdiff_t difference_type; + + const llama_token_data * data; + + bool operator==(const probs_iterator & other) const { return data == other.data; } + bool operator!=(const probs_iterator & other) const { return data != other.data; } + const float & operator*() const { return data->p; } + probs_iterator & operator++() { ++data; return *this; } + probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; } + }; + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif + + std::discrete_distribution dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size}); + + return dist(rng); +} + +/* static void llama_log_softmax(float * array, size_t size) { float max_l = *std::max_element(array, array + size); float sum = 0.f; @@ -22,303 +258,736 @@ static void llama_log_softmax(float * array, size_t size) { array[i] = logf(array[i] / sum); } } +*/ -void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) { - if (seed == LLAMA_DEFAULT_SEED) { - seed = time(NULL); - } +static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) { + if (temp <= 0.0f) { + // find the token with the highest logit and set the rest to -inf + size_t max_i = 0; + float max_l = cur_p->data[0].logit; - smpl->rng.seed(seed); -} - -void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - GGML_ASSERT(candidates->size > 0); - - const int64_t t_start_sample_us = ggml_time_us(); - - // Sort the logits in descending order - if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - candidates->sorted = true; - } - - float max_l = candidates->data[0].logit; - float cum_sum = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) { - float p = expf(candidates->data[i].logit - max_l); - candidates->data[i].p = p; - cum_sum += p; - } - for (size_t i = 0; i < candidates->size; ++i) { - candidates->data[i].p /= cum_sum; - } - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) { - // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast - // if (k >= (int32_t)candidates->size) { - // return; - // } - - const int64_t t_start_sample_us = ggml_time_us(); - - if (k <= 0) { - k = candidates->size; - } - - k = std::max(k, (int) min_keep); - k = std::min(k, (int) candidates->size); - - // Sort scores in descending order - if (!candidates->sorted) { - auto comp = [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }; - if (k <= 128) { - std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); - } else { - constexpr int nbuckets = 128; - constexpr float bucket_low = -10.0f; - constexpr float bucket_high = 10.0f; - constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); - constexpr float bucker_inter = -bucket_low * bucket_scale; - - std::vector bucket_idx(candidates->size); - std::vector histo(nbuckets, 0); - - for (int i = 0; i < (int)candidates->size; ++i) { - const float val = candidates->data[i].logit; - int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); - ib = std::max(0, std::min(nbuckets-1, ib)); - bucket_idx[i] = ib; - ++histo[ib]; + for (size_t i = 1; i < cur_p->size; ++i) { + if (cur_p->data[i ].logit > max_l) { + cur_p->data[max_i].logit = -INFINITY; + max_i = i; + max_l = cur_p->data[i].logit; + } else { + cur_p->data[i].logit = -INFINITY; } - int nhave = 0; - int ib = nbuckets - 1; - for ( ; ib >= 0; --ib) { - nhave += histo[ib]; - if (nhave >= k) break; - } - std::vector tmp_tokens(nhave); - auto ptr = tmp_tokens.data(); - std::vector bucket_ptrs; - bucket_ptrs.reserve(nbuckets - ib); - for (int j = nbuckets - 1; j >= ib; --j) { - bucket_ptrs.push_back(ptr); - ptr += histo[j]; - } - for (int i = 0; i < (int)candidates->size; ++i) { - int j = bucket_idx[i]; - if (j >= ib) { - *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i]; - } - } - - ptr = tmp_tokens.data(); - int ndone = 0; - for (int j = nbuckets-1; j > ib; --j) { - std::sort(ptr, ptr + histo[j], comp); - ptr += histo[j]; - ndone += histo[j]; - } - std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); - - std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data)); - } - candidates->sorted = true; - } - candidates->size = k; - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { - if (p >= 1.0f) { return; } - llama_sample_softmax_impl(smpl, candidates); + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].logit /= temp; + } +} - const int64_t t_start_sample_us = ggml_time_us(); +static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_sort) { + GGML_ASSERT(cur_p->size > 0); + + // Sort the logits in descending order if requested + if (do_sort && !cur_p->sorted) { + llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size); + } + + float max_l = cur_p->data[0].logit; + if (!cur_p->sorted) { + for (size_t i = 1; i < cur_p->size; ++i) { + max_l = std::max(max_l, cur_p->data[i].logit); + } + } + + float cum_sum = 0.0f; + + for (size_t i = 0; i < cur_p->size; ++i) { + float p = expf(cur_p->data[i].logit - max_l); + cur_p->data[i].p = p; + cum_sum += p; + } + + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= cum_sum; + } +} + +static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) { + // if (k >= (int32_t)cur_p->size) { + // return; + // } + + if (k <= 0) { + return; + } + + k = std::min(k, (int) cur_p->size); + + // Sort scores in descending order + if (!cur_p->sorted) { + llama_token_data_array_partial_sort_inplace(cur_p, k); + } + + cur_p->size = k; +} + +static uint32_t get_rng_seed(uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + // use system clock if std::random_device is not a true RNG + static bool is_rd_prng = std::random_device().entropy() == 0; + if (is_rd_prng) { + return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count(); + } + std::random_device rd; + return rd(); + } + return seed; +} + +// llama_sampler API + +struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) { + return new llama_sampler { + /* .iface = */ iface, + /* .ctx = */ ctx, + }; +} + +const char * llama_sampler_name(const struct llama_sampler * smpl) { + if (!smpl->iface) { + return "(null)"; + } + + return smpl->iface->name(smpl); +} + +void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { + if (smpl->iface->accept) { + smpl->iface->accept(smpl, token); + } +} + +void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) { + GGML_ASSERT(smpl->iface->apply); + smpl->iface->apply(smpl, cur_p); +} + +void llama_sampler_reset(struct llama_sampler * smpl) { + if (smpl->iface->reset) { + smpl->iface->reset(smpl); + } +} + +struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { + if (smpl->iface->clone) { + return smpl->iface->clone(smpl); + } + + if (smpl->ctx == nullptr) { + return llama_sampler_init( + /* .iface = */ smpl->iface, + /* .ctx = */ nullptr + ); + } + + GGML_ABORT("the sampler does not support cloning"); +} + +void llama_sampler_free(struct llama_sampler * smpl) { + if (smpl == nullptr) { + return; + } + + if (smpl->iface->free) { + smpl->iface->free(smpl); + } + + delete smpl; +} + +llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { + const auto * logits = llama_get_logits_ith(ctx, idx); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); + + // TODO: do not allocate each time + std::vector cur; + cur.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { + /* .data = */ cur.data(), + /* .size = */ cur.size(), + /* .selected = */ -1, + /* .sorted = */ false, + }; + + llama_sampler_apply(smpl, &cur_p); + + GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); + + auto token = cur_p.data[cur_p.selected].id; + + llama_sampler_accept(smpl, token); + + return token; +} + +// sampler chain + +static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { + return "chain"; +} + +static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_perf); + + for (auto * smpl : chain->samplers) { + llama_sampler_accept(smpl, token); + } + + chain->n_sample++; +} + +static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_perf); + + for (auto * smpl : chain->samplers) { + llama_sampler_apply(smpl, cur_p); + } +} + +static void llama_sampler_chain_reset(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_reset(smpl); + } + + chain->t_sample_us = 0; + chain->n_sample = 0; +} + +static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) { + const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; + + auto * result = llama_sampler_chain_init(chain_src->params); + + for (auto * smpl : chain_src->samplers) { + llama_sampler_chain_add(result, llama_sampler_clone(smpl)); + } + + return result; +} + +static void llama_sampler_chain_free(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_free(smpl); + } + + delete chain; +} + +static struct llama_sampler_i llama_sampler_chain_i = { + /* .name = */ llama_sampler_chain_name, + /* .accept = */ llama_sampler_chain_accept, + /* .apply = */ llama_sampler_chain_apply, + /* .reset = */ llama_sampler_chain_reset, + /* .clone = */ llama_sampler_chain_clone, + /* .free = */ llama_sampler_chain_free, +}; + +struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_chain_i, + /* .ctx = */ new llama_sampler_chain { + /* .params = */ params, + /* .samplers = */ {}, + /* .t_sample_us = */ 0, + /* .n_sample = */ 0, + } + ); +} + +void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { + auto * p = (llama_sampler_chain *) chain->ctx; + p->samplers.push_back(smpl); +} + +struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { + const auto * p = (const llama_sampler_chain *) chain->ctx; + + if (i < 0 || (size_t) i >= p->samplers.size()) { + return nullptr; + } + + return p->samplers[i]; +} + +struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) { + auto * p = (llama_sampler_chain *) chain->ctx; + + if (i < 0 || (size_t) i >= p->samplers.size()) { + return nullptr; + } + + auto * result = p->samplers[i]; + p->samplers.erase(p->samplers.begin() + i); + + return result; +} + +int llama_sampler_chain_n(const struct llama_sampler * chain) { + const auto * p = (const llama_sampler_chain *) chain->ctx; + + return p->samplers.size(); +} + +// +// samplers +// + +// greedy + +static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) { + return "greedy"; +} + +static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { + cur_p->selected = 0; + for (size_t i = 1; i < cur_p->size; ++i) { + if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) { + cur_p->selected = i; + } + } +} + +static struct llama_sampler_i llama_sampler_greedy_i = { + /* .name = */ llama_sampler_greedy_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_greedy_apply, + /* .reset = */ nullptr, + /* .clone = */ nullptr, + /* .free = */ nullptr, +}; + +struct llama_sampler * llama_sampler_init_greedy() { + return llama_sampler_init( + /* .iface = */ &llama_sampler_greedy_i, + /* .ctx = */ nullptr + ); +} + +// dist + +struct llama_sampler_dist { + const uint32_t seed; + uint32_t seed_cur; + + std::mt19937 rng; +}; + +static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) { + return "dist"; +} + +static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_dist *) smpl->ctx; + + // edge cases + if (cur_p->size == 0) { + cur_p->selected = -1; + return; + } + + cur_p->selected = 0; + + if (cur_p->size == 1) { + cur_p->data[0].p = 1.0f; + return; + } + + // max logit for numerical stability + float max_l = cur_p->data[0].logit; + if (!cur_p->sorted) { + for (size_t i = 1; i < cur_p->size; ++i) { + max_l = std::max(max_l, cur_p->data[i].logit); + } + } + + // apply softmax to obtain the probabilities + double sum_cum = 0.0f; + for (size_t i = 0; i < cur_p->size; ++i) { + float p = expf(cur_p->data[i].logit - max_l); + cur_p->data[i].p = p; + sum_cum += p; + } + +#if 1 + // sample from the obtained probabilities and normalize the probs in a single pass + // this is ~3x faster on Mac with full gpt-oss vocab than the version below + // + std::uniform_real_distribution dist(0.0f, 1.0f); + const double rnd = dist(ctx->rng); + + double sum_run = 0.0f; + const double sum_tgt = sum_cum*rnd; + + bool found = false; + for (size_t i = 0; i < cur_p->size; ++i) { + if (!found) { + // accumulate probs until we reach the target sum + sum_run += cur_p->data[i].p; + if (sum_run >= sum_tgt) { + cur_p->selected = i; + found = true; + } + } + + // normalize probs + cur_p->data[i].p /= sum_cum; + } + + // fallback to the last token (don't think this can happen) + assert(found); + if (!found) { + cur_p->selected = cur_p->size - 1; + } +#else + // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= sum_cum; + } + + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); +#endif +} + +static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_dist *) smpl->ctx; + auto * result = llama_sampler_init_dist(ctx->seed); + + // copy the state + { + auto * result_ctx = (llama_sampler_dist *) result->ctx; + + result_ctx->rng = ctx->rng; + } + + return result; +} + +static void llama_sampler_dist_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dist *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static void llama_sampler_dist_free(struct llama_sampler * smpl) { + delete (llama_sampler_dist *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_dist_i = { + /* .name = */ llama_sampler_dist_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_dist_apply, + /* .reset = */ llama_sampler_dist_reset, + /* .clone = */ llama_sampler_dist_clone, + /* .free = */ llama_sampler_dist_free, +}; + +struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { + auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( + /* .iface = */ &llama_sampler_dist_i, + /* .ctx = */ new llama_sampler_dist { + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + } + ); +} + +// top-k + +struct llama_sampler_top_k { + const int32_t k; +}; + +static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) { + return "top-k"; +} + +static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_top_k *) smpl->ctx; + llama_sampler_top_k_impl(cur_p, ctx->k); +} + +static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_top_k *) smpl->ctx; + return llama_sampler_init_top_k(ctx->k); +} + +static void llama_sampler_top_k_free(struct llama_sampler * smpl) { + delete (llama_sampler_top_k *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_top_k_i = { + /* .name = */ llama_sampler_top_k_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_k_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_k_clone, + /* .free = */ llama_sampler_top_k_free, +}; + +struct llama_sampler * llama_sampler_init_top_k(int32_t k) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_top_k_i, + /* .ctx = */ new llama_sampler_top_k { + /* .k = */ k, + } + ); +} + +// top-p + +struct llama_sampler_top_p { + const float p; + const size_t min_keep; + + std::vector buf_sort; +}; + +static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) { + return "top-p"; +} + +static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_top_p *) smpl->ctx; + + if (ctx->p >= 1.0f) { + return; + } + + llama_sampler_softmax_impl(cur_p, false); + + size_t k = cur_p->size; + auto * pdata = cur_p->data; + + auto & buf_sort = ctx->buf_sort; + + // if not sorted, try adaptive top-k sorting + if (!cur_p->sorted && cur_p->size > 1024) { + k = std::min(256, cur_p->size); + llama_token_data_array_partial_sort(*cur_p, k, buf_sort); + pdata = buf_sort.data(); + } else if (!cur_p->sorted) { + // small candidates -> sort inplace + llama_token_data_array_partial_sort_inplace(cur_p, k); + } // Compute the cumulative probabilities float cum_sum = 0.0f; - size_t last_idx = candidates->size; + size_t last_idx = cur_p->size; - for (size_t i = 0; i < candidates->size; ++i) { - cum_sum += candidates->data[i].p; + for (size_t i = 0; i < cur_p->size; ++i) { + cum_sum += pdata[i].p; // Check if the running sum is at least p or if we have kept at least min_keep tokens // we set the last index to i+1 to indicate that the current iterate should be included in the set - if (cum_sum >= p && i + 1 >= min_keep) { + if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) { last_idx = i + 1; break; } + + // we exceeded the current top-k heuristic -> increase k and continue + if (!cur_p->sorted && i == k - 1) { + k = cur_p->size; + llama_token_data_array_partial_sort(*cur_p, k, buf_sort); + pdata = buf_sort.data(); + } } // Resize the output vector to keep only the top-p tokens - candidates->size = last_idx; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + if (!cur_p->sorted) { + std::copy(buf_sort.data(), buf_sort.data() + last_idx, cur_p->data); + cur_p->sorted = true; } + + cur_p->size = last_idx; } -void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { - if (p <= 0.0f || !candidates->size) { +static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_top_p *) smpl->ctx; + return llama_sampler_init_top_p(ctx->p, ctx->min_keep); +} + +static void llama_sampler_top_p_free(struct llama_sampler * smpl) { + delete (llama_sampler_top_p *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_top_p_i = { + /* .name = */ llama_sampler_top_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_p_clone, + /* .free = */ llama_sampler_top_p_free, +}; + +struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_top_p_i, + /* .ctx = */ new llama_sampler_top_p { + /* .p = */ p, + /* .min_keep = */ min_keep, + /* .buf_sort = */ {}, + } + ); +} + +// min-p + +struct llama_sampler_min_p { + const float p; + const size_t min_keep; +}; + +static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) { + return "min-p"; +} + +static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_min_p *) smpl->ctx; + + if (ctx->p <= 0.0f || !cur_p->size) { return; } - const int64_t t_start_sample_us = ggml_time_us(); - bool min_p_applied = false; - // if the candidates aren't sorted, try the unsorted implementation first - if (!candidates->sorted) { + // if the cur_p aren't sorted, try the unsorted implementation first + if (!cur_p->sorted) { std::vector filtered_tokens; float max_logit = -FLT_MAX; - for (size_t i = 0; i < candidates->size; ++i) { - max_logit = std::max(max_logit, candidates->data[i].logit); + for (size_t i = 0; i < cur_p->size; ++i) { + max_logit = std::max(max_logit, cur_p->data[i].logit); } - const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max + const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max - for (size_t i = 0; i < candidates->size; ++i) { - if (candidates->data[i].logit >= min_logit) { - filtered_tokens.push_back(candidates->data[i]); + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit >= min_logit) { + filtered_tokens.push_back(cur_p->data[i]); } } // if we have enough values the operation was a success - if (filtered_tokens.size() >= min_keep) { - memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); - candidates->size = filtered_tokens.size(); + if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) { + std::copy(filtered_tokens.begin(), filtered_tokens.end(), cur_p->data); + cur_p->size = filtered_tokens.size(); min_p_applied = true; } } - // if the candidates are sorted or the unsorted implementation failed, use this implementation + // if the cur_p are sorted or the unsorted implementation failed, use this implementation if (!min_p_applied) { // Sort the logits in descending order - if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - candidates->sorted = true; + if (!cur_p->sorted) { + llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size); } - const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max + const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max size_t i = 1; // first token always matches - for (; i < candidates->size; ++i) { - if (candidates->data[i].logit < min_logit && i >= min_keep) { + for (; i < cur_p->size; ++i) { + if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) { break; // prob too small } } // Resize the output vector to keep only the matching tokens - candidates->size = i; - } - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + cur_p->size = i; } } -void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { - if (z >= 1.0f || candidates->size <= 2) { - return; - } - - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); - const int64_t t_start_sample_us = ggml_time_us(); - - // Compute the first and second derivatives - std::vector first_derivatives(candidates->size - 1); - std::vector second_derivatives(candidates->size - 2); - - for (size_t i = 0; i < first_derivatives.size(); ++i) { - first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p; - } - for (size_t i = 0; i < second_derivatives.size(); ++i) { - second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; - } - - // Calculate absolute value of second derivatives - for (size_t i = 0; i < second_derivatives.size(); ++i) { - second_derivatives[i] = std::abs(second_derivatives[i]); - } - - // Normalize the second derivatives - { - const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); - - if (second_derivatives_sum > 1e-6f) { - for (float & value : second_derivatives) { - value /= second_derivatives_sum; - } - } else { - for (float & value : second_derivatives) { - value = 1.0f / second_derivatives.size(); - } - } - } - - float cum_sum = 0.0f; - size_t last_idx = candidates->size; - for (size_t i = 0; i < second_derivatives.size(); ++i) { - cum_sum += second_derivatives[i]; - - // Check if the running sum is greater than z or if we have kept at least min_keep tokens - if (cum_sum > z && i >= min_keep) { - last_idx = i; - break; - } - } - - // Resize the output vector to keep only the tokens above the tail location - candidates->size = last_idx; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } +static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_min_p *) smpl->ctx; + return llama_sampler_init_min_p(ctx->p, ctx->min_keep); } -void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +static void llama_sampler_min_p_free(struct llama_sampler * smpl) { + delete (llama_sampler_min_p *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_min_p_i = { + /* .name = */ llama_sampler_min_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_min_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_min_p_clone, + /* .free = */ llama_sampler_min_p_free, +}; + +struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_min_p_i, + /* .ctx = */ new llama_sampler_min_p { + /* .p = */ p, + /* .min_keep = */ min_keep, + } + ); +} + +// typical + +struct llama_sampler_typical { + const float p; + const size_t min_keep; +}; + +static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) { + return "typical"; +} + +static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_typical *) smpl->ctx; + // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr - if (p >= 1.0f) { + if (ctx->p >= 1.0f) { return; } // Compute the softmax of logits and calculate entropy - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); - - const int64_t t_start_sample_us = ggml_time_us(); + llama_sampler_softmax_impl(cur_p, true); float entropy = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) { - entropy += -candidates->data[i].p * logf(candidates->data[i].p); + for (size_t i = 0; i < cur_p->size; ++i) { + entropy += -cur_p->data[i].p * logf(cur_p->data[i].p); } // Compute the absolute difference between negative log probability and entropy for each candidate std::vector shifted_scores; - for (size_t i = 0; i < candidates->size; ++i) { - float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy); + for (size_t i = 0; i < cur_p->size; ++i) { + float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy); shifted_scores.push_back(shifted_score); } // Sort tokens based on the shifted_scores and their corresponding indices - std::vector indices(candidates->size); + std::vector indices(cur_p->size); std::iota(indices.begin(), indices.end(), 0); std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { @@ -331,287 +1000,343 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar for (size_t i = 0; i < indices.size(); ++i) { size_t idx = indices[i]; - cum_sum += candidates->data[idx].p; + cum_sum += cur_p->data[idx].p; // Check if the running sum is greater than typical or if we have kept at least min_keep tokens - if (cum_sum > p && i >= min_keep - 1) { + if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) { last_idx = i + 1; break; } } // Resize the output vector to keep only the locally typical tokens - std::vector new_candidates; + std::vector cur_p_new; for (size_t i = 0; i < last_idx; ++i) { size_t idx = indices[i]; - new_candidates.push_back(candidates->data[idx]); + cur_p_new.push_back(cur_p->data[idx]); } - // Replace the data in candidates with the new_candidates data - std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); - candidates->size = new_candidates.size(); - candidates->sorted = false; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } + // Replace the data in cur_p with the cur_p_new data + std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data); + cur_p->size = cur_p_new.size(); + cur_p->sorted = false; } -void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { - const int64_t t_start_sample_us = ggml_time_us(); +static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_typical *) smpl->ctx; + return llama_sampler_init_typical(ctx->p, ctx->min_keep); +} - // no need to do anything if there is only one (or zero) candidates - if(candidates->size <= 1) { - return; - } +static void llama_sampler_typical_free(struct llama_sampler * smpl) { + delete (llama_sampler_typical *) smpl->ctx; +} - // Calculate maximum possible entropy - float max_entropy = -logf(1.0f / candidates->size); +static struct llama_sampler_i llama_sampler_typical_i = { + /* .name = */ llama_sampler_typical_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_typical_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_typical_clone, + /* .free = */ llama_sampler_typical_free, +}; - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); - - // Calculate entropy of the softmax probabilities - float entropy = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) { - float prob = candidates->data[i].p; - if (prob > 0.0f) { // Ensure no log(0) - entropy -= prob * logf(prob); +struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_typical_i, + /* .ctx = */ new llama_sampler_typical { + /* .p = */ p, + /* .min_keep = */ min_keep, } - } + ); +} - // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above) - float normalized_entropy = entropy / max_entropy; +// temp - // Map the normalized entropy to the desired temperature range using the power function - float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val); +struct llama_sampler_temp { + const float temp; +}; -#ifdef DEBUG - LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp); - LLAMA_LOG_INFO("Entropy: %f\n", entropy); - LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy); - LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy); - LLAMA_LOG_INFO("Exponent: %f\n", exponent_val); - LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp); -#endif +static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) { + return "temp"; +} - // Apply the dynamically calculated temperature scaling - for (size_t i = 0; i < candidates->size; ++i) { - candidates->data[i].logit /= dyn_temp; - } +static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_temp *) smpl->ctx; - // Re-compute softmax probabilities after scaling logits with dynamic temperature - double max_l_double = candidates->data[0].logit; - double cum_sum_double = 0.0; - for (size_t i = 0; i < candidates->size; ++i) { - double p = exp(candidates->data[i].logit - max_l_double); - candidates->data[i].p = p; // Store the scaled probability - cum_sum_double += p; - } - for (size_t i = 0; i < candidates->size; ++i) { - candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities - } + llama_sampler_temp_impl(cur_p, ctx->temp); +} -#ifdef DEBUG - // Print the updated top 25 probabilities after temperature scaling - LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); - for (size_t i = 0; i < 25 && i < candidates->size; ++i) { - LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f); - } -#endif +static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_temp *) smpl->ctx; + return llama_sampler_init_temp(ctx->temp); +} - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; +static void llama_sampler_temp_free(struct llama_sampler * smpl) { + delete (llama_sampler_temp *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_temp_i = { + /* .name = */ llama_sampler_temp_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_clone, + /* .free = */ llama_sampler_temp_free, +}; + +struct llama_sampler * llama_sampler_init_temp(float temp) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_temp_i, + /* .ctx = */ new llama_sampler_temp { + /*.temp = */ temp, + } + ); +} + +// temp-ext + +struct llama_sampler_temp_ext { + const float temp; + const float delta; + const float exponent; +}; + +static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) { + return "temp-ext"; +} + +static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_temp_ext *) smpl->ctx; + if (ctx->delta > 0) { + const float min_temp = std::max(0.0f, ctx->temp - ctx->delta); + const float max_temp = ctx->temp + ctx->delta; + + float exponent_val = ctx->exponent; + + // no need to do anything if there is only one (or zero) candidates + if (cur_p->size <= 1) { + return; + } + + // Calculate maximum possible entropy + float max_entropy = -logf(1.0f / cur_p->size); + + llama_sampler_softmax_impl(cur_p, true); + + // Calculate entropy of the softmax probabilities + float entropy = 0.0f; + for (size_t i = 0; i < cur_p->size; ++i) { + float prob = cur_p->data[i].p; + if (prob > 0.0f) { // Ensure no log(0) + entropy -= prob * logf(prob); + } + } + + // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above) + float normalized_entropy = entropy / max_entropy; + + // Map the normalized entropy to the desired temperature range using the power function + float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val); + + #ifdef DEBUG + LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp); + LLAMA_LOG_INFO("Entropy: %f\n", entropy); + LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy); + LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy); + LLAMA_LOG_INFO("Exponent: %f\n", exponent_val); + LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp); + #endif + + // Apply the dynamically calculated temperature scaling + llama_sampler_temp_impl(cur_p, dyn_temp); + + // Re-compute softmax probabilities after scaling logits with dynamic temperature + const double max_l_double = cur_p->data[0].logit; + + double cum_sum_double = 0.0; + for (size_t i = 0; i < cur_p->size; ++i) { + double p = exp(cur_p->data[i].logit - max_l_double); + cur_p->data[i].p = p; // Store the scaled probability + cum_sum_double += p; + } + + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities + } + + #ifdef DEBUG + // Print the updated top 25 probabilities after temperature scaling + LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); + for (size_t i = 0; i < 25 && i < cur_p->size; ++i) { + LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f); + } + #endif + } else { + llama_sampler_temp_impl(cur_p, ctx->temp); } } -void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) { - const int64_t t_start_sample_us = ggml_time_us(); - - for (size_t i = 0; i < candidates->size; ++i) { - candidates->data[i].logit /= temp; - } - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } +static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx; + return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent); } -void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep) { - if (probability <= 0 || threshold > 0.5f || candidates->size < 2) { +static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { + delete (llama_sampler_temp_ext *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_temp_ext_i = { + /* .name = */ llama_sampler_temp_ext_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_ext_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_ext_clone, + /* .free = */ llama_sampler_temp_ext_free, +}; + +struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_temp_ext_i, + /* .ctx = */ new llama_sampler_temp_ext { + /* .temp = */ temp, + /* .delta = */ delta, + /* .exponent = */ exponent, + } + ); +} + +// xtc + +struct llama_sampler_xtc { + const float probability; + const float threshold; + const size_t min_keep; + + const uint32_t seed; + uint32_t seed_cur; + + std::mt19937 rng; +}; + +static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) { + return "xtc"; +} + +static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_xtc *) smpl->ctx; + + if (ctx->probability <= 0.0f + || ctx->threshold > 0.5f + || cur_p->size < 2) { return; } - GGML_ASSERT(smpl); - const int64_t t_start_sample_us = ggml_time_us(); - if (probability < 1) { - std::uniform_real_distribution distribution(0.0f, 1.0f); - float chance = distribution(smpl->rng); - if (chance > probability) return; + + std::uniform_real_distribution distribution(0.0f, 1.0f); + float chance = distribution(ctx->rng); + if (chance > ctx->probability) { + return; } - llama_sample_softmax_impl(nullptr, candidates); + llama_sampler_softmax_impl(cur_p, true); int pos_last = 0; - for (size_t i = 0; i < candidates->size; ++i) { - if (candidates->data[i].p >= threshold) { + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].p >= ctx->threshold) { pos_last = i; - } else break; - } - - if (candidates->size - pos_last >= min_keep && pos_last > 0) { - candidates->data += pos_last; - candidates->size -= pos_last; - } - - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - smpl->n_sample++; - -} - -void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma) { - - if (top_n_sigma <= 0.0f || candidates->size < 4) { - // top_n_sigma <= 0: disabled - // candidates->size < 4: no point in applying the transformation for fewer than 4 logits. - return; - } - - const int64_t t_start_sample_us = ggml_time_us(); - - float max = candidates->data[0].logit; - float mean = 0; - size_t count = 0; - for (int i = 0; i < (int)candidates->size; ++i) { - // Only count non-negative infinity values - if (candidates->data[i].logit != -INFINITY) { - max = std::max(max, candidates->data[i].logit); - mean += candidates->data[i].logit; - ++count; - } - } - if (count < 4) { - return; // again, tandard deviation is not well defined for so few logits (4 is actually pushing it) - } - mean /= count; - - float sigma2 = 0; - for (int i = 0; i < (int)candidates->size; ++i) { - if (candidates->data[i].logit != -INFINITY) { - float delta = candidates->data[i].logit - mean; - sigma2 += delta*delta; - } - } - float sigma = sqrtf(sigma2/count); - float thresh = max - top_n_sigma*sigma; - - int n_masked = 0; - for (int i = 0; i < (int)candidates->size; ++i) { - if (candidates->data[i].logit != -INFINITY && candidates->data[i].logit < thresh) { - candidates->data[i].logit = -INFINITY; - ++n_masked; - } - } - - // do we really want to compute softmax unconditionally? - // The following coresponds to mainline implementation with the minor optimization - // that we only call the relativly expensive softmax if we masked away some tokens. - if (n_masked > 0 || !candidates->sorted) { - llama_sample_softmax_impl(nullptr, candidates); - } - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - smpl->n_sample++; - } -} - - -void llama_sample_repetition_penalties_impl( - struct llama_sampling * smpl, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present) { - if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { - return; - } - - const int64_t t_start_sample_us = ggml_time_us(); - - // Create a frequency map to count occurrences of each token in last_tokens - std::unordered_map token_count; - for (size_t i = 0; i < penalty_last_n; ++i) { - token_count[last_tokens[i]]++; - } - - // Apply frequency and presence penalties to the candidates - for (size_t i = 0; i < candidates->size; ++i) { - const auto token_iter = token_count.find(candidates->data[i].id); - if (token_iter == token_count.end()) { - continue; - } - - const int count = token_iter->second; - - // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. - // This is common fix for this problem, which is to multiply by the penalty instead of dividing. - if (candidates->data[i].logit <= 0) { - candidates->data[i].logit *= penalty_repeat; } else { - candidates->data[i].logit /= penalty_repeat; + break; } - - candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present; } - candidates->sorted = false; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) { + cur_p->data += pos_last; + cur_p->size -= pos_last; } } -void llama_sample_apply_guidance_impl( - struct llama_sampling * smpl, - float * logits, - float * logits_guidance, - float scale) { - GGML_ASSERT(smpl); +static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_xtc *) smpl->ctx; + auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed); - const auto t_start_sample_us = ggml_time_us(); - const auto n_vocab = smpl->n_vocab; + // copy the state + { + auto * result_ctx = (llama_sampler_xtc *) result->ctx; - llama_log_softmax(logits, n_vocab); - llama_log_softmax(logits_guidance, n_vocab); - - for (int i = 0; i < n_vocab; ++i) { - auto & l = logits[i]; - const auto & g = logits_guidance[i]; - - l = scale * (l - g) + g; + result_ctx->rng = ctx->rng; } - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + return result; } -llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - GGML_ASSERT(smpl); +static void llama_sampler_xtc_free(struct llama_sampler * smpl) { + delete (llama_sampler_xtc *) smpl->ctx; +} - const int32_t n_vocab = float(smpl->n_vocab); +static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_xtc *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} - int64_t t_start_sample_us = ggml_time_us(); +static struct llama_sampler_i llama_sampler_xtc_i = { + /* .name = */ llama_sampler_xtc_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sample_xtc_apply, + /* .reset = */ llama_sampler_xtc_reset, + /* .clone = */ llama_sampler_xtc_clone, + /* .free = */ llama_sampler_xtc_free, +}; - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); +struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { + auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( + /* .iface = */ &llama_sampler_xtc_i, + /* .ctx = */ new llama_sampler_xtc { + /* .probability = */ p, + /* .threshold = */ t, + /* .min_keep = */ min_keep, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + } + ); +} + +// mirostat + +struct llama_sampler_mirostat { + const int32_t n_vocab; + + const uint32_t seed; + uint32_t seed_cur; + + const float tau; + const float eta; + + const int32_t m; + + float mu; + + std::mt19937 rng; +}; + +static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) { + return "mirostat"; +} + +static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_mirostat *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p, true); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; float sum_ti_bi = 0.0; float sum_ti_sq = 0.0; - for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { + for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) { float t_i = logf(float(i + 2) / float(i + 1)); - float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); + float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p); sum_ti_bi += t_i * b_i; sum_ti_sq += t_i * t_i; } @@ -619,125 +1344,606 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama // Compute k from the estimated s_hat and target surprise value float epsilon_hat = s_hat - 1; - float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); + float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat); - // Sample the next word X using top-k sampling - llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1); - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - llama_token X = llama_sample_token_impl(smpl, candidates); - t_start_sample_us = ggml_time_us(); + llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - float observed_surprise = -log2f(candidates->data[X_idx].p); - float e = observed_surprise - tau; + llama_sampler_softmax_impl(cur_p, true); + + const int idx = llama_sample_dist(cur_p, ctx->rng); + + cur_p->selected = idx; + + float observed_surprise = -log2f(cur_p->data[idx].p); + float e = observed_surprise - ctx->tau; // Update mu using the learning rate and error - *mu = *mu - eta * e; - - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - return X; + ctx->mu = ctx->mu - ctx->eta * e; } -llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { - int64_t t_start_sample_us; - t_start_sample_us = ggml_time_us(); +static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx; + auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); - llama_sample_softmax_impl(smpl, candidates); + // copy the state + { + auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx; - // Truncate the words with surprise values greater than mu - candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return -log2f(candidate.p) > *mu; - })); - - if (candidates->size == 0) { - candidates->size = 1; + result_ctx->mu = ctx->mu; + result_ctx->rng = ctx->rng; } - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + return result; +} + +static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_mirostat *) smpl->ctx; + ctx->mu = 2.0f*ctx->tau; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { + delete (llama_sampler_mirostat *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_mirostat_i = { + /* .name = */ llama_sampler_mirostat_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_apply, + /* .reset = */ llama_sampler_mirostat_reset, + /* .clone = */ llama_sampler_mirostat_clone, + /* .free = */ llama_sampler_mirostat_free, +}; + +struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { + auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( + /* .iface = */ &llama_sampler_mirostat_i, + /* .ctx = */ new llama_sampler_mirostat { + /* .n_vocab = */ n_vocab, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .tau = */ tau, + /* .eta = */ eta, + /* .m = */ m, + /* .mu = */ 2.0f*tau, + /* .rng = */ std::mt19937(seed_cur), + } + ); +} + +// mirostat v2 + +struct llama_sampler_mirostat_v2 { + const uint32_t seed; + uint32_t seed_cur; + + const float tau; + const float eta; + + float mu; + + std::mt19937 rng; +}; + +static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) { + return "mirostat-v2"; +} + +static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p, true); + + // Truncate the words with surprise values greater than mu + cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > ctx->mu; + })); + + if (cur_p->size == 0) { + cur_p->size = 1; } // Normalize the probabilities of the remaining words - llama_sample_softmax_impl(smpl, candidates); + llama_sampler_softmax_impl(cur_p, true); - // Sample the next word X from the remaining words - llama_token X = llama_sample_token_impl(smpl, candidates); - t_start_sample_us = ggml_time_us(); + const int idx = llama_sample_dist(cur_p, ctx->rng); - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - float observed_surprise = -log2f(candidates->data[X_idx].p); - float e = observed_surprise - tau; + cur_p->selected = idx; + + float observed_surprise = -log2f(cur_p->data[idx].p); + float e = observed_surprise - ctx->tau; // Update mu using the learning rate and error - *mu = *mu - eta * e; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } - return X; + ctx->mu = ctx->mu - ctx->eta * e; } -llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - const int64_t t_start_sample_us = ggml_time_us(); - - // Find max element - auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit < b.logit; - }); - - llama_token result = max_iter->id; - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - smpl->n_sample++; - } - return result; +static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; + ctx->mu = 2.0f*ctx->tau; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); } -llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) { - GGML_ASSERT(smpl); +static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx; - const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); + auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta); - std::vector probs; - probs.reserve(candidates->size); - for (size_t i = 0; i < candidates->size; ++i) { - probs.push_back(candidates->data[i].p); + // copy the state + { + auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx; + + result_ctx->mu = ctx->mu; + result_ctx->rng = ctx->rng; } - std::discrete_distribution<> dist(probs.begin(), probs.end()); - int idx = dist(rng); - - llama_token result = candidates->data[idx].id; - - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - smpl->n_sample++; - return result; } -llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng); +static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) { + delete (llama_sampler_mirostat_v2 *) smpl->ctx; } +static struct llama_sampler_i llama_sampler_mirostat_v2_i = { + /* .name = */ llama_sampler_mirostat_v2_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_v2_apply, + /* .reset = */ llama_sampler_mirostat_v2_reset, + /* .clone = */ llama_sampler_mirostat_v2_clone, + /* .free = */ llama_sampler_mirostat_v2_free, +}; + +struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { + auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( + /* .iface = */ &llama_sampler_mirostat_v2_i, + /* .ctx = */ new llama_sampler_mirostat_v2 { + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .tau = */ tau, + /* .eta = */ eta, + /* .mu = */ 2.0f*tau, + /* .rng = */ std::mt19937(seed_cur), + } + ); +} + +// grammar + +struct llama_sampler_grammar { + const struct llama_vocab * vocab; + + std::string grammar_str; + std::string grammar_root; + + struct llama_grammar * grammar; +}; + +static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) { + return "grammar"; +} + +static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_grammar *) smpl->ctx; + if (ctx->grammar) { + llama_grammar_accept_impl(*ctx->grammar, token); + } +} + +static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_grammar *) smpl->ctx; + if (ctx->grammar) { + llama_grammar_apply_impl(*ctx->grammar, cur_p); + } +} + +// Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle. +static struct llama_sampler * llama_sampler_init_grammar_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens, + const char ** trigger_patterns, + size_t num_trigger_patterns); + +static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_grammar *) smpl->ctx; + if (!ctx->grammar) { + return; + } + + std::vector trigger_patterns_c; + trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size()); + for (auto & trigger_pattern : ctx->grammar->trigger_patterns) { + trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); + } + + auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), + ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), + ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); + + llama_grammar_free_impl(ctx->grammar); + ctx->grammar = grammar_new; +} + +static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; + + auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0); + GGML_ASSERT(result); + + // copy the state + { + auto * result_ctx = (llama_sampler_grammar *) result->ctx; + + if (ctx->grammar) { + result_ctx->grammar_str = ctx->grammar_str; + result_ctx->grammar_root = ctx->grammar_root; + + result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar); + } + } + + return result; +} + +static void llama_sampler_grammar_free(struct llama_sampler * smpl) { + const auto * ctx = (llama_sampler_grammar *) smpl->ctx; + + if (ctx->grammar) { + llama_grammar_free_impl(ctx->grammar); + } + + delete ctx; +} + +static struct llama_sampler_i llama_sampler_grammar_i = { + /* .name = */ llama_sampler_grammar_name, + /* .accept = */ llama_sampler_grammar_accept_impl, + /* .apply = */ llama_sampler_grammar_apply, + /* .reset = */ llama_sampler_grammar_reset, + /* .clone = */ llama_sampler_grammar_clone, + /* .free = */ llama_sampler_grammar_free, +}; + +static struct llama_sampler * llama_sampler_init_grammar_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens, + const char ** trigger_patterns, + size_t num_trigger_patterns) { + auto * ctx = new llama_sampler_grammar; + + if (grammar_str != nullptr && grammar_str[0] != '\0') { + std::string trigger_pattern; + llama_grammar * grammar = nullptr; + // TODO: remove trigger_words support. + if (trigger_words != nullptr && num_trigger_words > 0) { + GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0); + trigger_pattern = "[\\s\\S]*?("; + for (size_t i = 0; i < num_trigger_words; ++i) { + static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); + if (i > 0) { + trigger_pattern += "|"; + } + trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0"); + } + trigger_pattern += ")[\\s\\S]*"; + + std::array tmp_trigger_patterns = { trigger_pattern.c_str() }; + grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens); + } else { + grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens); + } + *ctx = { + /* .vocab = */ vocab, + /* .grammar_str = */ grammar_str, + /* .grammar_root = */ grammar_root, + /* .grammar = */ grammar, + }; + if (!ctx->grammar) { + delete ctx; + return nullptr; + } + } else { + *ctx = { + /* .vocab = */ vocab, + /* .grammar_str = */ {}, + /* .grammar_root = */ {}, + /* .grammar = */ nullptr, + }; + } + + return llama_sampler_init( + /* .iface = */ &llama_sampler_grammar_i, + /* .ctx = */ ctx + ); +} + +struct llama_sampler * llama_sampler_init_grammar( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0); +} + +struct llama_sampler * llama_sampler_init_grammar_lazy( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0); +} + +struct llama_sampler * llama_sampler_init_grammar_lazy_patterns( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_patterns, + size_t num_trigger_patterns, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns); +} + +// penalties + +struct llama_sampler_penalties { + const int32_t penalty_last_n; + const float penalty_repeat; + const float penalty_freq; + const float penalty_present; + + ring_buffer prev; + + // a frequency map to count token occurrences + std::unordered_map token_count; +}; + +static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) { + return "penalties"; +} + +static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_penalties *) smpl->ctx; + if (ctx->penalty_last_n == 0) { + return; + } + + ctx->token_count[token]++; + + // if the ring buffer is full, remove the oldest token + if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) { + const auto old = ctx->prev.front(); + + ctx->token_count[old]--; + if (ctx->token_count[old] == 0) { + ctx->token_count.erase(old); + } + } + + ctx->prev.push_back(token); + +#if 0 + // sanity check + std::unordered_map tmp; + for (int i = 0; i < std::min(ctx->penalty_last_n, ctx->prev.size()); ++i) { + tmp[ctx->prev.rat(i)]++; + } + + assert(ctx->token_count == tmp); +#endif +} + +static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_penalties *) smpl->ctx; + + if ((ctx->penalty_last_n == 0) || + (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) { + return; + } + + // Apply frequency and presence penalties to the cur_p + for (size_t i = 0; i < cur_p->size; ++i) { + const auto token_iter = ctx->token_count.find(cur_p->data[i].id); + if (token_iter == ctx->token_count.end()) { + continue; + } + + const int count = token_iter->second; + + assert(count > 0 && count <= ctx->penalty_last_n); + + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. + // This is common fix for this problem, which is to multiply by the penalty instead of dividing. + if (cur_p->data[i].logit <= 0) { + cur_p->data[i].logit *= ctx->penalty_repeat; + } else { + cur_p->data[i].logit /= ctx->penalty_repeat; + } + + cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present; + } + + cur_p->sorted = false; +} + +static void llama_sampler_penalties_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_penalties *) smpl->ctx; + ctx->prev.clear(); + ctx->token_count.clear(); +} + +static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_penalties *) smpl->ctx; + auto * result = llama_sampler_init_penalties( + ctx->penalty_last_n, + ctx->penalty_repeat, + ctx->penalty_freq, + ctx->penalty_present); + + // copy the state + { + auto * result_ctx = (llama_sampler_penalties *) result->ctx; + + result_ctx->prev = ctx->prev; + } + + return result; +} + +static void llama_sampler_penalties_free(struct llama_sampler * smpl) { + delete (llama_sampler_penalties *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_penalties_i = { + /* .name = */ llama_sampler_penalties_name, + /* .accept = */ llama_sampler_penalties_accept, + /* .apply = */ llama_sampler_penalties_apply, + /* .reset = */ llama_sampler_penalties_reset, + /* .clone = */ llama_sampler_penalties_clone, + /* .free = */ llama_sampler_penalties_free, +}; + +struct llama_sampler * llama_sampler_init_penalties( + int32_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present) { + penalty_last_n = std::max(penalty_last_n, 0); + + return llama_sampler_init( + /* .iface = */ &llama_sampler_penalties_i, + /* .ctx = */ new llama_sampler_penalties { + /* .penalty_last_n = */ penalty_last_n, + /* .penalty_repeat = */ penalty_repeat, + /* .penalty_freq = */ penalty_freq, + /* .penalty_present = */ penalty_present, + /* .prev = */ ring_buffer(penalty_last_n), + /* .token_count = */ {}, + } + ); +} + +// top-n-sigma + +struct llama_sampler_top_n_sigma { + const float n; +}; + +static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) { + return "top-n-sigma"; +} + +static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; + + if (ctx->n <= 0.0f || cur_p->size <= 1) { + return; + } + + // find max logit and calculate mean + float max = cur_p->data[0].logit; + float logits_sum = 0; + size_t valid_count = 0; + for (size_t i = 0; i < cur_p->size; ++i) { + // Only count non-negative infinity values + if (cur_p->data[i].logit != -INFINITY) { + if (cur_p->data[i].logit > max) { + max = cur_p->data[i].logit; + } + logits_sum += cur_p->data[i].logit; + valid_count++; + } + } + float mean = valid_count > 0 ? logits_sum/valid_count : 0; + + // calculate standard deviation + float acc = 0; + for (size_t i = 0; i < cur_p->size; ++i) { + // Skip -infinity in std calculation + if (cur_p->data[i].logit != -INFINITY) { + acc += pow(cur_p->data[i].logit - mean, 2); + } + } + float std = valid_count > 0 ? sqrt(acc/valid_count) : 0; + + // apply mask + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit < max - (ctx->n * std)) { + cur_p->data[i].logit = -INFINITY; + } + } + + llama_sampler_softmax_impl(cur_p, true); +} + +static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx; + return llama_sampler_init_top_n_sigma(ctx->n); +} + +static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { + delete (llama_sampler_top_n_sigma *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_top_n_sigma_i = { + /* .name = */ llama_sampler_top_n_sigma_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_n_sigma_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_n_sigma_clone, + /* .free = */ llama_sampler_top_n_sigma_free, +}; + +struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_top_n_sigma_i, + /* .ctx = */ new llama_sampler_top_n_sigma { + /* .n = */ n, + } + ); +} // DRY +struct llama_sampler_dry { + int32_t total_context_size; + + const float dry_multiplier; + const float dry_base; + const int32_t dry_allowed_length; + const int32_t dry_penalty_last_n; + + std::unordered_multimap> dry_processed_breakers; + std::vector dry_repeat_count; + std::unordered_map dry_max_token_repeat; + ring_buffer last_tokens; +}; + // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) -static void get_overlapping_token_sequences(const llama_vocab& vocab, const std::string& str, std::unordered_multimap>& token_sequences, int max_tail_len = -1) { - for (llama_token token_id = 0; token_id < (llama_token)vocab.n_tokens(); token_id++) { - auto word = vocab.detokenize( { token_id }, true); +static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap>& token_sequences, int max_tail_len = -1) { + for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) { + std::string word = vocab.detokenize({token_id}, true); if (word.find(str) != std::string::npos) { token_sequences.emplace(token_id, std::vector()); - } - else { - size_t word_len = word.size(), str_len = str.size(); + } else { + size_t word_len = word.size(); + size_t str_len = str.size(); size_t pos = -1; while ((pos = word.find(str[0], pos + 1)) != std::string::npos) { bool match = true; @@ -749,8 +1955,7 @@ static void get_overlapping_token_sequences(const llama_vocab& vocab, const std: } } if (match) { - auto tokenization = vocab.tokenize(str.substr(i), false, false); - //std::vector tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false); + std::vector tokenization = vocab.tokenize(str.substr(i), false, false); if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) { tokenization.resize(max_tail_len); } @@ -773,27 +1978,36 @@ static void get_overlapping_token_sequences(const llama_vocab& vocab, const std: } } -static const char* llama_sampler_dry_name(const struct llama_sampler* /*smpl*/) { +static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) { return "dry"; } +static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { + return; + } + ctx->last_tokens.push_back(token); +} // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) -void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p) { - if (smpl->dry_multiplier == 0.0f || smpl->dry_base < 1.0f || smpl->dry_penalty_last_n == 0) { +static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + + if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { return; } - int32_t effective_dry_penalty_last_n = (smpl->dry_penalty_last_n == -1) ? smpl->total_context_size : std::max(smpl->dry_penalty_last_n, 0); - int last_n_repeat = std::min(std::min((int)smpl->last_tokens.size(), effective_dry_penalty_last_n), smpl->total_context_size); + int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0); + int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size); - if (last_n_repeat <= smpl->dry_allowed_length) { + if (last_n_repeat <= ctx->dry_allowed_length) { return; } - smpl->dry_repeat_count.assign(last_n_repeat, 0); - smpl->dry_max_token_repeat.clear(); + ctx->dry_repeat_count.assign(last_n_repeat, 0); + ctx->dry_max_token_repeat.clear(); // Step 1: Look for restart sequences to limit the maximum repetition length. // Work backwards through the context looking for any token that begins a restart sequence. @@ -819,9 +2033,9 @@ void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_ar int rep_limit = last_n_repeat; for (int i = 0; i < last_n_repeat; ++i) { - llama_token token = smpl->last_tokens.rat(i); - auto its = smpl->dry_processed_breakers.equal_range(token); - if (its.first == smpl->dry_processed_breakers.end()) { + llama_token token = ctx->last_tokens.rat(i); + auto its = ctx->dry_processed_breakers.equal_range(token); + if (its.first == ctx->dry_processed_breakers.end()) { continue; } int longest_match = -1; @@ -835,7 +2049,7 @@ void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_ar bool match = true; for (int offset = 0; offset < seq_len; ++offset) { // The -1 when indexing `last_tokens` is because we already matched the head. - if (it->second[offset] != smpl->last_tokens.rat(i - offset - 1)) { + if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) { match = false; break; } @@ -852,7 +2066,7 @@ void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_ar break; } } - if (rep_limit < smpl->dry_allowed_length) { + if (rep_limit < ctx->dry_allowed_length) { return; } @@ -880,39 +2094,39 @@ void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_ar { const int last = last_n_repeat - 1; - int rt = 0, lt = 0; + + int rt = 0; + int lt = 0; for (int k = 1; k < last_n_repeat; ++k) { if (k > rt) { // If k is outside the current Z-box, do naive computation. int n = 0; - while (n + k < last_n_repeat && smpl->last_tokens.rat(n) == smpl->last_tokens.rat(n + k)) { + while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) { ++n; } - smpl->dry_repeat_count[last - k] = std::min(n, rep_limit); + ctx->dry_repeat_count[last - k] = std::min(n, rep_limit); if (n > 0) { lt = k; rt = k + n - 1; } - } - else { + } else { // If k is inside the current Z-box, consider two cases. int p = k - lt; // Pair index. int right_part_len = rt - k + 1; - if (smpl->dry_repeat_count[last - p] < right_part_len) { - int n = std::min(smpl->dry_repeat_count[last - p], rep_limit); - smpl->dry_repeat_count[last - k] = n; - } - else { + if (ctx->dry_repeat_count[last - p] < right_part_len) { + int n = std::min(ctx->dry_repeat_count[last - p], rep_limit); + ctx->dry_repeat_count[last - k] = n; + } else { int i = rt + 1; - while (i < last_n_repeat && smpl->last_tokens.rat(i) == smpl->last_tokens.rat(i - k)) { + while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) { i += 1; } int n = std::min(i - k, rep_limit); - smpl->dry_repeat_count[last - k] = n; + ctx->dry_repeat_count[last - k] = n; lt = k; rt = i - 1; } @@ -933,16 +2147,16 @@ void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_ar // y: 2 -> 3 (from `b c` to `b c y`) for (int i = 0; i < last_n_repeat - 1; ++i) { - int repeat_len = smpl->dry_repeat_count[i]; - if (repeat_len >= smpl->dry_allowed_length) { + int repeat_len = ctx->dry_repeat_count[i]; + if (repeat_len >= ctx->dry_allowed_length) { // This token ends a repeat, so the next token would continue one. // By convention, the value of `repeat_len` only includes the tokens currently // in the context, not the new token that would be added. - llama_token token = smpl->last_tokens.rat(last_n_repeat - 2 - i); + llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i); // Track the maximum sequence ending in this token. - const auto& it = smpl->dry_max_token_repeat.find(token); - if (it == smpl->dry_max_token_repeat.end() || it->second < repeat_len) { - smpl->dry_max_token_repeat[token] = repeat_len; + const auto& it = ctx->dry_max_token_repeat.find(token); + if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) { + ctx->dry_max_token_repeat[token] = repeat_len; } } } @@ -953,15 +2167,15 @@ void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_ar // Compute it from `penalty_base` and the approximate log of `std::numeric_limits::max()` const float FLOAT_MAX_LOG = 88.7228391f; int max_exponent = 0; - if (smpl->dry_base > 1.000001f) { - max_exponent = FLOAT_MAX_LOG / std::log(smpl->dry_base); + if (ctx->dry_base > 1.000001f) { + max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base); } for (size_t i = 0; i < cur_p->size; ++i) { - const auto& af_kvp = smpl->dry_max_token_repeat.find(cur_p->data[i].id); - if (af_kvp != smpl->dry_max_token_repeat.end()) { + const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id); + if (af_kvp != ctx->dry_max_token_repeat.end()) { // Check all sequence breakers starting with this token - auto range = smpl->dry_processed_breakers.equal_range(cur_p->data[i].id); + auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id); bool is_single_token_breaker = false; for (auto it = range.first; it != range.second; ++it) { @@ -973,11 +2187,11 @@ void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_ar // Apply penalty only if it's not a single-token sequence breaker if (!is_single_token_breaker) { - int repeat_exp = af_kvp->second - smpl->dry_allowed_length; + int repeat_exp = af_kvp->second - ctx->dry_allowed_length; if (max_exponent > 0 && repeat_exp > max_exponent) { repeat_exp = max_exponent; } - float penalty = smpl->dry_multiplier * std::pow(smpl->dry_base, repeat_exp); + float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp); cur_p->data[i].logit -= penalty; } } @@ -986,10 +2200,48 @@ void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_ar cur_p->sorted = false; } +static void llama_sampler_dry_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + ctx->last_tokens.clear(); + ctx->dry_repeat_count.clear(); + ctx->dry_max_token_repeat.clear(); +} +static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) { + const auto * ctx = (llama_sampler_dry *) smpl->ctx; -struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab& vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { - int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0); + llama_vocab dummy_vocab; + + // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying + auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0); + + // Copy the state, including the processed breakers + { + auto * result_ctx = (llama_sampler_dry *) result->ctx; + result_ctx->dry_processed_breakers = ctx->dry_processed_breakers; + result_ctx->dry_repeat_count = ctx->dry_repeat_count; + result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat; + result_ctx->last_tokens = ctx->last_tokens; + } + + return result; +} + +static void llama_sampler_dry_free(struct llama_sampler * smpl) { + delete (llama_sampler_dry *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_dry_i = { + /* .name = */ llama_sampler_dry_name, + /* .accept = */ llama_sampler_dry_accept, + /* .apply = */ llama_sampler_dry_apply, + /* .reset = */ llama_sampler_dry_reset, + /* .clone = */ llama_sampler_dry_clone, + /* .free = */ llama_sampler_dry_free, +}; + +struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { + int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? n_ctx_train : std::max(dry_penalty_last_n, 0); std::unordered_multimap> processed_breakers; const int MAX_CHAR_LEN = 40; const int MAX_SEQ_LEN = 20; @@ -1015,12 +2267,14 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab& sequence_break.resize(MAX_CHAR_LEN); } - get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN); + get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN); } } - return new llama_sampler_dry { - /* .total_context_size = */ context_size, + return llama_sampler_init( + /* .iface = */ &llama_sampler_dry_i, + /* .ctx = */ new llama_sampler_dry { + /* .total_context_size = */ n_ctx_train, /* .dry_multiplier = */ dry_multiplier, /* .dry_base = */ dry_base, /* .dry_allowed_length = */ dry_allowed_length, @@ -1029,180 +2283,403 @@ struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab& /* .dry_repeat_count = */ dry_enabled ? std::vector(effective_dry_penalty_last_n, 0) : std::vector{}, /* .dry_max_token_repeat = */ {}, /* .last_tokens = */ dry_enabled ? ring_buffer(effective_dry_penalty_last_n) : ring_buffer(0), - }; + } + ); } +// wrapper for test-sampling.cpp +struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector>& seq_breakers) { + llama_vocab dummy_vocab; + auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0); + auto * ctx = (llama_sampler_dry *) result->ctx; -// grammar + // Process the token-based sequence breakers + ctx->dry_processed_breakers.clear(); + if (seq_breakers.empty()) { + LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n"); + } else { + for (const auto& breaker : seq_breakers) { + if (breaker.empty()) { + LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n"); + continue; + } + llama_token head_token = breaker[0]; + std::vector tail_tokens(breaker.begin() + 1, breaker.end()); + ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens)); + } -struct llama_sampler_grammar { - const struct llama_vocab* vocab; + if (ctx->dry_processed_breakers.empty()) { + LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n"); + } + } - std::string grammar_str; - std::string grammar_root; + return result; +} - struct llama_grammar* grammar; +// logit-bias + +struct llama_sampler_logit_bias { + const int32_t n_vocab; + + const std::vector logit_bias; + + std::vector to_search; }; -static const char* llama_sampler_grammar_name(const struct llama_sampler* /*smpl*/) { - return "grammar"; +static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) { + return "logit-bias"; } -static void llama_sampler_grammar_accept_impl(struct llama_sampler* smpl, llama_token token) { - auto* ctx = (llama_sampler_grammar*)smpl->ctx; - if (ctx->grammar) { - llama_grammar_accept_token_impl(ctx->grammar,ctx->vocab ,nullptr, token); - } -} +static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; -static void llama_sampler_grammar_apply(struct llama_sampler* smpl, llama_token_data_array* cur_p) { - auto* ctx = (llama_sampler_grammar*)smpl->ctx; - if (ctx->grammar) { - llama_grammar_sample_impl(ctx->grammar, ctx->vocab, nullptr, cur_p); - } -} - -void llama_sampler_reset(struct llama_sampler* smpl) { - if (smpl->iface->reset) { - smpl->iface->reset(smpl); - } -} - -// Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle. -static struct llama_grammar* llama_sampler_init_grammar_impl( - const struct llama_vocab* vocab, - const char* grammar_str, - const char* grammar_root, - bool lazy, - const char** trigger_words, - size_t num_trigger_words, - const llama_token* trigger_tokens, - size_t num_trigger_tokens, - const char** trigger_patterns, - size_t num_trigger_patterns); - -static void llama_sampler_grammar_reset(struct llama_sampler* smpl) { - auto* ctx = (llama_sampler_grammar*)smpl->ctx; - if (!ctx->grammar) { + if (ctx->logit_bias.empty()) { return; } - std::vector trigger_patterns_c; - trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size()); - for (auto& trigger_pattern : ctx->grammar->trigger_patterns) { - trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); - } - auto* grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), - ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), - ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); + ctx->to_search.clear(); - llama_grammar_free_impl(ctx->grammar); - ctx->grammar = grammar_new; -} - -//static struct llama_sampler* llama_sampler_grammar_clone(const struct llama_sampler* smpl) { -// const auto* ctx = (const llama_sampler_grammar*)smpl->ctx; -// -// auto* result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0); -// -// // copy the state -// { -// auto* result_ctx = (llama_sampler_grammar*)result->ctx; -// -// if (ctx->grammar) { -// result_ctx->grammar_str = ctx->grammar_str; -// result_ctx->grammar_root = ctx->grammar_root; -// -// result_ctx->grammar = llama_grammar_copy_impl(ctx->grammar); -// } -// } -// -// return result; -//} - -static void llama_sampler_grammar_free(struct llama_sampler* smpl) { - const auto* ctx = (llama_sampler_grammar*)smpl->ctx; - - if (ctx->grammar) { - llama_grammar_free_impl(ctx->grammar); - } - - delete ctx; -} - -// ? -//static struct llama_sampler_i llama_sampler_grammar_i = { -// /* .name = */ llama_sampler_grammar_name, -// /* .accept = */ llama_sampler_grammar_accept_impl, -// /* .apply = */ llama_sampler_grammar_apply, -// /* .reset = */ llama_sampler_grammar_reset, -// /* .clone = */ NULL, -// /* .free = */ llama_sampler_grammar_free, -//}; - -struct llama_grammar* llama_sampler_init_grammar_impl( - const struct llama_vocab* vocab, - const char* grammar_str, - const char* grammar_root, - bool lazy, - const char** trigger_words, - size_t num_trigger_words, - const llama_token* trigger_tokens, - size_t num_trigger_tokens, - const char** trigger_patterns, - size_t num_trigger_patterns) { - // Huh? this is not used and leaks. auto* ctx = new llama_sampler_grammar; - struct llama_grammar* grammar; - if (grammar_str != nullptr && grammar_str[0] != '\0') { - // TODO: remove trigger_words support. - if (trigger_words != nullptr && num_trigger_words > 0) { - GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0); - std::string trigger_pattern("[\\s\\S]*?("); - for (size_t i = 0; i < num_trigger_words; ++i) { - static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); - if (i > 0) { - trigger_pattern += "|"; - } - trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0"); - } - trigger_pattern += ")[\\s\\S]*"; - auto trigger_pattern_c = trigger_pattern.c_str(); - trigger_patterns = &trigger_pattern_c; - num_trigger_patterns = 1; + // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id) + for (const auto & lb : ctx->logit_bias) { + if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) { + cur_p->data[lb.token].logit += lb.bias; + } else { + ctx->to_search.push_back(lb); } - grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens); } - else { - grammar = nullptr; + + if (ctx->to_search.empty()) { + return; + } + + // search for the remaining candidates that were not found in the previous step + for (size_t i = 0; i < cur_p->size; ++i) { + for (const auto & lb : ctx->to_search) { + if (cur_p->data[i].id == lb.token) { + cur_p->data[i].logit += lb.bias; + break; + } + } } - return grammar; } -struct llama_grammar* llama_sampler_init_grammar( - const struct llama_vocab* vocab, - const char* grammar_str, - const char* grammar_root) { - return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0); +static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx; + return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data()); } -struct llama_grammar* llama_sampler_init_grammar_lazy( - const struct llama_vocab* vocab, - const char* grammar_str, - const char* grammar_root, - const char** trigger_words, - size_t num_trigger_words, - const llama_token* trigger_tokens, - size_t num_trigger_tokens) { - return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0); +static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { + delete (llama_sampler_logit_bias *) smpl->ctx; } -struct llama_grammar* llama_sampler_init_grammar_lazy_patterns( - const struct llama_vocab* vocab, - const char* grammar_str, - const char* grammar_root, - const char** trigger_patterns, - size_t num_trigger_patterns, - const llama_token* trigger_tokens, - size_t num_trigger_tokens) { - return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns); +static struct llama_sampler_i llama_sampler_logit_bias_i = { + /* .name = */ llama_sampler_logit_bias_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_logit_bias_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_logit_bias_clone, + /* .free = */ llama_sampler_logit_bias_free, +}; + +struct llama_sampler * llama_sampler_init_logit_bias( + int32_t n_vocab, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_logit_bias_i, + /* .ctx = */ new llama_sampler_logit_bias { + /* .n_vocab = */ n_vocab, + /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), + /* .to_search = */ {}, + } + ); +} + +// infill + +//#define GGML_DEBUG_SAMPLER_INFILL + +struct llama_sampler_infill { + const struct llama_vocab * vocab; + + std::vector buf0; + std::vector buf1; +}; + +static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) { + return "infill"; +} + +static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_infill *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p, true); + +#if defined(GGML_DEBUG_SAMPLER_INFILL) +#define LOG_DBG_CUR LLAMA_LOG_DEBUG +#else +#define LOG_DBG_CUR(...) +#endif + + for (size_t i = 0; i < cur_p->size; ++i) { + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + + float p_txt_sum = 0.0f; + float p_eog_sum = 0.0f; + + for (size_t i = 0; i < cur_p->size; ++i) { + if (ctx->vocab->is_eog(cur_p->data[i].id)) { + p_eog_sum += cur_p->data[i].p; + } else { + p_txt_sum += cur_p->data[i].p; + } + } + + const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat); + + LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size); + + if (3*p_eog_sum*cur_p->size > p_txt_sum) { + LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum); + + // keep just the EOG tokens + const auto size_org = cur_p->size; + + cur_p->size = 0; + + float p_sum = 0.0f; + + for (size_t i = 0; i < size_org; ++i) { + if (ctx->vocab->is_eog(cur_p->data[i].id)) { + p_sum += cur_p->data[i].p; + + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + } + + return; + } + + size_t n_combined = 0; GGML_UNUSED(n_combined); + + // combine tokens with common prefix + for (size_t i0 = 0; i0 < cur_p->size; ++i0) { + for (size_t i1 = 0; i1 < cur_p->size; ++i1) { + if (cur_p->data[i0].logit == -INFINITY) { + break; + } + + if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) { + continue; + } + + int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); + if (len0 < 0) { + ctx->buf0.resize(len0); + len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); + assert(len0 > 0); + } + + int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); + if (len1 < 0) { + ctx->buf1.resize(len1); + len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); + assert(len1 > 0); + } + + // token i0 is a prefix of token i1 + if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) { + int dst = i0; + int src = i1; + + // merge into the token with higher probability + if (cur_p->data[i1].p > cur_p->data[i0].p) { + std::swap(dst, src); + } + + cur_p->data[dst].p += cur_p->data[src].p; + cur_p->data[src].logit = -INFINITY; + cur_p->data[src].p = 0.0f; + + n_combined++; + } + } + } + + size_t n_non_eog = 0; + + size_t size_org = cur_p->size; + + float p_sum = 0.0f; + float thold = 0.2f; + + cur_p->size = 0; + + LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold); + + for (size_t i = 0; i < size_org; ++i) { + const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id); + + if (cur_p->data[i].p < thold && !is_eog) { + continue; + } + + if (!is_eog) { + ++n_non_eog; + } + + p_sum += cur_p->data[i].p; + + // keep this token + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + + LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog); + + // if no non-EOG tokens are left -> reduce cur_p to single EOT token + if (n_non_eog == 0) { + cur_p->size = 1; + cur_p->data[0].id = ctx->vocab->token_eot(); + if (cur_p->data[0].id == LLAMA_TOKEN_NULL) { + cur_p->data[0].id = ctx->vocab->token_eos(); + } + cur_p->data[0].logit = 1.0f; + + GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL); + + return; + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + + size_org = cur_p->size; + p_sum = 0.0f; + thold = 1.0/(n_non_eog + 1); + + cur_p->size = 0; + + LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold); + + for (size_t i = 0; i < size_org; ++i) { + const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id); + + if (cur_p->data[i].p < thold && !is_eog) { + continue; + } + + p_sum += cur_p->data[i].p; + + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + +#undef LOG_DBG_CUR +} + +static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_infill *) smpl->ctx; + return llama_sampler_init_infill(ctx->vocab); +} + +static void llama_sampler_infill_free(struct llama_sampler * smpl) { + delete (llama_sampler_infill *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_infill_i = { + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, +}; + +struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_infill_i, + /* .ctx = */ new llama_sampler_infill { + /* .vocab = */ vocab, + /* .buf0 = */ std::vector(512), + /* .buf1 = */ std::vector(512), + } + ); +} + +// utils + +uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { + if (smpl->iface == &llama_sampler_dist_i) { + return ((const llama_sampler_dist *) smpl->ctx)->seed_cur; + } + + if (smpl->iface == &llama_sampler_mirostat_i) { + return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur; + } + + if (smpl->iface == &llama_sampler_mirostat_v2_i) { + return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur; + } + + if (smpl->iface == &llama_sampler_chain_i) { + const auto * ctx = (const llama_sampler_chain *) smpl->ctx; + for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) { + const uint32_t seed = llama_sampler_get_seed(*it); + if (seed != LLAMA_DEFAULT_SEED) { + return seed; + } + } + } + + return LLAMA_DEFAULT_SEED; +} + +// perf + +struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) { + struct llama_perf_sampler_data data = {}; + + if (chain == nullptr || chain->iface != &llama_sampler_chain_i) { + GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__); + } + + const auto * ctx = (const struct llama_sampler_chain *) chain->ctx; + + data.t_sample_ms = 1e-3 * ctx->t_sample_us; + data.n_sample = std::max(0, ctx->n_sample); + + return data; +} + +void llama_perf_sampler_print(const struct llama_sampler * chain) { + const auto data = llama_perf_sampler(chain); + + LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample); +} + +void llama_perf_sampler_reset(struct llama_sampler * chain) { + if (chain == nullptr || chain->iface != &llama_sampler_chain_i) { + GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__); + } + + auto * ctx = (struct llama_sampler_chain *) chain->ctx; + + ctx->t_sample_us = ctx->n_sample = 0; } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 855278e2..759dd7dc 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -1,88 +1,32 @@ #pragma once -#include "llama-impl.h" -#include -struct llama_sampling { - llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {} +// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? - std::mt19937 rng; +#include "llama.h" - int32_t n_vocab = 0; +#include - mutable int64_t t_sample_us = 0; - mutable int32_t n_sample = 0; +struct llama_vocab; +struct llama_grammar; - void reset_timings() const { - t_sample_us = 0; - n_sample = 0; - } +// sampler chain + +struct llama_sampler_chain { + llama_sampler_chain_params params; + + std::vector samplers; + + // timing + + mutable int64_t t_sample_us; + + mutable int32_t n_sample; }; -// -// internal API -// - -void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed); - -void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); -void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); -void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep); -void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); -void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp); -void llama_sample_xtc_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep); -void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma); - -struct llama_sampler_dry { - int32_t total_context_size; - - const float dry_multiplier; - const float dry_base; - const int32_t dry_allowed_length; - const int32_t dry_penalty_last_n; - - std::unordered_multimap> dry_processed_breakers; - std::vector dry_repeat_count; - std::unordered_map dry_max_token_repeat; - ring_buffer last_tokens; -}; - -struct llama_sampler_dry * llama_sampler_init_dry_impl( - const struct llama_vocab & vocab, - int32_t context_size, - float dry_multiplier, - float dry_base, - int32_t dry_allowed_length, - int32_t dry_penalty_last_n, - const char ** seq_breakers, - size_t num_breakers); - -void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p); - - - -void llama_sample_repetition_penalties_impl( - struct llama_sampling * smpl, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present); - -void llama_sample_apply_guidance_impl( - struct llama_sampling * smpl, - float * logits, - float * logits_guidance, - float scale); - -llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); -llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); -llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); -llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng); -llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); - - - +struct llama_sampler * llama_sampler_init_dry_testing( + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const std::vector>& seq_breakers); diff --git a/src/llama.cpp b/src/llama.cpp index 5a374b6d..d9de8c7c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -583,7 +583,7 @@ bool llama_context::update_cache_copies() { } llama_context::llama_context(const llama_model & model) - : model(model) , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) { + : model(model), t_start_us(model.t_start_us) , t_load_us(model.t_load_us) { const auto & hparams = model.hparams; cache_copies.resize(2*hparams.n_layer); } @@ -4311,7 +4311,6 @@ struct llama_context * llama_new_context_with_model( ctx->abort_callback = params.abort_callback; ctx->abort_callback_data = params.abort_callback_data; - ctx->sampling.rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; // build worst-case graph for encoder if a model contains encoder ctx->is_encoding = llama_model_has_encoder(model); @@ -5866,7 +5865,7 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da data_ctx.write_model_info(ctx); - data_ctx.write_rng(ctx->sampling.rng); + //data_ctx.write_rng(ctx->sampling.rng); // copy outputs data_ctx.write_output_ids(ctx); @@ -5905,9 +5904,6 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da data_ctx.read_model_info(ctx); - // set rng - data_ctx.read_rng(ctx->sampling.rng); - // set outputs data_ctx.read_output_ids(ctx); data_ctx.read_logits(ctx); @@ -7173,206 +7169,6 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) { } return (int32_t) LLM_CHAT_TEMPLATES.size(); } -// -// grammar -// - -struct llama_grammar * llama_grammar_init( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index) { - return llama_grammar_init_impl(rules, n_rules, start_rule_index); -} - -void llama_grammar_free(struct llama_grammar * grammar) { - llama_grammar_free_impl(grammar); -} -// -//void llama_grammar_init_lazy(struct llama_sampler* smpl) { -// -// if (!grammar) { -// return; -// } -// std::vector trigger_patterns_c; -// trigger_patterns_c.reserve(grammar.grammar->trigger_patterns.size()); -// for (auto& trigger_pattern : grammar.grammar->trigger_patterns) { -// trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); -// } -// //auto* grammar_new = llama_grammar_init_impl(grammar->vocab, "", "root", -// // grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), -// // grammar->trigger_tokens.data(), grammar->trigger_tokens.size()); -// -//} - - -struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { - return llama_grammar_copy_impl(grammar); -} - -void llama_grammar_sample( - const struct llama_grammar * grammar, - const struct llama_context * ctx, - llama_token_data_array * candidates) { - llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates); -} - -void llama_sample_grammar( - struct llama_context * ctx, - llama_token_data_array * candidates, - const struct llama_grammar * grammar) { - llama_grammar_sample(grammar, ctx, candidates); -} - -void llama_grammar_accept_token( - struct llama_grammar * grammar, - struct llama_context * ctx, - llama_token token) { - llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token); -} - -// -// sampling -// - -void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { - llama_set_rng_seed_impl(&ctx->sampling, seed); -} - -void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { - llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates); -} - -void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { - llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep); -} - -void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); -} - -void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); -} - -void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { - llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep); -} - -void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); -} - -void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { - llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val); -} - -void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { - llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp); -} - -void llama_sample_xtc(struct llama_context * ctx, llama_token_data_array * candidates_p, - float probability, float threshold, size_t min_keep) { - llama_sample_xtc_impl(ctx ? &ctx->sampling : nullptr, candidates_p, probability, threshold, min_keep); -} - -void llama_sample_top_n_sigma(struct llama_context * ctx, llama_token_data_array * candidates_p, float top_n_sigma) { - llama_sample_top_n_sigma_impl(ctx ? &ctx->sampling : nullptr, candidates_p, top_n_sigma); -} - - -void llama_sample_dry([[maybe_unused]] struct llama_context* ctx, struct llama_sampler_dry* smpl, llama_token_data_array* candidates_p) { - llama_sampler_dry_apply(smpl, candidates_p); -} - -void llama_sample_repetition_penalties( - struct llama_context * ctx, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present) { - llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); -} - -void llama_sample_apply_guidance( - struct llama_context * ctx, - float * logits, - float * logits_guidance, - float scale) { - llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale); -} - -llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu); -} - -llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { - return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu); -} - -llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { - return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates); -} - -llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) { - return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng); -} - -llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { - return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng); -} - -int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { - static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf"; - if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) { - return strlen(split_path); - } - return 0; -} - - -struct llama_sampler_dry * llama_sampler_init_dry(const struct llama_vocab* vocab, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { - return llama_sampler_init_dry_impl(*vocab, vocab->n_tokens(), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers); -} - -void llama_sampler_dry_reset(struct llama_sampler_dry* smpl) { - if (!smpl) { - return; - } - smpl->last_tokens.clear(); - smpl->dry_repeat_count.clear(); - smpl->dry_max_token_repeat.clear(); -} - -void llama_sampler_dry_free(struct llama_sampler_dry* smpl) { - delete smpl; -} - -struct llama_sampler_dry* llama_sampler_dry_clone(struct llama_sampler_dry* smpl) { - // nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying - auto* result = llama_sampler_init_dry(nullptr, smpl->dry_multiplier, smpl->dry_base, smpl->dry_allowed_length, smpl->dry_penalty_last_n, NULL, 0); - // Copy the state, including the processed breakers - { - auto* result_ctx = smpl; - result_ctx->dry_processed_breakers = smpl->dry_processed_breakers; - result_ctx->dry_repeat_count = smpl->dry_repeat_count; - result_ctx->dry_max_token_repeat = smpl->dry_max_token_repeat; - result_ctx->last_tokens = smpl->last_tokens; - } - - return result; -} - -void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token) { - if (!smpl) { - return; - } - if (smpl->dry_multiplier == 0.0f || smpl->dry_base < 1.0f || smpl->dry_penalty_last_n == 0) { - return; - } - smpl->last_tokens.push_back(token); -} int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) { std::string str_split_path(split_path); @@ -7395,11 +7191,11 @@ struct llama_timings llama_get_timings(struct llama_context * ctx) { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, /*.t_end_ms =*/ 1.00 * ggml_time_ms(), /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, - /*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us, + /*.t_sample_ms =*/ 0.0, //1e-3 * ctx->sampling.t_sample_us, /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, - /*.n_sample =*/ std::max(1, ctx->sampling.n_sample), + /*.n_sample =*/ 0, //std::max(1, ctx->sampling.n_sample), /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), /*.n_eval =*/ std::max(1, ctx->n_eval), }; @@ -7426,7 +7222,7 @@ void llama_reset_timings(struct llama_context * ctx) { ctx->t_eval_us = ctx->n_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0; - ctx->sampling.reset_timings(); + //ctx->sampling.reset_timings(); } const char * llama_print_system_info(void) { @@ -7468,21 +7264,21 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { 1.0e-3 * ctx->t_eval_us / ctx->n_eval); fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n", 1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval); - fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n", - 1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample); + //fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n", + // 1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample); fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); - fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->sampling.n_sample); + //fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->sampling.n_sample); fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); - fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->sampling.t_sample_us); + //fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->sampling.t_sample_us); fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", 1.0e6 * ctx->n_eval / ctx->t_eval_us); fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", 1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us); - fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n", - 1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us); + //fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n", + // 1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us); } // For internal test use