diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f41bfb2b..078d7219 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -20386,9 +20386,10 @@ static void ggml_compute_forward_cross_entropy_loss_back( ///////////////////////////////// -static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) { +static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, + const struct ggml_cgraph * cgraph, int i) { + GGML_ASSERT(params); - GGML_UNUSED(next); if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { return false; @@ -20398,7 +20399,6 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm int64_t t1 = ggml_time_us(); #endif - bool skip_next = false; switch (tensor->op) { case GGML_OP_DUP: { @@ -20586,7 +20586,21 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_SOFT_MAX: { - ggml_compute_forward_soft_max(params, tensor); + if (i + 4 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+2]->op == GGML_OP_ARGSORT && + cgraph->nodes[i+3]->op == GGML_OP_VIEW && + cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS && + cgraph->nodes[i+0]->type == GGML_TYPE_F32 && + cgraph->nodes[i+4]->type == GGML_TYPE_F32 && + cgraph->nodes[i+3]->type == GGML_TYPE_I32) { + iqk_topk_moe(cgraph->nodes[i]->ne[0], cgraph->nodes[i+4]->ne[1], cgraph->nodes[i]->ne[1], + (const float *)cgraph->nodes[i]->data, (float *)cgraph->nodes[i+4]->data, (int32_t *)cgraph->nodes[i+3]->data, + params->ith, params->nth); + i += 4; + } else { + ggml_compute_forward_soft_max(params, tensor); + } } break; case GGML_OP_SOFT_MAX_BACK: { @@ -20764,7 +20778,7 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm int64_t t2 = ggml_time_us(); if (params->ith == 0) printf("%s(%s): %d us\n", ggml_op_name(tensor->op), tensor->name, (int)(t2 - t1)); #endif - return skip_next; + return i; } //////////////////////////////////////////////////////////////////////////////// @@ -22725,9 +22739,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { #if IK_PRINT_TIMING int64_t tim1 = ggml_time_us(); #endif - if (ggml_compute_forward(¶ms, node, node_n < cgraph->n_nodes-1 ? cgraph->nodes[node_n+1] : NULL)) { - ++node_n; - } + node_n = ggml_compute_forward(¶ms, node, cgraph, node_n); #if IK_PRINT_TIMING int64_t tim2 = ggml_time_us(); t_eval += tim2 - tim1; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 7d7126f5..44fe0a68 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include "ggml-impl.h" #include "ggml-quants.h" @@ -1140,6 +1141,64 @@ void MulMat::relu(int n, const float * x, float * y) { #endif } // namespace +namespace { +void iqk_topk_moe(int n_experts, int n_experts_used, const float * logits, + float * weights, int32_t * ids, void * work) { + + if (work) { + auto sorted = (std::pair *)work; + for (int j = 0; j < n_experts; ++j) sorted[j] = {logits[j], j}; + + std::partial_sort(sorted, sorted + n_experts_used, sorted + n_experts, std::greater>{}); + + float max = sorted[0].first; + float sum = 0; + for (int j = 0; j < n_experts; ++j) { + float p = expf(sorted[j].first - max); + weights[j] = p; + ids[j] = sorted[j].second; + sum += p; + } + float norm = 1/sum; + for (int j = 0; j < n_experts; ++j) weights[j] *= norm; + } else { + for (int j = 0; j < n_experts; ++j) ids[j] = j; + + std::partial_sort(ids, ids + n_experts_used, ids + n_experts, + [logits] (int i1, int i2) { + return logits[i1] > logits[i2]; + }); + + float max = logits[ids[0]]; + float sum = 0; + for (int j = 0; j < n_experts_used; ++j) { + float p = expf(logits[ids[j]] - max); + weights[j] = p; + sum += p; + } + for (int j = n_experts_used; j < n_experts; ++j) { + sum += expf(logits[ids[j]] - max); + } + float norm = 1/sum; + for (int j = 0; j < n_experts_used; ++j) weights[j] *= norm; + } +} +} + +void iqk_topk_moe(int n_experts, int n_experts_used, int nrows, const float * logits, + float * weights, int32_t * ids, int ith, int nth) { + + int npt = (nrows + nth - 1)/nth; + int first = ith*npt; + int last = std::min(nrows, first + npt); + for (int row = first; row < last; ++row) { + auto row_logits = logits + row*n_experts; + auto row_weights = weights + row*n_experts_used; + auto row_ids = ids + row*n_experts; + iqk_topk_moe(n_experts, n_experts_used, row_logits, row_weights, row_ids, nullptr); + } +} + #ifdef GGML_IQK_FLASH_ATTENTION void * iqk_repack_k(int int_type_k, int nek0, int nek1, int nek2, int nek3, long nbk1, long nbk2, long nbk3, diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index c599281b..3c1250e2 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -65,6 +65,9 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, void * work_buffer, barrier_t barrier, void * barrier_data, int ith, int nth, int n_swa); +IQK_API void iqk_topk_moe(int n_experts, int n_experts_used, int nrows, const float * logits, + float * weights, int32_t * ids, int ith, int nth); + #ifdef __cplusplus } #endif