From 8bd983300cd450ded1e6d599d1ce19670e5e4183 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 12 Aug 2025 18:44:40 +0300 Subject: [PATCH] gpt_oss: Implement -fmoe on the CPU --- ggml/src/ggml.c | 8 +++ ggml/src/iqk/iqk_mul_mat.cpp | 136 +++++++++++++++++++++++------------ ggml/src/iqk/iqk_mul_mat.h | 1 + 3 files changed, 101 insertions(+), 44 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 58755fe8..9ec8e518 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -15485,6 +15485,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate( const struct ggml_tensor * src1 = dst->src[2]; const struct ggml_tensor * ids = dst->src[3]; + const struct ggml_tensor * up_b = dst->src[4]; + const struct ggml_tensor * gate_b = dst->src[5]; const struct ggml_tensor * src0_1 = dst->src[0]; const struct ggml_tensor * src0_2 = dst->src[1]; const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works @@ -15509,6 +15511,9 @@ static void ggml_compute_forward_mul_mat_id_up_gate( GGML_ASSERT(nb2 <= nb3); GGML_ASSERT(ne13 == 1); + const size_t nb41 = up_b ? up_b->nb[1] : 0; + const size_t nb51 = up_b ? gate_b->nb[1] : 0; + // row groups const int n_ids = ids->ne[0]; // n_expert_used const int n_as = ne02; // n_expert @@ -15596,6 +15601,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate( const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02; const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02; + const char * up_b_cur = up_b ? (const char *)up_b->data + cur_a*nb41 : NULL; + const char * gate_b_cur = gate_b ? (const char *)gate_b->data + cur_a*nb51 : NULL; const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); @@ -15606,6 +15613,7 @@ static void ggml_compute_forward_mul_mat_id_up_gate( if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0], type, src0_1_cur, src0_2_cur, nb01, vec_dot_type, (const char *)wdata, row_size, + up_b_cur, gate_b_cur, (float *)dst->data, nb1, nb2, matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 90882591..eb0ca056 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -120,16 +120,21 @@ struct MulMat { funcs[n_left-1](n, vx, bx, info, nrc_x); } } - inline void gelu(int n, const float * src, float * dst); - inline void relu(int n, const float * src, float * dst); - inline void silu(int n, const float * src, float * dst); - inline void activate(ggml_unary_op op, int n, const float * src, float * dst) { + inline static void gelu(int n, const float * src, float * dst); + inline static void relu(int n, const float * src, float * dst); + inline static void silu(int n, const float * src, float * dst); + inline static void swiglu_oai(int n, const float * src, float * dst); + inline static void clamp_oai(int n, float *x); + inline static void activate(ggml_unary_op op, int n, const float * src, float * dst) { if (op == GGML_UNARY_OP_GELU) gelu(n, src, dst); else if (op == GGML_UNARY_OP_RELU) relu(n, src, dst); else if (op == GGML_UNARY_OP_SILU) silu(n, src, dst); + else if (op == GGML_UNARY_OP_SWIGLU_OAI) swiglu_oai(n, src, dst); else GGML_ABORT("fatal error"); } - inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, DataInfo& info, int nrc_x, int nrc_y, int unary_op) { + inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, + const float * up_b, const float * gate_b, + DataInfo& info, int nrc_x, int nrc_y, int unary_op) { #ifdef __aarch64__ constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small) #else @@ -137,6 +142,29 @@ struct MulMat { #endif auto op = ggml_unary_op(unary_op); float tmp[k_x_step*16]; + auto process = [&tmp, n, op, vx_gate, vx_up, gate_b, up_b, bx, xstep = k_x_step] (mul_mat_t func, const DataInfo& this_info, int ix, int this_nrc_x, int ny) { + func(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny; ++ky) { + if (gate_b) { + auto b = gate_b + ix; + auto x = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) x[j] += b[j]; + } + activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*xstep); + } + func(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny; ++ky) { + auto result = this_info.dst_row(ky); + if (up_b) { + auto b = up_b + ix; + for (int j = 0; j < this_nrc_x; ++j) result[j] += b[j]; + } + if (op == GGML_UNARY_OP_SWIGLU_OAI) { + clamp_oai(this_nrc_x, result); + } + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*xstep + j]; + } + }; if (func16 && nrc_y >= 16) { int n_step = (nrc_y - info.cur_y)/16; for (int ix = 0; ix < nrc_x; ix += k_x_step) { @@ -144,15 +172,7 @@ struct MulMat { this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; for (int iy = 0; iy < n_step; ++iy) { - func16(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < 16; ++ky) { - activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - } - func16(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < 16; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(func16, this_info, ix, this_nrc_x, 16); this_info.cur_y += 16; } } @@ -175,23 +195,11 @@ struct MulMat { this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; for (int iy = 0; iy < my1; ++iy) { - funcs[ny1-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny1; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - funcs[ny1-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny1; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(funcs[ny1-1], this_info, ix, this_nrc_x, ny1); this_info.cur_y += ny1; } for (int iy = 0; iy < my2; ++iy) { - funcs[ny2-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny2; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - funcs[ny2-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny2; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(funcs[ny2-1], this_info, ix, this_nrc_x, ny2); this_info.cur_y += ny2; } } @@ -203,13 +211,7 @@ struct MulMat { this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; for (int iy = 0; iy < n_step; ++iy) { - funcs[ny-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - funcs[ny-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(funcs[ny-1], this_info, ix, this_nrc_x, ny); this_info.cur_y += ny; } } @@ -222,13 +224,7 @@ struct MulMat { auto this_info = info; this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; - funcs[n_left-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < n_left; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - funcs[n_left-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < n_left; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(funcs[n_left-1], this_info, ix, this_nrc_x, n_left); } } } @@ -731,6 +727,7 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, int typeA, const void * Aup, const void * Agate, long strideA, int typeB, const void * B, long strideB, + const char * up_b_c, const char * gate_b_c, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping; @@ -774,7 +771,9 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n if (!iqk_convert_repack(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) { GGML_ABORT("Fatal error"); } - mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, this_info, this_nrc_x, Ny, unary_op); + auto up_b = up_b_c ? (const float *)up_b_c + first_x + ix : nullptr; + auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x + ix : nullptr; + mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, up_b, gate_b, this_info, this_nrc_x, Ny, unary_op); } return true; @@ -795,7 +794,10 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n nrc_x *= num_rows; DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; - mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny, unary_op); + auto up_b = up_b_c ? (const float *)up_b_c + first_x : nullptr; + auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x : nullptr; + mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, + up_b, gate_b, info, nrc_x, Ny, unary_op); return true; } @@ -993,6 +995,21 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { namespace { +// TODO: these swiglu_oai constants shouldn't be hard coded +constexpr float k_swiglu_oai_alpha = 1.702f; +constexpr float k_swiglu_oai_limit = 7.f; + +void MulMat::swiglu_oai(int n, const float * x, float * y) { + for (int i = 0; i < n; ++i) { + auto xi = std::min(x[i], k_swiglu_oai_limit); + y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha)); + } +} + +void MulMat::clamp_oai(int n, float * x) { + for (int i = 0; i < n; ++i) x[i] = 1.f + std::max(std::min(x[i], k_swiglu_oai_limit), -k_swiglu_oai_limit); +} + #if defined(__ARM_NEON) && defined(__aarch64__) void MulMat::gelu(int n, const float * x, float * y) { constexpr float GELU_COEF_A = 0.044715f; @@ -1040,6 +1057,37 @@ void MulMat::gelu(int n, const float * x, float * y) { for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i]))); } +//void MulMat::swiglu_oai(int n, const float * x, float * y) { +// int i = 0; +//#if defined __AVX512F__ && defined __AVX512DQ__ +// { +// auto limit = _mm512_set1_ps(k_swiglu_oai_limit); +// auto alpha = _mm512_set1_ps(k_swiglu_oai_alpha); +// for (; i + 15 < n; i += 16) { +// auto xi = _mm512_loadu_ps(x + i); +// auto mask = _mm512_cmp +// +// } +// __m512 c1 = _mm512_set1_ps(GELU_COEF_A); +// __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI); +// for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_gelu(_mm512_loadu_ps(x + i), c1, c2)); +// } +//#endif +//#if defined __AVX2__ && defined __FMA__ +// if (i + 7 < n) { +// __m256 c1 = _mm256_set1_ps(GELU_COEF_A); +// __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI); +// for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_gelu(_mm256_loadu_ps(x + i), c1, c2)); +// +// } +//#endif +// for (; i < n; ++i) { +// auto xi = std::min(x[i], k_swiglu_oai_limit); +// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha)); +// } +//} + + void MulMat::silu(int n, const float * x, float * y) { int i = 0; #if defined __AVX512F__ && defined __AVX512DQ__ diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index bce1a935..b131095b 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -32,6 +32,7 @@ IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, int typeA, const void * Aup, const void * Agate, long strideA, int typeB, const void * B, long strideB, + const char * up_b, const char * gate_b, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); IQK_API int iqk_dequant_type(int type, int Ny);