mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
tests: add backend-op coverage for ggml_delta_net
This commit is contained in:
@@ -1061,6 +1061,34 @@ struct test_mul_mat_id : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_DELTA_NET
|
||||
struct test_delta_net : public test_case {
|
||||
const ggml_type type;
|
||||
|
||||
const int64_t n_heads;
|
||||
const int64_t head_dim;
|
||||
const int64_t n_tokens;
|
||||
const int64_t n_seqs;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, n_heads, head_dim, n_tokens, n_seqs);
|
||||
}
|
||||
|
||||
test_delta_net(ggml_type type = GGML_TYPE_F32,
|
||||
int64_t n_heads = 8, int64_t head_dim = 64, int64_t n_tokens = 32, int64_t n_seqs = 2)
|
||||
: type(type), n_heads(n_heads), head_dim(head_dim), n_tokens(n_tokens), n_seqs(n_seqs) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
|
||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
|
||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
|
||||
ggml_tensor * g = ggml_new_tensor_4d(ctx, type, n_tokens, 1, n_heads, n_seqs);
|
||||
ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, n_tokens, n_heads, n_seqs);
|
||||
ggml_tensor * state = ggml_new_tensor_4d(ctx, type, head_dim, head_dim * n_heads, 1, n_seqs);
|
||||
return ggml_delta_net(ctx, q, k, v, g, beta, state);
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SQR
|
||||
struct test_sqr : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -2436,6 +2464,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 1, 1));
|
||||
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 32, 1));
|
||||
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 32, 2));
|
||||
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 128, 2));
|
||||
|
||||
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
|
||||
for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {
|
||||
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
|
||||
|
||||
Reference in New Issue
Block a user