mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 06:34:13 +00:00
WIP
This commit is contained in:
@@ -2802,7 +2802,6 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
|
||||
}
|
||||
}
|
||||
|
||||
bool is_first = true;
|
||||
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
||||
int64_t num_src1_rows = moe_counts[i02];
|
||||
|
||||
@@ -2891,17 +2890,12 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
|
||||
}
|
||||
} else {
|
||||
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
|
||||
//if (is_first) {
|
||||
// printf("Doing ggml_swiglu_oai_cuda_f32: %ld %zu %ld %ld %ld\n", dst->ne[0], ggml_nelements(&dst_row)/2, dst_row.ne[0], src0_1->ne[1], num_src1_rows);
|
||||
// is_first = false;
|
||||
//}
|
||||
ggml_swiglu_oai_cuda_f32((const float *)dst_up_contiguous.get() + dst->ne[0], (const float *)dst_up_contiguous.get(),
|
||||
(float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row)/2, dst->ne[0], src0_1->ne[1], src0_1->ne[1],
|
||||
1.702f, 7.0f, stream);
|
||||
} else {
|
||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),
|
||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||
(float *)dst_gate_contiguous.get());
|
||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row)/2, dst->ne[0],
|
||||
(const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
|
||||
}
|
||||
dst_row.data = dst_gate_contiguous.get();
|
||||
dst_row.ne[0] /= 2;
|
||||
|
||||
@@ -70,7 +70,7 @@ static __global__ void fused_mul_silu_f32(const float * x, float * dst, const in
|
||||
int row = i / ne0;
|
||||
int j = i % ne0;
|
||||
auto x_row = x + 2*row*ne0;
|
||||
dst[i] = x_row[j] * x_row[j + ne0] / (1.0f + expf(-x[j]));
|
||||
dst[i] = x_row[j] * x_row[j + ne0] / (1.0f + expf(-x_row[j + ne0]));
|
||||
}
|
||||
|
||||
static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) {
|
||||
@@ -91,7 +91,7 @@ static __global__ void fused_mul_relu_f32(const float * x, float * dst, const in
|
||||
int row = i / ne0;
|
||||
int j = i % ne0;
|
||||
auto x_row = x + 2*row*ne0;
|
||||
dst[i] = fmaxf(x_row[j], 0) * x_row[j + ne0];
|
||||
dst[i] = fmaxf(x_row[j + ne0], 0) * x_row[j];
|
||||
}
|
||||
|
||||
static __global__ void fused_mul_gelu_f32(const float * x, const float * y, float * dst, const int k) {
|
||||
@@ -117,8 +117,8 @@ static __global__ void fused_mul_gelu_f32(const float * x, float * dst, const in
|
||||
int row = i / ne0;
|
||||
int j = i % ne0;
|
||||
auto x_row = x + 2*row*ne0;
|
||||
float xi = x_row[j];
|
||||
dst[i] = 0.5f*xi*x_row[j+ne0]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
|
||||
float xi = x_row[j + ne0];
|
||||
dst[i] = 0.5f*xi*x_row[j]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
|
||||
}
|
||||
|
||||
static __global__ void tanh_f32(const float * x, float * dst, int k) {
|
||||
|
||||
Reference in New Issue
Block a user