[feat](kt-kernel): support avx2 only inference for bf16 fp8 and gptq int4 (#1892)

* feat: support avx2 bf16 fp8 inference

* feat: support avx2 gptq int4 inference

* fix: numeric issues in fp8 dequant

* Tutorial avx2 (#1900)

* fix: prevent injecting -DLLAMA_AVX512=ON on AVX2-only machines

* docs: add AVX2 tutorial for running KTransformers on AVX2-only CPUs

* Tutorial avx2 (#1901)

* fix: prevent injecting -DLLAMA_AVX512=ON on AVX2-only machines

* docs: add AVX2 tutorial for running KTransformers on AVX2-only CPUs

* docs: update README.md

---------

Co-authored-by: Benjamin F <159887351+yyj6666667@users.noreply.github.com>
This commit is contained in:
mrhaoxx
2026-03-27 14:45:02 +08:00
committed by GitHub
parent 8561a71dd1
commit 7a9daf0cd4
19 changed files with 3472 additions and 12 deletions

View File

@@ -16,7 +16,7 @@
KTransformers is a research project focused on efficient inference and fine-tuning of large language models through CPU-GPU heterogeneous computing. The project has evolved into **two core modules**: [kt-kernel](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel/) and [kt-sft](https://github.com/kvcache-ai/ktransformers/tree/main/kt-sft).
## 🔥 Updates
* **Mar 26, 2026**: Support AVX2-only CPU backend for KT-Kernel inference. ([Tutorial](./doc/en/kt-kernel/AVX2-Tutorial.md))
* **Feb 13, 2026**: MiniMax-M2.5 Day0 Support! ([Tutorial](./doc/en/MiniMax-M2.5.md))
* **Feb 12, 2026**: GLM-5 Day0 Support! ([Tutorial](./doc/en/kt-kernel/GLM-5-Tutorial.md))
* **Jan 27, 2026**: Kimi-K2.5 Day0 Support! ([Tutorial](./doc/en/Kimi-K2.5.md)) ([SFT Tutorial](./doc/en/SFT_Installation_Guide_KimiK2.5.md))

View File

@@ -5,7 +5,7 @@
- [For kt-kernel](en/kt-kernel/kt-kernel_intro.md)
- [For kt-sft](en/SFT/KTransformers-Fine-Tuning_User-Guide.md)
# Tutorial
# Tutorial
- [kt-sft part](en/SFT/README.md)
- [Injection Tutorial](en/SFT/injection_tutorial.md)
- [kt-sft developer tech notes](en/SFT/KTransformers-Fine-Tuning_Developer-Technical-Notes.md)
@@ -19,6 +19,8 @@
- [Makefile Usage](en/makefile_usage.md) -->
- [kt-kernel part](en/kt-kernel/README.md)
- [kt-cli](en/kt-kernel/kt-cli.md)
- [AVX2 Backend Tutorial](en/kt-kernel/AVX2-Tutorial.md)
- [AVX2 后端教程(中文)](zh/AVX2-Tutorial_zh.md)
# FAQ
- [FAQ](en/FAQ.md)
<!-- # V3 Reproduction

View File

@@ -0,0 +1,188 @@
# Running KTransformers on AVX2 CPUs
This tutorial explains how to run KTransformers on machines that only support AVX2 (without AVX512 or AMX).
## Table of Contents
- [Supported Precision Formats](#supported-precision-formats)
- [Hardware Requirements](#hardware-requirements)
- [Installation](#installation)
- [Verification](#verification)
- [Starting the Inference Server](#starting-the-inference-server)
- [Example: Qwen3-30B-A3B (BF16)](#example-qwen3-30b-a3b-bf16)
- [Example: Qwen3.5-35B-A3B-FP8 (FP8)](#example-qwen35-35b-a3b-fp8-fp8)
- [Example: Qwen3-30B-A3B-GPTQ-Int4 (GPTQ_INT4)](#example-qwen3-30b-a3b-gptq-int4-gptq_int4)
- [Sending Requests](#sending-requests)
- [Performance Tuning](#performance-tuning)
- [FAQ](#faq)
## Supported Precision Formats
| `--kt-method` | Precision | Description |
|---------------|-----------|-------------|
| `BF16` | BF16 native precision | Zero precision loss, uses BF16 weights directly |
| `FP8` | FP8 block quantization | |
| `GPTQ_INT4` | INT4 GPTQ | |
## Hardware Requirements
- **CPU**: x86-64 + AVX2 + FMA (Intel Haswell 2013+ / AMD Zen+)
- **GPU**: NVIDIA 24GB+ VRAM (RTX 3090/4090/5090, etc.)
- **Memory**: At least the size of the model weights (e.g., Qwen3-30B-A3B BF16 requires 64GB+)
- **OS**: Linux
## Installation
Build and install from source (one-click install for kt-kernel + SGLang):
```bash
git clone https://github.com/kvcache-ai/ktransformers.git
cd ktransformers
git submodule update --init --recursive
# One-click install
./install.sh
```
On AVX512 or AMX machines, you can also manually force AVX2 compilation:
```bash
export CPUINFER_CPU_INSTRUCT=AVX2
export CPUINFER_ENABLE_AMX=OFF
./install.sh kt-kernel --manual
```
## Verification
```bash
# Check if the CPU supports AVX2
lscpu | grep -i avx2
# Check the loaded kt-kernel variant
python -c "import kt_kernel; print(kt_kernel.__cpu_variant__)"
# Expected output: avx2
# System diagnostics
kt doctor
```
## Starting the Inference Server
Use `--kt-method BF16`, `FP8`, or `GPTQ_INT4`. KT-Kernel will **automatically detect** the CPU and fall back to the AVX2 backend when AVX512/AMX is unavailable.
### Example: Qwen3-30B-A3B (BF16)
```bash
# Download the model
huggingface-cli download Qwen/Qwen3-30B-A3B --local-dir /path/to/Qwen3-30B-A3B
# Check physical core count and NUMA node count
lscpu | grep -E "^CPU\(s\)|Thread\(s\) per core|NUMA node\(s\)"
# Start the server (adjust kt-cpuinfer and kt-threadpool-count based on your hardware)
python -m sglang.launch_server \
--host 0.0.0.0 --port 30000 \
--model /path/to/Qwen3-30B-A3B \
--kt-weight-path /path/to/Qwen3-30B-A3B \
--kt-cpuinfer 16 \
--kt-threadpool-count 1 \
--kt-num-gpu-experts 32 \
--kt-method BF16 \
--attention-backend flashinfer \
--trust-remote-code \
--mem-fraction-static 0.80 \
--chunked-prefill-size 8192 \
--max-running-requests 2 \
--served-model-name Qwen3 \
--enable-mixed-chunk \
--tensor-parallel-size 1 \
--enable-p2p-check \
--disable-shared-experts-fusion
```
### Example: Qwen3.5-35B-A3B-FP8 (FP8)
```bash
# Download the model
huggingface-cli download Qwen/Qwen3.5-35B-A3B-FP8 --local-dir /path/to/Qwen3.5-35B-A3B-FP8
# Start the server
python -m sglang.launch_server \
--host 0.0.0.0 --port 30000 \
--model /path/to/Qwen3.5-35B-A3B-FP8 \
--kt-weight-path /path/to/Qwen3.5-35B-A3B-FP8 \
--kt-cpuinfer 16 \
--kt-threadpool-count 1 \
--kt-num-gpu-experts 2 \
--kt-method FP8 \
--kt-gpu-prefill-token-threshold 400 \
--attention-backend triton \
--trust-remote-code \
--mem-fraction-static 0.85 \
--chunked-prefill-size 4096 \
--max-running-requests 1 \
--max-total-tokens 32000 \
--enable-mixed-chunk \
--tensor-parallel-size 1 \
--disable-shared-experts-fusion
```
### Example: Qwen3-30B-A3B-GPTQ-Int4 (GPTQ_INT4)
```bash
# Download the model
huggingface-cli download Qwen/Qwen3-30B-A3B-GPTQ-Int4 --local-dir /path/to/Qwen3-30B-A3B-GPTQ-Int4
# Start the server
python -m sglang.launch_server \
--host 0.0.0.0 --port 30000 \
--model /path/to/Qwen3-30B-A3B-GPTQ-Int4 \
--kt-weight-path /path/to/Qwen3-30B-A3B-GPTQ-Int4 \
--kt-cpuinfer 16 \
--kt-threadpool-count 1 \
--kt-num-gpu-experts 2 \
--kt-method GPTQ_INT4 \
--attention-backend triton \
--trust-remote-code \
--mem-fraction-static 0.85 \
--chunked-prefill-size 4096 \
--max-running-requests 1 \
--max-total-tokens 32000 \
--enable-mixed-chunk \
--tensor-parallel-size 1 \
--disable-shared-experts-fusion
```
### Sending Requests
```bash
# Interactive chat
kt chat
# OpenAI-compatible API
curl http://localhost:30000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"model":"Qwen3","messages":[{"role":"user","content":"Hello"}],"stream":true}'
```
## Performance Tuning
- `--kt-cpuinfer`: set to the number of **physical cores**
- `--kt-threadpool-count`: set to the number of **NUMA nodes**
- `--kt-num-gpu-experts`: higher values reduce CPU load but increase GPU VRAM usage
- Memory bandwidth is often the bottleneck; high-frequency DDR5 memory helps significantly
## FAQ
**GPU OOM**
- Reduce `--kt-num-gpu-experts`, `--chunked-prefill-size`, `--max-total-tokens`
- Lower `--mem-fraction-static`
For more questions, see [FAQ](../FAQ.md).

188
doc/zh/AVX2-Tutorial_zh.md Normal file
View File

@@ -0,0 +1,188 @@
# 在 AVX2 CPU 上使用 KTransformers
本教程介绍如何在仅支持 AVX2 的机器上运行 KTransformers无需 AVX512 或 AMX
## 目录
- [支持的精度格式](#支持的精度格式)
- [硬件要求](#硬件要求)
- [安装](#安装)
- [验证](#验证)
- [启动推理服务](#启动推理服务)
- [示例Qwen3-30B-A3B (BF16)](#示例qwen3-30b-a3b-bf16)
- [示例Qwen3.5-35B-A3B-FP8 (FP8)](#示例qwen35-35b-a3b-fp8-fp8)
- [示例Qwen3-30B-A3B-GPTQ-Int4 (GPTQ_INT4)](#示例qwen3-30b-a3b-gptq-int4-gptq_int4)
- [发送请求](#发送请求)
- [性能调优](#性能调优)
- [常见问题](#常见问题)
## 支持的精度格式
| `--kt-method` | 精度 | 说明 |
|---------------|------|------|
| `BF16` | BF16 原精度 | 零精度损失,直接使用 BF16 权重 |
| `FP8` | FP8 分块量化 | |
| `GPTQ_INT4` | INT4 GPTQ | |
## 硬件要求
- **CPU**x86-64 + AVX2 + FMAIntel Haswell 2013+ / AMD Zen+
- **GPU**NVIDIA 24GB+ 显存RTX 3090/4090/5090 等)
- **内存**:不少于模型权重大小(如 Qwen3-30B-A3B BF16 需 64GB+
- **系统**Linux
## 安装
从源码编译安装(一键安装 kt-kernel + SGLang
```bash
git clone https://github.com/kvcache-ai/ktransformers.git
cd ktransformers
git submodule update --init --recursive
# 一键安装
./install.sh
```
在AVX512 AMX机器上 也可以手动强制 AVX2 编译:
```bash
export CPUINFER_CPU_INSTRUCT=AVX2
export CPUINFER_ENABLE_AMX=OFF
./install.sh kt-kernel --manual
```
## 验证
```bash
# 检查 CPU 是否支持 AVX2
lscpu | grep -i avx2
# 检查 kt-kernel 加载的变体
python -c "import kt_kernel; print(kt_kernel.__cpu_variant__)"
# 预期输出avx2
# 系统诊断
kt doctor
```
## 启动推理服务
使用 `--kt-method BF16``FP8``GPTQ_INT4`KT-Kernel 会**自动检测** CPU 并在缺少 AVX512/AMX 时回退到 AVX2 后端。
### 示例Qwen3-30B-A3B (BF16)
```bash
# 下载模型
huggingface-cli download Qwen/Qwen3-30B-A3B --local-dir /path/to/Qwen3-30B-A3B
# 查看物理核心数和 NUMA 节点数
lscpu | grep -E "^CPU\(s\)|Thread\(s\) per core|NUMA node\(s\)"
# 启动服务(按实际硬件调整 kt-cpuinfer 和 kt-threadpool-count
python -m sglang.launch_server \
--host 0.0.0.0 --port 30000 \
--model /path/to/Qwen3-30B-A3B \
--kt-weight-path /path/to/Qwen3-30B-A3B \
--kt-cpuinfer 16 \
--kt-threadpool-count 1 \
--kt-num-gpu-experts 32 \
--kt-method BF16 \
--attention-backend flashinfer \
--trust-remote-code \
--mem-fraction-static 0.80 \
--chunked-prefill-size 8192 \
--max-running-requests 2 \
--served-model-name Qwen3 \
--enable-mixed-chunk \
--tensor-parallel-size 1 \
--enable-p2p-check \
--disable-shared-experts-fusion
```
### 示例Qwen3.5-35B-A3B-FP8 (FP8)
```bash
# 下载模型
huggingface-cli download Qwen/Qwen3.5-35B-A3B-FP8 --local-dir /path/to/Qwen3.5-35B-A3B-FP8
# 启动服务
python -m sglang.launch_server \
--host 0.0.0.0 --port 30000 \
--model /path/to/Qwen3.5-35B-A3B-FP8 \
--kt-weight-path /path/to/Qwen3.5-35B-A3B-FP8 \
--kt-cpuinfer 16 \
--kt-threadpool-count 1 \
--kt-num-gpu-experts 2 \
--kt-method FP8 \
--kt-gpu-prefill-token-threshold 400 \
--attention-backend triton \
--trust-remote-code \
--mem-fraction-static 0.85 \
--chunked-prefill-size 4096 \
--max-running-requests 1 \
--max-total-tokens 32000 \
--enable-mixed-chunk \
--tensor-parallel-size 1 \
--disable-shared-experts-fusion
```
### 示例Qwen3-30B-A3B-GPTQ-Int4 (GPTQ_INT4)
```bash
# 下载模型
huggingface-cli download Qwen/Qwen3-30B-A3B-GPTQ-Int4 --local-dir /path/to/Qwen3-30B-A3B-GPTQ-Int4
# 启动服务
python -m sglang.launch_server \
--host 0.0.0.0 --port 30000 \
--model /path/to/Qwen3-30B-A3B-GPTQ-Int4 \
--kt-weight-path /path/to/Qwen3-30B-A3B-GPTQ-Int4 \
--kt-cpuinfer 16 \
--kt-threadpool-count 1 \
--kt-num-gpu-experts 2 \
--kt-method GPTQ_INT4 \
--attention-backend triton \
--trust-remote-code \
--mem-fraction-static 0.85 \
--chunked-prefill-size 4096 \
--max-running-requests 1 \
--max-total-tokens 32000 \
--enable-mixed-chunk \
--tensor-parallel-size 1 \
--disable-shared-experts-fusion
```
### 发送请求
```bash
# 交互聊天
kt chat
# OpenAI 兼容 API
curl http://localhost:30000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"model":"Qwen3","messages":[{"role":"user","content":"你好"}],"stream":true}'
```
## 性能调优
- `--kt-cpuinfer` 设为**物理核心数**
- `--kt-threadpool-count` 设为 **NUMA 节点数**
- `--kt-num-gpu-experts` 越大 CPU 负担越小,但 GPU 显存占用越高
- 内存带宽往往是瓶颈DDR5 高频内存有明显帮助
## 常见问题
**GPU OOM**
- 减小 `--kt-num-gpu-experts``--chunked-prefill-size``--max-total-tokens`
- 降低 `--mem-fraction-static`
更多问题参见 [FAQ](../en/FAQ.md)。

View File

@@ -45,6 +45,13 @@ static const bool _is_plain_ = false;
#include "operators/amx/la/amx_kernels.hpp"
#include "operators/amx/moe.hpp"
#endif
// AVX2 backends — always available on x86_64 (no AMX/AVX512 dependency)
#if defined(__x86_64__)
#include "operators/avx2/bf16-moe.hpp"
#include "operators/avx2/fp8-moe.hpp"
#include "operators/avx2/gptq_int4-moe.hpp"
#endif
#include <pybind11/stl.h> // std::vector/std::pair/std::string conversions
#include <cstdint>
@@ -578,6 +585,13 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
bind_moe_module<AMX_FP8_PERCHANNEL_MOE_TP<amx::GemmKernel224FP8PerChannel>>(moe_module, "AMXFP8PerChannel_MOE");
#endif
#endif
// AVX2 backends — available on all x86_64 (no AMX/AVX512 requirement)
#if defined(__x86_64__)
bind_moe_module<AVX2_BF16_MOE_TP<avx2::GemmKernelAVX2BF16>>(moe_module, "AVX2BF16_MOE");
bind_moe_module<AVX2_FP8_MOE_TP<avx2::GemmKernelAVX2FP8>>(moe_module, "AVX2FP8_MOE");
bind_moe_module<AVX2_GPTQ_INT4_MOE_TP<avx2::GemmKernelAVX2GPTQInt4>>(moe_module, "AVX2GPTQInt4_MOE");
#endif
#if defined(USE_MOE_KERNEL)
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");
#if defined(__aarch64__) && defined(CPU_USE_KML)

View File

@@ -0,0 +1,228 @@
/**
* @Description : AVX2 BF16 GEMM kernel with trivial Buffer abstractions
* @Author : Claude
* @Date : 2026-03-18
* @Version : 1.0.0
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*
* Unlike AMX kernels that use packed tile layouts (BufferB with 16x16 transpose),
* the AVX2 kernel uses row-major storage for all buffers.
* BufferA/B/C are thin wrappers over raw memory with trivial from_mat/to_mat.
*
* GEMM: C[m,n] = sum_k A[m,k] * B[n,k]
* A: [M, K] row-major BF16 (input activations)
* B: [N, K] row-major BF16 (weights, each row is one output neuron)
* C: [M, N] row-major FP32 (output)
**/
#ifndef CPUINFER_OPERATOR_AVX2_BF16_GEMM_H
#define CPUINFER_OPERATOR_AVX2_BF16_GEMM_H
#include <immintrin.h>
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <cstring>
#include <memory>
#include <tuple>
#include "avx2_bf16_utils.hpp"
namespace avx2 {
// Split range [0, total) among nth threads, return [start, end) for thread ith
static inline std::pair<int, int> split_range(int total, int ith, int nth) {
int per = total / nth;
int rem = total % nth;
int start = ith * per + std::min(ith, rem);
int end = start + per + (ith < rem ? 1 : 0);
return {start, end};
}
struct GemmKernelAVX2BF16 {
using dt = ggml_bf16_t;
using output_t = float;
static constexpr int M_STEP = 1; // No M-direction padding needed (vs AMX 16)
static constexpr int N_STEP = 8; // 8-wide FP32 AVX2 (vs AMX 32)
static constexpr int K_STEP = 8; // Process 8 K elements at a time
static constexpr int N_BLOCK = 64; // N blocking for cache
static constexpr int K_BLOCK = 256; // K blocking for cache
static constexpr double ELEMENT_SIZE = 2.0; // BF16 = 2 bytes
// No AMX tile configuration needed
static void config() {}
// Thread count for N-dimension parallelism
// Must return >= 1 to avoid division by zero in moe_base task dispatch
static int recommended_nth(int n) {
return std::max(1, n / N_STEP);
}
// Split N range for multi-threaded GEMM
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
return split_range(n, ith, nth);
}
// ========================================================================
// BufferA: Input activations [M, K] row-major BF16
// from_mat() = memcpy (no packing needed for AVX2)
// ========================================================================
struct BufferA {
ggml_bf16_t* data = nullptr;
size_t max_m = 0;
size_t k = 0;
BufferA() = default;
BufferA(size_t m, size_t k_, void* ptr) : max_m(m), k(k_), data((ggml_bf16_t*)ptr) {}
static size_t required_size(size_t m, size_t k) {
return m * k * sizeof(ggml_bf16_t);
}
void set_data(void* ptr) { data = (ggml_bf16_t*)ptr; }
// Copy input rows into buffer (trivial memcpy)
void from_mat(int m, const ggml_bf16_t* src, int ith, int nth) {
if (ith == 0 && nth == 1) {
std::memcpy(data, src, (size_t)m * k * sizeof(ggml_bf16_t));
} else {
// Multi-threaded: split by rows
auto [m_start, m_end] = split_range(m, ith, nth);
std::memcpy(data + m_start * k, src + m_start * k,
(size_t)(m_end - m_start) * k * sizeof(ggml_bf16_t));
}
}
};
// ========================================================================
// BufferB: Weight matrix [N, K] row-major BF16
// from_mat() = memcpy (no transpose/packing needed)
// ========================================================================
struct BufferB {
ggml_bf16_t* b = nullptr;
size_t n = 0;
size_t k = 0;
BufferB() = default;
BufferB(size_t n_, size_t k_, void* ptr) : n(n_), k(k_), b((ggml_bf16_t*)ptr) {}
static size_t required_size(size_t n, size_t k) {
return n * k * sizeof(ggml_bf16_t);
}
// Copy weight data (multi-threaded by N dimension)
void from_mat(const ggml_bf16_t* src, int ith, int nth) {
auto [n_start, n_end] = split_range((int)n, ith, nth);
std::memcpy(b + n_start * k, src + n_start * k,
(size_t)(n_end - n_start) * k * sizeof(ggml_bf16_t));
}
};
// ========================================================================
// BufferC: Output matrix [M, N] row-major FP32
// to_mat() converts FP32 -> BF16 and writes out
// ========================================================================
struct BufferC {
float* data = nullptr;
size_t max_m = 0;
size_t n = 0;
BufferC() = default;
BufferC(size_t m, size_t n_, void* ptr) : max_m(m), n(n_), data((float*)ptr) {}
static size_t required_size(size_t m, size_t n) {
return m * n * sizeof(float);
}
void set_data(void* ptr) { data = (float*)ptr; }
// Convert FP32 output to BF16 and write to destination
void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {
auto [n_start, n_end] = split_range_n((int)n, ith, nth);
for (int mi = 0; mi < m; mi++) {
float* src_row = data + mi * n;
ggml_bf16_t* dst_row = dst + mi * n;
int j = n_start;
for (; j + 8 <= n_end; j += 8) {
__m256 v = _mm256_loadu_ps(src_row + j);
store_fp32_to_bf16(dst_row + j, v);
}
// Scalar tail
for (; j < n_end; j++) {
dst_row[j] = GGML_FP32_TO_BF16(src_row[j]);
}
}
}
};
};
// ============================================================================
// AVX2 BF16 GEMM functions
// C[m,n] = sum_k A[m,k] * B[n,k]
// ============================================================================
// General GEMM (works for both vec_mul m=1 and mat_mul m>1)
static inline void gemm_bf16(
int m, int n, int k,
GemmKernelAVX2BF16::BufferA& a,
GemmKernelAVX2BF16::BufferB& b,
GemmKernelAVX2BF16::BufferC& c,
int ith, int nth) {
auto [n_start, n_end] = split_range(n, ith, nth);
for (int ni = n_start; ni < n_end; ni++) {
const ggml_bf16_t* b_row = b.b + (size_t)ni * k;
for (int mi = 0; mi < m; mi++) {
const ggml_bf16_t* a_row = a.data + (size_t)mi * a.k;
// AVX2 BF16 dot product (matches ggml_vec_dot_bf16 AVX2 path)
__m256 c1 = _mm256_setzero_ps();
__m256 c2 = _mm256_setzero_ps();
__m256 c3 = _mm256_setzero_ps();
__m256 c4 = _mm256_setzero_ps();
int ki = 0;
for (; ki + 32 <= k; ki += 32) {
c1 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki), load_bf16_to_fp32(b_row + ki), c1);
c2 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki + 8), load_bf16_to_fp32(b_row + ki + 8), c2);
c3 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki + 16), load_bf16_to_fp32(b_row + ki + 16), c3);
c4 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki + 24), load_bf16_to_fp32(b_row + ki + 24), c4);
}
float sum = hsum_avx2(_mm256_add_ps(_mm256_add_ps(c1, c3), _mm256_add_ps(c2, c4)));
// Scalar tail
for (; ki < k; ki++) {
sum += GGML_BF16_TO_FP32(a_row[ki]) * GGML_BF16_TO_FP32(b_row[ki]);
}
c.data[mi * n + ni] = sum;
}
}
}
// vec_mul: dispatch to gemm_bf16
static inline void vec_mul(
int m, int n, int k,
std::shared_ptr<GemmKernelAVX2BF16::BufferA>& a,
std::shared_ptr<GemmKernelAVX2BF16::BufferB>& b,
std::shared_ptr<GemmKernelAVX2BF16::BufferC>& c,
int ith, int nth) {
gemm_bf16(m, n, k, *a, *b, *c, ith, nth);
}
// mat_mul: dispatch to gemm_bf16
static inline void mat_mul(
int m, int n, int k,
std::shared_ptr<GemmKernelAVX2BF16::BufferA>& a,
std::shared_ptr<GemmKernelAVX2BF16::BufferB>& b,
std::shared_ptr<GemmKernelAVX2BF16::BufferC>& c,
int ith, int nth) {
gemm_bf16(m, n, k, *a, *b, *c, ith, nth);
}
} // namespace avx2
#endif // CPUINFER_OPERATOR_AVX2_BF16_GEMM_H

View File

@@ -0,0 +1,132 @@
/**
* @Description : AVX2 BF16 utility functions (bf16<->fp32 conversion, activation)
* @Author : Claude
* @Date : 2026-03-18
* @Version : 1.0.0
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*
* AVX2 ports of the AVX512 utilities in amx/la/utils.hpp and amx/la/amx.hpp.
* Uses 256-bit SIMD (8 floats) instead of 512-bit (16 floats).
**/
#ifndef CPUINFER_OPERATOR_AVX2_BF16_UTILS_H
#define CPUINFER_OPERATOR_AVX2_BF16_UTILS_H
#include <immintrin.h>
#include <cmath>
#include "llama.cpp/ggml.h"
namespace avx2 {
// ============================================================================
// BF16 <-> FP32 conversion
// ============================================================================
// Load 8 BF16 values and convert to 8 FP32 values
// BF16 is the upper 16 bits of FP32, so shift left by 16
static inline __m256 load_bf16_to_fp32(const ggml_bf16_t* src) {
__m128i bf16 = _mm_loadu_si128((const __m128i*)src);
__m256i i32 = _mm256_cvtepu16_epi32(bf16);
return _mm256_castsi256_ps(_mm256_slli_epi32(i32, 16));
}
// Convert 8 FP32 values to 8 BF16 values with round-to-nearest-even
// Matches ggml_compute_fp32_to_bf16 semantics (ggml-impl.h:87)
// and amx/la/utils.hpp:24 tie-bit correction
static inline void store_fp32_to_bf16(ggml_bf16_t* dst, __m256 src) {
__m256i i32 = _mm256_castps_si256(src);
// Round-to-nearest-even: add 0x7FFF + ((val >> 16) & 1)
__m256i tie_bit = _mm256_and_si256(_mm256_srli_epi32(i32, 16), _mm256_set1_epi32(1));
__m256i round = _mm256_add_epi32(_mm256_set1_epi32(0x7FFF), tie_bit);
__m256i rounded = _mm256_add_epi32(i32, round);
__m256i shifted = _mm256_srli_epi32(rounded, 16);
// Pack 32-bit -> 16-bit
// _mm_packus_epi32 processes 128-bit lanes: packs [lo0..lo3, hi0..hi3] -> [lo0..lo3, hi0..hi3]
__m128i lo = _mm256_castsi256_si128(shifted);
__m128i hi = _mm256_extracti128_si256(shifted, 1);
__m128i packed = _mm_packus_epi32(lo, hi);
_mm_storeu_si128((__m128i*)dst, packed);
}
// Load 16 BF16 -> 2x8 FP32 (corresponds to avx512_32xbf16_to_32xfp32)
static inline void load_16xbf16_to_2x8xfp32(const ggml_bf16_t* src, __m256* out0, __m256* out1) {
*out0 = load_bf16_to_fp32(src);
*out1 = load_bf16_to_fp32(src + 8);
}
// Store 2x8 FP32 -> 16 BF16 (corresponds to avx512_32xfp32_to_32xbf16)
static inline void store_2x8xfp32_to_16xbf16(__m256* in0, __m256* in1, ggml_bf16_t* dst) {
store_fp32_to_bf16(dst, *in0);
store_fp32_to_bf16(dst + 8, *in1);
}
// ============================================================================
// Horizontal sum for __m256 (8 floats -> 1 float)
// ============================================================================
static inline float hsum_avx2(__m256 v) {
__m128 hi = _mm256_extractf128_ps(v, 1);
__m128 lo = _mm256_castps256_ps128(v);
__m128 sum = _mm_add_ps(lo, hi);
sum = _mm_add_ps(sum, _mm_movehl_ps(sum, sum));
sum = _mm_add_ss(sum, _mm_movehdup_ps(sum));
return _mm_cvtss_f32(sum);
}
// ============================================================================
// Fast exp approximation (AVX2 port of amx::exp_avx512)
// ============================================================================
static inline __m256 exp_avx2(__m256 x) {
const __m256 log2e = _mm256_set1_ps(1.44269504089f);
__m256 y = _mm256_mul_ps(x, log2e);
__m256i int_part = _mm256_cvtps_epi32(y);
__m256 frac_part = _mm256_sub_ps(y, _mm256_cvtepi32_ps(int_part));
const __m256 poly_1 = _mm256_set1_ps(0.9999999995f);
const __m256 poly_2 = _mm256_set1_ps(0.6931471805f);
const __m256 poly_3 = _mm256_set1_ps(0.2402265069f);
const __m256 poly_4 = _mm256_set1_ps(0.0555041087f);
const __m256 poly_5 = _mm256_set1_ps(0.0096181291f);
const __m256 poly_6 = _mm256_set1_ps(0.0013333558f);
__m256 frac_exp = _mm256_fmadd_ps(
_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(poly_6, frac_part, poly_5), frac_part, poly_4),
frac_part, poly_3),
frac_part, poly_2),
frac_part, poly_1);
// 2^int_part: AVX2 doesn't have _mm256_scalef_ps, use manual construction
// 2^n = reinterpret((n + 127) << 23) for float
// Clamp int_part to [-126, 127] to avoid invalid bit patterns:
// int_part < -126 → biased < 1 → denorm/zero (scalef_ps would give 0)
// int_part > 127 → biased > 254 → inf (scalef_ps would give inf)
__m256i clamped = _mm256_max_epi32(_mm256_min_epi32(int_part, _mm256_set1_epi32(127)),
_mm256_set1_epi32(-126));
__m256i biased = _mm256_add_epi32(clamped, _mm256_set1_epi32(127));
__m256i shifted = _mm256_slli_epi32(biased, 23);
__m256 two_pow_i = _mm256_castsi256_ps(shifted);
return _mm256_mul_ps(two_pow_i, frac_exp);
}
// ============================================================================
// SiLU activation: silu(gate) * up = gate * sigmoid(gate) * up
// AVX2 port of amx::act_fn
// ============================================================================
static inline __m256 act_fn(__m256 gate_val, __m256 up_val) {
__m256 neg_gate_val = _mm256_sub_ps(_mm256_setzero_ps(), gate_val);
// Clamp to avoid exp overflow
const __m256 max_exp_input = _mm256_set1_ps(88.0f);
neg_gate_val = _mm256_min_ps(neg_gate_val, max_exp_input);
__m256 exp_neg_gate = exp_avx2(neg_gate_val);
__m256 denom = _mm256_add_ps(_mm256_set1_ps(1.0f), exp_neg_gate);
__m256 act_val = _mm256_div_ps(gate_val, denom);
return _mm256_mul_ps(act_val, up_val);
}
} // namespace avx2
#endif // CPUINFER_OPERATOR_AVX2_BF16_UTILS_H

View File

@@ -0,0 +1,327 @@
/**
* @Description : AVX2 BF16 MoE operator (ported from amx/bf16-moe.hpp)
* @Author : Claude
* @Date : 2026-03-18
* @Version : 1.0.0
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*
* Simplified from AMX version:
* - BufferB::from_mat is memcpy (no AMX transpose)
* - No unpack_nk_block_bf16 (no AMX packed format to undo)
* - write_weights_to_buffer uses direct memcpy with full TP routing logic
**/
#ifndef CPUINFER_OPERATOR_AVX2_BF16_MOE_H
#define CPUINFER_OPERATOR_AVX2_BF16_MOE_H
#include "avx2_bf16_gemm.hpp"
#include "moe_base.hpp"
template <class T = avx2::GemmKernelAVX2BF16>
class AVX2_BF16_MOE_TP : public AVX2_MOE_BASE<T, AVX2_BF16_MOE_TP<T>> {
using Base = AVX2_MOE_BASE<T, AVX2_BF16_MOE_TP<T>>;
using Base::config_;
using Base::down_ba_;
using Base::down_bb_;
using Base::down_bc_;
using Base::gate_bb_;
using Base::gate_bc_;
using Base::gate_up_ba_;
using Base::m_local_num_;
using Base::tp_part_idx;
using Base::up_bb_;
using Base::up_bc_;
public:
using typename Base::input_t;
using typename Base::output_t;
AVX2_BF16_MOE_TP() = default;
AVX2_BF16_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {}
void derived_init() {
printf("Created AVX2_BF16_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu()));
}
~AVX2_BF16_MOE_TP() = default;
// CRTP buffer creation
size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }
size_t buffer_b_required_size_impl(size_t n, size_t k) const { return T::BufferB::required_size(n, k); }
size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
return std::make_shared<typename T::BufferA>(m, k, data);
}
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
return std::make_shared<typename T::BufferB>(n, k, data);
}
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
return std::make_shared<typename T::BufferC>(m, n, data);
}
// GEMM dispatch — uses avx2::gemm_bf16
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
int m = m_local_num_[expert_idx];
auto& ba = gate_up_ba_[expert_idx];
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
avx2::gemm_bf16(m, config_.intermediate_size, config_.hidden_size, *ba, *bb, *bc, ith, nth);
}
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
int m = m_local_num_[expert_idx];
avx2::gemm_bf16(m, config_.hidden_size, config_.intermediate_size,
*down_ba_[expert_idx], *down_bb_[expert_idx], *down_bc_[expert_idx], ith, nth);
}
/**
* Load BF16 weights from contiguous memory layout.
* BufferB::from_mat is a trivial memcpy for AVX2 (no AMX transpose).
*/
void load_weights() {
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
auto pool = config_.pool->get_subpool(tp_part_idx);
if (config_.gate_proj == nullptr) {
throw std::runtime_error("BF16 MOE requires native BF16 weight.");
}
// Load gate + up weights
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
gate_bb_[expert_idx]->from_mat(
(ggml_bf16_t*)config_.gate_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
ith, nth);
up_bb_[expert_idx]->from_mat(
(ggml_bf16_t*)config_.up_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
ith, nth);
},
nullptr);
// Load down weights
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
down_bb_[expert_idx]->from_mat(
(ggml_bf16_t*)config_.down_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
ith, nth);
},
nullptr);
}
/**
* Write weights to GPU buffer for dynamic expert offload.
* Preserves full TP routing logic from AMX version but uses direct memcpy
* instead of unpack_nk_block_bf16 (since BufferB is row-major, not AMX-packed).
*/
void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id,
const GeneralMOEConfig& full_config, const std::vector<uintptr_t>& w13_weight_ptrs,
[[maybe_unused]] const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
[[maybe_unused]] const std::vector<uintptr_t>& w2_scale_ptrs) const {
auto& config = config_;
auto pool = config.pool->get_subpool(tp_part_idx);
// W13 (gate+up): Shape [intermediate, hidden], split by N across GPU TPs
const int cpu_n_w13 = config.intermediate_size;
const int cpu_k_w13 = config.hidden_size;
const int gpu_n_w13 = full_config.intermediate_size / gpu_tp_count;
const int gpu_k_w13 = full_config.hidden_size;
const int global_n_offset_w13 = tp_part_idx * cpu_n_w13;
const size_t gpu_w13_weight_per_mat = (size_t)gpu_n_w13 * gpu_k_w13;
// W2 (down): Shape [hidden, intermediate], split by K across GPU TPs
const int cpu_n_w2 = config.hidden_size;
const int cpu_k_w2 = config.intermediate_size;
const int gpu_k_w2 = full_config.intermediate_size / gpu_tp_count;
const int global_k_offset_w2 = tp_part_idx * cpu_k_w2;
constexpr int NUM_W13_TASKS = 32;
constexpr int NUM_W2_TASKS = 32;
const int total_tasks = NUM_W13_TASKS * 2 + NUM_W2_TASKS;
pool->do_work_stealing_job(
total_tasks, nullptr,
[=, &w13_weight_ptrs, &w2_weight_ptrs, this](int task_id) {
if (task_id < NUM_W13_TASKS * 2) {
// W13 weight task: copy rows from BufferB to GPU buffer
const bool is_up = task_id >= NUM_W13_TASKS;
const int chunk_idx = task_id % NUM_W13_TASKS;
const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];
const int rows_per_task = (cpu_n_w13 + NUM_W13_TASKS - 1) / NUM_W13_TASKS;
const int row_start = chunk_idx * rows_per_task;
const int row_end = std::min(row_start + rows_per_task, cpu_n_w13);
if (row_start >= cpu_n_w13) return;
for (int row = row_start; row < row_end; row++) {
const int global_n = global_n_offset_w13 + row;
const int target_gpu = global_n / gpu_n_w13;
const int n_in_gpu = global_n % gpu_n_w13;
ggml_bf16_t* dst = (ggml_bf16_t*)w13_weight_ptrs[target_gpu];
const size_t expert_weight_off = is_up ? gpu_w13_weight_per_mat : 0;
// BufferB is row-major [N, K], direct copy
std::memcpy(dst + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13,
bb->b + (size_t)row * cpu_k_w13,
cpu_k_w13 * sizeof(ggml_bf16_t));
}
} else {
// W2 weight task: copy rows, split K across GPU TPs
const int chunk_idx = task_id - NUM_W13_TASKS * 2;
const auto& bb = down_bb_[expert_id];
const int rows_per_task = (cpu_n_w2 + NUM_W2_TASKS - 1) / NUM_W2_TASKS;
const int row_start = chunk_idx * rows_per_task;
const int row_end = std::min(row_start + rows_per_task, cpu_n_w2);
if (row_start >= cpu_n_w2) return;
for (int row = row_start; row < row_end; row++) {
// For W2, K dimension is split across GPU TPs
// Iterate over all gpu_k_w2-sized slices within this CPU TP's K range
for (int k_start = 0; k_start < cpu_k_w2; k_start += gpu_k_w2) {
const int k_slice_end = std::min(k_start + gpu_k_w2, cpu_k_w2);
const int k_slice_len = k_slice_end - k_start;
// Map to correct GPU TP
const int global_k = global_k_offset_w2 + k_start;
const int target_gpu = global_k / gpu_k_w2;
const int k_in_gpu = global_k % gpu_k_w2;
ggml_bf16_t* dst = (ggml_bf16_t*)w2_weight_ptrs[target_gpu];
std::memcpy(dst + (size_t)row * gpu_k_w2 + k_in_gpu,
bb->b + (size_t)row * cpu_k_w2 + k_start,
k_slice_len * sizeof(ggml_bf16_t));
}
}
}
},
nullptr);
}
};
// ============================================================================
// TP_MOE specialization for AVX2_BF16_MOE_TP
// Handles per-expert pointer loading and TP weight splitting
// (Ported from amx/bf16-moe.hpp TP_MOE<AMX_BF16_MOE_TP<K>>)
// ============================================================================
template <typename K>
class TP_MOE<AVX2_BF16_MOE_TP<K>> : public TP_MOE<AVX2_MOE_BASE<K, AVX2_BF16_MOE_TP<K>>> {
public:
using Base = TP_MOE<AVX2_MOE_BASE<K, AVX2_BF16_MOE_TP<K>>>;
using Base::Base;
void load_weights() override {
auto& config = this->config;
auto& tps = this->tps;
auto& tp_count = this->tp_count;
auto pool = config.pool;
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
if (config.gate_projs.empty() && config.gate_proj == nullptr) {
throw std::runtime_error("no weight source");
}
const bool use_per_expert_ptrs = !config.gate_projs.empty();
const size_t full_weight_elems = (size_t)config.intermediate_size * config.hidden_size;
pool->dispense_backend()->do_numa_job([&, this](int i) {
auto& tpc = tps[i]->config_;
const size_t tp_weight_elems = (size_t)tpc.intermediate_size * tpc.hidden_size;
// Allocate temporary BF16 buffers for this TP part
tpc.gate_proj = new ggml_bf16_t[tpc.expert_num * tp_weight_elems];
tpc.up_proj = new ggml_bf16_t[tpc.expert_num * tp_weight_elems];
tpc.down_proj = new ggml_bf16_t[tpc.expert_num * tp_weight_elems];
const size_t gate_up_weight_src_offset = i * tp_weight_elems;
const size_t down_weight_src_col_offset = i * (size_t)tpc.intermediate_size;
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&, &tpc](int expert_id_) {
const size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
ggml_bf16_t* gate_dst = (ggml_bf16_t*)tpc.gate_proj + expert_id * tp_weight_elems;
ggml_bf16_t* up_dst = (ggml_bf16_t*)tpc.up_proj + expert_id * tp_weight_elems;
ggml_bf16_t* down_dst = (ggml_bf16_t*)tpc.down_proj + expert_id * tp_weight_elems;
const ggml_bf16_t* gate_src;
const ggml_bf16_t* up_src;
const ggml_bf16_t* down_src;
if (use_per_expert_ptrs) {
gate_src = (const ggml_bf16_t*)config.gate_projs[0][expert_id] + gate_up_weight_src_offset;
up_src = (const ggml_bf16_t*)config.up_projs[0][expert_id] + gate_up_weight_src_offset;
down_src = (const ggml_bf16_t*)config.down_projs[0][expert_id];
} else {
gate_src = (const ggml_bf16_t*)config.gate_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
up_src = (const ggml_bf16_t*)config.up_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
down_src = (const ggml_bf16_t*)config.down_proj + expert_id * full_weight_elems;
}
// Copy gate and up weights (column-slice for TP)
std::memcpy(gate_dst, gate_src, tp_weight_elems * sizeof(ggml_bf16_t));
std::memcpy(up_dst, up_src, tp_weight_elems * sizeof(ggml_bf16_t));
// Copy down weights (row-wise split: each row picks a slice of columns)
for (int row = 0; row < config.hidden_size; row++) {
const size_t src_row_offset = (size_t)row * (size_t)config.intermediate_size + down_weight_src_col_offset;
const size_t dst_row_offset = (size_t)row * (size_t)tpc.intermediate_size;
std::memcpy(down_dst + dst_row_offset, down_src + src_row_offset,
(size_t)tpc.intermediate_size * sizeof(ggml_bf16_t));
}
},
nullptr);
});
// Call per-TP load_weights (which does BufferB::from_mat = memcpy)
pool->dispense_backend()->do_numa_job([&, this](int i) {
tps[i]->load_weights();
});
// Free temporary buffers
pool->dispense_backend()->do_numa_job([&, this](int i) {
auto& tpc = tps[i]->config_;
delete[] (ggml_bf16_t*)tpc.gate_proj;
delete[] (ggml_bf16_t*)tpc.up_proj;
delete[] (ggml_bf16_t*)tpc.down_proj;
});
this->weights_loaded = true;
}
void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) {
if (this->weights_loaded == false) throw std::runtime_error("Not Loaded");
if (this->tps.empty()) throw std::runtime_error("No TP parts initialized");
if ((int)w13_weight_ptrs.size() != gpu_tp_count || (int)w2_weight_ptrs.size() != gpu_tp_count) {
throw std::runtime_error("Weight pointer arrays size must match gpu_tp_count");
}
this->config.pool->dispense_backend()->do_numa_job([&, this](int i) {
this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs,
w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs);
});
}
};
#endif // CPUINFER_OPERATOR_AVX2_BF16_MOE_H

View File

@@ -0,0 +1,598 @@
/**
* @Description : AVX2 FP8 MoE operator (ported from amx/fp8-moe.hpp)
* @Author : Claude
* @Date : 2026-03-18
* @Version : 1.0.0
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*
* FP8 E4M3 weights with 128×128 block-wise float32 scales.
* Dequantization: FP8→FP32 via precomputed 256-entry LUT + AVX2 gather.
* GEMM: BF16 input × FP32 dequantized weight → FP32 output.
**/
#ifndef CPUINFER_OPERATOR_AVX2_FP8_MOE_H
#define CPUINFER_OPERATOR_AVX2_FP8_MOE_H
#include "avx2_bf16_gemm.hpp"
#include "avx2_bf16_utils.hpp"
#include "fp8_dequant.hpp"
#include "moe_base.hpp"
namespace avx2 {
inline int div_up(int a, int b) { return (a + b - 1) / b; }
struct GemmKernelAVX2FP8 {
using dt = ggml_bf16_t;
using output_t = float;
static constexpr int M_STEP = 1;
static constexpr int N_STEP = 8;
static constexpr int K_STEP = 8;
static constexpr int BLOCK_SIZE = 128; // 128×128 block quantization
static constexpr int N_BLOCK = 128;
static constexpr int K_BLOCK = 128;
static constexpr double ELEMENT_SIZE = 1.0; // FP8 = 1 byte
static void config() {}
static int recommended_nth(int n) {
return std::max(1, div_up(n, N_BLOCK));
}
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
return avx2::split_range(n, ith, nth);
}
// ========================================================================
// BufferA: BF16 activations [M, K] — same as BF16 backend
// ========================================================================
struct BufferA {
ggml_bf16_t* data = nullptr;
size_t max_m = 0;
size_t k = 0;
BufferA() = default;
BufferA(size_t m, size_t k_, void* ptr) : max_m(m), k(k_), data((ggml_bf16_t*)ptr) {}
static size_t required_size(size_t m, size_t k) {
return m * k * sizeof(ggml_bf16_t);
}
void set_data(void* ptr) { data = (ggml_bf16_t*)ptr; }
void from_mat(int m, const ggml_bf16_t* src, int ith, int nth) {
if (ith == 0 && nth == 1) {
std::memcpy(data, src, (size_t)m * k * sizeof(ggml_bf16_t));
} else {
auto [m_start, m_end] = avx2::split_range(m, ith, nth);
std::memcpy(data + m_start * k, src + m_start * k,
(size_t)(m_end - m_start) * k * sizeof(ggml_bf16_t));
}
}
};
// ========================================================================
// BufferB: FP8 weights [N, K] + float32 scales [N/BS, K/BS]
// Row-major, no packing. from_mat = memcpy.
// ========================================================================
struct BufferB {
uint8_t* b = nullptr; // FP8 weights
float* d = nullptr; // Block-wise scales
size_t n = 0;
size_t k = 0;
int block_size = BLOCK_SIZE;
BufferB() = default;
BufferB(size_t n_, size_t k_, int bs, void* ptr) : n(n_), k(k_), block_size(bs) {
b = (uint8_t*)ptr;
size_t weight_bytes = n * k;
d = (float*)((uint8_t*)ptr + weight_bytes);
}
static size_t required_size(size_t n, size_t k, int bs) {
size_t n_blocks_n = div_up((int)n, bs);
size_t n_blocks_k = div_up((int)k, bs);
return n * k + n_blocks_n * n_blocks_k * sizeof(float);
}
void from_mat(const uint8_t* src_weights, const float* src_scales, int ith, int nth) {
// Copy weights (split by N)
auto [n_start, n_end] = avx2::split_range((int)n, ith, nth);
std::memcpy(b + n_start * k, src_weights + n_start * k,
(size_t)(n_end - n_start) * k);
// Copy scales (split by N blocks)
int n_blocks_k = div_up((int)k, block_size);
int nb_start = n_start / block_size;
int nb_end = div_up(n_end, block_size);
std::memcpy(d + nb_start * n_blocks_k, src_scales + nb_start * n_blocks_k,
(size_t)(nb_end - nb_start) * n_blocks_k * sizeof(float));
}
};
// ========================================================================
// BufferC: FP32 output — same as BF16 backend
// ========================================================================
struct BufferC {
float* data = nullptr;
size_t max_m = 0;
size_t n = 0;
BufferC() = default;
BufferC(size_t m, size_t n_, void* ptr) : max_m(m), n(n_), data((float*)ptr) {}
static size_t required_size(size_t m, size_t n) {
return m * n * sizeof(float);
}
void set_data(void* ptr) { data = (float*)ptr; }
void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {
auto [n_start, n_end] = avx2::split_range((int)n, ith, nth);
for (int mi = 0; mi < m; mi++) {
float* src_row = data + mi * n;
ggml_bf16_t* dst_row = dst + mi * n;
int j = n_start;
for (; j + 8 <= n_end; j += 8) {
__m256 v = _mm256_loadu_ps(src_row + j);
store_fp32_to_bf16(dst_row + j, v);
}
for (; j < n_end; j++) {
dst_row[j] = GGML_FP32_TO_BF16(src_row[j]);
}
}
}
};
};
// ============================================================================
// AVX2 FP8 GEMM: C[m,n] = sum_k (A[m,k] * dequant(B[n,k])) * scale[n/BS, k/BS]
// ============================================================================
static inline void gemm_fp8(
int m, int n, int k,
GemmKernelAVX2FP8::BufferA& a,
GemmKernelAVX2FP8::BufferB& b,
GemmKernelAVX2FP8::BufferC& c,
int ith, int nth) {
ensure_fp8_lut_initialized();
auto [n_start, n_end] = split_range(n, ith, nth);
const int block_size = b.block_size;
const int n_blocks_k = div_up(k, block_size);
for (int ni = n_start; ni < n_end; ni++) {
const uint8_t* b_row = b.b + (size_t)ni * k;
const int n_block_idx = ni / block_size;
for (int mi = 0; mi < m; mi++) {
const ggml_bf16_t* a_row = a.data + (size_t)mi * a.k;
float sum = 0.0f;
for (int kb = 0; kb < k; kb += block_size) {
int k_len = std::min(block_size, k - kb);
int k_block_idx = kb / block_size;
float scale = b.d[n_block_idx * n_blocks_k + k_block_idx];
// Accumulate within this block
__m256 acc1 = _mm256_setzero_ps();
__m256 acc2 = _mm256_setzero_ps();
__m256 acc3 = _mm256_setzero_ps();
__m256 acc4 = _mm256_setzero_ps();
int ki = 0;
for (; ki + 32 <= k_len; ki += 32) {
acc1 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + kb + ki),
fp8x8_to_fp32x8(b_row + kb + ki), acc1);
acc2 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + kb + ki + 8),
fp8x8_to_fp32x8(b_row + kb + ki + 8), acc2);
acc3 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + kb + ki + 16),
fp8x8_to_fp32x8(b_row + kb + ki + 16), acc3);
acc4 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + kb + ki + 24),
fp8x8_to_fp32x8(b_row + kb + ki + 24), acc4);
}
for (; ki + 8 <= k_len; ki += 8) {
acc1 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + kb + ki),
fp8x8_to_fp32x8(b_row + kb + ki), acc1);
}
float block_sum = hsum_avx2(_mm256_add_ps(_mm256_add_ps(acc1, acc3),
_mm256_add_ps(acc2, acc4)));
// Scalar tail
for (; ki < k_len; ki++) {
block_sum += GGML_BF16_TO_FP32(a_row[kb + ki]) * fp8_to_fp32_scalar(b_row[kb + ki]);
}
sum += block_sum * scale;
}
c.data[mi * n + ni] = sum;
}
}
}
} // namespace avx2
// ============================================================================
// AVX2 FP8 MoE operator (CRTP derived from AVX2_MOE_BASE)
// ============================================================================
template <class T = avx2::GemmKernelAVX2FP8>
class AVX2_FP8_MOE_TP : public AVX2_MOE_BASE<T, AVX2_FP8_MOE_TP<T>> {
using Base = AVX2_MOE_BASE<T, AVX2_FP8_MOE_TP<T>>;
using Base::config_;
using Base::down_ba_;
using Base::down_bb_;
using Base::down_bc_;
using Base::gate_bb_;
using Base::gate_bc_;
using Base::gate_up_ba_;
using Base::m_local_num_;
using Base::tp_part_idx;
using Base::up_bb_;
using Base::up_bc_;
public:
using typename Base::input_t;
using typename Base::output_t;
AVX2_FP8_MOE_TP() = default;
AVX2_FP8_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {}
void derived_init() {
avx2::ensure_fp8_lut_initialized();
auto& quant_config = config_.quant_config;
if (quant_config.group_size == 0 || quant_config.zero_point) {
throw std::runtime_error("AVX2 FP8 MoE only supports block-wise FP8 (group_size > 0, no zero_point)");
}
printf("Created AVX2_FP8_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu()));
}
~AVX2_FP8_MOE_TP() = default;
// CRTP buffer creation — with group_size for BufferB
size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
return T::BufferB::required_size(n, k, config_.quant_config.group_size);
}
size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
return std::make_shared<typename T::BufferA>(m, k, data);
}
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
return std::make_shared<typename T::BufferB>(n, k, config_.quant_config.group_size, data);
}
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
return std::make_shared<typename T::BufferC>(m, n, data);
}
// GEMM dispatch
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
int m = m_local_num_[expert_idx];
auto& ba = gate_up_ba_[expert_idx];
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
avx2::gemm_fp8(m, config_.intermediate_size, config_.hidden_size, *ba, *bb, *bc, ith, nth);
}
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
int m = m_local_num_[expert_idx];
avx2::gemm_fp8(m, config_.hidden_size, config_.intermediate_size,
*down_ba_[expert_idx], *down_bb_[expert_idx], *down_bc_[expert_idx], ith, nth);
}
// Load FP8 weights + scales from contiguous memory
void load_weights() {
auto& quant_config = config_.quant_config;
int group_size = quant_config.group_size;
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
auto pool = config_.pool->get_subpool(tp_part_idx);
if (config_.gate_scale == nullptr) {
throw std::runtime_error("FP8 MOE requires scale pointers.");
}
// Load gate + up weights
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map, group_size](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
size_t weight_offset = logical_expert_id * config_.intermediate_size * config_.hidden_size;
size_t scale_offset = logical_expert_id *
avx2::div_up(config_.hidden_size, group_size) *
avx2::div_up(config_.intermediate_size, group_size);
gate_bb_[expert_idx]->from_mat(
(uint8_t*)config_.gate_proj + weight_offset,
(float*)config_.gate_scale + scale_offset,
ith, nth);
up_bb_[expert_idx]->from_mat(
(uint8_t*)config_.up_proj + weight_offset,
(float*)config_.up_scale + scale_offset,
ith, nth);
},
nullptr);
// Load down weights
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map, group_size](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
size_t weight_offset = logical_expert_id * config_.intermediate_size * config_.hidden_size;
size_t scale_offset = logical_expert_id *
avx2::div_up(config_.hidden_size, group_size) *
avx2::div_up(config_.intermediate_size, group_size);
down_bb_[expert_idx]->from_mat(
(uint8_t*)config_.down_proj + weight_offset,
(float*)config_.down_scale + scale_offset,
ith, nth);
},
nullptr);
}
// Write weights to GPU buffer (for dynamic expert offload / layerwise prefill)
void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id,
const GeneralMOEConfig& full_config, const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) const {
auto& config = config_;
auto pool = config.pool->get_subpool(tp_part_idx);
int group_size = config.quant_config.group_size;
// W13 (gate+up)
const int cpu_n_w13 = config.intermediate_size;
const int cpu_k_w13 = config.hidden_size;
const int gpu_n_w13 = full_config.intermediate_size / gpu_tp_count;
const int gpu_k_w13 = full_config.hidden_size;
const int global_n_offset_w13 = tp_part_idx * cpu_n_w13;
const size_t gpu_w13_weight_per_mat = (size_t)gpu_n_w13 * gpu_k_w13;
const int gpu_n_blocks_k_w13 = avx2::div_up(gpu_k_w13, group_size);
const size_t gpu_w13_scale_per_mat = (size_t)avx2::div_up(gpu_n_w13, group_size) * gpu_n_blocks_k_w13;
// W2 (down)
const int cpu_n_w2 = config.hidden_size;
const int cpu_k_w2 = config.intermediate_size;
const int gpu_k_w2 = full_config.intermediate_size / gpu_tp_count;
const int global_k_offset_w2 = tp_part_idx * cpu_k_w2;
const int cpu_n_blocks_k_w2 = avx2::div_up(cpu_k_w2, group_size);
constexpr int NUM_W13_TASKS = 32;
constexpr int NUM_W2_TASKS = 32;
const int total_tasks = NUM_W13_TASKS * 2 + NUM_W2_TASKS;
pool->do_work_stealing_job(
total_tasks, nullptr,
[=, &w13_weight_ptrs, &w13_scale_ptrs, &w2_weight_ptrs, &w2_scale_ptrs, this](int task_id) {
if (task_id < NUM_W13_TASKS * 2) {
const bool is_up = task_id >= NUM_W13_TASKS;
const int chunk_idx = task_id % NUM_W13_TASKS;
const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];
const int rows_per_task = avx2::div_up(cpu_n_w13, NUM_W13_TASKS);
const int row_start = chunk_idx * rows_per_task;
const int row_end = std::min(row_start + rows_per_task, cpu_n_w13);
if (row_start >= cpu_n_w13) return;
for (int row = row_start; row < row_end; row++) {
const int global_n = global_n_offset_w13 + row;
const int target_gpu = global_n / gpu_n_w13;
const int n_in_gpu = global_n % gpu_n_w13;
// Copy weight row
uint8_t* w_dst = (uint8_t*)w13_weight_ptrs[target_gpu];
const size_t expert_w_off = is_up ? gpu_w13_weight_per_mat : 0;
std::memcpy(w_dst + expert_w_off + (size_t)n_in_gpu * gpu_k_w13,
bb->b + (size_t)row * cpu_k_w13,
cpu_k_w13);
// Copy scale row (if at block boundary)
if (row % group_size == 0) {
int n_block = row / group_size;
int gpu_n_block = n_in_gpu / group_size;
float* s_dst = (float*)w13_scale_ptrs[target_gpu];
const size_t expert_s_off = is_up ? gpu_w13_scale_per_mat : 0;
std::memcpy(s_dst + expert_s_off + gpu_n_block * gpu_n_blocks_k_w13,
bb->d + n_block * avx2::div_up(cpu_k_w13, group_size),
avx2::div_up(cpu_k_w13, group_size) * sizeof(float));
}
}
} else {
const int chunk_idx = task_id - NUM_W13_TASKS * 2;
const auto& bb = down_bb_[expert_id];
const int rows_per_task = avx2::div_up(cpu_n_w2, NUM_W2_TASKS);
const int row_start = chunk_idx * rows_per_task;
const int row_end = std::min(row_start + rows_per_task, cpu_n_w2);
if (row_start >= cpu_n_w2) return;
for (int row = row_start; row < row_end; row++) {
// Iterate over all gpu_k_w2-sized slices within this CPU TP's K range
for (int k_start = 0; k_start < cpu_k_w2; k_start += gpu_k_w2) {
const int k_slice_len = std::min(gpu_k_w2, cpu_k_w2 - k_start);
const int global_k = global_k_offset_w2 + k_start;
const int target_gpu = global_k / gpu_k_w2;
const int k_in_gpu = global_k % gpu_k_w2;
uint8_t* w_dst = (uint8_t*)w2_weight_ptrs[target_gpu];
std::memcpy(w_dst + (size_t)row * gpu_k_w2 + k_in_gpu,
bb->b + (size_t)row * cpu_k_w2 + k_start,
k_slice_len);
// Copy scales for down (at block boundaries)
if (row % group_size == 0) {
int n_block = row / group_size;
float* s_dst = (float*)w2_scale_ptrs[target_gpu];
int gpu_n_blocks_k_w2 = avx2::div_up(gpu_k_w2, group_size);
int k_block_start = k_in_gpu / group_size;
int n_blocks_to_copy = std::min(cpu_n_blocks_k_w2, gpu_n_blocks_k_w2 - k_block_start);
std::memcpy(s_dst + n_block * gpu_n_blocks_k_w2 + k_block_start,
bb->d + n_block * cpu_n_blocks_k_w2 + k_start / group_size,
n_blocks_to_copy * sizeof(float));
}
} // end k_start loop
} // end row loop
}
},
nullptr);
}
};
// ============================================================================
// TP_MOE specialization — ported from amx/fp8-moe.hpp:628-738
// Handles per-expert pointer loading + TP weight/scale splitting
// ============================================================================
template <typename K>
class TP_MOE<AVX2_FP8_MOE_TP<K>> : public TP_MOE<AVX2_MOE_BASE<K, AVX2_FP8_MOE_TP<K>>> {
public:
using Base = TP_MOE<AVX2_MOE_BASE<K, AVX2_FP8_MOE_TP<K>>>;
using Base::Base;
void load_weights() override {
auto& config = this->config;
auto& tps = this->tps;
auto& tp_count = this->tp_count;
auto pool = config.pool;
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
const int group_size = config.quant_config.group_size;
if (group_size == 0 || config.quant_config.zero_point) {
throw std::runtime_error("FP8 MoE only supports block-wise (group_size > 0, zero_point=false)");
}
if (config.gate_projs.empty() && config.gate_proj == nullptr) {
throw std::runtime_error("no weight source");
}
const bool use_per_expert_ptrs = !config.gate_projs.empty();
const size_t full_weight_elems = (size_t)config.intermediate_size * config.hidden_size;
const size_t full_scale_elems =
(size_t)avx2::div_up(config.hidden_size, group_size) * avx2::div_up(config.intermediate_size, group_size);
pool->dispense_backend()->do_numa_job([&, this](int i) {
auto& tpc = tps[i]->config_;
const size_t tp_weight_elems = (size_t)tpc.intermediate_size * tpc.hidden_size;
const size_t tp_scale_elems =
(size_t)avx2::div_up(tpc.intermediate_size, group_size) * avx2::div_up(tpc.hidden_size, group_size);
// Allocate temporary buffers
tpc.gate_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
tpc.up_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
tpc.down_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
tpc.gate_scale = new float[tpc.expert_num * tp_scale_elems];
tpc.up_scale = new float[tpc.expert_num * tp_scale_elems];
tpc.down_scale = new float[tpc.expert_num * tp_scale_elems];
const size_t gate_up_weight_src_offset = i * tp_weight_elems;
const size_t gate_up_scale_src_offset = i * tp_scale_elems;
const size_t down_weight_src_col_offset = i * (size_t)tpc.intermediate_size;
const size_t down_scale_src_block_k_offset = down_weight_src_col_offset / (size_t)group_size;
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&, &tpc](int expert_id_) {
const size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
uint8_t* gate_dst = (uint8_t*)tpc.gate_proj + expert_id * tp_weight_elems;
uint8_t* up_dst = (uint8_t*)tpc.up_proj + expert_id * tp_weight_elems;
uint8_t* down_dst = (uint8_t*)tpc.down_proj + expert_id * tp_weight_elems;
float* gate_scale_dst = (float*)tpc.gate_scale + expert_id * tp_scale_elems;
float* up_scale_dst = (float*)tpc.up_scale + expert_id * tp_scale_elems;
float* down_scale_dst = (float*)tpc.down_scale + expert_id * tp_scale_elems;
const uint8_t* gate_src;
const uint8_t* up_src;
const uint8_t* down_src;
const float* gate_scale_src;
const float* up_scale_src;
const float* down_scale_src;
if (use_per_expert_ptrs) {
gate_src = (const uint8_t*)config.gate_projs[0][expert_id] + gate_up_weight_src_offset;
up_src = (const uint8_t*)config.up_projs[0][expert_id] + gate_up_weight_src_offset;
down_src = (const uint8_t*)config.down_projs[0][expert_id];
gate_scale_src = (const float*)config.gate_scales[0][expert_id] + gate_up_scale_src_offset;
up_scale_src = (const float*)config.up_scales[0][expert_id] + gate_up_scale_src_offset;
down_scale_src = (const float*)config.down_scales[0][expert_id];
} else {
gate_src = (const uint8_t*)config.gate_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
up_src = (const uint8_t*)config.up_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
down_src = (const uint8_t*)config.down_proj + expert_id * full_weight_elems;
gate_scale_src = (const float*)config.gate_scale + expert_id * full_scale_elems + gate_up_scale_src_offset;
up_scale_src = (const float*)config.up_scale + expert_id * full_scale_elems + gate_up_scale_src_offset;
down_scale_src = (const float*)config.down_scale + expert_id * full_scale_elems;
}
// Copy gate/up weights + scales (column slice)
std::memcpy(gate_dst, gate_src, tp_weight_elems);
std::memcpy(up_dst, up_src, tp_weight_elems);
std::memcpy(gate_scale_dst, gate_scale_src, sizeof(float) * tp_scale_elems);
std::memcpy(up_scale_dst, up_scale_src, sizeof(float) * tp_scale_elems);
// Copy down weights (row-wise split)
for (int row = 0; row < config.hidden_size; row++) {
const size_t src_row_offset = (size_t)row * (size_t)config.intermediate_size + down_weight_src_col_offset;
const size_t dst_row_offset = (size_t)row * (size_t)tpc.intermediate_size;
std::memcpy(down_dst + dst_row_offset, down_src + src_row_offset, (size_t)tpc.intermediate_size);
}
// Copy down scales (block-row-wise split)
const int n_blocks_n = avx2::div_up(config.hidden_size, group_size);
const int full_n_blocks_k = avx2::div_up(config.intermediate_size, group_size);
const int tp_n_blocks_k = avx2::div_up(tpc.intermediate_size, group_size);
for (int bn = 0; bn < n_blocks_n; bn++) {
const float* src = down_scale_src + (size_t)bn * full_n_blocks_k + down_scale_src_block_k_offset;
float* dst = down_scale_dst + (size_t)bn * tp_n_blocks_k;
std::memcpy(dst, src, sizeof(float) * tp_n_blocks_k);
}
},
nullptr);
});
// Call per-TP load_weights
pool->dispense_backend()->do_numa_job([&, this](int i) {
tps[i]->load_weights();
});
// Free temporary buffers
pool->dispense_backend()->do_numa_job([&, this](int i) {
auto& tpc = tps[i]->config_;
delete[] (uint8_t*)tpc.gate_proj;
delete[] (uint8_t*)tpc.up_proj;
delete[] (uint8_t*)tpc.down_proj;
delete[] (float*)tpc.gate_scale;
delete[] (float*)tpc.up_scale;
delete[] (float*)tpc.down_scale;
});
this->weights_loaded = true;
}
void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) {
if (this->weights_loaded == false) throw std::runtime_error("Not Loaded");
if (this->tps.empty()) throw std::runtime_error("No TP parts initialized");
this->config.pool->dispense_backend()->do_numa_job([&, this](int i) {
this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config,
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs);
});
}
};
#endif // CPUINFER_OPERATOR_AVX2_FP8_MOE_H

View File

@@ -0,0 +1,86 @@
/**
* @Description : FP8 E4M3 dequantization for AVX2 (LUT-based)
* @Author : Claude
* @Date : 2026-03-18
* @Version : 1.0.0
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*
* AVX512 uses _mm512_permutex2var_epi8 (VBMI) for FP8→BF16 LUT conversion.
* AVX2 uses a precomputed 256-entry FP8→FP32 lookup table + _mm256_i32gather_ps.
*
* FP8 E4M3 format: sign(1) + exponent(4) + mantissa(3)
* Reference: examples/test_fp8_moe.py:103-116
**/
#ifndef CPUINFER_OPERATOR_AVX2_FP8_DEQUANT_H
#define CPUINFER_OPERATOR_AVX2_FP8_DEQUANT_H
#include <immintrin.h>
#include <cmath>
#include <cstdint>
namespace avx2 {
// Precomputed FP8 E4M3 → FP32 lookup table (256 entries)
// Initialized once at program startup via init_fp8_lut()
struct FP8LUT {
alignas(32) float table[256];
bool initialized = false;
void init() {
if (initialized) return;
for (int i = 0; i < 256; i++) {
int sign = (i >> 7) & 1;
int exp = (i >> 3) & 0xF; // 4-bit exponent (bits 3-6)
int man = i & 0x7; // 3-bit mantissa (bits 0-2)
float val;
if (exp == 0 && man == 0) {
val = 0.0f; // zero
} else if (exp == 0) {
val = std::ldexp((float)man / 8.0f, -6); // subnormal: 2^(-6) * (0.man)
} else if (exp == 15 && man == 7) {
val = 0.0f; // Only 0x7F is NaN in E4M3. Treat as 0 to avoid propagation.
// E4M3 has no Inf. exp=15 with man=0-6 are valid finite values (256-448).
} else {
val = std::ldexp(1.0f + (float)man / 8.0f, exp - 7); // normal: 2^(exp-7) * (1.man)
}
table[i] = sign ? -val : val;
}
initialized = true;
}
};
// Global LUT instance
inline FP8LUT& get_fp8_lut() {
static FP8LUT lut;
return lut;
}
// Ensure LUT is initialized (call once at startup)
inline void ensure_fp8_lut_initialized() {
get_fp8_lut().init();
}
// ============================================================================
// AVX2 FP8→FP32 dequantization: 8 FP8 bytes → 8 FP32 values
// Uses _mm256_i32gather_ps for parallel LUT lookups
// ============================================================================
static inline __m256 fp8x8_to_fp32x8(const uint8_t* src) {
const float* lut = get_fp8_lut().table;
// Load 8 bytes, zero-extend to 32-bit indices
__m128i bytes = _mm_loadl_epi64((const __m128i*)src);
__m256i indices = _mm256_cvtepu8_epi32(bytes);
// Gather 8 floats from LUT (scale=4 because float is 4 bytes)
return _mm256_i32gather_ps(lut, indices, 4);
}
// Scalar fallback for non-aligned or tail elements
static inline float fp8_to_fp32_scalar(uint8_t val) {
return get_fp8_lut().table[val];
}
} // namespace avx2
#endif // CPUINFER_OPERATOR_AVX2_FP8_DEQUANT_H

View File

@@ -0,0 +1,510 @@
/**
* @Description : AVX2 GPTQ-Int4 MoE operator (symmetric quantization)
* @Author : Claude
* @Date : 2026-03-18
* @Version : 1.0.0
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*
* Supports GPTQ symmetric (sym=true, desc_act=false) INT4 quantization.
* qweight [K/8, N] int32 + scales [K/gs, N] fp32. No qzeros needed.
**/
#ifndef CPUINFER_OPERATOR_AVX2_GPTQ_INT4_MOE_H
#define CPUINFER_OPERATOR_AVX2_GPTQ_INT4_MOE_H
#include "avx2_bf16_gemm.hpp"
#include "avx2_bf16_utils.hpp"
#include "gptq_int4_dequant.hpp"
#include "moe_base.hpp"
namespace avx2 {
struct GemmKernelAVX2GPTQInt4 {
using dt = ggml_bf16_t;
using output_t = float;
static constexpr int M_STEP = 1;
static constexpr int N_STEP = 8;
static constexpr int K_STEP = 8; // 8 INT4 values per int32
static constexpr int N_BLOCK = 64;
static constexpr int K_BLOCK = 128; // = group_size typically
static constexpr double ELEMENT_SIZE = 0.5; // INT4 = 0.5 byte
static void config() {}
static int recommended_nth(int n) {
return std::max(1, n / N_BLOCK);
}
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
return split_range(n, ith, nth);
}
// ========================================================================
// BufferA: BF16 activations [M, K] — same as BF16/FP8
// ========================================================================
struct BufferA {
ggml_bf16_t* data = nullptr;
size_t max_m = 0;
size_t k = 0;
BufferA() = default;
BufferA(size_t m, size_t k_, void* ptr) : max_m(m), k(k_), data((ggml_bf16_t*)ptr) {}
static size_t required_size(size_t m, size_t k) {
return m * k * sizeof(ggml_bf16_t);
}
void set_data(void* ptr) { data = (ggml_bf16_t*)ptr; }
void from_mat(int m, const ggml_bf16_t* src, int ith, int nth) {
if (ith == 0 && nth == 1) {
std::memcpy(data, src, (size_t)m * k * sizeof(ggml_bf16_t));
} else {
auto [m_start, m_end] = split_range(m, ith, nth);
std::memcpy(data + m_start * k, src + m_start * k,
(size_t)(m_end - m_start) * k * sizeof(ggml_bf16_t));
}
}
};
// ========================================================================
// BufferB: GPTQ INT4 weights [K/8, N] int32 + scales [num_groups, N] fp32
// ========================================================================
struct BufferB {
uint32_t* qweight = nullptr; // [K/8, N] packed int32
float* scales = nullptr; // [num_groups, N] fp32
int n = 0;
int k = 0;
int group_size = 128;
int num_groups = 0;
int k_packed = 0; // = K/8
BufferB() = default;
BufferB(size_t n_, size_t k_, int gs, void* ptr)
: n(n_), k(k_), group_size(gs) {
k_packed = k / 8;
num_groups = k / gs;
qweight = (uint32_t*)ptr;
scales = (float*)((uint8_t*)ptr + (size_t)k_packed * n * sizeof(uint32_t));
}
static size_t required_size(size_t n, size_t k, int gs) {
size_t k_packed = k / 8;
size_t num_groups = k / gs;
return k_packed * n * sizeof(uint32_t) + num_groups * n * sizeof(float);
}
// Load qweight and scales from separate source pointers
void from_mat(const uint32_t* src_qweight, const float* src_scales, int ith, int nth) {
// Split by N dimension
auto [n_start, n_end] = split_range(n, ith, nth);
int n_len = n_end - n_start;
// Copy qweight rows [K/8 rows, each row = N int32]
for (int kr = 0; kr < k_packed; kr++) {
std::memcpy(qweight + kr * n + n_start,
src_qweight + kr * n + n_start,
n_len * sizeof(uint32_t));
}
// Copy scales rows [num_groups rows, each row = N float]
for (int g = 0; g < num_groups; g++) {
std::memcpy(scales + g * n + n_start,
src_scales + g * n + n_start,
n_len * sizeof(float));
}
}
};
// ========================================================================
// BufferC: FP32 output — same as BF16/FP8
// ========================================================================
struct BufferC {
float* data = nullptr;
size_t max_m = 0;
size_t n = 0;
BufferC() = default;
BufferC(size_t m, size_t n_, void* ptr) : max_m(m), n(n_), data((float*)ptr) {}
static size_t required_size(size_t m, size_t n) {
return m * n * sizeof(float);
}
void set_data(void* ptr) { data = (float*)ptr; }
void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {
auto [n_start, n_end] = split_range((int)n, ith, nth);
for (int mi = 0; mi < m; mi++) {
float* src_row = data + mi * n;
ggml_bf16_t* dst_row = dst + mi * n;
int j = n_start;
for (; j + 8 <= n_end; j += 8) {
store_fp32_to_bf16(dst_row + j, _mm256_loadu_ps(src_row + j));
}
for (; j < n_end; j++) {
dst_row[j] = GGML_FP32_TO_BF16(src_row[j]);
}
}
}
};
};
// ============================================================================
// AVX2 GPTQ INT4 GEMM (symmetric)
// C[m,n] = sum_k A_bf16[m,k] * dequant(B_int4[k,n])
// ============================================================================
static inline void gemm_gptq_sym_int4(
int m, int n, int k,
GemmKernelAVX2GPTQInt4::BufferA& a,
GemmKernelAVX2GPTQInt4::BufferB& b,
GemmKernelAVX2GPTQInt4::BufferC& c,
int ith, int nth) {
auto [n_start, n_end] = split_range(n, ith, nth);
const int group_size = b.group_size;
const int num_groups = b.num_groups;
for (int ni = n_start; ni < n_end; ni++) {
for (int mi = 0; mi < m; mi++) {
const ggml_bf16_t* a_row = a.data + (size_t)mi * a.k;
float sum = 0.0f;
for (int g = 0; g < num_groups; g++) {
float scale = b.scales[g * n + ni];
int k_base = g * group_size;
__m256 acc1 = _mm256_setzero_ps();
__m256 acc2 = _mm256_setzero_ps();
// group_size/8 iterations (e.g., 128/8 = 16)
for (int ki = 0; ki < group_size; ki += 8) {
int k_abs = k_base + ki;
__m256 a_val = load_bf16_to_fp32(a_row + k_abs);
uint32_t packed = b.qweight[(k_abs / 8) * n + ni];
__m256 w_val = gptq_sym_dequant_8x4bit(packed, scale);
acc1 = _mm256_fmadd_ps(a_val, w_val, acc1);
}
sum += hsum_avx2(acc1);
}
c.data[mi * n + ni] = sum;
}
}
}
} // namespace avx2
// ============================================================================
// AVX2 GPTQ INT4 MoE operator
// ============================================================================
template <class T = avx2::GemmKernelAVX2GPTQInt4>
class AVX2_GPTQ_INT4_MOE_TP : public AVX2_MOE_BASE<T, AVX2_GPTQ_INT4_MOE_TP<T>> {
using Base = AVX2_MOE_BASE<T, AVX2_GPTQ_INT4_MOE_TP<T>>;
using Base::config_;
using Base::down_ba_;
using Base::down_bb_;
using Base::down_bc_;
using Base::gate_bb_;
using Base::gate_bc_;
using Base::gate_up_ba_;
using Base::m_local_num_;
using Base::tp_part_idx;
using Base::up_bb_;
using Base::up_bc_;
public:
using typename Base::input_t;
using typename Base::output_t;
AVX2_GPTQ_INT4_MOE_TP() = default;
AVX2_GPTQ_INT4_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {}
void derived_init() {
auto& qc = config_.quant_config;
if (qc.group_size == 0) {
throw std::runtime_error("GPTQ INT4 requires group_size > 0");
}
printf("Created AVX2_GPTQ_INT4_MOE_TP %d at numa %d (group_size=%d)\n",
tp_part_idx, numa_node_of_cpu(sched_getcpu()), qc.group_size);
}
~AVX2_GPTQ_INT4_MOE_TP() = default;
// CRTP buffer creation
size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
return T::BufferB::required_size(n, k, config_.quant_config.group_size);
}
size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
return std::make_shared<typename T::BufferA>(m, k, data);
}
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
return std::make_shared<typename T::BufferB>(n, k, config_.quant_config.group_size, data);
}
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
return std::make_shared<typename T::BufferC>(m, n, data);
}
// GEMM dispatch
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
int m = m_local_num_[expert_idx];
auto& ba = gate_up_ba_[expert_idx];
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
avx2::gemm_gptq_sym_int4(m, config_.intermediate_size, config_.hidden_size, *ba, *bb, *bc, ith, nth);
}
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
int m = m_local_num_[expert_idx];
avx2::gemm_gptq_sym_int4(m, config_.hidden_size, config_.intermediate_size,
*down_ba_[expert_idx], *down_bb_[expert_idx], *down_bc_[expert_idx], ith, nth);
}
// Load weights from contiguous qweight + scales pointers
void load_weights() {
int group_size = config_.quant_config.group_size;
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
auto pool = config_.pool->get_subpool(tp_part_idx);
if (config_.gate_scale == nullptr) {
throw std::runtime_error("GPTQ INT4 MOE requires scale pointers.");
}
// gate + up: qweight [K/8, N=intermediate], scales [K/gs, N=intermediate]
int gate_up_k = config_.hidden_size;
int gate_up_n = config_.intermediate_size;
size_t qw_elems = (size_t)(gate_up_k / 8) * gate_up_n;
size_t sc_elems = (size_t)(gate_up_k / group_size) * gate_up_n;
int nth = T::recommended_nth(gate_up_n);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map, qw_elems, sc_elems](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
gate_bb_[expert_idx]->from_mat(
(uint32_t*)config_.gate_proj + logical * qw_elems,
(float*)config_.gate_scale + logical * sc_elems,
ith, nth);
up_bb_[expert_idx]->from_mat(
(uint32_t*)config_.up_proj + logical * qw_elems,
(float*)config_.up_scale + logical * sc_elems,
ith, nth);
},
nullptr);
// down: qweight [K/8, N=hidden] where K=intermediate
int down_k = config_.intermediate_size;
int down_n = config_.hidden_size;
size_t down_qw_elems = (size_t)(down_k / 8) * down_n;
size_t down_sc_elems = (size_t)(down_k / group_size) * down_n;
nth = T::recommended_nth(down_n);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map, down_qw_elems, down_sc_elems](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
down_bb_[expert_idx]->from_mat(
(uint32_t*)config_.down_proj + logical * down_qw_elems,
(float*)config_.down_scale + logical * down_sc_elems,
ith, nth);
},
nullptr);
}
// write_weights_to_buffer for layerwise prefill / GPU expert offload
// Note: GPTQ INT4 GPU offload requires the GPU to support INT4 dequant.
// For now, this is a placeholder that copies raw packed data.
void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id,
const GeneralMOEConfig& full_config, const std::vector<uintptr_t>& w13_weight_ptrs,
[[maybe_unused]] const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
[[maybe_unused]] const std::vector<uintptr_t>& w2_scale_ptrs) const {
// TODO: Implement GPTQ INT4 GPU offload when needed
// For now, layerwise prefill with GPTQ INT4 is not supported
throw std::runtime_error("GPTQ INT4 write_weights_to_buffer not yet implemented");
}
};
// ============================================================================
// TP_MOE specialization for AVX2_GPTQ_INT4_MOE_TP
// ============================================================================
template <typename K>
class TP_MOE<AVX2_GPTQ_INT4_MOE_TP<K>> : public TP_MOE<AVX2_MOE_BASE<K, AVX2_GPTQ_INT4_MOE_TP<K>>> {
public:
using Base = TP_MOE<AVX2_MOE_BASE<K, AVX2_GPTQ_INT4_MOE_TP<K>>>;
using Base::Base;
void load_weights() override {
auto& config = this->config;
auto& tps = this->tps;
auto& tp_count = this->tp_count;
auto pool = config.pool;
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
const int group_size = config.quant_config.group_size;
if (group_size == 0) {
throw std::runtime_error("GPTQ INT4 requires group_size > 0");
}
if (config.gate_projs.empty() && config.gate_proj == nullptr) {
throw std::runtime_error("no weight source");
}
const bool use_per_expert_ptrs = !config.gate_projs.empty();
// Full dimensions
const int full_intermediate = config.intermediate_size;
const int full_hidden = config.hidden_size;
// gate/up: shape [K=hidden, N=intermediate]
// qweight: [hidden/8, intermediate], scales: [hidden/gs, intermediate]
const int gate_up_k_packed = full_hidden / 8;
const int gate_up_num_groups = full_hidden / group_size;
const size_t full_gate_up_qw_elems = (size_t)gate_up_k_packed * full_intermediate;
const size_t full_gate_up_sc_elems = (size_t)gate_up_num_groups * full_intermediate;
// down: shape [K=intermediate, N=hidden]
// qweight: [intermediate/8, hidden], scales: [intermediate/gs, hidden]
const int down_k_packed = full_intermediate / 8;
const int down_num_groups = full_intermediate / group_size;
const size_t full_down_qw_elems = (size_t)down_k_packed * full_hidden;
const size_t full_down_sc_elems = (size_t)down_num_groups * full_hidden;
pool->dispense_backend()->do_numa_job([&, this](int i) {
auto& tpc = tps[i]->config_;
const int tp_intermediate = tpc.intermediate_size;
// gate/up TP: N=intermediate is split
const size_t tp_gate_up_qw_elems = (size_t)gate_up_k_packed * tp_intermediate;
const size_t tp_gate_up_sc_elems = (size_t)gate_up_num_groups * tp_intermediate;
tpc.gate_proj = new uint32_t[tpc.expert_num * tp_gate_up_qw_elems];
tpc.up_proj = new uint32_t[tpc.expert_num * tp_gate_up_qw_elems];
tpc.gate_scale = new float[tpc.expert_num * tp_gate_up_sc_elems];
tpc.up_scale = new float[tpc.expert_num * tp_gate_up_sc_elems];
// down TP: K=intermediate is split
const int tp_down_k_packed = tp_intermediate / 8;
const int tp_down_num_groups = tp_intermediate / group_size;
const size_t tp_down_qw_elems = (size_t)tp_down_k_packed * full_hidden;
const size_t tp_down_sc_elems = (size_t)tp_down_num_groups * full_hidden;
tpc.down_proj = new uint32_t[tpc.expert_num * tp_down_qw_elems];
tpc.down_scale = new float[tpc.expert_num * tp_down_sc_elems];
const int gate_up_n_offset = i * tp_intermediate;
const int down_k_offset_packed = i * tp_down_k_packed;
const int down_group_offset = i * tp_down_num_groups;
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&, &tpc](int expert_id_) {
const size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
const uint32_t* gate_qw_src;
const uint32_t* up_qw_src;
const uint32_t* down_qw_src;
const float* gate_sc_src;
const float* up_sc_src;
const float* down_sc_src;
if (use_per_expert_ptrs) {
gate_qw_src = (const uint32_t*)config.gate_projs[0][expert_id];
up_qw_src = (const uint32_t*)config.up_projs[0][expert_id];
down_qw_src = (const uint32_t*)config.down_projs[0][expert_id];
gate_sc_src = (const float*)config.gate_scales[0][expert_id];
up_sc_src = (const float*)config.up_scales[0][expert_id];
down_sc_src = (const float*)config.down_scales[0][expert_id];
} else {
gate_qw_src = (const uint32_t*)config.gate_proj + expert_id * full_gate_up_qw_elems;
up_qw_src = (const uint32_t*)config.up_proj + expert_id * full_gate_up_qw_elems;
down_qw_src = (const uint32_t*)config.down_proj + expert_id * full_down_qw_elems;
gate_sc_src = (const float*)config.gate_scale + expert_id * full_gate_up_sc_elems;
up_sc_src = (const float*)config.up_scale + expert_id * full_gate_up_sc_elems;
down_sc_src = (const float*)config.down_scale + expert_id * full_down_sc_elems;
}
uint32_t* gate_qw_dst = (uint32_t*)tpc.gate_proj + expert_id * tp_gate_up_qw_elems;
uint32_t* up_qw_dst = (uint32_t*)tpc.up_proj + expert_id * tp_gate_up_qw_elems;
float* gate_sc_dst = (float*)tpc.gate_scale + expert_id * tp_gate_up_sc_elems;
float* up_sc_dst = (float*)tpc.up_scale + expert_id * tp_gate_up_sc_elems;
// gate/up qweight: [K/8, N] → slice N columns
for (int kr = 0; kr < gate_up_k_packed; kr++) {
std::memcpy(gate_qw_dst + kr * tp_intermediate,
gate_qw_src + kr * full_intermediate + gate_up_n_offset,
tp_intermediate * sizeof(uint32_t));
std::memcpy(up_qw_dst + kr * tp_intermediate,
up_qw_src + kr * full_intermediate + gate_up_n_offset,
tp_intermediate * sizeof(uint32_t));
}
// gate/up scales: [num_groups, N] → slice N columns
for (int g = 0; g < gate_up_num_groups; g++) {
std::memcpy(gate_sc_dst + g * tp_intermediate,
gate_sc_src + g * full_intermediate + gate_up_n_offset,
tp_intermediate * sizeof(float));
std::memcpy(up_sc_dst + g * tp_intermediate,
up_sc_src + g * full_intermediate + gate_up_n_offset,
tp_intermediate * sizeof(float));
}
// down qweight: [K/8, N=hidden] row-major → slice contiguous rows (K/8 dim)
uint32_t* down_qw_dst = (uint32_t*)tpc.down_proj + expert_id * tp_down_qw_elems;
for (int kr = 0; kr < tp_down_k_packed; kr++) {
std::memcpy(down_qw_dst + kr * full_hidden,
down_qw_src + (down_k_offset_packed + kr) * full_hidden,
full_hidden * sizeof(uint32_t));
}
// down scales: [K/gs, N=hidden] row-major → slice contiguous rows (K/gs dim)
float* down_sc_dst = (float*)tpc.down_scale + expert_id * tp_down_sc_elems;
for (int g = 0; g < tp_down_num_groups; g++) {
std::memcpy(down_sc_dst + g * full_hidden,
down_sc_src + (down_group_offset + g) * full_hidden,
full_hidden * sizeof(float));
}
},
nullptr);
});
// Call per-TP load_weights
pool->dispense_backend()->do_numa_job([&, this](int i) {
tps[i]->load_weights();
});
// Free temporary buffers
pool->dispense_backend()->do_numa_job([&, this](int i) {
auto& tpc = tps[i]->config_;
delete[] (uint32_t*)tpc.gate_proj;
delete[] (uint32_t*)tpc.up_proj;
delete[] (uint32_t*)tpc.down_proj;
delete[] (float*)tpc.gate_scale;
delete[] (float*)tpc.up_scale;
delete[] (float*)tpc.down_scale;
});
this->weights_loaded = true;
}
void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) {
// GPTQ INT4 GPU offload not yet supported
throw std::runtime_error("GPTQ INT4 write_weight_scale_to_buffer not yet implemented");
}
};
#endif // CPUINFER_OPERATOR_AVX2_GPTQ_INT4_MOE_H

View File

@@ -0,0 +1,50 @@
/**
* @Description : GPTQ-Int4 symmetric dequantization for AVX2
* @Author : Claude
* @Date : 2026-03-18
* @Version : 1.0.0
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*
* GPTQ symmetric quantization (sym=true):
* dequant[k,n] = (((qweight[k/8,n] >> ((k%8)*4)) & 0xF) - 8) * scale[k/gs, n]
*
* qweight layout: [K/8, N] int32, packing 8 x 4-bit values along K dimension
* scales layout: [K/group_size, N] fp16 (converted to fp32 at load time)
* qzeros: not needed (symmetric, zero_point = 8 for all)
**/
#ifndef CPUINFER_OPERATOR_AVX2_GPTQ_INT4_DEQUANT_H
#define CPUINFER_OPERATOR_AVX2_GPTQ_INT4_DEQUANT_H
#include <immintrin.h>
#include <cstdint>
namespace avx2 {
// Dequantize 8 x 4-bit values from a packed int32 (symmetric, zero_point=8)
// packed_weight contains 8 nibbles: bits [0:3]=val0, [4:7]=val1, ..., [28:31]=val7
// Result: ((nibble - 8) * scale) for each of the 8 values
static inline __m256 gptq_sym_dequant_8x4bit(uint32_t packed_weight, float scale) {
// Variable shift: extract each 4-bit nibble into its own 32-bit lane
// GPTQ packing: bit 0-3 = k_offset 0, bit 4-7 = k_offset 1, ...
// _mm256_set_epi32 sets lanes in reverse order (lane 7 first), so:
// lane 0 = shift 0 (k_offset 0), lane 1 = shift 4 (k_offset 1), ...
const __m256i shifts = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0);
__m256i packed_v = _mm256_set1_epi32(packed_weight);
__m256i nibbles = _mm256_and_si256(_mm256_srlv_epi32(packed_v, shifts),
_mm256_set1_epi32(0xF));
// (nibble - 8) * scale
__m256 w = _mm256_cvtepi32_ps(nibbles);
return _mm256_mul_ps(_mm256_sub_ps(w, _mm256_set1_ps(8.0f)),
_mm256_set1_ps(scale));
}
// Scalar version for verification
static inline float gptq_sym_dequant_scalar(uint32_t packed_weight, int k_in_pack, float scale) {
int nibble = (packed_weight >> (k_in_pack * 4)) & 0xF;
return (float)(nibble - 8) * scale;
}
} // namespace avx2
#endif // CPUINFER_OPERATOR_AVX2_GPTQ_INT4_DEQUANT_H

View File

@@ -0,0 +1,580 @@
/**
* @Description : AVX2 MoE base class (ported from amx/moe_base.hpp)
* @Author : Claude
* @Date : 2026-03-18
* @Version : 1.0.0
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*
* All AVX512 intrinsics (__m512, _mm512_*) replaced with AVX2 (__m256, _mm256_*).
* AMX tile configuration calls (T::config()) are kept but are no-ops.
**/
#ifndef CPUINFER_OPERATOR_AVX2_MOE_BASE_H
#define CPUINFER_OPERATOR_AVX2_MOE_BASE_H
#include <immintrin.h>
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "../../cpu_backend/shared_mem_buffer.h"
#include "../../cpu_backend/worker_pool.h"
#include "../common.hpp"
#include "../moe-tp.hpp"
#include "avx2_bf16_gemm.hpp"
#include "avx2_bf16_utils.hpp"
#include "llama.cpp/ggml.h"
template <class T, class Derived>
class AVX2_MOE_BASE {
public:
int tp_part_idx = 0;
ggml_bf16_t* m_local_input_ = nullptr;
ggml_bf16_t* m_local_gate_output_ = nullptr;
ggml_bf16_t* m_local_up_output_ = nullptr;
ggml_bf16_t* m_local_down_output_ = nullptr;
std::vector<std::vector<int>> m_local_pos_;
std::vector<int> m_local_num_;
std::vector<int> m_expert_id_map_;
std::vector<ggml_bf16_t*> m_local_input_ptr_;
std::vector<ggml_bf16_t*> m_local_gate_output_ptr_;
std::vector<ggml_bf16_t*> m_local_up_output_ptr_;
std::vector<ggml_bf16_t*> m_local_down_output_ptr_;
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
size_t pool_count_ = 0;
size_t gate_up_ba_pool_bytes_ = 0;
size_t gate_bc_pool_bytes_ = 0;
size_t up_bc_pool_bytes_ = 0;
size_t down_ba_pool_bytes_ = 0;
size_t down_bc_pool_bytes_ = 0;
void* gate_up_ba_pool_ = nullptr;
void* gate_bc_pool_ = nullptr;
void* up_bc_pool_ = nullptr;
void* down_ba_pool_ = nullptr;
void* down_bc_pool_ = nullptr;
GeneralMOEConfig config_;
using input_t = ggml_bf16_t;
using output_t = float;
static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;
AVX2_MOE_BASE(GeneralMOEConfig config, int tp_part_idx_) : tp_part_idx(tp_part_idx_), config_(config) {
init();
derived()->derived_init();
}
void init() {
if (config_.load && config_.path == "") {
config_.load = false;
}
MemoryRequest mem_requests;
mem_requests.append_pointer(
&m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);
mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.intermediate_size);
mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.intermediate_size);
mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.hidden_size);
m_local_pos_.resize(config_.max_len);
for (int i = 0; i < config_.max_len; i++) {
m_local_pos_[i].resize(config_.num_experts_per_tok);
}
m_expert_id_map_.resize(config_.expert_num);
m_local_num_.resize(config_.expert_num);
m_local_input_ptr_.resize(config_.expert_num);
m_local_gate_output_ptr_.resize(config_.expert_num);
m_local_up_output_ptr_.resize(config_.expert_num);
m_local_down_output_ptr_.resize(config_.expert_num);
for (size_t i = 0; i < config_.expert_num; i++) {
gate_up_ba_.push_back(make_buffer_a(config_.max_len, config_.hidden_size, nullptr));
gate_bc_.push_back(make_buffer_c(config_.max_len, config_.intermediate_size, nullptr));
up_bc_.push_back(make_buffer_c(config_.max_len, config_.intermediate_size, nullptr));
down_ba_.push_back(make_buffer_a(config_.max_len, config_.intermediate_size, nullptr));
down_bc_.push_back(make_buffer_c(config_.max_len, config_.hidden_size, nullptr));
void* gate_bb_ptr =
std::aligned_alloc(64, buffer_b_required_size(config_.intermediate_size, config_.hidden_size));
gate_bb_.push_back(make_buffer_b(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));
void* up_bb_ptr = std::aligned_alloc(64, buffer_b_required_size(config_.intermediate_size, config_.hidden_size));
up_bb_.push_back(make_buffer_b(config_.intermediate_size, config_.hidden_size, up_bb_ptr));
void* down_bb_ptr =
std::aligned_alloc(64, buffer_b_required_size(config_.hidden_size, config_.intermediate_size));
down_bb_.push_back(make_buffer_b(config_.hidden_size, config_.intermediate_size, down_bb_ptr));
}
pool_count_ = config_.max_len * config_.num_experts_per_tok + config_.expert_num * T::M_STEP;
gate_up_ba_pool_bytes_ = buffer_a_required_size(pool_count_, config_.hidden_size) + pool_count_ * 64;
gate_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
up_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
down_ba_pool_bytes_ = buffer_a_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
down_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.hidden_size) + pool_count_ * 64;
mem_requests.append_pointer(&gate_up_ba_pool_, gate_up_ba_pool_bytes_);
mem_requests.append_pointer(&gate_bc_pool_, gate_bc_pool_bytes_);
mem_requests.append_pointer(&up_bc_pool_, up_bc_pool_bytes_);
mem_requests.append_pointer(&down_ba_pool_, down_ba_pool_bytes_);
mem_requests.append_pointer(&down_bc_pool_, down_bc_pool_bytes_);
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
}
~AVX2_MOE_BASE() = default;
void warm_up() {
int qlen = config_.max_len;
std::vector<uint8_t> input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
std::vector<uint8_t> output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
std::vector<int64_t> expert_ids(qlen * config_.num_experts_per_tok);
std::vector<float> weights(qlen * config_.num_experts_per_tok);
for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) {
expert_ids[i] = i % config_.expert_num;
weights[i] = 0.01;
}
forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data());
}
void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
if (qlen > 1) {
forward_prefill(qlen, k, expert_ids, weights, input, output);
} else {
forward_decode(k, expert_ids, weights, input, output);
}
}
template <typename... Args>
void load_weights(Args&&... args) {
derived()->load_weights(std::forward<Args>(args)...);
}
template <typename... Args>
void write_weights_to_buffer(Args&&... args) const {
derived_const()->write_weights_to_buffer(std::forward<Args>(args)...);
}
void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
void* output) {
auto pool = config_.pool->get_subpool(tp_part_idx);
int activated_expert = 0;
std::fill(m_local_num_.begin(), m_local_num_.end(), 0);
for (int i = 0; i < qlen; i++) {
for (int j = 0; j < k; j++) {
if (config_.should_skip_expert(expert_ids[i * k + j])) {
continue;
}
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
}
}
for (int i = 0; i < config_.expert_num; i++) {
if (m_local_num_[i] > 0) {
m_expert_id_map_[activated_expert] = i;
activated_expert++;
}
}
// Assign pool memory to buffers
size_t offset = 0;
void* gate_up_ba_pool_ptr = gate_up_ba_pool_;
void* gate_bc_pool_ptr = gate_bc_pool_;
void* up_bc_pool_ptr = up_bc_pool_;
void* down_ba_pool_ptr = down_ba_pool_;
void* down_bc_pool_ptr = down_bc_pool_;
constexpr size_t M_STEP = T::M_STEP;
auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); };
for (int i = 0; i < config_.expert_num; i++) {
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
offset += m_local_num_[i];
if (m_local_num_[i] == 0) continue;
size_t max_m = (m_local_num_[i] + M_STEP - 1) / M_STEP * M_STEP;
gate_up_ba_[i]->max_m = max_m;
gate_up_ba_[i]->set_data(gate_up_ba_pool_ptr);
gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + align64(buffer_a_required_size(max_m, config_.hidden_size)));
gate_bc_[i]->max_m = max_m;
gate_bc_[i]->set_data(gate_bc_pool_ptr);
gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + align64(buffer_c_required_size(max_m, config_.intermediate_size)));
up_bc_[i]->max_m = max_m;
up_bc_[i]->set_data(up_bc_pool_ptr);
up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + align64(buffer_c_required_size(max_m, config_.intermediate_size)));
down_ba_[i]->max_m = max_m;
down_ba_[i]->set_data(down_ba_pool_ptr);
down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + align64(buffer_a_required_size(max_m, config_.intermediate_size)));
down_bc_[i]->max_m = max_m;
down_bc_[i]->set_data(down_bc_pool_ptr);
down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + align64(buffer_c_required_size(max_m, config_.hidden_size)));
}
auto direct_or_pool = [&](int count, auto&& fn) {
if (qlen < 10) {
for (int i = 0; i < count; i++) fn(i);
} else {
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
}
};
// Copy input to per-expert buffers
direct_or_pool(qlen, [&](int i) {
for (int j = 0; j < k; j++) {
if (config_.should_skip_expert(expert_ids[i * k + j])) continue;
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
(ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
}
});
// Pack input into BufferA (trivial memcpy for AVX2)
direct_or_pool(activated_expert, [this](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
});
// Gate + Up GEMM
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * activated_expert * 2, [](int) { T::config(); },
[this, nth, qlen](int task_id2) {
int task_id = task_id2 / 2;
bool do_up = task_id2 % 2;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
derived()->do_gate_up_gemm(do_up, expert_idx, ith, nth, qlen);
if (do_up) {
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
} else {
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
}
},
nullptr);
// Activation: SiLU(gate) * up — AVX2 version (8 elements at a time)
apply_activation(activated_expert, nth, qlen);
// Pack activation output into BufferA for down projection
pool->do_work_stealing_job(
activated_expert, nullptr,
[this](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr);
// Down GEMM
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * activated_expert, [](int) { T::config(); },
[this, nth, qlen](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
derived()->do_down_gemm(expert_idx, ith, nth, qlen);
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr);
// Weighted sum of expert outputs — AVX2 version (16 BF16 = 2x8 FP32 at a time)
pool->do_work_stealing_job(
qlen, nullptr,
[this, output, k, expert_ids, weights](int i) {
for (int e = 0; e < config_.hidden_size; e += 16) {
__m256 x0 = _mm256_setzero_ps();
__m256 x1 = _mm256_setzero_ps();
for (int j = 0; j < k; j++) {
if (config_.should_skip_expert(expert_ids[i * k + j])) continue;
__m256 weight = _mm256_set1_ps(weights[i * k + j]);
__m256 d0, d1;
avx2::load_16xbf16_to_2x8xfp32(
m_local_down_output_ptr_[expert_ids[i * k + j]] +
m_local_pos_[i][j] * config_.hidden_size + e,
&d0, &d1);
x0 = _mm256_fmadd_ps(d0, weight, x0);
x1 = _mm256_fmadd_ps(d1, weight, x1);
}
auto f32out = (__m256*)((float*)output + i * config_.hidden_size + e);
f32out[0] = x0;
f32out[1] = x1;
}
},
nullptr);
}
void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
int qlen = 1;
auto pool = config_.pool->get_subpool(tp_part_idx);
int activated_expert = 0;
std::fill(m_local_num_.begin(), m_local_num_.end(), 0);
for (int i = 0; i < k; i++) {
if (config_.should_skip_expert(expert_ids[i])) continue;
m_expert_id_map_[activated_expert] = expert_ids[i];
m_local_pos_[0][i] = 0;
m_local_num_[expert_ids[i]] = qlen;
activated_expert++;
}
size_t offset = 0;
for (int i = 0; i < activated_expert; i++) {
auto expert_idx = m_expert_id_map_[i];
m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size;
offset += qlen;
}
// Assign pool memory for decode
void* gate_bc_pool_ptr = gate_bc_pool_;
void* up_bc_pool_ptr = up_bc_pool_;
void* down_ba_pool_ptr = down_ba_pool_;
void* down_bc_pool_ptr = down_bc_pool_;
constexpr size_t M_STEP = T::M_STEP;
auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); };
for (int i = 0; i < activated_expert; i++) {
auto expert_idx = m_expert_id_map_[i];
size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;
gate_bc_[expert_idx]->max_m = max_m;
gate_bc_[expert_idx]->set_data(gate_bc_pool_ptr);
gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + align64(buffer_c_required_size(max_m, config_.intermediate_size)));
up_bc_[expert_idx]->max_m = max_m;
up_bc_[expert_idx]->set_data(up_bc_pool_ptr);
up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + align64(buffer_c_required_size(max_m, config_.intermediate_size)));
down_ba_[expert_idx]->max_m = max_m;
down_ba_[expert_idx]->set_data(down_ba_pool_ptr);
down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + align64(buffer_a_required_size(max_m, config_.intermediate_size)));
down_bc_[expert_idx]->max_m = max_m;
down_bc_[expert_idx]->set_data(down_bc_pool_ptr);
down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + align64(buffer_c_required_size(max_m, config_.hidden_size)));
}
// Pack input into BufferA for each activated expert
void* gate_up_ba_pool_ptr = gate_up_ba_pool_;
for (int i = 0; i < activated_expert; i++) {
auto expert_idx = m_expert_id_map_[i];
size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;
gate_up_ba_[expert_idx]->max_m = max_m;
gate_up_ba_[expert_idx]->set_data(gate_up_ba_pool_ptr);
gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + align64(buffer_a_required_size(max_m, config_.hidden_size)));
gate_up_ba_[expert_idx]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);
}
// Gate + Up GEMM
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * activated_expert * 2, [](int) { T::config(); },
[this, nth, qlen](int task_id2) {
int task_id = task_id2 / 2;
bool do_up = task_id2 % 2;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
derived()->do_gate_up_gemm(do_up, expert_idx, ith, nth, qlen);
if (do_up) {
up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth);
} else {
gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth);
}
},
nullptr);
// Activation
apply_activation(activated_expert, nth, qlen);
// Pack for down projection
pool->do_work_stealing_job(
activated_expert, nullptr,
[this, qlen](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr);
// Down GEMM
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * activated_expert, [](int) { T::config(); },
[this, nth, qlen](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
derived()->do_down_gemm(expert_idx, ith, nth, qlen);
down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr);
// Weighted sum — AVX2 (16 BF16 at a time)
for (int e = 0; e < config_.hidden_size; e += 16) {
__m256 x0 = _mm256_setzero_ps();
__m256 x1 = _mm256_setzero_ps();
for (int j = 0; j < k; j++) {
if (config_.should_skip_expert(expert_ids[j])) continue;
__m256 weight = _mm256_set1_ps(weights[j]);
__m256 d0, d1;
avx2::load_16xbf16_to_2x8xfp32(
m_local_down_output_ptr_[expert_ids[j]] + m_local_pos_[0][j] * config_.hidden_size + e,
&d0, &d1);
x0 = _mm256_fmadd_ps(d0, weight, x0);
x1 = _mm256_fmadd_ps(d1, weight, x1);
}
auto f32out = (__m256*)((float*)output + e);
f32out[0] = x0;
f32out[1] = x1;
}
}
protected:
Derived* derived() { return static_cast<Derived*>(this); }
const Derived* derived_const() const { return static_cast<const Derived*>(this); }
void derived_init() {}
// Buffer creation/size delegation (CRTP)
size_t buffer_a_required_size(size_t m, size_t k) const { return derived_const()->buffer_a_required_size_impl(m, k); }
size_t buffer_b_required_size(size_t n, size_t k) const { return derived_const()->buffer_b_required_size_impl(n, k); }
size_t buffer_c_required_size(size_t m, size_t n) const { return derived_const()->buffer_c_required_size_impl(m, n); }
std::shared_ptr<typename T::BufferA> make_buffer_a(size_t m, size_t k, void* data) const {
return derived_const()->make_buffer_a_impl(m, k, data);
}
std::shared_ptr<typename T::BufferB> make_buffer_b(size_t n, size_t k, void* data) const {
return derived_const()->make_buffer_b_impl(n, k, data);
}
std::shared_ptr<typename T::BufferC> make_buffer_c(size_t m, size_t n, void* data) const {
return derived_const()->make_buffer_c_impl(m, n, data);
}
// SiLU activation — AVX2: process 8 BF16 elements at a time
void apply_activation(int activated_expert, int nth, int qlen) {
auto pool = config_.pool->get_subpool(tp_part_idx);
auto fn = [this, nth](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t* gate_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t* up_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
int j = n_start;
for (; j + 8 <= n_end; j += 8) {
__m256 gate_val = avx2::load_bf16_to_fp32(gate_ptr + j);
__m256 up_val = avx2::load_bf16_to_fp32(up_ptr + j);
__m256 result = avx2::act_fn(gate_val, up_val);
avx2::store_fp32_to_bf16(gate_ptr + j, result);
}
// Scalar tail
for (; j < n_end; j++) {
float g = GGML_BF16_TO_FP32(gate_ptr[j]);
float u = GGML_BF16_TO_FP32(up_ptr[j]);
float sigmoid_g = 1.0f / (1.0f + expf(-g));
gate_ptr[j] = GGML_FP32_TO_BF16(g * sigmoid_g * u);
}
}
};
if (activated_expert == 0) return;
if (qlen < 10) {
for (int task_id = 0; task_id < nth * activated_expert; task_id++) fn(task_id);
} else {
pool->do_work_stealing_job(nth * activated_expert, nullptr, fn, nullptr);
}
}
};
// ============================================================================
// TP_MOE specialization for AVX2_MOE_BASE derived classes
// ============================================================================
template <class T, class Derived>
class TP_MOE<AVX2_MOE_BASE<T, Derived>> : public TP_MOE_Common<AVX2_MOE_BASE<T, Derived>> {
public:
using TP_MOE_Common<AVX2_MOE_BASE<T, Derived>>::TP_MOE_Common;
void load_weights() override { throw std::runtime_error("Not Implemented"); }
void write_weight_scale_to_buffer(int gpu_tp_count, int gpu_experts_num,
const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) {
throw std::runtime_error("Not Implemented");
}
void merge_results(int qlen, void* output, bool incremental) override {
auto& config = this->config;
auto& tp_count = this->tp_count;
auto& local_output_numa = this->local_output_numa;
auto& tp_configs = this->tp_configs;
auto merge_fn = [this, output, incremental, &config, &tp_count, &local_output_numa, &tp_configs](int token_nth) {
float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size;
if (incremental) {
// Convert BF16 output to FP32 and add — AVX2 (16 BF16 at a time)
for (int e = 0; e < config.hidden_size; e += 16) {
__m256 x0, x1;
avx2::load_16xbf16_to_2x8xfp32((ggml_bf16_t*)output + token_nth * config.hidden_size + e, &x0, &x1);
*((__m256*)(merge_to + e)) = _mm256_add_ps(*((__m256*)(merge_to + e)), x0);
*((__m256*)(merge_to + e + 8)) = _mm256_add_ps(*((__m256*)(merge_to + e + 8)), x1);
}
}
// Sum across TP parts
for (int i = 1; i < tp_count; i++) {
float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size;
for (int e = 0; e < tp_configs[i].hidden_size; e += 8) {
*((__m256*)(merge_to + e)) = _mm256_add_ps(*((__m256*)(merge_to + e)), *((__m256*)(merge_from + e)));
}
}
// Convert FP32 -> BF16 output
for (int e = 0; e < config.hidden_size; e += 16) {
__m256 x0 = *(__m256*)(merge_to + e);
__m256 x1 = *(__m256*)(merge_to + e + 8);
avx2::store_2x8xfp32_to_16xbf16(&x0, &x1, (ggml_bf16_t*)output + token_nth * config.hidden_size + e);
}
};
auto pool = config.pool;
if (qlen < 10) {
for (int i = 0; i < qlen; i++) merge_fn(i);
} else {
pool->do_work_stealing_job(qlen, nullptr, merge_fn, nullptr);
}
}
void merge_results(int qlen, void* output) override { merge_results(qlen, output, false); }
};
#endif // CPUINFER_OPERATOR_AVX2_MOE_BASE_H

View File

@@ -93,7 +93,7 @@ class KTMoEWrapper:
# Select backend based on method
if method in ["AMXINT4", "AMXINT8"]:
backend_cls = AMXMoEWrapper
elif method in ["RAWINT4", "FP8", "BF16", "FP8_PERCHANNEL"]:
elif method in ["RAWINT4", "FP8", "BF16", "FP8_PERCHANNEL", "GPTQ_INT4"]:
backend_cls = NativeMoEWrapper
elif method == "LLAMAFILE":
backend_cls = LlamafileMoEWrapper

View File

@@ -5,7 +5,7 @@ from typing import Optional
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader, GPTQSafeTensorLoader
from kt_kernel_ext.moe import MOEConfig
import kt_kernel_ext.moe as _moe_mod
@@ -15,6 +15,9 @@ AMXInt4_KGroup_MOE = getattr(_moe_mod, "AMXInt4_KGroup_MOE", None)
AMXFP8_MOE = getattr(_moe_mod, "AMXFP8_MOE", None)
AMXBF16_MOE = getattr(_moe_mod, "AMXBF16_MOE", None)
AMXFP8PerChannel_MOE = getattr(_moe_mod, "AMXFP8PerChannel_MOE", None)
AVX2BF16_MOE = getattr(_moe_mod, "AVX2BF16_MOE", None)
AVX2FP8_MOE = getattr(_moe_mod, "AVX2FP8_MOE", None)
AVX2GPTQInt4_MOE = getattr(_moe_mod, "AVX2GPTQInt4_MOE", None)
_HAS_AMXINT4_SUPPORT = AMXInt4_MOE is not None
_HAS_AMXINT8_SUPPORT = AMXInt8_MOE is not None
@@ -22,6 +25,9 @@ _HAS_RAWINT4_SUPPORT = AMXInt4_KGroup_MOE is not None
_HAS_FP8_SUPPORT = AMXFP8_MOE is not None
_HAS_BF16_SUPPORT = AMXBF16_MOE is not None
_HAS_FP8_PERCHANNEL_SUPPORT = AMXFP8PerChannel_MOE is not None
_HAS_AVX2_BF16_SUPPORT = AVX2BF16_MOE is not None
_HAS_AVX2_FP8_SUPPORT = AVX2FP8_MOE is not None
_HAS_AVX2_GPTQ_INT4_SUPPORT = AVX2GPTQInt4_MOE is not None
class AMXMoEWrapper(BaseMoEWrapper):
@@ -346,10 +352,11 @@ class NativeMoEWrapper(BaseMoEWrapper):
" - AVX512F + AVX512BW (VNNI optional)\n"
"Please recompile kt_kernel_ext with AVX512 enabled."
)
if method == "FP8" and not _HAS_FP8_SUPPORT:
if method == "FP8" and not _HAS_FP8_SUPPORT and not _HAS_AVX2_FP8_SUPPORT:
raise RuntimeError(
"FP8 backend not available. Required ISA:\n"
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\n"
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI (for AMX), or\n"
" - AVX2 + FMA (for AVX2 fallback)\n"
"Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled."
)
if method == "FP8_PERCHANNEL" and not _HAS_FP8_PERCHANNEL_SUPPORT:
@@ -358,11 +365,17 @@ class NativeMoEWrapper(BaseMoEWrapper):
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\n"
"Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled."
)
if method == "BF16" and not _HAS_BF16_SUPPORT:
if method == "BF16" and not _HAS_BF16_SUPPORT and not _HAS_AVX2_BF16_SUPPORT:
raise RuntimeError(
"BF16 backend not available. Required ISA:\n"
" - AVX512F + AVX512BW + AVX512_BF16\n"
"Please recompile kt_kernel_ext with AVX512 + BF16 enabled."
" - AVX512F + AVX512BW + AVX512_BF16 (for AMX backend), or\n"
" - AVX2 + FMA (for AVX2 fallback backend)\n"
"Please recompile kt_kernel_ext with AVX512+BF16 or AVX2 enabled."
)
if method == "GPTQ_INT4" and not _HAS_AVX2_GPTQ_INT4_SUPPORT:
raise RuntimeError(
"GPTQ_INT4 backend not available.\n"
"Please recompile kt_kernel_ext with AVX2 enabled."
)
super().__init__(
@@ -391,6 +404,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path, scale_suffix="weight_scale")
elif method == "BF16":
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
elif method == "GPTQ_INT4":
NativeMoEWrapper._native_loader_instance = GPTQSafeTensorLoader(weight_path)
else:
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
self.loader = NativeMoEWrapper._native_loader_instance
@@ -506,15 +521,31 @@ class NativeMoEWrapper(BaseMoEWrapper):
moe_config.quant_config.bits = 8
moe_config.quant_config.group_size = 128
moe_config.quant_config.zero_point = False
self.moe = AMXFP8_MOE(moe_config)
if _HAS_FP8_SUPPORT:
self.moe = AMXFP8_MOE(moe_config)
else:
self.moe = AVX2FP8_MOE(moe_config)
elif self.method == "FP8_PERCHANNEL":
moe_config.quant_config.bits = 8
moe_config.quant_config.per_channel = True
moe_config.quant_config.zero_point = False
self.moe = AMXFP8PerChannel_MOE(moe_config)
elif self.method == "GPTQ_INT4":
# GPTQ symmetric INT4: qweight (int32) + scales (fp32)
group_size = self.gate_scales[0].shape[0] # scales shape [K/gs, N], first dim = num_groups
# hidden_size / num_groups = group_size
actual_gs = self.hidden_size // group_size
moe_config.quant_config.bits = 4
moe_config.quant_config.group_size = actual_gs
moe_config.quant_config.zero_point = False
self.moe = AVX2GPTQInt4_MOE(moe_config)
elif self.method == "BF16":
# BF16 has no quantization config needed
self.moe = AMXBF16_MOE(moe_config)
# Prefer AMX backend, fall back to AVX2
if _HAS_BF16_SUPPORT:
self.moe = AMXBF16_MOE(moe_config)
else:
self.moe = AVX2BF16_MOE(moe_config)
t4 = time.time()
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))

View File

@@ -961,3 +961,120 @@ class GGUFLoader:
data = torch.from_numpy(np.frombuffer(data_bytes, dtype=np.uint8).copy())
return data, ggml_type
class GPTQSafeTensorLoader(FP8SafeTensorLoader):
"""Loader for symmetric GPTQ-Int4 expert weights (qweight + scales, no qzeros).
Only supports sym=true, desc_act=false GPTQ models.
Tensor keys:
- qweight: {prefix}.{id}.{proj}.qweight (int32, packed 8x4-bit along K)
- scales: {prefix}.{id}.{proj}.scales (fp16 -> converted to fp32)
"""
def __init__(self, file_path: str):
# Call FP8SafeTensorLoader init (which calls SafeTensorLoader init + format detection)
super().__init__(file_path, scale_suffix="scales")
# Verify GPTQ config
self._verify_gptq_config(file_path)
def _detect_format(self):
"""Override FP8 format detection to look for .qweight instead of .weight."""
sample_keys = list(self.tensor_file_map.keys())[:2000]
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
for key in sample_keys:
if ".experts." in key and f".{gate}.qweight" in key:
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
self._detected_format = fmt_name
break
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
self._detected_format = fmt_name
# Check for VL model (language_model prefix)
if "language_model." in key:
self._is_vl_model = True
break
elif fmt_name == "mistral" and "block_sparse_moe" not in key and "mlp" not in key:
self._detected_format = fmt_name
break
if self._detected_format is not None:
break
if self._detected_format is None:
self._detected_format = "deepseek"
vl_str = " (VL model)" if self._is_vl_model else ""
print(f"[GPTQSafeTensorLoader] Detected format: {self._detected_format}{vl_str}")
def _verify_gptq_config(self, file_path):
"""Check that the model uses sym=true, desc_act=false."""
import json
import os
config_path = os.path.join(os.path.dirname(file_path), "config.json")
if not os.path.exists(config_path):
# Try parent directory
config_path = os.path.join(file_path, "config.json")
if os.path.exists(config_path):
with open(config_path) as f:
config = json.load(f)
qc = config.get("quantization_config", {})
if qc.get("quant_method") == "gptq":
if qc.get("desc_act", False):
raise NotImplementedError(
"GPTQ desc_act=true is not supported. Only desc_act=false models are supported."
)
if not qc.get("sym", True):
raise NotImplementedError(
"GPTQ sym=false (asymmetric) is not supported. Only sym=true models are supported."
)
print(f"[GPTQSafeTensorLoader] Verified: sym={qc.get('sym')}, desc_act={qc.get('desc_act')}, "
f"bits={qc.get('bits')}, group_size={qc.get('group_size')}")
def load_experts(self, base_key: str, device: str = "cpu"):
"""Load GPTQ expert qweight and scales.
Returns dict with keys: gate, up, down (qweight int32), gate_scale, up_scale, down_scale (fp32).
"""
experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)
gate_name, up_name, down_name = self._get_proj_names()
expert_count = 0
experts_prefix = None
for prefix in experts_prefix_candidates:
expert_count = 0
while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.qweight"):
expert_count += 1
if expert_count > 0:
experts_prefix = prefix
break
if expert_count == 0 or experts_prefix is None:
raise ValueError(f"No GPTQ experts found for keys: {experts_prefix_candidates}")
gate_weights = [None] * expert_count
up_weights = [None] * expert_count
down_weights = [None] * expert_count
gate_scales = [None] * expert_count
up_scales = [None] * expert_count
down_scales = [None] * expert_count
for exp_id in range(expert_count):
gate_weights[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{gate_name}.qweight", device).contiguous()
up_weights[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{up_name}.qweight", device).contiguous()
down_weights[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{down_name}.qweight", device).contiguous()
gate_scales[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{gate_name}.scales", device).float().contiguous()
up_scales[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{up_name}.scales", device).float().contiguous()
down_scales[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{down_name}.scales", device).float().contiguous()
print(f"[GPTQSafeTensorLoader] Loaded {expert_count} experts from {experts_prefix}")
return {
"gate": gate_weights,
"up": up_weights,
"down": down_weights,
"gate_scale": gate_scales,
"up_scale": up_scales,
"down_scale": down_scales,
}

View File

@@ -579,7 +579,7 @@ class CMakeBuild(build_ext):
avx512_extension_enabled = True
# If any AVX512 extension is enabled, ensure base AVX512 is also enabled
if avx512_extension_enabled and cpu_mode in ("NATIVE", "FANCY", "AVX512"):
if avx512_extension_enabled and cpu_mode in ("NATIVE", "FANCY", "AVX512") and "AVX512" in d["features"]:
if not any("LLAMA_AVX512=ON" in a for a in cmake_args):
cmake_args.append("-DLLAMA_AVX512=ON")
print("-- AVX512 extensions enabled; also enabling base AVX512F (-DLLAMA_AVX512=ON)")

View File

@@ -0,0 +1,159 @@
#!/usr/bin/env python
# coding=utf-8
"""AVX2 BF16 MoE accuracy tests for KT-Kernel.
Tests accuracy of AVX2 BF16 MOE operations against torch reference.
"""
import os
import sys
import time
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from kt_kernel import kt_kernel_ext
# Small test parameters for fast validation
expert_num = 8
hidden_size = 256
intermediate_size = 512
num_experts_per_tok = 2
max_len = 128
validation_iter = 3
CPUINFER_PARAM = 60
def act_fn(x):
"""SiLU activation."""
return x / (1.0 + torch.exp(-x))
def mlp_torch(input, gate_proj, up_proj, down_proj):
"""PyTorch reference MLP."""
gate_buf = torch.mm(input, gate_proj.t())
up_buf = torch.mm(input, up_proj.t())
intermediate = act_fn(gate_buf) * up_buf
return torch.mm(intermediate, down_proj.t())
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
"""PyTorch reference MoE."""
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
cnts.scatter_(1, expert_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = expert_ids.view(-1).argsort()
sorted_tokens = input[idxs // expert_ids.shape[1]]
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
t_output = (
new_x.view(*expert_ids.shape, -1)
.type(weights.dtype)
.mul_(weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return t_output
def test_avx2_bf16_accuracy(qlen, label):
"""Test AVX2 BF16 MoE accuracy."""
physical_to_logical_map = torch.tensor(range(expert_num), device="cpu", dtype=torch.int64).contiguous()
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
with torch.inference_mode():
# Generate BF16 weights
gate_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16).contiguous()
up_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16).contiguous()
down_proj = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 10.0).to(torch.bfloat16).contiguous()
# Create MOE config
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
config.gate_scale = 0
config.up_scale = 0
config.down_scale = 0
config.pool = CPUInfer.backend_
# Create AVX2 BF16 MOE
moe = kt_kernel_ext.moe.AVX2BF16_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
print(f"\n--- {label} (qlen={qlen}) ---")
for i in range(validation_iter):
expert_ids = torch.stack(
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
).contiguous()
weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
input_data = (torch.randn((qlen, hidden_size), dtype=torch.float32) / 100.0).to(torch.bfloat16).contiguous()
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
bsz_tensor = torch.tensor([qlen], dtype=torch.int32)
# Run AVX2 BF16 MOE
CPUInfer.submit(
moe.forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids.data_ptr(),
weights.data_ptr(),
input_data.data_ptr(),
output.data_ptr(),
False,
)
)
CPUInfer.sync()
# Run torch reference (in float32 for accuracy)
t_output = moe_torch(
input_data.float(), expert_ids, weights,
gate_proj.float(), up_proj.float(), down_proj.float()
).to(torch.bfloat16)
# Calculate relative difference
diff = torch.mean(torch.abs(output.float() - t_output.float())) / (torch.mean(torch.abs(t_output.float())) + 1e-8)
print(f" Iteration {i}: diff = {diff:.6f}")
# BF16 should be very accurate (< 0.01)
assert diff < 0.02, f"AVX2 BF16 accuracy test failed: diff={diff:.6f} >= 0.02"
print(f" PASSED")
if __name__ == "__main__":
print("=" * 60)
print("AVX2 BF16 MoE Accuracy Test")
print("=" * 60)
try:
# Test decode path (qlen=1)
test_avx2_bf16_accuracy(qlen=1, label="Decode")
# Test prefill path (qlen=16)
test_avx2_bf16_accuracy(qlen=16, label="Prefill")
print("\n" + "=" * 60)
print("ALL TESTS PASSED")
print("=" * 60)
except Exception as e:
print(f"\nTEST FAILED: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -0,0 +1,250 @@
#!/usr/bin/env python
# coding=utf-8
"""AVX2 FP8 MoE accuracy tests for KT-Kernel."""
import os
import sys
import math
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from kt_kernel import kt_kernel_ext
expert_num = 8
hidden_size = 256
intermediate_size = 512
num_experts_per_tok = 2
max_len = 128
group_size = 128
validation_iter = 3
CPUINFER_PARAM = 60
def fp8_e4m3_quantize(tensor_bf16):
"""Quantize BF16 tensor to FP8 E4M3 with block-wise scales (128x128)."""
n, k = tensor_bf16.shape
tensor_fp32 = tensor_bf16.float()
n_blocks_n = (n + group_size - 1) // group_size
n_blocks_k = (k + group_size - 1) // group_size
fp8_data = torch.zeros(n, k, dtype=torch.uint8)
scales = torch.zeros(n_blocks_n, n_blocks_k, dtype=torch.float32)
# FP8 E4M3 max value: 2^8 * (1 + 7/8) = 448
fp8_max = 448.0
for bn in range(n_blocks_n):
for bk in range(n_blocks_k):
n_start = bn * group_size
n_end = min(n_start + group_size, n)
k_start = bk * group_size
k_end = min(k_start + group_size, k)
block = tensor_fp32[n_start:n_end, k_start:k_end]
amax = block.abs().max().item()
if amax == 0:
scale = 1.0
else:
scale = amax / fp8_max
scales[bn, bk] = scale
# Quantize
for i in range(n_end - n_start):
for j in range(k_end - k_start):
val = block[i, j].item() / scale
fp8_data[n_start + i, k_start + j] = float_to_fp8_e4m3(val)
return fp8_data, scales
def float_to_fp8_e4m3(val):
"""Convert float to FP8 E4M3."""
if math.isnan(val):
return 0x7F
sign = 1 if val < 0 else 0
val = abs(val)
if val == 0:
return sign << 7
# Clamp to max
if val >= 448.0:
return (sign << 7) | 0x7E # max finite
# Find exponent
exp = int(math.floor(math.log2(val))) + 7
if exp <= 0:
# Subnormal
man = int(round(val * (2**6) * 8))
man = min(man, 7)
return (sign << 7) | man
if exp >= 15:
return (sign << 7) | 0x7E # clamp to max
# Normal
man = int(round((val / (2**(exp-7)) - 1.0) * 8))
man = min(man, 7)
return (sign << 7) | (exp << 3) | man
def fp8_e4m3_to_float(byte_val):
"""Convert FP8 E4M3 byte to float."""
sign = (byte_val >> 7) & 1
exp = (byte_val >> 3) & 0xF
man = byte_val & 0x7
if exp == 0 and man == 0:
return 0.0
if exp == 0:
val = (2**-6) * (man / 8.0)
elif exp == 15:
return float("nan")
else:
val = (2**(exp-7)) * (1.0 + man / 8.0)
return -val if sign else val
def fp8_dequantize(fp8_data, scales):
"""Dequantize FP8 + scales back to float32."""
n, k = fp8_data.shape
result = torch.zeros(n, k, dtype=torch.float32)
n_blocks_n = scales.shape[0]
n_blocks_k = scales.shape[1]
for i in range(n):
for j in range(k):
bn = i // group_size
bk = j // group_size
scale = scales[bn, bk].item()
fp8_val = fp8_e4m3_to_float(fp8_data[i, j].item())
result[i, j] = fp8_val * scale
return result
def act_fn(x):
return x / (1.0 + torch.exp(-x))
def mlp_torch(input, gate_proj, up_proj, down_proj):
gate_buf = torch.mm(input, gate_proj.t())
up_buf = torch.mm(input, up_proj.t())
intermediate = act_fn(gate_buf) * up_buf
return torch.mm(intermediate, down_proj.t())
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
cnts.scatter_(1, expert_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = expert_ids.view(-1).argsort()
sorted_tokens = input[idxs // expert_ids.shape[1]]
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens = sorted_tokens[start_idx:end_idx]
out = mlp_torch(tokens, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if outputs else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
return (new_x.view(*expert_ids.shape, -1).float().mul_(weights.unsqueeze(-1)).sum(1)).to(new_x.dtype)
def test_avx2_fp8_accuracy(qlen, label):
physical_to_logical_map = torch.tensor(range(expert_num), dtype=torch.int64).contiguous()
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
with torch.inference_mode():
# Generate BF16 weights, quantize to FP8
gate_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16)
up_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16)
down_bf16 = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 10.0).to(torch.bfloat16)
# Quantize each expert
gate_fp8_list, gate_scale_list = [], []
up_fp8_list, up_scale_list = [], []
down_fp8_list, down_scale_list = [], []
for e in range(expert_num):
gf, gs = fp8_e4m3_quantize(gate_bf16[e])
gate_fp8_list.append(gf)
gate_scale_list.append(gs)
uf, us = fp8_e4m3_quantize(up_bf16[e])
up_fp8_list.append(uf)
up_scale_list.append(us)
df, ds = fp8_e4m3_quantize(down_bf16[e])
down_fp8_list.append(df)
down_scale_list.append(ds)
# Stack into contiguous tensors
gate_fp8 = torch.stack(gate_fp8_list).contiguous()
gate_scales = torch.stack(gate_scale_list).contiguous()
up_fp8 = torch.stack(up_fp8_list).contiguous()
up_scales = torch.stack(up_scale_list).contiguous()
down_fp8 = torch.stack(down_fp8_list).contiguous()
down_scales = torch.stack(down_scale_list).contiguous()
# Dequantize for reference computation
gate_deq = torch.stack([fp8_dequantize(gate_fp8_list[e], gate_scale_list[e]) for e in range(expert_num)])
up_deq = torch.stack([fp8_dequantize(up_fp8_list[e], up_scale_list[e]) for e in range(expert_num)])
down_deq = torch.stack([fp8_dequantize(down_fp8_list[e], down_scale_list[e]) for e in range(expert_num)])
# Create MOE config
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.gate_proj = gate_fp8.data_ptr()
config.up_proj = up_fp8.data_ptr()
config.down_proj = down_fp8.data_ptr()
config.gate_scale = gate_scales.data_ptr()
config.up_scale = up_scales.data_ptr()
config.down_scale = down_scales.data_ptr()
config.quant_config.bits = 8
config.quant_config.group_size = group_size
config.quant_config.zero_point = False
config.pool = CPUInfer.backend_
moe = kt_kernel_ext.moe.AVX2FP8_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
print("\n--- %s (qlen=%d) ---" % (label, qlen))
for i in range(validation_iter):
expert_ids = torch.stack([torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]).contiguous()
weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
input_data = (torch.randn((qlen, hidden_size), dtype=torch.float32) / 100.0).to(torch.bfloat16).contiguous()
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
bsz_tensor = torch.tensor([qlen], dtype=torch.int32)
CPUInfer.submit(moe.forward_task(
bsz_tensor.data_ptr(), num_experts_per_tok,
expert_ids.data_ptr(), weights.data_ptr(),
input_data.data_ptr(), output.data_ptr(), False,
))
CPUInfer.sync()
# Reference: use dequantized FP32 weights
t_output = moe_torch(input_data.float(), expert_ids, weights, gate_deq, up_deq, down_deq).to(torch.bfloat16)
diff = torch.mean(torch.abs(output.float() - t_output.float())) / (torch.mean(torch.abs(t_output.float())) + 1e-8)
print(" Iteration %d: diff = %.6f" % (i, diff.item()))
assert diff < 0.1, "FP8 accuracy test failed: diff=%.6f >= 0.1" % diff.item()
print(" PASSED")
if __name__ == "__main__":
print("=" * 60)
print("AVX2 FP8 MoE Accuracy Test")
print("=" * 60)
try:
test_avx2_fp8_accuracy(qlen=1, label="Decode")
test_avx2_fp8_accuracy(qlen=16, label="Prefill")
print("\n" + "=" * 60)
print("ALL TESTS PASSED")
print("=" * 60)
except Exception as e:
print("\nTEST FAILED: %s" % e)
import traceback
traceback.print_exc()
sys.exit(1)