tests: add backend-op coverage for ggml_delta_net

This commit is contained in:
yurko
2026-02-07 14:34:56 -08:00
parent 6dd990d15a
commit ed0565f801

View File

@@ -1061,6 +1061,34 @@ struct test_mul_mat_id : public test_case {
}
};
// GGML_OP_DELTA_NET
struct test_delta_net : public test_case {
const ggml_type type;
const int64_t n_heads;
const int64_t head_dim;
const int64_t n_tokens;
const int64_t n_seqs;
std::string vars() override {
return VARS_TO_STR5(type, n_heads, head_dim, n_tokens, n_seqs);
}
test_delta_net(ggml_type type = GGML_TYPE_F32,
int64_t n_heads = 8, int64_t head_dim = 64, int64_t n_tokens = 32, int64_t n_seqs = 2)
: type(type), n_heads(n_heads), head_dim(head_dim), n_tokens(n_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
ggml_tensor * k = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
ggml_tensor * g = ggml_new_tensor_4d(ctx, type, n_tokens, 1, n_heads, n_seqs);
ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, n_tokens, n_heads, n_seqs);
ggml_tensor * state = ggml_new_tensor_4d(ctx, type, head_dim, head_dim * n_heads, 1, n_seqs);
return ggml_delta_net(ctx, q, k, v, g, beta, state);
}
};
// GGML_OP_SQR
struct test_sqr : public test_case {
const ggml_type type;
@@ -2436,6 +2464,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
}
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 1, 1));
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 32, 1));
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 32, 2));
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 128, 2));
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;