Add command line argument for fused delta net

This commit is contained in:
Kawrakow
2026-02-24 05:40:26 +00:00
parent a350f1b96f
commit 28b31a66b2
6 changed files with 18 additions and 6 deletions

View File

@@ -43,6 +43,7 @@ struct llama_cparams {
bool split_mode_graph_scheduling;
//bool split_mode_f16;
bool scheduler_async;
bool fused_delta_net;
int min_experts;
float thresh_experts;
bool mtp;

View File

@@ -679,10 +679,9 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
GGML_ASSERT(identity != nullptr);
GGML_ASSERT(diag_mask != nullptr);
attn_out = n_tok == 1
//? build_delta_net_autoregressive(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb)
? build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb)
: build_delta_net_chunking(ctx0, q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il, cb);
attn_out = n_tok == 1 ? lctx.cparams.fused_delta_net ? build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb)
: build_delta_net_autoregressive(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb)
: build_delta_net_chunking(ctx0, q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il, cb);
ggml_tensor * output = attn_out.first;
ggml_tensor * new_state = attn_out.second;
cb(output, "attn_output", il);

View File

@@ -4378,8 +4378,9 @@ struct llama_context_params llama_context_default_params() {
/*.only_active_experts =*/ false,
/*.k_cache_hadamard =*/ false,
/*.split_mode_graph_scheduling =*/ false,
// /*.split_mode_f16 =*/ true,
// /*.split_mode_f16 =*/ true,
/*.scheduler_async =*/ false,
/*.fused_delta_net =*/ false,
/*.mtp =*/ false,
/*.mtp_op_type =*/ MTP_OP_NONE,
/*.abort_callback =*/ nullptr,
@@ -4750,6 +4751,7 @@ struct llama_context * llama_init_from_model(
cparams.split_mode_graph_scheduling = params.split_mode_graph_scheduling;
//cparams.split_mode_f16 = params.split_mode_f16;
cparams.scheduler_async = params.scheduler_async;
cparams.fused_delta_net = params.fused_delta_net;
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;
cparams.cuda_params = params.cuda_params;
@@ -4835,7 +4837,7 @@ struct llama_context * llama_init_from_model(
cparams.mtp = 0;
}
cparams.mtp_op_type = params.mtp_op_type;
cparams.mtp_op_type = params.mtp_op_type;
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
@@ -4856,6 +4858,7 @@ struct llama_context * llama_init_from_model(
//LLAMA_LOG_INFO("%s: split_mode_f16= %d\n", __func__, cparams.split_mode_f16);
LLAMA_LOG_INFO("%s: reduce_type = %s\n", __func__, ggml_type_name(cparams.reduce_type));
LLAMA_LOG_INFO("%s: sched_async = %d\n", __func__, cparams.scheduler_async);
LLAMA_LOG_INFO("%s: fused_delta = %d\n", __func__, cparams.fused_delta_net);
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);