mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Various fused ops around expert selection (#840)
* Fuse sigmoid+add+grouped_topk+get_rows (CPU) * Fix CPU + CUDA but CUDA is somehow not 100% correct as I get a slightly different PPL (lower!) * Minor * Fuse sigmoid+add+topk+get_rows (CUDA) * Fuse sigmoid+add+topk+get_rows (CPU) * Fuse topk+view+get_rows+reshape+softmax (CPU) * Fuse topk+view+get_rows+reshape+softmax (CUDA) * cpu: turn off the openai topk fusing for now Something is not right and I don't see the bug. On the CPU one doesn't gain much if anything, so not a big loss. * Also fuse sum_rows and div --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -3173,7 +3173,25 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_relu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
ggml_cuda_op_sigmoid(ctx, dst);
|
||||
if (i + 5 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_ARGSORT &&
|
||||
cgraph->nodes[i+4]->op == GGML_OP_VIEW &&
|
||||
cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS) {
|
||||
cuda_glm45moe_experts(ctx, cgraph->nodes[i+5], cgraph->nodes[i+4]);
|
||||
i += 5;
|
||||
}
|
||||
else if (i + 4 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK &&
|
||||
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS) {
|
||||
cuda_bailingmoev2_experts(ctx, cgraph->nodes[i+4], cgraph->nodes[i+4]);
|
||||
i += 4;
|
||||
} else {
|
||||
ggml_cuda_op_sigmoid(ctx, dst);
|
||||
}
|
||||
break;
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
ggml_cuda_op_hardsigmoid(ctx, dst);
|
||||
@@ -3315,10 +3333,28 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_pool2d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
ggml_cuda_op_sum_rows(ctx, dst);
|
||||
if (i + 1 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_DIV &&
|
||||
cgraph->nodes[i+1]->src[1] == dst &&
|
||||
cgraph->nodes[i+1]->src[0] == dst->src[0]) {
|
||||
ggml_cuda_op_sum_rows_div(ctx, cgraph->nodes[i+1]);
|
||||
++i;
|
||||
} else {
|
||||
ggml_cuda_op_sum_rows(ctx, dst);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_cuda_op_argsort(ctx, dst);
|
||||
if (i + 5 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_RESHAPE &&
|
||||
cgraph->nodes[i+4]->op == GGML_OP_SOFT_MAX &&
|
||||
cgraph->nodes[i+5]->op == GGML_OP_RESHAPE) {
|
||||
cuda_openai_experts(ctx, dst, cgraph->nodes[i+4]);
|
||||
i += 5;
|
||||
} else {
|
||||
ggml_cuda_op_argsort(ctx, dst);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_ARGSORT_THRESH:
|
||||
ggml_cuda_op_argsort_thresh(ctx, dst);
|
||||
|
||||
@@ -25,25 +25,8 @@ struct store {
|
||||
constexpr static bool has_thresh = false;
|
||||
};
|
||||
|
||||
template<ggml_sort_order order, typename Store, typename dst_t>
|
||||
static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int ncols, int ncols_pad, int ntop, Store s) {
|
||||
// int min_experts, float thresh_experts) {
|
||||
// bitonic sort
|
||||
int col = threadIdx.x;
|
||||
int row = blockIdx.y;
|
||||
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float * x_row = x + row * ncols;
|
||||
extern __shared__ int dst_row[];
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
template<ggml_sort_order order>
|
||||
static __device__ __forceinline__ void sort(int ncols_pad, int ncols, int col, const float * x_row, int * dst_row) {
|
||||
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
@@ -69,6 +52,28 @@ static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int n
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<ggml_sort_order order, typename Store, typename dst_t>
|
||||
static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int ncols, int ncols_pad, int ntop, Store s) {
|
||||
// int min_experts, float thresh_experts) {
|
||||
// bitonic sort
|
||||
int col = threadIdx.x;
|
||||
int row = blockIdx.y;
|
||||
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float * x_row = x + row * ncols;
|
||||
extern __shared__ int dst_row[];
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
sort<order>(ncols_pad, ncols, col, x_row, dst_row);
|
||||
|
||||
if constexpr (Store::has_thresh) {
|
||||
__syncthreads();
|
||||
@@ -92,7 +97,111 @@ static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int n
|
||||
}
|
||||
|
||||
template<ggml_sort_order order>
|
||||
static __global__ void k_topk_sum(const float * x, float * dst, const int ncols, int ncols_pad, int n_top_k) {
|
||||
static __global__ void k_argsort_f32_f32_i32(const float * x_biased, const float * x, float * weights, int * ids, const int ncols, int ncols_pad, int ntop,
|
||||
size_t nb_ids) {
|
||||
// bitonic sort
|
||||
int col = threadIdx.x;
|
||||
int row = blockIdx.y;
|
||||
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float * x_row = x_biased + row * ncols;
|
||||
extern __shared__ int dst_row[];
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
sort<order>(ncols_pad, ncols, col, x_row, dst_row);
|
||||
|
||||
if (col < ntop) {
|
||||
weights[row * ntop + col] = 1/(1 + expf(-x[row * ncols + dst_row[col]]));
|
||||
auto row_ids = (int *)((char *)ids + row*nb_ids);
|
||||
row_ids[col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
|
||||
template<ggml_sort_order order>
|
||||
static __global__ void k_argsort_biased_f32_f32_i32(const float * x, const float * bias, float * weights, int * ids, const int ncols, int ncols_pad, int ntop,
|
||||
size_t nb_ids) {
|
||||
// bitonic sort
|
||||
int col = threadIdx.x;
|
||||
int row = blockIdx.y;
|
||||
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
}
|
||||
|
||||
extern __shared__ int dst_row[];
|
||||
auto x_row = (float *)(dst_row + ncols_pad);
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
x_row[col] = col < ncols ? 1/(1 + expf(-x[row*ncols + col])) + bias[col] : -INFINITY;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
sort<order>(ncols_pad, ncols, col, x_row, dst_row);
|
||||
|
||||
if (col < ntop) {
|
||||
weights[row * ntop + col] = 1/(1 + expf(-x[row * ncols + dst_row[col]]));
|
||||
auto row_ids = (int *)((char *)ids + row*nb_ids);
|
||||
row_ids[col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
|
||||
template<ggml_sort_order order>
|
||||
static __global__ void k_openai_f32_f32_i32(const float * x, float * weights, int * ids, const int ncols, int ncols_pad, int ntop,
|
||||
size_t nb_ids) {
|
||||
// bitonic sort
|
||||
int col = threadIdx.x;
|
||||
int row = blockIdx.y;
|
||||
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
}
|
||||
|
||||
extern __shared__ int dst_row[];
|
||||
auto x_row = x + row*ncols;
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
sort<order>(ncols_pad, ncols, col, x_row, dst_row);
|
||||
|
||||
float max = x_row[dst_row[0]];
|
||||
float val = col < ntop ? expf(x_row[dst_row[col]] - max) : 0.0f;
|
||||
float sum = warp_reduce_sum(val);
|
||||
if (blockDim.x > WARP_SIZE) {
|
||||
__syncthreads();
|
||||
float * s_sum = (float *)(dst_row + ncols_pad);
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = sum;
|
||||
}
|
||||
__syncthreads();
|
||||
sum = 0.0f;
|
||||
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
|
||||
sum = s_sum[lane_id];
|
||||
}
|
||||
sum = warp_reduce_sum(sum);
|
||||
}
|
||||
float norm = 1/sum;
|
||||
if (col < ntop) {
|
||||
weights[row * ntop + col] = norm*val;
|
||||
auto row_ids = (int *)((char *)ids + row*nb_ids);
|
||||
row_ids[col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
|
||||
template<ggml_sort_order order>
|
||||
static __global__ void k_topk_sum(const float * x, const float * bias, float * x_p, float * dst, const int ncols, int ncols_pad, int n_top_k) {
|
||||
// bitonic sort
|
||||
int col = threadIdx.x;
|
||||
int row = blockIdx.y;
|
||||
@@ -106,40 +215,23 @@ static __global__ void k_topk_sum(const float * x, float * dst, const int ncols,
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
if (bias && x_p) {
|
||||
float * x_p_row = x_p + row * ncols;
|
||||
if (col < ncols) {
|
||||
x_p_row[col] = 1/(1 + expf(-x_row[col])) + bias[col];
|
||||
}
|
||||
x_row = x_p_row;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= ncols ||
|
||||
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= ncols ||
|
||||
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
sort<order>(ncols_pad, ncols, col, x_row, dst_row);
|
||||
|
||||
float val = col < n_top_k ? x_row[dst_row[col]] : 0;
|
||||
val = warp_reduce_sum(val);
|
||||
if (blockDim.x > WARP_SIZE) {
|
||||
__syncthreads();
|
||||
float * s_sum = (float *)dst_row;
|
||||
float * s_sum = (float *)(dst_row + ncols_pad);
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
@@ -208,6 +300,75 @@ static void argsort_f32_T_cuda(const float * x, dst_t * dst, const int ncols, co
|
||||
}
|
||||
}
|
||||
|
||||
static void argsort_f32_f32_i32_cuda(const float * x_biased, const float * x, float * weights, int * ids, const int ncols, const int nrows, int ntop,
|
||||
size_t nb_ids, ggml_sort_order order, cudaStream_t stream) {
|
||||
// bitonic sort requires ncols to be power of 2
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
|
||||
const dim3 block_dims(ncols_pad, 1, 1);
|
||||
const dim3 block_nums(1, nrows, 1);
|
||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||
|
||||
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
||||
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
k_argsort_f32_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x_biased, x, weights, ids,
|
||||
ncols, ncols_pad, ntop, nb_ids);
|
||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||
k_argsort_f32_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x_biased, x, weights, ids,
|
||||
ncols, ncols_pad, ntop, nb_ids);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
static void argsort_biased_f32_f32_i32_cuda(const float * x, const float * bias, float * weights, int * ids, const int ncols, const int nrows, int ntop,
|
||||
size_t nb_ids, ggml_sort_order order, cudaStream_t stream) {
|
||||
// bitonic sort requires ncols to be power of 2
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
|
||||
const dim3 block_dims(ncols_pad, 1, 1);
|
||||
const dim3 block_nums(1, nrows, 1);
|
||||
const size_t shared_mem = ncols_pad * (sizeof(int) + sizeof(float));
|
||||
|
||||
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
||||
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
k_argsort_biased_f32_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, bias, weights, ids,
|
||||
ncols, ncols_pad, ntop, nb_ids);
|
||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||
k_argsort_biased_f32_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, bias, weights, ids,
|
||||
ncols, ncols_pad, ntop, nb_ids);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
static void argsort_openai_f32_f32_i32_cuda(const float * x, float * weights, int * ids, const int ncols, const int nrows, int ntop,
|
||||
size_t nb_ids, ggml_sort_order order, cudaStream_t stream) {
|
||||
// bitonic sort requires ncols to be power of 2
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
|
||||
const dim3 block_dims(ncols_pad, 1, 1);
|
||||
const dim3 block_nums(1, nrows, 1);
|
||||
const size_t shared_mem = (ncols_pad + ncols_pad > WARP_SIZE ? WARP_SIZE : 0) * sizeof(int);
|
||||
|
||||
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
||||
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
k_openai_f32_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, weights, ids,
|
||||
ncols, ncols_pad, ntop, nb_ids);
|
||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||
k_openai_f32_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, weights, ids,
|
||||
ncols, ncols_pad, ntop, nb_ids);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
@@ -246,7 +407,8 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
argsort_f32_T_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, GGML_SORT_ORDER_DESC, min_experts, thresh, stream);
|
||||
}
|
||||
|
||||
static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * src, float * dst, int ncols, int nrows, int n_top_k) {
|
||||
static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * src, const float * bias, float * src_p, float * dst,
|
||||
int ncols, int nrows, int n_top_k) {
|
||||
|
||||
GGML_ASSERT(n_top_k <= ncols);
|
||||
|
||||
@@ -254,10 +416,10 @@ static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float *
|
||||
|
||||
const dim3 block_dims(ncols_pad, 1, 1);
|
||||
const dim3 block_nums(1, nrows, 1);
|
||||
const size_t shared_mem = std::max(ncols_pad, WARP_SIZE) * sizeof(int);
|
||||
const size_t shared_mem = (ncols_pad + WARP_SIZE) * sizeof(int);
|
||||
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
||||
|
||||
k_topk_sum<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, ctx.stream()>>>(src, dst, ncols, ncols_pad, n_top_k);
|
||||
k_topk_sum<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, ctx.stream()>>>(src, bias, src_p, dst, ncols, ncols_pad, n_top_k);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
@@ -291,7 +453,7 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
#else
|
||||
ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
|
||||
ggml_cuda_op_topk_sum(ctx, (const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, nk);
|
||||
ggml_cuda_op_topk_sum(ctx, (float *)src->data, nullptr, nullptr, group_scores.get(), n_per_group, nrows*n_groups, nk);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
#endif
|
||||
|
||||
@@ -310,3 +472,92 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds
|
||||
argsort_f32_T_cuda((const float *)src->data, (int *)dst->data, ne00, nrows, ne0, GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream());
|
||||
|
||||
}
|
||||
|
||||
void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk) {
|
||||
auto topk_src = topk->src[0];
|
||||
auto probs = topk_src->src[0]->src[0];
|
||||
auto bias = topk_src->src[1];
|
||||
|
||||
auto nrows = ggml_nrows(probs);
|
||||
|
||||
int n_groups = topk->op_params[0];
|
||||
int n_top_groups = topk->op_params[1];
|
||||
int nk = topk->op_params[2];
|
||||
|
||||
int ne00 = probs->ne[0];
|
||||
int ne0 = topk->ne[0];
|
||||
GGML_ASSERT(ggml_is_contiguous(probs));
|
||||
GGML_ASSERT(bias->ne[1] == 1);
|
||||
GGML_ASSERT(bias->ne[0] == probs->ne[0]);
|
||||
GGML_ASSERT(ne0 == dst->ne[1]);
|
||||
GGML_ASSERT(ne0 <= ne00);
|
||||
GGML_ASSERT(ne00%n_groups == 0);
|
||||
int n_per_group = ne00/n_groups;
|
||||
GGML_ASSERT(nk <= n_per_group);
|
||||
GGML_ASSERT(n_top_groups <= n_groups);
|
||||
int n_discarded_groups = n_groups - n_top_groups;
|
||||
|
||||
ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
|
||||
ggml_cuda_op_topk_sum(ctx, (const float *)probs->data, (const float *)bias->data, (float *)topk_src->data, group_scores.get(),
|
||||
n_per_group, nrows*n_groups, nk);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
ggml_cuda_pool_alloc<int> discarded_groups(ctx.pool(), nrows*n_discarded_groups);
|
||||
argsort_f32_T_cuda(group_scores.get(), discarded_groups.get(), n_groups, nrows, n_discarded_groups, GGML_SORT_ORDER_ASC, -1, 0.0f, ctx.stream());
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
{
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(1, nrows, 1);
|
||||
k_apply_mask<<<block_nums, block_dims, 0, ctx.stream()>>>((float *)topk_src->data, discarded_groups.get(), n_discarded_groups, n_per_group, ne00);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
argsort_f32_f32_i32_cuda((const float *)topk_src->data, (const float *)probs->data, (float *)dst->data, (int *)topk->data,
|
||||
ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream());
|
||||
|
||||
}
|
||||
|
||||
void cuda_glm45moe_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk_view) {
|
||||
GGML_ASSERT(topk_view->op == GGML_OP_VIEW);
|
||||
auto topk = topk_view->src[0];
|
||||
auto topk_src = topk->src[0];
|
||||
auto probs = topk_src->src[0]->src[0];
|
||||
auto bias = topk_src->src[1];
|
||||
|
||||
auto nrows = ggml_nrows(probs);
|
||||
|
||||
int ne00 = probs->ne[0];
|
||||
int ne0 = topk_view->ne[0];
|
||||
GGML_ASSERT(ggml_is_contiguous(probs));
|
||||
GGML_ASSERT(bias->ne[1] == 1);
|
||||
GGML_ASSERT(bias->ne[0] == probs->ne[0]);
|
||||
GGML_ASSERT(ne0 == dst->ne[1]);
|
||||
GGML_ASSERT(ne0 <= ne00);
|
||||
|
||||
argsort_biased_f32_f32_i32_cuda((const float *)probs->data, (const float *)bias->data, (float *)dst->data, (int *)topk->data,
|
||||
ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream());
|
||||
|
||||
}
|
||||
|
||||
void cuda_openai_experts(ggml_backend_cuda_context & ctx, ggml_tensor * topk, ggml_tensor * softmax) {
|
||||
|
||||
auto probs = topk->src[0];
|
||||
int ntop = topk->op_params[1];
|
||||
|
||||
auto nrows = ggml_nrows(probs);
|
||||
int ne00 = probs->ne[0];
|
||||
int ne0 = softmax->ne[0];
|
||||
GGML_ASSERT(ggml_is_contiguous(probs));
|
||||
GGML_ASSERT(ggml_is_contiguous(softmax));
|
||||
GGML_ASSERT(ne0 <= ne00);
|
||||
if (ntop != ne0) {
|
||||
printf("Oops: ntop = %d, ne0 = %d\n", ntop, ne0);
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
//GGML_ASSERT(ne0 == ntop);
|
||||
|
||||
argsort_openai_f32_f32_i32_cuda((const float *)probs->data, (float *)softmax->data, (int *)topk->data,
|
||||
ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream());
|
||||
|
||||
}
|
||||
|
||||
@@ -11,3 +11,9 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk);
|
||||
|
||||
void cuda_glm45moe_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk);
|
||||
|
||||
void cuda_openai_experts(ggml_backend_cuda_context & ctx, ggml_tensor * topk, ggml_tensor * softmax);
|
||||
|
||||
@@ -16,12 +16,38 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void k_sum_rows_div_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i = col; i < ncols; i += blockDim.x) {
|
||||
sum += x[row * ncols + i];
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
float norm = sum > 0 ? 1/sum : 0.0f;
|
||||
for (int i = col; i < ncols; i += blockDim.x) {
|
||||
dst[row * ncols + i] = x[row * ncols + i] * norm;
|
||||
}
|
||||
//for (int i = col; i < ncols; i += blockDim.x) {
|
||||
// dst[row * ncols + i] = x[row * ncols + i] / sum;
|
||||
//}
|
||||
}
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
}
|
||||
|
||||
static void sum_rows_div_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
k_sum_rows_div_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
@@ -38,3 +64,19 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_sum_rows_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
sum_rows_div_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
|
||||
}
|
||||
|
||||
@@ -3,3 +3,5 @@
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||
|
||||
void ggml_cuda_op_sum_rows_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -22357,7 +22357,15 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
{
|
||||
ggml_compute_forward_sum_rows(params, tensor);
|
||||
if (i + 1 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_DIV &&
|
||||
cgraph->nodes[i+1]->src[1] == tensor &&
|
||||
cgraph->nodes[i+1]->src[0] == tensor->src[0]) {
|
||||
iqk_sumrows_div(cgraph->nodes[i+1], params->ith, params->nth);
|
||||
++i;
|
||||
} else {
|
||||
ggml_compute_forward_sum_rows(params, tensor);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
@@ -22568,7 +22576,17 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
ggml_compute_forward_argsort(params, tensor);
|
||||
if (false && i + 5 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_RESHAPE &&
|
||||
cgraph->nodes[i+4]->op == GGML_OP_SOFT_MAX &&
|
||||
cgraph->nodes[i+5]->op == GGML_OP_RESHAPE) {
|
||||
iqk_openai_experts(tensor, cgraph->nodes[i+4], params->ith, params->nth);
|
||||
i += 5;
|
||||
} else {
|
||||
ggml_compute_forward_argsort(params, tensor);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ARGSORT_THRESH:
|
||||
{
|
||||
@@ -22611,7 +22629,26 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
|
||||
} break;
|
||||
case GGML_OP_UNARY:
|
||||
{
|
||||
ggml_compute_forward_unary(params, tensor);
|
||||
const enum ggml_unary_op unary_op = ggml_get_unary_op(tensor);
|
||||
if (unary_op == GGML_UNARY_OP_SIGMOID && i + 5 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_ARGSORT &&
|
||||
cgraph->nodes[i+4]->op == GGML_OP_VIEW &&
|
||||
cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS) {
|
||||
iqk_glm45moe_experts(cgraph->nodes[i+5], cgraph->nodes[i+4], params->ith, params->nth);
|
||||
i += 5;
|
||||
}
|
||||
else if (unary_op == GGML_UNARY_OP_SIGMOID && i + 4 < cgraph->n_nodes &&
|
||||
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
|
||||
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
|
||||
cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK &&
|
||||
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS) {
|
||||
iqk_bailingmoev2_experts(cgraph->nodes[i+4], cgraph->nodes[i+3], params->ith, params->nth);
|
||||
i += 4;
|
||||
} else {
|
||||
ggml_compute_forward_unary(params, tensor);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_GLU:
|
||||
{
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
//
|
||||
|
||||
#include "iqk_cpu_ops.h"
|
||||
#include "iqk_utils.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <cstdint>
|
||||
@@ -39,6 +40,110 @@ inline std::vector<std::pair<float,int>> & get_work_buffer(size_t size) {
|
||||
return buffer;
|
||||
|
||||
}
|
||||
#ifdef __ARM_NEON
|
||||
inline float32x4_t v_sigmoid(float32x4_t x) {
|
||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||
const float32x4_t zero = vdupq_n_f32(0.0f);
|
||||
const float32x4_t neg_x = vsubq_f32(zero, x);
|
||||
const float32x4_t exp_neg_x = v_expf(neg_x);
|
||||
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
|
||||
return vdivq_f32(one, one_plus_exp_neg_x);
|
||||
}
|
||||
#endif
|
||||
#ifdef __AVX2__
|
||||
inline __m256 v_sigmoid(__m256 x) {
|
||||
const __m256 one = _mm256_set1_ps(1);
|
||||
const __m256 zero = _mm256_setzero_ps();
|
||||
const __m256 neg_x = _mm256_sub_ps(zero, x);
|
||||
const __m256 exp_neg_x = v_expf(neg_x);
|
||||
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
|
||||
return _mm256_div_ps(one, one_plus_exp_neg_x);
|
||||
}
|
||||
#endif
|
||||
#if defined __AVX512F__ && defined __AVX512DQ__
|
||||
inline __m512 v_sigmoid(__m512 x) {
|
||||
const __m512 one = _mm512_set1_ps(1);
|
||||
const __m512 zero = _mm512_setzero_ps();
|
||||
const __m512 neg_x = _mm512_sub_ps(zero, x);
|
||||
const __m512 exp_neg_x = v_expf(neg_x);
|
||||
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
|
||||
return _mm512_div_ps(one, one_plus_exp_neg_x);
|
||||
}
|
||||
#endif
|
||||
inline void biased_sigmoid(int n, const float * x, const float * bias, float * y, float * z) {
|
||||
int i = 0;
|
||||
#if defined __AVX512F__ && defined __AVX512DQ__
|
||||
for (; i + 15 < n; i += 16) {
|
||||
auto v = v_sigmoid(_mm512_loadu_ps(x + i));
|
||||
_mm512_storeu_ps(y + i, _mm512_add_ps(v, _mm512_loadu_ps(bias + i)));
|
||||
_mm512_storeu_ps(z + i, v);
|
||||
}
|
||||
#endif
|
||||
#if defined __AVX2__ && defined __FMA__
|
||||
for (; i + 7 < n; i += 8) {
|
||||
auto v = v_sigmoid(_mm256_loadu_ps(x + i));
|
||||
_mm256_storeu_ps(y + i, _mm256_add_ps(v, _mm256_loadu_ps(bias + i)));
|
||||
_mm256_storeu_ps(z + i, v);
|
||||
}
|
||||
#endif
|
||||
#ifdef __ARM_NEON
|
||||
for (; i + 3 < n; i += 4) {
|
||||
auto v = v_sigmoid(vld1q_f32(x + i));
|
||||
vst1q_f32(y + i, vaddq_f32(v, vld1q_f32(bias + i)));
|
||||
vst1q_f32(z + i, v);
|
||||
}
|
||||
#endif
|
||||
for (; i < n; ++i) {
|
||||
z[i] = 1/(1 + expf(-x[i]));
|
||||
y[i] = y[i] + bias[i];
|
||||
}
|
||||
}
|
||||
inline void biased_sigmoid(int n, const float * x, const float * bias, float * y) {
|
||||
int i = 0;
|
||||
#if defined __AVX512F__ && defined __AVX512DQ__
|
||||
for (; i + 15 < n; i += 16) {
|
||||
auto v = v_sigmoid(_mm512_loadu_ps(x + i));
|
||||
_mm512_storeu_ps(y + i, _mm512_add_ps(v, _mm512_loadu_ps(bias + i)));
|
||||
}
|
||||
#endif
|
||||
#if defined __AVX2__ && defined __FMA__
|
||||
for (; i + 7 < n; i += 8) {
|
||||
auto v = v_sigmoid(_mm256_loadu_ps(x + i));
|
||||
_mm256_storeu_ps(y + i, _mm256_add_ps(v, _mm256_loadu_ps(bias + i)));
|
||||
}
|
||||
#endif
|
||||
#ifdef __ARM_NEON
|
||||
for (; i + 3 < n; i += 4) {
|
||||
auto v = v_sigmoid(vld1q_f32(x + i));
|
||||
vst1q_f32(y + i, vaddq_f32(v, vld1q_f32(bias + i)));
|
||||
}
|
||||
#endif
|
||||
for (; i < n; ++i) {
|
||||
y[i] = 1/(1 + expf(-x[i])) + bias[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void iqk_sumrows_div(struct ggml_tensor * div, int ith, int nth) {
|
||||
auto src = div->src[0];
|
||||
GGML_ASSERT(src->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(div->type == GGML_TYPE_F32);
|
||||
|
||||
int ne00 = src->ne[0];
|
||||
int nrows = ggml_nrows(src);
|
||||
int npt = (nrows + nth - 1)/nth;
|
||||
int first = ith*npt;
|
||||
int last = std::min(first + npt, nrows);
|
||||
if (last < first) return;
|
||||
|
||||
for (int ir = first; ir < last; ++ir) {
|
||||
auto values = (const float *)((const char *)src->data + ir*src->nb[1]);
|
||||
float sum = 0;
|
||||
for (int j = 0; j < ne00; ++j) sum += values[j];
|
||||
float norm = sum > 0 ? 1/sum : 0.0f;
|
||||
auto result = (float *)((char *)div->data + ir*div->nb[1]);
|
||||
for (int j = 0; j < ne00; ++j) result[j] = values[j]*norm;
|
||||
}
|
||||
}
|
||||
|
||||
void iqk_grouped_top_k(ggml_tensor * dst, int ith, int nth) {
|
||||
@@ -126,15 +231,15 @@ void iqk_argsort(ggml_tensor * dst, int ith, int nth) {
|
||||
for (int j = 0; j < ne00; ++j) aux[j] = {data[j], j};
|
||||
if (nk < ne00) {
|
||||
if (order == GGML_SORT_ORDER_DESC) {
|
||||
std::partial_sort(aux.begin(), aux.begin() + nk, aux.end(), std::greater<std::pair<float,int>>{});
|
||||
std::partial_sort(aux.begin(), aux.begin() + nk, aux.begin() + ne00, std::greater<std::pair<float,int>>{});
|
||||
} else {
|
||||
std::partial_sort(aux.begin(), aux.begin() + nk, aux.end());
|
||||
std::partial_sort(aux.begin(), aux.begin() + nk, aux.begin() + ne00);
|
||||
}
|
||||
} else {
|
||||
if (order == GGML_SORT_ORDER_DESC) {
|
||||
std::sort(aux.begin(), aux.end(), std::greater<std::pair<float,int>>{});
|
||||
std::sort(aux.begin(), aux.begin() + ne00, std::greater<std::pair<float,int>>{});
|
||||
} else {
|
||||
std::sort(aux.begin(), aux.end());
|
||||
std::sort(aux.begin(), aux.begin() + ne00);
|
||||
}
|
||||
}
|
||||
auto y = (int32_t *)((char *)dst->data + ir*dst->nb[1]);
|
||||
@@ -143,3 +248,164 @@ void iqk_argsort(ggml_tensor * dst, int ith, int nth) {
|
||||
|
||||
}
|
||||
|
||||
void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * topk, int ith, int nth) {
|
||||
auto topk_src = topk->src[0];
|
||||
auto probs = topk_src->src[0]->src[0];
|
||||
auto t_bias = topk_src->src[1];
|
||||
|
||||
auto nrows = ggml_nrows(probs);
|
||||
auto npt = (nrows + nth - 1)/nth;
|
||||
auto first = npt*ith;
|
||||
auto last = std::min(first + npt, nrows);
|
||||
if (last <= first) return;
|
||||
|
||||
int n_groups = topk->op_params[0];
|
||||
int n_top_groups = topk->op_params[1];
|
||||
int nk = topk->op_params[2];
|
||||
|
||||
int ne00 = probs->ne[0];
|
||||
int ne0 = topk->ne[0];
|
||||
GGML_ASSERT(ggml_is_contiguous(probs));
|
||||
GGML_ASSERT(t_bias->ne[1] == 1);
|
||||
GGML_ASSERT(t_bias->ne[0] == probs->ne[0]);
|
||||
GGML_ASSERT(ne0 == dst->ne[1]);
|
||||
GGML_ASSERT(ne0 <= ne00);
|
||||
GGML_ASSERT(ne00%n_groups == 0);
|
||||
int n_per_group = ne00/n_groups;
|
||||
GGML_ASSERT(nk <= n_per_group);
|
||||
GGML_ASSERT(n_top_groups <= n_groups);
|
||||
|
||||
size_t work_size = n_groups + n_per_group*n_top_groups + ne00;
|
||||
auto& aux = get_work_buffer(work_size);
|
||||
|
||||
auto groups = aux.data() + n_per_group*n_top_groups;
|
||||
auto biased_values = (float *)(groups + n_groups);
|
||||
auto values = biased_values + ne00;
|
||||
|
||||
auto bias = (const float *)t_bias->data;
|
||||
|
||||
for (int ir = first; ir < last; ++ir) {
|
||||
auto data = (const float *)((const char *)probs->data + ir*probs->nb[1]);
|
||||
biased_sigmoid(ne00, data, bias, biased_values, values);
|
||||
//for (int j = 0; j < ne00; ++j) values[j] = 1/(1 + expf(-data[j])) + bias[j];
|
||||
auto weights = (float *)((char *)dst->data + ir*dst->nb[2]);
|
||||
auto ids = (int32_t *)((char *)topk->data + ir*topk->nb[1]);
|
||||
if (ne0 > n_per_group*n_top_groups) {
|
||||
for (int j = 0; j < ne0; ++j) {
|
||||
weights[j] = values[j];
|
||||
ids[j] = j;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (n_top_groups < n_groups) {
|
||||
for (int ig = 0; ig < n_groups; ++ig) {
|
||||
groups[ig] = { group_score(n_per_group, nk, biased_values + ig*n_per_group, (float *)aux.data()), ig };
|
||||
}
|
||||
std::partial_sort(groups, groups + n_top_groups, groups + n_groups, std::greater<std::pair<float,int>>{});
|
||||
|
||||
for (int ig = 0; ig < n_top_groups; ++ig) {
|
||||
int i0 = n_per_group * ig;
|
||||
int j0 = n_per_group * groups[ig].second;
|
||||
for (int j = 0; j < n_per_group; ++j) aux[i0 + j] = { biased_values[j0 + j], j0 + j };
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < ne00; ++j) aux[j] = { biased_values[j], j };
|
||||
}
|
||||
std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + n_top_groups*n_per_group, std::greater<std::pair<float,int>>{});
|
||||
for (int j = 0; j < ne0; ++j) {
|
||||
weights[j] = values[aux[j].second];
|
||||
ids[j] = aux[j].second;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
void iqk_glm45moe_experts(struct ggml_tensor * dst, struct ggml_tensor * topk_view, int ith, int nth) {
|
||||
GGML_ASSERT(topk_view->op == GGML_OP_VIEW);
|
||||
auto topk = topk_view->src[0];
|
||||
auto topk_src = topk->src[0];
|
||||
auto probs = topk_src->src[0]->src[0];
|
||||
auto t_bias = topk_src->src[1];
|
||||
|
||||
auto nrows = ggml_nrows(probs);
|
||||
auto npt = (nrows + nth - 1)/nth;
|
||||
auto first = npt*ith;
|
||||
auto last = std::min(first + npt, nrows);
|
||||
if (last <= first) return;
|
||||
|
||||
int ne00 = probs->ne[0];
|
||||
int ne0 = topk_view->ne[0];
|
||||
GGML_ASSERT(ggml_is_contiguous(probs));
|
||||
GGML_ASSERT(t_bias->ne[1] == 1);
|
||||
GGML_ASSERT(t_bias->ne[0] == probs->ne[0]);
|
||||
GGML_ASSERT(ne0 == dst->ne[1]);
|
||||
GGML_ASSERT(ne0 <= ne00);
|
||||
|
||||
size_t work_size = 2*ne00;
|
||||
auto& aux = get_work_buffer(work_size);
|
||||
|
||||
auto biased_values = (float *)(aux.data() + ne00);
|
||||
//auto values = biased_values + ne00;
|
||||
|
||||
auto bias = (const float *)t_bias->data;
|
||||
|
||||
for (int ir = first; ir < last; ++ir) {
|
||||
auto data = (const float *)((const char *)probs->data + ir*probs->nb[1]);
|
||||
//biased_sigmoid(ne00, data, bias, biased_values, values);
|
||||
biased_sigmoid(ne00, data, bias, biased_values);
|
||||
auto weights = (float *)((char *)dst->data + ir*dst->nb[2]);
|
||||
auto ids = (int32_t *)((char *)topk->data + ir*topk->nb[1]);
|
||||
for (int j = 0; j < ne00; ++j) aux[j] = { biased_values[j], j };
|
||||
if (ne0 < ne00) {
|
||||
std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + ne00, std::greater<std::pair<float,int>>{});
|
||||
} else {
|
||||
std::sort(aux.begin(), aux.begin() + ne00, std::greater<std::pair<float,int>>{});
|
||||
}
|
||||
for (int j = 0; j < ne0; ++j) {
|
||||
weights[j] = 1/(1 + expf(-data[aux[j].second]));
|
||||
ids[j] = aux[j].second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void iqk_openai_experts(struct ggml_tensor * topk, struct ggml_tensor * softmax, int ith, int nth) {
|
||||
|
||||
auto probs = topk->src[0];
|
||||
|
||||
auto nrows = ggml_nrows(probs);
|
||||
auto npt = (nrows + nth - 1)/nth;
|
||||
auto first = npt*ith;
|
||||
auto last = std::min(first + npt, nrows);
|
||||
if (last <= first) return;
|
||||
|
||||
int ne00 = probs->ne[0];
|
||||
int ne0 = softmax->ne[0];
|
||||
GGML_ASSERT(ggml_is_contiguous(probs));
|
||||
GGML_ASSERT(ggml_is_contiguous(softmax));
|
||||
GGML_ASSERT(ne0 <= ne00);
|
||||
|
||||
size_t work_size = ne00;
|
||||
auto& aux = get_work_buffer(work_size);
|
||||
|
||||
for (int ir = first; ir < last; ++ir) {
|
||||
auto data = (const float *)((const char *)probs->data + ir*probs->nb[1]);
|
||||
for (int j = 0; j < ne00; ++j) aux[j] = { data[j], j };
|
||||
if (ne0 < ne00) {
|
||||
std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + ne00, std::greater<std::pair<float,int>>{});
|
||||
} else {
|
||||
std::sort(aux.begin(), aux.begin() + ne00, std::greater<std::pair<float,int>>{});
|
||||
}
|
||||
auto weights = (float *)((char *)softmax->data + ir*softmax->nb[1]);
|
||||
auto ids = (int32_t *)((char *)topk->data + ir*topk->nb[1]);
|
||||
float max = aux.front().first;
|
||||
float sum = 0;
|
||||
for (int j = 0; j < ne0; ++j) {
|
||||
weights[j] = expf(aux[j].first - max);
|
||||
ids[j] = aux[j].second;
|
||||
sum += weights[j];
|
||||
}
|
||||
GGML_ASSERT(sum > 0);
|
||||
float norm = 1/sum;
|
||||
for (int j = 0; j < ne0; ++j) weights[j] *= norm;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,10 +14,18 @@ extern "C" {
|
||||
|
||||
struct ggml_tensor;
|
||||
|
||||
void iqk_sumrows_div(struct ggml_tensor * div, int ith, int nth);
|
||||
|
||||
void iqk_grouped_top_k(struct ggml_tensor * dst, int ith, int nth);
|
||||
|
||||
void iqk_argsort(struct ggml_tensor * dst, int ith, int nth);
|
||||
|
||||
void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * topk, int ith, int nth);
|
||||
|
||||
void iqk_glm45moe_experts(struct ggml_tensor * dst, struct ggml_tensor * topk_view, int ith, int nth);
|
||||
|
||||
void iqk_openai_experts(struct ggml_tensor * topk, struct ggml_tensor * softmax, int ith, int nth);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user