Support Native Kimi K2 Thinking (#1663)

* [feat]: fix k2 prefill

* Update Kimi-K2-Thinking.md

* Create Kimi-K2-Thinking-Native.md

* Update Kimi-K2-Thinking.md

* Update Kimi-K2-Thinking.md

* Update Kimi-K2-Thinking-Native.md

* [perf] optimize K2 MoE weight loading with per-expert pointers

- Avoid expensive torch.stack().contiguous() in Python (was ~6.6s)
- Use per-expert pointer arrays (gate_projs) instead of contiguous memory
- C++ worker pool performs parallel memcpy for TP slicing
- Add LOAD_TIME_PROFILE for load_weights timing analysis

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: ouqingliang <1692110604@qq.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
ErvinXie
2025-12-05 21:53:05 +08:00
committed by GitHub
parent 4850424345
commit 71f683acec
5 changed files with 419 additions and 70 deletions

View File

@@ -16,7 +16,7 @@
#include <cstdint>
#include <cstring>
// #define FORWARD_TIME_PROFILE
// #define FORWARD_TIME_REPORT
#define LOAD_TIME_PROFILE
#include <immintrin.h>
@@ -145,10 +145,6 @@ class AMX_K2_MOE_TP {
fflush(stdout);
}
#ifdef FORWARD_TIME_REPORT
std::chrono::time_point<std::chrono::high_resolution_clock> last_now;
#endif
public:
using input_t = ggml_bf16_t;
using output_t = float;
@@ -518,8 +514,229 @@ class AMX_K2_MOE_TP {
void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
void* output) {
for (int i = 0; i < qlen; i ++)
forward_decode(k, expert_ids + i * k, weights + i * k, (ggml_bf16_t*)input + i * config_.hidden_size, (float*)output + i * config_.hidden_size);
auto pool = config_.pool->get_subpool(tp_part_idx);
auto& quant_config = config_.quant_config;
int& group_size = quant_config.group_size;
#ifdef FORWARD_TIME_PROFILE
auto start_time = std::chrono::high_resolution_clock::now();
auto last = start_time;
// 用于保存各阶段耗时(单位:微秒)
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
int max_local_num = 0; // 记录最大的 local num
#endif
int activated_expert = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_num_[i] = 0;
}
for (int i = 0; i < qlen; i++) {
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
}
}
for (int i = 0; i < config_.expert_num; i++) {
if (m_local_num_[i] > 0) {
#ifdef FORWARD_TIME_PROFILE
max_local_num = std::max(max_local_num, m_local_num_[i]);
#endif
m_expert_id_map_[activated_expert] = i;
activated_expert++;
}
}
// activated_expert 已经统计完成
size_t offset = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
offset += m_local_num_[i];
}
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
prepare_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
DIRECT_OR_POOL_BY_QLEN(qlen, [&](int i) {
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
(ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
}
});
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
cpy_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
DIRECT_OR_POOL_BY_QLEN(activated_expert, [this](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
});
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * activated_expert * 2, [](int _) { T::config(); },
[this, nth, qlen](int task_id2) {
int& group_size = config_.quant_config.group_size;
int task_id = task_id2 / 2;
bool do_up = task_id2 % 2;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
if (do_up) {
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
group_size, gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx],
ith, nth);
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
} else {
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
group_size, gate_up_ba_[expert_idx], gate_bb_[expert_idx],
gate_bc_[expert_idx], ith, nth);
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
}
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
auto up_gate_fn = [this, nth](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = amx::act_fn(gate_val0, up_val0);
__m512 result1 = amx::act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
}
}
};
DIRECT_OR_POOL_BY_QLEN(nth * activated_expert, up_gate_fn);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
pool->do_work_stealing_job(
activated_expert, nullptr,
[this](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * activated_expert, [](int _) { T::config(); },
[this, nth, qlen](int task_id) {
int& group_size = config_.quant_config.group_size;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size,
group_size, down_ba_[expert_idx], down_bb_[expert_idx], down_bc_[expert_idx],
ith, nth);
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
pool->do_work_stealing_job(
qlen, nullptr,
[this, nth, output, k, expert_ids, weights](int i) {
for (int e = 0; e < config_.hidden_size; e += 32) {
__m512 x0 = _mm512_setzero_ps();
__m512 x1 = _mm512_setzero_ps();
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
__m512 down_output0, down_output1;
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
m_local_pos_[i][j] * config_.hidden_size + e),
&down_output0, &down_output1);
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
}
auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);
f32out[0] = x0;
f32out[1] = x1;
}
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
auto end_time = std::chrono::high_resolution_clock::now();
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
// 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen
printf(
"Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, "
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: "
"%d, qlen: %d\n",
tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time,
down_time, weight_time, forward_total_time, max_local_num, qlen);
#endif
// for (int i = 0; i < qlen; i ++)
// forward_decode(k, expert_ids + i * k, weights + i * k, (ggml_bf16_t*)input + i * config_.hidden_size, (float*)output + i * config_.hidden_size);
}
void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
@@ -768,84 +985,169 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE_Common<AMX_K2_MOE_TP<K>> {
auto pool = config.pool;
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
if (config.gate_scale == nullptr) {
#ifdef LOAD_TIME_PROFILE
auto load_start_time = std::chrono::high_resolution_clock::now();
auto load_last = load_start_time;
long alloc_and_tp_slice_time = 0, tps_load_time = 0, cleanup_time = 0;
#endif
// Check if using per-expert pointers (gate_projs) or contiguous memory (gate_proj + gate_scale)
bool use_per_expert_ptrs = !config.gate_projs.empty();
if (!use_per_expert_ptrs && config.gate_scale == nullptr) {
throw std::runtime_error("K2 MoE only supports Packed Int4 with KGroup Scale");
}
printf("From Packed Int4 with KGroup Scale\n");
if (use_per_expert_ptrs) {
printf("From per-expert pointers (gate_projs)\n");
} else {
printf("From Packed Int4 with KGroup Scale\n");
}
int& group_size = config.quant_config.group_size;
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
if (use_per_expert_ptrs) {
// Load from per-expert pointers - no need to allocate intermediate buffers
// gate_projs[numa_id][expert_id] -> pointer to expert weight
// For RAWINT4, numa dimension is 1 (index 0)
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
// Allocate per-TP buffers
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
if (tps[i]->config_.load == false) {
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&](int expert_id_) { // weight and scale are all in col majored.
[&, i](int expert_id_) {
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
// weight and scale TP-slicing for gate and up
// Source pointers from per-expert pointer arrays
// gate_projs[0][expert_id] since numa dimension is 1
uint8_t* src_gate = (uint8_t*)config.gate_projs[0][expert_id];
uint8_t* src_up = (uint8_t*)config.up_projs[0][expert_id];
uint8_t* src_down = (uint8_t*)config.down_projs[0][expert_id];
ggml_bf16_t* src_gate_scale = (ggml_bf16_t*)config.gate_scales[0][expert_id];
ggml_bf16_t* src_up_scale = (ggml_bf16_t*)config.up_scales[0][expert_id];
ggml_bf16_t* src_down_scale = (ggml_bf16_t*)config.down_scales[0][expert_id];
// TP-slicing for gate and up (row-major slicing)
memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),
(uint8_t*)config.gate_proj +
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
((sizeof(uint8_t) * weight_elem_count) >> 1));
src_gate + ((i * weight_elem_count) >> 1),
(weight_elem_count >> 1));
memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1),
(uint8_t*)config.up_proj +
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
((sizeof(uint8_t) * weight_elem_count) >> 1));
src_up + ((i * weight_elem_count) >> 1),
(weight_elem_count >> 1));
memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count),
(ggml_bf16_t*)config.gate_scale +
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
i * scales_elem_count),
sizeof(ggml_bf16_t) * scales_elem_count);
src_gate_scale + (i * scales_elem_count),
sizeof(ggml_bf16_t) * scales_elem_count);
memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count),
(ggml_bf16_t*)config.up_scale +
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
i * scales_elem_count),
sizeof(ggml_bf16_t) * scales_elem_count);
src_up_scale + (i * scales_elem_count),
sizeof(ggml_bf16_t) * scales_elem_count);
// memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count) >> 1),
// (uint8_t*)config.down_proj +
// ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
// ((sizeof(uint8_t) * weight_elem_count) >> 1));
// memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count),
// (ggml_bf16_t*)config.down_scale +
// (expert_id * (config.intermediate_size / group_size) * config.hidden_size +
// i * scales_elem_count),
// sizeof(ggml_bf16_t) * scales_elem_count);
// weight and scale TP-slicing for down (by column)
// TP-slicing for down (by column)
for (size_t col = 0; col < config.hidden_size; col++) {
memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1),
(uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size +
col * config.intermediate_size + i * tpc.intermediate_size) >>
1),
(sizeof(uint8_t) * tpc.intermediate_size) >> 1);
src_down + ((col * config.intermediate_size + i * tpc.intermediate_size) >> 1),
(tpc.intermediate_size >> 1));
memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)),
(ggml_bf16_t*)config.down_scale + ((expert_id * (config.intermediate_size / group_size) * config.hidden_size) +
col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)),
sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size));
src_down_scale + (col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)),
sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size));
}
},
nullptr);
printf("TP %d load weight done.\n", i);
}
} else {
// Original path: load from contiguous memory with gate_proj/gate_scale
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
if (tps[i]->config_.load == false) {
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&](int expert_id_) { // weight and scale are all in col majored.
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
// weight and scale TP-slicing for gate and up
memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),
(uint8_t*)config.gate_proj +
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
((sizeof(uint8_t) * weight_elem_count) >> 1));
memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1),
(uint8_t*)config.up_proj +
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
((sizeof(uint8_t) * weight_elem_count) >> 1));
memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count),
(ggml_bf16_t*)config.gate_scale +
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
i * scales_elem_count),
sizeof(ggml_bf16_t) * scales_elem_count);
memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count),
(ggml_bf16_t*)config.up_scale +
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
i * scales_elem_count),
sizeof(ggml_bf16_t) * scales_elem_count);
// weight and scale TP-slicing for down (by column)
for (size_t col = 0; col < config.hidden_size; col++) {
memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1),
(uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size +
col * config.intermediate_size + i * tpc.intermediate_size) >>
1),
(sizeof(uint8_t) * tpc.intermediate_size) >> 1);
memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)),
(ggml_bf16_t*)config.down_scale + ((expert_id * (config.intermediate_size / group_size) * config.hidden_size) +
col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)),
sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size));
}
},
nullptr);
}
printf("TP %d load weight done.\n", i);
}
printf("TP %d load weight done.\n", i);
}
#ifdef LOAD_TIME_PROFILE
{
auto load_now_time = std::chrono::high_resolution_clock::now();
alloc_and_tp_slice_time = std::chrono::duration_cast<std::chrono::microseconds>(load_now_time - load_last).count();
load_last = load_now_time;
}
#endif
DO_TPS_LOAD_WEIGHTS(pool);
#ifdef LOAD_TIME_PROFILE
{
auto load_now_time = std::chrono::high_resolution_clock::now();
tps_load_time = std::chrono::duration_cast<std::chrono::microseconds>(load_now_time - load_last).count();
load_last = load_now_time;
}
#endif
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
delete[] (uint8_t*)(tpc.gate_proj);
@@ -857,6 +1159,18 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE_Common<AMX_K2_MOE_TP<K>> {
delete[] (ggml_bf16_t*)(tpc.down_scale);
}
#ifdef LOAD_TIME_PROFILE
{
auto load_now_time = std::chrono::high_resolution_clock::now();
cleanup_time = std::chrono::duration_cast<std::chrono::microseconds>(load_now_time - load_last).count();
}
auto load_end_time = std::chrono::high_resolution_clock::now();
auto load_total_time = std::chrono::duration_cast<std::chrono::microseconds>(load_end_time - load_start_time).count();
printf(
"[K2 MoE Load Weights] tp_count: %d, alloc_and_tp_slice: %ld us, tps_load_weights: %ld us, cleanup: %ld us, total: %ld us\n",
tp_count, alloc_and_tp_slice_time, tps_load_time, cleanup_time, load_total_time);
#endif
this->weights_loaded = true;
}

View File

@@ -2905,7 +2905,7 @@ struct GemmKernel224Int4SmallKGroup {
static inline void integer_mat_vec_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB *bb, BufferC* bc, int ith, int nth) {
auto [n_start, n_end] = split_range_n(n, ith, nth);
for (int m_begin = 0; m_begin < m; m_begin ++) {
float* c = bc->get_submat(m, n, m_begin, 0);
float* c = bc->get_submat(m, n, m_begin, n_start);
__m512i* a512 = (__m512i*)ba->get_submat(m, k, m_begin, 0);
for (int n_block_begin = n_start; n_block_begin < n_end; n_block_begin ++) {
@@ -2929,7 +2929,7 @@ struct GemmKernel224Int4SmallKGroup {
WORK_K_BLOCK(k_block + 1);
}
c[n_block_begin] = _mm512_reduce_add_ps(sum) / 16;
c[n_block_begin - n_start] = _mm512_reduce_add_ps(sum) / 16;
}
}
}

View File

@@ -364,16 +364,34 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
raise NotImplementedError("RAWINT4 wrapper expects pre-quantized safetensor weights.")
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
import time
t0 = time.time()
base_key = f"model.layers.{self.layer_idx}"
weights = self.loader.load_experts(base_key)
t1 = time.time()
self.gate_weights = torch.stack(weights["gate"], dim=0).contiguous()
self.up_weights = torch.stack(weights["up"], dim=0).contiguous()
self.down_weights = torch.stack(weights["down"], dim=0).contiguous()
# Keep individual tensors instead of stacking - avoid expensive memory copy
# weights["gate"], weights["up"], weights["down"] are lists of tensors per expert
self.gate_weights = weights["gate"] # list of tensors
self.up_weights = weights["up"]
self.down_weights = weights["down"]
self.gate_scales = torch.stack(weights["gate_scale"], dim=0).to(torch.bfloat16).contiguous()
self.up_scales = torch.stack(weights["up_scale"], dim=0).to(torch.bfloat16).contiguous()
self.down_scales = torch.stack(weights["down_scale"], dim=0).to(torch.bfloat16).contiguous()
# Convert scales to bf16 individually
self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
t2 = time.time()
# Build pointer lists: [numa_id][expert_id] -> pointer
# Since RAWINT4 has no numa sharding, numa dimension is 1
gate_ptrs = [[t.data_ptr() for t in self.gate_weights]]
up_ptrs = [[t.data_ptr() for t in self.up_weights]]
down_ptrs = [[t.data_ptr() for t in self.down_weights]]
gate_scale_ptrs = [[t.data_ptr() for t in self.gate_scales]]
up_scale_ptrs = [[t.data_ptr() for t in self.up_scales]]
down_scale_ptrs = [[t.data_ptr() for t in self.down_scales]]
t3 = time.time()
moe_config = MOEConfig(
self.num_experts,
@@ -390,17 +408,20 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
moe_config.quant_config.group_size = 32
moe_config.quant_config.zero_point = False
moe_config.gate_proj = self.gate_weights.data_ptr()
moe_config.up_proj = self.up_weights.data_ptr()
moe_config.down_proj = self.down_weights.data_ptr()
moe_config.gate_scale = self.gate_scales.data_ptr()
moe_config.up_scale = self.up_scales.data_ptr()
moe_config.down_scale = self.down_scales.data_ptr()
# Use gate_projs instead of gate_proj for per-expert pointers
moe_config.gate_projs = gate_ptrs
moe_config.up_projs = up_ptrs
moe_config.down_projs = down_ptrs
moe_config.gate_scales = gate_scale_ptrs
moe_config.up_scales = up_scale_ptrs
moe_config.down_scales = down_scale_ptrs
self.moe = AMXInt4_KGroup_MOE(moe_config)
t4 = time.time()
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
t5 = time.time()
del self.gate_weights
del self.up_weights
@@ -408,6 +429,18 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
del self.gate_scales
del self.up_scales
del self.down_scales
t6 = time.time()
print(
f"[RAWAMXMoEWrapper Layer {self.layer_idx}] "
f"load_experts: {(t1-t0)*1000:.1f}ms, "
f"prepare_tensors: {(t2-t1)*1000:.1f}ms, "
f"build_ptrs: {(t3-t2)*1000:.1f}ms, "
f"create_moe: {(t4-t3)*1000:.1f}ms, "
f"cpp_load_weights: {(t5-t4)*1000:.1f}ms, "
f"cleanup: {(t6-t5)*1000:.1f}ms, "
f"total: {(t6-t0)*1000:.1f}ms"
)
def submit_write_weight_scale_to_buffer(
self,