mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
model : add grok-2 support (#782)
Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -116,9 +116,9 @@ struct gpt_params {
|
||||
float rope_freq_base = 0.0f; // RoPE base frequency
|
||||
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
||||
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
|
||||
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
|
||||
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||
float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor
|
||||
float yarn_beta_fast = -1.0f; // YaRN low correction dim
|
||||
float yarn_beta_slow = -1.0f; // YaRN high correction dim
|
||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
||||
|
||||
|
||||
@@ -555,6 +555,9 @@ class Model:
|
||||
# NOTE: if you get an error here, you need to update the convert_hf_to_gguf_update.py script
|
||||
# or pull the latest version of the model from Huggingface
|
||||
# don't edit the hashes manually!
|
||||
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
|
||||
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
|
||||
res = "grok-2"
|
||||
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
||||
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
||||
res = "llama-bpe"
|
||||
@@ -1905,12 +1908,20 @@ class BitnetModel(Model):
|
||||
return tensors
|
||||
|
||||
|
||||
@Model.register("GrokForCausalLM")
|
||||
@Model.register("GrokForCausalLM", "Grok1ForCausalLM")
|
||||
class GrokModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.GROK
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
if (self.dir_model / 'tokenizer.model').is_file():
|
||||
self._set_vocab_sentencepiece()
|
||||
return
|
||||
|
||||
if not (self.dir_model / 'tokenizer.json').is_file() or not (self.dir_model / 'chat_template.jinja').is_file():
|
||||
logger.error('Error: Missing vocab and chat template, download files from https://huggingface.co/alvarobartt/grok-2-tokenizer')
|
||||
sys.exit(1)
|
||||
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -1918,11 +1929,46 @@ class GrokModel(Model):
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
self.gguf_writer.add_attn_logit_softcapping(self.hparams.get("attn_logit_softcapping", 30.0))
|
||||
self.gguf_writer.add_router_logit_softcapping(self.hparams.get("router_logit_softcapping", 30.0))
|
||||
if (final_logit_softcap := self.hparams.get("final_logit_softcapping")):
|
||||
self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
|
||||
|
||||
if (rope_dim := self.hparams.get("head_dim")) is None:
|
||||
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
|
||||
# Treat "original" as "yarn", seems to have been a mistake
|
||||
if self.hparams.get("rope_type") in ("yarn", "original"):
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(self.hparams["scaling_factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["original_max_position_embeddings"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_ext_factor(self.hparams["extrapolation_factor"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_attn_factor(self.hparams["attn_factor"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_fast(self.hparams["beta_fast"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_slow(self.hparams["beta_slow"])
|
||||
|
||||
if temp_len := self.hparams.get("attn_temperature_len"):
|
||||
self.gguf_writer.add_attn_temperature_length(temp_len)
|
||||
|
||||
self.gguf_writer.add_attn_output_scale(self.hparams.get("attn_output_multiplier", rope_dim**-0.5))
|
||||
self.gguf_writer.add_embedding_scale(self.hparams["embedding_multiplier_scale"])
|
||||
self.gguf_writer.add_logit_scale(self.hparams["output_multiplier_scale"])
|
||||
|
||||
_experts: list[dict[str, list[Tensor]]] | None = None
|
||||
_cur_expert = ""
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
is_expert = ".moe." in name or ".block_sparse_moe.experts." in name
|
||||
|
||||
if not is_expert:
|
||||
tensors.append((self.map_tensor_name(name), data_torch))
|
||||
|
||||
# process the experts separately
|
||||
if name.find(".moe.") != -1:
|
||||
if is_expert or self._cur_expert:
|
||||
n_experts = self.hparams["num_local_experts"]
|
||||
|
||||
assert bid is not None
|
||||
@@ -1930,32 +1976,41 @@ class GrokModel(Model):
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for wid in ["linear", "linear_1", "linear_v"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
# concatenate split tensors
|
||||
if name in self._experts[bid]:
|
||||
self._cur_expert = name
|
||||
self._experts[bid][name].append(data_torch)
|
||||
return []
|
||||
elif is_expert:
|
||||
self._cur_expert = name
|
||||
self._experts[bid][name] = [data_torch]
|
||||
return []
|
||||
else:
|
||||
self._cur_expert = ""
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
for bid in range(self.block_count):
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
# merge the experts into a single 3d tensor
|
||||
for wid in [("linear", "w1", 0), ("linear_1", "w2", 1), ("linear_v", "w3", 0)]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid[0]}.weight"
|
||||
if ename not in self._experts[bid]:
|
||||
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid[1]}.weight"
|
||||
tensor_list = self._experts[bid][ename]
|
||||
datas.append(torch.cat(tensor_list, dim=wid[2]) if len(tensor_list) > 1 else tensor_list[0])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid[0]}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
yield (new_name, data_torch)
|
||||
|
||||
yield from tensors
|
||||
|
||||
|
||||
@Model.register("DbrxForCausalLM")
|
||||
|
||||
@@ -99,6 +99,7 @@ models = [
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2", },
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902", },
|
||||
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890", },
|
||||
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -97,6 +97,7 @@ class Keys:
|
||||
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
||||
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
|
||||
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
|
||||
ROUTER_LOGIT_SOFTCAPPING = "{arch}.router_logit_softcapping"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
@@ -112,16 +113,22 @@ class Keys:
|
||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||
OUTPUT_SCALE = "{arch}.attention.output_scale"
|
||||
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
FREQ_BASE = "{arch}.rope.freq_base"
|
||||
SCALING_TYPE = "{arch}.rope.scaling.type"
|
||||
SCALING_FACTOR = "{arch}.rope.scaling.factor"
|
||||
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
|
||||
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
||||
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
FREQ_BASE = "{arch}.rope.freq_base"
|
||||
SCALING_TYPE = "{arch}.rope.scaling.type"
|
||||
SCALING_FACTOR = "{arch}.rope.scaling.factor"
|
||||
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
|
||||
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
||||
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
|
||||
SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor"
|
||||
SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
|
||||
SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast"
|
||||
SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow"
|
||||
|
||||
class Split:
|
||||
LLM_KV_SPLIT_NO = "split.no"
|
||||
@@ -540,6 +547,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.GPTNEOX: [
|
||||
|
||||
@@ -656,6 +656,9 @@ class GGUFWriter:
|
||||
def add_attn_logit_softcapping(self, value: float) -> None:
|
||||
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
||||
|
||||
def add_router_logit_softcapping(self, value: float) -> None:
|
||||
self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
||||
|
||||
def add_final_logit_softcapping(self, value: float) -> None:
|
||||
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
||||
|
||||
@@ -701,6 +704,12 @@ class GGUFWriter:
|
||||
def add_sliding_window(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
|
||||
|
||||
def add_attn_output_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)
|
||||
|
||||
def add_attn_temperature_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
|
||||
|
||||
def add_pooling_type(self, value: PoolingType) -> None:
|
||||
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
||||
|
||||
@@ -728,6 +737,18 @@ class GGUFWriter:
|
||||
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_yarn_ext_factor(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_yarn_attn_factor(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_yarn_beta_fast(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_yarn_beta_slow(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_conv_kernel(self, value: int) -> None:
|
||||
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ class TensorNameMap:
|
||||
"backbone.embedding", # mamba
|
||||
"backbone.embeddings", # mamba-hf
|
||||
"transformer.in_out_embed", # Grok
|
||||
"model.layers.{bid}.pre_attn_norm", # grok-2
|
||||
"embedding.word_embeddings", # chatglm
|
||||
"transformer.token_embeddings", # openelm
|
||||
"shared", # t5
|
||||
@@ -202,6 +203,7 @@ class TensorNameMap:
|
||||
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
||||
"encoder.layers.{bid}.norm1", # nomic-bert
|
||||
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
|
||||
"model.layers.{bid}.post_attn_norm", # grok-2
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
|
||||
),
|
||||
|
||||
@@ -230,6 +232,7 @@ class TensorNameMap:
|
||||
"h.{bid}.ln_2", # gpt2
|
||||
"model.layers.{bid}.ffn_norm", # internlm2
|
||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||
"model.layers.{bid}.pre_moe_norm", # grok-2
|
||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||
),
|
||||
@@ -242,6 +245,7 @@ class TensorNameMap:
|
||||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_POST_NORM: (
|
||||
"model.layers.{bid}.post_feedforward_layernorm", # gemma2
|
||||
"model.layers.{bid}.post_moe_norm", # grok-2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP: (
|
||||
|
||||
@@ -99,6 +99,7 @@ enum llm_kv {
|
||||
LLM_KV_LOGIT_SCALE,
|
||||
LLM_KV_DECODER_START_TOKEN_ID,
|
||||
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
|
||||
LLM_KV_ROUTER_LOGIT_SOFTCAPPING,
|
||||
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
|
||||
LLM_KV_SWIN_NORM,
|
||||
LLM_KV_RESCALE_EVERY_N_LAYERS,
|
||||
@@ -123,7 +124,8 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
|
||||
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
||||
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
LLM_KV_ROPE_SCALE_LINEAR,
|
||||
@@ -134,6 +136,11 @@ enum llm_kv {
|
||||
LLM_KV_ROPE_SCALING_FINETUNED,
|
||||
LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
|
||||
|
||||
LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR,
|
||||
LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR,
|
||||
LLM_KV_ROPE_SCALING_YARN_BETA_FAST,
|
||||
LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,
|
||||
|
||||
LLM_KV_SPLIT_NO,
|
||||
LLM_KV_SPLIT_COUNT,
|
||||
LLM_KV_SPLIT_TENSORS_COUNT,
|
||||
|
||||
@@ -433,6 +433,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_GROK_2:
|
||||
regex_exprs = {
|
||||
// original regex from tokenizer.json
|
||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
@@ -1973,6 +1980,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
tokenizer_pre == "kimi-k2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
|
||||
clean_spaces = false;
|
||||
}
|
||||
else if (
|
||||
tokenizer_pre == "grok-2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
|
||||
clean_spaces = false;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ enum llama_vocab_pre_type {
|
||||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
||||
LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
|
||||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38,
|
||||
LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
|
||||
};
|
||||
|
||||
struct LLM_KV;
|
||||
|
||||
164
src/llama.cpp
164
src/llama.cpp
@@ -289,6 +289,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
||||
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
||||
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
|
||||
{ LLM_KV_ROUTER_LOGIT_SOFTCAPPING, "%s.router_logit_softcapping" },
|
||||
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
|
||||
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
|
||||
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
||||
@@ -309,6 +310,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
|
||||
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
@@ -319,6 +322,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
|
||||
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
|
||||
{ LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" },
|
||||
{ LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" },
|
||||
{ LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" },
|
||||
{ LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" },
|
||||
{ LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" },
|
||||
|
||||
{ LLM_KV_SPLIT_NO, "split.no" },
|
||||
{ LLM_KV_SPLIT_COUNT, "split.count" },
|
||||
@@ -507,9 +514,13 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
||||
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
@@ -1558,6 +1569,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
||||
LLM_CHAT_TEMPLATE_KIMI_K2,
|
||||
LLM_CHAT_TEMPLATE_OPENAI_MOE,
|
||||
LLM_CHAT_TEMPLATE_GROK_2,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -1599,6 +1611,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
||||
{ "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
|
||||
{ "bitnet", LLM_CHAT_TEMPLATE_BITNET },
|
||||
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
|
||||
};
|
||||
|
||||
// helper to handle gguf constants
|
||||
@@ -1930,6 +1943,7 @@ struct llama_hparams {
|
||||
float f_norm_rms_eps;
|
||||
|
||||
float f_attn_logit_softcapping = 50.0f;
|
||||
float f_router_logit_softcapping = 30.0f;
|
||||
float f_final_logit_softcapping = 30.0f;
|
||||
|
||||
float rope_attn_factor = 1.0f;
|
||||
@@ -1940,6 +1954,11 @@ struct llama_hparams {
|
||||
uint32_t n_ctx_orig_yarn;
|
||||
float rope_yarn_log_mul;
|
||||
|
||||
float yarn_ext_factor = -1.0f;
|
||||
float yarn_attn_factor = 1.0f;
|
||||
float yarn_beta_fast = 32.0f;
|
||||
float yarn_beta_slow = 1.0f;
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
uint32_t ssm_d_inner = 0;
|
||||
@@ -1955,6 +1974,10 @@ struct llama_hparams {
|
||||
float f_embedding_scale = 0.0f;
|
||||
float f_attention_scale = 0.0f;
|
||||
|
||||
// grok-2
|
||||
float f_attn_out_scale = 0.0f;
|
||||
uint32_t attn_temp_length = 0;
|
||||
|
||||
bool causal_attn = true;
|
||||
bool use_alibi = false;
|
||||
bool attn_soft_cap = false;
|
||||
@@ -3644,7 +3667,30 @@ static void llm_load_hparams(
|
||||
} break;
|
||||
case LLM_ARCH_GROK:
|
||||
{
|
||||
// defaults for old GGUFs
|
||||
hparams.yarn_beta_fast = 8.0f;
|
||||
hparams.f_logit_scale = 0.5773502691896257f;
|
||||
hparams.f_embedding_scale = 78.38367176906169f;
|
||||
hparams.f_attn_out_scale = 0.08838834764831845f;
|
||||
hparams.f_attn_logit_softcapping = 30.0f;
|
||||
hparams.f_router_logit_softcapping = 30.0f;
|
||||
// no final_logit_softcapping in grok-1
|
||||
hparams.f_final_logit_softcapping = 0.0f;
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false);
|
||||
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false);
|
||||
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
|
||||
ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false);
|
||||
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 64: model.type = e_model::MODEL_314B; break;
|
||||
@@ -5238,7 +5284,7 @@ static bool llm_load_tensors(
|
||||
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
@@ -5256,12 +5302,16 @@ static bool llm_load_tensors(
|
||||
|
||||
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
||||
layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
if (layer.ffn_gate_exps) {
|
||||
layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert});
|
||||
layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
|
||||
layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert});
|
||||
layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert });
|
||||
} else {
|
||||
// merge split expert into a single tensor for compatibility with older models
|
||||
// requires disabling mmap
|
||||
@@ -5287,7 +5337,10 @@ static bool llm_load_tensors(
|
||||
}
|
||||
}
|
||||
|
||||
layer.layer_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_post_norm = create_tensor(ctx_layer,tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
|
||||
if (!layer.ffn_post_norm) {
|
||||
layer.ffn_post_norm = create_tensor(ctx_layer,tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DBRX:
|
||||
@@ -8197,7 +8250,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
|
||||
if (model.arch == LLM_ARCH_GROK) {
|
||||
// need to do the following:
|
||||
// multiply by attn_output_multiplyer of 0.08838834764831845
|
||||
// multiply by attn_output_multiplier
|
||||
// and then :
|
||||
// kq = 30 * tanh(kq / 30)
|
||||
// before the softmax below
|
||||
@@ -8208,7 +8261,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
//kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
|
||||
//kq = ggml_scale(ctx, kq, 30);
|
||||
|
||||
kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f);
|
||||
kq = ggml_softcap(ctx, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
|
||||
}
|
||||
|
||||
if (hparams.attn_soft_cap) {
|
||||
@@ -8254,7 +8307,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32);
|
||||
}
|
||||
if (model.arch == LLM_ARCH_GROK) {
|
||||
kq_i = ggml_softcap(ctx, kq_i, 0.08838834764831845f/30.0f, 30.f);
|
||||
kq_i = ggml_softcap(ctx, kq_i, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
|
||||
}
|
||||
if (hparams.attn_soft_cap) {
|
||||
kq_i = ggml_softcap_max(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
@@ -9553,15 +9606,11 @@ struct llm_build_context {
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
// multiply by embedding_multiplier_scale of 78.38367176906169
|
||||
inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
@@ -9623,26 +9672,23 @@ struct llm_build_context {
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
// Grok
|
||||
// if attn_out_norm is present then apply it before adding the input
|
||||
if (model.layers[il].attn_out_norm) {
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].attn_out_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_out_norm", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].attn_out_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_out_norm", il);
|
||||
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
// MoE branch
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_moe_ffn(ctx0, lctx, cur,
|
||||
// MoE branch
|
||||
ggml_tensor* moe_out = llm_build_moe_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
@@ -9653,17 +9699,29 @@ struct llm_build_context {
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il, gf);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
// Grok
|
||||
// if layer_out_norm is present then apply it before adding the input
|
||||
// Idea: maybe ffn_out_norm is a better name
|
||||
if (model.layers[il].layer_out_norm) {
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].layer_out_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "layer_out_norm", il);
|
||||
if (model.layers[il].ffn_up) {
|
||||
ggml_tensor* ffn_out = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
|
||||
cb(ffn_out, "ffn_out", il);
|
||||
|
||||
cur = ggml_scale(ctx0, ggml_add(ctx0, ffn_out, moe_out), std::sqrt(2) / 2);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
else {
|
||||
cur = moe_out;
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].ffn_post_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_post_norm", il);
|
||||
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
@@ -9685,11 +9743,15 @@ struct llm_build_context {
|
||||
// lm_head
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
|
||||
// Grok
|
||||
// multiply logits by output_multiplier_scale of 0.5773502691896257
|
||||
|
||||
cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
|
||||
cur = ggml_scale(ctx0, cur, hparams.f_logit_scale);
|
||||
// final logit soft-capping
|
||||
if (hparams.f_final_logit_softcapping) {
|
||||
/*cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
|
||||
cur = ggml_tanh(ctx0, cur);
|
||||
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);*/
|
||||
cur = ggml_softcap(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping, hparams.f_final_logit_softcapping);
|
||||
|
||||
}
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
@@ -19393,9 +19455,9 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.rope_freq_base =*/ 0.0f,
|
||||
/*.rope_freq_scale =*/ 0.0f,
|
||||
/*.yarn_ext_factor =*/ -1.0f,
|
||||
/*.yarn_attn_factor =*/ 1.0f,
|
||||
/*.yarn_beta_fast =*/ 32.0f,
|
||||
/*.yarn_beta_slow =*/ 1.0f,
|
||||
/*.yarn_attn_factor =*/ -1.0f,
|
||||
/*.yarn_beta_fast =*/ -1.0f,
|
||||
/*.yarn_beta_slow =*/ -1.0f,
|
||||
/*.yarn_orig_ctx =*/ 0,
|
||||
/*.defrag_thold =*/ -1.0f,
|
||||
/*.cb_eval =*/ nullptr,
|
||||
@@ -19607,10 +19669,10 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
|
||||
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.embeddings = params.embeddings;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
@@ -21987,6 +22049,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
||||
} else if (tmpl_contains("<|im_middle|>") && tmpl_contains("<|im_end|>")) {
|
||||
return LLM_CHAT_TEMPLATE_KIMI_K2;
|
||||
} else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) {
|
||||
return LLM_CHAT_TEMPLATE_GROK_2;
|
||||
} else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) {
|
||||
return LLM_CHAT_TEMPLATE_OPENAI_MOE;
|
||||
}
|
||||
@@ -22459,6 +22523,22 @@ static int32_t llama_chat_apply_template_internal(
|
||||
if (add_ass) {
|
||||
ss << "<|start|>assistant";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GROK_2) {
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << "System: " << trim(message->content) << "<|separator|>\n\n";
|
||||
}
|
||||
else if (role == "user") {
|
||||
ss << "Human: " << trim(message->content) << "<|separator|>\n\n";
|
||||
}
|
||||
else if (role == "assistant") {
|
||||
ss << "Assistant: " << message->content << "<|separator|>\n\n";
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "Assistant:";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
|
||||
Reference in New Issue
Block a user