qwen3next: add fused delta-net op and wire model path

This commit is contained in:
yurko
2026-02-07 14:32:16 -08:00
parent 5a6c4e8da5
commit 6dd990d15a
6 changed files with 1928 additions and 5 deletions

View File

@@ -676,6 +676,7 @@ extern "C" {
GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS,
GGML_OP_SOLVE_TRI,
GGML_OP_DELTA_NET,
GGML_OP_UNARY,
GGML_OP_MAP_UNARY,
@@ -2506,6 +2507,15 @@ extern "C" {
bool lower,
bool uni);
GGML_API struct ggml_tensor * ggml_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q, // [S_k, n_tokens, H_k, n_seqs]
struct ggml_tensor * k, // [S_k, n_tokens, H_k, n_seqs]
struct ggml_tensor * v, // [S_v, n_tokens, H_v, n_seqs]
struct ggml_tensor * g, // [n_tokens, 1, H_k, n_seqs]
struct ggml_tensor * beta, // [1, n_tokens, H_k, n_seqs]
struct ggml_tensor * state); // [S_v, S_v*H_v, 1, n_seqs]
// custom operators
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);

View File

@@ -50,6 +50,7 @@
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/solve_tri.cuh"
#include "ggml-cuda/ssm-conv.cuh"
#include "ggml-cuda/delta-net.cuh"
#include "ggml-cuda/argmax.cuh"
#include "ggml-cuda/multiadd.cuh"
#include "ggml-cuda/hadamard.cuh"
@@ -3675,6 +3676,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SOLVE_TRI:
ggml_cuda_op_solve_tri(ctx, dst);
break;
case GGML_OP_DELTA_NET:
ggml_cuda_op_delta_net(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_cuda_flash_attn_ext(ctx, dst);
break;
@@ -3908,6 +3912,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_cuda_graph
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s(%s): disabling CUDA graphs due to unsupported node type %ld %ld\n",
__func__, node->src[0]->name, node->ne[2], node->src[2]->ne[0]);
#endif
}
if (node->op == GGML_OP_DELTA_NET) {
use_cuda_graph = false;
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to DELTA_NET recurrent state\n", __func__);
#endif
}
if (node->op == GGML_OP_MOE_FUSED_UP_GATE) {
@@ -4527,6 +4537,22 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
op->src[2]->ne[1] == op->src[0]->ne[1] &&
op->src[1]->ne[0] == op->src[0]->ne[1] &&
op->src[3]->ne[0] == op->src[0]->ne[2];
case GGML_OP_DELTA_NET:
return op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32 &&
op->src[2]->type == GGML_TYPE_F32 &&
op->src[3]->type == GGML_TYPE_F32 &&
op->src[4]->type == GGML_TYPE_F32 &&
op->src[5]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op->src[2]) &&
ggml_is_contiguous(op->src[3]) &&
ggml_is_contiguous(op->src[4]) &&
ggml_is_contiguous(op->src[5]) &&
op->src[0]->ne[0] <= 256 &&
op->src[2]->ne[0] <= 256;
case GGML_OP_FLASH_ATTN_EXT:
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -4274,6 +4274,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GET_REL_POS",
"ADD_REL_POS",
"SOLVE_TRI",
"DELTA_NET",
"UNARY",
@@ -4298,7 +4299,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"FUSED_NORM",
};
static_assert(GGML_OP_COUNT == 100, "GGML_OP_COUNT != 100");
static_assert(GGML_OP_COUNT == 101, "GGML_OP_COUNT != 101");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -4391,6 +4392,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"get_rel_pos(x)",
"add_rel_pos(x)",
"solve_tri(x)",
"delta_net(q, k, v, g, beta, state)",
"unary(x)",
@@ -4415,7 +4417,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"norm(x,y)",
};
static_assert(GGML_OP_COUNT == 100, "GGML_OP_COUNT != 100");
static_assert(GGML_OP_COUNT == 101, "GGML_OP_COUNT != 101");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -10223,6 +10225,61 @@ struct ggml_tensor * ggml_solve_tri(
return result;
}
// ggml_delta_net
struct ggml_tensor * ggml_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * beta,
struct ggml_tensor * state) {
GGML_ASSERT(ggml_is_contiguous(q));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(g));
GGML_ASSERT(ggml_is_contiguous(beta));
GGML_ASSERT(ggml_is_contiguous(state));
GGML_ASSERT(q->type == GGML_TYPE_F32);
GGML_ASSERT(k->type == GGML_TYPE_F32);
GGML_ASSERT(v->type == GGML_TYPE_F32);
GGML_ASSERT(g->type == GGML_TYPE_F32);
GGML_ASSERT(beta->type == GGML_TYPE_F32);
GGML_ASSERT(state->type == GGML_TYPE_F32);
const int64_t S_k = q->ne[0];
const int64_t n_tokens = q->ne[1];
const int64_t H_k = q->ne[2];
const int64_t n_seqs = q->ne[3];
const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[2];
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == n_tokens && k->ne[2] == H_k && k->ne[3] == n_seqs);
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[3] == n_seqs);
GGML_ASSERT(g->ne[0] == n_tokens && g->ne[1] == 1 && g->ne[2] == H_k && g->ne[3] == n_seqs);
GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v);
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
const int64_t state_size = S_v * S_v * H_v * n_seqs;
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, output_size + state_size);
result->op = GGML_OP_DELTA_NET;
result->src[0] = q;
result->src[1] = k;
result->src[2] = v;
result->src[3] = g;
result->src[4] = beta;
result->src[5] = state;
return result;
}
// ggml_ssm_conv
struct ggml_tensor * ggml_ssm_conv(
@@ -21839,6 +21896,138 @@ static void ggml_compute_forward_flash_attn_back(
}
}
// ggml_compute_forward_delta_net
static void ggml_compute_forward_delta_net_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * src2 = dst->src[2];
const struct ggml_tensor * src3 = dst->src[3];
const struct ggml_tensor * src4 = dst->src[4];
const struct ggml_tensor * src5 = dst->src[5];
const int64_t head_dim = src0->ne[0];
const int64_t n_tokens = src0->ne[1];
const int64_t n_heads = src0->ne[2];
const int64_t n_seqs = src0->ne[3];
const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs;
const float * q_data = (const float *) src0->data;
const float * k_data = (const float *) src1->data;
const float * v_data = (const float *) src2->data;
const float * g_data = (const float *) src3->data;
const float * beta_data = (const float *) src4->data;
const float * state_in = (const float *) src5->data;
float * out_data = (float *) dst->data;
float * state_out = out_data + output_size;
const int ith = params->ith;
const int nth = params->nth;
const int64_t total_heads = n_heads * n_seqs;
const int64_t heads_per_thread = (total_heads + nth - 1) / nth;
const int64_t h_start = ith * heads_per_thread;
const int64_t h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads;
const float eps = 1e-12f;
const float scale = 1.0f / sqrtf((float) head_dim);
float * v_new_buf = (float *) malloc(head_dim * sizeof(float));
if (!v_new_buf) {
return;
}
for (int64_t h_idx = h_start; h_idx < h_end; ++h_idx) {
const int64_t batch_idx = h_idx / n_heads;
const int64_t head_idx = h_idx % n_heads;
const int64_t qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
const int64_t qkv_token_stride = head_dim;
const int64_t g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
const int64_t state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
const int64_t out_head_offset = batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim;
const int64_t out_token_stride = head_dim * n_heads;
for (int64_t i = 0; i < head_dim * head_dim; ++i) {
state_out[state_head_offset + i] = state_in[state_head_offset + i];
}
float * state = state_out + state_head_offset;
for (int64_t t = 0; t < n_tokens; ++t) {
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
const float g_val = g_data[g_head_offset + t];
const float beta_raw = beta_data[g_head_offset + t];
float q_norm_sq = 0.0f;
float k_norm_sq = 0.0f;
for (int64_t i = 0; i < head_dim; ++i) {
q_norm_sq += q_t[i] * q_t[i];
k_norm_sq += k_t[i] * k_t[i];
}
const float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps);
const float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps);
const float beta_val = 1.0f / (1.0f + expf(-beta_raw));
const float decay = expf(fminf(g_val, 50.0f));
float attn_score = 0.0f;
for (int64_t i = 0; i < head_dim; ++i) {
attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
}
float * out_t = out_data + out_head_offset + t * out_token_stride;
for (int64_t row = 0; row < head_dim; ++row) {
float v_prime = 0.0f;
float out_val = 0.0f;
for (int64_t col = 0; col < head_dim; ++col) {
const float k_col = k_t[col] * k_norm_inv;
const float q_col = q_t[col] * q_norm_inv * scale;
const float s = state[row + col * head_dim];
v_prime += s * k_col * beta_val * decay;
out_val += s * q_col * decay;
}
const float v_new = v_t[row] * beta_val - v_prime;
v_new_buf[row] = v_new;
out_t[row] = out_val + v_new * attn_score;
}
for (int64_t col = 0; col < head_dim; ++col) {
const float k_col = k_t[col] * k_norm_inv;
for (int64_t row = 0; row < head_dim; ++row) {
float s = state[row + col * head_dim];
s = decay * s + v_new_buf[row] * k_col;
state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f);
}
}
}
}
free(v_new_buf);
}
static void ggml_compute_forward_delta_net(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
switch (dst->src[0]->type) {
case GGML_TYPE_F32:
ggml_compute_forward_delta_net_f32(params, dst);
break;
default:
GGML_ABORT("fatal error");
}
}
// ggml_compute_forward_ssm_conv
static void ggml_compute_forward_ssm_conv_f32(
@@ -23862,6 +24051,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
{
ggml_compute_forward_ssm_scan(params, tensor);
} break;
case GGML_OP_DELTA_NET:
{
ggml_compute_forward_delta_net(params, tensor);
} break;
case GGML_OP_SOLVE_TRI:
{
ggml_compute_forward_solve_tri(params, tensor);
@@ -24986,6 +25179,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
}
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
case GGML_OP_DELTA_NET:
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
@@ -25722,6 +25916,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
case GGML_OP_DELTA_NET:
{
n_tasks = n_threads;
} break;