From 71f683acecd1e034cd2f942ed478fcb74dd92e93 Mon Sep 17 00:00:00 2001 From: ErvinXie Date: Fri, 5 Dec 2025 21:53:05 +0800 Subject: [PATCH] Support Native Kimi K2 Thinking (#1663) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [feat]: fix k2 prefill * Update Kimi-K2-Thinking.md * Create Kimi-K2-Thinking-Native.md * Update Kimi-K2-Thinking.md * Update Kimi-K2-Thinking.md * Update Kimi-K2-Thinking-Native.md * [perf] optimize K2 MoE weight loading with per-expert pointers - Avoid expensive torch.stack().contiguous() in Python (was ~6.6s) - Use per-expert pointer arrays (gate_projs) instead of contiguous memory - C++ worker pool performs parallel memcpy for TP slicing - Add LOAD_TIME_PROFILE for load_weights timing analysis 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --------- Co-authored-by: ouqingliang <1692110604@qq.com> Co-authored-by: Claude --- doc/en/Kimi-K2-Thinking-Native.md | 1 + doc/en/Kimi-K2-Thinking.md | 1 + kt-kernel/operators/amx/k2-moe.hpp | 426 ++++++++++++++++++--- kt-kernel/operators/amx/la/amx_kernels.hpp | 4 +- kt-kernel/python/utils/amx.py | 57 ++- 5 files changed, 419 insertions(+), 70 deletions(-) create mode 100644 doc/en/Kimi-K2-Thinking-Native.md diff --git a/doc/en/Kimi-K2-Thinking-Native.md b/doc/en/Kimi-K2-Thinking-Native.md new file mode 100644 index 0000000..b4ab55e --- /dev/null +++ b/doc/en/Kimi-K2-Thinking-Native.md @@ -0,0 +1 @@ +需要先写如何安装运行,然后写一个性能,然后链接到如何使用 claude code 接入的文档。 diff --git a/doc/en/Kimi-K2-Thinking.md b/doc/en/Kimi-K2-Thinking.md index bc560bb..c7f3671 100644 --- a/doc/en/Kimi-K2-Thinking.md +++ b/doc/en/Kimi-K2-Thinking.md @@ -1,4 +1,5 @@ # KTransformers+SGLang Inference Deployment +Please Note This is Quantization Deployment. For Native Kimi K2 Thinking deployment please refer to [here](./Kimi-K2-Thinking-Native.md). ## Installation diff --git a/kt-kernel/operators/amx/k2-moe.hpp b/kt-kernel/operators/amx/k2-moe.hpp index c6f4924..ed97bd2 100644 --- a/kt-kernel/operators/amx/k2-moe.hpp +++ b/kt-kernel/operators/amx/k2-moe.hpp @@ -16,7 +16,7 @@ #include #include // #define FORWARD_TIME_PROFILE -// #define FORWARD_TIME_REPORT +#define LOAD_TIME_PROFILE #include @@ -145,10 +145,6 @@ class AMX_K2_MOE_TP { fflush(stdout); } -#ifdef FORWARD_TIME_REPORT - std::chrono::time_point last_now; -#endif - public: using input_t = ggml_bf16_t; using output_t = float; @@ -518,8 +514,229 @@ class AMX_K2_MOE_TP { void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) { - for (int i = 0; i < qlen; i ++) - forward_decode(k, expert_ids + i * k, weights + i * k, (ggml_bf16_t*)input + i * config_.hidden_size, (float*)output + i * config_.hidden_size); + auto pool = config_.pool->get_subpool(tp_part_idx); + auto& quant_config = config_.quant_config; + int& group_size = quant_config.group_size; +#ifdef FORWARD_TIME_PROFILE + auto start_time = std::chrono::high_resolution_clock::now(); + auto last = start_time; + // 用于保存各阶段耗时(单位:微秒) + long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0; + long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0; + int max_local_num = 0; // 记录最大的 local num +#endif + + int activated_expert = 0; + for (int i = 0; i < config_.expert_num; i++) { + m_local_num_[i] = 0; + } + for (int i = 0; i < qlen; i++) { + for (int j = 0; j < k; j++) { + if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { + continue; + } + m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++; + } + } + + for (int i = 0; i < config_.expert_num; i++) { + if (m_local_num_[i] > 0) { +#ifdef FORWARD_TIME_PROFILE + max_local_num = std::max(max_local_num, m_local_num_[i]); +#endif + m_expert_id_map_[activated_expert] = i; + activated_expert++; + } + } + + // activated_expert 已经统计完成 + + size_t offset = 0; + for (int i = 0; i < config_.expert_num; i++) { + m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size; + m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size; + m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size; + m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size; + offset += m_local_num_[i]; + } +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + prepare_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + DIRECT_OR_POOL_BY_QLEN(qlen, [&](int i) { + for (int j = 0; j < k; j++) { + if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { + continue; + } + memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size, + (ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size); + } + }); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + cpy_input_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + DIRECT_OR_POOL_BY_QLEN(activated_expert, [this](int task_id) { + int expert_idx = m_expert_id_map_[task_id]; + gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1); + }); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + q_input_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + int nth = T::recommended_nth(config_.intermediate_size); + pool->do_work_stealing_job( + nth * activated_expert * 2, [](int _) { T::config(); }, + [this, nth, qlen](int task_id2) { + int& group_size = config_.quant_config.group_size; + int task_id = task_id2 / 2; + bool do_up = task_id2 % 2; + int expert_idx = m_expert_id_map_[task_id / nth]; + + int ith = task_id % nth; + if (do_up) { + MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, + group_size, gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], + ith, nth); + up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth); + } else { + MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, + group_size, gate_up_ba_[expert_idx], gate_bb_[expert_idx], + gate_bc_[expert_idx], ith, nth); + gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth); + } + }, + nullptr); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + up_gate_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + auto up_gate_fn = [this, nth](int task_id) { + int expert_idx = m_expert_id_map_[task_id / nth]; + int ith = task_id % nth; + auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); + for (int i = 0; i < m_local_num_[expert_idx]; i++) { + ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; + ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; + for (int j = n_start; j < n_end; j += 32) { + __m512 gate_val0, gate_val1, up_val0, up_val1; + avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1); + avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1); + __m512 result0 = amx::act_fn(gate_val0, up_val0); + __m512 result1 = amx::act_fn(gate_val1, up_val1); + avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j)); + } + } + }; + DIRECT_OR_POOL_BY_QLEN(nth * activated_expert, up_gate_fn); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + act_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + pool->do_work_stealing_job( + activated_expert, nullptr, + [this](int task_id) { + int expert_idx = m_expert_id_map_[task_id]; + down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1); + }, + nullptr); +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + q_down_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + nth = T::recommended_nth(config_.hidden_size); + pool->do_work_stealing_job( + nth * activated_expert, [](int _) { T::config(); }, + [this, nth, qlen](int task_id) { + int& group_size = config_.quant_config.group_size; + int expert_idx = m_expert_id_map_[task_id / nth]; + int ith = task_id % nth; + MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, + group_size, down_ba_[expert_idx], down_bb_[expert_idx], down_bc_[expert_idx], + ith, nth); + down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth); + }, + nullptr); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + down_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + pool->do_work_stealing_job( + qlen, nullptr, + [this, nth, output, k, expert_ids, weights](int i) { + for (int e = 0; e < config_.hidden_size; e += 32) { + __m512 x0 = _mm512_setzero_ps(); + __m512 x1 = _mm512_setzero_ps(); + for (int j = 0; j < k; j++) { + if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { + continue; + } + __m512 weight = _mm512_set1_ps(weights[i * k + j]); + __m512 down_output0, down_output1; + avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] + + m_local_pos_[i][j] * config_.hidden_size + e), + &down_output0, &down_output1); + x0 = _mm512_fmadd_ps(down_output0, weight, x0); + x1 = _mm512_fmadd_ps(down_output1, weight, x1); + } + auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e); + f32out[0] = x0; + f32out[1] = x1; + } + }, + nullptr); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + weight_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } + auto end_time = std::chrono::high_resolution_clock::now(); + auto forward_total_time = std::chrono::duration_cast(end_time - start_time).count(); + // 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen + printf( + "Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, " + "up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: " + "%d, qlen: %d\n", + tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time, + down_time, weight_time, forward_total_time, max_local_num, qlen); +#endif + // for (int i = 0; i < qlen; i ++) + // forward_decode(k, expert_ids + i * k, weights + i * k, (ggml_bf16_t*)input + i * config_.hidden_size, (float*)output + i * config_.hidden_size); } void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) { @@ -768,84 +985,169 @@ class TP_MOE> : public TP_MOE_Common> { auto pool = config.pool; const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map; - if (config.gate_scale == nullptr) { +#ifdef LOAD_TIME_PROFILE + auto load_start_time = std::chrono::high_resolution_clock::now(); + auto load_last = load_start_time; + long alloc_and_tp_slice_time = 0, tps_load_time = 0, cleanup_time = 0; +#endif + + // Check if using per-expert pointers (gate_projs) or contiguous memory (gate_proj + gate_scale) + bool use_per_expert_ptrs = !config.gate_projs.empty(); + + if (!use_per_expert_ptrs && config.gate_scale == nullptr) { throw std::runtime_error("K2 MoE only supports Packed Int4 with KGroup Scale"); } - printf("From Packed Int4 with KGroup Scale\n"); + + if (use_per_expert_ptrs) { + printf("From per-expert pointers (gate_projs)\n"); + } else { + printf("From Packed Int4 with KGroup Scale\n"); + } + int& group_size = config.quant_config.group_size; - for (auto i = 0; i < tp_count; i++) { - auto& tpc = tps[i]->config_; - size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size; - tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; - tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; - tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; - size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size; + if (use_per_expert_ptrs) { + // Load from per-expert pointers - no need to allocate intermediate buffers + // gate_projs[numa_id][expert_id] -> pointer to expert weight + // For RAWINT4, numa dimension is 1 (index 0) + for (auto i = 0; i < tp_count; i++) { + auto& tpc = tps[i]->config_; + size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size; + size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size; - tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)]; - tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)]; - tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)]; + // Allocate per-TP buffers + tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; + tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; + tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; + tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)]; + tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)]; + tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)]; - if (tps[i]->config_.load == false) { pool->get_subpool(i)->do_work_stealing_job( tpc.expert_num, nullptr, - [&](int expert_id_) { // weight and scale are all in col majored. + [&, i](int expert_id_) { size_t expert_id = expert_map(physical_to_logical_map, expert_id_); - // weight and scale TP-slicing for gate and up + // Source pointers from per-expert pointer arrays + // gate_projs[0][expert_id] since numa dimension is 1 + uint8_t* src_gate = (uint8_t*)config.gate_projs[0][expert_id]; + uint8_t* src_up = (uint8_t*)config.up_projs[0][expert_id]; + uint8_t* src_down = (uint8_t*)config.down_projs[0][expert_id]; + ggml_bf16_t* src_gate_scale = (ggml_bf16_t*)config.gate_scales[0][expert_id]; + ggml_bf16_t* src_up_scale = (ggml_bf16_t*)config.up_scales[0][expert_id]; + ggml_bf16_t* src_down_scale = (ggml_bf16_t*)config.down_scales[0][expert_id]; + + // TP-slicing for gate and up (row-major slicing) memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1), - (uint8_t*)config.gate_proj + - ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1), - ((sizeof(uint8_t) * weight_elem_count) >> 1)); + src_gate + ((i * weight_elem_count) >> 1), + (weight_elem_count >> 1)); memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1), - (uint8_t*)config.up_proj + - ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1), - ((sizeof(uint8_t) * weight_elem_count) >> 1)); + src_up + ((i * weight_elem_count) >> 1), + (weight_elem_count >> 1)); memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count), - (ggml_bf16_t*)config.gate_scale + - (expert_id * (config.hidden_size / group_size) * config.intermediate_size + - i * scales_elem_count), - sizeof(ggml_bf16_t) * scales_elem_count); + src_gate_scale + (i * scales_elem_count), + sizeof(ggml_bf16_t) * scales_elem_count); memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count), - (ggml_bf16_t*)config.up_scale + - (expert_id * (config.hidden_size / group_size) * config.intermediate_size + - i * scales_elem_count), - sizeof(ggml_bf16_t) * scales_elem_count); + src_up_scale + (i * scales_elem_count), + sizeof(ggml_bf16_t) * scales_elem_count); - // memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count) >> 1), - // (uint8_t*)config.down_proj + - // ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1), - // ((sizeof(uint8_t) * weight_elem_count) >> 1)); - - // memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count), - // (ggml_bf16_t*)config.down_scale + - // (expert_id * (config.intermediate_size / group_size) * config.hidden_size + - // i * scales_elem_count), - // sizeof(ggml_bf16_t) * scales_elem_count); - - // weight and scale TP-slicing for down (by column) + // TP-slicing for down (by column) for (size_t col = 0; col < config.hidden_size; col++) { memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1), - (uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size + - col * config.intermediate_size + i * tpc.intermediate_size) >> - 1), - (sizeof(uint8_t) * tpc.intermediate_size) >> 1); + src_down + ((col * config.intermediate_size + i * tpc.intermediate_size) >> 1), + (tpc.intermediate_size >> 1)); memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)), - (ggml_bf16_t*)config.down_scale + ((expert_id * (config.intermediate_size / group_size) * config.hidden_size) + - col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)), - sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size)); + src_down_scale + (col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)), + sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size)); } }, nullptr); + printf("TP %d load weight done.\n", i); + } + } else { + // Original path: load from contiguous memory with gate_proj/gate_scale + for (auto i = 0; i < tp_count; i++) { + auto& tpc = tps[i]->config_; + size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size; + tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; + tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; + tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; + + size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size; + + tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)]; + tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)]; + tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)]; + + if (tps[i]->config_.load == false) { + pool->get_subpool(i)->do_work_stealing_job( + tpc.expert_num, nullptr, + [&](int expert_id_) { // weight and scale are all in col majored. + size_t expert_id = expert_map(physical_to_logical_map, expert_id_); + + // weight and scale TP-slicing for gate and up + memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1), + (uint8_t*)config.gate_proj + + ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1), + ((sizeof(uint8_t) * weight_elem_count) >> 1)); + + memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1), + (uint8_t*)config.up_proj + + ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1), + ((sizeof(uint8_t) * weight_elem_count) >> 1)); + + memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count), + (ggml_bf16_t*)config.gate_scale + + (expert_id * (config.hidden_size / group_size) * config.intermediate_size + + i * scales_elem_count), + sizeof(ggml_bf16_t) * scales_elem_count); + + memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count), + (ggml_bf16_t*)config.up_scale + + (expert_id * (config.hidden_size / group_size) * config.intermediate_size + + i * scales_elem_count), + sizeof(ggml_bf16_t) * scales_elem_count); + + // weight and scale TP-slicing for down (by column) + for (size_t col = 0; col < config.hidden_size; col++) { + memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1), + (uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size + + col * config.intermediate_size + i * tpc.intermediate_size) >> + 1), + (sizeof(uint8_t) * tpc.intermediate_size) >> 1); + memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)), + (ggml_bf16_t*)config.down_scale + ((expert_id * (config.intermediate_size / group_size) * config.hidden_size) + + col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)), + sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size)); + } + }, + nullptr); + } + printf("TP %d load weight done.\n", i); } - printf("TP %d load weight done.\n", i); } +#ifdef LOAD_TIME_PROFILE + { + auto load_now_time = std::chrono::high_resolution_clock::now(); + alloc_and_tp_slice_time = std::chrono::duration_cast(load_now_time - load_last).count(); + load_last = load_now_time; + } +#endif + DO_TPS_LOAD_WEIGHTS(pool); +#ifdef LOAD_TIME_PROFILE + { + auto load_now_time = std::chrono::high_resolution_clock::now(); + tps_load_time = std::chrono::duration_cast(load_now_time - load_last).count(); + load_last = load_now_time; + } +#endif + for (auto i = 0; i < tp_count; i++) { auto& tpc = tps[i]->config_; delete[] (uint8_t*)(tpc.gate_proj); @@ -857,6 +1159,18 @@ class TP_MOE> : public TP_MOE_Common> { delete[] (ggml_bf16_t*)(tpc.down_scale); } +#ifdef LOAD_TIME_PROFILE + { + auto load_now_time = std::chrono::high_resolution_clock::now(); + cleanup_time = std::chrono::duration_cast(load_now_time - load_last).count(); + } + auto load_end_time = std::chrono::high_resolution_clock::now(); + auto load_total_time = std::chrono::duration_cast(load_end_time - load_start_time).count(); + printf( + "[K2 MoE Load Weights] tp_count: %d, alloc_and_tp_slice: %ld us, tps_load_weights: %ld us, cleanup: %ld us, total: %ld us\n", + tp_count, alloc_and_tp_slice_time, tps_load_time, cleanup_time, load_total_time); +#endif + this->weights_loaded = true; } diff --git a/kt-kernel/operators/amx/la/amx_kernels.hpp b/kt-kernel/operators/amx/la/amx_kernels.hpp index a89d3dd..6e8673e 100644 --- a/kt-kernel/operators/amx/la/amx_kernels.hpp +++ b/kt-kernel/operators/amx/la/amx_kernels.hpp @@ -2905,7 +2905,7 @@ struct GemmKernel224Int4SmallKGroup { static inline void integer_mat_vec_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB *bb, BufferC* bc, int ith, int nth) { auto [n_start, n_end] = split_range_n(n, ith, nth); for (int m_begin = 0; m_begin < m; m_begin ++) { - float* c = bc->get_submat(m, n, m_begin, 0); + float* c = bc->get_submat(m, n, m_begin, n_start); __m512i* a512 = (__m512i*)ba->get_submat(m, k, m_begin, 0); for (int n_block_begin = n_start; n_block_begin < n_end; n_block_begin ++) { @@ -2929,7 +2929,7 @@ struct GemmKernel224Int4SmallKGroup { WORK_K_BLOCK(k_block + 1); } - c[n_block_begin] = _mm512_reduce_add_ps(sum) / 16; + c[n_block_begin - n_start] = _mm512_reduce_add_ps(sum) / 16; } } } diff --git a/kt-kernel/python/utils/amx.py b/kt-kernel/python/utils/amx.py index b36ba3d..055c02b 100644 --- a/kt-kernel/python/utils/amx.py +++ b/kt-kernel/python/utils/amx.py @@ -364,16 +364,34 @@ class RAWAMXMoEWrapper(BaseMoEWrapper): raise NotImplementedError("RAWINT4 wrapper expects pre-quantized safetensor weights.") def load_weights(self, physical_to_logical_map_cpu: torch.Tensor): + import time + + t0 = time.time() base_key = f"model.layers.{self.layer_idx}" weights = self.loader.load_experts(base_key) + t1 = time.time() - self.gate_weights = torch.stack(weights["gate"], dim=0).contiguous() - self.up_weights = torch.stack(weights["up"], dim=0).contiguous() - self.down_weights = torch.stack(weights["down"], dim=0).contiguous() + # Keep individual tensors instead of stacking - avoid expensive memory copy + # weights["gate"], weights["up"], weights["down"] are lists of tensors per expert + self.gate_weights = weights["gate"] # list of tensors + self.up_weights = weights["up"] + self.down_weights = weights["down"] - self.gate_scales = torch.stack(weights["gate_scale"], dim=0).to(torch.bfloat16).contiguous() - self.up_scales = torch.stack(weights["up_scale"], dim=0).to(torch.bfloat16).contiguous() - self.down_scales = torch.stack(weights["down_scale"], dim=0).to(torch.bfloat16).contiguous() + # Convert scales to bf16 individually + self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]] + self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]] + self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]] + t2 = time.time() + + # Build pointer lists: [numa_id][expert_id] -> pointer + # Since RAWINT4 has no numa sharding, numa dimension is 1 + gate_ptrs = [[t.data_ptr() for t in self.gate_weights]] + up_ptrs = [[t.data_ptr() for t in self.up_weights]] + down_ptrs = [[t.data_ptr() for t in self.down_weights]] + gate_scale_ptrs = [[t.data_ptr() for t in self.gate_scales]] + up_scale_ptrs = [[t.data_ptr() for t in self.up_scales]] + down_scale_ptrs = [[t.data_ptr() for t in self.down_scales]] + t3 = time.time() moe_config = MOEConfig( self.num_experts, @@ -390,17 +408,20 @@ class RAWAMXMoEWrapper(BaseMoEWrapper): moe_config.quant_config.group_size = 32 moe_config.quant_config.zero_point = False - moe_config.gate_proj = self.gate_weights.data_ptr() - moe_config.up_proj = self.up_weights.data_ptr() - moe_config.down_proj = self.down_weights.data_ptr() - moe_config.gate_scale = self.gate_scales.data_ptr() - moe_config.up_scale = self.up_scales.data_ptr() - moe_config.down_scale = self.down_scales.data_ptr() + # Use gate_projs instead of gate_proj for per-expert pointers + moe_config.gate_projs = gate_ptrs + moe_config.up_projs = up_ptrs + moe_config.down_projs = down_ptrs + moe_config.gate_scales = gate_scale_ptrs + moe_config.up_scales = up_scale_ptrs + moe_config.down_scales = down_scale_ptrs self.moe = AMXInt4_KGroup_MOE(moe_config) + t4 = time.time() self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr())) self.cpu_infer.sync() + t5 = time.time() del self.gate_weights del self.up_weights @@ -408,6 +429,18 @@ class RAWAMXMoEWrapper(BaseMoEWrapper): del self.gate_scales del self.up_scales del self.down_scales + t6 = time.time() + + print( + f"[RAWAMXMoEWrapper Layer {self.layer_idx}] " + f"load_experts: {(t1-t0)*1000:.1f}ms, " + f"prepare_tensors: {(t2-t1)*1000:.1f}ms, " + f"build_ptrs: {(t3-t2)*1000:.1f}ms, " + f"create_moe: {(t4-t3)*1000:.1f}ms, " + f"cpp_load_weights: {(t5-t4)*1000:.1f}ms, " + f"cleanup: {(t6-t5)*1000:.1f}ms, " + f"total: {(t6-t0)*1000:.1f}ms" + ) def submit_write_weight_scale_to_buffer( self,