Add some code for evaluating FPx (not enabled)

This commit is contained in:
turboderp
2024-09-28 16:06:59 +02:00
parent d393bfe4a7
commit be3de0fa85
3 changed files with 253 additions and 0 deletions

View File

@@ -0,0 +1,249 @@
import torch
from torch import Tensor
import gc
# From https://github.com/PygmalionAI/aphrodite-engine/blob/main/aphrodite/quantization/utils/fp6_utils.py
def _n_ones(n: int) -> int:
return (1 << n) - 1
EBITS_F32, MBITS_F32 = 8, 23
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
_ONES_TABLE = [_n_ones(i) for i in range(8)]
def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
"""Convert FP32 numbers to sub-byte floating point numbers with the given
number of exponent and mantissa bits.
Input: torch.Tensor of dtype torch.float
Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored
in the least significant bits. e.g.
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
Note: there are no special values (NaN, inf) support in this code. Values
outside the representable range of FPx after rounding are clamped to the
maximum FPx magnitude (sign is preserved).
Code below is an adaptation of https://fburl.com/code/ciwofcg4
Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501
Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
"""
assert x.dtype == torch.float
assert 1 + ebits + mbits <= 8
# calculate constants
exp_bias = _n_ones(ebits - 1)
max_int = _n_ones(ebits + mbits)
sign_mask = 1 << (ebits + mbits)
# TODO document this better
magic_adder = _n_ones(MBITS_F32 - mbits - 1)
# all E bits and M bits are 1s
max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
# E bits = 1, M bits = 0
min_normal = 2 ** (1 - exp_bias)
denorm_exp = (
# exp bias conversion between formats
(F32_EXP_BIAS - exp_bias)
# mantissa length difference between formats
+ (MBITS_F32 - mbits)
# add one to encoded exponent for denormalized numbers
+ 1
)
denorm_mask_int = denorm_exp << MBITS_F32
# reinterpret int32 as float32
denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32)
# save the sign
# Note that we have torch.uint32, but some ops like cpu bit shifts
# do not work on it. So, we stay in int32.
x = x.view(torch.int32)
sign = x & 0x80000000
# set everything to positive, will add sign back at the end
x = x ^ sign
# TODO: can the branch floating point comparisons below be done without
# converting to float? probably but need to verify
x = x.view(torch.float)
# rewrite saturate/denorm/norm branches without explicit data dependent
# control flow, to be more compiler friendly
saturate_mask = x >= max_normal
denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
#
# branch 1: saturate to max val - handled later in the code which combines
# the branches
#
#
# branch 2: to conversion to denormal as well as rounding up to normal
#
denormal_x = x + denorm_mask_float
denormal_x = denormal_x.view(torch.int32)
denormal_x -= denorm_mask_int
denormal_x = denormal_x.to(torch.uint8)
#
# branch 3: stay in normal range, adjust the exponent and round
#
normal_x = x.view(torch.int32)
# resulting mantissa is odd
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
# update exponent, rounding bias part 1
val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
normal_x += val_to_add
# rounding bias part 2
normal_x += mant_odd
# take the bits!
normal_x = normal_x >> (MBITS_F32 - mbits)
normal_x = normal_x.to(torch.uint8)
#
# combine the branches
#
x = torch.full_like(x, max_int, dtype=torch.uint8)
x = torch.where(denormal_mask, denormal_x, x)
x = torch.where(normal_mask, normal_x, x)
# add sign back
sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
sign_lp = sign_lp.to(torch.uint8)
# Right shift of a negative signed integer can fill the least significant
# bits with either 1s or 0s, depending on the implementation. Since PyTorch
# doesn't have an uint32 dtype, we mask out these bits to get just the
# f4 sign bit
sign_lp = sign_lp & sign_mask
x = x | sign_lp
return x.to(torch.uint8)
def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
"""Convert sub-byte floating point numbers with the given number of exponent
and mantissa bits to FP32.
Input: torch.Tensor of dtype uint8, where the bit encoding is stored
in the least significant bits. e.g.
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
Output: torch.Tensor of dtype fp32 with the dequantized value
"""
assert x.dtype == torch.uint8
assert 1 + ebits + mbits <= 8
sign_mask = 1 << (ebits + mbits)
exp_bias = _n_ones(ebits - 1)
mantissa_mask = _n_ones(mbits)
# save the sign
sign_lp = x & sign_mask
# set everything to positive, will add sign back at the end
x_pos = x ^ sign_lp
#
# 1. Calculate zero mask
#
zero_mask = x_pos == 0
#
# 2. Calculate the denormal path mask
#
denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0))
#
# 3. Calculate the normal path
#
# calculate the new exponent and shift it to bits 2:9 of the result
exp_biased_lp = x_pos >> mbits
exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS
exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32
# shift the mantissa to bits 10:32 of the result
mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32)
mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits)
result = exp_biased_f32 | mantissa_f32
#
# 4. Add the zero and denormal casts to the already casted normal path
#
result[zero_mask] = 0
denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS
# fast path.
# without this, performance for FP4_E2M1 is slower by 2x
if mbits == 1:
result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32
else:
# iterate over all possible values of mantissa
# i=0, j=1
# i=1, j=10,11
# i=2, j=100,101,110,111
# and so on
for i in range(mbits):
for mantissa_cmp in range(1 << i, 1 << (i+1)):
# left shift mantissa until it overflows (create an implicit 1)
# subtract exponent by the same amount
left_shift = mbits - i
mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits)
exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32
# we can update this in-place since the values won't overlap
# torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int'
# thus we use + instead of | here
mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 + mantissa_f32
result = torch.where(denormal_mask, mantissa_lp_int32, result)
# add sign back
sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
result = result | sign_f32
return result.view(torch.float)
def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> tuple[Tensor, Tensor]:
# _n_ones() is not compatible with torch.compile() due to << operator
# https://github.com/pytorch/pytorch/issues/119152
# exp_bias = _n_ones(ebits - 1)
# max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
# workaround: global lookup table
exp_bias = _ONES_TABLE[ebits - 1]
max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits))
tensor = tensor.float()
scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
tensor_fpx = _f32_to_fpx_unpacked(tensor / scale.view(-1, 1), ebits, mbits)
return tensor_fpx, scale.half()
# tensor_tc_fpx = pack_tc_fpx(tensor_fpx, 1 + ebits + mbits)
# return tensor_tc_fpx, scale.half()
def from_scaled_tc_fpx(fpx_unpacked: Tensor, ebits: int, mbits: int, scale = None) -> Tensor:
# fpx_unpacked = unpack_tc_fpx(tensor, 1 + ebits + mbits)
tensor = _fpx_unpacked_to_f32(fpx_unpacked, ebits, mbits)
if scale is not None:
tensor = tensor * scale.float().view(-1, 1)
return tensor
def fpxify(tensor: torch.Tensor, exponent: int, mantissa: int) -> torch.Tensor:
"""
Convert to eXmY and back again
"""
a = tensor.to("cuda:0").float()
b, scale = to_scaled_tc_fpx(a, exponent, mantissa)
c = from_scaled_tc_fpx(b, exponent, mantissa, scale)
d = c.half().to(tensor.device)
return d

View File

@@ -1989,6 +1989,7 @@ class ExLlamaV2DynamicJob:
return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item())
# Stop if we reach max_new_tokens
# TODO: Auto-extend option
if self.new_tokens >= self.max_new_tokens - self.generator.num_draft_tokens:
return emit(results, emit_eos = True, emit_held = True, eos_reason = "max_new_tokens")

View File

@@ -9,6 +9,7 @@ from exllamav2.compat import safe_move_tensor
from exllamav2.tensor_p import BROADCAST_VC
from exllamav2.util import unpack_4bit, pack_4bit
import gc
from exllamav2.experimental.fpx import fpxify
from typing import TYPE_CHECKING
@@ -170,6 +171,7 @@ class ExLlamaV2Linear(ExLlamaV2Module):
elif isinstance(w, nn.Parameter):
assert not self.has_bias, self.key + " has no bias tensor but bias is expected"
# w = nn.Parameter(fpxify(w.data, 2, 3), requires_grad = False)
if self.normalize_unq:
w = self.normalize(w)
if self.padding > 0: w = nn.Parameter(F.pad(w.data, (0, 0, 0, self.padding)).contiguous())
@@ -188,6 +190,7 @@ class ExLlamaV2Linear(ExLlamaV2Module):
if self.normalize_unq:
w = self.normalize(w[0]), w[1]
ww = w[0]
# ww = nn.Parameter(fpxify(ww.data, 2, 3), requires_grad = False)
wb = w[1]
if self.padding > 0:
ww = nn.Parameter(F.pad(ww.data, (0, 0, 0, self.padding)).contiguous())