diff --git a/kt-kernel/operators/amx/awq-moe.hpp b/kt-kernel/operators/amx/awq-moe.hpp index 23cef12..a71bde1 100644 --- a/kt-kernel/operators/amx/awq-moe.hpp +++ b/kt-kernel/operators/amx/awq-moe.hpp @@ -31,16 +31,16 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE> { private: using Base = AMX_MOE_BASE>; using Base::config_; - using Base::tp_part_idx; - using Base::gate_bb_; - using Base::up_bb_; - using Base::down_bb_; - using Base::gate_up_ba_; - using Base::gate_bc_; - using Base::up_bc_; using Base::down_ba_; + using Base::down_bb_; using Base::down_bc_; + using Base::gate_bb_; + using Base::gate_bc_; + using Base::gate_up_ba_; using Base::m_local_num_; + using Base::tp_part_idx; + using Base::up_bb_; + using Base::up_bc_; std::filesystem::path prefix; @@ -265,7 +265,7 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE> { config_.quant_config.group_size)) != 0) { printf("verify error\n"); for (size_t i = 0; i < T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, - config_.quant_config.group_size); + config_.quant_config.group_size); ++i) { if (verify_bb[i] != check_bb[i]) { printf("Difference at byte %zu: verify_bb_%d[%zu] = %02x, check_bb[%zu] = %02x\n", i, compare_expers, i, @@ -393,19 +393,21 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE> { AMX_AWQ_MOE_TP() = default; - AMX_AWQ_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) { + AMX_AWQ_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {} + + void derived_init() { auto& quant_config = config_.quant_config; if (quant_config.group_size == 0 || !quant_config.zero_point) { throw std::runtime_error("AWQ-Quantization AMX MoE only support KGroup Int4_1"); } - printf("Creating AMX_AWQ_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu())); + printf("Creating AMX_AWQ_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu())); auto& load = config_.load; auto& save = config_.save; prefix = config_.path; - prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx_)); + prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx)); if (save) { std::cout << "Creating " << prefix << std::endl; std::filesystem::create_directories(prefix); @@ -431,9 +433,7 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE> { size_t buffer_b_required_size_impl(size_t n, size_t k) const { return T::BufferB::required_size(n, k, config_.quant_config.group_size); } - size_t buffer_c_required_size_impl(size_t m, size_t n) const { - return T::BufferC::required_size(m, n); - } + size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); } std::shared_ptr make_buffer_a_impl(size_t m, size_t k, void* data) const { return std::make_shared(m, k, config_.quant_config.group_size, data); diff --git a/kt-kernel/operators/amx/fp8-moe.hpp b/kt-kernel/operators/amx/fp8-moe.hpp index 7bb7b83..dd48f8e 100644 --- a/kt-kernel/operators/amx/fp8-moe.hpp +++ b/kt-kernel/operators/amx/fp8-moe.hpp @@ -13,16 +13,6 @@ // #define DEBUG_FP8_MOE -#include - -#include -#include -#include -#include -#include -#include -#include - #include "la/amx_raw_buffers.hpp" #include "la/amx_raw_kernels.hpp" #include "moe_base.hpp" @@ -57,12 +47,15 @@ class AMX_FP8_MOE_TP : public AMX_MOE_BASE> { AMX_FP8_MOE_TP() = default; AMX_FP8_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) { + // Initialization now happens in derived_init() which is called by base constructor + } + + void derived_init() { auto& quant_config = config_.quant_config; if (quant_config.group_size == 0 || quant_config.zero_point) { - throw std::runtime_error("KT-Kernel fp8 MoE only support block-wise FP8. group_size = %d, zero_point = %d", - quant_config.group_size, quant_config.zero_point); + throw std::runtime_error("KT-Kernel fp8 MoE only support block-wise FP8"); } - printf("Created AMX_FP8_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu())); + printf("Created AMX_FP8_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu())); } ~AMX_FP8_MOE_TP() = default; diff --git a/kt-kernel/operators/amx/k2-moe.hpp b/kt-kernel/operators/amx/k2-moe.hpp index 67809a9..3f6f5f6 100644 --- a/kt-kernel/operators/amx/k2-moe.hpp +++ b/kt-kernel/operators/amx/k2-moe.hpp @@ -45,12 +45,14 @@ class AMX_K2_MOE_TP : public AMX_MOE_BASE> { AMX_K2_MOE_TP() = default; - AMX_K2_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) { + AMX_K2_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {} + + void derived_init() { auto& quant_config = config_.quant_config; if (quant_config.group_size == 0 || quant_config.zero_point) { throw std::runtime_error("Kimi-K2 MoE only support KGroup Int4"); } - printf("Creating AMX_K2_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu())); + printf("Creating AMX_K2_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu())); } ~AMX_K2_MOE_TP() = default; diff --git a/kt-kernel/operators/amx/moe.hpp b/kt-kernel/operators/amx/moe.hpp index 168b04b..08f5354 100644 --- a/kt-kernel/operators/amx/moe.hpp +++ b/kt-kernel/operators/amx/moe.hpp @@ -21,18 +21,16 @@ class AMX_MOE_TP : public AMX_MOE_BASE> { private: using Base = AMX_MOE_BASE>; using Base::config_; - using Base::tp_part_idx; - using Base::gate_bb_; - using Base::up_bb_; - using Base::down_bb_; - using Base::gate_up_ba_; - using Base::gate_bc_; - using Base::up_bc_; using Base::down_ba_; + using Base::down_bb_; using Base::down_bc_; + using Base::gate_bb_; + using Base::gate_bc_; + using Base::gate_up_ba_; using Base::m_local_num_; - - std::filesystem::path prefix; + using Base::tp_part_idx; + using Base::up_bb_; + using Base::up_bc_; void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if // quantized)] @@ -140,11 +138,15 @@ class AMX_MOE_TP : public AMX_MOE_BASE> { AMX_MOE_TP() = default; AMX_MOE_TP(GeneralMOEConfig config, int tp_part_idx = 0) : Base(config, tp_part_idx) { + // Initialization now happens in derived_init() which is called by base constructor + } + + void derived_init() { printf("Creating AMX_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu())); auto& load = config_.load; auto& save = config_.save; - prefix = config_.path; + std::filesystem::path prefix = config_.path; prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx)); if (save) { std::cout << "Creating " << prefix << std::endl; @@ -169,15 +171,9 @@ class AMX_MOE_TP : public AMX_MOE_BASE> { // CRTP buffer creation - no group_size // ============================================================================ - size_t buffer_a_required_size_impl(size_t m, size_t k) const { - return T::BufferA::required_size(m, k); - } - size_t buffer_b_required_size_impl(size_t n, size_t k) const { - return T::BufferB::required_size(n, k); - } - size_t buffer_c_required_size_impl(size_t m, size_t n) const { - return T::BufferC::required_size(m, n); - } + size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); } + size_t buffer_b_required_size_impl(size_t n, size_t k) const { return T::BufferB::required_size(n, k); } + size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); } std::shared_ptr make_buffer_a_impl(size_t m, size_t k, void* data) const { return std::make_shared(m, k, data); @@ -260,6 +256,9 @@ class AMX_MOE_TP : public AMX_MOE_BASE> { } else { int nth = T::recommended_nth(config_.intermediate_size); static uint8_t mat_type_all = 3, mat_split = 1; + std::filesystem::path prefix = config_.path; + prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx)); + if (config_.load) { std::cout << "Loading from " << prefix << std::endl; for (int task_id = 0; task_id < config_.expert_num * mat_type_all * mat_split; task_id++) { @@ -335,7 +334,7 @@ class AMX_MOE_TP : public AMX_MOE_BASE> { if (config_.save) { pool->do_work_stealing_job( config_.expert_num * mat_type_all, nullptr, - [this, physical_to_logical_map](int task_id) { + [this, physical_to_logical_map, prefix](int task_id) { int64_t expert_idx = task_id / mat_type_all; expert_idx = expert_map(physical_to_logical_map, expert_idx); uint8_t mat_class = task_id % mat_type_all; @@ -426,7 +425,7 @@ class TP_MOE> : public TP_MOE>> { this->weights_loaded = true; } else if (config.path != "") { - printf("TP Load from file\n"); + printf("TP Load from file %s\n", config.path.c_str()); DO_TPS_LOAD_WEIGHTS(pool); this->weights_loaded = true; } else { diff --git a/kt-kernel/operators/amx/moe_base.hpp b/kt-kernel/operators/amx/moe_base.hpp index e1bb093..09149e0 100644 --- a/kt-kernel/operators/amx/moe_base.hpp +++ b/kt-kernel/operators/amx/moe_base.hpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -78,7 +79,10 @@ class AMX_MOE_BASE { using output_t = float; static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE; - AMX_MOE_BASE(GeneralMOEConfig config, int tp_part_idx_) : tp_part_idx(tp_part_idx_), config_(config) { init(); } + AMX_MOE_BASE(GeneralMOEConfig config, int tp_part_idx_) : tp_part_idx(tp_part_idx_), config_(config) { + init(); + derived()->derived_init(); + } void init() { if (config_.load && config_.path == "") { @@ -639,6 +643,15 @@ class AMX_MOE_BASE { Derived* derived() { return static_cast(this); } const Derived* derived_const() const { return static_cast(this); } + // ============================================================================ + // Derived class initialization hook + // Called after base class init() completes, allows derived classes to perform + // their own initialization that depends on base class being fully initialized + // ============================================================================ + void derived_init() { + // Default implementation does nothing - derived classes can override + } + // ============================================================================ // Virtual points for buffer creation and size calculation // Default implementations use group_size (for KGroup quantization like K2)