[fix]: update moe's physical to logical map

This commit is contained in:
KMSorSMS
2025-11-07 16:29:04 +00:00
parent 2641d15383
commit 62d8685699
2 changed files with 26 additions and 26 deletions

View File

@@ -179,8 +179,8 @@ class MOEBindings {
if (physical_to_logical_map) {
// printf("debug physical_to_logical_map in arg:%lu\n", physical_to_logical_map);
moe->config.physical_to_logical_map = reinterpret_cast<void*>(physical_to_logical_map);
printf("moe ptr:%p,confirm: moe->config.physical_to_logical_map:%lu\n", reinterpret_cast<void*>(moe.get()),
reinterpret_cast<uintptr_t>(moe->config.physical_to_logical_map));
// printf("moe ptr:%p,confirm: moe->config.physical_to_logical_map:%lu\n", reinterpret_cast<void*>(moe.get()),
// reinterpret_cast<uintptr_t>(moe->config.physical_to_logical_map));
}
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}

View File

@@ -29,10 +29,7 @@
#include "../../cpu_backend/worker_pool.h"
#include "../moe-tp.hpp"
#include "la/amx.hpp"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
template <class T>
class AMX_MOE_TP {
@@ -264,8 +261,6 @@ class AMX_MOE_TP {
~AMX_MOE_TP() {
// shared_mem_buffer_numa.dealloc(this);
}
// pack and quant the weights
void pack_weights() {}
void load_weights() {
auto pool = config_.pool->get_subpool(tp_part_idx);
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
@@ -274,7 +269,7 @@ class AMX_MOE_TP {
config_.expert_num, nullptr,
[this, physical_to_logical_map](int expert_id) {
// printf("Load layer %d [%d/%d]\n", config_.layer_idx, expert_id, config_.expert_num);
uint64_t logical_expert_id = expert_id;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_id);
{
size_t scale_size = config_.intermediate_size * sizeof(float);
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size) - scale_size;
@@ -312,7 +307,7 @@ class AMX_MOE_TP {
std::cout << "Loading from " << prefix << std::endl;
for (int task_id = 0; task_id < config_.expert_num * mat_type_all * mat_split; task_id++) {
int64_t expert_idx = task_id / (mat_type_all * mat_split);
uint64_t logical_expert_id = expert_idx;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
uint8_t mat_class = (task_id % (mat_type_all * mat_split)) / mat_split;
uint8_t mat_split_idex = task_id % mat_split;
if (mat_class == 0) { // the up matrix
@@ -346,31 +341,33 @@ class AMX_MOE_TP {
}
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth](int task_id) {
[this, nth, physical_to_logical_map](int task_id) {
int64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// gate part
gate_bb_[expert_idx]->from_mat(
(ggml_bf16_t*)config_.gate_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith,
nth);
gate_bb_[logical_expert_id]->from_mat(
(ggml_bf16_t*)config_.gate_proj + logical_expert_id * config_.intermediate_size * config_.hidden_size,
ith, nth);
// up part
up_bb_[expert_idx]->from_mat(
(ggml_bf16_t*)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith,
nth);
up_bb_[logical_expert_id]->from_mat(
(ggml_bf16_t*)config_.up_proj + logical_expert_id * config_.intermediate_size * config_.hidden_size,
ith, nth);
},
nullptr);
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth](int task_id) {
[this, nth, physical_to_logical_map](int task_id) {
int64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// down part
down_bb_[expert_idx]->from_mat(
(ggml_bf16_t*)config_.down_proj + expert_idx * config_.hidden_size * config_.intermediate_size, ith,
nth);
// printf("load down, expert %ld, ith %d, total nth %d\n", expert_idx, ith, nth);
down_bb_[logical_expert_id]->from_mat(
(ggml_bf16_t*)config_.down_proj + logical_expert_id * config_.hidden_size * config_.intermediate_size,
ith, nth);
// printf("load idown, expert %ld, ith %d, total nth %d\n", expert_idx, ith, nth);
},
nullptr);
}
@@ -381,8 +378,9 @@ class AMX_MOE_TP {
if (config_.save) {
pool->do_work_stealing_job(
config_.expert_num * mat_type_all, nullptr,
[this](int task_id) {
[this, physical_to_logical_map](int task_id) {
int64_t expert_idx = task_id / mat_type_all;
expert_idx = expert_map(physical_to_logical_map, expert_idx);
uint8_t mat_class = task_id % mat_type_all;
if (mat_class == 0) { // the up matrix
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);
@@ -838,8 +836,8 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
if (config.gate_projs.empty() == false) {
printf("TP Load from loader\n");
pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
// pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
DO_TPS_LOAD_WEIGHTS(pool);
this->weights_loaded = true;
} else if (config.gate_proj != nullptr) {
printf("From BF16\n");
@@ -874,7 +872,8 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
}
}
pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
// pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
DO_TPS_LOAD_WEIGHTS(pool);
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
@@ -886,7 +885,8 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
this->weights_loaded = true;
} else if (config.path != "") {
printf("TP Load from file\n");
pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
// pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
DO_TPS_LOAD_WEIGHTS(pool);
this->weights_loaded = true;
} else {
throw std::runtime_error("no weight source");