From 5d68e4eb359c24e906ffdd8871a0d45d7e45092f Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 26 Nov 2025 09:27:12 +0000 Subject: [PATCH] WIP: it runs with wrong result But it also looks like the backend scheduler is not going to help: * It copies mask and input positions to GPU 0 * => RoPE ops must run on GPU 0 * => To proceed attn evaluation, GPU 1 must wait for GPU 0 to finish its entire attn calculation * Same with FFN. The rms_norm gets scheduled on GPU 0. Hence, GPU 1 must wait for GPU 0 to finish its entore FFN calculation before it can start (as it needs to copy the result of rms_norm from GPU 0) * => Seems useless without writing a bespoke TP scheduling --- ggml/src/ggml-backend.cpp | 2 + ggml/src/ggml-cuda.cu | 78 +++++++++-- src/llama-build-context.cpp | 261 ++++++++++++++++++++++++++++++------ src/llama-build-context.h | 7 +- src/llama-load-tensors.cpp | 4 +- src/llama.cpp | 36 ++++- 6 files changed, 322 insertions(+), 66 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index e42b05ec..4d70de00 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1216,8 +1216,10 @@ static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, co return -1; } + //printf("%s: have %d backends, buffer is %s\n", __func__, sched->n_backends, ggml_backend_buffer_name(buffer)); // find highest prio backend that supports the buffer type and the op for (int i = 0; i < sched->n_backends; i++) { + //printf(" Checking bacckend %d (%s)\n", i, ggml_backend_name(sched->backends[i])); if (ggml_backend_supports_buft(sched->backends[i], buffer->buft) && ggml_backend_supports_op(sched->backends[i], op)) { return i; diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index d2f669f6..51b7ddb6 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -602,7 +602,7 @@ static ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = { /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer, /* .get_base = */ ggml_backend_cuda_buffer_get_base, /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor, - /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor, + /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor, @@ -811,6 +811,10 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor([[maybe_unused] printf(" allocated %zu bytes for tensor %s of type %s, dim = %ld x %ld x %ld. padding: %zu\n", padded_size, split->name, ggml_type_name(split->type), split->ne[0], split->ne[1], split->ne[2], padded_size - size); split->data = buf; + auto ctx = new ggml_backend_cuda_buffer_context(i, buf); + auto buft = ggml_backend_cuda_buffer_type(i); + split->buffer = ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, padded_size); + ggml_backend_buffer_set_usage(split->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); } return; @@ -862,7 +866,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor([[maybe_unused] //tensor->extra = extra; } -GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { +GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor([[maybe_unused]] ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { if (!tensor->extra) return; printf("%s(%s)\n", __func__, tensor->name); @@ -873,19 +877,64 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buf auto extra = (ggml_split_tensor_t *)tensor->extra; GGML_ASSERT(extra->n_device <= ggml_backend_cuda_get_device_count()); - if (extra->split_dim != 0) { - fprintf(stderr, "Split tensor copy not yet immplemented for dim 0\n"); - return; - } - - size_t cur_offset = 0; for (int i = 0; i < extra->n_device; ++i) { auto split = extra->splits[i]; if (!split) continue; - auto size = ggml_nbytes(split); - const char * buf_host = (const char *)data + cur_offset; - CUDA_CHECK(cudaMemcpyAsync(split->data, buf_host, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); - cur_offset += size; + printf(" Split %d: %p, %p, %s\n", i, (void *)split->data, (void *)split->buffer, split->buffer ? ggml_backend_buffer_name(split->buffer) : "none"); + } + + if (extra->split_dim == 0) { + if (tensor->type >= GGML_TYPE_Q4_0_R8) { + GGML_ABORT("Dim 0 copy of row-interleaved quants is not supported yet"); + } + auto tt = ggml_internal_get_type_traits(tensor->type); + //if (tt.row_meta_size > 0) { + // GGML_ABORT("Dim 0 copy is not implemented for tensors with row meta data\n"); + //} + GGML_ASSERT(ggml_is_contiguous(tensor)); + int nrows = ggml_nrows(tensor); + auto bs = tt.blck_size; + auto ts = tt.type_size; + auto row_size = ggml_row_size(tensor->type, tensor->ne[0]); + int ne = 0; + for (int i = 0; i < extra->n_device; ++i) { + auto split = extra->splits[i]; + if (!split) continue; + ggml_cuda_set_device(i); + GGML_ASSERT(split->type == tensor->type); + GGML_ASSERT((int)ggml_nrows(split) == nrows); + GGML_ASSERT(split->ne[0] % bs == 0); + auto source_offset = tt.row_meta_size + (ne / bs) * ts; + auto chost0 = (const char *)data; + //auto chost = (const char *)data + source_offset; + auto split_row_size = ggml_row_size(split->type, split->ne[0]); + for (int ir = 0; ir < nrows; ++ir) { + auto dst = (char *)split->data + ir*split_row_size; + if (tt.row_meta_size > 0) { + CUDA_CHECK(cudaMemcpyAsync(dst, chost0, tt.row_meta_size, cudaMemcpyHostToDevice, cudaStreamPerThread)); + } + CUDA_CHECK(cudaMemcpyAsync(dst + tt.row_meta_size, chost0 + source_offset, + split_row_size - tt.row_meta_size, cudaMemcpyHostToDevice, cudaStreamPerThread)); + chost0 += row_size; + } + ne += split->ne[0]; + } + } + else if (extra->split_dim == 1) { + size_t cur_offset = 0; + for (int i = 0; i < extra->n_device; ++i) { + auto split = extra->splits[i]; + if (!split) continue; + ggml_cuda_set_device(i); + auto size = ggml_nbytes(split); + const char * buf_host = (const char *)data + cur_offset; + CUDA_CHECK(cudaMemcpyAsync(split->data, buf_host, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); + cur_offset += size; + } + } + else { + fprintf(stderr, "%s: not implemented for split dim %d\n", __func__, extra->split_dim == 0); + GGML_ABORT("fatal error"); } for (int i = 0; i < extra->n_device; ++i) { @@ -3023,6 +3072,7 @@ static inline bool ops_are_same_device(const ggml_cgraph * cgraph, int first, in static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, const ggml_cgraph * cgraph, int & i) { // why is this here instead of mul_mat? if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) { + printf("%s: split buffer for %s(%s)\n", __func__, ggml_op_name(dst->op), dst->name); ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); } @@ -3034,7 +3084,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg auto fusion = ctx.fusion; - //printf("%4d %s(%s)\n", i, ggml_op_name(dst->op), dst->name); + printf("%4d %s(%s) on device %d. time = %ld\n", i, ggml_op_name(dst->op), dst->name, ctx.device, ggml_time_us()); switch (dst->op) { case GGML_OP_ARGMAX: ggml_cuda_argmax(ctx, dst); @@ -3759,6 +3809,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx // TODO const bool integrated = false; //ggml_cuda_info().devices[cuda_ctx->device].integrated; + printf("======================== %s: graph with %d nodes on device %d. time = %ld\n", __func__, cgraph->n_nodes, cuda_ctx->device, ggml_time_us()); while (!graph_evaluated_or_captured) { // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // With the use of CUDA graphs, the execution will be performed by the graph launch. @@ -4183,6 +4234,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons } GGML_CALL static bool ggml_backend_cuda_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { + //printf("%s(%s, %s): %d, %d\n", __func__, ggml_backend_name(backend), ggml_backend_buft_name(buft), ggml_backend_buft_is_cuda_split(buft), ggml_backend_buft_is_cuda(buft)); if (ggml_backend_buft_is_cuda_split(buft)) { return true; } diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 9588d1b9..f89d4a2f 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -636,6 +636,44 @@ ggml_tensor * llm_build_context::llm_build_ffn( llm_ffn_gate_type type_gate, const llm_build_cb & cb, int il) { + if (!up_b && !up_s && !gate_b && !gate_s && !down_b && !down_s && + up->extra && gate->extra && down->extra && type_gate == LLM_FFN_PAR && + (type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) { + auto unary_op = type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : + type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU; + auto u = (ggml_split_tensor_t *)up->extra; + auto g = (ggml_split_tensor_t *)gate->extra; + auto d = (ggml_split_tensor_t *)down->extra; + GGML_ASSERT(u->n_device == g->n_device && u->n_device == d->n_device); + std::vector ffn; + ffn.reserve(u->n_device); + for (int id = 0; id < u->n_device; ++id) { + int il_cb = 1000*id + il; + auto split_u = u->splits[id]; + auto split_g = g->splits[id]; + auto split_d = d->splits[id]; + GGML_ASSERT((!split_u && !split_g && split_d) || (split_u && split_g && split_d)); + if (!split_u) continue; + cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op); + cb(cur, "ffn_up_gate", il_cb); + cur = llm_build_lora_mm(lctx, ctx, split_d, cur); + cb(cur, "ffn_down", il_cb); + if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + ffn.push_back(cur); + } + if (ffn.size() == 1) return ffn.front(); + cur = ggml_add(ctx, ffn[0], ffn[1]); + cb(cur, "combine_ffn", il); + for (int id = 2; id < int(ffn.size()); ++id) { + cur = ggml_add(ctx, cur, ffn[id]); + cb(cur, "combine_ffn", il); + } + return cur; + } + if (lctx.cparams.fused_up_gate && up && gate && !up_b && !up_s && !gate_b && !gate_s && type_gate == LLM_FFN_PAR && (type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) { @@ -1243,7 +1281,7 @@ std::tuple llm_build_context::llm_buil ggml_tensor * wq, ggml_tensor * bq, ggml_tensor * wk, ggml_tensor * bk, ggml_tensor * wv, ggml_tensor * bv, - float attention_scale, int il) { + float attention_scale, int il) const { auto Qcur = llm_build_lora_mm(lctx, ctx0, wq, cur); cb(Qcur, "Qcur", il); auto Kcur = llm_build_lora_mm(lctx, ctx0, wk, cur); @@ -1282,7 +1320,7 @@ std::tuple llm_build_context::llm_buil ggml_tensor * wq, ggml_tensor * bq, ggml_tensor * wk, ggml_tensor * bk, ggml_tensor * wv, ggml_tensor * bv, - ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il) { + ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il) const { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); if (wqkv) { @@ -1351,13 +1389,13 @@ std::tuple llm_build_context::llm_buil } auto [Q, K, V] = llm_build_mul_mat_qkv(gf, cur, wq, bq, wk, bk, wv, bv, attention_scale, il); - auto Qcur = ggml_reshape_3d(ctx0, Q, n_embd_head, n_head, n_tokens); + auto Qcur = ggml_reshape_3d(ctx0, Q, n_embd_head, Q->ne[0]/n_embd_head, n_tokens); if (q_norm) { Qcur = llm_build_norm(ctx0, Qcur, hparams, q_norm, NULL, LLM_NORM_RMS, cb, il); cb(Qcur, "Qcur_normed", il); } - auto Kcur = ggml_reshape_3d(ctx0, K, n_embd_head, n_head_kv, n_tokens); + auto Kcur = ggml_reshape_3d(ctx0, K, n_embd_head, K->ne[0]/n_embd_head, n_tokens); if (k_norm) { Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); cb(Kcur, "Kcur_normed", il); @@ -1405,15 +1443,20 @@ ggml_cgraph * llm_build_context::build_llama() { bool use_rope = model.arch == LLM_ARCH_LLAMA4 ? (il + 1) % hparams.n_no_rope_layer_step != 0 : true; auto this_KQ_mask = hparams.n_swa > 0 && hparams.n_swa_pattern > 0 && il % hparams.n_swa_pattern < (hparams.n_swa_pattern - 1) ? KQ_mask_swa : KQ_mask; + int this_n_swa = this_KQ_mask == KQ_mask_swa ? hparams.n_swa : 0; // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); + // rope freq factors for llama3; may return nullptr for llama2 and other models + auto rope_factors = build_rope_factors(il); + // self-attention - { - // rope freq factors for llama3; may return nullptr for llama2 and other models - struct ggml_tensor * rope_factors = build_rope_factors(il); + if (use_rope) { + cur = build_std_attention(gf, cur, inp_pos, rope_factors, this_KQ_mask, nullptr, kq_scale, hparams.f_attention_scale, this_n_swa, il); + } + else { auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wqkv, model.layers[il].bqkv, @@ -1450,7 +1493,7 @@ ggml_cgraph * llm_build_context::build_llama() { cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr, - this_KQ_mask == KQ_mask_swa ? hparams.n_swa : 0); + this_n_swa); } if (il == n_layer - 1) { @@ -1555,7 +1598,23 @@ ggml_cgraph * llm_build_context::build_llama() { cb(cur, "result_norm", -1); // lm_head - cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + if (model.output->extra) { + auto output = (ggml_split_tensor_t *)model.output->extra; + std::vector o; + o.reserve(output->n_device); + for (int id = 0; id < output->n_device; ++id) { + auto split = output->splits[id]; + if (!split) continue; + o.push_back(llm_build_lora_mm(lctx, ctx0, split, cur)); + } + if (o.size() == 1) cur = o.front(); + cur = ggml_concat(ctx0, o[0], o[1], 0); + for (int id = 2; id < int(o.size()); ++id) { + cur = ggml_concat(ctx0, cur, o[id], 0); + } + } else { + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + } // For Granite architecture if (hparams.f_logit_scale) { @@ -3514,9 +3573,6 @@ ggml_cgraph * llm_build_context::build_qwen3() { ggml_cgraph * llm_build_context::build_qwen3moe() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -3532,10 +3588,6 @@ ggml_cgraph * llm_build_context::build_qwen3moe() { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - auto rope_cache = cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ? - ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow) : nullptr; - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -3543,35 +3595,11 @@ ggml_cgraph * llm_build_context::build_qwen3moe() { cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - // self_attention - { - auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, - model.layers[il].wqkv, nullptr, - model.layers[il].wqk, nullptr, - model.layers[il].wq, nullptr, model.layers[il].wk, nullptr, model.layers[il].wv, nullptr, - model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0, il); - - if (rope_cache) { - Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache); - Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache); - } else { - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - } - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - - cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); - } + cur = build_std_attention(gf, cur, inp_pos, nullptr, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il); if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -3583,8 +3611,7 @@ ggml_cgraph * llm_build_context::build_qwen3moe() { cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = - llm_build_moe_ffn(ctx0, lctx, cur, + cur = llm_build_moe_ffn(ctx0, lctx, cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -9010,3 +9037,151 @@ ggml_cgraph * llm_build_context::llama_build_graph( return result; } + +ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * rope_factors, + ggml_tensor * KQ_mask, ggml_tensor * sinks, float KQ_scale, float f_attn_scale, int n_swa, int il) { + if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn && + model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) { + if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) { + auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra; + auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra; + auto wv = (ggml_split_tensor_t *)model.layers[il].wv->extra; + auto wo = (ggml_split_tensor_t *)model.layers[il].wo->extra; + GGML_ASSERT(wq->n_device == wk->n_device && wq->n_device == wv->n_device && wq->n_device == wo->n_device); + auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra; + auto vl = (ggml_split_tensor_t *)kv_self.v_l[il]->extra; + GGML_ASSERT(wq->n_device == kl->n_device && wq->n_device == vl->n_device); + std::vector attn; attn.reserve(wq->n_device); + for (int id = 0; id < wq->n_device; ++id) { + int il_cb = 1000*id + il; + auto split_wq = wq->splits[id]; + auto split_wk = wk->splits[id]; + auto split_wv = wv->splits[id]; + auto split_wo = wo->splits[id]; + auto split_kl = kl->splits[id]; + auto split_vl = vl->splits[id]; + GGML_ASSERT((!split_wq && !split_wk && !split_wv && !split_wo && !split_kl && !split_vl) || + (split_wq && split_wk && split_wv && split_wo && split_kl && split_vl)); + if (!split_wq) continue; + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, nullptr, nullptr, nullptr, nullptr, + split_wq, nullptr, split_wk, nullptr, split_wv, nullptr, + model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il_cb); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur", il_cb); + cb(Kcur, "Kcur", il_cb); + ggml_build_forward_expand(gf, Qcur); + ggml_build_forward_expand(gf, Kcur); + ggml_build_forward_expand(gf, Vcur); + + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_head_kv = split_wk->ne[1] / n_embd_head_k; + + GGML_ASSERT(kv_self.size == cparams.n_ctx); + + GGML_ASSERT(2*il+1 < (int)lctx.cache_copies.size()); + auto k_row_size = ggml_row_size(split_kl->type, n_embd_head_k); + ggml_tensor * k_cache_view = ggml_view_2d(ctx0, split_kl, n_embd_head_k, n_tokens*n_head_kv, + k_row_size, k_row_size*n_head_kv*kv_head); + + lctx.cache_copies[2*il+0].cpy = ggml_cpy(ctx0, Kcur, k_cache_view); + lctx.cache_copies[2*il+0].step = k_row_size*n_head_kv; + + // note: storing RoPE-ed version of K in the KV cache + ggml_build_forward_expand(gf, lctx.cache_copies[2*il+0].cpy); + + struct ggml_tensor * v_cache_view = nullptr; + + if (cparams.flash_attn) { + v_cache_view = ggml_view_1d(ctx0, split_vl, n_tokens*split_wv->ne[1], + kv_head*ggml_row_size(split_vl->type, split_wv->ne[1])); + lctx.cache_copies[2*il+1].step = ggml_row_size(split_vl->type, split_wv->ne[1]); + } else { + // note: the V cache is transposed when not using flash attention + v_cache_view = ggml_view_2d(ctx0, split_vl, n_tokens, split_wv->ne[1], + ( n_ctx)*ggml_element_size(split_vl), + (kv_head)*ggml_element_size(split_vl)); + lctx.cache_copies[2*il+1].step = ggml_element_size(split_vl); + + Vcur = ggml_transpose(ctx0, Vcur); + } + cb(v_cache_view, "v_cache_view", il_cb); + + lctx.cache_copies[2*il+1].cpy = ggml_cpy(ctx0, Vcur, v_cache_view); + ggml_build_forward_expand(gf, lctx.cache_copies[2*il+1].cpy); + + auto q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + cb(q, "q", il_cb); + + auto k = ggml_view_3d(ctx0, split_kl, n_embd_head_k, n_kv, n_head_kv, + ggml_row_size(split_kl->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa), + ggml_row_size(split_kl->type, n_embd_head_k), 0); + cb(k, "k", il_cb); + + auto v = ggml_view_3d(ctx0, split_vl, n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(split_vl->type, split_wv->ne[1]), + ggml_row_size(split_vl->type, n_embd_head_v), 0); + cb(v, "v", il_cb); + +#ifdef GGML_USE_VULKAN + constexpr bool use_f32_precision = true; +#else + constexpr bool use_f32_precision = false; +#endif + cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); + ggml_flash_attn_ext_add_sinks(cur, sinks); + if (n_swa > 0) { + ((int32_t *)cur->op_params)[4] = n_swa; + } + + // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA + if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || + (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || + model.arch == LLM_ARCH_GLM4_MOE) { + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + } + + cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens); + + cur = llm_build_lora_mm(lctx, ctx0, split_wo, cur); + if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + cb(cur, "kqv_wo", il_cb); + // TODO: wo_b + attn.push_back(cur); + } + if (attn.size() == 1) return attn.front(); + cur = ggml_add(ctx0, attn[0], attn[1]); + cb(cur, "combine_attn", il); + for (int id = 2; id < (int)attn.size(); ++id) { + cur = ggml_add(ctx0, cur, attn[id]); + cb(cur, "combine_attn", il); + } + return cur; + } + } + + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, + model.layers[il].wqkv, model.layers[il].bqkv, + model.layers[il].wqk, model.layers[il].bqk, + model.layers[il].wq, model.layers[il].bq, model.layers[il].wk, model.layers[il].bk, model.layers[il].wv, model.layers[il].bv, + model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa); + + return cur; +} diff --git a/src/llama-build-context.h b/src/llama-build-context.h index a96a49d4..271dd689 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -148,7 +148,7 @@ struct llm_build_context { ggml_tensor * wq, ggml_tensor * bq, ggml_tensor * wk, ggml_tensor * bk, ggml_tensor * wv, ggml_tensor * bv, - float attention_scale, int il); + float attention_scale, int il) const; std::tuple llm_build_mul_mat_qkv(ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * wqkv, ggml_tensor * bqkv, @@ -156,7 +156,7 @@ struct llm_build_context { ggml_tensor * wq, ggml_tensor * bq, ggml_tensor * wk, ggml_tensor * bk, ggml_tensor * wv, ggml_tensor * bv, - ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il); + ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il) const; ggml_cgraph * build_llama(); @@ -383,4 +383,7 @@ llm_expert_gating_func_type gating_op, static ggml_cgraph * llama_build_graph(llama_context & lctx, const llama_batch & batch, bool worst_case); + ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * rope_factors, + ggml_tensor * KQ_mask, ggml_tensor * sinks, float KQ_scale, float f_attn_scale, int n_swa, int il); + }; diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index b120b36e..6672a246 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -2750,13 +2750,11 @@ static void prepare_split_tensors(int split_dim, ggml_context * ctx, ggml_tensor std::string name{tensor->name}; split_tensor.tensor_splits.resize(splits.size()); if (split_dim == 1) { - size_t offset = 0; for (int i = 0; i < int(splits.size()); ++i) { if (splits[i] > 0) { - split_tensor.tensor_splits[i] = ggml_view_3d(ctx, tensor, tensor->ne[0], splits[i], tensor->ne[2], tensor->nb[1], tensor->nb[2], offset); + split_tensor.tensor_splits[i] = ggml_new_tensor_3d(ctx, tensor->type, tensor->ne[0], splits[i], tensor->ne[2]); auto name_i = name + '.' + std::to_string(i); ggml_set_name(split_tensor.tensor_splits[i], name_i.c_str()); - offset += tensor->nb[1]*splits[i]; } else { split_tensor.tensor_splits[i] = nullptr; } diff --git a/src/llama.cpp b/src/llama.cpp index b214ab2c..ceabcd31 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -740,8 +740,12 @@ static bool llama_kv_cache_init( else { k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size); v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); + auto k_name = std::string{"cache_k_l"} + std::to_string(i); + auto v_name = std::string{"cache_v_l"} + std::to_string(i); + ggml_set_name(k, k_name.c_str()); + ggml_set_name(v, v_name.c_str()); + //ggml_format_name(k, "cache_k_l%d", i); + //ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); cache.v_l.push_back(v); if (split_cache) { @@ -758,6 +762,8 @@ static bool llama_kv_cache_init( auto split = extra_K->splits[is]; if (!split) continue; split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, split->ne[1]/n_embd_head_k * kv_size); + auto split_name = k_name + '.' + std::to_string(is); + ggml_set_name(split_k_l.tensor_splits[is], split_name.c_str()); } split_k_l.ggml.n_device = extra_K->n_device; split_k_l.ggml.split_dim = 0; @@ -766,6 +772,8 @@ static bool llama_kv_cache_init( auto split = extra_V->splits[is]; if (!split) continue; split_v_l.tensor_splits[is] = ggml_new_tensor_1d(ctx, type_v, split->ne[1] * kv_size); + auto split_name = v_name + '.' + std::to_string(is); + ggml_set_name(split_v_l.tensor_splits[is], split_name.c_str()); } split_v_l.ggml.n_device = extra_V->n_device; split_v_l.ggml.split_dim = 0; @@ -798,6 +806,25 @@ static bool llama_kv_cache_init( cache.bufs.push_back(buf); } + for (int il = 0; il < n_layer; ++il) { + if (cache.k_l[il]->extra) { + printf("Layer %2d, K-buffer: %p:", il, (void *)cache.k_l[il]->buffer); + auto split_kl = (ggml_split_tensor_t *)cache.k_l[il]->extra; + for (int id = 0; id < split_kl->n_device; ++id) { + if (split_kl->splits[id]) printf(" %p,%p", (void *)split_kl->splits[id]->data, (void *)split_kl->splits[id]->buffer); + } + printf("\n"); + } + if (cache.v_l[il]->extra) { + printf("Layer %2d, V-buffer: %p:", il, (void *)cache.v_l[il]->buffer); + auto split_vl = (ggml_split_tensor_t *)cache.v_l[il]->extra; + for (int id = 0; id < split_vl->n_device; ++id) { + if (split_vl->splits[id]) printf(" %p,%p", (void *)split_vl->splits[id]->data, (void *)split_vl->splits[id]->buffer); + } + printf("\n"); + } + } + return true; } @@ -4350,7 +4377,7 @@ struct llama_context * llama_new_context_with_model( ggml_backend_add_from_device(ctx, ctx->backend_metal); } #elif defined(GGML_USE_CUDA) - if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) { + if (model->split_mode == LLAMA_SPLIT_MODE_NONE) { // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu, cparams.cuda_params); if (backend == nullptr) { @@ -4361,7 +4388,7 @@ struct llama_context * llama_new_context_with_model( ggml_backend_add_from_device(ctx, backend); } else { - // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU + // LLAMA_SPLIT_MODE_LAYER and LLAMA_SPLIT_MODE_ROW require a backend for each GPU for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) { ggml_backend_t backend = ggml_backend_cuda_init(device, cparams.cuda_params); if (backend == nullptr) { @@ -4370,7 +4397,6 @@ struct llama_context * llama_new_context_with_model( return nullptr; } ggml_backend_add_from_device(ctx, backend); - } } #elif defined(GGML_USE_VULKAN)