Add cache quantization

This commit is contained in:
turboderp
2025-04-22 21:52:33 +02:00
parent fc5b39c2bb
commit cf84811485
14 changed files with 1156 additions and 16 deletions

View File

@@ -6,7 +6,7 @@ This is an **early preview release** of ExLlamaV3. Please note: ↙
- The framework <u>is not yet fully optimized</u>. Performance is lacking, especially on Ampere, and there may be a significant CPU bottleneck on slower processors until the extension functions are fully built out.
- AMD GPUs (ROCm) are not yet supported.
- [FlashAttention-2](https://github.com/Dao-AILab/flash-attention) is currently required. I hope to switch over to [FlashInfer](https://github.com/flashinfer-ai/flashinfer/tree/main) in time, but there are some obstacles to overcome first.
- A number of important features are yet to be added, such as cache quantization, tensor parallelism and multimodal support.
- A number of important features are yet to be added, such as tensor parallelism and multimodal support.
- There are no release builds yet.
- Integration into [TabbyAPI](https://github.com/theroyallab/tabbyAPI/) is planned when all the core functionality is in place.
@@ -26,7 +26,6 @@ There's much that still needs to be added and/or ported over from ExLlamaV2. I'v
- Samplers (most notably repetition penalties and min-P are missing)
- Constrained sampling (JSON filters etc.)
- Multimodal support
- Cache quantization
- LoRA support
- ROCm support
- Tensor-parallel inference

BIN
doc/cq_humaneval.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

View File

@@ -1,2 +1,3 @@
from .cache import Cache, CacheLayer
from .fp16 import CacheLayer_fp16
from .fp16 import CacheLayer_fp16
from .quant import CacheLayer_quant

View File

@@ -16,6 +16,7 @@ class CacheLayer(ABC):
config: Config,
attention: Attention,
max_num_tokens: int,
**kwargs
):
self.config = config
self.attention = attention
@@ -30,7 +31,18 @@ class CacheLayer(ABC):
pass
@abstractmethod
def get_kv(self):
def get_kv(self, cache_seqlens: torch.Tensor, block_table: torch.Tensor) -> tuple:
pass
@abstractmethod
def update_kv(
self,
cache_seqlens: torch.Tensor,
block_table: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
length: int
):
pass
@abstractmethod
@@ -45,6 +57,7 @@ class Cache:
model: Model,
max_num_tokens: int,
layer_type: Type[CacheLayer] | None = None,
**kwargs
):
"""
Create cache for model
@@ -71,7 +84,7 @@ class Cache:
self.num_layers = len(self.model.get_cache_layers())
self.layers = [
self.layer_type(self.config, attn, self.max_num_tokens)
self.layer_type(self.config, attn, self.max_num_tokens, **kwargs)
for attn in self.model.get_cache_layers()
]
self.attach_to_model()
@@ -107,8 +120,20 @@ class Cache:
module.cache_layers.remove(layer)
def get_layer(self, idx: int) -> tuple:
return self.layers[idx].get_kv()
def get_layer(self, idx: int, cache_seqlens: torch.Tensor, block_table: torch.Tensor) -> tuple:
return self.layers[idx].get_kv(cache_seqlens, block_table)
def update_layer(
self,
idx: int,
cache_seqlens: torch.Tensor,
block_table: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
length: int
):
return self.layers[idx].update_kv(cache_seqlens, block_table, k, v, length)
def copy_page(

View File

@@ -47,16 +47,24 @@ class CacheLayer_fp16(CacheLayer):
@override
def get_kv(self):
def get_kv(self, cache_seqlens: torch.Tensor, block_table: torch.Tensor) -> tuple:
return self.k, self.v
@override
def update_kv(
self,
cache_seqlens: torch.Tensor,
block_table: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
length: int
):
pass
@override
def copy_page(self, source: CacheLayer_fp16, from_page: int, to_page: int, num_tokens: int):
assert self.shape == source.shape
kd = self.k[to_page, :num_tokens, :, :]
vd = self.v[to_page, :num_tokens, :, :]
ks = source.k[from_page, :num_tokens, :, :]
vs = source.v[from_page, :num_tokens, :, :]
kd.copy_(ks, non_blocking = True)
vd.copy_(vs, non_blocking = True)
self.k[to_page, :num_tokens, :, :].copy_(source.k[from_page, :num_tokens, :, :], non_blocking = True)
self.v[to_page, :num_tokens, :, :].copy_(source.v[from_page, :num_tokens, :, :], non_blocking = True)

100
exllamav3/cache/quant.py vendored Normal file
View File

@@ -0,0 +1,100 @@
from __future__ import annotations
from typing_extensions import override
import torch
import torch.nn.functional as F
from torch import nn
from ..constants import PAGE_SIZE
from ..models import Model, Config
from .cache import CacheLayer
from typing import TYPE_CHECKING
from exllamav3.ext import exllamav3_ext as ext
if TYPE_CHECKING:
from ..modules import Attention
class CacheLayer_quant(CacheLayer):
def __init__(
self,
config: Config,
attention: Attention,
max_num_tokens: int,
k_bits: int,
v_bits: int,
):
super().__init__(config, attention, max_num_tokens)
assert max_num_tokens % PAGE_SIZE == 0, \
f"max_num_tokens must be a multiple of {PAGE_SIZE}."
assert (2 <= k_bits <= 8) and (2 <= v_bits <= 8), "quantized cache must be from 2 to 8 bits"
self.shape = (
(max_num_tokens // PAGE_SIZE, PAGE_SIZE, attention.num_kv_heads, attention.head_dim)
if attention else None
)
self.k_bits = k_bits
self.v_bits = v_bits
self.token_dim = attention.num_kv_heads * attention.head_dim
self.qshape_k = ((max_num_tokens // PAGE_SIZE, PAGE_SIZE, self.token_dim // 32 * k_bits) if attention else None)
self.qshape_v = ((max_num_tokens // PAGE_SIZE, PAGE_SIZE, self.token_dim // 32 * v_bits) if attention else None)
self.qshape_s = ((max_num_tokens // PAGE_SIZE, PAGE_SIZE, self.token_dim // 32) if attention else None)
self.qk = None
self.qv = None
self.sk = None
self.sv = None
self.device = None
@override
def alloc(self, device: torch.device):
self.device = device
self.qk = torch.zeros(self.qshape_k, dtype = torch.int, device = device) if self.shape else None
self.qv = torch.zeros(self.qshape_v, dtype = torch.int, device = device) if self.shape else None
self.sk = torch.zeros(self.qshape_s, dtype = torch.half, device = device) if self.shape else None
self.sv = torch.zeros(self.qshape_s, dtype = torch.half, device = device) if self.shape else None
@override
def free(self):
self.device = None
self.qk = None
self.qv = None
self.sk = None
self.sv = None
@override
def get_kv(self, cache_seqlens: torch.Tensor, block_table: torch.Tensor):
k = torch.empty(self.shape, dtype = torch.half, device = self.device)
v = torch.empty(self.shape, dtype = torch.half, device = self.device)
ext.dequant_cache_paged(self.qk, self.sk, k, self.qv, self.sv, v, cache_seqlens, block_table, PAGE_SIZE)
return k, v
@override
def update_kv(
self,
cache_seqlens: torch.Tensor,
block_table: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
length: int
):
ext.quant_cache_paged(
k, self.qk, self.sk,
v, self.qv, self.sv,
cache_seqlens, block_table,
PAGE_SIZE,
length
)
@override
def copy_page(self, source: CacheLayer_quant, from_page: int, to_page: int, num_tokens: int):
assert self.qshape_k == source.qshape_k
assert self.qshape_v == source.qshape_v
self.qk[to_page, :num_tokens, :].copy_(source.qk[from_page, :num_tokens, :], non_blocking = True)
self.qv[to_page, :num_tokens, :].copy_(source.qv[from_page, :num_tokens, :], non_blocking = True)
self.sk[to_page, :num_tokens, :].copy_(source.sk[from_page, :num_tokens, :], non_blocking = True)
self.sv[to_page, :num_tokens, :].copy_(source.sv[from_page, :num_tokens, :], non_blocking = True)

View File

@@ -21,6 +21,8 @@
#include "generator/sampling_basic.cuh"
#include "generator/gumbel.cuh"
#include "cache/q_cache.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("stloader_read", &stloader_read, "stloader_read");
@@ -56,4 +58,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("partial_strings_match", &partial_strings_match, "partial_strings_match");
m.def("count_match_tensor", &count_match_tensor, "count_match_tensor");
m.def("quant_cache_cont", &quant_cache_cont, "quant_cache_cont");
m.def("dequant_cache_cont", &dequant_cache_cont, "dequant_cache_cont");
m.def("quant_cache_paged", &quant_cache_paged, "quant_cache_paged");
m.def("dequant_cache_paged", &dequant_cache_paged, "dequant_cache_paged");
}

280
exllamav3/exllamav3_ext/cache/q_cache.cu vendored Normal file
View File

@@ -0,0 +1,280 @@
#include "q_cache.cuh"
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_fp16.h>
#include "../util.h"
#include "../util.cuh"
#include <limits>
#include "../quant/codebook.cuh"
#include "q_cache_kernels.cuh"
/*
Quantize contiguous tensor
in: float16, shape (..., dim)
out: int32, shape (..., dim / 32 * bitrate)
out_scales: float16, shape (..., dim / 32)
*/
void quant_cache_cont
(
const at::Tensor& in,
const at::Tensor& out,
const at::Tensor& out_scales
)
{
const at::cuda::OptionalCUDAGuard device_guard(in.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_DTYPE(in, kHalf);
TORCH_CHECK_DTYPE(out, kInt);
TORCH_CHECK_DTYPE(out_scales, kHalf);
int bsz = in.numel() / 32;
int head_dim = in.size(-1);
int head_blocks = head_dim / 32;
TORCH_CHECK(head_dim == 32 * head_blocks, "head_dim must be a multiple of 32");
int bits = out.size(-1) / head_blocks;
TORCH_CHECK(out.numel() == bsz * bits, "out is wrong size");
TORCH_CHECK(out_scales.numel() == bsz, "out_scales is wrong size");
TORCH_CHECK(2 <= bits && bits <= 8, "Unsupported K/V cache bitrate");
static_for_pack<2,3,4,5,6,7,8>([&](auto ic)
{
constexpr int i = decltype(ic)::value;
if (bits == i)
{
quant_cache_cont_kernel<i><<<bsz, 32, 0, stream>>>
(
(const half*) in.data_ptr(),
(uint32_t*) out.data_ptr(),
(half*) out_scales.data_ptr()
);
}
});
cuda_check(cudaPeekAtLastError());
}
/*
Dequantize contiguous tensor
in: int32, shape (..., dim / 32 * bitrate)
in_scales: float16, shape (..., dim / 32)
out: float16, shape (..., dim)
*/
void dequant_cache_cont
(
const at::Tensor& in,
const at::Tensor& in_scales,
const at::Tensor& out
)
{
const at::cuda::OptionalCUDAGuard device_guard(in.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_DTYPE(in, kInt);
TORCH_CHECK_DTYPE(in_scales, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
int bsz = out.numel() / 32;
int head_dim = out.size(-1);
int head_blocks = head_dim / 32;
TORCH_CHECK(head_dim == 32 * head_blocks, "head_dim must be a multiple of 32");
int bits = in.size(-1) / head_blocks;
TORCH_CHECK(in.numel() == bsz * bits, "in is wrong size");
TORCH_CHECK(in_scales.numel() == bsz, "in_scales is wrong size");
TORCH_CHECK(2 <= bits && bits <= 8, "Unsupported K/V cache bitrate");
static_for_pack<2,3,4,5,6,7,8>([&](auto ic)
{
constexpr int i = decltype(ic)::value;
if (bits == i)
{
dequant_cache_cont_kernel<i><<<bsz, 32, 0, stream>>>
(
(const uint32_t*) in.data_ptr(),
(const half*) in_scales.data_ptr(),
(half*) out.data_ptr()
);
}
});
cuda_check(cudaPeekAtLastError());
}
/*
Quantize paged tensor
k_in, v_in: float16, shape (1, cache_size, dim)
k_out, v_out: int32, shape (1, cache_size, dim / 32 * bitrate)
k_out_scales, v_out_scales: float16, shape (1, cache_size, dim / 32)
cache_seqlens: int32, length of each sequence in batch, k_out and v_out are updated _from_ this point
block_table: int32, shape (bsz, blocks_per_seq)
page_size: 256
seq_len: number of positions (size: dim) to update from end of each sequence
*/
void quant_cache_paged
(
const at::Tensor& k_in,
const at::Tensor& k_out,
const at::Tensor& k_out_scales,
const at::Tensor& v_in,
const at::Tensor& v_out,
const at::Tensor& v_out_scales,
const at::Tensor& cache_seqlens,
const at::Tensor& block_table,
int page_size,
int seq_len
)
{
const at::cuda::OptionalCUDAGuard device_guard(k_in.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_DTYPE(k_in, kHalf);
TORCH_CHECK_DTYPE(k_out, kInt);
TORCH_CHECK_DTYPE(k_out_scales, kHalf);
TORCH_CHECK_DTYPE(v_in, kHalf);
TORCH_CHECK_DTYPE(v_out, kInt);
TORCH_CHECK_DTYPE(v_out_scales, kHalf);
TORCH_CHECK_SHAPES_FULL(k_in, v_in);
TORCH_CHECK_SHAPES_FULL(k_out_scales, v_out_scales);
int dim;
if (k_in.dim() == 4)
dim = k_in.size(2) * k_in.size(3);
else if (k_in.dim() == 3)
dim = k_in.size(2);
else
TORCH_CHECK(false, "paged cache must be 3D or 4D")
int warps_per_token = dim / 32;
TORCH_CHECK(dim == 32 * warps_per_token, "dim must be a multiple of 32");
int tb_per_token = CEIL_DIVIDE(warps_per_token, MAX_WARPS); // Threadblocks per token position
int tb_usage = CEIL_DIVIDE(warps_per_token, tb_per_token); // Number of warps to use per threadblock
TORCH_CHECK(k_out.dim() == 3 && v_out.dim() == 3, "paged q.cache must have shape (num_pages, page_size, dim // 32 * bitrate)")
int k_bits = k_out.size(2) / warps_per_token;
int v_bits = v_out.size(2) / warps_per_token;
int bsz = block_table.size(0);
int blocks_per_seq = block_table.size(1);
dim3 blocks(tb_per_token, seq_len, bsz);
dim3 threads(32 * tb_usage);
static_for_pack<2,3,4,5,6,7,8>([&](auto jc)
{
constexpr int j = decltype(jc)::value;
static_for_pack<2,3,4,5,6,7,8>([&](auto ic)
{
constexpr int i = decltype(ic)::value;
if (k_bits == i && v_bits == j)
{
quant_cache_paged_kernel<i, j><<<blocks, threads, 0, stream>>>
(
(const half*) k_in.data_ptr(),
(uint32_t*) k_out.data_ptr(),
(half*) k_out_scales.data_ptr(),
(const half*) v_in.data_ptr(),
(uint32_t*) v_out.data_ptr(),
(half*) v_out_scales.data_ptr(),
(const uint32_t*) cache_seqlens.data_ptr(),
(const uint32_t*) block_table.data_ptr(),
page_size,
blocks_per_seq,
dim
);
}
});
});
cuda_check(cudaPeekAtLastError());
}
/*
Dequantize paged tensor
k_in, v_in: int32, shape (1, cache_size, dim / 32 * bitrate)
k_in_scales, v_in_scales: float16, shape (1, cache_size, dim / 32)
k_out, v_out: float16, shape (1, cache_size, dim)
cache_seqlens: int32, length of each sequence in batch, k_out and v_out are updated _up_to_ this point
block_table: int32, shape (bsz, blocks_per_seq)
page_size: 256
*/
void dequant_cache_paged
(
const at::Tensor& k_in,
const at::Tensor& k_in_scales,
const at::Tensor& k_out,
const at::Tensor& v_in,
const at::Tensor& v_in_scales,
const at::Tensor& v_out,
const at::Tensor& cache_seqlens,
const at::Tensor& block_table,
int page_size
)
{
const at::cuda::OptionalCUDAGuard device_guard(k_in.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_DTYPE(k_in, kInt);
TORCH_CHECK_DTYPE(k_in_scales, kHalf);
TORCH_CHECK_DTYPE(k_out, kHalf);
TORCH_CHECK_DTYPE(v_in, kInt);
TORCH_CHECK_DTYPE(v_in_scales, kHalf);
TORCH_CHECK_DTYPE(v_out, kHalf);
TORCH_CHECK_SHAPES_FULL(k_in_scales, v_in_scales);
TORCH_CHECK_SHAPES_FULL(k_out, v_out);
int dim;
if (k_out.dim() == 4)
dim = k_out.size(2) * k_out.size(3);
else if (k_out.dim() == 3)
dim = k_out.size(2);
else
TORCH_CHECK(false, "paged cache must be 3D or 4D")
int warps_per_token = dim / 32;
TORCH_CHECK(dim == 32 * warps_per_token, "dim must be a multiple of 32");
int bsz = block_table.size(0);
int pages_per_seq = block_table.size(1);
int warps_per_seq = pages_per_seq * page_size * warps_per_token;
int num_tb = CEIL_DIVIDE(32 * warps_per_seq, 1024);
int num_threads = MIN(32 * warps_per_seq, 1024);
dim3 blocks(num_tb, bsz);
dim3 threads(num_threads);
TORCH_CHECK(k_in.dim() == 3 && v_in.dim() == 3, "paged q.cache must have shape (num_pages, page_size, dim // 32 * bitrate)")
int k_bits = k_in.size(2) / warps_per_token;
int v_bits = v_in.size(2) / warps_per_token;
static_for_pack<2,3,4,5,6,7,8>([&](auto jc)
{
constexpr int j = decltype(jc)::value;
static_for_pack<2,3,4,5,6,7,8>([&](auto ic)
{
constexpr int i = decltype(ic)::value;
if (k_bits == i && v_bits == j)
{
dequant_cache_paged_kernel<i, j><<<blocks, threads, 0, stream>>>
(
(const uint32_t*) k_in.data_ptr(),
(const half*) k_in_scales.data_ptr(),
(half*) k_out.data_ptr(),
(const uint32_t*) v_in.data_ptr(),
(const half*) v_in_scales.data_ptr(),
(half*) v_out.data_ptr(),
(const uint32_t*) cache_seqlens.data_ptr(),
(const uint32_t*) block_table.data_ptr(),
page_size,
pages_per_seq,
warps_per_token
);
}
});
});
cuda_check(cudaPeekAtLastError());
}

View File

@@ -0,0 +1,44 @@
#pragma once
#include <ATen/Tensor.h>
void quant_cache_cont
(
const at::Tensor& in,
const at::Tensor& out,
const at::Tensor& out_scales
);
void dequant_cache_cont
(
const at::Tensor& in,
const at::Tensor& in_scales,
const at::Tensor& out
);
void quant_cache_paged
(
const at::Tensor& k_in,
const at::Tensor& k_out,
const at::Tensor& k_out_scales,
const at::Tensor& v_in,
const at::Tensor& v_out,
const at::Tensor& v_out_scales,
const at::Tensor& cache_seqlens,
const at::Tensor& block_table,
int page_size,
int seq_len
);
void dequant_cache_paged
(
const at::Tensor& k_in,
const at::Tensor& k_in_scales,
const at::Tensor& k_out,
const at::Tensor& v_in,
const at::Tensor& v_in_scales,
const at::Tensor& v_out,
const at::Tensor& cache_seqlens,
const at::Tensor& block_table,
int page_size
);

View File

@@ -0,0 +1,192 @@
__device__ inline float shuffle_had_fx32(float v, int lane_id)
{
for (int i = 1; i < 32; i <<= 1)
{
float pv = __shfl_xor_sync(0xffffffff, v, i);
uint32_t* vi = reinterpret_cast<uint32_t*>(&v);
int32_t sfm = -static_cast<int16_t>(lane_id & i) >> 31;
*vi ^= (sfm & 0x80000000);
v = v + pv;
}
return v;
}
__device__ inline float shuffle_sum_fx32(float s)
{
for (int i = 1; i < 32; i <<= 1)
s += __shfl_xor_sync(0xffffffff, s, i);
return s;
}
__device__ inline float shuffle_max_fx32(float s)
{
for (int i = 1; i < 32; i <<= 1)
s = fmaxf(s, __shfl_xor_sync(0xffffffff, s, i));
return s;
}
template <int bits>
__device__ inline void quant_block
(
const half* in,
uint32_t* out,
half* out_scales
)
{
int t = threadIdx.x % 32;
// Load, rotate and scale 32 values
float v = __half2float(in[t]);
v = shuffle_had_fx32(v, t);
v *= 0.17678f; // 0.17678 = 1 / sqrt(32)
float s = shuffle_max_fx32(fabsf(v) + 1e-10);
half sh = __float2half_rn(s);
s = __half2float(sh);
v /= s;
// Quantize and clamp
int m = (1 << (bits - 1));
v *= __int2float_rn(m);
int q = lrintf(v) + m;
q = max(min((1 << bits) - 1, q), 0);
// Pack bits
register uint32_t bitplanes[bits];
for (int i = 0, mask = 1; mask <= m; ++i, mask <<= 1)
bitplanes[i] = __ballot_sync(0xffffffff, q & mask);
// Write output
if (t < bits)
out[t] = bitplanes[t];
if (t == bits)
*out_scales = sh;
}
#define MAX_WARPS 32
template <int bits>
__device__ inline void dequant_block
(
const uint32_t* in,
const half* in_scales,
half* out
)
{
int t = threadIdx.x % 32;
int warp_id = threadIdx.x / 32;
// Load scale and bitplanes
float s = __half2float(*in_scales);
__shared__ uint32_t bitplanes[MAX_WARPS][bits];
if (t < bits)
bitplanes[warp_id][t] = in[t];
__syncthreads();
// Unpack bits
int m = (1 << (bits - 1));
uint32_t mask = 1 << t;
int q = 0;
for (int i = 0; i < bits; ++i)
q |= ((bitplanes[warp_id][i] & mask) >> t) << i;
// Dequantize
float v = __int2float_rn(q - m);
v /= __int2float_rn(m);
// Scale and rotate
v *= s;
v = shuffle_had_fx32(v, t);
v *= 0.17678f; // 0.17678 = 1 / sqrt(32)
// Store
out[t] = __float2half(v);
}
template <int bits>
__global__ __launch_bounds__(1024)
void quant_cache_cont_kernel
(
const half* __restrict__ in,
uint32_t* __restrict__ out,
half* __restrict__ out_scales
)
{
in += 32 * blockIdx.x;
out += bits * blockIdx.x;
out_scales += blockIdx.x;
quant_block<bits>(in, out, out_scales);
}
template <int bits>
__global__ __launch_bounds__(32)
void dequant_cache_cont_kernel
(
const uint32_t* __restrict__ in,
const half* __restrict__ in_scales,
half* __restrict__ out
)
{
in += bits * blockIdx.x;
in_scales += blockIdx.x;
out += 32 * blockIdx.x;
dequant_block<bits>(in, in_scales, out);
}
template <int k_bits, int v_bits>
__global__ __launch_bounds__(1024)
void quant_cache_paged_kernel
(
const half* __restrict__ k_in,
uint32_t* __restrict__ k_out,
half* __restrict__ k_out_scales,
const half* __restrict__ v_in,
uint32_t* __restrict__ v_out,
half* __restrict__ v_out_scales,
const uint32_t* __restrict__ cache_seqlens,
const uint32_t* __restrict__ block_table,
int page_size,
int blocks_per_seq,
int token_dim
)
{
int batch_idx = blockIdx.z;
int token_idx = blockIdx.y + cache_seqlens[batch_idx];
int page_idx = token_idx / page_size;
int token_pos = block_table[blocks_per_seq * batch_idx + page_idx] * page_size + (token_idx % page_size);
int sub_pos = (token_pos * token_dim + blockDim.x * blockIdx.x + threadIdx.x) / 32;
quant_block<k_bits>(k_in + sub_pos * 32, k_out + sub_pos * k_bits, k_out_scales + sub_pos);
quant_block<v_bits>(v_in + sub_pos * 32, v_out + sub_pos * v_bits, v_out_scales + sub_pos);
}
template <int k_bits, int v_bits>
__global__ __launch_bounds__(1024)
void dequant_cache_paged_kernel
(
const uint32_t* __restrict__ k_in,
const half* __restrict__ k_in_scales,
half* __restrict__ k_out,
const uint32_t* __restrict__ v_in,
const half* __restrict__ v_in_scales,
half* __restrict__ v_out,
const uint32_t* __restrict__ cache_seqlens,
const uint32_t* __restrict__ block_table,
int page_size,
int pages_per_seq,
int warps_per_token
)
{
int batch_idx = blockIdx.y;
int t_warp_id = (blockDim.x * blockIdx.x + threadIdx.x) / 32;
int token_idx = t_warp_id / warps_per_token;
int max_token_idx = cache_seqlens[batch_idx];
if (token_idx >= max_token_idx) return;
int page_idx = token_idx / page_size;
int page_sub = t_warp_id % (warps_per_token * page_size);
int mapped_page = block_table[batch_idx * pages_per_seq + page_idx];
int addr = mapped_page * page_size * warps_per_token + page_sub;
dequant_block<k_bits>(k_in + addr * k_bits, k_in_scales + addr, k_out + addr * 32);
dequant_block<v_bits>(v_in + addr * v_bits, v_in_scales + addr, v_out + addr * 32);
}

View File

@@ -1,4 +1,5 @@
from . import Model, Config, Cache, Tokenizer
from .cache import CacheLayer_fp16, CacheLayer_quant
from argparse import ArgumentParser
import torch
@@ -20,6 +21,7 @@ def add_args(
if cache:
parser.add_argument("-cs", "--cache_size", type = int, help = "Total cache size in tokens, default: 8192", default = 8192)
parser.add_argument("-cq", "--cache_quant", type = str, help = "Use quantized cache. Specify either kv_bits or k_bits,v_bits pair")
# TODO:
# parser.add_argument("-tp", "--tensor_parallel", action = "store_true", help = "Load in tensor-parallel mode")
@@ -69,7 +71,30 @@ def init(
model = Model.from_config(config)
# Cache
cache = Cache(model, max_num_tokens = args.cache_size) if "cache_size" in vars(args) else None
if "cache_size" in vars(args):
if args.cache_quant is not None:
split = [int(bits) for bits in args.cache_quant.split(",")]
if len(split) == 1:
k_bits = v_bits = split[0]
elif len(split) == 2:
k_bits, v_bits = tuple(split)
else:
raise ValueError("Specify either one or two bitrates for cache quantization")
cache = Cache(
model,
max_num_tokens = args.cache_size,
layer_type = CacheLayer_quant,
k_bits = k_bits,
v_bits = v_bits
)
else:
cache = Cache(
model,
max_num_tokens = args.cache_size,
layer_type = CacheLayer_fp16
)
else:
cache = None
# Split
if args.gpu_split is None or args.gpu_split == "auto":

View File

@@ -348,7 +348,7 @@ class Attention(Module):
q, k = self.rope.apply(q, k, position, positions, position_ids, in_place = True)
cache_k, cache_v = cache.get_layer(self.layer_idx)
cache_k, cache_v = cache.get_layer(self.layer_idx, cache_seqlens, block_table)
o = flash_attn_with_kvcache(
q = q,
k = k,
@@ -362,6 +362,7 @@ class Attention(Module):
window_size = (self.sliding_window, self.sliding_window),
softcap = self.logit_softcapping
)
cache.update_layer(self.layer_idx, cache_seqlens, block_table, cache_k, cache_v, seqlen)
o = o.view((bsz, seqlen, self.num_q_heads * self.head_dim))
# TODO: Store updated cache layer

314
science/kv_quant_exp.py Normal file
View File

@@ -0,0 +1,314 @@
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import torch.nn.functional as F
from exllamav3 import Config, Model, Tokenizer
from exllamav3.modules import TransformerBlock
from exllamav3.util.hadamard import get_hadamard_dt
from datasets import load_dataset
from exllamav3.util.file import disk_lru_cache, disk_lru_cache_clear
from flash_attn import flash_attn_func
from ref_quant2 import quantquant
from exllamav3.ext import exllamav3_ext as ext
import math
torch.set_printoptions(precision = 8, sci_mode = False, linewidth = 200)
model_dir = "/mnt/str/models/llama3.1-8b-instruct/hf/"
device = "cuda:1"
target_layers = [0]
num_rows = 1
# Create input tensor
@disk_lru_cache("get_test_data")
def get_test_data():
return "\n\n".join(
load_dataset("wikitext", "wikitext-2-raw-v1", split = "test")
["text"]
)
# Sample Q and K tensors from forward pass, Llama type model
@disk_lru_cache("sample_qkv")
def sample_qkv(_model_dir, _target_layers, _num_rows):
# Load model
config = Config.from_directory(_model_dir)
model = Model.from_config(config)
model.load(device, progressbar = True)
tokenizer = Tokenizer.from_config(config)
test_data = get_test_data()[:100000]
eval_tokens = tokenizer.encode(test_data)
eval_len = 2048
eval_stride = 512
num_tokens = eval_tokens.shape[-1]
seqs = []
for a in range(0, num_tokens - eval_len, eval_stride):
b = a + eval_len
seqs.append(eval_tokens[:, a:b])
if len(seqs) >= num_rows:
break
input_ids = torch.cat(seqs, dim = 0)[:, :]
_samples_qkv = []
params = {}
x = model.prepare_inputs(input_ids, params)
for idx, module in enumerate(model.modules):
params["prefill"] = (idx == model.last_kv_module_idx)
x = module.prepare_for_device(x, params)
if isinstance(module, TransformerBlock):
block_idx = int(module.key.split(".")[-1])
if block_idx > max(_target_layers):
break
if block_idx in _target_layers:
# Pre-attn norm
y = module.attn_norm.forward(x, params, out_dtype = torch.half)
# Projections and RoPE
attn = module.attn
bsz, seqlen, _ = y.shape
position, positions, position_ids = 0, None, None
q, k, v = attn.project_qkv(y, params)
q = q.view(bsz, seqlen, attn.num_q_heads, attn.head_dim)
k = k.view(bsz, seqlen, attn.num_kv_heads, attn.head_dim)
v = v.view(bsz, seqlen, attn.num_kv_heads, attn.head_dim)
q, k = attn.rope.apply(q, k, position, positions, position_ids)
# Sample right before dot product
_samples_qkv.append((q, k, v))
# Advance state
x = module.forward(x, params)
return _samples_qkv
samples_qkv = sample_qkv(model_dir, target_layers, num_rows)
# Get attention scores and output
def attn(q, k, v):
bsz, q_len, n_heads_q, head_dim = q.shape
_, k_len, n_heads_k, _ = k.shape
gqa = n_heads_q // n_heads_k
k_int = k.repeat_interleave(gqa, dim = 2)
scores = torch.einsum('bqhd,bkhd->bhqk', q, k_int) / math.sqrt(head_dim)
# Causal mask
mask = torch.ones((k_len, k_len), dtype = torch.bool, device = q.device).triu(diagonal = 1)
mask = mask[-q_len:, :]
scores = scores.masked_fill_(mask, -65504.)
# Now attention
o = flash_attn_func(
q = q,
k = k,
v = v,
causal = True,
)
return o, scores
# Refence method
def int_quant(v, bits):
m = 1 << (bits - 1)
scales = torch.amax(v.abs(), dim = -1).unsqueeze(3)
v = v / scales
vq = (v * m).round().clamp(-m, m - 1)
vq /= m
vq *= scales
return vq
# def quant_nf4(t):
# scales = torch.amax(t.abs(), dim = -1).unsqueeze(3)
# tq = t / scales
# tqq = torch.empty_like(tq)
# ext.test_nf4(tq, tqq)
# tqq *= scales
# return tqq
def quant_fp8(t):
return t.to(torch.float8_e4m3fn).half()
# Kernel equiv reference
def kernel_ref_quant(v, bits):
had32 = get_hadamard_dt(32, v.device, torch.half)
w = v.view(-1, 32)
m = 1 << (bits - 1)
w = w @ had32 / math.sqrt(32)
scales = torch.amax(w.abs(), dim = -1, keepdim = True).half()
w = w / scales
vq = (w * m).round().clamp(-m, m - 1)
vq /= m
vq *= scales
vq = vq @ had32 / math.sqrt(32)
vq = vq.view(v.shape)
return vq
# KL divergence between softmax distributions
def kl_divergence_scores(s, s_prime, dim = -1, eps = 1e-8):
alpha = F.softmax(s.float(), dim = dim)
alpha_hat = F.softmax(s_prime.float(), dim = dim)
kl_elementwise = alpha * (torch.log(alpha + eps) - torch.log(alpha_hat + eps))
kl_per_item = kl_elementwise.sum(dim = dim)
kl_mean = kl_per_item.mean()
return kl_mean
# Normalized MSE
def nmse(o, o_prime):
return (o - o_prime).square().mean() / o_prime.square().mean()
# Do stuff
def test_qkv(label, q, k, v, ref_o, ref_scores, q_rot = False, k_rot = False, v_rot = False):
head_dim = q.shape[-1]
had = get_hadamard_dt(head_dim, device, torch.half)
if q_rot != k_rot: q = (q @ had) / math.sqrt(head_dim)
if v_rot: v = (v @ had) / math.sqrt(head_dim)
test_o, test_scores = attn(q, k, v)
kld = kl_divergence_scores(test_scores, ref_scores)
mse = nmse(test_o, ref_o)
print(f"{label:26} weights_kld: {kld:.6f} output_nmse: {mse:.6f}")
with torch.inference_mode():
head_dim = samples_qkv[0][0].shape[-1]
had = get_hadamard_dt(head_dim, device, torch.half)
for idx, (q, k, v) in zip(target_layers, samples_qkv):
# Unquantized
ref_o, ref_scores = attn(q, k, v)
# Q4
test_qkv(
"Q4",
q,
int_quant(k, 4),
int_quant(v, 4),
ref_o,
ref_scores
)
# Q6
test_qkv(
"Q6",
q,
int_quant(k, 6),
int_quant(v, 6),
ref_o,
ref_scores
)
# Q8
test_qkv(
"Q8",
q,
int_quant(k, 8),
int_quant(v, 8),
ref_o,
ref_scores
)
# Rotated Q4
test_qkv(
"Rot. Q4",
q,
int_quant((k @ had) / math.sqrt(head_dim), 4),
int_quant((v @ had) / math.sqrt(head_dim), 4),
ref_o,
ref_scores,
False, True, True
)
# Rotated Q6
test_qkv(
"Rot. Q6",
q,
int_quant((k @ had) / math.sqrt(head_dim), 6),
int_quant((v @ had) / math.sqrt(head_dim), 6),
ref_o,
ref_scores,
False, True, True
)
# Channel scales + rotated Q4
psc_k = k.view(-1, k.shape[-2], k.shape[-1]).abs().mean(dim = 0)
psc_v = v.view(-1, k.shape[-2], k.shape[-1]).abs().mean(dim = 0)
test_qkv(
"Rot. Q4 ch.scales",
q,
int_quant(((k / psc_k) @ had) / math.sqrt(head_dim), 4) @ had / math.sqrt(head_dim) * psc_k,
int_quant(((v / psc_v) @ had) / math.sqrt(head_dim), 4) @ had / math.sqrt(head_dim) * psc_v,
ref_o,
ref_scores,
False, False, False
)
# Channel scales + rotated Q4 RMS
pscr_k = k.view(-1, k.shape[-2], k.shape[-1]).square().mean(dim = 0).sqrt()
pscr_v = v.view(-1, k.shape[-2], k.shape[-1]).square().mean(dim = 0).sqrt()
test_qkv(
"Rot. Q4 ch.scales (RMS)",
q,
int_quant(((k / pscr_k) @ had) / math.sqrt(head_dim), 4) @ had / math.sqrt(head_dim) * pscr_k,
int_quant(((v / pscr_v) @ had) / math.sqrt(head_dim), 4) @ had / math.sqrt(head_dim) * pscr_v,
ref_o,
ref_scores,
False, False, False
)
# Rotated Q4 + Q6
test_qkv(
"Rot. Q4+Q6",
q,
int_quant((k @ had) / math.sqrt(head_dim), 4),
int_quant((v @ had) / math.sqrt(head_dim), 6),
ref_o,
ref_scores,
False, True, True
)
# NF4
# k_nf4 = quant_nf4(k)
# v_nf4 = quant_nf4(v)
# test_qkv("NF4", q, k_nf4, v_nf4, ref_o, ref_scores, False, False, False)
# Rotated NF4
# k_h = (k @ had) / math.sqrt(128)
# v_h = (v @ had) / math.sqrt(128)
# k_h_nf4 = quant_nf4(k_h)
# v_h_nf4 = quant_nf4(v_h)
# test_qkv("RNF4", q, k_h_nf4, v_h_nf4, ref_o, ref_scores, False, True, True)
# FP8
test_qkv(
"FP8 e4m3",
q,
quant_fp8(k),
quant_fp8(v),
ref_o,
ref_scores,
False, False, False
)
# Kernel
for bits in range(2, 9):
quant_shape = k.shape[:-1] + (128 // 32 * bits,)
scale_shape = k.shape[:-1] + (128 // 32,)
k_quant = torch.zeros(quant_shape, dtype = torch.int, device = k.device)
k_scale = torch.zeros(scale_shape, dtype = torch.half, device = k.device)
v_quant = torch.zeros(quant_shape, dtype = torch.int, device = k.device)
v_scale = torch.zeros(scale_shape, dtype = torch.half, device = k.device)
ext.quant_cache_cont(k, k_quant, k_scale)
ext.quant_cache_cont(v, v_quant, v_scale)
k_kern = torch.empty_like(k)
v_kern = torch.empty_like(v)
ext.dequant_cache_cont(k_quant, k_scale, k_kern)
ext.dequant_cache_cont(v_quant, v_scale, v_kern)
test_qkv(f"Kernel {bits} bits", q, k_kern, v_kern, ref_o, ref_scores, False, False, False)
# Reference
test_qkv(f"Kernel ref 4 bits",
q,
kernel_ref_quant(k, 4),
kernel_ref_quant(v, 4),
ref_o,
ref_scores,
False, False, False
)

144
tests/test_kv_quant.py Normal file
View File

@@ -0,0 +1,144 @@
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pytest
import torch
from exllamav3.ext import exllamav3_ext as ext
import random
torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 200)
devices = [
"cuda:1"
]
page_size = 256
block_table_sizes = [(1,4), (1,8), (3, 4), (8,2)]
head_dims = [128, 64, 96, 32, 256]
num_kv_headss = [8, 2, 1]
cache_sizes = [32768]
bitss = [8] # Not testing accuracy, so 8-bit only to test the paging logic
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("block_table_size", block_table_sizes)
@pytest.mark.parametrize("head_dim", head_dims)
@pytest.mark.parametrize("num_kv_heads", num_kv_headss)
@pytest.mark.parametrize("cache_size", cache_sizes)
@pytest.mark.parametrize("bits", bitss)
@torch.inference_mode()
def test_kv_quant(device, block_table_size, head_dim, num_kv_heads, cache_size, bits):
torch.manual_seed(0)
bsz, pages = block_table_size
block_table = torch.arange(bsz * pages, dtype = torch.int, device = device).view(bsz, pages)
cache_seqlens = torch.zeros(size = (bsz,), dtype = torch.int, device = device)
cache_shape = (cache_size // page_size, page_size, num_kv_heads, head_dim)
cache_k_tensor = torch.zeros(cache_shape, dtype = torch.half, device = device)
cache_v_tensor = torch.zeros(cache_shape, dtype = torch.half, device = device)
cache_k_tensor_out = torch.zeros_like(cache_k_tensor)
cache_v_tensor_out = torch.zeros_like(cache_v_tensor)
qcache_shape = (cache_size // page_size, page_size, num_kv_heads * head_dim // 32 * bits)
qscales_shape = (cache_size // page_size, page_size, num_kv_heads * head_dim // 32)
cache_k_q = torch.zeros(qcache_shape, dtype = torch.int, device = device)
cache_v_q = torch.zeros(qcache_shape, dtype = torch.int, device = device)
cache_k_s = torch.zeros(qscales_shape, dtype = torch.half, device = device)
cache_v_s = torch.zeros(qscales_shape, dtype = torch.half, device = device)
def q(length):
ext.quant_cache_paged(
cache_k_tensor,
cache_k_q,
cache_k_s,
cache_v_tensor,
cache_v_q,
cache_v_s,
cache_seqlens,
block_table,
page_size,
length
)
def dq():
ext.dequant_cache_paged(
cache_k_q,
cache_k_s,
cache_k_tensor_out,
cache_v_q,
cache_v_s,
cache_v_tensor_out,
cache_seqlens,
block_table,
page_size
)
def tq():
torch.testing.assert_close(cache_k_tensor, cache_k_tensor_out, atol = 0.08, rtol = 0.01)
torch.testing.assert_close(cache_v_tensor, cache_v_tensor_out, atol = 0.08, rtol = 0.01)
# Put some stuff in cache
for i in range(bsz):
cache_seqlens[i] = i
for h in range(num_kv_heads):
cache_k_tensor[block_table[i, 0], i, h, :] = h
cache_v_tensor[block_table[i, 0], i, h, :] = h + num_kv_heads
q(1)
for i in range(bsz):
cache_seqlens[i] += 1
dq()
torch.cuda.synchronize()
tq()
# Put more stuff in the cache
new_cache_seqlens = torch.zeros_like(cache_seqlens)
random.seed(0)
for i in range(bsz):
l = random.randint(10, pages * page_size - 2)
new_cache_seqlens[i] = l
for j in range(l):
m = j % 13
for h in range(num_kv_heads):
cache_k_tensor[block_table[i, j // page_size], j % page_size, h, :] = h + m
cache_v_tensor[block_table[i, j // page_size], j % page_size, h, :] = h + m + num_kv_heads
cache_seqlens[:] = 0
q(new_cache_seqlens.amax())
cache_seqlens.copy_(new_cache_seqlens)
dq()
torch.cuda.synchronize()
tq()
# Mess up pages
block_table = block_table.flatten()[torch.randperm(block_table.numel())].view(block_table.shape)
cache_k_q[:, :, :] = 0
cache_v_q[:, :, :] = 0
cache_k_s[:, :, :] = 0
cache_v_s[:, :, :] = 0
for i in range(bsz):
l = new_cache_seqlens[i]
for j in range(l):
cache_k_tensor[block_table[i, j // page_size], j % page_size, :, :] += 1
cache_v_tensor[block_table[i, j // page_size], j % page_size, :, :] += 1
cache_seqlens[:] = 0
q(new_cache_seqlens.amax())
cache_seqlens.copy_(new_cache_seqlens)
dq()
torch.cuda.synchronize()
tq()
# Update five tokens
for i in range(bsz):
l = cache_seqlens[i]
for j in range(5):
pos = l + j
cache_k_tensor[block_table[i, pos // page_size], + pos % page_size, :, :] = 32 + j
cache_v_tensor[block_table[i, pos // page_size], + pos % page_size, :, :] = 32 + j
q(5)
for i in range(bsz):
cache_seqlens[i] += 5
dq()
tq()
xx = 0