diff --git a/common/common.cpp b/common/common.cpp index f61cb93b..942379d6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1533,7 +1533,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } if (arg == "-fdn" || arg == "--fused-delta-net") { CHECK_ARG - fprintf(stderr, "=================== %s has been deprecated\n", arg.c_str()); + params.fused_delta_net = std::stoi(argv[i]); return true; } if (arg == "-smf16" || arg == "--split-mode-f16") { @@ -2276,6 +2276,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-grt, --graph-reduce-type", "Type for data exchange between GPUs (default: %s)", "f32"}); options.push_back({ "*", "-smgs, --split-mode-graph-scheduling,", "Force Split Mode Graph Scheduling (default: %d)", params.split_mode_graph_scheduling}); options.push_back({ "*", "-sas, --scheduler_async,", "Async evaluation of compute graphs: %d)", params.scheduler_async}); + options.push_back({ "*", "-fdn, --fused-delta-net N", "Use fused delta-net when batch size is <= N with recurrent models: %d)", params.fused_delta_net}); options.push_back({ "*", "-vq, --validate-quants", "validate quantized data while loading the model (default: %d)", params.validate_quants}); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" @@ -3354,6 +3355,7 @@ struct llama_context_params common_context_params_to_llama(const gpt_params & pa cparams.split_mode_graph_scheduling = params.split_mode_graph_scheduling; //cparams.split_mode_f16 = params.split_mode_f16; cparams.scheduler_async = params.scheduler_async; + cparams.fused_delta_net = params.fused_delta_net; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; cparams.only_active_experts = params.only_active_exps; @@ -4364,6 +4366,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l //fprintf(stream, "split_mode_f16: %s # default: true\n", params.split_mode_f16 ? "true" : "false"); fprintf(stream, "reduce_type: %s # default f16\n", params.reduce_type.c_str()); fprintf(stream, "scheduler_async: %s # default: false\n", params.scheduler_async ? "true" : "false"); + fprintf(stream, "fused_delta_net: %d # default: 0\n", params.fused_delta_net ); fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index 7efd9ae5..44653a6f 100644 --- a/common/common.h +++ b/common/common.h @@ -359,6 +359,7 @@ struct gpt_params { bool split_mode_graph_scheduling = false; // if true, force split mode graph scheduling //bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops bool scheduler_async = false; // if true, in split mode graph the scheduler will use multiple threads to evaluate the graph + int fused_delta_net = 0; // use fused delta-net if number of tokens in the batch is less than this value bool has_mtp = false; // enable MTP if supported by the model std::string cache_type_k = "f16"; // KV cache data type for the K diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index a6111591..aeb78f89 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -271,6 +271,7 @@ struct cmd_params { bool muge = false; bool rcache = false; bool sas = false; + int fdn = 0; // fdn = fused delta net bool print_overrides = false; output_formats output_format; output_formats output_format_stderr; @@ -316,6 +317,7 @@ static const cmd_params cmd_params_defaults = { /* muge */ false, /* rcache */ false, /* sas */ false, + /* fdn */ 0, /* print_overrides */ false, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, @@ -369,6 +371,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -no-fug, --no-fused-up-gate <0|1> (default: %s)\n", cmd_params_defaults.no_fug? "1" : "0"); printf(" -no-ooae, --no-offload-only-active-experts <0|1> (default: %s)\n", cmd_params_defaults.no_ooae? "1" : "0"); printf(" -sas, --scheduler-async <0|1> (default: %s)\n", cmd_params_defaults.sas ? "1" : "0"); + printf(" -fdn, --fused-delta-net (default: %d)\n", cmd_params_defaults.fdn); printf(" --print-overrides <0|1> (default: %s)\n", cmd_params_defaults.print_overrides ? "1" : "0"); printf("\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); @@ -810,6 +813,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.sas = std::stoi(argv[i]); + } else if (arg == "-fdn" || arg == "--fused-delta-net") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.fdn = std::stoi(argv[i]); } else if (arg == "-rcache" || arg == "--rope-cache") { if (++i >= argc) { invalid_param = true; @@ -956,6 +965,7 @@ struct cmd_params_instance { bool muge = false; bool rcache = false; bool sas = false; + int fdn = 0; const llama_model_tensor_buft_override* buft_overrides; llama_model_params to_llama_mparams() const { @@ -991,6 +1001,7 @@ struct cmd_params_instance { muge == other.muge && use_thp == other.use_thp && sas == other.sas && + fdn == other.fdn && tensor_split == other.tensor_split; } @@ -1017,6 +1028,7 @@ struct cmd_params_instance { cparams.embeddings = embeddings; cparams.cuda_params = (void *)cuda_params.data(); cparams.scheduler_async = sas; + cparams.fused_delta_net = fdn; return cparams; } @@ -1083,6 +1095,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .muge = */ params.muge, /* .rcache = */ params.rcache, /* .sas = */ params.sas, + /* .fdn = */ params.fdn, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1126,6 +1139,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .muge = */ params.muge, /* .rcache = */ params.rcache, /* .sas = */ params.sas, + /* .fdn = */ params.fdn, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1169,6 +1183,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .muge = */ params.muge, /* .rcache = */ params.rcache, /* .sas = */ params.sas, + /* .fdn = */ params.fdn, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1212,6 +1227,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .muge = */ params.muge, /* .rcache = */ params.rcache, /* .sas = */ params.sas, + /* .fdn = */ params.fdn, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1266,6 +1282,7 @@ struct test { bool muge = false; bool rcache = false; bool sas = false; + int fdn = 0; std::string override_tensor; int n_prompt; int n_gen; @@ -1307,6 +1324,7 @@ struct test { ger = inst.ger; rcache = inst.rcache; sas = inst.sas; + fdn = inst.fdn; no_fug = inst.no_fug; use_thp = inst.use_thp; no_ooae = inst.no_ooae; @@ -1411,7 +1429,7 @@ struct test { field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" || field == "main_gpu" || field == "n_prompt" || field == "n_gen" || field == "mla_attn" || field == "attn_max_batch" || - field == "avg_ns" || field == "stddev_ns") { + field == "avg_ns" || field == "stddev_ns" || field == "fdn") { return INT; } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || @@ -1462,7 +1480,7 @@ struct test { std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), std::to_string(reuse), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(mqkv), std::to_string(muge), std::to_string(fmoe), std::to_string(ger), - std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), std::to_string(rcache), std::to_string(sas), + std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), std::to_string(rcache), std::to_string(sas), std::to_string(fdn), cuda_params, override_tensor, std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -1483,7 +1501,7 @@ struct test { "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", "reuse", "tensor_split", "use_mmap", "embeddings", "repack", "mqkv", "muge", "fused_moe", "grouped_er", - "no_fused_up_gate", "use_thp", "no_ooae", "rcache", "sas", "cuda_params", "override_tensor", + "no_fused_up_gate", "use_thp", "no_ooae", "rcache", "sas", "fdn", "cuda_params", "override_tensor", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", "test", @@ -1673,6 +1691,9 @@ struct markdown_printer : public printer { if (field == "sas") { return 3; } + if (field == "fdn") { + return 4; + } if (field == "use_thp") { return 3; } @@ -1746,6 +1767,9 @@ struct markdown_printer : public printer { if (field == "sas") { return "sas"; } + if (field == "fdn") { + return "fdn"; + } if (field == "use_thp") { return "thp"; } @@ -1856,6 +1880,9 @@ struct markdown_printer : public printer { if (params.sas != cmd_params_defaults.sas) { fields.emplace_back("sas"); } + if (params.fdn != cmd_params_defaults.fdn) { + fields.emplace_back("fdn"); + } if (params.muge != cmd_params_defaults.muge) { fields.emplace_back("muge"); } diff --git a/ggml/src/ggml-cuda/delta-net.cu b/ggml/src/ggml-cuda/delta-net.cu index a0193b55..acb6f6c4 100644 --- a/ggml/src/ggml-cuda/delta-net.cu +++ b/ggml/src/ggml-cuda/delta-net.cu @@ -41,12 +41,12 @@ __global__ void delta_net_recurrent_f32( const int64_t n_seqs, const int64_t output_offset, // offset where state starts in output const float eps) { - constexpr int warps_per_head = HEAD_DIM/WARP_SIZE; - const int batch_idx = blockIdx.x / (warps_per_head*n_heads); - const int sub_head_idx = blockIdx.x % (warps_per_head*n_heads); - const int head_idx = sub_head_idx / warps_per_head; - const int sub_idx = sub_head_idx % warps_per_head; + const int batch_idx = blockIdx.x / n_heads; + const int head_idx = blockIdx.x % n_heads; const int tid = threadIdx.x; + const int warp_id = tid / WARP_SIZE; // 0-7 for 256 threads + const int lane_id = tid % WARP_SIZE; // 0-31 + constexpr int NUM_WARPS = block_size/WARP_SIZE; // Strides for input tensors (column-major) // Q/K/V: [HEAD_DIM, n_tokens, n_heads, n_seqs] @@ -83,34 +83,32 @@ __global__ void delta_net_recurrent_f32( extern __shared__ float smem[]; float * sQ = smem; // HEAD_DIM float * sK = sQ + HEAD_DIM; // HEAD_DIM + float * sV = sK + HEAD_DIM; // HEAD_DIM + float * sVNew = sV + HEAD_DIM; // HEAD_DIM const float scale = rsqrtf((float)HEAD_DIM); __shared__ float sum_helper[block_size/WARP_SIZE]; - constexpr int num_warps = block_size/WARP_SIZE; - const int row = tid % WARP_SIZE; - const int col_idx_0 = tid / WARP_SIZE; - const int row_out = row + sub_idx * WARP_SIZE; - - // Keep the state in registers, copy the final state to its destination at the end - float state_local[HEAD_DIM/num_warps]; - for (int i = 0; i < HEAD_DIM/num_warps; ++i) { - int col = num_warps*i + col_idx_0; - state_local[i] = state_src[col*HEAD_DIM + row_out]; + // Copy initial state to output buffer (will be updated in place) + for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += block_size) { + state_dst[i] = state_src[i]; } - constexpr int WARP_SIZE_S = WARP_SIZE + 1; - constexpr int num_stored_rows = block_size/WARP_SIZE; - __shared__ float all_sum[2*WARP_SIZE_S*num_stored_rows]; + constexpr int HEAD_DIM_S = HEAD_DIM + 1; + constexpr int num_stored_rows = block_size >= HEAD_DIM && block_size % HEAD_DIM == 0 ? block_size/HEAD_DIM : NUM_WARPS; + __shared__ float all_sum[2*HEAD_DIM_S*num_stored_rows]; auto all_sum1 = all_sum; - auto all_sum2 = all_sum1 + WARP_SIZE_S*num_stored_rows; + auto all_sum2 = all_sum1 + HEAD_DIM_S*num_stored_rows; + // Process each token sequentially for (int64_t t = 0; t < n_tokens; t++) { + float sum_kq = 0.0f; for (int i = tid; i < HEAD_DIM; i += block_size) { sQ[i] = q_ptr[t * qkv_stride_token + i] * scale; sK[i] = k_ptr[t * qkv_stride_token + i]; + sV[i] = v_ptr[t * qkv_stride_token + i]; sum_kq += sK[i] * sQ[i]; } @@ -119,44 +117,281 @@ __global__ void delta_net_recurrent_f32( float beta_val = sigmoid_f(beta_ptr[t]); float decay = expf(fminf(g_ptr[t], 50.0f)); - float sum1 = 0, sum2 = 0; -#pragma unroll - for (int i = 0; i < HEAD_DIM/num_warps; ++i) { - int col = num_warps*i + col_idx_0; - sum1 += state_local[i] * sK[col]; - sum2 += state_local[i] * sQ[col]; - } - all_sum1[col_idx_0*WARP_SIZE_S + row] = sum1; - all_sum2[col_idx_0*WARP_SIZE_S + row] = sum2; + if constexpr (block_size >= HEAD_DIM && block_size % HEAD_DIM == 0) { + int idx = tid / HEAD_DIM; + int row_out = tid % HEAD_DIM; + float sum1 = 0, sum2 = 0; + #pragma unroll + for (int col = idx; col < HEAD_DIM; col += block_size/HEAD_DIM) { + float sval = state_dst[row_out + col * HEAD_DIM]; + sum1 += sval * sK[col]; + sum2 += sval * sQ[col]; + } + all_sum1[idx*HEAD_DIM_S + row_out] = sum1; + all_sum2[idx*HEAD_DIM_S + row_out] = sum2; - __syncthreads(); + __syncthreads(); - sum1 = sum2 = 0; -#pragma unroll - for (int i = 0; i < block_size/WARP_SIZE; ++i) { - sum1 += all_sum1[i*WARP_SIZE_S + row]; - sum2 += all_sum2[i*WARP_SIZE_S + row]; - } - // To be honest, I don't understand why we need this sync. But without it I observe results varying from run to run - __syncthreads(); + if (idx == 0) { + #pragma unroll + for (int i = 1; i < block_size/HEAD_DIM; ++i) { + sum1 += all_sum1[i*HEAD_DIM_S + row_out]; + sum2 += all_sum2[i*HEAD_DIM_S + row_out]; + } + sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay; + float v_attn = sVNew[row_out] * attn_score; + out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn; + } + __syncthreads(); + } else { + for (int row_out = lane_id; row_out < HEAD_DIM; row_out += WARP_SIZE) { + float sum1 = 0.0f; + float sum2 = 0.0f; + #pragma unroll + for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) { + float sval = state_dst[row_out + col * HEAD_DIM]; + sum1 += sval * sK[col]; + sum2 += sval * sQ[col]; + } + all_sum1[warp_id*HEAD_DIM_S + row_out] = sum1; + all_sum2[warp_id*HEAD_DIM_S + row_out] = sum2; + } + __syncthreads(); - float sv_new = beta_val * (v_ptr[t * qkv_stride_token + row_out] - sum1 * decay); - if (col_idx_0 == 0) { - out_base[t * out_token_stride + row_out] = sum2 * decay + sv_new * attn_score; + for (int row_out = tid; row_out < HEAD_DIM; row_out += block_size) { + float sum1 = all_sum1[row_out]; + float sum2 = all_sum2[row_out]; + #pragma unroll + for (int i = 1; i < NUM_WARPS; ++i) { + sum1 += all_sum1[row_out + i*HEAD_DIM_S]; + sum2 += all_sum2[row_out + i*HEAD_DIM_S]; + } + sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay; + float v_attn = sVNew[row_out] * attn_score; + out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn; + } + __syncthreads(); } - for (int i = 0; i < HEAD_DIM/num_warps; ++i) { - int col = num_warps*i + col_idx_0; - float new_state_val = decay * state_local[i] + sv_new * sK[col]; - new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f); - state_local[i] = new_state_val; + for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) { + float k_col = sK[out_dim]; + #pragma unroll + for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) { + float state_val = state_dst[row + out_dim * HEAD_DIM]; + float new_state_val = decay * state_val + sVNew[row] * k_col; //sK[out_dim]; + new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f); + state_dst[row + out_dim * HEAD_DIM] = new_state_val; + } } - } - // Copy the final state to its destination - for (int i = 0; i < HEAD_DIM/num_warps; ++i) { - int col = num_warps*i + col_idx_0; - state_dst[col*HEAD_DIM + row_out] = state_local[i]; +} + +// Generic kernel that handles any HEAD_DIM at runtime (slower but flexible) +__global__ void delta_net_recurrent_generic_f32( + const float * __restrict__ q, + const float * __restrict__ k, + const float * __restrict__ v, + const float * __restrict__ g, + const float * __restrict__ beta_in, + const float * __restrict__ state_in, + float * __restrict__ dst, + const int64_t head_dim, + const int64_t n_tokens, + const int64_t n_heads, + const int64_t n_seqs, + const int64_t output_offset, + const float eps) { + const int batch_idx = blockIdx.x / n_heads; + const int head_idx = blockIdx.x % n_heads; + const int tid = threadIdx.x; + + // Strides (column-major) + const int64_t qkv_stride_token = head_dim; + const int64_t qkv_stride_head = head_dim * n_tokens; + const int64_t qkv_stride_batch = head_dim * n_tokens * n_heads; + + const int64_t g_stride_head = n_tokens; + const int64_t g_stride_batch = n_tokens * n_heads; + + const int64_t state_head_offset = head_idx * head_dim * head_dim; + const int64_t state_batch_stride = head_dim * head_dim * n_heads; + + // Pointers + const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset; + + // Output layout: [head_v_dim, num_v_heads, n_seq_tokens, n_seqs] + float * out_base = dst + batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim; + const int64_t out_token_stride = head_dim * n_heads; + float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset; + + // Shared memory for scalars (outside loop) + __shared__ float shared_g_val, shared_beta_val, shared_decay, shared_attn_score; + + // Dynamic shared memory + extern __shared__ float smem[]; + float * sQ = smem; + float * sK = sQ + head_dim; + float * sV = sK + head_dim; + float * sKBeta = sV + head_dim; // plain k for state update + float * sVBeta = sKBeta + head_dim; // v * sigmoid(beta) + float * sOut = sVBeta + head_dim; + float * sKCumdecay = sOut + head_dim; // k * sigmoid(beta) * exp(g) + float * sVPrime = sKCumdecay + head_dim; // state @ k_cumdecay + float * sVNew = sVPrime + head_dim; // v_beta - v_prime + float * sNorm = sVNew + head_dim; + + const float scale = rsqrtf((float)head_dim); + + // Copy initial state to output buffer + for (int i = tid; i < head_dim * head_dim; i += blockDim.x) { + int col = i / head_dim; + int row = i % head_dim; + state_dst[row + col * head_dim] = state_src[row + col * head_dim]; + } + __syncthreads(); + + // Process each token + for (int64_t t = 0; t < n_tokens; t++) { + if (tid < 2) sNorm[tid] = 0.0f; + __syncthreads(); + + // Load Q, K, V + for (int i = tid; i < head_dim; i += blockDim.x) { + sQ[i] = q_ptr[t * qkv_stride_token + i]; + sK[i] = k_ptr[t * qkv_stride_token + i]; + sV[i] = v_ptr[t * qkv_stride_token + i]; + } + __syncthreads(); + + // L2 normalize Q and K + float q_sq = 0.0f, k_sq = 0.0f; + for (int i = tid; i < head_dim; i += blockDim.x) { + q_sq += sQ[i] * sQ[i]; + k_sq += sK[i] * sK[i]; + } + + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + q_sq += __shfl_xor_sync(0xffffffff, q_sq, offset); + k_sq += __shfl_xor_sync(0xffffffff, k_sq, offset); + } + + if (tid % WARP_SIZE == 0) { + atomicAdd(&sNorm[0], q_sq); + atomicAdd(&sNorm[1], k_sq); + } + __syncthreads(); + + float q_norm = rsqrtf(sNorm[0] + eps); + float k_norm = rsqrtf(sNorm[1] + eps); + + for (int i = tid; i < head_dim; i += blockDim.x) { + sQ[i] *= q_norm * scale; + sK[i] *= k_norm; + } + __syncthreads(); + + // Load g and beta, compute decay + if (tid == 0) { + shared_g_val = g_ptr[t]; + shared_beta_val = sigmoid_f(beta_ptr[t]); + shared_decay = expf(fminf(shared_g_val, 50.0f)); + } + __syncthreads(); + + float beta_val = shared_beta_val; + float decay = shared_decay; + + // Compute k_beta, v_beta, k_cumdecay + for (int i = tid; i < head_dim; i += blockDim.x) { + sKBeta[i] = sK[i]; + sVBeta[i] = sV[i] * beta_val; + sKCumdecay[i] = sK[i] * beta_val * decay; + } + __syncthreads(); + + // Compute v_prime = state @ k_cumdecay + for (int row_out = tid; row_out < head_dim; row_out += blockDim.x) { + float v_prime_val = 0.0f; + for (int col = 0; col < head_dim; col++) { + // Access state[row_out, col] = state_dst[row_out + col * head_dim] for state @ k + v_prime_val += state_dst[row_out + col * head_dim] * sKCumdecay[col]; + } + sVPrime[row_out] = v_prime_val; + } + __syncthreads(); + + // Compute v_new = v_beta - v_prime (the value residual) + for (int i = tid; i < head_dim; i += blockDim.x) { + sVNew[i] = sVBeta[i] - sVPrime[i]; + } + __syncthreads(); + + // Compute attn_score = dot(k, q) (L2 normalized vectors) + if (tid == 0) { + float dot_sum = 0.0f; + for (int i = 0; i < head_dim; i++) { + dot_sum += sK[i] * sQ[i]; + } + shared_attn_score = dot_sum; + } + __syncthreads(); + + // Compute output: o[t] = attn_inter + v_attn + // attn_inter = state @ (q * exp(g)) = sum_col(state[row_out, col] * q[col] * exp(g)) + // The decomposed path uses: attn_inter = ggml_mul_mat(state_t, q_g_exp) + // Since ggml_mul_mat(A,B) = A^T @ B, attn_inter = state_t^T @ q_g_exp = state @ (q * exp(g)) + for (int row_out = tid; row_out < head_dim; row_out += blockDim.x) { + float attn_inter = 0.0f; + + for (int col = 0; col < head_dim; col++) { + // Access state[row_out, col] = state_dst[row_out + col * head_dim] for state @ q + float state_val = state_dst[row_out + col * head_dim]; + attn_inter += sQ[col] * decay * state_val; + } + + // v_attn = v_new * attn_score + float v_attn = sVNew[row_out] * shared_attn_score; + + // Output = attn_inter + v_attn (correct DeltaNet formula) + sOut[row_out] = attn_inter + v_attn; + } + __syncthreads(); + + // Update state: state_new = decay * state + outer(v_new, k) + // Fixed: outer product orientation matches decomposed: state[v_idx, k_idx] += v_new[v_idx] * k[k_idx] + // Uses transposed indexing: state_dst[row + out_dim * head_dim] = state[row][out_dim] + // Only protect against NaN/Inf - do NOT clamp decay value + float safe_decay = decay; + if (isnan(safe_decay) || isinf(safe_decay)) { + safe_decay = 1.0f; + } + + for (int out_dim = tid; out_dim < head_dim; out_dim += blockDim.x) { + for (int row = 0; row < head_dim; row++) { + float state_val = state_dst[row + out_dim * head_dim]; + + // state_new[row][out_dim] = decay * state[row][out_dim] + v_new[row] * k[out_dim] + // Fix: outer product matches decomposed path: state[v_idx, k_idx] += v_new[v_idx] * k[k_idx] + float new_state_val = safe_decay * state_val + sVNew[row] * sKBeta[out_dim]; + + // Clamp state to prevent overflow + new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f); + state_dst[row + out_dim * head_dim] = new_state_val; + } + } + __syncthreads(); + + // Write output + for (int i = tid; i < head_dim; i += blockDim.x) { + out_base[t * out_token_stride + i] = sOut[i]; + } + __syncthreads(); } } @@ -181,32 +416,24 @@ static void delta_net_f32_cuda( const int64_t output_offset = head_dim * n_tokens * n_heads * n_seqs; - if (head_dim != 64 && head_dim != 128) { - GGML_ABORT("Unsupported delta net head size"); - } + // One block per (batch, head) pair + const int num_blocks = n_seqs * n_heads; + constexpr int threads_per_block = 512; //256; - GGML_ASSERT(head_dim % WARP_SIZE == 0); - const int num_blocks = n_seqs * n_heads * (head_dim/WARP_SIZE); - const size_t smem_size = 2 * head_dim * sizeof(float); + const size_t smem_size = 4 * head_dim * sizeof(float); - if (n_tokens <= 8) { - constexpr int threads_per_block = 256; - if (head_dim == 64) { - delta_net_recurrent_f32<64, threads_per_block><<>>( + // Use templated kernel for common head dimensions, generic for others + if (head_dim == 64) { + delta_net_recurrent_f32<64, threads_per_block><<>>( + q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); + } else if (head_dim == 128) { + GGML_ASSERT(num_blocks % 8 == 0); + delta_net_recurrent_f32<128, threads_per_block><<>>( q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); - } else { - delta_net_recurrent_f32<128, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); - } } else { - constexpr int threads_per_block = 128; - if (head_dim == 64) { - delta_net_recurrent_f32<64, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); - } else { - delta_net_recurrent_f32<128, threads_per_block><<>>( - q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); - } + GGML_ASSERT("Unsupported delta net head size"); + delta_net_recurrent_generic_f32<<>>( + q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps); } CUDA_CHECK(cudaGetLastError()); diff --git a/include/llama.h b/include/llama.h index cd8da575..338381b4 100644 --- a/include/llama.h +++ b/include/llama.h @@ -456,6 +456,7 @@ extern "C" { bool split_mode_graph_scheduling; // if true, force split mode graph scheduling //bool split_mode_f16; // if true, cast intermediate results to f16 before copying to other GPUs bool scheduler_async; // if true, with split mode "graph" graph evaluation will be done using multiple threads + int fused_delta_net; bool mtp; // Activate MTP if supported enum llama_mtp_op_type mtp_op_type; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index b178059f..6ac0a3a3 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -43,6 +43,7 @@ struct llama_cparams { bool split_mode_graph_scheduling; //bool split_mode_f16; bool scheduler_async; + int fused_delta_net; int min_experts; float thresh_experts; bool mtp; diff --git a/src/llama-delta-net.cpp b/src/llama-delta-net.cpp index b6566d05..5aa6d1c7 100644 --- a/src/llama-delta-net.cpp +++ b/src/llama-delta-net.cpp @@ -9,7 +9,7 @@ #include #include -#define DELTA_CHUNK_SIZE 64 +#define QWEN3NEXT_CHUNK_SIZE 64 delta_net::delta_net(llama_context & _lctx, const llama_batch & _batch) : lctx(_lctx), batch(_batch) { auto & model = lctx.model; @@ -74,6 +74,304 @@ delta_net::delta_net(llama_context & _lctx, const llama_batch & _batch) : lctx(_ delta_net::~delta_net() = default; +std::pair delta_net::build_delta_net_chunking(ggml_context * ctx0, + ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, + ggml_tensor * causal_mask, ggml_tensor * identity, + ggml_tensor * diag_mask, int il, const llm_build_cb & cb) { + + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_seqs == 1); + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + if (beta->ne[0] != H_v || beta->ne[2] != n_tokens || beta->ne[3] != n_seqs) { + printf("beta: %ld x %ld x %ld, expected %ld x %ld x %ld\n", beta->ne[0], beta->ne[2], beta->ne[3], H_v, n_tokens, n_seqs); + } + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); + GGML_ASSERT(H_k == H_v); + + const float scale = 1.0f / sqrtf(S_v); + q = ggml_scale(ctx0, q, scale); + + beta = ggml_sigmoid(ctx0, beta); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + cb(state,"state_in", il); + + const int64_t chunk_size = QWEN3NEXT_CHUNK_SIZE; + const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; + const int64_t n_chunks = (n_tokens + pad) / chunk_size; + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + k = ggml_permute(ctx0, k, 0, 2, 1, 3); + v = ggml_permute(ctx0, v, 0, 2, 1, 3); + g = ggml_permute(ctx0, g, 2, 0, 3, 1); + beta = ggml_permute(ctx0, beta, 2, 0, 1, 3); + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, pad, 0, 0, 0); + + cb(q, "q_pad", il); + cb(k, "k_pad", il); + cb(v, "v_pad", il); + cb(beta, "beta_pad", il); + cb(g, "g_pad", il); + + ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); + ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + + cb(v_beta, "v_beta", il); + cb(k_beta, "k_beta", il); + + q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); + k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_v * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); + v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); + + g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_v * n_seqs); + beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + cb(g_cumsum, "g_cumsum", il); + + ggml_tensor * gcs_i = + ggml_repeat_4d(ctx0, g_cumsum, chunk_size, chunk_size, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * gcs_j_broadcast = + ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); + ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); + cb(decay_mask, "decay_mask", il); + + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + cb(decay_mask, "decay_mask_1", il); + decay_mask = ggml_exp(ctx0, decay_mask); + cb(decay_mask, "decay_mask_exp", il); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + cb(decay_mask, "decay_mask_2", il); + + ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); + cb(kmulkbeta, "kk_beta", il); + + ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); + cb(k_decay, "k_decay_1", il); + k_decay = ggml_mul(ctx0, k_decay, causal_mask); + cb(k_decay, "k_decay_2", il); + ggml_tensor * attn = ggml_neg(ctx0, k_decay); + cb(attn, "attn_pre_solve", il); + + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); + cb(attn_lower, "attn_lower", il); + ggml_tensor * identity_repeat = + ggml_repeat_4d(ctx0, identity, attn_lower->ne[0], attn_lower->ne[1], attn_lower->ne[2], attn_lower->ne[3]); + ggml_tensor * lhs = ggml_neg(ctx0, ggml_sub(ctx0, attn_lower, identity_repeat)); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_mul(ctx0, lin_solve, causal_mask); + cb(attn, "attn_mul", il); + attn = ggml_add(ctx0, attn, identity); + cb(attn, "attn_solved", il); + + auto v_beta_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)); + cb(v_beta_t, "v_beta_t", il); + v = ggml_mul_mat(ctx0, v_beta_t, attn); + cb(v, "v_beta", il); + + ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); + cb(g_cumsum_t, "g_cumsum_t", il); + ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); + cb(gexp, "gexp", il); + + ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); + cb(kbeta_gexp, "kbeta_gexp", il); + + auto kbeta_gexp_t = ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)); + cb(kbeta_gexp_t, "kbeta_gexp_t", il); + auto attn_kbeta = ggml_mul_mat(ctx0, attn, kbeta_gexp_t); + cb(attn_kbeta, "attn_kbeta", il); + ggml_tensor * k_cumdecay = ggml_cont(ctx0, ggml_transpose(ctx0, attn_kbeta)); + cb(k_cumdecay, "k_cumdecay", il); + + ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); + cb(attn_kq, "attn_kq_pre", il); + attn_kq = ggml_mul(ctx0, decay_mask, attn_kq); + cb(attn_kq, "attn_kq_0", il); + attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); + cb(attn_kq, "attn_kq", il); + + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], + g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], + (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); + g_last = ggml_cont(ctx0, g_last); + cb(g_last, "g_last", il); + + ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); + cb(g_last_exp, "g_last_exp", il); + + ggml_tensor * g_last_repeat = + ggml_repeat_4d(ctx0, g_last, chunk_size, 1, n_chunks, H_v * n_seqs); + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last_repeat)); + cb(g_diff, "g_diff", il); + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + cb(g_diff_exp, "g_diff_exp", il); + ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, 1, chunk_size, n_chunks, g_diff_exp->ne[3]); + + ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); + cb(key_gdiff, "key_gdiff", il); + + ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); + cb(key_gdiff_t, "key_gdiff_t", il); + + cb(state, "new_state", il); + + auto get_slice_2d = [ctx0](ggml_tensor * t, int64_t c) -> ggml_tensor * { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); + }; + + ggml_tensor * core_attn_out = nullptr; + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + ggml_tensor * q_chunk = get_slice_2d(q, chunk); + ggml_tensor * v_chunk = get_slice_2d(v, chunk); + ggml_tensor * gexp_chunk = get_slice_2d(gexp, chunk); + ggml_tensor * k_cumdecay_chunk = get_slice_2d(k_cumdecay, chunk); + ggml_tensor * attn_chunk = get_slice_2d(attn_kq, chunk); + cb(attn_chunk, "attn_chunk", il); + + ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); + cb(state_t, "state_t", il); + + ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); + cb(v_prime, "v_prime_chunk", il); + + ggml_tensor * v_new = ggml_sub(ctx0, v_prime, v_chunk); + ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + cb(v_new, "v_new_chunk", il); + + ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); + cb(q_g_exp, "q_g_exp", il); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + cb(attn_inter, "attn_inter_chunk", il); + + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); + cb(v_attn, "v_attn_chunk", il); + + ggml_tensor * core_attn_out_chunk = ggml_sub(ctx0, attn_inter, v_attn); + cb(core_attn_out_chunk, "core_attn_out_chunk", il); + + core_attn_out = core_attn_out == nullptr + ? core_attn_out_chunk + : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); + + ggml_tensor * k_gdiff_t = get_slice_2d(key_gdiff_t, chunk); + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); + cb(kgdmulvnew, "kgdmulvnew", il); + + ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(g_last_exp, chunk)); + cb(gexp_last_chunk, "gexp_last_chunk", il); + auto s_mul = ggml_mul(ctx0, state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)); + cb(s_mul, "s_mul", il); + state = ggml_sub(ctx0, s_mul, ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); + } + + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(core_attn_out->type, S_v), + ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks), + ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks * H_v), 0); + cb(output_tokens, "output_tokens", il); + + output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); + output_tokens = ggml_cont(ctx0, output_tokens); + cb(output_tokens, "output_tokens", il); + + return {output_tokens, state}; +} + +std::pair delta_net::build_delta_net_autoregressive(ggml_context * ctx0, + ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, + int il, const llm_build_cb & cb) { + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_tokens == 1); + GGML_ASSERT(n_seqs == 1); + GGML_ASSERT(H_k == H_v); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + beta = ggml_sigmoid(ctx0, beta); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); + ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); + + g_t = ggml_exp(ctx0, g_t); + cb(g_t, "g_t", il); + state = ggml_mul(ctx0, state, g_t); + cb(state, "state", il); + + ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); + ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); + cb(kv_mem, "kv_mem", il); + kv_mem = ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)); + cb(kv_mem, "kv_mem_t_cont", il); + kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, kv_mem)); + + ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); + ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); + cb(v_diff, "v_diff", il); + ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); + cb(delta, "delta", il); + + ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); + cb(k_t_delta, "k_t_delta", il); + state = ggml_add(ctx0, state, k_t_delta); + + ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); + ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); + cb(state_q, "state_q", il); + state_q = ggml_cont(ctx0, ggml_transpose(ctx0, state_q)); + cb(state_q, "state_q_t_cont", il); + ggml_tensor * core_attn_out = ggml_transpose(ctx0, ggml_sum_rows(ctx0, state_q)); + + cb(core_attn_out, "output_tokens", il); + cb(state, "new_state", il); + + return {core_attn_out, state}; +} + std::pair delta_net::build_fused_delta_net(ggml_context * ctx0, ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, @@ -246,7 +544,9 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ const int64_t n_seqs = 1; const int64_t n_seq_tokens = n_tok; - auto [qkv_mixed, z] = build_qkvz(ctx0, cur, il, cb); + auto qkvz = build_qkvz(ctx0, cur, il, cb); + ggml_tensor * qkv_mixed = qkvz.first; + ggml_tensor * z = qkvz.second; ggml_tensor *alpha, *beta; if (model.layers[il].ssm_beta_alpha) { @@ -272,20 +572,19 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tok, 1); alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tok, 1); - cb(beta, "beta", il); - cb(alpha, "alpha", il); } else { beta = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_beta, cur); - alpha = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_alpha, cur); - ggml_build_forward_expand(gf, beta); - ggml_build_forward_expand(gf, alpha); cb(beta, "beta", il); - cb(alpha, "alpha", il); beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_tok, 1); cb(beta, "beta_reshaped", il); - alpha = ggml_reshape_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); - cb(alpha, "alpha_reshaped", il); + alpha = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_alpha, cur); + cb(alpha, "alpha", il); + // Why? Don't think this ggml_cont_3d is needed, but lets leave it in for now just in case. + alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + cb(alpha, "alpha_cont", il); } + cb(beta, "beta", il); + cb(alpha, "alpha", il); ggml_build_forward_expand(gf, beta); ggml_build_forward_expand(gf, alpha); @@ -307,13 +606,18 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ state_all = ggml_view_2d(ctx0, state_storage, state_dim, qnext_state_slots, state_row_size, 0); ggml_tensor * state_dst = ggml_view_2d(ctx0, state_all, state_dim, 1, state_row_size, state_seq_id_local * state_row_size); + ggml_tensor * state_f32 = state_dst; + if (state_f32->type != GGML_TYPE_F32) { + state_f32 = ggml_cast(ctx0, state_f32, GGML_TYPE_F32); + } if (reset_state_local) { - state_dst = ggml_scale(ctx0, state_dst, 0.0f); - cb(state_dst, "state_reset", il); + state_f32 = ggml_scale(ctx0, state_f32, 0.0f); + cb(state_f32, "state_reset", il); } - ggml_tensor * conv_state_flat = ggml_view_2d(ctx0, state_dst, conv_state_dim, 1, state_dst->nb[1], 0); - ggml_tensor * ssm_state_flat = ggml_view_2d(ctx0, state_dst, ssm_state_dim, 1, state_dst->nb[1], conv_state_dim * ggml_element_size(state_dst)); + ggml_tensor * conv_state_flat = ggml_view_2d(ctx0, state_f32, conv_state_dim, 1, state_f32->nb[1], 0); + ggml_tensor * ssm_state_flat = ggml_view_2d(ctx0, state_f32, ssm_state_dim, 1, state_f32->nb[1], + conv_state_dim * ggml_element_size(state_f32)); ggml_tensor * conv_states = ggml_reshape_3d(ctx0, conv_state_flat, hparams.ssm_d_conv - 1, conv_dim, 1); ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_state_flat, head_v_dim, head_v_dim, num_v_heads, 1); @@ -324,6 +628,8 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ ggml_tensor * conv_output_raw = ggml_ssm_conv(ctx0, conv_states, qkv_mixed, model.layers[il].ssm_conv1d, inp_s_seq_qnext); cb(conv_output_raw, "conv_output_raw", il); + //ggml_tensor * conv_output = ggml_view_2d(ctx0, conv_output_raw, conv_dim, n_tok, conv_dim * ggml_element_size(conv_output_raw), 0); + //ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output); ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_raw); cb(conv_output_silu, "conv_output_silu", il); @@ -333,24 +639,27 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ // Extract the convolved Q, K, V from conv_output ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_output_silu, head_k_dim, num_k_heads, n_tok, 1, - ggml_row_size(conv_output_silu->type, head_k_dim), nb1_qkv, nb1_qkv * n_tok, 0); + ggml_row_size(conv_output_silu->type, head_k_dim), + nb1_qkv, nb1_qkv * n_tok, 0); ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_output_silu, head_k_dim, num_k_heads, n_tok, 1, - ggml_row_size(conv_output_silu->type, head_k_dim), nb1_qkv, nb1_qkv * n_tok, + ggml_row_size(conv_output_silu->type, head_k_dim), + nb1_qkv, nb1_qkv * n_tok, head_k_dim * num_k_heads * ggml_element_size(conv_output_silu)); ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_output_silu, head_v_dim, num_v_heads, n_tok, 1, - ggml_row_size(conv_output_silu->type, head_v_dim), nb1_qkv, nb1_qkv * n_tok, + ggml_row_size(conv_output_silu->type, head_v_dim), + nb1_qkv, nb1_qkv * n_tok, ggml_row_size(conv_output_silu->type, 2 * head_k_dim * num_k_heads)); cb(q_conv, "q_conv", il); cb(k_conv, "k_conv", il); cb(v_conv, "v_conv", il); - q_conv = ggml_l2_norm(ctx0, q_conv, hparams.f_norm_rms_eps); - k_conv = ggml_l2_norm(ctx0, k_conv, hparams.f_norm_rms_eps); - cb(q_conv, "q_conv_normed", il); - cb(k_conv, "k_conv_normed", il); + const float eps_norm = hparams.f_norm_rms_eps; + + q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm); + k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm); if (num_k_heads != num_v_heads) { GGML_ASSERT(num_v_heads % num_k_heads == 0); @@ -379,8 +688,14 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ GGML_ASSERT(identity != nullptr); GGML_ASSERT(diag_mask != nullptr); - auto [output, new_state] = build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb); + std::pair attn_out; + // The fused delta-net implementation is only faster than chunked for n_tok <= 8, so use it only in that case + attn_out = n_tok <= lctx.cparams.fused_delta_net ? build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb) : + n_tok == 1 ? build_delta_net_autoregressive(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb) + : build_delta_net_chunking(ctx0, q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il, cb); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); cb(new_state, "new_state", il); @@ -393,7 +708,11 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ ggml_tensor * new_ssm_flat = ggml_reshape_2d(ctx0, new_state, ssm_state_dim, 1); ggml_tensor * new_state_flat = ggml_concat(ctx0, new_conv_flat, new_ssm_flat, 0); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state_flat, state_dst)); + ggml_tensor * state_update = new_state_flat; + if (state_dst->type != GGML_TYPE_F32) { + state_update = ggml_cast(ctx0, state_update, state_dst->type); + } + ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_update, state_dst)); ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_tok); ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_tok); @@ -409,7 +728,7 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ ggml_tensor * out = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, final_output); cb(out, "linear_attn_out", il); - return out; + return ggml_reshape_2d(ctx0, out, hparams.n_embd, n_tok); } diff --git a/src/llama-delta-net.h b/src/llama-delta-net.h index 1bd72e2c..e09259e5 100644 --- a/src/llama-delta-net.h +++ b/src/llama-delta-net.h @@ -8,6 +8,17 @@ struct delta_net { delta_net(llama_context & lctx, const llama_batch & batch); ~delta_net(); + static std::pair build_delta_net_chunking(ggml_context * ctx0, + ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, + ggml_tensor * causal_mask, ggml_tensor * identity, + ggml_tensor * diag_mask, int il, const llm_build_cb & cb); + + static std::pair build_delta_net_autoregressive(ggml_context * ctx0, + ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, + int il, const llm_build_cb & cb); + static std::pair build_fused_delta_net(ggml_context * ctx0, ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, diff --git a/src/llama.cpp b/src/llama.cpp index 3ccae37a..facceb0d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1512,7 +1512,6 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); } LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); @@ -4395,6 +4394,7 @@ struct llama_context_params llama_context_default_params() { /*.split_mode_graph_scheduling =*/ false, // /*.split_mode_f16 =*/ true, /*.scheduler_async =*/ false, + /*.fused_delta_net =*/ 0, /*.mtp =*/ false, /*.mtp_op_type =*/ MTP_OP_NONE, /*.abort_callback =*/ nullptr, @@ -4766,6 +4766,7 @@ struct llama_context * llama_init_from_model( cparams.split_mode_graph_scheduling = params.split_mode_graph_scheduling; //cparams.split_mode_f16 = params.split_mode_f16; cparams.scheduler_async = params.scheduler_async; + cparams.fused_delta_net = params.fused_delta_net; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; cparams.cuda_params = params.cuda_params; @@ -4872,6 +4873,7 @@ struct llama_context * llama_init_from_model( //LLAMA_LOG_INFO("%s: split_mode_f16= %d\n", __func__, cparams.split_mode_f16); LLAMA_LOG_INFO("%s: reduce_type = %s\n", __func__, ggml_type_name(cparams.reduce_type)); LLAMA_LOG_INFO("%s: sched_async = %d\n", __func__, cparams.scheduler_async); + LLAMA_LOG_INFO("%s: fused_delta = %d\n", __func__, cparams.fused_delta_net); LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);