diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index d5013b3b..cdf4893b 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2802,7 +2802,6 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten } } - bool is_first = true; for (int64_t i02 = 0; i02 < n_as; i02++) { int64_t num_src1_rows = moe_counts[i02]; @@ -2891,17 +2890,12 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten } } else { if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { - //if (is_first) { - // printf("Doing ggml_swiglu_oai_cuda_f32: %ld %zu %ld %ld %ld\n", dst->ne[0], ggml_nelements(&dst_row)/2, dst_row.ne[0], src0_1->ne[1], num_src1_rows); - // is_first = false; - //} ggml_swiglu_oai_cuda_f32((const float *)dst_up_contiguous.get() + dst->ne[0], (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row)/2, dst->ne[0], src0_1->ne[1], src0_1->ne[1], 1.702f, 7.0f, stream); } else { - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst), - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), - (float *)dst_gate_contiguous.get()); + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row)/2, dst->ne[0], + (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); } dst_row.data = dst_gate_contiguous.get(); dst_row.ne[0] /= 2; diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 659a0992..2152210e 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -70,7 +70,7 @@ static __global__ void fused_mul_silu_f32(const float * x, float * dst, const in int row = i / ne0; int j = i % ne0; auto x_row = x + 2*row*ne0; - dst[i] = x_row[j] * x_row[j + ne0] / (1.0f + expf(-x[j])); + dst[i] = x_row[j] * x_row[j + ne0] / (1.0f + expf(-x_row[j + ne0])); } static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) { @@ -91,7 +91,7 @@ static __global__ void fused_mul_relu_f32(const float * x, float * dst, const in int row = i / ne0; int j = i % ne0; auto x_row = x + 2*row*ne0; - dst[i] = fmaxf(x_row[j], 0) * x_row[j + ne0]; + dst[i] = fmaxf(x_row[j + ne0], 0) * x_row[j]; } static __global__ void fused_mul_gelu_f32(const float * x, const float * y, float * dst, const int k) { @@ -117,8 +117,8 @@ static __global__ void fused_mul_gelu_f32(const float * x, float * dst, const in int row = i / ne0; int j = i % ne0; auto x_row = x + 2*row*ne0; - float xi = x_row[j]; - dst[i] = 0.5f*xi*x_row[j+ne0]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); + float xi = x_row[j + ne0]; + dst[i] = 0.5f*xi*x_row[j]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); } static __global__ void tanh_f32(const float * x, float * dst, int k) {