Simplify/improve CUDA delta-net

This commit is contained in:
Kawrakow
2026-02-24 07:41:09 +00:00
parent 28b31a66b2
commit dc44a37ca2
2 changed files with 65 additions and 1258 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -401,11 +401,18 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
cb(g, "g_in", il);
cb(state,"state_in", il);
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
beta = ggml_cont_4d(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3), 1, n_tokens, H_k, n_seqs);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
g = ggml_permute(ctx0, g, 2, 0, 3, 1);
beta = ggml_permute(ctx0, beta, 2, 0, 1, 3);
if (n_seqs > 1) {
q = ggml_cont_4d(ctx0, q, S_k, n_tokens, H_k, n_seqs);
k = ggml_cont_4d(ctx0, k, S_k, n_tokens, H_k, n_seqs);
v = ggml_cont_4d(ctx0, v, S_v, n_tokens, H_v, n_seqs);
g = ggml_cont_4d(ctx0, g, n_tokens, 1, H_k, n_seqs);
beta = ggml_cont_4d(ctx0, beta, 1, n_tokens, H_k, n_seqs);
}
ggml_tensor * state_flat = ggml_reshape_4d(ctx0, state, S_v, S_v * H_v, 1, n_seqs);
if (!ggml_is_contiguous(state_flat)) {