More Qwen3-Next optimizations (#1277)

* Optimizing q3next TG

* Fused add -> softplus -> mul on CUDA

* Remove forgotten debug log

* Increase ggml context size

Required for Qwen3-Next with batch/u-batch size of 4096

* WIP

* Avoid some contiguous ops

* Avoid some repeats

* Avoid some more repeats
This commit is contained in:
Kawrakow
2026-02-17 16:03:51 +01:00
committed by GitHub
parent 88f98c891d
commit cafeef484c
7 changed files with 378 additions and 79 deletions

View File

@@ -2205,6 +2205,45 @@ static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tenso
}
template <typename src_t, int block_size = 256>
static __global__ void mul_mat_row(int n, const src_t * x, const float * y, float * z) {
float sum = 0;
for (int i = threadIdx.x; i < n; i += block_size) {
float xi = ggml_cuda_cast<float, src_t>(x[i]);
sum += xi * y[i];
}
sum = warp_reduce_sum(sum);
if constexpr (block_size > WARP_SIZE) {
__shared__ float tmp[block_size/WARP_SIZE];
if (threadIdx.x % WARP_SIZE == 0) {
tmp[threadIdx.x / WARP_SIZE] = sum;
}
__syncthreads();
sum = threadIdx.x < block_size / WARP_SIZE ? tmp[threadIdx.x] : 0.0f;
sum = warp_reduce_sum(sum);
}
if (threadIdx.x == 0) {
z[0] = sum;
}
}
static void mul_mat_1row(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_backend_cuda_context & ctx) {
constexpr int kBlockSize = 256;
GGML_ASSERT(src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
if (src0->type == GGML_TYPE_F16) {
mul_mat_row<<<1, kBlockSize, 0, ctx.stream()>>>((int)src0->ne[0], (const half *)src0->data, (const float *)src1->data, (float *)dst->data);
}
else if (src0->type == GGML_TYPE_BF16) {
mul_mat_row<<<1, kBlockSize, 0, ctx.stream()>>>((int)src0->ne[0], (const nv_bfloat16 *)src0->data, (const float *)src1->data, (float *)dst->data);
}
else if (src0->type == GGML_TYPE_F32) {
mul_mat_row<<<1, kBlockSize, 0, ctx.stream()>>>((int)src0->ne[0], (const float *)src0->data, (const float *)src1->data, (float *)dst->data);
}
else {
GGML_ABORT("Fatal error");
}
}
static int ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const ggml_cgraph * cgraph, int node_n) {
@@ -2239,6 +2278,12 @@ static int ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
return ggml_cuda_mul_mat_q(ctx, src0, src1, dst, cgraph, node_n, use_mul_mat_vec_q);
}
if (ggml_nrows(src0) == 1 && ggml_nrows(src1) == 1) { // && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
mul_mat_1row(src0, src1, dst, ctx);
return node_n;
}
bool debug = false; //src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32;
// debug helpers
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
@@ -2248,29 +2293,29 @@ static int ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
if (any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
//printf("%s(%s): using ggml_cuda_mul_mat_vec_p021\n", __func__, dst->name);
if (debug) printf("%s(%s): using ggml_cuda_mul_mat_vec_p021\n", __func__, dst->name);
// FP32 precision KQ single-batch for batch size 1 without FlashAttention
ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
} else if (any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
//printf("%s(%s): using ggml_cuda_mul_mat_vec_nc\n", __func__, dst->name);
if (debug) printf("%s(%s): using ggml_cuda_mul_mat_vec_nc\n", __func__, dst->name);
// FP32 precision KQV single-batch for batch size 1 without FlashAttention
ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
} else if ((src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32) && (src1->type == src0->type || !any_gpus_with_slow_fp16)
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
//printf("%s(%s): ggml_cuda_mul_mat_batched_cublas\n", __func__, dst->name);
if (debug) printf("%s(%s): ggml_cuda_mul_mat_batched_cublas\n", __func__, dst->name);
// KQ + KQV multi-batch without FlashAttention
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
} else if (use_dequantize_mul_mat_vec) {
//printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_dequantize_mul_mat_vec)\n", __func__, dst->name);
if (debug) printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_dequantize_mul_mat_vec)\n", __func__, dst->name);
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
} else if (use_mul_mat_vec_q) {
//printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_vec_q)\n", __func__, dst->name);
if (debug) printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_vec_q)\n", __func__, dst->name);
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
} else if (use_mul_mat_q) {
//printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_q)\n", __func__, dst->name);
if (debug) printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_q)\n", __func__, dst->name);
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
} else {
//printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_cublas)\n", __func__, dst->name);
if (debug) printf("%s(%s, %s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_cublas)\n", __func__, dst->name, ggml_type_name(src0->type));
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
}
return node_n;
@@ -3218,6 +3263,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_sum_rows_nc(ctx, cgraph->nodes[i+1]);
i += 2;
} else {
//auto src = dst->src[0];
//printf("cont(%s -> %s): %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu\n", src->name, dst->name, src->ne[0], src->ne[1], src->ne[2], src->ne[3], src->nb[0], src->nb[1], src->nb[2], src->nb[3]);
ggml_cuda_dup(ctx, dst);
}
break;
@@ -3237,6 +3284,17 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_fused_add_add_rms_norm(ctx, dst, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
}
else if (fusion && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_UNARY &&
(ggml_unary_op)cgraph->nodes[i+1]->op_params[0] == GGML_UNARY_OP_SOFTPLUS &&
cgraph->nodes[i+2]->op == GGML_OP_MUL &&
cgraph->nodes[i+2]->src[0] == cgraph->nodes[i+1] &&
cgraph->nodes[i+1]->src[0] == cgraph->nodes[i] &&
ggml_nrows(cgraph->nodes[i+0]->src[1]) == 1 &&
ggml_nrows(cgraph->nodes[i+2]->src[1]) == 1) {
ggml_cuda_fused_softplus(ctx, cgraph->nodes[i+2]);
i += 2;
}
else if (false && fusion && i + 1 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_FUSED_RMS_NORM &&
ggml_is_contiguous(dst->src[0]) &&
@@ -3262,10 +3320,25 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_acc(ctx, dst);
break;
case GGML_OP_MUL:
//printf("mul(%s): %d, %d, %d, %ld x %ld x %ld x %ld * %ld x %ld x %ld x %ld\n", dst->name, ggml_is_contiguous(dst->src[0]), ggml_is_contiguous(dst->src[1]), ggml_is_contiguous(dst),
// dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->src[0]->ne[3],
// dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], dst->src[1]->ne[3]);
ggml_cuda_op_mul(ctx, dst);
if (fusion && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_UNARY &&
cgraph->nodes[i+2]->op == GGML_OP_MUL &&
(ggml_unary_op)cgraph->nodes[i+1]->op_params[0] == GGML_UNARY_OP_EXP &&
cgraph->nodes[i+1]->src[0] == dst &&
cgraph->nodes[i+2]->src[0] == cgraph->nodes[i+1] &&
cgraph->nodes[i+2]->src[1] == dst->src[1]) {
ggml_cuda_fused_mul_exp_mul(ctx, cgraph->nodes[i+2]);
i += 2;
//printf("mul(%s) -> exp(%s) -> mul(%s), %d, %d, %zu, %zu; %ld x %ld x %ld x %ld - %ld x %ld x %ld x %ld\n", dst->name, cgraph->nodes[i+1]->name, cgraph->nodes[i+2]->name,
// ggml_is_contiguous(dst->src[0]), ggml_is_contiguous(dst->src[1]), ggml_nelements(dst->src[0]), ggml_nelements(dst->src[1]),
// dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->src[0]->ne[3],
// dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], dst->src[1]->ne[3]);
} else {
//printf("mul(%s): %d, %d, %d, %ld x %ld x %ld x %ld * %ld x %ld x %ld x %ld\n", dst->name, ggml_is_contiguous(dst->src[0]), ggml_is_contiguous(dst->src[1]), ggml_is_contiguous(dst),
// dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->src[0]->ne[3],
// dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], dst->src[1]->ne[3]);
ggml_cuda_op_mul(ctx, dst);
}
break;
case GGML_OP_FUSED_MUL_UNARY:
ggml_cuda_op_fused_mul_unary(ctx, dst);
@@ -3277,6 +3350,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_sub(ctx, dst);
break;
case GGML_OP_UNARY:
//printf("unary(%s, %s)\n", dst->name, ggml_unary_op_name((ggml_unary_op)dst->op_params[0]));
switch (ggml_get_unary_op(dst)) {
case GGML_UNARY_OP_GELU:
ggml_cuda_op_gelu(ctx, dst);

View File

@@ -33,6 +33,40 @@ static void pad_f32_cuda(const float * x, float * dst,
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
}
template <int dim>
static __global__ void pad_f32_nc(const char * cx, float * dst, int nelem,
int ne0, int ne1, int ne2, int ne3, int ne00, int ne01, int ne02, int ne03,
size_t nb00, size_t nb01, size_t nb02, size_t nb03) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i >= nelem) {
return;
}
int ii = i;
int i3 = ii/(ne0*ne1*ne2); ii -= i3*ne0*ne1*ne2;
int i2 = ii/(ne0*ne1 ); ii -= i2*ne0*ne1;
int i1 = ii/(ne0 );
int i0 = ii - i1*ne0;
if constexpr (dim == 0) {
dst[i] = i0 < ne00 ? *(const float *)(cx + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03) : 0.0f;
}
else if constexpr (dim == 1) {
dst[i] = i1 < ne01 ? *(const float *)(cx + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03) : 0.0f;
}
else if constexpr (dim == 2) {
dst[i] = i2 < ne02 ? *(const float *)(cx + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03) : 0.0f;
}
else if constexpr (dim == 3) {
dst[i] = i3 < ne03 ? *(const float *)(cx + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03) : 0.0f;
}
else if constexpr (dim == 4) {
dst[i] = *(const float *)(cx + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03);
}
else {
dst[i] = i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03 ? *(const float *)(cx + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03) : 0.0f;
}
}
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@@ -41,9 +75,53 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
GGML_ASSERT(ggml_is_contiguous(dst));
if (ggml_is_contiguous(src0)) {
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
pad_f32_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
pad_f32_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
return;
}
int npad = 0; int pad_dim = -1;
for (int i = 0; i < 4; ++i) {
if (dst->ne[i] > src0->ne[i]) {
++npad; pad_dim = i;
}
}
//if (npad == 0) {
// printf("Oops: npad = 0: %ld vs %ld, %ld vx %ld, %ld vs %ld, %ld vs %ld\n", dst->ne[0], src0->ne[0], dst->ne[1], src0->ne[1], dst->ne[2], src0->ne[2], dst->ne[3], src0->ne[3]);
//}
//GGML_ASSERT(npad > 0);
constexpr int kBlockSize = 256;
int nelem = ggml_nelements(dst);
int nblock = (nelem + kBlockSize - 1)/kBlockSize;
if (npad == 0) {
//printf("%s: %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu\n", src0->name, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
pad_f32_nc<4><<<nblock, kBlockSize, 0, ctx.stream()>>>((const char *)src0->data, (float *)dst->data, nelem,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
} else if (npad == 1) {
if (pad_dim == 0) {
pad_f32_nc<0><<<nblock, kBlockSize, 0, ctx.stream()>>>((const char *)src0->data, (float *)dst->data, nelem,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
} else if (pad_dim == 1) {
pad_f32_nc<1><<<nblock, kBlockSize, 0, ctx.stream()>>>((const char *)src0->data, (float *)dst->data, nelem,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
} else if (pad_dim == 2) {
pad_f32_nc<2><<<nblock, kBlockSize, 0, ctx.stream()>>>((const char *)src0->data, (float *)dst->data, nelem,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
} else if (pad_dim == 3) {
pad_f32_nc<3><<<nblock, kBlockSize, 0, ctx.stream()>>>((const char *)src0->data, (float *)dst->data, nelem,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
} else {
GGML_ABORT("Fatal error");
}
} else {
pad_f32_nc<-1><<<nblock, kBlockSize, 0, ctx.stream()>>>((const char *)src0->data, (float *)dst->data, nelem,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
}
}

View File

@@ -950,3 +950,59 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_elu>(ctx, dst);
}
static __global__ void k_fused_softplus(int ne0, int nelem, const float * a, const float * b, const float * c, float * dst) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i >= nelem) {
return;
}
int i0 = i % ne0;
dst[i] = c[i0] * op_softplus(a[i] + b[i0]);
}
void ggml_cuda_fused_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int kBlockSize = 256;
auto m = dst;
auto u = dst->src[0];
auto a = u->src[0];
GGML_ASSERT(m->op == GGML_OP_MUL);
GGML_ASSERT(a->op == GGML_OP_ADD);
GGML_ASSERT(u->op == GGML_OP_UNARY && (ggml_unary_op)u->op_params[0] == GGML_UNARY_OP_SOFTPLUS);
GGML_ASSERT(ggml_nrows(m->src[1]) == 1 && m->src[1]->ne[0] == m->src[0]->ne[0]);
GGML_ASSERT(ggml_nrows(a->src[1]) == 1 && a->src[1]->ne[0] == a->src[0]->ne[0]);
GGML_ASSERT(a->type == GGML_TYPE_F32 && u->type == GGML_TYPE_F32 && m->type == GGML_TYPE_F32);
GGML_ASSERT(a->src[0]->type == GGML_TYPE_F32 && a->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(a->src[0]));
int nelem = ggml_nelements(a->src[0]);
int nblock = (nelem + kBlockSize - 1)/kBlockSize;
k_fused_softplus<<<nblock, kBlockSize, 0, ctx.stream()>>>(a->src[0]->ne[0], nelem,
(const float *)a->src[0]->data, (const float *)a->src[1]->data, (const float *)m->src[1]->data, (float *)dst->data);
}
static __global__ void k_fused_mul_exp_mul(int ne0, int nelem, const float * x, const float * y, float * dst) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i >= nelem) {
return;
}
int i0 = i % ne0;
dst[i] = y[i0] * expf(x[i] * y[i0]);
}
void ggml_cuda_fused_mul_exp_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int kBlockSize = 256;
auto m2 = dst;
auto u = dst->src[0];
auto m1 = u->src[0];
GGML_ASSERT(m1->src[0]->type == GGML_TYPE_F32 && m1->src[1]->type == GGML_TYPE_F32 && m2->type == GGML_TYPE_F32);
GGML_ASSERT(m1->src[1] == m2->src[1]);
GGML_ASSERT(ggml_is_contiguous(m1->src[0]) && ggml_is_contiguous(m1->src[1]));
GGML_ASSERT(u->op == GGML_OP_UNARY && (ggml_unary_op)u->op_params[0] == GGML_UNARY_OP_EXP);
auto nelem = ggml_nelements(m1->src[0]);
auto ne0 = ggml_nelements(m1->src[1]);
int nblock = (nelem + kBlockSize - 1)/kBlockSize;
k_fused_mul_exp_mul<<<nblock, kBlockSize, 0, ctx.stream()>>>(ne0, nelem,
(const float *)m1->src[0]->data, (const float *)m1->src[1]->data, (float *)dst->data);
}

View File

@@ -95,3 +95,7 @@ void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op,
int64_t nelements,int64_t ne0, const float * x, float * z, float limit = 0);
void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_fused_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_fused_mul_exp_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -21071,7 +21071,6 @@ static void ggml_compute_forward_pad_f32(
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT( dst->nb[0] == sizeof(float));
const int ith = params->ith;
@@ -21083,22 +21082,82 @@ static void ggml_compute_forward_pad_f32(
// TODO: optimize
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
for (int64_t i3 = 0; i3 < ne3; ++i3) {
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
if (src0->nb[0] == sizeof(float)) {
for (int64_t i3 = 0; i3 < ne3; ++i3) {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
dst_ptr[dst_idx] = *src_ptr;
} else {
dst_ptr[dst_idx] = 0;
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
dst_ptr[dst_idx] = *src_ptr;
} else {
dst_ptr[dst_idx] = 0;
}
}
}
}
}
} else {
const int k_block_size = 1024;
int nelem = ggml_nelements(dst);
int nblocks = (nelem + k_block_size - 1)/k_block_size;
for (int ib = ith; ib < nblocks; ib += nth) {
int first = ib*k_block_size;
int last = MIN(first + k_block_size, nelem);
//
//int ii = first;
//int i3 = ii/(ne0*ne1*ne2); ii -= i3*ne0*ne1*ne2;
//int i2 = ii/(ne0*ne1 ); ii -= i2*ne0*ne1;
//int i1 = ii/(ne0 ); ii -= i1*ne0;
//int i0 = ii;
//int i = first;
//bool in_src = i1 < ne01 && i2 < ne02 && i3 < ne03;
//const char * c_src = (const char *)src0->data + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03;
//while (i < last) {
// if (i0 + last - i <= ne0) {
// for (; i < last; ++i, ++i0) {
// dst_ptr[i] = in_src && i0 < ne00 ? *(const float *)(c_src + i0*nb00) : 0.0f;
// }
// break;
// }
// for (; i0 < ne0; ++i0) {
// dst_ptr[i++] = in_src && i0 < ne00 ? *(const float *)(c_src + i0*nb00) : 0.0f;
// }
// i0 = 0;
// if (++i1 == (int)ne1) {
// i1 = 0;
// if (++i2 == (int)ne2) {
// i2 = 0; ++i3;
// }
// }
// in_src = i1 < ne01 && i2 < ne02 && i3 < ne03;
// c_src = (const char *)src0->data + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03;
//}
//for (int i = first; i < last; ++i) {
// dst_ptr[i] = i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03 ? *(const float *)((const char *)src0->data + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03) : 0.0f;
// if (++i0 == (int)ne0) {
// i0 = 0;
// if (++i1 == (int)ne1) {
// i1 = 0;
// if (++i2 == (int)ne2) {
// i2 = 0; ++i3;
// }
// }
// }
//}
//
for (int i = first; i < last; ++i) {
int ii = i;
int i3 = ii/(ne0*ne1*ne2); ii -= i3*ne0*ne1*ne2;
int i2 = ii/(ne0*ne1 ); ii -= i2*ne0*ne1;
int i1 = ii/(ne0 ); ii -= i1*ne0;
int i0 = ii;
dst_ptr[i] = i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03 ? *(const float *)((const char *)src0->data + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03) : 0.0f;
}
}
}
}

View File

@@ -4398,10 +4398,6 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v);
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
const float scale = 1.0f / sqrtf(S_v);
q = ggml_scale(ctx0, q, scale);
@@ -4412,29 +4408,23 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(g, "g_in", il);
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_v, n_seqs);
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
cb(q, "q_perm", il);
cb(k, "k_perm", il);
cb(v, "v_perm", il);
cb(beta, "beta_perm", il);
cb(g, "g_perm", il);
cb(state,"state_in", il);
const int64_t chunk_size = QWEN3NEXT_CHUNK_SIZE;
const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
const int64_t n_chunks = (n_tokens + pad) / chunk_size;
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
g = ggml_permute(ctx0, g, 2, 0, 3, 1);
beta = ggml_permute(ctx0, beta, 2, 0, 1, 3);
q = ggml_pad(ctx0, q, 0, pad, 0, 0);
k = ggml_pad(ctx0, k, 0, pad, 0, 0);
v = ggml_pad(ctx0, v, 0, pad, 0, 0);
g = ggml_pad(ctx0, g, pad, 0, 0, 0);
beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
g = ggml_pad(ctx0, g, pad, 0, 0, 0);
cb(q, "q_pad", il);
cb(k, "k_pad", il);
@@ -4443,7 +4433,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
cb(g, "g_pad", il);
ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
ggml_tensor * k_beta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, beta, k->ne[0], beta->ne[1], beta->ne[2], beta->ne[3]), k);
ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
cb(v_beta, "v_beta", il);
cb(k_beta, "k_beta", il);
@@ -4470,27 +4460,37 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
cb(decay_mask, "decay_mask", il);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
cb(decay_mask, "decay_mask_1", il);
decay_mask = ggml_exp(ctx0, decay_mask);
cb(decay_mask, "decay_mask_exp", il);
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
cb(decay_mask, "decay_mask_2", il);
ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
cb(kmulkbeta, "kk_beta", il);
ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
cb(k_decay, "k_decay_1", il);
k_decay = ggml_mul(ctx0, k_decay, causal_mask);
cb(k_decay, "k_decay_2", il);
ggml_tensor * attn = ggml_neg(ctx0, k_decay);
cb(attn, "attn_pre_solve", il);
ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
cb(attn_lower, "attn_lower", il);
ggml_tensor * identity_repeat =
ggml_repeat_4d(ctx0, identity, attn_lower->ne[0], attn_lower->ne[1], attn_lower->ne[2], attn_lower->ne[3]);
ggml_tensor * lhs = ggml_neg(ctx0, ggml_sub(ctx0, attn_lower, identity_repeat));
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
attn = ggml_mul(ctx0, lin_solve, causal_mask);
cb(attn, "attn_mul", il);
attn = ggml_add(ctx0, attn, identity);
cb(attn, "attn_solved", il);
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
auto v_beta_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_beta));
cb(v_beta_t, "v_beta_t", il);
v = ggml_mul_mat(ctx0, v_beta_t, attn);
cb(v, "v_beta", il);
ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
@@ -4501,7 +4501,9 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
cb(kbeta_gexp, "kbeta_gexp", il);
auto attn_kbeta = ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)));
auto kbeta_gexp_t = ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp));
cb(kbeta_gexp_t, "kbeta_gexp_t", il);
auto attn_kbeta = ggml_mul_mat(ctx0, attn, kbeta_gexp_t);
cb(attn_kbeta, "attn_kbeta", il);
ggml_tensor * k_cumdecay = ggml_cont(ctx0, ggml_transpose(ctx0, attn_kbeta));
cb(k_cumdecay, "k_cumdecay", il);
@@ -4509,6 +4511,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
cb(attn_kq, "attn_kq_pre", il);
attn_kq = ggml_mul(ctx0, decay_mask, attn_kq);
cb(attn_kq, "attn_kq_0", il);
attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
cb(attn_kq, "attn_kq", il);
@@ -4527,9 +4530,10 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
cb(g_diff, "g_diff", il);
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
cb(g_diff_exp, "g_diff_exp", il);
ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, 1, chunk_size, n_chunks, g_diff_exp->ne[3]);
ggml_tensor * key_gdiff = ggml_mul(ctx0, ggml_repeat_4d(ctx0, g_diff_exp_t, k->ne[0], g_diff_exp_t->ne[1], g_diff_exp_t->ne[2], g_diff_exp_t->ne[3]), k);
ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
cb(key_gdiff, "key_gdiff", il);
ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
@@ -4548,26 +4552,24 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
cb(attn_chunk, "attn_chunk", il);
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
cb(state_t, "state_t", il);
//printf("v_prime_chunk: %ld x %ld x %ld x %ld, %s x %ld x %ld x %ld x %ld, %s\n", state_t->ne[0], state_t->ne[1], state_t->ne[2], state_t->ne[3], ggml_type_name(state_t->type),
// k_cumdecay_chunk->ne[0], k_cumdecay_chunk->ne[1], k_cumdecay_chunk->ne[2], k_cumdecay_chunk->ne[3], ggml_type_name(k_cumdecay_chunk->type));
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
cb(v_prime, "v_prime_chunk", il);
ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
ggml_tensor * v_new = ggml_sub(ctx0, v_prime, v_chunk);
ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
cb(v_new, "v_new_chunk", il);
ggml_tensor * q_g_exp = ggml_mul(ctx0, ggml_repeat_4d(ctx0, gexp_chunk, q_chunk->ne[0], gexp_chunk->ne[1], gexp_chunk->ne[2], gexp_chunk->ne[3]), q_chunk);
ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk);
cb(q_g_exp, "q_g_exp", il);
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
cb(attn_inter, "attn_inter_chunk", il);
//printf("v_attn_chunk: %ld x %ld x %ld x %ld, %s x %ld x %ld x %ld x %ld, %s\n", v_new_t->ne[0], v_new_t->ne[1], v_new_t->ne[2], v_new_t->ne[3], ggml_type_name(v_new_t->type),
// attn_chunk->ne[0], attn_chunk->ne[1], attn_chunk->ne[2], attn_chunk->ne[3], ggml_type_name(attn_chunk->type));
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
cb(v_attn, "v_attn_chunk", il);
ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
ggml_tensor * core_attn_out_chunk = ggml_sub(ctx0, attn_inter, v_attn);
cb(core_attn_out_chunk, "core_attn_out_chunk", il);
core_attn_out = core_attn_out == nullptr
@@ -4575,15 +4577,14 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
: ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
ggml_tensor * k_gdiff_t = get_slice_2d(key_gdiff_t, chunk);
//printf("kgdmulvnew: %ld x %ld x %ld x %ld, %s x %ld x %ld x %ld x %ld, %s\n", v_new_t->ne[0], v_new_t->ne[1], v_new_t->ne[2], v_new_t->ne[3], ggml_type_name(v_new_t->type),
// k_gdiff_t->ne[0], k_gdiff_t->ne[1], k_gdiff_t->ne[2], k_gdiff_t->ne[3], ggml_type_name(k_gdiff_t->type));
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
cb(kgdmulvnew, "kgdmulvnew", il);
ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(g_last_exp, chunk));
state = ggml_add(ctx0,
ggml_mul(ctx0, state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
cb(gexp_last_chunk, "gexp_last_chunk", il);
auto s_mul = ggml_mul(ctx0, state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs));
cb(s_mul, "s_mul", il);
state = ggml_sub(ctx0, s_mul, ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
}
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
@@ -4595,6 +4596,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
output_tokens = ggml_cont(ctx0, output_tokens);
cb(output_tokens, "output_tokens", il);
return {output_tokens, state};
};
@@ -4614,9 +4616,9 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
GGML_ASSERT(H_k == H_v);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
const float eps_norm = hparams.f_norm_rms_eps;
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
//const float eps_norm = hparams.f_norm_rms_eps;
//q = ggml_l2_norm(ctx0, q, eps_norm);
//k = ggml_l2_norm(ctx0, k, eps_norm);
const float scale = 1.0f / sqrtf(S_v);
@@ -4633,10 +4635,13 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
g_t = ggml_exp(ctx0, g_t);
cb(g_t, "g_t", il);
state = ggml_mul(ctx0, state, g_t);
cb(state, "state", il);
ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
cb(kv_mem, "kv_mem", il);
kv_mem = ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem));
cb(kv_mem, "kv_mem_t_cont", il);
kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, kv_mem));
@@ -4645,12 +4650,15 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);
cb(v_diff, "v_diff", il);
ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);
cb(delta, "delta", il);
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
cb(k_t_delta, "k_t_delta", il);
state = ggml_add(ctx0, state, k_t_delta);
ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);
ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
cb(state_q, "state_q", il);
state_q = ggml_cont(ctx0, ggml_transpose(ctx0, state_q));
cb(state_q, "state_q_t_cont", il);
ggml_tensor * core_attn_out = ggml_transpose(ctx0, ggml_sum_rows(ctx0, state_q));
@@ -4880,6 +4888,8 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tok, 1);
ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tok, 1);
cb(beta, "beta", il);
cb(alpha, "alpha", il);
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
@@ -4919,25 +4929,38 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * conv_output_raw = ggml_ssm_conv(ctx0, conv_states, qkv_mixed, model.layers[il].ssm_conv1d, inp_s_seq_qnext);
cb(conv_output_raw, "conv_output_raw", il);
ggml_tensor * conv_output = ggml_view_2d(ctx0, conv_output_raw, conv_dim, n_tok, conv_dim * ggml_element_size(conv_output_raw), 0);
ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output);
//ggml_tensor * conv_output = ggml_view_2d(ctx0, conv_output_raw, conv_dim, n_tok, conv_dim * ggml_element_size(conv_output_raw), 0);
//ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output);
ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_raw);
cb(conv_output_silu, "conv_output_silu", il);
ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_output_silu, key_dim, n_tok, conv_output_silu->nb[1], 0);
ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_output_silu, key_dim, n_tok, conv_output_silu->nb[1],
key_dim * ggml_element_size(conv_output_silu));
// Calculate the total conv dimension
int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
int64_t nb1_qkv = ggml_row_size(conv_output_silu->type, qkv_dim);
// Extract the convolved Q, K, V from conv_output
ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_output_silu, head_k_dim, num_k_heads, n_tok, 1,
ggml_row_size(conv_output_silu->type, head_k_dim),
nb1_qkv, nb1_qkv * n_tok, 0);
ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_output_silu, head_k_dim, num_k_heads, n_tok, 1,
ggml_row_size(conv_output_silu->type, head_k_dim),
nb1_qkv, nb1_qkv * n_tok,
head_k_dim * num_k_heads * ggml_element_size(conv_output_silu));
ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_output_silu, head_v_dim, num_v_heads, n_tok, 1,
ggml_row_size(conv_output_silu->type, head_v_dim),
conv_output_silu->nb[1],
conv_output_silu->nb[1] * n_tok,
2 * key_dim * ggml_element_size(conv_output_silu));
nb1_qkv, nb1_qkv * n_tok,
ggml_row_size(conv_output_silu->type, 2 * head_k_dim * num_k_heads));
q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_tok, 1);
k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_tok, 1);
v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_tok, 1);
cb(q_conv, "q_conv_cont", il);
cb(k_conv, "k_conv_cont", il);
cb(v_conv, "v_conv_cont", il);
cb(q_conv, "q_conv", il);
cb(k_conv, "k_conv", il);
cb(v_conv, "v_conv", il);
const float eps_norm = hparams.f_norm_rms_eps;
q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
if (num_k_heads != num_v_heads) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
@@ -4974,7 +4997,9 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * new_conv_states = ggml_view_2d(ctx0, conv_output_raw, hparams.ssm_d_conv - 1, conv_dim,
hparams.ssm_d_conv * ggml_element_size(conv_output_raw),
(1 + conv_dim * n_tok) * ggml_element_size(conv_output_raw));
ggml_tensor * new_conv_flat = ggml_reshape_2d(ctx0, ggml_cont(ctx0, new_conv_states), conv_state_dim, 1);
auto new_conv_states_cont = ggml_cont(ctx0, new_conv_states);
cb(new_conv_states_cont, "new_conv_states_cont", il);
ggml_tensor * new_conv_flat = ggml_reshape_2d(ctx0, new_conv_states_cont, conv_state_dim, 1);
ggml_tensor * new_ssm_flat = ggml_reshape_2d(ctx0, new_state, ssm_state_dim, 1);
ggml_tensor * new_state_flat = ggml_concat(ctx0, new_conv_flat, new_ssm_flat, 0);
@@ -4989,7 +5014,9 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * attn_out_norm = llm_build_norm(ctx0, attn_out_2d, hparams, model.layers[il].ssm_norm, nullptr, LLM_NORM_RMS, cb, il);
ggml_tensor * gated_silu = ggml_silu(ctx0, z_2d);
cb(gated_silu, "gated_silu", il);
attn_out_norm = ggml_mul(ctx0, attn_out_norm, gated_silu);
cb(attn_out_norm, "attn_out_norm", il);
ggml_tensor * final_output = ggml_reshape_2d(ctx0, attn_out_norm, value_dim, n_tok);
cb(final_output, "final_output", il);
@@ -6694,6 +6721,7 @@ ggml_cgraph * llm_build_context::build_mamba() {
// {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
cb(y, "y", il);
// {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens}
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, y);

View File

@@ -418,7 +418,7 @@ struct llama_model {
~llama_model();
// Not actually needed, but left in place for now
size_t max_nodes() const { return 65536; }
size_t max_nodes() const { return 65536 * 2; }
bool has_tensor_overrides() const {
return tensor_overrides;