mrhaoxx 9544a8960d feat(sft): AMX MoE SFT backend with LoRA support (#1936)
* feat(sft): AMX MoE SFT backend with LoRA support

Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD:

Core C++ implementation:
- sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines)
- moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA
- amx/moe-sft-tp.hpp: AMX-specific TP implementation
- avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM
- amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization
- worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure
- ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants)

Python sft/ submodule (kt_kernel.sft):
- base.py: BaseSFTMoEWrapper with buffer management (template method pattern)
- amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction)
- autograd.py: KTMoEFunction (torch.autograd.Function for distributed training)
- layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers)
- arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection)
- weights.py: Expert weight extraction and checkpoint loading
- lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter)
- wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map
- config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough)
- dist_utils.py: Distributed gather/scatter, checkpoint-phase detection

Design decisions:
- Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights
- DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework
  interaction fields), all logic in kt_kernel.sft
- Inference isolation: importing kt_kernel does not load sft/ submodule
- Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically

Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.

* refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code

- KTConfig fields all use kt_ prefix matching dict keys — eliminates
  _OLD_TO_NEW mapping and prefix-stripping in wrapper.py
- Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing
  is on (via training_args.py), flows through to C++ cache allocation
- Remove dead checkpoint detection code: in_ckpt_recompute,
  in_ckpt_first_forward vars (assigned but never read), fallback
  _is_in_checkpoint_first_forward() function, unused inspect import
- Remove redundant env var fallbacks in wrapper.py for share_backward_bb
  and share_cache_pool (KTConfig.__post_init__ already handles env vars)
- Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check

Verified: Qwen3-235B 3-step training on sap4, loss matches baseline
(1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(sft): share_backward_bb default True, share_cache_pool auto-derived

- kt_share_backward_bb defaults to True (always saves memory)
- kt_share_cache_pool no longer reads from env var; defaults False,
  auto-set to True by trainer_config_process when gradient checkpointing
  is enabled

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: add missing gpu_experts_mask=None to KTMoEWrapper call in SFT wrapper

KTMoEWrapper.__new__() requires gpu_experts_mask as a positional argument,
but the SFT wrapper omitted it, causing MoE layer wrapping to fail silently
and FSDP2 to attempt broadcasting all expert weights (OOM/NCCL crash).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): support transformers v5 fused expert format

Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters
(gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert
nn.Linear modules. PEFT cannot attach LoRA to these, so we create
KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers
for the optimizer, and pre-assigned .grad for C++ backward.

- arch.py: detect_fused_experts() detection
- weights.py: fused format extraction and weight clearing
- wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank
- lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA,
  get_kt_lora_params collects fused params, deduplicate wrapper finding
- layer.py: handle v5 TopKRouter tuple output, remove dead code
- autograd.py: sync_forward_sft/submit_forward_sft API rename

Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): add Qwen3.5 MoE support + fused checkpoint loading

- arch.py: add Qwen3_5Moe arch match, read config from text_config,
  _get_layers_prefix returns model.language_model.layers for Qwen3.5,
  _get_model_container_and_layers searches language_model attr
- weights.py: load_experts_from_checkpoint_files detects fused format
  (gate_up_proj in weight_map) and splits into gate/up/down
- wrapper.py: hidden_size fallback to text_config

Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [fix](sft): align Python API with C++ backend after v5 refactor

- wrapper.py: pass gpu_experts_mask=None to KTMoEWrapper (required by C++ signature)
- layer.py: rename submit_forward_sft/sync_forward_sft to submit_forward/sync_forward
- autograd.py: rename sync_forward_sft to sync_forward

The sft-v5 refactor (commits 58d7eab, dd1da65) renamed Python-side method
calls but the C++ backend (AMXSFTMoEWrapper) still exposes the original
method names. This caused AttributeError on Qwen3.5-35B and other models.

* align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults

- Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace,
  sft_timer namespace, ITT API, extended do_work_stealing_job API)
- Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp,
  moe-sft-tp.hpp, avx_kernels.hpp)
- Restore pin_memory=True in KExpertsCPUBuffer (inference path)
- Restore fused tensor transpose logic in convert_cpu_weights.py (main layout)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* revert CMakeLists.txt to main: remove debug flags and cpptrace dep

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* clean up dev artifacts: remove SFT design docs, debug examples, bench scripts

Remove files not needed in the merge:
- docs/SFT+KTWrapper/ (6 Chinese design docs)
- docs/sft_moe_amx/ (21 dev/debug docs)
- 12 debug/test example scripts
- 6 SFT-specific bench scripts and report

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* remove dev version stamps from ext_bindings, sft_moe, moe-sft-tp

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: JimmyPeilinLi <lipeilin@mail.nwpu.edu.cn>
2026-04-22 11:27:01 +08:00

KTransformers

A Flexible Framework for Experiencing Cutting-edge LLM Inference/Fine-tune Optimizations

🎯 Overview | 🚀 kt-kernel | 🎓 kt-sft | 🔥 Citation | 🚀 Roadmap(2026Q2)

🎯 Overview

KTransformers is a research project focused on efficient inference and fine-tuning of large language models through CPU-GPU heterogeneous computing. The project has evolved into two core modules: kt-kernel and kt-sft.

🔥 Updates

  • May 6, 2026: KTransformers at GOSIM Paris 2026 — "Agentic AI on Edge" track. We'll present KT's inference performance on consumer hardware.
  • Mar 26, 2026: Support AVX2-only CPU backend for KT-Kernel inference. (Tutorial)
  • Feb 13, 2026: MiniMax-M2.5 Day0 Support! (Tutorial)
  • Feb 12, 2026: GLM-5 Day0 Support! (Tutorial)
  • Jan 27, 2026: Kimi-K2.5 Day0 Support! (Tutorial) (SFT Tutorial)
  • Jan 22, 2026: Support CPU-GPU Expert Scheduling, Native BF16 and FP8 per channel Precision and AutoDL unified fine-tuning and inference
  • Dec 24, 2025: Support Native MiniMax-M2.1 inference. (Tutorial)
  • Dec 22, 2025: Support RL-DPO fine-tuning with LLaMA-Factory. (Tutorial)
  • Dec 5, 2025: Support Native Kimi-K2-Thinking inference (Tutorial)
  • Nov 6, 2025: Support Kimi-K2-Thinking inference (Tutorial) and fine-tune (Tutorial)
  • Nov 4, 2025: KTransformers Fine-Tuning × LLaMA-Factory Integration. (Tutorial)
  • Oct 27, 2025: Support Ascend NPU. (Tutorial)
  • Oct 10, 2025: Integrating into SGLang. (Roadmap, Blog)
  • Sept 11, 2025: Support Qwen3-Next. (Tutorial)
  • Sept 05, 2025: Support Kimi-K2-0905. (Tutorial)
  • July 26, 2025: Support SmallThinker and GLM4-MoE. (Tutorial)
  • July 11, 2025: Support Kimi-K2. (Tutorial)
  • June 30, 2025: Support 3-layer (GPU-CPU-Disk) prefix cache reuse.
  • May 14, 2025: Support Intel Arc GPU (Tutorial).
  • Apr 29, 2025: Support AMX-Int8、 AMX-BF16 and Qwen3MoE (Tutorial)
  • Apr 9, 2025: Experimental support for LLaMA 4 models (Tutorial).
  • Apr 2, 2025: Support Multi-concurrency. (Tutorial).
  • Mar 15, 2025: Support ROCm on AMD GPU (Tutorial).
  • Mar 5, 2025: Support unsloth 1.58/2.51 bits weights and IQ1_S/FP8 hybrid weights. Support 139K Longer Context for DeepSeek-V3 and R1 in 24GB VRAM.
  • Feb 25, 2025: Support FP8 GPU kernel for DeepSeek-V3 and R1; Longer Context.
  • Feb 15, 2025: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed +15%, up to 16 Tokens/s), update docs and online books.
  • Feb 10, 2025: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. For detailed show case and reproduction tutorial, see here.
  • Aug 28, 2024: Decrease DeepseekV2's required VRAM from 21G to 11G.
  • Aug 15, 2024: Update detailed tutorial for injection and multi-GPU.
  • Aug 14, 2024: Support llamfile as linear backend.
  • Aug 12, 2024: Support multiple GPU; Support new model: mixtral 8*7B and 8*22B; Support q2k, q3k, q5k dequant on gpu.
  • Aug 9, 2024: Support windows native.

📦 Core Modules

🚀 kt-kernel - High-Performance Inference Kernels

CPU-optimized kernel operations for heterogeneous LLM inference.

image

Key Features:

  • AMX/AVX Acceleration: Intel AMX and AVX512/AVX2 optimized kernels for INT4/INT8 quantized inference
  • MoE Optimization: Efficient Mixture-of-Experts inference with NUMA-aware memory management
  • Quantization Support: CPU-side INT4/INT8 quantized weights, GPU-side GPTQ support
  • Easy Integration: Clean Python API for SGLang and other frameworks

Quick Start:

cd kt-kernel
pip install .

Use Cases:

  • CPU-GPU hybrid inference for large MoE models
  • Integration with SGLang for production serving
  • Heterogeneous expert placement (hot experts on GPU, cold experts on CPU)

Performance Examples:

Model Hardware Configuration Total Throughput Output Throughput
DeepSeek-R1-0528 (FP8) 8×L20 GPU + Xeon Gold 6454S 227.85 tokens/s 87.58 tokens/s (8-way concurrency)

👉 Full Documentation →


🎓 kt-sft - Fine-Tuning Framework

KTransformers × LLaMA-Factory integration for ultra-large MoE model fine-tuning.

image-20251011010558909

Key Features:

  • Resource Efficient: Fine-tune 671B DeepSeek-V3 with just 70GB GPU memory + 1.3TB RAM
  • LoRA Support: Full LoRA fine-tuning with heterogeneous acceleration
  • LLaMA-Factory Integration: Seamless integration with popular fine-tuning framework
  • Production Ready: Chat, batch inference, and metrics evaluation

Performance Examples:

Model Configuration Throughput GPU Memory
DeepSeek-V3 (671B) LoRA + AMX ~40 tokens/s 70GB (multi-GPU)
DeepSeek-V2-Lite (14B) LoRA + AMX ~530 tokens/s 6GB

Quick Start:

cd kt-sft
# Install environment following kt-sft/README.md
USE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml

👉 Full Documentation →


🔥 Citation

If you use KTransformers in your research, please cite our paper:

@inproceedings{10.1145/3731569.3764843,
  title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models},
  author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing},
  booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles},
  year = {2025}
}

👥 Contributors & Team

Developed and maintained by:

We welcome contributions! Please feel free to submit issues and pull requests.

💬 Community & Support

📦 KT original Code

The original integrated KTransformers framework has been archived to the archive/ directory for reference. The project now focuses on the two core modules above for better modularity and maintainability.

For the original documentation with full quick-start guides and examples, see:

Description
A Flexible Framework for Experiencing Heterogeneous LLM Inference/Fine-tune Optimizations
Readme Apache-2.0 87 MiB
Languages
Python 57.8%
C++ 33.1%
Cuda 5.1%
Shell 1%
CMake 0.8%
Other 2.1%