mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Mimo-V2-Flash support (#1096)
* Mimo-2 support * Fix bug for head sizes not being the same It still does not solve the Mimo-2 quantized cache issue. * Fix quantized cache * Minor --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -69,6 +69,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
|
||||
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
||||
{ LLM_ARCH_MISTRAL3, "mistral3" },
|
||||
{ LLM_ARCH_MIMO2, "mimo2" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -140,6 +141,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, "%s.attention.sliding_window_pattern" },
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
|
||||
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
|
||||
@@ -150,6 +152,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
|
||||
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
|
||||
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
|
||||
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
|
||||
|
||||
@@ -68,6 +68,7 @@ enum llm_arch {
|
||||
LLM_ARCH_MINIMAX_M2,
|
||||
LLM_ARCH_SMOLLM3,
|
||||
LLM_ARCH_MISTRAL3,
|
||||
LLM_ARCH_MIMO2,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -133,6 +134,7 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_KV_LORA_RANK,
|
||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN,
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
||||
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
||||
@@ -143,6 +145,7 @@ enum llm_kv {
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
LLM_KV_ROPE_FREQ_BASE_SWA,
|
||||
LLM_KV_ROPE_SCALE_LINEAR,
|
||||
LLM_KV_ROPE_SCALING_TYPE,
|
||||
LLM_KV_ROPE_SCALING_FACTOR,
|
||||
|
||||
@@ -1394,13 +1394,20 @@ static ggml_tensor * llm_build_kqv(
|
||||
|
||||
auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
|
||||
if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2] || sinks) {
|
||||
//if (n_swa > 0 && k->ne[1] > n_swa + q->ne[1]) {
|
||||
// auto nton = n_swa + q->ne[1];
|
||||
// auto first = k->ne[1] - nton;
|
||||
// k = ggml_view_3d(ctx, k, k->ne[0], nton, k->ne[2], k->nb[1], k->nb[2], k->nb[1]*first);
|
||||
// v = ggml_view_3d(ctx, v, v->ne[0], nton, v->ne[2], v->nb[1], v->nb[2], v->nb[1]*first);
|
||||
// kq_mask = ggml_view_3d(ctx, kq_mask, nton, kq_mask->ne[1], kq_mask->ne[2], kq_mask->nb[1], kq_mask->nb[2], kq_mask->nb[0]*first);
|
||||
//}
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||
|
||||
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 ||
|
||||
model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE) {
|
||||
model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE || model.arch == LLM_ARCH_MIMO2) {
|
||||
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
||||
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||
@@ -1615,7 +1622,7 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
|
||||
ggml_tensor * wk, ggml_tensor * bk,
|
||||
ggml_tensor * wv, ggml_tensor * bv,
|
||||
ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il, bool add_graph_split) const {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
if (wqkv) {
|
||||
auto qkv = llm_build_lora_mm(lctx, ctx0, wqkv, cur);
|
||||
@@ -1627,8 +1634,8 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
|
||||
qkv = ggml_add(ctx0, qkv, bqkv);
|
||||
cb(qkv, "qkv_b", il);
|
||||
}
|
||||
auto Qcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), qkv->nb[1], 0*sizeof(float)*(n_embd));
|
||||
auto Kcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), qkv->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]);
|
||||
auto Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head, n_tokens, n_embd_head_k*sizeof(float), qkv->nb[1], 0*sizeof(float)*(n_embd));
|
||||
auto Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k*sizeof(float), qkv->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]);
|
||||
auto Vcur = ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(Qcur->ne[0]*Qcur->ne[1] + Kcur->ne[0]*Kcur->ne[1]));
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
@@ -1669,8 +1676,8 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
|
||||
}
|
||||
ggml_build_forward_expand(gf, qk);
|
||||
ggml_build_forward_expand(gf, Vcur);
|
||||
auto Qcur = ggml_view_3d(ctx0, qk, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), qk->nb[1], 0*sizeof(float)*(n_embd));
|
||||
auto Kcur = ggml_view_3d(ctx0, qk, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), qk->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]);
|
||||
auto Qcur = ggml_view_3d(ctx0, qk, n_embd_head_k, n_head, n_tokens, n_embd_head_k*sizeof(float), qk->nb[1], 0*sizeof(float)*(n_embd));
|
||||
auto Kcur = ggml_view_3d(ctx0, qk, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k*sizeof(float), qk->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]);
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (q_norm) {
|
||||
@@ -1689,13 +1696,13 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
|
||||
}
|
||||
|
||||
auto [Q, K, V] = llm_build_mul_mat_qkv(gf, cur, wq, bq, wk, bk, wv, bv, attention_scale, il, add_graph_split);
|
||||
auto Qcur = ggml_reshape_3d(ctx0, Q, n_embd_head, Q->ne[0]/n_embd_head, n_tokens);
|
||||
auto Qcur = ggml_reshape_3d(ctx0, Q, n_embd_head_k, Q->ne[0]/n_embd_head_k, n_tokens);
|
||||
if (q_norm) {
|
||||
Qcur = llm_build_norm(ctx0, Qcur, hparams, q_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
}
|
||||
|
||||
auto Kcur = ggml_reshape_3d(ctx0, K, n_embd_head, K->ne[0]/n_embd_head, n_tokens);
|
||||
auto Kcur = ggml_reshape_3d(ctx0, K, n_embd_head_k, K->ne[0]/n_embd_head_k, n_tokens);
|
||||
if (k_norm) {
|
||||
Kcur = llm_build_norm(ctx0, Kcur, hparams, k_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
@@ -8494,6 +8501,81 @@ ggml_cgraph * llm_build_context::build_hunyuan_moe() {
|
||||
return gf;
|
||||
}
|
||||
|
||||
ggml_cgraph * llm_build_context::build_mimo2() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
//const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
//GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
//GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const bool is_sliding = model.hparams.swa_layers[il];
|
||||
auto KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
|
||||
|
||||
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask_l, model.layers[il].attn_sinks,
|
||||
nullptr, 1.0f/sqrtf(float(n_embd_head_k)), 0.0f, is_sliding ? hparams.n_swa : 0, il, true, false, true);
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
}
|
||||
|
||||
auto ffn_inp = cur;
|
||||
|
||||
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||
cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
|
||||
model.layers[il].ffn_gate_inp, nullptr,
|
||||
model.layers[il].ffn_up_exps, nullptr,
|
||||
model.layers[il].ffn_gate_exps, nullptr,
|
||||
model.layers[il].ffn_down_exps, nullptr,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
nullptr, nullptr, // we don't have shared expert biases?
|
||||
nullptr, nullptr,
|
||||
nullptr, nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true, false, 0.0f,
|
||||
LLM_EXPERT_GATING_FUNC_SIGMOID,
|
||||
LLM_FFN_SILU, cb, il, gf, true);
|
||||
}
|
||||
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
ggml_cgraph * llm_build_context::build_openai_moe() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
@@ -9317,6 +9399,10 @@ ggml_cgraph * llm_build_context::llama_build_graph(
|
||||
{
|
||||
result = llm.build_mistral3();
|
||||
} break;
|
||||
case LLM_ARCH_MIMO2:
|
||||
{
|
||||
result = llm.build_mimo2();
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@@ -9340,6 +9426,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
ggml_tensor * input, ggml_tensor * inp_pos, ggml_tensor * rope_factors_in,
|
||||
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale,
|
||||
int n_swa, int il, bool do_rope, bool add_graph_split, bool add_input, bool is_norm) {
|
||||
|
||||
float freq_base_l = n_swa > 0 ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
|
||||
float freq_scale_l = n_swa > 0 ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train;
|
||||
|
||||
if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn &&
|
||||
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) {
|
||||
if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) {
|
||||
@@ -9414,9 +9504,9 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
rope_factors = extra->splits[id];
|
||||
}
|
||||
if (do_rope) {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
cb(Qcur, "Qcur", il_cb);
|
||||
@@ -9550,9 +9640,9 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
|
||||
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il);
|
||||
|
||||
if (do_rope) {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
@@ -276,6 +276,8 @@ struct llm_build_context {
|
||||
|
||||
ggml_cgraph * build_smollm3();
|
||||
|
||||
ggml_cgraph * build_mimo2();
|
||||
|
||||
//
|
||||
static ggml_tensor * llm_build_lora_mm(llama_context & lctx, ggml_context * ctx0,
|
||||
ggml_tensor * w, ggml_tensor * cur);
|
||||
|
||||
@@ -1072,6 +1072,23 @@ void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_MIMO2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa);
|
||||
//TODO
|
||||
//hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; // which is the same as OpenAI
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 48: model.type = e_model::MODEL_310B_A15B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
|
||||
} break;
|
||||
|
||||
default: (void)0;
|
||||
}
|
||||
|
||||
@@ -113,7 +113,7 @@ struct llama_hparams {
|
||||
|
||||
// qwen3vl deepstack
|
||||
uint32_t n_deepstack_layers = 0;
|
||||
|
||||
|
||||
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
||||
llama_token dec_start_token_id = -1;
|
||||
@@ -122,6 +122,8 @@ struct llama_hparams {
|
||||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
||||
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> swa_layers;
|
||||
|
||||
bool operator!=(const llama_hparams & other) const {
|
||||
if (this->vocab_only != other.vocab_only) return true;
|
||||
if (this->n_vocab != other.n_vocab) return true;
|
||||
|
||||
@@ -133,6 +133,8 @@ struct create_tensors_helper : public create_tensors_helper_interface {
|
||||
|
||||
bool create_smollm3_tensors(const LLM_TN & tn);
|
||||
|
||||
bool create_mimo2_tensors(const LLM_TN & tn);
|
||||
|
||||
llama_model_loader & ml;
|
||||
llama_model & model;
|
||||
|
||||
@@ -1194,6 +1196,49 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) {
|
||||
return use_mmap_buffer;
|
||||
}
|
||||
|
||||
bool create_tensors_helper::create_mimo2_tensors(const LLM_TN & tn) {
|
||||
LOADING_PRELUDE
|
||||
|
||||
create_embd_output(tn, n_embd, n_vocab, true, false); //true);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i);
|
||||
uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i);
|
||||
uint32_t n_head = hparams.n_head(i);
|
||||
//printf("Layer %2d: n_head = %u, n_embd_head_k = %d, n_embd_head_v = %d, n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", i, n_head, (int)n_embd_head_k, (int)n_embd_head_v, n_embd_k_gqa, n_embd_v_gqa);
|
||||
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_sinks = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
|
||||
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
|
||||
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
|
||||
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0);
|
||||
|
||||
auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer;
|
||||
layer.ffn_norm = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
// non-MoE branch
|
||||
layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
// MoE branch
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_gate_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_down_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_exp_probs_b = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
return use_mmap_buffer;
|
||||
}
|
||||
|
||||
bool create_tensors_helper::create_phi2_tensors(const LLM_TN & tn) {
|
||||
LOADING_PRELUDE
|
||||
|
||||
@@ -2947,6 +2992,8 @@ bool create_tensors_helper::create_tensors() {
|
||||
use_mmap_buffer = create_minimaxm2_tensors(tn); break;
|
||||
case LLM_ARCH_SMOLLM3:
|
||||
use_mmap_buffer = create_smollm3_tensors(tn); break;
|
||||
case LLM_ARCH_MIMO2:
|
||||
use_mmap_buffer = create_mimo2_tensors(tn); break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@@ -2955,7 +3002,6 @@ bool create_tensors_helper::create_tensors() {
|
||||
printf("================================ max_gpu = %d\n", model.max_gpu);
|
||||
std::vector<size_t> mem_used(model.splits.size(), 0);
|
||||
const auto & hparams = model.hparams;
|
||||
int gqa_ratio = hparams.n_head() / hparams.n_head_kv();
|
||||
auto cur_splits = model.splits;
|
||||
int adjust_step = std::max(1, int(n_layer / (2*model.splits.size())));
|
||||
if (model.max_gpu > 1 && model.max_gpu < int(cur_splits.size())) {
|
||||
@@ -2976,6 +3022,7 @@ bool create_tensors_helper::create_tensors() {
|
||||
}
|
||||
}
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
int gqa_ratio = hparams.n_head(il) / hparams.n_head_kv(il);
|
||||
if (ggml_backend_buft_is_host(model.buft_layer[il].buft_matrix)) {
|
||||
LLAMA_LOG_INFO("%s: not splitting layer %d because buffer type is host\n", __func__, il);
|
||||
continue;
|
||||
@@ -2996,19 +3043,26 @@ bool create_tensors_helper::create_tensors() {
|
||||
if (layer.attn_norm) {
|
||||
auto split = create_split(ggml_nrows(layer.attn_norm), -1, cur_splits, mem_used);
|
||||
prepare_split_tensors(-1, ctx_split, layer.attn_norm, layer.split_attn_norm, split, mem_used);
|
||||
if (layer.attn_sinks) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.attn_sinks, layer.split_attn_sinks, split, mem_used);
|
||||
}
|
||||
}
|
||||
if (layer.rope_freqs) {
|
||||
auto split = create_split(ggml_nrows(layer.rope_freqs), -1, cur_splits, mem_used);
|
||||
prepare_split_tensors(-1, ctx_split, layer.rope_freqs, layer.split_rope_freqs, split, mem_used);
|
||||
}
|
||||
if (layer.wo && layer.wq && layer.wk && layer.wv) {
|
||||
int attn_granularity = hparams.n_embd_head_k * gqa_ratio;
|
||||
// TODO: fix this logic. It only works whe K and V head size is the same
|
||||
//printf("Layer %d: q = %ld x %ld, k = %ld x %ld, v = %ld x %ld, qo = %ld x %ld\n", il, layer.wq->ne[0], layer.wq->ne[1],
|
||||
// layer.wk->ne[0], layer.wk->ne[1], layer.wv->ne[0], layer.wv->ne[1], layer.wo->ne[0], layer.wo->ne[1]);
|
||||
int attn_granularity = hparams.n_embd_head_v * gqa_ratio;
|
||||
if (ggml_is_quantized(layer.wo->type)) {
|
||||
auto tt = ggml_internal_get_type_traits(layer.wo->type);
|
||||
if (tt.blck_size > attn_granularity) attn_granularity = tt.blck_size;
|
||||
}
|
||||
GGML_ASSERT(attn_granularity % hparams.n_embd_head_k == 0);
|
||||
GGML_ASSERT(attn_granularity % hparams.n_embd_head_v == 0);
|
||||
auto split = create_split(layer.wo->ne[0], attn_granularity, cur_splits, mem_used);
|
||||
//printf("Split:"); for (auto s : split) printf(" %d", s); printf("\n");
|
||||
prepare_split_tensors(0, ctx_split, layer.wo, layer.split_wo, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.wq, layer.split_wq, split, mem_used);
|
||||
if (layer.bo) {
|
||||
|
||||
@@ -1294,6 +1294,29 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_MIMO2,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_SINKS, "blk.%d.attn_sinks" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
@@ -1538,6 +1561,7 @@ const char * llama_model_type_name(e_model type) {
|
||||
case MODEL_106B_A12B: return "106B.A12B";
|
||||
case MODEL_230B_A10B: return "230B.A10B";
|
||||
case MODEL_235B_A22B: return "235B.A22B";
|
||||
case MODEL_310B_A15B: return "310B.A15B";
|
||||
case MODEL_300B_A47B: return "300B.A47B";
|
||||
case MODEL_355B_A32B: return "355B.A32B";
|
||||
case MODEL_E2B: return "E2B";
|
||||
|
||||
@@ -112,6 +112,7 @@ enum e_model {
|
||||
MODEL_106B_A12B,
|
||||
MODEL_230B_A10B, // Minimax M2
|
||||
MODEL_235B_A22B,
|
||||
MODEL_310B_A15B,
|
||||
MODEL_300B_A47B, // Ernie MoE big
|
||||
MODEL_355B_A32B,
|
||||
MODEL_E2B,
|
||||
@@ -184,6 +185,7 @@ struct llama_layer {
|
||||
struct ggml_tensor * bkv = nullptr;
|
||||
|
||||
llama_split_tensor split_attn_norm;
|
||||
llama_split_tensor split_attn_sinks;
|
||||
llama_split_tensor split_wq;
|
||||
llama_split_tensor split_wk;
|
||||
llama_split_tensor split_wv;
|
||||
|
||||
@@ -4905,6 +4905,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_OPENAI_MOE:
|
||||
case LLM_ARCH_BAILINGMOE2:
|
||||
case LLM_ARCH_MINIMAX_M2:
|
||||
case LLM_ARCH_MIMO2:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
||||
Reference in New Issue
Block a user