PPL eval: Transformers FP32 mode

This commit is contained in:
turboderp
2025-09-04 00:39:09 +02:00
parent 59e3304da1
commit d8203063dc

View File

@@ -48,12 +48,16 @@ def main(args):
model = AutoModelForCausalLM.from_pretrained(
args.model_dir,
device_map = "auto",
torch_dtype = torch.half if args.tight else None,
torch_dtype = torch.half if args.tight else torch.float if args.fp32 else None,
)
if args.tight:
free_mem()
model.half()
free_mem()
if args.fp32:
free_mem()
model.float()
free_mem()
vocab_size = tokenizer.actual_vocab_size
@@ -94,6 +98,7 @@ if __name__ == "__main__":
model_init.add_args(parser, cache = False)
parser.add_argument("-r", "--rows", type = int, help = "Number of rows", default = 100)
parser.add_argument("-t", "--tight", action = "store_true", help = "Force FP16 dtype to save memory")
parser.add_argument("-fp32", "--fp32", action = "store_true", help = "Force FP32 dtype")
parser.add_argument("-l", "--length", type = int, help = "Length", default = 2048)
_args = parser.parse_args()
main(_args)