Files
ktransformers/kt-kernel/ext_bindings.cpp
mrhaoxx 9544a8960d feat(sft): AMX MoE SFT backend with LoRA support (#1936)
* feat(sft): AMX MoE SFT backend with LoRA support

Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD:

Core C++ implementation:
- sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines)
- moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA
- amx/moe-sft-tp.hpp: AMX-specific TP implementation
- avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM
- amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization
- worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure
- ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants)

Python sft/ submodule (kt_kernel.sft):
- base.py: BaseSFTMoEWrapper with buffer management (template method pattern)
- amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction)
- autograd.py: KTMoEFunction (torch.autograd.Function for distributed training)
- layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers)
- arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection)
- weights.py: Expert weight extraction and checkpoint loading
- lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter)
- wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map
- config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough)
- dist_utils.py: Distributed gather/scatter, checkpoint-phase detection

Design decisions:
- Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights
- DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework
  interaction fields), all logic in kt_kernel.sft
- Inference isolation: importing kt_kernel does not load sft/ submodule
- Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically

Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.

* refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code

- KTConfig fields all use kt_ prefix matching dict keys — eliminates
  _OLD_TO_NEW mapping and prefix-stripping in wrapper.py
- Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing
  is on (via training_args.py), flows through to C++ cache allocation
- Remove dead checkpoint detection code: in_ckpt_recompute,
  in_ckpt_first_forward vars (assigned but never read), fallback
  _is_in_checkpoint_first_forward() function, unused inspect import
- Remove redundant env var fallbacks in wrapper.py for share_backward_bb
  and share_cache_pool (KTConfig.__post_init__ already handles env vars)
- Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check

Verified: Qwen3-235B 3-step training on sap4, loss matches baseline
(1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(sft): share_backward_bb default True, share_cache_pool auto-derived

- kt_share_backward_bb defaults to True (always saves memory)
- kt_share_cache_pool no longer reads from env var; defaults False,
  auto-set to True by trainer_config_process when gradient checkpointing
  is enabled

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: add missing gpu_experts_mask=None to KTMoEWrapper call in SFT wrapper

KTMoEWrapper.__new__() requires gpu_experts_mask as a positional argument,
but the SFT wrapper omitted it, causing MoE layer wrapping to fail silently
and FSDP2 to attempt broadcasting all expert weights (OOM/NCCL crash).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): support transformers v5 fused expert format

Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters
(gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert
nn.Linear modules. PEFT cannot attach LoRA to these, so we create
KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers
for the optimizer, and pre-assigned .grad for C++ backward.

- arch.py: detect_fused_experts() detection
- weights.py: fused format extraction and weight clearing
- wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank
- lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA,
  get_kt_lora_params collects fused params, deduplicate wrapper finding
- layer.py: handle v5 TopKRouter tuple output, remove dead code
- autograd.py: sync_forward_sft/submit_forward_sft API rename

Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): add Qwen3.5 MoE support + fused checkpoint loading

- arch.py: add Qwen3_5Moe arch match, read config from text_config,
  _get_layers_prefix returns model.language_model.layers for Qwen3.5,
  _get_model_container_and_layers searches language_model attr
- weights.py: load_experts_from_checkpoint_files detects fused format
  (gate_up_proj in weight_map) and splits into gate/up/down
- wrapper.py: hidden_size fallback to text_config

Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [fix](sft): align Python API with C++ backend after v5 refactor

- wrapper.py: pass gpu_experts_mask=None to KTMoEWrapper (required by C++ signature)
- layer.py: rename submit_forward_sft/sync_forward_sft to submit_forward/sync_forward
- autograd.py: rename sync_forward_sft to sync_forward

The sft-v5 refactor (commits 58d7eab, dd1da65) renamed Python-side method
calls but the C++ backend (AMXSFTMoEWrapper) still exposes the original
method names. This caused AttributeError on Qwen3.5-35B and other models.

* align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults

- Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace,
  sft_timer namespace, ITT API, extended do_work_stealing_job API)
- Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp,
  moe-sft-tp.hpp, avx_kernels.hpp)
- Restore pin_memory=True in KExpertsCPUBuffer (inference path)
- Restore fused tensor transpose logic in convert_cpu_weights.py (main layout)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* revert CMakeLists.txt to main: remove debug flags and cpptrace dep

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* clean up dev artifacts: remove SFT design docs, debug examples, bench scripts

Remove files not needed in the merge:
- docs/SFT+KTWrapper/ (6 Chinese design docs)
- docs/sft_moe_amx/ (21 dev/debug docs)
- 12 debug/test example scripts
- 6 SFT-specific bench scripts and report

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* remove dev version stamps from ext_bindings, sft_moe, moe-sft-tp

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: JimmyPeilinLi <lipeilin@mail.nwpu.edu.cn>
2026-04-22 11:27:01 +08:00

1005 lines
48 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* @Description :
* @Author : chenht2022, Jianwei Dong
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @LastEditors : Jianwei Dong
* @LastEditTime : 2024-08-26 22:47:06
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
// Python bindings
#include <sys/types.h>
#include <sys/wait.h>
#include <unistd.h>
#include <cpptrace/cpptrace.hpp>
#include <csignal>
#include <cstddef>
#include <cstring>
#include "cpu_backend/cpuinfer.h"
#include "cpu_backend/worker_pool.h"
#include "operators/common.hpp"
#if defined(USE_MOE_KERNEL)
#include "operators/moe_kernel/la/kernel.hpp"
#include "operators/moe_kernel/moe.hpp"
#endif
#if defined(__aarch64__) && defined(CPU_USE_KML)
#if defined(KTRANSFORMERS_CPU_MLA)
#include "operators/kml/deepseekv3.hpp"
#include "operators/kml/gate.hpp"
#include "operators/kml/mla.hpp"
#include "operators/kml/mla_int8.hpp"
#endif
#include "operators/kml/moe.hpp"
static const bool _is_plain_ = true;
#else
static const bool _is_plain_ = false;
#endif
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
#include "operators/amx/awq-moe.hpp"
#include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern, with fallback for AVX512F
#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support, with fallback for AVX512F+BW
#include "operators/amx/fp8-perchannel-moe.hpp" // FP8 Per-Channel MoE for GLM-4.7-FP8
#include "operators/amx/k2-moe.hpp"
#include "operators/amx/la/amx_kernels.hpp"
#include "operators/amx/moe.hpp"
#include "operators/amx/sft_moe.hpp"
#include "operators/moe-sft-tp.hpp"
#endif
// AVX2 backends — always available on x86_64 (no AMX/AVX512 dependency)
#if defined(__x86_64__)
#include "operators/avx2/bf16-moe.hpp"
#include "operators/avx2/fp8-moe.hpp"
#include "operators/avx2/gptq_int4_avxvnni-moe.hpp"
#include "operators/avx2/gptq_int4-moe.hpp"
#endif
#include <pybind11/stl.h> // std::vector/std::pair/std::string conversions
#include <cstdint>
#include <memory>
#include <type_traits>
#include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h"
#include "operators/llamafile/mla.hpp"
#include "operators/llamafile/mlp.h"
#include "operators/llamafile/moe.hpp"
#include "pybind11/pybind11.h"
namespace py = pybind11;
using namespace pybind11::literals;
py::object to_float_ptr(uintptr_t input_ptr, int size, ggml_type type) {
if (type < 0 || type >= GGML_TYPE_COUNT) {
PyErr_SetString(PyExc_ValueError, "Invalid ggml_type");
throw py::error_already_set();
}
py::module torch = py::module::import("torch");
py::dict kwargs;
kwargs["dtype"] = torch.attr("float32");
py::object tensor = torch.attr("empty")(size, **kwargs);
uintptr_t output_ptr = tensor.attr("data_ptr")().cast<uintptr_t>();
float* output_float_ptr = reinterpret_cast<float*>(output_ptr);
try {
to_float(reinterpret_cast<void*>(input_ptr), output_float_ptr, size, type);
} catch (const std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
throw py::error_already_set();
}
return tensor;
}
py::object from_float_ptr(uintptr_t input_ptr, int size, ggml_type type) {
if (type < 0 || type >= GGML_TYPE_COUNT) {
PyErr_SetString(PyExc_ValueError, "Invalid ggml_type");
throw py::error_already_set();
}
py::module torch = py::module::import("torch");
size_t output_elem_bytes = ggml_type_size(type);
size_t output_elem_count = (size + ggml_blck_size(type) - 1) / ggml_blck_size(type);
size_t total_bytes = output_elem_count * output_elem_bytes;
py::dict kwargs;
kwargs["dtype"] = torch.attr("uint8");
py::object tensor = torch.attr("empty")(total_bytes, **kwargs);
uintptr_t output_ptr = tensor.attr("data_ptr")().cast<uintptr_t>();
void* output_void_ptr = reinterpret_cast<void*>(output_ptr);
try {
from_float(reinterpret_cast<float*>(input_ptr), output_void_ptr, size, type);
} catch (const std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
throw py::error_already_set();
}
return tensor;
}
template <typename T>
std::vector<std::vector<uintptr_t>> void_ptr_nested_to_uint(const std::vector<std::vector<T*>>& input) {
std::vector<std::vector<uintptr_t>> result;
for (const auto& row : input) {
std::vector<uintptr_t> new_row;
for (auto ptr : row) {
new_row.push_back(reinterpret_cast<uintptr_t>(ptr));
}
result.push_back(std::move(new_row));
}
return result;
}
template <typename T>
std::vector<std::vector<T*>> uint_to_void_ptr_nested(const std::vector<std::vector<uintptr_t>>& input) {
std::vector<std::vector<T*>> result;
for (const auto& row : input) {
std::vector<T*> new_row;
for (auto val : row) {
new_row.push_back(reinterpret_cast<T*>(val));
}
result.push_back(std::move(new_row));
}
return result;
}
#define DEF_PTR_PROPERTY(cls, name) \
def_property( \
#name, [](const cls& self) { return reinterpret_cast<uintptr_t>(self.name); }, \
[](cls& self, uintptr_t val) { self.name = reinterpret_cast<void*>(val); })
#define DEF_PTR_2D_PROPERTY(cls, name) \
def_property( \
#name, [](const cls& self) { return void_ptr_nested_to_uint<void>(self.name); }, \
[](cls& self, const std::vector<std::vector<uintptr_t>>& val) { \
self.name = uint_to_void_ptr_nested<void>(val); \
})
template <class T>
class MOEBindings {
public:
class WarmUpBindings {
public:
struct Args {
CPUInfer* cpuinfer;
TP_MOE<T>* moe;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&TP_MOE<T>::warm_up, args_->moe);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe) {
Args* args = new Args{nullptr, moe.get()};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class LoadWeightsBindings {
public:
struct Args {
CPUInfer* cpuinfer;
TP_MOE<T>* moe;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&TP_MOE<T>::load_weights, args_->moe);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe,
const uintptr_t physical_to_logical_map = 0) {
Args* args = new Args{nullptr, moe.get()};
if (physical_to_logical_map) {
// printf("debug physical_to_logical_map in arg:%lu\n", physical_to_logical_map);
moe->config.physical_to_logical_map = reinterpret_cast<void*>(physical_to_logical_map);
// printf("moe ptr:%p,confirm: moe->config.physical_to_logical_map:%lu\n", reinterpret_cast<void*>(moe.get()),
// reinterpret_cast<uintptr_t>(moe->config.physical_to_logical_map));
}
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe) {
return cpuinfer_interface(moe, 0);
}
};
class ForwardBindings {
public:
struct Args {
CPUInfer* cpuinfer;
TP_MOE<T>* moe;
intptr_t qlen;
int k;
intptr_t expert_ids;
intptr_t weights;
intptr_t input;
intptr_t output;
bool incremental;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&TP_MOE<T>::forward_binding, args_->moe, args_->qlen, args_->k, args_->expert_ids,
args_->weights, args_->input, args_->output, args_->incremental);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe, intptr_t qlen, int k,
intptr_t expert_ids, intptr_t weights, intptr_t input,
intptr_t output, bool incremental = false) {
Args* args = new Args{nullptr, moe.get(), qlen, k, expert_ids, weights, input, output, incremental};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe, intptr_t qlen, int k,
intptr_t expert_ids, intptr_t weights, intptr_t input,
intptr_t output) {
return cpuinfer_interface(moe, qlen, k, expert_ids, weights, input, output, false);
}
};
};
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
template <class T>
class MOESFTBindings {
public:
class WarmUpBindings {
public:
struct Args {
CPUInfer* cpuinfer;
TP_MOE_SFT<T>* moe;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&TP_MOE_SFT<T>::warm_up, args_->moe);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE_SFT<T>> moe) {
Args* args = new Args{nullptr, moe.get()};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class LoadWeightsBindings {
public:
struct Args {
CPUInfer* cpuinfer;
TP_MOE_SFT<T>* moe;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&TP_MOE_SFT<T>::load_weights, args_->moe);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE_SFT<T>> moe) {
Args* args = new Args{nullptr, moe.get()};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class ForwardSFTBindings {
public:
struct Args {
CPUInfer* cpuinfer;
TP_MOE_SFT<T>* moe;
intptr_t qlen;
int k;
intptr_t expert_ids;
intptr_t weights;
intptr_t input;
intptr_t output;
bool save_for_backward;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&TP_MOE_SFT<T>::forward_sft_binding, args_->moe, args_->qlen, args_->k,
args_->expert_ids, args_->weights, args_->input, args_->output,
args_->save_for_backward);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE_SFT<T>> moe, intptr_t qlen, int k,
intptr_t expert_ids, intptr_t weights, intptr_t input,
intptr_t output, bool save_for_backward) {
Args* args = new Args{nullptr, moe.get(), qlen, k, expert_ids, weights, input, output, save_for_backward};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class BackwardBindings {
public:
struct Args {
CPUInfer* cpuinfer;
TP_MOE_SFT<T>* moe;
intptr_t grad_output;
intptr_t grad_input;
intptr_t grad_gate_lora_a;
intptr_t grad_gate_lora_b;
intptr_t grad_up_lora_a;
intptr_t grad_up_lora_b;
intptr_t grad_down_lora_a;
intptr_t grad_down_lora_b;
intptr_t grad_weights;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&TP_MOE_SFT<T>::backward_binding, args_->moe, args_->grad_output, args_->grad_input,
args_->grad_gate_lora_a, args_->grad_gate_lora_b, args_->grad_up_lora_a,
args_->grad_up_lora_b, args_->grad_down_lora_a, args_->grad_down_lora_b,
args_->grad_weights);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE_SFT<T>> moe, intptr_t grad_output,
intptr_t grad_input, intptr_t grad_gate_lora_a,
intptr_t grad_gate_lora_b, intptr_t grad_up_lora_a,
intptr_t grad_up_lora_b, intptr_t grad_down_lora_a,
intptr_t grad_down_lora_b, intptr_t grad_weights) {
Args* args = new Args{nullptr, moe.get(), grad_output, grad_input,
grad_gate_lora_a, grad_gate_lora_b, grad_up_lora_a, grad_up_lora_b,
grad_down_lora_a, grad_down_lora_b, grad_weights};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class UpdateLoRAWeightsBindings {
public:
struct Args {
CPUInfer* cpuinfer;
TP_MOE_SFT<T>* moe;
intptr_t gate_lora_a;
intptr_t gate_lora_b;
intptr_t up_lora_a;
intptr_t up_lora_b;
intptr_t down_lora_a;
intptr_t down_lora_b;
};
static void inner(void* args) {
// Debug code for Bug #18 - commented out after fix verified
// printf("[DEBUG UpdateLoRAWeightsBindings::inner] called\n");
Args* args_ = (Args*)args;
// printf(" moe=%p, gate_lora_a=%p, gate_lora_b=%p\n", (void*)args_->moe, (void*)args_->gate_lora_a,
// (void*)args_->gate_lora_b); printf(" up_lora_a=%p, up_lora_b=%p\n", (void*)args_->up_lora_a,
// (void*)args_->up_lora_b); printf(" down_lora_a=%p, down_lora_b=%p\n", (void*)args_->down_lora_a,
// (void*)args_->down_lora_b);
args_->cpuinfer->enqueue(&TP_MOE_SFT<T>::update_lora_weights_binding, args_->moe, args_->gate_lora_a,
args_->gate_lora_b, args_->up_lora_a, args_->up_lora_b, args_->down_lora_a,
args_->down_lora_b);
// printf("[DEBUG UpdateLoRAWeightsBindings::inner] enqueue done\n");
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE_SFT<T>> moe, intptr_t gate_lora_a,
intptr_t gate_lora_b, intptr_t up_lora_a,
intptr_t up_lora_b, intptr_t down_lora_a,
intptr_t down_lora_b) {
Args* args =
new Args{nullptr, moe.get(), gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, down_lora_a, down_lora_b};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
};
template <typename MoeSftTP>
void bind_moe_sft_module(py::module_& moe_module, const char* name) {
using MoeClass = TP_MOE_SFT<MoeSftTP>;
using MoeBindings = MOESFTBindings<MoeSftTP>;
py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name)
.def(py::init<MOESFTConfig>())
.def("warm_up_task", &MoeBindings::WarmUpBindings::cpuinfer_interface)
.def("load_weights_task", &MoeBindings::LoadWeightsBindings::cpuinfer_interface)
.def("forward_sft_task", &MoeBindings::ForwardSFTBindings::cpuinfer_interface)
.def("backward_task", &MoeBindings::BackwardBindings::cpuinfer_interface)
.def("update_lora_weights_task", &MoeBindings::UpdateLoRAWeightsBindings::cpuinfer_interface)
.def("warm_up", &MoeClass::warm_up)
.def("load_weights", &MoeClass::load_weights)
.def("forward_sft", &MoeClass::forward_sft_binding)
.def("backward", &MoeClass::backward_binding)
.def("update_lora_weights", &MoeClass::update_lora_weights_binding)
.def("prepare_and_save_bwd",
[](MoeClass& self, intptr_t gate, intptr_t up, intptr_t down, const std::string& path) {
self.prepare_and_save_bwd((void*)gate, (void*)up, (void*)down, path);
})
.def("submit_backward_repack", &MoeClass::submit_backward_repack)
.def("wait_backward_repack", &MoeClass::wait_backward_repack);
}
#endif // defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
template <typename MoeTP>
void bind_moe_module(py::module_& moe_module, const char* name) {
using MoeClass = TP_MOE<MoeTP>;
using MoeBindings = MOEBindings<MoeTP>;
auto moe_cls = py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name);
moe_cls.def(py::init<GeneralMOEConfig>())
.def("warm_up_task", &MoeBindings::WarmUpBindings::cpuinfer_interface)
.def("load_weights_task",
py::overload_cast<std::shared_ptr<MoeClass>>(&MoeBindings::LoadWeightsBindings::cpuinfer_interface))
.def("load_weights_task",
py::overload_cast<std::shared_ptr<MoeClass>, const uintptr_t>(
&MoeBindings::LoadWeightsBindings::cpuinfer_interface),
py::arg("physical_to_logical_map"))
// .def("forward_task", &MoeBindings::ForwardBindings::cpuinfer_interface)
.def("forward_task",
py::overload_cast<std::shared_ptr<MoeClass>, intptr_t, int, intptr_t, intptr_t, intptr_t, intptr_t>(
&MoeBindings::ForwardBindings::cpuinfer_interface))
.def("forward_task",
py::overload_cast<std::shared_ptr<MoeClass>, intptr_t, int, intptr_t, intptr_t, intptr_t, intptr_t, bool>(
&MoeBindings::ForwardBindings::cpuinfer_interface))
.def("warm_up", &MoeClass::warm_up)
.def("load_weights", &MoeClass::load_weights)
.def("forward", &MoeClass::forward_binding);
// Bind write_weight_scale_to_buffer_task for MoE types that support it
// Uses SFINAE to detect if MoeClass has write_weight_scale_to_buffer method
if constexpr (requires { &MoeClass::write_weight_scale_to_buffer; }) {
struct WriteWeightScaleToBufferBindings {
struct Args {
CPUInfer* cpuinfer;
MoeClass* moe;
int gpu_tp_count;
int expert_id;
std::vector<uintptr_t> w13_weight_ptrs;
std::vector<uintptr_t> w13_scale_ptrs;
std::vector<uintptr_t> w2_weight_ptrs;
std::vector<uintptr_t> w2_scale_ptrs;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe, args_->gpu_tp_count,
args_->expert_id, args_->w13_weight_ptrs, args_->w13_scale_ptrs, args_->w2_weight_ptrs,
args_->w2_scale_ptrs);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<MoeClass> moe, int gpu_tp_count,
int expert_id, py::list w13_weight_ptrs,
py::list w13_scale_ptrs, py::list w2_weight_ptrs,
py::list w2_scale_ptrs) {
// Convert Python lists to std::vector<uintptr_t>
std::vector<uintptr_t> w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec;
for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast<uintptr_t>(item));
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, expert_id,
w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface,
py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
}
}
PYBIND11_MODULE(kt_kernel_ext, m) {
py::class_<WorkerPool>(m, "WorkerPool").def(py::init<int>());
py::class_<WorkerPoolConfig>(m, "WorkerPoolConfig")
.def(py::init<>())
.def_readwrite("subpool_count", &WorkerPoolConfig::subpool_count)
.def_readwrite("subpool_numa_map", &WorkerPoolConfig::subpool_numa_map)
.def_readwrite("subpool_thread_count", &WorkerPoolConfig::subpool_thread_count);
py::class_<CPUInfer>(m, "CPUInfer")
.def(py::init<int>())
.def(py::init<WorkerPoolConfig>())
.def("submit", &CPUInfer::submit)
.def("sync", &CPUInfer::sync, py::arg("allow_n_pending") = 0)
.def_readwrite("backend_", &CPUInfer::backend_)
#ifndef KTRANSFORMERS_CPU_ONLY
.def("sync_with_cuda_stream", &CPUInfer::sync_with_cuda_stream, py::arg("user_cuda_stream"),
py::arg("allow_n_pending") = 0)
.def("submit_with_cuda_stream", &CPUInfer::submit_with_cuda_stream)
#endif
;
auto linear_module = m.def_submodule("linear");
py::class_<LinearConfig>(linear_module, "LinearConfig")
.def(py::init([](int hidden_size, int intermediate_size, int stride, int group_max_len, intptr_t proj,
int proj_type, int hidden_type) {
return LinearConfig(hidden_size, intermediate_size, stride, group_max_len, (void*)proj, (ggml_type)proj_type,
(ggml_type)hidden_type);
}));
// py::class_<Linear>(linear_module, "Linear")
// .def(py::init<LinearConfig>())
// .def("warm_up", &LinearBindings::WarmUpBindings::cpuinfer_interface)
// .def("forward", &LinearBindings::ForwardBindings::cpuinfer_interface);
auto mlp_module = m.def_submodule("mlp");
py::class_<MLPConfig>(mlp_module, "MLPConfig")
.def(py::init([](int hidden_size, int intermediate_size, int stride, int group_max_len, intptr_t gate_proj,
intptr_t up_proj, intptr_t down_proj, int gate_type, int up_type, int down_type,
int hidden_type) {
return MLPConfig(hidden_size, intermediate_size, stride, group_max_len, (void*)gate_proj, (void*)up_proj,
(void*)down_proj, (ggml_type)gate_type, (ggml_type)up_type, (ggml_type)down_type,
(ggml_type)hidden_type);
}));
// py::class_<MLP>(mlp_module, "MLP")
// .def(py::init<MLPConfig>())
// .def("warm_up", &MLPBindings::WarmUpBindings::cpuinfer_interface)
// .def("forward", &MLPBindings::ForwardBindings::cpuinfer_interface);
py::class_<GeneralConfig>(m, "GeneralConfig")
.def(py::init<>())
.def_readwrite("vocab_size", &GeneralConfig::vocab_size)
.def_readwrite("hidden_size", &GeneralConfig::hidden_size)
.def_readwrite("num_experts_per_tok", &GeneralConfig::num_experts_per_tok)
.def_readwrite("n_routed_experts", &GeneralConfig::n_routed_experts)
.def_readwrite("n_shared_experts", &GeneralConfig::n_shared_experts)
.def_readwrite("max_qlen", &GeneralConfig::max_qlen)
.DEF_PTR_PROPERTY(GeneralConfig, lm_heads_ptr)
.def_readwrite("lm_heads_type", &GeneralConfig::lm_heads_type)
.DEF_PTR_PROPERTY(GeneralConfig, norm_weights_ptr)
.def_readwrite("norm_weights_type", &GeneralConfig::norm_weights_type)
.DEF_PTR_PROPERTY(GeneralConfig, token_embd_ptr)
.def_readwrite("token_embd_type", &GeneralConfig::token_embd_type)
.def_readwrite("pool", &GeneralConfig::pool);
#if defined(__aarch64__) && defined(CPU_USE_KML) && defined(KTRANSFORMERS_CPU_MLA)
py::class_<DeepseekV3ForCausalLM, std::shared_ptr<DeepseekV3ForCausalLM>>(m, "DeepseekV3ForCausalLM")
.def(py::init([](GeneralConfig config) { return std::make_shared<DeepseekV3ForCausalLM>(config); }))
.def_readwrite("model", &DeepseekV3ForCausalLM::model)
.def("forward", &DeepseekV3ForCausalLM::forward_binding);
py::class_<DeepseekV3Model, std::shared_ptr<DeepseekV3Model>>(m, "DeepseekV3Model")
.def(py::init([](GeneralConfig config) { return std::make_shared<DeepseekV3Model>(config); }))
.def_readwrite("layers", &DeepseekV3Model::layers);
py::class_<DeepseekV3DecoderLayer, std::shared_ptr<DeepseekV3DecoderLayer>>(m, "DeepseekV3DecoderLayer")
.def(py::init([](GeneralConfig config, size_t layer_idx) {
return std::make_shared<DeepseekV3DecoderLayer>(config, layer_idx);
}))
.def("load_norm", &DeepseekV3DecoderLayer::load_norm_binding)
.def_readwrite("self_attn", &DeepseekV3DecoderLayer::self_attn)
.def_readwrite("gate", &DeepseekV3DecoderLayer::gate)
.def_readwrite("ffn", &DeepseekV3DecoderLayer::ffn);
#endif
auto mla_module = m.def_submodule("mla");
py::class_<GeneralMLAConfig>(mla_module, "MLAConfig")
.def(py::init([](size_t hidden_size, size_t q_lora_rank, size_t num_heads, size_t nope_size, size_t rope_size,
size_t kv_lora_rank) {
return GeneralMLAConfig(hidden_size, q_lora_rank, num_heads, nope_size, rope_size, kv_lora_rank);
}))
.def_readwrite("layer_idx", &GeneralMLAConfig::layer_idx)
.def_readwrite("pool", &GeneralMLAConfig::pool)
.def_readwrite("token_count_in_page", &GeneralMLAConfig::token_count_in_page)
.def_readwrite("max_qlen", &GeneralMLAConfig::max_qlen)
.def_readwrite("max_kvlen", &GeneralMLAConfig::max_kvlen)
.def_readwrite("max_position_embeddings", &GeneralMLAConfig::max_position_embeddings)
.def_readwrite("rope_scaling_factor", &GeneralMLAConfig::rope_scaling_factor)
.def_readwrite("rope_theta", &GeneralMLAConfig::rope_theta)
.def_readwrite("rope_scaling_beta_fast", &GeneralMLAConfig::rope_scaling_beta_fast)
.def_readwrite("rope_scaling_beta_slow", &GeneralMLAConfig::rope_scaling_beta_slow)
.def_readwrite("rope_scaling_mscale", &GeneralMLAConfig::rope_scaling_mscale)
.def_readwrite("rope_scaling_mscale_all_dim", &GeneralMLAConfig::rope_scaling_mscale_all_dim)
.def_readwrite("rope_scaling_original_max_position_embeddings",
&GeneralMLAConfig::rope_scaling_original_max_position_embeddings)
.DEF_PTR_PROPERTY(GeneralMLAConfig, q_a_proj)
.DEF_PTR_PROPERTY(GeneralMLAConfig, q_a_norm)
.DEF_PTR_PROPERTY(GeneralMLAConfig, q_b_proj)
.DEF_PTR_PROPERTY(GeneralMLAConfig, kv_a_proj_with_mqa)
.DEF_PTR_PROPERTY(GeneralMLAConfig, kv_a_norm)
.DEF_PTR_PROPERTY(GeneralMLAConfig, kv_b_proj)
.DEF_PTR_PROPERTY(GeneralMLAConfig, o_proj)
.def_readwrite("q_a_proj_type", &GeneralMLAConfig::q_a_proj_type)
.def_readwrite("q_a_norm_type", &GeneralMLAConfig::q_a_norm_type)
.def_readwrite("q_b_proj_type", &GeneralMLAConfig::q_b_proj_type)
.def_readwrite("kv_a_proj_with_mqa_type", &GeneralMLAConfig::kv_a_proj_with_mqa_type)
.def_readwrite("kv_a_norm_type", &GeneralMLAConfig::kv_a_norm_type)
.def_readwrite("kv_b_proj_type", &GeneralMLAConfig::kv_b_proj_type)
.def_readwrite("w_o_type", &GeneralMLAConfig::w_o_type)
.def_readwrite("page_count", &GeneralMLAConfig::page_count)
;
py::class_<MLA_Interface, std::shared_ptr<MLA_Interface>>(mla_module, "MLA_Interface");
#if defined(__aarch64__) && defined(CPU_USE_KML) && defined(KTRANSFORMERS_CPU_MLA)
py::class_<TP_MLA<KML_MLA_TP<float16_t>>, MLA_Interface, std::shared_ptr<TP_MLA<KML_MLA_TP<float16_t>>>>(mla_module,
"MLA_F16")
.def(py::init<GeneralMLAConfig>())
.def("load_weights", &TP_MLA<KML_MLA_TP<float16_t>>::load_weights)
.def("forward",
[](TP_MLA<KML_MLA_TP<float16_t>>& op, std::vector<int> qlens, std::vector<std::vector<int>> page_tables,
std::vector<int> kvlens, intptr_t input,
intptr_t output) { op.forward(qlens, page_tables, kvlens, (const void*)input, (void*)output); })
.def("set_local_pages", &TP_MLA<KML_MLA_TP<float16_t>>::set_local_pages)
.def("set_pages", [](TP_MLA<KML_MLA_TP<float16_t>>& op, std::vector<std::vector<intptr_t>> nope_pages,
std::vector<std::vector<intptr_t>> rope_pages) {
std::vector<std::vector<void*>> nope_pages_ptr;
std::vector<std::vector<void*>> rope_pages_ptr;
op.set_pages(nope_pages_ptr, rope_pages_ptr);
});
py::class_<TP_MLA<KML_MLA_TP<float>>, MLA_Interface, std::shared_ptr<TP_MLA<KML_MLA_TP<float>>>>(mla_module,
"MLA_F32")
.def(py::init<GeneralMLAConfig>())
.def("load_weights", &TP_MLA<KML_MLA_TP<float>>::load_weights)
.def("forward",
[](TP_MLA<KML_MLA_TP<float>>& op, std::vector<int> qlens, std::vector<std::vector<int>> page_tables,
std::vector<int> kvlens, intptr_t input,
intptr_t output) { op.forward(qlens, page_tables, kvlens, (const void*)input, (void*)output); })
.def("set_local_pages", &TP_MLA<KML_MLA_TP<float>>::set_local_pages)
.def("set_pages", [](TP_MLA<KML_MLA_TP<float>>& op, std::vector<std::vector<intptr_t>> nope_pages,
std::vector<std::vector<intptr_t>> rope_pages) {
std::vector<std::vector<void*>> nope_pages_ptr;
std::vector<std::vector<void*>> rope_pages_ptr;
op.set_pages(nope_pages_ptr, rope_pages_ptr);
});
py::class_<TP_MLA<KML_MLA_TP_QUAN<float>>, MLA_Interface, std::shared_ptr<TP_MLA<KML_MLA_TP_QUAN<float>>>>(
mla_module, "MLA_QUAN_F32")
.def(py::init<GeneralMLAConfig>())
.def("load_weights", &TP_MLA<KML_MLA_TP_QUAN<float>>::load_weights)
.def("forward",
[](TP_MLA<KML_MLA_TP_QUAN<float>>& op, std::vector<int> qlens, std::vector<std::vector<int>> page_tables,
std::vector<int> kvlens, intptr_t input,
intptr_t output) { op.forward(qlens, page_tables, kvlens, (const void*)input, (void*)output); })
.def("set_local_pages", &TP_MLA<KML_MLA_TP_QUAN<float>>::set_local_pages)
.def("set_pages", [](TP_MLA<KML_MLA_TP_QUAN<float>>& op, std::vector<std::vector<intptr_t>> nope_pages,
std::vector<std::vector<intptr_t>> rope_pages) {
std::vector<std::vector<void*>> nope_pages_ptr;
std::vector<std::vector<void*>> rope_pages_ptr;
op.set_pages(nope_pages_ptr, rope_pages_ptr);
});
auto gate_module = m.def_submodule("gate");
py::class_<GeneralGateConfig>(gate_module, "GateConfig")
.def(py::init([](int hidden_size, int num_experts_per_tok, int n_routed_experts, int n_group, int topk_group) {
return GeneralGateConfig(hidden_size, num_experts_per_tok, n_routed_experts, n_group, topk_group);
}))
.def_readwrite("routed_scaling_factor", &GeneralGateConfig::routed_scaling_factor)
.def_readwrite("layer_idx", &GeneralGateConfig::layer_idx)
.def_readwrite("pool", &GeneralGateConfig::pool)
.DEF_PTR_PROPERTY(GeneralGateConfig, weight)
.def_readwrite("weight_type", &GeneralGateConfig::weight_type)
.DEF_PTR_PROPERTY(GeneralGateConfig, e_score_correction_bias)
.def_readwrite("e_score_correction_bias_type", &GeneralGateConfig::e_score_correction_bias_type)
;
py::class_<MoEGate, std::shared_ptr<MoEGate>>(gate_module, "MoEGate")
.def(py::init<GeneralGateConfig>())
.def("forward", &MoEGate::forward_binding);
#endif
py::class_<QuantConfig>(m, "QuantConfig")
.def(py::init<>())
.def_readwrite("quant_method", &QuantConfig::quant_method)
.def_readwrite("bits", &QuantConfig::bits)
.def_readwrite("group_size", &QuantConfig::group_size)
.def_readwrite("zero_point", &QuantConfig::zero_point)
.def_readwrite("per_channel", &QuantConfig::per_channel);
auto moe_module = m.def_submodule("moe");
py::class_<GeneralMOEConfig>(moe_module, "MOEConfig")
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size) {
return GeneralMOEConfig(expert_num, routed_expert_num, hidden_size, intermediate_size);
}))
.def(py::init(
[](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int num_gpu_experts) {
GeneralMOEConfig cfg(expert_num, routed_expert_num, hidden_size, intermediate_size);
cfg.num_gpu_experts = num_gpu_experts;
return cfg;
}))
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size,
uintptr_t gpu_experts_mask_ptr) {
GeneralMOEConfig cfg(expert_num, routed_expert_num, hidden_size, intermediate_size);
cfg.gpu_experts_mask = reinterpret_cast<uint8_t*>(gpu_experts_mask_ptr);
cfg.compute_num_gpu_experts();
return cfg;
}))
// Core config fields (required for Python access after construction)
.def_readwrite("expert_num", &GeneralMOEConfig::expert_num)
.def_readwrite("num_experts_per_tok", &GeneralMOEConfig::num_experts_per_tok)
.def_readwrite("hidden_size", &GeneralMOEConfig::hidden_size)
.def_readwrite("intermediate_size", &GeneralMOEConfig::intermediate_size)
.def_readwrite("layer_idx", &GeneralMOEConfig::layer_idx)
.def_readwrite("pool", &GeneralMOEConfig::pool)
.def_readonly("num_gpu_experts", &GeneralMOEConfig::num_gpu_experts)
.def_property(
"gpu_experts_mask",
[](const GeneralMOEConfig& self) { return reinterpret_cast<uintptr_t>(self.gpu_experts_mask); },
[](GeneralMOEConfig& self, uintptr_t val) { self.gpu_experts_mask = reinterpret_cast<uint8_t*>(val); })
.DEF_PTR_PROPERTY(GeneralMOEConfig, physical_to_logical_map)
.DEF_PTR_PROPERTY(GeneralMOEConfig, gate_proj)
.DEF_PTR_PROPERTY(GeneralMOEConfig, up_proj)
.DEF_PTR_PROPERTY(GeneralMOEConfig, down_proj)
.DEF_PTR_PROPERTY(GeneralMOEConfig, gate_scale)
.DEF_PTR_PROPERTY(GeneralMOEConfig, up_scale)
.DEF_PTR_PROPERTY(GeneralMOEConfig, down_scale)
.DEF_PTR_PROPERTY(GeneralMOEConfig, gate_zero)
.DEF_PTR_PROPERTY(GeneralMOEConfig, up_zero)
.DEF_PTR_PROPERTY(GeneralMOEConfig, down_zero)
.def_readwrite("quant_config", &GeneralMOEConfig::quant_config)
.def_readwrite("max_len", &GeneralMOEConfig::max_len)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, gate_projs)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, up_projs)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, down_projs)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, gate_scales)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, up_scales)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, down_scales)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, gate_zeros)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, up_zeros)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, down_zeros)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, gate_bwd_projs)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, up_bwd_projs)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, down_bwd_projs)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, gate_bwd_scales)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, up_bwd_scales)
.DEF_PTR_2D_PROPERTY(GeneralMOEConfig, down_bwd_scales)
.def_readwrite("path", &GeneralMOEConfig::path)
.def_readwrite("save", &GeneralMOEConfig::save)
.def_readwrite("load", &GeneralMOEConfig::load)
.def_readwrite("share_backward_bb", &GeneralMOEConfig::share_backward_bb)
.def_readwrite("share_cache_pool", &GeneralMOEConfig::share_cache_pool)
.def_readwrite("m_block", &GeneralMOEConfig::m_block)
.def_readwrite("group_min_len", &GeneralMOEConfig::group_min_len)
.def_readwrite("group_max_len", &GeneralMOEConfig::group_max_len)
.def_readwrite("gate_type", &GeneralMOEConfig::gate_type)
.def_readwrite("up_type", &GeneralMOEConfig::up_type)
.def_readwrite("down_type", &GeneralMOEConfig::down_type)
.def_readwrite("hidden_type", &GeneralMOEConfig::hidden_type)
.def_readwrite("max_cache_depth", &GeneralMOEConfig::max_cache_depth)
;
// MOESFTConfig - extends GeneralMOEConfig with LoRA support
py::class_<MOESFTConfig, GeneralMOEConfig>(moe_module, "MOESFTConfig")
.def(py::init<>())
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size) {
return MOESFTConfig(expert_num, routed_expert_num, hidden_size, intermediate_size);
}))
.def_readwrite("lora_rank", &MOESFTConfig::lora_rank)
.def_readwrite("lora_alpha", &MOESFTConfig::lora_alpha)
.DEF_PTR_PROPERTY(MOESFTConfig, gate_lora_a)
.DEF_PTR_PROPERTY(MOESFTConfig, gate_lora_b)
.DEF_PTR_PROPERTY(MOESFTConfig, up_lora_a)
.DEF_PTR_PROPERTY(MOESFTConfig, up_lora_b)
.DEF_PTR_PROPERTY(MOESFTConfig, down_lora_a)
.DEF_PTR_PROPERTY(MOESFTConfig, down_lora_b);
py::class_<MoE_Interface, std::shared_ptr<MoE_Interface>>(moe_module, "MoE_Interface");
bind_moe_module<LLAMA_MOE_TP>(moe_module, "MOE");
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int8>>(moe_module, "AMXInt8_MOE");
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4>>(moe_module, "AMXInt4_MOE");
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
bind_moe_module<AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>(moe_module, "AMXInt4_KGroup_MOE");
#if defined(__AVX512F__)
bind_moe_module<AMX_BF16_MOE_TP<amx::GemmKernel224BF16>>(moe_module, "AMXBF16_MOE");
bind_moe_module<AMX_FP8_MOE_TP<amx::GemmKernel224FP8>>(moe_module, "AMXFP8_MOE");
bind_moe_module<AMX_FP8_PERCHANNEL_MOE_TP<amx::GemmKernel224FP8PerChannel>>(moe_module, "AMXFP8PerChannel_MOE");
#endif
// SFT MoE with LoRA support (BF16, INT8, INT4, AWQ, K2)
bind_moe_sft_module<AMX_SFT_MOE_TP<amx::GemmKernel224BF>>(moe_module, "AMXBF16_SFT_MOE");
bind_moe_sft_module<AMX_SFT_MOE_TP<amx::GemmKernel224Int8>>(moe_module, "AMXInt8_SFT_MOE");
bind_moe_sft_module<AMX_SFT_MOE_TP<amx::GemmKernel224Int4>>(moe_module, "AMXInt4_SFT_MOE");
// bind_moe_sft_module<AMX_SFT_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_SFT_MOE");
// bind_moe_sft_module<AMX_SFT_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup, AMX_AWQ_MOE_TP>>(moe_module,
// "AMXInt4_1KGroup_SFT_MOE");
// bind_moe_sft_module<AMX_SFT_MOE_TP<amx::GemmKernel224Int4SmallKGroup, AMX_K2_MOE_TP>>(moe_module,
// "AMXInt4_KGroup_SFT_MOE");
// SFT MoE with SkipLoRA=true (skip all LoRA computation in backward, only compute base weight grad_input)
bind_moe_sft_module<AMX_SFT_MOE_TP<amx::GemmKernel224BF, AMX_MOE_TP, true>>(moe_module, "AMXBF16_SFT_MOE_SkipLoRA");
bind_moe_sft_module<AMX_SFT_MOE_TP<amx::GemmKernel224Int8, AMX_MOE_TP, true>>(moe_module, "AMXInt8_SFT_MOE_SkipLoRA");
bind_moe_sft_module<AMX_SFT_MOE_TP<amx::GemmKernel224Int4, AMX_MOE_TP, true>>(moe_module, "AMXInt4_SFT_MOE_SkipLoRA");
// bind_moe_sft_module<AMX_SFT_MOE_TP<amx::GemmKernel224Int4_1, AMX_MOE_TP, true>>(moe_module,
// "AMXInt4_1_SFT_MOE_SkipLoRA");
// bind_moe_sft_module<AMX_SFT_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup, AMX_AWQ_MOE_TP, true>>(
// moe_module, "AMXInt4_1KGroup_SFT_MOE_SkipLoRA");
// bind_moe_sft_module<AMX_SFT_MOE_TP<amx::GemmKernel224Int4SmallKGroup, AMX_K2_MOE_TP, true>>(
// moe_module, "AMXInt4_KGroup_SFT_MOE_SkipLoRA");
#endif
// AVX2 backends — available on all x86_64 (no AMX/AVX512 requirement)
#if defined(__x86_64__)
bind_moe_module<AVX2_BF16_MOE_TP<avx2::GemmKernelAVX2BF16>>(moe_module, "AVX2BF16_MOE");
bind_moe_module<AVX2_FP8_MOE_TP<avx2::GemmKernelAVX2FP8>>(moe_module, "AVX2FP8_MOE");
bind_moe_module<AVX2_GPTQ_INT4_MOE_TP<avx2::GemmKernelAVX2GPTQInt4>>(moe_module, "AVX2GPTQInt4_MOE");
bind_moe_module<AVXVNNI256_GPTQ_INT4_MOE_TP<avxvnni::GemmKernelAVXVNNI256GPTQInt4>>(moe_module,
"AVXVNNI256GPTQInt4_MOE");
#endif
#if defined(USE_MOE_KERNEL)
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");
#if defined(__aarch64__) && defined(CPU_USE_KML)
// amd have not implemented int4 kernel yet
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt4, _is_plain_>>(moe_module, "Int4_KERNEL_MOE");
#endif
#endif
// Expose kernel tiling/runtime parameters so Python can modify them at runtime
{
auto tiling_module = moe_module.def_submodule("tiling");
#if defined(USE_MOE_KERNEL)
tiling_module.def(
"get_int8",
[]() {
auto t = moe_kernel::GemmKernelInt8::get_tiling();
py::dict d;
d["n_block_up_gate"] = std::get<0>(t);
d["n_block_down"] = std::get<1>(t);
d["n_block"] = std::get<2>(t);
d["m_block"] = std::get<3>(t);
d["k_block"] = std::get<4>(t);
d["n_block_up_gate_prefi"] = std::get<5>(t);
d["n_block_down_prefi"] = std::get<6>(t);
return d;
},
"Get current tiling parameters for INT8 kernel");
tiling_module.def(
"set_int8",
[](int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block, int n_block_up_gate_prefi,
int n_block_down_prefi) {
moe_kernel::GemmKernelInt8::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,
n_block_up_gate_prefi, n_block_down_prefi);
},
py::arg("n_block_up_gate"), py::arg("n_block_down"), py::arg("n_block"), py::arg("m_block"), py::arg("k_block"),
py::arg("n_block_up_gate_prefi"), py::arg("n_block_down_prefi"), "Set tiling parameters for INT8 kernel");
tiling_module.def(
"get_int4",
[]() {
auto t = moe_kernel::GemmKernelInt4::get_tiling();
py::dict d;
d["n_block_up_gate"] = std::get<0>(t);
d["n_block_down"] = std::get<1>(t);
d["n_block"] = std::get<2>(t);
d["m_block"] = std::get<3>(t);
d["k_block"] = std::get<4>(t);
d["n_block_up_gate_prefi"] = std::get<5>(t);
d["n_block_down_prefi"] = std::get<6>(t);
return d;
},
"Get current tiling parameters for INT4 kernel");
tiling_module.def(
"set_int4",
[](int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block, int n_block_up_gate_prefi,
int n_block_down_prefi) {
moe_kernel::GemmKernelInt4::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,
n_block_up_gate_prefi, n_block_down_prefi);
},
py::arg("n_block_up_gate"), py::arg("n_block_down"), py::arg("n_block"), py::arg("m_block"), py::arg("k_block"),
py::arg("n_block_up_gate_prefi"), py::arg("n_block_down_prefi"), "Set tiling parameters for INT4 kernel");
// Convenience: set both
tiling_module.def(
"set_all",
[](int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block, int n_block_up_gate_prefi,
int n_block_down_prefi) {
moe_kernel::GemmKernelInt8::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,
n_block_up_gate_prefi, n_block_down_prefi);
moe_kernel::GemmKernelInt4::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,
n_block_up_gate_prefi, n_block_down_prefi);
},
py::arg("n_block_up_gate"), py::arg("n_block_down"), py::arg("n_block"), py::arg("m_block"), py::arg("k_block"),
py::arg("n_block_up_gate_prefi"), py::arg("n_block_down_prefi"),
"Set tiling parameters for both INT8 and INT4 kernels");
#endif
}
auto kvcache_module = m.def_submodule("kvcache");
py::enum_<AnchorType>(kvcache_module, "AnchorType")
.value("FIXED", AnchorType::FIXED_ANCHOR)
.value("DYNAMIC", AnchorType::DYNAMIC)
.value("QUEST", AnchorType::QUEST)
.value("BLOCK_MAX", AnchorType::BLOCK_MAX)
.value("BLOCK_MEAN", AnchorType::BLOCK_MEAN);
py::enum_<ggml_type>(kvcache_module, "ggml_type")
// .value("FP16", ggml_type::GGML_TYPE_F16)
// .value("FP32", ggml_type::GGML_TYPE_F32)
// .value("Q4_0", ggml_type::GGML_TYPE_Q4_0)
// .value("Q8_0", ggml_type::GGML_TYPE_Q8_0)
.value("FP32", GGML_TYPE_F32)
.value("FP16", GGML_TYPE_F16)
.value("Q4_0", GGML_TYPE_Q4_0)
.value("Q4_1", GGML_TYPE_Q4_1)
.value("Q5_0", GGML_TYPE_Q5_0)
.value("Q5_1", GGML_TYPE_Q5_1)
.value("Q8_0", GGML_TYPE_Q8_0)
.value("Q8_1", GGML_TYPE_Q8_1)
.value("Q2_K", GGML_TYPE_Q2_K)
.value("Q3_K", GGML_TYPE_Q3_K)
.value("Q4_K", GGML_TYPE_Q4_K)
.value("Q5_K", GGML_TYPE_Q5_K)
.value("Q6_K", GGML_TYPE_Q6_K)
.value("Q8_K", GGML_TYPE_Q8_K)
.value("IQ2_XXS", GGML_TYPE_IQ2_XXS)
.value("IQ2_XS", GGML_TYPE_IQ2_XS)
.value("IQ3_XXS", GGML_TYPE_IQ3_XXS)
.value("IQ1_S", GGML_TYPE_IQ1_S)
.value("IQ4_NL", GGML_TYPE_IQ4_NL)
.value("IQ3_S", GGML_TYPE_IQ3_S)
.value("IQ2_S", GGML_TYPE_IQ2_S)
.value("IQ4_XS", GGML_TYPE_IQ4_XS)
.value("I8", GGML_TYPE_I8)
.value("I16", GGML_TYPE_I16)
.value("I32", GGML_TYPE_I32)
.value("I64", GGML_TYPE_I64)
.value("F64", GGML_TYPE_F64)
.value("IQ1_M", GGML_TYPE_IQ1_M)
.value("BF16", GGML_TYPE_BF16)
.export_values();
py::enum_<RetrievalType>(kvcache_module, "RetrievalType")
.value("LAYER", RetrievalType::LAYER)
.value("KVHEAD", RetrievalType::KVHEAD)
.value("QHEAD", RetrievalType::QHEAD);
py::class_<KVCacheConfig>(kvcache_module, "KVCacheConfig")
.def(py::init<int, int, int, int, int, int, AnchorType, ggml_type, RetrievalType, int, int, int, int, int, int>())
.def_readwrite("layer_num", &KVCacheConfig::layer_num)
.def_readwrite("kv_head_num", &KVCacheConfig::kv_head_num)
.def_readwrite("q_head_num", &KVCacheConfig::q_head_num)
.def_readwrite("head_dim", &KVCacheConfig::head_dim)
.def_readwrite("block_len", &KVCacheConfig::block_len)
.def_readwrite("anchor_num", &KVCacheConfig::anchor_num)
.def_readwrite("anchor_type", &KVCacheConfig::anchor_type)
.def_readwrite("kv_type", &KVCacheConfig::kv_type)
.def_readwrite("retrieval_type", &KVCacheConfig::retrieval_type)
.def_readwrite("layer_step", &KVCacheConfig::layer_step)
.def_readwrite("token_step", &KVCacheConfig::token_step)
.def_readwrite("layer_offset", &KVCacheConfig::layer_offset)
.def_readwrite("max_block_num", &KVCacheConfig::max_block_num)
.def_readwrite("max_batch_size", &KVCacheConfig::max_batch_size)
.def_readwrite("max_thread_num", &KVCacheConfig::max_thread_num);
py::class_<KVCache>(kvcache_module, "KVCache")
.def(py::init<KVCacheConfig>())
.def("get_cache_total_len", &KVCache::get_cache_total_len)
.def("update_cache_total_len",
[](KVCache& kvcache, int cache_total_len) { kvcache.update_cache_total_len(cache_total_len); });
auto utils = m.def_submodule("utils");
// 注册转换函数
utils.def("to_float", &to_float_ptr, "Convert tensor from any GGML type to float32", py::arg("input"),
py::arg("size"), py::arg("type"));
utils.def("from_float", &from_float_ptr, "Convert tensor from float32 to any GGML type", py::arg("input"),
py::arg("size"), py::arg("type"));
}
static void warmup_cpptrace() {
// 避免第一次调用触发 lazy-loadingmalloc 等) :contentReference[oaicite:7]{index=7}
cpptrace::frame_ptr buffer[10];
(void)cpptrace::safe_generate_raw_trace(buffer, 10);
cpptrace::safe_object_frame frame{};
cpptrace::get_safe_object_frame(buffer[0], &frame);
}
static void crash_handler(int signo, siginfo_t* /*info*/, void* /*ucontext*/) {
const char* head = "=== crash: signal received ===\n";
write(STDERR_FILENO, head, std::strlen(head));
cpptrace::generate_trace().print();
_exit(128 + signo);
}
__attribute__((constructor)) static void install_handlers() {
struct sigaction sa;
std::memset(&sa, 0, sizeof(sa));
sa.sa_sigaction = &crash_handler;
sa.sa_flags = SA_SIGINFO;
sigemptyset(&sa.sa_mask);
sigaction(SIGSEGV, &sa, nullptr);
sigaction(SIGABRT, &sa, nullptr);
}