mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 14:29:22 +00:00
Refactor KTMoEWrapper backend (#1587)
* universal backend for cpu inference * expert defer
This commit is contained in:
42
kt-kernel/install.sh
Executable file
42
kt-kernel/install.sh
Executable file
@@ -0,0 +1,42 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 [avx|amx]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [ $# -ne 1 ]; then
|
||||
usage
|
||||
fi
|
||||
|
||||
MODE="$1"
|
||||
case "$MODE" in
|
||||
avx)
|
||||
export CPUINFER_CPU_INSTRUCT=AVX2
|
||||
export CPUINFER_ENABLE_AMX=OFF
|
||||
;;
|
||||
amx)
|
||||
export CPUINFER_CPU_INSTRUCT=AMX512
|
||||
export CPUINFER_ENABLE_AMX=ON
|
||||
;;
|
||||
*)
|
||||
echo "Error: unknown mode '$MODE'"
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
|
||||
export CPUINFER_BUILD_TYPE=Release
|
||||
export CPUINFER_PARALLEL=8
|
||||
export CPUINFER_VERBOSE=1
|
||||
|
||||
echo "Building in mode: $MODE"
|
||||
echo "Environment:"
|
||||
echo " CPUINFER_CPU_INSTRUCT=$CPUINFER_CPU_INSTRUCT"
|
||||
echo " CPUINFER_ENABLE_AMX=$CPUINFER_ENABLE_AMX"
|
||||
echo " CPUINFER_BUILD_TYPE=$CPUINFER_BUILD_TYPE"
|
||||
echo " CPUINFER_PARALLEL=$CPUINFER_PARALLEL"
|
||||
echo " CPUINFER_VERBOSE=$CPUINFER_VERBOSE"
|
||||
|
||||
pip install -e . -v
|
||||
|
||||
@@ -315,21 +315,35 @@ class LLAMA_MOE_TP {
|
||||
|
||||
#endif
|
||||
|
||||
int activated_expert = 0;
|
||||
for (int i = 0; i < k; i++) {
|
||||
if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
m_expert_id_map_[activated_expert] = expert_ids[i];
|
||||
activated_expert++;
|
||||
}
|
||||
|
||||
int nth = config_.intermediate_size / config_.m_block;
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
nth * k, nullptr,
|
||||
[&](int task_id) {
|
||||
int expert_idx = task_id / nth;
|
||||
int64_t expert_id = expert_ids[expert_idx];
|
||||
int ith = task_id % nth;
|
||||
// Only process activated (CPU) experts; skip GPU experts entirely to keep buffers aligned.
|
||||
if (activated_expert > 0) {
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert, nullptr,
|
||||
[&](int task_id) {
|
||||
int act_idx = task_id / nth;
|
||||
int64_t expert_id = m_expert_id_map_[act_idx];
|
||||
if (expert_id == -1) {
|
||||
return;
|
||||
}
|
||||
int ith = task_id % nth;
|
||||
|
||||
void* gate_proj_ptr =
|
||||
(uint8_t*)m_local_gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.m_block) *
|
||||
config_.hidden_size * ggml_type_size((ggml_type)config_.gate_type) /
|
||||
ggml_blck_size((ggml_type)config_.gate_type);
|
||||
|
||||
float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.m_block;
|
||||
float* gate_output_ptr = s_gate_output_[act_idx] + ith * config_.m_block;
|
||||
auto ok = llamafile_sgemm(config_.m_block, 1,
|
||||
config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_proj_ptr,
|
||||
config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_input_ptr,
|
||||
@@ -340,33 +354,29 @@ class LLAMA_MOE_TP {
|
||||
if (ok == false) [[unlikely]] {
|
||||
throw std::runtime_error("llamafile not supported");
|
||||
}
|
||||
// printf("gate output: ");
|
||||
// debug_f32(gate_output_ptr);
|
||||
|
||||
void* up_proj_ptr =
|
||||
(uint8_t*)m_local_up_proj_ + (expert_id * config_.intermediate_size + ith * config_.m_block) *
|
||||
config_.hidden_size * ggml_type_size((ggml_type)config_.up_type) /
|
||||
ggml_blck_size((ggml_type)config_.up_type);
|
||||
|
||||
float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.m_block;
|
||||
float* up_output_ptr = s_up_output_[act_idx] + ith * config_.m_block;
|
||||
llamafile_sgemm(config_.m_block, 1, config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type),
|
||||
up_proj_ptr, config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type), up_input_ptr,
|
||||
config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type), up_output_ptr,
|
||||
config_.m_block, 0, 1, GGML_TASK_TYPE_COMPUTE, (ggml_type)config_.up_type,
|
||||
ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type, GGML_TYPE_F32,
|
||||
GGML_PREC_DEFAULT);
|
||||
// printf("up output: ");
|
||||
// debug_f32(up_output_ptr);
|
||||
|
||||
for (int i = ith * config_.m_block; i < (ith + 1) * config_.m_block; i++) {
|
||||
s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];
|
||||
s_intermediate_fp32_[act_idx][i] = act_fn(s_gate_output_[act_idx][i]) * s_up_output_[act_idx][i];
|
||||
}
|
||||
if (config_.m_block %
|
||||
ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) ==
|
||||
0) {
|
||||
float* intermediate_fp32_ptr = s_intermediate_fp32_[expert_idx] + ith * config_.m_block;
|
||||
float* intermediate_fp32_ptr = s_intermediate_fp32_[act_idx] + ith * config_.m_block;
|
||||
void* down_input_ptr =
|
||||
s_down_input_[expert_idx] +
|
||||
s_down_input_[act_idx] +
|
||||
ith * config_.m_block *
|
||||
ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) /
|
||||
ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type);
|
||||
@@ -375,10 +385,11 @@ class LLAMA_MOE_TP {
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
|
||||
if (config_.m_block % ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) !=
|
||||
0) {
|
||||
for (int i = 0; i < k; i++) {
|
||||
for (int i = 0; i < activated_expert; i++) {
|
||||
from_float(s_intermediate_fp32_[i], s_down_input_[i], config_.intermediate_size,
|
||||
ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type);
|
||||
}
|
||||
@@ -400,8 +411,11 @@ class LLAMA_MOE_TP {
|
||||
for (int i = ith * config_.m_block; i < (ith + 1) * config_.m_block; i++) {
|
||||
output[i] = 0;
|
||||
}
|
||||
for (int expert_idx = 0; expert_idx < k; expert_idx++) {
|
||||
int64_t expert_id = expert_ids[expert_idx];
|
||||
for (int expert_idx = 0; expert_idx < activated_expert; expert_idx++) {
|
||||
int64_t expert_id = m_expert_id_map_[expert_idx];
|
||||
if (expert_id == -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto expert_offset = expert_id * config_.hidden_size * config_.intermediate_size;
|
||||
auto m_block_offset = ith * config_.m_block * config_.intermediate_size;
|
||||
@@ -418,8 +432,16 @@ class LLAMA_MOE_TP {
|
||||
ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type, GGML_TYPE_F32,
|
||||
GGML_PREC_DEFAULT);
|
||||
|
||||
float expert_weight = 0.0f;
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[j] == expert_id) {
|
||||
expert_weight = weights[j];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = ith * config_.m_block; i < (ith + 1) * config_.m_block; i++) {
|
||||
output[i] += s_down_output_[expert_idx][i] * weights[expert_idx];
|
||||
output[i] += s_down_output_[expert_idx][i] * expert_weight;
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -452,6 +474,12 @@ class LLAMA_MOE_TP {
|
||||
}
|
||||
for (int i = 0; i < qlen; i++) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
if (expert_ids[i * k + j] == -1) {
|
||||
continue;
|
||||
}
|
||||
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
|
||||
}
|
||||
}
|
||||
@@ -539,6 +567,12 @@ class LLAMA_MOE_TP {
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
if (expert_ids[i * k + j] == -1) {
|
||||
continue;
|
||||
}
|
||||
memcpy(m_local_gate_input_ptr_[expert_ids[i * k + j]] +
|
||||
m_local_pos_[i][j] * config_.hidden_size *
|
||||
ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type) /
|
||||
@@ -683,6 +717,12 @@ class LLAMA_MOE_TP {
|
||||
m_output_fp32_[i][e] = 0;
|
||||
}
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
if (expert_ids[i * k + j] == -1) {
|
||||
continue;
|
||||
}
|
||||
for (int e = 0; e < config_.hidden_size; e++) {
|
||||
m_output_fp32_[i][e] +=
|
||||
m_local_down_output_ptr_[expert_ids[i * k + j]][m_local_pos_[i][j] * config_.hidden_size + e] *
|
||||
@@ -739,24 +779,38 @@ class TP_MOE<LLAMA_MOE_TP> : public TP_MOE_Common<LLAMA_MOE_TP> {
|
||||
|
||||
void load_weights() {
|
||||
auto pool = this->config.pool;
|
||||
auto inter = this->config.intermediate_size / this->tp_count;
|
||||
pool->dispense_backend()->do_numa_job([this, pool, inter](int tp_id) {
|
||||
this->tps[tp_id]->load_weights(this->config.intermediate_size, tp_id * inter);
|
||||
|
||||
std::vector<int> tp_offsets(this->tp_count);
|
||||
int accumulated_offset = 0;
|
||||
for (int i = 0; i < this->tp_count; i++) {
|
||||
tp_offsets[i] = accumulated_offset;
|
||||
accumulated_offset += this->tp_configs[i].intermediate_size;
|
||||
}
|
||||
|
||||
pool->dispense_backend()->do_numa_job([this, pool, tp_offsets](int tp_id) {
|
||||
this->tps[tp_id]->load_weights(this->config.intermediate_size, tp_offsets[tp_id]);
|
||||
});
|
||||
this->weights_loaded = true;
|
||||
}
|
||||
|
||||
void merge_results(int qlen, void* output) { merge_results(qlen, output, false); }
|
||||
|
||||
void merge_results(int qlen, void* output, bool incremental) {
|
||||
if (incremental) {
|
||||
throw std::runtime_error("Not Implemented");
|
||||
}
|
||||
void merge_results(int qlen, void *output, bool incremental) {
|
||||
auto pool = this->config.pool;
|
||||
pool->do_work_stealing_job(
|
||||
qlen, nullptr,
|
||||
[this, output](int token_nth) {
|
||||
auto& tp_count = this->tp_count;
|
||||
[this, output, incremental](int token_nth) {
|
||||
if (incremental) {
|
||||
to_float((uint8_t *)output + token_nth * config.hidden_size *
|
||||
ggml_type_size((ggml_type)config.hidden_type) /
|
||||
ggml_blck_size((ggml_type)config.hidden_type),
|
||||
local_output + token_nth * config.hidden_size, config.hidden_size, (ggml_type)config.hidden_type);
|
||||
for (int e = 0; e < config.hidden_size; e++) {
|
||||
local_output_numa[0][token_nth * config.hidden_size + e] +=
|
||||
local_output[token_nth * config.hidden_size + e];
|
||||
}
|
||||
}
|
||||
auto &tp_count = this->tp_count;
|
||||
for (int i = 1; i < tp_count; i++) {
|
||||
for (int e = 0; e < config.hidden_size; e++) {
|
||||
local_output_numa[0][token_nth * config.hidden_size + e] +=
|
||||
@@ -764,11 +818,12 @@ class TP_MOE<LLAMA_MOE_TP> : public TP_MOE_Common<LLAMA_MOE_TP> {
|
||||
}
|
||||
}
|
||||
from_float(local_output_numa[0] + token_nth * config.hidden_size,
|
||||
(uint8_t*)output + token_nth * config.hidden_size * ggml_type_size((ggml_type)config.hidden_type) /
|
||||
ggml_blck_size((ggml_type)config.hidden_type),
|
||||
(uint8_t *)output + token_nth * config.hidden_size *
|
||||
ggml_type_size((ggml_type)config.hidden_type) /
|
||||
ggml_blck_size((ggml_type)config.hidden_type),
|
||||
config.hidden_size, (ggml_type)config.hidden_type);
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -5,9 +5,12 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#include <type_traits>
|
||||
#include "common.hpp"
|
||||
|
||||
// Forward declaration for Llamafile backend type checking
|
||||
class LLAMA_MOE_TP;
|
||||
|
||||
template <typename T>
|
||||
concept MOE_TP_PART = requires(T t, int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
|
||||
void* output, GeneralMOEConfig config, int tp_idx) {
|
||||
@@ -26,6 +29,7 @@ class TP_MOE_Common : public MoE_Interface {
|
||||
std::vector<std::unique_ptr<T>> tps;
|
||||
|
||||
std::vector<typename T::output_t*> local_output_numa;
|
||||
T::output_t *local_output = nullptr;
|
||||
|
||||
bool weights_loaded = false;
|
||||
|
||||
@@ -53,11 +57,65 @@ class TP_MOE_Common : public MoE_Interface {
|
||||
"multiple of NUMA node count");
|
||||
}
|
||||
|
||||
for (auto i = 0; i < tp_count; i++) {
|
||||
tps.push_back(nullptr);
|
||||
GeneralMOEConfig tp_config = config;
|
||||
tp_config.intermediate_size /= tp_count;
|
||||
tp_configs.push_back(tp_config);
|
||||
|
||||
// Check if this is Llamafile backend using compile-time type checking
|
||||
constexpr bool is_llamafile = std::is_same<T, LLAMA_MOE_TP>::value;
|
||||
#ifndef QK_K
|
||||
#define QK_K 256
|
||||
#endif
|
||||
|
||||
if (is_llamafile) {
|
||||
// For Llamafile backend: use QK_K-aligned TP splitting
|
||||
if (config.intermediate_size % QK_K != 0) {
|
||||
printf("intermediate_size %d must be divisible by QK_K %d for Llamafile backend\n",
|
||||
config.intermediate_size, QK_K);
|
||||
throw std::runtime_error("intermediate_size must be divisible by QK_K (256) for Llamafile backend");
|
||||
}
|
||||
|
||||
int num_blocks = config.intermediate_size / QK_K;
|
||||
int base_blocks = num_blocks / tp_count;
|
||||
int extra_blocks = num_blocks % tp_count;
|
||||
|
||||
if (base_blocks == 0) {
|
||||
printf("intermediate_size %d is too small for tp_count %d (num_blocks=%d)\n",
|
||||
config.intermediate_size, tp_count, num_blocks);
|
||||
throw std::runtime_error("intermediate_size too small: cannot distribute blocks to all TP instances");
|
||||
}
|
||||
|
||||
printf("Llamafile TP splitting: intermediate_size=%d, tp_count=%d, QK_K=%d\n",
|
||||
config.intermediate_size, tp_count, QK_K);
|
||||
printf(" num_blocks=%d, base_blocks=%d, extra_blocks=%d\n", num_blocks, base_blocks, extra_blocks);
|
||||
|
||||
int current_offset = 0;
|
||||
for (auto i = 0; i < tp_count; i++) {
|
||||
tps.push_back(nullptr);
|
||||
GeneralMOEConfig tp_config = config;
|
||||
|
||||
// First extra_blocks TPs get one more block
|
||||
int num_blocks_for_this_tp = base_blocks + (i < extra_blocks ? 1 : 0);
|
||||
tp_config.intermediate_size = num_blocks_for_this_tp * QK_K;
|
||||
|
||||
printf(" TP %d: intermediate_size=%d, offset=%d, blocks=%d\n",
|
||||
i, tp_config.intermediate_size, current_offset, num_blocks_for_this_tp);
|
||||
|
||||
tp_configs.push_back(tp_config);
|
||||
current_offset += tp_config.intermediate_size;
|
||||
}
|
||||
} else {
|
||||
// For non-Llamafile backends: use simple equal division
|
||||
if (config.intermediate_size % tp_count != 0) {
|
||||
printf("intermediate_size %d, tp count %d\n", config.intermediate_size, tp_count);
|
||||
throw std::runtime_error(
|
||||
"For TP, intermediate_size must be a "
|
||||
"multiple of NUMA node count");
|
||||
}
|
||||
|
||||
for (auto i = 0; i < tp_count; i++) {
|
||||
tps.push_back(nullptr);
|
||||
GeneralMOEConfig tp_config = config;
|
||||
tp_config.intermediate_size /= tp_count;
|
||||
tp_configs.push_back(tp_config);
|
||||
}
|
||||
}
|
||||
|
||||
config.pool->dispense_backend()->do_numa_job(
|
||||
@@ -70,6 +128,8 @@ class TP_MOE_Common : public MoE_Interface {
|
||||
&local_output_numa[i],
|
||||
(size_t)sizeof(typename T::output_t) * tp_configs[i].max_possible_qlen() * tp_configs[i].hidden_size);
|
||||
}
|
||||
mem_requests.append_pointer((void **)&local_output, sizeof(typename T::output_t) * tp_configs[0].max_possible_qlen() *
|
||||
tp_configs[0].hidden_size);
|
||||
// printf("local output tp, %d,\n", tp_configs[0].max_possible_qlen());
|
||||
shared_mem_buffer.alloc(this, mem_requests);
|
||||
}
|
||||
@@ -144,7 +204,9 @@ class TP_MOE_Common : public MoE_Interface {
|
||||
|
||||
virtual void load_weights() = 0;
|
||||
|
||||
|
||||
virtual void merge_results(int qlen, void* output) = 0;
|
||||
|
||||
virtual void merge_results(int qlen, void* output, bool incremental) {
|
||||
if (incremental == false) {
|
||||
merge_results(qlen, output);
|
||||
|
||||
@@ -23,7 +23,7 @@ Example usage:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .experts import AMXMoEWrapper
|
||||
from .experts import KTMoEWrapper
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__all__ = ["AMXMoEWrapper"]
|
||||
__all__ = ["KTMoEWrapper"]
|
||||
|
||||
@@ -1,220 +1,51 @@
|
||||
# Wrapper for AMX MoE CPU inference operations
|
||||
# Wrapper for MoE CPU inference operations
|
||||
# This module encapsulates CPU inference engine, weight loading, and buffer management
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
Expert wrappers for CPU-based MoE inference.
|
||||
|
||||
This module provides high-level Python wrappers around the low-level C++ kernel
|
||||
implementations, handling weight loading, buffer management, and forward inference.
|
||||
This module provides the main factory interface (KTMoEWrapper) that automatically
|
||||
selects the appropriate backend implementation based on the method parameter.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from safetensors import safe_open
|
||||
import os
|
||||
import ctypes
|
||||
from typing import List, Optional
|
||||
|
||||
# Import the C++ extension module (compiled as kt_kernel_ext)
|
||||
import kt_kernel_ext
|
||||
from kt_kernel_ext.moe import MOEConfig, AMXInt4_MOE, AMXInt8_MOE
|
||||
# Import base infrastructure
|
||||
from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer
|
||||
|
||||
# Import backend implementations
|
||||
from .utils.amx import AMXMoEWrapper
|
||||
from .utils.llamafile import LlamafileMoEWrapper
|
||||
|
||||
|
||||
class SafeTensorLoader:
|
||||
tensor_file_map: dict
|
||||
tensor_type_map: dict
|
||||
file_handle_map: dict
|
||||
tensor_device_map: dict
|
||||
class KTMoEWrapper:
|
||||
"""
|
||||
Factory interface for MoE CPU inference operations.
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
self.__load_tensor_file_map(file_path)
|
||||
This class serves as the main entry point for external code. It automatically
|
||||
selects the appropriate backend implementation based on the `method` parameter.
|
||||
|
||||
def __load_tensor_file_map(self, file_path: str):
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Path not found: {file_path}")
|
||||
if os.path.isfile(file_path):
|
||||
folder_path = os.path.dirname(file_path)
|
||||
else:
|
||||
folder_path = file_path
|
||||
self.file_handle_map = {}
|
||||
self.tensor_file_map = {}
|
||||
self.tensor_type_map = {}
|
||||
self.tensor_device_map = {}
|
||||
|
||||
found_safetensor = False
|
||||
for root, _, files in os.walk(folder_path):
|
||||
files = sorted(files)
|
||||
for file in files:
|
||||
if file.endswith(".safetensors"):
|
||||
found_safetensor = True
|
||||
file_path = os.path.join(root, file)
|
||||
if file not in self.file_handle_map:
|
||||
try:
|
||||
handle = safe_open(file_path, framework="pt")
|
||||
self.file_handle_map[file] = handle
|
||||
except Exception as e:
|
||||
print(f"Error opening Safetensor file {file_path}: {e}")
|
||||
continue
|
||||
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
continue
|
||||
try:
|
||||
for key in f.keys():
|
||||
self.tensor_file_map[key] = file
|
||||
except Exception as e:
|
||||
print(f"Error reading Safetensor file {file_path}: {e}")
|
||||
|
||||
if not found_safetensor:
|
||||
raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
|
||||
|
||||
def load_tensor(self, key: str, device: str = "cpu"):
|
||||
if key not in self.tensor_file_map:
|
||||
raise KeyError(f"Key {key} not found in Safetensor files")
|
||||
file = self.tensor_file_map[key]
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||
tensor = f.get_tensor(key)
|
||||
return tensor.to(device)
|
||||
|
||||
def close_all_handles(self):
|
||||
for handle in self.file_handle_map.values():
|
||||
handle.close()
|
||||
self.file_handle_map.clear()
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
# base_key: blk.{layer_index}
|
||||
# blk.{layer_index}.ffn_[up, down, gate]_exps.{expert_id}.numa.{numa_id}.weight
|
||||
up_base_key = f"{base_key}.ffn_up_exps"
|
||||
gate_base_key = f"{base_key}.ffn_gate_exps"
|
||||
down_base_key = f"{base_key}.ffn_down_exps"
|
||||
max_numa_id = -1
|
||||
max_experts_count = -1
|
||||
while self.has_tensor(f"{up_base_key}.{max_experts_count+1}.numa.{0}.weight"):
|
||||
max_experts_count += 1
|
||||
if max_experts_count == 0:
|
||||
raise ValueError(f"No experts found for key {base_key}")
|
||||
while self.has_tensor(f"{up_base_key}.{0}.numa.{max_numa_id+1}.weight"):
|
||||
max_numa_id += 1
|
||||
# Initialize empty lists to store tensors for each projection type
|
||||
up_weights = [[] for _ in range(max_numa_id + 1)]
|
||||
gate_weights = [[] for _ in range(max_numa_id + 1)]
|
||||
down_weights = [[] for _ in range(max_numa_id + 1)]
|
||||
up_scales = [[] for _ in range(max_numa_id + 1)]
|
||||
gate_scales = [[] for _ in range(max_numa_id + 1)]
|
||||
down_scales = [[] for _ in range(max_numa_id + 1)]
|
||||
for numa_id in range(max_numa_id + 1):
|
||||
for expert_id in range(max_experts_count + 1):
|
||||
up_key = f"{up_base_key}.{expert_id}.numa.{numa_id}.weight"
|
||||
gate_key = f"{gate_base_key}.{expert_id}.numa.{numa_id}.weight"
|
||||
down_key = f"{down_base_key}.{expert_id}.numa.{numa_id}.weight"
|
||||
up_scale_key = f"{up_base_key}.{expert_id}.numa.{numa_id}.scale"
|
||||
gate_scale_key = f"{gate_base_key}.{expert_id}.numa.{numa_id}.scale"
|
||||
down_scale_key = f"{down_base_key}.{expert_id}.numa.{numa_id}.scale"
|
||||
# make sure contiguous
|
||||
up_tensor = self.load_tensor(up_key, device).numpy()
|
||||
gate_tensor = self.load_tensor(gate_key, device).numpy()
|
||||
down_tensor = self.load_tensor(down_key, device).numpy()
|
||||
up_scale_tensor = self.load_tensor(up_scale_key, device).numpy()
|
||||
gate_scale_tensor = self.load_tensor(gate_scale_key, device).numpy()
|
||||
down_scale_tensor = self.load_tensor(down_scale_key, device).numpy()
|
||||
|
||||
up_weights[numa_id].append(up_tensor)
|
||||
gate_weights[numa_id].append(gate_tensor)
|
||||
down_weights[numa_id].append(down_tensor)
|
||||
up_scales[numa_id].append(up_scale_tensor)
|
||||
gate_scales[numa_id].append(gate_scale_tensor)
|
||||
down_scales[numa_id].append(down_scale_tensor)
|
||||
return {
|
||||
"up": up_weights,
|
||||
"gate": gate_weights,
|
||||
"down": down_weights,
|
||||
"up_scale": up_scales,
|
||||
"gate_scale": gate_scales,
|
||||
"down_scale": down_scales,
|
||||
}
|
||||
|
||||
def has_tensor(self, name: str):
|
||||
return name in self.tensor_file_map
|
||||
|
||||
|
||||
class KExpertsCPUBuffer:
|
||||
capture_bs: List = list()
|
||||
capture_buffers: Dict = dict()
|
||||
temp_bs: int = 0
|
||||
temp_buffer: tuple = tuple()
|
||||
buffer_depth: int = 2
|
||||
|
||||
@classmethod
|
||||
def get_buffer(cls, hidden_states: torch.Tensor, num_experts_per_tok):
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
if batch_size in cls.capture_buffers:
|
||||
return cls.capture_buffers[batch_size]
|
||||
if batch_size == cls.temp_bs:
|
||||
return cls.temp_buffer
|
||||
|
||||
input_tensor_cpu = [
|
||||
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
immediate_experts_ids_cpu = [
|
||||
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
deferred_experts_ids_cpu = [
|
||||
torch.full((batch_size, num_experts_per_tok), -1, device="cpu", dtype=torch.long, pin_memory=True)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
weights_cpu = [
|
||||
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
output_cpu = [
|
||||
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
bsz_tensor_cpu = [
|
||||
torch.zeros((1,), device="cpu", dtype=torch.int32, pin_memory=True)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
output_gpu = [
|
||||
torch.zeros((batch_size, hidden_size), device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
|
||||
cur_buffer = (
|
||||
input_tensor_cpu,
|
||||
immediate_experts_ids_cpu,
|
||||
deferred_experts_ids_cpu,
|
||||
weights_cpu,
|
||||
output_cpu,
|
||||
bsz_tensor_cpu,
|
||||
output_gpu,
|
||||
Usage:
|
||||
wrapper = KTMoEWrapper(
|
||||
layer_idx=0,
|
||||
num_experts=8,
|
||||
num_experts_per_tok=2,
|
||||
hidden_size=4096,
|
||||
moe_intermediate_size=14336,
|
||||
num_gpu_experts=2,
|
||||
cpuinfer_threads=32,
|
||||
threadpool_count=2,
|
||||
weight_path="/path/to/weights",
|
||||
chunked_prefill_size=512,
|
||||
method="AMXINT4" # or "AMXINT8", "LLAMAFILE"
|
||||
)
|
||||
if batch_size in cls.capture_bs:
|
||||
cls.capture_buffers[batch_size] = cur_buffer
|
||||
cls.temp_bs = batch_size
|
||||
cls.temp_buffer = cur_buffer
|
||||
return cur_buffer
|
||||
|
||||
|
||||
class AMXMoEWrapper:
|
||||
"""
|
||||
Wrapper for AMX MoE CPU inference operations.
|
||||
Manages CPU inference engine, weight loading, and buffer management.
|
||||
"""
|
||||
|
||||
_cpu_infer_instance = None
|
||||
_safetensor_loader_instance = None
|
||||
_layer_has_pending_deferred: Dict[int, bool] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
def __new__(
|
||||
cls,
|
||||
layer_idx: int,
|
||||
num_experts: int,
|
||||
num_experts_per_tok: int,
|
||||
@@ -223,14 +54,14 @@ class AMXMoEWrapper:
|
||||
num_gpu_experts: int,
|
||||
cpuinfer_threads: int,
|
||||
threadpool_count: int,
|
||||
amx_weight_path: str,
|
||||
weight_path: str,
|
||||
chunked_prefill_size: int,
|
||||
cpu_save: bool = False,
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
amx_method: str = "AMXINT4",
|
||||
method: str = "AMXINT4",
|
||||
):
|
||||
"""
|
||||
Initialize AMX MoE Wrapper.
|
||||
Factory method to create the appropriate backend implementation.
|
||||
|
||||
Args:
|
||||
layer_idx: Layer index
|
||||
@@ -241,425 +72,41 @@ class AMXMoEWrapper:
|
||||
num_gpu_experts: Number of experts to run on GPU
|
||||
cpuinfer_threads: Number of CPU inference threads
|
||||
threadpool_count: Number of NUMA subpools
|
||||
amx_weight_path: Path to AMX weights
|
||||
weight_path: Path to weights
|
||||
chunked_prefill_size: Maximum prefill chunk size
|
||||
cpu_save: Whether to save weights to CPU memory
|
||||
max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).
|
||||
amx_method: AMX quantization method ("AMXINT4" or "AMXINT8")
|
||||
"""
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.num_experts = num_experts
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.hidden_size = hidden_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_gpu_experts = num_gpu_experts
|
||||
self.amx_weight_path = amx_weight_path
|
||||
self.chunked_prefill_size = chunked_prefill_size
|
||||
self.cpu_save = cpu_save
|
||||
self.max_deferred_experts_per_token = int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0
|
||||
|
||||
AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
|
||||
self.amx_method = amx_method
|
||||
|
||||
# Initialize CPU inference engine (singleton)
|
||||
if AMXMoEWrapper._cpu_infer_instance is None:
|
||||
worker_config = kt_kernel_ext.WorkerPoolConfig()
|
||||
|
||||
subpool_numa_map = list(range(threadpool_count))
|
||||
subpool_thread_count = [
|
||||
cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0)
|
||||
for i in range(threadpool_count)
|
||||
]
|
||||
|
||||
worker_config.subpool_count = threadpool_count
|
||||
worker_config.subpool_numa_map = subpool_numa_map
|
||||
worker_config.subpool_thread_count = subpool_thread_count
|
||||
AMXMoEWrapper._cpu_infer_instance = kt_kernel_ext.CPUInfer(worker_config)
|
||||
|
||||
self.cpu_infer = AMXMoEWrapper._cpu_infer_instance
|
||||
|
||||
# Check if we should load merged safetensor weights
|
||||
self.load_merged_weight = False
|
||||
import glob
|
||||
|
||||
if glob.glob(os.path.join(amx_weight_path, "*.safetensors")):
|
||||
self.load_merged_weight = True
|
||||
|
||||
# Initialize SafeTensor loader (singleton)
|
||||
if self.load_merged_weight:
|
||||
if AMXMoEWrapper._safetensor_loader_instance is None:
|
||||
AMXMoEWrapper._safetensor_loader_instance = SafeTensorLoader(amx_weight_path)
|
||||
self.safetensor_loader = AMXMoEWrapper._safetensor_loader_instance
|
||||
|
||||
self.moe = None
|
||||
self.gate_weights = None
|
||||
self.up_weights = None
|
||||
self.down_weights = None
|
||||
self.gate_scales = None
|
||||
self.up_scales = None
|
||||
self.down_scales = None
|
||||
|
||||
def load_weights_from_tensors(
|
||||
self,
|
||||
gate_proj: torch.Tensor,
|
||||
up_proj: torch.Tensor,
|
||||
down_proj: torch.Tensor,
|
||||
physical_to_logical_map_cpu: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Load and quantize weights from BF16/FP16 tensors (online quantization).
|
||||
|
||||
Args:
|
||||
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
|
||||
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
|
||||
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
|
||||
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
|
||||
"""
|
||||
# Store tensors as instance variables to keep them alive
|
||||
self.gate_proj = gate_proj.contiguous()
|
||||
self.up_proj = up_proj.contiguous()
|
||||
self.down_proj = down_proj.contiguous()
|
||||
|
||||
# Configure MoE with online quantization (cpu_save mode)
|
||||
moe_config = MOEConfig(
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
self.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
self.num_gpu_experts,
|
||||
)
|
||||
moe_config.layer_idx = self.layer_idx
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
moe_config.max_len = self.chunked_prefill_size
|
||||
|
||||
# Enable save mode for online quantization
|
||||
moe_config.save = True
|
||||
moe_config.load = False
|
||||
|
||||
# Set weight pointers
|
||||
moe_config.gate_proj = self.gate_proj.data_ptr()
|
||||
moe_config.up_proj = self.up_proj.data_ptr()
|
||||
moe_config.down_proj = self.down_proj.data_ptr()
|
||||
|
||||
# Set output path for quantized weights
|
||||
moe_config.path = self.amx_weight_path
|
||||
|
||||
# Create MoE module based on AMX method
|
||||
if self.amx_method == "AMXINT4":
|
||||
self.moe = AMXInt4_MOE(moe_config)
|
||||
elif self.amx_method == "AMXINT8":
|
||||
self.moe = AMXInt8_MOE(moe_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported AMX method: {self.amx_method}")
|
||||
|
||||
# Submit quantization and save task
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
|
||||
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
|
||||
"""
|
||||
Load weights for this layer and initialize the MoE module.
|
||||
|
||||
Args:
|
||||
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
|
||||
"""
|
||||
gate_ptr = 0
|
||||
up_ptr = 0
|
||||
down_ptr = 0
|
||||
|
||||
gate_ptrs = []
|
||||
up_ptrs = []
|
||||
down_ptrs = []
|
||||
|
||||
gate_scale_ptrs = []
|
||||
up_scale_ptrs = []
|
||||
down_scale_ptrs = []
|
||||
|
||||
if self.load_merged_weight:
|
||||
base_key = f"blk.{self.layer_idx}"
|
||||
w = self.safetensor_loader.load_experts(base_key)
|
||||
|
||||
self.gate_weights = w["gate"]
|
||||
self.up_weights = w["up"]
|
||||
self.down_weights = w["down"]
|
||||
self.gate_scales = w["gate_scale"]
|
||||
self.up_scales = w["up_scale"]
|
||||
self.down_scales = w["down_scale"]
|
||||
|
||||
# Get pointers to weight arrays
|
||||
gate_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.gate_weights
|
||||
]
|
||||
|
||||
up_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.up_weights
|
||||
]
|
||||
|
||||
down_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.down_weights
|
||||
]
|
||||
|
||||
gate_scale_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.gate_scales
|
||||
]
|
||||
|
||||
up_scale_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.up_scales
|
||||
]
|
||||
|
||||
down_scale_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.down_scales
|
||||
]
|
||||
|
||||
# Configure MoE
|
||||
moe_config = MOEConfig(
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
self.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
self.num_gpu_experts,
|
||||
)
|
||||
moe_config.layer_idx = self.layer_idx
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
moe_config.max_len = self.chunked_prefill_size
|
||||
|
||||
moe_config.gate_proj = gate_ptr
|
||||
moe_config.up_proj = up_ptr
|
||||
moe_config.down_proj = down_ptr
|
||||
moe_config.gate_projs = gate_ptrs
|
||||
moe_config.up_projs = up_ptrs
|
||||
moe_config.down_projs = down_ptrs
|
||||
moe_config.gate_scales = gate_scale_ptrs
|
||||
moe_config.up_scales = up_scale_ptrs
|
||||
moe_config.down_scales = down_scale_ptrs
|
||||
|
||||
if self.cpu_save:
|
||||
moe_config.save = True
|
||||
moe_config.load = False
|
||||
base_key = f"model.layers.{self.layer_idx}"
|
||||
w = self.safetensor_loader.load_experts(base_key)
|
||||
|
||||
self.gate_proj = torch.cat(w["gate_weight"], dim=0).contiguous()
|
||||
self.up_proj = torch.cat(w["up_weight"], dim=0).contiguous()
|
||||
self.down_proj = torch.cat(w["down_weight"], dim=0).contiguous()
|
||||
|
||||
moe_config.gate_proj = self.gate_proj.data_ptr()
|
||||
moe_config.up_proj = self.up_proj.data_ptr()
|
||||
moe_config.down_proj = self.down_proj.data_ptr()
|
||||
else:
|
||||
moe_config.load = True
|
||||
|
||||
if not self.load_merged_weight:
|
||||
moe_config.path = self.amx_weight_path
|
||||
|
||||
# Create MoE module based on AMX method
|
||||
if self.amx_method == "AMXINT4":
|
||||
self.moe = AMXInt4_MOE(moe_config)
|
||||
elif self.amx_method == "AMXINT8":
|
||||
self.moe = AMXInt8_MOE(moe_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported AMX method: {self.amx_method}")
|
||||
|
||||
# Load weights
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
|
||||
# Clean up temporary weight storage if using merged weights
|
||||
if self.load_merged_weight:
|
||||
del self.gate_weights
|
||||
del self.up_weights
|
||||
del self.down_weights
|
||||
del self.gate_scales
|
||||
del self.up_scales
|
||||
del self.down_scales
|
||||
|
||||
def select_deferred_experts(
|
||||
self,
|
||||
expert_ids: torch.Tensor,
|
||||
expert_scores: torch.Tensor,
|
||||
protected_k: int,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
batch, topk = expert_ids.shape
|
||||
device = expert_ids.device
|
||||
|
||||
protected_k = max(0, min(int(protected_k), topk))
|
||||
if protected_k == 0:
|
||||
deferred_ids = expert_ids.clone()
|
||||
immediate_ids = torch.full_like(expert_ids, -1)
|
||||
return immediate_ids, deferred_ids
|
||||
|
||||
topk_result = torch.topk(expert_scores, k=protected_k, dim=-1, largest=True, sorted=False)
|
||||
protected_indices = topk_result.indices
|
||||
protected_ids = torch.gather(expert_ids, -1, protected_indices)
|
||||
|
||||
protected_flag = torch.zeros((self.num_experts,), dtype=torch.int32, device=device)
|
||||
protected_flag.scatter_(0, protected_ids.reshape(-1), 1)
|
||||
|
||||
protected_mask_flat = torch.gather(protected_flag, 0, expert_ids.reshape(-1)).ne(0)
|
||||
protected_mask = protected_mask_flat.view(batch, topk)
|
||||
|
||||
immediate_ids = expert_ids.clone().masked_fill(~protected_mask, -1)
|
||||
deferred_ids = expert_ids.clone().masked_fill(protected_mask, -1)
|
||||
|
||||
return immediate_ids, deferred_ids
|
||||
|
||||
def submit_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
cuda_stream,
|
||||
):
|
||||
"""
|
||||
Submit forward inference task to CPU (non-blocking).
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states [batch_size, hidden_size]
|
||||
topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok]
|
||||
topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]
|
||||
cuda_stream: CUDA stream for synchronization
|
||||
"""
|
||||
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
batch_size = flat_hidden_states.shape[0]
|
||||
|
||||
(
|
||||
input_tensor_cpu,
|
||||
immediate_experts_ids_cpu,
|
||||
deferred_experts_ids_cpu,
|
||||
weights_cpu,
|
||||
output_cpu,
|
||||
bsz_tensor_cpu,
|
||||
_output_gpu,
|
||||
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
|
||||
|
||||
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
|
||||
next_slot = (current_slot + 1) % KExpertsCPUBuffer.buffer_depth
|
||||
|
||||
bsz_slot_tensor = bsz_tensor_cpu[current_slot]
|
||||
bsz_slot_tensor.fill_(batch_size)
|
||||
deferred_experts_ids_cpu[current_slot].fill_(-1)
|
||||
|
||||
topk_ids_long = topk_ids.to(torch.long)
|
||||
immediate_ids: torch.Tensor
|
||||
deferred_ids: Optional[torch.Tensor]
|
||||
if self.max_deferred_experts_per_token > 0:
|
||||
protected_k = self.num_experts_per_tok - self.max_deferred_experts_per_token
|
||||
|
||||
immediate_ids, deferred_ids = self.select_deferred_experts(topk_ids_long, topk_weights, protected_k)
|
||||
else:
|
||||
immediate_ids = topk_ids_long
|
||||
deferred_ids = None
|
||||
|
||||
input_tensor_cpu[current_slot].copy_(flat_hidden_states, non_blocking=True)
|
||||
weights_cpu[current_slot].copy_(topk_weights, non_blocking=True)
|
||||
immediate_experts_ids_cpu[current_slot].copy_(immediate_ids, non_blocking=True)
|
||||
|
||||
incremental = AMXMoEWrapper._layer_has_pending_deferred.get(self.layer_idx - 1, False)
|
||||
self.cpu_infer.submit_with_cuda_stream(
|
||||
cuda_stream,
|
||||
self.moe.forward_task(
|
||||
bsz_slot_tensor.data_ptr(),
|
||||
immediate_experts_ids_cpu[current_slot].size(-1),
|
||||
immediate_experts_ids_cpu[current_slot].data_ptr(),
|
||||
weights_cpu[current_slot].data_ptr(),
|
||||
input_tensor_cpu[current_slot].data_ptr(),
|
||||
output_cpu[current_slot].data_ptr(),
|
||||
incremental,
|
||||
),
|
||||
)
|
||||
|
||||
AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
|
||||
if deferred_ids is not None:
|
||||
deferred_experts_ids_cpu[current_slot].copy_(deferred_ids, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(
|
||||
cuda_stream,
|
||||
self.moe.forward_task(
|
||||
bsz_slot_tensor.data_ptr(),
|
||||
deferred_experts_ids_cpu[current_slot].size(-1),
|
||||
deferred_experts_ids_cpu[current_slot].data_ptr(),
|
||||
weights_cpu[current_slot].data_ptr(),
|
||||
input_tensor_cpu[current_slot].data_ptr(),
|
||||
output_cpu[next_slot].data_ptr(),
|
||||
False,
|
||||
),
|
||||
)
|
||||
AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = True
|
||||
|
||||
def sync_forward(self, hidden_states: torch.Tensor, cuda_stream) -> torch.Tensor:
|
||||
"""
|
||||
Synchronize and retrieve forward inference results.
|
||||
|
||||
Args:
|
||||
hidden_states: Original input hidden states (for getting buffer)
|
||||
cuda_stream: CUDA stream for synchronization
|
||||
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
|
||||
method: Backend method ("AMXINT4", "AMXINT8", "LLAMAFILE")
|
||||
|
||||
Returns:
|
||||
output_gpu: Output tensor on GPU
|
||||
An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)
|
||||
"""
|
||||
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
(
|
||||
input_tensor_cpu,
|
||||
immediate_experts_ids_cpu,
|
||||
_deferred_experts_ids_cpu,
|
||||
weights_cpu,
|
||||
output_cpu,
|
||||
_bsz_tensor_cpu,
|
||||
output_gpu,
|
||||
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
|
||||
# Select backend based on method
|
||||
if method in ["AMXINT4", "AMXINT8"]:
|
||||
backend_cls = AMXMoEWrapper
|
||||
elif method == "LLAMAFILE":
|
||||
backend_cls = LlamafileMoEWrapper
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported method: {method}")
|
||||
|
||||
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
|
||||
allow_pending = 1 if AMXMoEWrapper._layer_has_pending_deferred.get(self.layer_idx, False) else 0
|
||||
self.cpu_infer.sync_with_cuda_stream(cuda_stream, allow_pending)
|
||||
output_gpu[current_slot].copy_(output_cpu[current_slot], non_blocking=True)
|
||||
return output_gpu[current_slot]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
cuda_stream,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Execute forward inference synchronously (submit + sync).
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states [batch_size, hidden_size]
|
||||
topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok]
|
||||
topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]
|
||||
cuda_stream: CUDA stream for synchronization
|
||||
|
||||
Returns:
|
||||
Output tensor on GPU
|
||||
"""
|
||||
self.submit_forward(hidden_states, topk_ids, topk_weights, cuda_stream)
|
||||
return self.sync_forward(hidden_states, cuda_stream)
|
||||
# Create and return backend instance
|
||||
return backend_cls(
|
||||
layer_idx=layer_idx,
|
||||
num_experts=num_experts,
|
||||
num_experts_per_tok=num_experts_per_tok,
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
num_gpu_experts=num_gpu_experts,
|
||||
cpuinfer_threads=cpuinfer_threads,
|
||||
threadpool_count=threadpool_count,
|
||||
weight_path=weight_path,
|
||||
chunked_prefill_size=chunked_prefill_size,
|
||||
cpu_save=cpu_save,
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
)
|
||||
|
||||
# Forward static methods to the base class
|
||||
@staticmethod
|
||||
def set_capture_batch_sizes(capture_bs: List[int]):
|
||||
"""
|
||||
@@ -670,11 +117,8 @@ class AMXMoEWrapper:
|
||||
|
||||
Args:
|
||||
capture_bs: List of batch sizes to capture (e.g., [1, 2, 4, 8, 16])
|
||||
|
||||
Example:
|
||||
>>> AMXMoEWrapper.set_capture_batch_sizes([1, 2, 4, 8, 16])
|
||||
"""
|
||||
KExpertsCPUBuffer.capture_bs = capture_bs
|
||||
BaseMoEWrapper.set_capture_batch_sizes(capture_bs)
|
||||
|
||||
@staticmethod
|
||||
def get_capture_batch_sizes() -> List[int]:
|
||||
@@ -684,7 +128,7 @@ class AMXMoEWrapper:
|
||||
Returns:
|
||||
List of batch sizes that are being captured
|
||||
"""
|
||||
return KExpertsCPUBuffer.capture_bs
|
||||
return BaseMoEWrapper.get_capture_batch_sizes()
|
||||
|
||||
@staticmethod
|
||||
def clear_buffer_cache():
|
||||
@@ -694,6 +138,4 @@ class AMXMoEWrapper:
|
||||
This frees up memory by clearing the buffer cache. Useful when you want
|
||||
to reset the buffer state or free memory.
|
||||
"""
|
||||
KExpertsCPUBuffer.capture_buffers.clear()
|
||||
KExpertsCPUBuffer.temp_bs = 0
|
||||
KExpertsCPUBuffer.temp_buffer = tuple()
|
||||
BaseMoEWrapper.clear_buffer_cache()
|
||||
|
||||
394
kt-kernel/python/experts_base.py
Normal file
394
kt-kernel/python/experts_base.py
Normal file
@@ -0,0 +1,394 @@
|
||||
# Base classes for MoE CPU inference operations
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
Base infrastructure for CPU-based MoE inference.
|
||||
|
||||
This module contains base classes and utilities shared across all backend implementations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
import ctypes
|
||||
|
||||
import kt_kernel_ext
|
||||
|
||||
|
||||
|
||||
class KExpertsCPUBuffer:
|
||||
"""
|
||||
CPU buffer management for expert computation.
|
||||
|
||||
Manages pinned memory buffers for efficient GPU-CPU data transfer.
|
||||
"""
|
||||
capture_bs: List = list()
|
||||
capture_buffers: Dict = dict()
|
||||
temp_bs: int = 0
|
||||
temp_buffer: tuple = tuple()
|
||||
buffer_depth: int = 2
|
||||
|
||||
@classmethod
|
||||
def get_buffer(cls, hidden_states: torch.Tensor, num_experts_per_tok):
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
if batch_size in cls.capture_buffers:
|
||||
return cls.capture_buffers[batch_size]
|
||||
if batch_size == cls.temp_bs:
|
||||
return cls.temp_buffer
|
||||
|
||||
input_tensor_cpu = [
|
||||
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
immediate_experts_ids_cpu = [
|
||||
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
deferred_experts_ids_cpu = [
|
||||
torch.full((batch_size, num_experts_per_tok), -1, device="cpu", dtype=torch.long, pin_memory=True)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
weights_cpu = [
|
||||
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
output_cpu = [
|
||||
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
bsz_tensor_cpu = [
|
||||
torch.zeros((1,), device="cpu", dtype=torch.int32, pin_memory=True)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
output_gpu = [
|
||||
torch.zeros((batch_size, hidden_size), device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
for _ in range(cls.buffer_depth)
|
||||
]
|
||||
|
||||
cur_buffer = (
|
||||
input_tensor_cpu,
|
||||
immediate_experts_ids_cpu,
|
||||
deferred_experts_ids_cpu,
|
||||
weights_cpu,
|
||||
output_cpu,
|
||||
bsz_tensor_cpu,
|
||||
output_gpu,
|
||||
)
|
||||
if batch_size in cls.capture_bs:
|
||||
cls.capture_buffers[batch_size] = cur_buffer
|
||||
cls.temp_bs = batch_size
|
||||
cls.temp_buffer = cur_buffer
|
||||
return cur_buffer
|
||||
|
||||
|
||||
class BaseMoEWrapper(ABC):
|
||||
"""
|
||||
Base class for MoE CPU inference operations.
|
||||
Provides common functionality for all backend implementations.
|
||||
"""
|
||||
|
||||
_cpu_infer_instance = None
|
||||
_layer_has_pending_deferred: Dict[int, bool] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_idx: int,
|
||||
num_experts: int,
|
||||
num_experts_per_tok: int,
|
||||
hidden_size: int,
|
||||
moe_intermediate_size: int,
|
||||
num_gpu_experts: int,
|
||||
cpuinfer_threads: int,
|
||||
threadpool_count: int,
|
||||
weight_path: str,
|
||||
chunked_prefill_size: int,
|
||||
cpu_save: bool = False,
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
method: str = "AMXINT4",
|
||||
):
|
||||
"""
|
||||
Initialize base MoE Wrapper.
|
||||
|
||||
Args:
|
||||
layer_idx: Layer index
|
||||
num_experts: Total number of experts
|
||||
num_experts_per_tok: Number of experts per token (top-k)
|
||||
hidden_size: Hidden dimension size
|
||||
moe_intermediate_size: MoE intermediate size
|
||||
num_gpu_experts: Number of experts to run on GPU
|
||||
cpuinfer_threads: Number of CPU inference threads
|
||||
threadpool_count: Number of NUMA subpools
|
||||
weight_path: Path to weights
|
||||
chunked_prefill_size: Maximum prefill chunk size
|
||||
cpu_save: Whether to save weights to CPU memory
|
||||
max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).
|
||||
method: Backend method string
|
||||
"""
|
||||
print(f"Init {self.__class__.__name__}")
|
||||
self.layer_idx = layer_idx
|
||||
self.num_experts = num_experts
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.hidden_size = hidden_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_gpu_experts = num_gpu_experts
|
||||
self.weight_path = weight_path
|
||||
self.chunked_prefill_size = chunked_prefill_size
|
||||
self.cpu_save = cpu_save
|
||||
self.max_deferred_experts_per_token = int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0
|
||||
|
||||
BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
|
||||
self.method = method
|
||||
|
||||
# Initialize CPU inference engine (singleton)
|
||||
if BaseMoEWrapper._cpu_infer_instance is None:
|
||||
worker_config = kt_kernel_ext.WorkerPoolConfig()
|
||||
|
||||
subpool_numa_map = list(range(threadpool_count))
|
||||
subpool_thread_count = [
|
||||
cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0)
|
||||
for i in range(threadpool_count)
|
||||
]
|
||||
|
||||
worker_config.subpool_count = threadpool_count
|
||||
worker_config.subpool_numa_map = subpool_numa_map
|
||||
worker_config.subpool_thread_count = subpool_thread_count
|
||||
BaseMoEWrapper._cpu_infer_instance = kt_kernel_ext.CPUInfer(worker_config)
|
||||
|
||||
self.cpu_infer = BaseMoEWrapper._cpu_infer_instance
|
||||
|
||||
# Backend-specific initialization happens in subclasses
|
||||
self.moe = None
|
||||
|
||||
@abstractmethod
|
||||
def load_weights_from_tensors(
|
||||
self,
|
||||
gate_proj: torch.Tensor,
|
||||
up_proj: torch.Tensor,
|
||||
down_proj: torch.Tensor,
|
||||
physical_to_logical_map_cpu: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Load and quantize weights from BF16/FP16 tensors (online quantization).
|
||||
|
||||
Args:
|
||||
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
|
||||
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
|
||||
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
|
||||
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
|
||||
"""
|
||||
Load weights for this layer and initialize the MoE module.
|
||||
|
||||
Args:
|
||||
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
|
||||
"""
|
||||
pass
|
||||
|
||||
def select_deferred_experts(
|
||||
self,
|
||||
expert_ids: torch.Tensor,
|
||||
expert_scores: torch.Tensor,
|
||||
protected_k: int,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
batch, topk = expert_ids.shape
|
||||
device = expert_ids.device
|
||||
|
||||
protected_k = max(0, min(int(protected_k), topk))
|
||||
if protected_k == 0:
|
||||
deferred_ids = expert_ids.clone()
|
||||
immediate_ids = torch.full_like(expert_ids, -1)
|
||||
return immediate_ids, deferred_ids
|
||||
|
||||
topk_result = torch.topk(expert_scores, k=protected_k, dim=-1, largest=True, sorted=False)
|
||||
protected_indices = topk_result.indices
|
||||
protected_ids = torch.gather(expert_ids, -1, protected_indices)
|
||||
|
||||
protected_flag = torch.zeros((self.num_experts,), dtype=torch.int32, device=device)
|
||||
protected_flag.scatter_(0, protected_ids.reshape(-1), 1)
|
||||
|
||||
protected_mask_flat = torch.gather(protected_flag, 0, expert_ids.reshape(-1)).ne(0)
|
||||
protected_mask = protected_mask_flat.view(batch, topk)
|
||||
|
||||
immediate_ids = expert_ids.clone().masked_fill(~protected_mask, -1)
|
||||
deferred_ids = expert_ids.clone().masked_fill(protected_mask, -1)
|
||||
|
||||
return immediate_ids, deferred_ids
|
||||
|
||||
def submit_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
cuda_stream,
|
||||
):
|
||||
"""
|
||||
Submit forward inference task to CPU (non-blocking).
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states [batch_size, hidden_size]
|
||||
topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok]
|
||||
topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]
|
||||
cuda_stream: CUDA stream for synchronization
|
||||
"""
|
||||
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
batch_size = flat_hidden_states.shape[0]
|
||||
|
||||
(
|
||||
input_tensor_cpu,
|
||||
immediate_experts_ids_cpu,
|
||||
deferred_experts_ids_cpu,
|
||||
weights_cpu,
|
||||
output_cpu,
|
||||
bsz_tensor_cpu,
|
||||
_output_gpu,
|
||||
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
|
||||
|
||||
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
|
||||
next_slot = (current_slot + 1) % KExpertsCPUBuffer.buffer_depth
|
||||
|
||||
bsz_slot_tensor = bsz_tensor_cpu[current_slot]
|
||||
bsz_slot_tensor.fill_(batch_size)
|
||||
deferred_experts_ids_cpu[current_slot].fill_(-1)
|
||||
|
||||
topk_ids_long = topk_ids.to(torch.long)
|
||||
immediate_ids: torch.Tensor
|
||||
deferred_ids: Optional[torch.Tensor]
|
||||
if self.max_deferred_experts_per_token > 0:
|
||||
protected_k = self.num_experts_per_tok - self.max_deferred_experts_per_token
|
||||
|
||||
immediate_ids, deferred_ids = self.select_deferred_experts(topk_ids_long, topk_weights, protected_k)
|
||||
else:
|
||||
immediate_ids = topk_ids_long
|
||||
deferred_ids = None
|
||||
|
||||
input_tensor_cpu[current_slot].copy_(flat_hidden_states, non_blocking=True)
|
||||
weights_cpu[current_slot].copy_(topk_weights, non_blocking=True)
|
||||
immediate_experts_ids_cpu[current_slot].copy_(immediate_ids, non_blocking=True)
|
||||
|
||||
incremental = BaseMoEWrapper._layer_has_pending_deferred.get(self.layer_idx - 1, False)
|
||||
self.cpu_infer.submit_with_cuda_stream(
|
||||
cuda_stream,
|
||||
self.moe.forward_task(
|
||||
bsz_slot_tensor.data_ptr(),
|
||||
immediate_experts_ids_cpu[current_slot].size(-1),
|
||||
immediate_experts_ids_cpu[current_slot].data_ptr(),
|
||||
weights_cpu[current_slot].data_ptr(),
|
||||
input_tensor_cpu[current_slot].data_ptr(),
|
||||
output_cpu[current_slot].data_ptr(),
|
||||
incremental,
|
||||
),
|
||||
)
|
||||
|
||||
BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
|
||||
if deferred_ids is not None:
|
||||
deferred_experts_ids_cpu[current_slot].copy_(deferred_ids, non_blocking=True)
|
||||
self.cpu_infer.submit_with_cuda_stream(
|
||||
cuda_stream,
|
||||
self.moe.forward_task(
|
||||
bsz_slot_tensor.data_ptr(),
|
||||
deferred_experts_ids_cpu[current_slot].size(-1),
|
||||
deferred_experts_ids_cpu[current_slot].data_ptr(),
|
||||
weights_cpu[current_slot].data_ptr(),
|
||||
input_tensor_cpu[current_slot].data_ptr(),
|
||||
output_cpu[next_slot].data_ptr(),
|
||||
False,
|
||||
),
|
||||
)
|
||||
BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = True
|
||||
|
||||
def sync_forward(self, hidden_states: torch.Tensor, cuda_stream) -> torch.Tensor:
|
||||
"""
|
||||
Synchronize and retrieve forward inference results.
|
||||
|
||||
Args:
|
||||
hidden_states: Original input hidden states (for getting buffer)
|
||||
cuda_stream: CUDA stream for synchronization
|
||||
|
||||
Returns:
|
||||
output_gpu: Output tensor on GPU
|
||||
"""
|
||||
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
(
|
||||
_input_tensor_cpu,
|
||||
_immediate_experts_ids_cpu,
|
||||
_deferred_experts_ids_cpu,
|
||||
_weights_cpu,
|
||||
output_cpu,
|
||||
_bsz_tensor_cpu,
|
||||
output_gpu,
|
||||
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
|
||||
|
||||
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
|
||||
allow_pending = 1 if BaseMoEWrapper._layer_has_pending_deferred.get(self.layer_idx, False) else 0
|
||||
self.cpu_infer.sync_with_cuda_stream(cuda_stream, allow_pending)
|
||||
output_gpu[current_slot].copy_(output_cpu[current_slot], non_blocking=True)
|
||||
return output_gpu[current_slot]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
cuda_stream,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Execute forward inference synchronously (submit + sync).
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states [batch_size, hidden_size]
|
||||
topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok]
|
||||
topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]
|
||||
cuda_stream: CUDA stream for synchronization
|
||||
|
||||
Returns:
|
||||
Output tensor on GPU
|
||||
"""
|
||||
self.submit_forward(hidden_states, topk_ids, topk_weights, cuda_stream)
|
||||
return self.sync_forward(hidden_states, cuda_stream)
|
||||
|
||||
@staticmethod
|
||||
def set_capture_batch_sizes(capture_bs: List[int]):
|
||||
"""
|
||||
Set batch sizes to capture and cache buffers for.
|
||||
|
||||
This allows pre-allocation of CPU buffers for specific batch sizes,
|
||||
improving performance by avoiding buffer re-allocation during inference.
|
||||
|
||||
Args:
|
||||
capture_bs: List of batch sizes to capture (e.g., [1, 2, 4, 8, 16])
|
||||
|
||||
Example:
|
||||
>>> BaseMoEWrapper.set_capture_batch_sizes([1, 2, 4, 8, 16])
|
||||
"""
|
||||
KExpertsCPUBuffer.capture_bs = capture_bs
|
||||
|
||||
@staticmethod
|
||||
def get_capture_batch_sizes() -> List[int]:
|
||||
"""
|
||||
Get currently configured capture batch sizes.
|
||||
|
||||
Returns:
|
||||
List of batch sizes that are being captured
|
||||
"""
|
||||
return KExpertsCPUBuffer.capture_bs
|
||||
|
||||
@staticmethod
|
||||
def clear_buffer_cache():
|
||||
"""
|
||||
Clear all cached buffers.
|
||||
|
||||
This frees up memory by clearing the buffer cache. Useful when you want
|
||||
to reset the buffer state or free memory.
|
||||
"""
|
||||
KExpertsCPUBuffer.capture_buffers.clear()
|
||||
KExpertsCPUBuffer.temp_bs = 0
|
||||
KExpertsCPUBuffer.temp_buffer = tuple()
|
||||
301
kt-kernel/python/utils/amx.py
Normal file
301
kt-kernel/python/utils/amx.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import os
|
||||
import torch
|
||||
import ctypes
|
||||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import SafeTensorLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
try:
|
||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE
|
||||
_HAS_AMX_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_AMX_SUPPORT = False
|
||||
AMXInt4_MOE, AMXInt8_MOE = None, None
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class AMXMoEWrapper(BaseMoEWrapper):
|
||||
"""
|
||||
AMX-based MoE wrapper implementation.
|
||||
Supports AMXINT4 and AMXINT8 quantization methods.
|
||||
"""
|
||||
|
||||
_safetensor_loader_instance = None # Singleton SafeTensorLoader
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_idx: int,
|
||||
num_experts: int,
|
||||
num_experts_per_tok: int,
|
||||
hidden_size: int,
|
||||
moe_intermediate_size: int,
|
||||
num_gpu_experts: int,
|
||||
cpuinfer_threads: int,
|
||||
threadpool_count: int,
|
||||
weight_path: str,
|
||||
chunked_prefill_size: int,
|
||||
cpu_save: bool = False,
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
method: str = "AMXINT4",
|
||||
):
|
||||
"""
|
||||
Initialize AMX MoE Wrapper.
|
||||
|
||||
Args:
|
||||
layer_idx: Layer index
|
||||
num_experts: Total number of experts
|
||||
num_experts_per_tok: Number of experts per token (top-k)
|
||||
hidden_size: Hidden dimension size
|
||||
moe_intermediate_size: MoE intermediate size
|
||||
num_gpu_experts: Number of experts to run on GPU
|
||||
cpuinfer_threads: Number of CPU inference threads
|
||||
threadpool_count: Number of NUMA subpools
|
||||
weight_path: Path to AMX weights (SafeTensor format)
|
||||
chunked_prefill_size: Maximum prefill chunk size
|
||||
cpu_save: Whether to save weights to CPU memory
|
||||
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
|
||||
method: AMX quantization method ("AMXINT4" or "AMXINT8")
|
||||
"""
|
||||
if not _HAS_AMX_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"AMX backend not available. kt_kernel_ext was not compiled with AMX support.\n"
|
||||
"Please recompile with AMX enabled."
|
||||
)
|
||||
|
||||
# Initialize base class
|
||||
super().__init__(
|
||||
layer_idx=layer_idx,
|
||||
num_experts=num_experts,
|
||||
num_experts_per_tok=num_experts_per_tok,
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
num_gpu_experts=num_gpu_experts,
|
||||
cpuinfer_threads=cpuinfer_threads,
|
||||
threadpool_count=threadpool_count,
|
||||
weight_path=weight_path,
|
||||
chunked_prefill_size=chunked_prefill_size,
|
||||
cpu_save=cpu_save,
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
)
|
||||
|
||||
# AMX-specific: Check if we should load merged safetensor weights
|
||||
self.load_merged_weight = False
|
||||
import glob
|
||||
|
||||
if glob.glob(os.path.join(weight_path, "*.safetensors")):
|
||||
self.load_merged_weight = True
|
||||
|
||||
# Initialize SafeTensor loader (singleton)
|
||||
if self.load_merged_weight:
|
||||
if AMXMoEWrapper._safetensor_loader_instance is None:
|
||||
AMXMoEWrapper._safetensor_loader_instance = SafeTensorLoader(weight_path)
|
||||
self.safetensor_loader = AMXMoEWrapper._safetensor_loader_instance
|
||||
|
||||
# AMX-specific weight storage
|
||||
self.gate_weights = None
|
||||
self.up_weights = None
|
||||
self.down_weights = None
|
||||
self.gate_scales = None
|
||||
self.up_scales = None
|
||||
self.down_scales = None
|
||||
|
||||
def load_weights_from_tensors(
|
||||
self,
|
||||
gate_proj: torch.Tensor,
|
||||
up_proj: torch.Tensor,
|
||||
down_proj: torch.Tensor,
|
||||
physical_to_logical_map_cpu: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Load and quantize weights from BF16/FP16 tensors (online quantization).
|
||||
|
||||
Args:
|
||||
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
|
||||
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
|
||||
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
|
||||
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
|
||||
"""
|
||||
# Store tensors as instance variables to keep them alive
|
||||
self.gate_proj = gate_proj.contiguous()
|
||||
self.up_proj = up_proj.contiguous()
|
||||
self.down_proj = down_proj.contiguous()
|
||||
|
||||
# Configure MoE with online quantization (cpu_save mode)
|
||||
moe_config = MOEConfig(
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
self.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
self.num_gpu_experts,
|
||||
)
|
||||
moe_config.layer_idx = self.layer_idx
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
moe_config.max_len = self.chunked_prefill_size
|
||||
|
||||
# Enable save mode for online quantization
|
||||
moe_config.save = True
|
||||
moe_config.load = False
|
||||
|
||||
# Set weight pointers
|
||||
moe_config.gate_proj = self.gate_proj.data_ptr()
|
||||
moe_config.up_proj = self.up_proj.data_ptr()
|
||||
moe_config.down_proj = self.down_proj.data_ptr()
|
||||
|
||||
# Set output path for quantized weights
|
||||
moe_config.path = self.weight_path
|
||||
|
||||
# Create MoE module based on AMX method
|
||||
if self.method == "AMXINT4":
|
||||
self.moe = AMXInt4_MOE(moe_config)
|
||||
elif self.method == "AMXINT8":
|
||||
self.moe = AMXInt8_MOE(moe_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported AMX method: {self.method}")
|
||||
|
||||
# Submit quantization and save task
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
|
||||
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
|
||||
"""
|
||||
Load weights for this layer and initialize the MoE module.
|
||||
|
||||
Args:
|
||||
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
|
||||
"""
|
||||
gate_ptr = 0
|
||||
up_ptr = 0
|
||||
down_ptr = 0
|
||||
|
||||
gate_ptrs = []
|
||||
up_ptrs = []
|
||||
down_ptrs = []
|
||||
|
||||
gate_scale_ptrs = []
|
||||
up_scale_ptrs = []
|
||||
down_scale_ptrs = []
|
||||
|
||||
if self.load_merged_weight:
|
||||
base_key = f"blk.{self.layer_idx}"
|
||||
w = self.safetensor_loader.load_experts(base_key)
|
||||
|
||||
self.gate_weights = w["gate"]
|
||||
self.up_weights = w["up"]
|
||||
self.down_weights = w["down"]
|
||||
self.gate_scales = w["gate_scale"]
|
||||
self.up_scales = w["up_scale"]
|
||||
self.down_scales = w["down_scale"]
|
||||
|
||||
# Get pointers to weight arrays
|
||||
gate_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.gate_weights
|
||||
]
|
||||
|
||||
up_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.up_weights
|
||||
]
|
||||
|
||||
down_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.down_weights
|
||||
]
|
||||
|
||||
gate_scale_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.gate_scales
|
||||
]
|
||||
|
||||
up_scale_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.up_scales
|
||||
]
|
||||
|
||||
down_scale_ptrs = [
|
||||
[
|
||||
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
||||
for et in numa_array
|
||||
]
|
||||
for numa_array in self.down_scales
|
||||
]
|
||||
|
||||
# Configure MoE
|
||||
moe_config = MOEConfig(
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
self.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
self.num_gpu_experts,
|
||||
)
|
||||
moe_config.layer_idx = self.layer_idx
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
moe_config.max_len = self.chunked_prefill_size
|
||||
|
||||
moe_config.gate_proj = gate_ptr
|
||||
moe_config.up_proj = up_ptr
|
||||
moe_config.down_proj = down_ptr
|
||||
moe_config.gate_projs = gate_ptrs
|
||||
moe_config.up_projs = up_ptrs
|
||||
moe_config.down_projs = down_ptrs
|
||||
moe_config.gate_scales = gate_scale_ptrs
|
||||
moe_config.up_scales = up_scale_ptrs
|
||||
moe_config.down_scales = down_scale_ptrs
|
||||
|
||||
if self.cpu_save:
|
||||
moe_config.save = True
|
||||
moe_config.load = False
|
||||
base_key = f"model.layers.{self.layer_idx}"
|
||||
w = self.safetensor_loader.load_experts(base_key)
|
||||
|
||||
self.gate_proj = torch.cat(w["gate_weight"], dim=0).contiguous()
|
||||
self.up_proj = torch.cat(w["up_weight"], dim=0).contiguous()
|
||||
self.down_proj = torch.cat(w["down_weight"], dim=0).contiguous()
|
||||
|
||||
moe_config.gate_proj = self.gate_proj.data_ptr()
|
||||
moe_config.up_proj = self.up_proj.data_ptr()
|
||||
moe_config.down_proj = self.down_proj.data_ptr()
|
||||
else:
|
||||
moe_config.load = True
|
||||
|
||||
if not self.load_merged_weight:
|
||||
moe_config.path = self.weight_path
|
||||
|
||||
# Create MoE module based on AMX method
|
||||
if self.method == "AMXINT4":
|
||||
self.moe = AMXInt4_MOE(moe_config)
|
||||
elif self.method == "AMXINT8":
|
||||
self.moe = AMXInt8_MOE(moe_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported AMX method: {self.method}")
|
||||
|
||||
# Load weights
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
|
||||
# Clean up temporary weight storage if using merged weights
|
||||
if self.load_merged_weight:
|
||||
del self.gate_weights
|
||||
del self.up_weights
|
||||
del self.down_weights
|
||||
del self.gate_scales
|
||||
del self.up_scales
|
||||
del self.down_scales
|
||||
225
kt-kernel/python/utils/llamafile.py
Normal file
225
kt-kernel/python/utils/llamafile.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
import os
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import GGUFLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
try:
|
||||
from kt_kernel_ext.moe import MOE
|
||||
_HAS_LLAMAFILE_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_LLAMAFILE_SUPPORT = False
|
||||
MOE = None
|
||||
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
|
||||
class LlamafileMoEWrapper(BaseMoEWrapper):
|
||||
"""
|
||||
Llamafile-based MoE wrapper implementation.
|
||||
Supports GGUF quantized weights with llamafile backend.
|
||||
"""
|
||||
|
||||
_gguf_loader_instance = None # Singleton GGUFLoader
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_idx: int,
|
||||
num_experts: int,
|
||||
num_experts_per_tok: int,
|
||||
hidden_size: int,
|
||||
moe_intermediate_size: int,
|
||||
num_gpu_experts: int,
|
||||
cpuinfer_threads: int,
|
||||
threadpool_count: int,
|
||||
weight_path: str,
|
||||
chunked_prefill_size: int,
|
||||
cpu_save: bool = False,
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
method: str = "LLAMAFILE",
|
||||
):
|
||||
"""
|
||||
Initialize Llamafile MoE Wrapper.
|
||||
|
||||
Args:
|
||||
layer_idx: Layer index
|
||||
num_experts: Total number of experts
|
||||
num_experts_per_tok: Number of experts per token (top-k)
|
||||
hidden_size: Hidden dimension size
|
||||
moe_intermediate_size: MoE intermediate size
|
||||
num_gpu_experts: Number of experts to run on GPU
|
||||
cpuinfer_threads: Number of CPU inference threads
|
||||
threadpool_count: Number of NUMA subpools (TP count)
|
||||
weight_path: Path to GGUF weights
|
||||
chunked_prefill_size: Maximum prefill chunk size
|
||||
cpu_save: Not supported for Llamafile backend
|
||||
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
|
||||
method: Should be "LLAMAFILE"
|
||||
"""
|
||||
if not _HAS_LLAMAFILE_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"Llamafile backend not available. kt_kernel_ext was not compiled with Llamafile support.\n"
|
||||
"Please recompile with Llamafile enabled."
|
||||
)
|
||||
|
||||
if not os.path.exists(weight_path):
|
||||
raise FileNotFoundError(f"GGUF weight path not found: {weight_path}")
|
||||
|
||||
# Initialize GGUF loader (singleton)
|
||||
if LlamafileMoEWrapper._gguf_loader_instance is None:
|
||||
LlamafileMoEWrapper._gguf_loader_instance = GGUFLoader(weight_path)
|
||||
self.gguf_loader = LlamafileMoEWrapper._gguf_loader_instance
|
||||
|
||||
# Validate TP configuration with QK_K alignment
|
||||
QK_K = 256
|
||||
|
||||
# Check if intermediate_size is divisible by QK_K
|
||||
if moe_intermediate_size % QK_K != 0:
|
||||
raise ValueError(
|
||||
f"intermediate_size ({moe_intermediate_size}) must be divisible by QK_K ({QK_K}) "
|
||||
f"for Llamafile backend"
|
||||
)
|
||||
|
||||
# Calculate TP splits with QK_K alignment
|
||||
num_blocks = moe_intermediate_size // QK_K
|
||||
base_blocks = num_blocks // threadpool_count
|
||||
extra_blocks = num_blocks % threadpool_count
|
||||
|
||||
# Validate that we have enough blocks
|
||||
if base_blocks == 0:
|
||||
valid_tp_counts = list(range(1, num_blocks + 1))
|
||||
raise ValueError(
|
||||
f"intermediate_size ({moe_intermediate_size}) is too small for threadpool_count ({threadpool_count}).\n"
|
||||
f"Total blocks: {num_blocks} (intermediate_size / QK_K)\n"
|
||||
f"Cannot distribute to {threadpool_count} TPs (each TP needs at least 1 block).\n"
|
||||
f"Valid threadpool_count values: {valid_tp_counts}"
|
||||
)
|
||||
|
||||
# Log TP split information
|
||||
print(f"[LlamafileMoEWrapper] Layer {layer_idx} TP configuration:")
|
||||
print(f" intermediate_size: {moe_intermediate_size}")
|
||||
print(f" threadpool_count: {threadpool_count}")
|
||||
print(f" QK_K: {QK_K}")
|
||||
print(f" Total blocks: {num_blocks}")
|
||||
print(f" Base blocks per TP: {base_blocks}")
|
||||
print(f" Extra blocks (distributed to first TPs): {extra_blocks}")
|
||||
|
||||
current_offset = 0
|
||||
for tp_id in range(threadpool_count):
|
||||
tp_blocks = base_blocks + (1 if tp_id < extra_blocks else 0)
|
||||
tp_size = tp_blocks * QK_K
|
||||
print(f" TP {tp_id}: size={tp_size}, offset={current_offset}, blocks={tp_blocks}")
|
||||
current_offset += tp_size
|
||||
|
||||
# Initialize base class
|
||||
super().__init__(
|
||||
layer_idx=layer_idx,
|
||||
num_experts=num_experts,
|
||||
num_experts_per_tok=num_experts_per_tok,
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
num_gpu_experts=num_gpu_experts,
|
||||
cpuinfer_threads=cpuinfer_threads,
|
||||
threadpool_count=threadpool_count,
|
||||
weight_path=weight_path,
|
||||
chunked_prefill_size=chunked_prefill_size,
|
||||
cpu_save=cpu_save,
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
)
|
||||
|
||||
self.weights_to_keep = None
|
||||
|
||||
def load_weights_from_tensors(
|
||||
self,
|
||||
gate_proj: torch.Tensor,
|
||||
up_proj: torch.Tensor,
|
||||
down_proj: torch.Tensor,
|
||||
physical_to_logical_map_cpu: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Online quantization is not supported for Llamafile backend.
|
||||
Use pre-quantized GGUF weights instead.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Llamafile backend does not support online quantization (load_weights_from_tensors).\n"
|
||||
"Please use pre-quantized GGUF weights and call load_weights() instead."
|
||||
)
|
||||
|
||||
def load_weights(self, physical_to_logical_map_cpu: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Load weights for this layer from GGUF files and initialize the MoE module.
|
||||
|
||||
Args:
|
||||
physical_to_logical_map_cpu: Optional mapping from physical to logical expert IDs
|
||||
Shape: [num_experts], dtype: int32
|
||||
If None, uses identity mapping [0, 1, 2, ..., num_experts-1]
|
||||
"""
|
||||
if not _HAS_LLAMAFILE_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"Llamafile backend not available. kt_kernel_ext was not compiled with Llamafile support.\n"
|
||||
"Please recompile with Llamafile enabled."
|
||||
)
|
||||
|
||||
if physical_to_logical_map_cpu is None:
|
||||
physical_to_logical_map_cpu = torch.arange(
|
||||
self.num_experts,
|
||||
dtype=torch.int32,
|
||||
device="cpu"
|
||||
)
|
||||
print(f" Using default identity mapping for {self.num_experts} experts")
|
||||
|
||||
base_key = f"blk.{self.layer_idx}"
|
||||
|
||||
# Load quantized tensors from GGUF
|
||||
gate_data, gate_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(
|
||||
f"{base_key}.ffn_gate_exps.weight"
|
||||
)
|
||||
|
||||
up_data, up_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(
|
||||
f"{base_key}.ffn_up_exps.weight"
|
||||
)
|
||||
|
||||
down_data, down_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(
|
||||
f"{base_key}.ffn_down_exps.weight"
|
||||
)
|
||||
|
||||
# Keep tensors alive
|
||||
self.weights_to_keep = (gate_data, up_data, down_data)
|
||||
|
||||
hidden_type = ggml_type.BF16
|
||||
|
||||
# Configure MoE
|
||||
moe_config = MOEConfig(
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
self.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
self.num_gpu_experts,
|
||||
)
|
||||
moe_config.layer_idx = self.layer_idx
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
|
||||
# Llamafile-specific configuration
|
||||
moe_config.m_block = 32 # Parallel block size
|
||||
moe_config.group_min_len = 10 # Use forward_one when qlen < 10
|
||||
moe_config.max_len = self.chunked_prefill_size
|
||||
moe_config.group_max_len = max(1, int(self.chunked_prefill_size))
|
||||
|
||||
# Set weight pointers
|
||||
moe_config.gate_proj = gate_data.data_ptr()
|
||||
moe_config.up_proj = up_data.data_ptr()
|
||||
moe_config.down_proj = down_data.data_ptr()
|
||||
|
||||
# Set quantization types
|
||||
moe_config.gate_type = gate_type
|
||||
moe_config.up_type = up_type
|
||||
moe_config.down_type = down_type
|
||||
moe_config.hidden_type = hidden_type
|
||||
|
||||
# Create MoE module
|
||||
self.moe = MOE(moe_config)
|
||||
|
||||
# Load weights
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
522
kt-kernel/python/utils/loader.py
Normal file
522
kt-kernel/python/utils/loader.py
Normal file
@@ -0,0 +1,522 @@
|
||||
"""
|
||||
Weight loaders for different formats.
|
||||
|
||||
This module provides loaders for:
|
||||
- SafeTensor format (for AMX quantized weights)
|
||||
- GGUF format (for Llamafile quantized weights)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from enum import IntEnum
|
||||
from safetensors import safe_open
|
||||
from gguf.gguf_reader import GGUFReader
|
||||
|
||||
|
||||
class GGMLQuantizationType(IntEnum):
|
||||
"""GGML quantization type enumeration"""
|
||||
F32 = 0
|
||||
F16 = 1
|
||||
Q4_0 = 2
|
||||
Q4_1 = 3
|
||||
Q5_0 = 6
|
||||
Q5_1 = 7
|
||||
Q8_0 = 8
|
||||
Q8_1 = 9
|
||||
Q2_K = 10
|
||||
Q3_K = 11
|
||||
Q4_K = 12
|
||||
Q5_K = 13
|
||||
Q6_K = 14
|
||||
Q8_K = 15
|
||||
IQ2_XXS = 16
|
||||
IQ2_XS = 17
|
||||
IQ3_XXS = 18
|
||||
IQ1_S = 19
|
||||
IQ4_NL = 20
|
||||
IQ3_S = 21
|
||||
IQ2_S = 22
|
||||
IQ4_XS = 23
|
||||
I8 = 24
|
||||
I16 = 25
|
||||
I32 = 26
|
||||
I64 = 27
|
||||
F64 = 28
|
||||
IQ1_M = 29
|
||||
BF16 = 30
|
||||
|
||||
|
||||
def translate_name_to_gguf(name):
|
||||
"""
|
||||
Translate PyTorch tensor name to GGUF format
|
||||
"""
|
||||
name = name.replace("lm_head.", "output.")
|
||||
name = name.replace("model.embed_tokens.", "token_embd.")
|
||||
name = name.replace("model.norm.", "output_norm.")
|
||||
name = name.replace("model.layers.", "blk.")
|
||||
name = name.replace(".input_layernorm", ".attn_norm")
|
||||
name = name.replace(".mlp.down_proj", ".ffn_down")
|
||||
name = name.replace(".mlp.gate_proj", ".ffn_gate")
|
||||
name = name.replace(".mlp.up_proj", ".ffn_up")
|
||||
name = name.replace(".post_attention_layernorm", ".ffn_norm")
|
||||
name = name.replace(".self_attn.q_proj", ".attn_q")
|
||||
name = name.replace(".self_attn.k_proj", ".attn_k")
|
||||
name = name.replace(".self_attn.v_proj", ".attn_v")
|
||||
name = name.replace(".self_attn.o_proj", ".attn_output")
|
||||
name = name.replace(".self_attn.qkv_proj", ".attn_qkv")
|
||||
name = name.replace(".self_attn.kv_a_proj_with_mqa", ".attn_kv_a_mqa")
|
||||
name = name.replace(".self_attn.kv_a_layernorm", ".attn_kv_a_norm")
|
||||
name = name.replace(".self_attn.kv_b_proj", ".attn_kv_b")
|
||||
name = name.replace(".self_attn.q_a_proj", ".attn_q_a")
|
||||
name = name.replace(".self_attn.q_a_layernorm", ".attn_q_a_norm")
|
||||
name = name.replace(".self_attn.q_b_proj", ".attn_q_b")
|
||||
name = name.replace(".self_attn.q_norm", ".attn_q_norm")
|
||||
name = name.replace(".self_attn.k_norm", ".attn_k_norm")
|
||||
name = name.replace(".shared_expert.", ".shared_experts.")
|
||||
name = name.replace(".shared_expert_", ".shared_experts_")
|
||||
name = name.replace(".gate_up_proj.", ".up_proj")
|
||||
name = name.replace(".mlp.shared_experts.down_proj", ".ffn_down_shexp")
|
||||
name = name.replace(".mlp.gate.e_score_correction_bias", ".exp_probs_b.bias")
|
||||
name = name.replace(".mlp.gate", ".ffn_gate_inp")
|
||||
name = name.replace(".mlp.shared_experts.gate_proj", ".ffn_gate_shexp")
|
||||
name = name.replace(".mlp.shared_experts.up_proj", ".ffn_up_shexp")
|
||||
name = name.replace(".mlp.shared_experts_gate", ".ffn_gate_inp_shexp")
|
||||
name = name.replace(".mlp.experts", "")
|
||||
name = name.replace(".mlp.experts.ffn_down_exps", ".ffn_down_exps")
|
||||
name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps")
|
||||
name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps")
|
||||
name = name.replace(".block_sparse_moe.gate.", ".ffn_gate_inp.")
|
||||
name = name.replace(".block_sparse_moe.experts", "")
|
||||
name = name.replace(".feed_forward.experts", "")
|
||||
name = name.replace(".feed_forward.router", ".ffn_gate_inp")
|
||||
name = name.replace(".feed_forward.shared_experts.down_proj", ".ffn_down_shexp")
|
||||
name = name.replace(".feed_forward.shared_experts.gate_proj", ".ffn_gate_shexp")
|
||||
name = name.replace(".feed_forward.shared_experts.up_proj", ".ffn_up_shexp")
|
||||
return name
|
||||
|
||||
|
||||
class SafeTensorLoader:
|
||||
"""
|
||||
SafeTensor format loader for AMX quantized weights.
|
||||
|
||||
Supports loading tensors from .safetensors files with NUMA-sharded expert weights.
|
||||
"""
|
||||
tensor_file_map: dict
|
||||
tensor_type_map: dict
|
||||
file_handle_map: dict
|
||||
tensor_device_map: dict
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
self.__load_tensor_file_map(file_path)
|
||||
|
||||
def __load_tensor_file_map(self, file_path: str):
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Path not found: {file_path}")
|
||||
if os.path.isfile(file_path):
|
||||
folder_path = os.path.dirname(file_path)
|
||||
else:
|
||||
folder_path = file_path
|
||||
self.file_handle_map = {}
|
||||
self.tensor_file_map = {}
|
||||
self.tensor_type_map = {}
|
||||
self.tensor_device_map = {}
|
||||
|
||||
found_safetensor = False
|
||||
for root, _, files in os.walk(folder_path):
|
||||
files = sorted(files)
|
||||
for file in files:
|
||||
if file.endswith(".safetensors"):
|
||||
found_safetensor = True
|
||||
file_path = os.path.join(root, file)
|
||||
if file not in self.file_handle_map:
|
||||
try:
|
||||
handle = safe_open(file_path, framework="pt")
|
||||
self.file_handle_map[file] = handle
|
||||
except Exception as e:
|
||||
print(f"Error opening Safetensor file {file_path}: {e}")
|
||||
continue
|
||||
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
continue
|
||||
try:
|
||||
for key in f.keys():
|
||||
self.tensor_file_map[key] = file
|
||||
except Exception as e:
|
||||
print(f"Error reading Safetensor file {file_path}: {e}")
|
||||
|
||||
if not found_safetensor:
|
||||
raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
|
||||
|
||||
def load_tensor(self, key: str, device: str = "cpu"):
|
||||
if key not in self.tensor_file_map:
|
||||
raise KeyError(f"Key {key} not found in Safetensor files")
|
||||
file = self.tensor_file_map[key]
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||
tensor = f.get_tensor(key)
|
||||
return tensor.to(device)
|
||||
|
||||
def close_all_handles(self):
|
||||
for handle in self.file_handle_map.values():
|
||||
handle.close()
|
||||
self.file_handle_map.clear()
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
"""
|
||||
Load expert weights from SafeTensor files.
|
||||
|
||||
Expected format:
|
||||
- blk.{layer_index}.ffn_[up, down, gate]_exps.{expert_id}.numa.{numa_id}.weight
|
||||
- blk.{layer_index}.ffn_[up, down, gate]_exps.{expert_id}.numa.{numa_id}.scale
|
||||
|
||||
Args:
|
||||
base_key: Base key like "blk.{layer_index}"
|
||||
device: Target device for tensors
|
||||
|
||||
Returns:
|
||||
Dictionary with keys: up, gate, down, up_scale, gate_scale, down_scale
|
||||
Each value is a list of lists: [numa_id][expert_id] -> numpy array
|
||||
"""
|
||||
up_base_key = f"{base_key}.ffn_up_exps"
|
||||
gate_base_key = f"{base_key}.ffn_gate_exps"
|
||||
down_base_key = f"{base_key}.ffn_down_exps"
|
||||
max_numa_id = -1
|
||||
max_experts_count = -1
|
||||
while self.has_tensor(f"{up_base_key}.{max_experts_count+1}.numa.{0}.weight"):
|
||||
max_experts_count += 1
|
||||
if max_experts_count == 0:
|
||||
raise ValueError(f"No experts found for key {base_key}")
|
||||
while self.has_tensor(f"{up_base_key}.{0}.numa.{max_numa_id+1}.weight"):
|
||||
max_numa_id += 1
|
||||
# Initialize empty lists to store tensors for each projection type
|
||||
up_weights = [[] for _ in range(max_numa_id + 1)]
|
||||
gate_weights = [[] for _ in range(max_numa_id + 1)]
|
||||
down_weights = [[] for _ in range(max_numa_id + 1)]
|
||||
up_scales = [[] for _ in range(max_numa_id + 1)]
|
||||
gate_scales = [[] for _ in range(max_numa_id + 1)]
|
||||
down_scales = [[] for _ in range(max_numa_id + 1)]
|
||||
for numa_id in range(max_numa_id + 1):
|
||||
for expert_id in range(max_experts_count + 1):
|
||||
up_key = f"{up_base_key}.{expert_id}.numa.{numa_id}.weight"
|
||||
gate_key = f"{gate_base_key}.{expert_id}.numa.{numa_id}.weight"
|
||||
down_key = f"{down_base_key}.{expert_id}.numa.{numa_id}.weight"
|
||||
up_scale_key = f"{up_base_key}.{expert_id}.numa.{numa_id}.scale"
|
||||
gate_scale_key = f"{gate_base_key}.{expert_id}.numa.{numa_id}.scale"
|
||||
down_scale_key = f"{down_base_key}.{expert_id}.numa.{numa_id}.scale"
|
||||
# make sure contiguous
|
||||
up_tensor = self.load_tensor(up_key, device).numpy()
|
||||
gate_tensor = self.load_tensor(gate_key, device).numpy()
|
||||
down_tensor = self.load_tensor(down_key, device).numpy()
|
||||
up_scale_tensor = self.load_tensor(up_scale_key, device).numpy()
|
||||
gate_scale_tensor = self.load_tensor(gate_scale_key, device).numpy()
|
||||
down_scale_tensor = self.load_tensor(down_scale_key, device).numpy()
|
||||
|
||||
up_weights[numa_id].append(up_tensor)
|
||||
gate_weights[numa_id].append(gate_tensor)
|
||||
down_weights[numa_id].append(down_tensor)
|
||||
up_scales[numa_id].append(up_scale_tensor)
|
||||
gate_scales[numa_id].append(gate_scale_tensor)
|
||||
down_scales[numa_id].append(down_scale_tensor)
|
||||
return {
|
||||
"up": up_weights,
|
||||
"gate": gate_weights,
|
||||
"down": down_weights,
|
||||
"up_scale": up_scales,
|
||||
"gate_scale": gate_scales,
|
||||
"down_scale": down_scales,
|
||||
}
|
||||
|
||||
def has_tensor(self, name: str):
|
||||
return name in self.tensor_file_map
|
||||
|
||||
|
||||
class GGUFLoader:
|
||||
"""
|
||||
GGUF format loader using the official gguf library (gguf.gguf_reader.GGUFReader)
|
||||
|
||||
This is a cleaner implementation compared to manual binary parsing.
|
||||
"""
|
||||
|
||||
def __init__(self, gguf_path: str):
|
||||
"""
|
||||
Initialize GGUF loader from a file or directory
|
||||
|
||||
Args:
|
||||
gguf_path: Path to a single GGUF file or a directory containing GGUF files
|
||||
"""
|
||||
if not os.path.exists(gguf_path):
|
||||
raise FileNotFoundError(f"GGUF path not found: {gguf_path}")
|
||||
|
||||
self.tensor_info = {}
|
||||
self.metadata = {}
|
||||
self.tensor_file_map = {}
|
||||
self.file_data_map = {}
|
||||
|
||||
if os.path.isfile(gguf_path) and gguf_path.endswith('.gguf'):
|
||||
print(f"\n[GGUFLoader] Loading single GGUF file : {os.path.basename(gguf_path)}")
|
||||
self._load_single_file(gguf_path)
|
||||
elif os.path.isdir(gguf_path):
|
||||
print(f"\n[GGUFLoader] Loading GGUF files from directory: {gguf_path}")
|
||||
self._load_directory(gguf_path)
|
||||
else:
|
||||
raise ValueError(f"Path must be a .gguf file or a directory: {gguf_path}")
|
||||
|
||||
print(f"[GGUFLoader] Summary:")
|
||||
print(f" Files loaded: {len(self.file_data_map)}")
|
||||
print(f" Total tensors: {len(self.tensor_info)}")
|
||||
print(f" Metadata keys: {len(self.metadata)}")
|
||||
tensors = ["blk.0.ffn_up_exps.weight", "blk.0.ffn_gate_exps.weight", "blk.0.ffn_down_exps.weight"]
|
||||
for key in tensors:
|
||||
if key in self.tensor_info:
|
||||
info = self.tensor_info[key]
|
||||
print(f" {'.'.join(key.split('.')[2:-1])}, Dtype: {info['dtype'].name}")
|
||||
|
||||
def _load_single_file(self, file_path: str):
|
||||
"""Load a single GGUF file"""
|
||||
reader = GGUFReader(file_path)
|
||||
|
||||
for key, field in reader.fields.items():
|
||||
value = field.parts[field.data[0]]
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf-8')
|
||||
elif isinstance(value, np.ndarray) and value.dtype == np.uint8:
|
||||
try:
|
||||
value = bytes(value).decode('utf-8')
|
||||
except:
|
||||
pass
|
||||
self.metadata[key] = value
|
||||
|
||||
for tensor in reader.tensors:
|
||||
self.tensor_info[tensor.name] = {
|
||||
'shape': list(reversed(tensor.shape)), # Reverse to match PyTorch order
|
||||
'dtype': tensor.tensor_type,
|
||||
'offset': tensor.data_offset,
|
||||
'n_elements': tensor.n_elements,
|
||||
}
|
||||
self.tensor_file_map[tensor.name] = file_path
|
||||
|
||||
self.file_data_map[file_path] = np.memmap(file_path, mode='r')
|
||||
|
||||
def _load_directory(self, dir_path: str):
|
||||
"""Load all GGUF files from a directory (non-recursive)"""
|
||||
found_gguf = False
|
||||
|
||||
for file in sorted(os.listdir(dir_path)):
|
||||
if file.endswith(".gguf"):
|
||||
found_gguf = True
|
||||
file_path = os.path.join(dir_path, file)
|
||||
print(f" Loading: {file}")
|
||||
|
||||
reader = GGUFReader(file_path)
|
||||
|
||||
for key, field in reader.fields.items():
|
||||
value = field.parts[field.data[0]]
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf-8')
|
||||
elif isinstance(value, np.ndarray) and value.dtype == np.uint8:
|
||||
try:
|
||||
value = bytes(value).decode('utf-8')
|
||||
except:
|
||||
pass
|
||||
self.metadata[key] = value
|
||||
|
||||
for tensor in reader.tensors:
|
||||
self.tensor_info[tensor.name] = {
|
||||
'shape': list(reversed(tensor.shape)),
|
||||
'dtype': tensor.tensor_type,
|
||||
'offset': tensor.data_offset,
|
||||
'n_elements': tensor.n_elements,
|
||||
}
|
||||
self.tensor_file_map[tensor.name] = file_path
|
||||
|
||||
self.file_data_map[file_path] = np.memmap(file_path, mode='r')
|
||||
|
||||
if not found_gguf:
|
||||
raise FileNotFoundError(f"No .gguf files found in directory: {dir_path}")
|
||||
|
||||
def get_model_config(self, layer_idx: int = 0):
|
||||
"""
|
||||
Extract model configuration from GGUF metadata and tensor shapes.
|
||||
|
||||
Args:
|
||||
layer_idx: Layer index to inspect (default: 0)
|
||||
|
||||
Returns:
|
||||
dict with keys: num_experts, num_experts_per_tok, hidden_size, moe_intermediate_size
|
||||
"""
|
||||
config = {}
|
||||
|
||||
arch = self.metadata.get("general.architecture", "unknown")
|
||||
|
||||
num_experts = None
|
||||
for key_suffix in [
|
||||
"expert_count",
|
||||
"expert.count",
|
||||
"moe.expert_count",
|
||||
"expert_feed_forward_length",
|
||||
]:
|
||||
key = f"{arch}.{key_suffix}"
|
||||
if key in self.metadata:
|
||||
val = self.metadata[key]
|
||||
num_experts = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)
|
||||
break
|
||||
|
||||
num_experts_per_tok = None
|
||||
for key_suffix in [
|
||||
"expert_used_count",
|
||||
"expert.used_count",
|
||||
"moe.num_experts_per_tok",
|
||||
]:
|
||||
key = f"{arch}.{key_suffix}"
|
||||
if key in self.metadata:
|
||||
val = self.metadata[key]
|
||||
num_experts_per_tok = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)
|
||||
break
|
||||
|
||||
hidden_size = None
|
||||
for key_suffix in [
|
||||
"embedding_length",
|
||||
"embed_length",
|
||||
"hidden_size",
|
||||
]:
|
||||
key = f"{arch}.{key_suffix}"
|
||||
if key in self.metadata:
|
||||
val = self.metadata[key]
|
||||
hidden_size = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)
|
||||
break
|
||||
|
||||
moe_intermediate_size = None
|
||||
for key_suffix in [
|
||||
"expert_feed_forward_length",
|
||||
"feed_forward_length",
|
||||
"ffn_length",
|
||||
"intermediate_size",
|
||||
]:
|
||||
key = f"{arch}.{key_suffix}"
|
||||
if key in self.metadata:
|
||||
val = self.metadata[key]
|
||||
moe_intermediate_size = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)
|
||||
break
|
||||
|
||||
if any(v is None for v in [num_experts, hidden_size, moe_intermediate_size]):
|
||||
|
||||
base_key = f"blk.{layer_idx}.ffn_gate_exps.weight"
|
||||
if base_key in self.tensor_info:
|
||||
gate_shape = self.tensor_info[base_key]['shape']
|
||||
print(f" Found tensor '{base_key}' with shape: {gate_shape}")
|
||||
|
||||
if len(gate_shape) >= 3:
|
||||
if num_experts is None:
|
||||
num_experts = int(gate_shape[0])
|
||||
if moe_intermediate_size is None:
|
||||
moe_intermediate_size = int(gate_shape[1])
|
||||
if hidden_size is None:
|
||||
hidden_size = int(gate_shape[2])
|
||||
|
||||
config = {
|
||||
"num_experts": num_experts,
|
||||
"num_experts_per_tok": num_experts_per_tok,
|
||||
"hidden_size": hidden_size,
|
||||
"moe_intermediate_size": moe_intermediate_size,
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def print_metadata(self, filter_keywords=None):
|
||||
"""
|
||||
Print GGUF file metadata for debugging.
|
||||
|
||||
Args:
|
||||
filter_keywords: Optional list of keywords to filter metadata keys
|
||||
"""
|
||||
print(f"\n[GGUFLoader] GGUF Metadata:")
|
||||
print(f" Total metadata entries: {len(self.metadata)}")
|
||||
|
||||
if filter_keywords:
|
||||
filtered = {k: v for k, v in self.metadata.items()
|
||||
if any(kw.lower() in k.lower() for kw in filter_keywords)}
|
||||
for k, v in sorted(filtered.items()):
|
||||
print(f" {k}: {v}")
|
||||
else:
|
||||
for k, v in sorted(self.metadata.items()):
|
||||
print(f" {k}: {v}")
|
||||
|
||||
def has_tensor(self, name: str):
|
||||
"""Check if tensor exists"""
|
||||
name = translate_name_to_gguf(name)
|
||||
return name in self.tensor_info
|
||||
|
||||
def get_ggml_type(self, name: str):
|
||||
"""Get GGML type of a tensor"""
|
||||
name = translate_name_to_gguf(name)
|
||||
if name not in self.tensor_info:
|
||||
raise KeyError(f"Tensor '{name}' not found in GGUF files")
|
||||
return self.tensor_info[name]["dtype"]
|
||||
|
||||
def get_undequanted_tensor_and_ggml_type(self, name: str):
|
||||
"""
|
||||
Get tensor data and its GGML type without dequantizing
|
||||
|
||||
Args:
|
||||
name: Tensor name (in PyTorch format, will be translated to GGUF format)
|
||||
|
||||
Returns:
|
||||
(data, ggml_type): Tuple of tensor data and GGML quantization type
|
||||
"""
|
||||
name = translate_name_to_gguf(name)
|
||||
|
||||
if name not in self.tensor_info:
|
||||
raise KeyError(f"Tensor '{name}' not found in GGUF files")
|
||||
|
||||
info = self.tensor_info[name]
|
||||
file_path = self.tensor_file_map[name]
|
||||
mmap_data = self.file_data_map[file_path]
|
||||
|
||||
offset = info['offset']
|
||||
n_elements = info['n_elements']
|
||||
ggml_type = info['dtype']
|
||||
|
||||
GGML_QUANT_SIZES = {
|
||||
GGMLQuantizationType.F32: (1, 4),
|
||||
GGMLQuantizationType.F16: (1, 2),
|
||||
GGMLQuantizationType.BF16: (1, 2),
|
||||
GGMLQuantizationType.Q4_0: (32, 2 + 16),
|
||||
GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
|
||||
GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
|
||||
GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
|
||||
GGMLQuantizationType.Q8_0: (32, 2 + 32),
|
||||
GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
|
||||
GGMLQuantizationType.Q2_K: (256, 2 + 2 + 256 // 16 + 256 // 4),
|
||||
GGMLQuantizationType.Q3_K: (256, 2 + 256 // 4 + 256 // 8 + 12),
|
||||
GGMLQuantizationType.Q4_K: (256, 2 + 2 + 256 // 2 + 12),
|
||||
GGMLQuantizationType.Q5_K: (256, 2 + 2 + 256 // 2 + 256 // 8 + 12),
|
||||
GGMLQuantizationType.Q6_K: (256, 2 + 256 // 2 + 256 // 4 + 256 // 16),
|
||||
GGMLQuantizationType.Q8_K: (256, 4 + 256 + 256 // 8),
|
||||
GGMLQuantizationType.IQ2_XXS: (256, 2 + 256 // 4),
|
||||
GGMLQuantizationType.IQ2_XS: (256, 2 + 256 // 4 + 256 // 32),
|
||||
GGMLQuantizationType.IQ3_XXS: (256, 2 + 256 // 4 + 256 // 8),
|
||||
GGMLQuantizationType.IQ1_S: (256, 2 + 256 // 8 + 256 // 16),
|
||||
GGMLQuantizationType.IQ4_NL: (32, 2 + 16),
|
||||
GGMLQuantizationType.IQ3_S: (256, 2 + 256 // 4 + 256 // 8 + 256 // 32 + 4),
|
||||
GGMLQuantizationType.IQ2_S: (256, 2 + 256 // 4 + 256 // 16),
|
||||
GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + 256 // 2 + 256 // 64),
|
||||
GGMLQuantizationType.I8: (1, 1),
|
||||
GGMLQuantizationType.I16: (1, 2),
|
||||
GGMLQuantizationType.I32: (1, 4),
|
||||
GGMLQuantizationType.I64: (1, 8),
|
||||
GGMLQuantizationType.F64: (1, 8),
|
||||
GGMLQuantizationType.IQ1_M: (256, 256 // 8 + 256 // 16 + 256 // 32),
|
||||
}
|
||||
|
||||
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
||||
n_bytes = n_elements * type_size // block_size
|
||||
|
||||
data_bytes = mmap_data[offset : offset + n_bytes]
|
||||
data = torch.from_numpy(np.frombuffer(data_bytes, dtype=np.uint8).copy())
|
||||
|
||||
return data, ggml_type
|
||||
Reference in New Issue
Block a user