mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-28 17:14:17 +00:00
Simplify/improve CUDA delta-net
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -401,11 +401,18 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
|
||||
cb(g, "g_in", il);
|
||||
cb(state,"state_in", il);
|
||||
|
||||
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
|
||||
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
|
||||
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
|
||||
beta = ggml_cont_4d(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3), 1, n_tokens, H_k, n_seqs);
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
||||
g = ggml_permute(ctx0, g, 2, 0, 3, 1);
|
||||
beta = ggml_permute(ctx0, beta, 2, 0, 1, 3);
|
||||
if (n_seqs > 1) {
|
||||
q = ggml_cont_4d(ctx0, q, S_k, n_tokens, H_k, n_seqs);
|
||||
k = ggml_cont_4d(ctx0, k, S_k, n_tokens, H_k, n_seqs);
|
||||
v = ggml_cont_4d(ctx0, v, S_v, n_tokens, H_v, n_seqs);
|
||||
g = ggml_cont_4d(ctx0, g, n_tokens, 1, H_k, n_seqs);
|
||||
beta = ggml_cont_4d(ctx0, beta, 1, n_tokens, H_k, n_seqs);
|
||||
}
|
||||
|
||||
ggml_tensor * state_flat = ggml_reshape_4d(ctx0, state, S_v, S_v * H_v, 1, n_seqs);
|
||||
if (!ggml_is_contiguous(state_flat)) {
|
||||
|
||||
Reference in New Issue
Block a user