mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 14:29:22 +00:00
[fix]: fix k2-moe.hpp load weight (#1830)
This commit is contained in:
@@ -484,19 +484,19 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>
|
||||
|
||||
int& group_size = config.quant_config.group_size;
|
||||
|
||||
if (use_per_expert_ptrs) {
|
||||
for (auto i = 0; i < tp_count; i++) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
|
||||
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
|
||||
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
|
||||
|
||||
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
|
||||
if (use_per_expert_ptrs) {
|
||||
pool->get_subpool(i)->do_work_stealing_job(
|
||||
tpc.expert_num, nullptr,
|
||||
[&, i](int expert_id_) {
|
||||
@@ -533,26 +533,11 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
printf("TP %d load weight done.\n", i);
|
||||
}
|
||||
} else {
|
||||
for (auto i = 0; i < tp_count; i++) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
|
||||
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
|
||||
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
|
||||
|
||||
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
|
||||
} else {
|
||||
if (tpc.load == false) {
|
||||
pool->get_subpool(i)->do_work_stealing_job(
|
||||
tpc.expert_num, nullptr,
|
||||
[&](int expert_id_) {
|
||||
[&, i](int expert_id_) {
|
||||
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
|
||||
|
||||
memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),
|
||||
@@ -593,9 +578,9 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
printf("TP %d load weight done.\n", i);
|
||||
}
|
||||
}
|
||||
printf("TP %d load weight done.\n", i);
|
||||
});
|
||||
|
||||
#ifdef LOAD_TIME_PROFILE
|
||||
{
|
||||
@@ -616,7 +601,7 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>
|
||||
}
|
||||
#endif
|
||||
|
||||
for (auto i = 0; i < tp_count; i++) {
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
delete[] (uint8_t*)(tpc.gate_proj);
|
||||
delete[] (uint8_t*)(tpc.up_proj);
|
||||
@@ -625,7 +610,7 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>
|
||||
delete[] (ggml_bf16_t*)(tpc.gate_scale);
|
||||
delete[] (ggml_bf16_t*)(tpc.up_scale);
|
||||
delete[] (ggml_bf16_t*)(tpc.down_scale);
|
||||
}
|
||||
});
|
||||
|
||||
#ifdef LOAD_TIME_PROFILE
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user