mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 19:01:47 +00:00
BailingMoE2 conversion
This commit is contained in:
@@ -654,6 +654,9 @@ class Model:
|
|||||||
if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890":
|
if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890":
|
||||||
# ref: https://huggingface.co/moonshotai/Kimi-K2-Base
|
# ref: https://huggingface.co/moonshotai/Kimi-K2-Base
|
||||||
res = "kimi-k2"
|
res = "kimi-k2"
|
||||||
|
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
|
||||||
|
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
|
||||||
|
res = "bailingmoe2"
|
||||||
|
|
||||||
if res is None:
|
if res is None:
|
||||||
logger.warning("\n")
|
logger.warning("\n")
|
||||||
@@ -4461,6 +4464,103 @@ class ChatGLMModel(Model):
|
|||||||
name = name.removeprefix("transformer.")
|
name = name.removeprefix("transformer.")
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
@Model.register("BailingMoeV2ForCausalLM")
|
||||||
|
class BailingMoeV2Model(Model):
|
||||||
|
model_arch = gguf.MODEL_ARCH.BAILINGMOE2
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
if nextn_layers := self.hparams.get("num_nextn_predict_layers", 0):
|
||||||
|
self.block_count = self.hparams["num_hidden_layers"] + nextn_layers
|
||||||
|
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
hparams = self.hparams
|
||||||
|
if (rope_dim := hparams.get("head_dim")) is None:
|
||||||
|
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||||
|
|
||||||
|
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
|
||||||
|
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||||
|
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
|
||||||
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||||
|
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||||
|
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||||
|
else:
|
||||||
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||||
|
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
||||||
|
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||||
|
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||||
|
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["moe_shared_expert_intermediate_size"])
|
||||||
|
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||||
|
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||||
|
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||||
|
self.gguf_writer.add_expert_group_count(hparams["n_group"])
|
||||||
|
self.gguf_writer.add_expert_group_used_count(hparams["topk_group"])
|
||||||
|
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||||
|
|
||||||
|
if hparams["score_function"] == "sigmoid":
|
||||||
|
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||||
|
elif hparams["score_function"] == "softmax":
|
||||||
|
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")
|
||||||
|
|
||||||
|
if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
|
||||||
|
self.gguf_writer.add_nextn_predict_layers(nextn_layers)
|
||||||
|
|
||||||
|
_experts: list[dict[str, Tensor]] | None = None
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
if "mlp.experts" in name:
|
||||||
|
n_experts = self.hparams["num_experts"]
|
||||||
|
assert bid is not None
|
||||||
|
|
||||||
|
tensors: list[tuple[str, Tensor]] = []
|
||||||
|
|
||||||
|
if self._experts is None:
|
||||||
|
self._experts = [{} for _ in range(self.block_count)]
|
||||||
|
|
||||||
|
self._experts[bid][name] = data_torch
|
||||||
|
|
||||||
|
if len(self._experts[bid]) >= n_experts * 3:
|
||||||
|
# merge the experts into a single 3d tensor
|
||||||
|
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||||
|
datas: list[Tensor] = []
|
||||||
|
|
||||||
|
for xid in range(n_experts):
|
||||||
|
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||||
|
datas.append(self._experts[bid][ename])
|
||||||
|
del self._experts[bid][ename]
|
||||||
|
|
||||||
|
data_torch = torch.stack(datas, dim=0)
|
||||||
|
|
||||||
|
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||||
|
|
||||||
|
new_name = self.map_tensor_name(merged_name)
|
||||||
|
|
||||||
|
tensors.append((new_name, data_torch))
|
||||||
|
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
if name.endswith(".expert_bias"):
|
||||||
|
name = name.replace(".expert_bias", ".expert_bias.bias")
|
||||||
|
|
||||||
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
def prepare_tensors(self):
|
||||||
|
super().prepare_tensors()
|
||||||
|
|
||||||
|
if self._experts is not None:
|
||||||
|
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||||
|
experts = [k for d in self._experts for k in d.keys()]
|
||||||
|
if len(experts) > 0:
|
||||||
|
raise ValueError(f"Unprocessed experts: {experts}")
|
||||||
|
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -88,6 +88,8 @@ class Keys:
|
|||||||
EXPERT_COUNT = "{arch}.expert_count"
|
EXPERT_COUNT = "{arch}.expert_count"
|
||||||
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
||||||
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
|
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
|
||||||
|
EXPERT_GROUP_COUNT = "{arch}.expert_group_count"
|
||||||
|
EXPERT_GROUP_USED_COUNT = "{arch}.expert_group_used_count"
|
||||||
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
|
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
|
||||||
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
|
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
|
||||||
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
|
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
|
||||||
@@ -245,6 +247,7 @@ class MODEL_ARCH(IntEnum):
|
|||||||
DOTS1 = auto()
|
DOTS1 = auto()
|
||||||
ERNIE4_5 = auto()
|
ERNIE4_5 = auto()
|
||||||
ERNIE4_5_MOE = auto()
|
ERNIE4_5_MOE = auto()
|
||||||
|
BAILINGMOE2 = auto()
|
||||||
|
|
||||||
class MODEL_TENSOR(IntEnum):
|
class MODEL_TENSOR(IntEnum):
|
||||||
TOKEN_EMBD = auto()
|
TOKEN_EMBD = auto()
|
||||||
@@ -390,6 +393,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||||||
MODEL_ARCH.DOTS1: "dots1",
|
MODEL_ARCH.DOTS1: "dots1",
|
||||||
MODEL_ARCH.ERNIE4_5: "ernie4_5",
|
MODEL_ARCH.ERNIE4_5: "ernie4_5",
|
||||||
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
|
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
|
||||||
|
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
|
||||||
}
|
}
|
||||||
|
|
||||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
@@ -1291,6 +1295,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.BAILINGMOE2: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_K_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_QKV,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.FFN_GATE_INP,
|
||||||
|
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||||
|
MODEL_TENSOR.NEXTN_EH_PROJ,
|
||||||
|
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
|
||||||
|
MODEL_TENSOR.NEXTN_ENORM,
|
||||||
|
MODEL_TENSOR.NEXTN_HNORM,
|
||||||
|
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
|
||||||
|
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
|
||||||
|
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||||
|
],
|
||||||
# TODO
|
# TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -671,6 +671,12 @@ class GGUFWriter:
|
|||||||
def add_expert_shared_count(self, count: int) -> None:
|
def add_expert_shared_count(self, count: int) -> None:
|
||||||
self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
|
self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
|
||||||
|
|
||||||
|
def add_expert_group_count(self, count: int) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.EXPERT_GROUP_COUNT.format(arch=self.arch), count)
|
||||||
|
|
||||||
|
def add_expert_group_used_count(self, count: int) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.EXPERT_GROUP_USED_COUNT.format(arch=self.arch), count)
|
||||||
|
|
||||||
def add_expert_weights_scale(self, value: float) -> None:
|
def add_expert_weights_scale(self, value: float) -> None:
|
||||||
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
|
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class TensorNameMap:
|
|||||||
"model.layers.{bid}.pre_attn_norm", # grok-2
|
"model.layers.{bid}.pre_attn_norm", # grok-2
|
||||||
"embedding.word_embeddings", # chatglm
|
"embedding.word_embeddings", # chatglm
|
||||||
"transformer.token_embeddings", # openelm
|
"transformer.token_embeddings", # openelm
|
||||||
|
"model.word_embeddings", # bailingmoe
|
||||||
"shared", # t5
|
"shared", # t5
|
||||||
),
|
),
|
||||||
|
|
||||||
@@ -129,6 +130,7 @@ class TensorNameMap:
|
|||||||
"h.{bid}.self_attention.query_key_value", # bloom
|
"h.{bid}.self_attention.query_key_value", # bloom
|
||||||
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
|
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
|
||||||
"model.layers.{bid}.self_attn.query_key_value", # persimmon
|
"model.layers.{bid}.self_attn.query_key_value", # persimmon
|
||||||
|
"model.layers.{bid}.attention.query_key_value", # bailingmoe2
|
||||||
"h.{bid}.attn.c_attn", # gpt2
|
"h.{bid}.attn.c_attn", # gpt2
|
||||||
"transformer.h.{bid}.mixer.Wqkv", # phi2
|
"transformer.h.{bid}.mixer.Wqkv", # phi2
|
||||||
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
|
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
|
||||||
@@ -187,6 +189,7 @@ class TensorNameMap:
|
|||||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
||||||
"model.layers.{bid}.self_attn.dense", # persimmon
|
"model.layers.{bid}.self_attn.dense", # persimmon
|
||||||
|
"model.layers.{bid}.attention.dense", # bailingmoe2
|
||||||
"h.{bid}.attn.c_proj", # gpt2
|
"h.{bid}.attn.c_proj", # gpt2
|
||||||
"transformer.h.{bid}.mixer.out_proj", # phi2
|
"transformer.h.{bid}.mixer.out_proj", # phi2
|
||||||
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
||||||
@@ -263,6 +266,7 @@ class TensorNameMap:
|
|||||||
MODEL_TENSOR.FFN_EXP_PROBS_B: (
|
MODEL_TENSOR.FFN_EXP_PROBS_B: (
|
||||||
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
|
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
|
||||||
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
|
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
|
||||||
|
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
|
||||||
),
|
),
|
||||||
|
|
||||||
# Feed-forward up
|
# Feed-forward up
|
||||||
@@ -382,6 +386,7 @@ class TensorNameMap:
|
|||||||
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
|
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
|
||||||
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
|
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
|
||||||
"transformer.layers.{bid}.attn.q_norm", # openelm
|
"transformer.layers.{bid}.attn.q_norm", # openelm
|
||||||
|
"model.layers.{bid}.attention.query_layernorm", # bailingmoe2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_K_NORM: (
|
MODEL_TENSOR.ATTN_K_NORM: (
|
||||||
@@ -391,6 +396,7 @@ class TensorNameMap:
|
|||||||
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
|
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
|
||||||
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
|
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
|
||||||
"transformer.layers.{bid}.attn.k_norm", # openelm
|
"transformer.layers.{bid}.attn.k_norm", # openelm
|
||||||
|
"model.layers.{bid}.attention.key_layernorm", # bailingmoe2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ROPE_FREQS: (
|
MODEL_TENSOR.ROPE_FREQS: (
|
||||||
@@ -403,6 +409,7 @@ class TensorNameMap:
|
|||||||
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
|
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
|
||||||
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
|
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
|
||||||
"encoder.layer.{bid}.layer_norm_2" # jina-v2-code
|
"encoder.layer.{bid}.layer_norm_2" # jina-v2-code
|
||||||
|
"model.layers.{bid}.final_layernorm", # bailingmoe2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_IN: (
|
MODEL_TENSOR.SSM_IN: (
|
||||||
|
|||||||
Reference in New Issue
Block a user