Revert some changes, calibrate to q state again (fixes 70B low bitrate)

This commit is contained in:
turboderp
2023-12-10 17:34:18 +01:00
parent 91e35b8ce5
commit 3c43bad57f
2 changed files with 81 additions and 96 deletions

View File

@@ -574,10 +574,6 @@ def quant(job, save_fn, model):
for layer in job["measurement"]:
qparams[layer["key"]] = layer["best_option"]["qparams"]
#qparams["model.layers.0.mlp.gate_proj"] = QParams(32, [6], [1], 4).get_dict()
#qparams["model.layers.0.mlp.up_proj"] = QParams(32, [6], [1], 4).get_dict()
#qparams["model.layers.0.mlp.down_proj"] = QParams(64, [8, 4, 3], [0.05, 0.1, 0.85], 4).get_dict()
# Quantize
if not "q_last_module_idx" in job:
@@ -588,9 +584,9 @@ def quant(job, save_fn, model):
page_rows = (job["gpu_rows"] < job["dataset_rows"])
index = job["q_last_module_idx"]
while True:
index = job["q_last_module_idx"]
index += 1
if index >= len(model.modules): break
@@ -746,122 +742,111 @@ def quant(job, save_fn, model):
logprob_sum = 0.0
logprob_count = 0
# Post-quantization forward pass
out_name = os.path.join(job["out_dir"], "output_states.safetensors")
if job.get("diagnostics", False):
with torch.inference_mode():
# Post-quantization forward pass
rfn_sum = 0.0
# error_states_list = output_states_list.copy()
with torch.inference_mode():
for b in range(input_states.shape[0]):
rfn_sum = 0.0
# error_states_list = output_states_list.copy()
x = input_states[b:b+1, :, :].to("cuda:0")
cache = None
attn_mask = None
if isinstance(module, ExLlamaV2Attention):
attn_mask = model.build_attn_mask(1, x.shape[1], 0, None, "cuda:0")
for b in range(input_states.shape[0]):
outputs = module.forward(x, cache, attn_mask)
x = input_states[b:b+1, :, :].to("cuda:0")
cache = None
attn_mask = None
if isinstance(module, ExLlamaV2Attention):
attn_mask = model.build_attn_mask(1, x.shape[1], 0, None, "cuda:0")
# Clamp state values to FP16 range
outputs = module.forward(x, cache, attn_mask)
outputs[outputs == -float('inf')] = -65504.0
outputs[outputs == float('inf')] = 65504.0
# Clamp state values to FP16 range
# Compute perplexity for head layer without saving output state
outputs[outputs == -float('inf')] = -65504.0
outputs[outputs == float('inf')] = 65504.0
if module.key == "lm_head" and b < job["measurement_rows"]:
# Compute perplexity for head layer without saving output state
if module.padding > 0: outputs = outputs[:, :, :-module.padding]
# if module.key == "lm_head" and b < job["measurement_rows"]:
#
# if module.padding > 0: outputs = outputs[:, :, :-module.padding]
#
# logits = outputs[:, :-1, :]
# logits = logits.float() + 1e-10
# target_ids = cal_ids[b:b+1, 1:].to("cuda:0")
#
# log_probs = F.log_softmax(logits, dim = -1)
# token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
# logprob_sum += token_log_probs.sum().item()
# logprob_count += target_ids.numel()
logits = outputs[:, :-1, :]
logits = logits.float() + 1e-10
target_ids = cal_ids[b:b+1, 1:].to("cuda:0")
# Measure error
log_probs = F.log_softmax(logits, dim = -1)
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
logprob_sum += token_log_probs.sum().item()
logprob_count += target_ids.numel()
if module.key != "lm_head":
target = output_states_list[b]
if target.device == torch.device("cpu"): target = target.to("cuda:0")
a_ = outputs.narrow(-1, 0, min(target.shape[-1], outputs.shape[-1]))
b_ = target.narrow(-1, 0, min(target.shape[-1], outputs.shape[-1]))
a_0 = a_[0].float()
b_0 = b_[0].float()
norm_ab = torch.linalg.norm(a_0 - b_0, 'fro')
norm_b = torch.linalg.norm(b_0, 'fro')
rfn = norm_ab / norm_b
# print(rfn)
rfn_sum += rfn
target = None
# if page_rows: outputs = outputs.to("cpu") #########################
# output_states_list[b] = outputs
outputs = None
x = None
attn_mask = None
# Measure error
if module.key != "lm_head":
target = output_states_list[b]
if target.device == torch.device("cpu"): target = target.to("cuda:0")
a_ = outputs.narrow(-1, 0, min(target.shape[-1], outputs.shape[-1]))
b_ = target.narrow(-1, 0, min(target.shape[-1], outputs.shape[-1]))
a_0 = a_[0].float()
b_0 = b_[0].float()
norm_ab = torch.linalg.norm(a_0 - b_0, 'fro')
norm_b = torch.linalg.norm(b_0, 'fro')
rfn = norm_ab / norm_b
# print(rfn)
rfn_sum += rfn
target = None
rfn_avg = rfn_sum / input_states.shape[0]
print(f" -- Layer rfn_error: {rfn_avg:.6f}")
if page_rows: outputs = outputs.to("cpu")
output_states_list[b] = outputs
if math.isnan(rfn_avg) or rfn_avg > 1.0:
print(" ## Quantization error (3)")
os._exit(0)
outputs = None
x = None
attn_mask = None
input_states = None
del input_states
if module.key != "lm_head":
# Save progress
rfn_avg = rfn_sum / input_states.shape[0]
print(f" -- Layer rfn_error: {rfn_avg:.6f}")
if module.key != "lm_head":
output_states = torch.cat(output_states_list, dim = 0)
if math.isnan(rfn_avg) or rfn_avg > 1.0:
print(" ## Quantization error (3)")
os._exit(0)
output_states_list = None
output_states = torch.cat(output_states_list, dim = 0)
save_file({ "hidden_state": output_states }, out_name)
input_states = None
del input_states
# Perplexity
if module.key == "lm_head" and isinstance(module, ExLlamaV2Linear):
mean_log_prob = logprob_sum / logprob_count
perplexity = math.exp(-mean_log_prob)
print(f" -- Calibration perplexity (quant): {perplexity:.4f}")
job["cal_perplexity"] = perplexity
# Unload module
output_states_list = None
module.unload()
# Snapshot
# Advance
if job["snapshot_interval"] > 0 and index % job["snapshot_interval"] == 0:
job["invalid"] = True
save_fn()
if module.key != "lm_head":
print(" -- Saving snapshot...")
save_file({ "hidden_state": output_states }, out_name)
job["q_last_module_idx"] = index
# # Perplexity
#
# if module.key == "lm_head" and isinstance(module, ExLlamaV2Linear):
#
# mean_log_prob = logprob_sum / logprob_count
# perplexity = math.exp(-mean_log_prob)
#
# print(f" -- Calibration perplexity (quant): {perplexity:.4f}")
# job["cal_perplexity"] = perplexity
if module.key != "lm_head":
os.remove(in_name)
os.rename(out_name, in_name)
job["invalid"] = True
save_fn()
job["q_last_module_idx"] = index
if module.key != "lm_head":
os.remove(in_name)
os.rename(out_name, in_name)
if "invalid" in job: del job["invalid"]
save_fn()
if "invalid" in job: del job["invalid"]
save_fn()
# Report time taken

View File

@@ -31,8 +31,8 @@ parser.add_argument("-rs", "--rope_scale", type = float, default = 1.0, help = "
parser.add_argument("-ra", "--rope_alpha", type = float, default = 1.0, help = "RoPE alpha value (NTK)")
parser.add_argument("-kld", "--kld_estimate", action = "store_true", help = "Do not use measurement, instead optimize according to built-in KL divergence estimates")
parser.add_argument("-gq", "--gigaquant", action = "store_true", help = "Gigaquant mode (don't use this)")
parser.add_argument("-d", "--diagnostics", action = "store_true", help = "Output more diagnostic data while quantizing (slower)")
parser.add_argument("-si", "--snapshot_interval", type = int, default = 10, help = "Snapshot every this many layers while quantizing")
# parser.add_argument("-d", "--diagnostics", action = "store_true", help = "Output more diagnostic data while quantizing (slower)")
# parser.add_argument("-si", "--snapshot_interval", type = int, default = 10, help = "Snapshot every this many layers while quantizing")
args = parser.parse_args()
@@ -95,8 +95,8 @@ rope_scale = args.rope_scale
rope_alpha = args.rope_alpha
kld_estimate = args.kld_estimate
gigaquant = args.gigaquant
diagnostics = args.diagnostics
snapshot_interval = args.snapshot_interval
# diagnostics = args.diagnostics
# snapshot_interval = args.snapshot_interval
compile_full = args.compile_full
@@ -167,8 +167,8 @@ if no_resume or not os.path.exists(job_file):
"rope_alpha": rope_alpha,
"kld_estimate": kld_estimate,
"gigaquant": gigaquant,
"diagnostics": diagnostics,
"snapshot_interval": snapshot_interval
# "diagnostics": diagnostics,
# "snapshot_interval": snapshot_interval
}
if reuse_measurement is not None: