More Qwen3-Next optimizations (#1277)

* Optimizing q3next TG

* Fused add -> softplus -> mul on CUDA

* Remove forgotten debug log

* Increase ggml context size

Required for Qwen3-Next with batch/u-batch size of 4096

* WIP

* Avoid some contiguous ops

* Avoid some repeats

* Avoid some more repeats
This commit is contained in:
Kawrakow
2026-02-17 16:03:51 +01:00
committed by GitHub
parent 88f98c891d
commit cafeef484c
7 changed files with 378 additions and 79 deletions

View File

@@ -2205,6 +2205,45 @@ static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tenso
}
template <typename src_t, int block_size = 256>
static __global__ void mul_mat_row(int n, const src_t * x, const float * y, float * z) {
float sum = 0;
for (int i = threadIdx.x; i < n; i += block_size) {
float xi = ggml_cuda_cast<float, src_t>(x[i]);
sum += xi * y[i];
}
sum = warp_reduce_sum(sum);
if constexpr (block_size > WARP_SIZE) {
__shared__ float tmp[block_size/WARP_SIZE];
if (threadIdx.x % WARP_SIZE == 0) {
tmp[threadIdx.x / WARP_SIZE] = sum;
}
__syncthreads();
sum = threadIdx.x < block_size / WARP_SIZE ? tmp[threadIdx.x] : 0.0f;
sum = warp_reduce_sum(sum);
}
if (threadIdx.x == 0) {
z[0] = sum;
}
}
static void mul_mat_1row(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_backend_cuda_context & ctx) {
constexpr int kBlockSize = 256;
GGML_ASSERT(src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
if (src0->type == GGML_TYPE_F16) {
mul_mat_row<<<1, kBlockSize, 0, ctx.stream()>>>((int)src0->ne[0], (const half *)src0->data, (const float *)src1->data, (float *)dst->data);
}
else if (src0->type == GGML_TYPE_BF16) {
mul_mat_row<<<1, kBlockSize, 0, ctx.stream()>>>((int)src0->ne[0], (const nv_bfloat16 *)src0->data, (const float *)src1->data, (float *)dst->data);
}
else if (src0->type == GGML_TYPE_F32) {
mul_mat_row<<<1, kBlockSize, 0, ctx.stream()>>>((int)src0->ne[0], (const float *)src0->data, (const float *)src1->data, (float *)dst->data);
}
else {
GGML_ABORT("Fatal error");
}
}
static int ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const ggml_cgraph * cgraph, int node_n) {
@@ -2239,6 +2278,12 @@ static int ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
return ggml_cuda_mul_mat_q(ctx, src0, src1, dst, cgraph, node_n, use_mul_mat_vec_q);
}
if (ggml_nrows(src0) == 1 && ggml_nrows(src1) == 1) { // && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
mul_mat_1row(src0, src1, dst, ctx);
return node_n;
}
bool debug = false; //src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32;
// debug helpers
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
@@ -2248,29 +2293,29 @@ static int ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
if (any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
//printf("%s(%s): using ggml_cuda_mul_mat_vec_p021\n", __func__, dst->name);
if (debug) printf("%s(%s): using ggml_cuda_mul_mat_vec_p021\n", __func__, dst->name);
// FP32 precision KQ single-batch for batch size 1 without FlashAttention
ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
} else if (any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
//printf("%s(%s): using ggml_cuda_mul_mat_vec_nc\n", __func__, dst->name);
if (debug) printf("%s(%s): using ggml_cuda_mul_mat_vec_nc\n", __func__, dst->name);
// FP32 precision KQV single-batch for batch size 1 without FlashAttention
ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
} else if ((src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32) && (src1->type == src0->type || !any_gpus_with_slow_fp16)
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
//printf("%s(%s): ggml_cuda_mul_mat_batched_cublas\n", __func__, dst->name);
if (debug) printf("%s(%s): ggml_cuda_mul_mat_batched_cublas\n", __func__, dst->name);
// KQ + KQV multi-batch without FlashAttention
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
} else if (use_dequantize_mul_mat_vec) {
//printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_dequantize_mul_mat_vec)\n", __func__, dst->name);
if (debug) printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_dequantize_mul_mat_vec)\n", __func__, dst->name);
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
} else if (use_mul_mat_vec_q) {
//printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_vec_q)\n", __func__, dst->name);
if (debug) printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_vec_q)\n", __func__, dst->name);
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
} else if (use_mul_mat_q) {
//printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_q)\n", __func__, dst->name);
if (debug) printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_q)\n", __func__, dst->name);
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
} else {
//printf("%s(%s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_cublas)\n", __func__, dst->name);
if (debug) printf("%s(%s, %s): ggml_cuda_op_mul_mat(ggml_cuda_op_mul_mat_cublas)\n", __func__, dst->name, ggml_type_name(src0->type));
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
}
return node_n;
@@ -3218,6 +3263,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_sum_rows_nc(ctx, cgraph->nodes[i+1]);
i += 2;
} else {
//auto src = dst->src[0];
//printf("cont(%s -> %s): %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu\n", src->name, dst->name, src->ne[0], src->ne[1], src->ne[2], src->ne[3], src->nb[0], src->nb[1], src->nb[2], src->nb[3]);
ggml_cuda_dup(ctx, dst);
}
break;
@@ -3237,6 +3284,17 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_fused_add_add_rms_norm(ctx, dst, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
}
else if (fusion && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_UNARY &&
(ggml_unary_op)cgraph->nodes[i+1]->op_params[0] == GGML_UNARY_OP_SOFTPLUS &&
cgraph->nodes[i+2]->op == GGML_OP_MUL &&
cgraph->nodes[i+2]->src[0] == cgraph->nodes[i+1] &&
cgraph->nodes[i+1]->src[0] == cgraph->nodes[i] &&
ggml_nrows(cgraph->nodes[i+0]->src[1]) == 1 &&
ggml_nrows(cgraph->nodes[i+2]->src[1]) == 1) {
ggml_cuda_fused_softplus(ctx, cgraph->nodes[i+2]);
i += 2;
}
else if (false && fusion && i + 1 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_FUSED_RMS_NORM &&
ggml_is_contiguous(dst->src[0]) &&
@@ -3262,10 +3320,25 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_acc(ctx, dst);
break;
case GGML_OP_MUL:
//printf("mul(%s): %d, %d, %d, %ld x %ld x %ld x %ld * %ld x %ld x %ld x %ld\n", dst->name, ggml_is_contiguous(dst->src[0]), ggml_is_contiguous(dst->src[1]), ggml_is_contiguous(dst),
// dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->src[0]->ne[3],
// dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], dst->src[1]->ne[3]);
ggml_cuda_op_mul(ctx, dst);
if (fusion && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_UNARY &&
cgraph->nodes[i+2]->op == GGML_OP_MUL &&
(ggml_unary_op)cgraph->nodes[i+1]->op_params[0] == GGML_UNARY_OP_EXP &&
cgraph->nodes[i+1]->src[0] == dst &&
cgraph->nodes[i+2]->src[0] == cgraph->nodes[i+1] &&
cgraph->nodes[i+2]->src[1] == dst->src[1]) {
ggml_cuda_fused_mul_exp_mul(ctx, cgraph->nodes[i+2]);
i += 2;
//printf("mul(%s) -> exp(%s) -> mul(%s), %d, %d, %zu, %zu; %ld x %ld x %ld x %ld - %ld x %ld x %ld x %ld\n", dst->name, cgraph->nodes[i+1]->name, cgraph->nodes[i+2]->name,
// ggml_is_contiguous(dst->src[0]), ggml_is_contiguous(dst->src[1]), ggml_nelements(dst->src[0]), ggml_nelements(dst->src[1]),
// dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->src[0]->ne[3],
// dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], dst->src[1]->ne[3]);
} else {
//printf("mul(%s): %d, %d, %d, %ld x %ld x %ld x %ld * %ld x %ld x %ld x %ld\n", dst->name, ggml_is_contiguous(dst->src[0]), ggml_is_contiguous(dst->src[1]), ggml_is_contiguous(dst),
// dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->src[0]->ne[3],
// dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], dst->src[1]->ne[3]);
ggml_cuda_op_mul(ctx, dst);
}
break;
case GGML_OP_FUSED_MUL_UNARY:
ggml_cuda_op_fused_mul_unary(ctx, dst);
@@ -3277,6 +3350,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_sub(ctx, dst);
break;
case GGML_OP_UNARY:
//printf("unary(%s, %s)\n", dst->name, ggml_unary_op_name((ggml_unary_op)dst->op_params[0]));
switch (ggml_get_unary_op(dst)) {
case GGML_UNARY_OP_GELU:
ggml_cuda_op_gelu(ctx, dst);

View File

@@ -33,6 +33,40 @@ static void pad_f32_cuda(const float * x, float * dst,
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
}
template <int dim>
static __global__ void pad_f32_nc(const char * cx, float * dst, int nelem,
int ne0, int ne1, int ne2, int ne3, int ne00, int ne01, int ne02, int ne03,
size_t nb00, size_t nb01, size_t nb02, size_t nb03) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i >= nelem) {
return;
}
int ii = i;
int i3 = ii/(ne0*ne1*ne2); ii -= i3*ne0*ne1*ne2;
int i2 = ii/(ne0*ne1 ); ii -= i2*ne0*ne1;
int i1 = ii/(ne0 );
int i0 = ii - i1*ne0;
if constexpr (dim == 0) {
dst[i] = i0 < ne00 ? *(const float *)(cx + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03) : 0.0f;
}
else if constexpr (dim == 1) {
dst[i] = i1 < ne01 ? *(const float *)(cx + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03) : 0.0f;
}
else if constexpr (dim == 2) {
dst[i] = i2 < ne02 ? *(const float *)(cx + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03) : 0.0f;
}
else if constexpr (dim == 3) {
dst[i] = i3 < ne03 ? *(const float *)(cx + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03) : 0.0f;
}
else if constexpr (dim == 4) {
dst[i] = *(const float *)(cx + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03);
}
else {
dst[i] = i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03 ? *(const float *)(cx + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03) : 0.0f;
}
}
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@@ -41,9 +75,53 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
GGML_ASSERT(ggml_is_contiguous(dst));
if (ggml_is_contiguous(src0)) {
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
pad_f32_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
pad_f32_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
return;
}
int npad = 0; int pad_dim = -1;
for (int i = 0; i < 4; ++i) {
if (dst->ne[i] > src0->ne[i]) {
++npad; pad_dim = i;
}
}
//if (npad == 0) {
// printf("Oops: npad = 0: %ld vs %ld, %ld vx %ld, %ld vs %ld, %ld vs %ld\n", dst->ne[0], src0->ne[0], dst->ne[1], src0->ne[1], dst->ne[2], src0->ne[2], dst->ne[3], src0->ne[3]);
//}
//GGML_ASSERT(npad > 0);
constexpr int kBlockSize = 256;
int nelem = ggml_nelements(dst);
int nblock = (nelem + kBlockSize - 1)/kBlockSize;
if (npad == 0) {
//printf("%s: %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu\n", src0->name, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
pad_f32_nc<4><<<nblock, kBlockSize, 0, ctx.stream()>>>((const char *)src0->data, (float *)dst->data, nelem,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
} else if (npad == 1) {
if (pad_dim == 0) {
pad_f32_nc<0><<<nblock, kBlockSize, 0, ctx.stream()>>>((const char *)src0->data, (float *)dst->data, nelem,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
} else if (pad_dim == 1) {
pad_f32_nc<1><<<nblock, kBlockSize, 0, ctx.stream()>>>((const char *)src0->data, (float *)dst->data, nelem,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
} else if (pad_dim == 2) {
pad_f32_nc<2><<<nblock, kBlockSize, 0, ctx.stream()>>>((const char *)src0->data, (float *)dst->data, nelem,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
} else if (pad_dim == 3) {
pad_f32_nc<3><<<nblock, kBlockSize, 0, ctx.stream()>>>((const char *)src0->data, (float *)dst->data, nelem,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
} else {
GGML_ABORT("Fatal error");
}
} else {
pad_f32_nc<-1><<<nblock, kBlockSize, 0, ctx.stream()>>>((const char *)src0->data, (float *)dst->data, nelem,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
}
}

View File

@@ -950,3 +950,59 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_elu>(ctx, dst);
}
static __global__ void k_fused_softplus(int ne0, int nelem, const float * a, const float * b, const float * c, float * dst) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i >= nelem) {
return;
}
int i0 = i % ne0;
dst[i] = c[i0] * op_softplus(a[i] + b[i0]);
}
void ggml_cuda_fused_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int kBlockSize = 256;
auto m = dst;
auto u = dst->src[0];
auto a = u->src[0];
GGML_ASSERT(m->op == GGML_OP_MUL);
GGML_ASSERT(a->op == GGML_OP_ADD);
GGML_ASSERT(u->op == GGML_OP_UNARY && (ggml_unary_op)u->op_params[0] == GGML_UNARY_OP_SOFTPLUS);
GGML_ASSERT(ggml_nrows(m->src[1]) == 1 && m->src[1]->ne[0] == m->src[0]->ne[0]);
GGML_ASSERT(ggml_nrows(a->src[1]) == 1 && a->src[1]->ne[0] == a->src[0]->ne[0]);
GGML_ASSERT(a->type == GGML_TYPE_F32 && u->type == GGML_TYPE_F32 && m->type == GGML_TYPE_F32);
GGML_ASSERT(a->src[0]->type == GGML_TYPE_F32 && a->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(a->src[0]));
int nelem = ggml_nelements(a->src[0]);
int nblock = (nelem + kBlockSize - 1)/kBlockSize;
k_fused_softplus<<<nblock, kBlockSize, 0, ctx.stream()>>>(a->src[0]->ne[0], nelem,
(const float *)a->src[0]->data, (const float *)a->src[1]->data, (const float *)m->src[1]->data, (float *)dst->data);
}
static __global__ void k_fused_mul_exp_mul(int ne0, int nelem, const float * x, const float * y, float * dst) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i >= nelem) {
return;
}
int i0 = i % ne0;
dst[i] = y[i0] * expf(x[i] * y[i0]);
}
void ggml_cuda_fused_mul_exp_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int kBlockSize = 256;
auto m2 = dst;
auto u = dst->src[0];
auto m1 = u->src[0];
GGML_ASSERT(m1->src[0]->type == GGML_TYPE_F32 && m1->src[1]->type == GGML_TYPE_F32 && m2->type == GGML_TYPE_F32);
GGML_ASSERT(m1->src[1] == m2->src[1]);
GGML_ASSERT(ggml_is_contiguous(m1->src[0]) && ggml_is_contiguous(m1->src[1]));
GGML_ASSERT(u->op == GGML_OP_UNARY && (ggml_unary_op)u->op_params[0] == GGML_UNARY_OP_EXP);
auto nelem = ggml_nelements(m1->src[0]);
auto ne0 = ggml_nelements(m1->src[1]);
int nblock = (nelem + kBlockSize - 1)/kBlockSize;
k_fused_mul_exp_mul<<<nblock, kBlockSize, 0, ctx.stream()>>>(ne0, nelem,
(const float *)m1->src[0]->data, (const float *)m1->src[1]->data, (float *)dst->data);
}

View File

@@ -95,3 +95,7 @@ void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op,
int64_t nelements,int64_t ne0, const float * x, float * z, float limit = 0);
void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_fused_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_fused_mul_exp_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -21071,7 +21071,6 @@ static void ggml_compute_forward_pad_f32(
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT( dst->nb[0] == sizeof(float));
const int ith = params->ith;
@@ -21083,22 +21082,82 @@ static void ggml_compute_forward_pad_f32(
// TODO: optimize
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
for (int64_t i3 = 0; i3 < ne3; ++i3) {
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
if (src0->nb[0] == sizeof(float)) {
for (int64_t i3 = 0; i3 < ne3; ++i3) {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
dst_ptr[dst_idx] = *src_ptr;
} else {
dst_ptr[dst_idx] = 0;
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
dst_ptr[dst_idx] = *src_ptr;
} else {
dst_ptr[dst_idx] = 0;
}
}
}
}
}
} else {
const int k_block_size = 1024;
int nelem = ggml_nelements(dst);
int nblocks = (nelem + k_block_size - 1)/k_block_size;
for (int ib = ith; ib < nblocks; ib += nth) {
int first = ib*k_block_size;
int last = MIN(first + k_block_size, nelem);
//
//int ii = first;
//int i3 = ii/(ne0*ne1*ne2); ii -= i3*ne0*ne1*ne2;
//int i2 = ii/(ne0*ne1 ); ii -= i2*ne0*ne1;
//int i1 = ii/(ne0 ); ii -= i1*ne0;
//int i0 = ii;
//int i = first;
//bool in_src = i1 < ne01 && i2 < ne02 && i3 < ne03;
//const char * c_src = (const char *)src0->data + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03;
//while (i < last) {
// if (i0 + last - i <= ne0) {
// for (; i < last; ++i, ++i0) {
// dst_ptr[i] = in_src && i0 < ne00 ? *(const float *)(c_src + i0*nb00) : 0.0f;
// }
// break;
// }
// for (; i0 < ne0; ++i0) {
// dst_ptr[i++] = in_src && i0 < ne00 ? *(const float *)(c_src + i0*nb00) : 0.0f;
// }
// i0 = 0;
// if (++i1 == (int)ne1) {
// i1 = 0;
// if (++i2 == (int)ne2) {
// i2 = 0; ++i3;
// }
// }
// in_src = i1 < ne01 && i2 < ne02 && i3 < ne03;
// c_src = (const char *)src0->data + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03;
//}
//for (int i = first; i < last; ++i) {
// dst_ptr[i] = i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03 ? *(const float *)((const char *)src0->data + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03) : 0.0f;
// if (++i0 == (int)ne0) {
// i0 = 0;
// if (++i1 == (int)ne1) {
// i1 = 0;
// if (++i2 == (int)ne2) {
// i2 = 0; ++i3;
// }
// }
// }
//}
//
for (int i = first; i < last; ++i) {
int ii = i;
int i3 = ii/(ne0*ne1*ne2); ii -= i3*ne0*ne1*ne2;
int i2 = ii/(ne0*ne1 ); ii -= i2*ne0*ne1;
int i1 = ii/(ne0 ); ii -= i1*ne0;
int i0 = ii;
dst_ptr[i] = i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03 ? *(const float *)((const char *)src0->data + i0*nb00 + i1*nb01 + i2*nb02 + i3*nb03) : 0.0f;
}
}
}
}