gpt_oss: Implement -fmoe on the CPU

This commit is contained in:
Iwan Kawrakow
2025-08-12 18:44:40 +03:00
parent aa5a187a44
commit 8bd983300c
3 changed files with 101 additions and 44 deletions

View File

@@ -15485,6 +15485,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
const struct ggml_tensor * src1 = dst->src[2];
const struct ggml_tensor * ids = dst->src[3];
const struct ggml_tensor * up_b = dst->src[4];
const struct ggml_tensor * gate_b = dst->src[5];
const struct ggml_tensor * src0_1 = dst->src[0];
const struct ggml_tensor * src0_2 = dst->src[1];
const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works
@@ -15509,6 +15511,9 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ne13 == 1);
const size_t nb41 = up_b ? up_b->nb[1] : 0;
const size_t nb51 = up_b ? gate_b->nb[1] : 0;
// row groups
const int n_ids = ids->ne[0]; // n_expert_used
const int n_as = ne02; // n_expert
@@ -15596,6 +15601,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02;
const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02;
const char * up_b_cur = up_b ? (const char *)up_b->data + cur_a*nb41 : NULL;
const char * gate_b_cur = gate_b ? (const char *)gate_b->data + cur_a*nb51 : NULL;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -15606,6 +15613,7 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0],
type, src0_1_cur, src0_2_cur, nb01,
vec_dot_type, (const char *)wdata, row_size,
up_b_cur, gate_b_cur,
(float *)dst->data, nb1, nb2,
matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");

View File

@@ -120,16 +120,21 @@ struct MulMat {
funcs[n_left-1](n, vx, bx, info, nrc_x);
}
}
inline void gelu(int n, const float * src, float * dst);
inline void relu(int n, const float * src, float * dst);
inline void silu(int n, const float * src, float * dst);
inline void activate(ggml_unary_op op, int n, const float * src, float * dst) {
inline static void gelu(int n, const float * src, float * dst);
inline static void relu(int n, const float * src, float * dst);
inline static void silu(int n, const float * src, float * dst);
inline static void swiglu_oai(int n, const float * src, float * dst);
inline static void clamp_oai(int n, float *x);
inline static void activate(ggml_unary_op op, int n, const float * src, float * dst) {
if (op == GGML_UNARY_OP_GELU) gelu(n, src, dst);
else if (op == GGML_UNARY_OP_RELU) relu(n, src, dst);
else if (op == GGML_UNARY_OP_SILU) silu(n, src, dst);
else if (op == GGML_UNARY_OP_SWIGLU_OAI) swiglu_oai(n, src, dst);
else GGML_ABORT("fatal error");
}
inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, DataInfo& info, int nrc_x, int nrc_y, int unary_op) {
inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx,
const float * up_b, const float * gate_b,
DataInfo& info, int nrc_x, int nrc_y, int unary_op) {
#ifdef __aarch64__
constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small)
#else
@@ -137,6 +142,29 @@ struct MulMat {
#endif
auto op = ggml_unary_op(unary_op);
float tmp[k_x_step*16];
auto process = [&tmp, n, op, vx_gate, vx_up, gate_b, up_b, bx, xstep = k_x_step] (mul_mat_t func, const DataInfo& this_info, int ix, int this_nrc_x, int ny) {
func(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny; ++ky) {
if (gate_b) {
auto b = gate_b + ix;
auto x = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) x[j] += b[j];
}
activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*xstep);
}
func(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny; ++ky) {
auto result = this_info.dst_row(ky);
if (up_b) {
auto b = up_b + ix;
for (int j = 0; j < this_nrc_x; ++j) result[j] += b[j];
}
if (op == GGML_UNARY_OP_SWIGLU_OAI) {
clamp_oai(this_nrc_x, result);
}
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*xstep + j];
}
};
if (func16 && nrc_y >= 16) {
int n_step = (nrc_y - info.cur_y)/16;
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
@@ -144,15 +172,7 @@ struct MulMat {
this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
for (int iy = 0; iy < n_step; ++iy) {
func16(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < 16; ++ky) {
activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
}
func16(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < 16; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
process(func16, this_info, ix, this_nrc_x, 16);
this_info.cur_y += 16;
}
}
@@ -175,23 +195,11 @@ struct MulMat {
this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
for (int iy = 0; iy < my1; ++iy) {
funcs[ny1-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny1; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
funcs[ny1-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny1; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
process(funcs[ny1-1], this_info, ix, this_nrc_x, ny1);
this_info.cur_y += ny1;
}
for (int iy = 0; iy < my2; ++iy) {
funcs[ny2-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny2; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
funcs[ny2-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny2; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
process(funcs[ny2-1], this_info, ix, this_nrc_x, ny2);
this_info.cur_y += ny2;
}
}
@@ -203,13 +211,7 @@ struct MulMat {
this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
for (int iy = 0; iy < n_step; ++iy) {
funcs[ny-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
funcs[ny-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
process(funcs[ny-1], this_info, ix, this_nrc_x, ny);
this_info.cur_y += ny;
}
}
@@ -222,13 +224,7 @@ struct MulMat {
auto this_info = info;
this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
funcs[n_left-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < n_left; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
funcs[n_left-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < n_left; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
process(funcs[n_left-1], this_info, ix, this_nrc_x, n_left);
}
}
}
@@ -731,6 +727,7 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
int typeA, const void * Aup, const void * Agate, long strideA,
int typeB, const void * B, long strideB,
const char * up_b_c, const char * gate_b_c,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {
const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping;
@@ -774,7 +771,9 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
if (!iqk_convert_repack(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) {
GGML_ABORT("Fatal error");
}
mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, this_info, this_nrc_x, Ny, unary_op);
auto up_b = up_b_c ? (const float *)up_b_c + first_x + ix : nullptr;
auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x + ix : nullptr;
mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, up_b, gate_b, this_info, this_nrc_x, Ny, unary_op);
}
return true;
@@ -795,7 +794,10 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
nrc_x *= num_rows;
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float),
row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny, unary_op);
auto up_b = up_b_c ? (const float *)up_b_c + first_x : nullptr;
auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x : nullptr;
mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx,
up_b, gate_b, info, nrc_x, Ny, unary_op);
return true;
}
@@ -993,6 +995,21 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
namespace {
// TODO: these swiglu_oai constants shouldn't be hard coded
constexpr float k_swiglu_oai_alpha = 1.702f;
constexpr float k_swiglu_oai_limit = 7.f;
void MulMat::swiglu_oai(int n, const float * x, float * y) {
for (int i = 0; i < n; ++i) {
auto xi = std::min(x[i], k_swiglu_oai_limit);
y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
}
}
void MulMat::clamp_oai(int n, float * x) {
for (int i = 0; i < n; ++i) x[i] = 1.f + std::max(std::min(x[i], k_swiglu_oai_limit), -k_swiglu_oai_limit);
}
#if defined(__ARM_NEON) && defined(__aarch64__)
void MulMat::gelu(int n, const float * x, float * y) {
constexpr float GELU_COEF_A = 0.044715f;
@@ -1040,6 +1057,37 @@ void MulMat::gelu(int n, const float * x, float * y) {
for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i])));
}
//void MulMat::swiglu_oai(int n, const float * x, float * y) {
// int i = 0;
//#if defined __AVX512F__ && defined __AVX512DQ__
// {
// auto limit = _mm512_set1_ps(k_swiglu_oai_limit);
// auto alpha = _mm512_set1_ps(k_swiglu_oai_alpha);
// for (; i + 15 < n; i += 16) {
// auto xi = _mm512_loadu_ps(x + i);
// auto mask = _mm512_cmp
//
// }
// __m512 c1 = _mm512_set1_ps(GELU_COEF_A);
// __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI);
// for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_gelu(_mm512_loadu_ps(x + i), c1, c2));
// }
//#endif
//#if defined __AVX2__ && defined __FMA__
// if (i + 7 < n) {
// __m256 c1 = _mm256_set1_ps(GELU_COEF_A);
// __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI);
// for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_gelu(_mm256_loadu_ps(x + i), c1, c2));
//
// }
//#endif
// for (; i < n; ++i) {
// auto xi = std::min(x[i], k_swiglu_oai_limit);
// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
// }
//}
void MulMat::silu(int n, const float * x, float * y) {
int i = 0;
#if defined __AVX512F__ && defined __AVX512DQ__

View File

@@ -32,6 +32,7 @@ IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
int typeA, const void * Aup, const void * Agate, long strideA,
int typeB, const void * B, long strideB,
const char * up_b, const char * gate_b,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
IQK_API int iqk_dequant_type(int type, int Ny);