mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Add some code for evaluating FPx (not enabled)
This commit is contained in:
249
exllamav2/experimental/fpx.py
Normal file
249
exllamav2/experimental/fpx.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import gc
|
||||
|
||||
# From https://github.com/PygmalionAI/aphrodite-engine/blob/main/aphrodite/quantization/utils/fp6_utils.py
|
||||
|
||||
def _n_ones(n: int) -> int:
|
||||
return (1 << n) - 1
|
||||
|
||||
EBITS_F32, MBITS_F32 = 8, 23
|
||||
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
|
||||
|
||||
_ONES_TABLE = [_n_ones(i) for i in range(8)]
|
||||
|
||||
def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
|
||||
"""Convert FP32 numbers to sub-byte floating point numbers with the given
|
||||
number of exponent and mantissa bits.
|
||||
Input: torch.Tensor of dtype torch.float
|
||||
Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored
|
||||
in the least significant bits. e.g.
|
||||
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
|
||||
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
|
||||
Note: there are no special values (NaN, inf) support in this code. Values
|
||||
outside the representable range of FPx after rounding are clamped to the
|
||||
maximum FPx magnitude (sign is preserved).
|
||||
Code below is an adaptation of https://fburl.com/code/ciwofcg4
|
||||
Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501
|
||||
Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
|
||||
"""
|
||||
assert x.dtype == torch.float
|
||||
assert 1 + ebits + mbits <= 8
|
||||
|
||||
# calculate constants
|
||||
exp_bias = _n_ones(ebits - 1)
|
||||
max_int = _n_ones(ebits + mbits)
|
||||
sign_mask = 1 << (ebits + mbits)
|
||||
|
||||
# TODO document this better
|
||||
magic_adder = _n_ones(MBITS_F32 - mbits - 1)
|
||||
|
||||
# all E bits and M bits are 1s
|
||||
max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
|
||||
|
||||
# E bits = 1, M bits = 0
|
||||
min_normal = 2 ** (1 - exp_bias)
|
||||
|
||||
denorm_exp = (
|
||||
# exp bias conversion between formats
|
||||
(F32_EXP_BIAS - exp_bias)
|
||||
# mantissa length difference between formats
|
||||
+ (MBITS_F32 - mbits)
|
||||
# add one to encoded exponent for denormalized numbers
|
||||
+ 1
|
||||
)
|
||||
denorm_mask_int = denorm_exp << MBITS_F32
|
||||
|
||||
# reinterpret int32 as float32
|
||||
denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32)
|
||||
|
||||
# save the sign
|
||||
# Note that we have torch.uint32, but some ops like cpu bit shifts
|
||||
# do not work on it. So, we stay in int32.
|
||||
x = x.view(torch.int32)
|
||||
sign = x & 0x80000000
|
||||
|
||||
# set everything to positive, will add sign back at the end
|
||||
x = x ^ sign
|
||||
|
||||
# TODO: can the branch floating point comparisons below be done without
|
||||
# converting to float? probably but need to verify
|
||||
x = x.view(torch.float)
|
||||
|
||||
# rewrite saturate/denorm/norm branches without explicit data dependent
|
||||
# control flow, to be more compiler friendly
|
||||
saturate_mask = x >= max_normal
|
||||
denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
|
||||
normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
|
||||
|
||||
#
|
||||
# branch 1: saturate to max val - handled later in the code which combines
|
||||
# the branches
|
||||
#
|
||||
|
||||
#
|
||||
# branch 2: to conversion to denormal as well as rounding up to normal
|
||||
#
|
||||
denormal_x = x + denorm_mask_float
|
||||
denormal_x = denormal_x.view(torch.int32)
|
||||
denormal_x -= denorm_mask_int
|
||||
denormal_x = denormal_x.to(torch.uint8)
|
||||
|
||||
#
|
||||
# branch 3: stay in normal range, adjust the exponent and round
|
||||
#
|
||||
normal_x = x.view(torch.int32)
|
||||
# resulting mantissa is odd
|
||||
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
|
||||
# update exponent, rounding bias part 1
|
||||
val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
|
||||
normal_x += val_to_add
|
||||
# rounding bias part 2
|
||||
normal_x += mant_odd
|
||||
# take the bits!
|
||||
normal_x = normal_x >> (MBITS_F32 - mbits)
|
||||
normal_x = normal_x.to(torch.uint8)
|
||||
|
||||
#
|
||||
# combine the branches
|
||||
#
|
||||
x = torch.full_like(x, max_int, dtype=torch.uint8)
|
||||
x = torch.where(denormal_mask, denormal_x, x)
|
||||
x = torch.where(normal_mask, normal_x, x)
|
||||
|
||||
# add sign back
|
||||
sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
|
||||
sign_lp = sign_lp.to(torch.uint8)
|
||||
# Right shift of a negative signed integer can fill the least significant
|
||||
# bits with either 1s or 0s, depending on the implementation. Since PyTorch
|
||||
# doesn't have an uint32 dtype, we mask out these bits to get just the
|
||||
# f4 sign bit
|
||||
sign_lp = sign_lp & sign_mask
|
||||
x = x | sign_lp
|
||||
|
||||
return x.to(torch.uint8)
|
||||
|
||||
|
||||
def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
|
||||
"""Convert sub-byte floating point numbers with the given number of exponent
|
||||
and mantissa bits to FP32.
|
||||
Input: torch.Tensor of dtype uint8, where the bit encoding is stored
|
||||
in the least significant bits. e.g.
|
||||
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
|
||||
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
|
||||
Output: torch.Tensor of dtype fp32 with the dequantized value
|
||||
"""
|
||||
assert x.dtype == torch.uint8
|
||||
assert 1 + ebits + mbits <= 8
|
||||
|
||||
sign_mask = 1 << (ebits + mbits)
|
||||
exp_bias = _n_ones(ebits - 1)
|
||||
mantissa_mask = _n_ones(mbits)
|
||||
|
||||
# save the sign
|
||||
sign_lp = x & sign_mask
|
||||
|
||||
# set everything to positive, will add sign back at the end
|
||||
x_pos = x ^ sign_lp
|
||||
|
||||
#
|
||||
# 1. Calculate zero mask
|
||||
#
|
||||
zero_mask = x_pos == 0
|
||||
|
||||
#
|
||||
# 2. Calculate the denormal path mask
|
||||
#
|
||||
denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0))
|
||||
|
||||
#
|
||||
# 3. Calculate the normal path
|
||||
#
|
||||
|
||||
# calculate the new exponent and shift it to bits 2:9 of the result
|
||||
exp_biased_lp = x_pos >> mbits
|
||||
exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS
|
||||
exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32
|
||||
|
||||
# shift the mantissa to bits 10:32 of the result
|
||||
mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32)
|
||||
mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits)
|
||||
result = exp_biased_f32 | mantissa_f32
|
||||
|
||||
#
|
||||
# 4. Add the zero and denormal casts to the already casted normal path
|
||||
#
|
||||
result[zero_mask] = 0
|
||||
|
||||
denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS
|
||||
|
||||
# fast path.
|
||||
# without this, performance for FP4_E2M1 is slower by 2x
|
||||
if mbits == 1:
|
||||
result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32
|
||||
|
||||
else:
|
||||
# iterate over all possible values of mantissa
|
||||
# i=0, j=1
|
||||
# i=1, j=10,11
|
||||
# i=2, j=100,101,110,111
|
||||
# and so on
|
||||
for i in range(mbits):
|
||||
for mantissa_cmp in range(1 << i, 1 << (i+1)):
|
||||
# left shift mantissa until it overflows (create an implicit 1)
|
||||
# subtract exponent by the same amount
|
||||
left_shift = mbits - i
|
||||
mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits)
|
||||
exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32
|
||||
|
||||
# we can update this in-place since the values won't overlap
|
||||
# torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int'
|
||||
# thus we use + instead of | here
|
||||
mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 + mantissa_f32
|
||||
|
||||
result = torch.where(denormal_mask, mantissa_lp_int32, result)
|
||||
|
||||
# add sign back
|
||||
sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
|
||||
result = result | sign_f32
|
||||
|
||||
return result.view(torch.float)
|
||||
|
||||
|
||||
def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> tuple[Tensor, Tensor]:
|
||||
# _n_ones() is not compatible with torch.compile() due to << operator
|
||||
# https://github.com/pytorch/pytorch/issues/119152
|
||||
# exp_bias = _n_ones(ebits - 1)
|
||||
# max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
|
||||
|
||||
# workaround: global lookup table
|
||||
exp_bias = _ONES_TABLE[ebits - 1]
|
||||
max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits))
|
||||
|
||||
tensor = tensor.float()
|
||||
scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
|
||||
tensor_fpx = _f32_to_fpx_unpacked(tensor / scale.view(-1, 1), ebits, mbits)
|
||||
return tensor_fpx, scale.half()
|
||||
# tensor_tc_fpx = pack_tc_fpx(tensor_fpx, 1 + ebits + mbits)
|
||||
# return tensor_tc_fpx, scale.half()
|
||||
|
||||
|
||||
def from_scaled_tc_fpx(fpx_unpacked: Tensor, ebits: int, mbits: int, scale = None) -> Tensor:
|
||||
# fpx_unpacked = unpack_tc_fpx(tensor, 1 + ebits + mbits)
|
||||
tensor = _fpx_unpacked_to_f32(fpx_unpacked, ebits, mbits)
|
||||
if scale is not None:
|
||||
tensor = tensor * scale.float().view(-1, 1)
|
||||
return tensor
|
||||
|
||||
|
||||
|
||||
def fpxify(tensor: torch.Tensor, exponent: int, mantissa: int) -> torch.Tensor:
|
||||
"""
|
||||
Convert to eXmY and back again
|
||||
"""
|
||||
|
||||
a = tensor.to("cuda:0").float()
|
||||
b, scale = to_scaled_tc_fpx(a, exponent, mantissa)
|
||||
c = from_scaled_tc_fpx(b, exponent, mantissa, scale)
|
||||
d = c.half().to(tensor.device)
|
||||
return d
|
||||
@@ -1989,6 +1989,7 @@ class ExLlamaV2DynamicJob:
|
||||
return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item())
|
||||
|
||||
# Stop if we reach max_new_tokens
|
||||
# TODO: Auto-extend option
|
||||
|
||||
if self.new_tokens >= self.max_new_tokens - self.generator.num_draft_tokens:
|
||||
return emit(results, emit_eos = True, emit_held = True, eos_reason = "max_new_tokens")
|
||||
|
||||
@@ -9,6 +9,7 @@ from exllamav2.compat import safe_move_tensor
|
||||
from exllamav2.tensor_p import BROADCAST_VC
|
||||
from exllamav2.util import unpack_4bit, pack_4bit
|
||||
import gc
|
||||
from exllamav2.experimental.fpx import fpxify
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -170,6 +171,7 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
|
||||
elif isinstance(w, nn.Parameter):
|
||||
assert not self.has_bias, self.key + " has no bias tensor but bias is expected"
|
||||
# w = nn.Parameter(fpxify(w.data, 2, 3), requires_grad = False)
|
||||
if self.normalize_unq:
|
||||
w = self.normalize(w)
|
||||
if self.padding > 0: w = nn.Parameter(F.pad(w.data, (0, 0, 0, self.padding)).contiguous())
|
||||
@@ -188,6 +190,7 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
if self.normalize_unq:
|
||||
w = self.normalize(w[0]), w[1]
|
||||
ww = w[0]
|
||||
# ww = nn.Parameter(fpxify(ww.data, 2, 3), requires_grad = False)
|
||||
wb = w[1]
|
||||
if self.padding > 0:
|
||||
ww = nn.Parameter(F.pad(ww.data, (0, 0, 0, self.padding)).contiguous())
|
||||
|
||||
Reference in New Issue
Block a user