mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-26 01:08:58 +00:00
model_diff.py: Limit batch size (prevent OoM on output layer)
This commit is contained in:
@@ -55,6 +55,23 @@ def get_test_tokens(tokenizer, rows, eval_len = 2048, eval_stride = 512):
|
||||
return torch.cat(seqs, dim = 0)[:, :]
|
||||
|
||||
|
||||
def ppl(input_ids_, logits_):
|
||||
logprob_sum_ = 0.0
|
||||
logprob_count_ = 0
|
||||
chunksize = logits_.shape[1] * 10240 // logits_.shape[1]
|
||||
b_ = 0
|
||||
while b_ < logits_.shape[1]:
|
||||
a_ = b_
|
||||
b_ = min(b_ + chunksize, logits_.shape[1])
|
||||
logits_f = logits_[a_:b_, :].float() + 1e-10
|
||||
target_ids = input_ids_[a_ + 1:b_ + 1].to(logits_.device)
|
||||
log_probs = F.log_softmax(logits_f, dim = -1)
|
||||
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
||||
logprob_sum_ += token_log_probs.sum().item()
|
||||
logprob_count_ += target_ids.numel()
|
||||
return logprob_sum_, logprob_count_
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(args):
|
||||
|
||||
@@ -90,148 +107,177 @@ def main(args):
|
||||
config_a.stc = vstc
|
||||
|
||||
# Dataset
|
||||
eval_ids = get_test_tokens(tokenizer, args.rows)
|
||||
state_a = eval_ids
|
||||
state_b = eval_ids
|
||||
all_eval_ids = get_test_tokens(tokenizer, args.rows)
|
||||
|
||||
# Inputs
|
||||
states_a = list(all_eval_ids.split(args.batch_size))
|
||||
states_b = list(all_eval_ids.split(args.batch_size))
|
||||
all_eval_ids = list(all_eval_ids.split(args.batch_size))
|
||||
|
||||
# Save input IDs
|
||||
if args.save_input_ids:
|
||||
print(f" -- Saving input IDs to: {args.save_input_ids}")
|
||||
save_tensor(eval_ids, args.save_input_ids, "input_ids")
|
||||
save_tensor(all_eval_ids, args.save_input_ids, "input_ids")
|
||||
|
||||
# Output logits
|
||||
save_logits_a = []
|
||||
save_logits_b = []
|
||||
|
||||
# Inference
|
||||
for idx, (module_a, module_b) in enumerate(zip(model_a.modules, model_b.modules)):
|
||||
|
||||
logits_layer = module_a == model_a.modules[-1]
|
||||
|
||||
# Load modules
|
||||
config_a.stc.begin_deferred_load()
|
||||
module_a.load(device if not module_a.caps.get("prefer_cpu") else "cpu")
|
||||
config_a.stc.end_deferred_load()
|
||||
params_a = {}
|
||||
state_a = module_a.prepare_for_device(state_a, params_a)
|
||||
state_a = module_a.forward(state_a, params_a)
|
||||
module_a.unload()
|
||||
config_a.stc.close()
|
||||
free_mem()
|
||||
|
||||
config_b.stc.begin_deferred_load()
|
||||
module_b.load(device if not module_b.caps.get("prefer_cpu") else "cpu")
|
||||
config_b.stc.end_deferred_load()
|
||||
params_b = {}
|
||||
state_b = module_b.prepare_for_device(state_b, params_b)
|
||||
state_b = module_b.forward(state_b, params_b)
|
||||
module_b.unload()
|
||||
config_b.stc.close()
|
||||
free_mem()
|
||||
|
||||
if idx < args.keep_b:
|
||||
state_a = state_b.clone()
|
||||
|
||||
# Error measures
|
||||
max_diff = 0
|
||||
rfn_error_sum = 0
|
||||
cos_error_sum = 0
|
||||
sqnr_sum = 0
|
||||
rows = state_a.shape[0]
|
||||
for i in range(rows):
|
||||
sa = state_a[i].to(float, copy = True)
|
||||
sb = state_b[i].to(float)
|
||||
cos_error_sum += cosine_error(sa, sb)
|
||||
sqnr_sum += sqnr(sa, sb)
|
||||
sa -= sb
|
||||
rfn_error_sum += (torch.linalg.norm(sa, 'fro') / torch.linalg.norm(sb, 'fro').mean()).item()
|
||||
sa.abs_()
|
||||
md = ((sa.max().item()) / torch.linalg.norm(sb, 'fro').mean()).item()
|
||||
max_diff = max(max_diff, md)
|
||||
|
||||
del sa, sb
|
||||
rfn_error = rfn_error_sum / rows
|
||||
cos_error = cos_error_sum / rows
|
||||
sqnr_ = sqnr_sum / rows
|
||||
print(
|
||||
f" -- {module_a.key:40}"
|
||||
f" rfn_err: {rfn_error:.6f}"
|
||||
f" max_diff/norm: {max_diff:.6f}"
|
||||
f" sqnr: {sqnr_:9.6f}"
|
||||
f" cos_err: {cos_error:.6f}"
|
||||
)
|
||||
# Similarity measures
|
||||
topk_max = args.topk_max
|
||||
logprob_sum = [0, 0]
|
||||
logprob_count = [0, 0]
|
||||
kl_div_sum_ab = 0
|
||||
kl_div_sum_ba = 0
|
||||
topk_hits_sum = [[0] * topk_max, [0] * topk_max]
|
||||
topk_hits_count = [[0] * topk_max, [0] * topk_max]
|
||||
topk_agreement_sum = [0] * topk_max
|
||||
topk_agreement_count = [0] * topk_max
|
||||
|
||||
# Save logits
|
||||
if args.save_logits_a:
|
||||
print(f" -- Saving model A logits to: {args.save_logits_a}")
|
||||
save_tensor(state_a, args.save_logits_a, "logits")
|
||||
if args.save_logits_b:
|
||||
print(f" -- Saving model B logits to: {args.save_logits_b}")
|
||||
save_tensor(state_b, args.save_logits_b, "logits")
|
||||
for b in range(len(states_a)):
|
||||
|
||||
# Compare logits
|
||||
topk_max = args.topk_max
|
||||
logprob_sum = [0, 0]
|
||||
logprob_count = [0, 0]
|
||||
kl_div_sum_ab = 0
|
||||
kl_div_sum_ba = 0
|
||||
topk_hits_sum = [[0] * topk_max, [0] * topk_max]
|
||||
topk_hits_count = [[0] * topk_max, [0] * topk_max]
|
||||
topk_agreement_sum = [0] * topk_max
|
||||
topk_agreement_count = [0] * topk_max
|
||||
# Advance state
|
||||
state_a = states_a[b]
|
||||
state_b = states_b[b]
|
||||
eval_ids = all_eval_ids[b]
|
||||
|
||||
def ppl(input_ids_, logits_):
|
||||
nonlocal logprob_sum, logprob_count
|
||||
logprob_sum_ = 0.0
|
||||
logprob_count_ = 0
|
||||
chunksize = logits_.shape[1] * 10240 // logits_.shape[1]
|
||||
b_ = 0
|
||||
while b_ < logits_.shape[1]:
|
||||
a_ = b_
|
||||
b_ = min(b_ + chunksize, logits_.shape[1])
|
||||
logits_f = logits_[a_:b_, :].float() + 1e-10
|
||||
target_ids = input_ids_[a_ + 1:b_ + 1].to(logits_.device)
|
||||
log_probs = F.log_softmax(logits_f, dim=-1)
|
||||
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
||||
logprob_sum_ += token_log_probs.sum().item()
|
||||
logprob_count_ += target_ids.numel()
|
||||
return logprob_sum_, logprob_count_
|
||||
params_a = {}
|
||||
state_a = module_a.prepare_for_device(state_a, params_a)
|
||||
state_a = module_a.forward(state_a, params_a)
|
||||
|
||||
rows = state_a.shape[0]
|
||||
for j in range(rows):
|
||||
x = (state_a[j], state_b[j])
|
||||
input_ids = eval_ids[j]
|
||||
top_indices = []
|
||||
params_b = {}
|
||||
state_b = module_b.prepare_for_device(state_b, params_b)
|
||||
state_b = module_b.forward(state_b, params_b)
|
||||
|
||||
for i in [0, 1]:
|
||||
logits = x[i][:-1, :]
|
||||
logprob_sum__, logprob_count__ = ppl(input_ids, logits)
|
||||
logprob_sum[i] += logprob_sum__
|
||||
logprob_count[i] += logprob_count__
|
||||
# Optionally override model A state for first layers
|
||||
if idx < args.keep_b:
|
||||
state_a = state_b.clone()
|
||||
|
||||
_, top_index = torch.topk(logits, topk_max, dim = -1)
|
||||
top_index = top_index.cpu().view(-1, topk_max)
|
||||
top_indices.append(top_index)
|
||||
targets = input_ids[1:].view(-1, 1)
|
||||
# Drop logits on last iteration
|
||||
if not logits_layer:
|
||||
states_a[b] = state_a
|
||||
states_b[b] = state_b
|
||||
|
||||
for t in range(topk_max):
|
||||
top_slice = top_index[:, :t + 1]
|
||||
hits = torch.eq(targets, top_slice)
|
||||
row_hits = hits.any(dim = 1)
|
||||
topk_hits_sum[i][t] += row_hits.sum().item()
|
||||
topk_hits_count[i][t] += top_slice.shape[0]
|
||||
# Copy logits to CPU if saving
|
||||
else:
|
||||
if save_logits_a:
|
||||
save_logits_a.append(state_a.cpu().split(1))
|
||||
if save_logits_b:
|
||||
save_logits_b.append(state_b.cpu().split(1))
|
||||
|
||||
for t in range(topk_max):
|
||||
top_slice_a = top_indices[0][:, :t + 1]
|
||||
top_slice_b = top_indices[1][:, :t + 1]
|
||||
hits = torch.eq(top_slice_a, top_slice_b)
|
||||
row_hits = hits.all(dim = 1)
|
||||
topk_agreement_sum[t] += row_hits.sum().item()
|
||||
topk_agreement_count[t] += top_slice_a.shape[0]
|
||||
# Measure error
|
||||
if not logits_layer:
|
||||
rows = state_a.shape[0]
|
||||
for j in range(rows):
|
||||
sa = state_a[j].to(float)
|
||||
sb = state_b[j].to(float)
|
||||
cos_error_sum += cosine_error(sa, sb)
|
||||
sqnr_sum += sqnr(sa, sb)
|
||||
sa -= sb
|
||||
rfn_error_sum += (torch.linalg.norm(sa, 'fro') / torch.linalg.norm(sb, 'fro').mean()).item()
|
||||
sa.abs_()
|
||||
md = ((sa.max().item()) / torch.linalg.norm(sb, 'fro').mean()).item()
|
||||
max_diff = max(max_diff, md)
|
||||
del sa, sb
|
||||
|
||||
epsilon = 1e-10
|
||||
probs_a = torch.softmax(x[0].float(), dim = -1)
|
||||
probs_b = torch.softmax(x[1].float(), dim = -1)
|
||||
kl_div = F.kl_div(torch.log(probs_a + epsilon), probs_b, reduction = 'none')
|
||||
kl_div_sum_ab += kl_div.sum(dim = -1).mean().item()
|
||||
kl_div = F.kl_div(torch.log(probs_b + epsilon), probs_a, reduction = 'none')
|
||||
kl_div_sum_ba += kl_div.sum(dim = -1).mean().item()
|
||||
# Perplexity, KL-div
|
||||
if logits_layer:
|
||||
rows = state_a.shape[0]
|
||||
for j in range(rows):
|
||||
x = (state_a[j], state_b[j])
|
||||
input_ids = eval_ids[j]
|
||||
top_indices = []
|
||||
|
||||
perplexity = [math.exp(-logprob_sum[i] / logprob_count[i]) for i in (0, 1)]
|
||||
kl_div_ab = kl_div_sum_ab / rows
|
||||
kl_div_ba = kl_div_sum_ba / rows
|
||||
for i in [0, 1]:
|
||||
logits = x[i][:-1, :]
|
||||
logprob_sum__, logprob_count__ = ppl(input_ids, logits)
|
||||
logprob_sum[i] += logprob_sum__
|
||||
logprob_count[i] += logprob_count__
|
||||
|
||||
_, top_index = torch.topk(logits, topk_max, dim = -1)
|
||||
top_index = top_index.cpu().view(-1, topk_max)
|
||||
top_indices.append(top_index)
|
||||
targets = input_ids[1:].view(-1, 1)
|
||||
|
||||
for t in range(topk_max):
|
||||
top_slice = top_index[:, :t + 1]
|
||||
hits = torch.eq(targets, top_slice)
|
||||
row_hits = hits.any(dim = 1)
|
||||
topk_hits_sum[i][t] += row_hits.sum().item()
|
||||
topk_hits_count[i][t] += top_slice.shape[0]
|
||||
|
||||
for t in range(topk_max):
|
||||
top_slice_a = top_indices[0][:, :t + 1]
|
||||
top_slice_b = top_indices[1][:, :t + 1]
|
||||
hits = torch.eq(top_slice_a, top_slice_b)
|
||||
row_hits = hits.all(dim = 1)
|
||||
topk_agreement_sum[t] += row_hits.sum().item()
|
||||
topk_agreement_count[t] += top_slice_a.shape[0]
|
||||
|
||||
epsilon = 1e-10
|
||||
probs_a = torch.softmax(x[0].float(), dim = -1)
|
||||
probs_b = torch.softmax(x[1].float(), dim = -1)
|
||||
kl_div = F.kl_div(torch.log(probs_a + epsilon), probs_b, reduction = 'none')
|
||||
kl_div_sum_ab += kl_div.sum(dim = -1).mean().item()
|
||||
kl_div = F.kl_div(torch.log(probs_b + epsilon), probs_a, reduction = 'none')
|
||||
kl_div_sum_ba += kl_div.sum(dim = -1).mean().item()
|
||||
|
||||
# Print error
|
||||
if not logits_layer:
|
||||
rfn_error = rfn_error_sum / args.rows
|
||||
cos_error = cos_error_sum / args.rows
|
||||
sqnr_ = sqnr_sum / args.rows
|
||||
print(
|
||||
f" -- {module_a.key:40}"
|
||||
f" rfn_err: {rfn_error:.6f}"
|
||||
f" max_diff/norm: {max_diff:.6f}"
|
||||
f" sqnr: {sqnr_:9.6f}"
|
||||
f" cos_err: {cos_error:.6f}"
|
||||
)
|
||||
|
||||
# Save logits
|
||||
if logits_layer:
|
||||
if args.save_logits_a:
|
||||
print(f" -- Saving model A logits to: {args.save_logits_a}")
|
||||
save_tensor(state_a, args.save_logits_a, "logits")
|
||||
if args.save_logits_b:
|
||||
print(f" -- Saving model B logits to: {args.save_logits_b}")
|
||||
save_tensor(state_b, args.save_logits_b, "logits")
|
||||
|
||||
# Final ppl, kld
|
||||
if logits_layer:
|
||||
perplexity = [math.exp(-logprob_sum[i] / logprob_count[i]) for i in (0, 1)]
|
||||
kl_div_ab = kl_div_sum_ab / args.rows
|
||||
kl_div_ba = kl_div_sum_ba / args.rows
|
||||
|
||||
# Unload modules
|
||||
module_a.unload()
|
||||
config_a.stc.close()
|
||||
free_mem()
|
||||
|
||||
module_b.unload()
|
||||
config_b.stc.close()
|
||||
free_mem()
|
||||
|
||||
# Perplexity for each model
|
||||
print(f" -- A perplexity: {perplexity[0]:11.8f}")
|
||||
@@ -270,5 +316,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-si", "--save_input_ids", type = str, help = "Save input IDs (filename)", default = None)
|
||||
parser.add_argument("-sla", "--save_logits_a", type = str, help = "Save model A logits (filename)", default = None)
|
||||
parser.add_argument("-slb", "--save_logits_b", type = str, help = "Save model B logits (filename)", default = None)
|
||||
parser.add_argument("-bsz", "--batch_size", type = int, help = "Batch size", default = 1)
|
||||
_args = parser.parse_args()
|
||||
main(_args)
|
||||
|
||||
Reference in New Issue
Block a user