Files
ktransformers/kt-kernel/operators/amx/k2-moe.hpp
Oql dc6394e501 [fix]: fix moe hpp bug. (#1780)
fix moe hpp init bug.
2026-01-04 19:32:56 +08:00

673 lines
32 KiB
C++

/**
* @Description : K2 AMX MoE operator for Kimi-K2 native inference
* @Author : oql, Codex and Claude
* @Date : 2025-12-09
* @Version : 1.0.0
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*
* This file implements K2 Int4 MoE using CRTP pattern, inheriting from moe_base.hpp.
* K2 weights are stored with group-wise scales (KGroup Int4).
**/
#ifndef CPUINFER_OPERATOR_AMX_K2_MOE_H
#define CPUINFER_OPERATOR_AMX_K2_MOE_H
// #define LOAD_TIME_PROFILE
#include "moe_base.hpp"
/**
* @brief K2 Int4 MoE operator using CRTP pattern
* @tparam T Kernel type, defaults to amx::GemmKernel224Int4SmallKGroup
*
* This class provides K2-specific GEMM implementations:
* - do_gate_up_gemm: Int4 weight with KGroup scale + AMX GEMM
* - do_down_gemm: Same Int4 KGroup GEMM
* - load_weights: Load Int4 weights with group-wise scales
*/
template <class T = amx::GemmKernel224Int4SmallKGroup>
class AMX_K2_MOE_TP : public AMX_MOE_BASE<T, AMX_K2_MOE_TP<T>> {
using Base = AMX_MOE_BASE<T, AMX_K2_MOE_TP<T>>;
using Base::config_;
using Base::down_ba_;
using Base::down_bb_;
using Base::down_bc_;
using Base::gate_bb_;
using Base::gate_bc_;
using Base::gate_up_ba_;
using Base::m_local_num_;
using Base::tp_part_idx;
using Base::up_bb_;
using Base::up_bc_;
public:
using typename Base::input_t;
using typename Base::output_t;
AMX_K2_MOE_TP() = default;
AMX_K2_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {}
void derived_init() {
auto& quant_config = config_.quant_config;
if (quant_config.group_size == 0 || quant_config.zero_point) {
throw std::runtime_error("Kimi-K2 MoE only support KGroup Int4");
}
printf("Creating AMX_K2_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu()));
}
~AMX_K2_MOE_TP() = default;
// ============================================================================
// CRTP buffer creation - with group_size
// ============================================================================
size_t buffer_a_required_size_impl(size_t m, size_t k) const {
return T::BufferA::required_size(m, k, config_.quant_config.group_size);
}
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
return T::BufferB::required_size(n, k, config_.quant_config.group_size);
}
size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
return std::make_shared<typename T::BufferA>(m, k, config_.quant_config.group_size, data);
}
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
return std::make_shared<typename T::BufferB>(n, k, config_.quant_config.group_size, data);
}
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
return std::make_shared<typename T::BufferC>(m, n, data);
}
// ============================================================================
// CRTP virtual points - GEMM dispatch
// ============================================================================
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
auto& group_size = config_.quant_config.group_size;
int m = m_local_num_[expert_idx];
auto& ba = gate_up_ba_[expert_idx];
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
// Dispatch based on qlen threshold
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {
amx::mat_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth);
} else {
amx::vec_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth);
}
}
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
auto& group_size = config_.quant_config.group_size;
int m = m_local_num_[expert_idx];
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {
amx::mat_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
} else {
amx::vec_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
}
}
/**
* @brief Load Int4 weights from contiguous memory layout
*
* Loads weights from config_.gate_proj, up_proj, down_proj with scales
* from config_.gate_scale, up_scale, down_scale.
*/
void load_weights() {
auto& quant_config = config_.quant_config;
int& group_size = quant_config.group_size;
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
auto pool = config_.pool->get_subpool(tp_part_idx);
if (quant_config.group_size == 0 || quant_config.zero_point) {
throw std::runtime_error("Kimi AVX MOE only support KGroup Int4.");
}
if (config_.gate_scale == nullptr) {
throw std::runtime_error("Kimi AVX MOE only support load native weight.");
}
// load weight
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// gate part
gate_bb_[expert_idx]->from_raw_mat(
(uint8_t*)config_.gate_proj +
((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1),
ith, nth);
// up part
up_bb_[expert_idx]->from_raw_mat(
(uint8_t*)config_.up_proj + ((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1),
ith, nth);
},
nullptr);
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// down part
down_bb_[expert_idx]->from_raw_mat(
(uint8_t*)config_.down_proj +
((logical_expert_id * config_.hidden_size * config_.intermediate_size) >> 1),
ith, nth);
},
nullptr);
pool->do_work_stealing_job(
config_.expert_num, nullptr,
[this, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
size_t scale_elem_count = (config_.hidden_size * config_.intermediate_size) / config_.quant_config.group_size;
// convert scales from BF16 to FP32
convert_or_copy(gate_bb_[expert_idx]->d,
(ggml_bf16_t*)config_.gate_scale + (logical_expert_id * scale_elem_count), scale_elem_count);
convert_or_copy(up_bb_[expert_idx]->d,
(ggml_bf16_t*)config_.up_scale + (logical_expert_id * scale_elem_count), scale_elem_count);
convert_or_copy(down_bb_[expert_idx]->d,
(ggml_bf16_t*)config_.down_scale + (logical_expert_id * scale_elem_count), scale_elem_count);
},
nullptr);
#ifdef DEBUG_K2_MOE
dump_buffer_b("native", 0, "down", down_bb_[0].get());
#endif
}
static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) {
uint8_t* d = (uint8_t*)dst;
const uint8_t* s = (const uint8_t*)src;
// Main loop: 512-bit (64-byte) SIMD copies
size_t chunks = bytes / 64;
for (size_t i = 0; i < chunks; i++) {
__m512i data = _mm512_loadu_si512((__m512i*)s);
_mm512_storeu_si512((__m512i*)d, data);
d += 64;
s += 64;
}
bytes -= chunks * 64;
// Handle remaining bytes
if (bytes > 0) {
std::memcpy(d, s, bytes);
}
}
// Optimized SIMD float32 to bf16 conversion
static inline void fast_fp32_to_bf16(ggml_bf16_t* __restrict dst, const float* __restrict src, size_t count) {
size_t i = 0;
// Process 32 elements at a time (2x __m512, output 1x __m512i = 32 bf16)
for (; i + 32 <= count; i += 32) {
__m512 v0 = _mm512_loadu_ps(src + i);
__m512 v1 = _mm512_loadu_ps(src + i + 16);
// Convert to bf16 using truncation (shift right 16 bits)
__m512i i0 = _mm512_srli_epi32(_mm512_castps_si512(v0), 16);
__m512i i1 = _mm512_srli_epi32(_mm512_castps_si512(v1), 16);
// Pack 32-bit values to 16-bit
__m512i packed = _mm512_packus_epi32(i0, i1);
// Reorder due to packus lane behavior:
// packus outputs interleaved: [i0[0-3], i1[0-3], i0[4-7], i1[4-7], i0[8-11], i1[8-11], i0[12-15], i1[12-15]]
// We need sequential: [i0[0-15], i1[0-15]] = [i0[0-3], i0[4-7], i0[8-11], i0[12-15], i1[0-3], i1[4-7], i1[8-11],
// i1[12-15]] Permutation: [0, 2, 4, 6, 1, 3, 5, 7] (qword indices)
__m512i permuted = _mm512_permutexvar_epi64(_mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0), packed);
_mm512_storeu_si512((__m512i*)(dst + i), permuted);
}
// Handle remaining elements with scalar conversion
for (; i < count; i++) {
dst[i] = ggml_fp32_to_bf16(src[i]);
}
}
// Write a single expert's weights to the output buffers
// The caller provides pointers that already point to the target expert's location (no offset needed)
// expert_id: the index of the expert to write
// Optimized for maximum memory bandwidth using streaming stores
void write_weights_to_buffer(int gpu_tp_count, int cpu_tp_count, int expert_id, const GeneralMOEConfig& full_config,
const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) const {
const int group_size = config_.quant_config.group_size;
auto pool = config_.pool->get_subpool(tp_part_idx);
// Calculate sizes for CPU TP part (this instance)
size_t cpu_tp_weight_elem_count = (size_t)config_.intermediate_size * config_.hidden_size;
size_t cpu_tp_weight_bytes = cpu_tp_weight_elem_count / 2; // int4 packing
size_t cpu_tp_scale_elem_count = cpu_tp_weight_elem_count / group_size;
// Calculate sizes for GPU TP part
size_t gpu_tp_weight_elem_count = (size_t)full_config.intermediate_size * full_config.hidden_size / gpu_tp_count;
size_t gpu_tp_weight_bytes = gpu_tp_weight_elem_count / 2; // int4 packing
size_t gpu_tp_scale_elem_count = gpu_tp_weight_elem_count / group_size;
// Determine mapping: which GPU TP parts should this CPU TP part write to?
// Since weights are col-major and we slice directly by memory order:
// - If cpu_tp_count >= gpu_tp_count: multiple(or one) CPU TPs write to one GPU TP
// - If cpu_tp_count < gpu_tp_count: one CPU TP writes to multiple GPU TPs
if (cpu_tp_count >= gpu_tp_count) {
// Multiple CPU TPs map to one GPU TP
int target_gpu_tp = tp_part_idx / (cpu_tp_count / gpu_tp_count);
int local_idx = tp_part_idx % (cpu_tp_count / gpu_tp_count);
// Get pointers for this GPU TP part (already pointing to target expert's location)
uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[target_gpu_tp];
ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[target_gpu_tp];
uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[target_gpu_tp];
ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[target_gpu_tp];
// Calculate offset within the GPU TP buffer (for CPU TP slice within GPU TP)
size_t offset_in_gpu_weight = local_idx * cpu_tp_weight_bytes;
size_t offset_in_gpu_scale = local_idx * cpu_tp_scale_elem_count;
// Optimized task layout for maximum bandwidth:
// - Larger chunks to reduce task overhead
// - Separate large contiguous copies (gate_w, up_w) from strided copies (down)
// - Scale conversions are relatively small, merge with weight tasks
// Use fewer, larger tasks for better efficiency
constexpr int NUM_WEIGHT_TASKS = 8; // Fewer tasks, larger chunks
constexpr int MIN_COLS_PER_TASK = 128;
int num_down_tasks = std::max(1, (int)config_.hidden_size / MIN_COLS_PER_TASK);
num_down_tasks = std::min(num_down_tasks, 32);
// Total tasks: gate_weight + up_weight + down_weight_scale + gate_scale + up_scale
int total_tasks = NUM_WEIGHT_TASKS * 2 + num_down_tasks + 2;
size_t weight_chunk_size = (cpu_tp_weight_bytes + NUM_WEIGHT_TASKS - 1) / NUM_WEIGHT_TASKS;
// Align chunk size to 64 bytes for optimal streaming stores
weight_chunk_size = (weight_chunk_size + 63) & ~63ULL;
pool->do_work_stealing_job(
total_tasks, nullptr,
[&, this, num_down_tasks, expert_id, weight_chunk_size](int task_id) {
if (task_id < NUM_WEIGHT_TASKS) {
// Gate weight copy - chunked
int chunk_idx = task_id;
size_t start = chunk_idx * weight_chunk_size;
size_t end = std::min(start + weight_chunk_size, cpu_tp_weight_bytes);
if (start < end) {
uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b;
fast_memcpy(w13_weight_dst + offset_in_gpu_weight + start, gate_weight_src + start, end - start);
}
} else if (task_id < NUM_WEIGHT_TASKS * 2) {
// Up weight copy - chunked
int chunk_idx = task_id - NUM_WEIGHT_TASKS;
size_t start = chunk_idx * weight_chunk_size;
size_t end = std::min(start + weight_chunk_size, cpu_tp_weight_bytes);
if (start < end) {
uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b;
fast_memcpy(w13_weight_dst + offset_in_gpu_weight + gpu_tp_weight_bytes + start, up_weight_src + start,
end - start);
}
} else if (task_id < NUM_WEIGHT_TASKS * 2 + num_down_tasks) {
// Down columns - split by column chunks
// Each task handles multiple consecutive columns for better cache locality
int chunk_idx = task_id - NUM_WEIGHT_TASKS * 2;
size_t cols_per_chunk = (config_.hidden_size + num_down_tasks - 1) / num_down_tasks;
size_t col_start = chunk_idx * cols_per_chunk;
size_t col_end = std::min(col_start + cols_per_chunk, (size_t)config_.hidden_size);
size_t weight_per_col = config_.intermediate_size >> 1;
size_t scale_per_col = config_.intermediate_size / group_size;
size_t gpu_weight_stride = (full_config.intermediate_size / gpu_tp_count) >> 1;
size_t gpu_scale_stride = (full_config.intermediate_size / gpu_tp_count) / group_size;
size_t gpu_weight_slice_offset = local_idx * weight_per_col;
size_t gpu_scale_slice_offset = local_idx * scale_per_col;
for (size_t col = col_start; col < col_end; col++) {
fast_memcpy(w2_weight_dst + col * gpu_weight_stride + gpu_weight_slice_offset,
(uint8_t*)down_bb_[expert_id]->b + col * weight_per_col, weight_per_col);
fast_fp32_to_bf16(w2_scale_dst + col * gpu_scale_stride + gpu_scale_slice_offset,
down_bb_[expert_id]->d + col * scale_per_col, scale_per_col);
}
} else if (task_id == NUM_WEIGHT_TASKS * 2 + num_down_tasks) {
// Gate scale convert
float* gate_scale_src = gate_bb_[expert_id]->d;
fast_fp32_to_bf16(w13_scale_dst + offset_in_gpu_scale, gate_scale_src, cpu_tp_scale_elem_count);
} else {
// Up scale convert
float* up_scale_src = up_bb_[expert_id]->d;
fast_fp32_to_bf16(w13_scale_dst + offset_in_gpu_scale + gpu_tp_scale_elem_count, up_scale_src,
cpu_tp_scale_elem_count);
}
},
nullptr);
} else {
// cpu_tp_count < gpu_tp_count: one CPU TP writes to multiple GPU TPs
int gpu_tps_per_cpu_tp = gpu_tp_count / cpu_tp_count;
int start_gpu_tp = tp_part_idx * gpu_tps_per_cpu_tp;
// Size of data per GPU TP within this CPU TP
size_t data_per_gpu_tp_weight = cpu_tp_weight_bytes / gpu_tps_per_cpu_tp;
size_t data_per_gpu_tp_scale = cpu_tp_scale_elem_count / gpu_tps_per_cpu_tp;
// Optimized task layout
constexpr int NUM_WEIGHT_TASKS = 8;
constexpr int MIN_COLS_PER_TASK = 128;
int num_down_tasks = std::max(1, (int)config_.hidden_size / MIN_COLS_PER_TASK);
num_down_tasks = std::min(num_down_tasks, 32);
int tasks_per_gpu_tp = NUM_WEIGHT_TASKS * 2 + num_down_tasks + 2;
int total_tasks = tasks_per_gpu_tp * gpu_tps_per_cpu_tp;
size_t weight_chunk_size = (data_per_gpu_tp_weight + NUM_WEIGHT_TASKS - 1) / NUM_WEIGHT_TASKS;
weight_chunk_size = (weight_chunk_size + 63) & ~63ULL;
pool->do_work_stealing_job(
total_tasks, nullptr,
[&, this, gpu_tps_per_cpu_tp, start_gpu_tp, data_per_gpu_tp_weight, data_per_gpu_tp_scale, num_down_tasks,
tasks_per_gpu_tp, expert_id, weight_chunk_size](int task_id) {
int local_gpu_idx = task_id / tasks_per_gpu_tp;
int task_type = task_id % tasks_per_gpu_tp;
int gpu_tp_idx = start_gpu_tp + local_gpu_idx;
// Get pointers for this GPU TP part
uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[gpu_tp_idx];
ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[gpu_tp_idx];
uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[gpu_tp_idx];
ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[gpu_tp_idx];
// Calculate offsets within CPU TP buffers
size_t cpu_offset_weight = local_gpu_idx * data_per_gpu_tp_weight;
size_t cpu_offset_scale = local_gpu_idx * data_per_gpu_tp_scale;
if (task_type < NUM_WEIGHT_TASKS) {
// Gate weight copy - chunked
int chunk_idx = task_type;
size_t start = chunk_idx * weight_chunk_size;
size_t end = std::min(start + weight_chunk_size, data_per_gpu_tp_weight);
if (start < end) {
uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b + cpu_offset_weight;
fast_memcpy(w13_weight_dst + start, gate_weight_src + start, end - start);
}
} else if (task_type < NUM_WEIGHT_TASKS * 2) {
// Up weight copy - chunked
int chunk_idx = task_type - NUM_WEIGHT_TASKS;
size_t start = chunk_idx * weight_chunk_size;
size_t end = std::min(start + weight_chunk_size, data_per_gpu_tp_weight);
if (start < end) {
uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b + cpu_offset_weight;
fast_memcpy(w13_weight_dst + gpu_tp_weight_bytes + start, up_weight_src + start, end - start);
}
} else if (task_type < NUM_WEIGHT_TASKS * 2 + num_down_tasks) {
// Down columns - split by column chunks
int chunk_idx = task_type - NUM_WEIGHT_TASKS * 2;
size_t cols_per_chunk = (config_.hidden_size + num_down_tasks - 1) / num_down_tasks;
size_t col_start = chunk_idx * cols_per_chunk;
size_t col_end = std::min(col_start + cols_per_chunk, (size_t)config_.hidden_size);
size_t weight_per_gpu_col = (config_.intermediate_size / gpu_tps_per_cpu_tp) >> 1;
size_t scale_per_gpu_col = (config_.intermediate_size / gpu_tps_per_cpu_tp) / group_size;
for (size_t col = col_start; col < col_end; col++) {
size_t col_offset_weight = (col * config_.intermediate_size / 2) +
(local_gpu_idx * data_per_gpu_tp_weight / config_.hidden_size);
size_t col_offset_scale = (col * (config_.intermediate_size / group_size)) +
(local_gpu_idx * data_per_gpu_tp_scale / config_.hidden_size);
fast_memcpy(w2_weight_dst + col * weight_per_gpu_col,
(uint8_t*)down_bb_[expert_id]->b + col_offset_weight, weight_per_gpu_col);
fast_fp32_to_bf16(w2_scale_dst + col * scale_per_gpu_col, down_bb_[expert_id]->d + col_offset_scale,
scale_per_gpu_col);
}
} else if (task_type == NUM_WEIGHT_TASKS * 2 + num_down_tasks) {
// Gate scale convert
float* gate_scale_src = gate_bb_[expert_id]->d + cpu_offset_scale;
fast_fp32_to_bf16(w13_scale_dst, gate_scale_src, data_per_gpu_tp_scale);
} else {
// Up scale convert
float* up_scale_src = up_bb_[expert_id]->d + cpu_offset_scale;
fast_fp32_to_bf16(w13_scale_dst + gpu_tp_scale_elem_count, up_scale_src, data_per_gpu_tp_scale);
}
},
nullptr);
}
}
};
// ============================================================================
// TP_MOE specialization for AMX_K2_MOE_TP
// Inherits from TP_MOE<AMX_MOE_BASE<...>> to reuse merge_results implementation
// ============================================================================
template <typename K>
class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>> {
public:
using Base = TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>>;
using Base::Base;
void load_weights() override {
auto& config = this->config;
auto& tps = this->tps;
auto& tp_count = this->tp_count;
auto pool = config.pool;
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
#ifdef LOAD_TIME_PROFILE
auto load_start_time = std::chrono::high_resolution_clock::now();
auto load_last = load_start_time;
long alloc_and_tp_slice_time = 0, tps_load_time = 0, cleanup_time = 0;
#endif
bool use_per_expert_ptrs = !config.gate_projs.empty();
if (config.gate_projs.empty() && config.gate_scale == nullptr) {
throw std::runtime_error("K2 MoE only supports Packed Int4 with KGroup Scale");
}
if (use_per_expert_ptrs) {
printf("From per-expert pointers (gate_projs)\n");
} else {
printf("From Packed Int4 with KGroup Scale\n");
}
int& group_size = config.quant_config.group_size;
if (use_per_expert_ptrs) {
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&, i](int expert_id_) {
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
uint8_t* src_gate = (uint8_t*)config.gate_projs[0][expert_id];
uint8_t* src_up = (uint8_t*)config.up_projs[0][expert_id];
uint8_t* src_down = (uint8_t*)config.down_projs[0][expert_id];
ggml_bf16_t* src_gate_scale = (ggml_bf16_t*)config.gate_scales[0][expert_id];
ggml_bf16_t* src_up_scale = (ggml_bf16_t*)config.up_scales[0][expert_id];
ggml_bf16_t* src_down_scale = (ggml_bf16_t*)config.down_scales[0][expert_id];
memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),
src_gate + ((i * weight_elem_count) >> 1), (weight_elem_count >> 1));
memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1),
src_up + ((i * weight_elem_count) >> 1), (weight_elem_count >> 1));
memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count),
src_gate_scale + (i * scales_elem_count), sizeof(ggml_bf16_t) * scales_elem_count);
memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count),
src_up_scale + (i * scales_elem_count), sizeof(ggml_bf16_t) * scales_elem_count);
for (size_t col = 0; col < config.hidden_size; col++) {
memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1),
src_down + ((col * config.intermediate_size + i * tpc.intermediate_size) >> 1),
(tpc.intermediate_size >> 1));
memcpy((ggml_bf16_t*)tpc.down_scale +
(expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)),
src_down_scale +
(col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)),
sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size));
}
},
nullptr);
printf("TP %d load weight done.\n", i);
}
} else {
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
if (tpc.load == false) {
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&](int expert_id_) {
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),
(uint8_t*)config.gate_proj +
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
((sizeof(uint8_t) * weight_elem_count) >> 1));
memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1),
(uint8_t*)config.up_proj +
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
((sizeof(uint8_t) * weight_elem_count) >> 1));
memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count),
(ggml_bf16_t*)config.gate_scale +
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
i * scales_elem_count),
sizeof(ggml_bf16_t) * scales_elem_count);
memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count),
(ggml_bf16_t*)config.up_scale +
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
i * scales_elem_count),
sizeof(ggml_bf16_t) * scales_elem_count);
for (size_t col = 0; col < config.hidden_size; col++) {
memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1),
(uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size +
col * config.intermediate_size + i * tpc.intermediate_size) >>
1),
(sizeof(uint8_t) * tpc.intermediate_size) >> 1);
memcpy((ggml_bf16_t*)tpc.down_scale +
(expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)),
(ggml_bf16_t*)config.down_scale +
((expert_id * (config.intermediate_size / group_size) * config.hidden_size) +
col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)),
sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size));
}
},
nullptr);
}
printf("TP %d load weight done.\n", i);
}
}
#ifdef LOAD_TIME_PROFILE
{
auto load_now_time = std::chrono::high_resolution_clock::now();
alloc_and_tp_slice_time =
std::chrono::duration_cast<std::chrono::microseconds>(load_now_time - load_last).count();
load_last = load_now_time;
}
#endif
DO_TPS_LOAD_WEIGHTS(pool);
#ifdef LOAD_TIME_PROFILE
{
auto load_now_time = std::chrono::high_resolution_clock::now();
tps_load_time = std::chrono::duration_cast<std::chrono::microseconds>(load_now_time - load_last).count();
load_last = load_now_time;
}
#endif
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
delete[] (uint8_t*)(tpc.gate_proj);
delete[] (uint8_t*)(tpc.up_proj);
delete[] (uint8_t*)(tpc.down_proj);
delete[] (ggml_bf16_t*)(tpc.gate_scale);
delete[] (ggml_bf16_t*)(tpc.up_scale);
delete[] (ggml_bf16_t*)(tpc.down_scale);
}
#ifdef LOAD_TIME_PROFILE
{
auto load_now_time = std::chrono::high_resolution_clock::now();
cleanup_time = std::chrono::duration_cast<std::chrono::microseconds>(load_now_time - load_last).count();
}
auto load_end_time = std::chrono::high_resolution_clock::now();
auto load_total_time =
std::chrono::duration_cast<std::chrono::microseconds>(load_end_time - load_start_time).count();
printf(
"[K2 MoE Load Weights] tp_count: %d, alloc_and_tp_slice: %ld us, tps_load_weights: %ld us, cleanup: %ld us, "
"total: %ld us\n",
tp_count, alloc_and_tp_slice_time, tps_load_time, cleanup_time, load_total_time);
#endif
this->weights_loaded = true;
}
void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) {
if (this->weights_loaded == false) {
throw std::runtime_error("Not Loaded");
}
if (this->tps.empty()) {
throw std::runtime_error("No TP parts initialized");
}
if (w13_weight_ptrs.size() != gpu_tp_count || w13_scale_ptrs.size() != gpu_tp_count ||
w2_weight_ptrs.size() != gpu_tp_count || w2_scale_ptrs.size() != gpu_tp_count) {
throw std::runtime_error("Pointer arrays size must match gpu_tp_count");
}
this->config.pool->dispense_backend()->do_numa_job([&, this](int i) {
this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs,
w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs);
});
}
// merge_results is inherited from TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>>
};
#endif // CPUINFER_OPERATOR_AMX_K2_MOE_H