mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-23 07:49:05 +00:00
Add cache quantization
This commit is contained in:
@@ -6,7 +6,7 @@ This is an **early preview release** of ExLlamaV3. Please note: ↙
|
||||
- The framework <u>is not yet fully optimized</u>. Performance is lacking, especially on Ampere, and there may be a significant CPU bottleneck on slower processors until the extension functions are fully built out.
|
||||
- AMD GPUs (ROCm) are not yet supported.
|
||||
- [FlashAttention-2](https://github.com/Dao-AILab/flash-attention) is currently required. I hope to switch over to [FlashInfer](https://github.com/flashinfer-ai/flashinfer/tree/main) in time, but there are some obstacles to overcome first.
|
||||
- A number of important features are yet to be added, such as cache quantization, tensor parallelism and multimodal support.
|
||||
- A number of important features are yet to be added, such as tensor parallelism and multimodal support.
|
||||
- There are no release builds yet.
|
||||
- Integration into [TabbyAPI](https://github.com/theroyallab/tabbyAPI/) is planned when all the core functionality is in place.
|
||||
|
||||
@@ -26,7 +26,6 @@ There's much that still needs to be added and/or ported over from ExLlamaV2. I'v
|
||||
- Samplers (most notably repetition penalties and min-P are missing)
|
||||
- Constrained sampling (JSON filters etc.)
|
||||
- Multimodal support
|
||||
- Cache quantization
|
||||
- LoRA support
|
||||
- ROCm support
|
||||
- Tensor-parallel inference
|
||||
|
||||
BIN
doc/cq_humaneval.png
Normal file
BIN
doc/cq_humaneval.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 60 KiB |
3
exllamav3/cache/__init__.py
vendored
3
exllamav3/cache/__init__.py
vendored
@@ -1,2 +1,3 @@
|
||||
from .cache import Cache, CacheLayer
|
||||
from .fp16 import CacheLayer_fp16
|
||||
from .fp16 import CacheLayer_fp16
|
||||
from .quant import CacheLayer_quant
|
||||
33
exllamav3/cache/cache.py
vendored
33
exllamav3/cache/cache.py
vendored
@@ -16,6 +16,7 @@ class CacheLayer(ABC):
|
||||
config: Config,
|
||||
attention: Attention,
|
||||
max_num_tokens: int,
|
||||
**kwargs
|
||||
):
|
||||
self.config = config
|
||||
self.attention = attention
|
||||
@@ -30,7 +31,18 @@ class CacheLayer(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_kv(self):
|
||||
def get_kv(self, cache_seqlens: torch.Tensor, block_table: torch.Tensor) -> tuple:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_kv(
|
||||
self,
|
||||
cache_seqlens: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
length: int
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -45,6 +57,7 @@ class Cache:
|
||||
model: Model,
|
||||
max_num_tokens: int,
|
||||
layer_type: Type[CacheLayer] | None = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Create cache for model
|
||||
@@ -71,7 +84,7 @@ class Cache:
|
||||
|
||||
self.num_layers = len(self.model.get_cache_layers())
|
||||
self.layers = [
|
||||
self.layer_type(self.config, attn, self.max_num_tokens)
|
||||
self.layer_type(self.config, attn, self.max_num_tokens, **kwargs)
|
||||
for attn in self.model.get_cache_layers()
|
||||
]
|
||||
self.attach_to_model()
|
||||
@@ -107,8 +120,20 @@ class Cache:
|
||||
module.cache_layers.remove(layer)
|
||||
|
||||
|
||||
def get_layer(self, idx: int) -> tuple:
|
||||
return self.layers[idx].get_kv()
|
||||
def get_layer(self, idx: int, cache_seqlens: torch.Tensor, block_table: torch.Tensor) -> tuple:
|
||||
return self.layers[idx].get_kv(cache_seqlens, block_table)
|
||||
|
||||
|
||||
def update_layer(
|
||||
self,
|
||||
idx: int,
|
||||
cache_seqlens: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
length: int
|
||||
):
|
||||
return self.layers[idx].update_kv(cache_seqlens, block_table, k, v, length)
|
||||
|
||||
|
||||
def copy_page(
|
||||
|
||||
22
exllamav3/cache/fp16.py
vendored
22
exllamav3/cache/fp16.py
vendored
@@ -47,16 +47,24 @@ class CacheLayer_fp16(CacheLayer):
|
||||
|
||||
|
||||
@override
|
||||
def get_kv(self):
|
||||
def get_kv(self, cache_seqlens: torch.Tensor, block_table: torch.Tensor) -> tuple:
|
||||
return self.k, self.v
|
||||
|
||||
|
||||
@override
|
||||
def update_kv(
|
||||
self,
|
||||
cache_seqlens: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
length: int
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@override
|
||||
def copy_page(self, source: CacheLayer_fp16, from_page: int, to_page: int, num_tokens: int):
|
||||
assert self.shape == source.shape
|
||||
kd = self.k[to_page, :num_tokens, :, :]
|
||||
vd = self.v[to_page, :num_tokens, :, :]
|
||||
ks = source.k[from_page, :num_tokens, :, :]
|
||||
vs = source.v[from_page, :num_tokens, :, :]
|
||||
kd.copy_(ks, non_blocking = True)
|
||||
vd.copy_(vs, non_blocking = True)
|
||||
self.k[to_page, :num_tokens, :, :].copy_(source.k[from_page, :num_tokens, :, :], non_blocking = True)
|
||||
self.v[to_page, :num_tokens, :, :].copy_(source.v[from_page, :num_tokens, :, :], non_blocking = True)
|
||||
100
exllamav3/cache/quant.py
vendored
Normal file
100
exllamav3/cache/quant.py
vendored
Normal file
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
from typing_extensions import override
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from ..constants import PAGE_SIZE
|
||||
from ..models import Model, Config
|
||||
from .cache import CacheLayer
|
||||
from typing import TYPE_CHECKING
|
||||
from exllamav3.ext import exllamav3_ext as ext
|
||||
if TYPE_CHECKING:
|
||||
from ..modules import Attention
|
||||
|
||||
class CacheLayer_quant(CacheLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
attention: Attention,
|
||||
max_num_tokens: int,
|
||||
k_bits: int,
|
||||
v_bits: int,
|
||||
):
|
||||
super().__init__(config, attention, max_num_tokens)
|
||||
|
||||
assert max_num_tokens % PAGE_SIZE == 0, \
|
||||
f"max_num_tokens must be a multiple of {PAGE_SIZE}."
|
||||
assert (2 <= k_bits <= 8) and (2 <= v_bits <= 8), "quantized cache must be from 2 to 8 bits"
|
||||
|
||||
self.shape = (
|
||||
(max_num_tokens // PAGE_SIZE, PAGE_SIZE, attention.num_kv_heads, attention.head_dim)
|
||||
if attention else None
|
||||
)
|
||||
|
||||
self.k_bits = k_bits
|
||||
self.v_bits = v_bits
|
||||
self.token_dim = attention.num_kv_heads * attention.head_dim
|
||||
self.qshape_k = ((max_num_tokens // PAGE_SIZE, PAGE_SIZE, self.token_dim // 32 * k_bits) if attention else None)
|
||||
self.qshape_v = ((max_num_tokens // PAGE_SIZE, PAGE_SIZE, self.token_dim // 32 * v_bits) if attention else None)
|
||||
self.qshape_s = ((max_num_tokens // PAGE_SIZE, PAGE_SIZE, self.token_dim // 32) if attention else None)
|
||||
|
||||
self.qk = None
|
||||
self.qv = None
|
||||
self.sk = None
|
||||
self.sv = None
|
||||
self.device = None
|
||||
|
||||
|
||||
@override
|
||||
def alloc(self, device: torch.device):
|
||||
self.device = device
|
||||
self.qk = torch.zeros(self.qshape_k, dtype = torch.int, device = device) if self.shape else None
|
||||
self.qv = torch.zeros(self.qshape_v, dtype = torch.int, device = device) if self.shape else None
|
||||
self.sk = torch.zeros(self.qshape_s, dtype = torch.half, device = device) if self.shape else None
|
||||
self.sv = torch.zeros(self.qshape_s, dtype = torch.half, device = device) if self.shape else None
|
||||
|
||||
|
||||
@override
|
||||
def free(self):
|
||||
self.device = None
|
||||
self.qk = None
|
||||
self.qv = None
|
||||
self.sk = None
|
||||
self.sv = None
|
||||
|
||||
|
||||
@override
|
||||
def get_kv(self, cache_seqlens: torch.Tensor, block_table: torch.Tensor):
|
||||
k = torch.empty(self.shape, dtype = torch.half, device = self.device)
|
||||
v = torch.empty(self.shape, dtype = torch.half, device = self.device)
|
||||
ext.dequant_cache_paged(self.qk, self.sk, k, self.qv, self.sv, v, cache_seqlens, block_table, PAGE_SIZE)
|
||||
return k, v
|
||||
|
||||
|
||||
@override
|
||||
def update_kv(
|
||||
self,
|
||||
cache_seqlens: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
length: int
|
||||
):
|
||||
ext.quant_cache_paged(
|
||||
k, self.qk, self.sk,
|
||||
v, self.qv, self.sv,
|
||||
cache_seqlens, block_table,
|
||||
PAGE_SIZE,
|
||||
length
|
||||
)
|
||||
|
||||
|
||||
@override
|
||||
def copy_page(self, source: CacheLayer_quant, from_page: int, to_page: int, num_tokens: int):
|
||||
assert self.qshape_k == source.qshape_k
|
||||
assert self.qshape_v == source.qshape_v
|
||||
self.qk[to_page, :num_tokens, :].copy_(source.qk[from_page, :num_tokens, :], non_blocking = True)
|
||||
self.qv[to_page, :num_tokens, :].copy_(source.qv[from_page, :num_tokens, :], non_blocking = True)
|
||||
self.sk[to_page, :num_tokens, :].copy_(source.sk[from_page, :num_tokens, :], non_blocking = True)
|
||||
self.sv[to_page, :num_tokens, :].copy_(source.sv[from_page, :num_tokens, :], non_blocking = True)
|
||||
@@ -21,6 +21,8 @@
|
||||
#include "generator/sampling_basic.cuh"
|
||||
#include "generator/gumbel.cuh"
|
||||
|
||||
#include "cache/q_cache.cuh"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("stloader_read", &stloader_read, "stloader_read");
|
||||
@@ -56,4 +58,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
|
||||
m.def("partial_strings_match", &partial_strings_match, "partial_strings_match");
|
||||
m.def("count_match_tensor", &count_match_tensor, "count_match_tensor");
|
||||
|
||||
m.def("quant_cache_cont", &quant_cache_cont, "quant_cache_cont");
|
||||
m.def("dequant_cache_cont", &dequant_cache_cont, "dequant_cache_cont");
|
||||
m.def("quant_cache_paged", &quant_cache_paged, "quant_cache_paged");
|
||||
m.def("dequant_cache_paged", &dequant_cache_paged, "dequant_cache_paged");
|
||||
}
|
||||
280
exllamav3/exllamav3_ext/cache/q_cache.cu
vendored
Normal file
280
exllamav3/exllamav3_ext/cache/q_cache.cu
vendored
Normal file
@@ -0,0 +1,280 @@
|
||||
#include "q_cache.cuh"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include "../util.h"
|
||||
#include "../util.cuh"
|
||||
#include <limits>
|
||||
#include "../quant/codebook.cuh"
|
||||
#include "q_cache_kernels.cuh"
|
||||
|
||||
/*
|
||||
Quantize contiguous tensor
|
||||
|
||||
in: float16, shape (..., dim)
|
||||
out: int32, shape (..., dim / 32 * bitrate)
|
||||
out_scales: float16, shape (..., dim / 32)
|
||||
*/
|
||||
|
||||
void quant_cache_cont
|
||||
(
|
||||
const at::Tensor& in,
|
||||
const at::Tensor& out,
|
||||
const at::Tensor& out_scales
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(in.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK_DTYPE(in, kHalf);
|
||||
TORCH_CHECK_DTYPE(out, kInt);
|
||||
TORCH_CHECK_DTYPE(out_scales, kHalf);
|
||||
|
||||
int bsz = in.numel() / 32;
|
||||
int head_dim = in.size(-1);
|
||||
int head_blocks = head_dim / 32;
|
||||
TORCH_CHECK(head_dim == 32 * head_blocks, "head_dim must be a multiple of 32");
|
||||
int bits = out.size(-1) / head_blocks;
|
||||
TORCH_CHECK(out.numel() == bsz * bits, "out is wrong size");
|
||||
TORCH_CHECK(out_scales.numel() == bsz, "out_scales is wrong size");
|
||||
|
||||
TORCH_CHECK(2 <= bits && bits <= 8, "Unsupported K/V cache bitrate");
|
||||
|
||||
static_for_pack<2,3,4,5,6,7,8>([&](auto ic)
|
||||
{
|
||||
constexpr int i = decltype(ic)::value;
|
||||
if (bits == i)
|
||||
{
|
||||
quant_cache_cont_kernel<i><<<bsz, 32, 0, stream>>>
|
||||
(
|
||||
(const half*) in.data_ptr(),
|
||||
(uint32_t*) out.data_ptr(),
|
||||
(half*) out_scales.data_ptr()
|
||||
);
|
||||
}
|
||||
});
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
/*
|
||||
Dequantize contiguous tensor
|
||||
|
||||
in: int32, shape (..., dim / 32 * bitrate)
|
||||
in_scales: float16, shape (..., dim / 32)
|
||||
out: float16, shape (..., dim)
|
||||
*/
|
||||
|
||||
void dequant_cache_cont
|
||||
(
|
||||
const at::Tensor& in,
|
||||
const at::Tensor& in_scales,
|
||||
const at::Tensor& out
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(in.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK_DTYPE(in, kInt);
|
||||
TORCH_CHECK_DTYPE(in_scales, kHalf);
|
||||
TORCH_CHECK_DTYPE(out, kHalf);
|
||||
|
||||
int bsz = out.numel() / 32;
|
||||
int head_dim = out.size(-1);
|
||||
int head_blocks = head_dim / 32;
|
||||
TORCH_CHECK(head_dim == 32 * head_blocks, "head_dim must be a multiple of 32");
|
||||
int bits = in.size(-1) / head_blocks;
|
||||
TORCH_CHECK(in.numel() == bsz * bits, "in is wrong size");
|
||||
TORCH_CHECK(in_scales.numel() == bsz, "in_scales is wrong size");
|
||||
|
||||
TORCH_CHECK(2 <= bits && bits <= 8, "Unsupported K/V cache bitrate");
|
||||
|
||||
static_for_pack<2,3,4,5,6,7,8>([&](auto ic)
|
||||
{
|
||||
constexpr int i = decltype(ic)::value;
|
||||
if (bits == i)
|
||||
{
|
||||
dequant_cache_cont_kernel<i><<<bsz, 32, 0, stream>>>
|
||||
(
|
||||
(const uint32_t*) in.data_ptr(),
|
||||
(const half*) in_scales.data_ptr(),
|
||||
(half*) out.data_ptr()
|
||||
);
|
||||
}
|
||||
});
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
/*
|
||||
Quantize paged tensor
|
||||
|
||||
k_in, v_in: float16, shape (1, cache_size, dim)
|
||||
k_out, v_out: int32, shape (1, cache_size, dim / 32 * bitrate)
|
||||
k_out_scales, v_out_scales: float16, shape (1, cache_size, dim / 32)
|
||||
cache_seqlens: int32, length of each sequence in batch, k_out and v_out are updated _from_ this point
|
||||
block_table: int32, shape (bsz, blocks_per_seq)
|
||||
page_size: 256
|
||||
seq_len: number of positions (size: dim) to update from end of each sequence
|
||||
*/
|
||||
|
||||
void quant_cache_paged
|
||||
(
|
||||
const at::Tensor& k_in,
|
||||
const at::Tensor& k_out,
|
||||
const at::Tensor& k_out_scales,
|
||||
const at::Tensor& v_in,
|
||||
const at::Tensor& v_out,
|
||||
const at::Tensor& v_out_scales,
|
||||
const at::Tensor& cache_seqlens,
|
||||
const at::Tensor& block_table,
|
||||
int page_size,
|
||||
int seq_len
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(k_in.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_DTYPE(k_in, kHalf);
|
||||
TORCH_CHECK_DTYPE(k_out, kInt);
|
||||
TORCH_CHECK_DTYPE(k_out_scales, kHalf);
|
||||
TORCH_CHECK_DTYPE(v_in, kHalf);
|
||||
TORCH_CHECK_DTYPE(v_out, kInt);
|
||||
TORCH_CHECK_DTYPE(v_out_scales, kHalf);
|
||||
TORCH_CHECK_SHAPES_FULL(k_in, v_in);
|
||||
TORCH_CHECK_SHAPES_FULL(k_out_scales, v_out_scales);
|
||||
|
||||
int dim;
|
||||
if (k_in.dim() == 4)
|
||||
dim = k_in.size(2) * k_in.size(3);
|
||||
else if (k_in.dim() == 3)
|
||||
dim = k_in.size(2);
|
||||
else
|
||||
TORCH_CHECK(false, "paged cache must be 3D or 4D")
|
||||
|
||||
int warps_per_token = dim / 32;
|
||||
TORCH_CHECK(dim == 32 * warps_per_token, "dim must be a multiple of 32");
|
||||
int tb_per_token = CEIL_DIVIDE(warps_per_token, MAX_WARPS); // Threadblocks per token position
|
||||
int tb_usage = CEIL_DIVIDE(warps_per_token, tb_per_token); // Number of warps to use per threadblock
|
||||
|
||||
TORCH_CHECK(k_out.dim() == 3 && v_out.dim() == 3, "paged q.cache must have shape (num_pages, page_size, dim // 32 * bitrate)")
|
||||
int k_bits = k_out.size(2) / warps_per_token;
|
||||
int v_bits = v_out.size(2) / warps_per_token;
|
||||
|
||||
int bsz = block_table.size(0);
|
||||
int blocks_per_seq = block_table.size(1);
|
||||
|
||||
dim3 blocks(tb_per_token, seq_len, bsz);
|
||||
dim3 threads(32 * tb_usage);
|
||||
|
||||
static_for_pack<2,3,4,5,6,7,8>([&](auto jc)
|
||||
{
|
||||
constexpr int j = decltype(jc)::value;
|
||||
static_for_pack<2,3,4,5,6,7,8>([&](auto ic)
|
||||
{
|
||||
constexpr int i = decltype(ic)::value;
|
||||
if (k_bits == i && v_bits == j)
|
||||
{
|
||||
quant_cache_paged_kernel<i, j><<<blocks, threads, 0, stream>>>
|
||||
(
|
||||
(const half*) k_in.data_ptr(),
|
||||
(uint32_t*) k_out.data_ptr(),
|
||||
(half*) k_out_scales.data_ptr(),
|
||||
(const half*) v_in.data_ptr(),
|
||||
(uint32_t*) v_out.data_ptr(),
|
||||
(half*) v_out_scales.data_ptr(),
|
||||
(const uint32_t*) cache_seqlens.data_ptr(),
|
||||
(const uint32_t*) block_table.data_ptr(),
|
||||
page_size,
|
||||
blocks_per_seq,
|
||||
dim
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
/*
|
||||
Dequantize paged tensor
|
||||
|
||||
k_in, v_in: int32, shape (1, cache_size, dim / 32 * bitrate)
|
||||
k_in_scales, v_in_scales: float16, shape (1, cache_size, dim / 32)
|
||||
k_out, v_out: float16, shape (1, cache_size, dim)
|
||||
cache_seqlens: int32, length of each sequence in batch, k_out and v_out are updated _up_to_ this point
|
||||
block_table: int32, shape (bsz, blocks_per_seq)
|
||||
page_size: 256
|
||||
*/
|
||||
|
||||
void dequant_cache_paged
|
||||
(
|
||||
const at::Tensor& k_in,
|
||||
const at::Tensor& k_in_scales,
|
||||
const at::Tensor& k_out,
|
||||
const at::Tensor& v_in,
|
||||
const at::Tensor& v_in_scales,
|
||||
const at::Tensor& v_out,
|
||||
const at::Tensor& cache_seqlens,
|
||||
const at::Tensor& block_table,
|
||||
int page_size
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(k_in.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_DTYPE(k_in, kInt);
|
||||
TORCH_CHECK_DTYPE(k_in_scales, kHalf);
|
||||
TORCH_CHECK_DTYPE(k_out, kHalf);
|
||||
TORCH_CHECK_DTYPE(v_in, kInt);
|
||||
TORCH_CHECK_DTYPE(v_in_scales, kHalf);
|
||||
TORCH_CHECK_DTYPE(v_out, kHalf);
|
||||
TORCH_CHECK_SHAPES_FULL(k_in_scales, v_in_scales);
|
||||
TORCH_CHECK_SHAPES_FULL(k_out, v_out);
|
||||
|
||||
int dim;
|
||||
if (k_out.dim() == 4)
|
||||
dim = k_out.size(2) * k_out.size(3);
|
||||
else if (k_out.dim() == 3)
|
||||
dim = k_out.size(2);
|
||||
else
|
||||
TORCH_CHECK(false, "paged cache must be 3D or 4D")
|
||||
|
||||
int warps_per_token = dim / 32;
|
||||
TORCH_CHECK(dim == 32 * warps_per_token, "dim must be a multiple of 32");
|
||||
|
||||
int bsz = block_table.size(0);
|
||||
int pages_per_seq = block_table.size(1);
|
||||
int warps_per_seq = pages_per_seq * page_size * warps_per_token;
|
||||
|
||||
int num_tb = CEIL_DIVIDE(32 * warps_per_seq, 1024);
|
||||
int num_threads = MIN(32 * warps_per_seq, 1024);
|
||||
dim3 blocks(num_tb, bsz);
|
||||
dim3 threads(num_threads);
|
||||
|
||||
TORCH_CHECK(k_in.dim() == 3 && v_in.dim() == 3, "paged q.cache must have shape (num_pages, page_size, dim // 32 * bitrate)")
|
||||
int k_bits = k_in.size(2) / warps_per_token;
|
||||
int v_bits = v_in.size(2) / warps_per_token;
|
||||
|
||||
static_for_pack<2,3,4,5,6,7,8>([&](auto jc)
|
||||
{
|
||||
constexpr int j = decltype(jc)::value;
|
||||
static_for_pack<2,3,4,5,6,7,8>([&](auto ic)
|
||||
{
|
||||
constexpr int i = decltype(ic)::value;
|
||||
if (k_bits == i && v_bits == j)
|
||||
{
|
||||
dequant_cache_paged_kernel<i, j><<<blocks, threads, 0, stream>>>
|
||||
(
|
||||
(const uint32_t*) k_in.data_ptr(),
|
||||
(const half*) k_in_scales.data_ptr(),
|
||||
(half*) k_out.data_ptr(),
|
||||
(const uint32_t*) v_in.data_ptr(),
|
||||
(const half*) v_in_scales.data_ptr(),
|
||||
(half*) v_out.data_ptr(),
|
||||
(const uint32_t*) cache_seqlens.data_ptr(),
|
||||
(const uint32_t*) block_table.data_ptr(),
|
||||
page_size,
|
||||
pages_per_seq,
|
||||
warps_per_token
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
44
exllamav3/exllamav3_ext/cache/q_cache.cuh
vendored
Normal file
44
exllamav3/exllamav3_ext/cache/q_cache.cuh
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
void quant_cache_cont
|
||||
(
|
||||
const at::Tensor& in,
|
||||
const at::Tensor& out,
|
||||
const at::Tensor& out_scales
|
||||
);
|
||||
|
||||
void dequant_cache_cont
|
||||
(
|
||||
const at::Tensor& in,
|
||||
const at::Tensor& in_scales,
|
||||
const at::Tensor& out
|
||||
);
|
||||
|
||||
void quant_cache_paged
|
||||
(
|
||||
const at::Tensor& k_in,
|
||||
const at::Tensor& k_out,
|
||||
const at::Tensor& k_out_scales,
|
||||
const at::Tensor& v_in,
|
||||
const at::Tensor& v_out,
|
||||
const at::Tensor& v_out_scales,
|
||||
const at::Tensor& cache_seqlens,
|
||||
const at::Tensor& block_table,
|
||||
int page_size,
|
||||
int seq_len
|
||||
);
|
||||
|
||||
void dequant_cache_paged
|
||||
(
|
||||
const at::Tensor& k_in,
|
||||
const at::Tensor& k_in_scales,
|
||||
const at::Tensor& k_out,
|
||||
const at::Tensor& v_in,
|
||||
const at::Tensor& v_in_scales,
|
||||
const at::Tensor& v_out,
|
||||
const at::Tensor& cache_seqlens,
|
||||
const at::Tensor& block_table,
|
||||
int page_size
|
||||
);
|
||||
192
exllamav3/exllamav3_ext/cache/q_cache_kernels.cuh
vendored
Normal file
192
exllamav3/exllamav3_ext/cache/q_cache_kernels.cuh
vendored
Normal file
@@ -0,0 +1,192 @@
|
||||
|
||||
__device__ inline float shuffle_had_fx32(float v, int lane_id)
|
||||
{
|
||||
for (int i = 1; i < 32; i <<= 1)
|
||||
{
|
||||
float pv = __shfl_xor_sync(0xffffffff, v, i);
|
||||
uint32_t* vi = reinterpret_cast<uint32_t*>(&v);
|
||||
int32_t sfm = -static_cast<int16_t>(lane_id & i) >> 31;
|
||||
*vi ^= (sfm & 0x80000000);
|
||||
v = v + pv;
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
__device__ inline float shuffle_sum_fx32(float s)
|
||||
{
|
||||
for (int i = 1; i < 32; i <<= 1)
|
||||
s += __shfl_xor_sync(0xffffffff, s, i);
|
||||
return s;
|
||||
}
|
||||
|
||||
__device__ inline float shuffle_max_fx32(float s)
|
||||
{
|
||||
for (int i = 1; i < 32; i <<= 1)
|
||||
s = fmaxf(s, __shfl_xor_sync(0xffffffff, s, i));
|
||||
return s;
|
||||
}
|
||||
|
||||
template <int bits>
|
||||
__device__ inline void quant_block
|
||||
(
|
||||
const half* in,
|
||||
uint32_t* out,
|
||||
half* out_scales
|
||||
)
|
||||
{
|
||||
int t = threadIdx.x % 32;
|
||||
|
||||
// Load, rotate and scale 32 values
|
||||
float v = __half2float(in[t]);
|
||||
v = shuffle_had_fx32(v, t);
|
||||
v *= 0.17678f; // 0.17678 = 1 / sqrt(32)
|
||||
float s = shuffle_max_fx32(fabsf(v) + 1e-10);
|
||||
half sh = __float2half_rn(s);
|
||||
s = __half2float(sh);
|
||||
v /= s;
|
||||
|
||||
// Quantize and clamp
|
||||
int m = (1 << (bits - 1));
|
||||
v *= __int2float_rn(m);
|
||||
int q = lrintf(v) + m;
|
||||
q = max(min((1 << bits) - 1, q), 0);
|
||||
|
||||
// Pack bits
|
||||
register uint32_t bitplanes[bits];
|
||||
for (int i = 0, mask = 1; mask <= m; ++i, mask <<= 1)
|
||||
bitplanes[i] = __ballot_sync(0xffffffff, q & mask);
|
||||
|
||||
// Write output
|
||||
if (t < bits)
|
||||
out[t] = bitplanes[t];
|
||||
if (t == bits)
|
||||
*out_scales = sh;
|
||||
}
|
||||
|
||||
#define MAX_WARPS 32
|
||||
|
||||
template <int bits>
|
||||
__device__ inline void dequant_block
|
||||
(
|
||||
const uint32_t* in,
|
||||
const half* in_scales,
|
||||
half* out
|
||||
)
|
||||
{
|
||||
int t = threadIdx.x % 32;
|
||||
int warp_id = threadIdx.x / 32;
|
||||
|
||||
// Load scale and bitplanes
|
||||
float s = __half2float(*in_scales);
|
||||
__shared__ uint32_t bitplanes[MAX_WARPS][bits];
|
||||
if (t < bits)
|
||||
bitplanes[warp_id][t] = in[t];
|
||||
__syncthreads();
|
||||
|
||||
// Unpack bits
|
||||
int m = (1 << (bits - 1));
|
||||
uint32_t mask = 1 << t;
|
||||
int q = 0;
|
||||
for (int i = 0; i < bits; ++i)
|
||||
q |= ((bitplanes[warp_id][i] & mask) >> t) << i;
|
||||
|
||||
// Dequantize
|
||||
float v = __int2float_rn(q - m);
|
||||
v /= __int2float_rn(m);
|
||||
|
||||
// Scale and rotate
|
||||
v *= s;
|
||||
v = shuffle_had_fx32(v, t);
|
||||
v *= 0.17678f; // 0.17678 = 1 / sqrt(32)
|
||||
|
||||
// Store
|
||||
out[t] = __float2half(v);
|
||||
}
|
||||
|
||||
template <int bits>
|
||||
__global__ __launch_bounds__(1024)
|
||||
void quant_cache_cont_kernel
|
||||
(
|
||||
const half* __restrict__ in,
|
||||
uint32_t* __restrict__ out,
|
||||
half* __restrict__ out_scales
|
||||
)
|
||||
{
|
||||
in += 32 * blockIdx.x;
|
||||
out += bits * blockIdx.x;
|
||||
out_scales += blockIdx.x;
|
||||
quant_block<bits>(in, out, out_scales);
|
||||
}
|
||||
|
||||
template <int bits>
|
||||
__global__ __launch_bounds__(32)
|
||||
void dequant_cache_cont_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ in,
|
||||
const half* __restrict__ in_scales,
|
||||
half* __restrict__ out
|
||||
)
|
||||
{
|
||||
in += bits * blockIdx.x;
|
||||
in_scales += blockIdx.x;
|
||||
out += 32 * blockIdx.x;
|
||||
dequant_block<bits>(in, in_scales, out);
|
||||
}
|
||||
|
||||
template <int k_bits, int v_bits>
|
||||
__global__ __launch_bounds__(1024)
|
||||
void quant_cache_paged_kernel
|
||||
(
|
||||
const half* __restrict__ k_in,
|
||||
uint32_t* __restrict__ k_out,
|
||||
half* __restrict__ k_out_scales,
|
||||
const half* __restrict__ v_in,
|
||||
uint32_t* __restrict__ v_out,
|
||||
half* __restrict__ v_out_scales,
|
||||
const uint32_t* __restrict__ cache_seqlens,
|
||||
const uint32_t* __restrict__ block_table,
|
||||
int page_size,
|
||||
int blocks_per_seq,
|
||||
int token_dim
|
||||
)
|
||||
{
|
||||
int batch_idx = blockIdx.z;
|
||||
int token_idx = blockIdx.y + cache_seqlens[batch_idx];
|
||||
int page_idx = token_idx / page_size;
|
||||
int token_pos = block_table[blocks_per_seq * batch_idx + page_idx] * page_size + (token_idx % page_size);
|
||||
int sub_pos = (token_pos * token_dim + blockDim.x * blockIdx.x + threadIdx.x) / 32;
|
||||
|
||||
quant_block<k_bits>(k_in + sub_pos * 32, k_out + sub_pos * k_bits, k_out_scales + sub_pos);
|
||||
quant_block<v_bits>(v_in + sub_pos * 32, v_out + sub_pos * v_bits, v_out_scales + sub_pos);
|
||||
}
|
||||
|
||||
template <int k_bits, int v_bits>
|
||||
__global__ __launch_bounds__(1024)
|
||||
void dequant_cache_paged_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ k_in,
|
||||
const half* __restrict__ k_in_scales,
|
||||
half* __restrict__ k_out,
|
||||
const uint32_t* __restrict__ v_in,
|
||||
const half* __restrict__ v_in_scales,
|
||||
half* __restrict__ v_out,
|
||||
const uint32_t* __restrict__ cache_seqlens,
|
||||
const uint32_t* __restrict__ block_table,
|
||||
int page_size,
|
||||
int pages_per_seq,
|
||||
int warps_per_token
|
||||
)
|
||||
{
|
||||
int batch_idx = blockIdx.y;
|
||||
int t_warp_id = (blockDim.x * blockIdx.x + threadIdx.x) / 32;
|
||||
int token_idx = t_warp_id / warps_per_token;
|
||||
int max_token_idx = cache_seqlens[batch_idx];
|
||||
if (token_idx >= max_token_idx) return;
|
||||
int page_idx = token_idx / page_size;
|
||||
int page_sub = t_warp_id % (warps_per_token * page_size);
|
||||
int mapped_page = block_table[batch_idx * pages_per_seq + page_idx];
|
||||
int addr = mapped_page * page_size * warps_per_token + page_sub;
|
||||
|
||||
dequant_block<k_bits>(k_in + addr * k_bits, k_in_scales + addr, k_out + addr * 32);
|
||||
dequant_block<v_bits>(v_in + addr * v_bits, v_in_scales + addr, v_out + addr * 32);
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
from . import Model, Config, Cache, Tokenizer
|
||||
from .cache import CacheLayer_fp16, CacheLayer_quant
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
|
||||
@@ -20,6 +21,7 @@ def add_args(
|
||||
|
||||
if cache:
|
||||
parser.add_argument("-cs", "--cache_size", type = int, help = "Total cache size in tokens, default: 8192", default = 8192)
|
||||
parser.add_argument("-cq", "--cache_quant", type = str, help = "Use quantized cache. Specify either kv_bits or k_bits,v_bits pair")
|
||||
|
||||
# TODO:
|
||||
# parser.add_argument("-tp", "--tensor_parallel", action = "store_true", help = "Load in tensor-parallel mode")
|
||||
@@ -69,7 +71,30 @@ def init(
|
||||
model = Model.from_config(config)
|
||||
|
||||
# Cache
|
||||
cache = Cache(model, max_num_tokens = args.cache_size) if "cache_size" in vars(args) else None
|
||||
if "cache_size" in vars(args):
|
||||
if args.cache_quant is not None:
|
||||
split = [int(bits) for bits in args.cache_quant.split(",")]
|
||||
if len(split) == 1:
|
||||
k_bits = v_bits = split[0]
|
||||
elif len(split) == 2:
|
||||
k_bits, v_bits = tuple(split)
|
||||
else:
|
||||
raise ValueError("Specify either one or two bitrates for cache quantization")
|
||||
cache = Cache(
|
||||
model,
|
||||
max_num_tokens = args.cache_size,
|
||||
layer_type = CacheLayer_quant,
|
||||
k_bits = k_bits,
|
||||
v_bits = v_bits
|
||||
)
|
||||
else:
|
||||
cache = Cache(
|
||||
model,
|
||||
max_num_tokens = args.cache_size,
|
||||
layer_type = CacheLayer_fp16
|
||||
)
|
||||
else:
|
||||
cache = None
|
||||
|
||||
# Split
|
||||
if args.gpu_split is None or args.gpu_split == "auto":
|
||||
|
||||
@@ -348,7 +348,7 @@ class Attention(Module):
|
||||
|
||||
q, k = self.rope.apply(q, k, position, positions, position_ids, in_place = True)
|
||||
|
||||
cache_k, cache_v = cache.get_layer(self.layer_idx)
|
||||
cache_k, cache_v = cache.get_layer(self.layer_idx, cache_seqlens, block_table)
|
||||
o = flash_attn_with_kvcache(
|
||||
q = q,
|
||||
k = k,
|
||||
@@ -362,6 +362,7 @@ class Attention(Module):
|
||||
window_size = (self.sliding_window, self.sliding_window),
|
||||
softcap = self.logit_softcapping
|
||||
)
|
||||
cache.update_layer(self.layer_idx, cache_seqlens, block_table, cache_k, cache_v, seqlen)
|
||||
o = o.view((bsz, seqlen, self.num_q_heads * self.head_dim))
|
||||
|
||||
# TODO: Store updated cache layer
|
||||
|
||||
314
science/kv_quant_exp.py
Normal file
314
science/kv_quant_exp.py
Normal file
@@ -0,0 +1,314 @@
|
||||
import sys, os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from exllamav3 import Config, Model, Tokenizer
|
||||
from exllamav3.modules import TransformerBlock
|
||||
from exllamav3.util.hadamard import get_hadamard_dt
|
||||
from datasets import load_dataset
|
||||
from exllamav3.util.file import disk_lru_cache, disk_lru_cache_clear
|
||||
from flash_attn import flash_attn_func
|
||||
from ref_quant2 import quantquant
|
||||
from exllamav3.ext import exllamav3_ext as ext
|
||||
import math
|
||||
|
||||
torch.set_printoptions(precision = 8, sci_mode = False, linewidth = 200)
|
||||
|
||||
model_dir = "/mnt/str/models/llama3.1-8b-instruct/hf/"
|
||||
device = "cuda:1"
|
||||
target_layers = [0]
|
||||
num_rows = 1
|
||||
|
||||
# Create input tensor
|
||||
@disk_lru_cache("get_test_data")
|
||||
def get_test_data():
|
||||
return "\n\n".join(
|
||||
load_dataset("wikitext", "wikitext-2-raw-v1", split = "test")
|
||||
["text"]
|
||||
)
|
||||
|
||||
# Sample Q and K tensors from forward pass, Llama type model
|
||||
@disk_lru_cache("sample_qkv")
|
||||
def sample_qkv(_model_dir, _target_layers, _num_rows):
|
||||
|
||||
# Load model
|
||||
config = Config.from_directory(_model_dir)
|
||||
model = Model.from_config(config)
|
||||
model.load(device, progressbar = True)
|
||||
tokenizer = Tokenizer.from_config(config)
|
||||
|
||||
test_data = get_test_data()[:100000]
|
||||
eval_tokens = tokenizer.encode(test_data)
|
||||
eval_len = 2048
|
||||
eval_stride = 512
|
||||
num_tokens = eval_tokens.shape[-1]
|
||||
seqs = []
|
||||
for a in range(0, num_tokens - eval_len, eval_stride):
|
||||
b = a + eval_len
|
||||
seqs.append(eval_tokens[:, a:b])
|
||||
if len(seqs) >= num_rows:
|
||||
break
|
||||
input_ids = torch.cat(seqs, dim = 0)[:, :]
|
||||
|
||||
_samples_qkv = []
|
||||
params = {}
|
||||
x = model.prepare_inputs(input_ids, params)
|
||||
for idx, module in enumerate(model.modules):
|
||||
params["prefill"] = (idx == model.last_kv_module_idx)
|
||||
x = module.prepare_for_device(x, params)
|
||||
if isinstance(module, TransformerBlock):
|
||||
block_idx = int(module.key.split(".")[-1])
|
||||
if block_idx > max(_target_layers):
|
||||
break
|
||||
if block_idx in _target_layers:
|
||||
# Pre-attn norm
|
||||
y = module.attn_norm.forward(x, params, out_dtype = torch.half)
|
||||
# Projections and RoPE
|
||||
attn = module.attn
|
||||
bsz, seqlen, _ = y.shape
|
||||
position, positions, position_ids = 0, None, None
|
||||
q, k, v = attn.project_qkv(y, params)
|
||||
q = q.view(bsz, seqlen, attn.num_q_heads, attn.head_dim)
|
||||
k = k.view(bsz, seqlen, attn.num_kv_heads, attn.head_dim)
|
||||
v = v.view(bsz, seqlen, attn.num_kv_heads, attn.head_dim)
|
||||
q, k = attn.rope.apply(q, k, position, positions, position_ids)
|
||||
# Sample right before dot product
|
||||
_samples_qkv.append((q, k, v))
|
||||
# Advance state
|
||||
x = module.forward(x, params)
|
||||
|
||||
return _samples_qkv
|
||||
|
||||
samples_qkv = sample_qkv(model_dir, target_layers, num_rows)
|
||||
|
||||
# Get attention scores and output
|
||||
def attn(q, k, v):
|
||||
bsz, q_len, n_heads_q, head_dim = q.shape
|
||||
_, k_len, n_heads_k, _ = k.shape
|
||||
gqa = n_heads_q // n_heads_k
|
||||
k_int = k.repeat_interleave(gqa, dim = 2)
|
||||
scores = torch.einsum('bqhd,bkhd->bhqk', q, k_int) / math.sqrt(head_dim)
|
||||
|
||||
# Causal mask
|
||||
mask = torch.ones((k_len, k_len), dtype = torch.bool, device = q.device).triu(diagonal = 1)
|
||||
mask = mask[-q_len:, :]
|
||||
scores = scores.masked_fill_(mask, -65504.)
|
||||
|
||||
# Now attention
|
||||
o = flash_attn_func(
|
||||
q = q,
|
||||
k = k,
|
||||
v = v,
|
||||
causal = True,
|
||||
)
|
||||
return o, scores
|
||||
|
||||
# Refence method
|
||||
def int_quant(v, bits):
|
||||
m = 1 << (bits - 1)
|
||||
scales = torch.amax(v.abs(), dim = -1).unsqueeze(3)
|
||||
v = v / scales
|
||||
vq = (v * m).round().clamp(-m, m - 1)
|
||||
vq /= m
|
||||
vq *= scales
|
||||
return vq
|
||||
|
||||
# def quant_nf4(t):
|
||||
# scales = torch.amax(t.abs(), dim = -1).unsqueeze(3)
|
||||
# tq = t / scales
|
||||
# tqq = torch.empty_like(tq)
|
||||
# ext.test_nf4(tq, tqq)
|
||||
# tqq *= scales
|
||||
# return tqq
|
||||
|
||||
def quant_fp8(t):
|
||||
return t.to(torch.float8_e4m3fn).half()
|
||||
|
||||
|
||||
# Kernel equiv reference
|
||||
def kernel_ref_quant(v, bits):
|
||||
had32 = get_hadamard_dt(32, v.device, torch.half)
|
||||
w = v.view(-1, 32)
|
||||
m = 1 << (bits - 1)
|
||||
w = w @ had32 / math.sqrt(32)
|
||||
scales = torch.amax(w.abs(), dim = -1, keepdim = True).half()
|
||||
w = w / scales
|
||||
vq = (w * m).round().clamp(-m, m - 1)
|
||||
vq /= m
|
||||
vq *= scales
|
||||
vq = vq @ had32 / math.sqrt(32)
|
||||
vq = vq.view(v.shape)
|
||||
return vq
|
||||
|
||||
# KL divergence between softmax distributions
|
||||
def kl_divergence_scores(s, s_prime, dim = -1, eps = 1e-8):
|
||||
alpha = F.softmax(s.float(), dim = dim)
|
||||
alpha_hat = F.softmax(s_prime.float(), dim = dim)
|
||||
kl_elementwise = alpha * (torch.log(alpha + eps) - torch.log(alpha_hat + eps))
|
||||
kl_per_item = kl_elementwise.sum(dim = dim)
|
||||
kl_mean = kl_per_item.mean()
|
||||
return kl_mean
|
||||
|
||||
# Normalized MSE
|
||||
def nmse(o, o_prime):
|
||||
return (o - o_prime).square().mean() / o_prime.square().mean()
|
||||
|
||||
|
||||
# Do stuff
|
||||
def test_qkv(label, q, k, v, ref_o, ref_scores, q_rot = False, k_rot = False, v_rot = False):
|
||||
head_dim = q.shape[-1]
|
||||
had = get_hadamard_dt(head_dim, device, torch.half)
|
||||
if q_rot != k_rot: q = (q @ had) / math.sqrt(head_dim)
|
||||
if v_rot: v = (v @ had) / math.sqrt(head_dim)
|
||||
test_o, test_scores = attn(q, k, v)
|
||||
kld = kl_divergence_scores(test_scores, ref_scores)
|
||||
mse = nmse(test_o, ref_o)
|
||||
print(f"{label:26} weights_kld: {kld:.6f} output_nmse: {mse:.6f}")
|
||||
|
||||
with torch.inference_mode():
|
||||
|
||||
head_dim = samples_qkv[0][0].shape[-1]
|
||||
had = get_hadamard_dt(head_dim, device, torch.half)
|
||||
|
||||
for idx, (q, k, v) in zip(target_layers, samples_qkv):
|
||||
|
||||
# Unquantized
|
||||
ref_o, ref_scores = attn(q, k, v)
|
||||
|
||||
# Q4
|
||||
test_qkv(
|
||||
"Q4",
|
||||
q,
|
||||
int_quant(k, 4),
|
||||
int_quant(v, 4),
|
||||
ref_o,
|
||||
ref_scores
|
||||
)
|
||||
|
||||
# Q6
|
||||
test_qkv(
|
||||
"Q6",
|
||||
q,
|
||||
int_quant(k, 6),
|
||||
int_quant(v, 6),
|
||||
ref_o,
|
||||
ref_scores
|
||||
)
|
||||
|
||||
# Q8
|
||||
test_qkv(
|
||||
"Q8",
|
||||
q,
|
||||
int_quant(k, 8),
|
||||
int_quant(v, 8),
|
||||
ref_o,
|
||||
ref_scores
|
||||
)
|
||||
|
||||
# Rotated Q4
|
||||
test_qkv(
|
||||
"Rot. Q4",
|
||||
q,
|
||||
int_quant((k @ had) / math.sqrt(head_dim), 4),
|
||||
int_quant((v @ had) / math.sqrt(head_dim), 4),
|
||||
ref_o,
|
||||
ref_scores,
|
||||
False, True, True
|
||||
)
|
||||
|
||||
# Rotated Q6
|
||||
test_qkv(
|
||||
"Rot. Q6",
|
||||
q,
|
||||
int_quant((k @ had) / math.sqrt(head_dim), 6),
|
||||
int_quant((v @ had) / math.sqrt(head_dim), 6),
|
||||
ref_o,
|
||||
ref_scores,
|
||||
False, True, True
|
||||
)
|
||||
|
||||
# Channel scales + rotated Q4
|
||||
psc_k = k.view(-1, k.shape[-2], k.shape[-1]).abs().mean(dim = 0)
|
||||
psc_v = v.view(-1, k.shape[-2], k.shape[-1]).abs().mean(dim = 0)
|
||||
test_qkv(
|
||||
"Rot. Q4 ch.scales",
|
||||
q,
|
||||
int_quant(((k / psc_k) @ had) / math.sqrt(head_dim), 4) @ had / math.sqrt(head_dim) * psc_k,
|
||||
int_quant(((v / psc_v) @ had) / math.sqrt(head_dim), 4) @ had / math.sqrt(head_dim) * psc_v,
|
||||
ref_o,
|
||||
ref_scores,
|
||||
False, False, False
|
||||
)
|
||||
|
||||
# Channel scales + rotated Q4 RMS
|
||||
pscr_k = k.view(-1, k.shape[-2], k.shape[-1]).square().mean(dim = 0).sqrt()
|
||||
pscr_v = v.view(-1, k.shape[-2], k.shape[-1]).square().mean(dim = 0).sqrt()
|
||||
test_qkv(
|
||||
"Rot. Q4 ch.scales (RMS)",
|
||||
q,
|
||||
int_quant(((k / pscr_k) @ had) / math.sqrt(head_dim), 4) @ had / math.sqrt(head_dim) * pscr_k,
|
||||
int_quant(((v / pscr_v) @ had) / math.sqrt(head_dim), 4) @ had / math.sqrt(head_dim) * pscr_v,
|
||||
ref_o,
|
||||
ref_scores,
|
||||
False, False, False
|
||||
)
|
||||
|
||||
# Rotated Q4 + Q6
|
||||
test_qkv(
|
||||
"Rot. Q4+Q6",
|
||||
q,
|
||||
int_quant((k @ had) / math.sqrt(head_dim), 4),
|
||||
int_quant((v @ had) / math.sqrt(head_dim), 6),
|
||||
ref_o,
|
||||
ref_scores,
|
||||
False, True, True
|
||||
)
|
||||
|
||||
# NF4
|
||||
# k_nf4 = quant_nf4(k)
|
||||
# v_nf4 = quant_nf4(v)
|
||||
# test_qkv("NF4", q, k_nf4, v_nf4, ref_o, ref_scores, False, False, False)
|
||||
|
||||
# Rotated NF4
|
||||
# k_h = (k @ had) / math.sqrt(128)
|
||||
# v_h = (v @ had) / math.sqrt(128)
|
||||
# k_h_nf4 = quant_nf4(k_h)
|
||||
# v_h_nf4 = quant_nf4(v_h)
|
||||
# test_qkv("RNF4", q, k_h_nf4, v_h_nf4, ref_o, ref_scores, False, True, True)
|
||||
|
||||
# FP8
|
||||
test_qkv(
|
||||
"FP8 e4m3",
|
||||
q,
|
||||
quant_fp8(k),
|
||||
quant_fp8(v),
|
||||
ref_o,
|
||||
ref_scores,
|
||||
False, False, False
|
||||
)
|
||||
|
||||
# Kernel
|
||||
for bits in range(2, 9):
|
||||
quant_shape = k.shape[:-1] + (128 // 32 * bits,)
|
||||
scale_shape = k.shape[:-1] + (128 // 32,)
|
||||
k_quant = torch.zeros(quant_shape, dtype = torch.int, device = k.device)
|
||||
k_scale = torch.zeros(scale_shape, dtype = torch.half, device = k.device)
|
||||
v_quant = torch.zeros(quant_shape, dtype = torch.int, device = k.device)
|
||||
v_scale = torch.zeros(scale_shape, dtype = torch.half, device = k.device)
|
||||
ext.quant_cache_cont(k, k_quant, k_scale)
|
||||
ext.quant_cache_cont(v, v_quant, v_scale)
|
||||
k_kern = torch.empty_like(k)
|
||||
v_kern = torch.empty_like(v)
|
||||
ext.dequant_cache_cont(k_quant, k_scale, k_kern)
|
||||
ext.dequant_cache_cont(v_quant, v_scale, v_kern)
|
||||
test_qkv(f"Kernel {bits} bits", q, k_kern, v_kern, ref_o, ref_scores, False, False, False)
|
||||
|
||||
# Reference
|
||||
test_qkv(f"Kernel ref 4 bits",
|
||||
q,
|
||||
kernel_ref_quant(k, 4),
|
||||
kernel_ref_quant(v, 4),
|
||||
ref_o,
|
||||
ref_scores,
|
||||
False, False, False
|
||||
)
|
||||
144
tests/test_kv_quant.py
Normal file
144
tests/test_kv_quant.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import sys, os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
import pytest
|
||||
import torch
|
||||
from exllamav3.ext import exllamav3_ext as ext
|
||||
import random
|
||||
|
||||
torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 200)
|
||||
|
||||
devices = [
|
||||
"cuda:1"
|
||||
]
|
||||
|
||||
page_size = 256
|
||||
block_table_sizes = [(1,4), (1,8), (3, 4), (8,2)]
|
||||
head_dims = [128, 64, 96, 32, 256]
|
||||
num_kv_headss = [8, 2, 1]
|
||||
cache_sizes = [32768]
|
||||
bitss = [8] # Not testing accuracy, so 8-bit only to test the paging logic
|
||||
|
||||
@pytest.mark.parametrize("device", devices)
|
||||
@pytest.mark.parametrize("block_table_size", block_table_sizes)
|
||||
@pytest.mark.parametrize("head_dim", head_dims)
|
||||
@pytest.mark.parametrize("num_kv_heads", num_kv_headss)
|
||||
@pytest.mark.parametrize("cache_size", cache_sizes)
|
||||
@pytest.mark.parametrize("bits", bitss)
|
||||
@torch.inference_mode()
|
||||
def test_kv_quant(device, block_table_size, head_dim, num_kv_heads, cache_size, bits):
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
bsz, pages = block_table_size
|
||||
|
||||
block_table = torch.arange(bsz * pages, dtype = torch.int, device = device).view(bsz, pages)
|
||||
cache_seqlens = torch.zeros(size = (bsz,), dtype = torch.int, device = device)
|
||||
|
||||
cache_shape = (cache_size // page_size, page_size, num_kv_heads, head_dim)
|
||||
cache_k_tensor = torch.zeros(cache_shape, dtype = torch.half, device = device)
|
||||
cache_v_tensor = torch.zeros(cache_shape, dtype = torch.half, device = device)
|
||||
cache_k_tensor_out = torch.zeros_like(cache_k_tensor)
|
||||
cache_v_tensor_out = torch.zeros_like(cache_v_tensor)
|
||||
|
||||
qcache_shape = (cache_size // page_size, page_size, num_kv_heads * head_dim // 32 * bits)
|
||||
qscales_shape = (cache_size // page_size, page_size, num_kv_heads * head_dim // 32)
|
||||
cache_k_q = torch.zeros(qcache_shape, dtype = torch.int, device = device)
|
||||
cache_v_q = torch.zeros(qcache_shape, dtype = torch.int, device = device)
|
||||
cache_k_s = torch.zeros(qscales_shape, dtype = torch.half, device = device)
|
||||
cache_v_s = torch.zeros(qscales_shape, dtype = torch.half, device = device)
|
||||
|
||||
|
||||
def q(length):
|
||||
ext.quant_cache_paged(
|
||||
cache_k_tensor,
|
||||
cache_k_q,
|
||||
cache_k_s,
|
||||
cache_v_tensor,
|
||||
cache_v_q,
|
||||
cache_v_s,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
page_size,
|
||||
length
|
||||
)
|
||||
|
||||
def dq():
|
||||
ext.dequant_cache_paged(
|
||||
cache_k_q,
|
||||
cache_k_s,
|
||||
cache_k_tensor_out,
|
||||
cache_v_q,
|
||||
cache_v_s,
|
||||
cache_v_tensor_out,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
page_size
|
||||
)
|
||||
|
||||
def tq():
|
||||
torch.testing.assert_close(cache_k_tensor, cache_k_tensor_out, atol = 0.08, rtol = 0.01)
|
||||
torch.testing.assert_close(cache_v_tensor, cache_v_tensor_out, atol = 0.08, rtol = 0.01)
|
||||
|
||||
# Put some stuff in cache
|
||||
for i in range(bsz):
|
||||
cache_seqlens[i] = i
|
||||
for h in range(num_kv_heads):
|
||||
cache_k_tensor[block_table[i, 0], i, h, :] = h
|
||||
cache_v_tensor[block_table[i, 0], i, h, :] = h + num_kv_heads
|
||||
q(1)
|
||||
for i in range(bsz):
|
||||
cache_seqlens[i] += 1
|
||||
dq()
|
||||
torch.cuda.synchronize()
|
||||
tq()
|
||||
|
||||
# Put more stuff in the cache
|
||||
new_cache_seqlens = torch.zeros_like(cache_seqlens)
|
||||
random.seed(0)
|
||||
for i in range(bsz):
|
||||
l = random.randint(10, pages * page_size - 2)
|
||||
new_cache_seqlens[i] = l
|
||||
for j in range(l):
|
||||
m = j % 13
|
||||
for h in range(num_kv_heads):
|
||||
cache_k_tensor[block_table[i, j // page_size], j % page_size, h, :] = h + m
|
||||
cache_v_tensor[block_table[i, j // page_size], j % page_size, h, :] = h + m + num_kv_heads
|
||||
cache_seqlens[:] = 0
|
||||
q(new_cache_seqlens.amax())
|
||||
cache_seqlens.copy_(new_cache_seqlens)
|
||||
dq()
|
||||
torch.cuda.synchronize()
|
||||
tq()
|
||||
|
||||
# Mess up pages
|
||||
block_table = block_table.flatten()[torch.randperm(block_table.numel())].view(block_table.shape)
|
||||
cache_k_q[:, :, :] = 0
|
||||
cache_v_q[:, :, :] = 0
|
||||
cache_k_s[:, :, :] = 0
|
||||
cache_v_s[:, :, :] = 0
|
||||
for i in range(bsz):
|
||||
l = new_cache_seqlens[i]
|
||||
for j in range(l):
|
||||
cache_k_tensor[block_table[i, j // page_size], j % page_size, :, :] += 1
|
||||
cache_v_tensor[block_table[i, j // page_size], j % page_size, :, :] += 1
|
||||
cache_seqlens[:] = 0
|
||||
q(new_cache_seqlens.amax())
|
||||
cache_seqlens.copy_(new_cache_seqlens)
|
||||
dq()
|
||||
torch.cuda.synchronize()
|
||||
tq()
|
||||
|
||||
# Update five tokens
|
||||
for i in range(bsz):
|
||||
l = cache_seqlens[i]
|
||||
for j in range(5):
|
||||
pos = l + j
|
||||
cache_k_tensor[block_table[i, pos // page_size], + pos % page_size, :, :] = 32 + j
|
||||
cache_v_tensor[block_table[i, pos // page_size], + pos % page_size, :, :] = 32 + j
|
||||
q(5)
|
||||
for i in range(bsz):
|
||||
cache_seqlens[i] += 5
|
||||
dq()
|
||||
tq()
|
||||
|
||||
xx = 0
|
||||
Reference in New Issue
Block a user