mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
compare_q.py: Some fixes
This commit is contained in:
@@ -222,6 +222,7 @@ def plot(results, args):
|
||||
plt.show()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(args):
|
||||
with open(args.dataspec, "r", encoding = "utf8") as f:
|
||||
test_data_spec = json.load(f)
|
||||
|
||||
@@ -11,7 +11,7 @@ def get_storage_info(model):
|
||||
head_bpw = 0
|
||||
head_numel = 0
|
||||
for module in model:
|
||||
if module.key == "lm_head":
|
||||
if module.key.endswith("lm_head"):
|
||||
head_bpw = get_tensor_size(module.get_tensors()) / module.weights_numel()
|
||||
head_numel = module.weights_numel()
|
||||
elif isinstance(module, Linear):
|
||||
@@ -35,4 +35,5 @@ def load_exllamav3(model_dir: str | list):
|
||||
def fwd_exllamav3(model_instance, input_ids: torch.Tensor):
|
||||
input_ids = input_ids.cpu()
|
||||
output = model_instance.forward(input_ids, {"attn_mode": "flash_attn_nc"})
|
||||
output[..., model_instance.config.vocab_size:] = float("-inf")
|
||||
return output
|
||||
Reference in New Issue
Block a user