Is this better for DeepSeek-R1?

This commit is contained in:
Iwan Kawrakow
2025-03-24 21:18:06 +02:00
parent f9307d7907
commit be46f3ef14

View File

@@ -14737,6 +14737,34 @@ static void ggml_compute_forward_mul_mat_id(
}
}
#if GGML_USE_IQK_MULMAT
if (ids->ne[1] == 1 && dst->type == GGML_TYPE_F32) {
int gcd = simple_gcd(n_ids, nth);
if (gcd > 1) {
ggml_barrier(params->shared);
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
int counter = 0;
for (int id = 0; id < n_ids; ++id) {
if ((counter++ % gcd) == (ith%gcd)) {
int i02 = *(const int32_t *) ((const char *) ids->data + id*ids->nb[0]);
if (i02 >= 0 && i02 < n_as) {
const char * src0_cur = (const char *) src0->data + i02*nb02;
// i1 = id, i2 = iid1 = 0
if (!iqk_mul_mat(ne01, 1, ne00,
src0->type, (const char *)src0_cur, nb01,
vec_dot_type, (const char *)wdata, row_size,
(float *)((char *)dst->data + id*nb1), nb1,
ith/gcd, nth/gcd)) goto IQK_MulMat_Not_Available0;
}
}
}
return;
}
}
IQK_MulMat_Not_Available0:;
#endif
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
if (ith == 0) {
@@ -14990,6 +15018,32 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
}
}
if (ids->ne[1] == 1 && dst->type == GGML_TYPE_F32) {
int gcd = simple_gcd(n_ids, nth);
if (gcd > 1) {
ggml_barrier(params->shared);
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
int counter = 0;
for (int id = 0; id < n_ids; ++id) {
if ((counter++ % gcd) == (ith%gcd)) {
int i02 = *(const int32_t *) ((const char *) ids->data + id*ids->nb[0]);
if (i02 >= 0 && i02 < n_as) {
const char * src0_1_cur = (const char *) src0_1->data + i02*nb02;
const char * src0_2_cur = (const char *) src0_2->data + i02*nb02;
// i1 = id, i2 = iid1 = 0
if (!iqk_moe_fused_up_gate(ne01, 1, ne00, ne11, dst->op_params[0],
type, src0_1_cur, src0_2_cur, nb01,
vec_dot_type, (const char *)wdata, row_size,
(float *)((char *)dst->data + id*nb1), nb1, nb2,
NULL, ith/gcd, nth/gcd)) GGML_ABORT("fatal error");
}
}
}
return;
}
}
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
if (ith == 0) {