diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 7809d855..a7757dd9 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -6265,6 +6265,11 @@ ggml_cgraph * llm_build_context::build_deepseek2() { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + ggml_tensor * inp_attn_scale = nullptr; + if (hparams.f_attn_temp_scale != 0.0f) { + inp_attn_scale = build_input_scale(n_tokens); + } + // whether to use n_tokens as the matrix dimension during multiplication or n_head // n_tokens is higher during prompt processing, this allows to optimize for this case bool pp_opt = n_tokens >= 128; // Is it a fixed constant or is it somehow relared to n_head? original: n_tokens > n_head; @@ -6455,10 +6460,14 @@ ggml_cgraph * llm_build_context::build_deepseek2() { } cb(k_rope, "k_rope", il); - //auto q = ggml_concat(ctx0, q_nope, q_rope, 0); auto q = ggml_concat(ctx0, q_rope, q_nope, 0); - q = ggml_permute(ctx0, q, 0, 2, 1, 3); cb(q, "q_concat", il); + if (inp_attn_scale) { + q = ggml_mul(ctx0, q, inp_attn_scale); + cb(q, "q_concat_temp_scaled", il); + } + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_concat_permuted", il); ggml_build_forward_expand(gf, q); @@ -6528,9 +6537,12 @@ ggml_cgraph * llm_build_context::build_deepseek2() { struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); cb(q_nope2, "q_nope2", il); - //ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); ggml_tensor * q = ggml_concat(ctx0, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), q_nope2, 0); cb(q, "q", il); + if (inp_attn_scale) { + q = ggml_mul(ctx0, q, inp_attn_scale); + cb(q, "q_temp_scales", il); + } if (lctx.cparams.flash_attn && (lctx.cparams.mla_attn == 1 || lctx.cparams.mla_attn == 3)) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.k_l[il], @@ -6680,6 +6692,11 @@ ggml_cgraph * llm_build_context::build_deepseek2() { struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_rope, 0); cb(q_states, "q_states", il); + if (inp_attn_scale) { + q_states = ggml_mul(ctx0, q_states, inp_attn_scale); + cb(q_states, "q_states_temp_scales", il); + } + struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_rope, q_rope), 0); cb(k_states, "k_states", il); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 81afc62b..5fe0b2b7 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -786,6 +786,11 @@ void llm_load_hparams( } ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); + hparams.f_attn_temp_scale = 0; + hparams.n_attn_temp_floor_scale = 0; + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); + switch (hparams.n_layer) { case 27: model.type = e_model::MODEL_16B; break; case 60: model.type = e_model::MODEL_236B; break;