Merge branch 'main' into develop-cht

This commit is contained in:
chenht2022
2025-11-03 14:35:44 +00:00
192 changed files with 22265 additions and 12592 deletions

View File

@@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
<h2 id="Updates">🔥 Updates</h2>
* **Oct 27, 2025**: Support Ascend NPU. ([Tutorial](./doc/zh/DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md))
* **Oct 10, 2025**: Integrating into SGLang. ([Roadmap](https://github.com/sgl-project/sglang/issues/11425))
* **Sept 11, 2025**: Support Qwen3-Next. ([Tutorial](./doc/en/Qwen3-Next.md))
* **Sept 05, 2025**: Support Kimi-K2-0905. ([Tutorial](./doc/en/Kimi-K2.md))

0
config.json Normal file
View File

View File

@@ -1,3 +1,23 @@
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
if(KTRANSFORMERS_USE_NPU)
add_definitions(-DKTRANSFORMERS_USE_NPU=1)
endif()
if(KTRANSFORMERS_USE_NPU)
set(ASCEND_HOME_PATH "$ENV{ASCEND_HOME_PATH}")
message(STATUS "ASCEND_HOME_PATH is ${ASCEND_HOME_PATH}")
include_directories(${ASCEND_HOME_PATH}/include)
link_directories(${TORCH_INSTALL_PREFIX}/../torch.libs)
# find torch_npu
execute_process(
COMMAND python -c "import torch; import torch_npu; print(torch_npu.__path__[0])"
OUTPUT_VARIABLE TORCH_NPU_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
)
message(STATUS "Found PTA at: ${TORCH_NPU_PATH}")
find_library(PTA_LIBRARY torch_npu PATH "${TORCH_NPU_PATH}/lib")
endif()
cmake_minimum_required(VERSION 3.21)
find_program(GCC_COMPILER NAMES g++-13 g++-12 g++-11 g++ REQUIRED)

View File

@@ -73,8 +73,19 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
# include_directories(/usr/include/tbb)
# link_directories(/usr/lib64)
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
if(KTRANSFORMERS_USE_NPU)
add_definitions(-DKTRANSFORMERS_USE_NPU=1)
endif()
find_package(TBB REQUIRED)
find_package(CUDA REQUIRED)
if(KTRANSFORMERS_USE_NPU)
# NPU 构建
# find_package(CUDA REQUIRED) # NPU 情况不需要 CUDA
else()
# GPU 构建
find_package(CUDA REQUIRED)
endif()
# find_package(prometheus-cpp CONFIG REQUIRED)
if(NOT TARGET prometheus-cpp::pull)
@@ -83,18 +94,29 @@ else()
message(STATUS "prometheus Found!")
endif()
if(CUDA_FOUND)
message(STATUS "CUDA Found!")
message(STATUS "CUDA Version: ${CUDA_VERSION_STRING}")
message(STATUS "CUDA Toolkit Root: ${CUDA_TOOLKIT_ROOT_DIR}")
if(KTRANSFORMERS_USE_NPU)
# NPU 情况下不检查 CUDA
else()
message(FATAL_ERROR "CUDA not found!")
if(CUDA_FOUND)
message(STATUS "CUDA Found!")
message(STATUS "CUDA Version: ${CUDA_VERSION_STRING}")
message(STATUS "CUDA Toolkit Root: ${CUDA_TOOLKIT_ROOT_DIR}")
else()
message(FATAL_ERROR "CUDA not found!")
endif()
endif()
add_subdirectory(src)
if(BUILD_TEST)
add_subdirectory(test)
if(KTRANSFORMERS_USE_NPU)
message(STATUS "Build test...")
set(THIRD_PARTY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party)
add_subdirectory(${THIRD_PARTY_DIR}/spdlog ${CMAKE_CURRENT_SOURCE_DIR}/../build/third_party/spdlog)
add_subdirectory(test)
else()
add_subdirectory(test)
endif()
endif()
message(STATUS "BUILD_PYTHON_EXT: ${BUILD_PYTHON_EXT}")
@@ -126,7 +148,12 @@ endif()
set(PHOTON_CXX_STANDARD 14 CACHE INTERNAL "C++ standard")
set(CMAKE_CXX_FLAGS "-O3 -march=native")
if(KTRANSFORMERS_USE_NPU)
set(CMAKE_CXX_FLAGS "-O3 -march=armv8.2-a")
else()
set(CMAKE_CXX_FLAGS "-O3 -march=native")
endif()
message(STATUS "CMAKE_CXX_FLAGS of PhotonLibOS: ${CMAKE_CXX_FLAGS}")
add_subdirectory(${THIRD_PARTY_DIR}/PhotonLibOS ${THIRD_PARTY_BUILD_DIR}/PhotonLibOS)

View File

@@ -23,24 +23,46 @@ target_link_libraries(cache_entry PUBLIC gpu_cache)
add_library(gpu_cache gpu_cache.cpp)
add_third_party_includes(gpu_cache)
target_link_libraries(gpu_cache PUBLIC xxHash::xxhash ${TORCH_LIBRARIES} cuda_stream_manager)
# gpu_cache
if(KTRANSFORMERS_USE_NPU)
target_include_directories(gpu_cache PUBLIC ${TORCH_NPU_PATH}/include)
find_package(Python COMPONENTS Interpreter Development REQUIRED)
message("python include location: " ${Python_INCLUDE_DIRS} " lib location: " ${Python_LIBRARIES})
target_link_libraries(gpu_cache PUBLIC xxHash::xxhash "${TORCH_LIBRARIES}" "${TORCH_PYTHON_LIBRARY}" "${PTA_LIBRARY}" ${Python_LIBRARIES} cuda_stream_manager)
else()
target_link_libraries(gpu_cache PUBLIC xxHash::xxhash ${TORCH_LIBRARIES} cuda_stream_manager)
endif()
# kvc2
add_library(kvc2 prefix.cpp)
target_include_directories(kvc2 PRIVATE ${THIRD_PARTY_DIR}/nlohmann/single_include)
add_third_party_includes(kvc2)
target_link_libraries(kvc2 PUBLIC TBB::tbb xxHash::xxhash cache_entry cuda_stream_manager page_aligned_memory_pool ${TORCH_LIBRARIES} prometheus-cpp::pull kvc2_metrics)
if(KTRANSFORMERS_USE_NPU)
target_link_libraries(kvc2 PUBLIC TBB::tbb xxHash::xxhash cache_entry cuda_stream_manager page_aligned_memory_pool prometheus-cpp::pull kvc2_metrics)
else()
target_link_libraries(kvc2 PUBLIC TBB::tbb xxHash::xxhash cache_entry cuda_stream_manager page_aligned_memory_pool ${TORCH_LIBRARIES} prometheus-cpp::pull kvc2_metrics)
endif()
message(STATUS "CMAKE_SOURCE_DIR: " ${CMAKE_SOURCE_DIR})
# async_store
add_library(async_store async_store.cpp)
target_include_directories(async_store PRIVATE ${THIRD_PARTY_DIR}/nlohmann/single_include)
target_include_directories(async_store PRIVATE ${THIRD_PARTY_DIR}/PhotonLibOS/include)
target_include_directories(async_store PRIVATE ${THIRD_PARTY_DIR}/spdlog/include)
target_include_directories(async_store PRIVATE ${THIRD_PARTY_DIR}/PhotonLibOS/include)
target_link_libraries(async_store PUBLIC photon_static pthread)
# cuda_stream_manager
add_library(cuda_stream_manager cuda_stream_manager.cpp)
target_include_directories(cuda_stream_manager PUBLIC ${THIRD_PARTY_DIR}/nlohmann/single_include)
target_include_directories(cuda_stream_manager PUBLIC ${THIRD_PARTY_DIR}/spdlog/include)
target_include_directories(cuda_stream_manager PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
target_link_libraries(cuda_stream_manager PUBLIC CUDA::cudart)
if(KTRANSFORMERS_USE_NPU)
set(ASCEND_HOME_PATH "$ENV{ASCEND_HOME_PATH}")
target_include_directories(cuda_stream_manager PUBLIC ${ASCEND_HOME_PATH}/include)
target_link_directories(cuda_stream_manager PUBLIC ${ASCEND_HOME_PATH}/lib64)
target_link_libraries(cuda_stream_manager PUBLIC ascendcl)
else()
target_include_directories(cuda_stream_manager PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
target_link_libraries(cuda_stream_manager PUBLIC CUDA::cudart)
endif()

View File

@@ -1,7 +1,8 @@
#include "cuda_stream_manager.hh"
#include <cuda_runtime.h>
#include <functional>
#include <chrono>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <vector>
#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO
@@ -9,6 +10,12 @@
#define FMT_HEADER_ONLY
#include "spdlog/spdlog.h"
#ifdef KTRANSFORMERS_USE_NPU
#include "acl/acl_rt.h"
#else
#include <cuda_runtime.h>
#endif
CudaStreamManager::CudaStreamManager(const std::vector<size_t>& device_ids, int num_streams_per_device) {
for (int device_id : device_ids) {
auto x = std::unique_ptr<DeviceInfo>(new DeviceInfo);
@@ -18,25 +25,59 @@ CudaStreamManager::CudaStreamManager(const std::vector<size_t>& device_ids, int
device_info.stop_flag = false;
// 设置设备
cudaError_t err = cudaSetDevice(device_id);
if (err != cudaSuccess) {
SPDLOG_WARN("cudaSetDevice failed on device {}: {}", device_id, cudaGetErrorString(err));
#ifdef KTRANSFORMERS_USE_NPU
aclError acl_err = aclrtSetDevice(device_id);
if (acl_err != ACL_SUCCESS) {
SPDLOG_WARN("aclrtSetDevice failed on device {}: {}", device_id, acl_err);
throw std::runtime_error("aclrtSetDevice failed");
}
#else
cudaError_t cuda_err = cudaSetDevice(device_id);
if (cuda_err != cudaSuccess) {
SPDLOG_WARN("cudaSetDevice failed on device {}: {}", device_id, cudaGetErrorString(cuda_err));
throw std::runtime_error("cudaSetDevice failed");
}
#endif
// 创建 CUDA
// 创建流
device_info.streams.resize(num_streams_per_device);
for (int i = 0; i < num_streams_per_device; ++i) {
err = cudaStreamCreate(&device_info.streams[i]);
if (err != cudaSuccess) {
SPDLOG_WARN("Failed to create CUDA stream on device {}: {}", device_id, cudaGetErrorString(err));
#ifdef KTRANSFORMERS_USE_NPU
acl_err = aclrtCreateStream(&device_info.streams[i]);
if (acl_err != ACL_SUCCESS) {
SPDLOG_WARN("Failed to create NPU stream on device {}: {}", device_id, acl_err);
throw std::runtime_error("Failed to create NPU stream");
}
#else
cuda_err = cudaStreamCreate(&device_info.streams[i]);
if (cuda_err != cudaSuccess) {
SPDLOG_WARN("Failed to create CUDA stream on device {}: {}", device_id, cudaGetErrorString(cuda_err));
throw std::runtime_error("Failed to create CUDA stream");
}
#endif
}
// 启动设备工作线程
// 启动工作线程
device_info.worker_thread = std::thread(&CudaStreamManager::deviceWorker, this, std::ref(device_info));
#ifdef KTRANSFORMERS_USE_NPU
// NPU需要额外的回调线程
device_info.callback_thread = std::thread(&CudaStreamManager::deviceCallback, this, std::ref(device_info));
// 绑定回调线程
for (int i = 0; i < num_streams_per_device; ++i) {
std::ostringstream oss;
oss << device_info.callback_thread.get_id();
uint64_t tid = std::stoull(oss.str());
acl_err = aclrtSubscribeReport(tid, device_info.streams[i]);
SPDLOG_DEBUG("subscribe stream callback report on device {} with tid {} in idx {}", device_id, tid, i);
if (acl_err != ACL_SUCCESS) {
SPDLOG_WARN("Failed to subscribe stream callback report on device {}: {}", device_id, acl_err);
throw std::runtime_error("Failed to create stream callback job");
}
}
#endif
devices_.push_back(std::move(x));
}
}
@@ -52,15 +93,39 @@ CudaStreamManager::~CudaStreamManager() {
// 等待所有线程结束
for (auto& device_info : devices_) {
// 等待工作线程结束
if (device_info->worker_thread.joinable()) {
device_info->worker_thread.join();
}
// 销毁 CUDA 流
#ifdef KTRANSFORMERS_USE_NPU
// NPU需要额外的回调线程处理
aclrtSetDevice(device_info->device_id);
// 解绑callback任务并等待callback线程结束
std::ostringstream oss;
oss << device_info->callback_thread.get_id();
uint64_t tid = std::stoull(oss.str());
for (auto& stream: device_info->streams) {
aclrtUnSubscribeReport(tid, stream);
}
if (device_info->callback_thread.joinable()) {
device_info->callback_thread.join();
}
#endif
// 销毁流
#ifdef KTRANSFORMERS_USE_NPU
for (auto& stream : device_info->streams) {
aclrtDestroyStream(stream);
aclrtResetDevice(device_info->device_id);
}
#else
cudaSetDevice(device_info->device_id);
for (auto& stream : device_info->streams) {
cudaStreamDestroy(stream);
}
#endif
}
}
@@ -77,28 +142,47 @@ void CudaStreamManager::submitRequest(std::shared_ptr<Request> request) {
void CudaStreamManager::deviceWorker(DeviceInfo& device_info) {
// 设置设备
cudaError_t err = cudaSetDevice(device_info.device_id);
if (err != cudaSuccess) {
SPDLOG_WARN("cudaSetDevice failed in worker thread for device {}: {}", device_info.device_id,
cudaGetErrorString(err));
#ifdef KTRANSFORMERS_USE_NPU
aclError acl_err = aclrtSetDevice(device_info.device_id);
if (acl_err != ACL_SUCCESS) {
SPDLOG_WARN("aclrtSetDevice failed in worker thread for device {}: {}", device_info.device_id, acl_err);
return;
}
#else
cudaError_t cuda_err = cudaSetDevice(device_info.device_id);
if (cuda_err != cudaSuccess) {
SPDLOG_WARN("cudaSetDevice failed in worker thread for device {}: {}", device_info.device_id,
cudaGetErrorString(cuda_err));
return;
}
#endif
while (device_info.stop_flag.load() == false) {
auto request = device_info.request_queue.dequeue();
if (request->should_exit) {
return;
}
// 处理请求
SPDLOG_DEBUG("Getting request on device {}, count {}", device_info.device_id, request->host_mem_addresses.size());
int stream_index = device_info.next_stream_index;
cudaStream_t stream = device_info.streams[stream_index];
auto stream = device_info.streams[stream_index];
device_info.next_stream_index = (device_info.next_stream_index + 1) % device_info.streams.size();
size_t num_transfers = request->host_mem_addresses.size();
for (size_t i = 0; i < num_transfers; ++i) {
void* dst = request->device_mem_addresses[i];
void* src = request->host_mem_addresses[i];
#ifdef KTRANSFORMERS_USE_NPU
if (request->direction == ACL_MEMCPY_DEVICE_TO_HOST) {
std::swap(dst, src);
}
aclError err = aclrtMemcpyAsync(dst, request->sizes[i], src, request->sizes[i], request->direction, stream);
if (err != ACL_SUCCESS) {
SPDLOG_WARN("aclrtMemcpyAsync failed on device {}: {}", device_info.device_id, err);
continue;
}
#else
if (request->direction == cudaMemcpyDeviceToHost) {
std::swap(dst, src);
}
@@ -109,15 +193,31 @@ void CudaStreamManager::deviceWorker(DeviceInfo& device_info) {
// 可以根据需要处理错误,这里简单地继续
continue;
}
#endif
}
// 添加回调函数,因为是异步,所以需要包起来
// 添加回调函数
struct CallbackData {
std::function<void()> callback;
};
CallbackData* cb_data = new CallbackData{request->callback};
err = cudaLaunchHostFunc(
#ifdef KTRANSFORMERS_USE_NPU
aclError err = aclrtLaunchCallback(
[](void* data) {
CallbackData* cb_data = static_cast<CallbackData*>(data);
cb_data->callback();
delete cb_data;
},
cb_data,
ACL_CALLBACK_BLOCK,
stream);
if (err != ACL_SUCCESS) {
SPDLOG_WARN("aclrtLaunchCallback failed on device {}: {}", device_info.device_id, err);
}
#else
cudaError_t err = cudaLaunchHostFunc(
stream,
[](void* data) {
// SPDLOG_DEBUG("Callback function called");
@@ -129,7 +229,30 @@ void CudaStreamManager::deviceWorker(DeviceInfo& device_info) {
if (err != cudaSuccess) {
SPDLOG_WARN("cudaLaunchHostFunc failed on device {}: {}", device_info.device_id, cudaGetErrorString(err));
// 根据需要处理错误
}
#endif
}
}
#ifdef KTRANSFORMERS_USE_NPU
void CudaStreamManager::deviceCallback(DeviceInfo& device_info) {
aclError err = aclrtSetDevice(device_info.device_id);
if (err != ACL_SUCCESS) {
SPDLOG_WARN("aclrtSetDevice failed in callback thread for device {}: {}", device_info.device_id, err);
return;
}
int timeout = 60 * 1000; // ms
while (device_info.stop_flag.load() == false) {
err = aclrtProcessReport(timeout);
if (err != ACL_SUCCESS) {
if (err == ACL_ERROR_RT_THREAD_SUBSCRIBE) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
continue;
}
SPDLOG_WARN("aclrtProcessReport failed in callback thread for device {}: {}", device_info.device_id, err);
}
}
}
#endif

View File

@@ -8,7 +8,6 @@
*/
#pragma once
#include <cuda_runtime.h>
#include <atomic>
#include <functional>
#include <memory>
@@ -16,6 +15,11 @@
#include <vector>
#include "utils/mpsc.hpp"
#ifdef KTRANSFORMERS_USE_NPU
#include "acl/acl_mdl.h"
#else
#include <cuda_runtime.h>
#endif
class CudaStreamManager {
public:
// 构造函数,接受要使用的设备 ID 列表和每个设备的流数量
@@ -29,7 +33,11 @@ class CudaStreamManager {
std::vector<void*> host_mem_addresses;
std::vector<void*> device_mem_addresses;
std::vector<size_t> sizes;
#ifdef KTRANSFORMERS_USE_NPU
aclrtMemcpyKind direction;
#else
cudaMemcpyKind direction;
#endif
std::function<void()> callback;
};
@@ -40,7 +48,12 @@ class CudaStreamManager {
struct DeviceInfo {
int device_id;
std::thread worker_thread;
#ifdef KTRANSFORMERS_USE_NPU
std::thread callback_thread;
std::vector<aclrtStream> streams;
#else
std::vector<cudaStream_t> streams;
#endif
int next_stream_index;
MPSCQueueConsumerLock<std::shared_ptr<Request>> request_queue;
std::atomic_bool stop_flag;
@@ -51,4 +64,7 @@ class CudaStreamManager {
// 私有方法
void deviceWorker(DeviceInfo& device_info);
#ifdef KTRANSFORMERS_USE_NPU
void deviceCallback(DeviceInfo& device_info); // NPU 专用回调线程函数
#endif
};

View File

@@ -10,6 +10,24 @@
namespace kvc2 {
GPUPageCache::GPUPageCache(GPUPageCacheConfig& config) : config(config) {
#ifdef KTRANSFORMERS_USE_NPU
size_t gpu_count = c10_npu::device_count();
if (gpu_count > 0) {
SPDLOG_INFO("Number of available NPUs: {}, want {}", gpu_count, config.gpu_devices_id.size());
if (gpu_count < config.gpu_devices_id.size()) {
SPDLOG_ERROR("Not enough NPUs available.");
exit(0);
}
for (auto x : config.gpu_devices_id) {
std::string device_str = "npu:" + std::to_string(x);
// torch_npu::init_npu(device_str); should inited in scheduler
gpu_devices.push_back(at::Device(device_str));
}
} else {
SPDLOG_ERROR("NPU is not available on this system.");
exit(0);
}
#else
if (torch::cuda::is_available()) {
size_t gpu_count = torch::cuda::device_count();
SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count, config.gpu_devices_id.size());
@@ -24,6 +42,7 @@ GPUPageCache::GPUPageCache(GPUPageCacheConfig& config) : config(config) {
SPDLOG_ERROR("CUDA is not available on this system.");
exit(0);
}
#endif
SPDLOG_WARN("Creating GPU Cache");
shape.push_back(config.layer_count);
@@ -47,7 +66,9 @@ GPUPageCache::GPUPageCache(GPUPageCacheConfig& config) : config(config) {
if (config.k_cache_on) {
for (size_t i = 0; i < config.gpu_devices_id.size(); i++) {
auto k = torch::zeros(shape, torch::TensorOptions().dtype(config.tensor_type));
SPDLOG_INFO("ALLOCATE on CPU OK");
k = k.to(gpu_devices[i]);
SPDLOG_INFO("ALLOCATE on NPU OK");
k_cache.push_back(k);
@@ -95,6 +116,10 @@ GPUPageCache::GPUPageCache(GPUPageCacheConfig& config) : config(config) {
std::unique_ptr<CudaStreamManager>(new CudaStreamManager(config.gpu_devices_id, config.num_streams_per_device));
}
kvc2::GPUPageCache::~GPUPageCache() {
// torch_npu::finalize_npu()
}
bool GPUPageCache::alloc_col(std::vector<std::vector<std::shared_ptr<CacheBlockEntry>>>& k_entries,
std::vector<std::vector<std::shared_ptr<CacheBlockEntry>>>& v_entries, size_t at) {
std::lock_guard<std::mutex> lg(lock);
@@ -218,8 +243,13 @@ std::vector<std::unique_lock<CacheBlockEntry::MutexT>> GPUPageCache::try_lock_co
return re;
}
std::vector<std::shared_ptr<CudaStreamManager::Request>> GPUPageCache::basic_request(cudaMemcpyKind direction,
std::function<void()> callback) {
std::vector<std::shared_ptr<CudaStreamManager::Request>> GPUPageCache::basic_request(
#ifdef KTRANSFORMERS_USE_NPU
aclrtMemcpyKind direction,
#else
cudaMemcpyKind direction,
#endif
std::function<void()> callback) {
std::vector<std::shared_ptr<CudaStreamManager::Request>> re;
re.resize(config.gpu_devices_id.size(), nullptr);
for (size_t i = 0; i < re.size(); i++) {

View File

@@ -3,12 +3,20 @@
#include <torch/torch.h>
#include "cache_entry.hh"
#include "cuda_stream_manager.hh"
#include "defs.h"
#include "kvc2.h"
#include "metrics.h"
#include "utils/periodic_task.hpp"
// 根据设备类型包含不同的头文件和流管理器
#ifdef KTRANSFORMERS_USE_NPU
#include "torch_npu/csrc/libs/torch_npu.h"
#include "torch_npu/csrc/libs/init_npu.h"
#include "torch_npu/csrc/core/npu/NPUFunctions.h"
#else
#include "cuda_stream_manager.hh"
#endif
namespace kvc2 {
class GPUPageCache {
@@ -43,6 +51,7 @@ class GPUPageCache {
std::unique_ptr<periodic::PeriodicTask> background_flush_back =nullptr;
GPUPageCache(GPUPageCacheConfig& config);
~GPUPageCache(); // 统一添加析构函数声明
std::vector<size_t> gpu_only_alloc_col(size_t count);
void gpu_only_free_cols(std::vector<size_t> cols);
@@ -59,8 +68,14 @@ class GPUPageCache {
void free_col(size_t at);
std::vector<std::shared_ptr<CudaStreamManager::Request>> basic_request(cudaMemcpyKind direction,
std::function<void()> callback);
// 统一内存拷贝类型接口
std::vector<std::shared_ptr<CudaStreamManager::Request>> basic_request(
#ifdef KTRANSFORMERS_USE_NPU
aclrtMemcpyKind direction,
#else
cudaMemcpyKind direction,
#endif
std::function<void()> callback);
void submit_requests(std::vector<std::shared_ptr<CudaStreamManager::Request>> reqs);

View File

@@ -1,4 +1,7 @@
#ifdef KTRANSFORMERS_USE_NPU
#else
#include <immintrin.h>
#endif
#include <tbb/concurrent_hash_map.h>
#include <algorithm>
@@ -986,13 +989,23 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface {
}
});
cudaMemcpyKind direction;
if (option == IO_Read || option == IO_ForceRead) {
direction = cudaMemcpyHostToDevice;
}
if (option == IO_Write || option == IO_ForceWrite) {
direction = cudaMemcpyDeviceToHost;
}
#if defined(KTRANSFORMERS_USE_NPU) // NPU平台分支
aclrtMemcpyKind direction;
if (option == IO_Read || option == IO_ForceRead) {
direction = ACL_MEMCPY_HOST_TO_DEVICE;
}
if (option == IO_Write || option == IO_ForceWrite) {
direction = ACL_MEMCPY_DEVICE_TO_HOST;
}
#else // 默认GPU分支
cudaMemcpyKind direction;
if (option == IO_Read || option == IO_ForceRead) {
direction = cudaMemcpyHostToDevice;
}
if (option == IO_Write || option == IO_ForceWrite) {
direction = cudaMemcpyDeviceToHost;
}
#endif
auto reqs = gpu_cache->basic_request(direction, [io_helper]() { io_helper->batch_promise.set(); });
@@ -1752,7 +1765,11 @@ void GPUPageCache::gpu_background_flush() {
std::vector<CacheBlockEntry*> entries;
std::vector<std::unique_lock<CacheBlockEntry::MutexT>> uls;
BatchPromise promise(config.gpu_devices_id.size());
auto reqs = basic_request(cudaMemcpyDeviceToHost, [&promise]() { promise.set(); });
#if defined(KTRANSFORMERS_USE_NPU) // NPU分支
auto reqs = basic_request(ACL_MEMCPY_DEVICE_TO_HOST, [&promise]() { promise.set(); });
#else
auto reqs = basic_request(cudaMemcpyDeviceToHost, [&promise]() { promise.set(); });
#endif
for (size_t i = 0; i < config.total_kvcache_pages; i++) {
std::lock_guard<std::mutex> lg(this->lock);

View File

@@ -4,17 +4,38 @@ add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})
set(UTILS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/utils)
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
if(KTRANSFORMERS_USE_NPU)
add_definitions(-DKTRANSFORMERS_USE_NPU=1)
endif()
add_library(sched_metrics metrics.cpp)
target_include_directories(sched_metrics PRIVATE ${UTILS_DIR})
target_link_libraries(sched_metrics PUBLIC prometheus-cpp::pull)
add_library(sched scheduler.cpp)
if(KTRANSFORMERS_USE_NPU)
#target_link_directories(sched PUBLIC ${TORCH_NPU_PATH}/lib)
target_include_directories(sched PUBLIC ${TORCH_NPU_PATH}/include)
endif()
target_include_directories(sched PRIVATE ${SPDLOG_DIR}/include ${FMT_DIR}/include ${UTILS_DIR} ${KVC2_INCLUDE_DIR})
target_link_libraries(sched PUBLIC pthread ${TORCH_LIBRARIES} kvc2 async_store sched_metrics)
if(KTRANSFORMERS_USE_NPU)
target_link_libraries(sched PUBLIC pthread "${TORCH_LIBRARIES}" "${TORCH_PYTHON_LIBRARY}" "${PTA_LIBRARY}" kvc2 async_store sched_metrics)
else()
target_link_libraries(sched PUBLIC pthread ${TORCH_LIBRARIES} kvc2 async_store sched_metrics)
endif()
pybind11_add_module(sched_ext bind.cpp)
target_link_libraries(sched_ext PUBLIC sched ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})
if(KTRANSFORMERS_USE_NPU)
#target_link_directories(sched_ext PUBLIC ${TORCH_NPU_PATH}/lib)
target_include_directories(sched_ext PUBLIC ${TORCH_NPU_PATH}/include)
target_link_libraries(sched_ext PUBLIC "${TORCH_LIBRARIES}" "${TORCH_PYTHON_LIBRARY}" "${PTA_LIBRARY}" sched)
else()
target_link_libraries(sched_ext PUBLIC sched ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})
endif()

View File

@@ -25,28 +25,45 @@ using json = nlohmann::json;
namespace scheduler {
void Settings::auto_derive() {
// 统一设备数量获取方式
gpu_device_count = gpu_device_id.size();
if (torch::cuda::is_available()) {
size_t gpu_count = torch::cuda::device_count();
SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count,
gpu_device_count);
if (gpu_count < gpu_device_count) {
SPDLOG_ERROR("Not enough GPUs available.");
// 设备初始化分支
#ifdef KTRANSFORMERS_USE_NPU
size_t npu_count = c10_npu::device_count();
SPDLOG_INFO("Number of available NPUs: {}, want {}", npu_count, gpu_device_count);
if (npu_count < gpu_device_count) {
SPDLOG_ERROR("Not enough NPUs available.");
exit(0);
}
for (size_t i = 0; i < gpu_device_count; i++) {
devices.push_back(torch::Device(torch::kCUDA, gpu_device_id[i]));
std::string device_str = "npu:" + std::to_string(gpu_device_id[i]);
devices.push_back(torch::Device(device_str));
}
} else {
SPDLOG_ERROR("CUDA is not available on this system.");
exit(0);
}
if (model_settings.num_k_heads % gpu_device_count != 0) {
size_t head_per_gpu = model_settings.num_k_heads;
#else // GPU模式
if (torch::cuda::is_available()) {
size_t gpu_count = torch::cuda::device_count();
SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count, gpu_device_count);
if (gpu_count < gpu_device_count) {
SPDLOG_ERROR("Not enough GPUs available.");
exit(0);
}
for (size_t i = 0; i < gpu_device_count; i++) {
devices.push_back(torch::Device(torch::kCUDA, gpu_device_id[i]));
}
} else {
SPDLOG_ERROR("CUDA is not available on this system.");
exit(0);
}
if (model_settings.num_k_heads % gpu_device_count != 0) {
SPDLOG_ERROR("num_k_heads {} is not divisible by gpu_device_count {}",
model_settings.num_k_heads, gpu_device_count);
assert(false);
}
}
// 统一head_per_gpu计算方式每设备分配头数
size_t head_per_gpu = model_settings.num_k_heads / gpu_device_count;
#endif
size_t gpu_memory_available = gpu_memory_size * memory_utilization_percentage;
if (gpu_memory_available * gpu_device_count <
@@ -58,13 +75,12 @@ void Settings::auto_derive() {
}
assert(model_settings.k_head_dim % model_settings.num_k_heads == 0);
size_t head_per_gpu = model_settings.num_k_heads / gpu_device_count;
size_t gpu_memory_for_kv_cache =
gpu_memory_available /*- model_settings.params_nbytes() /
gpu_device_count*/
;
SPDLOG_INFO(
"Each GPU Total: {}MiB, Model Params: {}MiB, KVCache: {}MiB, Left: {}MiB",
"Each Device Total: {}MiB, Model Params: {}MiB, KVCache: {}MiB, Left: {}MiB",
gpu_memory_size / (1 << 20),
model_settings.params_nbytes() / gpu_device_count / (1 << 20),
gpu_memory_for_kv_cache / (1 << 20),
@@ -88,14 +104,18 @@ void Settings::auto_derive() {
max_total_kvcache_pages);
}
if (page_size % 256 != 0) {
SPDLOG_ERROR("page_size {} is not divisible by 256", page_size);
assert(false);
}
if (page_size < 256) {
SPDLOG_ERROR("page_size {} is smaller than 256", page_size);
assert(false);
}
#ifdef KTRANSFORMERS_USE_NPU
// NPU ND limit block size 16-128
#else
if (page_size % 256 != 0) {
SPDLOG_ERROR("page_size {} is not divisible by 256", page_size);
assert(false);
}
if (page_size < 256) {
SPDLOG_ERROR("page_size {} is smaller than 256", page_size);
assert(false);
}
#endif
}
std::string BatchQueryTodo::debug() {
@@ -284,7 +304,11 @@ struct KVC2_Maintainer {
.full_kv_cache_on_each_gpu = settings.full_kv_cache_on_each_gpu,
.k_cache_on = settings.k_cache_on,
.v_cache_on = settings.v_cache_on,
.tensor_type = torch::kBFloat16,
#ifdef KTRANSFORMERS_USE_NPU
.tensor_type = torch::kFloat16,
#else
.tensor_type = torch::kBFloat16,
#endif
};
auto model_configs_path =

View File

@@ -6,6 +6,13 @@
#include <torch/torch.h>
#include <vector>
// 条件编译仅在NPU环境下引入相关头文件
#ifdef KTRANSFORMERS_USE_NPU
#include "torch_npu/csrc/libs/torch_npu.h"
#include "torch_npu/csrc/libs/init_npu.h"
#include "torch_npu/csrc/core/npu/NPUFunctions.h"
#endif
namespace scheduler {
using Token = uint32_t;

View File

@@ -42,6 +42,11 @@ option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA"
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF)
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
if(KTRANSFORMERS_USE_NPU)
add_definitions(-DKTRANSFORMERS_USE_NPU=1)
endif()
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
@@ -89,6 +94,9 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STR
endif ()
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})
else()
if(KTRANSFORMERS_USE_NPU)
list(APPEND ARCH_FLAGS -march=armv8.2-a+fp16+fp16fml+dotprod -lnuma)
endif()
check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
list(APPEND ARCH_FLAGS -mfp16-format=ieee)
@@ -116,36 +124,38 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
message(STATUS "x86 detected")
set(HOST_IS_X86 TRUE)
set(HAS_AVX512 TRUE)
set(__HAS_AMX__ TRUE)
add_compile_definitions(__x86_64__)
# check AVX512
execute_process(
COMMAND lscpu
OUTPUT_VARIABLE LSCPU_OUTPUT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
if(NOT KTRANSFORMERS_USE_NPU)
set(HOST_IS_X86 TRUE)
set(HAS_AVX512 TRUE)
set(__HAS_AMX__ TRUE)
add_compile_definitions(__x86_64__)
# check AVX512
execute_process(
COMMAND lscpu
OUTPUT_VARIABLE LSCPU_OUTPUT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
if (COMPILER_SUPPORTS_AVX512F GREATER -1)
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
add_compile_definitions(__HAS_AVX512F__)
else()
message(STATUS "Compiler and/or CPU do NOT support AVX512F")
set(HAS_AVX512 False)
endif()
# check AMX
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
if (COMPILER_SUPPORTS_AVX512F GREATER -1)
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
add_compile_definitions(__HAS_AVX512F__)
else()
message(STATUS "Compiler and/or CPU do NOT support AVX512F")
set(HAS_AVX512 False)
endif()
if(COMPILER_SUPPORTS_AMX GREATER -1)
message(STATUS "Compiler supports AMX")
add_compile_definitions(__HAS_AMX__)
else()
message(STATUS "Compiler does NOT support AMX")
# check AMX
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
if(COMPILER_SUPPORTS_AMX GREATER -1)
message(STATUS "Compiler supports AMX")
add_compile_definitions(__HAS_AMX__)
else()
message(STATUS "Compiler does NOT support AMX")
endif()
endif()
if (MSVC)
# instruction set detection for MSVC only
@@ -306,7 +316,7 @@ elseif (UNIX)
endif()
elseif (KTRANSFORMERS_USE_XPU)
add_compile_definitions(KTRANSFORMERS_USE_XPU=1)
else()
elseif (KTRANSFORMERS_USE_CUDA)
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
include(CheckLanguage)
@@ -325,7 +335,13 @@ endif()
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4)
# aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4)
file(GLOB LLAMAFILE_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/*.cpp")
list(REMOVE_ITEM LLAMAFILE_SOURCES
"${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/sgemm_arm.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/sgemm_x86.cpp"
)
set(SOURCE_DIR4 ${LLAMAFILE_SOURCES})
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)
if (HOST_IS_X86 AND HAS_AVX512 AND __HAS_AMX__)
@@ -365,7 +381,7 @@ elseif(UNIX)
elseif(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
elseif(KTRANSFORMERS_USE_XPU)
else()
elseif(KTRANSFORMERS_USE_CUDA AND NOT KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif()
endif()
@@ -396,4 +412,4 @@ else()
else()
message(STATUS "NUMA library not found or user not set USE_NUMA - disabling NUMA support")
endif()
endif()
endif()

View File

@@ -9,7 +9,7 @@
**/
// Python bindings
#include "cpu_backend/cpuinfer.h"
#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU)
#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU) && !defined(KTRANSFORMERS_USE_NPU)
#include "device_launch_parameters.h"
#endif
#include "llamafile/flags.h"

View File

@@ -0,0 +1,181 @@
# 基准测试结果
| Prompt length | 1K | 2K | 4K |
| --------------------------------- | ------ | ------ | ------ |
| KTrans Prefill token/s | 174.68 | 169.52 | 167.15 |
| KTrans Decode token/s | 16.07 | 16.12 | 16.48 |
## 先决条件
我们在以下配置下进行了Deepseek-R1最佳性能测试
- 服务器型号Atlas 2UP
- NPU300I A2
- CPU: HUAWEI Kunpeng 920 7270Z
- 内存: DDR5服务器内存1TB
# 部署
## 物理机安装
部署满血版Deepseek-R1/V3需要机器物理内存能够存放下全部路由专家的权重约400GB。
目前支持的NPU型号**300I A2**。
在技术人员的支持下完成硬件安装。
## 系统安装
根据网页[昇腾兼容性查询助手](https://www.hiascend.com/hardware/compatibility)查询选用系统Ubuntu 22.04 for aarch64内核5.15.0-25-generic并禁止系统自动更新。系统镜像获取链接[ubuntu-old-releases](https://mirrors.aliyun.com/oldubuntu-releases/releases/22.04)。
## HDK安装
选择[Ascend HDK 25.3.RC1](https://support.huawei.com/enterprise/zh/ascend-computing/ascend-hdk-pid-252764743/software/264672545?idAbsPath=fixnode01|23710424|251366513|254884019|261408772|252764743)进行安装,安装方式参考[昇腾社区HDK安装指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/83RC1alpha003/softwareinst/instg/instg_0005.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)。
## Conda部署
建议按照最新[Installation Guide - kTransformers](https://kvcache-ai.github.io/ktransformers/en/install.html)部署开发环境此处注意Python版本要求3.11其他版本未验证arm平台不需要安装cpufeature包。
安装conda/miniconda
```bash
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh
bash ~/Miniconda3-latest-Linux-aarch64.sh
```
部署Python环境
```bash
conda create -n py311 python=3.11
conda activate py311
conda install -c conda-forge libstdcxx-ng # 安装`GLIBCXX-3.4.32`
apt install zlib1g-dev libtbb-dev libssl-dev libaio-dev libcurl4-openssl-dev
pip3 install numpy==1.26.4 # 适配torch/torch_npu
pip3 install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu
pip3 install packaging ninja fire protobuf attrs decorator cloudpickle ml-dtypes scipy tornado absl-py psutil
pip3 install sqlalchemy
pip3 install transformers==4.57.1 #此处注意运行时transformers版本要求4.57.1(其他版本未验证)
#pip3 install cpufeature # only for x86
```
## CANN安装
选择[CANN 8.3.RC1.alpha003](https://www.hiascend.com/developer/download/community/result?cann=8.3.RC1.alpha003&product=4&model=32)进行安装,安装方式参考[昇腾社区CANN安装指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/83RC1alpha003/softwareinst/instg/instg_quick.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)。
需要安装ToolKitKernel和NNAL。
## torch_npu安装
获取最新的仓库代码:[torch_npu Gitcode](https://gitcode.com/Ascend/pytorch)
由于涉及新增算子公网pypi内提供的torch_npu暂时无法直接使用可以下载代码仓库编译当前适配分支为v2.5.1,编译命令可以参考仓库内文档。
编译过程需要保证访问githubgitcode等平台网络畅通并设置如下环境变量
```bash
source /usr/local/Ascend/ascend-toolkit/set_env.sh # 以实际CANN安装路径为准
source /usr/local/Ascend/nnal/atb/set_env.sh # 以实际NNAL安装路径为准
```
由于环境对于torch_npu版本号有特定要求使用编译后的torch_npu包需要手动移除版本信息中的哈希后缀操作如下
使用文本编辑器打开`/usr/local/lib/python3.11/site-packages/torch_npu/version.py`(不同环境python路径可能不同可以使用`pip show torch_npu`查看安装的python路径)
`__version__ = '2.5.1.post4+git69550dfc'`改为`__version__ = '2.5.1.post4'`
## 权重准备
目前,为了满足性能和精度的要求,我们需要准备两份权重,并使用提供的权重合并脚本对权重进行合并,最终只会使用合并后的权重。
Q4权重[DeepSeek-R1-Q4_K_M](https://modelscope.cn/models/unsloth/DeepSeek-R1-GGUF/files)
W8A8权重[DeepSeek-R1-W8A8](https://modelers.cn/models/MindSpore-Lab/DeepSeek-R1-W8A8/tree/main)
使用[merge_safetensor_gguf.py](../../merge_tensors/merge_safetensor_gguf.py)来合并Q4和W8A8权重
```bash
python merge_safetensor_gguf.py --safetensor_path /mnt/weights/DeepSeek-R1-Q4_K_M --gguf_path /mnt/weights/DeepSeek-R1-W8A8 --output_path /mnt/weights/DeepSeek-R1-q4km-w8a8
```
## 图下沉部署
开启图下沉功能,需要添加如下环境变量:
```bash
export TASK_QUEUE_ENABLE=0 # 保证算子下发顺序有序
```
## kTransformers部署
将项目文件部署到机器上:
- 初始化third_party。由于此过程耗时较多且容易受网络影响导致仓库克隆失败建议初始化一次后将相关文件进行打包以便后续直接解压使用。
```bash
git clone https://github.com/kvcache-ai/ktransformers.git
cd ktransformers
git submodule update --init --recursive
```
- 对于arm平台注释掉`./third_party/llamafile/iqk_mul_mat_arm82.cpp`中的
```cpp
#define iqk_mul_mat iqk_mul_mat_arm82
#define iqk_mul_mat_moe iqk_mul_mat_moe_arm82
```
- 执行`source /usr/local/Ascend/ascend-toolkit/set_env.sh`以实际CANN-TOOLKIT安装路径为准
- 执行`apt install cmake libhwloc-dev pkg-config`安装依赖。
- 修改项目目录下 /ktransformers/config/config.yaml 中attn部分的page_size: 128 chunk_size: 16384
- 执行`USE_BALANCE_SERVE=1 USE_NUMA=1 bash ./install.sh`,等待安装完成。
此处给出示例balance_serve的启动脚本由于使用了相对路径需将该脚本放至项目的根路径下
```bash
#!/bin/bash
export USE_MERGE=0
export INF_NAN_MODE_FORCE_DISABLE=1
export TASK_QUEUE_ENABLE=0
export RANK=0
export LOCAL_WORLD_SIZE=1
#export PROF_DECODE=1
#export PROF_PREFILL=1
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
python ktransformers/server/main.py \
--port 10002 \
--model_path <your model path> \
--gguf_path <your model path> \
--model_name DeepSeekV3ForCausalLM \
--cpu_infer 100 \
--optimize_config_path ./ktransformers/optimize/optimize_rules/npu/DeepSeek-V3-Chat-300IA2-npu-serve.yaml \
--max_new_tokens 1024 \
--cache_lens 20480 \
--max_batch_size 4 \
--use_cuda_graph \
--tp 1 \
--backend_type balance_serve
```
相关参数说明:
- `--model_path`kTransformers原生参数str此处用来指定合并后的模型文件路径
- `--gguf_path`kTransformers原生参数str此处用来指定合并后的模型文件路径
- `--cpu_infer`kTransformers原生参数int用来控制CPU侧实际worker线程数非必选
- `--optimize_config_path`kTransformers原生参数str用来指定所用的模型优化配置文件需要注意相对路径的使用此处为**必选**
- `--cache_lens`:调度器申请 kvcache 的总长度。所有请求共享指定数量(例如 `20480`)的 tokens 对应的 kvcache 空间,请求完成后会释放其所占用的 kvcache 空间,非必选
- `--use_cuda_graph`kTransformers原生参数bool为True表示开启图下沉为False表示关闭图下沉非必选
- `--max_new_tokens`kTransformers原生参数int当统计到输出的tokens数量达到该值时会直接中止输出非必选
- `--tp`新增参数int用于开启tensor model parallel功能目前local_chat只支持tp大小与ws大小相同不支持local_chat使用多dp非必选
# 其他问题
## 可能存在的其他依赖问题
ImportError: libhccl.so: cannot open shared object file: No such file or directory
```bash
source /usr/local/Ascend/ascend-toolkit/set_env.sh # 以实际CANN安装路径为准
```
ImportError: libascend_hal.so: cannot open shared object file: No such file or directory
```bash
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH # 以实际Driver安装路径为准
```

View File

@@ -40,4 +40,4 @@ fi
# cp -a csrc/balance_serve/build/third_party/prometheus-cpp/lib/libprometheus-cpp-*.so* $SITE_PACKAGES/
# patchelf --set-rpath '$ORIGIN' $SITE_PACKAGES/sched_ext.cpython*
echo "Installation completed successfully"
echo "Installation completed successfully"

21
install_for_npu.sh Normal file
View File

@@ -0,0 +1,21 @@
#!/bin/bash
set -e
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# clear build dirs
rm -rf build
rm -rf *.egg-info
rm -rf csrc/build
rm -rf csrc/ktransformers_ext/build
rm -rf csrc/ktransformers_ext/cuda/build
rm -rf csrc/ktransformers_ext/cuda/dist
rm -rf csrc/ktransformers_ext/cuda/*.egg-info
rm -rf ~/.ktransformers
echo "Installing python dependencies from requirements.txt"
pip install -r requirements-local_chat.txt
pip install -r ktransformers/server/requirements.txt
echo "Installing ktransformers"
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -v . --no-build-isolation
pip install third_party/custom_flashinfer/
echo "Installation completed successfully"

21
kt-kernel/.gitignore vendored Normal file
View File

@@ -0,0 +1,21 @@
debug/
debug_prefill/
debug_decode/
debug1/
debug2/
.gdbinit
bp.gdb
.gdb_history
build/
# local git hooks installer and hooks
.clangd
.cache
tmp*
.vscode/
*.egg-info/
*.pyc
*.so
sparse_logs/
build-cm/
*.so
sparse_logs/

View File

@@ -2,7 +2,33 @@ cmake_minimum_required(VERSION 3.16)
# Toggle: default to system compilers; optionally use conda toolchain
option(USE_CONDA_TOOLCHAIN "Use C/C++ compilers and libraries from active conda env" OFF)
option(LLAMA_NATIVE "llama: enable -march=native flag" OFF)
option(LLAMA_AVX "llama: enable AVX" OFF)
option(LLAMA_AVX2 "llama: enable AVX2" OFF)
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
option(LLAMA_FMA "llama: enable FMA" OFF)
# in MSVC F16C is implied with AVX2/AVX512
if(NOT MSVC)
option(LLAMA_F16C "llama: enable F16C" OFF)
endif()
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF)
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
option(KTRANSFORMERS_CPU_USE_KML "ktransformers: CPU use KML" OFF)
option(KTRANSFORMERS_CPU_USE_AMX_AVX512 "ktransformers: CPU use AMX or AVX512" OFF)
option(KTRANSFORMERS_CPU_USE_AMX "ktransformers: CPU use AMX" OFF)
option(KTRANSFORMERS_CPU_DEBUG "ktransformers: DEBUG CPU use AMX" OFF)
option(KTRANSFORMERS_CPU_MLA "ktransformers: CPU use MLA" OFF)
option(KTRANSFORMERS_CPU_MOE_KERNEL "ktransformers: CPU use moe kernel" OFF)
option(KTRANSFORMERS_CPU_MOE_AMD "ktransformers: CPU use moe kernel for amd" OFF)
# LTO control
option(CPUINFER_ENABLE_LTO "Enable link time optimization (IPO)" OFF)
project(kt_kernel_ext VERSION 0.1.0)
# Choose compilers BEFORE project() so CMake honors them
if(USE_CONDA_TOOLCHAIN)
if(NOT DEFINED ENV{CONDA_PREFIX} OR NOT EXISTS "$ENV{CONDA_PREFIX}")
@@ -24,8 +50,6 @@ else()
endif()
endif()
project(cpuinfer_ext VERSION 0.1.0)
# If explicitly using conda toolchain, prefer its libraries/headers and RPATH
if(USE_CONDA_TOOLCHAIN)
@@ -88,7 +112,7 @@ add_compile_definitions(FMT_HEADER_ONLY)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -ffast-math")
set(CMAKE_BUILD_TYPE "Release")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -fsanitize=address -fno-omit-frame-pointer")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g ")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0")
# set(CMAKE_BUILD_TYPE "Debug")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
find_package(OpenMP REQUIRED)
@@ -98,7 +122,6 @@ include(CheckCXXCompilerFlag)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
option(LLAMA_NATIVE "llama: enable -march=native flag" ON)
# instruction set specific
if(LLAMA_NATIVE)
@@ -106,51 +129,10 @@ if(LLAMA_NATIVE)
else()
set(INS_ENB ON)
endif()
option(LLAMA_AVX "llama: enable AVX" OFF)
option(LLAMA_AVX2 "llama: enable AVX2" OFF)
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
option(LLAMA_FMA "llama: enable FMA" OFF)
# in MSVC F16C is implied with AVX2/AVX512
if(NOT MSVC)
option(LLAMA_F16C "llama: enable F16C" OFF)
endif()
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF)
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
option(KTRANSFORMERS_CPU_USE_KML "ktransformers: CPU use KML" OFF)
option(KTRANSFORMERS_CPU_USE_AMX_AVX512 "ktransformers: CPU use AMX or AVX512" ON)
option(KTRANSFORMERS_CPU_USE_AMX "ktransformers: CPU use AMX" OFF)
option(KTRANSFORMERS_CPU_DEBUG "ktransformers: DEBUG CPU use AMX" OFF)
option(KTRANSFORMERS_CPU_MLA "ktransformers: CPU use MLA" OFF)
# LTO control
option(CPUINFER_ENABLE_LTO "Enable link time optimization (IPO)" OFF)
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
if(MSVC)
string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR)
message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}")
else()
set(CMAKE_GENERATOR_PLATFORM_LWR "")
endif()
if(NOT MSVC)
if(LLAMA_STATIC)
add_link_options(-static)
if(MINGW)
add_link_options(-static-libgcc -static-libstdc++)
endif()
endif()
if(LLAMA_GPROF)
add_compile_options(-pg)
endif()
endif()
set(ARCH_FLAGS "")
@@ -266,14 +248,14 @@ elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR
list(APPEND ARCH_FLAGS -mfma)
endif()
if(LLAMA_AVX)
list(APPEND ARCH_FLAGS -mavx)
list(APPEND ARCH_FLAGS -mavx -mfma -msse3 -mf16c)
message(WARNING "pure AVX is not supported at least avx2")
endif()
if(LLAMA_AVX2)
list(APPEND ARCH_FLAGS -mavx2)
list(APPEND ARCH_FLAGS -mavx2 -mfma -msse3 -mf16c)
endif()
if(LLAMA_AVX512)
list(APPEND ARCH_FLAGS -mavx512f)
list(APPEND ARCH_FLAGS -mavx512bw)
list(APPEND ARCH_FLAGS -mavx512f -mavx512bw -mfma -mf16c)
endif()
if(LLAMA_AVX512_VBMI)
list(APPEND ARCH_FLAGS -mavx512vbmi)
@@ -305,14 +287,6 @@ else()
message(STATUS "Unknown architecture")
endif()
# message(STATUS "CUDAToolkit_ROOT:${CUDAToolkit_ROOT}")
# find_package(FindCUDAToolkit REQUIRED)
# if(CUDAToolkit_FOUND)
# message(STATUS "Found CUDA cudart lib at:${CUDAToolkit_LIBRARY_DIR}")
# else()
# message(STATUS "Can't found CUDA lib")
# endif()
if(NOT EXISTS $ENV{ROCM_PATH})
if(NOT EXISTS /opt/rocm)
set(ROCM_PATH /usr)
@@ -338,6 +312,62 @@ endif()
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
if(KTRANSFORMERS_CPU_MOE_AMD)
set(BLIS_ROOT "" CACHE PATH "Root directory of BLIS installation")
set(_BLIS_SEARCH_DIRS)
if(BLIS_ROOT)
list(APPEND _BLIS_SEARCH_DIRS "${BLIS_ROOT}")
endif()
list(APPEND _BLIS_SEARCH_DIRS "/usr/local" "/usr")
find_path(BLIS_INCLUDE_DIR
NAMES blis.h
HINTS ${_BLIS_SEARCH_DIRS}
PATH_SUFFIXES include include/blis
)
find_library(BLIS_LIBRARY
NAMES blis
HINTS ${_BLIS_SEARCH_DIRS}
PATH_SUFFIXES lib lib64
)
if(NOT BLIS_INCLUDE_DIR OR NOT BLIS_LIBRARY)
message(FATAL_ERROR "BLIS not found; set BLIS_ROOT or specify BLIS_INCLUDE_DIR/BLIS_LIBRARY")
else()
message(STATUS "Found BLIS include at ${BLIS_INCLUDE_DIR}")
message(STATUS "Found BLIS library ${BLIS_LIBRARY}")
endif()
target_include_directories(${PROJECT_NAME} PRIVATE ${BLIS_INCLUDE_DIR})
target_link_libraries(${PROJECT_NAME} PRIVATE ${BLIS_LIBRARY})
endif()
if(HOST_IS_X86)
if(KTRANSFORMERS_CPU_USE_AMX_AVX512)
add_compile_definitions(USE_AMX_AVX_KERNEL=1)
if(KTRANSFORMERS_CPU_USE_AMX)
add_compile_definitions(HAVE_AMX=1)
list(APPEND ARCH_FLAGS -mamx-tile -mamx-bf16 -mamx-int8)
message(STATUS "AMX enabled")
list(APPEND ARCH_FLAGS -mamx-tile)
endif()
# add_executable(amx-test ${CMAKE_CURRENT_SOURCE_DIR}/operators/amx/amx-test.cpp)
# target_link_libraries(amx-test llama)
if(KTRANSFORMERS_CPU_DEBUG)
file(GLOB AMX_TEST_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/operators/amx/test/*.cpp")
foreach(test_src ${AMX_TEST_SOURCES})
# 获取不带扩展名的文件名作为 target 名
get_filename_component(test_name ${test_src} NAME_WE)
add_executable(${test_name} ${test_src} ${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend/shared_mem_buffer.cpp)
target_link_libraries(${test_name} llama OpenMP::OpenMP_CXX numa)
endforeach()
endif()
list(APPEND ARCH_FLAGS -mfma -mf16c -mavx512bf16 -mavx512vnni)
endif()
endif()
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
add_compile_options("$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
@@ -345,53 +375,48 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third_party/pybind11 ${CMAKE_CURREN
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party)
if(WIN32)
include_directories("$ENV{CUDA_PATH}/include")
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
elseif(UNIX)
if(KTRANSFORMERS_USE_CUDA)
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA detected")
find_package(CUDAToolkit REQUIRED)
include_directories(${CUDAToolkit_INCLUDE_DIRS})
else()
message(FATAL_ERROR "KTRANSFORMERS_USE_CUDA=ON but CUDA compiler not found")
endif()
message(STATUS "enabling CUDA")
enable_language(CUDA)
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
elseif(KTRANSFORMERS_USE_ROCM)
find_package(HIP REQUIRED)
if(HIP_FOUND)
include_directories("${HIP_INCLUDE_DIRS}")
add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)
endif()
elseif(KTRANSFORMERS_USE_MUSA)
if(NOT EXISTS $ENV{MUSA_PATH})
if(NOT EXISTS /opt/musa)
set(MUSA_PATH /usr/local/musa)
else()
set(MUSA_PATH /opt/musa)
endif()
else()
set(MUSA_PATH $ENV{MUSA_PATH})
endif()
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
find_package(MUSAToolkit)
if(MUSAToolkit_FOUND)
message(STATUS "MUSA Toolkit found")
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
endif()
elseif(KTRANSFORMERS_CPU_USE_KML)
message(STATUS "KML CPU detected")
if(KTRANSFORMERS_USE_CUDA)
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA detected")
find_package(CUDAToolkit REQUIRED)
include_directories(${CUDAToolkit_INCLUDE_DIRS})
else()
message(STATUS "No GPU support enabled, building for CPU only")
add_compile_definitions(KTRANSFORMERS_CPU_ONLY=1)
message(FATAL_ERROR "KTRANSFORMERS_USE_CUDA=ON but CUDA compiler not found")
endif()
message(STATUS "enabling CUDA")
enable_language(CUDA)
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
elseif(KTRANSFORMERS_USE_ROCM)
find_package(HIP REQUIRED)
if(HIP_FOUND)
include_directories("${HIP_INCLUDE_DIRS}")
add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)
endif()
elseif(KTRANSFORMERS_USE_MUSA)
if(NOT EXISTS $ENV{MUSA_PATH})
if(NOT EXISTS /opt/musa)
set(MUSA_PATH /usr/local/musa)
else()
set(MUSA_PATH /opt/musa)
endif()
else()
set(MUSA_PATH $ENV{MUSA_PATH})
endif()
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
find_package(MUSAToolkit)
if(MUSAToolkit_FOUND)
message(STATUS "MUSA Toolkit found")
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
endif()
elseif(KTRANSFORMERS_CPU_USE_KML)
message(STATUS "KML CPU detected")
else()
message(STATUS "No GPU support enabled, building for CPU only")
add_compile_definitions(KTRANSFORMERS_CPU_ONLY=1)
endif()
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
@@ -404,14 +429,31 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)
# arm64
if(NOT HOST_IS_X86 AND KTRANSFORMERS_CPU_USE_KML)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kml SOURCE_DIR6)
if(NOT KTRANSFORMERS_CPU_MLA)
list(REMOVE_ITEM SOURCE_DIR6 ${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/mla/)
endif()
endif()
# message(STATUS "SOURCE_DIR6: ${SOURCE_DIR6}")
# aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR7)
if(KTRANSFORMERS_CPU_MOE_KERNEL)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/la SOURCE_DIR7)
if(KTRANSFORMERS_CPU_MOE_AMD)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/aocl_kernel SOURCE_DIR7_KERNEL)
add_compile_definitions(USE_MOE_KERNEL_AMD=1)
elseif(NOT HOST_IS_X86 AND KTRANSFORMERS_CPU_USE_KML)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel SOURCE_DIR7_KERNEL)
endif()
list(APPEND SOURCE_DIR7 ${SOURCE_DIR7_KERNEL})
if(NOT KTRANSFORMERS_CPU_MLA)
list(REMOVE_ITEM SOURCE_DIR7 ${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mla/)
endif()
add_compile_definitions(USE_MOE_KERNEL=1)
endif()
message(STATUS "SOURCE_DIR7: ${SOURCE_DIR7}")
set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6})
set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6} ${SOURCE_DIR7})
file(GLOB_RECURSE FMT_SOURCES
"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp"
@@ -437,47 +479,47 @@ if(NOT DEFINED CLANG_FORMAT_BIN)
)
endif()
if(NOT CLANG_FORMAT_BIN)
message(FATAL_ERROR "clang-format not found. Please install clang-format (>=18) or pass -DCLANG_FORMAT_BIN=/full/path and reconfigure.")
endif()
execute_process(
COMMAND ${CLANG_FORMAT_BIN} --version
OUTPUT_VARIABLE _CLANG_FORMAT_VER
OUTPUT_STRIP_TRAILING_WHITESPACE
)
message(STATUS "CMake PATH: $ENV{PATH}")
# Parse version string, e.g. "Ubuntu clang-format version 19.1.0" or "clang-format version 18.1.8"
string(REGEX MATCH "version[ ]+([0-9]+(\\.[0-9]+)*)" _CF_VER_MATCH "${_CLANG_FORMAT_VER}")
if(NOT _CF_VER_MATCH)
message(FATAL_ERROR "Failed to parse clang-format version from: ${_CLANG_FORMAT_VER}")
endif()
set(CLANG_FORMAT_VERSION "${CMAKE_MATCH_1}")
message(STATUS "Using clang-format ${CLANG_FORMAT_VERSION} at ${CLANG_FORMAT_BIN}")
if(CLANG_FORMAT_VERSION VERSION_LESS "18.0.0")
message(FATAL_ERROR "clang-format >=18.0.0 required (found ${CLANG_FORMAT_VERSION} at ${CLANG_FORMAT_BIN}).\n"
"Tip: Ensure your desired clang-format (e.g., conda's ${CONDA_PREFIX}/bin/clang-format) is earlier in PATH when running CMake,\n"
"or pass -DCLANG_FORMAT_BIN=/full/path/to/clang-format.")
endif()
message(WARNING "clang-format not found. Please install clang-format (>=18) or pass -DCLANG_FORMAT_BIN=/full/path and reconfigure.")
else()
execute_process(
COMMAND ${CLANG_FORMAT_BIN} --version
OUTPUT_VARIABLE _CLANG_FORMAT_VER
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# message(STATUS "CMake PATH: $ENV{PATH}")
# Parse version string, e.g. "Ubuntu clang-format version 19.1.0" or "clang-format version 18.1.8"
string(REGEX MATCH "version[ ]+([0-9]+(\\.[0-9]+)*)" _CF_VER_MATCH "${_CLANG_FORMAT_VER}")
if(NOT _CF_VER_MATCH)
message(WARNING "Failed to parse clang-format version from: ${_CLANG_FORMAT_VER}")
endif()
set(CLANG_FORMAT_VERSION "${CMAKE_MATCH_1}")
message(STATUS "Using clang-format ${CLANG_FORMAT_VERSION} at ${CLANG_FORMAT_BIN}")
if(CLANG_FORMAT_VERSION VERSION_LESS "18.0.0")
message(WARNING "clang-format >=18.0.0 required (found ${CLANG_FORMAT_VERSION} at ${CLANG_FORMAT_BIN}).\n"
"Tip: Ensure your desired clang-format (e.g., conda's ${CONDA_PREFIX}/bin/clang-format) is earlier in PATH when running CMake,\n"
"or pass -DCLANG_FORMAT_BIN=/full/path/to/clang-format.")
endif()
add_custom_target(
format
COMMAND ${CLANG_FORMAT_BIN}
-i
-style=file
-fallback-style=none
${FMT_SOURCES}
COMMENT "Running clang-format on all source files"
)
add_custom_target(
format
COMMAND ${CLANG_FORMAT_BIN}
-i
-style=file
-fallback-style=none
${FMT_SOURCES}
COMMENT "Running clang-format on all source files"
)
# Optional: target to check formatting without modifying files (CI-friendly)
add_custom_target(
format-check
COMMAND ${CLANG_FORMAT_BIN}
-n --Werror
-style=file
-fallback-style=none
${FMT_SOURCES}
COMMENT "Checking clang-format on all source files"
)
# Optional: target to check formatting without modifying files (CI-friendly)
add_custom_target(
format-check
COMMAND ${CLANG_FORMAT_BIN}
-n --Werror
-style=file
-fallback-style=none
${FMT_SOURCES}
COMMENT "Checking clang-format on all source files"
)
endif()
include(FindPkgConfig)
if(PKG_CONFIG_FOUND)
@@ -489,8 +531,6 @@ endif(PKG_CONFIG_FOUND)
add_library(llamafile STATIC ${SOURCE_DIR4})
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
message(STATUS "ARCH_FLAGS: ${ARCH_FLAGS}")
if(CPUINFER_ENABLE_LTO)
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION ON)
@@ -502,6 +542,7 @@ else()
pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES})
message(STATUS "LTO: disabled")
endif()
# Ensure the module target also has correct RPATH when conda is active
if(TARGET ${PROJECT_NAME} AND DEFINED ENV{CONDA_PREFIX} AND EXISTS "$ENV{CONDA_PREFIX}")
set_target_properties(${PROJECT_NAME} PROPERTIES
@@ -513,44 +554,23 @@ endif()
if(NOT HOST_IS_X86 AND KTRANSFORMERS_CPU_USE_KML)
message(STATUS "KML CPU detected")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/prefillgemm)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel/prefillgemm)
target_link_libraries(${PROJECT_NAME} PRIVATE prefillint8gemm)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/prefillgemm_int4)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel/prefillgemm_int4)
target_link_libraries(${PROJECT_NAME} PRIVATE prefillint4gemm)
set(DECODE_GEMM_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/la/batch_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/la/batch_gemm_kernels.cpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel/batch_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel/batch_gemm_kernels.cpp
)
add_library(decode_gemm SHARED ${DECODE_GEMM_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE decode_gemm)
target_link_libraries(${PROJECT_NAME} PRIVATE kml_rt)
if(KTRANSFORMERS_CPU_MLA)
target_link_libraries(${PROJECT_NAME} PRIVATE kml_rt)
endif()
target_compile_definitions(${PROJECT_NAME} PRIVATE CPU_USE_KML)
endif()
target_link_libraries(${PROJECT_NAME} PRIVATE llama PkgConfig::HWLOC OpenMP::OpenMP_CXX)
if(HOST_IS_X86)
if(KTRANSFORMERS_CPU_USE_AMX_AVX512)
if(KTRANSFORMERS_CPU_USE_AMX)
add_compile_definitions(HAVE_AMX=1)
message(STATUS "AMX enabled")
endif()
# add_executable(amx-test ${CMAKE_CURRENT_SOURCE_DIR}/operators/amx/amx-test.cpp)
# target_link_libraries(amx-test llama)
if(KTRANSFORMERS_CPU_DEBUG)
file(GLOB AMX_TEST_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/operators/amx/test/*.cpp")
foreach(test_src ${AMX_TEST_SOURCES})
# 获取不带扩展名的文件名作为 target 名
get_filename_component(test_name ${test_src} NAME_WE)
add_executable(${test_name} ${test_src} ${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend/shared_mem_buffer.cpp)
target_link_libraries(${test_name} llama OpenMP::OpenMP_CXX numa)
endforeach()
endif()
endif()
endif()
if(NOT HOST_IS_X86 AND KTRANSFORMERS_CPU_USE_KML)
if(KTRANSFORMERS_CPU_DEBUG)
# add_executable(convert-test ${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/convert-test.cpp)
@@ -560,27 +580,27 @@ if(NOT HOST_IS_X86 AND KTRANSFORMERS_CPU_USE_KML)
# 获取不带扩展名的文件名作为 target 名
get_filename_component(test_name ${test_src} NAME_WE)
add_executable(${test_name} ${test_src} ${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend/shared_mem_buffer.cpp)
target_link_libraries(${test_name} llama OpenMP::OpenMP_CXX numa kml_rt)
if(KTRANSFORMERS_CPU_MLA)
target_link_libraries(${test_name} llama OpenMP::OpenMP_CXX numa kml_rt)
endif()
endforeach()
endif()
endif()
if(WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib") #CUDA::cudart
elseif(UNIX)
if(NOT KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif()
if(KTRANSFORMERS_USE_ROCM)
add_compile_definitions(USE_HIP=1)
target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so")
message(STATUS "Building for HIP")
endif()
if(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
endif()
if(KTRANSFORMERS_USE_CUDA)
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif()
if(KTRANSFORMERS_USE_ROCM)
add_compile_definitions(USE_HIP=1)
target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so")
message(STATUS "Building for HIP")
endif()
if(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
endif()
find_library(NUMA_LIBRARY NAMES numa)
@@ -590,3 +610,7 @@ if(NUMA_LIBRARY)
else()
message(FATAL_ERROR "NUMA library not found, please install NUMA, sudo apt install libnuma-dev")
endif()
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
message(STATUS "ARCH_FLAGS: ${ARCH_FLAGS}")

View File

@@ -0,0 +1,47 @@
{
"version": 3,
"cmakeMinimumRequired": {
"major": 3,
"minor": 19,
"patch": 0
},
"configurePresets": [
{
"name": "avx512",
"displayName": "avx512_platform",
"description": "for avx512 platform",
"cacheVariables": {
"KTRANSFORMERS_CPU_USE_AMX": "OFF",
"LLAMA_AVX512": "OFF",
"LLAMA_AVX2": "OFF",
"KTRANSFORMERS_CPU_USE_AMX_AVX512": "ON",
"KTRANSFORMERS_USE_CUDA": "ON"
}
},
{
"name": "avx",
"displayName": "avx_platform",
"description": "for avx platform",
"cacheVariables": {
"KTRANSFORMERS_CPU_USE_AMX": "OFF",
"LLAMA_AVX2": "ON",
"KTRANSFORMERS_USE_CUDA": "ON"
}
},
{
"name": "amx",
"displayName": "amx_platform",
"description": "for amx platform",
"cacheVariables": {
"KTRANSFORMERS_CPU_USE_AMX": "ON",
"LLAMA_AVX512": "OFF",
"LLAMA_AVX2": "OFF",
"KTRANSFORMERS_CPU_USE_AMX_AVX512": "ON",
"KTRANSFORMERS_USE_CUDA": "ON"
}
}
]
}

View File

@@ -105,59 +105,42 @@ python -c "from kt_kernel import AMXMoEWrapper; print('✓ kt-kernel installed s
## Weight Quantization
KT-Kernel provides a weight conversion tool to quantize model weights from FP8/FP16/BF16 to INT4/INT8 format optimized for AMX inference.
KT-Kernel provides weight quantization tools for CPU-GPU hybrid inference (e.g., integrating with SGLang). Both tools work together to enable heterogeneous expert placement across CPUs and GPUs.
### Quantization Methods
### CPU Weights (for "cold" experts on CPU)
- **INT4**: 4-bit quantization for maximum memory efficiency
- **INT8**: 8-bit quantization for better accuracy
### Supported Input Formats
- **FP8**: 8-bit floating point with automatic dequantization
- **FP16**: 16-bit floating point
- **BF16**: BFloat16 format
### Basic Usage
Quantize weights to INT4/INT8 format optimized for AMX inference:
```bash
# Quantize BF16 model to INT4
python scripts/convert_weights.py \
--input-path /path/to/bf16/model \
python scripts/convert_cpu_weights.py \
--input-path /path/to/model \
--input-type bf16 \
--output /path/to/output \
--quant-method int4
# Quantize FP16 model to INT8
python scripts/convert_weights.py \
--input-path /path/to/fp16/model \
--input-type fp16 \
--output /path/to/output \
--quant-method int8
# Quantize FP8 model to INT4
python scripts/convert_weights.py \
--input-path /path/to/fp8/model \
--input-type fp8 \
--output /path/to/output \
--quant-method int4
```
### Output Format
**Supported formats:** FP8, FP16, BF16 → INT4/INT8
The converted weights are saved in SafeTensors format with NUMA-aware layout:
```
output_dir/
├── model-00001-of-00050.safetensors
├── model-00002-of-00050.safetensors
├── ...
├── config.json
└── tokenizer files...
### GPU Weights (for "hot" experts on GPU)
Apply GPTQ quantization to model weights:
```bash
# Install additional dependencies first
pip install accelerate transformers llmcompressor datasets
# Quantize GPU weights
python scripts/convert_gpu_weights.py \
--model_id /path/to/model \
--output_dir /path/to/output \
--quant_type W4A16
```
Each expert's weights are split across NUMA nodes for optimal memory access:
- `blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.weight`: Quantized weights
- `blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.scale`: Quantization scales
**Supported types:** W4A16 (GPTQ4), W8A16 (GPTQ8)
---
For detailed documentation, advanced options, and low-memory mode, see [scripts/README.md](scripts/README.md).
## Before Commit!
your msg should match: Conventional Commits (https://www.conventionalcommits.org/) <br>and format your code before commit:

23
kt-kernel/bench/Makefile Normal file
View File

@@ -0,0 +1,23 @@
# test bench_moe_kernel_tiling.py
kernel_tiling:
python3 bench_moe_kernel_tiling.py \
--hidden_size 7168 \
--intermediate_size 2048 \
--num_experts_per_tok 8 \
--expert_num 256 \
--max_len 51200 \
--layer_num 1 \
--qlen 1024 \
--quant int8 \
--warm_up_iter 500 \
--test_iter 1000 \
--threads 160 \
--m_block 320 \
# --n_block_up_gate 256 \
# --n_block_down 128 \
# --n_block_up_gate_prefi 256 \
# --n_block_down_prefi 128 \
# --n_block_up_gate 256 \
# --n_block_down 512 \

View File

@@ -13,7 +13,7 @@ import os, sys
import time
sys.path.append(os.path.dirname(__file__) + "/../build")
import cpuinfer_ext
import kt_kernel_ext
import torch
layer_num = 10
@@ -23,16 +23,16 @@ head_dim = 128
block_len = 128
anchor_num = 1
anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC
kv_type = cpuinfer_ext.kvcache.ggml_type.FP16
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER
anchor_type = kt_kernel_ext.kvcache.AnchorType.DYNAMIC
kv_type = kt_kernel_ext.kvcache.ggml_type.FP16
retrieval_type = kt_kernel_ext.kvcache.RetrievalType.LAYER
layer_step: int = 1
token_step: int = 1
layer_offset: int = 0
max_thread_num: int = 64
max_batch_size: int = 1
max_block_num: int = 1024
CPUInfer = cpuinfer_ext.CPUInfer(max_thread_num)
CPUInfer = kt_kernel_ext.CPUInfer(max_thread_num)
warm_up_iter = 1000
test_iter = 10000
@@ -43,7 +43,7 @@ def bench_linear(cache_seqlen: int):
cache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device="cpu")
seqlens_zero = torch.zeros((1,), dtype=torch.int32, device="cpu")
config = cpuinfer_ext.kvcache.KVCacheConfig(
config = kt_kernel_ext.kvcache.KVCacheConfig(
layer_num,
kv_head_num,
q_head_num,
@@ -60,7 +60,7 @@ def bench_linear(cache_seqlen: int):
max_batch_size,
max_thread_num,
)
local_kvcache = cpuinfer_ext.kvcache.KVCache(config)
local_kvcache = kt_kernel_ext.kvcache.KVCache(config)
block_table = (
torch.arange(max_block_num, dtype=torch.int32, device="cpu")
.contiguous()

View File

@@ -13,7 +13,7 @@ import os, sys
import time
sys.path.append(os.path.dirname(__file__) + "/../build")
import cpuinfer_ext
import kt_kernel_ext
import torch
layer_num = 10

View File

@@ -12,7 +12,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
import kt_kernel_ext
import torch
input_size = 16384
@@ -21,7 +21,7 @@ stride = 16
group_max_len = 1024
layer_num = 10
qlen = 1
CPUInfer = cpuinfer_ext.CPUInfer(64)
CPUInfer = kt_kernel_ext.CPUInfer(64)
warm_up_iter = 1000
test_iter = 10000
@@ -69,8 +69,8 @@ def bench_linear(quant_mode: str):
projs = []
for _ in range(layer_num):
proj = torch.randn((output_size, input_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
config = cpuinfer_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type)
linear = cpuinfer_ext.linear.Linear(config)
config = kt_kernel_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type)
linear = kt_kernel_ext.linear.Linear(config)
projs.append(proj)
linears.append(linear)
input = torch.randn((layer_num, qlen, input_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()

View File

@@ -5,8 +5,8 @@ import platform
import json
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'build'))
import cpuinfer_ext
from cpuinfer_ext.kvcache import ggml_type
import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
from torch import inf, nn
from torch.nn import init
@@ -47,7 +47,7 @@ rope_scaling = {
CPUINFER_PARAM = 304
# 初始化 CPUInfer此处使用原始构造函数可根据需要调整配置参数
CPUInfer = cpuinfer_ext.CPUInfer(CPUINFER_PARAM)
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
warm_up_iter = 20
@@ -200,7 +200,7 @@ def bench_mla(quant_mode: str):
kv_b_proj_weight = kv_b_proj.weight.to(torch.float16).to('cpu').contiguous()
o_proj_weight = o_proj.weight.to(torch.float16).to('cpu').contiguous()
config = cpuinfer_ext.mla.MLAConfig(
config = kt_kernel_ext.mla.MLAConfig(
hidden_size,
q_lora_rank,
kv_lora_rank,
@@ -236,7 +236,7 @@ def bench_mla(quant_mode: str):
mla = cpuinfer_ext.mla.MLA(config)
mla = kt_kernel_ext.mla.MLA(config)
mla.load_weights()
mla.set_local_pages(pages_count)
mlas.append(mla)

View File

@@ -12,7 +12,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
import kt_kernel_ext
import torch
hidden_size = 5120
@@ -21,7 +21,7 @@ stride = 16
group_max_len = 1024
layer_num = 10
qlen = 1
CPUInfer = cpuinfer_ext.CPUInfer(64)
CPUInfer = kt_kernel_ext.CPUInfer(64)
warm_up_iter = 1000
test_iter = 10000
@@ -96,8 +96,8 @@ def bench_mlp(quant_mode: str):
gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
config = cpuinfer_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)
mlp = cpuinfer_ext.mlp.MLP(config)
config = kt_kernel_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)
mlp = kt_kernel_ext.mlp.MLP(config)
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)

View File

@@ -6,7 +6,7 @@ import subprocess
import platform
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'build'))
import cpuinfer_ext
import kt_kernel_ext
import torch
from tqdm import tqdm
@@ -29,7 +29,7 @@ warm_up_iter = 100
test_iter = 10000
CPUINFER_PARAM = 304
# 初始化 CPUInfer此处使用原始构造函数可根据需要调整配置参数
CPUInfer = cpuinfer_ext.CPUInfer(CPUINFER_PARAM)
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
# 获取脚本相关信息,用于生成结果保存文件名
script_path = os.path.abspath(__file__)
@@ -198,7 +198,7 @@ def bench_moe(quant_mode: str):
up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float16, device="cpu").to("cpu").contiguous()
down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float16, device="cpu").to("cpu").contiguous()
config = cpuinfer_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
config.pool = CPUInfer.backend_
config.m_block = m_block
config.group_min_len = group_min_len
@@ -211,7 +211,7 @@ def bench_moe(quant_mode: str):
config.down_type = down_type
config.hidden_type = hidden_type
moe = cpuinfer_ext.moe.MOE(config)
moe = kt_kernel_ext.moe.MOE(config)
CPUInfer.submit(moe.load_weights_task())
CPUInfer.sync()
moes.append(moe)

View File

@@ -1,50 +1,46 @@
#!/usr/bin/env python
# coding=utf-8
'''
Description :
"""
Description :
Author : chenht2022
Date : 2024-07-25 10:32:05
Version : 1.0.0
LastEditors : chenht2022
LastEditors : chenht2022
LastEditTime : 2024-08-06 10:41:28
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import os, sys, time, json, subprocess, platform
from tqdm import tqdm
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'build'))
import cpuinfer_ext
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
import torch
import kt_kernel_ext
import numpy as np
# 测试参数设置
expert_num = 256
expert_num = 16
hidden_size = 7168
intermediate_size = 2048
max_len = 25600
max_len = 25600
num_experts_per_tok = 8
layer_num = 4
# qlen = 1024
qlen = 1
warm_up_iter = 500
test_iter = 1000
physical_to_logical_map = torch.tensor(
data=range(expert_num),
device="cpu",
dtype=torch.int64
).contiguous()
layer_num = 2
qlen = 2048
warm_up_iter = 1000
test_iter = 2000
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
# 将 CPUInfer 参数设为变量
# CPUINFER_PARAM = 257
# CPUInfer = cpuinfer_ext.CPUInfer(CPUINFER_PARAM)
# CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
worker_config = cpuinfer_ext.WorkerPoolConfig()
worker_config = kt_kernel_ext.WorkerPoolConfig()
worker_config.subpool_count = 2
worker_config.subpool_numa_map= [0,1]
worker_config.subpool_thread_count = [80,80]
CPUINFER_PARAM = 40
CPUInfer = cpuinfer_ext.CPUInfer(worker_config)
worker_config.subpool_numa_map = [0, 1]
worker_config.subpool_thread_count = [80, 80]
CPUINFER_PARAM = 160
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
def get_git_commit():
@@ -81,14 +77,14 @@ def get_system_info():
info = {}
# 系统名称及主机名
uname = platform.uname()
info["system_name"] = uname.system # 如 Linux, Windows 等
info["node_name"] = uname.node # 主机名称
info["system_name"] = uname.system # 如 Linux, Windows 等
info["node_name"] = uname.node # 主机名称
# 获取 CPU 型号(仅 Linux 支持)
cpu_model = None
if os.path.exists('/proc/cpuinfo'):
if os.path.exists("/proc/cpuinfo"):
try:
with open('/proc/cpuinfo', 'r') as f:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "model name" in line:
cpu_model = line.split(":", 1)[1].strip()
@@ -99,9 +95,9 @@ def get_system_info():
# 获取内存大小单位GB仅 Linux 支持
mem_total_gb = None
if os.path.exists('/proc/meminfo'):
if os.path.exists("/proc/meminfo"):
try:
with open('/proc/meminfo', 'r') as f:
with open("/proc/meminfo", "r") as f:
for line in f:
if "MemTotal" in line:
mem_kb = float(line.split(":", 1)[1].split()[0])
@@ -129,11 +125,13 @@ def get_system_info():
return info
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
script_name = os.path.splitext(os.path.basename(script_path))[0]
json_path = os.path.join(script_dir, script_name + ".jsonl")
def record_results(result, filename=json_path):
"""
将结果以 JSON 格式追加到文件中
@@ -141,6 +139,7 @@ def record_results(result, filename=json_path):
with open(filename, "a") as f:
f.write(json.dumps(result) + "\n")
def bench_moe(quant_mode: str):
with torch.inference_mode():
if quant_mode == "bf16":
@@ -157,33 +156,56 @@ def bench_moe(quant_mode: str):
up_projs = []
down_projs = []
for layer_index in range(layer_num):
gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda").to("cpu").contiguous()
up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda").to("cpu").contiguous()
down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cuda").to("cpu").contiguous()
config = cpuinfer_ext.moe.MOEConfig(
expert_num, num_experts_per_tok, hidden_size, intermediate_size,0)
gate_proj = (
torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda")
.to("cpu")
.contiguous()
)
up_proj = (
torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda")
.to("cpu")
.contiguous()
)
down_proj = (
torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cuda")
.to("cpu")
.contiguous()
)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
config.pool = CPUInfer.backend_
if quant_mode == "bf16":
moe = cpuinfer_ext.moe.AMXBF16_MOE(config)
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
elif quant_mode == "int8":
moe = cpuinfer_ext.moe.AMXInt8_MOE(config)
moe = kt_kernel_ext.moe.AMXInt8_MOE(config)
elif quant_mode == "int4":
moe = cpuinfer_ext.moe.AMXInt4_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
moe = kt_kernel_ext.moe.AMXInt4_MOE(config)
CPUInfer.submit(moe.load_weights_task())
CPUInfer.sync()
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)
moes.append(moe)
gen_iter = 3000
expert_ids = torch.rand(gen_iter * qlen , expert_num, device="cpu").argsort(dim=-1)[:, :num_experts_per_tok].reshape(gen_iter, qlen * num_experts_per_tok).to("cpu").contiguous()
weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu").to("cpu").contiguous()
input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
expert_ids = (
torch.rand(gen_iter * qlen, expert_num, device="cpu")
.argsort(dim=-1)[:, :num_experts_per_tok]
.reshape(gen_iter, qlen * num_experts_per_tok)
.to("cpu")
.contiguous()
)
weights = (
torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu").to("cpu").contiguous()
)
input_tensor = (
torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
)
output_tensor = (
torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
)
bsz_tensor = torch.tensor([qlen], device="cpu")
# 预热迭代
@@ -193,8 +215,8 @@ def bench_moe(quant_mode: str):
moes[i % layer_num].forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i%gen_iter].data_ptr(),
weights[i%gen_iter].data_ptr(),
expert_ids[i % gen_iter].data_ptr(),
weights[i % gen_iter].data_ptr(),
input_tensor[i % layer_num].data_ptr(),
output_tensor[i % layer_num].data_ptr(),
False,
@@ -213,8 +235,8 @@ def bench_moe(quant_mode: str):
moes[i % layer_num].forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i%gen_iter].data_ptr(),
weights[i%gen_iter].data_ptr(),
expert_ids[i % gen_iter].data_ptr(),
weights[i % gen_iter].data_ptr(),
input_tensor[i % layer_num].data_ptr(),
output_tensor[i % layer_num].data_ptr(),
False,
@@ -228,16 +250,28 @@ def bench_moe(quant_mode: str):
# 计算性能指标
time_per_iter_us = total_time / test_iter * 1e6
bandwidth = hidden_size * intermediate_size * 3 * num_experts_per_tok * (1/8 * 256 * (1-(31/32)**qlen)) * bytes_per_elem * test_iter / total_time / 1e9 # 单位GB/s
flops = hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12 # 单位TFLOPS
bandwidth = (
hidden_size
* intermediate_size
* 3
* num_experts_per_tok
* (1 / 8 * 256 * (1 - (31 / 32) ** qlen))
* bytes_per_elem
* test_iter
/ total_time
/ 1e9
) # 单位GB/s
flops = (
hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12
) # 单位TFLOPS
print('Quant mode: ', quant_mode)
print('Time(s): ', total_time)
print('Iteration: ', test_iter)
print('Time(us) per iteration: ', time_per_iter_us)
print('Bandwidth: ', bandwidth, 'GB/s')
print('Flops: ', flops, 'TFLOPS')
print('')
print("Quant mode: ", quant_mode)
print("Time(s): ", total_time)
print("Iteration: ", test_iter)
print("Time(us) per iteration: ", time_per_iter_us)
print("Bandwidth: ", bandwidth, "GB/s")
print("Flops: ", flops, "TFLOPS")
print("")
# 整理结果记录,包括测试参数
result = {
@@ -258,8 +292,8 @@ def bench_moe(quant_mode: str):
"qlen": qlen,
"warm_up_iter": warm_up_iter,
"test_iter": test_iter,
"CPUInfer_parameter": CPUINFER_PARAM
}
"CPUInfer_parameter": CPUINFER_PARAM,
},
}
# 添加 git 提交记录信息
result.update(get_git_commit())
@@ -268,8 +302,9 @@ def bench_moe(quant_mode: str):
# 将结果以 JSON 形式追加到文件中
record_results(result)
if __name__ == "__main__":
# 选择需要测试的量化模式
# bench_moe("bf16")
# bench_moe("int8")
bench_moe("int4")
bench_moe("int8")
# bench_moe("int4")

View File

@@ -13,7 +13,7 @@ import os, sys, time, json, subprocess, platform
from tqdm import tqdm
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'build'))
import cpuinfer_ext
import kt_kernel_ext
import torch
import numpy as np
@@ -37,14 +37,14 @@ physical_to_logical_map = torch.tensor(
).contiguous()
# 将 CPUInfer 参数设为变量
# CPUINFER_PARAM = 257
# CPUInfer = cpuinfer_ext.CPUInfer(CPUINFER_PARAM)
# CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
worker_config = cpuinfer_ext.WorkerPoolConfig()
worker_config = kt_kernel_ext.WorkerPoolConfig()
worker_config.subpool_count = 2
worker_config.subpool_numa_map= [0,1]
worker_config.subpool_thread_count = [40,40]
CPUINFER_PARAM = 80
CPUInfer = cpuinfer_ext.CPUInfer(worker_config)
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
@@ -163,7 +163,7 @@ def bench_moe(quant_mode: str):
gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda").to("cpu").contiguous()
up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda").to("cpu").contiguous()
down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cuda").to("cpu").contiguous()
config = cpuinfer_ext.moe.MOEConfig(
config = kt_kernel_ext.moe.MOEConfig(
expert_num, num_experts_per_tok, hidden_size, intermediate_size,0)
config.max_len = max_len
config.gate_proj = gate_proj.data_ptr()
@@ -171,17 +171,17 @@ def bench_moe(quant_mode: str):
config.down_proj = down_proj.data_ptr()
config.pool = CPUInfer.backend_
if quant_mode == "bf16":
moe = cpuinfer_ext.moe.AMXBF16_MOE(config)
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
elif quant_mode == "int8":
moe = cpuinfer_ext.moe.AMXInt8_MOE(config)
moe = kt_kernel_ext.moe.AMXInt8_MOE(config)
elif quant_mode == "int4":
moe = cpuinfer_ext.moe.AMXInt4_MOE(config)
moe = kt_kernel_ext.moe.AMXInt4_MOE(config)
elif quant_mode == "int4_1k":
config.quant_config.bits = 4
config.quant_config.group_size = k_group_size
config.quant_config.zero_point = True
config.gate_scale = 0
moe = cpuinfer_ext.moe.AMXInt4_1KGroup_MOE(config)
moe = kt_kernel_ext.moe.AMXInt4_1KGroup_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
gate_projs.append(gate_proj)

View File

@@ -0,0 +1,331 @@
#!/usr/bin/env python
# coding=utf-8
"""
Description :
Author : chenht2022
Date : 2024-07-25 10:32:05
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2024-08-06 10:41:28
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import os, sys, time, json, subprocess, platform
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
import torch
import kt_kernel_ext
import numpy as np
from tqdm import tqdm
# 测试参数设置
expert_num = 256
hidden_size = 7168
intermediate_size = 2048
max_len = 51200
num_experts_per_tok = 8
layer_num = 1
m_block = 320
n_block_up_gate = 32
n_block_down = 64
n_block_up_gate_prefi = 32
n_block_down_prefi = 64
qlen = 2048
warm_up_iter = 1000
test_iter = 1000
# 将 CPUInfer 参数设为变量
CPUINFER_PARAM = 160
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
# worker_config = kt_kernel_ext.WorkerPoolConfig()
# worker_config.subpool_count = 4
# worker_config.subpool_numa_map= [0,1,2,3]
# worker_config.subpool_thread_count = [36,36,36,36]
# worker_config.subpool_thread_count = [39,39,39,39]
# CPUINFER_PARAM = 156
# CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
def get_git_commit():
"""
获取当前 git 提交记录commit hash 和提交信息),
并检查是否存在未提交的更改dirty
"""
result = {}
try:
commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip()
result["commit"] = commit
result["commit_message"] = commit_msg
# 检查是否存在未提交的更改
dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
if dirty_output:
result["dirty"] = True
result["dirty_files"] = dirty_output.splitlines()
else:
result["dirty"] = False
except Exception as e:
result["commit"] = None
result["commit_message"] = None
result["dirty"] = None
result["error"] = str(e)
return result
def get_system_info():
"""
获取系统信息包括系统名称、CPU 型号、内存大小GB、CPU 核数及 socket 数量
"""
info = {}
# 系统名称及主机名
uname = platform.uname()
info["system_name"] = uname.system # 如 Linux, Windows 等
info["node_name"] = uname.node # 主机名称
# 获取 CPU 型号(仅 Linux 支持)
cpu_model = None
if os.path.exists("/proc/cpuinfo"):
try:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "model name" in line:
cpu_model = line.split(":", 1)[1].strip()
break
except Exception as e:
cpu_model = f"Error: {e}"
info["cpu_model"] = cpu_model
# 获取内存大小单位GB仅 Linux 支持
mem_total_gb = None
if os.path.exists("/proc/meminfo"):
try:
with open("/proc/meminfo", "r") as f:
for line in f:
if "MemTotal" in line:
mem_kb = float(line.split(":", 1)[1].split()[0])
mem_total_gb = round(mem_kb / (1024 * 1024), 2)
break
except Exception as e:
mem_total_gb = f"Error: {e}"
info["memory_size_GB"] = mem_total_gb
# 获取 CPU 核数(逻辑核数)
info["cpu_core_count"] = os.cpu_count()
# 解析 /proc/cpuinfo 获取 socket 数量
sockets = set()
if os.path.exists("/proc/cpuinfo"):
try:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "physical id" in line:
sockets.add(line.split(":", 1)[1].strip())
except Exception as e:
sockets = set()
# 如果没有解析到 socket 信息,则默认至少有 1 个 socket
info["cpu_socket_count"] = len(sockets) if len(sockets) > 0 else 1
return info
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
script_name = os.path.splitext(os.path.basename(script_path))[0]
json_path = os.path.join(script_dir, "bench_results " + ".jsonl")
def record_results(result, filename=json_path):
"""
将结果以 JSON 格式追加到文件中
"""
with open(filename, "a") as f:
f.write(json.dumps(result) + "\n")
def bench_moe(quant_mode: str):
with torch.inference_mode():
if quant_mode == "int8":
bytes_per_elem = 1.0
elif quant_mode == "int4":
bytes_per_elem = 0.5
else:
raise ValueError("不支持的量化模式")
moes = []
gate_projs = []
up_projs = []
down_projs = []
for layer_index in range(layer_num):
gate_proj = (
torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda")
.to("cpu")
.contiguous()
)
up_proj = (
torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda")
.to("cpu")
.contiguous()
)
down_proj = (
torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cuda")
.to("cpu")
.contiguous()
)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
config.pool = CPUInfer.backend_
if quant_mode == "int8":
d = kt_kernel_ext.moe.tiling.get_int8()
nbug_prefi = n_block_up_gate_prefi
nbd_prefi = n_block_down_prefi
kb = d["k_block"]
nb = d["n_block"]
mb = m_block
nbug = n_block_up_gate
nbd = n_block_down
print(
f"Int8 Tiling: nbug {nbug}, nbd {nbd}, nb {nb}, mb {mb}, kb {kb}, nbug_prefi {nbug_prefi}, nbd_prefi {nbd_prefi}"
)
kt_kernel_ext.moe.tiling.set_int8(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
moe = kt_kernel_ext.moe.Int8_KERNEL_MOE(config)
elif quant_mode == "int4":
moe = kt_kernel_ext.moe.Int4_KERNEL_MOE(config)
else:
raise ValueError(f"Unsupported quantization mode: {quant_mode}")
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)
moes.append(moe)
expert_ids = (
torch.rand(test_iter * qlen, expert_num, device="cuda")
.argsort(dim=-1)[:, :num_experts_per_tok]
.reshape(test_iter, qlen * num_experts_per_tok)
.to("cpu")
.contiguous()
)
weights = (
torch.rand((test_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cuda")
.to("cpu")
.contiguous()
)
input_tensor = (
torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
)
output_tensor = (
torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
)
bsz_tensor = torch.tensor([qlen], device="cuda").to("cpu").contiguous()
# 预热迭代
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
# print(f'warmup iteration {i}')
# start_it = time.time_ns()
CPUInfer.submit(
moes[i % layer_num].forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i].data_ptr(),
weights[i].data_ptr(),
input_tensor[i % layer_num].data_ptr(),
output_tensor[i % layer_num].data_ptr(),
# False,
)
)
CPUInfer.sync()
# end_it = time.time_ns()
# print('python Time(ns): ', end_it - start_it)
# 测试迭代
start = time.perf_counter()
for i in tqdm(range(test_iter), desc="Testing"):
# print(f'test iteration {i}')
# start_it = time.time_ns()
CPUInfer.submit(
moes[i % layer_num].forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i].data_ptr(),
weights[i].data_ptr(),
input_tensor[i % layer_num].data_ptr(),
output_tensor[i % layer_num].data_ptr(),
False,
)
)
CPUInfer.sync()
# end_it = time.time_ns()
# print('python Time(ns): ', end_it - start_it)
end = time.perf_counter()
total_time = end - start
# 计算性能指标
time_per_iter_us = total_time / test_iter * 1e6
bandwidth = (
hidden_size
* intermediate_size
* 3
* num_experts_per_tok
# * (1 / 8 * 256 * (1 - (31 / 32) ** qlen))
* qlen
* bytes_per_elem
* test_iter
/ total_time
/ 1e9
) # 单位GB/s
flops = (
hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12
) # 单位TFLOPS
print("Quant mode: ", quant_mode)
print("Time(s): ", total_time)
print("Iteration: ", test_iter)
print("Time(us) per iteration: ", time_per_iter_us)
print("Bandwidth: ", bandwidth, "GB/s")
print("Flops: ", flops, "TFLOPS")
print("")
# 整理结果记录,包括测试参数
result = {
"test_name": os.path.basename(__file__),
"quant_mode": quant_mode,
"total_time_seconds": total_time,
"iterations": test_iter,
"time_per_iteration_us": time_per_iter_us,
"bandwidth_GBs": bandwidth,
"flops_TFLOPS": flops,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"test_parameters": {
"expert_num": expert_num,
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"max_len": max_len,
"num_experts_per_tok": num_experts_per_tok,
"layer_num": layer_num,
"qlen": qlen,
"warm_up_iter": warm_up_iter,
"test_iter": test_iter,
"CPUInfer_parameter": CPUINFER_PARAM,
},
}
# 添加 git 提交记录信息
result.update(get_git_commit())
# 添加系统信息(包括 CPU 核数和 socket 数量)
result.update(get_system_info())
# 将结果以 JSON 形式追加到文件中
record_results(result)
if __name__ == "__main__":
# 选择需要测试的量化模式
bench_moe("int8")
# bench_moe("int4")

View File

@@ -0,0 +1,232 @@
#!/usr/bin/env python
# coding=utf-8
"""
Bench MOE kernel with runtime tiling params (N_BLOCK_UP_GATE, N_BLOCK_DOWN, N_BLOCK, M_BLOCK, K_BLOCK)
- Demonstrates how to get/set tiling params from Python via kt_kernel_ext.moe.tiling
- Runs a small benchmark similar to bench_moe_kernel.py
Usage examples:
# 1) Just run with defaults (int8)
python bench_moe_kernel_tiling.py --quant int8
# 2) Override tiling params for INT8
python bench_moe_kernel_tiling.py --quant int8 \
--n_block_up_gate 32 --n_block_down 64 --n_block 64 --m_block 320 --k_block 7168
# 3) Set both INT8 and INT4 tiling params (if INT4 kernel is available on your platform)
python bench_moe_kernel_tiling.py --quant int4 --set_all \
--n_block_up_gate 256 --n_block_down 1024 --n_block 64 --m_block 320 --k_block 7168
"""
import os
import sys
import time
import argparse
os.environ.setdefault("BLAS_NUM_THREADS", "1")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
import torch # noqa: E402
import kt_kernel_ext as ce # noqa: E402
from tqdm import tqdm # noqa: E402
def maybe_get_class(module, name):
return getattr(module, name) if hasattr(module, name) else None
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--quant", choices=["int8", "int4"], default="int8")
parser.add_argument("--expert_num", type=int, default=256)
parser.add_argument("--hidden_size", type=int, default=7168)
parser.add_argument("--intermediate_size", type=int, default=2048)
parser.add_argument("--num_experts_per_tok", type=int, default=8)
parser.add_argument("--max_len", type=int, default=25600)
parser.add_argument("--layer_num", type=int, default=1)
parser.add_argument("--qlen", type=int, default=1024)
parser.add_argument("--warm_up_iter", type=int, default=200)
parser.add_argument("--test_iter", type=int, default=500)
parser.add_argument("--threads", type=int, default=160, help="CPUInfer initialization param")
# Tiling params
parser.add_argument("--set_all", action="store_true", help="Apply tiling to both INT8 and INT4 kernels")
parser.add_argument("--n_block_up_gate", type=int, default=None)
parser.add_argument("--n_block_down", type=int, default=None)
parser.add_argument("--n_block", type=int, default=None)
parser.add_argument("--m_block", type=int, default=None)
parser.add_argument("--k_block", type=int, default=None)
parser.add_argument("--n_block_up_gate_prefi", type=int, default=None)
parser.add_argument("--n_block_down_prefi", type=int, default=None)
args = parser.parse_args()
# Show current tiling defaults
if args.quant == "int8":
print("[tiling] default int8:", ce.moe.tiling.get_int8())
if hasattr(ce.moe.tiling, "get_int4") and args.quant == "int4":
print("[tiling] default int4:", ce.moe.tiling.get_int4())
# Apply overrides if provided
if any(v is not None for v in [args.n_block_up_gate, args.n_block_down, args.n_block, args.m_block, args.k_block]):
# Fill missing values with current defaults to avoid overwriting unrelated params
def fill_defaults(getter):
cur = getter()
return (
args.n_block_up_gate if args.n_block_up_gate is not None else int(cur["n_block_up_gate"]),
args.n_block_down if args.n_block_down is not None else int(cur["n_block_down"]),
args.n_block if args.n_block is not None else int(cur["n_block"]),
args.m_block if args.m_block is not None else int(cur["m_block"]),
args.k_block if args.k_block is not None else int(cur["k_block"]),
(
args.n_block_up_gate_prefi
if args.n_block_up_gate_prefi is not None
else int(cur["n_block_up_gate_prefi"])
),
args.n_block_down_prefi if args.n_block_down_prefi is not None else int(cur["n_block_down_prefi"]),
)
if args.set_all and hasattr(ce.moe.tiling, "set_all"):
nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int8)
ce.moe.tiling.set_all(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
print("[tiling] set_all ->", ce.moe.tiling.get_int8())
if hasattr(ce.moe.tiling, "get_int4"):
print("[tiling] set_all -> int4:", ce.moe.tiling.get_int4())
else:
if args.quant == "int8":
nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int8)
ce.moe.tiling.set_int8(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
print("[tiling] set_int8 ->", ce.moe.tiling.get_int8())
elif args.quant == "int4" and hasattr(ce.moe.tiling, "set_int4"):
nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int4)
ce.moe.tiling.set_int4(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
print("[tiling] set_int4 ->", ce.moe.tiling.get_int4())
# Warn about divisibility expectations; kernels assume specific blocking
# - Some helpers assert n % N_BLOCK == 0, etc. Ensure your dims/tiles align.
print("[note] Ensure your selected tiling parameters are compatible with hidden/intermediate sizes and blocking.")
# Initialize CPUInfer
CPUInfer = ce.CPUInfer(args.threads)
# Select MOE kernel
moe_cls = None
if args.quant == "int8":
moe_cls = maybe_get_class(ce.moe, "Int8_KERNEL_MOE")
if moe_cls is None:
raise RuntimeError("Int8 kernel binding 'Int8_KERNEL_MOE' not found.")
bytes_per_elem = 1.0
else:
moe_cls = maybe_get_class(ce.moe, "Int4_KERNEL_MOE")
if moe_cls is None:
raise RuntimeError("Int4 kernel binding 'Int4_KERNEL_MOE' not available on this platform.")
bytes_per_elem = 0.5
# Prepare config/weights
expert_num = args.expert_num
hidden_size = args.hidden_size
intermediate_size = args.intermediate_size
num_experts_per_tok = args.num_experts_per_tok
layer_num = args.layer_num
max_len = args.max_len
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
moes = []
gate_projs, up_projs, down_projs = [], [], []
for layer_idx in range(layer_num):
gate_proj = torch.randn(
(expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cpu"
).contiguous()
up_proj = torch.randn(
(expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cpu"
).contiguous()
down_proj = torch.randn(
(expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cpu"
).contiguous()
cfg = ce.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
cfg.max_len = max_len
cfg.gate_proj = gate_proj.data_ptr()
cfg.up_proj = up_proj.data_ptr()
cfg.down_proj = down_proj.data_ptr()
cfg.pool = CPUInfer.backend_
moe = moe_cls(cfg)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)
moes.append(moe)
qlen = args.qlen
warm_up_iter = args.warm_up_iter
test_iter = args.test_iter
expert_ids = (
torch.rand(test_iter * qlen, expert_num)
.argsort(dim=-1)[:, :num_experts_per_tok]
.reshape(test_iter, qlen * num_experts_per_tok)
.to("cpu")
.contiguous()
)
weights = torch.rand((test_iter, qlen, num_experts_per_tok), dtype=torch.float32).to("cpu").contiguous()
input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16).to("cpu").contiguous()
output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16).to("cpu").contiguous()
bsz_tensor = torch.tensor([qlen], dtype=torch.int32).to("cpu").contiguous()
# Warmup
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
CPUInfer.submit(
moes[i % layer_num].forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i].data_ptr(),
weights[i].data_ptr(),
input_tensor[i % layer_num].data_ptr(),
output_tensor[i % layer_num].data_ptr(),
)
)
CPUInfer.sync()
# Measure
start = time.perf_counter()
for i in tqdm(range(test_iter), desc="Testing"):
CPUInfer.submit(
moes[i % layer_num].forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i].data_ptr(),
weights[i].data_ptr(),
input_tensor[i % layer_num].data_ptr(),
output_tensor[i % layer_num].data_ptr(),
False,
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time = end - start
time_per_iter_us = total_time / test_iter * 1e6
bandwidth_gbs = (
hidden_size * intermediate_size * 3 * num_experts_per_tok * qlen * bytes_per_elem * test_iter / total_time / 1e9
)
flops_tflops = hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12
print("\n=== Results ===")
print("quant:", args.quant)
if hasattr(ce.moe.tiling, "get_int8") and args.quant == "int8":
print("tiling int8:", ce.moe.tiling.get_int8())
if hasattr(ce.moe.tiling, "get_int4") and args.quant == "int4":
print("tiling int4:", ce.moe.tiling.get_int4())
print("time (s):", total_time)
print("iter:", test_iter)
print("time per iter (us):", time_per_iter_us)
print("bandwidth (GB/s):", bandwidth_gbs)
print("TFLOPS:", flops_tflops)
if __name__ == "__main__":
main()

View File

@@ -13,7 +13,7 @@ import os, sys, time, json, subprocess, platform
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
import cpuinfer_ext
import kt_kernel_ext
import torch
import numpy as np
from tqdm import tqdm
@@ -33,15 +33,15 @@ test_iter = 10000
# 将 CPUInfer 参数设为变量
CPUINFER_PARAM = 112
CPUInfer = cpuinfer_ext.CPUInfer(CPUINFER_PARAM)
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
# worker_config = cpuinfer_ext.WorkerPoolConfig()
# worker_config = kt_kernel_ext.WorkerPoolConfig()
# worker_config.subpool_count = 4
# worker_config.subpool_numa_map= [0,1,2,3]
# worker_config.subpool_thread_count = [36,36,36,36]
# worker_config.subpool_thread_count = [39,39,39,39]
# CPUINFER_PARAM = 156
# CPUInfer = cpuinfer_ext.CPUInfer(worker_config)
# CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
def get_git_commit():
@@ -166,16 +166,16 @@ def bench_moe(quant_mode: str):
down_proj = torch.randn(
(expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cpu"
).contiguous()
config = cpuinfer_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
config.max_len = max_len
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
config.pool = CPUInfer.backend_
if quant_mode == "int8":
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
elif quant_mode == "int4":
moe = cpuinfer_ext.moe.KMLInt4_MOE(config)
moe = kt_kernel_ext.moe.KMLInt4_MOE(config)
else:
raise ValueError(f"Unsupported quantization mode: {quant_mode}")
CPUInfer.submit(moe.load_weights_task())

View File

@@ -50,12 +50,12 @@ import torch
# Try importing both implementations
try:
import cpuinfer_ext
import kt_kernel_ext
KTRANSFORMERS_AVAILABLE = True
logger.info("KTransformers cpuinfer_ext loaded successfully")
logger.info("KTransformers kt_kernel_ext loaded successfully")
except ImportError as e:
KTRANSFORMERS_AVAILABLE = False
logger.warning(f"KTransformers cpuinfer_ext not available: {e}")
logger.warning(f"KTransformers kt_kernel_ext not available: {e}")
try:
from sgl_kernel.common_ops import fused_experts_cpu
@@ -472,11 +472,11 @@ def bench_ktransformers_moe(test_config: TestConfig, quant_mode: str, qlen: int,
try:
with torch.inference_mode():
# Setup worker config with consistent threads per NUMA
worker_config = cpuinfer_ext.WorkerPoolConfig()
worker_config = kt_kernel_ext.WorkerPoolConfig()
worker_config.subpool_count = sys_config.numa_count
worker_config.subpool_numa_map = list(range(sys_config.numa_count))
worker_config.subpool_thread_count = [thread_config.threads_per_numa] * sys_config.numa_count
CPUInfer = cpuinfer_ext.CPUInfer(worker_config)
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
# Create MoE layers
moes = []
@@ -493,7 +493,7 @@ def bench_ktransformers_moe(test_config: TestConfig, quant_mode: str, qlen: int,
down_proj = torch.randn((test_config.expert_num, test_config.hidden_size, test_config.intermediate_size),
dtype=torch.float32).contiguous()
config = cpuinfer_ext.moe.MOEConfig(
config = kt_kernel_ext.moe.MOEConfig(
test_config.expert_num, test_config.num_experts_per_tok,
test_config.hidden_size, test_config.intermediate_size)
config.max_len = test_config.max_len
@@ -503,11 +503,11 @@ def bench_ktransformers_moe(test_config: TestConfig, quant_mode: str, qlen: int,
config.pool = CPUInfer.backend_
if quant_mode == "bf16":
moe = cpuinfer_ext.moe.AMXBF16_MOE(config)
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
elif quant_mode == "int8":
moe = cpuinfer_ext.moe.AMXInt8_MOE(config)
moe = kt_kernel_ext.moe.AMXInt8_MOE(config)
elif quant_mode == "int4":
moe = cpuinfer_ext.moe.AMXInt4_MOE(config)
moe = kt_kernel_ext.moe.AMXInt4_MOE(config)
else:
raise ValueError(f"Unsupported quantization mode: {quant_mode}")

View File

@@ -90,7 +90,7 @@ def update_bench_parameters(params):
bench.test_iter = params["test_iter"]
bench.CPUINFER_PARAM = params["CPUINFER_PARAM"]
# 重新初始化 CPUInfer 对象
bench.CPUInfer = bench.cpuinfer_ext.CPUInfer(bench.CPUINFER_PARAM)
bench.CPUInfer = bench.kt_kernel_ext.CPUInfer(bench.CPUINFER_PARAM)
def main():

View File

@@ -2,11 +2,11 @@
# CFLAGS += -march=armv8.2-a+fp16+dotprod+sve+bf16 -I/home/test/kt-code/HPCKit_25.0.0_Linux-aarch64/package/KunpengHPCKit-kml.25.0.0/include
# CFLAGS += -march=armv8.2-a+fp16+dotprod+sve+bf16 -I/home/test/kt-code/HPCKit_25.0.0_Linux-aarch64/package/KunpengHPCKit-kml.25.0.0/include
CFLAGS += -O3
CFLAGS += -I/usr/local/include/blis/
CFLAGS += -I/usr/local/include/blis/ -fopenmp
LDLIBS += -L/usr/local/lib -lblis
# LDLIBS += $(shell pkg-config --libs hwloc) -lkml_rt
CXX = g++
CXX = /usr/bin/g++
# i8_cal: i8_cal.cpp
# $(CXX) i8_cal.cpp $(CFLAGS) -o i8_cal $(LDLIBS)
@@ -17,9 +17,8 @@ simple_test_build: simple_test.cpp
rm -f simple_test
BLAS_NUM_THREADS=1 $(CXX) simple_test.cpp $(CFLAGS) -o simple_test $(LDLIBS)
simple_aocl_build: simple_test_aocl.cpp
rm -f simple_test_aocl
BLAS_NUM_THREADS=1 $(CXX) simple_test_aocl.cpp $(CFLAGS) -o simple_test_aocl $(LDLIBS)
simple_aocl_build: build simple_test_aocl.cpp
$(CXX) simple_test_aocl.cpp $(CFLAGS) -o build/simple_test_aocl $(LDLIBS)
fp16_test_build: fp16-test.cpp
rm -f fp16-test
@@ -27,6 +26,11 @@ fp16_test_build: fp16-test.cpp
bf16_test_build: bf16-test.cpp
rm -f bf16-test
$(CXX) bf16-test.cpp $(CFLAGS) -o bf16-test $(LDLIBS)
build: build
mkdir -p build
bandwidth_build: bench_reorder_bandwidth.cpp
$(CXX) bench_reorder_bandwidth.cpp $(CFLAGS) -o build/bench_reorder_bandwidth $(LDLIBS)
run: simple_aocl_build
LD_LIBRARY_PATH=/usr/local/lib:$$LD_LIBRARY_PATH ./simple_test_aocl
LD_LIBRARY_PATH=/usr/local/lib:$$LD_LIBRARY_PATH ./build/simple_test_aocl
run_bandwidth: bandwidth_build
LD_LIBRARY_PATH=/usr/local/lib:$$LD_LIBRARY_PATH ./build/bench_reorder_bandwidth

View File

@@ -0,0 +1,110 @@
#include <blis.h>
#include <chrono>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
namespace {
constexpr int kM = 1;
constexpr int kK = 7168;
constexpr int kN = 512;
constexpr int kIters = 10000;
void fill_random(int8_t* ptr, size_t count) {
std::srand(47);
for (size_t i = 0; i < count; ++i) {
ptr[i] = static_cast<int8_t>(std::rand() % 30);
}
}
void fill_zero(int32_t* ptr, size_t count) { std::memset(ptr, 0, count * sizeof(int32_t)); }
bool verify(const int8_t* a, const int8_t* b, const int32_t* c) {
for (int m = 0; m < kM; ++m) {
for (int n = 0; n < kN; ++n) {
int32_t ref = 0;
for (int k = 0; k < kK; ++k) {
ref += static_cast<int32_t>(a[m * kK + k]) * static_cast<int32_t>(b[n * kK + k]);
}
if (ref != c[m * kN + n]) {
std::printf("Mismatch at (%d, %d): got %d, expect %d\n", m, n, c[m * kN + n], ref);
return false;
}
}
}
return true;
}
} // namespace
int main() {
int8_t* a = static_cast<int8_t*>(std::aligned_alloc(64, kM * kK));
int8_t* b = static_cast<int8_t*>(std::aligned_alloc(64, kK * kN));
int32_t* c = static_cast<int32_t*>(std::aligned_alloc(64, kM * kN * sizeof(int32_t)));
int32_t* c_tmp = static_cast<int32_t*>(std::aligned_alloc(64, kM * kN * sizeof(int32_t)));
if (!a || !b || !c || !c_tmp) {
std::fprintf(stderr, "Allocation failed.\n");
std::free(a);
std::free(b);
std::free(c);
std::free(c_tmp);
return EXIT_FAILURE;
}
fill_random(a, kM * kK);
fill_random(b, kK * kN);
fill_zero(c, kM * kN);
fill_zero(c_tmp, kM * kN);
const dim_t reorder_size = aocl_get_reorder_buf_size_s8s8s32os32('r', 't', 'B', kK, kN);
int8_t* b_reordered = static_cast<int8_t*>(std::aligned_alloc(64, reorder_size));
if (!b_reordered) {
std::fprintf(stderr, "Reorder buffer allocation failed.\n");
std::free(a);
std::free(b);
std::free(c);
return EXIT_FAILURE;
}
aocl_reorder_s8s8s32os32('r', 't', 'B', b, b_reordered, kK, kN, kK);
// Warm-up GEMM to load kernels.
aocl_gemm_s8s8s32os32('r', 'n', 't', kM, kN, kK, 1, a, kK, 'n', b_reordered, kK, 'r', 0, c_tmp, kN, nullptr);
fill_zero(c, kM * kN);
const double bytes_per_mul = static_cast<double>(kM) * kK * sizeof(int8_t) + // A matrix read
static_cast<double>(kK) * kN * sizeof(int8_t); // original B read
auto start = std::chrono::high_resolution_clock::now();
for (int iter = 0; iter < kIters; ++iter) {
aocl_gemm_s8s8s32os32('r', 'n', 't', kM, kN, kK, 1, a, kK, 'n', b_reordered, kK, 'r', 0, c, kN, nullptr);
}
auto end = std::chrono::high_resolution_clock::now();
const double elapsed_seconds = std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
const double total_bytes = bytes_per_mul * kIters;
const double bandwidth_gbps = total_bytes / elapsed_seconds / 1e9;
const double ops_per_mul = static_cast<double>(kM) * kN * kK * 2.0;
const double tflops = (ops_per_mul * kIters) / elapsed_seconds / 1e12;
std::printf("Reorder buffer size: %ld bytes\n", static_cast<long>(reorder_size));
std::printf("Iterations: %d\n", kIters);
std::printf("Elapsed time: %.4f s\n", elapsed_seconds);
std::printf("Effective bandwidth: %.2f GB/s\n", bandwidth_gbps);
std::printf("Int8 GEMM throughput: %.2f TOPS\n", tflops * 1e3);
if (!verify(a, b, c)) {
std::fprintf(stderr, "Verification failed.\n");
} else {
std::puts("Verification passed.");
}
std::free(a);
std::free(b);
std::free(b_reordered);
std::free(c);
return 0;
}

View File

@@ -1,119 +1,173 @@
#include <blis.h>
#include <dlfcn.h>
#include <unistd.h>
#include <chrono>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
// #define CHECK
namespace {
// B matrix is in col-major order
constexpr int kM = 3;
constexpr int kK = 7168;
constexpr int kN = 2048;
void fill_inputs(int8_t* a, int8_t* b) {
srand(static_cast<unsigned>(time(nullptr)));
for (int i = 0; i < kM * kK; ++i) {
a[i] = static_cast<int8_t>(rand() % 127);
}
for (int i = 0; i < kK * kN; ++i) {
b[i] = static_cast<int8_t>(rand() % 127);
}
}
void compute_reference(const int8_t* a, const int8_t* b, int32_t* ref) {
for (int m = 0; m < kM; ++m) {
for (int n = 0; n < kN; ++n) {
int32_t acc = 0;
for (int k = 0; k < kK; ++k) {
acc += static_cast<int32_t>(a[m * kK + k]) * static_cast<int32_t>(b[k * kN + n]);
}
ref[m * kN + n] = acc;
}
}
}
bool check_result(const int32_t* got, const int32_t* ref) {
for (int idx = 0; idx < kM * kN; ++idx) {
if (got[idx] != ref[idx]) {
std::printf("Mismatch at %d: got %d, expected %d\n", idx, got[idx], ref[idx]);
return false;
}
}
return true;
}
} // namespace
int main() {
// 矩阵维度 M 是 1024K 是 1024N 是 1024行主序
int M = 1024; // 行主序时A 的行长度为 K
const int K = 1024; // B 的行长度为 N
const int N = 1024; // C 的行长度为 N
const int iter = 10000; // 迭代次数
err_t err = BLIS_SUCCESS;
int8_t* a = static_cast<int8_t*>(bli_malloc_user(kM * kK, &err));
int8_t* b = static_cast<int8_t*>(bli_malloc_user(kK * kN, &err));
int8_t* b_rowmajor = static_cast<int8_t*>(bli_malloc_user(kK * kN, &err));
int8_t* b_reordered = nullptr;
int32_t* c = static_cast<int32_t*>(bli_malloc_user(kM * kN * sizeof(int32_t), &err));
int32_t* c_unp = static_cast<int32_t*>(bli_malloc_user(kM * kN * sizeof(int32_t), &err));
int32_t* ref = static_cast<int32_t*>(bli_malloc_user(kM * kN * sizeof(int32_t), &err));
// 分配矩阵内存
int8_t* A = (int8_t*)malloc(M * K * sizeof(int8_t));
int8_t* B = (int8_t*)malloc(K * N * sizeof(int8_t));
int32_t* C = (int32_t*)malloc(M * N * sizeof(int32_t));
// 初始化随机种子
srand((unsigned)time(NULL));
// 随机初始化 A范围 0 到 255和 B范围 -128 到 127
// 初始化矩阵 A 和 B
for (int j = 0; j < M * K; j++) {
// A[j] = rand() % 256;
A[j] = j;
}
for (int j = 0; j < K * N; j++) {
// B[j] = rand() % 256;
B[j] = j;
}
// 初始化矩阵 C
for (int j = 0; j < M * N; j++) {
C[j] = 0;
if (!a || !b || !c || !ref || !c_unp) {
std::fprintf(stderr, "Allocation failed\n");
bli_free_user(a);
bli_free_user(b);
bli_free_user(c);
bli_free_user(ref);
bli_free_user(c_unp);
return EXIT_FAILURE;
}
// 设置 cblas_gemm_s8u8s32 的参数
float alpha = 1.0f;
float beta = 0.0f;
int8_t oa = 0, ob = 0;
int32_t oc = 0;
// 打印矩阵 A、B
// printf("A=\n");
// for (int i = 0; i < M; i++) {
// for (int j = 0; j < K; j++) {
// printf("%d ", A[i * K + j]);
// }
// printf("\n");
// }
// printf("B=\n");
// for (int i = 0; i < N; i++) {
// for (int j = 0; j < K; j++) {
// printf("%d ", B[i * K + j]);
// }
// printf("\n");
// }
// printf("format: 'generate end'\n");
// 调用 cblas_gemm_s8u8s32 执行矩阵乘法C = i1(A+ao)(B+bo) + 0*C + oc
// 从m=10256 都测一遍速度,步长是 stride
int stride = 2;
int start_m = M;
for (int m = start_m; m <= M; m += stride) {
// 记录开始时间
auto start = std::chrono::high_resolution_clock::now();
#pragma GCC unroll 8
for (int i = 0; i < iter; i++) {
// cblas_gemm_s8s8s32(CblasRowMajor, CblasNoTrans, CblasTrans, CblasFixOffset, m, N / 2, K, alpha, A, K, oa, B, K,
// ob, beta, C, N, &oc);
aocl_gemm_s8s8s32os32('r', 'n', 't', m, N / 2, K, (int32_t)alpha, A, K, 'n', B, K, 'n', (int32_t)beta, C, N,
nullptr);
int8_t* B_high = B + K * N / 2;
int32_t* C_high = C + N / 2;
// cblas_gemm_s8s8s32(CblasRowMajor, CblasNoTrans, CblasTrans, CblasFixOffset, m, N / 2, K, alpha, A, K, oa,
// B_high,
// K, ob, beta, C_high, N, &oc);
aocl_gemm_s8s8s32os32('r', 'n', 't', m, N / 2, K, (int32_t)alpha, A, K, 'n', B_high, K, 'n', (int32_t)beta,
C_high, N, nullptr);
fill_inputs(a, b);
// transform B from col-major to row-major
for (int k = 0; k < kK; ++k) {
for (int n = 0; n < kN; ++n) {
// original B is in col-major: b[n * ld + k], here ld = kK
int8_t val = b[n * kK + k];
// target row-major: row index = k, col index = n
b_rowmajor[k * kN + n] = val;
}
}
#ifdef CHECK
// CHECK: printf inputs
std::puts("\nMatrix A:\n");
for (int m = 0; m < kM; ++m) {
for (int k = 0; k < kK; ++k) {
std::printf("%4d ", a[m * kK + k]);
}
std::puts("");
}
std::puts("\nMatrix B:\n");
for (int k = 0; k < kK; ++k) {
for (int n = 0; n < kN; ++n) {
std::printf("%4d ", b[n * kK + k]);
}
std::puts("");
}
#endif
std::memset(c, 0, kM * kN * sizeof(int32_t));
std::memset(c_unp, 0, kM * kN * sizeof(int32_t));
std::memset(ref, 0, kM * kN * sizeof(int32_t));
compute_reference(a, b_rowmajor, ref);
#ifdef CHECK
// CHECK: printf reference
std::puts("\nReference result:\n");
for (int m = 0; m < kM; ++m) {
for (int n = 0; n < kN; ++n) {
std::printf("%6d ", ref[m * kN + n]);
}
std::puts("");
}
#endif
const dim_t reorder_size = aocl_get_reorder_buf_size_s8s8s32os32('c', 'n', 'B', kK, kN);
b_reordered = static_cast<int8_t*>(bli_malloc_user(reorder_size, &err));
if (!b_reordered) {
std::fprintf(stderr, "Reorder buffer allocation failed\n");
bli_free_user(a);
bli_free_user(b);
bli_free_user(c);
bli_free_user(ref);
return EXIT_FAILURE;
}
aocl_reorder_s8s8s32os32('c', 'n', 'B', b, b_reordered, kK, kN, kK);
#ifdef CHECK
// CHECK: printf reordered B
std::puts("\nReordered Matrix B:\n");
for (int k = 0; k < kK; ++k) {
for (int n = 0; n < kN; ++n) {
std::printf("%4d ", b_reordered[k * kN + n]);
}
std::puts("");
}
std::printf("\nReorder buffer size: %zu bytes\n", reorder_size);
#endif
// 打印结果
// printf("result:\n");
// for (int i = 0; i < M; i++) {
// for (int j = 0; j < N; j++) {
// printf("%d ", C[i * N + j]);
// }
// printf("\n");
// }
const int32_t alpha = 1;
const int32_t beta = 0;
aocl_gemm_s8s8s32os32('r', 'n', 't', kM, kN, kK, alpha, a, kK, 'n', b_reordered, kK, 'r', beta, c, kN, nullptr);
aocl_gemm_s8s8s32os32('r', 'n', 't', kM, kN, kK, alpha, a, kK, 'n', b, kK, 'n', beta, c_unp, kN, nullptr);
#ifdef CHECK
// CHECK: printf AOCL result
std::puts("\nAOCL GEMM result (with reordered B):\n");
for (int m = 0; m < kM; ++m) {
for (int n = 0; n < kN; ++n) {
std::printf("%6d ", c[m * kN + n]);
}
std::puts("");
}
std::puts("\nAOCL GEMM result (without reordered B):\n");
for (int m = 0; m < kM; ++m) {
for (int n = 0; n < kN; ++n) {
std::printf("%6d ", c_unp[m * kN + n]);
}
std::puts("");
}
#endif
// 记录结束时间
auto end = std::chrono::high_resolution_clock::now();
// 计算总时长(秒)
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
double time_sec = duration.count() / 1e6; // 转换为秒
// 计算理论浮点运算次数并转换为 TFLOPS
double ops = iter * 2.0 * m * N * K;
double tflops = ops / (duration.count() * 1e6); // 转换为 TFLOPS
// 输出结果
printf("execute end,m is:%d\n", m);
// printf("执行时间: %.4f 秒\n", time_sec);
printf("计算性能: %.4f TFLOPS\n", tflops);
printf("\n");
if (check_result(c, ref)) {
std::puts("AOCL GEMM output matches reference.");
} else {
std::puts("AOCL GEMM output mismatch detected.");
}
// 释放资源
free(A);
free(B);
free(C);
if (check_result(c_unp, ref)) {
std::puts("unpack AOCL GEMM output matches reference.");
} else {
std::puts("unpack AOCL GEMM output mismatch detected.");
}
bli_free_user(a);
bli_free_user(b);
bli_free_user(b_rowmajor);
bli_free_user(b_reordered);
bli_free_user(c);
bli_free_user(c_unp);
bli_free_user(ref);
return 0;
}

View File

@@ -3,15 +3,15 @@ import sys
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
import torch
import ctypes
import cpuinfer_ext
from cpuinfer_ext.moe import MOEConfig, MOE, AMXBF16_MOE, AMXInt8_MOE, AMXInt4_MOE, AMXInt4_1_MOE
import kt_kernel_ext
from kt_kernel_ext.moe import MOEConfig, MOE, AMXBF16_MOE, AMXInt8_MOE, AMXInt4_MOE, AMXInt4_1_MOE
intermediate_size_full = 2048
moe_intermediate_size = 3072
hidden_size = 7168
experts_num = 256
num_experts_per_tok = 8
cpu_infer = cpuinfer_ext.CPUInfer(97)
cpu_infer = kt_kernel_ext.CPUInfer(97)
up = torch.empty(experts_num, intermediate_size_full, hidden_size, dtype=torch.bfloat16, device="cpu")

View File

@@ -13,7 +13,7 @@ import os, sys
import time
sys.path.append(os.path.dirname(__file__) + "/../build")
import cpuinfer_ext
import kt_kernel_ext
from flash_attn import flash_attn_with_kvcache
import torch
@@ -26,20 +26,20 @@ anchor_num = 1
cache_seqlen = 8192
cache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device="cpu")
seqlens_zero = torch.zeros((1,), dtype=torch.int32, device="cpu")
anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC
kv_type = cpuinfer_ext.kvcache.ggml_type.FP16
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER
anchor_type = kt_kernel_ext.kvcache.AnchorType.DYNAMIC
kv_type = kt_kernel_ext.kvcache.ggml_type.FP16
retrieval_type = kt_kernel_ext.kvcache.RetrievalType.LAYER
layer_step: int = 1
token_step: int = 1
layer_offset: int = 0
max_thread_num: int = 2
max_batch_size: int = 1
max_block_num: int = 512
CPUInfer = cpuinfer_ext.CPUInfer(max_thread_num)
CPUInfer = kt_kernel_ext.CPUInfer(max_thread_num)
validation_iter = 100
with torch.inference_mode(mode=True):
config = cpuinfer_ext.kvcache.KVCacheConfig(
config = kt_kernel_ext.kvcache.KVCacheConfig(
layer_num,
kv_head_num,
q_head_num,
@@ -56,7 +56,7 @@ with torch.inference_mode(mode=True):
max_batch_size,
max_thread_num,
)
local_kvcache = cpuinfer_ext.kvcache.KVCache(config)
local_kvcache = kt_kernel_ext.kvcache.KVCache(config)
kvcaches = []
block_table = (

View File

@@ -2,7 +2,7 @@ import os, sys
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
import cpuinfer_ext
import kt_kernel_ext
import torch
# Set fixed seed for reproducible results
@@ -49,7 +49,7 @@ max_len = 25600
num_experts_per_tok = 8
qlen = 1
layer_num = 1
CPUInfer = cpuinfer_ext.CPUInfer(40)
CPUInfer = kt_kernel_ext.CPUInfer(40)
validation_iter = 10
k_group_size = 64
debug_print_count = 16
@@ -302,7 +302,7 @@ def test_online_int4_kgroup_moe():
for _ in range(layer_num):
# Create Int4LowKGroup configuration (online quantization)
config = cpuinfer_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
@@ -320,7 +320,7 @@ def test_online_int4_kgroup_moe():
config.path = "./awq_dump_online"
# Create Int4LowKGroup MoE (online quantization during load_weights)
moe = cpuinfer_ext.moe.AMXInt4_1KGroup_MOE(config)
moe = kt_kernel_ext.moe.AMXInt4_1KGroup_MOE(config)
# Load weights (performs online quantization)
print(f"Physical Map: {physical_to_logical_map.data_ptr()}")
@@ -421,7 +421,7 @@ def test_awq_moe():
for _ in range(layer_num):
# Create AWQ MoE configuration
config = cpuinfer_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
# Set quantization config for Int4_1LowKGroup
@@ -449,7 +449,7 @@ def test_awq_moe():
config.pool = CPUInfer.backend_
# Create Int4_1LowKGroup MoE
moe = cpuinfer_ext.moe.AMXInt4_1KGroup_MOE(config)
moe = kt_kernel_ext.moe.AMXInt4_1KGroup_MOE(config)
# Load weights
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))

View File

@@ -2,8 +2,8 @@ import os, sys
import time
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
import cpuinfer_ext
from cpuinfer_ext.kvcache import ggml_type
import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
import logging
import sys
@@ -22,7 +22,7 @@ logger = logging.getLogger("reader")
from gguf.gguf_reader import GGUFReader
# load_layers = 6
load_layers = None
CPUInfer = cpuinfer_ext.CPUInfer(304)
CPUInfer = kt_kernel_ext.CPUInfer(304)
max_qlen = 4096
max_kvlen = 4096
page_size = 256
@@ -136,7 +136,7 @@ def build_mla(layer_idx, json_config, gguf_weights):
rope_theta = json_config["rope_theta"]
rope_scaling = json_config["rope_scaling"]
config = cpuinfer_ext.mla.MLAConfig(
config = kt_kernel_ext.mla.MLAConfig(
hidden_size,
q_lora_rank,
kv_lora_rank,
@@ -191,12 +191,12 @@ def build_mla(layer_idx, json_config, gguf_weights):
config.page_count = pages_count
if q_a_type == "F32":
mla = cpuinfer_ext.mla.MLA_F32(config)
mla = kt_kernel_ext.mla.MLA_F32(config)
elif q_a_type == "F16":
mla = cpuinfer_ext.mla.MLA_F16(config)
mla = kt_kernel_ext.mla.MLA_F16(config)
elif q_a_type == "BF16":
# mla = cpuinfer_ext.mla.MLA_F32(config)
mla = cpuinfer_ext.mla.MLA_QUAN_F32(config)
# mla = kt_kernel_ext.mla.MLA_F32(config)
mla = kt_kernel_ext.mla.MLA_QUAN_F32(config)
else:
raise ValueError(f"Unsupported data type: {q_a_type}")
@@ -207,7 +207,7 @@ def build_mla(layer_idx, json_config, gguf_weights):
def build_ffn(layer_idx, json_config, gguf_weights):
if f"blk.{layer_idx}.ffn_gate.weight" in gguf_weights: # dense
config = cpuinfer_ext.moe.MOEConfig(
config = kt_kernel_ext.moe.MOEConfig(
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
json_config["hidden_size"],
@@ -227,12 +227,12 @@ def build_ffn(layer_idx, json_config, gguf_weights):
config.down_proj = down.data_ptr()
config.down_type = type_to_ggml_type(down_type)
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
moe.load_weights()
return moe
elif f"blk.{layer_idx}.ffn_gate_exps.weight" in gguf_weights:
config = cpuinfer_ext.moe.MOEConfig(
config = kt_kernel_ext.moe.MOEConfig(
json_config["n_routed_experts"] + json_config["n_shared_experts"],
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
json_config["hidden_size"],
@@ -267,7 +267,7 @@ def build_ffn(layer_idx, json_config, gguf_weights):
config.down_proj = down.data_ptr()
config.down_type = type_to_ggml_type(down_type)
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
moe.load_weights()
return moe
@@ -276,7 +276,7 @@ def build_ffn(layer_idx, json_config, gguf_weights):
def build_moegate(layer_idx, json_config, gguf_weights):
config = cpuinfer_ext.gate.GateConfig(
config = kt_kernel_ext.gate.GateConfig(
json_config["hidden_size"],
json_config["num_experts_per_tok"],
json_config["n_routed_experts"],
@@ -296,7 +296,7 @@ def build_moegate(layer_idx, json_config, gguf_weights):
config.e_score_correction_bias = bias.data_ptr()
config.e_score_correction_bias_type = type_to_ggml_type(bias_type)
gate = cpuinfer_ext.gate.MoEGate(config)
gate = kt_kernel_ext.gate.MoEGate(config)
return gate
@@ -304,7 +304,7 @@ def build_moegate(layer_idx, json_config, gguf_weights):
def build_llm(json_config, gguf_weights):
general_config = cpuinfer_ext.GeneralConfig()
general_config = kt_kernel_ext.GeneralConfig()
general_config.vocab_size = json_config["vocab_size"]
general_config.hidden_size = json_config["hidden_size"]
general_config.num_experts_per_tok = json_config["num_experts_per_tok"]
@@ -326,8 +326,8 @@ def build_llm(json_config, gguf_weights):
general_config.pool = CPUInfer.backend_
llm = cpuinfer_ext.DeepseekV3ForCausalLM(general_config)
model = cpuinfer_ext.DeepseekV3Model(general_config)
llm = kt_kernel_ext.DeepseekV3ForCausalLM(general_config)
model = kt_kernel_ext.DeepseekV3Model(general_config)
llm.model = model
@@ -335,7 +335,7 @@ def build_llm(json_config, gguf_weights):
real_load_layers = json_config["num_hidden_layers"] if load_layers is None else load_layers
for i in range(real_load_layers):
layer = cpuinfer_ext.DeepseekV3DecoderLayer(general_config,i)
layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config,i)
attn_norm, attn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.attn_norm.weight")
ffn_norm, ffn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.ffn_norm.weight")

View File

@@ -2,8 +2,8 @@ import os, sys
import time
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
import cpuinfer_ext
from cpuinfer_ext.kvcache import ggml_type
import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
import logging
import sys
@@ -21,7 +21,7 @@ logger = logging.getLogger("reader")
from gguf.gguf_reader import GGUFReader
CPUInfer = cpuinfer_ext.CPUInfer(304)
CPUInfer = kt_kernel_ext.CPUInfer(304)
max_qlen = 4096
max_kvlen = 4096
page_size = 256
@@ -135,7 +135,7 @@ def build_mla(layer_idx, json_config, gguf_weights):
rope_theta = json_config["rope_theta"]
rope_scaling = json_config["rope_scaling"]
config = cpuinfer_ext.mla.MLAConfig(
config = kt_kernel_ext.mla.MLAConfig(
hidden_size,
q_lora_rank,
kv_lora_rank,
@@ -191,12 +191,12 @@ def build_mla(layer_idx, json_config, gguf_weights):
if q_a_type == "F32":
mla = cpuinfer_ext.mla.MLA_F32(config)
mla = kt_kernel_ext.mla.MLA_F32(config)
elif q_a_type == "F16":
mla = cpuinfer_ext.mla.MLA_F16(config)
mla = kt_kernel_ext.mla.MLA_F16(config)
elif q_a_type == "BF16":
mla = cpuinfer_ext.mla.MLA_QUAN_F32(config)
# mla = cpuinfer_ext.mla.MLA_F32(config)
mla = kt_kernel_ext.mla.MLA_QUAN_F32(config)
# mla = kt_kernel_ext.mla.MLA_F32(config)
else:
raise ValueError(f"Unsupported data type: {q_a_type}")
@@ -207,7 +207,7 @@ def build_mla(layer_idx, json_config, gguf_weights):
def build_ffn(layer_idx, json_config, gguf_weights):
if f"blk.{layer_idx}.ffn_gate.weight" in gguf_weights: # dense
config = cpuinfer_ext.moe.MOEConfig(
config = kt_kernel_ext.moe.MOEConfig(
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
json_config["hidden_size"],
@@ -227,12 +227,12 @@ def build_ffn(layer_idx, json_config, gguf_weights):
config.down_proj = down.data_ptr()
config.down_type = type_to_ggml_type(down_type)
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
moe.load_weights()
return moe
elif f"blk.{layer_idx}.ffn_gate_exps.weight" in gguf_weights:
config = cpuinfer_ext.moe.MOEConfig(
config = kt_kernel_ext.moe.MOEConfig(
json_config["n_routed_experts"] + json_config["n_shared_experts"],
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
json_config["hidden_size"],
@@ -267,7 +267,7 @@ def build_ffn(layer_idx, json_config, gguf_weights):
config.down_proj = down.data_ptr()
config.down_type = type_to_ggml_type(down_type)
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
moe.load_weights()
return moe
@@ -276,7 +276,7 @@ def build_ffn(layer_idx, json_config, gguf_weights):
def build_moegate(layer_idx, json_config, gguf_weights):
config = cpuinfer_ext.gate.GateConfig(
config = kt_kernel_ext.gate.GateConfig(
json_config["hidden_size"],
json_config["num_experts_per_tok"],
json_config["n_routed_experts"],
@@ -296,7 +296,7 @@ def build_moegate(layer_idx, json_config, gguf_weights):
config.e_score_correction_bias = bias.data_ptr()
config.e_score_correction_bias_type = type_to_ggml_type(bias_type)
gate = cpuinfer_ext.gate.MoEGate(config)
gate = kt_kernel_ext.gate.MoEGate(config)
return gate
@@ -304,7 +304,7 @@ def build_moegate(layer_idx, json_config, gguf_weights):
def build_llm(json_config, gguf_weights):
general_config = cpuinfer_ext.GeneralConfig()
general_config = kt_kernel_ext.GeneralConfig()
general_config.vocab_size = json_config["vocab_size"]
general_config.hidden_size = json_config["hidden_size"]
general_config.num_experts_per_tok = json_config["num_experts_per_tok"]
@@ -326,8 +326,8 @@ def build_llm(json_config, gguf_weights):
general_config.pool = CPUInfer.backend_
llm = cpuinfer_ext.DeepseekV3ForCausalLM(general_config)
model = cpuinfer_ext.DeepseekV3Model(general_config)
llm = kt_kernel_ext.DeepseekV3ForCausalLM(general_config)
model = kt_kernel_ext.DeepseekV3Model(general_config)
llm.model = model
@@ -335,7 +335,7 @@ def build_llm(json_config, gguf_weights):
for i in range(json_config["num_hidden_layers"]):
# for i in range(6):
# for i in [0,1,2,3,4,5,6,7,8,9,10]:
layer = cpuinfer_ext.DeepseekV3DecoderLayer(general_config,i)
layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config,i)
attn_norm, attn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.attn_norm.weight")
ffn_norm, ffn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.ffn_norm.weight")

View File

@@ -2,8 +2,8 @@ import os, sys
import time
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
import cpuinfer_ext
from cpuinfer_ext.kvcache import ggml_type
import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
import logging
import sys
@@ -22,11 +22,11 @@ logger = logging.getLogger("reader")
from gguf.gguf_reader import GGUFReader
# load_layers = 3
load_layers = None
worker_config = cpuinfer_ext.WorkerPoolConfig()
worker_config = kt_kernel_ext.WorkerPoolConfig()
worker_config.subpool_count = 2
worker_config.subpool_numa_map= [0,1]
worker_config.subpool_thread_count = [72,72]
CPUInfer = cpuinfer_ext.CPUInfer(worker_config)
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
max_qlen = 4096
max_kvlen = 4096
@@ -141,7 +141,7 @@ def build_mla(layer_idx, json_config, gguf_weights):
rope_theta = json_config["rope_theta"]
rope_scaling = json_config["rope_scaling"]
config = cpuinfer_ext.mla.MLAConfig(
config = kt_kernel_ext.mla.MLAConfig(
hidden_size,
q_lora_rank,
kv_lora_rank,
@@ -196,12 +196,12 @@ def build_mla(layer_idx, json_config, gguf_weights):
config.page_count = pages_count
if q_a_type == "F32":
mla = cpuinfer_ext.mla.MLA_F32(config)
mla = kt_kernel_ext.mla.MLA_F32(config)
elif q_a_type == "F16":
mla = cpuinfer_ext.mla.MLA_F16(config)
mla = kt_kernel_ext.mla.MLA_F16(config)
elif q_a_type == "BF16":
# mla = cpuinfer_ext.mla.MLA_F32(config)
mla = cpuinfer_ext.mla.MLA_QUAN_F32(config)
# mla = kt_kernel_ext.mla.MLA_F32(config)
mla = kt_kernel_ext.mla.MLA_QUAN_F32(config)
else:
raise ValueError(f"Unsupported data type: {q_a_type}")
@@ -212,7 +212,7 @@ def build_mla(layer_idx, json_config, gguf_weights):
def build_ffn(layer_idx, json_config, gguf_weights):
if f"blk.{layer_idx}.ffn_gate.weight" in gguf_weights: # dense
config = cpuinfer_ext.moe.MOEConfig(
config = kt_kernel_ext.moe.MOEConfig(
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
json_config["hidden_size"],
@@ -232,12 +232,12 @@ def build_ffn(layer_idx, json_config, gguf_weights):
config.down_proj = down.data_ptr()
config.down_type = type_to_ggml_type(down_type)
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
moe.load_weights()
return moe
elif f"blk.{layer_idx}.ffn_gate_exps.weight" in gguf_weights:
config = cpuinfer_ext.moe.MOEConfig(
config = kt_kernel_ext.moe.MOEConfig(
json_config["n_routed_experts"] + json_config["n_shared_experts"],
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
json_config["hidden_size"],
@@ -272,7 +272,7 @@ def build_ffn(layer_idx, json_config, gguf_weights):
config.down_proj = down.data_ptr()
config.down_type = type_to_ggml_type(down_type)
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
moe.load_weights()
return moe
@@ -281,7 +281,7 @@ def build_ffn(layer_idx, json_config, gguf_weights):
def build_moegate(layer_idx, json_config, gguf_weights):
config = cpuinfer_ext.gate.GateConfig(
config = kt_kernel_ext.gate.GateConfig(
json_config["hidden_size"],
json_config["num_experts_per_tok"],
json_config["n_routed_experts"],
@@ -301,7 +301,7 @@ def build_moegate(layer_idx, json_config, gguf_weights):
config.e_score_correction_bias = bias.data_ptr()
config.e_score_correction_bias_type = type_to_ggml_type(bias_type)
gate = cpuinfer_ext.gate.MoEGate(config)
gate = kt_kernel_ext.gate.MoEGate(config)
return gate
@@ -309,7 +309,7 @@ def build_moegate(layer_idx, json_config, gguf_weights):
def build_llm(json_config, gguf_weights):
general_config = cpuinfer_ext.GeneralConfig()
general_config = kt_kernel_ext.GeneralConfig()
general_config.vocab_size = json_config["vocab_size"]
general_config.hidden_size = json_config["hidden_size"]
general_config.num_experts_per_tok = json_config["num_experts_per_tok"]
@@ -331,8 +331,8 @@ def build_llm(json_config, gguf_weights):
general_config.pool = CPUInfer.backend_
llm = cpuinfer_ext.DeepseekV3ForCausalLM(general_config)
model = cpuinfer_ext.DeepseekV3Model(general_config)
llm = kt_kernel_ext.DeepseekV3ForCausalLM(general_config)
model = kt_kernel_ext.DeepseekV3Model(general_config)
llm.model = model
@@ -341,7 +341,7 @@ def build_llm(json_config, gguf_weights):
for i in range(real_load_layers):
# for i in [2,3]:
layer = cpuinfer_ext.DeepseekV3DecoderLayer(general_config,i)
layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config,i)
attn_norm, attn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.attn_norm.weight")
ffn_norm, ffn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.ffn_norm.weight")

View File

@@ -4,8 +4,8 @@ import time
from typing import Optional
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
from cpuinfer_ext.kvcache import ggml_type
import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
from torch import nn
@@ -171,7 +171,7 @@ def torch_gate(hidden_states):
def cpuinfer_gate(hidden_states):
config = cpuinfer_ext.gate.GateConfig(
config = kt_kernel_ext.gate.GateConfig(
hidden_size,
num_experts_per_token,
n_routed_experts,
@@ -179,7 +179,7 @@ def cpuinfer_gate(hidden_states):
topk_group,
)
CPUInfer = cpuinfer_ext.CPUInfer(64)
CPUInfer = kt_kernel_ext.CPUInfer(64)
config.routed_scaling_factor = routed_scaling_factor
config.pool = CPUInfer.backend_
@@ -188,7 +188,7 @@ def cpuinfer_gate(hidden_states):
config.e_score_correction_bias = bias.data_ptr()
config.e_score_correction_bias_type = ggml_type.FP32
gate = cpuinfer_ext.gate.MoEGate(config)
gate = kt_kernel_ext.gate.MoEGate(config)

View File

@@ -12,7 +12,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
import kt_kernel_ext
import torch
input_size = 16384
@@ -23,7 +23,7 @@ proj_type = 1 # ggml_type::GGML_TYPE_F16
hidden_type = 1 # ggml_type::GGML_TYPE_F16
qlen = 30
layer_num = 10
CPUInfer = cpuinfer_ext.CPUInfer(48)
CPUInfer = kt_kernel_ext.CPUInfer(48)
validation_iter = 100
with torch.inference_mode(mode=True):
@@ -31,8 +31,8 @@ with torch.inference_mode(mode=True):
projs = []
for _ in range(layer_num):
proj = torch.randn((output_size, input_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous()
config = cpuinfer_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type)
linear = cpuinfer_ext.linear.Linear(config)
config = kt_kernel_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type)
linear = kt_kernel_ext.linear.Linear(config)
projs.append(proj)
linears.append(linear)

View File

@@ -4,8 +4,8 @@ import time
from typing import Optional
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
from cpuinfer_ext.kvcache import ggml_type
import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
from torch import inf, nn
from torch.nn import init
@@ -110,7 +110,7 @@ rope_scaling = {
CPUInfer = cpuinfer_ext.CPUInfer(30)
CPUInfer = kt_kernel_ext.CPUInfer(30)
validation_iter = 100
@@ -214,7 +214,7 @@ def test_cpu_mla():
kv_b_proj_weight = kv_b_proj.weight.to(weight_type).to('cpu').contiguous()
o_proj_weight = o_proj.weight.to(weight_type).to('cpu').contiguous()
config = cpuinfer_ext.mla.MLAConfig(
config = kt_kernel_ext.mla.MLAConfig(
hidden_size,
q_lora_rank,
kv_lora_rank,
@@ -272,12 +272,12 @@ def test_cpu_mla():
if weight_type == torch.float32:
mla = cpuinfer_ext.mla.MLA_F32(config)
mla = kt_kernel_ext.mla.MLA_F32(config)
elif weight_type == torch.float16:
mla = cpuinfer_ext.mla.MLA_F16(config)
mla = kt_kernel_ext.mla.MLA_F16(config)
elif weight_type == torch.bfloat16:
# mla = cpuinfer_ext.mla.MLA_F32(config)
mla = cpuinfer_ext.mla.MLA_QUAN_F32(config)
# mla = kt_kernel_ext.mla.MLA_F32(config)
mla = kt_kernel_ext.mla.MLA_QUAN_F32(config)
else:
raise ValueError(f"Unsupported data type: {weight_type}")

View File

@@ -4,8 +4,8 @@ import time
from typing import Optional
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
from cpuinfer_ext.kvcache import ggml_type
import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
from torch import inf, nn
from torch.nn import init
@@ -110,7 +110,7 @@ rope_scaling = {
CPUInfer = cpuinfer_ext.CPUInfer(64)
CPUInfer = kt_kernel_ext.CPUInfer(64)
validation_iter = 100
@@ -214,7 +214,7 @@ def build_mla():
kv_b_proj_weight = kv_b_proj.weight.to(weight_type).to('cpu').contiguous()
o_proj_weight = o_proj.weight.to(weight_type).to('cpu').contiguous()
config = cpuinfer_ext.mla.MLAConfig(
config = kt_kernel_ext.mla.MLAConfig(
hidden_size,
q_lora_rank,
kv_lora_rank,
@@ -271,11 +271,11 @@ def build_mla():
if weight_type == torch.float32:
mla = cpuinfer_ext.mla.MLA_F32(config)
mla = kt_kernel_ext.mla.MLA_F32(config)
elif weight_type == torch.float16:
mla = cpuinfer_ext.mla.MLA_F16(config)
mla = kt_kernel_ext.mla.MLA_F16(config)
elif weight_type == torch.bfloat16:
mla = cpuinfer_ext.mla.MLA_F32(config)
mla = kt_kernel_ext.mla.MLA_F32(config)
else:
raise ValueError(f"Unsupported data type: {weight_type}")

View File

@@ -4,8 +4,8 @@ import time
from typing import Optional
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
from cpuinfer_ext.kvcache import ggml_type
import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
from torch import inf, nn
from torch.nn import init

View File

@@ -2,8 +2,8 @@ import os,sys
import time
from typing import Optional
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
from cpuinfer_ext.kvcache import ggml_type
import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
from torch import nn
from torch.nn import init
@@ -54,7 +54,7 @@ rope_scaling = {
CPUInfer = cpuinfer_ext.CPUInfer(64)
CPUInfer = kt_kernel_ext.CPUInfer(64)
validation_iter = 100
@@ -79,7 +79,7 @@ o_proj_weight = o_proj.weight.to(torch.float16).to('cpu').contiguous()
config = cpuinfer_ext.mla.MLAConfig(
config = kt_kernel_ext.mla.MLAConfig(
hidden_size,
q_lora_rank,
kv_lora_rank,
@@ -115,7 +115,7 @@ config.pool = CPUInfer.backend_
mla = cpuinfer_ext.mla.MLA(config)
mla = kt_kernel_ext.mla.MLA(config)
mla.load_weights()
mla.set_local_pages(pages_count)

View File

@@ -12,7 +12,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
import kt_kernel_ext
import torch
hidden_size = 5120
@@ -25,7 +25,7 @@ down_type = 1 # ggml_type::GGML_TYPE_F16
hidden_type = 1 # ggml_type::GGML_TYPE_F16
qlen = 30
layer_num = 10
CPUInfer = cpuinfer_ext.CPUInfer(48)
CPUInfer = kt_kernel_ext.CPUInfer(48)
validation_iter = 100
def act_fn(x):
@@ -47,8 +47,8 @@ with torch.inference_mode(mode=True):
gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous()
up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous()
down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous()
config = cpuinfer_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)
mlp = cpuinfer_ext.mlp.MLP(config)
config = kt_kernel_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)
mlp = kt_kernel_ext.mlp.MLP(config)
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)

View File

@@ -12,10 +12,10 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import os, sys
import time
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
import kt_kernel_ext
import torch
from tqdm import tqdm
from cpuinfer_ext.kvcache import ggml_type
from kt_kernel_ext.kvcache import ggml_type
torch.manual_seed(0)
@@ -36,7 +36,7 @@ layer_num = 1
# num_experts_per_tok = 8
# qlen = 1024
# layer_num = 1
CPUInfer = cpuinfer_ext.CPUInfer(64)
CPUInfer = kt_kernel_ext.CPUInfer(64)
validation_iter = 10
def act_fn(x):
@@ -83,10 +83,10 @@ def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
def to_cpuinfer_tensor(tensor, type):
size = torch.prod(torch.tensor(tensor.shape, dtype=torch.int32)).item()
return cpuinfer_ext.utils.from_float(tensor.data_ptr(), size, type)
return kt_kernel_ext.utils.from_float(tensor.data_ptr(), size, type)
def from_cpuinfer_tensor(tensor, size, type):
return cpuinfer_ext.utils.to_float(tensor.data_ptr(), size, type)
return kt_kernel_ext.utils.to_float(tensor.data_ptr(), size, type)
qlens = [1,64] #[64, 512, 2048, 8192, 16384]
# gate_types = [ggml_type.FP32, ggml_type.FP16, ggml_type.Q8_0, ggml_type.Q6_K, ggml_type.Q5_K, ggml_type.Q4_K, ggml_type.Q3_K]
@@ -118,7 +118,7 @@ for qlen in qlens:
up_tensor = to_cpuinfer_tensor(up_proj, up_type)
down_tensor = to_cpuinfer_tensor(down_proj, down_type)
config = cpuinfer_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
config.pool = CPUInfer.backend_
config.stride = stride
config.group_min_len = group_min_len
@@ -132,7 +132,7 @@ for qlen in qlens:
config.hidden_type = hidden_type
moe = cpuinfer_ext.moe.MOE(config)
moe = kt_kernel_ext.moe.MOE(config)
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)

View File

@@ -1,21 +1,22 @@
import os, sys
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
print("sys.path:", sys.path)
import cpuinfer_ext
import torch
import kt_kernel_ext
expert_num = 256
hidden_size = 7168
intermediate_size = 2048
max_len = 25600
num_experts_per_tok = 8
# qlen = 1
qlen = 640
qlen = 1
# qlen = 640
layer_num = 1
CPUInfer = cpuinfer_ext.CPUInfer(40)
CPUInfer = kt_kernel_ext.CPUInfer(90)
# validation_iter = 10000
validation_iter = 10
validation_iter = 2
k_group_size = 64
debug_print_count = 16 # Number of values to print in debug output
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
@@ -126,7 +127,7 @@ def test_moe(quant_mode: str):
.to("cpu")
.contiguous()
)
config = cpuinfer_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
@@ -134,25 +135,25 @@ def test_moe(quant_mode: str):
config.gate_scale = 0
config.pool = CPUInfer.backend_
if quant_mode == "bf16":
moe = cpuinfer_ext.moe.AMXBF16_MOE(config)
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
CPUInfer.submit(moe.warm_up_task())
CPUInfer.sync()
elif quant_mode == "int8":
moe = cpuinfer_ext.moe.AMXInt8_MOE(config)
moe = kt_kernel_ext.moe.AMXInt8_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
# CPUInfer.submit(moe.warm_up_task())
# CPUInfer.sync()
elif quant_mode == "int4":
moe = cpuinfer_ext.moe.AMXInt4_MOE(config)
moe = kt_kernel_ext.moe.AMXInt4_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
CPUInfer.submit(moe.warm_up_task())
CPUInfer.sync()
elif quant_mode == "int4_1":
moe = cpuinfer_ext.moe.AMXInt4_1_MOE(config)
moe = kt_kernel_ext.moe.AMXInt4_1_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
CPUInfer.submit(moe.warm_up_task())
@@ -161,11 +162,12 @@ def test_moe(quant_mode: str):
config.quant_config.bits = 4
config.quant_config.group_size = k_group_size
config.quant_config.zero_point = True
moe = cpuinfer_ext.moe.AMXInt4_1KGroup_MOE(config)
moe = kt_kernel_ext.moe.AMXInt4_1KGroup_MOE(config)
# import debugpy
# debugpy.listen(("127.0.0.1", 5678))
# debugpy.wait_for_client()
# debugpy.breakpoint()
print(f"the physical_logical map:{physical_to_logical_map.data_ptr()}")
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
# CPUInfer.submit(moe.warm_up_task())
@@ -260,7 +262,7 @@ def test_moe(quant_mode: str):
# 5. Final output comparison
# test_moe("bf16")
# test_moe("int8")
# test_moe("int4")
# test_moe("int4_1")
test_moe("int8")
test_moe("int4")
test_moe("int4_1")
test_moe("int4_1k")

View File

@@ -0,0 +1,187 @@
#!/usr/bin/env python
# coding=utf-8
"""
Description :
Author : chenht2022
Date : 2024-07-25 10:32:05
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2024-08-06 10:38:05
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import os, sys
import time
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
os.environ["BLAS_NUM_THREADS"] = "1"
import torch
import kt_kernel_ext
expert_num = 16
hidden_size = 7168
intermediate_size = 2048
max_len = 4096
num_experts_per_tok = 8
m_block = 320
n_block_up_gate = 32
n_block_down = 64
n_block_up_gate_prefi = 32
n_block_down_prefi = 64
# qlen = 1
qlen = 1024
layer_num = 1
CPUInfer = kt_kernel_ext.CPUInfer(160)
# validation_iter = 10000
validation_iter = 1
def act_fn(x):
return x / (1.0 + torch.exp(-x))
def mlp_torch(input, gate_proj, up_proj, down_proj):
gate_buf = torch.mm(input, gate_proj.t())
up_buf = torch.mm(input, up_proj.t())
intermediate = act_fn(gate_buf) * up_buf
ret = torch.mm(intermediate, down_proj.t())
return ret
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
cnts.scatter_(1, expert_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = expert_ids.view(-1).argsort()
sorted_tokens = input[idxs // expert_ids.shape[1]]
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
t_output = (
new_x.view(*expert_ids.shape, -1)
.type(weights.dtype)
.mul_(weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return t_output
def test_moe(quant_mode: str):
assert quant_mode == "int8" or quant_mode == "int4" or quant_mode == "int4_1"
with torch.inference_mode(mode=True):
moes = []
gate_projs = []
up_projs = []
down_projs = []
for _ in range(layer_num):
gate_proj = (
torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device="cpu")
.to("cpu")
.contiguous()
)
up_proj = (
torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device="cpu")
.to("cpu")
.contiguous()
)
down_proj = (
torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.bfloat16, device="cpu")
.to("cpu")
.contiguous()
)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
config.max_len = max_len
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
config.pool = CPUInfer.backend_
if quant_mode == "int8":
d = kt_kernel_ext.moe.tiling.get_int8()
nbug_prefi = n_block_up_gate_prefi
nbd_prefi = n_block_down_prefi
kb = d["k_block"]
nb = d["n_block"]
mb = m_block
nbug = n_block_up_gate
nbd = n_block_down
print(
f"Int8 Tiling: nbug {nbug}, nbd {nbd}, nb {nb}, mb {mb}, kb {kb}, nbug_prefi {nbug_prefi}, nbd_prefi {nbd_prefi}"
)
kt_kernel_ext.moe.tiling.set_int8(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
moe = kt_kernel_ext.moe.Int8_KERNEL_MOE(config)
CPUInfer.submit(moe.load_weights_task())
CPUInfer.sync()
# CPUInfer.submit(moe.warm_up_task())
# CPUInfer.sync()
elif quant_mode == "int4":
moe = kt_kernel_ext.moe.Int4_KERNEL_MOE(config)
CPUInfer.submit(moe.load_weights_task())
CPUInfer.sync()
CPUInfer.submit(moe.warm_up_task())
CPUInfer.sync()
else:
raise ValueError(f"Unsupported quantization mode: {quant_mode}")
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)
moes.append(moe)
# validation
for i in range(validation_iter):
bsz_tensor = torch.tensor([qlen], device="cpu")
expert_ids = torch.stack(
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
).contiguous()
weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
input = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
input = input / 100
# 打印 input 的内容
print("input:", input)
moe = moes[i % layer_num]
# print('expert ids:',expert_ids)
CPUInfer.submit(
moe.forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids.data_ptr(),
weights.data_ptr(),
input.data_ptr(),
output.data_ptr(),
False,
)
)
CPUInfer.sync()
print("cpuinfer output", output)
gate_proj = gate_projs[i % layer_num]
up_proj = up_projs[i % layer_num]
down_proj = down_projs[i % layer_num]
t_output = moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj)
print("torch output", t_output)
# print(output - t_output)
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
print("diff = ", diff)
if quant_mode == "int4":
assert diff < 0.35
else:
assert diff < 0.05
test_moe("int8")
# test_moe("int4")

View File

@@ -14,7 +14,7 @@ import time
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
os.environ["BLAS_NUM_THREADS"] = "1"
import cpuinfer_ext
import kt_kernel_ext
import torch
expert_num = 16
@@ -25,7 +25,7 @@ num_experts_per_tok = 8
qlen = 512
# qlen = 640
layer_num = 1
CPUInfer = cpuinfer_ext.CPUInfer(112)
CPUInfer = kt_kernel_ext.CPUInfer(112)
# validation_iter = 10000
validation_iter = 1
@@ -97,26 +97,26 @@ def test_moe(quant_mode: str):
.to("cpu")
.contiguous()
)
config = cpuinfer_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
config.max_len = max_len
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
config.pool = CPUInfer.backend_
if quant_mode == "bf16":
moe = cpuinfer_ext.moe.AMXBF16_MOE(config)
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
CPUInfer.submit(moe.load_weights_task())
CPUInfer.sync()
CPUInfer.submit(moe.warm_up_task())
CPUInfer.sync()
elif quant_mode == "int8":
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
CPUInfer.submit(moe.load_weights_task())
CPUInfer.sync()
# CPUInfer.submit(moe.warm_up_task())
# CPUInfer.sync()
elif quant_mode == "int4":
moe = cpuinfer_ext.moe.KMLInt4_MOE(config)
moe = kt_kernel_ext.moe.KMLInt4_MOE(config)
CPUInfer.submit(moe.load_weights_task())
CPUInfer.sync()
CPUInfer.submit(moe.warm_up_task())

View File

@@ -8,12 +8,19 @@
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
// Python bindings
#include <sys/types.h>
#include <cstddef>
#include "cpu_backend/cpuinfer.h"
#include "cpu_backend/worker_pool.h"
// #include "device_launch_parameters.h"
#include "llamafile/flags.h"
#include "operators/common.hpp"
#if defined(USE_MOE_KERNEL)
#include "operators/moe_kernel/la/kernel.hpp"
#include "operators/moe_kernel/moe.hpp"
#endif
#if defined(__aarch64__) && defined(CPU_USE_KML)
#if defined(KTRANSFORMERS_CPU_MLA)
#include "operators/kml/deepseekv3.hpp"
@@ -22,26 +29,27 @@
#include "operators/kml/mla_int8.hpp"
#endif
#include "operators/kml/moe.hpp"
static const bool _is_plain_ = true;
#else
static const bool _is_plain_ = false;
#endif
#ifdef __x86_64__
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
#include "operators/amx/awq-moe.hpp"
#include "operators/amx/la/amx_kernels.hpp"
#include "operators/amx/moe.hpp"
#endif
#include <pybind11/stl.h> // std::vector/std::pair/std::string conversions
#include <cstdint>
#include <iostream>
#include <memory>
#include <string>
#include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h"
#include "operators/llamafile/mla.hpp"
#include "operators/llamafile/mlp.h"
#include "operators/llamafile/moe.hpp"
#include "pybind11/functional.h"
#include "pybind11/operators.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
using namespace pybind11::literals;
@@ -160,17 +168,25 @@ class MOEBindings {
struct Args {
CPUInfer* cpuinfer;
TP_MOE<T>* moe;
const uint64_t* physical_to_logical_map;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&TP_MOE<T>::load_weights, args_->moe, args_->physical_to_logical_map);
args_->cpuinfer->enqueue(&TP_MOE<T>::load_weights, args_->moe);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe,
intptr_t physical_to_logical_map) {
Args* args = new Args{nullptr, moe.get(), (const uint64_t*)physical_to_logical_map};
const uintptr_t physical_to_logical_map = 0) {
Args* args = new Args{nullptr, moe.get()};
if (physical_to_logical_map) {
printf("debug physical_to_logical_map in arg:%lu\n", physical_to_logical_map);
moe->config.physical_to_logical_map = reinterpret_cast<void*>(physical_to_logical_map);
printf("moe ptr:%p,confirm: moe->config.physical_to_logical_map:%lu\n", reinterpret_cast<void*>(moe.get()),
reinterpret_cast<uintptr_t>(moe->config.physical_to_logical_map));
}
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe) {
return cpuinfer_interface(moe, 0);
}
};
class ForwardBindings {
public:
@@ -196,10 +212,41 @@ class MOEBindings {
Args* args = new Args{nullptr, moe.get(), qlen, k, expert_ids, weights, input, output, incremental};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe, intptr_t qlen, int k,
intptr_t expert_ids, intptr_t weights, intptr_t input,
intptr_t output) {
return cpuinfer_interface(moe, qlen, k, expert_ids, weights, input, output, false);
}
};
};
PYBIND11_MODULE(cpuinfer_ext, m) {
template <typename MoeTP>
void bind_moe_module(py::module_& moe_module, const char* name) {
using MoeClass = TP_MOE<MoeTP>;
using MoeBindings = MOEBindings<MoeTP>;
py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name)
.def(py::init<GeneralMOEConfig>())
.def("warm_up_task", &MoeBindings::WarmUpBindings::cpuinfer_interface)
.def("load_weights_task",
py::overload_cast<std::shared_ptr<MoeClass>>(&MoeBindings::LoadWeightsBindings::cpuinfer_interface))
.def("load_weights_task",
py::overload_cast<std::shared_ptr<MoeClass>, const uintptr_t>(
&MoeBindings::LoadWeightsBindings::cpuinfer_interface),
py::arg("physical_to_logical_map"))
// .def("forward_task", &MoeBindings::ForwardBindings::cpuinfer_interface)
.def("forward_task",
py::overload_cast<std::shared_ptr<MoeClass>, intptr_t, int, intptr_t, intptr_t, intptr_t, intptr_t>(
&MoeBindings::ForwardBindings::cpuinfer_interface))
.def("forward_task",
py::overload_cast<std::shared_ptr<MoeClass>, intptr_t, int, intptr_t, intptr_t, intptr_t, intptr_t, bool>(
&MoeBindings::ForwardBindings::cpuinfer_interface))
.def("warm_up", &MoeClass::warm_up)
.def("load_weights", &MoeClass::load_weights)
.def("forward", &MoeClass::forward_binding);
}
PYBIND11_MODULE(kt_kernel_ext, m) {
py::class_<WorkerPool>(m, "WorkerPool").def(py::init<int>());
py::class_<WorkerPoolConfig>(m, "WorkerPoolConfig")
.def(py::init<>())
@@ -399,11 +446,15 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
auto moe_module = m.def_submodule("moe");
py::class_<GeneralMOEConfig>(moe_module, "MOEConfig")
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size) {
return GeneralMOEConfig(expert_num, routed_expert_num, hidden_size, intermediate_size);
}))
.def(py::init(
[](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int num_gpu_experts) {
return GeneralMOEConfig(expert_num, routed_expert_num, hidden_size, intermediate_size, num_gpu_experts);
GeneralMOEConfig cfg(expert_num, routed_expert_num, hidden_size, intermediate_size);
cfg.num_gpu_experts = num_gpu_experts;
return cfg;
}))
.def_readwrite("layer_idx", &GeneralMOEConfig::layer_idx)
.def_readwrite("pool", &GeneralMOEConfig::pool)
@@ -454,109 +505,92 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
py::class_<MoE_Interface, std::shared_ptr<MoE_Interface>>(moe_module, "MoE_Interface");
py::class_<TP_MOE<LLAMA_MOE_TP>, MoE_Interface, std::shared_ptr<TP_MOE<LLAMA_MOE_TP>>>(moe_module, "MOE")
.def(py::init<GeneralMOEConfig>())
.def("warm_up_task", &MOEBindings<LLAMA_MOE_TP>::WarmUpBindings::cpuinfer_interface)
.def("load_weights_task", &MOEBindings<LLAMA_MOE_TP>::LoadWeightsBindings::cpuinfer_interface)
.def("forward_task", &MOEBindings<LLAMA_MOE_TP>::ForwardBindings::cpuinfer_interface)
.def("warm_up", &TP_MOE<LLAMA_MOE_TP>::warm_up)
.def("load_weights", &TP_MOE<LLAMA_MOE_TP>::load_weights)
.def("forward", &TP_MOE<LLAMA_MOE_TP>::forward_binding);
bind_moe_module<LLAMA_MOE_TP>(moe_module, "MOE");
#ifdef __x86_64__
py::class_<TP_MOE<AMX_MOE_TP<amx::GemmKernel224BF>>, MoE_Interface,
std::shared_ptr<TP_MOE<AMX_MOE_TP<amx::GemmKernel224BF>>>>(moe_module, "AMXBF16_MOE")
.def(py::init<GeneralMOEConfig>())
.def("warm_up_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224BF>>::WarmUpBindings::cpuinfer_interface)
.def("load_weights_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224BF>>::LoadWeightsBindings::cpuinfer_interface)
.def("forward_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224BF>>::ForwardBindings::cpuinfer_interface)
.def("warm_up", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224BF>>::warm_up)
.def("load_weights", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224BF>>::load_weights)
.def("forward", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224BF>>::forward_binding);
py::class_<TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int8>>, MoE_Interface,
std::shared_ptr<TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int8>>>>(moe_module, "AMXInt8_MOE")
.def(py::init<GeneralMOEConfig>())
.def("warm_up_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int8>>::WarmUpBindings::cpuinfer_interface)
.def("load_weights_task",
&MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int8>>::LoadWeightsBindings::cpuinfer_interface)
.def("forward_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int8>>::ForwardBindings::cpuinfer_interface)
.def("warm_up", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int8>>::warm_up)
.def("load_weights", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int8>>::load_weights)
.def("forward", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int8>>::forward_binding);
py::class_<TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4>>, MoE_Interface,
std::shared_ptr<TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4>>>>(moe_module, "AMXInt4_MOE")
.def(py::init<GeneralMOEConfig>())
.def("warm_up_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int4>>::WarmUpBindings::cpuinfer_interface)
.def("load_weights_task",
&MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int4>>::LoadWeightsBindings::cpuinfer_interface)
.def("forward_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int4>>::ForwardBindings::cpuinfer_interface)
.def("warm_up", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4>>::warm_up)
.def("load_weights", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4>>::load_weights)
.def("forward", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4>>::forward_binding);
py::class_<TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4_1>>, MoE_Interface,
std::shared_ptr<TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4_1>>>>(moe_module, "AMXInt4_1_MOE")
.def(py::init<GeneralMOEConfig>())
.def("warm_up_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int4_1>>::WarmUpBindings::cpuinfer_interface)
.def("load_weights_task",
&MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int4_1>>::LoadWeightsBindings::cpuinfer_interface)
.def("forward_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int4_1>>::ForwardBindings::cpuinfer_interface)
.def("warm_up", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4_1>>::warm_up)
.def("load_weights", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4_1>>::load_weights)
.def("forward", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4_1>>::forward_binding);
// py::class_<TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>, MoE_Interface,
// std::shared_ptr<TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>>>(moe_module, "AMXInt4KGroup_MOE")
// .def(py::init<GeneralMOEConfig>())
// .def("warm_up_task",
// &MOEBindings<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>::WarmUpBindings::cpuinfer_interface)
// .def("load_weights_task",
// &MOEBindings<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>::LoadWeightsBindings::cpuinfer_interface)
// .def("forward_task",
// &MOEBindings<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>::ForwardBindings::cpuinfer_interface)
// .def("warm_up", &TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>::warm_up)
// .def("load_weights", &TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>::load_weights)
// .def("forward", &TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>::forward_binding);
py::class_<TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>, MoE_Interface,
std::shared_ptr<TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>>>(moe_module,
"AMXInt4_1KGroup_MOE")
.def(py::init<GeneralMOEConfig>())
.def("warm_up_task",
&MOEBindings<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>::WarmUpBindings::cpuinfer_interface)
.def("load_weights_task",
&MOEBindings<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>::LoadWeightsBindings::cpuinfer_interface)
.def("forward_task",
&MOEBindings<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>::ForwardBindings::cpuinfer_interface)
.def("warm_up", &TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>::warm_up)
.def("load_weights", &TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>::load_weights)
.def("forward", &TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>::forward_binding);
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224BF>>(moe_module, "AMXBF16_MOE");
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int8>>(moe_module, "AMXInt8_MOE");
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4>>(moe_module, "AMXInt4_MOE");
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
#endif
#if defined(USE_MOE_KERNEL)
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");
#if defined(__aarch64__) && defined(CPU_USE_KML)
py::class_<TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt8>>, MoE_Interface,
std::shared_ptr<TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt8>>>>(moe_module, "KMLInt8_MOE")
.def(py::init<GeneralMOEConfig>())
.def("warm_up_task", &MOEBindings<KML_MOE_TP<arm_kml::GemmKernelInt8>>::WarmUpBindings::cpuinfer_interface)
.def("load_weights_task",
&MOEBindings<KML_MOE_TP<arm_kml::GemmKernelInt8>>::LoadWeightsBindings::cpuinfer_interface)
.def("forward_task", &MOEBindings<KML_MOE_TP<arm_kml::GemmKernelInt8>>::ForwardBindings::cpuinfer_interface)
.def("warm_up", &TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt8>>::warm_up)
.def("load_weights", &TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt8>>::load_weights)
.def("forward", &TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt8>>::forward_binding);
py::class_<TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt4>>, MoE_Interface,
std::shared_ptr<TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt4>>>>(moe_module, "KMLInt4_MOE")
.def(py::init<GeneralMOEConfig>())
.def("warm_up_task", &MOEBindings<KML_MOE_TP<arm_kml::GemmKernelInt4>>::WarmUpBindings::cpuinfer_interface)
.def("load_weights_task",
&MOEBindings<KML_MOE_TP<arm_kml::GemmKernelInt4>>::LoadWeightsBindings::cpuinfer_interface)
.def("forward_task", &MOEBindings<KML_MOE_TP<arm_kml::GemmKernelInt4>>::ForwardBindings::cpuinfer_interface)
.def("warm_up", &TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt4>>::warm_up)
.def("load_weights", &TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt4>>::load_weights)
.def("forward", &TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt4>>::forward_binding);
// amd have not implemented int4 kernel yet
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt4, _is_plain_>>(moe_module, "Int4_KERNEL_MOE");
#endif
#endif
// Expose kernel tiling/runtime parameters so Python can modify them at runtime
{
auto tiling_module = moe_module.def_submodule("tiling");
#if defined(USE_MOE_KERNEL)
tiling_module.def(
"get_int8",
[]() {
auto t = moe_kernel::GemmKernelInt8::get_tiling();
py::dict d;
d["n_block_up_gate"] = std::get<0>(t);
d["n_block_down"] = std::get<1>(t);
d["n_block"] = std::get<2>(t);
d["m_block"] = std::get<3>(t);
d["k_block"] = std::get<4>(t);
d["n_block_up_gate_prefi"] = std::get<5>(t);
d["n_block_down_prefi"] = std::get<6>(t);
return d;
},
"Get current tiling parameters for INT8 kernel");
tiling_module.def(
"set_int8",
[](int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block, int n_block_up_gate_prefi,
int n_block_down_prefi) {
moe_kernel::GemmKernelInt8::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,
n_block_up_gate_prefi, n_block_down_prefi);
},
py::arg("n_block_up_gate"), py::arg("n_block_down"), py::arg("n_block"), py::arg("m_block"), py::arg("k_block"),
py::arg("n_block_up_gate_prefi"), py::arg("n_block_down_prefi"), "Set tiling parameters for INT8 kernel");
tiling_module.def(
"get_int4",
[]() {
auto t = moe_kernel::GemmKernelInt4::get_tiling();
py::dict d;
d["n_block_up_gate"] = std::get<0>(t);
d["n_block_down"] = std::get<1>(t);
d["n_block"] = std::get<2>(t);
d["m_block"] = std::get<3>(t);
d["k_block"] = std::get<4>(t);
d["n_block_up_gate_prefi"] = std::get<5>(t);
d["n_block_down_prefi"] = std::get<6>(t);
return d;
},
"Get current tiling parameters for INT4 kernel");
tiling_module.def(
"set_int4",
[](int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block, int n_block_up_gate_prefi,
int n_block_down_prefi) {
moe_kernel::GemmKernelInt4::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,
n_block_up_gate_prefi, n_block_down_prefi);
},
py::arg("n_block_up_gate"), py::arg("n_block_down"), py::arg("n_block"), py::arg("m_block"), py::arg("k_block"),
py::arg("n_block_up_gate_prefi"), py::arg("n_block_down_prefi"), "Set tiling parameters for INT4 kernel");
// Convenience: set both
tiling_module.def(
"set_all",
[](int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block, int n_block_up_gate_prefi,
int n_block_down_prefi) {
moe_kernel::GemmKernelInt8::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,
n_block_up_gate_prefi, n_block_down_prefi);
moe_kernel::GemmKernelInt4::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,
n_block_up_gate_prefi, n_block_down_prefi);
},
py::arg("n_block_up_gate"), py::arg("n_block_down"), py::arg("n_block"), py::arg("m_block"), py::arg("k_block"),
py::arg("n_block_up_gate_prefi"), py::arg("n_block_down_prefi"),
"Set tiling parameters for both INT8 and INT4 kernels");
#endif
}
auto kvcache_module = m.def_submodule("kvcache");
@@ -628,18 +662,7 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
.def(py::init<KVCacheConfig>())
.def("get_cache_total_len", &KVCache::get_cache_total_len)
.def("update_cache_total_len",
[](KVCache& kvcache, int cache_total_len) { kvcache.update_cache_total_len(cache_total_len); })
// .def("attn", &KVCacheBindings::AttnBindings::cpuinfer_interface)
// .def("get_all_kvcache_one_layer", &KVCacheBindings::GetAllKVCacheOneLayerBindings::cpuinfer_interface)
// .def("get_and_update_kvcache_fp16", &KVCacheBindings::GetAndUpdateKVCacheFp16Bindings::cpuinfer_interface)
// .def("get_kvcache_fp16", &KVCacheBindings::GetKVCacheFp16Bindings::cpuinfer_interface)
// .def("update_kvcache_fp16", &KVCacheBindings::UpdateKVCacheFp16Bindings::cpuinfer_interface)
// .def("update_importance", &KVCacheBindings::UpdateImportanceBindings::cpuinfer_interface)
// .def("attn_with_kvcache", &KVCacheBindings::AttnWithKVCacheBindings::cpuinfer_interface)
// .def("clear_importance_all_layers", &KVCacheBindings::ClearImportanceAllLayersBindings::cpuinfer_interface)
// .def("calc_anchor_all_layers", &KVCacheBindings::CalcAnchorAllLayersBindings::cpuinfer_interface)
;
[](KVCache& kvcache, int cache_total_len) { kvcache.update_cache_total_len(cache_total_len); });
auto utils = m.def_submodule("utils");

View File

@@ -29,14 +29,10 @@
#include "../../cpu_backend/shared_mem_buffer.h"
#include "../../cpu_backend/worker_pool.h"
#include "../common.hpp"
#include "../moe-tp.hpp"
#include "la/amx.hpp"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#define expert_map(m, x) (m != nullptr ? m[(x)] : (x))
template <class T>
class AMX_AWQ_MOE_TP {
@@ -475,12 +471,6 @@ class AMX_AWQ_MOE_TP {
m_local_up_output_ptr_.resize(config_.expert_num);
m_local_down_output_ptr_.resize(config_.expert_num);
// printf("tp part %d alloc layer %d, %f GB, on numa %d\n", tp_part_idx, config_.layer_idx,
// 1e-9 * config_.expert_num *
// (T::BufferB::required_size(config_.intermediate_size, config_.hidden_size) * 2 +
// T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)),
// numa_node_of_cpu(sched_getcpu()));
for (size_t i = 0; i < config_.expert_num; i++) {
gate_up_ba_.push_back(
std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, group_size, nullptr));
@@ -524,9 +514,10 @@ class AMX_AWQ_MOE_TP {
// shared_mem_buffer_numa.dealloc(this);
}
void load_weights(const uint64_t* physical_to_logical_map) {
void load_weights() {
auto& quant_config = config_.quant_config;
int& group_size = quant_config.group_size;
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
if (quant_config.group_size == 0 || !quant_config.zero_point) {
throw std::runtime_error("AWQ-Quantization AMX MoE only support KGroup Int4_1");
}
@@ -534,171 +525,12 @@ class AMX_AWQ_MOE_TP {
auto pool = config_.pool->get_subpool(tp_part_idx);
if (config_.gate_projs.size()) {
throw std::runtime_error("AMX load weights is not support");
// pool->do_work_stealing_job(
// config_.expert_num, nullptr,
// [this, physical_to_logical_map](int expert_id) {
// // printf("Load layer %d [%d/%d]\n", config_.layer_idx, expert_id, config_.expert_num);
// uint64_t logical_expert_id = physical_to_logical_map[expert_id];
// auto& quant_config = config_.quant_config;
// int& group_size = quant_config.group_size;
// {
// int num_group = config_.hidden_size / group_size;
// size_t scale_size = num_group * config_.intermediate_size * sizeof(float);
// size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size) -
// (scale_size << 1);
// memcpy(gate_bb_[expert_id]->b, config_.gate_projs[tp_part_idx][logical_expert_id], size);
// if constexpr (T::BufferB::SCALE) {
// memcpy(gate_bb_[expert_id]->d, config_.gate_scales[tp_part_idx][logical_expert_id], scale_size);
// }
// memcpy(up_bb_[expert_id]->b, config_.up_projs[tp_part_idx][logical_expert_id], size);
// if constexpr (T::BufferB::SCALE) {
// memcpy(up_bb_[expert_id]->d, config_.up_scales[tp_part_idx][logical_expert_id], scale_size);
// }
// if (quant_config.zero_point) {
// // Convert INT4 zeros to float mins using AVX optimization
// size_t num_elements = num_group * config_.intermediate_size;
// convert_zeros_to_mins_avx(
// (const uint8_t*)config_.gate_zeros[tp_part_idx][logical_expert_id],
// (const float*)config_.gate_scales[tp_part_idx][logical_expert_id],
// gate_bb_[expert_id]->mins,
// num_elements
// );
// convert_zeros_to_mins_avx(
// (const uint8_t*)config_.up_zeros[tp_part_idx][logical_expert_id],
// (const float*)config_.up_scales[tp_part_idx][logical_expert_id],
// up_bb_[expert_id]->mins,
// num_elements
// );
// }
// }
// {
// int num_group = config_.intermediate_size / group_size;
// size_t scale_size = num_group * config_.hidden_size * sizeof(float);
// size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, group_size) -
// (scale_size << 1);
// memcpy(down_bb_[expert_id]->b, config_.down_projs[tp_part_idx][logical_expert_id], size);
// if constexpr (T::BufferB::SCALE) {
// memcpy(down_bb_[expert_id]->d, config_.down_scales[tp_part_idx][logical_expert_id], scale_size);
// }
// if (quant_config.zero_point) {
// // Convert INT4 zeros to float mins using AVX optimization
// size_t num_elements = num_group * config_.hidden_size;
// convert_zeros_to_mins_avx(
// (const uint8_t*)config_.down_zeros[tp_part_idx][logical_expert_id],
// (const float*)config_.down_scales[tp_part_idx][logical_expert_id],
// down_bb_[expert_id]->mins,
// num_elements
// );
// }
// }
// },
// nullptr);
} else {
// AWQ Load from file implementation
int nth = T::recommended_nth(config_.intermediate_size);
static uint8_t mat_type_all = 3, mat_split = 1;
if (config_.load) {
throw std::runtime_error("AMX load weights from file is not support");
// std::cout << "Loading AWQ weights from " << prefix << std::endl;
// // Use work stealing job for parallel loading
// pool->do_work_stealing_job(
// config_.expert_num * mat_type_all, nullptr,
// [this, physical_to_logical_map, mat_split](int task_id) {
// auto& quant_config = config_.quant_config;
// int& group_size = quant_config.group_size;
// int64_t expert_idx = task_id / mat_type_all;
// uint64_t logical_expert_id = physical_to_logical_map[expert_idx];
// uint8_t mat_class = task_id % mat_type_all;
// if (mat_class == 0) { // gate projection
// int num_group = config_.hidden_size / group_size;
// size_t weights_size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size,
// group_size) - (2 * num_group * config_.intermediate_size * sizeof(float)); size_t scales_size =
// num_group * config_.intermediate_size * sizeof(float); size_t zeros_size = num_group *
// config_.intermediate_size / 2; // INT4 packed format
// // Allocate temporary buffer for zeros
// std::vector<uint8_t> zeros_buf(zeros_size);
// read_awq_weights(prefix, "gate_proj", logical_expert_id,
// (char*)gate_bb_[expert_idx]->b,
// (float*)gate_bb_[expert_idx]->d,
// zeros_buf.data(),
// weights_size, scales_size, zeros_size,
// mat_split, 0);
// // Convert INT4 zeros to float mins
// if (quant_config.zero_point) {
// convert_zeros_to_mins_avx(zeros_buf.data(),
// (float*)gate_bb_[expert_idx]->d,
// gate_bb_[expert_idx]->mins,
// num_group * config_.intermediate_size);
// }
// } else if (mat_class == 1) { // up projection
// int num_group = config_.hidden_size / group_size;
// size_t weights_size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size,
// group_size) - (2 * num_group * config_.intermediate_size * sizeof(float)); size_t scales_size =
// num_group * config_.intermediate_size * sizeof(float); size_t zeros_size = num_group *
// config_.intermediate_size / 2; // INT4 packed format
// // Allocate temporary buffer for zeros
// std::vector<uint8_t> zeros_buf(zeros_size);
// read_awq_weights(prefix, "up_proj", logical_expert_id,
// (char*)up_bb_[expert_idx]->b,
// (float*)up_bb_[expert_idx]->d,
// zeros_buf.data(),
// weights_size, scales_size, zeros_size,
// mat_split, 0);
// // Convert INT4 zeros to float mins
// if (quant_config.zero_point) {
// convert_zeros_to_mins_avx(zeros_buf.data(),
// (float*)up_bb_[expert_idx]->d,
// up_bb_[expert_idx]->mins,
// num_group * config_.intermediate_size);
// }
// } else { // down projection
// int num_group = config_.intermediate_size / group_size;
// size_t weights_size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size,
// group_size) - (2 * num_group * config_.hidden_size * sizeof(float)); size_t scales_size = num_group
// * config_.hidden_size * sizeof(float); size_t zeros_size = num_group * config_.hidden_size / 2; //
// INT4 packed format
// // Allocate temporary buffer for zeros
// std::vector<uint8_t> zeros_buf(zeros_size);
// read_awq_weights(prefix, "down_proj", logical_expert_id,
// (char*)down_bb_[expert_idx]->b,
// (float*)down_bb_[expert_idx]->d,
// zeros_buf.data(),
// weights_size, scales_size, zeros_size,
// mat_split, 0);
// // Convert INT4 zeros to float mins
// if (quant_config.zero_point) {
// convert_zeros_to_mins_avx(zeros_buf.data(),
// (float*)down_bb_[expert_idx]->d,
// down_bb_[expert_idx]->mins,
// num_group * config_.hidden_size);
// }
// }
// },
// nullptr);
}
// check process, store down matrix to check
#ifdef CHECK
@@ -708,13 +540,12 @@ class AMX_AWQ_MOE_TP {
else if (config_.gate_scale != nullptr)
#endif
{
// Loading quantized weights
// Loading quantized weights
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = physical_to_logical_map[expert_idx];
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// gate part
gate_bb_[expert_idx]->from_raw_mat(
@@ -734,7 +565,7 @@ class AMX_AWQ_MOE_TP {
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = physical_to_logical_map[expert_idx];
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// down part
down_bb_[expert_idx]->from_raw_mat(
@@ -748,7 +579,7 @@ class AMX_AWQ_MOE_TP {
config_.expert_num, nullptr,
[this, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id;
uint64_t logical_expert_id = physical_to_logical_map[expert_idx];
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
size_t scale_elem_count =
(config_.hidden_size * config_.intermediate_size) / config_.quant_config.group_size;
@@ -793,7 +624,7 @@ class AMX_AWQ_MOE_TP {
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
int64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = physical_to_logical_map[expert_idx];
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// gate part
gate_bb_[logical_expert_id]->from_mat(
@@ -812,7 +643,7 @@ class AMX_AWQ_MOE_TP {
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
int64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = physical_to_logical_map[expert_idx];
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// down part
down_bb_[logical_expert_id]->from_mat(
@@ -1280,16 +1111,15 @@ template <typename K>
class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
public:
using TP_MOE_Common<AMX_AWQ_MOE_TP<K>>::TP_MOE_Common;
void load_weights(const uint64_t* physical_to_logical_map) {
void load_weights() {
auto& config = this->config;
auto& tps = this->tps;
auto& tp_count = this->tp_count;
auto pool = config.pool;
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
if (config.gate_projs.empty() == false) {
printf("TP Load from loader\n");
pool->dispense_backend()->do_numa_job([this, pool, physical_to_logical_map](int numa_id) {
this->tps[numa_id]->load_weights(physical_to_logical_map);
});
DO_TPS_LOAD_WEIGHTS(pool);
this->weights_loaded = true;
} else if (config.gate_scale != nullptr) {
printf("From Packed Int4 with KGroup Scale and Zeros\n");
@@ -1314,7 +1144,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&](int expert_id_) {
size_t expert_id = expert_id_;
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
// weight TP-slicing
memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),
@@ -1384,9 +1214,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
}
}
pool->dispense_backend()->do_numa_job([this, pool, physical_to_logical_map](int numa_id) {
this->tps[numa_id]->load_weights(physical_to_logical_map);
});
DO_TPS_LOAD_WEIGHTS(pool);
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
@@ -1417,7 +1245,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&](int expert_id_) {
size_t expert_id = physical_to_logical_map[expert_id_];
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
memcpy((ggml_bf16_t*)tpc.gate_proj + expert_id * gate_up_elcount,
(ggml_bf16_t*)config.gate_proj + expert_id * config.intermediate_size * config.hidden_size +
i * gate_up_elcount,
@@ -1438,9 +1266,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
}
}
pool->dispense_backend()->do_numa_job([this, pool, physical_to_logical_map](int numa_id) {
this->tps[numa_id]->load_weights(physical_to_logical_map);
});
DO_TPS_LOAD_WEIGHTS(pool);
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
@@ -1452,9 +1278,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
this->weights_loaded = true;
} else if (config.path != "") {
printf("TP Load from file\n");
pool->dispense_backend()->do_numa_job([this, pool, physical_to_logical_map](int numa_id) {
this->tps[numa_id]->load_weights(physical_to_logical_map);
});
DO_TPS_LOAD_WEIGHTS(pool);
this->weights_loaded = true;
} else {
throw std::runtime_error("no weight source");

View File

@@ -7,27 +7,15 @@
#include <tmmintrin.h>
#include <unistd.h>
#include <array>
#include <cassert>
#include <chrono>
#include <cstdint>
#include <cstdio>
#include <iostream>
#include <memory>
#include <random>
#include <stdexcept>
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llamafile/sgemm.h"
#include "utils.hpp"
// Include the split AMX headers
#include "amx_buffers.hpp"
#include "amx_config.hpp"
#include "amx_kernels.hpp"
#include "amx_quantization.hpp"
#include "amx_utils.hpp"
namespace amx {

View File

@@ -5,9 +5,7 @@
#include <cstdint>
#include <cstdio>
#include <limits>
#include <vector>
#include "amx_config.hpp"
#include "amx_utils.hpp"
#include "llama.cpp/ggml-impl.h"
#include "pack.hpp"
@@ -48,16 +46,41 @@ struct BufferAImpl {
assert(ith == 0 && nth == 1);
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
float amax = 0.0f;
for (int j = 0; j < k; j += 32) {
__m512 f0, f1;
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j), &f0, &f1);
amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));
amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));
__m512 amax_v0 = _mm512_setzero_ps();
__m512 amax_v1 = _mm512_setzero_ps();
__m512 amax_v2 = _mm512_setzero_ps();
__m512 amax_v3 = _mm512_setzero_ps();
__m512 amax_v4 = _mm512_setzero_ps();
__m512 amax_v5 = _mm512_setzero_ps();
__m512 amax_v6 = _mm512_setzero_ps();
__m512 amax_v7 = _mm512_setzero_ps();
for (int j = 0; j < k; j += 128) {
__m512 f0, f1, f2, f3, f4, f5, f6, f7;
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 0), &f0, &f1);
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 32), &f2, &f3);
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 64), &f4, &f5);
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 96), &f6, &f7);
amax_v0 = vector_abs_max(amax_v0, f0);
amax_v1 = vector_abs_max(amax_v1, f1);
amax_v2 = vector_abs_max(amax_v2, f2);
amax_v3 = vector_abs_max(amax_v3, f3);
amax_v4 = vector_abs_max(amax_v4, f4);
amax_v5 = vector_abs_max(amax_v5, f5);
amax_v6 = vector_abs_max(amax_v6, f6);
amax_v7 = vector_abs_max(amax_v7, f7);
}
amax_v0 = vector_abs_max(amax_v0, amax_v1);
amax_v2 = vector_abs_max(amax_v2, amax_v3);
amax_v4 = vector_abs_max(amax_v4, amax_v5);
amax_v6 = vector_abs_max(amax_v6, amax_v7);
amax_v0 = vector_abs_max(amax_v0, amax_v2);
amax_v4 = vector_abs_max(amax_v4, amax_v6);
amax_v0 = vector_abs_max(amax_v0, amax_v4);
float amax = _mm512_reduce_max_ps(amax_v0);
d[m_begin + i] = amax / ((1 << 7) - 1);
}
}
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {

View File

@@ -1180,9 +1180,9 @@ struct GemmKernel224Int4 {
}
static void load_a(dt* a, size_t lda) {
#ifdef HAVE_AMX
_tile_loadd(0, a, lda);
_tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);
#ifdef __AMX__
_tile_stream_loadd(0, a, lda);
_tile_stream_loadd(1, offset_pointer(a, lda * TILE_M), lda);
#else
(void)a;
(void)lda;

View File

@@ -1,5 +1,7 @@
#ifndef UTILS_HPP
#define UTILS_HPP
#include <immintrin.h>
#include <cstddef>
#include <cstdint>
@@ -18,4 +20,13 @@ static inline void avx512_32xbf16_to_32xfp32(__m512i* src, __m512* dst0, __m512*
_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)(src) + 1)), 16)));
}
static inline __m512 vector_abs_max(__m512 a, __m512 b) {
__m512 a_abs = _mm512_abs_ps(a);
__m512 b_abs = _mm512_abs_ps(b);
__mmask16 mask = _mm512_cmp_ps_mask(a_abs, b_abs, _CMP_GT_OS);
return _mm512_mask_blend_ps(mask, b_abs, a_abs);
}
#endif // UTILS_HPP

View File

@@ -29,10 +29,7 @@
#include "../../cpu_backend/worker_pool.h"
#include "../moe-tp.hpp"
#include "la/amx.hpp"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
template <class T>
class AMX_MOE_TP {
@@ -264,16 +261,15 @@ class AMX_MOE_TP {
~AMX_MOE_TP() {
// shared_mem_buffer_numa.dealloc(this);
}
// pack and quant the weights
void pack_weights() {}
void load_weights(const uint64_t* physical_to_logical_map) {
void load_weights() {
auto pool = config_.pool->get_subpool(tp_part_idx);
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
if (config_.gate_projs.size()) {
pool->do_work_stealing_job(
config_.expert_num, nullptr,
[this, physical_to_logical_map](int expert_id) {
// printf("Load layer %d [%d/%d]\n", config_.layer_idx, expert_id, config_.expert_num);
uint64_t logical_expert_id = physical_to_logical_map[expert_id];
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_id);
{
size_t scale_size = config_.intermediate_size * sizeof(float);
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size) - scale_size;
@@ -311,7 +307,7 @@ class AMX_MOE_TP {
std::cout << "Loading from " << prefix << std::endl;
for (int task_id = 0; task_id < config_.expert_num * mat_type_all * mat_split; task_id++) {
int64_t expert_idx = task_id / (mat_type_all * mat_split);
uint64_t logical_expert_id = physical_to_logical_map[expert_idx];
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
uint8_t mat_class = (task_id % (mat_type_all * mat_split)) / mat_split;
uint8_t mat_split_idex = task_id % mat_split;
if (mat_class == 0) { // the up matrix
@@ -345,30 +341,32 @@ class AMX_MOE_TP {
}
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth](int task_id) {
[this, nth, physical_to_logical_map](int task_id) {
int64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// gate part
gate_bb_[expert_idx]->from_mat(
(ggml_bf16_t*)config_.gate_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith,
nth);
gate_bb_[logical_expert_id]->from_mat(
(ggml_bf16_t*)config_.gate_proj + logical_expert_id * config_.intermediate_size * config_.hidden_size,
ith, nth);
// up part
up_bb_[expert_idx]->from_mat(
(ggml_bf16_t*)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith,
nth);
up_bb_[logical_expert_id]->from_mat(
(ggml_bf16_t*)config_.up_proj + logical_expert_id * config_.intermediate_size * config_.hidden_size,
ith, nth);
},
nullptr);
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth](int task_id) {
[this, nth, physical_to_logical_map](int task_id) {
int64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// down part
down_bb_[expert_idx]->from_mat(
(ggml_bf16_t*)config_.down_proj + expert_idx * config_.hidden_size * config_.intermediate_size, ith,
nth);
down_bb_[logical_expert_id]->from_mat(
(ggml_bf16_t*)config_.down_proj + logical_expert_id * config_.hidden_size * config_.intermediate_size,
ith, nth);
// printf("load down, expert %ld, ith %d, total nth %d\n", expert_idx, ith, nth);
},
nullptr);
@@ -380,8 +378,9 @@ class AMX_MOE_TP {
if (config_.save) {
pool->do_work_stealing_job(
config_.expert_num * mat_type_all, nullptr,
[this](int task_id) {
[this, physical_to_logical_map](int task_id) {
int64_t expert_idx = task_id / mat_type_all;
expert_idx = expert_map(physical_to_logical_map, expert_idx);
uint8_t mat_class = task_id % mat_type_all;
if (mat_class == 0) { // the up matrix
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);
@@ -829,16 +828,16 @@ template <typename K>
class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
public:
using TP_MOE_Common<AMX_MOE_TP<K>>::TP_MOE_Common;
void load_weights(const uint64_t* physical_to_logical_map) {
void load_weights() {
auto& config = this->config;
auto& tps = this->tps;
auto& tp_count = this->tp_count;
auto pool = config.pool;
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
if (config.gate_projs.empty() == false) {
printf("TP Load from loader\n");
pool->dispense_backend()->do_numa_job([this, pool, physical_to_logical_map](int numa_id) {
this->tps[numa_id]->load_weights(physical_to_logical_map);
});
pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
this->weights_loaded = true;
} else if (config.gate_proj != nullptr) {
printf("From BF16\n");
@@ -852,7 +851,7 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&](int expert_id_) {
size_t expert_id = expert_id_;
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
memcpy((ggml_bf16_t*)tpc.gate_proj + expert_id * gate_up_elcount,
(ggml_bf16_t*)config.gate_proj + expert_id * config.intermediate_size * config.hidden_size +
i * gate_up_elcount,
@@ -873,9 +872,7 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
}
}
pool->dispense_backend()->do_numa_job([this, pool, physical_to_logical_map](int numa_id) {
this->tps[numa_id]->load_weights(physical_to_logical_map);
});
pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
@@ -887,9 +884,7 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
this->weights_loaded = true;
} else if (config.path != "") {
printf("TP Load from file\n");
pool->dispense_backend()->do_numa_job([this, pool, physical_to_logical_map](int numa_id) {
this->tps[numa_id]->load_weights(physical_to_logical_map);
});
pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
this->weights_loaded = true;
} else {
throw std::runtime_error("no weight source");

View File

@@ -1,9 +1,10 @@
#ifndef CPUINFER_OPERATOR_COMMON_HPP
#define CPUINFER_OPERATOR_COMMON_HPP
#include "../cpu_backend/shared_mem_buffer.h"
#include <map>
#include "../cpu_backend/worker_pool.h"
#include "llama.cpp/ggml.h"
#include "ggml.h"
#if defined(__aarch64__) && defined(CPU_USE_KML)
#include <arm_sve.h>
@@ -13,8 +14,6 @@
#include <cmath>
#include <cstdio>
#include <cstring>
#include <limits>
#include <memory>
#include <stdexcept>
#include <type_traits>
@@ -40,6 +39,14 @@
last = end_time; \
} while (0)
#define DO_TPS_LOAD_WEIGHTS(pool) \
(pool)->dispense_backend()->do_numa_job([this, pool, config](int numa_id) { \
this->tps[numa_id]->config_.physical_to_logical_map = config.physical_to_logical_map; \
this->tps[numa_id]->load_weights(); \
})
#define expert_map(m, x) (m != nullptr ? m[(x)] : (x))
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline T div_up(T x, T y) {
return (x + y - 1) / y;
@@ -274,12 +281,11 @@ struct GeneralMOEConfig {
GeneralMOEConfig() {}
GeneralMOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int num_gpu_experts)
GeneralMOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size)
: expert_num(expert_num),
num_experts_per_tok(routed_expert_num),
hidden_size(hidden_size),
intermediate_size(intermediate_size),
num_gpu_experts(num_gpu_experts) {}
intermediate_size(intermediate_size) {}
int max_possible_qlen() { return std::max(max_len, group_max_len); }
};

View File

@@ -1,43 +0,0 @@
#include "batch_gemm_api.hpp"
#ifdef __cplusplus
extern "C" {
#endif
void decode_cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k,
const float alpha, const void* a, const size_t lda, const BLASINT8 oa, const void* b,
const size_t ldb, const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
const int32_t* oc) {
BLASINT8* ptrA = (BLASINT8*)a;
BLASINT8* ptrB = (BLASINT8*)b;
int32_t* ptrC = c;
size_t split_n = n / N_SIZE;
for (size_t n_block = 0; n_block < split_n; n_block++) {
BLASINT8* cur_ptrA = ptrA;
BLASINT8* cur_ptrB = ptrB + n_block * (N_SIZE * k);
int32_t* cur_ptrC = ptrC + n_block * N_SIZE;
gemm_kernel_1x8(cur_ptrA, cur_ptrB, cur_ptrC, ldc, k, COMP_SV_LEN);
}
}
void decode_int4_cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa,
const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc, const size_t m,
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc) {
BLASINT8* ptrA = (BLASINT8*)a;
BLASINT8* ptrB = (BLASINT8*)b;
int32_t* ptrC = c;
size_t split_n = n / N_SIZE;
for (size_t n_block = 0; n_block < split_n; n_block++) {
BLASINT8* cur_ptrA = ptrA;
BLASINT8* cur_ptrB = ptrB + n_block * (N_SIZE * (k / 2));
int32_t* cur_ptrC = ptrC + n_block * N_SIZE;
gemm_kernel_1x8_int4(cur_ptrA, cur_ptrB, cur_ptrC, (ldc / 2), (k / 2), COMP_SV_LEN);
}
}
#ifdef __cplusplus
}
#endif

View File

@@ -1,35 +0,0 @@
#ifndef _BATCH_GEMM_KERNEL_API_
#define _BATCH_GEMM_KERNEL_API_
#include "utils.hpp"
#ifdef __cplusplus
extern "C" {
#endif
void decode_cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k,
const float alpha, const void* a, const size_t lda, const BLASINT8 oa, const void* b,
const size_t ldb, const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
const int32_t* oc);
void prefill_cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k,
const float alpha, const void* a, const size_t lda, const BLASINT8 oa, const void* b,
const size_t ldb, const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
const int32_t* oc);
void decode_int4_cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa,
const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc, const size_t m,
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void prefill_int4_cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa,
const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc, const size_t m,
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
#ifdef __cplusplus
}
#endif
#endif /*** _BATCH_GEMM_KERNEL_API_ ***/

View File

@@ -1,104 +0,0 @@
#pragma once
// #include <arm_sve.h>
#include <cstdint>
#include <cstring>
// static inline void sve_32xbf16_to_32xfp32(const bfloat16_t *src, float *dst0, float *dst1) {
// #ifdef __ARM_FEATURE_SVE
// // 全真谓词,对应每个 16bit 元素
// svbool_t pg_h = svptrue_b16();
// // 每次循环处理 svcnth(pg_h) 个 BF16 元素svcnth(pg_h)==VL/16
// size_t offset = 0;
// // 我们要生产两段 FP32 输出,每段长度 svcntw(pg_h)==VL/32
// // SVE 向量寄存宽度 VL 可以任意 1282048但代码与之无关
// // Load first half BF16→FP32
// svbfloat16_t vb0 = svld1(pg_h, &src[offset]); // load BF16
// svfloat32_t vf0 = svcvt_f32_bf16_z(pg_h, vb0); // widen→FP32
// svst1(pg_h, &dst0[offset/2], vf0); // store
// offset += svcnth(pg_h); // 移到第二批 BF16 元素
// // Load second half BF16→FP32
// svbfloat16_t vb1 = svld1(pg_h, &src[offset]);
// svfloat32_t vf1 = svcvt_f32_bf16_z(pg_h, vb1);
// svst1(pg_h, &dst1[offset/2], vf1);
// #else
// // fallback: scalar or NEON
// #endif
// }
// 简单截断模式:直接丢弃低 16 位
static inline uint16_t float_to_bf16_trunc(float f) {
uint32_t u;
// 按位拷贝,避免 strictaliasing UB
memcpy(&u, &f, sizeof(u)); // :contentReference[oaicite:3]{index=3}
return (uint16_t)(u >> 16); // 截断得到高 16 位 :contentReference[oaicite:4]{index=4}
}
static inline void convert_32fp32_to_32bf16_pure_c(const float* src, uint16_t* dst) {
// src 已偏移至 token_nth * hidden_size
for (int e = 0; e < 32; e++) { // 共 32 个元素
// 选择截断或四舍五入
dst[e] = float_to_bf16_trunc(src[e]);
}
}
// 把 32 个 bf16 元素转换成 32 个 fp32 元素
static inline void convert_32bf16_to_32fp32_pure_c(const uint16_t* src, float* dst) {
for (int e = 0; e < 32; e++) {
uint32_t temp = ((uint32_t)src[e]) << 16; // 将 BF16 左移 16 位
memcpy(&dst[e], &temp, sizeof(float)); // 将结果复制到 FP32 变量中
}
}
// template <typename T> T *offset_pointer(T *ptr, std::size_t byte_offset) {
// return reinterpret_cast<T *>(reinterpret_cast<char *>(ptr) + byte_offset);
// }
/*** gemm helper ***/
#include <kblas.h>
#include <iostream>
#ifdef __cplusplus
extern "C" {
#endif
#define COMP_SV_LEN 32
#define K_SIZE COMP_SV_LEN
#define M_SIZE 1
#define N_SIZE 8
#define INDEXING_B(row_idx, col_idx, ldb) ((col_idx) * (ldb) + row_idx)
#define PROCESS_ACCUM(reg_idx, z_reg_idx, tmp_reg, dst, p) \
"mov w" #reg_idx \
", #0\n" \
"saddv d" #reg_idx ", " #p ", z" #z_reg_idx \
".s\n" \
"fmov " #tmp_reg ", d" #reg_idx \
"\n" \
"add x" #reg_idx ", x" #reg_idx ", " #tmp_reg \
"\n" \
"str w" #reg_idx ", [%[" #dst "]], #4\n"
#define INT4_CP_MASK_SHIFT_1x8(src_reg, dst_reg, mask_reg1, mask_reg2, shift) \
"movprfx z" #dst_reg ", z" #src_reg \
"\n" \
"lsl z" #dst_reg ".b, p0/m, z" #dst_reg ".b, #" #shift \
"\n" \
"and z" #src_reg ".b, p0/m, z" #src_reg ".b, z" #mask_reg1 ".b\n"
void pack_b_1x8(void* bufferB, const void* cur_b_ptr, size_t n, size_t k, size_t ldb, const BLASINT8 ob);
void pack_b_1x8_int4(void* bufferB, const void* cur_b_ptr, size_t n, size_t k, size_t ldb, const BLASINT8 ob);
void gemm_kernel_1x8(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
void gemm_kernel_1x8_int4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
int64_t sv_len);
#ifdef __cplusplus
}
#endif
/*** gemm helper ***/

View File

@@ -1,430 +0,0 @@
#ifndef KML_DEEPSEEKV3_HPP
#define KML_DEEPSEEKV3_HPP
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <memory>
#include "gate.hpp"
#include "kblas.h"
#include "la/arm_kml.hpp"
#include "llama.cpp/ggml.h"
#include "mla.hpp"
#include "moe.hpp"
// #define DEBUG_LAYER_CORRECT
class DeepseekV3DecoderLayer
#ifdef FORWARD_TIME_PROFILE
: protected TimePerf
#endif
{
private:
using A = float;
GeneralConfig config;
A* attn_norm;
A* ffn_norm;
A* hidden_states;
ggml_bf16_t* hidden_states_bf16;
int64_t* expert_ids;
A* experts_weights;
size_t layer_idx;
public:
std::shared_ptr<MLA_Interface> self_attn = nullptr;
std::shared_ptr<MoEGate> gate = nullptr;
std::shared_ptr<MoE_Interface> ffn = nullptr;
using input_t = float;
using output_t = float;
DeepseekV3DecoderLayer(GeneralConfig config, size_t layer_idx) : config(config), layer_idx(layer_idx) {
init_ggml();
MemoryRequest mem_requests;
PUSH_MEM_REQ(hidden_states, sizeof(A) * config.max_qlen * config.hidden_size);
PUSH_MEM_REQ(hidden_states_bf16, sizeof(ggml_bf16_t) * config.max_qlen * config.hidden_size);
PUSH_MEM_REQ(expert_ids,
sizeof(int64_t) * config.max_qlen * (config.num_experts_per_tok + config.n_shared_experts));
PUSH_MEM_REQ(experts_weights, sizeof(A) * config.max_qlen * (config.num_experts_per_tok + config.n_shared_experts));
shared_mem_buffer_for_decoder_layer.alloc(this, mem_requests);
}
void load_norm_binding(intptr_t attn_norm_ptr, ggml_type attn_norm_type, intptr_t ffn_norm_ptr,
ggml_type mlp_norm_type) {
load_norm((void*)attn_norm_ptr, attn_norm_type, (void*)ffn_norm_ptr, mlp_norm_type);
}
void load_norm(const void* attn_norm, ggml_type attn_norm_type, const void* ffn_norm, ggml_type ffn_norm_type) {
this->attn_norm = new A[config.hidden_size];
this->ffn_norm = new A[config.hidden_size];
convert_or_copy(this->attn_norm, (void*)attn_norm, attn_norm_type, config.hidden_size);
convert_or_copy(this->ffn_norm, (void*)ffn_norm, ffn_norm_type, config.hidden_size);
}
void forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,
const void* input, void* output) {
#ifdef FORWARD_TIME_PROFILE
forward_perf_start();
#endif
size_t seq_len = 0;
for (size_t i = 0; i < qlens.size(); i++) {
seq_len += qlens[i];
}
// for (size_t i = 0; i < 5; i++) {
// debug_f32((input_t *)input + (seq_len - 5 + i) * config.hidden_size, config.hidden_size);
// }
// printf("\n");
#ifdef DEBUG_LAYER_CORRECT
std::string prefix = "Layer_" + std::to_string(layer_idx);
dump_bin(prefix + "_input", (input_t*)input, seq_len * config.hidden_size);
#endif
// Residue
// printf("convert or copy hidden states, %ld,%ld\n", seq_len, config.hidden_size);
convert_or_copy(hidden_states, (input_t*)input, seq_len * config.hidden_size);
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("copy residue");
#endif
// Norm
config.pool->do_work_stealing_job(seq_len, [&](int task_id) {
A* input_row = (A*)input + task_id * config.hidden_size;
RMSNorm<A>::rms_norm_single_with_weights(config.hidden_size, attn_norm, input_row);
});
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("before attn norm");
#endif
// self attention
self_attn->forward(qlens, page_tables, kv_lens, input, output);
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("self attn");
#endif
#ifdef DEBUG_LAYER_CORRECT
dump_bin(prefix + "_after_attn", (input_t*)output, seq_len * config.hidden_size);
#endif
// Add Residue
config.pool->do_work_stealing_job(seq_len, [&](int task_id) {
A* hidden_state_row = hidden_states + task_id * config.hidden_size;
A* output_row = (A*)output + task_id * config.hidden_size;
A* input_row = (A*)input + task_id * config.hidden_size;
for (size_t i = 0; i < config.hidden_size; i++) {
hidden_state_row[i] += output_row[i];
input_row[i] = hidden_state_row[i];
}
});
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("after attn add residue");
#endif
// Norm
config.pool->do_work_stealing_job(seq_len, [&](int task_id) {
A* input_row = (A*)input + task_id * config.hidden_size;
RMSNorm<A>::rms_norm_single_with_weights(config.hidden_size, ffn_norm, input_row);
});
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("after attn norm");
#endif
// Moe
// size_t experts_ld = config.num_experts_per_tok;
size_t experts_ld = config.num_experts_per_tok + config.n_shared_experts;
if (gate != nullptr) {
gate->forward(seq_len, (input_t*)input, expert_ids, experts_weights, experts_ld);
for (size_t i = 0; i < seq_len; i++) {
for (size_t j = 0; j < config.n_shared_experts; j++) {
expert_ids[i * (experts_ld) + config.num_experts_per_tok + j] = config.n_routed_experts + j;
experts_weights[i * (experts_ld) + config.num_experts_per_tok + j] = 1.0f;
}
}
} else {
for (size_t i = 0; i < seq_len; i++) {
for (size_t j = 0; j < config.num_experts_per_tok + config.n_shared_experts; j++) {
expert_ids[i * (experts_ld) + j] = j;
experts_weights[i * (experts_ld) + j] = 1.0f;
}
}
}
// Debug 打印选中的 expert
// printf("chosen experts for layer %ld:\n", layer_idx);
// for (int i = 0; i < seq_len; i++) {
// for (int j = 0; j < experts_ld; j++) {
// printf("%ld ", expert_ids[i * experts_ld + j]);
// }
// printf("\n");
// }
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("moe gate");
// // Debug 重置选择专家为固定的连续专家
// for (int i = 0; i < seq_len; i++) {
// for (int j = 0; j < experts_ld; j++) {
// expert_ids[i * experts_ld + j] = j;
// }
// }
// Debug 打印选中的 expert
// printf("chosen experts for layer %ld:\n", layer_idx);
// for (int i = 0; i < seq_len; i++) {
// for (int j = 0; j < experts_ld; j++) {
// printf("%ld ", expert_ids[i * experts_ld + j]);
// }
// printf("\n");
// }
#endif
#ifdef DEBUG_LAYER_CORRECT
dump_bin(prefix + "_expert_ids", expert_ids, seq_len * (config.num_experts_per_tok + config.n_shared_experts));
dump_bin(prefix + "_expert_weights", experts_weights,
seq_len * (config.num_experts_per_tok + config.n_shared_experts));
#endif
convert_or_copy(hidden_states_bf16, (input_t*)input, seq_len * config.hidden_size);
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("convert bf16 time");
#endif
ffn->forward(seq_len, experts_ld, expert_ids, experts_weights, hidden_states_bf16, output);
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("ffn");
#endif
// Add Residue
config.pool->do_work_stealing_job(seq_len, [&](int task_id) {
A* hidden_state_row = hidden_states + task_id * config.hidden_size;
A* output_row = (A*)output + task_id * config.hidden_size;
for (size_t i = 0; i < config.hidden_size; i++) {
output_row[i] += hidden_state_row[i];
}
});
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("add ffn residue");
#endif
#ifdef DEBUG_LAYER_CORRECT
dump_bin(prefix + "_after_mlp", (input_t*)output, seq_len * config.hidden_size);
#endif
#ifdef FORWARD_TIME_PROFILE
time_perf_name = "DeepseekV3DecoderLayer" + std::to_string(layer_idx);
perf_report();
#endif
}
};
class DeepseekV3Model
#ifdef FORWARD_TIME_PROFILE
: protected TimePerf
#endif
{
private:
using A = float;
GeneralConfig config;
A* norm_weights;
public:
using input_t = float;
using output_t = float;
std::vector<std::shared_ptr<DeepseekV3DecoderLayer>> layers;
DeepseekV3Model(GeneralConfig config) : config(config) {
init_ggml();
norm_weights = new A[config.hidden_size];
convert_or_copy(norm_weights, config.norm_weights_ptr, config.norm_weights_type, config.hidden_size);
}
void forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,
const void* input, void* output) {
#ifdef FORWARD_TIME_PROFILE
forward_perf_start();
#endif
size_t seq_len = 0;
for (size_t i = 0; i < qlens.size(); i++) {
seq_len += qlens[i];
}
for (size_t i = 0; i < layers.size(); i++) {
layers[i]->forward(qlens, page_tables, kv_lens, input, output);
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP(std::string("layer ") + std::to_string(i));
#endif
if (i != layers.size() - 1) {
convert_or_copy((A*)input, (A*)output, seq_len * config.hidden_size);
} else {
config.pool->do_work_stealing_job(seq_len, [&](int task_id) {
A* output_row = (A*)output + task_id * config.hidden_size;
RMSNorm<A>::rms_norm_single_with_weights(config.hidden_size, norm_weights, output_row);
});
}
}
#ifdef FORWARD_TIME_PROFILE
time_perf_name = "DeepseekV3Model";
perf_report();
#endif
}
};
class DeepseekV3ForCausalLM
#ifdef FORWARD_TIME_PROFILE
: protected TimePerf
#endif
{
private:
using GemmKernel = arm_kml::GemmKernelInt4;
using KMatRefA = typename arm_kml::MatRef<int8_t>;
using KMatRefB = typename arm_kml::MatRef<GemmKernel::dt>;
// using KMatRefAB = typename arm_kml::MatRef<int8_t>;
using KMatRefC = typename arm_kml::MatRef<int32_t>;
using A = float;
GeneralConfig config;
A* input_hidden_states;
A* output_hidden_states;
A* lm_heads_ptr;
GemmKernel::BufferA* lm_heads_ba;
GemmKernel::BufferB* lm_heads_bb;
GemmKernel::BufferC* lm_heads_bc;
A* token_embd;
public:
using KMatRef = typename arm_kml::MatRef<float>;
using input_t = int64_t;
using output_t = float;
std::shared_ptr<DeepseekV3Model> model;
KMatRefB lm_heads;
DeepseekV3ForCausalLM(GeneralConfig config) : config(config) {
init_ggml();
MemoryRequest mem_requests;
lm_heads_ba = new GemmKernel::BufferA(config.max_qlen, config.hidden_size);
lm_heads_bb = new GemmKernel::BufferB(config.vocab_size, config.hidden_size, true);
lm_heads_bc = new GemmKernel::BufferC(config.max_qlen, config.vocab_size);
mem_requests.append_function([this](void* new_ptr) { lm_heads_ba->set_data(new_ptr); },
lm_heads_ba->required_size());
lm_heads_bb->set_data(std::aligned_alloc(64, lm_heads_bb->required_size()));
mem_requests.append_function([this](void* new_ptr) { lm_heads_bc->set_data(new_ptr); },
lm_heads_bc->required_size());
shared_mem_buffer.alloc(this, mem_requests);
input_hidden_states = new A[config.max_qlen * config.hidden_size];
output_hidden_states = new A[config.max_qlen * config.hidden_size];
lm_heads_ptr = new A[config.vocab_size * config.hidden_size];
token_embd = new A[config.vocab_size * config.hidden_size];
convert_or_copy(lm_heads_ptr, config.lm_heads_ptr, config.lm_heads_type, config.vocab_size * config.hidden_size);
// 做量化
auto pool = config.pool;
{
size_t nth_lm_b = GemmKernel::recommended_nth(config.vocab_size);
auto task = [&](int task_id) { lm_heads_bb->from_mat(lm_heads_ptr, task_id, nth_lm_b, -1, true); };
pool->do_work_stealing_job(nth_lm_b, task);
}
lm_heads = KMatRefB(lm_heads_bb->b, config.hidden_size, config.vocab_size, config.hidden_size, CblasColMajor,
CblasNoTrans, lm_heads_bb->d);
// lm_heads = KMatRef(lm_heads_ptr, config.vocab_size, config.hidden_size, config.hidden_size, CblasRowMajor);
convert_or_copy(token_embd, config.token_embd_ptr, config.token_embd_type, config.vocab_size * config.hidden_size);
}
void forward_binding(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,
intptr_t input, intptr_t output) {
forward(qlens, page_tables, kv_lens, (const void*)input, (void*)output);
}
void forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,
const void* input, void* output) {
#ifdef FORWARD_TIME_PROFILE
forward_perf_start();
#endif
{
size_t qlen_sum = 0;
for (size_t i = 0; i < qlens.size(); i++) {
qlen_sum += qlens[i];
}
// printf("DeepseekV3 forward, seq_len %ld\n", qlen_sum);
for (size_t i = 0; i < qlen_sum; i++) {
convert_or_copy(input_hidden_states + i * config.hidden_size,
token_embd + ((input_t*)input)[i] * config.hidden_size, config.hidden_size);
}
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("token embd");
#endif
#ifdef DEBUG_LAYER_CORRECT
dump_bin("input_ids", (input_t*)input, qlen_sum);
#endif
}
model->forward(qlens, page_tables, kv_lens, input_hidden_states, output_hidden_states);
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("model forward");
#endif
// KMatRef hidden_states_ref =
// KMatRef(output_hidden_states, config.hidden_size, config.max_qlen, config.hidden_size, CblasColMajor);
size_t qlen = 0;
for (size_t i = 0; i < qlens.size(); i++) {
qlen += qlens[i];
}
// printf("qlen is: %ld\n", qlen);
// KMatRef logits_ref = KMatRef((A *)output, config.vocab_size, config.max_qlen, config.vocab_size, CblasColMajor);
KMatRefC logits_ref =
KMatRefC(lm_heads_bc->c, config.max_qlen, config.vocab_size, config.vocab_size, CblasRowMajor);
KMatRef logits_out_ref = KMatRef((A*)output, config.max_qlen, config.vocab_size, config.vocab_size, CblasRowMajor);
// 量化输入
auto pool = config.pool;
{
size_t mth = GemmKernel::recommended_mth(qlen);
auto task_counter = TaskCounter({mth});
auto task = [&](int task_id) {
size_t mth_idx = task_counter.at(task_id, 0);
lm_heads_ba->from_mat(qlen, output_hidden_states, mth_idx, mth);
};
DIRECT_OR_POOL_BY(qlen, 10, task_counter.count(), task);
}
KMatRefA hidden_states_ref = KMatRefA(lm_heads_ba->a, qlen, config.hidden_size, config.hidden_size, CblasRowMajor);
size_t qlen_sum = 0;
for (size_t i = 0; i < qlens.size(); i++) {
// auto h = hidden_states_ref.offset_block(0, qlen_sum + qlens[i] - 1, config.hidden_size, 1);
auto h = hidden_states_ref.offset_block(qlen_sum + qlens[i] - 1, 0, 1, config.hidden_size);
{
const size_t vocab_block = 256;
const size_t vocab_block_count = div_up(config.vocab_size, vocab_block);
config.pool->do_work_stealing_job(vocab_block_count, [&](int task_id) {
size_t vocab_idx = task_id * vocab_block;
size_t vocab_begin = vocab_idx;
size_t vocab_end = std::min(vocab_begin + vocab_block, (size_t)config.vocab_size);
KMatRefB lm_head_ref = lm_heads.offset_col(vocab_begin, vocab_end - vocab_begin);
// KMatRef logits_ref_block = logits_ref.offset_block(vocab_begin, i, vocab_end - vocab_begin, 1);
KMatRefC logits_ref_block = logits_ref.offset_block(i, vocab_begin, 1, vocab_end - vocab_begin);
// arm_kml::decode_mul_mat_clearc(lm_head_ref, h, logits_ref_block);
// printf("h.ld: %ld, lm_head_ref.ld: %ld, logits_ref_block.ld: %ld\n", h.ld, lm_head_ref.ld,
// logits_ref_block.ld);
arm_kml::decode_mul_mat_clearc(h, lm_head_ref, logits_ref_block);
GemmKernel::apply_scale(logits_out_ref.data, logits_out_ref.ld, lm_heads_ba, lm_heads_bb, lm_heads_bc,
qlen_sum + qlens[i] - 1, qlen_sum + qlens[i], vocab_begin, vocab_end, true,
i - (qlen_sum + qlens[i] - 1));
});
}
qlen_sum += qlens[i];
}
#ifdef DEBUG_LAYER_CORRECT
dump_bin("output_logits", (output_t*)output, qlens.size() * config.vocab_size);
#endif
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("lm heads out");
#endif
#ifdef FORWARD_TIME_PROFILE
time_perf_name = "DeepseekV3ForCausalLM";
perf_report();
#endif
}
};
#endif

View File

@@ -1,255 +0,0 @@
#ifndef KML_GATE_HPP
#define KML_GATE_HPP
#include <algorithm>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <limits>
#include <vector>
#include "../common.hpp"
#include "kblas.h"
#include "la/arm_kml.hpp"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
// #define DEBUG_THIS_MOEGATE
#ifdef DEBUG_THIS_MOEGATE
#include "test/debug.hpp"
#endif
class MoEGate
#ifdef FORWARD_TIME_PROFILE
: protected TimePerf
#endif
{
private:
using A = float;
GeneralGateConfig config;
using KMatRef = typename arm_kml::MatRef<A>;
A* bias;
A* weight;
KMatRef weight_ref; // [expert_num, hidden_size]
A* logits; // [tokens, expert_num]
A* score_to_choice; // [tokens, expert_num]
A* group_score; // [tokens, group count]
size_t* temp_idx;
const size_t col_block = 256;
const size_t row_block = 256;
public:
using input_t = float;
using output_int_t = int64_t;
using output_t = float;
explicit MoEGate(const GeneralGateConfig& cfg) : config(cfg) {
ASSERT_RELEASE(config.weight, "cfg.weight must not be null");
ASSERT_RELEASE(config.n_routed_experts % config.n_group == 0, "E must be divisible by G");
ASSERT_RELEASE(config.scoring_func == "sigmoid", "Only sigmoid scoring function is supported");
ASSERT_RELEASE(config.topk_method == "noaux_tc", "Only noaux_tc topk method is supported");
ASSERT_RELEASE(config.norm_topk_prob, "must normalize topk prob");
MemoryRequest mem_requests;
// PUSH_MEM_REQ(input, sizeof(float) * config.max_seqlen * config.hidden_size);
PUSH_MEM_REQ(logits, sizeof(float) * config.max_seqlen * config.n_routed_experts);
PUSH_MEM_REQ(score_to_choice, sizeof(float) * config.max_seqlen * config.n_routed_experts);
PUSH_MEM_REQ(group_score, sizeof(float) * config.max_seqlen * config.n_group);
PUSH_MEM_REQ(temp_idx, sizeof(size_t) * config.max_seqlen * config.n_routed_experts);
shared_mem_buffer.alloc(this, mem_requests);
weight = new A[config.n_routed_experts * config.hidden_size];
bias = new A[config.n_routed_experts];
convert_or_copy(weight, config.weight, config.weight_type, config.n_routed_experts * config.hidden_size);
convert_or_copy(bias, config.e_score_correction_bias, config.e_score_correction_bias_type, config.n_routed_experts);
weight_ref = KMatRef(weight, config.n_routed_experts, config.hidden_size, config.hidden_size, CblasRowMajor);
}
void forward_binding(size_t seq_len, intptr_t input_hidden_states_raw, intptr_t output_topk_idx_raw,
intptr_t output_topk_weight_raw) {
forward(seq_len, (input_t*)input_hidden_states_raw, (output_int_t*)output_topk_idx_raw,
(output_t*)output_topk_weight_raw);
}
void forward(size_t seq_len, input_t* input_hidden_states, output_int_t* output_topk_idx,
output_t* output_topk_weight_raw) {
forward(seq_len, input_hidden_states, output_topk_idx, output_topk_weight_raw, config.num_experts_per_tok);
}
// forward: hidden_states [B,L,H] → (topk_idx, topk_weight) each [B·L,K]
// hidden_states must be contiguous rowmajor with H fastest.
// outputs are flattened (token first) and resized inside.
void forward(size_t seq_len, input_t* input_hidden_states, output_int_t* output_topk_idx,
output_t* output_topk_weight, size_t output_ld) {
#ifdef FORWARD_TIME_PROFILE
forward_perf_start();
#endif
KMatRef input_ref =
KMatRef((input_t*)input_hidden_states, config.hidden_size, seq_len, config.hidden_size, CblasColMajor);
#ifdef DEBUG_THIS_MOEGATE
dump_bin("gate_input", input_ref.data, seq_len * config.hidden_size);
#endif
KMatRef logits_ref = KMatRef(logits, config.n_routed_experts, seq_len, config.n_routed_experts, CblasColMajor);
{
size_t n_routed_experts_block = (seq_len == 1) ? 2 : 64;
const size_t seq_len_block = 64;
const size_t n_routed_experts_block_count = div_up(config.n_routed_experts, n_routed_experts_block);
const size_t seq_len_block_count = div_up(seq_len, seq_len_block);
auto task_counter = TaskCounter({n_routed_experts_block_count, seq_len_block_count});
auto task = [&](int task_id) {
size_t n_routed_experts_block_idx = task_counter.at(task_id, 0);
size_t seq_len_block_idx = task_counter.at(task_id, 1);
size_t n_routed_experts_begin = n_routed_experts_block_idx * n_routed_experts_block;
size_t n_routed_experts_end =
std::min(n_routed_experts_begin + n_routed_experts_block, config.n_routed_experts);
size_t seq_len_begin = seq_len_block_idx * seq_len_block;
size_t seq_len_end = std::min(seq_len_begin + seq_len_block, seq_len);
if (seq_len == 1) {
arm_kml::decode_mul_mat_clearc(
weight_ref.offset_block(n_routed_experts_begin, 0, n_routed_experts_end - n_routed_experts_begin,
config.hidden_size),
input_ref.offset_block(0, seq_len_begin, config.hidden_size, seq_len_end - seq_len_begin),
logits_ref.offset_block(n_routed_experts_begin, seq_len_begin,
n_routed_experts_end - n_routed_experts_begin, seq_len_end - seq_len_begin));
} else {
arm_kml::mul_mat_clearc(
weight_ref.offset_block(n_routed_experts_begin, 0, n_routed_experts_end - n_routed_experts_begin,
config.hidden_size),
input_ref.offset_block(0, seq_len_begin, config.hidden_size, seq_len_end - seq_len_begin),
logits_ref.offset_block(n_routed_experts_begin, seq_len_begin,
n_routed_experts_end - n_routed_experts_begin, seq_len_end - seq_len_begin));
}
};
config.pool->do_work_stealing_job(task_counter.count(), task);
}
// arm_kml::mul_mat_clearc(weight_ref, input_ref, logits_ref);
// if (seq_len == 1) {
// for (int i = 0; i < config.n_routed_experts; i++) {
// printf("%f ", logits[i]);
// }
// printf("\n");
// throw std::runtime_error("end");
// }
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("moe gate logits");
#endif
#ifdef DEBUG_THIS_MOEGATE
dump_bin("gate_logits", logits_ref.data, seq_len * config.n_routed_experts);
dump_bin("bias", bias, config.n_routed_experts);
#endif
for (size_t i = 0; i < seq_len; i++) {
float* logits_row = logits + i * config.n_routed_experts;
for (size_t j = 0; j < config.n_routed_experts; j++) {
logits_row[j] = 1.f / (1.f + std::exp(-logits_row[j]));
}
}
auto top2 = [](float* data, size_t begin, size_t end) {
if (end - begin < 2) {
throw std::invalid_argument("top2 requires at least two elements");
}
float first = -std::numeric_limits<float>::infinity();
float second = -std::numeric_limits<float>::infinity();
for (size_t i = begin; i < end; ++i) {
float v = data[i];
if (v > first) {
second = first;
first = v;
} else if (v > second) {
second = v;
}
}
return first + second;
};
for (size_t i = 0; i < seq_len; i++) {
float* logits_row = logits + i * config.n_routed_experts;
float* scores_row = score_to_choice + i * config.n_routed_experts;
for (size_t j = 0; j < config.n_routed_experts; j++) {
scores_row[j] = logits_row[j] + bias[j];
}
}
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("moe gate score to choice");
#endif
#ifdef DEBUG_THIS_MOEGATE
dump_bin("scores_to_choice", score_to_choice, seq_len * config.n_routed_experts);
#endif
for (size_t i = 0; i < seq_len; i++) {
output_int_t* output_topk_idx_row = output_topk_idx + i * output_ld;
output_t* output_topk_weight_row = output_topk_weight + i * output_ld;
float* logits_row = logits + i * config.n_routed_experts;
float* scores_row = score_to_choice + i * config.n_routed_experts;
float* group_score_row = group_score + i * config.n_group;
size_t* temp_idx_row = temp_idx + i * config.n_routed_experts;
size_t group_size = config.n_routed_experts / config.n_group;
for (size_t g = 0; g < config.n_group; g++) {
size_t group_begin = g * group_size;
size_t group_end = group_begin + group_size;
group_score_row[g] = top2(scores_row, group_begin, group_end);
temp_idx_row[g] = g;
}
std::sort(temp_idx_row, temp_idx_row + config.n_group,
[&](auto& a, auto& b) { return group_score_row[a] > group_score_row[b]; });
for (size_t g = config.topk_group; g < config.n_group; g++) {
size_t group_begin = temp_idx_row[g] * group_size;
size_t group_end = group_begin + group_size;
for (size_t j = group_begin; j < group_end; j++) {
scores_row[j] = -std::numeric_limits<float>::infinity();
}
}
for (int j = 0; j < config.n_routed_experts; j++) {
temp_idx_row[j] = j;
}
std::sort(temp_idx_row, temp_idx_row + config.n_routed_experts,
[&](auto& a, auto& b) { return scores_row[a] > scores_row[b]; });
float sum = 1e-20f;
for (int j = 0; j < config.num_experts_per_tok; j++) {
output_topk_idx_row[j] = temp_idx_row[j];
output_topk_weight_row[j] = logits_row[temp_idx_row[j]];
sum += output_topk_weight_row[j];
}
for (int j = 0; j < config.num_experts_per_tok; j++) {
output_topk_weight_row[j] /= sum;
output_topk_weight_row[j] *= config.routed_scaling_factor;
}
}
#ifdef DEBUG_THIS_MOEGATE
dump_bin("group_scores", group_score, seq_len * config.n_group);
dump_bin("gate_logits_toped", score_to_choice, seq_len * config.n_routed_experts);
#endif
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("gate_logits_toped");
#endif
// printf("Gate Forward Done\n");
// for(size_t i=0;i<seq_len;i++){
// printf("seq %ld, topk: ", i);
// for(size_t j=0;j<config.num_experts_per_tok;j++){
// printf("%ld ", output_topk_idx[i * output_ld + j]);
// }
// for(size_t j=0;j<config.num_experts_per_tok;j++){
// printf("%f ", output_topk_weight[i * output_ld + j]);
// }
// printf("\n");
// }
#ifdef FORWARD_TIME_PROFILE
time_perf_name = "moe gate inner";
perf_report();
#endif
}
};
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -1,487 +0,0 @@
#include "mla_quan.h"
#include "../mla-tp.hpp"
#include "../reduce.hpp"
#include "../rms-norm.hpp"
#include "../rope.hpp"
#include "../softmax.hpp"
#include "ggml-quants.h"
#include "ggml.h"
#include "kblas.h"
#include "la/arm_kml.hpp"
// #define DEBUG_THIS_MLA
#ifdef DEBUG_THIS_MLA
#include "test/debug.hpp"
#endif
#include <arm_sve.h>
#include <algorithm>
#include <cstddef>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <limits>
#include <memory>
#include <stdexcept>
#include <utility>
#include <vector>
// #define DEBUG_THIS_MLA
// #define FORWARD_TIME_PROFILE
#define DIRECT_OR_POOL_BY(what, threshold, var, fn) \
do { \
if ((what) < (threshold)) { \
for (int i = 0; i < (var); i++) { \
(fn)(i); \
} \
} else { \
pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \
} \
} while (0)
// inline void debug_output(const float16_t *data) {
// for (int i = 0; i < 10; i++) {
// printf("%f ", data[i]);
// }
// printf("\n");
// }
// inline void debug_output(const float *data) {
// for (int i = 0; i < 10; i++) {
// printf("%f ", data[i]);
// }
// printf("\n");
// }
// inline void debug_output(const bfloat16_t *data) {
// for (int i = 0; i < 10; i++) {
// float x = 0;
// *(bfloat16_t *)(&x) = data[i];
// printf("%f ", x);
// }
// printf("\n");
// }
// template <typename T> inline void dump_to(std::string file_name, T *data, size_t count) {
// std::ofstream f(file_name);
// for (int i = 0; i < count; i++) {
// f << data[i] << " ";
// }
// f.close();
// }
template <typename A, class KERNEL>
KML_MLA_TP_QUAN_TEST<A, KERNEL>::KML_MLA_TP_QUAN_TEST(GeneralMLAConfig config, int tp_part_idx)
: config(config), tp_part_idx(tp_part_idx) {
init_ggml();
cc_size = config.kv_lora_rank + config.rope_size;
MemoryRequest mem_requests;
softmax_scale = 1.0 / sqrt(config.rope_size + config.nope_size);
PUSH_MEM_REQ(q_lora_rank,
sizeof(std::remove_pointer_t<decltype(q_lora_rank)>) * config.q_lora_rank * config.max_qlen);
// qlen_decode_output.resize(config.num_heads, nullptr);
// for (int i = 0; i < config.num_heads; i++) {
// mem_requests.append_pointer(&qlen_decode_output[i], sizeof(std::remove_pointer_t<decltype(qlen_decode_output)>) *
// config.hidden_size * config.max_qlen);
// }
qlen_quant_output.resize((config.num_heads + sub_num_heads_decode - 1) / sub_num_heads_decode, nullptr);
for (int i = 0; i < (config.num_heads + sub_num_heads_decode - 1) / sub_num_heads_decode; i++) {
mem_requests.append_pointer(&qlen_quant_output[i], sizeof(std::remove_pointer_t<decltype(qlen_quant_output)>) *
config.hidden_size * config.max_qlen * sub_num_heads_decode);
}
mem_requests.append_function(
[this](void* new_ptr) {
auto& config = this->config;
q_nope = (A*)new_ptr;
q_nope_tmp_ref =
KMatRef(q_nope, config.nope_size, config.num_heads * config.max_qlen, config.nope_size, CblasColMajor);
},
sizeof(std::remove_pointer_t<decltype(q_nope)>) * config.num_heads * config.nope_size * config.max_qlen);
q_ld = std::max(config.kv_lora_rank, config.nope_size) + config.rope_size;
mem_requests.append_function(
[this](void* new_ptr) {
auto& config = this->config;
q = (A*)new_ptr;
q_ref = KMatRef(q, q_ld, config.num_heads * config.max_qlen, q_ld, CblasColMajor);
q_pe_absorb_ref = q_ref.offset_row(config.kv_lora_rank, config.rope_size);
q_pe_noabsorb_ref = q_ref.offset_row(config.nope_size, config.rope_size);
q_kv_lora_rank_ref = q_ref.offset_row(0, config.kv_lora_rank);
q_nope_ref = q_ref.offset_row(0, config.nope_size);
q_attn_absorb_ref = q_ref.offset_row(0, config.kv_lora_rank + config.rope_size);
q_attn_noabsorb_ref = q_ref.offset_row(0, config.nope_size + config.rope_size);
},
sizeof(std::remove_pointer_t<decltype(q)>) * config.num_heads * config.max_qlen * q_ld);
mem_requests.append_function(
[this](void* new_ptr) {
auto& config = this->config;
attention_weights = (A*)new_ptr;
attention_weights_ref = KMatRef(attention_weights, config.max_kvlen, config.max_qlen * config.num_heads,
config.max_kvlen, CblasColMajor);
},
sizeof(std::remove_pointer_t<decltype(attention_weights)>) * config.max_kvlen * config.max_qlen *
config.num_heads);
mem_requests.append_function(
[this](void* new_ptr) {
auto& config = this->config;
attention_output = (A*)new_ptr;
attention_output_ref = KMatRef(attention_output, config.nope_size * config.num_heads, config.max_qlen,
config.nope_size * config.num_heads, CblasColMajor);
},
(sizeof(std ::remove_pointer_t<decltype(attention_output)>) * config.num_heads * config.nope_size *
config.max_qlen));
mem_requests.append_function(
[this](void* new_ptr) {
auto& config = this->config;
k = (A*)new_ptr;
k_ref = KMatRef(k, config.nope_size + config.rope_size, config.max_kvlen * config.num_heads,
config.nope_size + config.rope_size, CblasColMajor);
k_nope_ref = k_ref.offset_row(0, config.nope_size);
k_rope_ref = k_ref.offset_row(config.nope_size, config.rope_size);
},
sizeof(std::remove_pointer_t<decltype(k)>) * config.num_heads * (config.nope_size + config.rope_size) *
config.max_kvlen);
size_t o_absorb_or_v_size = std::max(config.kv_lora_rank * config.max_qlen, config.nope_size * config.max_kvlen);
mem_requests.append_function(
[this](void* new_ptr) {
auto& config = this->config;
o_absorb_or_v = (A*)new_ptr;
o_absorb_ref = KMatRef(o_absorb_or_v, config.kv_lora_rank, config.max_qlen * config.num_heads,
config.kv_lora_rank, CblasColMajor);
v_ref = KMatRef(o_absorb_or_v, config.nope_size, config.max_kvlen * config.num_heads, config.nope_size,
CblasColMajor);
},
sizeof(std::remove_pointer_t<decltype(o_absorb_or_v)>) * o_absorb_or_v_size * config.num_heads);
rope_angle = std::make_unique<T_RopeAngle>(
config.rope_size, config.max_position_embeddings, config.rope_theta, config.rope_scaling_factor,
config.rope_scaling_original_max_position_embeddings, config.rope_scaling_beta_fast,
config.rope_scaling_beta_slow, config.rope_scaling_mscale, config.rope_scaling_mscale_all_dim);
rope_angle->init(config.max_kvlen);
// local_q_a_proj_deprecated_ba = new GemmKernel::BufferA(config.q_lora_rank, config.hidden_size, false);
// local_q_a_proj_deprecated_ba = new typename GemmKernel::BufferA(config.q_lora_rank, config.hidden_size, true);
// local_q_a_proj_deprecated_bb = new typename GemmKernel::BufferB(config.max_qlen, config.hidden_size);
// local_q_a_proj_deprecated_bc = new typename GemmKernel::BufferC(config.q_lora_rank, config.max_qlen);
local_q_a_proj_quant_ba = new typename GemmKernel::BufferA(config.max_qlen, config.hidden_size);
// local_q_a_proj_quant_bb = new GemmKernel::BufferB(config.q_lora_rank, config.hidden_size, false);
local_q_a_proj_quant_bb = new typename GemmKernel::BufferB(config.q_lora_rank, config.hidden_size, true);
local_q_a_proj_quant_bc = new typename GemmKernel::BufferC(config.max_qlen, config.q_lora_rank, true); // row major
mem_requests.append_function([this](void* new_ptr) { local_q_a_proj_quant_ba->set_data(new_ptr); },
local_q_a_proj_quant_ba->required_size());
local_q_a_proj_quant_bb->set_data(std::aligned_alloc(64, local_q_a_proj_quant_bb->required_size()));
mem_requests.append_function([this](void* new_ptr) { local_q_a_proj_quant_bc->set_data(new_ptr); },
local_q_a_proj_quant_bc->required_size());
// local_kv_a_proj_with_mqa_deprecated_ba = new GemmKernel::BufferA(cc_size, config.hidden_size, false);
// local_kv_a_proj_with_mqa_deprecated_ba = new typename GemmKernel::BufferA(cc_size, config.hidden_size, true);
// local_kv_a_proj_with_mqa_deprecated_bb = new typename GemmKernel::BufferB(config.max_qlen, config.hidden_size);
// for (int i = 0; i < config.page_count; i++) {
// local_kv_a_proj_with_mqa_deprecated_bc.push_back(new
// typename GemmKernel::BufferC(cc_size, config.token_count_in_page));
// }
local_kv_a_proj_with_mqa_quant_ba = new typename GemmKernel::BufferA(config.max_qlen, config.hidden_size);
// local_kv_a_proj_with_mqa_quant_bb = new GemmKernel::BufferB(cc_size, config.hidden_size, false);
local_kv_a_proj_with_mqa_quant_bb = new typename GemmKernel::BufferB(cc_size, config.hidden_size, true);
for (int i = 0; i < config.page_count; i++) {
local_kv_a_proj_with_mqa_quant_bc.push_back(
new typename GemmKernel::BufferC(config.token_count_in_page, cc_size, true)); // row major
}
mem_requests.append_function([this](void* new_ptr) { local_kv_a_proj_with_mqa_quant_ba->set_data(new_ptr); },
local_kv_a_proj_with_mqa_quant_ba->required_size());
local_kv_a_proj_with_mqa_quant_bb->set_data(
std::aligned_alloc(64, local_kv_a_proj_with_mqa_quant_bb->required_size()));
cc_page_refs_buffer.resize(config.page_count);
kv_lora_page_refs_buffer.resize(config.page_count);
rope_page_refs_buffer.resize(config.page_count);
cc_page_refs_decode_buffer.resize(config.page_count);
kv_lora_page_refs_decode_buffer.resize(config.page_count);
rope_page_refs_decode_buffer.resize(config.page_count);
for (int i = 0; i < config.page_count; i++) {
mem_requests.append_function(
[this, i, config](void* new_ptr) {
local_kv_a_proj_with_mqa_quant_bc[i]->set_data(new_ptr);
cc_page_refs_buffer[i] = KMatRefC(local_kv_a_proj_with_mqa_quant_bc[i]->c, cc_size,
config.token_count_in_page, cc_size, CblasColMajor);
kv_lora_page_refs_buffer[i] = cc_page_refs_buffer[i].offset_row(0, config.kv_lora_rank);
rope_page_refs_buffer[i] = cc_page_refs_buffer[i].offset_row(config.kv_lora_rank, config.rope_size);
cc_page_refs_decode_buffer[i] = KMatRefC(local_kv_a_proj_with_mqa_quant_bc[i]->c, config.token_count_in_page,
cc_size, cc_size, CblasRowMajor);
kv_lora_page_refs_decode_buffer[i] = cc_page_refs_decode_buffer[i].offset_col(0, config.kv_lora_rank);
rope_page_refs_decode_buffer[i] =
cc_page_refs_decode_buffer[i].offset_col(config.kv_lora_rank, config.rope_size);
},
local_kv_a_proj_with_mqa_quant_bc[i]->required_size());
}
// local_w_o_ba = new GemmKernel::BufferA(config.hidden_size, config.num_heads * config.nope_size, false);
// local_w_o_ba = new typename GemmKernel::BufferA(config.hidden_size, config.num_heads * config.nope_size, true);
// local_w_o_bb = new typename GemmKernel::BufferB(config.max_qlen, config.num_heads * config.nope_size);
// local_w_o_bc = new typename GemmKernel::BufferC(config.hidden_size, config.max_qlen);
for (int i = 0; i < div_up(config.num_heads, sub_num_heads_decode); i++) {
local_w_o_decode_bc.push_back(new typename GemmKernel::BufferC(config.hidden_size, config.max_qlen)); // col major
local_w_o_decode_bb.push_back(
new typename GemmKernel::BufferB(config.hidden_size, sub_num_heads_decode * config.nope_size, true));
}
local_w_o_quant_ba = new typename GemmKernel::BufferA(config.max_qlen, config.num_heads * config.nope_size);
// local_w_o_quant_bb = new GemmKernel::BufferB(config.hidden_size, config.num_heads * config.nope_size, false);
local_w_o_quant_bb = new typename GemmKernel::BufferB(config.hidden_size, config.num_heads * config.nope_size, true);
local_w_o_prefill_bc = new typename GemmKernel::BufferC(config.max_qlen, config.hidden_size, true); // row major
mem_requests.append_function([this](void* new_ptr) { local_w_o_quant_ba->set_data(new_ptr); },
local_w_o_quant_ba->required_size());
local_w_o_quant_bb->set_data(std::aligned_alloc(64, local_w_o_quant_bb->required_size()));
mem_requests.append_function([this](void* new_ptr) { local_w_o_prefill_bc->set_data(new_ptr); },
local_w_o_prefill_bc->required_size());
for (int i = 0; i < div_up(config.num_heads, sub_num_heads_decode); i++) {
mem_requests.append_function([this, i](void* new_ptr) { local_w_o_decode_bc[i]->set_data(new_ptr); },
local_w_o_decode_bc[i]->required_size());
local_w_o_decode_bb[i]->set_data(std::aligned_alloc(64, local_w_o_decode_bb[i]->required_size()));
}
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
}
template <typename A, class KERNEL>
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::load_weights(int complete_num_heads, int offset) {
if constexpr (std::is_same_v<A, float16_t>) {
ASSERT_RELEASE(config.q_a_proj_type == GGML_TYPE_F16, "q_a_proj_type must be GGML_TYPE_F16");
ASSERT_RELEASE(config.q_b_proj_type == GGML_TYPE_F16, "q_b_proj_type must be GGML_TYPE_F16");
ASSERT_RELEASE(config.kv_a_proj_with_mqa_type == GGML_TYPE_F16, "kv_a_proj_with_mqa_type must be GGML_TYPE_F16");
ASSERT_RELEASE(config.kv_b_proj_type == GGML_TYPE_F16, "kv_b_proj_type must be GGML_TYPE_F16");
ASSERT_RELEASE(config.w_o_type == GGML_TYPE_F16, "w_o_type must be GGML_TYPE_F16");
} else if constexpr (std::is_same_v<A, float>) {
} else {
throw std::runtime_error("Unsupported type for KML_MLA_TP");
}
default_attention_masks.resize(config.max_kvlen);
for (int i = 0; i < config.max_kvlen; i++) {
A* mask = new A[config.max_kvlen];
memset(mask, 0, config.max_kvlen * sizeof(A));
for (int j = i + 1; j < config.max_kvlen; j++) {
mask[j] = -std::numeric_limits<float>::infinity();
}
default_attention_masks[i] = mask;
}
local_q_a_proj = new A[config.hidden_size * config.q_lora_rank];
convert_or_copy(local_q_a_proj, config.q_a_proj, (ggml_type)config.q_a_proj_type,
config.hidden_size * config.q_lora_rank);
local_q_a_norm = new A[config.q_lora_rank];
if (config.q_a_norm == nullptr) {
for (size_t i = 0; i < config.q_lora_rank; i++) {
local_q_a_norm[i] = 1;
}
} else {
convert_or_copy(local_q_a_norm, config.q_a_norm, (ggml_type)config.q_a_norm_type, config.q_lora_rank);
}
local_kv_a_proj_with_mqa = new A[config.hidden_size * (config.kv_lora_rank + config.rope_size)];
convert_or_copy(local_kv_a_proj_with_mqa, config.kv_a_proj_with_mqa, (ggml_type)config.kv_a_proj_with_mqa_type,
config.hidden_size * (config.kv_lora_rank + config.rope_size));
local_kv_a_norm = new A[config.kv_lora_rank];
if (config.kv_a_norm == nullptr) {
for (size_t i = 0; i < config.kv_lora_rank; i++) {
local_kv_a_norm[i] = 1;
}
} else {
convert_or_copy(local_kv_a_norm, config.kv_a_norm, (ggml_type)config.kv_a_norm_type, config.kv_lora_rank);
}
local_q_b_proj = new A[config.num_heads * (config.nope_size + config.rope_size) * config.q_lora_rank];
convert_or_copy(local_q_b_proj,
offset_pointer(config.q_b_proj, offset * (config.nope_size + config.rope_size) * config.q_lora_rank *
ggml_type_size((ggml_type)config.q_b_proj_type)),
(ggml_type)config.q_b_proj_type,
config.num_heads * (config.nope_size + config.rope_size) * config.q_lora_rank);
local_k_b_proj = new A[config.num_heads * config.nope_size * config.kv_lora_rank];
local_v_b_proj = new A[config.num_heads * config.nope_size * config.kv_lora_rank];
for (size_t i = 0; i < config.num_heads; i++) {
convert_or_copy(
local_k_b_proj + i * config.nope_size * config.kv_lora_rank,
offset_pointer(config.kv_b_proj, (i + offset) * (config.nope_size + config.nope_size) * config.kv_lora_rank *
ggml_type_size((ggml_type)config.kv_b_proj_type)),
(ggml_type)config.kv_b_proj_type, config.nope_size * config.kv_lora_rank);
convert_or_copy(
local_v_b_proj + i * config.nope_size * config.kv_lora_rank,
offset_pointer(config.kv_b_proj, ((i + offset) * (config.nope_size + config.nope_size) + config.nope_size) *
config.kv_lora_rank * ggml_type_size((ggml_type)config.kv_b_proj_type)),
(ggml_type)config.kv_b_proj_type, config.nope_size * config.kv_lora_rank);
}
local_k_b_proj_ref = KMatRef((A*)local_k_b_proj, config.num_heads * config.nope_size, config.kv_lora_rank,
config.kv_lora_rank, CblasRowMajor);
local_v_b_proj_ref = KMatRef((A*)local_v_b_proj, config.num_heads * config.nope_size, config.kv_lora_rank,
config.kv_lora_rank, CblasRowMajor);
local_w_o = new A[config.num_heads * config.hidden_size * config.nope_size];
for (size_t i = 0; i < config.hidden_size; i++) {
convert_or_copy(
local_w_o + i * config.num_heads * config.nope_size,
offset_pointer(config.o_proj, (i * complete_num_heads * config.nope_size + (offset)*config.nope_size) *
ggml_type_size((ggml_type)config.w_o_type)),
(ggml_type)config.w_o_type, config.num_heads * config.nope_size);
}
size_t sub_num_heads_decode_group = div_up(config.num_heads, sub_num_heads_decode);
local_w_decode_o.resize(sub_num_heads_decode_group);
for (size_t h = 0; h < sub_num_heads_decode_group; h++) {
local_w_decode_o[h] = new A[config.hidden_size * sub_num_heads_decode * config.nope_size];
for (size_t i = 0; i < config.hidden_size; i++) {
convert_or_copy(local_w_decode_o[h] + i * config.nope_size * sub_num_heads_decode,
offset_pointer(config.o_proj, (i * complete_num_heads * config.nope_size +
(h * sub_num_heads_decode + offset) * config.nope_size) *
ggml_type_size((ggml_type)config.w_o_type)),
(ggml_type)config.w_o_type, config.nope_size * sub_num_heads_decode);
}
}
// 做量化
#ifdef DEBUG_THIS_MLA
printf("KML_MLA_TP_QUAN_TEST::load_weights: local_q_a_proj_quant_bb quantization\n");
if (tp_part_idx == 0) {
dump_bin("local_q_a_proj_quant.bin", local_q_a_proj, config.hidden_size * config.q_lora_rank);
}
#endif
auto pool = config.pool->get_subpool(tp_part_idx);
{
size_t mth_q_a = GemmKernel::recommended_nth(config.q_lora_rank);
size_t mth_kv_a = GemmKernel::recommended_nth(config.kv_lora_rank + config.rope_size);
auto task_counter = TaskCounter({mth_q_a + mth_kv_a});
auto task = [&](int task_id) {
// 前 nth 是 local_q_a_proj, 后 nth 是 local_kv_a_proj_with_mqa
size_t mth_idx = task_counter.at(task_id, 0);
if (mth_idx < mth_q_a) {
// local_q_a_proj_deprecated_ba->from_mat(config.q_lora_rank, (A *)local_q_a_proj, mth_idx, mth_q_a);
local_q_a_proj_quant_bb->from_mat((A*)local_q_a_proj, mth_idx, mth_q_a, config.q_lora_rank);
} else {
mth_idx -= mth_q_a;
// local_kv_a_proj_with_mqa_deprecated_ba->from_mat(config.kv_lora_rank + config.rope_size,
// (A *)local_kv_a_proj_with_mqa, mth_idx, mth_kv_a);
local_kv_a_proj_with_mqa_quant_bb->from_mat((A*)local_kv_a_proj_with_mqa, mth_idx, mth_kv_a,
config.kv_lora_rank + config.rope_size);
}
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
printf("KML_MLA_TP_QUAN_TEST::load_weights: local_w_o_quant_bb quantization\n");
#endif
{
size_t mth_w_o = GemmKernel::recommended_mth(config.hidden_size);
auto task_counter = TaskCounter({mth_w_o});
auto task = [&](int task_id) {
size_t mth_idx = task_counter.at(task_id, 0);
// local_w_o_ba->from_mat(config.hidden_size, (A *)local_w_o, mth_idx, mth_w_o);
local_w_o_quant_bb->from_mat((A*)local_w_o, mth_idx, mth_w_o, config.hidden_size);
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
printf("KML_MLA_TP_QUAN_TEST::load_weights: local_w_decode_o quantization\n");
#endif
{
size_t mth_w_o = GemmKernel::recommended_mth(config.hidden_size);
auto task_counter = TaskCounter({sub_num_heads_decode_group, mth_w_o});
auto task = [&](int task_id) {
size_t h = task_counter.at(task_id, 0);
size_t mth_idx = task_counter.at(task_id, 1);
// local_w_o_decode_ba[h]->from_mat(config.hidden_size, (A *)local_w_decode_o[h], mth_idx, mth_w_o);
local_w_o_decode_bb[h]->from_mat((A*)local_w_decode_o[h], mth_idx, mth_w_o, config.hidden_size);
};
pool->do_work_stealing_job(task_counter.count(), task);
}
local_q_a_proj_quant_ref =
KMatRefB(local_q_a_proj_quant_bb->b, config.hidden_size, config.q_lora_rank, config.hidden_size, CblasColMajor,
CblasNoTrans, local_q_a_proj_quant_bb->if_pack);
// local_q_a_proj_ref = KMatRefAB(local_q_a_proj_deprecated_ba->a, config.q_lora_rank, config.hidden_size,
// config.hidden_size, CblasRowMajor);
local_q_b_proj_ref = KMatRef((A*)local_q_b_proj, config.num_heads * (config.nope_size + config.rope_size),
config.q_lora_rank, config.q_lora_rank, CblasRowMajor);
local_kv_a_proj_with_mqa_decode_ref =
KMatRefB(local_kv_a_proj_with_mqa_quant_bb->b, config.hidden_size, config.kv_lora_rank + config.rope_size,
config.hidden_size, CblasColMajor, CblasNoTrans, local_kv_a_proj_with_mqa_quant_bb->if_pack);
local_w_o_ref =
KMatRefB(local_w_o_quant_bb->b, config.hidden_size, config.nope_size * config.num_heads,
config.nope_size * config.num_heads, CblasRowMajor, CblasNoTrans, local_w_o_quant_bb->if_pack);
delete[] local_w_o;
delete[] local_q_a_proj;
delete[] local_kv_a_proj_with_mqa;
for (int i = 0; i < sub_num_heads_decode_group; i++) {
delete[] local_w_decode_o[i];
}
local_w_decode_o.clear();
}
template <typename A, class KERNEL>
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::set_pages(std::vector<void*> kv_lora_pages, std::vector<void*> pe_pages) {
// this->kv_lora_pages = kv_lora_pages;
// this->rope_pages = pe_pages;
}
template <typename A, class KERNEL>
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::set_local_pages(int page_count) {
cc_pages.resize(page_count);
cc_page_refs.resize(page_count);
kv_lora_page_refs.resize(page_count);
rope_page_refs.resize(page_count);
// cc_page_refs_buffer.resize(page_count);
// kv_lora_page_refs_buffer.resize(page_count);
// rope_page_refs_buffer.resize(page_count);
cc_page_refs_decode_buffer.resize(page_count);
kv_lora_page_refs_decode_buffer.resize(page_count);
rope_page_refs_decode_buffer.resize(page_count);
for (int i = 0; i < page_count; i++) {
cc_pages[i] = new A[config.token_count_in_page * (config.kv_lora_rank + config.rope_size)];
cc_page_refs[i] = KMatRef(cc_pages[i], cc_size, config.token_count_in_page, cc_size, CblasColMajor);
// cc_page_refs_buffer[i] = KMatRefC(local_kv_a_proj_with_mqa_deprecated_bc[i]->c, cc_size,
// config.token_count_in_page,
// cc_size, CblasColMajor);
kv_lora_page_refs[i] = cc_page_refs[i].offset_row(0, config.kv_lora_rank);
// kv_lora_page_refs_buffer[i] = cc_page_refs_buffer[i].offset_row(0, config.kv_lora_rank);
kv_lora_page_refs_decode_buffer[i] = cc_page_refs_decode_buffer[i].offset_col(0, config.kv_lora_rank);
rope_page_refs[i] = cc_page_refs[i].offset_row(config.kv_lora_rank, config.rope_size);
// rope_page_refs_buffer[i] = cc_page_refs_buffer[i].offset_row(config.kv_lora_rank, config.rope_size);
rope_page_refs_decode_buffer[i] = cc_page_refs_decode_buffer[i].offset_col(config.kv_lora_rank, config.rope_size);
}
}
template <typename A, class KERNEL>
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables,
std::vector<int> kv_lens, std::vector<void*> attention_masks,
const void* input, void* output) {
if (qlens[0] <= 1) {
forward_decode(qlens, page_tables, kv_lens, attention_masks, (input_t*)input, (output_t*)output);
} else {
forward_prefill(qlens, page_tables, kv_lens, attention_masks, (input_t*)input, (output_t*)output);
}
}
template <typename A, class KERNEL>
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables,
std::vector<int> kv_lens, const void* input, void* output) {
forward(qlens, page_tables, kv_lens, default_attention_masks, input, output);
}

View File

@@ -1,268 +0,0 @@
#include "../mla-tp.hpp"
#include "../reduce.hpp"
#include "../rms-norm.hpp"
#include "../rope.hpp"
#include "../softmax.hpp"
#include "ggml-quants.h"
#include "ggml.h"
#include "kblas.h"
#include "la/arm_kml.hpp"
// #define DEBUG_THIS_MLA
#ifdef DEBUG_THIS_MLA
#include "test/debug.hpp"
#endif
#define DIRECT_OR_POOL_BY(what, threshold, var, fn) \
do { \
if ((what) < (threshold)) { \
for (int i = 0; i < (var); i++) { \
(fn)(i); \
} \
} else { \
pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \
} \
} while (0)
template <typename A, class KERNEL>
class KML_MLA_TP_QUAN_TEST
#ifdef FORWARD_TIME_PROFILE
: protected TimePerf
#endif
{
public:
using input_t = A;
using output_t = A;
using quant_t = int8_t;
KML_MLA_TP_QUAN_TEST(GeneralMLAConfig config, int tp_part_idx);
void load_weights(int complete_num_heads, int offset);
void set_pages(std::vector<void*> kv_lora_pages, std::vector<void*> pe_pages);
void set_local_pages(int page_count);
void forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,
std::vector<void*> attention_masks, const void* input, void* output);
void forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,
const void* input, void* output);
private:
using T_RMSNorm = RMSNorm<A>;
using T_RopeAngle = DeepseekV3YarnRotaryEmbedding;
using T_RopeApplier = Rope<T_RopeAngle, A>;
using T_SoftmaxApplier = Softmax<A>;
using KMatRefA = typename arm_kml::MatRef<int8_t>;
using KMatRefB = typename arm_kml::MatRef<typename KERNEL::dt>;
using KMatRef = typename arm_kml::MatRef<A>;
using KMatRefC = typename arm_kml::MatRef<int32_t>;
using GemmKernel = KERNEL;
GeneralMLAConfig config;
const size_t col_block = 256;
const size_t row_block = 256;
// for quant
const size_t col_block_q_absorb = 512; // 上限kv_lora_rank:512
const size_t row_block_q_absorb = 512; // 上限qlen
const size_t col_block_q_lora_kv_lora_rope = 256;
const size_t row_block_q_lora_kv_lora_rope = 256;
const size_t col_block_q_nope_rope = 512;
const size_t row_block_q_nope_rope = 512;
const size_t col_block_attention_output = 1024; // 上限qlen的大小
const size_t row_block_attention_output = 256; // 上限128
const size_t col_block_out_by_head = 256; // 上限 qlen
const size_t row_block_out_by_head = 256; // 上限 hidden_size:7168
// ==========================================
// decode
const size_t decode_col_block_q_absorb = 256; // 上限kv_lora_rank:512
const size_t decode_row_block_q_absorb = 64; // 上限qlen
const size_t decode_col_block_q_lora_kv_lora_rope = 64;
const size_t decode_row_block_q_lora_kv_lora_rope = 64;
const size_t decode_col_block_q_nope_rope = 512;
const size_t decode_row_block_q_nope_rope = 512;
const size_t decode_col_block_attention_output = 1024; // 上限qlen的大小
const size_t decode_row_block_attention_output = 256; // 上限128
const size_t decode_col_block_out_by_head = 1; // 上限 qlen
const size_t decode_row_block_out_by_head = 512; // 上限 hidden_size:7168
// ==========================================
const size_t col_block_o_absorb = 256;
const size_t row_block_o_absorb = 256;
const size_t col_block_nope_attention = 256;
const size_t row_block_nope_attention = 256;
const size_t col_block_pe_attention = 256;
const size_t row_block_pe_attention = 256;
int tp_part_idx;
std::vector<void*> default_attention_masks;
// std::vector<void *> kv_lora_pages; // [page_count * page_token_count * nope]
// std::vector<void *> rope_pages; // [page_count * page_token_count * nope]
std::vector<A*> cc_pages; // [page_count * page_token_count * (kv rank + rope size)]
size_t cc_size;
// col major:[kv_lora_rank, qlen] or row major:[qlen, kv_lora_rank]
std::vector<KMatRef> cc_page_refs, kv_lora_page_refs, rope_page_refs;
// col major:[kv_lora_rank, qlen] or [rope_size, qlen]
std::vector<KMatRefC> cc_page_refs_buffer, kv_lora_page_refs_buffer, rope_page_refs_buffer;
// row major:[qlen, kv_lora_rank] or [qlen, rope_size]
std::vector<KMatRefC> cc_page_refs_decode_buffer, kv_lora_page_refs_decode_buffer, rope_page_refs_decode_buffer;
// weights
A* local_q_a_proj; // [hidden_size * q_lora_rank]
A* local_q_a_norm;
A* local_q_b_proj; // [num_heads * (nope_size + rope_size))]
A* local_kv_a_proj_with_mqa; // [hidden_size * (kv_lora_rank + rope)]
A* local_kv_a_norm;
A* local_k_b_proj;
A* local_v_b_proj;
// A *local_kv_b_proj; // [(num_heads * (nope_size + nope_size) * kv_lora_rank)],
// q_absorb: [num_heads * nope_size * kv_lora_rank]
// out_absorb: [num_heads * nope_size * kv_lora_rank]
A* local_w_o; // [(num_heads * hidden_size * nope_size)]
std::vector<A*>
local_w_decode_o; // [num_heads/sub_num_heads_decode]*[(sub_num_heads_decode * hidden_size * nope_size)]
std::unique_ptr<T_RopeAngle> rope_angle;
// KMatRefAB local_q_a_proj_ref;
KMatRefB local_q_a_proj_quant_ref;
KMatRef local_q_b_proj_ref;
// KMatRefAB local_kv_a_proj_with_mqa_ref;
KMatRefB local_kv_a_proj_with_mqa_decode_ref;
KMatRef local_k_b_proj_ref;
KMatRef local_v_b_proj_ref;
// KMatRefAB local_w_o_decode_ref;
KMatRefB local_w_o_ref;
typename GemmKernel::BufferA* local_q_a_proj_quant_ba; // [max_qlen,hidden_size]
typename GemmKernel::BufferB* local_q_a_proj_quant_bb; // [hidden_size, q_lora_rank]
typename GemmKernel::BufferC* local_q_a_proj_quant_bc; // [max_qlen,q_lora_rank] (row major)
typename GemmKernel::BufferA* local_kv_a_proj_with_mqa_quant_ba; // [max_qlen, hidden_size]
typename GemmKernel::BufferB* local_kv_a_proj_with_mqa_quant_bb; // [hidden_size, kv_lora_rank + rope_size]
std::vector<typename GemmKernel::BufferC*>
local_kv_a_proj_with_mqa_quant_bc; // page_count * [page_token_count, rope_size + kv_lora_rank] (row major)
// 对应local_w_o
// 对应local_w_o
typename GemmKernel::BufferA* local_w_o_quant_ba; // [max_qlen, num_heads * nope_size]
typename GemmKernel::BufferB* local_w_o_quant_bb; // [num_heads * nope_size, hidden_size]
// qlen_output
typename GemmKernel::BufferC* local_w_o_prefill_bc; // [max_qlen, hidden_size]
std::vector<typename GemmKernel::BufferC*>
local_w_o_decode_bc; // [num_heads/sub_num_heads_decode] *[max_qlen, hidden_size]
// std::vector<typename GemmKernel::BufferA *>
// local_w_o_decode_ba; // [num_heads/sub_num_heads_decode]*[hidden_size,sub_num_heads_decode * nope_size] row
// major
std::vector<typename GemmKernel::BufferB*>
local_w_o_decode_bb; // [num_heads/sub_num_heads_decode]*[sub_num_heads_decode * nope_size,hidden_size] col major
// for each query
A* q_lora_rank; // [qlen_sum, q_lora_rank]
A* q_nope; // [num_heads * max_qlen, nope_size]
KMatRef q_nope_tmp_ref;
A* q; // [num_heads * max_qlen, max(kv_lora_rank,nope_size) + rope_size]
size_t q_ld;
KMatRef q_ref, q_pe_absorb_ref, q_pe_noabsorb_ref, q_nope_ref, q_kv_lora_rank_ref, q_attn_absorb_ref,
q_attn_noabsorb_ref;
// std::vector<A *> q_pe; // [num_heads * max_qlen * rope_size]
// std::vector<A *> q_nope; // [num_heads * max_qlen * nope_size]
A* k; // [num_heads * max_kvlen * (nope_size + rope_size)]
KMatRef k_ref, k_nope_ref, k_rope_ref;
A* attention_weights;
KMatRef attention_weights_ref; // [max_kvlen, max_qlen* num_heads]
// std::vector<A *> q_absorb; // [num_heads, max_qlen, kv_lora_rank], or [num_heads, kv_lora_rank, max_qlen]
A* o_absorb_or_v;
KMatRef o_absorb_ref; // [num_heads, max_qlen, kv_lora_rank]
KMatRef v_ref; // [num_heads,nope,max_kvlen]
A* attention_output; // [num_heads * max_qlen * nope]
KMatRef attention_output_ref;
size_t sub_num_heads = 16; // 用于并发的子头设置
size_t sub_num_heads_decode = 8; // 用于并发的子头设置
// std::vector<A *> qlen_decode_output; // [max_qlen * hidden_size]
std::vector<A*> qlen_quant_output; // [[num_heads/sub_num_heads] * max_qlen * hidden_size] row major
A softmax_scale;
#ifdef DEBUG_THIS_MLA
std::string file_name;
#endif
static bool decide_absorb(size_t qlen, size_t existing_kvlen) {
double x = existing_kvlen;
return qlen < (-x + sqrt(x * (x + 2048.0 / 3.0)) / 2.0);
}
// 只保留声明,移除实现
void nope_attention_q_absorb(int qlen, int kvlen, const std::vector<int>& page_table, bool increamental = true,
bool is_decode = false);
void nope_attention_no_absorb(int qlen, int kvlen, const std::vector<int>& page_table, bool increamental = true);
void output_absorb(int query, const std::vector<int>& qlens, const std::vector<std::vector<int>>& page_tables,
const std::vector<int>& kvlens, bool is_decode = false);
void output_no_absorb(int query, const std::vector<int>& qlens, const std::vector<std::vector<int>>& page_tables,
const std::vector<int>& kvlens);
void forward_prefill(const std::vector<int>& qlens, const std::vector<std::vector<int>>& page_tables,
const std::vector<int>& kvlens, const std::vector<void*>& attention_masks,
const input_t* input_raw, output_t* output_raw);
void forward_decode(const std::vector<int>& qlens, const std::vector<std::vector<int>>& page_tables,
const std::vector<int>& kvlens, const std::vector<void*>& attention_masks,
const input_t* input_raw, output_t* output_raw);
void q_lora_kv_lora_rope_quant(int query, const std::vector<int>& qlens, const std::vector<int>& kvlens,
const std::vector<std::vector<int>>& page_tables, std::vector<int>& qlen_split,
KMatRefA& input_ref, KMatRefC& q_lora_rank_ref, KMatRef& q_lora_rank_out_ref,
bool is_decode = false);
};
template <typename A, class KERNEL>
class TP_MLA<KML_MLA_TP_QUAN_TEST<A, KERNEL>> : public TP_MLA_Common<KML_MLA_TP_QUAN_TEST<A, KERNEL>> {
public:
using TP_MLA_Common<KML_MLA_TP_QUAN_TEST<A, KERNEL>>::TP_MLA_Common;
void load_weights() {
auto pool = this->config.pool;
auto tp_num_heads = this->config.num_heads / this->tp_count;
pool->dispense_backend()->do_numa_job([this, pool, tp_num_heads](int tp_id) {
this->tps[tp_id]->load_weights(this->config.num_heads, tp_id * tp_num_heads);
});
this->weights_loaded = true;
}
void merge_results(int qlen, void* output_raw) {
auto pool = this->config.pool;
typename KML_MLA_TP_QUAN_TEST<A, KERNEL>::output_t* output =
(typename KML_MLA_TP_QUAN_TEST<A, KERNEL>::output_t*)output_raw;
pool->do_work_stealing_job(qlen, [this, output](int token_nth) {
auto& tp_count = this->tp_count;
auto& config = this->config;
auto& local_output_numa = this->local_output_numa;
reduce_sum(local_output_numa.data(), tp_count, token_nth * config.hidden_size,
token_nth * config.hidden_size + config.hidden_size);
memcpy(&output[token_nth * config.hidden_size], &local_output_numa[0][token_nth * config.hidden_size],
config.hidden_size * sizeof(output[0]));
});
#ifdef DEBUG_THIS_MLA
// dump_bin_int8("output.bin", output, qlen * this->config.hidden_size);
#endif
}
};
template class KML_MLA_TP_QUAN_TEST<float, arm_kml::GemmKernelInt8>;
template class KML_MLA_TP_QUAN_TEST<float, arm_kml::GemmKernelInt4>;
// template class KML_MLA_TP_QUAN_TEST<float16_t>;

View File

@@ -1,546 +0,0 @@
// #define DEBUG_THIS_MLA
#ifdef DEBUG_THIS_MLA
#include "test/debug.hpp"
#endif
#include "mla_quan.h"
template <typename A, class KERNEL>
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::nope_attention_q_absorb(int qlen, int kvlen, const std::vector<int>& page_table,
bool increamental, bool is_decode) {
// 使用 lambda 包装器
auto mul_mat_clearc = [is_decode](auto a, auto b, auto c) {
if (is_decode) {
arm_kml::decode_mul_mat_clearc(a, b, c);
} else {
arm_kml::mul_mat_clearc(a, b, c);
}
};
auto pool = config.pool->get_subpool(tp_part_idx);
{
// q absorb
size_t qlen_block = div_up((size_t)qlen, row_block);
size_t kv_rank_block = div_up(config.kv_lora_rank, col_block);
auto task_counter = TaskCounter({config.num_heads, qlen_block, kv_rank_block});
auto task = [&](int task_id) {
auto [head_idx, qlen_block_idx, kv_rank_block_idx] = task_counter.get<3>(task_id);
size_t qlen_begin = qlen_block_idx * row_block;
size_t qlen_end = std::min(qlen_begin + row_block, (size_t)qlen);
size_t kv_rank_begin = kv_rank_block_idx * col_block;
size_t kv_rank_end = std::min(kv_rank_begin + col_block, (size_t)config.kv_lora_rank);
KMatRef this_local_k_b_proj_ref = local_k_b_proj_ref.offset_block(head_idx * config.nope_size, kv_rank_begin,
config.nope_size, kv_rank_end - kv_rank_begin);
this_local_k_b_proj_ref = this_local_k_b_proj_ref.t();
// printf("q absorb %d [%d,%d),[%d,%d)\n", head_idx, qlen_begin, qlen_end, kv_rank_begin, kv_rank_end);
mul_mat_clearc(this_local_k_b_proj_ref,
q_nope_tmp_ref.offset_col(head_idx * qlen + qlen_begin, qlen_end - qlen_begin),
q_kv_lora_rank_ref.offset_block(kv_rank_begin, head_idx * qlen + qlen_begin,
kv_rank_end - kv_rank_begin, qlen_end - qlen_begin));
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
printf("q absorb [%d] \n", tp_part_idx);
// dump_bin(file_name + "_k_b_lora", (A *)local_kv_b_proj, config.kv_lora_rank * config.nope_size);
// dump_bin(file_name + "_q_absorb", (A *)q_absorb[0], config.kv_lora_rank * config.max_qlen);
#endif
#ifdef FORWARD_TIME_PROFILE
if (is_decode) {
PROFILE_RECORD_TIME_STAMP("decode q absorb");
} else {
PROFILE_RECORD_TIME_STAMP("prefill q absorb");
}
#endif
{
// nope attention
size_t qlen_block = div_up((size_t)qlen * config.num_heads, col_block);
// page size % col_block == 0
size_t kvlen_block = div_up((size_t)kvlen + qlen, col_block);
TaskCounter task_counter({qlen_block, kvlen_block});
auto task = [&](int task_id) {
auto [qlen_block_idx, kvlen_block_idx] = task_counter.get<2>(task_id);
size_t qlen_begin = qlen_block_idx * col_block;
size_t qlen_end = std::min(qlen_begin + col_block, (size_t)qlen * config.num_heads);
size_t kvlen_begin = kvlen_block_idx * col_block;
size_t kvlen_end = std::min(kvlen_begin + col_block, (size_t)kvlen + qlen);
size_t kvlen_block_size = kvlen_end - kvlen_begin;
size_t kv_page = kvlen_begin / config.token_count_in_page;
size_t token_at_in_page = kvlen_begin % config.token_count_in_page;
KMatRef this_cc_ref =
KMatRef((A*)cc_pages[page_table[kv_page]], config.token_count_in_page, cc_size, cc_size, CblasRowMajor);
this_cc_ref = this_cc_ref.offset_row(token_at_in_page, kvlen_end - kvlen_begin);
KMatRef this_q_aborb_ref = q_attn_absorb_ref.offset_col(qlen_begin, qlen_end - qlen_begin);
KMatRef this_attention_weights_ref =
attention_weights_ref.offset_block(kvlen_begin, qlen_begin, kvlen_end - kvlen_begin, qlen_end - qlen_begin);
arm_kml::mul_mat_clearc(this_cc_ref, this_q_aborb_ref, this_attention_weights_ref);
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
printf("attention weights[%d] \n", tp_part_idx);
#endif
#ifdef FORWARD_TIME_PROFILE
if (is_decode) {
PROFILE_RECORD_TIME_STAMP("decode nope attention");
} else {
PROFILE_RECORD_TIME_STAMP("prefill nope attention");
}
#endif
}
template <typename A, class KERNEL>
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::nope_attention_no_absorb(int qlen, int kvlen, const std::vector<int>& page_table,
bool increamental) {
auto pool = config.pool->get_subpool(tp_part_idx);
{
// k nope
size_t nope_block = div_up(config.nope_size, row_block);
size_t kvlen_block = div_up((size_t)kvlen + qlen, col_block);
auto task_counter = TaskCounter({config.num_heads, kvlen_block, nope_block});
auto task = [&](int task_id) {
size_t head_idx = task_counter.at(task_id, 0);
size_t kvlen_block_idx = task_counter.at(task_id, 1);
size_t nope_block_idx = task_counter.at(task_id, 2);
size_t kvlen_begin = kvlen_block_idx * col_block;
size_t kvlen_end = std::min(kvlen_begin + col_block, (size_t)kvlen + qlen);
size_t kvlen_block_size = kvlen_end - kvlen_begin;
size_t kv_page = kvlen_begin / config.token_count_in_page;
size_t token_at_in_page = kvlen_begin % config.token_count_in_page;
size_t nope_begin = nope_block_idx * row_block;
size_t nope_end = std::min(nope_begin + row_block, config.nope_size);
auto k_b_ref = local_k_b_proj_ref.offset_row(head_idx * config.nope_size + nope_begin, nope_end - nope_begin);
KMatRef cc_ref = kv_lora_page_refs[page_table[kv_page]];
cc_ref = cc_ref.offset_col(token_at_in_page, kvlen_end - kvlen_begin);
KMatRef this_k_nope_ref = k_nope_ref.offset_block(nope_begin, head_idx * config.max_kvlen + kvlen_begin,
nope_end - nope_begin, kvlen_end - kvlen_begin);
arm_kml::mul_mat_clearc(k_b_ref, cc_ref, this_k_nope_ref);
if (nope_block_idx == 0) {
auto this_k_rope_ref =
k_rope_ref.offset_col(head_idx * config.max_kvlen + kvlen_begin, kvlen_end - kvlen_begin);
auto this_rope_page_ref =
rope_page_refs[page_table[kv_page]].offset_col(token_at_in_page, kvlen_end - kvlen_begin);
for (size_t i = 0; i < this_k_rope_ref.C; i++) {
memcpy(this_k_rope_ref.data + this_k_rope_ref.ld * i, this_rope_page_ref.data + this_rope_page_ref.ld * i,
config.rope_size * sizeof(A));
}
}
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("k nope");
#endif
{
// nope attention
size_t kvlen_block = div_up((size_t)kvlen + qlen, row_block);
size_t qlen_block = div_up((size_t)qlen, col_block);
auto task_counter = TaskCounter({config.num_heads, kvlen_block, qlen_block});
auto task = [&](int task_id) {
size_t head_idx = task_counter.at(task_id, 0);
size_t kvlen_block_idx = task_counter.at(task_id, 1);
size_t qlen_block_idx = task_counter.at(task_id, 2);
size_t kvlen_begin = kvlen_block_idx * row_block;
size_t kvlen_end = std::min(kvlen_begin + row_block, (size_t)kvlen + qlen);
size_t qlen_begin = qlen_block_idx * col_block;
size_t qlen_end = std::min(qlen_begin + col_block, (size_t)qlen);
KMatRef this_k_ref = k_ref.offset_col(head_idx * config.max_kvlen + kvlen_begin, kvlen_end - kvlen_begin);
this_k_ref = this_k_ref.t();
KMatRef this_q_ref = q_attn_noabsorb_ref.offset_col(head_idx * qlen + qlen_begin, qlen_end - qlen_begin);
KMatRef this_attention_weights_ref = attention_weights_ref.offset_block(
kvlen_begin, head_idx * qlen + qlen_begin, kvlen_end - kvlen_begin, qlen_end - qlen_begin);
if (increamental) {
arm_kml::mul_mat(this_k_ref, this_q_ref, this_attention_weights_ref);
} else {
arm_kml::mul_mat_clearc(this_k_ref, this_q_ref, this_attention_weights_ref);
}
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("nope attention no absorb");
#endif
}
template <typename A, class KERNEL>
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::output_absorb(int query, const std::vector<int>& qlens,
const std::vector<std::vector<int>>& page_tables,
const std::vector<int>& kvlens, bool is_decode) {
// 使用 lambda 包装器
auto mul_mat_clearc = [is_decode](auto a, auto b, auto c) {
if (is_decode) {
arm_kml::decode_mul_mat_clearc(a, b, c);
} else {
arm_kml::mul_mat_clearc(a, b, c);
}
};
auto pool = config.pool->get_subpool(tp_part_idx);
{
// by page
size_t page_count = div_up((size_t)kvlens[query] + qlens[query], config.token_count_in_page);
for (int kv_page = 0; kv_page < page_count; kv_page++) {
// o absorb
size_t kvlen_begin = kv_page * config.token_count_in_page;
size_t kvlen_end = std::min(kvlen_begin + config.token_count_in_page, (size_t)kvlens[query] + qlens[query]);
size_t page_kv_len = kvlen_end - kvlen_begin;
size_t qlen_block = div_up((size_t)qlens[query] * config.num_heads, row_block);
size_t kv_rank_block = div_up(config.kv_lora_rank, col_block);
auto task_counter = TaskCounter({qlen_block, kv_rank_block});
auto task = [&](int task_id) {
auto [qlen_block_idx, kv_rank_block_idx] = task_counter.get<2>(task_id);
size_t qlen_begin = qlen_block_idx * row_block;
size_t qlen_end = std::min(qlen_begin + row_block, (size_t)qlens[query] * config.num_heads);
size_t kv_rank_begin = kv_rank_block_idx * col_block;
size_t kv_rank_end = std::min(kv_rank_begin + col_block, (size_t)config.kv_lora_rank);
KMatRef kv_lora_page_ref = kv_lora_page_refs[page_tables[query][kv_page]];
kv_lora_page_ref = kv_lora_page_ref.offset_block(kv_rank_begin, 0, kv_rank_end - kv_rank_begin, page_kv_len);
KMatRef this_attention_weights_ref =
attention_weights_ref.offset_block(kvlen_begin, qlen_begin, page_kv_len, qlen_end - qlen_begin);
KMatRef this_o_absorb_ref =
o_absorb_ref.offset_block(kv_rank_begin, qlen_begin, kv_rank_end - kv_rank_begin, qlen_end - qlen_begin);
if (kv_page == 0) {
arm_kml::mul_mat_clearc(kv_lora_page_ref, this_attention_weights_ref, this_o_absorb_ref);
} else {
arm_kml::mul_mat(kv_lora_page_ref, this_attention_weights_ref, this_o_absorb_ref);
}
};
pool->do_work_stealing_job(task_counter.count(), task);
}
}
#ifdef DEBUG_THIS_MLA
printf("o absorb[%d]\n", tp_part_idx);
for (size_t i = 0; i < config.num_heads; i++)
dump_bin(file_name + "_o_absorb_" + std::to_string(i), (A*)o_absorb_or_v, config.kv_lora_rank * qlens[query]);
#endif
#ifdef FORWARD_TIME_PROFILE
if (is_decode) {
PROFILE_RECORD_TIME_STAMP("decode o absorb");
} else {
PROFILE_RECORD_TIME_STAMP("prefill o absorb");
}
#endif
{
// attention output
auto qlen_block = div_up((size_t)qlens[query], col_block);
auto nope_block = div_up((size_t)config.nope_size, row_block);
auto task_counter = TaskCounter({config.num_heads, qlen_block, nope_block});
auto task = [&](int task_id) {
size_t head_idx = task_counter.at(task_id, 0);
size_t qlen_block_idx = task_counter.at(task_id, 1);
size_t nope_block_idx = task_counter.at(task_id, 2);
size_t qlen_begin = qlen_block_idx * col_block;
size_t qlen_end = std::min(qlen_begin + col_block, (size_t)qlens[query]);
size_t nope_begin = nope_block_idx * row_block;
size_t nope_end = std::min(nope_begin + row_block, (size_t)config.nope_size);
KMatRef this_local_v_b_proj_ref =
local_v_b_proj_ref.offset_row(head_idx * config.nope_size + nope_begin, nope_end - nope_begin);
KMatRef this_o_absorb_ref = o_absorb_ref.offset_col(head_idx * qlens[query] + qlen_begin, qlen_end - qlen_begin);
KMatRef this_attention_output_ref = attention_output_ref.offset_block(
head_idx * config.nope_size + nope_begin, qlen_begin, nope_end - nope_begin, qlen_end - qlen_begin);
mul_mat_clearc(this_local_v_b_proj_ref, this_o_absorb_ref, this_attention_output_ref);
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
printf("attention output[%d]\n", tp_part_idx);
dump_bin(file_name + "_attention_output", (A*)attention_output, config.num_heads * config.nope_size * qlens[query]);
#endif
#ifdef FORWARD_TIME_PROFILE
if (is_decode) {
PROFILE_RECORD_TIME_STAMP("decode attention output");
} else {
PROFILE_RECORD_TIME_STAMP("prefill attention output");
}
#endif
}
template <typename A, class KERNEL>
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::output_no_absorb(int query, const std::vector<int>& qlens,
const std::vector<std::vector<int>>& page_tables,
const std::vector<int>& kvlens) {
auto pool = config.pool->get_subpool(tp_part_idx);
{
// v
size_t page_count = div_up((size_t)kvlens[query] + qlens[query], config.token_count_in_page);
size_t nope_block_count = div_up((size_t)config.nope_size, row_block);
size_t kvlen_in_page_count = div_up(config.token_count_in_page, col_block);
ASSERT_RELEASE(config.token_count_in_page % col_block == 0, "token_count_in_page must be divisible by col_block");
auto task_counter = TaskCounter({config.num_heads, page_count, kvlen_in_page_count, nope_block_count});
auto task = [&](int task_id) {
size_t head_idx = task_counter.at(task_id, 0);
size_t page_idx = task_counter.at(task_id, 1);
size_t kvlen_idx = task_counter.at(task_id, 2);
size_t nope_block_idx = task_counter.at(task_id, 3);
size_t kvlen_begin = page_idx * config.token_count_in_page + kvlen_idx * col_block;
size_t kvlen_end = std::min(kvlen_begin + col_block, (size_t)kvlens[query] + qlens[query]);
if (kvlen_begin >= kvlen_end) return; // skip the extra block
#ifdef DEBUG_THIS_MLA
printf("v nope[%d] %d %d %d %d\n", tp_part_idx, head_idx, page_idx, kvlen_begin, kvlen_end);
#endif
size_t kvlen_begin_in_page = kvlen_begin % config.token_count_in_page;
size_t nope_begin = nope_block_idx * row_block;
size_t nope_end = std::min(nope_begin + row_block, (size_t)config.nope_size);
KMatRef this_local_v_b_proj_ref =
local_v_b_proj_ref.offset_row(head_idx * config.nope_size + nope_begin, nope_end - nope_begin);
KMatRef kv_lora_page_ref = kv_lora_page_refs[page_tables[query][page_idx]];
kv_lora_page_ref =
kv_lora_page_ref.offset_block(0, kvlen_begin_in_page, config.kv_lora_rank, kvlen_end - kvlen_begin);
KMatRef this_v_ref = v_ref.offset_block(nope_begin, head_idx * config.max_kvlen + kvlen_begin,
nope_end - nope_begin, kvlen_end - kvlen_begin);
arm_kml::mul_mat_clearc(this_local_v_b_proj_ref, kv_lora_page_ref, this_v_ref);
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
printf("v nope[%d] done\n", tp_part_idx);
#endif
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("v nope");
#endif
{
// attn output
size_t nope_block_count = div_up((size_t)config.nope_size, row_block);
size_t qlen_block_count = div_up((size_t)qlens[query], col_block);
auto task_counter = TaskCounter({config.num_heads, nope_block_count, qlen_block_count});
auto task = [&](int task_id) {
size_t head_idx = task_counter.at(task_id, 0);
size_t nope_block_idx = task_counter.at(task_id, 1);
size_t qlen_block_idx = task_counter.at(task_id, 2);
size_t nope_begin = nope_block_idx * row_block;
size_t nope_end = std::min(nope_begin + row_block, (size_t)config.nope_size);
size_t qlen_begin = qlen_block_idx * col_block;
size_t qlen_end = std::min(qlen_begin + col_block, (size_t)qlens[query]);
if (qlen_begin >= qlen_end) return; // skip the extra block
KMatRef this_v_ref = v_ref.offset_block(nope_begin, head_idx * config.max_kvlen, nope_end - nope_begin,
kvlens[query] + qlens[query]);
KMatRef this_attention_weights_ref = attention_weights_ref.offset_col(head_idx * qlens[query], qlens[query]);
this_attention_weights_ref =
this_attention_weights_ref.offset_block(0, qlen_begin, kvlens[query] + qlens[query], qlen_end - qlen_begin);
KMatRef this_attention_output_ref = attention_output_ref.offset_block(
head_idx * config.nope_size + nope_begin, qlen_begin, nope_end - nope_begin, qlen_end - qlen_begin);
arm_kml::mul_mat_clearc(this_v_ref, this_attention_weights_ref, this_attention_output_ref);
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
dump_bin(file_name + "_attention_output", (A*)attention_output, config.num_heads * config.nope_size * qlens[query]);
#endif
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("attn output no absorb");
#endif
}
template <typename A, class KERNEL>
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::q_lora_kv_lora_rope_quant(int query, const std::vector<int>& qlens,
const std::vector<int>& kvlens,
const std::vector<std::vector<int>>& page_tables,
std::vector<int>& qlen_split, KMatRefA& input_ref,
KMatRefC& q_lora_rank_ref, KMatRef& q_lora_rank_out_ref,
bool is_decode) {
// 使用 lambda 包装器
auto mul_mat_clearc = [is_decode](auto a, auto b, auto c) {
if (is_decode) {
arm_kml::decode_mul_mat_clearc(a, b, c);
} else {
arm_kml::mul_mat_clearc(a, b, c);
}
};
auto pool = config.pool->get_subpool(tp_part_idx);
auto total_len = qlens[query] + kvlens[query];
size_t query_page_count = div_up((size_t)total_len, config.token_count_in_page);
#ifdef DEBUG_THIS_MLA
file_name = std::to_string(query);
file_name = "query_" + file_name + "_tp_" + std::to_string(tp_part_idx);
if (tp_part_idx == 0) {
printf("qlen %d, kvlen %d, page table: ", qlens[query], kvlens[query]);
for (auto x : page_tables[query]) {
printf(" %d,", x);
}
printf("\n");
}
#endif
// KMatRef qlen_input_ref = input_ref.offset_block(0, qlen_split[query], config.hidden_size, qlens[query]);
// KMatRefAB qlen_input_ref = input_ref.offset_block(0, qlen_split[query], config.hidden_size, qlens[query]);
KMatRefA qlen_input_ref = input_ref.offset_row(qlen_split[query], qlens[query]);
auto qlen_input_ba = local_q_a_proj_quant_ba->offset_row(qlen_split[query], qlens[query]);
// auto local_q_a_proj_qlen_bb = local_q_a_proj_bb->offset_col(qlen_split[query], qlens[query]);
{
// q lora, kv lora, rope
size_t cc_page_begin = (kvlens[query]) / config.token_count_in_page;
size_t cc_page_end = div_up((size_t)kvlens[query] + qlens[query], config.token_count_in_page);
size_t block_per_page = div_up(config.token_count_in_page, row_block_q_lora_kv_lora_rope);
size_t q_lora_rank_block_count = div_up((size_t)config.q_lora_rank, col_block_q_lora_kv_lora_rope);
size_t kv_lora_rank_block_count = div_up((size_t)config.kv_lora_rank, col_block_q_lora_kv_lora_rope);
size_t k_rope_block_count = div_up((size_t)config.rope_size, col_block_q_lora_kv_lora_rope);
TaskCounter task_counter({cc_page_end - cc_page_begin, block_per_page,
q_lora_rank_block_count + kv_lora_rank_block_count + k_rope_block_count});
auto task = [&](int task_id) {
size_t cc_page = task_counter.at(task_id, 0) + cc_page_begin;
size_t in_page_block_idx = task_counter.at(task_id, 1);
size_t kvlen_begin =
std::clamp(cc_page * config.token_count_in_page + in_page_block_idx * row_block_q_lora_kv_lora_rope,
(size_t)kvlens[query], (size_t)kvlens[query] + qlens[query]);
size_t kvlen_end =
std::clamp(cc_page * config.token_count_in_page + (in_page_block_idx + 1) * row_block_q_lora_kv_lora_rope,
(size_t)kvlens[query], (size_t)kvlens[query] + qlens[query]);
// printf("kvlen[%d,%d)\n", kvlen_begin, kvlen_end);
size_t kvlen_block_size = kvlen_end - kvlen_begin;
if (kvlen_block_size == 0) {
return;
}
size_t qlen_begin = kvlen_begin - kvlens[query];
size_t qlen_end = kvlen_end - kvlens[query];
auto blocked_input = qlen_input_ref.offset_row(qlen_begin, qlen_end - qlen_begin);
// auto blocked_input = qlen_input_ref.offset_block(0, qlen_begin, config.hidden_size, qlen_end - qlen_begin);
// auto local_q_a_proj_qlen_blocked_bb = local_q_a_proj_qlen_bb.offset_col(qlen_begin, qlen_end - qlen_begin);
int q_or_kv_or_krope = task_counter.at(task_id, 2);
if (q_or_kv_or_krope < q_lora_rank_block_count) {
size_t q_lora_rank_block_idx = q_or_kv_or_krope;
size_t q_lora_rank_begin = q_lora_rank_block_idx * col_block_q_lora_kv_lora_rope;
size_t q_lora_rank_end = std::min(config.q_lora_rank, q_lora_rank_begin + col_block_q_lora_kv_lora_rope);
mul_mat_clearc(blocked_input,
local_q_a_proj_quant_ref.offset_col(q_lora_rank_begin, q_lora_rank_end - q_lora_rank_begin),
q_lora_rank_ref.offset_block(qlen_begin, q_lora_rank_begin, qlen_end - qlen_begin,
q_lora_rank_end - q_lora_rank_begin));
GemmKernel::apply_scale(q_lora_rank_out_ref.data, q_lora_rank_out_ref.ld, &qlen_input_ba,
local_q_a_proj_quant_bb, local_q_a_proj_quant_bc, qlen_begin, qlen_end,
q_lora_rank_begin, q_lora_rank_end, true);
#ifdef DEBUG_THIS_MLA
// 打印前十个q_lora_rank_ref和前十个q_lora_rank_out_ref
printf("q_lora_rank_begin:%d, q_lora_rank_end:%d\n", q_lora_rank_begin, q_lora_rank_end);
printf("qlen_begin:%d, qlen_end:%d\n", qlen_begin, qlen_end);
if (tp_part_idx == 0) {
for (int i = 0; i < 10; i++) {
printf("q_lora_rank_ref:%d, %f\n", q_lora_rank_ref.data[i], q_lora_rank_out_ref.data[i]);
}
}
#endif
} else if (q_or_kv_or_krope < q_lora_rank_block_count + kv_lora_rank_block_count) {
size_t kv_lora_rank_block_idx = q_or_kv_or_krope - q_lora_rank_block_count;
size_t kv_lora_rank_begin = kv_lora_rank_block_idx * col_block_q_lora_kv_lora_rope;
size_t kv_lora_rank_end = std::min(config.kv_lora_rank, kv_lora_rank_begin + col_block_q_lora_kv_lora_rope);
KMatRefC kv_lora_page_ref = kv_lora_page_refs_decode_buffer[page_tables[query][cc_page]].offset_block(
kvlen_begin % config.token_count_in_page, kv_lora_rank_begin, kvlen_end - kvlen_begin,
kv_lora_rank_end - kv_lora_rank_begin);
mul_mat_clearc(
blocked_input,
local_kv_a_proj_with_mqa_decode_ref.offset_col(kv_lora_rank_begin, kv_lora_rank_end - kv_lora_rank_begin),
kv_lora_page_ref);
KMatRef kv_lora_page_out_ref = kv_lora_page_refs[page_tables[query][cc_page]];
GemmKernel::apply_scale(
kv_lora_page_out_ref.data, kv_lora_page_out_ref.ld, &qlen_input_ba, local_kv_a_proj_with_mqa_quant_bb,
kv_lora_page_refs_decode_buffer[page_tables[query][cc_page]].data, qlen_begin, qlen_end, kv_lora_rank_begin,
kv_lora_rank_end, true, (kvlen_begin) % config.token_count_in_page - qlen_begin, 0);
} else if (q_or_kv_or_krope < q_lora_rank_block_count + kv_lora_rank_block_count + k_rope_block_count) {
// single block for k rope, no norm
size_t rope_block_idx = q_or_kv_or_krope - q_lora_rank_block_count - kv_lora_rank_block_count;
size_t rope_begin = rope_block_idx * col_block_q_lora_kv_lora_rope;
size_t rope_end = std::min(config.rope_size, rope_begin + col_block_q_lora_kv_lora_rope);
KMatRefC rope_page_ref = rope_page_refs_decode_buffer[page_tables[query][cc_page]].offset_block(
kvlen_begin % config.token_count_in_page, rope_begin, kvlen_end - kvlen_begin, rope_end - rope_begin);
mul_mat_clearc(
blocked_input,
local_kv_a_proj_with_mqa_decode_ref.offset_col(config.kv_lora_rank + rope_begin, rope_end - rope_begin),
rope_page_ref);
KMatRef rope_page_out_ref = rope_page_refs[page_tables[query][cc_page]];
GemmKernel::apply_scale(rope_page_out_ref.data, rope_page_out_ref.ld, &qlen_input_ba,
local_kv_a_proj_with_mqa_quant_bb,
rope_page_refs_decode_buffer[page_tables[query][cc_page]].data, qlen_begin, qlen_end,
config.kv_lora_rank + rope_begin, config.kv_lora_rank + rope_end, true,
(kvlen_begin) % config.token_count_in_page - qlen_begin, -(config.kv_lora_rank));
rope_page_out_ref = rope_page_out_ref.offset_block(rope_begin, kvlen_begin % config.token_count_in_page,
rope_end - rope_begin, kvlen_end - kvlen_begin);
T_RopeApplier::apply_multiple(*rope_angle, rope_page_out_ref.data, config.rope_size, rope_page_out_ref.ld,
kvlen_begin, kvlen_block_size);
} else {
throw std::runtime_error("task id wrong");
}
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
printf("q lora, kv lora, rope[%d]\n", tp_part_idx);
// dump_bin(file_name + "_input.bin", qlen_input_ref.data, qlens[query] * config.hidden_size);
dump_bin(file_name + "_qlora.bin", q_lora_rank, qlens[query] * config.q_lora_rank);
for (int i = 0; i < query_page_count; i++) {
dump_bin(file_name + "_page_" + std::to_string(i) + "_cc_pages", (A*)cc_pages[page_tables[query][i]],
config.token_count_in_page * cc_size);
}
#endif
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("q lora, kv lora, rope");
#endif
}

View File

@@ -1,310 +0,0 @@
#include <cstdio>
#include "kblas.h"
#include "la/arm_kml.hpp"
#include "mla_quan.h"
template <typename A, class KERNEL>
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::forward_decode(const std::vector<int>& qlens,
const std::vector<std::vector<int>>& page_tables,
const std::vector<int>& kvlens,
const std::vector<void*>& attention_masks,
const input_t* input_raw, output_t* output_raw) {
#ifdef DEBUG_THIS_MLA
printf("input raw[%d]\n", tp_part_idx);
#endif
#ifdef FORWARD_TIME_PROFILE
forward_perf_start();
#endif
auto pool = config.pool->get_subpool(tp_part_idx);
std::vector<int> qlen_split, total_len_split;
qlen_split.reserve(qlens.size() + 1);
qlen_split.push_back(0);
total_len_split.reserve(qlens.size() + 1);
int qlen_sum = 0;
int total_len_sum = 0;
for (size_t i = 0; i < qlens.size(); i++) {
qlen_sum += qlens[i];
qlen_split.push_back(qlen_sum);
total_len_sum += qlens[i] + kvlens[i];
total_len_split.push_back(total_len_sum);
}
// 输入进行量化,输入是 hidden_size * qlen 且是 col major
{
size_t nth = GemmKernel::recommended_nth(qlen_sum);
auto task_counter = TaskCounter({nth});
auto task = [&](int task_id) {
size_t nth_idx = task_counter.at(task_id, 0);
local_q_a_proj_quant_ba->from_mat(qlen_sum, const_cast<A*>(input_raw), nth_idx, nth);
// local_q_a_proj_deprecated_bb->from_mat(const_cast<A *>(input_raw), nth_idx, nth, qlen_sum);
};
DIRECT_OR_POOL_BY(qlen_sum, 10, task_counter.count(), task);
}
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("input_quant");
#endif
KMatRefA input_ref =
KMatRefA(local_q_a_proj_quant_ba->a, qlen_sum, config.hidden_size, config.hidden_size, CblasRowMajor);
auto output_ref = KMatRef(output_raw, qlen_sum, config.hidden_size, config.hidden_size, CblasRowMajor);
KMatRefC q_lora_rank_ref =
KMatRefC(local_q_a_proj_quant_bc->c, qlen_sum, config.q_lora_rank, config.q_lora_rank, CblasRowMajor);
KMatRef q_lora_rank_out_ref =
KMatRef((A*)q_lora_rank, qlen_sum, config.q_lora_rank, config.q_lora_rank, CblasRowMajor);
for (int query = 0; query < qlens.size(); query++) {
bool use_absorb = true;
q_lora_kv_lora_rope_quant(query, qlens, kvlens, page_tables, qlen_split, input_ref, q_lora_rank_ref,
q_lora_rank_out_ref, true);
#ifdef DEBUG_THIS_MLA
if (tp_part_idx == 0) {
dump_bin("input.bin", const_cast<A*>(input_raw), qlen_sum * config.hidden_size);
}
#endif
{
// norm
auto task_counter = TaskCounter({2, (size_t)qlens[query]});
auto task = [&](int task_id) {
int q_or_k_kpe = task_counter.at(task_id, 0);
size_t qlen_idx = task_counter.at(task_id, 1);
if (q_or_k_kpe == 0) {
T_RMSNorm::rms_norm_single_with_weights(config.q_lora_rank, local_q_a_norm,
q_lora_rank_out_ref.offset_row(qlen_idx, 1).data);
} else if (q_or_k_kpe == 1) {
auto kv_page = (qlen_idx + kvlens[query]) / config.token_count_in_page;
auto token_at_in_page = (qlen_idx + kvlens[query]) % config.token_count_in_page;
KMatRef kv_lora_page_ref = kv_lora_page_refs[page_tables[query][kv_page]];
T_RMSNorm::rms_norm_single_with_weights(config.kv_lora_rank, local_kv_a_norm,
kv_lora_page_ref.offset_col(token_at_in_page, 1).data);
} else {
throw std::runtime_error("unknown task");
}
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
printf("q lora norm[%d]\n", tp_part_idx);
dump_bin(file_name + "_qlora_norm.bin", q_lora_rank, qlens[query] * config.q_lora_rank);
// for (int i = 0; i < query_page_count; i++) {
// dump_bin(file_name + "_page_" + std::to_string(i) + "_kv_lora_rank_norm",
// (A *)kv_lora_pages[page_tables[query][i]], config.token_count_in_page * config.kv_lora_rank);
// }
#endif
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("q/kv lora norm");
#endif
{
// q rope & nope
size_t qlen_block = div_up((size_t)qlens[query], col_block);
TaskCounter task_counter({config.num_heads, 2, qlen_block});
auto task = [&](int task_id) {
auto head_idx = task_counter.at(task_id, 0);
bool nope_or_rope = (task_counter.at(task_id, 1) == 0);
auto qlen_block_idx = task_counter.at(task_id, 2);
size_t qlen_begin = qlen_block_idx * col_block;
size_t qlen_end = std::min(qlen_begin + col_block, (size_t)qlens[query]);
KMatRef b = q_lora_rank_out_ref.trans_view().offset_col(qlen_begin, qlen_end - qlen_begin);
if (nope_or_rope) {
auto a = local_q_b_proj_ref.offset_row(head_idx * (config.nope_size + config.rope_size), config.nope_size);
KMatRef c = use_absorb ? q_nope_tmp_ref : q_nope_ref;
c = c.offset_col(head_idx * qlens[query] + qlen_begin, qlen_end - qlen_begin);
arm_kml::decode_mul_mat_clearc(a, b, c);
} else {
auto a = local_q_b_proj_ref.offset_row(head_idx * (config.nope_size + config.rope_size) + config.nope_size,
config.rope_size);
KMatRef c = use_absorb ? q_pe_absorb_ref : q_pe_noabsorb_ref;
c = c.offset_col(head_idx * qlens[query] + qlen_begin, qlen_end - qlen_begin);
arm_kml::decode_mul_mat_clearc(a, b, c);
T_RopeApplier::apply_multiple(*rope_angle, c.data, config.rope_size, c.ld, qlen_begin + kvlens[query],
qlen_end - qlen_begin);
}
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
printf("q nope/rope[%d]\n", tp_part_idx);
if (use_absorb) {
dump_bin(file_name + "_q_nope", q_nope, config.nope_size * qlens[query]);
} else {
dump_bin(file_name + "_q_nope", q, q_ld * qlens[query]);
}
dump_bin(file_name + "_q", q, q_ld * qlens[query]);
dump_bin(file_name + "_rope_cos", rope_angle->cos(0), config.rope_size / 2 * qlens[query]);
dump_bin(file_name + "_rope_sin", rope_angle->sin(0), config.rope_size / 2 * qlens[query]);
#endif
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("decode q nope/rope");
#endif
nope_attention_q_absorb(qlens[query], kvlens[query], page_tables[query], false, true);
#ifdef DEBUG_THIS_MLA
printf("attention weights[%d] \n", tp_part_idx);
for (size_t i = 0; i < config.num_heads; i++)
dump_bin(file_name + "_raw_attention_weights_" + std::to_string(i),
attention_weights_ref.offset_col(i * qlens[query], qlens[query]).data, config.max_kvlen * qlens[query]);
#endif
{
// attentino mask & soft max
auto task_counter = TaskCounter({config.num_heads, (size_t)qlens[query]});
auto task = [&](int task_id) {
size_t head_idx = task_counter.at(task_id, 0);
size_t qlen_idx = task_counter.at(task_id, 1);
size_t qlen_from_start = qlen_idx + kvlens[query];
A* aw = offset_pointer(attention_weights,
(config.max_kvlen * qlens[query] * head_idx + config.max_kvlen * qlen_idx) * sizeof(A));
for (int i = 0; i < kvlens[query] + qlens[query]; i++) {
aw[i] *= softmax_scale;
aw[i] += static_cast<A*>(attention_masks[qlen_from_start])[i];
}
T_SoftmaxApplier::apply_single(aw, kvlens[query] + qlens[query]);
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
printf("attention weights after softmax[%d] \n", tp_part_idx);
for (size_t i = 0; i < config.num_heads; i++)
dump_bin(file_name + "_attention_weights_" + std::to_string(i),
attention_weights_ref.offset_col(i * qlens[query], qlens[query]).data, config.max_kvlen * qlens[query]);
#endif
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("attention mask & softmax");
#endif
output_absorb(query, qlens, page_tables, kvlens, true);
#ifdef DEBUG_THIS_MLA
printf("attn output done[%d]\n", tp_part_idx);
#endif
// 量化输入attention_output_ref->data (attention_output) [config.nope_size * config.num_heads, qlens[query]] col
// major
{
size_t nth = GemmKernel::recommended_nth(qlens[query]);
auto task_counter = TaskCounter({nth});
auto task = [&](int task_id) {
size_t nth_idx = task_counter.at(task_id, 0);
local_w_o_quant_ba->from_mat(qlens[query], attention_output_ref.data, nth_idx, nth);
};
DIRECT_OR_POOL_BY(qlens[query], 10, task_counter.count(), task);
}
{
KMatRefA local_w_o_ba_ref = KMatRefA(local_w_o_quant_ba->a, config.nope_size * config.num_heads, qlens[query],
config.nope_size * config.num_heads, CblasColMajor);
auto qlen_block = div_up((size_t)qlens[query], decode_col_block_out_by_head);
auto hidden_size_block = div_up((size_t)config.hidden_size, decode_row_block_out_by_head);
auto task_counter = TaskCounter({div_up(config.num_heads, sub_num_heads_decode), qlen_block, hidden_size_block});
auto task = [&](int task_id) {
size_t head_idx = task_counter.at(task_id, 0);
size_t qlen_block_idx = task_counter.at(task_id, 1);
size_t hidden_size_block_idx = task_counter.at(task_id, 2);
size_t qlen_begin = qlen_block_idx * decode_col_block_out_by_head;
size_t qlen_end = std::min(qlen_begin + decode_col_block_out_by_head, (size_t)qlens[query]);
size_t hidden_size_begin = hidden_size_block_idx * decode_row_block_out_by_head;
size_t hidden_size_end = std::min(hidden_size_begin + decode_row_block_out_by_head, (size_t)config.hidden_size);
size_t head_begin = head_idx * config.nope_size * sub_num_heads_decode;
size_t head_end =
std::min(head_begin + config.nope_size * sub_num_heads_decode, (size_t)config.nope_size * config.num_heads);
#ifdef DEBUG_THIS_MLA
// printf("tp idx: "
// "%d,head_idx:%d,head_begin:%d,head_end:%d,qlen_begin:%d,qlen_end:%d,hidden_size_begin:%d,hidden_size_"
// "end:%d,qlen_sum:%d,qlen_split:%d,qlens[query]:%d\n",
// tp_part_idx, head_idx, head_begin, head_end, qlen_begin, qlen_end, hidden_size_begin, hidden_size_end,
// qlen_sum, qlen_split[query], qlens[query]);
#endif
auto output_ref_buffer =
KMatRefC(local_w_o_decode_bc[head_idx]->c, config.hidden_size, qlen_sum, config.hidden_size, CblasColMajor);
KMatRefC qlen_output_ref_buffer = output_ref_buffer.offset_col(qlen_split[query], qlens[query]);
KMatRefB this_local_w_o_ref =
KMatRefB(local_w_o_decode_bb[head_idx]->b, config.hidden_size, sub_num_heads_decode * config.nope_size,
sub_num_heads_decode * config.nope_size, CblasRowMajor, CblasNoTrans,
local_w_o_decode_bb[head_idx]->if_pack)
.offset_row(hidden_size_begin, hidden_size_end - hidden_size_begin);
KMatRefA this_attention_output_ref =
local_w_o_ba_ref.offset_block(head_begin, qlen_begin, head_end - head_begin, qlen_end - qlen_begin);
KMatRefC qlen_output_ref = qlen_output_ref_buffer.offset_block(
hidden_size_begin, qlen_begin, hidden_size_end - hidden_size_begin, qlen_end - qlen_begin);
arm_kml::decode_mul_mat_clearc(this_attention_output_ref.trans_view(), this_local_w_o_ref.trans_view(),
qlen_output_ref.trans_view());
#ifdef DEBUG_THIS_MLA
dump_bin(file_name + "_output_by_head_quant_" + std::to_string(head_idx), local_w_o_decode_bc[head_idx]->c,
qlens[query] * config.hidden_size);
#endif
KMatRef qlen_output_out_ref =
KMatRef(qlen_quant_output[head_idx], config.hidden_size, qlens[query], config.hidden_size, CblasColMajor);
// GemmKernel::apply_scale(qlen_output_out_ref.data, qlen_output_out_ref.ld, local_w_o_decode_bb[head_idx],
// local_w_o_quant_ba, local_w_o_decode_bc[head_idx], hidden_size_begin,
// hidden_size_end, qlen_begin, qlen_end, false, 0, qlen_split[query]);
GemmKernel::apply_scale(qlen_output_out_ref.data, qlen_output_out_ref.ld, local_w_o_quant_ba,
local_w_o_decode_bb[head_idx], local_w_o_decode_bc[head_idx], qlen_begin, qlen_end,
hidden_size_begin, hidden_size_end, true, qlen_split[query], 0);
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
printf("decode output by head done[%d]\n", tp_part_idx);
for (int i = 0; i < div_up(config.num_heads, sub_num_heads_decode); i++) {
dump_bin(file_name + "_output_by_head_dequant_" + std::to_string(i), qlen_quant_output[i],
qlens[query] * config.hidden_size);
}
dump_bin(file_name + "_local_w_o", local_w_o, config.hidden_size * config.nope_size * config.num_heads);
#endif
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("output by head");
#endif
{
// merge output
KMatRef reduce_qlen_output_ref = output_ref.offset_row(qlen_split[query], qlens[query]);
const size_t sum_block = 1024;
const size_t sum_block_count = div_up(config.hidden_size, sum_block);
auto task_counter = TaskCounter({sum_block_count, (size_t)qlens[query]});
pool->do_work_stealing_job(task_counter.count(), [&](int task_id) {
size_t hidden_idx = task_counter.at(task_id, 0);
size_t hidden_begin = hidden_idx * sum_block;
size_t hidden_end = std::min(hidden_begin + sum_block, (size_t)config.hidden_size);
size_t qlen_idx = task_counter.at(task_id, 1);
reduce_sum(qlen_quant_output.data(), div_up(config.num_heads, sub_num_heads_decode),
qlen_idx * config.hidden_size + hidden_begin, qlen_idx * config.hidden_size + hidden_end);
});
memcpy(reduce_qlen_output_ref.data, qlen_quant_output[0], qlens[query] * config.hidden_size * sizeof(A));
}
#ifdef DEBUG_THIS_MLA
dump_bin(file_name + "_qlen_output", (A*)output_raw, qlens[query] * config.hidden_size);
#endif
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("decode merge output tp");
#endif
#ifdef FORWARD_TIME_PROFILE
time_perf_name = "[mla] layer " + std::to_string(config.layer_idx) +
" tp_part_idx: " + std::to_string(tp_part_idx) + ", query: " + std::to_string(query);
perf_report();
#endif
}
}

View File

@@ -1,296 +0,0 @@
// #define DEBUG_THIS_MLA
#ifdef DEBUG_THIS_MLA
#include "test/debug.hpp"
#endif
#include <cstdio>
#include "la/arm_kml.hpp"
#include "mla_quan.h"
template <typename A, class KERNEL>
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::forward_prefill(const std::vector<int>& qlens,
const std::vector<std::vector<int>>& page_tables,
const std::vector<int>& kvlens,
const std::vector<void*>& attention_masks,
const input_t* input_raw, output_t* output_raw) {
#ifdef DEBUG_THIS_MLA
printf("input raw[%d]\n", tp_part_idx);
#endif
#ifdef FORWARD_TIME_PROFILE
forward_perf_start();
#endif
auto pool = config.pool->get_subpool(tp_part_idx);
std::vector<int> qlen_split, total_len_split;
qlen_split.reserve(qlens.size() + 1);
qlen_split.push_back(0);
total_len_split.reserve(qlens.size() + 1);
int qlen_sum = 0;
int total_len_sum = 0;
for (size_t i = 0; i < qlens.size(); i++) {
qlen_sum += qlens[i];
qlen_split.push_back(qlen_sum);
total_len_sum += qlens[i] + kvlens[i];
total_len_split.push_back(total_len_sum);
}
// 输入进行量化,输入是 hidden_size * qlen 且是 col major
{
size_t nth = GemmKernel::recommended_nth(qlen_sum);
auto task_counter = TaskCounter({nth});
auto task = [&](int task_id) {
size_t nth_idx = task_counter.at(task_id, 0);
// local_q_a_proj_deprecated_bb->from_mat(const_cast<A *>(input_raw), nth_idx, nth, qlen_sum);
local_q_a_proj_quant_ba->from_mat(qlen_sum, const_cast<A*>(input_raw), nth_idx, nth);
};
DIRECT_OR_POOL_BY(qlen_sum, 10, task_counter.count(), task);
}
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("input_quant");
#endif
auto input_ref =
KMatRefA(local_q_a_proj_quant_ba->a, qlen_sum, config.hidden_size, config.hidden_size, CblasRowMajor);
auto output_ref = KMatRef(output_raw, qlen_sum, config.hidden_size, config.hidden_size, CblasRowMajor);
auto output_ref_buffer =
KMatRefC(local_w_o_prefill_bc->c, qlen_sum, config.hidden_size, config.hidden_size, CblasRowMajor);
KMatRefC q_lora_rank_ref =
KMatRefC(local_q_a_proj_quant_bc->c, qlen_sum, config.q_lora_rank, config.q_lora_rank, CblasRowMajor);
KMatRef q_lora_rank_out_ref =
KMatRef((A*)q_lora_rank, qlen_sum, config.q_lora_rank, config.q_lora_rank, CblasRowMajor);
for (int query = 0; query < qlens.size(); query++) {
bool use_absorb = decide_absorb(qlens[query], kvlens[query]);
q_lora_kv_lora_rope_quant(query, qlens, kvlens, page_tables, qlen_split, input_ref, q_lora_rank_ref,
q_lora_rank_out_ref);
{
// norm
auto task_counter = TaskCounter({2, (size_t)qlens[query]});
auto task = [&](int task_id) {
int q_or_k_kpe = task_counter.at(task_id, 0);
size_t qlen_idx = task_counter.at(task_id, 1);
if (q_or_k_kpe == 0) {
T_RMSNorm::rms_norm_single_with_weights(config.q_lora_rank, local_q_a_norm,
q_lora_rank_out_ref.offset_row(qlen_idx, 1).data);
} else if (q_or_k_kpe == 1) {
auto kv_page = (qlen_idx + kvlens[query]) / config.token_count_in_page;
auto token_at_in_page = (qlen_idx + kvlens[query]) % config.token_count_in_page;
KMatRef kv_lora_page_ref = kv_lora_page_refs[page_tables[query][kv_page]];
T_RMSNorm::rms_norm_single_with_weights(config.kv_lora_rank, local_kv_a_norm,
kv_lora_page_ref.offset_col(token_at_in_page, 1).data);
} else {
throw std::runtime_error("unknown task");
}
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
printf("q lora norm[%d]\n", tp_part_idx);
dump_bin(file_name + "_qlora_norm.bin", q_lora_rank, qlens[query] * config.q_lora_rank);
// for (int i = 0; i < query_page_count; i++) {
// dump_bin(file_name + "_page_" + std::to_string(i) + "_kv_lora_rank_norm",
// (A *)kv_lora_pages[page_tables[query][i]], config.token_count_in_page * config.kv_lora_rank);
// }
#endif
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("q/kv lora norm");
#endif
{
// q rope & nope
size_t qlen_block = div_up((size_t)qlens[query], col_block);
TaskCounter task_counter({config.num_heads, 2, qlen_block});
auto task = [&](int task_id) {
auto head_idx = task_counter.at(task_id, 0);
bool nope_or_rope = (task_counter.at(task_id, 1) == 0);
auto qlen_block_idx = task_counter.at(task_id, 2);
size_t qlen_begin = qlen_block_idx * col_block;
size_t qlen_end = std::min(qlen_begin + col_block, (size_t)qlens[query]);
KMatRef b = q_lora_rank_out_ref.trans_view().offset_col(qlen_begin, qlen_end - qlen_begin);
if (nope_or_rope) {
auto a = local_q_b_proj_ref.offset_row(head_idx * (config.nope_size + config.rope_size), config.nope_size);
KMatRef c = use_absorb ? q_nope_tmp_ref : q_nope_ref;
c = c.offset_col(head_idx * qlens[query] + qlen_begin, qlen_end - qlen_begin);
arm_kml::mul_mat_clearc(a, b, c);
} else {
auto a = local_q_b_proj_ref.offset_row(head_idx * (config.nope_size + config.rope_size) + config.nope_size,
config.rope_size);
KMatRef c = use_absorb ? q_pe_absorb_ref : q_pe_noabsorb_ref;
c = c.offset_col(head_idx * qlens[query] + qlen_begin, qlen_end - qlen_begin);
arm_kml::mul_mat_clearc(a, b, c);
T_RopeApplier::apply_multiple(*rope_angle, c.data, config.rope_size, c.ld, qlen_begin + kvlens[query],
qlen_end - qlen_begin);
}
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
printf("q nope/rope[%d]\n", tp_part_idx);
if (use_absorb) {
dump_bin(file_name + "_q_nope", q_nope, config.nope_size * qlens[query]);
} else {
dump_bin(file_name + "_q_nope", q, q_ld * qlens[query]);
}
dump_bin(file_name + "_q", q, q_ld * qlens[query]);
dump_bin(file_name + "_rope_cos", rope_angle->cos(0), config.rope_size / 2 * qlens[query]);
dump_bin(file_name + "_rope_sin", rope_angle->sin(0), config.rope_size / 2 * qlens[query]);
#endif
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("q nope/rope");
#endif
if (use_absorb) {
nope_attention_q_absorb(qlens[query], kvlens[query], page_tables[query], false);
} else {
nope_attention_no_absorb(qlens[query], kvlens[query], page_tables[query], false);
}
#ifdef DEBUG_THIS_MLA
printf("attention weights[%d] \n", tp_part_idx);
for (size_t i = 0; i < config.num_heads; i++)
dump_bin(file_name + "_raw_attention_weights_" + std::to_string(i),
attention_weights_ref.offset_col(i * qlens[query], qlens[query]).data, config.max_kvlen * qlens[query]);
#endif
{
// attentino mask & soft max
auto task_counter = TaskCounter({config.num_heads, (size_t)qlens[query]});
auto task = [&](int task_id) {
size_t head_idx = task_counter.at(task_id, 0);
size_t qlen_idx = task_counter.at(task_id, 1);
size_t qlen_from_start = qlen_idx + kvlens[query];
A* aw = offset_pointer(attention_weights,
(config.max_kvlen * qlens[query] * head_idx + config.max_kvlen * qlen_idx) * sizeof(A));
for (int i = 0; i < kvlens[query] + qlens[query]; i++) {
aw[i] *= softmax_scale;
aw[i] += static_cast<A*>(attention_masks[qlen_from_start])[i];
}
T_SoftmaxApplier::apply_single(aw, kvlens[query] + qlens[query]);
};
pool->do_work_stealing_job(task_counter.count(), task);
}
#ifdef DEBUG_THIS_MLA
printf("attention weights after softmax[%d] \n", tp_part_idx);
for (size_t i = 0; i < config.num_heads; i++)
dump_bin(file_name + "_attention_weights_" + std::to_string(i),
attention_weights_ref.offset_col(i * qlens[query], qlens[query]).data, config.max_kvlen * qlens[query]);
#endif
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("attention mask & softmax");
#endif
if (use_absorb) {
output_absorb(query, qlens, page_tables, kvlens);
} else {
output_no_absorb(query, qlens, page_tables, kvlens);
}
#ifdef DEBUG_THIS_MLA
printf("attn output done[%d]\n", tp_part_idx);
#endif
// 量化输入attention_output_ref->data (attention_output) [config.nope_size * config.num_heads, qlens[query]] col
// major
{
size_t nth = GemmKernel::recommended_nth(qlens[query]);
auto task_counter = TaskCounter({nth});
auto task = [&](int task_id) {
size_t nth_idx = task_counter.at(task_id, 0);
// local_w_o_bb->from_mat(attention_output_ref.data, nth_idx, nth, qlens[query]);
local_w_o_quant_ba->from_mat(qlens[query], attention_output_ref.data, nth_idx, nth);
};
DIRECT_OR_POOL_BY(qlens[query], 10, task_counter.count(), task);
}
{
// output
auto qlen_block = div_up((size_t)qlens[query], col_block_out_by_head);
auto hidden_size_block = div_up((size_t)config.hidden_size, row_block_out_by_head);
KMatRefA local_w_o_ba_ref = KMatRefA(local_w_o_quant_ba->a, config.nope_size * config.num_heads, qlens[query],
config.nope_size * config.num_heads, CblasColMajor);
KMatRef reduce_qlen_output_ref = output_ref.offset_row(qlen_split[query], qlens[query]).trans_view();
KMatRefC reduce_qlen_output_ref_buffer =
output_ref_buffer.offset_row(qlen_split[query], qlens[query]).trans_view();
for (size_t mhead_idx = 0; mhead_idx < config.num_heads / sub_num_heads; mhead_idx++) {
auto task_counter = TaskCounter({qlen_block, hidden_size_block});
auto task = [&](int task_id) {
size_t head_begin = config.nope_size * mhead_idx * sub_num_heads;
size_t head_end = head_begin + config.nope_size * sub_num_heads;
size_t qlen_block_idx = task_counter.at(task_id, 0);
size_t hidden_size_block_idx = task_counter.at(task_id, 1);
size_t qlen_begin = qlen_block_idx * col_block_out_by_head;
size_t qlen_end = std::min(qlen_begin + col_block_out_by_head, (size_t)qlens[query]);
size_t hidden_size_begin = hidden_size_block_idx * row_block_out_by_head;
size_t hidden_size_end = std::min(hidden_size_begin + row_block_out_by_head, (size_t)config.hidden_size);
KMatRefC this_qlen_output_ref_buffer = reduce_qlen_output_ref_buffer.offset_block(
hidden_size_begin, qlen_begin, hidden_size_end - hidden_size_begin, qlen_end - qlen_begin);
KMatRefB this_local_w_o_ref = local_w_o_ref.offset_block(
hidden_size_begin, head_begin, hidden_size_end - hidden_size_begin, head_end - head_begin);
KMatRefA this_attention_output_ref =
local_w_o_ba_ref.offset_block(head_begin, qlen_begin, head_end - head_begin, qlen_end - qlen_begin);
if (mhead_idx == 0) {
// arm_kml::mul_mat_clearc(this_local_w_o_ref, this_attention_output_ref, this_qlen_output_ref_buffer);
arm_kml::mul_mat_clearc(this_attention_output_ref.trans_view(), this_local_w_o_ref.trans_view(),
this_qlen_output_ref_buffer.trans_view());
#ifdef DEBUG_THIS_MLA
dump_bin(file_name + "_output_by_head_" + std::to_string(mhead_idx), this_qlen_output_ref_buffer.data,
qlens[query] * config.hidden_size);
// printf("if pack: %d\n", this_local_w_o_ref.trans_view().if_pack);
#endif
} else {
// arm_kml::mul_mat(this_local_w_o_ref, this_attention_output_ref, this_qlen_output_ref_buffer);
arm_kml::mul_mat(this_attention_output_ref.trans_view(), this_local_w_o_ref.trans_view(),
this_qlen_output_ref_buffer.trans_view());
#ifdef DEBUG_THIS_MLA
dump_bin(file_name + "_output_by_head_" + std::to_string(mhead_idx), this_qlen_output_ref_buffer.data,
qlens[query] * config.hidden_size);
// printf("if pack: %d\n", this_local_w_o_ref.trans_view().if_pack);
#endif
}
if (mhead_idx == config.num_heads / sub_num_heads - 1) {
GemmKernel::apply_scale(reduce_qlen_output_ref.data, reduce_qlen_output_ref.ld, local_w_o_quant_ba,
local_w_o_quant_bb, local_w_o_prefill_bc, qlen_begin, qlen_end, hidden_size_begin,
hidden_size_end, true, qlen_split[query], 0);
// GemmKernel::apply_scale(reduce_qlen_output_ref.data, reduce_qlen_output_ref.ld, local_w_o_quant_bb,
// local_w_o_quant_ba, local_w_o_prefill_bc, hidden_size_begin, hidden_size_end,
// qlen_begin, qlen_end, false, 0, qlen_split[query]);
}
};
pool->do_work_stealing_job(task_counter.count(), task);
}
}
#ifdef DEBUG_THIS_MLA
printf("output by head done[%d]\n", tp_part_idx);
// dump_bin(file_name + "_local_w_o", local_w_o, config.hidden_size * config.nope_size * config.num_heads);
dump_bin(file_name + "_qlen_output", (A*)output_raw, qlens[query] * config.hidden_size);
#endif
#ifdef FORWARD_TIME_PROFILE
PROFILE_RECORD_TIME_STAMP("output by head");
#endif
#ifdef FORWARD_TIME_PROFILE
time_perf_name = "[mla] layer " + std::to_string(config.layer_idx) +
" tp_part_idx: " + std::to_string(tp_part_idx) + ", query: " + std::to_string(query);
perf_report();
#endif
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -11,31 +11,20 @@
#ifndef CPUINFER_OPERATOR_KVCACHE_H
#define CPUINFER_OPERATOR_KVCACHE_H
#include <algorithm>
#include <atomic>
#include <cassert>
#include <condition_variable>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <functional>
#include <future>
#include <iostream>
#include <memory>
#include <mutex>
#include <queue>
#include <random>
#include <stdexcept>
#include <thread>
#include <vector>
#include "../../cpu_backend/worker_pool.h"
#include "llama.cpp/ggml-common.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#define CHUNK_SIZE 32

View File

@@ -9,8 +9,11 @@
**/
#include <chrono>
#include <cmath>
#include "ggml-impl.h"
#include "kvcache.h"
#include "llamafile/sgemm.h"
void KVCache::attention_kvhead_(const uint16_t* q_in_data, ggml_fp16_t* output, float* attn_lse, int batch_size,
WorkerPool* backend) {

View File

@@ -9,6 +9,8 @@
**/
#include <chrono>
#include <fstream>
#include <iostream>
#include "kvcache.h"

View File

@@ -10,6 +10,7 @@
#include <chrono>
#include "ggml-impl.h"
#include "kvcache.h"
void KVCache::get_anchor_one_block(ggml_fp16_t* anchor, int layer_id, int block_idx, WorkerPool* backend) {

View File

@@ -10,6 +10,7 @@
#include <chrono>
#include "ggml-impl.h"
#include "kvcache.h"
std::string ggml_type_to_string(ggml_type type) {

View File

@@ -1,6 +1,8 @@
#ifndef LLAMAFILE_MOE_HPP
#define LLAMAFILE_MOE_HPP
#ifdef FORWARD_TIME_PROFILE
#include <fmt/format.h>
#endif
#include <numa.h>
#include <numaif.h>
@@ -10,8 +12,6 @@
#include <cstdint>
#include <cstdio>
#include <functional>
#include <iostream>
#include <mutex>
#include <vector>
#include "../../cpu_backend/shared_mem_buffer.h"
@@ -737,8 +737,7 @@ class TP_MOE<LLAMA_MOE_TP> : public TP_MOE_Common<LLAMA_MOE_TP> {
public:
using TP_MOE_Common<LLAMA_MOE_TP>::TP_MOE_Common;
void load_weights(const uint64_t* physical_to_logical_map) {
throw std::runtime_error("[llamafile] physical_to_logical_map not supported");
void load_weights() {
auto pool = this->config.pool;
auto inter = this->config.intermediate_size / this->tp_count;
pool->dispense_backend()->do_numa_job([this, pool, inter](int tp_id) {

View File

@@ -20,7 +20,6 @@ concept MOE_TP_PART = requires(T t, int qlen, int k, const int64_t* expert_ids,
template <MOE_TP_PART T>
class TP_MOE_Common : public MoE_Interface {
protected:
GeneralMOEConfig config;
std::vector<GeneralMOEConfig> tp_configs;
int tp_count;
int me_numa_id;
@@ -35,6 +34,7 @@ class TP_MOE_Common : public MoE_Interface {
size_t forward_count = 0;
#endif
public:
GeneralMOEConfig config;
using input_t = typename T::input_t;
TP_MOE_Common(GeneralMOEConfig config) : config(config) {
printf("TP MOE layer %d, pool: 0x%lx, expert num: %d, num_experts_per_tok: %d\n", config.layer_idx,
@@ -142,7 +142,7 @@ class TP_MOE_Common : public MoE_Interface {
#endif
}
virtual void load_weights(const uint64_t* physical_to_logical_map) = 0;
virtual void load_weights() = 0;
virtual void merge_results(int qlen, void* output) = 0;
virtual void merge_results(int qlen, void* output, bool incremental) {

View File

@@ -0,0 +1,63 @@
// BOOST_STRONG_TYPEDEF(int8_t, int4_2_t);
#pragma once
#include <cstdint>
#include "llama.cpp/ggml.h"
#if !defined(CPUINFER_HAS_FLOAT16_T)
using float16_t = ggml_fp16_t;
#define CPUINFER_HAS_FLOAT16_T 1
#endif
#if !defined(CPUINFER_HAS_BFLOAT16_T)
using bfloat16_t = ggml_bf16_t;
#define CPUINFER_HAS_BFLOAT16_T 1
#endif // CPUINFER_HAS_BFLOAT16_T
const bool PACKED = true;
#if defined(__aarch64__) || defined(__arm__) || defined(CPU_USE_KML)
#ifndef CPU_USE_KML
#define CPU_USE_KML
#endif
#endif // USE_MOE_KERNEL_AMD or CPU_USE_KML
#define STRONG_TYPEDEF(T, D) \
struct D { \
T t; \
explicit D(const T &v) : t(v) {} \
D() = default; \
D(const D &) = default; \
D &operator=(const D &) = default; \
D &operator=(const T &rhs) { \
t = rhs; \
return *this; \
} \
operator const T &() const { return t; } \
operator T &() { return t; } \
bool operator==(const D &rhs) const { return t == rhs.t; } \
bool operator!=(const D &rhs) const { return t != rhs.t; } \
bool operator<(const D &rhs) const { return t < rhs.t; } \
};
STRONG_TYPEDEF(int8_t, int4_2_t)
typedef int8_t BLASINT8;
/* matrix transpose or conjugate transpose */
typedef enum KERNEL_CBLAS_TRANSPOSE {
KernelCblasNoTrans = 111,
KernelCblasTrans = 112,
KernelCblasConjTrans = 113,
KernelCblasConjNoTrans = 114
} KERNEL_CBLAS_TRANSPOSE;
/* matrix stored in rows or cols */
typedef enum KERNEL_CBLAS_ORDER { KernelCblasRowMajor = 101, KernelCblasColMajor = 102 } KERNEL_CBLAS_ORDER;
/* matrix position is left or right */
typedef enum KERNEL_CBLAS_SIDE { KernelCblasLeft = 141, KernelCblasRight = 142 } KERNEL_CBLAS_SIDE;
typedef KERNEL_CBLAS_ORDER KERNEL_CBLAS_LAYOUT;
typedef enum KERNEL_CBLAS_OFFSET {
KernelCblasRowOffset = 171,
KernelCblasColOffset = 172,
KernelCblasFixOffset = 173
} KERNEL_CBLAS_OFFSET;
enum class MatKernelVariant {
Decode,
Prefill,
};

View File

@@ -0,0 +1,30 @@
#pragma once
#include <cstddef>
#include <cstdint>
#include <type_traits>
#include "common.h"
using GemmFn = void (*)(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
const int8_t oa, const void* b, const size_t ldb, const int8_t ob, const float beta, int32_t* c,
const size_t ldc, const int32_t* oc);
struct MatKernelSelection {
GemmFn fn;
int divide_elements_size;
};
MatKernelSelection select_kernel_for_int4(MatKernelVariant variant);
MatKernelSelection select_kernel_for_int8(MatKernelVariant variant);
template <typename T>
MatKernelSelection select_mat_kernel(MatKernelVariant variant) {
if constexpr (std::is_same_v<typename T::dt, int4_2_t>) {
return select_kernel_for_int4(variant);
} else {
return select_kernel_for_int8(variant);
}
}

View File

@@ -1,11 +1,5 @@
#ifndef CPUINFER_OPERATOR_KML_LA_HPP
#define CPUINFER_OPERATOR_KML_LA_HPP
#include "batch_gemm_api.hpp"
// #include <boost/serialization/strong_typedef.hpp>
// #include "../../common.hpp"
#include <arm_sve.h>
#ifndef CPUINFER_OPERATOR_KERNEL_LA_HPP
#define CPUINFER_OPERATOR_KERNEL_LA_HPP
#include <algorithm>
#include <cassert>
@@ -13,35 +7,16 @@
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <vector>
#include "kblas.h"
#include "../api/common.h"
#include "../mat_kernel/batch_gemm_api.hpp"
#include "llama.cpp/ggml.h"
#include "utils.hpp"
// BOOST_STRONG_TYPEDEF(int8_t, int4_2_t);
#define STRONG_TYPEDEF(T, D) \
struct D { \
T t; \
explicit D(const T &v) : t(v) {} \
D() = default; \
D(const D &) = default; \
D &operator=(const D &) = default; \
D &operator=(const T &rhs) { \
t = rhs; \
return *this; \
} \
operator const T &() const { return t; } \
operator T &() { return t; } \
bool operator==(const D &rhs) const { return t == rhs.t; } \
bool operator!=(const D &rhs) const { return t != rhs.t; } \
bool operator<(const D &rhs) const { return t < rhs.t; } \
};
STRONG_TYPEDEF(int8_t, int4_2_t);
namespace arm_kml {
static const size_t MAX_Nth_B = 1024, MAX_N_B = 1024, MAX_K_B = 10240;
namespace moe_kernel {
template <typename T>
T *offset_pointer(T *ptr, size_t byte_offset) {
return reinterpret_cast<T *>(reinterpret_cast<char *>(ptr) + byte_offset);
@@ -73,7 +48,8 @@ struct BufferAImpl {
static constexpr int M_STEP = K::M_STEP;
static constexpr int K_STEP = K::K_STEP;
static constexpr int K_BLOCK = K::K_BLOCK;
// K_BLOCK is runtime-configurable via kernel tiling; expose as function to avoid constexpr requirements
static inline int K_BLOCK() { return K::K_BLOCK; }
static constexpr int PACK_SIZE_M = K::PACK_SIZE_M;
static constexpr int PACK_SIZE_K = K::PACK_SIZE_K;
@@ -265,37 +241,6 @@ struct BufferAImpl {
}
}
void from_mat(int m, float16_t *src, int ith, int mth) {
assert(m <= max_m);
assert(ith == 0 && mth == 1);
if (!(ith == 0 && mth == 1)) {
throw std::runtime_error("m must be a multiple of M_STEP");
}
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
float amax = 0;
// TODO: 后续用 SVE 来加速
for (int j = 0; j < k; j++) {
// 先把 src 转换成 float
float f = src[(m_begin + i) * k + j];
f = f < 0 ? -f : f;
if (f > amax) {
amax = f;
}
}
d[m_begin + i] = amax / ((1 << 7) - 1);
// TODO: 后续用 SVE 来加速
// 通过这个 amax 来量化这一行
for (int j = 0; j < k; j++) {
// 先把 src 转换成 float
float f = src[(m_begin + i) * k + j];
// 这里的 amax 是当前行的最大值
a[(m_begin + i) * k + j] = static_cast<int8_t>(std::round(f / d[m_begin + i]));
}
}
}
}
// 反量化
void to_mat(int m, float *dst, int ith, int mth) {
auto [m_start, m_end] = K::split_range_m(m, ith, mth);
@@ -323,7 +268,8 @@ struct BufferCImpl {
static constexpr int M_STEP = K::M_STEP;
static constexpr int N_STEP = K::N_STEP;
static constexpr int N_BLOCK = K::N_BLOCK;
// N_BLOCK is runtime-configurable via kernel tiling; expose as function to avoid constexpr requirements
static inline int N_BLOCK() { return K::N_BLOCK; }
static size_t required_size(int max_m, int n) { return sizeof(int32_t) * max_m * n; }
@@ -347,27 +293,6 @@ struct BufferCImpl {
// }
};
// struct MLAGemmKernelInt8 {
// using dt = int8_t;
// using output_t = int32_t;
// struct BufferA {
// int8_t *a;
// float *d;
// int max_m, k;
// BufferA(int max_m, int k, void *ptr) : max_m(max_m), k(k) {
// assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
// assert(max_m % GemmKernelInt8::M_STEP == 0);
// assert(k % GemmKernelInt8::K_STEP == 0);
// a = reinterpret_cast<int8_t *>(ptr);
// d = reinterpret_cast<float *>(a + max_m * k);
// }
// void from_mat(int m, float16_t *src, int ith, int nth) {}
// float *get_scale(int m, int m_begin) { return d + m_begin; }
// };
// };
struct GemmKernelInt8 {
using dt = int8_t;
using output_t = int32_t;
@@ -385,12 +310,31 @@ struct GemmKernelInt8 {
static const int K_STEP = 1;
// static inline const int N_BLOCK = 1024;
static inline const int N_BLOCK_UP_GATE = 256;
static inline const int N_BLOCK_DOWN = 1024;
static inline const int N_BLOCK = 64;
static inline const int M_BLOCK = 64;
// Make tiling params runtime-configurable (modifiable via Python bindings)
static inline int N_BLOCK_UP_GATE = 32;
static inline int N_BLOCK_DOWN = 64;
static inline int N_BLOCK_UP_GATE_PREFI = 32;
static inline int N_BLOCK_DOWN_PREFI = 64;
static inline int N_BLOCK = 64;
static inline int M_BLOCK = 320;
// static inline const int N_BLOCK = 32;
static inline const int K_BLOCK = 7168;
static inline int K_BLOCK = 7168;
// Setter/getter for runtime tiling configuration
static void set_tiling(int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block,
int n_block_up_gate_prefi, int n_block_down_prefi) {
N_BLOCK_UP_GATE = n_block_up_gate;
N_BLOCK_DOWN = n_block_down;
N_BLOCK = n_block;
M_BLOCK = m_block;
K_BLOCK = k_block;
N_BLOCK_UP_GATE_PREFI = n_block_up_gate_prefi;
N_BLOCK_DOWN_PREFI = n_block_down_prefi;
}
static std::tuple<int, int, int, int, int, int, int> get_tiling() {
return std::make_tuple(N_BLOCK_UP_GATE, N_BLOCK_DOWN, N_BLOCK, M_BLOCK, K_BLOCK, N_BLOCK_UP_GATE_PREFI,
N_BLOCK_DOWN_PREFI);
}
static inline const int PACK_SIZE_N = 8;
static inline const int PACK_SIZE_M = 8;
@@ -398,26 +342,44 @@ struct GemmKernelInt8 {
static std::string name() { return "INT8"; }
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
static int recommended_nth_down(int n) {
assert(n % N_BLOCK == 0);
return n / N_BLOCK_DOWN;
// type_: d for decode, p for prefill
static int recommended_nth_down(int n, char type_ = 'd') {
if (type_ == 'p') {
if (n % N_BLOCK_DOWN_PREFI != 0) {
throw std::invalid_argument("n must be multiple of N_BLOCK_DOWN_PREFI in prefill");
}
return n / N_BLOCK_DOWN_PREFI;
} else {
if (n % N_BLOCK_DOWN != 0) {
throw std::invalid_argument("n must be multiple of N_BLOCK_DOWN in decode");
}
return n / N_BLOCK_DOWN;
}
}
static int recommended_nth_up_gate(int n) {
assert(n % N_BLOCK_UP_GATE == 0);
return n / N_BLOCK_UP_GATE;
static int recommended_nth_up_gate(int n, char type_ = 'd') {
if (type_ == 'p') {
if (n % N_BLOCK_UP_GATE_PREFI != 0) {
throw std::invalid_argument("n must be multiple of N_BLOCK_UP_GATE_PREFI in prefill");
}
return n / N_BLOCK_UP_GATE_PREFI;
} else {
if (n % N_BLOCK_UP_GATE != 0) {
throw std::invalid_argument("n must be multiple of N_BLOCK_UP_GATE in decode");
}
return n / N_BLOCK_UP_GATE;
}
}
static int recommended_mth(int m) { return (m + M_BLOCK - 1) / M_BLOCK; }
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
int n_start = N_BLOCK * ith;
int n_end = std::min(n, N_BLOCK * (ith + 1));
static std::pair<int, int> split_range_n(int n, int ith, int nth, int block_size = N_BLOCK) {
int n_start = block_size * ith;
int n_end = std::min(n, block_size * (ith + 1));
return {n_start, n_end};
}
static std::pair<int, int> split_range_m(int m, int ith, int mth) {
static std::pair<int, int> split_range_m(int m, int ith, int mth = 0) {
int m_start = M_BLOCK * ith;
int m_end = std::min(m, M_BLOCK * (ith + 1));
return {m_start, m_end};
@@ -434,26 +396,83 @@ struct GemmKernelInt8 {
struct BufferB {
int8_t *b;
std::vector<int8_t *> b_pack; // b_pack[i] -> the ith block (the ith packed matrix of B)
size_t reorder_B_size;
size_t nth_B; // number of blocks of B
size_t block_size; // size of each block of B
float *d;
int n, k;
static constexpr bool SCALE = true;
bool if_pack = false;
static size_t required_size(int n, int k) { return sizeof(int8_t) * n * k + sizeof(float) * n; }
BufferB(int n, int k, void *ptr, bool if_pack = false) : n(n), k(k), if_pack(if_pack) {
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
b = reinterpret_cast<int8_t *>(ptr);
d = reinterpret_cast<float *>(b + n * k);
// n for normal, u for up_gate, d for down
static size_t required_size(int n, int k, bool if_pack = false, char mat_type = 'n', bool plain = true) {
int nth, n_block;
if (if_pack && !plain) {
switch (mat_type) {
case 'n':
nth = recommended_nth(n);
n_block = N_BLOCK;
break;
case 'u':
nth = recommended_nth_up_gate(n);
n_block = N_BLOCK_UP_GATE;
break;
case 'd':
nth = recommended_nth_down(n);
n_block = N_BLOCK_DOWN;
break;
default:
throw std::invalid_argument("Invalid mat_type");
}
size_t reorder_B_size = get_reorder_B_size(KernelCblasRowMajor, KernelCblasNoTrans, k, n_block);
return sizeof(int8_t) * nth * reorder_B_size + sizeof(float) * n;
} else {
return sizeof(int8_t) * n * k + sizeof(float) * n;
}
}
BufferB(int n, int k, bool if_pack = false) : n(n), k(k), if_pack(if_pack) {
BufferB(int n, int k, bool if_pack = false, char mat_type = 'n', bool plain = true) : n(n), k(k), if_pack(if_pack) {
int nth, n_block;
if (if_pack && !plain) {
switch (mat_type) {
case 'n':
nth = recommended_nth(n);
n_block = N_BLOCK;
break;
case 'u':
nth = recommended_nth_up_gate(n);
n_block = N_BLOCK_UP_GATE;
break;
case 'd':
nth = recommended_nth_down(n);
n_block = N_BLOCK_DOWN;
break;
default:
throw std::invalid_argument("Invalid mat_type");
}
reorder_B_size = get_reorder_B_size(KernelCblasRowMajor, KernelCblasNoTrans, k, n_block);
nth_B = nth;
block_size = n_block;
b_pack.resize(nth);
}
if (n % N_STEP != 0 || k % K_STEP != 0) {
throw std::runtime_error("n and k must be multiples of N_STEP and K_STEP respectively");
}
}
void set_data(void *ptr) {
b = reinterpret_cast<int8_t *>(ptr);
d = reinterpret_cast<float *>(b + n * k);
BufferB(int n, int k, void *ptr, bool if_pack = false, char mat_type = 'n', bool plain = true)
: BufferB(n, k, if_pack, mat_type, plain) {
set_data(ptr, plain);
// printf("mat_type:%c,nth_B:%zu,b_pack_ptr[0]:%p,d_ptr:%p,ptr:%p\n", mat_type, nth_B, b_pack[0], d, ptr);
}
void set_data(void *ptr, bool plain = true) {
if (if_pack && !plain) {
for (size_t i = 0; i < nth_B; i++) {
b_pack[i] = reinterpret_cast<int8_t *>(ptr) + i * reorder_B_size;
}
d = reinterpret_cast<float *>((int8_t *)ptr + nth_B * reorder_B_size);
} else {
b = reinterpret_cast<int8_t *>(ptr);
d = reinterpret_cast<float *>(b + n * k);
}
}
size_t required_size() const { return sizeof(int8_t) * n * k + sizeof(float) * n; }
BufferB offset_col(size_t col_begin, size_t col_block) {
@@ -462,13 +481,17 @@ struct GemmKernelInt8 {
return bufferb;
}
// B 矩阵是 K * N 的矩阵,存储在 b 中, 是列主序的 (column major)
void from_mat(ggml_bf16_t *src, int ith, int nth, int n_new = -1,
bool if_pack = false) { // CHECK: nth has no usage
void from_mat(ggml_bf16_t *src, int ith, int nth, int n_new = -1, bool if_pack = false,
bool plain = true) { // CHECK: nth has no usage
if (n_new > 0) {
n = n_new; // 如果 n_new 大于 0则使用 n_new
}
// 这里将 src 转换成 int8_t 的形式按照k 维度量化 (也就是按列量化)
auto [n_start, n_end] = split_range_n(n, ith, nth);
int8_t *b_t = nullptr;
if ((if_pack || this->if_pack) && !plain) {
b_t = (int8_t *)malloc(sizeof(int8_t) * n * k);
}
auto [n_start, n_end] = split_range_n(n, ith, nth, block_size);
int n_block_begin = n_start;
int n_block_size = n_end - n_block_begin;
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
@@ -489,7 +512,7 @@ struct GemmKernelInt8 {
for (int j = 0; j < k; j++) {
// 先把 src 转换成 float
float f = bf16_to_fp32(src[(n_block_begin + n_begin + i) * k + j]);
if (if_pack || this->if_pack) {
if ((if_pack || this->if_pack) && plain) {
size_t split_n = (n_begin + i) / PACK_SIZE_N;
size_t n_idx = (n_begin + i) % PACK_SIZE_N;
size_t split_k = j / PACK_SIZE_K;
@@ -498,58 +521,20 @@ struct GemmKernelInt8 {
size_t buff_idx = n_block_begin * k + split_n * PACK_SIZE_N * k + split_k * PACK_SIZE_N * PACK_SIZE_K +
n_idx * PACK_SIZE_K + k_idx;
b[buff_idx] = static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));
} else {
} else if ((if_pack || this->if_pack) && !plain) {
// 这里的 amax 是当前列的最大值
b_t[(n_begin + i) * k + j] = static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));
} else {
b[(n_block_begin + n_begin + i) * k + j] =
static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));
}
}
}
}
}
void from_mat(float16_t *src, int ith, int nth, int n_new = -1, bool if_pack = false) { // CHECK: nth has no usage
if (n_new > 0) {
n = n_new; // 如果 n_new 大于 0则使用 n_new
}
// 这里将 src 转换成 int8_t 的形式按照k 维度量化 (也就是按列量化)
auto [n_start, n_end] = split_range_n(n, ith, nth);
int n_block_begin = n_start;
int n_block_size = n_end - n_block_begin;
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
for (int i = 0; i < N_STEP && n_begin + i < n_block_size; i++) {
float amax = 0;
// TODO: 后续用 SVE 来加速
for (int j = 0; j < k; j++) {
// 先把 src 转换成 float
float f = src[(n_block_begin + n_begin + i) * k + j];
f = f < 0 ? -f : f;
if (f > amax) {
amax = f;
}
}
d[n_block_begin + n_begin + i] = amax / ((1 << 7) - 1);
// TODO: 后续用 SVE 来加速
// 通过这个 amax 来量化这一列
for (int j = 0; j < k; j++) {
// 先把 src 转换成 float
float f = src[(n_block_begin + n_begin + i) * k + j];
if (if_pack || this->if_pack) {
size_t split_n = (n_begin + i) / PACK_SIZE_N;
size_t n_idx = (n_begin + i) % PACK_SIZE_N;
size_t split_k = j / PACK_SIZE_K;
size_t k_idx = j % PACK_SIZE_K;
size_t buff_idx = n_block_begin * k + split_n * PACK_SIZE_N * k + split_k * PACK_SIZE_N * PACK_SIZE_K +
n_idx * PACK_SIZE_K + k_idx;
b[buff_idx] = static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));
} else {
// 这里的 amax 是当前列的最大值
b[(n_block_begin + n_begin + i) * k + j] =
static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));
}
}
}
if ((if_pack || this->if_pack) && !plain) {
// 在这里调用 AMD 的reorder函数
reorder_B_gemm(KernelCblasColMajor, KernelCblasNoTrans, k, n_block_size, k, b_t, b_pack[ith]);
free(b_t);
}
}
@@ -704,12 +689,19 @@ struct GemmKernelInt8 {
}
// 对第二个维度分块的 apply scale
static void apply_scale(int m, int n, float *c, BufferA *ba, BufferB *bb, BufferC *bc, int ith, int nth, int block) {
static void apply_scale(int m, int n, float *c, BufferA *ba, BufferB *bb, BufferC *bc, int ith, int nth, int block,
int jth = -1) {
// printf("use split apply scale\n");
auto [n_start, n_end] = split_range_n_block(n, ith, nth, block);
int m_start = 0, m_end = m;
if (jth != -1) {
auto tmp = split_range_m(m, jth);
m_start = tmp.first;
m_end = tmp.second;
}
// TODO: 后续用 SVE 来加速
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
for (int m_begin = m_start; m_begin < m_end; m_begin += M_STEP) {
for (int i = 0; i < M_STEP && m_begin + i < m_end; i++) {
float *scale_a = ba->get_scale(m, m_begin + i);
for (int n_begin = n_start; n_begin < n_end; n_begin += N_STEP) {
for (int j = 0; j < N_STEP && n_begin + j < n_end; j++) {
@@ -811,12 +803,31 @@ struct GemmKernelInt4 {
static const int K_STEP = 1;
// static inline const int N_BLOCK = 1024;
static inline const int N_BLOCK_UP_GATE = 256;
static inline const int N_BLOCK_DOWN = 1024;
static inline const int N_BLOCK = 64;
static inline const int M_BLOCK = 64;
// Make tiling params runtime-configurable (modifiable via Python bindings)
static inline int N_BLOCK_UP_GATE = 256;
static inline int N_BLOCK_DOWN = 1024;
static inline int N_BLOCK_UP_GATE_PREFI = 256;
static inline int N_BLOCK_DOWN_PREFI = 1024;
static inline int N_BLOCK = 64;
static inline int M_BLOCK = 320;
// static inline const int N_BLOCK = 32;
static inline const int K_BLOCK = 7168;
static inline int K_BLOCK = 7168;
// Setter/getter for runtime tiling configuration
static void set_tiling(int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block,
int n_block_up_gate_prefi, int n_block_down_prefi) {
N_BLOCK_UP_GATE = n_block_up_gate;
N_BLOCK_DOWN = n_block_down;
N_BLOCK = n_block;
M_BLOCK = m_block;
K_BLOCK = k_block;
N_BLOCK_UP_GATE_PREFI = n_block_up_gate_prefi;
N_BLOCK_DOWN_PREFI = n_block_down_prefi;
}
static std::tuple<int, int, int, int, int, int, int> get_tiling() {
return std::make_tuple(N_BLOCK_UP_GATE, N_BLOCK_DOWN, N_BLOCK, M_BLOCK, K_BLOCK, N_BLOCK_UP_GATE_PREFI,
N_BLOCK_DOWN_PREFI);
}
static inline const int PACK_SIZE_N = 8;
static inline const int PACK_SIZE_K = 32;
@@ -825,15 +836,33 @@ struct GemmKernelInt4 {
static std::string name() { return "INT4"; }
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
static int recommended_nth_down(int n) {
assert(n % N_BLOCK == 0);
return n / N_BLOCK_DOWN;
static int recommended_nth_down(int n, char type_ = 'd') {
if (type_ == 'p') {
if (n % N_BLOCK_DOWN_PREFI != 0) {
throw std::invalid_argument("n must be multiple of N_BLOCK_DOWN_PREFI in prefill");
}
return n / N_BLOCK_DOWN_PREFI;
} else {
if (n % N_BLOCK_DOWN != 0) {
throw std::invalid_argument("n must be multiple of N_BLOCK_DOWN in decode");
}
return n / N_BLOCK_DOWN;
}
}
static int recommended_mth(int m) { return (m + M_BLOCK - 1) / M_BLOCK; }
static int recommended_nth_up_gate(int n) {
assert(n % N_BLOCK_UP_GATE == 0);
return n / N_BLOCK_UP_GATE;
static int recommended_nth_up_gate(int n, char type_ = 'd') {
if (type_ == 'p') {
if (n % N_BLOCK_UP_GATE_PREFI != 0) {
throw std::invalid_argument("n must be multiple of N_BLOCK_UP_GATE_PREFI in prefill");
}
return n / N_BLOCK_UP_GATE_PREFI;
} else {
if (n % N_BLOCK_UP_GATE != 0) {
throw std::invalid_argument("n must be multiple of N_BLOCK_UP_GATE in decode");
}
return n / N_BLOCK_UP_GATE;
}
}
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
@@ -860,38 +889,63 @@ struct GemmKernelInt4 {
dt *b;
float *d;
int n, k;
std::vector<int8_t *> b_pack; // b_pack[i] -> the ith block (the ith packed matrix of B)
static constexpr bool SCALE = true;
bool if_pack = false;
static size_t required_size(int n, int k) { return sizeof(int8_t) * n * k / 2 + sizeof(float) * n; }
BufferB(int n, int k, void *ptr, bool if_pack = false) : n(n), k(k), if_pack(if_pack) {
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
assert(n % N_STEP == 0);
assert(k % K_STEP == 0);
b = reinterpret_cast<dt *>(ptr);
d = reinterpret_cast<float *>(arm_kml::offset_pointer(b, n * k / 2));
// static size_t required_size(int n, int k) { return sizeof(int8_t) * n * k / 2 + sizeof(float) * n; }
static size_t required_size(int n, int k, bool if_pack = false, char mat_type = 'n', bool plain = true) {
int nth, n_block;
if (if_pack && !plain) {
switch (mat_type) {
case 'n':
nth = recommended_nth(n);
n_block = N_BLOCK;
break;
case 'u':
nth = recommended_nth_up_gate(n);
n_block = N_BLOCK_UP_GATE;
break;
case 'd':
nth = recommended_nth_down(n);
n_block = N_BLOCK_DOWN;
break;
default:
throw std::invalid_argument("Invalid mat_type");
}
size_t reorder_B_size = get_reorder_B_size(KernelCblasRowMajor, KernelCblasNoTrans, k, n_block);
return sizeof(int8_t) * nth * reorder_B_size + sizeof(float) * n;
} else {
return sizeof(int8_t) * n * k / 2 + sizeof(float) * n;
}
}
BufferB(int n, int k, bool if_pack = false) : n(n), k(k), if_pack(if_pack) {
// assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
assert(n % N_STEP == 0);
assert(k % K_STEP == 0);
// BufferB(int n, int k, void *ptr, bool if_pack = false) : n(n), k(k), if_pack(if_pack) {
// b = reinterpret_cast<dt *>(ptr);
// d = reinterpret_cast<float *>(moe_kernel::offset_pointer(b, n * k / 2));
// }
BufferB(int n, int k, bool if_pack = false, char mat_type = 'n', bool plain = true) : n(n), k(k), if_pack(if_pack) {
if (n % N_STEP != 0 || k % K_STEP != 0) {
throw std::runtime_error("n and k must be multiples of N_STEP and K_STEP respectively");
}
}
void set_data(void *ptr) {
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
BufferB(int n, int k, void *ptr, bool if_pack = false, char mat_type = 'n', bool plain = true)
: BufferB(n, k, if_pack, mat_type, plain) {
set_data(ptr, plain);
}
void set_data(void *ptr, bool plain = true) {
b = reinterpret_cast<dt *>(ptr);
d = reinterpret_cast<float *>(arm_kml::offset_pointer(b, n * k / 2));
d = reinterpret_cast<float *>(moe_kernel::offset_pointer(b, n * k / 2));
}
size_t required_size() const { return sizeof(int8_t) * n * k / 2 + sizeof(float) * n; }
BufferB offset_col(size_t col_begin, size_t col_block) {
auto bufferb = BufferB(col_block, k, arm_kml::offset_pointer(b, (col_begin * k) / 2), if_pack);
auto bufferb = BufferB(col_block, k, moe_kernel::offset_pointer(b, (col_begin * k) / 2), if_pack);
bufferb.d = d + col_begin;
return bufferb;
}
// B 矩阵是 K * N 的矩阵,存储在 b 中, 是列主序的 (column major)
void from_mat(ggml_bf16_t *src, int ith, int nth, int n_new = -1,
bool if_pack = false) { // CHECK: nth has no usage
void from_mat(ggml_bf16_t *src, int ith, int nth, int n_new = -1, bool if_pack = false,
bool plain = true) { // CHECK: nth has no usage
if (!if_pack && !this->if_pack) throw std::runtime_error("from mat for buffer should be packed");
if (n_new > 0) {
n = n_new; // 如果 n_new 大于 0则使用 n_new
@@ -939,55 +993,6 @@ struct GemmKernelInt4 {
}
}
void from_mat(float16_t *src, int ith, int nth, int n_new = -1, bool if_pack = false) { // CHECK: nth has no usage
if (!if_pack && !this->if_pack) throw std::runtime_error("from mat for buffer should be packed");
if (n_new > 0) {
n = n_new; // 如果 n_new 大于 0则使用 n_new
}
// 这里将 src 转换成 int8_t 的形式按照k 维度量化 (也就是按列量化)
auto [n_start, n_end] = split_range_n(n, ith, nth);
int n_block_begin = n_start;
int n_block_size = n_end - n_block_begin;
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
for (int i = 0; i < N_STEP && n_begin + i < n_block_size; i++) {
float amax = 0;
// TODO: 后续用 SVE 来加速
for (int j = 0; j < k; j++) {
// 先把 src 转换成 float
float f = src[(n_block_begin + n_begin + i) * k + j];
f = f < 0 ? -f : f;
if (f > amax) {
amax = f;
}
}
d[n_block_begin + n_begin + i] = amax / 112.0;
// TODO: 后续用 SVE 来加速
// 通过这个 amax 来量化这一列
for (int k_start = 0; k_start < k; k_start += (PACK_SIZE_K * 2)) {
for (int j = 0; j < PACK_SIZE_K; j++) {
size_t split_n = (n_begin + i) / PACK_SIZE_N;
size_t n_idx = (n_begin + i) % PACK_SIZE_N;
size_t split_k = k_start / (PACK_SIZE_K * 2);
size_t k_idx = j;
size_t buff_idx = n_block_begin * k / 2 + split_n * PACK_SIZE_N * k / 2 +
split_k * PACK_SIZE_N * PACK_SIZE_K + n_idx * PACK_SIZE_K + k_idx;
float f0 = (src[(n_block_begin + n_begin + i) * k + k_start + j]);
float f1 = (src[(n_block_begin + n_begin + i) * k + k_start + j + PACK_SIZE_K]);
// static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));
int8_t b0 = static_cast<int8_t>(std::round((f0 / (d[n_block_begin + n_begin + i] * 16.0))) * 16);
int8_t b1 = static_cast<int8_t>(std::round((f1 / (d[n_block_begin + n_begin + i] * 16.0))) * 16);
int8_t b01 = (b0 & 0xF0) | ((b1 >> 4) & 0x0F);
// int8_t b01 = ((b0 << 4) & 0xF0) | ((b1)&0x0F);
b[buff_idx] = b01;
}
}
}
}
}
void from_mat(float *src, int ith, int nth, int n_new = -1, bool if_pack = false) { // CHECK: nth has no usage
if (!if_pack && !this->if_pack) throw std::runtime_error("from mat for buffer should be packed");
if (n_new > 0) {
@@ -1190,212 +1195,6 @@ struct GemmKernelInt4 {
}
};
inline CBLAS_TRANSPOSE flip_trans(CBLAS_TRANSPOSE trans) {
if (trans == CblasNoTrans) {
return CblasTrans;
} else if (trans == CblasTrans) {
return CblasNoTrans;
} else {
throw std::runtime_error("Unsupported transpose");
}
}
template <typename F>
struct MatRef {
F *data = nullptr;
size_t R, C, ld;
CBLAS_ORDER order;
CBLAS_TRANSPOSE trans;
bool if_pack = false;
static inline const int PACK_SIZE_N = 8;
static inline const int PACK_SIZE_M = 8;
static inline const int PACK_SIZE_K = 32;
MatRef() {}
MatRef(F *data, size_t R, size_t C, size_t ld, CBLAS_ORDER order, CBLAS_TRANSPOSE trans = CblasNoTrans,
bool if_pack = false)
: data(data), R(R), C(C), ld(ld), order(order), trans(trans), if_pack(if_pack) {}
MatRef t() {
MatRef re = *this;
std::swap(re.R, re.C);
CBLAS_ORDER new_order = (order == CblasRowMajor) ? CblasColMajor : CblasRowMajor;
re.order = new_order;
return re;
}
CBLAS_TRANSPOSE trans_from(CBLAS_ORDER order) {
if (order == this->order) {
return trans;
} else {
return flip_trans(trans);
}
}
MatRef offset_block(size_t row, size_t col, size_t r_block, size_t c_block) {
if (trans == CblasTrans) {
std::swap(row, col);
std::swap(r_block, c_block);
}
int devide_elements_size = 1;
if constexpr (std::is_same_v<F, int4_2_t>) devide_elements_size = 2;
// printf("devide_elements_size : %d\n", devide_elements_size);
if (order == CblasRowMajor) {
if (if_pack) {
// if (devide_elements_size == 2)
// printf("data:%p,after: %p,offset: %d\n", data, data + (row * ld + col * PACK_SIZE_M) /
// devide_elements_size,
// (row * ld + col * PACK_SIZE_M) / devide_elements_size);
return MatRef(data + (row * ld + col * PACK_SIZE_M) / devide_elements_size, r_block, c_block, ld, order,
CblasNoTrans, if_pack);
} else {
return MatRef(data + (row * ld + col) / devide_elements_size, r_block, c_block, ld, order, CblasNoTrans,
if_pack);
}
} else if (order == CblasColMajor) {
if (if_pack) {
// if (devide_elements_size == 2)
// printf("data:%p,after: %p,offset: %d\n", data, data + (col * ld + row * PACK_SIZE_N) /
// devide_elements_size,
// (col * ld + row * PACK_SIZE_N) / devide_elements_size);
return MatRef(data + (col * ld + row * PACK_SIZE_N) / devide_elements_size, r_block, c_block, ld, order,
CblasNoTrans, if_pack);
} else {
return MatRef(data + (col * ld + row) / devide_elements_size, r_block, c_block, ld, order, CblasNoTrans,
if_pack);
}
} else {
throw std::runtime_error("Unsupported order");
}
}
inline MatRef trans_view() {
if (order == CblasRowMajor) {
return MatRef(data, C, R, ld, CblasColMajor, trans, if_pack);
} else {
return MatRef(data, C, R, ld, CblasRowMajor, trans, if_pack);
}
}
MatRef offset_row(size_t row_begin, size_t row_block) { return offset_block(row_begin, 0, row_block, C); }
MatRef offset_col(size_t col_begin, size_t col_block) { return offset_block(0, col_begin, R, col_block); }
F &at(size_t row, size_t col) {
if (trans == CblasTrans) {
throw std::runtime_error("Unsupported trans");
}
if (order == CblasRowMajor) {
return data[row * ld + col];
} else if (order == CblasColMajor) {
return data[col * ld + row];
} else {
throw std::runtime_error("Unsupported order");
}
}
};
template <typename A, typename B, typename C>
static void mul_mat(MatRef<A> a, MatRef<B> b, MatRef<C> c, C alpha, C beta) {
assert(a.C == b.R);
assert(a.R == c.R);
assert(b.C == c.C);
// assert(a.order == b.order);
// assert(a.order == c.order);
assert(c.trans == CblasNoTrans);
BLASINT8 oa = 0, ob = 0;
int32_t oc = 0;
if constexpr (std::is_same_v<A, float> && std::is_same_v<B, float> && std::is_same_v<C, float>) {
cblas_sgemm(c.order, a.trans_from(c.order), b.trans_from(c.order), c.R, c.C, a.C, alpha, a.data, a.ld, b.data, b.ld,
beta, c.data, c.ld);
} else if constexpr (std::is_same_v<A, float16_t> && std::is_same_v<B, float16_t> && std::is_same_v<C, float16_t>) {
cblas_hgemm(c.order, a.trans_from(c.order), b.trans_from(c.order), c.R, c.C, a.C, alpha, a.data, a.ld, b.data, b.ld,
beta, c.data, c.ld);
} else if constexpr (std::is_same_v<A, float16_t> && std::is_same_v<B, float16_t> && std::is_same_v<C, float>) {
cblas_shgemm(c.order, a.trans_from(c.order), b.trans_from(c.order), c.R, c.C, a.C, alpha, a.data, a.ld, b.data,
b.ld, beta, c.data, c.ld);
} else if constexpr (std::is_same_v<A, bfloat16_t> && std::is_same_v<B, bfloat16_t> &&
std::is_same_v<C, bfloat16_t>) {
cblas_bgemm(c.order, a.trans_from(c.order), b.trans_from(c.order), c.R, c.C, a.C, alpha, a.data, a.ld, b.data, b.ld,
beta, c.data, c.ld);
} else if constexpr (std::is_same_v<A, int8_t> && std::is_same_v<B, int8_t> && std::is_same_v<C, int32_t>) {
if (b.if_pack) {
prefill_cblas_gemm_s8s8s32(c.order, a.trans_from(c.order), b.trans_from(c.order), CblasFixOffset, c.R, c.C, a.C,
alpha, a.data, a.ld, oa, b.data, b.ld, ob, beta, c.data, c.ld, &oc);
} else {
cblas_gemm_s8s8s32(c.order, a.trans_from(c.order), b.trans_from(c.order), CblasFixOffset, c.R, c.C, a.C, alpha,
a.data, a.ld, oa, b.data, b.ld, ob, beta, c.data, c.ld, &oc);
}
} else if constexpr (std::is_same_v<A, int8_t> && std::is_same_v<B, int4_2_t> && std::is_same_v<C, int32_t>) {
// throw std::runtime_error("INT4 does not support cblas_gemm_s8s8s32, please use decode_cblas_gemm_s8s8s32");
if (b.if_pack) {
prefill_int4_cblas_gemm_s8s8s32(c.order, a.trans_from(c.order), b.trans_from(c.order), CblasFixOffset, c.R, c.C,
a.C, alpha, a.data, a.ld, oa, b.data, b.ld, ob, beta, c.data, c.ld, &oc);
} else {
throw std::runtime_error(
"INT4 does not support cblas_gemm_s8s8s32 for unpack, please use decode_cblas_gemm_s8s8s32");
}
} else {
throw std::runtime_error("Unsupported type");
}
}
template <typename A, typename B, typename C>
static void decode_mul_mat(MatRef<A> a, MatRef<B> b, MatRef<C> c, C alpha, C beta) {
assert(a.C == b.R);
assert(a.R == c.R);
assert(b.C == c.C);
// assert(a.order == b.order);
// assert(a.order == c.order);
assert(c.trans == CblasNoTrans);
BLASINT incX = 1, incY = 1;
BLASINT8 oa = 0, ob = 0;
int32_t oc = 0;
if constexpr (std::is_same_v<A, float> && std::is_same_v<B, float> && std::is_same_v<C, float>) {
cblas_sgemv(a.order, a.trans, a.R, a.C, alpha, a.data, a.ld, b.data, incX, beta, c.data, incY);
} else if constexpr (std::is_same_v<A, int8_t> && std::is_same_v<B, int8_t> && std::is_same_v<C, int32_t>) {
// printf("debug: c.order: %d, a.order: %d, b.order: %d,c.R: %zu, c.C: %zu, a.C: %zu, alpha: %d, a.ld: %ld, oa: %d,
// b.ld: %ld, ob: %d, beta: %d, c.ld: %ld, oc: %d\n",
// c.order, a.order, b.order, c.R, c.C, a.C, alpha, a.ld, oa, b.ld, ob, beta, c.ld, oc);
if (b.if_pack)
decode_cblas_gemm_s8s8s32(c.order, a.trans_from(c.order), b.trans_from(c.order), CblasFixOffset, c.R, c.C, a.C,
alpha, a.data, a.ld, oa, b.data, b.ld, ob, beta, c.data, c.ld, &oc);
else
throw std::runtime_error("Unsupported type");
} else if constexpr (std::is_same_v<A, int8_t> && std::is_same_v<B, int4_2_t> && std::is_same_v<C, int32_t>) {
decode_int4_cblas_gemm_s8s8s32(c.order, a.trans_from(c.order), b.trans_from(c.order), CblasFixOffset, c.R, c.C, a.C,
alpha, a.data, a.ld, oa, b.data, b.ld, ob, beta, c.data, c.ld, &oc);
} else {
throw std::runtime_error("Unsupported type");
}
}
template <typename A, typename B, typename C>
static void mul_mat(MatRef<A> a, MatRef<B> b, MatRef<C> c) {
mul_mat(a, b, c, static_cast<C>(1.0), static_cast<C>(1.0));
}
template <typename A, typename B, typename C>
static void mul_mat_clearc(MatRef<A> a, MatRef<B> b, MatRef<C> c) {
mul_mat(a, b, c, static_cast<C>(1.0), static_cast<C>(0.0));
}
template <typename A, typename B, typename C>
static void decode_mul_mat_clearc(MatRef<A> a, MatRef<B> b, MatRef<C> c) {
decode_mul_mat(a, b, c, static_cast<C>(1.0), static_cast<C>(0.0));
}
template <typename A, typename B, typename C>
static void decode_mul_mat(MatRef<A> a, MatRef<B> b, MatRef<C> c) {
decode_mul_mat(a, b, c, static_cast<C>(1.0), static_cast<C>(1.0));
}
} // namespace arm_kml
} // namespace moe_kernel
#endif

View File

@@ -0,0 +1,54 @@
#include "../api/mat_kernel.h"
#include <cassert>
namespace {
constexpr int kInt4ElementDivisor = 2;
constexpr int kInt8ElementDivisor = 1;
} // namespace
extern "C" {
void decode_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
const int8_t oa, const void* b, const size_t ldb, const int8_t ob, const float beta,
int32_t* c, const size_t ldc, const int32_t* oc);
void prefill_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
const int8_t oa, const void* b, const size_t ldb, const int8_t ob, const float beta,
int32_t* c, const size_t ldc, const int32_t* oc);
void decode_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const int8_t oa, const void* b, const size_t ldb, const int8_t ob,
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
void prefill_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const int8_t oa, const void* b, const size_t ldb,
const int8_t ob, const float beta, int32_t* c, const size_t ldc,
const int32_t* oc);
}
MatKernelSelection select_kernel_for_int4(MatKernelVariant variant) {
switch (variant) {
case MatKernelVariant::Decode:
return {decode_int4_cblas_gemm_s8s8s32, kInt4ElementDivisor};
case MatKernelVariant::Prefill:
return {prefill_int4_cblas_gemm_s8s8s32, kInt4ElementDivisor};
}
return {nullptr, 0};
}
MatKernelSelection select_kernel_for_int8(MatKernelVariant variant) {
switch (variant) {
case MatKernelVariant::Decode:
return {decode_cblas_gemm_s8s8s32, kInt8ElementDivisor};
case MatKernelVariant::Prefill:
return {prefill_cblas_gemm_s8s8s32, kInt8ElementDivisor};
}
return {nullptr, 0};
}

View File

@@ -0,0 +1,29 @@
#pragma once
// #include <arm_sve.h>
#include <cstdint>
#include <cstring>
// 简单截断模式:直接丢弃低 16 位
static inline uint16_t float_to_bf16_trunc(float f) {
uint32_t u;
// 按位拷贝,避免 strictaliasing UB
memcpy(&u, &f, sizeof(u)); // :contentReference[oaicite:3]{index=3}
return (uint16_t)(u >> 16); // 截断得到高 16 位 :contentReference[oaicite:4]{index=4}
}
static inline void convert_32fp32_to_32bf16_pure_c(const float* src, uint16_t* dst) {
// src 已偏移至 token_nth * hidden_size
for (int e = 0; e < 32; e++) { // 共 32 个元素
// 选择截断或四舍五入
dst[e] = float_to_bf16_trunc(src[e]);
}
}
// 把 32 个 bf16 元素转换成 32 个 fp32 元素
static inline void convert_32bf16_to_32fp32_pure_c(const uint16_t* src, float* dst) {
for (int e = 0; e < 32; e++) {
uint32_t temp = ((uint32_t)src[e]) << 16; // 将 BF16 左移 16 位
memcpy(&dst[e], &temp, sizeof(float)); // 将结果复制到 FP32 变量中
}
}

View File

@@ -0,0 +1,100 @@
#include <stdexcept>
#include "../batch_gemm_api.hpp"
#include "blis.h"
namespace {
char ToAoclOrder(KERNEL_CBLAS_LAYOUT layout) {
switch (layout) {
case KernelCblasRowMajor:
return 'r';
case KernelCblasColMajor:
return 'c';
}
throw std::invalid_argument("Unsupported KERNEL_CBLAS_LAYOUT value");
}
char ToAoclTranspose(KERNEL_CBLAS_TRANSPOSE transpose) {
switch (transpose) {
case KernelCblasNoTrans:
return 'n';
case KernelCblasTrans:
return 't';
case KernelCblasConjTrans:
case KernelCblasConjNoTrans:
break;
}
throw std::invalid_argument("Unsupported KERNEL_CBLAS_TRANSPOSE value");
}
} // namespace
// 映射表layout 从KERNEL_CBLAS_ORDER 映射到'r'或者'c',以及将KERNEL_CBLAS_TRANSPOSE映射到'n'或者't'
#ifdef __cplusplus
extern "C" {
#endif
void decode_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,
int32_t* c, const size_t ldc, const int32_t* oc) {
const char order = ToAoclOrder(layout);
const char op_a = ToAoclTranspose(transa);
const char op_b = ToAoclTranspose(transb);
(void)offsetc;
aocl_gemm_s8s8s32os32(order, op_a, op_b, static_cast<dim_t>(m), static_cast<dim_t>(n), static_cast<dim_t>(k),
static_cast<int32_t>(alpha), static_cast<const int8_t*>(a), static_cast<dim_t>(lda), 'n',
static_cast<const int8_t*>(b), static_cast<dim_t>(ldb), 'r', static_cast<int32_t>(beta), c,
static_cast<dim_t>(ldc), nullptr);
}
void prefill_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,
int32_t* c, const size_t ldc, const int32_t* oc) {
const char order = ToAoclOrder(layout);
const char op_a = ToAoclTranspose(transa);
const char op_b = ToAoclTranspose(transb);
(void)offsetc;
aocl_gemm_s8s8s32os32(order, op_a, op_b, static_cast<dim_t>(m), static_cast<dim_t>(n), static_cast<dim_t>(k),
static_cast<int32_t>(alpha), static_cast<const int8_t*>(a), static_cast<dim_t>(lda), 'n',
static_cast<const int8_t*>(b), static_cast<dim_t>(ldb), 'r', static_cast<int32_t>(beta), c,
static_cast<dim_t>(ldc), nullptr);
}
void prefill_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
const int32_t* oc) {
throw std::runtime_error("int4 not support prefill");
}
void decode_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
const int32_t* oc) {
throw std::runtime_error("int4 not support decode");
}
void reorder_B_gemm(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
const size_t n, const size_t ldb, const void* b, void* b_reordered) {
const char order = ToAoclOrder(layout);
const char op_b = ToAoclTranspose(transb);
aocl_reorder_s8s8s32os32(order, op_b, 'B', static_cast<const int8_t*>(b), static_cast<int8_t*>(b_reordered), k, n,
ldb);
}
size_t get_reorder_B_size(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
const size_t n) {
return aocl_get_reorder_buf_size_s8s8s32os32(ToAoclOrder(layout), ToAoclTranspose(transb), 'B', k, n);
}
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,42 @@
#pragma once
#include <cstddef>
#ifndef _BATCH_GEMM_KERNEL_API_
#define _BATCH_GEMM_KERNEL_API_
#include "../api/common.h"
#ifdef __cplusplus
extern "C" {
#endif
void decode_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,
int32_t* c, const size_t ldc, const int32_t* oc);
void prefill_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,
int32_t* c, const size_t ldc, const int32_t* oc);
void decode_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
const int32_t* oc);
void prefill_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
const int32_t* oc);
void reorder_B_gemm(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
const size_t n, const size_t ldb, const void* b, void* b_reordered);
size_t get_reorder_B_size(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
const size_t n);
#ifdef __cplusplus
}
#endif
#endif /*** _BATCH_GEMM_KERNEL_API_ ***/

View File

@@ -0,0 +1,56 @@
#include <stdexcept>
#include "../batch_gemm_api.hpp"
#include "utils.hpp"
#ifdef __cplusplus
extern "C" {
#endif
void decode_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,
int32_t* c, const size_t ldc, const int32_t* oc) {
BLASINT8* ptrA = (BLASINT8*)a;
BLASINT8* ptrB = (BLASINT8*)b;
int32_t* ptrC = c;
size_t split_n = n / N_SIZE;
for (size_t n_block = 0; n_block < split_n; n_block++) {
BLASINT8* cur_ptrA = ptrA;
BLASINT8* cur_ptrB = ptrB + n_block * (N_SIZE * k);
int32_t* cur_ptrC = ptrC + n_block * N_SIZE;
gemm_kernel_1x8(cur_ptrA, cur_ptrB, cur_ptrC, ldc, k, COMP_SV_LEN);
}
}
void decode_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
const int32_t* oc) {
BLASINT8* ptrA = (BLASINT8*)a;
BLASINT8* ptrB = (BLASINT8*)b;
int32_t* ptrC = c;
size_t split_n = n / N_SIZE;
for (size_t n_block = 0; n_block < split_n; n_block++) {
BLASINT8* cur_ptrA = ptrA;
BLASINT8* cur_ptrB = ptrB + n_block * (N_SIZE * (k / 2));
int32_t* cur_ptrC = ptrC + n_block * N_SIZE;
gemm_kernel_1x8_int4(cur_ptrA, cur_ptrB, cur_ptrC, (ldc / 2), (k / 2), COMP_SV_LEN);
}
}
void reorder_B_gemm(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
const size_t n, const size_t ldb, const void* b, void* b_reordered) {
throw std::runtime_error("haven't supported reorder");
}
size_t get_reorder_B_size(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
const size_t n) {
throw std::runtime_error("haven't supported reorder");
}
#ifdef __cplusplus
}
#endif

View File

@@ -1,4 +1,5 @@
#include "batch_gemm_api.hpp"
#include "prefillgemm_int4/integer_gemm_kernels.h"
#include "utils.hpp"
#ifdef __cplusplus
extern "C" {
#endif

Some files were not shown because too many files have changed in this diff Show More