cpu: fused softmax+topk (#794)

* cpu: fused softmax+topk

* Cleanup

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-09-24 09:02:21 +02:00
committed by GitHub
parent 17f7f1ed18
commit f59b2909d4
3 changed files with 82 additions and 8 deletions

View File

@@ -20386,9 +20386,10 @@ static void ggml_compute_forward_cross_entropy_loss_back(
/////////////////////////////////
static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) {
static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor,
const struct ggml_cgraph * cgraph, int i) {
GGML_ASSERT(params);
GGML_UNUSED(next);
if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
return false;
@@ -20398,7 +20399,6 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
int64_t t1 = ggml_time_us();
#endif
bool skip_next = false;
switch (tensor->op) {
case GGML_OP_DUP:
{
@@ -20586,7 +20586,21 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_SOFT_MAX:
{
ggml_compute_forward_soft_max(params, tensor);
if (i + 4 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+2]->op == GGML_OP_ARGSORT &&
cgraph->nodes[i+3]->op == GGML_OP_VIEW &&
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS &&
cgraph->nodes[i+0]->type == GGML_TYPE_F32 &&
cgraph->nodes[i+4]->type == GGML_TYPE_F32 &&
cgraph->nodes[i+3]->type == GGML_TYPE_I32) {
iqk_topk_moe(cgraph->nodes[i]->ne[0], cgraph->nodes[i+4]->ne[1], cgraph->nodes[i]->ne[1],
(const float *)cgraph->nodes[i]->data, (float *)cgraph->nodes[i+4]->data, (int32_t *)cgraph->nodes[i+3]->data,
params->ith, params->nth);
i += 4;
} else {
ggml_compute_forward_soft_max(params, tensor);
}
} break;
case GGML_OP_SOFT_MAX_BACK:
{
@@ -20764,7 +20778,7 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
int64_t t2 = ggml_time_us();
if (params->ith == 0) printf("%s(%s): %d us\n", ggml_op_name(tensor->op), tensor->name, (int)(t2 - t1));
#endif
return skip_next;
return i;
}
////////////////////////////////////////////////////////////////////////////////
@@ -22725,9 +22739,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
#if IK_PRINT_TIMING
int64_t tim1 = ggml_time_us();
#endif
if (ggml_compute_forward(&params, node, node_n < cgraph->n_nodes-1 ? cgraph->nodes[node_n+1] : NULL)) {
++node_n;
}
node_n = ggml_compute_forward(&params, node, cgraph, node_n);
#if IK_PRINT_TIMING
int64_t tim2 = ggml_time_us();
t_eval += tim2 - tim1;

View File

@@ -14,6 +14,7 @@
#include <cstring>
#include <type_traits>
#include <vector>
#include <algorithm>
#include "ggml-impl.h"
#include "ggml-quants.h"
@@ -1140,6 +1141,64 @@ void MulMat::relu(int n, const float * x, float * y) {
#endif
} // namespace
namespace {
void iqk_topk_moe(int n_experts, int n_experts_used, const float * logits,
float * weights, int32_t * ids, void * work) {
if (work) {
auto sorted = (std::pair<float, int> *)work;
for (int j = 0; j < n_experts; ++j) sorted[j] = {logits[j], j};
std::partial_sort(sorted, sorted + n_experts_used, sorted + n_experts, std::greater<std::pair<float,int>>{});
float max = sorted[0].first;
float sum = 0;
for (int j = 0; j < n_experts; ++j) {
float p = expf(sorted[j].first - max);
weights[j] = p;
ids[j] = sorted[j].second;
sum += p;
}
float norm = 1/sum;
for (int j = 0; j < n_experts; ++j) weights[j] *= norm;
} else {
for (int j = 0; j < n_experts; ++j) ids[j] = j;
std::partial_sort(ids, ids + n_experts_used, ids + n_experts,
[logits] (int i1, int i2) {
return logits[i1] > logits[i2];
});
float max = logits[ids[0]];
float sum = 0;
for (int j = 0; j < n_experts_used; ++j) {
float p = expf(logits[ids[j]] - max);
weights[j] = p;
sum += p;
}
for (int j = n_experts_used; j < n_experts; ++j) {
sum += expf(logits[ids[j]] - max);
}
float norm = 1/sum;
for (int j = 0; j < n_experts_used; ++j) weights[j] *= norm;
}
}
}
void iqk_topk_moe(int n_experts, int n_experts_used, int nrows, const float * logits,
float * weights, int32_t * ids, int ith, int nth) {
int npt = (nrows + nth - 1)/nth;
int first = ith*npt;
int last = std::min(nrows, first + npt);
for (int row = first; row < last; ++row) {
auto row_logits = logits + row*n_experts;
auto row_weights = weights + row*n_experts_used;
auto row_ids = ids + row*n_experts;
iqk_topk_moe(n_experts, n_experts_used, row_logits, row_weights, row_ids, nullptr);
}
}
#ifdef GGML_IQK_FLASH_ATTENTION
void * iqk_repack_k(int int_type_k, int nek0, int nek1, int nek2, int nek3, long nbk1, long nbk2, long nbk3,

View File

@@ -65,6 +65,9 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
void * work_buffer, barrier_t barrier, void * barrier_data,
int ith, int nth, int n_swa);
IQK_API void iqk_topk_moe(int n_experts, int n_experts_used, int nrows, const float * logits,
float * weights, int32_t * ids, int ith, int nth);
#ifdef __cplusplus
}
#endif