mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 14:29:22 +00:00
764 lines
31 KiB
C++
764 lines
31 KiB
C++
/**
|
|
* @Description : Common AMX MoE base class extracted from K2 implementation.
|
|
* @Author : oql, Codex and Claude
|
|
* @Date : 2025-12-09
|
|
* @Version : 0.1.0
|
|
* @LastEditors : oql, Codex and Claude
|
|
* @LastEditTime : 2025-12-09
|
|
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
|
**/
|
|
#ifndef CPUINFER_OPERATOR_AMX_MOE_BASE_H
|
|
#define CPUINFER_OPERATOR_AMX_MOE_BASE_H
|
|
|
|
// #define FORWARD_TIME_PROFILE
|
|
|
|
#include <immintrin.h>
|
|
|
|
#include <algorithm>
|
|
#include <chrono>
|
|
#include <cmath>
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <cstdio>
|
|
#include <cstring>
|
|
#include <filesystem>
|
|
#include <fstream>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "../../cpu_backend/shared_mem_buffer.h"
|
|
#include "../../cpu_backend/worker_pool.h"
|
|
#include "../common.hpp"
|
|
#include "../moe-tp.hpp"
|
|
#include "la/amx.hpp"
|
|
#include "llama.cpp/ggml.h"
|
|
|
|
template <class T, class Derived>
|
|
class AMX_MOE_BASE {
|
|
public:
|
|
int tp_part_idx = 0;
|
|
|
|
ggml_bf16_t* m_local_input_ = nullptr;
|
|
ggml_bf16_t* m_local_gate_output_ = nullptr;
|
|
ggml_bf16_t* m_local_up_output_ = nullptr;
|
|
ggml_bf16_t* m_local_down_output_ = nullptr;
|
|
|
|
std::vector<std::vector<int>> m_local_pos_;
|
|
std::vector<int> m_local_num_;
|
|
std::vector<int> m_expert_id_map_;
|
|
std::vector<ggml_bf16_t*> m_local_input_ptr_;
|
|
std::vector<ggml_bf16_t*> m_local_gate_output_ptr_;
|
|
std::vector<ggml_bf16_t*> m_local_up_output_ptr_;
|
|
std::vector<ggml_bf16_t*> m_local_down_output_ptr_;
|
|
|
|
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
|
|
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
|
|
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
|
|
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
|
|
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
|
|
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
|
|
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
|
|
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
|
|
|
|
size_t pool_count_ = 0;
|
|
size_t gate_up_ba_pool_bytes_ = 0;
|
|
size_t gate_bc_pool_bytes_ = 0;
|
|
size_t up_bc_pool_bytes_ = 0;
|
|
size_t down_ba_pool_bytes_ = 0;
|
|
size_t down_bc_pool_bytes_ = 0;
|
|
void* gate_up_ba_pool_ = nullptr;
|
|
void* gate_bc_pool_ = nullptr;
|
|
void* up_bc_pool_ = nullptr;
|
|
void* down_ba_pool_ = nullptr;
|
|
void* down_bc_pool_ = nullptr;
|
|
|
|
GeneralMOEConfig config_;
|
|
using input_t = ggml_bf16_t;
|
|
using output_t = float;
|
|
static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;
|
|
|
|
AMX_MOE_BASE(GeneralMOEConfig config, int tp_part_idx_) : tp_part_idx(tp_part_idx_), config_(config) { init(); }
|
|
|
|
void init() {
|
|
if (config_.load && config_.path == "") {
|
|
config_.load = false;
|
|
}
|
|
|
|
MemoryRequest mem_requests;
|
|
mem_requests.append_pointer(
|
|
&m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);
|
|
mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
|
config_.max_len * config_.intermediate_size);
|
|
mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
|
config_.max_len * config_.intermediate_size);
|
|
mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
|
config_.max_len * config_.hidden_size);
|
|
|
|
m_local_pos_.resize(config_.max_len);
|
|
for (int i = 0; i < config_.max_len; i++) {
|
|
m_local_pos_[i].resize(config_.num_experts_per_tok);
|
|
}
|
|
m_expert_id_map_.resize(config_.expert_num);
|
|
m_local_num_.resize(config_.expert_num);
|
|
m_local_input_ptr_.resize(config_.expert_num);
|
|
m_local_gate_output_ptr_.resize(config_.expert_num);
|
|
m_local_up_output_ptr_.resize(config_.expert_num);
|
|
m_local_down_output_ptr_.resize(config_.expert_num);
|
|
|
|
for (size_t i = 0; i < config_.expert_num; i++) {
|
|
gate_up_ba_.push_back(make_buffer_a(config_.max_len, config_.hidden_size, nullptr));
|
|
gate_bc_.push_back(make_buffer_c(config_.max_len, config_.intermediate_size, nullptr));
|
|
up_bc_.push_back(make_buffer_c(config_.max_len, config_.intermediate_size, nullptr));
|
|
down_ba_.push_back(make_buffer_a(config_.max_len, config_.intermediate_size, nullptr));
|
|
down_bc_.push_back(make_buffer_c(config_.max_len, config_.hidden_size, nullptr));
|
|
|
|
void* gate_bb_ptr =
|
|
std::aligned_alloc(64, buffer_b_required_size(config_.intermediate_size, config_.hidden_size));
|
|
gate_bb_.push_back(make_buffer_b(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));
|
|
|
|
void* up_bb_ptr = std::aligned_alloc(64, buffer_b_required_size(config_.intermediate_size, config_.hidden_size));
|
|
up_bb_.push_back(make_buffer_b(config_.intermediate_size, config_.hidden_size, up_bb_ptr));
|
|
|
|
void* down_bb_ptr =
|
|
std::aligned_alloc(64, buffer_b_required_size(config_.hidden_size, config_.intermediate_size));
|
|
down_bb_.push_back(make_buffer_b(config_.hidden_size, config_.intermediate_size, down_bb_ptr));
|
|
}
|
|
// TODO: need update to all *.hpp
|
|
// (config_.expert_num * T::M_STEP) in pool_count_ is to ensure padding for each experts.
|
|
pool_count_ = config_.max_len * config_.num_experts_per_tok + config_.expert_num * T::M_STEP;
|
|
|
|
gate_up_ba_pool_bytes_ = buffer_a_required_size(pool_count_, config_.hidden_size) + pool_count_ * 64;
|
|
gate_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
|
|
up_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
|
|
down_ba_pool_bytes_ = buffer_a_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
|
|
down_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.hidden_size) + pool_count_ * 64;
|
|
|
|
mem_requests.append_pointer(&gate_up_ba_pool_, gate_up_ba_pool_bytes_);
|
|
mem_requests.append_pointer(&gate_bc_pool_, gate_bc_pool_bytes_);
|
|
mem_requests.append_pointer(&up_bc_pool_, up_bc_pool_bytes_);
|
|
mem_requests.append_pointer(&down_ba_pool_, down_ba_pool_bytes_);
|
|
mem_requests.append_pointer(&down_bc_pool_, down_bc_pool_bytes_);
|
|
|
|
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
|
|
}
|
|
|
|
~AMX_MOE_BASE() = default;
|
|
|
|
void warm_up() {
|
|
int qlen = config_.max_len;
|
|
std::vector<uint8_t> input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
|
|
std::vector<uint8_t> output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
|
|
std::vector<int64_t> expert_ids(qlen * config_.num_experts_per_tok);
|
|
std::vector<float> weights(qlen * config_.num_experts_per_tok);
|
|
for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) {
|
|
expert_ids[i] = i % config_.expert_num;
|
|
weights[i] = 0.01;
|
|
}
|
|
forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data());
|
|
}
|
|
|
|
void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
|
|
if (qlen > 1) {
|
|
forward_prefill(qlen, k, expert_ids, weights, input, output);
|
|
} else {
|
|
forward_decode(k, expert_ids, weights, input, output);
|
|
}
|
|
}
|
|
|
|
template <typename... Args>
|
|
void load_weights(Args&&... args) {
|
|
derived()->load_weights(std::forward<Args>(args)...);
|
|
}
|
|
|
|
template <typename... Args>
|
|
void write_weights_to_buffer(Args&&... args) const {
|
|
derived_const()->write_weights_to_buffer(std::forward<Args>(args)...);
|
|
}
|
|
|
|
void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
|
|
void* output) {
|
|
auto pool = config_.pool->get_subpool(tp_part_idx);
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
auto start_time = std::chrono::high_resolution_clock::now();
|
|
auto last = start_time;
|
|
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
|
|
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
|
|
int max_local_num = 0;
|
|
#endif
|
|
|
|
int activated_expert = 0;
|
|
std::fill(m_local_num_.begin(), m_local_num_.end(), 0);
|
|
for (int i = 0; i < qlen; i++) {
|
|
for (int j = 0; j < k; j++) {
|
|
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
|
continue;
|
|
}
|
|
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
|
|
}
|
|
}
|
|
|
|
for (int i = 0; i < config_.expert_num; i++) {
|
|
if (m_local_num_[i] > 0) {
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
max_local_num = std::max(max_local_num, m_local_num_[i]);
|
|
#endif
|
|
m_expert_id_map_[activated_expert] = i;
|
|
activated_expert++;
|
|
}
|
|
}
|
|
|
|
size_t offset = 0;
|
|
void* gate_up_ba_pool_ptr = gate_up_ba_pool_;
|
|
void* gate_bc_pool_ptr = gate_bc_pool_;
|
|
void* up_bc_pool_ptr = up_bc_pool_;
|
|
void* down_ba_pool_ptr = down_ba_pool_;
|
|
void* down_bc_pool_ptr = down_bc_pool_;
|
|
constexpr size_t M_STEP = T::M_STEP;
|
|
auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); };
|
|
size_t used_pool_m = 0;
|
|
size_t used_pool_bytes_a = 0, used_pool_bytes_bc_gate = 0, used_pool_bytes_bc_up = 0, used_pool_bytes_ba_down = 0,
|
|
used_pool_bytes_bc_down = 0;
|
|
|
|
for (int i = 0; i < config_.expert_num; i++) {
|
|
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
|
|
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
|
|
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
|
|
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
|
|
offset += m_local_num_[i];
|
|
|
|
if (m_local_num_[i] == 0) {
|
|
continue;
|
|
}
|
|
|
|
size_t max_m = (m_local_num_[i] + M_STEP - 1) / M_STEP * M_STEP;
|
|
gate_up_ba_[i]->max_m = max_m;
|
|
gate_up_ba_[i]->set_data(gate_up_ba_pool_ptr);
|
|
size_t ba_size = align64(buffer_a_required_size(max_m, config_.hidden_size));
|
|
gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size);
|
|
|
|
gate_bc_[i]->max_m = max_m;
|
|
gate_bc_[i]->set_data(gate_bc_pool_ptr);
|
|
size_t bc_gate_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));
|
|
gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + bc_gate_size);
|
|
|
|
up_bc_[i]->max_m = max_m;
|
|
up_bc_[i]->set_data(up_bc_pool_ptr);
|
|
size_t bc_up_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));
|
|
up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + bc_up_size);
|
|
|
|
down_ba_[i]->max_m = max_m;
|
|
down_ba_[i]->set_data(down_ba_pool_ptr);
|
|
size_t ba_down_size = align64(buffer_a_required_size(max_m, config_.intermediate_size));
|
|
down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + ba_down_size);
|
|
|
|
down_bc_[i]->max_m = max_m;
|
|
down_bc_[i]->set_data(down_bc_pool_ptr);
|
|
size_t bc_down_size = align64(buffer_c_required_size(max_m, config_.hidden_size));
|
|
down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size);
|
|
|
|
used_pool_m += max_m;
|
|
used_pool_bytes_a += ba_size;
|
|
used_pool_bytes_bc_gate += bc_gate_size;
|
|
used_pool_bytes_bc_up += bc_up_size;
|
|
used_pool_bytes_ba_down += ba_down_size;
|
|
used_pool_bytes_bc_down += bc_down_size;
|
|
}
|
|
|
|
assert(used_pool_m <= pool_count_);
|
|
assert(used_pool_bytes_a <= gate_up_ba_pool_bytes_);
|
|
assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_);
|
|
assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_);
|
|
assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_);
|
|
assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_);
|
|
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
{
|
|
auto now_time = std::chrono::high_resolution_clock::now();
|
|
prepare_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
|
last = now_time;
|
|
}
|
|
#endif
|
|
|
|
auto direct_or_pool = [&](int count, auto&& fn) {
|
|
if (qlen < 10) {
|
|
for (int i = 0; i < count; i++) {
|
|
fn(i);
|
|
}
|
|
} else {
|
|
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
|
|
}
|
|
};
|
|
|
|
direct_or_pool(qlen, [&](int i) {
|
|
for (int j = 0; j < k; j++) {
|
|
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
|
continue;
|
|
}
|
|
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
|
|
(ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
|
|
}
|
|
});
|
|
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
{
|
|
auto now_time = std::chrono::high_resolution_clock::now();
|
|
cpy_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
|
last = now_time;
|
|
}
|
|
#endif
|
|
|
|
direct_or_pool(activated_expert, [this](int task_id) {
|
|
int expert_idx = m_expert_id_map_[task_id];
|
|
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
|
|
});
|
|
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
{
|
|
auto now_time = std::chrono::high_resolution_clock::now();
|
|
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
|
last = now_time;
|
|
}
|
|
#endif
|
|
|
|
int nth = T::recommended_nth(config_.intermediate_size);
|
|
pool->do_work_stealing_job(
|
|
nth * activated_expert * 2, [](int _) { T::config(); },
|
|
[this, nth, qlen](int task_id2) {
|
|
int task_id = task_id2 / 2;
|
|
bool do_up = task_id2 % 2;
|
|
int expert_idx = m_expert_id_map_[task_id / nth];
|
|
|
|
int ith = task_id % nth;
|
|
derived()->do_gate_up_gemm(do_up, expert_idx, ith, nth, qlen);
|
|
if (do_up) {
|
|
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
|
|
} else {
|
|
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
|
|
}
|
|
},
|
|
nullptr);
|
|
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
{
|
|
auto now_time = std::chrono::high_resolution_clock::now();
|
|
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
|
last = now_time;
|
|
}
|
|
#endif
|
|
|
|
apply_activation(activated_expert, nth, qlen);
|
|
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
{
|
|
auto now_time = std::chrono::high_resolution_clock::now();
|
|
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
|
last = now_time;
|
|
}
|
|
#endif
|
|
|
|
pool->do_work_stealing_job(
|
|
activated_expert, nullptr,
|
|
[this](int task_id) {
|
|
int expert_idx = m_expert_id_map_[task_id];
|
|
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
|
|
},
|
|
nullptr);
|
|
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
{
|
|
auto now_time = std::chrono::high_resolution_clock::now();
|
|
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
|
last = now_time;
|
|
}
|
|
#endif
|
|
|
|
nth = T::recommended_nth(config_.hidden_size);
|
|
pool->do_work_stealing_job(
|
|
nth * activated_expert, [](int _) { T::config(); },
|
|
[this, nth, qlen](int task_id) {
|
|
int expert_idx = m_expert_id_map_[task_id / nth];
|
|
int ith = task_id % nth;
|
|
derived()->do_down_gemm(expert_idx, ith, nth, qlen);
|
|
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
|
|
},
|
|
nullptr);
|
|
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
{
|
|
auto now_time = std::chrono::high_resolution_clock::now();
|
|
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
|
last = now_time;
|
|
}
|
|
#endif
|
|
|
|
pool->do_work_stealing_job(
|
|
qlen, nullptr,
|
|
[this, output, k, expert_ids, weights](int i) {
|
|
for (int e = 0; e < config_.hidden_size; e += 32) {
|
|
__m512 x0 = _mm512_setzero_ps();
|
|
__m512 x1 = _mm512_setzero_ps();
|
|
for (int j = 0; j < k; j++) {
|
|
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
|
continue;
|
|
}
|
|
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
|
|
__m512 down_output0, down_output1;
|
|
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
|
|
m_local_pos_[i][j] * config_.hidden_size + e),
|
|
&down_output0, &down_output1);
|
|
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
|
|
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
|
|
}
|
|
auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);
|
|
f32out[0] = x0;
|
|
f32out[1] = x1;
|
|
}
|
|
},
|
|
nullptr);
|
|
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
{
|
|
auto now_time = std::chrono::high_resolution_clock::now();
|
|
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
|
last = now_time;
|
|
}
|
|
auto end_time = std::chrono::high_resolution_clock::now();
|
|
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
|
|
printf(
|
|
"Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, "
|
|
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: "
|
|
"%d, qlen: %d\n",
|
|
tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time,
|
|
down_time, weight_time, forward_total_time, max_local_num, qlen);
|
|
#endif
|
|
}
|
|
|
|
void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
|
|
int qlen = 1;
|
|
auto pool = config_.pool->get_subpool(tp_part_idx);
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
auto start_time = std::chrono::high_resolution_clock::now();
|
|
auto last = start_time;
|
|
long q_input_time = 0, up_gate_time = 0, act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
|
|
#endif
|
|
|
|
int activated_expert = 0;
|
|
std::fill(m_local_num_.begin(), m_local_num_.end(), 0);
|
|
for (int i = 0; i < k; i++) {
|
|
if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) {
|
|
continue;
|
|
}
|
|
m_expert_id_map_[activated_expert] = expert_ids[i];
|
|
m_local_pos_[0][i] = 0;
|
|
m_local_num_[expert_ids[i]] = qlen;
|
|
activated_expert++;
|
|
}
|
|
|
|
size_t offset = 0;
|
|
for (int i = 0; i < activated_expert; i++) {
|
|
auto expert_idx = m_expert_id_map_[i];
|
|
m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size;
|
|
m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size;
|
|
m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size;
|
|
offset += qlen;
|
|
}
|
|
|
|
void* gate_bc_pool_ptr = gate_bc_pool_;
|
|
void* up_bc_pool_ptr = up_bc_pool_;
|
|
void* down_ba_pool_ptr = down_ba_pool_;
|
|
void* down_bc_pool_ptr = down_bc_pool_;
|
|
constexpr size_t M_STEP = T::M_STEP;
|
|
auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); };
|
|
size_t used_pool_m = 0;
|
|
size_t used_pool_bytes_bc_gate = 0, used_pool_bytes_bc_up = 0, used_pool_bytes_ba_down = 0,
|
|
used_pool_bytes_bc_down = 0;
|
|
for (int i = 0; i < activated_expert; i++) {
|
|
auto expert_idx = m_expert_id_map_[i];
|
|
size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;
|
|
|
|
gate_bc_[expert_idx]->max_m = max_m;
|
|
gate_bc_[expert_idx]->set_data(gate_bc_pool_ptr);
|
|
size_t bc_gate_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));
|
|
gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + bc_gate_size);
|
|
|
|
up_bc_[expert_idx]->max_m = max_m;
|
|
up_bc_[expert_idx]->set_data(up_bc_pool_ptr);
|
|
size_t bc_up_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));
|
|
up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + bc_up_size);
|
|
|
|
down_ba_[expert_idx]->max_m = max_m;
|
|
down_ba_[expert_idx]->set_data(down_ba_pool_ptr);
|
|
size_t ba_down_size = align64(buffer_a_required_size(max_m, config_.intermediate_size));
|
|
down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + ba_down_size);
|
|
|
|
down_bc_[expert_idx]->max_m = max_m;
|
|
down_bc_[expert_idx]->set_data(down_bc_pool_ptr);
|
|
size_t bc_down_size = align64(buffer_c_required_size(max_m, config_.hidden_size));
|
|
down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size);
|
|
|
|
used_pool_m += max_m;
|
|
used_pool_bytes_bc_gate += bc_gate_size;
|
|
used_pool_bytes_bc_up += bc_up_size;
|
|
used_pool_bytes_ba_down += ba_down_size;
|
|
used_pool_bytes_bc_down += bc_down_size;
|
|
}
|
|
assert(used_pool_m <= pool_count_);
|
|
assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_);
|
|
assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_);
|
|
assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_);
|
|
assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_);
|
|
|
|
void* gate_up_ba_pool_ptr = gate_up_ba_pool_;
|
|
for (int i = 0; i < activated_expert; i++) {
|
|
auto expert_idx = m_expert_id_map_[i];
|
|
size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;
|
|
gate_up_ba_[expert_idx]->max_m = max_m;
|
|
gate_up_ba_[expert_idx]->set_data(gate_up_ba_pool_ptr);
|
|
size_t ba_size = align64(buffer_a_required_size(max_m, config_.hidden_size));
|
|
gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size);
|
|
gate_up_ba_[expert_idx]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);
|
|
}
|
|
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
{
|
|
auto now_time = std::chrono::high_resolution_clock::now();
|
|
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
|
last = now_time;
|
|
}
|
|
#endif
|
|
|
|
int nth = T::recommended_nth(config_.intermediate_size);
|
|
pool->do_work_stealing_job(
|
|
nth * activated_expert * 2, [](int _) { T::config(); },
|
|
[this, nth, qlen](int task_id2) {
|
|
int task_id = task_id2 / 2;
|
|
bool do_up = task_id2 % 2;
|
|
int expert_idx = m_expert_id_map_[task_id / nth];
|
|
|
|
int ith = task_id % nth;
|
|
derived()->do_gate_up_gemm(do_up, expert_idx, ith, nth, qlen);
|
|
if (do_up) {
|
|
up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth);
|
|
} else {
|
|
gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth);
|
|
}
|
|
},
|
|
nullptr);
|
|
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
{
|
|
auto now_time = std::chrono::high_resolution_clock::now();
|
|
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
|
last = now_time;
|
|
}
|
|
#endif
|
|
|
|
apply_activation(activated_expert, nth, qlen);
|
|
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
{
|
|
auto now_time = std::chrono::high_resolution_clock::now();
|
|
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
|
last = now_time;
|
|
}
|
|
#endif
|
|
|
|
pool->do_work_stealing_job(
|
|
activated_expert, nullptr,
|
|
[this, qlen](int task_id) {
|
|
int expert_idx = m_expert_id_map_[task_id];
|
|
down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1);
|
|
},
|
|
nullptr);
|
|
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
{
|
|
auto now_time = std::chrono::high_resolution_clock::now();
|
|
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
|
last = now_time;
|
|
}
|
|
#endif
|
|
|
|
nth = T::recommended_nth(config_.hidden_size);
|
|
pool->do_work_stealing_job(
|
|
nth * activated_expert, [](int _) { T::config(); },
|
|
[this, nth, qlen](int task_id) {
|
|
int expert_idx = m_expert_id_map_[task_id / nth];
|
|
int ith = task_id % nth;
|
|
derived()->do_down_gemm(expert_idx, ith, nth, qlen);
|
|
down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth);
|
|
},
|
|
nullptr);
|
|
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
{
|
|
auto now_time = std::chrono::high_resolution_clock::now();
|
|
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
|
last = now_time;
|
|
}
|
|
#endif
|
|
|
|
for (int e = 0; e < config_.hidden_size; e += 32) {
|
|
__m512 x0 = _mm512_setzero_ps();
|
|
__m512 x1 = _mm512_setzero_ps();
|
|
for (int j = 0; j < k; j++) {
|
|
if (expert_ids[j] < config_.num_gpu_experts || expert_ids[j] >= config_.expert_num) {
|
|
continue;
|
|
}
|
|
__m512 weight = _mm512_set1_ps(weights[j]);
|
|
__m512 down_output0, down_output1;
|
|
avx512_32xbf16_to_32xfp32(
|
|
(__m512i*)(m_local_down_output_ptr_[expert_ids[j]] + m_local_pos_[0][j] * config_.hidden_size + e),
|
|
&down_output0, &down_output1);
|
|
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
|
|
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
|
|
}
|
|
auto f32out = (__m512*)((float*)output + e);
|
|
f32out[0] = x0;
|
|
f32out[1] = x1;
|
|
}
|
|
|
|
#ifdef FORWARD_TIME_PROFILE
|
|
{
|
|
auto now_time = std::chrono::high_resolution_clock::now();
|
|
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
|
last = now_time;
|
|
}
|
|
auto end_time = std::chrono::high_resolution_clock::now();
|
|
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
|
|
printf(
|
|
"Profiling Results (numa[%d]): activated_expert: %d, q_input: %ld us, "
|
|
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n",
|
|
tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time,
|
|
forward_total_time);
|
|
#endif
|
|
}
|
|
|
|
protected:
|
|
Derived* derived() { return static_cast<Derived*>(this); }
|
|
const Derived* derived_const() const { return static_cast<const Derived*>(this); }
|
|
|
|
// ============================================================================
|
|
// Virtual points for buffer creation and size calculation
|
|
// Default implementations use group_size (for KGroup quantization like K2)
|
|
// Derived classes (like moe.hpp) can override to not use group_size
|
|
// ============================================================================
|
|
|
|
size_t buffer_a_required_size(size_t m, size_t k) const { return derived_const()->buffer_a_required_size_impl(m, k); }
|
|
size_t buffer_b_required_size(size_t n, size_t k) const { return derived_const()->buffer_b_required_size_impl(n, k); }
|
|
size_t buffer_c_required_size(size_t m, size_t n) const { return derived_const()->buffer_c_required_size_impl(m, n); }
|
|
|
|
std::shared_ptr<typename T::BufferA> make_buffer_a(size_t m, size_t k, void* data) const {
|
|
return derived_const()->make_buffer_a_impl(m, k, data);
|
|
}
|
|
std::shared_ptr<typename T::BufferB> make_buffer_b(size_t n, size_t k, void* data) const {
|
|
return derived_const()->make_buffer_b_impl(n, k, data);
|
|
}
|
|
std::shared_ptr<typename T::BufferC> make_buffer_c(size_t m, size_t n, void* data) const {
|
|
return derived_const()->make_buffer_c_impl(m, n, data);
|
|
}
|
|
|
|
void apply_activation(int activated_expert, int nth, int qlen) {
|
|
auto pool = config_.pool->get_subpool(tp_part_idx);
|
|
auto fn = [this, nth](int task_id) {
|
|
int expert_idx = m_expert_id_map_[task_id / nth];
|
|
int ith = task_id % nth;
|
|
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
|
|
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
|
|
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
|
|
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
|
|
for (int j = n_start; j < n_end; j += 32) {
|
|
__m512 gate_val0, gate_val1, up_val0, up_val1;
|
|
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
|
|
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
|
|
__m512 result0 = amx::act_fn(gate_val0, up_val0);
|
|
__m512 result1 = amx::act_fn(gate_val1, up_val1);
|
|
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
|
|
}
|
|
}
|
|
};
|
|
|
|
if (activated_expert == 0) {
|
|
return;
|
|
}
|
|
|
|
if (qlen < 10) {
|
|
for (int task_id = 0; task_id < nth * activated_expert; task_id++) {
|
|
fn(task_id);
|
|
}
|
|
} else {
|
|
pool->do_work_stealing_job(nth * activated_expert, nullptr, fn, nullptr);
|
|
}
|
|
}
|
|
};
|
|
|
|
// ============================================================================
|
|
// TP_MOE specialization for AMX_MOE_BASE derived classes
|
|
// ============================================================================
|
|
|
|
template <class T, class Derived>
|
|
class TP_MOE<AMX_MOE_BASE<T, Derived>> : public TP_MOE_Common<AMX_MOE_BASE<T, Derived>> {
|
|
public:
|
|
using TP_MOE_Common<AMX_MOE_BASE<T, Derived>>::TP_MOE_Common;
|
|
|
|
// Default load_weights implementation - can be overridden by derived TP_MOE classes
|
|
void load_weights() override { throw std::runtime_error("Not Implemented"); }
|
|
|
|
void write_weight_scale_to_buffer(int gpu_tp_count, int gpu_experts_num,
|
|
const std::vector<uintptr_t>& w13_weight_ptrs,
|
|
const std::vector<uintptr_t>& w13_scale_ptrs,
|
|
const std::vector<uintptr_t>& w2_weight_ptrs,
|
|
const std::vector<uintptr_t>& w2_scale_ptrs) {
|
|
throw std::runtime_error("Not Implemented");
|
|
}
|
|
|
|
void merge_results(int qlen, void* output, bool incremental) override {
|
|
auto& config = this->config;
|
|
auto& tp_count = this->tp_count;
|
|
auto& local_output_numa = this->local_output_numa;
|
|
auto& tp_configs = this->tp_configs;
|
|
|
|
auto merge_fn = [this, output, incremental, &config, &tp_count, &local_output_numa, &tp_configs](int token_nth) {
|
|
float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size;
|
|
if (incremental) {
|
|
for (int e = 0; e < config.hidden_size; e += 32) {
|
|
__m512 x0, x1;
|
|
avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1);
|
|
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0);
|
|
*((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1);
|
|
}
|
|
}
|
|
for (int i = 1; i < tp_count; i++) {
|
|
float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size;
|
|
for (int e = 0; e < tp_configs[i].hidden_size; e += 16) {
|
|
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e)));
|
|
}
|
|
}
|
|
for (int e = 0; e < config.hidden_size; e += 32) {
|
|
__m512 x0 = *(__m512*)(merge_to + e);
|
|
__m512 x1 = *(__m512*)(merge_to + e + 16);
|
|
avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e));
|
|
}
|
|
};
|
|
|
|
auto pool = config.pool;
|
|
|
|
auto direct_or_pool = [&](int count, auto&& fn) {
|
|
if (qlen < 10) {
|
|
for (int i = 0; i < count; i++) {
|
|
fn(i);
|
|
}
|
|
} else {
|
|
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
|
|
}
|
|
};
|
|
|
|
direct_or_pool(qlen, merge_fn);
|
|
}
|
|
|
|
void merge_results(int qlen, void* output) override { merge_results(qlen, output, false); }
|
|
};
|
|
|
|
#endif // CPUINFER_OPERATOR_AMX_MOE_BASE_H
|