mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 23:24:13 +00:00
Fuse up and gate gemms in MoE models
Small (~1-2%) but measurable performan ce gain
This commit is contained in:
161
ggml/src/ggml.c
161
ggml/src/ggml.c
@@ -14581,6 +14581,149 @@ IQK_MulMat_Not_Available:;
|
||||
#undef MMID_MATRIX_ROW
|
||||
}
|
||||
|
||||
#if GGML_USE_IQK_MULMAT
|
||||
static void ggml_compute_forward_mul_mat_id_up_gate(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst1,
|
||||
struct ggml_tensor * dst2) {
|
||||
|
||||
GGML_ASSERT(dst1->src[1] == dst2->src[1]);
|
||||
GGML_ASSERT(dst1->src[2] == dst2->src[2]);
|
||||
GGML_ASSERT(dst1->src[0]->type == dst2->src[0]->type);
|
||||
GGML_ASSERT(dst1->type == GGML_TYPE_F32 && dst2->type == GGML_TYPE_F32);
|
||||
|
||||
const struct ggml_tensor * src1 = dst1->src[1];
|
||||
const struct ggml_tensor * ids = dst1->src[2];
|
||||
const struct ggml_tensor * src0_1 = dst1->src[0];
|
||||
const struct ggml_tensor * src0_2 = dst2->src[0];
|
||||
const struct ggml_tensor * src0 = src0_1;
|
||||
const struct ggml_tensor * dst = dst1; // so GGML_TENSOR_BINARY_OP_LOCALS works
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
|
||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
|
||||
// row groups
|
||||
const int n_ids = ids->ne[0]; // n_expert_used
|
||||
const int n_as = ne02; // n_expert
|
||||
|
||||
char * wdata_src1_end = (src1->type == vec_dot_type) ?
|
||||
(char *) params->wdata :
|
||||
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, src1->ne[0])*ggml_nrows(src1), sizeof(int64_t));
|
||||
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
||||
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
|
||||
|
||||
if (src1->type != vec_dot_type) {
|
||||
|
||||
ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
|
||||
|
||||
char * wdata = params->wdata;
|
||||
|
||||
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
|
||||
const size_t nbw2 = nbw1*ne11;
|
||||
const size_t nbw3 = nbw2*ne12;
|
||||
|
||||
assert(params->wsize >= ne13*nbw3);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
||||
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
||||
ne10);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
|
||||
|
||||
if (ith == 0) {
|
||||
// initialize matrix_row_counts
|
||||
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
||||
|
||||
// group rows by src0 matrix
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
|
||||
for (int id = 0; id < n_ids; ++id) {
|
||||
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
assert(i02 >= 0 && i02 < n_as);
|
||||
|
||||
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
|
||||
matrix_row_counts[i02] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_barrier(params->shared);
|
||||
|
||||
// compute each matrix multiplication in sequence
|
||||
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
||||
const int64_t cne1 = matrix_row_counts[cur_a];
|
||||
|
||||
if (cne1 == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02;
|
||||
const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02;
|
||||
|
||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
|
||||
const int64_t nr0 = ne01; // src0 rows
|
||||
const int64_t nr1 = cne1; // src1 rows
|
||||
|
||||
if (nth%2 == 0) {
|
||||
const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur;
|
||||
void * dst_d = ith%2 == 0 ? dst1->data : dst2->data;
|
||||
if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
|
||||
type, src0_d, nb01,
|
||||
vec_dot_type, (const char *)wdata, row_size,
|
||||
(float *)dst_d, nb1, nb2,
|
||||
matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error");
|
||||
|
||||
} else {
|
||||
if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
|
||||
src0_1->type, (const char *)src0_1_cur, nb01,
|
||||
vec_dot_type, (const char *)wdata, row_size,
|
||||
(float *)dst1->data, nb1, nb2,
|
||||
matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
|
||||
if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
|
||||
src0_2->type, (const char *)src0_2_cur, nb01,
|
||||
vec_dot_type, (const char *)wdata, row_size,
|
||||
(float *)dst2->data, nb1, nb2,
|
||||
matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
#undef MMID_MATRIX_ROW
|
||||
}
|
||||
#endif
|
||||
|
||||
// ggml_compute_forward_out_prod
|
||||
|
||||
static void ggml_compute_forward_out_prod_f32(
|
||||
@@ -19007,17 +19150,18 @@ static void ggml_compute_forward_cross_entropy_loss_back(
|
||||
|
||||
/////////////////////////////////
|
||||
|
||||
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||
static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) {
|
||||
GGML_ASSERT(params);
|
||||
|
||||
if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
int64_t t1 = ggml_time_us();
|
||||
#endif
|
||||
|
||||
bool skip_next = false;
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_DUP:
|
||||
{
|
||||
@@ -19125,6 +19269,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
#if GGML_USE_IQK_MULMAT
|
||||
if (next && next->op == GGML_OP_MUL_MAT_ID && tensor->src[1] == next->src[1] &&
|
||||
tensor->src[0]->type == next->src[0]->type) {
|
||||
ggml_compute_forward_mul_mat_id_up_gate(params, tensor, next);
|
||||
skip_next = true;
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
ggml_compute_forward_mul_mat_id(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_OUT_PROD:
|
||||
@@ -19367,6 +19519,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
int64_t t2 = ggml_time_us();
|
||||
if (params->ith == 0) printf("%s(%s): %d us\n", ggml_op_name(tensor->op), tensor->name, (int)(t2 - t1));
|
||||
#endif
|
||||
return skip_next;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -21219,7 +21372,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
|
||||
if (ggml_is_noop(node)) continue;
|
||||
|
||||
ggml_compute_forward(¶ms, node);
|
||||
if (ggml_compute_forward(¶ms, node, node_n < cgraph->n_nodes-1 ? cgraph->nodes[node_n+1] : NULL)) {
|
||||
++node_n;
|
||||
}
|
||||
|
||||
if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
|
||||
state->shared->ec = GGML_STATUS_ABORTED;
|
||||
|
||||
Reference in New Issue
Block a user