From 5d90f711d450d8b91e6ce0a4e08b86ff05562e9a Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 10 Nov 2025 11:27:18 +0200 Subject: [PATCH] Model loading and compute graph --- src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + src/llama-build-context.cpp | 98 +++++++++++++++++++++++++++++++++++++ src/llama-build-context.h | 2 + src/llama-hparams.cpp | 28 +++++++---- src/llama-load-tensors.cpp | 26 ++++++++++ src/llama-model.cpp | 17 +++++++ src/llama.cpp | 1 + 8 files changed, 165 insertions(+), 9 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d680673c..3c717873 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -67,6 +67,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_BAILINGMOE2, "bailingmoe2" }, { LLM_ARCH_MINIMAX_M2, "minimax-m2" }, + { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; diff --git a/src/llama-arch.h b/src/llama-arch.h index 1b32463b..a421fa6c 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -66,6 +66,7 @@ enum llm_arch { LLM_ARCH_OPENAI_MOE, LLM_ARCH_BAILINGMOE2, LLM_ARCH_MINIMAX_M2, + LLM_ARCH_SMOLLM3, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 0475995a..f7838a66 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -8489,6 +8489,100 @@ ggml_cgraph* llm_build_context::build_minimaxm2() { return gf; } +ggml_cgraph* llm_build_context::build_smollm3() { + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + // GGML_ASSERT(n_embd_head == hparams.n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64 + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + ggml_tensor * inp_pos = build_inp_pos(); + + + //auto * inp_attn = build_attn_inp_kv(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, + model.layers[il].wqkv, model.layers[il].bqkv, + model.layers[il].wqk, model.layers[il].bqk, + model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, + model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0, il); + + if (use_rope) { + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Kcur, "Kcur", il); + } + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + cb(cur, "attn_out", il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; +} + ggml_cgraph * llm_build_context::llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { llama_batch dummy; dummy.n_tokens = 0; @@ -8839,6 +8933,10 @@ ggml_cgraph * llm_build_context::llama_build_graph( { result = llm.build_minimaxm2(); } break; + case LLM_ARCH_SMOLLM3: + { + result = llm.build_smollm3(); + } break; default: GGML_ABORT("fatal error"); } diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 391cf319..8e9d7adb 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -270,6 +270,8 @@ struct llm_build_context { ggml_cgraph * build_minimaxm2(); + ggml_cgraph * build_smollm3(); + // static ggml_tensor * llm_build_lora_mm(llama_context & lctx, ggml_context * ctx0, ggml_tensor * w, ggml_tensor * cur); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 2d7cd439..e14167c1 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -1013,16 +1013,26 @@ void llm_load_hparams( } break; case LLM_ARCH_MINIMAX_M2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - switch (hparams.n_layer) { - case 62: model.type = e_model::MODEL_230B_A10B; break; - default: model.type = e_model::MODEL_UNKNOWN; - } - } break; + switch (hparams.n_layer) { + case 62: model.type = e_model::MODEL_230B_A10B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_SMOLLM3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.n_no_rope_layer_step = 4; + + switch (hparams.n_layer) { + case 36: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index a2fc5803..e921dc9b 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -130,6 +130,8 @@ struct create_tensors_helper : public create_tensors_helper_interface { bool create_minimaxm2_tensors(const LLM_TN & tn); + bool create_smollm3_tensors(const LLM_TN & tn); + llama_model_loader & ml; llama_model & model; @@ -2466,6 +2468,28 @@ bool create_tensors_helper::create_minimaxm2_tensors(const LLM_TN & tn) { return use_mmap_buffer; } +bool create_tensors_helper::create_smollm3_tensors(const LLM_TN & tn) { + LOADING_PRELUDE + + create_embd_output(tn, n_embd, n_vocab); + + for (int i = 0; i < n_layer; ++i) { + ggml_context* ctx_layer = ctx_for_layer(i); + ggml_context* ctx_split = ctx_for_layer_split(i); + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + use_mmap_buffer &= !merge_qkv(tn, i, 0); + + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + create_std_ffn(i, tn, layer, n_ff, n_embd, ctx_split); + } + return use_mmap_buffer; +} + bool create_tensors_helper::merge_qkv(const LLM_TN & tn, int i, int bias) { auto& hparams = model.hparams; const int64_t n_head = hparams.n_head(); @@ -2699,6 +2723,8 @@ bool create_tensors_helper::create_tensors() { use_mmap_buffer = create_bailingmoe2_tensors(tn); break; case LLM_ARCH_MINIMAX_M2: use_mmap_buffer = create_minimaxm2_tensors(tn); break; + case LLM_ARCH_SMOLLM3: + use_mmap_buffer = create_smollm3_tensors(tn); break; default: throw std::runtime_error("unknown architecture"); } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0f43aa50..9d8a6f30 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1249,6 +1249,23 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, }, }, + { + LLM_ARCH_SMOLLM3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { diff --git a/src/llama.cpp b/src/llama.cpp index 50a3a6f4..413d3fae 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4642,6 +4642,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_COHERE2: case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_SMOLLM3: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2