Fuse topk+view+get_rows+reshape+softmax (CPU)

This commit is contained in:
Iwan Kawrakow
2025-10-19 11:45:10 +03:00
parent 18d9f4fc4d
commit c8ed454564
4 changed files with 61 additions and 5 deletions

View File

@@ -22568,7 +22568,17 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
} break;
case GGML_OP_ARGSORT:
{
ggml_compute_forward_argsort(params, tensor);
if (i + 5 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS &&
cgraph->nodes[i+3]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+4]->op == GGML_OP_SOFT_MAX &&
cgraph->nodes[i+5]->op == GGML_OP_RESHAPE) {
iqk_openai_experts(tensor, cgraph->nodes[i+4], params->ith, params->nth);
i += 5;
} else {
ggml_compute_forward_argsort(params, tensor);
}
} break;
case GGML_OP_ARGSORT_THRESH:
{

View File

@@ -345,3 +345,47 @@ void iqk_glm45moe_experts(struct ggml_tensor * dst, struct ggml_tensor * topk_vi
}
}
}
void iqk_openai_experts(struct ggml_tensor * topk, struct ggml_tensor * softmax, int ith, int nth) {
auto probs = topk->src[0];
auto nrows = ggml_nrows(probs);
auto npt = (nrows + nth - 1)/nth;
auto first = npt*ith;
auto last = std::min(first + npt, nrows);
if (last <= first) return;
int ne00 = probs->ne[0];
int ne0 = softmax->ne[0];
GGML_ASSERT(ggml_is_contiguous(probs));
GGML_ASSERT(ggml_is_contiguous(softmax));
GGML_ASSERT(ne0 <= ne00);
//if (ith == 0) printf("%s: ne00 = %d, ne0 = %d, topk: %s, softmax: %s\n", __func__, ne00, ne0, ggml_type_name(topk->type), ggml_type_name(softmax->type));
//if (ith == 0) printf("%s: ne00 = %d, ne0 = %d, topk: %s, %ld x %ld x %ld x %ld, %zu x %zu x %zu x %zu\n", __func__, ne00, ne0, ggml_type_name(topk->type), topk->ne[0], topk->ne[1], topk->ne[2], topk->ne[3], topk->nb[0], topk->nb[1], topk->nb[2], topk->nb[3]);
size_t work_size = ne00;
auto& aux = get_work_buffer(work_size);
for (int ir = first; ir < last; ++ir) {
auto data = (const float *)((const char *)probs->data + ir*probs->nb[1]);
for (int j = 0; j < ne00; ++j) aux[j] = { data[j], j };
if (ne0 < ne00) {
std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + ne00, std::greater<std::pair<float,int>>{});
} else {
std::sort(aux.begin(), aux.begin() + ne00, std::greater<std::pair<float,int>>{});
}
auto weights = (float *)((char *)softmax->data + ir*softmax->nb[1]);
auto ids = (int32_t *)((char *)topk->data + ir*topk->nb[1]);
float max = aux.front().first;
float sum = 0;
for (int j = 0; j < ne0; ++j) {
weights[j] = expf(aux[j].first - max);
ids[j] = aux[j].second;
sum += weights[j];
}
GGML_ASSERT(sum > 0);
float norm = 1/sum;
for (int j = 0; j < ne0; ++j) weights[j] *= norm;
}
}

View File

@@ -22,6 +22,8 @@ void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * top
void iqk_glm45moe_experts(struct ggml_tensor * dst, struct ggml_tensor * topk_view, int ith, int nth);
void iqk_openai_experts(struct ggml_tensor * topk, struct ggml_tensor * softmax, int ith, int nth);
#ifdef __cplusplus
}
#endif

View File

@@ -836,10 +836,6 @@ llm_expert_gating_func_type gating_op,
ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights", il);
if (graph) {
ggml_build_forward_expand(graph, weights);
}
if (gating_op == LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens]
@@ -847,6 +843,10 @@ llm_expert_gating_func_type gating_op,
cb(weights, "ffn_moe_weights_softmax", il);
}
if (graph) {
ggml_build_forward_expand(graph, weights);
}
if (norm_w) {
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);