Fix K2 MoE decode bug in buffer management (#1686)

This commit is contained in:
Oql
2025-12-08 21:08:28 +08:00
committed by GitHub
parent 8139c092bf
commit ac69ea891e

View File

@@ -836,7 +836,6 @@ class AMX_K2_MOE_TP {
offset += qlen;
}
void* gate_up_ba_pool_ptr = gate_up_ba_pool_;
void* gate_bc_pool_ptr = gate_bc_pool_;
void* up_bc_pool_ptr = up_bc_pool_;
void* down_ba_pool_ptr = down_ba_pool_;
@@ -850,11 +849,6 @@ class AMX_K2_MOE_TP {
auto expert_idx = m_expert_id_map_[i];
size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;
gate_up_ba_[expert_idx]->max_m = max_m;
gate_up_ba_[expert_idx]->set_data(gate_up_ba_pool_ptr);
size_t ba_size = align64(T::BufferA::required_size(max_m, config_.hidden_size, group_size));
gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size);
gate_bc_[expert_idx]->max_m = max_m;
gate_bc_[expert_idx]->set_data(gate_bc_pool_ptr);
size_t bc_gate_size = align64(T::BufferC::required_size(max_m, config_.intermediate_size));
@@ -876,19 +870,19 @@ class AMX_K2_MOE_TP {
down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size);
used_pool_m += max_m;
used_pool_bytes_a += ba_size;
used_pool_bytes_bc_gate += bc_gate_size;
used_pool_bytes_bc_up += bc_up_size;
used_pool_bytes_ba_down += ba_down_size;
used_pool_bytes_bc_down += bc_down_size;
}
assert(used_pool_m <= pool_count_);
assert(used_pool_bytes_a <= gate_up_ba_pool_bytes_);
assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_);
assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_);
assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_);
assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_);
gate_up_ba_[0]->max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;
gate_up_ba_[0]->set_data(gate_up_ba_pool_);
gate_up_ba_[0]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);
#ifdef FORWARD_TIME_PROFILE