Fuse up and gate gemms in MoE models

Small (~1-2%) but measurable performan ce gain
This commit is contained in:
Iwan Kawrakow
2025-02-22 08:26:26 +02:00
parent af790bb5fa
commit 216ea5890d

View File

@@ -14581,6 +14581,149 @@ IQK_MulMat_Not_Available:;
#undef MMID_MATRIX_ROW
}
#if GGML_USE_IQK_MULMAT
static void ggml_compute_forward_mul_mat_id_up_gate(
const struct ggml_compute_params * params,
struct ggml_tensor * dst1,
struct ggml_tensor * dst2) {
GGML_ASSERT(dst1->src[1] == dst2->src[1]);
GGML_ASSERT(dst1->src[2] == dst2->src[2]);
GGML_ASSERT(dst1->src[0]->type == dst2->src[0]->type);
GGML_ASSERT(dst1->type == GGML_TYPE_F32 && dst2->type == GGML_TYPE_F32);
const struct ggml_tensor * src1 = dst1->src[1];
const struct ggml_tensor * ids = dst1->src[2];
const struct ggml_tensor * src0_1 = dst1->src[0];
const struct ggml_tensor * src0_2 = dst2->src[0];
const struct ggml_tensor * src0 = src0_1;
const struct ggml_tensor * dst = dst1; // so GGML_TENSOR_BINARY_OP_LOCALS works
GGML_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
const enum ggml_type type = src0->type;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ne13 == 1);
// row groups
const int n_ids = ids->ne[0]; // n_expert_used
const int n_as = ne02; // n_expert
char * wdata_src1_end = (src1->type == vec_dot_type) ?
(char *) params->wdata :
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, src1->ne[0])*ggml_nrows(src1), sizeof(int64_t));
struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
if (src1->type != vec_dot_type) {
ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
char * wdata = params->wdata;
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
const size_t nbw2 = nbw1*ne11;
const size_t nbw3 = nbw2*ne12;
assert(params->wsize >= ne13*nbw3);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
ne10);
}
}
}
}
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
if (ith == 0) {
// initialize matrix_row_counts
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
// group rows by src0 matrix
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
for (int id = 0; id < n_ids; ++id) {
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
assert(i02 >= 0 && i02 < n_as);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
matrix_row_counts[i02] += 1;
}
}
}
ggml_barrier(params->shared);
// compute each matrix multiplication in sequence
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
const int64_t cne1 = matrix_row_counts[cur_a];
if (cne1 == 0) {
continue;
}
const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02;
const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1; // src1 rows
if (nth%2 == 0) {
const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur;
void * dst_d = ith%2 == 0 ? dst1->data : dst2->data;
if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
type, src0_d, nb01,
vec_dot_type, (const char *)wdata, row_size,
(float *)dst_d, nb1, nb2,
matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error");
} else {
if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
src0_1->type, (const char *)src0_1_cur, nb01,
vec_dot_type, (const char *)wdata, row_size,
(float *)dst1->data, nb1, nb2,
matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
src0_2->type, (const char *)src0_2_cur, nb01,
vec_dot_type, (const char *)wdata, row_size,
(float *)dst2->data, nb1, nb2,
matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
}
}
#undef MMID_MATRIX_ROW
}
#endif
// ggml_compute_forward_out_prod
static void ggml_compute_forward_out_prod_f32(
@@ -19007,17 +19150,18 @@ static void ggml_compute_forward_cross_entropy_loss_back(
/////////////////////////////////
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) {
GGML_ASSERT(params);
if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
return;
return false;
}
#if IK_PRINT_TIMING
int64_t t1 = ggml_time_us();
#endif
bool skip_next = false;
switch (tensor->op) {
case GGML_OP_DUP:
{
@@ -19125,6 +19269,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_MUL_MAT_ID:
{
#if GGML_USE_IQK_MULMAT
if (next && next->op == GGML_OP_MUL_MAT_ID && tensor->src[1] == next->src[1] &&
tensor->src[0]->type == next->src[0]->type) {
ggml_compute_forward_mul_mat_id_up_gate(params, tensor, next);
skip_next = true;
break;
}
#endif
ggml_compute_forward_mul_mat_id(params, tensor);
} break;
case GGML_OP_OUT_PROD:
@@ -19367,6 +19519,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
int64_t t2 = ggml_time_us();
if (params->ith == 0) printf("%s(%s): %d us\n", ggml_op_name(tensor->op), tensor->name, (int)(t2 - t1));
#endif
return skip_next;
}
////////////////////////////////////////////////////////////////////////////////
@@ -21219,7 +21372,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
if (ggml_is_noop(node)) continue;
ggml_compute_forward(&params, node);
if (ggml_compute_forward(&params, node, node_n < cgraph->n_nodes-1 ? cgraph->nodes[node_n+1] : NULL)) {
++node_n;
}
if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
state->shared->ec = GGML_STATUS_ABORTED;