From ac69ea891eb4ddb0a6f43c7760b686d47d126bec Mon Sep 17 00:00:00 2001 From: Oql <1692110604@qq.com> Date: Mon, 8 Dec 2025 21:08:28 +0800 Subject: [PATCH] Fix K2 MoE decode bug in buffer management (#1686) --- kt-kernel/operators/amx/k2-moe.hpp | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/kt-kernel/operators/amx/k2-moe.hpp b/kt-kernel/operators/amx/k2-moe.hpp index 6fd4203..d9b478a 100644 --- a/kt-kernel/operators/amx/k2-moe.hpp +++ b/kt-kernel/operators/amx/k2-moe.hpp @@ -836,7 +836,6 @@ class AMX_K2_MOE_TP { offset += qlen; } - void* gate_up_ba_pool_ptr = gate_up_ba_pool_; void* gate_bc_pool_ptr = gate_bc_pool_; void* up_bc_pool_ptr = up_bc_pool_; void* down_ba_pool_ptr = down_ba_pool_; @@ -850,11 +849,6 @@ class AMX_K2_MOE_TP { auto expert_idx = m_expert_id_map_[i]; size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP; - gate_up_ba_[expert_idx]->max_m = max_m; - gate_up_ba_[expert_idx]->set_data(gate_up_ba_pool_ptr); - size_t ba_size = align64(T::BufferA::required_size(max_m, config_.hidden_size, group_size)); - gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size); - gate_bc_[expert_idx]->max_m = max_m; gate_bc_[expert_idx]->set_data(gate_bc_pool_ptr); size_t bc_gate_size = align64(T::BufferC::required_size(max_m, config_.intermediate_size)); @@ -876,19 +870,19 @@ class AMX_K2_MOE_TP { down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size); used_pool_m += max_m; - used_pool_bytes_a += ba_size; used_pool_bytes_bc_gate += bc_gate_size; used_pool_bytes_bc_up += bc_up_size; used_pool_bytes_ba_down += ba_down_size; used_pool_bytes_bc_down += bc_down_size; } assert(used_pool_m <= pool_count_); - assert(used_pool_bytes_a <= gate_up_ba_pool_bytes_); assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_); assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_); assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_); assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_); + gate_up_ba_[0]->max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP; + gate_up_ba_[0]->set_data(gate_up_ba_pool_); gate_up_ba_[0]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1); #ifdef FORWARD_TIME_PROFILE