ppl_transformers.py: Explicitly make bfloat16 the default dtype

This commit is contained in:
turboderp
2025-09-18 22:11:19 +02:00
parent 476e591966
commit 3845775650

View File

@@ -48,7 +48,7 @@ def main(args):
model = AutoModelForCausalLM.from_pretrained(
args.model_dir,
device_map = "auto",
torch_dtype = torch.half if args.tight else torch.float if args.fp32 else None,
torch_dtype = torch.half if args.tight else torch.float if args.fp32 else torch.bfloat16,
)
if args.tight:
free_mem()