From e30198a5534c6c6d445e078ebcce1eddb003bb7b Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 16 Feb 2026 06:50:28 +0100 Subject: [PATCH] WIP: Qwen3Next (#1266) * qwen3next: add architecture support and recurrent-state fixes * qwen3next: optimize broadcast sub and single-seq ssm conv * cuda: build MoE row mapping on device in mul_mat_id * cuda: add guarded multi-seq fast path for ssm_conv * docs: update qwen3next perf report for cuda MoE/SSM tuning * cuda: reduce qwen3next moe/ssm sync overhead and refresh eval * qwen3next: split cpu/cuda eval builds and tune PP scheduling * qwen3next: harden seq-state flow and support optional dense FFN layers * qwen3next: trim delta-net graph overhead in chunking path * qwen3next: remove redundant v_conv cont in delta path * qwen3next: avoid extra cont on linear attention output * qwen3next: drop redundant cont before recurrent state flatten * qwen3next: keep recurrent state in 4d layout through delta path * qwen3next: add fused delta-net op and wire model path * tests: add backend-op coverage for ggml_delta_net * qwen3next: add runtime switch for fused delta-net path * docs: refresh qwen3next perf review and benchmark matrix * qwen3next: default fused delta-net off and document quality checks * qwen3next: add decode-only fused delta mode * qwen3next: make fused delta safe by default and fix fused tensor layout * qwen3next: warn when forcing fused decode mode * qwen3next: add fused-delta regression runner script * qwen3next: integrate fused regression into eval harness * qwen3next: clean up chunked delta-net shape handling * qwen3next: add absolute sanity guards to fused regression * qwen3next: add unified regression runner script * qwen3next: disable flash-attn for cpu-only contexts * docs: reconcile qwen3next status and remaining upstream gaps * common: add qwen3next fused-delta runtime flag * cuda: add qwen3next delta-net kernel dispatch override * docs: update qwen3next quality and serving baseline findings * qwen3next: keep fused delta on safe path and remove PR artifacts * qwen3next: align autoregressive delta-net decode layout * Revert "qwen3next: align autoregressive delta-net decode layout" This reverts commit 9241164a5ea9e032a2456fbf2dd0bf798b264fd7. * cuda: port solve-tri fast-paths for qwen3next delta-net * qwen3next: add fused-delta runtime flag and drop env toggle * qwen3next: make fused delta single-flag and default on * Account for GPU arch differences * Revert "cuda: build MoE row mapping on device in mul_mat_id" This reverts commit 89e9ecfa840b04e88699ab3803eb732cd78727f9. * qwen3next: drop non-essential MoE scheduling and split heuristics * qwen3next: avoid generic ggml_sub broadcast changes * llama: restore only_active_experts log message * Remove unnecessary hacks, disable fusion for now. * qwen3next: port hybrid recurrent state memory semantics * qwen3next: clean up recurrent state slot plumbing * qwen3next: fix hybrid V-cache layout plumbing * qwen3next: guard recurrent state slots against kv capacity * qwen3next: persist recurrent state in session data - serialize/restore qwen3next cache.s_l in state/session paths\n- bump session and sequence-state file versions for format change\n- fallback to single-token chunking for mixed repeated seq_id batches * qwen3next: drop unused fused-delta builder path - remove dead build_delta_net_fused lambda\n- remove unused llm_build_context::fused_delta member * qwen3next: remove unused fused-delta CLI/context plumbing - drop -fd/-no-fd options and related YAML dump field\n- remove fused_delta fields from public/internal context params\n- remove fused_delta assignment and logging in context init * ggml: remove unused DELTA_NET operator stack * Missing include * Reorder ops/unary ops So we don't change again the enum values of the mul mat ops * Minor * Discard unnecessary changes in llama-build-context.cpp * Minor * Revert "Discard unnecessary changes in llama-build-context.cpp" This reverts commit edadb80ed68c4c0831e9c22609a9a3af19be9735. * Increase GGML_SCHED_MAX_SPLITS - required for larger u-batches * Fix CPU concat in the TG case: 7.25 -> 10.5 t/s for Qwen3Next * Fix CPU sum_rows: 10.5 -> 13.6 t/s for Qwen3Next It was single-threaded and was taking ~25% of the computation time during TG. It is now down to 2%. Strangely enough, I measure 13.6 t/s with llama-bench, but if I let the model give me an actual response with llama-cli, I get close to 17 t/s. * Fix CPU scale: 13.6 -> 16.7 t/s for Qwen3Next For Qwen3Next there is a scale op on a largish tensor (548k elements) that has a single row for TG, so was done in a single thread. We now simply use blocks of 1024 elements. * Optimize CPU mul: 16.7 -> 17.6 t/s for Qwen3Next * CPU: fuse transpose -> cont -> sum_rows -> transpos: 17.6 -> 23.1 t/s for Qwen3Next * Optimize CPU repeat: 176 -> 200 t/s for Qwen3Next PP-512 * Multithreading for OP_SUB * Don't commit with timing trace on * Multithread neg and sigmoid * Be able to turn on/off fusion more easily (CPU) * Name the mul_mat ops so we know where the time goes * WIP * Much better PP on CUDA * CUDA: fuse transpose -> cont -> sum_rows -> transpose Needs non-coontiguous variant of sum_rows. On the CPU this gave 30+% improvement in TG performance, on CUDA ist is disapointing 6-7%. I guess, this is because Georgi's cont CPU implementation was so bad that skipping it made such a big difference. * CUDA: faster mul for special case relevant for Qwen3Next Worth 1% in TG * Fix CPU OP_CONT --------- Co-authored-by: yurko Co-authored-by: Yurko Co-authored-by: yurko Co-authored-by: Yurko Hoshko --- common/common.cpp | 2 + ggml/include/ggml.h | 81 +++ ggml/src/ggml-backend.cpp | 2 +- ggml/src/ggml-cuda.cu | 111 +++- ggml/src/ggml-cuda/binbcast.cu | 29 +- ggml/src/ggml-cuda/binbcast.cuh | 1 + ggml/src/ggml-cuda/cumsum.cu | 76 +++ ggml/src/ggml-cuda/cumsum.cuh | 3 + ggml/src/ggml-cuda/fill.cu | 34 + ggml/src/ggml-cuda/fill.cuh | 3 + ggml/src/ggml-cuda/norm.cu | 126 ++++ ggml/src/ggml-cuda/norm.cuh | 2 + ggml/src/ggml-cuda/solve_tri.cu | 914 +++++++++++++++++++++++++++ ggml/src/ggml-cuda/solve_tri.cuh | 3 + ggml/src/ggml-cuda/ssm-conv.cu | 608 ++++++++++++++++++ ggml/src/ggml-cuda/ssm-conv.cuh | 3 + ggml/src/ggml-cuda/sumrows.cu | 42 +- ggml/src/ggml-cuda/sumrows.cuh | 2 + ggml/src/ggml-cuda/tri.cu | 120 ++++ ggml/src/ggml-cuda/tri.cuh | 5 + ggml/src/ggml-cuda/unary.cu | 9 +- ggml/src/ggml-cuda/unary.cuh | 2 + ggml/src/ggml.c | 1019 ++++++++++++++++++++++++++---- include/llama.h | 4 +- src/llama-arch.cpp | 3 +- src/llama-arch.h | 5 + src/llama-build-context.cpp | 903 +++++++++++++++++++++++++- src/llama-build-context.h | 2 + src/llama-context.h | 2 + src/llama-hparams.cpp | 25 +- src/llama-hparams.h | 26 + src/llama-load-tensors.cpp | 97 +++ src/llama-model.cpp | 34 + src/llama-model.h | 3 + src/llama.cpp | 531 +++++++++++++--- 35 files changed, 4600 insertions(+), 232 deletions(-) create mode 100644 ggml/src/ggml-cuda/cumsum.cu create mode 100644 ggml/src/ggml-cuda/cumsum.cuh create mode 100644 ggml/src/ggml-cuda/fill.cu create mode 100644 ggml/src/ggml-cuda/fill.cuh create mode 100644 ggml/src/ggml-cuda/solve_tri.cu create mode 100644 ggml/src/ggml-cuda/solve_tri.cuh create mode 100644 ggml/src/ggml-cuda/ssm-conv.cu create mode 100644 ggml/src/ggml-cuda/ssm-conv.cuh create mode 100644 ggml/src/ggml-cuda/tri.cu create mode 100644 ggml/src/ggml-cuda/tri.cuh diff --git a/common/common.cpp b/common/common.cpp index 1e75e328..fc96c124 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -494,6 +495,7 @@ void gpt_params_parse_from_env(gpt_params & params) { get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching); get_env("LLAMA_ARG_HOST", params.hostname); get_env("LLAMA_ARG_PORT", params.port); + } bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index fe7cb166..57ec82ec 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -673,6 +673,12 @@ extern "C" { GGML_OP_ADD_REL_POS, GGML_OP_UNARY, + GGML_OP_CUMSUM, + GGML_OP_L2_NORM, + GGML_OP_TRI, + GGML_OP_FILL, + GGML_OP_SOLVE_TRI, + GGML_OP_MAP_UNARY, GGML_OP_MAP_BINARY, @@ -713,6 +719,8 @@ extern "C" { GGML_UNARY_OP_SWIGLU, GGML_UNARY_OP_SWIGLU_OAI, GGML_UNARY_OP_GELU, + GGML_UNARY_OP_EXP, + GGML_UNARY_OP_SOFTPLUS, GGML_UNARY_OP_COUNT, }; @@ -739,6 +747,13 @@ extern "C" { GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) }; + enum ggml_tri_type { + GGML_TRI_TYPE_LOWER, + GGML_TRI_TYPE_UPPER, + GGML_TRI_TYPE_LOWER_DIAG, + GGML_TRI_TYPE_UPPER_DIAG, + }; + // ggml object struct ggml_object { size_t offs; @@ -1189,6 +1204,14 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_softplus( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_softplus_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + // return scalar GGML_API struct ggml_tensor * ggml_sum( struct ggml_context * ctx, @@ -1199,6 +1222,10 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_cumsum( + struct ggml_context * ctx, + struct ggml_tensor * a); + // mean along rows GGML_API struct ggml_tensor * ggml_mean( struct ggml_context * ctx, @@ -1217,6 +1244,15 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // repeat a to specified shape + GGML_API struct ggml_tensor * ggml_repeat_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + // sums repetitions in a into shape of b GGML_API struct ggml_tensor * ggml_repeat_back( struct ggml_context * ctx, @@ -1455,6 +1491,14 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_exp( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_exp_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, @@ -1514,6 +1558,17 @@ extern "C" { int n_groups, float eps); + // l2 normalize along rows + GGML_API struct ggml_tensor * ggml_l2_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_l2_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + // a - x // b - dy GGML_API struct ggml_tensor * ggml_rms_norm_back( @@ -2283,6 +2338,23 @@ extern "C" { int dim, int max_period); + // convert matrix to triangular form by zeroing values outside selected half + GGML_API struct ggml_tensor * ggml_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_tri_type type); + + // fill tensor with constant c + GGML_API struct ggml_tensor * ggml_fill( + struct ggml_context * ctx, + struct ggml_tensor * a, + float c); + + GGML_API struct ggml_tensor * ggml_fill_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float c); + // sort rows enum ggml_sort_order { GGML_SORT_ORDER_ASC, @@ -2426,6 +2498,15 @@ extern "C" { struct ggml_tensor * pw, struct ggml_tensor * ph); + // Solve Ax = B where A is triangular + GGML_API struct ggml_tensor * ggml_solve_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool left, + bool lower, + bool uni); + // custom operators typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 38fd5692..09f876d0 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1103,7 +1103,7 @@ static bool ggml_is_view_op(enum ggml_op op) { #endif #ifndef GGML_SCHED_MAX_SPLITS -#define GGML_SCHED_MAX_SPLITS 2048 +#define GGML_SCHED_MAX_SPLITS 4096 #endif #ifndef GGML_SCHED_MAX_SPLIT_INPUTS diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index b8a8ebb6..1a197972 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -18,9 +18,11 @@ #include "ggml-cuda/concat.cuh" #include "ggml-cuda/convert.cuh" #include "ggml-cuda/cpy.cuh" +#include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/dmmv.cuh" #include "ggml-cuda/fattn.cuh" +#include "ggml-cuda/fill.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmq.cuh" @@ -46,10 +48,13 @@ #include "ggml-cuda/conv2d.cuh" #include "ggml-cuda/conv2d-dw.cuh" #include "ggml-cuda/set-rows.cuh" +#include "ggml-cuda/solve_tri.cuh" +#include "ggml-cuda/ssm-conv.cuh" #include "ggml-cuda/argmax.cuh" #include "ggml-cuda/multiadd.cuh" #include "ggml-cuda/hadamard.cuh" #include "ggml-cuda/reduce.cuh" +#include "ggml-cuda/tri.cuh" #include #include @@ -2011,9 +2016,11 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct const int64_t r3 = ne13/ne03; if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) { + //printf("Using cublasGemmStridedBatchedEx for %s\n", dst->name); // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3: const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00; - const int64_t smb = ne12 == 1 ? s13 : s12; + //const int64_t smb = ne12 == 1 ? s13 : s12; + const int64_t smb = ne12 == 1 ? nb13/nb10 : nb12/nb10; // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx @@ -2027,6 +2034,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } else { + //printf("Using cublasGemmBatchedEx for %s\n", dst->name); + //printf(" src0: %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu\n",src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); + //printf(" src1: %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu\n",src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); // use cublasGemmBatchedEx const int64_t ne23 = ne12*ne13; @@ -2238,22 +2248,29 @@ static int ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); if (any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + //printf("%s(%s): using ggml_cuda_mul_mat_vec_p021\n", __func__, dst->name); // FP32 precision KQ single-batch for batch size 1 without FlashAttention ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst); } else if (any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + //printf("%s(%s): using ggml_cuda_mul_mat_vec_nc\n", __func__, dst->name); // FP32 precision KQV single-batch for batch size 1 without FlashAttention ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst); - } else if (src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) + } else if ((src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32) && (src1->type == src0->type || !any_gpus_with_slow_fp16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + //printf("%s(%s): ggml_cuda_mul_mat_batched_cublas\n", __func__, dst->name); // KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { + //printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_dequantize_mul_mat_vec)\n", __func__, dst->name); ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); } else if (use_mul_mat_vec_q) { + //printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_vec_q)\n", __func__, dst->name); ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { + //printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_q)\n", __func__, dst->name); ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); } else { + //printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_cublas)\n", __func__, dst->name); ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); } return node_n; @@ -2822,11 +2839,6 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten return i; } - std::vector ids_host(ggml_nbytes(ids)); - const char * ids_dev = (const char *) ids->data; - CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); - ggml_tensor src0_1_row = *src0_1; ggml_tensor src0_2_row; if (src0_2) src0_2_row = *src0_2; ggml_tensor src1_row = *src1; @@ -3199,7 +3211,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg } break; case GGML_OP_CONT: - ggml_cuda_dup(ctx, dst); + if (fusion && i + 2 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_SUM_ROWS && + cgraph->nodes[i+2]->op == GGML_OP_TRANSPOSE && + dst->src[0]->op == GGML_OP_TRANSPOSE) { + ggml_cuda_op_sum_rows_nc(ctx, cgraph->nodes[i+1]); + i += 2; + } else { + ggml_cuda_dup(ctx, dst); + } break; case GGML_OP_ADD: if (fusion && i + 2 < cgraph->n_nodes && @@ -3242,6 +3262,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_acc(ctx, dst); break; case GGML_OP_MUL: + //printf("mul(%s): %d, %d, %d, %ld x %ld x %ld x %ld * %ld x %ld x %ld x %ld\n", dst->name, ggml_is_contiguous(dst->src[0]), ggml_is_contiguous(dst->src[1]), ggml_is_contiguous(dst), + // dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->src[0]->ne[3], + // dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], dst->src[1]->ne[3]); ggml_cuda_op_mul(ctx, dst); break; case GGML_OP_FUSED_MUL_UNARY: @@ -3250,6 +3273,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_DIV: ggml_cuda_op_div(ctx, dst); break; + case GGML_OP_SUB: + ggml_cuda_op_sub(ctx, dst); + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(dst)) { case GGML_UNARY_OP_GELU: @@ -3273,6 +3299,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_RELU: ggml_cuda_op_relu(ctx, dst); break; + case GGML_UNARY_OP_NEG: + ggml_cuda_op_neg(ctx, dst); + break; case GGML_UNARY_OP_SIGMOID: if (fusion && i + 5 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && @@ -3305,6 +3334,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_HARDSWISH: ggml_cuda_op_hardswish(ctx, dst); break; + case GGML_UNARY_OP_EXP: + ggml_cuda_op_exp(ctx, dst); + break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_cuda_op_softplus(ctx, dst); + break; default: return -1; } @@ -3339,6 +3374,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GROUP_NORM: ggml_cuda_op_group_norm(ctx, dst); break; + case GGML_OP_L2_NORM: + ggml_cuda_op_l2_norm(ctx, dst); + break; case GGML_OP_CONCAT: if (fusion && i + 2 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_VIEW && @@ -3554,6 +3592,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_sum_rows(ctx, dst); } break; + case GGML_OP_CUMSUM: + ggml_cuda_op_cumsum(ctx, dst); + break; case GGML_OP_ARGSORT: if (fusion && i + 5 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_VIEW && @@ -3573,6 +3614,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GROUPED_TOPK: ggml_cuda_op_grouped_topk(ctx, dst); break; + case GGML_OP_SSM_CONV: + ggml_cuda_op_ssm_conv(ctx, dst); + break; + case GGML_OP_TRI: + ggml_cuda_op_tri(ctx, dst); + break; + case GGML_OP_FILL: + ggml_cuda_op_fill(ctx, dst); + break; + case GGML_OP_SOLVE_TRI: + ggml_cuda_op_solve_tri(ctx, dst); + break; case GGML_OP_FLASH_ATTN_EXT: ggml_cuda_flash_attn_ext(ctx, dst); break; @@ -3594,6 +3647,10 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg } #if IK_PRINT_TIMING + if (auto err = cudaStreamSynchronize(ctx.stream()); err != cudaSuccess) { + GGML_CUDA_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst)); + CUDA_CHECK(err); + } int64_t tim2 = ggml_time_us(); printf("%s(%s): %d us\n", ggml_op_name(dst->op), dst->name, (int)(tim2 - tim1)); #endif @@ -4149,6 +4206,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_NEG: return ggml_is_contiguous(op->src[0]); default: return false; @@ -4342,6 +4402,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_NORM: case GGML_OP_RMS_NORM: return true; + case GGML_OP_L2_NORM: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_RMS_NORM_BACK: return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; break; @@ -4356,6 +4418,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_MUL_MULTI_ADD: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_SUB: case GGML_OP_FUSED_RMS_NORM: case GGML_OP_SCALE: case GGML_OP_SOFTCAP: @@ -4389,6 +4452,38 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: return true; + case GGML_OP_CUMSUM: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_TRI: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->src[0]->type == op->type; + case GGML_OP_FILL: + return ggml_is_contiguous(op) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); + case GGML_OP_SOLVE_TRI: + return ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + ggml_is_contiguous(op) && + op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32 && + op->src[0]->ne[0] == op->src[0]->ne[1] && + op->src[0]->ne[1] == op->src[1]->ne[1] && + op->src[0]->ne[2] == op->src[1]->ne[2] && + op->src[0]->ne[3] == op->src[1]->ne[3]; + case GGML_OP_SSM_CONV: + return op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + op->src[2]->type == GGML_TYPE_F32 && + op->src[3]->type == GGML_TYPE_I32 && + op->type == GGML_TYPE_F32 && + op->src[0]->nb[0] == sizeof(float) && + op->src[1]->nb[0] == sizeof(float) && + op->src[2]->nb[0] == sizeof(float) && + op->src[3]->nb[0] == sizeof(int32_t) && + op->src[2]->ne[0] == op->src[0]->ne[0] + 1 && + op->src[2]->ne[1] == op->src[0]->ne[1] && + op->src[1]->ne[0] == op->src[0]->ne[1] && + op->src[3]->ne[0] == op->src[0]->ne[2]; case GGML_OP_FLASH_ATTN_EXT: #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128; diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 96166ceb..5953746b 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -24,6 +24,10 @@ static __device__ __forceinline__ float op_div(const float a, const float b) { return a / b; } +static __device__ __forceinline__ float op_sub(const float a, const float b) { + return a - b; +} + template static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, int ne0, int ne1, int ne2, int ne3, @@ -512,14 +516,37 @@ static void ggml_cuda_op_scale_tensor(ggml_backend_cuda_context & ctx, ggml_tens scale_f32_cuda_l(src0_d, dst_d, dst->src[1]->data, ggml_nelements(src0), stream); } +static __global__ void k_mul_fast(int ne0, int nelem, const float * x, const float * y, float * z) { + int i = blockDim.x*blockIdx.x + threadIdx.x; + if (i >= nelem) { + return; + } + int i1 = i / ne0; + z[i] = x[i] * y[i1]; +} + void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (ggml_nelements(dst->src[1]) == 1 && dst->src[1]->type == GGML_TYPE_F32 && dst->src[0]->type == GGML_TYPE_F32) { ggml_cuda_op_scale_tensor(ctx, dst); return; } - ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); + auto src0 = dst->src[0]; + auto src1 = dst->src[1]; + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && + src1->ne[0] == 1 && src0->ne[1] == src1->ne[1] && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3]) { + constexpr int kBlockSize = 256; + int nelem = ggml_nelements(src0); + int nblock = (nelem + kBlockSize - 1)/kBlockSize; + k_mul_fast<<>>(src0->ne[0], nelem, (const float *)src0->data, (const float *)src1->data, (float *)dst->data); + return; + } + ggml_cuda_op_bin_bcast>(src0, src1, dst, src0->data, src1->data, dst->data, ctx.stream()); } void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } + +void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); +} diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh index 6fb13413..4b72d026 100644 --- a/ggml/src/ggml-cuda/binbcast.cuh +++ b/ggml/src/ggml-cuda/binbcast.cuh @@ -2,6 +2,7 @@ void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu new file mode 100644 index 00000000..ae7bbfbd --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -0,0 +1,76 @@ +#include "cumsum.cuh" + +#define CUDA_CUMSUM_BLOCK_SIZE 256 + +static __global__ void cumsum_f32_kernel( + const float * src, float * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t d0, const int64_t d1, const int64_t d2, const int64_t d3) { + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.y; + const int64_t i3 = blockIdx.z; + + if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) { + return; + } + + const float * src_row = src + i1 * s01 + i2 * s02 + i3 * s03; + float * dst_row = dst + i1 * d1 + i2 * d2 + i3 * d3; + + extern __shared__ float s_scan[]; + + float carry = 0.0f; + for (int64_t start = 0; start < ne00; start += blockDim.x) { + const int tile_n = (int) ((ne00 - start) < (int64_t) blockDim.x ? (ne00 - start) : (int64_t) blockDim.x); + + float value = 0.0f; + if (threadIdx.x < tile_n) { + value = src_row[(start + threadIdx.x) * s00]; + } + s_scan[threadIdx.x] = value; + __syncthreads(); + + for (int offset = 1; offset < blockDim.x; offset <<= 1) { + float add = 0.0f; + if (threadIdx.x >= offset) { + add = s_scan[threadIdx.x - offset]; + } + __syncthreads(); + if (threadIdx.x >= offset) { + s_scan[threadIdx.x] += add; + } + __syncthreads(); + } + + if (threadIdx.x < tile_n) { + dst_row[(start + threadIdx.x) * d0] = s_scan[threadIdx.x] + carry; + } + + __syncthreads(); + if (threadIdx.x == tile_n - 1) { + carry += s_scan[threadIdx.x]; + } + __syncthreads(); + } +} + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + int block_size = WARP_SIZE; + while (block_size < src0->ne[0] && block_size < CUDA_CUMSUM_BLOCK_SIZE) { + block_size <<= 1; + } + + dim3 grid_dims(src0->ne[1], src0->ne[2], src0->ne[3]); + cumsum_f32_kernel<<>>( + (const float *) src0->data, + (float *) dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0] / sizeof(float), src0->nb[1] / sizeof(float), src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), + dst->nb[0] / sizeof(float), dst->nb[1] / sizeof(float), dst->nb[2] / sizeof(float), dst->nb[3] / sizeof(float)); +} diff --git a/ggml/src/ggml-cuda/cumsum.cuh b/ggml/src/ggml-cuda/cumsum.cuh new file mode 100644 index 00000000..eeb506b0 --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fill.cu b/ggml/src/ggml-cuda/fill.cu new file mode 100644 index 00000000..b78029f5 --- /dev/null +++ b/ggml/src/ggml-cuda/fill.cu @@ -0,0 +1,34 @@ +#include "fill.cuh" +#include "convert.cuh" + +#define CUDA_FILL_BLOCK_SIZE 256 + +template +static __global__ void fill_kernel(T * dst, const int64_t k, const T value) { + const int64_t i = (int64_t) blockDim.x * blockIdx.x + threadIdx.x; + if (i >= k) { + return; + } + dst[i] = value; +} + +void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst)); + + float value = 0.0f; + memcpy(&value, dst->op_params, sizeof(float)); + + const int64_t k = ggml_nelements(dst); + const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE; + + switch (dst->type) { + case GGML_TYPE_F32: + fill_kernel<<>>((float *) dst->data, k, value); + break; + case GGML_TYPE_F16: + fill_kernel<<>>((half *) dst->data, k, ggml_cuda_cast(value)); + break; + default: + GGML_ABORT("unsupported type"); + } +} diff --git a/ggml/src/ggml-cuda/fill.cuh b/ggml/src/ggml-cuda/fill.cuh new file mode 100644 index 00000000..8443c836 --- /dev/null +++ b/ggml/src/ggml-cuda/fill.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index b198b68e..61d90a0c 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -185,6 +185,38 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol } } +template +static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x * blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + float tmp = 0.0f; + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row * ncols + col]; + tmp += xi * xi; + } + + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = lane_id < block_size / WARP_SIZE ? s_sum[lane_id] : 0.0f; + tmp = warp_reduce_sum(tmp); + } + + const float scale = rsqrtf(fmaxf(tmp, eps * eps)); + + for (int col = tid; col < ncols; col += block_size) { + dst[row * ncols + col] = scale * x[row * ncols + col]; + } +} + template static __global__ void rms_norm_f32_nc( const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, @@ -230,6 +262,49 @@ static __global__ void rms_norm_f32_nc( } } +template +static __global__ void l2_norm_f32_nc( + const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; + + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + x += sample * stride_sample + channel * stride_channel + row * stride_row; + dst += ((sample * nchannels + channel) * nrows + row) * ncols; + + float tmp = 0.0f; + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[col]; + tmp += xi * xi; + } + + tmp = warp_reduce_sum(tmp); + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); + __shared__ float s_sum[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + const float scale = rsqrtf(fmaxf(tmp, eps * eps)); + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = scale * x[col]; + } +} + template static __global__ void fused_rms_norm_f32(const src_t * x, const float * y, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; @@ -387,6 +462,31 @@ static void rms_norm_f32_nc_cuda( } } +static void l2_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + constexpr int kBlockSize = 256; + if (ncols < 1024) { + const dim3 block_dims(kBlockSize, 1, 1); + l2_norm_f32<<>>(x, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + l2_norm_f32<1024><<>>(x, dst, ncols, eps); + } +} + +static void l2_norm_f32_nc_cuda( + const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + l2_norm_f32_nc<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } else { + const dim3 block_dims(1024, 1, 1); + l2_norm_f32_nc<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } +} + template static void fused_rms_norm_f32_cuda(const src_t * x, const float * y, float * dst, const int ncols, const int nrows, const float eps, bool is_norm, cudaStream_t stream) { @@ -527,6 +627,32 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } } +void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + float eps = 0.0f; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int64_t ne00 = src0->ne[0]; + if (ggml_is_contiguous(src0)) { + const int64_t nrows = ggml_nrows(src0); + l2_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); + } else { + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(src0->nb[0] == ts0); + const int64_t s01 = src0->nb[1] / ts0; + const int64_t s02 = src0->nb[2] / ts0; + const int64_t s03 = src0->nb[3] / ts0; + l2_norm_f32_nc_cuda(src0_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream); + } +} + void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst, bool is_norm) { if (!dst->src[1]) { ggml_cuda_op_rms_norm(ctx, dst); diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index 8550633b..0c6c16a6 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -6,6 +6,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst, bool is_norm = false); void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu new file mode 100644 index 00000000..726b08b9 --- /dev/null +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -0,0 +1,914 @@ +#include "common.cuh" +#include "ggml.h" +#include "solve_tri.cuh" +#include "ggml-cuda.h" +#include +#include + +#define MAX_N_FAST 64 +#define MAX_K_FAST 64 + +// This branch does not carry the fast-div helpers from upstream CUDA common code. +// Keep the PR kernel logic but back it with plain div/mod wrappers. +static inline uint3 init_fastdiv_values(uint32_t d) { + return make_uint3(d, 0u, 0u); +} + +static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 d) { + return make_uint2(n / d.x, n % d.x); +} + +// Kernel to set up pointer arrays for batched cuBLAS TRSM +// This avoids host-device copy during CUDA graph capture +static __global__ void setup_trsm_batch_pointers( + const float * A, + float * X, + const float ** A_ptrs, + float ** X_ptrs, + const int64_t ne02, + const int64_t total_batches, + const size_t nb02, // stride for A dim 2 (in floats) + const size_t nb03, // stride for A dim 3 (in floats) + const size_t nb2, // stride for X dim 2 (in floats) + const size_t nb3 // stride for X dim 3 (in floats) +) { + const int64_t batch_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (batch_idx >= total_batches) return; + + // Decompose batch_idx into i02, i03 + const int64_t i02 = batch_idx % ne02; + const int64_t i03 = batch_idx / ne02; + + A_ptrs[batch_idx] = A + i02 * nb02 + i03 * nb03; + X_ptrs[batch_idx] = X + i02 * nb2 + i03 * nb3; +} + +// Latency-optimized kernel for n=64, k=64 (single-token generation) +static __global__ void solve_tri_f32_64x64_latency( + const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const uint3 ne02, + const size_t nb02, + const size_t nb03, + const size_t nb12, + const size_t nb13, + const size_t nb2, + const size_t nb3) +{ + const int batch_idx = blockIdx.x; + const int lane = threadIdx.x; + const int warp_id = threadIdx.y; + + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; + + const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); + + // Shared memory: A is 64x64, X is 64x65 (padded for bank conflicts) + __shared__ float sA[64 * 64]; + __shared__ float sX[64 * 65]; + __shared__ float sDiagInv[64]; // Precomputed 1/diagonal + + const int tid = lane + warp_id * WARP_SIZE; + + // Cooperative load of A matrix (4096 elements / 512 threads = 8 per thread) + #pragma unroll 8 + for (int i = tid; i < 64 * 64; i += 512) { + sA[i] = A_batch[i]; + } + + // Cooperative load of B matrix into sX with padding + #pragma unroll 8 + for (int i = tid; i < 64 * 64; i += 512) { + const int row = i / 64; + const int col = i % 64; + sX[row * 65 + col] = B_batch[i]; + } + + __syncthreads(); + + // Precompute diagonal inverses (first 2 warps handle this) + if (warp_id == 0) { + if (lane < 32) { + sDiagInv[lane] = 1.0f / sA[lane * 64 + lane]; + } + } + if (warp_id == 1) { + if (lane < 32) { + sDiagInv[32 + lane] = 1.0f / sA[(32 + lane) * 64 + (32 + lane)]; + } + } + + __syncthreads(); + + // Each warp handles 4 columns: cols = warp_id*4 to warp_id*4+3 + const int col_base = warp_id * 4; + + #pragma unroll 1 + for (int row = 0; row < 64; ++row) { + float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f; + + if (row > 0) { + for (int j = lane; j < row; j += WARP_SIZE) { + const float a_val = sA[row * 64 + j]; + sum0 += a_val * sX[j * 65 + col_base + 0]; + sum1 += a_val * sX[j * 65 + col_base + 1]; + sum2 += a_val * sX[j * 65 + col_base + 2]; + sum3 += a_val * sX[j * 65 + col_base + 3]; + } + } + + sum0 = warp_reduce_sum(sum0); + sum1 = warp_reduce_sum(sum1); + sum2 = warp_reduce_sum(sum2); + sum3 = warp_reduce_sum(sum3); + + if (lane == 0) { + const float inv_diag = sDiagInv[row]; + sX[row * 65 + col_base + 0] = (sX[row * 65 + col_base + 0] - sum0) * inv_diag; + sX[row * 65 + col_base + 1] = (sX[row * 65 + col_base + 1] - sum1) * inv_diag; + sX[row * 65 + col_base + 2] = (sX[row * 65 + col_base + 2] - sum2) * inv_diag; + sX[row * 65 + col_base + 3] = (sX[row * 65 + col_base + 3] - sum3) * inv_diag; + } + + __syncthreads(); + } + + // Cooperative write results back + #pragma unroll 8 + for (int i = tid; i < 64 * 64; i += 512) { + const int row = i / 64; + const int col = i % 64; + X_batch[i] = sX[row * 65 + col]; + } +} + +static __global__ void solve_tri_f32_64x64_opt(const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const uint3 ne02, + const size_t nb02, + const size_t nb03, + const size_t nb12, + const size_t nb13, + const size_t nb2, + const size_t nb3) { + const int batch_idx = blockIdx.x; + const int lane = threadIdx.x; + const int warp_id = threadIdx.y; + + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; + + const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); + + // Shared memory: A is 64x64, sXt is 64x65 (padded) + __shared__ float sA[64 * 64]; + __shared__ float sXt[64 * 65]; + + const int tid = lane + warp_id * WARP_SIZE; + + // Cooperative load of A matrix (4096 elements / 1024 threads = 4 per thread) + #pragma unroll 4 + for (int i = tid; i < 64 * 64; i += 1024) { + sA[i] = A_batch[i]; + } + + // Cooperative load of B matrix transposed into sXt + // sXt[col * 65 + row] = B[row * 64 + col] + #pragma unroll 4 + for (int i = tid; i < 64 * 64; i += 1024) { + const int row = i / 64; + const int col = i % 64; + sXt[col * 65 + row] = B_batch[row * 64 + col]; + } + + __syncthreads(); + + // Each warp handles 2 columns: col0 = warp_id*2, col1 = warp_id*2 + 1 + const int col0 = warp_id * 2; + const int col1 = warp_id * 2 + 1; + + // Forward substitution with all columns processed in parallel + // Each row depends on previous rows, but different columns are independent + #pragma unroll 1 + for (int row = 0; row < 64; ++row) { + // Each lane computes partial sum for indices it handles + float sum0 = 0.0f; + float sum1 = 0.0f; + + // Sum over j < row + // For row <= 32: each lane handles at most 1 element + // For row > 32: each lane handles at most 2 elements + if (lane < row) { + const float a_val = sA[row * 64 + lane]; + sum0 = a_val * sXt[col0 * 65 + lane]; + sum1 = a_val * sXt[col1 * 65 + lane]; + } + if (row > WARP_SIZE) { + const int j2 = lane + WARP_SIZE; + if (j2 < row) { + const float a_val2 = sA[row * 64 + j2]; + sum0 += a_val2 * sXt[col0 * 65 + j2]; + sum1 += a_val2 * sXt[col1 * 65 + j2]; + } + } + + // Warp-level reduction + sum0 = warp_reduce_sum(sum0); + sum1 = warp_reduce_sum(sum1); + + // Lane 0 computes and stores the result + if (lane == 0) { + const float a_diag = sA[row * 64 + row]; + const float inv_diag = 1.0f / a_diag; + sXt[col0 * 65 + row] = (sXt[col0 * 65 + row] - sum0) * inv_diag; + sXt[col1 * 65 + row] = (sXt[col1 * 65 + row] - sum1) * inv_diag; + } + + // Sync within warp to ensure writes are visible before next row reads + __syncwarp(); + } + + __syncthreads(); + + // Cooperative write of results back (transpose sXt to X) + #pragma unroll 4 + for (int i = tid; i < 64 * 64; i += 1024) { + const int row = i / 64; + const int col = i % 64; + X_batch[row * 64 + col] = sXt[col * 65 + row]; + } +} + +static __global__ void solve_tri_f32_128x128_opt(const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const uint3 ne02, + const size_t nb02, + const size_t nb03, + const size_t nb12, + const size_t nb13, + const size_t nb2, + const size_t nb3, + const int n, + const int k) { + const int batch_idx = blockIdx.x; + const int lane = threadIdx.x; + const int warp_id = threadIdx.y; + + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; + + const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); + + // Shared memory with padding to avoid bank conflicts + // Layout: sA[128][128] + sXt[128][129] + extern __shared__ char smem_raw[]; + float * sA = (float *)smem_raw; // 128×128 (zero-initialized for unused parts) + float * sXt = sA + 128 * 128; // 128×129 (padded) + + const int tid = lane + warp_id * WARP_SIZE; + + // Zero-initialize shared memory first (important for variable n, k) + #pragma unroll 16 + for (int i = tid; i < 128 * 128; i += 1024) { + sA[i] = 0.0f; + } + #pragma unroll 16 + for (int i = tid; i < 128 * 129; i += 1024) { + sXt[i] = 0.0f; + } + __syncthreads(); + + // Cooperative load of A matrix (n×n elements) + for (int i = tid; i < n * n; i += 1024) { + const int row = i / n; + const int col = i % n; + sA[row * 128 + col] = A_batch[row * n + col]; + } + + // Cooperative load of B matrix transposed into sXt + // sXt[col * 129 + row] = B[row * k + col] + for (int i = tid; i < n * k; i += 1024) { + const int row = i / k; + const int col = i % k; + sXt[col * 129 + row] = B_batch[row * k + col]; + } + + __syncthreads(); + + // Each warp handles columns: col_base to col_base+3 + // But only process if col < k + const int col_base = warp_id * 4; + + // Forward substitution with all columns processed in parallel + for (int row = 0; row < n; ++row) { + float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f; + + // Sum over j < row - each lane handles multiple elements + for (int j = lane; j < row; j += WARP_SIZE) { + const float a_val = sA[row * 128 + j]; + if (col_base + 0 < k) sum0 += a_val * sXt[(col_base + 0) * 129 + j]; + if (col_base + 1 < k) sum1 += a_val * sXt[(col_base + 1) * 129 + j]; + if (col_base + 2 < k) sum2 += a_val * sXt[(col_base + 2) * 129 + j]; + if (col_base + 3 < k) sum3 += a_val * sXt[(col_base + 3) * 129 + j]; + } + + // Warp-level reduction + sum0 = warp_reduce_sum(sum0); + sum1 = warp_reduce_sum(sum1); + sum2 = warp_reduce_sum(sum2); + sum3 = warp_reduce_sum(sum3); + + // Lane 0 computes and stores the result + if (lane == 0) { + const float inv_diag = 1.0f / sA[row * 128 + row]; + if (col_base + 0 < k) { + sXt[(col_base + 0) * 129 + row] = (sXt[(col_base + 0) * 129 + row] - sum0) * inv_diag; + } + if (col_base + 1 < k) { + sXt[(col_base + 1) * 129 + row] = (sXt[(col_base + 1) * 129 + row] - sum1) * inv_diag; + } + if (col_base + 2 < k) { + sXt[(col_base + 2) * 129 + row] = (sXt[(col_base + 2) * 129 + row] - sum2) * inv_diag; + } + if (col_base + 3 < k) { + sXt[(col_base + 3) * 129 + row] = (sXt[(col_base + 3) * 129 + row] - sum3) * inv_diag; + } + } + + __syncwarp(); + } + + __syncthreads(); + + // Cooperative write of results back (transpose sXt to X) + for (int i = tid; i < n * k; i += 1024) { + const int row = i / k; + const int col = i % k; + X_batch[row * k + col] = sXt[col * 129 + row]; + } +} + +static __global__ void solve_tri_f32_256x256_tiled(const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const uint3 ne02, + const size_t nb02, + const size_t nb03, + const size_t nb12, + const size_t nb13, + const size_t nb2, + const size_t nb3, + const int n, + const int k) { + const int batch_idx = blockIdx.x; + const int lane = threadIdx.x; + const int warp_id = threadIdx.y; + + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; + + const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); + + // Tiled approach using 64×64 tiles to fit in shared memory + constexpr int TILE_SIZE = 64; + + extern __shared__ char smem_raw[]; + float * sA_tile = (float *)smem_raw; // 64×64 = 16KB + float * sXt_tile = sA_tile + TILE_SIZE * TILE_SIZE; // 64×65 = 16.25KB (padded) + float * sA_off = sXt_tile + TILE_SIZE * (TILE_SIZE+1); // 64×64 = 16KB (for off-diagonal blocks) + + const int tid = lane + warp_id * WARP_SIZE; + + // Initialize X = B (we'll solve in-place conceptually, using global memory) + for (int i = tid; i < n * k; i += 1024) { + X_batch[i] = B_batch[i]; + } + __syncthreads(); + + // Process tile-by-tile along the diagonal + for (int tile_row = 0; tile_row < n; tile_row += TILE_SIZE) { + const int tile_n = min(TILE_SIZE, n - tile_row); // Actual rows in this tile + + // Zero-init and load diagonal tile of A + for (int i = tid; i < TILE_SIZE * TILE_SIZE; i += 1024) { + sA_tile[i] = 0.0f; + } + __syncthreads(); + + for (int i = tid; i < tile_n * tile_n; i += 1024) { + int local_row = i / tile_n; + int local_col = i % tile_n; + sA_tile[local_row * TILE_SIZE + local_col] = A_batch[(tile_row + local_row) * n + tile_row + local_col]; + } + __syncthreads(); + + // For each column tile of X + for (int tile_col = 0; tile_col < k; tile_col += TILE_SIZE) { + const int tile_k = min(TILE_SIZE, k - tile_col); // Actual columns in this tile + + // Zero-init and load X tile transposed + for (int i = tid; i < TILE_SIZE * (TILE_SIZE+1); i += 1024) { + sXt_tile[i] = 0.0f; + } + __syncthreads(); + + for (int i = tid; i < tile_n * tile_k; i += 1024) { + int local_row = i / tile_k; + int local_col = i % tile_k; + sXt_tile[local_col * (TILE_SIZE+1) + local_row] = + X_batch[(tile_row + local_row) * k + tile_col + local_col]; + } + __syncthreads(); + + // Apply updates from previous tile rows + for (int prev_tile = 0; prev_tile < tile_row; prev_tile += TILE_SIZE) { + const int prev_n = min(TILE_SIZE, n - prev_tile); + + // Zero-init and load off-diagonal block + for (int i = tid; i < TILE_SIZE * TILE_SIZE; i += 1024) { + sA_off[i] = 0.0f; + } + __syncthreads(); + + for (int i = tid; i < tile_n * prev_n; i += 1024) { + int local_row = i / prev_n; + int local_col = i % prev_n; + sA_off[local_row * TILE_SIZE + local_col] = A_batch[(tile_row + local_row) * n + prev_tile + local_col]; + } + __syncthreads(); + + // Update: X_tile -= A_off @ X_prev + int col0 = warp_id * 2; + int col1 = warp_id * 2 + 1; + + for (int row = 0; row < tile_n; row++) { + float sum0 = 0.0f, sum1 = 0.0f; + + for (int j = lane; j < prev_n; j += WARP_SIZE) { + float a_val = sA_off[row * TILE_SIZE + j]; + if (col0 < tile_k) { + float x_prev0 = X_batch[(prev_tile + j) * k + tile_col + col0]; + sum0 += a_val * x_prev0; + } + if (col1 < tile_k) { + float x_prev1 = X_batch[(prev_tile + j) * k + tile_col + col1]; + sum1 += a_val * x_prev1; + } + } + + sum0 = warp_reduce_sum(sum0); + sum1 = warp_reduce_sum(sum1); + + if (lane == 0) { + if (col0 < tile_k) { + sXt_tile[col0 * (TILE_SIZE+1) + row] -= sum0; + } + if (col1 < tile_k) { + sXt_tile[col1 * (TILE_SIZE+1) + row] -= sum1; + } + } + __syncwarp(); + } + __syncthreads(); + } + + // Solve the diagonal tile + int col0 = warp_id * 2; + int col1 = warp_id * 2 + 1; + + for (int row = 0; row < tile_n; ++row) { + float sum0 = 0.0f, sum1 = 0.0f; + + if (lane < row) { + float a_val = sA_tile[row * TILE_SIZE + lane]; + if (col0 < tile_k) sum0 = a_val * sXt_tile[col0 * (TILE_SIZE+1) + lane]; + if (col1 < tile_k) sum1 = a_val * sXt_tile[col1 * (TILE_SIZE+1) + lane]; + } + if (row > WARP_SIZE) { + int j2 = lane + WARP_SIZE; + if (j2 < row) { + float a_val2 = sA_tile[row * TILE_SIZE + j2]; + if (col0 < tile_k) sum0 += a_val2 * sXt_tile[col0 * (TILE_SIZE+1) + j2]; + if (col1 < tile_k) sum1 += a_val2 * sXt_tile[col1 * (TILE_SIZE+1) + j2]; + } + } + + sum0 = warp_reduce_sum(sum0); + sum1 = warp_reduce_sum(sum1); + + if (lane == 0) { + float inv_diag = 1.0f / sA_tile[row * TILE_SIZE + row]; + if (col0 < tile_k) { + sXt_tile[col0 * (TILE_SIZE+1) + row] = + (sXt_tile[col0 * (TILE_SIZE+1) + row] - sum0) * inv_diag; + } + if (col1 < tile_k) { + sXt_tile[col1 * (TILE_SIZE+1) + row] = + (sXt_tile[col1 * (TILE_SIZE+1) + row] - sum1) * inv_diag; + } + } + __syncwarp(); + } + __syncthreads(); + + // Write solved tile back to global memory + for (int i = tid; i < tile_n * tile_k; i += 1024) { + int local_row = i / tile_k; + int local_col = i % tile_k; + X_batch[(tile_row + local_row) * k + tile_col + local_col] = + sXt_tile[local_col * (TILE_SIZE+1) + local_row]; + } + __syncthreads(); + } + } +} + +// When ncols_template == 0 the bounds for the loops in this function are not +// known and can't be unrolled. As we want to keep pragma unroll for all other +// cases we supress the clang transformation warning here. +#ifdef __clang__ +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +// Template parameters: n_template/k_template are the matrix dimensions when known at compile time (0 = runtime) +// threads_y_template is the number of threads in y dimension (max 32 to stay within 1024 thread limit) +template +static __global__ void solve_tri_f32_fast(const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const uint3 ne02, + const size_t nb02, + const size_t nb03, + const size_t nb12, + const size_t nb13, + const size_t nb2, + const size_t nb3, + const int n_arg, + const int k_arg) { + const int n = n_template == 0 ? n_arg : n_template; + const int k = k_template == 0 ? k_arg : k_template; + const int threads_y = threads_y_template == 0 ? blockDim.y : threads_y_template; + + const int batch_idx = blockIdx.x; + const int lane = threadIdx.x; + + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; + + const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); + + __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; + __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)]; + + const int offset = threadIdx.x + threadIdx.y * blockDim.x; + const int block_threads = blockDim.x * blockDim.y; + + // Load A matrix into shared memory +#pragma unroll + for (int i = 0; i < n * n; i += block_threads) { + int i0 = i + offset; + if (i0 < n * n) { + sA[i0] = A_batch[i0]; + } + } + + const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE; + const int cols_per_thread = (k + threads_y - 1) / threads_y; + + // Load B matrix into shared memory (transposed as sXt) + // Each thread handles multiple columns when k > threads_y + for (int c = 0; c < cols_per_thread; c++) { + const int col_idx = threadIdx.y + c * threads_y; + if (col_idx < k) { +#pragma unroll + for (int i = 0; i < rows_per_warp; i++) { + const int i0 = lane + i * WARP_SIZE; + if (i0 < n) { + sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx]; + } + } + } + } + + __syncthreads(); + + // Solve for each column this thread handles + for (int c = 0; c < cols_per_thread; c++) { + const int col_idx = threadIdx.y + c * threads_y; + if (col_idx >= k) { + continue; + } + +#pragma unroll + for (int row = 0; row < n; ++row) { + float sum = 0.0f; + + { + int j = lane; + if (j < row) { + sum += sA[row * n + j] * sXt[col_idx * n + j]; + } + } + if (row >= WARP_SIZE) { + int j = WARP_SIZE + lane; + if (j < row) { + sum += sA[row * n + j] * sXt[col_idx * n + j]; + } + } + + sum = warp_reduce_sum(sum); + + if (lane == 0) { + const float b_val = sXt[col_idx * n + row]; + const float a_diag = sA[row * n + row]; + // no safeguards for division by zero because that indicates corrupt + // data anyway + sXt[col_idx * n + row] = (b_val - sum) / a_diag; + } + } + + // Sync between columns to ensure writes are visible + if (c + 1 < cols_per_thread) { + __syncwarp(); + } + } + + __syncthreads(); + + // Write results back + for (int c = 0; c < cols_per_thread; c++) { + const int col_idx = threadIdx.y + c * threads_y; + if (col_idx < k) { +#pragma unroll + for (int i = 0; i < rows_per_warp; i++) { + const int i0 = lane + i * WARP_SIZE; + if (i0 < n) { + X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0]; + } + } + } + } +} +#ifdef __clang__ +# pragma clang diagnostic pop +#endif // __clang__ + +// cuBLAS batched TRSM fallback for larger matrices or as robust path +// Solves A * X = B where A is lower triangular +// This function modifies X in-place (X should be initialized with B) +static void solve_tri_f32_cublas( + ggml_backend_cuda_context & ctx, + const float * A, + float * X, // Input: B, Output: solution X (in-place) + int n, + int k, + int64_t ne02, + int64_t ne03, + size_t nb02, + size_t nb03, + size_t nb2, + size_t nb3, + cudaStream_t stream +) { + const int64_t total_batches = ne02 * ne03; + + // Allocate pointer arrays on device + ggml_cuda_pool_alloc A_ptrs(ctx.pool(), total_batches); + ggml_cuda_pool_alloc X_ptrs(ctx.pool(), total_batches); + + // Set up pointer arrays on device (CUDA graph compatible) + { + const int block_size = 256; + const int grid_size = (total_batches + block_size - 1) / block_size; + setup_trsm_batch_pointers<<>>( + A, X, + A_ptrs.get(), X_ptrs.get(), + ne02, total_batches, + nb02, nb03, nb2, nb3 + ); + CUDA_CHECK(cudaGetLastError()); + } + + // Get cuBLAS handle and set stream + cublasHandle_t handle = ctx.cublas_handle(); + cublasSetStream(handle, stream); + + // Save current math mode and set to default for accuracy + // (TF32 can cause numerical issues with triangular solves) + cublasMath_t prev_math_mode; + cublasGetMathMode(handle, &prev_math_mode); + cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH); + + const float alpha = 1.0f; + + cublasStatus_t status = cublasStrsmBatched( + handle, + CUBLAS_SIDE_RIGHT, // A is on the right: X * A = B + CUBLAS_FILL_MODE_UPPER, // A^T is upper (since A is lower in row-major) + CUBLAS_OP_N, // No additional transpose + CUBLAS_DIAG_NON_UNIT, // Diagonal is not assumed to be 1 + k, // m: rows of X^T (columns of X) + n, // n: columns of X^T (rows of X) = size of A + &alpha, + (const float **)A_ptrs.get(), n, // lda = n (leading dimension) + (float **)X_ptrs.get(), k, // ldb = k (leading dimension of X^T) + total_batches + ); + + // Restore previous math mode + cublasSetMathMode(handle, prev_math_mode); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, "cuBLAS batched TRSM failed: %d\n", (int) status); + } +} + +static void solve_tri_f32_cuda(const float * A, + const float * B, + float * X, + int n, + int k, + int64_t ne02, + int64_t ne03, + size_t nb02, + size_t nb03, + size_t nb12, + size_t nb13, + size_t nb2, + size_t nb3, + cudaStream_t stream) { + const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02); + dim3 grid(ne02 * ne03); + + // Handle large matrices first (256×256 and 65-128 range) + + // Route sizes 65-256 to the tiled kernel + if (n > 64 || k > 64) { + // Use the tiled kernel which works for any size up to 256 + // and only requires ~48KB shared memory (within standard limits) + dim3 threads_256(WARP_SIZE, 32); // 1024 threads + // Shared memory: 64×64 + 64×65 + 64×64 = 16KB + 16.25KB + 16KB = ~48KB + const size_t smem_size = (64 * 64 + 64 * 65 + 64 * 64) * sizeof(float); + + // Configure extended shared memory for this kernel + static bool smem_configured_tiled = false; + if (!smem_configured_tiled) { + cudaFuncSetAttribute(solve_tri_f32_256x256_tiled, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + smem_configured_tiled = true; + } + + solve_tri_f32_256x256_tiled<<>>( + A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); + return; + } + + // Limit threads_y to 32 to ensure we don't exceed 1024 threads per block (32 * 32 = 1024) + const int threads_y = k <= 32 ? k : 32; + dim3 threads(WARP_SIZE, threads_y); + + if (n == 64) { + switch (k) { + case 64: + { + // Use optimized kernel for n=64, k=64 case (common in Qwen3 Next DeltaNet) + // Block config: 32x32 = 1024 threads (32 warps) + dim3 threads_64x64(WARP_SIZE, 32); + solve_tri_f32_64x64_opt + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3); + } + break; + case 48: + // k=48 needs 2 columns per thread (threads_y=32, some threads handle 1, some 2) + solve_tri_f32_fast<64, 48, 32> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 40: + // k=40 needs 2 columns per thread (threads_y=32, some threads handle 1, some 2) + solve_tri_f32_fast<64, 40, 32> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 32: + solve_tri_f32_fast<64, 32, 32> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 16: + solve_tri_f32_fast<64, 16, 16> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 14: + solve_tri_f32_fast<64, 14, 14> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 12: + solve_tri_f32_fast<64, 12, 12> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 10: + solve_tri_f32_fast<64, 10, 10> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 8: + solve_tri_f32_fast<64, 8, 8> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 6: + solve_tri_f32_fast<64, 6, 6> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 4: + solve_tri_f32_fast<64, 4, 4> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 2: + solve_tri_f32_fast<64, 2, 2> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 1: + solve_tri_f32_fast<64, 1, 1> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + default: + solve_tri_f32_fast<0, 0, 0> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); + } + } else { // run general case + solve_tri_f32_fast<0, 0, 0> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); + } +} + +void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // A (triangular n x n matrix) + const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns) + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_ASSERT(src0->ne[0] == src0->ne[1]); + GGML_ASSERT(src0->ne[1] == src1->ne[1]); + GGML_ASSERT(src0->ne[2] == src1->ne[2]); + GGML_ASSERT(src0->ne[3] == src1->ne[3]); + + const int n = src0->ne[0]; + const int k = src1->ne[0]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + if (n <= MAX_N_FAST && k <= MAX_K_FAST) { + solve_tri_f32_cuda( + (const float *) src0->data, + (const float *) src1->data, + (float *) dst->data, + n, k, + ne02, ne03, + src0->nb[2] / sizeof(float), + src0->nb[3] / sizeof(float), + src1->nb[2] / sizeof(float), + src1->nb[3] / sizeof(float), + dst->nb[2] / sizeof(float), + dst->nb[3] / sizeof(float), + ctx.stream()); + return; + } + + if (dst->data != src1->data) { + const int64_t total_batches = ne02 * ne03; + const size_t X_size = (size_t) n * (size_t) k * (size_t) total_batches * sizeof(float); + CUDA_CHECK(cudaMemcpyAsync(dst->data, src1->data, X_size, cudaMemcpyDeviceToDevice, ctx.stream())); + } + + solve_tri_f32_cublas( + ctx, + (const float *) src0->data, + (float *) dst->data, + n, k, + ne02, ne03, + src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), + dst->nb[2] / sizeof(float), dst->nb[3] / sizeof(float), + ctx.stream()); +} diff --git a/ggml/src/ggml-cuda/solve_tri.cuh b/ggml/src/ggml-cuda/solve_tri.cuh new file mode 100644 index 00000000..63999239 --- /dev/null +++ b/ggml/src/ggml-cuda/solve_tri.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu new file mode 100644 index 00000000..a9f72ece --- /dev/null +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -0,0 +1,608 @@ +#include "ssm-conv.cuh" + +#define CUDA_SSM_CONV_BLOCK_SIZE 256 + +template +static __global__ void ssm_conv_single_seq_f32( + const float * src0, + const float * src1, + const float * src2, + float * dst_x, + int nc, + int nr, + int n_t, + int src0_s0, + int src0_s1, + int src1_s1) { + const int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= nr) { + return; + } + + const int t0 = blockIdx.y * split_n_t; + if (t0 >= n_t) { + return; + } + + const float * state_row = src0 + (size_t) row * src0_s1; + const float * c_row = src2 + (size_t) row * nc; + +#pragma unroll + for (int it = 0; it < split_n_t; ++it) { + const int t = t0 + it; + if (t >= n_t) { + break; + } + + float sumf = 0.0f; + for (int j = 0; j < nc; ++j) { + const int idx = t + j; + const float x = idx < nc - 1 + ? state_row[(size_t) idx * src0_s0] + : src1[row + (size_t) (idx - (nc - 1)) * src1_s1]; + + sumf += x * c_row[j]; + } + + dst_x[row + (size_t) t * nr] = sumf; + } +} + +template +static __global__ void ssm_conv_single_seq_f32_nc4( + const float * src0, + const float * src1, + const float * src2, + float * dst_x, + int nr, + int n_t, + int src0_s0, + int src0_s1, + int src1_s1) { + const int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= nr) { + return; + } + + const int t0 = blockIdx.y * split_n_t; + if (t0 >= n_t) { + return; + } + + const float * state_row = src0 + (size_t) row * src0_s1; + const float * c_row = src2 + (size_t) row * 4; + const float c0 = c_row[0]; + const float c1 = c_row[1]; + const float c2 = c_row[2]; + const float c3 = c_row[3]; + +#pragma unroll + for (int it = 0; it < split_n_t; ++it) { + const int t = t0 + it; + if (t >= n_t) { + break; + } + + const int i0 = t; + const int i1 = t + 1; + const int i2 = t + 2; + const int i3 = t + 3; + + const float x0 = i0 < 3 ? state_row[(size_t) i0 * src0_s0] : src1[row + (size_t) (i0 - 3) * src1_s1]; + const float x1 = i1 < 3 ? state_row[(size_t) i1 * src0_s0] : src1[row + (size_t) (i1 - 3) * src1_s1]; + const float x2 = i2 < 3 ? state_row[(size_t) i2 * src0_s0] : src1[row + (size_t) (i2 - 3) * src1_s1]; + const float x3 = i3 < 3 ? state_row[(size_t) i3 * src0_s0] : src1[row + (size_t) (i3 - 3) * src1_s1]; + + dst_x[row + (size_t) t * nr] = x0 * c0 + x1 * c1 + x2 * c2 + x3 * c3; + } +} + +static __global__ void ssm_conv_single_seq_final_state_f32( + const float * src0, + const float * src1, + float * dst_state, + int nc, + int nr, + int n_t, + int src0_s0, + int src0_s1, + int src1_s1) { + const int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= nr) { + return; + } + + const float * state_row = src0 + (size_t) row * src0_s1; + float * dst_row = dst_state + (size_t) row * nc; + + for (int j = 0; j < nc; ++j) { + const int idx = n_t - 1 + j; + dst_row[j] = idx < nc - 1 + ? state_row[(size_t) idx * src0_s0] + : src1[row + (size_t) (idx - (nc - 1)) * src1_s1]; + } +} + +static __global__ void ssm_conv_init_states_f32_nc4( + const float * src0, + float * state, + int nr, + int n_kv) { + const int row = blockIdx.x * blockDim.x + threadIdx.x; + const int seq = blockIdx.y; + + if (row >= nr || seq >= n_kv) { + return; + } + + const float * src_row = src0 + (size_t) seq * nr * 3 + (size_t) row * 3; + float * state_row = state + (size_t) seq * nr * 4 + (size_t) row * 4; + + state_row[1] = src_row[0]; + state_row[2] = src_row[1]; + state_row[3] = src_row[2]; +} + +static __global__ void ssm_conv_init_states_f32( + const float * src0, + float * state, + int nc, + int nr, + int n_kv) { + const int row = blockIdx.x * blockDim.x + threadIdx.x; + const int seq = blockIdx.y; + + if (row >= nr || seq >= n_kv) { + return; + } + + const float * src_row = src0 + (size_t) seq * nr * (nc - 1) + (size_t) row * (nc - 1); + float * state_row = state + (size_t) seq * nr * nc + (size_t) row * nc; + + for (int i0 = 0; i0 < nc - 1; ++i0) { + state_row[1 + i0] = src_row[i0]; + } +} + +static __global__ void ssm_conv_validate_unique_seq_map( + const int32_t * src3, + int32_t * seq_ids, + int32_t * seq_seen, + int32_t * fast_path_ok, + int n_t, + int n_kv, + int src3_nb1) { + const int t = blockIdx.x * blockDim.x + threadIdx.x; + if (t >= n_t) { + return; + } + + const int32_t * sq = src3 + (size_t) t * src3_nb1; + const int32_t seq0 = sq[0]; + if (seq0 < 0 || seq0 >= n_kv) { + atomicExch(fast_path_ok, 0); + return; + } + + // Fast path supports one sequence per token (no copy-to-multiple-sequences routing). + if (n_kv > 1) { + const int32_t seq1 = sq[1]; + if (seq1 >= 0 && seq1 < n_kv) { + atomicExch(fast_path_ok, 0); + return; + } + } + + seq_ids[t] = seq0; + if (atomicAdd(seq_seen + seq0, 1) != 0) { + // Sequence is updated by multiple tokens in the same batch => recurrent dependency across t. + atomicExch(fast_path_ok, 0); + } +} + +static __global__ void ssm_conv_multi_seq_unique_f32_kernel( + const float * src0, + const float * src1, + const float * src2, + const int32_t * seq_ids, + const int32_t * fast_path_ok, + float * dst_x, + float * dst_state, + int nc, + int nr, + int n_t, + int src1_nb1) { + if (fast_path_ok != nullptr && fast_path_ok[0] == 0) { + return; + } + + const int row = blockIdx.x * blockDim.x + threadIdx.x; + const int t = blockIdx.y; + + if (row >= nr || t >= n_t) { + return; + } + + const int seq = seq_ids[t]; + const float * src_state_row = src0 + (size_t) seq * nr * (nc - 1) + (size_t) row * (nc - 1); + float * state_row = dst_state + (size_t) seq * nr * nc + (size_t) row * nc; + const float * c_row = src2 + (size_t) row * nc; + + float sumf = 0.0f; + for (int i0 = 0; i0 < nc - 1; ++i0) { + const float v = src_state_row[i0]; + state_row[i0] = v; + sumf += v * c_row[i0]; + } + + const float x = src1[row + (size_t) t * src1_nb1]; + state_row[nc - 1] = x; + sumf += x * c_row[nc - 1]; + dst_x[row + (size_t) t * nr] = sumf; +} + +static __global__ void ssm_conv_multi_seq_unique_f32_kernel_nc4( + const float * src0, + const float * src1, + const float * src2, + const int32_t * seq_ids, + const int32_t * fast_path_ok, + float * dst_x, + float * dst_state, + int nr, + int n_t, + int src1_nb1) { + if (fast_path_ok != nullptr && fast_path_ok[0] == 0) { + return; + } + + const int row = blockIdx.x * blockDim.x + threadIdx.x; + const int t = blockIdx.y; + + if (row >= nr || t >= n_t) { + return; + } + + const int seq = seq_ids[t]; + const float * src_state_row = src0 + (size_t) seq * nr * 3 + (size_t) row * 3; + float * state_row = dst_state + (size_t) seq * nr * 4 + (size_t) row * 4; + const float * c_row = src2 + (size_t) row * 4; + + const float s0 = src_state_row[0]; + const float s1 = src_state_row[1]; + const float s2 = src_state_row[2]; + const float x = src1[row + (size_t) t * src1_nb1]; + + state_row[0] = s0; + state_row[1] = s1; + state_row[2] = s2; + state_row[3] = x; + + dst_x[row + (size_t) t * nr] = s0 * c_row[0] + s1 * c_row[1] + s2 * c_row[2] + x * c_row[3]; +} + +static __global__ void ssm_conv_f32_kernel( + const float * src0, + const float * src1, + const float * src2, + const int32_t * src3, + const int32_t * fast_path_ok, + float * dst_x, + float * dst_state, + int nc, + int nr, + int n_t, + int n_kv, + int src1_nb1, + int src3_nb1) { + if (fast_path_ok != nullptr && fast_path_ok[0] != 0) { + return; + } + + const int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= nr) { + return; + } + + const float * c_row = src2 + (size_t) row * nc; + + for (int t = 0; t < n_t; ++t) { + const int32_t * sq = src3 + (size_t) t * src3_nb1; + const int seq0 = sq[0]; + + if (seq0 < 0 || seq0 >= n_kv) { + continue; + } + + float * state_row = dst_state + (size_t) seq0 * nr * nc + (size_t) row * nc; + const float * src_state_row; + if (t == 0) { + src_state_row = src0 + (size_t) seq0 * nr * (nc - 1) + (size_t) row * (nc - 1); + } else { + src_state_row = state_row + 1; + } + + for (int i0 = 0; i0 < nc - 1; ++i0) { + state_row[i0] = src_state_row[i0]; + } + state_row[nc - 1] = src1[row + (size_t) t * src1_nb1]; + + for (int i3 = 1; i3 < n_kv; ++i3) { + const int seq = sq[i3]; + if (seq < 0 || seq >= n_kv) { + break; + } + + float * state_row_copy = dst_state + (size_t) seq * nr * nc + (size_t) row * nc; + for (int i0 = 0; i0 < nc; ++i0) { + state_row_copy[i0] = state_row[i0]; + } + } + + float sumf = 0.0f; + for (int i0 = 0; i0 < nc; ++i0) { + sumf += state_row[i0] * c_row[i0]; + } + dst_x[row + (size_t) t * nr] = sumf; + } +} + +template +static __global__ void ssm_conv_f32_kernel_nc4( + const float * src0, + const float * src1, + const float * src2, + const int32_t * src3, + const int32_t * fast_path_ok, + float * dst_x, + float * dst_state, + int nr, + int n_t, + int n_kv, + int src1_nb1, + int src3_nb1) { + if (fast_path_ok != nullptr && fast_path_ok[0] != 0) { + return; + } + + const int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= nr) { + return; + } + + const float * c_row = src2 + (size_t) row * 4; + const float c0 = c_row[0]; + const float c1 = c_row[1]; + const float c2 = c_row[2]; + const float c3 = c_row[3]; + + for (int t = 0; t < n_t; ++t) { + const int32_t * sq = src3 + (size_t) t * src3_nb1; + const int seq0 = sq[0]; + + if (seq0 < 0 || seq0 >= n_kv) { + continue; + } + + float * state_row = dst_state + (size_t) seq0 * nr * 4 + (size_t) row * 4; + + const float * src_state_row; + if (t == 0) { + src_state_row = src0 + (size_t) seq0 * nr * 3 + (size_t) row * 3; + } else { + src_state_row = state_row + 1; + } + + const float s0 = src_state_row[0]; + const float s1 = src_state_row[1]; + const float s2 = src_state_row[2]; + const float x = src1[row + (size_t) t * src1_nb1]; + + state_row[0] = s0; + state_row[1] = s1; + state_row[2] = s2; + state_row[3] = x; + + if constexpr (has_multi_seq) { + for (int i3 = 1; i3 < n_kv; ++i3) { + const int seq = sq[i3]; + if (seq < 0 || seq >= n_kv) { + break; + } + + float * state_row_copy = dst_state + (size_t) seq * nr * 4 + (size_t) row * 4; + state_row_copy[0] = s0; + state_row_copy[1] = s1; + state_row_copy[2] = s2; + state_row_copy[3] = x; + } + } + + dst_x[row + (size_t) t * nr] = s0 * c0 + s1 * c1 + s2 * c2 + x * c3; + } +} + +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // conv_state: [d_conv - 1, d_inner, n_kv] + const ggml_tensor * src1 = dst->src[1]; // x: [d_inner, n_tokens] + const ggml_tensor * src2 = dst->src[2]; // conv1d.weight: [d_conv, d_inner] + const ggml_tensor * src3 = dst->src[3]; // state_seq: [n_kv, n_tokens] + + const int nc = src2->ne[0]; + const int nr = src0->ne[1]; + const int n_t = src1->ne[1]; + const int n_kv = src0->ne[2]; + + GGML_ASSERT((int64_t) nr * n_t + (int64_t) nc * nr * n_kv == ggml_nelements(dst)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src3->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(float)); + GGML_ASSERT(src3->nb[0] == sizeof(int32_t)); + GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); + GGML_ASSERT(src2->nb[1] == src2->ne[0] * sizeof(float)); + GGML_ASSERT(src2->nb[2] == src2->ne[1] * src2->ne[0] * sizeof(float)); + + GGML_ASSERT(src2->ne[0] == src0->ne[0] + 1); + GGML_ASSERT(src2->ne[1] == src0->ne[1]); + GGML_ASSERT(src1->ne[0] == src0->ne[1]); + GGML_ASSERT(src3->ne[0] == src0->ne[2]); + GGML_ASSERT(src3->ne[1] == src1->ne[1]); + + float * dst_data = (float *) dst->data; + float * dst_x = dst_data; + float * dst_state = dst_data + (size_t) nr * n_t; + + const dim3 block_dims(CUDA_SSM_CONV_BLOCK_SIZE, 1, 1); + const dim3 row_grid((nr + CUDA_SSM_CONV_BLOCK_SIZE - 1) / CUDA_SSM_CONV_BLOCK_SIZE, 1, 1); + ggml_cuda_pool_alloc fast_path_ok_d(ctx.pool()); + const int32_t * multi_seq_fast_path_ok = nullptr; + + // Fast path for single-sequence recurrent updates (Qwen3Next prompt/decode path). + // In this case, outputs are independent given the initial conv state, so we parallelize over token blocks. + if (n_kv == 1 && src3->ne[0] == 1) { + GGML_ASSERT(n_t > 0); + + const int src0_s0 = src0->nb[0] / sizeof(float); + const int src0_s1 = src0->nb[1] / sizeof(float); + const int src1_s1 = src1->nb[1] / sizeof(float); + + constexpr int split_n_t = 32; + const dim3 token_grid(row_grid.x, (n_t + split_n_t - 1) / split_n_t, 1); + + if (nc == 4) { + ssm_conv_single_seq_f32_nc4<<>>( + (const float *) src0->data, + (const float *) src1->data, + (const float *) src2->data, + dst_x, + nr, n_t, + src0_s0, src0_s1, src1_s1); + } else { + ssm_conv_single_seq_f32<<>>( + (const float *) src0->data, + (const float *) src1->data, + (const float *) src2->data, + dst_x, + nc, nr, n_t, + src0_s0, src0_s1, src1_s1); + } + + ssm_conv_single_seq_final_state_f32<<>>( + (const float *) src0->data, + (const float *) src1->data, + dst_state, + nc, nr, n_t, + src0_s0, src0_s1, src1_s1); + return; + } + + if (n_kv > 1) { + const dim3 init_grid(row_grid.x, n_kv, 1); + if (nc == 4) { + ssm_conv_init_states_f32_nc4<<>>( + (const float *) src0->data, + dst_state, + nr, n_kv); + } else { + ssm_conv_init_states_f32<<>>( + (const float *) src0->data, + dst_state, + nc, nr, n_kv); + } + + // Fast path for multi-sequence decode-like batches: + // one token per unique sequence, no copy-to-multiple-sequences routing. + ggml_cuda_pool_alloc seq_ids(ctx.pool(), n_t); + ggml_cuda_pool_alloc seq_seen(ctx.pool(), n_kv); + int32_t fast_path_ok = 1; + fast_path_ok_d.alloc(1); + + CUDA_CHECK(cudaMemsetAsync(seq_seen.get(), 0, n_kv * sizeof(int32_t), ctx.stream())); + CUDA_CHECK(cudaMemcpyAsync(fast_path_ok_d.get(), &fast_path_ok, sizeof(int32_t), cudaMemcpyHostToDevice, ctx.stream())); + + constexpr int seq_map_block_size = 256; + const dim3 seq_map_grid((n_t + seq_map_block_size - 1) / seq_map_block_size, 1, 1); + ssm_conv_validate_unique_seq_map<<>>( + (const int32_t *) src3->data, + seq_ids.get(), + seq_seen.get(), + fast_path_ok_d.get(), + n_t, + n_kv, + src3->nb[1] / sizeof(int32_t)); + CUDA_CHECK(cudaGetLastError()); + multi_seq_fast_path_ok = fast_path_ok_d.get(); + + const dim3 token_grid(row_grid.x, n_t, 1); + if (nc == 4) { + ssm_conv_multi_seq_unique_f32_kernel_nc4<<>>( + (const float *) src0->data, + (const float *) src1->data, + (const float *) src2->data, + seq_ids.get(), + multi_seq_fast_path_ok, + dst_x, + dst_state, + nr, n_t, + src1->nb[1] / sizeof(float)); + } else { + ssm_conv_multi_seq_unique_f32_kernel<<>>( + (const float *) src0->data, + (const float *) src1->data, + (const float *) src2->data, + seq_ids.get(), + multi_seq_fast_path_ok, + dst_x, + dst_state, + nc, nr, n_t, + src1->nb[1] / sizeof(float)); + } + } + + if (nc == 4) { + if (n_kv > 1) { + ssm_conv_f32_kernel_nc4<<>>( + (const float *) src0->data, + (const float *) src1->data, + (const float *) src2->data, + (const int32_t *) src3->data, + multi_seq_fast_path_ok, + dst_x, + dst_state, + nr, n_t, n_kv, + src1->nb[1] / sizeof(float), + src3->nb[1] / sizeof(int32_t)); + } else { + ssm_conv_f32_kernel_nc4<<>>( + (const float *) src0->data, + (const float *) src1->data, + (const float *) src2->data, + (const int32_t *) src3->data, + nullptr, + dst_x, + dst_state, + nr, n_t, n_kv, + src1->nb[1] / sizeof(float), + src3->nb[1] / sizeof(int32_t)); + } + } else { + ssm_conv_f32_kernel<<>>( + (const float *) src0->data, + (const float *) src1->data, + (const float *) src2->data, + (const int32_t *) src3->data, + multi_seq_fast_path_ok, + dst_x, + dst_state, + nc, nr, n_t, n_kv, + src1->nb[1] / sizeof(float), + src3->nb[1] / sizeof(int32_t)); + } +} diff --git a/ggml/src/ggml-cuda/ssm-conv.cuh b/ggml/src/ggml-cuda/ssm-conv.cuh new file mode 100644 index 00000000..8e6c1f00 --- /dev/null +++ b/ggml/src/ggml-cuda/ssm-conv.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index c86ad4ed..c4929e81 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -16,6 +16,25 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc } } +static __global__ void k_sum_rows_nc_f32(const char * x, char * y, const int ncols, + size_t nb00, size_t nb01, size_t nb02, size_t nb03, size_t nb1, size_t nb2, size_t nb3) { + + const char * src = x + nb03*blockIdx.z + nb02*blockIdx.y + nb01*blockIdx.x; + float * dst = (float *)(y + nb3*blockIdx.z + nb2*blockIdx.y + nb1*blockIdx.x); + const int col = threadIdx.x; + + float sum = 0.0f; + for (int i = col; i < ncols; i += blockDim.x) { + sum += *(const float *)(src + i*nb00); + } + + sum = warp_reduce_sum(sum); + + if (col == 0) { + dst[0] = sum; + } +} + static __global__ void k_sum_rows_div_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, float s, float b) { const int row = blockIdx.x; const int col = threadIdx.x; @@ -43,6 +62,12 @@ void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int const dim3 block_nums(nrows, 1, 1); k_sum_rows_f32<<>>(x, dst, ncols); } +static void sum_rows_f32_cuda_nc(const char * x, char * dst, int ne0, int ne1, int ne2, int ne3, + size_t nb00, size_t nb01, size_t nb02, size_t nb03, size_t nb1, size_t nb2, size_t nb3, cudaStream_t stream) { + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(ne1, ne2, ne3); + k_sum_rows_nc_f32<<>>(x, dst, ne0, nb00, nb01, nb02, nb03, nb1, nb2, nb3); +} static void sum_rows_div_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, float s, float b, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); @@ -52,19 +77,30 @@ static void sum_rows_div_f32_cuda(const float * x, float * dst, const int ncols, void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(src0)); - + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; const int64_t ncols = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream); + +} + +void ggml_cuda_op_sum_rows_nc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src = dst->src[0]->src[0]; + GGML_ASSERT(src->op == GGML_OP_TRANSPOSE); + GGML_ASSERT(dst->type == GGML_TYPE_F32 && src->type == GGML_TYPE_F32); + + cudaStream_t stream = ctx.stream(); + + sum_rows_f32_cuda_nc((const char *)src->data, (char *)dst->data, src->ne[0], src->ne[1], src->ne[2], src->ne[3], + src->nb[0], src->nb[1], src->nb[2], src->nb[3], dst->nb[1], dst->nb[2], dst->nb[3], stream); } void ggml_cuda_op_sum_rows_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/sumrows.cuh b/ggml/src/ggml-cuda/sumrows.cuh index b6c0dc26..4c31439b 100644 --- a/ggml/src/ggml-cuda/sumrows.cuh +++ b/ggml/src/ggml-cuda/sumrows.cuh @@ -5,3 +5,5 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream); void ggml_cuda_op_sum_rows_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_sum_rows_nc(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu new file mode 100644 index 00000000..2c0ba0ca --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cu @@ -0,0 +1,120 @@ +#include "tri.cuh" +#include "convert.cuh" + +template +static __global__ void tri_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + const int64_t split_point = i1 + add_to_split; + + (void) nb00; + (void) nb0; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = src + i1 * nb01 + i2 * nb02 + i3 * nb03; + T * dst_row = dst + i1 * nb1 + i2 * nb2 + i3 * nb3; + + if constexpr (prefix_keep) { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + dst_row[i0] = src_row[i0]; + } + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = ggml_cuda_cast(0.0f); + } + } else { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + dst_row[i0] = ggml_cuda_cast(0.0f); + } + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = src_row[i0]; + } + } +} + +template +static void tri_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const ggml_tri_type ttype, + cudaStream_t stream) { + dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); + dim3 grid_dims(ne01, ne02, ne03); + const size_t type_size = sizeof(T); + + const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0; + const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG); + + if (prefix_keep) { + if (add_to_split == 0) { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } + } else { + if (add_to_split == 0) { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } + } +} + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tri_type ttype = static_cast(((const int32_t *) dst->op_params)[0]); + + GGML_ASSERT(src0->type == dst->type); + + switch (src0->type) { + case GGML_TYPE_F32: + tri_cuda( + (const float *) src0->data, (float *) dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, ctx.stream() + ); + break; + case GGML_TYPE_F16: + tri_cuda( + (const half *) src0->data, (half *) dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, ctx.stream() + ); + break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/tri.cuh b/ggml/src/ggml-cuda/tri.cuh new file mode 100644 index 00000000..a4cc6675 --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_TRI_BLOCK_SIZE 256 + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index b73a46db..f6ca4a4f 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -735,6 +735,10 @@ static __device__ __forceinline__ float op_exp(float x) { return expf(x); } +static __device__ __forceinline__ float op_softplus(float x) { + return (x > 20.0f) ? x : logf(1.0f + expf(x)); +} + static __device__ __forceinline__ float op_sin(float x) { return sinf(x); } @@ -831,6 +835,10 @@ void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } +void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + // === gated ops template @@ -942,4 +950,3 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } - diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index ddd776f8..75497aea 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -53,6 +53,8 @@ void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b899678c..e4b345bd 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3288,6 +3288,9 @@ inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); } inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); } +inline static float ggml_compute_softplus_f32(const float x) { return x > 20.0f ? x : logf(1.0f + expf(x)); } +inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); } +inline static void ggml_vec_softplus_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = ggml_compute_softplus_f32(x[i]); } // TODO: optimize performance inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } @@ -4269,6 +4272,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "UNARY", + "CUMSUM", + "L2_NORM", + "TRI", + "FILL", + "SOLVE_TRI", + "MAP_UNARY", "MAP_BINARY", @@ -4290,7 +4299,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "FUSED_NORM", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 100, "GGML_OP_COUNT != 100"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -4381,6 +4390,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "unary(x)", + "cumsum(x)", + "l2_norm(x)", + "tri(x)", + "fill(x)", + "solve_tri(x)", + "f(x)", "f(x,y)", @@ -4402,7 +4417,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "norm(x,y)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 100, "GGML_OP_COUNT != 100"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4424,9 +4439,11 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "SWIGLU", "SWIGLU_OAI", "GELU_ERF", + "EXP", + "SOFTPLUS", }; -static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16"); +static_assert(GGML_UNARY_OP_COUNT == 18, "GGML_UNARY_OP_COUNT != 18"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -6712,6 +6729,28 @@ struct ggml_tensor * ggml_sum_rows( return result; } +// ggml_cumsum + +struct ggml_tensor * ggml_cumsum( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + GGML_ASSERT(a->type == GGML_TYPE_F32); + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_CUMSUM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + // ggml_mean struct ggml_tensor * ggml_mean( @@ -6780,6 +6819,33 @@ struct ggml_tensor * ggml_repeat( return result; } +struct ggml_tensor * ggml_repeat_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) { + const bool can_repeat = ggml_is_empty(a) || ( + (ne0 % a->ne[0] == 0) && + (ne1 % a->ne[1] == 0) && + (ne2 % a->ne[2] == 0) && + (ne3 % a->ne[3] == 0) + ); + GGML_ASSERT(can_repeat); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + + result->op = GGML_OP_REPEAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + // ggml_repeat_back struct ggml_tensor * ggml_repeat_back( @@ -6978,6 +7044,20 @@ struct ggml_tensor * ggml_sigmoid_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID); } +// ggml_softplus + +struct ggml_tensor * ggml_softplus( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_SOFTPLUS); +} + +struct ggml_tensor * ggml_softplus_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SOFTPLUS); +} + // ggml_gelu struct ggml_tensor * ggml_gelu( @@ -7121,6 +7201,20 @@ struct ggml_tensor * ggml_hardsigmoid( return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID); } +// ggml exp + +struct ggml_tensor * ggml_exp( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_EXP); +} + +struct ggml_tensor * ggml_exp_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP); +} + // =============== from mainline begin ===================================== // ggml_glu @@ -7542,6 +7636,45 @@ struct ggml_tensor * ggml_group_norm_inplace( return ggml_group_norm_impl(ctx, a, n_groups, eps, true); } +// ggml_l2_norm + +static struct ggml_tensor * ggml_l2_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps, + bool inplace) { + bool is_node = false; + + if (!inplace && a->grad) { + GGML_ABORT("fatal error"); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params_f32(result, 0, eps); + + result->op = GGML_OP_L2_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_l2_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_l2_norm_impl(ctx, a, eps, false); +} + +struct ggml_tensor * ggml_l2_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_l2_norm_impl(ctx, a, eps, true); +} + // ggml_mul_mat struct ggml_tensor * ggml_mul_mat( @@ -9709,6 +9842,74 @@ struct ggml_tensor * ggml_timestep_embedding( return result; } +// ggml_tri + +struct ggml_tensor * ggml_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_tri_type type) { + GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(a->ne[0] == a->ne[1]); + + bool is_node = false; + if (a->grad) { + GGML_ABORT("fatal error"); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, (int32_t) type); + + result->op = GGML_OP_TRI; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_fill + +static struct ggml_tensor * ggml_fill_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + float c, + bool inplace) { + GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(a)); + + bool is_node = false; + if (!inplace && a->grad) { + GGML_ABORT("fatal error"); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params_f32(result, 0, c); + + result->op = GGML_OP_FILL; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_fill( + struct ggml_context * ctx, + struct ggml_tensor * a, + float c) { + return ggml_fill_impl(ctx, a, c, false); +} + +struct ggml_tensor * ggml_fill_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float c) { + return ggml_fill_impl(ctx, a, c, true); +} + // ggml_argsort struct ggml_tensor * ggml_argsort( @@ -9991,6 +10192,44 @@ struct ggml_tensor * ggml_flash_attn_back( return result; } +// ggml_solve_tri + +struct ggml_tensor * ggml_solve_tri( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool left, + bool lower, + bool uni) { + GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(b->type == GGML_TYPE_F32); + + GGML_ASSERT(a->ne[0] == a->ne[1]); + GGML_ASSERT(a->ne[1] == b->ne[1]); + GGML_ASSERT(a->ne[2] == b->ne[2]); + GGML_ASSERT(a->ne[3] == b->ne[3]); + + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_is_contiguous(b)); + + GGML_ASSERT(lower && left && !uni); // TODO: support other variants + + bool is_node = false; + if (a->grad || b->grad) { + GGML_ABORT("fatal error"); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, b->ne[0], b->ne[1], b->ne[2], b->ne[3]); + + result->op = GGML_OP_SOLVE_TRI; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_ssm_conv struct ggml_tensor * ggml_ssm_conv( @@ -11702,46 +11941,40 @@ static void ggml_compute_forward_dup_f32( } } -// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy. static void ggml_compute_forward_dup_bytes( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); GGML_ASSERT(src0->type == dst->type); - if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) { + GGML_TENSOR_UNARY_OP_LOCALS; + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->nb[0] == ggml_type_size(src0->type)) { ggml_compute_forward_dup_same_cont(params, dst); return; } - GGML_TENSOR_UNARY_OP_LOCALS; - const size_t type_size = ggml_type_size(src0->type); + const int ith = params->ith; // thread index const int nth = params->nth; // number of threads - // parallelize by rows const int nr = ne01; - const int n_packed = ggml_packed_rows(dst->type); - GGML_ASSERT(nr%n_packed == 0); - const int nrp = nr/n_packed; // number of rows per thread - const int drp = (nrp + nth - 1) / nth; - const int dr = drp*n_packed; + const int dr = (nr + nth - 1) / nth; // row range for this thread const int ir0 = dr * ith; - if (ir0 >= nr) return; const int ir1 = MIN(ir0 + dr, nr); if (src0->type == dst->type && - ne00 == ne0 && + ggml_are_same_shape(src0, dst) && nb00 == type_size && nb0 == type_size) { + //if (ith == 0) printf("%s(1): %ld x %ld x %ld x %ld\n", __func__, ne00, ne01, ne02, ne03); // copy by rows - const size_t rs = ggml_row_size(src0->type, ne00); //ne00 * type_size; + const size_t rs = ggml_row_size(src0->type, ne00); for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ir0; i01 < ir1; i01++) { @@ -11758,9 +11991,10 @@ static void ggml_compute_forward_dup_bytes( if (ggml_is_contiguous(dst)) { size_t id = 0; char * dst_ptr = (char *) dst->data; - const size_t rs = ggml_row_size(src0->type, ne00); //ne00 * type_size; + const size_t rs = ne00 * type_size; if (nb00 == type_size) { + //if (ith == 0) printf("%s(2): %ld x %ld x %ld x %ld\n", __func__, ne00, ne01, ne02, ne03); // src0 is contigous on first dimension, copy by rows for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { @@ -11774,7 +12008,6 @@ static void ggml_compute_forward_dup_bytes( } } } else { - //printf("%s: this is not optimal - fix me\n", __func__); for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { @@ -11796,17 +12029,20 @@ static void ggml_compute_forward_dup_bytes( } // dst counters - - int64_t i10 = 0; + int64_t k10 = 0; int64_t i11 = 0; int64_t i12 = 0; int64_t i13 = 0; + // number of blocks in a row + const int64_t nk00 = ne00 / ggml_blck_size(src0->type); + const int64_t nk0 = ne0 / ggml_blck_size(dst->type); + for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; + k10 += nk00 * ir0; + while (k10 >= nk0) { + k10 -= nk0; if (++i11 == ne1) { i11 = 0; if (++i12 == ne2) { @@ -11818,14 +12054,14 @@ static void ggml_compute_forward_dup_bytes( } } for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + for (int64_t k00 = 0; k00 < nk00; k00++) { + const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); memcpy(dst_ptr, src0_ptr, type_size); - if (++i10 == ne0) { - i10 = 0; + if (++k10 == nk0) { + k10 = 0; if (++i11 == ne1) { i11 = 0; if (++i12 == ne2) { @@ -11838,9 +12074,9 @@ static void ggml_compute_forward_dup_bytes( } } } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; + k10 += nk00 * (ne01 - ir1); + while (k10 >= nk0) { + k10 -= nk0; if (++i11 == ne1) { i11 = 0; if (++i12 == ne2) { @@ -13262,13 +13498,15 @@ static void ggml_compute_forward_sub_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (params->ith != 0) { - return; - } - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - const int nr = ggml_nrows(src0); + int ith = params->ith; + int nth = params->nth; + + int nr = ggml_nrows(src0); + int nr_per_thread = (nr + nth - 1)/nth; + int first = ith*nr_per_thread; + int last = MIN(first + nr_per_thread, nr); GGML_TENSOR_BINARY_OP_LOCALS @@ -13276,7 +13514,7 @@ static void ggml_compute_forward_sub_f32( GGML_ASSERT(nb00 == sizeof(float)); if (nb10 == sizeof(float)) { - for (int ir = 0; ir < nr; ++ir) { + for (int ir = first; ir < last; ++ir) { // src0, src1 and dst are same shape => same indices const int i3 = ir/(ne2*ne1); const int i2 = (ir - i3*ne2*ne1)/ne1; @@ -13307,10 +13545,9 @@ static void ggml_compute_forward_sub_f32( float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + char * src1_ptr = (char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11; for (int i0 = 0; i0 < ne0; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); - - dst_ptr[i0] = src0_ptr[i0] - *src1_ptr; + dst_ptr[i0] = src0_ptr[i0] - *(float *)(src1_ptr + i0*nb10); } } } @@ -13365,6 +13602,9 @@ static void ggml_compute_forward_mul_f32( return; } + //if (ith == 0) printf("%s(%s): %ld x %ld x %ld x %ld * %ld x %ld x %ld x %ld -> %ld x %ld x %ld x %ld\n", __func__, dst->name, + // src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[1], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]); + const int64_t nr = ggml_nrows(src0); GGML_TENSOR_BINARY_OP_LOCALS @@ -13388,6 +13628,13 @@ static void ggml_compute_forward_mul_f32( float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + if (ne10 == 1) { + if (dst_ptr != src0_ptr) { + memcpy(dst_ptr, src0_ptr, ne00*sizeof(float)); + } + ggml_vec_scale_f32(ne00, dst_ptr, src1_ptr[0]); + } else { + for (int64_t r = 0 ; r < nr0; ++r) { #ifdef GGML_USE_ACCELERATE UNUSED(ggml_vec_mul_f32); @@ -13397,6 +13644,7 @@ static void ggml_compute_forward_mul_f32( ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); #endif } + } } } else { // src1 is not contiguous @@ -13800,6 +14048,62 @@ static void ggml_compute_forward_sum( } } +// ggml_compute_forward_cumsum + +static void ggml_compute_forward_cumsum_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(dst->nb[0] == sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + const int ith = params->ith; + const int nth = params->nth; + const int64_t nr = ne01*ne02*ne03; + + for (int64_t ir = ith; ir < nr; ir += nth) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = ir - i03*ne02*ne01 - i02*ne01; + + const float * src_row = (const float *) ((const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + float * dst_row = ( float *) (( char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + float acc = 0.0f; + for (int64_t i00 = 0; i00 < ne00; ++i00) { + acc += src_row[i00]; + dst_row[i00] = acc; + } + } +} + +static void ggml_compute_forward_cumsum( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cumsum_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_sum_rows static void ggml_compute_forward_sum_rows_f32( @@ -13808,10 +14112,6 @@ static void ggml_compute_forward_sum_rows_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->ith != 0) { - return; - } - GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(dst->nb[0] == sizeof(float)); @@ -13822,22 +14122,32 @@ static void ggml_compute_forward_sum_rows_f32( GGML_ASSERT(ne2 == ne02); GGML_ASSERT(ne3 == ne03); - for (int64_t i3 = 0; i3 < ne03; i3++) { - for (int64_t i2 = 0; i2 < ne02; i2++) { - for (int64_t i1 = 0; i1 < ne01; i1++) { - float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); - float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); - float row_sum = 0; - ggml_vec_sum_f32(ne00, &row_sum, src_row); - if (!isfinite(row_sum)) { - fprintf(stderr, "Oops(%s, %s): found %g for i1 = %d, i2 = %d, i3 = %d. ne00 = %d\n", __func__, dst->name, - (double)row_sum, (int)i1, (int)i2, (int)i3, (int)ne00); - exit(1); - } - dst_row[0] = row_sum; - } + int ith = params->ith; + int nth = params->nth; + + //if (params->ith == 0) printf("%s(%s): %ld x %ld x %ld x %ld\n", __func__, dst->name, ne00, ne1, ne2, ne3); + + int nrows = ggml_nrows(src0); + int nrows_per_thread = (nrows + nth - 1)/nth; + int first_row = nrows_per_thread*ith; + int last_row = MIN(first_row + nrows_per_thread, nrows); + + for (int ir = first_row; ir < last_row; ++ir) { + int i3 = ir / (ne01*ne02); + int i2 = (ir - i3*ne01*ne02)/ne01; + int i1 = ir - i3*ne01*ne0 - i2*ne01; + const float * src_row = (const float *)((const char *)src0->data + i1*nb01 + i2*nb02 + i3*nb03); + float * dst_row = ( float *)(( char *)dst->data + i1*nb1 + i2*nb2 + i3*nb3); + float row_sum = 0; + ggml_vec_sum_f32(ne00, &row_sum, src_row); + if (!isfinite(row_sum)) { + fprintf(stderr, "Oops(%s, %s): found %g for i1 = %d, i2 = %d, i3 = %d. ne00 = %d\n", __func__, dst->name, + (double)row_sum, (int)i1, (int)i2, (int)i3, (int)ne00); + GGML_ABORT("Fatal error"); } + dst_row[0] = row_sum; } + } static void ggml_compute_forward_sum_rows( @@ -13858,6 +14168,36 @@ static void ggml_compute_forward_sum_rows( } } +static void ggml_compute_forward_sum_rows_f32_nc( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + struct ggml_tensor * src = dst->src[0]->src[0]; + GGML_ASSERT(src->op == GGML_OP_TRANSPOSE); + GGML_ASSERT(dst->type == GGML_TYPE_F32 && src->type == GGML_TYPE_F32); + + int ith = params->ith; + int nth = params->nth; + + int nrows = ggml_nrows(src); + int nrows_per_thread = (nrows + nth - 1)/nth; + int first_row = nrows_per_thread*ith; + int last_row = MIN(first_row + nrows_per_thread, nrows); + + for (int ir = first_row; ir < last_row; ++ir) { + int i3 = ir / (src->ne[1]*src->ne[2]); + int i2 = (ir - i3*src->ne[1]*src->ne[2])/src->ne[1]; + int i1 = ir - i3*src->ne[1]*src->ne[2] - i2*src->ne[1]; + const float * src_row = (const float *)((const char *)src->data + i1*src->nb[1] + i2*src->nb[2] + i3*src->nb[3]); + float * dst_row = ( float *)(( char *)dst->data + i1*dst->nb[1] + i2*dst->nb[2] + i3*dst->nb[3]); + float row_sum = 0; + for (int i0 = 0; i0 < (int)src->ne[0]; ++i0) { + row_sum += *(const float *)((const char *)src_row + i0*src->nb[0]); + } + dst_row[0] = row_sum; + } +} + // ggml_compute_forward_mean static void ggml_compute_forward_mean_f32( @@ -13987,6 +14327,58 @@ static void ggml_compute_forward_repeat_f32( return; } + int n_repeated_dims = 0; + int repeated_dim = -1; + for (int idim = 0; idim < GGML_MAX_DIMS; ++idim) { + if (src0->ne[idim] != dst->ne[idim]) { + ++n_repeated_dims; + repeated_dim = idim; + } + } + + int ith = params->ith; + int nth = params->nth; + + if (n_repeated_dims == 1) { + GGML_ASSERT(repeated_dim >= 0 && repeated_dim < GGML_MAX_DIMS); + int nrows = 1; + for (int idim = 0; idim < GGML_MAX_DIMS; ++idim) { + if (src0->ne[idim] == dst->ne[idim]) nrows *= dst->ne[idim]; + } + int nrows_per_thread = (nrows + nth - 1)/nth; + int first_row = ith*nrows_per_thread; + int last_row = MIN(first_row + nrows_per_thread, nrows); + for (int ir = first_row; ir < last_row; ++ir) { + int ii = ir; + int denom = nrows; + size_t offset_src = 0; + size_t offset_dst = 0; + for (int idim = GGML_MAX_DIMS-1; idim >= 0; --idim) { + if (idim == repeated_dim) continue; + denom /= dst->ne[idim]; + int idx = ii / denom; + ii -= idx * denom; + offset_src += idx*src0->nb[idim]; + offset_dst += idx*dst->nb[idim]; + } + char * dst_ptr = (char *)dst->data + offset_dst; + const char * src_ptr = (const char *)src0->data + offset_src; + if (src0->ne[repeated_dim] == 1) { + float value = *(const float *)src_ptr; + for (int i = 0; i < (int)dst->ne[repeated_dim]; ++i) { + *(float *)(dst_ptr + i*dst->nb[repeated_dim]) = value; + } + } else { + for (int i = 0; i < (int)dst->ne[repeated_dim]; ++i) { + int i0 = i % src0->ne[repeated_dim]; + float value = *(const float *)(src_ptr + i0*src0->nb[repeated_dim]); + *(float *)(dst_ptr + i*dst->nb[repeated_dim]) = value; + } + } + } + return; + } + if (params->ith != 0) { return; } @@ -14115,6 +14507,9 @@ static void ggml_compute_forward_repeat( const struct ggml_tensor * src0 = dst->src[0]; + //if (params->ith == 0) printf("%s(%s,%s,%s): %ld x %ld x %ld x %ld -> %ld x %ld x %ld x %ld\n", __func__, dst->name, ggml_type_name(src0->type), ggml_type_name(dst->type), + // src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]); + switch (src0->type) { case GGML_TYPE_F16: case GGML_TYPE_BF16: @@ -14238,8 +14633,11 @@ static void ggml_compute_forward_concat_f32( GGML_ASSERT(dim >= 0 && dim < 4); + //if (ith == 0) printf("%s(%s, dim = %d): %ld x %ld x %ld x %ld + %ld x %ld x %ld x %ld -> %ld x %ld x %ld x %ld\n", __func__, dst->name, dim, + // src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]); + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst) && - (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1))) { + (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1) || (dim == 0 && dst->ne[1]*dst->ne[2]*dst->ne[3] == 1))) { // simply copy the data const int64_t size_src_0 = ggml_nbytes(src0); const int64_t size_src_1 = ggml_nbytes(src1); @@ -14248,8 +14646,13 @@ static void ggml_compute_forward_concat_f32( for (int64_t i_block = ith; i_block < num_blocks; i_block += nth) { const int64_t start = i_block*block_size; if (start < size_src_0) { - int64_t copy_size = MIN(block_size, size_src_0 - start); - memcpy((char *)dst->data + start, (char *)src0->data + start, copy_size); + if (start + block_size <= size_src_0) { + memcpy((char *)dst->data + start, (char *)src0->data + start, block_size); + } else { + memcpy((char *)dst->data + start, (char *)src0->data + start, size_src_0 - start); + size_t copy_size = MIN(size_src_1, block_size - (size_src_0 - start)); + memcpy((char *)dst->data + size_src_0, (char *)src1->data, copy_size); + } } else { int64_t copy_size = MIN(block_size, size_src_0 + size_src_1 - start); memcpy((char *)dst->data + start, (char *)src1->data + start - size_src_0, copy_size); @@ -14437,18 +14840,21 @@ static void ggml_compute_forward_neg_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->ith != 0) { - return; - } + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); + int ith = params->ith; + int nth = params->nth; const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - for (int i = 0; i < n; i++) { + int npt = (n + nth - 1)/nth; + int first = ith*npt; + int last = MIN(first + npt, n); + + for (int i = first; i < last; i++) { ggml_vec_neg_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), (float *) ((char *) src0->data + i*(src0->nb[1]))); @@ -14616,9 +15022,9 @@ static void ggml_compute_forward_relu_f32( return; } - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); const int n = ggml_nrows(src0); const int nc = src0->ne[0]; @@ -14656,10 +15062,6 @@ static void ggml_compute_forward_sigmoid_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->ith != 0) { - return; - } - assert(ggml_is_contiguous_1(src0)); assert(ggml_is_contiguous_1(dst)); assert(ggml_are_same_shape(src0, dst)); @@ -14667,7 +15069,13 @@ static void ggml_compute_forward_sigmoid_f32( const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - for (int i = 0; i < n; i++) { + int ith = params->ith; + int nth = params->nth; + int npt = (n + nth - 1)/nth; + int first = ith*npt; + int last = MIN(first + npt, n); + + for (int i = first; i < last; i++) { ggml_vec_sigmoid_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), (float *) ((char *) src0->data + i*(src0->nb[1]))); @@ -14692,6 +15100,100 @@ static void ggml_compute_forward_sigmoid( } } +// ggml_compute_forward_exp + +static void ggml_compute_forward_exp_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + const int dr = (nr + nth - 1)/nth; + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_exp_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_exp( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_exp_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_softplus + +static void ggml_compute_forward_softplus_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + const int dr = (nr + nth - 1)/nth; + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_softplus_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_softplus( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_softplus_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_gelu static void ggml_compute_forward_gelu_f32( @@ -14751,6 +15253,102 @@ static void ggml_compute_forward_gelu( } } +// ggml_compute_forward_fill + +static void ggml_compute_forward_fill_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) { + const float c = ggml_get_op_params_f32(dst, 0); + + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const int ith = params->ith; + const int nth = params->nth; + const int64_t nr = ne1*ne2*ne3; + + for (int64_t ir = ith; ir < nr; ir += nth) { + const int64_t i03 = ir/(ne2*ne1); + const int64_t i02 = (ir - i03*ne2*ne1)/ne1; + const int64_t i01 = ir - i03*ne2*ne1 - i02*ne1; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1); + ggml_vec_set_f32(ne0, dst_ptr, c); + } +} + +static void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst) { + ggml_compute_forward_fill_f32(params, dst); +} + +// ggml_compute_forward_tri + +static inline bool ggml_tri_lower_pred(int i, int r) { + return i < r; +} + +static inline bool ggml_tri_lower_diag_pred(int i, int r) { + return i <= r; +} + +static inline bool ggml_tri_upper_pred(int i, int r) { + return i > r; +} + +static inline bool ggml_tri_upper_diag_pred(int i, int r) { + return i >= r; +} + +static void ggml_compute_forward_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + + const enum ggml_tri_type ttype = (enum ggml_tri_type) ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_TENSOR_UNARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + const int64_t nr = ne01*ne02*ne03; + + bool (*bipred)(int, int); + + switch (ttype) { + case GGML_TRI_TYPE_LOWER: bipred = ggml_tri_lower_pred; break; + case GGML_TRI_TYPE_LOWER_DIAG: bipred = ggml_tri_lower_diag_pred; break; + case GGML_TRI_TYPE_UPPER: bipred = ggml_tri_upper_pred; break; + case GGML_TRI_TYPE_UPPER_DIAG: bipred = ggml_tri_upper_diag_pred; break; + default: GGML_ABORT("invalid tri type"); + } + + for (int64_t ir = ith; ir < nr; ir += nth) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = ir - i03*ne02*ne01 - i02*ne01; + + const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1); + + for (int i0 = 0; i0 < ne0; ++i0) { + dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f; + } + } +} + +static void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_tri_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_gelu_quick static void ggml_compute_forward_gelu_quick_f32( @@ -16844,6 +17442,65 @@ static void ggml_compute_forward_mul_mat_up_gate( } #endif +// ggml_compute_forward_l2_norm + +static void ggml_compute_forward_l2_norm_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + const float eps = ggml_get_op_params_f32(dst, 0); + + GGML_ASSERT(eps >= 0.0f); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (const float *) ((const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float) (x[i00] * x[i00]); + } + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + memcpy(y, x, ne00 * sizeof(float)); + + const float scale = 1.0f/fmaxf(sqrtf(sum), eps); + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_l2_norm( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_l2_norm_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_out_prod static void ggml_compute_forward_out_prod_f32( @@ -17164,36 +17821,27 @@ static void ggml_compute_forward_scale_f32( const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); + //if (ith == 0) printf("%s(%s): %ld x %ld x %ld x %ld with %g, %g\n", __func__, dst->name, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (double)s, (double)b); - // rows per thread - const int dr = (nr + nth - 1)/nth; + const int64_t block_size = 1024; + int64_t nelements = ggml_nelements(dst); + int64_t nblocks = (nelements + block_size - 1)/block_size; - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const size_t nb01 = src0->nb[1]; - - const size_t nb1 = dst->nb[1]; - - if (b == 0.0f) { - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ib = ith; ib < nblocks; ib += nth) { + const float * src_data = (const float *)src0->data + block_size*ib; + float * dst_data = ( float *)dst->data + block_size*ib; + int n = MIN(block_size, nelements - block_size*ib); + if (b == 0.0f) { if (dst->data != src0->data) { // src0 is same shape as dst => same indices - memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + memcpy(dst_data, src_data, n * sizeof(float)); } - ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s); - } - } else { - for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_mad1_f32(nc, - (float *) ((char *) dst->data + i1*nb1), - (float *) ((char *) src0->data + i1*nb1), - s, b); + ggml_vec_scale_f32(n, dst_data, s); + } else { + ggml_vec_mad1_f32(n, dst_data, src_data, s, b); } } + } static void ggml_compute_forward_scale( @@ -17574,6 +18222,10 @@ static void ggml_compute_forward_cpy( static void ggml_compute_forward_cont( const struct ggml_compute_params * params, struct ggml_tensor * dst) { + //const struct ggml_tensor * src = dst->src[0]; + //if (params->ith == 0) printf("%s(%s): %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu (%s) -> %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu (%s)\n", __func__, dst->name, + // src->ne[0], src->ne[1], src->ne[2], src->ne[3], src->nb[0], src->nb[1], src->nb[2], src->nb[3], ggml_type_name(src->type), + // dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], ggml_type_name(dst->type)); ggml_compute_forward_dup(params, dst); } @@ -21578,6 +22230,75 @@ static void ggml_compute_forward_ssm_scan( } } +// ggml_compute_forward_solve_tri + +static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ne00 == ne01); + GGML_ASSERT(ne0 == ne10); + GGML_ASSERT(ne1 == ne11); + + GGML_ASSERT(ne02 == ne12 && ne12 == ne2); + GGML_ASSERT(ne03 == ne13 && ne13 == ne3); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t k = ne10; + const int64_t n = ne11; + const int64_t nr = ne02*ne03*k; + + const int64_t dr = (nr + nth - 1)/nth; + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + const float * A = (const float *) src0->data; + const float * B = (const float *) src1->data; + float * X = ( float *) dst->data; + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*k); + const int64_t i02 = (ir - i03*ne02*k)/k; + const int64_t i01 = ir - i03*ne02*k - i02*k; + + const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float); + const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float); + + float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float); + + for (int64_t i00 = 0; i00 < n; ++i00) { + float sum = 0.0f; + for (int64_t t = 0; t < i00; ++t) { + sum += A_batch[i00*n + t] * X_batch[t*k + i01]; + } + + const float diag = A_batch[i00*n + i00]; + GGML_ASSERT(diag != 0.0f); + + X_batch[i00*k + i01] = (B_batch[i00*k + i01] - sum) / diag; + } + } +} + +static void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + ggml_compute_forward_solve_tri_f32(params, dst); + } else { + GGML_ABORT("fatal error"); + } +} + // ggml_compute_forward_win_part static void ggml_compute_forward_win_part_f32( @@ -21771,6 +22492,14 @@ static void ggml_compute_forward_unary( { ggml_compute_forward_hardsigmoid(params, dst); } break; + case GGML_UNARY_OP_EXP: + { + ggml_compute_forward_exp(params, dst); + } break; + case GGML_UNARY_OP_SOFTPLUS: + { + ggml_compute_forward_softplus(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -22916,6 +23645,8 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml int64_t t1 = ggml_time_us(); #endif + const bool fusion = true; + switch (tensor->op) { case GGML_OP_REDUCE: { @@ -22991,7 +23722,7 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml } break; case GGML_OP_SUM_ROWS: { - if (i + 1 < cgraph->n_nodes && + if (fusion && i + 1 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_DIV && cgraph->nodes[i+1]->src[1] == tensor && cgraph->nodes[i+1]->src[0] == tensor->src[0]) { @@ -23001,6 +23732,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml ggml_compute_forward_sum_rows(params, tensor); } } break; + case GGML_OP_CUMSUM: + { + ggml_compute_forward_cumsum(params, tensor); + } break; case GGML_OP_MEAN: { ggml_compute_forward_mean(params, tensor); @@ -23049,6 +23784,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml { ggml_compute_forward_group_norm(params, tensor); } break; + case GGML_OP_L2_NORM: + { + ggml_compute_forward_l2_norm(params, tensor); + } break; case GGML_OP_MUL_MAT: { i = ggml_compute_forward_mul_mat(params, tensor, cgraph, i); @@ -23091,6 +23830,15 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml } break; case GGML_OP_CONT: { + if (i + 2 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_SUM_ROWS && + cgraph->nodes[i+2]->op == GGML_OP_TRANSPOSE) { + if (tensor->src[0]->op == GGML_OP_TRANSPOSE) { + ggml_compute_forward_sum_rows_f32_nc(params, cgraph->nodes[i+1]); + i += 2; + break; + } + } ggml_compute_forward_cont(params, tensor); } break; case GGML_OP_RESHAPE: @@ -23136,7 +23884,7 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml } break; case GGML_OP_SOFT_MAX: { - if (i + 4 < cgraph->n_nodes && + if (fusion && i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ARGSORT && cgraph->nodes[i+3]->op == GGML_OP_VIEW && @@ -23220,6 +23968,14 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml { ggml_compute_forward_timestep_embedding(params, tensor); } break; + case GGML_OP_TRI: + { + ggml_compute_forward_tri(params, tensor); + } break; + case GGML_OP_FILL: + { + ggml_compute_forward_fill(params, tensor); + } break; case GGML_OP_ARGSORT: { if (false && i + 5 < cgraph->n_nodes && @@ -23265,6 +24021,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml { ggml_compute_forward_ssm_scan(params, tensor); } break; + case GGML_OP_SOLVE_TRI: + { + ggml_compute_forward_solve_tri(params, tensor); + } break; case GGML_OP_WIN_PART: { ggml_compute_forward_win_part(params, tensor); @@ -23276,7 +24036,7 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml case GGML_OP_UNARY: { const enum ggml_unary_op unary_op = ggml_get_unary_op(tensor); - if (unary_op == GGML_UNARY_OP_SIGMOID && i + 5 < cgraph->n_nodes && + if (fusion && unary_op == GGML_UNARY_OP_SIGMOID && i + 5 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && cgraph->nodes[i+3]->op == GGML_OP_ARGSORT && @@ -23285,7 +24045,7 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml iqk_glm45moe_experts(cgraph->nodes[i+5], cgraph->nodes[i+4], params->ith, params->nth); i += 5; } - else if (unary_op == GGML_UNARY_OP_SIGMOID && i + 4 < cgraph->n_nodes && + else if (fusion && unary_op == GGML_UNARY_OP_SIGMOID && i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK && @@ -24292,6 +25052,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_OP_CUMSUM: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } + case GGML_OP_L2_NORM: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_TIMESTEP_EMBEDDING: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -24312,6 +25080,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_OP_TRI: + case GGML_OP_FILL: + case GGML_OP_SOLVE_TRI: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_FLASH_ATTN_EXT: { struct ggml_tensor * flash_grad = NULL; @@ -24454,6 +25228,24 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; + case GGML_UNARY_OP_EXP: + { + if (src0->grad) { + src0->grad = ggml_add_or_set(ctx, + src0->grad, + ggml_mul(ctx, tensor, tensor->grad), + zero_table); + } + } break; + case GGML_UNARY_OP_SOFTPLUS: + { + if (src0->grad) { + src0->grad = ggml_add_or_set(ctx, + src0->grad, + ggml_mul(ctx, tensor->grad, ggml_sigmoid(ctx, src0)), + zero_table); + } + } break; default: GGML_ABORT("fatal error"); } @@ -24933,21 +25725,23 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ADD_ID: case GGML_OP_ADD1: case GGML_OP_ACC: + case GGML_OP_CUMSUM: + case GGML_OP_TRI: + case GGML_OP_FILL: case GGML_OP_MULTI_ADD: case GGML_OP_MUL_MULTI_ADD: case GGML_OP_HADAMARD: + case GGML_OP_REPEAT: + case GGML_OP_SUB: { n_tasks = n_threads; } break; - case GGML_OP_SUB: case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_LOG: case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_ARGMAX: - case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_LEAKY_RELU: { @@ -24957,20 +25751,22 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { switch (ggml_get_unary_op(node)) { case GGML_UNARY_OP_ABS: case GGML_UNARY_OP_SGN: - case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_STEP: case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: { n_tasks = 1; } break; + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_SWIGLU: case GGML_UNARY_OP_SWIGLU_OAI: { @@ -25009,12 +25805,14 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_FUSED_NORM: case GGML_OP_RMS_NORM_BACK: case GGML_OP_GROUP_NORM: + case GGML_OP_L2_NORM: case GGML_OP_CONCAT: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: case GGML_OP_MOE_FUSED_UP_GATE: case GGML_OP_FUSED_UP_GATE: case GGML_OP_OUT_PROD: + case GGML_OP_SOLVE_TRI: { n_tasks = n_threads; } break; @@ -25051,13 +25849,14 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { { n_tasks = 1; //TODO } break; - case GGML_OP_SCALE: case GGML_OP_SOFTCAP: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_CAP_MAX: + case GGML_OP_SUM_ROWS: { n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); } break; + case GGML_OP_SCALE: case GGML_OP_IM2COL: case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: diff --git a/include/llama.h b/include/llama.h index 864d50cb..b7998a15 100644 --- a/include/llama.h +++ b/include/llama.h @@ -47,10 +47,10 @@ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 8 +#define LLAMA_SESSION_VERSION 9 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ -#define LLAMA_STATE_SEQ_VERSION 2 +#define LLAMA_STATE_SEQ_VERSION 3 #ifdef __cplusplus extern "C" { diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 91ce5af9..f18741ae 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -27,6 +27,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN2VL, "qwen2vl" }, { LLM_ARCH_QWEN3, "qwen3" }, { LLM_ARCH_QWEN3MOE, "qwen3moe" }, + { LLM_ARCH_QWEN3NEXT, "qwen3next" }, { LLM_ARCH_QWEN3VL, "qwen3vl" }, { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, { LLM_ARCH_PHI2, "phi2" }, @@ -186,6 +187,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, + { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, @@ -242,4 +244,3 @@ const char * llama_model_arch_name(llm_arch arch) { } return it->second; } - diff --git a/src/llama-arch.h b/src/llama-arch.h index 97915945..21715a57 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -26,6 +26,7 @@ enum llm_arch { LLM_ARCH_QWEN2VL, LLM_ARCH_QWEN3, LLM_ARCH_QWEN3MOE, + LLM_ARCH_QWEN3NEXT, LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE, LLM_ARCH_PHI2, @@ -180,6 +181,7 @@ enum llm_kv { LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_SSM_GROUP_COUNT, LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_PRE, @@ -278,8 +280,11 @@ enum llm_tensor { LLM_TENSOR_SSM_X, LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_A_NOSCAN, LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, + LLM_TENSOR_SSM_BETA_ALPHA, LLM_TENSOR_ATTN_Q_A, LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 68e5a17a..a18d372b 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -6,6 +6,28 @@ #include "ggml.h" +#include +#include + +static inline uint32_t llama_kv_qnext_state_slots(const llama_kv_cache & kv_self) { + uint32_t n_slots = 0; + + for (const ggml_tensor * t : kv_self.s_l) { + if (t == nullptr) { + continue; + } + + const uint32_t layer_slots = (uint32_t) t->ne[1]; + if (n_slots == 0) { + n_slots = layer_slots; + } else { + GGML_ASSERT(n_slots == layer_slots); + } + } + + return n_slots; +} + llm_build_context::llm_build_context( llama_context & lctx, const llama_batch & batch, @@ -84,6 +106,7 @@ void llm_build_context::init() { lctx.inp_s_copy = nullptr; lctx.inp_s_mask = nullptr; lctx.inp_s_seq = nullptr; + lctx.inp_s_seq_qnext = nullptr; lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; @@ -118,6 +141,12 @@ ggml_cgraph * llm_build_context::build_k_shift() { ggml_set_input(lctx.inp_K_shift); for (int il = 0; il < n_layer; ++il) { + if (model.arch == LLM_ARCH_QWEN3NEXT && hparams.is_recurrent(il)) { + continue; + } + if (kv_self.k_l[il] == nullptr) { + continue; + } const int64_t n_head_kv = hparams.n_head_kv(il); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); struct ggml_tensor * rope_factors = build_rope_factors(il); @@ -161,21 +190,34 @@ ggml_cgraph * llm_build_context::build_k_shift() { ggml_cgraph * llm_build_context::build_s_copy() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - GGML_ASSERT(kv_self.recurrent); + const uint32_t qnext_state_slots = llama_kv_qnext_state_slots(kv_self); + const bool has_qnext_state = qnext_state_slots > 0; + GGML_ASSERT(kv_self.recurrent || has_qnext_state); struct ggml_tensor * state_copy = build_inp_s_copy(); for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); + if (kv_self.recurrent) { + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); - conv_states = ggml_get_rows(ctx0, conv_states, state_copy); - ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy); + conv_states = ggml_get_rows(ctx0, conv_states, state_copy); + ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy); - // TODO: name the intermediate tensors with cb() + // TODO: name the intermediate tensors with cb() - ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il])); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il])); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il])); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il])); + } + + if (kv_self.s_l.size() > (size_t) il && kv_self.s_l[il] != nullptr) { + struct ggml_tensor * qnext_states_all = ggml_reshape_2d(ctx0, kv_self.s_l[il], hparams.n_embd_v_s(), kv_self.s_l[il]->ne[1]); + GGML_ASSERT((uint32_t) qnext_states_all->ne[1] == qnext_state_slots); + struct ggml_tensor * qnext_state_copy = ggml_view_1d(ctx0, state_copy, qnext_state_slots, 0); + struct ggml_tensor * qnext_states = ggml_get_rows(ctx0, qnext_states_all, qnext_state_copy); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, qnext_states, kv_self.s_l[il])); + } } return gf; @@ -198,6 +240,12 @@ ggml_cgraph * llm_build_context::build_defrag(const std::vector & ids) } for (int il = 0; il < n_layer; ++il) { + if (model.arch == LLM_ARCH_QWEN3NEXT && hparams.is_recurrent(il)) { + continue; + } + if (kv_self.k_l[il] == nullptr) { + continue; + } const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); @@ -214,7 +262,7 @@ ggml_cgraph * llm_build_context::build_defrag(const std::vector & ids) ggml_tensor * view_v_src = nullptr; ggml_tensor * view_v_dst = nullptr; - if (kv_self.v_l.size() > il) { + if (kv_self.v_l.size() > il && kv_self.v_l[il] != nullptr) { // Note: with MLA the V cache may not be present. if (flash_attn) { // NOTE: the V cache is not transposed when using flash attention @@ -509,12 +557,12 @@ void llm_build_context::llm_build_kv_store( struct ggml_tensor * v_cache_view = nullptr; - if (cparams.flash_attn) { + if (!kv.v_trans) { v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, (kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)); lctx.cache_copies[2*il+1].step = ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa); } else { - // note: the V cache is transposed when not using flash attention + // note: the V cache is transposed for legacy non-FA layouts v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, ( n_ctx)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); @@ -1454,12 +1502,21 @@ static ggml_tensor * llm_build_kqv( } else { // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], + struct ggml_tensor * v; + if (kv.v_trans) { + v = ggml_view_3d(ctx, kv.v_l[il], n_kv, n_embd_head_v, n_head_kv, ggml_element_size(kv.v_l[il])*n_ctx, ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, 0); + } else { + v = ggml_view_3d(ctx, kv.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_v), + 0); + v = ggml_cont(ctx, ggml_transpose(ctx, v)); + } cb(v, "v", il); auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); @@ -4248,6 +4305,822 @@ ggml_cgraph * llm_build_context::build_qwen3moe() { return gf; } +ggml_cgraph * llm_build_context::build_qwen3next() { + static constexpr int QWEN3NEXT_CHUNK_SIZE = 64; + + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + GGML_ASSERT(batch.n_tokens > 0); + const bool has_explicit_seq_info = batch.n_seq_id != nullptr && batch.seq_id != nullptr; + std::vector token_seq_ids(batch.n_tokens, 0); + for (int i = 0; i < batch.n_tokens; ++i) { + if (has_explicit_seq_info) { + GGML_ASSERT(batch.n_seq_id[i] > 0 && "qwen3next expects each token to belong to at least one sequence"); + GGML_ASSERT(batch.n_seq_id[i] == 1 && "qwen3next does not support multi-sequence tokens yet"); + token_seq_ids[i] = batch.seq_id[i][0]; + } else { + token_seq_ids[i] = 0; + } + } + + const llama_seq_id seq_id = token_seq_ids[0]; + const bool all_same_seq = std::all_of(token_seq_ids.begin(), token_seq_ids.end(), [&](llama_seq_id s) { + return s == seq_id; + }); + + bool has_unique_seq_ids = true; + if (!all_same_seq) { + std::unordered_set seen; + seen.reserve(token_seq_ids.size()); + for (llama_seq_id s : token_seq_ids) { + if (!seen.insert(s).second) { + has_unique_seq_ids = false; + break; + } + } + } + + GGML_ASSERT(hparams.ssm_n_group > 0); + GGML_ASSERT(hparams.ssm_dt_rank > 0); + GGML_ASSERT(hparams.ssm_d_conv > 0); + GGML_ASSERT(hparams.ssm_d_inner % hparams.ssm_dt_rank == 0); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = hparams.ssm_d_inner / num_v_heads; + const int64_t key_dim = head_k_dim * num_k_heads; + const int64_t value_dim = head_v_dim * num_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + const int64_t conv_state_dim = (hparams.ssm_d_conv - 1) * conv_dim; + const int64_t ssm_state_dim = head_v_dim * head_v_dim * num_v_heads; + const int64_t state_dim = conv_state_dim + ssm_state_dim; + const uint32_t qnext_state_slots = llama_kv_qnext_state_slots(kv_self); + GGML_ASSERT(qnext_state_slots > 0); + + GGML_ASSERT(hparams.n_embd_v_s() == (uint32_t) state_dim); + + // Reserve-graph builds may not carry explicit sequence IDs, in which case + // the fallback sequence slot is 0. + const uint32_t state_seq_id = (uint32_t) seq_id; + for (llama_seq_id s : token_seq_ids) { + GGML_ASSERT(s >= 0); + GGML_ASSERT((uint32_t) s < qnext_state_slots); + } + + const bool reset_state = batch.pos != nullptr && batch.pos[0] == 0; + + auto get_slice_2d = [&](ggml_tensor * t, int64_t c) -> ggml_tensor * { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); + }; + + auto build_delta_net_chunking = [&](ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, + ggml_tensor * causal_mask, ggml_tensor * identity, + ggml_tensor * diag_mask, int il) -> std::pair { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_seqs == 1); + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); + GGML_ASSERT(H_k == H_v); + + const float eps_norm = hparams.f_norm_rms_eps; + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + + const float scale = 1.0f / sqrtf(S_v); + q = ggml_scale(ctx0, q, scale); + + beta = ggml_sigmoid(ctx0, beta); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); + k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); + v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_v, n_seqs); + + beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); + cb(q, "q_perm", il); + cb(k, "k_perm", il); + cb(v, "v_perm", il); + cb(beta, "beta_perm", il); + cb(g, "g_perm", il); + cb(state,"state_in", il); + + const int64_t chunk_size = QWEN3NEXT_CHUNK_SIZE; + const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; + const int64_t n_chunks = (n_tokens + pad) / chunk_size; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, pad, 0, 0, 0); + beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); + + cb(q, "q_pad", il); + cb(k, "k_pad", il); + cb(v, "v_pad", il); + cb(beta, "beta_pad", il); + cb(g, "g_pad", il); + + ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); + ggml_tensor * k_beta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, beta, k->ne[0], beta->ne[1], beta->ne[2], beta->ne[3]), k); + + cb(v_beta, "v_beta", il); + cb(k_beta, "k_beta", il); + + q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); + k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_v * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); + v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); + + g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_v * n_seqs); + beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + cb(g_cumsum, "g_cumsum", il); + + ggml_tensor * gcs_i = + ggml_repeat_4d(ctx0, g_cumsum, chunk_size, chunk_size, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * gcs_j_broadcast = + ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); + ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); + cb(decay_mask, "decay_mask", il); + + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + decay_mask = ggml_exp(ctx0, decay_mask); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + + ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); + cb(kmulkbeta, "kk_beta", il); + + ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); + ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); + cb(attn, "attn_pre_solve", il); + + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); + ggml_tensor * identity_repeat = + ggml_repeat_4d(ctx0, identity, attn_lower->ne[0], attn_lower->ne[1], attn_lower->ne[2], attn_lower->ne[3]); + ggml_tensor * lhs = ggml_neg(ctx0, ggml_sub(ctx0, attn_lower, identity_repeat)); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_mul(ctx0, lin_solve, causal_mask); + attn = ggml_add(ctx0, attn, identity); + cb(attn, "attn_solved", il); + + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); + cb(v, "v_beta", il); + + ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); + cb(g_cumsum_t, "g_cumsum_t", il); + ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); + cb(gexp, "gexp", il); + + ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); + cb(kbeta_gexp, "kbeta_gexp", il); + + auto attn_kbeta = ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))); + cb(attn_kbeta, "attn_kbeta", il); + ggml_tensor * k_cumdecay = ggml_cont(ctx0, ggml_transpose(ctx0, attn_kbeta)); + cb(k_cumdecay, "k_cumdecay", il); + + ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); + cb(attn_kq, "attn_kq_pre", il); + attn_kq = ggml_mul(ctx0, decay_mask, attn_kq); + attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); + cb(attn_kq, "attn_kq", il); + + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], + g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], + (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); + g_last = ggml_cont(ctx0, g_last); + cb(g_last, "g_last", il); + + ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); + cb(g_last_exp, "g_last_exp", il); + + ggml_tensor * g_last_repeat = + ggml_repeat_4d(ctx0, g_last, chunk_size, 1, n_chunks, H_v * n_seqs); + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last_repeat)); + cb(g_diff, "g_diff", il); + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, 1, chunk_size, n_chunks, g_diff_exp->ne[3]); + + ggml_tensor * key_gdiff = ggml_mul(ctx0, ggml_repeat_4d(ctx0, g_diff_exp_t, k->ne[0], g_diff_exp_t->ne[1], g_diff_exp_t->ne[2], g_diff_exp_t->ne[3]), k); + cb(key_gdiff, "key_gdiff", il); + + ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); + cb(key_gdiff_t, "key_gdiff_t", il); + + cb(state, "new_state", il); + + ggml_tensor * core_attn_out = nullptr; + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + ggml_tensor * q_chunk = get_slice_2d(q, chunk); + ggml_tensor * v_chunk = get_slice_2d(v, chunk); + ggml_tensor * gexp_chunk = get_slice_2d(gexp, chunk); + ggml_tensor * k_cumdecay_chunk = get_slice_2d(k_cumdecay, chunk); + ggml_tensor * attn_chunk = get_slice_2d(attn_kq, chunk); + cb(attn_chunk, "attn_chunk", il); + + ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); + + //printf("v_prime_chunk: %ld x %ld x %ld x %ld, %s x %ld x %ld x %ld x %ld, %s\n", state_t->ne[0], state_t->ne[1], state_t->ne[2], state_t->ne[3], ggml_type_name(state_t->type), + // k_cumdecay_chunk->ne[0], k_cumdecay_chunk->ne[1], k_cumdecay_chunk->ne[2], k_cumdecay_chunk->ne[3], ggml_type_name(k_cumdecay_chunk->type)); + ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); + cb(v_prime, "v_prime_chunk", il); + + ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); + ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + cb(v_new, "v_new_chunk", il); + + ggml_tensor * q_g_exp = ggml_mul(ctx0, ggml_repeat_4d(ctx0, gexp_chunk, q_chunk->ne[0], gexp_chunk->ne[1], gexp_chunk->ne[2], gexp_chunk->ne[3]), q_chunk); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + cb(attn_inter, "attn_inter_chunk", il); + + //printf("v_attn_chunk: %ld x %ld x %ld x %ld, %s x %ld x %ld x %ld x %ld, %s\n", v_new_t->ne[0], v_new_t->ne[1], v_new_t->ne[2], v_new_t->ne[3], ggml_type_name(v_new_t->type), + // attn_chunk->ne[0], attn_chunk->ne[1], attn_chunk->ne[2], attn_chunk->ne[3], ggml_type_name(attn_chunk->type)); + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); + cb(v_attn, "v_attn_chunk", il); + + ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); + cb(core_attn_out_chunk, "core_attn_out_chunk", il); + + core_attn_out = core_attn_out == nullptr + ? core_attn_out_chunk + : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); + + ggml_tensor * k_gdiff_t = get_slice_2d(key_gdiff_t, chunk); + //printf("kgdmulvnew: %ld x %ld x %ld x %ld, %s x %ld x %ld x %ld x %ld, %s\n", v_new_t->ne[0], v_new_t->ne[1], v_new_t->ne[2], v_new_t->ne[3], ggml_type_name(v_new_t->type), + // k_gdiff_t->ne[0], k_gdiff_t->ne[1], k_gdiff_t->ne[2], k_gdiff_t->ne[3], ggml_type_name(k_gdiff_t->type)); + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); + cb(kgdmulvnew, "kgdmulvnew", il); + + ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(g_last_exp, chunk)); + state = ggml_add(ctx0, + ggml_mul(ctx0, state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), + ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); + } + + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(core_attn_out->type, S_v), + ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks), + ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks * H_v), 0); + cb(output_tokens, "output_tokens", il); + + output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); + output_tokens = ggml_cont(ctx0, output_tokens); + + return {output_tokens, state}; + }; + + auto build_delta_net_autoregressive = [&](ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, + int il) -> std::pair { + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_tokens == 1); + GGML_ASSERT(n_seqs == 1); + GGML_ASSERT(H_k == H_v); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); + + const float eps_norm = hparams.f_norm_rms_eps; + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + beta = ggml_sigmoid(ctx0, beta); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); + ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); + + g_t = ggml_exp(ctx0, g_t); + state = ggml_mul(ctx0, state, g_t); + + ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); + ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); + kv_mem = ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)); + cb(kv_mem, "kv_mem_t_cont", il); + kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, kv_mem)); + + ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); + ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); + cb(v_diff, "v_diff", il); + ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); + + ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); + state = ggml_add(ctx0, state, k_t_delta); + + ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); + ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); + state_q = ggml_cont(ctx0, ggml_transpose(ctx0, state_q)); + cb(state_q, "state_q_t_cont", il); + ggml_tensor * core_attn_out = ggml_transpose(ctx0, ggml_sum_rows(ctx0, state_q)); + + cb(core_attn_out, "output_tokens", il); + cb(state, "new_state", il); + + return {core_attn_out, state}; + }; + + auto build_qkvz = [&](ggml_tensor * input, int il) -> std::pair { + const int64_t n_tok = input->ne[1]; + if (model.layers[il].wqkv) { + ggml_tensor * qkv_mixed = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, input); + cb(qkv_mixed, "qkv_mixed", il); + qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_tok, 1); + cb(qkv_mixed, "linear_attn_qkv_mixed", il); + + ggml_tensor * z = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv_gate, input); + cb(z, "z", il); + + return { qkv_mixed, z }; + } + + ggml_tensor * mixed_qkvz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, input); + cb(mixed_qkvz, "linear_attn_mixed_qkvz", il); + + const int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads); + ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tok, 1); + + int64_t split_sizes_qkvz[4] = { + head_k_dim, + head_k_dim, + head_v_dim * num_v_heads / num_k_heads, + head_v_dim * num_v_heads / num_k_heads + }; + + ggml_tensor * query = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tok, 1, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0); + cb(query, "q", il); + + ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tok, 1, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped)); + cb(key, "k", il); + + ggml_tensor * value = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tok, 1, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped)); + cb(value, "v", il); + + ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tok, 1, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped)); + z = ggml_cont(ctx0, z); + cb(z, "z", il); + + ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_tok, 1); + cb(query_flat, "query_flat", il); + + ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_tok, 1); + cb(key_flat, "key_flat", il); + + ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_tok, 1); + cb(value_flat, "value_flat", il); + + ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0); + qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0); + cb(qkv_mixed, "qkv_mixed", il); + + return { qkv_mixed, z }; + }; + + auto build_layer_attn = [&](ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * KQ_mask, int il) -> ggml_tensor * { + ggml_tensor * Qcur_full = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur_full, "Qcur_full", il); + + Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1); + + ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, + Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0); + ggml_tensor * gate = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, + Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full)); + cb(Qcur, "Qcur", il); + cb(gate, "gate", il); + + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur_reshaped", il); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "gate_reshaped", il); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + + ggml_tensor * attn = llm_build_kv(ctx0, lctx, kv_self, gf, + nullptr, nullptr, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, + hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale, + cb, il); + cb(attn, "attn_pregate", il); + + ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); + cb(gate_sigmoid, "gate_sigmoid", il); + + attn = ggml_mul(ctx0, attn, gate_sigmoid); + cb(attn, "attn_gated", il); + + attn = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, attn); + cb(attn, "attn_output", il); + + return attn; + }; + + auto build_layer_ffn = [&](ggml_tensor * cur, int il) -> ggml_tensor * { + const bool has_moe = model.layers[il].ffn_gate_inp != nullptr; + const bool has_dense = model.layers[il].ffn_gate != nullptr && model.layers[il].ffn_up != nullptr && model.layers[il].ffn_down != nullptr; + + if (has_moe) { + ggml_tensor * moe_out = + llm_build_moe_ffn(ctx0, lctx, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, LLM_FFN_SILU, + true, false, 0.0f, LLM_EXPERT_GATING_FUNC_SOFTMAX, + cb, il, gf, false); + cb(moe_out, "ffn_moe_out", il); + + const bool has_shexp = model.layers[il].ffn_up_shexp != nullptr && + model.layers[il].ffn_gate_shexp != nullptr && + model.layers[il].ffn_down_shexp != nullptr && + model.layers[il].ffn_gate_inp_shexp != nullptr; + if (has_shexp) { + ggml_tensor * ffn_shexp = + llm_build_ffn(ctx0, lctx, nullptr, cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(ffn_shexp, "ffn_shexp", il); + + ggml_tensor * shared_gate = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate_inp_shexp, cur); + cb(shared_gate, "shared_expert_gate", il); + + if (shared_gate->ne[1] == 1) { + ffn_shexp = ggml_fused_mul_unary(ctx0, shared_gate, ffn_shexp, GGML_UNARY_OP_SIGMOID); + } else { + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "shared_expert_gate_sigmoid", il); + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + } + cb(ffn_shexp, "ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "ffn_out", il); + return cur; + } + + GGML_ASSERT(has_dense); + cur = llm_build_ffn(ctx0, lctx, nullptr, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + return cur; + }; + + auto build_layer_attn_linear_core = [&](ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity, + ggml_tensor * diag_mask, ggml_tensor * inp_s_seq_qnext, + uint32_t state_seq_id_local, bool reset_state_local, int il) -> ggml_tensor * { + const int64_t n_tok = cur->ne[1]; + + auto qkvz = build_qkvz(cur, il); + ggml_tensor * qkv_mixed = qkvz.first; + ggml_tensor * z = qkvz.second; + + ggml_tensor * mixed_ba = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_beta_alpha, cur); + cb(mixed_ba, "linear_attn_mixed_ba", il); + + int64_t ba_new_dim = 2 * num_v_heads / num_k_heads; + ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tok, 1); + + int64_t split_sizes_ba[2] = { + num_v_heads / num_k_heads, + num_v_heads / num_k_heads + }; + + ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tok, 1, + mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0); + cb(b, "b", il); + + ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tok, 1, + mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], + split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); + cb(a, "a", il); + + ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tok, 1); + ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tok, 1); + + ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); + ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); + cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); + cb(gate, "gate", il); + + size_t state_row_size = 0; + ggml_tensor * state_all = nullptr; + GGML_ASSERT((size_t) il < kv_self.s_l.size() && kv_self.s_l[il] != nullptr); + ggml_tensor * state_storage = kv_self.s_l[il]; + GGML_ASSERT(state_storage->type == GGML_TYPE_F32); + GGML_ASSERT(state_storage->ne[0] >= state_dim); + GGML_ASSERT((uint32_t) state_storage->ne[1] == qnext_state_slots); + state_row_size = state_storage->nb[1]; + GGML_ASSERT(ggml_nbytes(state_storage) >= state_row_size * qnext_state_slots); + state_all = ggml_view_2d(ctx0, state_storage, state_dim, qnext_state_slots, state_row_size, 0); + + ggml_tensor * state_dst = ggml_view_2d(ctx0, state_all, state_dim, 1, state_row_size, state_seq_id_local * state_row_size); + ggml_tensor * state_f32 = state_dst; + if (state_f32->type != GGML_TYPE_F32) { + state_f32 = ggml_cast(ctx0, state_f32, GGML_TYPE_F32); + } + if (reset_state_local) { + state_f32 = ggml_scale(ctx0, state_f32, 0.0f); + } + + ggml_tensor * conv_state_flat = ggml_view_2d(ctx0, state_f32, conv_state_dim, 1, state_f32->nb[1], 0); + ggml_tensor * ssm_state_flat = ggml_view_2d(ctx0, state_f32, ssm_state_dim, 1, state_f32->nb[1], + conv_state_dim * ggml_element_size(state_f32)); + + ggml_tensor * conv_states = ggml_reshape_3d(ctx0, conv_state_flat, hparams.ssm_d_conv - 1, conv_dim, 1); + ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_state_flat, head_v_dim, head_v_dim, num_v_heads, 1); + cb(conv_states, "conv_states", il); + cb(state, "state_predelta", il); + + ggml_tensor * conv_output_raw = ggml_ssm_conv(ctx0, conv_states, qkv_mixed, model.layers[il].ssm_conv1d, inp_s_seq_qnext); + cb(conv_output_raw, "conv_output_raw", il); + + ggml_tensor * conv_output = ggml_view_2d(ctx0, conv_output_raw, conv_dim, n_tok, conv_dim * ggml_element_size(conv_output_raw), 0); + ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output); + cb(conv_output_silu, "conv_output_silu", il); + + ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_output_silu, key_dim, n_tok, conv_output_silu->nb[1], 0); + ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_output_silu, key_dim, n_tok, conv_output_silu->nb[1], + key_dim * ggml_element_size(conv_output_silu)); + ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_output_silu, head_v_dim, num_v_heads, n_tok, 1, + ggml_row_size(conv_output_silu->type, head_v_dim), + conv_output_silu->nb[1], + conv_output_silu->nb[1] * n_tok, + 2 * key_dim * ggml_element_size(conv_output_silu)); + + q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_tok, 1); + k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tok, 1); + v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tok, 1); + cb(q_conv, "q_conv_cont", il); + cb(k_conv, "k_conv_cont", il); + cb(v_conv, "v_conv_cont", il); + + if (num_k_heads != num_v_heads) { + GGML_ASSERT(num_v_heads % num_k_heads == 0); + const int64_t repeat_factor = num_v_heads / num_k_heads; + + ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_tok); + ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_tok); + + ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1); + ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1); + + q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1); + k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1); + } + + cb(q_conv, "q_conv_predelta", il); + cb(k_conv, "k_conv_predelta", il); + cb(v_conv, "v_conv_predelta", il); + + std::pair attn_out; + + GGML_ASSERT(causal_mask != nullptr); + GGML_ASSERT(identity != nullptr); + GGML_ASSERT(diag_mask != nullptr); + + attn_out = n_tok == 1 + ? build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il) + : build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + ggml_tensor * new_conv_states = ggml_view_2d(ctx0, conv_output_raw, hparams.ssm_d_conv - 1, conv_dim, + hparams.ssm_d_conv * ggml_element_size(conv_output_raw), + (1 + conv_dim * n_tok) * ggml_element_size(conv_output_raw)); + ggml_tensor * new_conv_flat = ggml_reshape_2d(ctx0, ggml_cont(ctx0, new_conv_states), conv_state_dim, 1); + ggml_tensor * new_ssm_flat = ggml_reshape_2d(ctx0, new_state, ssm_state_dim, 1); + ggml_tensor * new_state_flat = ggml_concat(ctx0, new_conv_flat, new_ssm_flat, 0); + + ggml_tensor * state_update = new_state_flat; + if (state_dst->type != GGML_TYPE_F32) { + state_update = ggml_cast(ctx0, state_update, state_dst->type); + } + ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_update, state_dst)); + + ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_tok); + ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_tok); + + ggml_tensor * attn_out_norm = llm_build_norm(ctx0, attn_out_2d, hparams, model.layers[il].ssm_norm, nullptr, LLM_NORM_RMS, cb, il); + ggml_tensor * gated_silu = ggml_silu(ctx0, z_2d); + attn_out_norm = ggml_mul(ctx0, attn_out_norm, gated_silu); + + ggml_tensor * final_output = ggml_reshape_2d(ctx0, attn_out_norm, value_dim, n_tok); + cb(final_output, "final_output", il); + + ggml_tensor * out = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, final_output); + cb(out, "linear_attn_out", il); + + return ggml_reshape_2d(ctx0, out, n_embd, n_tok); + }; + + auto build_layer_attn_linear = [&](ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity, + ggml_tensor * diag_mask, int il) -> ggml_tensor * { + GGML_ASSERT(lctx.inp_s_seq_qnext != nullptr); + + if (all_same_seq) { + return build_layer_attn_linear_core(cur, causal_mask, identity, diag_mask, lctx.inp_s_seq_qnext, state_seq_id, reset_state, il); + } + + GGML_ASSERT(has_unique_seq_ids && "qwen3next mixed-sequence batches require unique sequence IDs per token"); + + ggml_tensor * out = nullptr; + for (int64_t i = 0; i < n_tokens; ++i) { + ggml_tensor * cur_i = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (size_t) i * cur->nb[1]); + ggml_tensor * inp_s_seq_qnext_i = ggml_view_2d(ctx0, lctx.inp_s_seq_qnext, 1, 1, lctx.inp_s_seq_qnext->nb[1], (size_t) i * lctx.inp_s_seq_qnext->nb[1]); + + const bool reset_state_i = batch.pos != nullptr && batch.pos[i] == 0; + const uint32_t state_seq_id_i = (uint32_t) token_seq_ids[i]; + ggml_tensor * out_i = build_layer_attn_linear_core(cur_i, causal_mask, identity, diag_mask, inp_s_seq_qnext_i, state_seq_id_i, reset_state_i, il); + + out = out == nullptr ? out_i : ggml_concat(ctx0, out, out_i, 1); + } + + return out; + }; + + ggml_tensor * inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = n_tokens > 1 ? build_inp_out_ids() : nullptr; + ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + lctx.inp_s_seq_qnext = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 1, n_tokens); + cb(lctx.inp_s_seq_qnext, "inp_s_seq_qnext", -1); + ggml_set_input(lctx.inp_s_seq_qnext); + + ggml_tensor * causal_mask = nullptr; + ggml_tensor * identity = nullptr; + ggml_tensor * diag_mask = nullptr; + causal_mask = ggml_tri(ctx0, + ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE, QWEN3NEXT_CHUNK_SIZE), 1.0f), + GGML_TRI_TYPE_LOWER); + identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, QWEN3NEXT_CHUNK_SIZE), 1.0f)); + diag_mask = ggml_add(ctx0, causal_mask, identity); + ggml_build_forward_expand(gf, causal_mask); + ggml_build_forward_expand(gf, identity); + ggml_build_forward_expand(gf, diag_mask); + + ggml_tensor * cur = nullptr; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + GGML_ASSERT(model.layers[il].attn_norm != nullptr); + GGML_ASSERT(model.layers[il].attn_post_norm != nullptr); + + const bool has_moe = model.layers[il].ffn_gate_inp != nullptr; + const bool has_dense = model.layers[il].ffn_gate != nullptr && + model.layers[il].ffn_up != nullptr && + model.layers[il].ffn_down != nullptr; + GGML_ASSERT(has_moe || has_dense); + if (has_moe) { + GGML_ASSERT(model.layers[il].ffn_up_exps != nullptr); + GGML_ASSERT(model.layers[il].ffn_gate_exps != nullptr); + GGML_ASSERT(model.layers[il].ffn_down_exps != nullptr); + } + + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + if (hparams.is_recurrent(il)) { + GGML_ASSERT(model.layers[il].ssm_conv1d != nullptr); + GGML_ASSERT(model.layers[il].ssm_dt != nullptr); + GGML_ASSERT(model.layers[il].ssm_a != nullptr); + GGML_ASSERT(model.layers[il].ssm_beta_alpha != nullptr); + GGML_ASSERT(model.layers[il].ssm_norm != nullptr); + GGML_ASSERT(model.layers[il].ssm_out != nullptr); + GGML_ASSERT(model.layers[il].wqkv != nullptr || model.layers[il].ssm_in != nullptr); + GGML_ASSERT(model.layers[il].wqkv_gate != nullptr || model.layers[il].ssm_in != nullptr); + cur = build_layer_attn_linear(cur, causal_mask, identity, diag_mask, il); + } else { + GGML_ASSERT(model.layers[il].wq != nullptr); + GGML_ASSERT(model.layers[il].wk != nullptr); + GGML_ASSERT(model.layers[il].wv != nullptr); + GGML_ASSERT(model.layers[il].wo != nullptr); + GGML_ASSERT(model.layers[il].attn_q_norm != nullptr); + GGML_ASSERT(model.layers[il].attn_k_norm != nullptr); + cur = build_layer_attn(cur, inp_pos, KQ_mask, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + ggml_tensor * ffn_residual = cur; + ggml_tensor * attn_post_norm = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il); + cb(attn_post_norm, "attn_post_norm", il); + + cur = build_layer_ffn(attn_post_norm, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + inpL = cur; + } + + cur = llm_build_norm(ctx0, inpL, hparams, model.output_norm, nullptr, LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; +} + ggml_cgraph * llm_build_context::build_qwen3vl() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -9273,6 +10146,10 @@ ggml_cgraph * llm_build_context::llama_build_graph( { result = llm.build_qwen3moe(); } break; + case LLM_ARCH_QWEN3NEXT: + { + result = llm.build_qwen3next(); + } break; case LLM_ARCH_QWEN3VL: { result = llm.build_qwen3vl(); diff --git a/src/llama-build-context.h b/src/llama-build-context.h index f1dea92b..b0ab91a4 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -204,6 +204,8 @@ struct llm_build_context { ggml_cgraph * build_qwen3vlmoe(); + ggml_cgraph * build_qwen3next(); + ggml_cgraph * build_phi2(); ggml_cgraph * build_phi3(); diff --git a/src/llama-context.h b/src/llama-context.h index 61ad51e5..b2b755ee 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -56,6 +56,7 @@ struct llama_kv_cache { std::vector k_l; // per layer std::vector v_l; + std::vector s_l; // per layer recurrent state storage (Qwen3Next) std::vector split_k_l; std::vector split_v_l; @@ -202,6 +203,7 @@ struct llama_context { struct ggml_tensor * inp_s_copy; // I32 [kv_size] struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct ggml_tensor * inp_s_seq_qnext; // I32 [1, n_batch] struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index bbdd83a1..217f0390 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -5,7 +5,7 @@ #include -#define LLAMA_MAX_EXPERTS 384 // Kimi-K2 +#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next static const std::map LLAMA_ROPE_SCALING_TYPES = { { LLAMA_ROPE_SCALING_TYPE_NONE, "none" }, @@ -83,6 +83,7 @@ void llm_load_hparams( std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), false); ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer); @@ -453,6 +454,28 @@ void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_QWEN3NEXT: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Upstream convention: every 4th layer is full attention, others are recurrent. + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); + } + + switch (hparams.n_layer) { + case 48: model.type = e_model::MODEL_80B_A3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_QWEN3VLMOE: { ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 0dd22303..1e6a05f2 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -89,6 +89,10 @@ struct llama_hparams { uint32_t ssm_d_inner = 0; uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + uint32_t ssm_n_group = 0; + + // for hybrid state-space models (e.g. qwen3next) + std::array recurrent_layer_arr; float f_clamp_kqv = 0.0f; float f_max_alibi_bias = 0.0f; @@ -169,6 +173,8 @@ struct llama_hparams { if (this->ssm_d_inner != other.ssm_d_inner) return true; if (this->ssm_d_state != other.ssm_d_state) return true; if (this->ssm_dt_rank != other.ssm_dt_rank) return true; + if (this->ssm_n_group != other.ssm_n_group) return true; + if (this->recurrent_layer_arr != other.recurrent_layer_arr) return true; if (this->dec_start_token_id != other.dec_start_token_id) return true; @@ -246,6 +252,10 @@ struct llama_hparams { } uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings + if (ssm_n_group > 0) { + // qwen3next keeps all recurrent state in the V-cache tail + return 0; + } // corresponds to Mamba's conv_states size // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed @@ -253,10 +263,26 @@ struct llama_hparams { } uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings + if (ssm_n_group > 0) { + // qwen3next recurrent state packs: + // 1) conv state: (d_conv - 1) * (2 * key_dim + value_dim) + // 2) delta-net state: head_v_dim * head_v_dim * num_v_heads + const uint32_t key_dim = ssm_d_state * ssm_n_group; + const uint32_t value_dim = ssm_d_inner; + const uint32_t conv_dim = 2 * key_dim + value_dim; + const uint32_t conv_state_dim = (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * conv_dim; + const uint32_t head_v_dim = ssm_dt_rank > 0 ? ssm_d_inner / ssm_dt_rank : 0; + const uint32_t ssm_state_dim = head_v_dim * head_v_dim * ssm_dt_rank; + return conv_state_dim + ssm_state_dim; + } // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } + bool is_recurrent(uint32_t il) const { + return il < n_layer ? recurrent_layer_arr[il] : false; + } + static bool is_float_close(float a, float b, float abs_tol) { // Check for non-negative tolerance if (abs_tol < 0.0) { diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index e2baacbb..8703cca8 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -73,6 +73,8 @@ struct create_tensors_helper : public create_tensors_helper_interface { bool create_qwen3_moe_tensors(const LLM_TN & tn); + bool create_qwen3next_tensors(const LLM_TN & tn); + bool create_phi2_tensors(const LLM_TN & tn); bool create_phi3_tensors(const LLM_TN & tn); @@ -1291,6 +1293,99 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) { return use_mmap_buffer; } +bool create_tensors_helper::create_qwen3next_tensors(const LLM_TN & tn) { + LOADING_PRELUDE + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (model.output == NULL) { + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + const bool has_moe_hparams = n_expert > 0 && n_expert_used > 0; + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : (has_moe_hparams ? n_ff / n_expert_used : n_ff); + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp; + + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = hparams.ssm_d_inner / num_v_heads; + const int64_t key_dim = head_k_dim * num_k_heads; + const int64_t value_dim = head_v_dim * num_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + const int64_t qkvz_dim = key_dim * 2 + value_dim * 2; + const int64_t ba_dim = num_v_heads * 2; + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}); + + if (!hparams.is_recurrent(i)) { + // Full-attention layer + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head * 2}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}); + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}); + } else { + // Recurrent linear-attention layer + layer.ssm_in = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, qkvz_dim}, + llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, key_dim * 2 + value_dim}, + llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, value_dim}, + llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {hparams.ssm_d_conv, conv_dim}); + layer.ssm_dt = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {hparams.ssm_dt_rank}); + layer.ssm_a = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A_NOSCAN, i), {hparams.ssm_dt_rank}); + layer.ssm_beta_alpha = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), {n_embd, ba_dim}); + layer.ssm_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_NORM, "weight", i), {head_v_dim}); + layer.ssm_out = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {value_dim, n_embd}); + } + + auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer; + + // Dense FFN path (optional, e.g. mlp_only_layers) + layer.ffn_gate = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // MoE path (optional per-layer) + layer.ffn_gate_inp = nullptr; + if (n_expert > 0) { + layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + + if (layer.ffn_gate_inp != nullptr) { + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 when QWEN3NEXT MoE tensors are present"); + } + use_mmap_buffer &= !create_std_ffn_exps(n_embd, tn, i, llama_model_loader::TENSOR_NOT_REQUIRED, n_ff_exp); + } + + // Shared expert path (optional per-layer) + layer.ffn_gate_inp_shexp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (layer.ffn_gate_inp_shexp != nullptr) { + layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + } + + return use_mmap_buffer; +} + bool create_tensors_helper::create_mimo2_tensors(const LLM_TN & tn) { LOADING_PRELUDE @@ -3221,6 +3316,8 @@ bool create_tensors_helper::create_tensors() { case LLM_ARCH_QWEN3MOE: case LLM_ARCH_QWEN3VLMOE: use_mmap_buffer = create_qwen3_moe_tensors(tn); break; + case LLM_ARCH_QWEN3NEXT: + use_mmap_buffer = create_qwen3next_tensors(tn); break; case LLM_ARCH_PHI2: use_mmap_buffer = create_phi2_tensors(tn); break; case LLM_ARCH_PHI3: diff --git a/src/llama-model.cpp b/src/llama-model.cpp index fa66c1e3..c2eeb85d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -429,6 +429,39 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_QWEN3NEXT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_QWEN3VL, { @@ -1648,6 +1681,7 @@ const char * llama_model_type_name(e_model type) { case MODEL_16B_A1B: return "16B.A1B"; case MODEL_21B_A3B: return "21B.A3B"; case MODEL_30B_A3B: return "30B.A3B"; + case MODEL_80B_A3B: return "80B.A3B"; case MODEL_80B_A13B: return "80B.A13B"; case MODEL_100B_A6B: return "100B.A6B"; case MODEL_106B_A12B: return "106B.A12B"; diff --git a/src/llama-model.h b/src/llama-model.h index 702fcae0..04e15c70 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -107,6 +107,7 @@ enum e_model { MODEL_16B_A1B, MODEL_21B_A3B, // Ernie MoE small MODEL_30B_A3B, + MODEL_80B_A3B, // Qwen3-Next MODEL_80B_A13B, MODEL_100B_A6B, MODEL_106B_A12B, @@ -289,6 +290,8 @@ struct llama_layer { struct ggml_tensor * ssm_x = nullptr; struct ggml_tensor * ssm_dt = nullptr; struct ggml_tensor * ssm_out = nullptr; + struct ggml_tensor * ssm_norm = nullptr; + struct ggml_tensor * ssm_beta_alpha = nullptr; // mamba struct ggml_tensor * ssm_conv1d = nullptr; diff --git a/src/llama.cpp b/src/llama.cpp index c1314c13..ba042199 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -568,9 +568,15 @@ bool llama_context::can_reuse_graph(const llama_batch & u_batch) { bool llama_context::update_cache_copies() { int n_layer = model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2; + auto layer_has_attention_kv = [&](int il) { + return !(model.arch == LLM_ARCH_QWEN3NEXT && model.hparams.is_recurrent(il)); + }; if ((int)kv_self.k_l.size() != n_layer) return false; if (!(kv_self.v_l.empty() || (int)kv_self.v_l.size() == n_layer)) return false; for (int il = 0; il < n_layer; ++il) { + if (!layer_has_attention_kv(il) || kv_self.k_l[il] == nullptr) { + continue; + } auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra; if (kl) { GGML_ASSERT(model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN); @@ -597,6 +603,9 @@ bool llama_context::update_cache_copies() { } } else { for (int il = 0; il < n_layer; ++il) { + if (!layer_has_attention_kv(il) || kv_self.k_l[il] == nullptr) { + continue; + } auto& c = cache_copies[2*il+0]; if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.k_l[il]) return false; c.cpy->view_offs = kv_self.head*c.step; @@ -605,6 +614,9 @@ bool llama_context::update_cache_copies() { } if (kv_self.v_l.empty()) return true; for (int il = 0; il < n_layer; ++il) { + if (!layer_has_attention_kv(il) || kv_self.v_l[il] == nullptr) { + continue; + } auto& c = cache_copies[2*il+1]; if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.v_l[il]) return false; c.cpy->view_offs = kv_self.head*c.step; @@ -640,6 +652,58 @@ llama_context::~llama_context() { // kv cache helpers // +static inline bool llama_qwen3next_is_recurrent_layer( + const llama_model & model, + const llama_hparams & hparams, + uint32_t il) { + return model.arch == LLM_ARCH_QWEN3NEXT && hparams.is_recurrent(il); +} + +static inline uint32_t llama_kv_v_row_embd( + const llama_model & model, + const llama_hparams & hparams, + uint32_t il) { + // qwen3next recurrent state is stored in a dedicated V-cache tail (per sequence), + // so per-token V rows include only attention values. + if (model.arch == LLM_ARCH_QWEN3NEXT) { + return hparams.n_embd_v_gqa(il); + } + + return hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); +} + +static inline uint32_t llama_qwen3next_state_slots(const llama_cparams & cparams, uint32_t kv_size) { + return std::min(std::max(1, cparams.n_seq_max), kv_size); +} + +static inline uint32_t llama_kv_qnext_state_slots(const llama_kv_cache & cache) { + uint32_t n_slots = 0; + + for (const ggml_tensor * t : cache.s_l) { + if (t == nullptr) { + continue; + } + + const uint32_t layer_slots = (uint32_t) t->ne[1]; + if (n_slots == 0) { + n_slots = layer_slots; + } else { + GGML_ASSERT(n_slots == layer_slots); + } + } + + return n_slots; +} + +static inline bool llama_kv_has_qnext_state_storage(const llama_kv_cache & cache) { + return llama_kv_qnext_state_slots(cache) > 0; +} + +static inline bool llama_kv_qnext_seq_id_in_range(const llama_kv_cache & cache, llama_seq_id seq_id) { + const uint32_t n_slots = llama_kv_qnext_state_slots(cache); + return n_slots > 0 && seq_id >= 0 && (uint32_t) seq_id < n_slots; +} + static bool llama_kv_cache_init( struct llama_kv_cache & cache, const llama_context * ctx, @@ -658,7 +722,9 @@ static bool llama_kv_cache_init( // TODO: find a nicer way to add other recurrent model architectures cache.recurrent = model.arch == LLM_ARCH_MAMBA; - cache.v_trans = !cache.recurrent && !cparams.flash_attn; + // qwen3next uses hybrid recurrent+attention cache semantics. Keep V rows in + // standard layout to match the mainline hybrid path when flash attention is off. + cache.v_trans = !cache.recurrent && !cparams.flash_attn && model.arch != LLM_ARCH_QWEN3NEXT; cache.head = 0; cache.size = kv_size; @@ -670,7 +736,7 @@ static bool llama_kv_cache_init( cache.cells.clear(); cache.cells.resize(kv_size); - if (cache.recurrent) { + if (cache.recurrent || model.arch == LLM_ARCH_QWEN3NEXT) { // init state copy sources for (uint32_t i = 0; i < cache.size; ++i) { cache.cells[i].src = i; @@ -750,18 +816,27 @@ static bool llama_kv_cache_init( needs_v_cache = cparams.mla_attn == 1 && !cparams.flash_attn; } if (needs_v_cache) cache.v_l.reserve(n_layer); + cache.s_l.reserve(n_layer); std::vector mem_split(model.splits.size(), 0); + const uint32_t qnext_state_slots = llama_qwen3next_state_slots(cparams, kv_size); + if (model.arch == LLM_ARCH_QWEN3NEXT && qnext_state_slots < std::max(1, cparams.n_seq_max)) { + LLAMA_LOG_WARN("%s: reducing qwen3next state slots from %u to %u to fit KV cache size\n", + __func__, std::max(1, cparams.n_seq_max), qnext_state_slots); + } + int n_mla = 0; for (int i = 0; i < (int) n_layer; i++) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + const bool qnext_recurrent = llama_qwen3next_is_recurrent_layer(model, hparams, i); + const uint32_t n_embd_v_row = llama_kv_v_row_embd(model, hparams, i); const uint32_t n_head_kv = hparams.n_head_kv(i); const uint32_t n_embd_head_k= hparams.n_embd_head_k; struct ggml_context * ctx = split_cache ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); - ggml_tensor * k; - ggml_tensor * v; + ggml_tensor * k = nullptr; + ggml_tensor * v = nullptr; + ggml_tensor * s = nullptr; if (is_mla_attn && cparams.mla_attn) { // DeepSeek MLA const uint32_t n_embd_head_qk_rope = hparams.n_rot; @@ -792,56 +867,70 @@ static bool llama_kv_cache_init( ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); split_cache_i = false; } - int n_embd_head_v = hparams.n_embd_head_v; - k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size); - v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - auto k_name = std::string{"cache_k_l"} + std::to_string(i); - auto v_name = std::string{"cache_v_l"} + std::to_string(i); - ggml_set_name(k, k_name.c_str()); - ggml_set_name(v, v_name.c_str()); - //ggml_format_name(k, "cache_k_l%d", i); - //ggml_format_name(v, "cache_v_l%d", i); + if (qnext_recurrent) { + s = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_embd_v_s(), qnext_state_slots); + split_cache_i = false; + } else { + int n_embd_head_v = hparams.n_embd_head_v; + k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size); + + int64_t v_ne = int64_t(n_embd_v_row)*kv_size; + v = ggml_new_tensor_1d(ctx, type_v, v_ne); + + auto k_name = std::string{"cache_k_l"} + std::to_string(i); + auto v_name = std::string{"cache_v_l"} + std::to_string(i); + ggml_set_name(k, k_name.c_str()); + ggml_set_name(v, v_name.c_str()); + //ggml_format_name(k, "cache_k_l%d", i); + //ggml_format_name(v, "cache_v_l%d", i); + + if (split_cache_i) { + bool use_V_for_K = model.layers[i].attn_k_norm && model.layers[i].attn_k_norm->ne[0] == K->ne[1] ? true : false; + auto extra_K = (const ggml_split_tensor_t *)K->extra; + auto extra_V = (const ggml_split_tensor_t *)V->extra; + auto & split_k_l = cache.split_k_l.emplace_back(); + auto & split_v_l = cache.split_v_l.emplace_back(); + split_k_l.tensor_splits.resize(extra_K->n_device, nullptr); + split_v_l.tensor_splits.resize(extra_V->n_device, nullptr); + for (int is = 0; is < extra_K->n_device; ++is) { + auto split = use_V_for_K ? extra_V->splits[is] : extra_K->splits[is]; + if (!split) continue; + int nhead_kv = use_V_for_K ? split->ne[1] / n_embd_head_v : split->ne[1]/n_embd_head_k; + if (use_V_for_K) { + LLAMA_LOG_DEBUG("K_cache(%d, %d): using %d instead of %ld heads\n", + i, is, nhead_kv, extra_K->splits[is]->ne[1]/n_embd_head_k); + } + split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, nhead_kv * kv_size); + auto split_name = k_name + '.' + std::to_string(is); + ggml_set_name(split_k_l.tensor_splits[is], split_name.c_str()); + mem_split[is] += ggml_nbytes(split_k_l.tensor_splits[is]); + } + split_k_l.ggml.n_device = extra_K->n_device; + split_k_l.ggml.split_dim = 0; + split_k_l.ggml.splits = split_k_l.tensor_splits.data(); + for (int is = 0; is < extra_V->n_device; ++is) { + auto split = extra_V->splits[is]; + if (!split) continue; + split_v_l.tensor_splits[is] = ggml_new_tensor_1d(ctx, type_v, split->ne[1] * kv_size); + auto split_name = v_name + '.' + std::to_string(is); + ggml_set_name(split_v_l.tensor_splits[is], split_name.c_str()); + mem_split[is] += ggml_nbytes(split_v_l.tensor_splits[is]); + } + split_v_l.ggml.n_device = extra_V->n_device; + split_v_l.ggml.split_dim = 0; + split_v_l.ggml.splits = split_v_l.tensor_splits.data(); + k->extra = (void *)&split_k_l.ggml; + v->extra = (void *)&split_v_l.ggml; + } + } + if (s) { + auto s_name = std::string{"cache_s_l"} + std::to_string(i); + ggml_set_name(s, s_name.c_str()); + } cache.k_l.push_back(k); cache.v_l.push_back(v); - if (split_cache_i) { - bool use_V_for_K = model.layers[i].attn_k_norm && model.layers[i].attn_k_norm->ne[0] == K->ne[1] ? true : false; - auto extra_K = (const ggml_split_tensor_t *)K->extra; - auto extra_V = (const ggml_split_tensor_t *)V->extra; - auto & split_k_l = cache.split_k_l.emplace_back(); - auto & split_v_l = cache.split_v_l.emplace_back(); - split_k_l.tensor_splits.resize(extra_K->n_device, nullptr); - split_v_l.tensor_splits.resize(extra_V->n_device, nullptr); - for (int is = 0; is < extra_K->n_device; ++is) { - auto split = use_V_for_K ? extra_V->splits[is] : extra_K->splits[is]; - if (!split) continue; - int nhead_kv = use_V_for_K ? split->ne[1] / n_embd_head_v : split->ne[1]/n_embd_head_k; - if (use_V_for_K) { - LLAMA_LOG_DEBUG("K_cache(%d, %d): using %d instead of %ld heads\n", - i, is, nhead_kv, extra_K->splits[is]->ne[1]/n_embd_head_k); - } - split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, nhead_kv * kv_size); - auto split_name = k_name + '.' + std::to_string(is); - ggml_set_name(split_k_l.tensor_splits[is], split_name.c_str()); - mem_split[is] += ggml_nbytes(split_k_l.tensor_splits[is]); - } - split_k_l.ggml.n_device = extra_K->n_device; - split_k_l.ggml.split_dim = 0; - split_k_l.ggml.splits = split_k_l.tensor_splits.data(); - for (int is = 0; is < extra_V->n_device; ++is) { - auto split = extra_V->splits[is]; - if (!split) continue; - split_v_l.tensor_splits[is] = ggml_new_tensor_1d(ctx, type_v, split->ne[1] * kv_size); - auto split_name = v_name + '.' + std::to_string(is); - ggml_set_name(split_v_l.tensor_splits[is], split_name.c_str()); - mem_split[is] += ggml_nbytes(split_v_l.tensor_splits[is]); - } - split_v_l.ggml.n_device = extra_V->n_device; - split_v_l.ggml.split_dim = 0; - split_v_l.ggml.splits = split_v_l.tensor_splits.data(); - k->extra = (void *)&split_k_l.ggml; - v->extra = (void *)&split_v_l.ggml; - } } + cache.s_l.push_back(s); } if (is_mla_attn && cparams.mla_attn && n_mla < n_layer && n_mla > 0) { LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer)); @@ -1017,6 +1106,7 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { static void llama_kv_cache_clear(struct llama_kv_cache & cache) { for (int32_t i = 0; i < (int32_t) cache.size; ++i) { cache.cells[i].pos = -1; + cache.cells[i].src = i; cache.cells[i].seq_id.clear(); } cache.head = 0; @@ -1056,6 +1146,8 @@ static bool llama_kv_cache_seq_rm( } } + const bool has_qnext_state = llama_kv_has_qnext_state_storage(cache); + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { if (seq_id < 0) { @@ -1070,6 +1162,9 @@ static bool llama_kv_cache_seq_rm( if (cache.cells[i].pos >= 0) cache.used--; cache.cells[i].pos = -1; + if (has_qnext_state) { + cache.cells[i].src = i; + } if (new_head == cache.size) new_head = i; } } @@ -1111,6 +1206,21 @@ static void llama_kv_cache_seq_cp( } return; } + + const bool has_qnext_state = llama_kv_has_qnext_state_storage(cache); + if (has_qnext_state && + llama_kv_qnext_seq_id_in_range(cache, seq_id_dst) && + llama_kv_qnext_seq_id_in_range(cache, seq_id_src) && + (uint32_t) seq_id_dst < cache.size && + (uint32_t) seq_id_src < cache.size) { + seq_id_src = cache.cells[seq_id_src].src; + GGML_ASSERT((uint32_t) seq_id_src < cache.size); + + cache.cells[seq_id_dst].src = seq_id_src; + cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos; + cache.do_copy = true; + } + // otherwise, this is the KV cache of a Transformer-like model cache.head = 0; @@ -1124,11 +1234,15 @@ static void llama_kv_cache_seq_cp( static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { uint32_t new_head = cache.size; + const bool has_qnext_state = llama_kv_has_qnext_state_storage(cache); for (uint32_t i = 0; i < cache.size; ++i) { if (!cache.cells[i].has_seq_id(seq_id)) { if (cache.cells[i].pos >= 0) cache.used--; cache.cells[i].pos = -1; + if (has_qnext_state) { + cache.cells[i].src = i; + } cache.cells[i].seq_id.clear(); if (new_head == cache.size) new_head = i; } else { @@ -2764,6 +2878,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } + if (lctx.inp_s_seq_qnext) { + const int64_t n_tokens = batch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq_qnext->buffer)); + int32_t * data = (int32_t *) lctx.inp_s_seq_qnext->data; + + for (int64_t j = 0; j < n_tokens; ++j) { + // qwen3next linear-attention path uses a single local recurrent state slot. + data[j] = 0; + } + } + if (lctx.inp_pos_bucket) { const int64_t n_tokens = batch.n_tokens; @@ -3012,11 +3138,51 @@ static int llama_decode_internal( } } - for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { + bool warned_qnext_mixed_repeat = false; + for (uint32_t cur_token = 0; cur_token < n_tokens_all; ) { #if IK_PRINT_TIMING auto tim1 = ggml_time_us(); #endif - const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); + uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); + if (model.arch == LLM_ARCH_QWEN3NEXT && + n_tokens > 1 && + batch_all.n_seq_id != nullptr && + batch_all.seq_id != nullptr) { + bool can_check = true; + bool any_diff = false; + bool has_dup = false; + llama_seq_id first_seq_id = 0; + std::unordered_set seen_seq_ids; + seen_seq_ids.reserve(n_tokens); + + for (uint32_t i = 0; i < n_tokens; ++i) { + const uint32_t idx = cur_token + i; + if (batch_all.n_seq_id[idx] <= 0 || batch_all.seq_id[idx] == nullptr) { + can_check = false; + break; + } + + const llama_seq_id seq_id_i = batch_all.seq_id[idx][0]; + if (i == 0) { + first_seq_id = seq_id_i; + } else if (seq_id_i != first_seq_id) { + any_diff = true; + } + + if (!seen_seq_ids.insert(seq_id_i).second) { + has_dup = true; + } + } + + if (can_check && any_diff && has_dup) { + n_tokens = 1; + if (!warned_qnext_mixed_repeat) { + LLAMA_LOG_WARN("%s: qwen3next mixed-sequence batch contains repeated seq_id values; falling back to single-token chunking\n", __func__); + warned_qnext_mixed_repeat = true; + } + } + } + llama_batch u_batch = { /* .n_tokens = */ (int32_t) n_tokens, /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr, @@ -3293,6 +3459,7 @@ static int llama_decode_internal( #endif } n_outputs_prev += lctx.n_outputs; + cur_token += n_tokens; } // set to total number of outputs in the batch, for use in llama_get_logits_ith @@ -3766,7 +3933,7 @@ static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) { } } - if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) { + if ((lctx.kv_self.recurrent || llama_kv_has_qnext_state_storage(lctx.kv_self)) && lctx.kv_self.do_copy) { { lctx.reset_scheduler(); @@ -4787,11 +4954,15 @@ struct llama_context * llama_init_from_model( size_t memory_size_v = 0; for (auto & k : ctx->kv_self.k_l) { - memory_size_k += ggml_nbytes(k); + if (k) { + memory_size_k += ggml_nbytes(k); + } } for (auto & v : ctx->kv_self.v_l) { - memory_size_v += ggml_nbytes(v); + if (v) { + memory_size_v += ggml_nbytes(v); + } } if (memory_size_k + memory_size_v > 0) { @@ -4918,7 +5089,7 @@ struct llama_context * llama_init_from_model( } if (params.only_active_experts) { - LLAMA_LOG_INFO("XXXXXXXXXXXXXXXXXXXXX Setting only active experts offload\n"); + LLAMA_LOG_INFO("%s: enabling only_active_experts scheduling\n", __func__); ggml_backend_sched_set_only_active_experts(ctx->sched, true); } if (model->split_mode == LLAMA_SPLIT_MODE_GRAPH && (!model->has_tensor_overrides() || cparams.split_mode_graph_scheduling)) { @@ -5031,6 +5202,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_QWEN2MOE: case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3MOE: + case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_PHI2: case LLM_ARCH_PHI3: case LLM_ARCH_GEMMA: @@ -5586,7 +5758,7 @@ struct llama_data_write { } } - void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) { + void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) { const struct llama_kv_cache & kv_self = ctx->kv_self; const struct llama_hparams & hparams = ctx->model.hparams; @@ -5599,23 +5771,30 @@ struct llama_data_write { write(&v_state, sizeof(v_state)); write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; + const bool has_k_cache = kv_self.k_l[il] != nullptr; // Write key type - const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; + const int32_t k_type_i = has_k_cache ? (int32_t) kv_self.k_l[il]->type : -1; write(&k_type_i, sizeof(k_type_i)); // Write row size of key - const uint64_t k_size_row = (ctx->cparams.mla_attn == 0) ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope); + const uint64_t k_size_row = has_k_cache + ? ((ctx->cparams.mla_attn == 0) + ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) + : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope)) + : 0; write(&k_size_row, sizeof(k_size_row)); + if (!has_k_cache) { + continue; + } + // Read each range of cells of k_size length each into tmp_buf and write out for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; @@ -5626,16 +5805,21 @@ struct llama_data_write { if (v_state == 0) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il); + const bool has_v_cache = kv_self.v_l[il] != nullptr; // Write value type - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1; write(&v_type_i, sizeof(v_type_i)); // Write row size of value - const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + const uint64_t v_size_row = has_v_cache ? ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa) : 0; write(&v_size_row, sizeof(v_size_row)); + if (!has_v_cache) { + continue; + } + // Read each range of cells of v_size length each into tmp_buf and write out for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; @@ -5648,18 +5832,24 @@ struct llama_data_write { // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = kv_self.size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il); + const bool has_v_cache = kv_self.v_l[il] != nullptr; // Write value type - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + const int32_t v_type_i = has_v_cache ? (int32_t) kv_self.v_l[il]->type : -1; write(&v_type_i, sizeof(v_type_i)); // Write element size - const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + const uint32_t v_size_el = has_v_cache ? ggml_type_size(kv_self.v_l[il]->type) : 0; write(&v_size_el, sizeof(v_size_el)); // Write GQA embedding size - write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + const uint32_t n_embd_v_gqa_write = has_v_cache ? n_embd_v_gqa : 0; + write(&n_embd_v_gqa_write, sizeof(n_embd_v_gqa_write)); + + if (!has_v_cache) { + continue; + } // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { @@ -5673,6 +5863,42 @@ struct llama_data_write { } } } + + const uint32_t qnext_state = llama_kv_has_qnext_state_storage(kv_self) ? 1 : 0; + write(&qnext_state, sizeof(qnext_state)); + + if (qnext_state != 0) { + for (uint32_t il = 0; il < n_layer; ++il) { + const bool has_s_cache = il < kv_self.s_l.size() && kv_self.s_l[il] != nullptr; + + const int32_t s_type_i = has_s_cache ? (int32_t) kv_self.s_l[il]->type : -1; + write(&s_type_i, sizeof(s_type_i)); + + const uint64_t s_size_row = has_s_cache ? ggml_row_size(kv_self.s_l[il]->type, kv_self.s_l[il]->ne[0]) : 0; + write(&s_size_row, sizeof(s_size_row)); + + uint32_t s_rows = 0; + size_t s_offset = 0; + if (has_s_cache) { + const uint32_t n_slots = (uint32_t) kv_self.s_l[il]->ne[1]; + if (seq_id == -1) { + s_rows = n_slots; + } else if (llama_kv_qnext_seq_id_in_range(kv_self, seq_id) && (uint32_t) seq_id < kv_self.size) { + llama_seq_id src_seq_id = kv_self.cells[seq_id].src; + if (llama_kv_qnext_seq_id_in_range(kv_self, src_seq_id)) { + s_rows = 1; + s_offset = (size_t) src_seq_id * s_size_row; + } + } + } + + write(&s_rows, sizeof(s_rows)); + + if (has_s_cache && s_rows > 0) { + write_tensor_data(kv_self.s_l[il], s_offset, s_rows * s_size_row, il); + } + } + } } void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) { @@ -5711,7 +5937,7 @@ struct llama_data_write { write(&cell_count, sizeof(cell_count)); write_kv_cache_meta(kv_self, cell_ranges, seq_id); - write_kv_cache_data(ctx, cell_ranges); + write_kv_cache_data(ctx, cell_ranges, seq_id); } }; @@ -5922,7 +6148,7 @@ struct llama_data_read { GGML_ASSERT(sum_split_row_size == row_size); } - bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) { + bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) { const struct llama_hparams & hparams = ctx->model.hparams; struct llama_kv_cache & kv_self = ctx->kv_self; @@ -5954,20 +6180,35 @@ struct llama_data_read { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; + const bool has_k_cache = kv_self.k_l[il] != nullptr; // Read type of key int32_t k_type_i_ref; read_to(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; - if (k_type_i != k_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); - return false; + if (!has_k_cache) { + if (k_type_i_ref != -1) { + LLAMA_LOG_ERROR("%s: missing key cache for layer %d\n", __func__, il); + return false; + } + } else { + const int32_t k_type_i = (int32_t) kv_self.k_l[il]->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } } // Read row size of key uint64_t k_size_row_ref; read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + if (!has_k_cache) { + if (k_size_row_ref != 0) { + LLAMA_LOG_ERROR("%s: expected empty key row size for layer %d\n", __func__, il); + return false; + } + continue; + } const uint64_t k_size_row = (ctx->cparams.mla_attn == 0) ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); @@ -5986,20 +6227,35 @@ struct llama_data_read { if (v_state == 0) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il); + const bool has_v_cache = kv_self.v_l[il] != nullptr; // Read type of value int32_t v_type_i_ref; read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; + if (!has_v_cache) { + if (v_type_i_ref != -1) { + LLAMA_LOG_ERROR("%s: missing value cache for layer %d\n", __func__, il); + return false; + } + } else { + const int32_t v_type_i = (int32_t) kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } } // Read row size of value uint64_t v_size_row_ref; read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + if (!has_v_cache) { + if (v_size_row_ref != 0) { + LLAMA_LOG_ERROR("%s: expected empty value row size for layer %d\n", __func__, il); + return false; + } + continue; + } const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); @@ -6019,35 +6275,58 @@ struct llama_data_read { else if (v_state == 1) { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = llama_kv_v_row_embd(ctx->model, hparams, il); + const bool has_v_cache = kv_self.v_l[il] != nullptr; // Read type of value int32_t v_type_i_ref; read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; + if (!has_v_cache) { + if (v_type_i_ref != -1) { + LLAMA_LOG_ERROR("%s: missing transposed value cache for layer %d\n", __func__, il); + return false; + } + } else { + const int32_t v_type_i = (int32_t) kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } } // Read element size of value uint32_t v_size_el_ref; read_to(&v_size_el_ref, sizeof(v_size_el_ref)); - const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); - if (v_size_el != v_size_el_ref) { - LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); - return false; + if (!has_v_cache) { + if (v_size_el_ref != 0) { + LLAMA_LOG_ERROR("%s: expected empty transposed value element size for layer %d\n", __func__, il); + return false; + } + } else { + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } } // Read GQA embedding size uint32_t n_embd_v_gqa_ref; read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (!has_v_cache) { + if (n_embd_v_gqa_ref != 0) { + LLAMA_LOG_ERROR("%s: expected empty transposed value rows for layer %d\n", __func__, il); + return false; + } + continue; + } if (n_embd_v_gqa != n_embd_v_gqa_ref) { LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); return false; } if (cell_count) { + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); if (kv_self.v_l[il]->extra) { throw std::runtime_error("Transposed V cache is not sypported with split mode 'graph'"); } @@ -6059,6 +6338,76 @@ struct llama_data_read { } } } + + uint32_t qnext_state_ref = 0; + read_to(&qnext_state_ref, sizeof(qnext_state_ref)); + + const bool has_qnext_state = llama_kv_has_qnext_state_storage(kv_self); + if ((qnext_state_ref != 0) != has_qnext_state) { + LLAMA_LOG_ERROR("%s: incompatible qwen3next state cache presence\n", __func__); + return false; + } + + if (qnext_state_ref != 0) { + for (uint32_t il = 0; il < n_layer; ++il) { + const bool has_s_cache = il < kv_self.s_l.size() && kv_self.s_l[il] != nullptr; + + int32_t s_type_i_ref; + read_to(&s_type_i_ref, sizeof(s_type_i_ref)); + if (!has_s_cache) { + if (s_type_i_ref != -1) { + LLAMA_LOG_ERROR("%s: missing qwen3next state cache for layer %d\n", __func__, il); + return false; + } + } else { + const int32_t s_type_i = (int32_t) kv_self.s_l[il]->type; + if (s_type_i != s_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched qwen3next state type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); + return false; + } + } + + uint64_t s_size_row_ref; + read_to(&s_size_row_ref, sizeof(s_size_row_ref)); + + const uint64_t s_size_row = has_s_cache ? ggml_row_size(kv_self.s_l[il]->type, kv_self.s_l[il]->ne[0]) : 0; + if (s_size_row != s_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched qwen3next state row size (%zu != %zu, layer %d)\n", + __func__, (size_t) s_size_row, (size_t) s_size_row_ref, il); + return false; + } + + uint32_t s_rows_ref; + read_to(&s_rows_ref, sizeof(s_rows_ref)); + + uint32_t s_rows = 0; + uint32_t s_dst_row = 0; + if (has_s_cache) { + const uint32_t n_slots = (uint32_t) kv_self.s_l[il]->ne[1]; + if (seq_id == -1) { + s_rows = n_slots; + } else if (llama_kv_qnext_seq_id_in_range(kv_self, seq_id)) { + s_rows = 1; + s_dst_row = (uint32_t) seq_id; + } + } + + if (s_rows_ref != s_rows) { + LLAMA_LOG_ERROR("%s: mismatched qwen3next state row count (%u != %u, layer %d)\n", __func__, s_rows, s_rows_ref, il); + return false; + } + + if (s_rows > 0) { + const size_t s_data_size = s_rows * s_size_row; + const size_t s_dst_offset = (size_t) s_dst_row * s_size_row; + if (kv_self.s_l[il]->extra) { + read_kv_cache_data_split(ctx, kv_self.s_l[il], read(s_data_size), s_dst_row, s_size_row, s_rows, il); + } else { + ggml_backend_tensor_set(kv_self.s_l[il], read(s_data_size), s_dst_offset, s_data_size); + } + } + } + } return true; } @@ -6066,7 +6415,7 @@ struct llama_data_read { uint32_t cell_count; read_to(&cell_count, sizeof(cell_count)); - bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count); + bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count, seq_id); if (!res) { if (seq_id == -1) {