mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-20 22:49:31 +00:00
Minor delta-net tweak (#1337)
This commit is contained in:
@@ -151,18 +151,18 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
|
||||
return {output_tokens, new_state};
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(ggml_context * ctx0, ggml_tensor * input, int il, const llm_build_cb & cb) const {
|
||||
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(ggml_context * ctx0, ggml_tensor * input, int il, const llm_build_cb & cb, ggml_cgraph * gf) const {
|
||||
auto & model = lctx.model;
|
||||
const int64_t n_tok = input->ne[1];
|
||||
if (model.layers[il].wqkv) {
|
||||
ggml_tensor * qkv_mixed = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, input);
|
||||
cb(qkv_mixed, "qkv_mixed", il);
|
||||
qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_tok, 1);
|
||||
cb(qkv_mixed, "linear_attn_qkv_mixed", il);
|
||||
|
||||
ggml_tensor * z = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv_gate, input);
|
||||
cb(z, "z", il);
|
||||
|
||||
ggml_build_forward_expand(gf, qkv_mixed);
|
||||
ggml_build_forward_expand(gf, z);
|
||||
qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_tok, 1);
|
||||
cb(qkv_mixed, "linear_attn_qkv_mixed", il);
|
||||
return { qkv_mixed, z };
|
||||
}
|
||||
|
||||
@@ -246,9 +246,7 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
const int64_t n_seqs = 1;
|
||||
const int64_t n_seq_tokens = n_tok;
|
||||
|
||||
auto qkvz = build_qkvz(ctx0, cur, il, cb);
|
||||
ggml_tensor * qkv_mixed = qkvz.first;
|
||||
ggml_tensor * z = qkvz.second;
|
||||
auto [qkv_mixed, z] = build_qkvz(ctx0, cur, il, cb, gf);
|
||||
|
||||
ggml_tensor *alpha, *beta;
|
||||
if (model.layers[il].ssm_beta_alpha) {
|
||||
@@ -291,6 +289,7 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
ggml_build_forward_expand(gf, alpha);
|
||||
|
||||
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
|
||||
cb(alpha_biased, "alpha_biased", il);
|
||||
ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
|
||||
cb(alpha_softplus, "a_softplus", il);
|
||||
ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);
|
||||
@@ -373,6 +372,8 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
|
||||
ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1);
|
||||
ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_tok, 1);
|
||||
cb(q_repeated, "q_repeated", il);
|
||||
cb(k_repeated, "k_repeated", il);
|
||||
|
||||
q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1);
|
||||
k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_tok, 1);
|
||||
@@ -403,12 +404,11 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
ggml_tensor * new_conv_flat = ggml_reshape_2d(ctx0, new_conv_states_cont, conv_state_dim, 1);
|
||||
ggml_tensor * new_ssm_flat = ggml_reshape_2d(ctx0, new_state, ssm_state_dim, 1);
|
||||
ggml_tensor * new_state_flat = ggml_concat(ctx0, new_conv_flat, new_ssm_flat, 0);
|
||||
cb(new_state_flat, "new_state_flat", il);
|
||||
|
||||
ggml_tensor * state_update = new_state_flat;
|
||||
if (state_dst->type != GGML_TYPE_F32) {
|
||||
state_update = ggml_cast(ctx0, state_update, state_dst->type);
|
||||
}
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_update, state_dst));
|
||||
auto state_cpy = ggml_cpy(ctx0, new_state_flat, state_dst);
|
||||
cb(state_cpy, "state_cpy", il);
|
||||
ggml_build_forward_expand(gf, state_cpy);
|
||||
|
||||
ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_tok);
|
||||
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_tok);
|
||||
|
||||
Reference in New Issue
Block a user