/** * @Description : * @Author : chenht2022 * @Date : 2024-07-22 02:03:22 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2024-07-25 10:35:10 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #ifndef CPUINFER_OPERATOR_AMX_MOE_H #define CPUINFER_OPERATOR_AMX_MOE_H // #define CHECK // #define FORWARD_TIME_PROFILE // #define FORWARD_TIME_REPORT #include "moe_base.hpp" template class AMX_MOE_TP : public AMX_MOE_BASE> { private: using Base = AMX_MOE_BASE>; using Base::config_; using Base::tp_part_idx; using Base::gate_bb_; using Base::up_bb_; using Base::down_bb_; using Base::gate_up_ba_; using Base::gate_bc_; using Base::up_bc_; using Base::down_ba_; using Base::down_bc_; using Base::m_local_num_; std::filesystem::path prefix; void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if // quantized)] void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if // quantized)] void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if // quantized)] #ifdef CHECK char verify_bb[100000000]; char check_bb[100000000]; uint8_t compare_expers = 3; #endif inline void write_weights(std::filesystem::path prefix, std::string mat_class, char* bb, int expert_idx, size_t size, size_t scale_size) { // printf("expert %d, size %ld, scale size %ld\n", expert_idx, size, scale_size); // std::ofstream of(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_quant_" + ".kt")); std::ofstream of(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_" + std::to_string(size - scale_size) + "Byte" + "_quant_" + ".kt")); if (of.is_open() == false) { printf("no such file: %s", (prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_" + std::to_string(size - scale_size) + "Byte" + "_quant_" + ".kt")) .c_str()); // throw std::runtime_error("No such file"); } of.write((char*)bb, size - scale_size); of.close(); // of.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_scale_" + ".kt")); of.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_" + std::to_string(scale_size) + "Byte" + "_scale_" + ".kt")); if (of.is_open() == false) { printf("no such file\n"); // throw std::runtime_error("No such file"); } of.write(((char*)bb) + size - scale_size, scale_size); } inline void read_weights(std::filesystem::path prefix, std::string mat_class, char* bb, int expert_idx, size_t size, size_t scale_size, uint8_t mat_split, uint8_t mat_split_idex) { // std::ifstream f(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_quant_" + ".kt")); std::ifstream f(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_" + std::to_string(size - scale_size) + "Byte" + "_quant_" + ".kt")); if (f.is_open() == false) { printf("no such file: %s\n", (prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_" + std::to_string(size - scale_size) + "Byte" + "_quant_" + ".kt")) .c_str()); // throw std::runtime_error("No such file"); } f.seekg(mat_split_idex * (size - scale_size) / mat_split); f.read(((char*)bb) + mat_split_idex * (size - scale_size) / mat_split, (size - scale_size) / mat_split); f.close(); // f.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_scale_" + ".kt")); f.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_" + std::to_string(scale_size) + "Byte" + "_scale_" + ".kt")); if (f.is_open() == false) { printf("no such file: %s\n", (prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_" + std::to_string(scale_size) + "Byte" + "_scale_" + ".kt")) .c_str()); // throw std::runtime_error("No such file"); } f.seekg(mat_split_idex * scale_size / mat_split); f.read((((char*)bb) + size - scale_size) + mat_split_idex * scale_size / mat_split, scale_size / mat_split); } #ifdef CHECK inline void load_check() { memcpy(check_bb, (char*)down_bb_[compare_expers]->b, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)); } void verify_load_right() { // printf("varify down bb_0 %d\n", tp_part_idx); memcpy(verify_bb, (char*)down_bb_[compare_expers]->b, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)); // check if verify_bb_0 equal to check_bb_0 if (memcmp(verify_bb, check_bb, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)) != 0) { printf("verify error\n"); for (size_t i = 0; i < T::BufferB::required_size(config_.hidden_size, config_.intermediate_size); ++i) { if (verify_bb[i] != check_bb[i]) { printf("Difference at byte %zu: verify_bb_%d[%zu] = %02x, check_bb[%zu] = %02x\n", i, compare_expers, i, (unsigned char)verify_bb[i], i, (unsigned char)check_bb[i]); break; // find the first difference and exit } } assert(0); } else { printf("pass verify\n"); // pick out the 100th~150th byte of scale to see printf("numa %d, verify_bb_%d:\n", tp_part_idx, compare_expers); size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size); size_t scale_size = config_.hidden_size * sizeof(float); for (size_t i = size - scale_size; i < size - scale_size + 50; ++i) { printf("%02x ", (unsigned char)verify_bb[i]); } printf("\n"); } } #endif #ifdef FORWARD_TIME_REPORT std::chrono::time_point last_now; #endif public: AMX_MOE_TP() = default; AMX_MOE_TP(GeneralMOEConfig config, int tp_part_idx = 0) : Base(config, tp_part_idx) { printf("Creating AMX_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu())); auto& load = config_.load; auto& save = config_.save; prefix = config_.path; prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx)); if (save) { std::cout << "Creating " << prefix << std::endl; std::filesystem::create_directories(prefix); } if (load) { if (std::filesystem::exists(prefix)) { std::cout << "Loading from " << prefix << std::endl; } else { throw std::runtime_error("Path not found: " + prefix.string()); } } gate_proj_ = config_.gate_proj; up_proj_ = config_.up_proj; down_proj_ = config_.down_proj; } ~AMX_MOE_TP() = default; // ============================================================================ // CRTP buffer creation - no group_size // ============================================================================ size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); } size_t buffer_b_required_size_impl(size_t n, size_t k) const { return T::BufferB::required_size(n, k); } size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); } std::shared_ptr make_buffer_a_impl(size_t m, size_t k, void* data) const { return std::make_shared(m, k, data); } std::shared_ptr make_buffer_b_impl(size_t n, size_t k, void* data) const { return std::make_shared(n, k, data); } std::shared_ptr make_buffer_c_impl(size_t m, size_t n, void* data) const { return std::make_shared(m, n, data); } // ============================================================================ // CRTP virtual points - GEMM dispatch // ============================================================================ void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) { int m = m_local_num_[expert_idx]; auto& ba = gate_up_ba_[expert_idx]; auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx]; auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx]; if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { amx::mat_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth); } else { amx::vec_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth); } } void do_down_gemm(int expert_idx, int ith, int nth, int qlen) { int m = m_local_num_[expert_idx]; auto& ba = down_ba_[expert_idx]; auto& bb = down_bb_[expert_idx]; auto& bc = down_bc_[expert_idx]; if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { amx::mat_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth); } else { amx::vec_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth); } } void load_weights() { auto pool = config_.pool->get_subpool(tp_part_idx); const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map; if (config_.gate_projs.size()) { pool->do_work_stealing_job( config_.expert_num, nullptr, [this, physical_to_logical_map](int expert_id) { // printf("Load layer %d [%d/%d]\n", config_.layer_idx, expert_id, config_.expert_num); uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_id); { size_t scale_size = config_.intermediate_size * sizeof(float); size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size) - scale_size; memcpy(gate_bb_[expert_id]->b, config_.gate_projs[tp_part_idx][logical_expert_id], size); if constexpr (T::BufferB::SCALE) { memcpy(gate_bb_[expert_id]->d, config_.gate_scales[tp_part_idx][logical_expert_id], scale_size); } memcpy(up_bb_[expert_id]->b, config_.up_projs[tp_part_idx][logical_expert_id], size); if constexpr (T::BufferB::SCALE) { memcpy(up_bb_[expert_id]->d, config_.up_scales[tp_part_idx][logical_expert_id], scale_size); } } { size_t scale_size = config_.hidden_size * sizeof(float); size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size) - scale_size; memcpy(down_bb_[expert_id]->b, config_.down_projs[tp_part_idx][logical_expert_id], size); if constexpr (T::BufferB::SCALE) { memcpy(down_bb_[expert_id]->d, config_.down_scales[tp_part_idx][logical_expert_id], scale_size); } } }, nullptr); } else { int nth = T::recommended_nth(config_.intermediate_size); static uint8_t mat_type_all = 3, mat_split = 1; if (config_.load) { std::cout << "Loading from " << prefix << std::endl; for (int task_id = 0; task_id < config_.expert_num * mat_type_all * mat_split; task_id++) { int64_t expert_idx = task_id / (mat_type_all * mat_split); uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx); uint8_t mat_class = (task_id % (mat_type_all * mat_split)) / mat_split; uint8_t mat_split_idex = task_id % mat_split; if (mat_class == 0) { // the up matrix size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size); size_t scale_size = config_.intermediate_size * sizeof(float); read_weights(prefix, "_up_", (char*)up_bb_[expert_idx]->b, logical_expert_id, size, scale_size, mat_split, mat_split_idex); } else if (mat_class == 1) { size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size); size_t scale_size = config_.intermediate_size * sizeof(float); read_weights(prefix, "_gate_", (char*)gate_bb_[expert_idx]->b, logical_expert_id, size, scale_size, mat_split, mat_split_idex); } else { size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size); size_t scale_size = config_.hidden_size * sizeof(float); read_weights(prefix, "_down_", (char*)down_bb_[expert_idx]->b, logical_expert_id, size, scale_size, mat_split, mat_split_idex); } } } // check process, store down matrix to check #ifdef CHECK load_check(); #endif #ifndef CHECK else #endif { if (tp_part_idx == 0) { std::cout << " online quant from bf16" << std::endl; } pool->do_work_stealing_job( nth * config_.expert_num, nullptr, [this, nth, physical_to_logical_map](int task_id) { int64_t expert_idx = task_id / nth; uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx); int ith = task_id % nth; // gate part gate_bb_[logical_expert_id]->from_mat( (ggml_bf16_t*)config_.gate_proj + logical_expert_id * config_.intermediate_size * config_.hidden_size, ith, nth); // up part up_bb_[logical_expert_id]->from_mat( (ggml_bf16_t*)config_.up_proj + logical_expert_id * config_.intermediate_size * config_.hidden_size, ith, nth); }, nullptr); nth = T::recommended_nth(config_.hidden_size); pool->do_work_stealing_job( nth * config_.expert_num, nullptr, [this, nth, physical_to_logical_map](int task_id) { int64_t expert_idx = task_id / nth; uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx); int ith = task_id % nth; // down part down_bb_[logical_expert_id]->from_mat( (ggml_bf16_t*)config_.down_proj + logical_expert_id * config_.hidden_size * config_.intermediate_size, ith, nth); // printf("load idown, expert %ld, ith %d, total nth %d\n", expert_idx, ith, nth); }, nullptr); } #ifdef CHECK verify_load_right(); #endif // save process if (config_.save) { pool->do_work_stealing_job( config_.expert_num * mat_type_all, nullptr, [this, physical_to_logical_map](int task_id) { int64_t expert_idx = task_id / mat_type_all; expert_idx = expert_map(physical_to_logical_map, expert_idx); uint8_t mat_class = task_id % mat_type_all; if (mat_class == 0) { // the up matrix size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size); size_t scale_size = config_.intermediate_size * sizeof(float); write_weights(prefix, "_up_", (char*)up_bb_[expert_idx]->b, expert_idx, size, scale_size); } else if (mat_class == 1) { size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size); size_t scale_size = config_.intermediate_size * sizeof(float); write_weights(prefix, "_gate_", (char*)gate_bb_[expert_idx]->b, expert_idx, size, scale_size); } else if (mat_class == 2) { size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size); size_t scale_size = config_.hidden_size * sizeof(float); write_weights(prefix, "_down_", (char*)down_bb_[expert_idx]->b, expert_idx, size, scale_size); } }, nullptr); } } } // forward, forward_prefill, forward_decode, warm_up are inherited from Base }; // ============================================================================ // TP_MOE specialization for AMX_MOE_TP // Inherits from TP_MOE> to reuse merge_results implementation // ============================================================================ template class TP_MOE> : public TP_MOE>> { public: using Base = TP_MOE>>; using Base::Base; void load_weights() override { auto& config = this->config; auto& tps = this->tps; auto& tp_count = this->tp_count; auto pool = config.pool; const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map; if (config.gate_projs.empty() == false) { printf("TP Load from loader\n"); DO_TPS_LOAD_WEIGHTS(pool); this->weights_loaded = true; } else if (config.gate_proj != nullptr) { printf("From BF16\n"); for (auto i = 0; i < tp_count; i++) { auto& tpc = tps[i]->config_; size_t gate_up_elcount = tpc.intermediate_size * tpc.hidden_size; tpc.gate_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount]; tpc.up_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount]; tpc.down_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount]; if (tps[i]->config_.load == false) { pool->get_subpool(i)->do_work_stealing_job( tpc.expert_num, nullptr, [&](int expert_id_) { size_t expert_id = expert_map(physical_to_logical_map, expert_id_); memcpy((ggml_bf16_t*)tpc.gate_proj + expert_id * gate_up_elcount, (ggml_bf16_t*)config.gate_proj + expert_id * config.intermediate_size * config.hidden_size + i * gate_up_elcount, sizeof(ggml_bf16_t) * gate_up_elcount); memcpy((ggml_bf16_t*)tpc.up_proj + expert_id * gate_up_elcount, (ggml_bf16_t*)config.up_proj + expert_id * config.intermediate_size * config.hidden_size + i * gate_up_elcount, sizeof(ggml_bf16_t) * gate_up_elcount); for (size_t col = 0; col < config.hidden_size; col++) { memcpy((ggml_bf16_t*)tpc.down_proj + expert_id * tpc.hidden_size * tpc.intermediate_size + col * tpc.intermediate_size, (ggml_bf16_t*)config.down_proj + expert_id * config.intermediate_size * config.hidden_size + col * config.intermediate_size + i * tpc.intermediate_size, sizeof(ggml_bf16_t) * tpc.intermediate_size); } }, nullptr); } } DO_TPS_LOAD_WEIGHTS(pool); for (auto i = 0; i < tp_count; i++) { auto& tpc = tps[i]->config_; delete[] (ggml_bf16_t*)(tpc.gate_proj); delete[] (ggml_bf16_t*)(tpc.up_proj); delete[] (ggml_bf16_t*)(tpc.down_proj); } this->weights_loaded = true; } else if (config.path != "") { printf("TP Load from file\n"); DO_TPS_LOAD_WEIGHTS(pool); this->weights_loaded = true; } else { throw std::runtime_error("no weight source"); } } // merge_results is inherited from TP_MOE>> }; #endif