From 07fd9328fa8dc9ffc7f90bce81cd178da13abac4 Mon Sep 17 00:00:00 2001 From: mrhaoxx Date: Wed, 8 Apr 2026 23:07:41 +0800 Subject: [PATCH] refactor(sft): move SFT logic into kt_kernel.sft submodule MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Create python/sft/ with 11 modules: base, amx, arch, autograd, layer, lora, weights, wrapper, dist_utils, config, __init__ - Move BaseSFTMoEWrapper + buffer management into sft/base.py (template method pattern: subclass provides _make_forward/backward_task) - Move AMXSFTMoEWrapper into sft/amx.py (thinner, no buffer logic) - Move from accelerate kt_moe.py: KTMoEFunction, KTMoELayerWrapper, MOEArchConfig, PEFT LoRA adaptation, weight extraction, wrapping - Add KTConfig dataclass (DeepSpeed pattern: opaque config passthrough) - Add _get_kt_config() with old→new field name compat conversion - Rename forward_sft→forward, submit_forward_sft→submit_forward, sync_forward_sft→sync_forward (Python only, C++ binding names unchanged) - Delete dump utilities from sft_moe.hpp (-526) and moe-sft-tp.hpp (-78) - Delete experts_sft.py and utils/amx_sft.py (moved to sft/) - Remove SFT stubs from BaseMoEWrapper (experts_base.py) - Lazy SFT import in __init__.py and experts.py (inference isolation) - Delete all lifecycle/debug logging (~500 lines) Verified: Qwen3-235B 4GPU AMXBF16 training, 3 steps loss converges. --- kt-kernel/examples/compare_tp_dumps.py | 2 +- kt-kernel/examples/test_moe_sft_tp_debug.py | 16 +- kt-kernel/examples/test_moe_sft_wrapper.py | 10 +- kt-kernel/examples/test_skip_lora.py | 4 +- kt-kernel/operators/amx/sft_moe.hpp | 526 +-------- kt-kernel/operators/moe-sft-tp.hpp | 78 -- kt-kernel/pyproject.toml | 2 + kt-kernel/python/__init__.py | 12 +- kt-kernel/python/experts.py | 14 +- kt-kernel/python/experts_base.py | 28 - kt-kernel/python/experts_sft.py | 431 ------- kt-kernel/python/sft/__init__.py | 83 ++ kt-kernel/python/sft/amx.py | 434 +++++++ kt-kernel/python/sft/arch.py | 265 +++++ kt-kernel/python/sft/autograd.py | 256 ++++ kt-kernel/python/sft/base.py | 402 +++++++ kt-kernel/python/sft/config.py | 124 ++ kt-kernel/python/sft/dist_utils.py | 184 +++ kt-kernel/python/sft/layer.py | 407 +++++++ kt-kernel/python/sft/lora.py | 688 +++++++++++ kt-kernel/python/sft/weights.py | 488 ++++++++ kt-kernel/python/sft/wrapper.py | 610 ++++++++++ kt-kernel/python/utils/amx_sft.py | 1162 ------------------- 23 files changed, 3975 insertions(+), 2251 deletions(-) delete mode 100644 kt-kernel/python/experts_sft.py create mode 100644 kt-kernel/python/sft/__init__.py create mode 100644 kt-kernel/python/sft/amx.py create mode 100644 kt-kernel/python/sft/arch.py create mode 100644 kt-kernel/python/sft/autograd.py create mode 100644 kt-kernel/python/sft/base.py create mode 100644 kt-kernel/python/sft/config.py create mode 100644 kt-kernel/python/sft/dist_utils.py create mode 100644 kt-kernel/python/sft/layer.py create mode 100644 kt-kernel/python/sft/lora.py create mode 100644 kt-kernel/python/sft/weights.py create mode 100644 kt-kernel/python/sft/wrapper.py delete mode 100644 kt-kernel/python/utils/amx_sft.py diff --git a/kt-kernel/examples/compare_tp_dumps.py b/kt-kernel/examples/compare_tp_dumps.py index b8dbb73d..17be5609 100644 --- a/kt-kernel/examples/compare_tp_dumps.py +++ b/kt-kernel/examples/compare_tp_dumps.py @@ -766,7 +766,7 @@ def run_cpp_forward_with_dump(wrapper, input_tensor, expert_ids, routing_weights os.environ["SFT_MOE_DUMP_DIR"] = dump_dir # Run forward with save_for_backward=True to enable backward - output = wrapper.forward_sft(input_tensor, expert_ids, routing_weights, save_for_backward=True) + output = wrapper.forward(input_tensor, expert_ids, routing_weights, save_for_backward=True) # Clean up environment del os.environ["SFT_MOE_DUMP"] diff --git a/kt-kernel/examples/test_moe_sft_tp_debug.py b/kt-kernel/examples/test_moe_sft_tp_debug.py index 92c35c12..0d30dce2 100644 --- a/kt-kernel/examples/test_moe_sft_tp_debug.py +++ b/kt-kernel/examples/test_moe_sft_tp_debug.py @@ -33,7 +33,7 @@ from typing import Dict, List, Optional, Tuple # Try to import kt_kernel try: from kt_kernel.experts import KTMoEWrapper - from kt_kernel.experts_sft import KExpertsSFTBuffer, BaseSFTMoEWrapper + from kt_kernel.sft.base import KExpertsSFTBuffer, BaseSFTMoEWrapper HAS_KT_KERNEL = True except ImportError: @@ -41,7 +41,7 @@ except ImportError: # Alternative import path (for development) sys.path.insert(0, os.path.dirname(__file__) + "/../python") from experts import KTMoEWrapper - from experts_sft import KExpertsSFTBuffer, BaseSFTMoEWrapper + from kt_kernel.sft.base import KExpertsSFTBuffer, BaseSFTMoEWrapper HAS_KT_KERNEL = True except ImportError as e: @@ -1460,7 +1460,7 @@ def test_tp_vs_cpp_wrapper(quant_mode: str = "AMXBF16_SFT", tp_count: int = TP_C py_output, py_intermediates = simulator.forward_moe(input_data, expert_ids, weights, dump_intermediates=True) # C++ wrapper forward - cpp_output = wrapper.forward_sft(input_data, expert_ids, weights, save_for_backward=False) + cpp_output = wrapper.forward(input_data, expert_ids, weights, save_for_backward=False) # Compare results diff = torch.mean(torch.abs(cpp_output - py_output)) / (torch.mean(torch.abs(py_output)) + 1e-8) @@ -1585,8 +1585,8 @@ def test_tp_vs_no_tp_cpp(quant_mode: str = "AMXBF16_SFT"): input_data = torch.randn((test_qlen, test_hidden_size), dtype=torch.bfloat16).contiguous() / 100 # Forward passes - output_tp = wrapper_tp.forward_sft(input_data, expert_ids, weights, save_for_backward=False) - output_no_tp = wrapper_no_tp.forward_sft(input_data, expert_ids, weights, save_for_backward=False) + output_tp = wrapper_tp.forward(input_data, expert_ids, weights, save_for_backward=False) + output_no_tp = wrapper_no_tp.forward(input_data, expert_ids, weights, save_for_backward=False) # Compare diff = torch.mean(torch.abs(output_tp - output_no_tp)) / (torch.mean(torch.abs(output_no_tp)) + 1e-8) @@ -1838,7 +1838,7 @@ def test_tp_backward_vs_cpp(quant_mode: str = "AMXBF16_SFT", tp_count: int = TP_ grad_output = torch.randn((test_qlen, test_hidden_size), dtype=torch.bfloat16).contiguous() / 100 # C++ forward (with save_for_backward=True) - cpp_output = wrapper.forward_sft(input_data, expert_ids, weights, save_for_backward=True) + cpp_output = wrapper.forward(input_data, expert_ids, weights, save_for_backward=True) # C++ backward cpp_grad_input, cpp_grad_loras = wrapper.backward(grad_output) @@ -2042,7 +2042,7 @@ def test_comprehensive_backward_with_dump( # C++ Forward Pass print("\n[Running C++ Forward Pass]") - cpp_output = wrapper.forward_sft(input_tensor, expert_ids, routing_weights, save_for_backward=True) + cpp_output = wrapper.forward(input_tensor, expert_ids, routing_weights, save_for_backward=True) cpp_fwd_has_nan = check_nan(cpp_output, "cpp_forward_output") # C++ Backward Pass @@ -2252,7 +2252,7 @@ def test_moe_backward_full(quant_mode: str = "AMXBF16_SFT", tp_count: int = TP_C # C++ forward + backward print("\n[Running C++ Forward + Backward]") - cpp_output = wrapper.forward_sft(input_tensor, expert_ids, routing_weights, save_for_backward=True) + cpp_output = wrapper.forward(input_tensor, expert_ids, routing_weights, save_for_backward=True) cpp_grad_input, cpp_grad_loras = wrapper.backward(output_grad) # PyTorch forward + backward using non-TP reference diff --git a/kt-kernel/examples/test_moe_sft_wrapper.py b/kt-kernel/examples/test_moe_sft_wrapper.py index 58f7e20a..f2938401 100644 --- a/kt-kernel/examples/test_moe_sft_wrapper.py +++ b/kt-kernel/examples/test_moe_sft_wrapper.py @@ -24,7 +24,7 @@ import torch.nn.functional as F # Try to import kt_kernel try: from kt_kernel.experts import KTMoEWrapper - from kt_kernel.experts_sft import KExpertsSFTBuffer, BaseSFTMoEWrapper + from kt_kernel.sft.base import KExpertsSFTBuffer, BaseSFTMoEWrapper HAS_KT_KERNEL = True except ImportError: @@ -32,7 +32,7 @@ except ImportError: # Alternative import path (for development) sys.path.insert(0, os.path.dirname(__file__) + "/../python") from experts import KTMoEWrapper - from experts_sft import KExpertsSFTBuffer, BaseSFTMoEWrapper + from kt_kernel.sft.base import KExpertsSFTBuffer, BaseSFTMoEWrapper HAS_KT_KERNEL = True except ImportError as e: @@ -513,7 +513,7 @@ def test_wrapper_forward(quant_mode: str = "AMXBF16_SFT", tp_count: int = TP_COU ) # Wrapper forward - output = wrapper.forward_sft(input_data, expert_ids, weights, save_for_backward=False) + output = wrapper.forward(input_data, expert_ids, weights, save_for_backward=False) # Compare results diff = torch.mean(torch.abs(output - torch_output)) / (torch.mean(torch.abs(torch_output)) + 1e-8) @@ -639,7 +639,7 @@ def test_wrapper_backward(quant_mode: str = "AMXBF16_SFT", tp_count: int = TP_CO ) # Wrapper forward (with save_for_backward=True) - output = wrapper.forward_sft(input_data, expert_ids, weights, save_for_backward=True) + output = wrapper.forward(input_data, expert_ids, weights, save_for_backward=True) # Wrapper backward grad_input, grad_loras = wrapper.backward(grad_output) @@ -794,7 +794,7 @@ def test_wrapper_training_loop(quant_mode: str = "AMXBF16_SFT", tp_count: int = target = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100 # Forward pass - output = wrapper.forward_sft(input_data, expert_ids, weights, save_for_backward=True) + output = wrapper.forward(input_data, expert_ids, weights, save_for_backward=True) # Compute loss loss = torch.mean((output.float() - target.float()) ** 2) diff --git a/kt-kernel/examples/test_skip_lora.py b/kt-kernel/examples/test_skip_lora.py index 1e58e897..56e1879e 100644 --- a/kt-kernel/examples/test_skip_lora.py +++ b/kt-kernel/examples/test_skip_lora.py @@ -429,8 +429,8 @@ def test_skip_lora(tp_count, threshold): # Run forward on both print("\n[7] Running C++ forward...") - output_normal = wrapper_normal.forward_sft(input_tensor, expert_ids, routing_weights, save_for_backward=True) - output_skip = wrapper_skip.forward_sft(input_tensor, expert_ids, routing_weights, save_for_backward=True) + output_normal = wrapper_normal.forward(input_tensor, expert_ids, routing_weights, save_for_backward=True) + output_skip = wrapper_skip.forward(input_tensor, expert_ids, routing_weights, save_for_backward=True) print(f" Normal forward output mean: {output_normal.float().mean():.6e}") print(f" SkipLoRA forward output mean: {output_skip.float().mean():.6e}") diff --git a/kt-kernel/operators/amx/sft_moe.hpp b/kt-kernel/operators/amx/sft_moe.hpp index dd65bdf5..bbd8eeea 100644 --- a/kt-kernel/operators/amx/sft_moe.hpp +++ b/kt-kernel/operators/amx/sft_moe.hpp @@ -158,197 +158,6 @@ inline bool is_nan_check_enabled() { return enabled == 1; } -// ===================================================== -// Dump Utility Functions for debugging -// Controlled by SFT_MOE_DUMP environment variable -// ===================================================== -inline bool is_dump_enabled() { - return false; - static int enabled = -1; - if (enabled < 0) { - const char* env = getenv("SFT_MOE_DUMP"); - enabled = (env && env[0] != '0') ? 1 : 0; - } - return enabled == 1; -} - -inline const char* get_dump_dir() { - static const char* dir = nullptr; - if (dir == nullptr) { - dir = getenv("SFT_MOE_DUMP_DIR"); - if (dir == nullptr) { - dir = "./cpp_dump"; - } - } - return dir; -} - -// Dump BF16 matrix to binary file (format: rows(int32), cols(int32), data(float32)) -// tp_idx: TP partition index (-1 for no TP suffix) -// expert_id: Expert index (-1 for no expert suffix) -inline void dump_bf16_matrix(const ggml_bf16_t* data, int rows, int cols, const char* name, int tp_idx = -1, - int expert_id = -1) { - if (!is_dump_enabled()) return; - - char filename[512]; - if (tp_idx >= 0 && expert_id >= 0) { - snprintf(filename, sizeof(filename), "%s/%s_tp%d_e%d.bin", get_dump_dir(), name, tp_idx, expert_id); - } else if (tp_idx >= 0) { - snprintf(filename, sizeof(filename), "%s/%s_tp%d.bin", get_dump_dir(), name, tp_idx); - } else if (expert_id >= 0) { - snprintf(filename, sizeof(filename), "%s/%s_e%d.bin", get_dump_dir(), name, expert_id); - } else { - snprintf(filename, sizeof(filename), "%s/%s.bin", get_dump_dir(), name); - } - - // Create directory if needed - char mkdir_cmd[600]; - snprintf(mkdir_cmd, sizeof(mkdir_cmd), "mkdir -p %s", get_dump_dir()); - system(mkdir_cmd); - - FILE* f = fopen(filename, "wb"); - if (!f) { - printf("[DUMP ERROR] Cannot open file: %s\n", filename); - return; - } - - // Write header - int32_t dims[2] = {rows, cols}; - fwrite(dims, sizeof(int32_t), 2, f); - - // Convert BF16 to FP32 and write - for (int i = 0; i < rows * cols; i++) { - float val = GGML_BF16_TO_FP32(data[i]); - fwrite(&val, sizeof(float), 1, f); - } - - fclose(f); - printf("[CPP DUMP] Saved %s: [%d x %d]\n", filename, rows, cols); -} - -// Dump BF16 matrix with scaling factor (for LoRA contributions that need lora_scaling applied) -inline void dump_bf16_matrix_scaled(const ggml_bf16_t* data, int rows, int cols, float scale, const char* name, - int tp_idx = -1, int expert_id = -1) { - if (!is_dump_enabled()) return; - - char filename[512]; - if (tp_idx >= 0 && expert_id >= 0) { - snprintf(filename, sizeof(filename), "%s/%s_tp%d_e%d.bin", get_dump_dir(), name, tp_idx, expert_id); - } else if (tp_idx >= 0) { - snprintf(filename, sizeof(filename), "%s/%s_tp%d.bin", get_dump_dir(), name, tp_idx); - } else if (expert_id >= 0) { - snprintf(filename, sizeof(filename), "%s/%s_e%d.bin", get_dump_dir(), name, expert_id); - } else { - snprintf(filename, sizeof(filename), "%s/%s.bin", get_dump_dir(), name); - } - - // Create directory if needed - char mkdir_cmd[600]; - snprintf(mkdir_cmd, sizeof(mkdir_cmd), "mkdir -p %s", get_dump_dir()); - system(mkdir_cmd); - - FILE* f = fopen(filename, "wb"); - if (!f) { - printf("[DUMP ERROR] Cannot open file: %s\n", filename); - return; - } - - // Write header - int32_t dims[2] = {rows, cols}; - fwrite(dims, sizeof(int32_t), 2, f); - - // Convert BF16 to FP32, apply scale, and write - for (int i = 0; i < rows * cols; i++) { - float val = GGML_BF16_TO_FP32(data[i]) * scale; - fwrite(&val, sizeof(float), 1, f); - } - - fclose(f); - printf("[CPP DUMP] Saved %s: [%d x %d] (scaled by %.2f)\n", filename, rows, cols, scale); -} - -// Dump FP32 matrix to binary file -inline void dump_fp32_matrix(const float* data, int rows, int cols, const char* name, int tp_idx = -1, - int expert_id = -1) { - if (!is_dump_enabled()) return; - - char filename[512]; - if (tp_idx >= 0 && expert_id >= 0) { - snprintf(filename, sizeof(filename), "%s/%s_tp%d_e%d.bin", get_dump_dir(), name, tp_idx, expert_id); - } else if (tp_idx >= 0) { - snprintf(filename, sizeof(filename), "%s/%s_tp%d.bin", get_dump_dir(), name, tp_idx); - } else if (expert_id >= 0) { - snprintf(filename, sizeof(filename), "%s/%s_e%d.bin", get_dump_dir(), name, expert_id); - } else { - snprintf(filename, sizeof(filename), "%s/%s.bin", get_dump_dir(), name); - } - - // Create directory if needed - char mkdir_cmd[600]; - snprintf(mkdir_cmd, sizeof(mkdir_cmd), "mkdir -p %s", get_dump_dir()); - system(mkdir_cmd); - - FILE* f = fopen(filename, "wb"); - if (!f) { - printf("[DUMP ERROR] Cannot open file: %s\n", filename); - return; - } - - // Write header - int32_t dims[2] = {rows, cols}; - fwrite(dims, sizeof(int32_t), 2, f); - - // Write data - fwrite(data, sizeof(float), rows * cols, f); - - fclose(f); - printf("[CPP DUMP] Saved %s: [%d x %d]\n", filename, rows, cols); -} - -// Dump routing info to binary file -inline void dump_routing_info(int qlen, int k, const int64_t* expert_ids, const float* weights, int num_experts, - const std::vector& m_local_num, int tp_idx = -1) { - if (!is_dump_enabled()) return; - - char filename[512]; - if (tp_idx >= 0) { - snprintf(filename, sizeof(filename), "%s/routing_info_tp%d.bin", get_dump_dir(), tp_idx); - } else { - snprintf(filename, sizeof(filename), "%s/routing_info.bin", get_dump_dir()); - } - - // Create directory if needed - char mkdir_cmd[600]; - snprintf(mkdir_cmd, sizeof(mkdir_cmd), "mkdir -p %s", get_dump_dir()); - system(mkdir_cmd); - - FILE* f = fopen(filename, "wb"); - if (!f) { - printf("[DUMP ERROR] Cannot open file: %s\n", filename); - return; - } - - // Write qlen, k - int32_t dims[2] = {qlen, k}; - fwrite(dims, sizeof(int32_t), 2, f); - - // Write expert_ids [qlen * k] - fwrite(expert_ids, sizeof(int64_t), qlen * k, f); - - // Write weights [qlen * k] - fwrite(weights, sizeof(float), qlen * k, f); - - // Write num_experts and m_local_num - int32_t ne = num_experts; - fwrite(&ne, sizeof(int32_t), 1, f); - for (int i = 0; i < num_experts; i++) { - int32_t cnt = m_local_num[i]; - fwrite(&cnt, sizeof(int32_t), 1, f); - } - - fclose(f); - printf("[CPP DUMP] Saved %s: qlen=%d, k=%d\n", filename, qlen, k); -} // ===================================================== // Pool Memory Logger — writes per-call alloc/free events to file @@ -1289,18 +1098,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { }, "fwd_pack_input", 1); - // DUMP: Routing info and packed input - if (is_dump_enabled()) { - dump_routing_info(qlen, k, expert_ids, weights, config_.expert_num, m_local_num_, tp_part_idx); - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - if (m_local_num_[expert_idx] > 0) { - dump_bf16_matrix(m_local_input_ptr_[expert_idx], m_local_num_[expert_idx], config_.hidden_size, - "packed_input", tp_part_idx, expert_idx); - } - } - } - // NaN Check: Step 3 - Packed input if (is_nan_check_enabled()) { for (int i = 0; i < activated_expert; i++) { @@ -1342,19 +1139,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { }, nullptr, "fwd_gate_up_gemm", 1); - // DUMP: Gate/Up base output (before LoRA) - if (is_dump_enabled()) { - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - if (m_local_num_[expert_idx] > 0) { - dump_bf16_matrix(m_local_gate_output_ptr_[expert_idx], m_local_num_[expert_idx], config_.intermediate_size, - "gate_base_output", tp_part_idx, expert_idx); - dump_bf16_matrix(m_local_up_output_ptr_[expert_idx], m_local_num_[expert_idx], config_.intermediate_size, - "up_base_output", tp_part_idx, expert_idx); - } - } - } - // NaN Check: Step 5 - Gate/Up GEMM output (before LoRA) if (is_nan_check_enabled()) { for (int i = 0; i < activated_expert; i++) { @@ -1378,21 +1162,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { compute_lora_gate_up(qlen, activated_expert); } - // DUMP: Gate/Up output (after LoRA, before activation) - if (is_dump_enabled() && gate_lora_a_ != nullptr) { - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - if (m_local_num_[expert_idx] > 0) { - // Note: After LoRA, gate/up outputs have been updated in-place - // These now include base + lora - dump_bf16_matrix(m_local_gate_output_ptr_[expert_idx], m_local_num_[expert_idx], config_.intermediate_size, - "gate_lora_output", tp_part_idx, expert_idx); - dump_bf16_matrix(m_local_up_output_ptr_[expert_idx], m_local_num_[expert_idx], config_.intermediate_size, - "up_lora_output", tp_part_idx, expert_idx); - } - } - } - // NaN Check: Step 5.5 - Gate/Up output (after LoRA) if (is_nan_check_enabled()) { for (int i = 0; i < activated_expert; i++) { @@ -1457,18 +1226,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { } } - // DUMP: Activation input (gate_out and up_out before activation) - if (is_dump_enabled()) { - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - if (m_local_num_[expert_idx] > 0) { - dump_bf16_matrix(m_local_gate_output_ptr_[expert_idx], m_local_num_[expert_idx], config_.intermediate_size, - "activation_input_gate", tp_part_idx, expert_idx); - dump_bf16_matrix(m_local_up_output_ptr_[expert_idx], m_local_num_[expert_idx], config_.intermediate_size, - "activation_input_up", tp_part_idx, expert_idx); - } - } - } + // Step 6: Activation (silu(gate) * up) { @@ -1478,18 +1236,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { sft_timer::add_kernel_trace("apply_activation", act_start, act_end, tp_part_idx, 0); } - // DUMP: Activation output (silu(gate) * up) - if (is_dump_enabled()) { - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - if (m_local_num_[expert_idx] > 0) { - // After activation, result is stored in m_local_gate_output_ptr_ - dump_bf16_matrix(m_local_gate_output_ptr_[expert_idx], m_local_num_[expert_idx], config_.intermediate_size, - "activation_output", tp_part_idx, expert_idx); - } - } - } - // NaN Check: Step 6 - Activation output (silu(gate) * up) if (is_nan_check_enabled()) { for (int i = 0; i < activated_expert; i++) { @@ -1563,17 +1309,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { }, nullptr, "fwd_down_gemm", 1); - // DUMP: Down base output (before LoRA) - if (is_dump_enabled()) { - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - if (m_local_num_[expert_idx] > 0) { - dump_bf16_matrix(m_local_down_output_ptr_[expert_idx], m_local_num_[expert_idx], config_.hidden_size, - "down_base_output", tp_part_idx, expert_idx); - } - } - } - // NaN Check: Step 8 - Down GEMM output (before LoRA) if (is_nan_check_enabled()) { for (int i = 0; i < activated_expert; i++) { @@ -1594,20 +1329,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { compute_lora_down(qlen, activated_expert, cache_ptr); } - // DUMP: Down output (after LoRA, before merge) - if (is_dump_enabled() && down_lora_a_ != nullptr) { - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - if (m_local_num_[expert_idx] > 0) { - dump_bf16_matrix(m_local_down_output_ptr_[expert_idx], m_local_num_[expert_idx], config_.hidden_size, - "down_lora_output", tp_part_idx, expert_idx); - // down_total_output is same as down_lora_output (lora is added in-place) - dump_bf16_matrix(m_local_down_output_ptr_[expert_idx], m_local_num_[expert_idx], config_.hidden_size, - "down_total_output", tp_part_idx, expert_idx); - } - } - } - // NaN Check: Step 8.5 - Down output (after LoRA) if (is_nan_check_enabled()) { for (int i = 0; i < activated_expert; i++) { @@ -1654,12 +1375,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { }, nullptr, "fwd_merge"); - // DUMP: Final output (after weighted merge) - // Note: Each TP partition outputs a partial result that gets summed later - if (is_dump_enabled()) { - dump_fp32_matrix((const float*)output, qlen, config_.hidden_size, "final_output", tp_part_idx); - } - // NaN Check: Step 9 - Final output (after weighted merge) if (is_nan_check_enabled()) { char label[128]; @@ -3708,21 +3423,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { }, nullptr, "fwd_lora_gu_a"); - // DUMP: LoRA intermediate (input @ lora_A^T) for gate and up - // Note: Use padded_lora_rank_ as stride since to_mat writes with this stride - if (is_dump_enabled()) { - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - int m = m_local_num_[expert_idx]; - if (m > 0) { - dump_bf16_matrix(lora_gate_intermediate_ptr_[expert_idx], m, padded_lora_rank_, "gate_lora_intermediate", - tp_part_idx, expert_idx); - dump_bf16_matrix(lora_up_intermediate_ptr_[expert_idx], m, padded_lora_rank_, "up_lora_intermediate", - tp_part_idx, expert_idx); - } - } - } - // ===================================================== // Step 2: Quantize lora_intermediate to BufferA // Need to quantize BOTH gate and up intermediates separately @@ -3745,10 +3445,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { // Step 3a: lora_intermediate @ lora_B^T -> lora_output (GEMM only) // ===================================================== nth = T::recommended_nth(config_.intermediate_size); - if (is_dump_enabled()) { - printf("[DEBUG] Step 3a GEMM: nth=%d, activated_expert=%d, total_tasks=%d\n", nth, activated_expert, - nth * activated_expert * 2); - } pool->do_work_stealing_job( nth * activated_expert * 2, [](int _) { T::config(); }, [this, nth](int task_id2) { @@ -3765,73 +3461,11 @@ class AMX_SFT_MOE_TP : public BaseMOE { auto& bb = do_up ? up_lora_b_bb_[expert_idx] : gate_lora_b_bb_[expert_idx]; auto& bc = do_up ? lora_up_out_bc_[expert_idx] : lora_gate_out_bc_[expert_idx]; - if (is_dump_enabled() && !do_up && expert_idx == 0) { - printf("[DEBUG] GEMM task START: expert=%d, ith=%d, nth=%d, m=%d, n=%d, k=%d\n", expert_idx, ith, nth, m, - config_.intermediate_size, padded_lora_rank_); - } - // GEMM: [m, padded_lora_rank] @ [intermediate_size, padded_lora_rank]^T -> [m, intermediate_size] amx::mat_mul(m, config_.intermediate_size, padded_lora_rank_, ba, bb, bc, ith, nth); - - if (is_dump_enabled() && !do_up && expert_idx == 0) { - // Check raw BufferC data immediately after this GEMM task - float* raw_c = bc->get_submat(m, config_.intermediate_size, 0, ith * T::N_BLOCK); - printf("[DEBUG] GEMM task DONE: expert=%d, ith=%d, raw_c[0]=%.6f, raw_c[1]=%.6f\n", expert_idx, ith, - raw_c[0], raw_c[1]); - } }, nullptr, "fwd_lora_gu_gemm"); - // DUMP: Pure gate/up LoRA GEMM output (before scaling and add) - // Note: to_mat with (ith, nth) only reads one N_BLOCK chunk, so we need to loop - if (is_dump_enabled()) { - int dump_nth = T::recommended_nth(config_.intermediate_size); - printf("[DEBUG] gate/up GEMM dump: intermediate_size=%d, N_BLOCK=%d, dump_nth=%d\n", config_.intermediate_size, - T::N_BLOCK, dump_nth); - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - int m = m_local_num_[expert_idx]; - if (m > 0) { - printf("[DEBUG] expert=%d, m=%d, BufferC.n=%d\n", expert_idx, m, lora_gate_out_bc_[expert_idx]->n); - // Convert BufferC to FP32 and dump for gate - std::vector gate_lora_fp32(m * config_.intermediate_size); - std::vector gate_lora_bf16(m * config_.intermediate_size); - // Initialize to a known pattern to detect if to_mat writes anything - for (size_t idx = 0; idx < gate_lora_bf16.size(); idx++) { - gate_lora_bf16[idx] = GGML_FP32_TO_BF16(999.0f); - } - for (int ith = 0; ith < dump_nth; ith++) { - printf("[DEBUG] calling to_mat with ith=%d, dump_nth=%d\n", ith, dump_nth); - lora_gate_out_bc_[expert_idx]->to_mat(m, gate_lora_bf16.data(), ith, dump_nth); - // Check what was written - float val_at_0 = GGML_BF16_TO_FP32(gate_lora_bf16[0]); - float val_at_256 = GGML_BF16_TO_FP32(gate_lora_bf16[256]); - float val_at_512 = (m > 1) ? GGML_BF16_TO_FP32(gate_lora_bf16[512]) : 0; - float val_at_768 = (m > 1) ? GGML_BF16_TO_FP32(gate_lora_bf16[768]) : 0; - printf("[DEBUG] after ith=%d: buf[0]=%.6f, buf[256]=%.6f, buf[512]=%.6f, buf[768]=%.6f\n", ith, val_at_0, - val_at_256, val_at_512, val_at_768); - } - for (int j = 0; j < m * config_.intermediate_size; j++) { - gate_lora_fp32[j] = GGML_BF16_TO_FP32(gate_lora_bf16[j]); - } - dump_fp32_matrix(gate_lora_fp32.data(), m, config_.intermediate_size, "gate_lora_gemm_output", tp_part_idx, - expert_idx); - - // Convert BufferC to FP32 and dump for up - std::vector up_lora_fp32(m * config_.intermediate_size); - std::vector up_lora_bf16(m * config_.intermediate_size); - for (int ith = 0; ith < dump_nth; ith++) { - lora_up_out_bc_[expert_idx]->to_mat(m, up_lora_bf16.data(), ith, dump_nth); - } - for (int j = 0; j < m * config_.intermediate_size; j++) { - up_lora_fp32[j] = GGML_BF16_TO_FP32(up_lora_bf16[j]); - } - dump_fp32_matrix(up_lora_fp32.data(), m, config_.intermediate_size, "up_lora_gemm_output", tp_part_idx, - expert_idx); - } - } - } - // ===================================================== // Step 3b: Add LoRA output to main output with scaling // ===================================================== @@ -3949,19 +3583,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { }, nullptr, "fwd_lora_down_a"); - // DUMP: LoRA intermediate (intermediate @ down_lora_A^T) for down - // Note: Use padded_lora_rank_ as stride since to_mat writes with this stride - if (is_dump_enabled()) { - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - int m = m_local_num_[expert_idx]; - if (m > 0) { - // Down reuses lora_gate_intermediate_ptr_ - dump_bf16_matrix(lora_gate_intermediate_ptr_[expert_idx], m, padded_lora_rank_, "down_lora_intermediate", - tp_part_idx, expert_idx); - } - } - } + // ===================================================== // Step 2: Quantize lora_intermediate to BufferA @@ -4000,32 +3622,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { }, nullptr, "fwd_lora_down_gemm", 1); - // DUMP: Pure down LoRA GEMM output (before scaling and add) - // Note: to_mat with (ith, nth) only reads one N_BLOCK chunk, so we need to loop - if (is_dump_enabled()) { - int dump_nth = T::recommended_nth(config_.hidden_size); - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - int m = m_local_num_[expert_idx]; - if (m > 0) { - // Convert BufferC to FP32 matrix and dump - std::vector lora_out_fp32(m * config_.hidden_size); - auto& bc = lora_down_out_bc_[expert_idx]; - // Use to_mat to convert, but we need BF16 temp buffer - std::vector lora_out_bf16(m * config_.hidden_size); - // Loop over all N_BLOCK chunks - for (int ith = 0; ith < dump_nth; ith++) { - bc->to_mat(m, lora_out_bf16.data(), ith, dump_nth); - } - for (int j = 0; j < m * config_.hidden_size; j++) { - lora_out_fp32[j] = GGML_BF16_TO_FP32(lora_out_bf16[j]); - } - dump_fp32_matrix(lora_out_fp32.data(), m, config_.hidden_size, "down_lora_gemm_output", tp_part_idx, - expert_idx); - } - } - } - // ===================================================== // Step 3b: Add LoRA output to main output with scaling // ===================================================== @@ -4825,23 +4421,7 @@ class AMX_SFT_MOE_TP : public BaseMOE { "bwd_down_lora_to_inter"); } - // DUMP: backward grad_output and grad_intermediate (base) after GEMM - if (is_dump_enabled()) { - size_t offset = 0; - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - int m = m_local_num_[expert_idx]; - if (m > 0) { - // Dump scattered grad_output - dump_bf16_matrix(grad_output_bf16_ptr_[expert_idx], m, config_.hidden_size, "backward_grad_output", - tp_part_idx, expert_idx); - // Dump grad_intermediate (base, before LoRA) - dump_bf16_matrix(grad_intermediate_ + offset, m, config_.intermediate_size, "backward_down_base", tp_part_idx, - expert_idx); - } - offset += m * config_.intermediate_size; - } - } + // ===================================================== // Step 5: LoRA gradient computation (parallelized across blocks) @@ -5365,36 +4945,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { }, "bwd_act_silu"); - // DUMP: backward activation inputs and outputs - // Bug #18b fix: offset is accumulated as elements (tokens * intermediate_size), - // so don't multiply by intermediate_size again when accessing cache - if (is_dump_enabled()) { - size_t offset = 0; - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - int m = m_local_num_[expert_idx]; - if (m > 0) { - // Dump cached gate_output and up_output used in activation backward - // Note: offset is already in elements, no need to multiply by intermediate_size - ggml_bf16_t* gate_out_cached = cache.gate_output_cache + offset; - ggml_bf16_t* up_out_cached = cache.up_output_cache + offset; - dump_bf16_matrix(gate_out_cached, m, config_.intermediate_size, "backward_act_gate_cache", tp_part_idx, - expert_idx); - dump_bf16_matrix(up_out_cached, m, config_.intermediate_size, "backward_act_up_cache", tp_part_idx, - expert_idx); - // Dump grad_intermediate (input to activation backward) - dump_bf16_matrix(grad_intermediate_ + offset, m, config_.intermediate_size, "backward_grad_intermediate", - tp_part_idx, expert_idx); - // Dump grad_gate_out - dump_bf16_matrix(grad_gate_output_ + offset, m, config_.intermediate_size, "backward_grad_gate_out", - tp_part_idx, expert_idx); - // Dump grad_up_out - dump_bf16_matrix(grad_up_output_ + offset, m, config_.intermediate_size, "backward_grad_up_out", tp_part_idx, - expert_idx); - } - offset += m * config_.intermediate_size; - } - } } /** @@ -5515,21 +5065,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { } } - bool dump_enabled = is_dump_enabled(); - - // Accumulation buffers for per-expert grad_input dump (only when dump is enabled) - // Maps expert_idx -> FP32 accumulation buffer [m_local_num x hidden_size] - std::unordered_map> expert_grad_accum; - if (dump_enabled) { - for (int task_id = 0; task_id < activated_expert; task_id++) { - int expert_idx = m_expert_id_map_[task_id]; - int m = m_local_num_[expert_idx]; - if (m > 0) { - expert_grad_accum[expert_idx].resize(m * config_.hidden_size, 0.0f); - } - } - } - auto scatter_to_grad_input = [&](float scale, const char* task_name) { ggml_bf16_t* grad_input_bf16 = (ggml_bf16_t*)grad_input; const int hidden = config_.hidden_size; @@ -5550,17 +5085,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { int pos = cache.m_local_pos_cache[token_id][j]; ggml_bf16_t* contrib = grad_output_bf16_ptr_[expert_idx] + pos * config_.hidden_size; - // Accumulate per-expert grad_input for dumps (no routing weights) - if (dump_enabled) { - auto it = expert_grad_accum.find(expert_idx); - if (it != expert_grad_accum.end()) { - float* accum = it->second.data() + pos * hidden; - for (int h = 0; h < hidden; h++) { - accum[h] += GGML_BF16_TO_FP32(contrib[h]) * scale; - } - } - } - int h = 0; for (; h < hidden_vec_end; h += 32) { __m512 x0, x1, cur0, cur1; @@ -5619,18 +5143,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { }, nullptr, gemm_name, 1); - // DUMP: base backward output before scatter - if (is_dump_enabled()) { - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - int m = m_local_num_[expert_idx]; - if (m > 0) { - const char* name = do_up ? "backward_up_base" : "backward_gate_base"; - dump_bf16_matrix(grad_output_bf16_ptr_[expert_idx], m, config_.hidden_size, name, tp_part_idx, expert_idx); - } - } - } - scatter_to_grad_input(1.0f, "bwd_gu_scatter_base"); }; @@ -5935,23 +5447,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { gb_gradin_name); } - // DUMP: LoRA contribution before scatter - if (is_dump_enabled()) { - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - int m = m_local_num_[expert_idx]; - if (m > 0) { - ggml_bf16_t* inter_ptr = - do_up ? lora_up_intermediate_ptr_[expert_idx] : lora_gate_intermediate_ptr_[expert_idx]; - const char* inter_name = do_up ? "backward_up_lora_inter" : "backward_gate_lora_inter"; - dump_bf16_matrix(inter_ptr, m, padded_lora_rank_, inter_name, tp_part_idx, expert_idx); - const char* lora_name = do_up ? "backward_up_lora" : "backward_gate_lora"; - dump_bf16_matrix_scaled(grad_output_bf16_ptr_[expert_idx], m, config_.hidden_size, lora_scaling_, lora_name, - tp_part_idx, expert_idx); - } - } - } - scatter_to_grad_input(lora_scaling_, "bwd_gu_scatter_lora"); // Step 6: grad_A = G_B^T @ X @@ -6058,21 +5553,6 @@ class AMX_SFT_MOE_TP : public BaseMOE { lora_pass_remainder(false); // gate: gb_gradin_fused, scatter, gradA lora_pass_remainder(true); // up: gb_gradin_fused, scatter, gradA - // DUMP: backward grad_input per expert (accumulated sum of gate_base + gate_lora + up_base + up_lora) - if (is_dump_enabled()) { - for (int i = 0; i < activated_expert; i++) { - int expert_idx = m_expert_id_map_[i]; - int m = m_local_num_[expert_idx]; - if (m > 0) { - auto it = expert_grad_accum.find(expert_idx); - if (it != expert_grad_accum.end()) { - // Dump accumulated per-expert grad_input (sum of all 4 contributions) - dump_fp32_matrix(it->second.data(), m, config_.hidden_size, "backward_grad_input_expert", tp_part_idx, - expert_idx); - } - } - } - } } }; diff --git a/kt-kernel/operators/moe-sft-tp.hpp b/kt-kernel/operators/moe-sft-tp.hpp index fe3ca2eb..35c3e9ec 100644 --- a/kt-kernel/operators/moe-sft-tp.hpp +++ b/kt-kernel/operators/moe-sft-tp.hpp @@ -29,77 +29,6 @@ static constexpr int kMoeSftTpVersion = 3; #include "amx/la/amx.hpp" #include "moe-tp.hpp" -// Dump utilities for TP backward debugging -namespace tp_dump { -inline bool is_dump_enabled() { - static int enabled = -1; - if (enabled < 0) { - const char* env = getenv("SFT_MOE_DUMP"); - enabled = (env != nullptr && env[0] == '1') ? 1 : 0; - } - return enabled == 1; -} - -inline const char* get_dump_dir() { - static const char* dir = nullptr; - if (dir == nullptr) { - dir = getenv("SFT_MOE_DUMP_DIR"); - if (dir == nullptr) dir = "./cpp_dump"; - } - return dir; -} - -inline void dump_bf16_matrix(const ggml_bf16_t* data, int rows, int cols, const char* name, int tp_idx) { - if (!is_dump_enabled()) return; - char filename[256]; - snprintf(filename, sizeof(filename), "%s/%s_tp%d.bin", get_dump_dir(), name, tp_idx); - std::ofstream file(filename, std::ios::binary); - if (!file.is_open()) return; - file.write(reinterpret_cast(&rows), sizeof(int)); - file.write(reinterpret_cast(&cols), sizeof(int)); - for (int i = 0; i < rows * cols; i++) { - float val = GGML_BF16_TO_FP32(data[i]); - file.write(reinterpret_cast(&val), sizeof(float)); - } -} - -inline void dump_bf16_matrix_final(const ggml_bf16_t* data, int rows, int cols, const char* name) { - if (!is_dump_enabled()) return; - char filename[256]; - snprintf(filename, sizeof(filename), "%s/%s.bin", get_dump_dir(), name); - std::ofstream file(filename, std::ios::binary); - if (!file.is_open()) return; - file.write(reinterpret_cast(&rows), sizeof(int)); - file.write(reinterpret_cast(&cols), sizeof(int)); - for (int i = 0; i < rows * cols; i++) { - float val = GGML_BF16_TO_FP32(data[i]); - file.write(reinterpret_cast(&val), sizeof(float)); - } -} - -inline void dump_fp32_matrix(const float* data, int rows, int cols, const char* name, int tp_idx) { - if (!is_dump_enabled()) return; - char filename[256]; - snprintf(filename, sizeof(filename), "%s/%s_tp%d.bin", get_dump_dir(), name, tp_idx); - std::ofstream file(filename, std::ios::binary); - if (!file.is_open()) return; - file.write(reinterpret_cast(&rows), sizeof(int)); - file.write(reinterpret_cast(&cols), sizeof(int)); - file.write(reinterpret_cast(data), sizeof(float) * rows * cols); -} - -inline void dump_fp32_matrix_final(const float* data, int rows, int cols, const char* name) { - if (!is_dump_enabled()) return; - char filename[256]; - snprintf(filename, sizeof(filename), "%s/%s.bin", get_dump_dir(), name); - std::ofstream file(filename, std::ios::binary); - if (!file.is_open()) return; - file.write(reinterpret_cast(&rows), sizeof(int)); - file.write(reinterpret_cast(&cols), sizeof(int)); - file.write(reinterpret_cast(data), sizeof(float) * rows * cols); -} -} // namespace tp_dump - struct TPBf16Stats { double abs_mean = 0.0; double abs_max = 0.0; @@ -809,11 +738,6 @@ class TP_MOE_SFT : public TP_MOE { // } // } - // DUMP: per-TP grad_input before merge - // for (int i = 0; i < tp_count; i++) { - // tp_dump::dump_bf16_matrix(part_grad_input[i], qlen, hidden_size, "backward_grad_input", i); - // } - // Bug #22 fix: Merge grad_input from all NUMA nodes (sum them together) auto start_sum = sft_timer::get_trace_timestamp(); { @@ -863,8 +787,6 @@ class TP_MOE_SFT : public TP_MOE { nullptr, "merge_grad_input"); } auto end_sum = sft_timer::get_trace_timestamp(); - // DUMP: final merged grad_input - tp_dump::dump_bf16_matrix_final((ggml_bf16_t*)grad_input, qlen, hidden_size, "backward_grad_input_final"); // Merge reduce-type LoRA gradients: sparse FP32 sum across TPs → BF16 final output // Copy-type grads (gate/up_lora_b, down_lora_a) were written directly — no merge needed. diff --git a/kt-kernel/pyproject.toml b/kt-kernel/pyproject.toml index 4c9e55ec..ef5b52aa 100644 --- a/kt-kernel/pyproject.toml +++ b/kt-kernel/pyproject.toml @@ -53,6 +53,7 @@ Homepage = "https://github.com/kvcache-ai" packages = [ "kt_kernel", "kt_kernel.utils", + "kt_kernel.sft", "kt_kernel.cli", "kt_kernel.cli.commands", "kt_kernel.cli.config", @@ -64,6 +65,7 @@ include-package-data = true [tool.setuptools.package-dir] kt_kernel = "python" "kt_kernel.utils" = "python/utils" +"kt_kernel.sft" = "python/sft" "kt_kernel.cli" = "python/cli" "kt_kernel.cli.commands" = "python/cli/commands" "kt_kernel.cli.config" = "python/cli/config" diff --git a/kt-kernel/python/__init__.py b/kt-kernel/python/__init__.py index 168fda24..fceb72d3 100644 --- a/kt-kernel/python/__init__.py +++ b/kt-kernel/python/__init__.py @@ -51,10 +51,14 @@ kt_kernel_ext = _kt_kernel_ext # Import main API from .experts import KTMoEWrapper -try: - from .utils.amx_sft import AMXSFTMoEWrapper -except (ImportError, AttributeError): - AMXSFTMoEWrapper = None +def __getattr__(name): + if name == "AMXSFTMoEWrapper": + try: + from .sft.amx import AMXSFTMoEWrapper + return AMXSFTMoEWrapper + except (ImportError, AttributeError): + return None + raise AttributeError(f"module 'kt_kernel' has no attribute {name!r}") # Read version from package metadata (preferred) or fallback to project root try: diff --git a/kt-kernel/python/experts.py b/kt-kernel/python/experts.py index 46dbcd91..83cba5c9 100644 --- a/kt-kernel/python/experts.py +++ b/kt-kernel/python/experts.py @@ -23,17 +23,11 @@ from typing import List, Optional, Union # Import base infrastructure for inference from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer -# Import base infrastructure for SFT -from .experts_sft import BaseSFTMoEWrapper, KExpertsSFTBuffer - # Import inference backend implementations from .utils.amx import AMXMoEWrapper, NativeMoEWrapper from .utils.llamafile import LlamafileMoEWrapper from .utils.moe_kernel import GeneralMoEWrapper -# Import SFT backend implementations -from .utils.amx_sft import AMXSFTMoEWrapper - # Valid methods for each mode INFERENCE_METHODS = frozenset( @@ -138,7 +132,7 @@ class KTMoEWrapper: # Quantization config (for K-Group SFT methods) group_size: int = 128, zero_point: bool = True, - ) -> Union[BaseMoEWrapper, BaseSFTMoEWrapper]: + ): """ Factory method to create the appropriate backend implementation. @@ -265,6 +259,7 @@ class KTMoEWrapper: This frees up memory by clearing the SFT buffer cache. Useful when you want to reset the buffer state or free memory during SFT. """ + from .sft.base import KExpertsSFTBuffer KExpertsSFTBuffer.clear_cache() @@ -345,7 +340,7 @@ def _create_sft_wrapper( max_cache_depth: int, group_size: int, zero_point: bool, -) -> BaseSFTMoEWrapper: +): """ Create an SFT wrapper based on the method. @@ -355,8 +350,9 @@ def _create_sft_wrapper( Returns: BaseSFTMoEWrapper instance """ + from .sft.amx import AMXSFTMoEWrapper + # Currently only AMX SFT methods are supported - # All SFT methods use AMXSFTMoEWrapper with different quantization return AMXSFTMoEWrapper( layer_idx=layer_idx, num_experts=num_experts, diff --git a/kt-kernel/python/experts_base.py b/kt-kernel/python/experts_base.py index 076086a9..be8f6877 100644 --- a/kt-kernel/python/experts_base.py +++ b/kt-kernel/python/experts_base.py @@ -451,31 +451,3 @@ class BaseMoEWrapper(_MoEBase, ABC): KExpertsCPUBuffer.temp_bs = 0 KExpertsCPUBuffer.temp_buffer = tuple() - # ========== SFT methods (not available in inference mode) ========== - - def forward_sft(self, *args, **kwargs): - """SFT forward is not available in inference mode.""" - raise RuntimeError( - "forward_sft() is not available in inference mode. " - "Use forward() instead, or create wrapper with mode='sft'." - ) - - def backward(self, *args, **kwargs): - """Backward pass is not available in inference mode.""" - raise RuntimeError( - "backward() is not available in inference mode. " "Create wrapper with mode='sft' to use SFT features." - ) - - def init_lora_weights(self, *args, **kwargs): - """LoRA weight initialization is not available in inference mode.""" - raise RuntimeError( - "init_lora_weights() is not available in inference mode. " - "Create wrapper with mode='sft' to use SFT features." - ) - - def update_lora_weights(self, *args, **kwargs): - """LoRA weight update is not available in inference mode.""" - raise RuntimeError( - "update_lora_weights() is not available in inference mode. " - "Create wrapper with mode='sft' to use SFT features." - ) diff --git a/kt-kernel/python/experts_sft.py b/kt-kernel/python/experts_sft.py deleted file mode 100644 index f1cdb9bb..00000000 --- a/kt-kernel/python/experts_sft.py +++ /dev/null @@ -1,431 +0,0 @@ -# SFT MoE Wrapper classes for CPU-based fine-tuning operations -# SPDX-License-Identifier: Apache-2.0 - -""" -SFT (Supervised Fine-Tuning) MoE Wrapper classes and buffer management. - -This module provides: -- KExpertsSFTBuffer: Buffer management for SFT forward/backward passes -- BaseSFTMoEWrapper: Abstract base class for SFT MoE wrappers - -Key differences from inference wrappers: -- Supports forward_sft() with gradient caching for backward pass -- Supports backward() for computing LoRA gradients -- Uses synchronous execution (no double buffering) -- Independent from inference forward() logic to ensure gradient correctness -""" - -from __future__ import annotations - -import torch -from typing import Dict, Optional, Tuple -from abc import ABC, abstractmethod - -from .experts_base import _MoEBase - - -class KExpertsSFTBuffer: - """ - CPU buffer management for SFT expert computation. - - Unlike inference KExpertsCPUBuffer: - - No double buffering (SFT requires synchronous execution) - - Includes gradient buffers for backward pass - - Includes 6 LoRA gradient buffers - - Buffer contents: - - Forward: input_cpu, expert_ids_cpu, weights_cpu, output_cpu - - Backward: grad_output_cpu, grad_input_cpu - - LoRA gradients: grad_gate_lora_a/b, grad_up_lora_a/b, grad_down_lora_a/b - """ - - # Single grow-only buffer (never shrinks). Replaces the old per-qlen - # dict that leaked one buffer per unique sequence length. - _shared_buffer: Optional["KExpertsSFTBuffer"] = None - - def __init__( - self, - qlen: int, - hidden_size: int, - moe_intermediate_size: int, - num_experts: int, - num_experts_per_tok: int, - lora_rank: int, - dtype: torch.dtype = torch.bfloat16, - ): - self.qlen = qlen - self.hidden_size = hidden_size - self.moe_intermediate_size = moe_intermediate_size - self.num_experts = num_experts - self.num_experts_per_tok = num_experts_per_tok - self.lora_rank = lora_rank - self.dtype = dtype - - pin_memory = False - - # ========== Forward buffers ========== - self.input_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory) - self.expert_ids_cpu = torch.empty( - (qlen, num_experts_per_tok), dtype=torch.int64, device="cpu", pin_memory=pin_memory - ) - self.weights_cpu = torch.empty( - (qlen, num_experts_per_tok), dtype=torch.float32, device="cpu", pin_memory=pin_memory - ) - self.output_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory) - - # ========== Backward buffers ========== - self.grad_output_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory) - self.grad_input_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory) - - # Routing weights gradient [qlen, num_experts_per_tok] (FP32) - self.grad_weights = torch.empty((qlen, num_experts_per_tok), dtype=torch.float32, device="cpu") - - # Batch size tensor for C++ interface - self.bsz_tensor = torch.tensor([qlen], dtype=torch.int32, device="cpu") - - @classmethod - def get_buffer( - cls, - qlen: int, - hidden_size: int, - moe_intermediate_size: int, - num_experts: int, - num_experts_per_tok: int, - lora_rank: int, - dtype: torch.dtype = torch.bfloat16, - ) -> "KExpertsSFTBuffer": - """ - Get or grow the single shared buffer. - - Only reallocates when qlen exceeds current capacity. - Callers must use [:qlen] slicing for copy/return since the - buffer may be larger than the current batch. - """ - buf = cls._shared_buffer - if buf is not None and qlen <= buf.qlen: - return buf - - # Need a (re)allocation — grow to new qlen - cls._shared_buffer = cls( - qlen=qlen, - hidden_size=hidden_size, - moe_intermediate_size=moe_intermediate_size, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - lora_rank=lora_rank, - dtype=dtype, - ) - return cls._shared_buffer - - @classmethod - def clear_cache(cls) -> None: - """Clear the shared buffer.""" - cls._shared_buffer = None - - -class BaseSFTMoEWrapper(_MoEBase, ABC): - """ - Base class for SFT MoE CPU operations. - - Provides LoRA fine-tuning functionality including: - - forward_sft(): Forward pass with gradient caching - - backward(): Backward pass computing LoRA gradients - - update_lora_weights(): Sync LoRA weights to C++ backend - - Key differences from BaseMoEWrapper (inference): - - Uses synchronous execution (no double buffering) - - Maintains forward cache for backward pass - - Independent forward_sft() implementation (not sharing inference forward()) - - Design Decision (forward_sft vs forward relationship): - forward_sft() is implemented independently from forward() because: - 1. Different requirements: inference optimizes for latency, SFT requires gradient correctness - 2. Safety: inference optimizations (deferred experts, async execution) would break SFT gradients - 3. Most reusable optimizations are already in C++ layer (via inheritance) - 4. Manual copying of useful optimizations is safer and more maintainable - - Attributes: - lora_rank: LoRA low-rank matrix rank - lora_alpha: LoRA scaling factor - lora_scaling: Actual scaling value (lora_alpha / lora_rank) - max_cache_depth: Maximum forward cache depth for gradient checkpointing - """ - - def __init__( - self, - layer_idx: int, - num_experts: int, - num_experts_per_tok: int, - hidden_size: int, - moe_intermediate_size: int, - num_gpu_experts: int, - cpuinfer_threads: int, - threadpool_count: int, - weight_path: str, - chunked_prefill_size: int, - # SFT-specific parameters - lora_rank: int = 16, - lora_alpha: float = 32.0, - max_cache_depth: int = 1, - ): - """ - Initialize SFT MoE Wrapper. - - Args: - layer_idx: Layer index - num_experts: Total number of experts - num_experts_per_tok: Number of experts per token (top-k) - hidden_size: Hidden dimension size - moe_intermediate_size: MoE intermediate size - num_gpu_experts: Number of experts on GPU (usually 0 for SFT) - cpuinfer_threads: Number of CPU inference threads - threadpool_count: Number of NUMA subpools (TP count) - weight_path: Path to weights - chunked_prefill_size: Maximum prefill chunk size - lora_rank: LoRA rank (r) - lora_alpha: LoRA scaling factor (alpha) - max_cache_depth: Maximum forward cache depth - """ - # Get shared CPUInfer instance - self.cpu_infer = self._get_cpu_infer(cpuinfer_threads, threadpool_count) - - # Validate basic configuration - self._validate_base_config( - num_experts=num_experts, - hidden_size=hidden_size, - moe_intermediate_size=moe_intermediate_size, - num_experts_per_tok=num_experts_per_tok, - ) - - # Validate SFT-specific parameters - self._validate_sft_config(lora_rank, lora_alpha, max_cache_depth) - - # Save configuration - self.layer_idx = layer_idx - self.num_experts = num_experts - self.num_experts_per_tok = num_experts_per_tok - self.hidden_size = hidden_size - self.moe_intermediate_size = moe_intermediate_size - self.num_gpu_experts = num_gpu_experts - self.weight_path = weight_path - self.chunked_prefill_size = chunked_prefill_size - self.threadpool_count = threadpool_count - - # SFT-specific configuration - self.lora_rank = lora_rank - self.lora_alpha = lora_alpha - self.lora_scaling = lora_alpha / lora_rank - self.max_cache_depth = max_cache_depth - - # LoRA weight placeholders (set via init_lora_weights) - self.gate_lora_a: Optional[torch.Tensor] = None - self.gate_lora_b: Optional[torch.Tensor] = None - self.up_lora_a: Optional[torch.Tensor] = None - self.up_lora_b: Optional[torch.Tensor] = None - self.down_lora_a: Optional[torch.Tensor] = None - self.down_lora_b: Optional[torch.Tensor] = None - - # State tracking - self._weights_loaded: bool = False - self._lora_initialized: bool = False - self._cache_depth: int = 0 - - # Backend-specific initialization happens in subclasses - self.moe = None - - @staticmethod - def _validate_sft_config(lora_rank: int, lora_alpha: float, max_cache_depth: int) -> None: - """ - Validate SFT-specific parameters. - - Raises: - ValueError: If parameters are invalid - """ - if lora_rank <= 0: - raise ValueError(f"lora_rank must be positive, got {lora_rank}") - if lora_alpha <= 0: - raise ValueError(f"lora_alpha must be positive, got {lora_alpha}") - if max_cache_depth <= 0: - raise ValueError(f"max_cache_depth must be positive, got {max_cache_depth}") - - @abstractmethod - def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None: - """ - Load base weights for this layer. - - Args: - physical_to_logical_map_cpu: Mapping from physical to logical expert IDs - """ - pass - - @abstractmethod - def init_lora_weights( - self, - gate_lora_a: torch.Tensor, - gate_lora_b: torch.Tensor, - up_lora_a: torch.Tensor, - up_lora_b: torch.Tensor, - down_lora_a: torch.Tensor, - down_lora_b: torch.Tensor, - grad_gate_lora_a: torch.Tensor, - grad_gate_lora_b: torch.Tensor, - grad_up_lora_a: torch.Tensor, - grad_up_lora_b: torch.Tensor, - grad_down_lora_a: torch.Tensor, - grad_down_lora_b: torch.Tensor, - ) -> None: - """ - Initialize LoRA weights. - - LoRA output formula: - lora_output = (input @ A.T @ B.T) * (lora_alpha / lora_rank) - output = base_output + lora_output - - Args: - gate_lora_a: Gate LoRA A matrix [num_experts, lora_rank, hidden_size] - gate_lora_b: Gate LoRA B matrix [num_experts, intermediate_size, lora_rank] - up_lora_a: Up LoRA A matrix [num_experts, lora_rank, hidden_size] - up_lora_b: Up LoRA B matrix [num_experts, intermediate_size, lora_rank] - down_lora_a: Down LoRA A matrix [num_experts, lora_rank, intermediate_size] - down_lora_b: Down LoRA B matrix [num_experts, hidden_size, lora_rank] - """ - pass - - @abstractmethod - def forward_sft( - self, - hidden_states: torch.Tensor, - expert_ids: torch.Tensor, - weights: torch.Tensor, - save_for_backward: bool = True, - output_device: Optional[torch.device] = None, - ) -> torch.Tensor: - """ - SFT forward pass with optional gradient caching. - - Optimized for minimal data copying: - - Accepts GPU tensors directly, copies to pinned buffer in one step - - Returns directly to output_device without intermediate clone - - Args: - hidden_states: Input hidden states [qlen, hidden_size] (any device) - expert_ids: Expert IDs [qlen, num_experts_per_tok] (any device) - weights: Expert weights [qlen, num_experts_per_tok] (any device) - save_for_backward: Whether to save activations for backward pass - output_device: Target device for output (None = clone CPU tensor) - - Returns: - Output hidden states [qlen, hidden_size] - """ - pass - - @abstractmethod - def backward( - self, - grad_output: torch.Tensor, - output_device: Optional[torch.device] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Backward pass computing gradients. - - Must be called after forward_sft(save_for_backward=True). - - Optimized for minimal data copying: - - Accepts GPU tensors directly - - Returns grad_input directly to output_device without intermediate clone - - LoRA gradients are returned in grad_loras dict (no clone needed) - - Args: - grad_output: Gradient from upstream [qlen, hidden_size] (any device) - lora_params: Optional dict of LoRA parameters (kept for compatibility). - If provided, gradients are still returned in grad_loras. - Keys: gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, down_lora_a, down_lora_b - output_device: Target device for grad_input (None = clone CPU tensor) - - Returns: - grad_input: Input gradient [qlen, hidden_size] - grad_loras: LoRA gradients dict (e.g., grad_gate_lora_a, grad_gate_lora_b, ...) - grad_weights: Routing weights gradient [qlen, num_experts_per_tok] - """ - pass - - @abstractmethod - def update_lora_weights(self) -> None: - """ - Sync LoRA weights to C++ backend. - - Call this after using an external optimizer to update LoRA weights. - This is a zero-copy operation that passes Python tensor pointers. - - Typical usage: - # 1. Forward + backward - output = wrapper.forward_sft(input, expert_ids, weights) - grad_input, grad_loras = wrapper.backward(grad_output) - - # 2. Update LoRA weights with optimizer - optimizer.step() - - # 3. Sync to C++ - wrapper.update_lora_weights() - """ - pass - - @abstractmethod - def submit_forward_sft( - self, - hidden_states: torch.Tensor, - expert_ids: torch.Tensor, - weights: torch.Tensor, - save_for_backward: bool = True, - ) -> None: - """ - Submit SFT forward pass asynchronously (non-blocking). - - This method submits the CPU MoE computation without waiting for completion, - allowing GPU computation (shared_experts, lora_experts) to proceed in parallel. - - Must be followed by sync_forward_sft() to retrieve results. - - Args: - hidden_states: Input hidden states [qlen, hidden_size] - expert_ids: Expert IDs [qlen, num_experts_per_tok] - weights: Expert weights [qlen, num_experts_per_tok] - save_for_backward: Whether to save activations for backward pass - """ - pass - - @abstractmethod - def sync_forward_sft(self, output_device: Optional[torch.device] = None) -> torch.Tensor: - """ - Synchronize and retrieve SFT forward results. - - Must be called after submit_forward_sft(). - - Args: - output_device: Target device for output (None = clone CPU tensor) - - Returns: - Output hidden states [qlen, hidden_size] - """ - pass - - # ========== Inference methods (not available in SFT mode) ========== - - def forward(self, *args, **kwargs): - """Inference forward is not available in SFT mode.""" - raise RuntimeError("forward() is not available in SFT mode. " "Use forward_sft() instead.") - - def submit_forward(self, *args, **kwargs): - """Async submit is not available in SFT mode.""" - raise RuntimeError("submit_forward() is not available in SFT mode. " "Use submit_forward_sft() instead.") - - def sync_forward(self, *args, **kwargs): - """Async sync is not available in SFT mode.""" - raise RuntimeError("sync_forward() is not available in SFT mode. " "Use sync_forward_sft() instead.") - - def select_deferred_experts(self, *args, **kwargs): - """Deferred experts is not available in SFT mode.""" - raise RuntimeError( - "select_deferred_experts() is not available in SFT mode. " - "SFT requires all experts for gradient computation." - ) diff --git a/kt-kernel/python/sft/__init__.py b/kt-kernel/python/sft/__init__.py new file mode 100644 index 00000000..7cab43bd --- /dev/null +++ b/kt-kernel/python/sft/__init__.py @@ -0,0 +1,83 @@ +# SFT (Supervised Fine-Tuning) submodule for kt-kernel +# SPDX-License-Identifier: Apache-2.0 + +""" +SFT training support for KT-Kernel MoE. + +This submodule adds training capabilities (forward/backward, LoRA, autograd, +distributed) on top of the inference-only kt_kernel base package. + +Additional dependencies beyond base kt_kernel: torch.nn, torch.distributed, peft (optional). +""" + +from .config import KTConfig +from .base import BaseSFTMoEWrapper, KExpertsSFTBuffer +from .amx import AMXSFTMoEWrapper +from .arch import ( + MOEArchConfig, get_moe_arch_config, get_moe_module, move_non_experts_to_gpu, get_expert_device, + KTAMXError, KTAMXNotAvailableError, KTAMXModelNotSupportedError, KTAMXConfigError, +) +from .autograd import KTMoEFunction +from .layer import KTMoELayerWrapper +from .weights import ( + extract_moe_weights, + load_experts_from_checkpoint_files, + load_experts_from_kt_weight_path, + INT8ExpertWeights, +) +from .lora import ( + kt_adapt_peft_lora, + get_kt_lora_params, + update_kt_lora_pointers, + sync_kt_lora_gradients, + save_lora_experts_to_adapter, + save_kt_moe_to_adapter, + load_lora_experts_from_adapter, + load_kt_moe_from_adapter, + LoRAExpertMLP, + LoRAExperts, +) +from .wrapper import ( + wrap_moe_layers_with_kt_wrapper, + build_kt_device_map, + build_kt_device_map_simplified, + get_kt_loading_kwargs, + load_kt_model, +) + +__all__ = [ + "KTConfig", + "BaseSFTMoEWrapper", + "KExpertsSFTBuffer", + "AMXSFTMoEWrapper", + "MOEArchConfig", + "get_moe_arch_config", + "get_moe_module", + "move_non_experts_to_gpu", + "get_expert_device", + "KTAMXError", + "KTAMXNotAvailableError", + "KTAMXModelNotSupportedError", + "KTAMXConfigError", + "KTMoEFunction", + "KTMoELayerWrapper", + "extract_moe_weights", + "load_experts_from_checkpoint_files", + "load_experts_from_kt_weight_path", + "INT8ExpertWeights", + "kt_adapt_peft_lora", + "get_kt_lora_params", + "update_kt_lora_pointers", + "sync_kt_lora_gradients", + "save_lora_experts_to_adapter", + "save_kt_moe_to_adapter", + "load_lora_experts_from_adapter", + "load_kt_moe_from_adapter", + "LoRAExpertMLP", + "LoRAExperts", + "wrap_moe_layers_with_kt_wrapper", + "build_kt_device_map", + "build_kt_device_map_simplified", + "get_kt_loading_kwargs", + "load_kt_model", +] diff --git a/kt-kernel/python/sft/amx.py b/kt-kernel/python/sft/amx.py new file mode 100644 index 00000000..3f3270f0 --- /dev/null +++ b/kt-kernel/python/sft/amx.py @@ -0,0 +1,434 @@ +# AMX SFT MoE Wrapper implementation +# SPDX-License-Identifier: Apache-2.0 + +""" +AMX-based SFT MoE Wrapper. Forward/backward buffer management is in base class; +this file handles weight loading, LoRA init, and C++ task construction. +""" + +from __future__ import annotations + +import ctypes +import os +import glob as _glob +import torch +from typing import Optional, List + +from kt_kernel_ext.moe import MOESFTConfig + +from ..utils.loader import BF16SafeTensorLoader, SafeTensorLoader + +try: + from kt_kernel_ext.moe import ( + AMXBF16_SFT_MOE, + AMXInt8_SFT_MOE, + AMXInt4_SFT_MOE, + AMXBF16_SFT_MOE_SkipLoRA, + AMXInt8_SFT_MOE_SkipLoRA, + AMXInt4_SFT_MOE_SkipLoRA, + ) + + _HAS_AMX_SFT_SUPPORT = True +except (ImportError, AttributeError): + _HAS_AMX_SFT_SUPPORT = False + AMXBF16_SFT_MOE = None + AMXInt8_SFT_MOE = None + AMXInt4_SFT_MOE = None + AMXBF16_SFT_MOE_SkipLoRA = None + AMXInt8_SFT_MOE_SkipLoRA = None + AMXInt4_SFT_MOE_SkipLoRA = None + +from .base import BaseSFTMoEWrapper, KExpertsSFTBuffer + + +# Mapping from method string to C++ SFT MOE class +_SFT_METHOD_TO_CLASS = { + "AMXBF16_SFT": AMXBF16_SFT_MOE, + "AMXINT8_SFT": AMXInt8_SFT_MOE, + "AMXINT4_SFT": AMXInt4_SFT_MOE, + "AMXBF16_SFT_SkipLoRA": AMXBF16_SFT_MOE_SkipLoRA, + "AMXINT8_SFT_SkipLoRA": AMXInt8_SFT_MOE_SkipLoRA, + "AMXINT4_SFT_SkipLoRA": AMXInt4_SFT_MOE_SkipLoRA, +} + + +class AMXSFTMoEWrapper(BaseSFTMoEWrapper): + """ + AMX-based SFT MoE wrapper. + + Supports BF16, INT8, INT4, and SkipLoRA variants. + Forward/backward buffer management is in BaseSFTMoEWrapper; + this class implements weight loading and C++ task construction. + """ + + def __init__( + self, + layer_idx: int, + num_experts: int, + num_experts_per_tok: int, + hidden_size: int, + moe_intermediate_size: int, + num_gpu_experts: int, + cpuinfer_threads: int, + threadpool_count: int, + weight_path: str, + chunked_prefill_size: int, + lora_rank: int = 16, + lora_alpha: float = 32.0, + max_cache_depth: int = 1, + method: str = "AMXBF16_SFT", + group_size: int = 128, + zero_point: bool = True, + ): + if not _HAS_AMX_SFT_SUPPORT: + raise RuntimeError( + "AMX SFT backend not available. kt_kernel_ext was not compiled with AMX SFT support.\n" + "Please recompile with AMX SFT enabled." + ) + + super().__init__( + layer_idx=layer_idx, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + hidden_size=hidden_size, + moe_intermediate_size=moe_intermediate_size, + num_gpu_experts=num_gpu_experts, + cpuinfer_threads=cpuinfer_threads, + threadpool_count=threadpool_count, + weight_path=weight_path, + chunked_prefill_size=chunked_prefill_size, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + max_cache_depth=max_cache_depth, + ) + + self.method = method + self._is_skip_lora = "SkipLoRA" in method + self.group_size = group_size + self.zero_point = zero_point + + if method not in _SFT_METHOD_TO_CLASS: + raise ValueError(f"Unknown SFT method: {method}. Supported: {list(_SFT_METHOD_TO_CLASS.keys())}") + + moe_class = _SFT_METHOD_TO_CLASS[method] + if moe_class is None: + raise RuntimeError(f"AMX SFT method '{method}' not available in current build.") + + self.gate_proj: Optional[torch.Tensor] = None + self.up_proj: Optional[torch.Tensor] = None + self.down_proj: Optional[torch.Tensor] = None + + self._moe_class = moe_class + + # ========== Template method: C++ task construction ========== + + def _make_forward_task(self, buffer: KExpertsSFTBuffer, save_for_backward: bool): + return self.moe.forward_sft_task( + buffer.bsz_tensor.data_ptr(), + self.num_experts_per_tok, + buffer.expert_ids_cpu.data_ptr(), + buffer.weights_cpu.data_ptr(), + buffer.input_cpu.data_ptr(), + buffer.output_cpu.data_ptr(), + save_for_backward, + ) + + def _make_backward_task(self, buffer: KExpertsSFTBuffer): + if self._is_skip_lora: + return self.moe.backward_task( + buffer.grad_output_cpu.data_ptr(), + buffer.grad_input_cpu.data_ptr(), + 0, 0, 0, 0, 0, 0, + buffer.grad_weights.data_ptr(), + ) + return self.moe.backward_task( + buffer.grad_output_cpu.data_ptr(), + buffer.grad_input_cpu.data_ptr(), + self.grad_gate_lora_a.data_ptr(), + self.grad_gate_lora_b.data_ptr(), + self.grad_up_lora_a.data_ptr(), + self.grad_up_lora_b.data_ptr(), + self.grad_down_lora_a.data_ptr(), + self.grad_down_lora_b.data_ptr(), + buffer.grad_weights.data_ptr(), + ) + + # ========== Weight loading ========== + + def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None: + if self._weights_loaded: + return + + if self.gate_proj is None and not getattr(self, "_use_projs_path", False): + self._load_base_weights_from_file() + + config = MOESFTConfig() + config.expert_num = self.num_experts + config.num_experts_per_tok = self.num_experts_per_tok + config.hidden_size = self.hidden_size + config.intermediate_size = self.moe_intermediate_size + config.lora_rank = self.lora_rank + config.lora_alpha = self.lora_alpha + config.max_cache_depth = self.max_cache_depth + config.max_len = self.chunked_prefill_size + config.layer_idx = self.layer_idx + config.share_backward_bb = getattr(self, "share_backward_bb", False) + config.share_cache_pool = getattr(self, "share_cache_pool", False) + + if getattr(self, "_use_kt_direct_load", False): + config.load = True + config.path = self.weight_path + elif getattr(self, "_use_projs_path", False): + config.gate_projs = self._gate_projs_ptrs + config.up_projs = self._up_projs_ptrs + config.down_projs = self._down_projs_ptrs + config.gate_scales = self._gate_scale_ptrs + config.up_scales = self._up_scale_ptrs + config.down_scales = self._down_scale_ptrs + if getattr(self, "_bf16_gate_proj", None) is not None: + config.gate_proj = self._bf16_gate_proj.data_ptr() + config.up_proj = self._bf16_up_proj.data_ptr() + config.down_proj = self._bf16_down_proj.data_ptr() + if getattr(self, "_has_bwd_projs", False): + config.gate_bwd_projs = self._gate_bwd_projs_ptrs + config.up_bwd_projs = self._up_bwd_projs_ptrs + config.down_bwd_projs = self._down_bwd_projs_ptrs + config.gate_bwd_scales = self._gate_bwd_scale_ptrs + config.up_bwd_scales = self._up_bwd_scale_ptrs + config.down_bwd_scales = self._down_bwd_scale_ptrs + else: + config.gate_proj = self.gate_proj.data_ptr() + config.up_proj = self.up_proj.data_ptr() + config.down_proj = self.down_proj.data_ptr() + + if self._lora_initialized: + config.gate_lora_a = self.gate_lora_a.data_ptr() + config.gate_lora_b = self.gate_lora_b.data_ptr() + config.up_lora_a = self.up_lora_a.data_ptr() + config.up_lora_b = self.up_lora_b.data_ptr() + config.down_lora_a = self.down_lora_a.data_ptr() + config.down_lora_b = self.down_lora_b.data_ptr() + + config.pool = self.cpu_infer.backend_ + + if self.method in ("AMXINT4_KGroup_SFT", "AMXINT4_1KGroup_SFT"): + config.quant_config.group_size = self.group_size + config.quant_config.zero_point = self.zero_point + + self.moe = self._moe_class(config) + + self.cpu_infer.submit(self.moe.load_weights_task()) + self.cpu_infer.sync() + + self.cpu_infer.submit(self.moe.warm_up_task()) + self.cpu_infer.sync() + + # Release Python-side weight tensors (C++ copied them) + self.gate_proj = None + self.up_proj = None + self.down_proj = None + + if getattr(self, "_bf16_gate_proj", None) is not None: + self._bf16_gate_proj = None + self._bf16_up_proj = None + self._bf16_down_proj = None + + if getattr(self, "_use_projs_path", False): + for attr in [ + "_gate_weights_per_numa", "_up_weights_per_numa", "_down_weights_per_numa", + "_gate_scales_per_numa", "_up_scales_per_numa", "_down_scales_per_numa", + "_gate_projs_ptrs", "_up_projs_ptrs", "_down_projs_ptrs", + "_gate_scale_ptrs", "_up_scale_ptrs", "_down_scale_ptrs", + ]: + setattr(self, attr, None) + if getattr(self, "_has_bwd_projs", False): + for attr in [ + "_gate_bwd_weights_per_numa", "_up_bwd_weights_per_numa", "_down_bwd_weights_per_numa", + "_gate_bwd_scales_per_numa", "_up_bwd_scales_per_numa", "_down_bwd_scales_per_numa", + "_gate_bwd_projs_ptrs", "_up_bwd_projs_ptrs", "_down_bwd_projs_ptrs", + "_gate_bwd_scale_ptrs", "_up_bwd_scale_ptrs", "_down_bwd_scale_ptrs", + ]: + setattr(self, attr, None) + + self._weights_loaded = True + + def load_weights_from_tensors( + self, + gate_proj: torch.Tensor, + up_proj: torch.Tensor, + down_proj: torch.Tensor, + physical_to_logical_map_cpu: torch.Tensor, + ) -> None: + self.gate_proj = gate_proj.contiguous() + self.up_proj = up_proj.contiguous() + self.down_proj = down_proj.contiguous() + self.load_weights(physical_to_logical_map_cpu) + del gate_proj, up_proj, down_proj + + def _load_base_weights_from_file(self) -> None: + if not hasattr(self, "weight_path") or self.weight_path is None: + raise RuntimeError( + "weight_path not set. Cannot load weights from file. " + "Either set weight_path or call load_weights_from_tensors() instead." + ) + + kt_layer_dir = os.path.join(self.weight_path, f"_layer_{self.layer_idx}") + if os.path.isdir(kt_layer_dir): + kt_files = _glob.glob(os.path.join(kt_layer_dir, "_numa_0", "*.kt")) + if kt_files: + self._use_kt_direct_load = True + return + + if "BF16" in self.method: + loader = BF16SafeTensorLoader(self.weight_path) + base_key = f"model.layers.{self.layer_idx}" + else: + loader = SafeTensorLoader(self.weight_path) + base_key = f"blk.{self.layer_idx}" + + experts_data = loader.load_experts(base_key, device="cpu") + + gate_weights: List[torch.Tensor] = experts_data["gate"] + up_weights: List[torch.Tensor] = experts_data["up"] + down_weights: List[torch.Tensor] = experts_data["down"] + + if "BF16" in self.method: + self.gate_proj = torch.stack(gate_weights, dim=0).contiguous() + self.up_proj = torch.stack(up_weights, dim=0).contiguous() + self.down_proj = torch.stack(down_weights, dim=0).contiguous() + else: + def _make_ptrs(arrays_per_numa): + return [ + [ + ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) + for et in numa_array + ] + for numa_array in arrays_per_numa + ] + + self._gate_weights_per_numa = gate_weights + self._up_weights_per_numa = up_weights + self._down_weights_per_numa = down_weights + self._gate_scales_per_numa = experts_data["gate_scale"] + self._up_scales_per_numa = experts_data["up_scale"] + self._down_scales_per_numa = experts_data["down_scale"] + + self._gate_projs_ptrs = _make_ptrs(gate_weights) + self._up_projs_ptrs = _make_ptrs(up_weights) + self._down_projs_ptrs = _make_ptrs(down_weights) + self._gate_scale_ptrs = _make_ptrs(experts_data["gate_scale"]) + self._up_scale_ptrs = _make_ptrs(experts_data["up_scale"]) + self._down_scale_ptrs = _make_ptrs(experts_data["down_scale"]) + + if "gate_bwd" in experts_data: + self._gate_bwd_weights_per_numa = experts_data["gate_bwd"] + self._up_bwd_weights_per_numa = experts_data["up_bwd"] + self._down_bwd_weights_per_numa = experts_data["down_bwd"] + self._gate_bwd_scales_per_numa = experts_data["gate_bwd_scale"] + self._up_bwd_scales_per_numa = experts_data["up_bwd_scale"] + self._down_bwd_scales_per_numa = experts_data["down_bwd_scale"] + + self._gate_bwd_projs_ptrs = _make_ptrs(experts_data["gate_bwd"]) + self._up_bwd_projs_ptrs = _make_ptrs(experts_data["up_bwd"]) + self._down_bwd_projs_ptrs = _make_ptrs(experts_data["down_bwd"]) + self._gate_bwd_scale_ptrs = _make_ptrs(experts_data["gate_bwd_scale"]) + self._up_bwd_scale_ptrs = _make_ptrs(experts_data["up_bwd_scale"]) + self._down_bwd_scale_ptrs = _make_ptrs(experts_data["down_bwd_scale"]) + self._has_bwd_projs = True + else: + self._has_bwd_projs = False + + self.gate_proj = None + self.up_proj = None + self.down_proj = None + self._use_projs_path = True + + loader.close_all_handles() + + # ========== LoRA ========== + + def init_lora_weights( + self, + gate_lora_a: torch.Tensor, gate_lora_b: torch.Tensor, + up_lora_a: torch.Tensor, up_lora_b: torch.Tensor, + down_lora_a: torch.Tensor, down_lora_b: torch.Tensor, + grad_gate_lora_a: torch.Tensor, grad_gate_lora_b: torch.Tensor, + grad_up_lora_a: torch.Tensor, grad_up_lora_b: torch.Tensor, + grad_down_lora_a: torch.Tensor, grad_down_lora_b: torch.Tensor, + ) -> None: + expected_shapes = { + "gate_lora_a": (self.num_experts, self.lora_rank, self.hidden_size), + "gate_lora_b": (self.num_experts, self.moe_intermediate_size, self.lora_rank), + "up_lora_a": (self.num_experts, self.lora_rank, self.hidden_size), + "up_lora_b": (self.num_experts, self.moe_intermediate_size, self.lora_rank), + "down_lora_a": (self.num_experts, self.lora_rank, self.moe_intermediate_size), + "down_lora_b": (self.num_experts, self.hidden_size, self.lora_rank), + } + provided = { + "gate_lora_a": gate_lora_a, "gate_lora_b": gate_lora_b, + "up_lora_a": up_lora_a, "up_lora_b": up_lora_b, + "down_lora_a": down_lora_a, "down_lora_b": down_lora_b, + } + for name, tensor in provided.items(): + expected = expected_shapes[name] + if tensor.shape != expected: + raise ValueError(f"{name} shape mismatch: expected {expected}, got {tuple(tensor.shape)}") + + self.gate_lora_a = gate_lora_a.contiguous() + self.gate_lora_b = gate_lora_b.contiguous() + self.up_lora_a = up_lora_a.contiguous() + self.up_lora_b = up_lora_b.contiguous() + self.down_lora_a = down_lora_a.contiguous() + self.down_lora_b = down_lora_b.contiguous() + + self.grad_gate_lora_a = grad_gate_lora_a.contiguous() + self.grad_gate_lora_b = grad_gate_lora_b.contiguous() + self.grad_up_lora_a = grad_up_lora_a.contiguous() + self.grad_up_lora_b = grad_up_lora_b.contiguous() + self.grad_down_lora_a = grad_down_lora_a.contiguous() + self.grad_down_lora_b = grad_down_lora_b.contiguous() + + self._lora_initialized = True + + if self._weights_loaded and self.moe is not None: + self.update_lora_weights() + + def update_lora_weights(self) -> None: + if not self._weights_loaded: + raise RuntimeError("Weights not loaded. Call load_weights() first.") + if self._is_skip_lora: + return + if not self._lora_initialized: + raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.") + + self.cpu_infer.submit( + self.moe.update_lora_weights_task( + self.gate_lora_a.data_ptr(), + self.gate_lora_b.data_ptr(), + self.up_lora_a.data_ptr(), + self.up_lora_b.data_ptr(), + self.down_lora_a.data_ptr(), + self.down_lora_b.data_ptr(), + ) + ) + self.cpu_infer.sync() + + def save_backward_weights_from_tensors( + self, + gate_proj: torch.Tensor, + up_proj: torch.Tensor, + down_proj: torch.Tensor, + physical_to_logical_map: torch.Tensor, + output_path: str, + ) -> None: + if not self._weights_loaded: + raise RuntimeError("Weights not loaded. Call load_weights() first.") + gate_proj = gate_proj.contiguous() + up_proj = up_proj.contiguous() + down_proj = down_proj.contiguous() + self.moe.prepare_and_save_bwd( + gate_proj.data_ptr(), + up_proj.data_ptr(), + down_proj.data_ptr(), + output_path, + ) diff --git a/kt-kernel/python/sft/arch.py b/kt-kernel/python/sft/arch.py new file mode 100644 index 00000000..80c88136 --- /dev/null +++ b/kt-kernel/python/sft/arch.py @@ -0,0 +1,265 @@ +# MoE architecture configuration and model utilities +# SPDX-License-Identifier: Apache-2.0 + +""" +MoE architecture detection and model navigation utilities. + +This is a leaf module — no imports from other sft/ submodules. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass + +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Exceptions +# ============================================================================= + + +class KTAMXError(Exception): + """Base exception for KT AMX errors.""" + + +class KTAMXNotAvailableError(KTAMXError): + """kt_kernel not installed or AMX not supported.""" + + +class KTAMXModelNotSupportedError(KTAMXError): + """Model architecture not supported.""" + + +class KTAMXConfigError(KTAMXError): + """Configuration error.""" + + +# ============================================================================= +# MoE Configuration +# ============================================================================= + + +@dataclass +class MOEArchConfig: + """MoE architecture configuration for different model types.""" + + moe_layer_attr: str + router_attr: str + experts_attr: str + weight_names: tuple[str, str, str] + expert_num: int + intermediate_size: int + num_experts_per_tok: int + has_shared_experts: bool = False + router_type: str = "linear" + + +def get_moe_arch_config(config) -> MOEArchConfig: + """ + Get MoE architecture configuration based on model type. + + Args: + config: HuggingFace model configuration + + Returns: + MOEArchConfig for the model + + Raises: + KTAMXModelNotSupportedError: If model architecture is not supported + """ + arch = config.architectures[0] if getattr(config, "architectures", None) else "" + + if "DeepseekV2" in arch: + return MOEArchConfig( + moe_layer_attr="mlp", + router_attr="gate", + experts_attr="experts", + weight_names=("gate_proj", "up_proj", "down_proj"), + expert_num=config.n_routed_experts, + intermediate_size=config.moe_intermediate_size, + num_experts_per_tok=config.num_experts_per_tok, + has_shared_experts=getattr(config, "n_shared_experts", 0) > 0, + router_type="deepseek_gate", + ) + if "DeepseekV3" in arch: + return MOEArchConfig( + moe_layer_attr="mlp", + router_attr="gate", + experts_attr="experts", + weight_names=("gate_proj", "up_proj", "down_proj"), + expert_num=config.n_routed_experts, + intermediate_size=config.moe_intermediate_size, + num_experts_per_tok=config.num_experts_per_tok, + has_shared_experts=getattr(config, "n_shared_experts", 0) > 0, + router_type="deepseek_gate", + ) + if "Qwen2Moe" in arch or "Qwen3Moe" in arch: + return MOEArchConfig( + moe_layer_attr="mlp", + router_attr="gate", + experts_attr="experts", + weight_names=("gate_proj", "up_proj", "down_proj"), + expert_num=config.num_experts, + intermediate_size=config.moe_intermediate_size, + num_experts_per_tok=config.num_experts_per_tok, + has_shared_experts=getattr(config, "shared_expert_intermediate_size", 0) > 0, + ) + if "Mixtral" in arch: + return MOEArchConfig( + moe_layer_attr="block_sparse_moe", + router_attr="gate", + experts_attr="experts", + weight_names=("w1", "w3", "w2"), + expert_num=config.num_local_experts, + intermediate_size=config.intermediate_size, + num_experts_per_tok=config.num_experts_per_tok, + has_shared_experts=False, + ) + + raise KTAMXModelNotSupportedError( + f"Model architecture {arch} not supported for KT AMX. " + "Supported architectures: DeepseekV2, DeepseekV3, Qwen2Moe, Qwen3Moe, Mixtral" + ) + + +def get_moe_module(layer: nn.Module, moe_config: MOEArchConfig) -> nn.Module | None: + """Get MoE module from transformer layer.""" + moe_module = getattr(layer, moe_config.moe_layer_attr, None) + if moe_module is None: + return None + if not hasattr(moe_module, moe_config.experts_attr): + return None + return moe_module + + +def _get_layers_prefix(config) -> str: + arch = config.architectures[0] if getattr(config, "architectures", None) else "" + if any(x in arch for x in ["Deepseek", "Qwen", "Mixtral", "Llama"]): + return "model.layers" + return "model.layers" + + +def _get_model_container_and_layers(model: nn.Module, *, purpose: str) -> tuple[nn.Module, any]: + """ + Resolve the transformer layer container for KT integration. + + KT expects the transformer block stack to be accessible as `.layers`. + Handles PEFT PeftModel, TRL value-head models, DDP wrappers. + """ + to_visit: list[nn.Module] = [model] + visited: set[int] = set() + visited_types: list[str] = [] + + while to_visit: + current = to_visit.pop(0) + if id(current) in visited: + continue + visited.add(id(current)) + visited_types.append(type(current).__name__) + + layers = getattr(current, "layers", None) + if layers is not None and isinstance(layers, (list, tuple, nn.ModuleList)): + return current, layers + + for attr in ("model", "base_model", "pretrained_model", "module"): + child = getattr(current, attr, None) + if isinstance(child, nn.Module) and child is not current: + to_visit.append(child) + + get_base_model = getattr(current, "get_base_model", None) + if callable(get_base_model): + try: + base = get_base_model() + except Exception: + base = None + if isinstance(base, nn.Module) and base is not current: + to_visit.append(base) + + visited_preview = ", ".join(visited_types[:6]) + if len(visited_types) > 6: + visited_preview += ", ..." + + raise KTAMXConfigError( + f"Model does not expose a .model.layers or .layers attribute for KT {purpose}. " + "Tried unwrapping via model/base_model/pretrained_model/module/get_base_model; " + f"visited: {visited_preview}" + ) + + +def move_non_experts_to_gpu( + model: nn.Module, + moe_config: MOEArchConfig | None = None, + device: str = "cuda:0", +) -> None: + """Move non-expert parameters to GPU after loading (experts stay on CPU).""" + if moe_config is None: + config = getattr(model, "config", None) + if config is None: + raise KTAMXConfigError("Model config is required to infer MoE architecture.") + moe_config = get_moe_arch_config(config) + + container, layers = _get_model_container_and_layers(model, purpose="placement") + + if hasattr(container, "embed_tokens"): + container.embed_tokens.to(device) + if hasattr(container, "norm"): + container.norm.to(device) + if hasattr(model, "lm_head"): + model.lm_head.to(device) + + for layer in layers: + if hasattr(layer, "self_attn"): + layer.self_attn.to(device) + + if hasattr(layer, "input_layernorm"): + layer.input_layernorm.to(device) + if hasattr(layer, "post_attention_layernorm"): + layer.post_attention_layernorm.to(device) + + moe_module = getattr(layer, moe_config.moe_layer_attr, None) + if moe_module is None or not hasattr(moe_module, moe_config.experts_attr): + if hasattr(layer, "mlp"): + layer.mlp.to(device) + continue + + router = getattr(moe_module, moe_config.router_attr, None) + if router is not None: + router.to(device) + + if hasattr(moe_module, "shared_experts") and moe_module.shared_experts is not None: + moe_module.shared_experts.to(device) + + logger.info(f"Moved non-expert parameters to {device}") + + +def get_expert_device(model: nn.Module, moe_config: MOEArchConfig | None = None) -> str: + """Get the device type of MoE experts.""" + if moe_config is None: + config = getattr(model, "config", None) + if config is None: + return "unknown" + moe_config = get_moe_arch_config(config) + + try: + _, layers = _get_model_container_and_layers(model, purpose="expert device probing") + except KTAMXConfigError: + return "unknown" + + for layer in layers: + moe_module = getattr(layer, moe_config.moe_layer_attr, None) + if moe_module is None: + continue + experts = getattr(moe_module, moe_config.experts_attr, None) + if not experts: + continue + first_expert = experts[0] + gate_name = moe_config.weight_names[0] + gate_proj = getattr(first_expert, gate_name, None) + if gate_proj is not None: + return str(gate_proj.weight.device.type) + + return "unknown" diff --git a/kt-kernel/python/sft/autograd.py b/kt-kernel/python/sft/autograd.py new file mode 100644 index 00000000..36981735 --- /dev/null +++ b/kt-kernel/python/sft/autograd.py @@ -0,0 +1,256 @@ +# Autograd function for KT MoE SFT training +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import logging +import os +from typing import Any + +import torch + +from .dist_utils import ( + _all_gather_qlens, + _qlen_offsets, + _dist_gather_varlen_to_rank0, + _dist_scatter_varlen_from_rank0, + _checkpoint_hook_mode, + _is_in_checkpoint_first_forward, +) + +_KT_SFT_DEBUG = os.environ.get("KT_SFT_DEBUG", "0") == "1" + +logger = logging.getLogger(__name__) + + +class KTMoEFunction(torch.autograd.Function): + """Unified autograd function for KTMoE forward/backward.""" + + @staticmethod + def forward( + ctx, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + wrapper: Any, + lora_ref: torch.Tensor, + hidden_size: int, + num_experts_per_tok: int, + layer_idx: int, + training: bool, + train_lora: bool, + all_qlens: list[int] | tuple[int, ...] | None, + ) -> torch.Tensor: + + if _KT_SFT_DEBUG: + logging.debug( + "KTMoEFunction.forward: layer=%d training=%s train_lora=%s", + layer_idx, training, train_lora, + ) + + original_device = hidden_states.device + original_dtype = hidden_states.dtype + batch_size, seq_len, _ = hidden_states.shape + qlen = batch_size * seq_len + + import torch.distributed as dist + dist_on = dist.is_initialized() and dist.get_world_size() > 1 + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist_on else 1 + + ctx.use_broadcast = wrapper is None + + # ---- Sync CPU expert result and distribute ---- + if dist_on: + if all_qlens is None: + all_qlens_list = _all_gather_qlens(qlen, original_device, world_size) + else: + all_qlens_list = [int(q) for q in all_qlens] + if len(all_qlens_list) != world_size: + raise RuntimeError( + f"all_qlens length mismatch: got {len(all_qlens_list)}, expected {world_size}" + ) + if int(all_qlens_list[rank]) != qlen: + raise RuntimeError( + f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens_list[rank]}" + ) + total_qlen = sum(all_qlens_list) + + # Rank 0: sync CPU result and split by real lengths + if rank == 0: + cpu_output = wrapper.sync_forward(output_device=original_device) + cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, hidden_size) + offsets = _qlen_offsets(all_qlens_list) + scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)] + else: + scatter_list = None + + output_flat = _dist_scatter_varlen_from_rank0( + rank0_chunks=scatter_list, + all_qlens=all_qlens_list, + rank=rank, + world_size=world_size, + feature_shape=(hidden_size,), + device=original_device, + dtype=original_dtype, + ) + output = output_flat.view(batch_size, seq_len, hidden_size) + del output_flat + elif wrapper is not None: + # Single-GPU: sync directly + cpu_output = wrapper.sync_forward(output_device=original_device) + output = cpu_output.view(batch_size, seq_len, hidden_size).to(dtype=original_dtype) + else: + # Broadcast-only rank (no wrapper) + output = torch.empty( + batch_size, seq_len, hidden_size, device=original_device, dtype=original_dtype + ) + + ctx.wrapper = wrapper + ctx.hidden_size = hidden_size + ctx.qlen = qlen + ctx.batch_size = batch_size + ctx.seq_len = seq_len + ctx.original_device = original_device + ctx.original_dtype = original_dtype + ctx.weights_shape = topk_weights.shape + ctx.weights_dtype = topk_weights.dtype + ctx.weights_device = topk_weights.device + ctx.dist_on = dist_on + ctx.world_size = world_size + ctx.all_qlens = all_qlens_list if dist_on else None + ctx.num_experts_per_tok = num_experts_per_tok + ctx.layer_idx = layer_idx + + # Save a sentinel tensor so non-reentrant checkpoint's saved_tensors + # hooks can intercept it. When backward accesses ctx.saved_tensors, + # the checkpoint unpack hook triggers a full recompute of the decoder + # layer — which re-runs the MoE forward with save_for_backward=True, + # populating the C++ cache BEFORE this backward proceeds. + # Without this, MoE backward runs before the recompute (MoE comes + # after attention in forward order → its backward runs first), and + # the C++ cache is empty when first-forward cache-skip is active. + ctx.save_for_backward(hidden_states.new_empty(())) + + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # Wait for any in-flight async repack before recompute forward uses the pool + if getattr(ctx.wrapper, 'share_backward_bb', False): + ctx.wrapper.wait_backward_repack() + + # Access saved_tensors FIRST — under non-reentrant checkpoint this + # triggers the unpack hook which runs a full decoder-layer recompute, + # populating the C++ cache before we call wrapper.backward(). + _ = ctx.saved_tensors + + qlen = ctx.qlen + hidden_size = ctx.hidden_size + batch_size = ctx.batch_size + seq_len = ctx.seq_len + dist_on = ctx.dist_on + world_size = ctx.world_size + num_experts_per_tok = ctx.num_experts_per_tok + + import torch.distributed as dist + rank = dist.get_rank() if dist.is_initialized() else 0 + + if _KT_SFT_DEBUG: + logging.debug( + "KTMoEFunction.backward: layer=%d dist_on=%s qlen=%d", + getattr(ctx, "layer_idx", -1), dist_on, qlen, + ) + + if dist_on: + all_qlens = getattr(ctx, "all_qlens", None) + if all_qlens is None or len(all_qlens) != world_size: + all_qlens = _all_gather_qlens(qlen, ctx.original_device, world_size) + else: + all_qlens = [int(q) for q in all_qlens] + if int(all_qlens[rank]) != qlen: + raise RuntimeError( + f"Backward qlen mismatch on rank {rank}: local={qlen}, all_qlens[{rank}]={all_qlens[rank]}" + ) + + grad_out_flat = grad_output.view(qlen, hidden_size).contiguous() + + gathered_go = _dist_gather_varlen_to_rank0( + grad_out_flat, + all_qlens=all_qlens, + rank=rank, + world_size=world_size, + ) + if rank == 0: + all_go = torch.cat(gathered_go, dim=0) + total_qlen = int(all_go.shape[0]) + + backward_out = ctx.wrapper.backward( + all_go, + output_device=ctx.original_device, + ) + if isinstance(backward_out, tuple) and len(backward_out) == 2: + all_grad_input, all_grad_weights = backward_out + elif isinstance(backward_out, tuple) and len(backward_out) == 3: + all_grad_input, _, all_grad_weights = backward_out + else: + raise ValueError("KTMoEWrapper.backward returned unexpected format.") + + all_grad_input = all_grad_input.to(dtype=ctx.original_dtype).view(total_qlen, hidden_size) + all_grad_weights = all_grad_weights.to(dtype=torch.bfloat16).view(total_qlen, num_experts_per_tok) + + offsets = _qlen_offsets(all_qlens) + scatter_gi = [all_grad_input[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)] + scatter_gw = [all_grad_weights[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)] + else: + scatter_gi = None + scatter_gw = None + + grad_input_flat = _dist_scatter_varlen_from_rank0( + rank0_chunks=scatter_gi, + all_qlens=all_qlens, + rank=rank, + world_size=world_size, + feature_shape=(hidden_size,), + device=ctx.original_device, + dtype=ctx.original_dtype, + ) + grad_weights_flat = _dist_scatter_varlen_from_rank0( + rank0_chunks=scatter_gw, + all_qlens=all_qlens, + rank=rank, + world_size=world_size, + feature_shape=(num_experts_per_tok,), + device=ctx.weights_device, + dtype=torch.bfloat16, + ) + grad_input = grad_input_flat.view(batch_size, seq_len, hidden_size) + grad_weights = grad_weights_flat.view(ctx.weights_shape).to(dtype=ctx.weights_dtype) + + elif not ctx.use_broadcast: + # ---- Single-GPU path ---- + grad_output_flat = grad_output.view(qlen, hidden_size) + backward_out = ctx.wrapper.backward( + grad_output_flat, + output_device=ctx.original_device, + ) + ctx.wrapper._kt_has_cached_forward = False + if isinstance(backward_out, tuple) and len(backward_out) == 2: + grad_input, grad_weights = backward_out + elif isinstance(backward_out, tuple) and len(backward_out) == 3: + grad_input, _, grad_weights = backward_out + else: + raise ValueError("KTMoEWrapper.backward returned unexpected format.") + grad_input = grad_input.view(batch_size, seq_len, hidden_size).to(dtype=ctx.original_dtype) + grad_weights = grad_weights.to(dtype=torch.bfloat16) + else: + # No wrapper, no dist — shouldn't happen in normal flow + grad_input = torch.zeros(batch_size, seq_len, hidden_size, device=ctx.original_device, dtype=ctx.original_dtype) + grad_weights = torch.zeros(ctx.weights_shape, device=ctx.weights_device, dtype=ctx.weights_dtype) + + # Trigger async repack for next MoE layer in backward order + next_bwd = getattr(ctx.wrapper, '_next_backward_wrapper', None) + if next_bwd is not None and getattr(next_bwd, 'share_backward_bb', False): + next_bwd.submit_backward_repack() + + return grad_input, None, grad_weights, None, None, None, None, None, None, None, None diff --git a/kt-kernel/python/sft/base.py b/kt-kernel/python/sft/base.py new file mode 100644 index 00000000..25b0e2cb --- /dev/null +++ b/kt-kernel/python/sft/base.py @@ -0,0 +1,402 @@ +# Base classes for SFT MoE operations +# SPDX-License-Identifier: Apache-2.0 + +""" +SFT (Supervised Fine-Tuning) MoE base classes and buffer management. + +Provides: +- KExpertsSFTBuffer: Grow-only shared buffer for forward/backward passes +- BaseSFTMoEWrapper: Abstract base with concrete buffer management (template method pattern) +""" + +from __future__ import annotations + +import torch +from typing import Optional, Tuple +from abc import ABC, abstractmethod + +from ..experts_base import _MoEBase + + +class KExpertsSFTBuffer: + """ + CPU buffer management for SFT expert computation. + + Single grow-only buffer (never shrinks). Callers must use [:qlen] slicing + since the buffer may be larger than the current batch. + """ + + _shared_buffer: Optional["KExpertsSFTBuffer"] = None + + def __init__( + self, + qlen: int, + hidden_size: int, + moe_intermediate_size: int, + num_experts: int, + num_experts_per_tok: int, + lora_rank: int, + dtype: torch.dtype = torch.bfloat16, + ): + self.qlen = qlen + self.hidden_size = hidden_size + self.moe_intermediate_size = moe_intermediate_size + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.lora_rank = lora_rank + self.dtype = dtype + + pin_memory = False + + # Forward buffers + self.input_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory) + self.expert_ids_cpu = torch.empty( + (qlen, num_experts_per_tok), dtype=torch.int64, device="cpu", pin_memory=pin_memory + ) + self.weights_cpu = torch.empty( + (qlen, num_experts_per_tok), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) + self.output_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory) + + # Backward buffers + self.grad_output_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory) + self.grad_input_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory) + self.grad_weights = torch.empty((qlen, num_experts_per_tok), dtype=torch.float32, device="cpu") + + # Batch size tensor for C++ interface + self.bsz_tensor = torch.tensor([qlen], dtype=torch.int32, device="cpu") + + @classmethod + def get_buffer( + cls, + qlen: int, + hidden_size: int, + moe_intermediate_size: int, + num_experts: int, + num_experts_per_tok: int, + lora_rank: int, + dtype: torch.dtype = torch.bfloat16, + ) -> "KExpertsSFTBuffer": + """Get or grow the single shared buffer. Only reallocates when qlen exceeds capacity.""" + buf = cls._shared_buffer + if buf is not None and qlen <= buf.qlen: + return buf + cls._shared_buffer = cls( + qlen=qlen, + hidden_size=hidden_size, + moe_intermediate_size=moe_intermediate_size, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + lora_rank=lora_rank, + dtype=dtype, + ) + return cls._shared_buffer + + @classmethod + def clear_cache(cls) -> None: + """Clear the shared buffer.""" + cls._shared_buffer = None + + +class BaseSFTMoEWrapper(_MoEBase, ABC): + """ + Base class for SFT MoE CPU operations with concrete buffer management. + + Subclasses implement: + - _make_forward_task(buffer, save_for_backward) -> C++ task object + - _make_backward_task(buffer) -> C++ task object + - load_weights(physical_to_logical_map_cpu) + - init_lora_weights(...) + - update_lora_weights() + """ + + def __init__( + self, + layer_idx: int, + num_experts: int, + num_experts_per_tok: int, + hidden_size: int, + moe_intermediate_size: int, + num_gpu_experts: int, + cpuinfer_threads: int, + threadpool_count: int, + weight_path: str, + chunked_prefill_size: int, + lora_rank: int = 16, + lora_alpha: float = 32.0, + max_cache_depth: int = 1, + ): + self.cpu_infer = self._get_cpu_infer(cpuinfer_threads, threadpool_count) + + self._validate_base_config( + num_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=moe_intermediate_size, + num_experts_per_tok=num_experts_per_tok, + ) + self._validate_sft_config(lora_rank, lora_alpha, max_cache_depth) + + self.layer_idx = layer_idx + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.hidden_size = hidden_size + self.moe_intermediate_size = moe_intermediate_size + self.num_gpu_experts = num_gpu_experts + self.weight_path = weight_path + self.chunked_prefill_size = chunked_prefill_size + self.threadpool_count = threadpool_count + + self.lora_rank = lora_rank + self.lora_alpha = lora_alpha + self.lora_scaling = lora_alpha / lora_rank + self.max_cache_depth = max_cache_depth + + self.gate_lora_a: Optional[torch.Tensor] = None + self.gate_lora_b: Optional[torch.Tensor] = None + self.up_lora_a: Optional[torch.Tensor] = None + self.up_lora_b: Optional[torch.Tensor] = None + self.down_lora_a: Optional[torch.Tensor] = None + self.down_lora_b: Optional[torch.Tensor] = None + + self._weights_loaded: bool = False + self._lora_initialized: bool = False + self._cache_depth: int = 0 + self._is_skip_lora: bool = False + + self.moe = None + + @staticmethod + def _validate_sft_config(lora_rank: int, lora_alpha: float, max_cache_depth: int) -> None: + if lora_rank <= 0: + raise ValueError(f"lora_rank must be positive, got {lora_rank}") + if lora_alpha <= 0: + raise ValueError(f"lora_alpha must be positive, got {lora_alpha}") + if max_cache_depth <= 0: + raise ValueError(f"max_cache_depth must be positive, got {max_cache_depth}") + + # ========== Abstract methods for subclasses ========== + + @abstractmethod + def _make_forward_task(self, buffer: KExpertsSFTBuffer, save_for_backward: bool): + """Construct the C++ forward task object. Backend-specific.""" + ... + + @abstractmethod + def _make_backward_task(self, buffer: KExpertsSFTBuffer): + """Construct the C++ backward task object. Backend-specific.""" + ... + + @abstractmethod + def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None: + ... + + @abstractmethod + def init_lora_weights( + self, + gate_lora_a: torch.Tensor, gate_lora_b: torch.Tensor, + up_lora_a: torch.Tensor, up_lora_b: torch.Tensor, + down_lora_a: torch.Tensor, down_lora_b: torch.Tensor, + grad_gate_lora_a: torch.Tensor, grad_gate_lora_b: torch.Tensor, + grad_up_lora_a: torch.Tensor, grad_up_lora_b: torch.Tensor, + grad_down_lora_a: torch.Tensor, grad_down_lora_b: torch.Tensor, + ) -> None: + ... + + @abstractmethod + def update_lora_weights(self) -> None: + ... + + # ========== Buffer helpers ========== + + def _get_buffer(self, qlen: int) -> KExpertsSFTBuffer: + return KExpertsSFTBuffer.get_buffer( + qlen=qlen, + hidden_size=self.hidden_size, + moe_intermediate_size=self.moe_intermediate_size, + num_experts=self.num_experts, + num_experts_per_tok=self.num_experts_per_tok, + lora_rank=self.lora_rank, + dtype=torch.bfloat16, + ) + + def _validate_forward_inputs(self, hidden_states: torch.Tensor, expert_ids: torch.Tensor, weights: torch.Tensor): + if not self._weights_loaded: + raise RuntimeError("Weights not loaded. Call load_weights() or load_weights_from_tensors() first.") + if not self._lora_initialized and not self._is_skip_lora: + raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.") + qlen = hidden_states.shape[0] + if qlen > self.chunked_prefill_size: + raise ValueError( + f"qlen ({qlen}) exceeds chunked_prefill_size ({self.chunked_prefill_size}). " + "Increase chunked_prefill_size or reduce qlen to avoid buffer overrun." + ) + if expert_ids.shape[0] != qlen or expert_ids.shape[1] != self.num_experts_per_tok: + raise ValueError( + f"expert_ids shape {tuple(expert_ids.shape)} must be ({qlen}, {self.num_experts_per_tok})." + ) + if weights.shape[0] != qlen or weights.shape[1] != self.num_experts_per_tok: + raise ValueError( + f"weights shape {tuple(weights.shape)} must be ({qlen}, {self.num_experts_per_tok})." + ) + + def _copy_inputs_to_buffer(self, buffer: KExpertsSFTBuffer, hidden_states: torch.Tensor, + expert_ids: torch.Tensor, weights: torch.Tensor, qlen: int) -> torch.device: + """Copy inputs to CPU buffer, return input device.""" + input_device = hidden_states.device + buffer.input_cpu[:qlen].copy_(hidden_states.to(torch.bfloat16), non_blocking=True) + buffer.expert_ids_cpu[:qlen].copy_(expert_ids.to(torch.int64), non_blocking=True) + buffer.weights_cpu[:qlen].copy_(weights.to(torch.float32), non_blocking=True) + buffer.bsz_tensor[0] = qlen + if input_device.type == "cuda": + torch.cuda.synchronize(input_device) + return input_device + + def _copy_grad_output_to_cpu(self, buffer: KExpertsSFTBuffer, grad_output: torch.Tensor, qlen: int): + """Copy grad_output to CPU buffer.""" + input_device = grad_output.device + if input_device.type == "cuda": + torch.cuda.synchronize(input_device) + buffer.grad_output_cpu[:qlen].copy_(grad_output.to(torch.bfloat16)) + + def _return_output(self, buffer: KExpertsSFTBuffer, qlen: int, output_device: Optional[torch.device]): + if output_device is not None: + return buffer.output_cpu[:qlen].to(device=output_device, non_blocking=True) + else: + return buffer.output_cpu[:qlen].clone() + + def _return_grads(self, buffer: KExpertsSFTBuffer, qlen: int, output_device: Optional[torch.device]): + if output_device is not None: + grad_input = buffer.grad_input_cpu[:qlen].to(device=output_device, non_blocking=True) + grad_weights = buffer.grad_weights[:qlen].to(device=output_device, non_blocking=True) + else: + grad_input = buffer.grad_input_cpu[:qlen].clone() + grad_weights = buffer.grad_weights[:qlen].clone() + return grad_input, grad_weights + + # ========== Concrete forward/backward ========== + + def forward( + self, + hidden_states: torch.Tensor, + expert_ids: torch.Tensor, + weights: torch.Tensor, + save_for_backward: bool = True, + output_device: Optional[torch.device] = None, + ) -> torch.Tensor: + """Synchronous forward pass with optional gradient caching.""" + self._validate_forward_inputs(hidden_states, expert_ids, weights) + qlen = hidden_states.shape[0] + buffer = self._get_buffer(qlen) + self._copy_inputs_to_buffer(buffer, hidden_states, expert_ids, weights, qlen) + + self.cpu_infer.submit(self._make_forward_task(buffer, save_for_backward)) + self.cpu_infer.sync() + + if save_for_backward and self._cache_depth == 0: + self._cache_depth += 1 + + return self._return_output(buffer, qlen, output_device) + + def backward( + self, + grad_output: torch.Tensor, + output_device: Optional[torch.device] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward pass computing grad_input and grad_weights.""" + if self._cache_depth <= 0: + raise RuntimeError("No forward cache available. Call forward(save_for_backward=True) first.") + + qlen = grad_output.shape[0] + buffer = self._get_buffer(qlen) + self._copy_grad_output_to_cpu(buffer, grad_output, qlen) + + self.cpu_infer.submit(self._make_backward_task(buffer)) + self.cpu_infer.sync() + + self._cache_depth -= 1 + return self._return_grads(buffer, qlen, output_device) + + # ========== Async forward ========== + + def submit_forward( + self, + hidden_states: torch.Tensor, + expert_ids: torch.Tensor, + weights: torch.Tensor, + save_for_backward: bool = True, + ) -> None: + """Submit forward pass asynchronously (non-blocking). Call sync_forward() to get results.""" + self._validate_forward_inputs(hidden_states, expert_ids, weights) + qlen = hidden_states.shape[0] + buffer = self._get_buffer(qlen) + self._copy_inputs_to_buffer(buffer, hidden_states, expert_ids, weights, qlen) + + self._pending_buffer = buffer + self._pending_save_for_backward = save_for_backward + self._pending_qlen = qlen + + self.cpu_infer.submit(self._make_forward_task(buffer, save_for_backward)) + + def sync_forward(self, output_device: Optional[torch.device] = None) -> torch.Tensor: + """Synchronize and retrieve forward results. Must be called after submit_forward().""" + if not hasattr(self, "_pending_buffer") or self._pending_buffer is None: + raise RuntimeError("No pending forward. Call submit_forward() first.") + + self.cpu_infer.sync() + + buffer = self._pending_buffer + save_for_backward = self._pending_save_for_backward + qlen = self._pending_qlen + + if save_for_backward and self._cache_depth == 0: + self._cache_depth += 1 + + self._pending_buffer = None + self._pending_save_for_backward = None + self._pending_qlen = None + + return self._return_output(buffer, qlen, output_device) + + # ========== Async backward ========== + + def submit_backward_async( + self, + grad_output: torch.Tensor, + output_device: Optional[torch.device] = None, + ) -> None: + """Submit backward task without waiting. Call sync_backward() for results.""" + if self._cache_depth <= 0: + raise RuntimeError("No forward cache available. Call forward(save_for_backward=True) first.") + + qlen = grad_output.shape[0] + buffer = self._get_buffer(qlen) + self._copy_grad_output_to_cpu(buffer, grad_output, qlen) + + self.cpu_infer.submit(self._make_backward_task(buffer)) + self._async_bwd_qlen = qlen + self._async_bwd_output_device = output_device + + def sync_backward(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Wait for async backward and return results.""" + self.cpu_infer.sync() + + qlen = self._async_bwd_qlen + output_device = self._async_bwd_output_device + buffer = self._get_buffer(qlen) + + self._cache_depth -= 1 + return self._return_grads(buffer, qlen, output_device) + + # ========== Backward repack (optional, subclasses may override) ========== + + def submit_backward_repack(self): + if not self._weights_loaded or self.moe is None: + return + if hasattr(self.moe, 'submit_backward_repack'): + self.moe.submit_backward_repack() + + def wait_backward_repack(self): + if not self._weights_loaded or self.moe is None: + return + if hasattr(self.moe, 'wait_backward_repack'): + self.moe.wait_backward_repack() diff --git a/kt-kernel/python/sft/config.py b/kt-kernel/python/sft/config.py new file mode 100644 index 00000000..82869227 --- /dev/null +++ b/kt-kernel/python/sft/config.py @@ -0,0 +1,124 @@ +# KT-Kernel SFT configuration +# SPDX-License-Identifier: Apache-2.0 + +""" +KTConfig: kt-kernel's own configuration dataclass. + +This is the kt-kernel equivalent of DeepSpeed's JSON config — +it holds all kt-kernel-specific settings and is passed through +KTransformersPlugin.kt_config (similar to DeepSpeedPlugin.hf_ds_config). +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from typing import Any, Callable + + +def _env_int(key: str, default: int | None) -> int | None: + value = os.environ.get(key, None) + if value is None or value == "": + return default + return int(value) + + +def _env_float(key: str, default: float | None) -> float | None: + value = os.environ.get(key, None) + if value is None or value == "": + return default + return float(value) + + +def _env_bool(key: str, default: bool) -> bool: + value = os.environ.get(key, None) + if value is None or value == "": + return default + return value.lower() in ("1", "true", "yes") + + +@dataclass +class KTConfig: + """ + KT-Kernel configuration for SFT training. + + All kt-kernel-specific settings live here. Accelerate's KTransformersPlugin + holds a reference to this via its `kt_config` field (similar to + DeepSpeedPlugin.hf_ds_config). + + Can be created from: + - Direct construction: KTConfig(backend="AMXBF16", weight_path="/path/...") + - Dict: KTConfig(**config_dict) + - Environment variables: KTConfig() reads ACCELERATE_KT_* env vars as defaults + """ + + # Backend selection + backend: str | None = None + num_threads: int | None = None + tp_enabled: bool | None = None + threadpool_count: int | None = None + + # Weight loading + weight_path: str | None = None + expert_checkpoint_path: str | None = None + num_gpu_experts: int | None = None + skip_expert_loading: bool | None = None + share_backward_bb: bool | None = None + + # Cache + max_cache_depth: int | None = None + model_max_length: int | None = None + + # LoRA + lora_rank: int | None = None + lora_alpha: float | None = None + + # LoRA Experts (GPU-side extra experts) + use_lora_experts: bool | None = None + lora_expert_num: int | None = None + lora_expert_intermediate_size: int | None = None + + # Runtime state (set during wrapping, not by user) + checkpoint_files: list[str] | None = None + sharded_metadata: dict | None = None + + # Custom wrapping + wrap_fn: Callable[..., Any] | None = None + wrap_kwargs: dict[str, Any] | None = None + + def __post_init__(self): + if self.backend is None: + self.backend = os.environ.get("ACCELERATE_KT_BACKEND", "AMXBF16") + if self.num_threads is None: + self.num_threads = _env_int("ACCELERATE_KT_NUM_THREADS", 1) + if self.tp_enabled is None: + self.tp_enabled = _env_bool("ACCELERATE_KT_TP_ENABLED", False) + if self.threadpool_count is None: + self.threadpool_count = _env_int("ACCELERATE_KT_THREADPOOL_COUNT", 1) + if self.weight_path is None: + self.weight_path = os.environ.get("ACCELERATE_KT_WEIGHT_PATH", None) + if self.expert_checkpoint_path is None: + self.expert_checkpoint_path = os.environ.get("ACCELERATE_KT_EXPERT_CHECKPOINT_PATH", None) + if self.num_gpu_experts is None: + self.num_gpu_experts = _env_int("ACCELERATE_KT_NUM_GPU_EXPERTS", 0) + if self.max_cache_depth is None: + self.max_cache_depth = _env_int("ACCELERATE_KT_MAX_CACHE_DEPTH", 2) + if self.share_backward_bb is None: + self.share_backward_bb = _env_bool("ACCELERATE_KT_SHARE_BACKWARD_BB", False) + if self.use_lora_experts is None: + self.use_lora_experts = _env_bool("ACCELERATE_KT_USE_LORA_EXPERTS", False) + if self.lora_expert_num is None: + self.lora_expert_num = _env_int("ACCELERATE_KT_LORA_EXPERT_NUM", None) + if self.lora_expert_intermediate_size is None: + self.lora_expert_intermediate_size = _env_int("ACCELERATE_KT_LORA_EXPERT_INTERMEDIATE_SIZE", None) + if self.lora_rank is None: + self.lora_rank = _env_int("ACCELERATE_KT_LORA_RANK", None) + if self.lora_alpha is None: + self.lora_alpha = _env_float("ACCELERATE_KT_LORA_ALPHA", None) + if self.lora_alpha is None and self.lora_rank is not None: + self.lora_alpha = float(self.lora_rank * 2) + if self.model_max_length is None: + self.model_max_length = _env_int("ACCELERATE_KT_MODEL_MAX_LENGTH", None) + if self.skip_expert_loading is None: + if "ACCELERATE_KT_SKIP_EXPERT_LOADING" in os.environ: + self.skip_expert_loading = _env_bool("ACCELERATE_KT_SKIP_EXPERT_LOADING", True) diff --git a/kt-kernel/python/sft/dist_utils.py b/kt-kernel/python/sft/dist_utils.py new file mode 100644 index 00000000..831d2d3b --- /dev/null +++ b/kt-kernel/python/sft/dist_utils.py @@ -0,0 +1,184 @@ +# Distributed and checkpoint utilities for SFT +# SPDX-License-Identifier: Apache-2.0 + +""" +Shared distributed communication and gradient-checkpoint detection helpers. + +This is a leaf module — no imports from other sft/ submodules. +""" + +from __future__ import annotations + +import inspect +from contextlib import nullcontext +from typing import Any + +import torch + + +def _all_gather_qlens(local_qlen: int, device: torch.device, world_size: int) -> list[int]: + import torch.distributed as dist + + local_qlen_t = torch.tensor([int(local_qlen)], device=device, dtype=torch.int64) + gathered = [torch.empty(1, device=device, dtype=torch.int64) for _ in range(world_size)] + dist.all_gather(gathered, local_qlen_t) + return [int(t.item()) for t in gathered] + + +def _qlen_offsets(all_qlens: list[int]) -> list[int]: + offsets = [0] + for q in all_qlens: + offsets.append(offsets[-1] + int(q)) + return offsets + + +def _dist_gather_varlen_to_rank0( + local_tensor: torch.Tensor, + *, + all_qlens: list[int], + rank: int, + world_size: int, +) -> list[torch.Tensor] | None: + import torch.distributed as dist + + local_tensor = local_tensor.contiguous() + local_expected = int(all_qlens[rank]) + if local_tensor.shape[0] != local_expected: + raise RuntimeError( + f"Local leading dim mismatch on rank {rank}: got {local_tensor.shape[0]}, expected {local_expected}" + ) + + if rank == 0: + gathered: list[torch.Tensor | None] = [None] * world_size + gathered[0] = local_tensor + ops: list[dist.P2POp] = [] + for src in range(1, world_size): + qlen_src = int(all_qlens[src]) + recv_shape = (qlen_src, *local_tensor.shape[1:]) + recv = torch.empty(recv_shape, device=local_tensor.device, dtype=local_tensor.dtype) + gathered[src] = recv + if qlen_src > 0: + ops.append(dist.P2POp(dist.irecv, recv, src)) + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + out: list[torch.Tensor] = [] + for idx, t in enumerate(gathered): + if t is None: + raise RuntimeError(f"Missing gathered tensor for rank {idx} on rank0.") + out.append(t) + return out + + if local_expected > 0: + reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, local_tensor, 0)]) + for req in reqs: + req.wait() + return None + + +def _dist_scatter_varlen_from_rank0( + *, + rank0_chunks: list[torch.Tensor] | None, + all_qlens: list[int], + rank: int, + world_size: int, + feature_shape: tuple[int, ...], + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + import torch.distributed as dist + + local_qlen = int(all_qlens[rank]) + local_out = torch.empty((local_qlen, *feature_shape), device=device, dtype=dtype) + + if rank == 0: + if rank0_chunks is None or len(rank0_chunks) != world_size: + raise RuntimeError("rank0_chunks must contain one chunk per rank on rank0.") + if int(rank0_chunks[0].shape[0]) != local_qlen: + raise RuntimeError( + f"Rank0 local chunk mismatch: got {rank0_chunks[0].shape[0]}, expected {local_qlen}" + ) + if local_qlen > 0: + local_out.copy_(rank0_chunks[0]) + ops: list[dist.P2POp] = [] + for dst in range(1, world_size): + qlen_dst = int(all_qlens[dst]) + if qlen_dst <= 0: + continue + chunk = rank0_chunks[dst].contiguous() + if int(chunk.shape[0]) != qlen_dst: + raise RuntimeError( + f"Rank{dst} chunk mismatch on rank0: got {chunk.shape[0]}, expected {qlen_dst}" + ) + ops.append(dist.P2POp(dist.isend, chunk, dst)) + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + return local_out + + if local_qlen > 0: + reqs = dist.batch_isend_irecv([dist.P2POp(dist.irecv, local_out, 0)]) + for req in reqs: + req.wait() + return local_out + + +def _is_in_checkpoint_first_forward() -> bool: + """Best-effort detection for non-reentrant checkpoint first forward.""" + try: + for frame_info in inspect.stack(context=0): + fn = frame_info.function + file = frame_info.filename or "" + if fn == "custom_gradient_checkpointing_func" and file.endswith("checkpointing.py"): + return True + except Exception: + return False + return False + + +def _checkpoint_hook_mode() -> str: + """Infer checkpoint phase from current saved_tensors_hooks top. + + Returns one of: + - "first_forward": non-reentrant checkpoint's _checkpoint_hook + - "recompute": non-reentrant checkpoint's _recomputation_hook + - "none": no default saved_tensors_hooks on top + - "other": unknown hook stack entry + - "error": failed to query hook stack + """ + try: + top = torch._C._autograd._top_saved_tensors_default_hooks(False) + except Exception: + return "error" + if top is None: + return "none" + try: + pack_fn, _ = top + mod = getattr(pack_fn, "__module__", "") + qual = getattr(pack_fn, "__qualname__", getattr(pack_fn, "__name__", "")) + tag = f"{mod}.{qual}" + except Exception: + return "other" + if "_recomputation_hook.__init__..pack_hook" in tag: + return "recompute" + if "_checkpoint_hook.__init__..pack_hook" in tag: + return "first_forward" + return "other" + + +def _maybe_zero3_gathered_parameters(params: list[torch.nn.Parameter]): + if not params: + return nullcontext() + try: + from transformers.integrations import is_deepspeed_zero3_enabled + except Exception: + return nullcontext() + if not is_deepspeed_zero3_enabled(): + return nullcontext() + try: + import deepspeed # type: ignore + except Exception: + return nullcontext() + return deepspeed.zero.GatheredParameters(params, modifier_rank=0) diff --git a/kt-kernel/python/sft/layer.py b/kt-kernel/python/sft/layer.py new file mode 100644 index 00000000..cea7b07e --- /dev/null +++ b/kt-kernel/python/sft/layer.py @@ -0,0 +1,407 @@ +# KTMoELayerWrapper — nn.Module replacing HF MoE layers for SFT +# SPDX-License-Identifier: Apache-2.0 + +""" +KTMoELayerWrapper: drop-in nn.Module replacement for HuggingFace MoE layers. + +Delegates expert computation to the C++ KTMoEWrapper backend, with support +for gradient checkpointing, PEFT LoRA on experts, LoRA Experts (separate +small MLPs on GPU), shared experts, and multi-GPU rank-0-only execution. +""" + +from __future__ import annotations + +import logging +import os +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .arch import MOEArchConfig +from .autograd import KTMoEFunction +from .dist_utils import ( + _all_gather_qlens, + _checkpoint_hook_mode, + _dist_gather_varlen_to_rank0, + _dist_scatter_varlen_from_rank0, + _is_in_checkpoint_first_forward, + _qlen_offsets, +) + +logger = logging.getLogger(__name__) +_KT_SFT_DEBUG = os.environ.get("KT_SFT_DEBUG", "0") == "1" + + +class KTMoELayerWrapper(nn.Module): + """Wrapper for MoE layer using KTMoEWrapper.""" + + def __init__( + self, + original_moe: nn.Module, + wrapper: Any, + lora_params: dict[str, nn.Parameter] | None, # Kept for backward compatibility, but ignored + moe_config: MOEArchConfig, + hidden_size: int, + layer_idx: int, + lora_experts: "LoRAExperts | None" = None, + ): + super().__init__() + self._is_kt_moe_wrapper = True + + self.wrapper = wrapper + self.moe_config = moe_config + self.hidden_size = hidden_size + self.layer_idx = layer_idx + self.router_type = moe_config.router_type + + # IMPORTANT: Register submodules in the SAME ORDER as original MoE module + # so that PEFT's named_modules() traversal order matches baseline. + # This ensures kaiming_uniform_ calls happen in the same sequence. + # Qwen3MoeSparseMoeBlock order: gate FIRST, then experts. + + # 1. gate/router FIRST - keep original attribute name for PEFT compatibility + router_attr = moe_config.router_attr # "gate" for Qwen3/DeepSeek + setattr(self, router_attr, getattr(original_moe, router_attr, None)) + self._router_attr = router_attr + + # 2. experts SECOND (this is what PEFT targets for LoRA) + experts_attr = moe_config.experts_attr # typically "experts" + setattr(self, experts_attr, getattr(original_moe, experts_attr, None)) + self._experts_attr = experts_attr + + # 3. shared_experts (if any) + if moe_config.has_shared_experts and hasattr(original_moe, "shared_experts"): + self.shared_experts = original_moe.shared_experts + else: + self.shared_experts = None + + # 4. lora_experts (separate LoRA expert MLPs, different from PEFT LoRA on experts) + self.lora_experts = lora_experts + + # PEFT LoRA tracking (set by kt_adapt_peft_lora) + # _peft_lora_modules: {expert_idx: {proj_name: (lora_A, lora_B)}} + self._peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]] | None = None + self._peft_lora_rank: int = 0 + self._peft_lora_alpha: float = 0.0 + self._skip_lora: bool = False # True when using SkipLoRA backend (no LoRA on experts) + + self._lora_pointers_dirty = False + + def _apply(self, fn, recurse=True): + # Protect experts from device transfer (PEFT LoRA should stay on CPU for KT) + saved_experts = None + experts_attr = getattr(self, '_experts_attr', None) + + if experts_attr is not None and getattr(self, experts_attr, None) is not None: + saved_experts = getattr(self, experts_attr) + self._modules.pop(experts_attr, None) + + result = super()._apply(fn, recurse) + + if saved_experts is not None: + self._modules[experts_attr] = saved_experts + + return result + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + import torch.distributed as dist + dist_on = dist.is_initialized() and dist.get_world_size() > 1 + rank = dist.get_rank() if dist.is_initialized() else 0 + + # Check if we need to use distributed broadcast (only rank 0 has KT kernel) + use_broadcast = dist_on and self.wrapper is None + + topk_ids, topk_weights = self._compute_routing(hidden_states) + + train_lora = self._peft_lora_modules is not None and len(self._peft_lora_modules) > 0 + + save_for_backward = ( + self.training + and torch.is_grad_enabled() + and (hidden_states.requires_grad or topk_weights.requires_grad or train_lora) + ) + ckpt_hook_mode = _checkpoint_hook_mode() + in_ckpt_recompute = ckpt_hook_mode == "recompute" + in_ckpt_first_forward = ckpt_hook_mode == "first_forward" + if ckpt_hook_mode in ("none", "other", "error"): + # Fallback for environments where hook-top probing is unavailable. + in_ckpt_first_forward = _is_in_checkpoint_first_forward() + if in_ckpt_recompute: + # Recompute must be treated as non-first-forward in diagnostics. + in_ckpt_first_forward = False + # Keep KT autograd path whenever backward is needed. Disabling it in + # checkpoint first-forward prevents KTMoEFunction.backward from running. + use_autograd_path = save_for_backward + save_for_backward_submit = use_autograd_path + # Only suppress cache when we have high-confidence first_forward detection + # via the saved_tensors_hooks stack. The stack-walk fallback is too fragile + # for a correctness-critical decision — it only logs. + if ckpt_hook_mode == "first_forward": + save_for_backward_submit = False + + if train_lora and self._lora_pointers_dirty: + self.update_lora_pointers() + self._lora_pointers_dirty = False + + gpu_output, all_qlens = self._submit_and_compute_gpu( + hidden_states, + topk_ids, + topk_weights, + save_for_backward_submit, + ) + + # Use KTMoEFunction whenever backward is needed so KT backward and LoRA + # gradient paths remain connected. + if use_autograd_path: + lora_ref = hidden_states.new_empty(()) + if train_lora and self._peft_lora_modules: + for expert_loras in self._peft_lora_modules.values(): + for lora_A, lora_B in expert_loras.values(): + if hasattr(lora_A, 'weight') and lora_A.weight.requires_grad: + lora_ref = lora_A.weight + break + if lora_ref.numel() > 0: + break + + moe_output = KTMoEFunction.apply( + hidden_states, + topk_ids, + topk_weights, + self.wrapper, + lora_ref, + self.hidden_size, + self.moe_config.num_experts_per_tok, + self.layer_idx, + save_for_backward, + train_lora, + all_qlens, + ) + else: + moe_output = self._sync_forward_output_no_autograd( + hidden_states=hidden_states, + all_qlens=all_qlens, + ) + + if gpu_output is not None: + moe_output = moe_output + gpu_output + + return moe_output + + def _sync_forward_output_no_autograd( + self, + hidden_states: torch.Tensor, + all_qlens: list[int] | tuple[int, ...] | None, + ) -> torch.Tensor: + """Sync CPU expert output without creating KTMoEFunction autograd nodes.""" + import torch.distributed as dist + + original_device = hidden_states.device + original_dtype = hidden_states.dtype + batch_size, seq_len, _ = hidden_states.shape + qlen = batch_size * seq_len + + dist_on = dist.is_initialized() and dist.get_world_size() > 1 + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist_on else 1 + + if dist_on: + if all_qlens is None: + all_qlens_list = _all_gather_qlens(qlen, original_device, world_size) + else: + all_qlens_list = [int(q) for q in all_qlens] + if len(all_qlens_list) != world_size: + raise RuntimeError( + f"all_qlens length mismatch: got {len(all_qlens_list)}, expected {world_size}" + ) + if int(all_qlens_list[rank]) != qlen: + raise RuntimeError( + f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens_list[rank]}" + ) + total_qlen = sum(all_qlens_list) + + if rank == 0: + if self.wrapper is None: + raise RuntimeError("Rank0 wrapper is required in distributed KT overlap path.") + cpu_output = self.wrapper.sync_forward(output_device=original_device) + cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, self.hidden_size) + offsets = _qlen_offsets(all_qlens_list) + scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)] + else: + scatter_list = None + + output_flat = _dist_scatter_varlen_from_rank0( + rank0_chunks=scatter_list, + all_qlens=all_qlens_list, + rank=rank, + world_size=world_size, + feature_shape=(self.hidden_size,), + device=original_device, + dtype=original_dtype, + ) + output = output_flat.view(batch_size, seq_len, self.hidden_size) + del output_flat + return output + + if self.wrapper is not None: + cpu_output = self.wrapper.sync_forward(output_device=original_device) + output = cpu_output.view(batch_size, seq_len, self.hidden_size).to(dtype=original_dtype) + return output + + return torch.empty(batch_size, seq_len, self.hidden_size, device=original_device, dtype=original_dtype) + + def _compute_routing(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # Run routing under no_grad to avoid creating autograd nodes whose + # SavedVariables become orphan holders inside gradient checkpoint. + # The gate is frozen during LoRA fine-tuning and the main gradient + # flows through KTMoEFunction.backward()'s grad_input, so the + # routing gradient contribution to hidden_states can be safely dropped. + with torch.no_grad(): + router = getattr(self, self._router_attr) + if self.router_type == "deepseek_gate": + # DeepSeek V3's MoEGate has `assert not self.training` in its noaux_tc + # routing path because the HF model is an inference-only port. + # For LoRA fine-tuning the router is frozen, so eval() is safe. + was_training = router.training + if was_training: + router.eval() + router_output = router(hidden_states) + if was_training: + router.train() + if len(router_output) == 2: + topk_ids, topk_weights = router_output + else: + topk_ids, topk_weights = router_output[0], router_output[1] + if topk_weights.is_floating_point(): + topk_weights = topk_weights.to(torch.bfloat16) + return topk_ids, topk_weights + + router_logits = router(hidden_states.view(-1, self.hidden_size)) + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(routing_weights, self.moe_config.num_experts_per_tok, dim=-1) + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(torch.bfloat16) + return topk_ids, topk_weights + + def _submit_and_compute_gpu( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + save_for_backward: bool, + ) -> tuple[torch.Tensor | None, list[int] | None]: + import torch.distributed as dist + + batch_size, seq_len, _ = hidden_states.shape + original_device = hidden_states.device + original_dtype = hidden_states.dtype + + dist_on = dist.is_initialized() and dist.get_world_size() > 1 + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist_on else 1 + + qlen = batch_size * seq_len + + if dist_on: + all_qlens = _all_gather_qlens(qlen, original_device, world_size) + if int(all_qlens[rank]) != qlen: + raise RuntimeError( + f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens[rank]}" + ) + total_qlen = sum(all_qlens) + + hs_flat = hidden_states.view(qlen, self.hidden_size).contiguous() + expert_ids = topk_ids.view(qlen, self.moe_config.num_experts_per_tok).contiguous() + weights = topk_weights.view(qlen, self.moe_config.num_experts_per_tok).contiguous() + + submit_hs = hs_flat.detach() + submit_ids = expert_ids.detach() + submit_wts = weights.detach() + + gathered_hs = _dist_gather_varlen_to_rank0( + submit_hs, + all_qlens=all_qlens, + rank=rank, + world_size=world_size, + ) + gathered_ids = _dist_gather_varlen_to_rank0( + submit_ids, + all_qlens=all_qlens, + rank=rank, + world_size=world_size, + ) + gathered_wts = _dist_gather_varlen_to_rank0( + submit_wts, + all_qlens=all_qlens, + rank=rank, + world_size=world_size, + ) + + if rank == 0: + all_hs = torch.cat(gathered_hs, dim=0) + all_ids = torch.cat(gathered_ids, dim=0) + all_wts = torch.cat(gathered_wts, dim=0) + self.wrapper.submit_forward( + all_hs, + all_ids, + all_wts, + save_for_backward=save_for_backward, + ) + + # Keep shared/lora experts local to avoid qlen_max-style amplification. + gpu_output = None + if self.shared_experts is not None: + gpu_output = self.shared_experts(hidden_states) + gpu_output = gpu_output.to(dtype=original_dtype) + + if self.lora_experts is not None: + lora_out = self.lora_experts(hidden_states) + gpu_output = lora_out if gpu_output is None else gpu_output + lora_out + + return gpu_output, all_qlens + + else: + # ---- Single-GPU path: submit + GPU compute ---- + input_flat = hidden_states.view(qlen, self.hidden_size) + expert_ids = topk_ids.view(qlen, self.moe_config.num_experts_per_tok) + weights = topk_weights.view(qlen, self.moe_config.num_experts_per_tok) + + # Avoid passing graph-attached tensors into C++ cache. + submit_hs = input_flat.detach() + submit_ids = expert_ids.detach() + submit_wts = weights.detach() + self.wrapper.submit_forward( + submit_hs, + submit_ids, + submit_wts, + save_for_backward=save_for_backward, + ) + + # GPU compute: shared_experts + lora_experts + gpu_output = None + if self.shared_experts is not None: + gpu_output = self.shared_experts(hidden_states) + if self.lora_experts is not None: + lora_out = self.lora_experts(hidden_states) + gpu_output = lora_out if gpu_output is None else gpu_output + lora_out + + return gpu_output, None + + def update_lora_pointers(self): + """Sync PEFT LoRA weights to C++ kernel after optimizer update.""" + # Skip if wrapper is None (non-rank-0 processes) + if self.wrapper is None: + return + # Skip if wrapper is not properly initialized + if not getattr(self.wrapper, "_weights_loaded", False): + logger.warning(f"Layer {self.layer_idx}: Skipping update_lora_pointers - weights not loaded") + return + if not getattr(self.wrapper, "_lora_initialized", False): + logger.warning(f"Layer {self.layer_idx}: Skipping update_lora_pointers - LoRA not initialized") + return + + # PEFT weights are views into wrapper's contiguous buffers — + # optimizer.step() already updated them in-place, just re-sync to C++. + self.wrapper.update_lora_weights() diff --git a/kt-kernel/python/sft/lora.py b/kt-kernel/python/sft/lora.py new file mode 100644 index 00000000..d949edf9 --- /dev/null +++ b/kt-kernel/python/sft/lora.py @@ -0,0 +1,688 @@ +# PEFT LoRA adaptation utilities for SFT +# SPDX-License-Identifier: Apache-2.0 + +""" +PEFT LoRA integration for KT-Kernel MoE training. + +Handles: +- LoRA Expert modules (LoRAExpertMLP, LoRAExperts) +- PEFT LoRA adaptation onto KT wrappers (contiguous buffer views, grad buffers) +- LoRA parameter collection for optimizer injection +- Checkpoint save/load for lora_experts +""" + +from __future__ import annotations + +import logging +import math +import os +import re + +import torch +import torch.nn as nn + +from .arch import MOEArchConfig + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# LoRA Experts Modules +# ============================================================================= + + +class LoRAExpertMLP(nn.Module): + """Single LoRA Expert with SwiGLU activation structure.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.le_gate = nn.Linear(hidden_size, intermediate_size, bias=False, device=device, dtype=dtype) + self.le_up = nn.Linear(hidden_size, intermediate_size, bias=False, device=device, dtype=dtype) + self.le_down = nn.Linear(intermediate_size, hidden_size, bias=False, device=device, dtype=dtype) + self.act_fn = nn.SiLU() + + nn.init.zeros_(self.le_down.weight) + nn.init.kaiming_uniform_(self.le_gate.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.le_up.weight, a=math.sqrt(5)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.le_down(self.act_fn(self.le_gate(x)) * self.le_up(x)) + + +class LoRAExperts(nn.Module): + """LoRA Experts module containing multiple LoRA Expert MLPs.""" + + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.experts = nn.ModuleList( + [LoRAExpertMLP(hidden_size, intermediate_size, device, dtype) for _ in range(num_experts)] + ) + self.num_experts = num_experts + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + output = torch.zeros_like(hidden_states) + for expert in self.experts: + output = output + expert(hidden_states) + return output / self.num_experts + + +# ============================================================================= +# LoRA Parameter Collection +# ============================================================================= + + +def _find_kt_wrappers(model: nn.Module): + """Find _kt_wrappers on model, unwrapping PEFT/other wrappers if needed.""" + wrappers = getattr(model, "_kt_wrappers", None) + if wrappers is None: + base_model = model + for attr in ("base_model", "model"): + if hasattr(base_model, attr): + base_model = getattr(base_model, attr) + wrappers = getattr(base_model, "_kt_wrappers", None) + if wrappers: + break + return wrappers + + +def get_kt_lora_params(model: nn.Module) -> list[nn.Parameter]: + """Get all MoE LoRA parameters from KT model. + + Returns PEFT LoRA parameters from expert modules and lora_experts parameters. + """ + params: list[nn.Parameter] = [] + + wrappers = _find_kt_wrappers(model) + + if wrappers: + for wrapper in wrappers: + # PEFT LoRA parameters (from _peft_lora_modules) + peft_lora_modules = getattr(wrapper, "_peft_lora_modules", None) + if peft_lora_modules is not None: + for expert_loras in peft_lora_modules.values(): + for lora_A, lora_B in expert_loras.values(): + if hasattr(lora_A, 'weight') and lora_A.weight.requires_grad: + params.append(lora_A.weight) + if hasattr(lora_B, 'weight') and lora_B.weight.requires_grad: + params.append(lora_B.weight) + # lora_experts parameters (separate feature) + if getattr(wrapper, "lora_experts", None) is not None: + params.extend(wrapper.lora_experts.parameters()) + + return params + + +# ============================================================================= +# PEFT LoRA Adaptation +# ============================================================================= + + +def kt_adapt_peft_lora(model: nn.Module) -> None: + """ + Adapt PEFT LoRA on expert modules for KT kernel. + + After PEFT injects LoRA adapters onto expert Linear modules, this function: + 1. Detects PEFT LoRA presence and rank on each wrapper's experts + 2. Stores references to PEFT LoRA modules on the wrapper (for backward gradient writing) + 3. Syncs initial PEFT LoRA weights to the C++ KT kernel (rank 0 only) + + PEFT LoRA remains active and is managed by PEFT. No separate KT lora_params created. + Optimizer updates PEFT LoRA directly, and KT kernel reads from PEFT LoRA on each forward. + + Should be called after PEFT LoRA injection and before create_optimizer. + """ + import torch.distributed as dist + + wrappers = _find_kt_wrappers(model) + + if not wrappers: + logger.info("[kt_adapt_peft_lora] No _kt_wrappers found, skipping") + return + + is_rank_0 = True + if dist.is_initialized(): + is_rank_0 = dist.get_rank() == 0 + + adapted_count = 0 + for wrapper in wrappers: + moe_config = wrapper.moe_config + layer_idx = wrapper.layer_idx + experts_attr = getattr(wrapper, "_experts_attr", "experts") + experts = getattr(wrapper, experts_attr, None) + + if experts is None or len(experts) == 0: + continue + + # Collect references to PEFT LoRA modules for each expert + # Structure: {expert_idx: {proj_name: (lora_A_module, lora_B_module)}} + peft_lora_modules = {} + gate_name, up_name, down_name = moe_config.weight_names + + for expert_idx, expert in enumerate(experts): + expert_loras = {} + for proj_name in (gate_name, up_name, down_name): + proj = getattr(expert, proj_name, None) + if proj is None: + continue + lora_A = getattr(proj, "lora_A", None) + lora_B = getattr(proj, "lora_B", None) + if lora_A is not None and lora_B is not None: + # Get the actual Linear modules (inside ModuleDict if using adapters) + if isinstance(lora_A, nn.ModuleDict): + adapter_name = "default" + active = getattr(proj, "active_adapter", ["default"]) + if isinstance(active, (list, tuple)) and active: + adapter_name = active[0] + # ModuleDict doesn't have .get(), use [] with in check + lora_A = lora_A[adapter_name] if adapter_name in lora_A else None + lora_B = lora_B[adapter_name] if adapter_name in lora_B else None + if lora_A is not None and lora_B is not None: + expert_loras[proj_name] = (lora_A, lora_B) + if expert_loras: + peft_lora_modules[expert_idx] = expert_loras + + # Store PEFT LoRA references on wrapper + wrapper._peft_lora_modules = peft_lora_modules + + # SkipLoRA mode: if no LoRA found on experts, skip buffer creation + if not peft_lora_modules: + if getattr(wrapper, '_skip_lora', False): + logger.info( + f"[kt_adapt_peft_lora] Layer {layer_idx}: SkipLoRA mode, " + f"no PEFT LoRA on experts — skipping LoRA buffer creation" + ) + adapted_count += 1 + continue + else: + raise RuntimeError( + f"[kt_adapt_peft_lora] Layer {layer_idx}: No PEFT LoRA found on any expert. " + f"If you intend to train without expert LoRA, use a SkipLoRA backend " + f"(e.g., kt_backend: AMXINT8_SkipLoRA)." + ) + + # Allocate contiguous bf16 buffers and populate with initial PEFT values (all ranks) + lora_buffers = _create_lora_view_buffers(peft_lora_modules, moe_config, torch.bfloat16) + lora_grad_buffers = _create_lora_grad_buffers(peft_lora_modules, moe_config) + + # Rank 0: pass buffers to C++ wrapper (init_lora_weights stores them via .contiguous() no-op) + if is_rank_0 and wrapper.wrapper is not None: + # concat lora_buffers and lora_grad_buffers into single dict + lora_buffers.update(lora_grad_buffers) + wrapper.wrapper.init_lora_weights(**lora_buffers) + logger.info(f"[kt_adapt_peft_lora] Layer {layer_idx}: synced PEFT LoRA to C++ kernel") + + # All ranks: replace PEFT weights with views into the contiguous buffers + _replace_peft_weights_with_views(peft_lora_modules, lora_buffers, lora_grad_buffers, moe_config) + + adapted_count += 1 + + # After collecting all LoRA references, shrink expert base weight parameters + # from their original shape (e.g. [768, 2048]) to scalar (1,). + # These base weights were already replaced with tiny-storage stride=[0] placeholders + # by _clear_original_expert_weights(). They have correct shape but serve no purpose + # after PEFT injection. FSDP2 broadcasts ALL non-DTensor params, and uses + # torch.empty(param.size()) on non-rank-0 — with the original shape this wastes + # ~28GB+. Shrinking to (1,) reduces broadcast cost to ~30KB total. + shrunk_count = 0 + shrunk_saved_bytes = 0 + for wrapper in wrappers: + experts_attr = getattr(wrapper, "_experts_attr", "experts") + experts = getattr(wrapper, experts_attr, None) + if experts is None: + continue + for expert in experts: + for param_name, param in list(expert.named_parameters()): + if param.requires_grad: + continue # Skip trainable params (LoRA weights) + try: + storage_bytes = param.data.untyped_storage().nbytes() + except Exception: + continue + if storage_bytes > 2: + continue # Skip non-placeholder params + + # This is a tiny-storage placeholder (base weight) — replace with + # a scalar (1,) parameter so FSDP broadcasts only 1 element. + original_numel = param.nelement() + parts = param_name.split(".") + container = expert + for p in parts[:-1]: + container = getattr(container, p) + local_name = parts[-1] + container_params = getattr(container, "_parameters", {}) + if isinstance(container_params, dict) and local_name in container_params: + scalar_param = nn.Parameter( + torch.empty(1, dtype=param.dtype, device="cpu"), + requires_grad=False, + ) + container_params[local_name] = scalar_param + shrunk_count += 1 + shrunk_saved_bytes += (original_numel - 1) * param.element_size() + + if shrunk_count > 0: + logger.info( + f"[kt_adapt_peft_lora] Shrunk {shrunk_count} expert base weight params " + f"to shape (1,), FSDP broadcast savings={shrunk_saved_bytes / 1024 / 1024:.1f} MB" + ) + + logger.info(f"[kt_adapt_peft_lora] Adapted {adapted_count} layers (PEFT LoRA mode)") + + +# ============================================================================= +# Contiguous Buffer Creation +# ============================================================================= + + +def _create_lora_view_buffers( + peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]], + moe_config: MOEArchConfig, + dtype: torch.dtype = torch.bfloat16, +) -> dict[str, torch.Tensor]: + """ + Allocate contiguous buffers and populate with initial PEFT LoRA values. + + Returns dict with gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, + down_lora_a, down_lora_b — each shape [num_experts, ...]. + """ + gate_name, up_name, down_name = moe_config.weight_names + num_experts = moe_config.expert_num + + first_expert_loras = peft_lora_modules.get(0, {}) + if not first_expert_loras: + raise RuntimeError("No PEFT LoRA found on expert 0") + gate_lora = first_expert_loras.get(gate_name) + if gate_lora is None: + raise RuntimeError(f"No PEFT LoRA found on expert 0 {gate_name}") + + lora_rank = gate_lora[0].weight.shape[0] + hidden_size = gate_lora[0].weight.shape[1] + intermediate_size = gate_lora[1].weight.shape[0] + + buffers = { + "gate_lora_a": torch.zeros(num_experts, lora_rank, hidden_size, dtype=dtype, device="cpu"), + "gate_lora_b": torch.zeros(num_experts, intermediate_size, lora_rank, dtype=dtype, device="cpu"), + "up_lora_a": torch.zeros(num_experts, lora_rank, hidden_size, dtype=dtype, device="cpu"), + "up_lora_b": torch.zeros(num_experts, intermediate_size, lora_rank, dtype=dtype, device="cpu"), + "down_lora_a": torch.zeros(num_experts, lora_rank, intermediate_size, dtype=dtype, device="cpu"), + "down_lora_b": torch.zeros(num_experts, hidden_size, lora_rank, dtype=dtype, device="cpu"), + } + + proj_to_keys = { + gate_name: ("gate_lora_a", "gate_lora_b"), + up_name: ("up_lora_a", "up_lora_b"), + down_name: ("down_lora_a", "down_lora_b"), + } + for expert_idx in range(num_experts): + expert_loras = peft_lora_modules.get(expert_idx, {}) + for proj_name, (key_a, key_b) in proj_to_keys.items(): + if proj_name in expert_loras: + lora_A, lora_B = expert_loras[proj_name] + buffers[key_a][expert_idx].copy_(lora_A.weight.data.to(dtype=dtype)) + buffers[key_b][expert_idx].copy_(lora_B.weight.data.to(dtype=dtype)) + + return buffers + + +def _create_lora_grad_buffers( + peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]], + moe_config: MOEArchConfig, + dtype: torch.dtype = torch.bfloat16, +) -> dict[str, torch.Tensor]: + """ + Allocate contiguous gradient buffers for PEFT LoRA. + + Returns dict with grad_gate_lora_a, grad_gate_lora_b, etc. — each shape [num_experts, ...]. + """ + gate_name, up_name, down_name = moe_config.weight_names + num_experts = moe_config.expert_num + + first_expert_loras = peft_lora_modules.get(0, {}) + if not first_expert_loras: + raise RuntimeError("No PEFT LoRA found on expert 0") + gate_lora = first_expert_loras.get(gate_name) + if gate_lora is None: + raise RuntimeError(f"No PEFT LoRA found on expert 0 {gate_name}") + + lora_rank = gate_lora[0].weight.shape[0] + hidden_size = gate_lora[0].weight.shape[1] + intermediate_size = gate_lora[1].weight.shape[0] + + buffers = { + "grad_gate_lora_a": torch.zeros(num_experts, lora_rank, hidden_size, dtype=dtype, device="cpu"), + "grad_gate_lora_b": torch.zeros(num_experts, intermediate_size, lora_rank, dtype=dtype, device="cpu"), + "grad_up_lora_a": torch.zeros(num_experts, lora_rank, hidden_size, dtype=dtype, device="cpu"), + "grad_up_lora_b": torch.zeros(num_experts, intermediate_size, lora_rank, dtype=dtype, device="cpu"), + "grad_down_lora_a": torch.zeros(num_experts, lora_rank, intermediate_size, dtype=dtype, device="cpu"), + "grad_down_lora_b": torch.zeros(num_experts, hidden_size, lora_rank, dtype=dtype, device="cpu"), + } + + return buffers + + +# ============================================================================= +# PEFT Weight View Replacement +# ============================================================================= + + +def _replace_peft_weights_with_views( + peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]], + buffers: dict[str, torch.Tensor], + grad_buffers: dict[str, torch.Tensor], + moe_config: MOEArchConfig, +) -> None: + """ + Replace each PEFT LoRA module's .weight with a view into the contiguous buffer. + + After this, optimizer.step() updates the buffer in-place via the view — + no copy needed to sync with C++. + """ + gate_name, up_name, down_name = moe_config.weight_names + num_experts = moe_config.expert_num + + proj_to_keys = { + gate_name: ("gate_lora_a", "gate_lora_b"), + up_name: ("up_lora_a", "up_lora_b"), + down_name: ("down_lora_a", "down_lora_b"), + } + + _replaced = 0 + _first_logged = False + for expert_idx in range(num_experts): + expert_loras = peft_lora_modules.get(expert_idx, {}) + for proj_name, (key_a, key_b) in proj_to_keys.items(): + if proj_name not in expert_loras: + continue + lora_A, lora_B = expert_loras[proj_name] + + # Log before/after for first replacement to verify .data assignment + if not _first_logged: + _old_id_a = id(lora_A.weight) + _old_ptr_a = lora_A.weight.data_ptr() + + # Use .data assignment to keep the same Parameter objects. + # This preserves optimizer references (which point to these objects). + # Creating new nn.Parameter() would break the optimizer link. + lora_A.weight.data = buffers[key_a][expert_idx] + lora_B.weight.data = buffers[key_b][expert_idx] + lora_A.weight.requires_grad_(True) + lora_B.weight.requires_grad_(True) + lora_A.weight.grad = grad_buffers["grad_" + key_a][expert_idx] + lora_B.weight.grad = grad_buffers["grad_" + key_b][expert_idx] + + if not _first_logged: + _new_id_a = id(lora_A.weight) + _new_ptr_a = lora_A.weight.data_ptr() + _buf_ptr_a = buffers[key_a][expert_idx].data_ptr() + _has_grad = lora_A.weight.grad is not None + logger.info( + "[_replace_peft_weights_with_views] first param: " + "id %s->%s (same=%s) data_ptr %s->%s buf_ptr=%s (match=%s) " + "has_grad=%s requires_grad=%s shape=%s", + _old_id_a, _new_id_a, _old_id_a == _new_id_a, + _old_ptr_a, _new_ptr_a, _buf_ptr_a, _new_ptr_a == _buf_ptr_a, + _has_grad, lora_A.weight.requires_grad, tuple(lora_A.weight.shape), + ) + _first_logged = True + _replaced += 1 + + logger.info("[_replace_peft_weights_with_views] replaced %d param pairs", _replaced) + + +# ============================================================================= +# Runtime LoRA Pointer Updates +# ============================================================================= + + +def update_kt_lora_pointers(model: nn.Module): + """Mark KT wrapper LoRA pointers as dirty after optimizer.step().""" + wrappers = _find_kt_wrappers(model) + + if wrappers: + for wrapper in wrappers: + wrapper._lora_pointers_dirty = True + + +# ============================================================================= +# Cross-Rank Gradient Synchronization +# ============================================================================= + + +def sync_kt_lora_gradients(model: nn.Module) -> None: + """ + Synchronize KT-managed LoRA gradients across ranks. + + KT computes expert LoRA gradients only on rank 0 (gather/scatter path). This function broadcasts the + per-layer contiguous grad buffers from rank 0 to all ranks so that: + - gradient clipping sees identical grads on every rank + - optimizer.step() applies identical updates + """ + import torch.distributed as dist + + if not (dist.is_initialized() and dist.get_world_size() > 1): + return + + world_size = dist.get_world_size() + if world_size <= 1: + return + + params = get_kt_lora_params(model) + if not params: + return + + for param in params: + if param.grad is not None: + # Move grad to the same device as the parameter for all-reduce + # Then move back to CPU + original_device = param.grad.device + if original_device.type == "cpu": + # All-reduce on CPU might be slow; consider using a GPU buffer + grad_gpu = param.grad.cuda() + dist.all_reduce(grad_gpu, op=dist.ReduceOp.SUM) + grad_gpu.div_(world_size) + param.grad.copy_(grad_gpu.cpu()) + else: + dist.all_reduce(param.grad, op=dist.ReduceOp.SUM) + param.grad.div_(world_size) + + +# ============================================================================= +# Checkpoint Save/Load +# ============================================================================= + + +def save_lora_experts_to_adapter(model: nn.Module, output_dir: str) -> None: + """ + Save LoRA Experts weights to adapter file by merging with existing Attention LoRA. + """ + from safetensors import safe_open + from safetensors.torch import save_file + + wrappers = getattr(model, "_kt_wrappers", []) + if not wrappers: + base_model = model + for attr in ["base_model", "model"]: + if hasattr(base_model, attr): + base_model = getattr(base_model, attr) + wrappers = getattr(base_model, "_kt_wrappers", []) + if wrappers: + break + if not wrappers: + logger.warning("No KT wrappers found, skipping LoRA Experts saving") + return + + adapter_file = os.path.join(output_dir, "adapter_model.safetensors") + if not os.path.exists(adapter_file): + adapter_file_bin = os.path.join(output_dir, "adapter_model.bin") + if os.path.exists(adapter_file_bin): + state_dict = torch.load(adapter_file_bin, map_location="cpu", weights_only=True) + else: + logger.warning(f"No existing adapter file found at {output_dir}, creating new one") + state_dict = {} + else: + state_dict = {} + with safe_open(adapter_file, framework="pt") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + + lora_expert_count = 0 + for wrapper in wrappers: + if wrapper.lora_experts is None: + continue + + layer_idx = wrapper.layer_idx + for expert_idx, expert in enumerate(wrapper.lora_experts.experts): + base_key = f"base_model.model.model.layers.{layer_idx}.mlp.lora_experts.{expert_idx}" + state_dict[f"{base_key}.le_gate.weight"] = expert.le_gate.weight.data.cpu().clone() + state_dict[f"{base_key}.le_up.weight"] = expert.le_up.weight.data.cpu().clone() + state_dict[f"{base_key}.le_down.weight"] = expert.le_down.weight.data.cpu().clone() + lora_expert_count += 3 + + logger.debug(f"Added LoRA Experts for layer {layer_idx} ({len(wrapper.lora_experts.experts)} experts)") + + output_file = os.path.join(output_dir, "adapter_model.safetensors") + save_file(state_dict, output_file, metadata={"format": "pt"}) + + logger.info( + f"Saved LoRA Experts to {output_file}: " + f"{len(wrappers)} layers, {lora_expert_count} LoRA Expert tensors added, " + f"{len(state_dict)} total tensors" + ) + + +def save_kt_moe_to_adapter(model: nn.Module, output_dir: str) -> None: + """ + Unified function to save KT MoE weights to adapter file. + Note: Per-expert PEFT LoRA is saved by PEFT directly, not here. + This function only handles lora_experts (a separate feature). + """ + wrappers = getattr(model, "_kt_wrappers", []) + if not wrappers: + base_model = model + for attr in ["base_model", "model"]: + if hasattr(base_model, attr): + base_model = getattr(base_model, attr) + wrappers = getattr(base_model, "_kt_wrappers", []) + if wrappers: + break + if not wrappers: + logger.info("[save_kt_moe] No KT wrappers found, skipping") + return + + has_lora_experts = any(w.lora_experts is not None for w in wrappers) + + if has_lora_experts: + save_lora_experts_to_adapter(model, output_dir) + else: + logger.info("[save_kt_moe] No lora_experts in KT wrappers") + + +def load_lora_experts_from_adapter(model: nn.Module, adapter_path: str) -> None: + """ + Load LoRA Experts weights from adapter file into KT wrappers. + """ + from safetensors import safe_open + + wrappers = getattr(model, "_kt_wrappers", []) + if not wrappers: + base_model = model + for attr in ["base_model", "model"]: + if hasattr(base_model, attr): + base_model = getattr(base_model, attr) + wrappers = getattr(base_model, "_kt_wrappers", []) + if wrappers: + break + if not wrappers: + logger.warning("No KT wrappers found, skipping LoRA Experts loading") + return + + wrapper_map = {w.layer_idx: w for w in wrappers if w.lora_experts is not None} + if not wrapper_map: + logger.warning("No LoRA Experts found in KT wrappers, skipping") + return + + # Prefer dedicated lora_experts file, fallback to adapter file + adapter_file = os.path.join(adapter_path, "lora_experts.safetensors") + if not os.path.exists(adapter_file): + adapter_file = os.path.join(adapter_path, "adapter_model.safetensors") + if not os.path.exists(adapter_file): + adapter_file = os.path.join(adapter_path, "adapter_model.bin") + if not os.path.exists(adapter_file): + logger.warning(f"No lora_experts or adapter file found at {adapter_path}") + return + + logger.info(f"Loading LoRA Experts from {adapter_file}") + + lora_expert_pattern = re.compile( + r"base_model\.model\.model\.layers\.(\d+)\.mlp\.lora_experts\.(\d+)\.(le_gate|le_up|le_down)\.weight" + ) + + layer_weights = {} + with safe_open(adapter_file, framework="pt") as f: + for key in f.keys(): + match = lora_expert_pattern.match(key) + if match: + layer_idx = int(match.group(1)) + expert_idx = int(match.group(2)) + proj_name = match.group(3) + layer_weights.setdefault(layer_idx, {}).setdefault(expert_idx, {})[proj_name] = f.get_tensor(key) + + loaded_count = 0 + for layer_idx, experts_dict in layer_weights.items(): + if layer_idx not in wrapper_map: + logger.warning(f"No LoRA Experts for layer {layer_idx}, skipping") + continue + + wrapper = wrapper_map[layer_idx] + for expert_idx, proj_dict in experts_dict.items(): + if expert_idx >= len(wrapper.lora_experts.experts): + continue + expert = wrapper.lora_experts.experts[expert_idx] + if "le_gate" in proj_dict: + expert.le_gate.weight.data.copy_(proj_dict["le_gate"].to(expert.le_gate.weight.device)) + if "le_up" in proj_dict: + expert.le_up.weight.data.copy_(proj_dict["le_up"].to(expert.le_up.weight.device)) + if "le_down" in proj_dict: + expert.le_down.weight.data.copy_(proj_dict["le_down"].to(expert.le_down.weight.device)) + loaded_count += 1 + + logger.info(f"Loaded LoRA Experts for {loaded_count} experts from {adapter_path}") + + +def load_kt_moe_from_adapter(model: nn.Module, adapter_path: str) -> None: + """ + Unified function to load KT MoE weights from adapter file. + Note: Per-expert PEFT LoRA is loaded by PEFT directly, not here. + This function only handles lora_experts (a separate feature). + """ + wrappers = getattr(model, "_kt_wrappers", []) + if not wrappers: + base_model = model + for attr in ["base_model", "model"]: + if hasattr(base_model, attr): + base_model = getattr(base_model, attr) + wrappers = getattr(base_model, "_kt_wrappers", []) + if wrappers: + break + if not wrappers: + logger.warning("No KT wrappers found, skipping KT MoE loading") + return + + has_lora_experts = any(w.lora_experts is not None for w in wrappers) + + if has_lora_experts: + load_lora_experts_from_adapter(model, adapter_path) + else: + logger.info("No lora_experts in KT wrappers (PEFT LoRA is loaded by PEFT directly)") diff --git a/kt-kernel/python/sft/weights.py b/kt-kernel/python/sft/weights.py new file mode 100644 index 00000000..b0bbb6a2 --- /dev/null +++ b/kt-kernel/python/sft/weights.py @@ -0,0 +1,488 @@ +# Weight extraction and loading utilities for SFT +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +import logging +import os +import time +from contextlib import nullcontext +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from .arch import MOEArchConfig +from .dist_utils import _maybe_zero3_gathered_parameters + +logger = logging.getLogger(__name__) + +try: + from safetensors import safe_open + + SAFETENSORS_AVAILABLE = True +except ImportError: + SAFETENSORS_AVAILABLE = False + safe_open = None + + +# ============================================================================= +# Weight Extraction +# ============================================================================= + + +def extract_moe_weights( + moe_module: nn.Module, moe_config: MOEArchConfig +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Extract MoE expert weights from the module. + + Returns (gate_proj, up_proj, down_proj) with shape + [expert_num, out_features, in_features]. + """ + experts = getattr(moe_module, moe_config.experts_attr) + gate_name, up_name, down_name = moe_config.weight_names + + gather_params: list[torch.nn.Parameter] = [] + for expert in experts: + for weight_name in (gate_name, up_name, down_name): + proj = getattr(expert, weight_name, None) + if proj is not None and hasattr(proj, "weight"): + # Handle PEFT LoRA wrapped modules + weight = proj.weight + if isinstance(weight, torch.Tensor): + gather_params.append(weight) + elif hasattr(weight, "data"): + gather_params.append(weight.data) + + with _maybe_zero3_gathered_parameters(gather_params): + gate_weights = [] + up_weights = [] + down_weights = [] + + for expert in experts: + # Handle PEFT LoRA wrapped modules - get weight tensor properly + gate_proj = getattr(expert, gate_name) + up_proj_mod = getattr(expert, up_name) + down_proj_mod = getattr(expert, down_name) + + # Get weight tensors, handling both regular Linear and PEFT LoRA wrapped + def get_weight_tensor(mod): + weight = mod.weight + if isinstance(weight, torch.Tensor): + return weight.data + elif hasattr(weight, "data"): + return weight.data + else: + raise ValueError(f"Cannot extract weight from {type(mod)}, weight type={type(weight)}") + + gate_weights.append(get_weight_tensor(gate_proj)) + up_weights.append(get_weight_tensor(up_proj_mod)) + down_weights.append(get_weight_tensor(down_proj_mod)) + + gate_proj = torch.stack(gate_weights, dim=0) + up_proj = torch.stack(up_weights, dim=0) + down_proj = torch.stack(down_weights, dim=0) + + return gate_proj, up_proj, down_proj + + +def _clear_original_expert_weights(moe_module: nn.Module, moe_config: MOEArchConfig) -> None: + """ + Clear original expert weights to free memory after KT weights are loaded. + """ + experts = getattr(moe_module, moe_config.experts_attr, None) + if experts is None: + return + + def _iter_weight_params(): + for expert in experts: + for weight_name in moe_config.weight_names: + proj = getattr(expert, weight_name, None) + if proj is None or not hasattr(proj, "weight"): + continue + + parametrizations = getattr(proj, "parametrizations", None) + parametrized_weight = getattr(parametrizations, "weight", None) if parametrizations is not None else None + if parametrized_weight is not None: + original = getattr(parametrized_weight, "original", None) + if isinstance(original, torch.nn.Parameter): + yield proj, parametrized_weight, "original", original + continue + + direct_weight = getattr(proj, "_parameters", {}).get("weight") + if isinstance(direct_weight, torch.nn.Parameter): + yield proj, proj, "weight", direct_weight + continue + + # Fallback: `weight` can be a non-settable property (e.g. parametrizations) or a non-Parameter. + weight_attr = getattr(proj, "weight", None) + if isinstance(weight_attr, torch.nn.Parameter): + yield proj, proj, "weight", weight_attr + + gather_params: list[torch.nn.Parameter] = [] + for _, _, _, weight_param in _iter_weight_params(): + gather_params.append(weight_param) + + replaced_count = 0 + + with _maybe_zero3_gathered_parameters(gather_params): + for proj, container, param_name, weight_param in _iter_weight_params(): + original_dtype = weight_param.dtype + + # Create a CPU tensor with the correct shape but NO physical memory. + # torch.empty(shape, device="cpu") unfortunately touches pages via the + # allocator, consuming real RSS. Instead, allocate a 1-byte storage and + # use set_ to give it the original shape with zero strides. The tensor + # is "valid" (correct dtype, device, shape) so PEFT can discover + # in/out features, but its storage is essentially zero-cost. + # NOTE: reading element values from this tensor is undefined -- it is + # only used for shape/dtype discovery by PEFT. + tiny_storage = torch.UntypedStorage(1, device="cpu") + fake_tensor = torch.tensor([], dtype=original_dtype, device="cpu").set_( + tiny_storage, storage_offset=0, size=weight_param.shape, + stride=[0] * len(weight_param.shape), + ) + new_param = nn.Parameter(fake_tensor, requires_grad=False) + replaced_count += 1 + + # Avoid `KeyError: attribute 'weight' already exists` for parametrized modules + # where `weight` is a property and the real parameter lives elsewhere. + container_params = getattr(container, "_parameters", {}) + if isinstance(container_params, dict) and param_name in container_params: + container_params[param_name] = new_param + continue + + if hasattr(container, param_name): + logger.debug( + f"Skipping clearing expert weight {type(proj).__name__}.{param_name}: " + "attribute exists but is not a registered parameter." + ) + continue + + try: + setattr(container, param_name, new_param) + except Exception as exc: + logger.warning( + f"Failed to clear expert weight {type(proj).__name__}.{param_name}: {exc}" + ) + + logger.info(f"Replaced {replaced_count} expert weight params") + + +# ============================================================================= +# kt_weight_path Loading Functions +# ============================================================================= + + +@dataclass +class INT8ExpertWeights: + """Container for INT8 expert weights with scales.""" + + gate_proj: torch.Tensor + gate_scale: torch.Tensor + up_proj: torch.Tensor + up_scale: torch.Tensor + down_proj: torch.Tensor + down_scale: torch.Tensor + + +def _find_safetensor_files(kt_weight_path: str) -> list[str]: + if not os.path.isdir(kt_weight_path): + raise FileNotFoundError(f"kt_weight_path directory not found: {kt_weight_path}") + + safetensor_files = [] + for file in sorted(os.listdir(kt_weight_path)): + if file.endswith(".safetensors"): + safetensor_files.append(os.path.join(kt_weight_path, file)) + + if not safetensor_files: + raise FileNotFoundError(f"No safetensors files found in {kt_weight_path}") + + return safetensor_files + + +def _load_kt_weight_index(kt_weight_path: str) -> dict[str, str]: + if not SAFETENSORS_AVAILABLE: + raise ImportError("safetensors is required for loading kt_weight_path") + + index = {} + safetensor_files = _find_safetensor_files(kt_weight_path) + + for file_path in safetensor_files: + with safe_open(file_path, framework="pt") as f: + for key in f.keys(): + index[key] = file_path + + logger.info(f"Indexed {len(index)} tensors from {len(safetensor_files)} safetensors files") + return index + + +def _dequant_fp8_experts(weights: list[torch.Tensor], scales: list[torch.Tensor | None], block_size: tuple[int, int]) -> torch.Tensor: + """Dequantize a list of FP8 expert weights and stack them (batched, vectorized). + + Args: + weights: list of [out, in] float8_e4m3fn tensors (one per expert) + scales: list of [out//bs_m, in//bs_n] scale_inv tensors (one per expert, may be None) + block_size: (bs_m, bs_n) + + Returns: + Stacked BF16 tensor of shape [num_experts, out, in] + """ + has_scales = scales[0] is not None + if not has_scales: + return torch.stack(weights, dim=0).to(torch.bfloat16).cpu().contiguous() + + bs_m, bs_n = block_size + n = len(weights) + out_features, in_features = weights[0].shape + + # Stack all experts: [N, out, in] fp8 -> reshape to blocks -> bf16 + w = torch.stack(weights, dim=0) # [N, out, in] fp8 + w = w.reshape(n, out_features // bs_m, bs_m, in_features // bs_n, bs_n) + w = w.to(torch.bfloat16) + + # Stack all scales: [N, out//bs_m, in//bs_n] -> bf16, broadcast multiply + s = torch.stack(scales, dim=0).to(torch.bfloat16) # [N, out//bs_m, in//bs_n] + w = w * s[:, :, None, :, None] + + return w.reshape(n, out_features, in_features).contiguous() + + +def load_experts_from_checkpoint_files( + checkpoint_files: list[str], + sharded_metadata: dict | None, + layers_prefix: str, + moe_config: MOEArchConfig, + layer_idx: int, + block_size: tuple[int, int] | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if not SAFETENSORS_AVAILABLE: + raise ImportError("safetensors is required for loading experts from checkpoint files") + + if not checkpoint_files: + raise FileNotFoundError("checkpoint_files is empty") + + t0 = time.time() + + weight_map = None + base_dir = os.path.dirname(checkpoint_files[0]) + if sharded_metadata is not None: + weight_map = sharded_metadata.get("weight_map", None) + + gate_name, up_name, down_name = moe_config.weight_names + keys = [] + for expert_idx in range(moe_config.expert_num): + base = f"{layers_prefix}.{layer_idx}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}.{expert_idx}" + keys.append(f"{base}.{gate_name}.weight") + keys.append(f"{base}.{gate_name}.weight_scale_inv") + keys.append(f"{base}.{up_name}.weight") + keys.append(f"{base}.{up_name}.weight_scale_inv") + keys.append(f"{base}.{down_name}.weight") + keys.append(f"{base}.{down_name}.weight_scale_inv") + + keys_by_file: dict[str, list[str]] = {} + mapped_count = 0 + unmapped_count = 0 + for key in keys: + if weight_map is not None: + filename = weight_map.get(key) + if filename is None: + unmapped_count += 1 + continue + mapped_count += 1 + file_path = os.path.join(base_dir, filename) + else: + file_path = checkpoint_files[0] + keys_by_file.setdefault(file_path, []).append(key) + + print( + f"[kt_moe] Layer {layer_idx}: key mapping done in {time.time()-t0:.1f}s — " + f"total_keys={len(keys)}, mapped={mapped_count}, unmapped={unmapped_count}, " + f"files_to_open={len(keys_by_file)}", + flush=True, + ) + + t1 = time.time() + tensor_map: dict[str, torch.Tensor] = {} + for file_idx, (file_path, file_keys) in enumerate(keys_by_file.items()): + with safe_open(file_path, framework="pt") as f: + available_keys = set(f.keys()) + for key in file_keys: + if key in available_keys: + tensor_map[key] = f.get_tensor(key) + if file_idx == 0: + print( + f"[kt_moe] Layer {layer_idx}: first file loaded ({os.path.basename(file_path)}, " + f"{len(file_keys)} keys) in {time.time()-t1:.1f}s", + flush=True, + ) + + print( + f"[kt_moe] Layer {layer_idx}: all files loaded in {time.time()-t1:.1f}s — " + f"tensor_map has {len(tensor_map)} tensors", + flush=True, + ) + + gate_weights = [] + up_weights = [] + down_weights = [] + gate_scales = [] + up_scales = [] + down_scales = [] + for expert_idx in range(moe_config.expert_num): + base = f"{layers_prefix}.{layer_idx}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}.{expert_idx}" + gate_key = f"{base}.{gate_name}.weight" + up_key = f"{base}.{up_name}.weight" + down_key = f"{base}.{down_name}.weight" + if gate_key not in tensor_map or up_key not in tensor_map or down_key not in tensor_map: + raise FileNotFoundError(f"Missing expert weights for layer {layer_idx}, expert {expert_idx}") + gate_weights.append(tensor_map[gate_key]) + up_weights.append(tensor_map[up_key]) + down_weights.append(tensor_map[down_key]) + gate_scales.append(tensor_map.get(f"{base}.{gate_name}.weight_scale_inv")) + up_scales.append(tensor_map.get(f"{base}.{up_name}.weight_scale_inv")) + down_scales.append(tensor_map.get(f"{base}.{down_name}.weight_scale_inv")) + + # Check if weights are FP8 and need dequantization + t2 = time.time() + is_fp8 = gate_weights[0].dtype == torch.float8_e4m3fn + if is_fp8: + if block_size is None: + block_size = (128, 128) + print( + f"[kt_moe] Layer {layer_idx}: FP8 expert weights detected, " + f"dequantizing with block_size={block_size} " + f"(has_scales={gate_scales[0] is not None})", + flush=True, + ) + gate_proj = _dequant_fp8_experts(gate_weights, gate_scales, block_size) + up_proj = _dequant_fp8_experts(up_weights, up_scales, block_size) + down_proj = _dequant_fp8_experts(down_weights, down_scales, block_size) + else: + gate_proj = torch.stack(gate_weights, dim=0).cpu().to(torch.bfloat16).contiguous() + up_proj = torch.stack(up_weights, dim=0).cpu().to(torch.bfloat16).contiguous() + down_proj = torch.stack(down_weights, dim=0).cpu().to(torch.bfloat16).contiguous() + + print( + f"[kt_moe] Layer {layer_idx}: done — dtype={gate_proj.dtype}, shape={gate_proj.shape}, " + f"dequant={time.time()-t2:.1f}s, total={time.time()-t0:.1f}s", + flush=True, + ) + return gate_proj, up_proj, down_proj + + +def load_experts_from_kt_weight_path( + kt_weight_path: str, + layer_idx: int, + num_experts: int, + hidden_size: int, + intermediate_size: int, +) -> INT8ExpertWeights: + """Load INT8 preprocessed expert weights from kt_weight_path for a specific layer.""" + if not SAFETENSORS_AVAILABLE: + raise ImportError("safetensors is required for loading kt_weight_path") + + index = _load_kt_weight_index(kt_weight_path) + + numa_count = 0 + test_key_prefix = f"blk.{layer_idx}.ffn_gate_exps.0.numa." + for key in index.keys(): + if key.startswith(test_key_prefix) and key.endswith(".weight"): + numa_idx = int(key.split("numa.")[1].split(".")[0]) + numa_count = max(numa_count, numa_idx + 1) + + if numa_count == 0: + raise FileNotFoundError( + f"No weights found for layer {layer_idx} in {kt_weight_path}. " + f"Expected keys like 'blk.{layer_idx}.ffn_gate_exps.0.numa.0.weight'" + ) + + logger.info( + f"Loading INT8 weights for layer {layer_idx}: {num_experts} experts, {numa_count} NUMA partitions" + ) + + gate_weights_list = [] + gate_scales_list = [] + up_weights_list = [] + up_scales_list = [] + down_weights_list = [] + down_scales_list = [] + + for expert_idx in range(num_experts): + gate_w_parts = [] + gate_s_parts = [] + for numa_idx in range(numa_count): + w_key = f"blk.{layer_idx}.ffn_gate_exps.{expert_idx}.numa.{numa_idx}.weight" + s_key = f"blk.{layer_idx}.ffn_gate_exps.{expert_idx}.numa.{numa_idx}.scale" + + if w_key not in index: + raise FileNotFoundError(f"Weight key not found: {w_key}") + + with safe_open(index[w_key], framework="pt") as f: + gate_w_parts.append(f.get_tensor(w_key)) + gate_s_parts.append(f.get_tensor(s_key)) + + gate_w = torch.cat(gate_w_parts, dim=0) + gate_s = torch.cat(gate_s_parts, dim=0) + gate_w = gate_w.view(intermediate_size, hidden_size) + + gate_weights_list.append(gate_w) + gate_scales_list.append(gate_s) + + up_w_parts = [] + up_s_parts = [] + for numa_idx in range(numa_count): + w_key = f"blk.{layer_idx}.ffn_up_exps.{expert_idx}.numa.{numa_idx}.weight" + s_key = f"blk.{layer_idx}.ffn_up_exps.{expert_idx}.numa.{numa_idx}.scale" + + if w_key not in index: + raise FileNotFoundError(f"Weight key not found: {w_key}") + + with safe_open(index[w_key], framework="pt") as f: + up_w_parts.append(f.get_tensor(w_key)) + up_s_parts.append(f.get_tensor(s_key)) + + up_w = torch.cat(up_w_parts, dim=0) + up_s = torch.cat(up_s_parts, dim=0) + up_w = up_w.view(intermediate_size, hidden_size) + + up_weights_list.append(up_w) + up_scales_list.append(up_s) + + down_w_parts = [] + down_s_parts = [] + for numa_idx in range(numa_count): + w_key = f"blk.{layer_idx}.ffn_down_exps.{expert_idx}.numa.{numa_idx}.weight" + s_key = f"blk.{layer_idx}.ffn_down_exps.{expert_idx}.numa.{numa_idx}.scale" + + if w_key not in index: + raise FileNotFoundError(f"Weight key not found: {w_key}") + + with safe_open(index[w_key], framework="pt") as f: + down_w_parts.append(f.get_tensor(w_key)) + down_s_parts.append(f.get_tensor(s_key)) + + down_w = torch.cat(down_w_parts, dim=0) + down_s = torch.cat(down_s_parts, dim=0) + down_w = down_w.view(hidden_size, intermediate_size) + + down_weights_list.append(down_w) + down_scales_list.append(down_s) + + gate_proj = torch.stack(gate_weights_list, dim=0) + gate_scale = torch.stack(gate_scales_list, dim=0) + up_proj = torch.stack(up_weights_list, dim=0) + up_scale = torch.stack(up_scales_list, dim=0) + down_proj = torch.stack(down_weights_list, dim=0) + down_scale = torch.stack(down_scales_list, dim=0) + + return INT8ExpertWeights( + gate_proj=gate_proj, + gate_scale=gate_scale, + up_proj=up_proj, + up_scale=up_scale, + down_proj=down_proj, + down_scale=down_scale, + ) diff --git a/kt-kernel/python/sft/wrapper.py b/kt-kernel/python/sft/wrapper.py new file mode 100644 index 00000000..c762d00d --- /dev/null +++ b/kt-kernel/python/sft/wrapper.py @@ -0,0 +1,610 @@ +# Model wrapping entry points for SFT +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import gc +import importlib.util as _u +import logging +import os +from typing import Any + +import torch +import torch.nn as nn + +from .arch import ( + KTAMXConfigError, + KTAMXNotAvailableError, + MOEArchConfig, + _get_layers_prefix, + _get_model_container_and_layers, + get_moe_arch_config, + get_moe_module, +) +from .layer import KTMoELayerWrapper +from .lora import LoRAExperts +from .weights import ( + _clear_original_expert_weights, + extract_moe_weights, + load_experts_from_checkpoint_files, +) + +logger = logging.getLogger(__name__) + +KT_KERNEL_AVAILABLE = _u.find_spec("kt_kernel") is not None + +if KT_KERNEL_AVAILABLE: + try: + from kt_kernel.experts import KTMoEWrapper + except Exception: + KTMoEWrapper = None + KT_KERNEL_AVAILABLE = False +else: + KTMoEWrapper = None + + +# ============================================================================= +# Device-map builders +# ============================================================================= + + +def _get_kt_config(kt_plugin: Any): + """Extract KTConfig from a KTransformersPlugin or compatible object. + + Handles three cases: + 1. KTransformersPlugin with .kt_config (new style) → return kt_config + 2. Object with old field names (kt_num_threads etc.) → convert to KTConfig + 3. KTConfig directly → return as-is + """ + from .config import KTConfig + + # New-style KTransformersPlugin + kt_config = getattr(kt_plugin, "kt_config", None) + if kt_config is not None and isinstance(kt_config, KTConfig): + return kt_config + + # Already a KTConfig + if isinstance(kt_plugin, KTConfig): + return kt_plugin + + # Old-style object (HfTrainerKTConfig, old KTransformersPlugin, dict-like) — convert + # Map old field names (kt_xxx) to new field names (xxx) + _OLD_TO_NEW = { + "kt_backend": "backend", "kt_num_threads": "num_threads", + "kt_tp_enabled": "tp_enabled", "kt_threadpool_count": "threadpool_count", + "kt_weight_path": "weight_path", "kt_expert_checkpoint_path": "expert_checkpoint_path", + "kt_num_gpu_experts": "num_gpu_experts", "kt_max_cache_depth": "max_cache_depth", + "kt_use_lora_experts": "use_lora_experts", "kt_lora_expert_num": "lora_expert_num", + "kt_lora_expert_intermediate_size": "lora_expert_intermediate_size", + "kt_skip_expert_loading": "skip_expert_loading", + "kt_share_backward_bb": "share_backward_bb", + "kt_checkpoint_files": "checkpoint_files", + "kt_sharded_metadata": "sharded_metadata", + } + kwargs = {} + for old_name, new_name in _OLD_TO_NEW.items(): + val = getattr(kt_plugin, old_name, None) + if val is not None: + kwargs[new_name] = val + # Fields that don't have kt_ prefix + for name in ("lora_rank", "lora_alpha", "model_max_length", "wrap_fn", "wrap_kwargs"): + val = getattr(kt_plugin, name, None) + if val is not None: + kwargs[name] = val + return KTConfig(**kwargs) + + +def build_kt_device_map(config, kt_plugin, device: str = "cuda:0") -> dict[str, str | int]: + """ + Build device_map for KT model loading with hybrid GPU/CPU expert placement. + """ + moe_config = get_moe_arch_config(config) + layers_prefix = _get_layers_prefix(config) + num_layers = config.num_hidden_layers + num_experts = moe_config.expert_num + cfg = _get_kt_config(kt_plugin) + num_gpu_experts = getattr(cfg, "num_gpu_experts", 0) or 0 + + device_map: dict[str, str | int] = {} + + device_map["model.embed_tokens"] = device + device_map["model.norm"] = device + device_map["lm_head"] = device + + for layer_idx in range(num_layers): + layer_prefix = f"{layers_prefix}.{layer_idx}" + device_map[layer_prefix] = device + moe_prefix = f"{layer_prefix}.{moe_config.moe_layer_attr}" + + for expert_idx in range(num_experts): + expert_key = f"{moe_prefix}.{moe_config.experts_attr}.{expert_idx}" + if expert_idx < num_gpu_experts: + device_map[expert_key] = device + else: + device_map[expert_key] = "cpu" + + logger.info( + f"Built KT device_map: {num_gpu_experts} GPU experts, {num_experts - num_gpu_experts} CPU experts" + ) + + return device_map + + +def build_kt_device_map_simplified(config, kt_plugin, device: str = "cuda:0") -> dict[str, str | int]: + """ + Simplified device_map builder: map full layers to GPU, override routed experts to CPU. + """ + moe_config = get_moe_arch_config(config) + layers_prefix = _get_layers_prefix(config) + num_layers = config.num_hidden_layers + cfg = _get_kt_config(kt_plugin) + num_gpu_experts = getattr(cfg, "num_gpu_experts", 0) or 0 + + device_map: dict[str, str | int] = {} + + device_map["model.embed_tokens"] = device + device_map["model.norm"] = device + device_map["lm_head"] = device + + for layer_idx in range(num_layers): + layer_prefix = f"{layers_prefix}.{layer_idx}" + device_map[layer_prefix] = device + + experts_prefix = f"{layer_prefix}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}" + + if num_gpu_experts == 0: + device_map[experts_prefix] = "cpu" + else: + return build_kt_device_map(config, kt_plugin, device=device) + + logger.info("Built simplified KT device_map: all layers on GPU, routed experts on CPU") + return device_map + + +# ============================================================================= +# MoE layer wrapping +# ============================================================================= + + +def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KTMoELayerWrapper]: + """ + Replace model's MoE layers with KTMoEWrapper-based wrappers. + + Loads expert weights into the C++ KT kernel. No LoRA initialization --- + LoRA is handled by PEFT and later adapted via kt_adapt_peft_lora(). + Only rank 0 initializes KT kernel and loads weights. + """ + import torch.distributed as dist + + if not KT_KERNEL_AVAILABLE: + raise KTAMXNotAvailableError("kt_kernel not found. Please install kt_kernel to enable KT MoE support.") + + # Only rank 0 should initialize KT and load weights + is_rank_0 = True + if dist.is_initialized(): + is_rank_0 = dist.get_rank() == 0 + + moe_config = get_moe_arch_config(model.config) + hidden_size = model.config.hidden_size + + cfg = _get_kt_config(kt_plugin) + + # Read lora_rank/lora_alpha for C++ wrapper initialization (buffer allocation only) + lora_rank = getattr(cfg, "lora_rank", 1) or 1 + lora_alpha = getattr(cfg, "lora_alpha", 1.0) or 1.0 + + # Read LoRA Experts configuration + _raw_le = getattr(cfg, "use_lora_experts", None) + use_lora_experts = bool(_raw_le) if _raw_le is not None else False + lora_expert_num = getattr(cfg, "lora_expert_num", 2) or 2 + lora_expert_intermediate_size = getattr(cfg, "lora_expert_intermediate_size", 1024) or 1024 + + if is_rank_0: + logger.info( + f"LoRA Experts config: use_lora_experts={use_lora_experts}, " + f"num={lora_expert_num}, intermediate_size={lora_expert_intermediate_size}" + ) + + wrappers: list[KTMoELayerWrapper] = [] + moe_layer_count = 0 + + kt_backend_map = { + "AMXBF16": "AMXBF16_SFT", + "AMXINT8": "AMXINT8_SFT", + "AMXINT4": "AMXINT4_SFT", + "AMXBF16_SkipLoRA": "AMXBF16_SFT_SkipLoRA", + "AMXINT8_SkipLoRA": "AMXINT8_SFT_SkipLoRA", + "AMXINT4_SkipLoRA": "AMXINT4_SFT_SkipLoRA", + } + # Build case-insensitive lookup to handle common typos like "SkipLora" vs "SkipLoRA" + _kt_backend_map_lower = {k.lower(): v for k, v in kt_backend_map.items()} + kt_backend = getattr(cfg, "backend", "AMXBF16") + kt_method = kt_backend_map.get(kt_backend) or _kt_backend_map_lower.get(kt_backend.lower(), "AMXBF16_SFT") + if kt_method != kt_backend_map.get(kt_backend): + logger.warning( + f"kt_backend '{kt_backend}' matched via case-insensitive lookup -> '{kt_method}'. " + f"Please use the exact name from: {list(kt_backend_map.keys())}" + ) + + if "SkipLoRA" in kt_method: + logger.info(f"Using SkipLoRA backend: {kt_method} (MoE LoRA gradients will be skipped)") + + threadpool_count = getattr(cfg, "threadpool_count", 1) if getattr(cfg, "tp_enabled", False) else 1 + + kt_weight_path = getattr(cfg, "weight_path", None) + use_kt_weight_path = kt_weight_path is not None + if use_kt_weight_path: + logger.info(f"Loading INT8 weights from kt_weight_path: {kt_weight_path}") + + checkpoint_files = getattr(cfg, "checkpoint_files", None) + sharded_metadata = getattr(cfg, "sharded_metadata", None) + + # When kt_expert_checkpoint_path is set, always resolve from it (overrides any existing + # checkpoint_files which may come from AttnOnlyBf16 and lack expert weights). + kt_expert_checkpoint_path = getattr(cfg, "expert_checkpoint_path", None) + if kt_expert_checkpoint_path: + logger.info(f"Resolving expert checkpoint files from kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}") + resolved_files, resolved_meta = _resolve_checkpoint_files(model_name_or_path=kt_expert_checkpoint_path) + if resolved_files and all(f.endswith(".safetensors") for f in resolved_files): + checkpoint_files = resolved_files + sharded_metadata = resolved_meta + cfg.checkpoint_files = checkpoint_files + cfg.sharded_metadata = sharded_metadata + logger.info(f"Resolved {len(checkpoint_files)} checkpoint files from kt_expert_checkpoint_path") + else: + logger.warning(f"Failed to resolve checkpoint files from kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}") + + use_checkpoint_files = bool(checkpoint_files) and not use_kt_weight_path + + logger.debug( + f"Weight source: kt_weight_path={kt_weight_path!r}, " + f"kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}, " + f"checkpoint_files count={len(checkpoint_files) if checkpoint_files else 0}, " + f"use_kt_weight_path={use_kt_weight_path}, use_checkpoint_files={use_checkpoint_files}" + ) + + if use_checkpoint_files: + logger.info("Loading expert weights from checkpoint files (online conversion).") + elif use_kt_weight_path and bool(checkpoint_files): + logger.info("BF16 checkpoint files available for backward gradient computation.") + elif (not use_kt_weight_path) and bool(getattr(cfg, "skip_expert_loading", False)): + # If HF expert weights were skipped during `from_pretrained`, we must source expert weights externally. + model_name_or_path = getattr(getattr(model, "config", None), "name_or_path", None) + if model_name_or_path: + resolved_files, resolved_meta = _resolve_checkpoint_files(model_name_or_path=model_name_or_path) + if resolved_files and all(f.endswith(".safetensors") for f in resolved_files): + checkpoint_files = resolved_files + sharded_metadata = resolved_meta + cfg.checkpoint_files = checkpoint_files + cfg.sharded_metadata = sharded_metadata + use_checkpoint_files = True + logger.info("KT skip_expert_loading enabled; using checkpoint files for online expert loading.") + + if not use_checkpoint_files: + raise KTAMXConfigError( + "KT skip_expert_loading is enabled but no `kt_weight_path` was provided and no safetensors checkpoint " + "files could be resolved for on-the-fly expert loading." + ) + + import torch.distributed as _dist + _rank = _dist.get_rank() if _dist.is_initialized() else 0 + + model_container, layers = _get_model_container_and_layers(model, purpose="wrapping") + logger.info(f"Total layers={len(layers)}, is_rank_0={is_rank_0}") + + for layer_idx, layer in enumerate(layers): + moe_module = get_moe_module(layer, moe_config) + if moe_module is None: + continue + + logger.debug(f"Wrapping MoE layer {layer_idx} (method={kt_method})") + + # Only rank 0 loads weights and initializes KT kernel + gate_proj, up_proj, down_proj = None, None, None + wrapper = None + + if is_rank_0: + # Get block_size from quantization_config if available (for FP8 dequant) + _quant_cfg = getattr(model.config, "quantization_config", None) + _block_size = None + if _quant_cfg is not None: + _block_size = getattr(_quant_cfg, "weight_block_size", None) + + if use_kt_weight_path: + logger.debug(f"Layer {layer_idx}: forward + backward from kt_weight_path (.kt files)") + elif use_checkpoint_files: + layers_prefix = _get_layers_prefix(model.config) + gate_proj, up_proj, down_proj = load_experts_from_checkpoint_files( + checkpoint_files=checkpoint_files, + sharded_metadata=sharded_metadata, + layers_prefix=layers_prefix, + moe_config=moe_config, + layer_idx=layer_idx, + block_size=_block_size, + ) + else: + gate_proj, up_proj, down_proj = extract_moe_weights(moe_module, moe_config) + gate_proj = gate_proj.cpu().to(torch.bfloat16).contiguous() + up_proj = up_proj.cpu().to(torch.bfloat16).contiguous() + down_proj = down_proj.cpu().to(torch.bfloat16).contiguous() + + chunked_prefill_size = getattr(cfg, "model_max_length", None) + if chunked_prefill_size is None: + chunked_prefill_size = getattr(model.config, "max_position_embeddings", 4096) + + # Only rank 0 creates KTMoEWrapper and loads weights + if is_rank_0: + wrapper = KTMoEWrapper( + layer_idx=layer_idx, + num_experts=moe_config.expert_num, + num_experts_per_tok=moe_config.num_experts_per_tok, + hidden_size=hidden_size, + moe_intermediate_size=moe_config.intermediate_size, + num_gpu_experts=0, + cpuinfer_threads=getattr(cfg, "num_threads", 1), + threadpool_count=threadpool_count, + weight_path=kt_weight_path or "", + chunked_prefill_size=chunked_prefill_size, + method=kt_method, + mode="sft", + lora_rank=lora_rank, + lora_alpha=lora_alpha, + max_cache_depth=getattr(cfg, "max_cache_depth", 2), + ) + + # Set share_backward_bb BEFORE load_weights (config is built during load) + share_backward_bb = getattr(cfg, "share_backward_bb", None) + if share_backward_bb is None: + share_backward_bb = os.environ.get("ACCELERATE_KT_SHARE_BACKWARD_BB", "").lower() in ("true", "1", "yes") + wrapper.share_backward_bb = share_backward_bb + + physical_to_logical_map = torch.arange(moe_config.expert_num, dtype=torch.int64, device="cpu") + + if use_kt_weight_path: + logger.debug(f"Layer {layer_idx}: calling wrapper.load_weights() (C++ direct .kt load)") + wrapper.load_weights(physical_to_logical_map) + else: + logger.debug( + f"Layer {layer_idx}: calling wrapper.load_weights_from_tensors() " + f"(BF16 tensor path, gate_proj shape={gate_proj.shape if gate_proj is not None else None})" + ) + wrapper.load_weights_from_tensors( + gate_proj=gate_proj, + up_proj=up_proj, + down_proj=down_proj, + physical_to_logical_map_cpu=physical_to_logical_map, + ) + + wrapper.gate_proj = None + wrapper.up_proj = None + wrapper.down_proj = None + + # Create LoRA Experts if enabled + lora_experts = None + if use_lora_experts: + lora_experts = LoRAExperts( + num_experts=lora_expert_num, + hidden_size=hidden_size, + intermediate_size=lora_expert_intermediate_size, + device="cuda", + dtype=torch.bfloat16, + ) + + layer_wrapper = KTMoELayerWrapper( + original_moe=moe_module, + wrapper=wrapper, + lora_params=None, + moe_config=moe_config, + hidden_size=hidden_size, + layer_idx=layer_idx, + lora_experts=lora_experts, + ) + layer_wrapper._skip_lora = "SkipLoRA" in kt_method + + setattr(layer, moe_config.moe_layer_attr, layer_wrapper) + # Base weights have been copied into the C++ kernel's internal BufferB format. + # Do not hold a Python-side reference --- it wastes ~1 GB/layer. + del gate_proj, up_proj, down_proj + + wrappers.append(layer_wrapper) + moe_layer_count += 1 + + # Replace original expert weights with meta placeholders. + # Experts remain in the model tree (via wrapper.experts) so PEFT can discover them. + # Rank 0 already copied weights to C++ kernel via load_weights_from_tensors. + _clear_original_expert_weights(moe_module, moe_config) + + logger.info(f"Wrapped {moe_layer_count} MoE layers with KTMoEWrapper") + + # Link wrappers for async backward repack (higher layer triggers repack for lower) + for i in range(1, len(wrappers)): + if wrappers[i].wrapper is not None and wrappers[i - 1].wrapper is not None: + wrappers[i].wrapper._next_backward_wrapper = wrappers[i - 1].wrapper + if wrappers and wrappers[0].wrapper is not None: + wrappers[0].wrapper._next_backward_wrapper = None + + gc.collect() + return wrappers + + +# ============================================================================= +# Plugin builder +# ============================================================================= + + +def _build_kt_plugin_from_args(model_args: Any, finetuning_args: Any | None = None): + """ + Build a KTransformersPlugin from model_args and optional finetuning_args. + + Imported here to avoid circular dependency --- callers that need the plugin + class should import it from the appropriate dataclasses module. + """ + from .config import KTConfig + from accelerate.utils.dataclasses import KTransformersPlugin + + kt_config = KTConfig( + backend=getattr(model_args, "kt_backend", None), + num_threads=getattr(model_args, "kt_num_threads", None), + tp_enabled=getattr(model_args, "kt_tp_enabled", None), + threadpool_count=getattr(model_args, "kt_threadpool_count", None), + max_cache_depth=getattr(model_args, "kt_max_cache_depth", None), + num_gpu_experts=getattr(model_args, "kt_num_gpu_experts", None), + weight_path=getattr(model_args, "kt_weight_path", None), + expert_checkpoint_path=getattr(model_args, "kt_expert_checkpoint_path", None), + use_lora_experts=getattr(model_args, "kt_use_lora_experts", None), + lora_expert_num=getattr(model_args, "kt_lora_expert_num", None), + lora_expert_intermediate_size=getattr(model_args, "kt_lora_expert_intermediate_size", None), + lora_rank=getattr(finetuning_args, "lora_rank", None) if finetuning_args else None, + lora_alpha=getattr(finetuning_args, "lora_alpha", None) if finetuning_args else None, + model_max_length=getattr(model_args, "model_max_length", None), + ) + return KTransformersPlugin(enabled=True, kt_config=kt_config) + + +def get_kt_loading_kwargs( + config, + kt_plugin, + torch_dtype: torch.dtype | str | None = torch.bfloat16, + trust_remote_code: bool | None = None, + token: str | None = None, +) -> dict[str, Any]: + """Get kwargs for AutoModel.from_pretrained() for KT loading.""" + kwargs: dict[str, Any] = { + "config": config, + "torch_dtype": torch_dtype, + "device_map": "cpu", + "low_cpu_mem_usage": True, + } + if trust_remote_code is not None: + kwargs["trust_remote_code"] = trust_remote_code + if token is not None: + kwargs["token"] = token + return kwargs + + +def _resolve_checkpoint_files( + model_name_or_path: str, + cache_dir: str | None = None, + revision: str | None = None, + token: str | None = None, + trust_remote_code: bool | None = None, +) -> tuple[list[str] | None, dict | None]: + """Resolve HF checkpoint files. Depends on transformers internals.""" + try: + from transformers.modeling_utils import _get_resolved_checkpoint_files + except Exception: + return None, None + try: + checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files( + pretrained_model_name_or_path=model_name_or_path, + subfolder="", + variant=None, + gguf_file=None, + from_tf=False, + from_flax=False, + use_safetensors=None, + cache_dir=cache_dir, + force_download=False, + proxies=None, + local_files_only=False, + token=token, + user_agent={"file_type": "model", "framework": "pytorch"}, + revision=revision or "main", + commit_hash=None, + is_remote_code=bool(trust_remote_code), + transformers_explicit_filename=None, + ) + except Exception: + return None, None + return checkpoint_files, sharded_metadata + + +def load_kt_model( + config, + model_args: Any | None = None, + finetuning_args: Any | None = None, + kt_plugin=None, + model_name_or_path: str | None = None, + trust_remote_code: bool | None = None, + token: str | None = None, + torch_dtype: torch.dtype | str | None = torch.bfloat16, + **kwargs, +) -> nn.Module: + """Load model with KTMoEWrapper backend.""" + from .arch import get_moe_arch_config, move_non_experts_to_gpu, get_expert_device, KTAMXNotAvailableError, KTAMXConfigError + + if kt_plugin is None: + if model_args is None: + raise KTAMXConfigError("Either kt_plugin or model_args must be provided to load_kt_model().") + kt_plugin = _build_kt_plugin_from_args(model_args, finetuning_args) + + if model_name_or_path is None and model_args is not None: + model_name_or_path = getattr(model_args, "model_name_or_path", None) + if model_name_or_path is None: + raise KTAMXConfigError("model_name_or_path is required to load_kt_model().") + + if trust_remote_code is None and model_args is not None: + trust_remote_code = getattr(model_args, "trust_remote_code", None) + if token is None and model_args is not None: + token = getattr(model_args, "hf_hub_token", None) + cache_dir = getattr(model_args, "cache_dir", None) if model_args is not None else None + revision = getattr(model_args, "revision", None) if model_args is not None else None + + _ = get_moe_arch_config(config) + + logger.info("Loading model with KTMoEWrapper backend") + + from transformers import AutoModelForCausalLM + from transformers.integrations.kt import set_kt_config, unset_kt_config + + loading_kwargs = get_kt_loading_kwargs( + config, kt_plugin, torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, token=token, + ) + if model_args is not None: + for key in ("cache_dir", "revision"): + value = getattr(model_args, key, None) + if value is not None: + loading_kwargs[key] = value + loading_kwargs.update(kwargs) + + cfg = _get_kt_config(kt_plugin) + + if getattr(cfg, "skip_expert_loading", None) is None: + checkpoint_files, sharded_metadata = _resolve_checkpoint_files( + model_name_or_path=model_name_or_path, + cache_dir=cache_dir, revision=revision, + token=token, trust_remote_code=trust_remote_code, + ) + if checkpoint_files and all(f.endswith(".safetensors") for f in checkpoint_files): + if getattr(cfg, "weight_path", None) is None: + cfg.skip_expert_loading = True + else: + cfg.skip_expert_loading = False + cfg.checkpoint_files = checkpoint_files + cfg.sharded_metadata = sharded_metadata + else: + cfg.skip_expert_loading = False + + set_kt_config(kt_plugin) + try: + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **loading_kwargs) + finally: + unset_kt_config() + + moe_config = get_moe_arch_config(config) + move_non_experts_to_gpu(model, moe_config, device="cuda:0") + + existing_wrappers = getattr(model, "_kt_wrappers", None) + if existing_wrappers: + logger.info(f"MoE layers already wrapped ({len(existing_wrappers)} layers), skipping re-wrap") + wrappers = existing_wrappers + else: + wrappers = wrap_moe_layers_with_kt_wrapper(model, kt_plugin) + + model._kt_wrappers = wrappers + model._kt_tp_enabled = bool(getattr(cfg, "tp_enabled", False)) + model._kt_use_lora_experts = bool(getattr(cfg, "use_lora_experts", False)) + + logger.info("Model loaded with KTMoEWrapper backend successfully") + return model diff --git a/kt-kernel/python/utils/amx_sft.py b/kt-kernel/python/utils/amx_sft.py deleted file mode 100644 index 52e7a2cc..00000000 --- a/kt-kernel/python/utils/amx_sft.py +++ /dev/null @@ -1,1162 +0,0 @@ -# AMX SFT MoE Wrapper classes for CPU-based fine-tuning operations -# SPDX-License-Identifier: Apache-2.0 - -""" -AMX-based SFT MoE Wrapper implementation. - -Supports quantization methods: -- AMXBF16_SFT: BF16 precision training -- AMXINT8_SFT: INT8 quantization training -- AMXINT4_SFT: INT4 quantization training -- AMXINT4_KGroup_SFT: INT4 K-Group quantization training (AWQ/K2) -""" - -import ctypes -import os -from datetime import datetime -import torch -from typing import Dict, Tuple, Optional, List - -from kt_kernel_ext.moe import MOESFTConfig - -from .loader import BF16SafeTensorLoader, SafeTensorLoader - -try: - from kt_kernel_ext.moe import ( - AMXBF16_SFT_MOE, - AMXInt8_SFT_MOE, - AMXInt4_SFT_MOE, - # AMXInt4_1_SFT_MOE, - # AMXInt4_1KGroup_SFT_MOE, - # AMXInt4_KGroup_SFT_MOE, - # SkipLoRA variants (skip all LoRA computation in backward) - AMXBF16_SFT_MOE_SkipLoRA, - AMXInt8_SFT_MOE_SkipLoRA, - AMXInt4_SFT_MOE_SkipLoRA, - # AMXInt4_1_SFT_MOE_SkipLoRA, - # AMXInt4_1KGroup_SFT_MOE_SkipLoRA, - # AMXInt4_KGroup_SFT_MOE_SkipLoRA, - ) - - _HAS_AMX_SFT_SUPPORT = True -except (ImportError, AttributeError): - _HAS_AMX_SFT_SUPPORT = False - AMXBF16_SFT_MOE = None - AMXInt8_SFT_MOE = None - AMXInt4_SFT_MOE = None - # AMXInt4_1_SFT_MOE = None - # AMXInt4_1KGroup_SFT_MOE = None - # AMXInt4_KGroup_SFT_MOE = None - # SkipLoRA variants - AMXBF16_SFT_MOE_SkipLoRA = None - AMXInt8_SFT_MOE_SkipLoRA = None - AMXInt4_SFT_MOE_SkipLoRA = None - # AMXInt4_1_SFT_MOE_SkipLoRA = None - # AMXInt4_1KGroup_SFT_MOE_SkipLoRA = None - # AMXInt4_KGroup_SFT_MOE_SkipLoRA = None - -from ..experts_sft import BaseSFTMoEWrapper, KExpertsSFTBuffer - -SFT_LIFECYCLE_LOG = os.environ.get("ACCELERATE_KT_LIFECYCLE_LOG", "0") == "1" - - -def _sft_lifecycle_log(tag: str, **stats) -> None: - if not SFT_LIFECYCLE_LOG: - return - try: - rank = int(os.environ.get("RANK", "0")) - except Exception: - rank = 0 - path_tpl = os.environ.get("ACCELERATE_KT_MEM_LOG_FILE", "kt_mem_rank{rank}.log") - path = path_tpl.format(rank=rank) - try: - pieces = [] - for k, v in stats.items(): - if isinstance(v, int) and ("bytes" in k or "nbytes" in k): - pieces.append(f"{k}={v/1024/1024:.2f}MB") - else: - pieces.append(f"{k}={v}") - line = ( - f"{datetime.now().isoformat()} pid={os.getpid()} rank={rank} " f"tag=sft_{tag} " + " ".join(pieces) + "\n" - ) - with open(path, "a", encoding="utf-8") as f: - f.write(line) - except Exception: - return - - -# Mapping from method string to C++ SFT MOE class -_SFT_METHOD_TO_CLASS = { - "AMXBF16_SFT": AMXBF16_SFT_MOE, - "AMXINT8_SFT": AMXInt8_SFT_MOE, - "AMXINT4_SFT": AMXInt4_SFT_MOE, - # "AMXINT4_1_SFT": AMXInt4_1_SFT_MOE, - # "AMXINT4_KGroup_SFT": AMXInt4_KGroup_SFT_MOE, - # "AMXINT4_1KGroup_SFT": AMXInt4_1KGroup_SFT_MOE, - # SkipLoRA variants (skip all LoRA computation in backward, only compute base weight grad_input) - "AMXBF16_SFT_SkipLoRA": AMXBF16_SFT_MOE_SkipLoRA, - "AMXINT8_SFT_SkipLoRA": AMXInt8_SFT_MOE_SkipLoRA, - "AMXINT4_SFT_SkipLoRA": AMXInt4_SFT_MOE_SkipLoRA, - # "AMXINT4_1_SFT_SkipLoRA": AMXInt4_1_SFT_MOE_SkipLoRA, - # "AMXINT4_KGroup_SFT_SkipLoRA": AMXInt4_KGroup_SFT_MOE_SkipLoRA, - # "AMXINT4_1KGroup_SFT_SkipLoRA": AMXInt4_1KGroup_SFT_MOE_SkipLoRA, -} - - -class AMXSFTMoEWrapper(BaseSFTMoEWrapper): - """ - AMX-based SFT MoE wrapper implementation. - - Supports BF16, INT8, INT4, and INT4 K-Group quantization methods - for supervised fine-tuning with LoRA adapters. - - Design Note (forward_sft vs forward): - forward_sft() is implemented independently from inference forward() because: - 1. Different requirements: inference optimizes for latency, SFT requires gradient correctness - 2. Safety: inference optimizations (deferred experts, async execution) would break SFT gradients - 3. Most reusable optimizations are already in C++ layer (via inheritance) - 4. Manual copying of useful optimizations is safer and more maintainable - """ - - def __init__( - self, - layer_idx: int, - num_experts: int, - num_experts_per_tok: int, - hidden_size: int, - moe_intermediate_size: int, - num_gpu_experts: int, - cpuinfer_threads: int, - threadpool_count: int, - weight_path: str, - chunked_prefill_size: int, - # SFT-specific parameters - lora_rank: int = 16, - lora_alpha: float = 32.0, - max_cache_depth: int = 1, - method: str = "AMXBF16_SFT", - # Quantization config (for K-Group methods) - group_size: int = 128, - zero_point: bool = True, - ): - """ - Initialize AMX SFT MoE Wrapper. - - Args: - layer_idx: Layer index - num_experts: Total number of experts - num_experts_per_tok: Number of experts per token (top-k) - hidden_size: Hidden dimension size - moe_intermediate_size: MoE intermediate size - num_gpu_experts: Number of experts on GPU (usually 0 for SFT) - cpuinfer_threads: Number of CPU inference threads - threadpool_count: Number of NUMA subpools (TP count) - weight_path: Path to weights - chunked_prefill_size: Maximum prefill chunk size - lora_rank: LoRA rank (r) - lora_alpha: LoRA scaling factor (alpha) - max_cache_depth: Maximum forward cache depth - method: AMX quantization method for SFT - group_size: Quantization group size (for K-Group methods) - zero_point: Whether to use zero point quantization (for K-Group methods) - """ - if not _HAS_AMX_SFT_SUPPORT: - raise RuntimeError( - "AMX SFT backend not available. kt_kernel_ext was not compiled with AMX SFT support.\n" - "Please recompile with AMX SFT enabled." - ) - - # Initialize base class - super().__init__( - layer_idx=layer_idx, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - hidden_size=hidden_size, - moe_intermediate_size=moe_intermediate_size, - num_gpu_experts=num_gpu_experts, - cpuinfer_threads=cpuinfer_threads, - threadpool_count=threadpool_count, - weight_path=weight_path, - chunked_prefill_size=chunked_prefill_size, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - max_cache_depth=max_cache_depth, - ) - - # Store method and quantization config - self.method = method - self._is_skip_lora = "SkipLoRA" in method - self.group_size = group_size - self.zero_point = zero_point - - # Dedicated CUDA stream for GPU→CPU grad_output copy. - # Avoids cuda.synchronize() which blocks on all pending GPU work. - self._copy_stream = None # lazily created on first CUDA backward - - # Validate method - if method not in _SFT_METHOD_TO_CLASS: - raise ValueError( - f"Unknown SFT method: {method}. " f"Supported methods: {list(_SFT_METHOD_TO_CLASS.keys())}" - ) - - # Get the C++ class for this method - moe_class = _SFT_METHOD_TO_CLASS[method] - if moe_class is None: - raise RuntimeError(f"AMX SFT method '{method}' not available in current build.") - - # Base weight storage (set via load_weights_from_tensors or loaded from file) - self.gate_proj: Optional[torch.Tensor] = None - self.up_proj: Optional[torch.Tensor] = None - self.down_proj: Optional[torch.Tensor] = None - - # MoE instance will be created during load_weights - self._moe_class = moe_class - - def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None: - """ - Load base weights for this layer. - - Supports two loading modes: - 1. From tensors: Call load_weights_from_tensors() first, then load_weights() - 2. From files: Automatically load from weight_path if base weights not set - - AMXBF16_SFT: Uses BF16SafeTensorLoader (HuggingFace format) - - AMXINT8_SFT/AMXINT4_SFT: Uses SafeTensorLoader (pre-quantized format) - - Args: - physical_to_logical_map_cpu: Mapping from physical to logical expert IDs - """ - if self._weights_loaded: - return - - # If base weights not set, try to load from file - if self.gate_proj is None and not getattr(self, "_use_projs_path", False): - self._load_base_weights_from_file() - - # Create MOE SFT config - config = MOESFTConfig() - config.expert_num = self.num_experts - config.num_experts_per_tok = self.num_experts_per_tok - config.hidden_size = self.hidden_size - config.intermediate_size = self.moe_intermediate_size - config.lora_rank = self.lora_rank - config.lora_alpha = self.lora_alpha - config.max_cache_depth = self.max_cache_depth - config.max_len = self.chunked_prefill_size - config.layer_idx = self.layer_idx - config.share_backward_bb = getattr(self, "share_backward_bb", False) - config.share_cache_pool = getattr(self, "share_cache_pool", False) - print( - f"[amx_sft] layer {self.layer_idx}: share_backward_bb={config.share_backward_bb}, " - f"share_cache_pool={config.share_cache_pool}, " - f"attr={getattr(self, 'share_backward_bb', 'MISSING')}", - flush=True, - ) - - # Set base weight pointers - if getattr(self, "_use_kt_direct_load", False): - # .kt directory: C++ reads forward + backward .kt files directly. - # Do NOT set config.gate_proj here — that would trigger the BF16 online-quant - # branch in moe-sft-tp.hpp instead of the .kt file-read branch. - config.load = True - config.path = self.weight_path - elif getattr(self, "_use_projs_path", False): - # Pre-quantized per-NUMA per-expert path (INT8/INT4) - config.gate_projs = self._gate_projs_ptrs - config.up_projs = self._up_projs_ptrs - config.down_projs = self._down_projs_ptrs - config.gate_scales = self._gate_scale_ptrs - config.up_scales = self._up_scale_ptrs - config.down_scales = self._down_scale_ptrs - # Also provide BF16 weight pointers for backward gradient computation. - # C++ backward needs BF16 base weights to compute gate/up LoRA B gradients - # through the gated MLP chain (grad_hidden = down_proj^T @ grad_output). - if getattr(self, "_bf16_gate_proj", None) is not None: - config.gate_proj = self._bf16_gate_proj.data_ptr() - config.up_proj = self._bf16_up_proj.data_ptr() - config.down_proj = self._bf16_down_proj.data_ptr() - # Set pre-quantized backward weight pointers if available - if getattr(self, "_has_bwd_projs", False): - config.gate_bwd_projs = self._gate_bwd_projs_ptrs - config.up_bwd_projs = self._up_bwd_projs_ptrs - config.down_bwd_projs = self._down_bwd_projs_ptrs - config.gate_bwd_scales = self._gate_bwd_scale_ptrs - config.up_bwd_scales = self._up_bwd_scale_ptrs - config.down_bwd_scales = self._down_bwd_scale_ptrs - else: - # Flat BF16 buffer path - config.gate_proj = self.gate_proj.data_ptr() - config.up_proj = self.up_proj.data_ptr() - config.down_proj = self.down_proj.data_ptr() - - # Set LoRA weight pointers (if initialized) - if self._lora_initialized: - config.gate_lora_a = self.gate_lora_a.data_ptr() - config.gate_lora_b = self.gate_lora_b.data_ptr() - config.up_lora_a = self.up_lora_a.data_ptr() - config.up_lora_b = self.up_lora_b.data_ptr() - config.down_lora_a = self.down_lora_a.data_ptr() - config.down_lora_b = self.down_lora_b.data_ptr() - - # Set thread pool - config.pool = self.cpu_infer.backend_ - - # Set quantization config for K-Group methods - if self.method in ("AMXINT4_KGroup_SFT", "AMXINT4_1KGroup_SFT"): - config.quant_config.group_size = self.group_size - config.quant_config.zero_point = self.zero_point - - # Create MoE instance - self.moe = self._moe_class(config) - - # Load weights - self.cpu_infer.submit(self.moe.load_weights_task()) - self.cpu_infer.sync() - - # Warm up - self.cpu_infer.submit(self.moe.warm_up_task()) - self.cpu_infer.sync() - - # Release Python-side base weight tensors. C++ has already copied/transformed - # them into internal BufferB format (backward_bb_pool_) and no longer needs - # the original bf16 data. Holding these wastes ~1 GB/layer. - self.gate_proj = None - self.up_proj = None - self.down_proj = None - - if getattr(self, "_bf16_gate_proj", None) is not None: - self._bf16_gate_proj = None - self._bf16_up_proj = None - self._bf16_down_proj = None - - # Release pre-quantized per-NUMA numpy arrays. C++ has already copied - # them into internal BufferB format via memcpy in load_weights(). - if getattr(self, "_use_projs_path", False): - self._gate_weights_per_numa = None - self._up_weights_per_numa = None - self._down_weights_per_numa = None - self._gate_scales_per_numa = None - self._up_scales_per_numa = None - self._down_scales_per_numa = None - self._gate_projs_ptrs = None - self._up_projs_ptrs = None - self._down_projs_ptrs = None - self._gate_scale_ptrs = None - self._up_scale_ptrs = None - self._down_scale_ptrs = None - # Release backward weight arrays - if getattr(self, "_has_bwd_projs", False): - self._gate_bwd_weights_per_numa = None - self._up_bwd_weights_per_numa = None - self._down_bwd_weights_per_numa = None - self._gate_bwd_scales_per_numa = None - self._up_bwd_scales_per_numa = None - self._down_bwd_scales_per_numa = None - self._gate_bwd_projs_ptrs = None - self._up_bwd_projs_ptrs = None - self._down_bwd_projs_ptrs = None - self._gate_bwd_scale_ptrs = None - self._up_bwd_scale_ptrs = None - self._down_bwd_scale_ptrs = None - - self._weights_loaded = True - - def load_weights_from_tensors( - self, - gate_proj: torch.Tensor, - up_proj: torch.Tensor, - down_proj: torch.Tensor, - physical_to_logical_map_cpu: torch.Tensor, - ) -> None: - """ - Load weights from BF16/FP16 tensors. - - This is the recommended way to load weights for SFT, as it supports - online quantization from full-precision weights. - - Args: - gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size] - up_proj: Up projection weights [num_experts, intermediate_size, hidden_size] - down_proj: Down projection weights [num_experts, hidden_size, intermediate_size] - physical_to_logical_map_cpu: Mapping from physical to logical expert IDs - """ - # Store tensors as instance variables to keep them alive - self.gate_proj = gate_proj.contiguous() - self.up_proj = up_proj.contiguous() - self.down_proj = down_proj.contiguous() - - # Now load weights - self.load_weights(physical_to_logical_map_cpu) - - del gate_proj - del up_proj - del down_proj - - def _load_base_weights_from_file(self) -> None: - """ - Load base MoE weights from file based on the SFT method. - - Loading strategy: - - .kt directory structure: Let C++ read .kt files directly (fastest) - - AMXBF16_SFT: Use BF16SafeTensorLoader (HuggingFace format, no scales) - - AMXINT8_SFT/AMXINT4_SFT: Use SafeTensorLoader (pre-quantized format with scales) - """ - if not hasattr(self, "weight_path") or self.weight_path is None: - raise RuntimeError( - "weight_path not set. Cannot load weights from file. " - "Either set weight_path or call load_weights_from_tensors() instead." - ) - - # Check if weight_path contains .kt directory structure (_layer_*/_numa_*/*.kt) - # If so, skip Python loading entirely — C++ reads .kt files directly via config.load - import os, glob as _glob - - kt_layer_dir = os.path.join(self.weight_path, f"_layer_{self.layer_idx}") - if os.path.isdir(kt_layer_dir): - kt_files = _glob.glob(os.path.join(kt_layer_dir, "_numa_0", "*.kt")) - if kt_files: - print( - f"[AMXSFTMoEWrapper] Detected .kt directory for layer {self.layer_idx}, " - f"C++ will load directly from {self.weight_path}" - ) - self._use_kt_direct_load = True - return - - print( - f"[AMXSFTMoEWrapper] Loading base weights for layer {self.layer_idx} " - f"from {self.weight_path} using method {self.method}" - ) - - # Determine loader and base key format based on method - if "BF16" in self.method: - # BF16 mode: Load from HuggingFace model path - loader = BF16SafeTensorLoader(self.weight_path) - base_key = f"model.layers.{self.layer_idx}" - else: - # INT8/INT4 mode: Load from pre-quantized path - # Note: SafeTensorLoader expects GGUF-style naming (blk.X) - loader = SafeTensorLoader(self.weight_path) - base_key = f"blk.{self.layer_idx}" - - # Load expert weights - experts_data = loader.load_experts(base_key, device="cpu") - - # Extract weights (list of tensors per expert -> stacked tensor) - gate_weights: List[torch.Tensor] = experts_data["gate"] - up_weights: List[torch.Tensor] = experts_data["up"] - down_weights: List[torch.Tensor] = experts_data["down"] - - # Stack expert weights: [num_experts, ...] - # For BF16: weights are already tensors - # For SafeTensorLoader: weights might be numpy arrays in nested lists - if "BF16" in self.method: - # BF16SafeTensorLoader returns list of tensors - self.gate_proj = torch.stack(gate_weights, dim=0).contiguous() - self.up_proj = torch.stack(up_weights, dim=0).contiguous() - self.down_proj = torch.stack(down_weights, dim=0).contiguous() - else: - # SafeTensorLoader returns nested lists [numa_id][expert_id] -> numpy array - # Keep per-NUMA per-expert arrays for gate_projs/gate_scales path - import numpy as np - - num_numa = len(gate_weights) - - # Store raw per-NUMA per-expert numpy arrays (keep references alive) - self._gate_weights_per_numa = gate_weights # [numa_id][expert_id] -> np array - self._up_weights_per_numa = up_weights - self._down_weights_per_numa = down_weights - self._gate_scales_per_numa = experts_data["gate_scale"] - self._up_scales_per_numa = experts_data["up_scale"] - self._down_scales_per_numa = experts_data["down_scale"] - - # Build pointer arrays: [[ptr_expert_0, ptr_expert_1, ...], ...] per NUMA - def _make_ptrs(arrays_per_numa): - return [ - [ - ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents) - for et in numa_array - ] - for numa_array in arrays_per_numa - ] - - self._gate_projs_ptrs = _make_ptrs(gate_weights) - self._up_projs_ptrs = _make_ptrs(up_weights) - self._down_projs_ptrs = _make_ptrs(down_weights) - self._gate_scale_ptrs = _make_ptrs(experts_data["gate_scale"]) - self._up_scale_ptrs = _make_ptrs(experts_data["up_scale"]) - self._down_scale_ptrs = _make_ptrs(experts_data["down_scale"]) - - # Build backward weight pointer arrays if available - if "gate_bwd" in experts_data: - self._gate_bwd_weights_per_numa = experts_data["gate_bwd"] - self._up_bwd_weights_per_numa = experts_data["up_bwd"] - self._down_bwd_weights_per_numa = experts_data["down_bwd"] - self._gate_bwd_scales_per_numa = experts_data["gate_bwd_scale"] - self._up_bwd_scales_per_numa = experts_data["up_bwd_scale"] - self._down_bwd_scales_per_numa = experts_data["down_bwd_scale"] - - self._gate_bwd_projs_ptrs = _make_ptrs(experts_data["gate_bwd"]) - self._up_bwd_projs_ptrs = _make_ptrs(experts_data["up_bwd"]) - self._down_bwd_projs_ptrs = _make_ptrs(experts_data["down_bwd"]) - self._gate_bwd_scale_ptrs = _make_ptrs(experts_data["gate_bwd_scale"]) - self._up_bwd_scale_ptrs = _make_ptrs(experts_data["up_bwd_scale"]) - self._down_bwd_scale_ptrs = _make_ptrs(experts_data["down_bwd_scale"]) - self._has_bwd_projs = True - else: - self._has_bwd_projs = False - - # Set gate_proj to None so load_weights() uses gate_projs path - self.gate_proj = None - self.up_proj = None - self.down_proj = None - self._use_projs_path = True - - # Close loader handles - loader.close_all_handles() - - if getattr(self, "_use_projs_path", False): - num_numa = len(self._gate_weights_per_numa) - num_experts = len(self._gate_weights_per_numa[0]) - print( - f"[AMXSFTMoEWrapper] Loaded pre-quantized weights: " - f"{num_numa} NUMA nodes, {num_experts} experts per NUMA" - ) - else: - print( - f"[AMXSFTMoEWrapper] Loaded weights: gate_proj={self.gate_proj.shape}, " - f"up_proj={self.up_proj.shape}, down_proj={self.down_proj.shape}" - ) - - def init_lora_weights( - self, - gate_lora_a: torch.Tensor, - gate_lora_b: torch.Tensor, - up_lora_a: torch.Tensor, - up_lora_b: torch.Tensor, - down_lora_a: torch.Tensor, - down_lora_b: torch.Tensor, - grad_gate_lora_a: torch.Tensor, - grad_gate_lora_b: torch.Tensor, - grad_up_lora_a: torch.Tensor, - grad_up_lora_b: torch.Tensor, - grad_down_lora_a: torch.Tensor, - grad_down_lora_b: torch.Tensor, - ) -> None: - """ - Initialize LoRA weights. - - LoRA output formula: - lora_output = (input @ A.T @ B.T) * (lora_alpha / lora_rank) - output = base_output + lora_output - - Args: - gate_lora_a: Gate LoRA A matrix [num_experts, lora_rank, hidden_size] - gate_lora_b: Gate LoRA B matrix [num_experts, intermediate_size, lora_rank] - up_lora_a: Up LoRA A matrix [num_experts, lora_rank, hidden_size] - up_lora_b: Up LoRA B matrix [num_experts, intermediate_size, lora_rank] - down_lora_a: Down LoRA A matrix [num_experts, lora_rank, intermediate_size] - down_lora_b: Down LoRA B matrix [num_experts, hidden_size, lora_rank] - """ - # Validate shapes - expected_shapes = { - "gate_lora_a": (self.num_experts, self.lora_rank, self.hidden_size), - "gate_lora_b": (self.num_experts, self.moe_intermediate_size, self.lora_rank), - "up_lora_a": (self.num_experts, self.lora_rank, self.hidden_size), - "up_lora_b": (self.num_experts, self.moe_intermediate_size, self.lora_rank), - "down_lora_a": (self.num_experts, self.lora_rank, self.moe_intermediate_size), - "down_lora_b": (self.num_experts, self.hidden_size, self.lora_rank), - } - - provided_tensors = { - "gate_lora_a": gate_lora_a, - "gate_lora_b": gate_lora_b, - "up_lora_a": up_lora_a, - "up_lora_b": up_lora_b, - "down_lora_a": down_lora_a, - "down_lora_b": down_lora_b, - } - - for name, tensor in provided_tensors.items(): - expected = expected_shapes[name] - if tensor.shape != expected: - raise ValueError(f"{name} shape mismatch: expected {expected}, got {tuple(tensor.shape)}") - - # Store LoRA weights (contiguous for C++ access) - self.gate_lora_a = gate_lora_a.contiguous() - self.gate_lora_b = gate_lora_b.contiguous() - self.up_lora_a = up_lora_a.contiguous() - self.up_lora_b = up_lora_b.contiguous() - self.down_lora_a = down_lora_a.contiguous() - self.down_lora_b = down_lora_b.contiguous() - - self.grad_gate_lora_a = grad_gate_lora_a.contiguous() - self.grad_gate_lora_b = grad_gate_lora_b.contiguous() - self.grad_up_lora_a = grad_up_lora_a.contiguous() - self.grad_up_lora_b = grad_up_lora_b.contiguous() - self.grad_down_lora_a = grad_down_lora_a.contiguous() - self.grad_down_lora_b = grad_down_lora_b.contiguous() - - self._lora_initialized = True - - # If weights already loaded, update LoRA pointers in C++ - if self._weights_loaded and self.moe is not None: - self.update_lora_weights() - - def forward_sft( - self, - hidden_states: torch.Tensor, - expert_ids: torch.Tensor, - weights: torch.Tensor, - save_for_backward: bool = True, - output_device: Optional[torch.device] = None, - ) -> torch.Tensor: - """ - SFT forward pass with optional gradient caching. - - Optimized for minimal data copying: - - Accepts GPU tensors directly, copies to pinned buffer in one step - - Returns directly to output_device without intermediate clone - - Args: - hidden_states: Input hidden states [qlen, hidden_size] (any device, will be converted to bf16) - expert_ids: Expert IDs [qlen, num_experts_per_tok] (any device, will be converted to int64) - weights: Expert weights [qlen, num_experts_per_tok] (any device, will be converted to float32) - save_for_backward: Whether to save activations for backward pass - output_device: Target device for output (None = return CPU tensor without clone, caller must copy immediately) - - Returns: - Output hidden states [qlen, hidden_size] - """ - if not self._weights_loaded: - raise RuntimeError("Weights not loaded. Call load_weights() or load_weights_from_tensors() first.") - - if not self._lora_initialized and not self._is_skip_lora: - raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.") - - qlen = hidden_states.shape[0] - if qlen > self.chunked_prefill_size: - raise ValueError( - f"qlen ({qlen}) exceeds chunked_prefill_size ({self.chunked_prefill_size}). " - "Increase chunked_prefill_size or reduce qlen to avoid buffer overrun." - ) - if expert_ids.shape[0] != qlen or expert_ids.shape[1] != self.num_experts_per_tok: - raise ValueError( - f"expert_ids shape {tuple(expert_ids.shape)} must be " f"({qlen}, {self.num_experts_per_tok})." - ) - if weights.shape[0] != qlen or weights.shape[1] != self.num_experts_per_tok: - raise ValueError(f"weights shape {tuple(weights.shape)} must be " f"({qlen}, {self.num_experts_per_tok}).") - - # Get or create buffer (always bf16 for computation) - buffer = KExpertsSFTBuffer.get_buffer( - qlen=qlen, - hidden_size=self.hidden_size, - moe_intermediate_size=self.moe_intermediate_size, - num_experts=self.num_experts, - num_experts_per_tok=self.num_experts_per_tok, - lora_rank=self.lora_rank, - dtype=torch.bfloat16, - ) - - # Copy input data directly to pinned CPU buffers (works for both CPU and GPU tensors) - # For GPU tensors: this is a single GPU->pinned copy (faster than GPU->CPU->pinned) - # For CPU tensors: this is a CPU->pinned copy - input_device = hidden_states.device - # Buffer may be larger than qlen (grow-only pool), slice for copy - buffer.input_cpu[:qlen].copy_(hidden_states.to(torch.bfloat16), non_blocking=True) - buffer.expert_ids_cpu[:qlen].copy_(expert_ids.to(torch.int64), non_blocking=True) - buffer.weights_cpu[:qlen].copy_(weights.to(torch.float32), non_blocking=True) - buffer.bsz_tensor[0] = qlen - - # Synchronize CUDA stream if input was on GPU to ensure data has arrived - if input_device.type == "cuda": - torch.cuda.synchronize(input_device) - - # Submit forward task — always pass the real save_for_backward. - # C++ will overwrite the existing cache entry (if any) instead of - # pushing a duplicate, so checkpoint recomputes are safe. - # data_ptr() is the same for [:qlen] slice (offset 0), C++ uses bsz_tensor for actual size. - self.cpu_infer.submit( - self.moe.forward_sft_task( - buffer.bsz_tensor.data_ptr(), - self.num_experts_per_tok, - buffer.expert_ids_cpu.data_ptr(), - buffer.weights_cpu.data_ptr(), - buffer.input_cpu.data_ptr(), - buffer.output_cpu.data_ptr(), - save_for_backward, - ) - ) - self.cpu_infer.sync() - - # Track cache depth (only increment on first push, not on overwrites) - if save_for_backward and self._cache_depth == 0: - self._cache_depth += 1 - _sft_lifecycle_log( - "forward_sync", - layer=self.layer_idx, - qlen=qlen, - save_for_backward=save_for_backward, - cache_depth=self._cache_depth, - max_cache_depth=self.max_cache_depth, - output_device=str(output_device) if output_device is not None else "None", - ) - - # Return output: slice to actual qlen - if output_device is not None: - return buffer.output_cpu[:qlen].to(device=output_device, non_blocking=True) - else: - return buffer.output_cpu[:qlen].clone() - - def backward( - self, - grad_output: torch.Tensor, - output_device: Optional[torch.device] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Backward pass computing gradients. - - Must be called after forward_sft(save_for_backward=True). - - Optimized for minimal data copying: - - Accepts GPU tensors directly - - Returns directly to output_device without intermediate clone - - LoRA gradients are returned in grad_loras dict (no clone needed) - - Args: - grad_output: Gradient from upstream [qlen, hidden_size] (any device, will be converted to bf16) - lora_params: Optional dict of LoRA parameters (kept for compatibility). - If provided, gradients are still returned in grad_loras. - Keys: gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, down_lora_a, down_lora_b - output_device: Target device for grad_input output (None = clone CPU tensors for safety) - - Returns: - grad_input: Input gradient [qlen, hidden_size] - grad_loras: LoRA gradients dict (e.g., grad_gate_lora_a, grad_gate_lora_b, ...) - grad_weights: Routing weights gradient [qlen, num_experts_per_tok] - """ - if self._cache_depth <= 0: - raise RuntimeError("No forward cache available. Call forward_sft(save_for_backward=True) first.") - - qlen = grad_output.shape[0] - - # Get buffer (should exist from forward pass, always bf16) - buffer = KExpertsSFTBuffer.get_buffer( - qlen=qlen, - hidden_size=self.hidden_size, - moe_intermediate_size=self.moe_intermediate_size, - num_experts=self.num_experts, - num_experts_per_tok=self.num_experts_per_tok, - lora_rank=self.lora_rank, - dtype=torch.bfloat16, - ) - - # Copy gradient to CPU buffer using dedicated stream (avoids blocking on all GPU work) - self._copy_grad_output_to_cpu(buffer, grad_output, qlen) - - # Submit backward task - # SkipLoRA: grad LoRA pointers unused by C++ (SkipLoRA template skips all LoRA grad writes), pass 0 - if self._is_skip_lora: - _gl = 0 - self.cpu_infer.submit( - self.moe.backward_task( - buffer.grad_output_cpu.data_ptr(), - buffer.grad_input_cpu.data_ptr(), - _gl, - _gl, - _gl, - _gl, - _gl, - _gl, - buffer.grad_weights.data_ptr(), - ) - ) - else: - self.cpu_infer.submit( - self.moe.backward_task( - buffer.grad_output_cpu.data_ptr(), - buffer.grad_input_cpu.data_ptr(), - self.grad_gate_lora_a.data_ptr(), - self.grad_gate_lora_b.data_ptr(), - self.grad_up_lora_a.data_ptr(), - self.grad_up_lora_b.data_ptr(), - self.grad_down_lora_a.data_ptr(), - self.grad_down_lora_b.data_ptr(), - buffer.grad_weights.data_ptr(), - ) - ) - self.cpu_infer.sync() - - # Decrease cache depth - self._cache_depth -= 1 - _sft_lifecycle_log( - "backward", - layer=self.layer_idx, - qlen=qlen, - cache_depth=self._cache_depth, - max_cache_depth=self.max_cache_depth, - grad_output_device=str(grad_output.device), - output_device=str(output_device) if output_device is not None else "None", - ) - - # Return gradients: slice to actual qlen - if output_device is not None: - grad_input = buffer.grad_input_cpu[:qlen].to(device=output_device, non_blocking=True) - grad_weights = buffer.grad_weights[:qlen].to(device=output_device, non_blocking=True) - else: - grad_input = buffer.grad_input_cpu[:qlen].clone() - grad_weights = buffer.grad_weights[:qlen].clone() - - return grad_input, grad_weights - - def _copy_grad_output_to_cpu(self, buffer, grad_output, qlen): - """Copy grad_output from GPU to CPU pinned buffer. - - Calls cuda.synchronize() first to drain pending GPU work (e.g. lm_head - backward) so that time is attributed to autograd, not to MoE backward. - The actual DMA copy with pinned memory takes <1ms. - """ - input_device = grad_output.device - if input_device.type == "cuda": - torch.cuda.synchronize(input_device) - buffer.grad_output_cpu[:qlen].copy_(grad_output.to(torch.bfloat16)) - # Data is now on CPU, ready for C++ kernel - - def submit_backward_async( - self, - grad_output: torch.Tensor, - output_device: Optional[torch.device] = None, - ) -> None: - """ - Submit backward task without waiting for completion. - Call sync_backward() later to get results. - - Args: - grad_output: Gradient from upstream [qlen, hidden_size] - output_device: Target device for results (stored for sync_backward) - """ - if self._cache_depth <= 0: - raise RuntimeError("No forward cache available. Call forward_sft(save_for_backward=True) first.") - - qlen = grad_output.shape[0] - - buffer = KExpertsSFTBuffer.get_buffer( - qlen=qlen, - hidden_size=self.hidden_size, - moe_intermediate_size=self.moe_intermediate_size, - num_experts=self.num_experts, - num_experts_per_tok=self.num_experts_per_tok, - lora_rank=self.lora_rank, - dtype=torch.bfloat16, - ) - - self._copy_grad_output_to_cpu(buffer, grad_output, qlen) - - if self._is_skip_lora: - _gl = 0 - self.cpu_infer.submit( - self.moe.backward_task( - buffer.grad_output_cpu.data_ptr(), - buffer.grad_input_cpu.data_ptr(), - _gl, - _gl, - _gl, - _gl, - _gl, - _gl, - buffer.grad_weights.data_ptr(), - ) - ) - else: - self.cpu_infer.submit( - self.moe.backward_task( - buffer.grad_output_cpu.data_ptr(), - buffer.grad_input_cpu.data_ptr(), - self.grad_gate_lora_a.data_ptr(), - self.grad_gate_lora_b.data_ptr(), - self.grad_up_lora_a.data_ptr(), - self.grad_up_lora_b.data_ptr(), - self.grad_down_lora_a.data_ptr(), - self.grad_down_lora_b.data_ptr(), - buffer.grad_weights.data_ptr(), - ) - ) - # DO NOT sync — store state for sync_backward() - self._async_bwd_qlen = qlen - self._async_bwd_output_device = output_device - - def sync_backward(self) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Wait for async backward to complete and return results. - Must be called after submit_backward_async(). - - Returns: - grad_input: Input gradient [qlen, hidden_size] - grad_weights: Routing weights gradient [qlen, num_experts_per_tok] - """ - self.cpu_infer.sync() - - qlen = self._async_bwd_qlen - output_device = self._async_bwd_output_device - - buffer = KExpertsSFTBuffer.get_buffer( - qlen=qlen, - hidden_size=self.hidden_size, - moe_intermediate_size=self.moe_intermediate_size, - num_experts=self.num_experts, - num_experts_per_tok=self.num_experts_per_tok, - lora_rank=self.lora_rank, - dtype=torch.bfloat16, - ) - - self._cache_depth -= 1 - _sft_lifecycle_log( - "sync_backward", - layer=self.layer_idx, - qlen=qlen, - cache_depth=self._cache_depth, - max_cache_depth=self.max_cache_depth, - output_device=str(output_device) if output_device is not None else "None", - ) - - if output_device is not None: - grad_input = buffer.grad_input_cpu[:qlen].to(device=output_device, non_blocking=True) - grad_weights = buffer.grad_weights[:qlen].to(device=output_device, non_blocking=True) - else: - grad_input = buffer.grad_input_cpu[:qlen].clone() - grad_weights = buffer.grad_weights[:qlen].clone() - - return grad_input, grad_weights - - def submit_backward_repack(self): - """Start async backward weight repacking (non-blocking).""" - if not self._weights_loaded or self.moe is None: - return - self.moe.submit_backward_repack() - - def wait_backward_repack(self): - """Wait for async backward weight repacking to complete.""" - if not self._weights_loaded or self.moe is None: - return - self.moe.wait_backward_repack() - - def save_backward_weights_from_tensors( - self, - gate_proj: "torch.Tensor", - up_proj: "torch.Tensor", - down_proj: "torch.Tensor", - physical_to_logical_map: "torch.Tensor", - output_path: str, - ) -> None: - """ - Prepare backward weights from BF16 tensors and save to disk. - - The C++ side transposes + quantizes the weights, then writes to .kt files. - This can be called offline to pre-compute backward weights. - - Args: - gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size] - up_proj: Up projection weights [num_experts, intermediate_size, hidden_size] - down_proj: Down projection weights [num_experts, hidden_size, intermediate_size] - physical_to_logical_map: Mapping from physical to logical expert IDs - output_path: Directory to save backward weight files - """ - if not self._weights_loaded: - raise RuntimeError("Weights not loaded. Call load_weights() first.") - - gate_proj = gate_proj.contiguous() - up_proj = up_proj.contiguous() - down_proj = down_proj.contiguous() - - self.moe.prepare_and_save_bwd( - gate_proj.data_ptr(), - up_proj.data_ptr(), - down_proj.data_ptr(), - output_path, - ) - - def update_lora_weights(self) -> None: - """ - Sync LoRA weights to C++ backend. - - Call this after using an external optimizer to update LoRA weights. - This is needed because TP mode partitions weights internally. - - Typical usage: - # 1. Forward + backward - output = wrapper.forward_sft(input, expert_ids, weights) - grad_input, grad_loras = wrapper.backward(grad_output) - - # 2. Update LoRA weights with optimizer - optimizer.step() - - # 3. Sync to C++ - wrapper.update_lora_weights() - """ - if not self._weights_loaded: - raise RuntimeError("Weights not loaded. Call load_weights() first.") - - if self._is_skip_lora: - return # SkipLoRA mode: no LoRA weights to update - - if not self._lora_initialized: - raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.") - - # Submit update task - self.cpu_infer.submit( - self.moe.update_lora_weights_task( - self.gate_lora_a.data_ptr(), - self.gate_lora_b.data_ptr(), - self.up_lora_a.data_ptr(), - self.up_lora_b.data_ptr(), - self.down_lora_a.data_ptr(), - self.down_lora_b.data_ptr(), - ) - ) - self.cpu_infer.sync() - - def submit_forward_sft( - self, - hidden_states: torch.Tensor, - expert_ids: torch.Tensor, - weights: torch.Tensor, - save_for_backward: bool = True, - ) -> None: - """ - Submit SFT forward pass asynchronously (non-blocking). - - This method submits the CPU MoE computation without waiting for completion, - allowing GPU computation (shared_experts, lora_experts) to proceed in parallel. - - Must be followed by sync_forward_sft() to retrieve results. - - Optimized: accepts GPU tensors directly, copies to pinned buffer in one step. - - Args: - hidden_states: Input hidden states [qlen, hidden_size] (any device, will be converted to bf16) - expert_ids: Expert IDs [qlen, num_experts_per_tok] (any device, will be converted to int64) - weights: Expert weights [qlen, num_experts_per_tok] (any device, will be converted to float32) - save_for_backward: Whether to save activations for backward pass - """ - if not self._weights_loaded: - raise RuntimeError("Weights not loaded. Call load_weights() or load_weights_from_tensors() first.") - - if not self._lora_initialized and not self._is_skip_lora: - raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.") - - qlen = hidden_states.shape[0] - if qlen > self.chunked_prefill_size: - raise ValueError( - f"qlen ({qlen}) exceeds chunked_prefill_size ({self.chunked_prefill_size}). " - "Increase chunked_prefill_size or reduce qlen to avoid buffer overrun." - ) - if expert_ids.shape[0] != qlen or expert_ids.shape[1] != self.num_experts_per_tok: - raise ValueError( - f"expert_ids shape {tuple(expert_ids.shape)} must be " f"({qlen}, {self.num_experts_per_tok})." - ) - if weights.shape[0] != qlen or weights.shape[1] != self.num_experts_per_tok: - raise ValueError(f"weights shape {tuple(weights.shape)} must be " f"({qlen}, {self.num_experts_per_tok}).") - - # Get or create buffer (always bf16) - buffer = KExpertsSFTBuffer.get_buffer( - qlen=qlen, - hidden_size=self.hidden_size, - moe_intermediate_size=self.moe_intermediate_size, - num_experts=self.num_experts, - num_experts_per_tok=self.num_experts_per_tok, - lora_rank=self.lora_rank, - dtype=torch.bfloat16, - ) - - # Buffer may be larger than qlen (grow-only pool), slice for copy - input_device = hidden_states.device - buffer.input_cpu[:qlen].copy_(hidden_states.to(torch.bfloat16), non_blocking=True) - buffer.expert_ids_cpu[:qlen].copy_(expert_ids.to(torch.int64), non_blocking=True) - buffer.weights_cpu[:qlen].copy_(weights.to(torch.float32), non_blocking=True) - buffer.bsz_tensor[0] = qlen - - # Synchronize CUDA stream if input was on GPU to ensure data has arrived - if input_device.type == "cuda": - torch.cuda.synchronize(input_device) - - # Store buffer reference and save_for_backward flag for sync_forward_sft. - # Always pass the real save_for_backward — C++ will overwrite the - # existing cache entry (if any) instead of pushing a duplicate. - self._pending_buffer = buffer - self._pending_save_for_backward = save_for_backward - self._pending_qlen = qlen - _sft_lifecycle_log( - "submit_pending_set", - layer=self.layer_idx, - qlen=qlen, - pending_buffer_id=id(buffer), - save_for_backward=save_for_backward, - cache_depth=self._cache_depth, - max_cache_depth=self.max_cache_depth, - hidden_states_device=str(hidden_states.device), - ) - - # Submit forward task (non-blocking) - self.cpu_infer.submit( - self.moe.forward_sft_task( - buffer.bsz_tensor.data_ptr(), - self.num_experts_per_tok, - buffer.expert_ids_cpu.data_ptr(), - buffer.weights_cpu.data_ptr(), - buffer.input_cpu.data_ptr(), - buffer.output_cpu.data_ptr(), - save_for_backward, - ) - ) - - def sync_forward_sft(self, output_device: Optional[torch.device] = None) -> torch.Tensor: - """ - Synchronize and retrieve SFT forward results. - - Must be called after submit_forward_sft(). - - Args: - output_device: Target device for output (None = clone CPU tensor for safety) - - Returns: - Output hidden states [qlen, hidden_size] - """ - if not hasattr(self, "_pending_buffer") or self._pending_buffer is None: - raise RuntimeError("No pending forward. Call submit_forward_sft() first.") - - # Wait for completion - self.cpu_infer.sync() - - buffer = self._pending_buffer - save_for_backward = self._pending_save_for_backward - qlen = self._pending_qlen - - # Track cache depth (only increment on first push, not on overwrites) - if save_for_backward and self._cache_depth == 0: - self._cache_depth += 1 - _sft_lifecycle_log( - "sync_before_clear", - layer=self.layer_idx, - qlen=qlen, - pending_buffer_id=id(buffer), - save_for_backward=save_for_backward, - cache_depth=self._cache_depth, - max_cache_depth=self.max_cache_depth, - output_device=str(output_device) if output_device is not None else "None", - ) - - # Clear pending state - self._pending_buffer = None - self._pending_save_for_backward = None - self._pending_qlen = None - _sft_lifecycle_log( - "sync_after_clear", - layer=self.layer_idx, - pending_exists=False, - cache_depth=self._cache_depth, - max_cache_depth=self.max_cache_depth, - ) - - # Return output: slice to actual qlen - if output_device is not None: - return buffer.output_cpu[:qlen].to(device=output_device, non_blocking=True) - else: - return buffer.output_cpu[:qlen].clone()