mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-15 02:47:22 +00:00
Support Native Kimi K2 Thinking (#1663)
* [feat]: fix k2 prefill * Update Kimi-K2-Thinking.md * Create Kimi-K2-Thinking-Native.md * Update Kimi-K2-Thinking.md * Update Kimi-K2-Thinking.md * Update Kimi-K2-Thinking-Native.md * [perf] optimize K2 MoE weight loading with per-expert pointers - Avoid expensive torch.stack().contiguous() in Python (was ~6.6s) - Use per-expert pointer arrays (gate_projs) instead of contiguous memory - C++ worker pool performs parallel memcpy for TP slicing - Add LOAD_TIME_PROFILE for load_weights timing analysis 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: ouqingliang <1692110604@qq.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
// #define FORWARD_TIME_PROFILE
|
||||
// #define FORWARD_TIME_REPORT
|
||||
#define LOAD_TIME_PROFILE
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
@@ -145,10 +145,6 @@ class AMX_K2_MOE_TP {
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
#ifdef FORWARD_TIME_REPORT
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> last_now;
|
||||
#endif
|
||||
|
||||
public:
|
||||
using input_t = ggml_bf16_t;
|
||||
using output_t = float;
|
||||
@@ -518,8 +514,229 @@ class AMX_K2_MOE_TP {
|
||||
|
||||
void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
|
||||
void* output) {
|
||||
for (int i = 0; i < qlen; i ++)
|
||||
forward_decode(k, expert_ids + i * k, weights + i * k, (ggml_bf16_t*)input + i * config_.hidden_size, (float*)output + i * config_.hidden_size);
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
auto& quant_config = config_.quant_config;
|
||||
int& group_size = quant_config.group_size;
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
auto last = start_time;
|
||||
// 用于保存各阶段耗时(单位:微秒)
|
||||
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
|
||||
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
|
||||
int max_local_num = 0; // 记录最大的 local num
|
||||
#endif
|
||||
|
||||
int activated_expert = 0;
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
m_local_num_[i] = 0;
|
||||
}
|
||||
for (int i = 0; i < qlen; i++) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
if (m_local_num_[i] > 0) {
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
max_local_num = std::max(max_local_num, m_local_num_[i]);
|
||||
#endif
|
||||
m_expert_id_map_[activated_expert] = i;
|
||||
activated_expert++;
|
||||
}
|
||||
}
|
||||
|
||||
// activated_expert 已经统计完成
|
||||
|
||||
size_t offset = 0;
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
|
||||
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
|
||||
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
|
||||
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
|
||||
offset += m_local_num_[i];
|
||||
}
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
prepare_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
DIRECT_OR_POOL_BY_QLEN(qlen, [&](int i) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
|
||||
(ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
|
||||
}
|
||||
});
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
cpy_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
DIRECT_OR_POOL_BY_QLEN(activated_expert, [this](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
|
||||
});
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert * 2, [](int _) { T::config(); },
|
||||
[this, nth, qlen](int task_id2) {
|
||||
int& group_size = config_.quant_config.group_size;
|
||||
int task_id = task_id2 / 2;
|
||||
bool do_up = task_id2 % 2;
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
|
||||
int ith = task_id % nth;
|
||||
if (do_up) {
|
||||
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
|
||||
group_size, gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx],
|
||||
ith, nth);
|
||||
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
|
||||
} else {
|
||||
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
|
||||
group_size, gate_up_ba_[expert_idx], gate_bb_[expert_idx],
|
||||
gate_bc_[expert_idx], ith, nth);
|
||||
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
auto up_gate_fn = [this, nth](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
|
||||
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
|
||||
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
for (int j = n_start; j < n_end; j += 32) {
|
||||
__m512 gate_val0, gate_val1, up_val0, up_val1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
|
||||
__m512 result0 = amx::act_fn(gate_val0, up_val0);
|
||||
__m512 result1 = amx::act_fn(gate_val1, up_val1);
|
||||
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
|
||||
}
|
||||
}
|
||||
};
|
||||
DIRECT_OR_POOL_BY_QLEN(nth * activated_expert, up_gate_fn);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
activated_expert, nullptr,
|
||||
[this](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
nullptr);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert, [](int _) { T::config(); },
|
||||
[this, nth, qlen](int task_id) {
|
||||
int& group_size = config_.quant_config.group_size;
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size,
|
||||
group_size, down_ba_[expert_idx], down_bb_[expert_idx], down_bc_[expert_idx],
|
||||
ith, nth);
|
||||
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
qlen, nullptr,
|
||||
[this, nth, output, k, expert_ids, weights](int i) {
|
||||
for (int e = 0; e < config_.hidden_size; e += 32) {
|
||||
__m512 x0 = _mm512_setzero_ps();
|
||||
__m512 x1 = _mm512_setzero_ps();
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
|
||||
__m512 down_output0, down_output1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
|
||||
m_local_pos_[i][j] * config_.hidden_size + e),
|
||||
&down_output0, &down_output1);
|
||||
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
|
||||
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
|
||||
}
|
||||
auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);
|
||||
f32out[0] = x0;
|
||||
f32out[1] = x1;
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
|
||||
// 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen
|
||||
printf(
|
||||
"Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, "
|
||||
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: "
|
||||
"%d, qlen: %d\n",
|
||||
tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time,
|
||||
down_time, weight_time, forward_total_time, max_local_num, qlen);
|
||||
#endif
|
||||
// for (int i = 0; i < qlen; i ++)
|
||||
// forward_decode(k, expert_ids + i * k, weights + i * k, (ggml_bf16_t*)input + i * config_.hidden_size, (float*)output + i * config_.hidden_size);
|
||||
}
|
||||
|
||||
void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
|
||||
@@ -768,84 +985,169 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE_Common<AMX_K2_MOE_TP<K>> {
|
||||
auto pool = config.pool;
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
|
||||
|
||||
if (config.gate_scale == nullptr) {
|
||||
#ifdef LOAD_TIME_PROFILE
|
||||
auto load_start_time = std::chrono::high_resolution_clock::now();
|
||||
auto load_last = load_start_time;
|
||||
long alloc_and_tp_slice_time = 0, tps_load_time = 0, cleanup_time = 0;
|
||||
#endif
|
||||
|
||||
// Check if using per-expert pointers (gate_projs) or contiguous memory (gate_proj + gate_scale)
|
||||
bool use_per_expert_ptrs = !config.gate_projs.empty();
|
||||
|
||||
if (!use_per_expert_ptrs && config.gate_scale == nullptr) {
|
||||
throw std::runtime_error("K2 MoE only supports Packed Int4 with KGroup Scale");
|
||||
}
|
||||
printf("From Packed Int4 with KGroup Scale\n");
|
||||
|
||||
if (use_per_expert_ptrs) {
|
||||
printf("From per-expert pointers (gate_projs)\n");
|
||||
} else {
|
||||
printf("From Packed Int4 with KGroup Scale\n");
|
||||
}
|
||||
|
||||
int& group_size = config.quant_config.group_size;
|
||||
for (auto i = 0; i < tp_count; i++) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
|
||||
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
|
||||
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
|
||||
if (use_per_expert_ptrs) {
|
||||
// Load from per-expert pointers - no need to allocate intermediate buffers
|
||||
// gate_projs[numa_id][expert_id] -> pointer to expert weight
|
||||
// For RAWINT4, numa dimension is 1 (index 0)
|
||||
for (auto i = 0; i < tp_count; i++) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
|
||||
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
|
||||
|
||||
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
// Allocate per-TP buffers
|
||||
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
|
||||
if (tps[i]->config_.load == false) {
|
||||
pool->get_subpool(i)->do_work_stealing_job(
|
||||
tpc.expert_num, nullptr,
|
||||
[&](int expert_id_) { // weight and scale are all in col majored.
|
||||
[&, i](int expert_id_) {
|
||||
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
|
||||
|
||||
// weight and scale TP-slicing for gate and up
|
||||
// Source pointers from per-expert pointer arrays
|
||||
// gate_projs[0][expert_id] since numa dimension is 1
|
||||
uint8_t* src_gate = (uint8_t*)config.gate_projs[0][expert_id];
|
||||
uint8_t* src_up = (uint8_t*)config.up_projs[0][expert_id];
|
||||
uint8_t* src_down = (uint8_t*)config.down_projs[0][expert_id];
|
||||
ggml_bf16_t* src_gate_scale = (ggml_bf16_t*)config.gate_scales[0][expert_id];
|
||||
ggml_bf16_t* src_up_scale = (ggml_bf16_t*)config.up_scales[0][expert_id];
|
||||
ggml_bf16_t* src_down_scale = (ggml_bf16_t*)config.down_scales[0][expert_id];
|
||||
|
||||
// TP-slicing for gate and up (row-major slicing)
|
||||
memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),
|
||||
(uint8_t*)config.gate_proj +
|
||||
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
|
||||
((sizeof(uint8_t) * weight_elem_count) >> 1));
|
||||
src_gate + ((i * weight_elem_count) >> 1),
|
||||
(weight_elem_count >> 1));
|
||||
|
||||
memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1),
|
||||
(uint8_t*)config.up_proj +
|
||||
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
|
||||
((sizeof(uint8_t) * weight_elem_count) >> 1));
|
||||
src_up + ((i * weight_elem_count) >> 1),
|
||||
(weight_elem_count >> 1));
|
||||
|
||||
memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count),
|
||||
(ggml_bf16_t*)config.gate_scale +
|
||||
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
|
||||
i * scales_elem_count),
|
||||
sizeof(ggml_bf16_t) * scales_elem_count);
|
||||
src_gate_scale + (i * scales_elem_count),
|
||||
sizeof(ggml_bf16_t) * scales_elem_count);
|
||||
|
||||
memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count),
|
||||
(ggml_bf16_t*)config.up_scale +
|
||||
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
|
||||
i * scales_elem_count),
|
||||
sizeof(ggml_bf16_t) * scales_elem_count);
|
||||
src_up_scale + (i * scales_elem_count),
|
||||
sizeof(ggml_bf16_t) * scales_elem_count);
|
||||
|
||||
// memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count) >> 1),
|
||||
// (uint8_t*)config.down_proj +
|
||||
// ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
|
||||
// ((sizeof(uint8_t) * weight_elem_count) >> 1));
|
||||
|
||||
// memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count),
|
||||
// (ggml_bf16_t*)config.down_scale +
|
||||
// (expert_id * (config.intermediate_size / group_size) * config.hidden_size +
|
||||
// i * scales_elem_count),
|
||||
// sizeof(ggml_bf16_t) * scales_elem_count);
|
||||
|
||||
// weight and scale TP-slicing for down (by column)
|
||||
// TP-slicing for down (by column)
|
||||
for (size_t col = 0; col < config.hidden_size; col++) {
|
||||
memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1),
|
||||
(uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size +
|
||||
col * config.intermediate_size + i * tpc.intermediate_size) >>
|
||||
1),
|
||||
(sizeof(uint8_t) * tpc.intermediate_size) >> 1);
|
||||
src_down + ((col * config.intermediate_size + i * tpc.intermediate_size) >> 1),
|
||||
(tpc.intermediate_size >> 1));
|
||||
memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)),
|
||||
(ggml_bf16_t*)config.down_scale + ((expert_id * (config.intermediate_size / group_size) * config.hidden_size) +
|
||||
col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)),
|
||||
sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size));
|
||||
src_down_scale + (col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)),
|
||||
sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size));
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
printf("TP %d load weight done.\n", i);
|
||||
}
|
||||
} else {
|
||||
// Original path: load from contiguous memory with gate_proj/gate_scale
|
||||
for (auto i = 0; i < tp_count; i++) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
|
||||
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||
|
||||
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
|
||||
|
||||
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||
|
||||
if (tps[i]->config_.load == false) {
|
||||
pool->get_subpool(i)->do_work_stealing_job(
|
||||
tpc.expert_num, nullptr,
|
||||
[&](int expert_id_) { // weight and scale are all in col majored.
|
||||
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
|
||||
|
||||
// weight and scale TP-slicing for gate and up
|
||||
memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),
|
||||
(uint8_t*)config.gate_proj +
|
||||
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
|
||||
((sizeof(uint8_t) * weight_elem_count) >> 1));
|
||||
|
||||
memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1),
|
||||
(uint8_t*)config.up_proj +
|
||||
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
|
||||
((sizeof(uint8_t) * weight_elem_count) >> 1));
|
||||
|
||||
memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count),
|
||||
(ggml_bf16_t*)config.gate_scale +
|
||||
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
|
||||
i * scales_elem_count),
|
||||
sizeof(ggml_bf16_t) * scales_elem_count);
|
||||
|
||||
memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count),
|
||||
(ggml_bf16_t*)config.up_scale +
|
||||
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
|
||||
i * scales_elem_count),
|
||||
sizeof(ggml_bf16_t) * scales_elem_count);
|
||||
|
||||
// weight and scale TP-slicing for down (by column)
|
||||
for (size_t col = 0; col < config.hidden_size; col++) {
|
||||
memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1),
|
||||
(uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size +
|
||||
col * config.intermediate_size + i * tpc.intermediate_size) >>
|
||||
1),
|
||||
(sizeof(uint8_t) * tpc.intermediate_size) >> 1);
|
||||
memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)),
|
||||
(ggml_bf16_t*)config.down_scale + ((expert_id * (config.intermediate_size / group_size) * config.hidden_size) +
|
||||
col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)),
|
||||
sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size));
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
printf("TP %d load weight done.\n", i);
|
||||
}
|
||||
printf("TP %d load weight done.\n", i);
|
||||
}
|
||||
|
||||
#ifdef LOAD_TIME_PROFILE
|
||||
{
|
||||
auto load_now_time = std::chrono::high_resolution_clock::now();
|
||||
alloc_and_tp_slice_time = std::chrono::duration_cast<std::chrono::microseconds>(load_now_time - load_last).count();
|
||||
load_last = load_now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
DO_TPS_LOAD_WEIGHTS(pool);
|
||||
|
||||
#ifdef LOAD_TIME_PROFILE
|
||||
{
|
||||
auto load_now_time = std::chrono::high_resolution_clock::now();
|
||||
tps_load_time = std::chrono::duration_cast<std::chrono::microseconds>(load_now_time - load_last).count();
|
||||
load_last = load_now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
for (auto i = 0; i < tp_count; i++) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
delete[] (uint8_t*)(tpc.gate_proj);
|
||||
@@ -857,6 +1159,18 @@ class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE_Common<AMX_K2_MOE_TP<K>> {
|
||||
delete[] (ggml_bf16_t*)(tpc.down_scale);
|
||||
}
|
||||
|
||||
#ifdef LOAD_TIME_PROFILE
|
||||
{
|
||||
auto load_now_time = std::chrono::high_resolution_clock::now();
|
||||
cleanup_time = std::chrono::duration_cast<std::chrono::microseconds>(load_now_time - load_last).count();
|
||||
}
|
||||
auto load_end_time = std::chrono::high_resolution_clock::now();
|
||||
auto load_total_time = std::chrono::duration_cast<std::chrono::microseconds>(load_end_time - load_start_time).count();
|
||||
printf(
|
||||
"[K2 MoE Load Weights] tp_count: %d, alloc_and_tp_slice: %ld us, tps_load_weights: %ld us, cleanup: %ld us, total: %ld us\n",
|
||||
tp_count, alloc_and_tp_slice_time, tps_load_time, cleanup_time, load_total_time);
|
||||
#endif
|
||||
|
||||
this->weights_loaded = true;
|
||||
}
|
||||
|
||||
|
||||
@@ -2905,7 +2905,7 @@ struct GemmKernel224Int4SmallKGroup {
|
||||
static inline void integer_mat_vec_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB *bb, BufferC* bc, int ith, int nth) {
|
||||
auto [n_start, n_end] = split_range_n(n, ith, nth);
|
||||
for (int m_begin = 0; m_begin < m; m_begin ++) {
|
||||
float* c = bc->get_submat(m, n, m_begin, 0);
|
||||
float* c = bc->get_submat(m, n, m_begin, n_start);
|
||||
__m512i* a512 = (__m512i*)ba->get_submat(m, k, m_begin, 0);
|
||||
|
||||
for (int n_block_begin = n_start; n_block_begin < n_end; n_block_begin ++) {
|
||||
@@ -2929,7 +2929,7 @@ struct GemmKernel224Int4SmallKGroup {
|
||||
WORK_K_BLOCK(k_block + 1);
|
||||
}
|
||||
|
||||
c[n_block_begin] = _mm512_reduce_add_ps(sum) / 16;
|
||||
c[n_block_begin - n_start] = _mm512_reduce_add_ps(sum) / 16;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,16 +364,34 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
raise NotImplementedError("RAWINT4 wrapper expects pre-quantized safetensor weights.")
|
||||
|
||||
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
|
||||
import time
|
||||
|
||||
t0 = time.time()
|
||||
base_key = f"model.layers.{self.layer_idx}"
|
||||
weights = self.loader.load_experts(base_key)
|
||||
t1 = time.time()
|
||||
|
||||
self.gate_weights = torch.stack(weights["gate"], dim=0).contiguous()
|
||||
self.up_weights = torch.stack(weights["up"], dim=0).contiguous()
|
||||
self.down_weights = torch.stack(weights["down"], dim=0).contiguous()
|
||||
# Keep individual tensors instead of stacking - avoid expensive memory copy
|
||||
# weights["gate"], weights["up"], weights["down"] are lists of tensors per expert
|
||||
self.gate_weights = weights["gate"] # list of tensors
|
||||
self.up_weights = weights["up"]
|
||||
self.down_weights = weights["down"]
|
||||
|
||||
self.gate_scales = torch.stack(weights["gate_scale"], dim=0).to(torch.bfloat16).contiguous()
|
||||
self.up_scales = torch.stack(weights["up_scale"], dim=0).to(torch.bfloat16).contiguous()
|
||||
self.down_scales = torch.stack(weights["down_scale"], dim=0).to(torch.bfloat16).contiguous()
|
||||
# Convert scales to bf16 individually
|
||||
self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
|
||||
self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
|
||||
self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
|
||||
t2 = time.time()
|
||||
|
||||
# Build pointer lists: [numa_id][expert_id] -> pointer
|
||||
# Since RAWINT4 has no numa sharding, numa dimension is 1
|
||||
gate_ptrs = [[t.data_ptr() for t in self.gate_weights]]
|
||||
up_ptrs = [[t.data_ptr() for t in self.up_weights]]
|
||||
down_ptrs = [[t.data_ptr() for t in self.down_weights]]
|
||||
gate_scale_ptrs = [[t.data_ptr() for t in self.gate_scales]]
|
||||
up_scale_ptrs = [[t.data_ptr() for t in self.up_scales]]
|
||||
down_scale_ptrs = [[t.data_ptr() for t in self.down_scales]]
|
||||
t3 = time.time()
|
||||
|
||||
moe_config = MOEConfig(
|
||||
self.num_experts,
|
||||
@@ -390,17 +408,20 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
moe_config.quant_config.group_size = 32
|
||||
moe_config.quant_config.zero_point = False
|
||||
|
||||
moe_config.gate_proj = self.gate_weights.data_ptr()
|
||||
moe_config.up_proj = self.up_weights.data_ptr()
|
||||
moe_config.down_proj = self.down_weights.data_ptr()
|
||||
moe_config.gate_scale = self.gate_scales.data_ptr()
|
||||
moe_config.up_scale = self.up_scales.data_ptr()
|
||||
moe_config.down_scale = self.down_scales.data_ptr()
|
||||
# Use gate_projs instead of gate_proj for per-expert pointers
|
||||
moe_config.gate_projs = gate_ptrs
|
||||
moe_config.up_projs = up_ptrs
|
||||
moe_config.down_projs = down_ptrs
|
||||
moe_config.gate_scales = gate_scale_ptrs
|
||||
moe_config.up_scales = up_scale_ptrs
|
||||
moe_config.down_scales = down_scale_ptrs
|
||||
|
||||
self.moe = AMXInt4_KGroup_MOE(moe_config)
|
||||
t4 = time.time()
|
||||
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
t5 = time.time()
|
||||
|
||||
del self.gate_weights
|
||||
del self.up_weights
|
||||
@@ -408,6 +429,18 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
del self.gate_scales
|
||||
del self.up_scales
|
||||
del self.down_scales
|
||||
t6 = time.time()
|
||||
|
||||
print(
|
||||
f"[RAWAMXMoEWrapper Layer {self.layer_idx}] "
|
||||
f"load_experts: {(t1-t0)*1000:.1f}ms, "
|
||||
f"prepare_tensors: {(t2-t1)*1000:.1f}ms, "
|
||||
f"build_ptrs: {(t3-t2)*1000:.1f}ms, "
|
||||
f"create_moe: {(t4-t3)*1000:.1f}ms, "
|
||||
f"cpp_load_weights: {(t5-t4)*1000:.1f}ms, "
|
||||
f"cleanup: {(t6-t5)*1000:.1f}ms, "
|
||||
f"total: {(t6-t0)*1000:.1f}ms"
|
||||
)
|
||||
|
||||
def submit_write_weight_scale_to_buffer(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user