diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 9a57a31b..eef83572 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3342,7 +3342,19 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_diag_mask_inf(ctx, dst); break; case GGML_OP_SOFT_MAX: - if (fusion && i + 4 < cgraph->n_nodes && + if (fusion && i + 8 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+2]->op == GGML_OP_ADD && + cgraph->nodes[i+3]->op == GGML_OP_ARGSORT && + cgraph->nodes[i+4]->op == GGML_OP_VIEW && + cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS && + cgraph->nodes[i+6]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+7]->op == GGML_OP_SUM_ROWS && + cgraph->nodes[i+8]->op == GGML_OP_DIV) { + ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+8], cgraph->nodes[i+4], cgraph->nodes[i+2]->src[1]); + i += 8; + } + else if (fusion && i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ARGSORT && cgraph->nodes[i+3]->op == GGML_OP_VIEW && diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index c71a1f64..aee3d813 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -14,6 +14,7 @@ template __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, float * weights, int32_t * ids, + const float * bias, const int n_rows, const int n_expert_used) { const int row = blockIdx.x * blockDim.y + threadIdx.y; @@ -32,7 +33,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * #pragma unroll for (int i = 0; i < n_experts; i += WARP_SIZE) { const int expert = i + threadIdx.x; - logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] : -INFINITY; + logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] + (bias ? bias[expert] : 0.0f) : -INFINITY; } float max_val = logits_r[0]; @@ -154,6 +155,7 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, const float * logits, float * weights, int32_t * ids, + const float * bias, const int n_rows, const int n_expert, const int n_expert_used) { @@ -169,34 +171,34 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, switch (n_expert) { case 1: - topk_moe_cuda<1, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<1, normalize><<>>(logits, weights, ids, bias, n_rows, n_expert_used); break; case 2: - topk_moe_cuda<2, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<2, normalize><<>>(logits, weights, ids, bias, n_rows, n_expert_used); break; case 4: - topk_moe_cuda<4, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<4, normalize><<>>(logits, weights, ids, bias, n_rows, n_expert_used); break; case 8: - topk_moe_cuda<8, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<8, normalize><<>>(logits, weights, ids, bias, n_rows, n_expert_used); break; case 16: - topk_moe_cuda<16, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<16, normalize><<>>(logits, weights, ids, bias, n_rows, n_expert_used); break; case 32: - topk_moe_cuda<32, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<32, normalize><<>>(logits, weights, ids, bias, n_rows, n_expert_used); break; case 64: - topk_moe_cuda<64, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<64, normalize><<>>(logits, weights, ids, bias, n_rows, n_expert_used); break; case 128: - topk_moe_cuda<128, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<128, normalize><<>>(logits, weights, ids, bias, n_rows, n_expert_used); break; case 256: - topk_moe_cuda<256, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<256, normalize><<>>(logits, weights, ids, bias, n_rows, n_expert_used); break; case 512: - topk_moe_cuda<512, normalize><<>>(logits, weights, ids, n_rows, n_expert_used); + topk_moe_cuda<512, normalize><<>>(logits, weights, ids, bias, n_rows, n_expert_used); break; default: GGML_ASSERT(false && "fatal error"); @@ -207,17 +209,23 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const ggml_tensor * logits, ggml_tensor * weights, - ggml_tensor * ids) { + ggml_tensor * ids, + ggml_tensor * bias) { GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); + if (bias) { + GGML_ASSERT(logits->ne[0] == bias->ne[0] && ggml_nrows(bias) == 1 && bias->type == GGML_TYPE_F32); + } + const int n_experts = logits->ne[0]; const int n_rows = logits->ne[1]; const float * logits_d = (const float *) logits->src[0]->data; float * weights_d = (float *) weights->data; int32_t * ids_d = (int32_t *) ids->data; + const float * bias_d = bias ? (const float *)bias->data : nullptr; GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); @@ -225,10 +233,10 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, if (weights->op == GGML_OP_DIV) { const int n_expert_used = weights->ne[0]; - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used); } else { const int n_expert_used = weights->ne[1]; - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used); } } diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh index 03f4ad56..696ed31f 100644 --- a/ggml/src/ggml-cuda/topk-moe.cuh +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -3,6 +3,7 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const ggml_tensor * logits, ggml_tensor * weights, - ggml_tensor * top_k); + ggml_tensor * top_k, + ggml_tensor * bias = nullptr); bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);