from exllamav2.model import \ ( ExLlamaV2Embedding, ExLlamaV2Attention, ExLlamaV2MLP, ExLlamaV2MoEMLP, ExLlamaV2Linear, ExLlamaV2RMSNorm, ExLlamaV2LayerNorm ) from safetensors import safe_open from safetensors.torch import save_file from conversion.qparams import QParams, qparams_headoptions, qparams_attn, qparams_mlp, get_qparams_reduced from conversion.adaptivegptq import AdaptiveGPTQ import torch from torch import nn import os, time, math, json import torch.nn.functional as F import gc def list_live_tensors(): tensors = {} gc.collect() torch.cuda.empty_cache() for obj in gc.get_objects(): try: if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): d = str(obj.size()) + ", " + str(obj.dtype) + ", " + str(obj.device) if d in tensors.keys(): tensors[d] += 1 else: tensors[d] = 1 except: pass print("-----------") for k, v in tensors.items(): print(f"{v} : {k}") # Quantize def quant_linear(job: dict, source: ExLlamaV2Linear, lq: AdaptiveGPTQ, qparams: dict, drop = False): qp = QParams.from_dict(qparams) print(f" -- Linear: {source.key} -> {qp.get_desc()}, {qp.bpw(source.linear.weight.T.shape):.2f} bpw") # Quantize lq.configure(qp.group_size, qp.bits, qp.bits_prop, qp.scale_bits) lq.quantize(keep_qweight = True, apply = True) # Pack and save quantized layer packed_dict = lq.pack(source.key, qp) tensorfile = os.path.join(job["out_dir"], "out_tensor/" + source.key + ".safetensors") save_file(packed_dict, tensorfile) # Drop buffers from quantizer to free VRAM if drop: lq.drop_buffers() # Reconstruct from packed layer recons_linear = ExLlamaV2Linear(source.model, source.key, source.in_features, source.out_features, source.has_bias) recons_linear.device_idx = source.device_idx recons_dict = {} recons_keys = ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups"] if source.has_bias: recons_keys += ["bias"] for k in recons_keys: recons_dict[k] = packed_dict[source.key + "." + k] recons_dict["q_perm"] = torch.argsort(recons_dict["q_invperm"]).to(torch.int) recons_linear.load(recons_dict) # Sanity test to ensure reconstructed matrix matches unpacked matrix quant_w = source.linear.weight.T recons_w = recons_linear.get_weight_tensor_dq() ident = torch.eye(recons_linear.in_features, dtype = torch.half).cuda() recons_w2 = recons_linear.forward(ident, force_cuda = True) recons_w2.sub_(quant_w) if recons_linear.has_bias: recons_w2.sub_(recons_dict["bias"]) recons_w2.abs_() diff2 = torch.max(recons_w2) quant_w.sub_(recons_w) quant_w.abs_() diff1 = torch.max(quant_w) quant_w = None # TODO: Investigate why this might fail for the first QKV projections of certain models if diff1 > 0.05 or diff2 > 0.05: print(" ## Quantization error (2)") os._exit(0) elif diff1 > 0.01 or diff2 > 0.01: print(f" !! Warning, difference of ({diff1:.6f}, {diff2:.6f}) between unpacked and dequantized matrices") # Free reconstructed linear layer recons_linear.unload() # Apply reconstructed matrix to source layer source.linear.weight.data = recons_w.T def quant_attn(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat): quantizers["q_proj"].prepare() quantizers["k_proj"].reuse_h(quantizers["q_proj"]) quantizers["v_proj"].reuse_h(quantizers["q_proj"]) quantizers["o_proj"].prepare() quant_linear(job, module.q_proj, quantizers["q_proj"], strat["q_proj"]) quant_linear(job, module.k_proj, quantizers["k_proj"], strat["k_proj"]) quant_linear(job, module.v_proj, quantizers["v_proj"], strat["v_proj"]) quant_linear(job, module.o_proj, quantizers["o_proj"], strat["o_proj"]) def quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat): has_mlp = module.model.config.arch.mlp_gate quantizers["up_proj"].prepare() if has_mlp: quantizers["gate_proj"].reuse_h(quantizers["up_proj"]) quant_linear(job, module.gate_proj, quantizers["gate_proj"], strat["gate_proj"]) del quantizers[f"gate_proj"] quant_linear(job, module.up_proj, quantizers["up_proj"], strat["up_proj"]) del quantizers[f"up_proj"] quantizers["down_proj"].prepare() quant_linear(job, module.down_proj, quantizers["down_proj"], strat["down_proj"]) del quantizers[f"down_proj"] def quant_moe_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat): num_experts = module.model.config.num_experts quantizers["w1.0"].prepare() for i in range(num_experts): if i > 0: quantizers[f"w1.{i}"].reuse_h(quantizers["w1.0"]) quantizers[f"w2.{i}"].prepare() quantizers[f"w3.{i}"].reuse_h(quantizers["w1.0"]) for i in range(num_experts): quant_linear(job, module.w1[i], quantizers[f"w1.{i}"], strat["w1"]) del quantizers[f"w1.{i}"] quant_linear(job, module.w3[i], quantizers[f"w3.{i}"], strat["w3"]) del quantizers[f"w3.{i}"] quant_linear(job, module.w2[i], quantizers[f"w2.{i}"], strat["w2"]) del quantizers[f"w2.{i}"] def quant_lm_head(job, module, hidden_states, quantizers, cache, attn_params): quantizers["lm_head"].prepare() qp = qparams_headoptions[job["head_bits"]] quant_linear(job, module, quantizers["lm_head"], qp.get_dict(), drop = True) # def testc(module, states, target_states, norm, layers): # # rows = len(states) # cols = states[0].shape[1] # dim = module.model.config.hidden_size # # a_batch = torch.empty((rows * cols, dim), dtype = torch.float, device = "cuda:0") # b_batch = torch.empty((rows * cols, dim), dtype = torch.float, device = "cuda:0") # # r = 0 # for state, target_state in zip(states, target_states): # a = norm.forward(state.to("cuda:0")) # b = norm.forward(target_state.to("cuda:0")) # a_batch[r:r+cols] = a.view(-1, dim) # b_batch[r:r+cols] = b.view(-1, dim) # r += cols # # # diff = F.mse_loss(b_batch, a_batch) # m_a = torch.mean(a_batch.abs(), dim = 0) # m_b = torch.mean(b_batch.abs(), dim = 0) # m_ab = m_b / m_a # # a_batch *= m_ab # # diff = F.mse_loss(b_batch, a_batch) # norm.weight.data *= m_ab # # # s = torch.linalg.lstsq(a_batch, b_batch) # # s = s.solution # # # # for linear in layers: # # m = torch.matmul(s, linear.linear.weight.data.T.float()) # # linear.linear.weight.data = nn.Parameter(m.T.half()) # # xx = 0 @torch.inference_mode() def quant(job, save_fn, model): snapshot_interval = 10 temp_filename = os.path.join(job["out_dir"], "hidden_states_temp.safetensors") states_filename = os.path.join(job["out_dir"], "hidden_states.safetensors") strategy = job["strategy"] # Quantize if not "q_last_module_idx" in job: job["q_last_module_idx"] = 0 hidden_states = [] # hidden_i_states = [] with safe_open(states_filename, framework = "pt", device = "cpu") as f: for k in sorted(f.keys()): if k.startswith("row"): hidden_states.append(f.get_tensor(k)) # elif k.startswith("i_row"): # hidden_i_states.append(f.get_tensor(k)) index = job["q_last_module_idx"] while True: index += 1 if index >= len(model.modules): break # Prepare module module = model.modules[index] module.load() print(f" -- Layer: {module.key} ({module.name})") # Create quantizers quantizers = {} if isinstance(module, ExLlamaV2Attention): mode = "self_attn" # if index > 1: testc(module, hidden_states, hidden_i_states, module.input_layernorm, [module.q_proj, module.k_proj, module.v_proj]) quantizers["q_proj"] = AdaptiveGPTQ(module.q_proj.linear) quantizers["k_proj"] = AdaptiveGPTQ(module.k_proj.linear) quantizers["v_proj"] = AdaptiveGPTQ(module.v_proj.linear) quantizers["o_proj"] = AdaptiveGPTQ(module.o_proj.linear) elif isinstance(module, ExLlamaV2MLP): mode = "mlp" has_mlp = model.config.arch.mlp_gate # testc(module, hidden_states, hidden_i_states, module.post_attention_layernorm, [module.gate_proj, module.up_proj]) if has_mlp: quantizers["gate_proj"] = AdaptiveGPTQ(module.gate_proj.linear) quantizers["up_proj"] = AdaptiveGPTQ(module.up_proj.linear) quantizers["down_proj"] = AdaptiveGPTQ(module.down_proj.linear) elif isinstance(module, ExLlamaV2MoEMLP): mode = "block_sparse_moe" for i in range(model.config.num_experts): quantizers[f"w1.{i}"] = AdaptiveGPTQ(module.w1[i].linear) quantizers[f"w3.{i}"] = AdaptiveGPTQ(module.w3[i].linear) quantizers[f"w2.{i}"] = AdaptiveGPTQ(module.w2[i].linear) elif isinstance(module, ExLlamaV2Linear): mode = "linear" assert module.key == "lm_head" quantizers["lm_head"] = AdaptiveGPTQ(module.linear) elif isinstance(module, ExLlamaV2RMSNorm) or isinstance(module, ExLlamaV2LayerNorm): mode = "norm" # Reference forward pass cache = None attn_params = ExLlamaV2Attention.Params(1, hidden_states[0].shape[1], 0, None, None) if mode == "self_attn" else None target_states = [] if mode == "block_sparse_moe": uncalibrated_experts = [0 for _ in range(model.config.num_experts)] for i in range(len(hidden_states)): x = hidden_states[i].to("cuda:0") outputs = module.forward(x, cache, attn_params, intermediates = True) # Hessians if mode == "self_attn": quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V quantizers["o_proj"].add_batch(outputs["attn_output"]) if mode == "mlp": quantizers["up_proj"].add_batch(outputs["post_norm"]) # Reuse H for gate_proj quantizers["down_proj"].add_batch(outputs["pre_down"]) if mode == "block_sparse_moe": for j in range(model.config.num_experts): if f"pre_down.{j}" in outputs: quantizers[f"w1.{j}"].add_batch(outputs["post_norm"]) quantizers[f"w2.{j}"].add_batch(outputs[f"pre_down.{j}"]) if outputs[f"pre_down.{j}"].shape[0] < outputs["post_norm"].shape[0] / 10: uncalibrated_experts[j] += 1 else: uncalibrated_experts[j] += 1 if mode == "linear": quantizers["lm_head"].add_batch(x) if mode != "linear": target_states.append(outputs["hidden_states"].to("cpu")) # For MoE layers, warn if any layer received less than 10% of a calibration batch if mode == "block_sparse_moe": for j in range(model.config.num_experts): ue = uncalibrated_experts[j] if ue > len(hidden_states) * 0.10: print(f" !! Warning: w2.{j} has less than 10% calibration for {ue}/{len(hidden_states)} rows") # Conversion if mode == "self_attn": strat = strategy[module.key + "." + mode] quant_attn(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat) if mode == "mlp": strat = strategy[module.key + "." + mode] quant_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat) if mode == "block_sparse_moe": strat = strategy[module.key + "." + mode] quant_moe_mlp(job, module, hidden_states, target_states, quantizers, cache, attn_params, strat) if mode == "linear": model.drop_device_tensors() gc.collect() # shruge torch.cuda.empty_cache() quant_lm_head(job, module, hidden_states, quantizers, cache, attn_params) quantizers.clear() gc.collect() torch.cuda.empty_cache() # Post-quantization forward pass if mode == "linear": with safe_open(job["cal_filename"], framework = "pt", device = "cpu") as f: cal_ids = f.get_tensor("input_ids") rfn_sum = 0 rfn_count = 0 logprob_sum = 0.0 logprob_count = 0 q_states = [] for i in range(len(hidden_states)): if mode != "linear": x = hidden_states[i].to("cuda:0") output = module.forward(x, cache, attn_params) x = None q_states.append(output.to("cpu")) output = output[0].float() output_ref = target_states[i].to("cuda:0") output_ref = output_ref[0].float() rfn_sum += (torch.linalg.norm(output - output_ref, 'fro') / torch.linalg.norm(output_ref, 'fro')).item() rfn_count += 1 output_ref = None output = None elif i < job["measurement_rows"]: x = hidden_states[i].to("cuda:0") output = module.forward(x, cache, attn_params) if module.padding > 0: output = output[:, :, :-module.padding] logits = output[:, :-1, :] logits = logits.float() + 1e-10 target_ids = cal_ids[i:i+1, 1:].to("cuda:0") log_probs = F.log_softmax(logits, dim = -1) token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) logprob_sum += token_log_probs.sum().item() logprob_count += target_ids.numel() output = None logits = None token_log_probs = None if mode != "linear": err = rfn_sum / rfn_count print(f" -- Module quantized, rfn_error: {err:1.6f}") else: mean_log_prob = logprob_sum / logprob_count perplexity = math.exp(-mean_log_prob) print(f" -- Module quantized, calibration perplexity (quant): {perplexity:.4f}") # Unload module module.unload() torch.cuda.empty_cache() # Advance if mode != "linear": # hidden_i_states = hidden_states # hidden_states = target_states # hidden_states = [(x + y) / 2 for x, y in zip(target_states, q_states)] hidden_states = q_states q_states = None # Checkpoint if index % snapshot_interval == 0 or index == len(model.modules) - 1: if mode != "linear": save_dict = {f"row.{idx:05}": h for idx, h in enumerate(hidden_states)} # save_dict.update( {f"i_row.{idx:05}": h for idx, h in enumerate(hidden_i_states)} ) save_file(save_dict, temp_filename) save_dict = None job["invalid"] = True save_fn() if mode != "linear": os.replace(temp_filename, states_filename) job["q_last_module_idx"] = index del job["invalid"] save_fn()