mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
442 lines
15 KiB
Python
442 lines
15 KiB
Python
from exllamav2.model import \
|
|
(
|
|
ExLlamaV2Embedding,
|
|
ExLlamaV2Attention,
|
|
ExLlamaV2MLP,
|
|
ExLlamaV2MoEMLP,
|
|
ExLlamaV2Linear,
|
|
ExLlamaV2RMSNorm,
|
|
ExLlamaV2LayerNorm
|
|
)
|
|
|
|
from safetensors import safe_open
|
|
from safetensors.torch import save_file
|
|
from conversion.qparams import QParams, qparams_headoptions, qparams_attn, qparams_mlp, get_qparams_reduced
|
|
from conversion.adaptivegptq import AdaptiveGPTQ
|
|
import torch
|
|
from torch import nn
|
|
import os, time, math, json
|
|
import torch.nn.functional as F
|
|
import gc
|
|
|
|
def list_live_tensors():
|
|
|
|
tensors = {}
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
for obj in gc.get_objects():
|
|
try:
|
|
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
|
|
d = str(obj.size()) + ", " + str(obj.dtype) + ", " + str(obj.device)
|
|
if d in tensors.keys():
|
|
tensors[d] += 1
|
|
else:
|
|
tensors[d] = 1
|
|
except:
|
|
pass
|
|
|
|
print("-----------")
|
|
for k, v in tensors.items():
|
|
print(f"{v} : {k}")
|
|
|
|
|
|
# Quantize
|
|
|
|
def quant_linear(job: dict,
|
|
source: ExLlamaV2Linear,
|
|
lq: AdaptiveGPTQ,
|
|
qparams: dict,
|
|
drop = False):
|
|
|
|
qp = QParams.from_dict(qparams)
|
|
print(f" -- Linear: {source.key} -> {qp.get_desc()}, {qp.bpw(source.linear.weight.T.shape):.2f} bpw")
|
|
|
|
# Quantize
|
|
|
|
lq.configure(qp.group_size, qp.bits, qp.bits_prop, qp.scale_bits)
|
|
lq.quantize(keep_qweight = True, apply = True)
|
|
|
|
# Pack and save quantized layer
|
|
|
|
packed_dict = lq.pack(source.key, qp)
|
|
tensorfile = os.path.join(job["out_dir"], "out_tensor/" + source.key + ".safetensors")
|
|
save_file(packed_dict, tensorfile)
|
|
|
|
# Drop buffers from quantizer to free VRAM
|
|
|
|
if drop: lq.drop_buffers()
|
|
|
|
# Reconstruct from packed layer
|
|
|
|
recons_linear = ExLlamaV2Linear(source.model, source.key, source.in_features, source.out_features, source.has_bias)
|
|
recons_linear.device_idx = source.device_idx
|
|
recons_dict = {}
|
|
recons_keys = ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups"]
|
|
if source.has_bias: recons_keys += ["bias"]
|
|
for k in recons_keys:
|
|
recons_dict[k] = packed_dict[source.key + "." + k]
|
|
recons_dict["q_perm"] = torch.argsort(recons_dict["q_invperm"]).to(torch.int)
|
|
recons_linear.load(recons_dict)
|
|
|
|
# Sanity test to ensure reconstructed matrix matches unpacked matrix
|
|
|
|
quant_w = source.linear.weight.T
|
|
recons_w = recons_linear.get_weight_tensor_dq()
|
|
|
|
ident = torch.eye(recons_linear.in_features, dtype = torch.half).cuda()
|
|
recons_w2 = recons_linear.forward(ident, force_cuda = True)
|
|
|
|
recons_w2.sub_(quant_w)
|
|
if recons_linear.has_bias: recons_w2.sub_(recons_dict["bias"])
|
|
recons_w2.abs_()
|
|
diff2 = torch.max(recons_w2)
|
|
|
|
quant_w.sub_(recons_w)
|
|
quant_w.abs_()
|
|
diff1 = torch.max(quant_w)
|
|
quant_w = None
|
|
|
|
# TODO: Investigate why this might fail for the first QKV projections of certain models
|
|
|
|
if diff1 > 0.05 or diff2 > 0.05:
|
|
print(" ## Quantization error (2)")
|
|
os._exit(0)
|
|
elif diff1 > 0.01 or diff2 > 0.01:
|
|
print(f" !! Warning, difference of ({diff1:.6f}, {diff2:.6f}) between unpacked and dequantized matrices")
|
|
|
|
# Free reconstructed linear layer
|
|
|
|
recons_linear.unload()
|
|
|
|
# Apply reconstructed matrix to source layer
|
|
|
|
source.linear.weight.data = recons_w.T
|
|
|
|
|
|
def quant_attn(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat):
|
|
|
|
quantizers["q_proj"].prepare()
|
|
quantizers["k_proj"].reuse_h(quantizers["q_proj"])
|
|
quantizers["v_proj"].reuse_h(quantizers["q_proj"])
|
|
quantizers["o_proj"].prepare()
|
|
|
|
quant_linear(job, module.q_proj, quantizers["q_proj"], strat["q_proj"])
|
|
quant_linear(job, module.k_proj, quantizers["k_proj"], strat["k_proj"])
|
|
quant_linear(job, module.v_proj, quantizers["v_proj"], strat["v_proj"])
|
|
quant_linear(job, module.o_proj, quantizers["o_proj"], strat["o_proj"])
|
|
|
|
|
|
def quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat):
|
|
|
|
quantizers["gate_proj"].prepare()
|
|
quantizers["up_proj"].reuse_h(quantizers["gate_proj"])
|
|
|
|
quant_linear(job, module.gate_proj, quantizers["gate_proj"], strat["gate_proj"])
|
|
del quantizers[f"gate_proj"]
|
|
quant_linear(job, module.up_proj, quantizers["up_proj"], strat["up_proj"])
|
|
del quantizers[f"up_proj"]
|
|
|
|
quantizers["down_proj"].prepare()
|
|
|
|
quant_linear(job, module.down_proj, quantizers["down_proj"], strat["down_proj"])
|
|
del quantizers[f"down_proj"]
|
|
|
|
|
|
def quant_moe_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat):
|
|
|
|
num_experts = module.model.config.num_experts
|
|
|
|
quantizers["w1.0"].prepare()
|
|
for i in range(num_experts):
|
|
if i > 0: quantizers[f"w1.{i}"].reuse_h(quantizers["w1.0"])
|
|
quantizers[f"w2.{i}"].prepare()
|
|
quantizers[f"w3.{i}"].reuse_h(quantizers["w1.0"])
|
|
|
|
for i in range(num_experts):
|
|
quant_linear(job, module.w1[i], quantizers[f"w1.{i}"], strat["w1"])
|
|
del quantizers[f"w1.{i}"]
|
|
quant_linear(job, module.w3[i], quantizers[f"w3.{i}"], strat["w3"])
|
|
del quantizers[f"w3.{i}"]
|
|
quant_linear(job, module.w2[i], quantizers[f"w2.{i}"], strat["w2"])
|
|
del quantizers[f"w2.{i}"]
|
|
|
|
|
|
def quant_lm_head(job, module, hidden_states, quantizers, cache, attn_params):
|
|
|
|
quantizers["lm_head"].prepare()
|
|
|
|
qp = qparams_headoptions[job["head_bits"]]
|
|
quant_linear(job, module, quantizers["lm_head"], qp.get_dict(), drop = True)
|
|
|
|
|
|
# def testc(module, states, target_states, norm, layers):
|
|
#
|
|
# rows = len(states)
|
|
# cols = states[0].shape[1]
|
|
# dim = module.model.config.hidden_size
|
|
#
|
|
# a_batch = torch.empty((rows * cols, dim), dtype = torch.float, device = "cuda:0")
|
|
# b_batch = torch.empty((rows * cols, dim), dtype = torch.float, device = "cuda:0")
|
|
#
|
|
# r = 0
|
|
# for state, target_state in zip(states, target_states):
|
|
# a = norm.forward(state.to("cuda:0"))
|
|
# b = norm.forward(target_state.to("cuda:0"))
|
|
# a_batch[r:r+cols] = a.view(-1, dim)
|
|
# b_batch[r:r+cols] = b.view(-1, dim)
|
|
# r += cols
|
|
#
|
|
# # diff = F.mse_loss(b_batch, a_batch)
|
|
# m_a = torch.mean(a_batch.abs(), dim = 0)
|
|
# m_b = torch.mean(b_batch.abs(), dim = 0)
|
|
# m_ab = m_b / m_a
|
|
# # a_batch *= m_ab
|
|
# # diff = F.mse_loss(b_batch, a_batch)
|
|
# norm.weight.data *= m_ab
|
|
#
|
|
# # s = torch.linalg.lstsq(a_batch, b_batch)
|
|
# # s = s.solution
|
|
# #
|
|
# # for linear in layers:
|
|
# # m = torch.matmul(s, linear.linear.weight.data.T.float())
|
|
# # linear.linear.weight.data = nn.Parameter(m.T.half())
|
|
#
|
|
# xx = 0
|
|
|
|
|
|
@torch.inference_mode()
|
|
def quant(job, save_fn, model):
|
|
|
|
snapshot_interval = 10
|
|
temp_filename = os.path.join(job["out_dir"], "hidden_states_temp.safetensors")
|
|
states_filename = os.path.join(job["out_dir"], "hidden_states.safetensors")
|
|
strategy = job["strategy"]
|
|
|
|
# Quantize
|
|
|
|
if not "q_last_module_idx" in job:
|
|
job["q_last_module_idx"] = 0
|
|
|
|
hidden_states = []
|
|
# hidden_i_states = []
|
|
with safe_open(states_filename, framework = "pt", device = "cpu") as f:
|
|
for k in sorted(f.keys()):
|
|
if k.startswith("row"):
|
|
hidden_states.append(f.get_tensor(k))
|
|
# elif k.startswith("i_row"):
|
|
# hidden_i_states.append(f.get_tensor(k))
|
|
|
|
index = job["q_last_module_idx"]
|
|
while True:
|
|
|
|
index += 1
|
|
if index >= len(model.modules): break
|
|
|
|
# Prepare module
|
|
|
|
module = model.modules[index]
|
|
module.load()
|
|
|
|
print(f" -- Layer: {module.key} ({module.name})")
|
|
|
|
# Create quantizers
|
|
|
|
quantizers = {}
|
|
|
|
if isinstance(module, ExLlamaV2Attention):
|
|
mode = "self_attn"
|
|
# if index > 1: testc(module, hidden_states, hidden_i_states, module.input_layernorm, [module.q_proj, module.k_proj, module.v_proj])
|
|
quantizers["q_proj"] = AdaptiveGPTQ(module.q_proj.linear)
|
|
quantizers["k_proj"] = AdaptiveGPTQ(module.k_proj.linear)
|
|
quantizers["v_proj"] = AdaptiveGPTQ(module.v_proj.linear)
|
|
quantizers["o_proj"] = AdaptiveGPTQ(module.o_proj.linear)
|
|
|
|
elif isinstance(module, ExLlamaV2MLP):
|
|
mode = "mlp"
|
|
# testc(module, hidden_states, hidden_i_states, module.post_attention_layernorm, [module.gate_proj, module.up_proj])
|
|
quantizers["gate_proj"] = AdaptiveGPTQ(module.gate_proj.linear)
|
|
quantizers["up_proj"] = AdaptiveGPTQ(module.up_proj.linear)
|
|
quantizers["down_proj"] = AdaptiveGPTQ(module.down_proj.linear)
|
|
|
|
elif isinstance(module, ExLlamaV2MoEMLP):
|
|
mode = "block_sparse_moe"
|
|
for i in range(model.config.num_experts):
|
|
quantizers[f"w1.{i}"] = AdaptiveGPTQ(module.w1[i].linear)
|
|
quantizers[f"w3.{i}"] = AdaptiveGPTQ(module.w3[i].linear)
|
|
quantizers[f"w2.{i}"] = AdaptiveGPTQ(module.w2[i].linear)
|
|
|
|
elif isinstance(module, ExLlamaV2Linear):
|
|
mode = "linear"
|
|
assert module.key == "lm_head"
|
|
quantizers["lm_head"] = AdaptiveGPTQ(module.linear)
|
|
|
|
elif isinstance(module, ExLlamaV2RMSNorm) or isinstance(module, ExLlamaV2LayerNorm):
|
|
mode = "norm"
|
|
|
|
|
|
# Reference forward pass
|
|
|
|
cache = None
|
|
attn_params = ExLlamaV2Attention.Params(1, hidden_states[0].shape[1], 0, None, None) if mode == "self_attn" else None
|
|
|
|
target_states = []
|
|
if mode == "block_sparse_moe":
|
|
uncalibrated_experts = [0 for _ in range(model.config.num_experts)]
|
|
|
|
for i in range(len(hidden_states)):
|
|
|
|
x = hidden_states[i].to("cuda:0")
|
|
outputs = module.forward(x, cache, attn_params, intermediates = True)
|
|
|
|
# Hessians
|
|
|
|
if mode == "self_attn":
|
|
quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V
|
|
quantizers["o_proj"].add_batch(outputs["attn_output"])
|
|
|
|
if mode == "mlp":
|
|
quantizers["gate_proj"].add_batch(outputs["post_norm"]) # Reuse H for up_proj
|
|
quantizers["down_proj"].add_batch(outputs["pre_down"])
|
|
|
|
if mode == "block_sparse_moe":
|
|
for j in range(model.config.num_experts):
|
|
if f"pre_down.{j}" in outputs:
|
|
quantizers[f"w1.{j}"].add_batch(outputs["post_norm"])
|
|
quantizers[f"w2.{j}"].add_batch(outputs[f"pre_down.{j}"])
|
|
if outputs[f"pre_down.{j}"].shape[0] < outputs["post_norm"].shape[0] / 10:
|
|
uncalibrated_experts[j] += 1
|
|
else:
|
|
uncalibrated_experts[j] += 1
|
|
|
|
if mode == "linear":
|
|
quantizers["lm_head"].add_batch(x)
|
|
|
|
if mode != "linear":
|
|
target_states.append(outputs["hidden_states"].to("cpu"))
|
|
|
|
# For MoE layers, warn if any layer received less than 10% of a calibration batch
|
|
|
|
if mode == "block_sparse_moe":
|
|
for j in range(model.config.num_experts):
|
|
ue = uncalibrated_experts[j]
|
|
if ue > len(hidden_states) * 0.10:
|
|
print(f" !! Warning: w2.{j} has less than 10% calibration for {ue}/{len(hidden_states)} rows")
|
|
|
|
# Conversion
|
|
|
|
if mode == "self_attn":
|
|
strat = strategy[module.key + "." + mode]
|
|
quant_attn(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat)
|
|
|
|
if mode == "mlp":
|
|
strat = strategy[module.key + "." + mode]
|
|
quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat)
|
|
|
|
if mode == "block_sparse_moe":
|
|
strat = strategy[module.key + "." + mode]
|
|
quant_moe_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat)
|
|
|
|
if mode == "linear":
|
|
quant_lm_head(job, module, hidden_states, quantizers, cache, attn_params)
|
|
|
|
quantizers.clear()
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Post-quantization forward pass
|
|
|
|
if mode == "linear":
|
|
with safe_open(job["cal_filename"], framework = "pt", device = "cpu") as f:
|
|
cal_ids = f.get_tensor("input_ids")
|
|
|
|
rfn_sum = 0
|
|
rfn_count = 0
|
|
logprob_sum = 0.0
|
|
logprob_count = 0
|
|
|
|
q_states = []
|
|
for i in range(len(hidden_states)):
|
|
|
|
if mode != "linear":
|
|
|
|
x = hidden_states[i].to("cuda:0")
|
|
output = module.forward(x, cache, attn_params)
|
|
q_states.append(output.to("cpu"))
|
|
|
|
output = output[0].float()
|
|
output_ref = target_states[i].to("cuda:0")
|
|
output_ref = output_ref[0].float()
|
|
|
|
rfn_sum += (torch.linalg.norm(output - output_ref, 'fro') / torch.linalg.norm(output_ref, 'fro')).item()
|
|
rfn_count += 1
|
|
|
|
output_ref = None
|
|
output = None
|
|
|
|
elif i < job["measurement_rows"]:
|
|
|
|
x = hidden_states[i].to("cuda:0")
|
|
output = module.forward(x, cache, attn_params)
|
|
if module.padding > 0: output = output[:, :, :-module.padding]
|
|
|
|
logits = output[:, :-1, :]
|
|
logits = logits.float() + 1e-10
|
|
target_ids = cal_ids[i:i+1, 1:].to("cuda:0")
|
|
|
|
log_probs = F.log_softmax(logits, dim = -1)
|
|
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
|
logprob_sum += token_log_probs.sum().item()
|
|
logprob_count += target_ids.numel()
|
|
|
|
output = None
|
|
logits = None
|
|
token_log_probs = None
|
|
|
|
if mode != "linear":
|
|
|
|
err = rfn_sum / rfn_count
|
|
print(f" -- Module quantized, rfn_error: {err:1.6f}")
|
|
|
|
else:
|
|
|
|
mean_log_prob = logprob_sum / logprob_count
|
|
perplexity = math.exp(-mean_log_prob)
|
|
|
|
print(f" -- Module quantized, calibration perplexity (quant): {perplexity:.4f}")
|
|
|
|
# Unload module
|
|
|
|
module.unload()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Advance
|
|
|
|
if mode != "linear":
|
|
# hidden_i_states = hidden_states
|
|
# hidden_states = target_states
|
|
# hidden_states = [(x + y) / 2 for x, y in zip(target_states, q_states)]
|
|
hidden_states = q_states
|
|
q_states = None
|
|
|
|
# Checkpoint
|
|
|
|
if index % snapshot_interval == 0 or index == len(model.modules) - 1:
|
|
|
|
if mode != "linear":
|
|
save_dict = {f"row.{idx:05}": h for idx, h in enumerate(hidden_states)}
|
|
# save_dict.update( {f"i_row.{idx:05}": h for idx, h in enumerate(hidden_i_states)} )
|
|
save_file(save_dict, temp_filename)
|
|
save_dict = None
|
|
|
|
job["invalid"] = True
|
|
save_fn()
|
|
|
|
if mode != "linear":
|
|
os.replace(temp_filename, states_filename)
|
|
|
|
job["q_last_module_idx"] = index
|
|
|
|
del job["invalid"]
|
|
save_fn()
|