compare_q.py: Some fixes

This commit is contained in:
turboderp
2025-05-16 00:33:48 +02:00
parent 48747ba09d
commit a19538cf1e
2 changed files with 3 additions and 1 deletions

View File

@@ -222,6 +222,7 @@ def plot(results, args):
plt.show()
@torch.inference_mode()
def main(args):
with open(args.dataspec, "r", encoding = "utf8") as f:
test_data_spec = json.load(f)

View File

@@ -11,7 +11,7 @@ def get_storage_info(model):
head_bpw = 0
head_numel = 0
for module in model:
if module.key == "lm_head":
if module.key.endswith("lm_head"):
head_bpw = get_tensor_size(module.get_tensors()) / module.weights_numel()
head_numel = module.weights_numel()
elif isinstance(module, Linear):
@@ -35,4 +35,5 @@ def load_exllamav3(model_dir: str | list):
def fwd_exllamav3(model_instance, input_ids: torch.Tensor):
input_ids = input_ids.cpu()
output = model_instance.forward(input_ids, {"attn_mode": "flash_attn_nc"})
output[..., model_instance.config.vocab_size:] = float("-inf")
return output