diff --git a/src/llama.cpp b/src/llama.cpp index 0ee01f1e..010fc358 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13425,31 +13425,46 @@ struct llm_build_context { ggml_row_size(q->type, n_embd_head_qk_nope)); cb(q_rope, "q_rope", il); + q_rope = ggml_rope_ext( + ctx0, q_rope, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(q_rope, "q_rope", il); + // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} struct ggml_tensor * kv_rope_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); cb(kv_rope_compresseed, "kv_rope_compresseed", il); + // and {n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * k_rope = ggml_view_3d(ctx0, kv_rope_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_rope_compresseed->nb[1], + kv_rope_compresseed->nb[1], + ggml_row_size(kv_rope_compresseed->type, kv_lora_rank)); + cb(k_rope, "k_rope", il); + + // shared RoPE key + k_rope = ggml_rope_ext( + ctx0, k_rope, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(k_rope, "k_rope", il); + // split into {kv_lora_rank, n_tokens} struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_rope_compresseed, kv_lora_rank, n_tokens, kv_rope_compresseed->nb[1], 0); cb(kv_compressed, "kv_compressed", il); + //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm + kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(kv_compressed, "kv_compressed", il); + if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) { - // and {n_embd_head_qk_rope, n_tokens} - struct ggml_tensor * k_rope = ggml_view_3d(ctx0, kv_rope_compresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_rope_compresseed->nb[1], - kv_rope_compresseed->nb[1], - ggml_row_size(kv_rope_compresseed->type, kv_lora_rank)); - cb(k_rope, "k_rope", il); - - //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm - kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, - model.layers[il].attn_kv_a_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(kv_compressed, "kv_compressed", il); - struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head); cb(kv_cache_view, "kv_cache_view", il); @@ -13476,21 +13491,6 @@ struct llm_build_context { 0); cb(kv_cache_trans, "kv_cache_trans", il); - q_rope = ggml_rope_ext( - ctx0, q_rope, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(q_rope, "q_rope", il); - - // shared RoPE key - k_rope = ggml_rope_ext( - ctx0, k_rope, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(k_rope, "k_rope", il); - struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head); cb(kr_cache_view, "kr_cache_view", il); @@ -13504,10 +13504,6 @@ struct llm_build_context { 0); cb(kr_cache, "kr_cache", il); - printf("kv_lora_rank = %d, n_head = %d, n_embd_head_qk_nope = %d, n_embd_head_v = %d\n", kv_lora_rank, (int)n_head, n_embd_head_qk_nope, (int)n_embd_head_v); - printf("wk_b: %d x %d x %d x %d, wkv_b: %d x %d x %d x %d\n", - (int)model.layers[il].wk_b->ne[0], (int)model.layers[il].wk_b->ne[1], (int)model.layers[il].wk_b->ne[2], (int)model.layers[il].wk_b->ne[3], - (int)model.layers[il].wkv_b->ne[0], (int)model.layers[il].wkv_b->ne[1], (int)model.layers[il].wkv_b->ne[2], (int)model.layers[il].wkv_b->ne[3]); struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank) * n_embd_head_qk_nope, 0); cb(wk_b, "wk_b", il); @@ -13517,7 +13513,7 @@ struct llm_build_context { //ggml_tensor * wkv_b = ggml_view_2d(ctx0, model.layers[il].wkv_b, kv_lora_rank, n_embd_head_qk_nope*n_head, // ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank), 0); //ggml_tensor * ik1 = ggml_mul_mat(ctx0, wkv_b, kv_cache); - //ggml_tensor * ik2 = ggml_view_3d(ctx0, ik1, n_embd_head_qk_nope, + //ggml_tensor * ik2 = ggml_view_3d(ctx0, ik1, n_embd_head_qk_nope, struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); cb(q_nope2, "q_nope2", il); @@ -13528,10 +13524,10 @@ struct llm_build_context { } struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2); cb(kq_nope, "kq_nope", il); - printf("kq_nope = kv_cache(%d x %d x %d x %d) * [wk_b (%d x %d x %d x %d) * q_nope (%d x %d x %d x %d)]\n", - (int)kv_cache->ne[0], (int)kv_cache->ne[1], (int)kv_cache->ne[2], (int)kv_cache->ne[3], - (int)wk_b->ne[0], (int)wk_b->ne[1], (int)wk_b->ne[2], (int)wk_b->ne[3], - (int)q_nope->ne[0], (int)q_nope->ne[1], (int)q_nope->ne[2], (int)q_nope->ne[3]); + //printf("kq_nope = kv_cache(%d x %d x %d x %d) * [wk_b (%d x %d x %d x %d) * q_nope (%d x %d x %d x %d)]\n", + // (int)kv_cache->ne[0], (int)kv_cache->ne[1], (int)kv_cache->ne[2], (int)kv_cache->ne[3], + // (int)wk_b->ne[0], (int)wk_b->ne[1], (int)wk_b->ne[2], (int)wk_b->ne[3], + // (int)q_nope->ne[0], (int)q_nope->ne[1], (int)q_nope->ne[2], (int)q_nope->ne[3]); if (!pp_opt) { kq_nope = ggml_permute(ctx0, kq_nope, 0, 2, 1, 3); @@ -13589,19 +13585,6 @@ struct llm_build_context { } else { - // and {n_embd_head_qk_rope, n_tokens} - struct ggml_tensor * k_rope = ggml_view_3d(ctx0, kv_rope_compresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_rope_compresseed->nb[1], - kv_rope_compresseed->nb[1], - ggml_row_size(kv_rope_compresseed->type, kv_lora_rank)); - cb(k_rope, "k_pe", il); - - //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm - kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, - model.layers[il].attn_kv_a_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(kv_compressed, "kv_compressed", il); - // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); cb(kv, "kv", il); @@ -13628,23 +13611,6 @@ struct llm_build_context { 0); cb(v_states, "v_states", il); - //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE - q_rope = ggml_rope_ext( - ctx0, q_rope, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(q_rope, "q_rope", il); - - // shared RoPE key - //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE - k_rope = ggml_rope_ext( - ctx0, k_rope, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(k_rope, "k_rope", il); - struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_rope, 0); cb(q_states, "q_states", il);