From b86d8024a556d5d2d8029f9acf639a8e47f559d7 Mon Sep 17 00:00:00 2001 From: dungquixote42 <62397442+dungquixote42@users.noreply.github.com> Date: Tue, 3 Feb 2026 00:36:12 -0500 Subject: [PATCH 1/8] Adaptive p: history update fix + temp as flag (#1213) * adaptive_p: fix history update + use current probability for high temp * adaptive_p: fix history update bug, update with current probability if temp is high * replace temp-as-signal with server argument * adaptive_p: rename ema_w_cur_p to updt_w_cur * delete test code --- common/common.cpp | 6 ++++++ common/sampling.cpp | 2 +- common/sampling.h | 1 + include/llama.h | 1 + src/llama-sampling.cpp | 18 +++++++++++++++--- src/llama-sampling.h | 5 ++++- src/llama.cpp | 4 ++-- 7 files changed, 30 insertions(+), 7 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 3192fd37..802fe0df 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -940,6 +940,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.adaptive_decay = std::stof(argv[i]); return true; } + if (arg == "--adaptive-updt-w-cur") { + sparams.adaptive_updt_w_cur = true; + return true; + } if (arg == "--spec-replace") { CHECK_ARG std::string target = argv[i]; @@ -2231,6 +2235,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --top-n-sigma t", "top-n-sigma parmeter (default: %.1f, 0.0 = disabled)", (double)sparams.top_n_sigma}); options.push_back({ "*", " --adaptive-target", "adaptive-p sampling: (default: %.2f, <0.0 = disabled)", (double)sparams.adaptive_target}); options.push_back({ "*", " --adaptive-decay", "adaptive-p sampling: (default: %.2f)", (double)sparams.adaptive_decay}); + options.push_back({ "*", " --adaptive-updt-w-cur", "adaptive-p sampling: (default: %s)", sparams.adaptive_updt_w_cur ? "true" : "false"}); options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n" "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n" "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" }); @@ -4227,6 +4232,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); fprintf(stream, "adaptive_target: %f # default: -1.0\n", sparams.adaptive_target); fprintf(stream, "adaptive_decay: %f # default: 0.9\n", sparams.adaptive_decay); + fprintf(stream, "adaptive_updt_w_cur: %s # default: false\n", sparams.adaptive_updt_w_cur ? "true" : "false"); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); } diff --git a/common/sampling.cpp b/common/sampling.cpp index ba8d3f67..cfec27fd 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -120,7 +120,7 @@ struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vo { GGML_ASSERT(vocab); auto n_vocab = llama_vocab_n_tokens(vocab); - result->adapt_p_ctx = llama_init_adaptive_p(n_vocab, params.adaptive_target, params.adaptive_decay, result->rng()); + result->adapt_p_ctx = llama_init_adaptive_p(n_vocab, params.adaptive_target, params.adaptive_decay, params.adaptive_updt_w_cur, result->rng()); break; } default: diff --git a/common/sampling.h b/common/sampling.h index a5420fa7..f2d1b1bf 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -69,6 +69,7 @@ typedef struct llama_sampling_params { float top_n_sigma = 0.0f; // top-n-sigma float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled) float adaptive_decay = 0.90f; // decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation) + bool adaptive_updt_w_cur = false; // update state with current probability bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context diff --git a/include/llama.h b/include/llama.h index 72cb9edd..a0a8e3ac 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1387,6 +1387,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns( LLAMA_API struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float target, const float decay, + const bool updt_w_cur, const uint32_t seed); void llama_prep_adaptive_p(struct llama_context * ctx, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 5e26eb20..bb94af7a 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1075,9 +1075,13 @@ llama_token llama_sample_token_adaptive_p_impl( GGML_ASSERT(iter != ctx->cum_probs.end()); const size_t idx = std::distance(ctx->cum_probs.begin(), iter); llama_token id = candidates->data[idx].id; - GGML_ASSERT(id < int(ctx->orig_prob.size())); - if (auto update_prob = ctx->orig_prob[id]; update_prob > 0) { + + // update history + const float update_prob = ctx->updt_w_cur + ? candidates->data[idx].p / ctx->cum_cur_p + : ctx->orig_prob[id] / ctx->cum_orig_prob; + if (update_prob > 0) { ctx->weighted_sum = ctx->decay * ctx->weighted_sum + update_prob; ctx->total_weight = ctx->decay * ctx->total_weight + 1.0f; } @@ -1111,6 +1115,7 @@ void llama_sample_adaptive_p_impl(struct llama_sampling * ctx, llama_token_data_ candidates->data[i].p = prob; cum_sum += prob; } + adapt_p_ctx->cum_cur_p = cum_sum; // compute adapted target probability const float target = std::clamp(adapt_p_ctx->target, 0.0f, 1.0f); @@ -1146,6 +1151,10 @@ void llama_prep_adaptive_p_impl( struct llama_sampling * smpl, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx) { + if (adapt_p_ctx->updt_w_cur) { + // update with current probability, original not needed + return; + } constexpr float kDelta = 30.0f; //16.6f; auto t_start = ggml_time_us(); auto & orig_prob = adapt_p_ctx->orig_prob; @@ -1169,17 +1178,20 @@ void llama_prep_adaptive_p_impl( struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab, const float target, const float decay, + const bool updt_w_cur, const uint32_t seed) { GGML_ASSERT(n_vocab > 0); const float clamped_decay = std::clamp(decay, 0.0f, 0.99f); auto result = new llama_sampler_adaptive_p { /* .target = */ target, /* .decay = */ clamped_decay, + /* .updt_w_cur = */ updt_w_cur, /* .rng = */ std::mt19937(seed), /* .weighted_sum = */ target / (1.0f - clamped_decay), /* .total_weight = */ 1.0f / (1.0f - clamped_decay), /* .orig_prob = */ {}, - /* .cum_orig_prob = */ 0.0f, + /* .cum_orig_prob = */ 1.0f, + /* .cum_cur_p = */ 1.0f, /* .max_xform_logit = */ -INFINITY, /* .cum_probs = */ {}, }; diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 8ebbfb49..0d7e72d2 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -68,15 +68,17 @@ void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_ar struct llama_sampler_adaptive_p { const float target; // target probability (0.0 - 1.0; negative = disabled) const float decay; // EMA decay; history ≈ 1/(1-decay) tokens (0.0 - 0.99) + const bool updt_w_cur; // false=original, true=current std::mt19937 rng; // RNG float weighted_sum; // sum(p_n * decay^N) float total_weight; // sum(decay^i), converges to 1/(1-decay) // first referenced in prep - std::vector orig_prob; // for storing the original proibabilities + std::vector orig_prob; // for storing the original proibabilities float cum_orig_prob; // for normalizing orig_prob in sample_token // first referenced in sample + float cum_cur_p; // cumulative sum of current probabilities float max_xform_logit; // maximum logit found during transform // first referenced in sample_token @@ -86,6 +88,7 @@ struct llama_sampler_adaptive_p { struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab, const float target, const float decay, + const bool updt_w_cur, const uint32_t seed); void llama_prep_adaptive_p_impl( diff --git a/src/llama.cpp b/src/llama.cpp index f6b4eba9..4f551745 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7814,8 +7814,8 @@ void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token) } -struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float target, const float decay, const uint32_t seed) { - return llama_init_adaptive_p_impl(n_vocab, target, decay, seed); +struct llama_sampler_adaptive_p * llama_init_adaptive_p(int n_vocab, const float target, const float decay, const bool updt_w_cur, const uint32_t seed) { + return llama_init_adaptive_p_impl(n_vocab, target, decay, updt_w_cur, seed); } From 8ba7e2b40cfd2a416484abcc62f38e9f0692fedb Mon Sep 17 00:00:00 2001 From: saood06 Date: Mon, 2 Feb 2026 23:39:45 -0600 Subject: [PATCH 2/8] Add support for Seed-OSS (#1218) * it compiles * Fix constants.py --- convert_hf_to_gguf.py | 3 + gguf-py/gguf/constants.py | 170 +++++++++++++++++++++--------------- src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + src/llama-build-context.cpp | 100 +++++++++++++++++++++ src/llama-build-context.h | 2 + src/llama-hparams.cpp | 9 +- src/llama-load-tensors.cpp | 47 ++++++++++ src/llama-model.cpp | 17 ++++ src/llama.cpp | 13 +++ 10 files changed, 291 insertions(+), 72 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 392d0125..0a3341fa 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4196,6 +4196,9 @@ class SmolLM3Model(LlamaModel): chat_template = tokenizer.chat_template.replace("[:]", "") self.gguf_writer.add_chat_template(chat_template) +@Model.register("SeedOssForCausalLM") +class SeedOssModel(Model): + model_arch = gguf.MODEL_ARCH.SEED_OSS @Model.register("Dots1ForCausalLM") class Dots1Model(Qwen2MoeModel): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 97992af6..d4af03c0 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -250,6 +250,7 @@ class MODEL_ARCH(IntEnum): BAILINGMOE2 = auto() MINIMAXM2 = auto() SMOLLM3 = auto() + SEED_OSS = auto() class MODEL_TENSOR(IntEnum): TOKEN_EMBD = auto() @@ -398,6 +399,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.BAILINGMOE2: "bailingmoe2", MODEL_ARCH.MINIMAXM2: "minimax-m2", MODEL_ARCH.SMOLLM3: "smollm3", + MODEL_ARCH.SEED_OSS: "seed_oss", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -1362,6 +1364,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.SEED_OSS: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + ], # TODO } @@ -1537,78 +1553,90 @@ class ExpertGatingFuncType(IntEnum): # from llama_ftype in llama.h # ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE. class LlamaFileType(IntEnum): - ALL_F32 = 0 - MOSTLY_F16 = 1 #except 1d tensors - MOSTLY_Q4_0 = 2 #except 1d tensors - MOSTLY_Q4_1 = 3 #except 1d tensors - MOSTLY_Q4_1_SOME_F16 = 4 #tok_embeddings.weight and output.weight are F16 - MOSTLY_Q8_0 = 7 #except 1d tensors - MOSTLY_Q5_0 = 8 #except 1d tensors - MOSTLY_Q5_1 = 9 #except 1d tensors - MOSTLY_Q2_K = 10 #except 1d tensors - MOSTLY_Q3_K = 11 #except 1d tensors - MOSTLY_Q4_K = 12 #except 1d tensors - MOSTLY_Q5_K = 13 #except 1d tensors - MOSTLY_Q6_K = 14 #except 1d tensors - MOSTLY_IQ2_XXS = 15 #except 1d tensors - MOSTLY_IQ2_XS = 16 #except 1d tensors - MOSTLY_IQ3_XXS = 17 #except 1d tensors - MOSTLY_IQ1_S = 18 #except 1d tensors - MOSTLY_IQ4_NL = 19 #except 1d tensors - MOSTLY_IQ3_S = 20 #except 1d tensors - MOSTLY_IQ2_S = 21 #except 1d tensors - MOSTLY_IQ4_XS = 22 #except 1d tensors - MOSTLY_IQ1_M = 23 #except 1d tensors - MOSTLY_BF16 = 24 #except 1d tensors - MOSTLY_MXFP4 = 25 #except 1d tensors - MOSTLY_Q4_0_4_4 = 26 #except 1d tensors - MOSTLY_Q4_0_4_8 = 27 #except 1d tensors - MOSTLY_Q4_0_8_8 = 28 #except 1d tensors - MOSTLY_Q6_0 = 127 #except 1d tensors - MOSTLY_IQ1_BN = 128 #except 1d tensors - MOSTLY_IQ2_BN = 129 #except 1d tensors - MOSTLY_IQ2_K = 130 #except 1d tensors - MOSTLY_IQ3_K = 131 #except 1d tensors - MOSTLY_IQ4_K = 132 #except 1d tensors - MOSTLY_IQ5_K = 133 #except 1d tensors - MOSTLY_IQ6_K = 134 #except 1d tensors - MOSTLY_IQ4_KS = 137 #except 1d tensors - MOSTLY_IQ2_KS = 138 #except 1d tensors - MOSTLY_IQ4_KSS = 139 #except 1d tensors - MOSTLY_Q8_KV = 140 #except 1d tensors - MOSTLY_IQ5_KS = 141 #except 1d tensors - MOSTLY_IQ2_KT = 142 #except 1d tensors - MOSTLY_IQ3_KT = 143 #except 1d tensors - MOSTLY_IQ4_KT = 144 #except 1d tensors - MOSTLY_Q4_0_R8 = 202 #except 1d tensors - MOSTLY_Q8_0_R8 = 207 #except 1d tensors - MOSTLY_Q5_0_R4 = 208 #except 1d tensors - MOSTLY_Q2_K_R4 = 210 #except 1d tensors - MOSTLY_Q3_K_R4 = 211 #except 1d tensors - MOSTLY_Q4_K_R4 = 212 #except 1d tensors - MOSTLY_Q5_K_R4 = 213 #except 1d tensors - MOSTLY_Q6_K_R4 = 214 #except 1d tensors - MOSTLY_IQ2_XXS_R4 = 215 #except 1d tensors - MOSTLY_IQ2_XS_R4 = 216 #except 1d tensors - MOSTLY_IQ3_XXS_R4 = 217 #except 1d tensors - MOSTLY_IQ1_S_R4 = 218 #except 1d tensors - MOSTLY_IQ4_NL_R4 = 219 #except 1d tensors - MOSTLY_IQ3_S_R4 = 220 #except 1d tensors - MOSTLY_IQ2_S_R4 = 221 #except 1d tensors - MOSTLY_IQ4_XS_R8 = 222 #except 1d tensors - MOSTLY_IQ1_M_R4 = 223 #except 1d tensors - MOSTLY_BF16_R16 = 224 #except 1d tensors - MOSTLY_Q6_0_R4 = 227 #except 1d tensors - MOSTLY_IQ2_BN_R4 = 329 #except 1d tensors - MOSTLY_IQ2_K_R4 = 330 #except 1d tensors - MOSTLY_IQ3_K_R4 = 331 #except 1d tensors - MOSTLY_IQ4_K_R4 = 332 #except 1d tensors - MOSTLY_IQ5_K_R4 = 333 #except 1d tensors - MOSTLY_IQ4_KS_R4 = 337 #except 1d tensors - MOSTLY_IQ5_KS_R4 = 341 #except 1d tensors - MOSTLY_Q8_KV_R8 = 398 #except 1d tensors - MOSTLY_Q8_K_R8 = 399 #except 1d tensors + ALL_F32 = 0 + MOSTLY_F16 = 1 #except 1d tensors + MOSTLY_Q4_0 = 2 #except 1d tensors + MOSTLY_Q4_1 = 3 #except 1d tensors + MOSTLY_Q8_0 = 7 #except 1d tensors + MOSTLY_Q5_0 = 8 #except 1d tensors + MOSTLY_Q5_1 = 9 #except 1d tensors + MOSTLY_Q2_K = 10 #except 1d tensors + MOSTLY_Q3_K_S = 11 #except 1d tensors + MOSTLY_Q3_K_M = 12 #except 1d tensors + MOSTLY_Q3_K_L = 13 #except 1d tensors + MOSTLY_Q4_K_S = 14 #except 1d tensors + MOSTLY_Q4_K_M = 15 #except 1d tensors + MOSTLY_Q5_K_S = 16 #except 1d tensors + MOSTLY_Q5_K_M = 17 #except 1d tensors + MOSTLY_Q6_K = 18 #except 1d tensors + MOSTLY_IQ2_XXS = 19 #except 1d tensors + MOSTLY_IQ2_XS = 20 #except 1d tensors + MOSTLY_Q2_K_S = 21 #except 1d tensors + MOSTLY_IQ3_XS = 22 #except 1d tensors + MOSTLY_IQ3_XXS = 23 #except 1d tensors + MOSTLY_IQ1_S = 24 #except 1d tensors + MOSTLY_IQ4_NL = 25 #except 1d tensors + MOSTLY_IQ3_S = 26 #except 1d tensors + MOSTLY_IQ3_M = 27 #except 1d tensors + MOSTLY_IQ2_S = 28 #except 1d tensors + MOSTLY_IQ2_M = 29 #except 1d tensors + MOSTLY_IQ4_XS = 30 #except 1d tensors + MOSTLY_IQ1_M = 31 #except 1d tensors + MOSTLY_BF16 = 32 #except 1d tensors + MOSTLY_Q4_0_4_4 = 33 #except 1d tensors + MOSTLY_Q4_0_4_8 = 34 #except 1d tensors + MOSTLY_Q4_0_8_8 = 35 #except 1d tensors + MOSTLY_MXFP4 = 38 #except 1d tensors, 38 to be compatible with mainline + MOSTLY_Q6_0 = 135 #except 1d tensors + MOSTLY_IQ1_BN = 136 #except 1d tensors + MOSTLY_IQ2_BN = 137 #except 1d tensors + MOSTLY_IQ2_K = 138 #except 1d tensors + MOSTLY_IQ3_K = 139 #except 1d tensors + MOSTLY_IQ4_K = 140 #except 1d tensors + MOSTLY_IQ5_K = 141 #except 1d tensors + MOSTLY_IQ6_K = 142 #except 1d tensors + MOSTLY_IQ4_KS = 145 #except 1d tensors + MOSTLY_IQ3_KL = 146 #except 1d tensors + MOSTLY_IQ2_KS = 147 #except 1d tensors + MOSTLY_IQ4_KSS = 148 #except 1d tensors + MOSTLY_Q8_KV = 149 #except 1d tensors + MOSTLY_IQ5_KS = 150 #except 1d tensors + MOSTLY_IQ2_KT = 151 #except 1d tensors + MOSTLY_IQ3_KT = 152 #except 1d tensors + MOSTLY_IQ4_KT = 153 #except 1d tensors + MOSTLY_IQ3_KS = 154 #except 1d tensors + MOSTLY_IQ2_KL = 155 #except 1d tensors + MOSTLY_IQ1_KT = 156 #except 1d tensors + + MOSTLY_Q4_0_R8 = 202 #except 1d tensors + MOSTLY_Q8_0_R8 = 207 #except 1d tensors + MOSTLY_Q5_0_R4 = 208 #except 1d tensors + MOSTLY_Q2_K_R4 = 210 #except 1d tensors + MOSTLY_Q3_K_R4 = 211 #except 1d tensors + MOSTLY_Q4_K_R4 = 214 #except 1d tensors + MOSTLY_Q5_K_R4 = 216 #except 1d tensors + MOSTLY_Q6_K_R4 = 218 #except 1d tensors + MOSTLY_IQ2_XXS_R4 = 219 #except 1d tensors + MOSTLY_IQ2_XS_R4 = 220 #except 1d tensors + MOSTLY_IQ3_XXS_R4 = 223 #except 1d tensors + MOSTLY_IQ1_S_R4 = 224 #except 1d tensors + MOSTLY_IQ4_NL_R4 = 225 #except 1d tensors + MOSTLY_IQ3_S_R4 = 226 #except 1d tensors + MOSTLY_IQ2_M_R4 = 229 #except 1d tensors + MOSTLY_IQ4_XS_R8 = 230 #except 1d tensors + MOSTLY_IQ1_M_R4 = 231 #except 1d tensors + MOSTLY_Q6_0_R4 = 335 #except 1d tensors + MOSTLY_BF16_R16 = 232 #except 1d tensors + MOSTLY_IQ2_BN_R4 = 337 #except 1d tensors + MOSTLY_IQ2_K_R4 = 338 #except 1d tensors + MOSTLY_IQ3_K_R4 = 339 #except 1d tensors + MOSTLY_IQ4_K_R4 = 340 #except 1d tensors + MOSTLY_IQ5_K_R4 = 341 #except 1d tensors + MOSTLY_IQ4_KS_R4 = 345 #except 1d tensors + MOSTLY_IQ5_KS_R4 = 350 #except 1d tensors + MOSTLY_Q8_KV_R8 = 398 #except 1d tensors + MOSTLY_Q8_K_R8 = 399 #except 1d tensors GUESSED = 1024 # not specified in the model file diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index cc8fb624..3268e219 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -70,6 +70,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_MISTRAL3, "mistral3" }, { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_SEED_OSS, "seed_oss" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; diff --git a/src/llama-arch.h b/src/llama-arch.h index efcbb577..b4b7df71 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -69,6 +69,7 @@ enum llm_arch { LLM_ARCH_SMOLLM3, LLM_ARCH_MISTRAL3, LLM_ARCH_MIMO2, + LLM_ARCH_SEED_OSS, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 097b1bc5..f44b4c2c 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -3506,6 +3506,102 @@ ggml_cgraph * llm_build_context::build_stablelm() { return gf; } +ggml_cgraph * llm_build_context::build_seedoss() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, 0.f, il); + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + } + + if (il == n_layer - 1) { + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward forward + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "attn_post_norm", il); + cur = llm_build_ffn(ctx0, lctx, model.layers[il].attn_post_norm, ffn_inp, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; +} + + ggml_cgraph * llm_build_context::build_qwen() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -9299,6 +9395,10 @@ ggml_cgraph * llm_build_context::llama_build_graph( { result = llm.build_mimo2(); } break; + case LLM_ARCH_SEED_OSS: + { + result = llm.build_seedoss(); + } break; default: GGML_ABORT("fatal error"); } diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 65b8e9f8..1c60028e 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -278,6 +278,8 @@ struct llm_build_context { ggml_cgraph * build_mimo2(); + ggml_cgraph * build_seedoss(); + // static ggml_tensor * llm_build_lora_mm(llama_context & lctx, ggml_context * ctx0, ggml_tensor * w, ggml_tensor * cur); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 24d07a9a..e5d9f764 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -1107,7 +1107,14 @@ void llm_load_hparams( } } break; - + case LLM_ARCH_SEED_OSS: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 64: model.type = e_model::MODEL_36B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index 159a37bc..d66333ea 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -139,6 +139,8 @@ struct create_tensors_helper : public create_tensors_helper_interface { bool create_mimo2_tensors(const LLM_TN & tn); + bool create_seedoss_tensors(const LLM_TN & tn); + llama_model_loader & ml; llama_model & model; @@ -981,6 +983,49 @@ bool create_tensors_helper::create_stablelm_tensors(const LLM_TN & tn) { return use_mmap_buffer; } +bool create_tensors_helper::create_seedoss_tensors(const LLM_TN & tn) { + LOADING_PRELUDE + + const int64_t n_qo_dim = n_head * n_embd_head_k; + const int64_t n_kv_dim = n_head_kv * n_embd_head_k; + + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}); + + // optional bias tensors + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}); + + create_std_ffn(i, tn, layer, n_ff, n_embd, ctx_split); + } + return use_mmap_buffer; +} + bool create_tensors_helper::create_qwen_tensors(const LLM_TN & tn) { LOADING_PRELUDE create_embd_output(tn, n_embd, n_vocab); @@ -3058,6 +3103,8 @@ bool create_tensors_helper::create_tensors() { use_mmap_buffer = create_smollm3_tensors(tn); break; case LLM_ARCH_MIMO2: use_mmap_buffer = create_mimo2_tensors(tn); break; + case LLM_ARCH_SEED_OSS: + use_mmap_buffer = create_seedoss_tensors(tn); break; default: throw std::runtime_error("unknown architecture"); } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index be4b9cef..4fc17f9c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1317,6 +1317,23 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, }, }, + { + LLM_ARCH_SEED_OSS, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { diff --git a/src/llama.cpp b/src/llama.cpp index 4f551745..a78fef0d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -224,6 +224,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_BAILING, LLM_CHAT_TEMPLATE_BAILING_THINK, LLM_CHAT_TEMPLATE_BAILING2, + LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_UNKNOWN, }; @@ -269,6 +270,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "bailing", LLM_CHAT_TEMPLATE_BAILING }, { "bailing-think", LLM_CHAT_TEMPLATE_BAILING_THINK }, { "bailing2", LLM_CHAT_TEMPLATE_BAILING2 }, + { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, }; @@ -5047,6 +5049,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_BAILINGMOE2: case LLM_ARCH_MINIMAX_M2: case LLM_ARCH_MIMO2: + case LLM_ARCH_SEED_OSS: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: @@ -7004,6 +7007,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_GROK_2; } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) { return LLM_CHAT_TEMPLATE_OPENAI_MOE; + } else if (tmpl_contains("")) { + return LLM_CHAT_TEMPLATE_SEED_OSS; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -7533,6 +7538,14 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "Assistant:"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_SEED_OSS) { + for (auto message: chat) { + std::string role(message->role); + ss << "" << role << "\n" << (role == "assistant" ? trim(message->content) : message->content) << ""; + } + if (add_ass) { + ss << "assistant\n"; + } } else { // template not supported return -1; From 7e8d4440338ddfb6b53b500549090e1f30be55f4 Mon Sep 17 00:00:00 2001 From: firecoperana <18252262+firecoperana@users.noreply.github.com> Date: Mon, 2 Feb 2026 23:57:17 -0600 Subject: [PATCH 3/8] llama : add token matching support to llama-grammar (#1220) * llama : add token matching support to llama-grammar llama : add token matching support to llama-grammar (#17816) common/grammar : replace problematic backtracking regex `[\s\S]*` (#18342) * disable tests and fix warnings --------- Co-authored-by: firecoperana --- common/CMakeLists.txt | 2 - common/chat.cpp | 8 +- common/grammar-parser.cpp | 542 --------------------- common/grammar-parser.h | 30 -- common/regex-partial.cpp | 26 +- common/sampling.cpp | 108 +--- common/sampling.h | 10 +- examples/gbnf-validator/gbnf-validator.cpp | 21 +- examples/infill/infill.cpp | 2 +- examples/speculative/speculative.cpp | 4 +- grammars/README.md | 24 + include/llama.h | 73 +-- src/llama-grammar.cpp | 408 ++++++++++++---- src/llama-grammar.h | 96 +++- src/llama-sampling.cpp | 2 +- src/llama.cpp | 27 +- tests/CMakeLists.txt | 22 +- tests/test-grammar-integration.cpp | 111 ++++- tests/test-grammar-parser.cpp | 14 + tests/test-llama-grammar.cpp | 2 +- tests/test-regex-partial.cpp | 28 +- 21 files changed, 644 insertions(+), 916 deletions(-) delete mode 100644 common/grammar-parser.cpp delete mode 100644 common/grammar-parser.h diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 4d3f462b..e1992589 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -63,8 +63,6 @@ add_library(${TARGET} STATIC sampling.cpp console.h console.cpp - grammar-parser.h - grammar-parser.cpp json-partial.h json-partial.cpp llguidance.cpp diff --git a/common/chat.cpp b/common/chat.cpp index 850340b8..34a48ea5 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1519,7 +1519,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp // Trigger on tool calls that appear in the commentary channel data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - "<\\|channel\\|>(commentary|analysis) to" + "<\\|channel\\|>(?:commentary|analysis) to" }); // Trigger tool calls that appear in the role section, either at the @@ -1850,17 +1850,17 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call)); // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives) data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, // If thinking_forced_open, then we capture the tag in the grammar, // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + ( + std::string(data.thinking_forced_open ? "(\\s*)" : "") + ( "\\s*(" "(?:" "||||)?" "\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\"" ")" - ")[\\s\\S]*" + ")" ), }); data.preserved_tokens = { diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp deleted file mode 100644 index d7b0fcba..00000000 --- a/common/grammar-parser.cpp +++ /dev/null @@ -1,542 +0,0 @@ -#include "grammar-parser.h" -#include -#include -#include -#include -#include -#include - -namespace grammar_parser { - // NOTE: assumes valid utf8 (but checks for overrun) - // copied from llama.cpp - static std::pair decode_utf8(const char * src) { - static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; - uint8_t first_byte = static_cast(*src); - uint8_t highbits = first_byte >> 4; - int len = lookup[highbits]; - uint8_t mask = (1 << (8 - len)) - 1; - uint32_t value = first_byte & mask; - const char * end = src + len; // may overrun! - const char * pos = src + 1; - for ( ; pos < end && *pos; pos++) { - value = (value << 6) + (static_cast(*pos) & 0x3F); - } - return std::make_pair(value, pos); - } - - static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { - uint32_t next_id = static_cast(state.symbol_ids.size()); - auto result = state.symbol_ids.emplace(std::string(src, len), next_id); - return result.first->second; - } - - static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { - uint32_t next_id = static_cast(state.symbol_ids.size()); - state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; - return next_id; - } - - static void add_rule( - parse_state & state, - uint32_t rule_id, - const std::vector & rule) { - if (state.rules.size() <= rule_id) { - state.rules.resize(rule_id + 1); - } - state.rules[rule_id] = rule; - } - - static bool is_digit_char(char c) { - return '0' <= c && c <= '9'; - } - - static bool is_word_char(char c) { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); - } - - static std::pair parse_hex(const char * src, int size) { - const char * pos = src; - const char * end = src + size; - uint32_t value = 0; - for ( ; pos < end && *pos; pos++) { - value <<= 4; - char c = *pos; - if ('a' <= c && c <= 'f') { - value += c - 'a' + 10; - } else if ('A' <= c && c <= 'F') { - value += c - 'A' + 10; - } else if ('0' <= c && c <= '9') { - value += c - '0'; - } else { - break; - } - } - if (pos != end) { - throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); - } - return std::make_pair(value, pos); - } - - static const char * parse_space(const char * src, bool newline_ok) { - const char * pos = src; - while (*pos == ' ' || *pos == '\t' || *pos == '#' || - (newline_ok && (*pos == '\r' || *pos == '\n'))) { - if (*pos == '#') { - while (*pos && *pos != '\r' && *pos != '\n') { - pos++; - } - } else { - pos++; - } - } - return pos; - } - - static const char * parse_name(const char * src) { - const char * pos = src; - while (is_word_char(*pos)) { - pos++; - } - if (pos == src) { - throw std::runtime_error(std::string("expecting name at ") + src); - } - return pos; - } - - static const char * parse_int(const char * src) { - const char * pos = src; - while (is_digit_char(*pos)) { - pos++; - } - if (pos == src) { - throw std::runtime_error(std::string("expecting integer at ") + src); - } - return pos; - } - - static std::pair parse_char(const char * src) { - if (*src == '\\') { - switch (src[1]) { - case 'x': return parse_hex(src + 2, 2); - case 'u': return parse_hex(src + 2, 4); - case 'U': return parse_hex(src + 2, 8); - case 't': return std::make_pair('\t', src + 2); - case 'r': return std::make_pair('\r', src + 2); - case 'n': return std::make_pair('\n', src + 2); - case '\\': - case '"': - case '[': - case ']': - return std::make_pair(src[1], src + 2); - default: - throw std::runtime_error(std::string("unknown escape at ") + src); - } - } else if (*src) { - return decode_utf8(src); - } - throw std::runtime_error("unexpected end of input"); - } - - const char * parse_alternates( - parse_state & state, - const char * src, - const std::string & rule_name, - uint32_t rule_id, - bool is_nested); - - static const char * parse_sequence( - parse_state & state, - const char * src, - const std::string & rule_name, - std::vector & out_elements, - bool is_nested) { - size_t last_sym_start = out_elements.size(); - const char * pos = src; - - auto handle_repetitions = [&](int min_times, int max_times) { - - if (last_sym_start == out_elements.size()) { - throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); - } - - // apply transformation to previous symbol (last_sym_start to end) according to - // the following rewrite rules: - // S{m,n} --> S S S (m times) S'(n-m) - // S'(x) ::= S S'(x-1) | - // (... n-m definitions of these S' rules ...) - // S'(1) ::= S | - // S{m,} --> S S S (m times) S' - // S' ::= S S' | - // S* --> S{0,} - // --> S' ::= S S' | - // S+ --> S{1,} - // --> S S' - // S' ::= S S' | - // S? --> S{0,1} - // --> S' - // S' ::= S | - - std::vector previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); - if (min_times == 0) { - out_elements.resize(last_sym_start); - } else { - // Repeat the previous elements (min_times - 1) times - for (int i = 1; i < min_times; i++) { - out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); - } - } - - uint32_t last_rec_rule_id = 0; - auto n_opt = max_times < 0 ? 1 : max_times - min_times; - - std::vector rec_rule(previous_elements); - for (int i = 0; i < n_opt; i++) { - rec_rule.resize(previous_elements.size()); - uint32_t rec_rule_id = generate_symbol_id(state, rule_name); - if (i > 0 || max_times < 0) { - rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); - } - rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - rec_rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, rec_rule_id, rec_rule); - last_rec_rule_id = rec_rule_id; - } - if (n_opt > 0) { - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); - } - }; - - while (*pos) { - if (*pos == '"') { // literal string - pos++; - last_sym_start = out_elements.size(); - while (*pos != '"') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '[') { // char range(s) - pos++; - enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; - if (*pos == '^') { - pos++; - start_type = LLAMA_GRETYPE_CHAR_NOT; - } - last_sym_start = out_elements.size(); - while (*pos != ']') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - enum llama_gretype type = last_sym_start < out_elements.size() - ? LLAMA_GRETYPE_CHAR_ALT - : start_type; - - out_elements.push_back({type, char_pair.first}); - if (pos[0] == '-' && pos[1] != ']') { - if (!pos[1]) { - throw std::runtime_error("unexpected end of input"); - } - auto endchar_pair = parse_char(pos + 1); - pos = endchar_pair.second; - out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); - } - } - pos = parse_space(pos + 1, is_nested); - } else if (is_word_char(*pos)) { // rule reference - const char * name_end = parse_name(pos); - uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); - pos = parse_space(name_end, is_nested); - last_sym_start = out_elements.size(); - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); - } else if (*pos == '(') { // grouping - // parse nested alternates into synthesized rule - pos = parse_space(pos + 1, true); - uint32_t sub_rule_id = generate_symbol_id(state, rule_name); - pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); - last_sym_start = out_elements.size(); - // output reference to synthesized rule - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - if (*pos != ')') { - throw std::runtime_error(std::string("expecting ')' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '.') { // any char - last_sym_start = out_elements.size(); - out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '*') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, -1); - } else if (*pos == '+') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(1, -1); - } else if (*pos == '?') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, 1); - } else if (*pos == '{') { - pos = parse_space(pos + 1, is_nested); - - if (!is_digit_char(*pos)) { - throw std::runtime_error(std::string("expecting an int at ") + pos); - } - const char * int_end = parse_int(pos); - int min_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - - int max_times = -1; - - if (*pos == '}') { - max_times = min_times; - pos = parse_space(pos + 1, is_nested); - } else if (*pos == ',') { - pos = parse_space(pos + 1, is_nested); - - if (is_digit_char(*pos)) { - const char * int_end = parse_int(pos); - max_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - } - - if (*pos != '}') { - throw std::runtime_error(std::string("expecting '}' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else { - throw std::runtime_error(std::string("expecting ',' at ") + pos); - } - handle_repetitions(min_times, max_times); - } else { - break; - } - } - return pos; - } - - const char * parse_alternates( - parse_state & state, - const char * src, - const std::string & rule_name, - uint32_t rule_id, - bool is_nested) { - std::vector rule; - const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); - while (*pos == '|') { - rule.push_back({LLAMA_GRETYPE_ALT, 0}); - pos = parse_space(pos + 1, true); - pos = parse_sequence(state, pos, rule_name, rule, is_nested); - } - rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, rule_id, rule); - return pos; - } - - static const char * parse_rule(parse_state & state, const char * src) { - const char * name_end = parse_name(src); - const char * pos = parse_space(name_end, false); - size_t name_len = name_end - src; - uint32_t rule_id = get_symbol_id(state, src, name_len); - const std::string name(src, name_len); - - if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { - throw std::runtime_error(std::string("expecting ::= at ") + pos); - } - pos = parse_space(pos + 3, true); - - pos = parse_alternates(state, pos, name, rule_id, false); - - if (*pos == '\r') { - pos += pos[1] == '\n' ? 2 : 1; - } else if (*pos == '\n') { - pos++; - } else if (*pos) { - throw std::runtime_error(std::string("expecting newline or end at ") + pos); - } - return parse_space(pos, true); - } - - parse_state parse(const char * src) { - try { - parse_state state; - const char * pos = parse_space(src, true); - while (*pos) { - pos = parse_rule(state, pos); - } - // Validate the state to ensure that all rules are defined - for (const auto & rule : state.rules) { - if (rule.empty()) { - throw std::runtime_error("Undefined rule"); - } - for (const auto & elem : rule) { - if (elem.type == LLAMA_GRETYPE_RULE_REF) { - // Ensure that the rule at that location exists - if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) { - // Get the name of the rule that is missing - for (const auto & kv : state.symbol_ids) { - if (kv.second == elem.value) { - throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); - } - } - } - } - } - } - state.success = true; - return state; - } catch (const std::exception & err) { - fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src); - parse_state state; - state.success = false; - return state; - } - } - - static void print_grammar_char(FILE * file, uint32_t c) { - if (0x20 <= c && c <= 0x7f) { - fprintf(file, "%c", static_cast(c)); - } else { - // cop out of encoding UTF-8 - fprintf(file, "", c); - } - } - - static bool is_char_element(llama_grammar_element elem) { - switch (elem.type) { - case LLAMA_GRETYPE_CHAR: return true; - case LLAMA_GRETYPE_CHAR_NOT: return true; - case LLAMA_GRETYPE_CHAR_ALT: return true; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; - case LLAMA_GRETYPE_CHAR_ANY: return true; - default: return false; - } - } - - static void print_rule_binary(FILE * file, const std::vector & rule) { - for (auto elem : rule) { - switch (elem.type) { - case LLAMA_GRETYPE_END: fprintf(file, "END"); break; - case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; - case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; - case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; - case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; - case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; - case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; - } - switch (elem.type) { - case LLAMA_GRETYPE_END: - case LLAMA_GRETYPE_ALT: - case LLAMA_GRETYPE_RULE_REF: - fprintf(file, "(%u) ", elem.value); - break; - case LLAMA_GRETYPE_CHAR: - case LLAMA_GRETYPE_CHAR_NOT: - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - case LLAMA_GRETYPE_CHAR_ALT: - case LLAMA_GRETYPE_CHAR_ANY: - fprintf(file, "(\""); - print_grammar_char(file, elem.value); - fprintf(file, "\") "); - break; - } - } - fprintf(file, "\n"); - } - - static void print_rule( - FILE * file, - uint32_t rule_id, - const std::vector & rule, - const std::map & symbol_id_names) { - if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { - throw std::runtime_error( - "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); - } - fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); - for (size_t i = 0, end = rule.size() - 1; i < end; i++) { - llama_grammar_element elem = rule[i]; - switch (elem.type) { - case LLAMA_GRETYPE_END: - throw std::runtime_error( - "unexpected end of rule: " + std::to_string(rule_id) + "," + - std::to_string(i)); - case LLAMA_GRETYPE_ALT: - fprintf(file, "| "); - break; - case LLAMA_GRETYPE_RULE_REF: - fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); - break; - case LLAMA_GRETYPE_CHAR: - fprintf(file, "["); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_NOT: - fprintf(file, "[^"); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - if (i == 0 || !is_char_element(rule[i - 1])) { - throw std::runtime_error( - "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + - std::to_string(rule_id) + "," + std::to_string(i)); - } - fprintf(file, "-"); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_ALT: - if (i == 0 || !is_char_element(rule[i - 1])) { - throw std::runtime_error( - "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + - std::to_string(rule_id) + "," + std::to_string(i)); - } - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_ANY: - fprintf(file, "."); - break; - } - if (is_char_element(elem)) { - switch (rule[i + 1].type) { - case LLAMA_GRETYPE_CHAR_ALT: - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - case LLAMA_GRETYPE_CHAR_ANY: - break; - default: - fprintf(file, "] "); - } - } - } - fprintf(file, "\n"); - } - - void print_grammar(FILE * file, const parse_state & state) { - try { - std::map symbol_id_names; - for (const auto & kv : state.symbol_ids) { - symbol_id_names[kv.second] = kv.first; - } - for (size_t i = 0, end = state.rules.size(); i < end; i++) { - // fprintf(file, "%zu: ", i); - // print_rule_binary(file, state.rules[i]); - print_rule(file, uint32_t(i), state.rules[i], symbol_id_names); - // fprintf(file, "\n"); - } - } catch (const std::exception & err) { - fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); - } - } - - std::vector parse_state::c_rules() { - std::vector ret; - ret.reserve(rules.size()); - for (const auto & rule : rules) { - ret.push_back(rule.data()); - } - return ret; - } -} diff --git a/common/grammar-parser.h b/common/grammar-parser.h deleted file mode 100644 index 3939bc30..00000000 --- a/common/grammar-parser.h +++ /dev/null @@ -1,30 +0,0 @@ -// Implements a parser for an extended Backus-Naur form (BNF), producing the -// binary context-free grammar format specified by llama.h. Supports character -// ranges, grouping, and repetition operators. As an example, a grammar for -// arithmetic might look like: -// -// root ::= expr -// expr ::= term ([-+*/] term)* -// term ::= num | "(" space expr ")" space -// num ::= [0-9]+ space -// space ::= [ \t\n]* - -#pragma once -#include "llama.h" -#include -#include -#include -#include - -namespace grammar_parser { - struct parse_state { - std::map symbol_ids; - std::vector> rules; - - std::vector c_rules(); - bool success; - }; - - parse_state parse(const char * src); - void print_grammar(FILE * file, const parse_state & state); -} diff --git a/common/regex-partial.cpp b/common/regex-partial.cpp index 4bff6b66..e667a209 100644 --- a/common/regex-partial.cpp +++ b/common/regex-partial.cpp @@ -27,7 +27,7 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b return res; } std::match_results srmatch; - if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) { + if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) { auto group = srmatch[1].str(); if (group.length() != 0) { auto it = srmatch[1].second.base(); @@ -55,18 +55,18 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b to see if a string ends with a partial regex match, but but it's not in std::regex yet. Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input. - - /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).* - - /a|b/ -> (a|b).* + - /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a) + - /a|b/ -> ^(a|b) - /a*?/ -> error, could match "" - - /a*b/ -> ((?:b)?a*+).* (final repetitions become eager) - - /.*?ab/ -> ((?:b)?a).* (merge .*) - - /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches) - - /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).* - - /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).* - - /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).* + - /a*b/ -> ^((?:b)?a*+) (final repetitions become eager) + - /.*?ab/ -> ^((?:b)?a) (omit .*) + - /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches) + - /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a) + - /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a) + - /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a) - The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern - (i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored) + The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern. + All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored. */ std::string regex_to_reversed_partial_regex(const std::string & pattern) { auto it = pattern.begin(); @@ -177,7 +177,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) { } } - // /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).* + // /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a) // if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group // We'll do the outermost capturing group and final .* in the enclosing function. std::vector res_alts; @@ -200,5 +200,5 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) { throw std::runtime_error("Unmatched '(' in pattern"); } - return "(" + res + ")[\\s\\S]*"; + return "^(" + res + ")"; } diff --git a/common/sampling.cpp b/common/sampling.cpp index cfec27fd..7c5accf4 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -69,34 +69,10 @@ struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vo trigger_tokens.data(), trigger_tokens.size()) : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); - //if (!grmr) { - // return nullptr; - //} - - // if there is a grammar, parse it - if (!params.grammar.empty()) { - result->parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - if (result->parsed_grammar.success) { - // will be empty (default) if there are parse errors - if (result->parsed_grammar.rules.empty()) { - fprintf(stderr, "%s: failed to parse grammar\n", __func__); - delete result; - return nullptr; - } - - // Ensure that there is a "root" node. - if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) { - fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); - delete result; - return nullptr; - } - if (grmr == nullptr) { - throw std::runtime_error("Failed to initialize llama_grammar"); - } - } - } result->prev.resize(params.n_prev); result->n_valid = 0; + result->grammar_str = params.grammar; + result->grammar_root = "root"; } result->grammar = grmr; llama_sampling_set_rng_seed(result, params.seed); @@ -140,71 +116,27 @@ void common_sampler_free(struct llama_sampling_context * ctx) { delete ctx; } -void common_sampler_reset(const struct llama_vocab* vocab, llama_sampling_context * ctx) { - - if (ctx->grammar != NULL) { - llama_grammar_free(ctx->grammar); - ctx->grammar = NULL; +static void llama_grammar_reset(llama_sampling_context * ctx) { + ctx->prev.clear(); + if (!ctx->grammar) { + return; } - struct llama_grammar* grmr; - auto params = ctx->params; - if (params.grammar.compare(0, 11, "%llguidance") == 0) { -#ifdef LLAMA_USE_LLGUIDANCE - grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); -#else - GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); -#endif // LLAMA_USE_LLGUIDANCE + std::vector trigger_patterns_c; + trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size()); + for (auto& trigger_pattern : ctx->grammar->trigger_patterns) { + trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); } - else { - std::vector trigger_patterns; - std::vector patterns_anywhere; - std::vector trigger_tokens; - for (const auto& trigger : params.grammar_triggers) { - switch (trigger.type) { - case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: - { - const auto& word = trigger.value; - patterns_anywhere.push_back(regex_escape(word)); - break; - } - case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: - { - patterns_anywhere.push_back(trigger.value); - break; - } - case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: - { - trigger_patterns.push_back(trigger.value); - break; - } - case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: - { - const auto token = trigger.token; - trigger_tokens.push_back(token); - break; - } - default: - GGML_ASSERT(false && "unknown trigger type"); - } - } - if (!patterns_anywhere.empty()) { - trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*"); - } - std::vector trigger_patterns_c; - trigger_patterns_c.reserve(trigger_patterns.size()); - for (const auto& regex : trigger_patterns) { - trigger_patterns_c.push_back(regex.c_str()); - } + auto* grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), + ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), + ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); - grmr = params.grammar_lazy - ? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", - trigger_patterns_c.data(), trigger_patterns_c.size(), - trigger_tokens.data(), trigger_tokens.size()) - : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); - } + llama_grammar_free_impl(ctx->grammar); + ctx->grammar = grammar_new; +} - ctx->grammar = grmr; +void common_sampler_reset(const struct llama_vocab * vocab, llama_sampling_context * ctx) { + llama_grammar_reset(ctx); llama_sampler_dry_reset(ctx->smpl); } @@ -215,13 +147,15 @@ void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t s ctx->rng.seed(seed); } -void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { +void common_sampler_clone(llama_sampling_context * src, llama_sampling_context * dst) { if (dst->grammar) { llama_grammar_free(dst->grammar); dst->grammar = nullptr; } if (src->grammar) { + dst->grammar_root = src->grammar_root; + dst->grammar_str = src->grammar_str; dst->grammar = llama_grammar_copy(src->grammar); } diff --git a/common/sampling.h b/common/sampling.h index f2d1b1bf..6ccfca6a 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -1,7 +1,7 @@ #pragma once #include "llama.h" -#include "grammar-parser.h" +#include "llama-grammar.h" #include #include #include @@ -113,10 +113,10 @@ struct llama_sampling_context { // mirostat sampler state float mirostat_mu; - llama_grammar * grammar; + std::string grammar_str; + std::string grammar_root; - // internal - grammar_parser::parse_state parsed_grammar; + llama_grammar * grammar; // TODO: replace with ring-buffer std::vector prev; @@ -148,7 +148,7 @@ void common_sampler_reset(const struct llama_vocab* vocab, llama_sampling_contex void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed); // Copy the sampler context -void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); +void common_sampler_clone(llama_sampling_context * src, llama_sampling_context * dst); // Get the last sampled token llama_token llama_sampling_last(llama_sampling_context * ctx); diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index df968413..ae22553c 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -1,6 +1,6 @@ #define LLAMA_API_INTERNAL -#include "grammar-parser.h" +#include "llama-grammar.h" #include "ggml.h" #include "llama.h" #include "unicode.h" @@ -77,27 +77,30 @@ int main(int argc, char** argv) { grammar_str = buffer.str(); } + // Parse the GBNF grammar - auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + llama_grammar_parser parser; + auto parsed_grammar = parser.parse(grammar_str.c_str()); // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - fprintf(stdout, "%s: failed to parse grammar\n", __func__); + if (!parser.parse(grammar_str.c_str()) || parser.rules.empty()) { + fprintf(stderr, "%s: failed to parse grammar\n", __func__); return 1; } // Ensure that there is a "root" node. - if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) { - fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__); + if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) { + fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); return 1; } - std::vector grammar_rules(parsed_grammar.c_rules()); + std::vector grammar_rules(parser.c_rules()); // Create the LLAMA grammar - auto grammar = llama_grammar_init( + auto grammar = llama_grammar_init_impl( grammar_rules.data(), - grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + grammar_rules.size(), parser.symbol_ids.at("root")); + if (grammar == nullptr) { throw std::runtime_error("Failed to initialize llama_grammar"); } diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index f98b2aba..60eed056 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -2,7 +2,7 @@ #include "console.h" #include "llama.h" -#include "grammar-parser.h" +#include "llama-grammar.h" #include #include diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 8fb1c7d3..487b02a8 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -434,7 +434,7 @@ int main(int argc, char ** argv) { break; } - llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling); + common_sampler_clone(ctx_sampling, drafts[0].ctx_sampling); int n_seq_cur = 1; int n_past_cur = n_past_dft; @@ -503,7 +503,7 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; - llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling); + common_sampler_clone(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling); sa.push_back(n_seq_cur); diff --git a/grammars/README.md b/grammars/README.md index 01b02abb..873925c9 100644 --- a/grammars/README.md +++ b/grammars/README.md @@ -67,6 +67,30 @@ Parentheses `()` can be used to group sequences, which allows for embedding alte - `{m,n}` repeats the precedent symbol or sequence at between `m` and `n` times (included) - `{0,n}` repeats the precedent symbol or sequence at most `n` times (included) +## Tokens + +Tokens allow grammars to match specific tokenizer tokens rather than character sequences. This is useful for constraining outputs based on special tokens (like `` or ``). + +Tokens can be specified in two ways: + +1. **Token ID**: Use angle brackets with the token ID in square brackets: `<[token-id]>`. For example, `<[1000]>` matches the token with ID 1000. + +2. **Token string**: Use angle brackets with the token text directly: ``. For example, `` will match the token whose text is exactly ``. This only works if the string tokenizes to exactly one token in the vocabulary, otherwise the grammar will fail to parse. + +You can negate token matches using the `!` prefix: `!<[1000]>` or `!` matches any token *except* the specified one. + +``` +# Match a thinking block: ... +# Using token strings (requires these to be single tokens in the vocab) +root ::= thinking .* +thinking ::= !* + +# Equivalent grammar using explicit token IDs +# Assumes token 1000 = , token 1001 = +root ::= <[1000]> thinking <[1001]> .* +thinking ::= !<[1001]>* +``` + ## Comments and newlines Comments can be specified with `#`: diff --git a/include/llama.h b/include/llama.h index a0a8e3ac..73c8c293 100644 --- a/include/llama.h +++ b/include/llama.h @@ -489,39 +489,7 @@ extern "C" { // grammar types struct llama_grammar; - // grammar element type - enum llama_gretype { - // end of rule definition - LLAMA_GRETYPE_END = 0, - // start of alternate definition for rule - LLAMA_GRETYPE_ALT = 1, - - // non-terminal element: reference to rule - LLAMA_GRETYPE_RULE_REF = 2, - - // terminal element: character (code point) - LLAMA_GRETYPE_CHAR = 3, - - // inverse char(s) ([^a], [^a-b] [^abc]) - LLAMA_GRETYPE_CHAR_NOT = 4, - - // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to - // be an inclusive range ([a-z]) - LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, - - // modifies a preceding LLAMA_GRETYPE_CHAR or - // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) - LLAMA_GRETYPE_CHAR_ALT = 6, - - // any character (.) - LLAMA_GRETYPE_CHAR_ANY = 7, - }; - - typedef struct llama_grammar_element { - enum llama_gretype type; - uint32_t value; // Unicode code point or rule ID - } llama_grammar_element; // performance timing information struct llama_timings { @@ -1194,10 +1162,10 @@ extern "C" { /// @param n_rules The number of rules. /// @param start_rule_index The index of the root rule (the starting point of the grammar). /// @return The initialized llama_grammar or nullptr if initialization failed. - LLAMA_API struct llama_grammar * llama_grammar_init( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index); + //LLAMA_API struct llama_grammar * llama_grammar_init( + // const llama_grammar_element ** rules, + // size_t n_rules, + // size_t start_rule_index); struct llama_sampler_grammar; LLAMA_API void llama_grammar_init_lazy(struct llama_sampler_grammar * grammar); @@ -1489,39 +1457,6 @@ const std::vector> & llama_internal struct llama_context * ctx ); -struct llama_partial_utf8 { - uint32_t value; // bit value so far (unshifted) - int n_remain; // num bytes remaining; -1 indicates invalid sequence -}; - -struct llama_grammar_candidate { - size_t index; - const uint32_t * code_points; - llama_partial_utf8 partial_utf8; -}; - -using llama_grammar_rule = std::vector< llama_grammar_element>; -using llama_grammar_stack = std::vector; - -using llama_grammar_rules = std::vector; -using llama_grammar_stacks = std::vector; -using llama_grammar_candidates = std::vector; - -const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); - llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); - -void llama_grammar_accept(struct llama_grammar* grammar, uint32_t chr); - -std::vector llama_grammar_reject_candidates_for_stack( - const llama_grammar_rules & rules, - const llama_grammar_stack & stack, - const llama_grammar_candidates & candidates); - -std::pair, llama_partial_utf8> decode_utf8( - const std::string & src, - llama_partial_utf8 partial_start); - - // Randomly selects a token from the candidates based on their probabilities using given std::mt19937. // This is a temporary workaround in order to fix race conditions when sampling with multiple sequences. diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 12f6f3dc..5d3a864e 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -31,7 +31,7 @@ static std::pair decode_utf8(const char* src) { // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. -std::pair, llama_partial_utf8> decode_utf8( +static std::pair, llama_partial_utf8> decode_utf8( const std::string & src, llama_partial_utf8 partial_start) { static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; @@ -188,7 +188,53 @@ static std::pair parse_char(const char* src) { throw std::runtime_error("unexpected end of input"); } -static void print_grammar_char(FILE* file, uint32_t c) { +static std::pair parse_token(const llama_vocab * vocab, const char * src) { + const char * pos = src; + if (*pos != '<') { + throw std::runtime_error(std::string("expecting '<' at ") + pos); + } + pos++; + + // Parse <[id]> + if (*pos == '[') { + pos++; + const char * int_end = parse_int(pos); + uint32_t token_id = std::stoul(std::string(pos, int_end - pos)); + pos = int_end; + if (*pos != ']') { + throw std::runtime_error(std::string("expecting ']' at ") + pos); + } + pos++; + if (*pos != '>') { + throw std::runtime_error(std::string("expecting '>' at ") + pos); + } + pos++; + return std::make_pair(token_id, pos); + } + + if (vocab == nullptr) { + throw std::runtime_error(std::string("no vocab to parse token at ") + src); + } + + // Parse and tokenize to obtain the token id + while (*pos != 0 && *pos != '>') { + pos++; + } + if (*pos != '>') { + throw std::runtime_error(std::string("expecting '>' at ") + pos); + } + pos++; + + llama_token tokens[2]; + int32_t n_tokens = vocab->tokenize(src, static_cast(pos - src), tokens, 2, false, true); + if (n_tokens != 1) { + // must tokenize to exactly 1 token + throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'"); + } + return std::make_pair(tokens[0], pos); +} + +static void print_grammar_char(FILE * file, uint32_t c) { if (0x20 <= c && c <= 0x7f) { fprintf(file, "%c", static_cast(c)); } @@ -220,6 +266,8 @@ static void print_rule_binary(FILE* file, const llama_grammar_rule& rule) { case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; + case LLAMA_GRETYPE_TOKEN: fprintf(file, "TOKEN"); break; + case LLAMA_GRETYPE_TOKEN_NOT: fprintf(file, "TOKEN_NOT"); break; } switch (elem.type) { case LLAMA_GRETYPE_END: @@ -236,6 +284,17 @@ static void print_rule_binary(FILE* file, const llama_grammar_rule& rule) { print_grammar_char(file, elem.value); fprintf(file, "\") "); break; + case LLAMA_GRETYPE_TOKEN: + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; + case LLAMA_GRETYPE_TOKEN_NOT: + fprintf(file, "!"); + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; } } fprintf(file, "\n"); @@ -292,6 +351,17 @@ static void print_rule( case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "."); break; + case LLAMA_GRETYPE_TOKEN: + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; + case LLAMA_GRETYPE_TOKEN_NOT: + fprintf(file, "!"); + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; } if (is_char_element(elem)) { switch (rule[i + 1].type) { @@ -307,6 +377,44 @@ static void print_rule( fprintf(file, "\n"); } +// +// Regex utilities +// + +size_t llama_grammar_trigger_pattern::find(const std::string & input) const { + auto find_start_pos = [](const std::smatch & match) { + // get from the first matched capturing group to the end of the string + size_t start = std::string::npos; + for (auto i = 1u; i < match.size(); i++) { + if (match.length(i) > 0) { + start = match.position(i); + break; + } + } + if (start == std::string::npos) { + start = match.position(0); + } + return start; + }; + + if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') { + // match against the entire input + std::smatch match; + if (std::regex_match(input, match, regex)) { + return find_start_pos(match); + } + } + + // search anywhere + std::smatch match; + if (std::regex_search(input, match, regex)) { + return find_start_pos(match); + } + + return std::string::npos; +} + + // // implementation // @@ -454,9 +562,19 @@ const char* llama_grammar_parser::parse_sequence( } } pos = parse_space(pos + 1, is_nested); - } - else if (is_word_char(*pos)) { // rule reference - const char* name_end = parse_name(pos); + } else if (*pos == '<' || *pos == '!') { // token + auto type = LLAMA_GRETYPE_TOKEN; + if (*pos == '!') { // token inverse + type = LLAMA_GRETYPE_TOKEN_NOT; + pos++; + } + auto token_pair = parse_token(vocab, pos); + const char * token_end = token_pair.second; + last_sym_start = rule.size(); + rule.push_back({type, token_pair.first}); + pos = parse_space(token_end, is_nested); + } else if (is_word_char(*pos)) { // rule reference + const char * name_end = parse_name(pos); uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); pos = parse_space(name_end, is_nested); last_sym_start = rule.size(); @@ -720,6 +838,21 @@ static bool llama_grammar_match_partial_char( return !is_positive_char; } +// returns true iff token matches the rule at pos (regular or inverse) +// asserts that pos is pointing to a token element +static bool llama_grammar_match_token( + const llama_grammar_element * pos, + const llama_token token) { + GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT); + if (pos->type == LLAMA_GRETYPE_TOKEN) { + return pos->value == static_cast(token); + } + if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + return pos->value != static_cast(token); + } + return false; +} + // transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void llama_grammar_advance_stack( @@ -768,6 +901,8 @@ static void llama_grammar_advance_stack( case LLAMA_GRETYPE_CHAR: case LLAMA_GRETYPE_CHAR_NOT: case LLAMA_GRETYPE_CHAR_ANY: + case LLAMA_GRETYPE_TOKEN: + case LLAMA_GRETYPE_TOKEN_NOT: if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { // only add the stack if it's not a duplicate of one we already have new_stacks.emplace_back(stack); @@ -864,31 +999,38 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } +static void llama_grammar_accept_chr( + struct llama_grammar & grammar, + const llama_grammar_stack & stack, + uint32_t chr, + llama_grammar_stacks & new_stacks) { + if (stack.empty()) { + return; + } -// takes a set of possible pushdown stacks on a grammar, which are required to -// be positioned at a character range (see `llama_grammar_advance_stack`), and -// produces the N possible stacks if the given char is accepted at those -// positions -void llama_grammar_accept(struct llama_grammar* grammar, uint32_t chr) { + const llama_grammar_element * pos = stack.back(); + + // ignore if this turns into a token + if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + return; + } + + auto match = llama_grammar_match_char(pos, chr); + if (match.first) { + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(match.second)) { + new_stack.push_back(match.second); + } + llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks); + } +} + +void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { llama_grammar_stacks stacks_new; stacks_new.reserve(grammar->stacks.size()); for (const auto& stack : grammar->stacks) { - if (stack.empty()) { - continue; - } - - auto match = llama_grammar_match_char(stack.back(), chr); - if (match.first) { - const llama_grammar_element* pos = match.second; - - // update top of stack to next element, if any - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(pos)) { - new_stack.push_back(pos); - } - llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new); - } + llama_grammar_accept_chr(*grammar, stack, chr, stacks_new); } grammar->stacks = std::move(stacks_new); @@ -913,6 +1055,22 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( const llama_grammar_element * stack_pos = stack.back(); + // if the top of the stack is a token rule, then we only need to check the token id + if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + for (const auto & tok : candidates) { + if (*tok.code_points == 0) { + // reached the end of a token consumed by char rules, reject iff it ended + // in a partial response + if (tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } else if (!llama_grammar_match_token(stack_pos, tok.id)) { + rejects.push_back(tok); + } + } + return rejects; + } + llama_grammar_candidates next_candidates; next_candidates.reserve(candidates.size()); @@ -925,7 +1083,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( rejects.push_back(tok); } } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { - next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); + next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id }); } else { rejects.push_back(tok); } @@ -943,7 +1101,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); for (const auto & tok : next_rejects) { - rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); + rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id }); } return rejects; @@ -969,10 +1127,9 @@ struct llama_grammar* llama_grammar_init_impl( for (size_t i = 0; i < n_rules; i++) { for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { vec_rules[i].push_back(*pos); - } - vec_rules[i].push_back({ LLAMA_GRETYPE_END, 0 }); } - + vec_rules[i].push_back({ LLAMA_GRETYPE_END, 0 }); + } // Check for left recursion std::vector rules_visited(n_rules); std::vector rules_in_progress(n_rules); @@ -1017,12 +1174,13 @@ struct llama_grammar* llama_grammar_init_impl( NULL, std::move(vec_rules), std::move(stacks), - /* .partial_utf8 = */ {}, - /* .lazy =*/ false, - /* .awaiting_trigger = */ false, - /* .trigger_buffer = */ "", - /* .trigger_tokens = */ {}, - /* .trigger_patterns = */ {}, + /* .partial_utf8 = */ {}, + /* .lazy = */ false, + /* .awaiting_trigger = */ false, + /* .trigger_buffer = */ "", + /* .trigger_buffer_positions = */ {}, + /* .trigger_tokens = */ {}, + /* .trigger_patterns = */ {}, }; } @@ -1035,7 +1193,7 @@ struct llama_grammar* llama_grammar_init_impl( size_t num_trigger_patterns, const llama_token* trigger_tokens, size_t num_trigger_tokens) { - llama_grammar_parser parser; + llama_grammar_parser parser(vocab); // if there is a grammar, parse it // rules will be empty (default) if there are parse errors @@ -1124,38 +1282,44 @@ struct llama_grammar* llama_grammar_init_impl( vocab, std::move(vec_rules), std::move(stacks), - /* .partial_utf8 = */ {}, - /* .lazy = */ lazy, - /* .awaiting_trigger = */ lazy, - /* .trigger_buffer = */ "", + /* .partial_utf8 = */ {}, + /* .lazy = */ lazy, + /* .awaiting_trigger = */ lazy, + /* .trigger_buffer = */ "", + /* .trigger_buffer_positions = */ {}, std::move(vec_trigger_tokens), std::move(vec_trigger_patterns), }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { + if (grammar == nullptr) { + return; + } delete grammar; } -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) { +struct llama_grammar* llama_grammar_clone_impl(const struct llama_grammar& grammar) { auto* result = new llama_grammar{ - grammar->vocab, - grammar->rules, - grammar->stacks, - grammar->partial_utf8, - grammar->lazy, - grammar->awaiting_trigger, - grammar->trigger_buffer, - grammar->trigger_tokens, - grammar->trigger_patterns, + grammar.vocab, + grammar.rules, + grammar.stacks, + grammar.partial_utf8, + grammar.lazy, + grammar.awaiting_trigger, + grammar.trigger_buffer, + grammar.trigger_buffer_positions, + grammar.trigger_tokens, + grammar.trigger_patterns, }; + // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { - for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { - for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { - if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { - result->stacks[is][ie] = &result->rules[ir0][ir1]; + for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { + for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { + if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { + result->stacks[is][ie] = &result->rules[ir0][ir1]; } } } @@ -1199,7 +1363,7 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc candidates->data[i].logit = -INFINITY; } else { candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); - candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id }); } } @@ -1208,44 +1372,49 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc candidates->data[reject.index].logit = -INFINITY; } if (!smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; -} + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + } } -void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { +void llama_grammar_accept_impl(struct llama_grammar & grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); - GGML_ASSERT(grammar->vocab != nullptr); - const auto& piece = grammar->vocab->token_to_piece(token); + GGML_ASSERT(grammar.vocab != nullptr); + const auto& piece = grammar.vocab->token_to_piece(token); - if (grammar->awaiting_trigger) { - if (std::find(grammar->trigger_tokens.begin(), grammar->trigger_tokens.end(), token) != grammar->trigger_tokens.end()) { - grammar->awaiting_trigger = false; - grammar->trigger_buffer.clear(); - llama_grammar_accept_str(grammar, piece); + if (grammar.awaiting_trigger) { + if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { + grammar.awaiting_trigger = false; + grammar.trigger_buffer.clear(); + llama_grammar_accept_token(grammar, token, piece); LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); return; - } else { - grammar->trigger_buffer += piece; + } + else { + auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size()); + grammar.trigger_buffer_positions.push_back(std::make_pair(token, position)); + grammar.trigger_buffer += piece; - std::smatch match; - for (const auto& trigger_pattern : grammar->trigger_patterns) { - if (std::regex_match(grammar->trigger_buffer, match, trigger_pattern.regex)) { - grammar->awaiting_trigger = false; - // get from the first matched capturing group to the end of the string - size_t start = std::string::npos; - for (auto i = 1u; i < match.size(); i++) { - if (match.length(i) > 0) { - start = match.position(i); - break; + for (const auto& trigger_pattern : grammar.trigger_patterns) { + auto start = trigger_pattern.find(grammar.trigger_buffer); + if (start != std::string::npos) { + grammar.awaiting_trigger = false; + + // replay tokens that overlap with [start, end) + for (const auto& [tok, tok_pos] : grammar.trigger_buffer_positions) { + auto [tok_start, tok_end] = tok_pos; + if (tok_end <= start) { + continue; } + + size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces + size_t piece_len = tok_end - piece_start; + auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len); + llama_grammar_accept_token(grammar, tok, tok_piece); } - if (start == std::string::npos) { - start = match.position(0); - } - auto constrained_str = grammar->trigger_buffer.substr(start); - // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); - grammar->trigger_buffer.clear(); - llama_grammar_accept_str(grammar, constrained_str); + + auto constrained_str = grammar.trigger_buffer.substr(start); + grammar.trigger_buffer.clear(); + grammar.trigger_buffer_positions.clear(); LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str()); return; } @@ -1256,7 +1425,7 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc } if (llama_token_is_eog(vocab, token)) { - for (const auto & stack : grammar->stacks) { + for (const auto & stack : grammar.stacks) { if (stack.empty()) { return; } @@ -1264,22 +1433,77 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc GGML_ABORT("fatal error"); } - llama_grammar_accept_str(grammar, piece); + llama_grammar_accept_token(grammar, token, piece); smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } -void llama_grammar_accept_str(struct llama_grammar* grammar, const std::string& piece) { +void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) { // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece, grammar->partial_utf8); + const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto& code_points = decoded.first; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar, *it); + llama_grammar_accept(&grammar, *it); } - grammar->partial_utf8 = decoded.second; - if (grammar->stacks.empty()) { + grammar.partial_utf8 = decoded.second; + if (grammar.stacks.empty()) { throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); } } +void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) { + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(piece, grammar.partial_utf8); + const auto & code_points = decoded.first; + + llama_grammar_stacks stacks_new; + stacks_new.reserve(grammar.stacks.size()); + + for (const auto & stack : grammar.stacks) { + if (stack.empty()) { + continue; + } + + const llama_grammar_element * pos = stack.back(); + + if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + if (llama_grammar_match_token(pos, token)) { + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + new_stack.push_back(pos + 1); + } + llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new); + } + } else { + llama_grammar_stacks current_stacks = {stack}; + + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + llama_grammar_stacks next_stacks; + + for (const auto & cur_stack : current_stacks) { + llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks); + } + + current_stacks = std::move(next_stacks); + if (current_stacks.empty()) { + break; + } + } + + for (auto & surviving_stack : current_stacks) { + if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) { + stacks_new.emplace_back(surviving_stack); + } + } + } + } + + grammar.stacks = std::move(stacks_new); + grammar.partial_utf8 = decoded.second; + + if (grammar.stacks.empty()) { + throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")"); + } +} + diff --git a/src/llama-grammar.h b/src/llama-grammar.h index f13953c1..78f9c5d7 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -9,11 +9,84 @@ struct llama_vocab; struct llama_sampling; +// grammar element type +enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 6, + + // any character (.) + LLAMA_GRETYPE_CHAR_ANY = 7, + + // terminal element: token (<[token-id]>) + LLAMA_GRETYPE_TOKEN = 8, + + // inverse token (!<[token-id]>) + LLAMA_GRETYPE_TOKEN_NOT = 9, +}; + +typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point, rule ID, or token ID +} llama_grammar_element; + + +struct llama_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t* code_points; + llama_partial_utf8 partial_utf8; + llama_token id; +}; + +using llama_grammar_rule = std::vector< llama_grammar_element>; +using llama_grammar_stack = std::vector; + +using llama_grammar_rules = std::vector; +using llama_grammar_stacks = std::vector; +using llama_grammar_candidates = std::vector; + +const llama_grammar_rules& llama_grammar_get_rules(const struct llama_grammar* grammar); +llama_grammar_stacks& llama_grammar_get_stacks(struct llama_grammar* grammar); + +void llama_grammar_accept(struct llama_grammar* grammar, uint32_t chr); + +std::vector llama_grammar_reject_candidates_for_stack( + const llama_grammar_rules& rules, + const llama_grammar_stack& stack, + const llama_grammar_candidates& candidates); + struct llama_grammar_parser { + const llama_vocab * vocab; std::map symbol_ids; llama_grammar_rules rules; + llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {} + llama_grammar_stack c_rules() const; uint32_t get_symbol_id(const char* src, size_t len); @@ -42,9 +115,15 @@ struct llama_grammar_parser { struct llama_grammar_trigger_pattern { std::string pattern; std::regex regex; + + size_t find(const std::string & input) const; }; + struct llama_grammar { + // maintain a list of llama_tokens and their positions in the trigger_buffer + using token_pos = std::pair>; + // note: allow null vocab for testing (not great) const llama_vocab* vocab; @@ -60,6 +139,7 @@ struct llama_grammar { bool lazy = false; bool awaiting_trigger = false; // Initialized to true for lazy grammars only std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. + std::vector trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found. std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). std::vector trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated @@ -88,7 +168,8 @@ struct llama_grammar* llama_grammar_init_impl( void llama_grammar_free_impl(struct llama_grammar * grammar); -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar); + +struct llama_grammar* llama_grammar_clone_impl(const struct llama_grammar& grammar); void llama_grammar_sample_impl( const struct llama_grammar * grammar, @@ -96,13 +177,18 @@ void llama_grammar_sample_impl( const struct llama_sampling * smpl, llama_token_data_array * candidates); -void llama_grammar_accept_token_impl( - struct llama_grammar * grammar, +void llama_grammar_accept_impl( + struct llama_grammar & grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token); void llama_grammar_accept_str( - struct llama_grammar* grammar, - const std::string& piece); + struct llama_grammar & grammar, + const std::string & piece); + +void llama_grammar_accept_token( + struct llama_grammar & grammar, + llama_token token, + const std::string & piece); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index bb94af7a..c8cca55c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1217,7 +1217,7 @@ static const char* llama_sampler_grammar_name(const struct llama_sampler* /*smpl static void llama_sampler_grammar_accept_impl(struct llama_sampler* smpl, llama_token token) { auto* ctx = (llama_sampler_grammar*)smpl->ctx; if (ctx->grammar) { - llama_grammar_accept_token_impl(ctx->grammar,ctx->vocab ,nullptr, token); + llama_grammar_accept_impl(*ctx->grammar,ctx->vocab ,nullptr, token); } } diff --git a/src/llama.cpp b/src/llama.cpp index a78fef0d..36d28c7e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7611,36 +7611,13 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) { // grammar // -struct llama_grammar * llama_grammar_init( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index) { - return llama_grammar_init_impl(rules, n_rules, start_rule_index); -} - void llama_grammar_free(struct llama_grammar * grammar) { llama_grammar_free_impl(grammar); } -// -//void llama_grammar_init_lazy(struct llama_sampler* smpl) { -// -// if (!grammar) { -// return; -// } -// std::vector trigger_patterns_c; -// trigger_patterns_c.reserve(grammar.grammar->trigger_patterns.size()); -// for (auto& trigger_pattern : grammar.grammar->trigger_patterns) { -// trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); -// } -// //auto* grammar_new = llama_grammar_init_impl(grammar->vocab, "", "root", -// // grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), -// // grammar->trigger_tokens.data(), grammar->trigger_tokens.size()); -// -//} struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { - return llama_grammar_copy_impl(grammar); + return llama_grammar_clone_impl(*grammar); } void llama_grammar_sample( @@ -7661,7 +7638,7 @@ void llama_grammar_accept_token( struct llama_grammar * grammar, struct llama_context * ctx, llama_token token) { - llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token); + llama_grammar_accept_impl(*grammar, &ctx->model.vocab, &ctx->sampling, token); } // diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d0334217..18b35616 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -96,16 +96,16 @@ if (NOT WIN32) #llama_target_and_test(test-llama-grammar.cpp) #llama_target_and_test(test-chat.cpp) # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8 - if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") - llama_target_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..) - target_include_directories(test-json-schema-to-grammar PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server) - endif() + #if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") + # llama_target_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..) + # target_include_directories(test-json-schema-to-grammar PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server) + #endif() # build test-tokenizer-1-bpe target once and add many tests - add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp) - target_link_libraries(test-tokenizer-1-bpe PRIVATE common) - install(TARGETS test-tokenizer-1-bpe RUNTIME) + # add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp) + # target_link_libraries(test-tokenizer-1-bpe PRIVATE common) + # install(TARGETS test-tokenizer-1-bpe RUNTIME) # TODO: disabled due to slowness #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-aquila ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf) @@ -118,11 +118,11 @@ if (NOT WIN32) #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf) # build test-tokenizer-1-spm target once and add many tests - add_executable(test-tokenizer-1-spm test-tokenizer-1-spm.cpp) - target_link_libraries(test-tokenizer-1-spm PRIVATE common) - install(TARGETS test-tokenizer-1-spm RUNTIME) + # add_executable(test-tokenizer-1-spm test-tokenizer-1-spm.cpp) + # target_link_libraries(test-tokenizer-1-spm PRIVATE common) + # install(TARGETS test-tokenizer-1-spm RUNTIME) - llama_test(test-tokenizer-1-spm NAME test-tokenizer-1-llama-spm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf) + # llama_test(test-tokenizer-1-spm NAME test-tokenizer-1-llama-spm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf) #llama_test(test-tokenizer-1-spm NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf) # llama_target_and_test(test-double-float.cpp) # SLOW diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 503eb43a..b6bf394c 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -44,13 +44,66 @@ static bool test_build_grammar_fails(const std::string & grammar_str) { return grammar_fails; } +struct token_and_piece { + llama_token token; + std::string piece; +}; + +// token() encodes a 32-bit ID as 5 bytes: a 0xff marker followed by the ID in big-endian order. +static std::string token(llama_token id) { + return std::string{ + static_cast(0xff), + static_cast((id >> 24) & 0xff), + static_cast((id >> 16) & 0xff), + static_cast((id >> 8) & 0xff), + static_cast(id & 0xff) + }; +} + +// parse_tokens() parses the token encodes above and UTF-8 text. +static std::vector parse_tokens(const std::string & input) { + std::vector result; + result.reserve(input.size()); + size_t offset = 0; + while (offset < input.size()) { + try { + if (static_cast(input[offset]) == 0xff) { + if (offset + 5 > input.size()) { + throw std::runtime_error("not enough bytes for token id"); + } + uint32_t val = + (static_cast(input[offset + 1]) << 24) | + (static_cast(input[offset + 2]) << 16) | + (static_cast(input[offset + 3]) << 8) | + (static_cast(input[offset + 4])); + auto piece = "<[" + std::to_string(val) + "]>"; + result.push_back({static_cast(val), piece}); + offset += 5; + } else { + uint32_t cpt = unicode_cpt_from_utf8(input, offset); + result.push_back({0, unicode_cpt_to_utf8(cpt)}); + } + } catch (const std::invalid_argument & /*ex*/) { + // Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize + ++offset; + result.push_back({0, unicode_cpt_to_utf8(0xFFFD)}); // replacement character + } + } + return result; +} + static bool match_string(const std::string & input, llama_grammar * grammar) { - auto decoded = decode_utf8(input, {}); + const auto parsed = parse_tokens(input); const auto & code_points = decoded.first; - const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); - llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); + for (const auto & in : parsed) { + try { + llama_grammar_accept_token(*grammar, in.token, in.piece); + } catch (const std::runtime_error & /*e*/) { + // normally this shouldn't get hit because of llama_grammar_apply + return false; + } for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy @@ -443,6 +496,30 @@ static void test_simple_grammar() { "12a45", } ); + + // Test case for a simple grammar with tokens + test_grammar( + "simple grammar with tokens", + R"""( + root ::= <[10]> content <[11]> + content ::= (!<[11]>)*)""", + // Passing strings + { + token(10) + "hello world" + token(11), + token(10) + "text with " + token(12) + " other tokens " + token(13) + " mixed in" + token(11), + token(10) + token(11), + token(10) + token(12) + token(13) + token(14) + token(15) + token(11), + token(10) + "a" + token(11), + }, + // Failing strings + { + token(10) + "missing end token", + token(10), + "missing start token" + token(11), + token(10) + token(11) + token(11), // double end token + token(11) + "wrong order" + token(10), + } + ); } static void test_complex_grammar() { @@ -504,6 +581,34 @@ static void test_complex_grammar() { "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/", } ); + + // Test case for a more complex grammar with tokens + test_grammar( + "complex grammar with tokens", + R"""( + root ::= reasoning+ content tool-call* + reasoning ::= <[10]> (!<[11]>)* <[11]> + content ::= <[20]> (!<[21]>)* <[21]> + tool-call ::= <[12]> name <[13]> args <[14]> + name ::= (!<[13]>)+ + args ::= (!<[14]>)*)""", + // Passing strings + { + token(10) + "I am thinking" + token(11) + token(20) + "hello world!" + token(21) + token(12) + "search" + token(13) + "query=test" + token(14), + token(10) + "reasoning 1" + token(11) + token(10) + "reasoning 2" + token(11) + token(20) + token(21) + token(12) + "tool" + token(13) + token(14), + token(10) + token(11) + token(20) + "content" + token(21), + token(10) + "think" + token(12) + " nested" + token(11) + token(20) + token(10) + "more content" + token(21) + token(12) + "fn" + token(13) + "x=1,y=2" + token(14) + token(12) + "fn2" + token(13) + token(14), + token(10) + "reasoning" + token(11) + token(10) + "more" + token(11) + token(10) + "even more" + token(11) + token(20) + "text" + token(21) + token(12) + "a" + token(13) + "b" + token(14) + token(12) + "c" + token(13) + "d" + token(14), + }, + // Failing strings + { + token(20) + "content only" + token(21), + token(10) + "no closing reasoning", + token(10) + token(11) + token(20) + "no closing content", + token(10) + token(11) + token(20) + token(21) + token(12) + "incomplete tool", + token(10) + token(11) + token(11) + token(20) + token(21), + } + ); } static void test_special_chars() { diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index 5df5abb2..68a38639 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -511,5 +511,19 @@ int main() {LLAMA_GRETYPE_END, 0}, }); + // <[1000]> = "" + // <[1001]> = "" + verify_parsing(R"""( + root ::= <[1000]> !<[1001]> <[1001]> + )""", { + {"root", 0} + }, { + // root (index 0) + {LLAMA_GRETYPE_TOKEN, 1000}, + {LLAMA_GRETYPE_TOKEN_NOT, 1001}, + {LLAMA_GRETYPE_TOKEN, 1001}, + {LLAMA_GRETYPE_END, 0}, + }); + return 0; } diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index 1f3a267b..3814f124 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -204,7 +204,7 @@ int main() uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point cp[0] = 37 + i; cp[1] = 0; - next_candidates[i] = {i, cp, {}}; + next_candidates[i] = {i, cp, {}, 0}; } std::vector>> expected_reject = { diff --git a/tests/test-regex-partial.cpp b/tests/test-regex-partial.cpp index ffad1897..70af6d75 100644 --- a/tests/test-regex-partial.cpp +++ b/tests/test-regex-partial.cpp @@ -232,52 +232,52 @@ static void test_regex_to_reversed_partial_regex() { printf("[%s]\n", __func__); assert_equals( - "((?:(?:c)?b)?a)[\\s\\S]*", + "^((?:(?:c)?b)?a)", regex_to_reversed_partial_regex("abc")); assert_equals( - "(a+)[\\s\\S]*", + "^(a+)", regex_to_reversed_partial_regex("a+")); assert_equals( - "(a*)[\\s\\S]*", + "^(a*)", regex_to_reversed_partial_regex("a*")); assert_equals( - "(a?)[\\s\\S]*", + "^(a?)", regex_to_reversed_partial_regex("a?")); assert_equals( - "([a-z])[\\s\\S]*", + "^([a-z])", regex_to_reversed_partial_regex("[a-z]")); assert_equals( - "((?:\\w+)?[a-z])[\\s\\S]*", + "^((?:\\w+)?[a-z])", regex_to_reversed_partial_regex("[a-z]\\w+")); assert_equals( - "((?:a|b))[\\s\\S]*", + "^((?:a|b))", regex_to_reversed_partial_regex("(?:a|b)")); assert_equals( - "((?:(?:(?:d)?c)?b)?a)[\\s\\S]*", + "^((?:(?:(?:d)?c)?b)?a)", regex_to_reversed_partial_regex("abcd")); assert_equals( - "((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ?? + "^((?:b)?a*)", // TODO: ((?:b)?a*+).* ?? regex_to_reversed_partial_regex("a*b")); assert_equals( - "((?:(?:b)?a)?.*)[\\s\\S]*", + "^((?:(?:b)?a)?.*)", regex_to_reversed_partial_regex(".*?ab")); assert_equals( - "((?:(?:b)?.*)?a)[\\s\\S]*", + "^((?:(?:b)?.*)?a)", regex_to_reversed_partial_regex("a.*?b")); assert_equals( - "((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*", + "^((?:(?:d)?(?:(?:c)?b))?a)", regex_to_reversed_partial_regex("a(bc)d")); assert_equals( - "((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*", + "^((?:(?:(?:c)?b|(?:e)?d))?a)", regex_to_reversed_partial_regex("a(bc|de)")); assert_equals( - "((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*", + "^((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)", regex_to_reversed_partial_regex("ab{2,4}c")); } From f8acfc2bf0a778bd49523e0c8badbf946290a3c1 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 3 Feb 2026 09:18:46 +0200 Subject: [PATCH 4/8] Better CUDA TG for GQA = 10 (#1221) * Better CUDA TG for GQA = 10 * Cleanup --- ggml/src/ggml-cuda/fattn-new-mma.cu | 14 ++++++-------- ggml/src/ggml-cuda/fattn.cu | 2 +- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 0e7908c2..f5cb9854 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -2136,21 +2136,19 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - if (K->ne[0] == 128 && (gqa_ratio == 12 || gqa_ratio == 6)) { + if (K->ne[0] == 128) { GGML_ASSERT(Q->ne[0] == 128 && V->ne[0] == 128); - //GGML_ASSERT(Q->ne[1] <= 4); - //ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, 128, 16>(ctx, dst); if (gqa_ratio == 12) { ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst); - } else { + } else if (gqa_ratio == 6) { ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 8>(ctx, dst); + } else if (gqa_ratio == 10) { + ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst); + } else { + GGML_ABORT("Not implemented"); } return; } - //if (K->ne[0] == 64 && V->ne[0] == 64) { - // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst); - // return; - //} if (K->ne[0] == 192 && V->ne[0] == 128) { GGML_ASSERT(Q->ne[0] == 192); //GGML_ASSERT(gqa_ratio == 1); // Haha, this assert was for DeepSeek. But now we have Mimo2, which has GQA > 1 diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 267968b0..7d47a81d 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -90,7 +90,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (new_mma_available(cc) && K->ne[0] == 128 && V->ne[0] == 128 && Q->ne[0] == 128 && Q->ne[1] == 1 && - (Q->ne[2] / K->ne[2] == 12 || Q->ne[2] / K->ne[2] == 6)) { + (Q->ne[2] / K->ne[2] == 12 || Q->ne[2] / K->ne[2] == 6 || Q->ne[2] / K->ne[2] == 10)) { ggml_cuda_flash_attn_ext_mma_new(ctx, dst); return; } From e5622a2e91c70f6cbd663dac43a5f07ee2d03e60 Mon Sep 17 00:00:00 2001 From: usrlocalben Date: Wed, 4 Feb 2026 04:57:50 -0500 Subject: [PATCH 5/8] Fix Phi-3, Phi-4 (#1226) * fix phi3 tensor setup * avoid SWA for Phi-4 --- src/llama-build-context.cpp | 11 +++++++++-- src/llama-load-tensors.cpp | 6 ++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index f44b4c2c..f5509070 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -4539,7 +4539,14 @@ ggml_cgraph * llm_build_context::build_phi3() { struct ggml_tensor * inp_pos = build_inp_pos(); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(); + struct ggml_tensor * KQ_mask; + if (hparams.n_swa == 0) { + // Phi-4 does not use SWA + KQ_mask = build_inp_KQ_mask(); + } + else { + KQ_mask = build_inp_KQ_mask_swa(); + } for (int il = 0; il < n_layer; ++il) { auto residual = inpL; @@ -4593,7 +4600,7 @@ ggml_cgraph * llm_build_context::build_phi3() { cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index d66333ea..fd4b5162 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -1316,6 +1316,12 @@ bool create_tensors_helper::create_phi3_tensors(const LLM_TN & tn) { model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + // output + { + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_split = ctx_for_layer_split(i); From 17d101863d5cf447a7d4f4612ce341b49f582904 Mon Sep 17 00:00:00 2001 From: gapeleon Date: Thu, 5 Feb 2026 01:07:18 +1100 Subject: [PATCH 6/8] server: add dynamic control vector management endpoints (#1223) This implements the ability to load, unload, and scale control vectors (representation engineering) mid-inference, following the existing task-queue pattern used by LoRA adapters. New Endpoints: - GET /control-vectors - POST /control-vectors/load - POST /control-vectors/unload - POST /control-vectors/apply (handles scaling) Technical Notes: - Centralizes vector aggregation logic to share implementation between load, unload, and apply tasks. - Vectors are applied globally to the model context. - Enforces dimension validation on load to safely reject incompatible vectors. Co-authored-by: Gapeleon --- examples/server/server-common.h | 10 ++ examples/server/server-context.cpp | 196 +++++++++++++++++++++++++++++ examples/server/server-context.h | 4 + examples/server/server-task.h | 3 + examples/server/server.cpp | 100 +++++++++++++++ 5 files changed, 313 insertions(+) diff --git a/examples/server/server-common.h b/examples/server/server-common.h index 52d1e5b3..1b4b2acc 100644 --- a/examples/server/server-common.h +++ b/examples/server/server-common.h @@ -111,6 +111,16 @@ static T json_value(const json& body, const std::string& key, const T& default_v } } +// Control vector container for dynamic management +struct control_vector_container { + std::string path; + float scale; + int32_t layer_start; + int32_t layer_end; + llama_control_vector_data data; + bool applied; +}; + // thin wrapper around common_grammar_trigger with (de)serialization functions struct server_grammar_trigger { common_grammar_trigger value; diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 3c4ff874..00c54d2b 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -1958,9 +1958,205 @@ void server_context::process_single_task(server_task&& task) { result.data = json{ { "success", true } }; queue_results.send(result); } break; + case SERVER_TASK_TYPE_LOAD_CONTROL_VECTOR: + { + // Load control vector from file + std::string path = task.data.at("path"); + float scale = task.data.value("scale", 1.0f); + int32_t layer_start = task.data.value("layer_start", 1); + int32_t layer_end = task.data.value("layer_end", llama_n_layer(model)); + + // Check if already loaded + int cv_id = -1; + for (size_t i = 0; i < control_vectors.size(); i++) { + if (control_vectors[i].path == path) { + control_vectors[i].scale = scale; + control_vectors[i].layer_start = layer_start; + control_vectors[i].layer_end = layer_end; + cv_id = i; + break; + } + } + + if (cv_id == -1) { + control_vector_container new_cv; + new_cv.path = path; + new_cv.scale = scale; + new_cv.layer_start = layer_start; + new_cv.layer_end = layer_end; + new_cv.applied = false; + + // Load the control vector data + llama_control_vector_load_info load_info; + load_info.fname = path; + load_info.strength = 1.0f; // Don't pre-scale here, we'll scale when applying + + std::vector load_infos = { load_info }; + new_cv.data = llama_control_vector_load(load_infos); + + if (new_cv.data.n_embd == -1) { + server_task_result result; + result.id = task.id; + result.error = true; + result.data = json{{ "success", false }, { "error", "Failed to load control vector from " + path }}; + queue_results.send(result); + break; + } + + // Validate dimension to prevent heap corruption + if (new_cv.data.n_embd != llama_model_n_embd(model)) { + server_task_result result; + result.id = task.id; + result.error = true; + result.data = json{{ "success", false }, + { "error", "Vector dimension mismatch" }}; + queue_results.send(result); + break; + } + + control_vectors.push_back(new_cv); + + cv_id = control_vectors.size() - 1; + } + + // Auto-apply control vectors after loading + if (!apply_control_vectors_internal()) { + server_task_result result; + result.id = task.id; + result.error = true; + result.data = json{{ "success", false }, { "error", "Failed to apply control vectors" }}; + queue_results.send(result); + break; + } + + server_task_result result; + result.id = task.id; + result.error = false; + result.data = json{{ "success", true }, { "id", cv_id }}; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_UNLOAD_CONTROL_VECTOR: + { + // Validate that "id" field exists and is a number + if (!task.data.contains("id") || task.data["id"].is_null() || !task.data["id"].is_number()) { + server_task_result result; + result.id = task.id; + result.error = true; + result.data = json{{ "success", false }, { "error", "Missing or invalid 'id' field" }}; + queue_results.send(result); + break; + } + + int id = task.data.at("id"); + + if (id < 0 || id >= (int)control_vectors.size()) { + server_task_result result; + result.id = task.id; + result.error = true; + result.data = json{{ "success", false }, { "error", "Invalid control vector ID" }}; + queue_results.send(result); + break; + } + + // Remove the control vector from the list + control_vectors.erase(control_vectors.begin() + id); + + // Reapply remaining control vectors + if (!apply_control_vectors_internal()) { + server_task_result result; + result.id = task.id; + result.error = true; + result.data = json{{ "success", false }, { "error", "Failed to apply control vectors" }}; + queue_results.send(result); + break; + } + + server_task_result result; + result.id = task.id; + result.error = false; + result.data = json{{ "success", true }}; + queue_results.send(result); + } break; + case SERVER_TASK_TYPE_SET_CONTROL_VECTOR: + { + if (!apply_control_vectors_internal()) { + server_task_result result; + result.id = task.id; + result.error = true; + result.data = json{{ "success", false }, { "error", "Failed to apply control vectors" }}; + queue_results.send(result); + break; + } + + server_task_result result; + result.id = task.id; + result.error = false; + result.data = json{{ "success", true }}; + queue_results.send(result); + } break; } } +bool server_context::apply_control_vectors_internal() { + llama_control_vector_data combined_cv = { -1, {} }; + + // Check if we have anything to apply + bool any_active = false; + for (const auto& cv : control_vectors) { + if (cv.scale != 0.0f) { + any_active = true; + break; + } + } + + if (!any_active) { + // Clear control vectors if nothing is active + llama_control_vector_apply(ctx, nullptr, 0, 0, 0, 0); + return true; + } + + // Aggregate control vectors with scaling + for (auto& cv : control_vectors) { + if (cv.scale == 0.0f) { + cv.applied = false; + continue; + } + + if (combined_cv.n_embd == -1) { + combined_cv.n_embd = cv.data.n_embd; + combined_cv.data.resize(cv.data.data.size(), 0.0f); + } + + for (size_t i = 0; i < cv.data.data.size(); i++) { + combined_cv.data[i] += cv.data.data[i] * cv.scale; + } + cv.applied = true; + } + + // Apply combined control vector + if (combined_cv.n_embd != -1 && !combined_cv.data.empty()) { + int32_t min_layer_start = INT32_MAX; + int32_t max_layer_end = 0; + + for (const auto& cv : control_vectors) { + if (cv.scale != 0.0f) { + min_layer_start = std::min(min_layer_start, cv.layer_start); + max_layer_end = std::max(max_layer_end, cv.layer_end); + } + } + + int err = llama_control_vector_apply(ctx, + combined_cv.data.data(), + combined_cv.data.size(), + combined_cv.n_embd, + min_layer_start, + max_layer_end); + return (err == 0); + } + + return true; +} + void server_context::on_finish_multitask(const server_task_multi& multitask) { // all subtasks done == multitask is done server_task_result result; diff --git a/examples/server/server-context.h b/examples/server/server-context.h index 34493565..4e52999a 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -183,6 +183,7 @@ struct server_context { llama_model* model = nullptr; llama_context* ctx = nullptr; std::vector lora_adapters; + std::vector control_vectors; gpt_params params_base; @@ -316,4 +317,7 @@ struct server_context { bool accept_special_token(const server_slot& slot, const llama_token token); json model_meta() const; + + // Re-aggregates all active vectors and updates the model state + bool apply_control_vectors_internal(); }; diff --git a/examples/server/server-task.h b/examples/server/server-task.h index 942097d3..1f4736f9 100644 --- a/examples/server/server-task.h +++ b/examples/server/server-task.h @@ -31,6 +31,9 @@ enum server_task_type { SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE, SERVER_TASK_TYPE_SET_LORA, + SERVER_TASK_TYPE_LOAD_CONTROL_VECTOR, + SERVER_TASK_TYPE_UNLOAD_CONTROL_VECTOR, + SERVER_TASK_TYPE_SET_CONTROL_VECTOR, }; enum oaicompat_type { diff --git a/examples/server/server.cpp b/examples/server/server.cpp index fa792f9e..ee8edd7b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1509,6 +1509,101 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; + // Control vector handlers + const auto handle_control_vectors_list = [&](const httplib::Request & req, httplib::Response & res) { + json result = json::array(); + for (size_t i = 0; i < ctx_server.control_vectors.size(); ++i) { + auto & cv = ctx_server.control_vectors[i]; + result.push_back({ + {"id", i}, + {"path", cv.path}, + {"scale", cv.scale}, + {"layer_start", cv.layer_start}, + {"layer_end", cv.layer_end}, + {"applied", cv.applied}, + }); + } + res.set_content(result.dump(), "application/json"); + res.status = 200; // HTTP OK + }; + + const auto handle_control_vectors_load = [&](const httplib::Request & req, httplib::Response & res) { + const json body = json::parse(req.body); + + server_task task; + task.type = SERVER_TASK_TYPE_LOAD_CONTROL_VECTOR; + task.data = body; + + const int id_task = ctx_server.queue_tasks.post(std::move(task)); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + res.set_content(result.data.dump(), "application/json"); + res.status = result.error ? 400 : 200; + }; + + const auto handle_control_vectors_unload = [&](const httplib::Request & req, httplib::Response & res) { + const json body = json::parse(req.body); + + server_task task; + task.type = SERVER_TASK_TYPE_UNLOAD_CONTROL_VECTOR; + task.data = body; + + const int id_task = ctx_server.queue_tasks.post(std::move(task)); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + res.set_content(result.data.dump(), "application/json"); + res.status = result.error ? 400 : 200; + }; + + const auto handle_control_vectors_apply = [&](const httplib::Request & req, httplib::Response & res) { + const std::vector body = json::parse(req.body); + int max_idx = ctx_server.control_vectors.size(); + + // Update scales for existing control vectors + for (auto & cv : ctx_server.control_vectors) { + cv.scale = 0.0f; // Reset all scales first + } + + // Set new scales + for (auto entry : body) { + int id = entry.at("id"); + float scale = entry.at("scale"); + if (0 <= id && id < max_idx) { + ctx_server.control_vectors[id].scale = scale; + + // Optionally update layer range + if (entry.contains("layer_start")) { + ctx_server.control_vectors[id].layer_start = entry.at("layer_start"); + } + if (entry.contains("layer_end")) { + ctx_server.control_vectors[id].layer_end = entry.at("layer_end"); + } + } else { + res.set_content(json{{ "success", false }, { "error", "Invalid control vector id" }}.dump(), "application/json"); + res.status = 400; + return; + } + } + + server_task task; + task.type = SERVER_TASK_TYPE_SET_CONTROL_VECTOR; + + const int id_task = ctx_server.queue_tasks.post(std::move(task)); + ctx_server.queue_results.add_waiting_task_id(id_task); + + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + res.set_content(result.data.dump(), "application/json"); + res.status = result.error ? 400 : 200; + }; + const auto list_saved_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { json response = json::array(); @@ -1925,6 +2020,11 @@ int main(int argc, char ** argv) { // LoRA adapters hotswap svr->Get ("/lora-adapters", handle_lora_adapters_list); svr->Post("/lora-adapters", handle_lora_adapters_apply); + // Control vectors + svr->Get ("/control-vectors", handle_control_vectors_list); + svr->Post("/control-vectors/load", handle_control_vectors_load); + svr->Post("/control-vectors/unload", handle_control_vectors_unload); + svr->Post("/control-vectors/apply", handle_control_vectors_apply); // Save & load slots svr->Get ("/slots", handle_slots); svr->Get ("/slots/list", list_slot_prompts); From b41b8cf813eae5f9d803e49cb220889119a875b8 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 4 Feb 2026 16:07:43 +0200 Subject: [PATCH 7/8] Graph parallel for SEED-OSS (#1222) * Graph parallel for SEED-OSS * Cleanup --- src/llama-build-context.cpp | 84 ++++++++----------------------------- src/llama-load-tensors.cpp | 12 +++--- src/llama.cpp | 1 + 3 files changed, 25 insertions(+), 72 deletions(-) diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index f5509070..427565fc 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -3513,13 +3513,14 @@ ggml_cgraph * llm_build_context::build_seedoss() { GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); - struct ggml_tensor * cur; - struct ggml_tensor * inpL; + ggml_tensor * cur; + ggml_tensor * inpL; inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); // inp_pos - contains the positions - struct ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); @@ -3527,57 +3528,16 @@ ggml_cgraph * llm_build_context::build_seedoss() { const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; - cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "attn_norm", il); + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, il == n_layer-1 ? inp_out_ids : nullptr, nullptr, + KQ_mask, nullptr, nullptr, kq_scale, 0.0f, 0, il, true, false, true); - // self-attention - { - auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, - model.layers[il].wk, model.layers[il].bk, - model.layers[il].wv, model.layers[il].bv, 0.f, il); - - Qcur = ggml_rope_ext( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - cb(Qcur, "Qcur", il); - - Kcur = ggml_rope_ext( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - cb(Kcur, "Kcur", il); - - cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); - } - - if (il == n_layer - 1) { - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward forward - cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "attn_post_norm", il); - cur = llm_build_ffn(ctx0, lctx, model.layers[il].attn_post_norm, ffn_inp, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].attn_post_norm, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, - LLM_FFN_SILU, LLM_FFN_PAR, cb, il); - cb(cur, "ffn_out", il); - - cur = ggml_add(ctx0, cur, ffn_inp); + LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true); cb(cur, "ffn_out", il); cur = lctx.cvec.apply_to(ctx0, cur, il); @@ -3587,13 +3547,7 @@ ggml_cgraph * llm_build_context::build_seedoss() { inpL = cur; } - cur = inpL; - - cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1); - cb(cur, "result_norm", -1); - - // lm_head - cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cur = build_output(lctx, ctx0, inpL, model.output, model.output_norm, cb); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -9478,8 +9432,8 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens GGML_ASSERT(bv->n_device == wq->n_device); } std::vector attn(wq->n_device, nullptr); - int id_last = -1; bool output_bias_added = false; + bool input_added = false; for (int id = 0; id < wq->n_device; ++id) { int il_cb = 1000*(id+1) + il; auto split_wq = wq->splits[id]; @@ -9492,6 +9446,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens (split_wq && split_wk && split_wv && split_wo && split_kl && split_vl)); if (!split_wq) continue; auto cur = get_input_tensor_sm_graph(ctx0, input, id); + auto input_id = cur; cur = do_split_norm(ctx0, cur, the_attn_norm, lctx.model.hparams, cb, id, il_cb, is_norm); auto the_q_norm = model.layers[il].attn_q_norm ? model.layers[il].attn_q_norm->extra ? ((ggml_split_tensor_t *)model.layers[il].attn_q_norm->extra)->splits[id] : model.layers[il].attn_q_norm : nullptr; @@ -9614,6 +9569,9 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens if (inp_out_ids) { // && ggml_nrows(inp_out_ids) > 1) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); + if (add_input && !input_added) { + input_id = ggml_get_rows(ctx0, input_id, inp_out_ids); + } cb(cur, "fa_get_rows", il_cb); } @@ -9628,21 +9586,15 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens cb(cur, "kqv_wo_biased", il_cb); output_bias_added = true; } + if (add_input && !input_added) { + cur = ggml_add(ctx0, cur, input_id); + input_added = true; + } if (cur->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) { cur = ggml_cast(ctx0, cur, lctx.cparams.reduce_type); } ggml_build_forward_expand(gf, cur); attn[id] = cur; - id_last = id; - } - GGML_ASSERT(id_last >= 0); - if (add_input) { - if (inp_out_ids) { // && ggml_nrows(inp_out_ids) > 1) { - input = ggml_get_rows(ctx0, input, inp_out_ids); - cb(input, "sainp_get_rows", il); - } - attn[id_last] = ggml_add(ctx0, attn[id_last], input); - cb(attn[id_last], "attn_out_with_input", il); } auto cur = ggml_reduce(ctx0, attn.data(), wq->n_device, GGML_OP_ADD); ggml_build_forward_expand(gf, cur); diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index fd4b5162..e51951f2 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -1002,12 +1002,11 @@ bool create_tensors_helper::create_seedoss_tensors(const LLM_TN & tn) { } for (int i = 0; i < n_layer; ++i) { - ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_split = ctx_for_layer_split(i); auto & layer = model.layers[i]; - layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}); layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}); @@ -1015,11 +1014,12 @@ bool create_tensors_helper::create_seedoss_tensors(const LLM_TN & tn) { layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}); // optional bias tensors - layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}); + layer.attn_post_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}); + layer.ffn_norm = layer.attn_post_norm; create_std_ffn(i, tn, layer, n_ff, n_embd, ctx_split); } diff --git a/src/llama.cpp b/src/llama.cpp index 36d28c7e..16844509 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1755,6 +1755,7 @@ static bool is_model_split_supported(const llama_model & model) { LLM_ARCH_OPENAI_MOE, LLM_ARCH_ERNIE4_5_MOE, LLM_ARCH_MINIMAX_M2, + LLM_ARCH_SEED_OSS, }; auto it = k_supported.find(model.arch); return it != k_supported.end(); From a335cff6643c7927cdc982cf780cc66e3c04482e Mon Sep 17 00:00:00 2001 From: Michael Militzer Date: Wed, 4 Feb 2026 16:08:00 +0200 Subject: [PATCH 8/8] Fix llama-server-cuda Dockerfile to build ik_llama.cpp correctly (#1224) Co-authored-by: Michael Militzer --- .devops/llama-server-cuda.Dockerfile | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/.devops/llama-server-cuda.Dockerfile b/.devops/llama-server-cuda.Dockerfile index 18424898..cb015476 100644 --- a/.devops/llama-server-cuda.Dockerfile +++ b/.devops/llama-server-cuda.Dockerfile @@ -1,6 +1,6 @@ ARG UBUNTU_VERSION=22.04 # This needs to generally match the container host's environment. -ARG CUDA_VERSION=11.7.1 +ARG CUDA_VERSION=12.4.1 # Target the CUDA build image ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} # Target the CUDA runtime image @@ -8,11 +8,13 @@ ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_V FROM ${BASE_CUDA_DEV_CONTAINER} AS build -# Unless otherwise specified, we make a fat build. -ARG CUDA_DOCKER_ARCH=all +# Set targeted arch here as needed, default: 86 (Ampere) and 90 (Hopper) +ARG CUDA_DOCKER_ARCH="86;90" RUN apt-get update && \ - apt-get install -y build-essential git libcurl4-openssl-dev + apt-get install -y build-essential git libcurl4-openssl-dev ninja-build python3-pip \ + && pip3 install --no-cache-dir cmake \ + && rm -rf /var/lib/apt/lists/* WORKDIR /app @@ -27,14 +29,25 @@ ENV LLAMA_CURL=1 # Must be set to 0.0.0.0 so it can listen to requests from host machine ENV LLAMA_ARG_HOST=0.0.0.0 -RUN make -j$(nproc) llama-server +RUN cmake -S . -B build -G Ninja \ + -DGGML_CUDA=ON -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CUDA_ARCHITECTURES="${CUDA_DOCKER_ARCH}" \ + -DBUILD_SHARED_LIBS=ON \ + -DCMAKE_C_FLAGS="-fPIC -mcmodel=large" \ + -DCMAKE_CXX_FLAGS="-fPIC -mcmodel=large" \ + && cmake --build build --target llama-server FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime RUN apt-get update && \ apt-get install -y libcurl4-openssl-dev libgomp1 curl -COPY --from=build /app/llama-server /llama-server +COPY --from=build /app/build/bin/llama-server /llama-server + +COPY --from=build /app/build/examples/mtmd/libmtmd.so /usr/local/lib/ +COPY --from=build /app/build/ggml/src/libggml.so /usr/local/lib/ +COPY --from=build /app/build/src/libllama.so /usr/local/lib/ +RUN ldconfig HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ]