Fused delta-net (AVX512) (#1362)

This commit is contained in:
Kawrakow
2026-03-05 07:55:05 +01:00
committed by GitHub
parent 2add439e43
commit 8fb002207a
2 changed files with 52 additions and 3 deletions

View File

@@ -134,12 +134,12 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
const int64_t state_size = S_v * S_v * H_v * n_seqs;
ggml_tensor * output_tokens = ggml_view_4d(ctx0, fused_result,
auto output_tokens = ggml_view_4d(ctx0, fused_result,
S_v, H_v, n_tokens, n_seqs,
ggml_row_size(fused_result->type, S_v),
ggml_row_size(fused_result->type, S_v * H_v),
ggml_row_size(fused_result->type, S_v * H_v * n_tokens), 0);
output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs);
//output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs);
ggml_tensor * new_state_flat = ggml_view_1d(ctx0, fused_result, state_size,
output_size * ggml_element_size(fused_result));