mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 22:54:10 +00:00
Attemt at grouped topk
This commit is contained in:
@@ -650,6 +650,7 @@ extern "C" {
|
||||
GGML_OP_TIMESTEP_EMBEDDING,
|
||||
GGML_OP_ARGSORT,
|
||||
GGML_OP_ARGSORT_THRESH,
|
||||
GGML_OP_GROUPED_TOPK,
|
||||
GGML_OP_LEAKY_RELU,
|
||||
GGML_OP_SOFTCAP,
|
||||
GGML_OP_SOFT_CAP_MAX,
|
||||
@@ -2265,6 +2266,12 @@ extern "C" {
|
||||
int k,
|
||||
int min_entries,
|
||||
float thresh);
|
||||
GGML_API struct ggml_tensor * ggml_grouped_topk(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int num_groups,
|
||||
int num_top_groups,
|
||||
int nk);
|
||||
|
||||
#define GGML_KQ_MASK_PAD 16
|
||||
|
||||
|
||||
@@ -4253,6 +4253,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"TIMESTEP_EMBEDDING",
|
||||
"ARGSORT",
|
||||
"ARGSORT_THRESH",
|
||||
"GROUPED_TOPK",
|
||||
"LEAKY_RELU",
|
||||
"SOFTCAP",
|
||||
"SOFT_CAP_MAX",
|
||||
@@ -4288,7 +4289,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
|
||||
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -4356,6 +4357,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"timestep_embedding(timesteps, dim, max_period)",
|
||||
"argsort(x)",
|
||||
"argsort_thresh(x)",
|
||||
"grouped_topk(x)",
|
||||
"leaky_relu(x)",
|
||||
"k2*tanh(k1*x)",
|
||||
"soft_max(k2*tanh(k1*x))",
|
||||
@@ -4391,7 +4393,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"glu(x),"
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
|
||||
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -9439,6 +9441,35 @@ struct ggml_tensor * ggml_argsort_thresh(
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_grouped_topk(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int num_groups,
|
||||
int num_top_groups,
|
||||
int nk) {
|
||||
|
||||
GGML_ASSERT(num_top_groups < num_groups);
|
||||
GGML_ASSERT(a->ne[0] % num_groups == 0);
|
||||
int64_t n_per_group = a->ne[0] / num_groups;
|
||||
GGML_ASSERT(n_per_group >= nk);
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
//struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, num_groups);
|
||||
ggml_set_op_params_i32(result, 1, num_top_groups);
|
||||
ggml_set_op_params_i32(result, 2, nk);
|
||||
|
||||
result->op = GGML_OP_GROUPED_TOPK;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
// ggml_top_k
|
||||
|
||||
struct ggml_tensor * ggml_top_k(
|
||||
@@ -20024,6 +20055,24 @@ static void ggml_compute_forward_argsort_thresh(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_grouped_topk(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
iqk_grouped_top_k(dst, params->ith, params->nth);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_flash_attn_ext
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
@@ -22521,6 +22570,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
|
||||
{
|
||||
ggml_compute_forward_argsort_thresh(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_GROUPED_TOPK:
|
||||
{
|
||||
ggml_compute_forward_grouped_topk(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
ggml_compute_forward_leaky_relu(params, tensor);
|
||||
@@ -23539,6 +23592,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
}
|
||||
case GGML_OP_GROUPED_TOPK:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
}
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
@@ -24281,6 +24338,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ARGSORT_THRESH:
|
||||
case GGML_OP_GROUPED_TOPK:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
case GGML_OP_SSM_CONV:
|
||||
|
||||
@@ -4,8 +4,64 @@
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
void iqk_grouped_top_k([[maybe_unused]] ggml_tensor * dst, [[maybe_unused]] int ith, [[maybe_unused]] int nth) {
|
||||
auto src = dst->src[0];
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_are_same_shape(src, dst));
|
||||
|
||||
auto nrows = ggml_nrows(src);
|
||||
auto npt = (nrows + nth - 1)/nth;
|
||||
auto first = npt*ith;
|
||||
auto last = std::min(first + npt, nrows);
|
||||
if (last <= first) return;
|
||||
|
||||
int n_groups = dst->op_params[0];
|
||||
int n_top_groups = dst->op_params[1];
|
||||
int nk = dst->op_params[2];
|
||||
|
||||
//if (ith == 0) printf("%s: ne00 = %ld, n_groups = %d, n_top_groups = %d, nk = %d\n", __func__, src->ne[0], n_groups, n_top_groups, nk);
|
||||
|
||||
int ne00 = src->ne[0];
|
||||
GGML_ASSERT(ne00%n_groups == 0);
|
||||
int n_per_group = ne00/n_groups;
|
||||
GGML_ASSERT(nk <= n_per_group);
|
||||
|
||||
thread_local std::vector<std::pair<float,int>> aux;
|
||||
if ((int)aux.size() < n_per_group + n_groups) aux.resize(n_per_group + n_groups);
|
||||
|
||||
auto groups = aux.data() + n_per_group;
|
||||
|
||||
for (int ir = first; ir < last; ++ir) {
|
||||
auto data = (const float *)((const char *)src->data + ir*src->nb[1]);
|
||||
auto result = (float *)((char *)dst->data + ir*dst->nb[1]);
|
||||
for (int j = 0; j < ne00; ++j) result[j] = -INFINITY;
|
||||
for (int ig = 0; ig < n_groups; ++ig) {
|
||||
for (int j = 0; j < n_per_group; ++j) {
|
||||
int jj = ig*n_per_group + j;
|
||||
aux[j] = { data[jj], jj };
|
||||
}
|
||||
std::partial_sort(aux.begin(), aux.begin() + nk, aux.end(), std::greater<std::pair<float,int>>{});
|
||||
for (int j = 0; j < nk; ++j) result[aux[j].second] = data[aux[j].second];
|
||||
//float sum = 0;
|
||||
//for (int j = 0; j < nk; ++j) sum += aux[j].first;
|
||||
//groups[ig] = { sum, ig };
|
||||
}
|
||||
//std::partial_sort(groups, groups + n_top_groups, groups + n_groups, std::greater<std::pair<float,int>>{});
|
||||
|
||||
//for (int ig = 0; ig < n_top_groups; ++ig) {
|
||||
// int jg = groups[ig].second;
|
||||
// for (int j = 0; j < n_per_group; ++j) result[jg*n_per_group + j] = data[jg*n_per_group + j];
|
||||
//}
|
||||
//for (int ig = n_top_groups; ig < n_groups; ++ig) {
|
||||
// int jg = groups[ig].second;
|
||||
// for (int j = 0; j < n_per_group; ++j) result[jg*n_per_group + j] = -INFINITY;
|
||||
//}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void iqk_argsort(ggml_tensor * dst, int ith, int nth) {
|
||||
|
||||
@@ -820,36 +820,40 @@ llm_expert_gating_func_type gating_op,
|
||||
selection_probs = logits;
|
||||
}
|
||||
|
||||
if (false && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) {
|
||||
if (true && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) {
|
||||
auto& hparams = lctx.model.hparams;
|
||||
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
|
||||
|
||||
// organize experts into n_expert_groups
|
||||
ggml_tensor * selection_groups = ggml_view_2d(ctx, ggml_cont(ctx, ggml_transpose(ctx, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens * n_exp_per_group, n_expert_groups]
|
||||
#if 0
|
||||
ggml_tensor * group_scores = ggml_top_k(ctx, selection_groups, 2); // [2, n_expert_groups]
|
||||
group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups]
|
||||
|
||||
// get top n_group_used expert groups
|
||||
group_scores = ggml_transpose(ctx, ggml_sum_rows(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1]
|
||||
#else
|
||||
// Replace top_k(2) with argmax due to backend limitations, ideally we should use something like argmax2 instead
|
||||
ggml_tensor * group_scores = ggml_reshape_2d(ctx, ggml_argmax(ctx, selection_groups), 1, selection_groups->ne[1]); // [1, n_expert_groups]
|
||||
group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 1, n_expert_groups]
|
||||
|
||||
// get top n_group_used expert groups
|
||||
group_scores = ggml_transpose(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2])); // [n_expert_groups, 1]
|
||||
#endif
|
||||
ggml_tensor * expert_groups = ggml_top_k(ctx, ggml_cont(ctx, group_scores), hparams.n_group_used); // [n_group_used, 1]
|
||||
cb(expert_groups->src[0], "ffn_moe_group_argsort", il);
|
||||
cb(expert_groups, "ffn_moe_group_topk", il);
|
||||
|
||||
// mask out the other groups
|
||||
selection_probs = ggml_get_rows(ctx, selection_groups, expert_groups); // [n_tokens * n_exp_per_group, n_group_used]
|
||||
selection_probs = ggml_set_rows(ctx, ggml_scale_bias(ctx, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_tokens * n_exp_per_group, n_expert_groups]
|
||||
selection_probs = ggml_view_2d(ctx, selection_probs, n_tokens, n_expert, n_tokens * sizeof(float), 0); // [n_tokens, n_expert]
|
||||
selection_probs = ggml_cont(ctx, ggml_transpose(ctx, selection_probs)); // [n_expert, n_tokens]
|
||||
selection_probs = ggml_grouped_topk(ctx, selection_probs, hparams.n_expert_groups, hparams.n_group_used, 2);
|
||||
cb(selection_probs, "ffn_moe_probs_masked", il);
|
||||
|
||||
// const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
|
||||
//
|
||||
// // organize experts into n_expert_groups
|
||||
// ggml_tensor * selection_groups = ggml_view_2d(ctx, ggml_cont(ctx, ggml_transpose(ctx, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens * n_exp_per_group, n_expert_groups]
|
||||
//#if 0
|
||||
// ggml_tensor * group_scores = ggml_top_k(ctx, selection_groups, 2); // [2, n_expert_groups]
|
||||
// group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups]
|
||||
//
|
||||
// // get top n_group_used expert groups
|
||||
// group_scores = ggml_transpose(ctx, ggml_sum_rows(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1]
|
||||
//#else
|
||||
// // Replace top_k(2) with argmax due to backend limitations, ideally we should use something like argmax2 instead
|
||||
// ggml_tensor * group_scores = ggml_reshape_2d(ctx, ggml_argmax(ctx, selection_groups), 1, selection_groups->ne[1]); // [1, n_expert_groups]
|
||||
// group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 1, n_expert_groups]
|
||||
//
|
||||
// // get top n_group_used expert groups
|
||||
// group_scores = ggml_transpose(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2])); // [n_expert_groups, 1]
|
||||
//#endif
|
||||
// ggml_tensor * expert_groups = ggml_top_k(ctx, ggml_cont(ctx, group_scores), hparams.n_group_used); // [n_group_used, 1]
|
||||
// cb(expert_groups->src[0], "ffn_moe_group_argsort", il);
|
||||
// cb(expert_groups, "ffn_moe_group_topk", il);
|
||||
//
|
||||
// // mask out the other groups
|
||||
// selection_probs = ggml_get_rows(ctx, selection_groups, expert_groups); // [n_tokens * n_exp_per_group, n_group_used]
|
||||
// selection_probs = ggml_set_rows(ctx, ggml_scale_bias(ctx, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_tokens * n_exp_per_group, n_expert_groups]
|
||||
// selection_probs = ggml_view_2d(ctx, selection_probs, n_tokens, n_expert, n_tokens * sizeof(float), 0); // [n_tokens, n_expert]
|
||||
// selection_probs = ggml_cont(ctx, ggml_transpose(ctx, selection_probs)); // [n_expert, n_tokens]
|
||||
// cb(selection_probs, "ffn_moe_probs_masked", il);
|
||||
}
|
||||
|
||||
// select experts
|
||||
|
||||
Reference in New Issue
Block a user