[fix]: fix moe hpp bug. (#1780)

fix moe hpp init bug.
This commit is contained in:
Oql
2026-01-04 19:32:56 +08:00
committed by GitHub
parent ad7674a6d5
commit dc6394e501
5 changed files with 58 additions and 51 deletions

View File

@@ -31,16 +31,16 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
private:
using Base = AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>>;
using Base::config_;
using Base::tp_part_idx;
using Base::gate_bb_;
using Base::up_bb_;
using Base::down_bb_;
using Base::gate_up_ba_;
using Base::gate_bc_;
using Base::up_bc_;
using Base::down_ba_;
using Base::down_bb_;
using Base::down_bc_;
using Base::gate_bb_;
using Base::gate_bc_;
using Base::gate_up_ba_;
using Base::m_local_num_;
using Base::tp_part_idx;
using Base::up_bb_;
using Base::up_bc_;
std::filesystem::path prefix;
@@ -265,7 +265,7 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
config_.quant_config.group_size)) != 0) {
printf("verify error\n");
for (size_t i = 0; i < T::BufferB::required_size(config_.hidden_size, config_.intermediate_size,
config_.quant_config.group_size);
config_.quant_config.group_size);
++i) {
if (verify_bb[i] != check_bb[i]) {
printf("Difference at byte %zu: verify_bb_%d[%zu] = %02x, check_bb[%zu] = %02x\n", i, compare_expers, i,
@@ -393,19 +393,21 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
AMX_AWQ_MOE_TP() = default;
AMX_AWQ_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {
AMX_AWQ_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {}
void derived_init() {
auto& quant_config = config_.quant_config;
if (quant_config.group_size == 0 || !quant_config.zero_point) {
throw std::runtime_error("AWQ-Quantization AMX MoE only support KGroup Int4_1");
}
printf("Creating AMX_AWQ_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu()));
printf("Creating AMX_AWQ_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu()));
auto& load = config_.load;
auto& save = config_.save;
prefix = config_.path;
prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx_));
prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx));
if (save) {
std::cout << "Creating " << prefix << std::endl;
std::filesystem::create_directories(prefix);
@@ -431,9 +433,7 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
return T::BufferB::required_size(n, k, config_.quant_config.group_size);
}
size_t buffer_c_required_size_impl(size_t m, size_t n) const {
return T::BufferC::required_size(m, n);
}
size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
return std::make_shared<typename T::BufferA>(m, k, config_.quant_config.group_size, data);

View File

@@ -13,16 +13,6 @@
// #define DEBUG_FP8_MOE
#include <immintrin.h>
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <vector>
#include "la/amx_raw_buffers.hpp"
#include "la/amx_raw_kernels.hpp"
#include "moe_base.hpp"
@@ -57,12 +47,15 @@ class AMX_FP8_MOE_TP : public AMX_MOE_BASE<T, AMX_FP8_MOE_TP<T>> {
AMX_FP8_MOE_TP() = default;
AMX_FP8_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {
// Initialization now happens in derived_init() which is called by base constructor
}
void derived_init() {
auto& quant_config = config_.quant_config;
if (quant_config.group_size == 0 || quant_config.zero_point) {
throw std::runtime_error("KT-Kernel fp8 MoE only support block-wise FP8. group_size = %d, zero_point = %d",
quant_config.group_size, quant_config.zero_point);
throw std::runtime_error("KT-Kernel fp8 MoE only support block-wise FP8");
}
printf("Created AMX_FP8_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu()));
printf("Created AMX_FP8_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu()));
}
~AMX_FP8_MOE_TP() = default;

View File

@@ -45,12 +45,14 @@ class AMX_K2_MOE_TP : public AMX_MOE_BASE<T, AMX_K2_MOE_TP<T>> {
AMX_K2_MOE_TP() = default;
AMX_K2_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {
AMX_K2_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {}
void derived_init() {
auto& quant_config = config_.quant_config;
if (quant_config.group_size == 0 || quant_config.zero_point) {
throw std::runtime_error("Kimi-K2 MoE only support KGroup Int4");
}
printf("Creating AMX_K2_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu()));
printf("Creating AMX_K2_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu()));
}
~AMX_K2_MOE_TP() = default;

View File

@@ -21,18 +21,16 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
private:
using Base = AMX_MOE_BASE<T, AMX_MOE_TP<T>>;
using Base::config_;
using Base::tp_part_idx;
using Base::gate_bb_;
using Base::up_bb_;
using Base::down_bb_;
using Base::gate_up_ba_;
using Base::gate_bc_;
using Base::up_bc_;
using Base::down_ba_;
using Base::down_bb_;
using Base::down_bc_;
using Base::gate_bb_;
using Base::gate_bc_;
using Base::gate_up_ba_;
using Base::m_local_num_;
std::filesystem::path prefix;
using Base::tp_part_idx;
using Base::up_bb_;
using Base::up_bc_;
void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
// quantized)]
@@ -140,11 +138,15 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
AMX_MOE_TP() = default;
AMX_MOE_TP(GeneralMOEConfig config, int tp_part_idx = 0) : Base(config, tp_part_idx) {
// Initialization now happens in derived_init() which is called by base constructor
}
void derived_init() {
printf("Creating AMX_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu()));
auto& load = config_.load;
auto& save = config_.save;
prefix = config_.path;
std::filesystem::path prefix = config_.path;
prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx));
if (save) {
std::cout << "Creating " << prefix << std::endl;
@@ -169,15 +171,9 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
// CRTP buffer creation - no group_size
// ============================================================================
size_t buffer_a_required_size_impl(size_t m, size_t k) const {
return T::BufferA::required_size(m, k);
}
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
return T::BufferB::required_size(n, k);
}
size_t buffer_c_required_size_impl(size_t m, size_t n) const {
return T::BufferC::required_size(m, n);
}
size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }
size_t buffer_b_required_size_impl(size_t n, size_t k) const { return T::BufferB::required_size(n, k); }
size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
return std::make_shared<typename T::BufferA>(m, k, data);
@@ -260,6 +256,9 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
} else {
int nth = T::recommended_nth(config_.intermediate_size);
static uint8_t mat_type_all = 3, mat_split = 1;
std::filesystem::path prefix = config_.path;
prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx));
if (config_.load) {
std::cout << "Loading from " << prefix << std::endl;
for (int task_id = 0; task_id < config_.expert_num * mat_type_all * mat_split; task_id++) {
@@ -335,7 +334,7 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
if (config_.save) {
pool->do_work_stealing_job(
config_.expert_num * mat_type_all, nullptr,
[this, physical_to_logical_map](int task_id) {
[this, physical_to_logical_map, prefix](int task_id) {
int64_t expert_idx = task_id / mat_type_all;
expert_idx = expert_map(physical_to_logical_map, expert_idx);
uint8_t mat_class = task_id % mat_type_all;
@@ -426,7 +425,7 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_MOE_TP<K>>> {
this->weights_loaded = true;
} else if (config.path != "") {
printf("TP Load from file\n");
printf("TP Load from file %s\n", config.path.c_str());
DO_TPS_LOAD_WEIGHTS(pool);
this->weights_loaded = true;
} else {

View File

@@ -23,6 +23,7 @@
#include <cstring>
#include <filesystem>
#include <fstream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
@@ -78,7 +79,10 @@ class AMX_MOE_BASE {
using output_t = float;
static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;
AMX_MOE_BASE(GeneralMOEConfig config, int tp_part_idx_) : tp_part_idx(tp_part_idx_), config_(config) { init(); }
AMX_MOE_BASE(GeneralMOEConfig config, int tp_part_idx_) : tp_part_idx(tp_part_idx_), config_(config) {
init();
derived()->derived_init();
}
void init() {
if (config_.load && config_.path == "") {
@@ -639,6 +643,15 @@ class AMX_MOE_BASE {
Derived* derived() { return static_cast<Derived*>(this); }
const Derived* derived_const() const { return static_cast<const Derived*>(this); }
// ============================================================================
// Derived class initialization hook
// Called after base class init() completes, allows derived classes to perform
// their own initialization that depends on base class being fully initialized
// ============================================================================
void derived_init() {
// Default implementation does nothing - derived classes can override
}
// ============================================================================
// Virtual points for buffer creation and size calculation
// Default implementations use group_size (for KGroup quantization like K2)