mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 22:54:10 +00:00
gpt_oss: Implement -fmoe on the CPU
This commit is contained in:
@@ -15485,6 +15485,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
|
||||
|
||||
const struct ggml_tensor * src1 = dst->src[2];
|
||||
const struct ggml_tensor * ids = dst->src[3];
|
||||
const struct ggml_tensor * up_b = dst->src[4];
|
||||
const struct ggml_tensor * gate_b = dst->src[5];
|
||||
const struct ggml_tensor * src0_1 = dst->src[0];
|
||||
const struct ggml_tensor * src0_2 = dst->src[1];
|
||||
const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works
|
||||
@@ -15509,6 +15511,9 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
|
||||
const size_t nb41 = up_b ? up_b->nb[1] : 0;
|
||||
const size_t nb51 = up_b ? gate_b->nb[1] : 0;
|
||||
|
||||
// row groups
|
||||
const int n_ids = ids->ne[0]; // n_expert_used
|
||||
const int n_as = ne02; // n_expert
|
||||
@@ -15596,6 +15601,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
|
||||
|
||||
const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02;
|
||||
const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02;
|
||||
const char * up_b_cur = up_b ? (const char *)up_b->data + cur_a*nb41 : NULL;
|
||||
const char * gate_b_cur = gate_b ? (const char *)gate_b->data + cur_a*nb51 : NULL;
|
||||
|
||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
@@ -15606,6 +15613,7 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
|
||||
if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0],
|
||||
type, src0_1_cur, src0_2_cur, nb01,
|
||||
vec_dot_type, (const char *)wdata, row_size,
|
||||
up_b_cur, gate_b_cur,
|
||||
(float *)dst->data, nb1, nb2,
|
||||
matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
|
||||
|
||||
|
||||
@@ -120,16 +120,21 @@ struct MulMat {
|
||||
funcs[n_left-1](n, vx, bx, info, nrc_x);
|
||||
}
|
||||
}
|
||||
inline void gelu(int n, const float * src, float * dst);
|
||||
inline void relu(int n, const float * src, float * dst);
|
||||
inline void silu(int n, const float * src, float * dst);
|
||||
inline void activate(ggml_unary_op op, int n, const float * src, float * dst) {
|
||||
inline static void gelu(int n, const float * src, float * dst);
|
||||
inline static void relu(int n, const float * src, float * dst);
|
||||
inline static void silu(int n, const float * src, float * dst);
|
||||
inline static void swiglu_oai(int n, const float * src, float * dst);
|
||||
inline static void clamp_oai(int n, float *x);
|
||||
inline static void activate(ggml_unary_op op, int n, const float * src, float * dst) {
|
||||
if (op == GGML_UNARY_OP_GELU) gelu(n, src, dst);
|
||||
else if (op == GGML_UNARY_OP_RELU) relu(n, src, dst);
|
||||
else if (op == GGML_UNARY_OP_SILU) silu(n, src, dst);
|
||||
else if (op == GGML_UNARY_OP_SWIGLU_OAI) swiglu_oai(n, src, dst);
|
||||
else GGML_ABORT("fatal error");
|
||||
}
|
||||
inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, DataInfo& info, int nrc_x, int nrc_y, int unary_op) {
|
||||
inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx,
|
||||
const float * up_b, const float * gate_b,
|
||||
DataInfo& info, int nrc_x, int nrc_y, int unary_op) {
|
||||
#ifdef __aarch64__
|
||||
constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small)
|
||||
#else
|
||||
@@ -137,6 +142,29 @@ struct MulMat {
|
||||
#endif
|
||||
auto op = ggml_unary_op(unary_op);
|
||||
float tmp[k_x_step*16];
|
||||
auto process = [&tmp, n, op, vx_gate, vx_up, gate_b, up_b, bx, xstep = k_x_step] (mul_mat_t func, const DataInfo& this_info, int ix, int this_nrc_x, int ny) {
|
||||
func(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny; ++ky) {
|
||||
if (gate_b) {
|
||||
auto b = gate_b + ix;
|
||||
auto x = this_info.dst_row(ky);
|
||||
for (int j = 0; j < this_nrc_x; ++j) x[j] += b[j];
|
||||
}
|
||||
activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*xstep);
|
||||
}
|
||||
func(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny; ++ky) {
|
||||
auto result = this_info.dst_row(ky);
|
||||
if (up_b) {
|
||||
auto b = up_b + ix;
|
||||
for (int j = 0; j < this_nrc_x; ++j) result[j] += b[j];
|
||||
}
|
||||
if (op == GGML_UNARY_OP_SWIGLU_OAI) {
|
||||
clamp_oai(this_nrc_x, result);
|
||||
}
|
||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*xstep + j];
|
||||
}
|
||||
};
|
||||
if (func16 && nrc_y >= 16) {
|
||||
int n_step = (nrc_y - info.cur_y)/16;
|
||||
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
|
||||
@@ -144,15 +172,7 @@ struct MulMat {
|
||||
this_info.s += ix;
|
||||
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||
for (int iy = 0; iy < n_step; ++iy) {
|
||||
func16(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < 16; ++ky) {
|
||||
activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
||||
}
|
||||
func16(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < 16; ++ky) {
|
||||
auto result = this_info.dst_row(ky);
|
||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
||||
}
|
||||
process(func16, this_info, ix, this_nrc_x, 16);
|
||||
this_info.cur_y += 16;
|
||||
}
|
||||
}
|
||||
@@ -175,23 +195,11 @@ struct MulMat {
|
||||
this_info.s += ix;
|
||||
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||
for (int iy = 0; iy < my1; ++iy) {
|
||||
funcs[ny1-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny1; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
||||
funcs[ny1-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny1; ++ky) {
|
||||
auto result = this_info.dst_row(ky);
|
||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
||||
}
|
||||
process(funcs[ny1-1], this_info, ix, this_nrc_x, ny1);
|
||||
this_info.cur_y += ny1;
|
||||
}
|
||||
for (int iy = 0; iy < my2; ++iy) {
|
||||
funcs[ny2-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny2; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
||||
funcs[ny2-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny2; ++ky) {
|
||||
auto result = this_info.dst_row(ky);
|
||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
||||
}
|
||||
process(funcs[ny2-1], this_info, ix, this_nrc_x, ny2);
|
||||
this_info.cur_y += ny2;
|
||||
}
|
||||
}
|
||||
@@ -203,13 +211,7 @@ struct MulMat {
|
||||
this_info.s += ix;
|
||||
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||
for (int iy = 0; iy < n_step; ++iy) {
|
||||
funcs[ny-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
||||
funcs[ny-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny; ++ky) {
|
||||
auto result = this_info.dst_row(ky);
|
||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
||||
}
|
||||
process(funcs[ny-1], this_info, ix, this_nrc_x, ny);
|
||||
this_info.cur_y += ny;
|
||||
}
|
||||
}
|
||||
@@ -222,13 +224,7 @@ struct MulMat {
|
||||
auto this_info = info;
|
||||
this_info.s += ix;
|
||||
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||
funcs[n_left-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < n_left; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
||||
funcs[n_left-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < n_left; ++ky) {
|
||||
auto result = this_info.dst_row(ky);
|
||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
||||
}
|
||||
process(funcs[n_left-1], this_info, ix, this_nrc_x, n_left);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -731,6 +727,7 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
|
||||
extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
|
||||
int typeA, const void * Aup, const void * Agate, long strideA,
|
||||
int typeB, const void * B, long strideB,
|
||||
const char * up_b_c, const char * gate_b_c,
|
||||
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {
|
||||
|
||||
const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping;
|
||||
@@ -774,7 +771,9 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
|
||||
if (!iqk_convert_repack(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) {
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, this_info, this_nrc_x, Ny, unary_op);
|
||||
auto up_b = up_b_c ? (const float *)up_b_c + first_x + ix : nullptr;
|
||||
auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x + ix : nullptr;
|
||||
mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, up_b, gate_b, this_info, this_nrc_x, Ny, unary_op);
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -795,7 +794,10 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
|
||||
nrc_x *= num_rows;
|
||||
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float),
|
||||
row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
|
||||
mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny, unary_op);
|
||||
auto up_b = up_b_c ? (const float *)up_b_c + first_x : nullptr;
|
||||
auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x : nullptr;
|
||||
mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx,
|
||||
up_b, gate_b, info, nrc_x, Ny, unary_op);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -993,6 +995,21 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO: these swiglu_oai constants shouldn't be hard coded
|
||||
constexpr float k_swiglu_oai_alpha = 1.702f;
|
||||
constexpr float k_swiglu_oai_limit = 7.f;
|
||||
|
||||
void MulMat::swiglu_oai(int n, const float * x, float * y) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
auto xi = std::min(x[i], k_swiglu_oai_limit);
|
||||
y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
|
||||
}
|
||||
}
|
||||
|
||||
void MulMat::clamp_oai(int n, float * x) {
|
||||
for (int i = 0; i < n; ++i) x[i] = 1.f + std::max(std::min(x[i], k_swiglu_oai_limit), -k_swiglu_oai_limit);
|
||||
}
|
||||
|
||||
#if defined(__ARM_NEON) && defined(__aarch64__)
|
||||
void MulMat::gelu(int n, const float * x, float * y) {
|
||||
constexpr float GELU_COEF_A = 0.044715f;
|
||||
@@ -1040,6 +1057,37 @@ void MulMat::gelu(int n, const float * x, float * y) {
|
||||
for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i])));
|
||||
}
|
||||
|
||||
//void MulMat::swiglu_oai(int n, const float * x, float * y) {
|
||||
// int i = 0;
|
||||
//#if defined __AVX512F__ && defined __AVX512DQ__
|
||||
// {
|
||||
// auto limit = _mm512_set1_ps(k_swiglu_oai_limit);
|
||||
// auto alpha = _mm512_set1_ps(k_swiglu_oai_alpha);
|
||||
// for (; i + 15 < n; i += 16) {
|
||||
// auto xi = _mm512_loadu_ps(x + i);
|
||||
// auto mask = _mm512_cmp
|
||||
//
|
||||
// }
|
||||
// __m512 c1 = _mm512_set1_ps(GELU_COEF_A);
|
||||
// __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI);
|
||||
// for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_gelu(_mm512_loadu_ps(x + i), c1, c2));
|
||||
// }
|
||||
//#endif
|
||||
//#if defined __AVX2__ && defined __FMA__
|
||||
// if (i + 7 < n) {
|
||||
// __m256 c1 = _mm256_set1_ps(GELU_COEF_A);
|
||||
// __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI);
|
||||
// for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_gelu(_mm256_loadu_ps(x + i), c1, c2));
|
||||
//
|
||||
// }
|
||||
//#endif
|
||||
// for (; i < n; ++i) {
|
||||
// auto xi = std::min(x[i], k_swiglu_oai_limit);
|
||||
// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
|
||||
// }
|
||||
//}
|
||||
|
||||
|
||||
void MulMat::silu(int n, const float * x, float * y) {
|
||||
int i = 0;
|
||||
#if defined __AVX512F__ && defined __AVX512DQ__
|
||||
|
||||
@@ -32,6 +32,7 @@ IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
|
||||
IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
|
||||
int typeA, const void * Aup, const void * Agate, long strideA,
|
||||
int typeB, const void * B, long strideB,
|
||||
const char * up_b, const char * gate_b,
|
||||
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
|
||||
|
||||
IQK_API int iqk_dequant_type(int type, int Ny);
|
||||
|
||||
Reference in New Issue
Block a user