Fix model architecture name

This commit is contained in:
junhuihe
2025-04-21 17:19:43 +08:00
committed by Saood Karim
parent 98d1626469
commit 0e247afcac

View File

@@ -227,6 +227,7 @@ enum llm_arch {
LLM_ARCH_GLM4, LLM_ARCH_GLM4,
LLM_ARCH_BITNET, LLM_ARCH_BITNET,
LLM_ARCH_BITNET_25, LLM_ARCH_BITNET_25,
LLM_ARCH_BITNET_B158,
LLM_ARCH_T5, LLM_ARCH_T5,
LLM_ARCH_T5ENCODER, LLM_ARCH_T5ENCODER,
LLM_ARCH_JAIS, LLM_ARCH_JAIS,
@@ -281,6 +282,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_BITNET_25, "bitnet-25" }, { LLM_ARCH_BITNET_25, "bitnet-25" },
{ LLM_ARCH_BITNET_B158, "bitnet-b1.58" },
{ LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_T5ENCODER, "t5encoder" },
{ LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_JAIS, "jais" },
@@ -1419,6 +1421,34 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
}, },
}, },
{
LLM_ARCH_BITNET_B158,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" },
{ LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
},
},
{ {
LLM_ARCH_T5, LLM_ARCH_T5,
{ {
@@ -5235,7 +5265,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25) { if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25 || model.arch == LLM_ARCH_BITNET_B158) {
if (hparams.n_rot != hparams.n_embd_head_k) { if (hparams.n_rot != hparams.n_embd_head_k) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
} }
@@ -5830,6 +5860,7 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_BITNET_B158:
case LLM_ARCH_BITNET_25: case LLM_ARCH_BITNET_25:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -8350,6 +8381,7 @@ static bool llm_load_tensors(
layer.ffn_up_scale = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_up_scale = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
} }
} break; } break;
case LLM_ARCH_BITNET_B158:
case LLM_ARCH_BITNET_25: case LLM_ARCH_BITNET_25:
{ {
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -15376,7 +15408,7 @@ struct llm_build_context {
return gf; return gf;
} }
struct ggml_cgraph * build_bitnet_25() { struct ggml_cgraph * build_bitnet_158() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
// mutable variable, needed during the last layer of the computation to skip unused tokens // mutable variable, needed during the last layer of the computation to skip unused tokens
@@ -16599,9 +16631,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_bitnet(); result = llm.build_bitnet();
} break; } break;
case LLM_ARCH_BITNET_B158:
case LLM_ARCH_BITNET_25: case LLM_ARCH_BITNET_25:
{ {
result = llm.build_bitnet_25(); result = llm.build_bitnet_158();
} break; } break;
case LLM_ARCH_COHERE2: case LLM_ARCH_COHERE2:
{ {
@@ -20293,6 +20326,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_STABLELM: case LLM_ARCH_STABLELM:
case LLM_ARCH_BITNET: case LLM_ARCH_BITNET:
case LLM_ARCH_BITNET_25: case LLM_ARCH_BITNET_25:
case LLM_ARCH_BITNET_B158:
case LLM_ARCH_QWEN: case LLM_ARCH_QWEN:
case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2:
case LLM_ARCH_QWEN2MOE: case LLM_ARCH_QWEN2MOE: