From 06fb3b5dbfbe1f9c19eea271fe91936c646b6464 Mon Sep 17 00:00:00 2001 From: mrhaoxx Date: Sun, 1 Feb 2026 15:17:47 +0000 Subject: [PATCH] [fix]: prequant weight load --- kt-kernel/operators/moe-sft-tp.hpp | 58 +++++++++++++++++++++++++++++- kt-kernel/python/utils/amx_sft.py | 7 ++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/kt-kernel/operators/moe-sft-tp.hpp b/kt-kernel/operators/moe-sft-tp.hpp index 2c240de2..7f57a6ec 100644 --- a/kt-kernel/operators/moe-sft-tp.hpp +++ b/kt-kernel/operators/moe-sft-tp.hpp @@ -211,8 +211,64 @@ class TP_MOE_SFT : public TP_MOE { // Sub-MOE accesses gate_projs[tp_part_idx] where tp_part_idx == numa_id. printf("TP_MOE_SFT: Pre-quantized per-NUMA mode (gate_projs path)\n"); pool->dispense_backend()->do_numa_job([this](int numa_id) { tps[numa_id]->load_weights(); }); + + // Also partition BF16 weights for backward gradient computation if available. + // C++ backward needs BF16 base weights to compute gate/up LoRA B gradients + // through the gated MLP chain (prepare_backward_weights checks config_.gate_proj). + if (config.gate_proj != nullptr) { + std::vector temp_gate(tp_count); + std::vector temp_up(tp_count); + std::vector temp_down(tp_count); + + for (int i = 0; i < tp_count; i++) { + auto& tpc = tp_configs[i]; + size_t gate_up_elcount = (size_t)tpc.intermediate_size * tpc.hidden_size; + + temp_gate[i] = new ggml_bf16_t[tpc.expert_num * gate_up_elcount]; + temp_up[i] = new ggml_bf16_t[tpc.expert_num * gate_up_elcount]; + temp_down[i] = new ggml_bf16_t[tpc.expert_num * gate_up_elcount]; + + pool->get_subpool(i)->do_work_stealing_job( + tpc.expert_num, nullptr, + [&, i, gate_up_elcount](int expert_id_) { + size_t expert_id = expert_map(physical_to_logical_map, expert_id_); + + memcpy(temp_gate[i] + expert_id * gate_up_elcount, + (ggml_bf16_t*)config.gate_proj + expert_id * config.intermediate_size * config.hidden_size + + i * gate_up_elcount, + sizeof(ggml_bf16_t) * gate_up_elcount); + + memcpy(temp_up[i] + expert_id * gate_up_elcount, + (ggml_bf16_t*)config.up_proj + expert_id * config.intermediate_size * config.hidden_size + + i * gate_up_elcount, + sizeof(ggml_bf16_t) * gate_up_elcount); + + for (size_t col = 0; col < config.hidden_size; col++) { + memcpy(temp_down[i] + expert_id * tpc.hidden_size * tpc.intermediate_size + col * tpc.intermediate_size, + (ggml_bf16_t*)config.down_proj + expert_id * config.intermediate_size * config.hidden_size + + col * config.intermediate_size + i * tpc.intermediate_size, + sizeof(ggml_bf16_t) * tpc.intermediate_size); + } + }, + nullptr); + } + + // Set BF16 weight pointers on sub-MOEs for backward + for (int i = 0; i < tp_count; i++) { + tps[i]->set_base_weight_pointers(temp_gate[i], temp_up[i], temp_down[i]); + } + + // Save partitioned BF16 weights for backward pass lifetime + partitioned_gate_proj_.resize(tp_count); + partitioned_up_proj_.resize(tp_count); + partitioned_down_proj_.resize(tp_count); + for (int i = 0; i < tp_count; i++) { + partitioned_gate_proj_[i] = temp_gate[i]; + partitioned_up_proj_[i] = temp_up[i]; + partitioned_down_proj_[i] = temp_down[i]; + } + } } else if (is_k2_prequantized) { - printf("TP_MOE_SFT: K2 pre-quantized mode (no BF16 partitioning)\n"); // For K2, weights are already int4-packed with scales // tp_configs[i] already has all pointers from config (copied in TP_MOE constructor) if (tp_count == 1) { diff --git a/kt-kernel/python/utils/amx_sft.py b/kt-kernel/python/utils/amx_sft.py index 89b5cf3b..b4c678a4 100644 --- a/kt-kernel/python/utils/amx_sft.py +++ b/kt-kernel/python/utils/amx_sft.py @@ -219,6 +219,13 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper): config.gate_scales = self._gate_scale_ptrs config.up_scales = self._up_scale_ptrs config.down_scales = self._down_scale_ptrs + # Also provide BF16 weight pointers for backward gradient computation. + # C++ backward needs BF16 base weights to compute gate/up LoRA B gradients + # through the gated MLP chain (grad_hidden = down_proj^T @ grad_output). + if getattr(self, "_bf16_gate_proj", None) is not None: + config.gate_proj = self._bf16_gate_proj.data_ptr() + config.up_proj = self._bf16_up_proj.data_ptr() + config.down_proj = self._bf16_down_proj.data_ptr() else: # Flat BF16 buffer path config.gate_proj = self.gate_proj.data_ptr()