mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Fuse experts bias in top_k_moe kernel (#1170)
* GLM-4.7-Flash support * Model type * Make FA work for mla != 0 * Fuse bias in top_k_moe kernel if present
This commit is contained in:
@@ -3342,7 +3342,19 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||||||
ggml_cuda_op_diag_mask_inf(ctx, dst);
|
ggml_cuda_op_diag_mask_inf(ctx, dst);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
if (fusion && i + 4 < cgraph->n_nodes &&
|
if (fusion && i + 8 < cgraph->n_nodes &&
|
||||||
|
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
|
||||||
|
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
|
||||||
|
cgraph->nodes[i+3]->op == GGML_OP_ARGSORT &&
|
||||||
|
cgraph->nodes[i+4]->op == GGML_OP_VIEW &&
|
||||||
|
cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS &&
|
||||||
|
cgraph->nodes[i+6]->op == GGML_OP_RESHAPE &&
|
||||||
|
cgraph->nodes[i+7]->op == GGML_OP_SUM_ROWS &&
|
||||||
|
cgraph->nodes[i+8]->op == GGML_OP_DIV) {
|
||||||
|
ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+8], cgraph->nodes[i+4], cgraph->nodes[i+2]->src[1]);
|
||||||
|
i += 8;
|
||||||
|
}
|
||||||
|
else if (fusion && i + 4 < cgraph->n_nodes &&
|
||||||
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
|
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
|
||||||
cgraph->nodes[i+2]->op == GGML_OP_ARGSORT &&
|
cgraph->nodes[i+2]->op == GGML_OP_ARGSORT &&
|
||||||
cgraph->nodes[i+3]->op == GGML_OP_VIEW &&
|
cgraph->nodes[i+3]->op == GGML_OP_VIEW &&
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ template <size_t n_experts, bool normalize>
|
|||||||
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
|
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
|
||||||
float * weights,
|
float * weights,
|
||||||
int32_t * ids,
|
int32_t * ids,
|
||||||
|
const float * bias,
|
||||||
const int n_rows,
|
const int n_rows,
|
||||||
const int n_expert_used) {
|
const int n_expert_used) {
|
||||||
const int row = blockIdx.x * blockDim.y + threadIdx.y;
|
const int row = blockIdx.x * blockDim.y + threadIdx.y;
|
||||||
@@ -32,7 +33,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < n_experts; i += WARP_SIZE) {
|
for (int i = 0; i < n_experts; i += WARP_SIZE) {
|
||||||
const int expert = i + threadIdx.x;
|
const int expert = i + threadIdx.x;
|
||||||
logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] : -INFINITY;
|
logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] + (bias ? bias[expert] : 0.0f) : -INFINITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
float max_val = logits_r[0];
|
float max_val = logits_r[0];
|
||||||
@@ -154,6 +155,7 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
|||||||
const float * logits,
|
const float * logits,
|
||||||
float * weights,
|
float * weights,
|
||||||
int32_t * ids,
|
int32_t * ids,
|
||||||
|
const float * bias,
|
||||||
const int n_rows,
|
const int n_rows,
|
||||||
const int n_expert,
|
const int n_expert,
|
||||||
const int n_expert_used) {
|
const int n_expert_used) {
|
||||||
@@ -169,34 +171,34 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
|||||||
|
|
||||||
switch (n_expert) {
|
switch (n_expert) {
|
||||||
case 1:
|
case 1:
|
||||||
topk_moe_cuda<1, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
topk_moe_cuda<1, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
topk_moe_cuda<2, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
topk_moe_cuda<2, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
topk_moe_cuda<4, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
topk_moe_cuda<4, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 8:
|
case 8:
|
||||||
topk_moe_cuda<8, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
topk_moe_cuda<8, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 16:
|
case 16:
|
||||||
topk_moe_cuda<16, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
topk_moe_cuda<16, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 32:
|
case 32:
|
||||||
topk_moe_cuda<32, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
topk_moe_cuda<32, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 64:
|
case 64:
|
||||||
topk_moe_cuda<64, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
topk_moe_cuda<64, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
topk_moe_cuda<128, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
topk_moe_cuda<128, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
topk_moe_cuda<256, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
topk_moe_cuda<256, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 512:
|
case 512:
|
||||||
topk_moe_cuda<512, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
topk_moe_cuda<512, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "fatal error");
|
GGML_ASSERT(false && "fatal error");
|
||||||
@@ -207,17 +209,23 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
|||||||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||||
const ggml_tensor * logits,
|
const ggml_tensor * logits,
|
||||||
ggml_tensor * weights,
|
ggml_tensor * weights,
|
||||||
ggml_tensor * ids) {
|
ggml_tensor * ids,
|
||||||
|
ggml_tensor * bias) {
|
||||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
if (bias) {
|
||||||
|
GGML_ASSERT(logits->ne[0] == bias->ne[0] && ggml_nrows(bias) == 1 && bias->type == GGML_TYPE_F32);
|
||||||
|
}
|
||||||
|
|
||||||
const int n_experts = logits->ne[0];
|
const int n_experts = logits->ne[0];
|
||||||
const int n_rows = logits->ne[1];
|
const int n_rows = logits->ne[1];
|
||||||
|
|
||||||
const float * logits_d = (const float *) logits->src[0]->data;
|
const float * logits_d = (const float *) logits->src[0]->data;
|
||||||
float * weights_d = (float *) weights->data;
|
float * weights_d = (float *) weights->data;
|
||||||
int32_t * ids_d = (int32_t *) ids->data;
|
int32_t * ids_d = (int32_t *) ids->data;
|
||||||
|
const float * bias_d = bias ? (const float *)bias->data : nullptr;
|
||||||
|
|
||||||
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
|
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
|
||||||
|
|
||||||
@@ -225,10 +233,10 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|||||||
|
|
||||||
if (weights->op == GGML_OP_DIV) {
|
if (weights->op == GGML_OP_DIV) {
|
||||||
const int n_expert_used = weights->ne[0];
|
const int n_expert_used = weights->ne[0];
|
||||||
launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used);
|
||||||
} else {
|
} else {
|
||||||
const int n_expert_used = weights->ne[1];
|
const int n_expert_used = weights->ne[1];
|
||||||
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||||
const ggml_tensor * logits,
|
const ggml_tensor * logits,
|
||||||
ggml_tensor * weights,
|
ggml_tensor * weights,
|
||||||
ggml_tensor * top_k);
|
ggml_tensor * top_k,
|
||||||
|
ggml_tensor * bias = nullptr);
|
||||||
|
|
||||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
|
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
|
||||||
|
|||||||
Reference in New Issue
Block a user