mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 06:19:10 +00:00
PPL eval: Transformers FP32 mode
This commit is contained in:
@@ -48,12 +48,16 @@ def main(args):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_dir,
|
||||
device_map = "auto",
|
||||
torch_dtype = torch.half if args.tight else None,
|
||||
torch_dtype = torch.half if args.tight else torch.float if args.fp32 else None,
|
||||
)
|
||||
if args.tight:
|
||||
free_mem()
|
||||
model.half()
|
||||
free_mem()
|
||||
if args.fp32:
|
||||
free_mem()
|
||||
model.float()
|
||||
free_mem()
|
||||
|
||||
vocab_size = tokenizer.actual_vocab_size
|
||||
|
||||
@@ -94,6 +98,7 @@ if __name__ == "__main__":
|
||||
model_init.add_args(parser, cache = False)
|
||||
parser.add_argument("-r", "--rows", type = int, help = "Number of rows", default = 100)
|
||||
parser.add_argument("-t", "--tight", action = "store_true", help = "Force FP16 dtype to save memory")
|
||||
parser.add_argument("-fp32", "--fp32", action = "store_true", help = "Force FP32 dtype")
|
||||
parser.add_argument("-l", "--length", type = int, help = "Length", default = 2048)
|
||||
_args = parser.parse_args()
|
||||
main(_args)
|
||||
|
||||
Reference in New Issue
Block a user