mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Revert some changes, calibrate to q state again (fixes 70B low bitrate)
This commit is contained in:
@@ -574,10 +574,6 @@ def quant(job, save_fn, model):
|
||||
for layer in job["measurement"]:
|
||||
qparams[layer["key"]] = layer["best_option"]["qparams"]
|
||||
|
||||
#qparams["model.layers.0.mlp.gate_proj"] = QParams(32, [6], [1], 4).get_dict()
|
||||
#qparams["model.layers.0.mlp.up_proj"] = QParams(32, [6], [1], 4).get_dict()
|
||||
#qparams["model.layers.0.mlp.down_proj"] = QParams(64, [8, 4, 3], [0.05, 0.1, 0.85], 4).get_dict()
|
||||
|
||||
# Quantize
|
||||
|
||||
if not "q_last_module_idx" in job:
|
||||
@@ -588,9 +584,9 @@ def quant(job, save_fn, model):
|
||||
|
||||
page_rows = (job["gpu_rows"] < job["dataset_rows"])
|
||||
|
||||
index = job["q_last_module_idx"]
|
||||
while True:
|
||||
|
||||
index = job["q_last_module_idx"]
|
||||
index += 1
|
||||
if index >= len(model.modules): break
|
||||
|
||||
@@ -746,122 +742,111 @@ def quant(job, save_fn, model):
|
||||
logprob_sum = 0.0
|
||||
logprob_count = 0
|
||||
|
||||
# Post-quantization forward pass
|
||||
|
||||
out_name = os.path.join(job["out_dir"], "output_states.safetensors")
|
||||
|
||||
if job.get("diagnostics", False):
|
||||
with torch.inference_mode():
|
||||
|
||||
# Post-quantization forward pass
|
||||
rfn_sum = 0.0
|
||||
# error_states_list = output_states_list.copy()
|
||||
|
||||
with torch.inference_mode():
|
||||
for b in range(input_states.shape[0]):
|
||||
|
||||
rfn_sum = 0.0
|
||||
# error_states_list = output_states_list.copy()
|
||||
x = input_states[b:b+1, :, :].to("cuda:0")
|
||||
cache = None
|
||||
attn_mask = None
|
||||
if isinstance(module, ExLlamaV2Attention):
|
||||
attn_mask = model.build_attn_mask(1, x.shape[1], 0, None, "cuda:0")
|
||||
|
||||
for b in range(input_states.shape[0]):
|
||||
outputs = module.forward(x, cache, attn_mask)
|
||||
|
||||
x = input_states[b:b+1, :, :].to("cuda:0")
|
||||
cache = None
|
||||
attn_mask = None
|
||||
if isinstance(module, ExLlamaV2Attention):
|
||||
attn_mask = model.build_attn_mask(1, x.shape[1], 0, None, "cuda:0")
|
||||
# Clamp state values to FP16 range
|
||||
|
||||
outputs = module.forward(x, cache, attn_mask)
|
||||
outputs[outputs == -float('inf')] = -65504.0
|
||||
outputs[outputs == float('inf')] = 65504.0
|
||||
|
||||
# Clamp state values to FP16 range
|
||||
# Compute perplexity for head layer without saving output state
|
||||
|
||||
outputs[outputs == -float('inf')] = -65504.0
|
||||
outputs[outputs == float('inf')] = 65504.0
|
||||
if module.key == "lm_head" and b < job["measurement_rows"]:
|
||||
|
||||
# Compute perplexity for head layer without saving output state
|
||||
if module.padding > 0: outputs = outputs[:, :, :-module.padding]
|
||||
|
||||
# if module.key == "lm_head" and b < job["measurement_rows"]:
|
||||
#
|
||||
# if module.padding > 0: outputs = outputs[:, :, :-module.padding]
|
||||
#
|
||||
# logits = outputs[:, :-1, :]
|
||||
# logits = logits.float() + 1e-10
|
||||
# target_ids = cal_ids[b:b+1, 1:].to("cuda:0")
|
||||
#
|
||||
# log_probs = F.log_softmax(logits, dim = -1)
|
||||
# token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
||||
# logprob_sum += token_log_probs.sum().item()
|
||||
# logprob_count += target_ids.numel()
|
||||
logits = outputs[:, :-1, :]
|
||||
logits = logits.float() + 1e-10
|
||||
target_ids = cal_ids[b:b+1, 1:].to("cuda:0")
|
||||
|
||||
# Measure error
|
||||
log_probs = F.log_softmax(logits, dim = -1)
|
||||
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
||||
logprob_sum += token_log_probs.sum().item()
|
||||
logprob_count += target_ids.numel()
|
||||
|
||||
if module.key != "lm_head":
|
||||
target = output_states_list[b]
|
||||
if target.device == torch.device("cpu"): target = target.to("cuda:0")
|
||||
a_ = outputs.narrow(-1, 0, min(target.shape[-1], outputs.shape[-1]))
|
||||
b_ = target.narrow(-1, 0, min(target.shape[-1], outputs.shape[-1]))
|
||||
a_0 = a_[0].float()
|
||||
b_0 = b_[0].float()
|
||||
norm_ab = torch.linalg.norm(a_0 - b_0, 'fro')
|
||||
norm_b = torch.linalg.norm(b_0, 'fro')
|
||||
rfn = norm_ab / norm_b
|
||||
# print(rfn)
|
||||
rfn_sum += rfn
|
||||
target = None
|
||||
|
||||
# if page_rows: outputs = outputs.to("cpu") #########################
|
||||
# output_states_list[b] = outputs
|
||||
|
||||
outputs = None
|
||||
x = None
|
||||
attn_mask = None
|
||||
# Measure error
|
||||
|
||||
if module.key != "lm_head":
|
||||
target = output_states_list[b]
|
||||
if target.device == torch.device("cpu"): target = target.to("cuda:0")
|
||||
a_ = outputs.narrow(-1, 0, min(target.shape[-1], outputs.shape[-1]))
|
||||
b_ = target.narrow(-1, 0, min(target.shape[-1], outputs.shape[-1]))
|
||||
a_0 = a_[0].float()
|
||||
b_0 = b_[0].float()
|
||||
norm_ab = torch.linalg.norm(a_0 - b_0, 'fro')
|
||||
norm_b = torch.linalg.norm(b_0, 'fro')
|
||||
rfn = norm_ab / norm_b
|
||||
# print(rfn)
|
||||
rfn_sum += rfn
|
||||
target = None
|
||||
|
||||
rfn_avg = rfn_sum / input_states.shape[0]
|
||||
print(f" -- Layer rfn_error: {rfn_avg:.6f}")
|
||||
if page_rows: outputs = outputs.to("cpu")
|
||||
output_states_list[b] = outputs
|
||||
|
||||
if math.isnan(rfn_avg) or rfn_avg > 1.0:
|
||||
print(" ## Quantization error (3)")
|
||||
os._exit(0)
|
||||
outputs = None
|
||||
x = None
|
||||
attn_mask = None
|
||||
|
||||
input_states = None
|
||||
del input_states
|
||||
if module.key != "lm_head":
|
||||
|
||||
# Save progress
|
||||
rfn_avg = rfn_sum / input_states.shape[0]
|
||||
print(f" -- Layer rfn_error: {rfn_avg:.6f}")
|
||||
|
||||
if module.key != "lm_head":
|
||||
output_states = torch.cat(output_states_list, dim = 0)
|
||||
if math.isnan(rfn_avg) or rfn_avg > 1.0:
|
||||
print(" ## Quantization error (3)")
|
||||
os._exit(0)
|
||||
|
||||
output_states_list = None
|
||||
output_states = torch.cat(output_states_list, dim = 0)
|
||||
save_file({ "hidden_state": output_states }, out_name)
|
||||
|
||||
input_states = None
|
||||
del input_states
|
||||
|
||||
# Perplexity
|
||||
|
||||
if module.key == "lm_head" and isinstance(module, ExLlamaV2Linear):
|
||||
|
||||
mean_log_prob = logprob_sum / logprob_count
|
||||
perplexity = math.exp(-mean_log_prob)
|
||||
|
||||
print(f" -- Calibration perplexity (quant): {perplexity:.4f}")
|
||||
job["cal_perplexity"] = perplexity
|
||||
|
||||
# Unload module
|
||||
|
||||
output_states_list = None
|
||||
module.unload()
|
||||
|
||||
# Snapshot
|
||||
# Advance
|
||||
|
||||
if job["snapshot_interval"] > 0 and index % job["snapshot_interval"] == 0:
|
||||
job["invalid"] = True
|
||||
save_fn()
|
||||
|
||||
if module.key != "lm_head":
|
||||
print(" -- Saving snapshot...")
|
||||
save_file({ "hidden_state": output_states }, out_name)
|
||||
job["q_last_module_idx"] = index
|
||||
|
||||
# # Perplexity
|
||||
#
|
||||
# if module.key == "lm_head" and isinstance(module, ExLlamaV2Linear):
|
||||
#
|
||||
# mean_log_prob = logprob_sum / logprob_count
|
||||
# perplexity = math.exp(-mean_log_prob)
|
||||
#
|
||||
# print(f" -- Calibration perplexity (quant): {perplexity:.4f}")
|
||||
# job["cal_perplexity"] = perplexity
|
||||
if module.key != "lm_head":
|
||||
os.remove(in_name)
|
||||
os.rename(out_name, in_name)
|
||||
|
||||
job["invalid"] = True
|
||||
save_fn()
|
||||
|
||||
job["q_last_module_idx"] = index
|
||||
|
||||
if module.key != "lm_head":
|
||||
os.remove(in_name)
|
||||
os.rename(out_name, in_name)
|
||||
|
||||
if "invalid" in job: del job["invalid"]
|
||||
save_fn()
|
||||
if "invalid" in job: del job["invalid"]
|
||||
save_fn()
|
||||
|
||||
# Report time taken
|
||||
|
||||
|
||||
12
convert.py
12
convert.py
@@ -31,8 +31,8 @@ parser.add_argument("-rs", "--rope_scale", type = float, default = 1.0, help = "
|
||||
parser.add_argument("-ra", "--rope_alpha", type = float, default = 1.0, help = "RoPE alpha value (NTK)")
|
||||
parser.add_argument("-kld", "--kld_estimate", action = "store_true", help = "Do not use measurement, instead optimize according to built-in KL divergence estimates")
|
||||
parser.add_argument("-gq", "--gigaquant", action = "store_true", help = "Gigaquant mode (don't use this)")
|
||||
parser.add_argument("-d", "--diagnostics", action = "store_true", help = "Output more diagnostic data while quantizing (slower)")
|
||||
parser.add_argument("-si", "--snapshot_interval", type = int, default = 10, help = "Snapshot every this many layers while quantizing")
|
||||
# parser.add_argument("-d", "--diagnostics", action = "store_true", help = "Output more diagnostic data while quantizing (slower)")
|
||||
# parser.add_argument("-si", "--snapshot_interval", type = int, default = 10, help = "Snapshot every this many layers while quantizing")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -95,8 +95,8 @@ rope_scale = args.rope_scale
|
||||
rope_alpha = args.rope_alpha
|
||||
kld_estimate = args.kld_estimate
|
||||
gigaquant = args.gigaquant
|
||||
diagnostics = args.diagnostics
|
||||
snapshot_interval = args.snapshot_interval
|
||||
# diagnostics = args.diagnostics
|
||||
# snapshot_interval = args.snapshot_interval
|
||||
|
||||
compile_full = args.compile_full
|
||||
|
||||
@@ -167,8 +167,8 @@ if no_resume or not os.path.exists(job_file):
|
||||
"rope_alpha": rope_alpha,
|
||||
"kld_estimate": kld_estimate,
|
||||
"gigaquant": gigaquant,
|
||||
"diagnostics": diagnostics,
|
||||
"snapshot_interval": snapshot_interval
|
||||
# "diagnostics": diagnostics,
|
||||
# "snapshot_interval": snapshot_interval
|
||||
}
|
||||
|
||||
if reuse_measurement is not None:
|
||||
|
||||
Reference in New Issue
Block a user