From c77ec4b8b80ada495c40d888cc52e40c14bac547 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 25 Feb 2026 14:12:48 +0100 Subject: [PATCH] Fused delta-net (#1315) * Revive fused delta-net * Add command line argument for fused delta net * Simplify/improve CUDA delta-net * Add -fdn to llama-bench * More CUDA fused delta net optimizations * CPU optimizations * Much faster fused delta-net on the CPU It seems it is faster than the chunked implementation! * Change meaning of fdn from bool flag to threshold value * Use eps = 1e-6 * Give some nodes a name --- common/common.cpp | 8 + common/common.h | 1 + examples/llama-bench/llama-bench.cpp | 34 +- ggml/include/ggml.h | 10 + ggml/src/ggml-cuda.cu | 6 + ggml/src/ggml-cuda/delta-net.cu | 485 +++++++++++++++++++++++++++ ggml/src/ggml-cuda/delta-net.cuh | 3 + ggml/src/ggml.c | 200 ++++++++++- ggml/src/iqk/iqk_mul_mat.cpp | 156 +++++++++ ggml/src/iqk/iqk_mul_mat.h | 4 + include/llama.h | 1 + src/llama-cparams.h | 1 + src/llama-delta-net.cpp | 94 +++++- src/llama-delta-net.h | 5 + src/llama.cpp | 7 +- 15 files changed, 1002 insertions(+), 13 deletions(-) create mode 100644 ggml/src/ggml-cuda/delta-net.cu create mode 100644 ggml/src/ggml-cuda/delta-net.cuh diff --git a/common/common.cpp b/common/common.cpp index f0a4d3dd..6486c097 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1531,6 +1531,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.scheduler_async = true; return true; } + if (arg == "-fdn" || arg == "--fused-delta-net") { + CHECK_ARG + params.fused_delta_net = std::stoi(argv[i]); + return true; + } if (arg == "-smf16" || arg == "--split-mode-f16") { params.reduce_type = "f16"; //params.split_mode_f16 = true; @@ -2258,6 +2263,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-grt, --graph-reduce-type", "Type for data exchange between GPUs (default: %s)", "f32"}); options.push_back({ "*", "-smgs, --split-mode-graph-scheduling,", "Force Split Mode Graph Scheduling (default: %d)", params.split_mode_graph_scheduling}); options.push_back({ "*", "-sas, --scheduler_async,", "Async evaluation of compute graphs: %d)", params.scheduler_async}); + options.push_back({ "*", "-fdn, --fused-delta-net N", "Use fused delta-net when batch size is <= N with recurrent models: %d)", params.fused_delta_net}); options.push_back({ "*", "-vq, --validate-quants", "validate quantized data while loading the model (default: %d)", params.validate_quants}); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" @@ -3336,6 +3342,7 @@ struct llama_context_params common_context_params_to_llama(const gpt_params & pa cparams.split_mode_graph_scheduling = params.split_mode_graph_scheduling; //cparams.split_mode_f16 = params.split_mode_f16; cparams.scheduler_async = params.scheduler_async; + cparams.fused_delta_net = params.fused_delta_net; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; cparams.only_active_experts = params.only_active_exps; @@ -4346,6 +4353,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l //fprintf(stream, "split_mode_f16: %s # default: true\n", params.split_mode_f16 ? "true" : "false"); fprintf(stream, "reduce_type: %s # default f16\n", params.reduce_type.c_str()); fprintf(stream, "scheduler_async: %s # default: false\n", params.scheduler_async ? "true" : "false"); + fprintf(stream, "fused_delta_net: %d # default: 0\n", params.fused_delta_net ); fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index 26cd520e..c14d2e22 100644 --- a/common/common.h +++ b/common/common.h @@ -357,6 +357,7 @@ struct gpt_params { bool split_mode_graph_scheduling = false; // if true, force split mode graph scheduling //bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops bool scheduler_async = false; // if true, in split mode graph the scheduler will use multiple threads to evaluate the graph + int fused_delta_net = 0; // use fused delta-net if number of tokens in the batch is less than this value bool has_mtp = false; // enable MTP if supported by the model std::string cache_type_k = "f16"; // KV cache data type for the K diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index a4420fd6..aeb78f89 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -271,6 +271,7 @@ struct cmd_params { bool muge = false; bool rcache = false; bool sas = false; + int fdn = 0; // fdn = fused delta net bool print_overrides = false; output_formats output_format; output_formats output_format_stderr; @@ -316,6 +317,7 @@ static const cmd_params cmd_params_defaults = { /* muge */ false, /* rcache */ false, /* sas */ false, + /* fdn */ 0, /* print_overrides */ false, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, @@ -369,6 +371,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -no-fug, --no-fused-up-gate <0|1> (default: %s)\n", cmd_params_defaults.no_fug? "1" : "0"); printf(" -no-ooae, --no-offload-only-active-experts <0|1> (default: %s)\n", cmd_params_defaults.no_ooae? "1" : "0"); printf(" -sas, --scheduler-async <0|1> (default: %s)\n", cmd_params_defaults.sas ? "1" : "0"); + printf(" -fdn, --fused-delta-net (default: %d)\n", cmd_params_defaults.fdn); printf(" --print-overrides <0|1> (default: %s)\n", cmd_params_defaults.print_overrides ? "1" : "0"); printf("\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); @@ -810,6 +813,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.sas = std::stoi(argv[i]); + } else if (arg == "-fdn" || arg == "--fused-delta-net") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.fdn = std::stoi(argv[i]); } else if (arg == "-rcache" || arg == "--rope-cache") { if (++i >= argc) { invalid_param = true; @@ -956,6 +965,7 @@ struct cmd_params_instance { bool muge = false; bool rcache = false; bool sas = false; + int fdn = 0; const llama_model_tensor_buft_override* buft_overrides; llama_model_params to_llama_mparams() const { @@ -990,6 +1000,8 @@ struct cmd_params_instance { mqkv == other.mqkv && muge == other.muge && use_thp == other.use_thp && + sas == other.sas && + fdn == other.fdn && tensor_split == other.tensor_split; } @@ -1016,6 +1028,7 @@ struct cmd_params_instance { cparams.embeddings = embeddings; cparams.cuda_params = (void *)cuda_params.data(); cparams.scheduler_async = sas; + cparams.fused_delta_net = fdn; return cparams; } @@ -1082,6 +1095,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .muge = */ params.muge, /* .rcache = */ params.rcache, /* .sas = */ params.sas, + /* .fdn = */ params.fdn, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1125,6 +1139,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .muge = */ params.muge, /* .rcache = */ params.rcache, /* .sas = */ params.sas, + /* .fdn = */ params.fdn, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1168,6 +1183,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .muge = */ params.muge, /* .rcache = */ params.rcache, /* .sas = */ params.sas, + /* .fdn = */ params.fdn, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1211,6 +1227,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .muge = */ params.muge, /* .rcache = */ params.rcache, /* .sas = */ params.sas, + /* .fdn = */ params.fdn, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1265,6 +1282,7 @@ struct test { bool muge = false; bool rcache = false; bool sas = false; + int fdn = 0; std::string override_tensor; int n_prompt; int n_gen; @@ -1306,6 +1324,7 @@ struct test { ger = inst.ger; rcache = inst.rcache; sas = inst.sas; + fdn = inst.fdn; no_fug = inst.no_fug; use_thp = inst.use_thp; no_ooae = inst.no_ooae; @@ -1410,7 +1429,7 @@ struct test { field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" || field == "main_gpu" || field == "n_prompt" || field == "n_gen" || field == "mla_attn" || field == "attn_max_batch" || - field == "avg_ns" || field == "stddev_ns") { + field == "avg_ns" || field == "stddev_ns" || field == "fdn") { return INT; } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || @@ -1461,7 +1480,7 @@ struct test { std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), std::to_string(reuse), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(mqkv), std::to_string(muge), std::to_string(fmoe), std::to_string(ger), - std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), std::to_string(rcache), std::to_string(sas), + std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), std::to_string(rcache), std::to_string(sas), std::to_string(fdn), cuda_params, override_tensor, std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -1482,7 +1501,7 @@ struct test { "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", "reuse", "tensor_split", "use_mmap", "embeddings", "repack", "mqkv", "muge", "fused_moe", "grouped_er", - "no_fused_up_gate", "use_thp", "no_ooae", "rcache", "sas", "cuda_params", "override_tensor", + "no_fused_up_gate", "use_thp", "no_ooae", "rcache", "sas", "fdn", "cuda_params", "override_tensor", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", "test", @@ -1672,6 +1691,9 @@ struct markdown_printer : public printer { if (field == "sas") { return 3; } + if (field == "fdn") { + return 4; + } if (field == "use_thp") { return 3; } @@ -1745,6 +1767,9 @@ struct markdown_printer : public printer { if (field == "sas") { return "sas"; } + if (field == "fdn") { + return "fdn"; + } if (field == "use_thp") { return "thp"; } @@ -1855,6 +1880,9 @@ struct markdown_printer : public printer { if (params.sas != cmd_params_defaults.sas) { fields.emplace_back("sas"); } + if (params.fdn != cmd_params_defaults.fdn) { + fields.emplace_back("fdn"); + } if (params.muge != cmd_params_defaults.muge) { fields.emplace_back("muge"); } diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index cb447eaa..2056b6a9 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -678,6 +678,7 @@ extern "C" { GGML_OP_TRI, GGML_OP_FILL, GGML_OP_SOLVE_TRI, + GGML_OP_DELTA_NET, GGML_OP_MAP_UNARY, GGML_OP_MAP_BINARY, @@ -2508,6 +2509,15 @@ extern "C" { bool lower, bool uni); + GGML_API struct ggml_tensor * ggml_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state); + // custom operators typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 6ab058ac..93a71c15 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -55,6 +55,7 @@ #include "ggml-cuda/hadamard.cuh" #include "ggml-cuda/reduce.cuh" #include "ggml-cuda/tri.cuh" +#include "ggml-cuda/delta-net.cuh" #include #include @@ -3698,6 +3699,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SOLVE_TRI: ggml_cuda_op_solve_tri(ctx, dst); break; + case GGML_OP_DELTA_NET: + ggml_cuda_op_delta_net(ctx, dst); + break; case GGML_OP_FLASH_ATTN_EXT: ggml_cuda_flash_attn_ext(ctx, dst); break; @@ -4557,6 +4561,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons op->src[2]->ne[1] == op->src[0]->ne[1] && op->src[1]->ne[0] == op->src[0]->ne[1] && op->src[3]->ne[0] == op->src[0]->ne[2]; + case GGML_OP_DELTA_NET: + return true; case GGML_OP_FLASH_ATTN_EXT: #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128; diff --git a/ggml/src/ggml-cuda/delta-net.cu b/ggml/src/ggml-cuda/delta-net.cu new file mode 100644 index 00000000..367ae67d --- /dev/null +++ b/ggml/src/ggml-cuda/delta-net.cu @@ -0,0 +1,485 @@ +#include "common.cuh" +#include "delta-net.cuh" +#include +#include + +// Delta Net Linear Attention Kernel for Qwen3-Next (HEAD_DIM=128) +// State layout: [S_v, S_v*H_v, 1, n_seqs] (column-major) + +__device__ __forceinline__ float sigmoid_f(float x) { + return 1.0f / (1.0f + expf(-x)); +} + +template +__device__ __forceinline__ float reduce_sum(float x, float * s) { + x = warp_reduce_sum(x); + if constexpr (block_size > WARP_SIZE) { + //__shared__ float s[block_size/WARP_SIZE]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s[warp_id] = x; + } + __syncthreads(); + x = lane_id < block_size/WARP_SIZE ? s[lane_id] : 0.0f; + x = warp_reduce_sum(x); + } + return x; +} + +template +__global__ void delta_net_recurrent_f32( + const float * __restrict__ q, // [HEAD_DIM, n_tokens, n_heads, n_seqs] + const float * __restrict__ k, // [HEAD_DIM, n_tokens, n_heads, n_seqs] + const float * __restrict__ v, // [HEAD_DIM, n_tokens, n_heads, n_seqs] + const float * __restrict__ g, // [n_tokens, 1, n_heads, n_seqs] + const float * __restrict__ beta_in, // [1, n_tokens, n_heads, n_seqs] + const float * __restrict__ state_in, // [HEAD_DIM, HEAD_DIM*n_heads, 1, n_seqs] + float * __restrict__ dst, // output + new_state concatenated + const int64_t n_heads, + const int64_t n_tokens, + const int64_t n_seqs, + const int64_t output_offset, // offset where state starts in output + const float eps) { + const int batch_idx = blockIdx.x / n_heads; + const int head_idx = blockIdx.x % n_heads; + const int tid = threadIdx.x; + const int warp_id = tid / WARP_SIZE; // 0-7 for 256 threads + const int lane_id = tid % WARP_SIZE; // 0-31 + constexpr int NUM_WARPS = block_size/WARP_SIZE; + + // Strides for input tensors (column-major) + // Q/K/V: [HEAD_DIM, n_tokens, n_heads, n_seqs] + const int64_t qkv_stride_token = HEAD_DIM; + const int64_t qkv_stride_head = HEAD_DIM * n_tokens; + const int64_t qkv_stride_batch = HEAD_DIM * n_tokens * n_heads; + + // G/Beta: [n_tokens, 1, n_heads, n_seqs] / [1, n_tokens, n_heads, n_seqs] + const int64_t g_stride_head = n_tokens; + const int64_t g_stride_batch = n_tokens * n_heads; + + // State: [HEAD_DIM, HEAD_DIM*n_heads, 1, n_seqs] + // For head h: columns h*HEAD_DIM to (h+1)*HEAD_DIM + // state[row, col] for head h = state[row, h*HEAD_DIM + col] + // Linear index: row + (h*HEAD_DIM + col) * HEAD_DIM = row + h*HEAD_DIM^2 + col*HEAD_DIM + const int64_t state_head_offset = head_idx * HEAD_DIM * HEAD_DIM; + const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads; + + // Pointers for this batch/head + const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset; + + // Output layout: [head_v_dim, num_v_heads, n_seq_tokens, n_seqs] + // For [dim, head, token, batch]: index = dim + head*S_v + token*S_v*H_v + batch*S_v*H_v*n_tokens + float * out_base = dst + batch_idx * (HEAD_DIM * n_heads * n_tokens) + head_idx * HEAD_DIM; + const int64_t out_token_stride = HEAD_DIM * n_heads; // stride between tokens + float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset; + + // Shared memory for current token's Q, K, V (normalized), and intermediate results + extern __shared__ float smem[]; + float * sQ = smem; // HEAD_DIM + float * sK = sQ + HEAD_DIM; // HEAD_DIM + float * sV = sK + HEAD_DIM; // HEAD_DIM + float * sKBeta = sV + HEAD_DIM; // HEAD_DIM (plain k for state update) + float * sVBeta = sKBeta + HEAD_DIM; // HEAD_DIM (v * sigmoid(beta)) + float * sOut = sVBeta + HEAD_DIM; // HEAD_DIM + float * sKCumdecay = sOut + HEAD_DIM; // HEAD_DIM (k * sigmoid(beta) * exp(g)) + float * sVNew = sKCumdecay + HEAD_DIM; // HEAD_DIM (v_beta - v_prime) + + const float scale = rsqrtf((float)HEAD_DIM); + + __shared__ float sum_helper[block_size/WARP_SIZE]; + + // Copy initial state to output buffer (will be updated in place) + for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { + state_dst[i] = state_src[i]; + } + __syncthreads(); + + // Process each token sequentially + for (int64_t t = 0; t < n_tokens; t++) { + + float q_sq = 0.0f; + float k_sq = 0.0f; + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + sQ[i] = q_ptr[t * qkv_stride_token + i]; + sK[i] = k_ptr[t * qkv_stride_token + i]; + sV[i] = v_ptr[t * qkv_stride_token + i]; + q_sq += sQ[i] * sQ[i]; + k_sq += sK[i] * sK[i]; + } + + q_sq = reduce_sum(q_sq, sum_helper); + k_sq = reduce_sum(k_sq, sum_helper); + + float q_norm = rsqrtf(q_sq + eps); + float k_norm = rsqrtf(k_sq + eps); + + float beta_val = sigmoid_f(beta_ptr[t]); + float decay = expf(fminf(g_ptr[t], 50.0f)); + + float sum = 0; + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + sQ[i] = sQ[i] * q_norm * scale; + sK[i] = sK[i] * k_norm; + sKBeta[i] = sK[i]; + sVBeta[i] = sV[i] * beta_val; + sKCumdecay[i] = sK[i] * beta_val * decay; + sum += sK[i] * sQ[i]; + } + float attn_score = reduce_sum(sum, sum_helper); + //__syncthreads(); + + for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) { + float sum1 = 0.0f; + float sum2 = 0.0f; + #pragma unroll + for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) { + float sval = state_dst[row_out + col * HEAD_DIM]; + sum1 += sval * sKCumdecay[col]; + sum2 += sval * sQ[col]; + } + sum1 = warp_reduce_sum(sum1); + sum2 = warp_reduce_sum(sum2); + if (lane_id == 0) { + sVNew[row_out] = sVBeta[row_out] - sum1; + float v_attn = sVNew[row_out] * attn_score; + //sOut[row_out] = sum2 * decay + v_attn; + out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn; + } + } + __syncthreads(); + + for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) { + #pragma unroll + for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) { + float state_val = state_dst[row + out_dim * HEAD_DIM]; + float safe_decay = decay; + if (isnan(safe_decay) || isinf(safe_decay)) { + safe_decay = 1.0f; + } + float new_state_val = safe_decay * state_val + sVNew[row] * sKBeta[out_dim]; + new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f); + state_dst[row + out_dim * HEAD_DIM] = new_state_val; + } + } + if (t < n_tokens - 1) { + __syncthreads(); + } + + } +} + +// Generic kernel that handles any HEAD_DIM at runtime (slower but flexible) +__global__ void delta_net_recurrent_generic_f32( + const float * __restrict__ q, + const float * __restrict__ k, + const float * __restrict__ v, + const float * __restrict__ g, + const float * __restrict__ beta_in, + const float * __restrict__ state_in, + float * __restrict__ dst, + const int64_t head_dim, + const int64_t n_tokens, + const int64_t n_heads, + const int64_t n_seqs, + const int64_t output_offset, + const float eps) { + const int batch_idx = blockIdx.x / n_heads; + const int head_idx = blockIdx.x % n_heads; + const int tid = threadIdx.x; + + // Strides (column-major) + const int64_t qkv_stride_token = head_dim; + const int64_t qkv_stride_head = head_dim * n_tokens; + const int64_t qkv_stride_batch = head_dim * n_tokens * n_heads; + + const int64_t g_stride_head = n_tokens; + const int64_t g_stride_batch = n_tokens * n_heads; + + const int64_t state_head_offset = head_idx * head_dim * head_dim; + const int64_t state_batch_stride = head_dim * head_dim * n_heads; + + // Pointers + const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset; + + // Output layout: [head_v_dim, num_v_heads, n_seq_tokens, n_seqs] + float * out_base = dst + batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim; + const int64_t out_token_stride = head_dim * n_heads; + float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset; + + // Shared memory for scalars (outside loop) + __shared__ float shared_g_val, shared_beta_val, shared_decay, shared_attn_score; + + // Dynamic shared memory + extern __shared__ float smem[]; + float * sQ = smem; + float * sK = sQ + head_dim; + float * sV = sK + head_dim; + float * sKBeta = sV + head_dim; // plain k for state update + float * sVBeta = sKBeta + head_dim; // v * sigmoid(beta) + float * sOut = sVBeta + head_dim; + float * sKCumdecay = sOut + head_dim; // k * sigmoid(beta) * exp(g) + float * sVPrime = sKCumdecay + head_dim; // state @ k_cumdecay + float * sVNew = sVPrime + head_dim; // v_beta - v_prime + float * sNorm = sVNew + head_dim; + + const float scale = rsqrtf((float)head_dim); + + // Copy initial state to output buffer + for (int i = tid; i < head_dim * head_dim; i += blockDim.x) { + int col = i / head_dim; + int row = i % head_dim; + state_dst[row + col * head_dim] = state_src[row + col * head_dim]; + } + __syncthreads(); + + // Process each token + for (int64_t t = 0; t < n_tokens; t++) { + if (tid < 2) sNorm[tid] = 0.0f; + __syncthreads(); + + // Load Q, K, V + for (int i = tid; i < head_dim; i += blockDim.x) { + sQ[i] = q_ptr[t * qkv_stride_token + i]; + sK[i] = k_ptr[t * qkv_stride_token + i]; + sV[i] = v_ptr[t * qkv_stride_token + i]; + } + __syncthreads(); + + // L2 normalize Q and K + float q_sq = 0.0f, k_sq = 0.0f; + for (int i = tid; i < head_dim; i += blockDim.x) { + q_sq += sQ[i] * sQ[i]; + k_sq += sK[i] * sK[i]; + } + + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + q_sq += __shfl_xor_sync(0xffffffff, q_sq, offset); + k_sq += __shfl_xor_sync(0xffffffff, k_sq, offset); + } + + if (tid % WARP_SIZE == 0) { + atomicAdd(&sNorm[0], q_sq); + atomicAdd(&sNorm[1], k_sq); + } + __syncthreads(); + + float q_norm = rsqrtf(sNorm[0] + eps); + float k_norm = rsqrtf(sNorm[1] + eps); + + for (int i = tid; i < head_dim; i += blockDim.x) { + sQ[i] *= q_norm * scale; + sK[i] *= k_norm; + } + __syncthreads(); + + // Load g and beta, compute decay + if (tid == 0) { + shared_g_val = g_ptr[t]; + shared_beta_val = sigmoid_f(beta_ptr[t]); + shared_decay = expf(fminf(shared_g_val, 50.0f)); + } + __syncthreads(); + + float beta_val = shared_beta_val; + float decay = shared_decay; + + // Compute k_beta, v_beta, k_cumdecay + for (int i = tid; i < head_dim; i += blockDim.x) { + sKBeta[i] = sK[i]; + sVBeta[i] = sV[i] * beta_val; + sKCumdecay[i] = sK[i] * beta_val * decay; + } + __syncthreads(); + + // Compute v_prime = state @ k_cumdecay + for (int row_out = tid; row_out < head_dim; row_out += blockDim.x) { + float v_prime_val = 0.0f; + for (int col = 0; col < head_dim; col++) { + // Access state[row_out, col] = state_dst[row_out + col * head_dim] for state @ k + v_prime_val += state_dst[row_out + col * head_dim] * sKCumdecay[col]; + } + sVPrime[row_out] = v_prime_val; + } + __syncthreads(); + + // Compute v_new = v_beta - v_prime (the value residual) + for (int i = tid; i < head_dim; i += blockDim.x) { + sVNew[i] = sVBeta[i] - sVPrime[i]; + } + __syncthreads(); + + // Compute attn_score = dot(k, q) (L2 normalized vectors) + if (tid == 0) { + float dot_sum = 0.0f; + for (int i = 0; i < head_dim; i++) { + dot_sum += sK[i] * sQ[i]; + } + shared_attn_score = dot_sum; + } + __syncthreads(); + + // Compute output: o[t] = attn_inter + v_attn + // attn_inter = state @ (q * exp(g)) = sum_col(state[row_out, col] * q[col] * exp(g)) + // The decomposed path uses: attn_inter = ggml_mul_mat(state_t, q_g_exp) + // Since ggml_mul_mat(A,B) = A^T @ B, attn_inter = state_t^T @ q_g_exp = state @ (q * exp(g)) + for (int row_out = tid; row_out < head_dim; row_out += blockDim.x) { + float attn_inter = 0.0f; + + for (int col = 0; col < head_dim; col++) { + // Access state[row_out, col] = state_dst[row_out + col * head_dim] for state @ q + float state_val = state_dst[row_out + col * head_dim]; + attn_inter += sQ[col] * decay * state_val; + } + + // v_attn = v_new * attn_score + float v_attn = sVNew[row_out] * shared_attn_score; + + // Output = attn_inter + v_attn (correct DeltaNet formula) + sOut[row_out] = attn_inter + v_attn; + } + __syncthreads(); + + // Update state: state_new = decay * state + outer(v_new, k) + // Fixed: outer product orientation matches decomposed: state[v_idx, k_idx] += v_new[v_idx] * k[k_idx] + // Uses transposed indexing: state_dst[row + out_dim * head_dim] = state[row][out_dim] + // Only protect against NaN/Inf - do NOT clamp decay value + float safe_decay = decay; + if (isnan(safe_decay) || isinf(safe_decay)) { + safe_decay = 1.0f; + } + + for (int out_dim = tid; out_dim < head_dim; out_dim += blockDim.x) { + for (int row = 0; row < head_dim; row++) { + float state_val = state_dst[row + out_dim * head_dim]; + + // state_new[row][out_dim] = decay * state[row][out_dim] + v_new[row] * k[out_dim] + // Fix: outer product matches decomposed path: state[v_idx, k_idx] += v_new[v_idx] * k[k_idx] + float new_state_val = safe_decay * state_val + sVNew[row] * sKBeta[out_dim]; + + // Clamp state to prevent overflow + new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f); + state_dst[row + out_dim * head_dim] = new_state_val; + } + } + __syncthreads(); + + // Write output + for (int i = tid; i < head_dim; i += blockDim.x) { + out_base[t * out_token_stride + i] = sOut[i]; + } + __syncthreads(); + } +} + +static void delta_net_f32_cuda( + const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * state_in, + float * dst, + const int64_t head_dim, + const int64_t n_tokens, + const int64_t n_heads, + const int64_t n_seqs, + const float eps, + const int device_id, + const int cc, // compute capability (e.g., 890 for SM 8.9, 1200 for SM 12.0) + cudaStream_t stream) { + GGML_UNUSED(device_id); + GGML_UNUSED(cc); + + const int64_t output_offset = head_dim * n_tokens * n_heads * n_seqs; + + // One block per (batch, head) pair + const int num_blocks = n_seqs * n_heads; + constexpr int threads_per_block = 512; //256; + + // Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew) + // Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score + const size_t smem_size = (9 * head_dim + 6) * sizeof(float); + + // Use templated kernel for common head dimensions, generic for others + if (head_dim == 64) { + delta_net_recurrent_f32<64, threads_per_block><<>>( + q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); + } else if (head_dim == 128) { + GGML_ASSERT(num_blocks % 8 == 0); + delta_net_recurrent_f32<128, threads_per_block><<>>( + q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); + } else { + delta_net_recurrent_generic_f32<<>>( + q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps); + } + + CUDA_CHECK(cudaGetLastError()); + +} + +void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // q + const ggml_tensor * src1 = dst->src[1]; // k + const ggml_tensor * src2 = dst->src[2]; // v + const ggml_tensor * src3 = dst->src[3]; // g + const ggml_tensor * src4 = dst->src[4]; // beta + const ggml_tensor * src5 = dst->src[5]; // state + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t head_dim = src0->ne[0]; + const int64_t n_tokens = src0->ne[1]; + const int64_t n_heads = src0->ne[2]; + const int64_t n_seqs = src0->ne[3]; + + // Dimension validation + // Q/K: [head_dim, n_tokens, n_heads, n_seqs] + GGML_ASSERT(src1->ne[0] == head_dim && src1->ne[1] == n_tokens && src1->ne[2] == n_heads && src1->ne[3] == n_seqs); + // V: [head_dim, n_tokens, n_heads, n_seqs] + GGML_ASSERT(src2->ne[0] == head_dim && src2->ne[1] == n_tokens && src2->ne[2] == n_heads && src2->ne[3] == n_seqs); + // G: [n_tokens, 1, n_heads, n_seqs] + GGML_ASSERT(src3->ne[0] == n_tokens && src3->ne[1] == 1 && src3->ne[2] == n_heads && src3->ne[3] == n_seqs); + // Beta: [1, n_tokens, n_heads, n_seqs] + GGML_ASSERT(src4->ne[0] == 1 && src4->ne[1] == n_tokens && src4->ne[2] == n_heads && src4->ne[3] == n_seqs); + // State: [head_dim, head_dim*n_heads, 1, n_seqs] + GGML_ASSERT(src5->ne[0] == head_dim && src5->ne[1] == head_dim * n_heads && src5->ne[2] == 1 && src5->ne[3] == n_seqs); + + // Verify output tensor size + const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs; + const int64_t state_size = head_dim * head_dim * n_heads * n_seqs; + GGML_ASSERT(ggml_nelements(dst) == output_size + state_size); + + const float eps = 1e-6f; + + GGML_ASSERT(head_dim <= 256); // Reasonable limit for shared memory + + // Get device info from ctx (avoids calling CUDA runtime APIs inside dispatch) + const int device_id = ctx.device; + const int cc = ggml_cuda_info().devices[device_id].cc; + + delta_net_f32_cuda( + (const float *)src0->data, + (const float *)src1->data, + (const float *)src2->data, + (const float *)src3->data, + (const float *)src4->data, + (const float *)src5->data, + (float *)dst->data, + head_dim, n_tokens, n_heads, n_seqs, eps, + device_id, cc, + ctx.stream()); + +} diff --git a/ggml/src/ggml-cuda/delta-net.cuh b/ggml/src/ggml-cuda/delta-net.cuh new file mode 100644 index 00000000..a9b223c6 --- /dev/null +++ b/ggml/src/ggml-cuda/delta-net.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 03e9e3e1..059b589c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4277,6 +4277,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TRI", "FILL", "SOLVE_TRI", + "DELTA_NET", "MAP_UNARY", "MAP_BINARY", @@ -4299,7 +4300,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "FUSED_NORM", }; -static_assert(GGML_OP_COUNT == 100, "GGML_OP_COUNT != 100"); +static_assert(GGML_OP_COUNT == 101, "GGML_OP_COUNT != 101"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -4395,6 +4396,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "tri(x)", "fill(x)", "solve_tri(x)", + "delta_net", "f(x)", "f(x,y)", @@ -4417,7 +4419,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "norm(x,y)", }; -static_assert(GGML_OP_COUNT == 100, "GGML_OP_COUNT != 100"); +static_assert(GGML_OP_COUNT == 101, "GGML_OP_COUNT != 101"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -9869,6 +9871,59 @@ struct ggml_tensor * ggml_tri( return result; } +struct ggml_tensor * ggml_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state) { + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + GGML_ASSERT(q->type == GGML_TYPE_F32); + GGML_ASSERT(k->type == GGML_TYPE_F32); + GGML_ASSERT(v->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(beta->type == GGML_TYPE_F32); + GGML_ASSERT(state->type == GGML_TYPE_F32); + + const int64_t S_k = q->ne[0]; + const int64_t n_tokens = q->ne[1]; + const int64_t H_k = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[2]; + + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == n_tokens && k->ne[2] == H_k && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == n_tokens && v->ne[3] == n_seqs); + GGML_ASSERT(g->ne[0] == n_tokens && g->ne[1] == 1 && g->ne[2] == H_k && g->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + GGML_ASSERT(H_k == H_v); + + const int64_t output_size = S_v * H_v * n_tokens * n_seqs; + const int64_t state_size = S_v * S_v * H_v * n_seqs; + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, output_size + state_size); + + result->op = GGML_OP_DELTA_NET; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = g; + result->src[4] = beta; + result->src[5] = state; + + return result; +} + // ggml_fill static struct ggml_tensor * ggml_fill_impl( @@ -22476,6 +22531,141 @@ static void ggml_compute_forward_solve_tri(const struct ggml_compute_params * pa } } +// ggml_compute_forward_delta_net + +static void ggml_compute_forward_delta_net_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; + const struct ggml_tensor * src3 = dst->src[3]; + const struct ggml_tensor * src4 = dst->src[4]; + const struct ggml_tensor * src5 = dst->src[5]; + + const int64_t head_dim = src0->ne[0]; + const int64_t n_tokens = src0->ne[1]; + const int64_t n_heads = src0->ne[2]; + const int64_t n_seqs = src0->ne[3]; + + const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs; + + const float * q_data = (const float *) src0->data; + const float * k_data = (const float *) src1->data; + const float * v_data = (const float *) src2->data; + const float * g_data = (const float *) src3->data; + const float * beta_data = (const float *) src4->data; + const float * state_in = (const float *) src5->data; + float * out_data = (float *) dst->data; + float * state_out = out_data + output_size; + + const int ith = params->ith; + const int nth = params->nth; + + if (iqk_fused_delta_net(head_dim, n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, + out_data, state_out, ith, nth)) { + return; + } + + const int64_t total_heads = n_heads * n_seqs; + const int64_t heads_per_thread = (total_heads + nth - 1) / nth; + const int64_t h_start = ith * heads_per_thread; + const int64_t h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads; + + const float eps = 1e-12f; + const float scale = 1.0f / sqrtf((float) head_dim); + + float * v_new_buf = (float *) malloc(head_dim * sizeof(float)); + GGML_ASSERT(v_new_buf); + + for (int64_t h_idx = h_start; h_idx < h_end; ++h_idx) { + const int64_t batch_idx = h_idx / n_heads; + const int64_t head_idx = h_idx % n_heads; + + const int64_t qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens); + const int64_t qkv_token_stride = head_dim; + const int64_t g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens; + const int64_t state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim); + const int64_t out_head_offset = batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim; + const int64_t out_token_stride = head_dim * n_heads; + + for (int64_t i = 0; i < head_dim * head_dim; ++i) { + state_out[state_head_offset + i] = state_in[state_head_offset + i]; + } + + float * state = state_out + state_head_offset; + + for (int64_t t = 0; t < n_tokens; ++t) { + const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride; + const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride; + const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride; + + const float g_val = g_data[g_head_offset + t]; + const float beta_raw = beta_data[g_head_offset + t]; + + float q_norm_sq = 0.0f; + float k_norm_sq = 0.0f; + for (int64_t i = 0; i < head_dim; ++i) { + q_norm_sq += q_t[i] * q_t[i]; + k_norm_sq += k_t[i] * k_t[i]; + } + const float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps); + const float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps); + + const float beta_val = 1.0f / (1.0f + expf(-beta_raw)); + const float decay = expf(fminf(g_val, 50.0f)); + + float attn_score = 0.0f; + for (int64_t i = 0; i < head_dim; ++i) { + attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale); + } + + float * out_t = out_data + out_head_offset + t * out_token_stride; + + for (int64_t row = 0; row < head_dim; ++row) { + float v_prime = 0.0f; + float out_val = 0.0f; + + for (int64_t col = 0; col < head_dim; ++col) { + const float k_col = k_t[col]; + const float q_col = q_t[col]; + const float s = state[row + col * head_dim]; + + v_prime += s * k_col; + out_val += s * q_col; + } + + const float v_new = v_t[row] * beta_val - v_prime * beta_val * decay * k_norm_inv; + v_new_buf[row] = v_new; + out_t[row] = out_val * decay * q_norm_inv * scale + v_new * attn_score; + } + + for (int64_t col = 0; col < head_dim; ++col) { + const float k_col = k_t[col] * k_norm_inv; + for (int64_t row = 0; row < head_dim; ++row) { + float s = state[row + col * head_dim]; + s = decay * s + v_new_buf[row] * k_col; + state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f); + } + } + } + } + + free(v_new_buf); +} + +static void ggml_compute_forward_delta_net( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + switch (dst->src[0]->type) { + case GGML_TYPE_F32: + ggml_compute_forward_delta_net_f32(params, dst); + break; + default: + GGML_ABORT("fatal error"); + } +} + // ggml_compute_forward_win_part static void ggml_compute_forward_win_part_f32( @@ -24202,6 +24392,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml { ggml_compute_forward_solve_tri(params, tensor); } break; + case GGML_OP_DELTA_NET: + { + ggml_compute_forward_delta_net(params, tensor); + } break; case GGML_OP_WIN_PART: { ggml_compute_forward_win_part(params, tensor); @@ -25260,6 +25454,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_TRI: case GGML_OP_FILL: case GGML_OP_SOLVE_TRI: + case GGML_OP_DELTA_NET: { GGML_ABORT("fatal error"); // TODO: not implemented } @@ -25990,6 +26185,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_FUSED_UP_GATE: case GGML_OP_OUT_PROD: case GGML_OP_SOLVE_TRI: + case GGML_OP_DELTA_NET: { n_tasks = n_threads; } break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index c580edbf..956b33e0 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1383,6 +1383,155 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k } #endif +namespace { +template +void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs, + const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data, + const float * state_in, float * out_data, float * state_out, int ith, int nth) { + const int total_heads = n_heads * n_seqs; + const int heads_per_thread = (total_heads + nth - 1) / nth; + const int h_start = ith * heads_per_thread; + const int h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads; + +#ifdef __AVX2__ + static_assert(head_dim % 8 == 0); +#endif + + const float eps = 1e-6f; + const float scale = 1.0f / sqrtf((float) head_dim); + + float v_new_buf[head_dim]; + float v_prime[head_dim], out_val[head_dim]; + + for (int h_idx = h_start; h_idx < h_end; ++h_idx) { + const int batch_idx = h_idx / n_heads; + const int head_idx = h_idx % n_heads; + + const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens); + const int qkv_token_stride = head_dim; + const int g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens; + const int state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim); + const int out_head_offset = batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim; + const int out_token_stride = head_dim * n_heads; + + for (int i = 0; i < head_dim * head_dim; ++i) { + state_out[state_head_offset + i] = state_in[state_head_offset + i]; + } + + float * state = state_out + state_head_offset; + + for (int t = 0; t < n_tokens; ++t) { + const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride; + const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride; + const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride; + + const float g_val = g_data[g_head_offset + t]; + const float beta_raw = beta_data[g_head_offset + t]; + + float q_norm_sq = 0.0f; + float k_norm_sq = 0.0f; + float kq_sum = 0.0f; +#ifdef __AVX2__ + auto vqsum = _mm256_setzero_ps(); + auto vksum = _mm256_setzero_ps(); + auto vqksum = _mm256_setzero_ps(); + for (int i = 0; i < head_dim; i += 8) { + auto vq = _mm256_loadu_ps(q_t + i); + auto vk = _mm256_loadu_ps(k_t + i); + vqsum = _mm256_fmadd_ps(vq, vq, vqsum); + vksum = _mm256_fmadd_ps(vk, vk, vksum); + vqksum = _mm256_fmadd_ps(vk, vq, vqksum); + } + q_norm_sq = hsum_float_8(vqsum); + k_norm_sq = hsum_float_8(vksum); + kq_sum = hsum_float_8(vqksum); +#else + for (int i = 0; i < head_dim; ++i) { + q_norm_sq += q_t[i] * q_t[i]; + k_norm_sq += k_t[i] * k_t[i]; + kq_sum += k_t[i] * q_t[i]; + } +#endif + const float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps); + const float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps); + + const float beta_val = 1.0f / (1.0f + expf(-beta_raw)); + const float decay = expf(fminf(g_val, 50.0f)); + + float attn_score = kq_sum * k_norm_inv * q_norm_inv * scale; + + //float attn_score = 0.0f; + //for (int i = 0; i < head_dim; ++i) { + // attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale); + //} + + float * out_t = out_data + out_head_offset + t * out_token_stride; + + std::memset(v_prime, 0, head_dim*sizeof(float)); + std::memset(out_val, 0, head_dim*sizeof(float)); + for (int col = 0; col < head_dim; ++col) { + const float k_col = k_t[col]; + const float q_col = q_t[col]; + for (int row = 0; row < head_dim; ++row) { + const float s = state[row + col * head_dim]; + v_prime[row] += s * k_col; + out_val[row] += s * q_col; + } + } + for (int row = 0; row < head_dim; ++row) { + const float v_new = v_t[row] * beta_val - v_prime[row] * beta_val * decay * k_norm_inv; + v_new_buf[row] = v_new; + out_t[row] = out_val[row] * decay * q_norm_inv * scale + v_new * attn_score; + } + +#ifdef __AVX2__ + auto vd = _mm256_set1_ps(decay); + auto vmin = _mm256_set1_ps(-1e6f); + auto vmax = _mm256_set1_ps( 1e6f); + for (int col = 0; col < head_dim; ++col) { + auto vk = _mm256_set1_ps(k_t[col] * k_norm_inv); + for (int row = 0; row < head_dim; row += 8) { + auto vs = _mm256_loadu_ps(state + col * head_dim + row); + auto vn = _mm256_loadu_ps(v_new_buf + row); + vs = _mm256_fmadd_ps(vn, vk, _mm256_mul_ps(vs, vd)); + auto mask_l = _mm256_cmp_ps(vs, vmin, _CMP_LT_OQ); + auto mask_u = _mm256_cmp_ps(vs, vmax, _CMP_GT_OQ); + vs = _mm256_or_ps(_mm256_and_ps(mask_l, vmin), _mm256_andnot_ps(mask_l, vs)); + vs = _mm256_or_ps(_mm256_and_ps(mask_u, vmax), _mm256_andnot_ps(mask_u, vs)); + _mm256_storeu_ps(state + col * head_dim + row, vs); + } + } +#else + for (int col = 0; col < head_dim; ++col) { + const float k_col = k_t[col] * k_norm_inv; + for (int row = 0; row < head_dim; ++row) { + float s = state[row + col * head_dim]; + s = decay * s + v_new_buf[row] * k_col; + state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f); + } + } +#endif + } + } +} +} + +bool iqk_fused_delta_net(int head_dim, int n_heads, int n_tokens, int n_seqs, + const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data, + const float * state_in, float * out_data, float * state_out, int ith, int nth) { + if (head_dim != 64 && head_dim != 128) { + return false; + } + if (head_dim == 64) { + iqk_fused_delta_net_impl<64>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, + out_data, state_out, ith, nth); + } else { + iqk_fused_delta_net_impl<128>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, + out_data, state_out, ith, nth); + } + return true; +} + #else // IQK_IMPLEMENT #include "ggml-impl.h" @@ -1416,4 +1565,11 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*n return false; } +bool iqk_fused_delta_net(int, int, int, int, + const float *, const float *, const float *, const float *, const float *, + const float *, float *, float *, int, int) { + return false; +} + + #endif diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 904b55ae..440bc815 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -73,6 +73,10 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, IQK_API void iqk_topk_moe(int n_experts, int n_experts_used, int nrows, const float * logits, float * weights, int32_t * ids, int ith, int nth); +IQK_API bool iqk_fused_delta_net(int head_dim, int n_heads, int n_tokens, int n_seqs, + const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data, + const float * state_in, float * out_data, float * state_out, int ith, int nth); + #ifdef __cplusplus } #endif diff --git a/include/llama.h b/include/llama.h index d73440fa..aec65af3 100644 --- a/include/llama.h +++ b/include/llama.h @@ -456,6 +456,7 @@ extern "C" { bool split_mode_graph_scheduling; // if true, force split mode graph scheduling //bool split_mode_f16; // if true, cast intermediate results to f16 before copying to other GPUs bool scheduler_async; // if true, with split mode "graph" graph evaluation will be done using multiple threads + int fused_delta_net; bool mtp; // Activate MTP if supported enum llama_mtp_op_type mtp_op_type; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index b178059f..6ac0a3a3 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -43,6 +43,7 @@ struct llama_cparams { bool split_mode_graph_scheduling; //bool split_mode_f16; bool scheduler_async; + int fused_delta_net; int min_experts; float thresh_experts; bool mtp; diff --git a/src/llama-delta-net.cpp b/src/llama-delta-net.cpp index 6826e11c..5aa6d1c7 100644 --- a/src/llama-delta-net.cpp +++ b/src/llama-delta-net.cpp @@ -372,6 +372,83 @@ std::pair delta_net::build_delta_net_autoregressiv return {core_attn_out, state}; } +std::pair delta_net::build_fused_delta_net(ggml_context * ctx0, + ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, + int il, const llm_build_cb & cb) { + + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); + GGML_ASSERT(H_k == H_v); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + cb(state,"state_in", il); + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + k = ggml_permute(ctx0, k, 0, 2, 1, 3); + v = ggml_permute(ctx0, v, 0, 2, 1, 3); + g = ggml_permute(ctx0, g, 2, 0, 3, 1); + beta = ggml_permute(ctx0, beta, 2, 0, 1, 3); + if (n_seqs > 1 || n_tokens > 1) { + q = ggml_cont_4d(ctx0, q, S_k, n_tokens, H_k, n_seqs); + k = ggml_cont_4d(ctx0, k, S_k, n_tokens, H_k, n_seqs); + v = ggml_cont_4d(ctx0, v, S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx0, g, n_tokens, 1, H_k, n_seqs); + beta = ggml_cont_4d(ctx0, beta, 1, n_tokens, H_k, n_seqs); + } + + ggml_tensor * state_flat = ggml_reshape_4d(ctx0, state, S_v, S_v * H_v, 1, n_seqs); + if (!ggml_is_contiguous(state_flat)) { + state_flat = ggml_cont_4d(ctx0, state_flat, S_v, S_v * H_v, 1, n_seqs); + } + + cb(q, "q_fused", il); + cb(k, "k_fused", il); + cb(v, "v_fused", il); + cb(g, "g_fused", il); + cb(beta, "beta_fused", il); + cb(state_flat,"state_fused", il); + + ggml_tensor * fused_result = ggml_delta_net(ctx0, q, k, v, g, beta, state_flat); + cb(fused_result, "delta_net_fused_raw", il); + + const int64_t output_size = S_v * H_v * n_tokens * n_seqs; + const int64_t state_size = S_v * S_v * H_v * n_seqs; + + ggml_tensor * output_tokens = ggml_view_4d(ctx0, fused_result, + S_v, H_v, n_tokens, n_seqs, + ggml_row_size(fused_result->type, S_v), + ggml_row_size(fused_result->type, S_v * H_v), + ggml_row_size(fused_result->type, S_v * H_v * n_tokens), 0); + output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs); + + ggml_tensor * new_state_flat = ggml_view_1d(ctx0, fused_result, state_size, + output_size * ggml_element_size(fused_result)); + ggml_tensor * new_state = ggml_reshape_4d(ctx0, new_state_flat, S_v, S_v, H_v, n_seqs); + + cb(output_tokens, "output_tokens", il); + cb(new_state, "new_state", il); + + return {output_tokens, new_state}; +} + std::pair delta_net::build_qkvz(ggml_context * ctx0, ggml_tensor * input, int il, const llm_build_cb & cb) const { auto & model = lctx.model; const int64_t n_tok = input->ne[1]; @@ -497,10 +574,14 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tok, 1); } else { beta = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_beta, cur); + cb(beta, "beta", il); beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_tok, 1); + cb(beta, "beta_reshaped", il); alpha = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_alpha, cur); - // Why??? + cb(alpha, "alpha", il); + // Why? Don't think this ggml_cont_3d is needed, but lets leave it in for now just in case. alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + cb(alpha, "alpha_cont", il); } cb(beta, "beta", il); cb(alpha, "alpha", il); @@ -603,15 +684,16 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - std::pair attn_out; - GGML_ASSERT(causal_mask != nullptr); GGML_ASSERT(identity != nullptr); GGML_ASSERT(diag_mask != nullptr); - attn_out = n_tok == 1 - ? build_delta_net_autoregressive(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb) - : build_delta_net_chunking(ctx0, q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il, cb); + std::pair attn_out; + // The fused delta-net implementation is only faster than chunked for n_tok <= 8, so use it only in that case + attn_out = n_tok <= lctx.cparams.fused_delta_net ? build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb) : + n_tok == 1 ? build_delta_net_autoregressive(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb) + : build_delta_net_chunking(ctx0, q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il, cb); + ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); diff --git a/src/llama-delta-net.h b/src/llama-delta-net.h index 360aecc9..e09259e5 100644 --- a/src/llama-delta-net.h +++ b/src/llama-delta-net.h @@ -19,6 +19,11 @@ struct delta_net { ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, int il, const llm_build_cb & cb); + static std::pair build_fused_delta_net(ggml_context * ctx0, + ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, + int il, const llm_build_cb & cb); + std::pair build_qkvz(ggml_context * ctx0, ggml_tensor * input, int il, const llm_build_cb & cb) const; ggml_tensor * build_layer_attn_linear_core(ggml_context * ctx0, ggml_cgraph * gf, diff --git a/src/llama.cpp b/src/llama.cpp index da2ab8f2..a5ceeb4f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4378,8 +4378,9 @@ struct llama_context_params llama_context_default_params() { /*.only_active_experts =*/ false, /*.k_cache_hadamard =*/ false, /*.split_mode_graph_scheduling =*/ false, - // /*.split_mode_f16 =*/ true, + // /*.split_mode_f16 =*/ true, /*.scheduler_async =*/ false, + /*.fused_delta_net =*/ 0, /*.mtp =*/ false, /*.mtp_op_type =*/ MTP_OP_NONE, /*.abort_callback =*/ nullptr, @@ -4751,6 +4752,7 @@ struct llama_context * llama_init_from_model( cparams.split_mode_graph_scheduling = params.split_mode_graph_scheduling; //cparams.split_mode_f16 = params.split_mode_f16; cparams.scheduler_async = params.scheduler_async; + cparams.fused_delta_net = params.fused_delta_net; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; cparams.cuda_params = params.cuda_params; @@ -4836,7 +4838,7 @@ struct llama_context * llama_init_from_model( cparams.mtp = 0; } - cparams.mtp_op_type = params.mtp_op_type; + cparams.mtp_op_type = params.mtp_op_type; LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); @@ -4857,6 +4859,7 @@ struct llama_context * llama_init_from_model( //LLAMA_LOG_INFO("%s: split_mode_f16= %d\n", __func__, cparams.split_mode_f16); LLAMA_LOG_INFO("%s: reduce_type = %s\n", __func__, ggml_type_name(cparams.reduce_type)); LLAMA_LOG_INFO("%s: sched_async = %d\n", __func__, cparams.scheduler_async); + LLAMA_LOG_INFO("%s: fused_delta = %d\n", __func__, cparams.fused_delta_net); LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);