[fix]: fix k2-moe.hpp load weight (#1830)

This commit is contained in:
Oql
2026-02-03 11:28:49 +08:00
committed by GitHub
parent 794c04fae4
commit c28cfcb26e

View File

@@ -484,19 +484,19 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>
int& group_size = config.quant_config.group_size;
if (use_per_expert_ptrs) {
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
pool->dispense_backend()->do_numa_job([&, this](int i) {
auto& tpc = tps[i]->config_;
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
if (use_per_expert_ptrs) {
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&, i](int expert_id_) {
@@ -533,26 +533,11 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>
}
},
nullptr);
printf("TP %d load weight done.\n", i);
}
} else {
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
} else {
if (tpc.load == false) {
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&](int expert_id_) {
[&, i](int expert_id_) {
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),
@@ -593,9 +578,9 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>
},
nullptr);
}
printf("TP %d load weight done.\n", i);
}
}
printf("TP %d load weight done.\n", i);
});
#ifdef LOAD_TIME_PROFILE
{
@@ -616,7 +601,7 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>
}
#endif
for (auto i = 0; i < tp_count; i++) {
pool->dispense_backend()->do_numa_job([&, this](int i) {
auto& tpc = tps[i]->config_;
delete[] (uint8_t*)(tpc.gate_proj);
delete[] (uint8_t*)(tpc.up_proj);
@@ -625,7 +610,7 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>
delete[] (ggml_bf16_t*)(tpc.gate_scale);
delete[] (ggml_bf16_t*)(tpc.up_scale);
delete[] (ggml_bf16_t*)(tpc.down_scale);
}
});
#ifdef LOAD_TIME_PROFILE
{