This commit is contained in:
Kawrakow
2026-01-12 11:39:50 +02:00
parent 3d9ee861f8
commit 4e4fabf0b4
2 changed files with 6 additions and 12 deletions

View File

@@ -2802,7 +2802,6 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
}
}
bool is_first = true;
for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = moe_counts[i02];
@@ -2891,17 +2890,12 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
}
} else {
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
//if (is_first) {
// printf("Doing ggml_swiglu_oai_cuda_f32: %ld %zu %ld %ld %ld\n", dst->ne[0], ggml_nelements(&dst_row)/2, dst_row.ne[0], src0_1->ne[1], num_src1_rows);
// is_first = false;
//}
ggml_swiglu_oai_cuda_f32((const float *)dst_up_contiguous.get() + dst->ne[0], (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row)/2, dst->ne[0], src0_1->ne[1], src0_1->ne[1],
1.702f, 7.0f, stream);
} else {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get());
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row)/2, dst->ne[0],
(const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
}
dst_row.data = dst_gate_contiguous.get();
dst_row.ne[0] /= 2;

View File

@@ -70,7 +70,7 @@ static __global__ void fused_mul_silu_f32(const float * x, float * dst, const in
int row = i / ne0;
int j = i % ne0;
auto x_row = x + 2*row*ne0;
dst[i] = x_row[j] * x_row[j + ne0] / (1.0f + expf(-x[j]));
dst[i] = x_row[j] * x_row[j + ne0] / (1.0f + expf(-x_row[j + ne0]));
}
static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) {
@@ -91,7 +91,7 @@ static __global__ void fused_mul_relu_f32(const float * x, float * dst, const in
int row = i / ne0;
int j = i % ne0;
auto x_row = x + 2*row*ne0;
dst[i] = fmaxf(x_row[j], 0) * x_row[j + ne0];
dst[i] = fmaxf(x_row[j + ne0], 0) * x_row[j];
}
static __global__ void fused_mul_gelu_f32(const float * x, const float * y, float * dst, const int k) {
@@ -117,8 +117,8 @@ static __global__ void fused_mul_gelu_f32(const float * x, float * dst, const in
int row = i / ne0;
int j = i % ne0;
auto x_row = x + 2*row*ne0;
float xi = x_row[j];
dst[i] = 0.5f*xi*x_row[j+ne0]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
float xi = x_row[j + ne0];
dst[i] = 0.5f*xi*x_row[j]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
}
static __global__ void tanh_f32(const float * x, float * dst, int k) {