mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 14:29:22 +00:00
[feat](kt-kernel): support avx2 only inference for bf16 fp8 and gptq int4 (#1892)
* feat: support avx2 bf16 fp8 inference * feat: support avx2 gptq int4 inference * fix: numeric issues in fp8 dequant * Tutorial avx2 (#1900) * fix: prevent injecting -DLLAMA_AVX512=ON on AVX2-only machines * docs: add AVX2 tutorial for running KTransformers on AVX2-only CPUs * Tutorial avx2 (#1901) * fix: prevent injecting -DLLAMA_AVX512=ON on AVX2-only machines * docs: add AVX2 tutorial for running KTransformers on AVX2-only CPUs * docs: update README.md --------- Co-authored-by: Benjamin F <159887351+yyj6666667@users.noreply.github.com>
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
KTransformers is a research project focused on efficient inference and fine-tuning of large language models through CPU-GPU heterogeneous computing. The project has evolved into **two core modules**: [kt-kernel](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel/) and [kt-sft](https://github.com/kvcache-ai/ktransformers/tree/main/kt-sft).
|
||||
|
||||
## 🔥 Updates
|
||||
|
||||
* **Mar 26, 2026**: Support AVX2-only CPU backend for KT-Kernel inference. ([Tutorial](./doc/en/kt-kernel/AVX2-Tutorial.md))
|
||||
* **Feb 13, 2026**: MiniMax-M2.5 Day0 Support! ([Tutorial](./doc/en/MiniMax-M2.5.md))
|
||||
* **Feb 12, 2026**: GLM-5 Day0 Support! ([Tutorial](./doc/en/kt-kernel/GLM-5-Tutorial.md))
|
||||
* **Jan 27, 2026**: Kimi-K2.5 Day0 Support! ([Tutorial](./doc/en/Kimi-K2.5.md)) ([SFT Tutorial](./doc/en/SFT_Installation_Guide_KimiK2.5.md))
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
- [For kt-kernel](en/kt-kernel/kt-kernel_intro.md)
|
||||
- [For kt-sft](en/SFT/KTransformers-Fine-Tuning_User-Guide.md)
|
||||
|
||||
# Tutorial
|
||||
# Tutorial
|
||||
- [kt-sft part](en/SFT/README.md)
|
||||
- [Injection Tutorial](en/SFT/injection_tutorial.md)
|
||||
- [kt-sft developer tech notes](en/SFT/KTransformers-Fine-Tuning_Developer-Technical-Notes.md)
|
||||
@@ -19,6 +19,8 @@
|
||||
- [Makefile Usage](en/makefile_usage.md) -->
|
||||
- [kt-kernel part](en/kt-kernel/README.md)
|
||||
- [kt-cli](en/kt-kernel/kt-cli.md)
|
||||
- [AVX2 Backend Tutorial](en/kt-kernel/AVX2-Tutorial.md)
|
||||
- [AVX2 后端教程(中文)](zh/AVX2-Tutorial_zh.md)
|
||||
# FAQ
|
||||
- [FAQ](en/FAQ.md)
|
||||
<!-- # V3 Reproduction
|
||||
|
||||
188
doc/en/kt-kernel/AVX2-Tutorial.md
Normal file
188
doc/en/kt-kernel/AVX2-Tutorial.md
Normal file
@@ -0,0 +1,188 @@
|
||||
# Running KTransformers on AVX2 CPUs
|
||||
|
||||
This tutorial explains how to run KTransformers on machines that only support AVX2 (without AVX512 or AMX).
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Supported Precision Formats](#supported-precision-formats)
|
||||
- [Hardware Requirements](#hardware-requirements)
|
||||
- [Installation](#installation)
|
||||
- [Verification](#verification)
|
||||
- [Starting the Inference Server](#starting-the-inference-server)
|
||||
- [Example: Qwen3-30B-A3B (BF16)](#example-qwen3-30b-a3b-bf16)
|
||||
- [Example: Qwen3.5-35B-A3B-FP8 (FP8)](#example-qwen35-35b-a3b-fp8-fp8)
|
||||
- [Example: Qwen3-30B-A3B-GPTQ-Int4 (GPTQ_INT4)](#example-qwen3-30b-a3b-gptq-int4-gptq_int4)
|
||||
- [Sending Requests](#sending-requests)
|
||||
- [Performance Tuning](#performance-tuning)
|
||||
- [FAQ](#faq)
|
||||
|
||||
## Supported Precision Formats
|
||||
|
||||
| `--kt-method` | Precision | Description |
|
||||
|---------------|-----------|-------------|
|
||||
| `BF16` | BF16 native precision | Zero precision loss, uses BF16 weights directly |
|
||||
| `FP8` | FP8 block quantization | |
|
||||
| `GPTQ_INT4` | INT4 GPTQ | |
|
||||
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
- **CPU**: x86-64 + AVX2 + FMA (Intel Haswell 2013+ / AMD Zen+)
|
||||
- **GPU**: NVIDIA 24GB+ VRAM (RTX 3090/4090/5090, etc.)
|
||||
- **Memory**: At least the size of the model weights (e.g., Qwen3-30B-A3B BF16 requires 64GB+)
|
||||
- **OS**: Linux
|
||||
|
||||
## Installation
|
||||
|
||||
Build and install from source (one-click install for kt-kernel + SGLang):
|
||||
|
||||
```bash
|
||||
git clone https://github.com/kvcache-ai/ktransformers.git
|
||||
cd ktransformers
|
||||
git submodule update --init --recursive
|
||||
|
||||
# One-click install
|
||||
./install.sh
|
||||
```
|
||||
|
||||
On AVX512 or AMX machines, you can also manually force AVX2 compilation:
|
||||
|
||||
```bash
|
||||
export CPUINFER_CPU_INSTRUCT=AVX2
|
||||
export CPUINFER_ENABLE_AMX=OFF
|
||||
./install.sh kt-kernel --manual
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Verification
|
||||
|
||||
```bash
|
||||
# Check if the CPU supports AVX2
|
||||
lscpu | grep -i avx2
|
||||
|
||||
# Check the loaded kt-kernel variant
|
||||
python -c "import kt_kernel; print(kt_kernel.__cpu_variant__)"
|
||||
# Expected output: avx2
|
||||
|
||||
# System diagnostics
|
||||
kt doctor
|
||||
```
|
||||
|
||||
## Starting the Inference Server
|
||||
|
||||
Use `--kt-method BF16`, `FP8`, or `GPTQ_INT4`. KT-Kernel will **automatically detect** the CPU and fall back to the AVX2 backend when AVX512/AMX is unavailable.
|
||||
|
||||
### Example: Qwen3-30B-A3B (BF16)
|
||||
|
||||
```bash
|
||||
# Download the model
|
||||
huggingface-cli download Qwen/Qwen3-30B-A3B --local-dir /path/to/Qwen3-30B-A3B
|
||||
|
||||
# Check physical core count and NUMA node count
|
||||
lscpu | grep -E "^CPU\(s\)|Thread\(s\) per core|NUMA node\(s\)"
|
||||
|
||||
# Start the server (adjust kt-cpuinfer and kt-threadpool-count based on your hardware)
|
||||
python -m sglang.launch_server \
|
||||
--host 0.0.0.0 --port 30000 \
|
||||
--model /path/to/Qwen3-30B-A3B \
|
||||
--kt-weight-path /path/to/Qwen3-30B-A3B \
|
||||
--kt-cpuinfer 16 \
|
||||
--kt-threadpool-count 1 \
|
||||
--kt-num-gpu-experts 32 \
|
||||
--kt-method BF16 \
|
||||
--attention-backend flashinfer \
|
||||
--trust-remote-code \
|
||||
--mem-fraction-static 0.80 \
|
||||
--chunked-prefill-size 8192 \
|
||||
--max-running-requests 2 \
|
||||
--served-model-name Qwen3 \
|
||||
--enable-mixed-chunk \
|
||||
--tensor-parallel-size 1 \
|
||||
--enable-p2p-check \
|
||||
--disable-shared-experts-fusion
|
||||
```
|
||||
|
||||
### Example: Qwen3.5-35B-A3B-FP8 (FP8)
|
||||
|
||||
```bash
|
||||
# Download the model
|
||||
huggingface-cli download Qwen/Qwen3.5-35B-A3B-FP8 --local-dir /path/to/Qwen3.5-35B-A3B-FP8
|
||||
|
||||
# Start the server
|
||||
python -m sglang.launch_server \
|
||||
--host 0.0.0.0 --port 30000 \
|
||||
--model /path/to/Qwen3.5-35B-A3B-FP8 \
|
||||
--kt-weight-path /path/to/Qwen3.5-35B-A3B-FP8 \
|
||||
--kt-cpuinfer 16 \
|
||||
--kt-threadpool-count 1 \
|
||||
--kt-num-gpu-experts 2 \
|
||||
--kt-method FP8 \
|
||||
--kt-gpu-prefill-token-threshold 400 \
|
||||
--attention-backend triton \
|
||||
--trust-remote-code \
|
||||
--mem-fraction-static 0.85 \
|
||||
--chunked-prefill-size 4096 \
|
||||
--max-running-requests 1 \
|
||||
--max-total-tokens 32000 \
|
||||
--enable-mixed-chunk \
|
||||
--tensor-parallel-size 1 \
|
||||
--disable-shared-experts-fusion
|
||||
```
|
||||
|
||||
### Example: Qwen3-30B-A3B-GPTQ-Int4 (GPTQ_INT4)
|
||||
|
||||
```bash
|
||||
# Download the model
|
||||
huggingface-cli download Qwen/Qwen3-30B-A3B-GPTQ-Int4 --local-dir /path/to/Qwen3-30B-A3B-GPTQ-Int4
|
||||
|
||||
# Start the server
|
||||
python -m sglang.launch_server \
|
||||
--host 0.0.0.0 --port 30000 \
|
||||
--model /path/to/Qwen3-30B-A3B-GPTQ-Int4 \
|
||||
--kt-weight-path /path/to/Qwen3-30B-A3B-GPTQ-Int4 \
|
||||
--kt-cpuinfer 16 \
|
||||
--kt-threadpool-count 1 \
|
||||
--kt-num-gpu-experts 2 \
|
||||
--kt-method GPTQ_INT4 \
|
||||
--attention-backend triton \
|
||||
--trust-remote-code \
|
||||
--mem-fraction-static 0.85 \
|
||||
--chunked-prefill-size 4096 \
|
||||
--max-running-requests 1 \
|
||||
--max-total-tokens 32000 \
|
||||
--enable-mixed-chunk \
|
||||
--tensor-parallel-size 1 \
|
||||
--disable-shared-experts-fusion
|
||||
```
|
||||
|
||||
### Sending Requests
|
||||
|
||||
```bash
|
||||
# Interactive chat
|
||||
kt chat
|
||||
|
||||
# OpenAI-compatible API
|
||||
curl http://localhost:30000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model":"Qwen3","messages":[{"role":"user","content":"Hello"}],"stream":true}'
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
- `--kt-cpuinfer`: set to the number of **physical cores**
|
||||
- `--kt-threadpool-count`: set to the number of **NUMA nodes**
|
||||
- `--kt-num-gpu-experts`: higher values reduce CPU load but increase GPU VRAM usage
|
||||
- Memory bandwidth is often the bottleneck; high-frequency DDR5 memory helps significantly
|
||||
|
||||
## FAQ
|
||||
|
||||
|
||||
|
||||
**GPU OOM**
|
||||
- Reduce `--kt-num-gpu-experts`, `--chunked-prefill-size`, `--max-total-tokens`
|
||||
- Lower `--mem-fraction-static`
|
||||
|
||||
For more questions, see [FAQ](../FAQ.md).
|
||||
188
doc/zh/AVX2-Tutorial_zh.md
Normal file
188
doc/zh/AVX2-Tutorial_zh.md
Normal file
@@ -0,0 +1,188 @@
|
||||
# 在 AVX2 CPU 上使用 KTransformers
|
||||
|
||||
本教程介绍如何在仅支持 AVX2 的机器上运行 KTransformers(无需 AVX512 或 AMX)。
|
||||
|
||||
## 目录
|
||||
|
||||
- [支持的精度格式](#支持的精度格式)
|
||||
- [硬件要求](#硬件要求)
|
||||
- [安装](#安装)
|
||||
- [验证](#验证)
|
||||
- [启动推理服务](#启动推理服务)
|
||||
- [示例:Qwen3-30B-A3B (BF16)](#示例qwen3-30b-a3b-bf16)
|
||||
- [示例:Qwen3.5-35B-A3B-FP8 (FP8)](#示例qwen35-35b-a3b-fp8-fp8)
|
||||
- [示例:Qwen3-30B-A3B-GPTQ-Int4 (GPTQ_INT4)](#示例qwen3-30b-a3b-gptq-int4-gptq_int4)
|
||||
- [发送请求](#发送请求)
|
||||
- [性能调优](#性能调优)
|
||||
- [常见问题](#常见问题)
|
||||
|
||||
## 支持的精度格式
|
||||
|
||||
| `--kt-method` | 精度 | 说明 |
|
||||
|---------------|------|------|
|
||||
| `BF16` | BF16 原精度 | 零精度损失,直接使用 BF16 权重 |
|
||||
| `FP8` | FP8 分块量化 | |
|
||||
| `GPTQ_INT4` | INT4 GPTQ | |
|
||||
|
||||
|
||||
## 硬件要求
|
||||
|
||||
- **CPU**:x86-64 + AVX2 + FMA(Intel Haswell 2013+ / AMD Zen+)
|
||||
- **GPU**:NVIDIA 24GB+ 显存(RTX 3090/4090/5090 等)
|
||||
- **内存**:不少于模型权重大小(如 Qwen3-30B-A3B BF16 需 64GB+)
|
||||
- **系统**:Linux
|
||||
|
||||
## 安装
|
||||
|
||||
从源码编译安装(一键安装 kt-kernel + SGLang):
|
||||
|
||||
```bash
|
||||
git clone https://github.com/kvcache-ai/ktransformers.git
|
||||
cd ktransformers
|
||||
git submodule update --init --recursive
|
||||
|
||||
# 一键安装
|
||||
./install.sh
|
||||
```
|
||||
|
||||
在AVX512, AMX机器上, 也可以手动强制 AVX2 编译:
|
||||
|
||||
```bash
|
||||
export CPUINFER_CPU_INSTRUCT=AVX2
|
||||
export CPUINFER_ENABLE_AMX=OFF
|
||||
./install.sh kt-kernel --manual
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 验证
|
||||
|
||||
```bash
|
||||
# 检查 CPU 是否支持 AVX2
|
||||
lscpu | grep -i avx2
|
||||
|
||||
# 检查 kt-kernel 加载的变体
|
||||
python -c "import kt_kernel; print(kt_kernel.__cpu_variant__)"
|
||||
# 预期输出:avx2
|
||||
|
||||
# 系统诊断
|
||||
kt doctor
|
||||
```
|
||||
|
||||
## 启动推理服务
|
||||
|
||||
使用 `--kt-method BF16`、`FP8` 或 `GPTQ_INT4`,KT-Kernel 会**自动检测** CPU 并在缺少 AVX512/AMX 时回退到 AVX2 后端。
|
||||
|
||||
### 示例:Qwen3-30B-A3B (BF16)
|
||||
|
||||
```bash
|
||||
# 下载模型
|
||||
huggingface-cli download Qwen/Qwen3-30B-A3B --local-dir /path/to/Qwen3-30B-A3B
|
||||
|
||||
# 查看物理核心数和 NUMA 节点数
|
||||
lscpu | grep -E "^CPU\(s\)|Thread\(s\) per core|NUMA node\(s\)"
|
||||
|
||||
# 启动服务(按实际硬件调整 kt-cpuinfer 和 kt-threadpool-count)
|
||||
python -m sglang.launch_server \
|
||||
--host 0.0.0.0 --port 30000 \
|
||||
--model /path/to/Qwen3-30B-A3B \
|
||||
--kt-weight-path /path/to/Qwen3-30B-A3B \
|
||||
--kt-cpuinfer 16 \
|
||||
--kt-threadpool-count 1 \
|
||||
--kt-num-gpu-experts 32 \
|
||||
--kt-method BF16 \
|
||||
--attention-backend flashinfer \
|
||||
--trust-remote-code \
|
||||
--mem-fraction-static 0.80 \
|
||||
--chunked-prefill-size 8192 \
|
||||
--max-running-requests 2 \
|
||||
--served-model-name Qwen3 \
|
||||
--enable-mixed-chunk \
|
||||
--tensor-parallel-size 1 \
|
||||
--enable-p2p-check \
|
||||
--disable-shared-experts-fusion
|
||||
```
|
||||
|
||||
### 示例:Qwen3.5-35B-A3B-FP8 (FP8)
|
||||
|
||||
```bash
|
||||
# 下载模型
|
||||
huggingface-cli download Qwen/Qwen3.5-35B-A3B-FP8 --local-dir /path/to/Qwen3.5-35B-A3B-FP8
|
||||
|
||||
# 启动服务
|
||||
python -m sglang.launch_server \
|
||||
--host 0.0.0.0 --port 30000 \
|
||||
--model /path/to/Qwen3.5-35B-A3B-FP8 \
|
||||
--kt-weight-path /path/to/Qwen3.5-35B-A3B-FP8 \
|
||||
--kt-cpuinfer 16 \
|
||||
--kt-threadpool-count 1 \
|
||||
--kt-num-gpu-experts 2 \
|
||||
--kt-method FP8 \
|
||||
--kt-gpu-prefill-token-threshold 400 \
|
||||
--attention-backend triton \
|
||||
--trust-remote-code \
|
||||
--mem-fraction-static 0.85 \
|
||||
--chunked-prefill-size 4096 \
|
||||
--max-running-requests 1 \
|
||||
--max-total-tokens 32000 \
|
||||
--enable-mixed-chunk \
|
||||
--tensor-parallel-size 1 \
|
||||
--disable-shared-experts-fusion
|
||||
```
|
||||
|
||||
### 示例:Qwen3-30B-A3B-GPTQ-Int4 (GPTQ_INT4)
|
||||
|
||||
```bash
|
||||
# 下载模型
|
||||
huggingface-cli download Qwen/Qwen3-30B-A3B-GPTQ-Int4 --local-dir /path/to/Qwen3-30B-A3B-GPTQ-Int4
|
||||
|
||||
# 启动服务
|
||||
python -m sglang.launch_server \
|
||||
--host 0.0.0.0 --port 30000 \
|
||||
--model /path/to/Qwen3-30B-A3B-GPTQ-Int4 \
|
||||
--kt-weight-path /path/to/Qwen3-30B-A3B-GPTQ-Int4 \
|
||||
--kt-cpuinfer 16 \
|
||||
--kt-threadpool-count 1 \
|
||||
--kt-num-gpu-experts 2 \
|
||||
--kt-method GPTQ_INT4 \
|
||||
--attention-backend triton \
|
||||
--trust-remote-code \
|
||||
--mem-fraction-static 0.85 \
|
||||
--chunked-prefill-size 4096 \
|
||||
--max-running-requests 1 \
|
||||
--max-total-tokens 32000 \
|
||||
--enable-mixed-chunk \
|
||||
--tensor-parallel-size 1 \
|
||||
--disable-shared-experts-fusion
|
||||
```
|
||||
|
||||
### 发送请求
|
||||
|
||||
```bash
|
||||
# 交互聊天
|
||||
kt chat
|
||||
|
||||
# OpenAI 兼容 API
|
||||
curl http://localhost:30000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model":"Qwen3","messages":[{"role":"user","content":"你好"}],"stream":true}'
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 性能调优
|
||||
|
||||
- `--kt-cpuinfer` 设为**物理核心数**
|
||||
- `--kt-threadpool-count` 设为 **NUMA 节点数**
|
||||
- `--kt-num-gpu-experts` 越大 CPU 负担越小,但 GPU 显存占用越高
|
||||
- 内存带宽往往是瓶颈,DDR5 高频内存有明显帮助
|
||||
|
||||
## 常见问题
|
||||
|
||||
|
||||
|
||||
**GPU OOM**
|
||||
- 减小 `--kt-num-gpu-experts`、`--chunked-prefill-size`、`--max-total-tokens`
|
||||
- 降低 `--mem-fraction-static`
|
||||
|
||||
更多问题参见 [FAQ](../en/FAQ.md)。
|
||||
@@ -45,6 +45,13 @@ static const bool _is_plain_ = false;
|
||||
#include "operators/amx/la/amx_kernels.hpp"
|
||||
#include "operators/amx/moe.hpp"
|
||||
#endif
|
||||
// AVX2 backends — always available on x86_64 (no AMX/AVX512 dependency)
|
||||
#if defined(__x86_64__)
|
||||
#include "operators/avx2/bf16-moe.hpp"
|
||||
#include "operators/avx2/fp8-moe.hpp"
|
||||
#include "operators/avx2/gptq_int4-moe.hpp"
|
||||
#endif
|
||||
|
||||
#include <pybind11/stl.h> // std::vector/std::pair/std::string conversions
|
||||
|
||||
#include <cstdint>
|
||||
@@ -578,6 +585,13 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
|
||||
bind_moe_module<AMX_FP8_PERCHANNEL_MOE_TP<amx::GemmKernel224FP8PerChannel>>(moe_module, "AMXFP8PerChannel_MOE");
|
||||
#endif
|
||||
#endif
|
||||
// AVX2 backends — available on all x86_64 (no AMX/AVX512 requirement)
|
||||
#if defined(__x86_64__)
|
||||
bind_moe_module<AVX2_BF16_MOE_TP<avx2::GemmKernelAVX2BF16>>(moe_module, "AVX2BF16_MOE");
|
||||
bind_moe_module<AVX2_FP8_MOE_TP<avx2::GemmKernelAVX2FP8>>(moe_module, "AVX2FP8_MOE");
|
||||
bind_moe_module<AVX2_GPTQ_INT4_MOE_TP<avx2::GemmKernelAVX2GPTQInt4>>(moe_module, "AVX2GPTQInt4_MOE");
|
||||
#endif
|
||||
|
||||
#if defined(USE_MOE_KERNEL)
|
||||
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");
|
||||
#if defined(__aarch64__) && defined(CPU_USE_KML)
|
||||
|
||||
228
kt-kernel/operators/avx2/avx2_bf16_gemm.hpp
Normal file
228
kt-kernel/operators/avx2/avx2_bf16_gemm.hpp
Normal file
@@ -0,0 +1,228 @@
|
||||
/**
|
||||
* @Description : AVX2 BF16 GEMM kernel with trivial Buffer abstractions
|
||||
* @Author : Claude
|
||||
* @Date : 2026-03-18
|
||||
* @Version : 1.0.0
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
*
|
||||
* Unlike AMX kernels that use packed tile layouts (BufferB with 16x16 transpose),
|
||||
* the AVX2 kernel uses row-major storage for all buffers.
|
||||
* BufferA/B/C are thin wrappers over raw memory with trivial from_mat/to_mat.
|
||||
*
|
||||
* GEMM: C[m,n] = sum_k A[m,k] * B[n,k]
|
||||
* A: [M, K] row-major BF16 (input activations)
|
||||
* B: [N, K] row-major BF16 (weights, each row is one output neuron)
|
||||
* C: [M, N] row-major FP32 (output)
|
||||
**/
|
||||
#ifndef CPUINFER_OPERATOR_AVX2_BF16_GEMM_H
|
||||
#define CPUINFER_OPERATOR_AVX2_BF16_GEMM_H
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
|
||||
#include "avx2_bf16_utils.hpp"
|
||||
|
||||
namespace avx2 {
|
||||
|
||||
// Split range [0, total) among nth threads, return [start, end) for thread ith
|
||||
static inline std::pair<int, int> split_range(int total, int ith, int nth) {
|
||||
int per = total / nth;
|
||||
int rem = total % nth;
|
||||
int start = ith * per + std::min(ith, rem);
|
||||
int end = start + per + (ith < rem ? 1 : 0);
|
||||
return {start, end};
|
||||
}
|
||||
|
||||
struct GemmKernelAVX2BF16 {
|
||||
using dt = ggml_bf16_t;
|
||||
using output_t = float;
|
||||
static constexpr int M_STEP = 1; // No M-direction padding needed (vs AMX 16)
|
||||
static constexpr int N_STEP = 8; // 8-wide FP32 AVX2 (vs AMX 32)
|
||||
static constexpr int K_STEP = 8; // Process 8 K elements at a time
|
||||
static constexpr int N_BLOCK = 64; // N blocking for cache
|
||||
static constexpr int K_BLOCK = 256; // K blocking for cache
|
||||
static constexpr double ELEMENT_SIZE = 2.0; // BF16 = 2 bytes
|
||||
|
||||
// No AMX tile configuration needed
|
||||
static void config() {}
|
||||
|
||||
// Thread count for N-dimension parallelism
|
||||
// Must return >= 1 to avoid division by zero in moe_base task dispatch
|
||||
static int recommended_nth(int n) {
|
||||
return std::max(1, n / N_STEP);
|
||||
}
|
||||
|
||||
// Split N range for multi-threaded GEMM
|
||||
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
|
||||
return split_range(n, ith, nth);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// BufferA: Input activations [M, K] row-major BF16
|
||||
// from_mat() = memcpy (no packing needed for AVX2)
|
||||
// ========================================================================
|
||||
struct BufferA {
|
||||
ggml_bf16_t* data = nullptr;
|
||||
size_t max_m = 0;
|
||||
size_t k = 0;
|
||||
|
||||
BufferA() = default;
|
||||
BufferA(size_t m, size_t k_, void* ptr) : max_m(m), k(k_), data((ggml_bf16_t*)ptr) {}
|
||||
|
||||
static size_t required_size(size_t m, size_t k) {
|
||||
return m * k * sizeof(ggml_bf16_t);
|
||||
}
|
||||
|
||||
void set_data(void* ptr) { data = (ggml_bf16_t*)ptr; }
|
||||
|
||||
// Copy input rows into buffer (trivial memcpy)
|
||||
void from_mat(int m, const ggml_bf16_t* src, int ith, int nth) {
|
||||
if (ith == 0 && nth == 1) {
|
||||
std::memcpy(data, src, (size_t)m * k * sizeof(ggml_bf16_t));
|
||||
} else {
|
||||
// Multi-threaded: split by rows
|
||||
auto [m_start, m_end] = split_range(m, ith, nth);
|
||||
std::memcpy(data + m_start * k, src + m_start * k,
|
||||
(size_t)(m_end - m_start) * k * sizeof(ggml_bf16_t));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
// BufferB: Weight matrix [N, K] row-major BF16
|
||||
// from_mat() = memcpy (no transpose/packing needed)
|
||||
// ========================================================================
|
||||
struct BufferB {
|
||||
ggml_bf16_t* b = nullptr;
|
||||
size_t n = 0;
|
||||
size_t k = 0;
|
||||
|
||||
BufferB() = default;
|
||||
BufferB(size_t n_, size_t k_, void* ptr) : n(n_), k(k_), b((ggml_bf16_t*)ptr) {}
|
||||
|
||||
static size_t required_size(size_t n, size_t k) {
|
||||
return n * k * sizeof(ggml_bf16_t);
|
||||
}
|
||||
|
||||
// Copy weight data (multi-threaded by N dimension)
|
||||
void from_mat(const ggml_bf16_t* src, int ith, int nth) {
|
||||
auto [n_start, n_end] = split_range((int)n, ith, nth);
|
||||
std::memcpy(b + n_start * k, src + n_start * k,
|
||||
(size_t)(n_end - n_start) * k * sizeof(ggml_bf16_t));
|
||||
}
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
// BufferC: Output matrix [M, N] row-major FP32
|
||||
// to_mat() converts FP32 -> BF16 and writes out
|
||||
// ========================================================================
|
||||
struct BufferC {
|
||||
float* data = nullptr;
|
||||
size_t max_m = 0;
|
||||
size_t n = 0;
|
||||
|
||||
BufferC() = default;
|
||||
BufferC(size_t m, size_t n_, void* ptr) : max_m(m), n(n_), data((float*)ptr) {}
|
||||
|
||||
static size_t required_size(size_t m, size_t n) {
|
||||
return m * n * sizeof(float);
|
||||
}
|
||||
|
||||
void set_data(void* ptr) { data = (float*)ptr; }
|
||||
|
||||
// Convert FP32 output to BF16 and write to destination
|
||||
void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {
|
||||
auto [n_start, n_end] = split_range_n((int)n, ith, nth);
|
||||
for (int mi = 0; mi < m; mi++) {
|
||||
float* src_row = data + mi * n;
|
||||
ggml_bf16_t* dst_row = dst + mi * n;
|
||||
int j = n_start;
|
||||
for (; j + 8 <= n_end; j += 8) {
|
||||
__m256 v = _mm256_loadu_ps(src_row + j);
|
||||
store_fp32_to_bf16(dst_row + j, v);
|
||||
}
|
||||
// Scalar tail
|
||||
for (; j < n_end; j++) {
|
||||
dst_row[j] = GGML_FP32_TO_BF16(src_row[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// AVX2 BF16 GEMM functions
|
||||
// C[m,n] = sum_k A[m,k] * B[n,k]
|
||||
// ============================================================================
|
||||
|
||||
// General GEMM (works for both vec_mul m=1 and mat_mul m>1)
|
||||
static inline void gemm_bf16(
|
||||
int m, int n, int k,
|
||||
GemmKernelAVX2BF16::BufferA& a,
|
||||
GemmKernelAVX2BF16::BufferB& b,
|
||||
GemmKernelAVX2BF16::BufferC& c,
|
||||
int ith, int nth) {
|
||||
|
||||
auto [n_start, n_end] = split_range(n, ith, nth);
|
||||
|
||||
for (int ni = n_start; ni < n_end; ni++) {
|
||||
const ggml_bf16_t* b_row = b.b + (size_t)ni * k;
|
||||
|
||||
for (int mi = 0; mi < m; mi++) {
|
||||
const ggml_bf16_t* a_row = a.data + (size_t)mi * a.k;
|
||||
|
||||
// AVX2 BF16 dot product (matches ggml_vec_dot_bf16 AVX2 path)
|
||||
__m256 c1 = _mm256_setzero_ps();
|
||||
__m256 c2 = _mm256_setzero_ps();
|
||||
__m256 c3 = _mm256_setzero_ps();
|
||||
__m256 c4 = _mm256_setzero_ps();
|
||||
|
||||
int ki = 0;
|
||||
for (; ki + 32 <= k; ki += 32) {
|
||||
c1 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki), load_bf16_to_fp32(b_row + ki), c1);
|
||||
c2 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki + 8), load_bf16_to_fp32(b_row + ki + 8), c2);
|
||||
c3 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki + 16), load_bf16_to_fp32(b_row + ki + 16), c3);
|
||||
c4 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + ki + 24), load_bf16_to_fp32(b_row + ki + 24), c4);
|
||||
}
|
||||
|
||||
float sum = hsum_avx2(_mm256_add_ps(_mm256_add_ps(c1, c3), _mm256_add_ps(c2, c4)));
|
||||
|
||||
// Scalar tail
|
||||
for (; ki < k; ki++) {
|
||||
sum += GGML_BF16_TO_FP32(a_row[ki]) * GGML_BF16_TO_FP32(b_row[ki]);
|
||||
}
|
||||
|
||||
c.data[mi * n + ni] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// vec_mul: dispatch to gemm_bf16
|
||||
static inline void vec_mul(
|
||||
int m, int n, int k,
|
||||
std::shared_ptr<GemmKernelAVX2BF16::BufferA>& a,
|
||||
std::shared_ptr<GemmKernelAVX2BF16::BufferB>& b,
|
||||
std::shared_ptr<GemmKernelAVX2BF16::BufferC>& c,
|
||||
int ith, int nth) {
|
||||
gemm_bf16(m, n, k, *a, *b, *c, ith, nth);
|
||||
}
|
||||
|
||||
// mat_mul: dispatch to gemm_bf16
|
||||
static inline void mat_mul(
|
||||
int m, int n, int k,
|
||||
std::shared_ptr<GemmKernelAVX2BF16::BufferA>& a,
|
||||
std::shared_ptr<GemmKernelAVX2BF16::BufferB>& b,
|
||||
std::shared_ptr<GemmKernelAVX2BF16::BufferC>& c,
|
||||
int ith, int nth) {
|
||||
gemm_bf16(m, n, k, *a, *b, *c, ith, nth);
|
||||
}
|
||||
|
||||
} // namespace avx2
|
||||
|
||||
#endif // CPUINFER_OPERATOR_AVX2_BF16_GEMM_H
|
||||
132
kt-kernel/operators/avx2/avx2_bf16_utils.hpp
Normal file
132
kt-kernel/operators/avx2/avx2_bf16_utils.hpp
Normal file
@@ -0,0 +1,132 @@
|
||||
/**
|
||||
* @Description : AVX2 BF16 utility functions (bf16<->fp32 conversion, activation)
|
||||
* @Author : Claude
|
||||
* @Date : 2026-03-18
|
||||
* @Version : 1.0.0
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
*
|
||||
* AVX2 ports of the AVX512 utilities in amx/la/utils.hpp and amx/la/amx.hpp.
|
||||
* Uses 256-bit SIMD (8 floats) instead of 512-bit (16 floats).
|
||||
**/
|
||||
#ifndef CPUINFER_OPERATOR_AVX2_BF16_UTILS_H
|
||||
#define CPUINFER_OPERATOR_AVX2_BF16_UTILS_H
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <cmath>
|
||||
#include "llama.cpp/ggml.h"
|
||||
|
||||
namespace avx2 {
|
||||
|
||||
// ============================================================================
|
||||
// BF16 <-> FP32 conversion
|
||||
// ============================================================================
|
||||
|
||||
// Load 8 BF16 values and convert to 8 FP32 values
|
||||
// BF16 is the upper 16 bits of FP32, so shift left by 16
|
||||
static inline __m256 load_bf16_to_fp32(const ggml_bf16_t* src) {
|
||||
__m128i bf16 = _mm_loadu_si128((const __m128i*)src);
|
||||
__m256i i32 = _mm256_cvtepu16_epi32(bf16);
|
||||
return _mm256_castsi256_ps(_mm256_slli_epi32(i32, 16));
|
||||
}
|
||||
|
||||
// Convert 8 FP32 values to 8 BF16 values with round-to-nearest-even
|
||||
// Matches ggml_compute_fp32_to_bf16 semantics (ggml-impl.h:87)
|
||||
// and amx/la/utils.hpp:24 tie-bit correction
|
||||
static inline void store_fp32_to_bf16(ggml_bf16_t* dst, __m256 src) {
|
||||
__m256i i32 = _mm256_castps_si256(src);
|
||||
// Round-to-nearest-even: add 0x7FFF + ((val >> 16) & 1)
|
||||
__m256i tie_bit = _mm256_and_si256(_mm256_srli_epi32(i32, 16), _mm256_set1_epi32(1));
|
||||
__m256i round = _mm256_add_epi32(_mm256_set1_epi32(0x7FFF), tie_bit);
|
||||
__m256i rounded = _mm256_add_epi32(i32, round);
|
||||
__m256i shifted = _mm256_srli_epi32(rounded, 16);
|
||||
// Pack 32-bit -> 16-bit
|
||||
// _mm_packus_epi32 processes 128-bit lanes: packs [lo0..lo3, hi0..hi3] -> [lo0..lo3, hi0..hi3]
|
||||
__m128i lo = _mm256_castsi256_si128(shifted);
|
||||
__m128i hi = _mm256_extracti128_si256(shifted, 1);
|
||||
__m128i packed = _mm_packus_epi32(lo, hi);
|
||||
_mm_storeu_si128((__m128i*)dst, packed);
|
||||
}
|
||||
|
||||
// Load 16 BF16 -> 2x8 FP32 (corresponds to avx512_32xbf16_to_32xfp32)
|
||||
static inline void load_16xbf16_to_2x8xfp32(const ggml_bf16_t* src, __m256* out0, __m256* out1) {
|
||||
*out0 = load_bf16_to_fp32(src);
|
||||
*out1 = load_bf16_to_fp32(src + 8);
|
||||
}
|
||||
|
||||
// Store 2x8 FP32 -> 16 BF16 (corresponds to avx512_32xfp32_to_32xbf16)
|
||||
static inline void store_2x8xfp32_to_16xbf16(__m256* in0, __m256* in1, ggml_bf16_t* dst) {
|
||||
store_fp32_to_bf16(dst, *in0);
|
||||
store_fp32_to_bf16(dst + 8, *in1);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Horizontal sum for __m256 (8 floats -> 1 float)
|
||||
// ============================================================================
|
||||
|
||||
static inline float hsum_avx2(__m256 v) {
|
||||
__m128 hi = _mm256_extractf128_ps(v, 1);
|
||||
__m128 lo = _mm256_castps256_ps128(v);
|
||||
__m128 sum = _mm_add_ps(lo, hi);
|
||||
sum = _mm_add_ps(sum, _mm_movehl_ps(sum, sum));
|
||||
sum = _mm_add_ss(sum, _mm_movehdup_ps(sum));
|
||||
return _mm_cvtss_f32(sum);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Fast exp approximation (AVX2 port of amx::exp_avx512)
|
||||
// ============================================================================
|
||||
|
||||
static inline __m256 exp_avx2(__m256 x) {
|
||||
const __m256 log2e = _mm256_set1_ps(1.44269504089f);
|
||||
|
||||
__m256 y = _mm256_mul_ps(x, log2e);
|
||||
__m256i int_part = _mm256_cvtps_epi32(y);
|
||||
__m256 frac_part = _mm256_sub_ps(y, _mm256_cvtepi32_ps(int_part));
|
||||
|
||||
const __m256 poly_1 = _mm256_set1_ps(0.9999999995f);
|
||||
const __m256 poly_2 = _mm256_set1_ps(0.6931471805f);
|
||||
const __m256 poly_3 = _mm256_set1_ps(0.2402265069f);
|
||||
const __m256 poly_4 = _mm256_set1_ps(0.0555041087f);
|
||||
const __m256 poly_5 = _mm256_set1_ps(0.0096181291f);
|
||||
const __m256 poly_6 = _mm256_set1_ps(0.0013333558f);
|
||||
|
||||
__m256 frac_exp = _mm256_fmadd_ps(
|
||||
_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(poly_6, frac_part, poly_5), frac_part, poly_4),
|
||||
frac_part, poly_3),
|
||||
frac_part, poly_2),
|
||||
frac_part, poly_1);
|
||||
|
||||
// 2^int_part: AVX2 doesn't have _mm256_scalef_ps, use manual construction
|
||||
// 2^n = reinterpret((n + 127) << 23) for float
|
||||
// Clamp int_part to [-126, 127] to avoid invalid bit patterns:
|
||||
// int_part < -126 → biased < 1 → denorm/zero (scalef_ps would give 0)
|
||||
// int_part > 127 → biased > 254 → inf (scalef_ps would give inf)
|
||||
__m256i clamped = _mm256_max_epi32(_mm256_min_epi32(int_part, _mm256_set1_epi32(127)),
|
||||
_mm256_set1_epi32(-126));
|
||||
__m256i biased = _mm256_add_epi32(clamped, _mm256_set1_epi32(127));
|
||||
__m256i shifted = _mm256_slli_epi32(biased, 23);
|
||||
__m256 two_pow_i = _mm256_castsi256_ps(shifted);
|
||||
|
||||
return _mm256_mul_ps(two_pow_i, frac_exp);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SiLU activation: silu(gate) * up = gate * sigmoid(gate) * up
|
||||
// AVX2 port of amx::act_fn
|
||||
// ============================================================================
|
||||
|
||||
static inline __m256 act_fn(__m256 gate_val, __m256 up_val) {
|
||||
__m256 neg_gate_val = _mm256_sub_ps(_mm256_setzero_ps(), gate_val);
|
||||
// Clamp to avoid exp overflow
|
||||
const __m256 max_exp_input = _mm256_set1_ps(88.0f);
|
||||
neg_gate_val = _mm256_min_ps(neg_gate_val, max_exp_input);
|
||||
__m256 exp_neg_gate = exp_avx2(neg_gate_val);
|
||||
__m256 denom = _mm256_add_ps(_mm256_set1_ps(1.0f), exp_neg_gate);
|
||||
__m256 act_val = _mm256_div_ps(gate_val, denom);
|
||||
|
||||
return _mm256_mul_ps(act_val, up_val);
|
||||
}
|
||||
|
||||
} // namespace avx2
|
||||
|
||||
#endif // CPUINFER_OPERATOR_AVX2_BF16_UTILS_H
|
||||
327
kt-kernel/operators/avx2/bf16-moe.hpp
Normal file
327
kt-kernel/operators/avx2/bf16-moe.hpp
Normal file
@@ -0,0 +1,327 @@
|
||||
/**
|
||||
* @Description : AVX2 BF16 MoE operator (ported from amx/bf16-moe.hpp)
|
||||
* @Author : Claude
|
||||
* @Date : 2026-03-18
|
||||
* @Version : 1.0.0
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
*
|
||||
* Simplified from AMX version:
|
||||
* - BufferB::from_mat is memcpy (no AMX transpose)
|
||||
* - No unpack_nk_block_bf16 (no AMX packed format to undo)
|
||||
* - write_weights_to_buffer uses direct memcpy with full TP routing logic
|
||||
**/
|
||||
#ifndef CPUINFER_OPERATOR_AVX2_BF16_MOE_H
|
||||
#define CPUINFER_OPERATOR_AVX2_BF16_MOE_H
|
||||
|
||||
#include "avx2_bf16_gemm.hpp"
|
||||
#include "moe_base.hpp"
|
||||
|
||||
template <class T = avx2::GemmKernelAVX2BF16>
|
||||
class AVX2_BF16_MOE_TP : public AVX2_MOE_BASE<T, AVX2_BF16_MOE_TP<T>> {
|
||||
using Base = AVX2_MOE_BASE<T, AVX2_BF16_MOE_TP<T>>;
|
||||
using Base::config_;
|
||||
using Base::down_ba_;
|
||||
using Base::down_bb_;
|
||||
using Base::down_bc_;
|
||||
using Base::gate_bb_;
|
||||
using Base::gate_bc_;
|
||||
using Base::gate_up_ba_;
|
||||
using Base::m_local_num_;
|
||||
using Base::tp_part_idx;
|
||||
using Base::up_bb_;
|
||||
using Base::up_bc_;
|
||||
|
||||
public:
|
||||
using typename Base::input_t;
|
||||
using typename Base::output_t;
|
||||
|
||||
AVX2_BF16_MOE_TP() = default;
|
||||
|
||||
AVX2_BF16_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {}
|
||||
|
||||
void derived_init() {
|
||||
printf("Created AVX2_BF16_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu()));
|
||||
}
|
||||
|
||||
~AVX2_BF16_MOE_TP() = default;
|
||||
|
||||
// CRTP buffer creation
|
||||
size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }
|
||||
size_t buffer_b_required_size_impl(size_t n, size_t k) const { return T::BufferB::required_size(n, k); }
|
||||
size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }
|
||||
|
||||
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
|
||||
return std::make_shared<typename T::BufferA>(m, k, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
|
||||
return std::make_shared<typename T::BufferB>(n, k, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
|
||||
return std::make_shared<typename T::BufferC>(m, n, data);
|
||||
}
|
||||
|
||||
// GEMM dispatch — uses avx2::gemm_bf16
|
||||
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
|
||||
int m = m_local_num_[expert_idx];
|
||||
auto& ba = gate_up_ba_[expert_idx];
|
||||
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
|
||||
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
|
||||
|
||||
avx2::gemm_bf16(m, config_.intermediate_size, config_.hidden_size, *ba, *bb, *bc, ith, nth);
|
||||
}
|
||||
|
||||
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
|
||||
int m = m_local_num_[expert_idx];
|
||||
avx2::gemm_bf16(m, config_.hidden_size, config_.intermediate_size,
|
||||
*down_ba_[expert_idx], *down_bb_[expert_idx], *down_bc_[expert_idx], ith, nth);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load BF16 weights from contiguous memory layout.
|
||||
* BufferB::from_mat is a trivial memcpy for AVX2 (no AMX transpose).
|
||||
*/
|
||||
void load_weights() {
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
|
||||
if (config_.gate_proj == nullptr) {
|
||||
throw std::runtime_error("BF16 MOE requires native BF16 weight.");
|
||||
}
|
||||
|
||||
// Load gate + up weights
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map](int task_id) {
|
||||
uint64_t expert_idx = task_id / nth;
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
|
||||
gate_bb_[expert_idx]->from_mat(
|
||||
(ggml_bf16_t*)config_.gate_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
|
||||
ith, nth);
|
||||
|
||||
up_bb_[expert_idx]->from_mat(
|
||||
(ggml_bf16_t*)config_.up_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
|
||||
ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
// Load down weights
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map](int task_id) {
|
||||
uint64_t expert_idx = task_id / nth;
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
|
||||
down_bb_[expert_idx]->from_mat(
|
||||
(ggml_bf16_t*)config_.down_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
|
||||
ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
|
||||
/**
|
||||
* Write weights to GPU buffer for dynamic expert offload.
|
||||
* Preserves full TP routing logic from AMX version but uses direct memcpy
|
||||
* instead of unpack_nk_block_bf16 (since BufferB is row-major, not AMX-packed).
|
||||
*/
|
||||
void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id,
|
||||
const GeneralMOEConfig& full_config, const std::vector<uintptr_t>& w13_weight_ptrs,
|
||||
[[maybe_unused]] const std::vector<uintptr_t>& w13_scale_ptrs,
|
||||
const std::vector<uintptr_t>& w2_weight_ptrs,
|
||||
[[maybe_unused]] const std::vector<uintptr_t>& w2_scale_ptrs) const {
|
||||
auto& config = config_;
|
||||
auto pool = config.pool->get_subpool(tp_part_idx);
|
||||
|
||||
// W13 (gate+up): Shape [intermediate, hidden], split by N across GPU TPs
|
||||
const int cpu_n_w13 = config.intermediate_size;
|
||||
const int cpu_k_w13 = config.hidden_size;
|
||||
const int gpu_n_w13 = full_config.intermediate_size / gpu_tp_count;
|
||||
const int gpu_k_w13 = full_config.hidden_size;
|
||||
const int global_n_offset_w13 = tp_part_idx * cpu_n_w13;
|
||||
const size_t gpu_w13_weight_per_mat = (size_t)gpu_n_w13 * gpu_k_w13;
|
||||
|
||||
// W2 (down): Shape [hidden, intermediate], split by K across GPU TPs
|
||||
const int cpu_n_w2 = config.hidden_size;
|
||||
const int cpu_k_w2 = config.intermediate_size;
|
||||
const int gpu_k_w2 = full_config.intermediate_size / gpu_tp_count;
|
||||
const int global_k_offset_w2 = tp_part_idx * cpu_k_w2;
|
||||
|
||||
constexpr int NUM_W13_TASKS = 32;
|
||||
constexpr int NUM_W2_TASKS = 32;
|
||||
const int total_tasks = NUM_W13_TASKS * 2 + NUM_W2_TASKS;
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
total_tasks, nullptr,
|
||||
[=, &w13_weight_ptrs, &w2_weight_ptrs, this](int task_id) {
|
||||
if (task_id < NUM_W13_TASKS * 2) {
|
||||
// W13 weight task: copy rows from BufferB to GPU buffer
|
||||
const bool is_up = task_id >= NUM_W13_TASKS;
|
||||
const int chunk_idx = task_id % NUM_W13_TASKS;
|
||||
const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];
|
||||
|
||||
const int rows_per_task = (cpu_n_w13 + NUM_W13_TASKS - 1) / NUM_W13_TASKS;
|
||||
const int row_start = chunk_idx * rows_per_task;
|
||||
const int row_end = std::min(row_start + rows_per_task, cpu_n_w13);
|
||||
if (row_start >= cpu_n_w13) return;
|
||||
|
||||
for (int row = row_start; row < row_end; row++) {
|
||||
const int global_n = global_n_offset_w13 + row;
|
||||
const int target_gpu = global_n / gpu_n_w13;
|
||||
const int n_in_gpu = global_n % gpu_n_w13;
|
||||
|
||||
ggml_bf16_t* dst = (ggml_bf16_t*)w13_weight_ptrs[target_gpu];
|
||||
const size_t expert_weight_off = is_up ? gpu_w13_weight_per_mat : 0;
|
||||
|
||||
// BufferB is row-major [N, K], direct copy
|
||||
std::memcpy(dst + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13,
|
||||
bb->b + (size_t)row * cpu_k_w13,
|
||||
cpu_k_w13 * sizeof(ggml_bf16_t));
|
||||
}
|
||||
} else {
|
||||
// W2 weight task: copy rows, split K across GPU TPs
|
||||
const int chunk_idx = task_id - NUM_W13_TASKS * 2;
|
||||
const auto& bb = down_bb_[expert_id];
|
||||
|
||||
const int rows_per_task = (cpu_n_w2 + NUM_W2_TASKS - 1) / NUM_W2_TASKS;
|
||||
const int row_start = chunk_idx * rows_per_task;
|
||||
const int row_end = std::min(row_start + rows_per_task, cpu_n_w2);
|
||||
if (row_start >= cpu_n_w2) return;
|
||||
|
||||
for (int row = row_start; row < row_end; row++) {
|
||||
// For W2, K dimension is split across GPU TPs
|
||||
// Iterate over all gpu_k_w2-sized slices within this CPU TP's K range
|
||||
for (int k_start = 0; k_start < cpu_k_w2; k_start += gpu_k_w2) {
|
||||
const int k_slice_end = std::min(k_start + gpu_k_w2, cpu_k_w2);
|
||||
const int k_slice_len = k_slice_end - k_start;
|
||||
|
||||
// Map to correct GPU TP
|
||||
const int global_k = global_k_offset_w2 + k_start;
|
||||
const int target_gpu = global_k / gpu_k_w2;
|
||||
const int k_in_gpu = global_k % gpu_k_w2;
|
||||
|
||||
ggml_bf16_t* dst = (ggml_bf16_t*)w2_weight_ptrs[target_gpu];
|
||||
|
||||
std::memcpy(dst + (size_t)row * gpu_k_w2 + k_in_gpu,
|
||||
bb->b + (size_t)row * cpu_k_w2 + k_start,
|
||||
k_slice_len * sizeof(ggml_bf16_t));
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// TP_MOE specialization for AVX2_BF16_MOE_TP
|
||||
// Handles per-expert pointer loading and TP weight splitting
|
||||
// (Ported from amx/bf16-moe.hpp TP_MOE<AMX_BF16_MOE_TP<K>>)
|
||||
// ============================================================================
|
||||
template <typename K>
|
||||
class TP_MOE<AVX2_BF16_MOE_TP<K>> : public TP_MOE<AVX2_MOE_BASE<K, AVX2_BF16_MOE_TP<K>>> {
|
||||
public:
|
||||
using Base = TP_MOE<AVX2_MOE_BASE<K, AVX2_BF16_MOE_TP<K>>>;
|
||||
using Base::Base;
|
||||
|
||||
void load_weights() override {
|
||||
auto& config = this->config;
|
||||
auto& tps = this->tps;
|
||||
auto& tp_count = this->tp_count;
|
||||
auto pool = config.pool;
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
|
||||
|
||||
if (config.gate_projs.empty() && config.gate_proj == nullptr) {
|
||||
throw std::runtime_error("no weight source");
|
||||
}
|
||||
|
||||
const bool use_per_expert_ptrs = !config.gate_projs.empty();
|
||||
const size_t full_weight_elems = (size_t)config.intermediate_size * config.hidden_size;
|
||||
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
const size_t tp_weight_elems = (size_t)tpc.intermediate_size * tpc.hidden_size;
|
||||
|
||||
// Allocate temporary BF16 buffers for this TP part
|
||||
tpc.gate_proj = new ggml_bf16_t[tpc.expert_num * tp_weight_elems];
|
||||
tpc.up_proj = new ggml_bf16_t[tpc.expert_num * tp_weight_elems];
|
||||
tpc.down_proj = new ggml_bf16_t[tpc.expert_num * tp_weight_elems];
|
||||
|
||||
const size_t gate_up_weight_src_offset = i * tp_weight_elems;
|
||||
const size_t down_weight_src_col_offset = i * (size_t)tpc.intermediate_size;
|
||||
|
||||
pool->get_subpool(i)->do_work_stealing_job(
|
||||
tpc.expert_num, nullptr,
|
||||
[&, &tpc](int expert_id_) {
|
||||
const size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
|
||||
|
||||
ggml_bf16_t* gate_dst = (ggml_bf16_t*)tpc.gate_proj + expert_id * tp_weight_elems;
|
||||
ggml_bf16_t* up_dst = (ggml_bf16_t*)tpc.up_proj + expert_id * tp_weight_elems;
|
||||
ggml_bf16_t* down_dst = (ggml_bf16_t*)tpc.down_proj + expert_id * tp_weight_elems;
|
||||
|
||||
const ggml_bf16_t* gate_src;
|
||||
const ggml_bf16_t* up_src;
|
||||
const ggml_bf16_t* down_src;
|
||||
|
||||
if (use_per_expert_ptrs) {
|
||||
gate_src = (const ggml_bf16_t*)config.gate_projs[0][expert_id] + gate_up_weight_src_offset;
|
||||
up_src = (const ggml_bf16_t*)config.up_projs[0][expert_id] + gate_up_weight_src_offset;
|
||||
down_src = (const ggml_bf16_t*)config.down_projs[0][expert_id];
|
||||
} else {
|
||||
gate_src = (const ggml_bf16_t*)config.gate_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
|
||||
up_src = (const ggml_bf16_t*)config.up_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
|
||||
down_src = (const ggml_bf16_t*)config.down_proj + expert_id * full_weight_elems;
|
||||
}
|
||||
|
||||
// Copy gate and up weights (column-slice for TP)
|
||||
std::memcpy(gate_dst, gate_src, tp_weight_elems * sizeof(ggml_bf16_t));
|
||||
std::memcpy(up_dst, up_src, tp_weight_elems * sizeof(ggml_bf16_t));
|
||||
|
||||
// Copy down weights (row-wise split: each row picks a slice of columns)
|
||||
for (int row = 0; row < config.hidden_size; row++) {
|
||||
const size_t src_row_offset = (size_t)row * (size_t)config.intermediate_size + down_weight_src_col_offset;
|
||||
const size_t dst_row_offset = (size_t)row * (size_t)tpc.intermediate_size;
|
||||
std::memcpy(down_dst + dst_row_offset, down_src + src_row_offset,
|
||||
(size_t)tpc.intermediate_size * sizeof(ggml_bf16_t));
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
});
|
||||
|
||||
// Call per-TP load_weights (which does BufferB::from_mat = memcpy)
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
tps[i]->load_weights();
|
||||
});
|
||||
|
||||
// Free temporary buffers
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
delete[] (ggml_bf16_t*)tpc.gate_proj;
|
||||
delete[] (ggml_bf16_t*)tpc.up_proj;
|
||||
delete[] (ggml_bf16_t*)tpc.down_proj;
|
||||
});
|
||||
|
||||
this->weights_loaded = true;
|
||||
}
|
||||
|
||||
void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w13_scale_ptrs,
|
||||
const std::vector<uintptr_t>& w2_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w2_scale_ptrs) {
|
||||
if (this->weights_loaded == false) throw std::runtime_error("Not Loaded");
|
||||
if (this->tps.empty()) throw std::runtime_error("No TP parts initialized");
|
||||
if ((int)w13_weight_ptrs.size() != gpu_tp_count || (int)w2_weight_ptrs.size() != gpu_tp_count) {
|
||||
throw std::runtime_error("Weight pointer arrays size must match gpu_tp_count");
|
||||
}
|
||||
|
||||
this->config.pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs,
|
||||
w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
#endif // CPUINFER_OPERATOR_AVX2_BF16_MOE_H
|
||||
598
kt-kernel/operators/avx2/fp8-moe.hpp
Normal file
598
kt-kernel/operators/avx2/fp8-moe.hpp
Normal file
@@ -0,0 +1,598 @@
|
||||
/**
|
||||
* @Description : AVX2 FP8 MoE operator (ported from amx/fp8-moe.hpp)
|
||||
* @Author : Claude
|
||||
* @Date : 2026-03-18
|
||||
* @Version : 1.0.0
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
*
|
||||
* FP8 E4M3 weights with 128×128 block-wise float32 scales.
|
||||
* Dequantization: FP8→FP32 via precomputed 256-entry LUT + AVX2 gather.
|
||||
* GEMM: BF16 input × FP32 dequantized weight → FP32 output.
|
||||
**/
|
||||
#ifndef CPUINFER_OPERATOR_AVX2_FP8_MOE_H
|
||||
#define CPUINFER_OPERATOR_AVX2_FP8_MOE_H
|
||||
|
||||
#include "avx2_bf16_gemm.hpp"
|
||||
#include "avx2_bf16_utils.hpp"
|
||||
#include "fp8_dequant.hpp"
|
||||
#include "moe_base.hpp"
|
||||
|
||||
namespace avx2 {
|
||||
|
||||
inline int div_up(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
struct GemmKernelAVX2FP8 {
|
||||
using dt = ggml_bf16_t;
|
||||
using output_t = float;
|
||||
static constexpr int M_STEP = 1;
|
||||
static constexpr int N_STEP = 8;
|
||||
static constexpr int K_STEP = 8;
|
||||
static constexpr int BLOCK_SIZE = 128; // 128×128 block quantization
|
||||
static constexpr int N_BLOCK = 128;
|
||||
static constexpr int K_BLOCK = 128;
|
||||
static constexpr double ELEMENT_SIZE = 1.0; // FP8 = 1 byte
|
||||
|
||||
static void config() {}
|
||||
|
||||
static int recommended_nth(int n) {
|
||||
return std::max(1, div_up(n, N_BLOCK));
|
||||
}
|
||||
|
||||
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
|
||||
return avx2::split_range(n, ith, nth);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// BufferA: BF16 activations [M, K] — same as BF16 backend
|
||||
// ========================================================================
|
||||
struct BufferA {
|
||||
ggml_bf16_t* data = nullptr;
|
||||
size_t max_m = 0;
|
||||
size_t k = 0;
|
||||
|
||||
BufferA() = default;
|
||||
BufferA(size_t m, size_t k_, void* ptr) : max_m(m), k(k_), data((ggml_bf16_t*)ptr) {}
|
||||
|
||||
static size_t required_size(size_t m, size_t k) {
|
||||
return m * k * sizeof(ggml_bf16_t);
|
||||
}
|
||||
|
||||
void set_data(void* ptr) { data = (ggml_bf16_t*)ptr; }
|
||||
|
||||
void from_mat(int m, const ggml_bf16_t* src, int ith, int nth) {
|
||||
if (ith == 0 && nth == 1) {
|
||||
std::memcpy(data, src, (size_t)m * k * sizeof(ggml_bf16_t));
|
||||
} else {
|
||||
auto [m_start, m_end] = avx2::split_range(m, ith, nth);
|
||||
std::memcpy(data + m_start * k, src + m_start * k,
|
||||
(size_t)(m_end - m_start) * k * sizeof(ggml_bf16_t));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
// BufferB: FP8 weights [N, K] + float32 scales [N/BS, K/BS]
|
||||
// Row-major, no packing. from_mat = memcpy.
|
||||
// ========================================================================
|
||||
struct BufferB {
|
||||
uint8_t* b = nullptr; // FP8 weights
|
||||
float* d = nullptr; // Block-wise scales
|
||||
size_t n = 0;
|
||||
size_t k = 0;
|
||||
int block_size = BLOCK_SIZE;
|
||||
|
||||
BufferB() = default;
|
||||
BufferB(size_t n_, size_t k_, int bs, void* ptr) : n(n_), k(k_), block_size(bs) {
|
||||
b = (uint8_t*)ptr;
|
||||
size_t weight_bytes = n * k;
|
||||
d = (float*)((uint8_t*)ptr + weight_bytes);
|
||||
}
|
||||
|
||||
static size_t required_size(size_t n, size_t k, int bs) {
|
||||
size_t n_blocks_n = div_up((int)n, bs);
|
||||
size_t n_blocks_k = div_up((int)k, bs);
|
||||
return n * k + n_blocks_n * n_blocks_k * sizeof(float);
|
||||
}
|
||||
|
||||
void from_mat(const uint8_t* src_weights, const float* src_scales, int ith, int nth) {
|
||||
// Copy weights (split by N)
|
||||
auto [n_start, n_end] = avx2::split_range((int)n, ith, nth);
|
||||
std::memcpy(b + n_start * k, src_weights + n_start * k,
|
||||
(size_t)(n_end - n_start) * k);
|
||||
|
||||
// Copy scales (split by N blocks)
|
||||
int n_blocks_k = div_up((int)k, block_size);
|
||||
int nb_start = n_start / block_size;
|
||||
int nb_end = div_up(n_end, block_size);
|
||||
std::memcpy(d + nb_start * n_blocks_k, src_scales + nb_start * n_blocks_k,
|
||||
(size_t)(nb_end - nb_start) * n_blocks_k * sizeof(float));
|
||||
}
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
// BufferC: FP32 output — same as BF16 backend
|
||||
// ========================================================================
|
||||
struct BufferC {
|
||||
float* data = nullptr;
|
||||
size_t max_m = 0;
|
||||
size_t n = 0;
|
||||
|
||||
BufferC() = default;
|
||||
BufferC(size_t m, size_t n_, void* ptr) : max_m(m), n(n_), data((float*)ptr) {}
|
||||
|
||||
static size_t required_size(size_t m, size_t n) {
|
||||
return m * n * sizeof(float);
|
||||
}
|
||||
|
||||
void set_data(void* ptr) { data = (float*)ptr; }
|
||||
|
||||
void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {
|
||||
auto [n_start, n_end] = avx2::split_range((int)n, ith, nth);
|
||||
for (int mi = 0; mi < m; mi++) {
|
||||
float* src_row = data + mi * n;
|
||||
ggml_bf16_t* dst_row = dst + mi * n;
|
||||
int j = n_start;
|
||||
for (; j + 8 <= n_end; j += 8) {
|
||||
__m256 v = _mm256_loadu_ps(src_row + j);
|
||||
store_fp32_to_bf16(dst_row + j, v);
|
||||
}
|
||||
for (; j < n_end; j++) {
|
||||
dst_row[j] = GGML_FP32_TO_BF16(src_row[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// AVX2 FP8 GEMM: C[m,n] = sum_k (A[m,k] * dequant(B[n,k])) * scale[n/BS, k/BS]
|
||||
// ============================================================================
|
||||
|
||||
static inline void gemm_fp8(
|
||||
int m, int n, int k,
|
||||
GemmKernelAVX2FP8::BufferA& a,
|
||||
GemmKernelAVX2FP8::BufferB& b,
|
||||
GemmKernelAVX2FP8::BufferC& c,
|
||||
int ith, int nth) {
|
||||
|
||||
ensure_fp8_lut_initialized();
|
||||
|
||||
auto [n_start, n_end] = split_range(n, ith, nth);
|
||||
const int block_size = b.block_size;
|
||||
const int n_blocks_k = div_up(k, block_size);
|
||||
|
||||
for (int ni = n_start; ni < n_end; ni++) {
|
||||
const uint8_t* b_row = b.b + (size_t)ni * k;
|
||||
const int n_block_idx = ni / block_size;
|
||||
|
||||
for (int mi = 0; mi < m; mi++) {
|
||||
const ggml_bf16_t* a_row = a.data + (size_t)mi * a.k;
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int kb = 0; kb < k; kb += block_size) {
|
||||
int k_len = std::min(block_size, k - kb);
|
||||
int k_block_idx = kb / block_size;
|
||||
float scale = b.d[n_block_idx * n_blocks_k + k_block_idx];
|
||||
|
||||
// Accumulate within this block
|
||||
__m256 acc1 = _mm256_setzero_ps();
|
||||
__m256 acc2 = _mm256_setzero_ps();
|
||||
__m256 acc3 = _mm256_setzero_ps();
|
||||
__m256 acc4 = _mm256_setzero_ps();
|
||||
|
||||
int ki = 0;
|
||||
for (; ki + 32 <= k_len; ki += 32) {
|
||||
acc1 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + kb + ki),
|
||||
fp8x8_to_fp32x8(b_row + kb + ki), acc1);
|
||||
acc2 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + kb + ki + 8),
|
||||
fp8x8_to_fp32x8(b_row + kb + ki + 8), acc2);
|
||||
acc3 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + kb + ki + 16),
|
||||
fp8x8_to_fp32x8(b_row + kb + ki + 16), acc3);
|
||||
acc4 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + kb + ki + 24),
|
||||
fp8x8_to_fp32x8(b_row + kb + ki + 24), acc4);
|
||||
}
|
||||
for (; ki + 8 <= k_len; ki += 8) {
|
||||
acc1 = _mm256_fmadd_ps(load_bf16_to_fp32(a_row + kb + ki),
|
||||
fp8x8_to_fp32x8(b_row + kb + ki), acc1);
|
||||
}
|
||||
|
||||
float block_sum = hsum_avx2(_mm256_add_ps(_mm256_add_ps(acc1, acc3),
|
||||
_mm256_add_ps(acc2, acc4)));
|
||||
|
||||
// Scalar tail
|
||||
for (; ki < k_len; ki++) {
|
||||
block_sum += GGML_BF16_TO_FP32(a_row[kb + ki]) * fp8_to_fp32_scalar(b_row[kb + ki]);
|
||||
}
|
||||
|
||||
sum += block_sum * scale;
|
||||
}
|
||||
|
||||
c.data[mi * n + ni] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace avx2
|
||||
|
||||
// ============================================================================
|
||||
// AVX2 FP8 MoE operator (CRTP derived from AVX2_MOE_BASE)
|
||||
// ============================================================================
|
||||
|
||||
template <class T = avx2::GemmKernelAVX2FP8>
|
||||
class AVX2_FP8_MOE_TP : public AVX2_MOE_BASE<T, AVX2_FP8_MOE_TP<T>> {
|
||||
using Base = AVX2_MOE_BASE<T, AVX2_FP8_MOE_TP<T>>;
|
||||
using Base::config_;
|
||||
using Base::down_ba_;
|
||||
using Base::down_bb_;
|
||||
using Base::down_bc_;
|
||||
using Base::gate_bb_;
|
||||
using Base::gate_bc_;
|
||||
using Base::gate_up_ba_;
|
||||
using Base::m_local_num_;
|
||||
using Base::tp_part_idx;
|
||||
using Base::up_bb_;
|
||||
using Base::up_bc_;
|
||||
|
||||
public:
|
||||
using typename Base::input_t;
|
||||
using typename Base::output_t;
|
||||
|
||||
AVX2_FP8_MOE_TP() = default;
|
||||
|
||||
AVX2_FP8_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {}
|
||||
|
||||
void derived_init() {
|
||||
avx2::ensure_fp8_lut_initialized();
|
||||
auto& quant_config = config_.quant_config;
|
||||
if (quant_config.group_size == 0 || quant_config.zero_point) {
|
||||
throw std::runtime_error("AVX2 FP8 MoE only supports block-wise FP8 (group_size > 0, no zero_point)");
|
||||
}
|
||||
printf("Created AVX2_FP8_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu()));
|
||||
}
|
||||
|
||||
~AVX2_FP8_MOE_TP() = default;
|
||||
|
||||
// CRTP buffer creation — with group_size for BufferB
|
||||
size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }
|
||||
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
|
||||
return T::BufferB::required_size(n, k, config_.quant_config.group_size);
|
||||
}
|
||||
size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }
|
||||
|
||||
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
|
||||
return std::make_shared<typename T::BufferA>(m, k, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
|
||||
return std::make_shared<typename T::BufferB>(n, k, config_.quant_config.group_size, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
|
||||
return std::make_shared<typename T::BufferC>(m, n, data);
|
||||
}
|
||||
|
||||
// GEMM dispatch
|
||||
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
|
||||
int m = m_local_num_[expert_idx];
|
||||
auto& ba = gate_up_ba_[expert_idx];
|
||||
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
|
||||
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
|
||||
avx2::gemm_fp8(m, config_.intermediate_size, config_.hidden_size, *ba, *bb, *bc, ith, nth);
|
||||
}
|
||||
|
||||
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
|
||||
int m = m_local_num_[expert_idx];
|
||||
avx2::gemm_fp8(m, config_.hidden_size, config_.intermediate_size,
|
||||
*down_ba_[expert_idx], *down_bb_[expert_idx], *down_bc_[expert_idx], ith, nth);
|
||||
}
|
||||
|
||||
// Load FP8 weights + scales from contiguous memory
|
||||
void load_weights() {
|
||||
auto& quant_config = config_.quant_config;
|
||||
int group_size = quant_config.group_size;
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
|
||||
if (config_.gate_scale == nullptr) {
|
||||
throw std::runtime_error("FP8 MOE requires scale pointers.");
|
||||
}
|
||||
|
||||
// Load gate + up weights
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map, group_size](int task_id) {
|
||||
uint64_t expert_idx = task_id / nth;
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
|
||||
size_t weight_offset = logical_expert_id * config_.intermediate_size * config_.hidden_size;
|
||||
size_t scale_offset = logical_expert_id *
|
||||
avx2::div_up(config_.hidden_size, group_size) *
|
||||
avx2::div_up(config_.intermediate_size, group_size);
|
||||
|
||||
gate_bb_[expert_idx]->from_mat(
|
||||
(uint8_t*)config_.gate_proj + weight_offset,
|
||||
(float*)config_.gate_scale + scale_offset,
|
||||
ith, nth);
|
||||
|
||||
up_bb_[expert_idx]->from_mat(
|
||||
(uint8_t*)config_.up_proj + weight_offset,
|
||||
(float*)config_.up_scale + scale_offset,
|
||||
ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
// Load down weights
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map, group_size](int task_id) {
|
||||
uint64_t expert_idx = task_id / nth;
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
|
||||
size_t weight_offset = logical_expert_id * config_.intermediate_size * config_.hidden_size;
|
||||
size_t scale_offset = logical_expert_id *
|
||||
avx2::div_up(config_.hidden_size, group_size) *
|
||||
avx2::div_up(config_.intermediate_size, group_size);
|
||||
|
||||
down_bb_[expert_idx]->from_mat(
|
||||
(uint8_t*)config_.down_proj + weight_offset,
|
||||
(float*)config_.down_scale + scale_offset,
|
||||
ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
|
||||
// Write weights to GPU buffer (for dynamic expert offload / layerwise prefill)
|
||||
void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id,
|
||||
const GeneralMOEConfig& full_config, const std::vector<uintptr_t>& w13_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w13_scale_ptrs,
|
||||
const std::vector<uintptr_t>& w2_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w2_scale_ptrs) const {
|
||||
auto& config = config_;
|
||||
auto pool = config.pool->get_subpool(tp_part_idx);
|
||||
int group_size = config.quant_config.group_size;
|
||||
|
||||
// W13 (gate+up)
|
||||
const int cpu_n_w13 = config.intermediate_size;
|
||||
const int cpu_k_w13 = config.hidden_size;
|
||||
const int gpu_n_w13 = full_config.intermediate_size / gpu_tp_count;
|
||||
const int gpu_k_w13 = full_config.hidden_size;
|
||||
const int global_n_offset_w13 = tp_part_idx * cpu_n_w13;
|
||||
const size_t gpu_w13_weight_per_mat = (size_t)gpu_n_w13 * gpu_k_w13;
|
||||
const int gpu_n_blocks_k_w13 = avx2::div_up(gpu_k_w13, group_size);
|
||||
const size_t gpu_w13_scale_per_mat = (size_t)avx2::div_up(gpu_n_w13, group_size) * gpu_n_blocks_k_w13;
|
||||
|
||||
// W2 (down)
|
||||
const int cpu_n_w2 = config.hidden_size;
|
||||
const int cpu_k_w2 = config.intermediate_size;
|
||||
const int gpu_k_w2 = full_config.intermediate_size / gpu_tp_count;
|
||||
const int global_k_offset_w2 = tp_part_idx * cpu_k_w2;
|
||||
const int cpu_n_blocks_k_w2 = avx2::div_up(cpu_k_w2, group_size);
|
||||
|
||||
constexpr int NUM_W13_TASKS = 32;
|
||||
constexpr int NUM_W2_TASKS = 32;
|
||||
const int total_tasks = NUM_W13_TASKS * 2 + NUM_W2_TASKS;
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
total_tasks, nullptr,
|
||||
[=, &w13_weight_ptrs, &w13_scale_ptrs, &w2_weight_ptrs, &w2_scale_ptrs, this](int task_id) {
|
||||
if (task_id < NUM_W13_TASKS * 2) {
|
||||
const bool is_up = task_id >= NUM_W13_TASKS;
|
||||
const int chunk_idx = task_id % NUM_W13_TASKS;
|
||||
const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];
|
||||
|
||||
const int rows_per_task = avx2::div_up(cpu_n_w13, NUM_W13_TASKS);
|
||||
const int row_start = chunk_idx * rows_per_task;
|
||||
const int row_end = std::min(row_start + rows_per_task, cpu_n_w13);
|
||||
if (row_start >= cpu_n_w13) return;
|
||||
|
||||
for (int row = row_start; row < row_end; row++) {
|
||||
const int global_n = global_n_offset_w13 + row;
|
||||
const int target_gpu = global_n / gpu_n_w13;
|
||||
const int n_in_gpu = global_n % gpu_n_w13;
|
||||
|
||||
// Copy weight row
|
||||
uint8_t* w_dst = (uint8_t*)w13_weight_ptrs[target_gpu];
|
||||
const size_t expert_w_off = is_up ? gpu_w13_weight_per_mat : 0;
|
||||
std::memcpy(w_dst + expert_w_off + (size_t)n_in_gpu * gpu_k_w13,
|
||||
bb->b + (size_t)row * cpu_k_w13,
|
||||
cpu_k_w13);
|
||||
|
||||
// Copy scale row (if at block boundary)
|
||||
if (row % group_size == 0) {
|
||||
int n_block = row / group_size;
|
||||
int gpu_n_block = n_in_gpu / group_size;
|
||||
float* s_dst = (float*)w13_scale_ptrs[target_gpu];
|
||||
const size_t expert_s_off = is_up ? gpu_w13_scale_per_mat : 0;
|
||||
std::memcpy(s_dst + expert_s_off + gpu_n_block * gpu_n_blocks_k_w13,
|
||||
bb->d + n_block * avx2::div_up(cpu_k_w13, group_size),
|
||||
avx2::div_up(cpu_k_w13, group_size) * sizeof(float));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int chunk_idx = task_id - NUM_W13_TASKS * 2;
|
||||
const auto& bb = down_bb_[expert_id];
|
||||
|
||||
const int rows_per_task = avx2::div_up(cpu_n_w2, NUM_W2_TASKS);
|
||||
const int row_start = chunk_idx * rows_per_task;
|
||||
const int row_end = std::min(row_start + rows_per_task, cpu_n_w2);
|
||||
if (row_start >= cpu_n_w2) return;
|
||||
|
||||
for (int row = row_start; row < row_end; row++) {
|
||||
// Iterate over all gpu_k_w2-sized slices within this CPU TP's K range
|
||||
for (int k_start = 0; k_start < cpu_k_w2; k_start += gpu_k_w2) {
|
||||
const int k_slice_len = std::min(gpu_k_w2, cpu_k_w2 - k_start);
|
||||
const int global_k = global_k_offset_w2 + k_start;
|
||||
const int target_gpu = global_k / gpu_k_w2;
|
||||
const int k_in_gpu = global_k % gpu_k_w2;
|
||||
|
||||
uint8_t* w_dst = (uint8_t*)w2_weight_ptrs[target_gpu];
|
||||
std::memcpy(w_dst + (size_t)row * gpu_k_w2 + k_in_gpu,
|
||||
bb->b + (size_t)row * cpu_k_w2 + k_start,
|
||||
k_slice_len);
|
||||
|
||||
// Copy scales for down (at block boundaries)
|
||||
if (row % group_size == 0) {
|
||||
int n_block = row / group_size;
|
||||
float* s_dst = (float*)w2_scale_ptrs[target_gpu];
|
||||
int gpu_n_blocks_k_w2 = avx2::div_up(gpu_k_w2, group_size);
|
||||
int k_block_start = k_in_gpu / group_size;
|
||||
int n_blocks_to_copy = std::min(cpu_n_blocks_k_w2, gpu_n_blocks_k_w2 - k_block_start);
|
||||
std::memcpy(s_dst + n_block * gpu_n_blocks_k_w2 + k_block_start,
|
||||
bb->d + n_block * cpu_n_blocks_k_w2 + k_start / group_size,
|
||||
n_blocks_to_copy * sizeof(float));
|
||||
}
|
||||
} // end k_start loop
|
||||
} // end row loop
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// TP_MOE specialization — ported from amx/fp8-moe.hpp:628-738
|
||||
// Handles per-expert pointer loading + TP weight/scale splitting
|
||||
// ============================================================================
|
||||
template <typename K>
|
||||
class TP_MOE<AVX2_FP8_MOE_TP<K>> : public TP_MOE<AVX2_MOE_BASE<K, AVX2_FP8_MOE_TP<K>>> {
|
||||
public:
|
||||
using Base = TP_MOE<AVX2_MOE_BASE<K, AVX2_FP8_MOE_TP<K>>>;
|
||||
using Base::Base;
|
||||
|
||||
void load_weights() override {
|
||||
auto& config = this->config;
|
||||
auto& tps = this->tps;
|
||||
auto& tp_count = this->tp_count;
|
||||
auto pool = config.pool;
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
|
||||
|
||||
const int group_size = config.quant_config.group_size;
|
||||
if (group_size == 0 || config.quant_config.zero_point) {
|
||||
throw std::runtime_error("FP8 MoE only supports block-wise (group_size > 0, zero_point=false)");
|
||||
}
|
||||
|
||||
if (config.gate_projs.empty() && config.gate_proj == nullptr) {
|
||||
throw std::runtime_error("no weight source");
|
||||
}
|
||||
const bool use_per_expert_ptrs = !config.gate_projs.empty();
|
||||
|
||||
const size_t full_weight_elems = (size_t)config.intermediate_size * config.hidden_size;
|
||||
const size_t full_scale_elems =
|
||||
(size_t)avx2::div_up(config.hidden_size, group_size) * avx2::div_up(config.intermediate_size, group_size);
|
||||
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
const size_t tp_weight_elems = (size_t)tpc.intermediate_size * tpc.hidden_size;
|
||||
const size_t tp_scale_elems =
|
||||
(size_t)avx2::div_up(tpc.intermediate_size, group_size) * avx2::div_up(tpc.hidden_size, group_size);
|
||||
|
||||
// Allocate temporary buffers
|
||||
tpc.gate_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
|
||||
tpc.up_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
|
||||
tpc.down_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
|
||||
tpc.gate_scale = new float[tpc.expert_num * tp_scale_elems];
|
||||
tpc.up_scale = new float[tpc.expert_num * tp_scale_elems];
|
||||
tpc.down_scale = new float[tpc.expert_num * tp_scale_elems];
|
||||
|
||||
const size_t gate_up_weight_src_offset = i * tp_weight_elems;
|
||||
const size_t gate_up_scale_src_offset = i * tp_scale_elems;
|
||||
const size_t down_weight_src_col_offset = i * (size_t)tpc.intermediate_size;
|
||||
const size_t down_scale_src_block_k_offset = down_weight_src_col_offset / (size_t)group_size;
|
||||
|
||||
pool->get_subpool(i)->do_work_stealing_job(
|
||||
tpc.expert_num, nullptr,
|
||||
[&, &tpc](int expert_id_) {
|
||||
const size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
|
||||
|
||||
uint8_t* gate_dst = (uint8_t*)tpc.gate_proj + expert_id * tp_weight_elems;
|
||||
uint8_t* up_dst = (uint8_t*)tpc.up_proj + expert_id * tp_weight_elems;
|
||||
uint8_t* down_dst = (uint8_t*)tpc.down_proj + expert_id * tp_weight_elems;
|
||||
float* gate_scale_dst = (float*)tpc.gate_scale + expert_id * tp_scale_elems;
|
||||
float* up_scale_dst = (float*)tpc.up_scale + expert_id * tp_scale_elems;
|
||||
float* down_scale_dst = (float*)tpc.down_scale + expert_id * tp_scale_elems;
|
||||
|
||||
const uint8_t* gate_src;
|
||||
const uint8_t* up_src;
|
||||
const uint8_t* down_src;
|
||||
const float* gate_scale_src;
|
||||
const float* up_scale_src;
|
||||
const float* down_scale_src;
|
||||
|
||||
if (use_per_expert_ptrs) {
|
||||
gate_src = (const uint8_t*)config.gate_projs[0][expert_id] + gate_up_weight_src_offset;
|
||||
up_src = (const uint8_t*)config.up_projs[0][expert_id] + gate_up_weight_src_offset;
|
||||
down_src = (const uint8_t*)config.down_projs[0][expert_id];
|
||||
gate_scale_src = (const float*)config.gate_scales[0][expert_id] + gate_up_scale_src_offset;
|
||||
up_scale_src = (const float*)config.up_scales[0][expert_id] + gate_up_scale_src_offset;
|
||||
down_scale_src = (const float*)config.down_scales[0][expert_id];
|
||||
} else {
|
||||
gate_src = (const uint8_t*)config.gate_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
|
||||
up_src = (const uint8_t*)config.up_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
|
||||
down_src = (const uint8_t*)config.down_proj + expert_id * full_weight_elems;
|
||||
gate_scale_src = (const float*)config.gate_scale + expert_id * full_scale_elems + gate_up_scale_src_offset;
|
||||
up_scale_src = (const float*)config.up_scale + expert_id * full_scale_elems + gate_up_scale_src_offset;
|
||||
down_scale_src = (const float*)config.down_scale + expert_id * full_scale_elems;
|
||||
}
|
||||
|
||||
// Copy gate/up weights + scales (column slice)
|
||||
std::memcpy(gate_dst, gate_src, tp_weight_elems);
|
||||
std::memcpy(up_dst, up_src, tp_weight_elems);
|
||||
std::memcpy(gate_scale_dst, gate_scale_src, sizeof(float) * tp_scale_elems);
|
||||
std::memcpy(up_scale_dst, up_scale_src, sizeof(float) * tp_scale_elems);
|
||||
|
||||
// Copy down weights (row-wise split)
|
||||
for (int row = 0; row < config.hidden_size; row++) {
|
||||
const size_t src_row_offset = (size_t)row * (size_t)config.intermediate_size + down_weight_src_col_offset;
|
||||
const size_t dst_row_offset = (size_t)row * (size_t)tpc.intermediate_size;
|
||||
std::memcpy(down_dst + dst_row_offset, down_src + src_row_offset, (size_t)tpc.intermediate_size);
|
||||
}
|
||||
|
||||
// Copy down scales (block-row-wise split)
|
||||
const int n_blocks_n = avx2::div_up(config.hidden_size, group_size);
|
||||
const int full_n_blocks_k = avx2::div_up(config.intermediate_size, group_size);
|
||||
const int tp_n_blocks_k = avx2::div_up(tpc.intermediate_size, group_size);
|
||||
for (int bn = 0; bn < n_blocks_n; bn++) {
|
||||
const float* src = down_scale_src + (size_t)bn * full_n_blocks_k + down_scale_src_block_k_offset;
|
||||
float* dst = down_scale_dst + (size_t)bn * tp_n_blocks_k;
|
||||
std::memcpy(dst, src, sizeof(float) * tp_n_blocks_k);
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
});
|
||||
|
||||
// Call per-TP load_weights
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
tps[i]->load_weights();
|
||||
});
|
||||
|
||||
// Free temporary buffers
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
delete[] (uint8_t*)tpc.gate_proj;
|
||||
delete[] (uint8_t*)tpc.up_proj;
|
||||
delete[] (uint8_t*)tpc.down_proj;
|
||||
delete[] (float*)tpc.gate_scale;
|
||||
delete[] (float*)tpc.up_scale;
|
||||
delete[] (float*)tpc.down_scale;
|
||||
});
|
||||
|
||||
this->weights_loaded = true;
|
||||
}
|
||||
|
||||
void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w13_scale_ptrs,
|
||||
const std::vector<uintptr_t>& w2_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w2_scale_ptrs) {
|
||||
if (this->weights_loaded == false) throw std::runtime_error("Not Loaded");
|
||||
if (this->tps.empty()) throw std::runtime_error("No TP parts initialized");
|
||||
|
||||
this->config.pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config,
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
#endif // CPUINFER_OPERATOR_AVX2_FP8_MOE_H
|
||||
86
kt-kernel/operators/avx2/fp8_dequant.hpp
Normal file
86
kt-kernel/operators/avx2/fp8_dequant.hpp
Normal file
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* @Description : FP8 E4M3 dequantization for AVX2 (LUT-based)
|
||||
* @Author : Claude
|
||||
* @Date : 2026-03-18
|
||||
* @Version : 1.0.0
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
*
|
||||
* AVX512 uses _mm512_permutex2var_epi8 (VBMI) for FP8→BF16 LUT conversion.
|
||||
* AVX2 uses a precomputed 256-entry FP8→FP32 lookup table + _mm256_i32gather_ps.
|
||||
*
|
||||
* FP8 E4M3 format: sign(1) + exponent(4) + mantissa(3)
|
||||
* Reference: examples/test_fp8_moe.py:103-116
|
||||
**/
|
||||
#ifndef CPUINFER_OPERATOR_AVX2_FP8_DEQUANT_H
|
||||
#define CPUINFER_OPERATOR_AVX2_FP8_DEQUANT_H
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
|
||||
namespace avx2 {
|
||||
|
||||
// Precomputed FP8 E4M3 → FP32 lookup table (256 entries)
|
||||
// Initialized once at program startup via init_fp8_lut()
|
||||
struct FP8LUT {
|
||||
alignas(32) float table[256];
|
||||
bool initialized = false;
|
||||
|
||||
void init() {
|
||||
if (initialized) return;
|
||||
for (int i = 0; i < 256; i++) {
|
||||
int sign = (i >> 7) & 1;
|
||||
int exp = (i >> 3) & 0xF; // 4-bit exponent (bits 3-6)
|
||||
int man = i & 0x7; // 3-bit mantissa (bits 0-2)
|
||||
|
||||
float val;
|
||||
if (exp == 0 && man == 0) {
|
||||
val = 0.0f; // zero
|
||||
} else if (exp == 0) {
|
||||
val = std::ldexp((float)man / 8.0f, -6); // subnormal: 2^(-6) * (0.man)
|
||||
} else if (exp == 15 && man == 7) {
|
||||
val = 0.0f; // Only 0x7F is NaN in E4M3. Treat as 0 to avoid propagation.
|
||||
// E4M3 has no Inf. exp=15 with man=0-6 are valid finite values (256-448).
|
||||
} else {
|
||||
val = std::ldexp(1.0f + (float)man / 8.0f, exp - 7); // normal: 2^(exp-7) * (1.man)
|
||||
}
|
||||
table[i] = sign ? -val : val;
|
||||
}
|
||||
initialized = true;
|
||||
}
|
||||
};
|
||||
|
||||
// Global LUT instance
|
||||
inline FP8LUT& get_fp8_lut() {
|
||||
static FP8LUT lut;
|
||||
return lut;
|
||||
}
|
||||
|
||||
// Ensure LUT is initialized (call once at startup)
|
||||
inline void ensure_fp8_lut_initialized() {
|
||||
get_fp8_lut().init();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// AVX2 FP8→FP32 dequantization: 8 FP8 bytes → 8 FP32 values
|
||||
// Uses _mm256_i32gather_ps for parallel LUT lookups
|
||||
// ============================================================================
|
||||
|
||||
static inline __m256 fp8x8_to_fp32x8(const uint8_t* src) {
|
||||
const float* lut = get_fp8_lut().table;
|
||||
// Load 8 bytes, zero-extend to 32-bit indices
|
||||
__m128i bytes = _mm_loadl_epi64((const __m128i*)src);
|
||||
__m256i indices = _mm256_cvtepu8_epi32(bytes);
|
||||
// Gather 8 floats from LUT (scale=4 because float is 4 bytes)
|
||||
return _mm256_i32gather_ps(lut, indices, 4);
|
||||
}
|
||||
|
||||
// Scalar fallback for non-aligned or tail elements
|
||||
static inline float fp8_to_fp32_scalar(uint8_t val) {
|
||||
return get_fp8_lut().table[val];
|
||||
}
|
||||
|
||||
} // namespace avx2
|
||||
|
||||
#endif // CPUINFER_OPERATOR_AVX2_FP8_DEQUANT_H
|
||||
510
kt-kernel/operators/avx2/gptq_int4-moe.hpp
Normal file
510
kt-kernel/operators/avx2/gptq_int4-moe.hpp
Normal file
@@ -0,0 +1,510 @@
|
||||
/**
|
||||
* @Description : AVX2 GPTQ-Int4 MoE operator (symmetric quantization)
|
||||
* @Author : Claude
|
||||
* @Date : 2026-03-18
|
||||
* @Version : 1.0.0
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
*
|
||||
* Supports GPTQ symmetric (sym=true, desc_act=false) INT4 quantization.
|
||||
* qweight [K/8, N] int32 + scales [K/gs, N] fp32. No qzeros needed.
|
||||
**/
|
||||
#ifndef CPUINFER_OPERATOR_AVX2_GPTQ_INT4_MOE_H
|
||||
#define CPUINFER_OPERATOR_AVX2_GPTQ_INT4_MOE_H
|
||||
|
||||
#include "avx2_bf16_gemm.hpp"
|
||||
#include "avx2_bf16_utils.hpp"
|
||||
#include "gptq_int4_dequant.hpp"
|
||||
#include "moe_base.hpp"
|
||||
|
||||
namespace avx2 {
|
||||
|
||||
struct GemmKernelAVX2GPTQInt4 {
|
||||
using dt = ggml_bf16_t;
|
||||
using output_t = float;
|
||||
static constexpr int M_STEP = 1;
|
||||
static constexpr int N_STEP = 8;
|
||||
static constexpr int K_STEP = 8; // 8 INT4 values per int32
|
||||
static constexpr int N_BLOCK = 64;
|
||||
static constexpr int K_BLOCK = 128; // = group_size typically
|
||||
static constexpr double ELEMENT_SIZE = 0.5; // INT4 = 0.5 byte
|
||||
|
||||
static void config() {}
|
||||
|
||||
static int recommended_nth(int n) {
|
||||
return std::max(1, n / N_BLOCK);
|
||||
}
|
||||
|
||||
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
|
||||
return split_range(n, ith, nth);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// BufferA: BF16 activations [M, K] — same as BF16/FP8
|
||||
// ========================================================================
|
||||
struct BufferA {
|
||||
ggml_bf16_t* data = nullptr;
|
||||
size_t max_m = 0;
|
||||
size_t k = 0;
|
||||
|
||||
BufferA() = default;
|
||||
BufferA(size_t m, size_t k_, void* ptr) : max_m(m), k(k_), data((ggml_bf16_t*)ptr) {}
|
||||
|
||||
static size_t required_size(size_t m, size_t k) {
|
||||
return m * k * sizeof(ggml_bf16_t);
|
||||
}
|
||||
|
||||
void set_data(void* ptr) { data = (ggml_bf16_t*)ptr; }
|
||||
|
||||
void from_mat(int m, const ggml_bf16_t* src, int ith, int nth) {
|
||||
if (ith == 0 && nth == 1) {
|
||||
std::memcpy(data, src, (size_t)m * k * sizeof(ggml_bf16_t));
|
||||
} else {
|
||||
auto [m_start, m_end] = split_range(m, ith, nth);
|
||||
std::memcpy(data + m_start * k, src + m_start * k,
|
||||
(size_t)(m_end - m_start) * k * sizeof(ggml_bf16_t));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
// BufferB: GPTQ INT4 weights [K/8, N] int32 + scales [num_groups, N] fp32
|
||||
// ========================================================================
|
||||
struct BufferB {
|
||||
uint32_t* qweight = nullptr; // [K/8, N] packed int32
|
||||
float* scales = nullptr; // [num_groups, N] fp32
|
||||
int n = 0;
|
||||
int k = 0;
|
||||
int group_size = 128;
|
||||
int num_groups = 0;
|
||||
int k_packed = 0; // = K/8
|
||||
|
||||
BufferB() = default;
|
||||
BufferB(size_t n_, size_t k_, int gs, void* ptr)
|
||||
: n(n_), k(k_), group_size(gs) {
|
||||
k_packed = k / 8;
|
||||
num_groups = k / gs;
|
||||
qweight = (uint32_t*)ptr;
|
||||
scales = (float*)((uint8_t*)ptr + (size_t)k_packed * n * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
static size_t required_size(size_t n, size_t k, int gs) {
|
||||
size_t k_packed = k / 8;
|
||||
size_t num_groups = k / gs;
|
||||
return k_packed * n * sizeof(uint32_t) + num_groups * n * sizeof(float);
|
||||
}
|
||||
|
||||
// Load qweight and scales from separate source pointers
|
||||
void from_mat(const uint32_t* src_qweight, const float* src_scales, int ith, int nth) {
|
||||
// Split by N dimension
|
||||
auto [n_start, n_end] = split_range(n, ith, nth);
|
||||
int n_len = n_end - n_start;
|
||||
|
||||
// Copy qweight rows [K/8 rows, each row = N int32]
|
||||
for (int kr = 0; kr < k_packed; kr++) {
|
||||
std::memcpy(qweight + kr * n + n_start,
|
||||
src_qweight + kr * n + n_start,
|
||||
n_len * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
// Copy scales rows [num_groups rows, each row = N float]
|
||||
for (int g = 0; g < num_groups; g++) {
|
||||
std::memcpy(scales + g * n + n_start,
|
||||
src_scales + g * n + n_start,
|
||||
n_len * sizeof(float));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
// BufferC: FP32 output — same as BF16/FP8
|
||||
// ========================================================================
|
||||
struct BufferC {
|
||||
float* data = nullptr;
|
||||
size_t max_m = 0;
|
||||
size_t n = 0;
|
||||
|
||||
BufferC() = default;
|
||||
BufferC(size_t m, size_t n_, void* ptr) : max_m(m), n(n_), data((float*)ptr) {}
|
||||
|
||||
static size_t required_size(size_t m, size_t n) {
|
||||
return m * n * sizeof(float);
|
||||
}
|
||||
|
||||
void set_data(void* ptr) { data = (float*)ptr; }
|
||||
|
||||
void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {
|
||||
auto [n_start, n_end] = split_range((int)n, ith, nth);
|
||||
for (int mi = 0; mi < m; mi++) {
|
||||
float* src_row = data + mi * n;
|
||||
ggml_bf16_t* dst_row = dst + mi * n;
|
||||
int j = n_start;
|
||||
for (; j + 8 <= n_end; j += 8) {
|
||||
store_fp32_to_bf16(dst_row + j, _mm256_loadu_ps(src_row + j));
|
||||
}
|
||||
for (; j < n_end; j++) {
|
||||
dst_row[j] = GGML_FP32_TO_BF16(src_row[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// AVX2 GPTQ INT4 GEMM (symmetric)
|
||||
// C[m,n] = sum_k A_bf16[m,k] * dequant(B_int4[k,n])
|
||||
// ============================================================================
|
||||
|
||||
static inline void gemm_gptq_sym_int4(
|
||||
int m, int n, int k,
|
||||
GemmKernelAVX2GPTQInt4::BufferA& a,
|
||||
GemmKernelAVX2GPTQInt4::BufferB& b,
|
||||
GemmKernelAVX2GPTQInt4::BufferC& c,
|
||||
int ith, int nth) {
|
||||
|
||||
auto [n_start, n_end] = split_range(n, ith, nth);
|
||||
const int group_size = b.group_size;
|
||||
const int num_groups = b.num_groups;
|
||||
|
||||
for (int ni = n_start; ni < n_end; ni++) {
|
||||
for (int mi = 0; mi < m; mi++) {
|
||||
const ggml_bf16_t* a_row = a.data + (size_t)mi * a.k;
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int g = 0; g < num_groups; g++) {
|
||||
float scale = b.scales[g * n + ni];
|
||||
int k_base = g * group_size;
|
||||
|
||||
__m256 acc1 = _mm256_setzero_ps();
|
||||
__m256 acc2 = _mm256_setzero_ps();
|
||||
|
||||
// group_size/8 iterations (e.g., 128/8 = 16)
|
||||
for (int ki = 0; ki < group_size; ki += 8) {
|
||||
int k_abs = k_base + ki;
|
||||
__m256 a_val = load_bf16_to_fp32(a_row + k_abs);
|
||||
uint32_t packed = b.qweight[(k_abs / 8) * n + ni];
|
||||
__m256 w_val = gptq_sym_dequant_8x4bit(packed, scale);
|
||||
acc1 = _mm256_fmadd_ps(a_val, w_val, acc1);
|
||||
}
|
||||
|
||||
sum += hsum_avx2(acc1);
|
||||
}
|
||||
|
||||
c.data[mi * n + ni] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace avx2
|
||||
|
||||
// ============================================================================
|
||||
// AVX2 GPTQ INT4 MoE operator
|
||||
// ============================================================================
|
||||
|
||||
template <class T = avx2::GemmKernelAVX2GPTQInt4>
|
||||
class AVX2_GPTQ_INT4_MOE_TP : public AVX2_MOE_BASE<T, AVX2_GPTQ_INT4_MOE_TP<T>> {
|
||||
using Base = AVX2_MOE_BASE<T, AVX2_GPTQ_INT4_MOE_TP<T>>;
|
||||
using Base::config_;
|
||||
using Base::down_ba_;
|
||||
using Base::down_bb_;
|
||||
using Base::down_bc_;
|
||||
using Base::gate_bb_;
|
||||
using Base::gate_bc_;
|
||||
using Base::gate_up_ba_;
|
||||
using Base::m_local_num_;
|
||||
using Base::tp_part_idx;
|
||||
using Base::up_bb_;
|
||||
using Base::up_bc_;
|
||||
|
||||
public:
|
||||
using typename Base::input_t;
|
||||
using typename Base::output_t;
|
||||
|
||||
AVX2_GPTQ_INT4_MOE_TP() = default;
|
||||
AVX2_GPTQ_INT4_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {}
|
||||
|
||||
void derived_init() {
|
||||
auto& qc = config_.quant_config;
|
||||
if (qc.group_size == 0) {
|
||||
throw std::runtime_error("GPTQ INT4 requires group_size > 0");
|
||||
}
|
||||
printf("Created AVX2_GPTQ_INT4_MOE_TP %d at numa %d (group_size=%d)\n",
|
||||
tp_part_idx, numa_node_of_cpu(sched_getcpu()), qc.group_size);
|
||||
}
|
||||
|
||||
~AVX2_GPTQ_INT4_MOE_TP() = default;
|
||||
|
||||
// CRTP buffer creation
|
||||
size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }
|
||||
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
|
||||
return T::BufferB::required_size(n, k, config_.quant_config.group_size);
|
||||
}
|
||||
size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }
|
||||
|
||||
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
|
||||
return std::make_shared<typename T::BufferA>(m, k, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
|
||||
return std::make_shared<typename T::BufferB>(n, k, config_.quant_config.group_size, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
|
||||
return std::make_shared<typename T::BufferC>(m, n, data);
|
||||
}
|
||||
|
||||
// GEMM dispatch
|
||||
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
|
||||
int m = m_local_num_[expert_idx];
|
||||
auto& ba = gate_up_ba_[expert_idx];
|
||||
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
|
||||
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
|
||||
avx2::gemm_gptq_sym_int4(m, config_.intermediate_size, config_.hidden_size, *ba, *bb, *bc, ith, nth);
|
||||
}
|
||||
|
||||
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
|
||||
int m = m_local_num_[expert_idx];
|
||||
avx2::gemm_gptq_sym_int4(m, config_.hidden_size, config_.intermediate_size,
|
||||
*down_ba_[expert_idx], *down_bb_[expert_idx], *down_bc_[expert_idx], ith, nth);
|
||||
}
|
||||
|
||||
// Load weights from contiguous qweight + scales pointers
|
||||
void load_weights() {
|
||||
int group_size = config_.quant_config.group_size;
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
|
||||
if (config_.gate_scale == nullptr) {
|
||||
throw std::runtime_error("GPTQ INT4 MOE requires scale pointers.");
|
||||
}
|
||||
|
||||
// gate + up: qweight [K/8, N=intermediate], scales [K/gs, N=intermediate]
|
||||
int gate_up_k = config_.hidden_size;
|
||||
int gate_up_n = config_.intermediate_size;
|
||||
size_t qw_elems = (size_t)(gate_up_k / 8) * gate_up_n;
|
||||
size_t sc_elems = (size_t)(gate_up_k / group_size) * gate_up_n;
|
||||
|
||||
int nth = T::recommended_nth(gate_up_n);
|
||||
pool->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map, qw_elems, sc_elems](int task_id) {
|
||||
uint64_t expert_idx = task_id / nth;
|
||||
uint64_t logical = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
|
||||
gate_bb_[expert_idx]->from_mat(
|
||||
(uint32_t*)config_.gate_proj + logical * qw_elems,
|
||||
(float*)config_.gate_scale + logical * sc_elems,
|
||||
ith, nth);
|
||||
|
||||
up_bb_[expert_idx]->from_mat(
|
||||
(uint32_t*)config_.up_proj + logical * qw_elems,
|
||||
(float*)config_.up_scale + logical * sc_elems,
|
||||
ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
// down: qweight [K/8, N=hidden] where K=intermediate
|
||||
int down_k = config_.intermediate_size;
|
||||
int down_n = config_.hidden_size;
|
||||
size_t down_qw_elems = (size_t)(down_k / 8) * down_n;
|
||||
size_t down_sc_elems = (size_t)(down_k / group_size) * down_n;
|
||||
|
||||
nth = T::recommended_nth(down_n);
|
||||
pool->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map, down_qw_elems, down_sc_elems](int task_id) {
|
||||
uint64_t expert_idx = task_id / nth;
|
||||
uint64_t logical = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
|
||||
down_bb_[expert_idx]->from_mat(
|
||||
(uint32_t*)config_.down_proj + logical * down_qw_elems,
|
||||
(float*)config_.down_scale + logical * down_sc_elems,
|
||||
ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
|
||||
// write_weights_to_buffer for layerwise prefill / GPU expert offload
|
||||
// Note: GPTQ INT4 GPU offload requires the GPU to support INT4 dequant.
|
||||
// For now, this is a placeholder that copies raw packed data.
|
||||
void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id,
|
||||
const GeneralMOEConfig& full_config, const std::vector<uintptr_t>& w13_weight_ptrs,
|
||||
[[maybe_unused]] const std::vector<uintptr_t>& w13_scale_ptrs,
|
||||
const std::vector<uintptr_t>& w2_weight_ptrs,
|
||||
[[maybe_unused]] const std::vector<uintptr_t>& w2_scale_ptrs) const {
|
||||
// TODO: Implement GPTQ INT4 GPU offload when needed
|
||||
// For now, layerwise prefill with GPTQ INT4 is not supported
|
||||
throw std::runtime_error("GPTQ INT4 write_weights_to_buffer not yet implemented");
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// TP_MOE specialization for AVX2_GPTQ_INT4_MOE_TP
|
||||
// ============================================================================
|
||||
template <typename K>
|
||||
class TP_MOE<AVX2_GPTQ_INT4_MOE_TP<K>> : public TP_MOE<AVX2_MOE_BASE<K, AVX2_GPTQ_INT4_MOE_TP<K>>> {
|
||||
public:
|
||||
using Base = TP_MOE<AVX2_MOE_BASE<K, AVX2_GPTQ_INT4_MOE_TP<K>>>;
|
||||
using Base::Base;
|
||||
|
||||
void load_weights() override {
|
||||
auto& config = this->config;
|
||||
auto& tps = this->tps;
|
||||
auto& tp_count = this->tp_count;
|
||||
auto pool = config.pool;
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
|
||||
|
||||
const int group_size = config.quant_config.group_size;
|
||||
if (group_size == 0) {
|
||||
throw std::runtime_error("GPTQ INT4 requires group_size > 0");
|
||||
}
|
||||
|
||||
if (config.gate_projs.empty() && config.gate_proj == nullptr) {
|
||||
throw std::runtime_error("no weight source");
|
||||
}
|
||||
const bool use_per_expert_ptrs = !config.gate_projs.empty();
|
||||
|
||||
// Full dimensions
|
||||
const int full_intermediate = config.intermediate_size;
|
||||
const int full_hidden = config.hidden_size;
|
||||
|
||||
// gate/up: shape [K=hidden, N=intermediate]
|
||||
// qweight: [hidden/8, intermediate], scales: [hidden/gs, intermediate]
|
||||
const int gate_up_k_packed = full_hidden / 8;
|
||||
const int gate_up_num_groups = full_hidden / group_size;
|
||||
const size_t full_gate_up_qw_elems = (size_t)gate_up_k_packed * full_intermediate;
|
||||
const size_t full_gate_up_sc_elems = (size_t)gate_up_num_groups * full_intermediate;
|
||||
|
||||
// down: shape [K=intermediate, N=hidden]
|
||||
// qweight: [intermediate/8, hidden], scales: [intermediate/gs, hidden]
|
||||
const int down_k_packed = full_intermediate / 8;
|
||||
const int down_num_groups = full_intermediate / group_size;
|
||||
const size_t full_down_qw_elems = (size_t)down_k_packed * full_hidden;
|
||||
const size_t full_down_sc_elems = (size_t)down_num_groups * full_hidden;
|
||||
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
const int tp_intermediate = tpc.intermediate_size;
|
||||
|
||||
// gate/up TP: N=intermediate is split
|
||||
const size_t tp_gate_up_qw_elems = (size_t)gate_up_k_packed * tp_intermediate;
|
||||
const size_t tp_gate_up_sc_elems = (size_t)gate_up_num_groups * tp_intermediate;
|
||||
|
||||
tpc.gate_proj = new uint32_t[tpc.expert_num * tp_gate_up_qw_elems];
|
||||
tpc.up_proj = new uint32_t[tpc.expert_num * tp_gate_up_qw_elems];
|
||||
tpc.gate_scale = new float[tpc.expert_num * tp_gate_up_sc_elems];
|
||||
tpc.up_scale = new float[tpc.expert_num * tp_gate_up_sc_elems];
|
||||
|
||||
// down TP: K=intermediate is split
|
||||
const int tp_down_k_packed = tp_intermediate / 8;
|
||||
const int tp_down_num_groups = tp_intermediate / group_size;
|
||||
const size_t tp_down_qw_elems = (size_t)tp_down_k_packed * full_hidden;
|
||||
const size_t tp_down_sc_elems = (size_t)tp_down_num_groups * full_hidden;
|
||||
|
||||
tpc.down_proj = new uint32_t[tpc.expert_num * tp_down_qw_elems];
|
||||
tpc.down_scale = new float[tpc.expert_num * tp_down_sc_elems];
|
||||
|
||||
const int gate_up_n_offset = i * tp_intermediate;
|
||||
const int down_k_offset_packed = i * tp_down_k_packed;
|
||||
const int down_group_offset = i * tp_down_num_groups;
|
||||
|
||||
pool->get_subpool(i)->do_work_stealing_job(
|
||||
tpc.expert_num, nullptr,
|
||||
[&, &tpc](int expert_id_) {
|
||||
const size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
|
||||
|
||||
const uint32_t* gate_qw_src;
|
||||
const uint32_t* up_qw_src;
|
||||
const uint32_t* down_qw_src;
|
||||
const float* gate_sc_src;
|
||||
const float* up_sc_src;
|
||||
const float* down_sc_src;
|
||||
|
||||
if (use_per_expert_ptrs) {
|
||||
gate_qw_src = (const uint32_t*)config.gate_projs[0][expert_id];
|
||||
up_qw_src = (const uint32_t*)config.up_projs[0][expert_id];
|
||||
down_qw_src = (const uint32_t*)config.down_projs[0][expert_id];
|
||||
gate_sc_src = (const float*)config.gate_scales[0][expert_id];
|
||||
up_sc_src = (const float*)config.up_scales[0][expert_id];
|
||||
down_sc_src = (const float*)config.down_scales[0][expert_id];
|
||||
} else {
|
||||
gate_qw_src = (const uint32_t*)config.gate_proj + expert_id * full_gate_up_qw_elems;
|
||||
up_qw_src = (const uint32_t*)config.up_proj + expert_id * full_gate_up_qw_elems;
|
||||
down_qw_src = (const uint32_t*)config.down_proj + expert_id * full_down_qw_elems;
|
||||
gate_sc_src = (const float*)config.gate_scale + expert_id * full_gate_up_sc_elems;
|
||||
up_sc_src = (const float*)config.up_scale + expert_id * full_gate_up_sc_elems;
|
||||
down_sc_src = (const float*)config.down_scale + expert_id * full_down_sc_elems;
|
||||
}
|
||||
|
||||
uint32_t* gate_qw_dst = (uint32_t*)tpc.gate_proj + expert_id * tp_gate_up_qw_elems;
|
||||
uint32_t* up_qw_dst = (uint32_t*)tpc.up_proj + expert_id * tp_gate_up_qw_elems;
|
||||
float* gate_sc_dst = (float*)tpc.gate_scale + expert_id * tp_gate_up_sc_elems;
|
||||
float* up_sc_dst = (float*)tpc.up_scale + expert_id * tp_gate_up_sc_elems;
|
||||
|
||||
// gate/up qweight: [K/8, N] → slice N columns
|
||||
for (int kr = 0; kr < gate_up_k_packed; kr++) {
|
||||
std::memcpy(gate_qw_dst + kr * tp_intermediate,
|
||||
gate_qw_src + kr * full_intermediate + gate_up_n_offset,
|
||||
tp_intermediate * sizeof(uint32_t));
|
||||
std::memcpy(up_qw_dst + kr * tp_intermediate,
|
||||
up_qw_src + kr * full_intermediate + gate_up_n_offset,
|
||||
tp_intermediate * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
// gate/up scales: [num_groups, N] → slice N columns
|
||||
for (int g = 0; g < gate_up_num_groups; g++) {
|
||||
std::memcpy(gate_sc_dst + g * tp_intermediate,
|
||||
gate_sc_src + g * full_intermediate + gate_up_n_offset,
|
||||
tp_intermediate * sizeof(float));
|
||||
std::memcpy(up_sc_dst + g * tp_intermediate,
|
||||
up_sc_src + g * full_intermediate + gate_up_n_offset,
|
||||
tp_intermediate * sizeof(float));
|
||||
}
|
||||
|
||||
// down qweight: [K/8, N=hidden] row-major → slice contiguous rows (K/8 dim)
|
||||
uint32_t* down_qw_dst = (uint32_t*)tpc.down_proj + expert_id * tp_down_qw_elems;
|
||||
for (int kr = 0; kr < tp_down_k_packed; kr++) {
|
||||
std::memcpy(down_qw_dst + kr * full_hidden,
|
||||
down_qw_src + (down_k_offset_packed + kr) * full_hidden,
|
||||
full_hidden * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
// down scales: [K/gs, N=hidden] row-major → slice contiguous rows (K/gs dim)
|
||||
float* down_sc_dst = (float*)tpc.down_scale + expert_id * tp_down_sc_elems;
|
||||
for (int g = 0; g < tp_down_num_groups; g++) {
|
||||
std::memcpy(down_sc_dst + g * full_hidden,
|
||||
down_sc_src + (down_group_offset + g) * full_hidden,
|
||||
full_hidden * sizeof(float));
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
});
|
||||
|
||||
// Call per-TP load_weights
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
tps[i]->load_weights();
|
||||
});
|
||||
|
||||
// Free temporary buffers
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
delete[] (uint32_t*)tpc.gate_proj;
|
||||
delete[] (uint32_t*)tpc.up_proj;
|
||||
delete[] (uint32_t*)tpc.down_proj;
|
||||
delete[] (float*)tpc.gate_scale;
|
||||
delete[] (float*)tpc.up_scale;
|
||||
delete[] (float*)tpc.down_scale;
|
||||
});
|
||||
|
||||
this->weights_loaded = true;
|
||||
}
|
||||
|
||||
void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w13_scale_ptrs,
|
||||
const std::vector<uintptr_t>& w2_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w2_scale_ptrs) {
|
||||
// GPTQ INT4 GPU offload not yet supported
|
||||
throw std::runtime_error("GPTQ INT4 write_weight_scale_to_buffer not yet implemented");
|
||||
}
|
||||
};
|
||||
|
||||
#endif // CPUINFER_OPERATOR_AVX2_GPTQ_INT4_MOE_H
|
||||
50
kt-kernel/operators/avx2/gptq_int4_dequant.hpp
Normal file
50
kt-kernel/operators/avx2/gptq_int4_dequant.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
/**
|
||||
* @Description : GPTQ-Int4 symmetric dequantization for AVX2
|
||||
* @Author : Claude
|
||||
* @Date : 2026-03-18
|
||||
* @Version : 1.0.0
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
*
|
||||
* GPTQ symmetric quantization (sym=true):
|
||||
* dequant[k,n] = (((qweight[k/8,n] >> ((k%8)*4)) & 0xF) - 8) * scale[k/gs, n]
|
||||
*
|
||||
* qweight layout: [K/8, N] int32, packing 8 x 4-bit values along K dimension
|
||||
* scales layout: [K/group_size, N] fp16 (converted to fp32 at load time)
|
||||
* qzeros: not needed (symmetric, zero_point = 8 for all)
|
||||
**/
|
||||
#ifndef CPUINFER_OPERATOR_AVX2_GPTQ_INT4_DEQUANT_H
|
||||
#define CPUINFER_OPERATOR_AVX2_GPTQ_INT4_DEQUANT_H
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <cstdint>
|
||||
|
||||
namespace avx2 {
|
||||
|
||||
// Dequantize 8 x 4-bit values from a packed int32 (symmetric, zero_point=8)
|
||||
// packed_weight contains 8 nibbles: bits [0:3]=val0, [4:7]=val1, ..., [28:31]=val7
|
||||
// Result: ((nibble - 8) * scale) for each of the 8 values
|
||||
static inline __m256 gptq_sym_dequant_8x4bit(uint32_t packed_weight, float scale) {
|
||||
// Variable shift: extract each 4-bit nibble into its own 32-bit lane
|
||||
// GPTQ packing: bit 0-3 = k_offset 0, bit 4-7 = k_offset 1, ...
|
||||
// _mm256_set_epi32 sets lanes in reverse order (lane 7 first), so:
|
||||
// lane 0 = shift 0 (k_offset 0), lane 1 = shift 4 (k_offset 1), ...
|
||||
const __m256i shifts = _mm256_set_epi32(28, 24, 20, 16, 12, 8, 4, 0);
|
||||
__m256i packed_v = _mm256_set1_epi32(packed_weight);
|
||||
__m256i nibbles = _mm256_and_si256(_mm256_srlv_epi32(packed_v, shifts),
|
||||
_mm256_set1_epi32(0xF));
|
||||
|
||||
// (nibble - 8) * scale
|
||||
__m256 w = _mm256_cvtepi32_ps(nibbles);
|
||||
return _mm256_mul_ps(_mm256_sub_ps(w, _mm256_set1_ps(8.0f)),
|
||||
_mm256_set1_ps(scale));
|
||||
}
|
||||
|
||||
// Scalar version for verification
|
||||
static inline float gptq_sym_dequant_scalar(uint32_t packed_weight, int k_in_pack, float scale) {
|
||||
int nibble = (packed_weight >> (k_in_pack * 4)) & 0xF;
|
||||
return (float)(nibble - 8) * scale;
|
||||
}
|
||||
|
||||
} // namespace avx2
|
||||
|
||||
#endif // CPUINFER_OPERATOR_AVX2_GPTQ_INT4_DEQUANT_H
|
||||
580
kt-kernel/operators/avx2/moe_base.hpp
Normal file
580
kt-kernel/operators/avx2/moe_base.hpp
Normal file
@@ -0,0 +1,580 @@
|
||||
/**
|
||||
* @Description : AVX2 MoE base class (ported from amx/moe_base.hpp)
|
||||
* @Author : Claude
|
||||
* @Date : 2026-03-18
|
||||
* @Version : 1.0.0
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
*
|
||||
* All AVX512 intrinsics (__m512, _mm512_*) replaced with AVX2 (__m256, _mm256_*).
|
||||
* AMX tile configuration calls (T::config()) are kept but are no-ops.
|
||||
**/
|
||||
#ifndef CPUINFER_OPERATOR_AVX2_MOE_BASE_H
|
||||
#define CPUINFER_OPERATOR_AVX2_MOE_BASE_H
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../../cpu_backend/shared_mem_buffer.h"
|
||||
#include "../../cpu_backend/worker_pool.h"
|
||||
#include "../common.hpp"
|
||||
#include "../moe-tp.hpp"
|
||||
#include "avx2_bf16_gemm.hpp"
|
||||
#include "avx2_bf16_utils.hpp"
|
||||
#include "llama.cpp/ggml.h"
|
||||
|
||||
template <class T, class Derived>
|
||||
class AVX2_MOE_BASE {
|
||||
public:
|
||||
int tp_part_idx = 0;
|
||||
|
||||
ggml_bf16_t* m_local_input_ = nullptr;
|
||||
ggml_bf16_t* m_local_gate_output_ = nullptr;
|
||||
ggml_bf16_t* m_local_up_output_ = nullptr;
|
||||
ggml_bf16_t* m_local_down_output_ = nullptr;
|
||||
|
||||
std::vector<std::vector<int>> m_local_pos_;
|
||||
std::vector<int> m_local_num_;
|
||||
std::vector<int> m_expert_id_map_;
|
||||
std::vector<ggml_bf16_t*> m_local_input_ptr_;
|
||||
std::vector<ggml_bf16_t*> m_local_gate_output_ptr_;
|
||||
std::vector<ggml_bf16_t*> m_local_up_output_ptr_;
|
||||
std::vector<ggml_bf16_t*> m_local_down_output_ptr_;
|
||||
|
||||
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
|
||||
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
|
||||
|
||||
size_t pool_count_ = 0;
|
||||
size_t gate_up_ba_pool_bytes_ = 0;
|
||||
size_t gate_bc_pool_bytes_ = 0;
|
||||
size_t up_bc_pool_bytes_ = 0;
|
||||
size_t down_ba_pool_bytes_ = 0;
|
||||
size_t down_bc_pool_bytes_ = 0;
|
||||
void* gate_up_ba_pool_ = nullptr;
|
||||
void* gate_bc_pool_ = nullptr;
|
||||
void* up_bc_pool_ = nullptr;
|
||||
void* down_ba_pool_ = nullptr;
|
||||
void* down_bc_pool_ = nullptr;
|
||||
|
||||
GeneralMOEConfig config_;
|
||||
using input_t = ggml_bf16_t;
|
||||
using output_t = float;
|
||||
static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;
|
||||
|
||||
AVX2_MOE_BASE(GeneralMOEConfig config, int tp_part_idx_) : tp_part_idx(tp_part_idx_), config_(config) {
|
||||
init();
|
||||
derived()->derived_init();
|
||||
}
|
||||
|
||||
void init() {
|
||||
if (config_.load && config_.path == "") {
|
||||
config_.load = false;
|
||||
}
|
||||
|
||||
MemoryRequest mem_requests;
|
||||
mem_requests.append_pointer(
|
||||
&m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);
|
||||
mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||
config_.max_len * config_.intermediate_size);
|
||||
mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||
config_.max_len * config_.intermediate_size);
|
||||
mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||
config_.max_len * config_.hidden_size);
|
||||
|
||||
m_local_pos_.resize(config_.max_len);
|
||||
for (int i = 0; i < config_.max_len; i++) {
|
||||
m_local_pos_[i].resize(config_.num_experts_per_tok);
|
||||
}
|
||||
m_expert_id_map_.resize(config_.expert_num);
|
||||
m_local_num_.resize(config_.expert_num);
|
||||
m_local_input_ptr_.resize(config_.expert_num);
|
||||
m_local_gate_output_ptr_.resize(config_.expert_num);
|
||||
m_local_up_output_ptr_.resize(config_.expert_num);
|
||||
m_local_down_output_ptr_.resize(config_.expert_num);
|
||||
|
||||
for (size_t i = 0; i < config_.expert_num; i++) {
|
||||
gate_up_ba_.push_back(make_buffer_a(config_.max_len, config_.hidden_size, nullptr));
|
||||
gate_bc_.push_back(make_buffer_c(config_.max_len, config_.intermediate_size, nullptr));
|
||||
up_bc_.push_back(make_buffer_c(config_.max_len, config_.intermediate_size, nullptr));
|
||||
down_ba_.push_back(make_buffer_a(config_.max_len, config_.intermediate_size, nullptr));
|
||||
down_bc_.push_back(make_buffer_c(config_.max_len, config_.hidden_size, nullptr));
|
||||
|
||||
void* gate_bb_ptr =
|
||||
std::aligned_alloc(64, buffer_b_required_size(config_.intermediate_size, config_.hidden_size));
|
||||
gate_bb_.push_back(make_buffer_b(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));
|
||||
|
||||
void* up_bb_ptr = std::aligned_alloc(64, buffer_b_required_size(config_.intermediate_size, config_.hidden_size));
|
||||
up_bb_.push_back(make_buffer_b(config_.intermediate_size, config_.hidden_size, up_bb_ptr));
|
||||
|
||||
void* down_bb_ptr =
|
||||
std::aligned_alloc(64, buffer_b_required_size(config_.hidden_size, config_.intermediate_size));
|
||||
down_bb_.push_back(make_buffer_b(config_.hidden_size, config_.intermediate_size, down_bb_ptr));
|
||||
}
|
||||
|
||||
pool_count_ = config_.max_len * config_.num_experts_per_tok + config_.expert_num * T::M_STEP;
|
||||
|
||||
gate_up_ba_pool_bytes_ = buffer_a_required_size(pool_count_, config_.hidden_size) + pool_count_ * 64;
|
||||
gate_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
|
||||
up_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
|
||||
down_ba_pool_bytes_ = buffer_a_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
|
||||
down_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.hidden_size) + pool_count_ * 64;
|
||||
|
||||
mem_requests.append_pointer(&gate_up_ba_pool_, gate_up_ba_pool_bytes_);
|
||||
mem_requests.append_pointer(&gate_bc_pool_, gate_bc_pool_bytes_);
|
||||
mem_requests.append_pointer(&up_bc_pool_, up_bc_pool_bytes_);
|
||||
mem_requests.append_pointer(&down_ba_pool_, down_ba_pool_bytes_);
|
||||
mem_requests.append_pointer(&down_bc_pool_, down_bc_pool_bytes_);
|
||||
|
||||
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
|
||||
}
|
||||
|
||||
~AVX2_MOE_BASE() = default;
|
||||
|
||||
void warm_up() {
|
||||
int qlen = config_.max_len;
|
||||
std::vector<uint8_t> input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
|
||||
std::vector<uint8_t> output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
|
||||
std::vector<int64_t> expert_ids(qlen * config_.num_experts_per_tok);
|
||||
std::vector<float> weights(qlen * config_.num_experts_per_tok);
|
||||
for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) {
|
||||
expert_ids[i] = i % config_.expert_num;
|
||||
weights[i] = 0.01;
|
||||
}
|
||||
forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data());
|
||||
}
|
||||
|
||||
void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
|
||||
if (qlen > 1) {
|
||||
forward_prefill(qlen, k, expert_ids, weights, input, output);
|
||||
} else {
|
||||
forward_decode(k, expert_ids, weights, input, output);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void load_weights(Args&&... args) {
|
||||
derived()->load_weights(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void write_weights_to_buffer(Args&&... args) const {
|
||||
derived_const()->write_weights_to_buffer(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
|
||||
void* output) {
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
|
||||
int activated_expert = 0;
|
||||
std::fill(m_local_num_.begin(), m_local_num_.end(), 0);
|
||||
for (int i = 0; i < qlen; i++) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (config_.should_skip_expert(expert_ids[i * k + j])) {
|
||||
continue;
|
||||
}
|
||||
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
if (m_local_num_[i] > 0) {
|
||||
m_expert_id_map_[activated_expert] = i;
|
||||
activated_expert++;
|
||||
}
|
||||
}
|
||||
|
||||
// Assign pool memory to buffers
|
||||
size_t offset = 0;
|
||||
void* gate_up_ba_pool_ptr = gate_up_ba_pool_;
|
||||
void* gate_bc_pool_ptr = gate_bc_pool_;
|
||||
void* up_bc_pool_ptr = up_bc_pool_;
|
||||
void* down_ba_pool_ptr = down_ba_pool_;
|
||||
void* down_bc_pool_ptr = down_bc_pool_;
|
||||
constexpr size_t M_STEP = T::M_STEP;
|
||||
auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); };
|
||||
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
|
||||
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
|
||||
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
|
||||
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
|
||||
offset += m_local_num_[i];
|
||||
|
||||
if (m_local_num_[i] == 0) continue;
|
||||
|
||||
size_t max_m = (m_local_num_[i] + M_STEP - 1) / M_STEP * M_STEP;
|
||||
gate_up_ba_[i]->max_m = max_m;
|
||||
gate_up_ba_[i]->set_data(gate_up_ba_pool_ptr);
|
||||
gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + align64(buffer_a_required_size(max_m, config_.hidden_size)));
|
||||
|
||||
gate_bc_[i]->max_m = max_m;
|
||||
gate_bc_[i]->set_data(gate_bc_pool_ptr);
|
||||
gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + align64(buffer_c_required_size(max_m, config_.intermediate_size)));
|
||||
|
||||
up_bc_[i]->max_m = max_m;
|
||||
up_bc_[i]->set_data(up_bc_pool_ptr);
|
||||
up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + align64(buffer_c_required_size(max_m, config_.intermediate_size)));
|
||||
|
||||
down_ba_[i]->max_m = max_m;
|
||||
down_ba_[i]->set_data(down_ba_pool_ptr);
|
||||
down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + align64(buffer_a_required_size(max_m, config_.intermediate_size)));
|
||||
|
||||
down_bc_[i]->max_m = max_m;
|
||||
down_bc_[i]->set_data(down_bc_pool_ptr);
|
||||
down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + align64(buffer_c_required_size(max_m, config_.hidden_size)));
|
||||
}
|
||||
|
||||
auto direct_or_pool = [&](int count, auto&& fn) {
|
||||
if (qlen < 10) {
|
||||
for (int i = 0; i < count; i++) fn(i);
|
||||
} else {
|
||||
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
// Copy input to per-expert buffers
|
||||
direct_or_pool(qlen, [&](int i) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (config_.should_skip_expert(expert_ids[i * k + j])) continue;
|
||||
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
|
||||
(ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
|
||||
}
|
||||
});
|
||||
|
||||
// Pack input into BufferA (trivial memcpy for AVX2)
|
||||
direct_or_pool(activated_expert, [this](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
|
||||
});
|
||||
|
||||
// Gate + Up GEMM
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert * 2, [](int) { T::config(); },
|
||||
[this, nth, qlen](int task_id2) {
|
||||
int task_id = task_id2 / 2;
|
||||
bool do_up = task_id2 % 2;
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
derived()->do_gate_up_gemm(do_up, expert_idx, ith, nth, qlen);
|
||||
if (do_up) {
|
||||
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
|
||||
} else {
|
||||
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
|
||||
// Activation: SiLU(gate) * up — AVX2 version (8 elements at a time)
|
||||
apply_activation(activated_expert, nth, qlen);
|
||||
|
||||
// Pack activation output into BufferA for down projection
|
||||
pool->do_work_stealing_job(
|
||||
activated_expert, nullptr,
|
||||
[this](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
// Down GEMM
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert, [](int) { T::config(); },
|
||||
[this, nth, qlen](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
derived()->do_down_gemm(expert_idx, ith, nth, qlen);
|
||||
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
// Weighted sum of expert outputs — AVX2 version (16 BF16 = 2x8 FP32 at a time)
|
||||
pool->do_work_stealing_job(
|
||||
qlen, nullptr,
|
||||
[this, output, k, expert_ids, weights](int i) {
|
||||
for (int e = 0; e < config_.hidden_size; e += 16) {
|
||||
__m256 x0 = _mm256_setzero_ps();
|
||||
__m256 x1 = _mm256_setzero_ps();
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (config_.should_skip_expert(expert_ids[i * k + j])) continue;
|
||||
__m256 weight = _mm256_set1_ps(weights[i * k + j]);
|
||||
__m256 d0, d1;
|
||||
avx2::load_16xbf16_to_2x8xfp32(
|
||||
m_local_down_output_ptr_[expert_ids[i * k + j]] +
|
||||
m_local_pos_[i][j] * config_.hidden_size + e,
|
||||
&d0, &d1);
|
||||
x0 = _mm256_fmadd_ps(d0, weight, x0);
|
||||
x1 = _mm256_fmadd_ps(d1, weight, x1);
|
||||
}
|
||||
auto f32out = (__m256*)((float*)output + i * config_.hidden_size + e);
|
||||
f32out[0] = x0;
|
||||
f32out[1] = x1;
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
|
||||
void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
|
||||
int qlen = 1;
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
|
||||
int activated_expert = 0;
|
||||
std::fill(m_local_num_.begin(), m_local_num_.end(), 0);
|
||||
for (int i = 0; i < k; i++) {
|
||||
if (config_.should_skip_expert(expert_ids[i])) continue;
|
||||
m_expert_id_map_[activated_expert] = expert_ids[i];
|
||||
m_local_pos_[0][i] = 0;
|
||||
m_local_num_[expert_ids[i]] = qlen;
|
||||
activated_expert++;
|
||||
}
|
||||
|
||||
size_t offset = 0;
|
||||
for (int i = 0; i < activated_expert; i++) {
|
||||
auto expert_idx = m_expert_id_map_[i];
|
||||
m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size;
|
||||
m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size;
|
||||
m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size;
|
||||
offset += qlen;
|
||||
}
|
||||
|
||||
// Assign pool memory for decode
|
||||
void* gate_bc_pool_ptr = gate_bc_pool_;
|
||||
void* up_bc_pool_ptr = up_bc_pool_;
|
||||
void* down_ba_pool_ptr = down_ba_pool_;
|
||||
void* down_bc_pool_ptr = down_bc_pool_;
|
||||
constexpr size_t M_STEP = T::M_STEP;
|
||||
auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); };
|
||||
|
||||
for (int i = 0; i < activated_expert; i++) {
|
||||
auto expert_idx = m_expert_id_map_[i];
|
||||
size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;
|
||||
|
||||
gate_bc_[expert_idx]->max_m = max_m;
|
||||
gate_bc_[expert_idx]->set_data(gate_bc_pool_ptr);
|
||||
gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + align64(buffer_c_required_size(max_m, config_.intermediate_size)));
|
||||
|
||||
up_bc_[expert_idx]->max_m = max_m;
|
||||
up_bc_[expert_idx]->set_data(up_bc_pool_ptr);
|
||||
up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + align64(buffer_c_required_size(max_m, config_.intermediate_size)));
|
||||
|
||||
down_ba_[expert_idx]->max_m = max_m;
|
||||
down_ba_[expert_idx]->set_data(down_ba_pool_ptr);
|
||||
down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + align64(buffer_a_required_size(max_m, config_.intermediate_size)));
|
||||
|
||||
down_bc_[expert_idx]->max_m = max_m;
|
||||
down_bc_[expert_idx]->set_data(down_bc_pool_ptr);
|
||||
down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + align64(buffer_c_required_size(max_m, config_.hidden_size)));
|
||||
}
|
||||
|
||||
// Pack input into BufferA for each activated expert
|
||||
void* gate_up_ba_pool_ptr = gate_up_ba_pool_;
|
||||
for (int i = 0; i < activated_expert; i++) {
|
||||
auto expert_idx = m_expert_id_map_[i];
|
||||
size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;
|
||||
gate_up_ba_[expert_idx]->max_m = max_m;
|
||||
gate_up_ba_[expert_idx]->set_data(gate_up_ba_pool_ptr);
|
||||
gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + align64(buffer_a_required_size(max_m, config_.hidden_size)));
|
||||
gate_up_ba_[expert_idx]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);
|
||||
}
|
||||
|
||||
// Gate + Up GEMM
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert * 2, [](int) { T::config(); },
|
||||
[this, nth, qlen](int task_id2) {
|
||||
int task_id = task_id2 / 2;
|
||||
bool do_up = task_id2 % 2;
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
derived()->do_gate_up_gemm(do_up, expert_idx, ith, nth, qlen);
|
||||
if (do_up) {
|
||||
up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth);
|
||||
} else {
|
||||
gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth);
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
|
||||
// Activation
|
||||
apply_activation(activated_expert, nth, qlen);
|
||||
|
||||
// Pack for down projection
|
||||
pool->do_work_stealing_job(
|
||||
activated_expert, nullptr,
|
||||
[this, qlen](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
// Down GEMM
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert, [](int) { T::config(); },
|
||||
[this, nth, qlen](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
derived()->do_down_gemm(expert_idx, ith, nth, qlen);
|
||||
down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
// Weighted sum — AVX2 (16 BF16 at a time)
|
||||
for (int e = 0; e < config_.hidden_size; e += 16) {
|
||||
__m256 x0 = _mm256_setzero_ps();
|
||||
__m256 x1 = _mm256_setzero_ps();
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (config_.should_skip_expert(expert_ids[j])) continue;
|
||||
__m256 weight = _mm256_set1_ps(weights[j]);
|
||||
__m256 d0, d1;
|
||||
avx2::load_16xbf16_to_2x8xfp32(
|
||||
m_local_down_output_ptr_[expert_ids[j]] + m_local_pos_[0][j] * config_.hidden_size + e,
|
||||
&d0, &d1);
|
||||
x0 = _mm256_fmadd_ps(d0, weight, x0);
|
||||
x1 = _mm256_fmadd_ps(d1, weight, x1);
|
||||
}
|
||||
auto f32out = (__m256*)((float*)output + e);
|
||||
f32out[0] = x0;
|
||||
f32out[1] = x1;
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
Derived* derived() { return static_cast<Derived*>(this); }
|
||||
const Derived* derived_const() const { return static_cast<const Derived*>(this); }
|
||||
|
||||
void derived_init() {}
|
||||
|
||||
// Buffer creation/size delegation (CRTP)
|
||||
size_t buffer_a_required_size(size_t m, size_t k) const { return derived_const()->buffer_a_required_size_impl(m, k); }
|
||||
size_t buffer_b_required_size(size_t n, size_t k) const { return derived_const()->buffer_b_required_size_impl(n, k); }
|
||||
size_t buffer_c_required_size(size_t m, size_t n) const { return derived_const()->buffer_c_required_size_impl(m, n); }
|
||||
|
||||
std::shared_ptr<typename T::BufferA> make_buffer_a(size_t m, size_t k, void* data) const {
|
||||
return derived_const()->make_buffer_a_impl(m, k, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferB> make_buffer_b(size_t n, size_t k, void* data) const {
|
||||
return derived_const()->make_buffer_b_impl(n, k, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferC> make_buffer_c(size_t m, size_t n, void* data) const {
|
||||
return derived_const()->make_buffer_c_impl(m, n, data);
|
||||
}
|
||||
|
||||
// SiLU activation — AVX2: process 8 BF16 elements at a time
|
||||
void apply_activation(int activated_expert, int nth, int qlen) {
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
auto fn = [this, nth](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
|
||||
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
|
||||
ggml_bf16_t* gate_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
ggml_bf16_t* up_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
int j = n_start;
|
||||
for (; j + 8 <= n_end; j += 8) {
|
||||
__m256 gate_val = avx2::load_bf16_to_fp32(gate_ptr + j);
|
||||
__m256 up_val = avx2::load_bf16_to_fp32(up_ptr + j);
|
||||
__m256 result = avx2::act_fn(gate_val, up_val);
|
||||
avx2::store_fp32_to_bf16(gate_ptr + j, result);
|
||||
}
|
||||
// Scalar tail
|
||||
for (; j < n_end; j++) {
|
||||
float g = GGML_BF16_TO_FP32(gate_ptr[j]);
|
||||
float u = GGML_BF16_TO_FP32(up_ptr[j]);
|
||||
float sigmoid_g = 1.0f / (1.0f + expf(-g));
|
||||
gate_ptr[j] = GGML_FP32_TO_BF16(g * sigmoid_g * u);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (activated_expert == 0) return;
|
||||
|
||||
if (qlen < 10) {
|
||||
for (int task_id = 0; task_id < nth * activated_expert; task_id++) fn(task_id);
|
||||
} else {
|
||||
pool->do_work_stealing_job(nth * activated_expert, nullptr, fn, nullptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// TP_MOE specialization for AVX2_MOE_BASE derived classes
|
||||
// ============================================================================
|
||||
|
||||
template <class T, class Derived>
|
||||
class TP_MOE<AVX2_MOE_BASE<T, Derived>> : public TP_MOE_Common<AVX2_MOE_BASE<T, Derived>> {
|
||||
public:
|
||||
using TP_MOE_Common<AVX2_MOE_BASE<T, Derived>>::TP_MOE_Common;
|
||||
|
||||
void load_weights() override { throw std::runtime_error("Not Implemented"); }
|
||||
|
||||
void write_weight_scale_to_buffer(int gpu_tp_count, int gpu_experts_num,
|
||||
const std::vector<uintptr_t>& w13_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w13_scale_ptrs,
|
||||
const std::vector<uintptr_t>& w2_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w2_scale_ptrs) {
|
||||
throw std::runtime_error("Not Implemented");
|
||||
}
|
||||
|
||||
void merge_results(int qlen, void* output, bool incremental) override {
|
||||
auto& config = this->config;
|
||||
auto& tp_count = this->tp_count;
|
||||
auto& local_output_numa = this->local_output_numa;
|
||||
auto& tp_configs = this->tp_configs;
|
||||
|
||||
auto merge_fn = [this, output, incremental, &config, &tp_count, &local_output_numa, &tp_configs](int token_nth) {
|
||||
float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size;
|
||||
if (incremental) {
|
||||
// Convert BF16 output to FP32 and add — AVX2 (16 BF16 at a time)
|
||||
for (int e = 0; e < config.hidden_size; e += 16) {
|
||||
__m256 x0, x1;
|
||||
avx2::load_16xbf16_to_2x8xfp32((ggml_bf16_t*)output + token_nth * config.hidden_size + e, &x0, &x1);
|
||||
*((__m256*)(merge_to + e)) = _mm256_add_ps(*((__m256*)(merge_to + e)), x0);
|
||||
*((__m256*)(merge_to + e + 8)) = _mm256_add_ps(*((__m256*)(merge_to + e + 8)), x1);
|
||||
}
|
||||
}
|
||||
// Sum across TP parts
|
||||
for (int i = 1; i < tp_count; i++) {
|
||||
float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size;
|
||||
for (int e = 0; e < tp_configs[i].hidden_size; e += 8) {
|
||||
*((__m256*)(merge_to + e)) = _mm256_add_ps(*((__m256*)(merge_to + e)), *((__m256*)(merge_from + e)));
|
||||
}
|
||||
}
|
||||
// Convert FP32 -> BF16 output
|
||||
for (int e = 0; e < config.hidden_size; e += 16) {
|
||||
__m256 x0 = *(__m256*)(merge_to + e);
|
||||
__m256 x1 = *(__m256*)(merge_to + e + 8);
|
||||
avx2::store_2x8xfp32_to_16xbf16(&x0, &x1, (ggml_bf16_t*)output + token_nth * config.hidden_size + e);
|
||||
}
|
||||
};
|
||||
|
||||
auto pool = config.pool;
|
||||
if (qlen < 10) {
|
||||
for (int i = 0; i < qlen; i++) merge_fn(i);
|
||||
} else {
|
||||
pool->do_work_stealing_job(qlen, nullptr, merge_fn, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
void merge_results(int qlen, void* output) override { merge_results(qlen, output, false); }
|
||||
};
|
||||
|
||||
#endif // CPUINFER_OPERATOR_AVX2_MOE_BASE_H
|
||||
@@ -93,7 +93,7 @@ class KTMoEWrapper:
|
||||
# Select backend based on method
|
||||
if method in ["AMXINT4", "AMXINT8"]:
|
||||
backend_cls = AMXMoEWrapper
|
||||
elif method in ["RAWINT4", "FP8", "BF16", "FP8_PERCHANNEL"]:
|
||||
elif method in ["RAWINT4", "FP8", "BF16", "FP8_PERCHANNEL", "GPTQ_INT4"]:
|
||||
backend_cls = NativeMoEWrapper
|
||||
elif method == "LLAMAFILE":
|
||||
backend_cls = LlamafileMoEWrapper
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader
|
||||
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader, GPTQSafeTensorLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
import kt_kernel_ext.moe as _moe_mod
|
||||
|
||||
@@ -15,6 +15,9 @@ AMXInt4_KGroup_MOE = getattr(_moe_mod, "AMXInt4_KGroup_MOE", None)
|
||||
AMXFP8_MOE = getattr(_moe_mod, "AMXFP8_MOE", None)
|
||||
AMXBF16_MOE = getattr(_moe_mod, "AMXBF16_MOE", None)
|
||||
AMXFP8PerChannel_MOE = getattr(_moe_mod, "AMXFP8PerChannel_MOE", None)
|
||||
AVX2BF16_MOE = getattr(_moe_mod, "AVX2BF16_MOE", None)
|
||||
AVX2FP8_MOE = getattr(_moe_mod, "AVX2FP8_MOE", None)
|
||||
AVX2GPTQInt4_MOE = getattr(_moe_mod, "AVX2GPTQInt4_MOE", None)
|
||||
|
||||
_HAS_AMXINT4_SUPPORT = AMXInt4_MOE is not None
|
||||
_HAS_AMXINT8_SUPPORT = AMXInt8_MOE is not None
|
||||
@@ -22,6 +25,9 @@ _HAS_RAWINT4_SUPPORT = AMXInt4_KGroup_MOE is not None
|
||||
_HAS_FP8_SUPPORT = AMXFP8_MOE is not None
|
||||
_HAS_BF16_SUPPORT = AMXBF16_MOE is not None
|
||||
_HAS_FP8_PERCHANNEL_SUPPORT = AMXFP8PerChannel_MOE is not None
|
||||
_HAS_AVX2_BF16_SUPPORT = AVX2BF16_MOE is not None
|
||||
_HAS_AVX2_FP8_SUPPORT = AVX2FP8_MOE is not None
|
||||
_HAS_AVX2_GPTQ_INT4_SUPPORT = AVX2GPTQInt4_MOE is not None
|
||||
|
||||
|
||||
class AMXMoEWrapper(BaseMoEWrapper):
|
||||
@@ -346,10 +352,11 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
" - AVX512F + AVX512BW (VNNI optional)\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 enabled."
|
||||
)
|
||||
if method == "FP8" and not _HAS_FP8_SUPPORT:
|
||||
if method == "FP8" and not _HAS_FP8_SUPPORT and not _HAS_AVX2_FP8_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"FP8 backend not available. Required ISA:\n"
|
||||
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\n"
|
||||
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI (for AMX), or\n"
|
||||
" - AVX2 + FMA (for AVX2 fallback)\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled."
|
||||
)
|
||||
if method == "FP8_PERCHANNEL" and not _HAS_FP8_PERCHANNEL_SUPPORT:
|
||||
@@ -358,11 +365,17 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled."
|
||||
)
|
||||
if method == "BF16" and not _HAS_BF16_SUPPORT:
|
||||
if method == "BF16" and not _HAS_BF16_SUPPORT and not _HAS_AVX2_BF16_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"BF16 backend not available. Required ISA:\n"
|
||||
" - AVX512F + AVX512BW + AVX512_BF16\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 + BF16 enabled."
|
||||
" - AVX512F + AVX512BW + AVX512_BF16 (for AMX backend), or\n"
|
||||
" - AVX2 + FMA (for AVX2 fallback backend)\n"
|
||||
"Please recompile kt_kernel_ext with AVX512+BF16 or AVX2 enabled."
|
||||
)
|
||||
if method == "GPTQ_INT4" and not _HAS_AVX2_GPTQ_INT4_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"GPTQ_INT4 backend not available.\n"
|
||||
"Please recompile kt_kernel_ext with AVX2 enabled."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
@@ -391,6 +404,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path, scale_suffix="weight_scale")
|
||||
elif method == "BF16":
|
||||
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
|
||||
elif method == "GPTQ_INT4":
|
||||
NativeMoEWrapper._native_loader_instance = GPTQSafeTensorLoader(weight_path)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
|
||||
self.loader = NativeMoEWrapper._native_loader_instance
|
||||
@@ -506,15 +521,31 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
moe_config.quant_config.bits = 8
|
||||
moe_config.quant_config.group_size = 128
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXFP8_MOE(moe_config)
|
||||
if _HAS_FP8_SUPPORT:
|
||||
self.moe = AMXFP8_MOE(moe_config)
|
||||
else:
|
||||
self.moe = AVX2FP8_MOE(moe_config)
|
||||
elif self.method == "FP8_PERCHANNEL":
|
||||
moe_config.quant_config.bits = 8
|
||||
moe_config.quant_config.per_channel = True
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXFP8PerChannel_MOE(moe_config)
|
||||
elif self.method == "GPTQ_INT4":
|
||||
# GPTQ symmetric INT4: qweight (int32) + scales (fp32)
|
||||
group_size = self.gate_scales[0].shape[0] # scales shape [K/gs, N], first dim = num_groups
|
||||
# hidden_size / num_groups = group_size
|
||||
actual_gs = self.hidden_size // group_size
|
||||
moe_config.quant_config.bits = 4
|
||||
moe_config.quant_config.group_size = actual_gs
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AVX2GPTQInt4_MOE(moe_config)
|
||||
elif self.method == "BF16":
|
||||
# BF16 has no quantization config needed
|
||||
self.moe = AMXBF16_MOE(moe_config)
|
||||
# Prefer AMX backend, fall back to AVX2
|
||||
if _HAS_BF16_SUPPORT:
|
||||
self.moe = AMXBF16_MOE(moe_config)
|
||||
else:
|
||||
self.moe = AVX2BF16_MOE(moe_config)
|
||||
t4 = time.time()
|
||||
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
|
||||
@@ -961,3 +961,120 @@ class GGUFLoader:
|
||||
data = torch.from_numpy(np.frombuffer(data_bytes, dtype=np.uint8).copy())
|
||||
|
||||
return data, ggml_type
|
||||
|
||||
|
||||
class GPTQSafeTensorLoader(FP8SafeTensorLoader):
|
||||
"""Loader for symmetric GPTQ-Int4 expert weights (qweight + scales, no qzeros).
|
||||
|
||||
Only supports sym=true, desc_act=false GPTQ models.
|
||||
|
||||
Tensor keys:
|
||||
- qweight: {prefix}.{id}.{proj}.qweight (int32, packed 8x4-bit along K)
|
||||
- scales: {prefix}.{id}.{proj}.scales (fp16 -> converted to fp32)
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
# Call FP8SafeTensorLoader init (which calls SafeTensorLoader init + format detection)
|
||||
super().__init__(file_path, scale_suffix="scales")
|
||||
# Verify GPTQ config
|
||||
self._verify_gptq_config(file_path)
|
||||
|
||||
def _detect_format(self):
|
||||
"""Override FP8 format detection to look for .qweight instead of .weight."""
|
||||
sample_keys = list(self.tensor_file_map.keys())[:2000]
|
||||
|
||||
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
|
||||
for key in sample_keys:
|
||||
if ".experts." in key and f".{gate}.qweight" in key:
|
||||
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
|
||||
self._detected_format = fmt_name
|
||||
break
|
||||
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
|
||||
self._detected_format = fmt_name
|
||||
# Check for VL model (language_model prefix)
|
||||
if "language_model." in key:
|
||||
self._is_vl_model = True
|
||||
break
|
||||
elif fmt_name == "mistral" and "block_sparse_moe" not in key and "mlp" not in key:
|
||||
self._detected_format = fmt_name
|
||||
break
|
||||
if self._detected_format is not None:
|
||||
break
|
||||
|
||||
if self._detected_format is None:
|
||||
self._detected_format = "deepseek"
|
||||
|
||||
vl_str = " (VL model)" if self._is_vl_model else ""
|
||||
print(f"[GPTQSafeTensorLoader] Detected format: {self._detected_format}{vl_str}")
|
||||
|
||||
def _verify_gptq_config(self, file_path):
|
||||
"""Check that the model uses sym=true, desc_act=false."""
|
||||
import json
|
||||
import os
|
||||
|
||||
config_path = os.path.join(os.path.dirname(file_path), "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
# Try parent directory
|
||||
config_path = os.path.join(file_path, "config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
qc = config.get("quantization_config", {})
|
||||
if qc.get("quant_method") == "gptq":
|
||||
if qc.get("desc_act", False):
|
||||
raise NotImplementedError(
|
||||
"GPTQ desc_act=true is not supported. Only desc_act=false models are supported."
|
||||
)
|
||||
if not qc.get("sym", True):
|
||||
raise NotImplementedError(
|
||||
"GPTQ sym=false (asymmetric) is not supported. Only sym=true models are supported."
|
||||
)
|
||||
print(f"[GPTQSafeTensorLoader] Verified: sym={qc.get('sym')}, desc_act={qc.get('desc_act')}, "
|
||||
f"bits={qc.get('bits')}, group_size={qc.get('group_size')}")
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
"""Load GPTQ expert qweight and scales.
|
||||
|
||||
Returns dict with keys: gate, up, down (qweight int32), gate_scale, up_scale, down_scale (fp32).
|
||||
"""
|
||||
experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)
|
||||
gate_name, up_name, down_name = self._get_proj_names()
|
||||
|
||||
expert_count = 0
|
||||
experts_prefix = None
|
||||
for prefix in experts_prefix_candidates:
|
||||
expert_count = 0
|
||||
while self.has_tensor(f"{prefix}.{expert_count}.{gate_name}.qweight"):
|
||||
expert_count += 1
|
||||
if expert_count > 0:
|
||||
experts_prefix = prefix
|
||||
break
|
||||
|
||||
if expert_count == 0 or experts_prefix is None:
|
||||
raise ValueError(f"No GPTQ experts found for keys: {experts_prefix_candidates}")
|
||||
|
||||
gate_weights = [None] * expert_count
|
||||
up_weights = [None] * expert_count
|
||||
down_weights = [None] * expert_count
|
||||
gate_scales = [None] * expert_count
|
||||
up_scales = [None] * expert_count
|
||||
down_scales = [None] * expert_count
|
||||
|
||||
for exp_id in range(expert_count):
|
||||
gate_weights[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{gate_name}.qweight", device).contiguous()
|
||||
up_weights[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{up_name}.qweight", device).contiguous()
|
||||
down_weights[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{down_name}.qweight", device).contiguous()
|
||||
|
||||
gate_scales[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{gate_name}.scales", device).float().contiguous()
|
||||
up_scales[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{up_name}.scales", device).float().contiguous()
|
||||
down_scales[exp_id] = self.load_tensor(f"{experts_prefix}.{exp_id}.{down_name}.scales", device).float().contiguous()
|
||||
|
||||
print(f"[GPTQSafeTensorLoader] Loaded {expert_count} experts from {experts_prefix}")
|
||||
return {
|
||||
"gate": gate_weights,
|
||||
"up": up_weights,
|
||||
"down": down_weights,
|
||||
"gate_scale": gate_scales,
|
||||
"up_scale": up_scales,
|
||||
"down_scale": down_scales,
|
||||
}
|
||||
|
||||
@@ -579,7 +579,7 @@ class CMakeBuild(build_ext):
|
||||
avx512_extension_enabled = True
|
||||
|
||||
# If any AVX512 extension is enabled, ensure base AVX512 is also enabled
|
||||
if avx512_extension_enabled and cpu_mode in ("NATIVE", "FANCY", "AVX512"):
|
||||
if avx512_extension_enabled and cpu_mode in ("NATIVE", "FANCY", "AVX512") and "AVX512" in d["features"]:
|
||||
if not any("LLAMA_AVX512=ON" in a for a in cmake_args):
|
||||
cmake_args.append("-DLLAMA_AVX512=ON")
|
||||
print("-- AVX512 extensions enabled; also enabling base AVX512F (-DLLAMA_AVX512=ON)")
|
||||
|
||||
159
kt-kernel/test/per_commit/test_moe_avx2_accuracy_bf16.py
Normal file
159
kt-kernel/test/per_commit/test_moe_avx2_accuracy_bf16.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
"""AVX2 BF16 MoE accuracy tests for KT-Kernel.
|
||||
|
||||
Tests accuracy of AVX2 BF16 MOE operations against torch reference.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
import torch
|
||||
from kt_kernel import kt_kernel_ext
|
||||
|
||||
# Small test parameters for fast validation
|
||||
expert_num = 8
|
||||
hidden_size = 256
|
||||
intermediate_size = 512
|
||||
num_experts_per_tok = 2
|
||||
max_len = 128
|
||||
validation_iter = 3
|
||||
CPUINFER_PARAM = 60
|
||||
|
||||
|
||||
def act_fn(x):
|
||||
"""SiLU activation."""
|
||||
return x / (1.0 + torch.exp(-x))
|
||||
|
||||
|
||||
def mlp_torch(input, gate_proj, up_proj, down_proj):
|
||||
"""PyTorch reference MLP."""
|
||||
gate_buf = torch.mm(input, gate_proj.t())
|
||||
up_buf = torch.mm(input, up_proj.t())
|
||||
intermediate = act_fn(gate_buf) * up_buf
|
||||
return torch.mm(intermediate, down_proj.t())
|
||||
|
||||
|
||||
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
|
||||
"""PyTorch reference MoE."""
|
||||
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
|
||||
cnts.scatter_(1, expert_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = expert_ids.view(-1).argsort()
|
||||
sorted_tokens = input[idxs // expert_ids.shape[1]]
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
t_output = (
|
||||
new_x.view(*expert_ids.shape, -1)
|
||||
.type(weights.dtype)
|
||||
.mul_(weights.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return t_output
|
||||
|
||||
|
||||
def test_avx2_bf16_accuracy(qlen, label):
|
||||
"""Test AVX2 BF16 MoE accuracy."""
|
||||
physical_to_logical_map = torch.tensor(range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
|
||||
|
||||
with torch.inference_mode():
|
||||
# Generate BF16 weights
|
||||
gate_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16).contiguous()
|
||||
up_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16).contiguous()
|
||||
down_proj = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 10.0).to(torch.bfloat16).contiguous()
|
||||
|
||||
# Create MOE config
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
config.gate_scale = 0
|
||||
config.up_scale = 0
|
||||
config.down_scale = 0
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
# Create AVX2 BF16 MOE
|
||||
moe = kt_kernel_ext.moe.AVX2BF16_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
|
||||
print(f"\n--- {label} (qlen={qlen}) ---")
|
||||
for i in range(validation_iter):
|
||||
expert_ids = torch.stack(
|
||||
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
|
||||
).contiguous()
|
||||
weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
|
||||
input_data = (torch.randn((qlen, hidden_size), dtype=torch.float32) / 100.0).to(torch.bfloat16).contiguous()
|
||||
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
||||
|
||||
bsz_tensor = torch.tensor([qlen], dtype=torch.int32)
|
||||
|
||||
# Run AVX2 BF16 MOE
|
||||
CPUInfer.submit(
|
||||
moe.forward_task(
|
||||
bsz_tensor.data_ptr(),
|
||||
num_experts_per_tok,
|
||||
expert_ids.data_ptr(),
|
||||
weights.data_ptr(),
|
||||
input_data.data_ptr(),
|
||||
output.data_ptr(),
|
||||
False,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
|
||||
# Run torch reference (in float32 for accuracy)
|
||||
t_output = moe_torch(
|
||||
input_data.float(), expert_ids, weights,
|
||||
gate_proj.float(), up_proj.float(), down_proj.float()
|
||||
).to(torch.bfloat16)
|
||||
|
||||
# Calculate relative difference
|
||||
diff = torch.mean(torch.abs(output.float() - t_output.float())) / (torch.mean(torch.abs(t_output.float())) + 1e-8)
|
||||
print(f" Iteration {i}: diff = {diff:.6f}")
|
||||
|
||||
# BF16 should be very accurate (< 0.01)
|
||||
assert diff < 0.02, f"AVX2 BF16 accuracy test failed: diff={diff:.6f} >= 0.02"
|
||||
|
||||
print(f" PASSED")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("AVX2 BF16 MoE Accuracy Test")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Test decode path (qlen=1)
|
||||
test_avx2_bf16_accuracy(qlen=1, label="Decode")
|
||||
|
||||
# Test prefill path (qlen=16)
|
||||
test_avx2_bf16_accuracy(qlen=16, label="Prefill")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("ALL TESTS PASSED")
|
||||
print("=" * 60)
|
||||
except Exception as e:
|
||||
print(f"\nTEST FAILED: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
250
kt-kernel/test/per_commit/test_moe_avx2_accuracy_fp8.py
Normal file
250
kt-kernel/test/per_commit/test_moe_avx2_accuracy_fp8.py
Normal file
@@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
"""AVX2 FP8 MoE accuracy tests for KT-Kernel."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
import torch
|
||||
from kt_kernel import kt_kernel_ext
|
||||
|
||||
expert_num = 8
|
||||
hidden_size = 256
|
||||
intermediate_size = 512
|
||||
num_experts_per_tok = 2
|
||||
max_len = 128
|
||||
group_size = 128
|
||||
validation_iter = 3
|
||||
CPUINFER_PARAM = 60
|
||||
|
||||
|
||||
def fp8_e4m3_quantize(tensor_bf16):
|
||||
"""Quantize BF16 tensor to FP8 E4M3 with block-wise scales (128x128)."""
|
||||
n, k = tensor_bf16.shape
|
||||
tensor_fp32 = tensor_bf16.float()
|
||||
|
||||
n_blocks_n = (n + group_size - 1) // group_size
|
||||
n_blocks_k = (k + group_size - 1) // group_size
|
||||
|
||||
fp8_data = torch.zeros(n, k, dtype=torch.uint8)
|
||||
scales = torch.zeros(n_blocks_n, n_blocks_k, dtype=torch.float32)
|
||||
|
||||
# FP8 E4M3 max value: 2^8 * (1 + 7/8) = 448
|
||||
fp8_max = 448.0
|
||||
|
||||
for bn in range(n_blocks_n):
|
||||
for bk in range(n_blocks_k):
|
||||
n_start = bn * group_size
|
||||
n_end = min(n_start + group_size, n)
|
||||
k_start = bk * group_size
|
||||
k_end = min(k_start + group_size, k)
|
||||
|
||||
block = tensor_fp32[n_start:n_end, k_start:k_end]
|
||||
amax = block.abs().max().item()
|
||||
if amax == 0:
|
||||
scale = 1.0
|
||||
else:
|
||||
scale = amax / fp8_max
|
||||
scales[bn, bk] = scale
|
||||
|
||||
# Quantize
|
||||
for i in range(n_end - n_start):
|
||||
for j in range(k_end - k_start):
|
||||
val = block[i, j].item() / scale
|
||||
fp8_data[n_start + i, k_start + j] = float_to_fp8_e4m3(val)
|
||||
|
||||
return fp8_data, scales
|
||||
|
||||
|
||||
def float_to_fp8_e4m3(val):
|
||||
"""Convert float to FP8 E4M3."""
|
||||
if math.isnan(val):
|
||||
return 0x7F
|
||||
sign = 1 if val < 0 else 0
|
||||
val = abs(val)
|
||||
if val == 0:
|
||||
return sign << 7
|
||||
# Clamp to max
|
||||
if val >= 448.0:
|
||||
return (sign << 7) | 0x7E # max finite
|
||||
# Find exponent
|
||||
exp = int(math.floor(math.log2(val))) + 7
|
||||
if exp <= 0:
|
||||
# Subnormal
|
||||
man = int(round(val * (2**6) * 8))
|
||||
man = min(man, 7)
|
||||
return (sign << 7) | man
|
||||
if exp >= 15:
|
||||
return (sign << 7) | 0x7E # clamp to max
|
||||
# Normal
|
||||
man = int(round((val / (2**(exp-7)) - 1.0) * 8))
|
||||
man = min(man, 7)
|
||||
return (sign << 7) | (exp << 3) | man
|
||||
|
||||
|
||||
def fp8_e4m3_to_float(byte_val):
|
||||
"""Convert FP8 E4M3 byte to float."""
|
||||
sign = (byte_val >> 7) & 1
|
||||
exp = (byte_val >> 3) & 0xF
|
||||
man = byte_val & 0x7
|
||||
if exp == 0 and man == 0:
|
||||
return 0.0
|
||||
if exp == 0:
|
||||
val = (2**-6) * (man / 8.0)
|
||||
elif exp == 15:
|
||||
return float("nan")
|
||||
else:
|
||||
val = (2**(exp-7)) * (1.0 + man / 8.0)
|
||||
return -val if sign else val
|
||||
|
||||
|
||||
def fp8_dequantize(fp8_data, scales):
|
||||
"""Dequantize FP8 + scales back to float32."""
|
||||
n, k = fp8_data.shape
|
||||
result = torch.zeros(n, k, dtype=torch.float32)
|
||||
n_blocks_n = scales.shape[0]
|
||||
n_blocks_k = scales.shape[1]
|
||||
|
||||
for i in range(n):
|
||||
for j in range(k):
|
||||
bn = i // group_size
|
||||
bk = j // group_size
|
||||
scale = scales[bn, bk].item()
|
||||
fp8_val = fp8_e4m3_to_float(fp8_data[i, j].item())
|
||||
result[i, j] = fp8_val * scale
|
||||
return result
|
||||
|
||||
|
||||
def act_fn(x):
|
||||
return x / (1.0 + torch.exp(-x))
|
||||
|
||||
|
||||
def mlp_torch(input, gate_proj, up_proj, down_proj):
|
||||
gate_buf = torch.mm(input, gate_proj.t())
|
||||
up_buf = torch.mm(input, up_proj.t())
|
||||
intermediate = act_fn(gate_buf) * up_buf
|
||||
return torch.mm(intermediate, down_proj.t())
|
||||
|
||||
|
||||
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
|
||||
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
|
||||
cnts.scatter_(1, expert_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = expert_ids.view(-1).argsort()
|
||||
sorted_tokens = input[idxs // expert_ids.shape[1]]
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
tokens = sorted_tokens[start_idx:end_idx]
|
||||
out = mlp_torch(tokens, gate_proj[i], up_proj[i], down_proj[i])
|
||||
outputs.append(out)
|
||||
start_idx = end_idx
|
||||
outs = torch.cat(outputs, dim=0) if outputs else sorted_tokens.new_empty(0)
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
return (new_x.view(*expert_ids.shape, -1).float().mul_(weights.unsqueeze(-1)).sum(1)).to(new_x.dtype)
|
||||
|
||||
|
||||
def test_avx2_fp8_accuracy(qlen, label):
|
||||
physical_to_logical_map = torch.tensor(range(expert_num), dtype=torch.int64).contiguous()
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
|
||||
|
||||
with torch.inference_mode():
|
||||
# Generate BF16 weights, quantize to FP8
|
||||
gate_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16)
|
||||
up_bf16 = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 10.0).to(torch.bfloat16)
|
||||
down_bf16 = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 10.0).to(torch.bfloat16)
|
||||
|
||||
# Quantize each expert
|
||||
gate_fp8_list, gate_scale_list = [], []
|
||||
up_fp8_list, up_scale_list = [], []
|
||||
down_fp8_list, down_scale_list = [], []
|
||||
|
||||
for e in range(expert_num):
|
||||
gf, gs = fp8_e4m3_quantize(gate_bf16[e])
|
||||
gate_fp8_list.append(gf)
|
||||
gate_scale_list.append(gs)
|
||||
uf, us = fp8_e4m3_quantize(up_bf16[e])
|
||||
up_fp8_list.append(uf)
|
||||
up_scale_list.append(us)
|
||||
df, ds = fp8_e4m3_quantize(down_bf16[e])
|
||||
down_fp8_list.append(df)
|
||||
down_scale_list.append(ds)
|
||||
|
||||
# Stack into contiguous tensors
|
||||
gate_fp8 = torch.stack(gate_fp8_list).contiguous()
|
||||
gate_scales = torch.stack(gate_scale_list).contiguous()
|
||||
up_fp8 = torch.stack(up_fp8_list).contiguous()
|
||||
up_scales = torch.stack(up_scale_list).contiguous()
|
||||
down_fp8 = torch.stack(down_fp8_list).contiguous()
|
||||
down_scales = torch.stack(down_scale_list).contiguous()
|
||||
|
||||
# Dequantize for reference computation
|
||||
gate_deq = torch.stack([fp8_dequantize(gate_fp8_list[e], gate_scale_list[e]) for e in range(expert_num)])
|
||||
up_deq = torch.stack([fp8_dequantize(up_fp8_list[e], up_scale_list[e]) for e in range(expert_num)])
|
||||
down_deq = torch.stack([fp8_dequantize(down_fp8_list[e], down_scale_list[e]) for e in range(expert_num)])
|
||||
|
||||
# Create MOE config
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
config.gate_proj = gate_fp8.data_ptr()
|
||||
config.up_proj = up_fp8.data_ptr()
|
||||
config.down_proj = down_fp8.data_ptr()
|
||||
config.gate_scale = gate_scales.data_ptr()
|
||||
config.up_scale = up_scales.data_ptr()
|
||||
config.down_scale = down_scales.data_ptr()
|
||||
config.quant_config.bits = 8
|
||||
config.quant_config.group_size = group_size
|
||||
config.quant_config.zero_point = False
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
moe = kt_kernel_ext.moe.AVX2FP8_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
|
||||
print("\n--- %s (qlen=%d) ---" % (label, qlen))
|
||||
for i in range(validation_iter):
|
||||
expert_ids = torch.stack([torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]).contiguous()
|
||||
weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
|
||||
input_data = (torch.randn((qlen, hidden_size), dtype=torch.float32) / 100.0).to(torch.bfloat16).contiguous()
|
||||
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
||||
|
||||
bsz_tensor = torch.tensor([qlen], dtype=torch.int32)
|
||||
CPUInfer.submit(moe.forward_task(
|
||||
bsz_tensor.data_ptr(), num_experts_per_tok,
|
||||
expert_ids.data_ptr(), weights.data_ptr(),
|
||||
input_data.data_ptr(), output.data_ptr(), False,
|
||||
))
|
||||
CPUInfer.sync()
|
||||
|
||||
# Reference: use dequantized FP32 weights
|
||||
t_output = moe_torch(input_data.float(), expert_ids, weights, gate_deq, up_deq, down_deq).to(torch.bfloat16)
|
||||
|
||||
diff = torch.mean(torch.abs(output.float() - t_output.float())) / (torch.mean(torch.abs(t_output.float())) + 1e-8)
|
||||
print(" Iteration %d: diff = %.6f" % (i, diff.item()))
|
||||
assert diff < 0.1, "FP8 accuracy test failed: diff=%.6f >= 0.1" % diff.item()
|
||||
|
||||
print(" PASSED")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("AVX2 FP8 MoE Accuracy Test")
|
||||
print("=" * 60)
|
||||
try:
|
||||
test_avx2_fp8_accuracy(qlen=1, label="Decode")
|
||||
test_avx2_fp8_accuracy(qlen=16, label="Prefill")
|
||||
print("\n" + "=" * 60)
|
||||
print("ALL TESTS PASSED")
|
||||
print("=" * 60)
|
||||
except Exception as e:
|
||||
print("\nTEST FAILED: %s" % e)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user