mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
cpu: fused softmax+topk (#794)
* cpu: fused softmax+topk * Cleanup --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -20386,9 +20386,10 @@ static void ggml_compute_forward_cross_entropy_loss_back(
|
||||
|
||||
/////////////////////////////////
|
||||
|
||||
static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) {
|
||||
static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor,
|
||||
const struct ggml_cgraph * cgraph, int i) {
|
||||
|
||||
GGML_ASSERT(params);
|
||||
GGML_UNUSED(next);
|
||||
|
||||
if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
|
||||
return false;
|
||||
@@ -20398,7 +20399,6 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
int64_t t1 = ggml_time_us();
|
||||
#endif
|
||||
|
||||
bool skip_next = false;
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_DUP:
|
||||
{
|
||||
@@ -20586,7 +20586,21 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
ggml_compute_forward_soft_max(params, tensor);
|
||||
if (i + 4 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_ARGSORT &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_VIEW &&
|
||||
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS &&
|
||||
cgraph->nodes[i+0]->type == GGML_TYPE_F32 &&
|
||||
cgraph->nodes[i+4]->type == GGML_TYPE_F32 &&
|
||||
cgraph->nodes[i+3]->type == GGML_TYPE_I32) {
|
||||
iqk_topk_moe(cgraph->nodes[i]->ne[0], cgraph->nodes[i+4]->ne[1], cgraph->nodes[i]->ne[1],
|
||||
(const float *)cgraph->nodes[i]->data, (float *)cgraph->nodes[i+4]->data, (int32_t *)cgraph->nodes[i+3]->data,
|
||||
params->ith, params->nth);
|
||||
i += 4;
|
||||
} else {
|
||||
ggml_compute_forward_soft_max(params, tensor);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
{
|
||||
@@ -20764,7 +20778,7 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
int64_t t2 = ggml_time_us();
|
||||
if (params->ith == 0) printf("%s(%s): %d us\n", ggml_op_name(tensor->op), tensor->name, (int)(t2 - t1));
|
||||
#endif
|
||||
return skip_next;
|
||||
return i;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -22725,9 +22739,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
#if IK_PRINT_TIMING
|
||||
int64_t tim1 = ggml_time_us();
|
||||
#endif
|
||||
if (ggml_compute_forward(¶ms, node, node_n < cgraph->n_nodes-1 ? cgraph->nodes[node_n+1] : NULL)) {
|
||||
++node_n;
|
||||
}
|
||||
node_n = ggml_compute_forward(¶ms, node, cgraph, node_n);
|
||||
#if IK_PRINT_TIMING
|
||||
int64_t tim2 = ggml_time_us();
|
||||
t_eval += tim2 - tim1;
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-quants.h"
|
||||
@@ -1140,6 +1141,64 @@ void MulMat::relu(int n, const float * x, float * y) {
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
void iqk_topk_moe(int n_experts, int n_experts_used, const float * logits,
|
||||
float * weights, int32_t * ids, void * work) {
|
||||
|
||||
if (work) {
|
||||
auto sorted = (std::pair<float, int> *)work;
|
||||
for (int j = 0; j < n_experts; ++j) sorted[j] = {logits[j], j};
|
||||
|
||||
std::partial_sort(sorted, sorted + n_experts_used, sorted + n_experts, std::greater<std::pair<float,int>>{});
|
||||
|
||||
float max = sorted[0].first;
|
||||
float sum = 0;
|
||||
for (int j = 0; j < n_experts; ++j) {
|
||||
float p = expf(sorted[j].first - max);
|
||||
weights[j] = p;
|
||||
ids[j] = sorted[j].second;
|
||||
sum += p;
|
||||
}
|
||||
float norm = 1/sum;
|
||||
for (int j = 0; j < n_experts; ++j) weights[j] *= norm;
|
||||
} else {
|
||||
for (int j = 0; j < n_experts; ++j) ids[j] = j;
|
||||
|
||||
std::partial_sort(ids, ids + n_experts_used, ids + n_experts,
|
||||
[logits] (int i1, int i2) {
|
||||
return logits[i1] > logits[i2];
|
||||
});
|
||||
|
||||
float max = logits[ids[0]];
|
||||
float sum = 0;
|
||||
for (int j = 0; j < n_experts_used; ++j) {
|
||||
float p = expf(logits[ids[j]] - max);
|
||||
weights[j] = p;
|
||||
sum += p;
|
||||
}
|
||||
for (int j = n_experts_used; j < n_experts; ++j) {
|
||||
sum += expf(logits[ids[j]] - max);
|
||||
}
|
||||
float norm = 1/sum;
|
||||
for (int j = 0; j < n_experts_used; ++j) weights[j] *= norm;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void iqk_topk_moe(int n_experts, int n_experts_used, int nrows, const float * logits,
|
||||
float * weights, int32_t * ids, int ith, int nth) {
|
||||
|
||||
int npt = (nrows + nth - 1)/nth;
|
||||
int first = ith*npt;
|
||||
int last = std::min(nrows, first + npt);
|
||||
for (int row = first; row < last; ++row) {
|
||||
auto row_logits = logits + row*n_experts;
|
||||
auto row_weights = weights + row*n_experts_used;
|
||||
auto row_ids = ids + row*n_experts;
|
||||
iqk_topk_moe(n_experts, n_experts_used, row_logits, row_weights, row_ids, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
void * iqk_repack_k(int int_type_k, int nek0, int nek1, int nek2, int nek3, long nbk1, long nbk2, long nbk3,
|
||||
|
||||
@@ -65,6 +65,9 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
|
||||
void * work_buffer, barrier_t barrier, void * barrier_data,
|
||||
int ith, int nth, int n_swa);
|
||||
|
||||
IQK_API void iqk_topk_moe(int n_experts, int n_experts_used, int nrows, const float * logits,
|
||||
float * weights, int32_t * ids, int ith, int nth);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user