Files
ktransformers/kt-kernel/operators/amx/moe.hpp
ErvinXie d8046e1bb4 Kt minimax (#1742)
[feat]: fp8 kernel and kt-cli support
2025-12-24 15:39:44 +08:00

441 lines
19 KiB
C++

/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:10
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_OPERATOR_AMX_MOE_H
#define CPUINFER_OPERATOR_AMX_MOE_H
// #define CHECK
// #define FORWARD_TIME_PROFILE
// #define FORWARD_TIME_REPORT
#include "moe_base.hpp"
template <class T>
class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
private:
using Base = AMX_MOE_BASE<T, AMX_MOE_TP<T>>;
using Base::config_;
using Base::tp_part_idx;
using Base::gate_bb_;
using Base::up_bb_;
using Base::down_bb_;
using Base::gate_up_ba_;
using Base::gate_bc_;
using Base::up_bc_;
using Base::down_ba_;
using Base::down_bc_;
using Base::m_local_num_;
std::filesystem::path prefix;
void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
// quantized)]
void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
// quantized)]
void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if
// quantized)]
#ifdef CHECK
char verify_bb[100000000];
char check_bb[100000000];
uint8_t compare_expers = 3;
#endif
inline void write_weights(std::filesystem::path prefix, std::string mat_class, char* bb, int expert_idx, size_t size,
size_t scale_size) {
// printf("expert %d, size %ld, scale size %ld\n", expert_idx, size, scale_size);
// std::ofstream of(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_quant_" + ".kt"));
std::ofstream of(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_" +
std::to_string(size - scale_size) + "Byte" + "_quant_" + ".kt"));
if (of.is_open() == false) {
printf("no such file: %s", (prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_" +
std::to_string(size - scale_size) + "Byte" + "_quant_" + ".kt"))
.c_str());
// throw std::runtime_error("No such file");
}
of.write((char*)bb, size - scale_size);
of.close();
// of.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_scale_" + ".kt"));
of.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_" + std::to_string(scale_size) + "Byte" +
"_scale_" + ".kt"));
if (of.is_open() == false) {
printf("no such file\n");
// throw std::runtime_error("No such file");
}
of.write(((char*)bb) + size - scale_size, scale_size);
}
inline void read_weights(std::filesystem::path prefix, std::string mat_class, char* bb, int expert_idx, size_t size,
size_t scale_size, uint8_t mat_split, uint8_t mat_split_idex) {
// std::ifstream f(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_quant_" + ".kt"));
std::ifstream f(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_" +
std::to_string(size - scale_size) + "Byte" + "_quant_" + ".kt"));
if (f.is_open() == false) {
printf("no such file: %s\n", (prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_" +
std::to_string(size - scale_size) + "Byte" + "_quant_" + ".kt"))
.c_str());
// throw std::runtime_error("No such file");
}
f.seekg(mat_split_idex * (size - scale_size) / mat_split);
f.read(((char*)bb) + mat_split_idex * (size - scale_size) / mat_split, (size - scale_size) / mat_split);
f.close();
// f.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_scale_" + ".kt"));
f.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_" + std::to_string(scale_size) + "Byte" +
"_scale_" + ".kt"));
if (f.is_open() == false) {
printf("no such file: %s\n", (prefix / (T::name() + mat_class + std::to_string(expert_idx) + "_" +
std::to_string(scale_size) + "Byte" + "_scale_" + ".kt"))
.c_str());
// throw std::runtime_error("No such file");
}
f.seekg(mat_split_idex * scale_size / mat_split);
f.read((((char*)bb) + size - scale_size) + mat_split_idex * scale_size / mat_split, scale_size / mat_split);
}
#ifdef CHECK
inline void load_check() {
memcpy(check_bb, (char*)down_bb_[compare_expers]->b,
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));
}
void verify_load_right() {
// printf("varify down bb_0 %d\n", tp_part_idx);
memcpy(verify_bb, (char*)down_bb_[compare_expers]->b,
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));
// check if verify_bb_0 equal to check_bb_0
if (memcmp(verify_bb, check_bb, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)) != 0) {
printf("verify error\n");
for (size_t i = 0; i < T::BufferB::required_size(config_.hidden_size, config_.intermediate_size); ++i) {
if (verify_bb[i] != check_bb[i]) {
printf("Difference at byte %zu: verify_bb_%d[%zu] = %02x, check_bb[%zu] = %02x\n", i, compare_expers, i,
(unsigned char)verify_bb[i], i, (unsigned char)check_bb[i]);
break; // find the first difference and exit
}
}
assert(0);
} else {
printf("pass verify\n");
// pick out the 100th~150th byte of scale to see
printf("numa %d, verify_bb_%d:\n", tp_part_idx, compare_expers);
size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size);
size_t scale_size = config_.hidden_size * sizeof(float);
for (size_t i = size - scale_size; i < size - scale_size + 50; ++i) {
printf("%02x ", (unsigned char)verify_bb[i]);
}
printf("\n");
}
}
#endif
#ifdef FORWARD_TIME_REPORT
std::chrono::time_point<std::chrono::high_resolution_clock> last_now;
#endif
public:
AMX_MOE_TP() = default;
AMX_MOE_TP(GeneralMOEConfig config, int tp_part_idx = 0) : Base(config, tp_part_idx) {
printf("Creating AMX_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu()));
auto& load = config_.load;
auto& save = config_.save;
prefix = config_.path;
prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx));
if (save) {
std::cout << "Creating " << prefix << std::endl;
std::filesystem::create_directories(prefix);
}
if (load) {
if (std::filesystem::exists(prefix)) {
std::cout << "Loading from " << prefix << std::endl;
} else {
throw std::runtime_error("Path not found: " + prefix.string());
}
}
gate_proj_ = config_.gate_proj;
up_proj_ = config_.up_proj;
down_proj_ = config_.down_proj;
}
~AMX_MOE_TP() = default;
// ============================================================================
// CRTP buffer creation - no group_size
// ============================================================================
size_t buffer_a_required_size_impl(size_t m, size_t k) const {
return T::BufferA::required_size(m, k);
}
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
return T::BufferB::required_size(n, k);
}
size_t buffer_c_required_size_impl(size_t m, size_t n) const {
return T::BufferC::required_size(m, n);
}
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
return std::make_shared<typename T::BufferA>(m, k, data);
}
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
return std::make_shared<typename T::BufferB>(n, k, data);
}
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
return std::make_shared<typename T::BufferC>(m, n, data);
}
// ============================================================================
// CRTP virtual points - GEMM dispatch
// ============================================================================
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
int m = m_local_num_[expert_idx];
auto& ba = gate_up_ba_[expert_idx];
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {
amx::mat_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth);
} else {
amx::vec_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth);
}
}
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
int m = m_local_num_[expert_idx];
auto& ba = down_ba_[expert_idx];
auto& bb = down_bb_[expert_idx];
auto& bc = down_bc_[expert_idx];
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {
amx::mat_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth);
} else {
amx::vec_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth);
}
}
void load_weights() {
auto pool = config_.pool->get_subpool(tp_part_idx);
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
if (config_.gate_projs.size()) {
pool->do_work_stealing_job(
config_.expert_num, nullptr,
[this, physical_to_logical_map](int expert_id) {
// printf("Load layer %d [%d/%d]\n", config_.layer_idx, expert_id, config_.expert_num);
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_id);
{
size_t scale_size = config_.intermediate_size * sizeof(float);
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size) - scale_size;
memcpy(gate_bb_[expert_id]->b, config_.gate_projs[tp_part_idx][logical_expert_id], size);
if constexpr (T::BufferB::SCALE) {
memcpy(gate_bb_[expert_id]->d, config_.gate_scales[tp_part_idx][logical_expert_id], scale_size);
}
memcpy(up_bb_[expert_id]->b, config_.up_projs[tp_part_idx][logical_expert_id], size);
if constexpr (T::BufferB::SCALE) {
memcpy(up_bb_[expert_id]->d, config_.up_scales[tp_part_idx][logical_expert_id], scale_size);
}
}
{
size_t scale_size = config_.hidden_size * sizeof(float);
size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size) - scale_size;
memcpy(down_bb_[expert_id]->b, config_.down_projs[tp_part_idx][logical_expert_id], size);
if constexpr (T::BufferB::SCALE) {
memcpy(down_bb_[expert_id]->d, config_.down_scales[tp_part_idx][logical_expert_id], scale_size);
}
}
},
nullptr);
} else {
int nth = T::recommended_nth(config_.intermediate_size);
static uint8_t mat_type_all = 3, mat_split = 1;
if (config_.load) {
std::cout << "Loading from " << prefix << std::endl;
for (int task_id = 0; task_id < config_.expert_num * mat_type_all * mat_split; task_id++) {
int64_t expert_idx = task_id / (mat_type_all * mat_split);
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
uint8_t mat_class = (task_id % (mat_type_all * mat_split)) / mat_split;
uint8_t mat_split_idex = task_id % mat_split;
if (mat_class == 0) { // the up matrix
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);
size_t scale_size = config_.intermediate_size * sizeof(float);
read_weights(prefix, "_up_", (char*)up_bb_[expert_idx]->b, logical_expert_id, size, scale_size, mat_split,
mat_split_idex);
} else if (mat_class == 1) {
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);
size_t scale_size = config_.intermediate_size * sizeof(float);
read_weights(prefix, "_gate_", (char*)gate_bb_[expert_idx]->b, logical_expert_id, size, scale_size,
mat_split, mat_split_idex);
} else {
size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size);
size_t scale_size = config_.hidden_size * sizeof(float);
read_weights(prefix, "_down_", (char*)down_bb_[expert_idx]->b, logical_expert_id, size, scale_size,
mat_split, mat_split_idex);
}
}
}
// check process, store down matrix to check
#ifdef CHECK
load_check();
#endif
#ifndef CHECK
else
#endif
{
if (tp_part_idx == 0) {
std::cout << " online quant from bf16" << std::endl;
}
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
int64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// gate part
gate_bb_[logical_expert_id]->from_mat(
(ggml_bf16_t*)config_.gate_proj + logical_expert_id * config_.intermediate_size * config_.hidden_size,
ith, nth);
// up part
up_bb_[logical_expert_id]->from_mat(
(ggml_bf16_t*)config_.up_proj + logical_expert_id * config_.intermediate_size * config_.hidden_size,
ith, nth);
},
nullptr);
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
int64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// down part
down_bb_[logical_expert_id]->from_mat(
(ggml_bf16_t*)config_.down_proj + logical_expert_id * config_.hidden_size * config_.intermediate_size,
ith, nth);
// printf("load idown, expert %ld, ith %d, total nth %d\n", expert_idx, ith, nth);
},
nullptr);
}
#ifdef CHECK
verify_load_right();
#endif
// save process
if (config_.save) {
pool->do_work_stealing_job(
config_.expert_num * mat_type_all, nullptr,
[this, physical_to_logical_map](int task_id) {
int64_t expert_idx = task_id / mat_type_all;
expert_idx = expert_map(physical_to_logical_map, expert_idx);
uint8_t mat_class = task_id % mat_type_all;
if (mat_class == 0) { // the up matrix
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);
size_t scale_size = config_.intermediate_size * sizeof(float);
write_weights(prefix, "_up_", (char*)up_bb_[expert_idx]->b, expert_idx, size, scale_size);
} else if (mat_class == 1) {
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);
size_t scale_size = config_.intermediate_size * sizeof(float);
write_weights(prefix, "_gate_", (char*)gate_bb_[expert_idx]->b, expert_idx, size, scale_size);
} else if (mat_class == 2) {
size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size);
size_t scale_size = config_.hidden_size * sizeof(float);
write_weights(prefix, "_down_", (char*)down_bb_[expert_idx]->b, expert_idx, size, scale_size);
}
},
nullptr);
}
}
}
// forward, forward_prefill, forward_decode, warm_up are inherited from Base
};
// ============================================================================
// TP_MOE specialization for AMX_MOE_TP
// Inherits from TP_MOE<AMX_MOE_BASE<...>> to reuse merge_results implementation
// ============================================================================
template <typename K>
class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_MOE_TP<K>>> {
public:
using Base = TP_MOE<AMX_MOE_BASE<K, AMX_MOE_TP<K>>>;
using Base::Base;
void load_weights() override {
auto& config = this->config;
auto& tps = this->tps;
auto& tp_count = this->tp_count;
auto pool = config.pool;
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
if (config.gate_projs.empty() == false) {
printf("TP Load from loader\n");
DO_TPS_LOAD_WEIGHTS(pool);
this->weights_loaded = true;
} else if (config.gate_proj != nullptr) {
printf("From BF16\n");
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
size_t gate_up_elcount = tpc.intermediate_size * tpc.hidden_size;
tpc.gate_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount];
tpc.up_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount];
tpc.down_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount];
if (tps[i]->config_.load == false) {
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&](int expert_id_) {
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
memcpy((ggml_bf16_t*)tpc.gate_proj + expert_id * gate_up_elcount,
(ggml_bf16_t*)config.gate_proj + expert_id * config.intermediate_size * config.hidden_size +
i * gate_up_elcount,
sizeof(ggml_bf16_t) * gate_up_elcount);
memcpy((ggml_bf16_t*)tpc.up_proj + expert_id * gate_up_elcount,
(ggml_bf16_t*)config.up_proj + expert_id * config.intermediate_size * config.hidden_size +
i * gate_up_elcount,
sizeof(ggml_bf16_t) * gate_up_elcount);
for (size_t col = 0; col < config.hidden_size; col++) {
memcpy((ggml_bf16_t*)tpc.down_proj + expert_id * tpc.hidden_size * tpc.intermediate_size +
col * tpc.intermediate_size,
(ggml_bf16_t*)config.down_proj + expert_id * config.intermediate_size * config.hidden_size +
col * config.intermediate_size + i * tpc.intermediate_size,
sizeof(ggml_bf16_t) * tpc.intermediate_size);
}
},
nullptr);
}
}
DO_TPS_LOAD_WEIGHTS(pool);
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
delete[] (ggml_bf16_t*)(tpc.gate_proj);
delete[] (ggml_bf16_t*)(tpc.up_proj);
delete[] (ggml_bf16_t*)(tpc.down_proj);
}
this->weights_loaded = true;
} else if (config.path != "") {
printf("TP Load from file\n");
DO_TPS_LOAD_WEIGHTS(pool);
this->weights_loaded = true;
} else {
throw std::runtime_error("no weight source");
}
}
// merge_results is inherited from TP_MOE<AMX_MOE_BASE<K, AMX_MOE_TP<K>>>
};
#endif