diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 4c801e76..04e5a142 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -6501,28 +6501,23 @@ ggml_cgraph * llm_build_context::build_cohere2() { // rope freq factors for 128k context struct ggml_tensor * rope_factors = build_rope_factors(il); - auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, + model.layers[il].wqkv, model.layers[il].bqkv, + model.layers[il].wq, model.layers[il].bq, model.layers[il].wk, model.layers[il].bk, - model.layers[il].wv, model.layers[il].bv, 0.f, il); + model.layers[il].wv, model.layers[il].bv, nullptr, nullptr, 0.f, il); if (is_sliding) { - Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - } else { - // For non-sliding layers, just reshape without applying RoPE - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cb(Qcur, "Qcur", il); - - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - cb(Kcur, "Kcur", il); - } + }; cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il, nullptr, @@ -6564,6 +6559,7 @@ ggml_cgraph * llm_build_context::build_cohere2() { // lm_head cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "output", -1); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale);