mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 06:19:10 +00:00
183 lines
5.2 KiB
Python
183 lines
5.2 KiB
Python
import sys, os
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
import pytest
|
|
import torch
|
|
from exllamav3 import Config, Model
|
|
from exllamav3.ext import exllamav3_ext as ext
|
|
from exllamav3.modules.quant.exl3_lib.quantize import quantize_tiles
|
|
from util import assert_close_mr
|
|
import torch.nn.functional as F
|
|
import torch.testing
|
|
import math
|
|
|
|
torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 200)
|
|
|
|
device = "cuda:2"
|
|
test_model = "/mnt/str/models/llama3.1-8b-instruct/hf/"
|
|
test_keys = [
|
|
"model.layers.0.self_attn.q_proj",
|
|
"model.layers.0.self_attn.k_proj",
|
|
"model.layers.0.self_attn.v_proj",
|
|
"model.layers.0.self_attn.o_proj",
|
|
"model.layers.0.mlp.up_proj",
|
|
"model.layers.0.mlp.gate_proj",
|
|
"model.layers.0.mlp.down_proj",
|
|
]
|
|
|
|
config = Config.from_directory(test_model)
|
|
model = Model.from_config(config)
|
|
|
|
max_mse_per_K = {
|
|
1: 0.3,
|
|
2: 0.1,
|
|
3: 0.1,
|
|
4: 0.1,
|
|
5: 0.1,
|
|
6: 0.07,
|
|
7: 0.05,
|
|
8: 0.04,
|
|
}
|
|
|
|
max_proxy_err_per_K = {
|
|
1: 0.5,
|
|
2: 0.1,
|
|
3: 0.05,
|
|
4: 0.01,
|
|
5: 0.005,
|
|
6: 0.005,
|
|
7: 0.005,
|
|
8: 0.005,
|
|
}
|
|
|
|
w_tol_per_K = {
|
|
1: (0.5, 0.5),
|
|
2: (0.1, 0.1),
|
|
3: (0.08, 0.08),
|
|
4: (0.06, 0.06),
|
|
5: (0.04, 0.04),
|
|
6: (0.03, 0.03),
|
|
7: (0.02, 0.02),
|
|
8: (0.02, 0.02),
|
|
}
|
|
|
|
|
|
@pytest.mark.parametrize("cb", [(False, False, 1.24371088), (True, False, 1.24371088), (False, True, 1.0)])
|
|
@pytest.mark.parametrize("batch_size", [1, 16, 17, 128])
|
|
@pytest.mark.parametrize("K", [1, 2, 3, 4, 5, 6, 7, 8])
|
|
@torch.inference_mode()
|
|
def test_encode(batch_size, K, cb):
|
|
|
|
torch.manual_seed(0)
|
|
mcg, mul1, scale = cb
|
|
in_tile = torch.randn((batch_size, 256), device = device) * scale
|
|
out_tile, out_idx = quantize_tiles(
|
|
in_tile,
|
|
{
|
|
"K": K,
|
|
"mcg": mcg,
|
|
"mul1": mul1,
|
|
}
|
|
)
|
|
|
|
# Test tail-biting
|
|
first_col = out_idx[:, 0].to(torch.int32) & 0xFFFF
|
|
last_col = out_idx[:, 255].to(torch.int32) & 0xFFFF
|
|
first_col = first_col >> K
|
|
last_col = last_col & ((1 << (16 - K)) - 1)
|
|
assert torch.equal(first_col, last_col)
|
|
|
|
# Test MSE
|
|
mse = F.mse_loss(in_tile / scale, out_tile / scale).item()
|
|
assert mse < max_mse_per_K[K]
|
|
|
|
|
|
@pytest.mark.parametrize("cb", [(False, False, 1.24371088), (True, False, 1.24371088), (False, True, 1.0)])
|
|
@pytest.mark.parametrize("batch_size", [1, 64])
|
|
@pytest.mark.parametrize("K", [1, 2, 3, 4, 5, 6, 7, 8])
|
|
@torch.inference_mode()
|
|
def test_encode_ideal(batch_size, K, cb):
|
|
|
|
# Create random, valid, tail-biting encoding
|
|
torch.manual_seed(0)
|
|
mcg, mul1, scale = cb
|
|
encoded = torch.randint(low = 0, high = 65535, size = (batch_size, 256), device = device)
|
|
for i in range(256):
|
|
x = encoded[:, i]
|
|
x = x & ((1 << K) - 1)
|
|
for shift in range(1, int(math.ceil(16 / K))):
|
|
j = (i + 256 - shift) % 256
|
|
y = encoded[:, j]
|
|
y = y & ((1 << K) - 1)
|
|
x = x | (y << (K * shift))
|
|
encoded[:, i] = x & 0xffff
|
|
encoded = encoded.to(torch.short)
|
|
|
|
# Decode
|
|
decoded = torch.empty_like(encoded, dtype = torch.float)
|
|
ext.decode(encoded, decoded, mcg, mul1)
|
|
|
|
# Should quantize with zero loss
|
|
out_tile, out_idx = quantize_tiles(
|
|
decoded,
|
|
{
|
|
"K": K,
|
|
"mcg": mcg,
|
|
"mul1": mul1,
|
|
}
|
|
)
|
|
torch.testing.assert_close(out_tile, decoded, rtol = 1e-6, atol = 1e-6)
|
|
|
|
|
|
@pytest.mark.parametrize("cb", [(False, False, 1.24371088), (True, False, 1.24371088), (False, True, 1.0)])
|
|
@pytest.mark.parametrize("K", [1, 2, 3, 4, 5, 6, 7, 8])
|
|
@pytest.mark.parametrize("test_key", test_keys)
|
|
@torch.inference_mode()
|
|
def test_quant_dequant(K, test_key, cb):
|
|
|
|
mcg, mul1, scale = cb
|
|
|
|
# Grab unquantized linear layer from model
|
|
linear = model.find_module(test_key)
|
|
linear.load(device = device)
|
|
|
|
# Forward some random data through the layer to capture Hessian
|
|
bsz = 2048
|
|
torch.manual_seed(0)
|
|
state = torch.randn((1, bsz, linear.in_features), dtype = torch.float16, device = device)
|
|
capture_H = {}
|
|
params = {
|
|
"attn_mode": "flash_attn_nc",
|
|
"capture": capture_H
|
|
}
|
|
rs = linear.prepare_for_device(state, params)
|
|
ref_out = linear.forward(rs, params)
|
|
|
|
# Copy the original weight since layer will be quantized in-place
|
|
weight_orig = linear.inner.get_weight_tensor().clone()
|
|
|
|
# Quantize the layer
|
|
quant_args = {
|
|
"K": K,
|
|
"seed": 1,
|
|
"apply_out_scales": None,
|
|
"mcg": mcg,
|
|
"mul1": mul1,
|
|
"devices": [device]
|
|
}
|
|
proxy_err, weight_q = linear.convert_exl3(capture_H[linear.qmap], quant_args, return_weight_q = True)
|
|
weight_q = weight_q.half()
|
|
|
|
# Test proxy_err
|
|
assert proxy_err < max_proxy_err_per_K[K]
|
|
|
|
# Test max absolute weight difference from original, allow for 1% outliers
|
|
rtol, atol = w_tol_per_K[K]
|
|
assert_close_mr(weight_q, weight_orig, rtol = rtol, atol = atol, mismatch_ratio = 0.01)
|
|
|
|
# Reconstruct from encoded/packed tensors. Some tolerance needed because the quantizer works in float32
|
|
# while reconstruction reverses the regularization in float16
|
|
weight_recons = linear.inner.get_weight_tensor()
|
|
assert_close_mr(weight_q, weight_recons, rtol = 1e-3, atol = 1e-3, mismatch_ratio = 0.001)
|
|
|
|
# Cleanup
|
|
linear.unload() |