Refactor KTMoEWrapper backend (#1587)

* universal backend for cpu inference
* expert defer
This commit is contained in:
Jiaqi Liao
2025-11-10 20:26:15 +08:00
committed by GitHub
parent 956d19d2d8
commit 9bc00e587b
9 changed files with 1704 additions and 661 deletions

42
kt-kernel/install.sh Executable file
View File

@@ -0,0 +1,42 @@
#!/usr/bin/env bash
set -e
usage() {
echo "Usage: $0 [avx|amx]"
exit 1
}
if [ $# -ne 1 ]; then
usage
fi
MODE="$1"
case "$MODE" in
avx)
export CPUINFER_CPU_INSTRUCT=AVX2
export CPUINFER_ENABLE_AMX=OFF
;;
amx)
export CPUINFER_CPU_INSTRUCT=AMX512
export CPUINFER_ENABLE_AMX=ON
;;
*)
echo "Error: unknown mode '$MODE'"
usage
;;
esac
export CPUINFER_BUILD_TYPE=Release
export CPUINFER_PARALLEL=8
export CPUINFER_VERBOSE=1
echo "Building in mode: $MODE"
echo "Environment:"
echo " CPUINFER_CPU_INSTRUCT=$CPUINFER_CPU_INSTRUCT"
echo " CPUINFER_ENABLE_AMX=$CPUINFER_ENABLE_AMX"
echo " CPUINFER_BUILD_TYPE=$CPUINFER_BUILD_TYPE"
echo " CPUINFER_PARALLEL=$CPUINFER_PARALLEL"
echo " CPUINFER_VERBOSE=$CPUINFER_VERBOSE"
pip install -e . -v

View File

@@ -315,21 +315,35 @@ class LLAMA_MOE_TP {
#endif
int activated_expert = 0;
for (int i = 0; i < k; i++) {
if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) {
continue;
}
m_expert_id_map_[activated_expert] = expert_ids[i];
activated_expert++;
}
int nth = config_.intermediate_size / config_.m_block;
pool->do_work_stealing_job(
nth * k, nullptr,
[&](int task_id) {
int expert_idx = task_id / nth;
int64_t expert_id = expert_ids[expert_idx];
int ith = task_id % nth;
// Only process activated (CPU) experts; skip GPU experts entirely to keep buffers aligned.
if (activated_expert > 0) {
pool->do_work_stealing_job(
nth * activated_expert, nullptr,
[&](int task_id) {
int act_idx = task_id / nth;
int64_t expert_id = m_expert_id_map_[act_idx];
if (expert_id == -1) {
return;
}
int ith = task_id % nth;
void* gate_proj_ptr =
(uint8_t*)m_local_gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.m_block) *
config_.hidden_size * ggml_type_size((ggml_type)config_.gate_type) /
ggml_blck_size((ggml_type)config_.gate_type);
float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.m_block;
float* gate_output_ptr = s_gate_output_[act_idx] + ith * config_.m_block;
auto ok = llamafile_sgemm(config_.m_block, 1,
config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_proj_ptr,
config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_input_ptr,
@@ -340,33 +354,29 @@ class LLAMA_MOE_TP {
if (ok == false) [[unlikely]] {
throw std::runtime_error("llamafile not supported");
}
// printf("gate output: ");
// debug_f32(gate_output_ptr);
void* up_proj_ptr =
(uint8_t*)m_local_up_proj_ + (expert_id * config_.intermediate_size + ith * config_.m_block) *
config_.hidden_size * ggml_type_size((ggml_type)config_.up_type) /
ggml_blck_size((ggml_type)config_.up_type);
float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.m_block;
float* up_output_ptr = s_up_output_[act_idx] + ith * config_.m_block;
llamafile_sgemm(config_.m_block, 1, config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type),
up_proj_ptr, config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type), up_input_ptr,
config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type), up_output_ptr,
config_.m_block, 0, 1, GGML_TASK_TYPE_COMPUTE, (ggml_type)config_.up_type,
ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type, GGML_TYPE_F32,
GGML_PREC_DEFAULT);
// printf("up output: ");
// debug_f32(up_output_ptr);
for (int i = ith * config_.m_block; i < (ith + 1) * config_.m_block; i++) {
s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];
s_intermediate_fp32_[act_idx][i] = act_fn(s_gate_output_[act_idx][i]) * s_up_output_[act_idx][i];
}
if (config_.m_block %
ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) ==
0) {
float* intermediate_fp32_ptr = s_intermediate_fp32_[expert_idx] + ith * config_.m_block;
float* intermediate_fp32_ptr = s_intermediate_fp32_[act_idx] + ith * config_.m_block;
void* down_input_ptr =
s_down_input_[expert_idx] +
s_down_input_[act_idx] +
ith * config_.m_block *
ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) /
ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type);
@@ -375,10 +385,11 @@ class LLAMA_MOE_TP {
}
},
nullptr);
}
if (config_.m_block % ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) !=
0) {
for (int i = 0; i < k; i++) {
for (int i = 0; i < activated_expert; i++) {
from_float(s_intermediate_fp32_[i], s_down_input_[i], config_.intermediate_size,
ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type);
}
@@ -400,8 +411,11 @@ class LLAMA_MOE_TP {
for (int i = ith * config_.m_block; i < (ith + 1) * config_.m_block; i++) {
output[i] = 0;
}
for (int expert_idx = 0; expert_idx < k; expert_idx++) {
int64_t expert_id = expert_ids[expert_idx];
for (int expert_idx = 0; expert_idx < activated_expert; expert_idx++) {
int64_t expert_id = m_expert_id_map_[expert_idx];
if (expert_id == -1) {
continue;
}
auto expert_offset = expert_id * config_.hidden_size * config_.intermediate_size;
auto m_block_offset = ith * config_.m_block * config_.intermediate_size;
@@ -418,8 +432,16 @@ class LLAMA_MOE_TP {
ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type, GGML_TYPE_F32,
GGML_PREC_DEFAULT);
float expert_weight = 0.0f;
for (int j = 0; j < k; j++) {
if (expert_ids[j] == expert_id) {
expert_weight = weights[j];
break;
}
}
for (int i = ith * config_.m_block; i < (ith + 1) * config_.m_block; i++) {
output[i] += s_down_output_[expert_idx][i] * weights[expert_idx];
output[i] += s_down_output_[expert_idx][i] * expert_weight;
}
}
},
@@ -452,6 +474,12 @@ class LLAMA_MOE_TP {
}
for (int i = 0; i < qlen; i++) {
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
if (expert_ids[i * k + j] == -1) {
continue;
}
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
}
}
@@ -539,6 +567,12 @@ class LLAMA_MOE_TP {
}
}
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
if (expert_ids[i * k + j] == -1) {
continue;
}
memcpy(m_local_gate_input_ptr_[expert_ids[i * k + j]] +
m_local_pos_[i][j] * config_.hidden_size *
ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type) /
@@ -683,6 +717,12 @@ class LLAMA_MOE_TP {
m_output_fp32_[i][e] = 0;
}
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
if (expert_ids[i * k + j] == -1) {
continue;
}
for (int e = 0; e < config_.hidden_size; e++) {
m_output_fp32_[i][e] +=
m_local_down_output_ptr_[expert_ids[i * k + j]][m_local_pos_[i][j] * config_.hidden_size + e] *
@@ -739,24 +779,38 @@ class TP_MOE<LLAMA_MOE_TP> : public TP_MOE_Common<LLAMA_MOE_TP> {
void load_weights() {
auto pool = this->config.pool;
auto inter = this->config.intermediate_size / this->tp_count;
pool->dispense_backend()->do_numa_job([this, pool, inter](int tp_id) {
this->tps[tp_id]->load_weights(this->config.intermediate_size, tp_id * inter);
std::vector<int> tp_offsets(this->tp_count);
int accumulated_offset = 0;
for (int i = 0; i < this->tp_count; i++) {
tp_offsets[i] = accumulated_offset;
accumulated_offset += this->tp_configs[i].intermediate_size;
}
pool->dispense_backend()->do_numa_job([this, pool, tp_offsets](int tp_id) {
this->tps[tp_id]->load_weights(this->config.intermediate_size, tp_offsets[tp_id]);
});
this->weights_loaded = true;
}
void merge_results(int qlen, void* output) { merge_results(qlen, output, false); }
void merge_results(int qlen, void* output, bool incremental) {
if (incremental) {
throw std::runtime_error("Not Implemented");
}
void merge_results(int qlen, void *output, bool incremental) {
auto pool = this->config.pool;
pool->do_work_stealing_job(
qlen, nullptr,
[this, output](int token_nth) {
auto& tp_count = this->tp_count;
[this, output, incremental](int token_nth) {
if (incremental) {
to_float((uint8_t *)output + token_nth * config.hidden_size *
ggml_type_size((ggml_type)config.hidden_type) /
ggml_blck_size((ggml_type)config.hidden_type),
local_output + token_nth * config.hidden_size, config.hidden_size, (ggml_type)config.hidden_type);
for (int e = 0; e < config.hidden_size; e++) {
local_output_numa[0][token_nth * config.hidden_size + e] +=
local_output[token_nth * config.hidden_size + e];
}
}
auto &tp_count = this->tp_count;
for (int i = 1; i < tp_count; i++) {
for (int e = 0; e < config.hidden_size; e++) {
local_output_numa[0][token_nth * config.hidden_size + e] +=
@@ -764,11 +818,12 @@ class TP_MOE<LLAMA_MOE_TP> : public TP_MOE_Common<LLAMA_MOE_TP> {
}
}
from_float(local_output_numa[0] + token_nth * config.hidden_size,
(uint8_t*)output + token_nth * config.hidden_size * ggml_type_size((ggml_type)config.hidden_type) /
ggml_blck_size((ggml_type)config.hidden_type),
(uint8_t *)output + token_nth * config.hidden_size *
ggml_type_size((ggml_type)config.hidden_type) /
ggml_blck_size((ggml_type)config.hidden_type),
config.hidden_size, (ggml_type)config.hidden_type);
},
nullptr);
}
};
#endif
#endif

View File

@@ -5,9 +5,12 @@
#include <cstdint>
#include <cstdio>
#include <type_traits>
#include "common.hpp"
// Forward declaration for Llamafile backend type checking
class LLAMA_MOE_TP;
template <typename T>
concept MOE_TP_PART = requires(T t, int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
void* output, GeneralMOEConfig config, int tp_idx) {
@@ -26,6 +29,7 @@ class TP_MOE_Common : public MoE_Interface {
std::vector<std::unique_ptr<T>> tps;
std::vector<typename T::output_t*> local_output_numa;
T::output_t *local_output = nullptr;
bool weights_loaded = false;
@@ -53,11 +57,65 @@ class TP_MOE_Common : public MoE_Interface {
"multiple of NUMA node count");
}
for (auto i = 0; i < tp_count; i++) {
tps.push_back(nullptr);
GeneralMOEConfig tp_config = config;
tp_config.intermediate_size /= tp_count;
tp_configs.push_back(tp_config);
// Check if this is Llamafile backend using compile-time type checking
constexpr bool is_llamafile = std::is_same<T, LLAMA_MOE_TP>::value;
#ifndef QK_K
#define QK_K 256
#endif
if (is_llamafile) {
// For Llamafile backend: use QK_K-aligned TP splitting
if (config.intermediate_size % QK_K != 0) {
printf("intermediate_size %d must be divisible by QK_K %d for Llamafile backend\n",
config.intermediate_size, QK_K);
throw std::runtime_error("intermediate_size must be divisible by QK_K (256) for Llamafile backend");
}
int num_blocks = config.intermediate_size / QK_K;
int base_blocks = num_blocks / tp_count;
int extra_blocks = num_blocks % tp_count;
if (base_blocks == 0) {
printf("intermediate_size %d is too small for tp_count %d (num_blocks=%d)\n",
config.intermediate_size, tp_count, num_blocks);
throw std::runtime_error("intermediate_size too small: cannot distribute blocks to all TP instances");
}
printf("Llamafile TP splitting: intermediate_size=%d, tp_count=%d, QK_K=%d\n",
config.intermediate_size, tp_count, QK_K);
printf(" num_blocks=%d, base_blocks=%d, extra_blocks=%d\n", num_blocks, base_blocks, extra_blocks);
int current_offset = 0;
for (auto i = 0; i < tp_count; i++) {
tps.push_back(nullptr);
GeneralMOEConfig tp_config = config;
// First extra_blocks TPs get one more block
int num_blocks_for_this_tp = base_blocks + (i < extra_blocks ? 1 : 0);
tp_config.intermediate_size = num_blocks_for_this_tp * QK_K;
printf(" TP %d: intermediate_size=%d, offset=%d, blocks=%d\n",
i, tp_config.intermediate_size, current_offset, num_blocks_for_this_tp);
tp_configs.push_back(tp_config);
current_offset += tp_config.intermediate_size;
}
} else {
// For non-Llamafile backends: use simple equal division
if (config.intermediate_size % tp_count != 0) {
printf("intermediate_size %d, tp count %d\n", config.intermediate_size, tp_count);
throw std::runtime_error(
"For TP, intermediate_size must be a "
"multiple of NUMA node count");
}
for (auto i = 0; i < tp_count; i++) {
tps.push_back(nullptr);
GeneralMOEConfig tp_config = config;
tp_config.intermediate_size /= tp_count;
tp_configs.push_back(tp_config);
}
}
config.pool->dispense_backend()->do_numa_job(
@@ -70,6 +128,8 @@ class TP_MOE_Common : public MoE_Interface {
&local_output_numa[i],
(size_t)sizeof(typename T::output_t) * tp_configs[i].max_possible_qlen() * tp_configs[i].hidden_size);
}
mem_requests.append_pointer((void **)&local_output, sizeof(typename T::output_t) * tp_configs[0].max_possible_qlen() *
tp_configs[0].hidden_size);
// printf("local output tp, %d,\n", tp_configs[0].max_possible_qlen());
shared_mem_buffer.alloc(this, mem_requests);
}
@@ -144,7 +204,9 @@ class TP_MOE_Common : public MoE_Interface {
virtual void load_weights() = 0;
virtual void merge_results(int qlen, void* output) = 0;
virtual void merge_results(int qlen, void* output, bool incremental) {
if (incremental == false) {
merge_results(qlen, output);

View File

@@ -23,7 +23,7 @@ Example usage:
from __future__ import annotations
from .experts import AMXMoEWrapper
from .experts import KTMoEWrapper
__version__ = "0.1.0"
__all__ = ["AMXMoEWrapper"]
__all__ = ["KTMoEWrapper"]

View File

@@ -1,220 +1,51 @@
# Wrapper for AMX MoE CPU inference operations
# Wrapper for MoE CPU inference operations
# This module encapsulates CPU inference engine, weight loading, and buffer management
# SPDX-License-Identifier: Apache-2.0
"""
Expert wrappers for CPU-based MoE inference.
This module provides high-level Python wrappers around the low-level C++ kernel
implementations, handling weight loading, buffer management, and forward inference.
This module provides the main factory interface (KTMoEWrapper) that automatically
selects the appropriate backend implementation based on the method parameter.
"""
from __future__ import annotations
import torch
from typing import Dict, List, Optional, Tuple
from safetensors import safe_open
import os
import ctypes
from typing import List, Optional
# Import the C++ extension module (compiled as kt_kernel_ext)
import kt_kernel_ext
from kt_kernel_ext.moe import MOEConfig, AMXInt4_MOE, AMXInt8_MOE
# Import base infrastructure
from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer
# Import backend implementations
from .utils.amx import AMXMoEWrapper
from .utils.llamafile import LlamafileMoEWrapper
class SafeTensorLoader:
tensor_file_map: dict
tensor_type_map: dict
file_handle_map: dict
tensor_device_map: dict
class KTMoEWrapper:
"""
Factory interface for MoE CPU inference operations.
def __init__(self, file_path: str):
self.__load_tensor_file_map(file_path)
This class serves as the main entry point for external code. It automatically
selects the appropriate backend implementation based on the `method` parameter.
def __load_tensor_file_map(self, file_path: str):
if not os.path.exists(file_path):
raise FileNotFoundError(f"Path not found: {file_path}")
if os.path.isfile(file_path):
folder_path = os.path.dirname(file_path)
else:
folder_path = file_path
self.file_handle_map = {}
self.tensor_file_map = {}
self.tensor_type_map = {}
self.tensor_device_map = {}
found_safetensor = False
for root, _, files in os.walk(folder_path):
files = sorted(files)
for file in files:
if file.endswith(".safetensors"):
found_safetensor = True
file_path = os.path.join(root, file)
if file not in self.file_handle_map:
try:
handle = safe_open(file_path, framework="pt")
self.file_handle_map[file] = handle
except Exception as e:
print(f"Error opening Safetensor file {file_path}: {e}")
continue
f = self.file_handle_map.get(file)
if f is None:
continue
try:
for key in f.keys():
self.tensor_file_map[key] = file
except Exception as e:
print(f"Error reading Safetensor file {file_path}: {e}")
if not found_safetensor:
raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
def load_tensor(self, key: str, device: str = "cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
return tensor.to(device)
def close_all_handles(self):
for handle in self.file_handle_map.values():
handle.close()
self.file_handle_map.clear()
def load_experts(self, base_key: str, device: str = "cpu"):
# base_key: blk.{layer_index}
# blk.{layer_index}.ffn_[up, down, gate]_exps.{expert_id}.numa.{numa_id}.weight
up_base_key = f"{base_key}.ffn_up_exps"
gate_base_key = f"{base_key}.ffn_gate_exps"
down_base_key = f"{base_key}.ffn_down_exps"
max_numa_id = -1
max_experts_count = -1
while self.has_tensor(f"{up_base_key}.{max_experts_count+1}.numa.{0}.weight"):
max_experts_count += 1
if max_experts_count == 0:
raise ValueError(f"No experts found for key {base_key}")
while self.has_tensor(f"{up_base_key}.{0}.numa.{max_numa_id+1}.weight"):
max_numa_id += 1
# Initialize empty lists to store tensors for each projection type
up_weights = [[] for _ in range(max_numa_id + 1)]
gate_weights = [[] for _ in range(max_numa_id + 1)]
down_weights = [[] for _ in range(max_numa_id + 1)]
up_scales = [[] for _ in range(max_numa_id + 1)]
gate_scales = [[] for _ in range(max_numa_id + 1)]
down_scales = [[] for _ in range(max_numa_id + 1)]
for numa_id in range(max_numa_id + 1):
for expert_id in range(max_experts_count + 1):
up_key = f"{up_base_key}.{expert_id}.numa.{numa_id}.weight"
gate_key = f"{gate_base_key}.{expert_id}.numa.{numa_id}.weight"
down_key = f"{down_base_key}.{expert_id}.numa.{numa_id}.weight"
up_scale_key = f"{up_base_key}.{expert_id}.numa.{numa_id}.scale"
gate_scale_key = f"{gate_base_key}.{expert_id}.numa.{numa_id}.scale"
down_scale_key = f"{down_base_key}.{expert_id}.numa.{numa_id}.scale"
# make sure contiguous
up_tensor = self.load_tensor(up_key, device).numpy()
gate_tensor = self.load_tensor(gate_key, device).numpy()
down_tensor = self.load_tensor(down_key, device).numpy()
up_scale_tensor = self.load_tensor(up_scale_key, device).numpy()
gate_scale_tensor = self.load_tensor(gate_scale_key, device).numpy()
down_scale_tensor = self.load_tensor(down_scale_key, device).numpy()
up_weights[numa_id].append(up_tensor)
gate_weights[numa_id].append(gate_tensor)
down_weights[numa_id].append(down_tensor)
up_scales[numa_id].append(up_scale_tensor)
gate_scales[numa_id].append(gate_scale_tensor)
down_scales[numa_id].append(down_scale_tensor)
return {
"up": up_weights,
"gate": gate_weights,
"down": down_weights,
"up_scale": up_scales,
"gate_scale": gate_scales,
"down_scale": down_scales,
}
def has_tensor(self, name: str):
return name in self.tensor_file_map
class KExpertsCPUBuffer:
capture_bs: List = list()
capture_buffers: Dict = dict()
temp_bs: int = 0
temp_buffer: tuple = tuple()
buffer_depth: int = 2
@classmethod
def get_buffer(cls, hidden_states: torch.Tensor, num_experts_per_tok):
hidden_size = hidden_states.shape[-1]
batch_size = hidden_states.shape[0]
if batch_size in cls.capture_buffers:
return cls.capture_buffers[batch_size]
if batch_size == cls.temp_bs:
return cls.temp_buffer
input_tensor_cpu = [
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
for _ in range(cls.buffer_depth)
]
immediate_experts_ids_cpu = [
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
for _ in range(cls.buffer_depth)
]
deferred_experts_ids_cpu = [
torch.full((batch_size, num_experts_per_tok), -1, device="cpu", dtype=torch.long, pin_memory=True)
for _ in range(cls.buffer_depth)
]
weights_cpu = [
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
for _ in range(cls.buffer_depth)
]
output_cpu = [
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
for _ in range(cls.buffer_depth)
]
bsz_tensor_cpu = [
torch.zeros((1,), device="cpu", dtype=torch.int32, pin_memory=True)
for _ in range(cls.buffer_depth)
]
output_gpu = [
torch.zeros((batch_size, hidden_size), device=hidden_states.device, dtype=hidden_states.dtype)
for _ in range(cls.buffer_depth)
]
cur_buffer = (
input_tensor_cpu,
immediate_experts_ids_cpu,
deferred_experts_ids_cpu,
weights_cpu,
output_cpu,
bsz_tensor_cpu,
output_gpu,
Usage:
wrapper = KTMoEWrapper(
layer_idx=0,
num_experts=8,
num_experts_per_tok=2,
hidden_size=4096,
moe_intermediate_size=14336,
num_gpu_experts=2,
cpuinfer_threads=32,
threadpool_count=2,
weight_path="/path/to/weights",
chunked_prefill_size=512,
method="AMXINT4" # or "AMXINT8", "LLAMAFILE"
)
if batch_size in cls.capture_bs:
cls.capture_buffers[batch_size] = cur_buffer
cls.temp_bs = batch_size
cls.temp_buffer = cur_buffer
return cur_buffer
class AMXMoEWrapper:
"""
Wrapper for AMX MoE CPU inference operations.
Manages CPU inference engine, weight loading, and buffer management.
"""
_cpu_infer_instance = None
_safetensor_loader_instance = None
_layer_has_pending_deferred: Dict[int, bool] = {}
def __init__(
self,
def __new__(
cls,
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
@@ -223,14 +54,14 @@ class AMXMoEWrapper:
num_gpu_experts: int,
cpuinfer_threads: int,
threadpool_count: int,
amx_weight_path: str,
weight_path: str,
chunked_prefill_size: int,
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
amx_method: str = "AMXINT4",
method: str = "AMXINT4",
):
"""
Initialize AMX MoE Wrapper.
Factory method to create the appropriate backend implementation.
Args:
layer_idx: Layer index
@@ -241,425 +72,41 @@ class AMXMoEWrapper:
num_gpu_experts: Number of experts to run on GPU
cpuinfer_threads: Number of CPU inference threads
threadpool_count: Number of NUMA subpools
amx_weight_path: Path to AMX weights
weight_path: Path to weights
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).
amx_method: AMX quantization method ("AMXINT4" or "AMXINT8")
"""
self.layer_idx = layer_idx
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.hidden_size = hidden_size
self.moe_intermediate_size = moe_intermediate_size
self.num_gpu_experts = num_gpu_experts
self.amx_weight_path = amx_weight_path
self.chunked_prefill_size = chunked_prefill_size
self.cpu_save = cpu_save
self.max_deferred_experts_per_token = int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0
AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
self.amx_method = amx_method
# Initialize CPU inference engine (singleton)
if AMXMoEWrapper._cpu_infer_instance is None:
worker_config = kt_kernel_ext.WorkerPoolConfig()
subpool_numa_map = list(range(threadpool_count))
subpool_thread_count = [
cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0)
for i in range(threadpool_count)
]
worker_config.subpool_count = threadpool_count
worker_config.subpool_numa_map = subpool_numa_map
worker_config.subpool_thread_count = subpool_thread_count
AMXMoEWrapper._cpu_infer_instance = kt_kernel_ext.CPUInfer(worker_config)
self.cpu_infer = AMXMoEWrapper._cpu_infer_instance
# Check if we should load merged safetensor weights
self.load_merged_weight = False
import glob
if glob.glob(os.path.join(amx_weight_path, "*.safetensors")):
self.load_merged_weight = True
# Initialize SafeTensor loader (singleton)
if self.load_merged_weight:
if AMXMoEWrapper._safetensor_loader_instance is None:
AMXMoEWrapper._safetensor_loader_instance = SafeTensorLoader(amx_weight_path)
self.safetensor_loader = AMXMoEWrapper._safetensor_loader_instance
self.moe = None
self.gate_weights = None
self.up_weights = None
self.down_weights = None
self.gate_scales = None
self.up_scales = None
self.down_scales = None
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
):
"""
Load and quantize weights from BF16/FP16 tensors (online quantization).
Args:
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
"""
# Store tensors as instance variables to keep them alive
self.gate_proj = gate_proj.contiguous()
self.up_proj = up_proj.contiguous()
self.down_proj = down_proj.contiguous()
# Configure MoE with online quantization (cpu_save mode)
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size
# Enable save mode for online quantization
moe_config.save = True
moe_config.load = False
# Set weight pointers
moe_config.gate_proj = self.gate_proj.data_ptr()
moe_config.up_proj = self.up_proj.data_ptr()
moe_config.down_proj = self.down_proj.data_ptr()
# Set output path for quantized weights
moe_config.path = self.amx_weight_path
# Create MoE module based on AMX method
if self.amx_method == "AMXINT4":
self.moe = AMXInt4_MOE(moe_config)
elif self.amx_method == "AMXINT8":
self.moe = AMXInt8_MOE(moe_config)
else:
raise NotImplementedError(f"Unsupported AMX method: {self.amx_method}")
# Submit quantization and save task
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
"""
Load weights for this layer and initialize the MoE module.
Args:
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
"""
gate_ptr = 0
up_ptr = 0
down_ptr = 0
gate_ptrs = []
up_ptrs = []
down_ptrs = []
gate_scale_ptrs = []
up_scale_ptrs = []
down_scale_ptrs = []
if self.load_merged_weight:
base_key = f"blk.{self.layer_idx}"
w = self.safetensor_loader.load_experts(base_key)
self.gate_weights = w["gate"]
self.up_weights = w["up"]
self.down_weights = w["down"]
self.gate_scales = w["gate_scale"]
self.up_scales = w["up_scale"]
self.down_scales = w["down_scale"]
# Get pointers to weight arrays
gate_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.gate_weights
]
up_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.up_weights
]
down_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.down_weights
]
gate_scale_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.gate_scales
]
up_scale_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.up_scales
]
down_scale_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.down_scales
]
# Configure MoE
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size
moe_config.gate_proj = gate_ptr
moe_config.up_proj = up_ptr
moe_config.down_proj = down_ptr
moe_config.gate_projs = gate_ptrs
moe_config.up_projs = up_ptrs
moe_config.down_projs = down_ptrs
moe_config.gate_scales = gate_scale_ptrs
moe_config.up_scales = up_scale_ptrs
moe_config.down_scales = down_scale_ptrs
if self.cpu_save:
moe_config.save = True
moe_config.load = False
base_key = f"model.layers.{self.layer_idx}"
w = self.safetensor_loader.load_experts(base_key)
self.gate_proj = torch.cat(w["gate_weight"], dim=0).contiguous()
self.up_proj = torch.cat(w["up_weight"], dim=0).contiguous()
self.down_proj = torch.cat(w["down_weight"], dim=0).contiguous()
moe_config.gate_proj = self.gate_proj.data_ptr()
moe_config.up_proj = self.up_proj.data_ptr()
moe_config.down_proj = self.down_proj.data_ptr()
else:
moe_config.load = True
if not self.load_merged_weight:
moe_config.path = self.amx_weight_path
# Create MoE module based on AMX method
if self.amx_method == "AMXINT4":
self.moe = AMXInt4_MOE(moe_config)
elif self.amx_method == "AMXINT8":
self.moe = AMXInt8_MOE(moe_config)
else:
raise NotImplementedError(f"Unsupported AMX method: {self.amx_method}")
# Load weights
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
# Clean up temporary weight storage if using merged weights
if self.load_merged_weight:
del self.gate_weights
del self.up_weights
del self.down_weights
del self.gate_scales
del self.up_scales
del self.down_scales
def select_deferred_experts(
self,
expert_ids: torch.Tensor,
expert_scores: torch.Tensor,
protected_k: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, topk = expert_ids.shape
device = expert_ids.device
protected_k = max(0, min(int(protected_k), topk))
if protected_k == 0:
deferred_ids = expert_ids.clone()
immediate_ids = torch.full_like(expert_ids, -1)
return immediate_ids, deferred_ids
topk_result = torch.topk(expert_scores, k=protected_k, dim=-1, largest=True, sorted=False)
protected_indices = topk_result.indices
protected_ids = torch.gather(expert_ids, -1, protected_indices)
protected_flag = torch.zeros((self.num_experts,), dtype=torch.int32, device=device)
protected_flag.scatter_(0, protected_ids.reshape(-1), 1)
protected_mask_flat = torch.gather(protected_flag, 0, expert_ids.reshape(-1)).ne(0)
protected_mask = protected_mask_flat.view(batch, topk)
immediate_ids = expert_ids.clone().masked_fill(~protected_mask, -1)
deferred_ids = expert_ids.clone().masked_fill(protected_mask, -1)
return immediate_ids, deferred_ids
def submit_forward(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
cuda_stream,
):
"""
Submit forward inference task to CPU (non-blocking).
Args:
hidden_states: Input hidden states [batch_size, hidden_size]
topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok]
topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]
cuda_stream: CUDA stream for synchronization
"""
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
batch_size = flat_hidden_states.shape[0]
(
input_tensor_cpu,
immediate_experts_ids_cpu,
deferred_experts_ids_cpu,
weights_cpu,
output_cpu,
bsz_tensor_cpu,
_output_gpu,
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
next_slot = (current_slot + 1) % KExpertsCPUBuffer.buffer_depth
bsz_slot_tensor = bsz_tensor_cpu[current_slot]
bsz_slot_tensor.fill_(batch_size)
deferred_experts_ids_cpu[current_slot].fill_(-1)
topk_ids_long = topk_ids.to(torch.long)
immediate_ids: torch.Tensor
deferred_ids: Optional[torch.Tensor]
if self.max_deferred_experts_per_token > 0:
protected_k = self.num_experts_per_tok - self.max_deferred_experts_per_token
immediate_ids, deferred_ids = self.select_deferred_experts(topk_ids_long, topk_weights, protected_k)
else:
immediate_ids = topk_ids_long
deferred_ids = None
input_tensor_cpu[current_slot].copy_(flat_hidden_states, non_blocking=True)
weights_cpu[current_slot].copy_(topk_weights, non_blocking=True)
immediate_experts_ids_cpu[current_slot].copy_(immediate_ids, non_blocking=True)
incremental = AMXMoEWrapper._layer_has_pending_deferred.get(self.layer_idx - 1, False)
self.cpu_infer.submit_with_cuda_stream(
cuda_stream,
self.moe.forward_task(
bsz_slot_tensor.data_ptr(),
immediate_experts_ids_cpu[current_slot].size(-1),
immediate_experts_ids_cpu[current_slot].data_ptr(),
weights_cpu[current_slot].data_ptr(),
input_tensor_cpu[current_slot].data_ptr(),
output_cpu[current_slot].data_ptr(),
incremental,
),
)
AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
if deferred_ids is not None:
deferred_experts_ids_cpu[current_slot].copy_(deferred_ids, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(
cuda_stream,
self.moe.forward_task(
bsz_slot_tensor.data_ptr(),
deferred_experts_ids_cpu[current_slot].size(-1),
deferred_experts_ids_cpu[current_slot].data_ptr(),
weights_cpu[current_slot].data_ptr(),
input_tensor_cpu[current_slot].data_ptr(),
output_cpu[next_slot].data_ptr(),
False,
),
)
AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = True
def sync_forward(self, hidden_states: torch.Tensor, cuda_stream) -> torch.Tensor:
"""
Synchronize and retrieve forward inference results.
Args:
hidden_states: Original input hidden states (for getting buffer)
cuda_stream: CUDA stream for synchronization
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
method: Backend method ("AMXINT4", "AMXINT8", "LLAMAFILE")
Returns:
output_gpu: Output tensor on GPU
An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)
"""
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
(
input_tensor_cpu,
immediate_experts_ids_cpu,
_deferred_experts_ids_cpu,
weights_cpu,
output_cpu,
_bsz_tensor_cpu,
output_gpu,
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
# Select backend based on method
if method in ["AMXINT4", "AMXINT8"]:
backend_cls = AMXMoEWrapper
elif method == "LLAMAFILE":
backend_cls = LlamafileMoEWrapper
else:
raise NotImplementedError(f"Unsupported method: {method}")
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
allow_pending = 1 if AMXMoEWrapper._layer_has_pending_deferred.get(self.layer_idx, False) else 0
self.cpu_infer.sync_with_cuda_stream(cuda_stream, allow_pending)
output_gpu[current_slot].copy_(output_cpu[current_slot], non_blocking=True)
return output_gpu[current_slot]
def forward(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
cuda_stream,
) -> torch.Tensor:
"""
Execute forward inference synchronously (submit + sync).
Args:
hidden_states: Input hidden states [batch_size, hidden_size]
topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok]
topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]
cuda_stream: CUDA stream for synchronization
Returns:
Output tensor on GPU
"""
self.submit_forward(hidden_states, topk_ids, topk_weights, cuda_stream)
return self.sync_forward(hidden_states, cuda_stream)
# Create and return backend instance
return backend_cls(
layer_idx=layer_idx,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
num_gpu_experts=num_gpu_experts,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
chunked_prefill_size=chunked_prefill_size,
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
)
# Forward static methods to the base class
@staticmethod
def set_capture_batch_sizes(capture_bs: List[int]):
"""
@@ -670,11 +117,8 @@ class AMXMoEWrapper:
Args:
capture_bs: List of batch sizes to capture (e.g., [1, 2, 4, 8, 16])
Example:
>>> AMXMoEWrapper.set_capture_batch_sizes([1, 2, 4, 8, 16])
"""
KExpertsCPUBuffer.capture_bs = capture_bs
BaseMoEWrapper.set_capture_batch_sizes(capture_bs)
@staticmethod
def get_capture_batch_sizes() -> List[int]:
@@ -684,7 +128,7 @@ class AMXMoEWrapper:
Returns:
List of batch sizes that are being captured
"""
return KExpertsCPUBuffer.capture_bs
return BaseMoEWrapper.get_capture_batch_sizes()
@staticmethod
def clear_buffer_cache():
@@ -694,6 +138,4 @@ class AMXMoEWrapper:
This frees up memory by clearing the buffer cache. Useful when you want
to reset the buffer state or free memory.
"""
KExpertsCPUBuffer.capture_buffers.clear()
KExpertsCPUBuffer.temp_bs = 0
KExpertsCPUBuffer.temp_buffer = tuple()
BaseMoEWrapper.clear_buffer_cache()

View File

@@ -0,0 +1,394 @@
# Base classes for MoE CPU inference operations
# SPDX-License-Identifier: Apache-2.0
"""
Base infrastructure for CPU-based MoE inference.
This module contains base classes and utilities shared across all backend implementations.
"""
from __future__ import annotations
import torch
from typing import Dict, List, Optional, Tuple
from abc import ABC, abstractmethod
import os
import ctypes
import kt_kernel_ext
class KExpertsCPUBuffer:
"""
CPU buffer management for expert computation.
Manages pinned memory buffers for efficient GPU-CPU data transfer.
"""
capture_bs: List = list()
capture_buffers: Dict = dict()
temp_bs: int = 0
temp_buffer: tuple = tuple()
buffer_depth: int = 2
@classmethod
def get_buffer(cls, hidden_states: torch.Tensor, num_experts_per_tok):
hidden_size = hidden_states.shape[-1]
batch_size = hidden_states.shape[0]
if batch_size in cls.capture_buffers:
return cls.capture_buffers[batch_size]
if batch_size == cls.temp_bs:
return cls.temp_buffer
input_tensor_cpu = [
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
for _ in range(cls.buffer_depth)
]
immediate_experts_ids_cpu = [
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
for _ in range(cls.buffer_depth)
]
deferred_experts_ids_cpu = [
torch.full((batch_size, num_experts_per_tok), -1, device="cpu", dtype=torch.long, pin_memory=True)
for _ in range(cls.buffer_depth)
]
weights_cpu = [
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
for _ in range(cls.buffer_depth)
]
output_cpu = [
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
for _ in range(cls.buffer_depth)
]
bsz_tensor_cpu = [
torch.zeros((1,), device="cpu", dtype=torch.int32, pin_memory=True)
for _ in range(cls.buffer_depth)
]
output_gpu = [
torch.zeros((batch_size, hidden_size), device=hidden_states.device, dtype=hidden_states.dtype)
for _ in range(cls.buffer_depth)
]
cur_buffer = (
input_tensor_cpu,
immediate_experts_ids_cpu,
deferred_experts_ids_cpu,
weights_cpu,
output_cpu,
bsz_tensor_cpu,
output_gpu,
)
if batch_size in cls.capture_bs:
cls.capture_buffers[batch_size] = cur_buffer
cls.temp_bs = batch_size
cls.temp_buffer = cur_buffer
return cur_buffer
class BaseMoEWrapper(ABC):
"""
Base class for MoE CPU inference operations.
Provides common functionality for all backend implementations.
"""
_cpu_infer_instance = None
_layer_has_pending_deferred: Dict[int, bool] = {}
def __init__(
self,
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
num_gpu_experts: int,
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
method: str = "AMXINT4",
):
"""
Initialize base MoE Wrapper.
Args:
layer_idx: Layer index
num_experts: Total number of experts
num_experts_per_tok: Number of experts per token (top-k)
hidden_size: Hidden dimension size
moe_intermediate_size: MoE intermediate size
num_gpu_experts: Number of experts to run on GPU
cpuinfer_threads: Number of CPU inference threads
threadpool_count: Number of NUMA subpools
weight_path: Path to weights
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).
method: Backend method string
"""
print(f"Init {self.__class__.__name__}")
self.layer_idx = layer_idx
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.hidden_size = hidden_size
self.moe_intermediate_size = moe_intermediate_size
self.num_gpu_experts = num_gpu_experts
self.weight_path = weight_path
self.chunked_prefill_size = chunked_prefill_size
self.cpu_save = cpu_save
self.max_deferred_experts_per_token = int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0
BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
self.method = method
# Initialize CPU inference engine (singleton)
if BaseMoEWrapper._cpu_infer_instance is None:
worker_config = kt_kernel_ext.WorkerPoolConfig()
subpool_numa_map = list(range(threadpool_count))
subpool_thread_count = [
cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0)
for i in range(threadpool_count)
]
worker_config.subpool_count = threadpool_count
worker_config.subpool_numa_map = subpool_numa_map
worker_config.subpool_thread_count = subpool_thread_count
BaseMoEWrapper._cpu_infer_instance = kt_kernel_ext.CPUInfer(worker_config)
self.cpu_infer = BaseMoEWrapper._cpu_infer_instance
# Backend-specific initialization happens in subclasses
self.moe = None
@abstractmethod
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
):
"""
Load and quantize weights from BF16/FP16 tensors (online quantization).
Args:
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
"""
pass
@abstractmethod
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
"""
Load weights for this layer and initialize the MoE module.
Args:
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
"""
pass
def select_deferred_experts(
self,
expert_ids: torch.Tensor,
expert_scores: torch.Tensor,
protected_k: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, topk = expert_ids.shape
device = expert_ids.device
protected_k = max(0, min(int(protected_k), topk))
if protected_k == 0:
deferred_ids = expert_ids.clone()
immediate_ids = torch.full_like(expert_ids, -1)
return immediate_ids, deferred_ids
topk_result = torch.topk(expert_scores, k=protected_k, dim=-1, largest=True, sorted=False)
protected_indices = topk_result.indices
protected_ids = torch.gather(expert_ids, -1, protected_indices)
protected_flag = torch.zeros((self.num_experts,), dtype=torch.int32, device=device)
protected_flag.scatter_(0, protected_ids.reshape(-1), 1)
protected_mask_flat = torch.gather(protected_flag, 0, expert_ids.reshape(-1)).ne(0)
protected_mask = protected_mask_flat.view(batch, topk)
immediate_ids = expert_ids.clone().masked_fill(~protected_mask, -1)
deferred_ids = expert_ids.clone().masked_fill(protected_mask, -1)
return immediate_ids, deferred_ids
def submit_forward(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
cuda_stream,
):
"""
Submit forward inference task to CPU (non-blocking).
Args:
hidden_states: Input hidden states [batch_size, hidden_size]
topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok]
topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]
cuda_stream: CUDA stream for synchronization
"""
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
batch_size = flat_hidden_states.shape[0]
(
input_tensor_cpu,
immediate_experts_ids_cpu,
deferred_experts_ids_cpu,
weights_cpu,
output_cpu,
bsz_tensor_cpu,
_output_gpu,
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
next_slot = (current_slot + 1) % KExpertsCPUBuffer.buffer_depth
bsz_slot_tensor = bsz_tensor_cpu[current_slot]
bsz_slot_tensor.fill_(batch_size)
deferred_experts_ids_cpu[current_slot].fill_(-1)
topk_ids_long = topk_ids.to(torch.long)
immediate_ids: torch.Tensor
deferred_ids: Optional[torch.Tensor]
if self.max_deferred_experts_per_token > 0:
protected_k = self.num_experts_per_tok - self.max_deferred_experts_per_token
immediate_ids, deferred_ids = self.select_deferred_experts(topk_ids_long, topk_weights, protected_k)
else:
immediate_ids = topk_ids_long
deferred_ids = None
input_tensor_cpu[current_slot].copy_(flat_hidden_states, non_blocking=True)
weights_cpu[current_slot].copy_(topk_weights, non_blocking=True)
immediate_experts_ids_cpu[current_slot].copy_(immediate_ids, non_blocking=True)
incremental = BaseMoEWrapper._layer_has_pending_deferred.get(self.layer_idx - 1, False)
self.cpu_infer.submit_with_cuda_stream(
cuda_stream,
self.moe.forward_task(
bsz_slot_tensor.data_ptr(),
immediate_experts_ids_cpu[current_slot].size(-1),
immediate_experts_ids_cpu[current_slot].data_ptr(),
weights_cpu[current_slot].data_ptr(),
input_tensor_cpu[current_slot].data_ptr(),
output_cpu[current_slot].data_ptr(),
incremental,
),
)
BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
if deferred_ids is not None:
deferred_experts_ids_cpu[current_slot].copy_(deferred_ids, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(
cuda_stream,
self.moe.forward_task(
bsz_slot_tensor.data_ptr(),
deferred_experts_ids_cpu[current_slot].size(-1),
deferred_experts_ids_cpu[current_slot].data_ptr(),
weights_cpu[current_slot].data_ptr(),
input_tensor_cpu[current_slot].data_ptr(),
output_cpu[next_slot].data_ptr(),
False,
),
)
BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = True
def sync_forward(self, hidden_states: torch.Tensor, cuda_stream) -> torch.Tensor:
"""
Synchronize and retrieve forward inference results.
Args:
hidden_states: Original input hidden states (for getting buffer)
cuda_stream: CUDA stream for synchronization
Returns:
output_gpu: Output tensor on GPU
"""
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
(
_input_tensor_cpu,
_immediate_experts_ids_cpu,
_deferred_experts_ids_cpu,
_weights_cpu,
output_cpu,
_bsz_tensor_cpu,
output_gpu,
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
allow_pending = 1 if BaseMoEWrapper._layer_has_pending_deferred.get(self.layer_idx, False) else 0
self.cpu_infer.sync_with_cuda_stream(cuda_stream, allow_pending)
output_gpu[current_slot].copy_(output_cpu[current_slot], non_blocking=True)
return output_gpu[current_slot]
def forward(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
cuda_stream,
) -> torch.Tensor:
"""
Execute forward inference synchronously (submit + sync).
Args:
hidden_states: Input hidden states [batch_size, hidden_size]
topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok]
topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]
cuda_stream: CUDA stream for synchronization
Returns:
Output tensor on GPU
"""
self.submit_forward(hidden_states, topk_ids, topk_weights, cuda_stream)
return self.sync_forward(hidden_states, cuda_stream)
@staticmethod
def set_capture_batch_sizes(capture_bs: List[int]):
"""
Set batch sizes to capture and cache buffers for.
This allows pre-allocation of CPU buffers for specific batch sizes,
improving performance by avoiding buffer re-allocation during inference.
Args:
capture_bs: List of batch sizes to capture (e.g., [1, 2, 4, 8, 16])
Example:
>>> BaseMoEWrapper.set_capture_batch_sizes([1, 2, 4, 8, 16])
"""
KExpertsCPUBuffer.capture_bs = capture_bs
@staticmethod
def get_capture_batch_sizes() -> List[int]:
"""
Get currently configured capture batch sizes.
Returns:
List of batch sizes that are being captured
"""
return KExpertsCPUBuffer.capture_bs
@staticmethod
def clear_buffer_cache():
"""
Clear all cached buffers.
This frees up memory by clearing the buffer cache. Useful when you want
to reset the buffer state or free memory.
"""
KExpertsCPUBuffer.capture_buffers.clear()
KExpertsCPUBuffer.temp_bs = 0
KExpertsCPUBuffer.temp_buffer = tuple()

View File

@@ -0,0 +1,301 @@
import os
import torch
import ctypes
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
from .loader import SafeTensorLoader
from kt_kernel_ext.moe import MOEConfig
try:
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE
_HAS_AMX_SUPPORT = True
except (ImportError, AttributeError):
_HAS_AMX_SUPPORT = False
AMXInt4_MOE, AMXInt8_MOE = None, None
from typing import Optional
class AMXMoEWrapper(BaseMoEWrapper):
"""
AMX-based MoE wrapper implementation.
Supports AMXINT4 and AMXINT8 quantization methods.
"""
_safetensor_loader_instance = None # Singleton SafeTensorLoader
def __init__(
self,
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
num_gpu_experts: int,
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
method: str = "AMXINT4",
):
"""
Initialize AMX MoE Wrapper.
Args:
layer_idx: Layer index
num_experts: Total number of experts
num_experts_per_tok: Number of experts per token (top-k)
hidden_size: Hidden dimension size
moe_intermediate_size: MoE intermediate size
num_gpu_experts: Number of experts to run on GPU
cpuinfer_threads: Number of CPU inference threads
threadpool_count: Number of NUMA subpools
weight_path: Path to AMX weights (SafeTensor format)
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
method: AMX quantization method ("AMXINT4" or "AMXINT8")
"""
if not _HAS_AMX_SUPPORT:
raise RuntimeError(
"AMX backend not available. kt_kernel_ext was not compiled with AMX support.\n"
"Please recompile with AMX enabled."
)
# Initialize base class
super().__init__(
layer_idx=layer_idx,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
num_gpu_experts=num_gpu_experts,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
chunked_prefill_size=chunked_prefill_size,
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
)
# AMX-specific: Check if we should load merged safetensor weights
self.load_merged_weight = False
import glob
if glob.glob(os.path.join(weight_path, "*.safetensors")):
self.load_merged_weight = True
# Initialize SafeTensor loader (singleton)
if self.load_merged_weight:
if AMXMoEWrapper._safetensor_loader_instance is None:
AMXMoEWrapper._safetensor_loader_instance = SafeTensorLoader(weight_path)
self.safetensor_loader = AMXMoEWrapper._safetensor_loader_instance
# AMX-specific weight storage
self.gate_weights = None
self.up_weights = None
self.down_weights = None
self.gate_scales = None
self.up_scales = None
self.down_scales = None
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
):
"""
Load and quantize weights from BF16/FP16 tensors (online quantization).
Args:
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
"""
# Store tensors as instance variables to keep them alive
self.gate_proj = gate_proj.contiguous()
self.up_proj = up_proj.contiguous()
self.down_proj = down_proj.contiguous()
# Configure MoE with online quantization (cpu_save mode)
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size
# Enable save mode for online quantization
moe_config.save = True
moe_config.load = False
# Set weight pointers
moe_config.gate_proj = self.gate_proj.data_ptr()
moe_config.up_proj = self.up_proj.data_ptr()
moe_config.down_proj = self.down_proj.data_ptr()
# Set output path for quantized weights
moe_config.path = self.weight_path
# Create MoE module based on AMX method
if self.method == "AMXINT4":
self.moe = AMXInt4_MOE(moe_config)
elif self.method == "AMXINT8":
self.moe = AMXInt8_MOE(moe_config)
else:
raise NotImplementedError(f"Unsupported AMX method: {self.method}")
# Submit quantization and save task
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
"""
Load weights for this layer and initialize the MoE module.
Args:
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
"""
gate_ptr = 0
up_ptr = 0
down_ptr = 0
gate_ptrs = []
up_ptrs = []
down_ptrs = []
gate_scale_ptrs = []
up_scale_ptrs = []
down_scale_ptrs = []
if self.load_merged_weight:
base_key = f"blk.{self.layer_idx}"
w = self.safetensor_loader.load_experts(base_key)
self.gate_weights = w["gate"]
self.up_weights = w["up"]
self.down_weights = w["down"]
self.gate_scales = w["gate_scale"]
self.up_scales = w["up_scale"]
self.down_scales = w["down_scale"]
# Get pointers to weight arrays
gate_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.gate_weights
]
up_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.up_weights
]
down_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.down_weights
]
gate_scale_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.gate_scales
]
up_scale_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.up_scales
]
down_scale_ptrs = [
[
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
for et in numa_array
]
for numa_array in self.down_scales
]
# Configure MoE
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size
moe_config.gate_proj = gate_ptr
moe_config.up_proj = up_ptr
moe_config.down_proj = down_ptr
moe_config.gate_projs = gate_ptrs
moe_config.up_projs = up_ptrs
moe_config.down_projs = down_ptrs
moe_config.gate_scales = gate_scale_ptrs
moe_config.up_scales = up_scale_ptrs
moe_config.down_scales = down_scale_ptrs
if self.cpu_save:
moe_config.save = True
moe_config.load = False
base_key = f"model.layers.{self.layer_idx}"
w = self.safetensor_loader.load_experts(base_key)
self.gate_proj = torch.cat(w["gate_weight"], dim=0).contiguous()
self.up_proj = torch.cat(w["up_weight"], dim=0).contiguous()
self.down_proj = torch.cat(w["down_weight"], dim=0).contiguous()
moe_config.gate_proj = self.gate_proj.data_ptr()
moe_config.up_proj = self.up_proj.data_ptr()
moe_config.down_proj = self.down_proj.data_ptr()
else:
moe_config.load = True
if not self.load_merged_weight:
moe_config.path = self.weight_path
# Create MoE module based on AMX method
if self.method == "AMXINT4":
self.moe = AMXInt4_MOE(moe_config)
elif self.method == "AMXINT8":
self.moe = AMXInt8_MOE(moe_config)
else:
raise NotImplementedError(f"Unsupported AMX method: {self.method}")
# Load weights
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
# Clean up temporary weight storage if using merged weights
if self.load_merged_weight:
del self.gate_weights
del self.up_weights
del self.down_weights
del self.gate_scales
del self.up_scales
del self.down_scales

View File

@@ -0,0 +1,225 @@
import torch
from typing import Optional
import os
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
from .loader import GGUFLoader
from kt_kernel_ext.moe import MOEConfig
try:
from kt_kernel_ext.moe import MOE
_HAS_LLAMAFILE_SUPPORT = True
except (ImportError, AttributeError):
_HAS_LLAMAFILE_SUPPORT = False
MOE = None
from kt_kernel_ext.kvcache import ggml_type
class LlamafileMoEWrapper(BaseMoEWrapper):
"""
Llamafile-based MoE wrapper implementation.
Supports GGUF quantized weights with llamafile backend.
"""
_gguf_loader_instance = None # Singleton GGUFLoader
def __init__(
self,
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
num_gpu_experts: int,
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
method: str = "LLAMAFILE",
):
"""
Initialize Llamafile MoE Wrapper.
Args:
layer_idx: Layer index
num_experts: Total number of experts
num_experts_per_tok: Number of experts per token (top-k)
hidden_size: Hidden dimension size
moe_intermediate_size: MoE intermediate size
num_gpu_experts: Number of experts to run on GPU
cpuinfer_threads: Number of CPU inference threads
threadpool_count: Number of NUMA subpools (TP count)
weight_path: Path to GGUF weights
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Not supported for Llamafile backend
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
method: Should be "LLAMAFILE"
"""
if not _HAS_LLAMAFILE_SUPPORT:
raise RuntimeError(
"Llamafile backend not available. kt_kernel_ext was not compiled with Llamafile support.\n"
"Please recompile with Llamafile enabled."
)
if not os.path.exists(weight_path):
raise FileNotFoundError(f"GGUF weight path not found: {weight_path}")
# Initialize GGUF loader (singleton)
if LlamafileMoEWrapper._gguf_loader_instance is None:
LlamafileMoEWrapper._gguf_loader_instance = GGUFLoader(weight_path)
self.gguf_loader = LlamafileMoEWrapper._gguf_loader_instance
# Validate TP configuration with QK_K alignment
QK_K = 256
# Check if intermediate_size is divisible by QK_K
if moe_intermediate_size % QK_K != 0:
raise ValueError(
f"intermediate_size ({moe_intermediate_size}) must be divisible by QK_K ({QK_K}) "
f"for Llamafile backend"
)
# Calculate TP splits with QK_K alignment
num_blocks = moe_intermediate_size // QK_K
base_blocks = num_blocks // threadpool_count
extra_blocks = num_blocks % threadpool_count
# Validate that we have enough blocks
if base_blocks == 0:
valid_tp_counts = list(range(1, num_blocks + 1))
raise ValueError(
f"intermediate_size ({moe_intermediate_size}) is too small for threadpool_count ({threadpool_count}).\n"
f"Total blocks: {num_blocks} (intermediate_size / QK_K)\n"
f"Cannot distribute to {threadpool_count} TPs (each TP needs at least 1 block).\n"
f"Valid threadpool_count values: {valid_tp_counts}"
)
# Log TP split information
print(f"[LlamafileMoEWrapper] Layer {layer_idx} TP configuration:")
print(f" intermediate_size: {moe_intermediate_size}")
print(f" threadpool_count: {threadpool_count}")
print(f" QK_K: {QK_K}")
print(f" Total blocks: {num_blocks}")
print(f" Base blocks per TP: {base_blocks}")
print(f" Extra blocks (distributed to first TPs): {extra_blocks}")
current_offset = 0
for tp_id in range(threadpool_count):
tp_blocks = base_blocks + (1 if tp_id < extra_blocks else 0)
tp_size = tp_blocks * QK_K
print(f" TP {tp_id}: size={tp_size}, offset={current_offset}, blocks={tp_blocks}")
current_offset += tp_size
# Initialize base class
super().__init__(
layer_idx=layer_idx,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
num_gpu_experts=num_gpu_experts,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
chunked_prefill_size=chunked_prefill_size,
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
)
self.weights_to_keep = None
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
):
"""
Online quantization is not supported for Llamafile backend.
Use pre-quantized GGUF weights instead.
"""
raise NotImplementedError(
"Llamafile backend does not support online quantization (load_weights_from_tensors).\n"
"Please use pre-quantized GGUF weights and call load_weights() instead."
)
def load_weights(self, physical_to_logical_map_cpu: Optional[torch.Tensor] = None):
"""
Load weights for this layer from GGUF files and initialize the MoE module.
Args:
physical_to_logical_map_cpu: Optional mapping from physical to logical expert IDs
Shape: [num_experts], dtype: int32
If None, uses identity mapping [0, 1, 2, ..., num_experts-1]
"""
if not _HAS_LLAMAFILE_SUPPORT:
raise RuntimeError(
"Llamafile backend not available. kt_kernel_ext was not compiled with Llamafile support.\n"
"Please recompile with Llamafile enabled."
)
if physical_to_logical_map_cpu is None:
physical_to_logical_map_cpu = torch.arange(
self.num_experts,
dtype=torch.int32,
device="cpu"
)
print(f" Using default identity mapping for {self.num_experts} experts")
base_key = f"blk.{self.layer_idx}"
# Load quantized tensors from GGUF
gate_data, gate_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(
f"{base_key}.ffn_gate_exps.weight"
)
up_data, up_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(
f"{base_key}.ffn_up_exps.weight"
)
down_data, down_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(
f"{base_key}.ffn_down_exps.weight"
)
# Keep tensors alive
self.weights_to_keep = (gate_data, up_data, down_data)
hidden_type = ggml_type.BF16
# Configure MoE
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
# Llamafile-specific configuration
moe_config.m_block = 32 # Parallel block size
moe_config.group_min_len = 10 # Use forward_one when qlen < 10
moe_config.max_len = self.chunked_prefill_size
moe_config.group_max_len = max(1, int(self.chunked_prefill_size))
# Set weight pointers
moe_config.gate_proj = gate_data.data_ptr()
moe_config.up_proj = up_data.data_ptr()
moe_config.down_proj = down_data.data_ptr()
# Set quantization types
moe_config.gate_type = gate_type
moe_config.up_type = up_type
moe_config.down_type = down_type
moe_config.hidden_type = hidden_type
# Create MoE module
self.moe = MOE(moe_config)
# Load weights
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()

View File

@@ -0,0 +1,522 @@
"""
Weight loaders for different formats.
This module provides loaders for:
- SafeTensor format (for AMX quantized weights)
- GGUF format (for Llamafile quantized weights)
"""
from __future__ import annotations
import os
import numpy as np
import torch
from enum import IntEnum
from safetensors import safe_open
from gguf.gguf_reader import GGUFReader
class GGMLQuantizationType(IntEnum):
"""GGML quantization type enumeration"""
F32 = 0
F16 = 1
Q4_0 = 2
Q4_1 = 3
Q5_0 = 6
Q5_1 = 7
Q8_0 = 8
Q8_1 = 9
Q2_K = 10
Q3_K = 11
Q4_K = 12
Q5_K = 13
Q6_K = 14
Q8_K = 15
IQ2_XXS = 16
IQ2_XS = 17
IQ3_XXS = 18
IQ1_S = 19
IQ4_NL = 20
IQ3_S = 21
IQ2_S = 22
IQ4_XS = 23
I8 = 24
I16 = 25
I32 = 26
I64 = 27
F64 = 28
IQ1_M = 29
BF16 = 30
def translate_name_to_gguf(name):
"""
Translate PyTorch tensor name to GGUF format
"""
name = name.replace("lm_head.", "output.")
name = name.replace("model.embed_tokens.", "token_embd.")
name = name.replace("model.norm.", "output_norm.")
name = name.replace("model.layers.", "blk.")
name = name.replace(".input_layernorm", ".attn_norm")
name = name.replace(".mlp.down_proj", ".ffn_down")
name = name.replace(".mlp.gate_proj", ".ffn_gate")
name = name.replace(".mlp.up_proj", ".ffn_up")
name = name.replace(".post_attention_layernorm", ".ffn_norm")
name = name.replace(".self_attn.q_proj", ".attn_q")
name = name.replace(".self_attn.k_proj", ".attn_k")
name = name.replace(".self_attn.v_proj", ".attn_v")
name = name.replace(".self_attn.o_proj", ".attn_output")
name = name.replace(".self_attn.qkv_proj", ".attn_qkv")
name = name.replace(".self_attn.kv_a_proj_with_mqa", ".attn_kv_a_mqa")
name = name.replace(".self_attn.kv_a_layernorm", ".attn_kv_a_norm")
name = name.replace(".self_attn.kv_b_proj", ".attn_kv_b")
name = name.replace(".self_attn.q_a_proj", ".attn_q_a")
name = name.replace(".self_attn.q_a_layernorm", ".attn_q_a_norm")
name = name.replace(".self_attn.q_b_proj", ".attn_q_b")
name = name.replace(".self_attn.q_norm", ".attn_q_norm")
name = name.replace(".self_attn.k_norm", ".attn_k_norm")
name = name.replace(".shared_expert.", ".shared_experts.")
name = name.replace(".shared_expert_", ".shared_experts_")
name = name.replace(".gate_up_proj.", ".up_proj")
name = name.replace(".mlp.shared_experts.down_proj", ".ffn_down_shexp")
name = name.replace(".mlp.gate.e_score_correction_bias", ".exp_probs_b.bias")
name = name.replace(".mlp.gate", ".ffn_gate_inp")
name = name.replace(".mlp.shared_experts.gate_proj", ".ffn_gate_shexp")
name = name.replace(".mlp.shared_experts.up_proj", ".ffn_up_shexp")
name = name.replace(".mlp.shared_experts_gate", ".ffn_gate_inp_shexp")
name = name.replace(".mlp.experts", "")
name = name.replace(".mlp.experts.ffn_down_exps", ".ffn_down_exps")
name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps")
name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps")
name = name.replace(".block_sparse_moe.gate.", ".ffn_gate_inp.")
name = name.replace(".block_sparse_moe.experts", "")
name = name.replace(".feed_forward.experts", "")
name = name.replace(".feed_forward.router", ".ffn_gate_inp")
name = name.replace(".feed_forward.shared_experts.down_proj", ".ffn_down_shexp")
name = name.replace(".feed_forward.shared_experts.gate_proj", ".ffn_gate_shexp")
name = name.replace(".feed_forward.shared_experts.up_proj", ".ffn_up_shexp")
return name
class SafeTensorLoader:
"""
SafeTensor format loader for AMX quantized weights.
Supports loading tensors from .safetensors files with NUMA-sharded expert weights.
"""
tensor_file_map: dict
tensor_type_map: dict
file_handle_map: dict
tensor_device_map: dict
def __init__(self, file_path: str):
self.__load_tensor_file_map(file_path)
def __load_tensor_file_map(self, file_path: str):
if not os.path.exists(file_path):
raise FileNotFoundError(f"Path not found: {file_path}")
if os.path.isfile(file_path):
folder_path = os.path.dirname(file_path)
else:
folder_path = file_path
self.file_handle_map = {}
self.tensor_file_map = {}
self.tensor_type_map = {}
self.tensor_device_map = {}
found_safetensor = False
for root, _, files in os.walk(folder_path):
files = sorted(files)
for file in files:
if file.endswith(".safetensors"):
found_safetensor = True
file_path = os.path.join(root, file)
if file not in self.file_handle_map:
try:
handle = safe_open(file_path, framework="pt")
self.file_handle_map[file] = handle
except Exception as e:
print(f"Error opening Safetensor file {file_path}: {e}")
continue
f = self.file_handle_map.get(file)
if f is None:
continue
try:
for key in f.keys():
self.tensor_file_map[key] = file
except Exception as e:
print(f"Error reading Safetensor file {file_path}: {e}")
if not found_safetensor:
raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
def load_tensor(self, key: str, device: str = "cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
return tensor.to(device)
def close_all_handles(self):
for handle in self.file_handle_map.values():
handle.close()
self.file_handle_map.clear()
def load_experts(self, base_key: str, device: str = "cpu"):
"""
Load expert weights from SafeTensor files.
Expected format:
- blk.{layer_index}.ffn_[up, down, gate]_exps.{expert_id}.numa.{numa_id}.weight
- blk.{layer_index}.ffn_[up, down, gate]_exps.{expert_id}.numa.{numa_id}.scale
Args:
base_key: Base key like "blk.{layer_index}"
device: Target device for tensors
Returns:
Dictionary with keys: up, gate, down, up_scale, gate_scale, down_scale
Each value is a list of lists: [numa_id][expert_id] -> numpy array
"""
up_base_key = f"{base_key}.ffn_up_exps"
gate_base_key = f"{base_key}.ffn_gate_exps"
down_base_key = f"{base_key}.ffn_down_exps"
max_numa_id = -1
max_experts_count = -1
while self.has_tensor(f"{up_base_key}.{max_experts_count+1}.numa.{0}.weight"):
max_experts_count += 1
if max_experts_count == 0:
raise ValueError(f"No experts found for key {base_key}")
while self.has_tensor(f"{up_base_key}.{0}.numa.{max_numa_id+1}.weight"):
max_numa_id += 1
# Initialize empty lists to store tensors for each projection type
up_weights = [[] for _ in range(max_numa_id + 1)]
gate_weights = [[] for _ in range(max_numa_id + 1)]
down_weights = [[] for _ in range(max_numa_id + 1)]
up_scales = [[] for _ in range(max_numa_id + 1)]
gate_scales = [[] for _ in range(max_numa_id + 1)]
down_scales = [[] for _ in range(max_numa_id + 1)]
for numa_id in range(max_numa_id + 1):
for expert_id in range(max_experts_count + 1):
up_key = f"{up_base_key}.{expert_id}.numa.{numa_id}.weight"
gate_key = f"{gate_base_key}.{expert_id}.numa.{numa_id}.weight"
down_key = f"{down_base_key}.{expert_id}.numa.{numa_id}.weight"
up_scale_key = f"{up_base_key}.{expert_id}.numa.{numa_id}.scale"
gate_scale_key = f"{gate_base_key}.{expert_id}.numa.{numa_id}.scale"
down_scale_key = f"{down_base_key}.{expert_id}.numa.{numa_id}.scale"
# make sure contiguous
up_tensor = self.load_tensor(up_key, device).numpy()
gate_tensor = self.load_tensor(gate_key, device).numpy()
down_tensor = self.load_tensor(down_key, device).numpy()
up_scale_tensor = self.load_tensor(up_scale_key, device).numpy()
gate_scale_tensor = self.load_tensor(gate_scale_key, device).numpy()
down_scale_tensor = self.load_tensor(down_scale_key, device).numpy()
up_weights[numa_id].append(up_tensor)
gate_weights[numa_id].append(gate_tensor)
down_weights[numa_id].append(down_tensor)
up_scales[numa_id].append(up_scale_tensor)
gate_scales[numa_id].append(gate_scale_tensor)
down_scales[numa_id].append(down_scale_tensor)
return {
"up": up_weights,
"gate": gate_weights,
"down": down_weights,
"up_scale": up_scales,
"gate_scale": gate_scales,
"down_scale": down_scales,
}
def has_tensor(self, name: str):
return name in self.tensor_file_map
class GGUFLoader:
"""
GGUF format loader using the official gguf library (gguf.gguf_reader.GGUFReader)
This is a cleaner implementation compared to manual binary parsing.
"""
def __init__(self, gguf_path: str):
"""
Initialize GGUF loader from a file or directory
Args:
gguf_path: Path to a single GGUF file or a directory containing GGUF files
"""
if not os.path.exists(gguf_path):
raise FileNotFoundError(f"GGUF path not found: {gguf_path}")
self.tensor_info = {}
self.metadata = {}
self.tensor_file_map = {}
self.file_data_map = {}
if os.path.isfile(gguf_path) and gguf_path.endswith('.gguf'):
print(f"\n[GGUFLoader] Loading single GGUF file : {os.path.basename(gguf_path)}")
self._load_single_file(gguf_path)
elif os.path.isdir(gguf_path):
print(f"\n[GGUFLoader] Loading GGUF files from directory: {gguf_path}")
self._load_directory(gguf_path)
else:
raise ValueError(f"Path must be a .gguf file or a directory: {gguf_path}")
print(f"[GGUFLoader] Summary:")
print(f" Files loaded: {len(self.file_data_map)}")
print(f" Total tensors: {len(self.tensor_info)}")
print(f" Metadata keys: {len(self.metadata)}")
tensors = ["blk.0.ffn_up_exps.weight", "blk.0.ffn_gate_exps.weight", "blk.0.ffn_down_exps.weight"]
for key in tensors:
if key in self.tensor_info:
info = self.tensor_info[key]
print(f" {'.'.join(key.split('.')[2:-1])}, Dtype: {info['dtype'].name}")
def _load_single_file(self, file_path: str):
"""Load a single GGUF file"""
reader = GGUFReader(file_path)
for key, field in reader.fields.items():
value = field.parts[field.data[0]]
if isinstance(value, bytes):
value = value.decode('utf-8')
elif isinstance(value, np.ndarray) and value.dtype == np.uint8:
try:
value = bytes(value).decode('utf-8')
except:
pass
self.metadata[key] = value
for tensor in reader.tensors:
self.tensor_info[tensor.name] = {
'shape': list(reversed(tensor.shape)), # Reverse to match PyTorch order
'dtype': tensor.tensor_type,
'offset': tensor.data_offset,
'n_elements': tensor.n_elements,
}
self.tensor_file_map[tensor.name] = file_path
self.file_data_map[file_path] = np.memmap(file_path, mode='r')
def _load_directory(self, dir_path: str):
"""Load all GGUF files from a directory (non-recursive)"""
found_gguf = False
for file in sorted(os.listdir(dir_path)):
if file.endswith(".gguf"):
found_gguf = True
file_path = os.path.join(dir_path, file)
print(f" Loading: {file}")
reader = GGUFReader(file_path)
for key, field in reader.fields.items():
value = field.parts[field.data[0]]
if isinstance(value, bytes):
value = value.decode('utf-8')
elif isinstance(value, np.ndarray) and value.dtype == np.uint8:
try:
value = bytes(value).decode('utf-8')
except:
pass
self.metadata[key] = value
for tensor in reader.tensors:
self.tensor_info[tensor.name] = {
'shape': list(reversed(tensor.shape)),
'dtype': tensor.tensor_type,
'offset': tensor.data_offset,
'n_elements': tensor.n_elements,
}
self.tensor_file_map[tensor.name] = file_path
self.file_data_map[file_path] = np.memmap(file_path, mode='r')
if not found_gguf:
raise FileNotFoundError(f"No .gguf files found in directory: {dir_path}")
def get_model_config(self, layer_idx: int = 0):
"""
Extract model configuration from GGUF metadata and tensor shapes.
Args:
layer_idx: Layer index to inspect (default: 0)
Returns:
dict with keys: num_experts, num_experts_per_tok, hidden_size, moe_intermediate_size
"""
config = {}
arch = self.metadata.get("general.architecture", "unknown")
num_experts = None
for key_suffix in [
"expert_count",
"expert.count",
"moe.expert_count",
"expert_feed_forward_length",
]:
key = f"{arch}.{key_suffix}"
if key in self.metadata:
val = self.metadata[key]
num_experts = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)
break
num_experts_per_tok = None
for key_suffix in [
"expert_used_count",
"expert.used_count",
"moe.num_experts_per_tok",
]:
key = f"{arch}.{key_suffix}"
if key in self.metadata:
val = self.metadata[key]
num_experts_per_tok = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)
break
hidden_size = None
for key_suffix in [
"embedding_length",
"embed_length",
"hidden_size",
]:
key = f"{arch}.{key_suffix}"
if key in self.metadata:
val = self.metadata[key]
hidden_size = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)
break
moe_intermediate_size = None
for key_suffix in [
"expert_feed_forward_length",
"feed_forward_length",
"ffn_length",
"intermediate_size",
]:
key = f"{arch}.{key_suffix}"
if key in self.metadata:
val = self.metadata[key]
moe_intermediate_size = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)
break
if any(v is None for v in [num_experts, hidden_size, moe_intermediate_size]):
base_key = f"blk.{layer_idx}.ffn_gate_exps.weight"
if base_key in self.tensor_info:
gate_shape = self.tensor_info[base_key]['shape']
print(f" Found tensor '{base_key}' with shape: {gate_shape}")
if len(gate_shape) >= 3:
if num_experts is None:
num_experts = int(gate_shape[0])
if moe_intermediate_size is None:
moe_intermediate_size = int(gate_shape[1])
if hidden_size is None:
hidden_size = int(gate_shape[2])
config = {
"num_experts": num_experts,
"num_experts_per_tok": num_experts_per_tok,
"hidden_size": hidden_size,
"moe_intermediate_size": moe_intermediate_size,
}
return config
def print_metadata(self, filter_keywords=None):
"""
Print GGUF file metadata for debugging.
Args:
filter_keywords: Optional list of keywords to filter metadata keys
"""
print(f"\n[GGUFLoader] GGUF Metadata:")
print(f" Total metadata entries: {len(self.metadata)}")
if filter_keywords:
filtered = {k: v for k, v in self.metadata.items()
if any(kw.lower() in k.lower() for kw in filter_keywords)}
for k, v in sorted(filtered.items()):
print(f" {k}: {v}")
else:
for k, v in sorted(self.metadata.items()):
print(f" {k}: {v}")
def has_tensor(self, name: str):
"""Check if tensor exists"""
name = translate_name_to_gguf(name)
return name in self.tensor_info
def get_ggml_type(self, name: str):
"""Get GGML type of a tensor"""
name = translate_name_to_gguf(name)
if name not in self.tensor_info:
raise KeyError(f"Tensor '{name}' not found in GGUF files")
return self.tensor_info[name]["dtype"]
def get_undequanted_tensor_and_ggml_type(self, name: str):
"""
Get tensor data and its GGML type without dequantizing
Args:
name: Tensor name (in PyTorch format, will be translated to GGUF format)
Returns:
(data, ggml_type): Tuple of tensor data and GGML quantization type
"""
name = translate_name_to_gguf(name)
if name not in self.tensor_info:
raise KeyError(f"Tensor '{name}' not found in GGUF files")
info = self.tensor_info[name]
file_path = self.tensor_file_map[name]
mmap_data = self.file_data_map[file_path]
offset = info['offset']
n_elements = info['n_elements']
ggml_type = info['dtype']
GGML_QUANT_SIZES = {
GGMLQuantizationType.F32: (1, 4),
GGMLQuantizationType.F16: (1, 2),
GGMLQuantizationType.BF16: (1, 2),
GGMLQuantizationType.Q4_0: (32, 2 + 16),
GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
GGMLQuantizationType.Q8_0: (32, 2 + 32),
GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
GGMLQuantizationType.Q2_K: (256, 2 + 2 + 256 // 16 + 256 // 4),
GGMLQuantizationType.Q3_K: (256, 2 + 256 // 4 + 256 // 8 + 12),
GGMLQuantizationType.Q4_K: (256, 2 + 2 + 256 // 2 + 12),
GGMLQuantizationType.Q5_K: (256, 2 + 2 + 256 // 2 + 256 // 8 + 12),
GGMLQuantizationType.Q6_K: (256, 2 + 256 // 2 + 256 // 4 + 256 // 16),
GGMLQuantizationType.Q8_K: (256, 4 + 256 + 256 // 8),
GGMLQuantizationType.IQ2_XXS: (256, 2 + 256 // 4),
GGMLQuantizationType.IQ2_XS: (256, 2 + 256 // 4 + 256 // 32),
GGMLQuantizationType.IQ3_XXS: (256, 2 + 256 // 4 + 256 // 8),
GGMLQuantizationType.IQ1_S: (256, 2 + 256 // 8 + 256 // 16),
GGMLQuantizationType.IQ4_NL: (32, 2 + 16),
GGMLQuantizationType.IQ3_S: (256, 2 + 256 // 4 + 256 // 8 + 256 // 32 + 4),
GGMLQuantizationType.IQ2_S: (256, 2 + 256 // 4 + 256 // 16),
GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + 256 // 2 + 256 // 64),
GGMLQuantizationType.I8: (1, 1),
GGMLQuantizationType.I16: (1, 2),
GGMLQuantizationType.I32: (1, 4),
GGMLQuantizationType.I64: (1, 8),
GGMLQuantizationType.F64: (1, 8),
GGMLQuantizationType.IQ1_M: (256, 256 // 8 + 256 // 16 + 256 // 32),
}
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
n_bytes = n_elements * type_size // block_size
data_bytes = mmap_data[offset : offset + n_bytes]
data = torch.from_numpy(np.frombuffer(data_bytes, dtype=np.uint8).copy())
return data, ggml_type