Check length of gpu_split in model_init

This commit is contained in:
turboderp
2025-01-09 11:36:25 +01:00
parent c8fa853c89
commit d0413b06f8

View File

@@ -1,5 +1,6 @@
import argparse, sys, os, glob, time
import torch
from exllamav2 import(
ExLlamaV2,
@@ -167,6 +168,9 @@ def post_init_load(
split = None
if args.gpu_split and args.gpu_split != "auto":
split = [float(alloc) for alloc in args.gpu_split.split(",")]
if len(split) > torch.cuda.device_count():
print(f" ## Error: Too many entries in gpu_split. {torch.cuda.device_count()} CUDA devices are available.")
sys.exit()
if args.tensor_parallel:
if args.gpu_split == "auto": split = None