mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 14:29:22 +00:00
Merge branch 'main' into develop-cht
This commit is contained in:
@@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
|
||||
|
||||
<h2 id="Updates">🔥 Updates</h2>
|
||||
|
||||
* **Oct 27, 2025**: Support Ascend NPU. ([Tutorial](./doc/zh/DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md))
|
||||
* **Oct 10, 2025**: Integrating into SGLang. ([Roadmap](https://github.com/sgl-project/sglang/issues/11425))
|
||||
* **Sept 11, 2025**: Support Qwen3-Next. ([Tutorial](./doc/en/Qwen3-Next.md))
|
||||
* **Sept 05, 2025**: Support Kimi-K2-0905. ([Tutorial](./doc/en/Kimi-K2.md))
|
||||
|
||||
0
config.json
Normal file
0
config.json
Normal file
@@ -1,3 +1,23 @@
|
||||
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
add_definitions(-DKTRANSFORMERS_USE_NPU=1)
|
||||
endif()
|
||||
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
set(ASCEND_HOME_PATH "$ENV{ASCEND_HOME_PATH}")
|
||||
message(STATUS "ASCEND_HOME_PATH is ${ASCEND_HOME_PATH}")
|
||||
include_directories(${ASCEND_HOME_PATH}/include)
|
||||
|
||||
link_directories(${TORCH_INSTALL_PREFIX}/../torch.libs)
|
||||
# find torch_npu
|
||||
execute_process(
|
||||
COMMAND python -c "import torch; import torch_npu; print(torch_npu.__path__[0])"
|
||||
OUTPUT_VARIABLE TORCH_NPU_PATH
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
message(STATUS "Found PTA at: ${TORCH_NPU_PATH}")
|
||||
find_library(PTA_LIBRARY torch_npu PATH "${TORCH_NPU_PATH}/lib")
|
||||
endif()
|
||||
|
||||
cmake_minimum_required(VERSION 3.21)
|
||||
find_program(GCC_COMPILER NAMES g++-13 g++-12 g++-11 g++ REQUIRED)
|
||||
|
||||
@@ -73,8 +73,19 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
|
||||
|
||||
# include_directories(/usr/include/tbb)
|
||||
# link_directories(/usr/lib64)
|
||||
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
add_definitions(-DKTRANSFORMERS_USE_NPU=1)
|
||||
endif()
|
||||
find_package(TBB REQUIRED)
|
||||
find_package(CUDA REQUIRED)
|
||||
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
# NPU 构建
|
||||
# find_package(CUDA REQUIRED) # NPU 情况不需要 CUDA
|
||||
else()
|
||||
# GPU 构建
|
||||
find_package(CUDA REQUIRED)
|
||||
endif()
|
||||
|
||||
# find_package(prometheus-cpp CONFIG REQUIRED)
|
||||
if(NOT TARGET prometheus-cpp::pull)
|
||||
@@ -83,18 +94,29 @@ else()
|
||||
message(STATUS "prometheus Found!")
|
||||
endif()
|
||||
|
||||
if(CUDA_FOUND)
|
||||
message(STATUS "CUDA Found!")
|
||||
message(STATUS "CUDA Version: ${CUDA_VERSION_STRING}")
|
||||
message(STATUS "CUDA Toolkit Root: ${CUDA_TOOLKIT_ROOT_DIR}")
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
# NPU 情况下不检查 CUDA
|
||||
else()
|
||||
message(FATAL_ERROR "CUDA not found!")
|
||||
if(CUDA_FOUND)
|
||||
message(STATUS "CUDA Found!")
|
||||
message(STATUS "CUDA Version: ${CUDA_VERSION_STRING}")
|
||||
message(STATUS "CUDA Toolkit Root: ${CUDA_TOOLKIT_ROOT_DIR}")
|
||||
else()
|
||||
message(FATAL_ERROR "CUDA not found!")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_subdirectory(src)
|
||||
|
||||
if(BUILD_TEST)
|
||||
add_subdirectory(test)
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
message(STATUS "Build test...")
|
||||
set(THIRD_PARTY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party)
|
||||
add_subdirectory(${THIRD_PARTY_DIR}/spdlog ${CMAKE_CURRENT_SOURCE_DIR}/../build/third_party/spdlog)
|
||||
add_subdirectory(test)
|
||||
else()
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message(STATUS "BUILD_PYTHON_EXT: ${BUILD_PYTHON_EXT}")
|
||||
@@ -126,7 +148,12 @@ endif()
|
||||
|
||||
set(PHOTON_CXX_STANDARD 14 CACHE INTERNAL "C++ standard")
|
||||
|
||||
set(CMAKE_CXX_FLAGS "-O3 -march=native")
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
set(CMAKE_CXX_FLAGS "-O3 -march=armv8.2-a")
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS "-O3 -march=native")
|
||||
endif()
|
||||
|
||||
message(STATUS "CMAKE_CXX_FLAGS of PhotonLibOS: ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
add_subdirectory(${THIRD_PARTY_DIR}/PhotonLibOS ${THIRD_PARTY_BUILD_DIR}/PhotonLibOS)
|
||||
|
||||
@@ -23,24 +23,46 @@ target_link_libraries(cache_entry PUBLIC gpu_cache)
|
||||
|
||||
add_library(gpu_cache gpu_cache.cpp)
|
||||
add_third_party_includes(gpu_cache)
|
||||
target_link_libraries(gpu_cache PUBLIC xxHash::xxhash ${TORCH_LIBRARIES} cuda_stream_manager)
|
||||
|
||||
# gpu_cache
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
target_include_directories(gpu_cache PUBLIC ${TORCH_NPU_PATH}/include)
|
||||
find_package(Python COMPONENTS Interpreter Development REQUIRED)
|
||||
message("python include location: " ${Python_INCLUDE_DIRS} " lib location: " ${Python_LIBRARIES})
|
||||
target_link_libraries(gpu_cache PUBLIC xxHash::xxhash "${TORCH_LIBRARIES}" "${TORCH_PYTHON_LIBRARY}" "${PTA_LIBRARY}" ${Python_LIBRARIES} cuda_stream_manager)
|
||||
else()
|
||||
target_link_libraries(gpu_cache PUBLIC xxHash::xxhash ${TORCH_LIBRARIES} cuda_stream_manager)
|
||||
endif()
|
||||
|
||||
# kvc2
|
||||
add_library(kvc2 prefix.cpp)
|
||||
target_include_directories(kvc2 PRIVATE ${THIRD_PARTY_DIR}/nlohmann/single_include)
|
||||
add_third_party_includes(kvc2)
|
||||
target_link_libraries(kvc2 PUBLIC TBB::tbb xxHash::xxhash cache_entry cuda_stream_manager page_aligned_memory_pool ${TORCH_LIBRARIES} prometheus-cpp::pull kvc2_metrics)
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
target_link_libraries(kvc2 PUBLIC TBB::tbb xxHash::xxhash cache_entry cuda_stream_manager page_aligned_memory_pool prometheus-cpp::pull kvc2_metrics)
|
||||
else()
|
||||
target_link_libraries(kvc2 PUBLIC TBB::tbb xxHash::xxhash cache_entry cuda_stream_manager page_aligned_memory_pool ${TORCH_LIBRARIES} prometheus-cpp::pull kvc2_metrics)
|
||||
endif()
|
||||
|
||||
message(STATUS "CMAKE_SOURCE_DIR: " ${CMAKE_SOURCE_DIR})
|
||||
|
||||
# async_store
|
||||
add_library(async_store async_store.cpp)
|
||||
target_include_directories(async_store PRIVATE ${THIRD_PARTY_DIR}/nlohmann/single_include)
|
||||
target_include_directories(async_store PRIVATE ${THIRD_PARTY_DIR}/PhotonLibOS/include)
|
||||
target_include_directories(async_store PRIVATE ${THIRD_PARTY_DIR}/spdlog/include)
|
||||
target_include_directories(async_store PRIVATE ${THIRD_PARTY_DIR}/PhotonLibOS/include)
|
||||
target_link_libraries(async_store PUBLIC photon_static pthread)
|
||||
|
||||
|
||||
|
||||
# cuda_stream_manager
|
||||
add_library(cuda_stream_manager cuda_stream_manager.cpp)
|
||||
target_include_directories(cuda_stream_manager PUBLIC ${THIRD_PARTY_DIR}/nlohmann/single_include)
|
||||
target_include_directories(cuda_stream_manager PUBLIC ${THIRD_PARTY_DIR}/spdlog/include)
|
||||
target_include_directories(cuda_stream_manager PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
|
||||
target_link_libraries(cuda_stream_manager PUBLIC CUDA::cudart)
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
set(ASCEND_HOME_PATH "$ENV{ASCEND_HOME_PATH}")
|
||||
target_include_directories(cuda_stream_manager PUBLIC ${ASCEND_HOME_PATH}/include)
|
||||
target_link_directories(cuda_stream_manager PUBLIC ${ASCEND_HOME_PATH}/lib64)
|
||||
target_link_libraries(cuda_stream_manager PUBLIC ascendcl)
|
||||
else()
|
||||
target_include_directories(cuda_stream_manager PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
|
||||
target_link_libraries(cuda_stream_manager PUBLIC CUDA::cudart)
|
||||
endif()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#include "cuda_stream_manager.hh"
|
||||
#include <cuda_runtime.h>
|
||||
#include <functional>
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO
|
||||
@@ -9,6 +10,12 @@
|
||||
#define FMT_HEADER_ONLY
|
||||
#include "spdlog/spdlog.h"
|
||||
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
#include "acl/acl_rt.h"
|
||||
#else
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
CudaStreamManager::CudaStreamManager(const std::vector<size_t>& device_ids, int num_streams_per_device) {
|
||||
for (int device_id : device_ids) {
|
||||
auto x = std::unique_ptr<DeviceInfo>(new DeviceInfo);
|
||||
@@ -18,25 +25,59 @@ CudaStreamManager::CudaStreamManager(const std::vector<size_t>& device_ids, int
|
||||
device_info.stop_flag = false;
|
||||
|
||||
// 设置设备
|
||||
cudaError_t err = cudaSetDevice(device_id);
|
||||
if (err != cudaSuccess) {
|
||||
SPDLOG_WARN("cudaSetDevice failed on device {}: {}", device_id, cudaGetErrorString(err));
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
aclError acl_err = aclrtSetDevice(device_id);
|
||||
if (acl_err != ACL_SUCCESS) {
|
||||
SPDLOG_WARN("aclrtSetDevice failed on device {}: {}", device_id, acl_err);
|
||||
throw std::runtime_error("aclrtSetDevice failed");
|
||||
}
|
||||
#else
|
||||
cudaError_t cuda_err = cudaSetDevice(device_id);
|
||||
if (cuda_err != cudaSuccess) {
|
||||
SPDLOG_WARN("cudaSetDevice failed on device {}: {}", device_id, cudaGetErrorString(cuda_err));
|
||||
throw std::runtime_error("cudaSetDevice failed");
|
||||
}
|
||||
#endif
|
||||
|
||||
// 创建 CUDA 流
|
||||
// 创建流
|
||||
device_info.streams.resize(num_streams_per_device);
|
||||
for (int i = 0; i < num_streams_per_device; ++i) {
|
||||
err = cudaStreamCreate(&device_info.streams[i]);
|
||||
if (err != cudaSuccess) {
|
||||
SPDLOG_WARN("Failed to create CUDA stream on device {}: {}", device_id, cudaGetErrorString(err));
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
acl_err = aclrtCreateStream(&device_info.streams[i]);
|
||||
if (acl_err != ACL_SUCCESS) {
|
||||
SPDLOG_WARN("Failed to create NPU stream on device {}: {}", device_id, acl_err);
|
||||
throw std::runtime_error("Failed to create NPU stream");
|
||||
}
|
||||
#else
|
||||
cuda_err = cudaStreamCreate(&device_info.streams[i]);
|
||||
if (cuda_err != cudaSuccess) {
|
||||
SPDLOG_WARN("Failed to create CUDA stream on device {}: {}", device_id, cudaGetErrorString(cuda_err));
|
||||
throw std::runtime_error("Failed to create CUDA stream");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// 启动设备工作线程
|
||||
// 启动工作线程
|
||||
device_info.worker_thread = std::thread(&CudaStreamManager::deviceWorker, this, std::ref(device_info));
|
||||
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
// NPU需要额外的回调线程
|
||||
device_info.callback_thread = std::thread(&CudaStreamManager::deviceCallback, this, std::ref(device_info));
|
||||
|
||||
// 绑定回调线程
|
||||
for (int i = 0; i < num_streams_per_device; ++i) {
|
||||
std::ostringstream oss;
|
||||
oss << device_info.callback_thread.get_id();
|
||||
uint64_t tid = std::stoull(oss.str());
|
||||
acl_err = aclrtSubscribeReport(tid, device_info.streams[i]);
|
||||
SPDLOG_DEBUG("subscribe stream callback report on device {} with tid {} in idx {}", device_id, tid, i);
|
||||
if (acl_err != ACL_SUCCESS) {
|
||||
SPDLOG_WARN("Failed to subscribe stream callback report on device {}: {}", device_id, acl_err);
|
||||
throw std::runtime_error("Failed to create stream callback job");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
devices_.push_back(std::move(x));
|
||||
}
|
||||
}
|
||||
@@ -52,15 +93,39 @@ CudaStreamManager::~CudaStreamManager() {
|
||||
|
||||
// 等待所有线程结束
|
||||
for (auto& device_info : devices_) {
|
||||
// 等待工作线程结束
|
||||
if (device_info->worker_thread.joinable()) {
|
||||
device_info->worker_thread.join();
|
||||
}
|
||||
|
||||
// 销毁 CUDA 流
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
// NPU需要额外的回调线程处理
|
||||
aclrtSetDevice(device_info->device_id);
|
||||
|
||||
// 解绑callback任务并等待callback线程结束
|
||||
std::ostringstream oss;
|
||||
oss << device_info->callback_thread.get_id();
|
||||
uint64_t tid = std::stoull(oss.str());
|
||||
for (auto& stream: device_info->streams) {
|
||||
aclrtUnSubscribeReport(tid, stream);
|
||||
}
|
||||
if (device_info->callback_thread.joinable()) {
|
||||
device_info->callback_thread.join();
|
||||
}
|
||||
#endif
|
||||
|
||||
// 销毁流
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
for (auto& stream : device_info->streams) {
|
||||
aclrtDestroyStream(stream);
|
||||
aclrtResetDevice(device_info->device_id);
|
||||
}
|
||||
#else
|
||||
cudaSetDevice(device_info->device_id);
|
||||
for (auto& stream : device_info->streams) {
|
||||
cudaStreamDestroy(stream);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,28 +142,47 @@ void CudaStreamManager::submitRequest(std::shared_ptr<Request> request) {
|
||||
|
||||
void CudaStreamManager::deviceWorker(DeviceInfo& device_info) {
|
||||
// 设置设备
|
||||
cudaError_t err = cudaSetDevice(device_info.device_id);
|
||||
if (err != cudaSuccess) {
|
||||
SPDLOG_WARN("cudaSetDevice failed in worker thread for device {}: {}", device_info.device_id,
|
||||
cudaGetErrorString(err));
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
aclError acl_err = aclrtSetDevice(device_info.device_id);
|
||||
if (acl_err != ACL_SUCCESS) {
|
||||
SPDLOG_WARN("aclrtSetDevice failed in worker thread for device {}: {}", device_info.device_id, acl_err);
|
||||
return;
|
||||
}
|
||||
#else
|
||||
cudaError_t cuda_err = cudaSetDevice(device_info.device_id);
|
||||
if (cuda_err != cudaSuccess) {
|
||||
SPDLOG_WARN("cudaSetDevice failed in worker thread for device {}: {}", device_info.device_id,
|
||||
cudaGetErrorString(cuda_err));
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
while (device_info.stop_flag.load() == false) {
|
||||
auto request = device_info.request_queue.dequeue();
|
||||
if (request->should_exit) {
|
||||
return;
|
||||
}
|
||||
// 处理请求
|
||||
|
||||
SPDLOG_DEBUG("Getting request on device {}, count {}", device_info.device_id, request->host_mem_addresses.size());
|
||||
int stream_index = device_info.next_stream_index;
|
||||
cudaStream_t stream = device_info.streams[stream_index];
|
||||
auto stream = device_info.streams[stream_index];
|
||||
device_info.next_stream_index = (device_info.next_stream_index + 1) % device_info.streams.size();
|
||||
|
||||
size_t num_transfers = request->host_mem_addresses.size();
|
||||
for (size_t i = 0; i < num_transfers; ++i) {
|
||||
void* dst = request->device_mem_addresses[i];
|
||||
void* src = request->host_mem_addresses[i];
|
||||
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
if (request->direction == ACL_MEMCPY_DEVICE_TO_HOST) {
|
||||
std::swap(dst, src);
|
||||
}
|
||||
aclError err = aclrtMemcpyAsync(dst, request->sizes[i], src, request->sizes[i], request->direction, stream);
|
||||
if (err != ACL_SUCCESS) {
|
||||
SPDLOG_WARN("aclrtMemcpyAsync failed on device {}: {}", device_info.device_id, err);
|
||||
continue;
|
||||
}
|
||||
#else
|
||||
if (request->direction == cudaMemcpyDeviceToHost) {
|
||||
std::swap(dst, src);
|
||||
}
|
||||
@@ -109,15 +193,31 @@ void CudaStreamManager::deviceWorker(DeviceInfo& device_info) {
|
||||
// 可以根据需要处理错误,这里简单地继续
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// 添加回调函数,因为是异步,所以需要包起来
|
||||
// 添加回调函数
|
||||
struct CallbackData {
|
||||
std::function<void()> callback;
|
||||
};
|
||||
CallbackData* cb_data = new CallbackData{request->callback};
|
||||
|
||||
err = cudaLaunchHostFunc(
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
aclError err = aclrtLaunchCallback(
|
||||
[](void* data) {
|
||||
CallbackData* cb_data = static_cast<CallbackData*>(data);
|
||||
cb_data->callback();
|
||||
delete cb_data;
|
||||
},
|
||||
cb_data,
|
||||
ACL_CALLBACK_BLOCK,
|
||||
stream);
|
||||
|
||||
if (err != ACL_SUCCESS) {
|
||||
SPDLOG_WARN("aclrtLaunchCallback failed on device {}: {}", device_info.device_id, err);
|
||||
}
|
||||
#else
|
||||
cudaError_t err = cudaLaunchHostFunc(
|
||||
stream,
|
||||
[](void* data) {
|
||||
// SPDLOG_DEBUG("Callback function called");
|
||||
@@ -129,7 +229,30 @@ void CudaStreamManager::deviceWorker(DeviceInfo& device_info) {
|
||||
|
||||
if (err != cudaSuccess) {
|
||||
SPDLOG_WARN("cudaLaunchHostFunc failed on device {}: {}", device_info.device_id, cudaGetErrorString(err));
|
||||
// 根据需要处理错误
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
void CudaStreamManager::deviceCallback(DeviceInfo& device_info) {
|
||||
aclError err = aclrtSetDevice(device_info.device_id);
|
||||
if (err != ACL_SUCCESS) {
|
||||
SPDLOG_WARN("aclrtSetDevice failed in callback thread for device {}: {}", device_info.device_id, err);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
int timeout = 60 * 1000; // ms
|
||||
while (device_info.stop_flag.load() == false) {
|
||||
err = aclrtProcessReport(timeout);
|
||||
if (err != ACL_SUCCESS) {
|
||||
if (err == ACL_ERROR_RT_THREAD_SUBSCRIBE) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
continue;
|
||||
}
|
||||
SPDLOG_WARN("aclrtProcessReport failed in callback thread for device {}: {}", device_info.device_id, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -8,7 +8,6 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
@@ -16,6 +15,11 @@
|
||||
#include <vector>
|
||||
#include "utils/mpsc.hpp"
|
||||
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
#include "acl/acl_mdl.h"
|
||||
#else
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
class CudaStreamManager {
|
||||
public:
|
||||
// 构造函数,接受要使用的设备 ID 列表和每个设备的流数量
|
||||
@@ -29,7 +33,11 @@ class CudaStreamManager {
|
||||
std::vector<void*> host_mem_addresses;
|
||||
std::vector<void*> device_mem_addresses;
|
||||
std::vector<size_t> sizes;
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
aclrtMemcpyKind direction;
|
||||
#else
|
||||
cudaMemcpyKind direction;
|
||||
#endif
|
||||
std::function<void()> callback;
|
||||
};
|
||||
|
||||
@@ -40,7 +48,12 @@ class CudaStreamManager {
|
||||
struct DeviceInfo {
|
||||
int device_id;
|
||||
std::thread worker_thread;
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
std::thread callback_thread;
|
||||
std::vector<aclrtStream> streams;
|
||||
#else
|
||||
std::vector<cudaStream_t> streams;
|
||||
#endif
|
||||
int next_stream_index;
|
||||
MPSCQueueConsumerLock<std::shared_ptr<Request>> request_queue;
|
||||
std::atomic_bool stop_flag;
|
||||
@@ -51,4 +64,7 @@ class CudaStreamManager {
|
||||
|
||||
// 私有方法
|
||||
void deviceWorker(DeviceInfo& device_info);
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
void deviceCallback(DeviceInfo& device_info); // NPU 专用回调线程函数
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -10,6 +10,24 @@
|
||||
namespace kvc2 {
|
||||
|
||||
GPUPageCache::GPUPageCache(GPUPageCacheConfig& config) : config(config) {
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
size_t gpu_count = c10_npu::device_count();
|
||||
if (gpu_count > 0) {
|
||||
SPDLOG_INFO("Number of available NPUs: {}, want {}", gpu_count, config.gpu_devices_id.size());
|
||||
if (gpu_count < config.gpu_devices_id.size()) {
|
||||
SPDLOG_ERROR("Not enough NPUs available.");
|
||||
exit(0);
|
||||
}
|
||||
for (auto x : config.gpu_devices_id) {
|
||||
std::string device_str = "npu:" + std::to_string(x);
|
||||
// torch_npu::init_npu(device_str); should inited in scheduler
|
||||
gpu_devices.push_back(at::Device(device_str));
|
||||
}
|
||||
} else {
|
||||
SPDLOG_ERROR("NPU is not available on this system.");
|
||||
exit(0);
|
||||
}
|
||||
#else
|
||||
if (torch::cuda::is_available()) {
|
||||
size_t gpu_count = torch::cuda::device_count();
|
||||
SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count, config.gpu_devices_id.size());
|
||||
@@ -24,6 +42,7 @@ GPUPageCache::GPUPageCache(GPUPageCacheConfig& config) : config(config) {
|
||||
SPDLOG_ERROR("CUDA is not available on this system.");
|
||||
exit(0);
|
||||
}
|
||||
#endif
|
||||
|
||||
SPDLOG_WARN("Creating GPU Cache");
|
||||
shape.push_back(config.layer_count);
|
||||
@@ -47,7 +66,9 @@ GPUPageCache::GPUPageCache(GPUPageCacheConfig& config) : config(config) {
|
||||
if (config.k_cache_on) {
|
||||
for (size_t i = 0; i < config.gpu_devices_id.size(); i++) {
|
||||
auto k = torch::zeros(shape, torch::TensorOptions().dtype(config.tensor_type));
|
||||
SPDLOG_INFO("ALLOCATE on CPU OK");
|
||||
k = k.to(gpu_devices[i]);
|
||||
SPDLOG_INFO("ALLOCATE on NPU OK");
|
||||
|
||||
k_cache.push_back(k);
|
||||
|
||||
@@ -95,6 +116,10 @@ GPUPageCache::GPUPageCache(GPUPageCacheConfig& config) : config(config) {
|
||||
std::unique_ptr<CudaStreamManager>(new CudaStreamManager(config.gpu_devices_id, config.num_streams_per_device));
|
||||
}
|
||||
|
||||
kvc2::GPUPageCache::~GPUPageCache() {
|
||||
// torch_npu::finalize_npu()
|
||||
}
|
||||
|
||||
bool GPUPageCache::alloc_col(std::vector<std::vector<std::shared_ptr<CacheBlockEntry>>>& k_entries,
|
||||
std::vector<std::vector<std::shared_ptr<CacheBlockEntry>>>& v_entries, size_t at) {
|
||||
std::lock_guard<std::mutex> lg(lock);
|
||||
@@ -218,8 +243,13 @@ std::vector<std::unique_lock<CacheBlockEntry::MutexT>> GPUPageCache::try_lock_co
|
||||
return re;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<CudaStreamManager::Request>> GPUPageCache::basic_request(cudaMemcpyKind direction,
|
||||
std::function<void()> callback) {
|
||||
std::vector<std::shared_ptr<CudaStreamManager::Request>> GPUPageCache::basic_request(
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
aclrtMemcpyKind direction,
|
||||
#else
|
||||
cudaMemcpyKind direction,
|
||||
#endif
|
||||
std::function<void()> callback) {
|
||||
std::vector<std::shared_ptr<CudaStreamManager::Request>> re;
|
||||
re.resize(config.gpu_devices_id.size(), nullptr);
|
||||
for (size_t i = 0; i < re.size(); i++) {
|
||||
|
||||
@@ -3,12 +3,20 @@
|
||||
|
||||
#include <torch/torch.h>
|
||||
#include "cache_entry.hh"
|
||||
#include "cuda_stream_manager.hh"
|
||||
#include "defs.h"
|
||||
#include "kvc2.h"
|
||||
#include "metrics.h"
|
||||
#include "utils/periodic_task.hpp"
|
||||
|
||||
// 根据设备类型包含不同的头文件和流管理器
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
#include "torch_npu/csrc/libs/torch_npu.h"
|
||||
#include "torch_npu/csrc/libs/init_npu.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUFunctions.h"
|
||||
#else
|
||||
#include "cuda_stream_manager.hh"
|
||||
#endif
|
||||
|
||||
namespace kvc2 {
|
||||
|
||||
class GPUPageCache {
|
||||
@@ -43,6 +51,7 @@ class GPUPageCache {
|
||||
std::unique_ptr<periodic::PeriodicTask> background_flush_back =nullptr;
|
||||
|
||||
GPUPageCache(GPUPageCacheConfig& config);
|
||||
~GPUPageCache(); // 统一添加析构函数声明
|
||||
|
||||
std::vector<size_t> gpu_only_alloc_col(size_t count);
|
||||
void gpu_only_free_cols(std::vector<size_t> cols);
|
||||
@@ -59,8 +68,14 @@ class GPUPageCache {
|
||||
|
||||
void free_col(size_t at);
|
||||
|
||||
std::vector<std::shared_ptr<CudaStreamManager::Request>> basic_request(cudaMemcpyKind direction,
|
||||
std::function<void()> callback);
|
||||
// 统一内存拷贝类型接口
|
||||
std::vector<std::shared_ptr<CudaStreamManager::Request>> basic_request(
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
aclrtMemcpyKind direction,
|
||||
#else
|
||||
cudaMemcpyKind direction,
|
||||
#endif
|
||||
std::function<void()> callback);
|
||||
|
||||
void submit_requests(std::vector<std::shared_ptr<CudaStreamManager::Request>> reqs);
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
#else
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
#include <tbb/concurrent_hash_map.h>
|
||||
|
||||
#include <algorithm>
|
||||
@@ -986,13 +989,23 @@ struct DoubleCacheHandle : public DoubleCacheHandleInterface {
|
||||
}
|
||||
});
|
||||
|
||||
cudaMemcpyKind direction;
|
||||
if (option == IO_Read || option == IO_ForceRead) {
|
||||
direction = cudaMemcpyHostToDevice;
|
||||
}
|
||||
if (option == IO_Write || option == IO_ForceWrite) {
|
||||
direction = cudaMemcpyDeviceToHost;
|
||||
}
|
||||
#if defined(KTRANSFORMERS_USE_NPU) // NPU平台分支
|
||||
aclrtMemcpyKind direction;
|
||||
if (option == IO_Read || option == IO_ForceRead) {
|
||||
direction = ACL_MEMCPY_HOST_TO_DEVICE;
|
||||
}
|
||||
if (option == IO_Write || option == IO_ForceWrite) {
|
||||
direction = ACL_MEMCPY_DEVICE_TO_HOST;
|
||||
}
|
||||
#else // 默认GPU分支
|
||||
cudaMemcpyKind direction;
|
||||
if (option == IO_Read || option == IO_ForceRead) {
|
||||
direction = cudaMemcpyHostToDevice;
|
||||
}
|
||||
if (option == IO_Write || option == IO_ForceWrite) {
|
||||
direction = cudaMemcpyDeviceToHost;
|
||||
}
|
||||
#endif
|
||||
|
||||
auto reqs = gpu_cache->basic_request(direction, [io_helper]() { io_helper->batch_promise.set(); });
|
||||
|
||||
@@ -1752,7 +1765,11 @@ void GPUPageCache::gpu_background_flush() {
|
||||
std::vector<CacheBlockEntry*> entries;
|
||||
std::vector<std::unique_lock<CacheBlockEntry::MutexT>> uls;
|
||||
BatchPromise promise(config.gpu_devices_id.size());
|
||||
auto reqs = basic_request(cudaMemcpyDeviceToHost, [&promise]() { promise.set(); });
|
||||
#if defined(KTRANSFORMERS_USE_NPU) // NPU分支
|
||||
auto reqs = basic_request(ACL_MEMCPY_DEVICE_TO_HOST, [&promise]() { promise.set(); });
|
||||
#else
|
||||
auto reqs = basic_request(cudaMemcpyDeviceToHost, [&promise]() { promise.set(); });
|
||||
#endif
|
||||
|
||||
for (size_t i = 0; i < config.total_kvcache_pages; i++) {
|
||||
std::lock_guard<std::mutex> lg(this->lock);
|
||||
|
||||
@@ -4,17 +4,38 @@ add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})
|
||||
|
||||
set(UTILS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/utils)
|
||||
|
||||
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
|
||||
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
add_definitions(-DKTRANSFORMERS_USE_NPU=1)
|
||||
endif()
|
||||
|
||||
add_library(sched_metrics metrics.cpp)
|
||||
target_include_directories(sched_metrics PRIVATE ${UTILS_DIR})
|
||||
target_link_libraries(sched_metrics PUBLIC prometheus-cpp::pull)
|
||||
|
||||
|
||||
add_library(sched scheduler.cpp)
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
#target_link_directories(sched PUBLIC ${TORCH_NPU_PATH}/lib)
|
||||
target_include_directories(sched PUBLIC ${TORCH_NPU_PATH}/include)
|
||||
endif()
|
||||
target_include_directories(sched PRIVATE ${SPDLOG_DIR}/include ${FMT_DIR}/include ${UTILS_DIR} ${KVC2_INCLUDE_DIR})
|
||||
target_link_libraries(sched PUBLIC pthread ${TORCH_LIBRARIES} kvc2 async_store sched_metrics)
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
target_link_libraries(sched PUBLIC pthread "${TORCH_LIBRARIES}" "${TORCH_PYTHON_LIBRARY}" "${PTA_LIBRARY}" kvc2 async_store sched_metrics)
|
||||
else()
|
||||
target_link_libraries(sched PUBLIC pthread ${TORCH_LIBRARIES} kvc2 async_store sched_metrics)
|
||||
endif()
|
||||
|
||||
|
||||
pybind11_add_module(sched_ext bind.cpp)
|
||||
target_link_libraries(sched_ext PUBLIC sched ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
#target_link_directories(sched_ext PUBLIC ${TORCH_NPU_PATH}/lib)
|
||||
target_include_directories(sched_ext PUBLIC ${TORCH_NPU_PATH}/include)
|
||||
target_link_libraries(sched_ext PUBLIC "${TORCH_LIBRARIES}" "${TORCH_PYTHON_LIBRARY}" "${PTA_LIBRARY}" sched)
|
||||
else()
|
||||
target_link_libraries(sched_ext PUBLIC sched ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -25,28 +25,45 @@ using json = nlohmann::json;
|
||||
namespace scheduler {
|
||||
|
||||
void Settings::auto_derive() {
|
||||
// 统一设备数量获取方式
|
||||
gpu_device_count = gpu_device_id.size();
|
||||
if (torch::cuda::is_available()) {
|
||||
size_t gpu_count = torch::cuda::device_count();
|
||||
SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count,
|
||||
gpu_device_count);
|
||||
if (gpu_count < gpu_device_count) {
|
||||
SPDLOG_ERROR("Not enough GPUs available.");
|
||||
|
||||
// 设备初始化分支
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
size_t npu_count = c10_npu::device_count();
|
||||
SPDLOG_INFO("Number of available NPUs: {}, want {}", npu_count, gpu_device_count);
|
||||
if (npu_count < gpu_device_count) {
|
||||
SPDLOG_ERROR("Not enough NPUs available.");
|
||||
exit(0);
|
||||
}
|
||||
for (size_t i = 0; i < gpu_device_count; i++) {
|
||||
devices.push_back(torch::Device(torch::kCUDA, gpu_device_id[i]));
|
||||
std::string device_str = "npu:" + std::to_string(gpu_device_id[i]);
|
||||
devices.push_back(torch::Device(device_str));
|
||||
}
|
||||
} else {
|
||||
SPDLOG_ERROR("CUDA is not available on this system.");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if (model_settings.num_k_heads % gpu_device_count != 0) {
|
||||
size_t head_per_gpu = model_settings.num_k_heads;
|
||||
#else // GPU模式
|
||||
if (torch::cuda::is_available()) {
|
||||
size_t gpu_count = torch::cuda::device_count();
|
||||
SPDLOG_INFO("Number of available GPUs: {}, want {}", gpu_count, gpu_device_count);
|
||||
if (gpu_count < gpu_device_count) {
|
||||
SPDLOG_ERROR("Not enough GPUs available.");
|
||||
exit(0);
|
||||
}
|
||||
for (size_t i = 0; i < gpu_device_count; i++) {
|
||||
devices.push_back(torch::Device(torch::kCUDA, gpu_device_id[i]));
|
||||
}
|
||||
} else {
|
||||
SPDLOG_ERROR("CUDA is not available on this system.");
|
||||
exit(0);
|
||||
}
|
||||
if (model_settings.num_k_heads % gpu_device_count != 0) {
|
||||
SPDLOG_ERROR("num_k_heads {} is not divisible by gpu_device_count {}",
|
||||
model_settings.num_k_heads, gpu_device_count);
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
// 统一head_per_gpu计算方式(每设备分配头数)
|
||||
size_t head_per_gpu = model_settings.num_k_heads / gpu_device_count;
|
||||
#endif
|
||||
|
||||
size_t gpu_memory_available = gpu_memory_size * memory_utilization_percentage;
|
||||
if (gpu_memory_available * gpu_device_count <
|
||||
@@ -58,13 +75,12 @@ void Settings::auto_derive() {
|
||||
}
|
||||
|
||||
assert(model_settings.k_head_dim % model_settings.num_k_heads == 0);
|
||||
size_t head_per_gpu = model_settings.num_k_heads / gpu_device_count;
|
||||
size_t gpu_memory_for_kv_cache =
|
||||
gpu_memory_available /*- model_settings.params_nbytes() /
|
||||
gpu_device_count*/
|
||||
;
|
||||
SPDLOG_INFO(
|
||||
"Each GPU Total: {}MiB, Model Params: {}MiB, KVCache: {}MiB, Left: {}MiB",
|
||||
"Each Device Total: {}MiB, Model Params: {}MiB, KVCache: {}MiB, Left: {}MiB",
|
||||
gpu_memory_size / (1 << 20),
|
||||
model_settings.params_nbytes() / gpu_device_count / (1 << 20),
|
||||
gpu_memory_for_kv_cache / (1 << 20),
|
||||
@@ -88,14 +104,18 @@ void Settings::auto_derive() {
|
||||
max_total_kvcache_pages);
|
||||
}
|
||||
|
||||
if (page_size % 256 != 0) {
|
||||
SPDLOG_ERROR("page_size {} is not divisible by 256", page_size);
|
||||
assert(false);
|
||||
}
|
||||
if (page_size < 256) {
|
||||
SPDLOG_ERROR("page_size {} is smaller than 256", page_size);
|
||||
assert(false);
|
||||
}
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
// NPU ND limit block size 16-128
|
||||
#else
|
||||
if (page_size % 256 != 0) {
|
||||
SPDLOG_ERROR("page_size {} is not divisible by 256", page_size);
|
||||
assert(false);
|
||||
}
|
||||
if (page_size < 256) {
|
||||
SPDLOG_ERROR("page_size {} is smaller than 256", page_size);
|
||||
assert(false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string BatchQueryTodo::debug() {
|
||||
@@ -284,7 +304,11 @@ struct KVC2_Maintainer {
|
||||
.full_kv_cache_on_each_gpu = settings.full_kv_cache_on_each_gpu,
|
||||
.k_cache_on = settings.k_cache_on,
|
||||
.v_cache_on = settings.v_cache_on,
|
||||
.tensor_type = torch::kBFloat16,
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
.tensor_type = torch::kFloat16,
|
||||
#else
|
||||
.tensor_type = torch::kBFloat16,
|
||||
#endif
|
||||
};
|
||||
|
||||
auto model_configs_path =
|
||||
|
||||
@@ -6,6 +6,13 @@
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
// 条件编译:仅在NPU环境下引入相关头文件
|
||||
#ifdef KTRANSFORMERS_USE_NPU
|
||||
#include "torch_npu/csrc/libs/torch_npu.h"
|
||||
#include "torch_npu/csrc/libs/init_npu.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUFunctions.h"
|
||||
#endif
|
||||
|
||||
namespace scheduler {
|
||||
|
||||
using Token = uint32_t;
|
||||
|
||||
@@ -42,6 +42,11 @@ option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA"
|
||||
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
|
||||
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
|
||||
option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF)
|
||||
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
|
||||
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
add_definitions(-DKTRANSFORMERS_USE_NPU=1)
|
||||
endif()
|
||||
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
@@ -89,6 +94,9 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STR
|
||||
endif ()
|
||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})
|
||||
else()
|
||||
if(KTRANSFORMERS_USE_NPU)
|
||||
list(APPEND ARCH_FLAGS -march=armv8.2-a+fp16+fp16fml+dotprod -lnuma)
|
||||
endif()
|
||||
check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
|
||||
if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
|
||||
list(APPEND ARCH_FLAGS -mfp16-format=ieee)
|
||||
@@ -116,36 +124,38 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
|
||||
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
||||
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
|
||||
message(STATUS "x86 detected")
|
||||
set(HOST_IS_X86 TRUE)
|
||||
set(HAS_AVX512 TRUE)
|
||||
set(__HAS_AMX__ TRUE)
|
||||
add_compile_definitions(__x86_64__)
|
||||
# check AVX512
|
||||
execute_process(
|
||||
COMMAND lscpu
|
||||
OUTPUT_VARIABLE LSCPU_OUTPUT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
|
||||
|
||||
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
|
||||
if(NOT KTRANSFORMERS_USE_NPU)
|
||||
set(HOST_IS_X86 TRUE)
|
||||
set(HAS_AVX512 TRUE)
|
||||
set(__HAS_AMX__ TRUE)
|
||||
add_compile_definitions(__x86_64__)
|
||||
# check AVX512
|
||||
execute_process(
|
||||
COMMAND lscpu
|
||||
OUTPUT_VARIABLE LSCPU_OUTPUT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
|
||||
|
||||
if (COMPILER_SUPPORTS_AVX512F GREATER -1)
|
||||
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
|
||||
add_compile_definitions(__HAS_AVX512F__)
|
||||
else()
|
||||
message(STATUS "Compiler and/or CPU do NOT support AVX512F")
|
||||
set(HAS_AVX512 False)
|
||||
endif()
|
||||
|
||||
# check AMX
|
||||
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
|
||||
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
|
||||
|
||||
if (COMPILER_SUPPORTS_AVX512F GREATER -1)
|
||||
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
|
||||
add_compile_definitions(__HAS_AVX512F__)
|
||||
else()
|
||||
message(STATUS "Compiler and/or CPU do NOT support AVX512F")
|
||||
set(HAS_AVX512 False)
|
||||
endif()
|
||||
|
||||
if(COMPILER_SUPPORTS_AMX GREATER -1)
|
||||
message(STATUS "Compiler supports AMX")
|
||||
add_compile_definitions(__HAS_AMX__)
|
||||
else()
|
||||
message(STATUS "Compiler does NOT support AMX")
|
||||
# check AMX
|
||||
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
|
||||
|
||||
if(COMPILER_SUPPORTS_AMX GREATER -1)
|
||||
message(STATUS "Compiler supports AMX")
|
||||
add_compile_definitions(__HAS_AMX__)
|
||||
else()
|
||||
message(STATUS "Compiler does NOT support AMX")
|
||||
endif()
|
||||
endif()
|
||||
if (MSVC)
|
||||
# instruction set detection for MSVC only
|
||||
@@ -306,7 +316,7 @@ elseif (UNIX)
|
||||
endif()
|
||||
elseif (KTRANSFORMERS_USE_XPU)
|
||||
add_compile_definitions(KTRANSFORMERS_USE_XPU=1)
|
||||
else()
|
||||
elseif (KTRANSFORMERS_USE_CUDA)
|
||||
find_package(CUDA REQUIRED)
|
||||
include_directories("${CUDA_INCLUDE_DIRS}")
|
||||
include(CheckLanguage)
|
||||
@@ -325,7 +335,13 @@ endif()
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3)
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4)
|
||||
# aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4)
|
||||
file(GLOB LLAMAFILE_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/*.cpp")
|
||||
list(REMOVE_ITEM LLAMAFILE_SOURCES
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/sgemm_arm.cpp"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/sgemm_x86.cpp"
|
||||
)
|
||||
set(SOURCE_DIR4 ${LLAMAFILE_SOURCES})
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)
|
||||
|
||||
if (HOST_IS_X86 AND HAS_AVX512 AND __HAS_AMX__)
|
||||
@@ -365,7 +381,7 @@ elseif(UNIX)
|
||||
elseif(KTRANSFORMERS_USE_MUSA)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
|
||||
elseif(KTRANSFORMERS_USE_XPU)
|
||||
else()
|
||||
elseif(KTRANSFORMERS_USE_CUDA AND NOT KTRANSFORMERS_USE_MUSA)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
|
||||
endif()
|
||||
endif()
|
||||
@@ -396,4 +412,4 @@ else()
|
||||
else()
|
||||
message(STATUS "NUMA library not found or user not set USE_NUMA - disabling NUMA support")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
**/
|
||||
// Python bindings
|
||||
#include "cpu_backend/cpuinfer.h"
|
||||
#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU)
|
||||
#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU) && !defined(KTRANSFORMERS_USE_NPU)
|
||||
#include "device_launch_parameters.h"
|
||||
#endif
|
||||
#include "llamafile/flags.h"
|
||||
|
||||
181
doc/zh/DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md
Normal file
181
doc/zh/DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md
Normal file
@@ -0,0 +1,181 @@
|
||||
# 基准测试结果
|
||||
|
||||
| Prompt length | 1K | 2K | 4K |
|
||||
| --------------------------------- | ------ | ------ | ------ |
|
||||
| KTrans Prefill token/s | 174.68 | 169.52 | 167.15 |
|
||||
| KTrans Decode token/s | 16.07 | 16.12 | 16.48 |
|
||||
|
||||
## 先决条件
|
||||
我们在以下配置下进行了Deepseek-R1最佳性能测试:
|
||||
- 服务器型号:Atlas 2UP
|
||||
- NPU:300I A2
|
||||
- CPU: HUAWEI Kunpeng 920 7270Z
|
||||
- 内存: DDR5服务器内存(1TB)
|
||||
|
||||
# 部署
|
||||
|
||||
## 物理机安装
|
||||
|
||||
部署满血版Deepseek-R1/V3,需要机器物理内存能够存放下全部路由专家的权重,约400GB。
|
||||
|
||||
目前支持的NPU型号:**300I A2**。
|
||||
|
||||
在技术人员的支持下完成硬件安装。
|
||||
|
||||
## 系统安装
|
||||
|
||||
根据网页[昇腾兼容性查询助手](https://www.hiascend.com/hardware/compatibility)查询,选用系统Ubuntu 22.04 for aarch64,内核5.15.0-25-generic,并禁止系统自动更新。系统镜像获取链接:[ubuntu-old-releases](https://mirrors.aliyun.com/oldubuntu-releases/releases/22.04)。
|
||||
|
||||
## HDK安装
|
||||
|
||||
选择[Ascend HDK 25.3.RC1](https://support.huawei.com/enterprise/zh/ascend-computing/ascend-hdk-pid-252764743/software/264672545?idAbsPath=fixnode01|23710424|251366513|254884019|261408772|252764743)进行安装,安装方式参考[昇腾社区HDK安装指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/83RC1alpha003/softwareinst/instg/instg_0005.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)。
|
||||
|
||||
|
||||
## Conda部署
|
||||
|
||||
建议按照最新[Installation Guide - kTransformers](https://kvcache-ai.github.io/ktransformers/en/install.html)部署开发环境,此处注意Python版本要求3.11(其他版本未验证),arm平台不需要安装cpufeature包。
|
||||
|
||||
安装conda/miniconda
|
||||
|
||||
```bash
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh
|
||||
bash ~/Miniconda3-latest-Linux-aarch64.sh
|
||||
```
|
||||
|
||||
部署Python环境:
|
||||
|
||||
```bash
|
||||
conda create -n py311 python=3.11
|
||||
conda activate py311
|
||||
conda install -c conda-forge libstdcxx-ng # 安装`GLIBCXX-3.4.32`
|
||||
apt install zlib1g-dev libtbb-dev libssl-dev libaio-dev libcurl4-openssl-dev
|
||||
pip3 install numpy==1.26.4 # 适配torch/torch_npu
|
||||
pip3 install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu
|
||||
pip3 install packaging ninja fire protobuf attrs decorator cloudpickle ml-dtypes scipy tornado absl-py psutil
|
||||
pip3 install sqlalchemy
|
||||
pip3 install transformers==4.57.1 #此处注意运行时transformers版本要求4.57.1(其他版本未验证)
|
||||
#pip3 install cpufeature # only for x86
|
||||
```
|
||||
|
||||
## CANN安装
|
||||
|
||||
选择[CANN 8.3.RC1.alpha003](https://www.hiascend.com/developer/download/community/result?cann=8.3.RC1.alpha003&product=4&model=32)进行安装,安装方式参考[昇腾社区CANN安装指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/83RC1alpha003/softwareinst/instg/instg_quick.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)。
|
||||
|
||||
需要安装ToolKit,Kernel和NNAL。
|
||||
|
||||
## torch_npu安装
|
||||
|
||||
获取最新的仓库代码:[torch_npu Gitcode](https://gitcode.com/Ascend/pytorch)
|
||||
|
||||
由于涉及新增算子,公网pypi内提供的torch_npu暂时无法直接使用,可以下载代码仓库编译,当前适配分支为v2.5.1,编译命令可以参考仓库内文档。
|
||||
编译过程需要保证访问github,gitcode等平台网络畅通并设置如下环境变量:
|
||||
|
||||
```bash
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh # 以实际CANN安装路径为准
|
||||
source /usr/local/Ascend/nnal/atb/set_env.sh # 以实际NNAL安装路径为准
|
||||
```
|
||||
由于环境对于torch_npu版本号有特定要求,使用编译后的torch_npu包需要手动移除版本信息中的哈希后缀,操作如下:
|
||||
使用文本编辑器打开`/usr/local/lib/python3.11/site-packages/torch_npu/version.py`(不同环境python路径可能不同,可以使用`pip show torch_npu`查看安装的python路径)
|
||||
将`__version__ = '2.5.1.post4+git69550dfc'`改为`__version__ = '2.5.1.post4'`
|
||||
|
||||
|
||||
## 权重准备
|
||||
|
||||
目前,为了满足性能和精度的要求,我们需要准备两份权重,并使用提供的权重合并脚本对权重进行合并,最终只会使用合并后的权重。
|
||||
|
||||
Q4权重:[DeepSeek-R1-Q4_K_M](https://modelscope.cn/models/unsloth/DeepSeek-R1-GGUF/files)
|
||||
|
||||
W8A8权重:[DeepSeek-R1-W8A8](https://modelers.cn/models/MindSpore-Lab/DeepSeek-R1-W8A8/tree/main)
|
||||
|
||||
使用[merge_safetensor_gguf.py](../../merge_tensors/merge_safetensor_gguf.py)来合并Q4和W8A8权重:
|
||||
|
||||
```bash
|
||||
python merge_safetensor_gguf.py --safetensor_path /mnt/weights/DeepSeek-R1-Q4_K_M --gguf_path /mnt/weights/DeepSeek-R1-W8A8 --output_path /mnt/weights/DeepSeek-R1-q4km-w8a8
|
||||
```
|
||||
|
||||
## 图下沉部署
|
||||
|
||||
开启图下沉功能,需要添加如下环境变量:
|
||||
|
||||
```bash
|
||||
export TASK_QUEUE_ENABLE=0 # 保证算子下发顺序有序
|
||||
```
|
||||
|
||||
|
||||
## kTransformers部署
|
||||
|
||||
将项目文件部署到机器上:
|
||||
|
||||
- 初始化third_party。由于此过程耗时较多,且容易受网络影响导致仓库克隆失败,建议初始化一次后,将相关文件进行打包,以便后续直接解压使用。
|
||||
```bash
|
||||
git clone https://github.com/kvcache-ai/ktransformers.git
|
||||
cd ktransformers
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
- 对于arm平台,注释掉`./third_party/llamafile/iqk_mul_mat_arm82.cpp`中的
|
||||
```cpp
|
||||
#define iqk_mul_mat iqk_mul_mat_arm82
|
||||
#define iqk_mul_mat_moe iqk_mul_mat_moe_arm82
|
||||
```
|
||||
- 执行`source /usr/local/Ascend/ascend-toolkit/set_env.sh`(以实际CANN-TOOLKIT安装路径为准)。
|
||||
- 执行`apt install cmake libhwloc-dev pkg-config`安装依赖。
|
||||
- 修改项目目录下 /ktransformers/config/config.yaml 中attn部分的page_size: 128 chunk_size: 16384
|
||||
- 执行`USE_BALANCE_SERVE=1 USE_NUMA=1 bash ./install.sh`,等待安装完成。
|
||||
|
||||
此处给出示例balance_serve的启动脚本(由于使用了相对路径,需将该脚本放至项目的根路径下):
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
export USE_MERGE=0
|
||||
export INF_NAN_MODE_FORCE_DISABLE=1
|
||||
export TASK_QUEUE_ENABLE=0
|
||||
export RANK=0
|
||||
export LOCAL_WORLD_SIZE=1
|
||||
#export PROF_DECODE=1
|
||||
#export PROF_PREFILL=1
|
||||
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
source /usr/local/Ascend/nnal/atb/set_env.sh
|
||||
|
||||
python ktransformers/server/main.py \
|
||||
--port 10002 \
|
||||
--model_path <your model path> \
|
||||
--gguf_path <your model path> \
|
||||
--model_name DeepSeekV3ForCausalLM \
|
||||
--cpu_infer 100 \
|
||||
--optimize_config_path ./ktransformers/optimize/optimize_rules/npu/DeepSeek-V3-Chat-300IA2-npu-serve.yaml \
|
||||
--max_new_tokens 1024 \
|
||||
--cache_lens 20480 \
|
||||
--max_batch_size 4 \
|
||||
--use_cuda_graph \
|
||||
--tp 1 \
|
||||
--backend_type balance_serve
|
||||
```
|
||||
|
||||
相关参数说明:
|
||||
|
||||
- `--model_path`:kTransformers原生参数,str,此处用来指定合并后的模型文件路径
|
||||
- `--gguf_path`:kTransformers原生参数,str,此处用来指定合并后的模型文件路径
|
||||
- `--cpu_infer`:kTransformers原生参数,int,用来控制CPU侧实际worker线程数,非必选
|
||||
- `--optimize_config_path`:kTransformers原生参数,str,用来指定所用的模型优化配置文件,需要注意相对路径的使用,此处为**必选**
|
||||
- `--cache_lens`:调度器申请 kvcache 的总长度。所有请求共享指定数量(例如 `20480`)的 tokens 对应的 kvcache 空间,请求完成后会释放其所占用的 kvcache 空间,非必选
|
||||
- `--use_cuda_graph`:kTransformers原生参数,bool,为True表示开启图下沉,为False表示关闭图下沉,非必选
|
||||
- `--max_new_tokens`:kTransformers原生参数,int,当统计到输出的tokens数量达到该值时,会直接中止输出,非必选
|
||||
- `--tp`:新增参数,int,用于开启tensor model parallel功能,目前local_chat只支持tp大小与ws大小相同(不支持local_chat使用多dp),非必选
|
||||
|
||||
|
||||
# 其他问题
|
||||
|
||||
## 可能存在的其他依赖问题
|
||||
|
||||
ImportError: libhccl.so: cannot open shared object file: No such file or directory
|
||||
|
||||
```bash
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh # 以实际CANN安装路径为准
|
||||
```
|
||||
|
||||
ImportError: libascend_hal.so: cannot open shared object file: No such file or directory
|
||||
|
||||
```bash
|
||||
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH # 以实际Driver安装路径为准
|
||||
```
|
||||
@@ -40,4 +40,4 @@ fi
|
||||
# cp -a csrc/balance_serve/build/third_party/prometheus-cpp/lib/libprometheus-cpp-*.so* $SITE_PACKAGES/
|
||||
# patchelf --set-rpath '$ORIGIN' $SITE_PACKAGES/sched_ext.cpython*
|
||||
|
||||
echo "Installation completed successfully"
|
||||
echo "Installation completed successfully"
|
||||
|
||||
21
install_for_npu.sh
Normal file
21
install_for_npu.sh
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
|
||||
# clear build dirs
|
||||
rm -rf build
|
||||
rm -rf *.egg-info
|
||||
rm -rf csrc/build
|
||||
rm -rf csrc/ktransformers_ext/build
|
||||
rm -rf csrc/ktransformers_ext/cuda/build
|
||||
rm -rf csrc/ktransformers_ext/cuda/dist
|
||||
rm -rf csrc/ktransformers_ext/cuda/*.egg-info
|
||||
rm -rf ~/.ktransformers
|
||||
echo "Installing python dependencies from requirements.txt"
|
||||
pip install -r requirements-local_chat.txt
|
||||
pip install -r ktransformers/server/requirements.txt
|
||||
echo "Installing ktransformers"
|
||||
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -v . --no-build-isolation
|
||||
pip install third_party/custom_flashinfer/
|
||||
|
||||
echo "Installation completed successfully"
|
||||
21
kt-kernel/.gitignore
vendored
Normal file
21
kt-kernel/.gitignore
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
debug/
|
||||
debug_prefill/
|
||||
debug_decode/
|
||||
debug1/
|
||||
debug2/
|
||||
.gdbinit
|
||||
bp.gdb
|
||||
.gdb_history
|
||||
build/
|
||||
# local git hooks installer and hooks
|
||||
.clangd
|
||||
.cache
|
||||
tmp*
|
||||
.vscode/
|
||||
*.egg-info/
|
||||
*.pyc
|
||||
*.so
|
||||
sparse_logs/
|
||||
build-cm/
|
||||
*.so
|
||||
sparse_logs/
|
||||
@@ -2,7 +2,33 @@ cmake_minimum_required(VERSION 3.16)
|
||||
|
||||
# Toggle: default to system compilers; optionally use conda toolchain
|
||||
option(USE_CONDA_TOOLCHAIN "Use C/C++ compilers and libraries from active conda env" OFF)
|
||||
option(LLAMA_NATIVE "llama: enable -march=native flag" OFF)
|
||||
option(LLAMA_AVX "llama: enable AVX" OFF)
|
||||
option(LLAMA_AVX2 "llama: enable AVX2" OFF)
|
||||
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
|
||||
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
|
||||
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
|
||||
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
|
||||
option(LLAMA_FMA "llama: enable FMA" OFF)
|
||||
# in MSVC F16C is implied with AVX2/AVX512
|
||||
if(NOT MSVC)
|
||||
option(LLAMA_F16C "llama: enable F16C" OFF)
|
||||
endif()
|
||||
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
|
||||
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF)
|
||||
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
|
||||
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
|
||||
option(KTRANSFORMERS_CPU_USE_KML "ktransformers: CPU use KML" OFF)
|
||||
option(KTRANSFORMERS_CPU_USE_AMX_AVX512 "ktransformers: CPU use AMX or AVX512" OFF)
|
||||
option(KTRANSFORMERS_CPU_USE_AMX "ktransformers: CPU use AMX" OFF)
|
||||
option(KTRANSFORMERS_CPU_DEBUG "ktransformers: DEBUG CPU use AMX" OFF)
|
||||
option(KTRANSFORMERS_CPU_MLA "ktransformers: CPU use MLA" OFF)
|
||||
option(KTRANSFORMERS_CPU_MOE_KERNEL "ktransformers: CPU use moe kernel" OFF)
|
||||
option(KTRANSFORMERS_CPU_MOE_AMD "ktransformers: CPU use moe kernel for amd" OFF)
|
||||
# LTO control
|
||||
option(CPUINFER_ENABLE_LTO "Enable link time optimization (IPO)" OFF)
|
||||
|
||||
project(kt_kernel_ext VERSION 0.1.0)
|
||||
# Choose compilers BEFORE project() so CMake honors them
|
||||
if(USE_CONDA_TOOLCHAIN)
|
||||
if(NOT DEFINED ENV{CONDA_PREFIX} OR NOT EXISTS "$ENV{CONDA_PREFIX}")
|
||||
@@ -24,8 +50,6 @@ else()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
project(cpuinfer_ext VERSION 0.1.0)
|
||||
|
||||
|
||||
# If explicitly using conda toolchain, prefer its libraries/headers and RPATH
|
||||
if(USE_CONDA_TOOLCHAIN)
|
||||
@@ -88,7 +112,7 @@ add_compile_definitions(FMT_HEADER_ONLY)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -ffast-math")
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -fsanitize=address -fno-omit-frame-pointer")
|
||||
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g ")
|
||||
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0")
|
||||
# set(CMAKE_BUILD_TYPE "Debug")
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
find_package(OpenMP REQUIRED)
|
||||
@@ -98,7 +122,6 @@ include(CheckCXXCompilerFlag)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
|
||||
option(LLAMA_NATIVE "llama: enable -march=native flag" ON)
|
||||
|
||||
# instruction set specific
|
||||
if(LLAMA_NATIVE)
|
||||
@@ -106,51 +129,10 @@ if(LLAMA_NATIVE)
|
||||
else()
|
||||
set(INS_ENB ON)
|
||||
endif()
|
||||
|
||||
option(LLAMA_AVX "llama: enable AVX" OFF)
|
||||
option(LLAMA_AVX2 "llama: enable AVX2" OFF)
|
||||
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
|
||||
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
|
||||
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
|
||||
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
|
||||
option(LLAMA_FMA "llama: enable FMA" OFF)
|
||||
# in MSVC F16C is implied with AVX2/AVX512
|
||||
if(NOT MSVC)
|
||||
option(LLAMA_F16C "llama: enable F16C" OFF)
|
||||
endif()
|
||||
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
|
||||
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF)
|
||||
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
|
||||
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
|
||||
option(KTRANSFORMERS_CPU_USE_KML "ktransformers: CPU use KML" OFF)
|
||||
option(KTRANSFORMERS_CPU_USE_AMX_AVX512 "ktransformers: CPU use AMX or AVX512" ON)
|
||||
option(KTRANSFORMERS_CPU_USE_AMX "ktransformers: CPU use AMX" OFF)
|
||||
option(KTRANSFORMERS_CPU_DEBUG "ktransformers: DEBUG CPU use AMX" OFF)
|
||||
option(KTRANSFORMERS_CPU_MLA "ktransformers: CPU use MLA" OFF)
|
||||
# LTO control
|
||||
option(CPUINFER_ENABLE_LTO "Enable link time optimization (IPO)" OFF)
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
if(MSVC)
|
||||
string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR)
|
||||
message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}")
|
||||
else()
|
||||
set(CMAKE_GENERATOR_PLATFORM_LWR "")
|
||||
endif()
|
||||
|
||||
if(NOT MSVC)
|
||||
if(LLAMA_STATIC)
|
||||
add_link_options(-static)
|
||||
if(MINGW)
|
||||
add_link_options(-static-libgcc -static-libstdc++)
|
||||
endif()
|
||||
endif()
|
||||
if(LLAMA_GPROF)
|
||||
add_compile_options(-pg)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(ARCH_FLAGS "")
|
||||
|
||||
@@ -266,14 +248,14 @@ elseif(CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR
|
||||
list(APPEND ARCH_FLAGS -mfma)
|
||||
endif()
|
||||
if(LLAMA_AVX)
|
||||
list(APPEND ARCH_FLAGS -mavx)
|
||||
list(APPEND ARCH_FLAGS -mavx -mfma -msse3 -mf16c)
|
||||
message(WARNING "pure AVX is not supported at least avx2")
|
||||
endif()
|
||||
if(LLAMA_AVX2)
|
||||
list(APPEND ARCH_FLAGS -mavx2)
|
||||
list(APPEND ARCH_FLAGS -mavx2 -mfma -msse3 -mf16c)
|
||||
endif()
|
||||
if(LLAMA_AVX512)
|
||||
list(APPEND ARCH_FLAGS -mavx512f)
|
||||
list(APPEND ARCH_FLAGS -mavx512bw)
|
||||
list(APPEND ARCH_FLAGS -mavx512f -mavx512bw -mfma -mf16c)
|
||||
endif()
|
||||
if(LLAMA_AVX512_VBMI)
|
||||
list(APPEND ARCH_FLAGS -mavx512vbmi)
|
||||
@@ -305,14 +287,6 @@ else()
|
||||
message(STATUS "Unknown architecture")
|
||||
endif()
|
||||
|
||||
# message(STATUS "CUDAToolkit_ROOT:${CUDAToolkit_ROOT}")
|
||||
# find_package(FindCUDAToolkit REQUIRED)
|
||||
# if(CUDAToolkit_FOUND)
|
||||
# message(STATUS "Found CUDA cudart lib at:${CUDAToolkit_LIBRARY_DIR}")
|
||||
# else()
|
||||
# message(STATUS "Can't found CUDA lib")
|
||||
# endif()
|
||||
|
||||
if(NOT EXISTS $ENV{ROCM_PATH})
|
||||
if(NOT EXISTS /opt/rocm)
|
||||
set(ROCM_PATH /usr)
|
||||
@@ -338,6 +312,62 @@ endif()
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
|
||||
|
||||
if(KTRANSFORMERS_CPU_MOE_AMD)
|
||||
set(BLIS_ROOT "" CACHE PATH "Root directory of BLIS installation")
|
||||
set(_BLIS_SEARCH_DIRS)
|
||||
if(BLIS_ROOT)
|
||||
list(APPEND _BLIS_SEARCH_DIRS "${BLIS_ROOT}")
|
||||
endif()
|
||||
list(APPEND _BLIS_SEARCH_DIRS "/usr/local" "/usr")
|
||||
|
||||
find_path(BLIS_INCLUDE_DIR
|
||||
NAMES blis.h
|
||||
HINTS ${_BLIS_SEARCH_DIRS}
|
||||
PATH_SUFFIXES include include/blis
|
||||
)
|
||||
find_library(BLIS_LIBRARY
|
||||
NAMES blis
|
||||
HINTS ${_BLIS_SEARCH_DIRS}
|
||||
PATH_SUFFIXES lib lib64
|
||||
)
|
||||
|
||||
if(NOT BLIS_INCLUDE_DIR OR NOT BLIS_LIBRARY)
|
||||
message(FATAL_ERROR "BLIS not found; set BLIS_ROOT or specify BLIS_INCLUDE_DIR/BLIS_LIBRARY")
|
||||
else()
|
||||
message(STATUS "Found BLIS include at ${BLIS_INCLUDE_DIR}")
|
||||
message(STATUS "Found BLIS library ${BLIS_LIBRARY}")
|
||||
endif()
|
||||
target_include_directories(${PROJECT_NAME} PRIVATE ${BLIS_INCLUDE_DIR})
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE ${BLIS_LIBRARY})
|
||||
endif()
|
||||
|
||||
|
||||
if(HOST_IS_X86)
|
||||
if(KTRANSFORMERS_CPU_USE_AMX_AVX512)
|
||||
add_compile_definitions(USE_AMX_AVX_KERNEL=1)
|
||||
if(KTRANSFORMERS_CPU_USE_AMX)
|
||||
add_compile_definitions(HAVE_AMX=1)
|
||||
list(APPEND ARCH_FLAGS -mamx-tile -mamx-bf16 -mamx-int8)
|
||||
message(STATUS "AMX enabled")
|
||||
list(APPEND ARCH_FLAGS -mamx-tile)
|
||||
endif()
|
||||
# add_executable(amx-test ${CMAKE_CURRENT_SOURCE_DIR}/operators/amx/amx-test.cpp)
|
||||
# target_link_libraries(amx-test llama)
|
||||
if(KTRANSFORMERS_CPU_DEBUG)
|
||||
file(GLOB AMX_TEST_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/operators/amx/test/*.cpp")
|
||||
foreach(test_src ${AMX_TEST_SOURCES})
|
||||
# 获取不带扩展名的文件名作为 target 名
|
||||
get_filename_component(test_name ${test_src} NAME_WE)
|
||||
add_executable(${test_name} ${test_src} ${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend/shared_mem_buffer.cpp)
|
||||
target_link_libraries(${test_name} llama OpenMP::OpenMP_CXX numa)
|
||||
endforeach()
|
||||
endif()
|
||||
list(APPEND ARCH_FLAGS -mfma -mf16c -mavx512bf16 -mavx512vnni)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
|
||||
|
||||
@@ -345,53 +375,48 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third_party/pybind11 ${CMAKE_CURREN
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party)
|
||||
if(WIN32)
|
||||
include_directories("$ENV{CUDA_PATH}/include")
|
||||
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
|
||||
elseif(UNIX)
|
||||
if(KTRANSFORMERS_USE_CUDA)
|
||||
include(CheckLanguage)
|
||||
check_language(CUDA)
|
||||
if(CMAKE_CUDA_COMPILER)
|
||||
message(STATUS "CUDA detected")
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
include_directories(${CUDAToolkit_INCLUDE_DIRS})
|
||||
else()
|
||||
message(FATAL_ERROR "KTRANSFORMERS_USE_CUDA=ON but CUDA compiler not found")
|
||||
endif()
|
||||
message(STATUS "enabling CUDA")
|
||||
enable_language(CUDA)
|
||||
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
|
||||
elseif(KTRANSFORMERS_USE_ROCM)
|
||||
find_package(HIP REQUIRED)
|
||||
if(HIP_FOUND)
|
||||
include_directories("${HIP_INCLUDE_DIRS}")
|
||||
add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)
|
||||
endif()
|
||||
elseif(KTRANSFORMERS_USE_MUSA)
|
||||
if(NOT EXISTS $ENV{MUSA_PATH})
|
||||
if(NOT EXISTS /opt/musa)
|
||||
set(MUSA_PATH /usr/local/musa)
|
||||
else()
|
||||
set(MUSA_PATH /opt/musa)
|
||||
endif()
|
||||
else()
|
||||
set(MUSA_PATH $ENV{MUSA_PATH})
|
||||
endif()
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
|
||||
|
||||
find_package(MUSAToolkit)
|
||||
if(MUSAToolkit_FOUND)
|
||||
message(STATUS "MUSA Toolkit found")
|
||||
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
|
||||
endif()
|
||||
elseif(KTRANSFORMERS_CPU_USE_KML)
|
||||
message(STATUS "KML CPU detected")
|
||||
if(KTRANSFORMERS_USE_CUDA)
|
||||
include(CheckLanguage)
|
||||
check_language(CUDA)
|
||||
if(CMAKE_CUDA_COMPILER)
|
||||
message(STATUS "CUDA detected")
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
include_directories(${CUDAToolkit_INCLUDE_DIRS})
|
||||
else()
|
||||
message(STATUS "No GPU support enabled, building for CPU only")
|
||||
add_compile_definitions(KTRANSFORMERS_CPU_ONLY=1)
|
||||
message(FATAL_ERROR "KTRANSFORMERS_USE_CUDA=ON but CUDA compiler not found")
|
||||
endif()
|
||||
message(STATUS "enabling CUDA")
|
||||
enable_language(CUDA)
|
||||
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
|
||||
elseif(KTRANSFORMERS_USE_ROCM)
|
||||
find_package(HIP REQUIRED)
|
||||
if(HIP_FOUND)
|
||||
include_directories("${HIP_INCLUDE_DIRS}")
|
||||
add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)
|
||||
endif()
|
||||
elseif(KTRANSFORMERS_USE_MUSA)
|
||||
if(NOT EXISTS $ENV{MUSA_PATH})
|
||||
if(NOT EXISTS /opt/musa)
|
||||
set(MUSA_PATH /usr/local/musa)
|
||||
else()
|
||||
set(MUSA_PATH /opt/musa)
|
||||
endif()
|
||||
else()
|
||||
set(MUSA_PATH $ENV{MUSA_PATH})
|
||||
endif()
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
|
||||
|
||||
find_package(MUSAToolkit)
|
||||
if(MUSAToolkit_FOUND)
|
||||
message(STATUS "MUSA Toolkit found")
|
||||
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
|
||||
endif()
|
||||
elseif(KTRANSFORMERS_CPU_USE_KML)
|
||||
message(STATUS "KML CPU detected")
|
||||
else()
|
||||
message(STATUS "No GPU support enabled, building for CPU only")
|
||||
add_compile_definitions(KTRANSFORMERS_CPU_ONLY=1)
|
||||
endif()
|
||||
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
|
||||
@@ -404,14 +429,31 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)
|
||||
# arm64
|
||||
if(NOT HOST_IS_X86 AND KTRANSFORMERS_CPU_USE_KML)
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kml SOURCE_DIR6)
|
||||
if(NOT KTRANSFORMERS_CPU_MLA)
|
||||
list(REMOVE_ITEM SOURCE_DIR6 ${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/mla/)
|
||||
endif()
|
||||
endif()
|
||||
# message(STATUS "SOURCE_DIR6: ${SOURCE_DIR6}")
|
||||
|
||||
# aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR7)
|
||||
if(KTRANSFORMERS_CPU_MOE_KERNEL)
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/la SOURCE_DIR7)
|
||||
if(KTRANSFORMERS_CPU_MOE_AMD)
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/aocl_kernel SOURCE_DIR7_KERNEL)
|
||||
add_compile_definitions(USE_MOE_KERNEL_AMD=1)
|
||||
elseif(NOT HOST_IS_X86 AND KTRANSFORMERS_CPU_USE_KML)
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel SOURCE_DIR7_KERNEL)
|
||||
endif()
|
||||
list(APPEND SOURCE_DIR7 ${SOURCE_DIR7_KERNEL})
|
||||
if(NOT KTRANSFORMERS_CPU_MLA)
|
||||
list(REMOVE_ITEM SOURCE_DIR7 ${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mla/)
|
||||
endif()
|
||||
add_compile_definitions(USE_MOE_KERNEL=1)
|
||||
endif()
|
||||
message(STATUS "SOURCE_DIR7: ${SOURCE_DIR7}")
|
||||
|
||||
|
||||
|
||||
set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6})
|
||||
set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6} ${SOURCE_DIR7})
|
||||
|
||||
file(GLOB_RECURSE FMT_SOURCES
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp"
|
||||
@@ -437,47 +479,47 @@ if(NOT DEFINED CLANG_FORMAT_BIN)
|
||||
)
|
||||
endif()
|
||||
if(NOT CLANG_FORMAT_BIN)
|
||||
message(FATAL_ERROR "clang-format not found. Please install clang-format (>=18) or pass -DCLANG_FORMAT_BIN=/full/path and reconfigure.")
|
||||
endif()
|
||||
execute_process(
|
||||
COMMAND ${CLANG_FORMAT_BIN} --version
|
||||
OUTPUT_VARIABLE _CLANG_FORMAT_VER
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
message(STATUS "CMake PATH: $ENV{PATH}")
|
||||
# Parse version string, e.g. "Ubuntu clang-format version 19.1.0" or "clang-format version 18.1.8"
|
||||
string(REGEX MATCH "version[ ]+([0-9]+(\\.[0-9]+)*)" _CF_VER_MATCH "${_CLANG_FORMAT_VER}")
|
||||
if(NOT _CF_VER_MATCH)
|
||||
message(FATAL_ERROR "Failed to parse clang-format version from: ${_CLANG_FORMAT_VER}")
|
||||
endif()
|
||||
set(CLANG_FORMAT_VERSION "${CMAKE_MATCH_1}")
|
||||
message(STATUS "Using clang-format ${CLANG_FORMAT_VERSION} at ${CLANG_FORMAT_BIN}")
|
||||
if(CLANG_FORMAT_VERSION VERSION_LESS "18.0.0")
|
||||
message(FATAL_ERROR "clang-format >=18.0.0 required (found ${CLANG_FORMAT_VERSION} at ${CLANG_FORMAT_BIN}).\n"
|
||||
"Tip: Ensure your desired clang-format (e.g., conda's ${CONDA_PREFIX}/bin/clang-format) is earlier in PATH when running CMake,\n"
|
||||
"or pass -DCLANG_FORMAT_BIN=/full/path/to/clang-format.")
|
||||
endif()
|
||||
message(WARNING "clang-format not found. Please install clang-format (>=18) or pass -DCLANG_FORMAT_BIN=/full/path and reconfigure.")
|
||||
else()
|
||||
execute_process(
|
||||
COMMAND ${CLANG_FORMAT_BIN} --version
|
||||
OUTPUT_VARIABLE _CLANG_FORMAT_VER
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
# message(STATUS "CMake PATH: $ENV{PATH}")
|
||||
# Parse version string, e.g. "Ubuntu clang-format version 19.1.0" or "clang-format version 18.1.8"
|
||||
string(REGEX MATCH "version[ ]+([0-9]+(\\.[0-9]+)*)" _CF_VER_MATCH "${_CLANG_FORMAT_VER}")
|
||||
if(NOT _CF_VER_MATCH)
|
||||
message(WARNING "Failed to parse clang-format version from: ${_CLANG_FORMAT_VER}")
|
||||
endif()
|
||||
set(CLANG_FORMAT_VERSION "${CMAKE_MATCH_1}")
|
||||
message(STATUS "Using clang-format ${CLANG_FORMAT_VERSION} at ${CLANG_FORMAT_BIN}")
|
||||
if(CLANG_FORMAT_VERSION VERSION_LESS "18.0.0")
|
||||
message(WARNING "clang-format >=18.0.0 required (found ${CLANG_FORMAT_VERSION} at ${CLANG_FORMAT_BIN}).\n"
|
||||
"Tip: Ensure your desired clang-format (e.g., conda's ${CONDA_PREFIX}/bin/clang-format) is earlier in PATH when running CMake,\n"
|
||||
"or pass -DCLANG_FORMAT_BIN=/full/path/to/clang-format.")
|
||||
endif()
|
||||
add_custom_target(
|
||||
format
|
||||
COMMAND ${CLANG_FORMAT_BIN}
|
||||
-i
|
||||
-style=file
|
||||
-fallback-style=none
|
||||
${FMT_SOURCES}
|
||||
COMMENT "Running clang-format on all source files"
|
||||
)
|
||||
|
||||
add_custom_target(
|
||||
format
|
||||
COMMAND ${CLANG_FORMAT_BIN}
|
||||
-i
|
||||
-style=file
|
||||
-fallback-style=none
|
||||
${FMT_SOURCES}
|
||||
COMMENT "Running clang-format on all source files"
|
||||
)
|
||||
|
||||
# Optional: target to check formatting without modifying files (CI-friendly)
|
||||
add_custom_target(
|
||||
format-check
|
||||
COMMAND ${CLANG_FORMAT_BIN}
|
||||
-n --Werror
|
||||
-style=file
|
||||
-fallback-style=none
|
||||
${FMT_SOURCES}
|
||||
COMMENT "Checking clang-format on all source files"
|
||||
)
|
||||
# Optional: target to check formatting without modifying files (CI-friendly)
|
||||
add_custom_target(
|
||||
format-check
|
||||
COMMAND ${CLANG_FORMAT_BIN}
|
||||
-n --Werror
|
||||
-style=file
|
||||
-fallback-style=none
|
||||
${FMT_SOURCES}
|
||||
COMMENT "Checking clang-format on all source files"
|
||||
)
|
||||
endif()
|
||||
|
||||
include(FindPkgConfig)
|
||||
if(PKG_CONFIG_FOUND)
|
||||
@@ -489,8 +531,6 @@ endif(PKG_CONFIG_FOUND)
|
||||
|
||||
add_library(llamafile STATIC ${SOURCE_DIR4})
|
||||
|
||||
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||
message(STATUS "ARCH_FLAGS: ${ARCH_FLAGS}")
|
||||
|
||||
if(CPUINFER_ENABLE_LTO)
|
||||
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION ON)
|
||||
@@ -502,6 +542,7 @@ else()
|
||||
pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES})
|
||||
message(STATUS "LTO: disabled")
|
||||
endif()
|
||||
|
||||
# Ensure the module target also has correct RPATH when conda is active
|
||||
if(TARGET ${PROJECT_NAME} AND DEFINED ENV{CONDA_PREFIX} AND EXISTS "$ENV{CONDA_PREFIX}")
|
||||
set_target_properties(${PROJECT_NAME} PROPERTIES
|
||||
@@ -513,44 +554,23 @@ endif()
|
||||
if(NOT HOST_IS_X86 AND KTRANSFORMERS_CPU_USE_KML)
|
||||
message(STATUS "KML CPU detected")
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/prefillgemm)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel/prefillgemm)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE prefillint8gemm)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/prefillgemm_int4)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel/prefillgemm_int4)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE prefillint4gemm)
|
||||
|
||||
set(DECODE_GEMM_SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/la/batch_gemm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/la/batch_gemm_kernels.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel/batch_gemm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel/batch_gemm_kernels.cpp
|
||||
)
|
||||
add_library(decode_gemm SHARED ${DECODE_GEMM_SOURCES})
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE decode_gemm)
|
||||
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE kml_rt)
|
||||
if(KTRANSFORMERS_CPU_MLA)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE kml_rt)
|
||||
endif()
|
||||
target_compile_definitions(${PROJECT_NAME} PRIVATE CPU_USE_KML)
|
||||
endif()
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE llama PkgConfig::HWLOC OpenMP::OpenMP_CXX)
|
||||
|
||||
|
||||
if(HOST_IS_X86)
|
||||
if(KTRANSFORMERS_CPU_USE_AMX_AVX512)
|
||||
if(KTRANSFORMERS_CPU_USE_AMX)
|
||||
add_compile_definitions(HAVE_AMX=1)
|
||||
message(STATUS "AMX enabled")
|
||||
endif()
|
||||
# add_executable(amx-test ${CMAKE_CURRENT_SOURCE_DIR}/operators/amx/amx-test.cpp)
|
||||
# target_link_libraries(amx-test llama)
|
||||
if(KTRANSFORMERS_CPU_DEBUG)
|
||||
file(GLOB AMX_TEST_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/operators/amx/test/*.cpp")
|
||||
foreach(test_src ${AMX_TEST_SOURCES})
|
||||
# 获取不带扩展名的文件名作为 target 名
|
||||
get_filename_component(test_name ${test_src} NAME_WE)
|
||||
add_executable(${test_name} ${test_src} ${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend/shared_mem_buffer.cpp)
|
||||
target_link_libraries(${test_name} llama OpenMP::OpenMP_CXX numa)
|
||||
endforeach()
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT HOST_IS_X86 AND KTRANSFORMERS_CPU_USE_KML)
|
||||
if(KTRANSFORMERS_CPU_DEBUG)
|
||||
# add_executable(convert-test ${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/convert-test.cpp)
|
||||
@@ -560,27 +580,27 @@ if(NOT HOST_IS_X86 AND KTRANSFORMERS_CPU_USE_KML)
|
||||
# 获取不带扩展名的文件名作为 target 名
|
||||
get_filename_component(test_name ${test_src} NAME_WE)
|
||||
add_executable(${test_name} ${test_src} ${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend/shared_mem_buffer.cpp)
|
||||
target_link_libraries(${test_name} llama OpenMP::OpenMP_CXX numa kml_rt)
|
||||
if(KTRANSFORMERS_CPU_MLA)
|
||||
target_link_libraries(${test_name} llama OpenMP::OpenMP_CXX numa kml_rt)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
if(WIN32)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib") #CUDA::cudart
|
||||
elseif(UNIX)
|
||||
if(NOT KTRANSFORMERS_USE_MUSA)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
|
||||
endif()
|
||||
if(KTRANSFORMERS_USE_ROCM)
|
||||
add_compile_definitions(USE_HIP=1)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so")
|
||||
message(STATUS "Building for HIP")
|
||||
endif()
|
||||
if(KTRANSFORMERS_USE_MUSA)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
|
||||
endif()
|
||||
|
||||
if(KTRANSFORMERS_USE_CUDA)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
|
||||
endif()
|
||||
if(KTRANSFORMERS_USE_ROCM)
|
||||
add_compile_definitions(USE_HIP=1)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so")
|
||||
message(STATUS "Building for HIP")
|
||||
endif()
|
||||
if(KTRANSFORMERS_USE_MUSA)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
find_library(NUMA_LIBRARY NAMES numa)
|
||||
@@ -590,3 +610,7 @@ if(NUMA_LIBRARY)
|
||||
else()
|
||||
message(FATAL_ERROR "NUMA library not found, please install NUMA, sudo apt install libnuma-dev")
|
||||
endif()
|
||||
|
||||
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||
message(STATUS "ARCH_FLAGS: ${ARCH_FLAGS}")
|
||||
|
||||
|
||||
47
kt-kernel/CMakePresets.json
Normal file
47
kt-kernel/CMakePresets.json
Normal file
@@ -0,0 +1,47 @@
|
||||
{
|
||||
"version": 3,
|
||||
"cmakeMinimumRequired": {
|
||||
"major": 3,
|
||||
"minor": 19,
|
||||
"patch": 0
|
||||
},
|
||||
"configurePresets": [
|
||||
{
|
||||
"name": "avx512",
|
||||
"displayName": "avx512_platform",
|
||||
"description": "for avx512 platform",
|
||||
"cacheVariables": {
|
||||
"KTRANSFORMERS_CPU_USE_AMX": "OFF",
|
||||
"LLAMA_AVX512": "OFF",
|
||||
"LLAMA_AVX2": "OFF",
|
||||
"KTRANSFORMERS_CPU_USE_AMX_AVX512": "ON",
|
||||
"KTRANSFORMERS_USE_CUDA": "ON"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "avx",
|
||||
"displayName": "avx_platform",
|
||||
"description": "for avx platform",
|
||||
"cacheVariables": {
|
||||
"KTRANSFORMERS_CPU_USE_AMX": "OFF",
|
||||
"LLAMA_AVX2": "ON",
|
||||
"KTRANSFORMERS_USE_CUDA": "ON"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "amx",
|
||||
"displayName": "amx_platform",
|
||||
"description": "for amx platform",
|
||||
"cacheVariables": {
|
||||
"KTRANSFORMERS_CPU_USE_AMX": "ON",
|
||||
"LLAMA_AVX512": "OFF",
|
||||
"LLAMA_AVX2": "OFF",
|
||||
"KTRANSFORMERS_CPU_USE_AMX_AVX512": "ON",
|
||||
"KTRANSFORMERS_USE_CUDA": "ON"
|
||||
}
|
||||
}
|
||||
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -105,59 +105,42 @@ python -c "from kt_kernel import AMXMoEWrapper; print('✓ kt-kernel installed s
|
||||
|
||||
## Weight Quantization
|
||||
|
||||
KT-Kernel provides a weight conversion tool to quantize model weights from FP8/FP16/BF16 to INT4/INT8 format optimized for AMX inference.
|
||||
KT-Kernel provides weight quantization tools for CPU-GPU hybrid inference (e.g., integrating with SGLang). Both tools work together to enable heterogeneous expert placement across CPUs and GPUs.
|
||||
|
||||
### Quantization Methods
|
||||
### CPU Weights (for "cold" experts on CPU)
|
||||
|
||||
- **INT4**: 4-bit quantization for maximum memory efficiency
|
||||
- **INT8**: 8-bit quantization for better accuracy
|
||||
|
||||
### Supported Input Formats
|
||||
|
||||
- **FP8**: 8-bit floating point with automatic dequantization
|
||||
- **FP16**: 16-bit floating point
|
||||
- **BF16**: BFloat16 format
|
||||
|
||||
### Basic Usage
|
||||
Quantize weights to INT4/INT8 format optimized for AMX inference:
|
||||
|
||||
```bash
|
||||
# Quantize BF16 model to INT4
|
||||
python scripts/convert_weights.py \
|
||||
--input-path /path/to/bf16/model \
|
||||
python scripts/convert_cpu_weights.py \
|
||||
--input-path /path/to/model \
|
||||
--input-type bf16 \
|
||||
--output /path/to/output \
|
||||
--quant-method int4
|
||||
|
||||
# Quantize FP16 model to INT8
|
||||
python scripts/convert_weights.py \
|
||||
--input-path /path/to/fp16/model \
|
||||
--input-type fp16 \
|
||||
--output /path/to/output \
|
||||
--quant-method int8
|
||||
|
||||
# Quantize FP8 model to INT4
|
||||
python scripts/convert_weights.py \
|
||||
--input-path /path/to/fp8/model \
|
||||
--input-type fp8 \
|
||||
--output /path/to/output \
|
||||
--quant-method int4
|
||||
```
|
||||
|
||||
### Output Format
|
||||
**Supported formats:** FP8, FP16, BF16 → INT4/INT8
|
||||
|
||||
The converted weights are saved in SafeTensors format with NUMA-aware layout:
|
||||
```
|
||||
output_dir/
|
||||
├── model-00001-of-00050.safetensors
|
||||
├── model-00002-of-00050.safetensors
|
||||
├── ...
|
||||
├── config.json
|
||||
└── tokenizer files...
|
||||
### GPU Weights (for "hot" experts on GPU)
|
||||
|
||||
Apply GPTQ quantization to model weights:
|
||||
|
||||
```bash
|
||||
# Install additional dependencies first
|
||||
pip install accelerate transformers llmcompressor datasets
|
||||
|
||||
# Quantize GPU weights
|
||||
python scripts/convert_gpu_weights.py \
|
||||
--model_id /path/to/model \
|
||||
--output_dir /path/to/output \
|
||||
--quant_type W4A16
|
||||
```
|
||||
|
||||
Each expert's weights are split across NUMA nodes for optimal memory access:
|
||||
- `blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.weight`: Quantized weights
|
||||
- `blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.scale`: Quantization scales
|
||||
**Supported types:** W4A16 (GPTQ4), W8A16 (GPTQ8)
|
||||
|
||||
---
|
||||
|
||||
For detailed documentation, advanced options, and low-memory mode, see [scripts/README.md](scripts/README.md).
|
||||
|
||||
## Before Commit!
|
||||
your msg should match: Conventional Commits (https://www.conventionalcommits.org/) <br>and format your code before commit:
|
||||
|
||||
23
kt-kernel/bench/Makefile
Normal file
23
kt-kernel/bench/Makefile
Normal file
@@ -0,0 +1,23 @@
|
||||
# test bench_moe_kernel_tiling.py
|
||||
kernel_tiling:
|
||||
python3 bench_moe_kernel_tiling.py \
|
||||
--hidden_size 7168 \
|
||||
--intermediate_size 2048 \
|
||||
--num_experts_per_tok 8 \
|
||||
--expert_num 256 \
|
||||
--max_len 51200 \
|
||||
--layer_num 1 \
|
||||
--qlen 1024 \
|
||||
--quant int8 \
|
||||
--warm_up_iter 500 \
|
||||
--test_iter 1000 \
|
||||
--threads 160 \
|
||||
--m_block 320 \
|
||||
|
||||
# --n_block_up_gate 256 \
|
||||
# --n_block_down 128 \
|
||||
# --n_block_up_gate_prefi 256 \
|
||||
# --n_block_down_prefi 128 \
|
||||
|
||||
# --n_block_up_gate 256 \
|
||||
# --n_block_down 512 \
|
||||
@@ -13,7 +13,7 @@ import os, sys
|
||||
import time
|
||||
|
||||
sys.path.append(os.path.dirname(__file__) + "/../build")
|
||||
import cpuinfer_ext
|
||||
import kt_kernel_ext
|
||||
import torch
|
||||
|
||||
layer_num = 10
|
||||
@@ -23,16 +23,16 @@ head_dim = 128
|
||||
block_len = 128
|
||||
anchor_num = 1
|
||||
|
||||
anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC
|
||||
kv_type = cpuinfer_ext.kvcache.ggml_type.FP16
|
||||
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER
|
||||
anchor_type = kt_kernel_ext.kvcache.AnchorType.DYNAMIC
|
||||
kv_type = kt_kernel_ext.kvcache.ggml_type.FP16
|
||||
retrieval_type = kt_kernel_ext.kvcache.RetrievalType.LAYER
|
||||
layer_step: int = 1
|
||||
token_step: int = 1
|
||||
layer_offset: int = 0
|
||||
max_thread_num: int = 64
|
||||
max_batch_size: int = 1
|
||||
max_block_num: int = 1024
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(max_thread_num)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(max_thread_num)
|
||||
|
||||
warm_up_iter = 1000
|
||||
test_iter = 10000
|
||||
@@ -43,7 +43,7 @@ def bench_linear(cache_seqlen: int):
|
||||
cache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device="cpu")
|
||||
seqlens_zero = torch.zeros((1,), dtype=torch.int32, device="cpu")
|
||||
|
||||
config = cpuinfer_ext.kvcache.KVCacheConfig(
|
||||
config = kt_kernel_ext.kvcache.KVCacheConfig(
|
||||
layer_num,
|
||||
kv_head_num,
|
||||
q_head_num,
|
||||
@@ -60,7 +60,7 @@ def bench_linear(cache_seqlen: int):
|
||||
max_batch_size,
|
||||
max_thread_num,
|
||||
)
|
||||
local_kvcache = cpuinfer_ext.kvcache.KVCache(config)
|
||||
local_kvcache = kt_kernel_ext.kvcache.KVCache(config)
|
||||
block_table = (
|
||||
torch.arange(max_block_num, dtype=torch.int32, device="cpu")
|
||||
.contiguous()
|
||||
|
||||
@@ -13,7 +13,7 @@ import os, sys
|
||||
import time
|
||||
|
||||
sys.path.append(os.path.dirname(__file__) + "/../build")
|
||||
import cpuinfer_ext
|
||||
import kt_kernel_ext
|
||||
import torch
|
||||
|
||||
layer_num = 10
|
||||
|
||||
@@ -12,7 +12,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
import os, sys
|
||||
import time
|
||||
sys.path.append(os.path.dirname(__file__) + '/../build')
|
||||
import cpuinfer_ext
|
||||
import kt_kernel_ext
|
||||
import torch
|
||||
|
||||
input_size = 16384
|
||||
@@ -21,7 +21,7 @@ stride = 16
|
||||
group_max_len = 1024
|
||||
layer_num = 10
|
||||
qlen = 1
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(64)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(64)
|
||||
warm_up_iter = 1000
|
||||
test_iter = 10000
|
||||
|
||||
@@ -69,8 +69,8 @@ def bench_linear(quant_mode: str):
|
||||
projs = []
|
||||
for _ in range(layer_num):
|
||||
proj = torch.randn((output_size, input_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
|
||||
config = cpuinfer_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type)
|
||||
linear = cpuinfer_ext.linear.Linear(config)
|
||||
config = kt_kernel_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type)
|
||||
linear = kt_kernel_ext.linear.Linear(config)
|
||||
projs.append(proj)
|
||||
linears.append(linear)
|
||||
input = torch.randn((layer_num, qlen, input_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
|
||||
|
||||
@@ -5,8 +5,8 @@ import platform
|
||||
import json
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'build'))
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.kvcache import ggml_type
|
||||
import kt_kernel_ext
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
import torch
|
||||
from torch import inf, nn
|
||||
from torch.nn import init
|
||||
@@ -47,7 +47,7 @@ rope_scaling = {
|
||||
|
||||
CPUINFER_PARAM = 304
|
||||
# 初始化 CPUInfer(此处使用原始构造函数,可根据需要调整配置参数)
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(CPUINFER_PARAM)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
|
||||
|
||||
|
||||
warm_up_iter = 20
|
||||
@@ -200,7 +200,7 @@ def bench_mla(quant_mode: str):
|
||||
kv_b_proj_weight = kv_b_proj.weight.to(torch.float16).to('cpu').contiguous()
|
||||
o_proj_weight = o_proj.weight.to(torch.float16).to('cpu').contiguous()
|
||||
|
||||
config = cpuinfer_ext.mla.MLAConfig(
|
||||
config = kt_kernel_ext.mla.MLAConfig(
|
||||
hidden_size,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
@@ -236,7 +236,7 @@ def bench_mla(quant_mode: str):
|
||||
|
||||
|
||||
|
||||
mla = cpuinfer_ext.mla.MLA(config)
|
||||
mla = kt_kernel_ext.mla.MLA(config)
|
||||
mla.load_weights()
|
||||
mla.set_local_pages(pages_count)
|
||||
mlas.append(mla)
|
||||
|
||||
@@ -12,7 +12,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
import os, sys
|
||||
import time
|
||||
sys.path.append(os.path.dirname(__file__) + '/../build')
|
||||
import cpuinfer_ext
|
||||
import kt_kernel_ext
|
||||
import torch
|
||||
|
||||
hidden_size = 5120
|
||||
@@ -21,7 +21,7 @@ stride = 16
|
||||
group_max_len = 1024
|
||||
layer_num = 10
|
||||
qlen = 1
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(64)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(64)
|
||||
warm_up_iter = 1000
|
||||
test_iter = 10000
|
||||
|
||||
@@ -96,8 +96,8 @@ def bench_mlp(quant_mode: str):
|
||||
gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
|
||||
up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
|
||||
down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
|
||||
config = cpuinfer_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)
|
||||
mlp = cpuinfer_ext.mlp.MLP(config)
|
||||
config = kt_kernel_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)
|
||||
mlp = kt_kernel_ext.mlp.MLP(config)
|
||||
gate_projs.append(gate_proj)
|
||||
up_projs.append(up_proj)
|
||||
down_projs.append(down_proj)
|
||||
|
||||
@@ -6,7 +6,7 @@ import subprocess
|
||||
import platform
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'build'))
|
||||
import cpuinfer_ext
|
||||
import kt_kernel_ext
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -29,7 +29,7 @@ warm_up_iter = 100
|
||||
test_iter = 10000
|
||||
CPUINFER_PARAM = 304
|
||||
# 初始化 CPUInfer(此处使用原始构造函数,可根据需要调整配置参数)
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(CPUINFER_PARAM)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
|
||||
|
||||
# 获取脚本相关信息,用于生成结果保存文件名
|
||||
script_path = os.path.abspath(__file__)
|
||||
@@ -198,7 +198,7 @@ def bench_moe(quant_mode: str):
|
||||
up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float16, device="cpu").to("cpu").contiguous()
|
||||
down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float16, device="cpu").to("cpu").contiguous()
|
||||
|
||||
config = cpuinfer_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
config.pool = CPUInfer.backend_
|
||||
config.m_block = m_block
|
||||
config.group_min_len = group_min_len
|
||||
@@ -211,7 +211,7 @@ def bench_moe(quant_mode: str):
|
||||
config.down_type = down_type
|
||||
config.hidden_type = hidden_type
|
||||
|
||||
moe = cpuinfer_ext.moe.MOE(config)
|
||||
moe = kt_kernel_ext.moe.MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task())
|
||||
CPUInfer.sync()
|
||||
moes.append(moe)
|
||||
|
||||
@@ -1,50 +1,46 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Description :
|
||||
"""
|
||||
Description :
|
||||
Author : chenht2022
|
||||
Date : 2024-07-25 10:32:05
|
||||
Version : 1.0.0
|
||||
LastEditors : chenht2022
|
||||
LastEditors : chenht2022
|
||||
LastEditTime : 2024-08-06 10:41:28
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
'''
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
"""
|
||||
import os, sys, time, json, subprocess, platform
|
||||
|
||||
from tqdm import tqdm
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'build'))
|
||||
import cpuinfer_ext
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
||||
import torch
|
||||
import kt_kernel_ext
|
||||
import numpy as np
|
||||
|
||||
# 测试参数设置
|
||||
expert_num = 256
|
||||
expert_num = 16
|
||||
hidden_size = 7168
|
||||
intermediate_size = 2048
|
||||
max_len = 25600
|
||||
max_len = 25600
|
||||
num_experts_per_tok = 8
|
||||
layer_num = 4
|
||||
# qlen = 1024
|
||||
qlen = 1
|
||||
warm_up_iter = 500
|
||||
test_iter = 1000
|
||||
physical_to_logical_map = torch.tensor(
|
||||
data=range(expert_num),
|
||||
device="cpu",
|
||||
dtype=torch.int64
|
||||
).contiguous()
|
||||
layer_num = 2
|
||||
|
||||
qlen = 2048
|
||||
warm_up_iter = 1000
|
||||
test_iter = 2000
|
||||
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||
|
||||
# 将 CPUInfer 参数设为变量
|
||||
# CPUINFER_PARAM = 257
|
||||
# CPUInfer = cpuinfer_ext.CPUInfer(CPUINFER_PARAM)
|
||||
# CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
|
||||
|
||||
worker_config = cpuinfer_ext.WorkerPoolConfig()
|
||||
worker_config = kt_kernel_ext.WorkerPoolConfig()
|
||||
worker_config.subpool_count = 2
|
||||
worker_config.subpool_numa_map= [0,1]
|
||||
worker_config.subpool_thread_count = [80,80]
|
||||
CPUINFER_PARAM = 40
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(worker_config)
|
||||
|
||||
worker_config.subpool_numa_map = [0, 1]
|
||||
worker_config.subpool_thread_count = [80, 80]
|
||||
CPUINFER_PARAM = 160
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
|
||||
|
||||
|
||||
def get_git_commit():
|
||||
@@ -81,14 +77,14 @@ def get_system_info():
|
||||
info = {}
|
||||
# 系统名称及主机名
|
||||
uname = platform.uname()
|
||||
info["system_name"] = uname.system # 如 Linux, Windows 等
|
||||
info["node_name"] = uname.node # 主机名称
|
||||
info["system_name"] = uname.system # 如 Linux, Windows 等
|
||||
info["node_name"] = uname.node # 主机名称
|
||||
|
||||
# 获取 CPU 型号(仅 Linux 支持)
|
||||
cpu_model = None
|
||||
if os.path.exists('/proc/cpuinfo'):
|
||||
if os.path.exists("/proc/cpuinfo"):
|
||||
try:
|
||||
with open('/proc/cpuinfo', 'r') as f:
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
for line in f:
|
||||
if "model name" in line:
|
||||
cpu_model = line.split(":", 1)[1].strip()
|
||||
@@ -99,9 +95,9 @@ def get_system_info():
|
||||
|
||||
# 获取内存大小(单位:GB),仅 Linux 支持
|
||||
mem_total_gb = None
|
||||
if os.path.exists('/proc/meminfo'):
|
||||
if os.path.exists("/proc/meminfo"):
|
||||
try:
|
||||
with open('/proc/meminfo', 'r') as f:
|
||||
with open("/proc/meminfo", "r") as f:
|
||||
for line in f:
|
||||
if "MemTotal" in line:
|
||||
mem_kb = float(line.split(":", 1)[1].split()[0])
|
||||
@@ -129,11 +125,13 @@ def get_system_info():
|
||||
|
||||
return info
|
||||
|
||||
|
||||
script_path = os.path.abspath(__file__)
|
||||
script_dir = os.path.dirname(script_path)
|
||||
script_name = os.path.splitext(os.path.basename(script_path))[0]
|
||||
json_path = os.path.join(script_dir, script_name + ".jsonl")
|
||||
|
||||
|
||||
def record_results(result, filename=json_path):
|
||||
"""
|
||||
将结果以 JSON 格式追加到文件中
|
||||
@@ -141,6 +139,7 @@ def record_results(result, filename=json_path):
|
||||
with open(filename, "a") as f:
|
||||
f.write(json.dumps(result) + "\n")
|
||||
|
||||
|
||||
def bench_moe(quant_mode: str):
|
||||
with torch.inference_mode():
|
||||
if quant_mode == "bf16":
|
||||
@@ -157,33 +156,56 @@ def bench_moe(quant_mode: str):
|
||||
up_projs = []
|
||||
down_projs = []
|
||||
for layer_index in range(layer_num):
|
||||
gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda").to("cpu").contiguous()
|
||||
up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda").to("cpu").contiguous()
|
||||
down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cuda").to("cpu").contiguous()
|
||||
config = cpuinfer_ext.moe.MOEConfig(
|
||||
expert_num, num_experts_per_tok, hidden_size, intermediate_size,0)
|
||||
gate_proj = (
|
||||
torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda")
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
up_proj = (
|
||||
torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda")
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
down_proj = (
|
||||
torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cuda")
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
if quant_mode == "bf16":
|
||||
moe = cpuinfer_ext.moe.AMXBF16_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
|
||||
elif quant_mode == "int8":
|
||||
moe = cpuinfer_ext.moe.AMXInt8_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXInt8_MOE(config)
|
||||
elif quant_mode == "int4":
|
||||
moe = cpuinfer_ext.moe.AMXInt4_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
moe = kt_kernel_ext.moe.AMXInt4_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task())
|
||||
CPUInfer.sync()
|
||||
gate_projs.append(gate_proj)
|
||||
up_projs.append(up_proj)
|
||||
down_projs.append(down_proj)
|
||||
moes.append(moe)
|
||||
gen_iter = 3000
|
||||
expert_ids = torch.rand(gen_iter * qlen , expert_num, device="cpu").argsort(dim=-1)[:, :num_experts_per_tok].reshape(gen_iter, qlen * num_experts_per_tok).to("cpu").contiguous()
|
||||
weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu").to("cpu").contiguous()
|
||||
input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
|
||||
output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
|
||||
expert_ids = (
|
||||
torch.rand(gen_iter * qlen, expert_num, device="cpu")
|
||||
.argsort(dim=-1)[:, :num_experts_per_tok]
|
||||
.reshape(gen_iter, qlen * num_experts_per_tok)
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
weights = (
|
||||
torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu").to("cpu").contiguous()
|
||||
)
|
||||
input_tensor = (
|
||||
torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
|
||||
)
|
||||
output_tensor = (
|
||||
torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
|
||||
)
|
||||
bsz_tensor = torch.tensor([qlen], device="cpu")
|
||||
|
||||
# 预热迭代
|
||||
@@ -193,8 +215,8 @@ def bench_moe(quant_mode: str):
|
||||
moes[i % layer_num].forward_task(
|
||||
bsz_tensor.data_ptr(),
|
||||
num_experts_per_tok,
|
||||
expert_ids[i%gen_iter].data_ptr(),
|
||||
weights[i%gen_iter].data_ptr(),
|
||||
expert_ids[i % gen_iter].data_ptr(),
|
||||
weights[i % gen_iter].data_ptr(),
|
||||
input_tensor[i % layer_num].data_ptr(),
|
||||
output_tensor[i % layer_num].data_ptr(),
|
||||
False,
|
||||
@@ -213,8 +235,8 @@ def bench_moe(quant_mode: str):
|
||||
moes[i % layer_num].forward_task(
|
||||
bsz_tensor.data_ptr(),
|
||||
num_experts_per_tok,
|
||||
expert_ids[i%gen_iter].data_ptr(),
|
||||
weights[i%gen_iter].data_ptr(),
|
||||
expert_ids[i % gen_iter].data_ptr(),
|
||||
weights[i % gen_iter].data_ptr(),
|
||||
input_tensor[i % layer_num].data_ptr(),
|
||||
output_tensor[i % layer_num].data_ptr(),
|
||||
False,
|
||||
@@ -228,16 +250,28 @@ def bench_moe(quant_mode: str):
|
||||
|
||||
# 计算性能指标
|
||||
time_per_iter_us = total_time / test_iter * 1e6
|
||||
bandwidth = hidden_size * intermediate_size * 3 * num_experts_per_tok * (1/8 * 256 * (1-(31/32)**qlen)) * bytes_per_elem * test_iter / total_time / 1e9 # 单位:GB/s
|
||||
flops = hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12 # 单位:TFLOPS
|
||||
bandwidth = (
|
||||
hidden_size
|
||||
* intermediate_size
|
||||
* 3
|
||||
* num_experts_per_tok
|
||||
* (1 / 8 * 256 * (1 - (31 / 32) ** qlen))
|
||||
* bytes_per_elem
|
||||
* test_iter
|
||||
/ total_time
|
||||
/ 1e9
|
||||
) # 单位:GB/s
|
||||
flops = (
|
||||
hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12
|
||||
) # 单位:TFLOPS
|
||||
|
||||
print('Quant mode: ', quant_mode)
|
||||
print('Time(s): ', total_time)
|
||||
print('Iteration: ', test_iter)
|
||||
print('Time(us) per iteration: ', time_per_iter_us)
|
||||
print('Bandwidth: ', bandwidth, 'GB/s')
|
||||
print('Flops: ', flops, 'TFLOPS')
|
||||
print('')
|
||||
print("Quant mode: ", quant_mode)
|
||||
print("Time(s): ", total_time)
|
||||
print("Iteration: ", test_iter)
|
||||
print("Time(us) per iteration: ", time_per_iter_us)
|
||||
print("Bandwidth: ", bandwidth, "GB/s")
|
||||
print("Flops: ", flops, "TFLOPS")
|
||||
print("")
|
||||
|
||||
# 整理结果记录,包括测试参数
|
||||
result = {
|
||||
@@ -258,8 +292,8 @@ def bench_moe(quant_mode: str):
|
||||
"qlen": qlen,
|
||||
"warm_up_iter": warm_up_iter,
|
||||
"test_iter": test_iter,
|
||||
"CPUInfer_parameter": CPUINFER_PARAM
|
||||
}
|
||||
"CPUInfer_parameter": CPUINFER_PARAM,
|
||||
},
|
||||
}
|
||||
# 添加 git 提交记录信息
|
||||
result.update(get_git_commit())
|
||||
@@ -268,8 +302,9 @@ def bench_moe(quant_mode: str):
|
||||
# 将结果以 JSON 形式追加到文件中
|
||||
record_results(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 选择需要测试的量化模式
|
||||
# bench_moe("bf16")
|
||||
# bench_moe("int8")
|
||||
bench_moe("int4")
|
||||
bench_moe("int8")
|
||||
# bench_moe("int4")
|
||||
|
||||
@@ -13,7 +13,7 @@ import os, sys, time, json, subprocess, platform
|
||||
|
||||
from tqdm import tqdm
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'build'))
|
||||
import cpuinfer_ext
|
||||
import kt_kernel_ext
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
@@ -37,14 +37,14 @@ physical_to_logical_map = torch.tensor(
|
||||
).contiguous()
|
||||
# 将 CPUInfer 参数设为变量
|
||||
# CPUINFER_PARAM = 257
|
||||
# CPUInfer = cpuinfer_ext.CPUInfer(CPUINFER_PARAM)
|
||||
# CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
|
||||
|
||||
worker_config = cpuinfer_ext.WorkerPoolConfig()
|
||||
worker_config = kt_kernel_ext.WorkerPoolConfig()
|
||||
worker_config.subpool_count = 2
|
||||
worker_config.subpool_numa_map= [0,1]
|
||||
worker_config.subpool_thread_count = [40,40]
|
||||
CPUINFER_PARAM = 80
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(worker_config)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
|
||||
|
||||
|
||||
|
||||
@@ -163,7 +163,7 @@ def bench_moe(quant_mode: str):
|
||||
gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda").to("cpu").contiguous()
|
||||
up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda").to("cpu").contiguous()
|
||||
down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cuda").to("cpu").contiguous()
|
||||
config = cpuinfer_ext.moe.MOEConfig(
|
||||
config = kt_kernel_ext.moe.MOEConfig(
|
||||
expert_num, num_experts_per_tok, hidden_size, intermediate_size,0)
|
||||
config.max_len = max_len
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
@@ -171,17 +171,17 @@ def bench_moe(quant_mode: str):
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
if quant_mode == "bf16":
|
||||
moe = cpuinfer_ext.moe.AMXBF16_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
|
||||
elif quant_mode == "int8":
|
||||
moe = cpuinfer_ext.moe.AMXInt8_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXInt8_MOE(config)
|
||||
elif quant_mode == "int4":
|
||||
moe = cpuinfer_ext.moe.AMXInt4_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXInt4_MOE(config)
|
||||
elif quant_mode == "int4_1k":
|
||||
config.quant_config.bits = 4
|
||||
config.quant_config.group_size = k_group_size
|
||||
config.quant_config.zero_point = True
|
||||
config.gate_scale = 0
|
||||
moe = cpuinfer_ext.moe.AMXInt4_1KGroup_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXInt4_1KGroup_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
gate_projs.append(gate_proj)
|
||||
|
||||
331
kt-kernel/bench/bench_moe_kernel.py
Normal file
331
kt-kernel/bench/bench_moe_kernel.py
Normal file
@@ -0,0 +1,331 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
"""
|
||||
Description :
|
||||
Author : chenht2022
|
||||
Date : 2024-07-25 10:32:05
|
||||
Version : 1.0.0
|
||||
LastEditors : chenht2022
|
||||
LastEditTime : 2024-08-06 10:41:28
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
"""
|
||||
import os, sys, time, json, subprocess, platform
|
||||
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
||||
import torch
|
||||
import kt_kernel_ext
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
# 测试参数设置
|
||||
expert_num = 256
|
||||
hidden_size = 7168
|
||||
intermediate_size = 2048
|
||||
max_len = 51200
|
||||
num_experts_per_tok = 8
|
||||
layer_num = 1
|
||||
m_block = 320
|
||||
n_block_up_gate = 32
|
||||
n_block_down = 64
|
||||
n_block_up_gate_prefi = 32
|
||||
n_block_down_prefi = 64
|
||||
qlen = 2048
|
||||
warm_up_iter = 1000
|
||||
test_iter = 1000
|
||||
|
||||
# 将 CPUInfer 参数设为变量
|
||||
CPUINFER_PARAM = 160
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
|
||||
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||
|
||||
|
||||
# worker_config = kt_kernel_ext.WorkerPoolConfig()
|
||||
# worker_config.subpool_count = 4
|
||||
# worker_config.subpool_numa_map= [0,1,2,3]
|
||||
# worker_config.subpool_thread_count = [36,36,36,36]
|
||||
# worker_config.subpool_thread_count = [39,39,39,39]
|
||||
# CPUINFER_PARAM = 156
|
||||
# CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
|
||||
|
||||
|
||||
def get_git_commit():
|
||||
"""
|
||||
获取当前 git 提交记录(commit hash 和提交信息),
|
||||
并检查是否存在未提交的更改(dirty)
|
||||
"""
|
||||
result = {}
|
||||
try:
|
||||
commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
|
||||
commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip()
|
||||
result["commit"] = commit
|
||||
result["commit_message"] = commit_msg
|
||||
|
||||
# 检查是否存在未提交的更改
|
||||
dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
|
||||
if dirty_output:
|
||||
result["dirty"] = True
|
||||
result["dirty_files"] = dirty_output.splitlines()
|
||||
else:
|
||||
result["dirty"] = False
|
||||
except Exception as e:
|
||||
result["commit"] = None
|
||||
result["commit_message"] = None
|
||||
result["dirty"] = None
|
||||
result["error"] = str(e)
|
||||
return result
|
||||
|
||||
|
||||
def get_system_info():
|
||||
"""
|
||||
获取系统信息,包括系统名称、CPU 型号、内存大小(GB)、CPU 核数及 socket 数量
|
||||
"""
|
||||
info = {}
|
||||
# 系统名称及主机名
|
||||
uname = platform.uname()
|
||||
info["system_name"] = uname.system # 如 Linux, Windows 等
|
||||
info["node_name"] = uname.node # 主机名称
|
||||
|
||||
# 获取 CPU 型号(仅 Linux 支持)
|
||||
cpu_model = None
|
||||
if os.path.exists("/proc/cpuinfo"):
|
||||
try:
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
for line in f:
|
||||
if "model name" in line:
|
||||
cpu_model = line.split(":", 1)[1].strip()
|
||||
break
|
||||
except Exception as e:
|
||||
cpu_model = f"Error: {e}"
|
||||
info["cpu_model"] = cpu_model
|
||||
|
||||
# 获取内存大小(单位:GB),仅 Linux 支持
|
||||
mem_total_gb = None
|
||||
if os.path.exists("/proc/meminfo"):
|
||||
try:
|
||||
with open("/proc/meminfo", "r") as f:
|
||||
for line in f:
|
||||
if "MemTotal" in line:
|
||||
mem_kb = float(line.split(":", 1)[1].split()[0])
|
||||
mem_total_gb = round(mem_kb / (1024 * 1024), 2)
|
||||
break
|
||||
except Exception as e:
|
||||
mem_total_gb = f"Error: {e}"
|
||||
info["memory_size_GB"] = mem_total_gb
|
||||
|
||||
# 获取 CPU 核数(逻辑核数)
|
||||
info["cpu_core_count"] = os.cpu_count()
|
||||
|
||||
# 解析 /proc/cpuinfo 获取 socket 数量
|
||||
sockets = set()
|
||||
if os.path.exists("/proc/cpuinfo"):
|
||||
try:
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
for line in f:
|
||||
if "physical id" in line:
|
||||
sockets.add(line.split(":", 1)[1].strip())
|
||||
except Exception as e:
|
||||
sockets = set()
|
||||
# 如果没有解析到 socket 信息,则默认至少有 1 个 socket
|
||||
info["cpu_socket_count"] = len(sockets) if len(sockets) > 0 else 1
|
||||
|
||||
return info
|
||||
|
||||
|
||||
script_path = os.path.abspath(__file__)
|
||||
script_dir = os.path.dirname(script_path)
|
||||
script_name = os.path.splitext(os.path.basename(script_path))[0]
|
||||
json_path = os.path.join(script_dir, "bench_results " + ".jsonl")
|
||||
|
||||
|
||||
def record_results(result, filename=json_path):
|
||||
"""
|
||||
将结果以 JSON 格式追加到文件中
|
||||
"""
|
||||
with open(filename, "a") as f:
|
||||
f.write(json.dumps(result) + "\n")
|
||||
|
||||
|
||||
def bench_moe(quant_mode: str):
|
||||
with torch.inference_mode():
|
||||
if quant_mode == "int8":
|
||||
bytes_per_elem = 1.0
|
||||
elif quant_mode == "int4":
|
||||
bytes_per_elem = 0.5
|
||||
else:
|
||||
raise ValueError("不支持的量化模式")
|
||||
|
||||
moes = []
|
||||
gate_projs = []
|
||||
up_projs = []
|
||||
down_projs = []
|
||||
for layer_index in range(layer_num):
|
||||
gate_proj = (
|
||||
torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda")
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
up_proj = (
|
||||
torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cuda")
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
down_proj = (
|
||||
torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cuda")
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
if quant_mode == "int8":
|
||||
d = kt_kernel_ext.moe.tiling.get_int8()
|
||||
nbug_prefi = n_block_up_gate_prefi
|
||||
nbd_prefi = n_block_down_prefi
|
||||
kb = d["k_block"]
|
||||
nb = d["n_block"]
|
||||
mb = m_block
|
||||
nbug = n_block_up_gate
|
||||
nbd = n_block_down
|
||||
print(
|
||||
f"Int8 Tiling: nbug {nbug}, nbd {nbd}, nb {nb}, mb {mb}, kb {kb}, nbug_prefi {nbug_prefi}, nbd_prefi {nbd_prefi}"
|
||||
)
|
||||
kt_kernel_ext.moe.tiling.set_int8(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
|
||||
moe = kt_kernel_ext.moe.Int8_KERNEL_MOE(config)
|
||||
elif quant_mode == "int4":
|
||||
moe = kt_kernel_ext.moe.Int4_KERNEL_MOE(config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode: {quant_mode}")
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
gate_projs.append(gate_proj)
|
||||
up_projs.append(up_proj)
|
||||
down_projs.append(down_proj)
|
||||
moes.append(moe)
|
||||
|
||||
expert_ids = (
|
||||
torch.rand(test_iter * qlen, expert_num, device="cuda")
|
||||
.argsort(dim=-1)[:, :num_experts_per_tok]
|
||||
.reshape(test_iter, qlen * num_experts_per_tok)
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
weights = (
|
||||
torch.rand((test_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cuda")
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
input_tensor = (
|
||||
torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
|
||||
)
|
||||
output_tensor = (
|
||||
torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
|
||||
)
|
||||
bsz_tensor = torch.tensor([qlen], device="cuda").to("cpu").contiguous()
|
||||
|
||||
# 预热迭代
|
||||
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
|
||||
# print(f'warmup iteration {i}')
|
||||
# start_it = time.time_ns()
|
||||
CPUInfer.submit(
|
||||
moes[i % layer_num].forward_task(
|
||||
bsz_tensor.data_ptr(),
|
||||
num_experts_per_tok,
|
||||
expert_ids[i].data_ptr(),
|
||||
weights[i].data_ptr(),
|
||||
input_tensor[i % layer_num].data_ptr(),
|
||||
output_tensor[i % layer_num].data_ptr(),
|
||||
# False,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
# end_it = time.time_ns()
|
||||
# print('python Time(ns): ', end_it - start_it)
|
||||
|
||||
# 测试迭代
|
||||
start = time.perf_counter()
|
||||
for i in tqdm(range(test_iter), desc="Testing"):
|
||||
# print(f'test iteration {i}')
|
||||
# start_it = time.time_ns()
|
||||
CPUInfer.submit(
|
||||
moes[i % layer_num].forward_task(
|
||||
bsz_tensor.data_ptr(),
|
||||
num_experts_per_tok,
|
||||
expert_ids[i].data_ptr(),
|
||||
weights[i].data_ptr(),
|
||||
input_tensor[i % layer_num].data_ptr(),
|
||||
output_tensor[i % layer_num].data_ptr(),
|
||||
False,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
# end_it = time.time_ns()
|
||||
# print('python Time(ns): ', end_it - start_it)
|
||||
end = time.perf_counter()
|
||||
total_time = end - start
|
||||
|
||||
# 计算性能指标
|
||||
time_per_iter_us = total_time / test_iter * 1e6
|
||||
bandwidth = (
|
||||
hidden_size
|
||||
* intermediate_size
|
||||
* 3
|
||||
* num_experts_per_tok
|
||||
# * (1 / 8 * 256 * (1 - (31 / 32) ** qlen))
|
||||
* qlen
|
||||
* bytes_per_elem
|
||||
* test_iter
|
||||
/ total_time
|
||||
/ 1e9
|
||||
) # 单位:GB/s
|
||||
flops = (
|
||||
hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12
|
||||
) # 单位:TFLOPS
|
||||
|
||||
print("Quant mode: ", quant_mode)
|
||||
print("Time(s): ", total_time)
|
||||
print("Iteration: ", test_iter)
|
||||
print("Time(us) per iteration: ", time_per_iter_us)
|
||||
print("Bandwidth: ", bandwidth, "GB/s")
|
||||
print("Flops: ", flops, "TFLOPS")
|
||||
print("")
|
||||
|
||||
# 整理结果记录,包括测试参数
|
||||
result = {
|
||||
"test_name": os.path.basename(__file__),
|
||||
"quant_mode": quant_mode,
|
||||
"total_time_seconds": total_time,
|
||||
"iterations": test_iter,
|
||||
"time_per_iteration_us": time_per_iter_us,
|
||||
"bandwidth_GBs": bandwidth,
|
||||
"flops_TFLOPS": flops,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||||
"test_parameters": {
|
||||
"expert_num": expert_num,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size": intermediate_size,
|
||||
"max_len": max_len,
|
||||
"num_experts_per_tok": num_experts_per_tok,
|
||||
"layer_num": layer_num,
|
||||
"qlen": qlen,
|
||||
"warm_up_iter": warm_up_iter,
|
||||
"test_iter": test_iter,
|
||||
"CPUInfer_parameter": CPUINFER_PARAM,
|
||||
},
|
||||
}
|
||||
# 添加 git 提交记录信息
|
||||
result.update(get_git_commit())
|
||||
# 添加系统信息(包括 CPU 核数和 socket 数量)
|
||||
result.update(get_system_info())
|
||||
# 将结果以 JSON 形式追加到文件中
|
||||
record_results(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 选择需要测试的量化模式
|
||||
bench_moe("int8")
|
||||
# bench_moe("int4")
|
||||
232
kt-kernel/bench/bench_moe_kernel_tiling.py
Normal file
232
kt-kernel/bench/bench_moe_kernel_tiling.py
Normal file
@@ -0,0 +1,232 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
"""
|
||||
Bench MOE kernel with runtime tiling params (N_BLOCK_UP_GATE, N_BLOCK_DOWN, N_BLOCK, M_BLOCK, K_BLOCK)
|
||||
- Demonstrates how to get/set tiling params from Python via kt_kernel_ext.moe.tiling
|
||||
- Runs a small benchmark similar to bench_moe_kernel.py
|
||||
|
||||
Usage examples:
|
||||
# 1) Just run with defaults (int8)
|
||||
python bench_moe_kernel_tiling.py --quant int8
|
||||
|
||||
# 2) Override tiling params for INT8
|
||||
python bench_moe_kernel_tiling.py --quant int8 \
|
||||
--n_block_up_gate 32 --n_block_down 64 --n_block 64 --m_block 320 --k_block 7168
|
||||
|
||||
# 3) Set both INT8 and INT4 tiling params (if INT4 kernel is available on your platform)
|
||||
python bench_moe_kernel_tiling.py --quant int4 --set_all \
|
||||
--n_block_up_gate 256 --n_block_down 1024 --n_block 64 --m_block 320 --k_block 7168
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
|
||||
os.environ.setdefault("BLAS_NUM_THREADS", "1")
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
||||
|
||||
import torch # noqa: E402
|
||||
import kt_kernel_ext as ce # noqa: E402
|
||||
from tqdm import tqdm # noqa: E402
|
||||
|
||||
|
||||
def maybe_get_class(module, name):
|
||||
return getattr(module, name) if hasattr(module, name) else None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--quant", choices=["int8", "int4"], default="int8")
|
||||
parser.add_argument("--expert_num", type=int, default=256)
|
||||
parser.add_argument("--hidden_size", type=int, default=7168)
|
||||
parser.add_argument("--intermediate_size", type=int, default=2048)
|
||||
parser.add_argument("--num_experts_per_tok", type=int, default=8)
|
||||
parser.add_argument("--max_len", type=int, default=25600)
|
||||
parser.add_argument("--layer_num", type=int, default=1)
|
||||
parser.add_argument("--qlen", type=int, default=1024)
|
||||
parser.add_argument("--warm_up_iter", type=int, default=200)
|
||||
parser.add_argument("--test_iter", type=int, default=500)
|
||||
parser.add_argument("--threads", type=int, default=160, help="CPUInfer initialization param")
|
||||
|
||||
# Tiling params
|
||||
parser.add_argument("--set_all", action="store_true", help="Apply tiling to both INT8 and INT4 kernels")
|
||||
parser.add_argument("--n_block_up_gate", type=int, default=None)
|
||||
parser.add_argument("--n_block_down", type=int, default=None)
|
||||
parser.add_argument("--n_block", type=int, default=None)
|
||||
parser.add_argument("--m_block", type=int, default=None)
|
||||
parser.add_argument("--k_block", type=int, default=None)
|
||||
parser.add_argument("--n_block_up_gate_prefi", type=int, default=None)
|
||||
parser.add_argument("--n_block_down_prefi", type=int, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Show current tiling defaults
|
||||
if args.quant == "int8":
|
||||
print("[tiling] default int8:", ce.moe.tiling.get_int8())
|
||||
if hasattr(ce.moe.tiling, "get_int4") and args.quant == "int4":
|
||||
print("[tiling] default int4:", ce.moe.tiling.get_int4())
|
||||
|
||||
# Apply overrides if provided
|
||||
if any(v is not None for v in [args.n_block_up_gate, args.n_block_down, args.n_block, args.m_block, args.k_block]):
|
||||
# Fill missing values with current defaults to avoid overwriting unrelated params
|
||||
def fill_defaults(getter):
|
||||
cur = getter()
|
||||
return (
|
||||
args.n_block_up_gate if args.n_block_up_gate is not None else int(cur["n_block_up_gate"]),
|
||||
args.n_block_down if args.n_block_down is not None else int(cur["n_block_down"]),
|
||||
args.n_block if args.n_block is not None else int(cur["n_block"]),
|
||||
args.m_block if args.m_block is not None else int(cur["m_block"]),
|
||||
args.k_block if args.k_block is not None else int(cur["k_block"]),
|
||||
(
|
||||
args.n_block_up_gate_prefi
|
||||
if args.n_block_up_gate_prefi is not None
|
||||
else int(cur["n_block_up_gate_prefi"])
|
||||
),
|
||||
args.n_block_down_prefi if args.n_block_down_prefi is not None else int(cur["n_block_down_prefi"]),
|
||||
)
|
||||
|
||||
if args.set_all and hasattr(ce.moe.tiling, "set_all"):
|
||||
nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int8)
|
||||
ce.moe.tiling.set_all(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
|
||||
print("[tiling] set_all ->", ce.moe.tiling.get_int8())
|
||||
if hasattr(ce.moe.tiling, "get_int4"):
|
||||
print("[tiling] set_all -> int4:", ce.moe.tiling.get_int4())
|
||||
else:
|
||||
if args.quant == "int8":
|
||||
nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int8)
|
||||
ce.moe.tiling.set_int8(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
|
||||
print("[tiling] set_int8 ->", ce.moe.tiling.get_int8())
|
||||
elif args.quant == "int4" and hasattr(ce.moe.tiling, "set_int4"):
|
||||
nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int4)
|
||||
ce.moe.tiling.set_int4(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
|
||||
print("[tiling] set_int4 ->", ce.moe.tiling.get_int4())
|
||||
|
||||
# Warn about divisibility expectations; kernels assume specific blocking
|
||||
# - Some helpers assert n % N_BLOCK == 0, etc. Ensure your dims/tiles align.
|
||||
print("[note] Ensure your selected tiling parameters are compatible with hidden/intermediate sizes and blocking.")
|
||||
|
||||
# Initialize CPUInfer
|
||||
CPUInfer = ce.CPUInfer(args.threads)
|
||||
|
||||
# Select MOE kernel
|
||||
moe_cls = None
|
||||
if args.quant == "int8":
|
||||
moe_cls = maybe_get_class(ce.moe, "Int8_KERNEL_MOE")
|
||||
if moe_cls is None:
|
||||
raise RuntimeError("Int8 kernel binding 'Int8_KERNEL_MOE' not found.")
|
||||
bytes_per_elem = 1.0
|
||||
else:
|
||||
moe_cls = maybe_get_class(ce.moe, "Int4_KERNEL_MOE")
|
||||
if moe_cls is None:
|
||||
raise RuntimeError("Int4 kernel binding 'Int4_KERNEL_MOE' not available on this platform.")
|
||||
bytes_per_elem = 0.5
|
||||
|
||||
# Prepare config/weights
|
||||
expert_num = args.expert_num
|
||||
hidden_size = args.hidden_size
|
||||
intermediate_size = args.intermediate_size
|
||||
num_experts_per_tok = args.num_experts_per_tok
|
||||
layer_num = args.layer_num
|
||||
max_len = args.max_len
|
||||
|
||||
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
||||
|
||||
moes = []
|
||||
gate_projs, up_projs, down_projs = [], [], []
|
||||
|
||||
for layer_idx in range(layer_num):
|
||||
gate_proj = torch.randn(
|
||||
(expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cpu"
|
||||
).contiguous()
|
||||
up_proj = torch.randn(
|
||||
(expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cpu"
|
||||
).contiguous()
|
||||
down_proj = torch.randn(
|
||||
(expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cpu"
|
||||
).contiguous()
|
||||
|
||||
cfg = ce.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
cfg.max_len = max_len
|
||||
cfg.gate_proj = gate_proj.data_ptr()
|
||||
cfg.up_proj = up_proj.data_ptr()
|
||||
cfg.down_proj = down_proj.data_ptr()
|
||||
cfg.pool = CPUInfer.backend_
|
||||
|
||||
moe = moe_cls(cfg)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
|
||||
gate_projs.append(gate_proj)
|
||||
up_projs.append(up_proj)
|
||||
down_projs.append(down_proj)
|
||||
moes.append(moe)
|
||||
|
||||
qlen = args.qlen
|
||||
warm_up_iter = args.warm_up_iter
|
||||
test_iter = args.test_iter
|
||||
|
||||
expert_ids = (
|
||||
torch.rand(test_iter * qlen, expert_num)
|
||||
.argsort(dim=-1)[:, :num_experts_per_tok]
|
||||
.reshape(test_iter, qlen * num_experts_per_tok)
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
weights = torch.rand((test_iter, qlen, num_experts_per_tok), dtype=torch.float32).to("cpu").contiguous()
|
||||
input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16).to("cpu").contiguous()
|
||||
output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16).to("cpu").contiguous()
|
||||
bsz_tensor = torch.tensor([qlen], dtype=torch.int32).to("cpu").contiguous()
|
||||
|
||||
# Warmup
|
||||
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
|
||||
CPUInfer.submit(
|
||||
moes[i % layer_num].forward_task(
|
||||
bsz_tensor.data_ptr(),
|
||||
num_experts_per_tok,
|
||||
expert_ids[i].data_ptr(),
|
||||
weights[i].data_ptr(),
|
||||
input_tensor[i % layer_num].data_ptr(),
|
||||
output_tensor[i % layer_num].data_ptr(),
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
|
||||
# Measure
|
||||
start = time.perf_counter()
|
||||
for i in tqdm(range(test_iter), desc="Testing"):
|
||||
CPUInfer.submit(
|
||||
moes[i % layer_num].forward_task(
|
||||
bsz_tensor.data_ptr(),
|
||||
num_experts_per_tok,
|
||||
expert_ids[i].data_ptr(),
|
||||
weights[i].data_ptr(),
|
||||
input_tensor[i % layer_num].data_ptr(),
|
||||
output_tensor[i % layer_num].data_ptr(),
|
||||
False,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
end = time.perf_counter()
|
||||
|
||||
total_time = end - start
|
||||
time_per_iter_us = total_time / test_iter * 1e6
|
||||
bandwidth_gbs = (
|
||||
hidden_size * intermediate_size * 3 * num_experts_per_tok * qlen * bytes_per_elem * test_iter / total_time / 1e9
|
||||
)
|
||||
flops_tflops = hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12
|
||||
|
||||
print("\n=== Results ===")
|
||||
print("quant:", args.quant)
|
||||
if hasattr(ce.moe.tiling, "get_int8") and args.quant == "int8":
|
||||
print("tiling int8:", ce.moe.tiling.get_int8())
|
||||
if hasattr(ce.moe.tiling, "get_int4") and args.quant == "int4":
|
||||
print("tiling int4:", ce.moe.tiling.get_int4())
|
||||
print("time (s):", total_time)
|
||||
print("iter:", test_iter)
|
||||
print("time per iter (us):", time_per_iter_us)
|
||||
print("bandwidth (GB/s):", bandwidth_gbs)
|
||||
print("TFLOPS:", flops_tflops)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -13,7 +13,7 @@ import os, sys, time, json, subprocess, platform
|
||||
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
||||
import cpuinfer_ext
|
||||
import kt_kernel_ext
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
@@ -33,15 +33,15 @@ test_iter = 10000
|
||||
|
||||
# 将 CPUInfer 参数设为变量
|
||||
CPUINFER_PARAM = 112
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(CPUINFER_PARAM)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
|
||||
|
||||
# worker_config = cpuinfer_ext.WorkerPoolConfig()
|
||||
# worker_config = kt_kernel_ext.WorkerPoolConfig()
|
||||
# worker_config.subpool_count = 4
|
||||
# worker_config.subpool_numa_map= [0,1,2,3]
|
||||
# worker_config.subpool_thread_count = [36,36,36,36]
|
||||
# worker_config.subpool_thread_count = [39,39,39,39]
|
||||
# CPUINFER_PARAM = 156
|
||||
# CPUInfer = cpuinfer_ext.CPUInfer(worker_config)
|
||||
# CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
|
||||
|
||||
|
||||
def get_git_commit():
|
||||
@@ -166,16 +166,16 @@ def bench_moe(quant_mode: str):
|
||||
down_proj = torch.randn(
|
||||
(expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cpu"
|
||||
).contiguous()
|
||||
config = cpuinfer_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
config.max_len = max_len
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
if quant_mode == "int8":
|
||||
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
|
||||
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
|
||||
elif quant_mode == "int4":
|
||||
moe = cpuinfer_ext.moe.KMLInt4_MOE(config)
|
||||
moe = kt_kernel_ext.moe.KMLInt4_MOE(config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode: {quant_mode}")
|
||||
CPUInfer.submit(moe.load_weights_task())
|
||||
|
||||
@@ -50,12 +50,12 @@ import torch
|
||||
|
||||
# Try importing both implementations
|
||||
try:
|
||||
import cpuinfer_ext
|
||||
import kt_kernel_ext
|
||||
KTRANSFORMERS_AVAILABLE = True
|
||||
logger.info("KTransformers cpuinfer_ext loaded successfully")
|
||||
logger.info("KTransformers kt_kernel_ext loaded successfully")
|
||||
except ImportError as e:
|
||||
KTRANSFORMERS_AVAILABLE = False
|
||||
logger.warning(f"KTransformers cpuinfer_ext not available: {e}")
|
||||
logger.warning(f"KTransformers kt_kernel_ext not available: {e}")
|
||||
|
||||
try:
|
||||
from sgl_kernel.common_ops import fused_experts_cpu
|
||||
@@ -472,11 +472,11 @@ def bench_ktransformers_moe(test_config: TestConfig, quant_mode: str, qlen: int,
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
# Setup worker config with consistent threads per NUMA
|
||||
worker_config = cpuinfer_ext.WorkerPoolConfig()
|
||||
worker_config = kt_kernel_ext.WorkerPoolConfig()
|
||||
worker_config.subpool_count = sys_config.numa_count
|
||||
worker_config.subpool_numa_map = list(range(sys_config.numa_count))
|
||||
worker_config.subpool_thread_count = [thread_config.threads_per_numa] * sys_config.numa_count
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(worker_config)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
|
||||
|
||||
# Create MoE layers
|
||||
moes = []
|
||||
@@ -493,7 +493,7 @@ def bench_ktransformers_moe(test_config: TestConfig, quant_mode: str, qlen: int,
|
||||
down_proj = torch.randn((test_config.expert_num, test_config.hidden_size, test_config.intermediate_size),
|
||||
dtype=torch.float32).contiguous()
|
||||
|
||||
config = cpuinfer_ext.moe.MOEConfig(
|
||||
config = kt_kernel_ext.moe.MOEConfig(
|
||||
test_config.expert_num, test_config.num_experts_per_tok,
|
||||
test_config.hidden_size, test_config.intermediate_size)
|
||||
config.max_len = test_config.max_len
|
||||
@@ -503,11 +503,11 @@ def bench_ktransformers_moe(test_config: TestConfig, quant_mode: str, qlen: int,
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
if quant_mode == "bf16":
|
||||
moe = cpuinfer_ext.moe.AMXBF16_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
|
||||
elif quant_mode == "int8":
|
||||
moe = cpuinfer_ext.moe.AMXInt8_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXInt8_MOE(config)
|
||||
elif quant_mode == "int4":
|
||||
moe = cpuinfer_ext.moe.AMXInt4_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXInt4_MOE(config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode: {quant_mode}")
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ def update_bench_parameters(params):
|
||||
bench.test_iter = params["test_iter"]
|
||||
bench.CPUINFER_PARAM = params["CPUINFER_PARAM"]
|
||||
# 重新初始化 CPUInfer 对象
|
||||
bench.CPUInfer = bench.cpuinfer_ext.CPUInfer(bench.CPUINFER_PARAM)
|
||||
bench.CPUInfer = bench.kt_kernel_ext.CPUInfer(bench.CPUINFER_PARAM)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
# CFLAGS += -march=armv8.2-a+fp16+dotprod+sve+bf16 -I/home/test/kt-code/HPCKit_25.0.0_Linux-aarch64/package/KunpengHPCKit-kml.25.0.0/include
|
||||
# CFLAGS += -march=armv8.2-a+fp16+dotprod+sve+bf16 -I/home/test/kt-code/HPCKit_25.0.0_Linux-aarch64/package/KunpengHPCKit-kml.25.0.0/include
|
||||
CFLAGS += -O3
|
||||
CFLAGS += -I/usr/local/include/blis/
|
||||
CFLAGS += -I/usr/local/include/blis/ -fopenmp
|
||||
LDLIBS += -L/usr/local/lib -lblis
|
||||
# LDLIBS += $(shell pkg-config --libs hwloc) -lkml_rt
|
||||
|
||||
CXX = g++
|
||||
CXX = /usr/bin/g++
|
||||
|
||||
# i8_cal: i8_cal.cpp
|
||||
# $(CXX) i8_cal.cpp $(CFLAGS) -o i8_cal $(LDLIBS)
|
||||
@@ -17,9 +17,8 @@ simple_test_build: simple_test.cpp
|
||||
rm -f simple_test
|
||||
BLAS_NUM_THREADS=1 $(CXX) simple_test.cpp $(CFLAGS) -o simple_test $(LDLIBS)
|
||||
|
||||
simple_aocl_build: simple_test_aocl.cpp
|
||||
rm -f simple_test_aocl
|
||||
BLAS_NUM_THREADS=1 $(CXX) simple_test_aocl.cpp $(CFLAGS) -o simple_test_aocl $(LDLIBS)
|
||||
simple_aocl_build: build simple_test_aocl.cpp
|
||||
$(CXX) simple_test_aocl.cpp $(CFLAGS) -o build/simple_test_aocl $(LDLIBS)
|
||||
|
||||
fp16_test_build: fp16-test.cpp
|
||||
rm -f fp16-test
|
||||
@@ -27,6 +26,11 @@ fp16_test_build: fp16-test.cpp
|
||||
bf16_test_build: bf16-test.cpp
|
||||
rm -f bf16-test
|
||||
$(CXX) bf16-test.cpp $(CFLAGS) -o bf16-test $(LDLIBS)
|
||||
|
||||
build: build
|
||||
mkdir -p build
|
||||
bandwidth_build: bench_reorder_bandwidth.cpp
|
||||
$(CXX) bench_reorder_bandwidth.cpp $(CFLAGS) -o build/bench_reorder_bandwidth $(LDLIBS)
|
||||
run: simple_aocl_build
|
||||
LD_LIBRARY_PATH=/usr/local/lib:$$LD_LIBRARY_PATH ./simple_test_aocl
|
||||
LD_LIBRARY_PATH=/usr/local/lib:$$LD_LIBRARY_PATH ./build/simple_test_aocl
|
||||
run_bandwidth: bandwidth_build
|
||||
LD_LIBRARY_PATH=/usr/local/lib:$$LD_LIBRARY_PATH ./build/bench_reorder_bandwidth
|
||||
110
kt-kernel/demo/bench_reorder_bandwidth.cpp
Normal file
110
kt-kernel/demo/bench_reorder_bandwidth.cpp
Normal file
@@ -0,0 +1,110 @@
|
||||
#include <blis.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
namespace {
|
||||
constexpr int kM = 1;
|
||||
constexpr int kK = 7168;
|
||||
constexpr int kN = 512;
|
||||
constexpr int kIters = 10000;
|
||||
|
||||
void fill_random(int8_t* ptr, size_t count) {
|
||||
std::srand(47);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
ptr[i] = static_cast<int8_t>(std::rand() % 30);
|
||||
}
|
||||
}
|
||||
|
||||
void fill_zero(int32_t* ptr, size_t count) { std::memset(ptr, 0, count * sizeof(int32_t)); }
|
||||
|
||||
bool verify(const int8_t* a, const int8_t* b, const int32_t* c) {
|
||||
for (int m = 0; m < kM; ++m) {
|
||||
for (int n = 0; n < kN; ++n) {
|
||||
int32_t ref = 0;
|
||||
for (int k = 0; k < kK; ++k) {
|
||||
ref += static_cast<int32_t>(a[m * kK + k]) * static_cast<int32_t>(b[n * kK + k]);
|
||||
}
|
||||
if (ref != c[m * kN + n]) {
|
||||
std::printf("Mismatch at (%d, %d): got %d, expect %d\n", m, n, c[m * kN + n], ref);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int main() {
|
||||
int8_t* a = static_cast<int8_t*>(std::aligned_alloc(64, kM * kK));
|
||||
int8_t* b = static_cast<int8_t*>(std::aligned_alloc(64, kK * kN));
|
||||
int32_t* c = static_cast<int32_t*>(std::aligned_alloc(64, kM * kN * sizeof(int32_t)));
|
||||
int32_t* c_tmp = static_cast<int32_t*>(std::aligned_alloc(64, kM * kN * sizeof(int32_t)));
|
||||
|
||||
if (!a || !b || !c || !c_tmp) {
|
||||
std::fprintf(stderr, "Allocation failed.\n");
|
||||
std::free(a);
|
||||
std::free(b);
|
||||
std::free(c);
|
||||
std::free(c_tmp);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
fill_random(a, kM * kK);
|
||||
fill_random(b, kK * kN);
|
||||
fill_zero(c, kM * kN);
|
||||
fill_zero(c_tmp, kM * kN);
|
||||
|
||||
const dim_t reorder_size = aocl_get_reorder_buf_size_s8s8s32os32('r', 't', 'B', kK, kN);
|
||||
int8_t* b_reordered = static_cast<int8_t*>(std::aligned_alloc(64, reorder_size));
|
||||
if (!b_reordered) {
|
||||
std::fprintf(stderr, "Reorder buffer allocation failed.\n");
|
||||
std::free(a);
|
||||
std::free(b);
|
||||
std::free(c);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
aocl_reorder_s8s8s32os32('r', 't', 'B', b, b_reordered, kK, kN, kK);
|
||||
|
||||
// Warm-up GEMM to load kernels.
|
||||
aocl_gemm_s8s8s32os32('r', 'n', 't', kM, kN, kK, 1, a, kK, 'n', b_reordered, kK, 'r', 0, c_tmp, kN, nullptr);
|
||||
fill_zero(c, kM * kN);
|
||||
|
||||
const double bytes_per_mul = static_cast<double>(kM) * kK * sizeof(int8_t) + // A matrix read
|
||||
static_cast<double>(kK) * kN * sizeof(int8_t); // original B read
|
||||
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
for (int iter = 0; iter < kIters; ++iter) {
|
||||
aocl_gemm_s8s8s32os32('r', 'n', 't', kM, kN, kK, 1, a, kK, 'n', b_reordered, kK, 'r', 0, c, kN, nullptr);
|
||||
}
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
const double elapsed_seconds = std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
|
||||
const double total_bytes = bytes_per_mul * kIters;
|
||||
const double bandwidth_gbps = total_bytes / elapsed_seconds / 1e9;
|
||||
const double ops_per_mul = static_cast<double>(kM) * kN * kK * 2.0;
|
||||
const double tflops = (ops_per_mul * kIters) / elapsed_seconds / 1e12;
|
||||
|
||||
std::printf("Reorder buffer size: %ld bytes\n", static_cast<long>(reorder_size));
|
||||
std::printf("Iterations: %d\n", kIters);
|
||||
std::printf("Elapsed time: %.4f s\n", elapsed_seconds);
|
||||
std::printf("Effective bandwidth: %.2f GB/s\n", bandwidth_gbps);
|
||||
std::printf("Int8 GEMM throughput: %.2f TOPS\n", tflops * 1e3);
|
||||
|
||||
if (!verify(a, b, c)) {
|
||||
std::fprintf(stderr, "Verification failed.\n");
|
||||
} else {
|
||||
std::puts("Verification passed.");
|
||||
}
|
||||
|
||||
std::free(a);
|
||||
std::free(b);
|
||||
std::free(b_reordered);
|
||||
std::free(c);
|
||||
return 0;
|
||||
}
|
||||
@@ -1,119 +1,173 @@
|
||||
#include <blis.h>
|
||||
#include <dlfcn.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
// #define CHECK
|
||||
namespace {
|
||||
// B matrix is in col-major order
|
||||
constexpr int kM = 3;
|
||||
constexpr int kK = 7168;
|
||||
constexpr int kN = 2048;
|
||||
void fill_inputs(int8_t* a, int8_t* b) {
|
||||
srand(static_cast<unsigned>(time(nullptr)));
|
||||
for (int i = 0; i < kM * kK; ++i) {
|
||||
a[i] = static_cast<int8_t>(rand() % 127);
|
||||
}
|
||||
for (int i = 0; i < kK * kN; ++i) {
|
||||
b[i] = static_cast<int8_t>(rand() % 127);
|
||||
}
|
||||
}
|
||||
|
||||
void compute_reference(const int8_t* a, const int8_t* b, int32_t* ref) {
|
||||
for (int m = 0; m < kM; ++m) {
|
||||
for (int n = 0; n < kN; ++n) {
|
||||
int32_t acc = 0;
|
||||
for (int k = 0; k < kK; ++k) {
|
||||
acc += static_cast<int32_t>(a[m * kK + k]) * static_cast<int32_t>(b[k * kN + n]);
|
||||
}
|
||||
ref[m * kN + n] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool check_result(const int32_t* got, const int32_t* ref) {
|
||||
for (int idx = 0; idx < kM * kN; ++idx) {
|
||||
if (got[idx] != ref[idx]) {
|
||||
std::printf("Mismatch at %d: got %d, expected %d\n", idx, got[idx], ref[idx]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int main() {
|
||||
// 矩阵维度 M 是 1024,K 是 1024,N 是 1024(行主序)
|
||||
int M = 1024; // 行主序时,A 的行长度为 K
|
||||
const int K = 1024; // B 的行长度为 N
|
||||
const int N = 1024; // C 的行长度为 N
|
||||
const int iter = 10000; // 迭代次数
|
||||
err_t err = BLIS_SUCCESS;
|
||||
int8_t* a = static_cast<int8_t*>(bli_malloc_user(kM * kK, &err));
|
||||
int8_t* b = static_cast<int8_t*>(bli_malloc_user(kK * kN, &err));
|
||||
int8_t* b_rowmajor = static_cast<int8_t*>(bli_malloc_user(kK * kN, &err));
|
||||
int8_t* b_reordered = nullptr;
|
||||
int32_t* c = static_cast<int32_t*>(bli_malloc_user(kM * kN * sizeof(int32_t), &err));
|
||||
int32_t* c_unp = static_cast<int32_t*>(bli_malloc_user(kM * kN * sizeof(int32_t), &err));
|
||||
int32_t* ref = static_cast<int32_t*>(bli_malloc_user(kM * kN * sizeof(int32_t), &err));
|
||||
|
||||
// 分配矩阵内存
|
||||
int8_t* A = (int8_t*)malloc(M * K * sizeof(int8_t));
|
||||
int8_t* B = (int8_t*)malloc(K * N * sizeof(int8_t));
|
||||
int32_t* C = (int32_t*)malloc(M * N * sizeof(int32_t));
|
||||
|
||||
// 初始化随机种子
|
||||
srand((unsigned)time(NULL));
|
||||
|
||||
// 随机初始化 A(范围 0 到 255)和 B(范围 -128 到 127)
|
||||
// 初始化矩阵 A 和 B
|
||||
for (int j = 0; j < M * K; j++) {
|
||||
// A[j] = rand() % 256;
|
||||
A[j] = j;
|
||||
}
|
||||
for (int j = 0; j < K * N; j++) {
|
||||
// B[j] = rand() % 256;
|
||||
B[j] = j;
|
||||
}
|
||||
// 初始化矩阵 C
|
||||
for (int j = 0; j < M * N; j++) {
|
||||
C[j] = 0;
|
||||
if (!a || !b || !c || !ref || !c_unp) {
|
||||
std::fprintf(stderr, "Allocation failed\n");
|
||||
bli_free_user(a);
|
||||
bli_free_user(b);
|
||||
bli_free_user(c);
|
||||
bli_free_user(ref);
|
||||
bli_free_user(c_unp);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
// 设置 cblas_gemm_s8u8s32 的参数
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
int8_t oa = 0, ob = 0;
|
||||
int32_t oc = 0;
|
||||
|
||||
// 打印矩阵 A、B
|
||||
// printf("A=\n");
|
||||
// for (int i = 0; i < M; i++) {
|
||||
// for (int j = 0; j < K; j++) {
|
||||
// printf("%d ", A[i * K + j]);
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
// printf("B=\n");
|
||||
// for (int i = 0; i < N; i++) {
|
||||
// for (int j = 0; j < K; j++) {
|
||||
// printf("%d ", B[i * K + j]);
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
|
||||
// printf("format: 'generate end'\n");
|
||||
// 调用 cblas_gemm_s8u8s32 执行矩阵乘法:C = i1(A+ao)(B+bo) + 0*C + oc
|
||||
// 从m=10~256 都测一遍速度,步长是 stride
|
||||
int stride = 2;
|
||||
int start_m = M;
|
||||
for (int m = start_m; m <= M; m += stride) {
|
||||
// 记录开始时间
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
#pragma GCC unroll 8
|
||||
for (int i = 0; i < iter; i++) {
|
||||
// cblas_gemm_s8s8s32(CblasRowMajor, CblasNoTrans, CblasTrans, CblasFixOffset, m, N / 2, K, alpha, A, K, oa, B, K,
|
||||
// ob, beta, C, N, &oc);
|
||||
aocl_gemm_s8s8s32os32('r', 'n', 't', m, N / 2, K, (int32_t)alpha, A, K, 'n', B, K, 'n', (int32_t)beta, C, N,
|
||||
nullptr);
|
||||
int8_t* B_high = B + K * N / 2;
|
||||
int32_t* C_high = C + N / 2;
|
||||
// cblas_gemm_s8s8s32(CblasRowMajor, CblasNoTrans, CblasTrans, CblasFixOffset, m, N / 2, K, alpha, A, K, oa,
|
||||
// B_high,
|
||||
// K, ob, beta, C_high, N, &oc);
|
||||
aocl_gemm_s8s8s32os32('r', 'n', 't', m, N / 2, K, (int32_t)alpha, A, K, 'n', B_high, K, 'n', (int32_t)beta,
|
||||
C_high, N, nullptr);
|
||||
fill_inputs(a, b);
|
||||
// transform B from col-major to row-major
|
||||
for (int k = 0; k < kK; ++k) {
|
||||
for (int n = 0; n < kN; ++n) {
|
||||
// original B is in col-major: b[n * ld + k], here ld = kK
|
||||
int8_t val = b[n * kK + k];
|
||||
// target row-major: row index = k, col index = n
|
||||
b_rowmajor[k * kN + n] = val;
|
||||
}
|
||||
}
|
||||
#ifdef CHECK
|
||||
// CHECK: printf inputs
|
||||
std::puts("\nMatrix A:\n");
|
||||
for (int m = 0; m < kM; ++m) {
|
||||
for (int k = 0; k < kK; ++k) {
|
||||
std::printf("%4d ", a[m * kK + k]);
|
||||
}
|
||||
std::puts("");
|
||||
}
|
||||
std::puts("\nMatrix B:\n");
|
||||
for (int k = 0; k < kK; ++k) {
|
||||
for (int n = 0; n < kN; ++n) {
|
||||
std::printf("%4d ", b[n * kK + k]);
|
||||
}
|
||||
std::puts("");
|
||||
}
|
||||
#endif
|
||||
std::memset(c, 0, kM * kN * sizeof(int32_t));
|
||||
std::memset(c_unp, 0, kM * kN * sizeof(int32_t));
|
||||
std::memset(ref, 0, kM * kN * sizeof(int32_t));
|
||||
compute_reference(a, b_rowmajor, ref);
|
||||
#ifdef CHECK
|
||||
// CHECK: printf reference
|
||||
std::puts("\nReference result:\n");
|
||||
for (int m = 0; m < kM; ++m) {
|
||||
for (int n = 0; n < kN; ++n) {
|
||||
std::printf("%6d ", ref[m * kN + n]);
|
||||
}
|
||||
std::puts("");
|
||||
}
|
||||
#endif
|
||||
const dim_t reorder_size = aocl_get_reorder_buf_size_s8s8s32os32('c', 'n', 'B', kK, kN);
|
||||
b_reordered = static_cast<int8_t*>(bli_malloc_user(reorder_size, &err));
|
||||
if (!b_reordered) {
|
||||
std::fprintf(stderr, "Reorder buffer allocation failed\n");
|
||||
bli_free_user(a);
|
||||
bli_free_user(b);
|
||||
bli_free_user(c);
|
||||
bli_free_user(ref);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
aocl_reorder_s8s8s32os32('c', 'n', 'B', b, b_reordered, kK, kN, kK);
|
||||
#ifdef CHECK
|
||||
// CHECK: printf reordered B
|
||||
std::puts("\nReordered Matrix B:\n");
|
||||
for (int k = 0; k < kK; ++k) {
|
||||
for (int n = 0; n < kN; ++n) {
|
||||
std::printf("%4d ", b_reordered[k * kN + n]);
|
||||
}
|
||||
std::puts("");
|
||||
}
|
||||
std::printf("\nReorder buffer size: %zu bytes\n", reorder_size);
|
||||
#endif
|
||||
|
||||
// 打印结果
|
||||
// printf("result:\n");
|
||||
// for (int i = 0; i < M; i++) {
|
||||
// for (int j = 0; j < N; j++) {
|
||||
// printf("%d ", C[i * N + j]);
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
const int32_t alpha = 1;
|
||||
const int32_t beta = 0;
|
||||
aocl_gemm_s8s8s32os32('r', 'n', 't', kM, kN, kK, alpha, a, kK, 'n', b_reordered, kK, 'r', beta, c, kN, nullptr);
|
||||
aocl_gemm_s8s8s32os32('r', 'n', 't', kM, kN, kK, alpha, a, kK, 'n', b, kK, 'n', beta, c_unp, kN, nullptr);
|
||||
#ifdef CHECK
|
||||
// CHECK: printf AOCL result
|
||||
std::puts("\nAOCL GEMM result (with reordered B):\n");
|
||||
for (int m = 0; m < kM; ++m) {
|
||||
for (int n = 0; n < kN; ++n) {
|
||||
std::printf("%6d ", c[m * kN + n]);
|
||||
}
|
||||
std::puts("");
|
||||
}
|
||||
std::puts("\nAOCL GEMM result (without reordered B):\n");
|
||||
for (int m = 0; m < kM; ++m) {
|
||||
for (int n = 0; n < kN; ++n) {
|
||||
std::printf("%6d ", c_unp[m * kN + n]);
|
||||
}
|
||||
std::puts("");
|
||||
}
|
||||
#endif
|
||||
|
||||
// 记录结束时间
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// 计算总时长(秒)
|
||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
|
||||
double time_sec = duration.count() / 1e6; // 转换为秒
|
||||
|
||||
// 计算理论浮点运算次数并转换为 TFLOPS
|
||||
double ops = iter * 2.0 * m * N * K;
|
||||
double tflops = ops / (duration.count() * 1e6); // 转换为 TFLOPS
|
||||
|
||||
// 输出结果
|
||||
printf("execute end,m is:%d\n", m);
|
||||
// printf("执行时间: %.4f 秒\n", time_sec);
|
||||
printf("计算性能: %.4f TFLOPS\n", tflops);
|
||||
printf("\n");
|
||||
if (check_result(c, ref)) {
|
||||
std::puts("AOCL GEMM output matches reference.");
|
||||
} else {
|
||||
std::puts("AOCL GEMM output mismatch detected.");
|
||||
}
|
||||
|
||||
// 释放资源
|
||||
free(A);
|
||||
free(B);
|
||||
free(C);
|
||||
if (check_result(c_unp, ref)) {
|
||||
std::puts("unpack AOCL GEMM output matches reference.");
|
||||
} else {
|
||||
std::puts("unpack AOCL GEMM output mismatch detected.");
|
||||
}
|
||||
|
||||
bli_free_user(a);
|
||||
bli_free_user(b);
|
||||
bli_free_user(b_rowmajor);
|
||||
bli_free_user(b_reordered);
|
||||
bli_free_user(c);
|
||||
bli_free_user(c_unp);
|
||||
bli_free_user(ref);
|
||||
return 0;
|
||||
}
|
||||
@@ -3,15 +3,15 @@ import sys
|
||||
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
|
||||
import torch
|
||||
import ctypes
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.moe import MOEConfig, MOE, AMXBF16_MOE, AMXInt8_MOE, AMXInt4_MOE, AMXInt4_1_MOE
|
||||
import kt_kernel_ext
|
||||
from kt_kernel_ext.moe import MOEConfig, MOE, AMXBF16_MOE, AMXInt8_MOE, AMXInt4_MOE, AMXInt4_1_MOE
|
||||
|
||||
intermediate_size_full = 2048
|
||||
moe_intermediate_size = 3072
|
||||
hidden_size = 7168
|
||||
experts_num = 256
|
||||
num_experts_per_tok = 8
|
||||
cpu_infer = cpuinfer_ext.CPUInfer(97)
|
||||
cpu_infer = kt_kernel_ext.CPUInfer(97)
|
||||
|
||||
up = torch.empty(experts_num, intermediate_size_full, hidden_size, dtype=torch.bfloat16, device="cpu")
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import os, sys
|
||||
import time
|
||||
|
||||
sys.path.append(os.path.dirname(__file__) + "/../build")
|
||||
import cpuinfer_ext
|
||||
import kt_kernel_ext
|
||||
from flash_attn import flash_attn_with_kvcache
|
||||
import torch
|
||||
|
||||
@@ -26,20 +26,20 @@ anchor_num = 1
|
||||
cache_seqlen = 8192
|
||||
cache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device="cpu")
|
||||
seqlens_zero = torch.zeros((1,), dtype=torch.int32, device="cpu")
|
||||
anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC
|
||||
kv_type = cpuinfer_ext.kvcache.ggml_type.FP16
|
||||
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER
|
||||
anchor_type = kt_kernel_ext.kvcache.AnchorType.DYNAMIC
|
||||
kv_type = kt_kernel_ext.kvcache.ggml_type.FP16
|
||||
retrieval_type = kt_kernel_ext.kvcache.RetrievalType.LAYER
|
||||
layer_step: int = 1
|
||||
token_step: int = 1
|
||||
layer_offset: int = 0
|
||||
max_thread_num: int = 2
|
||||
max_batch_size: int = 1
|
||||
max_block_num: int = 512
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(max_thread_num)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(max_thread_num)
|
||||
validation_iter = 100
|
||||
|
||||
with torch.inference_mode(mode=True):
|
||||
config = cpuinfer_ext.kvcache.KVCacheConfig(
|
||||
config = kt_kernel_ext.kvcache.KVCacheConfig(
|
||||
layer_num,
|
||||
kv_head_num,
|
||||
q_head_num,
|
||||
@@ -56,7 +56,7 @@ with torch.inference_mode(mode=True):
|
||||
max_batch_size,
|
||||
max_thread_num,
|
||||
)
|
||||
local_kvcache = cpuinfer_ext.kvcache.KVCache(config)
|
||||
local_kvcache = kt_kernel_ext.kvcache.KVCache(config)
|
||||
|
||||
kvcaches = []
|
||||
block_table = (
|
||||
|
||||
@@ -2,7 +2,7 @@ import os, sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||
|
||||
import cpuinfer_ext
|
||||
import kt_kernel_ext
|
||||
import torch
|
||||
|
||||
# Set fixed seed for reproducible results
|
||||
@@ -49,7 +49,7 @@ max_len = 25600
|
||||
num_experts_per_tok = 8
|
||||
qlen = 1
|
||||
layer_num = 1
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(40)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(40)
|
||||
validation_iter = 10
|
||||
k_group_size = 64
|
||||
debug_print_count = 16
|
||||
@@ -302,7 +302,7 @@ def test_online_int4_kgroup_moe():
|
||||
|
||||
for _ in range(layer_num):
|
||||
# Create Int4LowKGroup configuration (online quantization)
|
||||
config = cpuinfer_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
@@ -320,7 +320,7 @@ def test_online_int4_kgroup_moe():
|
||||
config.path = "./awq_dump_online"
|
||||
|
||||
# Create Int4LowKGroup MoE (online quantization during load_weights)
|
||||
moe = cpuinfer_ext.moe.AMXInt4_1KGroup_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXInt4_1KGroup_MOE(config)
|
||||
|
||||
# Load weights (performs online quantization)
|
||||
print(f"Physical Map: {physical_to_logical_map.data_ptr()}")
|
||||
@@ -421,7 +421,7 @@ def test_awq_moe():
|
||||
|
||||
for _ in range(layer_num):
|
||||
# Create AWQ MoE configuration
|
||||
config = cpuinfer_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
|
||||
# Set quantization config for Int4_1LowKGroup
|
||||
@@ -449,7 +449,7 @@ def test_awq_moe():
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
# Create Int4_1LowKGroup MoE
|
||||
moe = cpuinfer_ext.moe.AMXInt4_1KGroup_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXInt4_1KGroup_MOE(config)
|
||||
|
||||
# Load weights
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
|
||||
@@ -2,8 +2,8 @@ import os, sys
|
||||
import time
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.kvcache import ggml_type
|
||||
import kt_kernel_ext
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
import torch
|
||||
import logging
|
||||
import sys
|
||||
@@ -22,7 +22,7 @@ logger = logging.getLogger("reader")
|
||||
from gguf.gguf_reader import GGUFReader
|
||||
# load_layers = 6
|
||||
load_layers = None
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(304)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(304)
|
||||
max_qlen = 4096
|
||||
max_kvlen = 4096
|
||||
page_size = 256
|
||||
@@ -136,7 +136,7 @@ def build_mla(layer_idx, json_config, gguf_weights):
|
||||
rope_theta = json_config["rope_theta"]
|
||||
rope_scaling = json_config["rope_scaling"]
|
||||
|
||||
config = cpuinfer_ext.mla.MLAConfig(
|
||||
config = kt_kernel_ext.mla.MLAConfig(
|
||||
hidden_size,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
@@ -191,12 +191,12 @@ def build_mla(layer_idx, json_config, gguf_weights):
|
||||
config.page_count = pages_count
|
||||
|
||||
if q_a_type == "F32":
|
||||
mla = cpuinfer_ext.mla.MLA_F32(config)
|
||||
mla = kt_kernel_ext.mla.MLA_F32(config)
|
||||
elif q_a_type == "F16":
|
||||
mla = cpuinfer_ext.mla.MLA_F16(config)
|
||||
mla = kt_kernel_ext.mla.MLA_F16(config)
|
||||
elif q_a_type == "BF16":
|
||||
# mla = cpuinfer_ext.mla.MLA_F32(config)
|
||||
mla = cpuinfer_ext.mla.MLA_QUAN_F32(config)
|
||||
# mla = kt_kernel_ext.mla.MLA_F32(config)
|
||||
mla = kt_kernel_ext.mla.MLA_QUAN_F32(config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {q_a_type}")
|
||||
|
||||
@@ -207,7 +207,7 @@ def build_mla(layer_idx, json_config, gguf_weights):
|
||||
|
||||
def build_ffn(layer_idx, json_config, gguf_weights):
|
||||
if f"blk.{layer_idx}.ffn_gate.weight" in gguf_weights: # dense
|
||||
config = cpuinfer_ext.moe.MOEConfig(
|
||||
config = kt_kernel_ext.moe.MOEConfig(
|
||||
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
|
||||
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
|
||||
json_config["hidden_size"],
|
||||
@@ -227,12 +227,12 @@ def build_ffn(layer_idx, json_config, gguf_weights):
|
||||
config.down_proj = down.data_ptr()
|
||||
config.down_type = type_to_ggml_type(down_type)
|
||||
|
||||
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
|
||||
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
|
||||
moe.load_weights()
|
||||
return moe
|
||||
|
||||
elif f"blk.{layer_idx}.ffn_gate_exps.weight" in gguf_weights:
|
||||
config = cpuinfer_ext.moe.MOEConfig(
|
||||
config = kt_kernel_ext.moe.MOEConfig(
|
||||
json_config["n_routed_experts"] + json_config["n_shared_experts"],
|
||||
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
|
||||
json_config["hidden_size"],
|
||||
@@ -267,7 +267,7 @@ def build_ffn(layer_idx, json_config, gguf_weights):
|
||||
config.down_proj = down.data_ptr()
|
||||
config.down_type = type_to_ggml_type(down_type)
|
||||
|
||||
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
|
||||
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
|
||||
moe.load_weights()
|
||||
return moe
|
||||
|
||||
@@ -276,7 +276,7 @@ def build_ffn(layer_idx, json_config, gguf_weights):
|
||||
|
||||
|
||||
def build_moegate(layer_idx, json_config, gguf_weights):
|
||||
config = cpuinfer_ext.gate.GateConfig(
|
||||
config = kt_kernel_ext.gate.GateConfig(
|
||||
json_config["hidden_size"],
|
||||
json_config["num_experts_per_tok"],
|
||||
json_config["n_routed_experts"],
|
||||
@@ -296,7 +296,7 @@ def build_moegate(layer_idx, json_config, gguf_weights):
|
||||
config.e_score_correction_bias = bias.data_ptr()
|
||||
config.e_score_correction_bias_type = type_to_ggml_type(bias_type)
|
||||
|
||||
gate = cpuinfer_ext.gate.MoEGate(config)
|
||||
gate = kt_kernel_ext.gate.MoEGate(config)
|
||||
|
||||
return gate
|
||||
|
||||
@@ -304,7 +304,7 @@ def build_moegate(layer_idx, json_config, gguf_weights):
|
||||
|
||||
def build_llm(json_config, gguf_weights):
|
||||
|
||||
general_config = cpuinfer_ext.GeneralConfig()
|
||||
general_config = kt_kernel_ext.GeneralConfig()
|
||||
general_config.vocab_size = json_config["vocab_size"]
|
||||
general_config.hidden_size = json_config["hidden_size"]
|
||||
general_config.num_experts_per_tok = json_config["num_experts_per_tok"]
|
||||
@@ -326,8 +326,8 @@ def build_llm(json_config, gguf_weights):
|
||||
|
||||
general_config.pool = CPUInfer.backend_
|
||||
|
||||
llm = cpuinfer_ext.DeepseekV3ForCausalLM(general_config)
|
||||
model = cpuinfer_ext.DeepseekV3Model(general_config)
|
||||
llm = kt_kernel_ext.DeepseekV3ForCausalLM(general_config)
|
||||
model = kt_kernel_ext.DeepseekV3Model(general_config)
|
||||
llm.model = model
|
||||
|
||||
|
||||
@@ -335,7 +335,7 @@ def build_llm(json_config, gguf_weights):
|
||||
real_load_layers = json_config["num_hidden_layers"] if load_layers is None else load_layers
|
||||
|
||||
for i in range(real_load_layers):
|
||||
layer = cpuinfer_ext.DeepseekV3DecoderLayer(general_config,i)
|
||||
layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config,i)
|
||||
attn_norm, attn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.attn_norm.weight")
|
||||
ffn_norm, ffn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.ffn_norm.weight")
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ import os, sys
|
||||
import time
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.kvcache import ggml_type
|
||||
import kt_kernel_ext
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
import torch
|
||||
import logging
|
||||
import sys
|
||||
@@ -21,7 +21,7 @@ logger = logging.getLogger("reader")
|
||||
|
||||
from gguf.gguf_reader import GGUFReader
|
||||
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(304)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(304)
|
||||
max_qlen = 4096
|
||||
max_kvlen = 4096
|
||||
page_size = 256
|
||||
@@ -135,7 +135,7 @@ def build_mla(layer_idx, json_config, gguf_weights):
|
||||
rope_theta = json_config["rope_theta"]
|
||||
rope_scaling = json_config["rope_scaling"]
|
||||
|
||||
config = cpuinfer_ext.mla.MLAConfig(
|
||||
config = kt_kernel_ext.mla.MLAConfig(
|
||||
hidden_size,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
@@ -191,12 +191,12 @@ def build_mla(layer_idx, json_config, gguf_weights):
|
||||
|
||||
|
||||
if q_a_type == "F32":
|
||||
mla = cpuinfer_ext.mla.MLA_F32(config)
|
||||
mla = kt_kernel_ext.mla.MLA_F32(config)
|
||||
elif q_a_type == "F16":
|
||||
mla = cpuinfer_ext.mla.MLA_F16(config)
|
||||
mla = kt_kernel_ext.mla.MLA_F16(config)
|
||||
elif q_a_type == "BF16":
|
||||
mla = cpuinfer_ext.mla.MLA_QUAN_F32(config)
|
||||
# mla = cpuinfer_ext.mla.MLA_F32(config)
|
||||
mla = kt_kernel_ext.mla.MLA_QUAN_F32(config)
|
||||
# mla = kt_kernel_ext.mla.MLA_F32(config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {q_a_type}")
|
||||
|
||||
@@ -207,7 +207,7 @@ def build_mla(layer_idx, json_config, gguf_weights):
|
||||
|
||||
def build_ffn(layer_idx, json_config, gguf_weights):
|
||||
if f"blk.{layer_idx}.ffn_gate.weight" in gguf_weights: # dense
|
||||
config = cpuinfer_ext.moe.MOEConfig(
|
||||
config = kt_kernel_ext.moe.MOEConfig(
|
||||
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
|
||||
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
|
||||
json_config["hidden_size"],
|
||||
@@ -227,12 +227,12 @@ def build_ffn(layer_idx, json_config, gguf_weights):
|
||||
config.down_proj = down.data_ptr()
|
||||
config.down_type = type_to_ggml_type(down_type)
|
||||
|
||||
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
|
||||
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
|
||||
moe.load_weights()
|
||||
return moe
|
||||
|
||||
elif f"blk.{layer_idx}.ffn_gate_exps.weight" in gguf_weights:
|
||||
config = cpuinfer_ext.moe.MOEConfig(
|
||||
config = kt_kernel_ext.moe.MOEConfig(
|
||||
json_config["n_routed_experts"] + json_config["n_shared_experts"],
|
||||
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
|
||||
json_config["hidden_size"],
|
||||
@@ -267,7 +267,7 @@ def build_ffn(layer_idx, json_config, gguf_weights):
|
||||
config.down_proj = down.data_ptr()
|
||||
config.down_type = type_to_ggml_type(down_type)
|
||||
|
||||
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
|
||||
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
|
||||
moe.load_weights()
|
||||
return moe
|
||||
|
||||
@@ -276,7 +276,7 @@ def build_ffn(layer_idx, json_config, gguf_weights):
|
||||
|
||||
|
||||
def build_moegate(layer_idx, json_config, gguf_weights):
|
||||
config = cpuinfer_ext.gate.GateConfig(
|
||||
config = kt_kernel_ext.gate.GateConfig(
|
||||
json_config["hidden_size"],
|
||||
json_config["num_experts_per_tok"],
|
||||
json_config["n_routed_experts"],
|
||||
@@ -296,7 +296,7 @@ def build_moegate(layer_idx, json_config, gguf_weights):
|
||||
config.e_score_correction_bias = bias.data_ptr()
|
||||
config.e_score_correction_bias_type = type_to_ggml_type(bias_type)
|
||||
|
||||
gate = cpuinfer_ext.gate.MoEGate(config)
|
||||
gate = kt_kernel_ext.gate.MoEGate(config)
|
||||
|
||||
return gate
|
||||
|
||||
@@ -304,7 +304,7 @@ def build_moegate(layer_idx, json_config, gguf_weights):
|
||||
|
||||
def build_llm(json_config, gguf_weights):
|
||||
|
||||
general_config = cpuinfer_ext.GeneralConfig()
|
||||
general_config = kt_kernel_ext.GeneralConfig()
|
||||
general_config.vocab_size = json_config["vocab_size"]
|
||||
general_config.hidden_size = json_config["hidden_size"]
|
||||
general_config.num_experts_per_tok = json_config["num_experts_per_tok"]
|
||||
@@ -326,8 +326,8 @@ def build_llm(json_config, gguf_weights):
|
||||
|
||||
general_config.pool = CPUInfer.backend_
|
||||
|
||||
llm = cpuinfer_ext.DeepseekV3ForCausalLM(general_config)
|
||||
model = cpuinfer_ext.DeepseekV3Model(general_config)
|
||||
llm = kt_kernel_ext.DeepseekV3ForCausalLM(general_config)
|
||||
model = kt_kernel_ext.DeepseekV3Model(general_config)
|
||||
llm.model = model
|
||||
|
||||
|
||||
@@ -335,7 +335,7 @@ def build_llm(json_config, gguf_weights):
|
||||
for i in range(json_config["num_hidden_layers"]):
|
||||
# for i in range(6):
|
||||
# for i in [0,1,2,3,4,5,6,7,8,9,10]:
|
||||
layer = cpuinfer_ext.DeepseekV3DecoderLayer(general_config,i)
|
||||
layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config,i)
|
||||
attn_norm, attn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.attn_norm.weight")
|
||||
ffn_norm, ffn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.ffn_norm.weight")
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ import os, sys
|
||||
import time
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.kvcache import ggml_type
|
||||
import kt_kernel_ext
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
import torch
|
||||
import logging
|
||||
import sys
|
||||
@@ -22,11 +22,11 @@ logger = logging.getLogger("reader")
|
||||
from gguf.gguf_reader import GGUFReader
|
||||
# load_layers = 3
|
||||
load_layers = None
|
||||
worker_config = cpuinfer_ext.WorkerPoolConfig()
|
||||
worker_config = kt_kernel_ext.WorkerPoolConfig()
|
||||
worker_config.subpool_count = 2
|
||||
worker_config.subpool_numa_map= [0,1]
|
||||
worker_config.subpool_thread_count = [72,72]
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(worker_config)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
|
||||
|
||||
max_qlen = 4096
|
||||
max_kvlen = 4096
|
||||
@@ -141,7 +141,7 @@ def build_mla(layer_idx, json_config, gguf_weights):
|
||||
rope_theta = json_config["rope_theta"]
|
||||
rope_scaling = json_config["rope_scaling"]
|
||||
|
||||
config = cpuinfer_ext.mla.MLAConfig(
|
||||
config = kt_kernel_ext.mla.MLAConfig(
|
||||
hidden_size,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
@@ -196,12 +196,12 @@ def build_mla(layer_idx, json_config, gguf_weights):
|
||||
config.page_count = pages_count
|
||||
|
||||
if q_a_type == "F32":
|
||||
mla = cpuinfer_ext.mla.MLA_F32(config)
|
||||
mla = kt_kernel_ext.mla.MLA_F32(config)
|
||||
elif q_a_type == "F16":
|
||||
mla = cpuinfer_ext.mla.MLA_F16(config)
|
||||
mla = kt_kernel_ext.mla.MLA_F16(config)
|
||||
elif q_a_type == "BF16":
|
||||
# mla = cpuinfer_ext.mla.MLA_F32(config)
|
||||
mla = cpuinfer_ext.mla.MLA_QUAN_F32(config)
|
||||
# mla = kt_kernel_ext.mla.MLA_F32(config)
|
||||
mla = kt_kernel_ext.mla.MLA_QUAN_F32(config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {q_a_type}")
|
||||
|
||||
@@ -212,7 +212,7 @@ def build_mla(layer_idx, json_config, gguf_weights):
|
||||
|
||||
def build_ffn(layer_idx, json_config, gguf_weights):
|
||||
if f"blk.{layer_idx}.ffn_gate.weight" in gguf_weights: # dense
|
||||
config = cpuinfer_ext.moe.MOEConfig(
|
||||
config = kt_kernel_ext.moe.MOEConfig(
|
||||
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
|
||||
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
|
||||
json_config["hidden_size"],
|
||||
@@ -232,12 +232,12 @@ def build_ffn(layer_idx, json_config, gguf_weights):
|
||||
config.down_proj = down.data_ptr()
|
||||
config.down_type = type_to_ggml_type(down_type)
|
||||
|
||||
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
|
||||
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
|
||||
moe.load_weights()
|
||||
return moe
|
||||
|
||||
elif f"blk.{layer_idx}.ffn_gate_exps.weight" in gguf_weights:
|
||||
config = cpuinfer_ext.moe.MOEConfig(
|
||||
config = kt_kernel_ext.moe.MOEConfig(
|
||||
json_config["n_routed_experts"] + json_config["n_shared_experts"],
|
||||
json_config["num_experts_per_tok"] + json_config["n_shared_experts"],
|
||||
json_config["hidden_size"],
|
||||
@@ -272,7 +272,7 @@ def build_ffn(layer_idx, json_config, gguf_weights):
|
||||
config.down_proj = down.data_ptr()
|
||||
config.down_type = type_to_ggml_type(down_type)
|
||||
|
||||
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
|
||||
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
|
||||
moe.load_weights()
|
||||
return moe
|
||||
|
||||
@@ -281,7 +281,7 @@ def build_ffn(layer_idx, json_config, gguf_weights):
|
||||
|
||||
|
||||
def build_moegate(layer_idx, json_config, gguf_weights):
|
||||
config = cpuinfer_ext.gate.GateConfig(
|
||||
config = kt_kernel_ext.gate.GateConfig(
|
||||
json_config["hidden_size"],
|
||||
json_config["num_experts_per_tok"],
|
||||
json_config["n_routed_experts"],
|
||||
@@ -301,7 +301,7 @@ def build_moegate(layer_idx, json_config, gguf_weights):
|
||||
config.e_score_correction_bias = bias.data_ptr()
|
||||
config.e_score_correction_bias_type = type_to_ggml_type(bias_type)
|
||||
|
||||
gate = cpuinfer_ext.gate.MoEGate(config)
|
||||
gate = kt_kernel_ext.gate.MoEGate(config)
|
||||
|
||||
return gate
|
||||
|
||||
@@ -309,7 +309,7 @@ def build_moegate(layer_idx, json_config, gguf_weights):
|
||||
|
||||
def build_llm(json_config, gguf_weights):
|
||||
|
||||
general_config = cpuinfer_ext.GeneralConfig()
|
||||
general_config = kt_kernel_ext.GeneralConfig()
|
||||
general_config.vocab_size = json_config["vocab_size"]
|
||||
general_config.hidden_size = json_config["hidden_size"]
|
||||
general_config.num_experts_per_tok = json_config["num_experts_per_tok"]
|
||||
@@ -331,8 +331,8 @@ def build_llm(json_config, gguf_weights):
|
||||
|
||||
general_config.pool = CPUInfer.backend_
|
||||
|
||||
llm = cpuinfer_ext.DeepseekV3ForCausalLM(general_config)
|
||||
model = cpuinfer_ext.DeepseekV3Model(general_config)
|
||||
llm = kt_kernel_ext.DeepseekV3ForCausalLM(general_config)
|
||||
model = kt_kernel_ext.DeepseekV3Model(general_config)
|
||||
llm.model = model
|
||||
|
||||
|
||||
@@ -341,7 +341,7 @@ def build_llm(json_config, gguf_weights):
|
||||
|
||||
for i in range(real_load_layers):
|
||||
# for i in [2,3]:
|
||||
layer = cpuinfer_ext.DeepseekV3DecoderLayer(general_config,i)
|
||||
layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config,i)
|
||||
attn_norm, attn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.attn_norm.weight")
|
||||
ffn_norm, ffn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.ffn_norm.weight")
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import time
|
||||
from typing import Optional
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.kvcache import ggml_type
|
||||
import kt_kernel_ext
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -171,7 +171,7 @@ def torch_gate(hidden_states):
|
||||
|
||||
|
||||
def cpuinfer_gate(hidden_states):
|
||||
config = cpuinfer_ext.gate.GateConfig(
|
||||
config = kt_kernel_ext.gate.GateConfig(
|
||||
hidden_size,
|
||||
num_experts_per_token,
|
||||
n_routed_experts,
|
||||
@@ -179,7 +179,7 @@ def cpuinfer_gate(hidden_states):
|
||||
topk_group,
|
||||
)
|
||||
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(64)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(64)
|
||||
config.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
config.pool = CPUInfer.backend_
|
||||
@@ -188,7 +188,7 @@ def cpuinfer_gate(hidden_states):
|
||||
config.e_score_correction_bias = bias.data_ptr()
|
||||
config.e_score_correction_bias_type = ggml_type.FP32
|
||||
|
||||
gate = cpuinfer_ext.gate.MoEGate(config)
|
||||
gate = kt_kernel_ext.gate.MoEGate(config)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
import os, sys
|
||||
import time
|
||||
sys.path.append(os.path.dirname(__file__) + '/../build')
|
||||
import cpuinfer_ext
|
||||
import kt_kernel_ext
|
||||
import torch
|
||||
|
||||
input_size = 16384
|
||||
@@ -23,7 +23,7 @@ proj_type = 1 # ggml_type::GGML_TYPE_F16
|
||||
hidden_type = 1 # ggml_type::GGML_TYPE_F16
|
||||
qlen = 30
|
||||
layer_num = 10
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(48)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(48)
|
||||
validation_iter = 100
|
||||
|
||||
with torch.inference_mode(mode=True):
|
||||
@@ -31,8 +31,8 @@ with torch.inference_mode(mode=True):
|
||||
projs = []
|
||||
for _ in range(layer_num):
|
||||
proj = torch.randn((output_size, input_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous()
|
||||
config = cpuinfer_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type)
|
||||
linear = cpuinfer_ext.linear.Linear(config)
|
||||
config = kt_kernel_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type)
|
||||
linear = kt_kernel_ext.linear.Linear(config)
|
||||
projs.append(proj)
|
||||
linears.append(linear)
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import time
|
||||
from typing import Optional
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.kvcache import ggml_type
|
||||
import kt_kernel_ext
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
import torch
|
||||
from torch import inf, nn
|
||||
from torch.nn import init
|
||||
@@ -110,7 +110,7 @@ rope_scaling = {
|
||||
|
||||
|
||||
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(30)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(30)
|
||||
validation_iter = 100
|
||||
|
||||
|
||||
@@ -214,7 +214,7 @@ def test_cpu_mla():
|
||||
kv_b_proj_weight = kv_b_proj.weight.to(weight_type).to('cpu').contiguous()
|
||||
o_proj_weight = o_proj.weight.to(weight_type).to('cpu').contiguous()
|
||||
|
||||
config = cpuinfer_ext.mla.MLAConfig(
|
||||
config = kt_kernel_ext.mla.MLAConfig(
|
||||
hidden_size,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
@@ -272,12 +272,12 @@ def test_cpu_mla():
|
||||
|
||||
|
||||
if weight_type == torch.float32:
|
||||
mla = cpuinfer_ext.mla.MLA_F32(config)
|
||||
mla = kt_kernel_ext.mla.MLA_F32(config)
|
||||
elif weight_type == torch.float16:
|
||||
mla = cpuinfer_ext.mla.MLA_F16(config)
|
||||
mla = kt_kernel_ext.mla.MLA_F16(config)
|
||||
elif weight_type == torch.bfloat16:
|
||||
# mla = cpuinfer_ext.mla.MLA_F32(config)
|
||||
mla = cpuinfer_ext.mla.MLA_QUAN_F32(config)
|
||||
# mla = kt_kernel_ext.mla.MLA_F32(config)
|
||||
mla = kt_kernel_ext.mla.MLA_QUAN_F32(config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {weight_type}")
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import time
|
||||
from typing import Optional
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.kvcache import ggml_type
|
||||
import kt_kernel_ext
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
import torch
|
||||
from torch import inf, nn
|
||||
from torch.nn import init
|
||||
@@ -110,7 +110,7 @@ rope_scaling = {
|
||||
|
||||
|
||||
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(64)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(64)
|
||||
validation_iter = 100
|
||||
|
||||
|
||||
@@ -214,7 +214,7 @@ def build_mla():
|
||||
kv_b_proj_weight = kv_b_proj.weight.to(weight_type).to('cpu').contiguous()
|
||||
o_proj_weight = o_proj.weight.to(weight_type).to('cpu').contiguous()
|
||||
|
||||
config = cpuinfer_ext.mla.MLAConfig(
|
||||
config = kt_kernel_ext.mla.MLAConfig(
|
||||
hidden_size,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
@@ -271,11 +271,11 @@ def build_mla():
|
||||
|
||||
|
||||
if weight_type == torch.float32:
|
||||
mla = cpuinfer_ext.mla.MLA_F32(config)
|
||||
mla = kt_kernel_ext.mla.MLA_F32(config)
|
||||
elif weight_type == torch.float16:
|
||||
mla = cpuinfer_ext.mla.MLA_F16(config)
|
||||
mla = kt_kernel_ext.mla.MLA_F16(config)
|
||||
elif weight_type == torch.bfloat16:
|
||||
mla = cpuinfer_ext.mla.MLA_F32(config)
|
||||
mla = kt_kernel_ext.mla.MLA_F32(config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {weight_type}")
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import time
|
||||
from typing import Optional
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.kvcache import ggml_type
|
||||
import kt_kernel_ext
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
import torch
|
||||
from torch import inf, nn
|
||||
from torch.nn import init
|
||||
|
||||
@@ -2,8 +2,8 @@ import os,sys
|
||||
import time
|
||||
from typing import Optional
|
||||
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.kvcache import ggml_type
|
||||
import kt_kernel_ext
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
@@ -54,7 +54,7 @@ rope_scaling = {
|
||||
|
||||
|
||||
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(64)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(64)
|
||||
validation_iter = 100
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ o_proj_weight = o_proj.weight.to(torch.float16).to('cpu').contiguous()
|
||||
|
||||
|
||||
|
||||
config = cpuinfer_ext.mla.MLAConfig(
|
||||
config = kt_kernel_ext.mla.MLAConfig(
|
||||
hidden_size,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
@@ -115,7 +115,7 @@ config.pool = CPUInfer.backend_
|
||||
|
||||
|
||||
|
||||
mla = cpuinfer_ext.mla.MLA(config)
|
||||
mla = kt_kernel_ext.mla.MLA(config)
|
||||
mla.load_weights()
|
||||
mla.set_local_pages(pages_count)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
import os, sys
|
||||
import time
|
||||
sys.path.append(os.path.dirname(__file__) + '/../build')
|
||||
import cpuinfer_ext
|
||||
import kt_kernel_ext
|
||||
import torch
|
||||
|
||||
hidden_size = 5120
|
||||
@@ -25,7 +25,7 @@ down_type = 1 # ggml_type::GGML_TYPE_F16
|
||||
hidden_type = 1 # ggml_type::GGML_TYPE_F16
|
||||
qlen = 30
|
||||
layer_num = 10
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(48)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(48)
|
||||
validation_iter = 100
|
||||
|
||||
def act_fn(x):
|
||||
@@ -47,8 +47,8 @@ with torch.inference_mode(mode=True):
|
||||
gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous()
|
||||
up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous()
|
||||
down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous()
|
||||
config = cpuinfer_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)
|
||||
mlp = cpuinfer_ext.mlp.MLP(config)
|
||||
config = kt_kernel_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)
|
||||
mlp = kt_kernel_ext.mlp.MLP(config)
|
||||
gate_projs.append(gate_proj)
|
||||
up_projs.append(up_proj)
|
||||
down_projs.append(down_proj)
|
||||
|
||||
@@ -12,10 +12,10 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
import os, sys
|
||||
import time
|
||||
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
|
||||
import cpuinfer_ext
|
||||
import kt_kernel_ext
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from cpuinfer_ext.kvcache import ggml_type
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -36,7 +36,7 @@ layer_num = 1
|
||||
# num_experts_per_tok = 8
|
||||
# qlen = 1024
|
||||
# layer_num = 1
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(64)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(64)
|
||||
validation_iter = 10
|
||||
|
||||
def act_fn(x):
|
||||
@@ -83,10 +83,10 @@ def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
|
||||
|
||||
def to_cpuinfer_tensor(tensor, type):
|
||||
size = torch.prod(torch.tensor(tensor.shape, dtype=torch.int32)).item()
|
||||
return cpuinfer_ext.utils.from_float(tensor.data_ptr(), size, type)
|
||||
return kt_kernel_ext.utils.from_float(tensor.data_ptr(), size, type)
|
||||
|
||||
def from_cpuinfer_tensor(tensor, size, type):
|
||||
return cpuinfer_ext.utils.to_float(tensor.data_ptr(), size, type)
|
||||
return kt_kernel_ext.utils.to_float(tensor.data_ptr(), size, type)
|
||||
|
||||
qlens = [1,64] #[64, 512, 2048, 8192, 16384]
|
||||
# gate_types = [ggml_type.FP32, ggml_type.FP16, ggml_type.Q8_0, ggml_type.Q6_K, ggml_type.Q5_K, ggml_type.Q4_K, ggml_type.Q3_K]
|
||||
@@ -118,7 +118,7 @@ for qlen in qlens:
|
||||
up_tensor = to_cpuinfer_tensor(up_proj, up_type)
|
||||
down_tensor = to_cpuinfer_tensor(down_proj, down_type)
|
||||
|
||||
config = cpuinfer_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
config.pool = CPUInfer.backend_
|
||||
config.stride = stride
|
||||
config.group_min_len = group_min_len
|
||||
@@ -132,7 +132,7 @@ for qlen in qlens:
|
||||
config.hidden_type = hidden_type
|
||||
|
||||
|
||||
moe = cpuinfer_ext.moe.MOE(config)
|
||||
moe = kt_kernel_ext.moe.MOE(config)
|
||||
gate_projs.append(gate_proj)
|
||||
up_projs.append(up_proj)
|
||||
down_projs.append(down_proj)
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
import os, sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||
print("sys.path:", sys.path)
|
||||
|
||||
import cpuinfer_ext
|
||||
import torch
|
||||
import kt_kernel_ext
|
||||
|
||||
expert_num = 256
|
||||
hidden_size = 7168
|
||||
intermediate_size = 2048
|
||||
max_len = 25600
|
||||
num_experts_per_tok = 8
|
||||
# qlen = 1
|
||||
qlen = 640
|
||||
qlen = 1
|
||||
# qlen = 640
|
||||
layer_num = 1
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(40)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(90)
|
||||
# validation_iter = 10000
|
||||
validation_iter = 10
|
||||
validation_iter = 2
|
||||
k_group_size = 64
|
||||
debug_print_count = 16 # Number of values to print in debug output
|
||||
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||
@@ -126,7 +127,7 @@ def test_moe(quant_mode: str):
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
config = cpuinfer_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
@@ -134,25 +135,25 @@ def test_moe(quant_mode: str):
|
||||
config.gate_scale = 0
|
||||
config.pool = CPUInfer.backend_
|
||||
if quant_mode == "bf16":
|
||||
moe = cpuinfer_ext.moe.AMXBF16_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
CPUInfer.submit(moe.warm_up_task())
|
||||
CPUInfer.sync()
|
||||
elif quant_mode == "int8":
|
||||
moe = cpuinfer_ext.moe.AMXInt8_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXInt8_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
# CPUInfer.submit(moe.warm_up_task())
|
||||
# CPUInfer.sync()
|
||||
elif quant_mode == "int4":
|
||||
moe = cpuinfer_ext.moe.AMXInt4_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXInt4_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
CPUInfer.submit(moe.warm_up_task())
|
||||
CPUInfer.sync()
|
||||
elif quant_mode == "int4_1":
|
||||
moe = cpuinfer_ext.moe.AMXInt4_1_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXInt4_1_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
CPUInfer.submit(moe.warm_up_task())
|
||||
@@ -161,11 +162,12 @@ def test_moe(quant_mode: str):
|
||||
config.quant_config.bits = 4
|
||||
config.quant_config.group_size = k_group_size
|
||||
config.quant_config.zero_point = True
|
||||
moe = cpuinfer_ext.moe.AMXInt4_1KGroup_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXInt4_1KGroup_MOE(config)
|
||||
# import debugpy
|
||||
# debugpy.listen(("127.0.0.1", 5678))
|
||||
# debugpy.wait_for_client()
|
||||
# debugpy.breakpoint()
|
||||
print(f"the physical_logical map:{physical_to_logical_map.data_ptr()}")
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
# CPUInfer.submit(moe.warm_up_task())
|
||||
@@ -260,7 +262,7 @@ def test_moe(quant_mode: str):
|
||||
# 5. Final output comparison
|
||||
|
||||
# test_moe("bf16")
|
||||
# test_moe("int8")
|
||||
# test_moe("int4")
|
||||
# test_moe("int4_1")
|
||||
test_moe("int8")
|
||||
test_moe("int4")
|
||||
test_moe("int4_1")
|
||||
test_moe("int4_1k")
|
||||
|
||||
187
kt-kernel/examples/test_moe_kernel.py
Normal file
187
kt-kernel/examples/test_moe_kernel.py
Normal file
@@ -0,0 +1,187 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
"""
|
||||
Description :
|
||||
Author : chenht2022
|
||||
Date : 2024-07-25 10:32:05
|
||||
Version : 1.0.0
|
||||
LastEditors : chenht2022
|
||||
LastEditTime : 2024-08-06 10:38:05
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
"""
|
||||
import os, sys
|
||||
import time
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
import torch
|
||||
import kt_kernel_ext
|
||||
|
||||
|
||||
expert_num = 16
|
||||
hidden_size = 7168
|
||||
intermediate_size = 2048
|
||||
max_len = 4096
|
||||
num_experts_per_tok = 8
|
||||
m_block = 320
|
||||
n_block_up_gate = 32
|
||||
n_block_down = 64
|
||||
n_block_up_gate_prefi = 32
|
||||
n_block_down_prefi = 64
|
||||
# qlen = 1
|
||||
qlen = 1024
|
||||
layer_num = 1
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(160)
|
||||
# validation_iter = 10000
|
||||
validation_iter = 1
|
||||
|
||||
|
||||
def act_fn(x):
|
||||
return x / (1.0 + torch.exp(-x))
|
||||
|
||||
|
||||
def mlp_torch(input, gate_proj, up_proj, down_proj):
|
||||
gate_buf = torch.mm(input, gate_proj.t())
|
||||
up_buf = torch.mm(input, up_proj.t())
|
||||
intermediate = act_fn(gate_buf) * up_buf
|
||||
ret = torch.mm(intermediate, down_proj.t())
|
||||
return ret
|
||||
|
||||
|
||||
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
|
||||
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
|
||||
cnts.scatter_(1, expert_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = expert_ids.view(-1).argsort()
|
||||
sorted_tokens = input[idxs // expert_ids.shape[1]]
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
t_output = (
|
||||
new_x.view(*expert_ids.shape, -1)
|
||||
.type(weights.dtype)
|
||||
.mul_(weights.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return t_output
|
||||
|
||||
|
||||
def test_moe(quant_mode: str):
|
||||
assert quant_mode == "int8" or quant_mode == "int4" or quant_mode == "int4_1"
|
||||
with torch.inference_mode(mode=True):
|
||||
moes = []
|
||||
gate_projs = []
|
||||
up_projs = []
|
||||
down_projs = []
|
||||
for _ in range(layer_num):
|
||||
gate_proj = (
|
||||
torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device="cpu")
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
up_proj = (
|
||||
torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device="cpu")
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
down_proj = (
|
||||
torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.bfloat16, device="cpu")
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
config.max_len = max_len
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
if quant_mode == "int8":
|
||||
d = kt_kernel_ext.moe.tiling.get_int8()
|
||||
nbug_prefi = n_block_up_gate_prefi
|
||||
nbd_prefi = n_block_down_prefi
|
||||
kb = d["k_block"]
|
||||
nb = d["n_block"]
|
||||
mb = m_block
|
||||
nbug = n_block_up_gate
|
||||
nbd = n_block_down
|
||||
print(
|
||||
f"Int8 Tiling: nbug {nbug}, nbd {nbd}, nb {nb}, mb {mb}, kb {kb}, nbug_prefi {nbug_prefi}, nbd_prefi {nbd_prefi}"
|
||||
)
|
||||
kt_kernel_ext.moe.tiling.set_int8(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
|
||||
moe = kt_kernel_ext.moe.Int8_KERNEL_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task())
|
||||
CPUInfer.sync()
|
||||
# CPUInfer.submit(moe.warm_up_task())
|
||||
# CPUInfer.sync()
|
||||
elif quant_mode == "int4":
|
||||
moe = kt_kernel_ext.moe.Int4_KERNEL_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task())
|
||||
CPUInfer.sync()
|
||||
CPUInfer.submit(moe.warm_up_task())
|
||||
CPUInfer.sync()
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode: {quant_mode}")
|
||||
gate_projs.append(gate_proj)
|
||||
up_projs.append(up_proj)
|
||||
down_projs.append(down_proj)
|
||||
moes.append(moe)
|
||||
|
||||
# validation
|
||||
for i in range(validation_iter):
|
||||
bsz_tensor = torch.tensor([qlen], device="cpu")
|
||||
expert_ids = torch.stack(
|
||||
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
|
||||
).contiguous()
|
||||
weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
|
||||
input = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
||||
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
||||
input = input / 100
|
||||
# 打印 input 的内容
|
||||
print("input:", input)
|
||||
moe = moes[i % layer_num]
|
||||
# print('expert ids:',expert_ids)
|
||||
CPUInfer.submit(
|
||||
moe.forward_task(
|
||||
bsz_tensor.data_ptr(),
|
||||
num_experts_per_tok,
|
||||
expert_ids.data_ptr(),
|
||||
weights.data_ptr(),
|
||||
input.data_ptr(),
|
||||
output.data_ptr(),
|
||||
False,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
print("cpuinfer output", output)
|
||||
|
||||
gate_proj = gate_projs[i % layer_num]
|
||||
up_proj = up_projs[i % layer_num]
|
||||
down_proj = down_projs[i % layer_num]
|
||||
t_output = moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj)
|
||||
print("torch output", t_output)
|
||||
|
||||
# print(output - t_output)
|
||||
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
|
||||
print("diff = ", diff)
|
||||
if quant_mode == "int4":
|
||||
assert diff < 0.35
|
||||
else:
|
||||
assert diff < 0.05
|
||||
|
||||
|
||||
test_moe("int8")
|
||||
# test_moe("int4")
|
||||
@@ -14,7 +14,7 @@ import time
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
import cpuinfer_ext
|
||||
import kt_kernel_ext
|
||||
import torch
|
||||
|
||||
expert_num = 16
|
||||
@@ -25,7 +25,7 @@ num_experts_per_tok = 8
|
||||
qlen = 512
|
||||
# qlen = 640
|
||||
layer_num = 1
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(112)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(112)
|
||||
# validation_iter = 10000
|
||||
validation_iter = 1
|
||||
|
||||
@@ -97,26 +97,26 @@ def test_moe(quant_mode: str):
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
config = cpuinfer_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
config.max_len = max_len
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
if quant_mode == "bf16":
|
||||
moe = cpuinfer_ext.moe.AMXBF16_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task())
|
||||
CPUInfer.sync()
|
||||
CPUInfer.submit(moe.warm_up_task())
|
||||
CPUInfer.sync()
|
||||
elif quant_mode == "int8":
|
||||
moe = cpuinfer_ext.moe.KMLInt8_MOE(config)
|
||||
moe = kt_kernel_ext.moe.KMLInt8_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task())
|
||||
CPUInfer.sync()
|
||||
# CPUInfer.submit(moe.warm_up_task())
|
||||
# CPUInfer.sync()
|
||||
elif quant_mode == "int4":
|
||||
moe = cpuinfer_ext.moe.KMLInt4_MOE(config)
|
||||
moe = kt_kernel_ext.moe.KMLInt4_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task())
|
||||
CPUInfer.sync()
|
||||
CPUInfer.submit(moe.warm_up_task())
|
||||
|
||||
@@ -8,12 +8,19 @@
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
**/
|
||||
// Python bindings
|
||||
#include <sys/types.h>
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "cpu_backend/cpuinfer.h"
|
||||
#include "cpu_backend/worker_pool.h"
|
||||
// #include "device_launch_parameters.h"
|
||||
#include "llamafile/flags.h"
|
||||
#include "operators/common.hpp"
|
||||
|
||||
#if defined(USE_MOE_KERNEL)
|
||||
#include "operators/moe_kernel/la/kernel.hpp"
|
||||
#include "operators/moe_kernel/moe.hpp"
|
||||
#endif
|
||||
|
||||
#if defined(__aarch64__) && defined(CPU_USE_KML)
|
||||
#if defined(KTRANSFORMERS_CPU_MLA)
|
||||
#include "operators/kml/deepseekv3.hpp"
|
||||
@@ -22,26 +29,27 @@
|
||||
#include "operators/kml/mla_int8.hpp"
|
||||
#endif
|
||||
#include "operators/kml/moe.hpp"
|
||||
static const bool _is_plain_ = true;
|
||||
#else
|
||||
static const bool _is_plain_ = false;
|
||||
#endif
|
||||
|
||||
#ifdef __x86_64__
|
||||
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
|
||||
#include "operators/amx/awq-moe.hpp"
|
||||
#include "operators/amx/la/amx_kernels.hpp"
|
||||
#include "operators/amx/moe.hpp"
|
||||
#endif
|
||||
#include <pybind11/stl.h> // std::vector/std::pair/std::string conversions
|
||||
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "operators/kvcache/kvcache.h"
|
||||
#include "operators/llamafile/linear.h"
|
||||
#include "operators/llamafile/mla.hpp"
|
||||
#include "operators/llamafile/mlp.h"
|
||||
#include "operators/llamafile/moe.hpp"
|
||||
#include "pybind11/functional.h"
|
||||
#include "pybind11/operators.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace pybind11::literals;
|
||||
@@ -160,17 +168,25 @@ class MOEBindings {
|
||||
struct Args {
|
||||
CPUInfer* cpuinfer;
|
||||
TP_MOE<T>* moe;
|
||||
const uint64_t* physical_to_logical_map;
|
||||
};
|
||||
static void inner(void* args) {
|
||||
Args* args_ = (Args*)args;
|
||||
args_->cpuinfer->enqueue(&TP_MOE<T>::load_weights, args_->moe, args_->physical_to_logical_map);
|
||||
args_->cpuinfer->enqueue(&TP_MOE<T>::load_weights, args_->moe);
|
||||
}
|
||||
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe,
|
||||
intptr_t physical_to_logical_map) {
|
||||
Args* args = new Args{nullptr, moe.get(), (const uint64_t*)physical_to_logical_map};
|
||||
const uintptr_t physical_to_logical_map = 0) {
|
||||
Args* args = new Args{nullptr, moe.get()};
|
||||
if (physical_to_logical_map) {
|
||||
printf("debug physical_to_logical_map in arg:%lu\n", physical_to_logical_map);
|
||||
moe->config.physical_to_logical_map = reinterpret_cast<void*>(physical_to_logical_map);
|
||||
printf("moe ptr:%p,confirm: moe->config.physical_to_logical_map:%lu\n", reinterpret_cast<void*>(moe.get()),
|
||||
reinterpret_cast<uintptr_t>(moe->config.physical_to_logical_map));
|
||||
}
|
||||
return std::make_pair((intptr_t)&inner, (intptr_t)args);
|
||||
}
|
||||
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe) {
|
||||
return cpuinfer_interface(moe, 0);
|
||||
}
|
||||
};
|
||||
class ForwardBindings {
|
||||
public:
|
||||
@@ -196,10 +212,41 @@ class MOEBindings {
|
||||
Args* args = new Args{nullptr, moe.get(), qlen, k, expert_ids, weights, input, output, incremental};
|
||||
return std::make_pair((intptr_t)&inner, (intptr_t)args);
|
||||
}
|
||||
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe, intptr_t qlen, int k,
|
||||
intptr_t expert_ids, intptr_t weights, intptr_t input,
|
||||
intptr_t output) {
|
||||
return cpuinfer_interface(moe, qlen, k, expert_ids, weights, input, output, false);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
PYBIND11_MODULE(cpuinfer_ext, m) {
|
||||
template <typename MoeTP>
|
||||
void bind_moe_module(py::module_& moe_module, const char* name) {
|
||||
using MoeClass = TP_MOE<MoeTP>;
|
||||
using MoeBindings = MOEBindings<MoeTP>;
|
||||
|
||||
py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name)
|
||||
.def(py::init<GeneralMOEConfig>())
|
||||
.def("warm_up_task", &MoeBindings::WarmUpBindings::cpuinfer_interface)
|
||||
.def("load_weights_task",
|
||||
py::overload_cast<std::shared_ptr<MoeClass>>(&MoeBindings::LoadWeightsBindings::cpuinfer_interface))
|
||||
.def("load_weights_task",
|
||||
py::overload_cast<std::shared_ptr<MoeClass>, const uintptr_t>(
|
||||
&MoeBindings::LoadWeightsBindings::cpuinfer_interface),
|
||||
py::arg("physical_to_logical_map"))
|
||||
// .def("forward_task", &MoeBindings::ForwardBindings::cpuinfer_interface)
|
||||
.def("forward_task",
|
||||
py::overload_cast<std::shared_ptr<MoeClass>, intptr_t, int, intptr_t, intptr_t, intptr_t, intptr_t>(
|
||||
&MoeBindings::ForwardBindings::cpuinfer_interface))
|
||||
.def("forward_task",
|
||||
py::overload_cast<std::shared_ptr<MoeClass>, intptr_t, int, intptr_t, intptr_t, intptr_t, intptr_t, bool>(
|
||||
&MoeBindings::ForwardBindings::cpuinfer_interface))
|
||||
.def("warm_up", &MoeClass::warm_up)
|
||||
.def("load_weights", &MoeClass::load_weights)
|
||||
.def("forward", &MoeClass::forward_binding);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(kt_kernel_ext, m) {
|
||||
py::class_<WorkerPool>(m, "WorkerPool").def(py::init<int>());
|
||||
py::class_<WorkerPoolConfig>(m, "WorkerPoolConfig")
|
||||
.def(py::init<>())
|
||||
@@ -399,11 +446,15 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
|
||||
auto moe_module = m.def_submodule("moe");
|
||||
|
||||
py::class_<GeneralMOEConfig>(moe_module, "MOEConfig")
|
||||
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size) {
|
||||
return GeneralMOEConfig(expert_num, routed_expert_num, hidden_size, intermediate_size);
|
||||
}))
|
||||
.def(py::init(
|
||||
[](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int num_gpu_experts) {
|
||||
return GeneralMOEConfig(expert_num, routed_expert_num, hidden_size, intermediate_size, num_gpu_experts);
|
||||
GeneralMOEConfig cfg(expert_num, routed_expert_num, hidden_size, intermediate_size);
|
||||
cfg.num_gpu_experts = num_gpu_experts;
|
||||
return cfg;
|
||||
}))
|
||||
|
||||
.def_readwrite("layer_idx", &GeneralMOEConfig::layer_idx)
|
||||
.def_readwrite("pool", &GeneralMOEConfig::pool)
|
||||
|
||||
@@ -454,109 +505,92 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
|
||||
|
||||
py::class_<MoE_Interface, std::shared_ptr<MoE_Interface>>(moe_module, "MoE_Interface");
|
||||
|
||||
py::class_<TP_MOE<LLAMA_MOE_TP>, MoE_Interface, std::shared_ptr<TP_MOE<LLAMA_MOE_TP>>>(moe_module, "MOE")
|
||||
.def(py::init<GeneralMOEConfig>())
|
||||
.def("warm_up_task", &MOEBindings<LLAMA_MOE_TP>::WarmUpBindings::cpuinfer_interface)
|
||||
.def("load_weights_task", &MOEBindings<LLAMA_MOE_TP>::LoadWeightsBindings::cpuinfer_interface)
|
||||
.def("forward_task", &MOEBindings<LLAMA_MOE_TP>::ForwardBindings::cpuinfer_interface)
|
||||
.def("warm_up", &TP_MOE<LLAMA_MOE_TP>::warm_up)
|
||||
.def("load_weights", &TP_MOE<LLAMA_MOE_TP>::load_weights)
|
||||
.def("forward", &TP_MOE<LLAMA_MOE_TP>::forward_binding);
|
||||
bind_moe_module<LLAMA_MOE_TP>(moe_module, "MOE");
|
||||
|
||||
#ifdef __x86_64__
|
||||
py::class_<TP_MOE<AMX_MOE_TP<amx::GemmKernel224BF>>, MoE_Interface,
|
||||
std::shared_ptr<TP_MOE<AMX_MOE_TP<amx::GemmKernel224BF>>>>(moe_module, "AMXBF16_MOE")
|
||||
.def(py::init<GeneralMOEConfig>())
|
||||
.def("warm_up_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224BF>>::WarmUpBindings::cpuinfer_interface)
|
||||
.def("load_weights_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224BF>>::LoadWeightsBindings::cpuinfer_interface)
|
||||
.def("forward_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224BF>>::ForwardBindings::cpuinfer_interface)
|
||||
.def("warm_up", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224BF>>::warm_up)
|
||||
.def("load_weights", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224BF>>::load_weights)
|
||||
.def("forward", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224BF>>::forward_binding);
|
||||
py::class_<TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int8>>, MoE_Interface,
|
||||
std::shared_ptr<TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int8>>>>(moe_module, "AMXInt8_MOE")
|
||||
.def(py::init<GeneralMOEConfig>())
|
||||
.def("warm_up_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int8>>::WarmUpBindings::cpuinfer_interface)
|
||||
.def("load_weights_task",
|
||||
&MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int8>>::LoadWeightsBindings::cpuinfer_interface)
|
||||
.def("forward_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int8>>::ForwardBindings::cpuinfer_interface)
|
||||
.def("warm_up", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int8>>::warm_up)
|
||||
.def("load_weights", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int8>>::load_weights)
|
||||
.def("forward", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int8>>::forward_binding);
|
||||
|
||||
py::class_<TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4>>, MoE_Interface,
|
||||
std::shared_ptr<TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4>>>>(moe_module, "AMXInt4_MOE")
|
||||
.def(py::init<GeneralMOEConfig>())
|
||||
.def("warm_up_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int4>>::WarmUpBindings::cpuinfer_interface)
|
||||
.def("load_weights_task",
|
||||
&MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int4>>::LoadWeightsBindings::cpuinfer_interface)
|
||||
.def("forward_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int4>>::ForwardBindings::cpuinfer_interface)
|
||||
.def("warm_up", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4>>::warm_up)
|
||||
.def("load_weights", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4>>::load_weights)
|
||||
.def("forward", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4>>::forward_binding);
|
||||
|
||||
py::class_<TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4_1>>, MoE_Interface,
|
||||
std::shared_ptr<TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4_1>>>>(moe_module, "AMXInt4_1_MOE")
|
||||
.def(py::init<GeneralMOEConfig>())
|
||||
.def("warm_up_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int4_1>>::WarmUpBindings::cpuinfer_interface)
|
||||
.def("load_weights_task",
|
||||
&MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int4_1>>::LoadWeightsBindings::cpuinfer_interface)
|
||||
.def("forward_task", &MOEBindings<AMX_MOE_TP<amx::GemmKernel224Int4_1>>::ForwardBindings::cpuinfer_interface)
|
||||
.def("warm_up", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4_1>>::warm_up)
|
||||
.def("load_weights", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4_1>>::load_weights)
|
||||
.def("forward", &TP_MOE<AMX_MOE_TP<amx::GemmKernel224Int4_1>>::forward_binding);
|
||||
|
||||
// py::class_<TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>, MoE_Interface,
|
||||
// std::shared_ptr<TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>>>(moe_module, "AMXInt4KGroup_MOE")
|
||||
// .def(py::init<GeneralMOEConfig>())
|
||||
// .def("warm_up_task",
|
||||
// &MOEBindings<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>::WarmUpBindings::cpuinfer_interface)
|
||||
// .def("load_weights_task",
|
||||
// &MOEBindings<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>::LoadWeightsBindings::cpuinfer_interface)
|
||||
// .def("forward_task",
|
||||
// &MOEBindings<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>::ForwardBindings::cpuinfer_interface)
|
||||
// .def("warm_up", &TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>::warm_up)
|
||||
// .def("load_weights", &TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>::load_weights)
|
||||
// .def("forward", &TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4KGroup>>::forward_binding);
|
||||
|
||||
py::class_<TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>, MoE_Interface,
|
||||
std::shared_ptr<TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>>>(moe_module,
|
||||
"AMXInt4_1KGroup_MOE")
|
||||
.def(py::init<GeneralMOEConfig>())
|
||||
.def("warm_up_task",
|
||||
&MOEBindings<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>::WarmUpBindings::cpuinfer_interface)
|
||||
.def("load_weights_task",
|
||||
&MOEBindings<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>::LoadWeightsBindings::cpuinfer_interface)
|
||||
.def("forward_task",
|
||||
&MOEBindings<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>::ForwardBindings::cpuinfer_interface)
|
||||
.def("warm_up", &TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>::warm_up)
|
||||
.def("load_weights", &TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>::load_weights)
|
||||
.def("forward", &TP_MOE<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>::forward_binding);
|
||||
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
|
||||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224BF>>(moe_module, "AMXBF16_MOE");
|
||||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int8>>(moe_module, "AMXInt8_MOE");
|
||||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4>>(moe_module, "AMXInt4_MOE");
|
||||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
|
||||
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
|
||||
#endif
|
||||
|
||||
#if defined(USE_MOE_KERNEL)
|
||||
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");
|
||||
#if defined(__aarch64__) && defined(CPU_USE_KML)
|
||||
py::class_<TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt8>>, MoE_Interface,
|
||||
std::shared_ptr<TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt8>>>>(moe_module, "KMLInt8_MOE")
|
||||
.def(py::init<GeneralMOEConfig>())
|
||||
.def("warm_up_task", &MOEBindings<KML_MOE_TP<arm_kml::GemmKernelInt8>>::WarmUpBindings::cpuinfer_interface)
|
||||
.def("load_weights_task",
|
||||
&MOEBindings<KML_MOE_TP<arm_kml::GemmKernelInt8>>::LoadWeightsBindings::cpuinfer_interface)
|
||||
.def("forward_task", &MOEBindings<KML_MOE_TP<arm_kml::GemmKernelInt8>>::ForwardBindings::cpuinfer_interface)
|
||||
.def("warm_up", &TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt8>>::warm_up)
|
||||
.def("load_weights", &TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt8>>::load_weights)
|
||||
.def("forward", &TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt8>>::forward_binding);
|
||||
|
||||
py::class_<TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt4>>, MoE_Interface,
|
||||
std::shared_ptr<TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt4>>>>(moe_module, "KMLInt4_MOE")
|
||||
.def(py::init<GeneralMOEConfig>())
|
||||
.def("warm_up_task", &MOEBindings<KML_MOE_TP<arm_kml::GemmKernelInt4>>::WarmUpBindings::cpuinfer_interface)
|
||||
.def("load_weights_task",
|
||||
&MOEBindings<KML_MOE_TP<arm_kml::GemmKernelInt4>>::LoadWeightsBindings::cpuinfer_interface)
|
||||
.def("forward_task", &MOEBindings<KML_MOE_TP<arm_kml::GemmKernelInt4>>::ForwardBindings::cpuinfer_interface)
|
||||
.def("warm_up", &TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt4>>::warm_up)
|
||||
.def("load_weights", &TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt4>>::load_weights)
|
||||
.def("forward", &TP_MOE<KML_MOE_TP<arm_kml::GemmKernelInt4>>::forward_binding);
|
||||
// amd have not implemented int4 kernel yet
|
||||
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt4, _is_plain_>>(moe_module, "Int4_KERNEL_MOE");
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Expose kernel tiling/runtime parameters so Python can modify them at runtime
|
||||
{
|
||||
auto tiling_module = moe_module.def_submodule("tiling");
|
||||
#if defined(USE_MOE_KERNEL)
|
||||
tiling_module.def(
|
||||
"get_int8",
|
||||
[]() {
|
||||
auto t = moe_kernel::GemmKernelInt8::get_tiling();
|
||||
py::dict d;
|
||||
d["n_block_up_gate"] = std::get<0>(t);
|
||||
d["n_block_down"] = std::get<1>(t);
|
||||
d["n_block"] = std::get<2>(t);
|
||||
d["m_block"] = std::get<3>(t);
|
||||
d["k_block"] = std::get<4>(t);
|
||||
d["n_block_up_gate_prefi"] = std::get<5>(t);
|
||||
d["n_block_down_prefi"] = std::get<6>(t);
|
||||
return d;
|
||||
},
|
||||
"Get current tiling parameters for INT8 kernel");
|
||||
tiling_module.def(
|
||||
"set_int8",
|
||||
[](int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block, int n_block_up_gate_prefi,
|
||||
int n_block_down_prefi) {
|
||||
moe_kernel::GemmKernelInt8::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,
|
||||
n_block_up_gate_prefi, n_block_down_prefi);
|
||||
},
|
||||
py::arg("n_block_up_gate"), py::arg("n_block_down"), py::arg("n_block"), py::arg("m_block"), py::arg("k_block"),
|
||||
py::arg("n_block_up_gate_prefi"), py::arg("n_block_down_prefi"), "Set tiling parameters for INT8 kernel");
|
||||
|
||||
tiling_module.def(
|
||||
"get_int4",
|
||||
[]() {
|
||||
auto t = moe_kernel::GemmKernelInt4::get_tiling();
|
||||
py::dict d;
|
||||
d["n_block_up_gate"] = std::get<0>(t);
|
||||
d["n_block_down"] = std::get<1>(t);
|
||||
d["n_block"] = std::get<2>(t);
|
||||
d["m_block"] = std::get<3>(t);
|
||||
d["k_block"] = std::get<4>(t);
|
||||
d["n_block_up_gate_prefi"] = std::get<5>(t);
|
||||
d["n_block_down_prefi"] = std::get<6>(t);
|
||||
return d;
|
||||
},
|
||||
"Get current tiling parameters for INT4 kernel");
|
||||
tiling_module.def(
|
||||
"set_int4",
|
||||
[](int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block, int n_block_up_gate_prefi,
|
||||
int n_block_down_prefi) {
|
||||
moe_kernel::GemmKernelInt4::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,
|
||||
n_block_up_gate_prefi, n_block_down_prefi);
|
||||
},
|
||||
py::arg("n_block_up_gate"), py::arg("n_block_down"), py::arg("n_block"), py::arg("m_block"), py::arg("k_block"),
|
||||
py::arg("n_block_up_gate_prefi"), py::arg("n_block_down_prefi"), "Set tiling parameters for INT4 kernel");
|
||||
|
||||
// Convenience: set both
|
||||
tiling_module.def(
|
||||
"set_all",
|
||||
[](int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block, int n_block_up_gate_prefi,
|
||||
int n_block_down_prefi) {
|
||||
moe_kernel::GemmKernelInt8::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,
|
||||
n_block_up_gate_prefi, n_block_down_prefi);
|
||||
moe_kernel::GemmKernelInt4::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,
|
||||
n_block_up_gate_prefi, n_block_down_prefi);
|
||||
},
|
||||
py::arg("n_block_up_gate"), py::arg("n_block_down"), py::arg("n_block"), py::arg("m_block"), py::arg("k_block"),
|
||||
py::arg("n_block_up_gate_prefi"), py::arg("n_block_down_prefi"),
|
||||
"Set tiling parameters for both INT8 and INT4 kernels");
|
||||
#endif
|
||||
}
|
||||
|
||||
auto kvcache_module = m.def_submodule("kvcache");
|
||||
|
||||
@@ -628,18 +662,7 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
|
||||
.def(py::init<KVCacheConfig>())
|
||||
.def("get_cache_total_len", &KVCache::get_cache_total_len)
|
||||
.def("update_cache_total_len",
|
||||
[](KVCache& kvcache, int cache_total_len) { kvcache.update_cache_total_len(cache_total_len); })
|
||||
|
||||
// .def("attn", &KVCacheBindings::AttnBindings::cpuinfer_interface)
|
||||
// .def("get_all_kvcache_one_layer", &KVCacheBindings::GetAllKVCacheOneLayerBindings::cpuinfer_interface)
|
||||
// .def("get_and_update_kvcache_fp16", &KVCacheBindings::GetAndUpdateKVCacheFp16Bindings::cpuinfer_interface)
|
||||
// .def("get_kvcache_fp16", &KVCacheBindings::GetKVCacheFp16Bindings::cpuinfer_interface)
|
||||
// .def("update_kvcache_fp16", &KVCacheBindings::UpdateKVCacheFp16Bindings::cpuinfer_interface)
|
||||
// .def("update_importance", &KVCacheBindings::UpdateImportanceBindings::cpuinfer_interface)
|
||||
// .def("attn_with_kvcache", &KVCacheBindings::AttnWithKVCacheBindings::cpuinfer_interface)
|
||||
// .def("clear_importance_all_layers", &KVCacheBindings::ClearImportanceAllLayersBindings::cpuinfer_interface)
|
||||
// .def("calc_anchor_all_layers", &KVCacheBindings::CalcAnchorAllLayersBindings::cpuinfer_interface)
|
||||
;
|
||||
[](KVCache& kvcache, int cache_total_len) { kvcache.update_cache_total_len(cache_total_len); });
|
||||
|
||||
auto utils = m.def_submodule("utils");
|
||||
|
||||
|
||||
@@ -29,14 +29,10 @@
|
||||
|
||||
#include "../../cpu_backend/shared_mem_buffer.h"
|
||||
#include "../../cpu_backend/worker_pool.h"
|
||||
#include "../common.hpp"
|
||||
#include "../moe-tp.hpp"
|
||||
#include "la/amx.hpp"
|
||||
#include "llama.cpp/ggml-impl.h"
|
||||
#include "llama.cpp/ggml-quants.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#include "llamafile/sgemm.h"
|
||||
|
||||
#define expert_map(m, x) (m != nullptr ? m[(x)] : (x))
|
||||
|
||||
template <class T>
|
||||
class AMX_AWQ_MOE_TP {
|
||||
@@ -475,12 +471,6 @@ class AMX_AWQ_MOE_TP {
|
||||
m_local_up_output_ptr_.resize(config_.expert_num);
|
||||
m_local_down_output_ptr_.resize(config_.expert_num);
|
||||
|
||||
// printf("tp part %d alloc layer %d, %f GB, on numa %d\n", tp_part_idx, config_.layer_idx,
|
||||
// 1e-9 * config_.expert_num *
|
||||
// (T::BufferB::required_size(config_.intermediate_size, config_.hidden_size) * 2 +
|
||||
// T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)),
|
||||
// numa_node_of_cpu(sched_getcpu()));
|
||||
|
||||
for (size_t i = 0; i < config_.expert_num; i++) {
|
||||
gate_up_ba_.push_back(
|
||||
std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, group_size, nullptr));
|
||||
@@ -524,9 +514,10 @@ class AMX_AWQ_MOE_TP {
|
||||
// shared_mem_buffer_numa.dealloc(this);
|
||||
}
|
||||
|
||||
void load_weights(const uint64_t* physical_to_logical_map) {
|
||||
void load_weights() {
|
||||
auto& quant_config = config_.quant_config;
|
||||
int& group_size = quant_config.group_size;
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
|
||||
if (quant_config.group_size == 0 || !quant_config.zero_point) {
|
||||
throw std::runtime_error("AWQ-Quantization AMX MoE only support KGroup Int4_1");
|
||||
}
|
||||
@@ -534,171 +525,12 @@ class AMX_AWQ_MOE_TP {
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
if (config_.gate_projs.size()) {
|
||||
throw std::runtime_error("AMX load weights is not support");
|
||||
// pool->do_work_stealing_job(
|
||||
// config_.expert_num, nullptr,
|
||||
// [this, physical_to_logical_map](int expert_id) {
|
||||
// // printf("Load layer %d [%d/%d]\n", config_.layer_idx, expert_id, config_.expert_num);
|
||||
// uint64_t logical_expert_id = physical_to_logical_map[expert_id];
|
||||
// auto& quant_config = config_.quant_config;
|
||||
// int& group_size = quant_config.group_size;
|
||||
// {
|
||||
// int num_group = config_.hidden_size / group_size;
|
||||
// size_t scale_size = num_group * config_.intermediate_size * sizeof(float);
|
||||
// size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size) -
|
||||
// (scale_size << 1);
|
||||
|
||||
// memcpy(gate_bb_[expert_id]->b, config_.gate_projs[tp_part_idx][logical_expert_id], size);
|
||||
|
||||
// if constexpr (T::BufferB::SCALE) {
|
||||
// memcpy(gate_bb_[expert_id]->d, config_.gate_scales[tp_part_idx][logical_expert_id], scale_size);
|
||||
// }
|
||||
|
||||
// memcpy(up_bb_[expert_id]->b, config_.up_projs[tp_part_idx][logical_expert_id], size);
|
||||
|
||||
// if constexpr (T::BufferB::SCALE) {
|
||||
// memcpy(up_bb_[expert_id]->d, config_.up_scales[tp_part_idx][logical_expert_id], scale_size);
|
||||
// }
|
||||
|
||||
// if (quant_config.zero_point) {
|
||||
// // Convert INT4 zeros to float mins using AVX optimization
|
||||
// size_t num_elements = num_group * config_.intermediate_size;
|
||||
// convert_zeros_to_mins_avx(
|
||||
// (const uint8_t*)config_.gate_zeros[tp_part_idx][logical_expert_id],
|
||||
// (const float*)config_.gate_scales[tp_part_idx][logical_expert_id],
|
||||
// gate_bb_[expert_id]->mins,
|
||||
// num_elements
|
||||
// );
|
||||
// convert_zeros_to_mins_avx(
|
||||
// (const uint8_t*)config_.up_zeros[tp_part_idx][logical_expert_id],
|
||||
// (const float*)config_.up_scales[tp_part_idx][logical_expert_id],
|
||||
// up_bb_[expert_id]->mins,
|
||||
// num_elements
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
|
||||
// {
|
||||
// int num_group = config_.intermediate_size / group_size;
|
||||
// size_t scale_size = num_group * config_.hidden_size * sizeof(float);
|
||||
// size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, group_size) -
|
||||
// (scale_size << 1);
|
||||
|
||||
// memcpy(down_bb_[expert_id]->b, config_.down_projs[tp_part_idx][logical_expert_id], size);
|
||||
|
||||
// if constexpr (T::BufferB::SCALE) {
|
||||
// memcpy(down_bb_[expert_id]->d, config_.down_scales[tp_part_idx][logical_expert_id], scale_size);
|
||||
// }
|
||||
|
||||
// if (quant_config.zero_point) {
|
||||
// // Convert INT4 zeros to float mins using AVX optimization
|
||||
// size_t num_elements = num_group * config_.hidden_size;
|
||||
// convert_zeros_to_mins_avx(
|
||||
// (const uint8_t*)config_.down_zeros[tp_part_idx][logical_expert_id],
|
||||
// (const float*)config_.down_scales[tp_part_idx][logical_expert_id],
|
||||
// down_bb_[expert_id]->mins,
|
||||
// num_elements
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// nullptr);
|
||||
|
||||
} else {
|
||||
// AWQ Load from file implementation
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
static uint8_t mat_type_all = 3, mat_split = 1;
|
||||
if (config_.load) {
|
||||
throw std::runtime_error("AMX load weights from file is not support");
|
||||
// std::cout << "Loading AWQ weights from " << prefix << std::endl;
|
||||
|
||||
// // Use work stealing job for parallel loading
|
||||
// pool->do_work_stealing_job(
|
||||
// config_.expert_num * mat_type_all, nullptr,
|
||||
// [this, physical_to_logical_map, mat_split](int task_id) {
|
||||
// auto& quant_config = config_.quant_config;
|
||||
// int& group_size = quant_config.group_size;
|
||||
|
||||
// int64_t expert_idx = task_id / mat_type_all;
|
||||
// uint64_t logical_expert_id = physical_to_logical_map[expert_idx];
|
||||
// uint8_t mat_class = task_id % mat_type_all;
|
||||
|
||||
// if (mat_class == 0) { // gate projection
|
||||
// int num_group = config_.hidden_size / group_size;
|
||||
// size_t weights_size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size,
|
||||
// group_size) - (2 * num_group * config_.intermediate_size * sizeof(float)); size_t scales_size =
|
||||
// num_group * config_.intermediate_size * sizeof(float); size_t zeros_size = num_group *
|
||||
// config_.intermediate_size / 2; // INT4 packed format
|
||||
|
||||
// // Allocate temporary buffer for zeros
|
||||
// std::vector<uint8_t> zeros_buf(zeros_size);
|
||||
|
||||
// read_awq_weights(prefix, "gate_proj", logical_expert_id,
|
||||
// (char*)gate_bb_[expert_idx]->b,
|
||||
// (float*)gate_bb_[expert_idx]->d,
|
||||
// zeros_buf.data(),
|
||||
// weights_size, scales_size, zeros_size,
|
||||
// mat_split, 0);
|
||||
|
||||
// // Convert INT4 zeros to float mins
|
||||
// if (quant_config.zero_point) {
|
||||
// convert_zeros_to_mins_avx(zeros_buf.data(),
|
||||
// (float*)gate_bb_[expert_idx]->d,
|
||||
// gate_bb_[expert_idx]->mins,
|
||||
// num_group * config_.intermediate_size);
|
||||
// }
|
||||
|
||||
// } else if (mat_class == 1) { // up projection
|
||||
// int num_group = config_.hidden_size / group_size;
|
||||
// size_t weights_size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size,
|
||||
// group_size) - (2 * num_group * config_.intermediate_size * sizeof(float)); size_t scales_size =
|
||||
// num_group * config_.intermediate_size * sizeof(float); size_t zeros_size = num_group *
|
||||
// config_.intermediate_size / 2; // INT4 packed format
|
||||
|
||||
// // Allocate temporary buffer for zeros
|
||||
// std::vector<uint8_t> zeros_buf(zeros_size);
|
||||
|
||||
// read_awq_weights(prefix, "up_proj", logical_expert_id,
|
||||
// (char*)up_bb_[expert_idx]->b,
|
||||
// (float*)up_bb_[expert_idx]->d,
|
||||
// zeros_buf.data(),
|
||||
// weights_size, scales_size, zeros_size,
|
||||
// mat_split, 0);
|
||||
|
||||
// // Convert INT4 zeros to float mins
|
||||
// if (quant_config.zero_point) {
|
||||
// convert_zeros_to_mins_avx(zeros_buf.data(),
|
||||
// (float*)up_bb_[expert_idx]->d,
|
||||
// up_bb_[expert_idx]->mins,
|
||||
// num_group * config_.intermediate_size);
|
||||
// }
|
||||
|
||||
// } else { // down projection
|
||||
// int num_group = config_.intermediate_size / group_size;
|
||||
// size_t weights_size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size,
|
||||
// group_size) - (2 * num_group * config_.hidden_size * sizeof(float)); size_t scales_size = num_group
|
||||
// * config_.hidden_size * sizeof(float); size_t zeros_size = num_group * config_.hidden_size / 2; //
|
||||
// INT4 packed format
|
||||
|
||||
// // Allocate temporary buffer for zeros
|
||||
// std::vector<uint8_t> zeros_buf(zeros_size);
|
||||
|
||||
// read_awq_weights(prefix, "down_proj", logical_expert_id,
|
||||
// (char*)down_bb_[expert_idx]->b,
|
||||
// (float*)down_bb_[expert_idx]->d,
|
||||
// zeros_buf.data(),
|
||||
// weights_size, scales_size, zeros_size,
|
||||
// mat_split, 0);
|
||||
|
||||
// // Convert INT4 zeros to float mins
|
||||
// if (quant_config.zero_point) {
|
||||
// convert_zeros_to_mins_avx(zeros_buf.data(),
|
||||
// (float*)down_bb_[expert_idx]->d,
|
||||
// down_bb_[expert_idx]->mins,
|
||||
// num_group * config_.hidden_size);
|
||||
// }
|
||||
// }
|
||||
// },
|
||||
// nullptr);
|
||||
}
|
||||
// check process, store down matrix to check
|
||||
#ifdef CHECK
|
||||
@@ -708,13 +540,12 @@ class AMX_AWQ_MOE_TP {
|
||||
else if (config_.gate_scale != nullptr)
|
||||
#endif
|
||||
{
|
||||
// Loading quantized weights
|
||||
// Loading quantized weights
|
||||
pool->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map](int task_id) {
|
||||
uint64_t expert_idx = task_id / nth;
|
||||
uint64_t logical_expert_id = physical_to_logical_map[expert_idx];
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
// gate part
|
||||
gate_bb_[expert_idx]->from_raw_mat(
|
||||
@@ -734,7 +565,7 @@ class AMX_AWQ_MOE_TP {
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map](int task_id) {
|
||||
uint64_t expert_idx = task_id / nth;
|
||||
uint64_t logical_expert_id = physical_to_logical_map[expert_idx];
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
// down part
|
||||
down_bb_[expert_idx]->from_raw_mat(
|
||||
@@ -748,7 +579,7 @@ class AMX_AWQ_MOE_TP {
|
||||
config_.expert_num, nullptr,
|
||||
[this, physical_to_logical_map](int task_id) {
|
||||
uint64_t expert_idx = task_id;
|
||||
uint64_t logical_expert_id = physical_to_logical_map[expert_idx];
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
size_t scale_elem_count =
|
||||
(config_.hidden_size * config_.intermediate_size) / config_.quant_config.group_size;
|
||||
|
||||
@@ -793,7 +624,7 @@ class AMX_AWQ_MOE_TP {
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map](int task_id) {
|
||||
int64_t expert_idx = task_id / nth;
|
||||
uint64_t logical_expert_id = physical_to_logical_map[expert_idx];
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
// gate part
|
||||
gate_bb_[logical_expert_id]->from_mat(
|
||||
@@ -812,7 +643,7 @@ class AMX_AWQ_MOE_TP {
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map](int task_id) {
|
||||
int64_t expert_idx = task_id / nth;
|
||||
uint64_t logical_expert_id = physical_to_logical_map[expert_idx];
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
// down part
|
||||
down_bb_[logical_expert_id]->from_mat(
|
||||
@@ -1280,16 +1111,15 @@ template <typename K>
|
||||
class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
|
||||
public:
|
||||
using TP_MOE_Common<AMX_AWQ_MOE_TP<K>>::TP_MOE_Common;
|
||||
void load_weights(const uint64_t* physical_to_logical_map) {
|
||||
void load_weights() {
|
||||
auto& config = this->config;
|
||||
auto& tps = this->tps;
|
||||
auto& tp_count = this->tp_count;
|
||||
auto pool = config.pool;
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
|
||||
if (config.gate_projs.empty() == false) {
|
||||
printf("TP Load from loader\n");
|
||||
pool->dispense_backend()->do_numa_job([this, pool, physical_to_logical_map](int numa_id) {
|
||||
this->tps[numa_id]->load_weights(physical_to_logical_map);
|
||||
});
|
||||
DO_TPS_LOAD_WEIGHTS(pool);
|
||||
this->weights_loaded = true;
|
||||
} else if (config.gate_scale != nullptr) {
|
||||
printf("From Packed Int4 with KGroup Scale and Zeros\n");
|
||||
@@ -1314,7 +1144,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
|
||||
pool->get_subpool(i)->do_work_stealing_job(
|
||||
tpc.expert_num, nullptr,
|
||||
[&](int expert_id_) {
|
||||
size_t expert_id = expert_id_;
|
||||
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
|
||||
|
||||
// weight TP-slicing
|
||||
memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),
|
||||
@@ -1384,9 +1214,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
|
||||
}
|
||||
}
|
||||
|
||||
pool->dispense_backend()->do_numa_job([this, pool, physical_to_logical_map](int numa_id) {
|
||||
this->tps[numa_id]->load_weights(physical_to_logical_map);
|
||||
});
|
||||
DO_TPS_LOAD_WEIGHTS(pool);
|
||||
|
||||
for (auto i = 0; i < tp_count; i++) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
@@ -1417,7 +1245,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
|
||||
pool->get_subpool(i)->do_work_stealing_job(
|
||||
tpc.expert_num, nullptr,
|
||||
[&](int expert_id_) {
|
||||
size_t expert_id = physical_to_logical_map[expert_id_];
|
||||
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
|
||||
memcpy((ggml_bf16_t*)tpc.gate_proj + expert_id * gate_up_elcount,
|
||||
(ggml_bf16_t*)config.gate_proj + expert_id * config.intermediate_size * config.hidden_size +
|
||||
i * gate_up_elcount,
|
||||
@@ -1438,9 +1266,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
|
||||
}
|
||||
}
|
||||
|
||||
pool->dispense_backend()->do_numa_job([this, pool, physical_to_logical_map](int numa_id) {
|
||||
this->tps[numa_id]->load_weights(physical_to_logical_map);
|
||||
});
|
||||
DO_TPS_LOAD_WEIGHTS(pool);
|
||||
|
||||
for (auto i = 0; i < tp_count; i++) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
@@ -1452,9 +1278,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
|
||||
this->weights_loaded = true;
|
||||
} else if (config.path != "") {
|
||||
printf("TP Load from file\n");
|
||||
pool->dispense_backend()->do_numa_job([this, pool, physical_to_logical_map](int numa_id) {
|
||||
this->tps[numa_id]->load_weights(physical_to_logical_map);
|
||||
});
|
||||
DO_TPS_LOAD_WEIGHTS(pool);
|
||||
this->weights_loaded = true;
|
||||
} else {
|
||||
throw std::runtime_error("no weight source");
|
||||
|
||||
@@ -7,27 +7,15 @@
|
||||
#include <tmmintrin.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "llama.cpp/ggml-impl.h"
|
||||
#include "llama.cpp/ggml-quants.h"
|
||||
#include "llamafile/sgemm.h"
|
||||
#include "utils.hpp"
|
||||
|
||||
// Include the split AMX headers
|
||||
#include "amx_buffers.hpp"
|
||||
#include "amx_config.hpp"
|
||||
#include "amx_kernels.hpp"
|
||||
#include "amx_quantization.hpp"
|
||||
#include "amx_utils.hpp"
|
||||
|
||||
namespace amx {
|
||||
|
||||
|
||||
@@ -5,9 +5,7 @@
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "amx_config.hpp"
|
||||
#include "amx_utils.hpp"
|
||||
#include "llama.cpp/ggml-impl.h"
|
||||
#include "pack.hpp"
|
||||
@@ -48,16 +46,41 @@ struct BufferAImpl {
|
||||
assert(ith == 0 && nth == 1);
|
||||
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
|
||||
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
|
||||
float amax = 0.0f;
|
||||
for (int j = 0; j < k; j += 32) {
|
||||
__m512 f0, f1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j), &f0, &f1);
|
||||
amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));
|
||||
amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));
|
||||
__m512 amax_v0 = _mm512_setzero_ps();
|
||||
__m512 amax_v1 = _mm512_setzero_ps();
|
||||
__m512 amax_v2 = _mm512_setzero_ps();
|
||||
__m512 amax_v3 = _mm512_setzero_ps();
|
||||
__m512 amax_v4 = _mm512_setzero_ps();
|
||||
__m512 amax_v5 = _mm512_setzero_ps();
|
||||
__m512 amax_v6 = _mm512_setzero_ps();
|
||||
__m512 amax_v7 = _mm512_setzero_ps();
|
||||
for (int j = 0; j < k; j += 128) {
|
||||
__m512 f0, f1, f2, f3, f4, f5, f6, f7;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 0), &f0, &f1);
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 32), &f2, &f3);
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 64), &f4, &f5);
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 96), &f6, &f7);
|
||||
amax_v0 = vector_abs_max(amax_v0, f0);
|
||||
amax_v1 = vector_abs_max(amax_v1, f1);
|
||||
amax_v2 = vector_abs_max(amax_v2, f2);
|
||||
amax_v3 = vector_abs_max(amax_v3, f3);
|
||||
amax_v4 = vector_abs_max(amax_v4, f4);
|
||||
amax_v5 = vector_abs_max(amax_v5, f5);
|
||||
amax_v6 = vector_abs_max(amax_v6, f6);
|
||||
amax_v7 = vector_abs_max(amax_v7, f7);
|
||||
}
|
||||
amax_v0 = vector_abs_max(amax_v0, amax_v1);
|
||||
amax_v2 = vector_abs_max(amax_v2, amax_v3);
|
||||
amax_v4 = vector_abs_max(amax_v4, amax_v5);
|
||||
amax_v6 = vector_abs_max(amax_v6, amax_v7);
|
||||
amax_v0 = vector_abs_max(amax_v0, amax_v2);
|
||||
amax_v4 = vector_abs_max(amax_v4, amax_v6);
|
||||
amax_v0 = vector_abs_max(amax_v0, amax_v4);
|
||||
float amax = _mm512_reduce_max_ps(amax_v0);
|
||||
d[m_begin + i] = amax / ((1 << 7) - 1);
|
||||
}
|
||||
}
|
||||
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
|
||||
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
|
||||
|
||||
@@ -1180,9 +1180,9 @@ struct GemmKernel224Int4 {
|
||||
}
|
||||
|
||||
static void load_a(dt* a, size_t lda) {
|
||||
#ifdef HAVE_AMX
|
||||
_tile_loadd(0, a, lda);
|
||||
_tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);
|
||||
#ifdef __AMX__
|
||||
_tile_stream_loadd(0, a, lda);
|
||||
_tile_stream_loadd(1, offset_pointer(a, lda * TILE_M), lda);
|
||||
#else
|
||||
(void)a;
|
||||
(void)lda;
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#ifndef UTILS_HPP
|
||||
#define UTILS_HPP
|
||||
#include <immintrin.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
@@ -18,4 +20,13 @@ static inline void avx512_32xbf16_to_32xfp32(__m512i* src, __m512* dst0, __m512*
|
||||
_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)(src) + 1)), 16)));
|
||||
}
|
||||
|
||||
static inline __m512 vector_abs_max(__m512 a, __m512 b) {
|
||||
__m512 a_abs = _mm512_abs_ps(a);
|
||||
__m512 b_abs = _mm512_abs_ps(b);
|
||||
|
||||
__mmask16 mask = _mm512_cmp_ps_mask(a_abs, b_abs, _CMP_GT_OS);
|
||||
|
||||
return _mm512_mask_blend_ps(mask, b_abs, a_abs);
|
||||
}
|
||||
|
||||
#endif // UTILS_HPP
|
||||
@@ -29,10 +29,7 @@
|
||||
#include "../../cpu_backend/worker_pool.h"
|
||||
#include "../moe-tp.hpp"
|
||||
#include "la/amx.hpp"
|
||||
#include "llama.cpp/ggml-impl.h"
|
||||
#include "llama.cpp/ggml-quants.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#include "llamafile/sgemm.h"
|
||||
|
||||
template <class T>
|
||||
class AMX_MOE_TP {
|
||||
@@ -264,16 +261,15 @@ class AMX_MOE_TP {
|
||||
~AMX_MOE_TP() {
|
||||
// shared_mem_buffer_numa.dealloc(this);
|
||||
}
|
||||
// pack and quant the weights
|
||||
void pack_weights() {}
|
||||
void load_weights(const uint64_t* physical_to_logical_map) {
|
||||
void load_weights() {
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
|
||||
if (config_.gate_projs.size()) {
|
||||
pool->do_work_stealing_job(
|
||||
config_.expert_num, nullptr,
|
||||
[this, physical_to_logical_map](int expert_id) {
|
||||
// printf("Load layer %d [%d/%d]\n", config_.layer_idx, expert_id, config_.expert_num);
|
||||
uint64_t logical_expert_id = physical_to_logical_map[expert_id];
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_id);
|
||||
{
|
||||
size_t scale_size = config_.intermediate_size * sizeof(float);
|
||||
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size) - scale_size;
|
||||
@@ -311,7 +307,7 @@ class AMX_MOE_TP {
|
||||
std::cout << "Loading from " << prefix << std::endl;
|
||||
for (int task_id = 0; task_id < config_.expert_num * mat_type_all * mat_split; task_id++) {
|
||||
int64_t expert_idx = task_id / (mat_type_all * mat_split);
|
||||
uint64_t logical_expert_id = physical_to_logical_map[expert_idx];
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
uint8_t mat_class = (task_id % (mat_type_all * mat_split)) / mat_split;
|
||||
uint8_t mat_split_idex = task_id % mat_split;
|
||||
if (mat_class == 0) { // the up matrix
|
||||
@@ -345,30 +341,32 @@ class AMX_MOE_TP {
|
||||
}
|
||||
pool->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth](int task_id) {
|
||||
[this, nth, physical_to_logical_map](int task_id) {
|
||||
int64_t expert_idx = task_id / nth;
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
// gate part
|
||||
gate_bb_[expert_idx]->from_mat(
|
||||
(ggml_bf16_t*)config_.gate_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith,
|
||||
nth);
|
||||
gate_bb_[logical_expert_id]->from_mat(
|
||||
(ggml_bf16_t*)config_.gate_proj + logical_expert_id * config_.intermediate_size * config_.hidden_size,
|
||||
ith, nth);
|
||||
// up part
|
||||
up_bb_[expert_idx]->from_mat(
|
||||
(ggml_bf16_t*)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith,
|
||||
nth);
|
||||
up_bb_[logical_expert_id]->from_mat(
|
||||
(ggml_bf16_t*)config_.up_proj + logical_expert_id * config_.intermediate_size * config_.hidden_size,
|
||||
ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth](int task_id) {
|
||||
[this, nth, physical_to_logical_map](int task_id) {
|
||||
int64_t expert_idx = task_id / nth;
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
// down part
|
||||
down_bb_[expert_idx]->from_mat(
|
||||
(ggml_bf16_t*)config_.down_proj + expert_idx * config_.hidden_size * config_.intermediate_size, ith,
|
||||
nth);
|
||||
down_bb_[logical_expert_id]->from_mat(
|
||||
(ggml_bf16_t*)config_.down_proj + logical_expert_id * config_.hidden_size * config_.intermediate_size,
|
||||
ith, nth);
|
||||
// printf("load down, expert %ld, ith %d, total nth %d\n", expert_idx, ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
@@ -380,8 +378,9 @@ class AMX_MOE_TP {
|
||||
if (config_.save) {
|
||||
pool->do_work_stealing_job(
|
||||
config_.expert_num * mat_type_all, nullptr,
|
||||
[this](int task_id) {
|
||||
[this, physical_to_logical_map](int task_id) {
|
||||
int64_t expert_idx = task_id / mat_type_all;
|
||||
expert_idx = expert_map(physical_to_logical_map, expert_idx);
|
||||
uint8_t mat_class = task_id % mat_type_all;
|
||||
if (mat_class == 0) { // the up matrix
|
||||
size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);
|
||||
@@ -829,16 +828,16 @@ template <typename K>
|
||||
class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
|
||||
public:
|
||||
using TP_MOE_Common<AMX_MOE_TP<K>>::TP_MOE_Common;
|
||||
void load_weights(const uint64_t* physical_to_logical_map) {
|
||||
void load_weights() {
|
||||
auto& config = this->config;
|
||||
auto& tps = this->tps;
|
||||
auto& tp_count = this->tp_count;
|
||||
auto pool = config.pool;
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
|
||||
if (config.gate_projs.empty() == false) {
|
||||
printf("TP Load from loader\n");
|
||||
pool->dispense_backend()->do_numa_job([this, pool, physical_to_logical_map](int numa_id) {
|
||||
this->tps[numa_id]->load_weights(physical_to_logical_map);
|
||||
});
|
||||
pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
|
||||
|
||||
this->weights_loaded = true;
|
||||
} else if (config.gate_proj != nullptr) {
|
||||
printf("From BF16\n");
|
||||
@@ -852,7 +851,7 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
|
||||
pool->get_subpool(i)->do_work_stealing_job(
|
||||
tpc.expert_num, nullptr,
|
||||
[&](int expert_id_) {
|
||||
size_t expert_id = expert_id_;
|
||||
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
|
||||
memcpy((ggml_bf16_t*)tpc.gate_proj + expert_id * gate_up_elcount,
|
||||
(ggml_bf16_t*)config.gate_proj + expert_id * config.intermediate_size * config.hidden_size +
|
||||
i * gate_up_elcount,
|
||||
@@ -873,9 +872,7 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
|
||||
}
|
||||
}
|
||||
|
||||
pool->dispense_backend()->do_numa_job([this, pool, physical_to_logical_map](int numa_id) {
|
||||
this->tps[numa_id]->load_weights(physical_to_logical_map);
|
||||
});
|
||||
pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
|
||||
|
||||
for (auto i = 0; i < tp_count; i++) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
@@ -887,9 +884,7 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
|
||||
this->weights_loaded = true;
|
||||
} else if (config.path != "") {
|
||||
printf("TP Load from file\n");
|
||||
pool->dispense_backend()->do_numa_job([this, pool, physical_to_logical_map](int numa_id) {
|
||||
this->tps[numa_id]->load_weights(physical_to_logical_map);
|
||||
});
|
||||
pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
|
||||
this->weights_loaded = true;
|
||||
} else {
|
||||
throw std::runtime_error("no weight source");
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#ifndef CPUINFER_OPERATOR_COMMON_HPP
|
||||
#define CPUINFER_OPERATOR_COMMON_HPP
|
||||
|
||||
#include "../cpu_backend/shared_mem_buffer.h"
|
||||
#include <map>
|
||||
|
||||
#include "../cpu_backend/worker_pool.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#if defined(__aarch64__) && defined(CPU_USE_KML)
|
||||
#include <arm_sve.h>
|
||||
@@ -13,8 +14,6 @@
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -40,6 +39,14 @@
|
||||
last = end_time; \
|
||||
} while (0)
|
||||
|
||||
#define DO_TPS_LOAD_WEIGHTS(pool) \
|
||||
(pool)->dispense_backend()->do_numa_job([this, pool, config](int numa_id) { \
|
||||
this->tps[numa_id]->config_.physical_to_logical_map = config.physical_to_logical_map; \
|
||||
this->tps[numa_id]->load_weights(); \
|
||||
})
|
||||
|
||||
#define expert_map(m, x) (m != nullptr ? m[(x)] : (x))
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
|
||||
inline T div_up(T x, T y) {
|
||||
return (x + y - 1) / y;
|
||||
@@ -274,12 +281,11 @@ struct GeneralMOEConfig {
|
||||
|
||||
GeneralMOEConfig() {}
|
||||
|
||||
GeneralMOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int num_gpu_experts)
|
||||
GeneralMOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size)
|
||||
: expert_num(expert_num),
|
||||
num_experts_per_tok(routed_expert_num),
|
||||
hidden_size(hidden_size),
|
||||
intermediate_size(intermediate_size),
|
||||
num_gpu_experts(num_gpu_experts) {}
|
||||
intermediate_size(intermediate_size) {}
|
||||
|
||||
int max_possible_qlen() { return std::max(max_len, group_max_len); }
|
||||
};
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
#include "batch_gemm_api.hpp"
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void decode_cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
|
||||
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k,
|
||||
const float alpha, const void* a, const size_t lda, const BLASINT8 oa, const void* b,
|
||||
const size_t ldb, const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
|
||||
const int32_t* oc) {
|
||||
BLASINT8* ptrA = (BLASINT8*)a;
|
||||
BLASINT8* ptrB = (BLASINT8*)b;
|
||||
int32_t* ptrC = c;
|
||||
size_t split_n = n / N_SIZE;
|
||||
|
||||
for (size_t n_block = 0; n_block < split_n; n_block++) {
|
||||
BLASINT8* cur_ptrA = ptrA;
|
||||
BLASINT8* cur_ptrB = ptrB + n_block * (N_SIZE * k);
|
||||
int32_t* cur_ptrC = ptrC + n_block * N_SIZE;
|
||||
gemm_kernel_1x8(cur_ptrA, cur_ptrB, cur_ptrC, ldc, k, COMP_SV_LEN);
|
||||
}
|
||||
}
|
||||
|
||||
void decode_int4_cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa,
|
||||
const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc, const size_t m,
|
||||
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
|
||||
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc) {
|
||||
BLASINT8* ptrA = (BLASINT8*)a;
|
||||
BLASINT8* ptrB = (BLASINT8*)b;
|
||||
int32_t* ptrC = c;
|
||||
size_t split_n = n / N_SIZE;
|
||||
|
||||
for (size_t n_block = 0; n_block < split_n; n_block++) {
|
||||
BLASINT8* cur_ptrA = ptrA;
|
||||
BLASINT8* cur_ptrB = ptrB + n_block * (N_SIZE * (k / 2));
|
||||
int32_t* cur_ptrC = ptrC + n_block * N_SIZE;
|
||||
gemm_kernel_1x8_int4(cur_ptrA, cur_ptrB, cur_ptrC, (ldc / 2), (k / 2), COMP_SV_LEN);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1,35 +0,0 @@
|
||||
#ifndef _BATCH_GEMM_KERNEL_API_
|
||||
#define _BATCH_GEMM_KERNEL_API_
|
||||
|
||||
#include "utils.hpp"
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void decode_cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
|
||||
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k,
|
||||
const float alpha, const void* a, const size_t lda, const BLASINT8 oa, const void* b,
|
||||
const size_t ldb, const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
|
||||
const int32_t* oc);
|
||||
|
||||
void prefill_cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa, const CBLAS_TRANSPOSE transb,
|
||||
const CBLAS_OFFSET offsetc, const size_t m, const size_t n, const size_t k,
|
||||
const float alpha, const void* a, const size_t lda, const BLASINT8 oa, const void* b,
|
||||
const size_t ldb, const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
|
||||
const int32_t* oc);
|
||||
|
||||
void decode_int4_cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa,
|
||||
const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc, const size_t m,
|
||||
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
|
||||
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void prefill_int4_cblas_gemm_s8s8s32(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE transa,
|
||||
const CBLAS_TRANSPOSE transb, const CBLAS_OFFSET offsetc, const size_t m,
|
||||
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
|
||||
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif /*** _BATCH_GEMM_KERNEL_API_ ***/
|
||||
@@ -1,104 +0,0 @@
|
||||
#pragma once
|
||||
// #include <arm_sve.h>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
// static inline void sve_32xbf16_to_32xfp32(const bfloat16_t *src, float *dst0, float *dst1) {
|
||||
// #ifdef __ARM_FEATURE_SVE
|
||||
// // 全真谓词,对应每个 16‑bit 元素
|
||||
// svbool_t pg_h = svptrue_b16();
|
||||
// // 每次循环处理 svcnth(pg_h) 个 BF16 元素,svcnth(pg_h)==VL/16
|
||||
// size_t offset = 0;
|
||||
// // 我们要生产两段 FP32 输出,每段长度 svcntw(pg_h)==VL/32
|
||||
// // SVE 向量寄存宽度 VL 可以任意 128–2048,但代码与之无关
|
||||
|
||||
// // Load first half BF16→FP32
|
||||
// svbfloat16_t vb0 = svld1(pg_h, &src[offset]); // load BF16
|
||||
// svfloat32_t vf0 = svcvt_f32_bf16_z(pg_h, vb0); // widen→FP32
|
||||
// svst1(pg_h, &dst0[offset/2], vf0); // store
|
||||
|
||||
// offset += svcnth(pg_h); // 移到第二批 BF16 元素
|
||||
|
||||
// // Load second half BF16→FP32
|
||||
// svbfloat16_t vb1 = svld1(pg_h, &src[offset]);
|
||||
// svfloat32_t vf1 = svcvt_f32_bf16_z(pg_h, vb1);
|
||||
// svst1(pg_h, &dst1[offset/2], vf1);
|
||||
// #else
|
||||
// // fallback: scalar or NEON
|
||||
// #endif
|
||||
// }
|
||||
|
||||
// 简单截断模式:直接丢弃低 16 位
|
||||
static inline uint16_t float_to_bf16_trunc(float f) {
|
||||
uint32_t u;
|
||||
// 按位拷贝,避免 strict‑aliasing UB
|
||||
memcpy(&u, &f, sizeof(u)); // :contentReference[oaicite:3]{index=3}
|
||||
return (uint16_t)(u >> 16); // 截断得到高 16 位 :contentReference[oaicite:4]{index=4}
|
||||
}
|
||||
|
||||
static inline void convert_32fp32_to_32bf16_pure_c(const float* src, uint16_t* dst) {
|
||||
// src 已偏移至 token_nth * hidden_size
|
||||
for (int e = 0; e < 32; e++) { // 共 32 个元素
|
||||
// 选择截断或四舍五入
|
||||
dst[e] = float_to_bf16_trunc(src[e]);
|
||||
}
|
||||
}
|
||||
|
||||
// 把 32 个 bf16 元素转换成 32 个 fp32 元素
|
||||
|
||||
static inline void convert_32bf16_to_32fp32_pure_c(const uint16_t* src, float* dst) {
|
||||
for (int e = 0; e < 32; e++) {
|
||||
uint32_t temp = ((uint32_t)src[e]) << 16; // 将 BF16 左移 16 位
|
||||
memcpy(&dst[e], &temp, sizeof(float)); // 将结果复制到 FP32 变量中
|
||||
}
|
||||
}
|
||||
|
||||
// template <typename T> T *offset_pointer(T *ptr, std::size_t byte_offset) {
|
||||
// return reinterpret_cast<T *>(reinterpret_cast<char *>(ptr) + byte_offset);
|
||||
// }
|
||||
|
||||
/*** gemm helper ***/
|
||||
#include <kblas.h>
|
||||
|
||||
#include <iostream>
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#define COMP_SV_LEN 32
|
||||
#define K_SIZE COMP_SV_LEN
|
||||
#define M_SIZE 1
|
||||
#define N_SIZE 8
|
||||
|
||||
#define INDEXING_B(row_idx, col_idx, ldb) ((col_idx) * (ldb) + row_idx)
|
||||
|
||||
#define PROCESS_ACCUM(reg_idx, z_reg_idx, tmp_reg, dst, p) \
|
||||
"mov w" #reg_idx \
|
||||
", #0\n" \
|
||||
"saddv d" #reg_idx ", " #p ", z" #z_reg_idx \
|
||||
".s\n" \
|
||||
"fmov " #tmp_reg ", d" #reg_idx \
|
||||
"\n" \
|
||||
"add x" #reg_idx ", x" #reg_idx ", " #tmp_reg \
|
||||
"\n" \
|
||||
"str w" #reg_idx ", [%[" #dst "]], #4\n"
|
||||
|
||||
#define INT4_CP_MASK_SHIFT_1x8(src_reg, dst_reg, mask_reg1, mask_reg2, shift) \
|
||||
"movprfx z" #dst_reg ", z" #src_reg \
|
||||
"\n" \
|
||||
"lsl z" #dst_reg ".b, p0/m, z" #dst_reg ".b, #" #shift \
|
||||
"\n" \
|
||||
"and z" #src_reg ".b, p0/m, z" #src_reg ".b, z" #mask_reg1 ".b\n"
|
||||
|
||||
void pack_b_1x8(void* bufferB, const void* cur_b_ptr, size_t n, size_t k, size_t ldb, const BLASINT8 ob);
|
||||
void pack_b_1x8_int4(void* bufferB, const void* cur_b_ptr, size_t n, size_t k, size_t ldb, const BLASINT8 ob);
|
||||
|
||||
void gemm_kernel_1x8(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
void gemm_kernel_1x8_int4(const void* lhs_ptr, const void* rhs_ptr, int32_t* accum_ptr, size_t ldc, int64_t k_depth,
|
||||
int64_t sv_len);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
/*** gemm helper ***/
|
||||
@@ -1,430 +0,0 @@
|
||||
#ifndef KML_DEEPSEEKV3_HPP
|
||||
#define KML_DEEPSEEKV3_HPP
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
|
||||
#include "gate.hpp"
|
||||
#include "kblas.h"
|
||||
#include "la/arm_kml.hpp"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#include "mla.hpp"
|
||||
#include "moe.hpp"
|
||||
|
||||
// #define DEBUG_LAYER_CORRECT
|
||||
|
||||
class DeepseekV3DecoderLayer
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
: protected TimePerf
|
||||
#endif
|
||||
{
|
||||
private:
|
||||
using A = float;
|
||||
|
||||
GeneralConfig config;
|
||||
|
||||
A* attn_norm;
|
||||
A* ffn_norm;
|
||||
|
||||
A* hidden_states;
|
||||
ggml_bf16_t* hidden_states_bf16;
|
||||
int64_t* expert_ids;
|
||||
A* experts_weights;
|
||||
|
||||
size_t layer_idx;
|
||||
|
||||
public:
|
||||
std::shared_ptr<MLA_Interface> self_attn = nullptr;
|
||||
std::shared_ptr<MoEGate> gate = nullptr;
|
||||
std::shared_ptr<MoE_Interface> ffn = nullptr;
|
||||
|
||||
using input_t = float;
|
||||
using output_t = float;
|
||||
|
||||
DeepseekV3DecoderLayer(GeneralConfig config, size_t layer_idx) : config(config), layer_idx(layer_idx) {
|
||||
init_ggml();
|
||||
MemoryRequest mem_requests;
|
||||
PUSH_MEM_REQ(hidden_states, sizeof(A) * config.max_qlen * config.hidden_size);
|
||||
PUSH_MEM_REQ(hidden_states_bf16, sizeof(ggml_bf16_t) * config.max_qlen * config.hidden_size);
|
||||
PUSH_MEM_REQ(expert_ids,
|
||||
sizeof(int64_t) * config.max_qlen * (config.num_experts_per_tok + config.n_shared_experts));
|
||||
PUSH_MEM_REQ(experts_weights, sizeof(A) * config.max_qlen * (config.num_experts_per_tok + config.n_shared_experts));
|
||||
|
||||
shared_mem_buffer_for_decoder_layer.alloc(this, mem_requests);
|
||||
}
|
||||
void load_norm_binding(intptr_t attn_norm_ptr, ggml_type attn_norm_type, intptr_t ffn_norm_ptr,
|
||||
ggml_type mlp_norm_type) {
|
||||
load_norm((void*)attn_norm_ptr, attn_norm_type, (void*)ffn_norm_ptr, mlp_norm_type);
|
||||
}
|
||||
|
||||
void load_norm(const void* attn_norm, ggml_type attn_norm_type, const void* ffn_norm, ggml_type ffn_norm_type) {
|
||||
this->attn_norm = new A[config.hidden_size];
|
||||
this->ffn_norm = new A[config.hidden_size];
|
||||
convert_or_copy(this->attn_norm, (void*)attn_norm, attn_norm_type, config.hidden_size);
|
||||
convert_or_copy(this->ffn_norm, (void*)ffn_norm, ffn_norm_type, config.hidden_size);
|
||||
}
|
||||
|
||||
void forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,
|
||||
const void* input, void* output) {
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
forward_perf_start();
|
||||
#endif
|
||||
|
||||
size_t seq_len = 0;
|
||||
for (size_t i = 0; i < qlens.size(); i++) {
|
||||
seq_len += qlens[i];
|
||||
}
|
||||
// for (size_t i = 0; i < 5; i++) {
|
||||
// debug_f32((input_t *)input + (seq_len - 5 + i) * config.hidden_size, config.hidden_size);
|
||||
// }
|
||||
// printf("\n");
|
||||
|
||||
#ifdef DEBUG_LAYER_CORRECT
|
||||
std::string prefix = "Layer_" + std::to_string(layer_idx);
|
||||
dump_bin(prefix + "_input", (input_t*)input, seq_len * config.hidden_size);
|
||||
#endif
|
||||
// Residue
|
||||
// printf("convert or copy hidden states, %ld,%ld\n", seq_len, config.hidden_size);
|
||||
convert_or_copy(hidden_states, (input_t*)input, seq_len * config.hidden_size);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("copy residue");
|
||||
#endif
|
||||
// Norm
|
||||
config.pool->do_work_stealing_job(seq_len, [&](int task_id) {
|
||||
A* input_row = (A*)input + task_id * config.hidden_size;
|
||||
RMSNorm<A>::rms_norm_single_with_weights(config.hidden_size, attn_norm, input_row);
|
||||
});
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("before attn norm");
|
||||
#endif
|
||||
// self attention
|
||||
self_attn->forward(qlens, page_tables, kv_lens, input, output);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("self attn");
|
||||
#endif
|
||||
#ifdef DEBUG_LAYER_CORRECT
|
||||
dump_bin(prefix + "_after_attn", (input_t*)output, seq_len * config.hidden_size);
|
||||
#endif
|
||||
|
||||
// Add Residue
|
||||
config.pool->do_work_stealing_job(seq_len, [&](int task_id) {
|
||||
A* hidden_state_row = hidden_states + task_id * config.hidden_size;
|
||||
A* output_row = (A*)output + task_id * config.hidden_size;
|
||||
A* input_row = (A*)input + task_id * config.hidden_size;
|
||||
for (size_t i = 0; i < config.hidden_size; i++) {
|
||||
hidden_state_row[i] += output_row[i];
|
||||
input_row[i] = hidden_state_row[i];
|
||||
}
|
||||
});
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("after attn add residue");
|
||||
#endif
|
||||
|
||||
// Norm
|
||||
config.pool->do_work_stealing_job(seq_len, [&](int task_id) {
|
||||
A* input_row = (A*)input + task_id * config.hidden_size;
|
||||
RMSNorm<A>::rms_norm_single_with_weights(config.hidden_size, ffn_norm, input_row);
|
||||
});
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("after attn norm");
|
||||
#endif
|
||||
|
||||
// Moe
|
||||
// size_t experts_ld = config.num_experts_per_tok;
|
||||
size_t experts_ld = config.num_experts_per_tok + config.n_shared_experts;
|
||||
if (gate != nullptr) {
|
||||
gate->forward(seq_len, (input_t*)input, expert_ids, experts_weights, experts_ld);
|
||||
for (size_t i = 0; i < seq_len; i++) {
|
||||
for (size_t j = 0; j < config.n_shared_experts; j++) {
|
||||
expert_ids[i * (experts_ld) + config.num_experts_per_tok + j] = config.n_routed_experts + j;
|
||||
experts_weights[i * (experts_ld) + config.num_experts_per_tok + j] = 1.0f;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < seq_len; i++) {
|
||||
for (size_t j = 0; j < config.num_experts_per_tok + config.n_shared_experts; j++) {
|
||||
expert_ids[i * (experts_ld) + j] = j;
|
||||
experts_weights[i * (experts_ld) + j] = 1.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Debug 打印选中的 expert
|
||||
// printf("chosen experts for layer %ld:\n", layer_idx);
|
||||
// for (int i = 0; i < seq_len; i++) {
|
||||
// for (int j = 0; j < experts_ld; j++) {
|
||||
// printf("%ld ", expert_ids[i * experts_ld + j]);
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("moe gate");
|
||||
// // Debug 重置选择专家为固定的连续专家
|
||||
// for (int i = 0; i < seq_len; i++) {
|
||||
// for (int j = 0; j < experts_ld; j++) {
|
||||
// expert_ids[i * experts_ld + j] = j;
|
||||
// }
|
||||
// }
|
||||
// Debug 打印选中的 expert
|
||||
// printf("chosen experts for layer %ld:\n", layer_idx);
|
||||
// for (int i = 0; i < seq_len; i++) {
|
||||
// for (int j = 0; j < experts_ld; j++) {
|
||||
// printf("%ld ", expert_ids[i * experts_ld + j]);
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
#endif
|
||||
#ifdef DEBUG_LAYER_CORRECT
|
||||
dump_bin(prefix + "_expert_ids", expert_ids, seq_len * (config.num_experts_per_tok + config.n_shared_experts));
|
||||
dump_bin(prefix + "_expert_weights", experts_weights,
|
||||
seq_len * (config.num_experts_per_tok + config.n_shared_experts));
|
||||
#endif
|
||||
convert_or_copy(hidden_states_bf16, (input_t*)input, seq_len * config.hidden_size);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("convert bf16 time");
|
||||
#endif
|
||||
ffn->forward(seq_len, experts_ld, expert_ids, experts_weights, hidden_states_bf16, output);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("ffn");
|
||||
#endif
|
||||
// Add Residue
|
||||
config.pool->do_work_stealing_job(seq_len, [&](int task_id) {
|
||||
A* hidden_state_row = hidden_states + task_id * config.hidden_size;
|
||||
A* output_row = (A*)output + task_id * config.hidden_size;
|
||||
for (size_t i = 0; i < config.hidden_size; i++) {
|
||||
output_row[i] += hidden_state_row[i];
|
||||
}
|
||||
});
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("add ffn residue");
|
||||
#endif
|
||||
#ifdef DEBUG_LAYER_CORRECT
|
||||
dump_bin(prefix + "_after_mlp", (input_t*)output, seq_len * config.hidden_size);
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
time_perf_name = "DeepseekV3DecoderLayer" + std::to_string(layer_idx);
|
||||
perf_report();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
class DeepseekV3Model
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
: protected TimePerf
|
||||
#endif
|
||||
{
|
||||
private:
|
||||
using A = float;
|
||||
GeneralConfig config;
|
||||
A* norm_weights;
|
||||
|
||||
public:
|
||||
using input_t = float;
|
||||
using output_t = float;
|
||||
std::vector<std::shared_ptr<DeepseekV3DecoderLayer>> layers;
|
||||
DeepseekV3Model(GeneralConfig config) : config(config) {
|
||||
init_ggml();
|
||||
norm_weights = new A[config.hidden_size];
|
||||
convert_or_copy(norm_weights, config.norm_weights_ptr, config.norm_weights_type, config.hidden_size);
|
||||
}
|
||||
|
||||
void forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,
|
||||
const void* input, void* output) {
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
forward_perf_start();
|
||||
#endif
|
||||
size_t seq_len = 0;
|
||||
for (size_t i = 0; i < qlens.size(); i++) {
|
||||
seq_len += qlens[i];
|
||||
}
|
||||
for (size_t i = 0; i < layers.size(); i++) {
|
||||
layers[i]->forward(qlens, page_tables, kv_lens, input, output);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP(std::string("layer ") + std::to_string(i));
|
||||
#endif
|
||||
if (i != layers.size() - 1) {
|
||||
convert_or_copy((A*)input, (A*)output, seq_len * config.hidden_size);
|
||||
} else {
|
||||
config.pool->do_work_stealing_job(seq_len, [&](int task_id) {
|
||||
A* output_row = (A*)output + task_id * config.hidden_size;
|
||||
RMSNorm<A>::rms_norm_single_with_weights(config.hidden_size, norm_weights, output_row);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
time_perf_name = "DeepseekV3Model";
|
||||
perf_report();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
class DeepseekV3ForCausalLM
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
: protected TimePerf
|
||||
#endif
|
||||
{
|
||||
private:
|
||||
using GemmKernel = arm_kml::GemmKernelInt4;
|
||||
using KMatRefA = typename arm_kml::MatRef<int8_t>;
|
||||
using KMatRefB = typename arm_kml::MatRef<GemmKernel::dt>;
|
||||
// using KMatRefAB = typename arm_kml::MatRef<int8_t>;
|
||||
using KMatRefC = typename arm_kml::MatRef<int32_t>;
|
||||
|
||||
using A = float;
|
||||
|
||||
GeneralConfig config;
|
||||
A* input_hidden_states;
|
||||
A* output_hidden_states;
|
||||
|
||||
A* lm_heads_ptr;
|
||||
GemmKernel::BufferA* lm_heads_ba;
|
||||
GemmKernel::BufferB* lm_heads_bb;
|
||||
GemmKernel::BufferC* lm_heads_bc;
|
||||
A* token_embd;
|
||||
|
||||
public:
|
||||
using KMatRef = typename arm_kml::MatRef<float>;
|
||||
using input_t = int64_t;
|
||||
using output_t = float;
|
||||
std::shared_ptr<DeepseekV3Model> model;
|
||||
KMatRefB lm_heads;
|
||||
|
||||
DeepseekV3ForCausalLM(GeneralConfig config) : config(config) {
|
||||
init_ggml();
|
||||
MemoryRequest mem_requests;
|
||||
lm_heads_ba = new GemmKernel::BufferA(config.max_qlen, config.hidden_size);
|
||||
lm_heads_bb = new GemmKernel::BufferB(config.vocab_size, config.hidden_size, true);
|
||||
lm_heads_bc = new GemmKernel::BufferC(config.max_qlen, config.vocab_size);
|
||||
|
||||
mem_requests.append_function([this](void* new_ptr) { lm_heads_ba->set_data(new_ptr); },
|
||||
lm_heads_ba->required_size());
|
||||
lm_heads_bb->set_data(std::aligned_alloc(64, lm_heads_bb->required_size()));
|
||||
mem_requests.append_function([this](void* new_ptr) { lm_heads_bc->set_data(new_ptr); },
|
||||
lm_heads_bc->required_size());
|
||||
shared_mem_buffer.alloc(this, mem_requests);
|
||||
input_hidden_states = new A[config.max_qlen * config.hidden_size];
|
||||
output_hidden_states = new A[config.max_qlen * config.hidden_size];
|
||||
lm_heads_ptr = new A[config.vocab_size * config.hidden_size];
|
||||
token_embd = new A[config.vocab_size * config.hidden_size];
|
||||
convert_or_copy(lm_heads_ptr, config.lm_heads_ptr, config.lm_heads_type, config.vocab_size * config.hidden_size);
|
||||
// 做量化
|
||||
auto pool = config.pool;
|
||||
{
|
||||
size_t nth_lm_b = GemmKernel::recommended_nth(config.vocab_size);
|
||||
|
||||
auto task = [&](int task_id) { lm_heads_bb->from_mat(lm_heads_ptr, task_id, nth_lm_b, -1, true); };
|
||||
pool->do_work_stealing_job(nth_lm_b, task);
|
||||
}
|
||||
lm_heads = KMatRefB(lm_heads_bb->b, config.hidden_size, config.vocab_size, config.hidden_size, CblasColMajor,
|
||||
CblasNoTrans, lm_heads_bb->d);
|
||||
// lm_heads = KMatRef(lm_heads_ptr, config.vocab_size, config.hidden_size, config.hidden_size, CblasRowMajor);
|
||||
convert_or_copy(token_embd, config.token_embd_ptr, config.token_embd_type, config.vocab_size * config.hidden_size);
|
||||
}
|
||||
|
||||
void forward_binding(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,
|
||||
intptr_t input, intptr_t output) {
|
||||
forward(qlens, page_tables, kv_lens, (const void*)input, (void*)output);
|
||||
}
|
||||
|
||||
void forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,
|
||||
const void* input, void* output) {
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
forward_perf_start();
|
||||
#endif
|
||||
|
||||
{
|
||||
size_t qlen_sum = 0;
|
||||
for (size_t i = 0; i < qlens.size(); i++) {
|
||||
qlen_sum += qlens[i];
|
||||
}
|
||||
// printf("DeepseekV3 forward, seq_len %ld\n", qlen_sum);
|
||||
for (size_t i = 0; i < qlen_sum; i++) {
|
||||
convert_or_copy(input_hidden_states + i * config.hidden_size,
|
||||
token_embd + ((input_t*)input)[i] * config.hidden_size, config.hidden_size);
|
||||
}
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("token embd");
|
||||
#endif
|
||||
#ifdef DEBUG_LAYER_CORRECT
|
||||
dump_bin("input_ids", (input_t*)input, qlen_sum);
|
||||
#endif
|
||||
}
|
||||
|
||||
model->forward(qlens, page_tables, kv_lens, input_hidden_states, output_hidden_states);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("model forward");
|
||||
#endif
|
||||
// KMatRef hidden_states_ref =
|
||||
// KMatRef(output_hidden_states, config.hidden_size, config.max_qlen, config.hidden_size, CblasColMajor);
|
||||
|
||||
size_t qlen = 0;
|
||||
for (size_t i = 0; i < qlens.size(); i++) {
|
||||
qlen += qlens[i];
|
||||
}
|
||||
// printf("qlen is: %ld\n", qlen);
|
||||
// KMatRef logits_ref = KMatRef((A *)output, config.vocab_size, config.max_qlen, config.vocab_size, CblasColMajor);
|
||||
KMatRefC logits_ref =
|
||||
KMatRefC(lm_heads_bc->c, config.max_qlen, config.vocab_size, config.vocab_size, CblasRowMajor);
|
||||
KMatRef logits_out_ref = KMatRef((A*)output, config.max_qlen, config.vocab_size, config.vocab_size, CblasRowMajor);
|
||||
|
||||
// 量化输入
|
||||
auto pool = config.pool;
|
||||
{
|
||||
size_t mth = GemmKernel::recommended_mth(qlen);
|
||||
auto task_counter = TaskCounter({mth});
|
||||
auto task = [&](int task_id) {
|
||||
size_t mth_idx = task_counter.at(task_id, 0);
|
||||
lm_heads_ba->from_mat(qlen, output_hidden_states, mth_idx, mth);
|
||||
};
|
||||
DIRECT_OR_POOL_BY(qlen, 10, task_counter.count(), task);
|
||||
}
|
||||
KMatRefA hidden_states_ref = KMatRefA(lm_heads_ba->a, qlen, config.hidden_size, config.hidden_size, CblasRowMajor);
|
||||
|
||||
size_t qlen_sum = 0;
|
||||
for (size_t i = 0; i < qlens.size(); i++) {
|
||||
// auto h = hidden_states_ref.offset_block(0, qlen_sum + qlens[i] - 1, config.hidden_size, 1);
|
||||
auto h = hidden_states_ref.offset_block(qlen_sum + qlens[i] - 1, 0, 1, config.hidden_size);
|
||||
|
||||
{
|
||||
const size_t vocab_block = 256;
|
||||
const size_t vocab_block_count = div_up(config.vocab_size, vocab_block);
|
||||
config.pool->do_work_stealing_job(vocab_block_count, [&](int task_id) {
|
||||
size_t vocab_idx = task_id * vocab_block;
|
||||
size_t vocab_begin = vocab_idx;
|
||||
size_t vocab_end = std::min(vocab_begin + vocab_block, (size_t)config.vocab_size);
|
||||
KMatRefB lm_head_ref = lm_heads.offset_col(vocab_begin, vocab_end - vocab_begin);
|
||||
// KMatRef logits_ref_block = logits_ref.offset_block(vocab_begin, i, vocab_end - vocab_begin, 1);
|
||||
KMatRefC logits_ref_block = logits_ref.offset_block(i, vocab_begin, 1, vocab_end - vocab_begin);
|
||||
|
||||
// arm_kml::decode_mul_mat_clearc(lm_head_ref, h, logits_ref_block);
|
||||
// printf("h.ld: %ld, lm_head_ref.ld: %ld, logits_ref_block.ld: %ld\n", h.ld, lm_head_ref.ld,
|
||||
// logits_ref_block.ld);
|
||||
arm_kml::decode_mul_mat_clearc(h, lm_head_ref, logits_ref_block);
|
||||
GemmKernel::apply_scale(logits_out_ref.data, logits_out_ref.ld, lm_heads_ba, lm_heads_bb, lm_heads_bc,
|
||||
qlen_sum + qlens[i] - 1, qlen_sum + qlens[i], vocab_begin, vocab_end, true,
|
||||
i - (qlen_sum + qlens[i] - 1));
|
||||
});
|
||||
}
|
||||
|
||||
qlen_sum += qlens[i];
|
||||
}
|
||||
#ifdef DEBUG_LAYER_CORRECT
|
||||
dump_bin("output_logits", (output_t*)output, qlens.size() * config.vocab_size);
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("lm heads out");
|
||||
#endif
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
|
||||
time_perf_name = "DeepseekV3ForCausalLM";
|
||||
perf_report();
|
||||
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -1,255 +0,0 @@
|
||||
#ifndef KML_GATE_HPP
|
||||
#define KML_GATE_HPP
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "../common.hpp"
|
||||
#include "kblas.h"
|
||||
#include "la/arm_kml.hpp"
|
||||
#include "llama.cpp/ggml-quants.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
// #define DEBUG_THIS_MOEGATE
|
||||
#ifdef DEBUG_THIS_MOEGATE
|
||||
#include "test/debug.hpp"
|
||||
#endif
|
||||
|
||||
class MoEGate
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
: protected TimePerf
|
||||
#endif
|
||||
{
|
||||
private:
|
||||
using A = float;
|
||||
GeneralGateConfig config;
|
||||
using KMatRef = typename arm_kml::MatRef<A>;
|
||||
|
||||
A* bias;
|
||||
A* weight;
|
||||
KMatRef weight_ref; // [expert_num, hidden_size]
|
||||
|
||||
A* logits; // [tokens, expert_num]
|
||||
A* score_to_choice; // [tokens, expert_num]
|
||||
A* group_score; // [tokens, group count]
|
||||
size_t* temp_idx;
|
||||
|
||||
const size_t col_block = 256;
|
||||
const size_t row_block = 256;
|
||||
|
||||
public:
|
||||
using input_t = float;
|
||||
using output_int_t = int64_t;
|
||||
using output_t = float;
|
||||
|
||||
explicit MoEGate(const GeneralGateConfig& cfg) : config(cfg) {
|
||||
ASSERT_RELEASE(config.weight, "cfg.weight must not be null");
|
||||
ASSERT_RELEASE(config.n_routed_experts % config.n_group == 0, "E must be divisible by G");
|
||||
ASSERT_RELEASE(config.scoring_func == "sigmoid", "Only sigmoid scoring function is supported");
|
||||
ASSERT_RELEASE(config.topk_method == "noaux_tc", "Only noaux_tc topk method is supported");
|
||||
ASSERT_RELEASE(config.norm_topk_prob, "must normalize topk prob");
|
||||
|
||||
MemoryRequest mem_requests;
|
||||
// PUSH_MEM_REQ(input, sizeof(float) * config.max_seqlen * config.hidden_size);
|
||||
PUSH_MEM_REQ(logits, sizeof(float) * config.max_seqlen * config.n_routed_experts);
|
||||
PUSH_MEM_REQ(score_to_choice, sizeof(float) * config.max_seqlen * config.n_routed_experts);
|
||||
PUSH_MEM_REQ(group_score, sizeof(float) * config.max_seqlen * config.n_group);
|
||||
PUSH_MEM_REQ(temp_idx, sizeof(size_t) * config.max_seqlen * config.n_routed_experts);
|
||||
|
||||
shared_mem_buffer.alloc(this, mem_requests);
|
||||
|
||||
weight = new A[config.n_routed_experts * config.hidden_size];
|
||||
bias = new A[config.n_routed_experts];
|
||||
|
||||
convert_or_copy(weight, config.weight, config.weight_type, config.n_routed_experts * config.hidden_size);
|
||||
convert_or_copy(bias, config.e_score_correction_bias, config.e_score_correction_bias_type, config.n_routed_experts);
|
||||
weight_ref = KMatRef(weight, config.n_routed_experts, config.hidden_size, config.hidden_size, CblasRowMajor);
|
||||
}
|
||||
|
||||
void forward_binding(size_t seq_len, intptr_t input_hidden_states_raw, intptr_t output_topk_idx_raw,
|
||||
intptr_t output_topk_weight_raw) {
|
||||
forward(seq_len, (input_t*)input_hidden_states_raw, (output_int_t*)output_topk_idx_raw,
|
||||
(output_t*)output_topk_weight_raw);
|
||||
}
|
||||
|
||||
void forward(size_t seq_len, input_t* input_hidden_states, output_int_t* output_topk_idx,
|
||||
output_t* output_topk_weight_raw) {
|
||||
forward(seq_len, input_hidden_states, output_topk_idx, output_topk_weight_raw, config.num_experts_per_tok);
|
||||
}
|
||||
|
||||
// forward: hidden_states [B,L,H] → (topk_idx, topk_weight) each [B·L,K]
|
||||
// ‑ hidden_states must be contiguous row‑major with H fastest.
|
||||
// ‑ outputs are flattened (token first) and resized inside.
|
||||
void forward(size_t seq_len, input_t* input_hidden_states, output_int_t* output_topk_idx,
|
||||
output_t* output_topk_weight, size_t output_ld) {
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
forward_perf_start();
|
||||
#endif
|
||||
|
||||
KMatRef input_ref =
|
||||
KMatRef((input_t*)input_hidden_states, config.hidden_size, seq_len, config.hidden_size, CblasColMajor);
|
||||
|
||||
#ifdef DEBUG_THIS_MOEGATE
|
||||
dump_bin("gate_input", input_ref.data, seq_len * config.hidden_size);
|
||||
#endif
|
||||
KMatRef logits_ref = KMatRef(logits, config.n_routed_experts, seq_len, config.n_routed_experts, CblasColMajor);
|
||||
{
|
||||
size_t n_routed_experts_block = (seq_len == 1) ? 2 : 64;
|
||||
const size_t seq_len_block = 64;
|
||||
const size_t n_routed_experts_block_count = div_up(config.n_routed_experts, n_routed_experts_block);
|
||||
const size_t seq_len_block_count = div_up(seq_len, seq_len_block);
|
||||
auto task_counter = TaskCounter({n_routed_experts_block_count, seq_len_block_count});
|
||||
auto task = [&](int task_id) {
|
||||
size_t n_routed_experts_block_idx = task_counter.at(task_id, 0);
|
||||
size_t seq_len_block_idx = task_counter.at(task_id, 1);
|
||||
size_t n_routed_experts_begin = n_routed_experts_block_idx * n_routed_experts_block;
|
||||
size_t n_routed_experts_end =
|
||||
std::min(n_routed_experts_begin + n_routed_experts_block, config.n_routed_experts);
|
||||
size_t seq_len_begin = seq_len_block_idx * seq_len_block;
|
||||
size_t seq_len_end = std::min(seq_len_begin + seq_len_block, seq_len);
|
||||
if (seq_len == 1) {
|
||||
arm_kml::decode_mul_mat_clearc(
|
||||
weight_ref.offset_block(n_routed_experts_begin, 0, n_routed_experts_end - n_routed_experts_begin,
|
||||
config.hidden_size),
|
||||
input_ref.offset_block(0, seq_len_begin, config.hidden_size, seq_len_end - seq_len_begin),
|
||||
logits_ref.offset_block(n_routed_experts_begin, seq_len_begin,
|
||||
n_routed_experts_end - n_routed_experts_begin, seq_len_end - seq_len_begin));
|
||||
} else {
|
||||
arm_kml::mul_mat_clearc(
|
||||
weight_ref.offset_block(n_routed_experts_begin, 0, n_routed_experts_end - n_routed_experts_begin,
|
||||
config.hidden_size),
|
||||
input_ref.offset_block(0, seq_len_begin, config.hidden_size, seq_len_end - seq_len_begin),
|
||||
logits_ref.offset_block(n_routed_experts_begin, seq_len_begin,
|
||||
n_routed_experts_end - n_routed_experts_begin, seq_len_end - seq_len_begin));
|
||||
}
|
||||
};
|
||||
config.pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
// arm_kml::mul_mat_clearc(weight_ref, input_ref, logits_ref);
|
||||
// if (seq_len == 1) {
|
||||
// for (int i = 0; i < config.n_routed_experts; i++) {
|
||||
// printf("%f ", logits[i]);
|
||||
// }
|
||||
// printf("\n");
|
||||
// throw std::runtime_error("end");
|
||||
// }
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("moe gate logits");
|
||||
#endif
|
||||
|
||||
#ifdef DEBUG_THIS_MOEGATE
|
||||
dump_bin("gate_logits", logits_ref.data, seq_len * config.n_routed_experts);
|
||||
dump_bin("bias", bias, config.n_routed_experts);
|
||||
#endif
|
||||
for (size_t i = 0; i < seq_len; i++) {
|
||||
float* logits_row = logits + i * config.n_routed_experts;
|
||||
for (size_t j = 0; j < config.n_routed_experts; j++) {
|
||||
logits_row[j] = 1.f / (1.f + std::exp(-logits_row[j]));
|
||||
}
|
||||
}
|
||||
|
||||
auto top2 = [](float* data, size_t begin, size_t end) {
|
||||
if (end - begin < 2) {
|
||||
throw std::invalid_argument("top2 requires at least two elements");
|
||||
}
|
||||
|
||||
float first = -std::numeric_limits<float>::infinity();
|
||||
float second = -std::numeric_limits<float>::infinity();
|
||||
for (size_t i = begin; i < end; ++i) {
|
||||
float v = data[i];
|
||||
if (v > first) {
|
||||
second = first;
|
||||
first = v;
|
||||
} else if (v > second) {
|
||||
second = v;
|
||||
}
|
||||
}
|
||||
return first + second;
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < seq_len; i++) {
|
||||
float* logits_row = logits + i * config.n_routed_experts;
|
||||
float* scores_row = score_to_choice + i * config.n_routed_experts;
|
||||
for (size_t j = 0; j < config.n_routed_experts; j++) {
|
||||
scores_row[j] = logits_row[j] + bias[j];
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("moe gate score to choice");
|
||||
#endif
|
||||
#ifdef DEBUG_THIS_MOEGATE
|
||||
dump_bin("scores_to_choice", score_to_choice, seq_len * config.n_routed_experts);
|
||||
#endif
|
||||
for (size_t i = 0; i < seq_len; i++) {
|
||||
output_int_t* output_topk_idx_row = output_topk_idx + i * output_ld;
|
||||
output_t* output_topk_weight_row = output_topk_weight + i * output_ld;
|
||||
float* logits_row = logits + i * config.n_routed_experts;
|
||||
float* scores_row = score_to_choice + i * config.n_routed_experts;
|
||||
float* group_score_row = group_score + i * config.n_group;
|
||||
size_t* temp_idx_row = temp_idx + i * config.n_routed_experts;
|
||||
|
||||
size_t group_size = config.n_routed_experts / config.n_group;
|
||||
for (size_t g = 0; g < config.n_group; g++) {
|
||||
size_t group_begin = g * group_size;
|
||||
size_t group_end = group_begin + group_size;
|
||||
group_score_row[g] = top2(scores_row, group_begin, group_end);
|
||||
temp_idx_row[g] = g;
|
||||
}
|
||||
std::sort(temp_idx_row, temp_idx_row + config.n_group,
|
||||
[&](auto& a, auto& b) { return group_score_row[a] > group_score_row[b]; });
|
||||
|
||||
for (size_t g = config.topk_group; g < config.n_group; g++) {
|
||||
size_t group_begin = temp_idx_row[g] * group_size;
|
||||
size_t group_end = group_begin + group_size;
|
||||
for (size_t j = group_begin; j < group_end; j++) {
|
||||
scores_row[j] = -std::numeric_limits<float>::infinity();
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < config.n_routed_experts; j++) {
|
||||
temp_idx_row[j] = j;
|
||||
}
|
||||
std::sort(temp_idx_row, temp_idx_row + config.n_routed_experts,
|
||||
[&](auto& a, auto& b) { return scores_row[a] > scores_row[b]; });
|
||||
|
||||
float sum = 1e-20f;
|
||||
for (int j = 0; j < config.num_experts_per_tok; j++) {
|
||||
output_topk_idx_row[j] = temp_idx_row[j];
|
||||
output_topk_weight_row[j] = logits_row[temp_idx_row[j]];
|
||||
sum += output_topk_weight_row[j];
|
||||
}
|
||||
for (int j = 0; j < config.num_experts_per_tok; j++) {
|
||||
output_topk_weight_row[j] /= sum;
|
||||
output_topk_weight_row[j] *= config.routed_scaling_factor;
|
||||
}
|
||||
}
|
||||
#ifdef DEBUG_THIS_MOEGATE
|
||||
dump_bin("group_scores", group_score, seq_len * config.n_group);
|
||||
dump_bin("gate_logits_toped", score_to_choice, seq_len * config.n_routed_experts);
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("gate_logits_toped");
|
||||
#endif
|
||||
// printf("Gate Forward Done\n");
|
||||
// for(size_t i=0;i<seq_len;i++){
|
||||
// printf("seq %ld, topk: ", i);
|
||||
// for(size_t j=0;j<config.num_experts_per_tok;j++){
|
||||
// printf("%ld ", output_topk_idx[i * output_ld + j]);
|
||||
// }
|
||||
// for(size_t j=0;j<config.num_experts_per_tok;j++){
|
||||
// printf("%f ", output_topk_weight[i * output_ld + j]);
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
time_perf_name = "moe gate inner";
|
||||
perf_report();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,487 +0,0 @@
|
||||
#include "mla_quan.h"
|
||||
|
||||
#include "../mla-tp.hpp"
|
||||
#include "../reduce.hpp"
|
||||
#include "../rms-norm.hpp"
|
||||
#include "../rope.hpp"
|
||||
#include "../softmax.hpp"
|
||||
#include "ggml-quants.h"
|
||||
#include "ggml.h"
|
||||
#include "kblas.h"
|
||||
#include "la/arm_kml.hpp"
|
||||
|
||||
// #define DEBUG_THIS_MLA
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
#include "test/debug.hpp"
|
||||
#endif
|
||||
#include <arm_sve.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
// #define DEBUG_THIS_MLA
|
||||
// #define FORWARD_TIME_PROFILE
|
||||
|
||||
#define DIRECT_OR_POOL_BY(what, threshold, var, fn) \
|
||||
do { \
|
||||
if ((what) < (threshold)) { \
|
||||
for (int i = 0; i < (var); i++) { \
|
||||
(fn)(i); \
|
||||
} \
|
||||
} else { \
|
||||
pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// inline void debug_output(const float16_t *data) {
|
||||
// for (int i = 0; i < 10; i++) {
|
||||
// printf("%f ", data[i]);
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
|
||||
// inline void debug_output(const float *data) {
|
||||
// for (int i = 0; i < 10; i++) {
|
||||
// printf("%f ", data[i]);
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
// inline void debug_output(const bfloat16_t *data) {
|
||||
// for (int i = 0; i < 10; i++) {
|
||||
// float x = 0;
|
||||
// *(bfloat16_t *)(&x) = data[i];
|
||||
// printf("%f ", x);
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
|
||||
// template <typename T> inline void dump_to(std::string file_name, T *data, size_t count) {
|
||||
// std::ofstream f(file_name);
|
||||
// for (int i = 0; i < count; i++) {
|
||||
// f << data[i] << " ";
|
||||
// }
|
||||
// f.close();
|
||||
// }
|
||||
template <typename A, class KERNEL>
|
||||
KML_MLA_TP_QUAN_TEST<A, KERNEL>::KML_MLA_TP_QUAN_TEST(GeneralMLAConfig config, int tp_part_idx)
|
||||
: config(config), tp_part_idx(tp_part_idx) {
|
||||
init_ggml();
|
||||
cc_size = config.kv_lora_rank + config.rope_size;
|
||||
MemoryRequest mem_requests;
|
||||
softmax_scale = 1.0 / sqrt(config.rope_size + config.nope_size);
|
||||
PUSH_MEM_REQ(q_lora_rank,
|
||||
sizeof(std::remove_pointer_t<decltype(q_lora_rank)>) * config.q_lora_rank * config.max_qlen);
|
||||
// qlen_decode_output.resize(config.num_heads, nullptr);
|
||||
// for (int i = 0; i < config.num_heads; i++) {
|
||||
// mem_requests.append_pointer(&qlen_decode_output[i], sizeof(std::remove_pointer_t<decltype(qlen_decode_output)>) *
|
||||
// config.hidden_size * config.max_qlen);
|
||||
// }
|
||||
qlen_quant_output.resize((config.num_heads + sub_num_heads_decode - 1) / sub_num_heads_decode, nullptr);
|
||||
for (int i = 0; i < (config.num_heads + sub_num_heads_decode - 1) / sub_num_heads_decode; i++) {
|
||||
mem_requests.append_pointer(&qlen_quant_output[i], sizeof(std::remove_pointer_t<decltype(qlen_quant_output)>) *
|
||||
config.hidden_size * config.max_qlen * sub_num_heads_decode);
|
||||
}
|
||||
|
||||
mem_requests.append_function(
|
||||
[this](void* new_ptr) {
|
||||
auto& config = this->config;
|
||||
q_nope = (A*)new_ptr;
|
||||
q_nope_tmp_ref =
|
||||
KMatRef(q_nope, config.nope_size, config.num_heads * config.max_qlen, config.nope_size, CblasColMajor);
|
||||
},
|
||||
sizeof(std::remove_pointer_t<decltype(q_nope)>) * config.num_heads * config.nope_size * config.max_qlen);
|
||||
|
||||
q_ld = std::max(config.kv_lora_rank, config.nope_size) + config.rope_size;
|
||||
mem_requests.append_function(
|
||||
[this](void* new_ptr) {
|
||||
auto& config = this->config;
|
||||
q = (A*)new_ptr;
|
||||
q_ref = KMatRef(q, q_ld, config.num_heads * config.max_qlen, q_ld, CblasColMajor);
|
||||
q_pe_absorb_ref = q_ref.offset_row(config.kv_lora_rank, config.rope_size);
|
||||
q_pe_noabsorb_ref = q_ref.offset_row(config.nope_size, config.rope_size);
|
||||
q_kv_lora_rank_ref = q_ref.offset_row(0, config.kv_lora_rank);
|
||||
q_nope_ref = q_ref.offset_row(0, config.nope_size);
|
||||
q_attn_absorb_ref = q_ref.offset_row(0, config.kv_lora_rank + config.rope_size);
|
||||
q_attn_noabsorb_ref = q_ref.offset_row(0, config.nope_size + config.rope_size);
|
||||
},
|
||||
sizeof(std::remove_pointer_t<decltype(q)>) * config.num_heads * config.max_qlen * q_ld);
|
||||
mem_requests.append_function(
|
||||
[this](void* new_ptr) {
|
||||
auto& config = this->config;
|
||||
attention_weights = (A*)new_ptr;
|
||||
attention_weights_ref = KMatRef(attention_weights, config.max_kvlen, config.max_qlen * config.num_heads,
|
||||
config.max_kvlen, CblasColMajor);
|
||||
},
|
||||
sizeof(std::remove_pointer_t<decltype(attention_weights)>) * config.max_kvlen * config.max_qlen *
|
||||
config.num_heads);
|
||||
|
||||
mem_requests.append_function(
|
||||
[this](void* new_ptr) {
|
||||
auto& config = this->config;
|
||||
attention_output = (A*)new_ptr;
|
||||
attention_output_ref = KMatRef(attention_output, config.nope_size * config.num_heads, config.max_qlen,
|
||||
config.nope_size * config.num_heads, CblasColMajor);
|
||||
},
|
||||
(sizeof(std ::remove_pointer_t<decltype(attention_output)>) * config.num_heads * config.nope_size *
|
||||
config.max_qlen));
|
||||
|
||||
mem_requests.append_function(
|
||||
[this](void* new_ptr) {
|
||||
auto& config = this->config;
|
||||
k = (A*)new_ptr;
|
||||
k_ref = KMatRef(k, config.nope_size + config.rope_size, config.max_kvlen * config.num_heads,
|
||||
config.nope_size + config.rope_size, CblasColMajor);
|
||||
k_nope_ref = k_ref.offset_row(0, config.nope_size);
|
||||
k_rope_ref = k_ref.offset_row(config.nope_size, config.rope_size);
|
||||
},
|
||||
sizeof(std::remove_pointer_t<decltype(k)>) * config.num_heads * (config.nope_size + config.rope_size) *
|
||||
config.max_kvlen);
|
||||
size_t o_absorb_or_v_size = std::max(config.kv_lora_rank * config.max_qlen, config.nope_size * config.max_kvlen);
|
||||
mem_requests.append_function(
|
||||
[this](void* new_ptr) {
|
||||
auto& config = this->config;
|
||||
o_absorb_or_v = (A*)new_ptr;
|
||||
o_absorb_ref = KMatRef(o_absorb_or_v, config.kv_lora_rank, config.max_qlen * config.num_heads,
|
||||
config.kv_lora_rank, CblasColMajor);
|
||||
v_ref = KMatRef(o_absorb_or_v, config.nope_size, config.max_kvlen * config.num_heads, config.nope_size,
|
||||
CblasColMajor);
|
||||
},
|
||||
sizeof(std::remove_pointer_t<decltype(o_absorb_or_v)>) * o_absorb_or_v_size * config.num_heads);
|
||||
|
||||
rope_angle = std::make_unique<T_RopeAngle>(
|
||||
config.rope_size, config.max_position_embeddings, config.rope_theta, config.rope_scaling_factor,
|
||||
config.rope_scaling_original_max_position_embeddings, config.rope_scaling_beta_fast,
|
||||
config.rope_scaling_beta_slow, config.rope_scaling_mscale, config.rope_scaling_mscale_all_dim);
|
||||
rope_angle->init(config.max_kvlen);
|
||||
|
||||
// local_q_a_proj_deprecated_ba = new GemmKernel::BufferA(config.q_lora_rank, config.hidden_size, false);
|
||||
// local_q_a_proj_deprecated_ba = new typename GemmKernel::BufferA(config.q_lora_rank, config.hidden_size, true);
|
||||
// local_q_a_proj_deprecated_bb = new typename GemmKernel::BufferB(config.max_qlen, config.hidden_size);
|
||||
// local_q_a_proj_deprecated_bc = new typename GemmKernel::BufferC(config.q_lora_rank, config.max_qlen);
|
||||
|
||||
local_q_a_proj_quant_ba = new typename GemmKernel::BufferA(config.max_qlen, config.hidden_size);
|
||||
// local_q_a_proj_quant_bb = new GemmKernel::BufferB(config.q_lora_rank, config.hidden_size, false);
|
||||
local_q_a_proj_quant_bb = new typename GemmKernel::BufferB(config.q_lora_rank, config.hidden_size, true);
|
||||
local_q_a_proj_quant_bc = new typename GemmKernel::BufferC(config.max_qlen, config.q_lora_rank, true); // row major
|
||||
|
||||
mem_requests.append_function([this](void* new_ptr) { local_q_a_proj_quant_ba->set_data(new_ptr); },
|
||||
local_q_a_proj_quant_ba->required_size());
|
||||
local_q_a_proj_quant_bb->set_data(std::aligned_alloc(64, local_q_a_proj_quant_bb->required_size()));
|
||||
mem_requests.append_function([this](void* new_ptr) { local_q_a_proj_quant_bc->set_data(new_ptr); },
|
||||
local_q_a_proj_quant_bc->required_size());
|
||||
|
||||
// local_kv_a_proj_with_mqa_deprecated_ba = new GemmKernel::BufferA(cc_size, config.hidden_size, false);
|
||||
// local_kv_a_proj_with_mqa_deprecated_ba = new typename GemmKernel::BufferA(cc_size, config.hidden_size, true);
|
||||
// local_kv_a_proj_with_mqa_deprecated_bb = new typename GemmKernel::BufferB(config.max_qlen, config.hidden_size);
|
||||
// for (int i = 0; i < config.page_count; i++) {
|
||||
// local_kv_a_proj_with_mqa_deprecated_bc.push_back(new
|
||||
// typename GemmKernel::BufferC(cc_size, config.token_count_in_page));
|
||||
// }
|
||||
|
||||
local_kv_a_proj_with_mqa_quant_ba = new typename GemmKernel::BufferA(config.max_qlen, config.hidden_size);
|
||||
// local_kv_a_proj_with_mqa_quant_bb = new GemmKernel::BufferB(cc_size, config.hidden_size, false);
|
||||
local_kv_a_proj_with_mqa_quant_bb = new typename GemmKernel::BufferB(cc_size, config.hidden_size, true);
|
||||
for (int i = 0; i < config.page_count; i++) {
|
||||
local_kv_a_proj_with_mqa_quant_bc.push_back(
|
||||
new typename GemmKernel::BufferC(config.token_count_in_page, cc_size, true)); // row major
|
||||
}
|
||||
|
||||
mem_requests.append_function([this](void* new_ptr) { local_kv_a_proj_with_mqa_quant_ba->set_data(new_ptr); },
|
||||
local_kv_a_proj_with_mqa_quant_ba->required_size());
|
||||
local_kv_a_proj_with_mqa_quant_bb->set_data(
|
||||
std::aligned_alloc(64, local_kv_a_proj_with_mqa_quant_bb->required_size()));
|
||||
|
||||
cc_page_refs_buffer.resize(config.page_count);
|
||||
kv_lora_page_refs_buffer.resize(config.page_count);
|
||||
rope_page_refs_buffer.resize(config.page_count);
|
||||
|
||||
cc_page_refs_decode_buffer.resize(config.page_count);
|
||||
kv_lora_page_refs_decode_buffer.resize(config.page_count);
|
||||
rope_page_refs_decode_buffer.resize(config.page_count);
|
||||
for (int i = 0; i < config.page_count; i++) {
|
||||
mem_requests.append_function(
|
||||
[this, i, config](void* new_ptr) {
|
||||
local_kv_a_proj_with_mqa_quant_bc[i]->set_data(new_ptr);
|
||||
cc_page_refs_buffer[i] = KMatRefC(local_kv_a_proj_with_mqa_quant_bc[i]->c, cc_size,
|
||||
config.token_count_in_page, cc_size, CblasColMajor);
|
||||
kv_lora_page_refs_buffer[i] = cc_page_refs_buffer[i].offset_row(0, config.kv_lora_rank);
|
||||
rope_page_refs_buffer[i] = cc_page_refs_buffer[i].offset_row(config.kv_lora_rank, config.rope_size);
|
||||
|
||||
cc_page_refs_decode_buffer[i] = KMatRefC(local_kv_a_proj_with_mqa_quant_bc[i]->c, config.token_count_in_page,
|
||||
cc_size, cc_size, CblasRowMajor);
|
||||
kv_lora_page_refs_decode_buffer[i] = cc_page_refs_decode_buffer[i].offset_col(0, config.kv_lora_rank);
|
||||
rope_page_refs_decode_buffer[i] =
|
||||
cc_page_refs_decode_buffer[i].offset_col(config.kv_lora_rank, config.rope_size);
|
||||
},
|
||||
local_kv_a_proj_with_mqa_quant_bc[i]->required_size());
|
||||
}
|
||||
|
||||
// local_w_o_ba = new GemmKernel::BufferA(config.hidden_size, config.num_heads * config.nope_size, false);
|
||||
// local_w_o_ba = new typename GemmKernel::BufferA(config.hidden_size, config.num_heads * config.nope_size, true);
|
||||
// local_w_o_bb = new typename GemmKernel::BufferB(config.max_qlen, config.num_heads * config.nope_size);
|
||||
// local_w_o_bc = new typename GemmKernel::BufferC(config.hidden_size, config.max_qlen);
|
||||
for (int i = 0; i < div_up(config.num_heads, sub_num_heads_decode); i++) {
|
||||
local_w_o_decode_bc.push_back(new typename GemmKernel::BufferC(config.hidden_size, config.max_qlen)); // col major
|
||||
local_w_o_decode_bb.push_back(
|
||||
new typename GemmKernel::BufferB(config.hidden_size, sub_num_heads_decode * config.nope_size, true));
|
||||
}
|
||||
|
||||
local_w_o_quant_ba = new typename GemmKernel::BufferA(config.max_qlen, config.num_heads * config.nope_size);
|
||||
// local_w_o_quant_bb = new GemmKernel::BufferB(config.hidden_size, config.num_heads * config.nope_size, false);
|
||||
local_w_o_quant_bb = new typename GemmKernel::BufferB(config.hidden_size, config.num_heads * config.nope_size, true);
|
||||
local_w_o_prefill_bc = new typename GemmKernel::BufferC(config.max_qlen, config.hidden_size, true); // row major
|
||||
|
||||
mem_requests.append_function([this](void* new_ptr) { local_w_o_quant_ba->set_data(new_ptr); },
|
||||
local_w_o_quant_ba->required_size());
|
||||
local_w_o_quant_bb->set_data(std::aligned_alloc(64, local_w_o_quant_bb->required_size()));
|
||||
mem_requests.append_function([this](void* new_ptr) { local_w_o_prefill_bc->set_data(new_ptr); },
|
||||
local_w_o_prefill_bc->required_size());
|
||||
for (int i = 0; i < div_up(config.num_heads, sub_num_heads_decode); i++) {
|
||||
mem_requests.append_function([this, i](void* new_ptr) { local_w_o_decode_bc[i]->set_data(new_ptr); },
|
||||
local_w_o_decode_bc[i]->required_size());
|
||||
local_w_o_decode_bb[i]->set_data(std::aligned_alloc(64, local_w_o_decode_bb[i]->required_size()));
|
||||
}
|
||||
|
||||
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
|
||||
}
|
||||
|
||||
template <typename A, class KERNEL>
|
||||
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::load_weights(int complete_num_heads, int offset) {
|
||||
if constexpr (std::is_same_v<A, float16_t>) {
|
||||
ASSERT_RELEASE(config.q_a_proj_type == GGML_TYPE_F16, "q_a_proj_type must be GGML_TYPE_F16");
|
||||
ASSERT_RELEASE(config.q_b_proj_type == GGML_TYPE_F16, "q_b_proj_type must be GGML_TYPE_F16");
|
||||
ASSERT_RELEASE(config.kv_a_proj_with_mqa_type == GGML_TYPE_F16, "kv_a_proj_with_mqa_type must be GGML_TYPE_F16");
|
||||
ASSERT_RELEASE(config.kv_b_proj_type == GGML_TYPE_F16, "kv_b_proj_type must be GGML_TYPE_F16");
|
||||
ASSERT_RELEASE(config.w_o_type == GGML_TYPE_F16, "w_o_type must be GGML_TYPE_F16");
|
||||
} else if constexpr (std::is_same_v<A, float>) {
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported type for KML_MLA_TP");
|
||||
}
|
||||
|
||||
default_attention_masks.resize(config.max_kvlen);
|
||||
for (int i = 0; i < config.max_kvlen; i++) {
|
||||
A* mask = new A[config.max_kvlen];
|
||||
|
||||
memset(mask, 0, config.max_kvlen * sizeof(A));
|
||||
for (int j = i + 1; j < config.max_kvlen; j++) {
|
||||
mask[j] = -std::numeric_limits<float>::infinity();
|
||||
}
|
||||
default_attention_masks[i] = mask;
|
||||
}
|
||||
|
||||
local_q_a_proj = new A[config.hidden_size * config.q_lora_rank];
|
||||
convert_or_copy(local_q_a_proj, config.q_a_proj, (ggml_type)config.q_a_proj_type,
|
||||
config.hidden_size * config.q_lora_rank);
|
||||
|
||||
local_q_a_norm = new A[config.q_lora_rank];
|
||||
if (config.q_a_norm == nullptr) {
|
||||
for (size_t i = 0; i < config.q_lora_rank; i++) {
|
||||
local_q_a_norm[i] = 1;
|
||||
}
|
||||
} else {
|
||||
convert_or_copy(local_q_a_norm, config.q_a_norm, (ggml_type)config.q_a_norm_type, config.q_lora_rank);
|
||||
}
|
||||
|
||||
local_kv_a_proj_with_mqa = new A[config.hidden_size * (config.kv_lora_rank + config.rope_size)];
|
||||
|
||||
convert_or_copy(local_kv_a_proj_with_mqa, config.kv_a_proj_with_mqa, (ggml_type)config.kv_a_proj_with_mqa_type,
|
||||
config.hidden_size * (config.kv_lora_rank + config.rope_size));
|
||||
|
||||
local_kv_a_norm = new A[config.kv_lora_rank];
|
||||
if (config.kv_a_norm == nullptr) {
|
||||
for (size_t i = 0; i < config.kv_lora_rank; i++) {
|
||||
local_kv_a_norm[i] = 1;
|
||||
}
|
||||
} else {
|
||||
convert_or_copy(local_kv_a_norm, config.kv_a_norm, (ggml_type)config.kv_a_norm_type, config.kv_lora_rank);
|
||||
}
|
||||
|
||||
local_q_b_proj = new A[config.num_heads * (config.nope_size + config.rope_size) * config.q_lora_rank];
|
||||
|
||||
convert_or_copy(local_q_b_proj,
|
||||
offset_pointer(config.q_b_proj, offset * (config.nope_size + config.rope_size) * config.q_lora_rank *
|
||||
ggml_type_size((ggml_type)config.q_b_proj_type)),
|
||||
(ggml_type)config.q_b_proj_type,
|
||||
config.num_heads * (config.nope_size + config.rope_size) * config.q_lora_rank);
|
||||
|
||||
local_k_b_proj = new A[config.num_heads * config.nope_size * config.kv_lora_rank];
|
||||
local_v_b_proj = new A[config.num_heads * config.nope_size * config.kv_lora_rank];
|
||||
for (size_t i = 0; i < config.num_heads; i++) {
|
||||
convert_or_copy(
|
||||
local_k_b_proj + i * config.nope_size * config.kv_lora_rank,
|
||||
offset_pointer(config.kv_b_proj, (i + offset) * (config.nope_size + config.nope_size) * config.kv_lora_rank *
|
||||
ggml_type_size((ggml_type)config.kv_b_proj_type)),
|
||||
(ggml_type)config.kv_b_proj_type, config.nope_size * config.kv_lora_rank);
|
||||
|
||||
convert_or_copy(
|
||||
local_v_b_proj + i * config.nope_size * config.kv_lora_rank,
|
||||
offset_pointer(config.kv_b_proj, ((i + offset) * (config.nope_size + config.nope_size) + config.nope_size) *
|
||||
config.kv_lora_rank * ggml_type_size((ggml_type)config.kv_b_proj_type)),
|
||||
(ggml_type)config.kv_b_proj_type, config.nope_size * config.kv_lora_rank);
|
||||
}
|
||||
local_k_b_proj_ref = KMatRef((A*)local_k_b_proj, config.num_heads * config.nope_size, config.kv_lora_rank,
|
||||
config.kv_lora_rank, CblasRowMajor);
|
||||
local_v_b_proj_ref = KMatRef((A*)local_v_b_proj, config.num_heads * config.nope_size, config.kv_lora_rank,
|
||||
config.kv_lora_rank, CblasRowMajor);
|
||||
|
||||
local_w_o = new A[config.num_heads * config.hidden_size * config.nope_size];
|
||||
for (size_t i = 0; i < config.hidden_size; i++) {
|
||||
convert_or_copy(
|
||||
local_w_o + i * config.num_heads * config.nope_size,
|
||||
offset_pointer(config.o_proj, (i * complete_num_heads * config.nope_size + (offset)*config.nope_size) *
|
||||
ggml_type_size((ggml_type)config.w_o_type)),
|
||||
(ggml_type)config.w_o_type, config.num_heads * config.nope_size);
|
||||
}
|
||||
size_t sub_num_heads_decode_group = div_up(config.num_heads, sub_num_heads_decode);
|
||||
local_w_decode_o.resize(sub_num_heads_decode_group);
|
||||
for (size_t h = 0; h < sub_num_heads_decode_group; h++) {
|
||||
local_w_decode_o[h] = new A[config.hidden_size * sub_num_heads_decode * config.nope_size];
|
||||
for (size_t i = 0; i < config.hidden_size; i++) {
|
||||
convert_or_copy(local_w_decode_o[h] + i * config.nope_size * sub_num_heads_decode,
|
||||
offset_pointer(config.o_proj, (i * complete_num_heads * config.nope_size +
|
||||
(h * sub_num_heads_decode + offset) * config.nope_size) *
|
||||
ggml_type_size((ggml_type)config.w_o_type)),
|
||||
(ggml_type)config.w_o_type, config.nope_size * sub_num_heads_decode);
|
||||
}
|
||||
}
|
||||
|
||||
// 做量化
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("KML_MLA_TP_QUAN_TEST::load_weights: local_q_a_proj_quant_bb quantization\n");
|
||||
if (tp_part_idx == 0) {
|
||||
dump_bin("local_q_a_proj_quant.bin", local_q_a_proj, config.hidden_size * config.q_lora_rank);
|
||||
}
|
||||
#endif
|
||||
auto pool = config.pool->get_subpool(tp_part_idx);
|
||||
{
|
||||
size_t mth_q_a = GemmKernel::recommended_nth(config.q_lora_rank);
|
||||
size_t mth_kv_a = GemmKernel::recommended_nth(config.kv_lora_rank + config.rope_size);
|
||||
auto task_counter = TaskCounter({mth_q_a + mth_kv_a});
|
||||
auto task = [&](int task_id) {
|
||||
// 前 nth 是 local_q_a_proj, 后 nth 是 local_kv_a_proj_with_mqa
|
||||
size_t mth_idx = task_counter.at(task_id, 0);
|
||||
if (mth_idx < mth_q_a) {
|
||||
// local_q_a_proj_deprecated_ba->from_mat(config.q_lora_rank, (A *)local_q_a_proj, mth_idx, mth_q_a);
|
||||
local_q_a_proj_quant_bb->from_mat((A*)local_q_a_proj, mth_idx, mth_q_a, config.q_lora_rank);
|
||||
} else {
|
||||
mth_idx -= mth_q_a;
|
||||
// local_kv_a_proj_with_mqa_deprecated_ba->from_mat(config.kv_lora_rank + config.rope_size,
|
||||
// (A *)local_kv_a_proj_with_mqa, mth_idx, mth_kv_a);
|
||||
local_kv_a_proj_with_mqa_quant_bb->from_mat((A*)local_kv_a_proj_with_mqa, mth_idx, mth_kv_a,
|
||||
config.kv_lora_rank + config.rope_size);
|
||||
}
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("KML_MLA_TP_QUAN_TEST::load_weights: local_w_o_quant_bb quantization\n");
|
||||
#endif
|
||||
{
|
||||
size_t mth_w_o = GemmKernel::recommended_mth(config.hidden_size);
|
||||
auto task_counter = TaskCounter({mth_w_o});
|
||||
auto task = [&](int task_id) {
|
||||
size_t mth_idx = task_counter.at(task_id, 0);
|
||||
// local_w_o_ba->from_mat(config.hidden_size, (A *)local_w_o, mth_idx, mth_w_o);
|
||||
local_w_o_quant_bb->from_mat((A*)local_w_o, mth_idx, mth_w_o, config.hidden_size);
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("KML_MLA_TP_QUAN_TEST::load_weights: local_w_decode_o quantization\n");
|
||||
#endif
|
||||
{
|
||||
size_t mth_w_o = GemmKernel::recommended_mth(config.hidden_size);
|
||||
auto task_counter = TaskCounter({sub_num_heads_decode_group, mth_w_o});
|
||||
auto task = [&](int task_id) {
|
||||
size_t h = task_counter.at(task_id, 0);
|
||||
size_t mth_idx = task_counter.at(task_id, 1);
|
||||
// local_w_o_decode_ba[h]->from_mat(config.hidden_size, (A *)local_w_decode_o[h], mth_idx, mth_w_o);
|
||||
local_w_o_decode_bb[h]->from_mat((A*)local_w_decode_o[h], mth_idx, mth_w_o, config.hidden_size);
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
|
||||
local_q_a_proj_quant_ref =
|
||||
KMatRefB(local_q_a_proj_quant_bb->b, config.hidden_size, config.q_lora_rank, config.hidden_size, CblasColMajor,
|
||||
CblasNoTrans, local_q_a_proj_quant_bb->if_pack);
|
||||
|
||||
// local_q_a_proj_ref = KMatRefAB(local_q_a_proj_deprecated_ba->a, config.q_lora_rank, config.hidden_size,
|
||||
// config.hidden_size, CblasRowMajor);
|
||||
|
||||
local_q_b_proj_ref = KMatRef((A*)local_q_b_proj, config.num_heads * (config.nope_size + config.rope_size),
|
||||
config.q_lora_rank, config.q_lora_rank, CblasRowMajor);
|
||||
|
||||
local_kv_a_proj_with_mqa_decode_ref =
|
||||
KMatRefB(local_kv_a_proj_with_mqa_quant_bb->b, config.hidden_size, config.kv_lora_rank + config.rope_size,
|
||||
config.hidden_size, CblasColMajor, CblasNoTrans, local_kv_a_proj_with_mqa_quant_bb->if_pack);
|
||||
|
||||
local_w_o_ref =
|
||||
KMatRefB(local_w_o_quant_bb->b, config.hidden_size, config.nope_size * config.num_heads,
|
||||
config.nope_size * config.num_heads, CblasRowMajor, CblasNoTrans, local_w_o_quant_bb->if_pack);
|
||||
delete[] local_w_o;
|
||||
delete[] local_q_a_proj;
|
||||
delete[] local_kv_a_proj_with_mqa;
|
||||
for (int i = 0; i < sub_num_heads_decode_group; i++) {
|
||||
delete[] local_w_decode_o[i];
|
||||
}
|
||||
local_w_decode_o.clear();
|
||||
}
|
||||
template <typename A, class KERNEL>
|
||||
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::set_pages(std::vector<void*> kv_lora_pages, std::vector<void*> pe_pages) {
|
||||
// this->kv_lora_pages = kv_lora_pages;
|
||||
// this->rope_pages = pe_pages;
|
||||
}
|
||||
template <typename A, class KERNEL>
|
||||
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::set_local_pages(int page_count) {
|
||||
cc_pages.resize(page_count);
|
||||
|
||||
cc_page_refs.resize(page_count);
|
||||
kv_lora_page_refs.resize(page_count);
|
||||
rope_page_refs.resize(page_count);
|
||||
|
||||
// cc_page_refs_buffer.resize(page_count);
|
||||
// kv_lora_page_refs_buffer.resize(page_count);
|
||||
// rope_page_refs_buffer.resize(page_count);
|
||||
|
||||
cc_page_refs_decode_buffer.resize(page_count);
|
||||
kv_lora_page_refs_decode_buffer.resize(page_count);
|
||||
rope_page_refs_decode_buffer.resize(page_count);
|
||||
|
||||
for (int i = 0; i < page_count; i++) {
|
||||
cc_pages[i] = new A[config.token_count_in_page * (config.kv_lora_rank + config.rope_size)];
|
||||
cc_page_refs[i] = KMatRef(cc_pages[i], cc_size, config.token_count_in_page, cc_size, CblasColMajor);
|
||||
// cc_page_refs_buffer[i] = KMatRefC(local_kv_a_proj_with_mqa_deprecated_bc[i]->c, cc_size,
|
||||
// config.token_count_in_page,
|
||||
// cc_size, CblasColMajor);
|
||||
kv_lora_page_refs[i] = cc_page_refs[i].offset_row(0, config.kv_lora_rank);
|
||||
// kv_lora_page_refs_buffer[i] = cc_page_refs_buffer[i].offset_row(0, config.kv_lora_rank);
|
||||
kv_lora_page_refs_decode_buffer[i] = cc_page_refs_decode_buffer[i].offset_col(0, config.kv_lora_rank);
|
||||
rope_page_refs[i] = cc_page_refs[i].offset_row(config.kv_lora_rank, config.rope_size);
|
||||
// rope_page_refs_buffer[i] = cc_page_refs_buffer[i].offset_row(config.kv_lora_rank, config.rope_size);
|
||||
rope_page_refs_decode_buffer[i] = cc_page_refs_decode_buffer[i].offset_col(config.kv_lora_rank, config.rope_size);
|
||||
}
|
||||
}
|
||||
template <typename A, class KERNEL>
|
||||
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables,
|
||||
std::vector<int> kv_lens, std::vector<void*> attention_masks,
|
||||
const void* input, void* output) {
|
||||
if (qlens[0] <= 1) {
|
||||
forward_decode(qlens, page_tables, kv_lens, attention_masks, (input_t*)input, (output_t*)output);
|
||||
} else {
|
||||
forward_prefill(qlens, page_tables, kv_lens, attention_masks, (input_t*)input, (output_t*)output);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename A, class KERNEL>
|
||||
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables,
|
||||
std::vector<int> kv_lens, const void* input, void* output) {
|
||||
forward(qlens, page_tables, kv_lens, default_attention_masks, input, output);
|
||||
}
|
||||
@@ -1,268 +0,0 @@
|
||||
#include "../mla-tp.hpp"
|
||||
#include "../reduce.hpp"
|
||||
#include "../rms-norm.hpp"
|
||||
#include "../rope.hpp"
|
||||
#include "../softmax.hpp"
|
||||
#include "ggml-quants.h"
|
||||
#include "ggml.h"
|
||||
#include "kblas.h"
|
||||
#include "la/arm_kml.hpp"
|
||||
|
||||
// #define DEBUG_THIS_MLA
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
#include "test/debug.hpp"
|
||||
#endif
|
||||
|
||||
#define DIRECT_OR_POOL_BY(what, threshold, var, fn) \
|
||||
do { \
|
||||
if ((what) < (threshold)) { \
|
||||
for (int i = 0; i < (var); i++) { \
|
||||
(fn)(i); \
|
||||
} \
|
||||
} else { \
|
||||
pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
template <typename A, class KERNEL>
|
||||
class KML_MLA_TP_QUAN_TEST
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
: protected TimePerf
|
||||
#endif
|
||||
{
|
||||
public:
|
||||
using input_t = A;
|
||||
using output_t = A;
|
||||
using quant_t = int8_t;
|
||||
KML_MLA_TP_QUAN_TEST(GeneralMLAConfig config, int tp_part_idx);
|
||||
void load_weights(int complete_num_heads, int offset);
|
||||
void set_pages(std::vector<void*> kv_lora_pages, std::vector<void*> pe_pages);
|
||||
void set_local_pages(int page_count);
|
||||
void forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,
|
||||
std::vector<void*> attention_masks, const void* input, void* output);
|
||||
void forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,
|
||||
const void* input, void* output);
|
||||
|
||||
private:
|
||||
using T_RMSNorm = RMSNorm<A>;
|
||||
using T_RopeAngle = DeepseekV3YarnRotaryEmbedding;
|
||||
using T_RopeApplier = Rope<T_RopeAngle, A>;
|
||||
using T_SoftmaxApplier = Softmax<A>;
|
||||
using KMatRefA = typename arm_kml::MatRef<int8_t>;
|
||||
using KMatRefB = typename arm_kml::MatRef<typename KERNEL::dt>;
|
||||
using KMatRef = typename arm_kml::MatRef<A>;
|
||||
using KMatRefC = typename arm_kml::MatRef<int32_t>;
|
||||
using GemmKernel = KERNEL;
|
||||
|
||||
GeneralMLAConfig config;
|
||||
const size_t col_block = 256;
|
||||
const size_t row_block = 256;
|
||||
|
||||
// for quant
|
||||
const size_t col_block_q_absorb = 512; // 上限kv_lora_rank:512
|
||||
const size_t row_block_q_absorb = 512; // 上限qlen
|
||||
|
||||
const size_t col_block_q_lora_kv_lora_rope = 256;
|
||||
const size_t row_block_q_lora_kv_lora_rope = 256;
|
||||
|
||||
const size_t col_block_q_nope_rope = 512;
|
||||
const size_t row_block_q_nope_rope = 512;
|
||||
|
||||
const size_t col_block_attention_output = 1024; // 上限qlen的大小
|
||||
const size_t row_block_attention_output = 256; // 上限128
|
||||
|
||||
const size_t col_block_out_by_head = 256; // 上限 qlen
|
||||
const size_t row_block_out_by_head = 256; // 上限 hidden_size:7168
|
||||
|
||||
// ==========================================
|
||||
// decode
|
||||
|
||||
const size_t decode_col_block_q_absorb = 256; // 上限kv_lora_rank:512
|
||||
const size_t decode_row_block_q_absorb = 64; // 上限qlen
|
||||
|
||||
const size_t decode_col_block_q_lora_kv_lora_rope = 64;
|
||||
const size_t decode_row_block_q_lora_kv_lora_rope = 64;
|
||||
|
||||
const size_t decode_col_block_q_nope_rope = 512;
|
||||
const size_t decode_row_block_q_nope_rope = 512;
|
||||
|
||||
const size_t decode_col_block_attention_output = 1024; // 上限qlen的大小
|
||||
const size_t decode_row_block_attention_output = 256; // 上限128
|
||||
|
||||
const size_t decode_col_block_out_by_head = 1; // 上限 qlen
|
||||
const size_t decode_row_block_out_by_head = 512; // 上限 hidden_size:7168
|
||||
|
||||
// ==========================================
|
||||
|
||||
const size_t col_block_o_absorb = 256;
|
||||
const size_t row_block_o_absorb = 256;
|
||||
|
||||
const size_t col_block_nope_attention = 256;
|
||||
const size_t row_block_nope_attention = 256;
|
||||
|
||||
const size_t col_block_pe_attention = 256;
|
||||
const size_t row_block_pe_attention = 256;
|
||||
|
||||
int tp_part_idx;
|
||||
std::vector<void*> default_attention_masks;
|
||||
|
||||
// std::vector<void *> kv_lora_pages; // [page_count * page_token_count * nope]
|
||||
// std::vector<void *> rope_pages; // [page_count * page_token_count * nope]
|
||||
std::vector<A*> cc_pages; // [page_count * page_token_count * (kv rank + rope size)]
|
||||
size_t cc_size;
|
||||
// col major:[kv_lora_rank, qlen] or row major:[qlen, kv_lora_rank]
|
||||
std::vector<KMatRef> cc_page_refs, kv_lora_page_refs, rope_page_refs;
|
||||
// col major:[kv_lora_rank, qlen] or [rope_size, qlen]
|
||||
std::vector<KMatRefC> cc_page_refs_buffer, kv_lora_page_refs_buffer, rope_page_refs_buffer;
|
||||
|
||||
// row major:[qlen, kv_lora_rank] or [qlen, rope_size]
|
||||
std::vector<KMatRefC> cc_page_refs_decode_buffer, kv_lora_page_refs_decode_buffer, rope_page_refs_decode_buffer;
|
||||
// weights
|
||||
A* local_q_a_proj; // [hidden_size * q_lora_rank]
|
||||
A* local_q_a_norm;
|
||||
A* local_q_b_proj; // [num_heads * (nope_size + rope_size))]
|
||||
A* local_kv_a_proj_with_mqa; // [hidden_size * (kv_lora_rank + rope)]
|
||||
A* local_kv_a_norm;
|
||||
A* local_k_b_proj;
|
||||
A* local_v_b_proj;
|
||||
// A *local_kv_b_proj; // [(num_heads * (nope_size + nope_size) * kv_lora_rank)],
|
||||
// q_absorb: [num_heads * nope_size * kv_lora_rank]
|
||||
// out_absorb: [num_heads * nope_size * kv_lora_rank]
|
||||
A* local_w_o; // [(num_heads * hidden_size * nope_size)]
|
||||
std::vector<A*>
|
||||
local_w_decode_o; // [num_heads/sub_num_heads_decode]*[(sub_num_heads_decode * hidden_size * nope_size)]
|
||||
|
||||
std::unique_ptr<T_RopeAngle> rope_angle;
|
||||
|
||||
// KMatRefAB local_q_a_proj_ref;
|
||||
KMatRefB local_q_a_proj_quant_ref;
|
||||
KMatRef local_q_b_proj_ref;
|
||||
// KMatRefAB local_kv_a_proj_with_mqa_ref;
|
||||
KMatRefB local_kv_a_proj_with_mqa_decode_ref;
|
||||
KMatRef local_k_b_proj_ref;
|
||||
KMatRef local_v_b_proj_ref;
|
||||
// KMatRefAB local_w_o_decode_ref;
|
||||
KMatRefB local_w_o_ref;
|
||||
|
||||
typename GemmKernel::BufferA* local_q_a_proj_quant_ba; // [max_qlen,hidden_size]
|
||||
typename GemmKernel::BufferB* local_q_a_proj_quant_bb; // [hidden_size, q_lora_rank]
|
||||
typename GemmKernel::BufferC* local_q_a_proj_quant_bc; // [max_qlen,q_lora_rank] (row major)
|
||||
|
||||
typename GemmKernel::BufferA* local_kv_a_proj_with_mqa_quant_ba; // [max_qlen, hidden_size]
|
||||
typename GemmKernel::BufferB* local_kv_a_proj_with_mqa_quant_bb; // [hidden_size, kv_lora_rank + rope_size]
|
||||
std::vector<typename GemmKernel::BufferC*>
|
||||
local_kv_a_proj_with_mqa_quant_bc; // page_count * [page_token_count, rope_size + kv_lora_rank] (row major)
|
||||
|
||||
// 对应local_w_o
|
||||
|
||||
// 对应local_w_o
|
||||
typename GemmKernel::BufferA* local_w_o_quant_ba; // [max_qlen, num_heads * nope_size]
|
||||
typename GemmKernel::BufferB* local_w_o_quant_bb; // [num_heads * nope_size, hidden_size]
|
||||
// qlen_output
|
||||
typename GemmKernel::BufferC* local_w_o_prefill_bc; // [max_qlen, hidden_size]
|
||||
std::vector<typename GemmKernel::BufferC*>
|
||||
local_w_o_decode_bc; // [num_heads/sub_num_heads_decode] *[max_qlen, hidden_size]
|
||||
|
||||
// std::vector<typename GemmKernel::BufferA *>
|
||||
// local_w_o_decode_ba; // [num_heads/sub_num_heads_decode]*[hidden_size,sub_num_heads_decode * nope_size] row
|
||||
// major
|
||||
std::vector<typename GemmKernel::BufferB*>
|
||||
local_w_o_decode_bb; // [num_heads/sub_num_heads_decode]*[sub_num_heads_decode * nope_size,hidden_size] col major
|
||||
|
||||
// for each query
|
||||
A* q_lora_rank; // [qlen_sum, q_lora_rank]
|
||||
A* q_nope; // [num_heads * max_qlen, nope_size]
|
||||
KMatRef q_nope_tmp_ref;
|
||||
|
||||
A* q; // [num_heads * max_qlen, max(kv_lora_rank,nope_size) + rope_size]
|
||||
size_t q_ld;
|
||||
KMatRef q_ref, q_pe_absorb_ref, q_pe_noabsorb_ref, q_nope_ref, q_kv_lora_rank_ref, q_attn_absorb_ref,
|
||||
q_attn_noabsorb_ref;
|
||||
// std::vector<A *> q_pe; // [num_heads * max_qlen * rope_size]
|
||||
// std::vector<A *> q_nope; // [num_heads * max_qlen * nope_size]
|
||||
A* k; // [num_heads * max_kvlen * (nope_size + rope_size)]
|
||||
KMatRef k_ref, k_nope_ref, k_rope_ref;
|
||||
|
||||
A* attention_weights;
|
||||
KMatRef attention_weights_ref; // [max_kvlen, max_qlen* num_heads]
|
||||
// std::vector<A *> q_absorb; // [num_heads, max_qlen, kv_lora_rank], or [num_heads, kv_lora_rank, max_qlen]
|
||||
A* o_absorb_or_v;
|
||||
KMatRef o_absorb_ref; // [num_heads, max_qlen, kv_lora_rank]
|
||||
KMatRef v_ref; // [num_heads,nope,max_kvlen]
|
||||
|
||||
A* attention_output; // [num_heads * max_qlen * nope]
|
||||
KMatRef attention_output_ref;
|
||||
size_t sub_num_heads = 16; // 用于并发的子头设置
|
||||
size_t sub_num_heads_decode = 8; // 用于并发的子头设置
|
||||
|
||||
// std::vector<A *> qlen_decode_output; // [max_qlen * hidden_size]
|
||||
std::vector<A*> qlen_quant_output; // [[num_heads/sub_num_heads] * max_qlen * hidden_size] row major
|
||||
|
||||
A softmax_scale;
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
std::string file_name;
|
||||
#endif
|
||||
|
||||
static bool decide_absorb(size_t qlen, size_t existing_kvlen) {
|
||||
double x = existing_kvlen;
|
||||
return qlen < (-x + sqrt(x * (x + 2048.0 / 3.0)) / 2.0);
|
||||
}
|
||||
// 只保留声明,移除实现
|
||||
void nope_attention_q_absorb(int qlen, int kvlen, const std::vector<int>& page_table, bool increamental = true,
|
||||
bool is_decode = false);
|
||||
void nope_attention_no_absorb(int qlen, int kvlen, const std::vector<int>& page_table, bool increamental = true);
|
||||
void output_absorb(int query, const std::vector<int>& qlens, const std::vector<std::vector<int>>& page_tables,
|
||||
const std::vector<int>& kvlens, bool is_decode = false);
|
||||
void output_no_absorb(int query, const std::vector<int>& qlens, const std::vector<std::vector<int>>& page_tables,
|
||||
const std::vector<int>& kvlens);
|
||||
void forward_prefill(const std::vector<int>& qlens, const std::vector<std::vector<int>>& page_tables,
|
||||
const std::vector<int>& kvlens, const std::vector<void*>& attention_masks,
|
||||
const input_t* input_raw, output_t* output_raw);
|
||||
void forward_decode(const std::vector<int>& qlens, const std::vector<std::vector<int>>& page_tables,
|
||||
const std::vector<int>& kvlens, const std::vector<void*>& attention_masks,
|
||||
const input_t* input_raw, output_t* output_raw);
|
||||
void q_lora_kv_lora_rope_quant(int query, const std::vector<int>& qlens, const std::vector<int>& kvlens,
|
||||
const std::vector<std::vector<int>>& page_tables, std::vector<int>& qlen_split,
|
||||
KMatRefA& input_ref, KMatRefC& q_lora_rank_ref, KMatRef& q_lora_rank_out_ref,
|
||||
bool is_decode = false);
|
||||
};
|
||||
|
||||
template <typename A, class KERNEL>
|
||||
class TP_MLA<KML_MLA_TP_QUAN_TEST<A, KERNEL>> : public TP_MLA_Common<KML_MLA_TP_QUAN_TEST<A, KERNEL>> {
|
||||
public:
|
||||
using TP_MLA_Common<KML_MLA_TP_QUAN_TEST<A, KERNEL>>::TP_MLA_Common;
|
||||
|
||||
void load_weights() {
|
||||
auto pool = this->config.pool;
|
||||
auto tp_num_heads = this->config.num_heads / this->tp_count;
|
||||
pool->dispense_backend()->do_numa_job([this, pool, tp_num_heads](int tp_id) {
|
||||
this->tps[tp_id]->load_weights(this->config.num_heads, tp_id * tp_num_heads);
|
||||
});
|
||||
this->weights_loaded = true;
|
||||
}
|
||||
|
||||
void merge_results(int qlen, void* output_raw) {
|
||||
auto pool = this->config.pool;
|
||||
typename KML_MLA_TP_QUAN_TEST<A, KERNEL>::output_t* output =
|
||||
(typename KML_MLA_TP_QUAN_TEST<A, KERNEL>::output_t*)output_raw;
|
||||
pool->do_work_stealing_job(qlen, [this, output](int token_nth) {
|
||||
auto& tp_count = this->tp_count;
|
||||
auto& config = this->config;
|
||||
auto& local_output_numa = this->local_output_numa;
|
||||
reduce_sum(local_output_numa.data(), tp_count, token_nth * config.hidden_size,
|
||||
token_nth * config.hidden_size + config.hidden_size);
|
||||
memcpy(&output[token_nth * config.hidden_size], &local_output_numa[0][token_nth * config.hidden_size],
|
||||
config.hidden_size * sizeof(output[0]));
|
||||
});
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
// dump_bin_int8("output.bin", output, qlen * this->config.hidden_size);
|
||||
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template class KML_MLA_TP_QUAN_TEST<float, arm_kml::GemmKernelInt8>;
|
||||
template class KML_MLA_TP_QUAN_TEST<float, arm_kml::GemmKernelInt4>;
|
||||
// template class KML_MLA_TP_QUAN_TEST<float16_t>;
|
||||
@@ -1,546 +0,0 @@
|
||||
// #define DEBUG_THIS_MLA
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
#include "test/debug.hpp"
|
||||
#endif
|
||||
#include "mla_quan.h"
|
||||
|
||||
template <typename A, class KERNEL>
|
||||
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::nope_attention_q_absorb(int qlen, int kvlen, const std::vector<int>& page_table,
|
||||
bool increamental, bool is_decode) {
|
||||
// 使用 lambda 包装器
|
||||
auto mul_mat_clearc = [is_decode](auto a, auto b, auto c) {
|
||||
if (is_decode) {
|
||||
arm_kml::decode_mul_mat_clearc(a, b, c);
|
||||
} else {
|
||||
arm_kml::mul_mat_clearc(a, b, c);
|
||||
}
|
||||
};
|
||||
|
||||
auto pool = config.pool->get_subpool(tp_part_idx);
|
||||
{
|
||||
// q absorb
|
||||
size_t qlen_block = div_up((size_t)qlen, row_block);
|
||||
size_t kv_rank_block = div_up(config.kv_lora_rank, col_block);
|
||||
auto task_counter = TaskCounter({config.num_heads, qlen_block, kv_rank_block});
|
||||
|
||||
auto task = [&](int task_id) {
|
||||
auto [head_idx, qlen_block_idx, kv_rank_block_idx] = task_counter.get<3>(task_id);
|
||||
|
||||
size_t qlen_begin = qlen_block_idx * row_block;
|
||||
size_t qlen_end = std::min(qlen_begin + row_block, (size_t)qlen);
|
||||
|
||||
size_t kv_rank_begin = kv_rank_block_idx * col_block;
|
||||
size_t kv_rank_end = std::min(kv_rank_begin + col_block, (size_t)config.kv_lora_rank);
|
||||
|
||||
KMatRef this_local_k_b_proj_ref = local_k_b_proj_ref.offset_block(head_idx * config.nope_size, kv_rank_begin,
|
||||
config.nope_size, kv_rank_end - kv_rank_begin);
|
||||
this_local_k_b_proj_ref = this_local_k_b_proj_ref.t();
|
||||
// printf("q absorb %d [%d,%d),[%d,%d)\n", head_idx, qlen_begin, qlen_end, kv_rank_begin, kv_rank_end);
|
||||
mul_mat_clearc(this_local_k_b_proj_ref,
|
||||
q_nope_tmp_ref.offset_col(head_idx * qlen + qlen_begin, qlen_end - qlen_begin),
|
||||
|
||||
q_kv_lora_rank_ref.offset_block(kv_rank_begin, head_idx * qlen + qlen_begin,
|
||||
kv_rank_end - kv_rank_begin, qlen_end - qlen_begin));
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("q absorb [%d] \n", tp_part_idx);
|
||||
// dump_bin(file_name + "_k_b_lora", (A *)local_kv_b_proj, config.kv_lora_rank * config.nope_size);
|
||||
// dump_bin(file_name + "_q_absorb", (A *)q_absorb[0], config.kv_lora_rank * config.max_qlen);
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
if (is_decode) {
|
||||
PROFILE_RECORD_TIME_STAMP("decode q absorb");
|
||||
} else {
|
||||
PROFILE_RECORD_TIME_STAMP("prefill q absorb");
|
||||
}
|
||||
#endif
|
||||
|
||||
{
|
||||
// nope attention
|
||||
|
||||
size_t qlen_block = div_up((size_t)qlen * config.num_heads, col_block);
|
||||
// page size % col_block == 0
|
||||
size_t kvlen_block = div_up((size_t)kvlen + qlen, col_block);
|
||||
|
||||
TaskCounter task_counter({qlen_block, kvlen_block});
|
||||
auto task = [&](int task_id) {
|
||||
auto [qlen_block_idx, kvlen_block_idx] = task_counter.get<2>(task_id);
|
||||
|
||||
size_t qlen_begin = qlen_block_idx * col_block;
|
||||
size_t qlen_end = std::min(qlen_begin + col_block, (size_t)qlen * config.num_heads);
|
||||
|
||||
size_t kvlen_begin = kvlen_block_idx * col_block;
|
||||
size_t kvlen_end = std::min(kvlen_begin + col_block, (size_t)kvlen + qlen);
|
||||
|
||||
size_t kvlen_block_size = kvlen_end - kvlen_begin;
|
||||
size_t kv_page = kvlen_begin / config.token_count_in_page;
|
||||
size_t token_at_in_page = kvlen_begin % config.token_count_in_page;
|
||||
|
||||
KMatRef this_cc_ref =
|
||||
KMatRef((A*)cc_pages[page_table[kv_page]], config.token_count_in_page, cc_size, cc_size, CblasRowMajor);
|
||||
this_cc_ref = this_cc_ref.offset_row(token_at_in_page, kvlen_end - kvlen_begin);
|
||||
|
||||
KMatRef this_q_aborb_ref = q_attn_absorb_ref.offset_col(qlen_begin, qlen_end - qlen_begin);
|
||||
|
||||
KMatRef this_attention_weights_ref =
|
||||
attention_weights_ref.offset_block(kvlen_begin, qlen_begin, kvlen_end - kvlen_begin, qlen_end - qlen_begin);
|
||||
|
||||
arm_kml::mul_mat_clearc(this_cc_ref, this_q_aborb_ref, this_attention_weights_ref);
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("attention weights[%d] \n", tp_part_idx);
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
if (is_decode) {
|
||||
PROFILE_RECORD_TIME_STAMP("decode nope attention");
|
||||
} else {
|
||||
PROFILE_RECORD_TIME_STAMP("prefill nope attention");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename A, class KERNEL>
|
||||
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::nope_attention_no_absorb(int qlen, int kvlen, const std::vector<int>& page_table,
|
||||
bool increamental) {
|
||||
auto pool = config.pool->get_subpool(tp_part_idx);
|
||||
{
|
||||
// k nope
|
||||
size_t nope_block = div_up(config.nope_size, row_block);
|
||||
size_t kvlen_block = div_up((size_t)kvlen + qlen, col_block);
|
||||
auto task_counter = TaskCounter({config.num_heads, kvlen_block, nope_block});
|
||||
|
||||
auto task = [&](int task_id) {
|
||||
size_t head_idx = task_counter.at(task_id, 0);
|
||||
size_t kvlen_block_idx = task_counter.at(task_id, 1);
|
||||
size_t nope_block_idx = task_counter.at(task_id, 2);
|
||||
|
||||
size_t kvlen_begin = kvlen_block_idx * col_block;
|
||||
size_t kvlen_end = std::min(kvlen_begin + col_block, (size_t)kvlen + qlen);
|
||||
|
||||
size_t kvlen_block_size = kvlen_end - kvlen_begin;
|
||||
size_t kv_page = kvlen_begin / config.token_count_in_page;
|
||||
size_t token_at_in_page = kvlen_begin % config.token_count_in_page;
|
||||
|
||||
size_t nope_begin = nope_block_idx * row_block;
|
||||
size_t nope_end = std::min(nope_begin + row_block, config.nope_size);
|
||||
|
||||
auto k_b_ref = local_k_b_proj_ref.offset_row(head_idx * config.nope_size + nope_begin, nope_end - nope_begin);
|
||||
|
||||
KMatRef cc_ref = kv_lora_page_refs[page_table[kv_page]];
|
||||
cc_ref = cc_ref.offset_col(token_at_in_page, kvlen_end - kvlen_begin);
|
||||
|
||||
KMatRef this_k_nope_ref = k_nope_ref.offset_block(nope_begin, head_idx * config.max_kvlen + kvlen_begin,
|
||||
nope_end - nope_begin, kvlen_end - kvlen_begin);
|
||||
|
||||
arm_kml::mul_mat_clearc(k_b_ref, cc_ref, this_k_nope_ref);
|
||||
if (nope_block_idx == 0) {
|
||||
auto this_k_rope_ref =
|
||||
k_rope_ref.offset_col(head_idx * config.max_kvlen + kvlen_begin, kvlen_end - kvlen_begin);
|
||||
auto this_rope_page_ref =
|
||||
rope_page_refs[page_table[kv_page]].offset_col(token_at_in_page, kvlen_end - kvlen_begin);
|
||||
for (size_t i = 0; i < this_k_rope_ref.C; i++) {
|
||||
memcpy(this_k_rope_ref.data + this_k_rope_ref.ld * i, this_rope_page_ref.data + this_rope_page_ref.ld * i,
|
||||
config.rope_size * sizeof(A));
|
||||
}
|
||||
}
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("k nope");
|
||||
#endif
|
||||
|
||||
{
|
||||
// nope attention
|
||||
size_t kvlen_block = div_up((size_t)kvlen + qlen, row_block);
|
||||
size_t qlen_block = div_up((size_t)qlen, col_block);
|
||||
auto task_counter = TaskCounter({config.num_heads, kvlen_block, qlen_block});
|
||||
|
||||
auto task = [&](int task_id) {
|
||||
size_t head_idx = task_counter.at(task_id, 0);
|
||||
size_t kvlen_block_idx = task_counter.at(task_id, 1);
|
||||
size_t qlen_block_idx = task_counter.at(task_id, 2);
|
||||
|
||||
size_t kvlen_begin = kvlen_block_idx * row_block;
|
||||
size_t kvlen_end = std::min(kvlen_begin + row_block, (size_t)kvlen + qlen);
|
||||
size_t qlen_begin = qlen_block_idx * col_block;
|
||||
size_t qlen_end = std::min(qlen_begin + col_block, (size_t)qlen);
|
||||
|
||||
KMatRef this_k_ref = k_ref.offset_col(head_idx * config.max_kvlen + kvlen_begin, kvlen_end - kvlen_begin);
|
||||
this_k_ref = this_k_ref.t();
|
||||
|
||||
KMatRef this_q_ref = q_attn_noabsorb_ref.offset_col(head_idx * qlen + qlen_begin, qlen_end - qlen_begin);
|
||||
|
||||
KMatRef this_attention_weights_ref = attention_weights_ref.offset_block(
|
||||
kvlen_begin, head_idx * qlen + qlen_begin, kvlen_end - kvlen_begin, qlen_end - qlen_begin);
|
||||
if (increamental) {
|
||||
arm_kml::mul_mat(this_k_ref, this_q_ref, this_attention_weights_ref);
|
||||
} else {
|
||||
arm_kml::mul_mat_clearc(this_k_ref, this_q_ref, this_attention_weights_ref);
|
||||
}
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("nope attention no absorb");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename A, class KERNEL>
|
||||
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::output_absorb(int query, const std::vector<int>& qlens,
|
||||
const std::vector<std::vector<int>>& page_tables,
|
||||
const std::vector<int>& kvlens, bool is_decode) {
|
||||
// 使用 lambda 包装器
|
||||
auto mul_mat_clearc = [is_decode](auto a, auto b, auto c) {
|
||||
if (is_decode) {
|
||||
arm_kml::decode_mul_mat_clearc(a, b, c);
|
||||
} else {
|
||||
arm_kml::mul_mat_clearc(a, b, c);
|
||||
}
|
||||
};
|
||||
auto pool = config.pool->get_subpool(tp_part_idx);
|
||||
{
|
||||
// by page
|
||||
size_t page_count = div_up((size_t)kvlens[query] + qlens[query], config.token_count_in_page);
|
||||
for (int kv_page = 0; kv_page < page_count; kv_page++) {
|
||||
// o absorb
|
||||
|
||||
size_t kvlen_begin = kv_page * config.token_count_in_page;
|
||||
size_t kvlen_end = std::min(kvlen_begin + config.token_count_in_page, (size_t)kvlens[query] + qlens[query]);
|
||||
|
||||
size_t page_kv_len = kvlen_end - kvlen_begin;
|
||||
|
||||
size_t qlen_block = div_up((size_t)qlens[query] * config.num_heads, row_block);
|
||||
size_t kv_rank_block = div_up(config.kv_lora_rank, col_block);
|
||||
auto task_counter = TaskCounter({qlen_block, kv_rank_block});
|
||||
auto task = [&](int task_id) {
|
||||
auto [qlen_block_idx, kv_rank_block_idx] = task_counter.get<2>(task_id);
|
||||
|
||||
size_t qlen_begin = qlen_block_idx * row_block;
|
||||
size_t qlen_end = std::min(qlen_begin + row_block, (size_t)qlens[query] * config.num_heads);
|
||||
size_t kv_rank_begin = kv_rank_block_idx * col_block;
|
||||
size_t kv_rank_end = std::min(kv_rank_begin + col_block, (size_t)config.kv_lora_rank);
|
||||
KMatRef kv_lora_page_ref = kv_lora_page_refs[page_tables[query][kv_page]];
|
||||
kv_lora_page_ref = kv_lora_page_ref.offset_block(kv_rank_begin, 0, kv_rank_end - kv_rank_begin, page_kv_len);
|
||||
|
||||
KMatRef this_attention_weights_ref =
|
||||
attention_weights_ref.offset_block(kvlen_begin, qlen_begin, page_kv_len, qlen_end - qlen_begin);
|
||||
|
||||
KMatRef this_o_absorb_ref =
|
||||
o_absorb_ref.offset_block(kv_rank_begin, qlen_begin, kv_rank_end - kv_rank_begin, qlen_end - qlen_begin);
|
||||
if (kv_page == 0) {
|
||||
arm_kml::mul_mat_clearc(kv_lora_page_ref, this_attention_weights_ref, this_o_absorb_ref);
|
||||
} else {
|
||||
arm_kml::mul_mat(kv_lora_page_ref, this_attention_weights_ref, this_o_absorb_ref);
|
||||
}
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("o absorb[%d]\n", tp_part_idx);
|
||||
for (size_t i = 0; i < config.num_heads; i++)
|
||||
dump_bin(file_name + "_o_absorb_" + std::to_string(i), (A*)o_absorb_or_v, config.kv_lora_rank * qlens[query]);
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
if (is_decode) {
|
||||
PROFILE_RECORD_TIME_STAMP("decode o absorb");
|
||||
} else {
|
||||
PROFILE_RECORD_TIME_STAMP("prefill o absorb");
|
||||
}
|
||||
#endif
|
||||
|
||||
{
|
||||
// attention output
|
||||
auto qlen_block = div_up((size_t)qlens[query], col_block);
|
||||
auto nope_block = div_up((size_t)config.nope_size, row_block);
|
||||
auto task_counter = TaskCounter({config.num_heads, qlen_block, nope_block});
|
||||
auto task = [&](int task_id) {
|
||||
size_t head_idx = task_counter.at(task_id, 0);
|
||||
size_t qlen_block_idx = task_counter.at(task_id, 1);
|
||||
size_t nope_block_idx = task_counter.at(task_id, 2);
|
||||
|
||||
size_t qlen_begin = qlen_block_idx * col_block;
|
||||
size_t qlen_end = std::min(qlen_begin + col_block, (size_t)qlens[query]);
|
||||
size_t nope_begin = nope_block_idx * row_block;
|
||||
size_t nope_end = std::min(nope_begin + row_block, (size_t)config.nope_size);
|
||||
|
||||
KMatRef this_local_v_b_proj_ref =
|
||||
local_v_b_proj_ref.offset_row(head_idx * config.nope_size + nope_begin, nope_end - nope_begin);
|
||||
|
||||
KMatRef this_o_absorb_ref = o_absorb_ref.offset_col(head_idx * qlens[query] + qlen_begin, qlen_end - qlen_begin);
|
||||
|
||||
KMatRef this_attention_output_ref = attention_output_ref.offset_block(
|
||||
head_idx * config.nope_size + nope_begin, qlen_begin, nope_end - nope_begin, qlen_end - qlen_begin);
|
||||
|
||||
mul_mat_clearc(this_local_v_b_proj_ref, this_o_absorb_ref, this_attention_output_ref);
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("attention output[%d]\n", tp_part_idx);
|
||||
dump_bin(file_name + "_attention_output", (A*)attention_output, config.num_heads * config.nope_size * qlens[query]);
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
if (is_decode) {
|
||||
PROFILE_RECORD_TIME_STAMP("decode attention output");
|
||||
} else {
|
||||
PROFILE_RECORD_TIME_STAMP("prefill attention output");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename A, class KERNEL>
|
||||
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::output_no_absorb(int query, const std::vector<int>& qlens,
|
||||
const std::vector<std::vector<int>>& page_tables,
|
||||
const std::vector<int>& kvlens) {
|
||||
auto pool = config.pool->get_subpool(tp_part_idx);
|
||||
{
|
||||
// v
|
||||
size_t page_count = div_up((size_t)kvlens[query] + qlens[query], config.token_count_in_page);
|
||||
size_t nope_block_count = div_up((size_t)config.nope_size, row_block);
|
||||
size_t kvlen_in_page_count = div_up(config.token_count_in_page, col_block);
|
||||
ASSERT_RELEASE(config.token_count_in_page % col_block == 0, "token_count_in_page must be divisible by col_block");
|
||||
|
||||
auto task_counter = TaskCounter({config.num_heads, page_count, kvlen_in_page_count, nope_block_count});
|
||||
auto task = [&](int task_id) {
|
||||
size_t head_idx = task_counter.at(task_id, 0);
|
||||
size_t page_idx = task_counter.at(task_id, 1);
|
||||
size_t kvlen_idx = task_counter.at(task_id, 2);
|
||||
size_t nope_block_idx = task_counter.at(task_id, 3);
|
||||
|
||||
size_t kvlen_begin = page_idx * config.token_count_in_page + kvlen_idx * col_block;
|
||||
size_t kvlen_end = std::min(kvlen_begin + col_block, (size_t)kvlens[query] + qlens[query]);
|
||||
if (kvlen_begin >= kvlen_end) return; // skip the extra block
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("v nope[%d] %d %d %d %d\n", tp_part_idx, head_idx, page_idx, kvlen_begin, kvlen_end);
|
||||
#endif
|
||||
|
||||
size_t kvlen_begin_in_page = kvlen_begin % config.token_count_in_page;
|
||||
|
||||
size_t nope_begin = nope_block_idx * row_block;
|
||||
size_t nope_end = std::min(nope_begin + row_block, (size_t)config.nope_size);
|
||||
|
||||
KMatRef this_local_v_b_proj_ref =
|
||||
local_v_b_proj_ref.offset_row(head_idx * config.nope_size + nope_begin, nope_end - nope_begin);
|
||||
KMatRef kv_lora_page_ref = kv_lora_page_refs[page_tables[query][page_idx]];
|
||||
kv_lora_page_ref =
|
||||
kv_lora_page_ref.offset_block(0, kvlen_begin_in_page, config.kv_lora_rank, kvlen_end - kvlen_begin);
|
||||
KMatRef this_v_ref = v_ref.offset_block(nope_begin, head_idx * config.max_kvlen + kvlen_begin,
|
||||
nope_end - nope_begin, kvlen_end - kvlen_begin);
|
||||
arm_kml::mul_mat_clearc(this_local_v_b_proj_ref, kv_lora_page_ref, this_v_ref);
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("v nope[%d] done\n", tp_part_idx);
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("v nope");
|
||||
#endif
|
||||
|
||||
{
|
||||
// attn output
|
||||
size_t nope_block_count = div_up((size_t)config.nope_size, row_block);
|
||||
size_t qlen_block_count = div_up((size_t)qlens[query], col_block);
|
||||
auto task_counter = TaskCounter({config.num_heads, nope_block_count, qlen_block_count});
|
||||
auto task = [&](int task_id) {
|
||||
size_t head_idx = task_counter.at(task_id, 0);
|
||||
size_t nope_block_idx = task_counter.at(task_id, 1);
|
||||
size_t qlen_block_idx = task_counter.at(task_id, 2);
|
||||
size_t nope_begin = nope_block_idx * row_block;
|
||||
size_t nope_end = std::min(nope_begin + row_block, (size_t)config.nope_size);
|
||||
size_t qlen_begin = qlen_block_idx * col_block;
|
||||
size_t qlen_end = std::min(qlen_begin + col_block, (size_t)qlens[query]);
|
||||
if (qlen_begin >= qlen_end) return; // skip the extra block
|
||||
|
||||
KMatRef this_v_ref = v_ref.offset_block(nope_begin, head_idx * config.max_kvlen, nope_end - nope_begin,
|
||||
kvlens[query] + qlens[query]);
|
||||
|
||||
KMatRef this_attention_weights_ref = attention_weights_ref.offset_col(head_idx * qlens[query], qlens[query]);
|
||||
this_attention_weights_ref =
|
||||
this_attention_weights_ref.offset_block(0, qlen_begin, kvlens[query] + qlens[query], qlen_end - qlen_begin);
|
||||
|
||||
KMatRef this_attention_output_ref = attention_output_ref.offset_block(
|
||||
head_idx * config.nope_size + nope_begin, qlen_begin, nope_end - nope_begin, qlen_end - qlen_begin);
|
||||
|
||||
arm_kml::mul_mat_clearc(this_v_ref, this_attention_weights_ref, this_attention_output_ref);
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
dump_bin(file_name + "_attention_output", (A*)attention_output, config.num_heads * config.nope_size * qlens[query]);
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("attn output no absorb");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename A, class KERNEL>
|
||||
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::q_lora_kv_lora_rope_quant(int query, const std::vector<int>& qlens,
|
||||
const std::vector<int>& kvlens,
|
||||
const std::vector<std::vector<int>>& page_tables,
|
||||
std::vector<int>& qlen_split, KMatRefA& input_ref,
|
||||
KMatRefC& q_lora_rank_ref, KMatRef& q_lora_rank_out_ref,
|
||||
bool is_decode) {
|
||||
// 使用 lambda 包装器
|
||||
auto mul_mat_clearc = [is_decode](auto a, auto b, auto c) {
|
||||
if (is_decode) {
|
||||
arm_kml::decode_mul_mat_clearc(a, b, c);
|
||||
} else {
|
||||
arm_kml::mul_mat_clearc(a, b, c);
|
||||
}
|
||||
};
|
||||
auto pool = config.pool->get_subpool(tp_part_idx);
|
||||
auto total_len = qlens[query] + kvlens[query];
|
||||
size_t query_page_count = div_up((size_t)total_len, config.token_count_in_page);
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
file_name = std::to_string(query);
|
||||
file_name = "query_" + file_name + "_tp_" + std::to_string(tp_part_idx);
|
||||
|
||||
if (tp_part_idx == 0) {
|
||||
printf("qlen %d, kvlen %d, page table: ", qlens[query], kvlens[query]);
|
||||
for (auto x : page_tables[query]) {
|
||||
printf(" %d,", x);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
// KMatRef qlen_input_ref = input_ref.offset_block(0, qlen_split[query], config.hidden_size, qlens[query]);
|
||||
// KMatRefAB qlen_input_ref = input_ref.offset_block(0, qlen_split[query], config.hidden_size, qlens[query]);
|
||||
KMatRefA qlen_input_ref = input_ref.offset_row(qlen_split[query], qlens[query]);
|
||||
auto qlen_input_ba = local_q_a_proj_quant_ba->offset_row(qlen_split[query], qlens[query]);
|
||||
// auto local_q_a_proj_qlen_bb = local_q_a_proj_bb->offset_col(qlen_split[query], qlens[query]);
|
||||
{
|
||||
// q lora, kv lora, rope
|
||||
size_t cc_page_begin = (kvlens[query]) / config.token_count_in_page;
|
||||
size_t cc_page_end = div_up((size_t)kvlens[query] + qlens[query], config.token_count_in_page);
|
||||
size_t block_per_page = div_up(config.token_count_in_page, row_block_q_lora_kv_lora_rope);
|
||||
|
||||
size_t q_lora_rank_block_count = div_up((size_t)config.q_lora_rank, col_block_q_lora_kv_lora_rope);
|
||||
size_t kv_lora_rank_block_count = div_up((size_t)config.kv_lora_rank, col_block_q_lora_kv_lora_rope);
|
||||
size_t k_rope_block_count = div_up((size_t)config.rope_size, col_block_q_lora_kv_lora_rope);
|
||||
TaskCounter task_counter({cc_page_end - cc_page_begin, block_per_page,
|
||||
q_lora_rank_block_count + kv_lora_rank_block_count + k_rope_block_count});
|
||||
|
||||
auto task = [&](int task_id) {
|
||||
size_t cc_page = task_counter.at(task_id, 0) + cc_page_begin;
|
||||
size_t in_page_block_idx = task_counter.at(task_id, 1);
|
||||
size_t kvlen_begin =
|
||||
std::clamp(cc_page * config.token_count_in_page + in_page_block_idx * row_block_q_lora_kv_lora_rope,
|
||||
(size_t)kvlens[query], (size_t)kvlens[query] + qlens[query]);
|
||||
size_t kvlen_end =
|
||||
std::clamp(cc_page * config.token_count_in_page + (in_page_block_idx + 1) * row_block_q_lora_kv_lora_rope,
|
||||
(size_t)kvlens[query], (size_t)kvlens[query] + qlens[query]);
|
||||
// printf("kvlen[%d,%d)\n", kvlen_begin, kvlen_end);
|
||||
|
||||
size_t kvlen_block_size = kvlen_end - kvlen_begin;
|
||||
if (kvlen_block_size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t qlen_begin = kvlen_begin - kvlens[query];
|
||||
size_t qlen_end = kvlen_end - kvlens[query];
|
||||
|
||||
auto blocked_input = qlen_input_ref.offset_row(qlen_begin, qlen_end - qlen_begin);
|
||||
// auto blocked_input = qlen_input_ref.offset_block(0, qlen_begin, config.hidden_size, qlen_end - qlen_begin);
|
||||
// auto local_q_a_proj_qlen_blocked_bb = local_q_a_proj_qlen_bb.offset_col(qlen_begin, qlen_end - qlen_begin);
|
||||
|
||||
int q_or_kv_or_krope = task_counter.at(task_id, 2);
|
||||
|
||||
if (q_or_kv_or_krope < q_lora_rank_block_count) {
|
||||
size_t q_lora_rank_block_idx = q_or_kv_or_krope;
|
||||
size_t q_lora_rank_begin = q_lora_rank_block_idx * col_block_q_lora_kv_lora_rope;
|
||||
size_t q_lora_rank_end = std::min(config.q_lora_rank, q_lora_rank_begin + col_block_q_lora_kv_lora_rope);
|
||||
|
||||
mul_mat_clearc(blocked_input,
|
||||
local_q_a_proj_quant_ref.offset_col(q_lora_rank_begin, q_lora_rank_end - q_lora_rank_begin),
|
||||
q_lora_rank_ref.offset_block(qlen_begin, q_lora_rank_begin, qlen_end - qlen_begin,
|
||||
q_lora_rank_end - q_lora_rank_begin));
|
||||
|
||||
GemmKernel::apply_scale(q_lora_rank_out_ref.data, q_lora_rank_out_ref.ld, &qlen_input_ba,
|
||||
local_q_a_proj_quant_bb, local_q_a_proj_quant_bc, qlen_begin, qlen_end,
|
||||
q_lora_rank_begin, q_lora_rank_end, true);
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
// 打印前十个q_lora_rank_ref,和前十个q_lora_rank_out_ref
|
||||
printf("q_lora_rank_begin:%d, q_lora_rank_end:%d\n", q_lora_rank_begin, q_lora_rank_end);
|
||||
printf("qlen_begin:%d, qlen_end:%d\n", qlen_begin, qlen_end);
|
||||
if (tp_part_idx == 0) {
|
||||
for (int i = 0; i < 10; i++) {
|
||||
printf("q_lora_rank_ref:%d, %f\n", q_lora_rank_ref.data[i], q_lora_rank_out_ref.data[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} else if (q_or_kv_or_krope < q_lora_rank_block_count + kv_lora_rank_block_count) {
|
||||
size_t kv_lora_rank_block_idx = q_or_kv_or_krope - q_lora_rank_block_count;
|
||||
size_t kv_lora_rank_begin = kv_lora_rank_block_idx * col_block_q_lora_kv_lora_rope;
|
||||
size_t kv_lora_rank_end = std::min(config.kv_lora_rank, kv_lora_rank_begin + col_block_q_lora_kv_lora_rope);
|
||||
KMatRefC kv_lora_page_ref = kv_lora_page_refs_decode_buffer[page_tables[query][cc_page]].offset_block(
|
||||
kvlen_begin % config.token_count_in_page, kv_lora_rank_begin, kvlen_end - kvlen_begin,
|
||||
kv_lora_rank_end - kv_lora_rank_begin);
|
||||
mul_mat_clearc(
|
||||
blocked_input,
|
||||
local_kv_a_proj_with_mqa_decode_ref.offset_col(kv_lora_rank_begin, kv_lora_rank_end - kv_lora_rank_begin),
|
||||
kv_lora_page_ref);
|
||||
KMatRef kv_lora_page_out_ref = kv_lora_page_refs[page_tables[query][cc_page]];
|
||||
GemmKernel::apply_scale(
|
||||
kv_lora_page_out_ref.data, kv_lora_page_out_ref.ld, &qlen_input_ba, local_kv_a_proj_with_mqa_quant_bb,
|
||||
kv_lora_page_refs_decode_buffer[page_tables[query][cc_page]].data, qlen_begin, qlen_end, kv_lora_rank_begin,
|
||||
kv_lora_rank_end, true, (kvlen_begin) % config.token_count_in_page - qlen_begin, 0);
|
||||
} else if (q_or_kv_or_krope < q_lora_rank_block_count + kv_lora_rank_block_count + k_rope_block_count) {
|
||||
// single block for k rope, no norm
|
||||
size_t rope_block_idx = q_or_kv_or_krope - q_lora_rank_block_count - kv_lora_rank_block_count;
|
||||
size_t rope_begin = rope_block_idx * col_block_q_lora_kv_lora_rope;
|
||||
size_t rope_end = std::min(config.rope_size, rope_begin + col_block_q_lora_kv_lora_rope);
|
||||
KMatRefC rope_page_ref = rope_page_refs_decode_buffer[page_tables[query][cc_page]].offset_block(
|
||||
kvlen_begin % config.token_count_in_page, rope_begin, kvlen_end - kvlen_begin, rope_end - rope_begin);
|
||||
|
||||
mul_mat_clearc(
|
||||
blocked_input,
|
||||
local_kv_a_proj_with_mqa_decode_ref.offset_col(config.kv_lora_rank + rope_begin, rope_end - rope_begin),
|
||||
rope_page_ref);
|
||||
KMatRef rope_page_out_ref = rope_page_refs[page_tables[query][cc_page]];
|
||||
GemmKernel::apply_scale(rope_page_out_ref.data, rope_page_out_ref.ld, &qlen_input_ba,
|
||||
local_kv_a_proj_with_mqa_quant_bb,
|
||||
rope_page_refs_decode_buffer[page_tables[query][cc_page]].data, qlen_begin, qlen_end,
|
||||
config.kv_lora_rank + rope_begin, config.kv_lora_rank + rope_end, true,
|
||||
(kvlen_begin) % config.token_count_in_page - qlen_begin, -(config.kv_lora_rank));
|
||||
rope_page_out_ref = rope_page_out_ref.offset_block(rope_begin, kvlen_begin % config.token_count_in_page,
|
||||
rope_end - rope_begin, kvlen_end - kvlen_begin);
|
||||
T_RopeApplier::apply_multiple(*rope_angle, rope_page_out_ref.data, config.rope_size, rope_page_out_ref.ld,
|
||||
kvlen_begin, kvlen_block_size);
|
||||
} else {
|
||||
throw std::runtime_error("task id wrong");
|
||||
}
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("q lora, kv lora, rope[%d]\n", tp_part_idx);
|
||||
// dump_bin(file_name + "_input.bin", qlen_input_ref.data, qlens[query] * config.hidden_size);
|
||||
dump_bin(file_name + "_qlora.bin", q_lora_rank, qlens[query] * config.q_lora_rank);
|
||||
|
||||
for (int i = 0; i < query_page_count; i++) {
|
||||
dump_bin(file_name + "_page_" + std::to_string(i) + "_cc_pages", (A*)cc_pages[page_tables[query][i]],
|
||||
config.token_count_in_page * cc_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("q lora, kv lora, rope");
|
||||
#endif
|
||||
}
|
||||
@@ -1,310 +0,0 @@
|
||||
#include <cstdio>
|
||||
|
||||
#include "kblas.h"
|
||||
#include "la/arm_kml.hpp"
|
||||
#include "mla_quan.h"
|
||||
|
||||
template <typename A, class KERNEL>
|
||||
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::forward_decode(const std::vector<int>& qlens,
|
||||
const std::vector<std::vector<int>>& page_tables,
|
||||
const std::vector<int>& kvlens,
|
||||
const std::vector<void*>& attention_masks,
|
||||
const input_t* input_raw, output_t* output_raw) {
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("input raw[%d]\n", tp_part_idx);
|
||||
#endif
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
forward_perf_start();
|
||||
#endif
|
||||
|
||||
auto pool = config.pool->get_subpool(tp_part_idx);
|
||||
|
||||
std::vector<int> qlen_split, total_len_split;
|
||||
qlen_split.reserve(qlens.size() + 1);
|
||||
qlen_split.push_back(0);
|
||||
total_len_split.reserve(qlens.size() + 1);
|
||||
int qlen_sum = 0;
|
||||
int total_len_sum = 0;
|
||||
for (size_t i = 0; i < qlens.size(); i++) {
|
||||
qlen_sum += qlens[i];
|
||||
qlen_split.push_back(qlen_sum);
|
||||
|
||||
total_len_sum += qlens[i] + kvlens[i];
|
||||
total_len_split.push_back(total_len_sum);
|
||||
}
|
||||
|
||||
// 输入进行量化,输入是 hidden_size * qlen 且是 col major
|
||||
{
|
||||
size_t nth = GemmKernel::recommended_nth(qlen_sum);
|
||||
auto task_counter = TaskCounter({nth});
|
||||
auto task = [&](int task_id) {
|
||||
size_t nth_idx = task_counter.at(task_id, 0);
|
||||
local_q_a_proj_quant_ba->from_mat(qlen_sum, const_cast<A*>(input_raw), nth_idx, nth);
|
||||
// local_q_a_proj_deprecated_bb->from_mat(const_cast<A *>(input_raw), nth_idx, nth, qlen_sum);
|
||||
};
|
||||
DIRECT_OR_POOL_BY(qlen_sum, 10, task_counter.count(), task);
|
||||
}
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("input_quant");
|
||||
#endif
|
||||
KMatRefA input_ref =
|
||||
KMatRefA(local_q_a_proj_quant_ba->a, qlen_sum, config.hidden_size, config.hidden_size, CblasRowMajor);
|
||||
auto output_ref = KMatRef(output_raw, qlen_sum, config.hidden_size, config.hidden_size, CblasRowMajor);
|
||||
KMatRefC q_lora_rank_ref =
|
||||
KMatRefC(local_q_a_proj_quant_bc->c, qlen_sum, config.q_lora_rank, config.q_lora_rank, CblasRowMajor);
|
||||
KMatRef q_lora_rank_out_ref =
|
||||
KMatRef((A*)q_lora_rank, qlen_sum, config.q_lora_rank, config.q_lora_rank, CblasRowMajor);
|
||||
|
||||
for (int query = 0; query < qlens.size(); query++) {
|
||||
bool use_absorb = true;
|
||||
q_lora_kv_lora_rope_quant(query, qlens, kvlens, page_tables, qlen_split, input_ref, q_lora_rank_ref,
|
||||
q_lora_rank_out_ref, true);
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
if (tp_part_idx == 0) {
|
||||
dump_bin("input.bin", const_cast<A*>(input_raw), qlen_sum * config.hidden_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
{
|
||||
// norm
|
||||
auto task_counter = TaskCounter({2, (size_t)qlens[query]});
|
||||
auto task = [&](int task_id) {
|
||||
int q_or_k_kpe = task_counter.at(task_id, 0);
|
||||
size_t qlen_idx = task_counter.at(task_id, 1);
|
||||
if (q_or_k_kpe == 0) {
|
||||
T_RMSNorm::rms_norm_single_with_weights(config.q_lora_rank, local_q_a_norm,
|
||||
q_lora_rank_out_ref.offset_row(qlen_idx, 1).data);
|
||||
} else if (q_or_k_kpe == 1) {
|
||||
auto kv_page = (qlen_idx + kvlens[query]) / config.token_count_in_page;
|
||||
auto token_at_in_page = (qlen_idx + kvlens[query]) % config.token_count_in_page;
|
||||
KMatRef kv_lora_page_ref = kv_lora_page_refs[page_tables[query][kv_page]];
|
||||
T_RMSNorm::rms_norm_single_with_weights(config.kv_lora_rank, local_kv_a_norm,
|
||||
kv_lora_page_ref.offset_col(token_at_in_page, 1).data);
|
||||
} else {
|
||||
throw std::runtime_error("unknown task");
|
||||
}
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("q lora norm[%d]\n", tp_part_idx);
|
||||
dump_bin(file_name + "_qlora_norm.bin", q_lora_rank, qlens[query] * config.q_lora_rank);
|
||||
// for (int i = 0; i < query_page_count; i++) {
|
||||
// dump_bin(file_name + "_page_" + std::to_string(i) + "_kv_lora_rank_norm",
|
||||
// (A *)kv_lora_pages[page_tables[query][i]], config.token_count_in_page * config.kv_lora_rank);
|
||||
// }
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("q/kv lora norm");
|
||||
#endif
|
||||
|
||||
{
|
||||
// q rope & nope
|
||||
size_t qlen_block = div_up((size_t)qlens[query], col_block);
|
||||
TaskCounter task_counter({config.num_heads, 2, qlen_block});
|
||||
|
||||
auto task = [&](int task_id) {
|
||||
auto head_idx = task_counter.at(task_id, 0);
|
||||
bool nope_or_rope = (task_counter.at(task_id, 1) == 0);
|
||||
auto qlen_block_idx = task_counter.at(task_id, 2);
|
||||
size_t qlen_begin = qlen_block_idx * col_block;
|
||||
size_t qlen_end = std::min(qlen_begin + col_block, (size_t)qlens[query]);
|
||||
|
||||
KMatRef b = q_lora_rank_out_ref.trans_view().offset_col(qlen_begin, qlen_end - qlen_begin);
|
||||
if (nope_or_rope) {
|
||||
auto a = local_q_b_proj_ref.offset_row(head_idx * (config.nope_size + config.rope_size), config.nope_size);
|
||||
KMatRef c = use_absorb ? q_nope_tmp_ref : q_nope_ref;
|
||||
|
||||
c = c.offset_col(head_idx * qlens[query] + qlen_begin, qlen_end - qlen_begin);
|
||||
|
||||
arm_kml::decode_mul_mat_clearc(a, b, c);
|
||||
} else {
|
||||
auto a = local_q_b_proj_ref.offset_row(head_idx * (config.nope_size + config.rope_size) + config.nope_size,
|
||||
config.rope_size);
|
||||
KMatRef c = use_absorb ? q_pe_absorb_ref : q_pe_noabsorb_ref;
|
||||
c = c.offset_col(head_idx * qlens[query] + qlen_begin, qlen_end - qlen_begin);
|
||||
|
||||
arm_kml::decode_mul_mat_clearc(a, b, c);
|
||||
T_RopeApplier::apply_multiple(*rope_angle, c.data, config.rope_size, c.ld, qlen_begin + kvlens[query],
|
||||
qlen_end - qlen_begin);
|
||||
}
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("q nope/rope[%d]\n", tp_part_idx);
|
||||
if (use_absorb) {
|
||||
dump_bin(file_name + "_q_nope", q_nope, config.nope_size * qlens[query]);
|
||||
} else {
|
||||
dump_bin(file_name + "_q_nope", q, q_ld * qlens[query]);
|
||||
}
|
||||
dump_bin(file_name + "_q", q, q_ld * qlens[query]);
|
||||
dump_bin(file_name + "_rope_cos", rope_angle->cos(0), config.rope_size / 2 * qlens[query]);
|
||||
dump_bin(file_name + "_rope_sin", rope_angle->sin(0), config.rope_size / 2 * qlens[query]);
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("decode q nope/rope");
|
||||
#endif
|
||||
|
||||
nope_attention_q_absorb(qlens[query], kvlens[query], page_tables[query], false, true);
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("attention weights[%d] \n", tp_part_idx);
|
||||
for (size_t i = 0; i < config.num_heads; i++)
|
||||
dump_bin(file_name + "_raw_attention_weights_" + std::to_string(i),
|
||||
attention_weights_ref.offset_col(i * qlens[query], qlens[query]).data, config.max_kvlen * qlens[query]);
|
||||
#endif
|
||||
{
|
||||
// attentino mask & soft max
|
||||
auto task_counter = TaskCounter({config.num_heads, (size_t)qlens[query]});
|
||||
auto task = [&](int task_id) {
|
||||
size_t head_idx = task_counter.at(task_id, 0);
|
||||
size_t qlen_idx = task_counter.at(task_id, 1);
|
||||
size_t qlen_from_start = qlen_idx + kvlens[query];
|
||||
A* aw = offset_pointer(attention_weights,
|
||||
(config.max_kvlen * qlens[query] * head_idx + config.max_kvlen * qlen_idx) * sizeof(A));
|
||||
for (int i = 0; i < kvlens[query] + qlens[query]; i++) {
|
||||
aw[i] *= softmax_scale;
|
||||
aw[i] += static_cast<A*>(attention_masks[qlen_from_start])[i];
|
||||
}
|
||||
|
||||
T_SoftmaxApplier::apply_single(aw, kvlens[query] + qlens[query]);
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("attention weights after softmax[%d] \n", tp_part_idx);
|
||||
for (size_t i = 0; i < config.num_heads; i++)
|
||||
dump_bin(file_name + "_attention_weights_" + std::to_string(i),
|
||||
attention_weights_ref.offset_col(i * qlens[query], qlens[query]).data, config.max_kvlen * qlens[query]);
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("attention mask & softmax");
|
||||
#endif
|
||||
output_absorb(query, qlens, page_tables, kvlens, true);
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("attn output done[%d]\n", tp_part_idx);
|
||||
|
||||
#endif
|
||||
// 量化输入attention_output_ref->data (attention_output) [config.nope_size * config.num_heads, qlens[query]] col
|
||||
// major
|
||||
{
|
||||
size_t nth = GemmKernel::recommended_nth(qlens[query]);
|
||||
auto task_counter = TaskCounter({nth});
|
||||
auto task = [&](int task_id) {
|
||||
size_t nth_idx = task_counter.at(task_id, 0);
|
||||
local_w_o_quant_ba->from_mat(qlens[query], attention_output_ref.data, nth_idx, nth);
|
||||
};
|
||||
DIRECT_OR_POOL_BY(qlens[query], 10, task_counter.count(), task);
|
||||
}
|
||||
|
||||
{
|
||||
KMatRefA local_w_o_ba_ref = KMatRefA(local_w_o_quant_ba->a, config.nope_size * config.num_heads, qlens[query],
|
||||
config.nope_size * config.num_heads, CblasColMajor);
|
||||
auto qlen_block = div_up((size_t)qlens[query], decode_col_block_out_by_head);
|
||||
auto hidden_size_block = div_up((size_t)config.hidden_size, decode_row_block_out_by_head);
|
||||
auto task_counter = TaskCounter({div_up(config.num_heads, sub_num_heads_decode), qlen_block, hidden_size_block});
|
||||
auto task = [&](int task_id) {
|
||||
size_t head_idx = task_counter.at(task_id, 0);
|
||||
size_t qlen_block_idx = task_counter.at(task_id, 1);
|
||||
size_t hidden_size_block_idx = task_counter.at(task_id, 2);
|
||||
|
||||
size_t qlen_begin = qlen_block_idx * decode_col_block_out_by_head;
|
||||
size_t qlen_end = std::min(qlen_begin + decode_col_block_out_by_head, (size_t)qlens[query]);
|
||||
size_t hidden_size_begin = hidden_size_block_idx * decode_row_block_out_by_head;
|
||||
size_t hidden_size_end = std::min(hidden_size_begin + decode_row_block_out_by_head, (size_t)config.hidden_size);
|
||||
size_t head_begin = head_idx * config.nope_size * sub_num_heads_decode;
|
||||
size_t head_end =
|
||||
std::min(head_begin + config.nope_size * sub_num_heads_decode, (size_t)config.nope_size * config.num_heads);
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
// printf("tp idx: "
|
||||
// "%d,head_idx:%d,head_begin:%d,head_end:%d,qlen_begin:%d,qlen_end:%d,hidden_size_begin:%d,hidden_size_"
|
||||
// "end:%d,qlen_sum:%d,qlen_split:%d,qlens[query]:%d\n",
|
||||
// tp_part_idx, head_idx, head_begin, head_end, qlen_begin, qlen_end, hidden_size_begin, hidden_size_end,
|
||||
// qlen_sum, qlen_split[query], qlens[query]);
|
||||
#endif
|
||||
|
||||
auto output_ref_buffer =
|
||||
KMatRefC(local_w_o_decode_bc[head_idx]->c, config.hidden_size, qlen_sum, config.hidden_size, CblasColMajor);
|
||||
KMatRefC qlen_output_ref_buffer = output_ref_buffer.offset_col(qlen_split[query], qlens[query]);
|
||||
KMatRefB this_local_w_o_ref =
|
||||
KMatRefB(local_w_o_decode_bb[head_idx]->b, config.hidden_size, sub_num_heads_decode * config.nope_size,
|
||||
sub_num_heads_decode * config.nope_size, CblasRowMajor, CblasNoTrans,
|
||||
local_w_o_decode_bb[head_idx]->if_pack)
|
||||
.offset_row(hidden_size_begin, hidden_size_end - hidden_size_begin);
|
||||
KMatRefA this_attention_output_ref =
|
||||
local_w_o_ba_ref.offset_block(head_begin, qlen_begin, head_end - head_begin, qlen_end - qlen_begin);
|
||||
|
||||
KMatRefC qlen_output_ref = qlen_output_ref_buffer.offset_block(
|
||||
hidden_size_begin, qlen_begin, hidden_size_end - hidden_size_begin, qlen_end - qlen_begin);
|
||||
|
||||
arm_kml::decode_mul_mat_clearc(this_attention_output_ref.trans_view(), this_local_w_o_ref.trans_view(),
|
||||
qlen_output_ref.trans_view());
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
dump_bin(file_name + "_output_by_head_quant_" + std::to_string(head_idx), local_w_o_decode_bc[head_idx]->c,
|
||||
qlens[query] * config.hidden_size);
|
||||
|
||||
#endif
|
||||
KMatRef qlen_output_out_ref =
|
||||
KMatRef(qlen_quant_output[head_idx], config.hidden_size, qlens[query], config.hidden_size, CblasColMajor);
|
||||
// GemmKernel::apply_scale(qlen_output_out_ref.data, qlen_output_out_ref.ld, local_w_o_decode_bb[head_idx],
|
||||
// local_w_o_quant_ba, local_w_o_decode_bc[head_idx], hidden_size_begin,
|
||||
// hidden_size_end, qlen_begin, qlen_end, false, 0, qlen_split[query]);
|
||||
GemmKernel::apply_scale(qlen_output_out_ref.data, qlen_output_out_ref.ld, local_w_o_quant_ba,
|
||||
local_w_o_decode_bb[head_idx], local_w_o_decode_bc[head_idx], qlen_begin, qlen_end,
|
||||
hidden_size_begin, hidden_size_end, true, qlen_split[query], 0);
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("decode output by head done[%d]\n", tp_part_idx);
|
||||
for (int i = 0; i < div_up(config.num_heads, sub_num_heads_decode); i++) {
|
||||
dump_bin(file_name + "_output_by_head_dequant_" + std::to_string(i), qlen_quant_output[i],
|
||||
qlens[query] * config.hidden_size);
|
||||
}
|
||||
dump_bin(file_name + "_local_w_o", local_w_o, config.hidden_size * config.nope_size * config.num_heads);
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("output by head");
|
||||
#endif
|
||||
|
||||
{
|
||||
// merge output
|
||||
KMatRef reduce_qlen_output_ref = output_ref.offset_row(qlen_split[query], qlens[query]);
|
||||
|
||||
const size_t sum_block = 1024;
|
||||
const size_t sum_block_count = div_up(config.hidden_size, sum_block);
|
||||
auto task_counter = TaskCounter({sum_block_count, (size_t)qlens[query]});
|
||||
pool->do_work_stealing_job(task_counter.count(), [&](int task_id) {
|
||||
size_t hidden_idx = task_counter.at(task_id, 0);
|
||||
size_t hidden_begin = hidden_idx * sum_block;
|
||||
size_t hidden_end = std::min(hidden_begin + sum_block, (size_t)config.hidden_size);
|
||||
size_t qlen_idx = task_counter.at(task_id, 1);
|
||||
reduce_sum(qlen_quant_output.data(), div_up(config.num_heads, sub_num_heads_decode),
|
||||
qlen_idx * config.hidden_size + hidden_begin, qlen_idx * config.hidden_size + hidden_end);
|
||||
});
|
||||
|
||||
memcpy(reduce_qlen_output_ref.data, qlen_quant_output[0], qlens[query] * config.hidden_size * sizeof(A));
|
||||
}
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
dump_bin(file_name + "_qlen_output", (A*)output_raw, qlens[query] * config.hidden_size);
|
||||
#endif
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("decode merge output tp");
|
||||
#endif
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
time_perf_name = "[mla] layer " + std::to_string(config.layer_idx) +
|
||||
" tp_part_idx: " + std::to_string(tp_part_idx) + ", query: " + std::to_string(query);
|
||||
perf_report();
|
||||
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -1,296 +0,0 @@
|
||||
// #define DEBUG_THIS_MLA
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
#include "test/debug.hpp"
|
||||
#endif
|
||||
#include <cstdio>
|
||||
|
||||
#include "la/arm_kml.hpp"
|
||||
#include "mla_quan.h"
|
||||
|
||||
template <typename A, class KERNEL>
|
||||
void KML_MLA_TP_QUAN_TEST<A, KERNEL>::forward_prefill(const std::vector<int>& qlens,
|
||||
const std::vector<std::vector<int>>& page_tables,
|
||||
const std::vector<int>& kvlens,
|
||||
const std::vector<void*>& attention_masks,
|
||||
const input_t* input_raw, output_t* output_raw) {
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("input raw[%d]\n", tp_part_idx);
|
||||
#endif
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
forward_perf_start();
|
||||
#endif
|
||||
|
||||
auto pool = config.pool->get_subpool(tp_part_idx);
|
||||
|
||||
std::vector<int> qlen_split, total_len_split;
|
||||
qlen_split.reserve(qlens.size() + 1);
|
||||
qlen_split.push_back(0);
|
||||
total_len_split.reserve(qlens.size() + 1);
|
||||
int qlen_sum = 0;
|
||||
int total_len_sum = 0;
|
||||
for (size_t i = 0; i < qlens.size(); i++) {
|
||||
qlen_sum += qlens[i];
|
||||
qlen_split.push_back(qlen_sum);
|
||||
|
||||
total_len_sum += qlens[i] + kvlens[i];
|
||||
total_len_split.push_back(total_len_sum);
|
||||
}
|
||||
|
||||
// 输入进行量化,输入是 hidden_size * qlen 且是 col major
|
||||
{
|
||||
size_t nth = GemmKernel::recommended_nth(qlen_sum);
|
||||
auto task_counter = TaskCounter({nth});
|
||||
auto task = [&](int task_id) {
|
||||
size_t nth_idx = task_counter.at(task_id, 0);
|
||||
// local_q_a_proj_deprecated_bb->from_mat(const_cast<A *>(input_raw), nth_idx, nth, qlen_sum);
|
||||
local_q_a_proj_quant_ba->from_mat(qlen_sum, const_cast<A*>(input_raw), nth_idx, nth);
|
||||
};
|
||||
DIRECT_OR_POOL_BY(qlen_sum, 10, task_counter.count(), task);
|
||||
}
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("input_quant");
|
||||
#endif
|
||||
auto input_ref =
|
||||
KMatRefA(local_q_a_proj_quant_ba->a, qlen_sum, config.hidden_size, config.hidden_size, CblasRowMajor);
|
||||
auto output_ref = KMatRef(output_raw, qlen_sum, config.hidden_size, config.hidden_size, CblasRowMajor);
|
||||
|
||||
auto output_ref_buffer =
|
||||
KMatRefC(local_w_o_prefill_bc->c, qlen_sum, config.hidden_size, config.hidden_size, CblasRowMajor);
|
||||
KMatRefC q_lora_rank_ref =
|
||||
KMatRefC(local_q_a_proj_quant_bc->c, qlen_sum, config.q_lora_rank, config.q_lora_rank, CblasRowMajor);
|
||||
KMatRef q_lora_rank_out_ref =
|
||||
KMatRef((A*)q_lora_rank, qlen_sum, config.q_lora_rank, config.q_lora_rank, CblasRowMajor);
|
||||
|
||||
for (int query = 0; query < qlens.size(); query++) {
|
||||
bool use_absorb = decide_absorb(qlens[query], kvlens[query]);
|
||||
q_lora_kv_lora_rope_quant(query, qlens, kvlens, page_tables, qlen_split, input_ref, q_lora_rank_ref,
|
||||
q_lora_rank_out_ref);
|
||||
{
|
||||
// norm
|
||||
auto task_counter = TaskCounter({2, (size_t)qlens[query]});
|
||||
auto task = [&](int task_id) {
|
||||
int q_or_k_kpe = task_counter.at(task_id, 0);
|
||||
size_t qlen_idx = task_counter.at(task_id, 1);
|
||||
if (q_or_k_kpe == 0) {
|
||||
T_RMSNorm::rms_norm_single_with_weights(config.q_lora_rank, local_q_a_norm,
|
||||
q_lora_rank_out_ref.offset_row(qlen_idx, 1).data);
|
||||
} else if (q_or_k_kpe == 1) {
|
||||
auto kv_page = (qlen_idx + kvlens[query]) / config.token_count_in_page;
|
||||
auto token_at_in_page = (qlen_idx + kvlens[query]) % config.token_count_in_page;
|
||||
KMatRef kv_lora_page_ref = kv_lora_page_refs[page_tables[query][kv_page]];
|
||||
T_RMSNorm::rms_norm_single_with_weights(config.kv_lora_rank, local_kv_a_norm,
|
||||
kv_lora_page_ref.offset_col(token_at_in_page, 1).data);
|
||||
} else {
|
||||
throw std::runtime_error("unknown task");
|
||||
}
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("q lora norm[%d]\n", tp_part_idx);
|
||||
dump_bin(file_name + "_qlora_norm.bin", q_lora_rank, qlens[query] * config.q_lora_rank);
|
||||
// for (int i = 0; i < query_page_count; i++) {
|
||||
// dump_bin(file_name + "_page_" + std::to_string(i) + "_kv_lora_rank_norm",
|
||||
// (A *)kv_lora_pages[page_tables[query][i]], config.token_count_in_page * config.kv_lora_rank);
|
||||
// }
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("q/kv lora norm");
|
||||
#endif
|
||||
|
||||
{
|
||||
// q rope & nope
|
||||
size_t qlen_block = div_up((size_t)qlens[query], col_block);
|
||||
TaskCounter task_counter({config.num_heads, 2, qlen_block});
|
||||
|
||||
auto task = [&](int task_id) {
|
||||
auto head_idx = task_counter.at(task_id, 0);
|
||||
bool nope_or_rope = (task_counter.at(task_id, 1) == 0);
|
||||
auto qlen_block_idx = task_counter.at(task_id, 2);
|
||||
size_t qlen_begin = qlen_block_idx * col_block;
|
||||
size_t qlen_end = std::min(qlen_begin + col_block, (size_t)qlens[query]);
|
||||
|
||||
KMatRef b = q_lora_rank_out_ref.trans_view().offset_col(qlen_begin, qlen_end - qlen_begin);
|
||||
if (nope_or_rope) {
|
||||
auto a = local_q_b_proj_ref.offset_row(head_idx * (config.nope_size + config.rope_size), config.nope_size);
|
||||
KMatRef c = use_absorb ? q_nope_tmp_ref : q_nope_ref;
|
||||
|
||||
c = c.offset_col(head_idx * qlens[query] + qlen_begin, qlen_end - qlen_begin);
|
||||
|
||||
arm_kml::mul_mat_clearc(a, b, c);
|
||||
} else {
|
||||
auto a = local_q_b_proj_ref.offset_row(head_idx * (config.nope_size + config.rope_size) + config.nope_size,
|
||||
config.rope_size);
|
||||
KMatRef c = use_absorb ? q_pe_absorb_ref : q_pe_noabsorb_ref;
|
||||
c = c.offset_col(head_idx * qlens[query] + qlen_begin, qlen_end - qlen_begin);
|
||||
|
||||
arm_kml::mul_mat_clearc(a, b, c);
|
||||
T_RopeApplier::apply_multiple(*rope_angle, c.data, config.rope_size, c.ld, qlen_begin + kvlens[query],
|
||||
qlen_end - qlen_begin);
|
||||
}
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("q nope/rope[%d]\n", tp_part_idx);
|
||||
if (use_absorb) {
|
||||
dump_bin(file_name + "_q_nope", q_nope, config.nope_size * qlens[query]);
|
||||
} else {
|
||||
dump_bin(file_name + "_q_nope", q, q_ld * qlens[query]);
|
||||
}
|
||||
dump_bin(file_name + "_q", q, q_ld * qlens[query]);
|
||||
dump_bin(file_name + "_rope_cos", rope_angle->cos(0), config.rope_size / 2 * qlens[query]);
|
||||
dump_bin(file_name + "_rope_sin", rope_angle->sin(0), config.rope_size / 2 * qlens[query]);
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("q nope/rope");
|
||||
#endif
|
||||
|
||||
if (use_absorb) {
|
||||
nope_attention_q_absorb(qlens[query], kvlens[query], page_tables[query], false);
|
||||
} else {
|
||||
nope_attention_no_absorb(qlens[query], kvlens[query], page_tables[query], false);
|
||||
}
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("attention weights[%d] \n", tp_part_idx);
|
||||
for (size_t i = 0; i < config.num_heads; i++)
|
||||
dump_bin(file_name + "_raw_attention_weights_" + std::to_string(i),
|
||||
attention_weights_ref.offset_col(i * qlens[query], qlens[query]).data, config.max_kvlen * qlens[query]);
|
||||
#endif
|
||||
{
|
||||
// attentino mask & soft max
|
||||
auto task_counter = TaskCounter({config.num_heads, (size_t)qlens[query]});
|
||||
auto task = [&](int task_id) {
|
||||
size_t head_idx = task_counter.at(task_id, 0);
|
||||
size_t qlen_idx = task_counter.at(task_id, 1);
|
||||
size_t qlen_from_start = qlen_idx + kvlens[query];
|
||||
A* aw = offset_pointer(attention_weights,
|
||||
(config.max_kvlen * qlens[query] * head_idx + config.max_kvlen * qlen_idx) * sizeof(A));
|
||||
for (int i = 0; i < kvlens[query] + qlens[query]; i++) {
|
||||
aw[i] *= softmax_scale;
|
||||
aw[i] += static_cast<A*>(attention_masks[qlen_from_start])[i];
|
||||
}
|
||||
|
||||
T_SoftmaxApplier::apply_single(aw, kvlens[query] + qlens[query]);
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("attention weights after softmax[%d] \n", tp_part_idx);
|
||||
for (size_t i = 0; i < config.num_heads; i++)
|
||||
dump_bin(file_name + "_attention_weights_" + std::to_string(i),
|
||||
attention_weights_ref.offset_col(i * qlens[query], qlens[query]).data, config.max_kvlen * qlens[query]);
|
||||
#endif
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("attention mask & softmax");
|
||||
#endif
|
||||
if (use_absorb) {
|
||||
output_absorb(query, qlens, page_tables, kvlens);
|
||||
} else {
|
||||
output_no_absorb(query, qlens, page_tables, kvlens);
|
||||
}
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("attn output done[%d]\n", tp_part_idx);
|
||||
|
||||
#endif
|
||||
// 量化输入attention_output_ref->data (attention_output) [config.nope_size * config.num_heads, qlens[query]] col
|
||||
// major
|
||||
{
|
||||
size_t nth = GemmKernel::recommended_nth(qlens[query]);
|
||||
auto task_counter = TaskCounter({nth});
|
||||
auto task = [&](int task_id) {
|
||||
size_t nth_idx = task_counter.at(task_id, 0);
|
||||
// local_w_o_bb->from_mat(attention_output_ref.data, nth_idx, nth, qlens[query]);
|
||||
local_w_o_quant_ba->from_mat(qlens[query], attention_output_ref.data, nth_idx, nth);
|
||||
};
|
||||
DIRECT_OR_POOL_BY(qlens[query], 10, task_counter.count(), task);
|
||||
}
|
||||
|
||||
{
|
||||
// output
|
||||
|
||||
auto qlen_block = div_up((size_t)qlens[query], col_block_out_by_head);
|
||||
auto hidden_size_block = div_up((size_t)config.hidden_size, row_block_out_by_head);
|
||||
KMatRefA local_w_o_ba_ref = KMatRefA(local_w_o_quant_ba->a, config.nope_size * config.num_heads, qlens[query],
|
||||
config.nope_size * config.num_heads, CblasColMajor);
|
||||
KMatRef reduce_qlen_output_ref = output_ref.offset_row(qlen_split[query], qlens[query]).trans_view();
|
||||
KMatRefC reduce_qlen_output_ref_buffer =
|
||||
output_ref_buffer.offset_row(qlen_split[query], qlens[query]).trans_view();
|
||||
|
||||
for (size_t mhead_idx = 0; mhead_idx < config.num_heads / sub_num_heads; mhead_idx++) {
|
||||
auto task_counter = TaskCounter({qlen_block, hidden_size_block});
|
||||
|
||||
auto task = [&](int task_id) {
|
||||
size_t head_begin = config.nope_size * mhead_idx * sub_num_heads;
|
||||
size_t head_end = head_begin + config.nope_size * sub_num_heads;
|
||||
|
||||
size_t qlen_block_idx = task_counter.at(task_id, 0);
|
||||
size_t hidden_size_block_idx = task_counter.at(task_id, 1);
|
||||
|
||||
size_t qlen_begin = qlen_block_idx * col_block_out_by_head;
|
||||
size_t qlen_end = std::min(qlen_begin + col_block_out_by_head, (size_t)qlens[query]);
|
||||
size_t hidden_size_begin = hidden_size_block_idx * row_block_out_by_head;
|
||||
size_t hidden_size_end = std::min(hidden_size_begin + row_block_out_by_head, (size_t)config.hidden_size);
|
||||
KMatRefC this_qlen_output_ref_buffer = reduce_qlen_output_ref_buffer.offset_block(
|
||||
hidden_size_begin, qlen_begin, hidden_size_end - hidden_size_begin, qlen_end - qlen_begin);
|
||||
|
||||
KMatRefB this_local_w_o_ref = local_w_o_ref.offset_block(
|
||||
hidden_size_begin, head_begin, hidden_size_end - hidden_size_begin, head_end - head_begin);
|
||||
KMatRefA this_attention_output_ref =
|
||||
local_w_o_ba_ref.offset_block(head_begin, qlen_begin, head_end - head_begin, qlen_end - qlen_begin);
|
||||
if (mhead_idx == 0) {
|
||||
// arm_kml::mul_mat_clearc(this_local_w_o_ref, this_attention_output_ref, this_qlen_output_ref_buffer);
|
||||
arm_kml::mul_mat_clearc(this_attention_output_ref.trans_view(), this_local_w_o_ref.trans_view(),
|
||||
this_qlen_output_ref_buffer.trans_view());
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
dump_bin(file_name + "_output_by_head_" + std::to_string(mhead_idx), this_qlen_output_ref_buffer.data,
|
||||
qlens[query] * config.hidden_size);
|
||||
// printf("if pack: %d\n", this_local_w_o_ref.trans_view().if_pack);
|
||||
#endif
|
||||
} else {
|
||||
// arm_kml::mul_mat(this_local_w_o_ref, this_attention_output_ref, this_qlen_output_ref_buffer);
|
||||
arm_kml::mul_mat(this_attention_output_ref.trans_view(), this_local_w_o_ref.trans_view(),
|
||||
this_qlen_output_ref_buffer.trans_view());
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
dump_bin(file_name + "_output_by_head_" + std::to_string(mhead_idx), this_qlen_output_ref_buffer.data,
|
||||
qlens[query] * config.hidden_size);
|
||||
// printf("if pack: %d\n", this_local_w_o_ref.trans_view().if_pack);
|
||||
#endif
|
||||
}
|
||||
if (mhead_idx == config.num_heads / sub_num_heads - 1) {
|
||||
GemmKernel::apply_scale(reduce_qlen_output_ref.data, reduce_qlen_output_ref.ld, local_w_o_quant_ba,
|
||||
local_w_o_quant_bb, local_w_o_prefill_bc, qlen_begin, qlen_end, hidden_size_begin,
|
||||
hidden_size_end, true, qlen_split[query], 0);
|
||||
// GemmKernel::apply_scale(reduce_qlen_output_ref.data, reduce_qlen_output_ref.ld, local_w_o_quant_bb,
|
||||
// local_w_o_quant_ba, local_w_o_prefill_bc, hidden_size_begin, hidden_size_end,
|
||||
// qlen_begin, qlen_end, false, 0, qlen_split[query]);
|
||||
}
|
||||
};
|
||||
pool->do_work_stealing_job(task_counter.count(), task);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef DEBUG_THIS_MLA
|
||||
printf("output by head done[%d]\n", tp_part_idx);
|
||||
|
||||
// dump_bin(file_name + "_local_w_o", local_w_o, config.hidden_size * config.nope_size * config.num_heads);
|
||||
dump_bin(file_name + "_qlen_output", (A*)output_raw, qlens[query] * config.hidden_size);
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
PROFILE_RECORD_TIME_STAMP("output by head");
|
||||
#endif
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
time_perf_name = "[mla] layer " + std::to_string(config.layer_idx) +
|
||||
" tp_part_idx: " + std::to_string(tp_part_idx) + ", query: " + std::to_string(query);
|
||||
perf_report();
|
||||
|
||||
#endif
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,31 +11,20 @@
|
||||
#ifndef CPUINFER_OPERATOR_KVCACHE_H
|
||||
#define CPUINFER_OPERATOR_KVCACHE_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <condition_variable>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "../../cpu_backend/worker_pool.h"
|
||||
#include "llama.cpp/ggml-common.h"
|
||||
#include "llama.cpp/ggml-impl.h"
|
||||
#include "llama.cpp/ggml-quants.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#include "llamafile/sgemm.h"
|
||||
|
||||
#define CHUNK_SIZE 32
|
||||
|
||||
|
||||
@@ -9,8 +9,11 @@
|
||||
**/
|
||||
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "kvcache.h"
|
||||
#include "llamafile/sgemm.h"
|
||||
|
||||
void KVCache::attention_kvhead_(const uint16_t* q_in_data, ggml_fp16_t* output, float* attn_lse, int batch_size,
|
||||
WorkerPool* backend) {
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
**/
|
||||
|
||||
#include <chrono>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "kvcache.h"
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "kvcache.h"
|
||||
|
||||
void KVCache::get_anchor_one_block(ggml_fp16_t* anchor, int layer_id, int block_idx, WorkerPool* backend) {
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "kvcache.h"
|
||||
|
||||
std::string ggml_type_to_string(ggml_type type) {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#ifndef LLAMAFILE_MOE_HPP
|
||||
#define LLAMAFILE_MOE_HPP
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
#include <fmt/format.h>
|
||||
#endif
|
||||
#include <numa.h>
|
||||
#include <numaif.h>
|
||||
|
||||
@@ -10,8 +12,6 @@
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
#include "../../cpu_backend/shared_mem_buffer.h"
|
||||
@@ -737,8 +737,7 @@ class TP_MOE<LLAMA_MOE_TP> : public TP_MOE_Common<LLAMA_MOE_TP> {
|
||||
public:
|
||||
using TP_MOE_Common<LLAMA_MOE_TP>::TP_MOE_Common;
|
||||
|
||||
void load_weights(const uint64_t* physical_to_logical_map) {
|
||||
throw std::runtime_error("[llamafile] physical_to_logical_map not supported");
|
||||
void load_weights() {
|
||||
auto pool = this->config.pool;
|
||||
auto inter = this->config.intermediate_size / this->tp_count;
|
||||
pool->dispense_backend()->do_numa_job([this, pool, inter](int tp_id) {
|
||||
|
||||
@@ -20,7 +20,6 @@ concept MOE_TP_PART = requires(T t, int qlen, int k, const int64_t* expert_ids,
|
||||
template <MOE_TP_PART T>
|
||||
class TP_MOE_Common : public MoE_Interface {
|
||||
protected:
|
||||
GeneralMOEConfig config;
|
||||
std::vector<GeneralMOEConfig> tp_configs;
|
||||
int tp_count;
|
||||
int me_numa_id;
|
||||
@@ -35,6 +34,7 @@ class TP_MOE_Common : public MoE_Interface {
|
||||
size_t forward_count = 0;
|
||||
#endif
|
||||
public:
|
||||
GeneralMOEConfig config;
|
||||
using input_t = typename T::input_t;
|
||||
TP_MOE_Common(GeneralMOEConfig config) : config(config) {
|
||||
printf("TP MOE layer %d, pool: 0x%lx, expert num: %d, num_experts_per_tok: %d\n", config.layer_idx,
|
||||
@@ -142,7 +142,7 @@ class TP_MOE_Common : public MoE_Interface {
|
||||
#endif
|
||||
}
|
||||
|
||||
virtual void load_weights(const uint64_t* physical_to_logical_map) = 0;
|
||||
virtual void load_weights() = 0;
|
||||
|
||||
virtual void merge_results(int qlen, void* output) = 0;
|
||||
virtual void merge_results(int qlen, void* output, bool incremental) {
|
||||
|
||||
63
kt-kernel/operators/moe_kernel/api/common.h
Normal file
63
kt-kernel/operators/moe_kernel/api/common.h
Normal file
@@ -0,0 +1,63 @@
|
||||
// BOOST_STRONG_TYPEDEF(int8_t, int4_2_t);
|
||||
#pragma once
|
||||
#include <cstdint>
|
||||
|
||||
#include "llama.cpp/ggml.h"
|
||||
#if !defined(CPUINFER_HAS_FLOAT16_T)
|
||||
using float16_t = ggml_fp16_t;
|
||||
#define CPUINFER_HAS_FLOAT16_T 1
|
||||
#endif
|
||||
|
||||
#if !defined(CPUINFER_HAS_BFLOAT16_T)
|
||||
using bfloat16_t = ggml_bf16_t;
|
||||
#define CPUINFER_HAS_BFLOAT16_T 1
|
||||
#endif // CPUINFER_HAS_BFLOAT16_T
|
||||
const bool PACKED = true;
|
||||
#if defined(__aarch64__) || defined(__arm__) || defined(CPU_USE_KML)
|
||||
#ifndef CPU_USE_KML
|
||||
#define CPU_USE_KML
|
||||
#endif
|
||||
#endif // USE_MOE_KERNEL_AMD or CPU_USE_KML
|
||||
|
||||
#define STRONG_TYPEDEF(T, D) \
|
||||
struct D { \
|
||||
T t; \
|
||||
explicit D(const T &v) : t(v) {} \
|
||||
D() = default; \
|
||||
D(const D &) = default; \
|
||||
D &operator=(const D &) = default; \
|
||||
D &operator=(const T &rhs) { \
|
||||
t = rhs; \
|
||||
return *this; \
|
||||
} \
|
||||
operator const T &() const { return t; } \
|
||||
operator T &() { return t; } \
|
||||
bool operator==(const D &rhs) const { return t == rhs.t; } \
|
||||
bool operator!=(const D &rhs) const { return t != rhs.t; } \
|
||||
bool operator<(const D &rhs) const { return t < rhs.t; } \
|
||||
};
|
||||
STRONG_TYPEDEF(int8_t, int4_2_t)
|
||||
typedef int8_t BLASINT8;
|
||||
|
||||
/* matrix transpose or conjugate transpose */
|
||||
typedef enum KERNEL_CBLAS_TRANSPOSE {
|
||||
KernelCblasNoTrans = 111,
|
||||
KernelCblasTrans = 112,
|
||||
KernelCblasConjTrans = 113,
|
||||
KernelCblasConjNoTrans = 114
|
||||
} KERNEL_CBLAS_TRANSPOSE;
|
||||
/* matrix stored in rows or cols */
|
||||
typedef enum KERNEL_CBLAS_ORDER { KernelCblasRowMajor = 101, KernelCblasColMajor = 102 } KERNEL_CBLAS_ORDER;
|
||||
/* matrix position is left or right */
|
||||
typedef enum KERNEL_CBLAS_SIDE { KernelCblasLeft = 141, KernelCblasRight = 142 } KERNEL_CBLAS_SIDE;
|
||||
typedef KERNEL_CBLAS_ORDER KERNEL_CBLAS_LAYOUT;
|
||||
typedef enum KERNEL_CBLAS_OFFSET {
|
||||
KernelCblasRowOffset = 171,
|
||||
KernelCblasColOffset = 172,
|
||||
KernelCblasFixOffset = 173
|
||||
} KERNEL_CBLAS_OFFSET;
|
||||
|
||||
enum class MatKernelVariant {
|
||||
Decode,
|
||||
Prefill,
|
||||
};
|
||||
30
kt-kernel/operators/moe_kernel/api/mat_kernel.h
Normal file
30
kt-kernel/operators/moe_kernel/api/mat_kernel.h
Normal file
@@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
||||
#include "common.h"
|
||||
|
||||
using GemmFn = void (*)(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
|
||||
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
|
||||
const int8_t oa, const void* b, const size_t ldb, const int8_t ob, const float beta, int32_t* c,
|
||||
const size_t ldc, const int32_t* oc);
|
||||
|
||||
struct MatKernelSelection {
|
||||
GemmFn fn;
|
||||
int divide_elements_size;
|
||||
};
|
||||
|
||||
MatKernelSelection select_kernel_for_int4(MatKernelVariant variant);
|
||||
MatKernelSelection select_kernel_for_int8(MatKernelVariant variant);
|
||||
|
||||
template <typename T>
|
||||
MatKernelSelection select_mat_kernel(MatKernelVariant variant) {
|
||||
if constexpr (std::is_same_v<typename T::dt, int4_2_t>) {
|
||||
return select_kernel_for_int4(variant);
|
||||
} else {
|
||||
return select_kernel_for_int8(variant);
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,5 @@
|
||||
#ifndef CPUINFER_OPERATOR_KML_LA_HPP
|
||||
#define CPUINFER_OPERATOR_KML_LA_HPP
|
||||
|
||||
#include "batch_gemm_api.hpp"
|
||||
// #include <boost/serialization/strong_typedef.hpp>
|
||||
|
||||
// #include "../../common.hpp"
|
||||
#include <arm_sve.h>
|
||||
#ifndef CPUINFER_OPERATOR_KERNEL_LA_HPP
|
||||
#define CPUINFER_OPERATOR_KERNEL_LA_HPP
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@@ -13,35 +7,16 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "kblas.h"
|
||||
#include "../api/common.h"
|
||||
#include "../mat_kernel/batch_gemm_api.hpp"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#include "utils.hpp"
|
||||
|
||||
// BOOST_STRONG_TYPEDEF(int8_t, int4_2_t);
|
||||
#define STRONG_TYPEDEF(T, D) \
|
||||
struct D { \
|
||||
T t; \
|
||||
explicit D(const T &v) : t(v) {} \
|
||||
D() = default; \
|
||||
D(const D &) = default; \
|
||||
D &operator=(const D &) = default; \
|
||||
D &operator=(const T &rhs) { \
|
||||
t = rhs; \
|
||||
return *this; \
|
||||
} \
|
||||
operator const T &() const { return t; } \
|
||||
operator T &() { return t; } \
|
||||
bool operator==(const D &rhs) const { return t == rhs.t; } \
|
||||
bool operator!=(const D &rhs) const { return t != rhs.t; } \
|
||||
bool operator<(const D &rhs) const { return t < rhs.t; } \
|
||||
};
|
||||
STRONG_TYPEDEF(int8_t, int4_2_t);
|
||||
|
||||
namespace arm_kml {
|
||||
static const size_t MAX_Nth_B = 1024, MAX_N_B = 1024, MAX_K_B = 10240;
|
||||
namespace moe_kernel {
|
||||
template <typename T>
|
||||
T *offset_pointer(T *ptr, size_t byte_offset) {
|
||||
return reinterpret_cast<T *>(reinterpret_cast<char *>(ptr) + byte_offset);
|
||||
@@ -73,7 +48,8 @@ struct BufferAImpl {
|
||||
|
||||
static constexpr int M_STEP = K::M_STEP;
|
||||
static constexpr int K_STEP = K::K_STEP;
|
||||
static constexpr int K_BLOCK = K::K_BLOCK;
|
||||
// K_BLOCK is runtime-configurable via kernel tiling; expose as function to avoid constexpr requirements
|
||||
static inline int K_BLOCK() { return K::K_BLOCK; }
|
||||
static constexpr int PACK_SIZE_M = K::PACK_SIZE_M;
|
||||
static constexpr int PACK_SIZE_K = K::PACK_SIZE_K;
|
||||
|
||||
@@ -265,37 +241,6 @@ struct BufferAImpl {
|
||||
}
|
||||
}
|
||||
|
||||
void from_mat(int m, float16_t *src, int ith, int mth) {
|
||||
assert(m <= max_m);
|
||||
assert(ith == 0 && mth == 1);
|
||||
if (!(ith == 0 && mth == 1)) {
|
||||
throw std::runtime_error("m must be a multiple of M_STEP");
|
||||
}
|
||||
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
|
||||
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
|
||||
float amax = 0;
|
||||
// TODO: 后续用 SVE 来加速
|
||||
for (int j = 0; j < k; j++) {
|
||||
// 先把 src 转换成 float
|
||||
float f = src[(m_begin + i) * k + j];
|
||||
f = f < 0 ? -f : f;
|
||||
if (f > amax) {
|
||||
amax = f;
|
||||
}
|
||||
}
|
||||
d[m_begin + i] = amax / ((1 << 7) - 1);
|
||||
// TODO: 后续用 SVE 来加速
|
||||
// 通过这个 amax 来量化这一行
|
||||
for (int j = 0; j < k; j++) {
|
||||
// 先把 src 转换成 float
|
||||
float f = src[(m_begin + i) * k + j];
|
||||
// 这里的 amax 是当前行的最大值
|
||||
a[(m_begin + i) * k + j] = static_cast<int8_t>(std::round(f / d[m_begin + i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 反量化
|
||||
void to_mat(int m, float *dst, int ith, int mth) {
|
||||
auto [m_start, m_end] = K::split_range_m(m, ith, mth);
|
||||
@@ -323,7 +268,8 @@ struct BufferCImpl {
|
||||
|
||||
static constexpr int M_STEP = K::M_STEP;
|
||||
static constexpr int N_STEP = K::N_STEP;
|
||||
static constexpr int N_BLOCK = K::N_BLOCK;
|
||||
// N_BLOCK is runtime-configurable via kernel tiling; expose as function to avoid constexpr requirements
|
||||
static inline int N_BLOCK() { return K::N_BLOCK; }
|
||||
|
||||
static size_t required_size(int max_m, int n) { return sizeof(int32_t) * max_m * n; }
|
||||
|
||||
@@ -347,27 +293,6 @@ struct BufferCImpl {
|
||||
// }
|
||||
};
|
||||
|
||||
// struct MLAGemmKernelInt8 {
|
||||
// using dt = int8_t;
|
||||
// using output_t = int32_t;
|
||||
// struct BufferA {
|
||||
// int8_t *a;
|
||||
// float *d;
|
||||
// int max_m, k;
|
||||
|
||||
// BufferA(int max_m, int k, void *ptr) : max_m(max_m), k(k) {
|
||||
// assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
// assert(max_m % GemmKernelInt8::M_STEP == 0);
|
||||
// assert(k % GemmKernelInt8::K_STEP == 0);
|
||||
// a = reinterpret_cast<int8_t *>(ptr);
|
||||
// d = reinterpret_cast<float *>(a + max_m * k);
|
||||
// }
|
||||
|
||||
// void from_mat(int m, float16_t *src, int ith, int nth) {}
|
||||
// float *get_scale(int m, int m_begin) { return d + m_begin; }
|
||||
// };
|
||||
// };
|
||||
|
||||
struct GemmKernelInt8 {
|
||||
using dt = int8_t;
|
||||
using output_t = int32_t;
|
||||
@@ -385,12 +310,31 @@ struct GemmKernelInt8 {
|
||||
static const int K_STEP = 1;
|
||||
|
||||
// static inline const int N_BLOCK = 1024;
|
||||
static inline const int N_BLOCK_UP_GATE = 256;
|
||||
static inline const int N_BLOCK_DOWN = 1024;
|
||||
static inline const int N_BLOCK = 64;
|
||||
static inline const int M_BLOCK = 64;
|
||||
// Make tiling params runtime-configurable (modifiable via Python bindings)
|
||||
static inline int N_BLOCK_UP_GATE = 32;
|
||||
static inline int N_BLOCK_DOWN = 64;
|
||||
static inline int N_BLOCK_UP_GATE_PREFI = 32;
|
||||
static inline int N_BLOCK_DOWN_PREFI = 64;
|
||||
static inline int N_BLOCK = 64;
|
||||
static inline int M_BLOCK = 320;
|
||||
// static inline const int N_BLOCK = 32;
|
||||
static inline const int K_BLOCK = 7168;
|
||||
static inline int K_BLOCK = 7168;
|
||||
|
||||
// Setter/getter for runtime tiling configuration
|
||||
static void set_tiling(int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block,
|
||||
int n_block_up_gate_prefi, int n_block_down_prefi) {
|
||||
N_BLOCK_UP_GATE = n_block_up_gate;
|
||||
N_BLOCK_DOWN = n_block_down;
|
||||
N_BLOCK = n_block;
|
||||
M_BLOCK = m_block;
|
||||
K_BLOCK = k_block;
|
||||
N_BLOCK_UP_GATE_PREFI = n_block_up_gate_prefi;
|
||||
N_BLOCK_DOWN_PREFI = n_block_down_prefi;
|
||||
}
|
||||
static std::tuple<int, int, int, int, int, int, int> get_tiling() {
|
||||
return std::make_tuple(N_BLOCK_UP_GATE, N_BLOCK_DOWN, N_BLOCK, M_BLOCK, K_BLOCK, N_BLOCK_UP_GATE_PREFI,
|
||||
N_BLOCK_DOWN_PREFI);
|
||||
}
|
||||
|
||||
static inline const int PACK_SIZE_N = 8;
|
||||
static inline const int PACK_SIZE_M = 8;
|
||||
@@ -398,26 +342,44 @@ struct GemmKernelInt8 {
|
||||
|
||||
static std::string name() { return "INT8"; }
|
||||
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
|
||||
|
||||
static int recommended_nth_down(int n) {
|
||||
assert(n % N_BLOCK == 0);
|
||||
return n / N_BLOCK_DOWN;
|
||||
// type_: d for decode, p for prefill
|
||||
static int recommended_nth_down(int n, char type_ = 'd') {
|
||||
if (type_ == 'p') {
|
||||
if (n % N_BLOCK_DOWN_PREFI != 0) {
|
||||
throw std::invalid_argument("n must be multiple of N_BLOCK_DOWN_PREFI in prefill");
|
||||
}
|
||||
return n / N_BLOCK_DOWN_PREFI;
|
||||
} else {
|
||||
if (n % N_BLOCK_DOWN != 0) {
|
||||
throw std::invalid_argument("n must be multiple of N_BLOCK_DOWN in decode");
|
||||
}
|
||||
return n / N_BLOCK_DOWN;
|
||||
}
|
||||
}
|
||||
|
||||
static int recommended_nth_up_gate(int n) {
|
||||
assert(n % N_BLOCK_UP_GATE == 0);
|
||||
return n / N_BLOCK_UP_GATE;
|
||||
static int recommended_nth_up_gate(int n, char type_ = 'd') {
|
||||
if (type_ == 'p') {
|
||||
if (n % N_BLOCK_UP_GATE_PREFI != 0) {
|
||||
throw std::invalid_argument("n must be multiple of N_BLOCK_UP_GATE_PREFI in prefill");
|
||||
}
|
||||
return n / N_BLOCK_UP_GATE_PREFI;
|
||||
} else {
|
||||
if (n % N_BLOCK_UP_GATE != 0) {
|
||||
throw std::invalid_argument("n must be multiple of N_BLOCK_UP_GATE in decode");
|
||||
}
|
||||
return n / N_BLOCK_UP_GATE;
|
||||
}
|
||||
}
|
||||
|
||||
static int recommended_mth(int m) { return (m + M_BLOCK - 1) / M_BLOCK; }
|
||||
|
||||
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
|
||||
int n_start = N_BLOCK * ith;
|
||||
int n_end = std::min(n, N_BLOCK * (ith + 1));
|
||||
static std::pair<int, int> split_range_n(int n, int ith, int nth, int block_size = N_BLOCK) {
|
||||
int n_start = block_size * ith;
|
||||
int n_end = std::min(n, block_size * (ith + 1));
|
||||
return {n_start, n_end};
|
||||
}
|
||||
|
||||
static std::pair<int, int> split_range_m(int m, int ith, int mth) {
|
||||
static std::pair<int, int> split_range_m(int m, int ith, int mth = 0) {
|
||||
int m_start = M_BLOCK * ith;
|
||||
int m_end = std::min(m, M_BLOCK * (ith + 1));
|
||||
return {m_start, m_end};
|
||||
@@ -434,26 +396,83 @@ struct GemmKernelInt8 {
|
||||
|
||||
struct BufferB {
|
||||
int8_t *b;
|
||||
std::vector<int8_t *> b_pack; // b_pack[i] -> the ith block (the ith packed matrix of B)
|
||||
size_t reorder_B_size;
|
||||
size_t nth_B; // number of blocks of B
|
||||
size_t block_size; // size of each block of B
|
||||
float *d;
|
||||
int n, k;
|
||||
static constexpr bool SCALE = true;
|
||||
bool if_pack = false;
|
||||
|
||||
static size_t required_size(int n, int k) { return sizeof(int8_t) * n * k + sizeof(float) * n; }
|
||||
|
||||
BufferB(int n, int k, void *ptr, bool if_pack = false) : n(n), k(k), if_pack(if_pack) {
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
b = reinterpret_cast<int8_t *>(ptr);
|
||||
d = reinterpret_cast<float *>(b + n * k);
|
||||
// n for normal, u for up_gate, d for down
|
||||
static size_t required_size(int n, int k, bool if_pack = false, char mat_type = 'n', bool plain = true) {
|
||||
int nth, n_block;
|
||||
if (if_pack && !plain) {
|
||||
switch (mat_type) {
|
||||
case 'n':
|
||||
nth = recommended_nth(n);
|
||||
n_block = N_BLOCK;
|
||||
break;
|
||||
case 'u':
|
||||
nth = recommended_nth_up_gate(n);
|
||||
n_block = N_BLOCK_UP_GATE;
|
||||
break;
|
||||
case 'd':
|
||||
nth = recommended_nth_down(n);
|
||||
n_block = N_BLOCK_DOWN;
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("Invalid mat_type");
|
||||
}
|
||||
size_t reorder_B_size = get_reorder_B_size(KernelCblasRowMajor, KernelCblasNoTrans, k, n_block);
|
||||
return sizeof(int8_t) * nth * reorder_B_size + sizeof(float) * n;
|
||||
} else {
|
||||
return sizeof(int8_t) * n * k + sizeof(float) * n;
|
||||
}
|
||||
}
|
||||
BufferB(int n, int k, bool if_pack = false) : n(n), k(k), if_pack(if_pack) {
|
||||
BufferB(int n, int k, bool if_pack = false, char mat_type = 'n', bool plain = true) : n(n), k(k), if_pack(if_pack) {
|
||||
int nth, n_block;
|
||||
if (if_pack && !plain) {
|
||||
switch (mat_type) {
|
||||
case 'n':
|
||||
nth = recommended_nth(n);
|
||||
n_block = N_BLOCK;
|
||||
break;
|
||||
case 'u':
|
||||
nth = recommended_nth_up_gate(n);
|
||||
n_block = N_BLOCK_UP_GATE;
|
||||
break;
|
||||
case 'd':
|
||||
nth = recommended_nth_down(n);
|
||||
n_block = N_BLOCK_DOWN;
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("Invalid mat_type");
|
||||
}
|
||||
reorder_B_size = get_reorder_B_size(KernelCblasRowMajor, KernelCblasNoTrans, k, n_block);
|
||||
nth_B = nth;
|
||||
block_size = n_block;
|
||||
b_pack.resize(nth);
|
||||
}
|
||||
if (n % N_STEP != 0 || k % K_STEP != 0) {
|
||||
throw std::runtime_error("n and k must be multiples of N_STEP and K_STEP respectively");
|
||||
}
|
||||
}
|
||||
void set_data(void *ptr) {
|
||||
b = reinterpret_cast<int8_t *>(ptr);
|
||||
d = reinterpret_cast<float *>(b + n * k);
|
||||
BufferB(int n, int k, void *ptr, bool if_pack = false, char mat_type = 'n', bool plain = true)
|
||||
: BufferB(n, k, if_pack, mat_type, plain) {
|
||||
set_data(ptr, plain);
|
||||
// printf("mat_type:%c,nth_B:%zu,b_pack_ptr[0]:%p,d_ptr:%p,ptr:%p\n", mat_type, nth_B, b_pack[0], d, ptr);
|
||||
}
|
||||
void set_data(void *ptr, bool plain = true) {
|
||||
if (if_pack && !plain) {
|
||||
for (size_t i = 0; i < nth_B; i++) {
|
||||
b_pack[i] = reinterpret_cast<int8_t *>(ptr) + i * reorder_B_size;
|
||||
}
|
||||
d = reinterpret_cast<float *>((int8_t *)ptr + nth_B * reorder_B_size);
|
||||
} else {
|
||||
b = reinterpret_cast<int8_t *>(ptr);
|
||||
d = reinterpret_cast<float *>(b + n * k);
|
||||
}
|
||||
}
|
||||
size_t required_size() const { return sizeof(int8_t) * n * k + sizeof(float) * n; }
|
||||
BufferB offset_col(size_t col_begin, size_t col_block) {
|
||||
@@ -462,13 +481,17 @@ struct GemmKernelInt8 {
|
||||
return bufferb;
|
||||
}
|
||||
// B 矩阵是 K * N 的矩阵,存储在 b 中, 是列主序的 (column major)
|
||||
void from_mat(ggml_bf16_t *src, int ith, int nth, int n_new = -1,
|
||||
bool if_pack = false) { // CHECK: nth has no usage
|
||||
void from_mat(ggml_bf16_t *src, int ith, int nth, int n_new = -1, bool if_pack = false,
|
||||
bool plain = true) { // CHECK: nth has no usage
|
||||
if (n_new > 0) {
|
||||
n = n_new; // 如果 n_new 大于 0,则使用 n_new
|
||||
}
|
||||
// 这里将 src 转换成 int8_t 的形式,按照k 维度量化 (也就是按列量化)
|
||||
auto [n_start, n_end] = split_range_n(n, ith, nth);
|
||||
int8_t *b_t = nullptr;
|
||||
if ((if_pack || this->if_pack) && !plain) {
|
||||
b_t = (int8_t *)malloc(sizeof(int8_t) * n * k);
|
||||
}
|
||||
auto [n_start, n_end] = split_range_n(n, ith, nth, block_size);
|
||||
int n_block_begin = n_start;
|
||||
int n_block_size = n_end - n_block_begin;
|
||||
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
|
||||
@@ -489,7 +512,7 @@ struct GemmKernelInt8 {
|
||||
for (int j = 0; j < k; j++) {
|
||||
// 先把 src 转换成 float
|
||||
float f = bf16_to_fp32(src[(n_block_begin + n_begin + i) * k + j]);
|
||||
if (if_pack || this->if_pack) {
|
||||
if ((if_pack || this->if_pack) && plain) {
|
||||
size_t split_n = (n_begin + i) / PACK_SIZE_N;
|
||||
size_t n_idx = (n_begin + i) % PACK_SIZE_N;
|
||||
size_t split_k = j / PACK_SIZE_K;
|
||||
@@ -498,58 +521,20 @@ struct GemmKernelInt8 {
|
||||
size_t buff_idx = n_block_begin * k + split_n * PACK_SIZE_N * k + split_k * PACK_SIZE_N * PACK_SIZE_K +
|
||||
n_idx * PACK_SIZE_K + k_idx;
|
||||
b[buff_idx] = static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));
|
||||
} else {
|
||||
} else if ((if_pack || this->if_pack) && !plain) {
|
||||
// 这里的 amax 是当前列的最大值
|
||||
b_t[(n_begin + i) * k + j] = static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));
|
||||
} else {
|
||||
b[(n_block_begin + n_begin + i) * k + j] =
|
||||
static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void from_mat(float16_t *src, int ith, int nth, int n_new = -1, bool if_pack = false) { // CHECK: nth has no usage
|
||||
if (n_new > 0) {
|
||||
n = n_new; // 如果 n_new 大于 0,则使用 n_new
|
||||
}
|
||||
// 这里将 src 转换成 int8_t 的形式,按照k 维度量化 (也就是按列量化)
|
||||
auto [n_start, n_end] = split_range_n(n, ith, nth);
|
||||
int n_block_begin = n_start;
|
||||
int n_block_size = n_end - n_block_begin;
|
||||
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
|
||||
for (int i = 0; i < N_STEP && n_begin + i < n_block_size; i++) {
|
||||
float amax = 0;
|
||||
// TODO: 后续用 SVE 来加速
|
||||
for (int j = 0; j < k; j++) {
|
||||
// 先把 src 转换成 float
|
||||
float f = src[(n_block_begin + n_begin + i) * k + j];
|
||||
f = f < 0 ? -f : f;
|
||||
if (f > amax) {
|
||||
amax = f;
|
||||
}
|
||||
}
|
||||
d[n_block_begin + n_begin + i] = amax / ((1 << 7) - 1);
|
||||
// TODO: 后续用 SVE 来加速
|
||||
// 通过这个 amax 来量化这一列
|
||||
for (int j = 0; j < k; j++) {
|
||||
// 先把 src 转换成 float
|
||||
float f = src[(n_block_begin + n_begin + i) * k + j];
|
||||
if (if_pack || this->if_pack) {
|
||||
size_t split_n = (n_begin + i) / PACK_SIZE_N;
|
||||
size_t n_idx = (n_begin + i) % PACK_SIZE_N;
|
||||
size_t split_k = j / PACK_SIZE_K;
|
||||
size_t k_idx = j % PACK_SIZE_K;
|
||||
|
||||
size_t buff_idx = n_block_begin * k + split_n * PACK_SIZE_N * k + split_k * PACK_SIZE_N * PACK_SIZE_K +
|
||||
n_idx * PACK_SIZE_K + k_idx;
|
||||
b[buff_idx] = static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));
|
||||
} else {
|
||||
// 这里的 amax 是当前列的最大值
|
||||
b[(n_block_begin + n_begin + i) * k + j] =
|
||||
static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
if ((if_pack || this->if_pack) && !plain) {
|
||||
// 在这里调用 AMD 的reorder函数
|
||||
reorder_B_gemm(KernelCblasColMajor, KernelCblasNoTrans, k, n_block_size, k, b_t, b_pack[ith]);
|
||||
free(b_t);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -704,12 +689,19 @@ struct GemmKernelInt8 {
|
||||
}
|
||||
|
||||
// 对第二个维度分块的 apply scale
|
||||
static void apply_scale(int m, int n, float *c, BufferA *ba, BufferB *bb, BufferC *bc, int ith, int nth, int block) {
|
||||
static void apply_scale(int m, int n, float *c, BufferA *ba, BufferB *bb, BufferC *bc, int ith, int nth, int block,
|
||||
int jth = -1) {
|
||||
// printf("use split apply scale\n");
|
||||
auto [n_start, n_end] = split_range_n_block(n, ith, nth, block);
|
||||
int m_start = 0, m_end = m;
|
||||
if (jth != -1) {
|
||||
auto tmp = split_range_m(m, jth);
|
||||
m_start = tmp.first;
|
||||
m_end = tmp.second;
|
||||
}
|
||||
// TODO: 后续用 SVE 来加速
|
||||
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
|
||||
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
|
||||
for (int m_begin = m_start; m_begin < m_end; m_begin += M_STEP) {
|
||||
for (int i = 0; i < M_STEP && m_begin + i < m_end; i++) {
|
||||
float *scale_a = ba->get_scale(m, m_begin + i);
|
||||
for (int n_begin = n_start; n_begin < n_end; n_begin += N_STEP) {
|
||||
for (int j = 0; j < N_STEP && n_begin + j < n_end; j++) {
|
||||
@@ -811,12 +803,31 @@ struct GemmKernelInt4 {
|
||||
static const int K_STEP = 1;
|
||||
|
||||
// static inline const int N_BLOCK = 1024;
|
||||
static inline const int N_BLOCK_UP_GATE = 256;
|
||||
static inline const int N_BLOCK_DOWN = 1024;
|
||||
static inline const int N_BLOCK = 64;
|
||||
static inline const int M_BLOCK = 64;
|
||||
// Make tiling params runtime-configurable (modifiable via Python bindings)
|
||||
static inline int N_BLOCK_UP_GATE = 256;
|
||||
static inline int N_BLOCK_DOWN = 1024;
|
||||
static inline int N_BLOCK_UP_GATE_PREFI = 256;
|
||||
static inline int N_BLOCK_DOWN_PREFI = 1024;
|
||||
static inline int N_BLOCK = 64;
|
||||
static inline int M_BLOCK = 320;
|
||||
// static inline const int N_BLOCK = 32;
|
||||
static inline const int K_BLOCK = 7168;
|
||||
static inline int K_BLOCK = 7168;
|
||||
|
||||
// Setter/getter for runtime tiling configuration
|
||||
static void set_tiling(int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block,
|
||||
int n_block_up_gate_prefi, int n_block_down_prefi) {
|
||||
N_BLOCK_UP_GATE = n_block_up_gate;
|
||||
N_BLOCK_DOWN = n_block_down;
|
||||
N_BLOCK = n_block;
|
||||
M_BLOCK = m_block;
|
||||
K_BLOCK = k_block;
|
||||
N_BLOCK_UP_GATE_PREFI = n_block_up_gate_prefi;
|
||||
N_BLOCK_DOWN_PREFI = n_block_down_prefi;
|
||||
}
|
||||
static std::tuple<int, int, int, int, int, int, int> get_tiling() {
|
||||
return std::make_tuple(N_BLOCK_UP_GATE, N_BLOCK_DOWN, N_BLOCK, M_BLOCK, K_BLOCK, N_BLOCK_UP_GATE_PREFI,
|
||||
N_BLOCK_DOWN_PREFI);
|
||||
}
|
||||
|
||||
static inline const int PACK_SIZE_N = 8;
|
||||
static inline const int PACK_SIZE_K = 32;
|
||||
@@ -825,15 +836,33 @@ struct GemmKernelInt4 {
|
||||
static std::string name() { return "INT4"; }
|
||||
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
|
||||
|
||||
static int recommended_nth_down(int n) {
|
||||
assert(n % N_BLOCK == 0);
|
||||
return n / N_BLOCK_DOWN;
|
||||
static int recommended_nth_down(int n, char type_ = 'd') {
|
||||
if (type_ == 'p') {
|
||||
if (n % N_BLOCK_DOWN_PREFI != 0) {
|
||||
throw std::invalid_argument("n must be multiple of N_BLOCK_DOWN_PREFI in prefill");
|
||||
}
|
||||
return n / N_BLOCK_DOWN_PREFI;
|
||||
} else {
|
||||
if (n % N_BLOCK_DOWN != 0) {
|
||||
throw std::invalid_argument("n must be multiple of N_BLOCK_DOWN in decode");
|
||||
}
|
||||
return n / N_BLOCK_DOWN;
|
||||
}
|
||||
}
|
||||
static int recommended_mth(int m) { return (m + M_BLOCK - 1) / M_BLOCK; }
|
||||
|
||||
static int recommended_nth_up_gate(int n) {
|
||||
assert(n % N_BLOCK_UP_GATE == 0);
|
||||
return n / N_BLOCK_UP_GATE;
|
||||
static int recommended_nth_up_gate(int n, char type_ = 'd') {
|
||||
if (type_ == 'p') {
|
||||
if (n % N_BLOCK_UP_GATE_PREFI != 0) {
|
||||
throw std::invalid_argument("n must be multiple of N_BLOCK_UP_GATE_PREFI in prefill");
|
||||
}
|
||||
return n / N_BLOCK_UP_GATE_PREFI;
|
||||
} else {
|
||||
if (n % N_BLOCK_UP_GATE != 0) {
|
||||
throw std::invalid_argument("n must be multiple of N_BLOCK_UP_GATE in decode");
|
||||
}
|
||||
return n / N_BLOCK_UP_GATE;
|
||||
}
|
||||
}
|
||||
|
||||
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
|
||||
@@ -860,38 +889,63 @@ struct GemmKernelInt4 {
|
||||
dt *b;
|
||||
float *d;
|
||||
int n, k;
|
||||
std::vector<int8_t *> b_pack; // b_pack[i] -> the ith block (the ith packed matrix of B)
|
||||
static constexpr bool SCALE = true;
|
||||
bool if_pack = false;
|
||||
|
||||
static size_t required_size(int n, int k) { return sizeof(int8_t) * n * k / 2 + sizeof(float) * n; }
|
||||
|
||||
BufferB(int n, int k, void *ptr, bool if_pack = false) : n(n), k(k), if_pack(if_pack) {
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
assert(n % N_STEP == 0);
|
||||
assert(k % K_STEP == 0);
|
||||
b = reinterpret_cast<dt *>(ptr);
|
||||
d = reinterpret_cast<float *>(arm_kml::offset_pointer(b, n * k / 2));
|
||||
// static size_t required_size(int n, int k) { return sizeof(int8_t) * n * k / 2 + sizeof(float) * n; }
|
||||
static size_t required_size(int n, int k, bool if_pack = false, char mat_type = 'n', bool plain = true) {
|
||||
int nth, n_block;
|
||||
if (if_pack && !plain) {
|
||||
switch (mat_type) {
|
||||
case 'n':
|
||||
nth = recommended_nth(n);
|
||||
n_block = N_BLOCK;
|
||||
break;
|
||||
case 'u':
|
||||
nth = recommended_nth_up_gate(n);
|
||||
n_block = N_BLOCK_UP_GATE;
|
||||
break;
|
||||
case 'd':
|
||||
nth = recommended_nth_down(n);
|
||||
n_block = N_BLOCK_DOWN;
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("Invalid mat_type");
|
||||
}
|
||||
size_t reorder_B_size = get_reorder_B_size(KernelCblasRowMajor, KernelCblasNoTrans, k, n_block);
|
||||
return sizeof(int8_t) * nth * reorder_B_size + sizeof(float) * n;
|
||||
} else {
|
||||
return sizeof(int8_t) * n * k / 2 + sizeof(float) * n;
|
||||
}
|
||||
}
|
||||
|
||||
BufferB(int n, int k, bool if_pack = false) : n(n), k(k), if_pack(if_pack) {
|
||||
// assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
assert(n % N_STEP == 0);
|
||||
assert(k % K_STEP == 0);
|
||||
// BufferB(int n, int k, void *ptr, bool if_pack = false) : n(n), k(k), if_pack(if_pack) {
|
||||
// b = reinterpret_cast<dt *>(ptr);
|
||||
// d = reinterpret_cast<float *>(moe_kernel::offset_pointer(b, n * k / 2));
|
||||
// }
|
||||
BufferB(int n, int k, bool if_pack = false, char mat_type = 'n', bool plain = true) : n(n), k(k), if_pack(if_pack) {
|
||||
if (n % N_STEP != 0 || k % K_STEP != 0) {
|
||||
throw std::runtime_error("n and k must be multiples of N_STEP and K_STEP respectively");
|
||||
}
|
||||
}
|
||||
void set_data(void *ptr) {
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
BufferB(int n, int k, void *ptr, bool if_pack = false, char mat_type = 'n', bool plain = true)
|
||||
: BufferB(n, k, if_pack, mat_type, plain) {
|
||||
set_data(ptr, plain);
|
||||
}
|
||||
void set_data(void *ptr, bool plain = true) {
|
||||
b = reinterpret_cast<dt *>(ptr);
|
||||
d = reinterpret_cast<float *>(arm_kml::offset_pointer(b, n * k / 2));
|
||||
d = reinterpret_cast<float *>(moe_kernel::offset_pointer(b, n * k / 2));
|
||||
}
|
||||
size_t required_size() const { return sizeof(int8_t) * n * k / 2 + sizeof(float) * n; }
|
||||
BufferB offset_col(size_t col_begin, size_t col_block) {
|
||||
auto bufferb = BufferB(col_block, k, arm_kml::offset_pointer(b, (col_begin * k) / 2), if_pack);
|
||||
auto bufferb = BufferB(col_block, k, moe_kernel::offset_pointer(b, (col_begin * k) / 2), if_pack);
|
||||
bufferb.d = d + col_begin;
|
||||
return bufferb;
|
||||
}
|
||||
// B 矩阵是 K * N 的矩阵,存储在 b 中, 是列主序的 (column major)
|
||||
void from_mat(ggml_bf16_t *src, int ith, int nth, int n_new = -1,
|
||||
bool if_pack = false) { // CHECK: nth has no usage
|
||||
void from_mat(ggml_bf16_t *src, int ith, int nth, int n_new = -1, bool if_pack = false,
|
||||
bool plain = true) { // CHECK: nth has no usage
|
||||
if (!if_pack && !this->if_pack) throw std::runtime_error("from mat for buffer should be packed");
|
||||
if (n_new > 0) {
|
||||
n = n_new; // 如果 n_new 大于 0,则使用 n_new
|
||||
@@ -939,55 +993,6 @@ struct GemmKernelInt4 {
|
||||
}
|
||||
}
|
||||
|
||||
void from_mat(float16_t *src, int ith, int nth, int n_new = -1, bool if_pack = false) { // CHECK: nth has no usage
|
||||
if (!if_pack && !this->if_pack) throw std::runtime_error("from mat for buffer should be packed");
|
||||
if (n_new > 0) {
|
||||
n = n_new; // 如果 n_new 大于 0,则使用 n_new
|
||||
}
|
||||
// 这里将 src 转换成 int8_t 的形式,按照k 维度量化 (也就是按列量化)
|
||||
auto [n_start, n_end] = split_range_n(n, ith, nth);
|
||||
int n_block_begin = n_start;
|
||||
int n_block_size = n_end - n_block_begin;
|
||||
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
|
||||
for (int i = 0; i < N_STEP && n_begin + i < n_block_size; i++) {
|
||||
float amax = 0;
|
||||
// TODO: 后续用 SVE 来加速
|
||||
for (int j = 0; j < k; j++) {
|
||||
// 先把 src 转换成 float
|
||||
float f = src[(n_block_begin + n_begin + i) * k + j];
|
||||
f = f < 0 ? -f : f;
|
||||
if (f > amax) {
|
||||
amax = f;
|
||||
}
|
||||
}
|
||||
d[n_block_begin + n_begin + i] = amax / 112.0;
|
||||
// TODO: 后续用 SVE 来加速
|
||||
// 通过这个 amax 来量化这一列
|
||||
for (int k_start = 0; k_start < k; k_start += (PACK_SIZE_K * 2)) {
|
||||
for (int j = 0; j < PACK_SIZE_K; j++) {
|
||||
size_t split_n = (n_begin + i) / PACK_SIZE_N;
|
||||
size_t n_idx = (n_begin + i) % PACK_SIZE_N;
|
||||
size_t split_k = k_start / (PACK_SIZE_K * 2);
|
||||
size_t k_idx = j;
|
||||
|
||||
size_t buff_idx = n_block_begin * k / 2 + split_n * PACK_SIZE_N * k / 2 +
|
||||
split_k * PACK_SIZE_N * PACK_SIZE_K + n_idx * PACK_SIZE_K + k_idx;
|
||||
|
||||
float f0 = (src[(n_block_begin + n_begin + i) * k + k_start + j]);
|
||||
float f1 = (src[(n_block_begin + n_begin + i) * k + k_start + j + PACK_SIZE_K]);
|
||||
// static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));
|
||||
int8_t b0 = static_cast<int8_t>(std::round((f0 / (d[n_block_begin + n_begin + i] * 16.0))) * 16);
|
||||
int8_t b1 = static_cast<int8_t>(std::round((f1 / (d[n_block_begin + n_begin + i] * 16.0))) * 16);
|
||||
int8_t b01 = (b0 & 0xF0) | ((b1 >> 4) & 0x0F);
|
||||
// int8_t b01 = ((b0 << 4) & 0xF0) | ((b1)&0x0F);
|
||||
|
||||
b[buff_idx] = b01;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void from_mat(float *src, int ith, int nth, int n_new = -1, bool if_pack = false) { // CHECK: nth has no usage
|
||||
if (!if_pack && !this->if_pack) throw std::runtime_error("from mat for buffer should be packed");
|
||||
if (n_new > 0) {
|
||||
@@ -1190,212 +1195,6 @@ struct GemmKernelInt4 {
|
||||
}
|
||||
};
|
||||
|
||||
inline CBLAS_TRANSPOSE flip_trans(CBLAS_TRANSPOSE trans) {
|
||||
if (trans == CblasNoTrans) {
|
||||
return CblasTrans;
|
||||
} else if (trans == CblasTrans) {
|
||||
return CblasNoTrans;
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported transpose");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
struct MatRef {
|
||||
F *data = nullptr;
|
||||
size_t R, C, ld;
|
||||
CBLAS_ORDER order;
|
||||
CBLAS_TRANSPOSE trans;
|
||||
bool if_pack = false;
|
||||
static inline const int PACK_SIZE_N = 8;
|
||||
static inline const int PACK_SIZE_M = 8;
|
||||
static inline const int PACK_SIZE_K = 32;
|
||||
|
||||
MatRef() {}
|
||||
MatRef(F *data, size_t R, size_t C, size_t ld, CBLAS_ORDER order, CBLAS_TRANSPOSE trans = CblasNoTrans,
|
||||
bool if_pack = false)
|
||||
: data(data), R(R), C(C), ld(ld), order(order), trans(trans), if_pack(if_pack) {}
|
||||
|
||||
MatRef t() {
|
||||
MatRef re = *this;
|
||||
std::swap(re.R, re.C);
|
||||
CBLAS_ORDER new_order = (order == CblasRowMajor) ? CblasColMajor : CblasRowMajor;
|
||||
re.order = new_order;
|
||||
return re;
|
||||
}
|
||||
|
||||
CBLAS_TRANSPOSE trans_from(CBLAS_ORDER order) {
|
||||
if (order == this->order) {
|
||||
return trans;
|
||||
} else {
|
||||
return flip_trans(trans);
|
||||
}
|
||||
}
|
||||
|
||||
MatRef offset_block(size_t row, size_t col, size_t r_block, size_t c_block) {
|
||||
if (trans == CblasTrans) {
|
||||
std::swap(row, col);
|
||||
std::swap(r_block, c_block);
|
||||
}
|
||||
int devide_elements_size = 1;
|
||||
if constexpr (std::is_same_v<F, int4_2_t>) devide_elements_size = 2;
|
||||
// printf("devide_elements_size : %d\n", devide_elements_size);
|
||||
if (order == CblasRowMajor) {
|
||||
if (if_pack) {
|
||||
// if (devide_elements_size == 2)
|
||||
// printf("data:%p,after: %p,offset: %d\n", data, data + (row * ld + col * PACK_SIZE_M) /
|
||||
// devide_elements_size,
|
||||
// (row * ld + col * PACK_SIZE_M) / devide_elements_size);
|
||||
return MatRef(data + (row * ld + col * PACK_SIZE_M) / devide_elements_size, r_block, c_block, ld, order,
|
||||
CblasNoTrans, if_pack);
|
||||
} else {
|
||||
return MatRef(data + (row * ld + col) / devide_elements_size, r_block, c_block, ld, order, CblasNoTrans,
|
||||
if_pack);
|
||||
}
|
||||
} else if (order == CblasColMajor) {
|
||||
if (if_pack) {
|
||||
// if (devide_elements_size == 2)
|
||||
// printf("data:%p,after: %p,offset: %d\n", data, data + (col * ld + row * PACK_SIZE_N) /
|
||||
// devide_elements_size,
|
||||
// (col * ld + row * PACK_SIZE_N) / devide_elements_size);
|
||||
return MatRef(data + (col * ld + row * PACK_SIZE_N) / devide_elements_size, r_block, c_block, ld, order,
|
||||
CblasNoTrans, if_pack);
|
||||
} else {
|
||||
return MatRef(data + (col * ld + row) / devide_elements_size, r_block, c_block, ld, order, CblasNoTrans,
|
||||
if_pack);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported order");
|
||||
}
|
||||
}
|
||||
|
||||
inline MatRef trans_view() {
|
||||
if (order == CblasRowMajor) {
|
||||
return MatRef(data, C, R, ld, CblasColMajor, trans, if_pack);
|
||||
} else {
|
||||
return MatRef(data, C, R, ld, CblasRowMajor, trans, if_pack);
|
||||
}
|
||||
}
|
||||
|
||||
MatRef offset_row(size_t row_begin, size_t row_block) { return offset_block(row_begin, 0, row_block, C); }
|
||||
|
||||
MatRef offset_col(size_t col_begin, size_t col_block) { return offset_block(0, col_begin, R, col_block); }
|
||||
|
||||
F &at(size_t row, size_t col) {
|
||||
if (trans == CblasTrans) {
|
||||
throw std::runtime_error("Unsupported trans");
|
||||
}
|
||||
|
||||
if (order == CblasRowMajor) {
|
||||
return data[row * ld + col];
|
||||
} else if (order == CblasColMajor) {
|
||||
return data[col * ld + row];
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported order");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename A, typename B, typename C>
|
||||
static void mul_mat(MatRef<A> a, MatRef<B> b, MatRef<C> c, C alpha, C beta) {
|
||||
assert(a.C == b.R);
|
||||
assert(a.R == c.R);
|
||||
assert(b.C == c.C);
|
||||
// assert(a.order == b.order);
|
||||
// assert(a.order == c.order);
|
||||
assert(c.trans == CblasNoTrans);
|
||||
BLASINT8 oa = 0, ob = 0;
|
||||
int32_t oc = 0;
|
||||
|
||||
if constexpr (std::is_same_v<A, float> && std::is_same_v<B, float> && std::is_same_v<C, float>) {
|
||||
cblas_sgemm(c.order, a.trans_from(c.order), b.trans_from(c.order), c.R, c.C, a.C, alpha, a.data, a.ld, b.data, b.ld,
|
||||
beta, c.data, c.ld);
|
||||
|
||||
} else if constexpr (std::is_same_v<A, float16_t> && std::is_same_v<B, float16_t> && std::is_same_v<C, float16_t>) {
|
||||
cblas_hgemm(c.order, a.trans_from(c.order), b.trans_from(c.order), c.R, c.C, a.C, alpha, a.data, a.ld, b.data, b.ld,
|
||||
beta, c.data, c.ld);
|
||||
} else if constexpr (std::is_same_v<A, float16_t> && std::is_same_v<B, float16_t> && std::is_same_v<C, float>) {
|
||||
cblas_shgemm(c.order, a.trans_from(c.order), b.trans_from(c.order), c.R, c.C, a.C, alpha, a.data, a.ld, b.data,
|
||||
b.ld, beta, c.data, c.ld);
|
||||
} else if constexpr (std::is_same_v<A, bfloat16_t> && std::is_same_v<B, bfloat16_t> &&
|
||||
std::is_same_v<C, bfloat16_t>) {
|
||||
cblas_bgemm(c.order, a.trans_from(c.order), b.trans_from(c.order), c.R, c.C, a.C, alpha, a.data, a.ld, b.data, b.ld,
|
||||
beta, c.data, c.ld);
|
||||
} else if constexpr (std::is_same_v<A, int8_t> && std::is_same_v<B, int8_t> && std::is_same_v<C, int32_t>) {
|
||||
if (b.if_pack) {
|
||||
prefill_cblas_gemm_s8s8s32(c.order, a.trans_from(c.order), b.trans_from(c.order), CblasFixOffset, c.R, c.C, a.C,
|
||||
alpha, a.data, a.ld, oa, b.data, b.ld, ob, beta, c.data, c.ld, &oc);
|
||||
} else {
|
||||
cblas_gemm_s8s8s32(c.order, a.trans_from(c.order), b.trans_from(c.order), CblasFixOffset, c.R, c.C, a.C, alpha,
|
||||
a.data, a.ld, oa, b.data, b.ld, ob, beta, c.data, c.ld, &oc);
|
||||
}
|
||||
|
||||
} else if constexpr (std::is_same_v<A, int8_t> && std::is_same_v<B, int4_2_t> && std::is_same_v<C, int32_t>) {
|
||||
// throw std::runtime_error("INT4 does not support cblas_gemm_s8s8s32, please use decode_cblas_gemm_s8s8s32");
|
||||
if (b.if_pack) {
|
||||
prefill_int4_cblas_gemm_s8s8s32(c.order, a.trans_from(c.order), b.trans_from(c.order), CblasFixOffset, c.R, c.C,
|
||||
a.C, alpha, a.data, a.ld, oa, b.data, b.ld, ob, beta, c.data, c.ld, &oc);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"INT4 does not support cblas_gemm_s8s8s32 for unpack, please use decode_cblas_gemm_s8s8s32");
|
||||
}
|
||||
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename A, typename B, typename C>
|
||||
static void decode_mul_mat(MatRef<A> a, MatRef<B> b, MatRef<C> c, C alpha, C beta) {
|
||||
assert(a.C == b.R);
|
||||
assert(a.R == c.R);
|
||||
assert(b.C == c.C);
|
||||
// assert(a.order == b.order);
|
||||
// assert(a.order == c.order);
|
||||
assert(c.trans == CblasNoTrans);
|
||||
BLASINT incX = 1, incY = 1;
|
||||
BLASINT8 oa = 0, ob = 0;
|
||||
int32_t oc = 0;
|
||||
if constexpr (std::is_same_v<A, float> && std::is_same_v<B, float> && std::is_same_v<C, float>) {
|
||||
cblas_sgemv(a.order, a.trans, a.R, a.C, alpha, a.data, a.ld, b.data, incX, beta, c.data, incY);
|
||||
} else if constexpr (std::is_same_v<A, int8_t> && std::is_same_v<B, int8_t> && std::is_same_v<C, int32_t>) {
|
||||
// printf("debug: c.order: %d, a.order: %d, b.order: %d,c.R: %zu, c.C: %zu, a.C: %zu, alpha: %d, a.ld: %ld, oa: %d,
|
||||
// b.ld: %ld, ob: %d, beta: %d, c.ld: %ld, oc: %d\n",
|
||||
// c.order, a.order, b.order, c.R, c.C, a.C, alpha, a.ld, oa, b.ld, ob, beta, c.ld, oc);
|
||||
if (b.if_pack)
|
||||
decode_cblas_gemm_s8s8s32(c.order, a.trans_from(c.order), b.trans_from(c.order), CblasFixOffset, c.R, c.C, a.C,
|
||||
alpha, a.data, a.ld, oa, b.data, b.ld, ob, beta, c.data, c.ld, &oc);
|
||||
else
|
||||
throw std::runtime_error("Unsupported type");
|
||||
|
||||
} else if constexpr (std::is_same_v<A, int8_t> && std::is_same_v<B, int4_2_t> && std::is_same_v<C, int32_t>) {
|
||||
decode_int4_cblas_gemm_s8s8s32(c.order, a.trans_from(c.order), b.trans_from(c.order), CblasFixOffset, c.R, c.C, a.C,
|
||||
alpha, a.data, a.ld, oa, b.data, b.ld, ob, beta, c.data, c.ld, &oc);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename A, typename B, typename C>
|
||||
static void mul_mat(MatRef<A> a, MatRef<B> b, MatRef<C> c) {
|
||||
mul_mat(a, b, c, static_cast<C>(1.0), static_cast<C>(1.0));
|
||||
}
|
||||
|
||||
template <typename A, typename B, typename C>
|
||||
static void mul_mat_clearc(MatRef<A> a, MatRef<B> b, MatRef<C> c) {
|
||||
mul_mat(a, b, c, static_cast<C>(1.0), static_cast<C>(0.0));
|
||||
}
|
||||
|
||||
template <typename A, typename B, typename C>
|
||||
static void decode_mul_mat_clearc(MatRef<A> a, MatRef<B> b, MatRef<C> c) {
|
||||
decode_mul_mat(a, b, c, static_cast<C>(1.0), static_cast<C>(0.0));
|
||||
}
|
||||
|
||||
template <typename A, typename B, typename C>
|
||||
static void decode_mul_mat(MatRef<A> a, MatRef<B> b, MatRef<C> c) {
|
||||
decode_mul_mat(a, b, c, static_cast<C>(1.0), static_cast<C>(1.0));
|
||||
}
|
||||
|
||||
} // namespace arm_kml
|
||||
} // namespace moe_kernel
|
||||
|
||||
#endif
|
||||
54
kt-kernel/operators/moe_kernel/la/mat_kernel.cpp
Normal file
54
kt-kernel/operators/moe_kernel/la/mat_kernel.cpp
Normal file
@@ -0,0 +1,54 @@
|
||||
#include "../api/mat_kernel.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace {
|
||||
constexpr int kInt4ElementDivisor = 2;
|
||||
constexpr int kInt8ElementDivisor = 1;
|
||||
} // namespace
|
||||
extern "C" {
|
||||
void decode_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
|
||||
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
|
||||
const int8_t oa, const void* b, const size_t ldb, const int8_t ob, const float beta,
|
||||
int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void prefill_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
|
||||
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
|
||||
const int8_t oa, const void* b, const size_t ldb, const int8_t ob, const float beta,
|
||||
int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void decode_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
|
||||
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const int8_t oa, const void* b, const size_t ldb, const int8_t ob,
|
||||
const float beta, int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void prefill_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
|
||||
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const int8_t oa, const void* b, const size_t ldb,
|
||||
const int8_t ob, const float beta, int32_t* c, const size_t ldc,
|
||||
const int32_t* oc);
|
||||
}
|
||||
|
||||
MatKernelSelection select_kernel_for_int4(MatKernelVariant variant) {
|
||||
switch (variant) {
|
||||
case MatKernelVariant::Decode:
|
||||
return {decode_int4_cblas_gemm_s8s8s32, kInt4ElementDivisor};
|
||||
case MatKernelVariant::Prefill:
|
||||
return {prefill_int4_cblas_gemm_s8s8s32, kInt4ElementDivisor};
|
||||
}
|
||||
return {nullptr, 0};
|
||||
}
|
||||
|
||||
MatKernelSelection select_kernel_for_int8(MatKernelVariant variant) {
|
||||
switch (variant) {
|
||||
case MatKernelVariant::Decode:
|
||||
return {decode_cblas_gemm_s8s8s32, kInt8ElementDivisor};
|
||||
case MatKernelVariant::Prefill:
|
||||
return {prefill_cblas_gemm_s8s8s32, kInt8ElementDivisor};
|
||||
}
|
||||
return {nullptr, 0};
|
||||
}
|
||||
29
kt-kernel/operators/moe_kernel/la/utils.hpp
Normal file
29
kt-kernel/operators/moe_kernel/la/utils.hpp
Normal file
@@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
// #include <arm_sve.h>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
// 简单截断模式:直接丢弃低 16 位
|
||||
static inline uint16_t float_to_bf16_trunc(float f) {
|
||||
uint32_t u;
|
||||
// 按位拷贝,避免 strict‑aliasing UB
|
||||
memcpy(&u, &f, sizeof(u)); // :contentReference[oaicite:3]{index=3}
|
||||
return (uint16_t)(u >> 16); // 截断得到高 16 位 :contentReference[oaicite:4]{index=4}
|
||||
}
|
||||
|
||||
static inline void convert_32fp32_to_32bf16_pure_c(const float* src, uint16_t* dst) {
|
||||
// src 已偏移至 token_nth * hidden_size
|
||||
for (int e = 0; e < 32; e++) { // 共 32 个元素
|
||||
// 选择截断或四舍五入
|
||||
dst[e] = float_to_bf16_trunc(src[e]);
|
||||
}
|
||||
}
|
||||
|
||||
// 把 32 个 bf16 元素转换成 32 个 fp32 元素
|
||||
|
||||
static inline void convert_32bf16_to_32fp32_pure_c(const uint16_t* src, float* dst) {
|
||||
for (int e = 0; e < 32; e++) {
|
||||
uint32_t temp = ((uint32_t)src[e]) << 16; // 将 BF16 左移 16 位
|
||||
memcpy(&dst[e], &temp, sizeof(float)); // 将结果复制到 FP32 变量中
|
||||
}
|
||||
}
|
||||
100
kt-kernel/operators/moe_kernel/mat_kernel/aocl_kernel/kernel.cpp
Normal file
100
kt-kernel/operators/moe_kernel/mat_kernel/aocl_kernel/kernel.cpp
Normal file
@@ -0,0 +1,100 @@
|
||||
#include <stdexcept>
|
||||
|
||||
#include "../batch_gemm_api.hpp"
|
||||
#include "blis.h"
|
||||
|
||||
namespace {
|
||||
|
||||
char ToAoclOrder(KERNEL_CBLAS_LAYOUT layout) {
|
||||
switch (layout) {
|
||||
case KernelCblasRowMajor:
|
||||
return 'r';
|
||||
case KernelCblasColMajor:
|
||||
return 'c';
|
||||
}
|
||||
throw std::invalid_argument("Unsupported KERNEL_CBLAS_LAYOUT value");
|
||||
}
|
||||
|
||||
char ToAoclTranspose(KERNEL_CBLAS_TRANSPOSE transpose) {
|
||||
switch (transpose) {
|
||||
case KernelCblasNoTrans:
|
||||
return 'n';
|
||||
case KernelCblasTrans:
|
||||
return 't';
|
||||
case KernelCblasConjTrans:
|
||||
case KernelCblasConjNoTrans:
|
||||
break;
|
||||
}
|
||||
throw std::invalid_argument("Unsupported KERNEL_CBLAS_TRANSPOSE value");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// 映射表,layout 从KERNEL_CBLAS_ORDER 映射到'r'或者'c',以及将KERNEL_CBLAS_TRANSPOSE映射到'n'或者't'
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void decode_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
|
||||
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
|
||||
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,
|
||||
int32_t* c, const size_t ldc, const int32_t* oc) {
|
||||
const char order = ToAoclOrder(layout);
|
||||
const char op_a = ToAoclTranspose(transa);
|
||||
const char op_b = ToAoclTranspose(transb);
|
||||
(void)offsetc;
|
||||
aocl_gemm_s8s8s32os32(order, op_a, op_b, static_cast<dim_t>(m), static_cast<dim_t>(n), static_cast<dim_t>(k),
|
||||
static_cast<int32_t>(alpha), static_cast<const int8_t*>(a), static_cast<dim_t>(lda), 'n',
|
||||
static_cast<const int8_t*>(b), static_cast<dim_t>(ldb), 'r', static_cast<int32_t>(beta), c,
|
||||
static_cast<dim_t>(ldc), nullptr);
|
||||
}
|
||||
|
||||
void prefill_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
|
||||
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
|
||||
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,
|
||||
int32_t* c, const size_t ldc, const int32_t* oc) {
|
||||
const char order = ToAoclOrder(layout);
|
||||
const char op_a = ToAoclTranspose(transa);
|
||||
const char op_b = ToAoclTranspose(transb);
|
||||
(void)offsetc;
|
||||
aocl_gemm_s8s8s32os32(order, op_a, op_b, static_cast<dim_t>(m), static_cast<dim_t>(n), static_cast<dim_t>(k),
|
||||
static_cast<int32_t>(alpha), static_cast<const int8_t*>(a), static_cast<dim_t>(lda), 'n',
|
||||
static_cast<const int8_t*>(b), static_cast<dim_t>(ldb), 'r', static_cast<int32_t>(beta), c,
|
||||
static_cast<dim_t>(ldc), nullptr);
|
||||
}
|
||||
|
||||
void prefill_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
|
||||
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
|
||||
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
|
||||
const int32_t* oc) {
|
||||
throw std::runtime_error("int4 not support prefill");
|
||||
}
|
||||
|
||||
void decode_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
|
||||
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
|
||||
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
|
||||
const int32_t* oc) {
|
||||
throw std::runtime_error("int4 not support decode");
|
||||
}
|
||||
|
||||
void reorder_B_gemm(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
|
||||
const size_t n, const size_t ldb, const void* b, void* b_reordered) {
|
||||
const char order = ToAoclOrder(layout);
|
||||
const char op_b = ToAoclTranspose(transb);
|
||||
aocl_reorder_s8s8s32os32(order, op_b, 'B', static_cast<const int8_t*>(b), static_cast<int8_t*>(b_reordered), k, n,
|
||||
ldb);
|
||||
}
|
||||
|
||||
size_t get_reorder_B_size(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
|
||||
const size_t n) {
|
||||
return aocl_get_reorder_buf_size_s8s8s32os32(ToAoclOrder(layout), ToAoclTranspose(transb), 'B', k, n);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
42
kt-kernel/operators/moe_kernel/mat_kernel/batch_gemm_api.hpp
Normal file
42
kt-kernel/operators/moe_kernel/mat_kernel/batch_gemm_api.hpp
Normal file
@@ -0,0 +1,42 @@
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
#ifndef _BATCH_GEMM_KERNEL_API_
|
||||
#define _BATCH_GEMM_KERNEL_API_
|
||||
#include "../api/common.h"
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void decode_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
|
||||
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
|
||||
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,
|
||||
int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void prefill_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
|
||||
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
|
||||
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,
|
||||
int32_t* c, const size_t ldc, const int32_t* oc);
|
||||
|
||||
void decode_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
|
||||
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
|
||||
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
|
||||
const int32_t* oc);
|
||||
|
||||
void prefill_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
|
||||
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
|
||||
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
|
||||
const int32_t* oc);
|
||||
void reorder_B_gemm(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
|
||||
const size_t n, const size_t ldb, const void* b, void* b_reordered);
|
||||
size_t get_reorder_B_size(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
|
||||
const size_t n);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif /*** _BATCH_GEMM_KERNEL_API_ ***/
|
||||
@@ -0,0 +1,56 @@
|
||||
#include <stdexcept>
|
||||
|
||||
#include "../batch_gemm_api.hpp"
|
||||
#include "utils.hpp"
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void decode_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,
|
||||
const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,
|
||||
const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,
|
||||
int32_t* c, const size_t ldc, const int32_t* oc) {
|
||||
BLASINT8* ptrA = (BLASINT8*)a;
|
||||
BLASINT8* ptrB = (BLASINT8*)b;
|
||||
int32_t* ptrC = c;
|
||||
size_t split_n = n / N_SIZE;
|
||||
|
||||
for (size_t n_block = 0; n_block < split_n; n_block++) {
|
||||
BLASINT8* cur_ptrA = ptrA;
|
||||
BLASINT8* cur_ptrB = ptrB + n_block * (N_SIZE * k);
|
||||
int32_t* cur_ptrC = ptrC + n_block * N_SIZE;
|
||||
gemm_kernel_1x8(cur_ptrA, cur_ptrB, cur_ptrC, ldc, k, COMP_SV_LEN);
|
||||
}
|
||||
}
|
||||
|
||||
void decode_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,
|
||||
const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,
|
||||
const size_t m, const size_t n, const size_t k, const float alpha, const void* a,
|
||||
const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,
|
||||
const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,
|
||||
const int32_t* oc) {
|
||||
BLASINT8* ptrA = (BLASINT8*)a;
|
||||
BLASINT8* ptrB = (BLASINT8*)b;
|
||||
int32_t* ptrC = c;
|
||||
size_t split_n = n / N_SIZE;
|
||||
|
||||
for (size_t n_block = 0; n_block < split_n; n_block++) {
|
||||
BLASINT8* cur_ptrA = ptrA;
|
||||
BLASINT8* cur_ptrB = ptrB + n_block * (N_SIZE * (k / 2));
|
||||
int32_t* cur_ptrC = ptrC + n_block * N_SIZE;
|
||||
gemm_kernel_1x8_int4(cur_ptrA, cur_ptrB, cur_ptrC, (ldc / 2), (k / 2), COMP_SV_LEN);
|
||||
}
|
||||
}
|
||||
void reorder_B_gemm(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
|
||||
const size_t n, const size_t ldb, const void* b, void* b_reordered) {
|
||||
throw std::runtime_error("haven't supported reorder");
|
||||
}
|
||||
|
||||
size_t get_reorder_B_size(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,
|
||||
const size_t n) {
|
||||
throw std::runtime_error("haven't supported reorder");
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "batch_gemm_api.hpp"
|
||||
#include "prefillgemm_int4/integer_gemm_kernels.h"
|
||||
#include "utils.hpp"
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user