From f8b66238fac94721e0eed3e2e6e20b9a381988c0 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 24 Sep 2025 16:52:54 +0200 Subject: [PATCH] Fused matrix multiplications (CUDA and CPU) (#796) * Quick attempt to fuse the Q, K, V GEMMs Doesn't do much on the CPU * Doesn't do much on the GPU either * Use llm_build_mul_mat_qkv * This is not needed * Revert timing on committed by mistake --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 87 +++++-- ggml/src/ggml.c | 154 ++++-------- src/llama.cpp | 567 ++++++++++-------------------------------- 3 files changed, 244 insertions(+), 564 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 998c4a23..51042966 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2143,7 +2143,62 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } } -static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const ggml_cgraph * cgraph, int node_n, bool is_gemv) { + + auto stream = ctx.stream(); + + auto ne10_padded = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); + auto nb10_padded = ne10_padded*sizeof(block_q8_1)/QK8_1; + auto quantized_size = nb10_padded*ggml_nrows(src1); + if (!is_gemv) { + quantized_size += get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq); + } + ggml_cuda_pool_alloc src1_quantized(ctx.pool(), quantized_size); + if (is_gemv) { + quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], src1->ne[1], src1->ne[2], ne10_padded, + src0->type, stream); + CUDA_CHECK(cudaGetLastError()); + + ggml_cuda_op_mul_mat_vec_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data, + 0, src0->ne[1], src1->ne[1], ne10_padded, stream); + CUDA_CHECK(cudaGetLastError()); + } else { + quantize_mmq_q8_1_cuda((const float *)src1->data, src1_quantized.get(), src1->ne[0], src1->ne[1], 1, ne10_padded, src0->type, stream); + CUDA_CHECK(cudaGetLastError()); + + ggml_cuda_op_mul_mat_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data, + 0, src0->ne[1], src1->ne[1], ne10_padded, stream); + CUDA_CHECK(cudaGetLastError()); + } + + if (!cgraph) return node_n; + + while (node_n + 1 < cgraph->n_nodes) { + dst = cgraph->nodes[node_n+1]; + if (ggml_is_empty(dst) || dst->op == GGML_OP_RESHAPE || dst->op == GGML_OP_TRANSPOSE || dst->op == GGML_OP_VIEW + || dst->op == GGML_OP_PERMUTE || dst->op == GGML_OP_NONE) { + ++node_n; continue; + } + if (dst->op != GGML_OP_MUL_MAT || dst->src[1] != src1 || !ggml_is_quantized(dst->src[0]->type)) break; + if (!is_gemv && mmq_get_q8_1_ds_layout(src0->type) != mmq_get_q8_1_ds_layout(dst->src[0]->type)) break; + if (is_gemv) { + ggml_cuda_op_mul_mat_vec_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(), + (float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream); + } else { + ggml_cuda_op_mul_mat_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(), + (float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream); + } + CUDA_CHECK(cudaGetLastError()); + ++node_n; + } + + return node_n; + +} + +static int ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const ggml_cgraph * cgraph, int node_n) { const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer); // If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q. @@ -2188,6 +2243,10 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); } + if (!split && (use_mul_mat_vec_q || use_mul_mat_q) && src1->ne[2]*src1->ne[3] == 1) { + return ggml_cuda_mul_mat_q(ctx, src0, src1, dst, cgraph, node_n, use_mul_mat_vec_q); + } + // debug helpers //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); @@ -2215,6 +2274,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } else { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); } + return node_n; } struct mmid_row_mapping { @@ -2454,7 +2514,7 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * src1_row.data = src1_original + i11*nb11 + i12*nb12; dst_row.data = dst_original + i1*nb1 + i2*nb2; - ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row, nullptr, 0); } } } else { @@ -2505,7 +2565,7 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; - ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row, nullptr, 0); { dim3 block_dims(std::min((unsigned int)ne0, 768u)); @@ -2889,7 +2949,7 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, 0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); } else { - ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row, nullptr, 0); } CUDA_CHECK(cudaGetLastError()); @@ -2906,7 +2966,7 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, 0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); } else { - ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row, nullptr, 0); } CUDA_CHECK(cudaGetLastError()); @@ -2947,8 +3007,7 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te (int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]); first = false; } - ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); - //ggml_cuda_mul_mat(ctx, next->src[0], &dst_row, &final_dst); + ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst, nullptr, 0); CUDA_CHECK(cudaGetLastError()); dim3 block_dims(std::min((unsigned int)next->ne[0], 768u)); @@ -3031,8 +3090,7 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } -static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next, - const ggml_cgraph * cgraph, int & i) { +static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, const ggml_cgraph * cgraph, int & i) { // why is this here instead of mul_mat? if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) { ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); @@ -3042,6 +3100,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg int64_t tim1 = ggml_time_us(); #endif + auto next = i < cgraph->n_nodes - 1 ? cgraph->nodes[i+1] : nullptr; + switch (dst->op) { case GGML_OP_REPEAT: ggml_cuda_op_repeat(ctx, dst); @@ -3112,7 +3172,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_hardswish(ctx, dst); break; default: - return false; + return -1; } break; case GGML_OP_NORM: @@ -3148,9 +3208,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_MUL_MAT: if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); - return false; + return -1; } else { - ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst); + i = ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst, cgraph, i); } break; case GGML_OP_MUL_MAT_ID: @@ -3569,7 +3629,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - ggml_tensor * next = i < cgraph->n_nodes-1 ? cgraph->nodes[i+1] : nullptr; if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; @@ -3604,7 +3663,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx GGML_UNUSED(integrated); #endif // NDEBUG - bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, cgraph, i); + bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, cgraph, i); if (!ok) { GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 078d7219..c8cd6b17 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -14974,9 +14974,11 @@ static inline uint32_t simple_gcd(uint32_t a, uint32_t b) { return a; } -static void ggml_compute_forward_mul_mat( +static int ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, - struct ggml_tensor * dst) { + struct ggml_tensor * dst, + const struct ggml_cgraph * cgraph, + int node_n) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -15017,12 +15019,6 @@ static void ggml_compute_forward_mul_mat( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if GGML_USE_LLAMAFILE - // broadcast factors - const int64_t r2 = ne12 / ne02; - const int64_t r3 = ne13 / ne03; -#endif - #if GGML_USE_IQK_MULMAT if (ith == 0) { static bool first_time = true; @@ -15040,34 +15036,10 @@ static void ggml_compute_forward_mul_mat( ne02, ne03, ne12, ne13, nb02, nb03, nb12, nb13, nb2/sizeof(float), nb3/sizeof(float), src0->type, src0->data, nb01, src1->type, src1->data, nb11, - (float *)dst->data, nb1/sizeof(float), ith, nth)) return; + (float *)dst->data, nb1/sizeof(float), ith, nth)) return node_n; } #endif -#if GGML_USE_LLAMAFILE - - const bool src1_cont = ggml_is_contiguous(src1); - - if (src1_cont) { - for (int64_t i13 = 0; i13 < ne13; i13++) - for (int64_t i12 = 0; i12 < ne12; i12++) - if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), - (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, - nb01/ggml_type_size(src0->type), - (const char *)src1->data + i12*nb12 + i13*nb13, - nb11/ggml_type_size(src1->type), - (char *)dst->data + i12*nb2 + i13*nb3, - nb1/ggml_type_size(dst->type), - ith, nth, - src0->type, - src1->type, - dst->type)) - goto UseGgmlGemm1; - return; - } -UseGgmlGemm1:; -#endif - if (src1->type != vec_dot_type) { char * wdata = params->wdata; @@ -15092,51 +15064,27 @@ UseGgmlGemm1:; } else { -//#ifdef GGML_USE_IQK_MULMAT -// int ts = type_traits[vec_dot_type].type_size; -// int bs = type_traits[vec_dot_type].blck_size; -// int64_t blocks_per_row = ne10/bs; -// int64_t num_blocks = ne11*ne12*ne13*blocks_per_row; -// int gcd = simple_gcd(128, ts); // 128 is to cover cache line sizes for common architectures without getting involved -// // with trying to get it from ggml -// int64_t num_blocks_gcd = (num_blocks + gcd - 1)/gcd; -// int64_t block_per_thread = ((num_blocks_gcd + nth - 1)/nth)*gcd; -// int64_t first_block = ith*block_per_thread; -// int64_t last_block = MIN(num_blocks, first_block + block_per_thread); -// while (first_block < last_block) { -// int64_t i13 = first_block/(ne11*ne12*blocks_per_row); -// int64_t i12 = (first_block - i13*ne11*ne12*blocks_per_row)/(ne11*blocks_per_row); -// int64_t i11 = (first_block - (i13*ne12 + i12)*ne11*blocks_per_row)/blocks_per_row; -// int64_t i10 = first_block % blocks_per_row; -// int64_t blocks_to_do = MIN(blocks_per_row - i10, last_block - first_block); -// from_float((float *)((char *)src1->data + i13*nb13 + i12*nb12 + i11*nb11) + i10*bs, -// (void *)(wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + i10*ts), blocks_to_do*bs); -// first_block += blocks_to_do; -// } -//#else - - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - int64_t i11_processed = 0; + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + int64_t i11_processed = 0; #if !GGML_USE_IQK_MULMAT - if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) { - for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { - from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - 4, ne10, blck_size_interleave); + if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) { + for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { + from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + 4, ne10, blck_size_interleave); + } + i11_processed = ne11 - ne11 % 4; } - i11_processed = ne11 - ne11 % 4; - } #endif - for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { - from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - ne10); + for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); + } } } } -//#endif - } ggml_barrier(params->shared); @@ -15145,17 +15093,10 @@ UseGgmlGemm1:; if (ith == 0) printf("quantize(%s): %d us\n", dst->name, (int)(t2 - t1)); #endif - if (ith == 0) { - // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - //atomic_store(¶ms->shared->current_chunk, nth); - } - - ggml_barrier(params->shared); } const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; -#if GGML_USE_IQK_MULMAT if (src1->type != vec_dot_type && dst->type == GGML_TYPE_F32) { const size_t row_size = ggml_row_size(vec_dot_type, ne10); if (iqk_mul_mat_4d(ne01, ne11, ne00, @@ -15163,32 +15104,27 @@ UseGgmlGemm1:; nb2/sizeof(float), nb3/sizeof(float), src0->type, src0->data, nb01, vec_dot_type, wdata, row_size, - (float *)dst->data, nb1/sizeof(float), ith, nth)) return; + (float *)dst->data, nb1/sizeof(float), ith, nth)) { + while (node_n < cgraph->n_nodes - 1 && + cgraph->nodes[node_n+1]->op == GGML_OP_MUL_MAT && + cgraph->nodes[node_n+1]->src[1] == src1 && + type_traits[cgraph->nodes[node_n+1]->src[0]->type].vec_dot_type == vec_dot_type) { + struct ggml_tensor * dst_next = cgraph->nodes[node_n+1]; + struct ggml_tensor * src0_next = dst_next->src[0]; + GGML_ASSERT(dst_next->type == GGML_TYPE_F32); + GGML_ASSERT(src0_next->ne[0] == ne00); + //if (ith == 0) printf("Fusing %s\n", src0_next->name); + if (!iqk_mul_mat_4d(src0_next->ne[1], ne11, ne00, + src0_next->ne[2], src0_next->ne[3], ne12, ne13, src0_next->nb[2], src0_next->nb[3], row_size*ne11, row_size*ne11*ne12, + dst_next->nb[2]/sizeof(float), dst_next->nb[3]/sizeof(float), + src0_next->type, src0_next->data, src0_next->nb[1], + vec_dot_type, wdata, row_size, + (float *)dst_next->data, dst_next->nb[1]/sizeof(float), ith, nth)) break; + ++node_n; + } + } + return node_n; } -#endif - -#if GGML_USE_LLAMAFILE - if (src1->type != vec_dot_type) { - const size_t row_size = ggml_row_size(vec_dot_type, ne10); - - for (int64_t i13 = 0; i13 < ne13; i13++) - for (int64_t i12 = 0; i12 < ne12; i12++) - if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), - (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, - nb01/ggml_type_size(src0->type), - (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, - row_size/ggml_type_size(vec_dot_type), - (char *)dst->data + i12*nb2 + i13*nb3, - nb1/ggml_type_size(dst->type), - ith, nth, - src0->type, - vec_dot_type, - dst->type)) - goto UseGgmlGemm2; - return; - } -UseGgmlGemm2:; -#endif if (ith == 0) { atomic_store(¶ms->shared->current_chunk, nth); @@ -15243,7 +15179,7 @@ UseGgmlGemm2:; int64_t src0_end = ((ith + 1) * ne01) / nth; src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start; src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end; - if (src0_start >= src0_end) return; + if (src0_start >= src0_end) return node_n; // If there are more than three rows in src1, use gemm; otherwise, use gemv. if (gemm && (ne11 > 3)) { @@ -15255,7 +15191,7 @@ UseGgmlGemm2:; (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1, src0_end - src0_start); } - return; + return node_n; } // The first chunk comes from our thread_id, the rest will get auto-assigned. @@ -15279,6 +15215,8 @@ UseGgmlGemm2:; current_chunk = atomic_fetch_add(¶ms->shared->current_chunk, 1); } + + return node_n; } // ggml_compute_forward_mul_mat_id @@ -20392,7 +20330,7 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml GGML_ASSERT(params); if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { - return false; + return i; } #if IK_PRINT_TIMING @@ -20506,7 +20444,7 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml } break; case GGML_OP_MUL_MAT: { - ggml_compute_forward_mul_mat(params, tensor); + i = ggml_compute_forward_mul_mat(params, tensor, cgraph, i); } break; case GGML_OP_MUL_MAT_ID: { diff --git a/src/llama.cpp b/src/llama.cpp index 789ee8ce..54b1048f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8861,6 +8861,40 @@ struct llm_build_context { return lctx.inp_KQ_mask_cross; } + std::tuple llm_build_mul_mat_qkv(ggml_cgraph * gf, ggml_tensor * cur, + ggml_tensor * wq, ggml_tensor * bq, + ggml_tensor * wk, ggml_tensor * bk, + ggml_tensor * wv, ggml_tensor * bv, + float attention_scale, int il) { + auto Qcur = llm_build_lora_mm(lctx, ctx0, wq, cur); + cb(Qcur, "Qcur", il); + auto Kcur = llm_build_lora_mm(lctx, ctx0, wk, cur); + cb(Kcur, "Kcur", il); + auto Vcur = llm_build_lora_mm(lctx, ctx0, wv, cur); + cb(Vcur, "Vcur", il); + ggml_build_forward_expand(gf, Qcur); + ggml_build_forward_expand(gf, Kcur); + ggml_build_forward_expand(gf, Vcur); + + if (attention_scale != 0) { + Qcur = ggml_scale(ctx0, Qcur, attention_scale); + cb(Qcur, "Qcur", il); + } + if (bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + if (bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + if (bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + return {Qcur, Kcur, Vcur}; + } + struct ggml_cgraph * build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -8912,31 +8946,10 @@ struct llm_build_context { // rope freq factors for llama3; may return nullptr for llama2 and other models struct ggml_tensor * rope_factors = build_rope_factors(il); - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - if (hparams.f_attention_scale != 0) { - // Why is hparams.f_attention_scale not simply absorbed into model.layers[il].wq ? - Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); - } - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, + hparams.f_attention_scale, il); if (use_rope) { Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, @@ -9137,27 +9150,10 @@ struct llm_build_context { // rope freq factors for llama3; may return nullptr for llama2 and other models struct ggml_tensor * rope_factors = build_rope_factors(il); - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, + 0.f, il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, @@ -9281,15 +9277,9 @@ struct llm_build_context { // self-attention { - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, 0, il); switch (model.type) { case MODEL_7B: Qcur = ggml_rope_ext( @@ -9396,15 +9386,9 @@ struct llm_build_context { // self-attention { - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, 0, il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9623,27 +9607,9 @@ struct llm_build_context { // self-attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, 0.f, il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, @@ -10015,14 +9981,9 @@ struct llm_build_context { // self-attention { - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, 0, il); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); cb(Kcur, "Kcur", il); @@ -10551,27 +10512,9 @@ struct llm_build_context { // self-attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, 0.f, il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur", il); @@ -10811,21 +10754,9 @@ struct llm_build_context { // self-attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, 0.f, il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, @@ -10926,21 +10857,9 @@ struct llm_build_context { // self_attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, 0.f, il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, @@ -11071,17 +10990,11 @@ struct llm_build_context { // self-attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, 0, il); - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); cb(Qcur, "Qcur_normed", il); @@ -11092,7 +11005,7 @@ struct llm_build_context { ); cb(Qcur, "Qcur", il); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); cb(Kcur, "Kcur_normed", il); @@ -11189,17 +11102,11 @@ struct llm_build_context { // self_attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, 0, il); - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); cb(Qcur, "Qcur_normed", il); @@ -11210,7 +11117,7 @@ struct llm_build_context { ); cb(Qcur, "Qcur", il); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); cb(Kcur, "Kcur_normed", il); @@ -11559,16 +11466,9 @@ struct llm_build_context { // self-attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, 0, il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr, n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -11878,28 +11778,9 @@ struct llm_build_context { // self-attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - // if (model.layers[il].bq) { - // Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - // cb(Qcur, "Qcur", il); - // } - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - // if (model.layers[il].bk) { - // Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - // cb(Kcur, "Kcur", il); - // } - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - // if (model.layers[il].bv) { - // Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - // cb(Vcur, "Vcur", il); - // } - + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, 0, il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -11996,28 +11877,9 @@ struct llm_build_context { // self-attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, 0.f, il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -12127,27 +11989,9 @@ struct llm_build_context { // self-attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, 0.f, il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, @@ -12260,16 +12104,9 @@ struct llm_build_context { // self-attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, 0, il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -12373,16 +12210,9 @@ struct llm_build_context { // self-attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, 0, il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -12517,15 +12347,9 @@ struct llm_build_context { // self-attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, 0, il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens); Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); @@ -12628,28 +12452,9 @@ struct llm_build_context { // self-attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, 0.f, il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -12896,27 +12701,9 @@ struct llm_build_context { // self-attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, 0.f, il); if (model.layers[il].attn_q_norm) { Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, @@ -13444,15 +13231,9 @@ struct llm_build_context { // self-attention { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, 0, il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, @@ -14086,24 +13867,9 @@ struct llm_build_context { // self-attention { - // Q, K, V projections - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, 0.f, il); // reshape for multi-head Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); @@ -14413,29 +14179,10 @@ struct llm_build_context { { // rope freq factors for llama3; may return nullptr for llama2 and other models struct ggml_tensor * rope_factors = build_rope_factors(il); - // printf("%f\n\n\n\n",((float*)rope_factors->data)[1]); - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, 0, il); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, @@ -14568,27 +14315,9 @@ struct llm_build_context { // rope freq factors for 128k context struct ggml_tensor * rope_factors = build_rope_factors(il); - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, 0.f, il); if (is_sliding) { Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, @@ -14693,14 +14422,9 @@ struct llm_build_context { // self-attention { - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_enc, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_enc, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_enc, cur); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, 0, il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); @@ -14828,14 +14552,9 @@ struct llm_build_context { // self-attention { - struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, nullptr, + model.layers[il].wk, nullptr, + model.layers[il].wv, nullptr, 0, il); llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); @@ -15817,26 +15536,9 @@ struct llm_build_context { struct ggml_tensor * rope_factors = build_rope_factors(il); // compute Q and K and RoPE them - ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, 0.f, il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); @@ -15846,16 +15548,14 @@ struct llm_build_context { n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); + cb(Kcur, "Kcur", il); Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); cb(Kcur, "Kcur_norm", il); @@ -15966,26 +15666,9 @@ struct llm_build_context { // self-attention { - ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, 0.f, il); Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,