mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Deepseek V3 support added (#176)
Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
This commit is contained in:
@@ -590,6 +590,9 @@ class Model:
|
||||
if chkhsh == "855059429035d75a914d1eda9f10a876752e281a054a7a3d421ef0533e5b6249":
|
||||
# ref: https://huggingface.co/HuggingFaceTB/SmolLM-135M
|
||||
res = "smollm"
|
||||
if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5":
|
||||
# ref: https://huggingface.co/deepseek-ai/DeepSeek-V3
|
||||
res = "deepseek-v3"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
||||
@@ -94,6 +94,7 @@ models = [
|
||||
{"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", },
|
||||
{"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", },
|
||||
{"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", },
|
||||
{"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"},
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -89,6 +89,8 @@ class Keys:
|
||||
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
||||
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
|
||||
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
|
||||
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
|
||||
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
|
||||
POOLING_TYPE = "{arch}.pooling_type"
|
||||
LOGIT_SCALE = "{arch}.logit_scale"
|
||||
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
||||
@@ -257,6 +259,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
FFN_GATE_SHEXP = auto()
|
||||
FFN_DOWN_SHEXP = auto()
|
||||
FFN_UP_SHEXP = auto()
|
||||
FFN_EXP_PROBS_B = auto()
|
||||
ATTN_Q_NORM = auto()
|
||||
ATTN_K_NORM = auto()
|
||||
LAYER_OUT_NORM = auto()
|
||||
@@ -387,6 +390,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
|
||||
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
|
||||
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
|
||||
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
|
||||
@@ -978,6 +982,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B
|
||||
],
|
||||
MODEL_ARCH.CHATGLM : [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
@@ -1177,6 +1182,10 @@ class GGMLQuantizationType(IntEnum):
|
||||
IQ2_TN = 42,
|
||||
|
||||
|
||||
class ExpertGatingFuncType(IntEnum):
|
||||
SOFTMAX = 1
|
||||
SIGMOID = 2
|
||||
|
||||
|
||||
# TODO: add GGMLFileType from ggml_ftype in ggml.h
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from .constants import (
|
||||
RopeScalingType,
|
||||
PoolingType,
|
||||
TokenType,
|
||||
ExpertGatingFuncType,
|
||||
)
|
||||
|
||||
from .quants import quant_shape_from_byte_shape
|
||||
@@ -670,6 +671,12 @@ class GGUFWriter:
|
||||
def add_expert_weights_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
|
||||
|
||||
def add_expert_weights_norm(self, value: bool) -> None:
|
||||
self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value)
|
||||
|
||||
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
|
||||
|
||||
def add_layer_norm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
|
||||
|
||||
|
||||
@@ -251,6 +251,10 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B: (
|
||||
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3
|
||||
),
|
||||
|
||||
# Feed-forward up
|
||||
MODEL_TENSOR.FFN_UP: (
|
||||
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
|
||||
|
||||
@@ -93,6 +93,7 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
|
||||
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
|
||||
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 23, //llama.cpp lists this as 28
|
||||
};
|
||||
|
||||
// note: these values should be synchronized with ggml_rope
|
||||
|
||||
@@ -367,6 +367,13 @@ struct llm_tokenizer_bpe {
|
||||
"\\p{N}+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM:
|
||||
regex_exprs = {
|
||||
"\\p{N}{1,3}",
|
||||
"[一-龥-ゟ゠-ヿ]+",
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
|
||||
regex_exprs = {
|
||||
"[\r\n]",
|
||||
|
||||
100
src/llama.cpp
100
src/llama.cpp
@@ -106,7 +106,7 @@
|
||||
|
||||
// bump if necessary
|
||||
#define LLAMA_MAX_LAYERS 512
|
||||
#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2
|
||||
#define LLAMA_MAX_EXPERTS 256 // DeepSeekV2
|
||||
|
||||
//
|
||||
// helpers
|
||||
@@ -294,6 +294,8 @@ enum llm_kv {
|
||||
LLM_KV_EXPERT_USED_COUNT,
|
||||
LLM_KV_EXPERT_SHARED_COUNT,
|
||||
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
||||
LLM_KV_EXPERT_WEIGHTS_NORM,
|
||||
LLM_KV_EXPERT_GATING_FUNC,
|
||||
LLM_KV_POOLING_TYPE,
|
||||
LLM_KV_LOGIT_SCALE,
|
||||
LLM_KV_DECODER_START_TOKEN_ID,
|
||||
@@ -399,6 +401,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
|
||||
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
|
||||
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
|
||||
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
|
||||
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
|
||||
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
|
||||
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
||||
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
||||
@@ -520,6 +524,7 @@ enum llm_tensor {
|
||||
LLM_TENSOR_FFN_DOWN_SHEXP,
|
||||
LLM_TENSOR_FFN_GATE_SHEXP,
|
||||
LLM_TENSOR_FFN_UP_SHEXP,
|
||||
LLM_TENSOR_FFN_EXP_PROBS_B,
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
LLM_TENSOR_ATTN_K_NORM,
|
||||
LLM_TENSOR_LAYER_OUT_NORM,
|
||||
@@ -1211,6 +1216,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -2186,6 +2192,7 @@ enum e_model {
|
||||
MODEL_70B,
|
||||
MODEL_236B,
|
||||
MODEL_314B,
|
||||
MODEL_671B,
|
||||
MODEL_SMALL,
|
||||
MODEL_MEDIUM,
|
||||
MODEL_LARGE,
|
||||
@@ -2203,6 +2210,21 @@ static const size_t kiB = 1024;
|
||||
static const size_t MiB = 1024*kiB;
|
||||
static const size_t GiB = 1024*MiB;
|
||||
|
||||
enum llm_expert_gating_func_type {
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX = 1,
|
||||
LLM_EXPERT_GATING_FUNC_SIGMOID = 2,
|
||||
};
|
||||
|
||||
static const char * llama_expert_gating_func_name(llm_expert_gating_func_type type) {
|
||||
switch (type) {
|
||||
case LLM_EXPERT_GATING_FUNC_SOFTMAX: return "softmax";
|
||||
case LLM_EXPERT_GATING_FUNC_SIGMOID: return "sigmoid";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
struct llama_hparams {
|
||||
bool vocab_only;
|
||||
bool rope_finetuned;
|
||||
@@ -2232,6 +2254,8 @@ struct llama_hparams {
|
||||
uint32_t n_ff_shexp = 0;
|
||||
uint32_t n_expert_shared = 0;
|
||||
float expert_weights_scale = 0.0;
|
||||
bool expert_weights_norm = false;
|
||||
uint32_t expert_gating_func = LLM_EXPERT_GATING_FUNC_SOFTMAX;
|
||||
|
||||
float f_norm_eps;
|
||||
float f_norm_rms_eps;
|
||||
@@ -2502,6 +2526,7 @@ struct llama_layer {
|
||||
struct ggml_tensor * ffn_down_b = nullptr; // b2
|
||||
struct ggml_tensor * ffn_up_b = nullptr; // b3
|
||||
struct ggml_tensor * ffn_act;
|
||||
struct ggml_tensor * ffn_exp_probs_b = nullptr;
|
||||
|
||||
// mamba proj
|
||||
struct ggml_tensor * ssm_in;
|
||||
@@ -4677,6 +4702,7 @@ static const char * llama_model_type_name(e_model type) {
|
||||
case MODEL_70B: return "70B";
|
||||
case MODEL_236B: return "236B";
|
||||
case MODEL_314B: return "314B";
|
||||
case MODEL_671B: return "671B";
|
||||
case MODEL_SMALL: return "0.1B";
|
||||
case MODEL_MEDIUM: return "0.4B";
|
||||
case MODEL_LARGE: return "0.8B";
|
||||
@@ -5302,11 +5328,19 @@ static void llm_load_hparams(
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
|
||||
if (hparams.expert_gating_func == 0) {
|
||||
// for compatibility with existing DeepSeek V2 and V2.5 GGUFs
|
||||
// that have no expert_gating_func model parameter set
|
||||
hparams.expert_gating_func = LLM_EXPERT_GATING_FUNC_SOFTMAX;
|
||||
}
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 27: model.type = e_model::MODEL_16B; break;
|
||||
case 60: model.type = e_model::MODEL_236B; break;
|
||||
case 61: model.type = e_model::MODEL_671B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
@@ -5565,6 +5599,10 @@ static void llm_load_vocab(
|
||||
tokenizer_pre == "deepseek-coder") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER;
|
||||
vocab.tokenizer_clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "deepseek-v3") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM;
|
||||
vocab.tokenizer_clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "falcon") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON;
|
||||
@@ -6075,6 +6113,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
|
||||
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
||||
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
|
||||
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((enum llm_expert_gating_func_type) hparams.expert_gating_func));
|
||||
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
|
||||
}
|
||||
|
||||
@@ -7540,6 +7580,7 @@ static bool llm_load_tensors(
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
} else {
|
||||
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
||||
layer.ffn_exp_probs_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert} );
|
||||
|
||||
GGML_ASSERT(n_expert > 0);
|
||||
GGML_ASSERT(n_expert_used > 0);
|
||||
@@ -8346,12 +8387,14 @@ static struct ggml_tensor * llm_build_moe_ffn(
|
||||
struct ggml_tensor * up_exps,
|
||||
struct ggml_tensor * gate_exps,
|
||||
struct ggml_tensor * down_exps,
|
||||
struct ggml_tensor * exp_probs_b,
|
||||
int64_t n_expert,
|
||||
int64_t n_expert_used,
|
||||
llm_ffn_op_type type_op,
|
||||
bool norm_w,
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llm_expert_gating_func_type gating_op,
|
||||
const llm_build_cb & cb,
|
||||
int il) {
|
||||
int64_t n_embd = cur->ne[0];
|
||||
@@ -8360,11 +8403,32 @@ static struct ggml_tensor * llm_build_moe_ffn(
|
||||
ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
|
||||
cb(logits, "ffn_moe_logits", il);
|
||||
|
||||
ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
|
||||
//ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
|
||||
ggml_tensor * probs = nullptr;
|
||||
switch (gating_op) {
|
||||
case LLM_EXPERT_GATING_FUNC_SOFTMAX:
|
||||
{
|
||||
probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
|
||||
} break;
|
||||
case LLM_EXPERT_GATING_FUNC_SIGMOID:
|
||||
{
|
||||
probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
cb(probs, "ffn_moe_probs", il);
|
||||
|
||||
// add experts selection bias - introduced in DeepSeek V3
|
||||
// leave probs unbiased as it's later used to get expert weights
|
||||
ggml_tensor * selection_probs = probs;
|
||||
if (exp_probs_b != nullptr) {
|
||||
selection_probs = ggml_add(ctx, probs, exp_probs_b);
|
||||
cb(selection_probs, "ffn_moe_probs_biased", il);
|
||||
}
|
||||
|
||||
// select experts
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||
cb(selected_experts, "ffn_moe_topk", il);
|
||||
|
||||
@@ -9180,9 +9244,11 @@ struct llm_build_context {
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
}
|
||||
@@ -9673,9 +9739,11 @@ struct llm_build_context {
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_GELU, true,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
@@ -9814,9 +9882,11 @@ struct llm_build_context {
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
@@ -10944,9 +11014,11 @@ struct llm_build_context {
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, false,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
@@ -13109,9 +13181,11 @@ struct llm_build_context {
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb, il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
@@ -13324,9 +13398,11 @@ struct llm_build_context {
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, false,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
true, hparams.expert_weights_scale,
|
||||
(enum llm_expert_gating_func_type) hparams.expert_gating_func,
|
||||
cb, il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
@@ -18547,6 +18623,7 @@ struct llama_data_read {
|
||||
read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (n_seq_id != 0) {
|
||||
llama_batch_free(batch);
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
||||
return false;
|
||||
}
|
||||
@@ -19732,6 +19809,21 @@ static int32_t llama_chat_apply_template_internal(
|
||||
if (add_ass) {
|
||||
ss << "Assistant:";
|
||||
}
|
||||
} else if (tmpl == "deepseek3" || tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) {
|
||||
// DeepSeek-V3
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << message->content << "\n\n";
|
||||
} else if (role == "user") {
|
||||
ss << LU8("<|User|>") << message->content;
|
||||
} else if (role == "assistant") {
|
||||
ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>");
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << LU8("<|Assistant|>");
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
|
||||
@@ -648,18 +648,25 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
||||
{ "\\p{N}", codepoint_flags::NUMBER },
|
||||
{ "\\p{L}", codepoint_flags::LETTER },
|
||||
{ "\\p{P}", codepoint_flags::PUNCTUATION },
|
||||
{ "\\p{M}", codepoint_flags::ACCENT_MARK },
|
||||
{ "\\p{S}", codepoint_flags::SYMBOL },
|
||||
};
|
||||
|
||||
static const std::map<int, int> k_ucat_cpt = {
|
||||
{ codepoint_flags::NUMBER, 0xD1 },
|
||||
{ codepoint_flags::LETTER, 0xD2 },
|
||||
{ codepoint_flags::PUNCTUATION, 0xD3 },
|
||||
{ codepoint_flags::ACCENT_MARK, 0xD4 },
|
||||
{ codepoint_flags::SYMBOL, 0xD5 },
|
||||
|
||||
};
|
||||
|
||||
static const std::map<int, std::string> k_ucat_map = {
|
||||
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
|
||||
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}i
|
||||
{ codepoint_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
|
||||
{ codepoint_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
|
||||
};
|
||||
|
||||
// compute collapsed codepoints only if needed by at least one regex
|
||||
|
||||
Reference in New Issue
Block a user