mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-23 07:49:17 +00:00
Fix kt-kernel for new wrapper (#1588)
* update README for kt-kernel * style: format C++ and Python code in kt-kernel - Format C++ files: task_queue, ext_bindings, and MoE operators - Format Python utility modules: amx, llamafile, and loader - Improve code readability and consistency
This commit is contained in:
@@ -5,9 +5,13 @@ High-performance kernel operations for KTransformers, featuring CPU-optimized Mo
|
||||
## Features
|
||||
|
||||
- **AMX Optimization**: Intel AMX (Advanced Matrix Extensions) support for INT4/INT8 quantized MoE inference
|
||||
- **Multi-Backend**: AVX512, AVX2, and ARM KML support
|
||||
- **Multi-Backend**: Unified `KTMoEWrapper` API supporting multiple backends (AMXINT4, AMXINT8, LLAMAFILE*)
|
||||
- **Flexible Backends**: AVX512, AVX2 via pluggable backend architecture
|
||||
- **Efficient MoE**: Optimized Mixture-of-Experts operations with NUMA-aware memory management
|
||||
- **Easy Integration**: Clean Python API with `AMXMoEWrapper` and future wrapper support
|
||||
- **Async Execution**: Non-blocking `submit_forward` / `sync_forward` API for improved pipelining
|
||||
- **Easy Integration**: Clean Python API with automatic backend selection
|
||||
|
||||
**Note**: *LLAMAFILE backend support is currently in preview and not yet fully complete.
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -42,10 +46,10 @@ pip install -r requirements.txt
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from kt_kernel import AMXMoEWrapper
|
||||
from kt_kernel import KTMoEWrapper
|
||||
|
||||
# Initialize the MoE wrapper
|
||||
wrapper = AMXMoEWrapper(
|
||||
wrapper = KTMoEWrapper(
|
||||
layer_idx=0,
|
||||
num_experts=8,
|
||||
num_experts_per_tok=2,
|
||||
@@ -53,16 +57,55 @@ wrapper = AMXMoEWrapper(
|
||||
moe_intermediate_size=14336,
|
||||
num_gpu_experts=2,
|
||||
cpuinfer_threads=32,
|
||||
subpool_count=2,
|
||||
amx_weight_path="/path/to/weights",
|
||||
chunked_prefill_size=512
|
||||
threadpool_count=2,
|
||||
weight_path="/path/to/weights",
|
||||
chunked_prefill_size=512,
|
||||
method="AMXINT4" # Options: "AMXINT4", "AMXINT8", "LLAMAFILE" (preview)
|
||||
)
|
||||
|
||||
# Load weights
|
||||
# Load weights (from disk - pre-quantized)
|
||||
wrapper.load_weights(physical_to_logical_map)
|
||||
|
||||
# Or load weights from tensors (online quantization)
|
||||
wrapper.load_weights_from_tensors(gate_proj, up_proj, down_proj, physical_to_logical_map)
|
||||
|
||||
# Run inference
|
||||
output = wrapper.forward(hidden_states, topk_ids, topk_weights, cuda_stream)
|
||||
|
||||
# Or use async API for better performance
|
||||
wrapper.submit_forward(hidden_states, topk_ids, topk_weights, cuda_stream)
|
||||
# ... do other work ...
|
||||
output = wrapper.sync_forward(hidden_states, cuda_stream)
|
||||
```
|
||||
|
||||
### Advanced Options
|
||||
|
||||
```python
|
||||
# Initialize with additional options
|
||||
wrapper = KTMoEWrapper(
|
||||
layer_idx=0,
|
||||
num_experts=8,
|
||||
num_experts_per_tok=2,
|
||||
hidden_size=4096,
|
||||
moe_intermediate_size=14336,
|
||||
num_gpu_experts=2,
|
||||
cpuinfer_threads=32,
|
||||
threadpool_count=2,
|
||||
weight_path="/path/to/weights",
|
||||
chunked_prefill_size=512,
|
||||
method="AMXINT4",
|
||||
cpu_save=False, # Keep weights in CPU memory after loading
|
||||
max_deferred_experts_per_token=0 # Number of experts to defer (for pipelined execution)
|
||||
)
|
||||
|
||||
# Pre-allocate buffers for specific batch sizes (improves performance)
|
||||
KTMoEWrapper.set_capture_batch_sizes([1, 2, 4, 8, 16])
|
||||
|
||||
# Query captured batch sizes
|
||||
batch_sizes = KTMoEWrapper.get_capture_batch_sizes()
|
||||
|
||||
# Clear buffer cache to free memory
|
||||
KTMoEWrapper.clear_buffer_cache()
|
||||
```
|
||||
|
||||
## Build Configuration
|
||||
@@ -100,7 +143,7 @@ pip install .
|
||||
## Verification
|
||||
|
||||
```bash
|
||||
python -c "from kt_kernel import AMXMoEWrapper; print('✓ kt-kernel installed successfully')"
|
||||
python -c "from kt_kernel import KTMoEWrapper; print('✓ kt-kernel installed successfully')"
|
||||
```
|
||||
|
||||
## Weight Quantization
|
||||
|
||||
@@ -44,8 +44,7 @@ void TaskQueue::enqueue(std::function<void()> task) {
|
||||
|
||||
void TaskQueue::sync(size_t allow_n_pending) {
|
||||
// Spin until the pending task count drops to the allowed threshold.
|
||||
while (pending.load(std::memory_order_acquire) > allow_n_pending)
|
||||
;
|
||||
while (pending.load(std::memory_order_acquire) > allow_n_pending);
|
||||
}
|
||||
|
||||
void TaskQueue::worker() {
|
||||
|
||||
@@ -180,7 +180,7 @@ class MOEBindings {
|
||||
// printf("debug physical_to_logical_map in arg:%lu\n", physical_to_logical_map);
|
||||
moe->config.physical_to_logical_map = reinterpret_cast<void*>(physical_to_logical_map);
|
||||
// printf("moe ptr:%p,confirm: moe->config.physical_to_logical_map:%lu\n", reinterpret_cast<void*>(moe.get()),
|
||||
// reinterpret_cast<uintptr_t>(moe->config.physical_to_logical_map));
|
||||
// reinterpret_cast<uintptr_t>(moe->config.physical_to_logical_map));
|
||||
}
|
||||
return std::make_pair((intptr_t)&inner, (intptr_t)args);
|
||||
}
|
||||
|
||||
@@ -338,53 +338,53 @@ class LLAMA_MOE_TP {
|
||||
}
|
||||
int ith = task_id % nth;
|
||||
|
||||
void* gate_proj_ptr =
|
||||
(uint8_t*)m_local_gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.m_block) *
|
||||
config_.hidden_size * ggml_type_size((ggml_type)config_.gate_type) /
|
||||
ggml_blck_size((ggml_type)config_.gate_type);
|
||||
void* gate_proj_ptr =
|
||||
(uint8_t*)m_local_gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.m_block) *
|
||||
config_.hidden_size * ggml_type_size((ggml_type)config_.gate_type) /
|
||||
ggml_blck_size((ggml_type)config_.gate_type);
|
||||
|
||||
float* gate_output_ptr = s_gate_output_[act_idx] + ith * config_.m_block;
|
||||
auto ok = llamafile_sgemm(config_.m_block, 1,
|
||||
config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_proj_ptr,
|
||||
config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_input_ptr,
|
||||
config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_output_ptr,
|
||||
config_.m_block, 0, 1, GGML_TASK_TYPE_COMPUTE, (ggml_type)config_.gate_type,
|
||||
ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type,
|
||||
GGML_TYPE_F32, GGML_PREC_DEFAULT);
|
||||
if (ok == false) [[unlikely]] {
|
||||
throw std::runtime_error("llamafile not supported");
|
||||
}
|
||||
float* gate_output_ptr = s_gate_output_[act_idx] + ith * config_.m_block;
|
||||
auto ok = llamafile_sgemm(
|
||||
config_.m_block, 1, config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_proj_ptr,
|
||||
config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_input_ptr,
|
||||
config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_output_ptr, config_.m_block, 0,
|
||||
1, GGML_TASK_TYPE_COMPUTE, (ggml_type)config_.gate_type,
|
||||
ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type, GGML_TYPE_F32,
|
||||
GGML_PREC_DEFAULT);
|
||||
if (ok == false) [[unlikely]] {
|
||||
throw std::runtime_error("llamafile not supported");
|
||||
}
|
||||
|
||||
void* up_proj_ptr =
|
||||
(uint8_t*)m_local_up_proj_ + (expert_id * config_.intermediate_size + ith * config_.m_block) *
|
||||
config_.hidden_size * ggml_type_size((ggml_type)config_.up_type) /
|
||||
ggml_blck_size((ggml_type)config_.up_type);
|
||||
void* up_proj_ptr =
|
||||
(uint8_t*)m_local_up_proj_ + (expert_id * config_.intermediate_size + ith * config_.m_block) *
|
||||
config_.hidden_size * ggml_type_size((ggml_type)config_.up_type) /
|
||||
ggml_blck_size((ggml_type)config_.up_type);
|
||||
|
||||
float* up_output_ptr = s_up_output_[act_idx] + ith * config_.m_block;
|
||||
llamafile_sgemm(config_.m_block, 1, config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type),
|
||||
up_proj_ptr, config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type), up_input_ptr,
|
||||
config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type), up_output_ptr,
|
||||
config_.m_block, 0, 1, GGML_TASK_TYPE_COMPUTE, (ggml_type)config_.up_type,
|
||||
ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type, GGML_TYPE_F32,
|
||||
GGML_PREC_DEFAULT);
|
||||
float* up_output_ptr = s_up_output_[act_idx] + ith * config_.m_block;
|
||||
llamafile_sgemm(config_.m_block, 1, config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type),
|
||||
up_proj_ptr, config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type), up_input_ptr,
|
||||
config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type), up_output_ptr,
|
||||
config_.m_block, 0, 1, GGML_TASK_TYPE_COMPUTE, (ggml_type)config_.up_type,
|
||||
ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type, GGML_TYPE_F32,
|
||||
GGML_PREC_DEFAULT);
|
||||
|
||||
for (int i = ith * config_.m_block; i < (ith + 1) * config_.m_block; i++) {
|
||||
s_intermediate_fp32_[act_idx][i] = act_fn(s_gate_output_[act_idx][i]) * s_up_output_[act_idx][i];
|
||||
}
|
||||
if (config_.m_block %
|
||||
ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) ==
|
||||
0) {
|
||||
float* intermediate_fp32_ptr = s_intermediate_fp32_[act_idx] + ith * config_.m_block;
|
||||
void* down_input_ptr =
|
||||
s_down_input_[act_idx] +
|
||||
ith * config_.m_block *
|
||||
ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) /
|
||||
ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type);
|
||||
from_float(intermediate_fp32_ptr, down_input_ptr, config_.m_block,
|
||||
ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type);
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
for (int i = ith * config_.m_block; i < (ith + 1) * config_.m_block; i++) {
|
||||
s_intermediate_fp32_[act_idx][i] = act_fn(s_gate_output_[act_idx][i]) * s_up_output_[act_idx][i];
|
||||
}
|
||||
if (config_.m_block %
|
||||
ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) ==
|
||||
0) {
|
||||
float* intermediate_fp32_ptr = s_intermediate_fp32_[act_idx] + ith * config_.m_block;
|
||||
void* down_input_ptr =
|
||||
s_down_input_[act_idx] +
|
||||
ith * config_.m_block *
|
||||
ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) /
|
||||
ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type);
|
||||
from_float(intermediate_fp32_ptr, down_input_ptr, config_.m_block,
|
||||
ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type);
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
|
||||
if (config_.m_block % ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) !=
|
||||
@@ -795,22 +795,21 @@ class TP_MOE<LLAMA_MOE_TP> : public TP_MOE_Common<LLAMA_MOE_TP> {
|
||||
|
||||
void merge_results(int qlen, void* output) { merge_results(qlen, output, false); }
|
||||
|
||||
void merge_results(int qlen, void *output, bool incremental) {
|
||||
void merge_results(int qlen, void* output, bool incremental) {
|
||||
auto pool = this->config.pool;
|
||||
pool->do_work_stealing_job(
|
||||
qlen, nullptr,
|
||||
[this, output, incremental](int token_nth) {
|
||||
if (incremental) {
|
||||
to_float((uint8_t *)output + token_nth * config.hidden_size *
|
||||
ggml_type_size((ggml_type)config.hidden_type) /
|
||||
ggml_blck_size((ggml_type)config.hidden_type),
|
||||
to_float((uint8_t*)output + token_nth * config.hidden_size * ggml_type_size((ggml_type)config.hidden_type) /
|
||||
ggml_blck_size((ggml_type)config.hidden_type),
|
||||
local_output + token_nth * config.hidden_size, config.hidden_size, (ggml_type)config.hidden_type);
|
||||
for (int e = 0; e < config.hidden_size; e++) {
|
||||
local_output_numa[0][token_nth * config.hidden_size + e] +=
|
||||
local_output[token_nth * config.hidden_size + e];
|
||||
}
|
||||
}
|
||||
auto &tp_count = this->tp_count;
|
||||
auto& tp_count = this->tp_count;
|
||||
for (int i = 1; i < tp_count; i++) {
|
||||
for (int e = 0; e < config.hidden_size; e++) {
|
||||
local_output_numa[0][token_nth * config.hidden_size + e] +=
|
||||
@@ -818,9 +817,8 @@ class TP_MOE<LLAMA_MOE_TP> : public TP_MOE_Common<LLAMA_MOE_TP> {
|
||||
}
|
||||
}
|
||||
from_float(local_output_numa[0] + token_nth * config.hidden_size,
|
||||
(uint8_t *)output + token_nth * config.hidden_size *
|
||||
ggml_type_size((ggml_type)config.hidden_type) /
|
||||
ggml_blck_size((ggml_type)config.hidden_type),
|
||||
(uint8_t*)output + token_nth * config.hidden_size * ggml_type_size((ggml_type)config.hidden_type) /
|
||||
ggml_blck_size((ggml_type)config.hidden_type),
|
||||
config.hidden_size, (ggml_type)config.hidden_type);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <type_traits>
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
// Forward declaration for Llamafile backend type checking
|
||||
@@ -29,7 +30,7 @@ class TP_MOE_Common : public MoE_Interface {
|
||||
std::vector<std::unique_ptr<T>> tps;
|
||||
|
||||
std::vector<typename T::output_t*> local_output_numa;
|
||||
T::output_t *local_output = nullptr;
|
||||
T::output_t* local_output = nullptr;
|
||||
|
||||
bool weights_loaded = false;
|
||||
|
||||
@@ -57,18 +58,17 @@ class TP_MOE_Common : public MoE_Interface {
|
||||
"multiple of NUMA node count");
|
||||
}
|
||||
|
||||
|
||||
// Check if this is Llamafile backend using compile-time type checking
|
||||
constexpr bool is_llamafile = std::is_same<T, LLAMA_MOE_TP>::value;
|
||||
#ifndef QK_K
|
||||
#define QK_K 256
|
||||
#endif
|
||||
#ifndef QK_K
|
||||
#define QK_K 256
|
||||
#endif
|
||||
|
||||
if (is_llamafile) {
|
||||
// For Llamafile backend: use QK_K-aligned TP splitting
|
||||
if (config.intermediate_size % QK_K != 0) {
|
||||
printf("intermediate_size %d must be divisible by QK_K %d for Llamafile backend\n",
|
||||
config.intermediate_size, QK_K);
|
||||
printf("intermediate_size %d must be divisible by QK_K %d for Llamafile backend\n", config.intermediate_size,
|
||||
QK_K);
|
||||
throw std::runtime_error("intermediate_size must be divisible by QK_K (256) for Llamafile backend");
|
||||
}
|
||||
|
||||
@@ -77,13 +77,13 @@ class TP_MOE_Common : public MoE_Interface {
|
||||
int extra_blocks = num_blocks % tp_count;
|
||||
|
||||
if (base_blocks == 0) {
|
||||
printf("intermediate_size %d is too small for tp_count %d (num_blocks=%d)\n",
|
||||
config.intermediate_size, tp_count, num_blocks);
|
||||
printf("intermediate_size %d is too small for tp_count %d (num_blocks=%d)\n", config.intermediate_size,
|
||||
tp_count, num_blocks);
|
||||
throw std::runtime_error("intermediate_size too small: cannot distribute blocks to all TP instances");
|
||||
}
|
||||
|
||||
printf("Llamafile TP splitting: intermediate_size=%d, tp_count=%d, QK_K=%d\n",
|
||||
config.intermediate_size, tp_count, QK_K);
|
||||
printf("Llamafile TP splitting: intermediate_size=%d, tp_count=%d, QK_K=%d\n", config.intermediate_size, tp_count,
|
||||
QK_K);
|
||||
printf(" num_blocks=%d, base_blocks=%d, extra_blocks=%d\n", num_blocks, base_blocks, extra_blocks);
|
||||
|
||||
int current_offset = 0;
|
||||
@@ -95,8 +95,8 @@ class TP_MOE_Common : public MoE_Interface {
|
||||
int num_blocks_for_this_tp = base_blocks + (i < extra_blocks ? 1 : 0);
|
||||
tp_config.intermediate_size = num_blocks_for_this_tp * QK_K;
|
||||
|
||||
printf(" TP %d: intermediate_size=%d, offset=%d, blocks=%d\n",
|
||||
i, tp_config.intermediate_size, current_offset, num_blocks_for_this_tp);
|
||||
printf(" TP %d: intermediate_size=%d, offset=%d, blocks=%d\n", i, tp_config.intermediate_size, current_offset,
|
||||
num_blocks_for_this_tp);
|
||||
|
||||
tp_configs.push_back(tp_config);
|
||||
current_offset += tp_config.intermediate_size;
|
||||
@@ -128,8 +128,9 @@ class TP_MOE_Common : public MoE_Interface {
|
||||
&local_output_numa[i],
|
||||
(size_t)sizeof(typename T::output_t) * tp_configs[i].max_possible_qlen() * tp_configs[i].hidden_size);
|
||||
}
|
||||
mem_requests.append_pointer((void **)&local_output, sizeof(typename T::output_t) * tp_configs[0].max_possible_qlen() *
|
||||
tp_configs[0].hidden_size);
|
||||
mem_requests.append_pointer(
|
||||
(void**)&local_output,
|
||||
sizeof(typename T::output_t) * tp_configs[0].max_possible_qlen() * tp_configs[0].hidden_size);
|
||||
// printf("local output tp, %d,\n", tp_configs[0].max_possible_qlen());
|
||||
shared_mem_buffer.alloc(this, mem_requests);
|
||||
}
|
||||
@@ -204,7 +205,6 @@ class TP_MOE_Common : public MoE_Interface {
|
||||
|
||||
virtual void load_weights() = 0;
|
||||
|
||||
|
||||
virtual void merge_results(int qlen, void* output) = 0;
|
||||
|
||||
virtual void merge_results(int qlen, void* output, bool incremental) {
|
||||
|
||||
@@ -6,8 +6,8 @@ KT-Kernel provides high-performance kernel operations for KTransformers,
|
||||
including CPU-optimized MoE inference with AMX, AVX, and KML support.
|
||||
|
||||
Example usage:
|
||||
>>> from kt_kernel import AMXMoEWrapper
|
||||
>>> wrapper = AMXMoEWrapper(
|
||||
>>> from kt_kernel import KTMoEWrapper
|
||||
>>> wrapper = KTMoEWrapper(
|
||||
... layer_idx=0,
|
||||
... num_experts=8,
|
||||
... num_experts_per_tok=2,
|
||||
@@ -15,9 +15,10 @@ Example usage:
|
||||
... moe_intermediate_size=14336,
|
||||
... num_gpu_experts=2,
|
||||
... cpuinfer_threads=32,
|
||||
... subpool_count=2,
|
||||
... amx_weight_path="/path/to/weights",
|
||||
... chunked_prefill_size=512
|
||||
... threadpool_count=2,
|
||||
... weight_path="/path/to/weights",
|
||||
... chunked_prefill_size=512,
|
||||
... method="AMXINT4"
|
||||
... )
|
||||
"""
|
||||
|
||||
|
||||
@@ -18,13 +18,13 @@ import ctypes
|
||||
import kt_kernel_ext
|
||||
|
||||
|
||||
|
||||
class KExpertsCPUBuffer:
|
||||
"""
|
||||
CPU buffer management for expert computation.
|
||||
|
||||
Manages pinned memory buffers for efficient GPU-CPU data transfer.
|
||||
"""
|
||||
|
||||
capture_bs: List = list()
|
||||
capture_buffers: Dict = dict()
|
||||
temp_bs: int = 0
|
||||
@@ -62,8 +62,7 @@ class KExpertsCPUBuffer:
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
bsz_tensor_cpu = [
|
||||
torch.zeros((1,), device="cpu", dtype=torch.int32, pin_memory=True)
|
||||
for _ in range(cls.buffer_depth)
|
||||
torch.zeros((1,), device="cpu", dtype=torch.int32, pin_memory=True) for _ in range(cls.buffer_depth)
|
||||
]
|
||||
output_gpu = [
|
||||
torch.zeros((batch_size, hidden_size), device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
@@ -129,7 +128,6 @@ class BaseMoEWrapper(ABC):
|
||||
max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).
|
||||
method: Backend method string
|
||||
"""
|
||||
print(f"Init {self.__class__.__name__}")
|
||||
self.layer_idx = layer_idx
|
||||
self.num_experts = num_experts
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
@@ -139,7 +137,9 @@ class BaseMoEWrapper(ABC):
|
||||
self.weight_path = weight_path
|
||||
self.chunked_prefill_size = chunked_prefill_size
|
||||
self.cpu_save = cpu_save
|
||||
self.max_deferred_experts_per_token = int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0
|
||||
self.max_deferred_experts_per_token = (
|
||||
int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0
|
||||
)
|
||||
|
||||
BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
|
||||
self.method = method
|
||||
|
||||
@@ -6,15 +6,17 @@ import ctypes
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import SafeTensorLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
|
||||
try:
|
||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE
|
||||
|
||||
_HAS_AMX_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_AMX_SUPPORT = False
|
||||
AMXInt4_MOE, AMXInt8_MOE = None, None
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
|
||||
class AMXMoEWrapper(BaseMoEWrapper):
|
||||
"""
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import GGUFLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
|
||||
try:
|
||||
from kt_kernel_ext.moe import MOE
|
||||
|
||||
_HAS_LLAMAFILE_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_LLAMAFILE_SUPPORT = False
|
||||
@@ -14,6 +17,7 @@ except (ImportError, AttributeError):
|
||||
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
|
||||
|
||||
class LlamafileMoEWrapper(BaseMoEWrapper):
|
||||
"""
|
||||
Llamafile-based MoE wrapper implementation.
|
||||
@@ -162,27 +166,17 @@ class LlamafileMoEWrapper(BaseMoEWrapper):
|
||||
)
|
||||
|
||||
if physical_to_logical_map_cpu is None:
|
||||
physical_to_logical_map_cpu = torch.arange(
|
||||
self.num_experts,
|
||||
dtype=torch.int32,
|
||||
device="cpu"
|
||||
)
|
||||
physical_to_logical_map_cpu = torch.arange(self.num_experts, dtype=torch.int32, device="cpu")
|
||||
print(f" Using default identity mapping for {self.num_experts} experts")
|
||||
|
||||
base_key = f"blk.{self.layer_idx}"
|
||||
|
||||
# Load quantized tensors from GGUF
|
||||
gate_data, gate_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(
|
||||
f"{base_key}.ffn_gate_exps.weight"
|
||||
)
|
||||
gate_data, gate_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(f"{base_key}.ffn_gate_exps.weight")
|
||||
|
||||
up_data, up_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(
|
||||
f"{base_key}.ffn_up_exps.weight"
|
||||
)
|
||||
up_data, up_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(f"{base_key}.ffn_up_exps.weight")
|
||||
|
||||
down_data, down_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(
|
||||
f"{base_key}.ffn_down_exps.weight"
|
||||
)
|
||||
down_data, down_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(f"{base_key}.ffn_down_exps.weight")
|
||||
|
||||
# Keep tensors alive
|
||||
self.weights_to_keep = (gate_data, up_data, down_data)
|
||||
|
||||
@@ -18,35 +18,36 @@ from gguf.gguf_reader import GGUFReader
|
||||
|
||||
class GGMLQuantizationType(IntEnum):
|
||||
"""GGML quantization type enumeration"""
|
||||
F32 = 0
|
||||
F16 = 1
|
||||
Q4_0 = 2
|
||||
Q4_1 = 3
|
||||
Q5_0 = 6
|
||||
Q5_1 = 7
|
||||
Q8_0 = 8
|
||||
Q8_1 = 9
|
||||
Q2_K = 10
|
||||
Q3_K = 11
|
||||
Q4_K = 12
|
||||
Q5_K = 13
|
||||
Q6_K = 14
|
||||
Q8_K = 15
|
||||
|
||||
F32 = 0
|
||||
F16 = 1
|
||||
Q4_0 = 2
|
||||
Q4_1 = 3
|
||||
Q5_0 = 6
|
||||
Q5_1 = 7
|
||||
Q8_0 = 8
|
||||
Q8_1 = 9
|
||||
Q2_K = 10
|
||||
Q3_K = 11
|
||||
Q4_K = 12
|
||||
Q5_K = 13
|
||||
Q6_K = 14
|
||||
Q8_K = 15
|
||||
IQ2_XXS = 16
|
||||
IQ2_XS = 17
|
||||
IQ2_XS = 17
|
||||
IQ3_XXS = 18
|
||||
IQ1_S = 19
|
||||
IQ4_NL = 20
|
||||
IQ3_S = 21
|
||||
IQ2_S = 22
|
||||
IQ4_XS = 23
|
||||
I8 = 24
|
||||
I16 = 25
|
||||
I32 = 26
|
||||
I64 = 27
|
||||
F64 = 28
|
||||
IQ1_M = 29
|
||||
BF16 = 30
|
||||
IQ1_S = 19
|
||||
IQ4_NL = 20
|
||||
IQ3_S = 21
|
||||
IQ2_S = 22
|
||||
IQ4_XS = 23
|
||||
I8 = 24
|
||||
I16 = 25
|
||||
I32 = 26
|
||||
I64 = 27
|
||||
F64 = 28
|
||||
IQ1_M = 29
|
||||
BF16 = 30
|
||||
|
||||
|
||||
def translate_name_to_gguf(name):
|
||||
@@ -104,6 +105,7 @@ class SafeTensorLoader:
|
||||
|
||||
Supports loading tensors from .safetensors files with NUMA-sharded expert weights.
|
||||
"""
|
||||
|
||||
tensor_file_map: dict
|
||||
tensor_type_map: dict
|
||||
file_handle_map: dict
|
||||
@@ -257,7 +259,7 @@ class GGUFLoader:
|
||||
self.tensor_file_map = {}
|
||||
self.file_data_map = {}
|
||||
|
||||
if os.path.isfile(gguf_path) and gguf_path.endswith('.gguf'):
|
||||
if os.path.isfile(gguf_path) and gguf_path.endswith(".gguf"):
|
||||
print(f"\n[GGUFLoader] Loading single GGUF file : {os.path.basename(gguf_path)}")
|
||||
self._load_single_file(gguf_path)
|
||||
elif os.path.isdir(gguf_path):
|
||||
@@ -283,24 +285,24 @@ class GGUFLoader:
|
||||
for key, field in reader.fields.items():
|
||||
value = field.parts[field.data[0]]
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf-8')
|
||||
value = value.decode("utf-8")
|
||||
elif isinstance(value, np.ndarray) and value.dtype == np.uint8:
|
||||
try:
|
||||
value = bytes(value).decode('utf-8')
|
||||
value = bytes(value).decode("utf-8")
|
||||
except:
|
||||
pass
|
||||
self.metadata[key] = value
|
||||
|
||||
for tensor in reader.tensors:
|
||||
self.tensor_info[tensor.name] = {
|
||||
'shape': list(reversed(tensor.shape)), # Reverse to match PyTorch order
|
||||
'dtype': tensor.tensor_type,
|
||||
'offset': tensor.data_offset,
|
||||
'n_elements': tensor.n_elements,
|
||||
"shape": list(reversed(tensor.shape)), # Reverse to match PyTorch order
|
||||
"dtype": tensor.tensor_type,
|
||||
"offset": tensor.data_offset,
|
||||
"n_elements": tensor.n_elements,
|
||||
}
|
||||
self.tensor_file_map[tensor.name] = file_path
|
||||
|
||||
self.file_data_map[file_path] = np.memmap(file_path, mode='r')
|
||||
self.file_data_map[file_path] = np.memmap(file_path, mode="r")
|
||||
|
||||
def _load_directory(self, dir_path: str):
|
||||
"""Load all GGUF files from a directory (non-recursive)"""
|
||||
@@ -317,24 +319,24 @@ class GGUFLoader:
|
||||
for key, field in reader.fields.items():
|
||||
value = field.parts[field.data[0]]
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf-8')
|
||||
value = value.decode("utf-8")
|
||||
elif isinstance(value, np.ndarray) and value.dtype == np.uint8:
|
||||
try:
|
||||
value = bytes(value).decode('utf-8')
|
||||
value = bytes(value).decode("utf-8")
|
||||
except:
|
||||
pass
|
||||
self.metadata[key] = value
|
||||
|
||||
for tensor in reader.tensors:
|
||||
self.tensor_info[tensor.name] = {
|
||||
'shape': list(reversed(tensor.shape)),
|
||||
'dtype': tensor.tensor_type,
|
||||
'offset': tensor.data_offset,
|
||||
'n_elements': tensor.n_elements,
|
||||
"shape": list(reversed(tensor.shape)),
|
||||
"dtype": tensor.tensor_type,
|
||||
"offset": tensor.data_offset,
|
||||
"n_elements": tensor.n_elements,
|
||||
}
|
||||
self.tensor_file_map[tensor.name] = file_path
|
||||
|
||||
self.file_data_map[file_path] = np.memmap(file_path, mode='r')
|
||||
self.file_data_map[file_path] = np.memmap(file_path, mode="r")
|
||||
|
||||
if not found_gguf:
|
||||
raise FileNotFoundError(f"No .gguf files found in directory: {dir_path}")
|
||||
@@ -407,7 +409,7 @@ class GGUFLoader:
|
||||
|
||||
base_key = f"blk.{layer_idx}.ffn_gate_exps.weight"
|
||||
if base_key in self.tensor_info:
|
||||
gate_shape = self.tensor_info[base_key]['shape']
|
||||
gate_shape = self.tensor_info[base_key]["shape"]
|
||||
print(f" Found tensor '{base_key}' with shape: {gate_shape}")
|
||||
|
||||
if len(gate_shape) >= 3:
|
||||
@@ -438,8 +440,9 @@ class GGUFLoader:
|
||||
print(f" Total metadata entries: {len(self.metadata)}")
|
||||
|
||||
if filter_keywords:
|
||||
filtered = {k: v for k, v in self.metadata.items()
|
||||
if any(kw.lower() in k.lower() for kw in filter_keywords)}
|
||||
filtered = {
|
||||
k: v for k, v in self.metadata.items() if any(kw.lower() in k.lower() for kw in filter_keywords)
|
||||
}
|
||||
for k, v in sorted(filtered.items()):
|
||||
print(f" {k}: {v}")
|
||||
else:
|
||||
@@ -477,40 +480,40 @@ class GGUFLoader:
|
||||
file_path = self.tensor_file_map[name]
|
||||
mmap_data = self.file_data_map[file_path]
|
||||
|
||||
offset = info['offset']
|
||||
n_elements = info['n_elements']
|
||||
ggml_type = info['dtype']
|
||||
offset = info["offset"]
|
||||
n_elements = info["n_elements"]
|
||||
ggml_type = info["dtype"]
|
||||
|
||||
GGML_QUANT_SIZES = {
|
||||
GGMLQuantizationType.F32: (1, 4),
|
||||
GGMLQuantizationType.F16: (1, 2),
|
||||
GGMLQuantizationType.BF16: (1, 2),
|
||||
GGMLQuantizationType.Q4_0: (32, 2 + 16),
|
||||
GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
|
||||
GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
|
||||
GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
|
||||
GGMLQuantizationType.Q8_0: (32, 2 + 32),
|
||||
GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
|
||||
GGMLQuantizationType.Q2_K: (256, 2 + 2 + 256 // 16 + 256 // 4),
|
||||
GGMLQuantizationType.Q3_K: (256, 2 + 256 // 4 + 256 // 8 + 12),
|
||||
GGMLQuantizationType.Q4_K: (256, 2 + 2 + 256 // 2 + 12),
|
||||
GGMLQuantizationType.Q5_K: (256, 2 + 2 + 256 // 2 + 256 // 8 + 12),
|
||||
GGMLQuantizationType.Q6_K: (256, 2 + 256 // 2 + 256 // 4 + 256 // 16),
|
||||
GGMLQuantizationType.Q8_K: (256, 4 + 256 + 256 // 8),
|
||||
GGMLQuantizationType.F32: (1, 4),
|
||||
GGMLQuantizationType.F16: (1, 2),
|
||||
GGMLQuantizationType.BF16: (1, 2),
|
||||
GGMLQuantizationType.Q4_0: (32, 2 + 16),
|
||||
GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
|
||||
GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
|
||||
GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
|
||||
GGMLQuantizationType.Q8_0: (32, 2 + 32),
|
||||
GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
|
||||
GGMLQuantizationType.Q2_K: (256, 2 + 2 + 256 // 16 + 256 // 4),
|
||||
GGMLQuantizationType.Q3_K: (256, 2 + 256 // 4 + 256 // 8 + 12),
|
||||
GGMLQuantizationType.Q4_K: (256, 2 + 2 + 256 // 2 + 12),
|
||||
GGMLQuantizationType.Q5_K: (256, 2 + 2 + 256 // 2 + 256 // 8 + 12),
|
||||
GGMLQuantizationType.Q6_K: (256, 2 + 256 // 2 + 256 // 4 + 256 // 16),
|
||||
GGMLQuantizationType.Q8_K: (256, 4 + 256 + 256 // 8),
|
||||
GGMLQuantizationType.IQ2_XXS: (256, 2 + 256 // 4),
|
||||
GGMLQuantizationType.IQ2_XS: (256, 2 + 256 // 4 + 256 // 32),
|
||||
GGMLQuantizationType.IQ2_XS: (256, 2 + 256 // 4 + 256 // 32),
|
||||
GGMLQuantizationType.IQ3_XXS: (256, 2 + 256 // 4 + 256 // 8),
|
||||
GGMLQuantizationType.IQ1_S: (256, 2 + 256 // 8 + 256 // 16),
|
||||
GGMLQuantizationType.IQ4_NL: (32, 2 + 16),
|
||||
GGMLQuantizationType.IQ3_S: (256, 2 + 256 // 4 + 256 // 8 + 256 // 32 + 4),
|
||||
GGMLQuantizationType.IQ2_S: (256, 2 + 256 // 4 + 256 // 16),
|
||||
GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + 256 // 2 + 256 // 64),
|
||||
GGMLQuantizationType.I8: (1, 1),
|
||||
GGMLQuantizationType.I16: (1, 2),
|
||||
GGMLQuantizationType.I32: (1, 4),
|
||||
GGMLQuantizationType.I64: (1, 8),
|
||||
GGMLQuantizationType.F64: (1, 8),
|
||||
GGMLQuantizationType.IQ1_M: (256, 256 // 8 + 256 // 16 + 256 // 32),
|
||||
GGMLQuantizationType.IQ1_S: (256, 2 + 256 // 8 + 256 // 16),
|
||||
GGMLQuantizationType.IQ4_NL: (32, 2 + 16),
|
||||
GGMLQuantizationType.IQ3_S: (256, 2 + 256 // 4 + 256 // 8 + 256 // 32 + 4),
|
||||
GGMLQuantizationType.IQ2_S: (256, 2 + 256 // 4 + 256 // 16),
|
||||
GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + 256 // 2 + 256 // 64),
|
||||
GGMLQuantizationType.I8: (1, 1),
|
||||
GGMLQuantizationType.I16: (1, 2),
|
||||
GGMLQuantizationType.I32: (1, 4),
|
||||
GGMLQuantizationType.I64: (1, 8),
|
||||
GGMLQuantizationType.F64: (1, 8),
|
||||
GGMLQuantizationType.IQ1_M: (256, 256 // 8 + 256 // 16 + 256 // 32),
|
||||
}
|
||||
|
||||
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
||||
|
||||
Reference in New Issue
Block a user