diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 036bd8a8..d00f50a3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -14737,6 +14737,34 @@ static void ggml_compute_forward_mul_mat_id( } } +#if GGML_USE_IQK_MULMAT + if (ids->ne[1] == 1 && dst->type == GGML_TYPE_F32) { + int gcd = simple_gcd(n_ids, nth); + if (gcd > 1) { + ggml_barrier(params->shared); + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + int counter = 0; + for (int id = 0; id < n_ids; ++id) { + if ((counter++ % gcd) == (ith%gcd)) { + int i02 = *(const int32_t *) ((const char *) ids->data + id*ids->nb[0]); + if (i02 >= 0 && i02 < n_as) { + const char * src0_cur = (const char *) src0->data + i02*nb02; + // i1 = id, i2 = iid1 = 0 + if (!iqk_mul_mat(ne01, 1, ne00, + src0->type, (const char *)src0_cur, nb01, + vec_dot_type, (const char *)wdata, row_size, + (float *)((char *)dst->data + id*nb1), nb1, + ith/gcd, nth/gcd)) goto IQK_MulMat_Not_Available0; + } + } + } + return; + } + } +IQK_MulMat_Not_Available0:; +#endif + #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] if (ith == 0) { @@ -14990,6 +15018,32 @@ static void ggml_compute_forward_mul_mat_id_up_gate( } } + if (ids->ne[1] == 1 && dst->type == GGML_TYPE_F32) { + int gcd = simple_gcd(n_ids, nth); + if (gcd > 1) { + ggml_barrier(params->shared); + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + int counter = 0; + for (int id = 0; id < n_ids; ++id) { + if ((counter++ % gcd) == (ith%gcd)) { + int i02 = *(const int32_t *) ((const char *) ids->data + id*ids->nb[0]); + if (i02 >= 0 && i02 < n_as) { + const char * src0_1_cur = (const char *) src0_1->data + i02*nb02; + const char * src0_2_cur = (const char *) src0_2->data + i02*nb02; + // i1 = id, i2 = iid1 = 0 + if (!iqk_moe_fused_up_gate(ne01, 1, ne00, ne11, dst->op_params[0], + type, src0_1_cur, src0_2_cur, nb01, + vec_dot_type, (const char *)wdata, row_size, + (float *)((char *)dst->data + id*nb1), nb1, nb2, + NULL, ith/gcd, nth/gcd)) GGML_ABORT("fatal error"); + } + } + } + return; + } + } + #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] if (ith == 0) {