mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-05-04 05:01:39 +00:00
374 lines
10 KiB
Python
374 lines
10 KiB
Python
import math
|
|
|
|
class QParams:
|
|
|
|
group_size: dict
|
|
bits: list
|
|
bits_prop: list
|
|
scale_bits: int
|
|
|
|
desc: str
|
|
|
|
def __init__(self, group_size, bits, bits_prop, scale_bits):
|
|
|
|
self.bits = bits
|
|
self.bits_prop = bits_prop
|
|
self.scale_bits = scale_bits
|
|
|
|
# Allow group size per bitrate
|
|
|
|
if isinstance(group_size, dict):
|
|
self.group_size = { int(b): g for b, g in group_size.items() }
|
|
elif isinstance(group_size, list):
|
|
assert len(group_size) == len(bits)
|
|
self.group_size = { b: g for g, b in zip(group_size, bits) }
|
|
else:
|
|
self.group_size = { b: group_size for b in bits }
|
|
|
|
self.desc = self.get_desc()
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
if len(set(self.group_size.values())) == 1:
|
|
_gs = str(list(self.group_size.values())[0])
|
|
else:
|
|
_gs = "[" + ", ".join(str(self.group_size[b]) for b in self.bits) + "]"
|
|
_b = "[" + ", ".join(str(b) for b in self.bits) + "]"
|
|
_bp = "[" + ", ".join(str(bp) for bp in self.bits_prop) + "]"
|
|
_sb = "4"
|
|
return "QParams(" + _gs + ", " + _b + ", " + _bp + ", " + _sb + ")"
|
|
|
|
|
|
def get_dict(self):
|
|
|
|
return { "group_size": self.group_size,
|
|
"bits": self.bits,
|
|
"bits_prop": self.bits_prop,
|
|
"scale_bits": self.scale_bits }
|
|
|
|
|
|
@staticmethod
|
|
def from_dict(qp_dict):
|
|
|
|
return QParams(qp_dict["group_size"],
|
|
qp_dict["bits"],
|
|
qp_dict["bits_prop"],
|
|
qp_dict["scale_bits"])
|
|
|
|
|
|
def total_bits(self, shape, bias_shape = None):
|
|
|
|
rows = shape[0]
|
|
columns = shape[1]
|
|
numel = rows * columns
|
|
|
|
groups = 0
|
|
remaining_rows = rows
|
|
bits_groups = []
|
|
for b, p in zip(self.bits, self.bits_prop):
|
|
gsz = self.group_size[b]
|
|
g = math.ceil(min(rows * p, remaining_rows) / gsz)
|
|
groups += g
|
|
remaining_rows -= g * gsz
|
|
bits_groups.append(g)
|
|
|
|
assert remaining_rows <= 0
|
|
|
|
total_bits = 0
|
|
tr = rows
|
|
|
|
for g, b in zip(bits_groups, self.bits):
|
|
|
|
r = self.group_size[b] * g
|
|
c = columns
|
|
if r > tr: r = tr
|
|
tr -= r
|
|
total_bits += r * c * b # q_weight
|
|
|
|
total_bits += groups * 16 # q_scale_max
|
|
total_bits += groups * (16 + 16) # q_groups
|
|
total_bits += groups * columns * self.scale_bits # q_scale
|
|
total_bits += rows * 32 # q_invperm
|
|
|
|
if bias_shape is not None:
|
|
bias_numel = 1
|
|
for d in bias_shape: bias_numel *= d
|
|
total_bits += 16 * d
|
|
|
|
return total_bits
|
|
|
|
|
|
def bpw(self, shape, bias_shape = None):
|
|
|
|
rows = shape[0]
|
|
columns = shape[1]
|
|
numel = rows * columns
|
|
|
|
if bias_shape is not None:
|
|
bias_numel = 1
|
|
for d in bias_shape: bias_numel *= d
|
|
numel += d
|
|
|
|
return self.total_bits(shape, bias_shape) / numel
|
|
|
|
|
|
def get_desc(self, filename = False):
|
|
|
|
s = ""
|
|
for b, p in zip(self.bits, self.bits_prop):
|
|
if s != "": s += ("__" if filename else "/")
|
|
g = self.group_size[b]
|
|
s += f"{p}{'___' if filename else ':'}{b}b_{g}g"
|
|
|
|
s += f" s{self.scale_bits}"
|
|
|
|
return s
|
|
|
|
|
|
# kernels require groupsize divisible by 32, except 2-bit groups which can be divisible by 16
|
|
|
|
qparams_attn = \
|
|
[
|
|
[
|
|
QParams(64, [3, 2], [0.05, 0.95], 4),
|
|
QParams(64, [3, 2], [0.05, 0.95], 4),
|
|
QParams(64, [3, 2], [0.05, 0.95], 4),
|
|
QParams(64, [3, 2], [0.05, 0.95], 4),
|
|
],
|
|
[
|
|
QParams(64, [3, 2], [0.1, 0.9], 4),
|
|
QParams(64, [3, 2], [0.1, 0.9], 4),
|
|
QParams(64, [3, 2], [0.25, 0.75], 4),
|
|
QParams(64, [3, 2], [0.1, 0.9], 4),
|
|
],
|
|
[
|
|
QParams(64, [3, 2], [0.1, 0.9], 4),
|
|
QParams(64, [3, 2], [0.1, 0.9], 4),
|
|
QParams(128, [4, 3], [0.1, 0.9], 4),
|
|
QParams(64, [3, 2], [0.1, 0.9], 4),
|
|
],
|
|
[
|
|
QParams(64, [3, 2], [0.1, 0.9], 4),
|
|
QParams(64, [3, 2], [0.1, 0.9], 4),
|
|
QParams(128, [4, 3], [0.1, 0.9], 4),
|
|
QParams(128, [4, 3], [0.1, 0.9], 4),
|
|
],
|
|
[
|
|
QParams(128, [4, 3], [0.1, 0.9], 4),
|
|
QParams(128, [4, 3], [0.1, 0.9], 4),
|
|
QParams(128, [4, 3], [0.1, 0.9], 4),
|
|
QParams(128, [4, 3], [0.1, 0.9], 4),
|
|
],
|
|
[
|
|
QParams(128, [4, 3], [0.1, 0.9], 4),
|
|
QParams(128, [4, 3], [0.1, 0.9], 4),
|
|
QParams(64, [4, 3], [0.1, 0.9], 4),
|
|
QParams(128, [4, 3], [0.1, 0.9], 4),
|
|
],
|
|
[
|
|
QParams(128, [4], [1], 4),
|
|
QParams(128, [4], [1], 4),
|
|
QParams(128, [4], [1], 4),
|
|
QParams(128, [4], [1], 4),
|
|
],
|
|
[
|
|
QParams(128, [4], [1], 4),
|
|
QParams(128, [4], [1], 4),
|
|
QParams(64, [4], [1], 4),
|
|
QParams(128, [4], [1], 4),
|
|
],
|
|
[
|
|
QParams(64, [4], [1], 4),
|
|
QParams(64, [4], [1], 4),
|
|
QParams(32, [4], [1], 4),
|
|
QParams(64, [4], [1], 4),
|
|
],
|
|
[
|
|
QParams(32, [4], [1], 4),
|
|
QParams(32, [4], [1], 4),
|
|
QParams(32, [4], [1], 4),
|
|
QParams(32, [4], [1], 4),
|
|
],
|
|
[
|
|
QParams(128, [5, 4], [0.1, 0.9], 4),
|
|
QParams(128, [5, 4], [0.1, 0.9], 4),
|
|
QParams(64, [5, 4], [0.1, 0.9], 4),
|
|
QParams(128, [5, 4], [0.1, 0.9], 4),
|
|
],
|
|
[
|
|
QParams(64, [5, 4], [0.1, 0.9], 4),
|
|
QParams(64, [5, 4], [0.1, 0.9], 4),
|
|
QParams(32, [5, 4], [0.1, 0.9], 4),
|
|
QParams(64, [5, 4], [0.1, 0.9], 4),
|
|
],
|
|
[
|
|
QParams(64, [5, 4], [0.1, 0.9], 4),
|
|
QParams(64, [5, 4], [0.1, 0.9], 4),
|
|
QParams(64, [5], [1], 4),
|
|
QParams(64, [5, 4], [0.1, 0.9], 4),
|
|
],
|
|
[
|
|
QParams(32, [5, 4], [0.1, 0.9], 4),
|
|
QParams(32, [5, 4], [0.1, 0.9], 4),
|
|
QParams(32, [5], [1], 4),
|
|
QParams(32, [5, 4], [0.1, 0.9], 4),
|
|
],
|
|
[
|
|
QParams(128, [6, 5], [0.1, 0.9], 4),
|
|
QParams(128, [6, 5], [0.1, 0.9], 4),
|
|
QParams(128, [6], [1], 4),
|
|
QParams(128, [6, 5], [0.1, 0.9], 4),
|
|
],
|
|
[
|
|
QParams(32, [6, 5], [0.1, 0.9], 4),
|
|
QParams(32, [6, 5], [0.1, 0.9], 4),
|
|
QParams(32, [6], [1], 4),
|
|
QParams(32, [6, 5], [0.1, 0.9], 4),
|
|
],
|
|
[
|
|
QParams(128, [6], [1], 4),
|
|
QParams(128, [6], [1], 4),
|
|
QParams(128, [6], [1], 4),
|
|
QParams(128, [6], [1], 4),
|
|
],
|
|
[
|
|
QParams(32, [6], [1], 4),
|
|
QParams(32, [6], [1], 4),
|
|
QParams(32, [8], [1], 4),
|
|
QParams(32, [6], [1], 4),
|
|
],
|
|
[
|
|
QParams(128, [8], [1], 4),
|
|
QParams(128, [8], [1], 4),
|
|
QParams(128, [8], [1], 4),
|
|
QParams(128, [8], [1], 4),
|
|
],
|
|
]
|
|
|
|
qparams_mlp = \
|
|
[
|
|
[
|
|
QParams(64, [3, 2], [0.05, 0.95], 4),
|
|
QParams(64, [3, 2], [0.05, 0.95], 4),
|
|
QParams([32, 64, 64], [6, 3, 2], [0.05, 0.2, 0.75], 4),
|
|
],
|
|
[
|
|
QParams(64, [3, 2], [0.1, 0.9], 4),
|
|
QParams(64, [3, 2], [0.25, 0.75], 4),
|
|
QParams([32, 64, 64], [6, 3, 2], [0.05, 0.2, 0.75], 4),
|
|
],
|
|
[
|
|
QParams(64, [3, 2], [0.1, 0.9], 4),
|
|
QParams(64, [3, 2], [0.3, 0.7], 4),
|
|
QParams(32, [5, 3], [0.05, 0.95], 4),
|
|
],
|
|
[
|
|
QParams(64, [3, 2], [0.1, 0.9], 4),
|
|
QParams(64, [3, 2], [0.3, 0.7], 4),
|
|
QParams(32, [5, 4], [0.05, 0.95], 4),
|
|
],
|
|
[
|
|
QParams(128, [4, 3], [0.1, 0.9], 4),
|
|
QParams(128, [4, 3], [0.25, 0.75], 4),
|
|
QParams([32, 128, 128], [8, 4, 3], [0.05, 0.1, 0.85], 4),
|
|
],
|
|
[
|
|
QParams(32, [4, 3], [0.1, 0.9], 4),
|
|
QParams(32, [4, 3], [0.25, 0.75], 4),
|
|
QParams([32, 32, 32], [8, 4, 3], [0.05, 0.1, 0.85], 4),
|
|
],
|
|
[
|
|
QParams(32, [4, 3], [0.1, 0.9], 4),
|
|
QParams(32, [4, 3], [0.25, 0.75], 4),
|
|
QParams([32, 128], [8, 4], [0.05, 0.95], 4),
|
|
],
|
|
[
|
|
QParams(128, [4], [1], 4),
|
|
QParams(32, [4], [1], 4),
|
|
QParams([32, 128], [8, 4], [0.05, 0.95], 4),
|
|
],
|
|
[
|
|
QParams(32, [4], [1], 4),
|
|
QParams(32, [4], [1], 4),
|
|
QParams([32, 32], [8, 4], [0.05, 0.95], 4),
|
|
],
|
|
[
|
|
QParams(128, [5, 4], [0.1, 0.9], 4),
|
|
QParams(128, [5, 4], [0.25, 0.75], 4),
|
|
QParams([32, 128, 128], [8, 5, 4], [0.05, 0.1, 0.85], 4),
|
|
],
|
|
[
|
|
QParams(32, [5, 4], [0.1, 0.9], 4),
|
|
QParams(32, [5, 4], [0.25, 0.75], 4),
|
|
QParams([32, 32, 32], [8, 5, 4], [0.05, 0.1, 0.85], 4),
|
|
],
|
|
[
|
|
QParams(128, [6, 5], [0.1, 0.9], 4),
|
|
QParams(128, [6, 5], [0.25, 0.75], 4),
|
|
QParams([32, 128, 128], [8, 6, 5], [0.05, 0.1, 0.85], 4),
|
|
],
|
|
[
|
|
QParams(32, [6, 5], [0.1, 0.9], 4),
|
|
QParams(32, [6, 5], [0.25, 0.75], 4),
|
|
QParams([32, 32, 32], [8, 6, 5], [0.05, 0.1, 0.85], 4),
|
|
],
|
|
[
|
|
QParams(128, [6], [1], 4),
|
|
QParams(128, [6], [1], 4),
|
|
QParams([32, 128], [8, 6], [0.05, 0.95], 4),
|
|
],
|
|
[
|
|
QParams(128, [8, 6], [0.1, 0.9], 4),
|
|
QParams(128, [8, 6], [0.1, 0.9], 4),
|
|
QParams(128, [8, 6], [0.15, 0.85], 4),
|
|
],
|
|
[
|
|
QParams(128, [8, 6], [0.1, 0.9], 4),
|
|
QParams(128, [8, 6], [0.1, 0.9], 4),
|
|
QParams(128, [8], [1], 4),
|
|
],
|
|
[
|
|
QParams(128, [8], [1], 4),
|
|
QParams(128, [8], [1], 4),
|
|
QParams(128, [8], [1], 4),
|
|
]
|
|
]
|
|
|
|
qparams_headoptions = \
|
|
{
|
|
2: QParams(32, [4, 2], [0.3, 0.7], 4),
|
|
3: QParams(32, [4, 3], [0.15, 0.85], 4),
|
|
4: QParams(32, [6, 4], [0.15, 0.85], 4),
|
|
5: QParams(128, [6, 5], [0.15, 0.85], 4),
|
|
6: QParams(128, [8, 6], [0.15, 0.85], 4),
|
|
8: QParams(128, [8], [1.0], 4),
|
|
# 16: None
|
|
}
|
|
|
|
def get_qparams_reduced(options, ignore_gate = False):
|
|
|
|
num_options = len(options)
|
|
dim = len(options[0])
|
|
assert all(len(o) == dim for o in options)
|
|
|
|
desc_to_idx = [{} for _ in range(dim)]
|
|
idx_to_qp = [[] for _ in range(dim)]
|
|
maps = []
|
|
|
|
for o in options:
|
|
m = []
|
|
for idx, qp in enumerate(o):
|
|
if ignore_gate and idx == 0: continue
|
|
desc = qp.get_desc()
|
|
if desc not in desc_to_idx[idx]:
|
|
j = len(idx_to_qp[idx])
|
|
desc_to_idx[idx][desc] = j
|
|
idx_to_qp[idx].append(qp)
|
|
else:
|
|
j = desc_to_idx[idx][desc]
|
|
m.append(j)
|
|
maps.append(m)
|
|
|
|
return idx_to_qp, maps
|