mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-05-11 16:30:25 +00:00
313 lines
10 KiB
Python
313 lines
10 KiB
Python
from exllamav2.model import ExLlamaV2Embedding, ExLlamaV2Attention, ExLlamaV2MLP, ExLlamaV2Linear, ExLlamaV2RMSNorm
|
|
from safetensors import safe_open
|
|
from safetensors.torch import save_file
|
|
from conversion.qparams import QParams, qparams_headoptions, qparams_attn, qparams_mlp, get_qparams_reduced
|
|
from conversion.adaptivegptq import AdaptiveGPTQ
|
|
import torch
|
|
from torch import nn
|
|
import os, time, math, json
|
|
import torch.nn.functional as F
|
|
import gc
|
|
|
|
def list_live_tensors():
|
|
|
|
tensors = {}
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
for obj in gc.get_objects():
|
|
try:
|
|
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
|
|
d = str(obj.size()) + ", " + str(obj.dtype) + ", " + str(obj.device)
|
|
if d in tensors.keys():
|
|
tensors[d] += 1
|
|
else:
|
|
tensors[d] = 1
|
|
except:
|
|
pass
|
|
|
|
print("-----------")
|
|
for k, v in tensors.items():
|
|
print(f"{v} : {k}")
|
|
|
|
|
|
# Quantize
|
|
|
|
def quant_linear(job: dict,
|
|
source: ExLlamaV2Linear,
|
|
lq: AdaptiveGPTQ,
|
|
qparams: dict,
|
|
drop = False):
|
|
|
|
qp = QParams.from_dict(qparams)
|
|
print(f" -- Linear: {source.key} -> {qp.get_desc()}, {qp.bpw(source.linear.weight.T.shape):.2f} bpw")
|
|
|
|
# Quantize
|
|
|
|
lq.configure(qp.group_size, qp.bits, qp.bits_prop, qp.scale_bits)
|
|
lq.quantize(keep_qweight = True, apply = True, drop = drop)
|
|
|
|
# Pack and save quantized layer
|
|
|
|
packed_dict = lq.pack(source.key, qp)
|
|
tensorfile = os.path.join(job["out_dir"], "out_tensor/" + source.key + ".safetensors")
|
|
save_file(packed_dict, tensorfile)
|
|
|
|
# Drop buffers from quantizer to free VRAM
|
|
|
|
if drop: lq.drop_buffers()
|
|
|
|
# Reconstruct from packed layer
|
|
|
|
recons_linear = ExLlamaV2Linear(source.model, source.key, source.in_features, source.out_features, False)
|
|
recons_linear.device_idx = source.device_idx
|
|
recons_dict = {}
|
|
for k in ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups"]:
|
|
recons_dict[k] = packed_dict[source.key + "." + k]
|
|
recons_dict["q_perm"] = torch.argsort(recons_dict["q_invperm"]).to(torch.int)
|
|
recons_linear.load(recons_dict)
|
|
|
|
# Sanity test to ensure reconstructed matrix matches unpacked matrix
|
|
|
|
quant_w = source.linear.weight.T
|
|
recons_w = recons_linear.get_weight_tensor_dq()
|
|
|
|
ident = torch.eye(recons_linear.in_features, dtype = torch.half).cuda()
|
|
recons_w2 = recons_linear.forward(ident, force_cuda = True)
|
|
|
|
recons_w2.sub_(quant_w)
|
|
recons_w2.abs_()
|
|
diff2 = torch.max(recons_w2)
|
|
|
|
quant_w.sub_(recons_w)
|
|
quant_w.abs_()
|
|
diff1 = torch.max(quant_w)
|
|
quant_w = None
|
|
|
|
if diff1 > 0.01 or diff2 > 0.01:
|
|
print(" ## Quantization error (2)")
|
|
os._exit(0)
|
|
|
|
# Free reconstructed linear layer
|
|
|
|
recons_linear.unload()
|
|
|
|
# Apply reconstructed matrix to source layer
|
|
|
|
source.linear.weight.data = recons_w.T
|
|
|
|
|
|
def quant_attn(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat):
|
|
|
|
quantizers["q_proj"].prepare()
|
|
quantizers["k_proj"].reuse_h(quantizers["q_proj"])
|
|
quantizers["v_proj"].reuse_h(quantizers["q_proj"])
|
|
quantizers["o_proj"].prepare()
|
|
|
|
quant_linear(job, module.q_proj, quantizers["q_proj"], strat["q_proj"])
|
|
quant_linear(job, module.k_proj, quantizers["k_proj"], strat["k_proj"])
|
|
quant_linear(job, module.v_proj, quantizers["v_proj"], strat["v_proj"])
|
|
quant_linear(job, module.o_proj, quantizers["o_proj"], strat["o_proj"])
|
|
|
|
|
|
def quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat):
|
|
|
|
quantizers["gate_proj"].prepare()
|
|
quantizers["up_proj"].reuse_h(quantizers["gate_proj"])
|
|
quantizers["down_proj"].prepare()
|
|
|
|
quant_linear(job, module.gate_proj, quantizers["gate_proj"], strat["gate_proj"])
|
|
quant_linear(job, module.up_proj, quantizers["up_proj"], strat["up_proj"])
|
|
quant_linear(job, module.down_proj, quantizers["down_proj"], strat["down_proj"])
|
|
|
|
|
|
def quant_lm_head(job, module, hidden_states, quantizers, cache, attn_mask):
|
|
|
|
quantizers["lm_head"].prepare()
|
|
|
|
qp = qparams_headoptions[job["head_bits"]]
|
|
quant_linear(job, module, quantizers["lm_head"], qp.get_dict())
|
|
|
|
|
|
@torch.inference_mode()
|
|
def quant(job, save_fn, model):
|
|
|
|
snapshot_interval = 10
|
|
temp_filename = os.path.join(job["out_dir"], "hidden_states_temp.safetensors")
|
|
states_filename = os.path.join(job["out_dir"], "hidden_states.safetensors")
|
|
strategy = job["strategy"]
|
|
|
|
# Quantize
|
|
|
|
if not "q_last_module_idx" in job:
|
|
job["q_last_module_idx"] = 0
|
|
|
|
hidden_states = []
|
|
with safe_open(states_filename, framework = "pt", device = "cpu") as f:
|
|
for k in sorted(f.keys()):
|
|
hidden_states.append(f.get_tensor(k))
|
|
|
|
index = job["q_last_module_idx"]
|
|
while True:
|
|
|
|
index += 1
|
|
if index >= len(model.modules): break
|
|
|
|
# Prepare module
|
|
|
|
module = model.modules[index]
|
|
module.load()
|
|
|
|
print(f" -- Layer: {module.key} ({module.name})")
|
|
|
|
# Create quantizers
|
|
|
|
quantizers = {}
|
|
|
|
if isinstance(module, ExLlamaV2Attention):
|
|
mode = "self_attn"
|
|
quantizers["q_proj"] = AdaptiveGPTQ(module.q_proj.linear)
|
|
quantizers["k_proj"] = AdaptiveGPTQ(module.k_proj.linear)
|
|
quantizers["v_proj"] = AdaptiveGPTQ(module.v_proj.linear)
|
|
quantizers["o_proj"] = AdaptiveGPTQ(module.o_proj.linear)
|
|
|
|
elif isinstance(module, ExLlamaV2MLP):
|
|
mode = "mlp"
|
|
quantizers["gate_proj"] = AdaptiveGPTQ(module.gate_proj.linear)
|
|
quantizers["up_proj"] = AdaptiveGPTQ(module.up_proj.linear)
|
|
quantizers["down_proj"] = AdaptiveGPTQ(module.down_proj.linear)
|
|
|
|
elif isinstance(module, ExLlamaV2Linear):
|
|
mode = "linear"
|
|
assert module.key == "lm_head"
|
|
quantizers["lm_head"] = AdaptiveGPTQ(module.linear)
|
|
|
|
elif isinstance(module, ExLlamaV2RMSNorm):
|
|
mode = "norm"
|
|
|
|
# Reference forward pass
|
|
|
|
cache = None
|
|
attn_mask = model.build_attn_mask(1, hidden_states[0].shape[1], 0, None, "cuda:0") if mode == "self_attn" else None
|
|
|
|
target_states = []
|
|
for i in range(len(hidden_states)):
|
|
|
|
x = hidden_states[i].to("cuda:0")
|
|
outputs = module.forward(x, cache, attn_mask, intermediates = True)
|
|
|
|
# Hessians
|
|
|
|
if mode == "self_attn":
|
|
quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V
|
|
quantizers["o_proj"].add_batch(outputs["attn_output"])
|
|
|
|
if mode == "mlp":
|
|
quantizers["gate_proj"].add_batch(outputs["post_norm"]) # Reuse H for up_proj
|
|
quantizers["down_proj"].add_batch(outputs["pre_down"])
|
|
|
|
if mode == "linear":
|
|
quantizers["lm_head"].add_batch(x)
|
|
|
|
if mode != "linear":
|
|
target_states.append(outputs["hidden_states"].to("cpu"))
|
|
|
|
# Measurement
|
|
|
|
if mode == "self_attn":
|
|
strat = strategy[module.key + "." + mode]
|
|
quant_attn(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat)
|
|
|
|
if mode == "mlp":
|
|
strat = strategy[module.key + "." + mode]
|
|
quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_mask, strat)
|
|
|
|
if mode == "linear":
|
|
quant_lm_head(job, module, hidden_states, quantizers, cache, attn_mask)
|
|
|
|
# Post-quantization forward pass
|
|
|
|
if mode == "linear":
|
|
with safe_open(job["cal_filename"], framework = "pt", device = "cpu") as f:
|
|
cal_ids = f.get_tensor("input_ids")
|
|
|
|
rfn_sum = 0
|
|
rfn_count = 0
|
|
logprob_sum = 0.0
|
|
logprob_count = 0
|
|
|
|
q_states = []
|
|
for i in range(len(hidden_states)):
|
|
|
|
if mode != "linear":
|
|
|
|
x = hidden_states[i].to("cuda:0")
|
|
output = module.forward(x, cache, attn_mask)
|
|
q_states.append(output.to("cpu"))
|
|
|
|
output = output[0].float()
|
|
output_ref = target_states[i].to("cuda:0")
|
|
output_ref = output_ref[0].float()
|
|
|
|
rfn_sum += (torch.linalg.norm(output - output_ref, 'fro') / torch.linalg.norm(output_ref, 'fro')).item()
|
|
rfn_count += 1
|
|
|
|
elif i < job["measurement_rows"]:
|
|
|
|
x = hidden_states[i].to("cuda:0")
|
|
output = module.forward(x, cache, attn_mask)
|
|
if module.padding > 0: outputs = outputs[:, :, :-module.padding]
|
|
|
|
logits = output[:, :-1, :]
|
|
logits = logits.float() + 1e-10
|
|
target_ids = cal_ids[i:i+1, 1:].to("cuda:0")
|
|
|
|
log_probs = F.log_softmax(logits, dim = -1)
|
|
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
|
logprob_sum += token_log_probs.sum().item()
|
|
logprob_count += target_ids.numel()
|
|
|
|
if mode != "linear":
|
|
|
|
err = rfn_sum / rfn_count
|
|
print(f" -- Module quantized, rfn_error: {err:1.6f}")
|
|
|
|
else:
|
|
|
|
mean_log_prob = logprob_sum / logprob_count
|
|
perplexity = math.exp(-mean_log_prob)
|
|
|
|
print(f" -- Module quantized, calibration perplexity (quant): {perplexity:.4f}")
|
|
|
|
|
|
# Unload module
|
|
|
|
module.unload()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Advance
|
|
|
|
hidden_states = q_states
|
|
|
|
# Checkpoint
|
|
|
|
if index % snapshot_interval == 0 or index == len(model.modules) - 1:
|
|
|
|
if mode != "linear":
|
|
save_dict = {f"row.{idx:05}": h for idx, h in enumerate(hidden_states)}
|
|
save_file(save_dict, temp_filename)
|
|
save_dict = None
|
|
|
|
job["invalid"] = True
|
|
save_fn()
|
|
|
|
if mode != "linear":
|
|
os.remove(states_filename)
|
|
os.rename(temp_filename, states_filename)
|
|
|
|
job["q_last_module_idx"] = index
|
|
|
|
del job["invalid"]
|
|
save_fn()
|
|
|