Fused delta-net (#1315)

* Revive fused delta-net

* Add command line argument for fused delta net

* Simplify/improve CUDA delta-net

* Add -fdn to llama-bench

* More CUDA fused delta net optimizations

* CPU optimizations

* Much faster fused delta-net on the CPU

It seems it is faster than the chunked implementation!

* Change meaning of fdn from bool flag to threshold value

* Use eps = 1e-6

* Give some nodes a name
This commit is contained in:
Kawrakow
2026-02-25 14:12:48 +01:00
committed by GitHub
parent 0bf7043a7b
commit c77ec4b8b8
15 changed files with 1002 additions and 13 deletions

View File

@@ -4378,8 +4378,9 @@ struct llama_context_params llama_context_default_params() {
/*.only_active_experts =*/ false,
/*.k_cache_hadamard =*/ false,
/*.split_mode_graph_scheduling =*/ false,
// /*.split_mode_f16 =*/ true,
// /*.split_mode_f16 =*/ true,
/*.scheduler_async =*/ false,
/*.fused_delta_net =*/ 0,
/*.mtp =*/ false,
/*.mtp_op_type =*/ MTP_OP_NONE,
/*.abort_callback =*/ nullptr,
@@ -4751,6 +4752,7 @@ struct llama_context * llama_init_from_model(
cparams.split_mode_graph_scheduling = params.split_mode_graph_scheduling;
//cparams.split_mode_f16 = params.split_mode_f16;
cparams.scheduler_async = params.scheduler_async;
cparams.fused_delta_net = params.fused_delta_net;
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;
cparams.cuda_params = params.cuda_params;
@@ -4836,7 +4838,7 @@ struct llama_context * llama_init_from_model(
cparams.mtp = 0;
}
cparams.mtp_op_type = params.mtp_op_type;
cparams.mtp_op_type = params.mtp_op_type;
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
@@ -4857,6 +4859,7 @@ struct llama_context * llama_init_from_model(
//LLAMA_LOG_INFO("%s: split_mode_f16= %d\n", __func__, cparams.split_mode_f16);
LLAMA_LOG_INFO("%s: reduce_type = %s\n", __func__, ggml_type_name(cparams.reduce_type));
LLAMA_LOG_INFO("%s: sched_async = %d\n", __func__, cparams.scheduler_async);
LLAMA_LOG_INFO("%s: fused_delta = %d\n", __func__, cparams.fused_delta_net);
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);