Fuse experts bias in top_k_moe kernel (#1170)

* GLM-4.7-Flash support

* Model type

* Make FA work for mla != 0

* Fuse bias in top_k_moe kernel if present
This commit is contained in:
Kawrakow
2026-01-20 15:38:51 +02:00
committed by GitHub
parent 996e77047a
commit 6f1a69352f
3 changed files with 37 additions and 16 deletions

View File

@@ -3342,7 +3342,19 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_diag_mask_inf(ctx, dst); ggml_cuda_op_diag_mask_inf(ctx, dst);
break; break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
if (fusion && i + 4 < cgraph->n_nodes && if (fusion && i + 8 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
cgraph->nodes[i+3]->op == GGML_OP_ARGSORT &&
cgraph->nodes[i+4]->op == GGML_OP_VIEW &&
cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS &&
cgraph->nodes[i+6]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+7]->op == GGML_OP_SUM_ROWS &&
cgraph->nodes[i+8]->op == GGML_OP_DIV) {
ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+8], cgraph->nodes[i+4], cgraph->nodes[i+2]->src[1]);
i += 8;
}
else if (fusion && i + 4 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+2]->op == GGML_OP_ARGSORT && cgraph->nodes[i+2]->op == GGML_OP_ARGSORT &&
cgraph->nodes[i+3]->op == GGML_OP_VIEW && cgraph->nodes[i+3]->op == GGML_OP_VIEW &&

View File

@@ -14,6 +14,7 @@ template <size_t n_experts, bool normalize>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights, float * weights,
int32_t * ids, int32_t * ids,
const float * bias,
const int n_rows, const int n_rows,
const int n_expert_used) { const int n_expert_used) {
const int row = blockIdx.x * blockDim.y + threadIdx.y; const int row = blockIdx.x * blockDim.y + threadIdx.y;
@@ -32,7 +33,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
#pragma unroll #pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) { for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x; const int expert = i + threadIdx.x;
logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] : -INFINITY; logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] + (bias ? bias[expert] : 0.0f) : -INFINITY;
} }
float max_val = logits_r[0]; float max_val = logits_r[0];
@@ -154,6 +155,7 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const float * logits, const float * logits,
float * weights, float * weights,
int32_t * ids, int32_t * ids,
const float * bias,
const int n_rows, const int n_rows,
const int n_expert, const int n_expert,
const int n_expert_used) { const int n_expert_used) {
@@ -169,34 +171,34 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
switch (n_expert) { switch (n_expert) {
case 1: case 1:
topk_moe_cuda<1, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); topk_moe_cuda<1, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
break; break;
case 2: case 2:
topk_moe_cuda<2, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); topk_moe_cuda<2, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
break; break;
case 4: case 4:
topk_moe_cuda<4, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); topk_moe_cuda<4, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
break; break;
case 8: case 8:
topk_moe_cuda<8, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); topk_moe_cuda<8, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
break; break;
case 16: case 16:
topk_moe_cuda<16, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); topk_moe_cuda<16, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
break; break;
case 32: case 32:
topk_moe_cuda<32, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); topk_moe_cuda<32, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
break; break;
case 64: case 64:
topk_moe_cuda<64, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); topk_moe_cuda<64, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
break; break;
case 128: case 128:
topk_moe_cuda<128, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); topk_moe_cuda<128, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
break; break;
case 256: case 256:
topk_moe_cuda<256, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); topk_moe_cuda<256, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
break; break;
case 512: case 512:
topk_moe_cuda<512, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); topk_moe_cuda<512, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used);
break; break;
default: default:
GGML_ASSERT(false && "fatal error"); GGML_ASSERT(false && "fatal error");
@@ -207,17 +209,23 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits, const ggml_tensor * logits,
ggml_tensor * weights, ggml_tensor * weights,
ggml_tensor * ids) { ggml_tensor * ids,
ggml_tensor * bias) {
GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32); GGML_ASSERT(ids->type == GGML_TYPE_I32);
if (bias) {
GGML_ASSERT(logits->ne[0] == bias->ne[0] && ggml_nrows(bias) == 1 && bias->type == GGML_TYPE_F32);
}
const int n_experts = logits->ne[0]; const int n_experts = logits->ne[0];
const int n_rows = logits->ne[1]; const int n_rows = logits->ne[1];
const float * logits_d = (const float *) logits->src[0]->data; const float * logits_d = (const float *) logits->src[0]->data;
float * weights_d = (float *) weights->data; float * weights_d = (float *) weights->data;
int32_t * ids_d = (int32_t *) ids->data; int32_t * ids_d = (int32_t *) ids->data;
const float * bias_d = bias ? (const float *)bias->data : nullptr;
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
@@ -225,10 +233,10 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
if (weights->op == GGML_OP_DIV) { if (weights->op == GGML_OP_DIV) {
const int n_expert_used = weights->ne[0]; const int n_expert_used = weights->ne[0];
launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used);
} else { } else {
const int n_expert_used = weights->ne[1]; const int n_expert_used = weights->ne[1];
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used);
} }
} }

View File

@@ -3,6 +3,7 @@
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits, const ggml_tensor * logits,
ggml_tensor * weights, ggml_tensor * weights,
ggml_tensor * top_k); ggml_tensor * top_k,
ggml_tensor * bias = nullptr);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights); bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);