diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a556b2a4..d6b76a5b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -22568,7 +22568,17 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml } break; case GGML_OP_ARGSORT: { - ggml_compute_forward_argsort(params, tensor); + if (i + 5 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_VIEW && + cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS && + cgraph->nodes[i+3]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+4]->op == GGML_OP_SOFT_MAX && + cgraph->nodes[i+5]->op == GGML_OP_RESHAPE) { + iqk_openai_experts(tensor, cgraph->nodes[i+4], params->ith, params->nth); + i += 5; + } else { + ggml_compute_forward_argsort(params, tensor); + } } break; case GGML_OP_ARGSORT_THRESH: { diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index 51cdfcc8..ff34abcf 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -345,3 +345,47 @@ void iqk_glm45moe_experts(struct ggml_tensor * dst, struct ggml_tensor * topk_vi } } } + +void iqk_openai_experts(struct ggml_tensor * topk, struct ggml_tensor * softmax, int ith, int nth) { + + auto probs = topk->src[0]; + + auto nrows = ggml_nrows(probs); + auto npt = (nrows + nth - 1)/nth; + auto first = npt*ith; + auto last = std::min(first + npt, nrows); + if (last <= first) return; + + int ne00 = probs->ne[0]; + int ne0 = softmax->ne[0]; + GGML_ASSERT(ggml_is_contiguous(probs)); + GGML_ASSERT(ggml_is_contiguous(softmax)); + GGML_ASSERT(ne0 <= ne00); + //if (ith == 0) printf("%s: ne00 = %d, ne0 = %d, topk: %s, softmax: %s\n", __func__, ne00, ne0, ggml_type_name(topk->type), ggml_type_name(softmax->type)); + //if (ith == 0) printf("%s: ne00 = %d, ne0 = %d, topk: %s, %ld x %ld x %ld x %ld, %zu x %zu x %zu x %zu\n", __func__, ne00, ne0, ggml_type_name(topk->type), topk->ne[0], topk->ne[1], topk->ne[2], topk->ne[3], topk->nb[0], topk->nb[1], topk->nb[2], topk->nb[3]); + + size_t work_size = ne00; + auto& aux = get_work_buffer(work_size); + + for (int ir = first; ir < last; ++ir) { + auto data = (const float *)((const char *)probs->data + ir*probs->nb[1]); + for (int j = 0; j < ne00; ++j) aux[j] = { data[j], j }; + if (ne0 < ne00) { + std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + ne00, std::greater>{}); + } else { + std::sort(aux.begin(), aux.begin() + ne00, std::greater>{}); + } + auto weights = (float *)((char *)softmax->data + ir*softmax->nb[1]); + auto ids = (int32_t *)((char *)topk->data + ir*topk->nb[1]); + float max = aux.front().first; + float sum = 0; + for (int j = 0; j < ne0; ++j) { + weights[j] = expf(aux[j].first - max); + ids[j] = aux[j].second; + sum += weights[j]; + } + GGML_ASSERT(sum > 0); + float norm = 1/sum; + for (int j = 0; j < ne0; ++j) weights[j] *= norm; + } +} diff --git a/ggml/src/iqk/iqk_cpu_ops.h b/ggml/src/iqk/iqk_cpu_ops.h index 2de3a5cb..ef2bbe1a 100644 --- a/ggml/src/iqk/iqk_cpu_ops.h +++ b/ggml/src/iqk/iqk_cpu_ops.h @@ -22,6 +22,8 @@ void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * top void iqk_glm45moe_experts(struct ggml_tensor * dst, struct ggml_tensor * topk_view, int ith, int nth); +void iqk_openai_experts(struct ggml_tensor * topk, struct ggml_tensor * softmax, int ith, int nth); + #ifdef __cplusplus } #endif diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index e2fccfca..dd0b62be 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -836,10 +836,6 @@ llm_expert_gating_func_type gating_op, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); - if (graph) { - ggml_build_forward_expand(graph, weights); - } - if (gating_op == LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) { weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens] @@ -847,6 +843,10 @@ llm_expert_gating_func_type gating_op, cb(weights, "ffn_moe_weights_softmax", il); } + if (graph) { + ggml_build_forward_expand(graph, weights); + } + if (norm_w) { weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);